diff --git a/internal/pipewire/pipewire.go b/internal/pipewire/pipewire.go index 9af142b..bd1d4e8 100644 --- a/internal/pipewire/pipewire.go +++ b/internal/pipewire/pipewire.go @@ -188,63 +188,57 @@ func (conn *unixConn) wantsEpoll() error { if !conn.epoll { conn.epoll = true conn.epollFd, conn.epollErr = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC) + if conn.epollErr == nil { + if conn.epollErr = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_ADD, conn.fd, &syscall.EpollEvent{ + Events: syscall.EPOLLERR | syscall.EPOLLHUP, + Fd: int32(conn.fd), + }); conn.epollErr != nil { + _ = syscall.Close(conn.epollFd) + } + } } return conn.epollErr } -// wait waits for a specific I/O event on fd and returns a function -// that must be deferred by the caller regardless of error. -func (conn *unixConn) wait(event uint32, errP *error) (cleanupFunc func()) { +// wait waits for a specific I/O event on fd. Caller must arrange for wantsEpoll +// to be called somewhere before wait is called. +func (conn *unixConn) wait(event uint32) (err error) { if conn.timeout == 0 { - return func() {} + return nil } deadline := time.Now().Add(conn.timeout) conn.timeout = 0 - if *errP = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_ADD, conn.fd, &syscall.EpollEvent{ + if err = syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_MOD, conn.fd, &syscall.EpollEvent{ Events: event | syscall.EPOLLERR | syscall.EPOLLHUP, Fd: int32(conn.fd), - }); *errP != nil { - return func() {} - } else { - cleanupFunc = func() { - // fd is guaranteed to remain valid while f executes but not after f returns - if epDelErr := syscall.EpollCtl(conn.epollFd, syscall.EPOLL_CTL_DEL, conn.fd, nil); epDelErr != nil && *errP == nil { - *errP = epDelErr - return - } - } + }); err != nil { + return } for timeout := deadline.Sub(time.Now()); timeout > 0; timeout = deadline.Sub(time.Now()) { - if n, err := syscall.EpollWait(conn.epollFd, conn.epollBuf[:], int(timeout/time.Millisecond)); err != nil { - *errP = err + var n int + if n, err = syscall.EpollWait(conn.epollFd, conn.epollBuf[:], int(timeout/time.Millisecond)); err != nil { return - } else { - switch n { - case 1: // only the socket fd is ever added - if conn.epollBuf[0].Fd != int32(conn.fd) { // unreachable - err = syscall.ENOTRECOVERABLE - break - } - if conn.epollBuf[0].Events&event == event || - conn.epollBuf[0].Events&syscall.EPOLLERR|syscall.EPOLLHUP != 0 { - break - } - *errP = syscall.ETIME - continue + } - case 0: // timeout - err = syscall.ETIMEDOUT - break - - default: // unreachable - err = syscall.ENOTRECOVERABLE - break + switch n { + case 1: // only the socket fd is ever added + if conn.epollBuf[0].Fd != int32(conn.fd) { // unreachable + return syscall.ENOTRECOVERABLE } + if conn.epollBuf[0].Events&event == event || + conn.epollBuf[0].Events&syscall.EPOLLERR|syscall.EPOLLHUP != 0 { + return nil + } + err = syscall.ETIME + continue - *errP = err - break + case 0: // timeout + return syscall.ETIMEDOUT + + default: // unreachable + return syscall.ENOTRECOVERABLE } } return @@ -254,9 +248,7 @@ func (conn *unixConn) wait(event uint32, errP *error) (cleanupFunc func()) { func (conn *unixConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, err error) { if err = conn.wantsEpoll(); err != nil { return - } - defer conn.wait(syscall.EPOLLIN, &err)() - if err != nil { + } else if err = conn.wait(syscall.EPOLLIN); err != nil { return } @@ -268,9 +260,7 @@ func (conn *unixConn) Recvmsg(p, oob []byte, flags int) (n, oobn, recvflags int, func (conn *unixConn) Sendmsg(p, oob []byte, flags int) (n int, err error) { if err = conn.wantsEpoll(); err != nil { return - } - defer conn.wait(syscall.EPOLLOUT, &err)() - if err != nil { + } else if err = conn.wait(syscall.EPOLLOUT); err != nil { return }