A vibe coded tangled fork which supports pijul.
at master 426 lines 11 kB view raw
1package oauth 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "log/slog" 8 "net/http" 9 "net/url" 10 "strings" 11 "sync" 12 "time" 13 14 comatproto "github.com/bluesky-social/indigo/api/atproto" 15 "github.com/bluesky-social/indigo/atproto/auth/oauth" 16 atpclient "github.com/bluesky-social/indigo/atproto/client" 17 atcrypto "github.com/bluesky-social/indigo/atproto/crypto" 18 "github.com/bluesky-social/indigo/atproto/syntax" 19 xrpc "github.com/bluesky-social/indigo/xrpc" 20 "github.com/gorilla/sessions" 21 "github.com/posthog/posthog-go" 22 "tangled.org/core/appview/config" 23 "tangled.org/core/appview/db" 24 "tangled.org/core/idresolver" 25 "tangled.org/core/rbac" 26) 27 28type OAuth struct { 29 ClientApp *oauth.ClientApp 30 SessStore *sessions.CookieStore 31 Config *config.Config 32 JwksUri string 33 ClientName string 34 ClientUri string 35 Posthog posthog.Client 36 Db *db.DB 37 Enforcer *rbac.Enforcer 38 IdResolver *idresolver.Resolver 39 Logger *slog.Logger 40 41 appPasswordSession *AppPasswordSession 42 appPasswordSessionMu sync.Mutex 43} 44 45func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 46 var oauthConfig oauth.ClientConfig 47 var clientUri string 48 if config.Core.Dev { 49 clientUri = "http://127.0.0.1:3000" 50 callbackUri := clientUri + "/oauth/callback" 51 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes) 52 } else { 53 clientUri = "https://" + config.Core.AppviewHost 54 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 55 callbackUri := clientUri + "/oauth/callback" 56 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes) 57 } 58 59 // configure client secret 60 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret) 61 if err != nil { 62 return nil, err 63 } 64 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil { 65 return nil, err 66 } 67 68 jwksUri := clientUri + "/oauth/jwks.json" 69 70 authStore, err := NewRedisStore(&RedisStoreConfig{ 71 RedisURL: config.Redis.ToURL(), 72 SessionExpiryDuration: time.Hour * 24 * 90, 73 SessionInactivityDuration: time.Hour * 24 * 14, 74 AuthRequestExpiryDuration: time.Minute * 30, 75 }) 76 if err != nil { 77 return nil, err 78 } 79 80 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 81 82 clientApp := oauth.NewClientApp(&oauthConfig, authStore) 83 clientApp.Dir = res.Directory() 84 // allow non-public transports in dev mode 85 if config.Core.Dev { 86 clientApp.Resolver.Client.Transport = http.DefaultTransport 87 } 88 89 clientName := config.Core.AppviewName 90 91 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential()) 92 return &OAuth{ 93 ClientApp: clientApp, 94 Config: config, 95 SessStore: sessStore, 96 JwksUri: jwksUri, 97 ClientName: clientName, 98 ClientUri: clientUri, 99 Posthog: ph, 100 Db: db, 101 Enforcer: enforcer, 102 IdResolver: res, 103 Logger: logger, 104 }, nil 105} 106 107func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 108 userSession, err := o.SessStore.Get(r, SessionName) 109 if err != nil { 110 o.Logger.Warn("failed to decode existing session cookie, will create new", "err", err) 111 } 112 113 userSession.Values[SessionDid] = sessData.AccountDID.String() 114 userSession.Values[SessionPds] = sessData.HostURL 115 userSession.Values[SessionId] = sessData.SessionID 116 userSession.Values[SessionAuthenticated] = true 117 118 if err := userSession.Save(r, w); err != nil { 119 return err 120 } 121 122 handle := "" 123 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String()) 124 if err == nil && resolved.Handle.String() != "" { 125 handle = resolved.Handle.String() 126 } 127 128 registry := o.GetAccounts(r) 129 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil { 130 return err 131 } 132 return o.SaveAccounts(w, r, registry) 133} 134 135func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 136 userSession, err := o.SessStore.Get(r, SessionName) 137 if err != nil { 138 return nil, fmt.Errorf("error getting user session: %w", err) 139 } 140 if userSession.IsNew { 141 return nil, fmt.Errorf("no session available for user") 142 } 143 144 d := userSession.Values[SessionDid].(string) 145 sessDid, err := syntax.ParseDID(d) 146 if err != nil { 147 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 148 } 149 150 sessId := userSession.Values[SessionId].(string) 151 152 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 153 if err != nil { 154 return nil, fmt.Errorf("failed to resume session: %w", err) 155 } 156 157 return clientSess, nil 158} 159 160func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 161 userSession, err := o.SessStore.Get(r, SessionName) 162 if err != nil { 163 return fmt.Errorf("error getting user session: %w", err) 164 } 165 if userSession.IsNew { 166 return fmt.Errorf("no session available for user") 167 } 168 169 d := userSession.Values[SessionDid].(string) 170 sessDid, err := syntax.ParseDID(d) 171 if err != nil { 172 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 173 } 174 175 sessId := userSession.Values[SessionId].(string) 176 177 // delete the session 178 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 179 if err1 != nil { 180 err1 = fmt.Errorf("failed to logout: %w", err1) 181 } 182 183 // remove the cookie 184 userSession.Options.MaxAge = -1 185 err2 := o.SessStore.Save(r, w, userSession) 186 if err2 != nil { 187 err2 = fmt.Errorf("failed to save into session store: %w", err2) 188 } 189 190 return errors.Join(err1, err2) 191} 192 193func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 194 registry := o.GetAccounts(r) 195 account := registry.FindAccount(targetDid) 196 if account == nil { 197 return fmt.Errorf("account not found in registry: %s", targetDid) 198 } 199 200 did, err := syntax.ParseDID(targetDid) 201 if err != nil { 202 return fmt.Errorf("invalid DID: %w", err) 203 } 204 205 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId) 206 if err != nil { 207 registry.RemoveAccount(targetDid) 208 _ = o.SaveAccounts(w, r, registry) 209 return fmt.Errorf("session expired for account: %w", err) 210 } 211 212 userSession, err := o.SessStore.Get(r, SessionName) 213 if err != nil { 214 return err 215 } 216 217 userSession.Values[SessionDid] = sess.Data.AccountDID.String() 218 userSession.Values[SessionPds] = sess.Data.HostURL 219 userSession.Values[SessionId] = sess.Data.SessionID 220 userSession.Values[SessionAuthenticated] = true 221 222 return userSession.Save(r, w) 223} 224 225func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 226 registry := o.GetAccounts(r) 227 account := registry.FindAccount(targetDid) 228 if account == nil { 229 return nil 230 } 231 232 did, err := syntax.ParseDID(targetDid) 233 if err == nil { 234 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId) 235 } 236 237 registry.RemoveAccount(targetDid) 238 return o.SaveAccounts(w, r, registry) 239} 240 241type User struct { 242 Did string 243 Pds string 244} 245 246func (o *OAuth) GetUser(r *http.Request) *User { 247 sess, err := o.ResumeSession(r) 248 if err != nil { 249 return nil 250 } 251 252 return &User{ 253 Did: sess.Data.AccountDID.String(), 254 Pds: sess.Data.HostURL, 255 } 256} 257 258func (o *OAuth) GetDid(r *http.Request) string { 259 if u := o.GetMultiAccountUser(r); u != nil { 260 return u.Did() 261 } 262 263 return "" 264} 265 266func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 267 session, err := o.ResumeSession(r) 268 if err != nil { 269 return nil, fmt.Errorf("error getting session: %w", err) 270 } 271 return session.APIClient(), nil 272} 273 274// this is a higher level abstraction on ServerGetServiceAuth 275type ServiceClientOpts struct { 276 service string 277 exp int64 278 lxm string 279 dev bool 280 timeout time.Duration 281} 282 283type ServiceClientOpt func(*ServiceClientOpts) 284 285func DefaultServiceClientOpts() ServiceClientOpts { 286 return ServiceClientOpts{ 287 timeout: time.Second * 5, 288 } 289} 290 291func WithService(service string) ServiceClientOpt { 292 return func(s *ServiceClientOpts) { 293 s.service = service 294 } 295} 296 297// Specify the Duration in seconds for the expiry of this token 298// 299// The time of expiry is calculated as time.Now().Unix() + exp 300func WithExp(exp int64) ServiceClientOpt { 301 return func(s *ServiceClientOpts) { 302 s.exp = time.Now().Unix() + exp 303 } 304} 305 306func WithLxm(lxm string) ServiceClientOpt { 307 return func(s *ServiceClientOpts) { 308 s.lxm = lxm 309 } 310} 311 312func WithDev(dev bool) ServiceClientOpt { 313 return func(s *ServiceClientOpts) { 314 s.dev = dev 315 } 316} 317 318func WithTimeout(timeout time.Duration) ServiceClientOpt { 319 return func(s *ServiceClientOpts) { 320 s.timeout = timeout 321 } 322} 323 324func (s *ServiceClientOpts) Audience() string { 325 // did:web spec requires colons to be encoded as %3A 326 encoded := strings.ReplaceAll(s.service, ":", "%3A") 327 return fmt.Sprintf("did:web:%s", encoded) 328} 329 330func (s *ServiceClientOpts) Host() string { 331 scheme := "https://" 332 if s.dev { 333 scheme = "http://" 334 } 335 336 return scheme + s.service 337} 338 339func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 340 opts := DefaultServiceClientOpts() 341 for _, o := range os { 342 o(&opts) 343 } 344 345 client, err := o.AuthorizedClient(r) 346 if err != nil { 347 return nil, err 348 } 349 350 // force expiry to atleast 60 seconds in the future 351 sixty := time.Now().Unix() + 60 352 if opts.exp < sixty { 353 opts.exp = sixty 354 } 355 356 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 357 if err != nil { 358 return nil, err 359 } 360 361 return &xrpc.Client{ 362 Auth: &xrpc.AuthInfo{ 363 AccessJwt: resp.Token, 364 }, 365 Host: opts.Host(), 366 Client: &http.Client{ 367 Timeout: opts.timeout, 368 }, 369 }, nil 370} 371 372func (o *OAuth) StartElevatedAuthFlow(ctx context.Context, w http.ResponseWriter, r *http.Request, did string, extraScopes []string, returnURL string) (string, error) { 373 parsedDid, err := syntax.ParseDID(did) 374 if err != nil { 375 return "", fmt.Errorf("invalid DID: %w", err) 376 } 377 378 ident, err := o.ClientApp.Dir.Lookup(ctx, parsedDid.AtIdentifier()) 379 if err != nil { 380 return "", fmt.Errorf("failed to resolve DID (%s): %w", did, err) 381 } 382 383 host := ident.PDSEndpoint() 384 if host == "" { 385 return "", fmt.Errorf("identity does not link to an atproto host (PDS)") 386 } 387 388 authserverURL, err := o.ClientApp.Resolver.ResolveAuthServerURL(ctx, host) 389 if err != nil { 390 return "", fmt.Errorf("resolving auth server: %w", err) 391 } 392 393 authserverMeta, err := o.ClientApp.Resolver.ResolveAuthServerMetadata(ctx, authserverURL) 394 if err != nil { 395 return "", fmt.Errorf("fetching auth server metadata: %w", err) 396 } 397 398 scopes := make([]string, 0, len(TangledScopes)+len(extraScopes)) 399 scopes = append(scopes, TangledScopes...) 400 scopes = append(scopes, extraScopes...) 401 402 loginHint := did 403 if ident.Handle != "" && !ident.Handle.IsInvalidHandle() { 404 loginHint = ident.Handle.String() 405 } 406 407 info, err := o.ClientApp.SendAuthRequest(ctx, authserverMeta, scopes, loginHint) 408 if err != nil { 409 return "", fmt.Errorf("auth request failed: %w", err) 410 } 411 412 info.AccountDID = &parsedDid 413 o.ClientApp.Store.SaveAuthRequestInfo(ctx, *info) 414 415 if err := o.SetAuthReturn(w, r, returnURL, false); err != nil { 416 return "", fmt.Errorf("failed to set auth return: %w", err) 417 } 418 419 redirectURL := fmt.Sprintf("%s?client_id=%s&request_uri=%s", 420 authserverMeta.AuthorizationEndpoint, 421 url.QueryEscape(o.ClientApp.Config.ClientID), 422 url.QueryEscape(info.RequestURI), 423 ) 424 425 return redirectURL, nil 426}