forked from rosa/hakurei
internal/uevent: consume kernel-originated events
These are not possible to cover outside integration vm. Extreme care is required when dealing with this method, so keep it simple. Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
@@ -4,6 +4,9 @@
|
|||||||
package uevent
|
package uevent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
@@ -48,3 +51,55 @@ func Dial() (*Conn, error) {
|
|||||||
}
|
}
|
||||||
return &Conn{conn: c}, err
|
return &Conn{conn: c}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrBadSocket is returned by [Conn.Consume] for a reply from a
|
||||||
|
// syscall.Sockaddr with unexpected concrete type.
|
||||||
|
ErrBadSocket = errors.New("unexpected socket address")
|
||||||
|
)
|
||||||
|
|
||||||
|
// BadPortError is returned by [Conn.Consume] upon receiving a message that did
|
||||||
|
// not come from the kernel.
|
||||||
|
type BadPortError syscall.SockaddrNetlink
|
||||||
|
|
||||||
|
var _ Recoverable = new(BadPortError)
|
||||||
|
|
||||||
|
func (*BadPortError) recoverable() {}
|
||||||
|
func (e *BadPortError) Error() string {
|
||||||
|
return "unexpected message from port id " + strconv.Itoa(int(e.Pid)) +
|
||||||
|
" on NETLINK_KOBJECT_UEVENT"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume continuously receives and parses events from the kernel. It returns
|
||||||
|
// the first error it encounters.
|
||||||
|
//
|
||||||
|
// Callers must not restart event processing after a non-nil error that does not
|
||||||
|
// satisfy [Recoverable] is returned.
|
||||||
|
func (c *Conn) Consume(ctx context.Context, events chan<- *Message) error {
|
||||||
|
if err := c.enterExcl(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer c.exitExcl()
|
||||||
|
|
||||||
|
for {
|
||||||
|
data, from, err := c.conn.Recvfrom(ctx, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// lib/kobject_uevent.c:
|
||||||
|
// set portid 0 to inform userspace message comes from kernel
|
||||||
|
if v, ok := from.(*syscall.SockaddrNetlink); !ok {
|
||||||
|
return ErrBadSocket
|
||||||
|
} else if v.Pid != 0 {
|
||||||
|
return (*BadPortError)(v)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg Message
|
||||||
|
if err = msg.UnmarshalBinary(data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
events <- &msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
package uevent_test
|
package uevent_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding"
|
"encoding"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"hakurei.app/internal/uevent"
|
"hakurei.app/internal/uevent"
|
||||||
)
|
)
|
||||||
@@ -108,6 +113,92 @@ func adeB[V any, S interface {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialConsume(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
c, err := uevent.Dial()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial: error = %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if closeErr := c.Close(); closeErr != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// check kernel-assigned port id
|
||||||
|
c0, err0 := uevent.Dial()
|
||||||
|
if err0 != nil {
|
||||||
|
t.Fatalf("Dial: error = %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if closeErr := c0.Close(); closeErr != nil {
|
||||||
|
t.Fatal(closeErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
done := make(chan struct{})
|
||||||
|
events := make(chan *uevent.Message, 1<<10)
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
for msg := range events {
|
||||||
|
t.Log(msg)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(events)
|
||||||
|
<-done
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
wg.Go(func() {
|
||||||
|
if err = c.Consume(ctx, events); err != context.Canceled {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
wg.Go(func() {
|
||||||
|
if err0 = c0.Consume(ctx, events); err0 != context.Canceled {
|
||||||
|
panic(err0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if testing.Verbose() {
|
||||||
|
if d, perr := time.ParseDuration(os.Getenv(
|
||||||
|
"ROSA_UEVENT_TEST_DURATION",
|
||||||
|
)); perr != nil {
|
||||||
|
t.Logf("skipping long test: error = %v", perr)
|
||||||
|
} else {
|
||||||
|
time.Sleep(d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
ctx, cancel = context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var errs [2]error
|
||||||
|
exclExit := make(chan struct{})
|
||||||
|
wg.Go(func() {
|
||||||
|
defer func() { exclExit <- struct{}{} }()
|
||||||
|
errs[0] = c.Consume(ctx, events)
|
||||||
|
})
|
||||||
|
wg.Go(func() {
|
||||||
|
defer func() { exclExit <- struct{}{} }()
|
||||||
|
errs[1] = c.Consume(ctx, events)
|
||||||
|
})
|
||||||
|
<-exclExit
|
||||||
|
cancel()
|
||||||
|
<-exclExit
|
||||||
|
if errs[0] != syscall.EAGAIN && errs[1] != syscall.EAGAIN {
|
||||||
|
t.Fatalf("enterExcl: err0 = %v, err1 = %v", errs[0], errs[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestErrors(t *testing.T) {
|
func TestErrors(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@@ -138,6 +229,10 @@ func TestErrors(t *testing.T) {
|
|||||||
Data: "\x00",
|
Data: "\x00",
|
||||||
Kind: 0xbad,
|
Kind: 0xbad,
|
||||||
}, `section "" is invalid`},
|
}, `section "" is invalid`},
|
||||||
|
|
||||||
|
{"BadPortError", &uevent.BadPortError{
|
||||||
|
Pid: 1,
|
||||||
|
}, "unexpected message from port id 1 on NETLINK_KOBJECT_UEVENT"},
|
||||||
}
|
}
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user