diff --git a/Benchmarks_test.go b/Benchmarks_test.go index 73995b6..46dc52b 100644 --- a/Benchmarks_test.go +++ b/Benchmarks_test.go @@ -2,38 +2,62 @@ package server_test import ( "net/http/httptest" + "strings" "testing" "git.akyoto.dev/go/router/testdata" "git.akyoto.dev/go/server" ) -func BenchmarkHello(b *testing.B) { - request := httptest.NewRequest("GET", "/", nil) - response := &NullResponse{} +func BenchmarkStatic(b *testing.B) { + paths := []string{ + "/", + "/hello", + "/hello/world", + } + s := server.New() - s.Get("/", func(ctx server.Context) error { - return ctx.String("Hello") - }) + for _, path := range paths { + s.Get(path, func(ctx server.Context) error { + return ctx.String("Hello") + }) + } - for range b.N { - s.ServeHTTP(response, request) + for _, path := range paths { + b.Run(strings.TrimPrefix(path, "/"), func(b *testing.B) { + request := httptest.NewRequest("GET", path, nil) + response := &NullResponse{} + + for range b.N { + s.ServeHTTP(response, request) + } + }) } } func BenchmarkGitHub(b *testing.B) { - request := httptest.NewRequest("GET", "/repos/:owner/:repo", nil) - response := &NullResponse{} + paths := []string{ + "/gists/:id", + "/repos/:a/:b", + } + s := server.New() for _, route := range testdata.Routes("testdata/github.txt") { - s.Router.Add(route.Method, route.Path, func(server.Context) error { - return nil + s.Router().Add(route.Method, route.Path, func(ctx server.Context) error { + return ctx.String("Hello") }) } - for range b.N { - s.ServeHTTP(response, request) + for _, path := range paths { + b.Run(strings.TrimPrefix(path, "/"), func(b *testing.B) { + request := httptest.NewRequest("GET", path, nil) + response := &NullResponse{} + + for range b.N { + s.ServeHTTP(response, request) + } + }) } } diff --git a/Context.go b/Context.go index 30b9358..fb41d12 100644 --- a/Context.go +++ b/Context.go @@ -24,16 +24,18 @@ type Context interface { Reader(io.Reader) error RequestHeader(key string) string ResponseHeader(key string) string + Scheme() string Status(status int) Context String(string) error Write([]byte) (int, error) + WriteString(string) (int, error) } // ctx represents a request & response context. type ctx struct { request *http.Request response http.ResponseWriter - server *Server + server *server paramNames [maxParams]string paramValues [maxParams]string paramCount int @@ -64,7 +66,7 @@ func (ctx *ctx) Error(messages ...any) error { // Get retrieves a parameter. func (ctx *ctx) Get(param string) string { - for i := 0; i < ctx.paramCount; i++ { + for i := range ctx.paramCount { if ctx.paramNames[i] == param { return ctx.paramValues[i] } @@ -120,6 +122,11 @@ func (ctx *ctx) Reader(reader io.Reader) error { return err } +// Scheme returns either `http` or `https`. +func (ctx *ctx) Scheme() string { + return ctx.request.URL.Scheme +} + // Status sets the HTTP status of the response. func (ctx *ctx) Status(status int) Context { ctx.response.WriteHeader(status) @@ -137,6 +144,11 @@ func (ctx *ctx) Write(body []byte) (int, error) { return ctx.response.Write(body) } +// WriteString implements the io.StringWriter interface. +func (ctx *ctx) WriteString(body string) (int, error) { + return ctx.response.(io.StringWriter).WriteString(body) +} + // addParameter adds a new parameter to the context. func (ctx *ctx) addParameter(name string, value string) { ctx.paramNames[ctx.paramCount] = name diff --git a/README.md b/README.md index ea7a12e..0ac38a0 100644 --- a/README.md +++ b/README.md @@ -50,8 +50,11 @@ coverage: 100.0% of statements ## Benchmarks ``` -BenchmarkHello-12 35983602 33.28 ns/op 0 B/op 0 allocs/op -BenchmarkGitHub-12 18320769 68.66 ns/op 0 B/op 0 allocs/op +BenchmarkStatic/#00-12 33616044 33.82 ns/op 0 B/op 0 allocs/op +BenchmarkStatic/hello-12 26220148 43.75 ns/op 0 B/op 0 allocs/op +BenchmarkStatic/hello/world-12 19777713 58.89 ns/op 0 B/op 0 allocs/op +BenchmarkGitHub/gists/:id-12 20842587 56.36 ns/op 0 B/op 0 allocs/op +BenchmarkGitHub/repos/:a/:b-12 17875575 65.04 ns/op 0 B/op 0 allocs/op ``` ## License diff --git a/Server.go b/Server.go index 80357cb..8451bbe 100644 --- a/Server.go +++ b/Server.go @@ -2,32 +2,46 @@ package server import ( "context" - "io" "log" "net" "net/http" "os" "os/signal" + "sync" "syscall" "git.akyoto.dev/go/router" ) -// Server represents a single web service. -type Server struct { - Router *router.Router[Handler] - Config Configuration - handlers []Handler +// Server is the interface for an HTTP server. +type Server interface { + http.Handler + Delete(path string, handler Handler) + Get(path string, handler Handler) + Post(path string, handler Handler) + Put(path string, handler Handler) + Router() *router.Router[Handler] + Run(address string) error + Use(handlers ...Handler) } -// New creates a new server. -func New() *Server { - return &Server{ - Router: router.New[Handler](), - Config: defaultConfig(), +// server is an HTTP server. +type server struct { + pool sync.Pool + handlers []Handler + router router.Router[Handler] + errorHandler func(Context, error) + config Configuration +} + +// New creates a new HTTP server. +func New() Server { + s := &server{ + router: router.Router[Handler]{}, + config: defaultConfig(), handlers: []Handler{ func(c Context) error { - handler := c.(*ctx).server.Router.LookupNoAlloc(c.Method(), c.Path(), c.(*ctx).addParameter) + handler := c.(*ctx).server.router.LookupNoAlloc(c.Method(), c.Path(), c.(*ctx).addParameter) if handler == nil { return c.Status(http.StatusNotFound).String(http.StatusText(http.StatusNotFound)) @@ -36,56 +50,64 @@ func New() *Server { return handler(c) }, }, + errorHandler: func(ctx Context, err error) { + ctx.WriteString(err.Error()) + log.Println(ctx.Path(), err) + }, } + + s.pool.New = func() any { + return &ctx{server: s} + } + + return s } // Get registers your function to be called when the given GET path has been requested. -func (server *Server) Get(path string, handler Handler) { - server.Router.Add(http.MethodGet, path, handler) +func (s *server) Get(path string, handler Handler) { + s.Router().Add(http.MethodGet, path, handler) } // Post registers your function to be called when the given POST path has been requested. -func (server *Server) Post(path string, handler Handler) { - server.Router.Add(http.MethodPost, path, handler) +func (s *server) Post(path string, handler Handler) { + s.Router().Add(http.MethodPost, path, handler) } // Delete registers your function to be called when the given DELETE path has been requested. -func (server *Server) Delete(path string, handler Handler) { - server.Router.Add(http.MethodDelete, path, handler) +func (s *server) Delete(path string, handler Handler) { + s.Router().Add(http.MethodDelete, path, handler) } // Put registers your function to be called when the given PUT path has been requested. -func (server *Server) Put(path string, handler Handler) { - server.Router.Add(http.MethodPut, path, handler) +func (s *server) Put(path string, handler Handler) { + s.Router().Add(http.MethodPut, path, handler) } // ServeHTTP responds to the given request. -func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Request) { - ctx := contextPool.Get().(*ctx) +func (s *server) ServeHTTP(response http.ResponseWriter, request *http.Request) { + ctx := s.pool.Get().(*ctx) ctx.request = request ctx.response = response - ctx.server = server - err := server.handlers[0](ctx) + err := s.handlers[0](ctx) if err != nil { - response.(io.StringWriter).WriteString(err.Error()) - log.Println(request.URL, err) + s.errorHandler(ctx, err) } ctx.paramCount = 0 ctx.handlerCount = 0 - contextPool.Put(ctx) + s.pool.Put(ctx) } // Run starts the server on the given address. -func (server *Server) Run(address string) error { +func (server *server) Run(address string) error { srv := &http.Server{ Addr: address, Handler: server, - ReadTimeout: server.Config.Timeout.Read, - WriteTimeout: server.Config.Timeout.Write, - IdleTimeout: server.Config.Timeout.Idle, - ReadHeaderTimeout: server.Config.Timeout.ReadHeader, + ReadTimeout: server.config.Timeout.Read, + WriteTimeout: server.config.Timeout.Write, + IdleTimeout: server.config.Timeout.Idle, + ReadHeaderTimeout: server.config.Timeout.ReadHeader, } listener, err := net.Listen("tcp", address) @@ -100,15 +122,20 @@ func (server *Server) Run(address string) error { signal.Notify(stop, os.Interrupt, syscall.SIGTERM) <-stop - ctx, cancel := context.WithTimeout(context.Background(), server.Config.Timeout.Shutdown) + ctx, cancel := context.WithTimeout(context.Background(), server.config.Timeout.Shutdown) defer cancel() return srv.Shutdown(ctx) } -// Use adds handlers to your handlers chain. -func (server *Server) Use(handlers ...Handler) { - last := server.handlers[len(server.handlers)-1] - server.handlers = append(server.handlers[:len(server.handlers)-1], handlers...) - server.handlers = append(server.handlers, last) +// Router returns the router used by the server. +func (s *server) Router() *router.Router[Handler] { + return &s.router +} + +// Use adds handlers to your handlers chain. +func (s *server) Use(handlers ...Handler) { + last := s.handlers[len(s.handlers)-1] + s.handlers = append(s.handlers[:len(s.handlers)-1], handlers...) + s.handlers = append(s.handlers, last) } diff --git a/Server_test.go b/Server_test.go index 58e62a0..b7f6ea4 100644 --- a/Server_test.go +++ b/Server_test.go @@ -27,6 +27,11 @@ func TestRouter(t *testing.T) { }) s.Get("/write", func(ctx server.Context) error { + _, err := ctx.Write([]byte("Hello")) + return err + }) + + s.Get("/writestring", func(ctx server.Context) error { _, err := io.WriteString(ctx, "Hello") return err }) @@ -72,6 +77,10 @@ func TestRouter(t *testing.T) { return ctx.String(missing) }) + s.Get("/scheme", func(ctx server.Context) error { + return ctx.String(ctx.Scheme()) + }) + s.Post("/", func(ctx server.Context) error { return ctx.String("Post") }) @@ -99,7 +108,9 @@ func TestRouter(t *testing.T) { {Method: "GET", URL: "/response/header", Status: http.StatusOK, Body: "text/plain"}, {Method: "GET", URL: "/reader", Status: http.StatusOK, Body: "Hello"}, {Method: "GET", URL: "/string", Status: http.StatusOK, Body: "Hello"}, + {Method: "GET", URL: "/scheme", Status: http.StatusOK, Body: "http"}, {Method: "GET", URL: "/write", Status: http.StatusOK, Body: "Hello"}, + {Method: "GET", URL: "/writestring", Status: http.StatusOK, Body: "Hello"}, {Method: "GET", URL: "/blog/testing-my-router", Status: http.StatusOK, Body: "testing-my-router"}, {Method: "GET", URL: "/missing-parameter", Status: http.StatusOK, Body: ""}, {Method: "POST", URL: "/", Status: http.StatusOK, Body: "Post"}, @@ -109,7 +120,7 @@ func TestRouter(t *testing.T) { for _, test := range tests { t.Run("example.com"+test.URL, func(t *testing.T) { - request := httptest.NewRequest(test.Method, test.URL, nil) + request := httptest.NewRequest(test.Method, "http://example.com"+test.URL, nil) response := httptest.NewRecorder() s.ServeHTTP(response, request) @@ -141,7 +152,7 @@ func TestMiddleware(t *testing.T) { func TestPanic(t *testing.T) { s := server.New() - s.Router.Add(http.MethodGet, "/panic", func(ctx server.Context) error { + s.Router().Add(http.MethodGet, "/panic", func(ctx server.Context) error { panic("Something unbelievable happened") }) diff --git a/pool.go b/pool.go deleted file mode 100644 index dd1e032..0000000 --- a/pool.go +++ /dev/null @@ -1,9 +0,0 @@ -package server - -import "sync" - -var contextPool = sync.Pool{ - New: func() any { - return &ctx{} - }, -}