package migrate import ( "context" "fmt" "log/slog" "os" "path/filepath" "sort" "strings" "github.com/jackc/pgx/v5/pgxpool" ) // Tracking table lives in schema "infra" so it survives migrations that run // DROP SCHEMA public CASCADE (e.g. 000001_init.sql). public.schema_migrations // from dumps is optional/redundant. const migrationTable = "infra.schema_migrations" // Run applies pending SQL migrations from dir in sorted order. // Already-applied migrations (tracked in infra.schema_migrations) are skipped. func Run(ctx context.Context, pool *pgxpool.Pool, dir string) error { if _, err := pool.Exec(ctx, `CREATE SCHEMA IF NOT EXISTS infra`); err != nil { return fmt.Errorf("migrate: create infra schema: %w", err) } if _, err := pool.Exec(ctx, `CREATE TABLE IF NOT EXISTS `+migrationTable+` ( filename TEXT PRIMARY KEY, applied_at TIMESTAMPTZ NOT NULL DEFAULT now() )`); err != nil { return fmt.Errorf("migrate: create tracking table: %w", err) } if err := copyLegacyPublicMigrations(ctx, pool); err != nil { return err } entries, err := os.ReadDir(dir) if err != nil { return fmt.Errorf("migrate: read dir %s: %w", dir, err) } var files []string for _, e := range entries { if !e.IsDir() && strings.HasSuffix(e.Name(), ".sql") { files = append(files, e.Name()) } } sort.Strings(files) rows, err := pool.Query(ctx, "SELECT filename FROM "+migrationTable) if err != nil { return fmt.Errorf("migrate: query applied: %w", err) } defer rows.Close() applied := make(map[string]bool) for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return fmt.Errorf("migrate: scan: %w", err) } applied[name] = true } if err := rows.Err(); err != nil { return fmt.Errorf("migrate: rows: %w", err) } for _, f := range files { if applied[f] { continue } sql, err := os.ReadFile(filepath.Join(dir, f)) if err != nil { return fmt.Errorf("migrate: read %s: %w", f, err) } tx, err := pool.Begin(ctx) if err != nil { return fmt.Errorf("migrate: begin tx for %s: %w", f, err) } if _, err := tx.Exec(ctx, string(sql)); err != nil { tx.Rollback(ctx) //nolint:errcheck return fmt.Errorf("migrate: exec %s: %w", f, err) } if _, err := tx.Exec(ctx, "INSERT INTO "+migrationTable+" (filename) VALUES ($1)", f); err != nil { tx.Rollback(ctx) //nolint:errcheck return fmt.Errorf("migrate: record %s: %w", f, err) } if err := tx.Commit(ctx); err != nil { return fmt.Errorf("migrate: commit %s: %w", f, err) } slog.Info("migrate: applied", "file", f) } return nil } // copyLegacyPublicMigrations copies rows from public.schema_migrations once, if infra was empty // and the legacy table exists (deployments from before infra.schema_migrations). func copyLegacyPublicMigrations(ctx context.Context, pool *pgxpool.Pool) error { var infraCount int if err := pool.QueryRow(ctx, `SELECT COUNT(*) FROM `+migrationTable).Scan(&infraCount); err != nil { return fmt.Errorf("migrate: count infra migrations: %w", err) } if infraCount > 0 { return nil } var legacyExists bool q := `SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'schema_migrations' )` if err := pool.QueryRow(ctx, q).Scan(&legacyExists); err != nil { return fmt.Errorf("migrate: check legacy schema_migrations: %w", err) } if !legacyExists { return nil } if _, err := pool.Exec(ctx, ` INSERT INTO `+migrationTable+` (filename, applied_at) SELECT filename, applied_at FROM public.schema_migrations ON CONFLICT (filename) DO NOTHING `); err != nil { return fmt.Errorf("migrate: copy legacy public.schema_migrations: %w", err) } slog.Info("migrate: copied applied migrations from public.schema_migrations to infra.schema_migrations") return nil }