core/db/sql/db.go

215 lines
4.3 KiB
Go
Raw Normal View History

2024-01-01 16:55:54 +00:00
package sql
import (
"database/sql"
"fmt"
sq "github.com/Masterminds/squirrel"
"github.com/blockloop/scan/v2"
)
type Logger interface {
Debugf(format string, args ...any)
}
type Database struct {
sql *sql.DB
log Logger
placeholderFormat sq.PlaceholderFormat
DebugEnabled bool
}
func (db *Database) builder() sq.StatementBuilderType {
return sq.StatementBuilder.PlaceholderFormat(db.placeholderFormat)
}
func (db *Database) logQuery(query string, args ...interface{}) {
if db.DebugEnabled {
db.log.Debugf("%s\n", query)
}
}
// ObjectGet retrieves 1 object based on provided criteria
func (db *Database) ObjectGet(table string, pred []interface{}, obj interface{}) (error, bool) {
sb := db.builder().Select("*").From(table)
if len(pred) == 1 {
sb = sb.Where(pred[0])
} else if len(pred) > 1 {
sb = sb.Where(pred[0], pred[1:]...)
}
query, args, err := sb.ToSql()
if err != nil {
return err, false
}
rows, err := db.sql.Query(query, args...)
if err != nil {
if err == sql.ErrNoRows {
return nil, false
}
return NewError(query, err), false
}
defer rows.Close()
err = scan.Row(obj, rows)
if err != nil {
if err == sql.ErrNoRows {
return nil, false
}
return err, false
}
return nil, true
}
func (db *Database) Insert(table string, record interface{}) error {
m := GetRecordMap(record)
query, args, err := db.builder().Insert(table).SetMap(m).ToSql()
if err != nil {
return NewError(query, err)
}
db.logQuery(query)
stmt, err := db.sql.Prepare(query)
if err != nil {
return NewError(query, err)
}
_, err = stmt.Exec(args...)
if err != nil {
return err // Concern maybe no rows?
}
defer stmt.Close()
return nil
}
// ListRows performs a generic list from a db based on the passed in parameters
func (db *Database) ListRows(table string, pred []interface{}, order []string, scanFunc func(rows *sql.Rows) error) error {
sb := db.builder().Select("*").From(table)
if len(pred) == 1 {
sb = sb.Where(pred[0])
} else if len(pred) > 1 {
sb = sb.Where(pred[0], pred[1:]...)
}
if len(order) != 0 {
sb = sb.OrderBy(order...)
}
query, args, err := sb.ToSql()
if err != nil {
return err
}
db.logQuery(query, args...)
rows, err := db.SQL().Query(query, args...)
if err != nil {
if err == sql.ErrNoRows {
return nil
}
return err
}
defer rows.Close()
return scanFunc(rows)
}
func (db *Database) Update(table string, constraint string, value interface{}, record interface{}) error {
m := GetRecordMap(record)
delete(m, constraint)
query, args, err := db.builder().Update(table).SetMap(m).Where(sq.Eq{constraint: value}).ToSql()
if err != nil {
return NewError(query, err)
}
db.logQuery(query, args...)
stmt, err := db.sql.Prepare(query)
if err != nil {
return NewError(query, err)
}
result, err := stmt.Exec(args...)
if err != nil {
return NewError(query, err)
}
defer stmt.Close()
c, err := result.RowsAffected()
if err != nil {
return err
}
if c == 0 {
return sql.ErrNoRows
}
return nil
}
func (db *Database) InsertOrUpdate(table string, constraint string, value interface{}, record interface{}) error {
err := db.Update(table, constraint, value, record)
if err == sql.ErrNoRows {
return db.Insert(table, record)
}
return err
}
// CountRows counts the number of rows matching the where clause
func (db *Database) CountRows(table string, where string) (int, error) {
query := fmt.Sprintf("select count(*) from %s", table)
if where != "" {
query += fmt.Sprintf(" WHERE %s", where)
}
query += ";"
db.logQuery(query)
count := 0
row := db.SQL().QueryRow(query)
err := row.Scan(&count)
if err != nil {
if err == sql.ErrNoRows {
return 0, nil
}
return 0, NewError(query, err)
}
return count, nil
}
// Close closes the db connection
func (db *Database) Close() error {
if db.sql != nil {
return db.sql.Close()
}
return nil
}
// SQL returns the accessor to the sql driver under the covers.. used for importing data
func (db *Database) SQL() *sql.DB {
return db.sql
}
func New(dbURN string) (*Database, error) {
if dbURN == "" {
return nil, fmt.Errorf("no database URN provided")
}
sqlDB, err := sql.Open("postgres", dbURN)
if err != nil {
return nil, err
}
err = sqlDB.Ping()
if err != nil {
_ = sqlDB.Close()
return nil, err
}
return &Database{
sql: sqlDB,
}, nil
}