helper/args: copy args on wt creation
Some checks failed
Test / Create distribution (push) Successful in 25s
Test / Fpkg (push) Failing after 47s
Test / Fortify (push) Failing after 1m4s
Test / Data race detector (push) Failing after 1m3s
Test / Flake checks (push) Has been skipped

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
Ophestra 2025-03-27 17:30:40 +09:00
parent 5c82f1ed3e
commit 4ef4e13eef
Signed by: cat
SSH Key Fingerprint: SHA256:gQ67O0enBZ7UdZypgtspB2FDM1g3GVw8nX0XSdcFw8Q
2 changed files with 34 additions and 42 deletions

View File

@ -1,38 +1,17 @@
package helper package helper
import ( import (
"errors" "bytes"
"io" "io"
"strings" "syscall"
) )
var ( type argsWt [][]byte
ErrContainsNull = errors.New("argument contains null character")
)
type argsWt []string
// checks whether any element contains the null character
// must be called before args use and args must not be modified after call
func (a argsWt) check() error {
for _, arg := range a {
for _, b := range arg {
if b == '\x00' {
return ErrContainsNull
}
}
}
return nil
}
func (a argsWt) WriteTo(w io.Writer) (int64, error) { func (a argsWt) WriteTo(w io.Writer) (int64, error) {
// assuming already checked
nt := 0 nt := 0
// write null terminated arguments
for _, arg := range a { for _, arg := range a {
n, err := w.Write([]byte(arg + "\x00")) n, err := w.Write(arg)
nt += n nt += n
if err != nil { if err != nil {
@ -44,18 +23,32 @@ func (a argsWt) WriteTo(w io.Writer) (int64, error) {
} }
func (a argsWt) String() string { func (a argsWt) String() string {
return strings.Join(a, " ") return string(
bytes.TrimSuffix(
bytes.ReplaceAll(
bytes.Join(a, nil),
[]byte{0}, []byte{' '},
),
[]byte{' '},
),
)
} }
// NewCheckedArgs returns a checked argument writer for args. // NewCheckedArgs returns a checked null-terminated argument writer for a copy of args.
// Callers must not retain any references to args. func NewCheckedArgs(args []string) (wt io.WriterTo, err error) {
func NewCheckedArgs(args []string) (io.WriterTo, error) { a := make(argsWt, len(args))
a := argsWt(args) for i, arg := range args {
return a, a.check() a[i], err = syscall.ByteSliceFromString(arg)
if err != nil {
return
}
}
wt = a
return
} }
// MustNewCheckedArgs returns a checked argument writer for args and panics if check fails. // MustNewCheckedArgs returns a checked null-terminated argument writer for a copy of args.
// Callers must not retain any references to args. // If s contains a NUL byte this function panics instead of returning an error.
func MustNewCheckedArgs(args []string) io.WriterTo { func MustNewCheckedArgs(args []string) io.WriterTo {
a, err := NewCheckedArgs(args) a, err := NewCheckedArgs(args)
if err != nil { if err != nil {

View File

@ -4,34 +4,33 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"syscall"
"testing" "testing"
"git.gensokyo.uk/security/fortify/helper" "git.gensokyo.uk/security/fortify/helper"
) )
func Test_argsFd_String(t *testing.T) { func TestArgsString(t *testing.T) {
wantString := strings.Join(wantArgs, " ") wantString := strings.Join(wantArgs, " ")
if got := argsWt.(fmt.Stringer).String(); got != wantString { if got := argsWt.(fmt.Stringer).String(); got != wantString {
t.Errorf("String(): got %v; want %v", t.Errorf("String: %q, want %q",
got, wantString) got, wantString)
} }
} }
func TestNewCheckedArgs(t *testing.T) { func TestNewCheckedArgs(t *testing.T) {
args := []string{"\x00"} args := []string{"\x00"}
if _, err := helper.NewCheckedArgs(args); !errors.Is(err, helper.ErrContainsNull) { if _, err := helper.NewCheckedArgs(args); !errors.Is(err, syscall.EINVAL) {
t.Errorf("NewCheckedArgs(%q) error = %v, wantErr %v", t.Errorf("NewCheckedArgs: error = %v, wantErr %v",
args, err, syscall.EINVAL)
err, helper.ErrContainsNull)
} }
t.Run("must panic", func(t *testing.T) { t.Run("must panic", func(t *testing.T) {
badPayload := []string{"\x00"} badPayload := []string{"\x00"}
defer func() { defer func() {
wantPanic := "argument contains null character" wantPanic := "invalid argument"
if r := recover(); r != wantPanic { if r := recover(); r != wantPanic {
t.Errorf("MustNewCheckedArgs(%q) panic = %v, wantPanic %v", t.Errorf("MustNewCheckedArgs: panic = %v, wantPanic %v",
badPayload,
r, wantPanic) r, wantPanic)
} }
}() }()