package policy import ( "context" "errors" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" ) type Repository interface { List(ctx context.Context) ([]Policy, error) Create(ctx context.Context, input CreateRequest, createdBy uuid.UUID) (Policy, error) Update(ctx context.Context, policyID uuid.UUID, input UpdateRequest) (Policy, error) Delete(ctx context.Context, policyID uuid.UUID) error ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error) } type PGRepository struct { db *pgxpool.Pool } func NewPGRepository(db *pgxpool.Pool) *PGRepository { return &PGRepository{db: db} } func (r *PGRepository) List(ctx context.Context) ([]Policy, error) { rows, err := r.db.Query(ctx, ` select p.id, p.name, p.description, p.priority, p.effect, p.full_tunnel, p.is_active, coalesce(array_agg(pd.destination::text) filter (where pd.destination is not null), '{}') from policies p left join policy_destinations pd on pd.policy_id = p.id where p.deleted_at is null group by p.id, p.name, p.description, p.priority, p.effect, p.full_tunnel, p.is_active, p.created_at order by p.priority asc, p.created_at desc `) if err != nil { return nil, err } defer rows.Close() var items []Policy for rows.Next() { var item Policy if err := rows.Scan(&item.ID, &item.Name, &item.Description, &item.Priority, &item.Effect, &item.FullTunnel, &item.IsActive, &item.Destinations); err != nil { return nil, err } targets, err := r.listTargets(ctx, item.ID) if err != nil { return nil, err } item.Targets = targets items = append(items, item) } return items, rows.Err() } func (r *PGRepository) Create(ctx context.Context, input CreateRequest, createdBy uuid.UUID) (Policy, error) { tx, err := r.db.Begin(ctx) if err != nil { return Policy{}, err } defer tx.Rollback(ctx) policyID := uuid.New() _, err = tx.Exec(ctx, ` insert into policies (id, name, description, priority, effect, full_tunnel, created_by) values ($1, $2, $3, $4, $5, $6, $7) `, policyID, input.Name, input.Description, input.Priority, input.Effect, input.FullTunnel, createdBy) if err != nil { return Policy{}, err } for _, destination := range input.Destinations { if _, err := tx.Exec(ctx, ` insert into policy_destinations (id, policy_id, destination) values ($1, $2, $3::cidr) `, uuid.New(), policyID, destination); err != nil { return Policy{}, err } } for _, target := range input.Targets { if _, err := tx.Exec(ctx, ` insert into policy_targets (id, policy_id, target_type, target_id) values ($1, $2, $3, $4) `, uuid.New(), policyID, target.Type, target.ID); err != nil { return Policy{}, err } } if err := tx.Commit(ctx); err != nil { return Policy{}, err } return r.getByID(ctx, policyID) } func (r *PGRepository) Update(ctx context.Context, policyID uuid.UUID, input UpdateRequest) (Policy, error) { tx, err := r.db.Begin(ctx) if err != nil { return Policy{}, err } defer tx.Rollback(ctx) _, err = tx.Exec(ctx, ` update policies set name = coalesce($2, name), description = coalesce($3, description), priority = coalesce($4, priority), effect = coalesce($5, effect), full_tunnel = coalesce($6, full_tunnel), is_active = coalesce($7, is_active), updated_at = now() where id = $1 and deleted_at is null `, policyID, input.Name, input.Description, input.Priority, input.Effect, input.FullTunnel, input.IsActive) if err != nil { return Policy{}, err } if input.Destinations != nil { if _, err := tx.Exec(ctx, `delete from policy_destinations where policy_id = $1`, policyID); err != nil { return Policy{}, err } for _, destination := range input.Destinations { if _, err := tx.Exec(ctx, ` insert into policy_destinations (id, policy_id, destination) values ($1, $2, $3::cidr) `, uuid.New(), policyID, destination); err != nil { return Policy{}, err } } } if input.Targets != nil { if _, err := tx.Exec(ctx, `delete from policy_targets where policy_id = $1`, policyID); err != nil { return Policy{}, err } for _, target := range input.Targets { if _, err := tx.Exec(ctx, ` insert into policy_targets (id, policy_id, target_type, target_id) values ($1, $2, $3, $4) `, uuid.New(), policyID, target.Type, target.ID); err != nil { return Policy{}, err } } } if err := tx.Commit(ctx); err != nil { return Policy{}, err } return r.getByID(ctx, policyID) } func (r *PGRepository) Delete(ctx context.Context, policyID uuid.UUID) error { _, err := r.db.Exec(ctx, `update policies set deleted_at = now(), updated_at = now() where id = $1 and deleted_at is null`, policyID) return err } func (r *PGRepository) ResolveDestinations(ctx context.Context, userID uuid.UUID, deviceID *uuid.UUID) ([]string, error) { query := ` select distinct pd.destination::text from policies p join policy_destinations pd on pd.policy_id = p.id join policy_targets pt on pt.policy_id = p.id where p.deleted_at is null and p.is_active = true and p.effect = 'allow' and ( (pt.target_type = 'user' and pt.target_id = $1) or (pt.target_type = 'group' and exists ( select 1 from group_memberships gm where gm.group_id = pt.target_id and gm.user_id = $1 )) ` args := []any{userID} if deviceID != nil { query += ` or (pt.target_type = 'device' and pt.target_id = $2)` args = append(args, *deviceID) } query += `)` rows, err := r.db.Query(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() var destinations []string for rows.Next() { var value string if err := rows.Scan(&value); err != nil { return nil, err } destinations = append(destinations, value) } return destinations, rows.Err() } func (r *PGRepository) getByID(ctx context.Context, policyID uuid.UUID) (Policy, error) { items, err := r.List(ctx) if err != nil { return Policy{}, err } for _, item := range items { if item.ID == policyID { return item, nil } } return Policy{}, errors.New("policy not found") } func (r *PGRepository) listTargets(ctx context.Context, policyID uuid.UUID) ([]Target, error) { rows, err := r.db.Query(ctx, ` select pt.target_type, pt.target_id, coalesce(u.username, g.name, d.name, '') from policy_targets pt left join users u on pt.target_type = 'user' and u.id = pt.target_id and u.deleted_at is null left join groups g on pt.target_type = 'group' and g.id = pt.target_id and g.deleted_at is null left join devices d on pt.target_type = 'device' and d.id = pt.target_id and d.deleted_at is null where pt.policy_id = $1 order by pt.created_at asc `, policyID) if err != nil { return nil, err } defer rows.Close() var items []Target for rows.Next() { var item Target if err := rows.Scan(&item.Type, &item.ID, &item.Name); err != nil { return nil, err } items = append(items, item) } return items, rows.Err() }