package sql import ( "database/sql" "fmt" sq "github.com/Masterminds/squirrel" "github.com/blockloop/scan/v2" ) type Logger interface { Debugf(format string, args ...any) } type Database struct { sql *sql.DB log Logger placeholderFormat sq.PlaceholderFormat DebugEnabled bool } func (db *Database) builder() sq.StatementBuilderType { return sq.StatementBuilder.PlaceholderFormat(db.placeholderFormat) } func (db *Database) logQuery(query string, args ...interface{}) { if db.DebugEnabled { db.log.Debugf("%s\n", query) } } // ObjectGet retrieves 1 object based on provided criteria func (db *Database) ObjectGet(table string, pred []interface{}, obj interface{}) (error, bool) { sb := db.builder().Select("*").From(table) if len(pred) == 1 { sb = sb.Where(pred[0]) } else if len(pred) > 1 { sb = sb.Where(pred[0], pred[1:]...) } query, args, err := sb.ToSql() if err != nil { return err, false } rows, err := db.sql.Query(query, args...) if err != nil { if err == sql.ErrNoRows { return nil, false } return NewError(query, err), false } defer rows.Close() err = scan.Row(obj, rows) if err != nil { if err == sql.ErrNoRows { return nil, false } return err, false } return nil, true } func (db *Database) Insert(table string, record interface{}) error { m := GetRecordMap(record) query, args, err := db.builder().Insert(table).SetMap(m).ToSql() if err != nil { return NewError(query, err) } db.logQuery(query) stmt, err := db.sql.Prepare(query) if err != nil { return NewError(query, err) } _, err = stmt.Exec(args...) if err != nil { return err // Concern maybe no rows? } defer stmt.Close() return nil } // ListRows performs a generic list from a db based on the passed in parameters func (db *Database) ListRows(table string, pred []interface{}, order []string, scanFunc func(rows *sql.Rows) error) error { sb := db.builder().Select("*").From(table) if len(pred) == 1 { sb = sb.Where(pred[0]) } else if len(pred) > 1 { sb = sb.Where(pred[0], pred[1:]...) } if len(order) != 0 { sb = sb.OrderBy(order...) } query, args, err := sb.ToSql() if err != nil { return err } db.logQuery(query, args...) rows, err := db.SQL().Query(query, args...) if err != nil { if err == sql.ErrNoRows { return nil } return err } defer rows.Close() return scanFunc(rows) } func (db *Database) Update(table string, constraint string, value interface{}, record interface{}) error { m := GetRecordMap(record) delete(m, constraint) query, args, err := db.builder().Update(table).SetMap(m).Where(sq.Eq{constraint: value}).ToSql() if err != nil { return NewError(query, err) } db.logQuery(query, args...) stmt, err := db.sql.Prepare(query) if err != nil { return NewError(query, err) } result, err := stmt.Exec(args...) if err != nil { return NewError(query, err) } defer stmt.Close() c, err := result.RowsAffected() if err != nil { return err } if c == 0 { return sql.ErrNoRows } return nil } func (db *Database) InsertOrUpdate(table string, constraint string, value interface{}, record interface{}) error { err := db.Update(table, constraint, value, record) if err == sql.ErrNoRows { return db.Insert(table, record) } return err } // CountRows counts the number of rows matching the where clause func (db *Database) CountRows(table string, where string) (int, error) { query := fmt.Sprintf("select count(*) from %s", table) if where != "" { query += fmt.Sprintf(" WHERE %s", where) } query += ";" db.logQuery(query) count := 0 row := db.SQL().QueryRow(query) err := row.Scan(&count) if err != nil { if err == sql.ErrNoRows { return 0, nil } return 0, NewError(query, err) } return count, nil } // Close closes the db connection func (db *Database) Close() error { if db.sql != nil { return db.sql.Close() } return nil } // SQL returns the accessor to the sql driver under the covers.. used for importing data func (db *Database) SQL() *sql.DB { return db.sql } func New(dbURN string) (*Database, error) { if dbURN == "" { return nil, fmt.Errorf("no database URN provided") } sqlDB, err := sql.Open("postgres", dbURN) if err != nil { return nil, err } err = sqlDB.Ping() if err != nil { _ = sqlDB.Close() return nil, err } return &Database{ sql: sqlDB, }, nil }