internal/pipewire: hold socket fd directly
All checks were successful
Test / Create distribution (push) Successful in 38s
Test / Sandbox (push) Successful in 2m39s
Test / Hakurei (push) Successful in 3m31s
Test / Sandbox (race detector) (push) Successful in 4m30s
Test / Hpkg (push) Successful in 4m34s
Test / Hakurei (race detector) (push) Successful in 5m29s
Test / Flake checks (push) Successful in 1m40s

The interface provided by net is not used here and is a leftover from a previous implementation. This change removes it.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2025-12-19 00:28:02 +09:00
parent ce06b7b663
commit 5a50bf80ee

View File

@@ -19,7 +19,6 @@ import (
"errors"
"fmt"
"io"
"net"
"os"
"path"
"runtime"
@@ -144,14 +143,11 @@ func New(conn Conn, props SPADict) (*Context, error) {
return &ctx, nil
}
// A SyscallConnCloser is a [syscall.Conn] that implements [io.Closer].
type SyscallConnCloser interface {
syscall.Conn
io.Closer
}
// unixConn is an implementation of the [Conn] interface for connections
// to Unix domain sockets.
type unixConn struct {
fd int
// A SyscallConn is a [Conn] adapter for [syscall.Conn].
type SyscallConn struct {
// Whether creation of a new epoll instance was attempted.
epoll bool
// File descriptor referring to the new epoll instance.
@@ -164,20 +160,31 @@ type SyscallConn struct {
// If non-zero, next call is treated as a blocking call.
timeout time.Duration
SyscallConnCloser
}
// MightBlock implements [Conn.MightBlock].
func (conn *SyscallConn) MightBlock(timeout time.Duration) {
// Dial connects to a Unix domain socket described by name.
func Dial(name string) (Conn, error) {
if fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC|syscall.SOCK_NONBLOCK, 0); err != nil {
return nil, os.NewSyscallError("socket", err)
} else if err = syscall.Connect(fd, &syscall.SockaddrUnix{Name: name}); err != nil {
_ = syscall.Close(fd)
return nil, os.NewSyscallError("connect", err)
} else {
return &unixConn{fd: fd}, nil
}
}
// MightBlock informs the implementation that the next call
// might block for a non-zero timeout.
func (conn *unixConn) MightBlock(timeout time.Duration) {
if timeout < 0 {
timeout = 0
}
conn.timeout = timeout
}
// wantsEpoll is called at the beginning of any method that wants to use epoll.
func (conn *SyscallConn) wantsEpoll() error {
// wantsEpoll is called at the beginning of any method that might use epoll.
func (conn *unixConn) wantsEpoll() error {
if !conn.epoll {
conn.epoll = true
conn.epollFd, conn.epollErr = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC)
@@ -185,24 +192,24 @@ func (conn *SyscallConn) wantsEpoll() error {
return conn.epollErr
}
// wait waits for a specific I/O event for a fd passed for [syscall.Conn.SyscallConn]
// and returns a function that should be deferred by the caller regardless of error.
func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc func()) {
// wait waits for a specific I/O event on fd and returns a function
// that must be deferred by the caller regardless of error.
func (conn *unixConn) wait(event uint32, errP *error) (cleanupFunc func()) {
if conn.timeout == 0 {
return func() {}
}
deadline := time.Now().Add(conn.timeout)
conn.timeout = 0
if *errP = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_ADD, fd, &syscall.EpollEvent{
if *errP = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_ADD, conn.fd, &syscall.EpollEvent{
Events: event | syscall.EPOLLERR | syscall.EPOLLHUP,
Fd: int32(fd),
Fd: int32(conn.fd),
}); *errP != nil {
return func() {}
} else {
cleanupFunc = func() {
// fd is guaranteed to remain valid while f executes but not after f returns
if epDelErr := syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_DEL, fd, nil); epDelErr != nil && *errP == nil {
if epDelErr := syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_DEL, conn.fd, nil); epDelErr != nil && *errP == nil {
*errP = epDelErr
return
}
@@ -216,7 +223,7 @@ func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc fu
} else {
switch n {
case 1: // only the socket fd is ever added
if conn.epollBuf[0].Fd != int32(fd) { // unreachable
if conn.epollBuf[0].Fd != int32(conn.fd) { // unreachable
err = syscall.ENOTRECOVERABLE
break
}
@@ -243,62 +250,40 @@ func (conn *SyscallConn) wait(fd int, event uint32, errP *error) (cleanupFunc fu
return
}
// Recvmsg implements [Conn.Recvmsg] via [syscall.Conn.SyscallConn].
func (conn *SyscallConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) {
// Recvmsg calls syscall.Recvmsg on the underlying socket.
func (conn *unixConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) {
if err = conn.wantsEpoll(); err != nil {
return
}
var rc syscall.RawConn
if rc, err = conn.SyscallConn(); err != nil {
defer conn.wait(syscall.EPOLLIN, &err)()
if err != nil {
return
}
if controlErr := rc.Control(func(fd uintptr) {
defer conn.wait(int(fd), syscall.EPOLLIN, &err)()
if err != nil {
return
}
n, oobn, recvflags, _, err = syscall.Recvmsg(int(fd), p, oob, flags)
}); controlErr != nil && err == nil {
err = controlErr
}
n, oobn, recvflags, _, err = syscall.Recvmsg(conn.fd, p, oob, flags)
return
}
// Sendmsg implements [Conn.Sendmsg] via [syscall.Conn.SyscallConn].
func (conn *SyscallConn) Sendmsg(p, oob []byte, flags int) (n int, err error) {
// Sendmsg calls syscall.Sendmsg on the underlying socket.
func (conn *unixConn) Sendmsg(p, oob []byte, flags int) (n int, err error) {
if err = conn.wantsEpoll(); err != nil {
return
}
var rc syscall.RawConn
if rc, err = conn.SyscallConn(); err != nil {
defer conn.wait(syscall.EPOLLOUT, &err)()
if err != nil {
return
}
if controlErr := rc.Control(func(fd uintptr) {
defer conn.wait(int(fd), syscall.EPOLLOUT, &err)()
if err != nil {
return
}
n, err = syscall.SendmsgN(int(fd), p, oob, nil, flags)
}); controlErr != nil && err == nil {
err = controlErr
}
n, err = syscall.SendmsgN(conn.fd, p, oob, nil, flags)
return
}
// Close implements [Conn.Close] via [syscall.Conn.Close] but also
// closes the epoll fd if populated.
func (conn *SyscallConn) Close() (err error) {
// Close closes the underlying socket and the epoll fd if populated.
func (conn *unixConn) Close() (err error) {
if conn.epoll && conn.epollErr == nil {
conn.epollErr = syscall.Close(conn.epollFd)
}
if err = conn.SyscallConnCloser.Close(); err != nil {
if err = syscall.Close(conn.fd); err != nil {
return
}
return conn.epollErr
@@ -992,14 +977,14 @@ const Remote = "PIPEWIRE_REMOTE"
const DEFAULT_SYSTEM_RUNTIME_DIR = "/run/pipewire"
// connectName connects to a PipeWire remote by name and returns the [net.UnixConn].
func connectName(name string, manager bool) (conn *net.UnixConn, err error) {
// connectName connects to a PipeWire remote by name and returns the resulting [Conn].
func connectName(name string, manager bool) (conn Conn, err error) {
if manager && !strings.HasSuffix(name, "-manager") {
return connectName(name+"-manager", false)
}
if path.IsAbs(name) || (len(name) > 0 && name[0] == '@') {
return net.DialUnix("unix", nil, &net.UnixAddr{Name: name, Net: "unix"})
return Dial(name)
} else {
runtimeDir, ok := os.LookupEnv("PIPEWIRE_RUNTIME_DIR")
if !ok || !path.IsAbs(runtimeDir) {
@@ -1014,7 +999,7 @@ func connectName(name string, manager bool) (conn *net.UnixConn, err error) {
if !ok || !path.IsAbs(runtimeDir) {
runtimeDir = DEFAULT_SYSTEM_RUNTIME_DIR
}
return net.DialUnix("unix", nil, &net.UnixAddr{Name: path.Join(runtimeDir, name), Net: "unix"})
return Dial(path.Join(runtimeDir, name))
}
}
@@ -1032,12 +1017,10 @@ func ConnectName(name string, manager bool, props SPADict) (ctx *Context, err er
}
}
var unixConn *net.UnixConn
if unixConn, err = connectName(name, manager); err != nil {
var conn Conn
if conn, err = connectName(name, manager); err != nil {
return
}
conn := &SyscallConn{SyscallConnCloser: unixConn}
if ctx, err = New(conn, props); err != nil {
ctx = nil
_ = conn.Close()