diff --git a/Context.go b/Context.go index 8c07b97..bd6cecf 100644 --- a/Context.go +++ b/Context.go @@ -1,7 +1,9 @@ package server import ( + "io" "net/http" + "unsafe" ) // maxParams defines the maximum number of parameters per route. @@ -11,6 +13,8 @@ const maxParams = 16 type Context interface { Bytes([]byte) error Error(int, error) error + Reader(io.Reader) error + String(string) error } // context represents a request & response context. @@ -43,6 +47,18 @@ func (ctx *context) Error(status int, err error) error { return err } +// Reader sends the contents of the io.Reader without creating an in-memory copy. +func (ctx *context) Reader(reader io.Reader) error { + _, err := io.Copy(ctx.response, reader) + return err +} + +// String responds with the given string. +func (ctx *context) String(body string) error { + slice := unsafe.Slice(unsafe.StringData(body), len(body)) + return ctx.Bytes(slice) +} + // addParameter adds a new parameter to the context. func (ctx *context) addParameter(name string, value string) { ctx.paramNames[ctx.paramCount] = name diff --git a/Server_test.go b/Server_test.go index aa65971..880d87a 100644 --- a/Server_test.go +++ b/Server_test.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "git.akyoto.dev/go/assert" @@ -26,6 +27,14 @@ func TestRouter(t *testing.T) { return ctx.Error(http.StatusUnauthorized, errors.New("Not logged in")) }) + s.Get("/reader", func(ctx server.Context) error { + return ctx.Reader(strings.NewReader("Hello")) + }) + + s.Get("/string", func(ctx server.Context) error { + return ctx.String("Hello") + }) + tests := []struct { URL string Status int @@ -35,6 +44,8 @@ func TestRouter(t *testing.T) { {URL: "/blog/post", Status: http.StatusOK, Body: "Hello"}, {URL: "/error", Status: http.StatusUnauthorized, Body: "Not logged in"}, {URL: "/not-found", Status: http.StatusNotFound, Body: http.StatusText(http.StatusNotFound)}, + {URL: "/reader", Status: http.StatusOK, Body: "Hello"}, + {URL: "/string", Status: http.StatusOK, Body: "Hello"}, } for _, test := range tests {