Files
hakurei/internal/netlink/netlink.go
Ophestra 722c3cc54f
All checks were successful
Test / Create distribution (push) Successful in 1m1s
Test / Sandbox (push) Successful in 2m40s
Test / Hakurei (push) Successful in 3m48s
Test / ShareFS (push) Successful in 3m41s
Test / Sandbox (race detector) (push) Successful in 5m10s
Test / Hakurei (race detector) (push) Successful in 6m18s
Test / Flake checks (push) Successful in 1m19s
internal/netlink: optional check header as reply
Not every received message is a reply.

Signed-off-by: Ophestra <cat@gensokyo.uk>
2026-03-25 19:33:01 +09:00

310 lines
7.1 KiB
Go

// Package netlink is a partial implementation of the netlink protocol.
package netlink
import (
"context"
"fmt"
"os"
"syscall"
"time"
"unsafe"
)
// net/netlink/af_netlink.c
const maxRecvmsgLen = 32768
const (
// stateOpen denotes an open conn.
stateOpen uint32 = 1 << iota
)
// A conn represents resources associated to a netlink socket.
type conn struct {
// AF_NETLINK socket.
f *os.File
// For using runtime polling via f.
raw syscall.RawConn
// Port ID assigned by the kernel.
port uint32
// Internal connection status.
state uint32
// Kernel module or netlink group to communicate with.
family int
// Message sequence number.
seq uint32
// For pending outgoing message.
typ, flags uint16
// Outgoing position in buf.
pos int
// Pages holding incoming and outgoing messages.
buf [maxRecvmsgLen]byte
// An instant some time after conn was established, but before the first
// I/O operation on f through raw. This serves as a cached deadline to
// cancel blocking I/O.
t time.Time
}
// dial returns the address of a newly connected conn of specified family.
func dial(family int, groups uint32) (*conn, error) {
var c conn
if fd, err := syscall.Socket(
syscall.AF_NETLINK,
syscall.SOCK_RAW|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC,
family,
); err != nil {
return nil, os.NewSyscallError("socket", err)
} else if err = syscall.Bind(fd, &syscall.SockaddrNetlink{
Family: syscall.AF_NETLINK,
Groups: groups,
}); err != nil {
_ = syscall.Close(fd)
return nil, os.NewSyscallError("bind", err)
} else {
var addr syscall.Sockaddr
if addr, err = syscall.Getsockname(fd); err != nil {
_ = syscall.Close(fd)
return nil, os.NewSyscallError("getsockname", err)
}
switch a := addr.(type) {
case *syscall.SockaddrNetlink:
c.port = a.Pid
default: // unreachable
_ = syscall.Close(fd)
return nil, syscall.ENOTRECOVERABLE
}
c.family = family
c.f = os.NewFile(uintptr(fd), "netlink")
if c.raw, err = c.f.SyscallConn(); err != nil {
_ = c.f.Close()
return nil, err
}
c.state |= stateOpen
}
c.pos = syscall.NLMSG_HDRLEN
c.t = time.Now().UTC()
return &c, nil
}
// ok returns whether conn is still open.
func (c *conn) ok() bool { return c.state&stateOpen != 0 }
// Close closes the underlying socket.
func (c *conn) Close() error {
if !c.ok() {
return syscall.EINVAL
}
c.state &= ^stateOpen
return c.f.Close()
}
// recvfrom wraps recv(2) with nonblocking behaviour via the runtime network poller.
func (c *conn) recvfrom(
ctx context.Context,
p []byte,
flags int,
) (n int, from syscall.Sockaddr, err error) {
if err = c.f.SetReadDeadline(time.Time{}); err != nil {
return
}
done := make(chan error, 1)
go func() {
done <- c.raw.Read(func(fd uintptr) (done bool) {
n, from, err = syscall.Recvfrom(int(fd), p, flags)
return err != syscall.EWOULDBLOCK
})
}()
select {
case rcErr := <-done:
if err != nil {
err = os.NewSyscallError("recvfrom", err)
} else {
err = rcErr
}
case <-ctx.Done():
cancelErr := c.f.SetReadDeadline(c.t)
<-done
if cancelErr != nil {
err = cancelErr
} else {
err = ctx.Err()
}
return
}
return
}
// sendto wraps send(2) with nonblocking behaviour via the runtime network poller.
func (c *conn) sendto(
ctx context.Context,
p []byte,
flags int,
to syscall.Sockaddr,
) (err error) {
if err = c.f.SetWriteDeadline(time.Time{}); err != nil {
return
}
done := make(chan error, 1)
go func() {
done <- c.raw.Write(func(fd uintptr) (done bool) {
err = syscall.Sendto(int(fd), p, flags, to)
return err != syscall.EWOULDBLOCK
})
}()
select {
case rcErr := <-done:
if err != nil {
err = os.NewSyscallError("sendto", err)
} else {
err = rcErr
}
case <-ctx.Done():
cancelErr := c.f.SetWriteDeadline(c.t)
<-done
if cancelErr != nil {
err = cancelErr
} else {
err = ctx.Err()
}
return
}
return
}
// Msg is type constraint for types sent over the wire via netlink.
//
// No pointer types or compound types containing pointers may appear here.
type Msg interface {
syscall.NlMsghdr | syscall.NlMsgerr |
syscall.IfAddrmsg | RtAttrMsg[InAddr] |
syscall.IfInfomsg
}
// As returns data as the specified netlink message type.
func As[M Msg](data []byte) *M {
var v M
if unsafe.Sizeof(v) != uintptr(len(data)) {
return nil
}
return (*M)(unsafe.Pointer(unsafe.SliceData(data)))
}
// add queues a value to be sent by conn.
func add[M Msg](c *conn, p *M) bool {
pos := c.pos
c.pos += int(unsafe.Sizeof(*p))
if c.pos > len(c.buf) {
c.pos = pos
return false
}
*(*M)(unsafe.Pointer(&c.buf[pos])) = *p
return true
}
// InconsistentError describes a reply from the kernel that is not consistent
// with the internal state tracked by this package.
type InconsistentError struct {
// Offending header.
syscall.NlMsghdr
// Expected message sequence.
Seq uint32
// Expected pid.
Pid uint32
}
func (*InconsistentError) Unwrap() error { return os.ErrInvalid }
func (e *InconsistentError) Error() string {
s := "netlink socket has inconsistent state"
switch {
case e.Seq != e.NlMsghdr.Seq:
s += fmt.Sprintf(": seq %d != %d", e.Seq, e.NlMsghdr.Seq)
case e.Pid != e.NlMsghdr.Pid:
s += fmt.Sprintf(": pid %d != %d", e.Pid, e.NlMsghdr.Pid)
}
return s
}
// checkReply checks the message header of a reply from the kernel.
func (c *conn) checkReply(header *syscall.NlMsghdr) error {
if header.Seq != c.seq || header.Pid != c.port {
return &InconsistentError{*header, c.seq, c.port}
}
return nil
}
// pending returns the valid slice of buf and initialises pos.
func (c *conn) pending() []byte {
buf := c.buf[:c.pos]
c.pos = syscall.NLMSG_HDRLEN
*(*syscall.NlMsghdr)(unsafe.Pointer(unsafe.SliceData(buf))) = syscall.NlMsghdr{
Len: uint32(len(buf)),
Type: c.typ,
Flags: c.flags,
Seq: c.seq,
Pid: c.port,
}
return buf
}
// Complete indicates the completion of a roundtrip.
type Complete struct{}
// Error returns a hardcoded string that should never be displayed to the user.
func (Complete) Error() string { return "returning from roundtrip" }
// HandlerFunc handles [syscall.NetlinkMessage] and returns a non-nil error to
// discontinue the receiving of more messages.
type HandlerFunc func(resp []syscall.NetlinkMessage) error
// receive receives from a socket with specified flags until a non-nil error is
// returned by f. An error of type [Complete] is returned as nil.
func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error {
for {
buf := c.buf[:]
if n, _, err := c.recvfrom(ctx, buf, flags); err != nil {
return err
} else if n < syscall.NLMSG_HDRLEN {
return syscall.EBADE
} else {
buf = buf[:n]
}
resp, err := syscall.ParseNetlinkMessage(buf)
if err != nil {
return err
}
if err = f(resp); err != nil {
if err == (Complete{}) {
return nil
}
return err
}
}
}
// Roundtrip sends the pending message and handles the reply.
func (c *conn) Roundtrip(ctx context.Context, f HandlerFunc) error {
if !c.ok() {
return syscall.EINVAL
}
defer func() { c.seq++ }()
if err := c.sendto(ctx, c.pending(), 0, &syscall.SockaddrNetlink{
Family: syscall.AF_NETLINK,
}); err != nil {
return err
}
return c.receive(ctx, f, 0)
}