internal/pkg: expose response body
All checks were successful
Test / Create distribution (push) Successful in 49s
Test / Sandbox (push) Successful in 2m37s
Test / Hakurei (push) Successful in 3m52s
Test / ShareFS (push) Successful in 3m56s
Test / Hpkg (push) Successful in 4m26s
Test / Sandbox (race detector) (push) Successful in 4m56s
Test / Hakurei (race detector) (push) Successful in 5m52s
Test / Flake checks (push) Successful in 1m39s

This uses the new measured reader provided by Cache. This should make httpArtifact zero-copy.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2026-01-25 16:07:07 +09:00
parent 334578fdde
commit 861801597d
5 changed files with 156 additions and 66 deletions

View File

@@ -2,7 +2,6 @@ package pkg
import ( import (
"bytes" "bytes"
"context"
"crypto/sha512" "crypto/sha512"
"fmt" "fmt"
"io" "io"
@@ -54,6 +53,6 @@ func (a *fileArtifact) Checksum() Checksum {
} }
// Cure returns the caller-supplied data. // Cure returns the caller-supplied data.
func (a *fileArtifact) Cure(context.Context) (io.ReadCloser, error) { func (a *fileArtifact) Cure(*RContext) (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(*a)), nil return io.NopCloser(bytes.NewReader(*a)), nil
} }

View File

@@ -1,14 +1,11 @@
package pkg package pkg
import ( import (
"bytes"
"context"
"crypto/sha512"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"path" "path"
"sync" "unique"
) )
// An httpArtifact is an [Artifact] backed by a [http] url string. The method is // An httpArtifact is an [Artifact] backed by a [http] url string. The method is
@@ -18,18 +15,12 @@ type httpArtifact struct {
// Caller-supplied url string. // Caller-supplied url string.
url string url string
// Caller-supplied checksum of the response body. This is validated during // Caller-supplied checksum of the response body. This is validated when
// curing and the first call to Data. // closing the [io.ReadCloser] returned by Cure.
checksum Checksum checksum unique.Handle[Checksum]
// doFunc is the Do method of [http.Client] supplied by the caller. // doFunc is the Do method of [http.Client] supplied by the caller.
doFunc func(req *http.Request) (*http.Response, error) doFunc func(req *http.Request) (*http.Response, error)
// Response body read to EOF.
data []byte
// Synchronises access to data.
mu sync.Mutex
} }
var _ KnownChecksum = new(httpArtifact) var _ KnownChecksum = new(httpArtifact)
@@ -45,7 +36,7 @@ func NewHTTPGet(
if c == nil { if c == nil {
c = http.DefaultClient c = http.DefaultClient
} }
return &httpArtifact{url: url, checksum: checksum, doFunc: c.Do} return &httpArtifact{url: url, checksum: unique.Make(checksum), doFunc: c.Do}
} }
// Kind returns the hardcoded [Kind] constant. // Kind returns the hardcoded [Kind] constant.
@@ -61,7 +52,7 @@ func (a *httpArtifact) Params(ctx *IContext) {
func (a *httpArtifact) Dependencies() []Artifact { return nil } func (a *httpArtifact) Dependencies() []Artifact { return nil }
// Checksum returns the caller-supplied checksum. // Checksum returns the caller-supplied checksum.
func (a *httpArtifact) Checksum() Checksum { return a.checksum } func (a *httpArtifact) Checksum() Checksum { return a.checksum.Value() }
// String returns [path.Base] over the backing url. // String returns [path.Base] over the backing url.
func (a *httpArtifact) String() string { return path.Base(a.url) } func (a *httpArtifact) String() string { return path.Base(a.url) }
@@ -74,11 +65,13 @@ func (e ResponseStatusError) Error() string {
return "the requested URL returned non-OK status: " + http.StatusText(int(e)) return "the requested URL returned non-OK status: " + http.StatusText(int(e))
} }
// do sends the caller-supplied request on the caller-supplied [http.Client] // Cure sends the http request and returns the resulting response body reader
// and reads its response body to EOF and returns the resulting bytes. // wrapped to perform checksum validation. It is valid but not encouraged to
func (a *httpArtifact) do(ctx context.Context) (data []byte, err error) { // close the resulting [io.ReadCloser] before it is read to EOF, as that causes
// Close to block until all remaining data is consumed and validated.
func (a *httpArtifact) Cure(r *RContext) (rc io.ReadCloser, err error) {
var req *http.Request var req *http.Request
req, err = http.NewRequestWithContext(ctx, http.MethodGet, a.url, nil) req, err = http.NewRequestWithContext(r.Unwrap(), http.MethodGet, a.url, nil)
if err != nil { if err != nil {
return return
} }
@@ -93,37 +86,6 @@ func (a *httpArtifact) do(ctx context.Context) (data []byte, err error) {
return nil, ResponseStatusError(resp.StatusCode) return nil, ResponseStatusError(resp.StatusCode)
} }
if data, err = io.ReadAll(resp.Body); err != nil { rc = r.NewMeasuredReader(resp.Body, a.checksum)
_ = resp.Body.Close()
return
}
err = resp.Body.Close()
return
}
// Cure completes the http request and returns the resulting response body read
// to EOF. Data does not interact with the filesystem.
func (a *httpArtifact) Cure(ctx context.Context) (r io.ReadCloser, err error) {
a.mu.Lock()
defer a.mu.Unlock()
if a.data != nil {
// validated by cache or a previous call to Cure
return io.NopCloser(bytes.NewReader(a.data)), nil
}
var data []byte
if data, err = a.do(ctx); err != nil {
return
}
h := sha512.New384()
h.Write(data)
if got := (Checksum)(h.Sum(nil)); got != a.checksum {
return nil, &ChecksumMismatchError{got, a.checksum}
}
a.data = data
r = io.NopCloser(bytes.NewReader(data))
return return
} }

