internal/netlink: expose multicast groups
All checks were successful
Test / Create distribution (push) Successful in 1m16s
Test / Sandbox (push) Successful in 3m14s
Test / Hakurei (push) Successful in 4m19s
Test / ShareFS (push) Successful in 4m19s
Test / Sandbox (race detector) (push) Successful in 5m36s
Test / Hakurei (race detector) (push) Successful in 6m43s
Test / Flake checks (push) Successful in 1m25s

This also gets rid of the cached pid value for port since that prevents multiple sockets from being open at once.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2026-03-25 17:55:35 +09:00
parent d62516ed1e
commit 372d509e5c
3 changed files with 34 additions and 24 deletions

View File

@@ -5,33 +5,29 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"sync"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
) )
// AF_NETLINK socket is never shared
var (
nlPid uint32
nlPidOnce sync.Once
)
// getpid returns a cached pid value.
func getpid() uint32 {
nlPidOnce.Do(func() { nlPid = uint32(os.Getpid()) })
return nlPid
}
// net/netlink/af_netlink.c // net/netlink/af_netlink.c
const maxRecvmsgLen = 32768 const maxRecvmsgLen = 32768
const (
// stateOpen denotes an open conn.
stateOpen uint32 = 1 << iota
)
// A conn represents resources associated to a netlink socket. // A conn represents resources associated to a netlink socket.
type conn struct { type conn struct {
// AF_NETLINK socket. // AF_NETLINK socket.
f *os.File f *os.File
// For using runtime polling via f. // For using runtime polling via f.
raw syscall.RawConn raw syscall.RawConn
// Port ID assigned by the kernel.
port uint32
// Internal connection status.
state uint32
// Kernel module or netlink group to communicate with. // Kernel module or netlink group to communicate with.
family int family int
// Message sequence number. // Message sequence number.
@@ -49,7 +45,7 @@ type conn struct {
} }
// dial returns the address of a newly connected conn of specified family. // dial returns the address of a newly connected conn of specified family.
func dial(family int) (*conn, error) { func dial(family int, groups uint32) (*conn, error) {
var c conn var c conn
if fd, err := syscall.Socket( if fd, err := syscall.Socket(
syscall.AF_NETLINK, syscall.AF_NETLINK,
@@ -59,17 +55,32 @@ func dial(family int) (*conn, error) {
return nil, os.NewSyscallError("socket", err) return nil, os.NewSyscallError("socket", err)
} else if err = syscall.Bind(fd, &syscall.SockaddrNetlink{ } else if err = syscall.Bind(fd, &syscall.SockaddrNetlink{
Family: syscall.AF_NETLINK, Family: syscall.AF_NETLINK,
Pid: getpid(), Groups: groups,
}); err != nil { }); err != nil {
_ = syscall.Close(fd) _ = syscall.Close(fd)
return nil, os.NewSyscallError("bind", err) return nil, os.NewSyscallError("bind", err)
} else { } else {
var addr syscall.Sockaddr
if addr, err = syscall.Getsockname(fd); err != nil {
_ = syscall.Close(fd)
return nil, os.NewSyscallError("getsockname", err)
}
switch a := addr.(type) {
case *syscall.SockaddrNetlink:
c.port = a.Pid
default: // unreachable
_ = syscall.Close(fd)
return nil, syscall.ENOTRECOVERABLE
}
c.family = family c.family = family
c.f = os.NewFile(uintptr(fd), "netlink") c.f = os.NewFile(uintptr(fd), "netlink")
if c.raw, err = c.f.SyscallConn(); err != nil { if c.raw, err = c.f.SyscallConn(); err != nil {
_ = c.f.Close() _ = c.f.Close()
return nil, err return nil, err
} }
c.state |= stateOpen
} }
c.pos = syscall.NLMSG_HDRLEN c.pos = syscall.NLMSG_HDRLEN
@@ -78,14 +89,14 @@ func dial(family int) (*conn, error) {
} }
// ok returns whether conn is still open. // ok returns whether conn is still open.
func (c *conn) ok() bool { return c.family >= 0 } func (c *conn) ok() bool { return c.state&stateOpen != 0 }
// Close closes the underlying socket. // Close closes the underlying socket.
func (c *conn) Close() error { func (c *conn) Close() error {
if !c.ok() { if !c.ok() {
return syscall.EINVAL return syscall.EINVAL
} }
c.family = -1 c.state &= ^stateOpen
return c.f.Close() return c.f.Close()
} }
@@ -231,7 +242,7 @@ func (c *conn) pending() []byte {
Type: c.typ, Type: c.typ,
Flags: c.flags, Flags: c.flags,
Seq: c.seq, Seq: c.seq,
Pid: getpid(), Pid: c.port,
} }
return buf return buf
} }
@@ -266,8 +277,8 @@ func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error {
for i := range resp { for i := range resp {
header := &resp[i].Header header := &resp[i].Header
if header.Seq != c.seq || header.Pid != getpid() { if header.Seq != c.seq || header.Pid != c.port {
return &InconsistentError{*header, c.seq, getpid()} return &InconsistentError{*header, c.seq, c.port}
} }
} }
if err = f(resp); err != nil { if err = f(resp); err != nil {

View File

@@ -5,8 +5,6 @@ import (
"testing" "testing"
) )
func init() { nlPidOnce.Do(func() {}); nlPid = 1 }
type payloadTestCase struct { type payloadTestCase struct {
name string name string
f func(c *conn) f func(c *conn)
@@ -21,8 +19,9 @@ func checkPayload(t *testing.T, testCases []payloadTestCase) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
t.Helper()
c := conn{pos: syscall.NLMSG_HDRLEN} c := conn{port: 1, pos: syscall.NLMSG_HDRLEN}
tc.f(&c) tc.f(&c)
if got := c.pending(); string(got) != string(tc.want) { if got := c.pending(); string(got) != string(tc.want) {
t.Errorf("pending: %#v, want %#v", got, tc.want) t.Errorf("pending: %#v, want %#v", got, tc.want)

View File

@@ -11,7 +11,7 @@ type RouteConn struct{ *conn }
// DialRoute returns the address of a newly connected [RouteConn]. // DialRoute returns the address of a newly connected [RouteConn].
func DialRoute() (*RouteConn, error) { func DialRoute() (*RouteConn, error) {
c, err := dial(syscall.NETLINK_ROUTE) c, err := dial(syscall.NETLINK_ROUTE, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }