package server_test import ( "errors" "fmt" "io" "net" "net/http" "net/http/httptest" "strings" "syscall" "testing" "git.akyoto.dev/go/assert" "git.akyoto.dev/go/server" ) func TestRouter(t *testing.T) { s := server.New() s.Get("/", func(ctx server.Context) error { return ctx.Bytes([]byte("Hello")) }) s.Get("/string", func(ctx server.Context) error { return ctx.String("Hello") }) s.Get("/write", func(ctx server.Context) error { _, err := ctx.Response().Write([]byte("Hello")) return err }) s.Get("/writestring", func(ctx server.Context) error { _, err := io.WriteString(ctx.Response(), "Hello") return err }) s.Get("/error", func(ctx server.Context) error { return ctx.Status(http.StatusUnauthorized).Error("Not logged in") }) s.Get("/error2", func(ctx server.Context) error { return ctx.Status(http.StatusUnauthorized).Error("Not logged in", errors.New("Missing auth token")) }) s.Get("/reader", func(ctx server.Context) error { return ctx.Copy(strings.NewReader("Hello")) }) s.Get("/file", func(ctx server.Context) error { return ctx.File("testdata/file.txt") }) s.Get("/echo", func(ctx server.Context) error { return ctx.Copy(ctx.Request()) }) s.Get("/context", func(ctx server.Context) error { return ctx.Request().Context().Err() }) s.Get("/request/data", func(ctx server.Context) error { request := ctx.Request() method := request.Method() protocol := request.Protocol() host := request.Host() path := request.Path() return ctx.String(fmt.Sprintf("%s %s %s %s", method, protocol, host, path)) }) s.Get("/request/header", func(ctx server.Context) error { acceptEncoding := ctx.Request().Header("Accept-Encoding") return ctx.String(acceptEncoding) }) s.Get("/response/header", func(ctx server.Context) error { ctx.Response().SetHeader("Content-Type", "text/plain") contentType := ctx.Response().Header("Content-Type") return ctx.String(contentType) }) s.Get("/blog/:article", func(ctx server.Context) error { article := ctx.Get("article") return ctx.String(article) }) s.Get("/missing-parameter", func(ctx server.Context) error { missing := ctx.Get("missing") return ctx.String(missing) }) s.Get("/scheme", func(ctx server.Context) error { return ctx.String(ctx.Request().Scheme()) }) s.Post("/", func(ctx server.Context) error { return ctx.String("Post") }) s.Delete("/", func(ctx server.Context) error { return ctx.String("Delete") }) s.Put("/", func(ctx server.Context) error { return ctx.String("Put") }) tests := []struct { Method string URL string Body string Status int Response string }{ {Method: "GET", URL: "/", Body: "", Status: http.StatusOK, Response: "Hello"}, {Method: "GET", URL: "/context", Body: "", Status: http.StatusOK, Response: ""}, {Method: "GET", URL: "/echo", Body: "Echo", Status: http.StatusOK, Response: "Echo"}, {Method: "GET", URL: "/error", Body: "", Status: http.StatusUnauthorized, Response: "Not logged in"}, {Method: "GET", URL: "/error2", Body: "", Status: http.StatusUnauthorized, Response: "Not logged in\nMissing auth token"}, {Method: "GET", URL: "/file", Body: "", Status: http.StatusOK, Response: "Hello File"}, {Method: "GET", URL: "/not-found", Body: "", Status: http.StatusNotFound, Response: http.StatusText(http.StatusNotFound)}, {Method: "GET", URL: "/request/data", Body: "", Status: http.StatusOK, Response: "GET HTTP/1.1 example.com /request/data"}, {Method: "GET", URL: "/request/header", Body: "", Status: http.StatusOK, Response: ""}, {Method: "GET", URL: "/response/header", Body: "", Status: http.StatusOK, Response: "text/plain"}, {Method: "GET", URL: "/reader", Body: "", Status: http.StatusOK, Response: "Hello"}, {Method: "GET", URL: "/string", Body: "", Status: http.StatusOK, Response: "Hello"}, {Method: "GET", URL: "/scheme", Body: "", Status: http.StatusOK, Response: "http"}, {Method: "GET", URL: "/write", Body: "", Status: http.StatusOK, Response: "Hello"}, {Method: "GET", URL: "/writestring", Body: "", Status: http.StatusOK, Response: "Hello"}, {Method: "GET", URL: "/blog/testing-my-router", Body: "", Status: http.StatusOK, Response: "testing-my-router"}, {Method: "GET", URL: "/missing-parameter", Body: "", Status: http.StatusOK, Response: ""}, {Method: "POST", URL: "/", Body: "", Status: http.StatusOK, Response: "Post"}, {Method: "DELETE", URL: "/", Body: "", Status: http.StatusOK, Response: "Delete"}, {Method: "PUT", URL: "/", Body: "", Status: http.StatusOK, Response: "Put"}, } for _, test := range tests { t.Run("example.com"+test.URL, func(t *testing.T) { request := httptest.NewRequest(test.Method, "http://example.com"+test.URL, strings.NewReader(test.Body)) response := httptest.NewRecorder() s.ServeHTTP(response, request) result := response.Result() assert.Equal(t, result.StatusCode, test.Status) body, err := io.ReadAll(result.Body) assert.Nil(t, err) assert.Equal(t, string(body), test.Response) }) } } func TestMiddleware(t *testing.T) { s := server.New() s.Use(func(ctx server.Context) error { ctx.Response().SetHeader("Middleware", "true") return ctx.Next() }) request := httptest.NewRequest(http.MethodGet, "/", nil) response := httptest.NewRecorder() s.ServeHTTP(response, request) assert.Equal(t, response.Header().Get("Middleware"), "true") } func TestPanic(t *testing.T) { s := server.New() s.Router().Add(http.MethodGet, "/panic", func(ctx server.Context) error { panic("Something unbelievable happened") }) t.Run("example.com/panic", func(t *testing.T) { defer func() { r := recover() if r == nil { t.Error("Didn't panic") } }() request := httptest.NewRequest(http.MethodGet, "/panic", nil) response := httptest.NewRecorder() s.ServeHTTP(response, request) }) } func TestRun(t *testing.T) { s := server.New() go func() { _, err := http.Get("http://127.0.0.1:8080/") assert.Nil(t, err) err = syscall.Kill(syscall.Getpid(), syscall.SIGTERM) assert.Nil(t, err) }() s.Run(":8080") } func TestUnavailablePort(t *testing.T) { listener, err := net.Listen("tcp", ":8080") assert.Nil(t, err) defer listener.Close() s := server.New() s.Run(":8080") }