diff --git a/Context_test.go b/Context_test.go new file mode 100644 index 0000000..c4d5b31 --- /dev/null +++ b/Context_test.go @@ -0,0 +1,57 @@ +package web_test + +import ( + "errors" + "testing" + + "git.akyoto.dev/go/assert" + "git.akyoto.dev/go/web" +) + +func TestBytes(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.Bytes([]byte("Hello")) + }) + + response := s.Request("GET", "/", nil) + assert.Equal(t, response.Status(), 200) + assert.Equal(t, string(response.Body()), "Hello") +} + +func TestString(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.String("Hello") + }) + + response := s.Request("GET", "/", nil) + assert.Equal(t, response.Status(), 200) + assert.Equal(t, string(response.Body()), "Hello") +} + +func TestError(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.Status(401).Error("Not logged in") + }) + + response := s.Request("GET", "/", nil) + assert.Equal(t, response.Status(), 401) + assert.Equal(t, string(response.Body()), "") +} + +func TestErrorMultiple(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + return ctx.Status(401).Error("Not logged in", errors.New("Missing auth token")) + }) + + response := s.Request("GET", "/", nil) + assert.Equal(t, response.Status(), 401) + assert.Equal(t, string(response.Body()), "") +} diff --git a/Request.go b/Request.go index 7c1d996..54dd6f8 100644 --- a/Request.go +++ b/Request.go @@ -8,6 +8,7 @@ type Request interface { Method() string Path() string Scheme() string + Param(string) string } // request represents the HTTP request used in the given context. @@ -16,6 +17,7 @@ type request struct { host string method string path string + query string params []router.Parameter } @@ -29,6 +31,19 @@ func (req *request) Method() string { return req.method } +// Param retrieves a parameter. +func (req *request) Param(name string) string { + for i := range len(req.params) { + p := req.params[i] + + if p.Key == name { + return p.Value + } + } + + return "" +} + // Path returns the requested path. func (req *request) Path() string { return req.path diff --git a/Request_test.go b/Request_test.go new file mode 100644 index 0000000..cbeda40 --- /dev/null +++ b/Request_test.go @@ -0,0 +1,39 @@ +package web_test + +import ( + "fmt" + "testing" + + "git.akyoto.dev/go/assert" + "git.akyoto.dev/go/web" +) + +func TestRequest(t *testing.T) { + s := web.NewServer() + + s.Get("/request", func(ctx web.Context) error { + req := ctx.Request() + method := req.Method() + scheme := req.Scheme() + host := req.Host() + path := req.Path() + return ctx.String(fmt.Sprintf("%s %s %s %s", method, scheme, host, path)) + }) + + response := s.Request("GET", "http://example.com/request?x=1", nil) + assert.Equal(t, response.Status(), 200) + assert.Equal(t, string(response.Body()), "GET http example.com /request") +} + +func TestRequestParam(t *testing.T) { + s := web.NewServer() + + s.Get("/blog/:article", func(ctx web.Context) error { + article := ctx.Request().Param("article") + return ctx.String(article) + }) + + response := s.Request("GET", "/blog/my-article", nil) + assert.Equal(t, response.Status(), 200) + assert.Equal(t, string(response.Body()), "my-article") +} diff --git a/Response.go b/Response.go index cc66cf8..dc46c03 100644 --- a/Response.go +++ b/Response.go @@ -12,10 +12,10 @@ type Response interface { io.Writer io.StringWriter Body() []byte - Header(key string) string + Header(string) string SetHeader(key string, value string) SetBody([]byte) - SetStatus(status int) + SetStatus(int) Status() int } @@ -44,9 +44,9 @@ func (res *response) Header(key string) string { // SetHeader sets the header value for the given key. func (res *response) SetHeader(key string, value string) { - for _, header := range res.headers { + for i, header := range res.headers { if header.Key == key { - header.Value = value + res.headers[i].Value = value return } } diff --git a/Response_test.go b/Response_test.go new file mode 100644 index 0000000..256bc32 --- /dev/null +++ b/Response_test.go @@ -0,0 +1,66 @@ +package web_test + +import ( + "io" + "testing" + + "git.akyoto.dev/go/assert" + "git.akyoto.dev/go/web" +) + +func TestWrite(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + _, err := ctx.Response().Write([]byte("Hello")) + return err + }) + + response := s.Request("GET", "/", nil) + assert.Equal(t, response.Status(), 200) + assert.Equal(t, string(response.Body()), "Hello") +} + +func TestWriteString(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + _, err := io.WriteString(ctx.Response(), "Hello") + return err + }) + + response := s.Request("GET", "/", nil) + assert.Equal(t, response.Status(), 200) + assert.Equal(t, string(response.Body()), "Hello") +} + +func TestResponseHeader(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + ctx.Response().SetHeader("Content-Type", "text/plain") + contentType := ctx.Response().Header("Content-Type") + return ctx.String(contentType) + }) + + response := s.Request("GET", "/", nil) + assert.Equal(t, response.Status(), 200) + assert.Equal(t, response.Header("Content-Type"), "text/plain") + assert.Equal(t, response.Header("Non existent header"), "") + assert.Equal(t, string(response.Body()), "text/plain") +} + +func TestResponseHeaderOverwrite(t *testing.T) { + s := web.NewServer() + + s.Get("/", func(ctx web.Context) error { + ctx.Response().SetHeader("Content-Type", "text/plain") + ctx.Response().SetHeader("Content-Type", "text/html") + return nil + }) + + response := s.Request("GET", "/", nil) + assert.Equal(t, response.Status(), 200) + assert.Equal(t, response.Header("Content-Type"), "text/html") + assert.Equal(t, string(response.Body()), "") +} diff --git a/Server.go b/Server.go index 459356f..940a2f1 100644 --- a/Server.go +++ b/Server.go @@ -160,29 +160,8 @@ func (s *server) handleConnection(conn net.Conn) { // handleRequest handles the given request. func (s *server) handleRequest(ctx *context, method string, url string, writer io.Writer) { - schemePos := strings.Index(url, "://") - schemeEnd := 0 - - if schemePos != -1 { - schemeEnd = schemePos + len("://") - } else { - schemePos = 0 - } - - pathPos := strings.IndexByte(url[schemeEnd:], '/') - - if pathPos == -1 { - return - } - - scheme := url[:schemePos] - host := url[schemeEnd : schemeEnd+pathPos] - path := url[schemeEnd+pathPos:] - ctx.method = method - ctx.scheme = scheme - ctx.host = host - ctx.path = path + ctx.scheme, ctx.host, ctx.path, ctx.query = parseURL(url) err := s.handlers[0](ctx) @@ -190,11 +169,7 @@ func (s *server) handleRequest(ctx *context, method string, url string, writer i s.errorHandler(ctx, err) } - _, err = fmt.Fprintf(writer, "HTTP/1.1 %d %s\r\nContent-Length: %d\r\n%s\r\n%s", ctx.status, "OK", len(ctx.body), ctx.response.headerText(), ctx.body) - - if err != nil { - s.errorHandler(ctx, err) - } + fmt.Fprintf(writer, "HTTP/1.1 %d %s\r\nContent-Length: %d\r\n%s\r\n%s", ctx.status, "OK", len(ctx.body), ctx.response.headerText(), ctx.body) } // newContext allocates a new context with the default state. @@ -210,3 +185,31 @@ func (s *server) newContext() *context { }, } } + +// parseURL parses a URL and returns the scheme, host, path and query. +func parseURL(url string) (scheme string, host string, path string, query string) { + schemePos := strings.Index(url, "://") + + if schemePos != -1 { + scheme = url[:schemePos] + url = url[schemePos+len("://"):] + } + + pathPos := strings.IndexByte(url, '/') + + if pathPos != -1 { + host = url[:pathPos] + url = url[pathPos:] + } + + queryPos := strings.IndexByte(url, '?') + + if queryPos != -1 { + path = url[:queryPos] + query = url[queryPos+1:] + return + } + + path = url + return +} diff --git a/Server_test.go b/Server_test.go deleted file mode 100644 index a09df47..0000000 --- a/Server_test.go +++ /dev/null @@ -1,113 +0,0 @@ -package web_test - -import ( - "errors" - "fmt" - "io" - "strings" - "testing" - - "git.akyoto.dev/go/assert" - "git.akyoto.dev/go/web" -) - -func TestContext(t *testing.T) { - s := web.NewServer() - - s.Get("/bytes", func(ctx web.Context) error { - return ctx.Bytes([]byte("Hello")) - }) - - s.Get("/string", func(ctx web.Context) error { - return ctx.String("Hello") - }) - - s.Get("/write", func(ctx web.Context) error { - _, err := ctx.Response().Write([]byte("Hello")) - return err - }) - - s.Get("/writestring", func(ctx web.Context) error { - _, err := io.WriteString(ctx.Response(), "Hello") - return err - }) - - s.Get("/error", func(ctx web.Context) error { - return ctx.Status(401).Error("Not logged in") - }) - - s.Get("/error2", func(ctx web.Context) error { - return ctx.Status(401).Error("Not logged in", errors.New("Missing auth token")) - }) - - s.Get("/request/data", func(ctx web.Context) error { - req := ctx.Request() - method := req.Method() - scheme := req.Scheme() - host := req.Host() - path := req.Path() - return ctx.String(fmt.Sprintf("%s %s %s %s", method, scheme, host, path)) - }) - - s.Router().Add("POST", "/", func(ctx web.Context) error { - return ctx.String("Post") - }) - - s.Router().Add("DELETE", "/", func(ctx web.Context) error { - return ctx.String("Delete") - }) - - s.Router().Add("PUT", "/", func(ctx web.Context) error { - return ctx.String("Put") - }) - - tests := []struct { - Method string - Path string - Body string - Status int - Response string - }{ - {Method: "GET", Path: "/bytes", Body: "", Status: 200, Response: "Hello"}, - // {Method: "GET", Path: "/context", Body: "", Status: 200, Response: ""}, - // {Method: "GET", Path: "/echo", Body: "Echo", Status: 200, Response: "Echo"}, - {Method: "GET", Path: "/error", Body: "", Status: 401, Response: ""}, - {Method: "GET", Path: "/error2", Body: "", Status: 401, Response: ""}, - // {Method: "GET", Path: "/file", Body: "", Status: 200, Response: "Hello File"}, - // {Method: "GET", Path: "/flush", Body: "", Status: 200, Response: "Hello 1\nHello 2\n"}, - // {Method: "GET", Path: "/not-found", Body: "", Status: 404, Response: ""}, - {Method: "GET", Path: "/request/data", Body: "", Status: 200, Response: "GET http example.com /request/data"}, - // {Method: "GET", Path: "/request/header", Body: "", Status: 200, Response: ""}, - // {Method: "GET", Path: "/response/header", Body: "", Status: 200, Response: "text/plain"}, - // {Method: "GET", Path: "/reader", Body: "", Status: 200, Response: "Hello"}, - // {Method: "GET", Path: "/redirect", Body: "", Status: 307, Response: ""}, - {Method: "GET", Path: "/string", Body: "", Status: 200, Response: "Hello"}, - {Method: "GET", Path: "/write", Body: "", Status: 200, Response: "Hello"}, - {Method: "GET", Path: "/writestring", Body: "", Status: 200, Response: "Hello"}, - // {Method: "GET", Path: "/blog/testing-my-router", Body: "", Status: 200, Response: "testing-my-router"}, - // {Method: "GET", Path: "/missing-parameter", Body: "", Status: 200, Response: ""}, - {Method: "POST", Path: "/", Body: "", Status: 200, Response: "Post"}, - {Method: "DELETE", Path: "/", Body: "", Status: 200, Response: "Delete"}, - {Method: "PUT", Path: "/", Body: "", Status: 200, Response: "Put"}, - } - - for _, test := range tests { - t.Run(test.Path, func(t *testing.T) { - response := s.Request(test.Method, "http://example.com"+test.Path, strings.NewReader(test.Body)) - assert.Equal(t, response.Status(), test.Status) - assert.Equal(t, string(response.Body()), test.Response) - }) - } -} - -func TestString(t *testing.T) { - s := web.NewServer() - - s.Get("/", func(ctx web.Context) error { - return ctx.String("Hello") - }) - - response := s.Request("GET", "/", nil) - assert.Equal(t, response.Status(), 200) - assert.DeepEqual(t, response.Body(), []byte("Hello")) -}