internal/netlink: export generic connection
All checks were successful
Test / Create distribution (push) Successful in 1m16s
Test / Sandbox (push) Successful in 3m8s
Test / ShareFS (push) Successful in 4m18s
Test / Hakurei (push) Successful in 4m23s
Test / Sandbox (race detector) (push) Successful in 5m36s
Test / Hakurei (race detector) (push) Successful in 6m43s
Test / Flake checks (push) Successful in 1m27s
All checks were successful
Test / Create distribution (push) Successful in 1m16s
Test / Sandbox (push) Successful in 3m8s
Test / ShareFS (push) Successful in 4m18s
Test / Hakurei (push) Successful in 4m23s
Test / Sandbox (race detector) (push) Successful in 5m36s
Test / Hakurei (race detector) (push) Successful in 6m43s
Test / Flake checks (push) Successful in 1m27s
This enables abstractions around some families to be implemented in a separate package. Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
@@ -18,8 +18,8 @@ const (
|
|||||||
stateOpen uint32 = 1 << iota
|
stateOpen uint32 = 1 << iota
|
||||||
)
|
)
|
||||||
|
|
||||||
// A conn represents resources associated to a netlink socket.
|
// A Conn represents resources associated to a netlink socket.
|
||||||
type conn struct {
|
type Conn struct {
|
||||||
// AF_NETLINK socket.
|
// AF_NETLINK socket.
|
||||||
f *os.File
|
f *os.File
|
||||||
// For using runtime polling via f.
|
// For using runtime polling via f.
|
||||||
@@ -44,9 +44,10 @@ type conn struct {
|
|||||||
t time.Time
|
t time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// dial returns the address of a newly connected conn of specified family.
|
// Dial returns the address of a newly connected generic netlink connection of
|
||||||
func dial(family int, groups uint32) (*conn, error) {
|
// specified family and groups.
|
||||||
var c conn
|
func Dial(family int, groups uint32) (*Conn, error) {
|
||||||
|
var c Conn
|
||||||
if fd, err := syscall.Socket(
|
if fd, err := syscall.Socket(
|
||||||
syscall.AF_NETLINK,
|
syscall.AF_NETLINK,
|
||||||
syscall.SOCK_RAW|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC,
|
syscall.SOCK_RAW|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC,
|
||||||
@@ -89,10 +90,10 @@ func dial(family int, groups uint32) (*conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ok returns whether conn is still open.
|
// ok returns whether conn is still open.
|
||||||
func (c *conn) ok() bool { return c.state&stateOpen != 0 }
|
func (c *Conn) ok() bool { return c.state&stateOpen != 0 }
|
||||||
|
|
||||||
// Close closes the underlying socket.
|
// Close closes the underlying socket.
|
||||||
func (c *conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
if !c.ok() {
|
if !c.ok() {
|
||||||
return syscall.EINVAL
|
return syscall.EINVAL
|
||||||
}
|
}
|
||||||
@@ -100,35 +101,41 @@ func (c *conn) Close() error {
|
|||||||
return c.f.Close()
|
return c.f.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// recvfrom wraps recv(2) with nonblocking behaviour via the runtime network poller.
|
// Recvfrom wraps recv(2) with nonblocking behaviour via the runtime network poller.
|
||||||
func (c *conn) recvfrom(
|
//
|
||||||
|
// The returned slice is valid until the next call to Recvfrom.
|
||||||
|
func (c *Conn) Recvfrom(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
p []byte,
|
|
||||||
flags int,
|
flags int,
|
||||||
) (n int, from syscall.Sockaddr, err error) {
|
) (data []byte, from syscall.Sockaddr, err error) {
|
||||||
if err = c.f.SetReadDeadline(time.Time{}); err != nil {
|
if err = c.f.SetReadDeadline(time.Time{}); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var n int
|
||||||
|
data = c.buf[:]
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
done <- c.raw.Read(func(fd uintptr) (done bool) {
|
done <- c.raw.Read(func(fd uintptr) (done bool) {
|
||||||
n, from, err = syscall.Recvfrom(int(fd), p, flags)
|
n, from, err = syscall.Recvfrom(int(fd), data, flags)
|
||||||
return err != syscall.EWOULDBLOCK
|
return err != syscall.EWOULDBLOCK
|
||||||
})
|
})
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case rcErr := <-done:
|
case rcErr := <-done:
|
||||||
|
data = data[:n]
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = os.NewSyscallError("recvfrom", err)
|
err = os.NewSyscallError("recvfrom", err)
|
||||||
} else {
|
} else {
|
||||||
err = rcErr
|
err = rcErr
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
cancelErr := c.f.SetReadDeadline(c.t)
|
cancelErr := c.f.SetReadDeadline(c.t)
|
||||||
<-done
|
<-done
|
||||||
|
data = data[:n]
|
||||||
if cancelErr != nil {
|
if cancelErr != nil {
|
||||||
err = cancelErr
|
err = cancelErr
|
||||||
} else {
|
} else {
|
||||||
@@ -136,11 +143,10 @@ func (c *conn) recvfrom(
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendto wraps send(2) with nonblocking behaviour via the runtime network poller.
|
// Sendto wraps send(2) with nonblocking behaviour via the runtime network poller.
|
||||||
func (c *conn) sendto(
|
func (c *Conn) Sendto(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
p []byte,
|
p []byte,
|
||||||
flags int,
|
flags int,
|
||||||
@@ -165,6 +171,7 @@ func (c *conn) sendto(
|
|||||||
} else {
|
} else {
|
||||||
err = rcErr
|
err = rcErr
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
cancelErr := c.f.SetWriteDeadline(c.t)
|
cancelErr := c.f.SetWriteDeadline(c.t)
|
||||||
@@ -176,7 +183,6 @@ func (c *conn) sendto(
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Msg is type constraint for types sent over the wire via netlink.
|
// Msg is type constraint for types sent over the wire via netlink.
|
||||||
@@ -198,7 +204,7 @@ func As[M Msg](data []byte) *M {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// add queues a value to be sent by conn.
|
// add queues a value to be sent by conn.
|
||||||
func add[M Msg](c *conn, p *M) bool {
|
func add[M Msg](c *Conn, p *M) bool {
|
||||||
pos := c.pos
|
pos := c.pos
|
||||||
c.pos += int(unsafe.Sizeof(*p))
|
c.pos += int(unsafe.Sizeof(*p))
|
||||||
if c.pos > len(c.buf) {
|
if c.pos > len(c.buf) {
|
||||||
@@ -233,7 +239,7 @@ func (e *InconsistentError) Error() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkReply checks the message header of a reply from the kernel.
|
// checkReply checks the message header of a reply from the kernel.
|
||||||
func (c *conn) checkReply(header *syscall.NlMsghdr) error {
|
func (c *Conn) checkReply(header *syscall.NlMsghdr) error {
|
||||||
if header.Seq != c.seq || header.Pid != c.port {
|
if header.Seq != c.seq || header.Pid != c.port {
|
||||||
return &InconsistentError{*header, c.seq, c.port}
|
return &InconsistentError{*header, c.seq, c.port}
|
||||||
}
|
}
|
||||||
@@ -241,7 +247,7 @@ func (c *conn) checkReply(header *syscall.NlMsghdr) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// pending returns the valid slice of buf and initialises pos.
|
// pending returns the valid slice of buf and initialises pos.
|
||||||
func (c *conn) pending() []byte {
|
func (c *Conn) pending() []byte {
|
||||||
buf := c.buf[:c.pos]
|
buf := c.buf[:c.pos]
|
||||||
c.pos = syscall.NLMSG_HDRLEN
|
c.pos = syscall.NLMSG_HDRLEN
|
||||||
|
|
||||||
@@ -267,23 +273,18 @@ type HandlerFunc func(resp []syscall.NetlinkMessage) error
|
|||||||
|
|
||||||
// receive receives from a socket with specified flags until a non-nil error is
|
// receive receives from a socket with specified flags until a non-nil error is
|
||||||
// returned by f. An error of type [Complete] is returned as nil.
|
// returned by f. An error of type [Complete] is returned as nil.
|
||||||
func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error {
|
func (c *Conn) receive(ctx context.Context, f HandlerFunc, flags int) error {
|
||||||
for {
|
for {
|
||||||
buf := c.buf[:]
|
var resp []syscall.NetlinkMessage
|
||||||
if n, _, err := c.recvfrom(ctx, buf, flags); err != nil {
|
if data, _, err := c.Recvfrom(ctx, flags); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if n < syscall.NLMSG_HDRLEN {
|
} else if len(data) < syscall.NLMSG_HDRLEN {
|
||||||
return syscall.EBADE
|
return syscall.EBADE
|
||||||
} else {
|
} else if resp, err = syscall.ParseNetlinkMessage(data); err != nil {
|
||||||
buf = buf[:n]
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := syscall.ParseNetlinkMessage(buf)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = f(resp); err != nil {
|
if err := f(resp); err != nil {
|
||||||
if err == (Complete{}) {
|
if err == (Complete{}) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -293,13 +294,13 @@ func (c *conn) receive(ctx context.Context, f HandlerFunc, flags int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Roundtrip sends the pending message and handles the reply.
|
// Roundtrip sends the pending message and handles the reply.
|
||||||
func (c *conn) Roundtrip(ctx context.Context, f HandlerFunc) error {
|
func (c *Conn) Roundtrip(ctx context.Context, f HandlerFunc) error {
|
||||||
if !c.ok() {
|
if !c.ok() {
|
||||||
return syscall.EINVAL
|
return syscall.EINVAL
|
||||||
}
|
}
|
||||||
defer func() { c.seq++ }()
|
defer func() { c.seq++ }()
|
||||||
|
|
||||||
if err := c.sendto(ctx, c.pending(), 0, &syscall.SockaddrNetlink{
|
if err := c.Sendto(ctx, c.pending(), 0, &syscall.SockaddrNetlink{
|
||||||
Family: syscall.AF_NETLINK,
|
Family: syscall.AF_NETLINK,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
type payloadTestCase struct {
|
type payloadTestCase struct {
|
||||||
name string
|
name string
|
||||||
f func(c *conn)
|
f func(c *Conn)
|
||||||
want []byte
|
want []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ func checkPayload(t *testing.T, testCases []payloadTestCase) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
c := conn{port: 1, pos: syscall.NLMSG_HDRLEN}
|
c := Conn{port: 1, pos: syscall.NLMSG_HDRLEN}
|
||||||
tc.f(&c)
|
tc.f(&c)
|
||||||
if got := c.pending(); string(got) != string(tc.want) {
|
if got := c.pending(); string(got) != string(tc.want) {
|
||||||
t.Errorf("pending: %#v, want %#v", got, tc.want)
|
t.Errorf("pending: %#v, want %#v", got, tc.want)
|
||||||
|
|||||||
@@ -7,11 +7,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// RouteConn represents a NETLINK_ROUTE socket.
|
// RouteConn represents a NETLINK_ROUTE socket.
|
||||||
type RouteConn struct{ *conn }
|
type RouteConn struct{ conn *Conn }
|
||||||
|
|
||||||
|
// Close closes the underlying socket.
|
||||||
|
func (c *RouteConn) Close() error { return c.conn.Close() }
|
||||||
|
|
||||||
// DialRoute returns the address of a newly connected [RouteConn].
|
// DialRoute returns the address of a newly connected [RouteConn].
|
||||||
func DialRoute() (*RouteConn, error) {
|
func DialRoute() (*RouteConn, error) {
|
||||||
c, err := dial(syscall.NETLINK_ROUTE, 0)
|
c, err := Dial(syscall.NETLINK_ROUTE, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -19,9 +22,9 @@ func DialRoute() (*RouteConn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// rtnlConsume consumes a message from rtnetlink.
|
// rtnlConsume consumes a message from rtnetlink.
|
||||||
func (c *conn) rtnlConsume(resp []syscall.NetlinkMessage) error {
|
func (c *RouteConn) rtnlConsume(resp []syscall.NetlinkMessage) error {
|
||||||
for i := range resp {
|
for i := range resp {
|
||||||
if err := c.checkReply(&resp[i].Header); err != nil {
|
if err := c.conn.checkReply(&resp[i].Header); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +65,7 @@ func (c *RouteConn) writeIfAddrmsg(
|
|||||||
msg *syscall.IfAddrmsg,
|
msg *syscall.IfAddrmsg,
|
||||||
attrs ...RtAttrMsg[InAddr],
|
attrs ...RtAttrMsg[InAddr],
|
||||||
) bool {
|
) bool {
|
||||||
c.typ, c.flags = typ, syscall.NLM_F_REQUEST|syscall.NLM_F_ACK|flags
|
c.conn.typ, c.conn.flags = typ, syscall.NLM_F_REQUEST|syscall.NLM_F_ACK|flags
|
||||||
if !add(c.conn, msg) {
|
if !add(c.conn, msg) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -85,7 +88,7 @@ func (c *RouteConn) SendIfAddrmsg(
|
|||||||
if !c.writeIfAddrmsg(typ, flags, msg, attrs...) {
|
if !c.writeIfAddrmsg(typ, flags, msg, attrs...) {
|
||||||
return syscall.ENOMEM
|
return syscall.ENOMEM
|
||||||
}
|
}
|
||||||
return c.Roundtrip(ctx, c.conn.rtnlConsume)
|
return c.conn.Roundtrip(ctx, c.rtnlConsume)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeNewaddrLo writes a RTM_NEWADDR message for the loopback address.
|
// writeNewaddrLo writes a RTM_NEWADDR message for the loopback address.
|
||||||
@@ -114,7 +117,7 @@ func (c *RouteConn) SendNewaddrLo(ctx context.Context, lo uint32) error {
|
|||||||
if !c.writeNewaddrLo(lo) {
|
if !c.writeNewaddrLo(lo) {
|
||||||
return syscall.ENOMEM
|
return syscall.ENOMEM
|
||||||
}
|
}
|
||||||
return c.Roundtrip(ctx, c.conn.rtnlConsume)
|
return c.conn.Roundtrip(ctx, c.rtnlConsume)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeIfInfomsg writes an ifinfomsg structure to conn.
|
// writeIfInfomsg writes an ifinfomsg structure to conn.
|
||||||
@@ -122,7 +125,7 @@ func (c *RouteConn) writeIfInfomsg(
|
|||||||
typ, flags uint16,
|
typ, flags uint16,
|
||||||
msg *syscall.IfInfomsg,
|
msg *syscall.IfInfomsg,
|
||||||
) bool {
|
) bool {
|
||||||
c.typ, c.flags = typ, syscall.NLM_F_REQUEST|syscall.NLM_F_ACK|flags
|
c.conn.typ, c.conn.flags = typ, syscall.NLM_F_REQUEST|syscall.NLM_F_ACK|flags
|
||||||
return add(c.conn, msg)
|
return add(c.conn, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,5 +138,5 @@ func (c *RouteConn) SendIfInfomsg(
|
|||||||
if !c.writeIfInfomsg(typ, flags, msg) {
|
if !c.writeIfInfomsg(typ, flags, msg) {
|
||||||
return syscall.ENOMEM
|
return syscall.ENOMEM
|
||||||
}
|
}
|
||||||
return c.Roundtrip(ctx, c.conn.rtnlConsume)
|
return c.conn.Roundtrip(ctx, c.rtnlConsume)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ func TestPayloadRTNETLINK(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
checkPayload(t, []payloadTestCase{
|
checkPayload(t, []payloadTestCase{
|
||||||
{"RTM_NEWADDR lo", func(c *conn) {
|
{"RTM_NEWADDR lo", func(c *Conn) {
|
||||||
(&RouteConn{c}).writeNewaddrLo(1)
|
(&RouteConn{c}).writeNewaddrLo(1)
|
||||||
}, []byte{
|
}, []byte{
|
||||||
/* Len */ 0x28, 0, 0, 0,
|
/* Len */ 0x28, 0, 0, 0,
|
||||||
@@ -33,7 +33,7 @@ func TestPayloadRTNETLINK(t *testing.T) {
|
|||||||
/* in_addr */ 127, 0, 0, 1,
|
/* in_addr */ 127, 0, 0, 1,
|
||||||
}},
|
}},
|
||||||
|
|
||||||
{"RTM_NEWLINK", func(c *conn) {
|
{"RTM_NEWLINK", func(c *Conn) {
|
||||||
c.seq++
|
c.seq++
|
||||||
(&RouteConn{c}).writeIfInfomsg(
|
(&RouteConn{c}).writeIfInfomsg(
|
||||||
syscall.RTM_NEWLINK, 0,
|
syscall.RTM_NEWLINK, 0,
|
||||||
|
|||||||
Reference in New Issue
Block a user