diff --git a/Context.go b/Context.go index bd6cecf..ab68f05 100644 --- a/Context.go +++ b/Context.go @@ -1,6 +1,7 @@ package server import ( + "errors" "io" "net/http" "unsafe" @@ -12,8 +13,9 @@ const maxParams = 16 // Context represents the interface for a request & response context. type Context interface { Bytes([]byte) error - Error(int, error) error + Error(status int, messages ...any) error Reader(io.Reader) error + SetStatus(status int) String(string) error } @@ -42,9 +44,20 @@ func (ctx *context) Bytes(body []byte) error { } // Error is used for sending error messages to the client. -func (ctx *context) Error(status int, err error) error { - ctx.response.WriteHeader(status) - return err +func (ctx *context) Error(status int, messages ...any) error { + var combined []error + + for _, msg := range messages { + switch err := msg.(type) { + case error: + combined = append(combined, err) + case string: + combined = append(combined, errors.New(err)) + } + } + + ctx.SetStatus(status) + return errors.Join(combined...) } // Reader sends the contents of the io.Reader without creating an in-memory copy. @@ -53,6 +66,11 @@ func (ctx *context) Reader(reader io.Reader) error { return err } +// SetStatus writes the header with the given HTTP status code. +func (ctx *context) SetStatus(status int) { + ctx.response.WriteHeader(status) +} + // String responds with the given string. func (ctx *context) String(body string) error { slice := unsafe.Slice(unsafe.StringData(body), len(body)) diff --git a/Server.go b/Server.go index 51eaa78..13544a7 100644 --- a/Server.go +++ b/Server.go @@ -1,7 +1,8 @@ package server import ( - "fmt" + "io" + "log" "net/http" "git.akyoto.dev/go/router" @@ -34,7 +35,7 @@ func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Requ if handler == nil { response.WriteHeader(http.StatusNotFound) - fmt.Fprint(response, http.StatusText(http.StatusNotFound)) + response.(io.StringWriter).WriteString(http.StatusText(http.StatusNotFound)) contextPool.Put(ctx) return } @@ -42,7 +43,8 @@ func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Requ err := handler(ctx) if err != nil { - fmt.Fprint(response, err.Error()) + response.(io.StringWriter).WriteString(err.Error()) + log.Println(request.URL, err) } contextPool.Put(ctx) diff --git a/Server_test.go b/Server_test.go index 880d87a..e274bfd 100644 --- a/Server_test.go +++ b/Server_test.go @@ -24,7 +24,11 @@ func TestRouter(t *testing.T) { }) s.Get("/error", func(ctx server.Context) error { - return ctx.Error(http.StatusUnauthorized, errors.New("Not logged in")) + return ctx.Error(http.StatusUnauthorized, "Not logged in") + }) + + s.Get("/error2", func(ctx server.Context) error { + return ctx.Error(http.StatusUnauthorized, "Not logged in", errors.New("Missing auth token")) }) s.Get("/reader", func(ctx server.Context) error { @@ -43,6 +47,7 @@ func TestRouter(t *testing.T) { {URL: "/", Status: http.StatusOK, Body: "Hello"}, {URL: "/blog/post", Status: http.StatusOK, Body: "Hello"}, {URL: "/error", Status: http.StatusUnauthorized, Body: "Not logged in"}, + {URL: "/error2", Status: http.StatusUnauthorized, Body: "Not logged in\nMissing auth token"}, {URL: "/not-found", Status: http.StatusNotFound, Body: http.StatusText(http.StatusNotFound)}, {URL: "/reader", Status: http.StatusOK, Body: "Hello"}, {URL: "/string", Status: http.StatusOK, Body: "Hello"},