internal/netlink: expose multicast groups
All checks were successful
Test / Create distribution (push) Successful in 1m16s
Test / Sandbox (push) Successful in 3m14s
Test / Hakurei (push) Successful in 4m19s
Test / ShareFS (push) Successful in 4m19s
Test / Sandbox (race detector) (push) Successful in 5m36s
Test / Hakurei (race detector) (push) Successful in 6m43s
Test / Flake checks (push) Successful in 1m25s
All checks were successful
Test / Create distribution (push) Successful in 1m16s
Test / Sandbox (push) Successful in 3m14s
Test / Hakurei (push) Successful in 4m19s
Test / ShareFS (push) Successful in 4m19s
Test / Sandbox (race detector) (push) Successful in 5m36s
Test / Hakurei (race detector) (push) Successful in 6m43s
Test / Flake checks (push) Successful in 1m25s
This also gets rid of the cached pid value for port since that prevents multiple sockets from being open at once. Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
@@ -5,33 +5,29 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// AF_NETLINK socket is never shared
|
||||
var (
|
||||
nlPid uint32
|
||||
nlPidOnce sync.Once
|
||||
)
|
||||
|
||||
// getpid returns a cached pid value.
|
||||
func getpid() uint32 {
|
||||
nlPidOnce.Do(func() { nlPid = uint32(os.Getpid()) })
|
||||
return nlPid
|
||||
}
|
||||
|
||||
// net/netlink/af_netlink.c
|
||||
const maxRecvmsgLen = 32768
|
||||
|
||||
const (
|
||||
// stateOpen denotes an open conn.
|
||||
stateOpen uint32 = 1 << iota
|
||||
)
|
||||
|
||||
// A conn represents resources associated to a netlink socket.
|
||||
type conn struct {
|
||||
// AF_NETLINK socket.
|
||||
f *os.File
|
||||
// For using runtime polling via f.
|
||||
raw syscall.RawConn
|
||||
// Port ID assigned by the kernel.
|
||||
port uint32
|
||||
// Internal connection status.
|
||||
state uint32
|
||||
// Kernel module or netlink group to communicate with.
|
||||
family int
|
||||
// Message sequence number.
|
||||
@@ -49,7 +45,7 @@ type conn struct {
|
||||
}
|
||||
|
||||
// dial returns the address of a newly connected conn of specified family.
|
||||
func dial(family int) (*conn, error) {
|
||||
func dial(family int, groups uint32) (*conn, error) {
|
||||
var c conn
|
||||
if fd, err := syscall.Socket(
|
||||
syscall.AF_NETLINK,
|
||||
@@ -59,17 +55,32 @@ func dial(family int) (*conn, error) {
|
||||
return nil, os.NewSyscallError("socket", err)
|
||||
} else if err = syscall.Bind(fd, &syscall.SockaddrNetlink{
|
||||
Family: syscall.AF_NETLINK,
|
||||
Pid: getpid(),
|
||||
Groups: groups,
|
||||
}); err != nil {
|
||||
_ = syscall.Close(fd)
|
||||
return nil, os.NewSyscallError("bind", err)
|
||||
} else {
|
||||
var addr syscall.Sockaddr
|
||||
if addr, err = syscall.Getsockname(fd); err != nil {
|
||||
_ = syscall.Close(fd)
|
||||
return nil, os.NewSyscallError("getsockname", err)
|
||||
}
|
||||
switch a := addr.(type) {
|
||||
case *syscall.SockaddrNetlink:
|
||||
c.port = a.Pid
|
||||
|
||||
default: // unreachable
|
||||
_ = syscall.Close(fd)
|
||||
return nil, syscall.ENOTRECOVERABLE
|
||||
}
|
||||
|
||||
c.family = family
|
||||
c.f = os.NewFile(uintptr(fd), "netlink")
|
||||
if c.raw, err = c.f.SyscallConn(); err != nil {
|
||||
_ = c.f.Close()
|
||||
return nil, err
|
||||
}
|
||||
c.state |= stateOpen
|
||||
}
|
||||
|
||||
c.pos = syscall.NLMSG_HDRLEN
|
||||
@@ -78,14 +89,14 @@ func dial(family int) (*conn, error) {
|
||||
}
|
||||
|
||||
// ok returns whether conn is still open.
|
||||
func (c *conn) ok() bool { return c.family >= 0 }
|
||||
func (c *conn) ok() bool { return c.state&stateOpen != 0 }
|
||||
|
||||
// Close closes the underlying socket.
|
||||
func (c *conn) Close() error {
|
||||
if !c.ok() {
|
||||
return syscall.EINVAL
|
||||
}
|
||||
c.family = -1
|
||||
c.state &= ^stateOpen
|
||||
return c.f.Close()
|
||||
}
|
||||
|
||||
@@ -231,7 +242,7 @@ func (c *conn) pending() []byte {
|
||||
Type: c.typ,
|
||||
Flags: c.flags,
|
||||
Seq: c.seq,
|
||||
Pid: getpid(),
|
||||
Pid: c.port,
|
||||
}
|
||||
return buf
|
||||
}
|
||||
@@ -266,8 +277,8 @@ func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error {
|
||||
|
||||
for i := range resp {
|
||||
header := &resp[i].Header
|
||||
if header.Seq != c.seq || header.Pid != getpid() {
|
||||
return &InconsistentError{*header, c.seq, getpid()}
|
||||
if header.Seq != c.seq || header.Pid != c.port {
|
||||
return &InconsistentError{*header, c.seq, c.port}
|
||||
}
|
||||
}
|
||||
if err = f(resp); err != nil {
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func init() { nlPidOnce.Do(func() {}); nlPid = 1 }
|
||||
|
||||
type payloadTestCase struct {
|
||||
name string
|
||||
f func(c *conn)
|
||||
@@ -21,8 +19,9 @@ func checkPayload(t *testing.T, testCases []payloadTestCase) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Helper()
|
||||
|
||||
c := conn{pos: syscall.NLMSG_HDRLEN}
|
||||
c := conn{port: 1, pos: syscall.NLMSG_HDRLEN}
|
||||
tc.f(&c)
|
||||
if got := c.pending(); string(got) != string(tc.want) {
|
||||
t.Errorf("pending: %#v, want %#v", got, tc.want)
|
||||
|
||||
@@ -11,7 +11,7 @@ type RouteConn struct{ *conn }
|
||||
|
||||
// DialRoute returns the address of a newly connected [RouteConn].
|
||||
func DialRoute() (*RouteConn, error) {
|
||||
c, err := dial(syscall.NETLINK_ROUTE)
|
||||
c, err := dial(syscall.NETLINK_ROUTE, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user