internal/uevent: optionally pass UUID during coldboot
All checks were successful
Test / Create distribution (push) Successful in 1m3s
Test / Sandbox (push) Successful in 2m42s
Test / Hakurei (push) Successful in 3m49s
Test / ShareFS (push) Successful in 3m47s
Test / Sandbox (race detector) (push) Successful in 5m12s
Test / Hakurei (race detector) (push) Successful in 6m20s
Test / Flake checks (push) Successful in 1m20s

This enables rejection of non-coldboot synthetic events.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2026-04-06 11:43:16 +09:00
parent a69273ab2a
commit cd0beeaf8e
4 changed files with 184 additions and 5 deletions

View File

@@ -7,6 +7,7 @@ import (
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"slices"
) )
// synthAdd is prepared bytes written to uevent to cause a synthetic add event // synthAdd is prepared bytes written to uevent to cause a synthetic add event
@@ -26,6 +27,7 @@ var synthAdd = []byte(KOBJ_ADD.String())
func Coldboot( func Coldboot(
ctx context.Context, ctx context.Context,
pathname string, pathname string,
uuid *UUID,
visited chan<- string, visited chan<- string,
handleWalkErr func(error) error, handleWalkErr func(error) error,
) error { ) error {
@@ -39,6 +41,11 @@ func Coldboot(
} }
} }
add := synthAdd
if uuid != nil {
add = slices.Concat(add, []byte{' '}, []byte(uuid.String()))
}
return filepath.WalkDir(filepath.Join(pathname, "devices"), func( return filepath.WalkDir(filepath.Join(pathname, "devices"), func(
path string, path string,
d fs.DirEntry, d fs.DirEntry,
@@ -54,7 +61,7 @@ func Coldboot(
if d.IsDir() || d.Name() != "uevent" { if d.IsDir() || d.Name() != "uevent" {
return nil return nil
} }
if err = os.WriteFile(path, synthAdd, 0); err != nil { if err = os.WriteFile(path, add, 0); err != nil {
return handleWalkErr(err) return handleWalkErr(err)
} }

View File

@@ -59,7 +59,7 @@ func TestColdboot(t *testing.T) {
} }
}) })
err := uevent.Coldboot(t.Context(), d, visited, func(err error) error { err := uevent.Coldboot(t.Context(), d, nil, visited, func(err error) error {
t.Errorf("handleWalkErr: %v", err) t.Errorf("handleWalkErr: %v", err)
return err return err
}) })
@@ -219,7 +219,7 @@ func TestColdbootError(t *testing.T) {
} }
} }
if err := uevent.Coldboot(ctx, d, visited, handleWalkErr); !reflect.DeepEqual(err, wantErr) { if err := uevent.Coldboot(ctx, d, new(uevent.UUID), visited, handleWalkErr); !reflect.DeepEqual(err, wantErr) {
t.Errorf("Coldboot: error = %v, want %v", err, wantErr) t.Errorf("Coldboot: error = %v, want %v", err, wantErr)
} }
}) })

View File

