From 722c3cc54f764679b12daa04066a504726fce253 Mon Sep 17 00:00:00 2001 From: Ophestra Date: Wed, 25 Mar 2026 19:33:01 +0900 Subject: [PATCH] internal/netlink: optional check header as reply Not every received message is a reply. Signed-off-by: Ophestra --- internal/netlink/netlink.go | 14 ++++++++------ internal/netlink/rtnl.go | 12 ++++++++---- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/internal/netlink/netlink.go b/internal/netlink/netlink.go index 857259f1..0a3c6688 100644 --- a/internal/netlink/netlink.go +++ b/internal/netlink/netlink.go @@ -232,6 +232,14 @@ func (e *InconsistentError) Error() string { return s } +// checkReply checks the message header of a reply from the kernel. +func (c *conn) checkReply(header *syscall.NlMsghdr) error { + if header.Seq != c.seq || header.Pid != c.port { + return &InconsistentError{*header, c.seq, c.port} + } + return nil +} + // pending returns the valid slice of buf and initialises pos. func (c *conn) pending() []byte { buf := c.buf[:c.pos] @@ -275,12 +283,6 @@ func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error { return err } - for i := range resp { - header := &resp[i].Header - if header.Seq != c.seq || header.Pid != c.port { - return &InconsistentError{*header, c.seq, c.port} - } - } if err = f(resp); err != nil { if err == (Complete{}) { return nil diff --git a/internal/netlink/rtnl.go b/internal/netlink/rtnl.go index 2f51fe58..243bc27e 100644 --- a/internal/netlink/rtnl.go +++ b/internal/netlink/rtnl.go @@ -19,8 +19,12 @@ func DialRoute() (*RouteConn, error) { } // rtnlConsume consumes a message from rtnetlink. -func rtnlConsume(resp []syscall.NetlinkMessage) error { +func (c *conn) rtnlConsume(resp []syscall.NetlinkMessage) error { for i := range resp { + if err := c.checkReply(&resp[i].Header); err != nil { + return err + } + switch resp[i].Header.Type { case syscall.NLMSG_DONE: return Complete{} @@ -81,7 +85,7 @@ func (c *RouteConn) SendIfAddrmsg( if !c.writeIfAddrmsg(typ, flags, msg, attrs...) { return syscall.ENOMEM } - return c.Roundtrip(ctx, rtnlConsume) + return c.Roundtrip(ctx, c.conn.rtnlConsume) } // writeNewaddrLo writes a RTM_NEWADDR message for the loopback address. @@ -110,7 +114,7 @@ func (c *RouteConn) SendNewaddrLo(ctx context.Context, lo uint32) error { if !c.writeNewaddrLo(lo) { return syscall.ENOMEM } - return c.Roundtrip(ctx, rtnlConsume) + return c.Roundtrip(ctx, c.conn.rtnlConsume) } // writeIfInfomsg writes an ifinfomsg structure to conn. @@ -131,5 +135,5 @@ func (c *RouteConn) SendIfInfomsg( if !c.writeIfInfomsg(typ, flags, msg) { return syscall.ENOMEM } - return c.Roundtrip(ctx, rtnlConsume) + return c.Roundtrip(ctx, c.conn.rtnlConsume) }