diff --git a/internal/app/state/data.go b/internal/app/state/data.go new file mode 100644 index 0000000..3c289f9 --- /dev/null +++ b/internal/app/state/data.go @@ -0,0 +1,43 @@ +package state + +import ( + "encoding/gob" + "fmt" + "io" + "os" + + "hakurei.app/hst" +) + +// entryEncode encodes [hst.State] into [io.Writer] with the state entry header. +// entryEncode does not validate the embedded [hst.Config] value. +// +// A non-nil error returned by entryEncode is of type [hst.AppError]. +func entryEncode(w io.Writer, s *hst.State) error { + if err := entryWriteHeader(w, s.Enablements.Unwrap()); err != nil { + return &hst.AppError{Step: "encode state header", Err: err} + } else if err = gob.NewEncoder(w).Encode(s); err != nil { + return &hst.AppError{Step: "encode state body", Err: err} + } else { + return nil + } +} + +// entryDecode decodes [hst.State] from [io.Reader] and stores the result in the value pointed to by p. +// entryDecode validates the embedded [hst.Config] value. +// +// A non-nil error returned by entryDecode is of type [hst.AppError]. +func entryDecode(r io.Reader, p *hst.State) error { + if et, err := entryReadHeader(r); err != nil { + return &hst.AppError{Step: "decode state header", Err: err} + } else if err = gob.NewDecoder(r).Decode(&p); err != nil { + return &hst.AppError{Step: "decode state body", Err: err} + } else if err = p.Config.Validate(); err != nil { + return err + } else if p.Enablements.Unwrap() != et { + return &hst.AppError{Step: "validate state enablement", Err: os.ErrInvalid, + Msg: fmt.Sprintf("state entry %s has unexpected enablement byte %#x, %#x", p.ID.String(), byte(p.Enablements.Unwrap()), byte(et))} + } else { + return nil + } +} diff --git a/internal/app/state/data_test.go b/internal/app/state/data_test.go new file mode 100644 index 0000000..49f0f52 --- /dev/null +++ b/internal/app/state/data_test.go @@ -0,0 +1,139 @@ +package state + +import ( + "bytes" + "encoding/gob" + "errors" + "io" + "os" + "reflect" + "strings" + "testing" + "time" + + "hakurei.app/container/stub" + "hakurei.app/hst" +) + +func TestEntryData(t *testing.T) { + t.Parallel() + newTemplateState := func() *hst.State { + return &hst.State{ + ID: hst.ID(bytes.Repeat([]byte{0xaa}, len(hst.ID{}))), + PID: 0xcafebabe, + ShimPID: 0xdeadbeef, + Config: hst.Template(), + Time: time.Unix(0, 0), + } + } + + mustEncodeGob := func(e any) string { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(e); err != nil { + t.Fatalf("cannot encode invalid state: %v", err) + return "\x00" // not reached + } else { + return buf.String() + } + } + templateStateGob := mustEncodeGob(newTemplateState()) + + testCases := []struct { + name string + data string + s *hst.State + err error + }{ + {"invalid header", "\x00\xff\xca\xfe\xff\xff\xff\x00", nil, &hst.AppError{ + Step: "decode state header", Err: errors.New("unexpected revision ffff")}}, + + {"invalid gob", "\x00\xff\xca\xfe\x00\x00\xff\x00", nil, &hst.AppError{ + Step: "decode state body", Err: io.EOF}}, + + {"invalid config", "\x00\xff\xca\xfe\x00\x00\xff\x00" + mustEncodeGob(new(hst.State)), new(hst.State), &hst.AppError{ + Step: "validate configuration", Err: hst.ErrConfigNull, + Msg: "invalid configuration"}}, + + {"inconsistent enablement", "\x00\xff\xca\xfe\x00\x00\xff\x00" + templateStateGob, newTemplateState(), &hst.AppError{ + Step: "validate state enablement", Err: os.ErrInvalid, + Msg: "state entry aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa has unexpected enablement byte 0xd, 0xff"}}, + + {"template", "\x00\xff\xca\xfe\x00\x00\x0d\xf2" + templateStateGob, newTemplateState(), nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + t.Run("encode", func(t *testing.T) { + if tc.s == nil || tc.s.Config == nil { + return + } + t.Parallel() + + var buf bytes.Buffer + if err := entryEncode(&buf, tc.s); err != nil { + t.Fatalf("entryEncode: error = %v", err) + } + + if tc.err == nil { + // Gob encoding is not guaranteed to be deterministic. + // While the current implementation mostly is, it has randomised order + // for iterating over maps, and hst.Config holds a map for environ. + var got hst.State + if err := entryDecode(&buf, &got); err != nil { + t.Fatalf("entryDecode: error = %v", err) + } + if !reflect.DeepEqual(&got, tc.s) { + t.Errorf("entryEncode: %x", buf.Bytes()) + } + } else if testing.Verbose() { + t.Logf("%x", buf.String()) + } + }) + + t.Run("decode", func(t *testing.T) { + t.Parallel() + + var got hst.State + if err := entryDecode(strings.NewReader(tc.data), &got); !reflect.DeepEqual(err, tc.err) { + t.Fatalf("entryDecode: error = %#v, want %#v", err, tc.err) + } else if err != nil { + return + } + + if !reflect.DeepEqual(&got, tc.s) { + t.Errorf("entryDecode: %#v, want %#v", &got, tc.s) + } + }) + }) + } + + t.Run("encode fault", func(t *testing.T) { + t.Parallel() + s := newTemplateState() + + t.Run("gob", func(t *testing.T) { + var want = &hst.AppError{Step: "encode state body", Err: stub.UniqueError(0xcafe)} + if err := entryEncode(stubNErrorWriter(entryHeaderSize), s); !reflect.DeepEqual(err, want) { + t.Errorf("entryEncode: error = %#v, want %#v", err, want) + } + }) + + t.Run("header", func(t *testing.T) { + var want = &hst.AppError{Step: "encode state header", Err: stub.UniqueError(0xcafe)} + if err := entryEncode(stubNErrorWriter(entryHeaderSize-1), s); !reflect.DeepEqual(err, want) { + t.Errorf("entryEncode: error = %#v, want %#v", err, want) + } + }) + }) +} + +// stubNErrorWriter returns an error for writes above a certain size. +type stubNErrorWriter int + +func (w stubNErrorWriter) Write(p []byte) (n int, err error) { + if len(p) > int(w) { + return int(w), stub.UniqueError(0xcafe) + } + return io.Discard.Write(p) +} diff --git a/internal/app/state/multi.go b/internal/app/state/multi.go index 695bff1..446d6ce 100644 --- a/internal/app/state/multi.go +++ b/internal/app/state/multi.go @@ -1,7 +1,6 @@ package state import ( - "encoding/gob" "errors" "fmt" "io/fs" @@ -161,25 +160,18 @@ func (b *multiBackend) load(decode bool) (map[hst.ID]*hst.State, error) { // append regardless, but only parse if required, implements Len if decode { - var et hst.Enablement - if et, err = entryReadHeader(f); err != nil { + if err = entryDecode(f, &s); err != nil { _ = f.Close() - return &hst.AppError{Step: "decode state header", Err: err} - } else if err = gob.NewDecoder(f).Decode(&s); err != nil { - _ = f.Close() - return &hst.AppError{Step: "decode state body", Err: err} - } else if s.ID != id { - _ = f.Close() - return fmt.Errorf("state entry %s has unexpected id %s", id, &s.ID) - } else if err = f.Close(); err != nil { - return &hst.AppError{Step: "close state file", Err: err} - } else if err = s.Config.Validate(); err != nil { return err - } else if s.Enablements.Unwrap() != et { - return fmt.Errorf("state entry %s has unexpected enablement byte %x, %x", id, s.Enablements, et) + } else if s.ID != id { + return &hst.AppError{Step: "validate state identifier", Err: os.ErrInvalid, + Msg: fmt.Sprintf("state entry %s has unexpected id %s", id, &s.ID)} } } + if err = f.Close(); err != nil { + return &hst.AppError{Step: "close state file", Err: err} + } return nil } }(); err != nil { @@ -202,12 +194,9 @@ func (b *multiBackend) Save(state *hst.State) error { statePath := b.filename(&state.ID) if f, err := os.OpenFile(statePath, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600); err != nil { return &hst.AppError{Step: "create state file", Err: err} - } else if err = entryWriteHeader(f, state.Enablements.Unwrap()); err != nil { + } else if err = entryEncode(f, state); err != nil { _ = f.Close() - return &hst.AppError{Step: "encode state header", Err: err} - } else if err = gob.NewEncoder(f).Encode(state); err != nil { - _ = f.Close() - return &hst.AppError{Step: "encode state body", Err: err} + return err } else if err = f.Close(); err != nil { return &hst.AppError{Step: "close state file", Err: err} }