query_engine.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. from __future__ import annotations
  2. import json
  3. from dataclasses import dataclass, field
  4. from uuid import uuid4
  5. from .commands import build_command_backlog
  6. from .models import PermissionDenial, UsageSummary
  7. from .port_manifest import PortManifest, build_port_manifest
  8. from .session_store import StoredSession, load_session, save_session
  9. from .tools import build_tool_backlog
  10. from .transcript import TranscriptStore
  11. @dataclass(frozen=True)
  12. class QueryEngineConfig:
  13. max_turns: int = 8
  14. max_budget_tokens: int = 2000
  15. compact_after_turns: int = 12
  16. structured_output: bool = False
  17. structured_retry_limit: int = 2
  18. @dataclass(frozen=True)
  19. class TurnResult:
  20. prompt: str
  21. output: str
  22. matched_commands: tuple[str, ...]
  23. matched_tools: tuple[str, ...]
  24. permission_denials: tuple[PermissionDenial, ...]
  25. usage: UsageSummary
  26. stop_reason: str
  27. @dataclass
  28. class QueryEnginePort:
  29. manifest: PortManifest
  30. config: QueryEngineConfig = field(default_factory=QueryEngineConfig)
  31. session_id: str = field(default_factory=lambda: uuid4().hex)
  32. mutable_messages: list[str] = field(default_factory=list)
  33. permission_denials: list[PermissionDenial] = field(default_factory=list)
  34. total_usage: UsageSummary = field(default_factory=UsageSummary)
  35. transcript_store: TranscriptStore = field(default_factory=TranscriptStore)
  36. @classmethod
  37. def from_workspace(cls) -> 'QueryEnginePort':
  38. return cls(manifest=build_port_manifest())
  39. @classmethod
  40. def from_saved_session(cls, session_id: str) -> 'QueryEnginePort':
  41. stored = load_session(session_id)
  42. transcript = TranscriptStore(entries=list(stored.messages), flushed=True)
  43. return cls(
  44. manifest=build_port_manifest(),
  45. session_id=stored.session_id,
  46. mutable_messages=list(stored.messages),
  47. total_usage=UsageSummary(stored.input_tokens, stored.output_tokens),
  48. transcript_store=transcript,
  49. )
  50. def submit_message(
  51. self,
  52. prompt: str,
  53. matched_commands: tuple[str, ...] = (),
  54. matched_tools: tuple[str, ...] = (),
  55. denied_tools: tuple[PermissionDenial, ...] = (),
  56. ) -> TurnResult:
  57. if len(self.mutable_messages) >= self.config.max_turns:
  58. output = f'Max turns reached before processing prompt: {prompt}'
  59. return TurnResult(
  60. prompt=prompt,
  61. output=output,
  62. matched_commands=matched_commands,
  63. matched_tools=matched_tools,
  64. permission_denials=denied_tools,
  65. usage=self.total_usage,
  66. stop_reason='max_turns_reached',
  67. )
  68. summary_lines = [
  69. f'Prompt: {prompt}',
  70. f'Matched commands: {", ".join(matched_commands) if matched_commands else "none"}',
  71. f'Matched tools: {", ".join(matched_tools) if matched_tools else "none"}',
  72. f'Permission denials: {len(denied_tools)}',
  73. ]
  74. output = self._format_output(summary_lines)
  75. projected_usage = self.total_usage.add_turn(prompt, output)
  76. stop_reason = 'completed'
  77. if projected_usage.input_tokens + projected_usage.output_tokens > self.config.max_budget_tokens:
  78. stop_reason = 'max_budget_reached'
  79. self.mutable_messages.append(prompt)
  80. self.transcript_store.append(prompt)
  81. self.permission_denials.extend(denied_tools)
  82. self.total_usage = projected_usage
  83. self.compact_messages_if_needed()
  84. return TurnResult(
  85. prompt=prompt,
  86. output=output,
  87. matched_commands=matched_commands,
  88. matched_tools=matched_tools,
  89. permission_denials=denied_tools,
  90. usage=self.total_usage,
  91. stop_reason=stop_reason,
  92. )
  93. def stream_submit_message(
  94. self,
  95. prompt: str,
  96. matched_commands: tuple[str, ...] = (),
  97. matched_tools: tuple[str, ...] = (),
  98. denied_tools: tuple[PermissionDenial, ...] = (),
  99. ):
  100. yield {'type': 'message_start', 'session_id': self.session_id, 'prompt': prompt}
  101. if matched_commands:
  102. yield {'type': 'command_match', 'commands': matched_commands}
  103. if matched_tools:
  104. yield {'type': 'tool_match', 'tools': matched_tools}
  105. if denied_tools:
  106. yield {'type': 'permission_denial', 'denials': [denial.tool_name for denial in denied_tools]}
  107. result = self.submit_message(prompt, matched_commands, matched_tools, denied_tools)
  108. yield {'type': 'message_delta', 'text': result.output}
  109. yield {
  110. 'type': 'message_stop',
  111. 'usage': {'input_tokens': result.usage.input_tokens, 'output_tokens': result.usage.output_tokens},
  112. 'stop_reason': result.stop_reason,
  113. 'transcript_size': len(self.transcript_store.entries),
  114. }
  115. def compact_messages_if_needed(self) -> None:
  116. if len(self.mutable_messages) > self.config.compact_after_turns:
  117. self.mutable_messages[:] = self.mutable_messages[-self.config.compact_after_turns :]
  118. self.transcript_store.compact(self.config.compact_after_turns)
  119. def replay_user_messages(self) -> tuple[str, ...]:
  120. return self.transcript_store.replay()
  121. def flush_transcript(self) -> None:
  122. self.transcript_store.flush()
  123. def persist_session(self) -> str:
  124. self.flush_transcript()
  125. path = save_session(
  126. StoredSession(
  127. session_id=self.session_id,
  128. messages=tuple(self.mutable_messages),
  129. input_tokens=self.total_usage.input_tokens,
  130. output_tokens=self.total_usage.output_tokens,
  131. )
  132. )
  133. return str(path)
  134. def _format_output(self, summary_lines: list[str]) -> str:
  135. if self.config.structured_output:
  136. payload = {
  137. 'summary': summary_lines,
  138. 'session_id': self.session_id,
  139. }
  140. return self._render_structured_output(payload)
  141. return '\n'.join(summary_lines)
  142. def _render_structured_output(self, payload: dict[str, object]) -> str:
  143. last_error: Exception | None = None
  144. for _ in range(self.config.structured_retry_limit):
  145. try:
  146. return json.dumps(payload, indent=2)
  147. except (TypeError, ValueError) as exc: # pragma: no cover - defensive branch
  148. last_error = exc
  149. payload = {'summary': ['structured output retry'], 'session_id': self.session_id}
  150. raise RuntimeError('structured output rendering failed') from last_error
  151. def render_summary(self) -> str:
  152. command_backlog = build_command_backlog()
  153. tool_backlog = build_tool_backlog()
  154. sections = [
  155. '# Python Porting Workspace Summary',
  156. '',
  157. self.manifest.to_markdown(),
  158. '',
  159. f'Command surface: {len(command_backlog.modules)} mirrored entries',
  160. *command_backlog.summary_lines()[:10],
  161. '',
  162. f'Tool surface: {len(tool_backlog.modules)} mirrored entries',
  163. *tool_backlog.summary_lines()[:10],
  164. '',
  165. f'Session id: {self.session_id}',
  166. f'Conversation turns stored: {len(self.mutable_messages)}',
  167. f'Permission denials tracked: {len(self.permission_denials)}',
  168. f'Usage totals: in={self.total_usage.input_tokens} out={self.total_usage.output_tokens}',
  169. f'Max turns: {self.config.max_turns}',
  170. f'Max budget tokens: {self.config.max_budget_tokens}',
  171. f'Transcript flushed: {self.transcript_store.flushed}',
  172. ]
  173. return '\n'.join(sections)