[NX-102 Issue] Add exponential backoff with jitter for retry logic

Introduced an exponential backoff mechanism with a configurable base, max delay, and jitter factor to handle retries for target failures. This improves resilience by reducing the load during repeated failures and avoids synchronized retry storms. Additionally, stale target cleanup logic has been implemented to prevent unnecessary state retention.
This commit is contained in:
2026-02-14 11:44:49 +01:00
parent 117710cc0a
commit d9dfde1c87

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from random import uniform
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@@ -15,6 +16,9 @@ logger = logging.getLogger(__name__)
settings = get_settings() settings = get_settings()
_failure_state: dict[int, dict[str, object]] = {} _failure_state: dict[int, dict[str, object]] = {}
_failure_log_interval_seconds = 300 _failure_log_interval_seconds = 300
_backoff_base_seconds = max(3, int(settings.poll_interval_seconds))
_backoff_max_seconds = 300
_backoff_jitter_factor = 0.15
def build_target_dsn(target: Target) -> str: def build_target_dsn(target: Target) -> str:
@@ -181,38 +185,66 @@ async def collect_once() -> None:
async with SessionLocal() as db: async with SessionLocal() as db:
targets = (await db.scalars(select(Target))).all() targets = (await db.scalars(select(Target))).all()
active_target_ids = {target.id for target in targets}
stale_target_ids = [target_id for target_id in _failure_state.keys() if target_id not in active_target_ids]
for stale_target_id in stale_target_ids:
_failure_state.pop(stale_target_id, None)
for target in targets: for target in targets:
now = datetime.now(timezone.utc)
state = _failure_state.get(target.id)
if state:
next_attempt_at = state.get("next_attempt_at")
if isinstance(next_attempt_at, datetime) and now < next_attempt_at:
continue
try: try:
await collect_target(target) await collect_target(target)
prev = _failure_state.pop(target.id, None) prev = _failure_state.pop(target.id, None)
if prev: if prev:
first_failure_at = prev.get("first_failure_at")
downtime_seconds = None
if isinstance(first_failure_at, datetime):
downtime_seconds = max(0, int((now - first_failure_at).total_seconds()))
logger.info( logger.info(
"collector_target_recovered target=%s after_failures=%s last_error=%s", "collector_target_recovered target=%s after_failures=%s downtime_seconds=%s last_error=%s",
target.id, target.id,
prev.get("count", 0), prev.get("count", 0),
downtime_seconds,
prev.get("error"), prev.get("error"),
) )
except (OSError, SQLAlchemyError, asyncpg.PostgresError, Exception) as exc: except (OSError, SQLAlchemyError, asyncpg.PostgresError, Exception) as exc:
now = datetime.now(timezone.utc)
current_error = str(exc) current_error = str(exc)
error_class = exc.__class__.__name__ error_class = exc.__class__.__name__
state = _failure_state.get(target.id) state = _failure_state.get(target.id)
if state is None: if state is None:
next_delay = min(_backoff_max_seconds, _backoff_base_seconds)
jitter = next_delay * _backoff_jitter_factor
next_delay = max(1, int(next_delay + uniform(-jitter, jitter)))
next_attempt_at = now.timestamp() + next_delay
_failure_state[target.id] = { _failure_state[target.id] = {
"count": 1, "count": 1,
"first_failure_at": now,
"last_log_at": now, "last_log_at": now,
"error": current_error, "error": current_error,
"next_attempt_at": datetime.fromtimestamp(next_attempt_at, tz=timezone.utc),
} }
logger.warning( logger.warning(
"collector_target_unreachable target=%s error_class=%s err=%s consecutive_failures=%s", "collector_target_unreachable target=%s error_class=%s err=%s consecutive_failures=%s retry_in_seconds=%s",
target.id, target.id,
error_class, error_class,
current_error, current_error,
1, 1,
next_delay,
) )
continue continue
count = int(state.get("count", 0)) + 1 count = int(state.get("count", 0)) + 1
raw_backoff = min(_backoff_max_seconds, _backoff_base_seconds * (2 ** min(count - 1, 10)))
jitter = raw_backoff * _backoff_jitter_factor
next_delay = max(1, int(raw_backoff + uniform(-jitter, jitter)))
state["next_attempt_at"] = datetime.fromtimestamp(now.timestamp() + next_delay, tz=timezone.utc)
last_log_at = state.get("last_log_at") last_log_at = state.get("last_log_at")
last_logged_error = str(state.get("error", "")) last_logged_error = str(state.get("error", ""))
should_log = False should_log = False
@@ -228,11 +260,12 @@ async def collect_once() -> None:
state["last_log_at"] = now state["last_log_at"] = now
state["error"] = current_error state["error"] = current_error
logger.warning( logger.warning(
"collector_target_unreachable target=%s error_class=%s err=%s consecutive_failures=%s", "collector_target_unreachable target=%s error_class=%s err=%s consecutive_failures=%s retry_in_seconds=%s",
target.id, target.id,
error_class, error_class,
current_error, current_error,
count, count,
next_delay,
) )