From 72bd3fb05ef796e04e3b61015f6ffac3c1a937e3 Mon Sep 17 00:00:00 2001 From: Ophestra Date: Mon, 16 Mar 2026 22:09:12 +0900 Subject: [PATCH] internal/netlink: generalise implementation from container This is useful for uevent implementation. Signed-off-by: Ophestra --- internal/netlink/netlink.go | 186 +++++++++++++++++++++++++++++++ internal/netlink/netlink_test.go | 36 ++++++ internal/netlink/rtnl.go | 132 ++++++++++++++++++++++ internal/netlink/rtnl_test.go | 62 +++++++++++ 4 files changed, 416 insertions(+) create mode 100644 internal/netlink/netlink.go create mode 100644 internal/netlink/netlink_test.go create mode 100644 internal/netlink/rtnl.go create mode 100644 internal/netlink/rtnl_test.go diff --git a/internal/netlink/netlink.go b/internal/netlink/netlink.go new file mode 100644 index 0000000..d1260c8 --- /dev/null +++ b/internal/netlink/netlink.go @@ -0,0 +1,186 @@ +// Package netlink is a partial implementation of the netlink protocol. +package netlink + +import ( + "fmt" + "os" + "sync" + "syscall" + "unsafe" +) + +// AF_NETLINK socket is never shared +var ( + nlPid uint32 + nlPidOnce sync.Once +) + +// getpid returns a cached pid value. +func getpid() uint32 { + nlPidOnce.Do(func() { nlPid = uint32(os.Getpid()) }) + return nlPid +} + +// A conn represents resources associated to a netlink socket. +type conn struct { + // AF_NETLINK socket. + fd int + // Kernel module or netlink group to communicate with. + family int + // Message sequence number. + seq uint32 + // For pending outgoing message. + typ, flags uint16 + // Outgoing position in buf. + pos int + // A page holding incoming and outgoing messages. + buf []byte +} + +// dial returns the address of a newly connected conn of specified family. +func dial(family int) (*conn, error) { + var c conn + if fd, err := syscall.Socket( + syscall.AF_NETLINK, + syscall.SOCK_RAW|syscall.SOCK_CLOEXEC, + family, + ); err != nil { + return nil, os.NewSyscallError("socket", err) + } else if err = syscall.Bind(fd, &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: getpid(), + }); err != nil { + _ = syscall.Close(fd) + return nil, os.NewSyscallError("bind", err) + } else { + c.fd, c.family = fd, family + } + c.pos = syscall.NLMSG_HDRLEN + c.buf = make([]byte, os.Getpagesize()) + return &c, nil +} + +// Close closes the underlying socket. +func (c *conn) Close() error { + if c.buf == nil { + return syscall.EINVAL + } + c.buf = nil + return syscall.Close(c.fd) +} + +// Msg is type constraint for types sent over the wire via netlink. +// +// No pointer types or compound types containing pointers may appear here. +type Msg interface { + syscall.NlMsghdr | syscall.NlMsgerr | + syscall.IfAddrmsg | RtAttrMsg[InAddr] | + syscall.IfInfomsg +} + +// As returns data as the specified netlink message type. +func As[M Msg](data []byte) *M { + var v M + if unsafe.Sizeof(v) != uintptr(len(data)) { + return nil + } + return (*M)(unsafe.Pointer(unsafe.SliceData(data))) +} + +// add queues a value to be sent by conn. +func add[M Msg](c *conn, p *M) bool { + pos := c.pos + c.pos += int(unsafe.Sizeof(*p)) + if c.pos > len(c.buf) { + c.pos = pos + return false + } + *(*M)(unsafe.Pointer(&c.buf[pos])) = *p + return true +} + +// InconsistentError describes a reply from the kernel that is not consistent +// with the internal state tracked by this package. +type InconsistentError struct { + // Offending header. + syscall.NlMsghdr + // Expected message sequence. + Seq uint32 + // Expected pid. + Pid uint32 +} + +func (*InconsistentError) Unwrap() error { return os.ErrInvalid } +func (e *InconsistentError) Error() string { + s := "netlink socket has inconsistent state" + switch { + case e.Seq != e.NlMsghdr.Seq: + s += fmt.Sprintf(": seq %d != %d", e.Seq, e.NlMsghdr.Seq) + case e.Pid != e.NlMsghdr.Pid: + s += fmt.Sprintf(": pid %d != %d", e.Pid, e.NlMsghdr.Pid) + } + return s +} + +// pending returns the valid slice of buf and initialises pos. +func (c *conn) pending() []byte { + buf := c.buf[:c.pos] + c.pos = syscall.NLMSG_HDRLEN + + *(*syscall.NlMsghdr)(unsafe.Pointer(unsafe.SliceData(buf))) = syscall.NlMsghdr{ + Len: uint32(len(buf)), + Type: c.typ, + Flags: c.flags, + Seq: c.seq, + Pid: getpid(), + } + return buf +} + +// Complete indicates the completion of a roundtrip. +type Complete struct{} + +// Error returns a hardcoded string that should never be displayed to the user. +func (Complete) Error() string { return "returning from roundtrip" } + +// Roundtrip sends the pending message and handles the reply. +func (c *conn) Roundtrip(f func(msg *syscall.NetlinkMessage) error) error { + if c.buf == nil { + return syscall.EINVAL + } + defer func() { c.seq++ }() + + if err := syscall.Sendto(c.fd, c.pending(), 0, &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + }); err != nil { + return os.NewSyscallError("sendto", err) + } + + for { + buf := c.buf + if n, _, err := syscall.Recvfrom(c.fd, buf, 0); err != nil { + return os.NewSyscallError("recvfrom", err) + } else if n < syscall.NLMSG_HDRLEN { + return syscall.EBADE + } else { + buf = buf[:n] + } + + msgs, err := syscall.ParseNetlinkMessage(buf) + if err != nil { + return err + } + + for _, msg := range msgs { + if msg.Header.Seq != c.seq || msg.Header.Pid != getpid() { + return &InconsistentError{msg.Header, c.seq, getpid()} + } + if err = f(&msg); err != nil { + if err == (Complete{}) { + return nil + } + return err + } + } + } +} diff --git a/internal/netlink/netlink_test.go b/internal/netlink/netlink_test.go new file mode 100644 index 0000000..0926e3b --- /dev/null +++ b/internal/netlink/netlink_test.go @@ -0,0 +1,36 @@ +package netlink + +import ( + "os" + "syscall" + "testing" +) + +func init() { nlPidOnce.Do(func() {}); nlPid = 1 } + +type payloadTestCase struct { + name string + f func(c *conn) + want []byte +} + +// checkPayload runs multiple payloadTestCase against a stub conn and checks +// the outgoing message written to its buffer page. +func checkPayload(t *testing.T, testCases []payloadTestCase) { + t.Helper() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + c := conn{ + pos: syscall.NLMSG_HDRLEN, + buf: make([]byte, os.Getpagesize()), + } + tc.f(&c) + if got := c.pending(); string(got) != string(tc.want) { + t.Errorf("pending: %#v, want %#v", got, tc.want) + } + }) + } +} diff --git a/internal/netlink/rtnl.go b/internal/netlink/rtnl.go new file mode 100644 index 0000000..1ed463b --- /dev/null +++ b/internal/netlink/rtnl.go @@ -0,0 +1,132 @@ +package netlink + +import ( + "syscall" + "unsafe" +) + +// RouteConn represents a NETLINK_ROUTE socket. +type RouteConn struct{ *conn } + +// DialRoute returns the address of a newly connected [RouteConn]. +func DialRoute() (*RouteConn, error) { + c, err := dial(syscall.NETLINK_ROUTE) + if err != nil { + return nil, err + } + return &RouteConn{c}, nil +} + +// rtnlConsume consumes a message from rtnetlink. +func rtnlConsume(msg *syscall.NetlinkMessage) error { + switch msg.Header.Type { + case syscall.NLMSG_DONE: + return Complete{} + + case syscall.NLMSG_ERROR: + if e := As[syscall.NlMsgerr](msg.Data); e != nil { + if e.Error == 0 { + return Complete{} + } + return syscall.Errno(-e.Error) + } + return syscall.EBADE + + default: + return nil + } +} + +// InAddr is equivalent to struct in_addr. +type InAddr [4]byte + +// RtAttrMsg holds syscall.RtAttr alongside its payload. +type RtAttrMsg[D any] struct { + syscall.RtAttr + Data D +} + +// populate populates the Len field of the embedded syscall.RtAttr. +func (attr *RtAttrMsg[M]) populate() { + attr.Len = syscall.SizeofRtAttr + uint16(unsafe.Sizeof(attr.Data)) +} + +// writeIfAddrmsg writes an ifaddrmsg structure to conn. +func (c *RouteConn) writeIfAddrmsg( + typ, flags uint16, + msg *syscall.IfAddrmsg, + attrs ...RtAttrMsg[InAddr], +) bool { + c.typ, c.flags = typ, syscall.NLM_F_REQUEST|syscall.NLM_F_ACK|flags + if !add(c.conn, msg) { + return false + } + for _, attr := range attrs { + attr.populate() + if !add(c.conn, &attr) { + return false + } + } + return true +} + +// SendIfAddrmsg sends an ifaddrmsg structure to rtnetlink. +func (c *RouteConn) SendIfAddrmsg( + typ, flags uint16, + msg *syscall.IfAddrmsg, + attrs ...RtAttrMsg[InAddr], +) error { + if !c.writeIfAddrmsg(typ, flags, msg, attrs...) { + return syscall.ENOMEM + } + return c.Roundtrip(rtnlConsume) +} + +// writeNewaddrLo writes a RTM_NEWADDR message for the loopback address. +func (c *RouteConn) writeNewaddrLo(lo uint32) bool { + return c.writeIfAddrmsg( + syscall.RTM_NEWADDR, + syscall.NLM_F_CREATE|syscall.NLM_F_EXCL, + &syscall.IfAddrmsg{ + Family: syscall.AF_INET, + Prefixlen: 8, + Flags: syscall.IFA_F_PERMANENT, + Scope: syscall.RT_SCOPE_HOST, + Index: lo, + }, + RtAttrMsg[InAddr]{syscall.RtAttr{ + Type: syscall.IFA_LOCAL, + }, InAddr{127, 0, 0, 1}}, + RtAttrMsg[InAddr]{syscall.RtAttr{ + Type: syscall.IFA_ADDRESS, + }, InAddr{127, 0, 0, 1}}, + ) +} + +// SendNewaddrLo sends a RTM_NEWADDR message for the loopback address to the kernel. +func (c *RouteConn) SendNewaddrLo(lo uint32) error { + if !c.writeNewaddrLo(lo) { + return syscall.ENOMEM + } + return c.Roundtrip(rtnlConsume) +} + +// writeIfInfomsg writes an ifinfomsg structure to conn. +func (c *RouteConn) writeIfInfomsg( + typ, flags uint16, + msg *syscall.IfInfomsg, +) bool { + c.typ, c.flags = typ, syscall.NLM_F_REQUEST|syscall.NLM_F_ACK|flags + return add(c.conn, msg) +} + +// SendIfInfomsg sends an ifinfomsg structure to rtnetlink. +func (c *RouteConn) SendIfInfomsg( + typ, flags uint16, + msg *syscall.IfInfomsg, +) error { + if !c.writeIfInfomsg(typ, flags, msg) { + return syscall.ENOMEM + } + return c.Roundtrip(rtnlConsume) +} diff --git a/internal/netlink/rtnl_test.go b/internal/netlink/rtnl_test.go new file mode 100644 index 0000000..8a281a4 --- /dev/null +++ b/internal/netlink/rtnl_test.go @@ -0,0 +1,62 @@ +package netlink + +import ( + "syscall" + "testing" +) + +func TestPayloadRTNETLINK(t *testing.T) { + t.Parallel() + + checkPayload(t, []payloadTestCase{ + {"RTM_NEWADDR lo", func(c *conn) { + (&RouteConn{c}).writeNewaddrLo(1) + }, []byte{ + /* Len */ 0x28, 0, 0, 0, + /* Type */ 0x14, 0, + /* Flags */ 5, 6, + /* Seq */ 0, 0, 0, 0, + /* Pid */ 1, 0, 0, 0, + + /* Family */ 2, + /* Prefixlen */ 8, + /* Flags */ 0x80, + /* Scope */ 0xfe, + /* Index */ 1, 0, 0, 0, + + /* Len */ 8, 0, + /* Type */ 2, 0, + /* in_addr */ 127, 0, 0, 1, + + /* Len */ 8, 0, + /* Type */ 1, 0, + /* in_addr */ 127, 0, 0, 1, + }}, + + {"RTM_NEWLINK", func(c *conn) { + c.seq++ + (&RouteConn{c}).writeIfInfomsg( + syscall.RTM_NEWLINK, 0, + &syscall.IfInfomsg{ + Family: syscall.AF_UNSPEC, + Index: 1, + Flags: syscall.IFF_UP, + Change: syscall.IFF_UP, + }, + ) + }, []byte{ + /* Len */ 0x20, 0, 0, 0, + /* Type */ 0x10, 0, + /* Flags */ 5, 0, + /* Seq */ 1, 0, 0, 0, + /* Pid */ 1, 0, 0, 0, + + /* Family */ 0, + /* pad */ 0, + /* Type */ 0, 0, + /* Index */ 1, 0, 0, 0, + /* Flags */ 1, 0, 0, 0, + /* Change */ 1, 0, 0, 0, + }}, + }) +}