Coverage for /home/runner/work/bijux-cli/bijux-cli/src/bijux_cli/infra/retry.py: 100%

82 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 23:36 +0000

1# SPDX-License-Identifier: MIT 

2# Copyright © 2025 Bijan Mousavi 

3 

4"""Provides concrete asynchronous retry policy implementations. 

5 

6This module defines classes that implement the `RetryPolicyProtocol` to 

7handle transient errors in asynchronous operations. It offers two main 

8strategies: 

9 

10 * `TimeoutRetryPolicy`: A simple policy that applies a single timeout to an 

11 operation. 

12 * `ExponentialBackoffRetryPolicy`: A more advanced policy that retries an 

13 operation multiple times with an exponentially increasing delay and 

14 random jitter between attempts. 

15 

16These components are designed to be used by services to build resilience 

17against temporary failures, such as network issues. 

18""" 

19 

20from __future__ import annotations 

21 

22import asyncio 

23from collections.abc import Awaitable, Callable 

24from contextlib import AbstractAsyncContextManager, suppress 

25import inspect 

26import secrets 

27from typing import Any, TypeVar, cast 

28 

29from injector import inject 

30 

31from bijux_cli.contracts import RetryPolicyProtocol, TelemetryProtocol 

32from bijux_cli.core.exceptions import BijuxError 

33 

34T = TypeVar("T") 

35 

36 

37def _close_awaitable(obj: Any) -> None: 

38 """Safely closes an object if it has a synchronous `close()` method. 

39 

40 This helper performs a best-effort call to `obj.close()`, suppressing any 

41 exceptions that may be raised. 

42 

43 Args: 

44 obj (Any): The object to close. 

45 

46 Returns: 

47 None: 

48 """ 

49 close = getattr(obj, "close", None) 

50 if callable(close): 

51 with suppress(Exception): 

52 close() 

53 

54 

55def _try_asyncio_timeout( 

56 seconds: float, 

57) -> AbstractAsyncContextManager[None] | None: 

58 """Returns an `asyncio.timeout` context manager if it is available and usable. 

59 

60 This function checks if `asyncio.timeout` is a real, usable implementation, 

61 avoiding mocks or incompatible objects that may exist in some test 

62 environments. 

63 

64 Args: 

65 seconds (float): The timeout duration in seconds. 

66 

67 Returns: 

68 AbstractAsyncContextManager[None] | None: A configured timeout context 

69 manager if supported, otherwise `None`. 

70 """ 

71 async_timeout = getattr(asyncio, "timeout", None) 

72 

73 if ( 

74 async_timeout is None 

75 or not callable(async_timeout) 

76 or getattr(async_timeout, "__module__", "") == "unittest.mock" 

77 ): 

78 return None 

79 

80 try: 

81 candidate = async_timeout(seconds) 

82 except (TypeError, ValueError, RuntimeError): 

83 return None 

84 

85 if inspect.isawaitable(candidate): 

86 _close_awaitable(candidate) 

87 return None 

88 

89 if hasattr(candidate, "__aenter__") and hasattr(candidate, "__aexit__"): 

90 return cast(AbstractAsyncContextManager[None], candidate) 

91 

92 return None 

93 

94 

95async def _backoff_loop( 

96 supplier: Callable[[], Awaitable[T]], 

97 *, 

98 retries: int, 

99 delay: float, 

100 backoff: float, 

101 jitter: float, 

102 retry_on: tuple[type[BaseException], ...], 

103 telemetry: TelemetryProtocol, 

104) -> T: 

105 """Executes an async operation with an exponential-backoff retry loop. 

106 

107 Args: 

108 supplier (Callable[[], Awaitable[T]]): The async function to execute. 

109 retries (int): The maximum number of retry attempts. 

110 delay (float): The initial delay in seconds before the first retry. 

111 backoff (float): The multiplier applied to the delay after each failure. 

112 jitter (float): The random fractional jitter to apply to each delay. 

113 retry_on (tuple[type[BaseException], ...]): A tuple of exception types 

114 that will trigger a retry. 

115 telemetry (TelemetryProtocol): The service for emitting telemetry events. 

116 

117 Returns: 

118 T: The result from a successful call to the `supplier`. 

119 

120 Raises: 

121 BaseException: The last exception raised by `supplier` if all retries fail. 

122 RuntimeError: If the loop finishes without returning or raising, which 

123 should be unreachable. 

124 """ 

125 attempts = max(1, retries) 

126 for attempt in range(attempts): 

127 try: 

128 result = await supplier() 

129 telemetry.event("retry_async_success", {"retries": attempt}) 

130 return result 

131 except retry_on as exc: 

132 if attempt + 1 == attempts: 

133 telemetry.event( 

134 "retry_async_failed", {"retries": retries, "error": str(exc)} 

135 ) 

136 raise 

137 sleep_for = delay * (backoff**attempt) 

138 if jitter: 

139 sleep_for += sleep_for * secrets.SystemRandom().uniform(-jitter, jitter) 

140 await asyncio.sleep(sleep_for) 

141 raise RuntimeError("Unreachable code") # pragma: no cover 

142 

143 

144class TimeoutRetryPolicy(RetryPolicyProtocol): 

145 """A retry policy that applies a single, one-time timeout to an operation. 

146 

147 Attributes: 

148 _telemetry (TelemetryProtocol): The service for emitting telemetry events. 

149 """ 

