diff --git a/container/dispatcher.go b/container/dispatcher.go index dce5f023..9f971fa5 100644 --- a/container/dispatcher.go +++ b/container/dispatcher.go @@ -1,6 +1,7 @@ package container import ( + "context" "io" "io/fs" "net" @@ -66,7 +67,7 @@ type syscallDispatcher interface { // ensureFile provides ensureFile. ensureFile(name string, perm, pperm os.FileMode) error // mustLoopback provides mustLoopback. - mustLoopback(msg message.Msg) + mustLoopback(ctx context.Context, msg message.Msg) // seccompLoad provides [seccomp.Load]. seccompLoad(rules []std.NativeRule, flags seccomp.ExportFlag) error @@ -170,7 +171,7 @@ func (k direct) mountTmpfs(fsname, target string, flags uintptr, size int, perm func (direct) ensureFile(name string, perm, pperm os.FileMode) error { return ensureFile(name, perm, pperm) } -func (direct) mustLoopback(msg message.Msg) { +func (direct) mustLoopback(ctx context.Context, msg message.Msg) { var lo int if ifi, err := net.InterfaceByName("lo"); err != nil { msg.GetLogger().Fatalln(err) @@ -199,11 +200,14 @@ func (direct) mustLoopback(msg message.Msg) { msg.GetLogger().Fatalf("RTNETLINK answers: %v", err) default: - msg.GetLogger().Fatalf("RTNETLINK answers with malformed message") + if err == context.DeadlineExceeded || err == context.Canceled { + msg.GetLogger().Fatalf("interrupted RTNETLINK operation") + } + msg.GetLogger().Fatal("RTNETLINK answers with malformed message") } } - must(c.SendNewaddrLo(uint32(lo))) - must(c.SendIfInfomsg(syscall.RTM_NEWLINK, 0, &syscall.IfInfomsg{ + must(c.SendNewaddrLo(ctx, uint32(lo))) + must(c.SendIfInfomsg(ctx, syscall.RTM_NEWLINK, 0, &syscall.IfInfomsg{ Family: syscall.AF_UNSPEC, Index: int32(lo), Flags: syscall.IFF_UP, diff --git a/container/dispatcher_test.go b/container/dispatcher_test.go index ee2c3f19..b57c53b9 100644 --- a/container/dispatcher_test.go +++ b/container/dispatcher_test.go @@ -2,6 +2,7 @@ package container import ( "bytes" + "context" "fmt" "io" "io/fs" @@ -468,7 +469,7 @@ func (k *kstub) ensureFile(name string, perm, pperm os.FileMode) error { stub.CheckArg(k.Stub, "pperm", pperm, 2)) } -func (*kstub) mustLoopback(message.Msg) { /* noop */ } +func (*kstub) mustLoopback(context.Context, message.Msg) { /* noop */ } func (k *kstub) seccompLoad(rules []std.NativeRule, flags seccomp.ExportFlag) error { k.Helper() diff --git a/container/init.go b/container/init.go index 799bb29b..7a561733 100644 --- a/container/init.go +++ b/container/init.go @@ -7,6 +7,7 @@ import ( "log" "os" "os/exec" + "os/signal" "path" "slices" "strconv" @@ -175,7 +176,11 @@ func initEntrypoint(k syscallDispatcher, msg message.Msg) { } if !params.HostNet { - k.mustLoopback(msg) + ctx, cancel := signal.NotifyContext(context.Background(), CancelSignal, + os.Interrupt, SIGTERM, SIGQUIT) + defer cancel() // for panics + k.mustLoopback(ctx, msg) + cancel() } // write uid/gid map here so parent does not need to set dumpable diff --git a/internal/netlink/netlink.go b/internal/netlink/netlink.go index df6af840..406f567b 100644 --- a/internal/netlink/netlink.go +++ b/internal/netlink/netlink.go @@ -2,10 +2,12 @@ package netlink import ( + "context" "fmt" "os" "sync" "syscall" + "time" "unsafe" ) @@ -37,6 +39,10 @@ type conn struct { pos int // A page holding incoming and outgoing messages. buf []byte + // An instant some time after conn was established, but before the first + // I/O operation on f through raw. This serves as a cached deadline to + // cancel blocking I/O. + t time.Time } // dial returns the address of a newly connected conn of specified family. @@ -65,6 +71,7 @@ func dial(family int) (*conn, error) { c.pos = syscall.NLMSG_HDRLEN c.buf = make([]byte, os.Getpagesize()) + c.t = time.Now().UTC() return &c, nil } @@ -79,35 +86,79 @@ func (c *conn) Close() error { // recvfrom wraps recv(2) with nonblocking behaviour via the runtime network poller. func (c *conn) recvfrom( + ctx context.Context, p []byte, flags int, ) (n int, from syscall.Sockaddr, err error) { - rcErr := c.raw.Read(func(fd uintptr) (done bool) { - n, from, err = syscall.Recvfrom(int(fd), p, flags) - return err != syscall.EWOULDBLOCK - }) - if err != nil { - err = os.NewSyscallError("recvfrom", err) - } else { - err = rcErr + if err = c.f.SetReadDeadline(time.Time{}); err != nil { + return + } + + done := make(chan error, 1) + go func() { + done <- c.raw.Read(func(fd uintptr) (done bool) { + n, from, err = syscall.Recvfrom(int(fd), p, flags) + return err != syscall.EWOULDBLOCK + }) + }() + + select { + case rcErr := <-done: + if err != nil { + err = os.NewSyscallError("recvfrom", err) + } else { + err = rcErr + } + + case <-ctx.Done(): + cancelErr := c.f.SetReadDeadline(c.t) + <-done + if cancelErr != nil { + err = cancelErr + } else { + err = ctx.Err() + } + return } return } // sendto wraps send(2) with nonblocking behaviour via the runtime network poller. func (c *conn) sendto( + ctx context.Context, p []byte, flags int, to syscall.Sockaddr, ) (err error) { - rcErr := c.raw.Write(func(fd uintptr) (done bool) { - err = syscall.Sendto(int(fd), p, flags, to) - return err != syscall.EWOULDBLOCK - }) - if err != nil { - err = os.NewSyscallError("sendto", err) - } else { - err = rcErr + if err = c.f.SetWriteDeadline(time.Time{}); err != nil { + return + } + + done := make(chan error, 1) + go func() { + done <- c.raw.Write(func(fd uintptr) (done bool) { + err = syscall.Sendto(int(fd), p, flags, to) + return err != syscall.EWOULDBLOCK + }) + }() + + select { + case rcErr := <-done: + if err != nil { + err = os.NewSyscallError("sendto", err) + } else { + err = rcErr + } + + case <-ctx.Done(): + cancelErr := c.f.SetWriteDeadline(c.t) + <-done + if cancelErr != nil { + err = cancelErr + } else { + err = ctx.Err() + } + return } return } @@ -192,10 +243,10 @@ type HandlerFunc func(resp []syscall.NetlinkMessage) error // receive receives from a socket with specified flags until a non-nil error is // returned by f. An error of type [Complete] is returned as nil. -func (c *conn) receive(f HandlerFunc, flags int) error { +func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error { for { buf := c.buf - if n, _, err := c.recvfrom(buf, flags); err != nil { + if n, _, err := c.recvfrom(ctx, buf, flags); err != nil { return err } else if n < syscall.NLMSG_HDRLEN { return syscall.EBADE @@ -224,17 +275,17 @@ func (c *conn) receive(f HandlerFunc, flags int) error { } // Roundtrip sends the pending message and handles the reply. -func (c *conn) Roundtrip(f HandlerFunc) error { +func (c *conn) Roundtrip(ctx context.Context, f HandlerFunc) error { if c.buf == nil { return syscall.EINVAL } defer func() { c.seq++ }() - if err := c.sendto(c.pending(), 0, &syscall.SockaddrNetlink{ + if err := c.sendto(ctx, c.pending(), 0, &syscall.SockaddrNetlink{ Family: syscall.AF_NETLINK, }); err != nil { return err } - return c.receive(f, 0) + return c.receive(ctx, f, 0) } diff --git a/internal/netlink/rtnl.go b/internal/netlink/rtnl.go index e6585bf8..e824f091 100644 --- a/internal/netlink/rtnl.go +++ b/internal/netlink/rtnl.go @@ -1,6 +1,7 @@ package netlink import ( + "context" "syscall" "unsafe" ) @@ -72,6 +73,7 @@ func (c *RouteConn) writeIfAddrmsg( // SendIfAddrmsg sends an ifaddrmsg structure to rtnetlink. func (c *RouteConn) SendIfAddrmsg( + ctx context.Context, typ, flags uint16, msg *syscall.IfAddrmsg, attrs ...RtAttrMsg[InAddr], @@ -79,7 +81,7 @@ func (c *RouteConn) SendIfAddrmsg( if !c.writeIfAddrmsg(typ, flags, msg, attrs...) { return syscall.ENOMEM } - return c.Roundtrip(rtnlConsume) + return c.Roundtrip(ctx, rtnlConsume) } // writeNewaddrLo writes a RTM_NEWADDR message for the loopback address. @@ -104,11 +106,11 @@ func (c *RouteConn) writeNewaddrLo(lo uint32) bool { } // SendNewaddrLo sends a RTM_NEWADDR message for the loopback address to the kernel. -func (c *RouteConn) SendNewaddrLo(lo uint32) error { +func (c *RouteConn) SendNewaddrLo(ctx context.Context, lo uint32) error { if !c.writeNewaddrLo(lo) { return syscall.ENOMEM } - return c.Roundtrip(rtnlConsume) + return c.Roundtrip(ctx, rtnlConsume) } // writeIfInfomsg writes an ifinfomsg structure to conn. @@ -122,11 +124,12 @@ func (c *RouteConn) writeIfInfomsg( // SendIfInfomsg sends an ifinfomsg structure to rtnetlink. func (c *RouteConn) SendIfInfomsg( + ctx context.Context, typ, flags uint16, msg *syscall.IfInfomsg, ) error { if !c.writeIfInfomsg(typ, flags, msg) { return syscall.ENOMEM } - return c.Roundtrip(rtnlConsume) + return c.Roundtrip(ctx, rtnlConsume) }