Files
NexaPG/backend/app/services/collector.py
nessi d9dfde1c87 [NX-102 Issue] Add exponential backoff with jitter for retry logic
Introduced an exponential backoff mechanism with a configurable base, max delay, and jitter factor to handle retries for target failures. This improves resilience by reducing the load during repeated failures and avoids synchronized retry storms. Additionally, stale target cleanup logic has been implemented to prevent unnecessary state retention.
2026-02-14 11:44:49 +01:00

279 lines
11 KiB
Python

import asyncio
import logging
from datetime import datetime, timezone
from random import uniform
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from app.core.config import get_settings
from app.core.db import SessionLocal
from app.models.models import Metric, QueryStat, Target
from app.services.crypto import decrypt_secret
import asyncpg
logger = logging.getLogger(__name__)
settings = get_settings()
_failure_state: dict[int, dict[str, object]] = {}
_failure_log_interval_seconds = 300
_backoff_base_seconds = max(3, int(settings.poll_interval_seconds))
_backoff_max_seconds = 300
_backoff_jitter_factor = 0.15
def build_target_dsn(target: Target) -> str:
password = decrypt_secret(target.encrypted_password)
return (
f"postgresql://{target.username}:{password}"
f"@{target.host}:{target.port}/{target.dbname}?sslmode={target.sslmode}"
)
async def _store_metric(db: AsyncSession, target_id: int, name: str, value: float, labels: dict | None = None) -> None:
db.add(
Metric(
target_id=target_id,
ts=datetime.now(timezone.utc),
metric_name=name,
value=float(value),
labels=labels or {},
)
)
async def collect_target(target: Target) -> None:
dsn = build_target_dsn(target)
conn = await asyncpg.connect(dsn=dsn)
try:
stat_db = await conn.fetchrow(
"""
SELECT
numbackends,
xact_commit,
xact_rollback,
deadlocks,
temp_files,
temp_bytes,
blk_read_time,
blk_write_time,
blks_hit,
blks_read,
tup_returned,
tup_fetched
FROM pg_stat_database
WHERE datname = current_database()
"""
)
activity = await conn.fetchrow(
"""
SELECT
count(*) FILTER (WHERE state = 'active') AS active_connections,
count(*) AS total_connections
FROM pg_stat_activity
WHERE datname = current_database()
"""
)
checkpointer_view_exists = await conn.fetchval("SELECT to_regclass('pg_catalog.pg_stat_checkpointer') IS NOT NULL")
bgwriter = None
if checkpointer_view_exists:
try:
bgwriter = await conn.fetchrow(
"""
SELECT
num_timed AS checkpoints_timed,
num_requested AS checkpoints_req,
0::bigint AS buffers_checkpoint,
0::bigint AS buffers_clean,
0::bigint AS maxwritten_clean
FROM pg_stat_checkpointer
"""
)
except Exception:
bgwriter = None
if bgwriter is None:
try:
bgwriter = await conn.fetchrow(
"""
SELECT checkpoints_timed, checkpoints_req, buffers_checkpoint, buffers_clean, maxwritten_clean
FROM pg_stat_bgwriter
"""
)
except Exception:
bgwriter = None
if stat_db is None:
stat_db = {
"numbackends": 0,
"xact_commit": 0,
"xact_rollback": 0,
"deadlocks": 0,
"temp_files": 0,
"temp_bytes": 0,
"blk_read_time": 0,
"blk_write_time": 0,
"blks_hit": 0,
"blks_read": 0,
"tup_returned": 0,
"tup_fetched": 0,
}
if activity is None:
activity = {"active_connections": 0, "total_connections": 0}
if bgwriter is None:
bgwriter = {
"checkpoints_timed": 0,
"checkpoints_req": 0,
"buffers_checkpoint": 0,
"buffers_clean": 0,
"maxwritten_clean": 0,
}
lock_count = await conn.fetchval("SELECT count(*) FROM pg_locks")
cache_hit_ratio = 0.0
if stat_db and (stat_db["blks_hit"] + stat_db["blks_read"]) > 0:
cache_hit_ratio = stat_db["blks_hit"] / (stat_db["blks_hit"] + stat_db["blks_read"])
query_rows = []
if target.use_pg_stat_statements:
try:
query_rows = await conn.fetch(
"""
SELECT queryid::text, calls, total_exec_time, mean_exec_time, rows, left(query, 2000) AS query_text
FROM pg_stat_statements
ORDER BY total_exec_time DESC
LIMIT 20
"""
)
except Exception:
# Extension may be disabled on monitored instance.
query_rows = []
async with SessionLocal() as db:
await _store_metric(db, target.id, "connections_total", activity["total_connections"], {})
await _store_metric(db, target.id, "connections_active", activity["active_connections"], {})
await _store_metric(db, target.id, "xacts_total", stat_db["xact_commit"] + stat_db["xact_rollback"], {})
await _store_metric(db, target.id, "xact_commit", stat_db["xact_commit"], {})
await _store_metric(db, target.id, "xact_rollback", stat_db["xact_rollback"], {})
await _store_metric(db, target.id, "deadlocks", stat_db["deadlocks"], {})
await _store_metric(db, target.id, "temp_files", stat_db["temp_files"], {})
await _store_metric(db, target.id, "temp_bytes", stat_db["temp_bytes"], {})
await _store_metric(db, target.id, "blk_read_time", stat_db["blk_read_time"], {})
await _store_metric(db, target.id, "blk_write_time", stat_db["blk_write_time"], {})
await _store_metric(db, target.id, "cache_hit_ratio", cache_hit_ratio, {})
await _store_metric(db, target.id, "locks_total", lock_count, {})
await _store_metric(db, target.id, "checkpoints_timed", bgwriter["checkpoints_timed"], {})
await _store_metric(db, target.id, "checkpoints_req", bgwriter["checkpoints_req"], {})
for row in query_rows:
db.add(
QueryStat(
target_id=target.id,
ts=datetime.now(timezone.utc),
queryid=row["queryid"] or "0",
calls=row["calls"] or 0,
total_time=row["total_exec_time"] or 0.0,
mean_time=row["mean_exec_time"] or 0.0,
rows=row["rows"] or 0,
query_text=row["query_text"],
)
)
await db.commit()
finally:
await conn.close()
async def collect_once() -> None:
async with SessionLocal() as db:
targets = (await db.scalars(select(Target))).all()
active_target_ids = {target.id for target in targets}
stale_target_ids = [target_id for target_id in _failure_state.keys() if target_id not in active_target_ids]
for stale_target_id in stale_target_ids:
_failure_state.pop(stale_target_id, None)
for target in targets:
now = datetime.now(timezone.utc)
state = _failure_state.get(target.id)
if state:
next_attempt_at = state.get("next_attempt_at")
if isinstance(next_attempt_at, datetime) and now < next_attempt_at:
continue
try:
await collect_target(target)
prev = _failure_state.pop(target.id, None)
if prev:
first_failure_at = prev.get("first_failure_at")
downtime_seconds = None
if isinstance(first_failure_at, datetime):
downtime_seconds = max(0, int((now - first_failure_at).total_seconds()))
logger.info(
"collector_target_recovered target=%s after_failures=%s downtime_seconds=%s last_error=%s",
target.id,
prev.get("count", 0),
downtime_seconds,
prev.get("error"),
)
except (OSError, SQLAlchemyError, asyncpg.PostgresError, Exception) as exc:
current_error = str(exc)
error_class = exc.__class__.__name__
state = _failure_state.get(target.id)
if state is None:
next_delay = min(_backoff_max_seconds, _backoff_base_seconds)
jitter = next_delay * _backoff_jitter_factor
next_delay = max(1, int(next_delay + uniform(-jitter, jitter)))
next_attempt_at = now.timestamp() + next_delay
_failure_state[target.id] = {
"count": 1,
"first_failure_at": now,
"last_log_at": now,
"error": current_error,
"next_attempt_at": datetime.fromtimestamp(next_attempt_at, tz=timezone.utc),
}
logger.warning(
"collector_target_unreachable target=%s error_class=%s err=%s consecutive_failures=%s retry_in_seconds=%s",
target.id,
error_class,
current_error,
1,
next_delay,
)
continue
count = int(state.get("count", 0)) + 1
raw_backoff = min(_backoff_max_seconds, _backoff_base_seconds * (2 ** min(count - 1, 10)))
jitter = raw_backoff * _backoff_jitter_factor
next_delay = max(1, int(raw_backoff + uniform(-jitter, jitter)))
state["next_attempt_at"] = datetime.fromtimestamp(now.timestamp() + next_delay, tz=timezone.utc)
last_log_at = state.get("last_log_at")
last_logged_error = str(state.get("error", ""))
should_log = False
if current_error != last_logged_error:
should_log = True
elif isinstance(last_log_at, datetime):
should_log = (now - last_log_at).total_seconds() >= _failure_log_interval_seconds
else:
should_log = True
state["count"] = count
if should_log:
state["last_log_at"] = now
state["error"] = current_error
logger.warning(
"collector_target_unreachable target=%s error_class=%s err=%s consecutive_failures=%s retry_in_seconds=%s",
target.id,
error_class,
current_error,
count,
next_delay,
)
async def collector_loop(stop_event: asyncio.Event) -> None:
while not stop_event.is_set():
await collect_once()
try:
await asyncio.wait_for(stop_event.wait(), timeout=settings.poll_interval_seconds)
except asyncio.TimeoutError:
pass