From 8543812e02404034df5b111621da724c3a82f84b Mon Sep 17 00:00:00 2001 From: Ophestra Date: Sat, 9 Aug 2025 17:50:03 +0900 Subject: [PATCH] container/absolute: early absolute pathname check This is less error-prone, and allows pathname to be checked once. Signed-off-by: Ophestra --- container/absolute.go | 78 ++++++++++ container/absolute_test.go | 284 +++++++++++++++++++++++++++++++++++++ 2 files changed, 362 insertions(+) create mode 100644 container/absolute.go create mode 100644 container/absolute_test.go diff --git a/container/absolute.go b/container/absolute.go new file mode 100644 index 0000000..0eee497 --- /dev/null +++ b/container/absolute.go @@ -0,0 +1,78 @@ +package container + +import ( + "encoding/json" + "errors" + "fmt" + "path" + "syscall" +) + +// AbsoluteError is returned by [NewAbsolute] and holds the invalid pathname. +type AbsoluteError struct { + Pathname string +} + +func (e *AbsoluteError) Error() string { return fmt.Sprintf("path %q is not absolute", e.Pathname) } +func (e *AbsoluteError) Is(target error) bool { + var ce *AbsoluteError + if !errors.As(target, &ce) { + return errors.Is(target, syscall.EINVAL) + } + return *e == *ce +} + +// Absolute holds a pathname checked to be absolute. +type Absolute struct { + pathname string +} + +// isAbs wraps [path.IsAbs] in case additional checks are added in the future. +func isAbs(pathname string) bool { return path.IsAbs(pathname) } + +func (a *Absolute) String() string { + if a.pathname == zeroString { + panic("attempted use of zero Absolute") + } + return a.pathname +} + +// NewAbsolute checks pathname and returns a new [Absolute] if pathname is absolute. +func NewAbsolute(pathname string) (*Absolute, error) { + if !isAbs(pathname) { + return nil, &AbsoluteError{pathname} + } + return &Absolute{pathname}, nil +} + +// MustAbsolute calls [NewAbsolute] and panics on error. +func MustAbsolute(pathname string) *Absolute { + if a, err := NewAbsolute(pathname); err != nil { + panic(err.Error()) + } else { + return a + } +} + +func (a *Absolute) GobEncode() ([]byte, error) { return []byte(a.String()), nil } +func (a *Absolute) GobDecode(data []byte) error { + pathname := string(data) + if !isAbs(pathname) { + return &AbsoluteError{pathname} + } + a.pathname = pathname + return nil +} + +func (a *Absolute) MarshalJSON() ([]byte, error) { return json.Marshal(a.String()) } +func (a *Absolute) UnmarshalJSON(data []byte) error { + var pathname string + if err := json.Unmarshal(data, &pathname); err != nil { + return err + } + if !isAbs(pathname) { + return &AbsoluteError{pathname} + } + a.pathname = pathname + return nil +} diff --git a/container/absolute_test.go b/container/absolute_test.go new file mode 100644 index 0000000..8a5b8b0 --- /dev/null +++ b/container/absolute_test.go @@ -0,0 +1,284 @@ +package container + +import ( + "bytes" + "encoding/gob" + "encoding/json" + "errors" + "reflect" + "strings" + "syscall" + "testing" +) + +func TestAbsoluteError(t *testing.T) { + testCases := []struct { + name string + + err error + cmp error + ok bool + }{ + {"EINVAL", new(AbsoluteError), syscall.EINVAL, true}, + {"not EINVAL", new(AbsoluteError), syscall.EBADE, false}, + {"ne val", new(AbsoluteError), &AbsoluteError{"etc"}, false}, + {"equals", &AbsoluteError{"etc"}, &AbsoluteError{"etc"}, true}, + } + + for _, tc := range testCases { + if got := errors.Is(tc.err, tc.cmp); got != tc.ok { + t.Errorf("Is: %v, want %v", got, tc.ok) + } + } + + t.Run("string", func(t *testing.T) { + want := `path "etc" is not absolute` + if got := (&AbsoluteError{"etc"}).Error(); got != want { + t.Errorf("Error: %q, want %q", got, want) + } + }) +} + +func TestNewAbsolute(t *testing.T) { + testCases := []struct { + name string + + pathname string + want *Absolute + wantErr error + }{ + {"good", "/etc", MustAbsolute("/etc"), nil}, + {"not absolute", "etc", nil, &AbsoluteError{"etc"}}, + {"zero", "", nil, &AbsoluteError{""}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := NewAbsolute(tc.pathname) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("NewAbsolute: %#v, want %#v", got, tc.want) + } + if !errors.Is(err, tc.wantErr) { + t.Errorf("NewAbsolute: error = %v, want %v", err, tc.wantErr) + } + }) + } + + t.Run("must", func(t *testing.T) { + defer func() { + wantPanic := `path "etc" is not absolute` + + if r := recover(); r != wantPanic { + t.Errorf("MustAbsolute: panic = %v; want %v", r, wantPanic) + } + }() + + MustAbsolute("etc") + }) +} + +func TestAbsoluteString(t *testing.T) { + t.Run("passthrough", func(t *testing.T) { + pathname := "/etc" + if got := (&Absolute{pathname}).String(); got != pathname { + t.Errorf("String: %q, want %q", got, pathname) + } + }) + + t.Run("zero", func(t *testing.T) { + defer func() { + wantPanic := "attempted use of zero Absolute" + + if r := recover(); r != wantPanic { + t.Errorf("String: panic = %v, want %v", r, wantPanic) + } + }() + + panic(new(Absolute).String()) + }) +} + +type sCheck struct { + Pathname *Absolute `json:"val"` + Magic int `json:"magic"` +} + +func TestCodecAbsolute(t *testing.T) { + testCases := []struct { + name string + a *Absolute + + wantErr error + + gob, sGob string + json, sJson string + }{ + {"good", MustAbsolute("/etc"), + nil, + "\t\x7f\x05\x01\x02\xff\x82\x00\x00\x00\b\xff\x80\x00\x04/etc", + ",\xff\x83\x03\x01\x01\x06sCheck\x01\xff\x84\x00\x01\x02\x01\bPathname\x01\xff\x80\x00\x01\x05Magic\x01\x04\x00\x00\x00\t\x7f\x05\x01\x02\xff\x82\x00\x00\x00\x10\xff\x84\x01\x04/etc\x01\xfb\x01\x81\xda\x00\x00\x00", + + `"/etc"`, `{"val":"/etc","magic":3236757504}`}, + {"not absolute", nil, + &AbsoluteError{"etc"}, + "\t\x7f\x05\x01\x02\xff\x82\x00\x00\x00\a\xff\x80\x00\x03etc", + ",\xff\x83\x03\x01\x01\x06sCheck\x01\xff\x84\x00\x01\x02\x01\bPathname\x01\xff\x80\x00\x01\x05Magic\x01\x04\x00\x00\x00\t\x7f\x05\x01\x02\xff\x82\x00\x00\x00\x0f\xff\x84\x01\x03etc\x01\xfb\x01\x81\xda\x00\x00\x00", + + `"etc"`, `{"val":"etc","magic":3236757504}`}, + {"zero", nil, + new(AbsoluteError), + "\t\x7f\x05\x01\x02\xff\x82\x00\x00\x00\x04\xff\x80\x00\x00", + ",\xff\x83\x03\x01\x01\x06sCheck\x01\xff\x84\x00\x01\x02\x01\bPathname\x01\xff\x80\x00\x01\x05Magic\x01\x04\x00\x00\x00\t\x7f\x05\x01\x02\xff\x82\x00\x00\x00\f\xff\x84\x01\x00\x01\xfb\x01\x81\xda\x00\x00\x00", + `""`, `{"val":"","magic":3236757504}`}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Run("gob", func(t *testing.T) { + t.Run("encode", func(t *testing.T) { + // encode is unchecked + if errors.Is(tc.wantErr, syscall.EINVAL) { + return + } + + { + buf := new(bytes.Buffer) + err := gob.NewEncoder(buf).Encode(tc.a) + if !errors.Is(err, tc.wantErr) { + t.Errorf("Encode: error = %v, want %v", err, tc.wantErr) + } + if tc.wantErr != nil { + goto checkSEncode + } + if buf.String() != tc.gob { + t.Errorf("Encode:\n%q\nwant:\n%q", buf.String(), tc.gob) + } + } + + checkSEncode: + { + buf := new(bytes.Buffer) + err := gob.NewEncoder(buf).Encode(&sCheck{tc.a, syscall.MS_MGC_VAL}) + if !errors.Is(err, tc.wantErr) { + t.Errorf("Encode: error = %v, want %v", err, tc.wantErr) + } + if tc.wantErr != nil { + return + } + if buf.String() != tc.sGob { + t.Errorf("Encode:\n%q\nwant:\n%q", buf.String(), tc.sGob) + } + } + }) + + t.Run("decode", func(t *testing.T) { + { + var gotA *Absolute + err := gob.NewDecoder(strings.NewReader(tc.gob)).Decode(&gotA) + if !errors.Is(err, tc.wantErr) { + t.Errorf("Decode: error = %v, want %v", err, tc.wantErr) + } + if tc.wantErr != nil { + goto checkSDecode + } + if !reflect.DeepEqual(tc.a, gotA) { + t.Errorf("Decode: %#v, want %#v", tc.a, gotA) + } + } + + checkSDecode: + { + var gotSCheck sCheck + err := gob.NewDecoder(strings.NewReader(tc.sGob)).Decode(&gotSCheck) + if !errors.Is(err, tc.wantErr) { + t.Errorf("Decode: error = %v, want %v", err, tc.wantErr) + } + if tc.wantErr != nil { + return + } + want := sCheck{tc.a, syscall.MS_MGC_VAL} + if !reflect.DeepEqual(gotSCheck, want) { + t.Errorf("Decode: %#v, want %#v", gotSCheck, want) + } + } + }) + + }) + + t.Run("json", func(t *testing.T) { + t.Run("marshal", func(t *testing.T) { + // marshal is unchecked + if errors.Is(tc.wantErr, syscall.EINVAL) { + return + } + + { + d, err := json.Marshal(tc.a) + if !errors.Is(err, tc.wantErr) { + t.Errorf("Marshal: error = %v, want %v", err, tc.wantErr) + } + if tc.wantErr != nil { + goto checkSMarshal + } + if string(d) != tc.json { + t.Errorf("Marshal:\n%s\nwant:\n%s", string(d), tc.json) + } + } + + checkSMarshal: + { + d, err := json.Marshal(&sCheck{tc.a, syscall.MS_MGC_VAL}) + if !errors.Is(err, tc.wantErr) { + t.Errorf("Marshal: error = %v, want %v", err, tc.wantErr) + } + if tc.wantErr != nil { + return + } + if string(d) != tc.sJson { + t.Errorf("Marshal:\n%s\nwant:\n%s", string(d), tc.sJson) + } + } + }) + + t.Run("unmarshal", func(t *testing.T) { + { + var gotA *Absolute + err := json.Unmarshal([]byte(tc.json), &gotA) + if !errors.Is(err, tc.wantErr) { + t.Errorf("Unmarshal: error = %v, want %v", err, tc.wantErr) + } + if tc.wantErr != nil { + goto checkSUnmarshal + } + if !reflect.DeepEqual(tc.a, gotA) { + t.Errorf("Unmarshal: %#v, want %#v", tc.a, gotA) + } + } + + checkSUnmarshal: + { + var gotSCheck sCheck + err := json.Unmarshal([]byte(tc.sJson), &gotSCheck) + if !errors.Is(err, tc.wantErr) { + t.Errorf("Unmarshal: error = %v, want %v", err, tc.wantErr) + } + if tc.wantErr != nil { + return + } + want := sCheck{tc.a, syscall.MS_MGC_VAL} + if !reflect.DeepEqual(gotSCheck, want) { + t.Errorf("Unmarshal: %#v, want %#v", gotSCheck, want) + } + } + }) + }) + }) + } + + t.Run("json passthrough", func(t *testing.T) { + wantErr := "invalid character ':' looking for beginning of value" + if err := new(Absolute).UnmarshalJSON([]byte(":3")); err == nil || err.Error() != wantErr { + t.Errorf("UnmarshalJSON: error = %v, want %s", err, wantErr) + } + }) +}