internal/pipewire: use sendmsg/recvmsg directly
All checks were successful
Test / Create distribution (push) Successful in 34s
Test / Sandbox (push) Successful in 2m41s
Test / Sandbox (race detector) (push) Successful in 4m46s
Test / Hakurei (push) Successful in 4m58s
Test / Hpkg (push) Successful in 5m4s
Test / Hakurei (race detector) (push) Successful in 6m32s
Test / Flake checks (push) Successful in 1m29s

The PipeWire protocol does not work with Go abstractions. This change makes relevant methods call sendmsg/recvmsg directly.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
Ophestra 2025-12-06 01:51:57 +09:00
parent 0d3f332d45
commit 8a2f9edcf9
Signed by: cat
SSH Key Fingerprint: SHA256:gQ67O0enBZ7UdZypgtspB2FDM1g3GVw8nX0XSdcFw8Q
3 changed files with 232 additions and 181 deletions

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"strconv"
"time"
)
/* pipewire/core.h */
@ -348,17 +349,28 @@ func (ctx *Context) coreSync(id Int) error {
// receiving a [CoreDone] event targeting the [CoreSync] event it delivered.
var ErrNotDone = errors.New("did not receive a Core::Done event targeting previously delivered Core::Sync")
const (
// syncTimeout is the maximum duration [Core.Sync] is allowed to take before
// receiving [CoreDone] or failing.
syncTimeout = 5 * time.Second
)
// Sync queues a [CoreSync] message for the PipeWire server and initiates a Roundtrip.
func (core *Core) Sync() error {
core.done = false
if err := core.ctx.coreSync(roundtripSyncID); err != nil {
return err
}
if err := core.ctx.Roundtrip(); err != nil {
return err
}
if !core.done {
return ErrNotDone
deadline := time.Now().Add(syncTimeout)
for !core.done {
if time.Now().After(deadline) {
return ErrNotDone
}
if err := core.ctx.Roundtrip(); err != nil {
return err
}
}
return nil
}

View File

@ -19,59 +19,22 @@ import (
"fmt"
"io"
"maps"
"net"
"os"
"runtime"
"slices"
"strconv"
"syscall"
"time"
)
// Conn is a subset of methods of [net.UnixConn] used by [Context].
// Conn is a low level unix socket interface used by [Context].
type Conn interface {
// ReadMsgUnix reads a message from c, copying the payload into b and
// the associated out-of-band data into oob. It returns the number of
// bytes copied into b, the number of bytes copied into oob, the flags
// that were set on the message and the source address of the message.
//
// Note that if len(b) == 0 and len(oob) > 0, this function will still
// read (and discard) 1 byte from the connection.
ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *net.UnixAddr, err error)
// Recvmsg calls syscall.Recvmsg on the underlying socket.
Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error)
// WriteMsgUnix writes a message to addr via c, copying the payload
// from b and the associated out-of-band data from oob. It returns the
// number of payload and out-of-band bytes written.
//
// Note that if len(b) == 0 and len(oob) > 0, this function will still
// write 1 byte to the connection.
WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oobn int, err error)
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
//
// A deadline is an absolute time after which I/O operations
// fail instead of blocking. The deadline applies to all future
// and pending I/O, not just the immediately following call to
// Read or Write. After a deadline has been exceeded, the
// connection can be refreshed by setting a deadline in the future.
//
// If the deadline is exceeded a call to Read or Write or to other
// I/O methods will return an error that wraps os.ErrDeadlineExceeded.
// This can be tested using errors.Is(err, os.ErrDeadlineExceeded).
// The error's Timeout method will return true, but note that there
// are other possible errors for which the Timeout method will
// return true even if the deadline has not been exceeded.
//
// An idle timeout can be implemented by repeatedly extending
// the deadline after successful Read or Write calls.
//
// A zero value for t means I/O operations will not time out.
SetDeadline(t time.Time) error
// Sendmsg calls syscall.SendmsgN on the underlying socket.
Sendmsg(p, oob []byte, flags int) (n int, err error)
// Close closes the connection.
// Any blocked Read or Write operations will be unblocked and return errors.
Close() error
}
@ -155,6 +118,45 @@ func New(conn Conn, props SPADict) (*Context, error) {
return &ctx, nil
}
// A SyscallConnCloser is a [syscall.Conn] that implements [io.Closer].
type SyscallConnCloser interface {
syscall.Conn
io.Closer
}
// A SyscallConn is a [Conn] adapter for [syscall.Conn].
type SyscallConn struct{ SyscallConnCloser }
// Recvmsg implements [Conn.Recvmsg] via [syscall.Conn.SyscallConn].
func (conn SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) {
var rc syscall.RawConn
if rc, err = conn.SyscallConn(); err != nil {
return
}
if controlErr := rc.Control(func(fd uintptr) {
n, oobn, recvflags, _, err = syscall.Recvmsg(int(fd), p, oob, flags)
}); controlErr != nil && err == nil {
err = controlErr
}
return
}
// Sendmsg implements [Conn.Sendmsg] via [syscall.Conn.SyscallConn].
func (conn SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) {
var rc syscall.RawConn
if rc, err = conn.SyscallConn(); err != nil {
return
}
if controlErr := rc.Control(func(fd uintptr) {
n, err = syscall.SendmsgN(int(fd), p, oob, nil, flags)
}); controlErr != nil && err == nil {
err = controlErr
}
return
}
// MustNew calls [New](conn, props) and panics on error.
// It is intended for use in tests with hard-coded strings.
func MustNew(conn Conn, props SPADict) *Context {
@ -226,66 +228,102 @@ increment:
return newId
}
// connTimeout is the maximum duration an I/O operation is allowed for [Conn].
const connTimeout = 5 * time.Second
// receiveAll receives from conn until no more data is available.
// The returned slice is valid until the next call to receiveAll.
func (ctx *Context) receiveAll() (payload []byte, err error) {
if err = ctx.conn.SetDeadline(time.Now().Add(connTimeout)); err != nil {
return
// closeReceivedFiles closes all receivedFiles. This is only during protocol error
// where [Context] is rendered unusable.
func (ctx *Context) closeReceivedFiles() {
slices.Sort(ctx.receivedFiles)
ctx.receivedFiles = slices.Compact(ctx.receivedFiles)
for _, fd := range ctx.receivedFiles {
_ = syscall.Close(fd)
}
var n, oobn int
ctx.receivedFiles = ctx.receivedFiles[:0]
buf := ctx.iovecBuf[:]
}
recvmsg:
buf = buf[n:]
n, oobn, _, _, err = ctx.conn.ReadMsgUnix(buf, ctx.oobBuf[:])
if err != nil {
return
}
if oobn == len(ctx.oobBuf) {
return nil, syscall.ENOMEM // unreachable
}
if oob := ctx.oobBuf[:oobn]; len(oob) > 0 {
var scm []syscall.SocketControlMessage
if scm, err = syscall.ParseSocketControlMessage(oob); err != nil {
return
}
// recvmsgFlags are flags passed to [Conn.Recvmsg] during Context.recvmsg.
const recvmsgFlags = syscall.MSG_CMSG_CLOEXEC | syscall.MSG_DONTWAIT
var fds []int
for i := range scm {
if fds, err = syscall.ParseUnixRights(&scm[i]); err != nil {
// recvmsg receives from conn and returns the received payload backed by
// iovecBuf. The returned slice is valid until the next call to recvmsg.
func (ctx *Context) recvmsg(remaining []byte) (payload []byte, err error) {
if copy(ctx.iovecBuf[:], remaining) != len(remaining) {
// should not be reachable with correct internal usage
return remaining, syscall.ENOMEM
}
var n, oobn, recvflags int
for {
n, oobn, recvflags, err = ctx.conn.Recvmsg(ctx.iovecBuf[len(remaining):], ctx.oobBuf[:], recvmsgFlags)
if oob := ctx.oobBuf[:oobn]; len(oob) > 0 {
var scm []syscall.SocketControlMessage
if scm, err = syscall.ParseSocketControlMessage(oob); err != nil {
ctx.closeReceivedFiles()
return
}
ctx.receivedFiles = append(ctx.receivedFiles, fds...)
var fds []int
for i := range scm {
if fds, err = syscall.ParseUnixRights(&scm[i]); err != nil {
ctx.closeReceivedFiles()
return
}
ctx.receivedFiles = append(ctx.receivedFiles, fds...)
}
}
if recvflags&syscall.MSG_CTRUNC != 0 {
// unreachable
ctx.closeReceivedFiles()
return nil, syscall.ENOMEM
}
if err != nil {
if err == syscall.EINTR {
continue
}
if err != syscall.EAGAIN && err != syscall.EWOULDBLOCK {
ctx.closeReceivedFiles()
return nil, os.NewSyscallError("recvmsg", err)
}
}
break
}
// receive until buffer fills or payload is depleted
if n > 0 {
goto recvmsg
if n >= 0 {
payload = ctx.iovecBuf[:n]
}
data := ctx.iovecBuf[:len(ctx.iovecBuf)-len(buf)]
// avoids copy if payload fits in a single ctx.recvmsgBuf
if payload == nil && len(buf) > 0 {
payload = data
return
}
payload = append(payload, data...)
// this indicates a full ctx.recvmsgBuf
if len(buf) == 0 {
ctx.buf = ctx.iovecBuf[:]
goto recvmsg
}
return
}
// sendmsgFlags are flags passed to [Conn.Sendmsg] during Context.sendmsg.
const sendmsgFlags = syscall.MSG_NOSIGNAL | syscall.MSG_DONTWAIT
// sendmsg sends p to conn. sendmsg does not retain p.
func (ctx *Context) sendmsg(p []byte, fds ...int) error {
var oob []byte
if len(fds) > 0 {
oob = syscall.UnixRights(fds...)
}
for {
n, err := ctx.conn.Sendmsg(p, oob, sendmsgFlags)
if err == syscall.EINTR {
continue
}
if err == nil && n != len(p) {
err = syscall.EMSGSIZE
}
if err != nil && err != syscall.EAGAIN && err != syscall.EWOULDBLOCK {
return os.NewSyscallError("sendmsg", err)
}
return err
}
}
// An UnknownIdError describes a server message with an Id unknown to [Context].
type UnknownIdError struct {
// Offending id decoded from Data.
@ -427,7 +465,7 @@ func (e UnexpectedSequenceError) Error() string { return "unexpected seq " + str
type UnexpectedFilesError int
func (e UnexpectedFilesError) Error() string {
return "server message headers claim to have sent more than " + strconv.Itoa(int(e)) + " files"
return "server message headers claim to have sent more files than actually received"
}
// A DanglingFilesError holds onto files that were sent by the server but no [Header]
@ -486,39 +524,49 @@ func (e ProxyConsumeError) Error() string {
// roundtripSyncID is the id passed to Context.coreSync during a [Context.Roundtrip].
const roundtripSyncID = 0
// Roundtrip queues the [CoreSync] message and sends all pending messages to the server.
// Roundtrip sends all pending messages to the server and processes events until
// the server has no more messages.
//
// For a non-nil error, if the error happens over the network, it has concrete type
// [net.OpError].
// [os.SyscallError].
func (ctx *Context) Roundtrip() (err error) {
if err = ctx.conn.SetDeadline(time.Now().Add(connTimeout)); err != nil {
return
}
var unixRightsOob []byte
if len(ctx.pendingFiles) > 0 {
unixRightsOob = syscall.UnixRights(ctx.pendingFiles...)
}
if _, _, err = ctx.conn.WriteMsgUnix(ctx.buf, unixRightsOob, nil); err != nil {
if err = ctx.sendmsg(ctx.buf, ctx.pendingFiles...); err != nil {
return
}
var remaining []byte
for {
remaining, err = ctx.roundtrip(remaining)
if err == nil {
continue
}
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
if len(remaining) == 0 {
err = nil
} else if len(remaining) < SizeHeader {
err = ErrRoundtripEOFHeader
} else {
err = ErrRoundtripEOFBody
}
}
return
}
}
// roundtrip receives messages from the server and processes events.
func (ctx *Context) roundtrip(receiveRemaining []byte) (remaining []byte, err error) {
var (
// this holds onto non-protocol errors encountered during event handling;
// errors that prevent event processing from continuing must be panicked
proxyErrors ProxyConsumeError
// current position of processed events in ctx.receivedFiles, anything
// beyond this is closed if event processing is terminated
receivedHeaderFiles int
)
defer func() {
// anything before this has already been processed and must not be closed
// here, as anything holding onto them will end up with a dangling fd that
// can be reused and cause serious problems
if len(ctx.receivedFiles) > receivedHeaderFiles {
for _, fd := range ctx.receivedFiles[receivedHeaderFiles:] {
_ = syscall.Close(fd)
}
if len(ctx.receivedFiles) > 0 {
ctx.closeReceivedFiles()
// this catches cases where Roundtrip somehow returns without processing
// all received files or preparing an error for dangling files, this is
@ -551,48 +599,47 @@ func (ctx *Context) Roundtrip() (err error) {
ctx.pendingFiles = ctx.pendingFiles[:0]
ctx.headerFiles = 0
var data []byte
if data, err = ctx.receiveAll(); err != nil {
if remaining, err = ctx.recvmsg(receiveRemaining); err != nil {
return
}
var header Header
for len(data) > 0 {
if len(data) < SizeHeader {
return ErrRoundtripEOFHeader
for len(remaining) > 0 {
if len(remaining) < SizeHeader {
return
}
if err = header.UnmarshalBinary(data[:SizeHeader]); err != nil {
if err = header.UnmarshalBinary(remaining[:SizeHeader]); err != nil {
return
}
if header.Sequence != ctx.remoteSequence {
return UnexpectedSequenceError(header.Sequence)
return remaining, UnexpectedSequenceError(header.Sequence)
}
ctx.remoteSequence++
if len(data) < int(SizeHeader+header.Size) {
return ErrRoundtripEOFBody
if len(remaining) < int(SizeHeader+header.Size) {
return
}
proxy, ok := ctx.proxy[header.ID]
if !ok {
return &UnknownIdError{header.ID, string(data[:SizeHeader+header.Size])}
return remaining, &UnknownIdError{header.ID, string(remaining[:SizeHeader+header.Size])}
}
nextReceivedHeaderFiles := receivedHeaderFiles + int(header.FileCount)
if nextReceivedHeaderFiles > len(ctx.receivedFiles) {
return UnexpectedFilesError(len(ctx.receivedFiles))
fileCount := int(header.FileCount)
if fileCount > len(ctx.receivedFiles) {
return remaining, UnexpectedFilesError(fileCount)
}
files := ctx.receivedFiles[receivedHeaderFiles:nextReceivedHeaderFiles]
receivedHeaderFiles = nextReceivedHeaderFiles
files := ctx.receivedFiles[:fileCount]
ctx.receivedFiles = ctx.receivedFiles[fileCount:]
data = data[SizeHeader:]
remaining = remaining[SizeHeader:]
proxyErr := proxy.consume(header.Opcode, files, func(v any) {
if unmarshalErr := ctx.unmarshal(&header, data, v); unmarshalErr != nil {
if unmarshalErr := ctx.unmarshal(&header, remaining, v); unmarshalErr != nil {
panic(unmarshalErr)
}
})
data = data[header.Size:]
remaining = remaining[header.Size:]
if proxyErr != nil {
proxyErrors = append(proxyErrors, proxyErr)
}
@ -601,36 +648,36 @@ func (ctx *Context) Roundtrip() (err error) {
// prepared here so finalizers are set up, but should not prevent proxyErrors
// from reaching the caller as those describe the cause of these dangling fds
var danglingFiles DanglingFilesError
if len(ctx.receivedFiles) > receivedHeaderFiles {
danglingFds := ctx.receivedFiles[receivedHeaderFiles:]
if len(ctx.receivedFiles) > 0 {
// having multiple *os.File with the same fd causes serious problems
slices.Sort(danglingFds)
danglingFds = slices.Compact(danglingFds)
slices.Sort(ctx.receivedFiles)
ctx.receivedFiles = slices.Compact(ctx.receivedFiles)
danglingFiles = make(DanglingFilesError, 0, len(danglingFds))
for _, fd := range danglingFds {
danglingFiles = make(DanglingFilesError, 0, len(ctx.receivedFiles))
for _, fd := range ctx.receivedFiles {
// hold these as *os.File so they are closed if this error never reaches the caller,
// or the caller discards or otherwise does not handle this error, to avoid leaking fds
danglingFiles = append(danglingFiles, os.NewFile(uintptr(fd),
"dangling fd "+strconv.Itoa(fd)+" received from PipeWire"))
}
ctx.receivedFiles = ctx.receivedFiles[:0]
}
// these are checked and made available first since they describe the cause
// of so-called symptoms checked after this point; the symptoms should only
// be made available as a catch-all if these are unavailable
if len(proxyErrors) > 0 {
return proxyErrors
return remaining, proxyErrors
}
// populated early for finalizers
if len(danglingFiles) > 0 {
return danglingFiles
return remaining, danglingFiles
}
// this check must happen after everything else passes
if len(ctx.pendingIds) != 0 {
return UnacknowledgedProxyError(slices.Collect(maps.Keys(ctx.pendingIds)))
return remaining, UnacknowledgedProxyError(slices.Collect(maps.Keys(ctx.pendingIds)))
}
return
}

View File

@ -2,11 +2,10 @@ package pipewire_test
import (
"fmt"
"net"
"reflect"
"strconv"
. "syscall"
"testing"
"time"
"hakurei.app/container/stub"
"hakurei.app/internal/pipewire"
@ -707,7 +706,7 @@ func TestContext(t *testing.T) {
type stubUnixConnSample struct {
nr uintptr
iovec string
flags uintptr
flags int
files []int
errno Errno
}
@ -716,17 +715,6 @@ type stubUnixConnSample struct {
type stubUnixConn struct {
samples []stubUnixConnSample
current int
deadline *time.Time
}
// checkDeadline checks whether deadline is set reasonably.
func (conn *stubUnixConn) checkDeadline() error {
if conn.deadline == nil || conn.deadline.Before(time.Now()) {
return fmt.Errorf("invalid deadline %v", conn.deadline)
}
conn.deadline = nil
return nil
}
// nextSample returns the current sample and increments the counter.
@ -744,13 +732,7 @@ func (conn *stubUnixConn) nextSample(nr uintptr) (sample *stubUnixConnSample, wa
return
}
func (conn *stubUnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *net.UnixAddr, err error) {
if conn.samples[conn.current-1].nr == SYS_SENDMSG {
if err = conn.checkDeadline(); err != nil {
return
}
}
func (conn *stubUnixConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) {
var (
sample *stubUnixConnSample
wantOOB []byte
@ -760,28 +742,31 @@ func (conn *stubUnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *
return
}
if copy(b, sample.iovec) != len(sample.iovec) {
err = fmt.Errorf("insufficient iovec size %d, want at least %d", len(b), len(sample.iovec))
if n = copy(p, sample.iovec); n != len(sample.iovec) {
err = fmt.Errorf("insufficient iovec size %d, want at least %d", len(p), len(sample.iovec))
return
}
if copy(oob, wantOOB) != len(wantOOB) {
if oobn = copy(oob, wantOOB); oobn != len(wantOOB) {
err = fmt.Errorf("insufficient oob size %d, want at least %d", len(oob), len(wantOOB))
return
}
if flags != sample.flags {
err = fmt.Errorf("flags = %#x, want %#x", flags, sample.flags)
return
}
if sample.errno != 0 && sample.errno != EAGAIN {
recvflags = MSG_CMSG_CLOEXEC
if sample.errno != 0 {
err = sample.errno
if n != 0 {
panic("invalid recvmsg: n = " + strconv.Itoa(n))
}
n = -1
}
return len(sample.iovec), len(wantOOB), MSG_CMSG_CLOEXEC, nil, nil
return
}
func (conn *stubUnixConn) WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oobn int, err error) {
if addr != nil {
err = fmt.Errorf("WriteMsgUnix called with non-nil addr: %#v", addr)
return
}
if err = conn.checkDeadline(); err != nil {
return
}
func (conn *stubUnixConn) Sendmsg(p, oob []byte, flags int) (n int, err error) {
var (
sample *stubUnixConnSample
wantOOB []byte
@ -791,18 +776,25 @@ func (conn *stubUnixConn) WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oo
return
}
if string(b) != sample.iovec {
err = fmt.Errorf("iovec: %#v, want %#v", b, []byte(sample.iovec))
if string(p) != sample.iovec {
err = fmt.Errorf("iovec: %#v, want %#v", p, []byte(sample.iovec))
return
}
if string(oob[:len(wantOOB)]) != string(wantOOB) {
err = fmt.Errorf("oob: %#v, want %#v", oob[:len(wantOOB)], wantOOB)
return
}
return len(sample.iovec), len(wantOOB), nil
}
if flags != sample.flags {
err = fmt.Errorf("flags = %#x, want %#x", flags, sample.flags)
return
}
func (conn *stubUnixConn) SetDeadline(t time.Time) error { conn.deadline = &t; return nil }
n = len(sample.iovec)
if sample.errno != 0 {
err = sample.errno
}
return
}
func (conn *stubUnixConn) Close() error {
if conn.current != len(conn.samples) {
@ -844,7 +836,7 @@ func TestContextErrors(t *testing.T) {
{"UnexpectedFileCountError", &pipewire.UnexpectedFileCountError{0, -1}, "received -1 files instead of the expected 0"},
{"UnacknowledgedProxyError", make(pipewire.UnacknowledgedProxyError, 1<<4), "server did not acknowledge 16 proxies"},
{"DanglingFilesError", make(pipewire.DanglingFilesError, 1<<4), "received 16 dangling files"},
{"UnexpectedFilesError", pipewire.UnexpectedFilesError(1 << 4), "server message headers claim to have sent more than 16 files"},
{"UnexpectedFilesError", pipewire.UnexpectedFilesError(1 << 4), "server message headers claim to have sent more files than actually received"},
{"UnexpectedSequenceError", pipewire.UnexpectedSequenceError(1 << 4), "unexpected seq 16"},
{"UnsupportedFooterOpcodeError", pipewire.UnsupportedFooterOpcodeError(1 << 4), "unsupported footer opcode 16"},