internal/pipewire: use sendmsg/recvmsg directly
All checks were successful
Test / Create distribution (push) Successful in 34s
Test / Sandbox (push) Successful in 2m41s
Test / Sandbox (race detector) (push) Successful in 4m46s
Test / Hakurei (push) Successful in 4m58s
Test / Hpkg (push) Successful in 5m4s
Test / Hakurei (race detector) (push) Successful in 6m32s
Test / Flake checks (push) Successful in 1m29s

The PipeWire protocol does not work with Go abstractions. This change makes relevant methods call sendmsg/recvmsg directly.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2025-12-06 01:51:57 +09:00
parent 0d3f332d45
commit 8a2f9edcf9
3 changed files with 232 additions and 181 deletions

View File

@@ -19,59 +19,22 @@ import (
"fmt"
"io"
"maps"
"net"
"os"
"runtime"
"slices"
"strconv"
"syscall"
"time"
)
// Conn is a subset of methods of [net.UnixConn] used by [Context].
// Conn is a low level unix socket interface used by [Context].
type Conn interface {
// ReadMsgUnix reads a message from c, copying the payload into b and
// the associated out-of-band data into oob. It returns the number of
// bytes copied into b, the number of bytes copied into oob, the flags
// that were set on the message and the source address of the message.
//
// Note that if len(b) == 0 and len(oob) > 0, this function will still
// read (and discard) 1 byte from the connection.
ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *net.UnixAddr, err error)
// Recvmsg calls syscall.Recvmsg on the underlying socket.
Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error)
// WriteMsgUnix writes a message to addr via c, copying the payload
// from b and the associated out-of-band data from oob. It returns the
// number of payload and out-of-band bytes written.
//
// Note that if len(b) == 0 and len(oob) > 0, this function will still
// write 1 byte to the connection.
WriteMsgUnix(b, oob []byte, addr *net.UnixAddr) (n, oobn int, err error)
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
//
// A deadline is an absolute time after which I/O operations
// fail instead of blocking. The deadline applies to all future
// and pending I/O, not just the immediately following call to
// Read or Write. After a deadline has been exceeded, the
// connection can be refreshed by setting a deadline in the future.
//
// If the deadline is exceeded a call to Read or Write or to other
// I/O methods will return an error that wraps os.ErrDeadlineExceeded.
// This can be tested using errors.Is(err, os.ErrDeadlineExceeded).
// The error's Timeout method will return true, but note that there
// are other possible errors for which the Timeout method will
// return true even if the deadline has not been exceeded.
//
// An idle timeout can be implemented by repeatedly extending
// the deadline after successful Read or Write calls.
//
// A zero value for t means I/O operations will not time out.
SetDeadline(t time.Time) error
// Sendmsg calls syscall.SendmsgN on the underlying socket.
Sendmsg(p, oob []byte, flags int) (n int, err error)
// Close closes the connection.
// Any blocked Read or Write operations will be unblocked and return errors.
Close() error
}
@@ -155,6 +118,45 @@ func New(conn Conn, props SPADict) (*Context, error) {
return &ctx, nil
}
// A SyscallConnCloser is a [syscall.Conn] that implements [io.Closer].
type SyscallConnCloser interface {
syscall.Conn
io.Closer
}
// A SyscallConn is a [Conn] adapter for [syscall.Conn].
type SyscallConn struct{ SyscallConnCloser }
// Recvmsg implements [Conn.Recvmsg] via [syscall.Conn.SyscallConn].
func (conn SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) {
var rc syscall.RawConn
if rc, err = conn.SyscallConn(); err != nil {
return
}
if controlErr := rc.Control(func(fd uintptr) {
n, oobn, recvflags, _, err = syscall.Recvmsg(int(fd), p, oob, flags)
}); controlErr != nil && err == nil {
err = controlErr
}
return
}
// Sendmsg implements [Conn.Sendmsg] via [syscall.Conn.SyscallConn].
func (conn SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) {
var rc syscall.RawConn
if rc, err = conn.SyscallConn(); err != nil {
return
}
if controlErr := rc.Control(func(fd uintptr) {
n, err = syscall.SendmsgN(int(fd), p, oob, nil, flags)
}); controlErr != nil && err == nil {
err = controlErr
}
return
}
// MustNew calls [New](conn, props) and panics on error.
// It is intended for use in tests with hard-coded strings.
func MustNew(conn Conn, props SPADict) *Context {
@@ -226,66 +228,102 @@ increment:
return newId
}
// connTimeout is the maximum duration an I/O operation is allowed for [Conn].
const connTimeout = 5 * time.Second
// receiveAll receives from conn until no more data is available.
// The returned slice is valid until the next call to receiveAll.
func (ctx *Context) receiveAll() (payload []byte, err error) {
if err = ctx.conn.SetDeadline(time.Now().Add(connTimeout)); err != nil {
return
// closeReceivedFiles closes all receivedFiles. This is only during protocol error
// where [Context] is rendered unusable.
func (ctx *Context) closeReceivedFiles() {
slices.Sort(ctx.receivedFiles)
ctx.receivedFiles = slices.Compact(ctx.receivedFiles)
for _, fd := range ctx.receivedFiles {
_ = syscall.Close(fd)
}
var n, oobn int
ctx.receivedFiles = ctx.receivedFiles[:0]
buf := ctx.iovecBuf[:]
}
recvmsg:
buf = buf[n:]
n, oobn, _, _, err = ctx.conn.ReadMsgUnix(buf, ctx.oobBuf[:])
if err != nil {
return
}
if oobn == len(ctx.oobBuf) {
return nil, syscall.ENOMEM // unreachable
}
if oob := ctx.oobBuf[:oobn]; len(oob) > 0 {
var scm []syscall.SocketControlMessage
if scm, err = syscall.ParseSocketControlMessage(oob); err != nil {
return
}
// recvmsgFlags are flags passed to [Conn.Recvmsg] during Context.recvmsg.
const recvmsgFlags = syscall.MSG_CMSG_CLOEXEC | syscall.MSG_DONTWAIT
var fds []int
for i := range scm {
if fds, err = syscall.ParseUnixRights(&scm[i]); err != nil {
// recvmsg receives from conn and returns the received payload backed by
// iovecBuf. The returned slice is valid until the next call to recvmsg.
func (ctx *Context) recvmsg(remaining []byte) (payload []byte, err error) {
if copy(ctx.iovecBuf[:], remaining) != len(remaining) {
// should not be reachable with correct internal usage
return remaining, syscall.ENOMEM
}
var n, oobn, recvflags int
for {
n, oobn, recvflags, err = ctx.conn.Recvmsg(ctx.iovecBuf[len(remaining):], ctx.oobBuf[:], recvmsgFlags)
if oob := ctx.oobBuf[:oobn]; len(oob) > 0 {
var scm []syscall.SocketControlMessage
if scm, err = syscall.ParseSocketControlMessage(oob); err != nil {
ctx.closeReceivedFiles()
return
}
ctx.receivedFiles = append(ctx.receivedFiles, fds...)
var fds []int
for i := range scm {
if fds, err = syscall.ParseUnixRights(&scm[i]); err != nil {
ctx.closeReceivedFiles()
return
}
ctx.receivedFiles = append(ctx.receivedFiles, fds...)
}
}
if recvflags&syscall.MSG_CTRUNC != 0 {
// unreachable
ctx.closeReceivedFiles()
return nil, syscall.ENOMEM
}
if err != nil {
if err == syscall.EINTR {
continue
}
if err != syscall.EAGAIN && err != syscall.EWOULDBLOCK {
ctx.closeReceivedFiles()
return nil, os.NewSyscallError("recvmsg", err)
}
}
break
}
// receive until buffer fills or payload is depleted
if n > 0 {
goto recvmsg
if n >= 0 {
payload = ctx.iovecBuf[:n]
}
data := ctx.iovecBuf[:len(ctx.iovecBuf)-len(buf)]
// avoids copy if payload fits in a single ctx.recvmsgBuf
if payload == nil && len(buf) > 0 {
payload = data
return
}
payload = append(payload, data...)
// this indicates a full ctx.recvmsgBuf
if len(buf) == 0 {
ctx.buf = ctx.iovecBuf[:]
goto recvmsg
}
return
}
// sendmsgFlags are flags passed to [Conn.Sendmsg] during Context.sendmsg.
const sendmsgFlags = syscall.MSG_NOSIGNAL | syscall.MSG_DONTWAIT
// sendmsg sends p to conn. sendmsg does not retain p.
func (ctx *Context) sendmsg(p []byte, fds ...int) error {
var oob []byte
if len(fds) > 0 {
oob = syscall.UnixRights(fds...)
}
for {
n, err := ctx.conn.Sendmsg(p, oob, sendmsgFlags)
if err == syscall.EINTR {
continue
}
if err == nil && n != len(p) {
err = syscall.EMSGSIZE
}
if err != nil && err != syscall.EAGAIN && err != syscall.EWOULDBLOCK {
return os.NewSyscallError("sendmsg", err)
}
return err
}
}
// An UnknownIdError describes a server message with an Id unknown to [Context].
type UnknownIdError struct {
// Offending id decoded from Data.
@@ -427,7 +465,7 @@ func (e UnexpectedSequenceError) Error() string { return "unexpected seq " + str
type UnexpectedFilesError int
func (e UnexpectedFilesError) Error() string {
return "server message headers claim to have sent more than " + strconv.Itoa(int(e)) + " files"
return "server message headers claim to have sent more files than actually received"
}
// A DanglingFilesError holds onto files that were sent by the server but no [Header]
@@ -486,39 +524,49 @@ func (e ProxyConsumeError) Error() string {
// roundtripSyncID is the id passed to Context.coreSync during a [Context.Roundtrip].
const roundtripSyncID = 0
// Roundtrip queues the [CoreSync] message and sends all pending messages to the server.
// Roundtrip sends all pending messages to the server and processes events until
// the server has no more messages.
//
// For a non-nil error, if the error happens over the network, it has concrete type
// [net.OpError].
// [os.SyscallError].
func (ctx *Context) Roundtrip() (err error) {
if err = ctx.conn.SetDeadline(time.Now().Add(connTimeout)); err != nil {
return
}
var unixRightsOob []byte
if len(ctx.pendingFiles) > 0 {
unixRightsOob = syscall.UnixRights(ctx.pendingFiles...)
}
if _, _, err = ctx.conn.WriteMsgUnix(ctx.buf, unixRightsOob, nil); err != nil {
if err = ctx.sendmsg(ctx.buf, ctx.pendingFiles...); err != nil {
return
}
var remaining []byte
for {
remaining, err = ctx.roundtrip(remaining)
if err == nil {
continue
}
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
if len(remaining) == 0 {
err = nil
} else if len(remaining) < SizeHeader {
err = ErrRoundtripEOFHeader
} else {
err = ErrRoundtripEOFBody
}
}
return
}
}
// roundtrip receives messages from the server and processes events.
func (ctx *Context) roundtrip(receiveRemaining []byte) (remaining []byte, err error) {
var (
// this holds onto non-protocol errors encountered during event handling;
// errors that prevent event processing from continuing must be panicked
proxyErrors ProxyConsumeError
// current position of processed events in ctx.receivedFiles, anything
// beyond this is closed if event processing is terminated
receivedHeaderFiles int
)
defer func() {
// anything before this has already been processed and must not be closed
// here, as anything holding onto them will end up with a dangling fd that
// can be reused and cause serious problems
if len(ctx.receivedFiles) > receivedHeaderFiles {
for _, fd := range ctx.receivedFiles[receivedHeaderFiles:] {
_ = syscall.Close(fd)
}
if len(ctx.receivedFiles) > 0 {
ctx.closeReceivedFiles()
// this catches cases where Roundtrip somehow returns without processing
// all received files or preparing an error for dangling files, this is
@@ -551,48 +599,47 @@ func (ctx *Context) Roundtrip() (err error) {
ctx.pendingFiles = ctx.pendingFiles[:0]
ctx.headerFiles = 0
var data []byte
if data, err = ctx.receiveAll(); err != nil {
if remaining, err = ctx.recvmsg(receiveRemaining); err != nil {
return
}
var header Header
for len(data) > 0 {
if len(data) < SizeHeader {
return ErrRoundtripEOFHeader
for len(remaining) > 0 {
if len(remaining) < SizeHeader {
return
}
if err = header.UnmarshalBinary(data[:SizeHeader]); err != nil {
if err = header.UnmarshalBinary(remaining[:SizeHeader]); err != nil {
return
}
if header.Sequence != ctx.remoteSequence {
return UnexpectedSequenceError(header.Sequence)
return remaining, UnexpectedSequenceError(header.Sequence)
}
ctx.remoteSequence++
if len(data) < int(SizeHeader+header.Size) {
return ErrRoundtripEOFBody
if len(remaining) < int(SizeHeader+header.Size) {
return
}
proxy, ok := ctx.proxy[header.ID]
if !ok {
return &UnknownIdError{header.ID, string(data[:SizeHeader+header.Size])}
return remaining, &UnknownIdError{header.ID, string(remaining[:SizeHeader+header.Size])}
}
nextReceivedHeaderFiles := receivedHeaderFiles + int(header.FileCount)
if nextReceivedHeaderFiles > len(ctx.receivedFiles) {
return UnexpectedFilesError(len(ctx.receivedFiles))
fileCount := int(header.FileCount)
if fileCount > len(ctx.receivedFiles) {
return remaining, UnexpectedFilesError(fileCount)
}
files := ctx.receivedFiles[receivedHeaderFiles:nextReceivedHeaderFiles]
receivedHeaderFiles = nextReceivedHeaderFiles
files := ctx.receivedFiles[:fileCount]
ctx.receivedFiles = ctx.receivedFiles[fileCount:]
data = data[SizeHeader:]
remaining = remaining[SizeHeader:]
proxyErr := proxy.consume(header.Opcode, files, func(v any) {
if unmarshalErr := ctx.unmarshal(&header, data, v); unmarshalErr != nil {
if unmarshalErr := ctx.unmarshal(&header, remaining, v); unmarshalErr != nil {
panic(unmarshalErr)
}
})
data = data[header.Size:]
remaining = remaining[header.Size:]
if proxyErr != nil {
proxyErrors = append(proxyErrors, proxyErr)
}
@@ -601,36 +648,36 @@ func (ctx *Context) Roundtrip() (err error) {
// prepared here so finalizers are set up, but should not prevent proxyErrors
// from reaching the caller as those describe the cause of these dangling fds
var danglingFiles DanglingFilesError
if len(ctx.receivedFiles) > receivedHeaderFiles {
danglingFds := ctx.receivedFiles[receivedHeaderFiles:]
if len(ctx.receivedFiles) > 0 {
// having multiple *os.File with the same fd causes serious problems
slices.Sort(danglingFds)
danglingFds = slices.Compact(danglingFds)
slices.Sort(ctx.receivedFiles)
ctx.receivedFiles = slices.Compact(ctx.receivedFiles)
danglingFiles = make(DanglingFilesError, 0, len(danglingFds))
for _, fd := range danglingFds {
danglingFiles = make(DanglingFilesError, 0, len(ctx.receivedFiles))
for _, fd := range ctx.receivedFiles {
// hold these as *os.File so they are closed if this error never reaches the caller,
// or the caller discards or otherwise does not handle this error, to avoid leaking fds
danglingFiles = append(danglingFiles, os.NewFile(uintptr(fd),
"dangling fd "+strconv.Itoa(fd)+" received from PipeWire"))
}
ctx.receivedFiles = ctx.receivedFiles[:0]
}
// these are checked and made available first since they describe the cause
// of so-called symptoms checked after this point; the symptoms should only
// be made available as a catch-all if these are unavailable
if len(proxyErrors) > 0 {
return proxyErrors
return remaining, proxyErrors
}
// populated early for finalizers
if len(danglingFiles) > 0 {
return danglingFiles
return remaining, danglingFiles
}
// this check must happen after everything else passes
if len(ctx.pendingIds) != 0 {
return UnacknowledgedProxyError(slices.Collect(maps.Keys(ctx.pendingIds)))
return remaining, UnacknowledgedProxyError(slices.Collect(maps.Keys(ctx.pendingIds)))
}
return
}