Add some nice code cleanup
This commit is contained in:
parent
e59a973228
commit
8151a24cc3
|
@ -20,7 +20,6 @@ type Authenticator interface {
|
||||||
DoAuth(w http.ResponseWriter, r *http.Request) (*AuthData, bool)
|
DoAuth(w http.ResponseWriter, r *http.Request) (*AuthData, bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
//----------------------------------------------------------------------------------------------------------------------
|
|
||||||
func NewAuth(kind string) AuthManager {
|
func NewAuth(kind string) AuthManager {
|
||||||
switch kind {
|
switch kind {
|
||||||
case "basic":
|
case "basic":
|
||||||
|
|
|
@ -2,6 +2,7 @@ package snap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -96,6 +97,12 @@ func (c *Context) RenderEx(tmpl string, content interface{}) {
|
||||||
Meta: c.srv.meta,
|
Meta: c.srv.meta,
|
||||||
Content: content,
|
Content: content,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.auth != nil {
|
||||||
|
fmt.Printf("Properties: %+v\n", c.auth.Properties)
|
||||||
|
cnt.UserProperties = c.auth.Properties
|
||||||
|
}
|
||||||
|
|
||||||
c.srv.render(c.w, tmpl, &cnt)
|
c.srv.render(c.w, tmpl, &cnt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,6 +126,7 @@ func (c *Context) RenderWithMeta(tmpl string, meta map[string]string, content in
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.auth != nil {
|
if c.auth != nil {
|
||||||
|
fmt.Printf("Properties: %+v\n", c.auth.Properties)
|
||||||
cnt.UserProperties = c.auth.Properties
|
cnt.UserProperties = c.auth.Properties
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
package autocert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CertManager struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
cert *tls.Certificate
|
||||||
|
certPath string
|
||||||
|
keyPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager refreshes certificates before they expire
|
||||||
|
func NewManager(certPath, keyPath string) (*CertManager, error) {
|
||||||
|
m := &CertManager{
|
||||||
|
certPath: certPath,
|
||||||
|
keyPath: keyPath,
|
||||||
|
}
|
||||||
|
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.cert = &cert
|
||||||
|
go m.updater()
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CertManager) tryRefresh() time.Duration {
|
||||||
|
newCert, err := tls.LoadX509KeyPair(m.certPath, m.keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
m.cert = &newCert
|
||||||
|
return time.Hour * 24
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CertManager) updater() {
|
||||||
|
for {
|
||||||
|
sleep := m.tryRefresh()
|
||||||
|
time.Sleep(sleep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CertManager) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
m.lock.RLock()
|
||||||
|
defer m.lock.RUnlock()
|
||||||
|
return m.cert, nil
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
package snap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.thirdmartini.com/pub/snap/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RouteBuilder struct {
|
||||||
|
auth auth.Authenticator
|
||||||
|
contentHandler func(c *Context)
|
||||||
|
loginHandler func(c *Context) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RouteBuilder) WithLoginHandler(loginHandler func(c *Context) bool) *RouteBuilder {
|
||||||
|
return &RouteBuilder{
|
||||||
|
auth: r.auth,
|
||||||
|
loginHandler: loginHandler,
|
||||||
|
contentHandler: r.contentHandler,
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RouteBuilder) WithAuth(auth auth.Authenticator) *RouteBuilder {
|
||||||
|
return &RouteBuilder{
|
||||||
|
auth: auth,
|
||||||
|
loginHandler: r.loginHandler,
|
||||||
|
contentHandler: r.contentHandler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RouteBuilder) WithContent(contentHandler func(c *Context)) *RouteBuilder {
|
||||||
|
return &RouteBuilder{
|
||||||
|
auth: r.auth,
|
||||||
|
loginHandler: r.loginHandler,
|
||||||
|
contentHandler: contentHandler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RouteBuilder) BuildRoute(s *Server) http.HandlerFunc {
|
||||||
|
if r.auth == nil {
|
||||||
|
return func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
c := s.makeContext(nil, w, req)
|
||||||
|
r.contentHandler(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
rec, ok := r.auth.DoAuth(w, req)
|
||||||
|
if !ok {
|
||||||
|
if r.loginHandler != nil {
|
||||||
|
c := s.makeContext(rec, w, req)
|
||||||
|
if r.loginHandler(c) == false {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf("Sending to: %s\n", req.URL.Path)
|
||||||
|
req.Method = http.MethodGet
|
||||||
|
http.Redirect(w, req, req.URL.Path, http.StatusSeeOther)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c := s.makeContext(rec, w, req)
|
||||||
|
r.contentHandler(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRouteBuilder() *RouteBuilder {
|
||||||
|
return &RouteBuilder{}
|
||||||
|
}
|
26
server.go
26
server.go
|
@ -1,6 +1,7 @@
|
||||||
package snap
|
package snap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
|
@ -12,6 +13,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.thirdmartini.com/pub/snap/pkg/autocert"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"git.thirdmartini.com/pub/fancylog"
|
"git.thirdmartini.com/pub/fancylog"
|
||||||
|
@ -222,7 +224,6 @@ func (s *Server) loadTemplates() *template.Template {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("loadTemplates", err, s.templates)
|
log.Fatal("loadTemplates", err, s.templates)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tmpl
|
return tmpl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -288,6 +289,10 @@ func (s *Server) HandleFunc(path string, f func(c *Context)) *mux.Route {
|
||||||
return s.router.HandleFunc(path, s.wrapper(f))
|
return s.router.HandleFunc(path, s.wrapper(f))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) AddRoute(path string, r *RouteBuilder) *mux.Route {
|
||||||
|
return s.router.HandleFunc(path, r.BuildRoute(s))
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) SetDebug(enable bool) {
|
func (s *Server) SetDebug(enable bool) {
|
||||||
s.debug = enable
|
s.debug = enable
|
||||||
}
|
}
|
||||||
|
@ -308,15 +313,31 @@ func (s *Server) Router() *mux.Router {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ServeTLS(keyPath string, certPath string) error {
|
func (s *Server) ServeTLS(keyPath string, certPath string) error {
|
||||||
|
kpr, err := autocert.NewManager(certPath, keyPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Handler: s.router,
|
Handler: s.router,
|
||||||
Addr: s.address,
|
Addr: s.address,
|
||||||
// Good practice: enforce timeouts for servers you create!
|
// Good practice: enforce timeouts for servers you create!
|
||||||
WriteTimeout: 120 * time.Second,
|
WriteTimeout: 120 * time.Second,
|
||||||
ReadTimeout: 120 * time.Second,
|
ReadTimeout: 120 * time.Second,
|
||||||
|
TLSConfig: &tls.Config{},
|
||||||
}
|
}
|
||||||
|
srv.TLSConfig.GetCertificate = kpr.GetCertificateFunc()
|
||||||
|
return srv.ListenAndServeTLS("", "")
|
||||||
|
}
|
||||||
|
|
||||||
return srv.ListenAndServeTLS(certPath, keyPath)
|
func (s *Server) ServeTLSRedirect(address string) error {
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: address,
|
||||||
|
// Good practice: enforce timeouts for servers you create!
|
||||||
|
WriteTimeout: 120 * time.Second,
|
||||||
|
ReadTimeout: 120 * time.Second,
|
||||||
|
}
|
||||||
|
return srv.ListenAndServe()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve serve content forever
|
// Serve serve content forever
|
||||||
|
@ -328,7 +349,6 @@ func (s *Server) Serve() error {
|
||||||
WriteTimeout: 120 * time.Second,
|
WriteTimeout: 120 * time.Second,
|
||||||
ReadTimeout: 120 * time.Second,
|
ReadTimeout: 120 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
return srv.ListenAndServe()
|
return srv.ListenAndServe()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue