From 6dbdeafca282f0d315303191989dc6500a37fb6e Mon Sep 17 00:00:00 2001 From: spsobole Date: Thu, 21 Mar 2024 12:09:05 -0600 Subject: [PATCH] Add ability to shutdown the server --- options.go | 22 ++++++++-------------- server.go | 37 +++++++++++++++++++++++-------------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/options.go b/options.go index 7b5f7ef..26fba07 100644 --- a/options.go +++ b/options.go @@ -16,7 +16,6 @@ type Options struct { kv OptionsProvider } - // FormValueProvider we only care about something that can provide us the K/V lookup type FormValueProvider interface { FormValue(string) string @@ -36,7 +35,6 @@ func (f *Form) Get(k string) (string, bool) { return v, true } - type Vars struct { vars map[string]string } @@ -59,7 +57,6 @@ func (q *Query) Get(key string) (string, bool) { return "", false } - func (o *Options) GetVar(k string) (string, bool) { return o.kv.Get(k) } @@ -85,15 +82,14 @@ func (o *Options) GetVarAsSlice(k string, delim string, def []string) []string { return r } - // StringValue returns form value as a string -func (o *Options) StringValue(k string) (string,bool) { +func (o *Options) StringValue(k string) (string, bool) { return o.kv.Get(k) } // Uint64Value returns the first value in a set as uint64 func (o *Options) Uint64Value(k string) (uint64, bool) { - str,ok := o.kv.Get(k) + str, ok := o.kv.Get(k) if !ok { return 0, ok } @@ -107,8 +103,8 @@ func (o *Options) Uint64Value(k string) (uint64, bool) { } // StringSlice returns multiple values assigned to the same key as a string slice -func (o *Options) StringSlice(k string, delim string) ([]string,bool) { - str,ok := o.kv.Get(k) +func (o *Options) StringSlice(k string, delim string) ([]string, bool) { + str, ok := o.kv.Get(k) if !ok { return nil, ok } @@ -121,10 +117,11 @@ func (o *Options) StringSlice(k string, delim string) ([]string,bool) { } // TimeValue returns a parsed time value for the specified key, -// if the key does not exist the but defaultValue was set it will return the defaultValue -// otherwise it will return the passed in time or an error if parsing failed +// +// if the key does not exist the but defaultValue was set it will return the defaultValue +// otherwise it will return the passed in time or an error if parsing failed func (o *Options) TimeValue(key, timeFormat string, defaultValue *time.Time) time.Time { - str,ok := o.kv.Get(key) + str, ok := o.kv.Get(key) if !ok { if defaultValue != nil { return *defaultValue @@ -141,9 +138,6 @@ func (o *Options) TimeValue(key, timeFormat string, defaultValue *time.Time) tim return t } - - - type CompoundOptions struct { kvs []OptionsProvider } diff --git a/server.go b/server.go index f8eac16..030a684 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package snap import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -22,12 +23,14 @@ import ( ) type Server struct { - address string - theme string - debug bool - fs http.FileSystem - auth auth.Authenticator - router *mux.Router + address string + theme string + debug bool + fs http.FileSystem + auth auth.Authenticator + router *mux.Router + http *http.Server + templates string cachedTmpl *template.Template templateFuncs template.FuncMap @@ -320,7 +323,7 @@ func (s *Server) ServeTLS(keyPath string, certPath string) error { log.Fatalf("%v\n", err) } - srv := &http.Server{ + s.http = &http.Server{ Handler: s.router, Addr: s.address, // Good practice: enforce timeouts for servers you create! @@ -328,30 +331,30 @@ func (s *Server) ServeTLS(keyPath string, certPath string) error { ReadTimeout: 120 * time.Second, TLSConfig: &tls.Config{}, } - srv.TLSConfig.GetCertificate = kpr.GetCertificateFunc() - return srv.ListenAndServeTLS("", "") + s.http.TLSConfig.GetCertificate = kpr.GetCertificateFunc() + return s.http.ListenAndServeTLS("", "") } func (s *Server) ServeTLSRedirect(address string) error { - srv := &http.Server{ + s.http = &http.Server{ Addr: address, // Good practice: enforce timeouts for servers you create! WriteTimeout: 120 * time.Second, ReadTimeout: 120 * time.Second, } - return srv.ListenAndServe() + return s.http.ListenAndServe() } -// Serve serve content forever +// Serve content forever func (s *Server) Serve() error { - srv := &http.Server{ + s.http = &http.Server{ Handler: s.router, Addr: s.address, // Good practice: enforce timeouts for servers you create! WriteTimeout: 120 * time.Second, ReadTimeout: 120 * time.Second, } - return srv.ListenAndServe() + return s.http.ListenAndServe() } func (s *Server) WithStaticFiles(prefix string) *Server { @@ -424,6 +427,12 @@ func (s *Server) Dump() { fmt.Printf(" Templates: %s\n", s.templates) } +func (s *Server) Shutdown(ctx context.Context) { + if s.http != nil { + s.http.Shutdown(ctx) + } +} + func New(address string, path string, auth auth.Authenticator) *Server { s := Server{ router: mux.NewRouter(),