1
0
forked from rosa/hakurei

internal/netlink: wrap netpoll via context

This removes netpoll boilerplate for the most common use case.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2026-03-25 15:34:20 +09:00
parent b98c5f2e21
commit 50403e9d60
5 changed files with 96 additions and 32 deletions

View File

@@ -1,6 +1,7 @@
package container package container
import ( import (
"context"
"io" "io"
"io/fs" "io/fs"
"net" "net"
@@ -66,7 +67,7 @@ type syscallDispatcher interface {
// ensureFile provides ensureFile. // ensureFile provides ensureFile.
ensureFile(name string, perm, pperm os.FileMode) error ensureFile(name string, perm, pperm os.FileMode) error
// mustLoopback provides mustLoopback. // mustLoopback provides mustLoopback.
mustLoopback(msg message.Msg) mustLoopback(ctx context.Context, msg message.Msg)
// seccompLoad provides [seccomp.Load]. // seccompLoad provides [seccomp.Load].
seccompLoad(rules []std.NativeRule, flags seccomp.ExportFlag) error seccompLoad(rules []std.NativeRule, flags seccomp.ExportFlag) error
@@ -170,7 +171,7 @@ func (k direct) mountTmpfs(fsname, target string, flags uintptr, size int, perm
func (direct) ensureFile(name string, perm, pperm os.FileMode) error { func (direct) ensureFile(name string, perm, pperm os.FileMode) error {
return ensureFile(name, perm, pperm) return ensureFile(name, perm, pperm)
} }
func (direct) mustLoopback(msg message.Msg) { func (direct) mustLoopback(ctx context.Context, msg message.Msg) {
var lo int var lo int
if ifi, err := net.InterfaceByName("lo"); err != nil { if ifi, err := net.InterfaceByName("lo"); err != nil {
msg.GetLogger().Fatalln(err) msg.GetLogger().Fatalln(err)
@@ -199,11 +200,14 @@ func (direct) mustLoopback(msg message.Msg) {
msg.GetLogger().Fatalf("RTNETLINK answers: %v", err) msg.GetLogger().Fatalf("RTNETLINK answers: %v", err)
default: default:
msg.GetLogger().Fatalf("RTNETLINK answers with malformed message") if err == context.DeadlineExceeded || err == context.Canceled {
msg.GetLogger().Fatalf("interrupted RTNETLINK operation")
}
msg.GetLogger().Fatal("RTNETLINK answers with malformed message")
} }
} }
must(c.SendNewaddrLo(uint32(lo))) must(c.SendNewaddrLo(ctx, uint32(lo)))
must(c.SendIfInfomsg(syscall.RTM_NEWLINK, 0, &syscall.IfInfomsg{ must(c.SendIfInfomsg(ctx, syscall.RTM_NEWLINK, 0, &syscall.IfInfomsg{
Family: syscall.AF_UNSPEC, Family: syscall.AF_UNSPEC,
Index: int32(lo), Index: int32(lo),
Flags: syscall.IFF_UP, Flags: syscall.IFF_UP,

View File

@@ -2,6 +2,7 @@ package container
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
@@ -468,7 +469,7 @@ func (k *kstub) ensureFile(name string, perm, pperm os.FileMode) error {
stub.CheckArg(k.Stub, "pperm", pperm, 2)) stub.CheckArg(k.Stub, "pperm", pperm, 2))
} }
func (*kstub) mustLoopback(message.Msg) { /* noop */ } func (*kstub) mustLoopback(context.Context, message.Msg) { /* noop */ }
func (k *kstub) seccompLoad(rules []std.NativeRule, flags seccomp.ExportFlag) error { func (k *kstub) seccompLoad(rules []std.NativeRule, flags seccomp.ExportFlag) error {
k.Helper() k.Helper()

View File

@@ -7,6 +7,7 @@ import (
"log" "log"
"os" "os"
"os/exec" "os/exec"
"os/signal"
"path" "path"
"slices" "slices"
"strconv" "strconv"
@@ -175,7 +176,11 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) {
} }
if !params.HostNet { if !params.HostNet {
k.mustLoopback(msg) ctx, cancel := signal.NotifyContext(context.Background(), CancelSignal,
os.Interrupt, SIGTERM, SIGQUIT)
defer cancel() // for panics
k.mustLoopback(ctx, msg)
cancel()
} }
// write uid/gid map here so parent does not need to set dumpable // write uid/gid map here so parent does not need to set dumpable

View File

@@ -2,10 +2,12 @@
package netlink package netlink
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sync" "sync"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
) )
@@ -37,6 +39,10 @@ type conn struct {
pos int pos int
// A page holding incoming and outgoing messages. // A page holding incoming and outgoing messages.
buf []byte buf []byte
// An instant some time after conn was established, but before the first
// I/O operation on f through raw. This serves as a cached deadline to
// cancel blocking I/O.
t time.Time
} }
// dial returns the address of a newly connected conn of specified family. // dial returns the address of a newly connected conn of specified family.
@@ -65,6 +71,7 @@ func dial(family int) (*conn, error) {
c.pos = syscall.NLMSG_HDRLEN c.pos = syscall.NLMSG_HDRLEN
c.buf = make([]byte, os.Getpagesize()) c.buf = make([]byte, os.Getpagesize())
c.t = time.Now().UTC()
return &c, nil return &c, nil
} }
@@ -79,35 +86,79 @@ func (c *conn) Close() error {
// recvfrom wraps recv(2) with nonblocking behaviour via the runtime network poller. // recvfrom wraps recv(2) with nonblocking behaviour via the runtime network poller.
func (c *conn) recvfrom( func (c *conn) recvfrom(
ctx context.Context,
p []byte, p []byte,
flags int, flags int,
) (n int, from syscall.Sockaddr, err error) { ) (n int, from syscall.Sockaddr, err error) {
rcErr := c.raw.Read(func(fd uintptr) (done bool) { if err = c.f.SetReadDeadline(time.Time{}); err != nil {
n, from, err = syscall.Recvfrom(int(fd), p, flags) return
return err != syscall.EWOULDBLOCK }
})
if err != nil { done := make(chan error, 1)
err = os.NewSyscallError("recvfrom", err) go func() {
} else { done <- c.raw.Read(func(fd uintptr) (done bool) {
err = rcErr n, from, err = syscall.Recvfrom(int(fd), p, flags)
return err != syscall.EWOULDBLOCK
})
}()
select {
case rcErr := <-done:
if err != nil {
err = os.NewSyscallError("recvfrom", err)
} else {
err = rcErr
}
case <-ctx.Done():
cancelErr := c.f.SetReadDeadline(c.t)
<-done
if cancelErr != nil {
err = cancelErr
} else {
err = ctx.Err()
}
return
} }
return return
} }
// sendto wraps send(2) with nonblocking behaviour via the runtime network poller. // sendto wraps send(2) with nonblocking behaviour via the runtime network poller.
func (c *conn) sendto( func (c *conn) sendto(
ctx context.Context,
p []byte, p []byte,
flags int, flags int,
to syscall.Sockaddr, to syscall.Sockaddr,
) (err error) { ) (err error) {
rcErr := c.raw.Write(func(fd uintptr) (done bool) { if err = c.f.SetWriteDeadline(time.Time{}); err != nil {
err = syscall.Sendto(int(fd), p, flags, to) return
return err != syscall.EWOULDBLOCK }
})
if err != nil { done := make(chan error, 1)
err = os.NewSyscallError("sendto", err) go func() {
} else { done <- c.raw.Write(func(fd uintptr) (done bool) {
err = rcErr err = syscall.Sendto(int(fd), p, flags, to)
return err != syscall.EWOULDBLOCK
})
}()
select {
case rcErr := <-done:
if err != nil {
err = os.NewSyscallError("sendto", err)
} else {
err = rcErr
}
case <-ctx.Done():
cancelErr := c.f.SetWriteDeadline(c.t)
<-done
if cancelErr != nil {
err = cancelErr
} else {
err = ctx.Err()
}
return
} }
return return
} }
@@ -192,10 +243,10 @@ type HandlerFunc func(resp []syscall.NetlinkMessage) error
// receive receives from a socket with specified flags until a non-nil error is // receive receives from a socket with specified flags until a non-nil error is
// returned by f. An error of type [Complete] is returned as nil. // returned by f. An error of type [Complete] is returned as nil.
func (c *conn) receive(f HandlerFunc, flags int) error { func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error {
for { for {
buf := c.buf buf := c.buf
if n, _, err := c.recvfrom(buf, flags); err != nil { if n, _, err := c.recvfrom(ctx, buf, flags); err != nil {
return err return err
} else if n < syscall.NLMSG_HDRLEN { } else if n < syscall.NLMSG_HDRLEN {
return syscall.EBADE return syscall.EBADE
@@ -224,17 +275,17 @@ func (c *conn) receive(f HandlerFunc, flags int) error {
} }
// Roundtrip sends the pending message and handles the reply. // Roundtrip sends the pending message and handles the reply.
func (c *conn) Roundtrip(f HandlerFunc) error { func (c *conn) Roundtrip(ctx context.Context, f HandlerFunc) error {
if c.buf == nil { if c.buf == nil {
return syscall.EINVAL return syscall.EINVAL
} }
defer func() { c.seq++ }() defer func() { c.seq++ }()
if err := c.sendto(c.pending(), 0, &syscall.SockaddrNetlink{ if err := c.sendto(ctx, c.pending(), 0, &syscall.SockaddrNetlink{
Family: syscall.AF_NETLINK, Family: syscall.AF_NETLINK,
}); err != nil { }); err != nil {
return err return err
} }
return c.receive(f, 0) return c.receive(ctx, f, 0)
} }

View File

@@ -1,6 +1,7 @@
package netlink package netlink
import ( import (
"context"
"syscall" "syscall"
"unsafe" "unsafe"
) )
@@ -72,6 +73,7 @@ func (c *RouteConn) writeIfAddrmsg(
// SendIfAddrmsg sends an ifaddrmsg structure to rtnetlink. // SendIfAddrmsg sends an ifaddrmsg structure to rtnetlink.
func (c *RouteConn) SendIfAddrmsg( func (c *RouteConn) SendIfAddrmsg(
ctx context.Context,
typ, flags uint16, typ, flags uint16,
msg *syscall.IfAddrmsg, msg *syscall.IfAddrmsg,
attrs ...RtAttrMsg[InAddr], attrs ...RtAttrMsg[InAddr],
@@ -79,7 +81,7 @@ func (c *RouteConn) SendIfAddrmsg(
if !c.writeIfAddrmsg(typ, flags, msg, attrs...) { if !c.writeIfAddrmsg(typ, flags, msg, attrs...) {
return syscall.ENOMEM return syscall.ENOMEM
} }
return c.Roundtrip(rtnlConsume) return c.Roundtrip(ctx, rtnlConsume)
} }
// writeNewaddrLo writes a RTM_NEWADDR message for the loopback address. // writeNewaddrLo writes a RTM_NEWADDR message for the loopback address.
@@ -104,11 +106,11 @@ func (c *RouteConn) writeNewaddrLo(lo uint32) bool {
} }
// SendNewaddrLo sends a RTM_NEWADDR message for the loopback address to the kernel. // SendNewaddrLo sends a RTM_NEWADDR message for the loopback address to the kernel.
func (c *RouteConn) SendNewaddrLo(lo uint32) error { func (c *RouteConn) SendNewaddrLo(ctx context.Context, lo uint32) error {
if !c.writeNewaddrLo(lo) { if !c.writeNewaddrLo(lo) {
return syscall.ENOMEM return syscall.ENOMEM
} }
return c.Roundtrip(rtnlConsume) return c.Roundtrip(ctx, rtnlConsume)
} }
// writeIfInfomsg writes an ifinfomsg structure to conn. // writeIfInfomsg writes an ifinfomsg structure to conn.
@@ -122,11 +124,12 @@ func (c *RouteConn) writeIfInfomsg(
// SendIfInfomsg sends an ifinfomsg structure to rtnetlink. // SendIfInfomsg sends an ifinfomsg structure to rtnetlink.
func (c *RouteConn) SendIfInfomsg( func (c *RouteConn) SendIfInfomsg(
ctx context.Context,
typ, flags uint16, typ, flags uint16,
msg *syscall.IfInfomsg, msg *syscall.IfInfomsg,
) error { ) error {
if !c.writeIfInfomsg(typ, flags, msg) { if !c.writeIfInfomsg(typ, flags, msg) {
return syscall.ENOMEM return syscall.ENOMEM
} }
return c.Roundtrip(rtnlConsume) return c.Roundtrip(ctx, rtnlConsume)
} }