diff --git a/internal/netlink/netlink.go b/internal/netlink/netlink.go index e6f6fa26..5df7b390 100644 --- a/internal/netlink/netlink.go +++ b/internal/netlink/netlink.go @@ -143,8 +143,45 @@ type Complete struct{} // Error returns a hardcoded string that should never be displayed to the user. func (Complete) Error() string { return "returning from roundtrip" } +// HandlerFunc handles [syscall.NetlinkMessage] and returns a non-nil error to +// discontinue the receiving of more messages. +type HandlerFunc func(resp []syscall.NetlinkMessage) error + +// receive receives from a socket with specified flags until a non-nil error is +// returned by f. An error of type [Complete] is returned as nil. +func (c *conn) receive(f HandlerFunc, flags int) error { + for { + buf := c.buf + if n, _, err := syscall.Recvfrom(c.fd, buf, flags); err != nil { + return os.NewSyscallError("recvfrom", err) + } else if n < syscall.NLMSG_HDRLEN { + return syscall.EBADE + } else { + buf = buf[:n] + } + + resp, err := syscall.ParseNetlinkMessage(buf) + if err != nil { + return err + } + + for i := range resp { + header := &resp[i].Header + if header.Seq != c.seq || header.Pid != getpid() { + return &InconsistentError{*header, c.seq, getpid()} + } + } + if err = f(resp); err != nil { + if err == (Complete{}) { + return nil + } + return err + } + } +} + // Roundtrip sends the pending message and handles the reply. -func (c *conn) Roundtrip(f func(msg *syscall.NetlinkMessage) error) error { +func (c *conn) Roundtrip(f HandlerFunc) error { if c.buf == nil { return syscall.EINVAL } @@ -158,38 +195,3 @@ func (c *conn) Roundtrip(f func(msg *syscall.NetlinkMessage) error) error { return c.receive(f, 0) } - -// receive receives from a socket with specified flags until a non-nil error is -// returned by f. An error of type [Complete] is returned as nil. -func (c *conn) receive( - f func(msg *syscall.NetlinkMessage) error, - flags int, -) error { - for { - buf := c.buf - if n, _, err := syscall.Recvfrom(c.fd, buf, flags); err != nil { - return os.NewSyscallError("recvfrom", err) - } else if n < syscall.NLMSG_HDRLEN { - return syscall.EBADE - } else { - buf = buf[:n] - } - - msgs, err := syscall.ParseNetlinkMessage(buf) - if err != nil { - return err - } - - for _, msg := range msgs { - if msg.Header.Seq != c.seq || msg.Header.Pid != getpid() { - return &InconsistentError{msg.Header, c.seq, getpid()} - } - if err = f(&msg); err != nil { - if err == (Complete{}) { - return nil - } - return err - } - } - } -} diff --git a/internal/netlink/rtnl.go b/internal/netlink/rtnl.go index 1ed463b6..e6585bf8 100644 --- a/internal/netlink/rtnl.go +++ b/internal/netlink/rtnl.go @@ -18,23 +18,23 @@ func DialRoute() (*RouteConn, error) { } // rtnlConsume consumes a message from rtnetlink. -func rtnlConsume(msg *syscall.NetlinkMessage) error { - switch msg.Header.Type { - case syscall.NLMSG_DONE: - return Complete{} +func rtnlConsume(resp []syscall.NetlinkMessage) error { + for i := range resp { + switch resp[i].Header.Type { + case syscall.NLMSG_DONE: + return Complete{} - case syscall.NLMSG_ERROR: - if e := As[syscall.NlMsgerr](msg.Data); e != nil { - if e.Error == 0 { - return Complete{} + case syscall.NLMSG_ERROR: + if e := As[syscall.NlMsgerr](resp[i].Data); e != nil { + if e.Error == 0 { + return Complete{} + } + return syscall.Errno(-e.Error) } - return syscall.Errno(-e.Error) + return syscall.EBADE } - return syscall.EBADE - - default: - return nil } + return nil } // InAddr is equivalent to struct in_addr.