Add some nice code cleanup
This commit is contained in:
parent
e59a973228
commit
8151a24cc3
|
@ -20,7 +20,6 @@ type Authenticator interface {
|
|||
DoAuth(w http.ResponseWriter, r *http.Request) (*AuthData, bool)
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------------------------------------------------
|
||||
func NewAuth(kind string) AuthManager {
|
||||
switch kind {
|
||||
case "basic":
|
||||
|
|
|
@ -2,6 +2,7 @@ package snap
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
@ -96,6 +97,12 @@ func (c *Context) RenderEx(tmpl string, content interface{}) {
|
|||
Meta: c.srv.meta,
|
||||
Content: content,
|
||||
}
|
||||
|
||||
if c.auth != nil {
|
||||
fmt.Printf("Properties: %+v\n", c.auth.Properties)
|
||||
cnt.UserProperties = c.auth.Properties
|
||||
}
|
||||
|
||||
c.srv.render(c.w, tmpl, &cnt)
|
||||
}
|
||||
|
||||
|
@ -119,6 +126,7 @@ func (c *Context) RenderWithMeta(tmpl string, meta map[string]string, content in
|
|||
}
|
||||
|
||||
if c.auth != nil {
|
||||
fmt.Printf("Properties: %+v\n", c.auth.Properties)
|
||||
cnt.UserProperties = c.auth.Properties
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
package autocert
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CertManager struct {
|
||||
lock sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
certPath string
|
||||
keyPath string
|
||||
}
|
||||
|
||||
// NewManager refreshes certificates before they expire
|
||||
func NewManager(certPath, keyPath string) (*CertManager, error) {
|
||||
m := &CertManager{
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.cert = &cert
|
||||
go m.updater()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *CertManager) tryRefresh() time.Duration {
|
||||
newCert, err := tls.LoadX509KeyPair(m.certPath, m.keyPath)
|
||||
if err != nil {
|
||||
return time.Hour
|
||||
}
|
||||
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
m.cert = &newCert
|
||||
return time.Hour * 24
|
||||
}
|
||||
|
||||
func (m *CertManager) updater() {
|
||||
for {
|
||||
sleep := m.tryRefresh()
|
||||
time.Sleep(sleep)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *CertManager) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
return m.cert, nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
package snap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.thirdmartini.com/pub/snap/auth"
|
||||
)
|
||||
|
||||
type RouteBuilder struct {
|
||||
auth auth.Authenticator
|
||||
contentHandler func(c *Context)
|
||||
loginHandler func(c *Context) bool
|
||||
}
|
||||
|
||||
func (r *RouteBuilder) WithLoginHandler(loginHandler func(c *Context) bool) *RouteBuilder {
|
||||
return &RouteBuilder{
|
||||
auth: r.auth,
|
||||
loginHandler: loginHandler,
|
||||
contentHandler: r.contentHandler,
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RouteBuilder) WithAuth(auth auth.Authenticator) *RouteBuilder {
|
||||
return &RouteBuilder{
|
||||
auth: auth,
|
||||
loginHandler: r.loginHandler,
|
||||
contentHandler: r.contentHandler,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RouteBuilder) WithContent(contentHandler func(c *Context)) *RouteBuilder {
|
||||
return &RouteBuilder{
|
||||
auth: r.auth,
|
||||
loginHandler: r.loginHandler,
|
||||
contentHandler: contentHandler,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RouteBuilder) BuildRoute(s *Server) http.HandlerFunc {
|
||||
if r.auth == nil {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
c := s.makeContext(nil, w, req)
|
||||
r.contentHandler(c)
|
||||
}
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
rec, ok := r.auth.DoAuth(w, req)
|
||||
if !ok {
|
||||
if r.loginHandler != nil {
|
||||
c := s.makeContext(rec, w, req)
|
||||
if r.loginHandler(c) == false {
|
||||
return
|
||||
}
|
||||
}
|
||||
fmt.Printf("Sending to: %s\n", req.URL.Path)
|
||||
req.Method = http.MethodGet
|
||||
http.Redirect(w, req, req.URL.Path, http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
c := s.makeContext(rec, w, req)
|
||||
r.contentHandler(c)
|
||||
}
|
||||
}
|
||||
|
||||
func NewRouteBuilder() *RouteBuilder {
|
||||
return &RouteBuilder{}
|
||||
}
|
26
server.go
26
server.go
|
@ -1,6 +1,7 @@
|
|||
package snap
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
|
@ -12,6 +13,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"git.thirdmartini.com/pub/snap/pkg/autocert"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"git.thirdmartini.com/pub/fancylog"
|
||||
|
@ -222,7 +224,6 @@ func (s *Server) loadTemplates() *template.Template {
|
|||
if err != nil {
|
||||
log.Fatal("loadTemplates", err, s.templates)
|
||||
}
|
||||
|
||||
return tmpl
|
||||
}
|
||||
|
||||
|
@ -288,6 +289,10 @@ func (s *Server) HandleFunc(path string, f func(c *Context)) *mux.Route {
|
|||
return s.router.HandleFunc(path, s.wrapper(f))
|
||||
}
|
||||
|
||||
func (s *Server) AddRoute(path string, r *RouteBuilder) *mux.Route {
|
||||
return s.router.HandleFunc(path, r.BuildRoute(s))
|
||||
}
|
||||
|
||||
func (s *Server) SetDebug(enable bool) {
|
||||
s.debug = enable
|
||||
}
|
||||
|
@ -308,15 +313,31 @@ func (s *Server) Router() *mux.Router {
|
|||
}
|
||||
|
||||
func (s *Server) ServeTLS(keyPath string, certPath string) error {
|
||||
kpr, err := autocert.NewManager(certPath, keyPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
Handler: s.router,
|
||||
Addr: s.address,
|
||||
// Good practice: enforce timeouts for servers you create!
|
||||
WriteTimeout: 120 * time.Second,
|
||||
ReadTimeout: 120 * time.Second,
|
||||
TLSConfig: &tls.Config{},
|
||||
}
|
||||
srv.TLSConfig.GetCertificate = kpr.GetCertificateFunc()
|
||||
return srv.ListenAndServeTLS("", "")
|
||||
}
|
||||
|
||||
return srv.ListenAndServeTLS(certPath, keyPath)
|
||||
func (s *Server) ServeTLSRedirect(address string) error {
|
||||
srv := &http.Server{
|
||||
Addr: address,
|
||||
// Good practice: enforce timeouts for servers you create!
|
||||
WriteTimeout: 120 * time.Second,
|
||||
ReadTimeout: 120 * time.Second,
|
||||
}
|
||||
return srv.ListenAndServe()
|
||||
}
|
||||
|
||||
// Serve serve content forever
|
||||
|
@ -328,7 +349,6 @@ func (s *Server) Serve() error {
|
|||
WriteTimeout: 120 * time.Second,
|
||||
ReadTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
return srv.ListenAndServe()
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue