A vibe coded tangled fork which supports pijul.
at sl/uvpzuszrulvq 348 lines 9.1 kB view raw
1package oauth 2 3import ( 4 "errors" 5 "fmt" 6 "log/slog" 7 "net/http" 8 "sync" 9 "time" 10 11 comatproto "github.com/bluesky-social/indigo/api/atproto" 12 "github.com/bluesky-social/indigo/atproto/auth/oauth" 13 atpclient "github.com/bluesky-social/indigo/atproto/client" 14 atcrypto "github.com/bluesky-social/indigo/atproto/crypto" 15 "github.com/bluesky-social/indigo/atproto/syntax" 16 xrpc "github.com/bluesky-social/indigo/xrpc" 17 "github.com/gorilla/sessions" 18 "github.com/posthog/posthog-go" 19 "tangled.org/core/appview/config" 20 "tangled.org/core/appview/db" 21 "tangled.org/core/idresolver" 22 "tangled.org/core/rbac" 23) 24 25type OAuth struct { 26 ClientApp *oauth.ClientApp 27 SessStore *sessions.CookieStore 28 Config *config.Config 29 JwksUri string 30 ClientName string 31 ClientUri string 32 Posthog posthog.Client 33 Db *db.DB 34 Enforcer *rbac.Enforcer 35 IdResolver *idresolver.Resolver 36 Logger *slog.Logger 37 38 appPasswordSession *AppPasswordSession 39 appPasswordSessionMu sync.Mutex 40} 41 42func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 43 var oauthConfig oauth.ClientConfig 44 var clientUri string 45 if config.Core.Dev { 46 clientUri = "http://127.0.0.1:3000" 47 callbackUri := clientUri + "/oauth/callback" 48 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes) 49 } else { 50 clientUri = "https://" + config.Core.AppviewHost 51 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 52 callbackUri := clientUri + "/oauth/callback" 53 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes) 54 } 55 56 // configure client secret 57 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret) 58 if err != nil { 59 return nil, err 60 } 61 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil { 62 return nil, err 63 } 64 65 jwksUri := clientUri + "/oauth/jwks.json" 66 67 authStore, err := NewRedisStore(&RedisStoreConfig{ 68 RedisURL: config.Redis.ToURL(), 69 SessionExpiryDuration: time.Hour * 24 * 90, 70 SessionInactivityDuration: time.Hour * 24 * 14, 71 AuthRequestExpiryDuration: time.Minute * 30, 72 }) 73 if err != nil { 74 return nil, err 75 } 76 77 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 78 79 clientApp := oauth.NewClientApp(&oauthConfig, authStore) 80 clientApp.Dir = res.Directory() 81 // allow non-public transports in dev mode 82 if config.Core.Dev { 83 clientApp.Resolver.Client.Transport = http.DefaultTransport 84 } 85 86 clientName := config.Core.AppviewName 87 88 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential()) 89 return &OAuth{ 90 ClientApp: clientApp, 91 Config: config, 92 SessStore: sessStore, 93 JwksUri: jwksUri, 94 ClientName: clientName, 95 ClientUri: clientUri, 96 Posthog: ph, 97 Db: db, 98 Enforcer: enforcer, 99 IdResolver: res, 100 Logger: logger, 101 }, nil 102} 103 104func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 105 userSession, err := o.SessStore.Get(r, SessionName) 106 if err != nil { 107 o.Logger.Warn("failed to decode existing session cookie, will create new", "err", err) 108 } 109 110 userSession.Values[SessionDid] = sessData.AccountDID.String() 111 userSession.Values[SessionPds] = sessData.HostURL 112 userSession.Values[SessionId] = sessData.SessionID 113 userSession.Values[SessionAuthenticated] = true 114 115 if err := userSession.Save(r, w); err != nil { 116 return err 117 } 118 119 handle := "" 120 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String()) 121 if err == nil && resolved.Handle.String() != "" { 122 handle = resolved.Handle.String() 123 } 124 125 registry := o.GetAccounts(r) 126 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil { 127 return err 128 } 129 return o.saveAccounts(w, r, registry) 130} 131 132func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 133 userSession, err := o.SessStore.Get(r, SessionName) 134 if err != nil { 135 return nil, fmt.Errorf("error getting user session: %w", err) 136 } 137 if userSession.IsNew { 138 return nil, fmt.Errorf("no session available for user") 139 } 140 141 d := userSession.Values[SessionDid].(string) 142 sessDid, err := syntax.ParseDID(d) 143 if err != nil { 144 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 145 } 146 147 sessId := userSession.Values[SessionId].(string) 148 149 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 150 if err != nil { 151 return nil, fmt.Errorf("failed to resume session: %w", err) 152 } 153 154 return clientSess, nil 155} 156 157func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 158 userSession, err := o.SessStore.Get(r, SessionName) 159 if err != nil { 160 return fmt.Errorf("error getting user session: %w", err) 161 } 162 if userSession.IsNew { 163 return fmt.Errorf("no session available for user") 164 } 165 166 d := userSession.Values[SessionDid].(string) 167 sessDid, err := syntax.ParseDID(d) 168 if err != nil { 169 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 170 } 171 172 sessId := userSession.Values[SessionId].(string) 173 174 // delete the session 175 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 176 if err1 != nil { 177 err1 = fmt.Errorf("failed to logout: %w", err1) 178 } 179 180 // remove the cookie 181 userSession.Options.MaxAge = -1 182 err2 := o.SessStore.Save(r, w, userSession) 183 if err2 != nil { 184 err2 = fmt.Errorf("failed to save into session store: %w", err2) 185 } 186 187 return errors.Join(err1, err2) 188} 189 190func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 191 registry := o.GetAccounts(r) 192 account := registry.FindAccount(targetDid) 193 if account == nil { 194 return fmt.Errorf("account not found in registry: %s", targetDid) 195 } 196 197 did, err := syntax.ParseDID(targetDid) 198 if err != nil { 199 return fmt.Errorf("invalid DID: %w", err) 200 } 201 202 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId) 203 if err != nil { 204 registry.RemoveAccount(targetDid) 205 _ = o.saveAccounts(w, r, registry) 206 return fmt.Errorf("session expired for account: %w", err) 207 } 208 209 userSession, err := o.SessStore.Get(r, SessionName) 210 if err != nil { 211 return err 212 } 213 214 userSession.Values[SessionDid] = sess.Data.AccountDID.String() 215 userSession.Values[SessionPds] = sess.Data.HostURL 216 userSession.Values[SessionId] = sess.Data.SessionID 217 userSession.Values[SessionAuthenticated] = true 218 219 return userSession.Save(r, w) 220} 221 222func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 223 registry := o.GetAccounts(r) 224 account := registry.FindAccount(targetDid) 225 if account == nil { 226 return nil 227 } 228 229 did, err := syntax.ParseDID(targetDid) 230 if err == nil { 231 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId) 232 } 233 234 registry.RemoveAccount(targetDid) 235 return o.saveAccounts(w, r, registry) 236} 237 238func (o *OAuth) GetDid(r *http.Request) string { 239 if u := o.GetMultiAccountUser(r); u != nil { 240 return u.Did 241 } 242 243 return "" 244} 245 246func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) { 247 session, err := o.ResumeSession(r) 248 if err != nil { 249 return nil, fmt.Errorf("error getting session: %w", err) 250 } 251 return session.APIClient(), nil 252} 253 254// this is a higher level abstraction on ServerGetServiceAuth 255type ServiceClientOpts struct { 256 service string 257 exp int64 258 lxm string 259 dev bool 260 timeout time.Duration 261} 262 263type ServiceClientOpt func(*ServiceClientOpts) 264 265func DefaultServiceClientOpts() ServiceClientOpts { 266 return ServiceClientOpts{ 267 timeout: time.Second * 5, 268 } 269} 270 271func WithService(service string) ServiceClientOpt { 272 return func(s *ServiceClientOpts) { 273 s.service = service 274 } 275} 276 277// Specify the Duration in seconds for the expiry of this token 278// 279// The time of expiry is calculated as time.Now().Unix() + exp 280func WithExp(exp int64) ServiceClientOpt { 281 return func(s *ServiceClientOpts) { 282 s.exp = time.Now().Unix() + exp 283 } 284} 285 286func WithLxm(lxm string) ServiceClientOpt { 287 return func(s *ServiceClientOpts) { 288 s.lxm = lxm 289 } 290} 291 292func WithDev(dev bool) ServiceClientOpt { 293 return func(s *ServiceClientOpts) { 294 s.dev = dev 295 } 296} 297 298func WithTimeout(timeout time.Duration) ServiceClientOpt { 299 return func(s *ServiceClientOpts) { 300 s.timeout = timeout 301 } 302} 303 304func (s *ServiceClientOpts) Audience() string { 305 return fmt.Sprintf("did:web:%s", s.service) 306} 307 308func (s *ServiceClientOpts) Host() string { 309 scheme := "https://" 310 if s.dev { 311 scheme = "http://" 312 } 313 314 return scheme + s.service 315} 316 317func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 318 opts := DefaultServiceClientOpts() 319 for _, o := range os { 320 o(&opts) 321 } 322 323 client, err := o.AuthorizedClient(r) 324 if err != nil { 325 return nil, err 326 } 327 328 // force expiry to atleast 60 seconds in the future 329 sixty := time.Now().Unix() + 60 330 if opts.exp < sixty { 331 opts.exp = sixty 332 } 333 334 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 335 if err != nil { 336 return nil, err 337 } 338 339 return &xrpc.Client{ 340 Auth: &xrpc.AuthInfo{ 341 AccessJwt: resp.Token, 342 }, 343 Host: opts.Host(), 344 Client: &http.Client{ 345 Timeout: opts.timeout, 346 }, 347 }, nil 348}