A vibe coded tangled fork which supports pijul.
at 598f5cdb0f7cc53454d1bc8110e95f9b3a3fa713 432 lines 12 kB view raw
1package oauth 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "log/slog" 10 "net/http" 11 "slices" 12 "time" 13 14 comatproto "github.com/bluesky-social/indigo/api/atproto" 15 "github.com/bluesky-social/indigo/atproto/auth/oauth" 16 lexutil "github.com/bluesky-social/indigo/lex/util" 17 "github.com/go-chi/chi/v5" 18 "github.com/posthog/posthog-go" 19 "tangled.org/core/api/tangled" 20 "tangled.org/core/appview/db" 21 "tangled.org/core/appview/models" 22 "tangled.org/core/consts" 23 "tangled.org/core/idresolver" 24 "tangled.org/core/orm" 25 "tangled.org/core/tid" 26) 27 28func (o *OAuth) Router() http.Handler { 29 r := chi.NewRouter() 30 31 r.Get("/oauth/client-metadata.json", o.clientMetadata) 32 r.Get("/oauth/jwks.json", o.jwks) 33 r.Get("/oauth/callback", o.callback) 34 return r 35} 36 37func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) { 38 doc := o.ClientApp.Config.ClientMetadata() 39 doc.JWKSURI = &o.JwksUri 40 doc.ClientName = &o.ClientName 41 doc.ClientURI = &o.ClientUri 42 43 w.Header().Set("Content-Type", "application/json") 44 if err := json.NewEncoder(w).Encode(doc); err != nil { 45 http.Error(w, err.Error(), http.StatusInternalServerError) 46 return 47 } 48} 49 50func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) { 51 w.Header().Set("Content-Type", "application/json") 52 body := o.ClientApp.Config.PublicJWKS() 53 if err := json.NewEncoder(w).Encode(body); err != nil { 54 http.Error(w, err.Error(), http.StatusInternalServerError) 55 return 56 } 57} 58 59func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) { 60 ctx := r.Context() 61 l := o.Logger.With("query", r.URL.Query()) 62 63 authReturn := o.GetAuthReturn(r) 64 _ = o.ClearAuthReturn(w, r) 65 66 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query()) 67 if err != nil { 68 var callbackErr *oauth.AuthRequestCallbackError 69 if errors.As(err, &callbackErr) { 70 l.Debug("callback error", "err", callbackErr) 71 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", callbackErr.ErrorCode), http.StatusFound) 72 return 73 } 74 l.Error("failed to process callback", "err", err) 75 http.Redirect(w, r, "/login?error=oauth", http.StatusFound) 76 return 77 } 78 79 if err := o.SaveSession(w, r, sessData); err != nil { 80 l.Error("failed to save session", "data", sessData, "err", err) 81 errorCode := "session" 82 if errors.Is(err, ErrMaxAccountsReached) { 83 errorCode = "max_accounts" 84 } 85 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", errorCode), http.StatusFound) 86 return 87 } 88 89 o.Logger.Debug("session saved successfully") 90 91 go o.addToDefaultKnot(sessData.AccountDID.String()) 92 go o.addToDefaultSpindle(sessData.AccountDID.String()) 93 go o.ensureTangledProfile(sessData) 94 95 if !o.Config.Core.Dev { 96 err = o.Posthog.Enqueue(posthog.Capture{ 97 DistinctId: sessData.AccountDID.String(), 98 Event: "signin", 99 }) 100 if err != nil { 101 o.Logger.Error("failed to enqueue posthog event", "err", err) 102 } 103 } 104 105 redirectURL := "/" 106 if authReturn.ReturnURL != "" { 107 redirectURL = authReturn.ReturnURL 108 } 109 110 http.Redirect(w, r, redirectURL, http.StatusFound) 111} 112 113func (o *OAuth) addToDefaultSpindle(did string) { 114 l := o.Logger.With("subject", did) 115 116 // use the tangled.sh app password to get an accessJwt 117 // and create an sh.tangled.spindle.member record with that 118 spindleMembers, err := db.GetSpindleMembers( 119 o.Db, 120 orm.FilterEq("instance", "spindle.tangled.sh"), 121 orm.FilterEq("subject", did), 122 ) 123 if err != nil { 124 l.Error("failed to get spindle members", "err", err) 125 return 126 } 127 128 if len(spindleMembers) != 0 { 129 l.Warn("already a member of the default spindle") 130 return 131 } 132 133 l.Debug("adding to default spindle") 134 session, err := o.getAppPasswordSession() 135 if err != nil { 136 l.Error("failed to create session", "err", err) 137 return 138 } 139 140 record := tangled.SpindleMember{ 141 LexiconTypeID: tangled.SpindleMemberNSID, 142 Subject: did, 143 Instance: consts.DefaultSpindle, 144 CreatedAt: time.Now().Format(time.RFC3339), 145 } 146 147 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil { 148 l.Error("failed to add to default spindle", "err", err) 149 return 150 } 151 152 l.Debug("successfully added to default spindle", "did", did) 153} 154 155func (o *OAuth) addToDefaultKnot(did string) { 156 l := o.Logger.With("subject", did) 157 158 // use the tangled.sh app password to get an accessJwt 159 // and create an sh.tangled.spindle.member record with that 160 161 allKnots, err := o.Enforcer.GetKnotsForUser(did) 162 if err != nil { 163 l.Error("failed to get knot members for did", "err", err) 164 return 165 } 166 167 if slices.Contains(allKnots, consts.DefaultKnot) { 168 l.Warn("already a member of the default knot") 169 return 170 } 171 172 l.Debug("adding to default knot") 173 session, err := o.getAppPasswordSession() 174 if err != nil { 175 l.Error("failed to create session", "err", err) 176 return 177 } 178 179 record := tangled.KnotMember{ 180 LexiconTypeID: tangled.KnotMemberNSID, 181 Subject: did, 182 Domain: consts.DefaultKnot, 183 CreatedAt: time.Now().Format(time.RFC3339), 184 } 185 186 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil { 187 l.Error("failed to add to default knot", "err", err) 188 return 189 } 190 191 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil { 192 l.Error("failed to set up enforcer rules", "err", err) 193 return 194 } 195 196 l.Debug("successfully added to default knot") 197} 198 199func (o *OAuth) ensureTangledProfile(sessData *oauth.ClientSessionData) { 200 ctx := context.Background() 201 did := sessData.AccountDID.String() 202 l := o.Logger.With("did", did) 203 204 profile, _ := db.GetProfile(o.Db, did) 205 if profile != nil { 206 l.Debug("profile already exists in DB") 207 return 208 } 209 210 l.Debug("creating empty Tangled profile") 211 212 sess, err := o.ClientApp.ResumeSession(ctx, sessData.AccountDID, sessData.SessionID) 213 if err != nil { 214 l.Error("failed to resume session for profile creation", "err", err) 215 return 216 } 217 client := sess.APIClient() 218 219 _, err = comatproto.RepoPutRecord(ctx, client, &comatproto.RepoPutRecord_Input{ 220 Collection: tangled.ActorProfileNSID, 221 Repo: did, 222 Rkey: "self", 223 Record: &lexutil.LexiconTypeDecoder{Val: &tangled.ActorProfile{}}, 224 }) 225 226 if err != nil { 227 l.Error("failed to create empty profile on PDS", "err", err) 228 return 229 } 230 231 tx, err := o.Db.BeginTx(ctx, nil) 232 if err != nil { 233 l.Error("failed to start transaction", "err", err) 234 return 235 } 236 237 emptyProfile := &models.Profile{Did: did} 238 if err := db.UpsertProfile(tx, emptyProfile); err != nil { 239 l.Error("failed to create empty profile in DB", "err", err) 240 return 241 } 242 243 l.Debug("successfully created empty Tangled profile on PDS and DB") 244} 245 246// create a AppPasswordSession using apppasswords 247type AppPasswordSession struct { 248 AccessJwt string `json:"accessJwt"` 249 RefreshJwt string `json:"refreshJwt"` 250 PdsEndpoint string 251 Did string 252 Logger *slog.Logger 253 ExpiresAt time.Time 254} 255 256func CreateAppPasswordSession(res *idresolver.Resolver, appPassword, did string, logger *slog.Logger) (*AppPasswordSession, error) { 257 if appPassword == "" { 258 return nil, fmt.Errorf("no app password configured") 259 } 260 261 resolved, err := res.ResolveIdent(context.Background(), did) 262 if err != nil { 263 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err) 264 } 265 266 pdsEndpoint := resolved.PDSEndpoint() 267 if pdsEndpoint == "" { 268 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did) 269 } 270 271 sessionPayload := map[string]string{ 272 "identifier": did, 273 "password": appPassword, 274 } 275 sessionBytes, err := json.Marshal(sessionPayload) 276 if err != nil { 277 return nil, fmt.Errorf("failed to marshal session payload: %v", err) 278 } 279 280 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession" 281 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes)) 282 if err != nil { 283 return nil, fmt.Errorf("failed to create session request: %v", err) 284 } 285 sessionReq.Header.Set("Content-Type", "application/json") 286 287 logger.Debug("creating app password session", "url", sessionURL, "headers", sessionReq.Header) 288 289 client := &http.Client{Timeout: 30 * time.Second} 290 sessionResp, err := client.Do(sessionReq) 291 if err != nil { 292 return nil, fmt.Errorf("failed to create session: %v", err) 293 } 294 defer sessionResp.Body.Close() 295 296 if sessionResp.StatusCode != http.StatusOK { 297 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode) 298 } 299 300 var session AppPasswordSession 301 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil { 302 return nil, fmt.Errorf("failed to decode session response: %v", err) 303 } 304 305 session.PdsEndpoint = pdsEndpoint 306 session.Did = did 307 session.Logger = logger 308 session.ExpiresAt = time.Now().Add(115 * time.Minute) 309 310 return &session, nil 311} 312 313func (s *AppPasswordSession) refreshSession() error { 314 refreshURL := s.PdsEndpoint + "/xrpc/com.atproto.server.refreshSession" 315 req, err := http.NewRequestWithContext(context.Background(), "POST", refreshURL, nil) 316 if err != nil { 317 return fmt.Errorf("failed to create refresh request: %w", err) 318 } 319 320 req.Header.Set("Authorization", "Bearer "+s.RefreshJwt) 321 322 s.Logger.Debug("refreshing app password session", "url", refreshURL) 323 324 client := &http.Client{Timeout: 30 * time.Second} 325 resp, err := client.Do(req) 326 if err != nil { 327 return fmt.Errorf("failed to refresh session: %w", err) 328 } 329 defer resp.Body.Close() 330 331 if resp.StatusCode != http.StatusOK { 332 var errorResponse map[string]any 333 if err := json.NewDecoder(resp.Body).Decode(&errorResponse); err != nil { 334 return fmt.Errorf("failed to refresh session: HTTP %d (failed to decode error response: %w)", resp.StatusCode, err) 335 } 336 errorBytes, _ := json.Marshal(errorResponse) 337 return fmt.Errorf("failed to refresh session: HTTP %d, response: %s", resp.StatusCode, string(errorBytes)) 338 } 339 340 var refreshResponse struct { 341 AccessJwt string `json:"accessJwt"` 342 RefreshJwt string `json:"refreshJwt"` 343 } 344 if err := json.NewDecoder(resp.Body).Decode(&refreshResponse); err != nil { 345 return fmt.Errorf("failed to decode refresh response: %w", err) 346 } 347 348 s.AccessJwt = refreshResponse.AccessJwt 349 s.RefreshJwt = refreshResponse.RefreshJwt 350 // Set new expiry time with 5 minute buffer 351 s.ExpiresAt = time.Now().Add(115 * time.Minute) 352 353 s.Logger.Debug("successfully refreshed app password session") 354 return nil 355} 356 357func (s *AppPasswordSession) isValid() bool { 358 return time.Now().Before(s.ExpiresAt) 359} 360 361func (s *AppPasswordSession) putRecord(record any, collection string) error { 362 if !s.isValid() { 363 s.Logger.Debug("access token expired, refreshing session") 364 if err := s.refreshSession(); err != nil { 365 return fmt.Errorf("failed to refresh session: %w", err) 366 } 367 s.Logger.Debug("session refreshed") 368 } 369 370 recordBytes, err := json.Marshal(record) 371 if err != nil { 372 return fmt.Errorf("failed to marshal knot member record: %w", err) 373 } 374 375 payload := map[string]any{ 376 "repo": s.Did, 377 "collection": collection, 378 "rkey": tid.TID(), 379 "record": json.RawMessage(recordBytes), 380 } 381 382 payloadBytes, err := json.Marshal(payload) 383 if err != nil { 384 return fmt.Errorf("failed to marshal request payload: %w", err) 385 } 386 387 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord" 388 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes)) 389 if err != nil { 390 return fmt.Errorf("failed to create HTTP request: %w", err) 391 } 392 393 req.Header.Set("Content-Type", "application/json") 394 req.Header.Set("Authorization", "Bearer "+s.AccessJwt) 395 396 s.Logger.Debug("putting record", "url", url, "collection", collection) 397 398 client := &http.Client{Timeout: 30 * time.Second} 399 resp, err := client.Do(req) 400 if err != nil { 401 return fmt.Errorf("failed to add user to default service: %w", err) 402 } 403 defer resp.Body.Close() 404 405 if resp.StatusCode != http.StatusOK { 406 var errorResponse map[string]any 407 if err := json.NewDecoder(resp.Body).Decode(&errorResponse); err != nil { 408 return fmt.Errorf("failed to add user to default service: HTTP %d (failed to decode error response: %w)", resp.StatusCode, err) 409 } 410 return fmt.Errorf("failed to add user to default service: HTTP %d, response: %v", resp.StatusCode, errorResponse) 411 } 412 413 return nil 414} 415 416// getAppPasswordSession returns a cached AppPasswordSession, creating one if needed. 417func (o *OAuth) getAppPasswordSession() (*AppPasswordSession, error) { 418 o.appPasswordSessionMu.Lock() 419 defer o.appPasswordSessionMu.Unlock() 420 421 if o.appPasswordSession != nil { 422 return o.appPasswordSession, nil 423 } 424 425 session, err := CreateAppPasswordSession(o.IdResolver, o.Config.Core.AppPassword, consts.TangledDid, o.Logger) 426 if err != nil { 427 return nil, err 428 } 429 430 o.appPasswordSession = session 431 return session, nil 432}