diff --git a/internal/app/dispatcher_test.go b/internal/app/dispatcher_test.go index 45d7f84..a994d0b 100644 --- a/internal/app/dispatcher_test.go +++ b/internal/app/dispatcher_test.go @@ -11,8 +11,10 @@ import ( "os/exec" "reflect" "slices" + "sync" "testing" "time" + "unsafe" "hakurei.app/container" "hakurei.app/container/check" @@ -314,6 +316,10 @@ type kstub struct { *stub.Stub[syscallDispatcher] } +func (k *kstub) new(f func(k syscallDispatcher, msg message.Msg)) { + k.New(func(k syscallDispatcher) { f(k, k.(*kstub)) }) +} + func (k *kstub) getpid() int { k.Helper(); return k.Expects("getpid").Ret.(int) } func (k *kstub) getuid() int { k.Helper(); return k.Expects("getuid").Ret.(int) } func (k *kstub) getgid() int { k.Helper(); return k.Expects("getgid").Ret.(int) } @@ -355,6 +361,61 @@ func (k *kstub) evalSymlinks(path string) (string, error) { stub.CheckArg(k.Stub, "path", path, 0)) } +func (k *kstub) prctl(op, arg2, arg3 uintptr) error { + k.Helper() + return k.Expects("prctl").Error( + stub.CheckArg(k.Stub, "op", op, 0), + stub.CheckArg(k.Stub, "arg2", arg2, 1), + stub.CheckArg(k.Stub, "arg3", arg3, 2)) +} + +func (k *kstub) setDumpable(dumpable uintptr) error { + k.Helper() + return k.Expects("setDumpable").Error( + stub.CheckArg(k.Stub, "dumpable", dumpable, 0)) +} + +func (k *kstub) receive(key string, e any, fdp *uintptr) (closeFunc func() error, err error) { + k.Helper() + expect := k.Expects("receive") + reflect.ValueOf(e).Elem().Set(reflect.ValueOf(expect.Args[1])) + if expect.Args[2] != nil { + *fdp = expect.Args[2].(uintptr) + } + return func() error { return k.Expects("closeReceive").Err }, expect.Error( + stub.CheckArg(k.Stub, "key", key, 0)) +} + +func (k *kstub) expectCheckContainer(expect *stub.Call, z *container.Container) error { + k.Helper() + err := expect.Error( + stub.CheckArgReflect(k.Stub, "params", &z.Params, 0)) + if err != nil { + k.Errorf("params:\n%s\n%s", mustMarshal(&z.Params), mustMarshal(expect.Args[0])) + } + return err +} + +func (k *kstub) containerStart(z *container.Container) error { + k.Helper() + return k.expectCheckContainer(k.Expects("containerStart"), z) +} +func (k *kstub) containerServe(z *container.Container) error { + k.Helper() + return k.expectCheckContainer(k.Expects("containerServe"), z) +} +func (k *kstub) containerWait(z *container.Container) error { + k.Helper() + return k.expectCheckContainer(k.Expects("containerWait"), z) +} + +func (k *kstub) seccompLoad(rules []seccomp.NativeRule, flags seccomp.ExportFlag) error { + k.Helper() + return k.Expects("seccompLoad").Error( + stub.CheckArgReflect(k.Stub, "rules", rules, 0), + stub.CheckArg(k.Stub, "flags", flags, 1)) +} + func (k *kstub) cmdOutput(cmd *exec.Cmd) ([]byte, error) { k.Helper() expect := k.Expects("cmdOutput") @@ -365,6 +426,16 @@ func (k *kstub) cmdOutput(cmd *exec.Cmd) ([]byte, error) { stub.CheckArg(k.Stub, "cmd.Dir", cmd.Dir, 3)) } +func (k *kstub) notifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) { + k.Helper() + if k.Expects("notifyContext").Error( + stub.CheckArgReflect(k.Stub, "parent", parent, 0), + stub.CheckArgReflect(k.Stub, "signals", signals, 1)) != nil { + k.FailNow() + } + return k.Context(), func() { k.Helper(); k.Expects("notifyContextStop") } +} + func (k *kstub) mustHsuPath() *check.Absolute { k.Helper() return k.Expects("mustHsuPath").Ret.(*check.Absolute) @@ -376,9 +447,52 @@ func (k *kstub) dbusAddress() (session, system string) { return ret[0], ret[1] } -func (k *kstub) GetLogger() *log.Logger { panic("unreachable") } +// stubTrackReader embeds kstub but switches the underlying [stub.Stub] index to sub on its first Read. +// The resulting kstub does not share any state with the instance passed to the instrumented goroutine. +// Therefore, any method making use of such must not be called. +type stubTrackReader struct { + sub int + subOnce sync.Once -func (k *kstub) IsVerbose() bool { k.Helper(); return k.Expects("isVerbose").Ret.(bool) } + *kstub +} + +func (r *stubTrackReader) Read(p []byte) (n int, err error) { + r.subOnce.Do(func() { + subVal := reflect.ValueOf(r.kstub.Stub).Elem().FieldByName("sub") + r.kstub = &kstub{panicDispatcher{}, reflect. + NewAt(subVal.Type(), unsafe.Pointer(subVal.UnsafeAddr())).Elem(). + Interface().([]*stub.Stub[syscallDispatcher])[r.sub]} + }) + + return r.kstub.Read(p) +} + +func (k *kstub) setupContSignal(pid int) (io.ReadCloser, func(), error) { + k.Helper() + expect := k.Expects("setupContSignal") + return &stubTrackReader{sub: expect.Ret.(int), kstub: k}, func() { k.Expects("wKeepAlive") }, expect.Error( + stub.CheckArg(k.Stub, "pid", pid, 0)) +} + +func (k *kstub) getMsg() message.Msg { k.Helper(); k.Expects("getMsg"); return k } + +func (k *kstub) Close() error { k.Helper(); return k.Expects("rcClose").Err } +func (k *kstub) Read(p []byte) (n int, err error) { + k.Helper() + expect := k.Expects("rcRead") + + // special case to terminate exit outcomes goroutine + // to proceed with further testing of the entrypoint + if expect.Ret == nil { + panic(stub.PanicExit) + } + + return copy(p, expect.Ret.([]byte)), expect.Err +} + +func (k *kstub) GetLogger() *log.Logger { k.Helper(); return k.Expects("getLogger").Ret.(*log.Logger) } +func (k *kstub) IsVerbose() bool { k.Helper(); return k.Expects("isVerbose").Ret.(bool) } func (k *kstub) SwapVerbose(verbose bool) bool { k.Helper() expect := k.Expects("swapVerbose") diff --git a/internal/app/shim_test.go b/internal/app/shim_test.go new file mode 100644 index 0000000..e207dfd --- /dev/null +++ b/internal/app/shim_test.go @@ -0,0 +1,155 @@ +package app + +import ( + "bytes" + "context" + "log" + "os" + "syscall" + "testing" + + "hakurei.app/container" + "hakurei.app/container/comp" + "hakurei.app/container/fhs" + "hakurei.app/container/seccomp" + "hakurei.app/container/stub" + "hakurei.app/hst" +) + +func TestShimEntrypoint(t *testing.T) { + t.Parallel() + shimPreset := seccomp.Preset(comp.PresetStrict, seccomp.AllowMultiarch) + templateParams := &container.Params{ + Dir: m("/data/data/org.chromium.Chromium"), + Env: []string{ + "DBUS_SESSION_BUS_ADDRESS=unix:path=/run/user/1000/bus", + "DBUS_SYSTEM_BUS_ADDRESS=unix:path=/var/run/dbus/system_bus_socket", + "GOOGLE_API_KEY=AIzaSyBHDrl33hwRp4rMQY0ziRbj8K9LPA6vUCY", + "GOOGLE_DEFAULT_CLIENT_ID=77185425430.apps.googleusercontent.com", + "GOOGLE_DEFAULT_CLIENT_SECRET=OTJgUOQcT7lO7GsGZq2G4IlT", + "HOME=/data/data/org.chromium.Chromium", + "PULSE_COOKIE=/.hakurei/pulse-cookie", + "PULSE_SERVER=unix:/run/user/1000/pulse/native", + "SHELL=/run/current-system/sw/bin/zsh", + "TERM=xterm-256color", + "USER=chronos", + "WAYLAND_DISPLAY=wayland-0", + "XDG_RUNTIME_DIR=/run/user/1000", + "XDG_SESSION_CLASS=user", + "XDG_SESSION_TYPE=wayland", + }, + + // spParamsOp + Hostname: "localhost", + RetainSession: true, + HostNet: true, + HostAbstract: true, + ForwardCancel: true, + Path: m("/run/current-system/sw/bin/chromium"), + Args: []string{ + "chromium", + "--ignore-gpu-blocklist", + "--disable-smooth-scrolling", + "--enable-features=UseOzonePlatform", + "--ozone-platform=wayland", + }, + SeccompFlags: seccomp.AllowMultiarch, + Uid: 1000, + Gid: 100, + + Ops: new(container.Ops). + // resolveRoot + Root(m("/var/lib/hakurei/base/org.debian"), comp.BindWritable). + // spParamsOp + Proc(fhs.AbsProc). + Tmpfs(hst.AbsPrivateTmp, 1<<12, 0755). + Bind(fhs.AbsDev, fhs.AbsDev, comp.BindWritable|comp.BindDevice). + Tmpfs(fhs.AbsDev.Append("shm"), 0, 01777). + + // spRuntimeOp + Tmpfs(fhs.AbsRunUser, 1<<12, 0755). + Bind(m("/tmp/hakurei.10/runtime/9999"), m("/run/user/1000"), comp.BindWritable). + + // spTmpdirOp + Bind(m("/tmp/hakurei.10/tmpdir/9999"), fhs.AbsTmp, comp.BindWritable). + + // spAccountOp + Place(m("/etc/passwd"), []byte("chronos:x:1000:100:Hakurei:/data/data/org.chromium.Chromium:/run/current-system/sw/bin/zsh\n")). + Place(m("/etc/group"), []byte("hakurei:x:100:\n")). + + // spWaylandOp + Bind(m("/tmp/hakurei.10/aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa/wayland"), m("/run/user/1000/wayland-0"), 0). + + // spPulseOp + Bind(m("/run/user/1000/hakurei/aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa/pulse"), m("/run/user/1000/pulse/native"), 0). + Place(m("/.hakurei/pulse-cookie"), bytes.Repeat([]byte{0}, pulseCookieSizeMax)). + + // spDBusOp + Bind(m("/tmp/hakurei.10/aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa/bus"), m("/run/user/1000/bus"), 0). + Bind(m("/tmp/hakurei.10/aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa/system_bus_socket"), m("/var/run/dbus/system_bus_socket"), 0). + + // spFilesystemOp + Etc(fhs.AbsEtc, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"). + Tmpfs(fhs.AbsTmp, 0, 0755). + Overlay(m("/nix/store"), + fhs.AbsVarLib.Append("hakurei/nix/u0/org.chromium.Chromium/rw-store/upper"), + fhs.AbsVarLib.Append("hakurei/nix/u0/org.chromium.Chromium/rw-store/work"), + fhs.AbsVarLib.Append("hakurei/base/org.nixos/ro-store")). + Link(m("/run/current-system"), "/run/current-system", true). + Link(m("/run/opengl-driver"), "/run/opengl-driver", true). + Bind(fhs.AbsVarLib.Append("hakurei/u0/org.chromium.Chromium"), + m("/data/data/org.chromium.Chromium"), + comp.BindWritable|comp.BindEnsure). + Bind(fhs.AbsDev.Append("dri"), fhs.AbsDev.Append("dri"), + comp.BindOptional|comp.BindWritable|comp.BindDevice). + Remount(fhs.AbsRoot, syscall.MS_RDONLY), + } + + checkSimple(t, "shimEntrypoint", []simpleTestCase{ + {"success", func(k *kstub) error { shimEntrypoint(k); return nil }, stub.Expect{Calls: []stub.Call{ + call("getMsg", stub.ExpectArgs{}, nil, nil), + call("getLogger", stub.ExpectArgs{}, (*log.Logger)(nil), nil), + call("setDumpable", stub.ExpectArgs{uintptr(container.SUID_DUMP_DISABLE)}, nil, nil), + call("receive", stub.ExpectArgs{"HAKUREI_SHIM", outcomeState{ + Shim: &shimParams{PrivPID: 0xbad, WaitDelay: 0xf, Verbose: true, Ops: []outcomeOp{ + &spParamsOp{"xterm-256color", true}, + &spRuntimeOp{sessionTypeWayland}, + spTmpdirOp{}, + spAccountOp{}, + &spWaylandOp{}, + &spPulseOp{(*[256]byte)(bytes.Repeat([]byte{0}, pulseCookieSizeMax)), pulseCookieSizeMax}, + &spDBusOp{true}, + &spFilesystemOp{}, + }}, + + ID: &checkExpectInstanceId, + Identity: hst.IdentityMax, + UserID: 10, + Container: hst.Template().Container, + Mapuid: 1000, + Mapgid: 100, + EnvPaths: &EnvPaths{TempDir: fhs.AbsTmp, RuntimePath: fhs.AbsRunUser.Append("1000")}, + }, nil}, nil, nil), + call("swapVerbose", stub.ExpectArgs{true}, false, nil), + call("verbosef", stub.ExpectArgs{"process share directory at %q, runtime directory at %q", []any{m("/tmp/hakurei.10"), m("/run/user/1000/hakurei")}}, nil, nil), + call("setupContSignal", stub.ExpectArgs{0xbad}, 0, nil), + call("prctl", stub.ExpectArgs{uintptr(syscall.PR_SET_PDEATHSIG), uintptr(syscall.SIGCONT), uintptr(0)}, nil, nil), + call("New", stub.ExpectArgs{}, nil, nil), + call("closeReceive", stub.ExpectArgs{}, nil, nil), + call("notifyContext", stub.ExpectArgs{context.Background(), []os.Signal{os.Interrupt, syscall.SIGTERM}}, nil, nil), + call("containerStart", stub.ExpectArgs{templateParams}, nil, nil), + call("containerServe", stub.ExpectArgs{templateParams}, nil, nil), + call("seccompLoad", stub.ExpectArgs{shimPreset, seccomp.AllowMultiarch}, nil, nil), + call("containerWait", stub.ExpectArgs{templateParams}, nil, nil), + + // deferred + call("wKeepAlive", stub.ExpectArgs{}, nil, nil), + }, Tracks: []stub.Expect{{Calls: []stub.Call{ + call("rcRead", stub.ExpectArgs{}, []byte{2}, nil), + call("verbose", stub.ExpectArgs{[]any{"sa_sigaction got invalid siginfo"}}, nil, nil), + call("rcRead", stub.ExpectArgs{}, []byte{3}, nil), + call("verbose", stub.ExpectArgs{[]any{"got SIGCONT from unexpected process"}}, nil, nil), + call("rcRead", stub.ExpectArgs{}, nil, nil), // stub terminates this goroutine + }}}}, nil}, + }) +}