142 lines
3.4 KiB
Go

package server
import (
"context"
"log"
"net"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"git.akyoto.dev/go/router"
)
// Server is the interface for an HTTP server.
type Server interface {
http.Handler
Delete(path string, handler Handler)
Get(path string, handler Handler)
Post(path string, handler Handler)
Put(path string, handler Handler)
Router() *router.Router[Handler]
Run(address string) error
Use(handlers ...Handler)
}
// server is an HTTP server.
type server struct {
pool sync.Pool
handlers []Handler
router router.Router[Handler]
errorHandler func(Context, error)
config Configuration
}
// New creates a new HTTP server.
func New() Server {
s := &server{
router: router.Router[Handler]{},
config: defaultConfig(),
handlers: []Handler{
func(c Context) error {
handler := c.(*ctx).server.router.LookupNoAlloc(c.Method(), c.Path(), c.(*ctx).addParameter)
if handler == nil {
return c.Status(http.StatusNotFound).String(http.StatusText(http.StatusNotFound))
}
return handler(c)
},
},
errorHandler: func(ctx Context, err error) {
ctx.WriteString(err.Error())
log.Println(ctx.Path(), err)
},
}
s.pool.New = func() any {
return &ctx{server: s}
}
return s
}
// Get registers your function to be called when the given GET path has been requested.
func (s *server) Get(path string, handler Handler) {
s.Router().Add(http.MethodGet, path, handler)
}
// Post registers your function to be called when the given POST path has been requested.
func (s *server) Post(path string, handler Handler) {
s.Router().Add(http.MethodPost, path, handler)
}
// Delete registers your function to be called when the given DELETE path has been requested.
func (s *server) Delete(path string, handler Handler) {
s.Router().Add(http.MethodDelete, path, handler)
}
// Put registers your function to be called when the given PUT path has been requested.
func (s *server) Put(path string, handler Handler) {
s.Router().Add(http.MethodPut, path, handler)
}
// ServeHTTP responds to the given request.
func (s *server) ServeHTTP(response http.ResponseWriter, request *http.Request) {
ctx := s.pool.Get().(*ctx)
ctx.request = request
ctx.response = response
err := s.handlers[0](ctx)
if err != nil {
s.errorHandler(ctx, err)
}
ctx.paramCount = 0
ctx.handlerCount = 0
s.pool.Put(ctx)
}
// Run starts the server on the given address.
func (server *server) Run(address string) error {
srv := &http.Server{
Addr: address,
Handler: server,
ReadTimeout: server.config.Timeout.Read,
WriteTimeout: server.config.Timeout.Write,
IdleTimeout: server.config.Timeout.Idle,
ReadHeaderTimeout: server.config.Timeout.ReadHeader,
}
listener, err := net.Listen("tcp", address)
if err != nil {
return err
}
go srv.Serve(listener)
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop
ctx, cancel := context.WithTimeout(context.Background(), server.config.Timeout.Shutdown)
defer cancel()
return srv.Shutdown(ctx)
}
// Router returns the router used by the server.
func (s *server) Router() *router.Router[Handler] {
return &s.router
}
// Use adds handlers to your handlers chain.
func (s *server) Use(handlers ...Handler) {
last := s.handlers[len(s.handlers)-1]
s.handlers = append(s.handlers[:len(s.handlers)-1], handlers...)
s.handlers = append(s.handlers, last)
}