diff --git a/internal/kobject/kobject.go b/internal/kobject/kobject.go index f8f040e3..72d8f052 100644 --- a/internal/kobject/kobject.go +++ b/internal/kobject/kobject.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "maps" + "slices" "strconv" "sync" @@ -96,6 +97,13 @@ func (o *Object) update(env map[string]string, strip bool) { } } +// A pendingIterator is a callback currently iterating through objects targeted +// by ongoing events. +type pendingIterator struct { + f func(o *Object) bool + done chan<- struct{} +} + // State processes a stream of [Event] populated from [uevent.Message] received // from a NETLINK_KOBJECT_UEVENT socket and presents an efficient representation // of kernel state. @@ -106,6 +114,10 @@ type State struct { uevent map[string]*Object // Synchronises access to uevent and its objects. ueventMu sync.RWMutex + // Alive iterators. + iter []*pendingIterator + // Synchronises access to iter. + iterMu sync.Mutex // UUID for synthetic [uevent.Coldboot] events. coldboot uevent.UUID // Called on [uevent.KOBJ_CHANGE] with stripped environment variables. @@ -129,6 +141,59 @@ func New( } } +// deleteIter removes an iterator from s. Must be called after acquiring iterMu. +func (s *State) deleteIter(p *pendingIterator) { + s.iter = slices.DeleteFunc(s.iter, func(v *pendingIterator) bool { + return p == v + }) +} + +// dispatchIter broadcasts an [Object] to all alive iterators. +func (s *State) dispatchIter(o *Object) { + s.iterMu.Lock() + defer s.iterMu.Unlock() + + for _, p := range s.iter { + if !p.f(o) { + s.deleteIter(p) + close(p.done) + } + } +} + +// Range calls f on all current and upcoming [Object] values tracked by s until +// f returns false or the context is cancelled. f must not retain o or modify +// the value it points to. +func (s *State) Range(ctx context.Context, f func(o *Object) bool) { + done := make(chan struct{}) + p := pendingIterator{f, done} + + s.iterMu.Lock() + s.ueventMu.RLock() + for _, o := range s.uevent { + if !f(o) { + s.ueventMu.RUnlock() + s.iterMu.Unlock() + return + } + } + s.ueventMu.RUnlock() + s.iter = append(s.iter, &p) + s.iterMu.Unlock() + + select { + case <-ctx.Done(): + s.iterMu.Lock() + s.deleteIter(&p) + s.iterMu.Unlock() + return + + case <-done: + // deregistered by dispatchIter + return + } +} + // UnexpectedColdbootError is reported by [State.Consume] for a coldboot event // with action other than the expected [uevent.KOBJ_ADD]. type UnexpectedColdbootError Event @@ -219,6 +284,7 @@ func (s *State) processEvent(e *Event) { if o, ok := s.uevent[e.DevPath]; ok { s.reportErr(DuplicateAddError(e.Clone())) o.merge(e.Env) + s.dispatchIter(o) return } } @@ -228,6 +294,7 @@ func (s *State) processEvent(e *Event) { } o.merge(e.Env) s.uevent[e.DevPath] = o + s.dispatchIter(o) return case uevent.KOBJ_REMOVE: @@ -253,12 +320,14 @@ func (s *State) processEvent(e *Event) { o = e.makeColdboot() o.merge(e.Env) s.uevent[e.DevPath] = o + s.dispatchIter(o) return } o.update(e.Env, true) if s.handleChange != nil { s.handleChange(o, e.Env) } + s.dispatchIter(o) return case uevent.KOBJ_MOVE: @@ -281,6 +350,7 @@ func (s *State) processEvent(e *Event) { o.merge(e.Env) s.uevent[e.DevPath] = o o.DevPath = e.DevPath + s.dispatchIter(o) return case uevent.KOBJ_ONLINE: @@ -296,6 +366,7 @@ func (s *State) processEvent(e *Event) { s.reportErr(UnexpectedOfflineError(o.Clone())) } o.Offline = false + s.dispatchIter(o) return case uevent.KOBJ_OFFLINE: @@ -311,6 +382,7 @@ func (s *State) processEvent(e *Event) { s.reportErr(UnexpectedOfflineError(o.Clone())) } o.Offline = true + s.dispatchIter(o) return case uevent.KOBJ_BIND: @@ -326,6 +398,7 @@ func (s *State) processEvent(e *Event) { } o.State = StateBound o.merge(e.Env) + s.dispatchIter(o) return case uevent.KOBJ_UNBIND: @@ -341,6 +414,7 @@ func (s *State) processEvent(e *Event) { } o.State = StateNew o.Driver = "" + s.dispatchIter(o) return default: // not reached diff --git a/internal/kobject/kobject_test.go b/internal/kobject/kobject_test.go index 892ff241..f04db2d1 100644 --- a/internal/kobject/kobject_test.go +++ b/internal/kobject/kobject_test.go @@ -2,6 +2,7 @@ package kobject_test import ( "bytes" + "context" _ "embed" "encoding/json" "fmt" @@ -10,7 +11,9 @@ import ( "reflect" "slices" "strings" + "sync" "testing" + "testing/synctest" "unsafe" . "hakurei.app/internal/kobject" @@ -335,6 +338,90 @@ func TestConsume(t *testing.T) { } } +func TestIter(t *testing.T) { + t.Parallel() + synctest.Test(t, func(t *testing.T) { + s := New(uevent.UUID{}, nil, nil) + var wg sync.WaitGroup + defer wg.Wait() + events := make(chan *uevent.Message) + defer close(events) + wg.Go(func() { s.Consume(t.Context(), events) }) + + events <- &uevent.Message{Action: uevent.KOBJ_ADD, DevPath: "\x00", Env: []string{ + "V=\xfd", + "SEQNUM=0", + }} + events <- &uevent.Message{Action: uevent.KOBJ_ADD, DevPath: "\x01", Env: []string{ + "V=\xfc", + "SEQNUM=1", + }} + synctest.Wait() + s.Range(t.Context(), func(o *Object) bool { return false }) + + var got []*Object + check := func(want []*Object) { + slices.SortFunc(got, func(a, b *Object) int { + return strings.Compare(a.DevPath, b.DevPath) + }) + if !reflect.DeepEqual(got, want) { + t.Fatalf("Range: %#v, want %#v", got, want) + } + got = nil + } + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + var done bool + wg.Go(func() { + s.Range(ctx, func(o *Object) bool { + got = append(got, new(o.Clone())) + return !done + }) + }) + synctest.Wait() + check([]*Object{ + { + State: StateNew, + DevPath: "\x00", + Env: map[string]string{"V": "\xfd"}, + }, + { + State: StateNew, + DevPath: "\x01", + Env: map[string]string{"V": "\xfc"}, + }, + }) + + done = true + events <- &uevent.Message{Action: uevent.KOBJ_MOVE, DevPath: " ", Env: []string{ + "DEVPATH_OLD=\x01", + }} + synctest.Wait() + check([]*Object{ + { + State: StateNew, + DevPath: " ", + Env: map[string]string{"V": "\xfc"}, + }, + }) + + wg.Go(func() { s.Range(ctx, func(*Object) bool { return true }) }) + synctest.Wait() + + iter := reflect.ValueOf(s).Elem().FieldByName("iter") + if l := iter.Len(); l != 1 { + t.Errorf("len(s.iter): %d", l) + } + + cancel() + synctest.Wait() + if l := iter.Len(); l != 0 { + t.Errorf("len(s.iter): %d", l) + } + }) +} + func TestErrors(t *testing.T) { t.Parallel()