diff --git a/internal/app/app_test.go b/internal/app/app_test.go index f40b092..9cab358 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -445,108 +445,75 @@ func TestApp(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - t.Run("finalise", func(t *testing.T) { - seal := outcome{syscallDispatcher: tc.k} - err := seal.finalise(t.Context(), msg, &tc.id, tc.config) - if err != nil { - if s, ok := container.GetErrorMessage(err); !ok { - t.Fatalf("outcome: error = %v", err) - } else { - t.Fatalf("outcome: %s", s) + gr, gw := io.Pipe() + + var gotSys *system.I + { + sPriv := outcomeState{ + ID: &tc.id, + Identity: tc.config.Identity, + UserID: (&Hsu{k: tc.k}).MustIDMsg(msg), + EnvPaths: copyPaths(tc.k), + Container: tc.config.Container, + } + + sPriv.populateEarly(tc.k, msg, tc.config) + if err := sPriv.populateLocal(tc.k, msg); err != nil { + t.Fatalf("populateLocal: error = %#v", err) + } + + gotSys = system.New(t.Context(), msg, sPriv.uid.unwrap()) + stateSys := outcomeStateSys{sys: gotSys, outcomeState: &sPriv} + for _, op := range sPriv.Shim.Ops { + if err := op.toSystem(&stateSys, tc.config); err != nil { + t.Fatalf("toSystem: error = %#v", err) } } - t.Run("sys", func(t *testing.T) { - if !seal.sys.Equal(tc.wantSys) { - t.Errorf("outcome: sys = %#v, want %#v", seal.sys, tc.wantSys) + go func() { + e := gob.NewEncoder(gw) + if err := errors.Join(e.Encode(&sPriv)); err != nil { + t.Errorf("Encode: error = %v", err) + panic("unexpected encode fault") } - }) + }() + } - t.Run("params", func(t *testing.T) { - if !reflect.DeepEqual(&seal.container, tc.wantParams) { - t.Errorf("outcome: container =\n%s\n, want\n%s", mustMarshal(&seal.container), mustMarshal(tc.wantParams)) + var gotParams container.Params + { + var sShim outcomeState + + d := gob.NewDecoder(gr) + if err := errors.Join(d.Decode(&sShim)); err != nil { + t.Fatalf("Decode: error = %v", err) + } + if err := sShim.populateLocal(tc.k, msg); err != nil { + t.Fatalf("populateLocal: error = %#v", err) + } + + stateParams := outcomeStateParams{params: &gotParams, outcomeState: &sShim} + if sShim.Container.Env == nil { + stateParams.env = make(map[string]string, envAllocSize) + } else { + stateParams.env = maps.Clone(sShim.Container.Env) + } + for _, op := range sShim.Shim.Ops { + if err := op.toContainer(&stateParams); err != nil { + t.Fatalf("toContainer: error = %#v", err) } - }) + } + } + + t.Run("sys", func(t *testing.T) { + if !gotSys.Equal(tc.wantSys) { + t.Errorf("toSystem: sys = %#v, want %#v", gotSys, tc.wantSys) + } }) - t.Run("ops", func(t *testing.T) { - // copied from shim - const envAllocSize = 1 << 6 - - gr, gw := io.Pipe() - - var gotSys *system.I - { - sPriv := outcomeState{ - ID: &tc.id, - Identity: tc.config.Identity, - UserID: (&Hsu{k: tc.k}).MustIDMsg(msg), - EnvPaths: copyPaths(tc.k), - Container: tc.config.Container, - } - - sPriv.populateEarly(tc.k, msg) - if err := sPriv.populateLocal(tc.k, msg); err != nil { - t.Fatalf("populateLocal: error = %#v", err) - } - - gotSys = system.New(t.Context(), msg, sPriv.uid.unwrap()) - opsPriv := fromConfig(tc.config) - stateSys := outcomeStateSys{sys: gotSys, outcomeState: &sPriv} - for _, op := range opsPriv { - if err := op.toSystem(&stateSys, tc.config); err != nil { - t.Fatalf("toSystem: error = %#v", err) - } - } - - go func() { - e := gob.NewEncoder(gw) - if err := errors.Join(e.Encode(&sPriv), e.Encode(&opsPriv)); err != nil { - t.Errorf("Encode: error = %v", err) - panic("unexpected encode fault") - } - }() + t.Run("params", func(t *testing.T) { + if !reflect.DeepEqual(&gotParams, tc.wantParams) { + t.Errorf("toContainer: params =\n%s\n, want\n%s", mustMarshal(&gotParams), mustMarshal(tc.wantParams)) } - - var gotParams container.Params - { - var ( - sShim outcomeState - opsShim []outcomeOp - ) - - d := gob.NewDecoder(gr) - if err := errors.Join(d.Decode(&sShim), d.Decode(&opsShim)); err != nil { - t.Fatalf("Decode: error = %v", err) - } - if err := sShim.populateLocal(tc.k, msg); err != nil { - t.Fatalf("populateLocal: error = %#v", err) - } - - stateParams := outcomeStateParams{params: &gotParams, outcomeState: &sShim} - if sShim.Container.Env == nil { - stateParams.env = make(map[string]string, envAllocSize) - } else { - stateParams.env = maps.Clone(sShim.Container.Env) - } - for _, op := range opsShim { - if err := op.toContainer(&stateParams); err != nil { - t.Fatalf("toContainer: error = %#v", err) - } - } - } - - t.Run("sys", func(t *testing.T) { - if !gotSys.Equal(tc.wantSys) { - t.Errorf("toSystem: sys = %#v, want %#v", gotSys, tc.wantSys) - } - }) - - t.Run("params", func(t *testing.T) { - if !reflect.DeepEqual(&gotParams, tc.wantParams) { - t.Errorf("toContainer: params =\n%s\n, want\n%s", mustMarshal(&gotParams), mustMarshal(tc.wantParams)) - } - }) }) }) } diff --git a/internal/app/finalise.go b/internal/app/finalise.go index 491a775..a5286bd 100644 --- a/internal/app/finalise.go +++ b/internal/app/finalise.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "maps" "os" "os/user" "sync/atomic" @@ -29,29 +28,26 @@ type outcome struct { // this is prepared ahead of time as config is clobbered during seal creation ct io.WriterTo + // Supplementary group ids. Populated during finalise. + supp []string + // Resolved priv side operating system interactions. Populated during finalise. sys *system.I - ctx context.Context - - container container.Params - - // Populated during outcome.finalise. - proc *finaliseProcess + // Transmitted to shim. Populated during finalise. + state *outcomeState // Whether the current process is in outcome.main. active atomic.Bool + ctx context.Context syscallDispatcher } func (k *outcome) finalise(ctx context.Context, msg container.Msg, id *state.ID, config *hst.Config) error { - // only used for a nil configured env map - const envAllocSize = 1 << 6 - if ctx == nil || id == nil { // unreachable panic("invalid call to finalise") } - if k.ctx != nil || k.sys != nil || k.proc != nil { + if k.ctx != nil || k.sys != nil || k.state != nil { // unreachable panic("attempting to finalise twice") } @@ -71,10 +67,8 @@ func (k *outcome) finalise(ctx context.Context, msg container.Msg, id *state.ID, k.ct = ct } - var kp finaliseProcess - // hsu expects numerical group ids - kp.supp = make([]string, len(config.Groups)) + supp := make([]string, len(config.Groups)) for i, name := range config.Groups { if gid, err := k.lookupGroupId(name); err != nil { var unknownGroupError user.UnknownGroupError @@ -84,7 +78,7 @@ func (k *outcome) finalise(ctx context.Context, msg container.Msg, id *state.ID, return &hst.AppError{Step: "look up group by name", Err: err} } } else { - kp.supp[i] = gid + supp[i] = gid } } @@ -96,38 +90,21 @@ func (k *outcome) finalise(ctx context.Context, msg container.Msg, id *state.ID, EnvPaths: copyPaths(k.syscallDispatcher), Container: config.Container, } - kp.waitDelay = s.populateEarly(k.syscallDispatcher, msg) - - // TODO(ophestra): duplicate in shim (params to shim) + s.populateEarly(k.syscallDispatcher, msg, config) if err := s.populateLocal(k.syscallDispatcher, msg); err != nil { return err } - kp.runDirPath, kp.identity, kp.id = s.sc.RunDirPath, s.identity, s.id + sys := system.New(k.ctx, msg, s.uid.unwrap()) - - ops := fromConfig(config) - stateSys := outcomeStateSys{sys: sys, outcomeState: &s} - for _, op := range ops { + for _, op := range s.Shim.Ops { if err := op.toSystem(&stateSys, config); err != nil { return err } } - // TODO(ophestra): move to shim - stateParams := outcomeStateParams{params: &k.container, outcomeState: &s} - if s.Container.Env == nil { - stateParams.env = make(map[string]string, envAllocSize) - } else { - stateParams.env = maps.Clone(s.Container.Env) - } - for _, op := range ops { - if err := op.toContainer(&stateParams); err != nil { - return err - } - } - k.sys = sys - k.proc = &kp + k.supp = supp + k.state = &s return nil } diff --git a/internal/app/outcome.go b/internal/app/outcome.go index ef7c780..f029a4b 100644 --- a/internal/app/outcome.go +++ b/internal/app/outcome.go @@ -1,8 +1,8 @@ package app import ( + "os" "strconv" - "time" "hakurei.app/container" "hakurei.app/container/check" @@ -26,6 +26,9 @@ func (s *stringPair[T]) String() string { return s.s } // outcomeState is copied to the shim process and available while applying outcomeOp. // This is transmitted from the priv side to the shim, so exported fields should be kept to a minimum. type outcomeState struct { + // Params only used by the shim process. Populated by populateEarly. + Shim *shimParams + // Generated and accounted for by the caller. ID *state.ID // Copied from ID. @@ -64,6 +67,7 @@ type outcomeState struct { // valid checks outcomeState to be safe for use with outcomeOp. func (s *outcomeState) valid() bool { return s != nil && + s.Shim.valid() && s.ID != nil && s.Container != nil && s.EnvPaths != nil @@ -71,14 +75,16 @@ func (s *outcomeState) valid() bool { // populateEarly populates exported fields via syscallDispatcher. // This must only be called from the priv side. -func (s *outcomeState) populateEarly(k syscallDispatcher, msg container.Msg) (waitDelay time.Duration) { +func (s *outcomeState) populateEarly(k syscallDispatcher, msg container.Msg, config *hst.Config) { + s.Shim = &shimParams{PrivPID: os.Getpid(), Verbose: msg.IsVerbose(), Ops: fromConfig(config)} + // enforce bounds and default early if s.Container.WaitDelay <= 0 { - waitDelay = hst.WaitDelayDefault + s.Shim.WaitDelay = hst.WaitDelayDefault } else if s.Container.WaitDelay > hst.WaitDelayMax { - waitDelay = hst.WaitDelayMax + s.Shim.WaitDelay = hst.WaitDelayMax } else { - waitDelay = s.Container.WaitDelay + s.Shim.WaitDelay = s.Container.WaitDelay } if s.Container.MapRealUID { diff --git a/internal/app/outcome_test.go b/internal/app/outcome_test.go index a16f0a1..81fb6c7 100644 --- a/internal/app/outcome_test.go +++ b/internal/app/outcome_test.go @@ -16,10 +16,11 @@ func TestOutcomeStateValid(t *testing.T) { }{ {"nil", nil, false}, {"zero", new(outcomeState), false}, - {"id", &outcomeState{Container: new(hst.ContainerConfig), EnvPaths: new(EnvPaths)}, false}, - {"container", &outcomeState{ID: new(state.ID), EnvPaths: new(EnvPaths)}, false}, - {"envpaths", &outcomeState{ID: new(state.ID), Container: new(hst.ContainerConfig)}, false}, - {"valid", &outcomeState{ID: new(state.ID), Container: new(hst.ContainerConfig), EnvPaths: new(EnvPaths)}, true}, + {"shim", &outcomeState{Shim: &shimParams{PrivPID: -1, Ops: []outcomeOp{}}, Container: new(hst.ContainerConfig), EnvPaths: new(EnvPaths)}, false}, + {"id", &outcomeState{Shim: &shimParams{PrivPID: 1, Ops: []outcomeOp{}}, Container: new(hst.ContainerConfig), EnvPaths: new(EnvPaths)}, false}, + {"container", &outcomeState{Shim: &shimParams{PrivPID: 1, Ops: []outcomeOp{}}, ID: new(state.ID), EnvPaths: new(EnvPaths)}, false}, + {"envpaths", &outcomeState{Shim: &shimParams{PrivPID: 1, Ops: []outcomeOp{}}, ID: new(state.ID), Container: new(hst.ContainerConfig)}, false}, + {"valid", &outcomeState{Shim: &shimParams{PrivPID: 1, Ops: []outcomeOp{}}, ID: new(state.ID), Container: new(hst.ContainerConfig), EnvPaths: new(EnvPaths)}, true}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { diff --git a/internal/app/process.go b/internal/app/process.go index 1e738e7..a868512 100644 --- a/internal/app/process.go +++ b/internal/app/process.go @@ -13,7 +13,6 @@ import ( "time" "hakurei.app/container" - "hakurei.app/container/check" "hakurei.app/container/fhs" "hakurei.app/hst" "hakurei.app/internal" @@ -43,7 +42,6 @@ type mainState struct { k *outcome container.Msg uintptr - *finaliseProcess } const ( @@ -83,7 +81,7 @@ func (ms mainState) beforeExit(isFault bool) { waitDone := make(chan struct{}) // this ties waitDone to ctx with the additional compensated timeout duration - go func() { <-ms.k.ctx.Done(); time.Sleep(ms.waitDelay + shimWaitTimeout); close(waitDone) }() + go func() { <-ms.k.ctx.Done(); time.Sleep(ms.k.state.Shim.WaitDelay + shimWaitTimeout); close(waitDone) }() select { case err := <-ms.cmdWait: @@ -129,9 +127,9 @@ func (ms mainState) beforeExit(isFault bool) { } if ms.uintptr&mainNeedsRevert != 0 { - if ok, err := ms.store.Do(ms.identity.unwrap(), func(c state.Cursor) { + if ok, err := ms.store.Do(ms.k.state.identity.unwrap(), func(c state.Cursor) { if ms.uintptr&mainNeedsDestroy != 0 { - if err := c.Destroy(ms.id.unwrap()); err != nil { + if err := c.Destroy(ms.k.state.id.unwrap()); err != nil { perror(err, "destroy state entry") } } @@ -208,31 +206,13 @@ func (ms mainState) fatal(fallback string, ferr error) { os.Exit(1) } -// finaliseProcess contains information collected during outcome.finalise used in outcome.main. -type finaliseProcess struct { - // Supplementary group ids. - supp []string - - // Copied from [hst.ContainerConfig], without exceeding [MaxShimWaitDelay]. - waitDelay time.Duration - - // Copied from the RunDirPath field of [hst.Paths]. - runDirPath *check.Absolute - - // Copied from outcomeState. - identity *stringPair[int] - - // Copied from outcomeState. - id *stringPair[state.ID] -} - // main carries out outcome and terminates. main does not return. func (k *outcome) main(msg container.Msg) { if !k.active.CompareAndSwap(false, true) { panic("outcome: attempted to run twice") } - if k.proc == nil { + if k.ctx == nil || k.sys == nil || k.state == nil { panic("outcome: did not finalise") } @@ -240,13 +220,13 @@ func (k *outcome) main(msg container.Msg) { hsuPath := internal.MustHsuPath() // ms.beforeExit required beyond this point - ms := &mainState{Msg: msg, k: k, finaliseProcess: k.proc} + ms := &mainState{Msg: msg, k: k} if err := k.sys.Commit(); err != nil { ms.fatal("cannot commit system setup:", err) } ms.uintptr |= mainNeedsRevert - ms.store = state.NewMulti(msg, ms.runDirPath.String()) + ms.store = state.NewMulti(msg, k.state.sc.RunDirPath.String()) ctx, cancel := context.WithCancel(k.ctx) defer cancel() @@ -267,14 +247,14 @@ func (k *outcome) main(msg container.Msg) { // passed through to shim by hsu shimEnv + "=" + strconv.Itoa(fd), // interpreted by hsu - "HAKUREI_IDENTITY=" + ms.identity.String(), + "HAKUREI_IDENTITY=" + k.state.identity.String(), } } - if len(ms.supp) > 0 { - msg.Verbosef("attaching supplementary group ids %s", ms.supp) + if len(k.supp) > 0 { + msg.Verbosef("attaching supplementary group ids %s", k.supp) // interpreted by hsu - ms.cmd.Env = append(ms.cmd.Env, "HAKUREI_GROUPS="+strings.Join(ms.supp, " ")) + ms.cmd.Env = append(ms.cmd.Env, "HAKUREI_GROUPS="+strings.Join(k.supp, " ")) } msg.Verbosef("setuid helper at %s", hsuPath) @@ -293,14 +273,7 @@ func (k *outcome) main(msg container.Msg) { select { case err := <-func() (setupErr chan error) { setupErr = make(chan error, 1) - go func() { - setupErr <- e.Encode(&shimParams{ - os.Getpid(), - ms.waitDelay, - &k.container, - msg.IsVerbose(), - }) - }() + go func() { setupErr <- e.Encode(k.state) }() return }(): if err != nil { @@ -314,9 +287,9 @@ func (k *outcome) main(msg container.Msg) { } // shim accepted setup payload, create process state - if ok, err := ms.store.Do(ms.identity.unwrap(), func(c state.Cursor) { + if ok, err := ms.store.Do(k.state.identity.unwrap(), func(c state.Cursor) { if err := c.Save(&state.State{ - ID: ms.id.unwrap(), + ID: k.state.id.unwrap(), PID: ms.cmd.Process.Pid, Time: *ms.Time, }, k.ct); err != nil { diff --git a/internal/app/shim.go b/internal/app/shim.go index 3a8a551..52aa530 100644 --- a/internal/app/shim.go +++ b/internal/app/shim.go @@ -5,6 +5,7 @@ import ( "errors" "io" "log" + "maps" "os" "os/exec" "os/signal" @@ -22,22 +23,34 @@ import ( //#include "shim-signal.h" import "C" -const shimEnv = "HAKUREI_SHIM" +const ( + // setup pipe fd for [container.Receive] + shimEnv = "HAKUREI_SHIM" + + // only used for a nil configured env map + envAllocSize = 1 << 6 +) type shimParams struct { // Priv side pid, checked against ppid in signal handler for the syscall.SIGCONT hack. - Monitor int + PrivPID int - // Duration to wait for after interrupting a container's initial process before the container is killed. + // Duration to wait for after the initial process receives os.Interrupt before the container is killed. // Limits are enforced on the priv side. WaitDelay time.Duration - // Finalised container params. - // TODO(ophestra): transmit outcomeState instead (params to shim) - Container *container.Params - - // Verbosity pass through. + // Verbosity pass through from [container.Msg]. Verbose bool + + // Outcome setup ops, contains setup state. Populated by outcome.finalise. + Ops []outcomeOp +} + +// valid checks shimParams to be safe for use. +func (p *shimParams) valid() bool { + return p != nil && + p.Ops != nil && + p.PrivPID > 0 } // ShimMain is the main function of the shim process and runs as the unconstrained target user. @@ -51,28 +64,36 @@ func ShimMain() { } var ( - params shimParams + state outcomeState closeSetup func() error ) - if f, err := container.Receive(shimEnv, ¶ms, nil); err != nil { + if f, err := container.Receive(shimEnv, &state, nil); err != nil { if errors.Is(err, syscall.EBADF) { log.Fatal("invalid config descriptor") } if errors.Is(err, container.ErrReceiveEnv) { - log.Fatal("HAKUREI_SHIM not set") + log.Fatal(shimEnv + " not set") } log.Fatalf("cannot receive shim setup params: %v", err) } else { - msg.SwapVerbose(params.Verbose) + msg.SwapVerbose(state.Shim.Verbose) closeSetup = f + + if err = state.populateLocal(direct{}, msg); err != nil { + if m, ok := container.GetErrorMessage(err); ok { + log.Fatal(m) + } else { + log.Fatalf("cannot populate local state: %v", err) + } + } } - var signalPipe io.ReadCloser // the Go runtime does not expose siginfo_t so SIGCONT is handled in C to check si_pid + var signalPipe io.ReadCloser if r, w, err := os.Pipe(); err != nil { log.Fatalf("cannot pipe: %v", err) - } else if _, err = C.hakurei_shim_setup_cont_signal(C.pid_t(params.Monitor), C.int(w.Fd())); err != nil { + } else if _, err = C.hakurei_shim_setup_cont_signal(C.pid_t(state.Shim.PrivPID), C.int(w.Fd())); err != nil { log.Fatalf("cannot install SIGCONT handler: %v", err) } else { defer runtime.KeepAlive(w) @@ -84,7 +105,24 @@ func ShimMain() { log.Fatalf("cannot set parent-death signal: %v", errno) } - // signal handler outcome + var params container.Params + stateParams := outcomeStateParams{params: ¶ms, outcomeState: &state} + if state.Container.Env == nil { + stateParams.env = make(map[string]string, envAllocSize) + } else { + stateParams.env = maps.Clone(state.Container.Env) + } + for _, op := range state.Shim.Ops { + if err := op.toContainer(&stateParams); err != nil { + if m, ok := container.GetErrorMessage(err); ok { + log.Fatal(m) + } else { + log.Fatalf("cannot create container state: %v", err) + } + } + } + + // shim exit outcomes var cancelContainer atomic.Pointer[context.CancelFunc] go func() { buf := make([]byte, 1) @@ -95,7 +133,7 @@ func ShimMain() { switch buf[0] { case 0: // got SIGCONT from monitor: shim exit requested - if fp := cancelContainer.Load(); params.Container.ForwardCancel && fp != nil && *fp != nil { + if fp := cancelContainer.Load(); params.ForwardCancel && fp != nil && *fp != nil { (*fp)() // shim now bound by ShimWaitDelay, implemented below continue @@ -123,7 +161,7 @@ func ShimMain() { } }() - if params.Container == nil || params.Container.Ops == nil { + if params.Ops == nil { log.Fatal("invalid container params") } @@ -136,11 +174,11 @@ func ShimMain() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) cancelContainer.Store(&stop) z := container.New(ctx, msg) - z.Params = *params.Container + z.Params = params z.Stdin, z.Stdout, z.Stderr = os.Stdin, os.Stdout, os.Stderr // bounds and default enforced in finalise.go - z.WaitDelay = params.WaitDelay + z.WaitDelay = state.Shim.WaitDelay if err := z.Start(); err != nil { printMessageError("cannot start container:", err)