215 lines
4.3 KiB
Go
215 lines
4.3 KiB
Go
|
package sql
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"fmt"
|
||
|
|
||
|
sq "github.com/Masterminds/squirrel"
|
||
|
"github.com/blockloop/scan/v2"
|
||
|
)
|
||
|
|
||
|
type Logger interface {
|
||
|
Debugf(format string, args ...any)
|
||
|
}
|
||
|
|
||
|
type Database struct {
|
||
|
sql *sql.DB
|
||
|
log Logger
|
||
|
placeholderFormat sq.PlaceholderFormat
|
||
|
DebugEnabled bool
|
||
|
}
|
||
|
|
||
|
func (db *Database) builder() sq.StatementBuilderType {
|
||
|
return sq.StatementBuilder.PlaceholderFormat(db.placeholderFormat)
|
||
|
}
|
||
|
|
||
|
func (db *Database) logQuery(query string, args ...interface{}) {
|
||
|
if db.DebugEnabled {
|
||
|
db.log.Debugf("%s\n", query)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ObjectGet retrieves 1 object based on provided criteria
|
||
|
func (db *Database) ObjectGet(table string, pred []interface{}, obj interface{}) (error, bool) {
|
||
|
sb := db.builder().Select("*").From(table)
|
||
|
if len(pred) == 1 {
|
||
|
sb = sb.Where(pred[0])
|
||
|
} else if len(pred) > 1 {
|
||
|
sb = sb.Where(pred[0], pred[1:]...)
|
||
|
}
|
||
|
|
||
|
query, args, err := sb.ToSql()
|
||
|
if err != nil {
|
||
|
return err, false
|
||
|
}
|
||
|
|
||
|
rows, err := db.sql.Query(query, args...)
|
||
|
if err != nil {
|
||
|
if err == sql.ErrNoRows {
|
||
|
return nil, false
|
||
|
}
|
||
|
return NewError(query, err), false
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
|
||
|
err = scan.Row(obj, rows)
|
||
|
if err != nil {
|
||
|
if err == sql.ErrNoRows {
|
||
|
return nil, false
|
||
|
}
|
||
|
return err, false
|
||
|
}
|
||
|
return nil, true
|
||
|
}
|
||
|
|
||
|
func (db *Database) Insert(table string, record interface{}) error {
|
||
|
m := GetRecordMap(record)
|
||
|
|
||
|
query, args, err := db.builder().Insert(table).SetMap(m).ToSql()
|
||
|
if err != nil {
|
||
|
return NewError(query, err)
|
||
|
}
|
||
|
|
||
|
db.logQuery(query)
|
||
|
|
||
|
stmt, err := db.sql.Prepare(query)
|
||
|
if err != nil {
|
||
|
return NewError(query, err)
|
||
|
}
|
||
|
|
||
|
_, err = stmt.Exec(args...)
|
||
|
if err != nil {
|
||
|
return err // Concern maybe no rows?
|
||
|
}
|
||
|
defer stmt.Close()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// ListRows performs a generic list from a db based on the passed in parameters
|
||
|
func (db *Database) ListRows(table string, pred []interface{}, order []string, scanFunc func(rows *sql.Rows) error) error {
|
||
|
sb := db.builder().Select("*").From(table)
|
||
|
if len(pred) == 1 {
|
||
|
sb = sb.Where(pred[0])
|
||
|
} else if len(pred) > 1 {
|
||
|
sb = sb.Where(pred[0], pred[1:]...)
|
||
|
}
|
||
|
|
||
|
if len(order) != 0 {
|
||
|
sb = sb.OrderBy(order...)
|
||
|
}
|
||
|
|
||
|
query, args, err := sb.ToSql()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
db.logQuery(query, args...)
|
||
|
|
||
|
rows, err := db.SQL().Query(query, args...)
|
||
|
if err != nil {
|
||
|
if err == sql.ErrNoRows {
|
||
|
return nil
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
defer rows.Close()
|
||
|
return scanFunc(rows)
|
||
|
}
|
||
|
|
||
|
func (db *Database) Update(table string, constraint string, value interface{}, record interface{}) error {
|
||
|
m := GetRecordMap(record)
|
||
|
delete(m, constraint)
|
||
|
|
||
|
query, args, err := db.builder().Update(table).SetMap(m).Where(sq.Eq{constraint: value}).ToSql()
|
||
|
if err != nil {
|
||
|
return NewError(query, err)
|
||
|
}
|
||
|
|
||
|
db.logQuery(query, args...)
|
||
|
|
||
|
stmt, err := db.sql.Prepare(query)
|
||
|
if err != nil {
|
||
|
return NewError(query, err)
|
||
|
}
|
||
|
|
||
|
result, err := stmt.Exec(args...)
|
||
|
if err != nil {
|
||
|
return NewError(query, err)
|
||
|
}
|
||
|
defer stmt.Close()
|
||
|
|
||
|
c, err := result.RowsAffected()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if c == 0 {
|
||
|
return sql.ErrNoRows
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (db *Database) InsertOrUpdate(table string, constraint string, value interface{}, record interface{}) error {
|
||
|
err := db.Update(table, constraint, value, record)
|
||
|
if err == sql.ErrNoRows {
|
||
|
return db.Insert(table, record)
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// CountRows counts the number of rows matching the where clause
|
||
|
func (db *Database) CountRows(table string, where string) (int, error) {
|
||
|
query := fmt.Sprintf("select count(*) from %s", table)
|
||
|
|
||
|
if where != "" {
|
||
|
query += fmt.Sprintf(" WHERE %s", where)
|
||
|
}
|
||
|
query += ";"
|
||
|
|
||
|
db.logQuery(query)
|
||
|
|
||
|
count := 0
|
||
|
row := db.SQL().QueryRow(query)
|
||
|
err := row.Scan(&count)
|
||
|
if err != nil {
|
||
|
if err == sql.ErrNoRows {
|
||
|
return 0, nil
|
||
|
}
|
||
|
return 0, NewError(query, err)
|
||
|
}
|
||
|
return count, nil
|
||
|
}
|
||
|
|
||
|
// Close closes the db connection
|
||
|
func (db *Database) Close() error {
|
||
|
if db.sql != nil {
|
||
|
return db.sql.Close()
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// SQL returns the accessor to the sql driver under the covers.. used for importing data
|
||
|
func (db *Database) SQL() *sql.DB {
|
||
|
return db.sql
|
||
|
}
|
||
|
|
||
|
func New(dbURN string) (*Database, error) {
|
||
|
if dbURN == "" {
|
||
|
return nil, fmt.Errorf("no database URN provided")
|
||
|
}
|
||
|
|
||
|
sqlDB, err := sql.Open("postgres", dbURN)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
err = sqlDB.Ping()
|
||
|
if err != nil {
|
||
|
_ = sqlDB.Close()
|
||
|
return nil, err
|
||
|
}
|
||
|
return &Database{
|
||
|
sql: sqlDB,
|
||
|
}, nil
|
||
|
}
|