diff --git a/internal/pkg/file.go b/internal/pkg/file.go index 03197c5..9e075cd 100644 --- a/internal/pkg/file.go +++ b/internal/pkg/file.go @@ -2,7 +2,6 @@ package pkg import ( "bytes" - "context" "crypto/sha512" "fmt" "io" @@ -54,6 +53,6 @@ func (a *fileArtifact) Checksum() Checksum { } // Cure returns the caller-supplied data. -func (a *fileArtifact) Cure(context.Context) (io.ReadCloser, error) { +func (a *fileArtifact) Cure(*RContext) (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(*a)), nil } diff --git a/internal/pkg/net.go b/internal/pkg/net.go index 8726fca..00c2bc6 100644 --- a/internal/pkg/net.go +++ b/internal/pkg/net.go @@ -1,14 +1,11 @@ package pkg import ( - "bytes" - "context" - "crypto/sha512" "fmt" "io" "net/http" "path" - "sync" + "unique" ) // An httpArtifact is an [Artifact] backed by a [http] url string. The method is @@ -18,18 +15,12 @@ type httpArtifact struct { // Caller-supplied url string. url string - // Caller-supplied checksum of the response body. This is validated during - // curing and the first call to Data. - checksum Checksum + // Caller-supplied checksum of the response body. This is validated when + // closing the [io.ReadCloser] returned by Cure. + checksum unique.Handle[Checksum] // doFunc is the Do method of [http.Client] supplied by the caller. doFunc func(req *http.Request) (*http.Response, error) - - // Response body read to EOF. - data []byte - - // Synchronises access to data. - mu sync.Mutex } var _ KnownChecksum = new(httpArtifact) @@ -45,7 +36,7 @@ func NewHTTPGet( if c == nil { c = http.DefaultClient } - return &httpArtifact{url: url, checksum: checksum, doFunc: c.Do} + return &httpArtifact{url: url, checksum: unique.Make(checksum), doFunc: c.Do} } // Kind returns the hardcoded [Kind] constant. @@ -61,7 +52,7 @@ func (a *httpArtifact) Params(ctx *IContext) { func (a *httpArtifact) Dependencies() []Artifact { return nil } // Checksum returns the caller-supplied checksum. -func (a *httpArtifact) Checksum() Checksum { return a.checksum } +func (a *httpArtifact) Checksum() Checksum { return a.checksum.Value() } // String returns [path.Base] over the backing url. func (a *httpArtifact) String() string { return path.Base(a.url) } @@ -74,11 +65,13 @@ func (e ResponseStatusError) Error() string { return "the requested URL returned non-OK status: " + http.StatusText(int(e)) } -// do sends the caller-supplied request on the caller-supplied [http.Client] -// and reads its response body to EOF and returns the resulting bytes. -func (a *httpArtifact) do(ctx context.Context) (data []byte, err error) { +// Cure sends the http request and returns the resulting response body reader +// wrapped to perform checksum validation. It is valid but not encouraged to +// close the resulting [io.ReadCloser] before it is read to EOF, as that causes +// Close to block until all remaining data is consumed and validated. +func (a *httpArtifact) Cure(r *RContext) (rc io.ReadCloser, err error) { var req *http.Request - req, err = http.NewRequestWithContext(ctx, http.MethodGet, a.url, nil) + req, err = http.NewRequestWithContext(r.Unwrap(), http.MethodGet, a.url, nil) if err != nil { return } @@ -93,37 +86,6 @@ func (a *httpArtifact) do(ctx context.Context) (data []byte, err error) { return nil, ResponseStatusError(resp.StatusCode) } - if data, err = io.ReadAll(resp.Body); err != nil { - _ = resp.Body.Close() - return - } - - err = resp.Body.Close() - return -} - -// Cure completes the http request and returns the resulting response body read -// to EOF. Data does not interact with the filesystem. -func (a *httpArtifact) Cure(ctx context.Context) (r io.ReadCloser, err error) { - a.mu.Lock() - defer a.mu.Unlock() - - if a.data != nil { - // validated by cache or a previous call to Cure - return io.NopCloser(bytes.NewReader(a.data)), nil - } - - var data []byte - if data, err = a.do(ctx); err != nil { - return - } - - h := sha512.New384() - h.Write(data) - if got := (Checksum)(h.Sum(nil)); got != a.checksum { - return nil, &ChecksumMismatchError{got, a.checksum} - } - a.data = data - r = io.NopCloser(bytes.NewReader(data)) + rc = r.NewMeasuredReader(resp.Body, a.checksum) return } diff --git a/internal/pkg/net_test.go b/internal/pkg/net_test.go index c7c4906..afabcba 100644 --- a/internal/pkg/net_test.go +++ b/internal/pkg/net_test.go @@ -8,6 +8,7 @@ import ( "testing" "testing/fstest" "unique" + "unsafe" "hakurei.app/container/check" "hakurei.app/internal/pkg" @@ -32,6 +33,12 @@ func TestHTTPGet(t *testing.T) { checkWithCache(t, []cacheTestCase{ {"direct", nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { + var r pkg.RContext + rCacheVal := reflect.ValueOf(&r).Elem().FieldByName("cache") + reflect.NewAt( + rCacheVal.Type(), + unsafe.Pointer(rCacheVal.UnsafeAddr()), + ).Elem().Set(reflect.ValueOf(c)) f := pkg.NewHTTPGet( &client, @@ -39,12 +46,14 @@ func TestHTTPGet(t *testing.T) { testdataChecksum.Value(), ) var got []byte - if r, err := f.Cure(t.Context()); err != nil { + if rc, err := f.Cure(&r); err != nil { t.Fatalf("Cure: error = %v", err) - } else if got, err = io.ReadAll(r); err != nil { + } else if got, err = io.ReadAll(rc); err != nil { t.Fatalf("ReadAll: error = %v", err) } else if string(got) != testdata { t.Fatalf("Cure: %x, want %x", got, testdata) + } else if err = rc.Close(); err != nil { + t.Fatalf("Close: error = %v", err) } // check direct validation @@ -56,8 +65,21 @@ func TestHTTPGet(t *testing.T) { wantErrMismatch := &pkg.ChecksumMismatchError{ Got: testdataChecksum.Value(), } - if _, err := f.Cure(t.Context()); !reflect.DeepEqual(err, wantErrMismatch) { - t.Fatalf("Cure: error = %#v, want %#v", err, wantErrMismatch) + if rc, err := f.Cure(&r); err != nil { + t.Fatalf("Cure: error = %v", err) + } else if got, err = io.ReadAll(rc); err != nil { + t.Fatalf("ReadAll: error = %v", err) + } else if string(got) != testdata { + t.Fatalf("Cure: %x, want %x", got, testdata) + } else if err = rc.Close(); !reflect.DeepEqual(err, wantErrMismatch) { + t.Fatalf("Close: error = %#v, want %#v", err, wantErrMismatch) + } + + // check fallback validation + if rc, err := f.Cure(&r); err != nil { + t.Fatalf("Cure: error = %v", err) + } else if err = rc.Close(); !reflect.DeepEqual(err, wantErrMismatch) { + t.Fatalf("Close: error = %#v, want %#v", err, wantErrMismatch) } // check direct response error @@ -67,12 +89,19 @@ func TestHTTPGet(t *testing.T) { pkg.Checksum{}, ) wantErrNotFound := pkg.ResponseStatusError(http.StatusNotFound) - if _, err := f.Cure(t.Context()); !reflect.DeepEqual(err, wantErrNotFound) { + if _, err := f.Cure(&r); !reflect.DeepEqual(err, wantErrNotFound) { t.Fatalf("Cure: error = %#v, want %#v", err, wantErrNotFound) } }, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")}, {"cure", nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { + var r pkg.RContext + rCacheVal := reflect.ValueOf(&r).Elem().FieldByName("cache") + reflect.NewAt( + rCacheVal.Type(), + unsafe.Pointer(rCacheVal.UnsafeAddr()), + ).Elem().Set(reflect.ValueOf(c)) + f := pkg.NewHTTPGet( &client, "file:///testdata", @@ -91,12 +120,14 @@ func TestHTTPGet(t *testing.T) { } var got []byte - if r, err := f.Cure(t.Context()); err != nil { + if rc, err := f.Cure(&r); err != nil { t.Fatalf("Cure: error = %v", err) - } else if got, err = io.ReadAll(r); err != nil { + } else if got, err = io.ReadAll(rc); err != nil { t.Fatalf("ReadAll: error = %v", err) } else if string(got) != testdata { t.Fatalf("Cure: %x, want %x", got, testdata) + } else if err = rc.Close(); err != nil { + t.Fatalf("Close: error = %v", err) } // check load from cache @@ -105,12 +136,14 @@ func TestHTTPGet(t *testing.T) { "file:///testdata", testdataChecksum.Value(), ) - if r, err := f.Cure(t.Context()); err != nil { + if rc, err := f.Cure(&r); err != nil { t.Fatalf("Cure: error = %v", err) - } else if got, err = io.ReadAll(r); err != nil { + } else if got, err = io.ReadAll(rc); err != nil { t.Fatalf("ReadAll: error = %v", err) } else if string(got) != testdata { t.Fatalf("Cure: %x, want %x", got, testdata) + } else if err = rc.Close(); err != nil { + t.Fatalf("Close: error = %v", err) } // check error passthrough diff --git a/internal/pkg/pkg.go b/internal/pkg/pkg.go index d029eff..b3b0851 100644 --- a/internal/pkg/pkg.go +++ b/internal/pkg/pkg.go @@ -215,6 +215,20 @@ func (f *FContext) GetArtifact(a Artifact) ( panic(InvalidLookupError(f.cache.Ident(a).Value())) } +// RContext is passed to [FileArtifact.Cure] and provides helper methods useful +// for curing the [FileArtifact]. +// +// Methods of RContext are safe for concurrent use. RContext is valid +// until [FileArtifact.Cure] returns. +type RContext struct { + // Address of underlying [Cache], should be zeroed or made unusable after + // [FileArtifact.Cure] returns and must not be exposed directly. + cache *Cache +} + +// Unwrap returns the underlying [context.Context]. +func (r *RContext) Unwrap() context.Context { return r.cache.ctx } + // An Artifact is a read-only reference to a piece of data that may be created // deterministically but might not currently be available in memory or on the // filesystem. @@ -323,7 +337,7 @@ type FileArtifact interface { // Callers are responsible for closing the resulting [io.ReadCloser]. // // Result must remain identical across multiple invocations. - Cure(ctx context.Context) (io.ReadCloser, error) + Cure(r *RContext) (io.ReadCloser, error) Artifact } @@ -952,7 +966,7 @@ func (c *Cache) openFile(f FileArtifact) (r io.ReadCloser, err error) { } }() } - return f.Cure(c.ctx) + return f.Cure(&RContext{c}) } return } @@ -1229,6 +1243,88 @@ func (c *Cache) getWriter(w io.Writer) *bufio.Writer { return bw } +// measuredReader implements [io.ReadCloser] and measures the checksum during +// Close. If the underlying reader is not read to EOF, Close blocks until all +// remaining data is consumed and validated. +type measuredReader struct { + // Underlying reader. Never exposed directly. + r io.ReadCloser + // For validating checksum. Never exposed directly. + h hash.Hash + // Buffers writes to h, initialised by [Cache]. Never exposed directly. + hbw *bufio.Writer + // Expected checksum, compared during Close. + want unique.Handle[Checksum] + + // For accessing free lists. + c *Cache + + // Set up via [io.TeeReader] by [Cache]. + io.Reader +} + +// Close reads the underlying [io.ReadCloser] to EOF, closes it and measures its +// outcome. It returns a [ChecksumMismatchError] for an unexpected checksum. +func (mr *measuredReader) Close() (err error) { + if mr.hbw == nil || mr.Reader == nil { + return os.ErrInvalid + } + err = mr.hbw.Flush() + mr.c.putWriter(mr.hbw) + mr.hbw, mr.Reader = nil, nil + if err != nil { + _ = mr.r.Close() + return + } + var n int64 + if n, err = io.Copy(mr.h, mr.r); err != nil { + _ = mr.r.Close() + return + } + + if n > 0 { + mr.c.msg.Verbosef("missed %d bytes on measured reader", n) + } + + if err = mr.r.Close(); err != nil { + return + } + + buf := mr.c.getIdentBuf() + mr.h.Sum(buf[:0]) + + if got := Checksum(buf[:]); got != mr.want.Value() { + err = &ChecksumMismatchError{ + Got: got, + Want: mr.want.Value(), + } + } + + mr.c.putIdentBuf(buf) + return +} + +// newMeasuredReader implements [RContext.NewMeasuredReader]. +func (c *Cache) newMeasuredReader( + r io.ReadCloser, + checksum unique.Handle[Checksum], +) io.ReadCloser { + mr := measuredReader{r: r, h: sha512.New384(), want: checksum, c: c} + mr.hbw = c.getWriter(mr.h) + mr.Reader = io.TeeReader(r, mr.hbw) + return &mr +} + +// NewMeasuredReader returns an [io.ReadCloser] implementing behaviour required +// by [FileArtifact]. The resulting [io.ReadCloser] holds a buffer originating +// from [Cache] and must be closed to return this buffer. +func (r *RContext) NewMeasuredReader( + rc io.ReadCloser, + checksum unique.Handle[Checksum], +) io.ReadCloser { + return r.cache.newMeasuredReader(rc, checksum) +} + // putWriter adds bw to bufioPool. func (c *Cache) putWriter(bw *bufio.Writer) { c.bufioPool.Put(bw) } @@ -1363,7 +1459,7 @@ func (c *Cache) cure(a Artifact, curesExempt bool) ( if err = c.enterCure(curesExempt); err != nil { return } - r, err = f.Cure(c.ctx) + r, err = f.Cure(&RContext{c}) if err == nil { if checksumPathname == nil || c.IsStrict() { h := sha512.New384() diff --git a/internal/pkg/pkg_test.go b/internal/pkg/pkg_test.go index 903db56..fe618a6 100644 --- a/internal/pkg/pkg_test.go +++ b/internal/pkg/pkg_test.go @@ -119,7 +119,7 @@ type stubFile struct { stubArtifact } -func (a *stubFile) Cure(context.Context) (io.ReadCloser, error) { +func (a *stubFile) Cure(*pkg.RContext) (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(a.data)), a.err }