diff --git a/Context.go b/Context.go index 460a648..b2715e0 100644 --- a/Context.go +++ b/Context.go @@ -16,24 +16,26 @@ type Context interface { Error(status int, messages ...any) error Get(param string) string Reader(io.Reader) error + Request() Request + Response() Response SetStatus(status int) String(string) error } // context represents a request & response context. type context struct { - request *http.Request - response http.ResponseWriter + request request + response response paramNames [maxParams]string paramValues [maxParams]string paramCount int } // newContext returns a new context from the pool. -func newContext(request *http.Request, response http.ResponseWriter) *context { +func newContext(req *http.Request, res http.ResponseWriter) *context { ctx := contextPool.Get().(*context) - ctx.request = request - ctx.response = response + ctx.request.Request = req + ctx.response.ResponseWriter = res ctx.paramCount = 0 return ctx } @@ -78,6 +80,16 @@ func (ctx *context) Reader(reader io.Reader) error { return err } +// Request returns the HTTP request. +func (ctx *context) Request() Request { + return &ctx.request +} + +// Response returns the HTTP response. +func (ctx *context) Response() Response { + return &ctx.response +} + // SetStatus writes the header with the given HTTP status code. func (ctx *context) SetStatus(status int) { ctx.response.WriteHeader(status) diff --git a/Request.go b/Request.go new file mode 100644 index 0000000..85a9a50 --- /dev/null +++ b/Request.go @@ -0,0 +1,42 @@ +package server + +import "net/http" + +// Request is an interface for HTTP requests. +type Request interface { + Header(key string) string + Host() string + Method() string + Path() string + Protocol() string +} + +// request represents the HTTP request used in the given context. +type request struct { + *http.Request +} + +// Header returns the header value for the given key. +func (req request) Header(key string) string { + return req.Request.Header.Get(key) +} + +// Method returns the request method. +func (req request) Method() string { + return req.Request.Method +} + +// Protocol returns the request protocol. +func (req request) Protocol() string { + return req.Request.Proto +} + +// Host returns the requested host. +func (req request) Host() string { + return req.Request.Host +} + +// Path returns the requested path. +func (req request) Path() string { + return req.Request.URL.Path +} diff --git a/Response.go b/Response.go new file mode 100644 index 0000000..db45bed --- /dev/null +++ b/Response.go @@ -0,0 +1,24 @@ +package server + +import "net/http" + +// Response is the interface for an HTTP response. +type Response interface { + Header(key string) string + SetHeader(key string, value string) +} + +// response represents the HTTP response used in the given context. +type response struct { + http.ResponseWriter +} + +// Header returns the header value for the given key. +func (res response) Header(key string) string { + return res.ResponseWriter.Header().Get(key) +} + +// SetHeader sets the header value for the given key. +func (res response) SetHeader(key string, value string) { + res.ResponseWriter.Header().Set(key, value) +} diff --git a/Server_test.go b/Server_test.go index 2823690..1bf355e 100644 --- a/Server_test.go +++ b/Server_test.go @@ -2,6 +2,7 @@ package server_test import ( "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -35,6 +36,28 @@ func TestRouter(t *testing.T) { return ctx.Reader(strings.NewReader("Hello")) }) + s.Get("/request/data", func(ctx server.Context) error { + request := ctx.Request() + method := request.Method() + protocol := request.Protocol() + host := request.Host() + path := request.Path() + return ctx.String(fmt.Sprintf("%s %s %s %s", method, protocol, host, path)) + }) + + s.Get("/request/header", func(ctx server.Context) error { + request := ctx.Request() + acceptEncoding := request.Header("Accept-Encoding") + return ctx.String(acceptEncoding) + }) + + s.Get("/response/header", func(ctx server.Context) error { + response := ctx.Response() + response.SetHeader("Content-Type", "text/plain") + contentType := response.Header("Content-Type") + return ctx.String(contentType) + }) + s.Get("/blog/:article", func(ctx server.Context) error { article := ctx.Get("article") return ctx.String(article) @@ -67,6 +90,9 @@ func TestRouter(t *testing.T) { {Method: "GET", URL: "/error", Status: http.StatusUnauthorized, Body: "Not logged in"}, {Method: "GET", URL: "/error2", Status: http.StatusUnauthorized, Body: "Not logged in\nMissing auth token"}, {Method: "GET", URL: "/not-found", Status: http.StatusNotFound, Body: http.StatusText(http.StatusNotFound)}, + {Method: "GET", URL: "/request/data", Status: http.StatusOK, Body: "GET HTTP/1.1 example.com /request/data"}, + {Method: "GET", URL: "/request/header", Status: http.StatusOK, Body: ""}, + {Method: "GET", URL: "/response/header", Status: http.StatusOK, Body: "text/plain"}, {Method: "GET", URL: "/reader", Status: http.StatusOK, Body: "Hello"}, {Method: "GET", URL: "/string", Status: http.StatusOK, Body: "Hello"}, {Method: "GET", URL: "/blog/testing-my-router", Status: http.StatusOK, Body: "testing-my-router"},