Add ability to shutdown the server

This commit is contained in:
spsobole 2024-03-21 12:09:05 -06:00
parent b312ba9cb1
commit 6dbdeafca2
2 changed files with 31 additions and 28 deletions

View File

@ -16,7 +16,6 @@ type Options struct {
kv OptionsProvider kv OptionsProvider
} }
// FormValueProvider we only care about something that can provide us the K/V lookup // FormValueProvider we only care about something that can provide us the K/V lookup
type FormValueProvider interface { type FormValueProvider interface {
FormValue(string) string FormValue(string) string
@ -36,7 +35,6 @@ func (f *Form) Get(k string) (string, bool) {
return v, true return v, true
} }
type Vars struct { type Vars struct {
vars map[string]string vars map[string]string
} }
@ -59,7 +57,6 @@ func (q *Query) Get(key string) (string, bool) {
return "", false return "", false
} }
func (o *Options) GetVar(k string) (string, bool) { func (o *Options) GetVar(k string) (string, bool) {
return o.kv.Get(k) return o.kv.Get(k)
} }
@ -85,15 +82,14 @@ func (o *Options) GetVarAsSlice(k string, delim string, def []string) []string {
return r return r
} }
// StringValue returns form value as a string // StringValue returns form value as a string
func (o *Options) StringValue(k string) (string,bool) { func (o *Options) StringValue(k string) (string, bool) {
return o.kv.Get(k) return o.kv.Get(k)
} }
// Uint64Value returns the first value in a set as uint64 // Uint64Value returns the first value in a set as uint64
func (o *Options) Uint64Value(k string) (uint64, bool) { func (o *Options) Uint64Value(k string) (uint64, bool) {
str,ok := o.kv.Get(k) str, ok := o.kv.Get(k)
if !ok { if !ok {
return 0, ok return 0, ok
} }
@ -107,8 +103,8 @@ func (o *Options) Uint64Value(k string) (uint64, bool) {
} }
// StringSlice returns multiple values assigned to the same key as a string slice // StringSlice returns multiple values assigned to the same key as a string slice
func (o *Options) StringSlice(k string, delim string) ([]string,bool) { func (o *Options) StringSlice(k string, delim string) ([]string, bool) {
str,ok := o.kv.Get(k) str, ok := o.kv.Get(k)
if !ok { if !ok {
return nil, ok return nil, ok
} }
@ -121,10 +117,11 @@ func (o *Options) StringSlice(k string, delim string) ([]string,bool) {
} }
// TimeValue returns a parsed time value for the specified key, // TimeValue returns a parsed time value for the specified key,
// if the key does not exist the but defaultValue was set it will return the defaultValue //
// otherwise it will return the passed in time or an error if parsing failed // if the key does not exist the but defaultValue was set it will return the defaultValue
// otherwise it will return the passed in time or an error if parsing failed
func (o *Options) TimeValue(key, timeFormat string, defaultValue *time.Time) time.Time { func (o *Options) TimeValue(key, timeFormat string, defaultValue *time.Time) time.Time {
str,ok := o.kv.Get(key) str, ok := o.kv.Get(key)
if !ok { if !ok {
if defaultValue != nil { if defaultValue != nil {
return *defaultValue return *defaultValue
@ -141,9 +138,6 @@ func (o *Options) TimeValue(key, timeFormat string, defaultValue *time.Time) tim
return t return t
} }
type CompoundOptions struct { type CompoundOptions struct {
kvs []OptionsProvider kvs []OptionsProvider
} }

View File

@ -1,6 +1,7 @@
package snap package snap
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -22,12 +23,14 @@ import (
) )
type Server struct { type Server struct {
address string address string
theme string theme string
debug bool debug bool
fs http.FileSystem fs http.FileSystem
auth auth.Authenticator auth auth.Authenticator
router *mux.Router router *mux.Router
http *http.Server
templates string templates string
cachedTmpl *template.Template cachedTmpl *template.Template
templateFuncs template.FuncMap templateFuncs template.FuncMap
@ -320,7 +323,7 @@ func (s *Server) ServeTLS(keyPath string, certPath string) error {
log.Fatalf("%v\n", err) log.Fatalf("%v\n", err)
} }
srv := &http.Server{ s.http = &http.Server{
Handler: s.router, Handler: s.router,
Addr: s.address, Addr: s.address,
// Good practice: enforce timeouts for servers you create! // Good practice: enforce timeouts for servers you create!
@ -328,30 +331,30 @@ func (s *Server) ServeTLS(keyPath string, certPath string) error {
ReadTimeout: 120 * time.Second, ReadTimeout: 120 * time.Second,
TLSConfig: &tls.Config{}, TLSConfig: &tls.Config{},
} }
srv.TLSConfig.GetCertificate = kpr.GetCertificateFunc() s.http.TLSConfig.GetCertificate = kpr.GetCertificateFunc()
return srv.ListenAndServeTLS("", "") return s.http.ListenAndServeTLS("", "")
} }
func (s *Server) ServeTLSRedirect(address string) error { func (s *Server) ServeTLSRedirect(address string) error {
srv := &http.Server{ s.http = &http.Server{
Addr: address, Addr: address,
// Good practice: enforce timeouts for servers you create! // Good practice: enforce timeouts for servers you create!
WriteTimeout: 120 * time.Second, WriteTimeout: 120 * time.Second,
ReadTimeout: 120 * time.Second, ReadTimeout: 120 * time.Second,
} }
return srv.ListenAndServe() return s.http.ListenAndServe()
} }
// Serve serve content forever // Serve content forever
func (s *Server) Serve() error { func (s *Server) Serve() error {
srv := &http.Server{ s.http = &http.Server{
Handler: s.router, Handler: s.router,
Addr: s.address, Addr: s.address,
// Good practice: enforce timeouts for servers you create! // Good practice: enforce timeouts for servers you create!
WriteTimeout: 120 * time.Second, WriteTimeout: 120 * time.Second,
ReadTimeout: 120 * time.Second, ReadTimeout: 120 * time.Second,
} }
return srv.ListenAndServe() return s.http.ListenAndServe()
} }
func (s *Server) WithStaticFiles(prefix string) *Server { func (s *Server) WithStaticFiles(prefix string) *Server {
@ -424,6 +427,12 @@ func (s *Server) Dump() {
fmt.Printf(" Templates: %s\n", s.templates) fmt.Printf(" Templates: %s\n", s.templates)
} }
func (s *Server) Shutdown(ctx context.Context) {
if s.http != nil {
s.http.Shutdown(ctx)
}
}
func New(address string, path string, auth auth.Authenticator) *Server { func New(address string, path string, auth auth.Authenticator) *Server {
s := Server{ s := Server{
router: mux.NewRouter(), router: mux.NewRouter(),