diff --git a/internal/pkg/pkg.go b/internal/pkg/pkg.go index 4fc2d3b2..d7191907 100644 --- a/internal/pkg/pkg.go +++ b/internal/pkg/pkg.go @@ -542,6 +542,20 @@ const ( CHostAbstract ) +// toplevel holds [context.WithCancel] over caller-supplied context, where all +// [Artifact] context are derived from. +type toplevel struct { + ctx context.Context + cancel context.CancelFunc +} + +// newToplevel returns the address of a new toplevel via ctx. +func newToplevel(ctx context.Context) *toplevel { + var t toplevel + t.ctx, t.cancel = context.WithCancel(ctx) + return &t +} + // pendingCure provides synchronisation and cancellation for pending cures. type pendingCure struct { // Closed on cure completion. @@ -557,11 +571,10 @@ type Cache struct { // implementation and receives an equal amount of elements after. cures chan struct{} - // [context.WithCancel] over caller-supplied context, used by [Artifact] and - // all dependency curing goroutines. - ctx context.Context - // Cancels ctx. - cancel context.CancelFunc + // Parent context which toplevel was derived from. + parent context.Context + // For deriving curing context, must not be accessed directly. + toplevel atomic.Pointer[toplevel] // For waiting on dependency curing goroutines. wg sync.WaitGroup // Reports new cures and passed to [Artifact]. @@ -596,8 +609,10 @@ type Cache struct { // Unlocks the on-filesystem cache. Must only be called from Close. unlock func() - // Synchronises calls to Close. - closeOnce sync.Once + // Whether [Cache] is considered closed. + closed bool + // Synchronises calls to Abort and Close. + closeMu sync.Mutex // Whether EnterExec has not yet returned. inExec atomic.Bool @@ -1053,7 +1068,7 @@ func (c *Cache) loadOrStoreIdent(id unique.Handle[ID]) ( d := make(chan struct{}) pending = &pendingCure{done: d} - ctx, pending.cancel = context.WithCancel(c.ctx) + ctx, pending.cancel = context.WithCancel(c.toplevel.Load().ctx) c.identPending[id] = pending c.identMu.Unlock() done = d @@ -1295,7 +1310,7 @@ func (c *Cache) Cure(a Artifact) ( checksum unique.Handle[Checksum], err error, ) { - if err = c.ctx.Err(); err != nil { + if err = c.toplevel.Load().ctx.Err(); err != nil { return } @@ -1382,15 +1397,16 @@ func (c *Cache) enterCure(a Artifact, curesExempt bool) error { return nil } + ctx := c.toplevel.Load().ctx select { case c.cures <- struct{}{}: return nil - case <-c.ctx.Done(): + case <-ctx.Done(): if a.IsExclusive() { c.exclMu.Unlock() } - return c.ctx.Err() + return ctx.Err() } } @@ -1855,14 +1871,33 @@ func (c *Cache) OpenStatus(a Artifact) (r io.ReadSeekCloser, err error) { return } +// Abort cancels all pending cures but does not close the store. Abort does not +// wait for cures to complete. +func (c *Cache) Abort() { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.closed { + return + } + + c.toplevel.Swap(newToplevel(c.parent)).cancel() +} + // Close cancels all pending cures and waits for them to clean up. func (c *Cache) Close() { - c.closeOnce.Do(func() { - c.cancel() - c.wg.Wait() - close(c.cures) - c.unlock() - }) + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.closed { + return + } + + c.closed = true + c.toplevel.Load().cancel() + c.wg.Wait() + close(c.cures) + c.unlock() } // Open returns the address of a newly opened instance of [Cache]. @@ -1916,6 +1951,8 @@ func open( } c := Cache{ + parent: ctx, + cures: make(chan struct{}, cures), flags: flags, jobs: jobs, @@ -1932,7 +1969,7 @@ func open( brPool: sync.Pool{New: func() any { return new(bufio.Reader) }}, bwPool: sync.Pool{New: func() any { return new(bufio.Writer) }}, } - c.ctx, c.cancel = context.WithCancel(ctx) + c.toplevel.Store(newToplevel(ctx)) if lock || !testing.Testing() { if unlock, err := lockedfile.MutexAt( diff --git a/internal/pkg/pkg_test.go b/internal/pkg/pkg_test.go index dbaa0801..3110ed06 100644 --- a/internal/pkg/pkg_test.go +++ b/internal/pkg/pkg_test.go @@ -16,6 +16,7 @@ import ( "path/filepath" "reflect" "strconv" + "sync" "syscall" "testing" "unique" @@ -903,26 +904,46 @@ func TestCache(t *testing.T) { <-wCureDone }, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")}, - {"cancel hanging", pkg.CValidateKnown, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { - started := make(chan struct{}) - go func() { - <-started - if !c.Cancel(unique.Make(pkg.ID{0xff})) { - panic("missed cancellation") + {"cancel abort block", pkg.CValidateKnown, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { + var wg sync.WaitGroup + defer wg.Wait() + + var started sync.WaitGroup + defer started.Wait() + + blockCures := func(d byte, n int) { + started.Add(n) + for i := range n { + wg.Go(func() { + if _, _, err := c.Cure(overrideIdent{pkg.ID{d, byte(i)}, &stubArtifact{ + kind: pkg.KindTar, + cure: func(t *pkg.TContext) error { + started.Done() + <-t.Unwrap().Done() + return stub.UniqueError(0xbad0 + i) + }, + }}); !reflect.DeepEqual(err, stub.UniqueError(0xbad0+i)) { + panic(err) + } + }) } - }() - if _, _, err := c.Cure(overrideIdent{pkg.ID{0xff}, &stubArtifact{ - kind: pkg.KindTar, - cure: func(t *pkg.TContext) error { - close(started) - <-t.Unwrap().Done() - return stub.UniqueError(0xbad) - }, - }}); !reflect.DeepEqual(err, stub.UniqueError(0xbad)) { - t.Fatalf("Cure: error = %v", err) + started.Wait() } + + blockCures(0xfd, 16) + c.Abort() + wg.Wait() + + blockCures(0xff, 1) + if !c.Cancel(unique.Make(pkg.ID{0xff})) { + t.Fatal("missed cancellation") + } + wg.Wait() for c.Cancel(unique.Make(pkg.ID{0xff})) { } + + c.Close() + c.Abort() }, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")}, {"no assume checksum", 0, nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) {