Added middleware support

This commit is contained in:
Eduard Urbach 2024-03-13 20:18:01 +01:00
parent 1bcf4794f5
commit e69a66aa31
Signed by: akyoto
GPG Key ID: C874F672B1AF20C0
9 changed files with 135 additions and 157 deletions

View File

@ -4,6 +4,7 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"git.akyoto.dev/go/router/testdata"
"git.akyoto.dev/go/server" "git.akyoto.dev/go/server"
) )
@ -26,7 +27,7 @@ func BenchmarkGitHub(b *testing.B) {
response := &NullResponse{} response := &NullResponse{}
s := server.New() s := server.New()
for _, route := range loadRoutes("testdata/github.txt") { for _, route := range testdata.Routes("testdata/github.txt") {
s.Router.Add(route.Method, route.Path, func(server.Context) error { s.Router.Add(route.Method, route.Path, func(server.Context) error {
return nil return nil
}) })

View File

@ -15,28 +15,38 @@ type Context interface {
Bytes([]byte) error Bytes([]byte) error
Error(messages ...any) error Error(messages ...any) error
Get(param string) string Get(param string) string
Header(key string, value string)
Host() string
Method() string
Next() error
Path() string
Protocol() string
Reader(io.Reader) error Reader(io.Reader) error
Request() Request RequestHeader(key string) string
Response() Response ResponseHeader(key string) string
Status(status int) Context Status(status int) Context
String(string) error String(string) error
} }
// ctx represents a request & response context. // ctx represents a request & response context.
type ctx struct { type ctx struct {
request request request *http.Request
response response response http.ResponseWriter
server *Server
paramNames [maxParams]string paramNames [maxParams]string
paramValues [maxParams]string paramValues [maxParams]string
paramCount int paramCount int
handlerCount int
} }
// newContext returns a new context from the pool. // newContext returns a new context from the pool.
func newContext(req *http.Request, res http.ResponseWriter) *ctx { func newContext(req *http.Request, res http.ResponseWriter, server *Server) *ctx {
ctx := contextPool.Get().(*ctx) ctx := contextPool.Get().(*ctx)
ctx.request.Request = req ctx.request = req
ctx.response.ResponseWriter = res ctx.response = res
ctx.server = server
ctx.paramCount = 0 ctx.paramCount = 0
ctx.handlerCount = 0
return ctx return ctx
} }
@ -73,22 +83,53 @@ func (ctx *ctx) Get(param string) string {
return "" return ""
} }
// Next executes the next handler in the middleware chain.
func (ctx *ctx) Next() error {
ctx.handlerCount++
return ctx.server.handlers[ctx.handlerCount](ctx)
}
// RequestHeader returns the request header value for the given key.
func (ctx *ctx) RequestHeader(key string) string {
return ctx.request.Header.Get(key)
}
// ResponseHeader returns the response header value for the given key.
func (ctx *ctx) ResponseHeader(key string) string {
return ctx.response.Header().Get(key)
}
// Header sets the header value for the given key.
func (ctx *ctx) Header(key string, value string) {
ctx.response.Header().Set(key, value)
}
// Method returns the request method.
func (ctx *ctx) Method() string {
return ctx.request.Method
}
// Protocol returns the request protocol.
func (ctx *ctx) Protocol() string {
return ctx.request.Proto
}
// Host returns the requested host.
func (ctx *ctx) Host() string {
return ctx.request.Host
}
// Path returns the requested path.
func (ctx *ctx) Path() string {
return ctx.request.URL.Path
}
// Reader sends the contents of the io.Reader without creating an in-memory copy. // Reader sends the contents of the io.Reader without creating an in-memory copy.
func (ctx *ctx) Reader(reader io.Reader) error { func (ctx *ctx) Reader(reader io.Reader) error {
_, err := io.Copy(ctx.response, reader) _, err := io.Copy(ctx.response, reader)
return err return err
} }
// Request returns the HTTP request.
func (ctx *ctx) Request() Request {
return &ctx.request
}
// Response returns the HTTP response.
func (ctx *ctx) Response() Response {
return &ctx.response
}
// Status sets the HTTP status of the response. // Status sets the HTTP status of the response.
func (ctx *ctx) Status(status int) Context { func (ctx *ctx) Status(status int) Context {
ctx.response.WriteHeader(status) ctx.response.WriteHeader(status)

View File

@ -40,6 +40,7 @@ s.Run(":8080")
``` ```
PASS: TestRouter PASS: TestRouter
PASS: TestMiddleware
PASS: TestPanic PASS: TestPanic
PASS: TestRun PASS: TestRun
PASS: TestUnavailablePort PASS: TestUnavailablePort
@ -49,8 +50,8 @@ coverage: 100.0% of statements
## Benchmarks ## Benchmarks
``` ```
BenchmarkHello-12 33502272 33.64 ns/op 0 B/op 0 allocs/op BenchmarkHello-12 31128832 38.41 ns/op 0 B/op 0 allocs/op
BenchmarkGitHub-12 17698947 65.50 ns/op 0 B/op 0 allocs/op BenchmarkGitHub-12 17406580 68.56 ns/op 0 B/op 0 allocs/op
``` ```
## License ## License

View File

@ -1,42 +0,0 @@
package server
import "net/http"
// Request is an interface for HTTP requests.
type Request interface {
Header(key string) string
Host() string
Method() string
Path() string
Protocol() string
}
// request represents the HTTP request used in the given context.
type request struct {
*http.Request
}
// Header returns the header value for the given key.
func (req request) Header(key string) string {
return req.Request.Header.Get(key)
}
// Method returns the request method.
func (req request) Method() string {
return req.Request.Method
}
// Protocol returns the request protocol.
func (req request) Protocol() string {
return req.Request.Proto
}
// Host returns the requested host.
func (req request) Host() string {
return req.Request.Host
}
// Path returns the requested path.
func (req request) Path() string {
return req.Request.URL.Path
}

View File

@ -1,24 +0,0 @@
package server
import "net/http"
// Response is the interface for an HTTP response.
type Response interface {
Header(key string) string
SetHeader(key string, value string)
}
// response represents the HTTP response used in the given context.
type response struct {
http.ResponseWriter
}
// Header returns the header value for the given key.
func (res response) Header(key string) string {
return res.ResponseWriter.Header().Get(key)
}
// SetHeader sets the header value for the given key.
func (res response) SetHeader(key string, value string) {
res.ResponseWriter.Header().Set(key, value)
}

View File

@ -1,46 +0,0 @@
package server_test
import (
"bufio"
"os"
"strings"
)
// Route represents a single line in the test data.
type Route struct {
Method string
Path string
}
// loadRoutes loads all routes from a text file.
func loadRoutes(filePath string) []Route {
var routes []Route
f, err := os.Open(filePath)
if err != nil {
panic(err)
}
defer f.Close()
reader := bufio.NewReader(f)
for {
line, err := reader.ReadString('\n')
if line != "" {
line = strings.TrimSpace(line)
parts := strings.Split(line, " ")
routes = append(routes, Route{
Method: parts[0],
Path: parts[1],
})
}
if err != nil {
break
}
}
return routes
}

View File

@ -17,6 +17,7 @@ import (
type Server struct { type Server struct {
Router *router.Router[Handler] Router *router.Router[Handler]
Config Configuration Config Configuration
handlers []Handler
} }
// New creates a new server. // New creates a new server.
@ -24,6 +25,17 @@ func New() *Server {
return &Server{ return &Server{
Router: router.New[Handler](), Router: router.New[Handler](),
Config: defaultConfig(), Config: defaultConfig(),
handlers: []Handler{
func(c Context) error {
handler := c.(*ctx).server.Router.LookupNoAlloc(c.Method(), c.Path(), c.(*ctx).addParameter)
if handler == nil {
return c.Status(http.StatusNotFound).String(http.StatusText(http.StatusNotFound))
}
return handler(c)
},
},
} }
} }
@ -49,18 +61,10 @@ func (server *Server) Put(path string, handler Handler) {
// ServeHTTP responds to the given request. // ServeHTTP responds to the given request.
func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Request) { func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Request) {
ctx := newContext(request, response) ctx := newContext(request, response, server)
defer contextPool.Put(ctx) defer contextPool.Put(ctx)
handler := server.Router.LookupNoAlloc(request.Method, request.URL.Path, ctx.addParameter) err := server.handlers[0](ctx)
if handler == nil {
response.WriteHeader(http.StatusNotFound)
response.(io.StringWriter).WriteString(http.StatusText(http.StatusNotFound))
return
}
err := handler(ctx)
if err != nil { if err != nil {
response.(io.StringWriter).WriteString(err.Error()) response.(io.StringWriter).WriteString(err.Error())
@ -96,3 +100,10 @@ func (server *Server) Run(address string) error {
return srv.Shutdown(ctx) return srv.Shutdown(ctx)
} }
// Use adds handlers to your handlers chain.
func (server *Server) Use(handlers ...Handler) {
last := server.handlers[len(server.handlers)-1]
server.handlers = append(server.handlers[:len(server.handlers)-1], handlers...)
server.handlers = append(server.handlers, last)
}

View File

@ -39,24 +39,21 @@ func TestRouter(t *testing.T) {
}) })
s.Get("/request/data", func(ctx server.Context) error { s.Get("/request/data", func(ctx server.Context) error {
request := ctx.Request() method := ctx.Method()
method := request.Method() protocol := ctx.Protocol()
protocol := request.Protocol() host := ctx.Host()
host := request.Host() path := ctx.Path()
path := request.Path()
return ctx.String(fmt.Sprintf("%s %s %s %s", method, protocol, host, path)) return ctx.String(fmt.Sprintf("%s %s %s %s", method, protocol, host, path))
}) })
s.Get("/request/header", func(ctx server.Context) error { s.Get("/request/header", func(ctx server.Context) error {
request := ctx.Request() acceptEncoding := ctx.RequestHeader("Accept-Encoding")
acceptEncoding := request.Header("Accept-Encoding")
return ctx.String(acceptEncoding) return ctx.String(acceptEncoding)
}) })
s.Get("/response/header", func(ctx server.Context) error { s.Get("/response/header", func(ctx server.Context) error {
response := ctx.Response() ctx.Header("Content-Type", "text/plain")
response.SetHeader("Content-Type", "text/plain") contentType := ctx.ResponseHeader("Content-Type")
contentType := response.Header("Content-Type")
return ctx.String(contentType) return ctx.String(contentType)
}) })
@ -120,6 +117,21 @@ func TestRouter(t *testing.T) {
} }
} }
func TestMiddleware(t *testing.T) {
s := server.New()
s.Use(func(ctx server.Context) error {
ctx.Header("Middleware", "true")
return ctx.Next()
})
request := httptest.NewRequest(http.MethodGet, "/", nil)
response := httptest.NewRecorder()
s.ServeHTTP(response, request)
assert.Equal(t, response.Header().Get("Middleware"), "true")
}
func TestPanic(t *testing.T) { func TestPanic(t *testing.T) {
s := server.New() s := server.New()

24
examples/logger/main.go Normal file
View File

@ -0,0 +1,24 @@
package main
import (
"fmt"
"time"
"git.akyoto.dev/go/server"
)
func main() {
s := server.New()
s.Use(func(ctx server.Context) error {
start := time.Now()
defer func() {
fmt.Println(ctx.Path(), time.Since(start))
}()
return ctx.Next()
})
s.Run(":8080")
}