Add some nice code cleanup

This commit is contained in:
spsobole 2021-06-18 13:43:52 -06:00
parent e59a973228
commit 8151a24cc3
5 changed files with 158 additions and 4 deletions

View File

@ -20,7 +20,6 @@ type Authenticator interface {
DoAuth(w http.ResponseWriter, r *http.Request) (*AuthData, bool) DoAuth(w http.ResponseWriter, r *http.Request) (*AuthData, bool)
} }
//----------------------------------------------------------------------------------------------------------------------
func NewAuth(kind string) AuthManager { func NewAuth(kind string) AuthManager {
switch kind { switch kind {
case "basic": case "basic":

View File

@ -2,6 +2,7 @@ package snap
import ( import (
"encoding/json" "encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strconv" "strconv"
@ -96,6 +97,12 @@ func (c *Context) RenderEx(tmpl string, content interface{}) {
Meta: c.srv.meta, Meta: c.srv.meta,
Content: content, Content: content,
} }
if c.auth != nil {
fmt.Printf("Properties: %+v\n", c.auth.Properties)
cnt.UserProperties = c.auth.Properties
}
c.srv.render(c.w, tmpl, &cnt) c.srv.render(c.w, tmpl, &cnt)
} }
@ -119,6 +126,7 @@ func (c *Context) RenderWithMeta(tmpl string, meta map[string]string, content in
} }
if c.auth != nil { if c.auth != nil {
fmt.Printf("Properties: %+v\n", c.auth.Properties)
cnt.UserProperties = c.auth.Properties cnt.UserProperties = c.auth.Properties
} }

57
pkg/autocert/certs.go Normal file
View File

@ -0,0 +1,57 @@
package autocert
import (
"crypto/tls"
"sync"
"time"
)
type CertManager struct {
lock sync.RWMutex
cert *tls.Certificate
certPath string
keyPath string
}
// NewManager refreshes certificates before they expire
func NewManager(certPath, keyPath string) (*CertManager, error) {
m := &CertManager{
certPath: certPath,
keyPath: keyPath,
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, err
}
m.cert = &cert
go m.updater()
return m, nil
}
func (m *CertManager) tryRefresh() time.Duration {
newCert, err := tls.LoadX509KeyPair(m.certPath, m.keyPath)
if err != nil {
return time.Hour
}
m.lock.Lock()
defer m.lock.Unlock()
m.cert = &newCert
return time.Hour * 24
}
func (m *CertManager) updater() {
for {
sleep := m.tryRefresh()
time.Sleep(sleep)
}
}
func (m *CertManager) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
m.lock.RLock()
defer m.lock.RUnlock()
return m.cert, nil
}
}

70
route.go Normal file
View File

@ -0,0 +1,70 @@
package snap
import (
"fmt"
"net/http"
"git.thirdmartini.com/pub/snap/auth"
)
type RouteBuilder struct {
auth auth.Authenticator
contentHandler func(c *Context)
loginHandler func(c *Context) bool
}
func (r *RouteBuilder) WithLoginHandler(loginHandler func(c *Context) bool) *RouteBuilder {
return &RouteBuilder{
auth: r.auth,
loginHandler: loginHandler,
contentHandler: r.contentHandler,
}
return r
}
func (r *RouteBuilder) WithAuth(auth auth.Authenticator) *RouteBuilder {
return &RouteBuilder{
auth: auth,
loginHandler: r.loginHandler,
contentHandler: r.contentHandler,
}
}
func (r *RouteBuilder) WithContent(contentHandler func(c *Context)) *RouteBuilder {
return &RouteBuilder{
auth: r.auth,
loginHandler: r.loginHandler,
contentHandler: contentHandler,
}
}
func (r *RouteBuilder) BuildRoute(s *Server) http.HandlerFunc {
if r.auth == nil {
return func(w http.ResponseWriter, req *http.Request) {
c := s.makeContext(nil, w, req)
r.contentHandler(c)
}
}
return func(w http.ResponseWriter, req *http.Request) {
rec, ok := r.auth.DoAuth(w, req)
if !ok {
if r.loginHandler != nil {
c := s.makeContext(rec, w, req)
if r.loginHandler(c) == false {
return
}
}
fmt.Printf("Sending to: %s\n", req.URL.Path)
req.Method = http.MethodGet
http.Redirect(w, req, req.URL.Path, http.StatusSeeOther)
return
}
c := s.makeContext(rec, w, req)
r.contentHandler(c)
}
}
func NewRouteBuilder() *RouteBuilder {
return &RouteBuilder{}
}

View File

@ -1,6 +1,7 @@
package snap package snap
import ( import (
"crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"html/template" "html/template"
@ -12,6 +13,7 @@ import (
"strings" "strings"
"time" "time"
"git.thirdmartini.com/pub/snap/pkg/autocert"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"git.thirdmartini.com/pub/fancylog" "git.thirdmartini.com/pub/fancylog"
@ -222,7 +224,6 @@ func (s *Server) loadTemplates() *template.Template {
if err != nil { if err != nil {
log.Fatal("loadTemplates", err, s.templates) log.Fatal("loadTemplates", err, s.templates)
} }
return tmpl return tmpl
} }
@ -288,6 +289,10 @@ func (s *Server) HandleFunc(path string, f func(c *Context)) *mux.Route {
return s.router.HandleFunc(path, s.wrapper(f)) return s.router.HandleFunc(path, s.wrapper(f))
} }
func (s *Server) AddRoute(path string, r *RouteBuilder) *mux.Route {
return s.router.HandleFunc(path, r.BuildRoute(s))
}
func (s *Server) SetDebug(enable bool) { func (s *Server) SetDebug(enable bool) {
s.debug = enable s.debug = enable
} }
@ -308,15 +313,31 @@ func (s *Server) Router() *mux.Router {
} }
func (s *Server) ServeTLS(keyPath string, certPath string) error { func (s *Server) ServeTLS(keyPath string, certPath string) error {
kpr, err := autocert.NewManager(certPath, keyPath)
if err != nil {
log.Fatal(err)
}
srv := &http.Server{ srv := &http.Server{
Handler: s.router, Handler: s.router,
Addr: s.address, Addr: s.address,
// Good practice: enforce timeouts for servers you create! // Good practice: enforce timeouts for servers you create!
WriteTimeout: 120 * time.Second, WriteTimeout: 120 * time.Second,
ReadTimeout: 120 * time.Second, ReadTimeout: 120 * time.Second,
TLSConfig: &tls.Config{},
}
srv.TLSConfig.GetCertificate = kpr.GetCertificateFunc()
return srv.ListenAndServeTLS("", "")
} }
return srv.ListenAndServeTLS(certPath, keyPath) func (s *Server) ServeTLSRedirect(address string) error {
srv := &http.Server{
Addr: address,
// Good practice: enforce timeouts for servers you create!
WriteTimeout: 120 * time.Second,
ReadTimeout: 120 * time.Second,
}
return srv.ListenAndServe()
} }
// Serve serve content forever // Serve serve content forever
@ -328,7 +349,6 @@ func (s *Server) Serve() error {
WriteTimeout: 120 * time.Second, WriteTimeout: 120 * time.Second,
ReadTimeout: 120 * time.Second, ReadTimeout: 120 * time.Second,
} }
return srv.ListenAndServe() return srv.ListenAndServe()
} }