cmd/earlyinit: mount system device
Some checks failed
Test / ShareFS (push) Failing after 49s
Test / Create distribution (push) Failing after 51s
Test / Hakurei (race detector) (push) Failing after 59s
Test / Sandbox (push) Failing after 1m0s
Test / Hakurei (push) Failing after 1m11s
Test / Sandbox (race detector) (push) Failing after 1m7s
Test / Flake checks (push) Has been skipped

This currently relies on a trusted bootloader to determine the boot device.

Signed-off-by: Ophestra <cat@gensokyo.uk>
This commit is contained in:
2026-05-26 19:16:32 +09:00
parent caf3e0db4b
commit 3c4ae383ab
4 changed files with 337 additions and 20 deletions

View File

@@ -5,17 +5,89 @@
package main package main
import ( import (
"context"
"crypto/rand"
"log" "log"
"os" "os"
"os/exec"
"os/signal"
"runtime" "runtime"
"runtime/pprof"
"slices"
"strings" "strings"
. "syscall" . "syscall"
"hakurei.app/internal/kobject"
"hakurei.app/internal/report"
"hakurei.app/internal/uevent"
)
var r report.Reporter
func init() {
log.SetFlags(0)
log.SetPrefix("earlyinit: ")
r.SetOutput(log.Default())
// this handles SIGQUIT to provide useful debugging information without
// terminating, and prevents the runtime from throwing on the must family
// of early error reporting functions, DO NOT REMOVE
c := make(chan os.Signal, 1)
signal.Notify(c, SIGQUIT)
go func() {
for {
<-c
if p := pprof.Lookup("goroutine"); p == nil {
log.Println("initial built-in goroutine profile does not exist")
} else if err := p.WriteTo(os.Stderr, 2); err != nil {
log.Println(err)
}
}
}()
}
// fatal calls [log.Println] with v and blocks forever. Must be called from
// main. Must not be used after error reporting is set up.
func fatal(v ...any) {
log.Println(v...)
log.Println("unable to continue, please reboot and resolve the problem manually")
select {}
}
// must calls fatal with err if it is non-nil.
func must(err error) {
if err != nil {
log.Println(err)
select {}
}
}
// mustSyscall is like must, but with an additional action name.
func mustSyscall(action string, err error) {
if err != nil {
fatal("cannot "+action+":", err)
select {}
}
}
// must1 is like must, but with an additional passed through value.
func must1[T any](v T, err error) T {
must(err)
return v
}
const (
// optionSystem specifies devpath of the system device.
optionSystem = "system"
// flagStrict sets [report.DStrict] on r.
flagStrict = "strict"
// flagNoRecover sets [report.DNoRecover] on r.
flagNoRecover = "no_recover"
) )
func main() { func main() {
runtime.LockOSThread() runtime.LockOSThread()
log.SetFlags(0)
log.SetPrefix("earlyinit: ")
var ( var (
option map[string]string option map[string]string
@@ -33,15 +105,33 @@ func main() {
} }
} }
if err := Mount( {
var flag uint64
if slices.Contains(flags, flagStrict) {
flag |= report.DStrict
}
if slices.Contains(flags, flagNoRecover) {
flag |= report.DNoRecover
}
log.Printf("reporting flags %x", flag)
r.SetFlags(flag)
}
mustSyscall("mount devtmpfs", Mount(
"devtmpfs", "devtmpfs",
"/dev/", "/dev/",
"devtmpfs", "devtmpfs",
MS_NOSUID|MS_NOEXEC, MS_NOSUID|MS_NOEXEC,
"", "",
); err != nil { ))
log.Fatalf("cannot mount devtmpfs: %v", err) must(os.Mkdir("/dev/pts/", 0))
} mustSyscall("mount devpts", Mount(
"devpts",
"/dev/pts/",
"devpts",
MS_NOSUID|MS_NOEXEC,
"mode=620,ptmxmode=666",
))
// The kernel might be unable to set up the console. When that happens, // The kernel might be unable to set up the console. When that happens,
// printk is called with "Warning: unable to open an initial console." // printk is called with "Warning: unable to open an initial console."
@@ -98,6 +188,40 @@ func main() {
"", "",
)) ))
conn := must1(uevent.Dial(-128 * 1024 * 1024))
events := make(chan *uevent.Message, 1<<10)
var uuid uevent.UUID
must1(rand.Read(uuid[:]))
ctx, cancel := context.WithCancel(context.Background())
go consume(ctx, &r, conn, uuid, events)
s := kobject.New(uuid, func(o *kobject.Object, env map[string]string) {
log.Printf("change %s: %q", o.DevPath, env)
}, func(err error) {
r.Dispatch(
report.Inconsistent,
"processed inconsistent uevent",
err,
)
})
go func() {
s.Consume(ctx, events)
log.Println("closing NETLINK_KOBJECT_UEVENT socket")
cancel()
if err := conn.Close(); err != nil {
log.Fatal(err) // not reached
}
}()
must(os.Mkdir("/system", 0))
if devpath := option[optionSystem]; devpath == "" {
fatal("system must be nonempty")
} else {
log.Printf("waiting for devpath pattern %q", devpath)
mustMountSystem(ctx, s, devpath)
}
// after top level has been set up // after top level has been set up
mustSyscall("remount root", Mount( mustSyscall("remount root", Mount(
"", "",
@@ -115,17 +239,3 @@ func main() {
)) ))
} }
// mustSyscall calls [log.Fatalln] if err is non-nil.
func mustSyscall(action string, err error) {
if err != nil {
log.Fatalln("cannot "+action+":", err)
}
}
// must calls [log.Fatal] with err if it is non-nil.
func must(err error) {
if err != nil {
log.Fatal(err)
}
}

