From 8a2f9edcf963358aeeddf1ee544300e2ec72a533 Mon Sep 17 00:00:00 2001 From: Ophestra Date: Sat, 6 Dec 2025 01:51:57 +0900 Subject: [PATCH] internal/pipewire: use sendmsg/recvmsg directly The PipeWire protocol does not work with Go abstractions. This change makes relevant methods call sendmsg/recvmsg directly. Signed-off-by: Ophestra --- internal/pipewire/core.go | 22 +- internal/pipewire/pipewire.go | 317 +++++++++++++++++------------ internal/pipewire/pipewire_test.go | 74 +++---- 3 files changed, 232 insertions(+), 181 deletions(-) diff --git a/internal/pipewire/core.go b/internal/pipewire/core.go index 435aa89..c2913b2 100644 --- a/internal/pipewire/core.go +++ b/internal/pipewire/core.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strconv" + "time" ) /* pipewire/core.h */ @@ -348,17 +349,28 @@ func (ctx *Context) coreSync(id Int) error { // receiving a [CoreDone] event targeting the [CoreSync] event it delivered. var ErrNotDone = errors.New("did not receive a Core::Done event targeting previously delivered Core::Sync") +const ( + // syncTimeout is the maximum duration [Core.Sync] is allowed to take before + // receiving [CoreDone] or failing. + syncTimeout = 5 * time.Second +) + // Sync queues a [CoreSync] message for the PipeWire server and initiates a Roundtrip. func (core *Core) Sync() error { core.done = false if err := core.ctx.coreSync(roundtripSyncID); err != nil { return err } - if err := core.ctx.Roundtrip(); err != nil { - return err - } - if !core.done { - return ErrNotDone + + deadline := time.Now().Add(syncTimeout) + for !core.done { + if time.Now().After(deadline) { + return ErrNotDone + } + + if err := core.ctx.Roundtrip(); err != nil { + return err + } } return nil } diff --git a/internal/pipewire/pipewire.go b/internal/pipewire/pipewire.go index 00507e8..82301d0 100644 --- a/internal/pipewire/pipewire.go +++ b/internal/pipewire/pipewire.go @@ -19,59 +19,22 @@ import ( "fmt" "io" "maps" - "net" "os" "runtime" "slices" "strconv" "syscall" - "time" ) -// Conn is a subset of methods of [net.UnixConn] used by [Context]. +// Conn is a low level unix socket interface used by [Context]. type Conn interface { - // ReadMsgUnix reads a message from c, copying the payload into b and - // the associated out-of-band data into oob. It returns the number of - // bytes copied into b, the number of bytes copied into oob, the flags - // that were set on the message and the source address of the message. - // - // Note that if len(b) == 0 and len(oob) > 0, this function will still - // read (and discard) 1 byte from the connection. - ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *net.UnixAddr, err error) + // Recvmsg calls syscall.Recvmsg on the underlying socket. + Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) - // WriteMsgUnix writes a message to addr via c, copying the payload - // from b and the associated out-of-band data from oob. It returns the - // number of payload and out-of-band bytes written. - // - // Note that if len(b) == 0 and len(oob) > 0, this function will still - // write 1 byte to the connection. - WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oobn int, err error) - - // SetDeadline sets the read and write deadlines associated - // with the connection. It is equivalent to calling both - // SetReadDeadline and SetWriteDeadline. - // - // A deadline is an absolute time after which I/O operations - // fail instead of blocking. The deadline applies to all future - // and pending I/O, not just the immediately following call to - // Read or Write. After a deadline has been exceeded, the - // connection can be refreshed by setting a deadline in the future. - // - // If the deadline is exceeded a call to Read or Write or to other - // I/O methods will return an error that wraps os.ErrDeadlineExceeded. - // This can be tested using errors.Is(err, os.ErrDeadlineExceeded). - // The error's Timeout method will return true, but note that there - // are other possible errors for which the Timeout method will - // return true even if the deadline has not been exceeded. - // - // An idle timeout can be implemented by repeatedly extending - // the deadline after successful Read or Write calls. - // - // A zero value for t means I/O operations will not time out. - SetDeadline(t time.Time) error + // Sendmsg calls syscall.SendmsgN on the underlying socket. + Sendmsg(p, oob []byte, flags int) (n int, err error) // Close closes the connection. - // Any blocked Read or Write operations will be unblocked and return errors. Close() error } @@ -155,6 +118,45 @@ func New(conn Conn, props SPADict) (*Context, error) { return &ctx, nil } +// A SyscallConnCloser is a [syscall.Conn] that implements [io.Closer]. +type SyscallConnCloser interface { + syscall.Conn + io.Closer +} + +// A SyscallConn is a [Conn] adapter for [syscall.Conn]. +type SyscallConn struct{ SyscallConnCloser } + +// Recvmsg implements [Conn.Recvmsg] via [syscall.Conn.SyscallConn]. +func (conn SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) { + var rc syscall.RawConn + if rc, err = conn.SyscallConn(); err != nil { + return + } + + if controlErr := rc.Control(func(fd uintptr) { + n, oobn, recvflags, _, err = syscall.Recvmsg(int(fd), p, oob, flags) + }); controlErr != nil && err == nil { + err = controlErr + } + return +} + +// Sendmsg implements [Conn.Sendmsg] via [syscall.Conn.SyscallConn]. +func (conn SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { + var rc syscall.RawConn + if rc, err = conn.SyscallConn(); err != nil { + return + } + + if controlErr := rc.Control(func(fd uintptr) { + n, err = syscall.SendmsgN(int(fd), p, oob, nil, flags) + }); controlErr != nil && err == nil { + err = controlErr + } + return +} + // MustNew calls [New](conn, props) and panics on error. // It is intended for use in tests with hard-coded strings. func MustNew(conn Conn, props SPADict) *Context { @@ -226,66 +228,102 @@ increment: return newId } -// connTimeout is the maximum duration an I/O operation is allowed for [Conn]. -const connTimeout = 5 * time.Second - -// receiveAll receives from conn until no more data is available. -// The returned slice is valid until the next call to receiveAll. -func (ctx *Context) receiveAll() (payload []byte, err error) { - if err = ctx.conn.SetDeadline(time.Now().Add(connTimeout)); err != nil { - return +// closeReceivedFiles closes all receivedFiles. This is only during protocol error +// where [Context] is rendered unusable. +func (ctx *Context) closeReceivedFiles() { + slices.Sort(ctx.receivedFiles) + ctx.receivedFiles = slices.Compact(ctx.receivedFiles) + for _, fd := range ctx.receivedFiles { + _ = syscall.Close(fd) } - - var n, oobn int ctx.receivedFiles = ctx.receivedFiles[:0] - buf := ctx.iovecBuf[:] +} -recvmsg: - buf = buf[n:] - n, oobn, _, _, err = ctx.conn.ReadMsgUnix(buf, ctx.oobBuf[:]) - if err != nil { - return - } - if oobn == len(ctx.oobBuf) { - return nil, syscall.ENOMEM // unreachable - } - if oob := ctx.oobBuf[:oobn]; len(oob) > 0 { - var scm []syscall.SocketControlMessage - if scm, err = syscall.ParseSocketControlMessage(oob); err != nil { - return - } +// recvmsgFlags are flags passed to [Conn.Recvmsg] during Context.recvmsg. +const recvmsgFlags = syscall.MSG_CMSG_CLOEXEC | syscall.MSG_DONTWAIT - var fds []int - for i := range scm { - if fds, err = syscall.ParseUnixRights(&scm[i]); err != nil { +// recvmsg receives from conn and returns the received payload backed by +// iovecBuf. The returned slice is valid until the next call to recvmsg. +func (ctx *Context) recvmsg(remaining []byte) (payload []byte, err error) { + if copy(ctx.iovecBuf[:], remaining) != len(remaining) { + // should not be reachable with correct internal usage + return remaining, syscall.ENOMEM + } + + var n, oobn, recvflags int + + for { + n, oobn, recvflags, err = ctx.conn.Recvmsg(ctx.iovecBuf[len(remaining):], ctx.oobBuf[:], recvmsgFlags) + + if oob := ctx.oobBuf[:oobn]; len(oob) > 0 { + var scm []syscall.SocketControlMessage + if scm, err = syscall.ParseSocketControlMessage(oob); err != nil { + ctx.closeReceivedFiles() return } - ctx.receivedFiles = append(ctx.receivedFiles, fds...) + + var fds []int + for i := range scm { + if fds, err = syscall.ParseUnixRights(&scm[i]); err != nil { + ctx.closeReceivedFiles() + return + } + ctx.receivedFiles = append(ctx.receivedFiles, fds...) + } } + + if recvflags&syscall.MSG_CTRUNC != 0 { + // unreachable + ctx.closeReceivedFiles() + return nil, syscall.ENOMEM + } + + if err != nil { + if err == syscall.EINTR { + continue + } + if err != syscall.EAGAIN && err != syscall.EWOULDBLOCK { + ctx.closeReceivedFiles() + return nil, os.NewSyscallError("recvmsg", err) + } + } + + break } - // receive until buffer fills or payload is depleted - if n > 0 { - goto recvmsg + if n >= 0 { + payload = ctx.iovecBuf[:n] } - data := ctx.iovecBuf[:len(ctx.iovecBuf)-len(buf)] - - // avoids copy if payload fits in a single ctx.recvmsgBuf - if payload == nil && len(buf) > 0 { - payload = data - return - } - - payload = append(payload, data...) - // this indicates a full ctx.recvmsgBuf - if len(buf) == 0 { - ctx.buf = ctx.iovecBuf[:] - goto recvmsg - } - return } +// sendmsgFlags are flags passed to [Conn.Sendmsg] during Context.sendmsg. +const sendmsgFlags = syscall.MSG_NOSIGNAL | syscall.MSG_DONTWAIT + +// sendmsg sends p to conn. sendmsg does not retain p. +func (ctx *Context) sendmsg(p []byte, fds ...int) error { + var oob []byte + if len(fds) > 0 { + oob = syscall.UnixRights(fds...) + } + + for { + n, err := ctx.conn.Sendmsg(p, oob, sendmsgFlags) + if err == syscall.EINTR { + continue + } + + if err == nil && n != len(p) { + err = syscall.EMSGSIZE + } + + if err != nil && err != syscall.EAGAIN && err != syscall.EWOULDBLOCK { + return os.NewSyscallError("sendmsg", err) + } + return err + } +} + // An UnknownIdError describes a server message with an Id unknown to [Context]. type UnknownIdError struct { // Offending id decoded from Data. @@ -427,7 +465,7 @@ func (e UnexpectedSequenceError) Error() string { return "unexpected seq " + str type UnexpectedFilesError int func (e UnexpectedFilesError) Error() string { - return "server message headers claim to have sent more than " + strconv.Itoa(int(e)) + " files" + return "server message headers claim to have sent more files than actually received" } // A DanglingFilesError holds onto files that were sent by the server but no [Header] @@ -486,39 +524,49 @@ func (e ProxyConsumeError) Error() string { // roundtripSyncID is the id passed to Context.coreSync during a [Context.Roundtrip]. const roundtripSyncID = 0 -// Roundtrip queues the [CoreSync] message and sends all pending messages to the server. +// Roundtrip sends all pending messages to the server and processes events until +// the server has no more messages. // // For a non-nil error, if the error happens over the network, it has concrete type -// [net.OpError]. +// [os.SyscallError]. func (ctx *Context) Roundtrip() (err error) { - if err = ctx.conn.SetDeadline(time.Now().Add(connTimeout)); err != nil { - return - } - var unixRightsOob []byte - if len(ctx.pendingFiles) > 0 { - unixRightsOob = syscall.UnixRights(ctx.pendingFiles...) - } - if _, _, err = ctx.conn.WriteMsgUnix(ctx.buf, unixRightsOob, nil); err != nil { + if err = ctx.sendmsg(ctx.buf, ctx.pendingFiles...); err != nil { return } + var remaining []byte + for { + remaining, err = ctx.roundtrip(remaining) + if err == nil { + continue + } + + if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK { + if len(remaining) == 0 { + err = nil + } else if len(remaining) < SizeHeader { + err = ErrRoundtripEOFHeader + } else { + err = ErrRoundtripEOFBody + } + } + return + } +} + +// roundtrip receives messages from the server and processes events. +func (ctx *Context) roundtrip(receiveRemaining []byte) (remaining []byte, err error) { var ( // this holds onto non-protocol errors encountered during event handling; // errors that prevent event processing from continuing must be panicked proxyErrors ProxyConsumeError - - // current position of processed events in ctx.receivedFiles, anything - // beyond this is closed if event processing is terminated - receivedHeaderFiles int ) defer func() { // anything before this has already been processed and must not be closed // here, as anything holding onto them will end up with a dangling fd that // can be reused and cause serious problems - if len(ctx.receivedFiles) > receivedHeaderFiles { - for _, fd := range ctx.receivedFiles[receivedHeaderFiles:] { - _ = syscall.Close(fd) - } + if len(ctx.receivedFiles) > 0 { + ctx.closeReceivedFiles() // this catches cases where Roundtrip somehow returns without processing // all received files or preparing an error for dangling files, this is @@ -551,48 +599,47 @@ func (ctx *Context) Roundtrip() (err error) { ctx.pendingFiles = ctx.pendingFiles[:0] ctx.headerFiles = 0 - var data []byte - if data, err = ctx.receiveAll(); err != nil { + if remaining, err = ctx.recvmsg(receiveRemaining); err != nil { return } var header Header - for len(data) > 0 { - if len(data) < SizeHeader { - return ErrRoundtripEOFHeader + for len(remaining) > 0 { + if len(remaining) < SizeHeader { + return } - if err = header.UnmarshalBinary(data[:SizeHeader]); err != nil { + if err = header.UnmarshalBinary(remaining[:SizeHeader]); err != nil { return } if header.Sequence != ctx.remoteSequence { - return UnexpectedSequenceError(header.Sequence) + return remaining, UnexpectedSequenceError(header.Sequence) } ctx.remoteSequence++ - if len(data) < int(SizeHeader+header.Size) { - return ErrRoundtripEOFBody + if len(remaining) < int(SizeHeader+header.Size) { + return } proxy, ok := ctx.proxy[header.ID] if !ok { - return &UnknownIdError{header.ID, string(data[:SizeHeader+header.Size])} + return remaining, &UnknownIdError{header.ID, string(remaining[:SizeHeader+header.Size])} } - nextReceivedHeaderFiles := receivedHeaderFiles + int(header.FileCount) - if nextReceivedHeaderFiles > len(ctx.receivedFiles) { - return UnexpectedFilesError(len(ctx.receivedFiles)) + fileCount := int(header.FileCount) + if fileCount > len(ctx.receivedFiles) { + return remaining, UnexpectedFilesError(fileCount) } - files := ctx.receivedFiles[receivedHeaderFiles:nextReceivedHeaderFiles] - receivedHeaderFiles = nextReceivedHeaderFiles + files := ctx.receivedFiles[:fileCount] + ctx.receivedFiles = ctx.receivedFiles[fileCount:] - data = data[SizeHeader:] + remaining = remaining[SizeHeader:] proxyErr := proxy.consume(header.Opcode, files, func(v any) { - if unmarshalErr := ctx.unmarshal(&header, data, v); unmarshalErr != nil { + if unmarshalErr := ctx.unmarshal(&header, remaining, v); unmarshalErr != nil { panic(unmarshalErr) } }) - data = data[header.Size:] + remaining = remaining[header.Size:] if proxyErr != nil { proxyErrors = append(proxyErrors, proxyErr) } @@ -601,36 +648,36 @@ func (ctx *Context) Roundtrip() (err error) { // prepared here so finalizers are set up, but should not prevent proxyErrors // from reaching the caller as those describe the cause of these dangling fds var danglingFiles DanglingFilesError - if len(ctx.receivedFiles) > receivedHeaderFiles { - danglingFds := ctx.receivedFiles[receivedHeaderFiles:] + if len(ctx.receivedFiles) > 0 { // having multiple *os.File with the same fd causes serious problems - slices.Sort(danglingFds) - danglingFds = slices.Compact(danglingFds) + slices.Sort(ctx.receivedFiles) + ctx.receivedFiles = slices.Compact(ctx.receivedFiles) - danglingFiles = make(DanglingFilesError, 0, len(danglingFds)) - for _, fd := range danglingFds { + danglingFiles = make(DanglingFilesError, 0, len(ctx.receivedFiles)) + for _, fd := range ctx.receivedFiles { // hold these as *os.File so they are closed if this error never reaches the caller, // or the caller discards or otherwise does not handle this error, to avoid leaking fds danglingFiles = append(danglingFiles, os.NewFile(uintptr(fd), "dangling fd "+strconv.Itoa(fd)+" received from PipeWire")) } + ctx.receivedFiles = ctx.receivedFiles[:0] } // these are checked and made available first since they describe the cause // of so-called symptoms checked after this point; the symptoms should only // be made available as a catch-all if these are unavailable if len(proxyErrors) > 0 { - return proxyErrors + return remaining, proxyErrors } // populated early for finalizers if len(danglingFiles) > 0 { - return danglingFiles + return remaining, danglingFiles } // this check must happen after everything else passes if len(ctx.pendingIds) != 0 { - return UnacknowledgedProxyError(slices.Collect(maps.Keys(ctx.pendingIds))) + return remaining, UnacknowledgedProxyError(slices.Collect(maps.Keys(ctx.pendingIds))) } return } diff --git a/internal/pipewire/pipewire_test.go b/internal/pipewire/pipewire_test.go index fb7a61e..b6de283 100644 --- a/internal/pipewire/pipewire_test.go +++ b/internal/pipewire/pipewire_test.go @@ -2,11 +2,10 @@ package pipewire_test import ( "fmt" - "net" "reflect" + "strconv" . "syscall" "testing" - "time" "hakurei.app/container/stub" "hakurei.app/internal/pipewire" @@ -707,7 +706,7 @@ func TestContext(t *testing.T) { type stubUnixConnSample struct { nr uintptr iovec string - flags uintptr + flags int files []int errno Errno } @@ -716,17 +715,6 @@ type stubUnixConnSample struct { type stubUnixConn struct { samples []stubUnixConnSample current int - - deadline *time.Time -} - -// checkDeadline checks whether deadline is set reasonably. -func (conn *stubUnixConn) checkDeadline() error { - if conn.deadline == nil || conn.deadline.Before(time.Now()) { - return fmt.Errorf("invalid deadline %v", conn.deadline) - } - conn.deadline = nil - return nil } // nextSample returns the current sample and increments the counter. @@ -744,13 +732,7 @@ func (conn *stubUnixConn) nextSample(nr uintptr) (sample *stubUnixConnSample, wa return } -func (conn *stubUnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *net.UnixAddr, err error) { - if conn.samples[conn.current-1].nr == SYS_SENDMSG { - if err = conn.checkDeadline(); err != nil { - return - } - } - +func (conn *stubUnixConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) { var ( sample *stubUnixConnSample wantOOB []byte @@ -760,28 +742,31 @@ func (conn *stubUnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr * return } - if copy(b, sample.iovec) != len(sample.iovec) { - err = fmt.Errorf("insufficient iovec size %d, want at least %d", len(b), len(sample.iovec)) + if n = copy(p, sample.iovec); n != len(sample.iovec) { + err = fmt.Errorf("insufficient iovec size %d, want at least %d", len(p), len(sample.iovec)) + return } - if copy(oob, wantOOB) != len(wantOOB) { + if oobn = copy(oob, wantOOB); oobn != len(wantOOB) { err = fmt.Errorf("insufficient oob size %d, want at least %d", len(oob), len(wantOOB)) + return + } + if flags != sample.flags { + err = fmt.Errorf("flags = %#x, want %#x", flags, sample.flags) + return } - if sample.errno != 0 && sample.errno != EAGAIN { + recvflags = MSG_CMSG_CLOEXEC + if sample.errno != 0 { err = sample.errno + if n != 0 { + panic("invalid recvmsg: n = " + strconv.Itoa(n)) + } + n = -1 } - return len(sample.iovec), len(wantOOB), MSG_CMSG_CLOEXEC, nil, nil + return } -func (conn *stubUnixConn) WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oobn int, err error) { - if addr != nil { - err = fmt.Errorf("WriteMsgUnix called with non-nil addr: %#v", addr) - return - } - if err = conn.checkDeadline(); err != nil { - return - } - +func (conn *stubUnixConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { var ( sample *stubUnixConnSample wantOOB []byte @@ -791,18 +776,25 @@ func (conn *stubUnixConn) WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oo return } - if string(b) != sample.iovec { - err = fmt.Errorf("iovec: %#v, want %#v", b, []byte(sample.iovec)) + if string(p) != sample.iovec { + err = fmt.Errorf("iovec: %#v, want %#v", p, []byte(sample.iovec)) return } if string(oob[:len(wantOOB)]) != string(wantOOB) { err = fmt.Errorf("oob: %#v, want %#v", oob[:len(wantOOB)], wantOOB) return } - return len(sample.iovec), len(wantOOB), nil -} + if flags != sample.flags { + err = fmt.Errorf("flags = %#x, want %#x", flags, sample.flags) + return + } -func (conn *stubUnixConn) SetDeadline(t time.Time) error { conn.deadline = &t; return nil } + n = len(sample.iovec) + if sample.errno != 0 { + err = sample.errno + } + return +} func (conn *stubUnixConn) Close() error { if conn.current != len(conn.samples) { @@ -844,7 +836,7 @@ func TestContextErrors(t *testing.T) { {"UnexpectedFileCountError", &pipewire.UnexpectedFileCountError{0, -1}, "received -1 files instead of the expected 0"}, {"UnacknowledgedProxyError", make(pipewire.UnacknowledgedProxyError, 1<<4), "server did not acknowledge 16 proxies"}, {"DanglingFilesError", make(pipewire.DanglingFilesError, 1<<4), "received 16 dangling files"}, - {"UnexpectedFilesError", pipewire.UnexpectedFilesError(1 << 4), "server message headers claim to have sent more than 16 files"}, + {"UnexpectedFilesError", pipewire.UnexpectedFilesError(1 << 4), "server message headers claim to have sent more files than actually received"}, {"UnexpectedSequenceError", pipewire.UnexpectedSequenceError(1 << 4), "unexpected seq 16"}, {"UnsupportedFooterOpcodeError", pipewire.UnsupportedFooterOpcodeError(1 << 4), "unsupported footer opcode 16"},