From 8accd3b2190d0527ec0d848c877b1a68abaac28c Mon Sep 17 00:00:00 2001 From: Ophestra Date: Wed, 22 Oct 2025 06:55:02 +0900 Subject: [PATCH] internal/app/shim: use syscall dispatcher This enables instrumented testing of the shim. Signed-off-by: Ophestra --- cmd/hakurei/command.go | 2 +- internal/app/app.go | 2 +- internal/app/app_test.go | 3 +- internal/app/dispatcher.go | 71 ++++++++++++++++-- internal/app/dispatcher_test.go | 54 ++++++++----- internal/app/hsu.go | 4 +- internal/app/shim.go | 129 +++++++++++++++++++++----------- 7 files changed, 190 insertions(+), 75 deletions(-) diff --git a/cmd/hakurei/command.go b/cmd/hakurei/command.go index 2a38760..15973cf 100644 --- a/cmd/hakurei/command.go +++ b/cmd/hakurei/command.go @@ -50,7 +50,7 @@ func buildCommand(ctx context.Context, msg message.Msg, early *earlyHardeningErr Flag(&flagVerbose, "v", command.BoolFlag(false), "Increase log verbosity"). Flag(&flagJSON, "json", command.BoolFlag(false), "Serialise output in JSON when applicable") - c.Command("shim", command.UsageInternal, func([]string) error { app.ShimMain(); return errSuccess }) + c.Command("shim", command.UsageInternal, func([]string) error { app.Shim(msg); return errSuccess }) c.Command("app", "Load and start container from configuration file", func(args []string) error { if len(args) < 1 { diff --git a/internal/app/app.go b/internal/app/app.go index 0d5f1a5..4d9d2ef 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -18,7 +18,7 @@ func Main(ctx context.Context, msg message.Msg, config *hst.Config) { log.Fatal(err) } - seal := outcome{syscallDispatcher: direct{}} + seal := outcome{syscallDispatcher: direct{msg}} if err := seal.finalise(ctx, msg, &id, config); err != nil { printMessageError("cannot seal app:", err) os.Exit(1) diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 3f3d298..cb3023b 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -574,10 +574,9 @@ func (s stubOsFileReadCloser) Stat() (fs.FileInfo, error) { panic("attempting to type stubNixOS struct { usernameErr map[string]error + panicDispatcher } -func (k *stubNixOS) new(func(k syscallDispatcher)) { panic("not implemented") } - func (k *stubNixOS) getpid() int { return 0xdeadbeef } func (k *stubNixOS) getuid() int { return 1971 } func (k *stubNixOS) getgid() int { return 100 } diff --git a/internal/app/dispatcher.go b/internal/app/dispatcher.go index d298d6b..b081ece 100644 --- a/internal/app/dispatcher.go +++ b/internal/app/dispatcher.go @@ -1,16 +1,18 @@ package app import ( + "context" "io" "io/fs" - "log" "os" "os/exec" + "os/signal" "os/user" "path/filepath" "hakurei.app/container" "hakurei.app/container/check" + "hakurei.app/container/seccomp" "hakurei.app/internal" "hakurei.app/message" "hakurei.app/system/dbus" @@ -28,7 +30,7 @@ type syscallDispatcher interface { // new starts a goroutine with a new instance of syscallDispatcher. // A syscallDispatcher must never be used in any goroutine other than the one owning it, // just synchronising access is not enough, as this is for test instrumentation. - new(f func(k syscallDispatcher)) + new(f func(k syscallDispatcher, msg message.Msg)) // getpid provides [os.Getpid]. getpid() int @@ -38,6 +40,8 @@ type syscallDispatcher interface { getgid() int // lookupEnv provides [os.LookupEnv]. lookupEnv(key string) (string, bool) + // pipe provides os.Pipe. + pipe() (r, w *os.File, err error) // stat provides [os.Stat]. stat(name string) (os.FileInfo, error) // open provides [os.Open]. @@ -46,6 +50,8 @@ type syscallDispatcher interface { readdir(name string) ([]os.DirEntry, error) // tempdir provides [os.TempDir]. tempdir() string + // exit provides [os.Exit]. + exit(code int) // evalSymlinks provides [filepath.EvalSymlinks]. evalSymlinks(path string) (string, error) @@ -56,10 +62,29 @@ type syscallDispatcher interface { // cmdOutput provides the Output method of [exec.Cmd]. cmdOutput(cmd *exec.Cmd) ([]byte, error) + // notifyContext provides [signal.NotifyContext]. + notifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) + + // prctl provides [container.Prctl]. + prctl(op, arg2, arg3 uintptr) error // overflowUid provides [container.OverflowUid]. overflowUid(msg message.Msg) int // overflowGid provides [container.OverflowGid]. overflowGid(msg message.Msg) int + // setDumpable provides [container.SetDumpable]. + setDumpable(dumpable uintptr) error + // receive provides [container.Receive]. + receive(key string, e any, fdp *uintptr) (closeFunc func() error, err error) + + // containerStart provides the Start method of [container.Container]. + containerStart(z *container.Container) error + // containerStart provides the Serve method of [container.Container]. + containerServe(z *container.Container) error + // containerStart provides the Wait method of [container.Container]. + containerWait(z *container.Container) error + + // seccompLoad provides [seccomp.Load]. + seccompLoad(rules []seccomp.NativeRule, flags seccomp.ExportFlag) error // mustHsuPath provides [internal.MustHsuPath]. mustHsuPath() *check.Absolute @@ -67,23 +92,32 @@ type syscallDispatcher interface { // dbusAddress provides [dbus.Address]. dbusAddress() (session, system string) + // setupContSignal provides setupContSignal. + setupContSignal(pid int) (io.ReadCloser, func(), error) + + // getMsg returns the [message.Msg] held by syscallDispatcher. + getMsg() message.Msg + // fatal provides [log.Fatal]. + fatal(v ...any) // fatalf provides [log.Fatalf]. fatalf(format string, v ...any) } // direct implements syscallDispatcher on the current kernel. -type direct struct{} +type direct struct{ msg message.Msg } -func (k direct) new(f func(k syscallDispatcher)) { go f(k) } +func (k direct) new(f func(k syscallDispatcher, msg message.Msg)) { go f(k, k.msg) } func (direct) getpid() int { return os.Getpid() } func (direct) getuid() int { return os.Getuid() } func (direct) getgid() int { return os.Getgid() } func (direct) lookupEnv(key string) (string, bool) { return os.LookupEnv(key) } +func (direct) pipe() (r, w *os.File, err error) { return os.Pipe() } func (direct) stat(name string) (os.FileInfo, error) { return os.Stat(name) } func (direct) open(name string) (osFile, error) { return os.Open(name) } func (direct) readdir(name string) ([]os.DirEntry, error) { return os.ReadDir(name) } func (direct) tempdir() string { return os.TempDir() } +func (direct) exit(code int) { os.Exit(code) } func (direct) evalSymlinks(path string) (string, error) { return filepath.EvalSymlinks(path) } @@ -98,11 +132,32 @@ func (direct) lookupGroupId(name string) (gid string, err error) { func (direct) cmdOutput(cmd *exec.Cmd) ([]byte, error) { return cmd.Output() } -func (direct) overflowUid(msg message.Msg) int { return container.OverflowUid(msg) } -func (direct) overflowGid(msg message.Msg) int { return container.OverflowGid(msg) } +func (direct) notifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) { + return signal.NotifyContext(parent, signals...) +} + +func (direct) prctl(op, arg2, arg3 uintptr) error { return container.Prctl(op, arg2, arg3) } +func (direct) overflowUid(msg message.Msg) int { return container.OverflowUid(msg) } +func (direct) overflowGid(msg message.Msg) int { return container.OverflowGid(msg) } +func (direct) setDumpable(dumpable uintptr) error { return container.SetDumpable(dumpable) } +func (direct) receive(key string, e any, fdp *uintptr) (func() error, error) { + return container.Receive(key, e, fdp) +} + +func (direct) containerStart(z *container.Container) error { return z.Start() } +func (direct) containerServe(z *container.Container) error { return z.Serve() } +func (direct) containerWait(z *container.Container) error { return z.Wait() } + +func (direct) seccompLoad(rules []seccomp.NativeRule, flags seccomp.ExportFlag) error { + return seccomp.Load(rules, flags) +} func (direct) mustHsuPath() *check.Absolute { return internal.MustHsuPath() } -func (k direct) dbusAddress() (session, system string) { return dbus.Address() } +func (direct) dbusAddress() (session, system string) { return dbus.Address() } -func (direct) fatalf(format string, v ...any) { log.Fatalf(format, v...) } +func (direct) setupContSignal(pid int) (io.ReadCloser, func(), error) { return setupContSignal(pid) } + +func (k direct) getMsg() message.Msg { return k.msg } +func (k direct) fatal(v ...any) { k.msg.GetLogger().Fatal(v...) } +func (k direct) fatalf(format string, v ...any) { k.msg.GetLogger().Fatalf(format, v...) } diff --git a/internal/app/dispatcher_test.go b/internal/app/dispatcher_test.go index 944d007..45d7f84 100644 --- a/internal/app/dispatcher_test.go +++ b/internal/app/dispatcher_test.go @@ -2,6 +2,7 @@ package app import ( "bytes" + "context" "io" "io/fs" "log" @@ -15,6 +16,7 @@ import ( "hakurei.app/container" "hakurei.app/container/check" + "hakurei.app/container/seccomp" "hakurei.app/container/stub" "hakurei.app/hst" "hakurei.app/internal/app/state" @@ -492,20 +494,38 @@ func (panicMsgContext) Value(any) any { panic("unreachable") } // This type is meant to be embedded in partial syscallDispatcher implementations. type panicDispatcher struct{} -func (panicDispatcher) new(func(k syscallDispatcher)) { panic("unreachable") } -func (panicDispatcher) getpid() int { panic("unreachable") } -func (panicDispatcher) getuid() int { panic("unreachable") } -func (panicDispatcher) getgid() int { panic("unreachable") } -func (panicDispatcher) lookupEnv(string) (string, bool) { panic("unreachable") } -func (panicDispatcher) stat(string) (os.FileInfo, error) { panic("unreachable") } -func (panicDispatcher) open(string) (osFile, error) { panic("unreachable") } -func (panicDispatcher) readdir(string) ([]os.DirEntry, error) { panic("unreachable") } -func (panicDispatcher) tempdir() string { panic("unreachable") } -func (panicDispatcher) evalSymlinks(string) (string, error) { panic("unreachable") } -func (panicDispatcher) lookupGroupId(string) (string, error) { panic("unreachable") } -func (panicDispatcher) cmdOutput(*exec.Cmd) ([]byte, error) { panic("unreachable") } -func (panicDispatcher) overflowUid(message.Msg) int { panic("unreachable") } -func (panicDispatcher) overflowGid(message.Msg) int { panic("unreachable") } -func (panicDispatcher) mustHsuPath() *check.Absolute { panic("unreachable") } -func (panicDispatcher) dbusAddress() (string, string) { panic("unreachable") } -func (panicDispatcher) fatalf(string, ...any) { panic("unreachable") } +func (panicDispatcher) new(func(k syscallDispatcher, msg message.Msg)) { panic("unreachable") } +func (panicDispatcher) getpid() int { panic("unreachable") } +func (panicDispatcher) getuid() int { panic("unreachable") } +func (panicDispatcher) getgid() int { panic("unreachable") } +func (panicDispatcher) lookupEnv(string) (string, bool) { panic("unreachable") } +func (panicDispatcher) pipe() (*os.File, *os.File, error) { panic("unreachable") } +func (panicDispatcher) stat(string) (os.FileInfo, error) { panic("unreachable") } +func (panicDispatcher) open(string) (osFile, error) { panic("unreachable") } +func (panicDispatcher) readdir(string) ([]os.DirEntry, error) { panic("unreachable") } +func (panicDispatcher) tempdir() string { panic("unreachable") } +func (panicDispatcher) exit(int) { panic("unreachable") } +func (panicDispatcher) evalSymlinks(string) (string, error) { panic("unreachable") } +func (panicDispatcher) prctl(uintptr, uintptr, uintptr) error { panic("unreachable") } +func (panicDispatcher) lookupGroupId(string) (string, error) { panic("unreachable") } +func (panicDispatcher) cmdOutput(*exec.Cmd) ([]byte, error) { panic("unreachable") } +func (panicDispatcher) overflowUid(message.Msg) int { panic("unreachable") } +func (panicDispatcher) overflowGid(message.Msg) int { panic("unreachable") } +func (panicDispatcher) setDumpable(uintptr) error { panic("unreachable") } +func (panicDispatcher) receive(string, any, *uintptr) (func() error, error) { panic("unreachable") } +func (panicDispatcher) containerStart(*container.Container) error { panic("unreachable") } +func (panicDispatcher) containerServe(*container.Container) error { panic("unreachable") } +func (panicDispatcher) containerWait(*container.Container) error { panic("unreachable") } +func (panicDispatcher) mustHsuPath() *check.Absolute { panic("unreachable") } +func (panicDispatcher) dbusAddress() (string, string) { panic("unreachable") } +func (panicDispatcher) setupContSignal(int) (io.ReadCloser, func(), error) { panic("unreachable") } +func (panicDispatcher) getMsg() message.Msg { panic("unreachable") } +func (panicDispatcher) fatal(...any) { panic("unreachable") } +func (panicDispatcher) fatalf(string, ...any) { panic("unreachable") } + +func (panicDispatcher) notifyContext(context.Context, ...os.Signal) (context.Context, context.CancelFunc) { + panic("unreachable") +} +func (panicDispatcher) seccompLoad([]seccomp.NativeRule, seccomp.ExportFlag) error { + panic("unreachable") +} diff --git a/internal/app/hsu.go b/internal/app/hsu.go index cae56c8..1286249 100644 --- a/internal/app/hsu.go +++ b/internal/app/hsu.go @@ -21,7 +21,9 @@ type Hsu struct { id int kOnce sync.Once - k syscallDispatcher + + // msg is not populated + k syscallDispatcher } var ErrHsuAccess = errors.New("current user is not in the hsurc file") diff --git a/internal/app/shim.go b/internal/app/shim.go index 9f85a20..11affdd 100644 --- a/internal/app/shim.go +++ b/internal/app/shim.go @@ -7,7 +7,6 @@ import ( "log" "os" "os/exec" - "os/signal" "runtime" "sync/atomic" "syscall" @@ -23,6 +22,20 @@ import ( //#include "shim-signal.h" import "C" +// setupContSignal sets up the SIGCONT signal handler for the cross-uid shim exit hack. +// The signal handler is implemented in C, signals can be processed by reading from the returned reader. +// The returned function must be called after all signal processing concludes. +func setupContSignal(pid int) (io.ReadCloser, func(), error) { + if r, w, err := os.Pipe(); err != nil { + return nil, nil, err + } else if _, err = C.hakurei_shim_setup_cont_signal(C.pid_t(pid), C.int(w.Fd())); err != nil { + _, _ = r.Close(), w.Close() + return nil, nil, err + } else { + return r, func() { runtime.KeepAlive(w) }, nil + } +} + // shimEnv is the name of the environment variable storing decimal representation of // setup pipe fd for [container.Receive]. const shimEnv = "HAKUREI_SHIM" @@ -46,76 +59,102 @@ type shimParams struct { // valid checks shimParams to be safe for use. func (p *shimParams) valid() bool { return p != nil && p.PrivPID > 0 } -// ShimMain is the main function of the shim process and runs as the unconstrained target user. -func ShimMain() { - log.SetPrefix("shim: ") - log.SetFlags(0) - msg := message.NewMsg(log.Default()) +// shimName is the prefix used by log.std in the shim process. +const shimName = "shim" - if err := container.SetDumpable(container.SUID_DUMP_DISABLE); err != nil { - log.Fatalf("cannot set SUID_DUMP_DISABLE: %s", err) +// Shim is called by the main function of the shim process and runs as the unconstrained target user. +// Shim does not return. +func Shim(msg message.Msg) { + if msg == nil { + msg = message.NewMsg(log.Default()) + } + shimEntrypoint(direct{msg}) +} + +func shimEntrypoint(k syscallDispatcher) { + msg := k.getMsg() + if msg == nil { + panic("attempting to call shimEntrypoint with nil msg") + } else if logger := msg.GetLogger(); logger != nil { + logger.SetPrefix(shimName + ": ") + logger.SetFlags(0) + } + + if err := k.setDumpable(container.SUID_DUMP_DISABLE); err != nil { + k.fatalf("cannot set SUID_DUMP_DISABLE: %s", err) } var ( state outcomeState closeSetup func() error ) - if f, err := container.Receive(shimEnv, &state, nil); err != nil { + if f, err := k.receive(shimEnv, &state, nil); err != nil { if errors.Is(err, syscall.EBADF) { - log.Fatal("invalid config descriptor") + k.fatal("invalid config descriptor") } if errors.Is(err, container.ErrReceiveEnv) { - log.Fatal(shimEnv + " not set") + k.fatal(shimEnv + " not set") } - log.Fatalf("cannot receive shim setup params: %v", err) + k.fatalf("cannot receive shim setup params: %v", err) } else { msg.SwapVerbose(state.Shim.Verbose) closeSetup = f - if err = state.populateLocal(direct{}, msg); err != nil { + if err = state.populateLocal(k, msg); err != nil { if m, ok := message.GetMessage(err); ok { - log.Fatal(m) + k.fatal(m) } else { - log.Fatalf("cannot populate local state: %v", err) + k.fatalf("cannot populate local state: %v", err) } } } // the Go runtime does not expose siginfo_t so SIGCONT is handled in C to check si_pid var signalPipe io.ReadCloser - if r, w, err := os.Pipe(); err != nil { - log.Fatalf("cannot pipe: %v", err) - } else if _, err = C.hakurei_shim_setup_cont_signal(C.pid_t(state.Shim.PrivPID), C.int(w.Fd())); err != nil { - log.Fatalf("cannot install SIGCONT handler: %v", err) + if r, wKeepAlive, err := k.setupContSignal(state.Shim.PrivPID); err != nil { + switch { + case errors.As(err, new(*os.SyscallError)): // returned by os.Pipe + k.fatal(err.Error()) + return + + case errors.As(err, new(syscall.Errno)): // returned by hakurei_shim_setup_cont_signal + k.fatalf("cannot install SIGCONT handler: %v", err) + return + + default: // unreachable + k.fatalf("cannot set up exit request: %v", err) + return + } + } else { - defer runtime.KeepAlive(w) + defer wKeepAlive() signalPipe = r } // pdeath_signal delivery is checked as if the dying process called kill(2), see kernel/exit.c - if _, _, errno := syscall.Syscall(syscall.SYS_PRCTL, syscall.PR_SET_PDEATHSIG, uintptr(syscall.SIGCONT), 0); errno != 0 { - log.Fatalf("cannot set parent-death signal: %v", errno) + if err := k.prctl(syscall.PR_SET_PDEATHSIG, uintptr(syscall.SIGCONT), 0); err != nil { + k.fatalf("cannot set parent-death signal: %v", err) } stateParams := state.newParams() for _, op := range state.Shim.Ops { if err := op.toContainer(stateParams); err != nil { if m, ok := message.GetMessage(err); ok { - log.Fatal(m) + k.fatal(m) } else { - log.Fatalf("cannot create container state: %v", err) + k.fatalf("cannot create container state: %v", err) } } } // shim exit outcomes var cancelContainer atomic.Pointer[context.CancelFunc] - go func() { + k.new(func(k syscallDispatcher, msg message.Msg) { buf := make([]byte, 1) for { if _, err := signalPipe.Read(buf); err != nil { - log.Fatalf("cannot read from signal pipe: %v", err) + k.fatalf("cannot read from signal pipe: %v", err) } switch buf[0] { @@ -128,37 +167,37 @@ func ShimMain() { // setup has not completed, terminate immediately msg.Resume() - os.Exit(hst.ExitRequest) + k.exit(hst.ExitRequest) return case 1: // got SIGCONT after adoption: monitor died before delivering signal msg.BeforeExit() - os.Exit(hst.ExitOrphan) + k.exit(hst.ExitOrphan) return case 2: // unreachable - log.Println("sa_sigaction got invalid siginfo") + msg.Verbose("sa_sigaction got invalid siginfo") case 3: // got SIGCONT from unexpected process: hopefully the terminal driver - log.Println("got SIGCONT from unexpected process") + msg.Verbose("got SIGCONT from unexpected process") default: // unreachable - log.Fatalf("got invalid message %d from signal handler", buf[0]) + k.fatalf("got invalid message %d from signal handler", buf[0]) } } - }() + }) if stateParams.params.Ops == nil { - log.Fatal("invalid container params") + k.fatal("invalid container params") } // close setup socket if err := closeSetup(); err != nil { - log.Printf("cannot close setup pipe: %v", err) + msg.Verbosef("cannot close setup pipe: %v", err) // not fatal } - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + ctx, stop := k.notifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) cancelContainer.Store(&stop) z := container.New(ctx, msg) z.Params = *stateParams.params @@ -167,30 +206,30 @@ func ShimMain() { // bounds and default enforced in finalise.go z.WaitDelay = state.Shim.WaitDelay - if err := z.Start(); err != nil { + if err := k.containerStart(z); err != nil { printMessageError("cannot start container:", err) - os.Exit(hst.ExitFailure) + k.exit(hst.ExitFailure) } - if err := z.Serve(); err != nil { + if err := k.containerServe(z); err != nil { printMessageError("cannot configure container:", err) } - if err := seccomp.Load( + if err := k.seccompLoad( seccomp.Preset(comp.PresetStrict, seccomp.AllowMultiarch), seccomp.AllowMultiarch, ); err != nil { - log.Fatalf("cannot load syscall filter: %v", err) + k.fatalf("cannot load syscall filter: %v", err) } - if err := z.Wait(); err != nil { + if err := k.containerWait(z); err != nil { var exitError *exec.ExitError if !errors.As(err, &exitError) { if errors.Is(err, context.Canceled) { - os.Exit(hst.ExitCancel) + k.exit(hst.ExitCancel) } - log.Printf("wait: %v", err) - os.Exit(127) + msg.Verbosef("cannot wait: %v", err) + k.exit(127) } - os.Exit(exitError.ExitCode()) + k.exit(exitError.ExitCode()) } }