diff --git a/container/dispatcher.go b/container/dispatcher.go index 2482f0b..86837b9 100644 --- a/container/dispatcher.go +++ b/container/dispatcher.go @@ -8,6 +8,7 @@ import ( "os/exec" "os/signal" "path/filepath" + "runtime" "syscall" "hakurei.app/container/seccomp" @@ -21,6 +22,9 @@ type osFile interface { // syscallDispatcher provides methods that make state-dependent system calls as part of their behaviour. type syscallDispatcher interface { + // lockOSThread provides [runtime.LockOSThread]. + lockOSThread() + // setPtracer provides [SetPtracer]. setPtracer(pid uintptr) error // setDumpable provides [SetDumpable]. @@ -136,6 +140,8 @@ type syscallDispatcher interface { // direct implements syscallDispatcher on the current kernel. type direct struct{} +func (direct) lockOSThread() { runtime.LockOSThread() } + func (direct) setPtracer(pid uintptr) error { return SetPtracer(pid) } func (direct) setDumpable(dumpable uintptr) error { return SetDumpable(dumpable) } func (direct) setNoNewPrivs() error { return SetNoNewPrivs() } diff --git a/container/dispatcher_test.go b/container/dispatcher_test.go index f01b6f3..1152e16 100644 --- a/container/dispatcher_test.go +++ b/container/dispatcher_test.go @@ -315,6 +315,8 @@ func checkArgReflect(k *kstub, arg string, got any, n int) bool { return true } +func (k *kstub) lockOSThread() { k.expect("lockOSThread") } + func (k *kstub) setPtracer(pid uintptr) error { return k.expect("setPtracer").error( checkArg(k, "pid", pid, 0)) diff --git a/container/init.go b/container/init.go index 6688a34..2139ccc 100644 --- a/container/init.go +++ b/container/init.go @@ -6,7 +6,6 @@ import ( "os" "os/exec" "path" - "runtime" "slices" "strconv" . "syscall" @@ -81,11 +80,11 @@ type initParams struct { } func Init(prepareLogger func(prefix string), setVerbose func(verbose bool)) { - initEntrypoint(prepareLogger, setVerbose, direct{}) + initEntrypoint(direct{}, prepareLogger, setVerbose) } -func initEntrypoint(prepareLogger func(prefix string), setVerbose func(verbose bool), k syscallDispatcher) { - runtime.LockOSThread() +func initEntrypoint(k syscallDispatcher, prepareLogger func(prefix string), setVerbose func(verbose bool)) { + k.lockOSThread() prepareLogger("init") if k.getpid() != 1 {