Coverage for  / home / runner / work / bijux-cli / bijux-cli / src / bijux_cli / plugins / registry.py: 97%

245 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-01-26 17:59 +0000

1# SPDX-License-Identifier: Apache-2.0 

2# Copyright © 2025 Bijan Mousavi 

3 

4"""Provides a concrete plugin registry service using the `pluggy` framework. 

5 

6This module defines the `Registry` class, which implements the 

7`RegistryProtocol`. It serves as the central manager for the entire plugin 

8lifecycle, including registration, aliasing, metadata storage, and the 

9invocation of plugin hooks. It is built on top of the `pluggy` library to 

10provide a robust and extensible plugin architecture. 

11""" 

12 

13from __future__ import annotations 

14 

15import asyncio 

16from collections.abc import AsyncIterable, Callable 

17import contextlib 

18import importlib.metadata as im 

19import logging 

20import traceback 

21from types import MappingProxyType 

22from typing import Any 

23 

24from injector import inject 

25from packaging.specifiers import SpecifierSet 

26from packaging.version import Version as PkgVersion 

27import pluggy 

28 

29from bijux_cli.core.di import DIContainer 

30from bijux_cli.plugins.contracts import PluginState, RegistryProtocol 

31from bijux_cli.services.contracts import ObservabilityProtocol, TelemetryProtocol 

32from bijux_cli.services.errors import ServiceError 

33 

34PRE_EXECUTE = "pre_execute" 

35POST_EXECUTE = "post_execute" 

36SPEC_VERSION = __import__("bijux_cli").version 

37 

38hookspec = pluggy.HookspecMarker("bijux") 

39 

40 

41class CoreSpec: 

42 """Defines the core hook specifications for CLI plugins.""" 

43 

44 def __init__(self, dependency_injector: DIContainer) -> None: 

45 """Initialize with observability from DI.""" 

46 self._log = dependency_injector.resolve(ObservabilityProtocol) 

47 

48 @hookspec 

49 async def startup(self) -> None: 

50 """Hook called at startup.""" 

51 self._log.log("debug", "Hook startup called", extra={}) 

52 

53 @hookspec 

54 async def shutdown(self) -> None: 

55 """Hook called at shutdown.""" 

56 self._log.log("debug", "Hook shutdown called", extra={}) 

57 

58 @hookspec 

59 async def pre_execute( 

60 self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] 

61 ) -> None: 

62 """Hook called before command execution.""" 

63 self._log.log( 

64 "debug", 

65 "Hook pre_execute called", 

66 extra={"name": name, "args": args, "kwargs": kwargs}, 

67 ) 

68 

69 @hookspec 

70 async def post_execute(self, name: str, result: Any) -> None: 

71 """Hook called after command execution.""" 

72 self._log.log( 

73 "debug", 

74 "Hook post_execute called", 

75 extra={"name": name, "result": repr(result)}, 

76 ) 

77 

78 @hookspec 

79 def health(self) -> bool | str: 

80 """Hook used for health checks.""" 

81 self._log.log("debug", "Hook health called", extra={}) 

82 return True 

83 

84 

85def command_group( 

86 name: str, 

87 *, 

88 version: str | None = None, 

89) -> Callable[[str], Callable[[Callable[..., Any]], Callable[..., Any]]]: 

90 """Decorator factory for registering plugin subcommands under a group.""" 

91 

