app: expose single run method
All checks were successful
Tests / Go tests (push) Successful in 1m1s
Nix / NixOS tests (push) Successful in 3m20s

App is no longer just a simple [exec.Cmd] wrapper, so exposing these steps separately no longer makes sense and actually hinders proper error handling, cleanup and cancellation. This change removes the five-second wait when the shim dies before receiving the payload, and provides caller the ability to gracefully stop execution of the confined process.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
Ophestra 2025-01-15 23:39:51 +09:00
parent be4d8b6300
commit 124743ffd3
Signed by: cat
SSH Key Fingerprint: SHA256:gQ67O0enBZ7UdZypgtspB2FDM1g3GVw8nX0XSdcFw8Q
5 changed files with 195 additions and 161 deletions

View File

@ -1,14 +1,13 @@
package shim package shim
import ( import (
"context"
"encoding/gob" "encoding/gob"
"errors" "errors"
"os" "os"
"os/exec" "os/exec"
"os/signal"
"strconv" "strconv"
"strings" "strings"
"syscall"
"time" "time"
shim0 "git.gensokyo.uk/security/fortify/cmd/fshim/ipc" shim0 "git.gensokyo.uk/security/fortify/cmd/fshim/ipc"
@ -17,8 +16,6 @@ import (
"git.gensokyo.uk/security/fortify/internal/proc" "git.gensokyo.uk/security/fortify/internal/proc"
) )
const shimSetupTimeout = 5 * time.Second
// used by the parent process // used by the parent process
type Shim struct { type Shim struct {
@ -34,6 +31,8 @@ type Shim struct {
killFallback chan error killFallback chan error
// shim setup payload // shim setup payload
payload *shim0.Payload payload *shim0.Payload
// monitor to shim encoder
encoder *gob.Encoder
} }
func New(uid uint32, aid string, supp []string, payload *shim0.Payload) *Shim { func New(uid uint32, aid string, supp []string, payload *shim0.Payload) *Shim {
@ -56,7 +55,7 @@ func (s *Shim) WaitFallback() chan error {
} }
func (s *Shim) Start() (*time.Time, error) { func (s *Shim) Start() (*time.Time, error) {
// start user switcher process and save time // prepare user switcher invocation
var fsu string var fsu string
if p, ok := internal.Check(internal.Fsu); !ok { if p, ok := internal.Check(internal.Fsu); !ok {
fmsg.Fatal("invalid fsu path, this copy of fshim is not compiled correctly") fmsg.Fatal("invalid fsu path, this copy of fshim is not compiled correctly")
@ -66,18 +65,19 @@ func (s *Shim) Start() (*time.Time, error) {
} }
s.cmd = exec.Command(fsu) s.cmd = exec.Command(fsu)
var encoder *gob.Encoder // pass shim setup pipe
if fd, e, err := proc.Setup(&s.cmd.ExtraFiles); err != nil { if fd, e, err := proc.Setup(&s.cmd.ExtraFiles); err != nil {
return nil, fmsg.WrapErrorSuffix(err, return nil, fmsg.WrapErrorSuffix(err,
"cannot create shim setup pipe:") "cannot create shim setup pipe:")
} else { } else {
encoder = e s.encoder = e
s.cmd.Env = []string{ s.cmd.Env = []string{
shim0.Env + "=" + strconv.Itoa(fd), shim0.Env + "=" + strconv.Itoa(fd),
"FORTIFY_APP_ID=" + s.aid, "FORTIFY_APP_ID=" + s.aid,
} }
} }
// format fsu supplementary groups
if len(s.supp) > 0 { if len(s.supp) > 0 {
fmsg.VPrintf("attaching supplementary group ids %s", s.supp) fmsg.VPrintf("attaching supplementary group ids %s", s.supp)
s.cmd.Env = append(s.cmd.Env, "FORTIFY_GROUPS="+strings.Join(s.supp, " ")) s.cmd.Env = append(s.cmd.Env, "FORTIFY_GROUPS="+strings.Join(s.supp, " "))
@ -92,13 +92,17 @@ func (s *Shim) Start() (*time.Time, error) {
} }
fmsg.VPrintln("starting shim via fsu:", s.cmd) fmsg.VPrintln("starting shim via fsu:", s.cmd)
fmsg.Suspend() // withhold messages to stderr // withhold messages to stderr
fmsg.Suspend()
if err := s.cmd.Start(); err != nil { if err := s.cmd.Start(); err != nil {
return nil, fmsg.WrapErrorSuffix(err, return nil, fmsg.WrapErrorSuffix(err,
"cannot start fsu:") "cannot start fsu:")
} }
startTime := time.Now().UTC() startTime := time.Now().UTC()
return &startTime, nil
}
func (s *Shim) Serve(ctx context.Context) error {
// kill shim if something goes wrong and an error is returned // kill shim if something goes wrong and an error is returned
s.killFallback = make(chan error, 1) s.killFallback = make(chan error, 1)
killShim := func() { killShim := func() {
@ -108,30 +112,31 @@ func (s *Shim) Start() (*time.Time, error) {
} }
defer func() { killShim() }() defer func() { killShim() }()
// take alternative exit path on signal encodeErr := make(chan error)
sig := make(chan os.Signal, 2) go func() { encodeErr <- s.encoder.Encode(s.payload) }()
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
go func() {
v := <-sig
fmsg.Printf("got %s after program start", v)
s.killFallback <- nil
signal.Ignore(syscall.SIGINT, syscall.SIGTERM)
}()
shimErr := make(chan error)
go func() { shimErr <- encoder.Encode(s.payload) }()
select { select {
case err := <-shimErr: // encode return indicates setup completion
case err := <-encodeErr:
if err != nil { if err != nil {
return &startTime, fmsg.WrapErrorSuffix(err, return fmsg.WrapErrorSuffix(err,
"cannot transmit shim config:") "cannot transmit shim config:")
} }
killShim = func() {} killShim = func() {}
case <-time.After(shimSetupTimeout): return nil
return &startTime, fmsg.WrapError(errors.New("timed out waiting for shim"),
"timed out waiting for shim")
}
return &startTime, nil // setup canceled before payload was accepted
case <-ctx.Done():
err := ctx.Err()
if errors.Is(err, context.Canceled) {
return fmsg.WrapError(errors.New("shim setup canceled"),
"shim setup canceled")
}
if errors.Is(err, context.DeadlineExceeded) {
return fmsg.WrapError(errors.New("deadline exceeded waiting for shim"),
"deadline exceeded waiting for shim")
}
// unreachable
return err
}
} }

View File

@ -1,6 +1,7 @@
package app package app
import ( import (
"context"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -12,17 +13,22 @@ import (
type App interface { type App interface {
// ID returns a copy of App's unique ID. // ID returns a copy of App's unique ID.
ID() fst.ID ID() fst.ID
// Start sets up the system and starts the App. // Run sets up the system and runs the App.
Start() error Run(ctx context.Context, rs *RunState) error
// Wait waits for App's process to exit and reverts system setup.
Wait() (int, error)
// WaitErr returns error returned by the underlying wait syscall.
WaitErr() error
Seal(config *fst.Config) error Seal(config *fst.Config) error
String() string String() string
} }
type RunState struct {
// Start is true if fsu is successfully started.
Start bool
// ExitCode is the value returned by fshim.
ExitCode int
// WaitErr is error returned by the underlying wait syscall.
WaitErr error
}
type app struct { type app struct {
// single-use config reference // single-use config reference
ct *appCt ct *appCt
@ -35,8 +41,6 @@ type app struct {
shim *shim.Shim shim *shim.Shim
// child process related information // child process related information
seal *appSeal seal *appSeal
// error returned waiting for process
waitErr error
lock sync.RWMutex lock sync.RWMutex
} }
@ -64,10 +68,6 @@ func (a *app) String() string {
return "(unsealed fortified app)" return "(unsealed fortified app)"
} }
func (a *app) WaitErr() error {
return a.waitErr
}
func New(os linux.System) (App, error) { func New(os linux.System) (App, error) {
a := new(app) a := new(app)
a.id = new(fst.ID) a.id = new(fst.ID)

View File

@ -1,11 +1,13 @@
package app package app
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
"time"
shim0 "git.gensokyo.uk/security/fortify/cmd/fshim/ipc" shim0 "git.gensokyo.uk/security/fortify/cmd/fshim/ipc"
"git.gensokyo.uk/security/fortify/cmd/fshim/ipc/shim" "git.gensokyo.uk/security/fortify/cmd/fshim/ipc/shim"
@ -15,12 +17,16 @@ import (
"git.gensokyo.uk/security/fortify/internal/system" "git.gensokyo.uk/security/fortify/internal/system"
) )
// Start selects a user switcher and starts shim. const shimSetupTimeout = 5 * time.Second
// Note that Wait must be called regardless of error returned by Start.
func (a *app) Start() error { func (a *app) Run(ctx context.Context, rs *RunState) error {
a.lock.Lock() a.lock.Lock()
defer a.lock.Unlock() defer a.lock.Unlock()
if rs == nil {
panic("attempted to pass nil state to run")
}
// resolve exec paths // resolve exec paths
shimExec := [2]string{helper.BubblewrapName} shimExec := [2]string{helper.BubblewrapName}
if len(a.seal.command) > 0 { if len(a.seal.command) > 0 {
@ -64,10 +70,30 @@ func (a *app) Start() error {
// export sync pipe from sys // export sync pipe from sys
a.seal.sys.bwrap.SetSync(a.seal.sys.Sync()) a.seal.sys.bwrap.SetSync(a.seal.sys.Sync())
// start shim via manager
waitErr := make(chan error, 1)
if startTime, err := a.shim.Start(); err != nil { if startTime, err := a.shim.Start(); err != nil {
return err return err
} else { } else {
// shim start and setup success, create process state // shim process created
rs.Start = true
shimSetupCtx, shimSetupCancel := context.WithDeadline(ctx, time.Now().Add(shimSetupTimeout))
defer shimSetupCancel()
// start waiting for shim
go func() {
waitErr <- a.shim.Unwrap().Wait()
// cancel shim setup in case shim died before receiving payload
shimSetupCancel()
}()
// send payload
if err = a.shim.Serve(shimSetupCtx); err != nil {
return err
}
// shim accepted setup payload, create process state
sd := state.State{ sd := state.State{
ID: *a.id, ID: *a.id,
PID: a.shim.Unwrap().Process.Pid, PID: a.shim.Unwrap().Process.Pid,
@ -81,110 +107,41 @@ func (a *app) Start() error {
err0.InnerErr = c.Save(&sd) err0.InnerErr = c.Save(&sd)
}) })
a.seal.sys.saveState = true a.seal.sys.saveState = true
return err0.equiv("cannot save process state:") if err = err0.equiv("cannot save process state:"); err != nil {
return err
} }
} }
// StateStoreError is returned for a failed state save
type StateStoreError struct {
// whether inner function was called
Inner bool
// error returned by state.Store Do method
DoErr error
// error returned by state.Backend Save method
InnerErr error
// any other errors needing to be tracked
Err error
}
func (e *StateStoreError) equiv(a ...any) error {
if e.Inner && e.DoErr == nil && e.InnerErr == nil && e.Err == nil {
return nil
} else {
return fmsg.WrapErrorSuffix(e, a...)
}
}
func (e *StateStoreError) Error() string {
if e.Inner && e.InnerErr != nil {
return e.InnerErr.Error()
}
if e.DoErr != nil {
return e.DoErr.Error()
}
if e.Err != nil {
return e.Err.Error()
}
return "(nil)"
}
func (e *StateStoreError) Unwrap() (errs []error) {
errs = make([]error, 0, 3)
if e.DoErr != nil {
errs = append(errs, e.DoErr)
}
if e.InnerErr != nil {
errs = append(errs, e.InnerErr)
}
if e.Err != nil {
errs = append(errs, e.Err)
}
return
}
type RevertCompoundError interface {
Error() string
Unwrap() []error
}
func (a *app) Wait() (int, error) {
a.lock.Lock()
defer a.lock.Unlock()
if a.shim == nil {
fmsg.VPrintln("shim not initialised, skipping cleanup")
return 1, nil
}
var r int
if cmd := a.shim.Unwrap(); cmd == nil {
// failure prior to process start
r = 255
} else {
wait := make(chan error, 1)
go func() { wait <- cmd.Wait() }()
select { select {
// wait for process and resolve exit code // wait for process and resolve exit code
case err := <-wait: case err := <-waitErr:
if err != nil { if err != nil {
var exitError *exec.ExitError var exitError *exec.ExitError
if !errors.As(err, &exitError) { if !errors.As(err, &exitError) {
// should be unreachable // should be unreachable
a.waitErr = err rs.WaitErr = err
} }
// store non-zero return code // store non-zero return code
r = exitError.ExitCode() rs.ExitCode = exitError.ExitCode()
} else { } else {
r = cmd.ProcessState.ExitCode() rs.ExitCode = a.shim.Unwrap().ProcessState.ExitCode()
}
if fmsg.Verbose() {
fmsg.VPrintf("process %d exited with exit code %d", a.shim.Unwrap().Process.Pid, rs.ExitCode)
} }
fmsg.VPrintf("process %d exited with exit code %d", cmd.Process.Pid, r)
// alternative exit path when kill was unsuccessful // this is reached when a fault makes an already running shim impossible to continue execution
// however a kill signal could not be delivered (should actually always happen like that since fsu)
// the effects of this is similar to the alternative exit path and ensures shim death
case err := <-a.shim.WaitFallback(): case err := <-a.shim.WaitFallback():
r = 255 rs.ExitCode = 255
if err != nil {
fmsg.Printf("cannot terminate shim on faulted setup: %v", err) fmsg.Printf("cannot terminate shim on faulted setup: %v", err)
} else {
// alternative exit path relying on shim behaviour on monitor process exit
case <-ctx.Done():
fmsg.VPrintln("alternative exit path selected") fmsg.VPrintln("alternative exit path selected")
} }
}
}
// child process exited, resume output // child process exited, resume output
fmsg.Resume() fmsg.Resume()
@ -262,5 +219,60 @@ func (a *app) Wait() (int, error) {
}) })
e.Err = a.seal.store.Close() e.Err = a.seal.store.Close()
return r, e.equiv("error returned during cleanup:", e) return e.equiv("error returned during cleanup:", e)
}
// StateStoreError is returned for a failed state save
type StateStoreError struct {
// whether inner function was called
Inner bool
// error returned by state.Store Do method
DoErr error
// error returned by state.Backend Save method
InnerErr error
// any other errors needing to be tracked
Err error
}
func (e *StateStoreError) equiv(a ...any) error {
if e.Inner && e.DoErr == nil && e.InnerErr == nil && e.Err == nil {
return nil
} else {
return fmsg.WrapErrorSuffix(e, a...)
}
}
func (e *StateStoreError) Error() string {
if e.Inner && e.InnerErr != nil {
return e.InnerErr.Error()
}
if e.DoErr != nil {
return e.DoErr.Error()
}
if e.Err != nil {
return e.Err.Error()
}
return "(nil)"
}
func (e *StateStoreError) Unwrap() (errs []error) {
errs = make([]error, 0, 3)
if e.DoErr != nil {
errs = append(errs, e.DoErr)
}
if e.InnerErr != nil {
errs = append(errs, e.InnerErr)
}
if e.Err != nil {
errs = append(errs, e.Err)
}
return
}
type RevertCompoundError interface {
Error() string
Unwrap() []error
} }

View File

@ -12,6 +12,7 @@ var (
ErrInvalid = errors.New("bad file descriptor") ErrInvalid = errors.New("bad file descriptor")
) )
// Setup appends the read end of a pipe for payload transmission and returns its fd.
func Setup(extraFiles *[]*os.File) (int, *gob.Encoder, error) { func Setup(extraFiles *[]*os.File) (int, *gob.Encoder, error) {
if r, w, err := os.Pipe(); err != nil { if r, w, err := os.Pipe(); err != nil {
return -1, nil, err return -1, nil, err
@ -22,6 +23,8 @@ func Setup(extraFiles *[]*os.File) (int, *gob.Encoder, error) {
} }
} }
// Receive retrieves payload pipe fd from the environment,
// receives its payload and returns the Close method of the pipe.
func Receive(key string, e any) (func() error, error) { func Receive(key string, e any) (func() error, error) {
var setup *os.File var setup *os.File

42
main.go
View File

@ -1,14 +1,17 @@
package main package main
import ( import (
"context"
_ "embed" _ "embed"
"flag" "flag"
"fmt" "fmt"
"os" "os"
"os/signal"
"os/user" "os/user"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"syscall"
"text/tabwriter" "text/tabwriter"
"git.gensokyo.uk/security/fortify/dbus" "git.gensokyo.uk/security/fortify/dbus"
@ -288,27 +291,38 @@ func main() {
} }
func runApp(config *fst.Config) { func runApp(config *fst.Config) {
a, err := app.New(sys) rs := new(app.RunState)
if err != nil { ctx, cancel := context.WithCancel(context.Background())
// handle signals for graceful shutdown
sig := make(chan os.Signal, 2)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
go func() {
v := <-sig
fmsg.Printf("got %s after program start", v)
cancel()
signal.Ignore(syscall.SIGINT, syscall.SIGTERM)
}()
if a, err := app.New(sys); err != nil {
fmsg.Fatalf("cannot create app: %s\n", err) fmsg.Fatalf("cannot create app: %s\n", err)
} else if err = a.Seal(config); err != nil { } else if err = a.Seal(config); err != nil {
logBaseError(err, "cannot seal app:") logBaseError(err, "cannot seal app:")
fmsg.Exit(1) fmsg.Exit(1)
} else if err = a.Start(); err != nil { } else if err = a.Run(ctx, rs); err != nil {
if !rs.Start {
logBaseError(err, "cannot start app:") logBaseError(err, "cannot start app:")
} } else {
var r int
// wait must be called regardless of result of start
if r, err = a.Wait(); err != nil {
if r < 1 {
r = 1
}
logWaitError(err) logWaitError(err)
} }
if err = a.WaitErr(); err != nil {
fmsg.Println("inner wait failed:", err)
} }
fmsg.Exit(r) if rs.WaitErr != nil {
fmsg.Println("inner wait failed:", rs.WaitErr)
}
if rs.ExitCode < 0 {
fmsg.VPrintf("got negative exit %v", rs.ExitCode)
fmsg.Exit(1)
}
fmsg.Exit(rs.ExitCode)
panic("unreachable") panic("unreachable")
} }