Coverage for / home / runner / work / bijux-cli / bijux-cli / src / bijux_cli / core / di.py: 98%
305 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-26 17:59 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-26 17:59 +0000
1# SPDX-License-Identifier: Apache-2.0
2# Copyright © 2025 Bijan Mousavi
4"""Provides the central dependency injection container for the Bijux CLI.
6This module defines the `DIContainer` class, a thread-safe singleton that
7manages the registration, resolution, and lifecycle of all services within the
8application. It allows components to be loosely coupled by requesting
9dependencies based on abstract protocols rather than concrete classes.
11Key features include:
12 * Singleton pattern for global access via `DIContainer.current()`.
13 * Thread-safe operations for concurrent environments.
14 * Lazy instantiation of services upon first request.
15 * Support for named registrations to allow multiple implementations of the
16 same protocol.
17 * Both synchronous (`resolve`) and asynchronous (`resolve_async`) service
18 resolution.
19 * Circular dependency detection.
20"""
22from __future__ import annotations
24import asyncio
25from collections.abc import Awaitable, Callable, Coroutine, Iterator, Sequence
26from contextlib import contextmanager, suppress
27from contextvars import ContextVar
28import inspect
29import logging
30from threading import RLock
31from typing import Any, Literal, TypeVar, cast, overload
33from injector import Injector
35from bijux_cli.core.errors import BijuxError
36from bijux_cli.core.precedence import LogPolicy
37from bijux_cli.core.runtime import run_awaitable
38from bijux_cli.services.contracts import ObservabilityProtocol
40T = TypeVar("T")
41_SENTINEL = object()
44def _key_name(key: type[Any] | str) -> str:
45 """Returns a human-readable name for a DI service key.
47 Args:
48 key (type[Any] | str): The service key, which can be a type or a string.
50 Returns:
51 str: The string representation of the key.
52 """
53 if isinstance(key, str):
54 return key
55 try:
56 return key.__name__
57 except AttributeError:
58 return str(key)
61class DIContainer:
62 """A thread-safe, singleton dependency injection container.
64 This class manages the lifecycle of services, including registration of
65 factories, lazy instantiation, and resolution. It integrates with an
66 underlying `injector` for basic services and handles custom named
67 registrations and circular dependency detection.
69 Attributes:
70 _instance (DIContainer | None): The singleton instance of the container.
71 _lock (RLock): A reentrant lock to ensure thread safety.
72 _resolving (ContextVar): A context variable to track services currently
73 being resolved, used for circular dependency detection.
74 _obs (ObservabilityProtocol | None): A cached reference to the logging
75 service for internal use.
76 _injector (Injector): The underlying `injector` library instance.
77 _store (dict): A mapping of (key, name) tuples to registered factories
78 or values.
79 _services (dict): A cache of resolved service instances.
80 """
82 _instance: DIContainer | None = None
83 _lock = RLock()
84 _resolving: ContextVar[set[str] | None] = ContextVar("resolving", default=None)
85 _obs: ObservabilityProtocol | None = None
86 _log_policy: LogPolicy | None = None
88 @classmethod
89 def current(cls) -> DIContainer:
90 """Returns the active singleton instance of the `DIContainer`.
92 Returns:
93 DIContainer: The singleton instance.
94 """
95 with cls._lock:
96 if cls._instance is None:
97 cls._instance = cls()
98 cls._log_static(
99 logging.DEBUG, "DIContainer.current auto-initialized singleton"
100 )
101 return cls._instance
103 @classmethod
104 def reset(cls) -> None:
105 """Resets the singleton instance, shutting down all services.
107 This method is intended for use in testing environments to ensure a
108 clean state between tests. It clears all registered services and
109 factories.
110 """
111 inst = None
112 with cls._lock:
113 inst = cls._instance
114 cls._instance = None
115 cls._obs = None
116 cls._log_policy = None
117 if inst is None:
118 cls._log_static(logging.DEBUG, "DIContainer reset (no instance)")
119 return
120 try:
121 run_awaitable(inst.shutdown())
122 except Exception as exc:
123 cls._log_static(logging.ERROR, f"Error during shutdown: {exc}")
124 inst._services.clear()
125 inst._store.clear()
126 inst._obs = None
127 cls._log_static(logging.DEBUG, "DIContainer reset")
129 @classmethod
130 async def reset_async(cls) -> None:
131 """Asynchronously resets the singleton instance.
133 This method is intended for use in testing environments. All services
134 and factories are cleared.
135 """
136 instance = None
137 with cls._lock:
138 if cls._instance is not None: 138 ↛ 143line 138 didn't jump to line 143
139 instance = cls._instance
140 cls._instance = None
141 cls._obs = None
142 cls._log_policy = None
143 if instance is not None: 143 ↛ 148line 143 didn't jump to line 148 because the condition on line 143 was always true
144 await instance.shutdown()
145 instance._services.clear()
146 instance._store.clear()
147 instance._obs = None
148 cls._log_static(logging.DEBUG, "DIContainer reset")
150 def __new__(cls) -> DIContainer:
151 """Creates or returns the singleton instance of the container."""
152 with cls._lock:
153 if cls._instance is None:
154 cls._instance = super().__new__(cls)
155 return cls._instance
157 def __init__(self) -> None:
158 """Initializes the container's internal stores.
160 This method is idempotent; it does nothing if the container has already
161 been initialized.
162 """
163 if getattr(self, "_initialised", False):
164 return
165 self._injector = Injector()
166 self._store: dict[
167 tuple[type[Any] | str, str | None], Callable[[], Any | Awaitable[Any]] | Any
168 ] = {}
169 self._services: dict[tuple[type[Any] | str, str | None], Any] = {}
170 self._obs: ObservabilityProtocol | None = None
171 self._initialised = True
172 self._log_static(logging.DEBUG, "DIContainer initialised")
174 @classmethod
175 def set_log_policy(cls, policy: LogPolicy) -> None:
176 """Attach a log policy for DI logging."""
177 cls._log_policy = policy
179 def register(
180 self,
181 key: type[T] | str,
182 factory_or_value: Callable[[], T | Awaitable[T]] | T,
183 name: str | None = None,
184 ) -> None:
185 """Registers a factory or a pre-resolved value for a given service key.
187 Args:
188 key (type[T] | str): The service key, which can be a protocol type
189 or a unique string identifier.
190 factory_or_value: The factory function that creates the service,
191 or the service instance itself.
192 name (str | None): An optional name for the registration, allowing
193 multiple implementations of the same key.
195 Returns:
196 None:
198 Raises:
199 BijuxError: If the registration key is invalid or conflicts with an
200 existing registration.
201 """
202 if not (isinstance(key, str) or inspect.isclass(key)):
203 raise BijuxError("Service key must be a type or str", http_status=400)
204 try:
205 store_key = (key, name)
206 if isinstance(key, str) and any(
207 isinstance(k, type) and k.__name__ == key for k, _ in self._store
208 ):
209 raise BijuxError(
210 f"Key {key} conflicts with existing type name", http_status=400
211 )
212 if isinstance(key, type) and any(k == key.__name__ for k, _ in self._store):
213 raise BijuxError(
214 f"Type {key.__name__} conflicts with existing string key",
215 http_status=400,
216 )
217 self._store[store_key] = factory_or_value
218 if isinstance(factory_or_value, ObservabilityProtocol) and not isinstance(
219 factory_or_value, type
220 ):
221 self._obs = factory_or_value
222 self._log(
223 logging.DEBUG,
224 "Registered service",
225 extra={"service_name": _key_name(key), "svc_alias": name},
226 )
227 except (TypeError, KeyError) as exc:
228 self._log(
229 logging.ERROR,
230 f"Failed to register service: {exc}",
231 extra={"service_name": _key_name(key), "name": name},
232 )
233 raise BijuxError(
234 f"Failed to register service {_key_name(key)}: {exc}", http_status=400
235 ) from exc
237 @overload
238 def _resolve_common(
239 self, key: type[T] | str, name: str | None, *, async_mode: Literal[False]
240 ) -> T: # pragma: no cover
241 ...
243 @overload
244 def _resolve_common(
245 self, key: type[T] | str, name: str | None, *, async_mode: Literal[True]
246 ) -> T | Awaitable[T]: # pragma: no cover
247 ...
249 @overload
250 def _resolve_common(
251 self, key: type[T] | str, name: str | None, *, async_mode: bool
252 ) -> T | Awaitable[T]: # pragma: no cover
253 ...
255 def _resolve_common(
256 self,
257 key: type[T] | str,
258 name: str | None = None,
259 *,
260 async_mode: bool = False,
261 ) -> T | Awaitable[T]:
262 """Handles the core logic for resolving a service instance.
264 This internal method implements the resolution strategy:
265 1. Check for a cached instance.
266 2. If not cached, check for a registered factory.
267 3. If no factory, attempt resolution via the underlying `injector`.
268 4. If a factory is found, execute it, handling circular dependencies
269 and both sync/async factories.
270 5. Cache and return the result.
272 Args:
273 key (type[T] | str): The service key to resolve.
274 name (str | None): An optional name for the registration.
275 async_mode (bool): If True, allows returning an awaitable if the
276 factory is async.
278 Returns:
279 T | Awaitable[T]: The resolved service instance, or an awaitable
280 that will resolve to the instance if `async_mode` is True.
282 Raises:
283 KeyError: If the service is not registered.
284 BijuxError: If a circular dependency is detected or the factory fails.
285 """
286 name_str = f"{_key_name(key)}:{name}" if name else _key_name(key)
287 resolving = self._resolving.get() or set()
288 if name_str in resolving:
289 self._log(
290 logging.ERROR,
291 f"Circular dependency detected for {name_str}",
292 extra={"service_name": name_str},
293 )
294 raise BijuxError(
295 f"Circular dependency detected for {name_str}", http_status=400
296 )
297 with self._lock:
298 store_key = (key, name)
299 if (
300 store_key in self._services
301 and self._services[store_key] is not _SENTINEL
302 ):
303 self._log(
304 logging.DEBUG,
305 f"Resolved service: {type(self._services[store_key]).__name__}",
306 extra={"service_name": name_str},
307 )
308 return cast(T, self._services[store_key])
309 if store_key not in self._store:
310 if isinstance(key, type):
311 try:
312 resolved: T = self._injector.get(key)
313 self._services[store_key] = resolved
314 self._log(
315 logging.DEBUG,
316 f"Resolved service via injector: {type(resolved).__name__}",
317 extra={"service_name": name_str},
318 )
319 return resolved
320 except Exception as exc:
321 self._log(
322 logging.ERROR,
323 "Service not registered via injector",
324 extra={"service_name": name_str},
325 )
326 raise KeyError(f"Service not registered: {name_str}") from exc
327 else:
328 self._log(
329 logging.ERROR,
330 "Service not registered",
331 extra={"service_name": name_str},
332 )
333 raise KeyError(f"Service not registered: {name_str}")
334 self._services[store_key] = _SENTINEL
335 token = self._resolving.set(resolving | {name_str})
336 try:
337 factory = self._store[store_key]
338 is_function_like = (
339 inspect.isfunction(factory)
340 or inspect.ismethod(factory)
341 or inspect.iscoroutinefunction(factory)
342 )
343 result: T | None
344 raw = factory() if is_function_like else factory
345 if inspect.isawaitable(raw):
346 if async_mode:
347 return cast(Awaitable[T], raw)
348 try:
349 loop = asyncio.get_running_loop()
350 except RuntimeError:
351 loop = None
352 if loop is not None:
353 if not hasattr(loop, "run_until_complete"):
354 with suppress(Exception):
355 if asyncio.iscoroutine(raw) and hasattr(raw, "close"): 355 ↛ 357line 355 didn't jump to line 357
356 cast(Any, raw).close()
357 raise RuntimeError(
358 "Cannot sync-resolve while an event loop is running"
359 )
361 is_coro = asyncio.iscoroutine(raw)
362 try:
363 if is_coro:
364 result = loop.run_until_complete(
365 cast(Coroutine[Any, Any, T], raw)
366 )
367 else:
369 async def _await(a: Awaitable[T]) -> T:
370 return await a
372 result = loop.run_until_complete(
373 _await(cast(Awaitable[T], raw))
374 )
375 finally:
376 if is_coro:
377 with suppress(Exception):
378 if hasattr(raw, "close"): 378 ↛ 379line 378 didn't jump to line 379 because the condition on line 378 was never true
379 cast(Any, raw).close()
380 return result
381 if asyncio.iscoroutine(raw):
382 coro = cast(Coroutine[Any, Any, T], raw)
383 try:
384 result = run_awaitable(coro, want_result=True)
385 finally:
386 with suppress(Exception):
387 if hasattr(coro, "close"): 387 ↛ 399line 387 didn't jump to line 399
388 coro.close()
389 else:
391 async def _await(a: Awaitable[T]) -> T:
392 return await a
394 result = run_awaitable(
395 _await(cast(Awaitable[T], raw)), want_result=True
396 )
397 else:
398 result = cast(T, raw)
399 if result is None:
400 self._log(
401 logging.ERROR,
402 "Factory returned None",
403 extra={"service_name": name_str},
404 )
405 raise BijuxError(
406 f"Factory for {name_str} returned None", http_status=424
407 )
408 with self._lock:
409 self._services[store_key] = result
410 if isinstance(result, ObservabilityProtocol) and not isinstance(
411 result, type
412 ):
413 self._obs = result
414 self._log(
415 logging.DEBUG,
416 f"Resolved service: {type(result).__name__}",
417 extra={"service_name": name_str},
418 )
419 return result
420 except (KeyError, TypeError, RuntimeError):
421 with self._lock:
422 self._services.pop(store_key, None)
423 self._log(
424 logging.ERROR,
425 f"Service resolution failed: {_key_name(key)}",
426 extra={"service_name": name_str},
427 )
428 raise
429 except BaseException as exc:
430 with self._lock:
431 self._services.pop(store_key, None)
432 self._log(
433 logging.ERROR,
434 f"Factory failed: {exc}",
435 extra={"service_name": name_str},
436 )
437 raise BijuxError(
438 f"Factory for {name_str} raised: {exc}", http_status=400
439 ) from exc
440 finally:
441 self._resolving.reset(token)
443 def resolve(self, key: type[T] | str, name: str | None = None) -> T:
444 """Resolves and returns a service instance synchronously.
446 If the service factory is asynchronous, this method will run the
447 async factory to completion.
449 Args:
450 key (type[T] | str): The service key to resolve.
451 name (str | None): An optional name for the registration.
453 Returns:
454 T: The resolved service instance.
456 Raises:
457 KeyError: If the service is not registered.
458 BijuxError: If the factory fails, returns None, or a circular
459 dependency is detected.
460 """
461 return self._resolve_common(key, name, async_mode=False)
463 async def resolve_async(self, key: type[T] | str, name: str | None = None) -> T:
464 """Resolves and returns a service instance asynchronously.
466 This method should be used when the caller is in an async context. It
467 can resolve both synchronous and asynchronous factories.
469 Args:
470 key (type[T] | str): The service key to resolve.
471 name (str | None): An optional name for the registration.
473 Returns:
474 T: The resolved service instance.
476 Raises:
477 KeyError: If the service is not registered.
478 BijuxError: If the factory fails, returns None, or a circular
479 dependency is detected.
480 """
481 result = self._resolve_common(key, name, async_mode=True)
482 if asyncio.iscoroutine(result):
483 return await cast(Awaitable[T], result)
484 else:
485 return cast(T, result)
487 def unregister(self, key: type[Any] | str, name: str | None = None) -> bool:
488 """Unregisters a service factory and removes any cached instance.
490 Args:
491 key (type[Any] | str): The service key to unregister.
492 name (str | None): An optional name for the registration.
494 Returns:
495 bool: True if a service was found and unregistered, otherwise False.
496 """
497 with self._lock:
498 store_key = (key, name)
499 removed = self._store.pop(store_key, None) is not None
500 if store_key in self._services and isinstance(
501 self._services[store_key], ObservabilityProtocol
502 ):
503 self._obs = None
504 self._services.pop(store_key, None)
505 if removed:
506 self._log(
507 logging.INFO,
508 "Unregistered service",
509 extra={"service_name": _key_name(key), "svc_alias": name},
510 )
511 return removed
513 @contextmanager
514 def override(
515 self,
516 key: type[T] | str,
517 factory_or_value: Callable[[], T | Awaitable[T]] | T,
518 name: str | None = None,
519 ) -> Iterator[None]:
520 """Temporarily overrides a service registration within a context block.
522 This is primarily useful for testing, allowing a service to be replaced
523 with a mock or stub. The original registration is restored upon exiting
524 the context.
526 Args:
527 key (type[T] | str): The service key to override.
528 factory_or_value: The temporary factory or value.
529 name (str | None): An optional name for the registration.
531 Yields:
532 None:
533 """
534 with self._lock:
535 store_key = (key, name)
536 original_factory = self._store.get(store_key)
537 original_instance = self._services.get(store_key)
538 self.register(key, factory_or_value, name)
539 if store_key in self._services:
540 del self._services[store_key]
541 self._log(
542 logging.DEBUG,
543 "Overriding service",
544 extra={"service_name": _key_name(key), "svc_alias": name},
545 )
546 try:
547 yield
548 finally:
549 with self._lock:
550 if original_factory is not None:
551 self._store[store_key] = original_factory
552 if original_instance is not None:
553 self._services[store_key] = original_instance
554 else:
555 self._services.pop(store_key, None)
556 self._log(
557 logging.DEBUG,
558 "Restored service",
559 extra={"service_name": _key_name(key), "svc_alias": name},
560 )
561 else:
562 self.unregister(key, name)
563 self._log(
564 logging.DEBUG,
565 "Removed service override",
566 extra={"service_name": _key_name(key), "svc_alias": name},
567 )
569 async def shutdown(self) -> None:
570 """Shuts down all resolved services that have a cleanup method.
572 Iterates through all cached services and calls a `shutdown()` or
573 `close()` method if one exists, handling both sync and async methods.
574 """
575 services = []
576 with self._lock:
577 services = list(self._services.items())
578 obs_ref = self._obs
579 self._services.clear()
580 self._store.clear()
581 self._obs = None
582 for key, instance in services:
583 try:
584 shutdown_func = getattr(instance, "shutdown", None)
585 if shutdown_func and callable(shutdown_func):
586 is_async_shutdown = asyncio.iscoroutinefunction(shutdown_func)
587 if is_async_shutdown:
588 await asyncio.wait_for(shutdown_func(), timeout=5.0)
589 else:
590 shutdown_func()
591 self._log(
592 logging.DEBUG,
593 "Shutting down service",
594 extra={"service_name": _key_name(key[0]), "svc_alias": key[1]},
595 )
596 elif isinstance(instance, ObservabilityProtocol) and not isinstance(
597 instance, type
598 ):
599 instance.close()
600 self._log(
601 logging.DEBUG,
602 "Closing observability service",
603 extra={"service_name": _key_name(key[0]), "svc_alias": key[1]},
604 )
605 except (RuntimeError, TypeError, TimeoutError) as exc:
606 self._log(
607 logging.ERROR,
608 f"Shutdown failed: {exc}",
609 extra={"service_name": _key_name(key[0]), "svc_alias": key[1]},
610 )
611 if obs_ref and hasattr(obs_ref, "close"):
612 with suppress(Exception):
613 obs_ref.close()
614 self._log(logging.DEBUG, "DIContainer shutdown", extra={})
616 def services(self) -> Sequence[tuple[type[Any] | str, str | None]]:
617 """Returns a list of all resolved and cached service keys.
619 Returns:
620 Sequence[tuple[type[Any] | str, str | None]]: A sequence of
621 (key, name) tuples for all currently resolved services.
622 """
623 with self._lock:
624 return list(self._services.keys())
626 def factories(self) -> Sequence[tuple[type[Any] | str, str | None]]:
627 """Returns a list of all registered factory keys.
629 Returns:
630 Sequence[tuple[type[Any] | str, str | None]]: A sequence of
631 (key, name) tuples for all registered factories.
632 """
633 with self._lock:
634 return list(self._store.keys())
636 def _log(
637 self, level: int, msg: str, *, extra: dict[str, Any] | None = None
638 ) -> None:
639 """Logs a message via the resolved observability service or a fallback.
641 Args:
642 level (int): The logging level (e.g., `logging.INFO`).
643 msg (str): The message to log.
644 extra (dict[str, Any] | None): Additional context for the log entry.
646 Returns:
647 None:
648 """
649 if level <= logging.DEBUG and not (
650 self._log_policy and self._log_policy.show_internal
651 ):
652 return
654 if self._obs and level <= logging.DEBUG:
655 self._obs.log(logging.getLevelName(level).lower(), msg, extra=extra or {})
656 return
658 logger = logging.getLogger("bijux_cli.di")
659 log_extra: dict[str, Any] = {}
660 if extra:
661 log_extra.update(extra)
662 if "name" in log_extra:
663 log_extra["svc_alias"] = log_extra.pop("name")
664 try:
665 logger.log(level, msg, extra=log_extra)
666 except KeyError:
667 logger.warning(
668 "Failed to log with extra=%s – retrying without it", log_extra
669 )
670 logger.log(level, msg)
672 @classmethod
673 def _log_static(
674 cls, level: int, msg: str, *, extra: dict[str, Any] | None = None
675 ) -> None:
676 """Logs a message from a class method context.
678 This method attempts to use a statically cached observability service
679 to prevent re-initialization loops.
681 Args:
682 level (int): The logging level (e.g., `logging.INFO`).
683 msg (str): The message to log.
684 extra (dict[str, Any] | None): Additional context for the log entry.
686 Returns:
687 None:
688 """
689 if level <= logging.DEBUG and not (
690 cls._log_policy and cls._log_policy.show_internal
691 ):
692 return
694 obs = cls._obs or (cls._instance._obs if cls._instance else None)
695 if obs and level <= logging.DEBUG:
696 obs.log(logging.getLevelName(level).lower(), msg, extra=extra or {})
697 return
699 logger = logging.getLogger("bijux_cli.di")
700 log_extra: dict[str, Any] = {}
701 if extra:
702 log_extra.update(extra)
703 if "name" in log_extra:
704 log_extra["svc_alias"] = log_extra.pop("name")
705 try:
706 logger.log(level, msg, extra=log_extra)
707 except KeyError:
708 logger.log(
709 logging.WARNING,
710 "Failed to log with extra=%s – retrying without it",
711 log_extra,
712 )
713 logger.log(level, msg)
715 @classmethod
716 def _reset_for_tests(cls) -> None:
717 """Fully tears down the singleton instance for testing.
719 This method shuts down all services and clears all internal state of
720 the singleton. It is intended exclusively for test suite cleanup.
721 """
722 if cls._instance:
723 try:
724 run_awaitable(cls._instance.shutdown())
725 except Exception as exc:
726 cls._log_static(logging.ERROR, f"Error during test shutdown: {exc}")
727 cls._instance = None
728 cls._obs = None
729 cls._log_static(logging.DEBUG, "DIContainer reset for tests")
732__all__ = ["DIContainer"]