Coverage for / home / runner / work / bijux-cli / bijux-cli / src / bijux_cli / plugins / registry.py: 97%
245 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-26 17:59 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-26 17:59 +0000
1# SPDX-License-Identifier: Apache-2.0
2# Copyright © 2025 Bijan Mousavi
4"""Provides a concrete plugin registry service using the `pluggy` framework.
6This module defines the `Registry` class, which implements the
7`RegistryProtocol`. It serves as the central manager for the entire plugin
8lifecycle, including registration, aliasing, metadata storage, and the
9invocation of plugin hooks. It is built on top of the `pluggy` library to
10provide a robust and extensible plugin architecture.
11"""
13from __future__ import annotations
15import asyncio
16from collections.abc import AsyncIterable, Callable
17import contextlib
18import importlib.metadata as im
19import logging
20import traceback
21from types import MappingProxyType
22from typing import Any
24from injector import inject
25from packaging.specifiers import SpecifierSet
26from packaging.version import Version as PkgVersion
27import pluggy
29from bijux_cli.core.di import DIContainer
30from bijux_cli.plugins.contracts import PluginState, RegistryProtocol
31from bijux_cli.services.contracts import ObservabilityProtocol, TelemetryProtocol
32from bijux_cli.services.errors import ServiceError
34PRE_EXECUTE = "pre_execute"
35POST_EXECUTE = "post_execute"
36SPEC_VERSION = __import__("bijux_cli").version
38hookspec = pluggy.HookspecMarker("bijux")
41class CoreSpec:
42 """Defines the core hook specifications for CLI plugins."""
44 def __init__(self, dependency_injector: DIContainer) -> None:
45 """Initialize with observability from DI."""
46 self._log = dependency_injector.resolve(ObservabilityProtocol)
48 @hookspec
49 async def startup(self) -> None:
50 """Hook called at startup."""
51 self._log.log("debug", "Hook startup called", extra={})
53 @hookspec
54 async def shutdown(self) -> None:
55 """Hook called at shutdown."""
56 self._log.log("debug", "Hook shutdown called", extra={})
58 @hookspec
59 async def pre_execute(
60 self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
61 ) -> None:
62 """Hook called before command execution."""
63 self._log.log(
64 "debug",
65 "Hook pre_execute called",
66 extra={"name": name, "args": args, "kwargs": kwargs},
67 )
69 @hookspec
70 async def post_execute(self, name: str, result: Any) -> None:
71 """Hook called after command execution."""
72 self._log.log(
73 "debug",
74 "Hook post_execute called",
75 extra={"name": name, "result": repr(result)},
76 )
78 @hookspec
79 def health(self) -> bool | str:
80 """Hook used for health checks."""
81 self._log.log("debug", "Hook health called", extra={})
82 return True
85def command_group(
86 name: str,
87 *,
88 version: str | None = None,
89) -> Callable[[str], Callable[[Callable[..., Any]], Callable[..., Any]]]:
90 """Decorator factory for registering plugin subcommands under a group."""
92 def with_sub(sub: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
93 if " " in sub:
94 raise ValueError("subcommand may not contain spaces")
95 full = f"{name} {sub}"
97 def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
98 try:
99 di = DIContainer.current()
100 reg: RegistryProtocol = di.resolve(RegistryProtocol)
101 except KeyError as exc:
102 raise RuntimeError("RegistryProtocol is not initialized") from exc
104 reg.register(full, fn, alias=None, version=version)
106 try:
107 obs: ObservabilityProtocol = di.resolve(ObservabilityProtocol)
108 obs.log(
109 "info",
110 "Registered command group",
111 extra={"cmd": full, "version": version},
112 )
113 except KeyError:
114 pass
116 try:
117 tel: TelemetryProtocol = di.resolve(TelemetryProtocol)
118 tel.event(
119 "command_group_registered", {"command": full, "version": version}
120 )
121 except KeyError:
122 pass
124 return fn
126 return decorator
128 return with_sub
131def dynamic_choices(
132 callback: Callable[[], list[str]],
133 *,
134 case_sensitive: bool = True,
135) -> Callable[[Any, Any, str], list[str]]:
136 """Creates a Typer completer from a callback function."""
138 def completer(_ctx: Any, _param: Any, incomplete: str) -> list[str]:
139 choices = callback()
140 if case_sensitive:
141 return [c for c in choices if c.startswith(incomplete)]
142 return [c for c in choices if c.lower().startswith(incomplete.lower())]
144 return completer
147def _iter_plugin_eps() -> list[im.EntryPoint]:
148 """Returns all entry points in the 'bijux_cli.plugins' group."""
149 try:
150 eps = im.entry_points()
151 return list(eps.select(group="bijux_cli.plugins"))
152 except Exception:
153 return []
156def _compatible(plugin: Any) -> bool:
157 """Determines if a plugin is compatible with the current CLI API version."""
158 import bijux_cli
160 spec = getattr(plugin, "requires_api_version", ">=0.0.0")
161 try:
162 apiv = bijux_cli.api_version
163 host_api_version = PkgVersion(str(apiv))
164 return SpecifierSet(spec).contains(host_api_version)
165 except Exception:
166 return False
169async def load_entrypoints(
170 di: DIContainer | None = None,
171 registry: RegistryProtocol | None = None,
172) -> None:
173 """Discovers, loads, and registers all entry point-based plugins."""
174 import bijux_cli
176 di = di or DIContainer.current()
177 registry = registry or di.resolve(RegistryProtocol)
179 obs = di.resolve(ObservabilityProtocol, None)
180 tel = di.resolve(TelemetryProtocol, None)
182 for ep in _iter_plugin_eps():
183 try:
184 plugin_class = await asyncio.to_thread(ep.load)
185 plugin = await asyncio.to_thread(plugin_class)
187 if not _compatible(plugin):
188 raise RuntimeError(
189 f"Plugin '{ep.name}' requires API {getattr(plugin, 'requires_api_version', 'N/A')}, "
190 f"host is {bijux_cli.api_version}"
191 )
193 for tgt in (plugin_class, plugin):
194 raw = getattr(tgt, "version", None)
195 if raw is not None and not isinstance(raw, str):
196 tgt.version = str(raw)
198 registry.transition(ep.name, PluginState.DISCOVERED)
199 registry.transition(ep.name, PluginState.INSTALLED)
200 registry.register(ep.name, plugin, alias=None, version=plugin.version)
202 startup = getattr(plugin, "startup", None)
203 if asyncio.iscoroutinefunction(startup):
204 await startup(di)
205 elif callable(startup):
206 await asyncio.to_thread(startup, di)
208 if obs:
209 obs.log("info", f"Loaded plugin '{ep.name}'", extra={})
210 if tel:
211 tel.event("entrypoint_plugin_loaded", {"name": ep.name})
213 except Exception as exc:
214 with contextlib.suppress(Exception):
215 registry.deregister(ep.name)
217 if obs:
218 obs.log(
219 "error",
220 f"Failed to load plugin '{ep.name}'",
221 extra={"trace": traceback.format_exc(limit=5)},
222 )
223 if tel:
224 tel.event(
225 "entrypoint_plugin_failed", {"name": ep.name, "error": str(exc)}
226 )
228 _LOG.debug("Skipped plugin %s: %s", ep.name, exc, exc_info=True)
231class Registry(RegistryProtocol):
232 """A `pluggy`-based registry for managing CLI plugins.
234 This class provides aliasing, metadata storage, and telemetry integration
235 on top of the core `pluggy` plugin management system.
237 Attributes:
238 _telemetry (TelemetryProtocol): The telemetry service for events.
239 _pm (pluggy.PluginManager): The underlying `pluggy` plugin manager.
240 _plugins (dict): A mapping of canonical plugin names to plugin objects.
241 _aliases (dict): A mapping of alias names to canonical plugin names.
242 _meta (dict): A mapping of canonical plugin names to their metadata.
243 mapping (MappingProxyType): A read-only view of the `_plugins` mapping.
244 """
246 @inject
247 def __init__(self, telemetry: TelemetryProtocol):
248 """Initializes the `Registry` service.
250 Args:
251 telemetry (TelemetryProtocol): The telemetry service for tracking
252 registry events.
253 """
254 self._telemetry = telemetry
255 self._pm = pluggy.PluginManager("bijux")
256 self._pm.add_hookspecs(CoreSpec)
257 self._plugins: dict[str, object] = {}
258 self._aliases: dict[str, str] = {}
259 self._meta: dict[str, dict[str, str]] = {}
260 self._states: dict[str, PluginState] = {}
261 self.mapping = MappingProxyType(self._plugins)
263 def state(self, name: str) -> PluginState | None:
264 """Return the lifecycle state for a plugin."""
265 canonical = self._aliases.get(name, name)
266 return self._states.get(canonical)
268 def transition(self, name: str, state: PluginState) -> None:
269 """Transition a plugin to a new lifecycle state."""
270 allowed: dict[PluginState, set[PluginState]] = {
271 PluginState.DISCOVERED: {
272 PluginState.INSTALLED,
273 PluginState.ACTIVE,
274 PluginState.REMOVED,
275 },
276 PluginState.INSTALLED: {PluginState.ACTIVE, PluginState.INACTIVE},
277 PluginState.ACTIVE: {PluginState.INACTIVE, PluginState.REMOVED},
278 PluginState.INACTIVE: {PluginState.ACTIVE, PluginState.REMOVED},
279 PluginState.REMOVED: set(),
280 }
281 canonical = self._aliases.get(name, name)
282 current = self._states.get(canonical)
283 if current is None: 283 ↛ 286line 283 didn't jump to line 286 because the condition on line 283 was always true
284 self._states[canonical] = state
285 return
286 if state not in allowed.get(current, set()):
287 raise ServiceError(
288 f"Invalid plugin state transition {current.value} -> {state.value}",
289 http_status=400,
290 )
291 self._states[canonical] = state
293 def register(
294 self,
295 name: str,
296 plugin: object,
297 *,
298 alias: str | None = None,
299 version: str | None = None,
300 ) -> None:
301 """Registers a plugin with the registry.
303 Args:
304 name (str): The canonical name of the plugin.
305 plugin (object): The plugin object to register.
306 alias (str | None): An optional alias for the plugin.
307 version (str | None): An optional version string for the plugin.
309 Returns:
310 None:
312 Raises:
313 ServiceError: If the name, alias, or plugin object is already
314 registered, or if the underlying `pluggy` registration fails.
315 """
316 if name in self._plugins:
317 raise ServiceError(f"Plugin {name!r} already registered", http_status=400)
318 if plugin in self._plugins.values():
319 raise ServiceError(
320 "Plugin object already registered under a different name",
321 http_status=400,
322 )
323 if alias and (alias in self._plugins or alias in self._aliases):
324 raise ServiceError(f"Alias {alias!r} already in use", http_status=400)
325 try:
326 self._pm.register(plugin, name)
327 except ValueError as error:
328 raise ServiceError(
329 f"Pluggy failed to register {name}: {error}", http_status=500
330 ) from error
331 self._plugins[name] = plugin
332 self._meta[name] = {"version": version or "unknown"}
333 self.transition(name, PluginState.ACTIVE)
334 if alias:
335 self._aliases[alias] = name
336 try:
337 self._telemetry.event(
338 "registry_plugin_registered",
339 {"name": name, "alias": alias, "version": version},
340 )
341 except RuntimeError as error:
342 self._telemetry.event(
343 "registry_telemetry_failed",
344 {"operation": "register", "error": str(error)},
345 )
347 def deregister(self, name: str) -> None:
348 """Deregisters a plugin from the registry.
350 Args:
351 name (str): The name or alias of the plugin to deregister.
353 Returns:
354 None:
356 Raises:
357 ServiceError: If the underlying `pluggy` deregistration fails.
358 """
359 canonical = self._aliases.get(name, name)
360 plugin = self._plugins.pop(canonical, None)
361 if not plugin:
362 return
363 try:
364 self._pm.unregister(plugin)
365 except ValueError as error:
366 raise ServiceError(
367 f"Pluggy failed to deregister {canonical}: {error}", http_status=500
368 ) from error
369 self._meta.pop(canonical, None)
370 self._states[canonical] = PluginState.REMOVED
371 self._aliases = {a: n for a, n in self._aliases.items() if n != canonical}
372 try:
373 self._telemetry.event("registry_plugin_deregistered", {"name": canonical})
374 except RuntimeError as error:
375 self._telemetry.event(
376 "registry_telemetry_failed",
377 {"operation": "deregister", "error": str(error)},
378 )
380 def get(self, name: str) -> object:
381 """Retrieves a plugin by its name or alias.
383 Args:
384 name (str): The name or alias of the plugin to retrieve.
386 Returns:
387 object: The registered plugin object.
389 Raises:
390 ServiceError: If the plugin is not found.
391 """
392 canonical = self._aliases.get(name, name)
393 try:
394 plugin = self._plugins[canonical]
395 except KeyError as key_error:
396 try:
397 self._telemetry.event(
398 "registry_plugin_retrieve_failed",
399 {"name": name, "error": str(key_error)},
400 )
401 except RuntimeError as telemetry_error:
402 self._telemetry.event(
403 "registry_telemetry_failed",
404 {"operation": "retrieve_failed", "error": str(telemetry_error)},
405 )
406 raise ServiceError(
407 f"Plugin {name!r} not found", http_status=404
408 ) from key_error
409 try:
410 self._telemetry.event("registry_plugin_retrieved", {"name": canonical})
411 except RuntimeError as error:
412 self._telemetry.event(
413 "registry_telemetry_failed",
414 {"operation": "retrieve", "error": str(error)},
415 )
416 return plugin
418 def names(self) -> list[str]:
419 """Returns a list of all registered plugin names.
421 Returns:
422 list[str]: A list of the canonical names of all registered plugins.
423 """
424 names = list(self._plugins.keys())
425 try:
426 self._telemetry.event("registry_list", {"names": names})
427 except RuntimeError as error:
428 self._telemetry.event(
429 "registry_telemetry_failed", {"operation": "list", "error": str(error)}
430 )
431 return names
433 def has(self, name: str) -> bool:
434 """Checks if a plugin is registered under a given name or alias.
436 Args:
437 name (str): The name or alias of the plugin to check.
439 Returns:
440 bool: True if the plugin is registered, otherwise False.
441 """
442 exists = name in self._plugins or name in self._aliases
443 try:
444 self._telemetry.event("registry_contains", {"name": name, "result": exists})
445 except RuntimeError as error:
446 self._telemetry.event(
447 "registry_telemetry_failed",
448 {"operation": "contains", "error": str(error)},
449 )
450 return exists
452 def meta(self, name: str) -> dict[str, str]:
453 """Retrieves metadata for a specific plugin.
455 Args:
456 name (str): The name or alias of the plugin.
458 Returns:
459 dict[str, str]: A dictionary containing the plugin's metadata.
460 """
461 canonical = self._aliases.get(name, name)
462 info = dict(self._meta.get(canonical, {}))
463 try:
464 self._telemetry.event("registry_meta_retrieved", {"name": canonical})
465 except RuntimeError as error:
466 self._telemetry.event(
467 "registry_telemetry_failed",
468 {"operation": "meta_retrieved", "error": str(error)},
469 )
470 return info
472 async def call_hook(self, hook: str, *args: Any, **kwargs: Any) -> list[Any]:
473 """Invokes a hook on all registered plugins that implement it.
475 This method handles results from multiple plugins, awaiting any results
476 that are coroutines.
478 Args:
479 hook (str): The name of the hook to invoke.
480 *args (Any): Positional arguments to pass to the hook.
481 **kwargs (Any): Keyword arguments to pass to the hook.
483 Returns:
484 list[Any]: A list containing the results from all hook
485 implementations that did not return `None`.
487 Raises:
488 ServiceError: If the specified hook does not exist.
489 """
490 try:
491 hook_fn = getattr(self._pm.hook, hook)
492 results = hook_fn(*args, **kwargs)
493 except AttributeError as error:
494 raise ServiceError(f"Hook {hook!r} not found", http_status=404) from error
495 collected = []
496 if isinstance(results, AsyncIterable):
497 async for result in results:
498 if asyncio.iscoroutine(result):
499 collected.append(await result)
500 elif result is not None:
501 collected.append(result)
502 else:
503 for result in results:
504 if asyncio.iscoroutine(result):
505 collected.append(await result)
506 elif result is not None:
507 collected.append(result)
508 try:
509 self._telemetry.event("registry_hook_called", {"hook": hook})
510 except RuntimeError as error:
511 self._telemetry.event(
512 "registry_telemetry_failed",
513 {"operation": "hook_called", "error": str(error)},
514 )
515 return collected
518__all__ = [
519 "Registry",
520 "CoreSpec",
521 "SPEC_VERSION",
522 "PRE_EXECUTE",
523 "POST_EXECUTE",
524 "command_group",
525 "dynamic_choices",
526 "load_entrypoints",
527 "_iter_plugin_eps",
528 "_compatible",
529]
530_LOG = logging.getLogger("bijux_cli.plugin_loader")