forked from rosa/hakurei
internal/netlink: wrap netpoll via context
This removes netpoll boilerplate for the most common use case. Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package container
|
package container
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net"
|
"net"
|
||||||
@@ -66,7 +67,7 @@ type syscallDispatcher interface {
|
|||||||
// ensureFile provides ensureFile.
|
// ensureFile provides ensureFile.
|
||||||
ensureFile(name string, perm, pperm os.FileMode) error
|
ensureFile(name string, perm, pperm os.FileMode) error
|
||||||
// mustLoopback provides mustLoopback.
|
// mustLoopback provides mustLoopback.
|
||||||
mustLoopback(msg message.Msg)
|
mustLoopback(ctx context.Context, msg message.Msg)
|
||||||
|
|
||||||
// seccompLoad provides [seccomp.Load].
|
// seccompLoad provides [seccomp.Load].
|
||||||
seccompLoad(rules []std.NativeRule, flags seccomp.ExportFlag) error
|
seccompLoad(rules []std.NativeRule, flags seccomp.ExportFlag) error
|
||||||
@@ -170,7 +171,7 @@ func (k direct) mountTmpfs(fsname, target string, flags uintptr, size int, perm
|
|||||||
func (direct) ensureFile(name string, perm, pperm os.FileMode) error {
|
func (direct) ensureFile(name string, perm, pperm os.FileMode) error {
|
||||||
return ensureFile(name, perm, pperm)
|
return ensureFile(name, perm, pperm)
|
||||||
}
|
}
|
||||||
func (direct) mustLoopback(msg message.Msg) {
|
func (direct) mustLoopback(ctx context.Context, msg message.Msg) {
|
||||||
var lo int
|
var lo int
|
||||||
if ifi, err := net.InterfaceByName("lo"); err != nil {
|
if ifi, err := net.InterfaceByName("lo"); err != nil {
|
||||||
msg.GetLogger().Fatalln(err)
|
msg.GetLogger().Fatalln(err)
|
||||||
@@ -199,11 +200,14 @@ func (direct) mustLoopback(msg message.Msg) {
|
|||||||
msg.GetLogger().Fatalf("RTNETLINK answers: %v", err)
|
msg.GetLogger().Fatalf("RTNETLINK answers: %v", err)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
msg.GetLogger().Fatalf("RTNETLINK answers with malformed message")
|
if err == context.DeadlineExceeded || err == context.Canceled {
|
||||||
|
msg.GetLogger().Fatalf("interrupted RTNETLINK operation")
|
||||||
|
}
|
||||||
|
msg.GetLogger().Fatal("RTNETLINK answers with malformed message")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
must(c.SendNewaddrLo(uint32(lo)))
|
must(c.SendNewaddrLo(ctx, uint32(lo)))
|
||||||
must(c.SendIfInfomsg(syscall.RTM_NEWLINK, 0, &syscall.IfInfomsg{
|
must(c.SendIfInfomsg(ctx, syscall.RTM_NEWLINK, 0, &syscall.IfInfomsg{
|
||||||
Family: syscall.AF_UNSPEC,
|
Family: syscall.AF_UNSPEC,
|
||||||
Index: int32(lo),
|
Index: int32(lo),
|
||||||
Flags: syscall.IFF_UP,
|
Flags: syscall.IFF_UP,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package container
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
@@ -468,7 +469,7 @@ func (k *kstub) ensureFile(name string, perm, pperm os.FileMode) error {
|
|||||||
stub.CheckArg(k.Stub, "pperm", pperm, 2))
|
stub.CheckArg(k.Stub, "pperm", pperm, 2))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*kstub) mustLoopback(message.Msg) { /* noop */ }
|
func (*kstub) mustLoopback(context.Context, message.Msg) { /* noop */ }
|
||||||
|
|
||||||
func (k *kstub) seccompLoad(rules []std.NativeRule, flags seccomp.ExportFlag) error {
|
func (k *kstub) seccompLoad(rules []std.NativeRule, flags seccomp.ExportFlag) error {
|
||||||
k.Helper()
|
k.Helper()
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"os/signal"
|
||||||
"path"
|
"path"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -175,7 +176,11 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !params.HostNet {
|
if !params.HostNet {
|
||||||
k.mustLoopback(msg)
|
ctx, cancel := signal.NotifyContext(context.Background(), CancelSignal,
|
||||||
|
os.Interrupt, SIGTERM, SIGQUIT)
|
||||||
|
defer cancel() // for panics
|
||||||
|
k.mustLoopback(ctx, msg)
|
||||||
|
cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// write uid/gid map here so parent does not need to set dumpable
|
// write uid/gid map here so parent does not need to set dumpable
|
||||||
|
|||||||
@@ -2,10 +2,12 @@
|
|||||||
package netlink
|
package netlink
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -37,6 +39,10 @@ type conn struct {
|
|||||||
pos int
|
pos int
|
||||||
// A page holding incoming and outgoing messages.
|
// A page holding incoming and outgoing messages.
|
||||||
buf []byte
|
buf []byte
|
||||||
|
// An instant some time after conn was established, but before the first
|
||||||
|
// I/O operation on f through raw. This serves as a cached deadline to
|
||||||
|
// cancel blocking I/O.
|
||||||
|
t time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// dial returns the address of a newly connected conn of specified family.
|
// dial returns the address of a newly connected conn of specified family.
|
||||||
@@ -65,6 +71,7 @@ func dial(family int) (*conn, error) {
|
|||||||
|
|
||||||
c.pos = syscall.NLMSG_HDRLEN
|
c.pos = syscall.NLMSG_HDRLEN
|
||||||
c.buf = make([]byte, os.Getpagesize())
|
c.buf = make([]byte, os.Getpagesize())
|
||||||
|
c.t = time.Now().UTC()
|
||||||
return &c, nil
|
return &c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,35 +86,79 @@ func (c *conn) Close() error {
|
|||||||
|
|
||||||
// recvfrom wraps recv(2) with nonblocking behaviour via the runtime network poller.
|
// recvfrom wraps recv(2) with nonblocking behaviour via the runtime network poller.
|
||||||
func (c *conn) recvfrom(
|
func (c *conn) recvfrom(
|
||||||
|
ctx context.Context,
|
||||||
p []byte,
|
p []byte,
|
||||||
flags int,
|
flags int,
|
||||||
) (n int, from syscall.Sockaddr, err error) {
|
) (n int, from syscall.Sockaddr, err error) {
|
||||||
rcErr := c.raw.Read(func(fd uintptr) (done bool) {
|
if err = c.f.SetReadDeadline(time.Time{}); err != nil {
|
||||||
n, from, err = syscall.Recvfrom(int(fd), p, flags)
|
return
|
||||||
return err != syscall.EWOULDBLOCK
|
}
|
||||||
})
|
|
||||||
if err != nil {
|
done := make(chan error, 1)
|
||||||
err = os.NewSyscallError("recvfrom", err)
|
go func() {
|
||||||
} else {
|
done <- c.raw.Read(func(fd uintptr) (done bool) {
|
||||||
err = rcErr
|
n, from, err = syscall.Recvfrom(int(fd), p, flags)
|
||||||
|
return err != syscall.EWOULDBLOCK
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case rcErr := <-done:
|
||||||
|
if err != nil {
|
||||||
|
err = os.NewSyscallError("recvfrom", err)
|
||||||
|
} else {
|
||||||
|
err = rcErr
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-ctx.Done():
|
||||||
|
cancelErr := c.f.SetReadDeadline(c.t)
|
||||||
|
<-done
|
||||||
|
if cancelErr != nil {
|
||||||
|
err = cancelErr
|
||||||
|
} else {
|
||||||
|
err = ctx.Err()
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendto wraps send(2) with nonblocking behaviour via the runtime network poller.
|
// sendto wraps send(2) with nonblocking behaviour via the runtime network poller.
|
||||||
func (c *conn) sendto(
|
func (c *conn) sendto(
|
||||||
|
ctx context.Context,
|
||||||
p []byte,
|
p []byte,
|
||||||
flags int,
|
flags int,
|
||||||
to syscall.Sockaddr,
|
to syscall.Sockaddr,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
rcErr := c.raw.Write(func(fd uintptr) (done bool) {
|
if err = c.f.SetWriteDeadline(time.Time{}); err != nil {
|
||||||
err = syscall.Sendto(int(fd), p, flags, to)
|
return
|
||||||
return err != syscall.EWOULDBLOCK
|
}
|
||||||
})
|
|
||||||
if err != nil {
|
done := make(chan error, 1)
|
||||||
err = os.NewSyscallError("sendto", err)
|
go func() {
|
||||||
} else {
|
done <- c.raw.Write(func(fd uintptr) (done bool) {
|
||||||
err = rcErr
|
err = syscall.Sendto(int(fd), p, flags, to)
|
||||||
|
return err != syscall.EWOULDBLOCK
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case rcErr := <-done:
|
||||||
|
if err != nil {
|
||||||
|
err = os.NewSyscallError("sendto", err)
|
||||||
|
} else {
|
||||||
|
err = rcErr
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-ctx.Done():
|
||||||
|
cancelErr := c.f.SetWriteDeadline(c.t)
|
||||||
|
<-done
|
||||||
|
if cancelErr != nil {
|
||||||
|
err = cancelErr
|
||||||
|
} else {
|
||||||
|
err = ctx.Err()
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -192,10 +243,10 @@ type HandlerFunc func(resp []syscall.NetlinkMessage) error
|
|||||||
|
|
||||||
// receive receives from a socket with specified flags until a non-nil error is
|
// receive receives from a socket with specified flags until a non-nil error is
|
||||||
// returned by f. An error of type [Complete] is returned as nil.
|
// returned by f. An error of type [Complete] is returned as nil.
|
||||||
func (c *conn) receive(f HandlerFunc, flags int) error {
|
func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error {
|
||||||
for {
|
for {
|
||||||
buf := c.buf
|
buf := c.buf
|
||||||
if n, _, err := c.recvfrom(buf, flags); err != nil {
|
if n, _, err := c.recvfrom(ctx, buf, flags); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if n < syscall.NLMSG_HDRLEN {
|
} else if n < syscall.NLMSG_HDRLEN {
|
||||||
return syscall.EBADE
|
return syscall.EBADE
|
||||||
@@ -224,17 +275,17 @@ func (c *conn) receive(f HandlerFunc, flags int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Roundtrip sends the pending message and handles the reply.
|
// Roundtrip sends the pending message and handles the reply.
|
||||||
func (c *conn) Roundtrip(f HandlerFunc) error {
|
func (c *conn) Roundtrip(ctx context.Context, f HandlerFunc) error {
|
||||||
if c.buf == nil {
|
if c.buf == nil {
|
||||||
return syscall.EINVAL
|
return syscall.EINVAL
|
||||||
}
|
}
|
||||||
defer func() { c.seq++ }()
|
defer func() { c.seq++ }()
|
||||||
|
|
||||||
if err := c.sendto(c.pending(), 0, &syscall.SockaddrNetlink{
|
if err := c.sendto(ctx, c.pending(), 0, &syscall.SockaddrNetlink{
|
||||||
Family: syscall.AF_NETLINK,
|
Family: syscall.AF_NETLINK,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.receive(f, 0)
|
return c.receive(ctx, f, 0)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package netlink
|
package netlink
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
@@ -72,6 +73,7 @@ func (c *RouteConn) writeIfAddrmsg(
|
|||||||
|
|
||||||
// SendIfAddrmsg sends an ifaddrmsg structure to rtnetlink.
|
// SendIfAddrmsg sends an ifaddrmsg structure to rtnetlink.
|
||||||
func (c *RouteConn) SendIfAddrmsg(
|
func (c *RouteConn) SendIfAddrmsg(
|
||||||
|
ctx context.Context,
|
||||||
typ, flags uint16,
|
typ, flags uint16,
|
||||||
msg *syscall.IfAddrmsg,
|
msg *syscall.IfAddrmsg,
|
||||||
attrs ...RtAttrMsg[InAddr],
|
attrs ...RtAttrMsg[InAddr],
|
||||||
@@ -79,7 +81,7 @@ func (c *RouteConn) SendIfAddrmsg(
|
|||||||
if !c.writeIfAddrmsg(typ, flags, msg, attrs...) {
|
if !c.writeIfAddrmsg(typ, flags, msg, attrs...) {
|
||||||
return syscall.ENOMEM
|
return syscall.ENOMEM
|
||||||
}
|
}
|
||||||
return c.Roundtrip(rtnlConsume)
|
return c.Roundtrip(ctx, rtnlConsume)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeNewaddrLo writes a RTM_NEWADDR message for the loopback address.
|
// writeNewaddrLo writes a RTM_NEWADDR message for the loopback address.
|
||||||
@@ -104,11 +106,11 @@ func (c *RouteConn) writeNewaddrLo(lo uint32) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendNewaddrLo sends a RTM_NEWADDR message for the loopback address to the kernel.
|
// SendNewaddrLo sends a RTM_NEWADDR message for the loopback address to the kernel.
|
||||||
func (c *RouteConn) SendNewaddrLo(lo uint32) error {
|
func (c *RouteConn) SendNewaddrLo(ctx context.Context, lo uint32) error {
|
||||||
if !c.writeNewaddrLo(lo) {
|
if !c.writeNewaddrLo(lo) {
|
||||||
return syscall.ENOMEM
|
return syscall.ENOMEM
|
||||||
}
|
}
|
||||||
return c.Roundtrip(rtnlConsume)
|
return c.Roundtrip(ctx, rtnlConsume)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeIfInfomsg writes an ifinfomsg structure to conn.
|
// writeIfInfomsg writes an ifinfomsg structure to conn.
|
||||||
@@ -122,11 +124,12 @@ func (c *RouteConn) writeIfInfomsg(
|
|||||||
|
|
||||||
// SendIfInfomsg sends an ifinfomsg structure to rtnetlink.
|
// SendIfInfomsg sends an ifinfomsg structure to rtnetlink.
|
||||||
func (c *RouteConn) SendIfInfomsg(
|
func (c *RouteConn) SendIfInfomsg(
|
||||||
|
ctx context.Context,
|
||||||
typ, flags uint16,
|
typ, flags uint16,
|
||||||
msg *syscall.IfInfomsg,
|
msg *syscall.IfInfomsg,
|
||||||
) error {
|
) error {
|
||||||
if !c.writeIfInfomsg(typ, flags, msg) {
|
if !c.writeIfInfomsg(typ, flags, msg) {
|
||||||
return syscall.ENOMEM
|
return syscall.ENOMEM
|
||||||
}
|
}
|
||||||
return c.Roundtrip(rtnlConsume)
|
return c.Roundtrip(ctx, rtnlConsume)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user