Coverage for  / home / runner / work / bijux-cli / bijux-cli / src / bijux_cli / infra / process.py: 98%

66 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-01-26 17:59 +0000

1# SPDX-License-Identifier: Apache-2.0 

2# Copyright © 2025 Bijan Mousavi 

3 

4"""Process execution adapters.""" 

5 

6from __future__ import annotations 

7 

8from collections import OrderedDict 

9from concurrent.futures import ProcessPoolExecutor 

10import os 

11import shutil 

12import subprocess # nosec B404 

13from typing import Any 

14 

15 

16def validate_command(cmd: list[str], *, allowed_commands: list[str]) -> list[str]: 

17 """Validates a command and its arguments against a whitelist.""" 

18 if not cmd: 

19 raise ValueError("invalid command: empty") 

20 

21 cmd_name = os.path.basename(cmd[0]) 

22 if cmd_name not in allowed_commands: 

23 raise ValueError( 

24 f"invalid command {cmd_name!r} not in allowed list: {allowed_commands}" 

25 ) 

26 resolved_cmd_path = shutil.which(cmd[0]) 

27 if not resolved_cmd_path: 

28 raise ValueError(f"Command not found or not executable: {cmd[0]!r}") 

29 if os.path.basename(resolved_cmd_path) != cmd_name: 

30 raise ValueError(f"Disallowed command path: {cmd[0]!r}") 

31 cmd[0] = resolved_cmd_path 

32 forbidden = set(";|&><`!") 

33 for arg in cmd[1:]: 

34 if any(ch in arg for ch in forbidden): 

35 raise ValueError(f"Unsafe argument: {arg!r}") 

36 return cmd 

37 

38 

39class ProcessPool: 

40 """Executes validated commands in a worker pool with an LRU cache.""" 

41 

42 _MAX_CACHE = 1000 

43 

44 def __init__( 

45 self, 

46 observability: Any, 

47 telemetry: Any, 

48 max_workers: int, 

49 allowed_commands: list[str], 

50 ) -> None: 

51 """Initialize the process pool executor.""" 

52 self._exec = ProcessPoolExecutor(max_workers=max_workers) 

53 self._log = observability 

54 self._tel = telemetry 

55 self._allowed_commands = allowed_commands 

56 self._cache: OrderedDict[tuple[str, ...], tuple[int, bytes, bytes]] = ( 

57 OrderedDict() 

58 ) 

59 

60 def run(self, cmd: list[str], *, executor: str) -> tuple[int, bytes, bytes]: 

61 """Run a validated command via the process pool.""" 

62 key = tuple(cmd) 

63 if key in self._cache: 

64 self._log.log("debug", "Process-pool cache hit", extra={"cmd": cmd}) 

65 self._tel.event("procpool_cache_hit", {"cmd": cmd, "executor": executor}) 

66 self._cache.move_to_end(key) 

67 return self._cache[key] 

68 

69 orig_cmd = list(cmd) 

70 try: 

71 validate = __import__( 

72 "bijux_cli.infra.process", fromlist=["validate_command"] 

73 ).validate_command 

74 safe_cmd = validate(cmd, allowed_commands=self._allowed_commands) 

75 except ValueError: 

76 self._tel.event( 

77 "procpool_execution_failed", 

78 {"cmd": cmd, "executor": executor, "error": "validation"}, 

79 ) 

80 raise 

81 

82 try: 

83 self._log.log("info", "Process-pool executing", extra={"cmd": orig_cmd}) 

84 self._tel.event("procpool_execute", {"cmd": orig_cmd, "executor": executor}) 

85 

86 result = subprocess.run( # noqa: S603 # nosec B603 

87 safe_cmd, 

88 capture_output=True, 

89 check=False, 

90 shell=False, 

91 ) 

92 

93 self._cache[key] = (result.returncode, result.stdout, result.stderr) 

94 self._cache.move_to_end(key) 

95 if len(self._cache) > self._MAX_CACHE: 

96 self._cache.popitem(last=False) 

97 

98 self._tel.event( 

99 "procpool_executed", 

100 { 

101 "cmd": orig_cmd, 

102 "executor": executor, 

103 "returncode": result.returncode, 

104 }, 

105 ) 

106 return result.returncode, result.stdout, result.stderr 

107 except Exception as exc: 

108 self._tel.event( 

109 "procpool_execution_failed", 

110 {"cmd": orig_cmd, "executor": executor, "error": str(exc)}, 

111 ) 

112 raise RuntimeError(f"Process-pool execution failed: {exc}") from exc 

113 

114 def shutdown(self) -> None: 

115 """Shutdown the process pool and emit telemetry.""" 

116 self._exec.shutdown(wait=True) 

117 self._tel.event("procpool_shutdown", {}) 

118 self._log.log("debug", "Process-pool shutdown") 

119 

120 def get_status(self) -> dict[str, Any]: 

121 """Return basic status for the process pool.""" 

122 return {"commands_processed": len(self._cache)} 

123 

124 

125__all__ = ["ProcessPool", "validate_command"]