From 08c35ca24fabc41268411cec14ea386a6d87ece9 Mon Sep 17 00:00:00 2001 From: Ophestra Date: Mon, 16 Mar 2026 23:32:42 +0900 Subject: [PATCH] container: use new netlink implementation This is adapted from the container netlink implementation and is much more reusable. Signed-off-by: Ophestra --- container/dispatcher.go | 44 ++++++- container/netlink.go | 269 -------------------------------------- container/netlink_test.go | 72 ---------- 3 files changed, 43 insertions(+), 342 deletions(-) delete mode 100644 container/netlink.go delete mode 100644 container/netlink_test.go diff --git a/container/dispatcher.go b/container/dispatcher.go index 23786b5..710e439 100644 --- a/container/dispatcher.go +++ b/container/dispatcher.go @@ -3,6 +3,7 @@ package container import ( "io" "io/fs" + "net" "os" "os/exec" "os/signal" @@ -12,6 +13,7 @@ import ( "hakurei.app/container/seccomp" "hakurei.app/container/std" + "hakurei.app/internal/netlink" "hakurei.app/message" ) @@ -167,7 +169,47 @@ func (k direct) mountTmpfs(fsname, target string, flags uintptr, size int, perm func (direct) ensureFile(name string, perm, pperm os.FileMode) error { return ensureFile(name, perm, pperm) } -func (direct) mustLoopback(msg message.Msg) { mustLoopback(msg) } +func (direct) mustLoopback(msg message.Msg) { + var lo int + if ifi, err := net.InterfaceByName("lo"); err != nil { + msg.GetLogger().Fatalln(err) + } else { + lo = ifi.Index + } + + c, err := netlink.DialRoute() + if err != nil { + msg.GetLogger().Fatalln(err) + } + + must := func(err error) { + if err == nil { + return + } + if closeErr := c.Close(); closeErr != nil { + msg.Verbosef("cannot close RTNETLINK: %v", closeErr) + } + + switch err.(type) { + case *os.SyscallError: + msg.GetLogger().Fatalf("cannot %v", err) + + case syscall.Errno: + msg.GetLogger().Fatalf("RTNETLINK answers: %v", err) + + default: + msg.GetLogger().Fatalf("RTNETLINK answers with malformed message") + } + } + must(c.SendNewaddrLo(uint32(lo))) + must(c.SendIfInfomsg(syscall.RTM_NEWLINK, 0, &syscall.IfInfomsg{ + Family: syscall.AF_UNSPEC, + Index: int32(lo), + Flags: syscall.IFF_UP, + Change: syscall.IFF_UP, + })) + must(c.Close()) +} func (direct) seccompLoad(rules []std.NativeRule, flags seccomp.ExportFlag) error { return seccomp.Load(rules, flags) diff --git a/container/netlink.go b/container/netlink.go deleted file mode 100644 index 412f964..0000000 --- a/container/netlink.go +++ /dev/null @@ -1,269 +0,0 @@ -package container - -import ( - "encoding/binary" - "errors" - "net" - "os" - . "syscall" - "unsafe" - - "hakurei.app/container/std" - "hakurei.app/message" -) - -// rtnetlink represents a NETLINK_ROUTE socket. -type rtnetlink struct { - // Sent as part of rtnetlink messages. - pid uint32 - // AF_NETLINK socket. - fd int - // Whether the socket is open. - ok bool - // Message sequence number. - seq uint32 -} - -// open creates the underlying NETLINK_ROUTE socket. -func (s *rtnetlink) open() (err error) { - if s.ok || s.fd < 0 { - return os.ErrInvalid - } - - s.pid = uint32(Getpid()) - if s.fd, err = Socket( - AF_NETLINK, - SOCK_RAW|SOCK_CLOEXEC, - NETLINK_ROUTE, - ); err != nil { - return os.NewSyscallError("socket", err) - } else if err = Bind(s.fd, &SockaddrNetlink{ - Family: AF_NETLINK, - Pid: s.pid, - }); err != nil { - _ = s.close() - return os.NewSyscallError("bind", err) - } else { - s.ok = true - return nil - } -} - -// close closes the underlying NETLINK_ROUTE socket. -func (s *rtnetlink) close() error { - if !s.ok { - return os.ErrInvalid - } - - s.ok = false - err := Close(s.fd) - s.fd = -1 - return err -} - -// roundtrip sends a netlink message and handles the reply. -func (s *rtnetlink) roundtrip(data []byte) error { - if !s.ok { - return os.ErrInvalid - } - - defer func() { s.seq++ }() - - if err := Sendto(s.fd, data, 0, &SockaddrNetlink{ - Family: AF_NETLINK, - }); err != nil { - return os.NewSyscallError("sendto", err) - } - buf := make([]byte, Getpagesize()) - -done: - for { - p := buf - if n, _, err := Recvfrom(s.fd, p, 0); err != nil { - return os.NewSyscallError("recvfrom", err) - } else if n < NLMSG_HDRLEN { - return errors.ErrUnsupported - } else { - p = p[:n] - } - - if msgs, err := ParseNetlinkMessage(p); err != nil { - return err - } else { - for _, m := range msgs { - if m.Header.Seq != s.seq || m.Header.Pid != s.pid { - return errors.ErrUnsupported - } - if m.Header.Type == NLMSG_DONE { - break done - } - if m.Header.Type == NLMSG_ERROR { - if len(m.Data) >= 4 { - errno := Errno(-std.Int(binary.NativeEndian.Uint32(m.Data))) - if errno == 0 { - return nil - } - return errno - } - return errors.ErrUnsupported - } - } - } - } - - return nil -} - -// mustRoundtrip calls roundtrip and terminates via msg for a non-nil error. -func (s *rtnetlink) mustRoundtrip(msg message.Msg, data []byte) { - err := s.roundtrip(data) - if err == nil { - return - } - if closeErr := Close(s.fd); closeErr != nil { - msg.Verbosef("cannot close: %v", err) - } - - switch err.(type) { - case *os.SyscallError: - msg.GetLogger().Fatalf("cannot %v", err) - - case Errno: - msg.GetLogger().Fatalf("RTNETLINK answers: %v", err) - - default: - msg.GetLogger().Fatalln("RTNETLINK answers with unexpected message") - } -} - -// newaddrLo represents a RTM_NEWADDR message with two addresses. -type newaddrLo struct { - header NlMsghdr - data IfAddrmsg - - r0 RtAttr - a0 [4]byte // in_addr - r1 RtAttr - a1 [4]byte // in_addr -} - -// sizeofNewaddrLo is the expected size of newaddrLo. -const sizeofNewaddrLo = NLMSG_HDRLEN + SizeofIfAddrmsg + (SizeofRtAttr+4)*2 - -// newaddrLo returns the address of a populated newaddrLo. -func (s *rtnetlink) newaddrLo(lo int) *newaddrLo { - return &newaddrLo{NlMsghdr{ - Len: sizeofNewaddrLo, - Type: RTM_NEWADDR, - Flags: NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL, - Seq: s.seq, - Pid: s.pid, - }, IfAddrmsg{ - Family: AF_INET, - Prefixlen: 8, - Flags: IFA_F_PERMANENT, - Scope: RT_SCOPE_HOST, - Index: uint32(lo), - }, RtAttr{ - Len: uint16(SizeofRtAttr + len(newaddrLo{}.a0)), - Type: IFA_LOCAL, - }, [4]byte{127, 0, 0, 1}, RtAttr{ - Len: uint16(SizeofRtAttr + len(newaddrLo{}.a1)), - Type: IFA_ADDRESS, - }, [4]byte{127, 0, 0, 1}} -} - -func (msg *newaddrLo) toWireFormat() []byte { - var buf [sizeofNewaddrLo]byte - - *(*uint32)(unsafe.Pointer(&buf[0:4][0])) = msg.header.Len - *(*uint16)(unsafe.Pointer(&buf[4:6][0])) = msg.header.Type - *(*uint16)(unsafe.Pointer(&buf[6:8][0])) = msg.header.Flags - *(*uint32)(unsafe.Pointer(&buf[8:12][0])) = msg.header.Seq - *(*uint32)(unsafe.Pointer(&buf[12:16][0])) = msg.header.Pid - - buf[16] = msg.data.Family - buf[17] = msg.data.Prefixlen - buf[18] = msg.data.Flags - buf[19] = msg.data.Scope - *(*uint32)(unsafe.Pointer(&buf[20:24][0])) = msg.data.Index - - *(*uint16)(unsafe.Pointer(&buf[24:26][0])) = msg.r0.Len - *(*uint16)(unsafe.Pointer(&buf[26:28][0])) = msg.r0.Type - copy(buf[28:32], msg.a0[:]) - *(*uint16)(unsafe.Pointer(&buf[32:34][0])) = msg.r1.Len - *(*uint16)(unsafe.Pointer(&buf[34:36][0])) = msg.r1.Type - copy(buf[36:40], msg.a1[:]) - - return buf[:] -} - -// newlinkLo represents a RTM_NEWLINK message. -type newlinkLo struct { - header NlMsghdr - data IfInfomsg -} - -// sizeofNewlinkLo is the expected size of newlinkLo. -const sizeofNewlinkLo = NLMSG_HDRLEN + SizeofIfInfomsg - -// newlinkLo returns the address of a populated newlinkLo. -func (s *rtnetlink) newlinkLo(lo int) *newlinkLo { - return &newlinkLo{NlMsghdr{ - Len: sizeofNewlinkLo, - Type: RTM_NEWLINK, - Flags: NLM_F_REQUEST | NLM_F_ACK, - Seq: s.seq, - Pid: s.pid, - }, IfInfomsg{ - Family: AF_UNSPEC, - Index: int32(lo), - Flags: IFF_UP, - Change: IFF_UP, - }} -} - -func (msg *newlinkLo) toWireFormat() []byte { - var buf [sizeofNewlinkLo]byte - - *(*uint32)(unsafe.Pointer(&buf[0:4][0])) = msg.header.Len - *(*uint16)(unsafe.Pointer(&buf[4:6][0])) = msg.header.Type - *(*uint16)(unsafe.Pointer(&buf[6:8][0])) = msg.header.Flags - *(*uint32)(unsafe.Pointer(&buf[8:12][0])) = msg.header.Seq - *(*uint32)(unsafe.Pointer(&buf[12:16][0])) = msg.header.Pid - - buf[16] = msg.data.Family - *(*uint16)(unsafe.Pointer(&buf[18:20][0])) = msg.data.Type - *(*int32)(unsafe.Pointer(&buf[20:24][0])) = msg.data.Index - *(*uint32)(unsafe.Pointer(&buf[24:28][0])) = msg.data.Flags - *(*uint32)(unsafe.Pointer(&buf[28:32][0])) = msg.data.Change - - return buf[:] -} - -// mustLoopback creates the loopback address and brings the lo interface up. -// mustLoopback calls a fatal method of the underlying [log.Logger] of m with a -// user-facing error message if RTNETLINK behaves unexpectedly. -func mustLoopback(msg message.Msg) { - log := msg.GetLogger() - - var lo int - if ifi, err := net.InterfaceByName("lo"); err != nil { - log.Fatalln(err) - } else { - lo = ifi.Index - } - - var s rtnetlink - if err := s.open(); err != nil { - log.Fatalln(err) - } - defer func() { - if err := s.close(); err != nil { - msg.Verbosef("cannot close netlink: %v", err) - } - }() - - s.mustRoundtrip(msg, s.newaddrLo(lo).toWireFormat()) - s.mustRoundtrip(msg, s.newlinkLo(lo).toWireFormat()) -} diff --git a/container/netlink_test.go b/container/netlink_test.go deleted file mode 100644 index 8ecc8ea..0000000 --- a/container/netlink_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package container - -import ( - "testing" - "unsafe" -) - -func TestSizeof(t *testing.T) { - if got := unsafe.Sizeof(newaddrLo{}); got != sizeofNewaddrLo { - t.Fatalf("newaddrLo: sizeof = %#x, want %#x", got, sizeofNewaddrLo) - } - - if got := unsafe.Sizeof(newlinkLo{}); got != sizeofNewlinkLo { - t.Fatalf("newlinkLo: sizeof = %#x, want %#x", got, sizeofNewlinkLo) - } -} - -func TestRtnetlinkMessage(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - msg interface{ toWireFormat() []byte } - want []byte - }{ - {"newaddrLo", (&rtnetlink{pid: 1, seq: 0}).newaddrLo(1), []byte{ - /* Len */ 0x28, 0, 0, 0, - /* Type */ 0x14, 0, - /* Flags */ 5, 6, - /* Seq */ 0, 0, 0, 0, - /* Pid */ 1, 0, 0, 0, - - /* Family */ 2, - /* Prefixlen */ 8, - /* Flags */ 0x80, - /* Scope */ 0xfe, - /* Index */ 1, 0, 0, 0, - - /* Len */ 8, 0, - /* Type */ 2, 0, - /* in_addr */ 127, 0, 0, 1, - - /* Len */ 8, 0, - /* Type */ 1, 0, - /* in_addr */ 127, 0, 0, 1, - }}, - - {"newlinkLo", (&rtnetlink{pid: 1, seq: 1}).newlinkLo(1), []byte{ - /* Len */ 0x20, 0, 0, 0, - /* Type */ 0x10, 0, - /* Flags */ 5, 0, - /* Seq */ 1, 0, 0, 0, - /* Pid */ 1, 0, 0, 0, - - /* Family */ 0, - /* pad */ 0, - /* Type */ 0, 0, - /* Index */ 1, 0, 0, 0, - /* Flags */ 1, 0, 0, 0, - /* Change */ 1, 0, 0, 0, - }}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - if got := tc.msg.toWireFormat(); string(got) != string(tc.want) { - t.Fatalf("toWireFormat: %#v, want %#v", got, tc.want) - } - }) - } -}