View File

@@ -8,6 +8,7 @@ import (
"testing" "testing"
"testing/fstest" "testing/fstest"
"unique" "unique"
"unsafe"
"hakurei.app/container/check" "hakurei.app/container/check"
"hakurei.app/internal/pkg" "hakurei.app/internal/pkg"
@@ -32,6 +33,12 @@ func TestHTTPGet(t *testing.T) {
checkWithCache(t, []cacheTestCase{ checkWithCache(t, []cacheTestCase{
{"direct", nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { {"direct", nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) {
var r pkg.RContext
rCacheVal := reflect.ValueOf(&r).Elem().FieldByName("cache")
reflect.NewAt(
rCacheVal.Type(),
unsafe.Pointer(rCacheVal.UnsafeAddr()),
).Elem().Set(reflect.ValueOf(c))
f := pkg.NewHTTPGet( f := pkg.NewHTTPGet(
&client, &client,
@@ -39,12 +46,14 @@ func TestHTTPGet(t *testing.T) {
testdataChecksum.Value(), testdataChecksum.Value(),
) )
var got []byte var got []byte
if r, err := f.Cure(t.Context()); err != nil { if rc, err := f.Cure(&r); err != nil {
t.Fatalf("Cure: error = %v", err) t.Fatalf("Cure: error = %v", err)
} else if got, err = io.ReadAll(r); err != nil { } else if got, err = io.ReadAll(rc); err != nil {
t.Fatalf("ReadAll: error = %v", err) t.Fatalf("ReadAll: error = %v", err)
} else if string(got) != testdata { } else if string(got) != testdata {
t.Fatalf("Cure: %x, want %x", got, testdata) t.Fatalf("Cure: %x, want %x", got, testdata)
} else if err = rc.Close(); err != nil {
t.Fatalf("Close: error = %v", err)
} }
// check direct validation // check direct validation
@@ -56,8 +65,21 @@ func TestHTTPGet(t *testing.T) {
wantErrMismatch := &pkg.ChecksumMismatchError{ wantErrMismatch := &pkg.ChecksumMismatchError{
Got: testdataChecksum.Value(), Got: testdataChecksum.Value(),
} }
if _, err := f.Cure(t.Context()); !reflect.DeepEqual(err, wantErrMismatch) { if rc, err := f.Cure(&r); err != nil {
t.Fatalf("Cure: error = %#v, want %#v", err, wantErrMismatch) t.Fatalf("Cure: error = %v", err)
} else if got, err = io.ReadAll(rc); err != nil {
t.Fatalf("ReadAll: error = %v", err)
} else if string(got) != testdata {
t.Fatalf("Cure: %x, want %x", got, testdata)
} else if err = rc.Close(); !reflect.DeepEqual(err, wantErrMismatch) {
t.Fatalf("Close: error = %#v, want %#v", err, wantErrMismatch)
}
// check fallback validation
if rc, err := f.Cure(&r); err != nil {
t.Fatalf("Cure: error = %v", err)
} else if err = rc.Close(); !reflect.DeepEqual(err, wantErrMismatch) {
t.Fatalf("Close: error = %#v, want %#v", err, wantErrMismatch)
} }
// check direct response error // check direct response error
@@ -67,12 +89,19 @@ func TestHTTPGet(t *testing.T) {
pkg.Checksum{}, pkg.Checksum{},
) )
wantErrNotFound := pkg.ResponseStatusError(http.StatusNotFound) wantErrNotFound := pkg.ResponseStatusError(http.StatusNotFound)
if _, err := f.Cure(t.Context()); !reflect.DeepEqual(err, wantErrNotFound) { if _, err := f.Cure(&r); !reflect.DeepEqual(err, wantErrNotFound) {
t.Fatalf("Cure: error = %#v, want %#v", err, wantErrNotFound) t.Fatalf("Cure: error = %#v, want %#v", err, wantErrNotFound)
} }
}, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")}, }, pkg.MustDecode("E4vEZKhCcL2gPZ2Tt59FS3lDng-d_2SKa2i5G_RbDfwGn6EemptFaGLPUDiOa94C")},
{"cure", nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) { {"cure", nil, func(t *testing.T, base *check.Absolute, c *pkg.Cache) {
var r pkg.RContext
rCacheVal := reflect.ValueOf(&r).Elem().FieldByName("cache")
reflect.NewAt(
rCacheVal.Type(),
unsafe.Pointer(rCacheVal.UnsafeAddr()),
).Elem().Set(reflect.ValueOf(c))
f := pkg.NewHTTPGet( f := pkg.NewHTTPGet(
&client, &client,
"file:///testdata", "file:///testdata",
@@ -91,12 +120,14 @@ func TestHTTPGet(t *testing.T) {
} }
var got []byte var got []byte
if r, err := f.Cure(t.Context()); err != nil { if rc, err := f.Cure(&r); err != nil {
t.Fatalf("Cure: error = %v", err) t.Fatalf("Cure: error = %v", err)
} else if got, err = io.ReadAll(r); err != nil { } else if got, err = io.ReadAll(rc); err != nil {
t.Fatalf("ReadAll: error = %v", err) t.Fatalf("ReadAll: error = %v", err)
} else if string(got) != testdata { } else if string(got) != testdata {
t.Fatalf("Cure: %x, want %x", got, testdata) t.Fatalf("Cure: %x, want %x", got, testdata)
} else if err = rc.Close(); err != nil {
t.Fatalf("Close: error = %v", err)
} }
// check load from cache // check load from cache
@@ -105,12 +136,14 @@ func TestHTTPGet(t *testing.T) {
"file:///testdata", "file:///testdata",
testdataChecksum.Value(), testdataChecksum.Value(),
) )
if r, err := f.Cure(t.Context()); err != nil { if rc, err := f.Cure(&r); err != nil {
t.Fatalf("Cure: error = %v", err) t.Fatalf("Cure: error = %v", err)
} else if got, err = io.ReadAll(r); err != nil { } else if got, err = io.ReadAll(rc); err != nil {
t.Fatalf("ReadAll: error = %v", err) t.Fatalf("ReadAll: error = %v", err)
} else if string(got) != testdata { } else if string(got) != testdata {
t.Fatalf("Cure: %x, want %x", got, testdata) t.Fatalf("Cure: %x, want %x", got, testdata)
} else if err = rc.Close(); err != nil {
t.Fatalf("Close: error = %v", err)
} }
// check error passthrough // check error passthrough

View File

@@ -215,6 +215,20 @@ func (f *FContext) GetArtifact(a Artifact) (
panic(InvalidLookupError(f.cache.Ident(a).Value())) panic(InvalidLookupError(f.cache.Ident(a).Value()))
} }
// RContext is passed to [FileArtifact.Cure] and provides helper methods useful
// for curing the [FileArtifact].
//
// Methods of RContext are safe for concurrent use. RContext is valid
// until [FileArtifact.Cure] returns.
type RContext struct {
// Address of underlying [Cache], should be zeroed or made unusable after
// [FileArtifact.Cure] returns and must not be exposed directly.
cache *Cache
}
// Unwrap returns the underlying [context.Context].
func (r *RContext) Unwrap() context.Context { return r.cache.ctx }
// An Artifact is a read-only reference to a piece of data that may be created // An Artifact is a read-only reference to a piece of data that may be created
// deterministically but might not currently be available in memory or on the // deterministically but might not currently be available in memory or on the
// filesystem. // filesystem.
@@ -323,7 +337,7 @@ type FileArtifact interface {
// Callers are responsible for closing the resulting [io.ReadCloser]. // Callers are responsible for closing the resulting [io.ReadCloser].
// //
// Result must remain identical across multiple invocations. // Result must remain identical across multiple invocations.
Cure(ctx context.Context) (io.ReadCloser, error) Cure(r *RContext) (io.ReadCloser, error)
Artifact Artifact
} }
@@ -952,7 +966,7 @@ func (c *Cache) openFile(f FileArtifact) (r io.ReadCloser, err error) {
} }
}() }()
} }
return f.Cure(c.ctx) return f.Cure(&RContext{c})
} }
return return
} }
@@ -1229,6 +1243,88 @@ func (c *Cache) getWriter(w io.Writer) *bufio.Writer {
return bw return bw
} }
// measuredReader implements [io.ReadCloser] and measures the checksum during
// Close. If the underlying reader is not read to EOF, Close blocks until all
// remaining data is consumed and validated.
type measuredReader struct {
// Underlying reader. Never exposed directly.
r io.ReadCloser
// For validating checksum. Never exposed directly.
h hash.Hash
// Buffers writes to h, initialised by [Cache]. Never exposed directly.
hbw *bufio.Writer
// Expected checksum, compared during Close.
want unique.Handle[Checksum]
// For accessing free lists.
c *Cache
// Set up via [io.TeeReader] by [Cache].
io.Reader
}
// Close reads the underlying [io.ReadCloser] to EOF, closes it and measures its
// outcome. It returns a [ChecksumMismatchError] for an unexpected checksum.
func (mr *measuredReader) Close() (err error) {
if mr.hbw == nil || mr.Reader == nil {
return os.ErrInvalid
}
err = mr.hbw.Flush()
mr.c.putWriter(mr.hbw)
mr.hbw, mr.Reader = nil, nil
if err != nil {
_ = mr.r.Close()
return
}
var n int64
if n, err = io.Copy(mr.h, mr.r); err != nil {
_ = mr.r.Close()
return
}
if n > 0 {
mr.c.msg.Verbosef("missed %d bytes on measured reader", n)
}
if err = mr.r.Close(); err != nil {
return
}
buf := mr.c.getIdentBuf()
mr.h.Sum(buf[:0])
if got := Checksum(buf[:]); got != mr.want.Value() {
err = &ChecksumMismatchError{
Got: got,
Want: mr.want.Value(),
}
}
mr.c.putIdentBuf(buf)
return
}
// newMeasuredReader implements [RContext.NewMeasuredReader].
func (c *Cache) newMeasuredReader(
r io.ReadCloser,
checksum unique.Handle[Checksum],
) io.ReadCloser {
mr := measuredReader{r: r, h: sha512.New384(), want: checksum, c: c}
mr.hbw = c.getWriter(mr.h)
mr.Reader = io.TeeReader(r, mr.hbw)
return &mr
}
// NewMeasuredReader returns an [io.ReadCloser] implementing behaviour required
// by [FileArtifact]. The resulting [io.ReadCloser] holds a buffer originating
// from [Cache] and must be closed to return this buffer.
func (r *RContext) NewMeasuredReader(
rc io.ReadCloser,
checksum unique.Handle[Checksum],
) io.ReadCloser {
return r.cache.newMeasuredReader(rc, checksum)
}
// putWriter adds bw to bufioPool. // putWriter adds bw to bufioPool.
func (c *Cache) putWriter(bw *bufio.Writer) { c.bufioPool.Put(bw) } func (c *Cache) putWriter(bw *bufio.Writer) { c.bufioPool.Put(bw) }
@@ -1363,7 +1459,7 @@ func (c *Cache) cure(a Artifact, curesExempt bool) (
if err = c.enterCure(curesExempt); err != nil { if err = c.enterCure(curesExempt); err != nil {
return return
} }
r, err = f.Cure(c.ctx) r, err = f.Cure(&RContext{c})
if err == nil { if err == nil {
if checksumPathname == nil || c.IsStrict() { if checksumPathname == nil || c.IsStrict() {
h := sha512.New384() h := sha512.New384()

View File

@@ -119,7 +119,7 @@ type stubFile struct {
stubArtifact stubArtifact
} }
func (a *stubFile) Cure(context.Context) (io.ReadCloser, error) { func (a *stubFile) Cure(*pkg.RContext) (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(a.data)), a.err return io.NopCloser(bytes.NewReader(a.data)), a.err
} }