diff --git a/Context.go b/Context.go index 946b408..6da6e30 100644 --- a/Context.go +++ b/Context.go @@ -4,44 +4,42 @@ import ( "errors" "io" "net/http" - "unsafe" "git.akyoto.dev/go/router" ) // Context represents the interface for a request & response context. type Context interface { + Copy(io.Reader) error Bytes([]byte) error Error(messages ...any) error + File(path string) error Get(param string) string - Header(key string, value string) - Host() string - Method() string Next() error - Path() string - Protocol() string - Reader(io.Reader) error - RequestHeader(key string) string - ResponseHeader(key string) string - Scheme() string + Request() Request + Response() Response Status(status int) Context String(string) error - Write([]byte) (int, error) - WriteString(string) (int, error) } // ctx represents a request & response context. type ctx struct { - request *http.Request - response http.ResponseWriter + request request + response response server *server params []router.Parameter handlerCount uint8 } // Bytes responds with a raw byte slice. -func (c *ctx) Bytes(body []byte) error { - _, err := c.response.Write(body) +func (ctx *ctx) Bytes(body []byte) error { + _, err := ctx.response.Write(body) + return err +} + +// Copy sends the contents of the io.Reader without creating an in-memory copy. +func (ctx *ctx) Copy(reader io.Reader) error { + _, err := io.Copy(ctx.response.ResponseWriter, reader) return err } @@ -74,56 +72,26 @@ func (ctx *ctx) Get(param string) string { return "" } +// File serves the file at the given path. +func (ctx *ctx) File(path string) error { + http.ServeFile(ctx.response.ResponseWriter, ctx.request.Request, path) + return nil +} + // Next executes the next handler in the middleware chain. func (ctx *ctx) Next() error { ctx.handlerCount++ return ctx.server.handlers[ctx.handlerCount](ctx) } -// RequestHeader returns the request header value for the given key. -func (ctx *ctx) RequestHeader(key string) string { - return ctx.request.Header.Get(key) +// Request returns the HTTP request. +func (ctx *ctx) Request() Request { + return &ctx.request } -// ResponseHeader returns the response header value for the given key. -func (ctx *ctx) ResponseHeader(key string) string { - return ctx.response.Header().Get(key) -} - -// Header sets the header value for the given key. -func (ctx *ctx) Header(key string, value string) { - ctx.response.Header().Set(key, value) -} - -// Method returns the request method. -func (ctx *ctx) Method() string { - return ctx.request.Method -} - -// Protocol returns the request protocol. -func (ctx *ctx) Protocol() string { - return ctx.request.Proto -} - -// Host returns the requested host. -func (ctx *ctx) Host() string { - return ctx.request.Host -} - -// Path returns the requested path. -func (ctx *ctx) Path() string { - return ctx.request.URL.Path -} - -// Reader sends the contents of the io.Reader without creating an in-memory copy. -func (ctx *ctx) Reader(reader io.Reader) error { - _, err := io.Copy(ctx.response, reader) - return err -} - -// Scheme returns either `http` or `https`. -func (ctx *ctx) Scheme() string { - return ctx.request.URL.Scheme +// Response returns the HTTP response. +func (ctx *ctx) Response() Response { + return &ctx.response } // Status sets the HTTP status of the response. @@ -134,18 +102,8 @@ func (ctx *ctx) Status(status int) Context { // String responds with the given string. func (ctx *ctx) String(body string) error { - slice := unsafe.Slice(unsafe.StringData(body), len(body)) - return ctx.Bytes(slice) -} - -// Write implements the io.Writer interface. -func (ctx *ctx) Write(body []byte) (int, error) { - return ctx.response.Write(body) -} - -// WriteString implements the io.StringWriter interface. -func (ctx *ctx) WriteString(body string) (int, error) { - return ctx.response.(io.StringWriter).WriteString(body) + _, err := ctx.response.WriteString(body) + return err } // addParameter adds a new parameter to the context. diff --git a/README.md b/README.md index 9193167..5c57576 100644 --- a/README.md +++ b/README.md @@ -51,11 +51,11 @@ coverage: 100.0% of statements ## Benchmarks ``` -BenchmarkStatic/#00-12 34907296 30.19 ns/op 0 B/op 0 allocs/op -BenchmarkStatic/hello-12 27628322 40.89 ns/op 0 B/op 0 allocs/op -BenchmarkStatic/hello/world-12 21330940 56.24 ns/op 0 B/op 0 allocs/op -BenchmarkGitHub/gists/:id-12 23608254 50.86 ns/op 0 B/op 0 allocs/op -BenchmarkGitHub/repos/:a/:b-12 18912602 59.02 ns/op 0 B/op 0 allocs/op +BenchmarkStatic/#00-12 32963155 30.88 ns/op 0 B/op 0 allocs/op +BenchmarkStatic/hello-12 31640433 37.92 ns/op 0 B/op 0 allocs/op +BenchmarkStatic/hello/world-12 22497412 52.57 ns/op 0 B/op 0 allocs/op +BenchmarkGitHub/gists/:id-12 24162244 49.70 ns/op 0 B/op 0 allocs/op +BenchmarkGitHub/repos/:a/:b-12 18865028 59.22 ns/op 0 B/op 0 allocs/op ``` ## License diff --git a/Request.go b/Request.go new file mode 100644 index 0000000..ecce4f8 --- /dev/null +++ b/Request.go @@ -0,0 +1,63 @@ +package server + +import ( + "context" + "net/http" +) + +// Request is an interface for HTTP requests. +type Request interface { + Context() context.Context + Header(key string) string + Host() string + Method() string + Path() string + Protocol() string + Read([]byte) (int, error) + Scheme() string +} + +// request represents the HTTP request used in the given context. +type request struct { + *http.Request +} + +// Context returns the request context. +func (req request) Context() context.Context { + return req.Request.Context() +} + +// Header returns the header value for the given key. +func (req request) Header(key string) string { + return req.Request.Header.Get(key) +} + +// Method returns the request method. +func (req request) Method() string { + return req.Request.Method +} + +// Protocol returns the request protocol. +func (req request) Protocol() string { + return req.Request.Proto +} + +// Host returns the requested host. +func (req request) Host() string { + return req.Request.Host +} + +// Path returns the requested path. +func (req request) Path() string { + return req.Request.URL.Path +} + +// // Read implements the io.Reader interface and reads the request body. +func (req request) Read(buffer []byte) (int, error) { + return req.Request.Body.Read(buffer) +} + +// Scheme returns either `http` or `https`. +func (req request) Scheme() string { + return req.Request.URL.Scheme +} diff --git a/Response.go b/Response.go new file mode 100644 index 0000000..49bbd33 --- /dev/null +++ b/Response.go @@ -0,0 +1,39 @@ +package server + +import ( + "io" + "net/http" +) + +// Response is the interface for an HTTP response. +type Response interface { + Header(key string) string + SetHeader(key string, value string) + Write([]byte) (int, error) + WriteString(string) (int, error) +} + +// response represents the HTTP response used in the given context. +type response struct { + http.ResponseWriter +} + +// Header returns the header value for the given key. +func (res response) Header(key string) string { + return res.ResponseWriter.Header().Get(key) +} + +// SetHeader sets the header value for the given key. +func (res response) SetHeader(key string, value string) { + res.ResponseWriter.Header().Set(key, value) +} + +// Write implements the io.Writer interface. +func (res response) Write(body []byte) (int, error) { + return res.ResponseWriter.Write(body) +} + +// WriteString implements the io.StringWriter interface. +func (res response) WriteString(body string) (int, error) { + return res.ResponseWriter.(io.StringWriter).WriteString(body) +} diff --git a/Server.go b/Server.go index a285f68..8f8949a 100644 --- a/Server.go +++ b/Server.go @@ -42,8 +42,8 @@ func New() Server { handlers: []Handler{ func(c Context) error { ctx := c.(*ctx) - method := ctx.Method() - path := ctx.Path() + method := ctx.request.Method() + path := ctx.request.Path() handler := ctx.server.router.LookupNoAlloc(method, path, ctx.addParameter) if handler == nil { @@ -54,8 +54,8 @@ func New() Server { }, }, errorHandler: func(ctx Context, err error) { - ctx.WriteString(err.Error()) - log.Println(ctx.Path(), err) + ctx.Response().WriteString(err.Error()) + log.Println(ctx.Request().Path(), err) }, } @@ -90,10 +90,10 @@ func (s *server) Put(path string, handler Handler) { } // ServeHTTP responds to the given request. -func (s *server) ServeHTTP(response http.ResponseWriter, request *http.Request) { +func (s *server) ServeHTTP(res http.ResponseWriter, req *http.Request) { ctx := s.pool.Get().(*ctx) - ctx.request = request - ctx.response = response + ctx.request = request{req} + ctx.response = response{res} err := s.handlers[0](ctx) if err != nil { diff --git a/Server_test.go b/Server_test.go index b7f6ea4..0d5b4b3 100644 --- a/Server_test.go +++ b/Server_test.go @@ -27,12 +27,12 @@ func TestRouter(t *testing.T) { }) s.Get("/write", func(ctx server.Context) error { - _, err := ctx.Write([]byte("Hello")) + _, err := ctx.Response().Write([]byte("Hello")) return err }) s.Get("/writestring", func(ctx server.Context) error { - _, err := io.WriteString(ctx, "Hello") + _, err := io.WriteString(ctx.Response(), "Hello") return err }) @@ -45,25 +45,38 @@ func TestRouter(t *testing.T) { }) s.Get("/reader", func(ctx server.Context) error { - return ctx.Reader(strings.NewReader("Hello")) + return ctx.Copy(strings.NewReader("Hello")) + }) + + s.Get("/file", func(ctx server.Context) error { + return ctx.File("testdata/file.txt") + }) + + s.Get("/echo", func(ctx server.Context) error { + return ctx.Copy(ctx.Request()) + }) + + s.Get("/context", func(ctx server.Context) error { + return ctx.Request().Context().Err() }) s.Get("/request/data", func(ctx server.Context) error { - method := ctx.Method() - protocol := ctx.Protocol() - host := ctx.Host() - path := ctx.Path() + request := ctx.Request() + method := request.Method() + protocol := request.Protocol() + host := request.Host() + path := request.Path() return ctx.String(fmt.Sprintf("%s %s %s %s", method, protocol, host, path)) }) s.Get("/request/header", func(ctx server.Context) error { - acceptEncoding := ctx.RequestHeader("Accept-Encoding") + acceptEncoding := ctx.Request().Header("Accept-Encoding") return ctx.String(acceptEncoding) }) s.Get("/response/header", func(ctx server.Context) error { - ctx.Header("Content-Type", "text/plain") - contentType := ctx.ResponseHeader("Content-Type") + ctx.Response().SetHeader("Content-Type", "text/plain") + contentType := ctx.Response().Header("Content-Type") return ctx.String(contentType) }) @@ -78,7 +91,7 @@ func TestRouter(t *testing.T) { }) s.Get("/scheme", func(ctx server.Context) error { - return ctx.String(ctx.Scheme()) + return ctx.String(ctx.Request().Scheme()) }) s.Post("/", func(ctx server.Context) error { @@ -94,33 +107,37 @@ func TestRouter(t *testing.T) { }) tests := []struct { - Method string - URL string - Status int - Body string + Method string + URL string + Body string + Status int + Response string }{ - {Method: "GET", URL: "/", Status: http.StatusOK, Body: "Hello"}, - {Method: "GET", URL: "/error", Status: http.StatusUnauthorized, Body: "Not logged in"}, - {Method: "GET", URL: "/error2", Status: http.StatusUnauthorized, Body: "Not logged in\nMissing auth token"}, - {Method: "GET", URL: "/not-found", Status: http.StatusNotFound, Body: http.StatusText(http.StatusNotFound)}, - {Method: "GET", URL: "/request/data", Status: http.StatusOK, Body: "GET HTTP/1.1 example.com /request/data"}, - {Method: "GET", URL: "/request/header", Status: http.StatusOK, Body: ""}, - {Method: "GET", URL: "/response/header", Status: http.StatusOK, Body: "text/plain"}, - {Method: "GET", URL: "/reader", Status: http.StatusOK, Body: "Hello"}, - {Method: "GET", URL: "/string", Status: http.StatusOK, Body: "Hello"}, - {Method: "GET", URL: "/scheme", Status: http.StatusOK, Body: "http"}, - {Method: "GET", URL: "/write", Status: http.StatusOK, Body: "Hello"}, - {Method: "GET", URL: "/writestring", Status: http.StatusOK, Body: "Hello"}, - {Method: "GET", URL: "/blog/testing-my-router", Status: http.StatusOK, Body: "testing-my-router"}, - {Method: "GET", URL: "/missing-parameter", Status: http.StatusOK, Body: ""}, - {Method: "POST", URL: "/", Status: http.StatusOK, Body: "Post"}, - {Method: "DELETE", URL: "/", Status: http.StatusOK, Body: "Delete"}, - {Method: "PUT", URL: "/", Status: http.StatusOK, Body: "Put"}, + {Method: "GET", URL: "/", Body: "", Status: http.StatusOK, Response: "Hello"}, + {Method: "GET", URL: "/context", Body: "", Status: http.StatusOK, Response: ""}, + {Method: "GET", URL: "/echo", Body: "Echo", Status: http.StatusOK, Response: "Echo"}, + {Method: "GET", URL: "/error", Body: "", Status: http.StatusUnauthorized, Response: "Not logged in"}, + {Method: "GET", URL: "/error2", Body: "", Status: http.StatusUnauthorized, Response: "Not logged in\nMissing auth token"}, + {Method: "GET", URL: "/file", Body: "", Status: http.StatusOK, Response: "Hello File"}, + {Method: "GET", URL: "/not-found", Body: "", Status: http.StatusNotFound, Response: http.StatusText(http.StatusNotFound)}, + {Method: "GET", URL: "/request/data", Body: "", Status: http.StatusOK, Response: "GET HTTP/1.1 example.com /request/data"}, + {Method: "GET", URL: "/request/header", Body: "", Status: http.StatusOK, Response: ""}, + {Method: "GET", URL: "/response/header", Body: "", Status: http.StatusOK, Response: "text/plain"}, + {Method: "GET", URL: "/reader", Body: "", Status: http.StatusOK, Response: "Hello"}, + {Method: "GET", URL: "/string", Body: "", Status: http.StatusOK, Response: "Hello"}, + {Method: "GET", URL: "/scheme", Body: "", Status: http.StatusOK, Response: "http"}, + {Method: "GET", URL: "/write", Body: "", Status: http.StatusOK, Response: "Hello"}, + {Method: "GET", URL: "/writestring", Body: "", Status: http.StatusOK, Response: "Hello"}, + {Method: "GET", URL: "/blog/testing-my-router", Body: "", Status: http.StatusOK, Response: "testing-my-router"}, + {Method: "GET", URL: "/missing-parameter", Body: "", Status: http.StatusOK, Response: ""}, + {Method: "POST", URL: "/", Body: "", Status: http.StatusOK, Response: "Post"}, + {Method: "DELETE", URL: "/", Body: "", Status: http.StatusOK, Response: "Delete"}, + {Method: "PUT", URL: "/", Body: "", Status: http.StatusOK, Response: "Put"}, } for _, test := range tests { t.Run("example.com"+test.URL, func(t *testing.T) { - request := httptest.NewRequest(test.Method, "http://example.com"+test.URL, nil) + request := httptest.NewRequest(test.Method, "http://example.com"+test.URL, strings.NewReader(test.Body)) response := httptest.NewRecorder() s.ServeHTTP(response, request) @@ -129,7 +146,7 @@ func TestRouter(t *testing.T) { body, err := io.ReadAll(result.Body) assert.Nil(t, err) - assert.DeepEqual(t, string(body), test.Body) + assert.Equal(t, string(body), test.Response) }) } } @@ -138,7 +155,7 @@ func TestMiddleware(t *testing.T) { s := server.New() s.Use(func(ctx server.Context) error { - ctx.Header("Middleware", "true") + ctx.Response().SetHeader("Middleware", "true") return ctx.Next() }) diff --git a/common_test.go b/common_test.go index 6b7b942..39514ca 100644 --- a/common_test.go +++ b/common_test.go @@ -6,6 +6,7 @@ import "net/http" // empty methods to better understand memory usage in benchmarks. type NullResponse struct{} -func (r *NullResponse) Header() http.Header { return nil } -func (r *NullResponse) Write([]byte) (int, error) { return 0, nil } -func (r *NullResponse) WriteHeader(int) {} +func (r *NullResponse) Header() http.Header { return nil } +func (r *NullResponse) Write([]byte) (int, error) { return 0, nil } +func (r *NullResponse) WriteString(string) (int, error) { return 0, nil } +func (r *NullResponse) WriteHeader(int) {} diff --git a/examples/logger/main.go b/examples/logger/main.go index c1f6bed..a35a4a1 100644 --- a/examples/logger/main.go +++ b/examples/logger/main.go @@ -14,7 +14,7 @@ func main() { start := time.Now() defer func() { - fmt.Println(ctx.Path(), time.Since(start)) + fmt.Println(ctx.Request().Path(), time.Since(start)) }() return ctx.Next() diff --git a/send/send.go b/send/send.go new file mode 100644 index 0000000..146c98a --- /dev/null +++ b/send/send.go @@ -0,0 +1,32 @@ +package send + +import ( + "encoding/json" + + "git.akyoto.dev/go/server" +) + +func Text(ctx server.Context, body string) error { + ctx.Response().SetHeader("Content-Type", "text/plain") + return ctx.String(body) +} + +func CSS(ctx server.Context, body string) error { + ctx.Response().SetHeader("Content-Type", "text/css") + return ctx.String(body) +} + +func JS(ctx server.Context, body string) error { + ctx.Response().SetHeader("Content-Type", "text/javascript") + return ctx.String(body) +} + +func JSON(ctx server.Context, object any) error { + ctx.Response().SetHeader("Content-Type", "application/json") + return json.NewEncoder(ctx.Response()).Encode(object) +} + +func HTML(ctx server.Context, body string) error { + ctx.Response().SetHeader("Content-Type", "text/html") + return ctx.String(body) +} diff --git a/testdata/file.txt b/testdata/file.txt new file mode 100644 index 0000000..7d5bdbf --- /dev/null +++ b/testdata/file.txt @@ -0,0 +1 @@ +Hello File \ No newline at end of file