internal/pipewire: inform conn of blocking intent
All checks were successful
Test / Create distribution (push) Successful in 30s
Test / Sandbox (push) Successful in 2m31s
Test / Hakurei (push) Successful in 3m29s
Test / Hpkg (push) Successful in 4m24s
Test / Sandbox (race detector) (push) Successful in 4m30s
Test / Hakurei (race detector) (push) Successful in 5m27s
Test / Flake checks (push) Successful in 1m40s

The interface does not expose underlying kernel notification mechanisms. This change removes the need to poll in situations were the next call might block.

This is made cumbersome by the SyscallConn interface left over from a previous implementation, it will be replaced in a later commit as the current implementation does not make use of any net.Conn methods other than Close.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2025-12-18 23:51:47 +09:00
parent 08bdc68f3a
commit ce06b7b663
3 changed files with 162 additions and 6 deletions

View File

@@ -27,10 +27,16 @@ import (
"strconv"
"strings"
"syscall"
"time"
)
// Conn is a low level unix socket interface used by [Context].
type Conn interface {
// MightBlock informs the implementation that the next call to
// Recvmsg or Sendmsg might block. A zero or negative timeout
// cancels this behaviour.
MightBlock(timeout time.Duration)
// Recvmsg calls syscall.Recvmsg on the underlying socket.
Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error)
@@ -145,17 +151,117 @@ type SyscallConnCloser interface {
}
// A SyscallConn is a [Conn] adapter for [syscall.Conn].
type SyscallConn struct{ SyscallConnCloser }
type SyscallConn struct {
// Whether creation of a new epoll instance was attempted.
epoll bool
// File descriptor referring to the new epoll instance.
// Valid if epoll is true and epollErr is nil.
epollFd int
// Error returned by syscall.EpollCreate1.
epollErr error
// Stores epoll events from the kernel.
epollBuf [32]syscall.EpollEvent
// If non-zero, next call is treated as a blocking call.
timeout time.Duration
SyscallConnCloser
}
// MightBlock implements [Conn.MightBlock].
func (conn *SyscallConn) MightBlock(timeout time.Duration) {
if timeout < 0 {
timeout = 0
}
conn.timeout = timeout
}
// wantsEpoll is called at the beginning of any method that wants to use epoll.
func (conn *SyscallConn) wantsEpoll() error {
if !conn.epoll {
conn.epoll = true
conn.epollFd, conn.epollErr = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC)
}
return conn.epollErr
}
// wait waits for a specific I/O event for a fd passed for [syscall.Conn.SyscallConn]
// and returns a function that should be deferred by the caller regardless of error.
func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc func()) {
if conn.timeout == 0 {
return func() {}
}
deadline := time.Now().Add(conn.timeout)
conn.timeout = 0
if *errP = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_ADD, fd, &syscall.EpollEvent{
Events: event | syscall.EPOLLERR | syscall.EPOLLHUP,
Fd: int32(fd),
}); *errP != nil {
return func() {}
} else {
cleanupFunc = func() {
// fd is guaranteed to remain valid while f executes but not after f returns
if epDelErr := syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_DEL, fd, nil); epDelErr != nil && *errP == nil {
*errP = epDelErr
return
}
}
}
for timeout := deadline.Sub(time.Now()); timeout > 0; timeout = deadline.Sub(time.Now()) {
if n, err := syscall.EpollWait(conn.epollFd, conn.epollBuf[:], int(timeout/time.Millisecond)); err != nil {
*errP = err
return
} else {
switch n {
case 1: // only the socket fd is ever added
if conn.epollBuf[0].Fd != int32(fd) { // unreachable
err = syscall.ENOTRECOVERABLE
break
}
if conn.epollBuf[0].Events&event == event ||
conn.epollBuf[0].Events&syscall.EPOLLERR|syscall.EPOLLHUP != 0 {
break
}
*errP = syscall.ETIME
continue
case 0: // timeout
err = syscall.ETIMEDOUT
break
default: // unreachable
err = syscall.ENOTRECOVERABLE
break
}
*errP = err
break
}
}
return
}
// Recvmsg implements [Conn.Recvmsg] via [syscall.Conn.SyscallConn].
func (conn SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) {
func (conn *SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) {
if err = conn.wantsEpoll(); err != nil {
return
}
var rc syscall.RawConn
if rc, err = conn.SyscallConn(); err != nil {
return
}
if controlErr := rc.Control(func(fd uintptr) {
defer conn.wait(int(fd), syscall.EPOLLIN, &err)()
if err != nil {
return
}
n, oobn, recvflags, _, err = syscall.Recvmsg(int(fd), p, oob, flags)
}); controlErr != nil && err == nil {
err = controlErr
}
@@ -163,13 +269,22 @@ func (conn SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags in
}
// Sendmsg implements [Conn.Sendmsg] via [syscall.Conn.SyscallConn].
func (conn SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) {
func (conn *SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) {
if err = conn.wantsEpoll(); err != nil {
return
}
var rc syscall.RawConn
if rc, err = conn.SyscallConn(); err != nil {
return
}
if controlErr := rc.Control(func(fd uintptr) {
defer conn.wait(int(fd), syscall.EPOLLOUT, &err)()
if err != nil {
return
}
n, err = syscall.SendmsgN(int(fd), p, oob, nil, flags)
}); controlErr != nil && err == nil {
err = controlErr
@@ -177,6 +292,18 @@ func (conn SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) {
return
}
// Close implements [Conn.Close] via [syscall.Conn.Close] but also
// closes the epoll fd if populated.
func (conn *SyscallConn) Close() (err error) {
if conn.epoll && conn.epollErr == nil {
conn.epollErr = syscall.Close(conn.epollFd)
}
if err = conn.SyscallConnCloser.Close(); err != nil {
return
}
return conn.epollErr
}
// MustNew calls [New](conn, props) and panics on error.
// It is intended for use in tests with hard-coded strings.
func MustNew(conn Conn, props SPADict) *Context {
@@ -598,8 +725,15 @@ func (ctx *Context) Roundtrip() (err error) {
return
}
const (
// roundtripTimeout is the maximum duration socket operations during
// Context.roundtrip is allowed to block for.
roundtripTimeout = 5 * time.Second
)
// roundtrip implements the Roundtrip method without checking proxyErrors.
func (ctx *Context) roundtrip() (err error) {
ctx.conn.MightBlock(roundtripTimeout)
if err = ctx.sendmsg(ctx.buf, ctx.pendingFiles...); err != nil {
return
}
@@ -633,6 +767,7 @@ func (ctx *Context) roundtrip() (err error) {
}()
var remaining []byte
ctx.conn.MightBlock(roundtripTimeout)
for {
remaining, err = ctx.consume(remaining)
if err == nil {
@@ -897,12 +1032,13 @@ func ConnectName(name string, manager bool, props SPADict) (ctx *Context, err er
}
}
var conn *net.UnixConn
if conn, err = connectName(name, manager); err != nil {
var unixConn *net.UnixConn
if unixConn, err = connectName(name, manager); err != nil {
return
}
if ctx, err = New(SyscallConn{conn}, props); err != nil {
conn := &SyscallConn{SyscallConnCloser: unixConn}
if ctx, err = New(conn, props); err != nil {
ctx = nil
_ = conn.Close()
}

View File

@@ -6,6 +6,7 @@ import (
"strconv"
. "syscall"
"testing"
"time"
"hakurei.app/container/stub"
"hakurei.app/internal/pipewire"
@@ -715,6 +716,18 @@ type stubUnixConn struct {
current int
}
func (conn *stubUnixConn) MightBlock(timeout time.Duration) {
if timeout != 5*time.Second {
panic("unexpected timeout " + timeout.String())
}
if conn.current == 0 ||
(conn.samples[conn.current-1].nr == SYS_RECVMSG && conn.samples[conn.current-1].errno == EAGAIN && conn.samples[conn.current].nr == SYS_SENDMSG) ||
(conn.samples[conn.current-1].nr == SYS_SENDMSG && conn.samples[conn.current].nr == SYS_RECVMSG) {
return
}
panic("unexpected blocking hint before sample " + strconv.Itoa(conn.current))
}
// nextSample returns the current sample and increments the counter.
func (conn *stubUnixConn) nextSample(nr uintptr) (sample *stubUnixConnSample, wantOOB []byte, err error) {
sample = &conn.samples[conn.current]

View File

@@ -6,6 +6,7 @@ import (
"path"
"syscall"
"testing"
"time"
"hakurei.app/container/stub"
"hakurei.app/internal/acl"
@@ -497,6 +498,12 @@ type stubPipeWireConn struct {
curSendmsg int
}
func (conn *stubPipeWireConn) MightBlock(timeout time.Duration) {
if timeout != 5*time.Second {
panic("unexpected timeout " + timeout.String())
}
}
// Recvmsg marshals and copies a stubMessage prepared ahead of time.
func (conn *stubPipeWireConn) Recvmsg(p, _ []byte, _ int) (n, _, recvflags int, err error) {
defer func() { conn.curRecvmsg++ }()