diff --git a/internal/pipewire/pipewire.go b/internal/pipewire/pipewire.go index d190eb6..53077da 100644 --- a/internal/pipewire/pipewire.go +++ b/internal/pipewire/pipewire.go @@ -27,10 +27,16 @@ import ( "strconv" "strings" "syscall" + "time" ) // Conn is a low level unix socket interface used by [Context]. type Conn interface { + // MightBlock informs the implementation that the next call to + // Recvmsg or Sendmsg might block. A zero or negative timeout + // cancels this behaviour. + MightBlock(timeout time.Duration) + // Recvmsg calls syscall.Recvmsg on the underlying socket. Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) @@ -145,17 +151,117 @@ type SyscallConnCloser interface { } // A SyscallConn is a [Conn] adapter for [syscall.Conn]. -type SyscallConn struct{ SyscallConnCloser } +type SyscallConn struct { + // Whether creation of a new epoll instance was attempted. + epoll bool + // File descriptor referring to the new epoll instance. + // Valid if epoll is true and epollErr is nil. + epollFd int + // Error returned by syscall.EpollCreate1. + epollErr error + // Stores epoll events from the kernel. + epollBuf [32]syscall.EpollEvent + + // If non-zero, next call is treated as a blocking call. + timeout time.Duration + + SyscallConnCloser +} + +// MightBlock implements [Conn.MightBlock]. +func (conn *SyscallConn) MightBlock(timeout time.Duration) { + if timeout < 0 { + timeout = 0 + } + conn.timeout = timeout +} + +// wantsEpoll is called at the beginning of any method that wants to use epoll. +func (conn *SyscallConn) wantsEpoll() error { + if !conn.epoll { + conn.epoll = true + conn.epollFd, conn.epollErr = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC) + } + return conn.epollErr +} + +// wait waits for a specific I/O event for a fd passed for [syscall.Conn.SyscallConn] +// and returns a function that should be deferred by the caller regardless of error. +func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc func()) { + if conn.timeout == 0 { + return func() {} + } + deadline := time.Now().Add(conn.timeout) + conn.timeout = 0 + + if *errP = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_ADD, fd, &syscall.EpollEvent{ + Events: event | syscall.EPOLLERR | syscall.EPOLLHUP, + Fd: int32(fd), + }); *errP != nil { + return func() {} + } else { + cleanupFunc = func() { + // fd is guaranteed to remain valid while f executes but not after f returns + if epDelErr := syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_DEL, fd, nil); epDelErr != nil && *errP == nil { + *errP = epDelErr + return + } + } + } + + for timeout := deadline.Sub(time.Now()); timeout > 0; timeout = deadline.Sub(time.Now()) { + if n, err := syscall.EpollWait(conn.epollFd, conn.epollBuf[:], int(timeout/time.Millisecond)); err != nil { + *errP = err + return + } else { + switch n { + case 1: // only the socket fd is ever added + if conn.epollBuf[0].Fd != int32(fd) { // unreachable + err = syscall.ENOTRECOVERABLE + break + } + if conn.epollBuf[0].Events&event == event || + conn.epollBuf[0].Events&syscall.EPOLLERR|syscall.EPOLLHUP != 0 { + break + } + *errP = syscall.ETIME + continue + + case 0: // timeout + err = syscall.ETIMEDOUT + break + + default: // unreachable + err = syscall.ENOTRECOVERABLE + break + } + + *errP = err + break + } + } + return +} // Recvmsg implements [Conn.Recvmsg] via [syscall.Conn.SyscallConn]. -func (conn SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) { +func (conn *SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) { + if err = conn.wantsEpoll(); err != nil { + return + } + var rc syscall.RawConn if rc, err = conn.SyscallConn(); err != nil { return } if controlErr := rc.Control(func(fd uintptr) { + defer conn.wait(int(fd), syscall.EPOLLIN, &err)() + if err != nil { + return + } + n, oobn, recvflags, _, err = syscall.Recvmsg(int(fd), p, oob, flags) + }); controlErr != nil && err == nil { err = controlErr } @@ -163,13 +269,22 @@ func (conn SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags in } // Sendmsg implements [Conn.Sendmsg] via [syscall.Conn.SyscallConn]. -func (conn SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { +func (conn *SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { + if err = conn.wantsEpoll(); err != nil { + return + } + var rc syscall.RawConn if rc, err = conn.SyscallConn(); err != nil { return } if controlErr := rc.Control(func(fd uintptr) { + defer conn.wait(int(fd), syscall.EPOLLOUT, &err)() + if err != nil { + return + } + n, err = syscall.SendmsgN(int(fd), p, oob, nil, flags) }); controlErr != nil && err == nil { err = controlErr @@ -177,6 +292,18 @@ func (conn SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { return } +// Close implements [Conn.Close] via [syscall.Conn.Close] but also +// closes the epoll fd if populated. +func (conn *SyscallConn) Close() (err error) { + if conn.epoll && conn.epollErr == nil { + conn.epollErr = syscall.Close(conn.epollFd) + } + if err = conn.SyscallConnCloser.Close(); err != nil { + return + } + return conn.epollErr +} + // MustNew calls [New](conn, props) and panics on error. // It is intended for use in tests with hard-coded strings. func MustNew(conn Conn, props SPADict) *Context { @@ -598,8 +725,15 @@ func (ctx *Context) Roundtrip() (err error) { return } +const ( + // roundtripTimeout is the maximum duration socket operations during + // Context.roundtrip is allowed to block for. + roundtripTimeout = 5 * time.Second +) + // roundtrip implements the Roundtrip method without checking proxyErrors. func (ctx *Context) roundtrip() (err error) { + ctx.conn.MightBlock(roundtripTimeout) if err = ctx.sendmsg(ctx.buf, ctx.pendingFiles...); err != nil { return } @@ -633,6 +767,7 @@ func (ctx *Context) roundtrip() (err error) { }() var remaining []byte + ctx.conn.MightBlock(roundtripTimeout) for { remaining, err = ctx.consume(remaining) if err == nil { @@ -897,12 +1032,13 @@ func ConnectName(name string, manager bool, props SPADict) (ctx *Context, err er } } - var conn *net.UnixConn - if conn, err = connectName(name, manager); err != nil { + var unixConn *net.UnixConn + if unixConn, err = connectName(name, manager); err != nil { return } - if ctx, err = New(SyscallConn{conn}, props); err != nil { + conn := &SyscallConn{SyscallConnCloser: unixConn} + if ctx, err = New(conn, props); err != nil { ctx = nil _ = conn.Close() } diff --git a/internal/pipewire/pipewire_test.go b/internal/pipewire/pipewire_test.go index 0b65b86..9305e7e 100644 --- a/internal/pipewire/pipewire_test.go +++ b/internal/pipewire/pipewire_test.go @@ -6,6 +6,7 @@ import ( "strconv" . "syscall" "testing" + "time" "hakurei.app/container/stub" "hakurei.app/internal/pipewire" @@ -715,6 +716,18 @@ type stubUnixConn struct { current int } +func (conn *stubUnixConn) MightBlock(timeout time.Duration) { + if timeout != 5*time.Second { + panic("unexpected timeout " + timeout.String()) + } + if conn.current == 0 || + (conn.samples[conn.current-1].nr == SYS_RECVMSG && conn.samples[conn.current-1].errno == EAGAIN && conn.samples[conn.current].nr == SYS_SENDMSG) || + (conn.samples[conn.current-1].nr == SYS_SENDMSG && conn.samples[conn.current].nr == SYS_RECVMSG) { + return + } + panic("unexpected blocking hint before sample " + strconv.Itoa(conn.current)) +} + // nextSample returns the current sample and increments the counter. func (conn *stubUnixConn) nextSample(nr uintptr) (sample *stubUnixConnSample, wantOOB []byte, err error) { sample = &conn.samples[conn.current] diff --git a/internal/system/pipewire_test.go b/internal/system/pipewire_test.go index 9c4950c..a39237f 100644 --- a/internal/system/pipewire_test.go +++ b/internal/system/pipewire_test.go @@ -6,6 +6,7 @@ import ( "path" "syscall" "testing" + "time" "hakurei.app/container/stub" "hakurei.app/internal/acl" @@ -497,6 +498,12 @@ type stubPipeWireConn struct { curSendmsg int } +func (conn *stubPipeWireConn) MightBlock(timeout time.Duration) { + if timeout != 5*time.Second { + panic("unexpected timeout " + timeout.String()) + } +} + // Recvmsg marshals and copies a stubMessage prepared ahead of time. func (conn *stubPipeWireConn) Recvmsg(p, _ []byte, _ int) (n, _, recvflags int, err error) { defer func() { conn.curRecvmsg++ }()