A vibe coded tangled fork which supports pijul.
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16 "tangled.org/core/appview/pagination"
17 "tangled.org/core/orm"
18)
19
20func NewPull(tx *sql.Tx, pull *models.Pull) error {
21 _, err := tx.Exec(`
22 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
23 values (?, 1)
24 `, pull.RepoAt)
25 if err != nil {
26 return err
27 }
28
29 var nextId int
30 err = tx.QueryRow(`
31 update repo_pull_seqs
32 set next_pull_id = next_pull_id + 1
33 where repo_at = ?
34 returning next_pull_id - 1
35 `, pull.RepoAt).Scan(&nextId)
36 if err != nil {
37 return err
38 }
39
40 pull.PullId = nextId
41 pull.State = models.PullOpen
42
43 var sourceBranch, sourceRepoAt *string
44 if pull.PullSource != nil {
45 sourceBranch = &pull.PullSource.Branch
46 if pull.PullSource.RepoAt != nil {
47 x := pull.PullSource.RepoAt.String()
48 sourceRepoAt = &x
49 }
50 }
51
52 var stackId, changeId, parentChangeId *string
53 if pull.StackId != "" {
54 stackId = &pull.StackId
55 }
56 if pull.ChangeId != "" {
57 changeId = &pull.ChangeId
58 }
59 if pull.ParentChangeId != "" {
60 parentChangeId = &pull.ParentChangeId
61 }
62
63 result, err := tx.Exec(
64 `
65 insert into pulls (
66 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
67 )
68 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
69 pull.RepoAt,
70 pull.OwnerDid,
71 pull.PullId,
72 pull.Title,
73 pull.TargetBranch,
74 pull.Body,
75 pull.Rkey,
76 pull.State,
77 sourceBranch,
78 sourceRepoAt,
79 stackId,
80 changeId,
81 parentChangeId,
82 )
83 if err != nil {
84 return err
85 }
86
87 // Set the database primary key ID
88 id, err := result.LastInsertId()
89 if err != nil {
90 return err
91 }
92 pull.ID = int(id)
93
94 _, err = tx.Exec(`
95 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
96 values (?, ?, ?, ?, ?)
97 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev)
98 if err != nil {
99 return err
100 }
101
102 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
103 return fmt.Errorf("put reference_links: %w", err)
104 }
105
106 return nil
107}
108
109func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
110 pull, err := GetPull(e, repoAt, pullId)
111 if err != nil {
112 return "", err
113 }
114 return pull.AtUri(), err
115}
116
117func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
118 var pullId int
119 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
120 return pullId - 1, err
121}
122
123func GetPullsPaginated(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.Pull, error) {
124 pulls := make(map[syntax.ATURI]*models.Pull)
125
126 var conditions []string
127 var args []any
128 for _, filter := range filters {
129 conditions = append(conditions, filter.Condition())
130 args = append(args, filter.Arg()...)
131 }
132
133 whereClause := ""
134 if conditions != nil {
135 whereClause = " where " + strings.Join(conditions, " and ")
136 }
137 pageClause := ""
138 if page.Limit != 0 {
139 pageClause = fmt.Sprintf(
140 " limit %d offset %d ",
141 page.Limit,
142 page.Offset,
143 )
144 }
145
146 query := fmt.Sprintf(`
147 select
148 id,
149 owner_did,
150 repo_at,
151 pull_id,
152 created,
153 title,
154 state,
155 target_branch,
156 body,
157 rkey,
158 source_branch,
159 source_repo_at,
160 stack_id,
161 change_id,
162 parent_change_id
163 from
164 pulls
165 %s
166 order by
167 created desc
168 %s
169 `, whereClause, pageClause)
170
171 rows, err := e.Query(query, args...)
172 if err != nil {
173 return nil, err
174 }
175 defer rows.Close()
176
177 for rows.Next() {
178 var pull models.Pull
179 var createdAt string
180 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
181 err := rows.Scan(
182 &pull.ID,
183 &pull.OwnerDid,
184 &pull.RepoAt,
185 &pull.PullId,
186 &createdAt,
187 &pull.Title,
188 &pull.State,
189 &pull.TargetBranch,
190 &pull.Body,
191 &pull.Rkey,
192 &sourceBranch,
193 &sourceRepoAt,
194 &stackId,
195 &changeId,
196 &parentChangeId,
197 )
198 if err != nil {
199 return nil, err
200 }
201
202 createdTime, err := time.Parse(time.RFC3339, createdAt)
203 if err != nil {
204 return nil, err
205 }
206 pull.Created = createdTime
207
208 if sourceBranch.Valid {
209 pull.PullSource = &models.PullSource{
210 Branch: sourceBranch.String,
211 }
212 if sourceRepoAt.Valid {
213 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
214 if err != nil {
215 return nil, err
216 }
217 pull.PullSource.RepoAt = &sourceRepoAtParsed
218 }
219 }
220
221 if stackId.Valid {
222 pull.StackId = stackId.String
223 }
224 if changeId.Valid {
225 pull.ChangeId = changeId.String
226 }
227 if parentChangeId.Valid {
228 pull.ParentChangeId = parentChangeId.String
229 }
230
231 pulls[pull.AtUri()] = &pull
232 }
233
234 var pullAts []syntax.ATURI
235 for _, p := range pulls {
236 pullAts = append(pullAts, p.AtUri())
237 }
238 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts))
239 if err != nil {
240 return nil, fmt.Errorf("failed to get submissions: %w", err)
241 }
242
243 for pullAt, submissions := range submissionsMap {
244 if p, ok := pulls[pullAt]; ok {
245 p.Submissions = submissions
246 }
247 }
248
249 // collect allLabels for each issue
250 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts))
251 if err != nil {
252 return nil, fmt.Errorf("failed to query labels: %w", err)
253 }
254 for pullAt, labels := range allLabels {
255 if p, ok := pulls[pullAt]; ok {
256 p.Labels = labels
257 }
258 }
259
260 // collect pull source for all pulls that need it
261 var sourceAts []syntax.ATURI
262 for _, p := range pulls {
263 if p.PullSource != nil && p.PullSource.RepoAt != nil {
264 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
265 }
266 }
267 sourceRepos, err := GetRepos(e, 0, orm.FilterIn("at_uri", sourceAts))
268 if err != nil && !errors.Is(err, sql.ErrNoRows) {
269 return nil, fmt.Errorf("failed to get source repos: %w", err)
270 }
271 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
272 for _, r := range sourceRepos {
273 sourceRepoMap[r.RepoAt()] = &r
274 }
275 for _, p := range pulls {
276 if p.PullSource != nil && p.PullSource.RepoAt != nil {
277 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
278 p.PullSource.Repo = sourceRepo
279 }
280 }
281 }
282
283 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts))
284 if err != nil {
285 return nil, fmt.Errorf("failed to query reference_links: %w", err)
286 }
287 for pullAt, references := range allReferences {
288 if pull, ok := pulls[pullAt]; ok {
289 pull.References = references
290 }
291 }
292
293 orderedByPullId := []*models.Pull{}
294 for _, p := range pulls {
295 orderedByPullId = append(orderedByPullId, p)
296 }
297 sort.Slice(orderedByPullId, func(i, j int) bool {
298 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
299 })
300
301 return orderedByPullId, nil
302}
303
304func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) {
305 return GetPullsPaginated(e, pagination.Page{}, filters...)
306}
307
308func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
309 pulls, err := GetPullsPaginated(e, pagination.Page{Limit: 1}, orm.FilterEq("repo_at", repoAt), orm.FilterEq("pull_id", pullId))
310 if err != nil {
311 return nil, err
312 }
313 if len(pulls) == 0 {
314 return nil, sql.ErrNoRows
315 }
316
317 return pulls[0], nil
318}
319
320// mapping from pull -> pull submissions
321func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
322 var conditions []string
323 var args []any
324 for _, filter := range filters {
325 conditions = append(conditions, filter.Condition())
326 args = append(args, filter.Arg()...)
327 }
328
329 whereClause := ""
330 if conditions != nil {
331 whereClause = " where " + strings.Join(conditions, " and ")
332 }
333
334 query := fmt.Sprintf(`
335 select
336 id,
337 pull_at,
338 round_number,
339 patch,
340 combined,
341 created,
342 source_rev
343 from
344 pull_submissions
345 %s
346 order by
347 round_number asc
348 `, whereClause)
349
350 rows, err := e.Query(query, args...)
351 if err != nil {
352 return nil, err
353 }
354 defer rows.Close()
355
356 submissionMap := make(map[int]*models.PullSubmission)
357
358 for rows.Next() {
359 var submission models.PullSubmission
360 var submissionCreatedStr string
361 var submissionSourceRev, submissionCombined sql.NullString
362 err := rows.Scan(
363 &submission.ID,
364 &submission.PullAt,
365 &submission.RoundNumber,
366 &submission.Patch,
367 &submissionCombined,
368 &submissionCreatedStr,
369 &submissionSourceRev,
370 )
371 if err != nil {
372 return nil, err
373 }
374
375 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
376 submission.Created = t
377 }
378
379 if submissionSourceRev.Valid {
380 submission.SourceRev = submissionSourceRev.String
381 }
382
383 if submissionCombined.Valid {
384 submission.Combined = submissionCombined.String
385 }
386
387 submissionMap[submission.ID] = &submission
388 }
389
390 if err := rows.Err(); err != nil {
391 return nil, err
392 }
393
394 // Get comments for all submissions using GetPullComments
395 submissionIds := slices.Collect(maps.Keys(submissionMap))
396 comments, err := GetPullComments(e, orm.FilterIn("submission_id", submissionIds))
397 if err != nil {
398 return nil, fmt.Errorf("failed to get pull comments: %w", err)
399 }
400 for _, comment := range comments {
401 if submission, ok := submissionMap[comment.SubmissionId]; ok {
402 submission.Comments = append(submission.Comments, comment)
403 }
404 }
405
406 // group the submissions by pull_at
407 m := make(map[syntax.ATURI][]*models.PullSubmission)
408 for _, s := range submissionMap {
409 m[s.PullAt] = append(m[s.PullAt], s)
410 }
411
412 // sort each one by round number
413 for _, s := range m {
414 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
415 return cmp.Compare(a.RoundNumber, b.RoundNumber)
416 })
417 }
418
419 return m, nil
420}
421
422func GetPullComments(e Execer, filters ...orm.Filter) ([]models.PullComment, error) {
423 var conditions []string
424 var args []any
425 for _, filter := range filters {
426 conditions = append(conditions, filter.Condition())
427 args = append(args, filter.Arg()...)
428 }
429
430 whereClause := ""
431 if conditions != nil {
432 whereClause = " where " + strings.Join(conditions, " and ")
433 }
434
435 query := fmt.Sprintf(`
436 select
437 id,
438 pull_id,
439 submission_id,
440 repo_at,
441 owner_did,
442 comment_at,
443 body,
444 created
445 from
446 pull_comments
447 %s
448 order by
449 created asc
450 `, whereClause)
451
452 rows, err := e.Query(query, args...)
453 if err != nil {
454 return nil, err
455 }
456 defer rows.Close()
457
458 commentMap := make(map[string]*models.PullComment)
459 for rows.Next() {
460 var comment models.PullComment
461 var createdAt string
462 err := rows.Scan(
463 &comment.ID,
464 &comment.PullId,
465 &comment.SubmissionId,
466 &comment.RepoAt,
467 &comment.OwnerDid,
468 &comment.CommentAt,
469 &comment.Body,
470 &createdAt,
471 )
472 if err != nil {
473 return nil, err
474 }
475
476 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
477 comment.Created = t
478 }
479
480 atUri := comment.AtUri().String()
481 commentMap[atUri] = &comment
482 }
483
484 if err := rows.Err(); err != nil {
485 return nil, err
486 }
487
488 // collect references for each comments
489 commentAts := slices.Collect(maps.Keys(commentMap))
490 allReferencs, err := GetReferencesAll(e, orm.FilterIn("from_at", commentAts))
491 if err != nil {
492 return nil, fmt.Errorf("failed to query reference_links: %w", err)
493 }
494 for commentAt, references := range allReferencs {
495 if comment, ok := commentMap[commentAt.String()]; ok {
496 comment.References = references
497 }
498 }
499
500 var comments []models.PullComment
501 for _, c := range commentMap {
502 comments = append(comments, *c)
503 }
504
505 sort.Slice(comments, func(i, j int) bool {
506 return comments[i].Created.Before(comments[j].Created)
507 })
508
509 return comments, nil
510}
511
512// timeframe here is directly passed into the sql query filter, and any
513// timeframe in the past should be negative; e.g.: "-3 months"
514func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
515 var pulls []models.Pull
516
517 rows, err := e.Query(`
518 select
519 p.owner_did,
520 p.repo_at,
521 p.pull_id,
522 p.created,
523 p.title,
524 p.state,
525 r.did,
526 r.name,
527 r.knot,
528 r.rkey,
529 r.created
530 from
531 pulls p
532 join
533 repos r on p.repo_at = r.at_uri
534 where
535 p.owner_did = ? and p.created >= date ('now', ?)
536 order by
537 p.created desc`, did, timeframe)
538 if err != nil {
539 return nil, err
540 }
541 defer rows.Close()
542
543 for rows.Next() {
544 var pull models.Pull
545 var repo models.Repo
546 var pullCreatedAt, repoCreatedAt string
547 err := rows.Scan(
548 &pull.OwnerDid,
549 &pull.RepoAt,
550 &pull.PullId,
551 &pullCreatedAt,
552 &pull.Title,
553 &pull.State,
554 &repo.Did,
555 &repo.Name,
556 &repo.Knot,
557 &repo.Rkey,
558 &repoCreatedAt,
559 )
560 if err != nil {
561 return nil, err
562 }
563
564 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
565 if err != nil {
566 return nil, err
567 }
568 pull.Created = pullCreatedTime
569
570 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
571 if err != nil {
572 return nil, err
573 }
574 repo.Created = repoCreatedTime
575
576 pull.Repo = &repo
577
578 pulls = append(pulls, pull)
579 }
580
581 if err := rows.Err(); err != nil {
582 return nil, err
583 }
584
585 return pulls, nil
586}
587
588func UpsertPullComment(tx *sql.Tx, comment *models.PullComment) error {
589 panic("unimplemented")
590}
591
592func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) {
593 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
594 res, err := tx.Exec(
595 query,
596 comment.OwnerDid,
597 comment.RepoAt,
598 comment.SubmissionId,
599 comment.CommentAt,
600 comment.PullId,
601 comment.Body,
602 )
603 if err != nil {
604 return 0, err
605 }
606
607 i, err := res.LastInsertId()
608 if err != nil {
609 return 0, err
610 }
611
612 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil {
613 return 0, fmt.Errorf("put reference_links: %w", err)
614 }
615
616 return i, nil
617}
618
619func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
620 _, err := e.Exec(
621 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
622 pullState,
623 repoAt,
624 pullId,
625 models.PullDeleted, // only update state of non-deleted pulls
626 models.PullMerged, // only update state of non-merged pulls
627 )
628 return err
629}
630
631func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
632 err := SetPullState(e, repoAt, pullId, models.PullClosed)
633 return err
634}
635
636func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
637 err := SetPullState(e, repoAt, pullId, models.PullOpen)
638 return err
639}
640
641func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
642 err := SetPullState(e, repoAt, pullId, models.PullMerged)
643 return err
644}
645
646func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
647 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
648 return err
649}
650
651func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error {
652 _, err := e.Exec(`
653 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
654 values (?, ?, ?, ?, ?)
655 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev)
656
657 return err
658}
659
660func SetPullParentChangeId(e Execer, parentChangeId string, filters ...orm.Filter) error {
661 var conditions []string
662 var args []any
663
664 args = append(args, parentChangeId)
665
666 for _, filter := range filters {
667 conditions = append(conditions, filter.Condition())
668 args = append(args, filter.Arg()...)
669 }
670
671 whereClause := ""
672 if conditions != nil {
673 whereClause = " where " + strings.Join(conditions, " and ")
674 }
675
676 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
677 _, err := e.Exec(query, args...)
678
679 return err
680}
681
682// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
683// otherwise submissions are immutable
684func UpdatePull(e Execer, newPatch, sourceRev string, filters ...orm.Filter) error {
685 var conditions []string
686 var args []any
687
688 args = append(args, sourceRev)
689 args = append(args, newPatch)
690
691 for _, filter := range filters {
692 conditions = append(conditions, filter.Condition())
693 args = append(args, filter.Arg()...)
694 }
695
696 whereClause := ""
697 if conditions != nil {
698 whereClause = " where " + strings.Join(conditions, " and ")
699 }
700
701 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
702 _, err := e.Exec(query, args...)
703
704 return err
705}
706
707func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
708 row := e.QueryRow(`
709 select
710 count(case when state = ? then 1 end) as open_count,
711 count(case when state = ? then 1 end) as merged_count,
712 count(case when state = ? then 1 end) as closed_count,
713 count(case when state = ? then 1 end) as deleted_count
714 from pulls
715 where repo_at = ?`,
716 models.PullOpen,
717 models.PullMerged,
718 models.PullClosed,
719 models.PullDeleted,
720 repoAt,
721 )
722
723 var count models.PullCount
724 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
725 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
726 }
727
728 return count, nil
729}
730
731// change-id parent-change-id
732//
733// 4 w ,-------- z (TOP)
734// 3 z <----',------- y
735// 2 y <-----',------ x
736// 1 x <------' nil (BOT)
737//
738// `w` is parent of none, so it is the top of the stack
739func GetStack(e Execer, stackId string) (models.Stack, error) {
740 unorderedPulls, err := GetPulls(
741 e,
742 orm.FilterEq("stack_id", stackId),
743 orm.FilterNotEq("state", models.PullDeleted),
744 )
745 if err != nil {
746 return nil, err
747 }
748 // map of parent-change-id to pull
749 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
750 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
751 for _, p := range unorderedPulls {
752 changeIdMap[p.ChangeId] = p
753 if p.ParentChangeId != "" {
754 parentMap[p.ParentChangeId] = p
755 }
756 }
757
758 // the top of the stack is the pull that is not a parent of any pull
759 var topPull *models.Pull
760 for _, maybeTop := range unorderedPulls {
761 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
762 topPull = maybeTop
763 break
764 }
765 }
766
767 pulls := []*models.Pull{}
768 for {
769 pulls = append(pulls, topPull)
770 if topPull.ParentChangeId != "" {
771 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
772 topPull = next
773 } else {
774 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
775 }
776 } else {
777 break
778 }
779 }
780
781 return pulls, nil
782}
783
784func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
785 pulls, err := GetPulls(
786 e,
787 orm.FilterEq("stack_id", stackId),
788 orm.FilterEq("state", models.PullDeleted),
789 )
790 if err != nil {
791 return nil, err
792 }
793
794 return pulls, nil
795}