diff --git a/internal/app/app_test.go b/internal/app/app_test.go index fccb7ad..34313df 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -458,17 +458,14 @@ func TestApp(t *testing.T) { Container: tc.config.Container, } - sPriv.populateEarly(tc.k, msg, tc.config) + sPriv.populateEarly(tc.k, msg) if err := sPriv.populateLocal(tc.k, msg); err != nil { t.Fatalf("populateLocal: error = %#v", err) } gotSys = system.New(t.Context(), msg, sPriv.uid.unwrap()) - stateSys := outcomeStateSys{config: tc.config, sys: gotSys, outcomeState: &sPriv} - for _, op := range sPriv.Shim.Ops { - if err := op.toSystem(&stateSys); err != nil { - t.Fatalf("toSystem: error = %#v", err) - } + if err := (&outcomeStateSys{config: tc.config, sys: gotSys, outcomeState: &sPriv}).toSystem(); err != nil { + t.Fatalf("toSystem: error = %#v", err) } go func() { diff --git a/internal/app/finalise.go b/internal/app/finalise.go index 9acc51f..62b380e 100644 --- a/internal/app/finalise.go +++ b/internal/app/finalise.go @@ -75,17 +75,14 @@ func (k *outcome) finalise(ctx context.Context, msg message.Msg, id *state.ID, c EnvPaths: copyPaths(k.syscallDispatcher), Container: config.Container, } - s.populateEarly(k.syscallDispatcher, msg, config) + s.populateEarly(k.syscallDispatcher, msg) if err := s.populateLocal(k.syscallDispatcher, msg); err != nil { return err } sys := system.New(k.ctx, msg, s.uid.unwrap()) - stateSys := outcomeStateSys{config: config, sys: sys, outcomeState: &s} - for _, op := range s.Shim.Ops { - if err := op.toSystem(&stateSys); err != nil { - return err - } + if err := (&outcomeStateSys{config: config, sys: sys, outcomeState: &s}).toSystem(); err != nil { + return err } k.sys = sys diff --git a/internal/app/outcome.go b/internal/app/outcome.go index 0161755..c5f0fa7 100644 --- a/internal/app/outcome.go +++ b/internal/app/outcome.go @@ -1,6 +1,7 @@ package app import ( + "errors" "os" "strconv" @@ -76,8 +77,8 @@ func (s *outcomeState) valid() bool { // populateEarly populates exported fields via syscallDispatcher. // This must only be called from the priv side. -func (s *outcomeState) populateEarly(k syscallDispatcher, msg message.Msg, config *hst.Config) { - s.Shim = &shimParams{PrivPID: os.Getpid(), Verbose: msg.IsVerbose(), Ops: fromConfig(config)} +func (s *outcomeState) populateEarly(k syscallDispatcher, msg message.Msg) { + s.Shim = &shimParams{PrivPID: os.Getpid(), Verbose: msg.IsVerbose()} // enforce bounds and default early if s.Container.WaitDelay <= 0 { @@ -203,6 +204,9 @@ type outcomeStateParams struct { *outcomeState } +// errNotEnabled is returned by outcomeOp.toSystem and used internally to exclude an outcomeOp from transmission. +var errNotEnabled = errors.New("op not enabled in the configuration") + // An outcomeOp inflicts an outcome on [system.I] and contains enough information to // inflict it on [container.Params] in a separate process. // An implementation of outcomeOp must store cross-process states in exported fields only. @@ -216,11 +220,15 @@ type outcomeOp interface { toContainer(state *outcomeStateParams) error } -// fromConfig returns a corresponding slice of outcomeOp for [hst.Config]. +// toSystem calls the outcomeOp.toSystem method on all outcomeOp implementations and populates shimParams.Ops. // This function assumes the caller has already called the Validate method on [hst.Config] // and checked that it returns nil. -func fromConfig(config *hst.Config) (ops []outcomeOp) { - ops = []outcomeOp{ +func (state *outcomeStateSys) toSystem() error { + if state.Shim == nil || state.Shim.Ops != nil { + return newWithMessage("invalid ops state reached") + } + + ops := [...]outcomeOp{ // must run first &spParamsOp{}, @@ -230,22 +238,27 @@ func fromConfig(config *hst.Config) (ops []outcomeOp) { spRuntimeOp{}, spTmpdirOp{}, spAccountOp{}, + + // optional via enablements + &spWaylandOp{}, + &spX11Op{}, + &spPulseOp{}, + &spDBusOp{}, + + spFinal{}, } - et := config.Enablements.Unwrap() - if et&hst.EWayland != 0 { - ops = append(ops, &spWaylandOp{}) - } - if et&hst.EX11 != 0 { - ops = append(ops, &spX11Op{}) - } - if et&hst.EPulse != 0 { - ops = append(ops, &spPulseOp{}) - } - if et&hst.EDBus != 0 { - ops = append(ops, &spDBusOp{}) - } + state.Shim.Ops = make([]outcomeOp, 0, len(ops)) + for _, op := range ops { + if err := op.toSystem(state); err != nil { + // this error is used internally to exclude this outcomeOp from transmission + if errors.Is(err, errNotEnabled) { + continue + } - ops = append(ops, spFinal{}) - return + return err + } + state.Shim.Ops = append(state.Shim.Ops, op) + } + return nil } diff --git a/internal/app/outcome_test.go b/internal/app/outcome_test.go index 81fb6c7..183dbfd 100644 --- a/internal/app/outcome_test.go +++ b/internal/app/outcome_test.go @@ -1,7 +1,6 @@ package app import ( - "reflect" "testing" "hakurei.app/hst" @@ -30,49 +29,3 @@ func TestOutcomeStateValid(t *testing.T) { }) } } - -func TestFromConfig(t *testing.T) { - testCases := []struct { - name string - config *hst.Config - want []outcomeOp - }{ - {"ne", new(hst.Config), []outcomeOp{ - &spParamsOp{}, - spFilesystemOp{}, - spRuntimeOp{}, - spTmpdirOp{}, - spAccountOp{}, - spFinal{}, - }}, - {"wayland pulse", &hst.Config{Enablements: hst.NewEnablements(hst.EWayland | hst.EPulse)}, []outcomeOp{ - &spParamsOp{}, - spFilesystemOp{}, - spRuntimeOp{}, - spTmpdirOp{}, - spAccountOp{}, - &spWaylandOp{}, - &spPulseOp{}, - spFinal{}, - }}, - {"all", &hst.Config{Enablements: hst.NewEnablements(0xff)}, []outcomeOp{ - &spParamsOp{}, - spFilesystemOp{}, - spRuntimeOp{}, - spTmpdirOp{}, - spAccountOp{}, - &spWaylandOp{}, - &spX11Op{}, - &spPulseOp{}, - &spDBusOp{}, - spFinal{}, - }}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if got := fromConfig(tc.config); !reflect.DeepEqual(got, tc.want) { - t.Errorf("fromConfig: %#v, want %#v", got, tc.want) - } - }) - } -} diff --git a/internal/app/shim.go b/internal/app/shim.go index 27305e6..9cf515c 100644 --- a/internal/app/shim.go +++ b/internal/app/shim.go @@ -48,11 +48,7 @@ type shimParams struct { } // valid checks shimParams to be safe for use. -func (p *shimParams) valid() bool { - return p != nil && - p.Ops != nil && - p.PrivPID > 0 -} +func (p *shimParams) valid() bool { return p != nil && p.PrivPID > 0 } // ShimMain is the main function of the shim process and runs as the unconstrained target user. func ShimMain() { diff --git a/internal/app/spdbus.go b/internal/app/spdbus.go index 0d26d99..ac27eeb 100644 --- a/internal/app/spdbus.go +++ b/internal/app/spdbus.go @@ -4,6 +4,7 @@ import ( "encoding/gob" "hakurei.app/container/fhs" + "hakurei.app/hst" "hakurei.app/system/acl" "hakurei.app/system/dbus" ) @@ -18,6 +19,10 @@ type spDBusOp struct { } func (s *spDBusOp) toSystem(state *outcomeStateSys) error { + if state.config.Enablements.Unwrap()&hst.EDBus == 0 { + return errNotEnabled + } + if state.config.SessionBus == nil { state.config.SessionBus = dbus.NewConfig(state.config.ID, true, true) } diff --git a/internal/app/sppulse.go b/internal/app/sppulse.go index 5e4bfab..9fc22f6 100644 --- a/internal/app/sppulse.go +++ b/internal/app/sppulse.go @@ -24,6 +24,10 @@ type spPulseOp struct { } func (s *spPulseOp) toSystem(state *outcomeStateSys) error { + if state.config.Enablements.Unwrap()&hst.EPulse == 0 { + return errNotEnabled + } + pulseRuntimeDir, pulseSocket := s.commonPaths(state.outcomeState) if _, err := state.k.stat(pulseRuntimeDir.String()); err != nil { diff --git a/internal/app/spwayland.go b/internal/app/spwayland.go index 104f0a6..25fc118 100644 --- a/internal/app/spwayland.go +++ b/internal/app/spwayland.go @@ -18,6 +18,10 @@ type spWaylandOp struct { } func (s *spWaylandOp) toSystem(state *outcomeStateSys) error { + if state.config.Enablements.Unwrap()&hst.EWayland == 0 { + return errNotEnabled + } + // outer wayland socket (usually `/run/user/%d/wayland-%d`) var socketPath *check.Absolute if name, ok := state.k.lookupEnv(wayland.WaylandDisplay); !ok { diff --git a/internal/app/spx11.go b/internal/app/spx11.go index e0e1067..219f2cb 100644 --- a/internal/app/spx11.go +++ b/internal/app/spx11.go @@ -25,6 +25,10 @@ type spX11Op struct { } func (s *spX11Op) toSystem(state *outcomeStateSys) error { + if state.config.Enablements.Unwrap()&hst.EX11 == 0 { + return errNotEnabled + } + if d, ok := state.k.lookupEnv("DISPLAY"); !ok { return newWithMessage("DISPLAY is not set") } else {