69
cmd/earlyinit/mount.go Normal file
View File

@@ -0,0 +1,69 @@
package main
import (
"context"
"errors"
"os"
"path/filepath"
"strconv"
"syscall"
"time"
"hakurei.app/check"
"hakurei.app/fhs"
"hakurei.app/internal/kobject"
)
// mustMountSystem waits for and mounts a system device matching pattern.
func mustMountSystem(
ctx context.Context,
s *kobject.State,
pattern string,
) {
c, stop := context.WithTimeout(ctx, 30*time.Second)
defer stop()
for {
var matchErr error
var systemPath *check.Absolute
s.Range(c, func(o *kobject.Object) bool {
if o.Subsystem != "block" ||
o.Env["DEVTYPE"] != "disk" {
return true
}
if ok, err := filepath.Match(pattern, o.DevPath); err != nil {
matchErr = err
return false
} else if !ok {
return true
}
name, ok := o.Env["DEVNAME"]
if !ok {
return true
}
systemPath = fhs.AbsDev.Append(name)
return false
})
if c.Err() != nil {
fatal("devpath", strconv.Quote(pattern), "never appeared")
}
if matchErr != nil {
fatal("cannot match system devpath:", matchErr)
}
err := syscall.Mount(
systemPath.String(),
"/system/",
"squashfs",
0,
"threads=multi",
)
if err == nil {
break
}
if !errors.Is(err, os.ErrNotExist) {
fatal("cannot mount system:", err)
}
}
}

103
cmd/earlyinit/uevent.go Normal file
View File

@@ -0,0 +1,103 @@
package main
import (
"context"
"log"
"time"
"hakurei.app/fhs"
"hakurei.app/internal/report"
"hakurei.app/internal/uevent"
)
// newRejectColdboot returns a function to be called on every subsequent pending
// coldboot, and returns whether coldboot should proceed. Rejection is sticky.
func newRejectColdboot() func() bool {
// one coldboot per five minutes, two consecutive coldboot
const (
coldbootInterval = 5 * time.Minute
coldbootBurst = 2
)
done := make(chan struct{})
s := make(chan struct{}, coldbootBurst)
s <- struct{}{} // for early fault before reporting is ready
go func() {
t := time.NewTicker(coldbootInterval)
for {
select {
case <-done:
return
case <-t.C:
select {
case s <- struct{}{}:
default:
}
}
}
}()
return func() bool {
select {
case <-s:
return true
case <-done:
return false
default:
close(done)
return false
}
}
}
// consume continuously consumes events from conn with retries.
func consume(
ctx context.Context,
r *report.Reporter,
conn *uevent.Conn,
uuid uevent.UUID,
events chan<- *uevent.Message,
) {
defer close(events)
nextColdboot := newRejectColdboot()
coldboot := true
retry:
if dispatchErr := conn.Consume(ctx, fhs.Sys, &uuid, events, coldboot, func(path string) {
log.Println("coldboot visited", path)
}, func(err error) bool {
if _, ok := err.(uevent.NeedsColdboot); ok && !nextColdboot() {
r.Dispatch(
report.Degraded,
"rejecting coldboot loop",
err,
)
return false
}
r.Dispatch(
report.Inconsistent,
"consumed invalid message",
err,
)
return true
}, nil); dispatchErr != nil {
if _, ok := dispatchErr.(uevent.Recoverable); !ok {
r.Dispatch(
report.Fatal,
"discontinuing uevent processing due to nonrecoverable error",
dispatchErr,
)
return
}
if _, ok := dispatchErr.(uevent.NeedsColdboot); ok {
// coldboot loop rejected by handler
coldboot = false
}
goto retry
}
}

View File

@@ -0,0 +1,35 @@
package main
import (
"testing"
"testing/synctest"
"time"
)
func TestRejectColdboot(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) {
nextColdboot := newRejectColdboot()
want := func(want bool) {
if got := nextColdboot(); got != want {
t.Fatalf("nextColdboot: %v, want %v", got, want)
}
}
synctest.Wait()
want(true)
time.Sleep(time.Hour)
synctest.Wait()
want(true)
want(true)
time.Sleep(5 * time.Minute)
synctest.Wait()
want(true)
want(false)
time.Sleep(time.Hour)
synctest.Wait()
want(false)
want(false)
})
}