From 5a50bf80ee08f68f6edd7015aab987d7614c0707 Mon Sep 17 00:00:00 2001 From: Ophestra Date: Fri, 19 Dec 2025 00:28:02 +0900 Subject: [PATCH] internal/pipewire: hold socket fd directly The interface provided by net is not used here and is a leftover from a previous implementation. This change removes it. Signed-off-by: Ophestra --- internal/pipewire/pipewire.go | 111 ++++++++++++++-------------------- 1 file changed, 47 insertions(+), 64 deletions(-) diff --git a/internal/pipewire/pipewire.go b/internal/pipewire/pipewire.go index 53077da..9af142b 100644 --- a/internal/pipewire/pipewire.go +++ b/internal/pipewire/pipewire.go @@ -19,7 +19,6 @@ import ( "errors" "fmt" "io" - "net" "os" "path" "runtime" @@ -144,14 +143,11 @@ func New(conn Conn, props SPADict) (*Context, error) { return &ctx, nil } -// A SyscallConnCloser is a [syscall.Conn] that implements [io.Closer]. -type SyscallConnCloser interface { - syscall.Conn - io.Closer -} +// unixConn is an implementation of the [Conn] interface for connections +// to Unix domain sockets. +type unixConn struct { + fd int -// A SyscallConn is a [Conn] adapter for [syscall.Conn]. -type SyscallConn struct { // Whether creation of a new epoll instance was attempted. epoll bool // File descriptor referring to the new epoll instance. @@ -164,20 +160,31 @@ type SyscallConn struct { // If non-zero, next call is treated as a blocking call. timeout time.Duration - - SyscallConnCloser } -// MightBlock implements [Conn.MightBlock]. -func (conn *SyscallConn) MightBlock(timeout time.Duration) { +// Dial connects to a Unix domain socket described by name. +func Dial(name string) (Conn, error) { + if fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC|syscall.SOCK_NONBLOCK, 0); err != nil { + return nil, os.NewSyscallError("socket", err) + } else if err = syscall.Connect(fd, &syscall.SockaddrUnix{Name: name}); err != nil { + _ = syscall.Close(fd) + return nil, os.NewSyscallError("connect", err) + } else { + return &unixConn{fd: fd}, nil + } +} + +// MightBlock informs the implementation that the next call +// might block for a non-zero timeout. +func (conn *unixConn) MightBlock(timeout time.Duration) { if timeout < 0 { timeout = 0 } conn.timeout = timeout } -// wantsEpoll is called at the beginning of any method that wants to use epoll. -func (conn *SyscallConn) wantsEpoll() error { +// wantsEpoll is called at the beginning of any method that might use epoll. +func (conn *unixConn) wantsEpoll() error { if !conn.epoll { conn.epoll = true conn.epollFd, conn.epollErr = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC) @@ -185,24 +192,24 @@ func (conn *SyscallConn) wantsEpoll() error { return conn.epollErr } -// wait waits for a specific I/O event for a fd passed for [syscall.Conn.SyscallConn] -// and returns a function that should be deferred by the caller regardless of error. -func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc func()) { +// wait waits for a specific I/O event on fd and returns a function +// that must be deferred by the caller regardless of error. +func (conn *unixConn) wait(event uint32, errP *error) (cleanupFunc func()) { if conn.timeout == 0 { return func() {} } deadline := time.Now().Add(conn.timeout) conn.timeout = 0 - if *errP = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_ADD, fd, &syscall.EpollEvent{ + if *errP = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_ADD, conn.fd, &syscall.EpollEvent{ Events: event | syscall.EPOLLERR | syscall.EPOLLHUP, - Fd: int32(fd), + Fd: int32(conn.fd), }); *errP != nil { return func() {} } else { cleanupFunc = func() { // fd is guaranteed to remain valid while f executes but not after f returns - if epDelErr := syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_DEL, fd, nil); epDelErr != nil && *errP == nil { + if epDelErr := syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_DEL, conn.fd, nil); epDelErr != nil && *errP == nil { *errP = epDelErr return } @@ -216,7 +223,7 @@ func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc fu } else { switch n { case 1: // only the socket fd is ever added - if conn.epollBuf[0].Fd != int32(fd) { // unreachable + if conn.epollBuf[0].Fd != int32(conn.fd) { // unreachable err = syscall.ENOTRECOVERABLE break } @@ -243,62 +250,40 @@ func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc fu return } -// Recvmsg implements [Conn.Recvmsg] via [syscall.Conn.SyscallConn]. -func (conn *SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) { +// Recvmsg calls syscall.Recvmsg on the underlying socket. +func (conn *unixConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) { if err = conn.wantsEpoll(); err != nil { return } - - var rc syscall.RawConn - if rc, err = conn.SyscallConn(); err != nil { + defer conn.wait(syscall.EPOLLIN, &err)() + if err != nil { return } - if controlErr := rc.Control(func(fd uintptr) { - defer conn.wait(int(fd), syscall.EPOLLIN, &err)() - if err != nil { - return - } - - n, oobn, recvflags, _, err = syscall.Recvmsg(int(fd), p, oob, flags) - - }); controlErr != nil && err == nil { - err = controlErr - } + n, oobn, recvflags, _, err = syscall.Recvmsg(conn.fd, p, oob, flags) return } -// Sendmsg implements [Conn.Sendmsg] via [syscall.Conn.SyscallConn]. -func (conn *SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { +// Sendmsg calls syscall.Sendmsg on the underlying socket. +func (conn *unixConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { if err = conn.wantsEpoll(); err != nil { return } - - var rc syscall.RawConn - if rc, err = conn.SyscallConn(); err != nil { + defer conn.wait(syscall.EPOLLOUT, &err)() + if err != nil { return } - if controlErr := rc.Control(func(fd uintptr) { - defer conn.wait(int(fd), syscall.EPOLLOUT, &err)() - if err != nil { - return - } - - n, err = syscall.SendmsgN(int(fd), p, oob, nil, flags) - }); controlErr != nil && err == nil { - err = controlErr - } + n, err = syscall.SendmsgN(conn.fd, p, oob, nil, flags) return } -// Close implements [Conn.Close] via [syscall.Conn.Close] but also -// closes the epoll fd if populated. -func (conn *SyscallConn) Close() (err error) { +// Close closes the underlying socket and the epoll fd if populated. +func (conn *unixConn) Close() (err error) { if conn.epoll && conn.epollErr == nil { conn.epollErr = syscall.Close(conn.epollFd) } - if err = conn.SyscallConnCloser.Close(); err != nil { + if err = syscall.Close(conn.fd); err != nil { return } return conn.epollErr @@ -992,14 +977,14 @@ const Remote = "PIPEWIRE_REMOTE" const DEFAULT_SYSTEM_RUNTIME_DIR = "/run/pipewire" -// connectName connects to a PipeWire remote by name and returns the [net.UnixConn]. -func connectName(name string, manager bool) (conn *net.UnixConn, err error) { +// connectName connects to a PipeWire remote by name and returns the resulting [Conn]. +func connectName(name string, manager bool) (conn Conn, err error) { if manager && !strings.HasSuffix(name, "-manager") { return connectName(name+"-manager", false) } if path.IsAbs(name) || (len(name) > 0 && name[0] == '@') { - return net.DialUnix("unix", nil, &net.UnixAddr{Name: name, Net: "unix"}) + return Dial(name) } else { runtimeDir, ok := os.LookupEnv("PIPEWIRE_RUNTIME_DIR") if !ok || !path.IsAbs(runtimeDir) { @@ -1014,7 +999,7 @@ func connectName(name string, manager bool) (conn *net.UnixConn, err error) { if !ok || !path.IsAbs(runtimeDir) { runtimeDir = DEFAULT_SYSTEM_RUNTIME_DIR } - return net.DialUnix("unix", nil, &net.UnixAddr{Name: path.Join(runtimeDir, name), Net: "unix"}) + return Dial(path.Join(runtimeDir, name)) } } @@ -1032,12 +1017,10 @@ func ConnectName(name string, manager bool, props SPADict) (ctx *Context, err er } } - var unixConn *net.UnixConn - if unixConn, err = connectName(name, manager); err != nil { + var conn Conn + if conn, err = connectName(name, manager); err != nil { return } - - conn := &SyscallConn{SyscallConnCloser: unixConn} if ctx, err = New(conn, props); err != nil { ctx = nil _ = conn.Close()