You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

134 lines
3.7 KiB
Go

package migrate
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"sort"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
)
// Tracking table lives in schema "infra" so it survives migrations that run
// DROP SCHEMA public CASCADE (e.g. 000001_init.sql). public.schema_migrations
// from dumps is optional/redundant.
const migrationTable = "infra.schema_migrations"
// Run applies pending SQL migrations from dir in sorted order.
// Already-applied migrations (tracked in infra.schema_migrations) are skipped.
func Run(ctx context.Context, pool *pgxpool.Pool, dir string) error {
if _, err := pool.Exec(ctx, `CREATE SCHEMA IF NOT EXISTS infra`); err != nil {
return fmt.Errorf("migrate: create infra schema: %w", err)
}
if _, err := pool.Exec(ctx, `CREATE TABLE IF NOT EXISTS `+migrationTable+` (
filename TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL DEFAULT now()
)`); err != nil {
return fmt.Errorf("migrate: create tracking table: %w", err)
}
if err := copyLegacyPublicMigrations(ctx, pool); err != nil {
return err
}
entries, err := os.ReadDir(dir)
if err != nil {
return fmt.Errorf("migrate: read dir %s: %w", dir, err)
}
var files []string
for _, e := range entries {
if !e.IsDir() && strings.HasSuffix(e.Name(), ".sql") {
files = append(files, e.Name())
}
}
sort.Strings(files)
rows, err := pool.Query(ctx, "SELECT filename FROM "+migrationTable)
if err != nil {
return fmt.Errorf("migrate: query applied: %w", err)
}
defer rows.Close()
applied := make(map[string]bool)
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return fmt.Errorf("migrate: scan: %w", err)
}
applied[name] = true
}
if err := rows.Err(); err != nil {
return fmt.Errorf("migrate: rows: %w", err)
}
for _, f := range files {
if applied[f] {
continue
}
sql, err := os.ReadFile(filepath.Join(dir, f))
if err != nil {
return fmt.Errorf("migrate: read %s: %w", f, err)
}
tx, err := pool.Begin(ctx)
if err != nil {
return fmt.Errorf("migrate: begin tx for %s: %w", f, err)
}
if _, err := tx.Exec(ctx, string(sql)); err != nil {
tx.Rollback(ctx) //nolint:errcheck
return fmt.Errorf("migrate: exec %s: %w", f, err)
}
if _, err := tx.Exec(ctx, "INSERT INTO "+migrationTable+" (filename) VALUES ($1)", f); err != nil {
tx.Rollback(ctx) //nolint:errcheck
return fmt.Errorf("migrate: record %s: %w", f, err)
}
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("migrate: commit %s: %w", f, err)
}
slog.Info("migrate: applied", "file", f)
}
return nil
}
// copyLegacyPublicMigrations copies rows from public.schema_migrations once, if infra was empty
// and the legacy table exists (deployments from before infra.schema_migrations).
func copyLegacyPublicMigrations(ctx context.Context, pool *pgxpool.Pool) error {
var infraCount int
if err := pool.QueryRow(ctx, `SELECT COUNT(*) FROM `+migrationTable).Scan(&infraCount); err != nil {
return fmt.Errorf("migrate: count infra migrations: %w", err)
}
if infraCount > 0 {
return nil
}
var legacyExists bool
q := `SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = 'schema_migrations'
)`
if err := pool.QueryRow(ctx, q).Scan(&legacyExists); err != nil {
return fmt.Errorf("migrate: check legacy schema_migrations: %w", err)
}
if !legacyExists {
return nil
}
if _, err := pool.Exec(ctx, `
INSERT INTO `+migrationTable+` (filename, applied_at)
SELECT filename, applied_at FROM public.schema_migrations
ON CONFLICT (filename) DO NOTHING
`); err != nil {
return fmt.Errorf("migrate: copy legacy public.schema_migrations: %w", err)
}
slog.Info("migrate: copied applied migrations from public.schema_migrations to infra.schema_migrations")
return nil
}