From 46177ccd42ba54c41bc91101b912409d7b1f80b8 Mon Sep 17 00:00:00 2001 From: Eduard Urbach Date: Tue, 12 Mar 2024 22:31:45 +0100 Subject: [PATCH] Added graceful shutdown --- Benchmarks_test.go | 2 +- Configuration.go | 30 ++++++++++++++++++++++ Context.go | 24 +++++++++--------- README.md | 2 +- Server.go | 57 +++++++++++++++++++++++++++++++++--------- Server_test.go | 26 ++++++++++++++++++- examples/hello/main.go | 4 +-- pool.go | 2 +- 8 files changed, 116 insertions(+), 31 deletions(-) create mode 100644 Configuration.go diff --git a/Benchmarks_test.go b/Benchmarks_test.go index 3ce6969..efcb4ff 100644 --- a/Benchmarks_test.go +++ b/Benchmarks_test.go @@ -27,7 +27,7 @@ func BenchmarkGitHub(b *testing.B) { s := server.New() for _, route := range loadRoutes("testdata/github.txt") { - s.Router().Add(route.Method, route.Path, func(server.Context) error { + s.Router.Add(route.Method, route.Path, func(server.Context) error { return nil }) } diff --git a/Configuration.go b/Configuration.go new file mode 100644 index 0000000..0d9bd2a --- /dev/null +++ b/Configuration.go @@ -0,0 +1,30 @@ +package server + +import "time" + +// Configuration represents the server configuration. +type Configuration struct { + Timeout TimeoutConfiguration `json:"timeouts"` +} + +// TimeoutConfiguration lets you configure the different timeout durations. +type TimeoutConfiguration struct { + Idle time.Duration `json:"idle"` + Read time.Duration `json:"read"` + ReadHeader time.Duration `json:"readHeader"` + Write time.Duration `json:"write"` + Shutdown time.Duration `json:"shutdown"` +} + +// Reset resets all fields to the default configuration. +func defaultConfig() Configuration { + return Configuration{ + Timeout: TimeoutConfiguration{ + Idle: 3 * time.Minute, + Write: 2 * time.Minute, + Read: 5 * time.Second, + ReadHeader: 5 * time.Second, + Shutdown: 250 * time.Millisecond, + }, + } +} diff --git a/Context.go b/Context.go index 29d151f..89ded47 100644 --- a/Context.go +++ b/Context.go @@ -21,8 +21,8 @@ type Context interface { String(string) error } -// context represents a request & response context. -type context struct { +// ctx represents a request & response context. +type ctx struct { request request response response paramNames [maxParams]string @@ -31,8 +31,8 @@ type context struct { } // newContext returns a new context from the pool. -func newContext(req *http.Request, res http.ResponseWriter) *context { - ctx := contextPool.Get().(*context) +func newContext(req *http.Request, res http.ResponseWriter) *ctx { + ctx := contextPool.Get().(*ctx) ctx.request.Request = req ctx.response.ResponseWriter = res ctx.paramCount = 0 @@ -40,13 +40,13 @@ func newContext(req *http.Request, res http.ResponseWriter) *context { } // Bytes responds with a raw byte slice. -func (ctx *context) Bytes(body []byte) error { +func (ctx *ctx) Bytes(body []byte) error { _, err := ctx.response.Write(body) return err } // Error is used for sending error messages to the client. -func (ctx *context) Error(status int, messages ...any) error { +func (ctx *ctx) Error(status int, messages ...any) error { var combined []error for _, msg := range messages { @@ -63,7 +63,7 @@ func (ctx *context) Error(status int, messages ...any) error { } // Get retrieves a parameter. -func (ctx *context) Get(param string) string { +func (ctx *ctx) Get(param string) string { for i := 0; i < ctx.paramCount; i++ { if ctx.paramNames[i] == param { return ctx.paramValues[i] @@ -74,29 +74,29 @@ func (ctx *context) Get(param string) string { } // Reader sends the contents of the io.Reader without creating an in-memory copy. -func (ctx *context) Reader(reader io.Reader) error { +func (ctx *ctx) Reader(reader io.Reader) error { _, err := io.Copy(ctx.response, reader) return err } // Request returns the HTTP request. -func (ctx *context) Request() Request { +func (ctx *ctx) Request() Request { return &ctx.request } // Response returns the HTTP response. -func (ctx *context) Response() Response { +func (ctx *ctx) Response() Response { return &ctx.response } // String responds with the given string. -func (ctx *context) String(body string) error { +func (ctx *ctx) String(body string) error { slice := unsafe.Slice(unsafe.StringData(body), len(body)) return ctx.Bytes(slice) } // addParameter adds a new parameter to the context. -func (ctx *context) addParameter(name string, value string) { +func (ctx *ctx) addParameter(name string, value string) { ctx.paramNames[ctx.paramCount] = name ctx.paramValues[ctx.paramCount] = value ctx.paramCount++ diff --git a/README.md b/README.md index ac4d2bf..1b44d98 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ s.Get("/images/*file", func(ctx server.Context) error { return ctx.String(ctx.Get("file")) }) -http.ListenAndServe(":8080", s) +s.Run(":8080") ``` ## Tests diff --git a/Server.go b/Server.go index 3758672..f23d6b6 100644 --- a/Server.go +++ b/Server.go @@ -1,48 +1,51 @@ package server import ( + "context" + "fmt" "io" "log" + "net" "net/http" + "os" + "os/signal" + "syscall" "git.akyoto.dev/go/router" ) // Server represents a single web service. type Server struct { - router *router.Router[Handler] + Router *router.Router[Handler] + Config Configuration } // New creates a new server. func New() *Server { return &Server{ - router: router.New[Handler](), + Router: router.New[Handler](), + Config: defaultConfig(), } } // Get registers your function to be called when the given GET path has been requested. func (server *Server) Get(path string, handler Handler) { - server.router.Add(http.MethodGet, path, handler) + server.Router.Add(http.MethodGet, path, handler) } // Post registers your function to be called when the given POST path has been requested. func (server *Server) Post(path string, handler Handler) { - server.router.Add(http.MethodPost, path, handler) + server.Router.Add(http.MethodPost, path, handler) } // Delete registers your function to be called when the given DELETE path has been requested. func (server *Server) Delete(path string, handler Handler) { - server.router.Add(http.MethodDelete, path, handler) + server.Router.Add(http.MethodDelete, path, handler) } // Put registers your function to be called when the given PUT path has been requested. func (server *Server) Put(path string, handler Handler) { - server.router.Add(http.MethodPut, path, handler) -} - -// Router returns the router used by the server. -func (server *Server) Router() *router.Router[Handler] { - return server.router + server.Router.Add(http.MethodPut, path, handler) } // ServeHTTP responds to the given request. @@ -50,7 +53,7 @@ func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Requ ctx := newContext(request, response) defer contextPool.Put(ctx) - handler := server.router.LookupNoAlloc(request.Method, request.URL.Path, ctx.addParameter) + handler := server.Router.LookupNoAlloc(request.Method, request.URL.Path, ctx.addParameter) if handler == nil { response.WriteHeader(http.StatusNotFound) @@ -65,3 +68,33 @@ func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Requ log.Println(request.URL, err) } } + +// Run starts the server on the given address. +func (server *Server) Run(address string) { + srv := &http.Server{ + Addr: address, + Handler: server, + ReadTimeout: server.Config.Timeout.Read, + WriteTimeout: server.Config.Timeout.Write, + IdleTimeout: server.Config.Timeout.Idle, + ReadHeaderTimeout: server.Config.Timeout.ReadHeader, + } + + listener, err := net.Listen("tcp", address) + + if err != nil { + fmt.Println(err) + return + } + + go srv.Serve(listener) + + stop := make(chan os.Signal, 1) + signal.Notify(stop, os.Interrupt, syscall.SIGTERM) + <-stop + + ctx, cancel := context.WithTimeout(context.Background(), server.Config.Timeout.Shutdown) + defer cancel() + + srv.Shutdown(ctx) +} diff --git a/Server_test.go b/Server_test.go index 557291a..4977f1f 100644 --- a/Server_test.go +++ b/Server_test.go @@ -4,9 +4,11 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "strings" + "syscall" "testing" "git.akyoto.dev/go/assert" @@ -121,7 +123,7 @@ func TestRouter(t *testing.T) { func TestPanic(t *testing.T) { s := server.New() - s.Router().Add(http.MethodGet, "/panic", func(ctx server.Context) error { + s.Router.Add(http.MethodGet, "/panic", func(ctx server.Context) error { panic("Something unbelievable happened") }) @@ -139,3 +141,25 @@ func TestPanic(t *testing.T) { s.ServeHTTP(response, request) }) } + +func TestRun(t *testing.T) { + s := server.New() + + go func() { + _, err := http.Get("http://127.0.0.1:8080/") + assert.Nil(t, err) + err = syscall.Kill(syscall.Getpid(), syscall.SIGTERM) + assert.Nil(t, err) + }() + + s.Run(":8080") +} + +func TestUnavailablePort(t *testing.T) { + listener, err := net.Listen("tcp", ":8080") + assert.Nil(t, err) + defer listener.Close() + + s := server.New() + s.Run(":8080") +} diff --git a/examples/hello/main.go b/examples/hello/main.go index 6ac7313..215c77d 100644 --- a/examples/hello/main.go +++ b/examples/hello/main.go @@ -1,8 +1,6 @@ package main import ( - "net/http" - "git.akyoto.dev/go/server" ) @@ -13,5 +11,5 @@ func main() { return ctx.String("Hello") }) - http.ListenAndServe(":8080", s) + s.Run(":8080") } diff --git a/pool.go b/pool.go index 5652235..dd1e032 100644 --- a/pool.go +++ b/pool.go @@ -4,6 +4,6 @@ import "sync" var contextPool = sync.Pool{ New: func() any { - return &context{} + return &ctx{} }, }