Added graceful shutdown

This commit is contained in:
Eduard Urbach 2024-03-12 22:31:45 +01:00
parent d47604606f
commit 46177ccd42
Signed by: akyoto
GPG Key ID: C874F672B1AF20C0
8 changed files with 116 additions and 31 deletions

View File

@ -27,7 +27,7 @@ func BenchmarkGitHub(b *testing.B) {
s := server.New() s := server.New()
for _, route := range loadRoutes("testdata/github.txt") { for _, route := range loadRoutes("testdata/github.txt") {
s.Router().Add(route.Method, route.Path, func(server.Context) error { s.Router.Add(route.Method, route.Path, func(server.Context) error {
return nil return nil
}) })
} }

30
Configuration.go Normal file
View File

@ -0,0 +1,30 @@
package server
import "time"
// Configuration represents the server configuration.
type Configuration struct {
Timeout TimeoutConfiguration `json:"timeouts"`
}
// TimeoutConfiguration lets you configure the different timeout durations.
type TimeoutConfiguration struct {
Idle time.Duration `json:"idle"`
Read time.Duration `json:"read"`
ReadHeader time.Duration `json:"readHeader"`
Write time.Duration `json:"write"`
Shutdown time.Duration `json:"shutdown"`
}
// Reset resets all fields to the default configuration.
func defaultConfig() Configuration {
return Configuration{
Timeout: TimeoutConfiguration{
Idle: 3 * time.Minute,
Write: 2 * time.Minute,
Read: 5 * time.Second,
ReadHeader: 5 * time.Second,
Shutdown: 250 * time.Millisecond,
},
}
}

View File

@ -21,8 +21,8 @@ type Context interface {
String(string) error String(string) error
} }
// context represents a request & response context. // ctx represents a request & response context.
type context struct { type ctx struct {
request request request request
response response response response
paramNames [maxParams]string paramNames [maxParams]string
@ -31,8 +31,8 @@ type context struct {
} }
// newContext returns a new context from the pool. // newContext returns a new context from the pool.
func newContext(req *http.Request, res http.ResponseWriter) *context { func newContext(req *http.Request, res http.ResponseWriter) *ctx {
ctx := contextPool.Get().(*context) ctx := contextPool.Get().(*ctx)
ctx.request.Request = req ctx.request.Request = req
ctx.response.ResponseWriter = res ctx.response.ResponseWriter = res
ctx.paramCount = 0 ctx.paramCount = 0
@ -40,13 +40,13 @@ func newContext(req *http.Request, res http.ResponseWriter) *context {
} }
// Bytes responds with a raw byte slice. // Bytes responds with a raw byte slice.
func (ctx *context) Bytes(body []byte) error { func (ctx *ctx) Bytes(body []byte) error {
_, err := ctx.response.Write(body) _, err := ctx.response.Write(body)
return err return err
} }
// Error is used for sending error messages to the client. // Error is used for sending error messages to the client.
func (ctx *context) Error(status int, messages ...any) error { func (ctx *ctx) Error(status int, messages ...any) error {
var combined []error var combined []error
for _, msg := range messages { for _, msg := range messages {
@ -63,7 +63,7 @@ func (ctx *context) Error(status int, messages ...any) error {
} }
// Get retrieves a parameter. // Get retrieves a parameter.
func (ctx *context) Get(param string) string { func (ctx *ctx) Get(param string) string {
for i := 0; i < ctx.paramCount; i++ { for i := 0; i < ctx.paramCount; i++ {
if ctx.paramNames[i] == param { if ctx.paramNames[i] == param {
return ctx.paramValues[i] return ctx.paramValues[i]
@ -74,29 +74,29 @@ func (ctx *context) Get(param string) string {
} }
// Reader sends the contents of the io.Reader without creating an in-memory copy. // Reader sends the contents of the io.Reader without creating an in-memory copy.
func (ctx *context) Reader(reader io.Reader) error { func (ctx *ctx) Reader(reader io.Reader) error {
_, err := io.Copy(ctx.response, reader) _, err := io.Copy(ctx.response, reader)
return err return err
} }
// Request returns the HTTP request. // Request returns the HTTP request.
func (ctx *context) Request() Request { func (ctx *ctx) Request() Request {
return &ctx.request return &ctx.request
} }
// Response returns the HTTP response. // Response returns the HTTP response.
func (ctx *context) Response() Response { func (ctx *ctx) Response() Response {
return &ctx.response return &ctx.response
} }
// String responds with the given string. // String responds with the given string.
func (ctx *context) String(body string) error { func (ctx *ctx) String(body string) error {
slice := unsafe.Slice(unsafe.StringData(body), len(body)) slice := unsafe.Slice(unsafe.StringData(body), len(body))
return ctx.Bytes(slice) return ctx.Bytes(slice)
} }
// addParameter adds a new parameter to the context. // addParameter adds a new parameter to the context.
func (ctx *context) addParameter(name string, value string) { func (ctx *ctx) addParameter(name string, value string) {
ctx.paramNames[ctx.paramCount] = name ctx.paramNames[ctx.paramCount] = name
ctx.paramValues[ctx.paramCount] = value ctx.paramValues[ctx.paramCount] = value
ctx.paramCount++ ctx.paramCount++

View File

@ -33,7 +33,7 @@ s.Get("/images/*file", func(ctx server.Context) error {
return ctx.String(ctx.Get("file")) return ctx.String(ctx.Get("file"))
}) })
http.ListenAndServe(":8080", s) s.Run(":8080")
``` ```
## Tests ## Tests

View File

@ -1,48 +1,51 @@
package server package server
import ( import (
"context"
"fmt"
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"os"
"os/signal"
"syscall"
"git.akyoto.dev/go/router" "git.akyoto.dev/go/router"
) )
// Server represents a single web service. // Server represents a single web service.
type Server struct { type Server struct {
router *router.Router[Handler] Router *router.Router[Handler]
Config Configuration
} }
// New creates a new server. // New creates a new server.
func New() *Server { func New() *Server {
return &Server{ return &Server{
router: router.New[Handler](), Router: router.New[Handler](),
Config: defaultConfig(),
} }
} }
// Get registers your function to be called when the given GET path has been requested. // Get registers your function to be called when the given GET path has been requested.
func (server *Server) Get(path string, handler Handler) { func (server *Server) Get(path string, handler Handler) {
server.router.Add(http.MethodGet, path, handler) server.Router.Add(http.MethodGet, path, handler)
} }
// Post registers your function to be called when the given POST path has been requested. // Post registers your function to be called when the given POST path has been requested.
func (server *Server) Post(path string, handler Handler) { func (server *Server) Post(path string, handler Handler) {
server.router.Add(http.MethodPost, path, handler) server.Router.Add(http.MethodPost, path, handler)
} }
// Delete registers your function to be called when the given DELETE path has been requested. // Delete registers your function to be called when the given DELETE path has been requested.
func (server *Server) Delete(path string, handler Handler) { func (server *Server) Delete(path string, handler Handler) {
server.router.Add(http.MethodDelete, path, handler) server.Router.Add(http.MethodDelete, path, handler)
} }
// Put registers your function to be called when the given PUT path has been requested. // Put registers your function to be called when the given PUT path has been requested.
func (server *Server) Put(path string, handler Handler) { func (server *Server) Put(path string, handler Handler) {
server.router.Add(http.MethodPut, path, handler) server.Router.Add(http.MethodPut, path, handler)
}
// Router returns the router used by the server.
func (server *Server) Router() *router.Router[Handler] {
return server.router
} }
// ServeHTTP responds to the given request. // ServeHTTP responds to the given request.
@ -50,7 +53,7 @@ func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Requ
ctx := newContext(request, response) ctx := newContext(request, response)
defer contextPool.Put(ctx) defer contextPool.Put(ctx)
handler := server.router.LookupNoAlloc(request.Method, request.URL.Path, ctx.addParameter) handler := server.Router.LookupNoAlloc(request.Method, request.URL.Path, ctx.addParameter)
if handler == nil { if handler == nil {
response.WriteHeader(http.StatusNotFound) response.WriteHeader(http.StatusNotFound)
@ -65,3 +68,33 @@ func (server *Server) ServeHTTP(response http.ResponseWriter, request *http.Requ
log.Println(request.URL, err) log.Println(request.URL, err)
} }
} }
// Run starts the server on the given address.
func (server *Server) Run(address string) {
srv := &http.Server{
Addr: address,
Handler: server,
ReadTimeout: server.Config.Timeout.Read,
WriteTimeout: server.Config.Timeout.Write,
IdleTimeout: server.Config.Timeout.Idle,
ReadHeaderTimeout: server.Config.Timeout.ReadHeader,
}
listener, err := net.Listen("tcp", address)
if err != nil {
fmt.Println(err)
return
}
go srv.Serve(listener)
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop
ctx, cancel := context.WithTimeout(context.Background(), server.Config.Timeout.Shutdown)
defer cancel()
srv.Shutdown(ctx)
}

View File

@ -4,9 +4,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"syscall"
"testing" "testing"
"git.akyoto.dev/go/assert" "git.akyoto.dev/go/assert"
@ -121,7 +123,7 @@ func TestRouter(t *testing.T) {
func TestPanic(t *testing.T) { func TestPanic(t *testing.T) {
s := server.New() s := server.New()
s.Router().Add(http.MethodGet, "/panic", func(ctx server.Context) error { s.Router.Add(http.MethodGet, "/panic", func(ctx server.Context) error {
panic("Something unbelievable happened") panic("Something unbelievable happened")
}) })
@ -139,3 +141,25 @@ func TestPanic(t *testing.T) {
s.ServeHTTP(response, request) s.ServeHTTP(response, request)
}) })
} }
func TestRun(t *testing.T) {
s := server.New()
go func() {
_, err := http.Get("http://127.0.0.1:8080/")
assert.Nil(t, err)
err = syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
assert.Nil(t, err)
}()
s.Run(":8080")
}
func TestUnavailablePort(t *testing.T) {
listener, err := net.Listen("tcp", ":8080")
assert.Nil(t, err)
defer listener.Close()
s := server.New()
s.Run(":8080")
}

View File

@ -1,8 +1,6 @@
package main package main
import ( import (
"net/http"
"git.akyoto.dev/go/server" "git.akyoto.dev/go/server"
) )
@ -13,5 +11,5 @@ func main() {
return ctx.String("Hello") return ctx.String("Hello")
}) })
http.ListenAndServe(":8080", s) s.Run(":8080")
} }

View File

@ -4,6 +4,6 @@ import "sync"
var contextPool = sync.Pool{ var contextPool = sync.Pool{
New: func() any { New: func() any {
return &context{} return &ctx{}
}, },
} }