diff --git a/internal/pkg/dir.go b/internal/pkg/dir.go index 780b54e..3e7ad6c 100644 --- a/internal/pkg/dir.go +++ b/internal/pkg/dir.go @@ -8,6 +8,7 @@ import ( "io/fs" "math" "os" + "path/filepath" "syscall" "hakurei.app/container/check" @@ -52,8 +53,12 @@ func (ent *FlatEntry) Encode(w io.Writer) (n int, err error) { return w.Write(payload) } +// ErrInsecurePath is returned by [FlatEntry.Decode] if validation is requested +// and a nonlocal path is encountered in the stream. +var ErrInsecurePath = errors.New("insecure file path") + // Decode decodes the entry from its representation produced by Encode. -func (ent *FlatEntry) Decode(r io.Reader) (n int, err error) { +func (ent *FlatEntry) Decode(r io.Reader, validate bool) (n int, err error) { var nr int header := make([]byte, wordSize*2) @@ -92,6 +97,11 @@ func (ent *FlatEntry) Decode(r io.Reader) (n int, err error) { } else { ent.Data = buf[pPathSize : pPathSize+dataSize] } + + if validate && !filepath.IsLocal(ent.Path) { + err = ErrInsecurePath + } + return } @@ -108,11 +118,16 @@ type DirScanner struct { // Entry to store results in. Its address is returned by the Entry method // and is updated on every call to Scan. ent FlatEntry + + // Validate pathnames during decoding. + validate bool } // NewDirScanner returns the address of a new instance of [DirScanner] reading // from r. The caller must no longer read from r after this function returns. -func NewDirScanner(r io.Reader) *DirScanner { return &DirScanner{r: r} } +func NewDirScanner(r io.Reader, validate bool) *DirScanner { + return &DirScanner{r: r, validate: validate} +} // Err returns the first non-EOF I/O error. func (s *DirScanner) Err() error { @@ -132,7 +147,7 @@ func (s *DirScanner) Scan() bool { } var n int - n, s.err = s.ent.Decode(s.r) + n, s.err = s.ent.Decode(s.r, s.validate) if errors.Is(s.err, io.EOF) { return n != 0 } diff --git a/internal/pkg/dir_test.go b/internal/pkg/dir_test.go index 35fbbb3..9766027 100644 --- a/internal/pkg/dir_test.go +++ b/internal/pkg/dir_test.go @@ -85,7 +85,7 @@ func TestFlatten(t *testing.T) { t.Fatalf("Flatten: error = %v", err) } - s := pkg.NewDirScanner(bytes.NewReader(buf.Bytes())) + s := pkg.NewDirScanner(bytes.NewReader(buf.Bytes()), true) var got []pkg.FlatEntry for s.Scan() { got = append(got, *s.Entry())