From c34439fc5fe52db3f346131bd6398da8a69fe0bf Mon Sep 17 00:00:00 2001 From: Ophestra Date: Wed, 3 Dec 2025 01:29:19 +0900 Subject: [PATCH] internal/pipewire: collect non-protocol errors These errors are recoverable and should not terminate event handling. Only terminate event handling for protocol errors or inconsistent state that makes further event handling impossible. Signed-off-by: Ophestra --- internal/pipewire/client.go | 10 +- internal/pipewire/core.go | 36 +++---- internal/pipewire/pipewire.go | 154 ++++++++++++++++++++++----- internal/pipewire/pipewire_test.go | 65 +++++++++++ internal/pipewire/securitycontext.go | 7 +- 5 files changed, 211 insertions(+), 61 deletions(-) diff --git a/internal/pipewire/client.go b/internal/pipewire/client.go index 7c619ca..8e901d8 100644 --- a/internal/pipewire/client.go +++ b/internal/pipewire/client.go @@ -94,14 +94,12 @@ type Client struct { Properties SPADict `json:"props"` } -func (client *Client) consume(opcode byte, files []int, unmarshal func(v any) error) error { - if err := closeReceivedFiles(files...); err != nil { - return err - } - +func (client *Client) consume(opcode byte, files []int, unmarshal func(v any)) error { + closeReceivedFiles(files...) switch opcode { case PW_CLIENT_EVENT_INFO: - return unmarshal(&client.Info) + unmarshal(&client.Info) + return nil default: return &UnsupportedOpcodeError{opcode, client.String()} diff --git a/internal/pipewire/core.go b/internal/pipewire/core.go index 671bbd5..435aa89 100644 --- a/internal/pipewire/core.go +++ b/internal/pipewire/core.go @@ -525,20 +525,16 @@ func (e *UnknownBoundIdError[E]) Error() string { return "unknown bound proxy id " + strconv.Itoa(int(e.Id)) } -func (core *Core) consume(opcode byte, files []int, unmarshal func(v any) error) error { - if err := closeReceivedFiles(files...); err != nil { - return err - } - +func (core *Core) consume(opcode byte, files []int, unmarshal func(v any)) error { + closeReceivedFiles(files...) switch opcode { case PW_CORE_EVENT_INFO: - return unmarshal(&core.Info) + unmarshal(&core.Info) + return nil case PW_CORE_EVENT_DONE: var done CoreDone - if err := unmarshal(&done); err != nil { - return err - } + unmarshal(&done) if done.ID == roundtripSyncID && done.Sequence == CoreSyncSequenceOffset+core.ctx.sequence-1 { if core.done { return ErrUnexpectedDone @@ -553,16 +549,12 @@ func (core *Core) consume(opcode byte, files []int, unmarshal func(v any) error) case PW_CORE_EVENT_ERROR: var coreError CoreError - if err := unmarshal(&coreError); err != nil { - return err - } + unmarshal(&coreError) return &coreError case PW_CORE_EVENT_BOUND_PROPS: var boundProps CoreBoundProps - if err := unmarshal(&boundProps); err != nil { - return err - } + unmarshal(&boundProps) delete(core.ctx.pendingIds, boundProps.ID) proxy, ok := core.ctx.proxy[boundProps.ID] @@ -606,19 +598,15 @@ func (e *GlobalIDCollisionError) Error() string { " stepping on previous id " + strconv.Itoa(int(e.ID)) + " for " + e.Previous.Type } -func (registry *Registry) consume(opcode byte, files []int, unmarshal func(v any) error) error { - if err := closeReceivedFiles(files...); err != nil { - return err - } - +func (registry *Registry) consume(opcode byte, files []int, unmarshal func(v any)) error { + closeReceivedFiles(files...) switch opcode { case PW_REGISTRY_EVENT_GLOBAL: var global RegistryGlobal - if err := unmarshal(&global); err != nil { - return err - } + unmarshal(&global) if object, ok := registry.Objects[global.ID]; ok { - return &GlobalIDCollisionError{global.ID, &object, &global} + // this should never happen so is non-recoverable if it does + panic(&GlobalIDCollisionError{global.ID, &object, &global}) } registry.Objects[global.ID] = global return nil diff --git a/internal/pipewire/pipewire.go b/internal/pipewire/pipewire.go index f1a950c..0f0f98f 100644 --- a/internal/pipewire/pipewire.go +++ b/internal/pipewire/pipewire.go @@ -16,10 +16,12 @@ package pipewire import ( "encoding/binary" + "errors" "fmt" "io" "maps" "net" + "runtime" "slices" "strconv" "syscall" @@ -310,27 +312,35 @@ func (e UnsupportedFooterOpcodeError) Error() string { return "unsupported footer opcode " + strconv.Itoa(int(e)) } -// RoundtripUnexpectedEOFError is returned when EOF was unexpectedly encountered during [Context.Roundtrip]. +// A RoundtripUnexpectedEOFError describes an unexpected EOF encountered during [Context.Roundtrip]. type RoundtripUnexpectedEOFError uintptr const ( - roundtripEOFHeader RoundtripUnexpectedEOFError = iota - roundtripEOFBody - roundtripEOFFooter - roundtripEOFFooterOpcode + // ErrRoundtripEOFHeader is returned when unexpectedly encountering EOF + // decoding the message header. + ErrRoundtripEOFHeader RoundtripUnexpectedEOFError = iota + // ErrRoundtripEOFBody is returned when unexpectedly encountering EOF + // establishing message body bounds. + ErrRoundtripEOFBody + // ErrRoundtripEOFFooter is like [ErrRoundtripEOFBody], but for when establishing + // bounds for the footer instead. + ErrRoundtripEOFFooter + // ErrRoundtripEOFFooterOpcode is returned when unexpectedly encountering EOF + // during the footer opcode hack. + ErrRoundtripEOFFooterOpcode ) func (RoundtripUnexpectedEOFError) Unwrap() error { return io.ErrUnexpectedEOF } func (e RoundtripUnexpectedEOFError) Error() string { var suffix string switch e { - case roundtripEOFHeader: + case ErrRoundtripEOFHeader: suffix = "decoding message header" - case roundtripEOFBody: + case ErrRoundtripEOFBody: suffix = "establishing message body bounds" - case roundtripEOFFooter: + case ErrRoundtripEOFFooter: suffix = "establishing message footer bounds" - case roundtripEOFFooterOpcode: + case ErrRoundtripEOFFooterOpcode: suffix = "decoding message footer opcode" default: @@ -343,7 +353,7 @@ func (e RoundtripUnexpectedEOFError) Error() string { // eventProxy consumes events during a [Context.Roundtrip]. type eventProxy interface { // consume consumes an event and its optional footer. - consume(opcode byte, files []int, unmarshal func(v any) error) error + consume(opcode byte, files []int, unmarshal func(v any)) error // setBoundProps stores a [CoreBoundProps] event received from the server. setBoundProps(event *CoreBoundProps) error @@ -358,7 +368,7 @@ func (ctx *Context) unmarshal(header *Header, data []byte, v any) error { return err } if len(data) < int(header.Size) || header.Size < n { - return roundtripEOFFooter + return ErrRoundtripEOFFooter } isLastMessage := len(data) == int(header.Size) @@ -369,7 +379,7 @@ func (ctx *Context) unmarshal(header *Header, data []byte, v any) error { skip the struct prefix, then the integer prefix, and the next SizeId bytes are the encoded opcode value */ if len(data) < int(SizePrefix*2+SizeId) { - return roundtripEOFFooterOpcode + return ErrRoundtripEOFFooterOpcode } switch opcode := binary.NativeEndian.Uint32(data[SizePrefix*2:]); opcode { case FOOTER_CORE_OPCODE_GENERATION: @@ -433,6 +443,43 @@ func (e UnacknowledgedProxyError) Error() string { return "server did not acknowledge " + strconv.Itoa(len(e)) + " proxies" } +// A ProxyFatalError describes an error that terminates event handling during a +// [Context.Roundtrip] and makes further event processing no longer possible. +type ProxyFatalError struct { + // The fatal error causing the termination of event processing. + Err error + // Previous non-fatal proxy errors. + ProxyErrs []error +} + +func (e *ProxyFatalError) Unwrap() []error { return append(e.ProxyErrs, e.Err) } +func (e *ProxyFatalError) Error() string { + s := e.Err.Error() + if len(e.ProxyErrs) > 0 { + s += "; " + strconv.Itoa(len(e.ProxyErrs)) + " additional proxy errors occurred before this point" + } + return s +} + +// A ProxyConsumeError is a collection of non-protocol errors returned by proxies +// during event processing. These do not prevent event handling from continuing but +// may be considered fatal to the application. +type ProxyConsumeError []error + +func (e ProxyConsumeError) Unwrap() []error { return e } +func (e ProxyConsumeError) Error() string { + if len(e) == 0 { + return "invalid proxy consume error" + } + + // first error is usually the most relevant one + s := e[0].Error() + if len(e) > 1 { + s += "; " + strconv.Itoa(len(e)) + " additional proxy errors occurred after this point" + } + return s +} + // roundtripSyncID is the id passed to Context.coreSync during a [Context.Roundtrip]. const roundtripSyncID = 0 @@ -447,6 +494,52 @@ func (ctx *Context) Roundtrip() (err error) { if _, _, err = ctx.conn.WriteMsgUnix(ctx.buf, syscall.UnixRights(ctx.pendingFiles...), nil); err != nil { return } + + var ( + // this holds onto non-protocol errors encountered during event handling; + // errors that prevent event processing from continuing must be panicked + proxyErrors ProxyConsumeError + + // current position of processed events in ctx.receivedFiles, anything + // beyond this is closed if event processing is terminated + receivedHeaderFiles int + ) + defer func() { + // anything before this has already been processed and must not be closed + // here, as anything holding onto them will end up with a dangling fd that + // can be reused and cause serious problems + if len(ctx.receivedFiles) > receivedHeaderFiles { + for _, fd := range ctx.receivedFiles[receivedHeaderFiles:] { + _ = syscall.Close(fd) + } + + // this catches cases where Roundtrip somehow returns without processing + // all received files or preparing an error for dangling files, this is + // always overwritten by the fatal error being processed below or made + // inaccessible due to repanicking, so if this ends up returned to the + // caller it indicates something has gone seriously wrong in Roundtrip + if err == nil { + err = syscall.ENOTRECOVERABLE + } + } + + r := recover() + if r == nil { + return + } + + recoveredErr, ok := r.(error) + if !ok { + panic(r) + } + if recoveredErr == nil { + panic(&runtime.PanicNilError{}) + } + + err = &ProxyFatalError{Err: recoveredErr, ProxyErrs: proxyErrors} + return + }() + ctx.buf = ctx.buf[:0] ctx.pendingFiles = ctx.pendingFiles[:0] ctx.headerFiles = 0 @@ -457,10 +550,9 @@ func (ctx *Context) Roundtrip() (err error) { } var header Header - var receivedHeaderFiles int for len(data) > 0 { if len(data) < SizeHeader { - return roundtripEOFHeader + return ErrRoundtripEOFHeader } if err = header.UnmarshalBinary(data[:SizeHeader]); err != nil { @@ -472,7 +564,7 @@ func (ctx *Context) Roundtrip() (err error) { ctx.remoteSequence++ if len(data) < int(SizeHeader+header.Size) { - return roundtripEOFBody + return ErrRoundtripEOFBody } proxy, ok := ctx.proxy[header.ID] @@ -488,15 +580,26 @@ func (ctx *Context) Roundtrip() (err error) { receivedHeaderFiles = nextReceivedHeaderFiles data = data[SizeHeader:] - err = proxy.consume(header.Opcode, files, func(v any) error { return ctx.unmarshal(&header, data, v) }) + proxyErr := proxy.consume(header.Opcode, files, func(v any) { + if unmarshalErr := ctx.unmarshal(&header, data, v); unmarshalErr != nil { + panic(unmarshalErr) + } + }) data = data[header.Size:] - if err != nil { - return + if proxyErr != nil { + proxyErrors = append(proxyErrors, proxyErr) } } + var joinError []error + if len(proxyErrors) > 0 { + joinError = append(joinError, proxyErrors) + } if len(ctx.receivedFiles) < receivedHeaderFiles { - return DanglingFilesError(ctx.receivedFiles[len(ctx.receivedFiles)-receivedHeaderFiles:]) + joinError = append(joinError, DanglingFilesError(ctx.receivedFiles[len(ctx.receivedFiles)-receivedHeaderFiles:])) + } + if len(joinError) > 0 { + return errors.Join(joinError...) } if len(ctx.pendingIds) != 0 { @@ -505,24 +608,23 @@ func (ctx *Context) Roundtrip() (err error) { return } -// An UnexpectedFileCountError is returned for an event that received an unexpected -// number of files. The proxy closes these extra files before returning +// An UnexpectedFileCountError is returned as part of a [ProxyFatalError] for an event +// that received an unexpected number of files. type UnexpectedFileCountError [2]int func (e *UnexpectedFileCountError) Error() string { return "received " + strconv.Itoa(e[1]) + " files instead of the expected " + strconv.Itoa(e[0]) } -// closeReceivedFiles closes all received files and returns [UnexpectedFileCountError] +// closeReceivedFiles closes all received files and panics with [UnexpectedFileCountError] // if one or more files are passed. This is used with events that do not expect files. -func closeReceivedFiles(fds ...int) error { +func closeReceivedFiles(fds ...int) { for _, fd := range fds { _ = syscall.Close(fd) } - if len(fds) == 0 { - return nil + if len(fds) > 0 { + panic(&UnexpectedFileCountError{0, len(fds)}) } - return &UnexpectedFileCountError{0, len(fds)} } // Close frees the underlying buffer and closes the connection. diff --git a/internal/pipewire/pipewire_test.go b/internal/pipewire/pipewire_test.go index 950311d..f7df3ab 100644 --- a/internal/pipewire/pipewire_test.go +++ b/internal/pipewire/pipewire_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "hakurei.app/container/stub" "hakurei.app/internal/pipewire" ) @@ -862,3 +863,67 @@ func splitMessages(iovec string) (messages [][3][]byte) { } return } + +func TestContextErrors(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + err error + want string + }{ + {"ProxyConsumeError invalid", pipewire.ProxyConsumeError{}, "invalid proxy consume error"}, + {"ProxyConsumeError single", pipewire.ProxyConsumeError{ + stub.UniqueError(0), + }, "unique error 0 injected by the test suite"}, + {"ProxyConsumeError multiple", pipewire.ProxyConsumeError{ + stub.UniqueError(1), + stub.UniqueError(2), + stub.UniqueError(3), + stub.UniqueError(4), + stub.UniqueError(5), + stub.UniqueError(6), + stub.UniqueError(7), + }, "unique error 1 injected by the test suite; 7 additional proxy errors occurred after this point"}, + + {"ProxyFatalError", &pipewire.ProxyFatalError{ + Err: stub.UniqueError(8), + }, "unique error 8 injected by the test suite"}, + {"ProxyFatalError proxy errors", &pipewire.ProxyFatalError{ + Err: stub.UniqueError(9), + ProxyErrs: make([]error, 1<<4), + }, "unique error 9 injected by the test suite; 16 additional proxy errors occurred before this point"}, + + {"UnexpectedFileCountError", &pipewire.UnexpectedFileCountError{0, -1}, "received -1 files instead of the expected 0"}, + {"UnacknowledgedProxyError", make(pipewire.UnacknowledgedProxyError, 1<<4), "server did not acknowledge 16 proxies"}, + {"DanglingFilesError", make(pipewire.DanglingFilesError, 1<<4), "received 16 dangling files"}, + {"UnexpectedFilesError", pipewire.UnexpectedFilesError(1 << 4), "server message headers claim to have sent more than 16 files"}, + {"UnexpectedSequenceError", pipewire.UnexpectedSequenceError(1 << 4), "unexpected seq 16"}, + {"UnsupportedFooterOpcodeError", pipewire.UnsupportedFooterOpcodeError(1 << 4), "unsupported footer opcode 16"}, + + {"RoundtripUnexpectedEOFError ErrRoundtripEOFHeader", pipewire.ErrRoundtripEOFHeader, "unexpected EOF decoding message header"}, + {"RoundtripUnexpectedEOFError ErrRoundtripEOFBody", pipewire.ErrRoundtripEOFBody, "unexpected EOF establishing message body bounds"}, + {"RoundtripUnexpectedEOFError ErrRoundtripEOFFooter", pipewire.ErrRoundtripEOFFooter, "unexpected EOF establishing message footer bounds"}, + {"RoundtripUnexpectedEOFError ErrRoundtripEOFFooterOpcode", pipewire.ErrRoundtripEOFFooterOpcode, "unexpected EOF decoding message footer opcode"}, + {"RoundtripUnexpectedEOFError invalid", pipewire.RoundtripUnexpectedEOFError(0xbad), "unexpected EOF"}, + + {"UnsupportedOpcodeError", &pipewire.UnsupportedOpcodeError{ + Opcode: 0xff, + Interface: pipewire.PW_TYPE_INFO_INTERFACE_BASE + "Invalid", + }, "unsupported PipeWire:Interface:Invalid opcode 255"}, + + {"UnknownIdError", &pipewire.UnknownIdError{ + Id: -1, + Data: "\x00", + }, "unknown proxy id -1"}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if got := tc.err.Error(); got != tc.want { + t.Errorf("Error: %q, want %q", got, tc.want) + } + }) + } +} diff --git a/internal/pipewire/securitycontext.go b/internal/pipewire/securitycontext.go index 958673e..42eb16c 100644 --- a/internal/pipewire/securitycontext.go +++ b/internal/pipewire/securitycontext.go @@ -104,11 +104,8 @@ func (securityContext *SecurityContext) Create(listenFd, closeFd int, props SPAD ) } -func (securityContext *SecurityContext) consume(opcode byte, files []int, _ func(v any) error) error { - if err := closeReceivedFiles(files...); err != nil { - return err - } - +func (securityContext *SecurityContext) consume(opcode byte, files []int, _ func(v any)) error { + closeReceivedFiles(files...) switch opcode { // SecurityContext does not receive any events