diff --git a/dbus/run.go b/dbus/run.go index fc3943a..ae2d9e0 100644 --- a/dbus/run.go +++ b/dbus/run.go @@ -39,9 +39,14 @@ func (p *Proxy) Start(ctx context.Context, output io.Writer, sandbox bool) error c, cancel := context.WithCancelCause(ctx) if !sandbox { - h = helper.New(c, p.seal, p.name, argF) - // xdg-dbus-proxy does not need to inherit the environment - h.SetEnv(make([]string, 0)) + h = helper.NewDirect(c, p.seal, p.name, argF, func(cmd *exec.Cmd) { + if output != nil { + cmd.Stdout, cmd.Stderr = output, output + } + + // xdg-dbus-proxy does not need to inherit the environment + cmd.Env = make([]string, 0) + }, true) } else { // look up absolute path if name is just a file name toolPath := p.name @@ -111,14 +116,15 @@ func (p *Proxy) Start(ctx context.Context, output io.Writer, sandbox bool) error bc.Bind(k, k) } - h = helper.MustNewBwrap(c, bc, toolPath, true, p.seal, argF, nil, nil) + h = helper.MustNewBwrap(c, bc, toolPath, true, p.seal, argF, func(cmd *exec.Cmd) { + if output != nil { + cmd.Stdout, cmd.Stderr = output, output + } + }, nil, nil, true) p.bwrap = bc } - if output != nil { - h.SetStdout(output).SetStderr(output) - } - if err := h.Start(true); err != nil { + if err := h.Start(); err != nil { cancel(err) return err } diff --git a/helper/bwrap.go b/helper/bwrap.go index a1f6e42..7ced938 100644 --- a/helper/bwrap.go +++ b/helper/bwrap.go @@ -5,6 +5,7 @@ import ( "errors" "io" "os" + "os/exec" "slices" "strconv" "sync" @@ -24,14 +25,11 @@ type bubblewrap struct { // name of the command to run in bwrap name string - // whether to set process group id - setpgid bool - lock sync.RWMutex *helperCmd } -func (b *bubblewrap) Start(stat bool) error { +func (b *bubblewrap) Start() error { b.lock.Lock() defer b.lock.Unlock() @@ -41,7 +39,7 @@ func (b *bubblewrap) Start(stat bool) error { return errors.New("exec: already started") } - args := b.finalise(stat) + args := b.finalise() b.Cmd.Args = slices.Grow(b.Cmd.Args, 4+len(args)) b.Cmd.Args = append(b.Cmd.Args, "--args", strconv.Itoa(int(b.argsFd)), "--", b.name) b.Cmd.Args = append(b.Cmd.Args, args...) @@ -53,12 +51,17 @@ func (b *bubblewrap) Start(stat bool) error { // Function argF returns an array of arguments passed directly to the child process. func MustNewBwrap( ctx context.Context, - conf *bwrap.Config, name string, setpgid bool, - wt io.WriterTo, argF func(argsFD, statFD int) []string, + conf *bwrap.Config, + name string, + setpgid bool, + wt io.WriterTo, + argF func(argsFD, statFD int) []string, + cmdF func(cmd *exec.Cmd), extraFiles []*os.File, syncFd *os.File, + stat bool, ) Helper { - b, err := NewBwrap(ctx, conf, name, setpgid, wt, argF, extraFiles, syncFd) + b, err := NewBwrap(ctx, conf, name, setpgid, wt, argF, cmdF, extraFiles, syncFd, stat) if err != nil { panic(err.Error()) } else { @@ -71,19 +74,26 @@ func MustNewBwrap( // Function argF returns an array of arguments passed directly to the child process. func NewBwrap( ctx context.Context, - conf *bwrap.Config, name string, setpgid bool, - wt io.WriterTo, argF func(argsFd, statFd int) []string, + conf *bwrap.Config, + name string, + setpgid bool, + wt io.WriterTo, + argF func(argsFd, statFd int) []string, + cmdF func(cmd *exec.Cmd), extraFiles []*os.File, syncFd *os.File, + stat bool, ) (Helper, error) { b := new(bubblewrap) b.name = name - b.setpgid = setpgid - b.helperCmd = newHelperCmd(b, ctx, BubblewrapName, wt, argF, extraFiles) - if b.setpgid { + b.helperCmd = newHelperCmd(ctx, BubblewrapName, wt, argF, extraFiles, stat) + if setpgid { b.Cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} } + if cmdF != nil { + cmdF(b.helperCmd.Cmd) + } if v, err := NewCheckedArgs(conf.Args(syncFd, b.extraFiles, &b.files)); err != nil { return nil, err diff --git a/helper/bwrap_test.go b/helper/bwrap_test.go index a14a417..4b42356 100644 --- a/helper/bwrap_test.go +++ b/helper/bwrap_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "os" + "os/exec" "strings" "testing" "time" @@ -33,11 +34,11 @@ func TestBwrap(t *testing.T) { h := helper.MustNewBwrap( context.Background(), sc, "fortify", false, - argsWt, argF, - nil, nil, + argsWt, argF, nil, + nil, nil, false, ) - if err := h.Start(false); !errors.Is(err, os.ErrNotExist) { + if err := h.Start(); !errors.Is(err, os.ErrNotExist) { t.Errorf("Start: error = %v, wantErr %v", err, os.ErrNotExist) } @@ -47,8 +48,8 @@ func TestBwrap(t *testing.T) { if got := helper.MustNewBwrap( context.TODO(), sc, "fortify", false, - argsWt, argF, - nil, nil, + argsWt, argF, nil, + nil, nil, false, ); got == nil { t.Errorf("MustNewBwrap(%#v, %#v, %#v) got nil", sc, argsWt, "fortify") @@ -68,8 +69,8 @@ func TestBwrap(t *testing.T) { helper.MustNewBwrap( context.TODO(), &bwrap.Config{Hostname: "\x00"}, "fortify", false, - nil, argF, - nil, nil, + nil, argF, nil, + nil, nil, false, ) }) @@ -78,17 +79,14 @@ func TestBwrap(t *testing.T) { c, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + stdout, stderr := new(strings.Builder), new(strings.Builder) h := helper.MustNewBwrap( - c, - sc, "crash-test-dummy", false, - nil, argFChecked, - nil, nil, + c, sc, "crash-test-dummy", false, + nil, argFChecked, func(cmd *exec.Cmd) { cmd.Stdout, cmd.Stderr = stdout, stderr }, + nil, nil, false, ) - stdout, stderr := new(strings.Builder), new(strings.Builder) - h.SetStdout(stdout).SetStderr(stderr) - - if err := h.Start(false); err != nil { + if err := h.Start(); err != nil { t.Errorf("Start: error = %v", err) return @@ -101,11 +99,10 @@ func TestBwrap(t *testing.T) { }) t.Run("implementation compliance", func(t *testing.T) { - testHelper(t, func(ctx context.Context) helper.Helper { + testHelper(t, func(ctx context.Context, cmdF func(cmd *exec.Cmd), stat bool) helper.Helper { return helper.MustNewBwrap( - ctx, - sc, "crash-test-dummy", false, - argsWt, argF, nil, nil, + ctx, sc, "crash-test-dummy", false, + argsWt, argF, cmdF, nil, nil, stat, ) }) }) diff --git a/helper/direct.go b/helper/direct.go index 9e27913..93eae22 100644 --- a/helper/direct.go +++ b/helper/direct.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "os/exec" "sync" "git.gensokyo.uk/security/fortify/helper/proc" @@ -15,7 +16,7 @@ type direct struct { *helperCmd } -func (h *direct) Start(stat bool) error { +func (h *direct) Start() error { h.lock.Lock() defer h.lock.Unlock() @@ -25,15 +26,25 @@ func (h *direct) Start(stat bool) error { return errors.New("exec: already started") } - args := h.finalise(stat) + args := h.finalise() h.Cmd.Args = append(h.Cmd.Args, args...) return proc.Fulfill(h.ctx, &h.ExtraFiles, h.Cmd.Start, h.files, h.extraFiles) } -// New initialises a new direct Helper instance with wt as the null-terminated argument writer. +// NewDirect initialises a new direct Helper instance with wt as the null-terminated argument writer. // Function argF returns an array of arguments passed directly to the child process. -func New(ctx context.Context, wt io.WriterTo, name string, argF func(argsFd, statFd int) []string) Helper { +func NewDirect( + ctx context.Context, + wt io.WriterTo, + name string, + argF func(argsFd, statFd int) []string, + cmdF func(cmd *exec.Cmd), + stat bool, +) Helper { d := new(direct) - d.helperCmd = newHelperCmd(d, ctx, name, wt, argF, nil) + d.helperCmd = newHelperCmd(ctx, name, wt, argF, nil, stat) + if cmdF != nil { + cmdF(d.helperCmd.Cmd) + } return d } diff --git a/helper/direct_test.go b/helper/direct_test.go index e7bda79..7af0e3a 100644 --- a/helper/direct_test.go +++ b/helper/direct_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "os" + "os/exec" "testing" "git.gensokyo.uk/security/fortify/helper" @@ -11,16 +12,16 @@ import ( func TestDirect(t *testing.T) { t.Run("start non-existent helper path", func(t *testing.T) { - h := helper.New(context.Background(), argsWt, "/nonexistent", argF) + h := helper.NewDirect(context.Background(), argsWt, "/nonexistent", argF, nil, false) - if err := h.Start(false); !errors.Is(err, os.ErrNotExist) { + if err := h.Start(); !errors.Is(err, os.ErrNotExist) { t.Errorf("Start: error = %v, wantErr %v", err, os.ErrNotExist) } }) t.Run("valid new helper nil check", func(t *testing.T) { - if got := helper.New(context.TODO(), argsWt, "fortify", argF); got == nil { + if got := helper.NewDirect(context.TODO(), argsWt, "fortify", argF, nil, false); got == nil { t.Errorf("New(%q, %q) got nil", argsWt, "fortify") return @@ -28,6 +29,8 @@ func TestDirect(t *testing.T) { }) t.Run("implementation compliance", func(t *testing.T) { - testHelper(t, func(ctx context.Context) helper.Helper { return helper.New(ctx, argsWt, "crash-test-dummy", argF) }) + testHelper(t, func(ctx context.Context, cmdF func(cmd *exec.Cmd), stat bool) helper.Helper { + return helper.NewDirect(ctx, argsWt, "crash-test-dummy", argF, cmdF, stat) + }) }) } diff --git a/helper/helper.go b/helper/helper.go index f065d0e..c500c63 100644 --- a/helper/helper.go +++ b/helper/helper.go @@ -26,32 +26,22 @@ const ( ) type Helper interface { - // SetStdin sets the standard input of Helper. - SetStdin(r io.Reader) Helper - // SetStdout sets the standard output of Helper. - SetStdout(w io.Writer) Helper - // SetStderr sets the standard error of Helper. - SetStderr(w io.Writer) Helper - // SetEnv sets the environment of Helper. - SetEnv(env []string) Helper - // Start starts the helper process. - // A status pipe is passed to the helper if stat is true. - Start(stat bool) error - // Wait blocks until Helper exits and releases all its resources. + Start() error + // Wait blocks until Helper exits. Wait() error fmt.Stringer } func newHelperCmd( - h Helper, ctx context.Context, name string, + ctx context.Context, name string, wt io.WriterTo, argF func(argsFd, statFd int) []string, - extraFiles []*os.File, + extraFiles []*os.File, stat bool, ) (cmd *helperCmd) { cmd = new(helperCmd) - cmd.r = h cmd.ctx = ctx + cmd.hasStatFd = stat cmd.Cmd = commandContext(ctx, name) cmd.Cmd.Cancel = func() error { return cmd.Process.Signal(syscall.SIGTERM) } @@ -77,14 +67,13 @@ func newHelperCmd( // helperCmd wraps Cmd and implements methods shared across all Helper implementations. type helperCmd struct { - // ref to parent - r Helper - // returns an array of arguments passed directly // to the helper process argF func(statFd int) []string // whether argsFd is present hasArgsFd bool + // whether statFd is present + hasStatFd bool // closes statFd stat io.Closer @@ -97,13 +86,8 @@ type helperCmd struct { *exec.Cmd } -func (h *helperCmd) SetStdin(r io.Reader) Helper { h.Stdin = r; return h.r } -func (h *helperCmd) SetStdout(w io.Writer) Helper { h.Stdout = w; return h.r } -func (h *helperCmd) SetStderr(w io.Writer) Helper { h.Stderr = w; return h.r } -func (h *helperCmd) SetEnv(env []string) Helper { h.Env = env; return h.r } - // finalise sets up the underlying [exec.Cmd] object. -func (h *helperCmd) finalise(stat bool) (args []string) { +func (h *helperCmd) finalise() (args []string) { h.Env = slices.Grow(h.Env, 2) if h.hasArgsFd { h.Cmd.Env = append(h.Env, FortifyHelper+"=1") @@ -112,7 +96,7 @@ func (h *helperCmd) finalise(stat bool) (args []string) { } statFd := -1 - if stat { + if h.hasStatFd { f := proc.NewStat(&h.stat) statFd = int(proc.InitFile(f, h.extraFiles)) h.files = append(h.files, f) diff --git a/helper/helper_test.go b/helper/helper_test.go index 363cc04..3274fe5 100644 --- a/helper/helper_test.go +++ b/helper/helper_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "os/exec" "strconv" "strings" "testing" @@ -46,15 +47,13 @@ func argFChecked(argsFd, statFd int) (args []string) { } // this function tests an implementation of the helper.Helper interface -func testHelper(t *testing.T, createHelper func(ctx context.Context) helper.Helper) { +func testHelper(t *testing.T, createHelper func(ctx context.Context, cmdF func(cmd *exec.Cmd), stat bool) helper.Helper) { helper.InternalReplaceExecCommand(t) t.Run("start helper with status channel and wait", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - h := createHelper(ctx) - stdout, stderr := new(strings.Builder), new(strings.Builder) - h.SetStdout(stdout).SetStderr(stderr) + h := createHelper(ctx, func(cmd *exec.Cmd) { cmd.Stdout, cmd.Stderr = stdout, stderr }, true) t.Run("wait not yet started helper", func(t *testing.T) { defer func() { @@ -67,7 +66,7 @@ func testHelper(t *testing.T, createHelper func(ctx context.Context) helper.Help }) t.Log("starting helper stub") - if err := h.Start(true); err != nil { + if err := h.Start(); err != nil { t.Errorf("Start: error = %v", err) cancel() return @@ -77,7 +76,7 @@ func testHelper(t *testing.T, createHelper func(ctx context.Context) helper.Help t.Run("start already started helper", func(t *testing.T) { wantErr := "exec: already started" - if err := h.Start(true); err != nil && err.Error() != wantErr { + if err := h.Start(); err != nil && err.Error() != wantErr { t.Errorf("Start: error = %v, wantErr %v", err, wantErr) return @@ -108,12 +107,10 @@ func testHelper(t *testing.T, createHelper func(ctx context.Context) helper.Help t.Run("start helper and wait", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - h := createHelper(ctx) - stdout, stderr := new(strings.Builder), new(strings.Builder) - h.SetStdout(stdout).SetStderr(stderr) + h := createHelper(ctx, func(cmd *exec.Cmd) { cmd.Stdout, cmd.Stderr = stdout, stderr }, false) - if err := h.Start(false); err != nil { + if err := h.Start(); err != nil { t.Errorf("Start() error = %v", err) return diff --git a/internal/app/shim/main.go b/internal/app/shim/main.go index 9855535..5690f8f 100644 --- a/internal/app/shim/main.go +++ b/internal/app/shim/main.go @@ -131,15 +131,15 @@ func Main() { ctx, conf, path.Join(fst.Tmp, "sbin/init0"), false, nil, func(int, int) []string { return make([]string, 0) }, + func(cmd *exec.Cmd) { cmd.Stdin, cmd.Stdout, cmd.Stderr = os.Stdin, os.Stdout, os.Stderr }, extraFiles, syncFd, + false, ); err != nil { log.Fatalf("malformed sandbox config: %v", err) } else { - b.SetStdin(os.Stdin).SetStdout(os.Stdout).SetStderr(os.Stderr) - // run and pass through exit code - if err = b.Start(false); err != nil { + if err = b.Start(); err != nil { log.Fatalf("cannot start target process: %v", err) } else if err = b.Wait(); err != nil { var exitError *exec.ExitError