diff --git a/auth/auth.go b/auth/auth.go index df69983..fe96c39 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -20,7 +20,6 @@ type Authenticator interface { DoAuth(w http.ResponseWriter, r *http.Request) (*AuthData, bool) } -//---------------------------------------------------------------------------------------------------------------------- func NewAuth(kind string) AuthManager { switch kind { case "basic": diff --git a/context.go b/context.go index a008b12..c0d3d43 100644 --- a/context.go +++ b/context.go @@ -2,6 +2,7 @@ package snap import ( "encoding/json" + "fmt" "io/ioutil" "net/http" "strconv" @@ -96,6 +97,12 @@ func (c *Context) RenderEx(tmpl string, content interface{}) { Meta: c.srv.meta, Content: content, } + + if c.auth != nil { + fmt.Printf("Properties: %+v\n", c.auth.Properties) + cnt.UserProperties = c.auth.Properties + } + c.srv.render(c.w, tmpl, &cnt) } @@ -119,6 +126,7 @@ func (c *Context) RenderWithMeta(tmpl string, meta map[string]string, content in } if c.auth != nil { + fmt.Printf("Properties: %+v\n", c.auth.Properties) cnt.UserProperties = c.auth.Properties } diff --git a/pkg/autocert/certs.go b/pkg/autocert/certs.go new file mode 100644 index 0000000..f5a1608 --- /dev/null +++ b/pkg/autocert/certs.go @@ -0,0 +1,57 @@ +package autocert + +import ( + "crypto/tls" + "sync" + "time" +) + +type CertManager struct { + lock sync.RWMutex + cert *tls.Certificate + certPath string + keyPath string +} + +// NewManager refreshes certificates before they expire +func NewManager(certPath, keyPath string) (*CertManager, error) { + m := &CertManager{ + certPath: certPath, + keyPath: keyPath, + } + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + m.cert = &cert + go m.updater() + + return m, nil +} + +func (m *CertManager) tryRefresh() time.Duration { + newCert, err := tls.LoadX509KeyPair(m.certPath, m.keyPath) + if err != nil { + return time.Hour + } + + m.lock.Lock() + defer m.lock.Unlock() + m.cert = &newCert + return time.Hour * 24 +} + +func (m *CertManager) updater() { + for { + sleep := m.tryRefresh() + time.Sleep(sleep) + } +} + +func (m *CertManager) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + m.lock.RLock() + defer m.lock.RUnlock() + return m.cert, nil + } +} diff --git a/route.go b/route.go new file mode 100644 index 0000000..54237d8 --- /dev/null +++ b/route.go @@ -0,0 +1,70 @@ +package snap + +import ( + "fmt" + "net/http" + + "git.thirdmartini.com/pub/snap/auth" +) + +type RouteBuilder struct { + auth auth.Authenticator + contentHandler func(c *Context) + loginHandler func(c *Context) bool +} + +func (r *RouteBuilder) WithLoginHandler(loginHandler func(c *Context) bool) *RouteBuilder { + return &RouteBuilder{ + auth: r.auth, + loginHandler: loginHandler, + contentHandler: r.contentHandler, + } + return r +} + +func (r *RouteBuilder) WithAuth(auth auth.Authenticator) *RouteBuilder { + return &RouteBuilder{ + auth: auth, + loginHandler: r.loginHandler, + contentHandler: r.contentHandler, + } +} + +func (r *RouteBuilder) WithContent(contentHandler func(c *Context)) *RouteBuilder { + return &RouteBuilder{ + auth: r.auth, + loginHandler: r.loginHandler, + contentHandler: contentHandler, + } +} + +func (r *RouteBuilder) BuildRoute(s *Server) http.HandlerFunc { + if r.auth == nil { + return func(w http.ResponseWriter, req *http.Request) { + c := s.makeContext(nil, w, req) + r.contentHandler(c) + } + } + + return func(w http.ResponseWriter, req *http.Request) { + rec, ok := r.auth.DoAuth(w, req) + if !ok { + if r.loginHandler != nil { + c := s.makeContext(rec, w, req) + if r.loginHandler(c) == false { + return + } + } + fmt.Printf("Sending to: %s\n", req.URL.Path) + req.Method = http.MethodGet + http.Redirect(w, req, req.URL.Path, http.StatusSeeOther) + return + } + c := s.makeContext(rec, w, req) + r.contentHandler(c) + } +} + +func NewRouteBuilder() *RouteBuilder { + return &RouteBuilder{} +} diff --git a/server.go b/server.go index 751d7ce..587d55d 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package snap import ( + "crypto/tls" "encoding/json" "fmt" "html/template" @@ -12,6 +13,7 @@ import ( "strings" "time" + "git.thirdmartini.com/pub/snap/pkg/autocert" "github.com/gorilla/mux" "git.thirdmartini.com/pub/fancylog" @@ -222,7 +224,6 @@ func (s *Server) loadTemplates() *template.Template { if err != nil { log.Fatal("loadTemplates", err, s.templates) } - return tmpl } @@ -288,6 +289,10 @@ func (s *Server) HandleFunc(path string, f func(c *Context)) *mux.Route { return s.router.HandleFunc(path, s.wrapper(f)) } +func (s *Server) AddRoute(path string, r *RouteBuilder) *mux.Route { + return s.router.HandleFunc(path, r.BuildRoute(s)) +} + func (s *Server) SetDebug(enable bool) { s.debug = enable } @@ -308,15 +313,31 @@ func (s *Server) Router() *mux.Router { } func (s *Server) ServeTLS(keyPath string, certPath string) error { + kpr, err := autocert.NewManager(certPath, keyPath) + if err != nil { + log.Fatal(err) + } + srv := &http.Server{ Handler: s.router, Addr: s.address, // Good practice: enforce timeouts for servers you create! WriteTimeout: 120 * time.Second, ReadTimeout: 120 * time.Second, + TLSConfig: &tls.Config{}, } + srv.TLSConfig.GetCertificate = kpr.GetCertificateFunc() + return srv.ListenAndServeTLS("", "") +} - return srv.ListenAndServeTLS(certPath, keyPath) +func (s *Server) ServeTLSRedirect(address string) error { + srv := &http.Server{ + Addr: address, + // Good practice: enforce timeouts for servers you create! + WriteTimeout: 120 * time.Second, + ReadTimeout: 120 * time.Second, + } + return srv.ListenAndServe() } // Serve serve content forever @@ -328,7 +349,6 @@ func (s *Server) Serve() error { WriteTimeout: 120 * time.Second, ReadTimeout: 120 * time.Second, } - return srv.ListenAndServe() }