package database import ( "fmt" "github.com/Henelik/cms/pkg/config" "golang.org/x/crypto/bcrypt" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" ) func Migrate(db *gorm.DB, su config.User) error { if err := db.AutoMigrate(&User{}, &Role{}); err != nil { return fmt.Errorf("auto-migrate: %w", err) } defaultRoles := []Role{ {Name: "superadmin"}, {Name: "admin"}, {Name: "editor"}, {Name: "author"}, {Name: "viewer"}, } for _, r := range defaultRoles { if err := db.FirstOrCreate(&r, Role{Name: r.Name}).Error; err != nil { return fmt.Errorf("seed role %q: %w", r.Name, err) } } if su.Name != "" && su.Email != "" && su.Password != "" { if err := seedSuperUser(db, su); err != nil { return fmt.Errorf("seed superuser: %w", err) } } return nil } func seedSuperUser(db *gorm.DB, su config.User) error { hash, err := bcrypt.GenerateFromPassword([]byte(su.Password), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("hash password: %w", err) } user := User{ Name: su.Name, Email: su.Email, PasswordHash: string(hash), } if err := db.FirstOrCreate(&user, User{Email: su.Email}).Error; err != nil { return fmt.Errorf("create user: %w", err) } var superadmin Role if err := db.First(&superadmin, "name = ?", "superadmin").Error; err != nil { return fmt.Errorf("find superadmin role: %w", err) } // Check if user already has the superadmin role var count int64 if err := db.Table("user_roles"). Where("user_id = ? AND role_id = ?", user.ID, superadmin.ID). Count(&count).Error; err != nil { return fmt.Errorf("check role association: %w", err) } if count == 0 { if err := db.Model(&user).Association("Roles").Append(&superadmin); err != nil { return fmt.Errorf("assign superadmin role: %w", err) } } return nil } func NewDB(driver, dsn string) (*gorm.DB, error) { var dialector gorm.Dialector switch driver { case "sqlite": dialector = sqlite.Open(dsn) case "postgres": dialector = postgres.Open(dsn) default: return nil, fmt.Errorf("unsupported database driver: %s", driver) } db, err := gorm.Open(dialector, &gorm.Config{}) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } return db, nil }