diff --git a/internal/netlink/netlink.go b/internal/netlink/netlink.go index 0a3c6688..4c3e10cb 100644 --- a/internal/netlink/netlink.go +++ b/internal/netlink/netlink.go @@ -18,8 +18,8 @@ const ( stateOpen uint32 = 1 << iota ) -// A conn represents resources associated to a netlink socket. -type conn struct { +// A Conn represents resources associated to a netlink socket. +type Conn struct { // AF_NETLINK socket. f *os.File // For using runtime polling via f. @@ -44,9 +44,10 @@ type conn struct { t time.Time } -// dial returns the address of a newly connected conn of specified family. -func dial(family int, groups uint32) (*conn, error) { - var c conn +// Dial returns the address of a newly connected generic netlink connection of +// specified family and groups. +func Dial(family int, groups uint32) (*Conn, error) { + var c Conn if fd, err := syscall.Socket( syscall.AF_NETLINK, syscall.SOCK_RAW|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, @@ -89,10 +90,10 @@ func dial(family int, groups uint32) (*conn, error) { } // ok returns whether conn is still open. -func (c *conn) ok() bool { return c.state&stateOpen != 0 } +func (c *Conn) ok() bool { return c.state&stateOpen != 0 } // Close closes the underlying socket. -func (c *conn) Close() error { +func (c *Conn) Close() error { if !c.ok() { return syscall.EINVAL } @@ -100,35 +101,41 @@ func (c *conn) Close() error { return c.f.Close() } -// recvfrom wraps recv(2) with nonblocking behaviour via the runtime network poller. -func (c *conn) recvfrom( +// Recvfrom wraps recv(2) with nonblocking behaviour via the runtime network poller. +// +// The returned slice is valid until the next call to Recvfrom. +func (c *Conn) Recvfrom( ctx context.Context, - p []byte, flags int, -) (n int, from syscall.Sockaddr, err error) { +) (data []byte, from syscall.Sockaddr, err error) { if err = c.f.SetReadDeadline(time.Time{}); err != nil { return } + var n int + data = c.buf[:] done := make(chan error, 1) go func() { done <- c.raw.Read(func(fd uintptr) (done bool) { - n, from, err = syscall.Recvfrom(int(fd), p, flags) + n, from, err = syscall.Recvfrom(int(fd), data, flags) return err != syscall.EWOULDBLOCK }) }() select { case rcErr := <-done: + data = data[:n] if err != nil { err = os.NewSyscallError("recvfrom", err) } else { err = rcErr } + return case <-ctx.Done(): cancelErr := c.f.SetReadDeadline(c.t) <-done + data = data[:n] if cancelErr != nil { err = cancelErr } else { @@ -136,11 +143,10 @@ func (c *conn) recvfrom( } return } - return } -// sendto wraps send(2) with nonblocking behaviour via the runtime network poller. -func (c *conn) sendto( +// Sendto wraps send(2) with nonblocking behaviour via the runtime network poller. +func (c *Conn) Sendto( ctx context.Context, p []byte, flags int, @@ -165,6 +171,7 @@ func (c *conn) sendto( } else { err = rcErr } + return case <-ctx.Done(): cancelErr := c.f.SetWriteDeadline(c.t) @@ -176,7 +183,6 @@ func (c *conn) sendto( } return } - return } // Msg is type constraint for types sent over the wire via netlink. @@ -198,7 +204,7 @@ func As[M Msg](data []byte) *M { } // add queues a value to be sent by conn. -func add[M Msg](c *conn, p *M) bool { +func add[M Msg](c *Conn, p *M) bool { pos := c.pos c.pos += int(unsafe.Sizeof(*p)) if c.pos > len(c.buf) { @@ -233,7 +239,7 @@ func (e *InconsistentError) Error() string { } // checkReply checks the message header of a reply from the kernel. -func (c *conn) checkReply(header *syscall.NlMsghdr) error { +func (c *Conn) checkReply(header *syscall.NlMsghdr) error { if header.Seq != c.seq || header.Pid != c.port { return &InconsistentError{*header, c.seq, c.port} } @@ -241,7 +247,7 @@ func (c *conn) checkReply(header *syscall.NlMsghdr) error { } // pending returns the valid slice of buf and initialises pos. -func (c *conn) pending() []byte { +func (c *Conn) pending() []byte { buf := c.buf[:c.pos] c.pos = syscall.NLMSG_HDRLEN @@ -267,23 +273,18 @@ type HandlerFunc func(resp []syscall.NetlinkMessage) error // receive receives from a socket with specified flags until a non-nil error is // returned by f. An error of type [Complete] is returned as nil. -func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error { +func (c *Conn) receive(ctx context.Context, f HandlerFunc, flags int) error { for { - buf := c.buf[:] - if n, _, err := c.recvfrom(ctx, buf, flags); err != nil { + var resp []syscall.NetlinkMessage + if data, _, err := c.Recvfrom(ctx, flags); err != nil { return err - } else if n < syscall.NLMSG_HDRLEN { + } else if len(data) < syscall.NLMSG_HDRLEN { return syscall.EBADE - } else { - buf = buf[:n] - } - - resp, err := syscall.ParseNetlinkMessage(buf) - if err != nil { + } else if resp, err = syscall.ParseNetlinkMessage(data); err != nil { return err } - if err = f(resp); err != nil { + if err := f(resp); err != nil { if err == (Complete{}) { return nil } @@ -293,13 +294,13 @@ func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error { } // Roundtrip sends the pending message and handles the reply. -func (c *conn) Roundtrip(ctx context.Context, f HandlerFunc) error { +func (c *Conn) Roundtrip(ctx context.Context, f HandlerFunc) error { if !c.ok() { return syscall.EINVAL } defer func() { c.seq++ }() - if err := c.sendto(ctx, c.pending(), 0, &syscall.SockaddrNetlink{ + if err := c.Sendto(ctx, c.pending(), 0, &syscall.SockaddrNetlink{ Family: syscall.AF_NETLINK, }); err != nil { return err diff --git a/internal/netlink/netlink_test.go b/internal/netlink/netlink_test.go index a161ef5d..9db9fca5 100644 --- a/internal/netlink/netlink_test.go +++ b/internal/netlink/netlink_test.go @@ -7,7 +7,7 @@ import ( type payloadTestCase struct { name string - f func(c *conn) + f func(c *Conn) want []byte } @@ -21,7 +21,7 @@ func checkPayload(t *testing.T, testCases []payloadTestCase) { t.Parallel() t.Helper() - c := conn{port: 1, pos: syscall.NLMSG_HDRLEN} + c := Conn{port: 1, pos: syscall.NLMSG_HDRLEN} tc.f(&c) if got := c.pending(); string(got) != string(tc.want) { t.Errorf("pending: %#v, want %#v", got, tc.want) diff --git a/internal/netlink/rtnl.go b/internal/netlink/rtnl.go index 243bc27e..e370bae7 100644 --- a/internal/netlink/rtnl.go +++ b/internal/netlink/rtnl.go @@ -7,11 +7,14 @@ import ( ) // RouteConn represents a NETLINK_ROUTE socket. -type RouteConn struct{ *conn } +type RouteConn struct{ conn *Conn } + +// Close closes the underlying socket. +func (c *RouteConn) Close() error { return c.conn.Close() } // DialRoute returns the address of a newly connected [RouteConn]. func DialRoute() (*RouteConn, error) { - c, err := dial(syscall.NETLINK_ROUTE, 0) + c, err := Dial(syscall.NETLINK_ROUTE, 0) if err != nil { return nil, err } @@ -19,9 +22,9 @@ func DialRoute() (*RouteConn, error) { } // rtnlConsume consumes a message from rtnetlink. -func (c *conn) rtnlConsume(resp []syscall.NetlinkMessage) error { +func (c *RouteConn) rtnlConsume(resp []syscall.NetlinkMessage) error { for i := range resp { - if err := c.checkReply(&resp[i].Header); err != nil { + if err := c.conn.checkReply(&resp[i].Header); err != nil { return err } @@ -62,7 +65,7 @@ func (c *RouteConn) writeIfAddrmsg( msg *syscall.IfAddrmsg, attrs ...RtAttrMsg[InAddr], ) bool { - c.typ, c.flags = typ, syscall.NLM_F_REQUEST|syscall.NLM_F_ACK|flags + c.conn.typ, c.conn.flags = typ, syscall.NLM_F_REQUEST|syscall.NLM_F_ACK|flags if !add(c.conn, msg) { return false } @@ -85,7 +88,7 @@ func (c *RouteConn) SendIfAddrmsg( if !c.writeIfAddrmsg(typ, flags, msg, attrs...) { return syscall.ENOMEM } - return c.Roundtrip(ctx, c.conn.rtnlConsume) + return c.conn.Roundtrip(ctx, c.rtnlConsume) } // writeNewaddrLo writes a RTM_NEWADDR message for the loopback address. @@ -114,7 +117,7 @@ func (c *RouteConn) SendNewaddrLo(ctx context.Context, lo uint32) error { if !c.writeNewaddrLo(lo) { return syscall.ENOMEM } - return c.Roundtrip(ctx, c.conn.rtnlConsume) + return c.conn.Roundtrip(ctx, c.rtnlConsume) } // writeIfInfomsg writes an ifinfomsg structure to conn. @@ -122,7 +125,7 @@ func (c *RouteConn) writeIfInfomsg( typ, flags uint16, msg *syscall.IfInfomsg, ) bool { - c.typ, c.flags = typ, syscall.NLM_F_REQUEST|syscall.NLM_F_ACK|flags + c.conn.typ, c.conn.flags = typ, syscall.NLM_F_REQUEST|syscall.NLM_F_ACK|flags return add(c.conn, msg) } @@ -135,5 +138,5 @@ func (c *RouteConn) SendIfInfomsg( if !c.writeIfInfomsg(typ, flags, msg) { return syscall.ENOMEM } - return c.Roundtrip(ctx, c.conn.rtnlConsume) + return c.conn.Roundtrip(ctx, c.rtnlConsume) } diff --git a/internal/netlink/rtnl_test.go b/internal/netlink/rtnl_test.go index 8a281a46..0d6290e8 100644 --- a/internal/netlink/rtnl_test.go +++ b/internal/netlink/rtnl_test.go @@ -9,7 +9,7 @@ func TestPayloadRTNETLINK(t *testing.T) { t.Parallel() checkPayload(t, []payloadTestCase{ - {"RTM_NEWADDR lo", func(c *conn) { + {"RTM_NEWADDR lo", func(c *Conn) { (&RouteConn{c}).writeNewaddrLo(1) }, []byte{ /* Len */ 0x28, 0, 0, 0, @@ -33,7 +33,7 @@ func TestPayloadRTNETLINK(t *testing.T) { /* in_addr */ 127, 0, 0, 1, }}, - {"RTM_NEWLINK", func(c *conn) { + {"RTM_NEWLINK", func(c *Conn) { c.seq++ (&RouteConn{c}).writeIfInfomsg( syscall.RTM_NEWLINK, 0,