Added graceful shutdown

This commit is contained in:
Eduard Urbach 2024-03-12 22:31:45 +01:00
parent d47604606f
commit 46177ccd42
Signed by: akyoto
GPG Key ID: C874F672B1AF20C0
8 changed files with 116 additions and 31 deletions

View File

@ -27,7 +27,7 @@ func BenchmarkGitHub(b *testing.B) {
s := server.New()
for _, route := range loadRoutes("testdata/github.txt") {
s.Router().Add(route.Method, route.Path, func(server.Context) error {
s.Router.Add(route.Method, route.Path, func(server.Context) error {
return nil
})
}

30
Configuration.go Normal file
View File

@ -0,0 +1,30 @@
package server
import "time"
// Configuration represents the server configuration.
type Configuration struct {
Timeout TimeoutConfiguration `json:"timeouts"`
}
// TimeoutConfiguration lets you configure the different timeout durations.
type TimeoutConfiguration struct {
Idle time.Duration `json:"idle"`
Read time.Duration `json:"read"`
ReadHeader time.Duration `json:"readHeader"`
Write time.Duration `json:"write"`
Shutdown time.Duration `json:"shutdown"`
}
// Reset resets all fields to the default configuration.
func defaultConfig() Configuration {
return Configuration{
Timeout: TimeoutConfiguration{
Idle: 3 * time.Minute,
Write: 2 * time.Minute,
Read: 5 * time.Second,
ReadHeader: 5 * time.Second,
Shutdown: 250 * time.Millisecond,
},
}
}

View File

@ -21,8 +21,8 @@ type Context interface {
String(string) error
}
// context represents a request & response context.
type context struct {
// ctx represents a request & response context.
type ctx struct {
request request
response response
paramNames [maxParams]string
@ -31,8 +31,8 @@ type context struct {
}
// newContext returns a new context from the pool.
func newContext(req *http.Request, res http.ResponseWriter) *context {
ctx := contextPool.Get().(*context)
func newContext(req *http.Request, res http.ResponseWriter) *ctx {
ctx := contextPool.Get().(*ctx)
ctx.request.Request = req
ctx.response.ResponseWriter = res
ctx.paramCount = 0
@ -40,13 +40,13 @@ func newContext(req *http.Request, res http.ResponseWriter) *context {
}
// Bytes responds with a raw byte slice.
func (ctx *context) Bytes(body []byte) error {
func (ctx *ctx) Bytes(body []byte) error {
_, err := ctx.response.Write(body)
return err
}
// Error is used for sending error messages to the client.
func (ctx *context) Error(status int, messages ...any) error {
func (ctx *ctx) Error(status int, messages ...any) error {
var combined []error
for _, msg := range messages {
@ -63,7 +63,7 @@ func (ctx *context) Error(status int, messages ...any) error {
}
// Get retrieves a parameter.
func (ctx *context) Get(param string) string {
func (ctx *ctx) Get(param string) string {
for i := 0; i < ctx.paramCount; i++ {
if ctx.paramNames[i] == param {
return ctx.paramValues[i]
@ -74,29 +74,29 @@ func (ctx *context) Get(param string) string {
}
// Reader sends the contents of the io.Reader without creating an in-memory copy.
func (ctx *context) Reader(reader io.Reader) error {
func (ctx *ctx) Reader(reader io.Reader) error {
_, err := io.Copy(ctx.response, reader)
return err
}
// Request returns the HTTP request.
func (ctx *context) Request() Request {
func (ctx *ctx) Request() Request {
return &ctx.request
}
// Response returns the HTTP response.
func (ctx *context) Response() Response {
func (ctx *ctx) Response() Response {
return &ctx.response
}
// String responds with the given string.
func (ctx *context) String(body string) error {
func (ctx *ctx) String(body string) error {
slice := unsafe.Slice(unsafe.StringData(body), len(body))
return ctx.Bytes(slice)
}
// addParameter adds a new parameter to the context.
func (ctx *context) addParameter(name string, value string) {
func (ctx *ctx) addParameter(name string, value string) {
ctx.paramNames[ctx.paramCount] = name
ctx.paramValues[ctx.paramCount] = value
ctx.paramCount++

View File

@ -33,7 +33,7 @@ s.Get("/images/*file", func(ctx server.Context) error {
return ctx.String(ctx.Get("file"))
})
http.ListenAndServe(":8080", s)
s.Run(":8080")
```
## Tests

View File

@ -1,48 +1,51 @@
package server
import (
"context"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"git.akyoto.dev/go/router"
)
// Server represents a single web service.
type Server struct {
router *router.Router[Handler]
Router *router.Router[Handler]
Config Configuration
}
// New creates a new server.
func New() *Server {
return &Server{
router: router.New[Handler](),
Router: router.New[Handler](),
Config: defaultConfig(),
}
}
// Get registers your function to be called when the given GET path has been requested.
func (server *Server) Get(path string, handler Handler) {
server.router.Add(http.MethodGet, path, handler)
server.Router.Add(http.MethodGet, path, handler)
}
// Post registers your function to be called when the given POST path has been requested.
func (server *Server) Post(path string, handler Handler) {
server.router.Add(http.MethodPost, path, handler)
server.Router.Add(http.MethodPost, path, handler)
}
// Delete registers your function to be called when the given DELETE path has been requested.
func (server *Server) Delete(path string, handler Handler) {
server.router.Add(http.MethodDelete, path, handler)
server.Router.Add(http.MethodDelete, path, handler)
}
// Put registers your function to be called when the given PUT path has been requested.
func (server *Server) Put(path string, handler Handler) {
server.router.Add(http.MethodPut, path, handler)
}
// Router returns the router used by the server.
func (server *Server) Router() *router.Router[Handler] {
return server.router
server.Router.Add(http.MethodPut, path, handler)
}
// ServeHTTP responds to the given request.
@ -50,7 +53,7 @@ func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Requ
ctx := newContext(request, response)
defer contextPool.Put(ctx)
handler := server.router.LookupNoAlloc(request.Method, request.URL.Path, ctx.addParameter)
handler := server.Router.LookupNoAlloc(request.Method, request.URL.Path, ctx.addParameter)
if handler == nil {
response.WriteHeader(http.StatusNotFound)
@ -65,3 +68,33 @@ func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Requ
log.Println(request.URL, err)
}
}
// Run starts the server on the given address.
func (server *Server) Run(address string) {
srv := &http.Server{
Addr: address,
Handler: server,
ReadTimeout: server.Config.Timeout.Read,
WriteTimeout: server.Config.Timeout.Write,
IdleTimeout: server.Config.Timeout.Idle,
ReadHeaderTimeout: server.Config.Timeout.ReadHeader,
}
listener, err := net.Listen("tcp", address)
if err != nil {
fmt.Println(err)
return
}
go srv.Serve(listener)
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop
ctx, cancel := context.WithTimeout(context.Background(), server.Config.Timeout.Shutdown)
defer cancel()
srv.Shutdown(ctx)
}

View File

@ -4,9 +4,11 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"syscall"
"testing"
"git.akyoto.dev/go/assert"
@ -121,7 +123,7 @@ func TestRouter(t *testing.T) {
func TestPanic(t *testing.T) {
s := server.New()
s.Router().Add(http.MethodGet, "/panic", func(ctx server.Context) error {
s.Router.Add(http.MethodGet, "/panic", func(ctx server.Context) error {
panic("Something unbelievable happened")
})
@ -139,3 +141,25 @@ func TestPanic(t *testing.T) {
s.ServeHTTP(response, request)
})
}
func TestRun(t *testing.T) {
s := server.New()
go func() {
_, err := http.Get("http://127.0.0.1:8080/")
assert.Nil(t, err)
err = syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
assert.Nil(t, err)
}()
s.Run(":8080")
}
func TestUnavailablePort(t *testing.T) {
listener, err := net.Listen("tcp", ":8080")
assert.Nil(t, err)
defer listener.Close()
s := server.New()
s.Run(":8080")
}

View File

@ -1,8 +1,6 @@
package main
import (
"net/http"
"git.akyoto.dev/go/server"
)
@ -13,5 +11,5 @@ func main() {
return ctx.String("Hello")
})
http.ListenAndServe(":8080", s)
s.Run(":8080")
}

View File

@ -4,6 +4,6 @@ import "sync"
var contextPool = sync.Pool{
New: func() any {
return &context{}
return &ctx{}
},
}