diff --git a/internal/netlink/netlink.go b/internal/netlink/netlink.go index 89f692a5..857259f1 100644 --- a/internal/netlink/netlink.go +++ b/internal/netlink/netlink.go @@ -5,33 +5,29 @@ import ( "context" "fmt" "os" - "sync" "syscall" "time" "unsafe" ) -// AF_NETLINK socket is never shared -var ( - nlPid uint32 - nlPidOnce sync.Once -) - -// getpid returns a cached pid value. -func getpid() uint32 { - nlPidOnce.Do(func() { nlPid = uint32(os.Getpid()) }) - return nlPid -} - // net/netlink/af_netlink.c const maxRecvmsgLen = 32768 +const ( + // stateOpen denotes an open conn. + stateOpen uint32 = 1 << iota +) + // A conn represents resources associated to a netlink socket. type conn struct { // AF_NETLINK socket. f *os.File // For using runtime polling via f. raw syscall.RawConn + // Port ID assigned by the kernel. + port uint32 + // Internal connection status. + state uint32 // Kernel module or netlink group to communicate with. family int // Message sequence number. @@ -49,7 +45,7 @@ type conn struct { } // dial returns the address of a newly connected conn of specified family. -func dial(family int) (*conn, error) { +func dial(family int, groups uint32) (*conn, error) { var c conn if fd, err := syscall.Socket( syscall.AF_NETLINK, @@ -59,17 +55,32 @@ func dial(family int) (*conn, error) { return nil, os.NewSyscallError("socket", err) } else if err = syscall.Bind(fd, &syscall.SockaddrNetlink{ Family: syscall.AF_NETLINK, - Pid: getpid(), + Groups: groups, }); err != nil { _ = syscall.Close(fd) return nil, os.NewSyscallError("bind", err) } else { + var addr syscall.Sockaddr + if addr, err = syscall.Getsockname(fd); err != nil { + _ = syscall.Close(fd) + return nil, os.NewSyscallError("getsockname", err) + } + switch a := addr.(type) { + case *syscall.SockaddrNetlink: + c.port = a.Pid + + default: // unreachable + _ = syscall.Close(fd) + return nil, syscall.ENOTRECOVERABLE + } + c.family = family c.f = os.NewFile(uintptr(fd), "netlink") if c.raw, err = c.f.SyscallConn(); err != nil { _ = c.f.Close() return nil, err } + c.state |= stateOpen } c.pos = syscall.NLMSG_HDRLEN @@ -78,14 +89,14 @@ func dial(family int) (*conn, error) { } // ok returns whether conn is still open. -func (c *conn) ok() bool { return c.family >= 0 } +func (c *conn) ok() bool { return c.state&stateOpen != 0 } // Close closes the underlying socket. func (c *conn) Close() error { if !c.ok() { return syscall.EINVAL } - c.family = -1 + c.state &= ^stateOpen return c.f.Close() } @@ -231,7 +242,7 @@ func (c *conn) pending() []byte { Type: c.typ, Flags: c.flags, Seq: c.seq, - Pid: getpid(), + Pid: c.port, } return buf } @@ -266,8 +277,8 @@ func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error { for i := range resp { header := &resp[i].Header - if header.Seq != c.seq || header.Pid != getpid() { - return &InconsistentError{*header, c.seq, getpid()} + if header.Seq != c.seq || header.Pid != c.port { + return &InconsistentError{*header, c.seq, c.port} } } if err = f(resp); err != nil { diff --git a/internal/netlink/netlink_test.go b/internal/netlink/netlink_test.go index 0b7df0f1..a161ef5d 100644 --- a/internal/netlink/netlink_test.go +++ b/internal/netlink/netlink_test.go @@ -5,8 +5,6 @@ import ( "testing" ) -func init() { nlPidOnce.Do(func() {}); nlPid = 1 } - type payloadTestCase struct { name string f func(c *conn) @@ -21,8 +19,9 @@ func checkPayload(t *testing.T, testCases []payloadTestCase) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() + t.Helper() - c := conn{pos: syscall.NLMSG_HDRLEN} + c := conn{port: 1, pos: syscall.NLMSG_HDRLEN} tc.f(&c) if got := c.pending(); string(got) != string(tc.want) { t.Errorf("pending: %#v, want %#v", got, tc.want) diff --git a/internal/netlink/rtnl.go b/internal/netlink/rtnl.go index e824f091..2f51fe58 100644 --- a/internal/netlink/rtnl.go +++ b/internal/netlink/rtnl.go @@ -11,7 +11,7 @@ type RouteConn struct{ *conn } // DialRoute returns the address of a newly connected [RouteConn]. func DialRoute() (*RouteConn, error) { - c, err := dial(syscall.NETLINK_ROUTE) + c, err := dial(syscall.NETLINK_ROUTE, 0) if err != nil { return nil, err }