From cd0beeaf8ed5fb81ba20375517c7a3bf34fd27cd Mon Sep 17 00:00:00 2001 From: Ophestra Date: Mon, 6 Apr 2026 11:43:16 +0900 Subject: [PATCH] internal/uevent: optionally pass UUID during coldboot This enables rejection of non-coldboot synthetic events. Signed-off-by: Ophestra --- internal/uevent/coldboot.go | 9 ++- internal/uevent/coldboot_test.go | 4 +- internal/uevent/uevent.go | 118 ++++++++++++++++++++++++++++++- internal/uevent/uevent_test.go | 58 ++++++++++++++- 4 files changed, 184 insertions(+), 5 deletions(-) diff --git a/internal/uevent/coldboot.go b/internal/uevent/coldboot.go index 07f6497f..660f0284 100644 --- a/internal/uevent/coldboot.go +++ b/internal/uevent/coldboot.go @@ -7,6 +7,7 @@ import ( "log" "os" "path/filepath" + "slices" ) // synthAdd is prepared bytes written to uevent to cause a synthetic add event @@ -26,6 +27,7 @@ var synthAdd = []byte(KOBJ_ADD.String()) func Coldboot( ctx context.Context, pathname string, + uuid *UUID, visited chan<- string, handleWalkErr func(error) error, ) error { @@ -39,6 +41,11 @@ func Coldboot( } } + add := synthAdd + if uuid != nil { + add = slices.Concat(add, []byte{' '}, []byte(uuid.String())) + } + return filepath.WalkDir(filepath.Join(pathname, "devices"), func( path string, d fs.DirEntry, @@ -54,7 +61,7 @@ func Coldboot( if d.IsDir() || d.Name() != "uevent" { return nil } - if err = os.WriteFile(path, synthAdd, 0); err != nil { + if err = os.WriteFile(path, add, 0); err != nil { return handleWalkErr(err) } diff --git a/internal/uevent/coldboot_test.go b/internal/uevent/coldboot_test.go index e3491e64..4b38453a 100644 --- a/internal/uevent/coldboot_test.go +++ b/internal/uevent/coldboot_test.go @@ -59,7 +59,7 @@ func TestColdboot(t *testing.T) { } }) - err := uevent.Coldboot(t.Context(), d, visited, func(err error) error { + err := uevent.Coldboot(t.Context(), d, nil, visited, func(err error) error { t.Errorf("handleWalkErr: %v", err) return err }) @@ -219,7 +219,7 @@ func TestColdbootError(t *testing.T) { } } - if err := uevent.Coldboot(ctx, d, visited, handleWalkErr); !reflect.DeepEqual(err, wantErr) { + if err := uevent.Coldboot(ctx, d, new(uevent.UUID), visited, handleWalkErr); !reflect.DeepEqual(err, wantErr) { t.Errorf("Coldboot: error = %v, want %v", err, wantErr) } }) diff --git a/internal/uevent/uevent.go b/internal/uevent/uevent.go index 2c37172f..0e85c4cf 100644 --- a/internal/uevent/uevent.go +++ b/internal/uevent/uevent.go @@ -5,10 +5,14 @@ package uevent import ( "context" + "encoding" + "encoding/hex" "errors" + "fmt" "strconv" "sync/atomic" "syscall" + "unsafe" "hakurei.app/internal/netlink" ) @@ -121,6 +125,117 @@ func (c *Conn) receiveEvent(ctx context.Context) (*Message, error) { return &msg, err } +// UUID represents the value of SYNTH_UUID. +// +// This is not a generic UUID implementation. Do not attempt to use it for +// anything other than passing and interpreting the SYNTH_UUID environment +// variable of a uevent. +type UUID [16]byte + +const ( + // SizeUUID is the fixed size of string representation of [UUID] according + // to Documentation/ABI/testing/sysfs-uevent. + SizeUUID = 4 + len(UUID{})*2 + + // UUIDSep is the separator byte of [UUID]. + UUIDSep = '-' +) + +var ( + _ encoding.TextAppender = new(UUID) + _ encoding.TextMarshaler = new(UUID) + _ encoding.TextUnmarshaler = new(UUID) +) + +// String formats uuid according to Documentation/ABI/testing/sysfs-uevent. +func (uuid *UUID) String() string { + s := make([]byte, 0, SizeUUID) + s = hex.AppendEncode(s, uuid[:4]) + s = append(s, UUIDSep) + s = hex.AppendEncode(s, uuid[4:6]) + s = append(s, UUIDSep) + s = hex.AppendEncode(s, uuid[6:8]) + s = append(s, UUIDSep) + s = hex.AppendEncode(s, uuid[8:10]) + s = append(s, UUIDSep) + s = hex.AppendEncode(s, uuid[10:16]) + return unsafe.String(unsafe.SliceData(s), len(s)) +} + +func (uuid *UUID) AppendText(data []byte) ([]byte, error) { + return append(data, uuid.String()...), nil +} +func (uuid *UUID) MarshalText() ([]byte, error) { + return uuid.AppendText(nil) +} + +var ( + // ErrAutoUUID is returned parsing a SYNTH_UUID generated by the kernel for + // a synthetic event without a UUID passed in. + ErrAutoUUID = errors.New("UUID is not passed in") +) + +// UUIDSizeError describes an incorrectly sized string representation of [UUID]. +type UUIDSizeError int + +func (e UUIDSizeError) Error() string { + return "got " + strconv.Itoa(int(e)) + " bytes " + + "instead of " + strconv.Itoa(SizeUUID) +} + +// UUIDSeparatorError is an invalid separator in a malformed string +// representation of [UUID]. +type UUIDSeparatorError byte + +func (e UUIDSeparatorError) Error() string { + return fmt.Sprintf("invalid UUID separator: %#U", rune(e)) +} + +// UnmarshalText parses data according to Documentation/ABI/testing/sysfs-uevent. +func (uuid *UUID) UnmarshalText(data []byte) (err error) { + if len(data) == 1 && data[0] == '0' { + return ErrAutoUUID + } + if len(data) != SizeUUID { + return UUIDSizeError(len(data)) + } + + if _, err = hex.Decode(uuid[:], data[:8]); err != nil { + return + } + if data[8] != UUIDSep { + return UUIDSeparatorError(data[8]) + } + data = data[9:] + + if _, err = hex.Decode(uuid[4:], data[:4]); err != nil { + return + } + if data[4] != UUIDSep { + return UUIDSeparatorError(data[4]) + } + data = data[5:] + + if _, err = hex.Decode(uuid[6:], data[:4]); err != nil { + return + } + if data[4] != UUIDSep { + return UUIDSeparatorError(data[4]) + } + data = data[5:] + + if _, err = hex.Decode(uuid[8:], data[:4]); err != nil { + return + } + if data[4] != UUIDSep { + return UUIDSeparatorError(data[4]) + } + data = data[5:] + + _, err = hex.Decode(uuid[10:], data) + return +} + // Consume continuously receives and parses events from the kernel and handles // [Recoverable] and [NeedsColdboot] errors via caller-supplied functions, // entering coldboot when required. @@ -145,6 +260,7 @@ func (c *Conn) receiveEvent(ctx context.Context) (*Message, error) { func (c *Conn) Consume( ctx context.Context, sysfs string, + uuid *UUID, events chan<- *Message, coldboot bool, @@ -193,7 +309,7 @@ coldboot: ctxColdboot, cancelColdboot := context.WithCancel(ctx) var coldbootErr error go func() { - coldbootErr = Coldboot(ctxColdboot, sysfs, visited, handleWalkErr) + coldbootErr = Coldboot(ctxColdboot, sysfs, uuid, visited, handleWalkErr) close(visited) }() for pathname := range visited { diff --git a/internal/uevent/uevent_test.go b/internal/uevent/uevent_test.go index 6420eea1..d7d938e4 100644 --- a/internal/uevent/uevent_test.go +++ b/internal/uevent/uevent_test.go @@ -3,8 +3,10 @@ package uevent_test import ( "context" "encoding" + "encoding/hex" "os" "reflect" + "strings" "sync" "syscall" "testing" @@ -23,6 +25,12 @@ func adeT[V any, S interface { *V }](t *testing.T, name string, v V, want string, wantErr, wantErrE error) { t.Helper() + + noEncode := strings.HasSuffix(name, "\x00") + if noEncode { + name = name[:len(name)-1] + } + f := func(t *testing.T) { if name != "" { t.Parallel() @@ -46,6 +54,10 @@ func adeT[V any, S interface { } }) + if noEncode { + return + } + t.Run("encode", func(t *testing.T) { t.Parallel() t.Helper() @@ -114,6 +126,45 @@ func adeB[V any, S interface { } } +func TestUUID(t *testing.T) { + t.Parallel() + + adeT(t, "sample", uevent.UUID{ + 0xfe, 0x4d, 0x7c, 0x9d, + 0xb8, 0xc6, + 0x4a, 0x70, + 0x9e, 0xf1, + 0x3d, 0x8a, 0x58, 0xd1, 0x8e, 0xed, + }, "fe4d7c9d-b8c6-4a70-9ef1-3d8a58d18eed", nil, nil) + + adeT(t, "auto\x00", uevent.UUID{}, "0", uevent.ErrAutoUUID, nil) + adeT(t, "short\x00", uevent.UUID{}, "1", uevent.UUIDSizeError(1), nil) + + adeT(t, "bad0\x00", uevent.UUID{}, "fe4d7c9\x00-b8c6-4a70-9ef1-3d8a58d18eed", + hex.InvalidByteError(0), nil) + adeT(t, "sep0\x00", uevent.UUID{}, "fe4d7c9d\x00b8c6-4a70-9ef1-3d8a58d18eed", + uevent.UUIDSeparatorError(0), nil) + + adeT(t, "bad1\x00", uevent.UUID{}, "fe4d7c9d-b8c\x00-4a70-9ef1-3d8a58d18eed", + hex.InvalidByteError(0), nil) + adeT(t, "sep1\x00", uevent.UUID{}, "fe4d7c9d-b8c6\x004a70-9ef1-3d8a58d18eed", + uevent.UUIDSeparatorError(0), nil) + + adeT(t, "bad2\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a7\x00-9ef1-3d8a58d18eed", + hex.InvalidByteError(0), nil) + adeT(t, "sep2\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70\x009ef1-3d8a58d18eed", + uevent.UUIDSeparatorError(0), nil) + + adeT(t, "bad3\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70-9ef\x00-3d8a58d18eed", + hex.InvalidByteError(0), nil) + adeT(t, "sep3\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70-9ef1\x003d8a58d18eed", + uevent.UUIDSeparatorError(0), nil) + + adeT(t, "bad4\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70-9ef1-3d8a58d18ee\x00", + hex.InvalidByteError(0), nil) + +} + func TestDialConsume(t *testing.T) { t.Parallel() @@ -157,7 +208,7 @@ func TestDialConsume(t *testing.T) { defer cancel() consume := func(c *uevent.Conn, ctx context.Context) error { - return c.Consume(ctx, fhs.Sys, events, false, func(path string) { + return c.Consume(ctx, fhs.Sys, nil, events, false, func(path string) { t.Log("coldboot visited", path) }, func(err error) bool { t.Log(err) @@ -244,6 +295,11 @@ func TestErrors(t *testing.T) { {"BadPortError", &uevent.BadPortError{ Pid: 1, }, "unexpected message from port id 1 on NETLINK_KOBJECT_UEVENT"}, + + {"UUIDSizeError", uevent.UUIDSizeError(0xbad), + "got 2989 bytes instead of 36"}, + {"UUIDSeparatorError", uevent.UUIDSeparatorError(0xfd), + "invalid UUID separator: U+00FD 'ý'"}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) {