package uevent_test import ( "context" "os" "path/filepath" "reflect" "slices" "sync" "syscall" "testing" "hakurei.app/check" "hakurei.app/internal/pkg" "hakurei.app/internal/uevent" ) func TestColdboot(t *testing.T) { t.Parallel() d := t.TempDir() if err := os.Chmod(d, 0700); err != nil { t.Fatal(err) } for _, s := range []string{ "devices", "devices/sub", "devices/empty", "block", } { if err := os.MkdirAll(filepath.Join(d, s), 0700); err != nil { t.Fatal(err) } } for _, f := range [][2]string{ {"devices/uevent", ""}, {"devices/sub/uevent", ""}, {"block/uevent", ""}, } { if err := os.WriteFile( filepath.Join(d, f[0]), []byte(f[1]), 0600, ); err != nil { t.Fatal(err) } } var wg sync.WaitGroup defer wg.Wait() visited := make(chan string) var got []string wg.Go(func() { for path := range visited { got = append(got, path) } }) err := uevent.Coldboot(t.Context(), d, visited, func(err error) error { t.Errorf("handleWalkErr: %v", err) return err }) close(visited) if err != nil { t.Fatalf("Coldboot: error = %v", err) } wg.Wait() want := []string{ "devices/sub/uevent", "devices/uevent", } for i, rel := range want { want[i] = filepath.Join(d, rel) } if !slices.Equal(got, want) { t.Errorf("Coldboot: %#v, want %#v", got, want) } var checksum pkg.Checksum if err = pkg.HashDir(&checksum, check.MustAbs(d)); err != nil { t.Fatalf("HashDir: error = %v", err) } wantChecksum := pkg.MustDecode("mEy_Lf5KotThm7OwMx7yTKZh5HCCyaB41pVAvI9uDMgVQFM91iosBLYsRm8bDsX8") if checksum != wantChecksum { t.Errorf( "Coldboot: checksum = %s, want %s", pkg.Encode(checksum), pkg.Encode(wantChecksum), ) } } func TestColdbootError(t *testing.T) { t.Parallel() testCases := []struct { name string dF func(t *testing.T, d string) (wantErr error) vF func(<-chan string, context.Context, context.CancelFunc) hF func(d string, err error) error }{ {"walk", func(t *testing.T, d string) (wantErr error) { wantErr = &os.PathError{ Op: "open", Path: filepath.Join(d, "devices"), Err: syscall.EACCES, } if err := os.Mkdir(filepath.Join(d, "devices"), 0); err != nil { t.Fatal(err) } return }, nil, nil}, {"write", func(t *testing.T, d string) (wantErr error) { wantErr = &os.PathError{ Op: "open", Path: filepath.Join(d, "devices/uevent"), Err: syscall.EACCES, } if err := os.Mkdir(filepath.Join(d, "devices"), 0700); err != nil { t.Fatal(err) } else if err = os.WriteFile(filepath.Join(d, "devices/uevent"), nil, 0); err != nil { t.Fatal(err) } return }, nil, nil}, {"deref", func(t *testing.T, d string) (wantErr error) { if err := os.Mkdir(filepath.Join(d, "devices"), 0700); err != nil { t.Fatal(err) } else if err = os.Symlink("/proc/nonexistent", filepath.Join(d, "devices/uevent")); err != nil { t.Fatal(err) } return }, nil, nil}, {"deref handle", func(t *testing.T, d string) (wantErr error) { if err := os.Mkdir(filepath.Join(d, "devices"), 0700); err != nil { t.Fatal(err) } else if err = os.Symlink("/proc/nonexistent", filepath.Join(d, "devices/uevent")); err != nil { t.Fatal(err) } return }, nil, func(d string, err error) error { if reflect.DeepEqual(err, &os.PathError{ Op: "open", Path: filepath.Join(d, "devices/uevent"), Err: syscall.ENOENT, }) { return nil } return err }}, {"cancel early", func(t *testing.T, d string) (wantErr error) { wantErr = context.Canceled if err := os.Mkdir(filepath.Join(d, "devices"), 0700); err != nil { t.Fatal(err) } return }, func(visited <-chan string, ctx context.Context, cancel context.CancelFunc) { if visited == nil { cancel() } return }, nil}, {"cancel", func(t *testing.T, d string) (wantErr error) { wantErr = context.Canceled if err := os.Mkdir(filepath.Join(d, "devices"), 0700); err != nil { t.Fatal(err) } else if err = os.WriteFile(filepath.Join(d, "devices/uevent"), nil, 0600); err != nil { t.Fatal(err) } else if err = os.Mkdir(filepath.Join(d, "devices/sub"), 0700); err != nil { t.Fatal(err) } else if err = os.WriteFile(filepath.Join(d, "devices/sub/uevent"), nil, 0600); err != nil { t.Fatal(err) } return }, func(visited <-chan string, ctx context.Context, cancel context.CancelFunc) { if visited == nil { return } <-visited cancel() return }, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() d := t.TempDir() wantErr := tc.dF(t, d) var wg sync.WaitGroup defer wg.Wait() ctx, cancel := context.WithCancel(t.Context()) defer cancel() var visited chan string if tc.vF != nil { tc.vF(nil, ctx, cancel) visited = make(chan string) defer close(visited) wg.Go(func() { tc.vF(visited, ctx, cancel) }) } var handleWalkErr func(error) error if tc.hF != nil { handleWalkErr = func(err error) error { return tc.hF(d, err) } } if err := uevent.Coldboot(ctx, d, visited, handleWalkErr); !reflect.DeepEqual(err, wantErr) { t.Errorf("Coldboot: error = %v, want %v", err, wantErr) } }) } }