Added middleware support
This commit is contained in:
parent
02328cb41e
commit
c13dbc55d2
@ -4,6 +4,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.akyoto.dev/go/router/testdata"
|
||||
"git.akyoto.dev/go/server"
|
||||
)
|
||||
|
||||
@ -26,7 +27,7 @@ func BenchmarkGitHub(b *testing.B) {
|
||||
response := &NullResponse{}
|
||||
s := server.New()
|
||||
|
||||
for _, route := range loadRoutes("testdata/github.txt") {
|
||||
for _, route := range testdata.Routes("testdata/github.txt") {
|
||||
s.Router.Add(route.Method, route.Path, func(server.Context) error {
|
||||
return nil
|
||||
})
|
||||
|
75
Context.go
75
Context.go
@ -15,28 +15,38 @@ type Context interface {
|
||||
Bytes([]byte) error
|
||||
Error(messages ...any) error
|
||||
Get(param string) string
|
||||
Header(key string, value string)
|
||||
Host() string
|
||||
Method() string
|
||||
Next() error
|
||||
Path() string
|
||||
Protocol() string
|
||||
Reader(io.Reader) error
|
||||
Request() Request
|
||||
Response() Response
|
||||
RequestHeader(key string) string
|
||||
ResponseHeader(key string) string
|
||||
Status(status int) Context
|
||||
String(string) error
|
||||
}
|
||||
|
||||
// ctx represents a request & response context.
|
||||
type ctx struct {
|
||||
request request
|
||||
response response
|
||||
request *http.Request
|
||||
response http.ResponseWriter
|
||||
server *Server
|
||||
paramNames [maxParams]string
|
||||
paramValues [maxParams]string
|
||||
paramCount int
|
||||
handlerCount int
|
||||
}
|
||||
|
||||
// newContext returns a new context from the pool.
|
||||
func newContext(req *http.Request, res http.ResponseWriter) *ctx {
|
||||
func newContext(req *http.Request, res http.ResponseWriter, server *Server) *ctx {
|
||||
ctx := contextPool.Get().(*ctx)
|
||||
ctx.request.Request = req
|
||||
ctx.response.ResponseWriter = res
|
||||
ctx.request = req
|
||||
ctx.response = res
|
||||
ctx.server = server
|
||||
ctx.paramCount = 0
|
||||
ctx.handlerCount = 0
|
||||
return ctx
|
||||
}
|
||||
|
||||
@ -73,22 +83,53 @@ func (ctx *ctx) Get(param string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Next executes the next handler in the middleware chain.
|
||||
func (ctx *ctx) Next() error {
|
||||
ctx.handlerCount++
|
||||
return ctx.server.handlers[ctx.handlerCount](ctx)
|
||||
}
|
||||
|
||||
// RequestHeader returns the request header value for the given key.
|
||||
func (ctx *ctx) RequestHeader(key string) string {
|
||||
return ctx.request.Header.Get(key)
|
||||
}
|
||||
|
||||
// ResponseHeader returns the response header value for the given key.
|
||||
func (ctx *ctx) ResponseHeader(key string) string {
|
||||
return ctx.response.Header().Get(key)
|
||||
}
|
||||
|
||||
// Header sets the header value for the given key.
|
||||
func (ctx *ctx) Header(key string, value string) {
|
||||
ctx.response.Header().Set(key, value)
|
||||
}
|
||||
|
||||
// Method returns the request method.
|
||||
func (ctx *ctx) Method() string {
|
||||
return ctx.request.Method
|
||||
}
|
||||
|
||||
// Protocol returns the request protocol.
|
||||
func (ctx *ctx) Protocol() string {
|
||||
return ctx.request.Proto
|
||||
}
|
||||
|
||||
// Host returns the requested host.
|
||||
func (ctx *ctx) Host() string {
|
||||
return ctx.request.Host
|
||||
}
|
||||
|
||||
// Path returns the requested path.
|
||||
func (ctx *ctx) Path() string {
|
||||
return ctx.request.URL.Path
|
||||
}
|
||||
|
||||
// Reader sends the contents of the io.Reader without creating an in-memory copy.
|
||||
func (ctx *ctx) Reader(reader io.Reader) error {
|
||||
_, err := io.Copy(ctx.response, reader)
|
||||
return err
|
||||
}
|
||||
|
||||
// Request returns the HTTP request.
|
||||
func (ctx *ctx) Request() Request {
|
||||
return &ctx.request
|
||||
}
|
||||
|
||||
// Response returns the HTTP response.
|
||||
func (ctx *ctx) Response() Response {
|
||||
return &ctx.response
|
||||
}
|
||||
|
||||
// Status sets the HTTP status of the response.
|
||||
func (ctx *ctx) Status(status int) Context {
|
||||
ctx.response.WriteHeader(status)
|
||||
|
@ -40,6 +40,7 @@ s.Run(":8080")
|
||||
|
||||
```
|
||||
PASS: TestRouter
|
||||
PASS: TestMiddleware
|
||||
PASS: TestPanic
|
||||
PASS: TestRun
|
||||
PASS: TestUnavailablePort
|
||||
@ -49,8 +50,8 @@ coverage: 100.0% of statements
|
||||
## Benchmarks
|
||||
|
||||
```
|
||||
BenchmarkHello-12 33502272 33.64 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkGitHub-12 17698947 65.50 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkHello-12 31128832 38.41 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkGitHub-12 17406580 68.56 ns/op 0 B/op 0 allocs/op
|
||||
```
|
||||
|
||||
## License
|
||||
|
42
Request.go
42
Request.go
@ -1,42 +0,0 @@
|
||||
package server
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Request is an interface for HTTP requests.
|
||||
type Request interface {
|
||||
Header(key string) string
|
||||
Host() string
|
||||
Method() string
|
||||
Path() string
|
||||
Protocol() string
|
||||
}
|
||||
|
||||
// request represents the HTTP request used in the given context.
|
||||
type request struct {
|
||||
*http.Request
|
||||
}
|
||||
|
||||
// Header returns the header value for the given key.
|
||||
func (req request) Header(key string) string {
|
||||
return req.Request.Header.Get(key)
|
||||
}
|
||||
|
||||
// Method returns the request method.
|
||||
func (req request) Method() string {
|
||||
return req.Request.Method
|
||||
}
|
||||
|
||||
// Protocol returns the request protocol.
|
||||
func (req request) Protocol() string {
|
||||
return req.Request.Proto
|
||||
}
|
||||
|
||||
// Host returns the requested host.
|
||||
func (req request) Host() string {
|
||||
return req.Request.Host
|
||||
}
|
||||
|
||||
// Path returns the requested path.
|
||||
func (req request) Path() string {
|
||||
return req.Request.URL.Path
|
||||
}
|
24
Response.go
24
Response.go
@ -1,24 +0,0 @@
|
||||
package server
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Response is the interface for an HTTP response.
|
||||
type Response interface {
|
||||
Header(key string) string
|
||||
SetHeader(key string, value string)
|
||||
}
|
||||
|
||||
// response represents the HTTP response used in the given context.
|
||||
type response struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
// Header returns the header value for the given key.
|
||||
func (res response) Header(key string) string {
|
||||
return res.ResponseWriter.Header().Get(key)
|
||||
}
|
||||
|
||||
// SetHeader sets the header value for the given key.
|
||||
func (res response) SetHeader(key string, value string) {
|
||||
res.ResponseWriter.Header().Set(key, value)
|
||||
}
|
@ -1,46 +0,0 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Route represents a single line in the test data.
|
||||
type Route struct {
|
||||
Method string
|
||||
Path string
|
||||
}
|
||||
|
||||
// loadRoutes loads all routes from a text file.
|
||||
func loadRoutes(filePath string) []Route {
|
||||
var routes []Route
|
||||
f, err := os.Open(filePath)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
reader := bufio.NewReader(f)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
|
||||
if line != "" {
|
||||
line = strings.TrimSpace(line)
|
||||
parts := strings.Split(line, " ")
|
||||
|
||||
routes = append(routes, Route{
|
||||
Method: parts[0],
|
||||
Path: parts[1],
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
31
Server.go
31
Server.go
@ -17,6 +17,7 @@ import (
|
||||
type Server struct {
|
||||
Router *router.Router[Handler]
|
||||
Config Configuration
|
||||
handlers []Handler
|
||||
}
|
||||
|
||||
// New creates a new server.
|
||||
@ -24,6 +25,17 @@ func New() *Server {
|
||||
return &Server{
|
||||
Router: router.New[Handler](),
|
||||
Config: defaultConfig(),
|
||||
handlers: []Handler{
|
||||
func(c Context) error {
|
||||
handler := c.(*ctx).server.Router.LookupNoAlloc(c.Method(), c.Path(), c.(*ctx).addParameter)
|
||||
|
||||
if handler == nil {
|
||||
return c.Status(http.StatusNotFound).String(http.StatusText(http.StatusNotFound))
|
||||
}
|
||||
|
||||
return handler(c)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,18 +61,10 @@ func (server *Server) Put(path string, handler Handler) {
|
||||
|
||||
// ServeHTTP responds to the given request.
|
||||
func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Request) {
|
||||
ctx := newContext(request, response)
|
||||
ctx := newContext(request, response, server)
|
||||
defer contextPool.Put(ctx)
|
||||
|
||||
handler := server.Router.LookupNoAlloc(request.Method, request.URL.Path, ctx.addParameter)
|
||||
|
||||
if handler == nil {
|
||||
response.WriteHeader(http.StatusNotFound)
|
||||
response.(io.StringWriter).WriteString(http.StatusText(http.StatusNotFound))
|
||||
return
|
||||
}
|
||||
|
||||
err := handler(ctx)
|
||||
err := server.handlers[0](ctx)
|
||||
|
||||
if err != nil {
|
||||
response.(io.StringWriter).WriteString(err.Error())
|
||||
@ -96,3 +100,10 @@ func (server *Server) Run(address string) error {
|
||||
|
||||
return srv.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// Use adds handlers to your handlers chain.
|
||||
func (server *Server) Use(handlers ...Handler) {
|
||||
last := server.handlers[len(server.handlers)-1]
|
||||
server.handlers = append(server.handlers[:len(server.handlers)-1], handlers...)
|
||||
server.handlers = append(server.handlers, last)
|
||||
}
|
||||
|
@ -39,24 +39,21 @@ func TestRouter(t *testing.T) {
|
||||
})
|
||||
|
||||
s.Get("/request/data", func(ctx server.Context) error {
|
||||
request := ctx.Request()
|
||||
method := request.Method()
|
||||
protocol := request.Protocol()
|
||||
host := request.Host()
|
||||
path := request.Path()
|
||||
method := ctx.Method()
|
||||
protocol := ctx.Protocol()
|
||||
host := ctx.Host()
|
||||
path := ctx.Path()
|
||||
return ctx.String(fmt.Sprintf("%s %s %s %s", method, protocol, host, path))
|
||||
})
|
||||
|
||||
s.Get("/request/header", func(ctx server.Context) error {
|
||||
request := ctx.Request()
|
||||
acceptEncoding := request.Header("Accept-Encoding")
|
||||
acceptEncoding := ctx.RequestHeader("Accept-Encoding")
|
||||
return ctx.String(acceptEncoding)
|
||||
})
|
||||
|
||||
s.Get("/response/header", func(ctx server.Context) error {
|
||||
response := ctx.Response()
|
||||
response.SetHeader("Content-Type", "text/plain")
|
||||
contentType := response.Header("Content-Type")
|
||||
ctx.Header("Content-Type", "text/plain")
|
||||
contentType := ctx.ResponseHeader("Content-Type")
|
||||
return ctx.String(contentType)
|
||||
})
|
||||
|
||||
@ -120,6 +117,21 @@ func TestRouter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
s := server.New()
|
||||
|
||||
s.Use(func(ctx server.Context) error {
|
||||
ctx.Header("Middleware", "true")
|
||||
return ctx.Next()
|
||||
})
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
response := httptest.NewRecorder()
|
||||
s.ServeHTTP(response, request)
|
||||
|
||||
assert.Equal(t, response.Header().Get("Middleware"), "true")
|
||||
}
|
||||
|
||||
func TestPanic(t *testing.T) {
|
||||
s := server.New()
|
||||
|
||||
|
24
examples/logger/main.go
Normal file
24
examples/logger/main.go
Normal file
@ -0,0 +1,24 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.akyoto.dev/go/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
s := server.New()
|
||||
|
||||
s.Use(func(ctx server.Context) error {
|
||||
start := time.Now()
|
||||
|
||||
defer func() {
|
||||
fmt.Println(ctx.Path(), time.Since(start))
|
||||
}()
|
||||
|
||||
return ctx.Next()
|
||||
})
|
||||
|
||||
s.Run(":8080")
|
||||
}
|
Loading…
Reference in New Issue
Block a user