// Package netlink is a partial implementation of the netlink protocol. package netlink import ( "fmt" "os" "sync" "syscall" "unsafe" ) // AF_NETLINK socket is never shared var ( nlPid uint32 nlPidOnce sync.Once ) // getpid returns a cached pid value. func getpid() uint32 { nlPidOnce.Do(func() { nlPid = uint32(os.Getpid()) }) return nlPid } // A conn represents resources associated to a netlink socket. type conn struct { // AF_NETLINK socket. fd int // Kernel module or netlink group to communicate with. family int // Message sequence number. seq uint32 // For pending outgoing message. typ, flags uint16 // Outgoing position in buf. pos int // A page holding incoming and outgoing messages. buf []byte } // dial returns the address of a newly connected conn of specified family. func dial(family int) (*conn, error) { var c conn if fd, err := syscall.Socket( syscall.AF_NETLINK, syscall.SOCK_RAW|syscall.SOCK_CLOEXEC, family, ); err != nil { return nil, os.NewSyscallError("socket", err) } else if err = syscall.Bind(fd, &syscall.SockaddrNetlink{ Family: syscall.AF_NETLINK, Pid: getpid(), }); err != nil { _ = syscall.Close(fd) return nil, os.NewSyscallError("bind", err) } else { c.fd, c.family = fd, family } c.pos = syscall.NLMSG_HDRLEN c.buf = make([]byte, os.Getpagesize()) return &c, nil } // Close closes the underlying socket. func (c *conn) Close() error { if c.buf == nil { return syscall.EINVAL } c.buf = nil return syscall.Close(c.fd) } // Msg is type constraint for types sent over the wire via netlink. // // No pointer types or compound types containing pointers may appear here. type Msg interface { syscall.NlMsghdr | syscall.NlMsgerr | syscall.IfAddrmsg | RtAttrMsg[InAddr] | syscall.IfInfomsg } // As returns data as the specified netlink message type. func As[M Msg](data []byte) *M { var v M if unsafe.Sizeof(v) != uintptr(len(data)) { return nil } return (*M)(unsafe.Pointer(unsafe.SliceData(data))) } // add queues a value to be sent by conn. func add[M Msg](c *conn, p *M) bool { pos := c.pos c.pos += int(unsafe.Sizeof(*p)) if c.pos > len(c.buf) { c.pos = pos return false } *(*M)(unsafe.Pointer(&c.buf[pos])) = *p return true } // InconsistentError describes a reply from the kernel that is not consistent // with the internal state tracked by this package. type InconsistentError struct { // Offending header. syscall.NlMsghdr // Expected message sequence. Seq uint32 // Expected pid. Pid uint32 } func (*InconsistentError) Unwrap() error { return os.ErrInvalid } func (e *InconsistentError) Error() string { s := "netlink socket has inconsistent state" switch { case e.Seq != e.NlMsghdr.Seq: s += fmt.Sprintf(": seq %d != %d", e.Seq, e.NlMsghdr.Seq) case e.Pid != e.NlMsghdr.Pid: s += fmt.Sprintf(": pid %d != %d", e.Pid, e.NlMsghdr.Pid) } return s } // pending returns the valid slice of buf and initialises pos. func (c *conn) pending() []byte { buf := c.buf[:c.pos] c.pos = syscall.NLMSG_HDRLEN *(*syscall.NlMsghdr)(unsafe.Pointer(unsafe.SliceData(buf))) = syscall.NlMsghdr{ Len: uint32(len(buf)), Type: c.typ, Flags: c.flags, Seq: c.seq, Pid: getpid(), } return buf } // Complete indicates the completion of a roundtrip. type Complete struct{} // Error returns a hardcoded string that should never be displayed to the user. func (Complete) Error() string { return "returning from roundtrip" } // Roundtrip sends the pending message and handles the reply. func (c *conn) Roundtrip(f func(msg *syscall.NetlinkMessage) error) error { if c.buf == nil { return syscall.EINVAL } defer func() { c.seq++ }() if err := syscall.Sendto(c.fd, c.pending(), 0, &syscall.SockaddrNetlink{ Family: syscall.AF_NETLINK, }); err != nil { return os.NewSyscallError("sendto", err) } for { buf := c.buf if n, _, err := syscall.Recvfrom(c.fd, buf, 0); err != nil { return os.NewSyscallError("recvfrom", err) } else if n < syscall.NLMSG_HDRLEN { return syscall.EBADE } else { buf = buf[:n] } msgs, err := syscall.ParseNetlinkMessage(buf) if err != nil { return err } for _, msg := range msgs { if msg.Header.Seq != c.seq || msg.Header.Pid != getpid() { return &InconsistentError{msg.Header, c.seq, getpid()} } if err = f(&msg); err != nil { if err == (Complete{}) { return nil } return err } } } }