diff --git a/internal/app/hsu.go b/internal/app/hsu.go index e53d9f3..cae56c8 100644 --- a/internal/app/hsu.go +++ b/internal/app/hsu.go @@ -62,8 +62,8 @@ func (h *Hsu) ID() (int, error) { } else if errors.As(h.idErr, &exitError) && exitError != nil && exitError.ExitCode() == 1 { // hsu prints an error message in this case h.idErr = &hst.AppError{Step: step, Err: ErrHsuAccess} - } else if os.IsNotExist(h.idErr) { - h.idErr = &hst.AppError{Step: step, Err: os.ErrNotExist, + } else if errors.Is(h.idErr, os.ErrNotExist) { + h.idErr = &hst.AppError{Step: step, Err: h.idErr, Msg: fmt.Sprintf("the setuid helper is missing: %s", hsuPath)} } }) @@ -84,16 +84,16 @@ func (h *Hsu) MustID(msg message.Msg) int { msg.Verbose("*"+fallback, err) } os.Exit(1) - return -0xdeadbeef + return -0xdeadbeef // not reached } else if m, ok := message.GetMessage(err); ok { log.Fatal(m) - return -0xdeadbeef + return -0xdeadbeef // not reached } else { log.Fatalln(fallback, err) - return -0xdeadbeef + return -0xdeadbeef // not reached } } // HsuUid returns target uid for the stable hsu uid format. -// No bounds check is performed, a value retrieved from hsu is expected. +// No bounds check is performed, a value retrieved by [Hsu] is expected. func HsuUid(id, identity int) int { return 1000000 + id*10000 + identity } diff --git a/internal/app/hsu_test.go b/internal/app/hsu_test.go new file mode 100644 index 0000000..fad724d --- /dev/null +++ b/internal/app/hsu_test.go @@ -0,0 +1,84 @@ +package app + +import ( + "os" + "os/exec" + "reflect" + "strconv" + "syscall" + "testing" + "unsafe" + + "hakurei.app/container/stub" + "hakurei.app/hst" +) + +func TestHsu(t *testing.T) { + t.Parallel() + + t.Run("ensure dispatcher", func(t *testing.T) { + hsu := new(Hsu) + hsu.ensureDispatcher() + + k := direct{} + if !reflect.DeepEqual(hsu.k, k) { + t.Errorf("ensureDispatcher: k = %#v, want %#v", hsu.k, k) + } + }) + + fCheckID := func(k *kstub) error { + hsu := &Hsu{k: k} + id, err := hsu.ID() + k.Verbose(id) + if id0, err0 := hsu.ID(); id0 != id || !reflect.DeepEqual(err0, err) { + t.Fatalf("ID: id0 = %d, err0 = %#v, id = %d, err = %#v", id0, err0, id, err) + } + return err + } + + checkSimple(t, "Hsu.ID", []simpleTestCase{ + {"hsu nonexistent", fCheckID, stub.Expect{Calls: []stub.Call{ + call("mustHsuPath", stub.ExpectArgs{}, m("/run/wrappers/bin/hsu"), nil), + call("cmdOutput", stub.ExpectArgs{"/run/wrappers/bin/hsu", os.Stderr, []string{}, "/"}, ([]byte)(nil), os.ErrNotExist), + call("verbose", stub.ExpectArgs{[]any{-1}}, nil, nil), + }}, &hst.AppError{ + Step: "obtain uid from hsu", + Err: os.ErrNotExist, + Msg: "the setuid helper is missing: /run/wrappers/bin/hsu", + }}, + + {"access", fCheckID, stub.Expect{Calls: []stub.Call{ + call("mustHsuPath", stub.ExpectArgs{}, m("/run/wrappers/bin/hsu"), nil), + call("cmdOutput", stub.ExpectArgs{"/run/wrappers/bin/hsu", os.Stderr, []string{}, "/"}, ([]byte)(nil), makeExitError(1<<8)), + call("verbose", stub.ExpectArgs{[]any{-1}}, nil, nil), + }}, &hst.AppError{ + Step: "obtain uid from hsu", + Err: ErrHsuAccess, + }}, + + {"invalid output", fCheckID, stub.Expect{Calls: []stub.Call{ + call("mustHsuPath", stub.ExpectArgs{}, m("/run/wrappers/bin/hsu"), nil), + call("cmdOutput", stub.ExpectArgs{"/run/wrappers/bin/hsu", os.Stderr, []string{}, "/"}, []byte{0}, nil), + call("verbose", stub.ExpectArgs{[]any{0}}, nil, nil), + }}, &hst.AppError{ + Step: "obtain uid from hsu", + Err: &strconv.NumError{Func: "Atoi", Num: "\x00", Err: strconv.ErrSyntax}, + Msg: "invalid uid string from hsu", + }}, + + {"success", fCheckID, stub.Expect{Calls: []stub.Call{ + call("mustHsuPath", stub.ExpectArgs{}, m("/run/wrappers/bin/hsu"), nil), + call("cmdOutput", stub.ExpectArgs{"/run/wrappers/bin/hsu", os.Stderr, []string{}, "/"}, []byte{'0'}, nil), + call("verbose", stub.ExpectArgs{[]any{0}}, nil, nil), + }}, nil}, + }) +} + +// makeExitError populates syscall.WaitStatus in an [exec.ExitError]. +// Do not reuse this function in a cross-platform package. +func makeExitError(status syscall.WaitStatus) error { + ps := new(os.ProcessState) + statusV := reflect.ValueOf(ps).Elem().FieldByName("status") + *reflect.NewAt(statusV.Type(), unsafe.Pointer(statusV.UnsafeAddr())).Interface().(*syscall.WaitStatus) = status + return &exec.ExitError{ProcessState: ps} +}