92 def with_sub(sub: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: 

93 if " " in sub: 

94 raise ValueError("subcommand may not contain spaces") 

95 full = f"{name} {sub}" 

96 

97 def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: 

98 try: 

99 di = DIContainer.current() 

100 reg: RegistryProtocol = di.resolve(RegistryProtocol) 

101 except KeyError as exc: 

102 raise RuntimeError("RegistryProtocol is not initialized") from exc 

103 

104 reg.register(full, fn, alias=None, version=version) 

105 

106 try: 

107 obs: ObservabilityProtocol = di.resolve(ObservabilityProtocol) 

108 obs.log( 

109 "info", 

110 "Registered command group", 

111 extra={"cmd": full, "version": version}, 

112 ) 

113 except KeyError: 

114 pass 

115 

116 try: 

117 tel: TelemetryProtocol = di.resolve(TelemetryProtocol) 

118 tel.event( 

119 "command_group_registered", {"command": full, "version": version} 

120 ) 

121 except KeyError: 

122 pass 

123 

124 return fn 

125 

126 return decorator 

127 

128 return with_sub 

129 

130 

131def dynamic_choices( 

132 callback: Callable[[], list[str]], 

133 *, 

134 case_sensitive: bool = True, 

135) -> Callable[[Any, Any, str], list[str]]: 

136 """Creates a Typer completer from a callback function.""" 

137 

138 def completer(_ctx: Any, _param: Any, incomplete: str) -> list[str]: 

139 choices = callback() 

140 if case_sensitive: 

141 return [c for c in choices if c.startswith(incomplete)] 

142 return [c for c in choices if c.lower().startswith(incomplete.lower())] 

143 

144 return completer 

145 

146 

147def _iter_plugin_eps() -> list[im.EntryPoint]: 

148 """Returns all entry points in the 'bijux_cli.plugins' group.""" 

149 try: 

150 eps = im.entry_points() 

151 return list(eps.select(group="bijux_cli.plugins")) 

152 except Exception: 

153 return [] 

154 

155 

156def _compatible(plugin: Any) -> bool: 

157 """Determines if a plugin is compatible with the current CLI API version.""" 

158 import bijux_cli 

159 

160 spec = getattr(plugin, "requires_api_version", ">=0.0.0") 

161 try: 

162 apiv = bijux_cli.api_version 

163 host_api_version = PkgVersion(str(apiv)) 

164 return SpecifierSet(spec).contains(host_api_version) 

165 except Exception: 

166 return False 

167 

168 

169async def load_entrypoints( 

170 di: DIContainer | None = None, 

171 registry: RegistryProtocol | None = None, 

172) -> None: 

173 """Discovers, loads, and registers all entry point-based plugins.""" 

174 import bijux_cli 

175 

176 di = di or DIContainer.current() 

177 registry = registry or di.resolve(RegistryProtocol) 

178 

179 obs = di.resolve(ObservabilityProtocol, None) 

180 tel = di.resolve(TelemetryProtocol, None) 

181 

182 for ep in _iter_plugin_eps(): 

183 try: 

184 plugin_class = await asyncio.to_thread(ep.load) 

185 plugin = await asyncio.to_thread(plugin_class) 

186 

187 if not _compatible(plugin): 

188 raise RuntimeError( 

189 f"Plugin '{ep.name}' requires API {getattr(plugin, 'requires_api_version', 'N/A')}, " 

190 f"host is {bijux_cli.api_version}" 

191 ) 

192 

193 for tgt in (plugin_class, plugin): 

194 raw = getattr(tgt, "version", None) 

195 if raw is not None and not isinstance(raw, str): 

196 tgt.version = str(raw) 

197 

198 registry.transition(ep.name, PluginState.DISCOVERED) 

199 registry.transition(ep.name, PluginState.INSTALLED) 

200 registry.register(ep.name, plugin, alias=None, version=plugin.version) 

201 

202 startup = getattr(plugin, "startup", None) 

203 if asyncio.iscoroutinefunction(startup): 

204 await startup(di) 

205 elif callable(startup): 

206 await asyncio.to_thread(startup, di) 

207 

208 if obs: 

209 obs.log("info", f"Loaded plugin '{ep.name}'", extra={}) 

210 if tel: 

211 tel.event("entrypoint_plugin_loaded", {"name": ep.name}) 

212 

213 except Exception as exc: 

214 with contextlib.suppress(Exception): 

215 registry.deregister(ep.name) 

216 

217 if obs: 

218 obs.log( 

219 "error", 

220 f"Failed to load plugin '{ep.name}'", 

221 extra={"trace": traceback.format_exc(limit=5)}, 

222 ) 

223 if tel: 

224 tel.event( 

225 "entrypoint_plugin_failed", {"name": ep.name, "error": str(exc)} 

226 ) 

227 

228 _LOG.debug("Skipped plugin %s: %s", ep.name, exc, exc_info=True) 

229 

230 

231class Registry(RegistryProtocol): 

232 """A `pluggy`-based registry for managing CLI plugins. 

233 

234 This class provides aliasing, metadata storage, and telemetry integration 

235 on top of the core `pluggy` plugin management system. 

236 

237 Attributes: 

238 _telemetry (TelemetryProtocol): The telemetry service for events. 

239 _pm (pluggy.PluginManager): The underlying `pluggy` plugin manager. 

240 _plugins (dict): A mapping of canonical plugin names to plugin objects. 

241 _aliases (dict): A mapping of alias names to canonical plugin names. 

242 _meta (dict): A mapping of canonical plugin names to their metadata. 

243 mapping (MappingProxyType): A read-only view of the `_plugins` mapping. 

244 """ 

245 

246 @inject 

247 def __init__(self, telemetry: TelemetryProtocol): 

248 """Initializes the `Registry` service. 

249 

250 Args: 

251 telemetry (TelemetryProtocol): The telemetry service for tracking 

252 registry events. 

253 """ 

254 self._telemetry = telemetry 

255 self._pm = pluggy.PluginManager("bijux") 

256 self._pm.add_hookspecs(CoreSpec) 

257 self._plugins: dict[str, object] = {} 

258 self._aliases: dict[str, str] = {} 

259 self._meta: dict[str, dict[str, str]] = {} 

260 self._states: dict[str, PluginState] = {} 

261 self.mapping = MappingProxyType(self._plugins) 

262 

263 def state(self, name: str) -> PluginState | None: 

264 """Return the lifecycle state for a plugin.""" 

265 canonical = self._aliases.get(name, name) 

266 return self._states.get(canonical) 

267 

268 def transition(self, name: str, state: PluginState) -> None: 

269 """Transition a plugin to a new lifecycle state.""" 

270 allowed: dict[PluginState, set[PluginState]] = { 

271 PluginState.DISCOVERED: { 

272 PluginState.INSTALLED, 

273 PluginState.ACTIVE, 

274 PluginState.REMOVED, 

275 }, 

276 PluginState.INSTALLED: {PluginState.ACTIVE, PluginState.INACTIVE}, 

277 PluginState.ACTIVE: {PluginState.INACTIVE, PluginState.REMOVED}, 

278 PluginState.INACTIVE: {PluginState.ACTIVE, PluginState.REMOVED}, 

279 PluginState.REMOVED: set(), 

280 } 

281 canonical = self._aliases.get(name, name) 

282 current = self._states.get(canonical) 

283 if current is None: 283 ↛ 286line 283 didn't jump to line 286 because the condition on line 283 was always true

284 self._states[canonical] = state 

285 return 

286 if state not in allowed.get(current, set()): 

287 raise ServiceError( 

288 f"Invalid plugin state transition {current.value} -> {state.value}", 

289 http_status=400, 

290 ) 

291 self._states[canonical] = state 

292 

293 def register( 

294 self, 

295 name: str, 

296 plugin: object, 

297 *, 

298 alias: str | None = None, 

299 version: str | None = None, 

300 ) -> None: 

301 """Registers a plugin with the registry. 

302 

303 Args: 

304 name (str): The canonical name of the plugin. 

305 plugin (object): The plugin object to register. 

306 alias (str | None): An optional alias for the plugin. 

307 version (str | None): An optional version string for the plugin. 

308 

309 Returns: 

310 None: 

311 

312 Raises: 

313 ServiceError: If the name, alias, or plugin object is already 

314 registered, or if the underlying `pluggy` registration fails. 

315 """ 

316 if name in self._plugins: 

317 raise ServiceError(f"Plugin {name!r} already registered", http_status=400) 

318 if plugin in self._plugins.values(): 

319 raise ServiceError( 

320 "Plugin object already registered under a different name", 

321 http_status=400, 

322 ) 

323 if alias and (alias in self._plugins or alias in self._aliases): 

324 raise ServiceError(f"Alias {alias!r} already in use", http_status=400) 

325 try: 

326 self._pm.register(plugin, name) 

327 except ValueError as error: 

328 raise ServiceError( 

329 f"Pluggy failed to register {name}: {error}", http_status=500 

330 ) from error 

331 self._plugins[name] = plugin 

332 self._meta[name] = {"version": version or "unknown"} 

333 self.transition(name, PluginState.ACTIVE) 

334 if alias: 

335 self._aliases[alias] = name 

336 try: 

337 self._telemetry.event( 

338 "registry_plugin_registered", 

339 {"name": name, "alias": alias, "version": version}, 

340 ) 

341 except RuntimeError as error: 

342 self._telemetry.event( 

343 "registry_telemetry_failed", 

344 {"operation": "register", "error": str(error)}, 

345 ) 

346 

347 def deregister(self, name: str) -> None: 

348 """Deregisters a plugin from the registry. 

349 

350 Args: 

351 name (str): The name or alias of the plugin to deregister. 

352 

353 Returns: 

354 None: 

355 

356 Raises: 

357 ServiceError: If the underlying `pluggy` deregistration fails. 

358 """ 

359 canonical = self._aliases.get(name, name) 

360 plugin = self._plugins.pop(canonical, None) 

361 if not plugin: 

362 return 

363 try: 

364 self._pm.unregister(plugin) 

365 except ValueError as error: 

366 raise ServiceError( 

367 f"Pluggy failed to deregister {canonical}: {error}", http_status=500 

368 ) from error 

369 self._meta.pop(canonical, None) 

370 self._states[canonical] = PluginState.REMOVED 

371 self._aliases = {a: n for a, n in self._aliases.items() if n != canonical} 

372 try: 

373 self._telemetry.event("registry_plugin_deregistered", {"name": canonical}) 

374 except RuntimeError as error: 

375 self._telemetry.event( 

376 "registry_telemetry_failed", 

377 {"operation": "deregister", "error": str(error)}, 

378 ) 

379 

380 def get(self, name: str) -> object: 

381 """Retrieves a plugin by its name or alias. 

382 

383 Args: 

384 name (str): The name or alias of the plugin to retrieve. 

385 

386 Returns: 

387 object: The registered plugin object. 

388 

389 Raises: 

390 ServiceError: If the plugin is not found. 

391 """ 

392 canonical = self._aliases.get(name, name) 

393 try: 

394 plugin = self._plugins[canonical] 

395 except KeyError as key_error: 

396 try: 

397 self._telemetry.event( 

398 "registry_plugin_retrieve_failed", 

399 {"name": name, "error": str(key_error)}, 

400 ) 

401 except RuntimeError as telemetry_error: 

402 self._telemetry.event( 

403 "registry_telemetry_failed", 

404 {"operation": "retrieve_failed", "error": str(telemetry_error)}, 

405 ) 

406 raise ServiceError( 

407 f"Plugin {name!r} not found", http_status=404 

408 ) from key_error 

409 try: 

410 self._telemetry.event("registry_plugin_retrieved", {"name": canonical}) 

411 except RuntimeError as error: 

412 self._telemetry.event( 

413 "registry_telemetry_failed", 

414 {"operation": "retrieve", "error": str(error)}, 

415 ) 

416 return plugin 

417 

418 def names(self) -> list[str]: 

419 """Returns a list of all registered plugin names. 

420 

421 Returns: 

422 list[str]: A list of the canonical names of all registered plugins. 

423 """ 

424 names = list(self._plugins.keys()) 

425 try: 

426 self._telemetry.event("registry_list", {"names": names}) 

427 except RuntimeError as error: 

428 self._telemetry.event( 

429 "registry_telemetry_failed", {"operation": "list", "error": str(error)} 

430 ) 

431 return names 

432 

433 def has(self, name: str) -> bool: 

434 """Checks if a plugin is registered under a given name or alias. 

435 

436 Args: 

437 name (str): The name or alias of the plugin to check. 

438 

439 Returns: 

440 bool: True if the plugin is registered, otherwise False. 

441 """ 

442 exists = name in self._plugins or name in self._aliases 

443 try: 

444 self._telemetry.event("registry_contains", {"name": name, "result": exists}) 

445 except RuntimeError as error: 

446 self._telemetry.event( 

447 "registry_telemetry_failed", 

448 {"operation": "contains", "error": str(error)}, 

449 ) 

450 return exists 

451 

452 def meta(self, name: str) -> dict[str, str]: 

453 """Retrieves metadata for a specific plugin. 

454 

455 Args: 

456 name (str): The name or alias of the plugin. 

457 

458 Returns: 

459 dict[str, str]: A dictionary containing the plugin's metadata. 

460 """ 

461 canonical = self._aliases.get(name, name) 

462 info = dict(self._meta.get(canonical, {})) 

463 try: 

464 self._telemetry.event("registry_meta_retrieved", {"name": canonical}) 

465 except RuntimeError as error: 

466 self._telemetry.event( 

467 "registry_telemetry_failed", 

468 {"operation": "meta_retrieved", "error": str(error)}, 

469 ) 

470 return info 

471 

472 async def call_hook(self, hook: str, *args: Any, **kwargs: Any) -> list[Any]: 

473 """Invokes a hook on all registered plugins that implement it. 

474 

475 This method handles results from multiple plugins, awaiting any results 

476 that are coroutines. 

477 

478 Args: 

479 hook (str): The name of the hook to invoke. 

480 *args (Any): Positional arguments to pass to the hook. 

481 **kwargs (Any): Keyword arguments to pass to the hook. 

482 

483 Returns: 

484 list[Any]: A list containing the results from all hook 

485 implementations that did not return `None`. 

486 

487 Raises: 

488 ServiceError: If the specified hook does not exist. 

489 """ 

490 try: 

491 hook_fn = getattr(self._pm.hook, hook) 

492 results = hook_fn(*args, **kwargs) 

493 except AttributeError as error: 

494 raise ServiceError(f"Hook {hook!r} not found", http_status=404) from error 

495 collected = [] 

496 if isinstance(results, AsyncIterable): 

497 async for result in results: 

498 if asyncio.iscoroutine(result): 

499 collected.append(await result) 

500 elif result is not None: 

501 collected.append(result) 

502 else: 

503 for result in results: 

504 if asyncio.iscoroutine(result): 

505 collected.append(await result) 

506 elif result is not None: 

507 collected.append(result) 

508 try: 

509 self._telemetry.event("registry_hook_called", {"hook": hook}) 

510 except RuntimeError as error: 

511 self._telemetry.event( 

512 "registry_telemetry_failed", 

513 {"operation": "hook_called", "error": str(error)}, 

514 ) 

515 return collected 

516 

517 

518__all__ = [ 

519 "Registry", 

520 "CoreSpec", 

521 "SPEC_VERSION", 

522 "PRE_EXECUTE", 

523 "POST_EXECUTE", 

524 "command_group", 

525 "dynamic_choices", 

526 "load_entrypoints", 

527 "_iter_plugin_eps", 

528 "_compatible", 

529] 

530_LOG = logging.getLogger("bijux_cli.plugin_loader")