diff --git a/server.go b/server.go index 99a2844..37f8f2b 100644 --- a/server.go +++ b/server.go @@ -8,7 +8,6 @@ import ( "net/http" "os" "path" - "path/filepath" "strings" "time" @@ -21,9 +20,9 @@ import ( type Server struct { address string - path string theme string debug bool + fs http.FileSystem auth auth.Authenticator router *mux.Router templates string @@ -94,13 +93,19 @@ func (s *Server) parseTemplates(t *template.Template, filenames ...string) (*tem return nil, fmt.Errorf("html/template: no files named in call to ParseFiles") } for _, filename := range filenames { - b, err := ioutil.ReadFile(filename) + f, err := s.fs.Open(filename) + if err != nil { + return nil, err + } + + b, err := ioutil.ReadAll(f) + f.Close() if err != nil { return nil, err } data := string(b) name := strings.TrimPrefix(filename, s.templates+"/") - // log.Println("Template:", name) + log.Println("Template:", name) // First template becomes return value if not already defined, // and we use that one for subsequent New calls to associate @@ -125,10 +130,40 @@ func (s *Server) parseTemplates(t *template.Template, filenames ...string) (*tem return t, nil } -func (s *Server) loadTemplates() *template.Template { +func Walk(fs http.FileSystem, base string, walkFunc func(path string, info os.FileInfo, err error) error) error { + f, err := fs.Open(base) + if err != nil { + return err + } + defer f.Close() + + s, err := f.Stat() + if err != nil { + return err + } + + if s.IsDir() { + // its a directory, recurse + files, err := f.Readdir(1024) + if err != nil { + return err + } + + for _, cf := range files { + if err = Walk(fs, path.Join(base, cf.Name()), walkFunc); err != nil { + return err + } + } + return nil + } + + return walkFunc(base, s, nil) +} + +func (s *Server) LoadTemplatesFS(fs http.FileSystem, base string) (*template.Template, error) { tmpl := template.New("") - err := filepath.Walk(s.templates, func(path string, info os.FileInfo, err error) error { + err := Walk(fs, base, func(path string, info os.FileInfo, err error) error { if strings.Contains(path, ".html") { _, err := s.parseTemplates(tmpl, path) if err != nil { @@ -138,6 +173,15 @@ func (s *Server) loadTemplates() *template.Template { return err }) + if err != nil { + return nil, err + } + + return tmpl, nil +} + +func (s *Server) loadTemplates() *template.Template { + tmpl, err := s.LoadTemplatesFS(s.fs, s.templates) if err != nil { log.Fatal(err) } @@ -236,8 +280,8 @@ func (s *Server) Serve() error { return srv.ListenAndServe() } -func (s *Server) WithStaticFiles(prefix string, path string) *Server { - s.router.PathPrefix(prefix).Handler(http.StripPrefix(prefix, http.FileServer(http.Dir(path)))) +func (s *Server) WithStaticFiles(prefix string) *Server { + s.router.PathPrefix(prefix).Handler(http.FileServer(s.fs)) return s } @@ -253,10 +297,8 @@ func (s *Server) WithDebug(debugURL string) *Server { } func (s *Server) Dump() { - fmt.Printf(" Path: %s\n", s.path) fmt.Printf(" Theme: %s\n", s.theme) fmt.Printf(" Templates: %s\n", s.templates) - } func New(address string, path string, auth auth.Authenticator) *Server { @@ -264,8 +306,8 @@ func New(address string, path string, auth auth.Authenticator) *Server { router: mux.NewRouter(), auth: auth, address: address, - path: path, - templates: "templates", + fs: http.FileSystem(http.Dir(path)), + templates: "/templates", theme: "/static/css/default.css", } return &s diff --git a/server_test.go b/server_test.go index 8290ac6..6f0fa77 100644 --- a/server_test.go +++ b/server_test.go @@ -1,6 +1,7 @@ package snap import ( + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -46,6 +47,7 @@ func handlerLogin(c *Context) { } func handlerTest(c *Context) { + fmt.Printf(":handlerTest\n") c.RenderEx("test.html", nil) } @@ -65,10 +67,10 @@ func TestServer_BasicAuth(t *testing.T) { auth := auth.NewBasicAuth() auth.AddUser("admin", "admin", "test") - s := New("", "/test", auth) + s := New("", "test", auth) s.SetDebug(true) - s.SetTemplatePath("test/templates") - s.WithStaticFiles("/static", "test/static") + s.SetTemplatePath("/templates") + s.WithStaticFiles("/static") s.WithTheme("skin/") s.HandleFunc("/", handlerRoot) s.HandleFuncAuthenticated("/login", "", handlerLogin) @@ -126,10 +128,10 @@ func TestServer_TokenAuth(t *testing.T) { auth := auth.NewTokenAuth() auth.AddUser("admin", "admin", "1234567890") - s := New("", "/test", auth) + s := New("", "test", auth) s.SetDebug(true) - s.SetTemplatePath("test/templates") - s.WithStaticFiles("/static", "test/static") + s.SetTemplatePath("/templates") + s.WithStaticFiles("/static") s.WithTheme("skin/") s.HandleFunc("/", handlerRoot) s.HandleFuncAuthenticated("/login", "", handlerLogin) @@ -148,3 +150,12 @@ func TestServer_TokenAuth(t *testing.T) { assert.Equal(t, http.StatusOK, code) assert.Equal(t, rootExpected, data) } + +func TestFilesystem(t *testing.T) { + fs := http.FileSystem(http.FileSystem(http.Dir("./test"))) + + s := New("", "test", nil) + tmpl, err := s.LoadTemplatesFS(fs, "/templates") + assert.Nil(t, err) + assert.NotNil(t, tmpl) +}