diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 5edc9b8..bfa42eb 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -9,7 +9,6 @@ import ( "io" "io/fs" "log" - "maps" "os/exec" "os/user" "reflect" @@ -469,7 +468,7 @@ func TestApp(t *testing.T) { }() } - var gotParams container.Params + var gotParams *container.Params { var sShim outcomeState @@ -481,17 +480,13 @@ func TestApp(t *testing.T) { t.Fatalf("populateLocal: error = %#v", err) } - stateParams := outcomeStateParams{params: &gotParams, outcomeState: &sShim} - if sShim.Container.Env == nil { - stateParams.env = make(map[string]string, envAllocSize) - } else { - stateParams.env = maps.Clone(sShim.Container.Env) - } + stateParams := sShim.newParams() for _, op := range sShim.Shim.Ops { - if err := op.toContainer(&stateParams); err != nil { + if err := op.toContainer(stateParams); err != nil { t.Fatalf("toContainer: error = %#v", err) } } + gotParams = stateParams.params } t.Run("sys", func(t *testing.T) { @@ -501,8 +496,8 @@ func TestApp(t *testing.T) { }) t.Run("params", func(t *testing.T) { - if !reflect.DeepEqual(&gotParams, tc.wantParams) { - t.Errorf("toContainer: params =\n%s\n, want\n%s", mustMarshal(&gotParams), mustMarshal(tc.wantParams)) + if !reflect.DeepEqual(gotParams, tc.wantParams) { + t.Errorf("toContainer: params =\n%s\n, want\n%s", mustMarshal(gotParams), mustMarshal(tc.wantParams)) } }) }) diff --git a/internal/app/outcome.go b/internal/app/outcome.go index 4daa8da..e294d94 100644 --- a/internal/app/outcome.go +++ b/internal/app/outcome.go @@ -2,6 +2,7 @@ package app import ( "errors" + "maps" "strconv" "hakurei.app/container" @@ -163,7 +164,7 @@ type outcomeStateSys struct { *outcomeState } -// outcomeState returns the address of a new outcomeStateSys embedding the current outcomeState. +// newSys returns the address of a new outcomeStateSys embedding the current outcomeState. func (s *outcomeState) newSys(config *hst.Config, sys *system.I) *outcomeStateSys { return &outcomeStateSys{ appId: config.ID, et: config.Enablements.Unwrap(), @@ -173,6 +174,17 @@ func (s *outcomeState) newSys(config *hst.Config, sys *system.I) *outcomeStateSy } } +// newParams returns the address of a new outcomeStateParams embedding the current outcomeState. +func (s *outcomeState) newParams() *outcomeStateParams { + stateParams := outcomeStateParams{params: new(container.Params), outcomeState: s} + if s.Container.Env == nil { + stateParams.env = make(map[string]string, envAllocSize) + } else { + stateParams.env = maps.Clone(s.Container.Env) + } + return &stateParams +} + // ensureRuntimeDir must be called if access to paths within XDG_RUNTIME_DIR is required. func (state *outcomeStateSys) ensureRuntimeDir() { if state.useRuntimeDir { diff --git a/internal/app/shim.go b/internal/app/shim.go index 9cf515c..1fdedd3 100644 --- a/internal/app/shim.go +++ b/internal/app/shim.go @@ -5,7 +5,6 @@ import ( "errors" "io" "log" - "maps" "os" "os/exec" "os/signal" @@ -102,15 +101,9 @@ func ShimMain() { log.Fatalf("cannot set parent-death signal: %v", errno) } - var params container.Params - stateParams := outcomeStateParams{params: ¶ms, outcomeState: &state} - if state.Container.Env == nil { - stateParams.env = make(map[string]string, envAllocSize) - } else { - stateParams.env = maps.Clone(state.Container.Env) - } + stateParams := state.newParams() for _, op := range state.Shim.Ops { - if err := op.toContainer(&stateParams); err != nil { + if err := op.toContainer(stateParams); err != nil { if m, ok := message.GetMessage(err); ok { log.Fatal(m) } else { @@ -130,7 +123,7 @@ func ShimMain() { switch buf[0] { case 0: // got SIGCONT from monitor: shim exit requested - if fp := cancelContainer.Load(); params.ForwardCancel && fp != nil && *fp != nil { + if fp := cancelContainer.Load(); stateParams.params.ForwardCancel && fp != nil && *fp != nil { (*fp)() // shim now bound by ShimWaitDelay, implemented below continue @@ -158,7 +151,7 @@ func ShimMain() { } }() - if params.Ops == nil { + if stateParams.params.Ops == nil { log.Fatal("invalid container params") } @@ -171,7 +164,7 @@ func ShimMain() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) cancelContainer.Store(&stop) z := container.New(ctx, msg) - z.Params = params + z.Params = *stateParams.params z.Stdin, z.Stdout, z.Stderr = os.Stdin, os.Stdout, os.Stderr // bounds and default enforced in finalise.go