A vibe coded tangled fork which supports pijul.
at sl/lvnwqspuwzom 423 lines 11 kB view raw
1package oauth 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "log/slog" 8 "net/http" 9 "net/url" 10 "sync" 11 "time" 12 13 comatproto "github.com/bluesky-social/indigo/api/atproto" 14 "github.com/bluesky-social/indigo/atproto/auth/oauth" 15 atpclient "github.com/bluesky-social/indigo/atproto/client" 16 atcrypto "github.com/bluesky-social/indigo/atproto/crypto" 17 "github.com/bluesky-social/indigo/atproto/syntax" 18 xrpc "github.com/bluesky-social/indigo/xrpc" 19 "github.com/gorilla/sessions" 20 "github.com/posthog/posthog-go" 21 "tangled.org/core/appview/config" 22 "tangled.org/core/appview/db" 23 "tangled.org/core/idresolver" 24 "tangled.org/core/rbac" 25) 26 27type OAuth struct { 28 ClientApp *oauth.ClientApp 29 SessStore *sessions.CookieStore 30 Config *config.Config 31 JwksUri string 32 ClientName string 33 ClientUri string 34 Posthog posthog.Client 35 Db *db.DB 36 Enforcer *rbac.Enforcer 37 IdResolver *idresolver.Resolver 38 Logger *slog.Logger 39 40 appPasswordSession *AppPasswordSession 41 appPasswordSessionMu sync.Mutex 42} 43 44func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 45 var oauthConfig oauth.ClientConfig 46 var clientUri string 47 if config.Core.Dev { 48 clientUri = "http://127.0.0.1:3000" 49 callbackUri := clientUri + "/oauth/callback" 50 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes) 51 } else { 52 clientUri = "https://" + config.Core.AppviewHost 53 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 54 callbackUri := clientUri + "/oauth/callback" 55 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes) 56 } 57 58 // configure client secret 59 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret) 60 if err != nil { 61 return nil, err 62 } 63 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil { 64 return nil, err 65 } 66 67 jwksUri := clientUri + "/oauth/jwks.json" 68 69 authStore, err := NewRedisStore(&RedisStoreConfig{ 70 RedisURL: config.Redis.ToURL(), 71 SessionExpiryDuration: time.Hour * 24 * 90, 72 SessionInactivityDuration: time.Hour * 24 * 14, 73 AuthRequestExpiryDuration: time.Minute * 30, 74 }) 75 if err != nil { 76 return nil, err 77 } 78 79 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 80 81 clientApp := oauth.NewClientApp(&oauthConfig, authStore) 82 clientApp.Dir = res.Directory() 83 // allow non-public transports in dev mode 84 if config.Core.Dev { 85 clientApp.Resolver.Client.Transport = http.DefaultTransport 86 } 87 88 clientName := config.Core.AppviewName 89 90 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential()) 91 return &OAuth{ 92 ClientApp: clientApp, 93 Config: config, 94 SessStore: sessStore, 95 JwksUri: jwksUri, 96 ClientName: clientName, 97 ClientUri: clientUri, 98 Posthog: ph, 99 Db: db, 100 Enforcer: enforcer, 101 IdResolver: res, 102 Logger: logger, 103 }, nil 104} 105 106func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 107 userSession, err := o.SessStore.Get(r, SessionName) 108 if err != nil { 109 o.Logger.Warn("failed to decode existing session cookie, will create new", "err", err) 110 } 111 112 userSession.Values[SessionDid] = sessData.AccountDID.String() 113 userSession.Values[SessionPds] = sessData.HostURL 114 userSession.Values[SessionId] = sessData.SessionID 115 userSession.Values[SessionAuthenticated] = true 116 117 if err := userSession.Save(r, w); err != nil { 118 return err 119 } 120 121 handle := "" 122 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String()) 123 if err == nil && resolved.Handle.String() != "" { 124 handle = resolved.Handle.String() 125 } 126 127 registry := o.GetAccounts(r) 128 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil { 129 return err 130 } 131 return o.SaveAccounts(w, r, registry) 132} 133 134func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 135 userSession, err := o.SessStore.Get(r, SessionName) 136 if err != nil { 137 return nil, fmt.Errorf("error getting user session: %w", err) 138 } 139 if userSession.IsNew { 140 return nil, fmt.Errorf("no session available for user") 141 } 142 143 d := userSession.Values[SessionDid].(string) 144 sessDid, err := syntax.ParseDID(d) 145 if err != nil { 146 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 147 } 148 149 sessId := userSession.Values[SessionId].(string) 150 151 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 152 if err != nil { 153 return nil, fmt.Errorf("failed to resume session: %w", err) 154 } 155 156 return clientSess, nil 157} 158 159func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 160 userSession, err := o.SessStore.Get(r, SessionName) 161 if err != nil { 162 return fmt.Errorf("error getting user session: %w", err) 163 } 164 if userSession.IsNew { 165 return fmt.Errorf("no session available for user") 166 } 167 168 d := userSession.Values[SessionDid].(string) 169 sessDid, err := syntax.ParseDID(d) 170 if err != nil { 171 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 172 } 173 174 sessId := userSession.Values[SessionId].(string) 175 176 // delete the session 177 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 178 if err1 != nil { 179 err1 = fmt.Errorf("failed to logout: %w", err1) 180 } 181 182 // remove the cookie 183 userSession.Options.MaxAge = -1 184 err2 := o.SessStore.Save(r, w, userSession) 185 if err2 != nil { 186 err2 = fmt.Errorf("failed to save into session store: %w", err2) 187 } 188 189 return errors.Join(err1, err2) 190} 191 192func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 193 registry := o.GetAccounts(r) 194 account := registry.FindAccount(targetDid) 195 if account == nil { 196 return fmt.Errorf("account not found in registry: %s", targetDid) 197 } 198 199 did, err := syntax.ParseDID(targetDid) 200 if err != nil { 201 return fmt.Errorf("invalid DID: %w", err) 202 } 203 204 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId) 205 if err != nil { 206 registry.RemoveAccount(targetDid) 207 _ = o.SaveAccounts(w, r, registry) 208 return fmt.Errorf("session expired for account: %w", err) 209 } 210 211 userSession, err := o.SessStore.Get(r, SessionName) 212 if err != nil { 213 return err 214 } 215 216 userSession.Values[SessionDid] = sess.Data.AccountDID.String() 217 userSession.Values[SessionPds] = sess.Data.HostURL 218 userSession.Values[SessionId] = sess.Data.SessionID 219 userSession.Values[SessionAuthenticated] = true 220 221 return userSession.Save(r, w) 222} 223 224func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 225 registry := o.GetAccounts(r) 226 account := registry.FindAccount(targetDid) 227 if account == nil { 228 return nil 229 } 230 231 did, err := syntax.ParseDID(targetDid) 232 if err == nil { 233 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId) 234 } 235 236 registry.RemoveAccount(targetDid) 237 return o.SaveAccounts(w, r, registry) 238} 239 240type User struct { 241 Did string 242 Pds string 243} 244 245func (o *OAuth) GetUser(r *http.Request) *User { 246 sess, err := o.ResumeSession(r) 247 if err != nil { 248 return nil 249 } 250 251 return &User{ 252 Did: sess.Data.AccountDID.String(), 253 Pds: sess.Data.HostURL, 254 } 255} 256 257func (o *OAuth) GetDid(r *http.Request) string { 258 if u := o.GetMultiAccountUser(r); u != nil { 259 return u.Did() 260 } 261 262 return "" 263} 264 265func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 266 session, err := o.ResumeSession(r) 267 if err != nil { 268 return nil, fmt.Errorf("error getting session: %w", err) 269 } 270 return session.APIClient(), nil 271} 272 273// this is a higher level abstraction on ServerGetServiceAuth 274type ServiceClientOpts struct { 275 service string 276 exp int64 277 lxm string 278 dev bool 279 timeout time.Duration 280} 281 282type ServiceClientOpt func(*ServiceClientOpts) 283 284func DefaultServiceClientOpts() ServiceClientOpts { 285 return ServiceClientOpts{ 286 timeout: time.Second * 5, 287 } 288} 289 290func WithService(service string) ServiceClientOpt { 291 return func(s *ServiceClientOpts) { 292 s.service = service 293 } 294} 295 296// Specify the Duration in seconds for the expiry of this token 297// 298// The time of expiry is calculated as time.Now().Unix() + exp 299func WithExp(exp int64) ServiceClientOpt { 300 return func(s *ServiceClientOpts) { 301 s.exp = time.Now().Unix() + exp 302 } 303} 304 305func WithLxm(lxm string) ServiceClientOpt { 306 return func(s *ServiceClientOpts) { 307 s.lxm = lxm 308 } 309} 310 311func WithDev(dev bool) ServiceClientOpt { 312 return func(s *ServiceClientOpts) { 313 s.dev = dev 314 } 315} 316 317func WithTimeout(timeout time.Duration) ServiceClientOpt { 318 return func(s *ServiceClientOpts) { 319 s.timeout = timeout 320 } 321} 322 323func (s *ServiceClientOpts) Audience() string { 324 return fmt.Sprintf("did:web:%s", s.service) 325} 326 327func (s *ServiceClientOpts) Host() string { 328 scheme := "https://" 329 if s.dev { 330 scheme = "http://" 331 } 332 333 return scheme + s.service 334} 335 336func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 337 opts := DefaultServiceClientOpts() 338 for _, o := range os { 339 o(&opts) 340 } 341 342 client, err := o.AuthorizedClient(r) 343 if err != nil { 344 return nil, err 345 } 346 347 // force expiry to atleast 60 seconds in the future 348 sixty := time.Now().Unix() + 60 349 if opts.exp < sixty { 350 opts.exp = sixty 351 } 352 353 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 354 if err != nil { 355 return nil, err 356 } 357 358 return &xrpc.Client{ 359 Auth: &xrpc.AuthInfo{ 360 AccessJwt: resp.Token, 361 }, 362 Host: opts.Host(), 363 Client: &http.Client{ 364 Timeout: opts.timeout, 365 }, 366 }, nil 367} 368 369func (o *OAuth) StartElevatedAuthFlow(ctx context.Context, w http.ResponseWriter, r *http.Request, did string, extraScopes []string, returnURL string) (string, error) { 370 parsedDid, err := syntax.ParseDID(did) 371 if err != nil { 372 return "", fmt.Errorf("invalid DID: %w", err) 373 } 374 375 ident, err := o.ClientApp.Dir.Lookup(ctx, parsedDid.AtIdentifier()) 376 if err != nil { 377 return "", fmt.Errorf("failed to resolve DID (%s): %w", did, err) 378 } 379 380 host := ident.PDSEndpoint() 381 if host == "" { 382 return "", fmt.Errorf("identity does not link to an atproto host (PDS)") 383 } 384 385 authserverURL, err := o.ClientApp.Resolver.ResolveAuthServerURL(ctx, host) 386 if err != nil { 387 return "", fmt.Errorf("resolving auth server: %w", err) 388 } 389 390 authserverMeta, err := o.ClientApp.Resolver.ResolveAuthServerMetadata(ctx, authserverURL) 391 if err != nil { 392 return "", fmt.Errorf("fetching auth server metadata: %w", err) 393 } 394 395 scopes := make([]string, 0, len(TangledScopes)+len(extraScopes)) 396 scopes = append(scopes, TangledScopes...) 397 scopes = append(scopes, extraScopes...) 398 399 loginHint := did 400 if ident.Handle != "" && !ident.Handle.IsInvalidHandle() { 401 loginHint = ident.Handle.String() 402 } 403 404 info, err := o.ClientApp.SendAuthRequest(ctx, authserverMeta, scopes, loginHint) 405 if err != nil { 406 return "", fmt.Errorf("auth request failed: %w", err) 407 } 408 409 info.AccountDID = &parsedDid 410 o.ClientApp.Store.SaveAuthRequestInfo(ctx, *info) 411 412 if err := o.SetAuthReturn(w, r, returnURL, false); err != nil { 413 return "", fmt.Errorf("failed to set auth return: %w", err) 414 } 415 416 redirectURL := fmt.Sprintf("%s?client_id=%s&request_uri=%s", 417 authserverMeta.AuthorizationEndpoint, 418 url.QueryEscape(o.ClientApp.Config.ClientID), 419 url.QueryEscape(info.RequestURI), 420 ) 421 422 return redirectURL, nil 423}