@@ -5,10 +5,14 @@ package uevent
import ( import (
"context" "context"
"encoding"
"encoding/hex"
"errors" "errors"
"fmt"
"strconv" "strconv"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"unsafe"
"hakurei.app/internal/netlink" "hakurei.app/internal/netlink"
) )
@@ -121,6 +125,117 @@ func (c *Conn) receiveEvent(ctx context.Context) (*Message, error) {
return &msg, err return &msg, err
} }
// UUID represents the value of SYNTH_UUID.
//
// This is not a generic UUID implementation. Do not attempt to use it for
// anything other than passing and interpreting the SYNTH_UUID environment
// variable of a uevent.
type UUID [16]byte
const (
// SizeUUID is the fixed size of string representation of [UUID] according
// to Documentation/ABI/testing/sysfs-uevent.
SizeUUID = 4 + len(UUID{})*2
// UUIDSep is the separator byte of [UUID].
UUIDSep = '-'
)
var (
_ encoding.TextAppender = new(UUID)
_ encoding.TextMarshaler = new(UUID)
_ encoding.TextUnmarshaler = new(UUID)
)
// String formats uuid according to Documentation/ABI/testing/sysfs-uevent.
func (uuid *UUID) String() string {
s := make([]byte, 0, SizeUUID)
s = hex.AppendEncode(s, uuid[:4])
s = append(s, UUIDSep)
s = hex.AppendEncode(s, uuid[4:6])
s = append(s, UUIDSep)
s = hex.AppendEncode(s, uuid[6:8])
s = append(s, UUIDSep)
s = hex.AppendEncode(s, uuid[8:10])
s = append(s, UUIDSep)
s = hex.AppendEncode(s, uuid[10:16])
return unsafe.String(unsafe.SliceData(s), len(s))
}
func (uuid *UUID) AppendText(data []byte) ([]byte, error) {
return append(data, uuid.String()...), nil
}
func (uuid *UUID) MarshalText() ([]byte, error) {
return uuid.AppendText(nil)
}
var (
// ErrAutoUUID is returned parsing a SYNTH_UUID generated by the kernel for
// a synthetic event without a UUID passed in.
ErrAutoUUID = errors.New("UUID is not passed in")
)
// UUIDSizeError describes an incorrectly sized string representation of [UUID].
type UUIDSizeError int
func (e UUIDSizeError) Error() string {
return "got " + strconv.Itoa(int(e)) + " bytes " +
"instead of " + strconv.Itoa(SizeUUID)
}
// UUIDSeparatorError is an invalid separator in a malformed string
// representation of [UUID].
type UUIDSeparatorError byte
func (e UUIDSeparatorError) Error() string {
return fmt.Sprintf("invalid UUID separator: %#U", rune(e))
}
// UnmarshalText parses data according to Documentation/ABI/testing/sysfs-uevent.
func (uuid *UUID) UnmarshalText(data []byte) (err error) {
if len(data) == 1 && data[0] == '0' {
return ErrAutoUUID
}
if len(data) != SizeUUID {
return UUIDSizeError(len(data))
}
if _, err = hex.Decode(uuid[:], data[:8]); err != nil {
return
}
if data[8] != UUIDSep {
return UUIDSeparatorError(data[8])
}
data = data[9:]
if _, err = hex.Decode(uuid[4:], data[:4]); err != nil {
return
}
if data[4] != UUIDSep {
return UUIDSeparatorError(data[4])
}
data = data[5:]
if _, err = hex.Decode(uuid[6:], data[:4]); err != nil {
return
}
if data[4] != UUIDSep {
return UUIDSeparatorError(data[4])
}
data = data[5:]
if _, err = hex.Decode(uuid[8:], data[:4]); err != nil {
return
}
if data[4] != UUIDSep {
return UUIDSeparatorError(data[4])
}
data = data[5:]
_, err = hex.Decode(uuid[10:], data)
return
}
// Consume continuously receives and parses events from the kernel and handles // Consume continuously receives and parses events from the kernel and handles
// [Recoverable] and [NeedsColdboot] errors via caller-supplied functions, // [Recoverable] and [NeedsColdboot] errors via caller-supplied functions,
// entering coldboot when required. // entering coldboot when required.
@@ -145,6 +260,7 @@ func (c *Conn) receiveEvent(ctx context.Context) (*Message, error) {
func (c *Conn) Consume( func (c *Conn) Consume(
ctx context.Context, ctx context.Context,
sysfs string, sysfs string,
uuid *UUID,
events chan<- *Message, events chan<- *Message,
coldboot bool, coldboot bool,
@@ -193,7 +309,7 @@ coldboot:
ctxColdboot, cancelColdboot := context.WithCancel(ctx) ctxColdboot, cancelColdboot := context.WithCancel(ctx)
var coldbootErr error var coldbootErr error
go func() { go func() {
coldbootErr = Coldboot(ctxColdboot, sysfs, visited, handleWalkErr) coldbootErr = Coldboot(ctxColdboot, sysfs, uuid, visited, handleWalkErr)
close(visited) close(visited)
}() }()
for pathname := range visited { for pathname := range visited {

View File

@@ -3,8 +3,10 @@ package uevent_test
import ( import (
"context" "context"
"encoding" "encoding"
"encoding/hex"
"os" "os"
"reflect" "reflect"
"strings"
"sync" "sync"
"syscall" "syscall"
"testing" "testing"
@@ -23,6 +25,12 @@ func adeT[V any, S interface {
*V *V
}](t *testing.T, name string, v V, want string, wantErr, wantErrE error) { }](t *testing.T, name string, v V, want string, wantErr, wantErrE error) {
t.Helper() t.Helper()
noEncode := strings.HasSuffix(name, "\x00")
if noEncode {
name = name[:len(name)-1]
}
f := func(t *testing.T) { f := func(t *testing.T) {
if name != "" { if name != "" {
t.Parallel() t.Parallel()
@@ -46,6 +54,10 @@ func adeT[V any, S interface {
} }
}) })
if noEncode {
return
}
t.Run("encode", func(t *testing.T) { t.Run("encode", func(t *testing.T) {
t.Parallel() t.Parallel()
t.Helper() t.Helper()
@@ -114,6 +126,45 @@ func adeB[V any, S interface {
} }
} }
func TestUUID(t *testing.T) {
t.Parallel()
adeT(t, "sample", uevent.UUID{
0xfe, 0x4d, 0x7c, 0x9d,
0xb8, 0xc6,
0x4a, 0x70,
0x9e, 0xf1,
0x3d, 0x8a, 0x58, 0xd1, 0x8e, 0xed,
}, "fe4d7c9d-b8c6-4a70-9ef1-3d8a58d18eed", nil, nil)
adeT(t, "auto\x00", uevent.UUID{}, "0", uevent.ErrAutoUUID, nil)
adeT(t, "short\x00", uevent.UUID{}, "1", uevent.UUIDSizeError(1), nil)
adeT(t, "bad0\x00", uevent.UUID{}, "fe4d7c9\x00-b8c6-4a70-9ef1-3d8a58d18eed",
hex.InvalidByteError(0), nil)
adeT(t, "sep0\x00", uevent.UUID{}, "fe4d7c9d\x00b8c6-4a70-9ef1-3d8a58d18eed",
uevent.UUIDSeparatorError(0), nil)
adeT(t, "bad1\x00", uevent.UUID{}, "fe4d7c9d-b8c\x00-4a70-9ef1-3d8a58d18eed",
hex.InvalidByteError(0), nil)
adeT(t, "sep1\x00", uevent.UUID{}, "fe4d7c9d-b8c6\x004a70-9ef1-3d8a58d18eed",
uevent.UUIDSeparatorError(0), nil)
adeT(t, "bad2\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a7\x00-9ef1-3d8a58d18eed",
hex.InvalidByteError(0), nil)
adeT(t, "sep2\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70\x009ef1-3d8a58d18eed",
uevent.UUIDSeparatorError(0), nil)
adeT(t, "bad3\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70-9ef\x00-3d8a58d18eed",
hex.InvalidByteError(0), nil)
adeT(t, "sep3\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70-9ef1\x003d8a58d18eed",
uevent.UUIDSeparatorError(0), nil)
adeT(t, "bad4\x00", uevent.UUID{}, "fe4d7c9d-b8c6-4a70-9ef1-3d8a58d18ee\x00",
hex.InvalidByteError(0), nil)
}
func TestDialConsume(t *testing.T) { func TestDialConsume(t *testing.T) {
t.Parallel() t.Parallel()
@@ -157,7 +208,7 @@ func TestDialConsume(t *testing.T) {
defer cancel() defer cancel()
consume := func(c *uevent.Conn, ctx context.Context) error { consume := func(c *uevent.Conn, ctx context.Context) error {
return c.Consume(ctx, fhs.Sys, events, false, func(path string) { return c.Consume(ctx, fhs.Sys, nil, events, false, func(path string) {
t.Log("coldboot visited", path) t.Log("coldboot visited", path)
}, func(err error) bool { }, func(err error) bool {
t.Log(err) t.Log(err)
@@ -244,6 +295,11 @@ func TestErrors(t *testing.T) {
{"BadPortError", &uevent.BadPortError{ {"BadPortError", &uevent.BadPortError{
Pid: 1, Pid: 1,
}, "unexpected message from port id 1 on NETLINK_KOBJECT_UEVENT"}, }, "unexpected message from port id 1 on NETLINK_KOBJECT_UEVENT"},
{"UUIDSizeError", uevent.UUIDSizeError(0xbad),
"got 2989 bytes instead of 36"},
{"UUIDSeparatorError", uevent.UUIDSeparatorError(0xfd),
"invalid UUID separator: U+00FD 'ý'"},
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {