From ce06b7b66386f592df26ca089b470cc5b40ecd39 Mon Sep 17 00:00:00 2001 From: Ophestra Date: Thu, 18 Dec 2025 23:51:47 +0900 Subject: [PATCH] internal/pipewire: inform conn of blocking intent The interface does not expose underlying kernel notification mechanisms. This change removes the need to poll in situations were the next call might block. This is made cumbersome by the SyscallConn interface left over from a previous implementation, it will be replaced in a later commit as the current implementation does not make use of any net.Conn methods other than Close. Signed-off-by: Ophestra --- internal/pipewire/pipewire.go | 148 +++++++++++++++++++++++++++-- internal/pipewire/pipewire_test.go | 13 +++ internal/system/pipewire_test.go | 7 ++ 3 files changed, 162 insertions(+), 6 deletions(-) diff --git a/internal/pipewire/pipewire.go b/internal/pipewire/pipewire.go index d190eb6..53077da 100644 --- a/internal/pipewire/pipewire.go +++ b/internal/pipewire/pipewire.go @@ -27,10 +27,16 @@ import ( "strconv" "strings" "syscall" + "time" ) // Conn is a low level unix socket interface used by [Context]. type Conn interface { + // MightBlock informs the implementation that the next call to + // Recvmsg or Sendmsg might block. A zero or negative timeout + // cancels this behaviour. + MightBlock(timeout time.Duration) + // Recvmsg calls syscall.Recvmsg on the underlying socket. Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) @@ -145,17 +151,117 @@ type SyscallConnCloser interface { } // A SyscallConn is a [Conn] adapter for [syscall.Conn]. -type SyscallConn struct{ SyscallConnCloser } +type SyscallConn struct { + // Whether creation of a new epoll instance was attempted. + epoll bool + // File descriptor referring to the new epoll instance. + // Valid if epoll is true and epollErr is nil. + epollFd int + // Error returned by syscall.EpollCreate1. + epollErr error + // Stores epoll events from the kernel. + epollBuf [32]syscall.EpollEvent + + // If non-zero, next call is treated as a blocking call. + timeout time.Duration + + SyscallConnCloser +} + +// MightBlock implements [Conn.MightBlock]. +func (conn *SyscallConn) MightBlock(timeout time.Duration) { + if timeout < 0 { + timeout = 0 + } + conn.timeout = timeout +} + +// wantsEpoll is called at the beginning of any method that wants to use epoll. +func (conn *SyscallConn) wantsEpoll() error { + if !conn.epoll { + conn.epoll = true + conn.epollFd, conn.epollErr = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC) + } + return conn.epollErr +} + +// wait waits for a specific I/O event for a fd passed for [syscall.Conn.SyscallConn] +// and returns a function that should be deferred by the caller regardless of error. +func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc func()) { + if conn.timeout == 0 { + return func() {} + } + deadline := time.Now().Add(conn.timeout) + conn.timeout = 0 + + if *errP = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_ADD, fd, &syscall.EpollEvent{ + Events: event | syscall.EPOLLERR | syscall.EPOLLHUP, + Fd: int32(fd), + }); *errP != nil { + return func() {} + } else { + cleanupFunc = func() { + // fd is guaranteed to remain valid while f executes but not after f returns + if epDelErr := syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_DEL, fd, nil); epDelErr != nil && *errP == nil { + *errP = epDelErr + return + } + } + } + + for timeout := deadline.Sub(time.Now()); timeout > 0; timeout = deadline.Sub(time.Now()) { + if n, err := syscall.EpollWait(conn.epollFd, conn.epollBuf[:], int(timeout/time.Millisecond)); err != nil { + *errP = err + return + } else { + switch n { + case 1: // only the socket fd is ever added + if conn.epollBuf[0].Fd != int32(fd) { // unreachable + err = syscall.ENOTRECOVERABLE + break + } + if conn.epollBuf[0].Events&event == event || + conn.epollBuf[0].Events&syscall.EPOLLERR|syscall.EPOLLHUP != 0 { + break + } + *errP = syscall.ETIME + continue + + case 0: // timeout + err = syscall.ETIMEDOUT + break + + default: // unreachable + err = syscall.ENOTRECOVERABLE + break + } + + *errP = err + break + } + } + return +} // Recvmsg implements [Conn.Recvmsg] via [syscall.Conn.SyscallConn]. -func (conn SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) { +func (conn *SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) { + if err = conn.wantsEpoll(); err != nil { + return + } + var rc syscall.RawConn if rc, err = conn.SyscallConn(); err != nil { return } if controlErr := rc.Control(func(fd uintptr) { + defer conn.wait(int(fd), syscall.EPOLLIN, &err)() + if err != nil { + return + } + n, oobn, recvflags, _, err = syscall.Recvmsg(int(fd), p, oob, flags) + }); controlErr != nil && err == nil { err = controlErr } @@ -163,13 +269,22 @@ func (conn SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags in } // Sendmsg implements [Conn.Sendmsg] via [syscall.Conn.SyscallConn]. -func (conn SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { +func (conn *SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { + if err = conn.wantsEpoll(); err != nil { + return + } + var rc syscall.RawConn if rc, err = conn.SyscallConn(); err != nil { return } if controlErr := rc.Control(func(fd uintptr) { + defer conn.wait(int(fd), syscall.EPOLLOUT, &err)() + if err != nil { + return + } + n, err = syscall.SendmsgN(int(fd), p, oob, nil, flags) }); controlErr != nil && err == nil { err = controlErr @@ -177,6 +292,18 @@ func (conn SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { return } +// Close implements [Conn.Close] via [syscall.Conn.Close] but also +// closes the epoll fd if populated. +func (conn *SyscallConn) Close() (err error) { + if conn.epoll && conn.epollErr == nil { + conn.epollErr = syscall.Close(conn.epollFd) + } + if err = conn.SyscallConnCloser.Close(); err != nil { + return + } + return conn.epollErr +} + // MustNew calls [New](conn, props) and panics on error. // It is intended for use in tests with hard-coded strings. func MustNew(conn Conn, props SPADict) *Context { @@ -598,8 +725,15 @@ func (ctx *Context) Roundtrip() (err error) { return } +const ( + // roundtripTimeout is the maximum duration socket operations during + // Context.roundtrip is allowed to block for. + roundtripTimeout = 5 * time.Second +) + // roundtrip implements the Roundtrip method without checking proxyErrors. func (ctx *Context) roundtrip() (err error) { + ctx.conn.MightBlock(roundtripTimeout) if err = ctx.sendmsg(ctx.buf, ctx.pendingFiles...); err != nil { return } @@ -633,6 +767,7 @@ func (ctx *Context) roundtrip() (err error) { }() var remaining []byte + ctx.conn.MightBlock(roundtripTimeout) for { remaining, err = ctx.consume(remaining) if err == nil { @@ -897,12 +1032,13 @@ func ConnectName(name string, manager bool, props SPADict) (ctx *Context, err er } } - var conn *net.UnixConn - if conn, err = connectName(name, manager); err != nil { + var unixConn *net.UnixConn + if unixConn, err = connectName(name, manager); err != nil { return } - if ctx, err = New(SyscallConn{conn}, props); err != nil { + conn := &SyscallConn{SyscallConnCloser: unixConn} + if ctx, err = New(conn, props); err != nil { ctx = nil _ = conn.Close() } diff --git a/internal/pipewire/pipewire_test.go b/internal/pipewire/pipewire_test.go index 0b65b86..9305e7e 100644 --- a/internal/pipewire/pipewire_test.go +++ b/internal/pipewire/pipewire_test.go @@ -6,6 +6,7 @@ import ( "strconv" . "syscall" "testing" + "time" "hakurei.app/container/stub" "hakurei.app/internal/pipewire" @@ -715,6 +716,18 @@ type stubUnixConn struct { current int } +func (conn *stubUnixConn) MightBlock(timeout time.Duration) { + if timeout != 5*time.Second { + panic("unexpected timeout " + timeout.String()) + } + if conn.current == 0 || + (conn.samples[conn.current-1].nr == SYS_RECVMSG && conn.samples[conn.current-1].errno == EAGAIN && conn.samples[conn.current].nr == SYS_SENDMSG) || + (conn.samples[conn.current-1].nr == SYS_SENDMSG && conn.samples[conn.current].nr == SYS_RECVMSG) { + return + } + panic("unexpected blocking hint before sample " + strconv.Itoa(conn.current)) +} + // nextSample returns the current sample and increments the counter. func (conn *stubUnixConn) nextSample(nr uintptr) (sample *stubUnixConnSample, wantOOB []byte, err error) { sample = &conn.samples[conn.current] diff --git a/internal/system/pipewire_test.go b/internal/system/pipewire_test.go index 9c4950c..a39237f 100644 --- a/internal/system/pipewire_test.go +++ b/internal/system/pipewire_test.go @@ -6,6 +6,7 @@ import ( "path" "syscall" "testing" + "time" "hakurei.app/container/stub" "hakurei.app/internal/acl" @@ -497,6 +498,12 @@ type stubPipeWireConn struct { curSendmsg int } +func (conn *stubPipeWireConn) MightBlock(timeout time.Duration) { + if timeout != 5*time.Second { + panic("unexpected timeout " + timeout.String()) + } +} + // Recvmsg marshals and copies a stubMessage prepared ahead of time. func (conn *stubPipeWireConn) Recvmsg(p, _ []byte, _ int) (n, _, recvflags int, err error) { defer func() { conn.curRecvmsg++ }()