From c61cdc505f05decd940cb2d861241039a49251e3 Mon Sep 17 00:00:00 2001 From: Ophestra Date: Tue, 7 Apr 2026 16:31:46 +0900 Subject: [PATCH] internal/params: relocate from package container This does not make sense as part of the public API, so make it internal. Signed-off-by: Ophestra --- container/container_test.go | 5 +- container/dispatcher.go | 7 +- container/dispatcher_test.go | 13 ++- container/init.go | 71 ++++++++------- container/init_test.go | 3 +- container/params.go | 36 -------- container/params_test.go | 53 ----------- internal/outcome/dispatcher.go | 7 +- internal/outcome/dispatcher_test.go | 68 +++++++-------- internal/outcome/shim.go | 3 +- internal/outcome/shim_test.go | 3 +- internal/params/params.go | 42 +++++++++ internal/params/params_test.go | 131 ++++++++++++++++++++++++++++ 13 files changed, 269 insertions(+), 173 deletions(-) delete mode 100644 container/params.go delete mode 100644 container/params_test.go create mode 100644 internal/params/params.go create mode 100644 internal/params/params_test.go diff --git a/container/container_test.go b/container/container_test.go index 169efd8e..092a7a37 100644 --- a/container/container_test.go +++ b/container/container_test.go @@ -26,6 +26,7 @@ import ( "hakurei.app/fhs" "hakurei.app/hst" "hakurei.app/internal/info" + "hakurei.app/internal/params" "hakurei.app/ldd" "hakurei.app/message" "hakurei.app/vfs" @@ -84,9 +85,9 @@ func TestStartError(t *testing.T) { {"params env", &container.StartError{ Fatal: true, Step: "set up params stream", - Err: container.ErrReceiveEnv, + Err: params.ErrReceiveEnv, }, "set up params stream: environment variable not set", - container.ErrReceiveEnv, syscall.EBADF, + params.ErrReceiveEnv, syscall.EBADF, "cannot set up params stream: environment variable not set"}, {"params", &container.StartError{ diff --git a/container/dispatcher.go b/container/dispatcher.go index d3ff4f00..fd885014 100644 --- a/container/dispatcher.go +++ b/container/dispatcher.go @@ -16,6 +16,7 @@ import ( "hakurei.app/container/std" "hakurei.app/ext" "hakurei.app/internal/netlink" + "hakurei.app/internal/params" "hakurei.app/message" ) @@ -56,7 +57,7 @@ type syscallDispatcher interface { // isatty provides [Isatty]. isatty(fd int) bool // receive provides [Receive]. - receive(key string, e any, fdp *uintptr) (closeFunc func() error, err error) + receive(key string, e any, fdp *int) (closeFunc func() error, err error) // bindMount provides procPaths.bindMount. bindMount(msg message.Msg, source, target string, flags uintptr) error @@ -155,8 +156,8 @@ func (direct) capBoundingSetDrop(cap uintptr) error { return capBound func (direct) capAmbientClearAll() error { return capAmbientClearAll() } func (direct) capAmbientRaise(cap uintptr) error { return capAmbientRaise(cap) } func (direct) isatty(fd int) bool { return ext.Isatty(fd) } -func (direct) receive(key string, e any, fdp *uintptr) (func() error, error) { - return Receive(key, e, fdp) +func (direct) receive(key string, e any, fdp *int) (func() error, error) { + return params.Receive(key, e, fdp) } func (direct) bindMount(msg message.Msg, source, target string, flags uintptr) error { diff --git a/container/dispatcher_test.go b/container/dispatcher_test.go index b57c53b9..c8a7b492 100644 --- a/container/dispatcher_test.go +++ b/container/dispatcher_test.go @@ -390,7 +390,7 @@ func (k *kstub) isatty(fd int) bool { return expect.Ret.(bool) } -func (k *kstub) receive(key string, e any, fdp *uintptr) (closeFunc func() error, err error) { +func (k *kstub) receive(key string, e any, fdp *int) (closeFunc func() error, err error) { k.Helper() expect := k.Expects("receive") @@ -408,10 +408,17 @@ func (k *kstub) receive(key string, e any, fdp *uintptr) (closeFunc func() error } return nil } + + // avoid changing test cases + var fdpComp *uintptr + if fdp != nil { + fdpComp = new(uintptr(*fdp)) + } + err = expect.Error( stub.CheckArg(k.Stub, "key", key, 0), stub.CheckArgReflect(k.Stub, "e", e, 1), - stub.CheckArgReflect(k.Stub, "fdp", fdp, 2)) + stub.CheckArgReflect(k.Stub, "fdp", fdpComp, 2)) // 3 is unused so stores params if expect.Args[3] != nil { @@ -426,7 +433,7 @@ func (k *kstub) receive(key string, e any, fdp *uintptr) (closeFunc func() error if expect.Args[4] != nil { if v, ok := expect.Args[4].(uintptr); ok && v >= 3 { if fdp != nil { - *fdp = v + *fdp = int(v) } } } diff --git a/container/init.go b/container/init.go index 4d5bb8a9..f4796f1d 100644 --- a/container/init.go +++ b/container/init.go @@ -19,6 +19,7 @@ import ( "hakurei.app/container/seccomp" "hakurei.app/ext" "hakurei.app/fhs" + "hakurei.app/internal/params" "hakurei.app/message" ) @@ -147,35 +148,33 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { } var ( - params initParams - closeSetup func() error - setupFd uintptr - offsetSetup int + param initParams + closeSetup func() error + setupFd int ) - if f, err := k.receive(setupEnv, ¶ms, &setupFd); err != nil { + if f, err := k.receive(setupEnv, ¶m, &setupFd); err != nil { if errors.Is(err, EBADF) { k.fatal(msg, "invalid setup descriptor") } - if errors.Is(err, ErrReceiveEnv) { + if errors.Is(err, params.ErrReceiveEnv) { k.fatal(msg, setupEnv+" not set") } k.fatalf(msg, "cannot decode init setup payload: %v", err) } else { - if params.Ops == nil { + if param.Ops == nil { k.fatal(msg, "invalid setup parameters") } - if params.ParentPerm == 0 { - params.ParentPerm = 0755 + if param.ParentPerm == 0 { + param.ParentPerm = 0755 } - msg.SwapVerbose(params.Verbose) + msg.SwapVerbose(param.Verbose) msg.Verbose("received setup parameters") closeSetup = f - offsetSetup = int(setupFd + 1) } - if !params.HostNet { + if !param.HostNet { ctx, cancel := signal.NotifyContext(context.Background(), CancelSignal, os.Interrupt, SIGTERM, SIGQUIT) defer cancel() // for panics @@ -188,7 +187,7 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { k.fatalf(msg, "cannot set SUID_DUMP_USER: %v", err) } if err := k.writeFile(fhs.Proc+"self/uid_map", - append([]byte{}, strconv.Itoa(params.Uid)+" "+strconv.Itoa(params.HostUid)+" 1\n"...), + append([]byte{}, strconv.Itoa(param.Uid)+" "+strconv.Itoa(param.HostUid)+" 1\n"...), 0); err != nil { k.fatalf(msg, "%v", err) } @@ -198,7 +197,7 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { k.fatalf(msg, "%v", err) } if err := k.writeFile(fhs.Proc+"self/gid_map", - append([]byte{}, strconv.Itoa(params.Gid)+" "+strconv.Itoa(params.HostGid)+" 1\n"...), + append([]byte{}, strconv.Itoa(param.Gid)+" "+strconv.Itoa(param.HostGid)+" 1\n"...), 0); err != nil { k.fatalf(msg, "%v", err) } @@ -207,8 +206,8 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { } oldmask := k.umask(0) - if params.Hostname != "" { - if err := k.sethostname([]byte(params.Hostname)); err != nil { + if param.Hostname != "" { + if err := k.sethostname([]byte(param.Hostname)); err != nil { k.fatalf(msg, "cannot set hostname: %v", err) } } @@ -221,7 +220,7 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { } ctx, cancel := context.WithCancel(context.Background()) - state := &setupState{process: make(map[int]WaitStatus), Params: ¶ms.Params, Msg: msg, Context: ctx} + state := &setupState{process: make(map[int]WaitStatus), Params: ¶m.Params, Msg: msg, Context: ctx} defer cancel() /* early is called right before pivot_root into intermediate root; @@ -229,7 +228,7 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { difficult to obtain via library functions after pivot_root, and implementations are expected to avoid changing the state of the mount namespace */ - for i, op := range *params.Ops { + for i, op := range *param.Ops { if op == nil || !op.Valid() { k.fatalf(msg, "invalid op at index %d", i) } @@ -272,7 +271,7 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { step sets up the container filesystem, and implementations are expected to keep the host root and sysroot mount points intact but otherwise can do whatever they need to. Calling chdir is allowed but discouraged. */ - for i, op := range *params.Ops { + for i, op := range *param.Ops { // ops already checked during early setup if prefix, ok := op.prefix(); ok { msg.Verbosef("%s %s", prefix, op) @@ -328,7 +327,7 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { k.fatalf(msg, "cannot clear the ambient capability set: %v", err) } for i := uintptr(0); i <= lastcap; i++ { - if params.Privileged && i == CAP_SYS_ADMIN { + if param.Privileged && i == CAP_SYS_ADMIN { continue } if err := k.capBoundingSetDrop(i); err != nil { @@ -337,7 +336,7 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { } var keep [2]uint32 - if params.Privileged { + if param.Privileged { keep[capToIndex(CAP_SYS_ADMIN)] |= capToMask(CAP_SYS_ADMIN) if err := k.capAmbientRaise(CAP_SYS_ADMIN); err != nil { @@ -351,13 +350,13 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { k.fatalf(msg, "cannot capset: %v", err) } - if !params.SeccompDisable { - rules := params.SeccompRules + if !param.SeccompDisable { + rules := param.SeccompRules if len(rules) == 0 { // non-empty rules slice always overrides presets - msg.Verbosef("resolving presets %#x", params.SeccompPresets) - rules = seccomp.Preset(params.SeccompPresets, params.SeccompFlags) + msg.Verbosef("resolving presets %#x", param.SeccompPresets) + rules = seccomp.Preset(param.SeccompPresets, param.SeccompFlags) } - if err := k.seccompLoad(rules, params.SeccompFlags); err != nil { + if err := k.seccompLoad(rules, param.SeccompFlags); err != nil { // this also indirectly asserts PR_SET_NO_NEW_PRIVS k.fatalf(msg, "cannot load syscall filter: %v", err) } @@ -366,10 +365,10 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { msg.Verbose("syscall filter not configured") } - extraFiles := make([]*os.File, params.Count) + extraFiles := make([]*os.File, param.Count) for i := range extraFiles { // setup fd is placed before all extra files - extraFiles[i] = k.newFile(uintptr(offsetSetup+i), "extra file "+strconv.Itoa(i)) + extraFiles[i] = k.newFile(uintptr(setupFd+1+i), "extra file "+strconv.Itoa(i)) } k.umask(oldmask) @@ -447,7 +446,7 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { // called right before startup of initial process, all state changes to the // current process is prohibited during late - for i, op := range *params.Ops { + for i, op := range *param.Ops { // ops already checked during early setup if err := op.late(state, k); err != nil { if m, ok := messageFromError(err); ok { @@ -468,14 +467,14 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { k.fatalf(msg, "cannot close setup pipe: %v", err) } - cmd := exec.Command(params.Path.String()) + cmd := exec.Command(param.Path.String()) cmd.Stdin, cmd.Stdout, cmd.Stderr = os.Stdin, os.Stdout, os.Stderr - cmd.Args = params.Args - cmd.Env = params.Env + cmd.Args = param.Args + cmd.Env = param.Env cmd.ExtraFiles = extraFiles - cmd.Dir = params.Dir.String() + cmd.Dir = param.Dir.String() - msg.Verbosef("starting initial process %s", params.Path) + msg.Verbosef("starting initial process %s", param.Path) if err := k.start(cmd); err != nil { k.fatalf(msg, "%v", err) } @@ -493,7 +492,7 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { for { select { case s := <-sig: - if s == CancelSignal && params.ForwardCancel && cmd.Process != nil { + if s == CancelSignal && param.ForwardCancel && cmd.Process != nil { msg.Verbose("forwarding context cancellation") if err := k.signal(cmd, os.Interrupt); err != nil && !errors.Is(err, os.ErrProcessDone) { k.printf(msg, "cannot forward cancellation: %v", err) @@ -525,7 +524,7 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { cancel() // start timeout early - go func() { time.Sleep(params.AdoptWaitDelay); close(timeout) }() + go func() { time.Sleep(param.AdoptWaitDelay); close(timeout) }() // close initial process files; this also keeps them alive for _, f := range extraFiles { diff --git a/container/init_test.go b/container/init_test.go index 02531e99..a4eb1182 100644 --- a/container/init_test.go +++ b/container/init_test.go @@ -10,6 +10,7 @@ import ( "hakurei.app/check" "hakurei.app/container/seccomp" "hakurei.app/container/std" + "hakurei.app/internal/params" "hakurei.app/internal/stub" ) @@ -40,7 +41,7 @@ func TestInitEntrypoint(t *testing.T) { call("lockOSThread", stub.ExpectArgs{}, nil, nil), call("getpid", stub.ExpectArgs{}, 1, nil), call("setPtracer", stub.ExpectArgs{uintptr(0)}, nil, nil), - call("receive", stub.ExpectArgs{"HAKUREI_SETUP", new(initParams), new(uintptr)}, nil, ErrReceiveEnv), + call("receive", stub.ExpectArgs{"HAKUREI_SETUP", new(initParams), new(uintptr)}, nil, params.ErrReceiveEnv), call("fatal", stub.ExpectArgs{[]any{"HAKUREI_SETUP not set"}}, nil, nil), }, }, nil}, diff --git a/container/params.go b/container/params.go deleted file mode 100644 index e6cd26f8..00000000 --- a/container/params.go +++ /dev/null @@ -1,36 +0,0 @@ -package container - -import ( - "encoding/gob" - "errors" - "os" - "strconv" - "syscall" -) - -var ( - ErrReceiveEnv = errors.New("environment variable not set") -) - -// Receive retrieves setup fd from the environment and receives params. -func Receive(key string, e any, fdp *uintptr) (func() error, error) { - var setup *os.File - - if s, ok := os.LookupEnv(key); !ok { - return nil, ErrReceiveEnv - } else { - if fd, err := strconv.Atoi(s); err != nil { - return nil, optionalErrorUnwrap(err) - } else { - setup = os.NewFile(uintptr(fd), "setup") - if setup == nil { - return nil, syscall.EDOM - } - if fdp != nil { - *fdp = setup.Fd() - } - } - } - - return setup.Close, gob.NewDecoder(setup).Decode(e) -} diff --git a/container/params_test.go b/container/params_test.go deleted file mode 100644 index 4d2f84f6..00000000 --- a/container/params_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package container_test - -import ( - "errors" - "os" - "strconv" - "syscall" - "testing" - - "hakurei.app/container" -) - -func TestSetupReceive(t *testing.T) { - t.Run("not set", func(t *testing.T) { - const key = "TEST_ENV_NOT_SET" - { - v, ok := os.LookupEnv(key) - t.Cleanup(func() { - if ok { - if err := os.Setenv(key, v); err != nil { - t.Fatalf("Setenv: error = %v", err) - } - } else { - if err := os.Unsetenv(key); err != nil { - t.Fatalf("Unsetenv: error = %v", err) - } - } - }) - } - - if _, err := container.Receive(key, nil, nil); !errors.Is(err, container.ErrReceiveEnv) { - t.Errorf("Receive: error = %v, want %v", err, container.ErrReceiveEnv) - } - }) - - t.Run("format", func(t *testing.T) { - const key = "TEST_ENV_FORMAT" - t.Setenv(key, "") - - if _, err := container.Receive(key, nil, nil); !errors.Is(err, strconv.ErrSyntax) { - t.Errorf("Receive: error = %v, want %v", err, strconv.ErrSyntax) - } - }) - - t.Run("range", func(t *testing.T) { - const key = "TEST_ENV_RANGE" - t.Setenv(key, "-1") - - if _, err := container.Receive(key, nil, nil); !errors.Is(err, syscall.EDOM) { - t.Errorf("Receive: error = %v, want %v", err, syscall.EDOM) - } - }) -} diff --git a/internal/outcome/dispatcher.go b/internal/outcome/dispatcher.go index a3d3dbbc..892e6baf 100644 --- a/internal/outcome/dispatcher.go +++ b/internal/outcome/dispatcher.go @@ -17,6 +17,7 @@ import ( "hakurei.app/ext" "hakurei.app/internal/dbus" "hakurei.app/internal/info" + "hakurei.app/internal/params" "hakurei.app/message" ) @@ -84,7 +85,7 @@ type syscallDispatcher interface { // setDumpable provides [container.SetDumpable]. setDumpable(dumpable uintptr) error // receive provides [container.Receive]. - receive(key string, e any, fdp *uintptr) (closeFunc func() error, err error) + receive(key string, e any, fdp *int) (closeFunc func() error, err error) // containerStart provides the Start method of [container.Container]. containerStart(z *container.Container) error @@ -154,8 +155,8 @@ func (direct) prctl(op, arg2, arg3 uintptr) error { return ext.Prctl(op, arg2, a func (direct) overflowUid(msg message.Msg) int { return container.OverflowUid(msg) } func (direct) overflowGid(msg message.Msg) int { return container.OverflowGid(msg) } func (direct) setDumpable(dumpable uintptr) error { return ext.SetDumpable(dumpable) } -func (direct) receive(key string, e any, fdp *uintptr) (func() error, error) { - return container.Receive(key, e, fdp) +func (direct) receive(key string, e any, fdp *int) (func() error, error) { + return params.Receive(key, e, fdp) } func (direct) containerStart(z *container.Container) error { return z.Start() } diff --git a/internal/outcome/dispatcher_test.go b/internal/outcome/dispatcher_test.go index a099df58..9ae2ba52 100644 --- a/internal/outcome/dispatcher_test.go +++ b/internal/outcome/dispatcher_test.go @@ -401,12 +401,12 @@ func (k *kstub) setDumpable(dumpable uintptr) error { stub.CheckArg(k.Stub, "dumpable", dumpable, 0)) } -func (k *kstub) receive(key string, e any, fdp *uintptr) (closeFunc func() error, err error) { +func (k *kstub) receive(key string, e any, fdp *int) (closeFunc func() error, err error) { k.Helper() expect := k.Expects("receive") reflect.ValueOf(e).Elem().Set(reflect.ValueOf(expect.Args[1])) if expect.Args[2] != nil { - *fdp = expect.Args[2].(uintptr) + *fdp = int(expect.Args[2].(uintptr)) } return func() error { return k.Expects("closeReceive").Err }, expect.Error( stub.CheckArg(k.Stub, "key", key, 0)) @@ -690,38 +690,38 @@ func (panicMsgContext) Value(any) any { panic("unreachable") } // This type is meant to be embedded in partial syscallDispatcher implementations. type panicDispatcher struct{} -func (panicDispatcher) new(func(k syscallDispatcher, msg message.Msg)) { panic("unreachable") } -func (panicDispatcher) getppid() int { panic("unreachable") } -func (panicDispatcher) getpid() int { panic("unreachable") } -func (panicDispatcher) getuid() int { panic("unreachable") } -func (panicDispatcher) getgid() int { panic("unreachable") } -func (panicDispatcher) lookupEnv(string) (string, bool) { panic("unreachable") } -func (panicDispatcher) pipe() (*os.File, *os.File, error) { panic("unreachable") } -func (panicDispatcher) stat(string) (os.FileInfo, error) { panic("unreachable") } -func (panicDispatcher) open(string) (osFile, error) { panic("unreachable") } -func (panicDispatcher) readdir(string) ([]os.DirEntry, error) { panic("unreachable") } -func (panicDispatcher) tempdir() string { panic("unreachable") } -func (panicDispatcher) mkdir(string, os.FileMode) error { panic("unreachable") } -func (panicDispatcher) removeAll(string) error { panic("unreachable") } -func (panicDispatcher) exit(int) { panic("unreachable") } -func (panicDispatcher) evalSymlinks(string) (string, error) { panic("unreachable") } -func (panicDispatcher) prctl(uintptr, uintptr, uintptr) error { panic("unreachable") } -func (panicDispatcher) lookupGroupId(string) (string, error) { panic("unreachable") } -func (panicDispatcher) lookPath(string) (string, error) { panic("unreachable") } -func (panicDispatcher) cmdOutput(*exec.Cmd) ([]byte, error) { panic("unreachable") } -func (panicDispatcher) overflowUid(message.Msg) int { panic("unreachable") } -func (panicDispatcher) overflowGid(message.Msg) int { panic("unreachable") } -func (panicDispatcher) setDumpable(uintptr) error { panic("unreachable") } -func (panicDispatcher) receive(string, any, *uintptr) (func() error, error) { panic("unreachable") } -func (panicDispatcher) containerStart(*container.Container) error { panic("unreachable") } -func (panicDispatcher) containerServe(*container.Container) error { panic("unreachable") } -func (panicDispatcher) containerWait(*container.Container) error { panic("unreachable") } -func (panicDispatcher) mustHsuPath() *check.Absolute { panic("unreachable") } -func (panicDispatcher) dbusAddress() (string, string) { panic("unreachable") } -func (panicDispatcher) setupContSignal(int) (io.ReadCloser, func(), error) { panic("unreachable") } -func (panicDispatcher) getMsg() message.Msg { panic("unreachable") } -func (panicDispatcher) fatal(...any) { panic("unreachable") } -func (panicDispatcher) fatalf(string, ...any) { panic("unreachable") } +func (panicDispatcher) new(func(k syscallDispatcher, msg message.Msg)) { panic("unreachable") } +func (panicDispatcher) getppid() int { panic("unreachable") } +func (panicDispatcher) getpid() int { panic("unreachable") } +func (panicDispatcher) getuid() int { panic("unreachable") } +func (panicDispatcher) getgid() int { panic("unreachable") } +func (panicDispatcher) lookupEnv(string) (string, bool) { panic("unreachable") } +func (panicDispatcher) pipe() (*os.File, *os.File, error) { panic("unreachable") } +func (panicDispatcher) stat(string) (os.FileInfo, error) { panic("unreachable") } +func (panicDispatcher) open(string) (osFile, error) { panic("unreachable") } +func (panicDispatcher) readdir(string) ([]os.DirEntry, error) { panic("unreachable") } +func (panicDispatcher) tempdir() string { panic("unreachable") } +func (panicDispatcher) mkdir(string, os.FileMode) error { panic("unreachable") } +func (panicDispatcher) removeAll(string) error { panic("unreachable") } +func (panicDispatcher) exit(int) { panic("unreachable") } +func (panicDispatcher) evalSymlinks(string) (string, error) { panic("unreachable") } +func (panicDispatcher) prctl(uintptr, uintptr, uintptr) error { panic("unreachable") } +func (panicDispatcher) lookupGroupId(string) (string, error) { panic("unreachable") } +func (panicDispatcher) lookPath(string) (string, error) { panic("unreachable") } +func (panicDispatcher) cmdOutput(*exec.Cmd) ([]byte, error) { panic("unreachable") } +func (panicDispatcher) overflowUid(message.Msg) int { panic("unreachable") } +func (panicDispatcher) overflowGid(message.Msg) int { panic("unreachable") } +func (panicDispatcher) setDumpable(uintptr) error { panic("unreachable") } +func (panicDispatcher) receive(string, any, *int) (func() error, error) { panic("unreachable") } +func (panicDispatcher) containerStart(*container.Container) error { panic("unreachable") } +func (panicDispatcher) containerServe(*container.Container) error { panic("unreachable") } +func (panicDispatcher) containerWait(*container.Container) error { panic("unreachable") } +func (panicDispatcher) mustHsuPath() *check.Absolute { panic("unreachable") } +func (panicDispatcher) dbusAddress() (string, string) { panic("unreachable") } +func (panicDispatcher) setupContSignal(int) (io.ReadCloser, func(), error) { panic("unreachable") } +func (panicDispatcher) getMsg() message.Msg { panic("unreachable") } +func (panicDispatcher) fatal(...any) { panic("unreachable") } +func (panicDispatcher) fatalf(string, ...any) { panic("unreachable") } func (panicDispatcher) notifyContext(context.Context, ...os.Signal) (context.Context, context.CancelFunc) { panic("unreachable") diff --git a/internal/outcome/shim.go b/internal/outcome/shim.go index d2706c8b..b2074928 100644 --- a/internal/outcome/shim.go +++ b/internal/outcome/shim.go @@ -20,6 +20,7 @@ import ( "hakurei.app/ext" "hakurei.app/fhs" "hakurei.app/hst" + "hakurei.app/internal/params" "hakurei.app/internal/pipewire" "hakurei.app/message" ) @@ -197,7 +198,7 @@ func shimEntrypoint(k syscallDispatcher) { if errors.Is(err, syscall.EBADF) { k.fatal("invalid config descriptor") } - if errors.Is(err, container.ErrReceiveEnv) { + if errors.Is(err, params.ErrReceiveEnv) { k.fatal(shimEnv + " not set") } diff --git a/internal/outcome/shim_test.go b/internal/outcome/shim_test.go index 08f037da..6d9a9f3d 100644 --- a/internal/outcome/shim_test.go +++ b/internal/outcome/shim_test.go @@ -16,6 +16,7 @@ import ( "hakurei.app/fhs" "hakurei.app/hst" "hakurei.app/internal/env" + "hakurei.app/internal/params" "hakurei.app/internal/stub" ) @@ -172,7 +173,7 @@ func TestShimEntrypoint(t *testing.T) { call("setDumpable", stub.ExpectArgs{uintptr(ext.SUID_DUMP_DISABLE)}, nil, nil), call("getppid", stub.ExpectArgs{}, 0xbad, nil), call("setupContSignal", stub.ExpectArgs{0xbad}, 0, nil), - call("receive", stub.ExpectArgs{"HAKUREI_SHIM", outcomeState{}, nil}, nil, container.ErrReceiveEnv), + call("receive", stub.ExpectArgs{"HAKUREI_SHIM", outcomeState{}, nil}, nil, params.ErrReceiveEnv), call("fatal", stub.ExpectArgs{[]any{"HAKUREI_SHIM not set"}}, nil, nil), // deferred diff --git a/internal/params/params.go b/internal/params/params.go new file mode 100644 index 00000000..6f1962af --- /dev/null +++ b/internal/params/params.go @@ -0,0 +1,42 @@ +// Package params provides helpers for receiving setup payload from parent. +package params + +import ( + "encoding/gob" + "errors" + "os" + "strconv" + "syscall" +) + +// ErrReceiveEnv is returned by [Receive] if setup fd is not present in environment. +var ErrReceiveEnv = errors.New("environment variable not set") + +// Receive retrieves setup fd from the environment and receives params. +// +// The file descriptor written to the value pointed to by fdp must not be passed +// to any system calls. It is made available for ordering file descriptor only. +func Receive(key string, v any, fdp *int) (func() error, error) { + var setup *os.File + + if s, ok := os.LookupEnv(key); !ok { + return nil, ErrReceiveEnv + } else { + if fd, err := strconv.Atoi(s); err != nil { + if _err := errors.Unwrap(err); _err != nil { + err = _err + } + return nil, err + } else { + setup = os.NewFile(uintptr(fd), "setup") + if setup == nil { + return nil, syscall.EDOM + } + if fdp != nil { + *fdp = fd + } + } + } + + return setup.Close, gob.NewDecoder(setup).Decode(v) +} diff --git a/internal/params/params_test.go b/internal/params/params_test.go new file mode 100644 index 00000000..86c165b8 --- /dev/null +++ b/internal/params/params_test.go @@ -0,0 +1,131 @@ +package params_test + +import ( + "encoding/gob" + "errors" + "os" + "slices" + "strconv" + "syscall" + "testing" + + "hakurei.app/internal/params" +) + +func TestSetupReceive(t *testing.T) { + t.Run("not set", func(t *testing.T) { + const key = "TEST_ENV_NOT_SET" + { + v, ok := os.LookupEnv(key) + t.Cleanup(func() { + if ok { + if err := os.Setenv(key, v); err != nil { + t.Fatalf("Setenv: error = %v", err) + } + } else { + if err := os.Unsetenv(key); err != nil { + t.Fatalf("Unsetenv: error = %v", err) + } + } + }) + } + + if _, err := params.Receive(key, nil, nil); !errors.Is(err, params.ErrReceiveEnv) { + t.Errorf("Receive: error = %v, want %v", err, params.ErrReceiveEnv) + } + }) + + t.Run("format", func(t *testing.T) { + const key = "TEST_ENV_FORMAT" + t.Setenv(key, "") + + if _, err := params.Receive(key, nil, nil); !errors.Is(err, strconv.ErrSyntax) { + t.Errorf("Receive: error = %v, want %v", err, strconv.ErrSyntax) + } + }) + + t.Run("range", func(t *testing.T) { + const key = "TEST_ENV_RANGE" + t.Setenv(key, "-1") + + if _, err := params.Receive(key, nil, nil); !errors.Is(err, syscall.EDOM) { + t.Errorf("Receive: error = %v, want %v", err, syscall.EDOM) + } + }) + + t.Run("setup receive", func(t *testing.T) { + check := func(t *testing.T, useNilFdp bool) { + const key = "TEST_SETUP_RECEIVE" + payload := []uint64{syscall.MS_MGC_VAL, syscall.MS_MGC_MSK, syscall.MS_ASYNC, syscall.MS_ACTIVE} + + encoderDone := make(chan error, 1) + extraFiles := make([]*os.File, 0, 1) + if r, w, err := os.Pipe(); err != nil { + t.Fatalf("Setup: error = %v", err) + } else { + t.Cleanup(func() { + if err = errors.Join(r.Close(), w.Close()); err != nil { + t.Fatal(err) + } + }) + + extraFiles = append(extraFiles, r) + if deadline, ok := t.Deadline(); ok { + if err = w.SetDeadline(deadline); err != nil { + t.Fatal(err) + } + } + go func() { encoderDone <- gob.NewEncoder(w).Encode(payload) }() + } + + if len(extraFiles) != 1 { + t.Fatalf("extraFiles: len = %v, want 1", len(extraFiles)) + } + + var dupFd int + if fd, err := syscall.Dup(int(extraFiles[0].Fd())); err != nil { + t.Fatalf("Dup: error = %v", err) + } else { + syscall.CloseOnExec(fd) + dupFd = fd + t.Setenv(key, strconv.Itoa(fd)) + } + + var ( + gotPayload []uint64 + fdp *int + ) + if !useNilFdp { + fdp = new(int) + } + var closeFile func() error + if f, err := params.Receive(key, &gotPayload, fdp); err != nil { + t.Fatalf("Receive: error = %v", err) + } else { + closeFile = f + + if !slices.Equal(payload, gotPayload) { + t.Errorf("Receive: %#v, want %#v", gotPayload, payload) + } + } + if !useNilFdp { + if *fdp != dupFd { + t.Errorf("Fd: %d, want %d", *fdp, dupFd) + } + } + + if err := <-encoderDone; err != nil { + t.Errorf("Encode: error = %v", err) + } + + if closeFile != nil { + if err := closeFile(); err != nil { + t.Errorf("Close: error = %v", err) + } + } + } + + t.Run("fp", func(t *testing.T) { check(t, false) }) + t.Run("nil", func(t *testing.T) { check(t, true) }) + }) +}