internal/netlink: make full response available

The previous API makes it impossible to retrieve remaining messages in the current iteration.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2026-03-23 15:16:21 +09:00
parent d8648304bb
commit d972cffe5a
2 changed files with 51 additions and 49 deletions

View File

@@ -143,8 +143,45 @@ type Complete struct{}
// Error returns a hardcoded string that should never be displayed to the user.
func (Complete) Error() string { return "returning from roundtrip" }
// HandlerFunc handles [syscall.NetlinkMessage] and returns a non-nil error to
// discontinue the receiving of more messages.
type HandlerFunc func(resp []syscall.NetlinkMessage) error
// receive receives from a socket with specified flags until a non-nil error is
// returned by f. An error of type [Complete] is returned as nil.
func (c *conn) receive(f HandlerFunc, flags int) error {
for {
buf := c.buf
if n, _, err := syscall.Recvfrom(c.fd, buf, flags); err != nil {
return os.NewSyscallError("recvfrom", err)
} else if n < syscall.NLMSG_HDRLEN {
return syscall.EBADE
} else {
buf = buf[:n]
}
resp, err := syscall.ParseNetlinkMessage(buf)
if err != nil {
return err
}
for i := range resp {
header := &resp[i].Header
if header.Seq != c.seq || header.Pid != getpid() {
return &InconsistentError{*header, c.seq, getpid()}
}
}
if err = f(resp); err != nil {
if err == (Complete{}) {
return nil
}
return err
}
}
}
// Roundtrip sends the pending message and handles the reply.
func (c *conn) Roundtrip(f func(msg *syscall.NetlinkMessage) error) error {
func (c *conn) Roundtrip(f HandlerFunc) error {
if c.buf == nil {
return syscall.EINVAL
}
@@ -158,38 +195,3 @@ func (c *conn) Roundtrip(f func(msg *syscall.NetlinkMessage) error) error {
return c.receive(f, 0)
}
// receive receives from a socket with specified flags until a non-nil error is
// returned by f. An error of type [Complete] is returned as nil.
func (c *conn) receive(
f func(msg *syscall.NetlinkMessage) error,
flags int,
) error {
for {
buf := c.buf
if n, _, err := syscall.Recvfrom(c.fd, buf, flags); err != nil {
return os.NewSyscallError("recvfrom", err)
} else if n < syscall.NLMSG_HDRLEN {
return syscall.EBADE
} else {
buf = buf[:n]
}
msgs, err := syscall.ParseNetlinkMessage(buf)
if err != nil {
return err
}
for _, msg := range msgs {
if msg.Header.Seq != c.seq || msg.Header.Pid != getpid() {
return &InconsistentError{msg.Header, c.seq, getpid()}
}
if err = f(&msg); err != nil {
if err == (Complete{}) {
return nil
}
return err
}
}
}
}

View File

@@ -18,23 +18,23 @@ func DialRoute() (*RouteConn, error) {
}
// rtnlConsume consumes a message from rtnetlink.
func rtnlConsume(msg *syscall.NetlinkMessage) error {
switch msg.Header.Type {
case syscall.NLMSG_DONE:
return Complete{}
func rtnlConsume(resp []syscall.NetlinkMessage) error {
for i := range resp {
switch resp[i].Header.Type {
case syscall.NLMSG_DONE:
return Complete{}
case syscall.NLMSG_ERROR:
if e := As[syscall.NlMsgerr](msg.Data); e != nil {
if e.Error == 0 {
return Complete{}
case syscall.NLMSG_ERROR:
if e := As[syscall.NlMsgerr](resp[i].Data); e != nil {
if e.Error == 0 {
return Complete{}
}
return syscall.Errno(-e.Error)
}
return syscall.Errno(-e.Error)
return syscall.EBADE
}
return syscall.EBADE
default:
return nil
}
return nil
}
// InAddr is equivalent to struct in_addr.