internal/netlink: expose multicast groups
All checks were successful
Test / Create distribution (push) Successful in 1m16s
Test / Sandbox (push) Successful in 3m14s
Test / Hakurei (push) Successful in 4m19s
Test / ShareFS (push) Successful in 4m19s
Test / Sandbox (race detector) (push) Successful in 5m36s
Test / Hakurei (race detector) (push) Successful in 6m43s
Test / Flake checks (push) Successful in 1m25s
All checks were successful
Test / Create distribution (push) Successful in 1m16s
Test / Sandbox (push) Successful in 3m14s
Test / Hakurei (push) Successful in 4m19s
Test / ShareFS (push) Successful in 4m19s
Test / Sandbox (race detector) (push) Successful in 5m36s
Test / Hakurei (race detector) (push) Successful in 6m43s
Test / Flake checks (push) Successful in 1m25s
This also gets rid of the cached pid value for port since that prevents multiple sockets from being open at once. Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
@@ -5,33 +5,29 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AF_NETLINK socket is never shared
|
|
||||||
var (
|
|
||||||
nlPid uint32
|
|
||||||
nlPidOnce sync.Once
|
|
||||||
)
|
|
||||||
|
|
||||||
// getpid returns a cached pid value.
|
|
||||||
func getpid() uint32 {
|
|
||||||
nlPidOnce.Do(func() { nlPid = uint32(os.Getpid()) })
|
|
||||||
return nlPid
|
|
||||||
}
|
|
||||||
|
|
||||||
// net/netlink/af_netlink.c
|
// net/netlink/af_netlink.c
|
||||||
const maxRecvmsgLen = 32768
|
const maxRecvmsgLen = 32768
|
||||||
|
|
||||||
|
const (
|
||||||
|
// stateOpen denotes an open conn.
|
||||||
|
stateOpen uint32 = 1 << iota
|
||||||
|
)
|
||||||
|
|
||||||
// A conn represents resources associated to a netlink socket.
|
// A conn represents resources associated to a netlink socket.
|
||||||
type conn struct {
|
type conn struct {
|
||||||
// AF_NETLINK socket.
|
// AF_NETLINK socket.
|
||||||
f *os.File
|
f *os.File
|
||||||
// For using runtime polling via f.
|
// For using runtime polling via f.
|
||||||
raw syscall.RawConn
|
raw syscall.RawConn
|
||||||
|
// Port ID assigned by the kernel.
|
||||||
|
port uint32
|
||||||
|
// Internal connection status.
|
||||||
|
state uint32
|
||||||
// Kernel module or netlink group to communicate with.
|
// Kernel module or netlink group to communicate with.
|
||||||
family int
|
family int
|
||||||
// Message sequence number.
|
// Message sequence number.
|
||||||
@@ -49,7 +45,7 @@ type conn struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// dial returns the address of a newly connected conn of specified family.
|
// dial returns the address of a newly connected conn of specified family.
|
||||||
func dial(family int) (*conn, error) {
|
func dial(family int, groups uint32) (*conn, error) {
|
||||||
var c conn
|
var c conn
|
||||||
if fd, err := syscall.Socket(
|
if fd, err := syscall.Socket(
|
||||||
syscall.AF_NETLINK,
|
syscall.AF_NETLINK,
|
||||||
@@ -59,17 +55,32 @@ func dial(family int) (*conn, error) {
|
|||||||
return nil, os.NewSyscallError("socket", err)
|
return nil, os.NewSyscallError("socket", err)
|
||||||
} else if err = syscall.Bind(fd, &syscall.SockaddrNetlink{
|
} else if err = syscall.Bind(fd, &syscall.SockaddrNetlink{
|
||||||
Family: syscall.AF_NETLINK,
|
Family: syscall.AF_NETLINK,
|
||||||
Pid: getpid(),
|
Groups: groups,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
_ = syscall.Close(fd)
|
_ = syscall.Close(fd)
|
||||||
return nil, os.NewSyscallError("bind", err)
|
return nil, os.NewSyscallError("bind", err)
|
||||||
} else {
|
} else {
|
||||||
|
var addr syscall.Sockaddr
|
||||||
|
if addr, err = syscall.Getsockname(fd); err != nil {
|
||||||
|
_ = syscall.Close(fd)
|
||||||
|
return nil, os.NewSyscallError("getsockname", err)
|
||||||
|
}
|
||||||
|
switch a := addr.(type) {
|
||||||
|
case *syscall.SockaddrNetlink:
|
||||||
|
c.port = a.Pid
|
||||||
|
|
||||||
|
default: // unreachable
|
||||||
|
_ = syscall.Close(fd)
|
||||||
|
return nil, syscall.ENOTRECOVERABLE
|
||||||
|
}
|
||||||
|
|
||||||
c.family = family
|
c.family = family
|
||||||
c.f = os.NewFile(uintptr(fd), "netlink")
|
c.f = os.NewFile(uintptr(fd), "netlink")
|
||||||
if c.raw, err = c.f.SyscallConn(); err != nil {
|
if c.raw, err = c.f.SyscallConn(); err != nil {
|
||||||
_ = c.f.Close()
|
_ = c.f.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
c.state |= stateOpen
|
||||||
}
|
}
|
||||||
|
|
||||||
c.pos = syscall.NLMSG_HDRLEN
|
c.pos = syscall.NLMSG_HDRLEN
|
||||||
@@ -78,14 +89,14 @@ func dial(family int) (*conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ok returns whether conn is still open.
|
// ok returns whether conn is still open.
|
||||||
func (c *conn) ok() bool { return c.family >= 0 }
|
func (c *conn) ok() bool { return c.state&stateOpen != 0 }
|
||||||
|
|
||||||
// Close closes the underlying socket.
|
// Close closes the underlying socket.
|
||||||
func (c *conn) Close() error {
|
func (c *conn) Close() error {
|
||||||
if !c.ok() {
|
if !c.ok() {
|
||||||
return syscall.EINVAL
|
return syscall.EINVAL
|
||||||
}
|
}
|
||||||
c.family = -1
|
c.state &= ^stateOpen
|
||||||
return c.f.Close()
|
return c.f.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,7 +242,7 @@ func (c *conn) pending() []byte {
|
|||||||
Type: c.typ,
|
Type: c.typ,
|
||||||
Flags: c.flags,
|
Flags: c.flags,
|
||||||
Seq: c.seq,
|
Seq: c.seq,
|
||||||
Pid: getpid(),
|
Pid: c.port,
|
||||||
}
|
}
|
||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
@@ -266,8 +277,8 @@ func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error {
|
|||||||
|
|
||||||
for i := range resp {
|
for i := range resp {
|
||||||
header := &resp[i].Header
|
header := &resp[i].Header
|
||||||
if header.Seq != c.seq || header.Pid != getpid() {
|
if header.Seq != c.seq || header.Pid != c.port {
|
||||||
return &InconsistentError{*header, c.seq, getpid()}
|
return &InconsistentError{*header, c.seq, c.port}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err = f(resp); err != nil {
|
if err = f(resp); err != nil {
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() { nlPidOnce.Do(func() {}); nlPid = 1 }
|
|
||||||
|
|
||||||
type payloadTestCase struct {
|
type payloadTestCase struct {
|
||||||
name string
|
name string
|
||||||
f func(c *conn)
|
f func(c *conn)
|
||||||
@@ -21,8 +19,9 @@ func checkPayload(t *testing.T, testCases []payloadTestCase) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
c := conn{pos: syscall.NLMSG_HDRLEN}
|
c := conn{port: 1, pos: syscall.NLMSG_HDRLEN}
|
||||||
tc.f(&c)
|
tc.f(&c)
|
||||||
if got := c.pending(); string(got) != string(tc.want) {
|
if got := c.pending(); string(got) != string(tc.want) {
|
||||||
t.Errorf("pending: %#v, want %#v", got, tc.want)
|
t.Errorf("pending: %#v, want %#v", got, tc.want)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ type RouteConn struct{ *conn }
|
|||||||
|
|
||||||
// DialRoute returns the address of a newly connected [RouteConn].
|
// DialRoute returns the address of a newly connected [RouteConn].
|
||||||
func DialRoute() (*RouteConn, error) {
|
func DialRoute() (*RouteConn, error) {
|
||||||
c, err := dial(syscall.NETLINK_ROUTE)
|
c, err := dial(syscall.NETLINK_ROUTE, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user