150 

151 @inject 

152 def __init__(self, telemetry: TelemetryProtocol) -> None: 

153 """Initializes the `TimeoutRetryPolicy`. 

154 

155 Args: 

156 telemetry (TelemetryProtocol): The service for emitting events. 

157 """ 

158 self._telemetry = telemetry 

159 

160 async def run( 

161 self, 

162 supplier: Callable[[], Awaitable[T]], 

163 seconds: float = 1.0, 

164 ) -> T: 

165 """Executes an awaitable `supplier` with a single timeout. 

166 

167 This method uses the modern `asyncio.timeout` context manager if 

168 available, otherwise it falls back to `asyncio.wait_for`. 

169 

170 Args: 

171 supplier (Callable[[], Awaitable[T]]): The async operation to run. 

172 seconds (float): The timeout duration in seconds. Must be positive. 

173 

174 Returns: 

175 T: The result of the `supplier` if it completes in time. 

176 

177 Raises: 

178 ValueError: If `seconds` is less than or equal to 0. 

179 BijuxError: If the operation times out. 

180 """ 

181 if seconds <= 0: 

182 raise ValueError("seconds must be > 0") 

183 

184 ctx = _try_asyncio_timeout(seconds) 

185 

186 try: 

187 if ctx is not None: 

188 async with ctx: 

189 result = await supplier() 

190 else: 

191 result = await asyncio.wait_for(supplier(), timeout=seconds) 

192 

193 self._telemetry.event("retry_timeout_success", {"seconds": seconds}) 

194 return result 

195 

196 except TimeoutError as exc: 

197 self._telemetry.event( 

198 "retry_timeout_failed", {"seconds": seconds, "error": str(exc)} 

199 ) 

200 raise BijuxError( 

201 f"Operation timed out after {seconds}s", http_status=504 

202 ) from exc 

203 

204 def reset(self) -> None: 

205 """Resets the retry policy state. This is a no-op for this policy.""" 

206 self._telemetry.event("retry_reset", {}) 

207 

208 

209class ExponentialBackoffRetryPolicy(RetryPolicyProtocol): 

210 """A retry policy with exponential backoff, jitter, and per-attempt timeouts. 

211 

212 Attributes: 

213 _telemetry (TelemetryProtocol): The service for emitting telemetry events. 

214 """ 

215 

216 @inject 

217 def __init__(self, telemetry: TelemetryProtocol) -> None: 

218 """Initializes the `ExponentialBackoffRetryPolicy`. 

219 

220 Args: 

221 telemetry (TelemetryProtocol): The service for emitting events. 

222 """ 

223 self._telemetry = telemetry 

224 

225 async def run( 

226 self, 

227 supplier: Callable[[], Awaitable[T]], 

228 seconds: float = 1.0, 

229 retries: int = 3, 

230 delay: float = 1.0, 

231 backoff: float = 2.0, 

232 jitter: float = 0.3, 

233 retry_on: tuple[type[BaseException], ...] = (Exception,), 

234 ) -> T: 

235 """Executes a supplier with a timeout and exponential-backoff retries. 

236 

237 Args: 

238 supplier (Callable[[], Awaitable[T]]): The async operation to run. 

239 seconds (float): The timeout for each attempt in seconds. Must be > 0. 

240 retries (int): The maximum number of retry attempts. 

241 delay (float): The initial delay in seconds before the first retry. 

242 backoff (float): The multiplier for the delay after each failure. 

243 jitter (float): The random fractional jitter to apply to each delay. 

244 retry_on (tuple[type[BaseException], ...]): A tuple of exception 

245 types that will trigger a retry. 

246 

247 Returns: 

248 T: The result of the `supplier` if one of the attempts succeeds. 

249 

250 Raises: 

251 ValueError: If `seconds` is less than or equal to 0. 

252 BaseException: The last exception raised by `supplier` if all 

253 attempts fail. 

254 """ 

255 if seconds <= 0: 

256 raise ValueError("seconds must be > 0") 

257 

258 ctx = _try_asyncio_timeout(seconds) 

259 

260 if ctx is not None: 

261 async with ctx: 

262 return await _backoff_loop( 

263 supplier, 

264 retries=retries, 

265 delay=delay, 

266 backoff=backoff, 

267 jitter=jitter, 

268 retry_on=retry_on, 

269 telemetry=self._telemetry, 

270 ) 

271 else: 

272 

273 async def timed_supplier() -> T: 

274 """Wraps the supplier in an `asyncio.wait_for` timeout. 

275 

276 Returns: 

277 T: The result of the `supplier` if it completes in time. 

278 """ 

279 return await asyncio.wait_for(supplier(), timeout=seconds) 

280 

281 return await _backoff_loop( 

282 timed_supplier, 

283 retries=retries, 

284 delay=delay, 

285 backoff=backoff, 

286 jitter=jitter, 

287 retry_on=retry_on, 

288 telemetry=self._telemetry, 

289 ) 

290 

291 def reset(self) -> None: 

292 """Resets the retry policy state. This is a no-op for this policy.""" 

293 self._telemetry.event("retry_reset", {}) 

294 

295 

296__all__ = ["TimeoutRetryPolicy", "ExponentialBackoffRetryPolicy"]