Coverage for /home/runner/work/bijux-cli/bijux-cli/src/bijux_cli/core/di.py: 96%
298 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 23:36 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 23:36 +0000
1# SPDX-License-Identifier: MIT
2# Copyright © 2025 Bijan Mousavi
4"""Provides the central dependency injection container for the Bijux CLI.
6This module defines the `DIContainer` class, a thread-safe singleton that
7manages the registration, resolution, and lifecycle of all services within the
8application. It allows components to be loosely coupled by requesting
9dependencies based on abstract protocols rather than concrete classes.
11Key features include:
12 * Singleton pattern for global access via `DIContainer.current()`.
13 * Thread-safe operations for concurrent environments.
14 * Lazy instantiation of services upon first request.
15 * Support for named registrations to allow multiple implementations of the
16 same protocol.
17 * Both synchronous (`resolve`) and asynchronous (`resolve_async`) service
18 resolution.
19 * Circular dependency detection.
20"""
22from __future__ import annotations
24import asyncio
25from collections.abc import Awaitable, Callable, Coroutine, Iterator, Sequence
26from contextlib import contextmanager, suppress
27from contextvars import ContextVar
28import inspect
29import logging
30import os
31from threading import RLock
32from typing import Any, Literal, TypeVar, cast, overload
34from injector import Binder, Injector, Module, singleton
36from bijux_cli.contracts import ConfigProtocol, ObservabilityProtocol
37from bijux_cli.core.exceptions import BijuxError
38from bijux_cli.services.config import Config
40T = TypeVar("T")
41_SENTINEL = object()
44def _key_name(key: type[Any] | str) -> str:
45 """Returns a human-readable name for a DI service key.
47 Args:
48 key (type[Any] | str): The service key, which can be a type or a string.
50 Returns:
51 str: The string representation of the key.
52 """
53 if isinstance(key, str):
54 return key
55 try:
56 return key.__name__
57 except AttributeError:
58 return str(key)
61class AppConfigModule(Module):
62 """An `injector` module for configuring core CLI dependencies."""
64 def configure(self, binder: Binder) -> None:
65 """Binds the `ConfigProtocol` to its default `Config` implementation.
67 Args:
68 binder (Binder): The `injector` binder instance.
70 Returns:
71 None:
72 """
73 binder.bind(ConfigProtocol, to=Config, scope=singleton)
76class DIContainer:
77 """A thread-safe, singleton dependency injection container.
79 This class manages the lifecycle of services, including registration of
80 factories, lazy instantiation, and resolution. It integrates with an
81 underlying `injector` for basic services and handles custom named
82 registrations and circular dependency detection.
84 Attributes:
85 _instance (DIContainer | None): The singleton instance of the container.
86 _lock (RLock): A reentrant lock to ensure thread safety.
87 _resolving (ContextVar): A context variable to track services currently
88 being resolved, used for circular dependency detection.
89 _obs (ObservabilityProtocol | None): A cached reference to the logging
90 service for internal use.
91 _injector (Injector): The underlying `injector` library instance.
92 _store (dict): A mapping of (key, name) tuples to registered factories
93 or values.
94 _services (dict): A cache of resolved service instances.
95 """
97 _instance: DIContainer | None = None
98 _lock = RLock()
99 _resolving: ContextVar[set[str] | None] = ContextVar("resolving", default=None)
100 _obs: ObservabilityProtocol | None = None
102 @classmethod
103 def current(cls) -> DIContainer:
104 """Returns the active singleton instance of the `DIContainer`.
106 Returns:
107 DIContainer: The singleton instance.
108 """
109 with cls._lock:
110 if cls._instance is None:
111 cls._instance = cls()
112 cls._log_static(
113 logging.WARNING, "DIContainer.current auto-initialized singleton"
114 )
115 return cls._instance
117 @classmethod
118 def reset(cls) -> None:
119 """Resets the singleton instance, shutting down all services.
121 This method is intended for use in testing environments to ensure a
122 clean state between tests. It clears all registered services and
123 factories.
124 """
125 inst = None
126 with cls._lock:
127 inst = cls._instance
128 cls._instance = None
129 cls._obs = None
130 if inst is None:
131 cls._log_static(logging.DEBUG, "DIContainer reset (no instance)")
132 return
133 try:
134 asyncio.run(inst.shutdown())
135 except Exception as exc:
136 cls._log_static(logging.ERROR, f"Error during shutdown: {exc}")
137 inst._services.clear()
138 inst._store.clear()
139 inst._obs = None
140 cls._log_static(logging.DEBUG, "DIContainer reset")
142 @classmethod
143 async def reset_async(cls) -> None:
144 """Asynchronously resets the singleton instance.
146 This method is intended for use in testing environments. All services
147 and factories are cleared.
148 """
149 instance = None
150 with cls._lock:
151 if cls._instance is not None:
152 instance = cls._instance
153 cls._instance = None
154 cls._obs = None
155 if instance is not None:
156 await instance.shutdown()
157 instance._services.clear()
158 instance._store.clear()
159 instance._obs = None
160 cls._log_static(logging.DEBUG, "DIContainer reset")
162 def __new__(cls) -> DIContainer:
163 """Creates or returns the singleton instance of the container."""
164 with cls._lock:
165 if cls._instance is None:
166 cls._instance = super().__new__(cls)
167 return cls._instance
169 def __init__(self) -> None:
170 """Initializes the container's internal stores.
172 This method is idempotent; it does nothing if the container has already
173 been initialized.
174 """
175 if getattr(self, "_initialised", False):
176 return
177 self._injector = Injector(AppConfigModule())
178 self._store: dict[
179 tuple[type[Any] | str, str | None], Callable[[], Any | Awaitable[Any]] | Any
180 ] = {}
181 self._services: dict[tuple[type[Any] | str, str | None], Any] = {}
182 self._obs: ObservabilityProtocol | None = None
183 self._initialised = True
184 self._log_static(logging.INFO, "DIContainer initialised")
186 def register(
187 self,
188 key: type[T] | str,
189 factory_or_value: Callable[[], T | Awaitable[T]] | T,
190 name: str | None = None,
191 ) -> None:
192 """Registers a factory or a pre-resolved value for a given service key.
194 Args:
195 key (type[T] | str): The service key, which can be a protocol type
196 or a unique string identifier.
197 factory_or_value: The factory function that creates the service,
198 or the service instance itself.
199 name (str | None): An optional name for the registration, allowing
200 multiple implementations of the same key.
202 Returns:
203 None:
205 Raises:
206 BijuxError: If the registration key is invalid or conflicts with an
207 existing registration.
208 """
209 if not (isinstance(key, str) or inspect.isclass(key)):
210 raise BijuxError("Service key must be a type or str", http_status=400)
211 try:
212 store_key = (key, name)
213 if isinstance(key, str) and any(
214 isinstance(k, type) and k.__name__ == key for k, _ in self._store
215 ):
216 raise BijuxError(
217 f"Key {key} conflicts with existing type name", http_status=400
218 )
219 if isinstance(key, type) and any(k == key.__name__ for k, _ in self._store):
220 raise BijuxError(
221 f"Type {key.__name__} conflicts with existing string key",
222 http_status=400,
223 )
224 self._store[store_key] = factory_or_value
225 if isinstance(factory_or_value, ObservabilityProtocol) and not isinstance(
226 factory_or_value, type
227 ):
228 self._obs = factory_or_value
229 self._log(
230 logging.DEBUG,
231 "Registered service",
232 extra={"service_name": _key_name(key), "svc_alias": name},
233 )
234 except (TypeError, KeyError) as exc:
235 self._log(
236 logging.ERROR,
237 f"Failed to register service: {exc}",
238 extra={"service_name": _key_name(key), "name": name},
239 )
240 raise BijuxError(
241 f"Failed to register service {_key_name(key)}: {exc}", http_status=400
242 ) from exc
244 @overload
245 def _resolve_common(
246 self, key: type[T] | str, name: str | None, *, async_mode: Literal[False]
247 ) -> T: # pragma: no cover
248 ...
250 @overload
251 def _resolve_common(
252 self, key: type[T] | str, name: str | None, *, async_mode: Literal[True]
253 ) -> T | Awaitable[T]: # pragma: no cover
254 ...
256 @overload
257 def _resolve_common(
258 self, key: type[T] | str, name: str | None, *, async_mode: bool
259 ) -> T | Awaitable[T]: # pragma: no cover
260 ...
262 def _resolve_common(
263 self,
264 key: type[T] | str,
265 name: str | None = None,
266 *,
267 async_mode: bool = False,
268 ) -> T | Awaitable[T]:
269 """Handles the core logic for resolving a service instance.
271 This internal method implements the resolution strategy:
272 1. Check for a cached instance.
273 2. If not cached, check for a registered factory.
274 3. If no factory, attempt resolution via the underlying `injector`.
275 4. If a factory is found, execute it, handling circular dependencies
276 and both sync/async factories.
277 5. Cache and return the result.
279 Args:
280 key (type[T] | str): The service key to resolve.
281 name (str | None): An optional name for the registration.
282 async_mode (bool): If True, allows returning an awaitable if the
283 factory is async.
285 Returns:
286 T | Awaitable[T]: The resolved service instance, or an awaitable
287 that will resolve to the instance if `async_mode` is True.
289 Raises:
290 KeyError: If the service is not registered.
291 BijuxError: If a circular dependency is detected or the factory fails.
292 """
293 name_str = f"{_key_name(key)}:{name}" if name else _key_name(key)
294 resolving = self._resolving.get() or set()
295 if name_str in resolving:
296 self._log(
297 logging.ERROR,
298 f"Circular dependency detected for {name_str}",
299 extra={"service_name": name_str},
300 )
301 raise BijuxError(
302 f"Circular dependency detected for {name_str}", http_status=400
303 )
304 with self._lock:
305 store_key = (key, name)
306 if (
307 store_key in self._services
308 and self._services[store_key] is not _SENTINEL
309 ):
310 self._log(
311 logging.DEBUG,
312 f"Resolved service: {type(self._services[store_key]).__name__}",
313 extra={"service_name": name_str},
314 )
315 return cast(T, self._services[store_key])
316 if store_key not in self._store:
317 if isinstance(key, type):
318 try:
319 result: T = self._injector.get(key)
320 self._services[store_key] = result
321 self._log(
322 logging.DEBUG,
323 f"Resolved service via injector: {type(result).__name__}",
324 extra={"service_name": name_str},
325 )
326 return result
327 except Exception as exc:
328 self._log(
329 logging.ERROR,
330 "Service not registered via injector",
331 extra={"service_name": name_str},
332 )
333 raise KeyError(f"Service not registered: {name_str}") from exc
334 else:
335 self._log(
336 logging.ERROR,
337 "Service not registered",
338 extra={"service_name": name_str},
339 )
340 raise KeyError(f"Service not registered: {name_str}")
341 self._services[store_key] = _SENTINEL
342 token = self._resolving.set(resolving | {name_str})
343 try:
344 factory = self._store[store_key]
345 is_function_like = (
346 inspect.isfunction(factory)
347 or inspect.ismethod(factory)
348 or inspect.iscoroutinefunction(factory)
349 )
350 if os.getenv("VERBOSE_DI") and not os.getenv("BIJUXCLI_TEST_MODE"):
351 self._log(
352 logging.DEBUG,
353 f"Executing factory for service: {name_str}",
354 extra={"service_name": name_str},
355 )
356 raw = factory() if is_function_like else factory
357 if inspect.isawaitable(raw):
358 if async_mode:
359 return cast(Awaitable[T], raw)
360 try:
361 loop = asyncio.get_running_loop()
362 except RuntimeError:
363 loop = None
364 if loop is not None:
365 if not hasattr(loop, "run_until_complete"):
366 with suppress(Exception):
367 if asyncio.iscoroutine(raw) and hasattr(raw, "close"): 367 ↛ 369line 367 didn't jump to line 369
368 cast(Any, raw).close()
369 raise RuntimeError(
370 "Cannot sync-resolve while an event loop is running"
371 )
373 is_coro = asyncio.iscoroutine(raw)
374 try:
375 if is_coro: 375 ↛ 381line 375 didn't jump to line 381 because the condition on line 375 was always true
376 result = loop.run_until_complete(
377 cast(Coroutine[Any, Any, T], raw)
378 )
379 else:
381 async def _await(a: Awaitable[T]) -> T:
382 return await a
384 result = loop.run_until_complete(
385 _await(cast(Awaitable[T], raw))
386 )
387 finally:
388 if is_coro: 388 ↛ 392line 388 didn't jump to line 392 because the condition on line 388 was always true
389 with suppress(Exception):
390 if hasattr(raw, "close"): 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true
391 cast(Any, raw).close()
392 return result
393 if asyncio.iscoroutine(raw): 393 ↛ 403line 393 didn't jump to line 403 because the condition on line 393 was always true
394 coro = cast(Coroutine[Any, Any, T], raw)
395 try:
396 result = asyncio.run(coro)
397 finally:
398 with suppress(Exception):
399 if hasattr(coro, "close"): 399 ↛ 409line 399 didn't jump to line 409
400 coro.close()
401 else:
403 async def _await(a: Awaitable[T]) -> T:
404 return await a
406 result = asyncio.run(_await(cast(Awaitable[T], raw)))
407 else:
408 result = cast(T, raw)
409 if result is None:
410 self._log(
411 logging.ERROR,
412 "Factory returned None",
413 extra={"service_name": name_str},
414 )
415 raise BijuxError(
416 f"Factory for {name_str} returned None", http_status=424
417 )
418 with self._lock:
419 self._services[store_key] = result
420 if isinstance(result, ObservabilityProtocol) and not isinstance(
421 result, type
422 ):
423 self._obs = result
424 self._log(
425 logging.DEBUG,
426 f"Resolved service: {type(result).__name__}",
427 extra={"service_name": name_str},
428 )
429 return result
430 except (KeyError, TypeError, RuntimeError):
431 with self._lock:
432 self._services.pop(store_key, None)
433 self._log(
434 logging.ERROR,
435 f"Service resolution failed: {_key_name(key)}",
436 extra={"service_name": name_str},
437 )
438 raise
439 except BaseException as exc:
440 with self._lock:
441 self._services.pop(store_key, None)
442 self._log(
443 logging.ERROR,
444 f"Factory failed: {exc}",
445 extra={"service_name": name_str},
446 )
447 raise BijuxError(
448 f"Factory for {name_str} raised: {exc}", http_status=400
449 ) from exc
450 finally:
451 self._resolving.reset(token)
453 def resolve(self, key: type[T] | str, name: str | None = None) -> T:
454 """Resolves and returns a service instance synchronously.
456 If the service factory is asynchronous, this method will run the
457 async factory to completion.
459 Args:
460 key (type[T] | str): The service key to resolve.
461 name (str | None): An optional name for the registration.
463 Returns:
464 T: The resolved service instance.
466 Raises:
467 KeyError: If the service is not registered.
468 BijuxError: If the factory fails, returns None, or a circular
469 dependency is detected.
470 """
471 return self._resolve_common(key, name, async_mode=False)
473 async def resolve_async(self, key: type[T] | str, name: str | None = None) -> T:
474 """Resolves and returns a service instance asynchronously.
476 This method should be used when the caller is in an async context. It
477 can resolve both synchronous and asynchronous factories.
479 Args:
480 key (type[T] | str): The service key to resolve.
481 name (str | None): An optional name for the registration.
483 Returns:
484 T: The resolved service instance.
486 Raises:
487 KeyError: If the service is not registered.
488 BijuxError: If the factory fails, returns None, or a circular
489 dependency is detected.
490 """
491 result = self._resolve_common(key, name, async_mode=True)
492 if asyncio.iscoroutine(result):
493 return await cast(Awaitable[T], result)
494 else:
495 return cast(T, result)
497 def unregister(self, key: type[Any] | str, name: str | None = None) -> bool:
498 """Unregisters a service factory and removes any cached instance.
500 Args:
501 key (type[Any] | str): The service key to unregister.
502 name (str | None): An optional name for the registration.
504 Returns:
505 bool: True if a service was found and unregistered, otherwise False.
506 """
507 with self._lock:
508 store_key = (key, name)
509 removed = self._store.pop(store_key, None) is not None
510 if store_key in self._services and isinstance(
511 self._services[store_key], ObservabilityProtocol
512 ):
513 self._obs = None
514 self._services.pop(store_key, None)
515 if removed:
516 self._log(
517 logging.INFO,
518 "Unregistered service",
519 extra={"service_name": _key_name(key), "svc_alias": name},
520 )
521 return removed
523 @contextmanager
524 def override(
525 self,
526 key: type[T] | str,
527 factory_or_value: Callable[[], T | Awaitable[T]] | T,
528 name: str | None = None,
529 ) -> Iterator[None]:
530 """Temporarily overrides a service registration within a context block.
532 This is primarily useful for testing, allowing a service to be replaced
533 with a mock or stub. The original registration is restored upon exiting
534 the context.
536 Args:
537 key (type[T] | str): The service key to override.
538 factory_or_value: The temporary factory or value.
539 name (str | None): An optional name for the registration.
541 Yields:
542 None:
543 """
544 with self._lock:
545 store_key = (key, name)
546 original_factory = self._store.get(store_key)
547 original_instance = self._services.get(store_key)
548 self.register(key, factory_or_value, name)
549 if store_key in self._services:
550 del self._services[store_key]
551 self._log(
552 logging.DEBUG,
553 "Overriding service",
554 extra={"service_name": _key_name(key), "svc_alias": name},
555 )
556 try:
557 yield
558 finally:
559 with self._lock:
560 if original_factory is not None:
561 self._store[store_key] = original_factory
562 if original_instance is not None:
563 self._services[store_key] = original_instance
564 else:
565 self._services.pop(store_key, None)
566 self._log(
567 logging.DEBUG,
568 "Restored service",
569 extra={"service_name": _key_name(key), "svc_alias": name},
570 )
571 else:
572 self.unregister(key, name)
573 self._log(
574 logging.DEBUG,
575 "Removed service override",
576 extra={"service_name": _key_name(key), "svc_alias": name},
577 )
579 async def shutdown(self) -> None:
580 """Shuts down all resolved services that have a cleanup method.
582 Iterates through all cached services and calls a `shutdown()` or
583 `close()` method if one exists, handling both sync and async methods.
584 """
585 services = []
586 with self._lock:
587 services = list(self._services.items())
588 obs_ref = self._obs
589 self._services.clear()
590 self._store.clear()
591 self._obs = None
592 for key, instance in services:
593 try:
594 shutdown_func = getattr(instance, "shutdown", None)
595 if shutdown_func and callable(shutdown_func):
596 is_async_shutdown = asyncio.iscoroutinefunction(shutdown_func)
597 if is_async_shutdown:
598 await asyncio.wait_for(shutdown_func(), timeout=5.0)
599 else:
600 shutdown_func()
601 self._log(
602 logging.DEBUG,
603 "Shutting down service",
604 extra={"service_name": _key_name(key[0]), "svc_alias": key[1]},
605 )
606 elif isinstance(instance, ObservabilityProtocol) and not isinstance(
607 instance, type
608 ):
609 instance.close()
610 self._log(
611 logging.DEBUG,
612 "Closing observability service",
613 extra={"service_name": _key_name(key[0]), "svc_alias": key[1]},
614 )
615 except (RuntimeError, TypeError, TimeoutError) as exc:
616 self._log(
617 logging.ERROR,
618 f"Shutdown failed: {exc}",
619 extra={"service_name": _key_name(key[0]), "svc_alias": key[1]},
620 )
621 if obs_ref and hasattr(obs_ref, "close"):
622 with suppress(Exception):
623 obs_ref.close()
624 self._log(logging.INFO, "DIContainer shutdown", extra={})
626 def services(self) -> Sequence[tuple[type[Any] | str, str | None]]:
627 """Returns a list of all resolved and cached service keys.
629 Returns:
630 Sequence[tuple[type[Any] | str, str | None]]: A sequence of
631 (key, name) tuples for all currently resolved services.
632 """
633 with self._lock:
634 return list(self._services.keys())
636 def factories(self) -> Sequence[tuple[type[Any] | str, str | None]]:
637 """Returns a list of all registered factory keys.
639 Returns:
640 Sequence[tuple[type[Any] | str, str | None]]: A sequence of
641 (key, name) tuples for all registered factories.
642 """
643 with self._lock:
644 return list(self._store.keys())
646 def _log(
647 self, level: int, msg: str, *, extra: dict[str, Any] | None = None
648 ) -> None:
649 """Logs a message via the resolved observability service or a fallback.
651 Args:
652 level (int): The logging level (e.g., `logging.INFO`).
653 msg (str): The message to log.
654 extra (dict[str, Any] | None): Additional context for the log entry.
656 Returns:
657 None:
658 """
659 if self._obs:
660 self._obs.log(logging.getLevelName(level).lower(), msg, extra=extra or {})
661 else:
662 logger = logging.getLogger("bijux_cli.di")
663 log_extra: dict[str, Any] = {}
664 if extra:
665 log_extra.update(extra)
666 if "name" in log_extra:
667 log_extra["svc_alias"] = log_extra.pop("name")
668 try:
669 logger.log(level, msg, extra=log_extra)
670 except KeyError:
671 logger.warning(
672 "Failed to log with extra=%s – retrying without it", log_extra
673 )
674 logger.log(level, msg)
676 @classmethod
677 def _log_static(
678 cls, level: int, msg: str, *, extra: dict[str, Any] | None = None
679 ) -> None:
680 """Logs a message from a class method context.
682 This method attempts to use a statically cached observability service
683 to prevent re-initialization loops.
685 Args:
686 level (int): The logging level (e.g., `logging.INFO`).
687 msg (str): The message to log.
688 extra (dict[str, Any] | None): Additional context for the log entry.
690 Returns:
691 None:
692 """
693 obs = cls._obs or (cls._instance._obs if cls._instance else None)
694 if obs:
695 obs.log(logging.getLevelName(level).lower(), msg, extra=extra or {})
696 else:
697 logger = logging.getLogger("bijux_cli.di")
698 log_extra: dict[str, Any] = {}
699 if extra:
700 log_extra.update(extra)
701 if "name" in log_extra:
702 log_extra["svc_alias"] = log_extra.pop("name")
703 try:
704 logger.log(level, msg, extra=log_extra)
705 except KeyError:
706 logger.log(
707 logging.WARNING,
708 "Failed to log with extra=%s – retrying without it",
709 log_extra,
710 )
711 logger.log(level, msg)
713 @classmethod
714 def _reset_for_tests(cls) -> None:
715 """Fully tears down the singleton instance for testing.
717 This method shuts down all services and clears all internal state of
718 the singleton. It is intended exclusively for test suite cleanup.
719 """
720 if cls._instance:
721 try:
722 asyncio.run(cls._instance.shutdown())
723 except Exception as exc:
724 cls._log_static(logging.ERROR, f"Error during test shutdown: {exc}")
725 cls._instance = None
726 cls._obs = None
727 cls._log_static(logging.DEBUG, "DIContainer reset for tests")
730__all__ = ["DIContainer"]