internal/pipewire: hold socket fd directly
All checks were successful
Test / Create distribution (push) Successful in 38s
Test / Sandbox (push) Successful in 2m39s
Test / Hakurei (push) Successful in 3m31s
Test / Sandbox (race detector) (push) Successful in 4m30s
Test / Hpkg (push) Successful in 4m34s
Test / Hakurei (race detector) (push) Successful in 5m29s
Test / Flake checks (push) Successful in 1m40s

The interface provided by net is not used here and is a leftover from a previous implementation. This change removes it.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2025-12-19 00:28:02 +09:00
parent ce06b7b663
commit 5a50bf80ee

View File

@@ -19,7 +19,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"os" "os"
"path" "path"
"runtime" "runtime"
@@ -144,14 +143,11 @@ func New(conn Conn, props SPADict) (*Context, error) {
return &ctx, nil return &ctx, nil
} }
// A SyscallConnCloser is a [syscall.Conn] that implements [io.Closer]. // unixConn is an implementation of the [Conn] interface for connections
type SyscallConnCloser interface { // to Unix domain sockets.
syscall.Conn type unixConn struct {
io.Closer fd int
}
// A SyscallConn is a [Conn] adapter for [syscall.Conn].
type SyscallConn struct {
// Whether creation of a new epoll instance was attempted. // Whether creation of a new epoll instance was attempted.
epoll bool epoll bool
// File descriptor referring to the new epoll instance. // File descriptor referring to the new epoll instance.
@@ -164,20 +160,31 @@ type SyscallConn struct {
// If non-zero, next call is treated as a blocking call. // If non-zero, next call is treated as a blocking call.
timeout time.Duration timeout time.Duration
SyscallConnCloser
} }
// MightBlock implements [Conn.MightBlock]. // Dial connects to a Unix domain socket described by name.
func (conn *SyscallConn) MightBlock(timeout time.Duration) { func Dial(name string) (Conn, error) {
if fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC|syscall.SOCK_NONBLOCK, 0); err != nil {
return nil, os.NewSyscallError("socket", err)
} else if err = syscall.Connect(fd, &syscall.SockaddrUnix{Name: name}); err != nil {
_ = syscall.Close(fd)
return nil, os.NewSyscallError("connect", err)
} else {
return &unixConn{fd: fd}, nil
}
}
// MightBlock informs the implementation that the next call
// might block for a non-zero timeout.
func (conn *unixConn) MightBlock(timeout time.Duration) {
if timeout < 0 { if timeout < 0 {
timeout = 0 timeout = 0
} }
conn.timeout = timeout conn.timeout = timeout
} }
// wantsEpoll is called at the beginning of any method that wants to use epoll. // wantsEpoll is called at the beginning of any method that might use epoll.
func (conn *SyscallConn) wantsEpoll() error { func (conn *unixConn) wantsEpoll() error {
if !conn.epoll { if !conn.epoll {
conn.epoll = true conn.epoll = true
conn.epollFd, conn.epollErr = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC) conn.epollFd, conn.epollErr = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC)
@@ -185,24 +192,24 @@ func (conn *SyscallConn) wantsEpoll() error {
return conn.epollErr return conn.epollErr
} }
// wait waits for a specific I/O event for a fd passed for [syscall.Conn.SyscallConn] // wait waits for a specific I/O event on fd and returns a function
// and returns a function that should be deferred by the caller regardless of error. // that must be deferred by the caller regardless of error.
func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc func()) { func (conn *unixConn) wait(event uint32, errP *error) (cleanupFunc func()) {
if conn.timeout == 0 { if conn.timeout == 0 {
return func() {} return func() {}
} }
deadline := time.Now().Add(conn.timeout) deadline := time.Now().Add(conn.timeout)
conn.timeout = 0 conn.timeout = 0
if *errP = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_ADD, fd, &syscall.EpollEvent{ if *errP = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_ADD, conn.fd, &syscall.EpollEvent{
Events: event | syscall.EPOLLERR | syscall.EPOLLHUP, Events: event | syscall.EPOLLERR | syscall.EPOLLHUP,
Fd: int32(fd), Fd: int32(conn.fd),
}); *errP != nil { }); *errP != nil {
return func() {} return func() {}
} else { } else {
cleanupFunc = func() { cleanupFunc = func() {
// fd is guaranteed to remain valid while f executes but not after f returns // fd is guaranteed to remain valid while f executes but not after f returns
if epDelErr := syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_DEL, fd, nil); epDelErr != nil && *errP == nil { if epDelErr := syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_DEL, conn.fd, nil); epDelErr != nil && *errP == nil {
*errP = epDelErr *errP = epDelErr
return return
} }
@@ -216,7 +223,7 @@ func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc fu
} else { } else {
switch n { switch n {
case 1: // only the socket fd is ever added case 1: // only the socket fd is ever added
if conn.epollBuf[0].Fd != int32(fd) { // unreachable if conn.epollBuf[0].Fd != int32(conn.fd) { // unreachable
err = syscall.ENOTRECOVERABLE err = syscall.ENOTRECOVERABLE
break break
} }
@@ -243,62 +250,40 @@ func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc fu
return return
} }
// Recvmsg implements [Conn.Recvmsg] via [syscall.Conn.SyscallConn]. // Recvmsg calls syscall.Recvmsg on the underlying socket.
func (conn *SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) { func (conn *unixConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) {
if err = conn.wantsEpoll(); err != nil { if err = conn.wantsEpoll(); err != nil {
return return
} }
defer conn.wait(syscall.EPOLLIN, &err)()
var rc syscall.RawConn
if rc, err = conn.SyscallConn(); err != nil {
return
}
if controlErr := rc.Control(func(fd uintptr) {
defer conn.wait(int(fd), syscall.EPOLLIN, &err)()
if err != nil { if err != nil {
return return
} }
n, oobn, recvflags, _, err = syscall.Recvmsg(int(fd), p, oob, flags) n, oobn, recvflags, _, err = syscall.Recvmsg(conn.fd, p, oob, flags)
}); controlErr != nil && err == nil {
err = controlErr
}
return return
} }
// Sendmsg implements [Conn.Sendmsg] via [syscall.Conn.SyscallConn]. // Sendmsg calls syscall.Sendmsg on the underlying socket.
func (conn *SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { func (conn *unixConn) Sendmsg(p, oob []byte, flags int) (n int, err error) {
if err = conn.wantsEpoll(); err != nil { if err = conn.wantsEpoll(); err != nil {
return return
} }
defer conn.wait(syscall.EPOLLOUT, &err)()
var rc syscall.RawConn
if rc, err = conn.SyscallConn(); err != nil {
return
}
if controlErr := rc.Control(func(fd uintptr) {
defer conn.wait(int(fd), syscall.EPOLLOUT, &err)()
if err != nil { if err != nil {
return return
} }
n, err = syscall.SendmsgN(int(fd), p, oob, nil, flags) n, err = syscall.SendmsgN(conn.fd, p, oob, nil, flags)
}); controlErr != nil && err == nil {
err = controlErr
}
return return
} }
// Close implements [Conn.Close] via [syscall.Conn.Close] but also // Close closes the underlying socket and the epoll fd if populated.
// closes the epoll fd if populated. func (conn *unixConn) Close() (err error) {
func (conn *SyscallConn) Close() (err error) {
if conn.epoll && conn.epollErr == nil { if conn.epoll && conn.epollErr == nil {
conn.epollErr = syscall.Close(conn.epollFd) conn.epollErr = syscall.Close(conn.epollFd)
} }
if err = conn.SyscallConnCloser.Close(); err != nil { if err = syscall.Close(conn.fd); err != nil {
return return
} }
return conn.epollErr return conn.epollErr
@@ -992,14 +977,14 @@ const Remote = "PIPEWIRE_REMOTE"
const DEFAULT_SYSTEM_RUNTIME_DIR = "/run/pipewire" const DEFAULT_SYSTEM_RUNTIME_DIR = "/run/pipewire"
// connectName connects to a PipeWire remote by name and returns the [net.UnixConn]. // connectName connects to a PipeWire remote by name and returns the resulting [Conn].
func connectName(name string, manager bool) (conn *net.UnixConn, err error) { func connectName(name string, manager bool) (conn Conn, err error) {
if manager && !strings.HasSuffix(name, "-manager") { if manager && !strings.HasSuffix(name, "-manager") {
return connectName(name+"-manager", false) return connectName(name+"-manager", false)
} }
if path.IsAbs(name) || (len(name) > 0 && name[0] == '@') { if path.IsAbs(name) || (len(name) > 0 && name[0] == '@') {
return net.DialUnix("unix", nil, &net.UnixAddr{Name: name, Net: "unix"}) return Dial(name)
} else { } else {
runtimeDir, ok := os.LookupEnv("PIPEWIRE_RUNTIME_DIR") runtimeDir, ok := os.LookupEnv("PIPEWIRE_RUNTIME_DIR")
if !ok || !path.IsAbs(runtimeDir) { if !ok || !path.IsAbs(runtimeDir) {
@@ -1014,7 +999,7 @@ func connectName(name string, manager bool) (conn *net.UnixConn, err error) {
if !ok || !path.IsAbs(runtimeDir) { if !ok || !path.IsAbs(runtimeDir) {
runtimeDir = DEFAULT_SYSTEM_RUNTIME_DIR runtimeDir = DEFAULT_SYSTEM_RUNTIME_DIR
} }
return net.DialUnix("unix", nil, &net.UnixAddr{Name: path.Join(runtimeDir, name), Net: "unix"}) return Dial(path.Join(runtimeDir, name))
} }
} }
@@ -1032,12 +1017,10 @@ func ConnectName(name string, manager bool, props SPADict) (ctx *Context, err er
} }
} }
var unixConn *net.UnixConn var conn Conn
if unixConn, err = connectName(name, manager); err != nil { if conn, err = connectName(name, manager); err != nil {
return return
} }
conn := &SyscallConn{SyscallConnCloser: unixConn}
if ctx, err = New(conn, props); err != nil { if ctx, err = New(conn, props); err != nil {
ctx = nil ctx = nil
_ = conn.Close() _ = conn.Close()