diff --git a/dbus/dbus_test.go b/dbus/dbus_test.go index 740ab52..0917d9d 100644 --- a/dbus/dbus_test.go +++ b/dbus/dbus_test.go @@ -8,6 +8,7 @@ import ( "os" "os/exec" "strings" + "syscall" "testing" "time" @@ -71,7 +72,7 @@ func TestProxy_Seal(t *testing.T) { for id, tc := range testCasePairs() { t.Run("create seal for "+id, func(t *testing.T) { p := dbus.New(tc[0].bus, tc[1].bus) - if err := p.Seal(tc[0].c, tc[1].c); (errors.Is(err, helper.ErrContainsNull)) != tc[0].wantErr { + if err := p.Seal(tc[0].c, tc[1].c); (errors.Is(err, syscall.EINVAL)) != tc[0].wantErr { t.Errorf("Seal(%p, %p) error = %v, wantErr %v", tc[0].c, tc[1].c, err, tc[0].wantErr) diff --git a/helper/args.go b/helper/args.go index f85d57e..88b5e69 100644 --- a/helper/args.go +++ b/helper/args.go @@ -1,38 +1,17 @@ package helper import ( - "errors" + "bytes" "io" - "strings" + "syscall" ) -var ( - ErrContainsNull = errors.New("argument contains null character") -) - -type argsWt []string - -// checks whether any element contains the null character -// must be called before args use and args must not be modified after call -func (a argsWt) check() error { - for _, arg := range a { - for _, b := range arg { - if b == '\x00' { - return ErrContainsNull - } - } - } - - return nil -} +type argsWt [][]byte func (a argsWt) WriteTo(w io.Writer) (int64, error) { - // assuming already checked - nt := 0 - // write null terminated arguments for _, arg := range a { - n, err := w.Write([]byte(arg + "\x00")) + n, err := w.Write(arg) nt += n if err != nil { @@ -44,18 +23,32 @@ func (a argsWt) WriteTo(w io.Writer) (int64, error) { } func (a argsWt) String() string { - return strings.Join(a, " ") + return string( + bytes.TrimSuffix( + bytes.ReplaceAll( + bytes.Join(a, nil), + []byte{0}, []byte{' '}, + ), + []byte{' '}, + ), + ) } -// NewCheckedArgs returns a checked argument writer for args. -// Callers must not retain any references to args. -func NewCheckedArgs(args []string) (io.WriterTo, error) { - a := argsWt(args) - return a, a.check() +// NewCheckedArgs returns a checked null-terminated argument writer for a copy of args. +func NewCheckedArgs(args []string) (wt io.WriterTo, err error) { + a := make(argsWt, len(args)) + for i, arg := range args { + a[i], err = syscall.ByteSliceFromString(arg) + if err != nil { + return + } + } + wt = a + return } -// MustNewCheckedArgs returns a checked argument writer for args and panics if check fails. -// Callers must not retain any references to args. +// MustNewCheckedArgs returns a checked null-terminated argument writer for a copy of args. +// If s contains a NUL byte this function panics instead of returning an error. func MustNewCheckedArgs(args []string) io.WriterTo { a, err := NewCheckedArgs(args) if err != nil { diff --git a/helper/args_test.go b/helper/args_test.go index 9014041..7eca66b 100644 --- a/helper/args_test.go +++ b/helper/args_test.go @@ -4,34 +4,33 @@ import ( "errors" "fmt" "strings" + "syscall" "testing" "git.gensokyo.uk/security/fortify/helper" ) -func Test_argsFd_String(t *testing.T) { +func TestArgsString(t *testing.T) { wantString := strings.Join(wantArgs, " ") if got := argsWt.(fmt.Stringer).String(); got != wantString { - t.Errorf("String(): got %v; want %v", + t.Errorf("String: %q, want %q", got, wantString) } } func TestNewCheckedArgs(t *testing.T) { args := []string{"\x00"} - if _, err := helper.NewCheckedArgs(args); !errors.Is(err, helper.ErrContainsNull) { - t.Errorf("NewCheckedArgs(%q) error = %v, wantErr %v", - args, - err, helper.ErrContainsNull) + if _, err := helper.NewCheckedArgs(args); !errors.Is(err, syscall.EINVAL) { + t.Errorf("NewCheckedArgs: error = %v, wantErr %v", + err, syscall.EINVAL) } t.Run("must panic", func(t *testing.T) { badPayload := []string{"\x00"} defer func() { - wantPanic := "argument contains null character" + wantPanic := "invalid argument" if r := recover(); r != wantPanic { - t.Errorf("MustNewCheckedArgs(%q) panic = %v, wantPanic %v", - badPayload, + t.Errorf("MustNewCheckedArgs: panic = %v, wantPanic %v", r, wantPanic) } }()