mirror of
https://codeberg.org/forgejo/forgejo.git
synced 2025-06-18 23:20:51 +00:00
199 lines
4.2 KiB
Go
199 lines
4.2 KiB
Go
|
// Copyright 2025 The Forgejo Authors. All rights reserved.
|
||
|
// SPDX-License-Identifier: GPL-3.0-or-later
|
||
|
|
||
|
package unittest
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"encoding/hex"
|
||
|
"encoding/json" //nolint:depguard
|
||
|
"fmt"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"strings"
|
||
|
|
||
|
"gopkg.in/yaml.v3"
|
||
|
)
|
||
|
|
||
|
type insertSQL struct {
|
||
|
statement string
|
||
|
values []any
|
||
|
}
|
||
|
|
||
|
type fixtureFile struct {
|
||
|
name string
|
||
|
insertSQLs []insertSQL
|
||
|
}
|
||
|
|
||
|
type loader struct {
|
||
|
db *sql.DB
|
||
|
dialect string
|
||
|
|
||
|
fixtureFiles []*fixtureFile
|
||
|
}
|
||
|
|
||
|
func newFixtureLoader(db *sql.DB, dialect string, fixturePaths []string) (*loader, error) {
|
||
|
l := &loader{
|
||
|
db: db,
|
||
|
dialect: dialect,
|
||
|
fixtureFiles: []*fixtureFile{},
|
||
|
}
|
||
|
|
||
|
// Load fixtures
|
||
|
for _, fixturePath := range fixturePaths {
|
||
|
stat, err := os.Stat(fixturePath)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// If fixture path is a directory, then read read the files of the directory
|
||
|
// and use those as fixture files.
|
||
|
if stat.IsDir() {
|
||
|
files, err := os.ReadDir(fixturePath)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
for _, file := range files {
|
||
|
if !file.IsDir() {
|
||
|
fixtureFile, err := l.buildFixtureFile(filepath.Join(fixturePath, file.Name()))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
l.fixtureFiles = append(l.fixtureFiles, fixtureFile)
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
fixtureFile, err := l.buildFixtureFile(fixturePath)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
l.fixtureFiles = append(l.fixtureFiles, fixtureFile)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return l, nil
|
||
|
}
|
||
|
|
||
|
// quoteKeyword returns the quoted string of keyword.
|
||
|
func (l *loader) quoteKeyword(keyword string) string {
|
||
|
switch l.dialect {
|
||
|
case "sqlite3":
|
||
|
return `"` + keyword + `"`
|
||
|
case "mysql":
|
||
|
return "`" + keyword + "`"
|
||
|
case "postgres":
|
||
|
parts := strings.Split(keyword, ".")
|
||
|
for i, p := range parts {
|
||
|
parts[i] = `"` + p + `"`
|
||
|
}
|
||
|
return strings.Join(parts, ".")
|
||
|
default:
|
||
|
return "invalid"
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// placeholder returns the placeholder string.
|
||
|
func (l *loader) placeholder(index int) string {
|
||
|
if l.dialect == "postgres" {
|
||
|
return fmt.Sprintf("$%d", index)
|
||
|
}
|
||
|
return "?"
|
||
|
}
|
||
|
|
||
|
func (l *loader) buildFixtureFile(fixturePath string) (*fixtureFile, error) {
|
||
|
f, err := os.Open(fixturePath)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer f.Close()
|
||
|
|
||
|
var records []map[string]any
|
||
|
if err := yaml.NewDecoder(f).Decode(&records); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
fixture := &fixtureFile{
|
||
|
name: filepath.Base(strings.TrimSuffix(f.Name(), filepath.Ext(f.Name()))),
|
||
|
insertSQLs: []insertSQL{},
|
||
|
}
|
||
|
|
||
|
for _, record := range records {
|
||
|
columns := []string{}
|
||
|
sqlValues := []string{}
|
||
|
values := []any{}
|
||
|
i := 1
|
||
|
|
||
|
for key, value := range record {
|
||
|
columns = append(columns, l.quoteKeyword(key))
|
||
|
|
||
|
switch v := value.(type) {
|
||
|
case string:
|
||
|
// Try to decode hex.
|
||
|
if strings.HasPrefix(v, "0x") {
|
||
|
value, err = hex.DecodeString(strings.TrimPrefix(v, "0x"))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
case []any:
|
||
|
// Decode array.
|
||
|
var bytes []byte
|
||
|
bytes, err = json.Marshal(v)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
value = string(bytes)
|
||
|
}
|
||
|
|
||
|
values = append(values, value)
|
||
|
|
||
|
sqlValues = append(sqlValues, l.placeholder(i))
|
||
|
i++
|
||
|
}
|
||
|
|
||
|
// Construct the insert SQL.
|
||
|
fixture.insertSQLs = append(fixture.insertSQLs, insertSQL{
|
||
|
statement: fmt.Sprintf(
|
||
|
"INSERT INTO %s (%s) VALUES (%s)",
|
||
|
l.quoteKeyword(fixture.name),
|
||
|
strings.Join(columns, ", "),
|
||
|
strings.Join(sqlValues, ", "),
|
||
|
),
|
||
|
values: values,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return fixture, nil
|
||
|
}
|
||
|
|
||
|
func (l *loader) Load() error {
|
||
|
// Start transaction.
|
||
|
tx, err := l.db.Begin()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
_ = tx.Rollback()
|
||
|
}()
|
||
|
|
||
|
// Clean the table and re-insert the fixtures.
|
||
|
tableDeleted := map[string]struct{}{}
|
||
|
for _, fixture := range l.fixtureFiles {
|
||
|
if _, ok := tableDeleted[fixture.name]; !ok {
|
||
|
if _, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", l.quoteKeyword(fixture.name))); err != nil {
|
||
|
return fmt.Errorf("cannot delete table %s: %w", fixture.name, err)
|
||
|
}
|
||
|
tableDeleted[fixture.name] = struct{}{}
|
||
|
}
|
||
|
|
||
|
for _, insertSQL := range fixture.insertSQLs {
|
||
|
if _, err := tx.Exec(insertSQL.statement, insertSQL.values...); err != nil {
|
||
|
return fmt.Errorf("cannot insert %q with values %q: %w", insertSQL.statement, insertSQL.values, err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return tx.Commit()
|
||
|
}
|