Coverage for /home/runner/work/bijux-cli/bijux-cli/src/bijux_cli/infra/retry.py: 100%
82 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 23:36 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 23:36 +0000
1# SPDX-License-Identifier: MIT
2# Copyright © 2025 Bijan Mousavi
4"""Provides concrete asynchronous retry policy implementations.
6This module defines classes that implement the `RetryPolicyProtocol` to
7handle transient errors in asynchronous operations. It offers two main
8strategies:
10 * `TimeoutRetryPolicy`: A simple policy that applies a single timeout to an
11 operation.
12 * `ExponentialBackoffRetryPolicy`: A more advanced policy that retries an
13 operation multiple times with an exponentially increasing delay and
14 random jitter between attempts.
16These components are designed to be used by services to build resilience
17against temporary failures, such as network issues.
18"""
20from __future__ import annotations
22import asyncio
23from collections.abc import Awaitable, Callable
24from contextlib import AbstractAsyncContextManager, suppress
25import inspect
26import secrets
27from typing import Any, TypeVar, cast
29from injector import inject
31from bijux_cli.contracts import RetryPolicyProtocol, TelemetryProtocol
32from bijux_cli.core.exceptions import BijuxError
34T = TypeVar("T")
37def _close_awaitable(obj: Any) -> None:
38 """Safely closes an object if it has a synchronous `close()` method.
40 This helper performs a best-effort call to `obj.close()`, suppressing any
41 exceptions that may be raised.
43 Args:
44 obj (Any): The object to close.
46 Returns:
47 None:
48 """
49 close = getattr(obj, "close", None)
50 if callable(close):
51 with suppress(Exception):
52 close()
55def _try_asyncio_timeout(
56 seconds: float,
57) -> AbstractAsyncContextManager[None] | None:
58 """Returns an `asyncio.timeout` context manager if it is available and usable.
60 This function checks if `asyncio.timeout` is a real, usable implementation,
61 avoiding mocks or incompatible objects that may exist in some test
62 environments.
64 Args:
65 seconds (float): The timeout duration in seconds.
67 Returns:
68 AbstractAsyncContextManager[None] | None: A configured timeout context
69 manager if supported, otherwise `None`.
70 """
71 async_timeout = getattr(asyncio, "timeout", None)
73 if (
74 async_timeout is None
75 or not callable(async_timeout)
76 or getattr(async_timeout, "__module__", "") == "unittest.mock"
77 ):
78 return None
80 try:
81 candidate = async_timeout(seconds)
82 except (TypeError, ValueError, RuntimeError):
83 return None
85 if inspect.isawaitable(candidate):
86 _close_awaitable(candidate)
87 return None
89 if hasattr(candidate, "__aenter__") and hasattr(candidate, "__aexit__"):
90 return cast(AbstractAsyncContextManager[None], candidate)
92 return None
95async def _backoff_loop(
96 supplier: Callable[[], Awaitable[T]],
97 *,
98 retries: int,
99 delay: float,
100 backoff: float,
101 jitter: float,
102 retry_on: tuple[type[BaseException], ...],
103 telemetry: TelemetryProtocol,
104) -> T:
105 """Executes an async operation with an exponential-backoff retry loop.
107 Args:
108 supplier (Callable[[], Awaitable[T]]): The async function to execute.
109 retries (int): The maximum number of retry attempts.
110 delay (float): The initial delay in seconds before the first retry.
111 backoff (float): The multiplier applied to the delay after each failure.
112 jitter (float): The random fractional jitter to apply to each delay.
113 retry_on (tuple[type[BaseException], ...]): A tuple of exception types
114 that will trigger a retry.
115 telemetry (TelemetryProtocol): The service for emitting telemetry events.
117 Returns:
118 T: The result from a successful call to the `supplier`.
120 Raises:
121 BaseException: The last exception raised by `supplier` if all retries fail.
122 RuntimeError: If the loop finishes without returning or raising, which
123 should be unreachable.
124 """
125 attempts = max(1, retries)
126 for attempt in range(attempts):
127 try:
128 result = await supplier()
129 telemetry.event("retry_async_success", {"retries": attempt})
130 return result
131 except retry_on as exc:
132 if attempt + 1 == attempts:
133 telemetry.event(
134 "retry_async_failed", {"retries": retries, "error": str(exc)}
135 )
136 raise
137 sleep_for = delay * (backoff**attempt)
138 if jitter:
139 sleep_for += sleep_for * secrets.SystemRandom().uniform(-jitter, jitter)
140 await asyncio.sleep(sleep_for)
141 raise RuntimeError("Unreachable code") # pragma: no cover
144class TimeoutRetryPolicy(RetryPolicyProtocol):
145 """A retry policy that applies a single, one-time timeout to an operation.
147 Attributes:
148 _telemetry (TelemetryProtocol): The service for emitting telemetry events.
149 """
151 @inject
152 def __init__(self, telemetry: TelemetryProtocol) -> None:
153 """Initializes the `TimeoutRetryPolicy`.
155 Args:
156 telemetry (TelemetryProtocol): The service for emitting events.
157 """
158 self._telemetry = telemetry
160 async def run(
161 self,
162 supplier: Callable[[], Awaitable[T]],
163 seconds: float = 1.0,
164 ) -> T:
165 """Executes an awaitable `supplier` with a single timeout.
167 This method uses the modern `asyncio.timeout` context manager if
168 available, otherwise it falls back to `asyncio.wait_for`.
170 Args:
171 supplier (Callable[[], Awaitable[T]]): The async operation to run.
172 seconds (float): The timeout duration in seconds. Must be positive.
174 Returns:
175 T: The result of the `supplier` if it completes in time.
177 Raises:
178 ValueError: If `seconds` is less than or equal to 0.
179 BijuxError: If the operation times out.
180 """
181 if seconds <= 0:
182 raise ValueError("seconds must be > 0")
184 ctx = _try_asyncio_timeout(seconds)
186 try:
187 if ctx is not None:
188 async with ctx:
189 result = await supplier()
190 else:
191 result = await asyncio.wait_for(supplier(), timeout=seconds)
193 self._telemetry.event("retry_timeout_success", {"seconds": seconds})
194 return result
196 except TimeoutError as exc:
197 self._telemetry.event(
198 "retry_timeout_failed", {"seconds": seconds, "error": str(exc)}
199 )
200 raise BijuxError(
201 f"Operation timed out after {seconds}s", http_status=504
202 ) from exc
204 def reset(self) -> None:
205 """Resets the retry policy state. This is a no-op for this policy."""
206 self._telemetry.event("retry_reset", {})
209class ExponentialBackoffRetryPolicy(RetryPolicyProtocol):
210 """A retry policy with exponential backoff, jitter, and per-attempt timeouts.
212 Attributes:
213 _telemetry (TelemetryProtocol): The service for emitting telemetry events.
214 """
216 @inject
217 def __init__(self, telemetry: TelemetryProtocol) -> None:
218 """Initializes the `ExponentialBackoffRetryPolicy`.
220 Args:
221 telemetry (TelemetryProtocol): The service for emitting events.
222 """
223 self._telemetry = telemetry
225 async def run(
226 self,
227 supplier: Callable[[], Awaitable[T]],
228 seconds: float = 1.0,
229 retries: int = 3,
230 delay: float = 1.0,
231 backoff: float = 2.0,
232 jitter: float = 0.3,
233 retry_on: tuple[type[BaseException], ...] = (Exception,),
234 ) -> T:
235 """Executes a supplier with a timeout and exponential-backoff retries.
237 Args:
238 supplier (Callable[[], Awaitable[T]]): The async operation to run.
239 seconds (float): The timeout for each attempt in seconds. Must be > 0.
240 retries (int): The maximum number of retry attempts.
241 delay (float): The initial delay in seconds before the first retry.
242 backoff (float): The multiplier for the delay after each failure.
243 jitter (float): The random fractional jitter to apply to each delay.
244 retry_on (tuple[type[BaseException], ...]): A tuple of exception
245 types that will trigger a retry.
247 Returns:
248 T: The result of the `supplier` if one of the attempts succeeds.
250 Raises:
251 ValueError: If `seconds` is less than or equal to 0.
252 BaseException: The last exception raised by `supplier` if all
253 attempts fail.
254 """
255 if seconds <= 0:
256 raise ValueError("seconds must be > 0")
258 ctx = _try_asyncio_timeout(seconds)
260 if ctx is not None:
261 async with ctx:
262 return await _backoff_loop(
263 supplier,
264 retries=retries,
265 delay=delay,
266 backoff=backoff,
267 jitter=jitter,
268 retry_on=retry_on,
269 telemetry=self._telemetry,
270 )
271 else:
273 async def timed_supplier() -> T:
274 """Wraps the supplier in an `asyncio.wait_for` timeout.
276 Returns:
277 T: The result of the `supplier` if it completes in time.
278 """
279 return await asyncio.wait_for(supplier(), timeout=seconds)
281 return await _backoff_loop(
282 timed_supplier,
283 retries=retries,
284 delay=delay,
285 backoff=backoff,
286 jitter=jitter,
287 retry_on=retry_on,
288 telemetry=self._telemetry,
289 )
291 def reset(self) -> None:
292 """Resets the retry policy state. This is a no-op for this policy."""
293 self._telemetry.event("retry_reset", {})
296__all__ = ["TimeoutRetryPolicy", "ExponentialBackoffRetryPolicy"]