helper/args: copy args on wt creation
Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
parent
5c82f1ed3e
commit
78aaae7ee0
@ -8,6 +8,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ func TestProxy_Seal(t *testing.T) {
|
|||||||
for id, tc := range testCasePairs() {
|
for id, tc := range testCasePairs() {
|
||||||
t.Run("create seal for "+id, func(t *testing.T) {
|
t.Run("create seal for "+id, func(t *testing.T) {
|
||||||
p := dbus.New(tc[0].bus, tc[1].bus)
|
p := dbus.New(tc[0].bus, tc[1].bus)
|
||||||
if err := p.Seal(tc[0].c, tc[1].c); (errors.Is(err, helper.ErrContainsNull)) != tc[0].wantErr {
|
if err := p.Seal(tc[0].c, tc[1].c); (errors.Is(err, syscall.EINVAL)) != tc[0].wantErr {
|
||||||
t.Errorf("Seal(%p, %p) error = %v, wantErr %v",
|
t.Errorf("Seal(%p, %p) error = %v, wantErr %v",
|
||||||
tc[0].c, tc[1].c,
|
tc[0].c, tc[1].c,
|
||||||
err, tc[0].wantErr)
|
err, tc[0].wantErr)
|
||||||
|
@ -1,38 +1,17 @@
|
|||||||
package helper
|
package helper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
type argsWt [][]byte
|
||||||
ErrContainsNull = errors.New("argument contains null character")
|
|
||||||
)
|
|
||||||
|
|
||||||
type argsWt []string
|
|
||||||
|
|
||||||
// checks whether any element contains the null character
|
|
||||||
// must be called before args use and args must not be modified after call
|
|
||||||
func (a argsWt) check() error {
|
|
||||||
for _, arg := range a {
|
|
||||||
for _, b := range arg {
|
|
||||||
if b == '\x00' {
|
|
||||||
return ErrContainsNull
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a argsWt) WriteTo(w io.Writer) (int64, error) {
|
func (a argsWt) WriteTo(w io.Writer) (int64, error) {
|
||||||
// assuming already checked
|
|
||||||
|
|
||||||
nt := 0
|
nt := 0
|
||||||
// write null terminated arguments
|
|
||||||
for _, arg := range a {
|
for _, arg := range a {
|
||||||
n, err := w.Write([]byte(arg + "\x00"))
|
n, err := w.Write(arg)
|
||||||
nt += n
|
nt += n
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -44,18 +23,32 @@ func (a argsWt) WriteTo(w io.Writer) (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a argsWt) String() string {
|
func (a argsWt) String() string {
|
||||||
return strings.Join(a, " ")
|
return string(
|
||||||
|
bytes.TrimSuffix(
|
||||||
|
bytes.ReplaceAll(
|
||||||
|
bytes.Join(a, nil),
|
||||||
|
[]byte{0}, []byte{' '},
|
||||||
|
),
|
||||||
|
[]byte{' '},
|
||||||
|
),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCheckedArgs returns a checked argument writer for args.
|
// NewCheckedArgs returns a checked null-terminated argument writer for a copy of args.
|
||||||
// Callers must not retain any references to args.
|
func NewCheckedArgs(args []string) (wt io.WriterTo, err error) {
|
||||||
func NewCheckedArgs(args []string) (io.WriterTo, error) {
|
a := make(argsWt, len(args))
|
||||||
a := argsWt(args)
|
for i, arg := range args {
|
||||||
return a, a.check()
|
a[i], err = syscall.ByteSliceFromString(arg)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wt = a
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustNewCheckedArgs returns a checked argument writer for args and panics if check fails.
|
// MustNewCheckedArgs returns a checked null-terminated argument writer for a copy of args.
|
||||||
// Callers must not retain any references to args.
|
// If s contains a NUL byte this function panics instead of returning an error.
|
||||||
func MustNewCheckedArgs(args []string) io.WriterTo {
|
func MustNewCheckedArgs(args []string) io.WriterTo {
|
||||||
a, err := NewCheckedArgs(args)
|
a, err := NewCheckedArgs(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -4,34 +4,33 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.gensokyo.uk/security/fortify/helper"
|
"git.gensokyo.uk/security/fortify/helper"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_argsFd_String(t *testing.T) {
|
func TestArgsString(t *testing.T) {
|
||||||
wantString := strings.Join(wantArgs, " ")
|
wantString := strings.Join(wantArgs, " ")
|
||||||
if got := argsWt.(fmt.Stringer).String(); got != wantString {
|
if got := argsWt.(fmt.Stringer).String(); got != wantString {
|
||||||
t.Errorf("String(): got %v; want %v",
|
t.Errorf("String: %q, want %q",
|
||||||
got, wantString)
|
got, wantString)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewCheckedArgs(t *testing.T) {
|
func TestNewCheckedArgs(t *testing.T) {
|
||||||
args := []string{"\x00"}
|
args := []string{"\x00"}
|
||||||
if _, err := helper.NewCheckedArgs(args); !errors.Is(err, helper.ErrContainsNull) {
|
if _, err := helper.NewCheckedArgs(args); !errors.Is(err, syscall.EINVAL) {
|
||||||
t.Errorf("NewCheckedArgs(%q) error = %v, wantErr %v",
|
t.Errorf("NewCheckedArgs: error = %v, wantErr %v",
|
||||||
args,
|
err, syscall.EINVAL)
|
||||||
err, helper.ErrContainsNull)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("must panic", func(t *testing.T) {
|
t.Run("must panic", func(t *testing.T) {
|
||||||
badPayload := []string{"\x00"}
|
badPayload := []string{"\x00"}
|
||||||
defer func() {
|
defer func() {
|
||||||
wantPanic := "argument contains null character"
|
wantPanic := "invalid argument"
|
||||||
if r := recover(); r != wantPanic {
|
if r := recover(); r != wantPanic {
|
||||||
t.Errorf("MustNewCheckedArgs(%q) panic = %v, wantPanic %v",
|
t.Errorf("MustNewCheckedArgs: panic = %v, wantPanic %v",
|
||||||
badPayload,
|
|
||||||
r, wantPanic)
|
r, wantPanic)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
Loading…
Reference in New Issue
Block a user