diff --git a/Benchmarks_test.go b/Benchmarks_test.go index efcb4ff..73995b6 100644 --- a/Benchmarks_test.go +++ b/Benchmarks_test.go @@ -4,6 +4,7 @@ import ( "net/http/httptest" "testing" + "git.akyoto.dev/go/router/testdata" "git.akyoto.dev/go/server" ) @@ -26,7 +27,7 @@ func BenchmarkGitHub(b *testing.B) { response := &NullResponse{} s := server.New() - for _, route := range loadRoutes("testdata/github.txt") { + for _, route := range testdata.Routes("testdata/github.txt") { s.Router.Add(route.Method, route.Path, func(server.Context) error { return nil }) diff --git a/Context.go b/Context.go index 22a69b6..1525c78 100644 --- a/Context.go +++ b/Context.go @@ -15,28 +15,38 @@ type Context interface { Bytes([]byte) error Error(messages ...any) error Get(param string) string + Header(key string, value string) + Host() string + Method() string + Next() error + Path() string + Protocol() string Reader(io.Reader) error - Request() Request - Response() Response + RequestHeader(key string) string + ResponseHeader(key string) string Status(status int) Context String(string) error } // ctx represents a request & response context. type ctx struct { - request request - response response - paramNames [maxParams]string - paramValues [maxParams]string - paramCount int + request *http.Request + response http.ResponseWriter + server *Server + paramNames [maxParams]string + paramValues [maxParams]string + paramCount int + handlerCount int } // newContext returns a new context from the pool. -func newContext(req *http.Request, res http.ResponseWriter) *ctx { +func newContext(req *http.Request, res http.ResponseWriter, server *Server) *ctx { ctx := contextPool.Get().(*ctx) - ctx.request.Request = req - ctx.response.ResponseWriter = res + ctx.request = req + ctx.response = res + ctx.server = server ctx.paramCount = 0 + ctx.handlerCount = 0 return ctx } @@ -73,22 +83,53 @@ func (ctx *ctx) Get(param string) string { return "" } +// Next executes the next handler in the middleware chain. +func (ctx *ctx) Next() error { + ctx.handlerCount++ + return ctx.server.handlers[ctx.handlerCount](ctx) +} + +// RequestHeader returns the request header value for the given key. +func (ctx *ctx) RequestHeader(key string) string { + return ctx.request.Header.Get(key) +} + +// ResponseHeader returns the response header value for the given key. +func (ctx *ctx) ResponseHeader(key string) string { + return ctx.response.Header().Get(key) +} + +// Header sets the header value for the given key. +func (ctx *ctx) Header(key string, value string) { + ctx.response.Header().Set(key, value) +} + +// Method returns the request method. +func (ctx *ctx) Method() string { + return ctx.request.Method +} + +// Protocol returns the request protocol. +func (ctx *ctx) Protocol() string { + return ctx.request.Proto +} + +// Host returns the requested host. +func (ctx *ctx) Host() string { + return ctx.request.Host +} + +// Path returns the requested path. +func (ctx *ctx) Path() string { + return ctx.request.URL.Path +} + // Reader sends the contents of the io.Reader without creating an in-memory copy. func (ctx *ctx) Reader(reader io.Reader) error { _, err := io.Copy(ctx.response, reader) return err } -// Request returns the HTTP request. -func (ctx *ctx) Request() Request { - return &ctx.request -} - -// Response returns the HTTP response. -func (ctx *ctx) Response() Response { - return &ctx.response -} - // Status sets the HTTP status of the response. func (ctx *ctx) Status(status int) Context { ctx.response.WriteHeader(status) diff --git a/README.md b/README.md index ab9af3b..1af75f5 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ s.Run(":8080") ``` PASS: TestRouter +PASS: TestMiddleware PASS: TestPanic PASS: TestRun PASS: TestUnavailablePort @@ -49,8 +50,8 @@ coverage: 100.0% of statements ## Benchmarks ``` -BenchmarkHello-12 33502272 33.64 ns/op 0 B/op 0 allocs/op -BenchmarkGitHub-12 17698947 65.50 ns/op 0 B/op 0 allocs/op +BenchmarkHello-12 31128832 38.41 ns/op 0 B/op 0 allocs/op +BenchmarkGitHub-12 17406580 68.56 ns/op 0 B/op 0 allocs/op ``` ## License diff --git a/Request.go b/Request.go deleted file mode 100644 index 85a9a50..0000000 --- a/Request.go +++ /dev/null @@ -1,42 +0,0 @@ -package server - -import "net/http" - -// Request is an interface for HTTP requests. -type Request interface { - Header(key string) string - Host() string - Method() string - Path() string - Protocol() string -} - -// request represents the HTTP request used in the given context. -type request struct { - *http.Request -} - -// Header returns the header value for the given key. -func (req request) Header(key string) string { - return req.Request.Header.Get(key) -} - -// Method returns the request method. -func (req request) Method() string { - return req.Request.Method -} - -// Protocol returns the request protocol. -func (req request) Protocol() string { - return req.Request.Proto -} - -// Host returns the requested host. -func (req request) Host() string { - return req.Request.Host -} - -// Path returns the requested path. -func (req request) Path() string { - return req.Request.URL.Path -} diff --git a/Response.go b/Response.go deleted file mode 100644 index db45bed..0000000 --- a/Response.go +++ /dev/null @@ -1,24 +0,0 @@ -package server - -import "net/http" - -// Response is the interface for an HTTP response. -type Response interface { - Header(key string) string - SetHeader(key string, value string) -} - -// response represents the HTTP response used in the given context. -type response struct { - http.ResponseWriter -} - -// Header returns the header value for the given key. -func (res response) Header(key string) string { - return res.ResponseWriter.Header().Get(key) -} - -// SetHeader sets the header value for the given key. -func (res response) SetHeader(key string, value string) { - res.ResponseWriter.Header().Set(key, value) -} diff --git a/Route_test.go b/Route_test.go deleted file mode 100644 index 8ca46de..0000000 --- a/Route_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package server_test - -import ( - "bufio" - "os" - "strings" -) - -// Route represents a single line in the test data. -type Route struct { - Method string - Path string -} - -// loadRoutes loads all routes from a text file. -func loadRoutes(filePath string) []Route { - var routes []Route - f, err := os.Open(filePath) - - if err != nil { - panic(err) - } - - defer f.Close() - reader := bufio.NewReader(f) - - for { - line, err := reader.ReadString('\n') - - if line != "" { - line = strings.TrimSpace(line) - parts := strings.Split(line, " ") - - routes = append(routes, Route{ - Method: parts[0], - Path: parts[1], - }) - } - - if err != nil { - break - } - } - - return routes -} diff --git a/Server.go b/Server.go index d4ffd8e..1042ee4 100644 --- a/Server.go +++ b/Server.go @@ -15,8 +15,9 @@ import ( // Server represents a single web service. type Server struct { - Router *router.Router[Handler] - Config Configuration + Router *router.Router[Handler] + Config Configuration + handlers []Handler } // New creates a new server. @@ -24,6 +25,17 @@ func New() *Server { return &Server{ Router: router.New[Handler](), Config: defaultConfig(), + handlers: []Handler{ + func(c Context) error { + handler := c.(*ctx).server.Router.LookupNoAlloc(c.Method(), c.Path(), c.(*ctx).addParameter) + + if handler == nil { + return c.Status(http.StatusNotFound).String(http.StatusText(http.StatusNotFound)) + } + + return handler(c) + }, + }, } } @@ -49,18 +61,10 @@ func (server *Server) Put(path string, handler Handler) { // ServeHTTP responds to the given request. func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Request) { - ctx := newContext(request, response) + ctx := newContext(request, response, server) defer contextPool.Put(ctx) - handler := server.Router.LookupNoAlloc(request.Method, request.URL.Path, ctx.addParameter) - - if handler == nil { - response.WriteHeader(http.StatusNotFound) - response.(io.StringWriter).WriteString(http.StatusText(http.StatusNotFound)) - return - } - - err := handler(ctx) + err := server.handlers[0](ctx) if err != nil { response.(io.StringWriter).WriteString(err.Error()) @@ -96,3 +100,10 @@ func (server *Server) Run(address string) error { return srv.Shutdown(ctx) } + +// Use adds handlers to your handlers chain. +func (server *Server) Use(handlers ...Handler) { + last := server.handlers[len(server.handlers)-1] + server.handlers = append(server.handlers[:len(server.handlers)-1], handlers...) + server.handlers = append(server.handlers, last) +} diff --git a/Server_test.go b/Server_test.go index e88b791..bb87387 100644 --- a/Server_test.go +++ b/Server_test.go @@ -39,24 +39,21 @@ func TestRouter(t *testing.T) { }) s.Get("/request/data", func(ctx server.Context) error { - request := ctx.Request() - method := request.Method() - protocol := request.Protocol() - host := request.Host() - path := request.Path() + method := ctx.Method() + protocol := ctx.Protocol() + host := ctx.Host() + path := ctx.Path() return ctx.String(fmt.Sprintf("%s %s %s %s", method, protocol, host, path)) }) s.Get("/request/header", func(ctx server.Context) error { - request := ctx.Request() - acceptEncoding := request.Header("Accept-Encoding") + acceptEncoding := ctx.RequestHeader("Accept-Encoding") return ctx.String(acceptEncoding) }) s.Get("/response/header", func(ctx server.Context) error { - response := ctx.Response() - response.SetHeader("Content-Type", "text/plain") - contentType := response.Header("Content-Type") + ctx.Header("Content-Type", "text/plain") + contentType := ctx.ResponseHeader("Content-Type") return ctx.String(contentType) }) @@ -120,6 +117,21 @@ func TestRouter(t *testing.T) { } } +func TestMiddleware(t *testing.T) { + s := server.New() + + s.Use(func(ctx server.Context) error { + ctx.Header("Middleware", "true") + return ctx.Next() + }) + + request := httptest.NewRequest(http.MethodGet, "/", nil) + response := httptest.NewRecorder() + s.ServeHTTP(response, request) + + assert.Equal(t, response.Header().Get("Middleware"), "true") +} + func TestPanic(t *testing.T) { s := server.New() diff --git a/examples/logger/main.go b/examples/logger/main.go new file mode 100644 index 0000000..c1f6bed --- /dev/null +++ b/examples/logger/main.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "time" + + "git.akyoto.dev/go/server" +) + +func main() { + s := server.New() + + s.Use(func(ctx server.Context) error { + start := time.Now() + + defer func() { + fmt.Println(ctx.Path(), time.Since(start)) + }() + + return ctx.Next() + }) + + s.Run(":8080") +}