package uevent_test import ( "context" "encoding" "encoding/hex" "os" "reflect" "strings" "sync" "syscall" "testing" "time" "hakurei.app/fhs" "hakurei.app/internal/uevent" ) // adeT sets up a parallel subtest for a textual appender/decoder/encoder. func adeT[V any, S interface { encoding.TextAppender encoding.TextMarshaler encoding.TextUnmarshaler *V }](t *testing.T, name string, v V, want string, wantErr, wantErrE error) { t.Helper() noEncode := strings.HasSuffix(name, "\x00") if noEncode { name = name[:len(name)-1] } f := func(t *testing.T) { if name != "" { t.Parallel() } t.Helper() t.Run("decode", func(t *testing.T) { t.Parallel() t.Helper() var got V if err := S(&got).UnmarshalText([]byte(want)); !reflect.DeepEqual(err, wantErr) { t.Fatalf("UnmarshalText: error = %v, want %v", err, wantErr) } if wantErr != nil { return } if !reflect.DeepEqual(&got, &v) { t.Errorf("UnmarshalText: %#v, want %#v", got, v) } }) if noEncode { return } t.Run("encode", func(t *testing.T) { t.Parallel() t.Helper() if got, err := S(&v).MarshalText(); !reflect.DeepEqual(err, wantErrE) { t.Fatalf("MarshalText: error = %v, want %v", err, wantErrE) } else if err == nil && string(got) != want { t.Errorf("MarshalText: %q, want %q", string(got), want) } }) } if name != "" { t.Run(name, f) } else { f(t) } } // adeT sets up a binary subtest for a textual appender/decoder/encoder. func adeB[V any, S interface { encoding.BinaryAppender encoding.BinaryMarshaler encoding.BinaryUnmarshaler *V }](t *testing.T, name string, v V, want string, wantErr, wantErrE error) { t.Helper() f := func(t *testing.T) { if name != "" { t.Parallel() } t.Helper() t.Run("decode", func(t *testing.T) { t.Parallel() t.Helper() var got V if err := S(&got).UnmarshalBinary([]byte(want)); !reflect.DeepEqual(err, wantErr) { t.Fatalf("UnmarshalBinary: error = %v, want %v", err, wantErr) } if wantErr != nil { return } if !reflect.DeepEqual(&got, &v) { t.Errorf("UnmarshalBinary: %#v, want %#v", got, v) } }) t.Run("encode", func(t *testing.T) { t.Parallel() t.Helper() if got, err := S(&v).MarshalBinary(); !reflect.DeepEqual(err, wantErrE) { t.Fatalf("MarshalBinary: error = %v, want %v", err, wantErrE) } else if err == nil && string(got) != want { t.Errorf("MarshalBinary: %q, want %q", string(got), want) } }) } if name != "" { t.Run(name, f) } else { f(t) } } func TestUUID(t *testing.T) { t.Parallel() adeT(t, "sample", uevent.UUID{ 0xfe, 0x4d, 0x7c, 0x9d, 0xb8, 0xc6, 0x4a, 0x70, 0x9e, 0xf1, 0x3d, 0x8a, 0x58, 0xd1, 0x8e, 0xed, }, "fe4d7c9d-b8c6-4a70-9ef1-3d8a58d18eed", nil, nil) adeT(t, "auto\x00", uevent.UUID{}, "0", uevent.ErrAutoUUID, nil) adeT(t, "short\x00", uevent.UUID{}, "1", uevent.UUIDSizeError(1), nil) adeT(t, "bad0\x00", uevent.UUID{}, "fe4d7c9\x00-b8c6-4a70-9ef1-3d8a58d18eed", hex.InvalidByteError(0), nil) adeT(t, "sep0\x00", uevent.UUID{}, "fe4d7c9d\x00b8c6-4a70-9ef1-3d8a58d18eed", uevent.UUIDSeparatorError(0), nil) adeT(t, "bad1\x00", uevent.UUID{}, "fe4d7c9d-b8c\x00-4a70-9ef1-3d8a58d18eed", hex.InvalidByteError(0), nil) adeT(t, "sep1\x00", uevent.UUID{}, "fe4d7c9d-b8c6\x004a70-9ef1-3d8a58d18eed", uevent.UUIDSeparatorError(0), nil) adeT(t, "bad2\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a7\x00-9ef1-3d8a58d18eed", hex.InvalidByteError(0), nil) adeT(t, "sep2\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70\x009ef1-3d8a58d18eed", uevent.UUIDSeparatorError(0), nil) adeT(t, "bad3\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70-9ef\x00-3d8a58d18eed", hex.InvalidByteError(0), nil) adeT(t, "sep3\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70-9ef1\x003d8a58d18eed", uevent.UUIDSeparatorError(0), nil) adeT(t, "bad4\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70-9ef1-3d8a58d18ee\x00", hex.InvalidByteError(0), nil) } func TestDialConsume(t *testing.T) { t.Parallel() c, err := uevent.Dial(0) if err != nil { t.Fatalf("Dial: error = %v", err) } t.Cleanup(func() { if closeErr := c.Close(); closeErr != nil { t.Fatal(err) } }) // check kernel-assigned port id c0, err0 := uevent.Dial(0) if err0 != nil { t.Fatalf("Dial: error = %v", err) } t.Cleanup(func() { if closeErr := c0.Close(); closeErr != nil { t.Fatal(closeErr) } }) var wg sync.WaitGroup done := make(chan struct{}) events := make(chan *uevent.Message, 1<<10) go func() { defer close(done) for msg := range events { t.Log(msg) } }() t.Cleanup(func() { wg.Wait() close(events) <-done }) ctx, cancel := context.WithCancel(t.Context()) defer cancel() consume := func(c *uevent.Conn, ctx context.Context) error { return c.Consume(ctx, fhs.Sys, nil, events, false, func(path string) { t.Log("coldboot visited", path) }, func(err error) bool { t.Log(err) _, ok := err.(uevent.NeedsColdboot) return !ok }, nil) } wg.Go(func() { if err = consume(c, ctx); err != context.Canceled { panic(err) } }) wg.Go(func() { if err0 = consume(c0, ctx); err0 != context.Canceled { panic(err0) } }) if testing.Verbose() { if d, perr := time.ParseDuration(os.Getenv( "ROSA_UEVENT_TEST_DURATION", )); perr != nil { t.Logf("skipping long test: error = %v", perr) } else { time.Sleep(d) } } cancel() wg.Wait() ctx, cancel = context.WithCancel(t.Context()) defer cancel() var errs [2]error exclExit := make(chan struct{}) wg.Go(func() { defer func() { exclExit <- struct{}{} }() errs[0] = consume(c, ctx) }) wg.Go(func() { defer func() { exclExit <- struct{}{} }() errs[1] = consume(c, ctx) }) <-exclExit cancel() <-exclExit if errs[0] != syscall.EAGAIN && errs[1] != syscall.EAGAIN { t.Fatalf("enterExcl: err0 = %v, err1 = %v", errs[0], errs[1]) } } func TestErrors(t *testing.T) { t.Parallel() testCases := []struct { name string err error want string }{ {"UnsupportedActionError", uevent.UnsupportedActionError("explode"), `unsupported kobject_action "explode"`}, {"MissingHeaderError", uevent.MissingHeaderError("move"), `message "move" has no header`}, {"MessageError MErrorKindHeaderSep", &uevent.MessageError{ Data: "move\x00", Section: "move", Kind: uevent.MErrorKindHeaderSep, }, `header "move" missing separator`}, {"MessageError MErrorKindFinalNUL", &uevent.MessageError{ Data: "move\x00truncated", Section: "truncated", Kind: uevent.MErrorKindFinalNUL, }, `entry "truncated" missing NUL`}, {"MessageError bad", &uevent.MessageError{ Data: "\x00", Kind: 0xbad, }, `section "" is invalid`}, {"BadPortError", &uevent.BadPortError{ Pid: 1, }, "unexpected message from port id 1 on NETLINK_KOBJECT_UEVENT"}, {"UUIDSizeError", uevent.UUIDSizeError(0xbad), "got 2989 bytes instead of 36"}, {"UUIDSeparatorError", uevent.UUIDSeparatorError(0xfd), "invalid UUID separator: U+00FD 'ý'"}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() if got := tc.err.Error(); got != tc.want { t.Errorf("Error: %q, want %q", got, tc.want) } }) } }