1
0
forked from rosa/hakurei
Files
hakurei/internal/uevent/uevent_test.go
Ophestra ea014d6af2 internal/uevent: consume kernel-originated events
These are not possible to cover outside integration vm. Extreme care is required when dealing with this method, so keep it simple.

Signed-off-by: Ophestra <cat@gensokyo.uk>
2026-03-28 15:39:16 +09:00

247 lines
5.1 KiB
Go

package uevent_test
import (
"context"
"encoding"
"os"
"reflect"
"sync"
"syscall"
"testing"
"time"
"hakurei.app/internal/uevent"
)
// adeT sets up a parallel subtest for a textual appender/decoder/encoder.
func adeT[V any, S interface {
encoding.TextAppender
encoding.TextMarshaler
encoding.TextUnmarshaler
*V
}](t *testing.T, name string, v V, want string, wantErr, wantErrE error) {
t.Helper()
f := func(t *testing.T) {
if name != "" {
t.Parallel()
}
t.Helper()
t.Run("decode", func(t *testing.T) {
t.Parallel()
t.Helper()
var got V
if err := S(&got).UnmarshalText([]byte(want)); !reflect.DeepEqual(err, wantErr) {
t.Fatalf("UnmarshalText: error = %v, want %v", err, wantErr)
}
if wantErr != nil {
return
}
if !reflect.DeepEqual(&got, &v) {
t.Errorf("UnmarshalText: %#v, want %#v", got, v)
}
})
t.Run("encode", func(t *testing.T) {
t.Parallel()
t.Helper()
if got, err := S(&v).MarshalText(); !reflect.DeepEqual(err, wantErrE) {
t.Fatalf("MarshalText: error = %v, want %v", err, wantErrE)
} else if err == nil && string(got) != want {
t.Errorf("MarshalText: %q, want %q", string(got), want)
}
})
}
if name != "" {
t.Run(name, f)
} else {
f(t)
}
}
// adeT sets up a binary subtest for a textual appender/decoder/encoder.
func adeB[V any, S interface {
encoding.BinaryAppender
encoding.BinaryMarshaler
encoding.BinaryUnmarshaler
*V
}](t *testing.T, name string, v V, want string, wantErr, wantErrE error) {
t.Helper()
f := func(t *testing.T) {
if name != "" {
t.Parallel()
}
t.Helper()
t.Run("decode", func(t *testing.T) {
t.Parallel()
t.Helper()
var got V
if err := S(&got).UnmarshalBinary([]byte(want)); !reflect.DeepEqual(err, wantErr) {
t.Fatalf("UnmarshalBinary: error = %v, want %v", err, wantErr)
}
if wantErr != nil {
return
}
if !reflect.DeepEqual(&got, &v) {
t.Errorf("UnmarshalBinary: %#v, want %#v", got, v)
}
})
t.Run("encode", func(t *testing.T) {
t.Parallel()
t.Helper()
if got, err := S(&v).MarshalBinary(); !reflect.DeepEqual(err, wantErrE) {
t.Fatalf("MarshalBinary: error = %v, want %v", err, wantErrE)
} else if err == nil && string(got) != want {
t.Errorf("MarshalBinary: %q, want %q", string(got), want)
}
})
}
if name != "" {
t.Run(name, f)
} else {
f(t)
}
}
func TestDialConsume(t *testing.T) {
t.Parallel()
c, err := uevent.Dial()
if err != nil {
t.Fatalf("Dial: error = %v", err)
}
t.Cleanup(func() {
if closeErr := c.Close(); closeErr != nil {
t.Fatal(err)
}
})
// check kernel-assigned port id
c0, err0 := uevent.Dial()
if err0 != nil {
t.Fatalf("Dial: error = %v", err)
}
t.Cleanup(func() {
if closeErr := c0.Close(); closeErr != nil {
t.Fatal(closeErr)
}
})
var wg sync.WaitGroup
done := make(chan struct{})
events := make(chan *uevent.Message, 1<<10)
go func() {
defer close(done)
for msg := range events {
t.Log(msg)
}
}()
t.Cleanup(func() {
wg.Wait()
close(events)
<-done
})
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
wg.Go(func() {
if err = c.Consume(ctx, events); err != context.Canceled {
panic(err)
}
})
wg.Go(func() {
if err0 = c0.Consume(ctx, events); err0 != context.Canceled {
panic(err0)
}
})
if testing.Verbose() {
if d, perr := time.ParseDuration(os.Getenv(
"ROSA_UEVENT_TEST_DURATION",
)); perr != nil {
t.Logf("skipping long test: error = %v", perr)
} else {
time.Sleep(d)
}
}
cancel()
wg.Wait()
ctx, cancel = context.WithCancel(t.Context())
defer cancel()
var errs [2]error
exclExit := make(chan struct{})
wg.Go(func() {
defer func() { exclExit <- struct{}{} }()
errs[0] = c.Consume(ctx, events)
})
wg.Go(func() {
defer func() { exclExit <- struct{}{} }()
errs[1] = c.Consume(ctx, events)
})
<-exclExit
cancel()
<-exclExit
if errs[0] != syscall.EAGAIN && errs[1] != syscall.EAGAIN {
t.Fatalf("enterExcl: err0 = %v, err1 = %v", errs[0], errs[1])
}
}
func TestErrors(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
err error
want string
}{
{"UnsupportedActionError", uevent.UnsupportedActionError("explode"),
`unsupported kobject_action "explode"`},
{"MissingHeaderError", uevent.MissingHeaderError("move"),
`message "move" has no header`},
{"MessageError MErrorKindHeaderSep", &uevent.MessageError{
Data: "move\x00",
Section: "move",
Kind: uevent.MErrorKindHeaderSep,
}, `header "move" missing separator`},
{"MessageError MErrorKindFinalNUL", &uevent.MessageError{
Data: "move\x00truncated",
Section: "truncated",
Kind: uevent.MErrorKindFinalNUL,
}, `entry "truncated" missing NUL`},
{"MessageError bad", &uevent.MessageError{
Data: "\x00",
Kind: 0xbad,
}, `section "" is invalid`},
{"BadPortError", &uevent.BadPortError{
Pid: 1,
}, "unexpected message from port id 1 on NETLINK_KOBJECT_UEVENT"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := tc.err.Error(); got != tc.want {
t.Errorf("Error: %q, want %q", got, tc.want)
}
})
}
}