diff --git a/internal/uevent/uevent.go b/internal/uevent/uevent.go index 12754f46..61c0216e 100644 --- a/internal/uevent/uevent.go +++ b/internal/uevent/uevent.go @@ -4,6 +4,9 @@ package uevent import ( + "context" + "errors" + "strconv" "sync/atomic" "syscall" @@ -48,3 +51,55 @@ func Dial() (*Conn, error) { } return &Conn{conn: c}, err } + +var ( + // ErrBadSocket is returned by [Conn.Consume] for a reply from a + // syscall.Sockaddr with unexpected concrete type. + ErrBadSocket = errors.New("unexpected socket address") +) + +// BadPortError is returned by [Conn.Consume] upon receiving a message that did +// not come from the kernel. +type BadPortError syscall.SockaddrNetlink + +var _ Recoverable = new(BadPortError) + +func (*BadPortError) recoverable() {} +func (e *BadPortError) Error() string { + return "unexpected message from port id " + strconv.Itoa(int(e.Pid)) + + " on NETLINK_KOBJECT_UEVENT" +} + +// Consume continuously receives and parses events from the kernel. It returns +// the first error it encounters. +// +// Callers must not restart event processing after a non-nil error that does not +// satisfy [Recoverable] is returned. +func (c *Conn) Consume(ctx context.Context, events chan<- *Message) error { + if err := c.enterExcl(); err != nil { + return err + } + defer c.exitExcl() + + for { + data, from, err := c.conn.Recvfrom(ctx, 0) + if err != nil { + return err + } + + // lib/kobject_uevent.c: + // set portid 0 to inform userspace message comes from kernel + if v, ok := from.(*syscall.SockaddrNetlink); !ok { + return ErrBadSocket + } else if v.Pid != 0 { + return (*BadPortError)(v) + + } + + var msg Message + if err = msg.UnmarshalBinary(data); err != nil { + return err + } + events <- &msg + } +} diff --git a/internal/uevent/uevent_test.go b/internal/uevent/uevent_test.go index 53d9bafc..7dce4798 100644 --- a/internal/uevent/uevent_test.go +++ b/internal/uevent/uevent_test.go @@ -1,9 +1,14 @@ package uevent_test import ( + "context" "encoding" + "os" "reflect" + "sync" + "syscall" "testing" + "time" "hakurei.app/internal/uevent" ) @@ -108,6 +113,92 @@ func adeB[V any, S interface { } } +func TestDialConsume(t *testing.T) { + t.Parallel() + + c, err := uevent.Dial() + if err != nil { + t.Fatalf("Dial: error = %v", err) + } + t.Cleanup(func() { + if closeErr := c.Close(); closeErr != nil { + t.Fatal(err) + } + }) + + // check kernel-assigned port id + c0, err0 := uevent.Dial() + if err0 != nil { + t.Fatalf("Dial: error = %v", err) + } + t.Cleanup(func() { + if closeErr := c0.Close(); closeErr != nil { + t.Fatal(closeErr) + } + }) + + var wg sync.WaitGroup + done := make(chan struct{}) + events := make(chan *uevent.Message, 1<<10) + go func() { + defer close(done) + for msg := range events { + t.Log(msg) + } + }() + t.Cleanup(func() { + wg.Wait() + close(events) + <-done + }) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + wg.Go(func() { + if err = c.Consume(ctx, events); err != context.Canceled { + panic(err) + } + }) + wg.Go(func() { + if err0 = c0.Consume(ctx, events); err0 != context.Canceled { + panic(err0) + } + }) + + if testing.Verbose() { + if d, perr := time.ParseDuration(os.Getenv( + "ROSA_UEVENT_TEST_DURATION", + )); perr != nil { + t.Logf("skipping long test: error = %v", perr) + } else { + time.Sleep(d) + } + } + cancel() + wg.Wait() + + ctx, cancel = context.WithCancel(t.Context()) + defer cancel() + + var errs [2]error + exclExit := make(chan struct{}) + wg.Go(func() { + defer func() { exclExit <- struct{}{} }() + errs[0] = c.Consume(ctx, events) + }) + wg.Go(func() { + defer func() { exclExit <- struct{}{} }() + errs[1] = c.Consume(ctx, events) + }) + <-exclExit + cancel() + <-exclExit + if errs[0] != syscall.EAGAIN && errs[1] != syscall.EAGAIN { + t.Fatalf("enterExcl: err0 = %v, err1 = %v", errs[0], errs[1]) + } +} + func TestErrors(t *testing.T) { t.Parallel() @@ -138,6 +229,10 @@ func TestErrors(t *testing.T) { Data: "\x00", Kind: 0xbad, }, `section "" is invalid`}, + + {"BadPortError", &uevent.BadPortError{ + Pid: 1, + }, "unexpected message from port id 1 on NETLINK_KOBJECT_UEVENT"}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) {