Added middleware support
This commit is contained in:
parent
1bcf4794f5
commit
e69a66aa31
@ -4,6 +4,7 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"git.akyoto.dev/go/router/testdata"
|
||||||
"git.akyoto.dev/go/server"
|
"git.akyoto.dev/go/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ func BenchmarkGitHub(b *testing.B) {
|
|||||||
response := &NullResponse{}
|
response := &NullResponse{}
|
||||||
s := server.New()
|
s := server.New()
|
||||||
|
|
||||||
for _, route := range loadRoutes("testdata/github.txt") {
|
for _, route := range testdata.Routes("testdata/github.txt") {
|
||||||
s.Router.Add(route.Method, route.Path, func(server.Context) error {
|
s.Router.Add(route.Method, route.Path, func(server.Context) error {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
75
Context.go
75
Context.go
@ -15,28 +15,38 @@ type Context interface {
|
|||||||
Bytes([]byte) error
|
Bytes([]byte) error
|
||||||
Error(messages ...any) error
|
Error(messages ...any) error
|
||||||
Get(param string) string
|
Get(param string) string
|
||||||
|
Header(key string, value string)
|
||||||
|
Host() string
|
||||||
|
Method() string
|
||||||
|
Next() error
|
||||||
|
Path() string
|
||||||
|
Protocol() string
|
||||||
Reader(io.Reader) error
|
Reader(io.Reader) error
|
||||||
Request() Request
|
RequestHeader(key string) string
|
||||||
Response() Response
|
ResponseHeader(key string) string
|
||||||
Status(status int) Context
|
Status(status int) Context
|
||||||
String(string) error
|
String(string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ctx represents a request & response context.
|
// ctx represents a request & response context.
|
||||||
type ctx struct {
|
type ctx struct {
|
||||||
request request
|
request *http.Request
|
||||||
response response
|
response http.ResponseWriter
|
||||||
|
server *Server
|
||||||
paramNames [maxParams]string
|
paramNames [maxParams]string
|
||||||
paramValues [maxParams]string
|
paramValues [maxParams]string
|
||||||
paramCount int
|
paramCount int
|
||||||
|
handlerCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
// newContext returns a new context from the pool.
|
// newContext returns a new context from the pool.
|
||||||
func newContext(req *http.Request, res http.ResponseWriter) *ctx {
|
func newContext(req *http.Request, res http.ResponseWriter, server *Server) *ctx {
|
||||||
ctx := contextPool.Get().(*ctx)
|
ctx := contextPool.Get().(*ctx)
|
||||||
ctx.request.Request = req
|
ctx.request = req
|
||||||
ctx.response.ResponseWriter = res
|
ctx.response = res
|
||||||
|
ctx.server = server
|
||||||
ctx.paramCount = 0
|
ctx.paramCount = 0
|
||||||
|
ctx.handlerCount = 0
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,22 +83,53 @@ func (ctx *ctx) Get(param string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Next executes the next handler in the middleware chain.
|
||||||
|
func (ctx *ctx) Next() error {
|
||||||
|
ctx.handlerCount++
|
||||||
|
return ctx.server.handlers[ctx.handlerCount](ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestHeader returns the request header value for the given key.
|
||||||
|
func (ctx *ctx) RequestHeader(key string) string {
|
||||||
|
return ctx.request.Header.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseHeader returns the response header value for the given key.
|
||||||
|
func (ctx *ctx) ResponseHeader(key string) string {
|
||||||
|
return ctx.response.Header().Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Header sets the header value for the given key.
|
||||||
|
func (ctx *ctx) Header(key string, value string) {
|
||||||
|
ctx.response.Header().Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method returns the request method.
|
||||||
|
func (ctx *ctx) Method() string {
|
||||||
|
return ctx.request.Method
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protocol returns the request protocol.
|
||||||
|
func (ctx *ctx) Protocol() string {
|
||||||
|
return ctx.request.Proto
|
||||||
|
}
|
||||||
|
|
||||||
|
// Host returns the requested host.
|
||||||
|
func (ctx *ctx) Host() string {
|
||||||
|
return ctx.request.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
// Path returns the requested path.
|
||||||
|
func (ctx *ctx) Path() string {
|
||||||
|
return ctx.request.URL.Path
|
||||||
|
}
|
||||||
|
|
||||||
// Reader sends the contents of the io.Reader without creating an in-memory copy.
|
// Reader sends the contents of the io.Reader without creating an in-memory copy.
|
||||||
func (ctx *ctx) Reader(reader io.Reader) error {
|
func (ctx *ctx) Reader(reader io.Reader) error {
|
||||||
_, err := io.Copy(ctx.response, reader)
|
_, err := io.Copy(ctx.response, reader)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request returns the HTTP request.
|
|
||||||
func (ctx *ctx) Request() Request {
|
|
||||||
return &ctx.request
|
|
||||||
}
|
|
||||||
|
|
||||||
// Response returns the HTTP response.
|
|
||||||
func (ctx *ctx) Response() Response {
|
|
||||||
return &ctx.response
|
|
||||||
}
|
|
||||||
|
|
||||||
// Status sets the HTTP status of the response.
|
// Status sets the HTTP status of the response.
|
||||||
func (ctx *ctx) Status(status int) Context {
|
func (ctx *ctx) Status(status int) Context {
|
||||||
ctx.response.WriteHeader(status)
|
ctx.response.WriteHeader(status)
|
||||||
|
@ -40,6 +40,7 @@ s.Run(":8080")
|
|||||||
|
|
||||||
```
|
```
|
||||||
PASS: TestRouter
|
PASS: TestRouter
|
||||||
|
PASS: TestMiddleware
|
||||||
PASS: TestPanic
|
PASS: TestPanic
|
||||||
PASS: TestRun
|
PASS: TestRun
|
||||||
PASS: TestUnavailablePort
|
PASS: TestUnavailablePort
|
||||||
@ -49,8 +50,8 @@ coverage: 100.0% of statements
|
|||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
```
|
```
|
||||||
BenchmarkHello-12 33502272 33.64 ns/op 0 B/op 0 allocs/op
|
BenchmarkHello-12 31128832 38.41 ns/op 0 B/op 0 allocs/op
|
||||||
BenchmarkGitHub-12 17698947 65.50 ns/op 0 B/op 0 allocs/op
|
BenchmarkGitHub-12 17406580 68.56 ns/op 0 B/op 0 allocs/op
|
||||||
```
|
```
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
42
Request.go
42
Request.go
@ -1,42 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
// Request is an interface for HTTP requests.
|
|
||||||
type Request interface {
|
|
||||||
Header(key string) string
|
|
||||||
Host() string
|
|
||||||
Method() string
|
|
||||||
Path() string
|
|
||||||
Protocol() string
|
|
||||||
}
|
|
||||||
|
|
||||||
// request represents the HTTP request used in the given context.
|
|
||||||
type request struct {
|
|
||||||
*http.Request
|
|
||||||
}
|
|
||||||
|
|
||||||
// Header returns the header value for the given key.
|
|
||||||
func (req request) Header(key string) string {
|
|
||||||
return req.Request.Header.Get(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Method returns the request method.
|
|
||||||
func (req request) Method() string {
|
|
||||||
return req.Request.Method
|
|
||||||
}
|
|
||||||
|
|
||||||
// Protocol returns the request protocol.
|
|
||||||
func (req request) Protocol() string {
|
|
||||||
return req.Request.Proto
|
|
||||||
}
|
|
||||||
|
|
||||||
// Host returns the requested host.
|
|
||||||
func (req request) Host() string {
|
|
||||||
return req.Request.Host
|
|
||||||
}
|
|
||||||
|
|
||||||
// Path returns the requested path.
|
|
||||||
func (req request) Path() string {
|
|
||||||
return req.Request.URL.Path
|
|
||||||
}
|
|
24
Response.go
24
Response.go
@ -1,24 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
// Response is the interface for an HTTP response.
|
|
||||||
type Response interface {
|
|
||||||
Header(key string) string
|
|
||||||
SetHeader(key string, value string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// response represents the HTTP response used in the given context.
|
|
||||||
type response struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
// Header returns the header value for the given key.
|
|
||||||
func (res response) Header(key string) string {
|
|
||||||
return res.ResponseWriter.Header().Get(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetHeader sets the header value for the given key.
|
|
||||||
func (res response) SetHeader(key string, value string) {
|
|
||||||
res.ResponseWriter.Header().Set(key, value)
|
|
||||||
}
|
|
@ -1,46 +0,0 @@
|
|||||||
package server_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Route represents a single line in the test data.
|
|
||||||
type Route struct {
|
|
||||||
Method string
|
|
||||||
Path string
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadRoutes loads all routes from a text file.
|
|
||||||
func loadRoutes(filePath string) []Route {
|
|
||||||
var routes []Route
|
|
||||||
f, err := os.Open(filePath)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer f.Close()
|
|
||||||
reader := bufio.NewReader(f)
|
|
||||||
|
|
||||||
for {
|
|
||||||
line, err := reader.ReadString('\n')
|
|
||||||
|
|
||||||
if line != "" {
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
parts := strings.Split(line, " ")
|
|
||||||
|
|
||||||
routes = append(routes, Route{
|
|
||||||
Method: parts[0],
|
|
||||||
Path: parts[1],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return routes
|
|
||||||
}
|
|
31
Server.go
31
Server.go
@ -17,6 +17,7 @@ import (
|
|||||||
type Server struct {
|
type Server struct {
|
||||||
Router *router.Router[Handler]
|
Router *router.Router[Handler]
|
||||||
Config Configuration
|
Config Configuration
|
||||||
|
handlers []Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new server.
|
// New creates a new server.
|
||||||
@ -24,6 +25,17 @@ func New() *Server {
|
|||||||
return &Server{
|
return &Server{
|
||||||
Router: router.New[Handler](),
|
Router: router.New[Handler](),
|
||||||
Config: defaultConfig(),
|
Config: defaultConfig(),
|
||||||
|
handlers: []Handler{
|
||||||
|
func(c Context) error {
|
||||||
|
handler := c.(*ctx).server.Router.LookupNoAlloc(c.Method(), c.Path(), c.(*ctx).addParameter)
|
||||||
|
|
||||||
|
if handler == nil {
|
||||||
|
return c.Status(http.StatusNotFound).String(http.StatusText(http.StatusNotFound))
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler(c)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,18 +61,10 @@ func (server *Server) Put(path string, handler Handler) {
|
|||||||
|
|
||||||
// ServeHTTP responds to the given request.
|
// ServeHTTP responds to the given request.
|
||||||
func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Request) {
|
func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Request) {
|
||||||
ctx := newContext(request, response)
|
ctx := newContext(request, response, server)
|
||||||
defer contextPool.Put(ctx)
|
defer contextPool.Put(ctx)
|
||||||
|
|
||||||
handler := server.Router.LookupNoAlloc(request.Method, request.URL.Path, ctx.addParameter)
|
err := server.handlers[0](ctx)
|
||||||
|
|
||||||
if handler == nil {
|
|
||||||
response.WriteHeader(http.StatusNotFound)
|
|
||||||
response.(io.StringWriter).WriteString(http.StatusText(http.StatusNotFound))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := handler(ctx)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.(io.StringWriter).WriteString(err.Error())
|
response.(io.StringWriter).WriteString(err.Error())
|
||||||
@ -96,3 +100,10 @@ func (server *Server) Run(address string) error {
|
|||||||
|
|
||||||
return srv.Shutdown(ctx)
|
return srv.Shutdown(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use adds handlers to your handlers chain.
|
||||||
|
func (server *Server) Use(handlers ...Handler) {
|
||||||
|
last := server.handlers[len(server.handlers)-1]
|
||||||
|
server.handlers = append(server.handlers[:len(server.handlers)-1], handlers...)
|
||||||
|
server.handlers = append(server.handlers, last)
|
||||||
|
}
|
||||||
|
@ -39,24 +39,21 @@ func TestRouter(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
s.Get("/request/data", func(ctx server.Context) error {
|
s.Get("/request/data", func(ctx server.Context) error {
|
||||||
request := ctx.Request()
|
method := ctx.Method()
|
||||||
method := request.Method()
|
protocol := ctx.Protocol()
|
||||||
protocol := request.Protocol()
|
host := ctx.Host()
|
||||||
host := request.Host()
|
path := ctx.Path()
|
||||||
path := request.Path()
|
|
||||||
return ctx.String(fmt.Sprintf("%s %s %s %s", method, protocol, host, path))
|
return ctx.String(fmt.Sprintf("%s %s %s %s", method, protocol, host, path))
|
||||||
})
|
})
|
||||||
|
|
||||||
s.Get("/request/header", func(ctx server.Context) error {
|
s.Get("/request/header", func(ctx server.Context) error {
|
||||||
request := ctx.Request()
|
acceptEncoding := ctx.RequestHeader("Accept-Encoding")
|
||||||
acceptEncoding := request.Header("Accept-Encoding")
|
|
||||||
return ctx.String(acceptEncoding)
|
return ctx.String(acceptEncoding)
|
||||||
})
|
})
|
||||||
|
|
||||||
s.Get("/response/header", func(ctx server.Context) error {
|
s.Get("/response/header", func(ctx server.Context) error {
|
||||||
response := ctx.Response()
|
ctx.Header("Content-Type", "text/plain")
|
||||||
response.SetHeader("Content-Type", "text/plain")
|
contentType := ctx.ResponseHeader("Content-Type")
|
||||||
contentType := response.Header("Content-Type")
|
|
||||||
return ctx.String(contentType)
|
return ctx.String(contentType)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -120,6 +117,21 @@ func TestRouter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMiddleware(t *testing.T) {
|
||||||
|
s := server.New()
|
||||||
|
|
||||||
|
s.Use(func(ctx server.Context) error {
|
||||||
|
ctx.Header("Middleware", "true")
|
||||||
|
return ctx.Next()
|
||||||
|
})
|
||||||
|
|
||||||
|
request := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
response := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(response, request)
|
||||||
|
|
||||||
|
assert.Equal(t, response.Header().Get("Middleware"), "true")
|
||||||
|
}
|
||||||
|
|
||||||
func TestPanic(t *testing.T) {
|
func TestPanic(t *testing.T) {
|
||||||
s := server.New()
|
s := server.New()
|
||||||
|
|
||||||
|
24
examples/logger/main.go
Normal file
24
examples/logger/main.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.akyoto.dev/go/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
s := server.New()
|
||||||
|
|
||||||
|
s.Use(func(ctx server.Context) error {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
fmt.Println(ctx.Path(), time.Since(start))
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ctx.Next()
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run(":8080")
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user