lib.rs 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123
  1. use std::collections::HashMap;
  2. use std::io;
  3. use std::sync::Arc;
  4. use std::time::{SystemTime, UNIX_EPOCH};
  5. use api::{InputContentBlock, MessageRequest, MessageResponse, OutputContentBlock, Usage};
  6. use serde_json::{json, Value};
  7. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  8. use tokio::net::TcpListener;
  9. use tokio::sync::{oneshot, Mutex};
  10. use tokio::task::JoinHandle;
  11. pub const SCENARIO_PREFIX: &str = "PARITY_SCENARIO:";
  12. pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
  13. #[derive(Debug, Clone, PartialEq, Eq)]
  14. pub struct CapturedRequest {
  15. pub method: String,
  16. pub path: String,
  17. pub headers: HashMap<String, String>,
  18. pub scenario: String,
  19. pub stream: bool,
  20. pub raw_body: String,
  21. }
  22. pub struct MockAnthropicService {
  23. base_url: String,
  24. requests: Arc<Mutex<Vec<CapturedRequest>>>,
  25. shutdown: Option<oneshot::Sender<()>>,
  26. join_handle: JoinHandle<()>,
  27. }
  28. impl MockAnthropicService {
  29. pub async fn spawn() -> io::Result<Self> {
  30. Self::spawn_on("127.0.0.1:0").await
  31. }
  32. pub async fn spawn_on(bind_addr: &str) -> io::Result<Self> {
  33. let listener = TcpListener::bind(bind_addr).await?;
  34. let address = listener.local_addr()?;
  35. let requests = Arc::new(Mutex::new(Vec::new()));
  36. let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
  37. let request_state = Arc::clone(&requests);
  38. let join_handle = tokio::spawn(async move {
  39. loop {
  40. tokio::select! {
  41. _ = &mut shutdown_rx => break,
  42. accepted = listener.accept() => {
  43. let Ok((socket, _)) = accepted else {
  44. break;
  45. };
  46. let request_state = Arc::clone(&request_state);
  47. tokio::spawn(async move {
  48. let _ = handle_connection(socket, request_state).await;
  49. });
  50. }
  51. }
  52. }
  53. });
  54. Ok(Self {
  55. base_url: format!("http://{address}"),
  56. requests,
  57. shutdown: Some(shutdown_tx),
  58. join_handle,
  59. })
  60. }
  61. #[must_use]
  62. pub fn base_url(&self) -> String {
  63. self.base_url.clone()
  64. }
  65. pub async fn captured_requests(&self) -> Vec<CapturedRequest> {
  66. self.requests.lock().await.clone()
  67. }
  68. }
  69. impl Drop for MockAnthropicService {
  70. fn drop(&mut self) {
  71. if let Some(shutdown) = self.shutdown.take() {
  72. let _ = shutdown.send(());
  73. }
  74. self.join_handle.abort();
  75. }
  76. }
  77. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  78. enum Scenario {
  79. StreamingText,
  80. ReadFileRoundtrip,
  81. GrepChunkAssembly,
  82. WriteFileAllowed,
  83. WriteFileDenied,
  84. MultiToolTurnRoundtrip,
  85. BashStdoutRoundtrip,
  86. BashPermissionPromptApproved,
  87. BashPermissionPromptDenied,
  88. PluginToolRoundtrip,
  89. AutoCompactTriggered,
  90. TokenCostReporting,
  91. }
  92. impl Scenario {
  93. fn parse(value: &str) -> Option<Self> {
  94. match value.trim() {
  95. "streaming_text" => Some(Self::StreamingText),
  96. "read_file_roundtrip" => Some(Self::ReadFileRoundtrip),
  97. "grep_chunk_assembly" => Some(Self::GrepChunkAssembly),
  98. "write_file_allowed" => Some(Self::WriteFileAllowed),
  99. "write_file_denied" => Some(Self::WriteFileDenied),
  100. "multi_tool_turn_roundtrip" => Some(Self::MultiToolTurnRoundtrip),
  101. "bash_stdout_roundtrip" => Some(Self::BashStdoutRoundtrip),
  102. "bash_permission_prompt_approved" => Some(Self::BashPermissionPromptApproved),
  103. "bash_permission_prompt_denied" => Some(Self::BashPermissionPromptDenied),
  104. "plugin_tool_roundtrip" => Some(Self::PluginToolRoundtrip),
  105. "auto_compact_triggered" => Some(Self::AutoCompactTriggered),
  106. "token_cost_reporting" => Some(Self::TokenCostReporting),
  107. _ => None,
  108. }
  109. }
  110. fn name(self) -> &'static str {
  111. match self {
  112. Self::StreamingText => "streaming_text",
  113. Self::ReadFileRoundtrip => "read_file_roundtrip",
  114. Self::GrepChunkAssembly => "grep_chunk_assembly",
  115. Self::WriteFileAllowed => "write_file_allowed",
  116. Self::WriteFileDenied => "write_file_denied",
  117. Self::MultiToolTurnRoundtrip => "multi_tool_turn_roundtrip",
  118. Self::BashStdoutRoundtrip => "bash_stdout_roundtrip",
  119. Self::BashPermissionPromptApproved => "bash_permission_prompt_approved",
  120. Self::BashPermissionPromptDenied => "bash_permission_prompt_denied",
  121. Self::PluginToolRoundtrip => "plugin_tool_roundtrip",
  122. Self::AutoCompactTriggered => "auto_compact_triggered",
  123. Self::TokenCostReporting => "token_cost_reporting",
  124. }
  125. }
  126. }
  127. async fn handle_connection(
  128. mut socket: tokio::net::TcpStream,
  129. requests: Arc<Mutex<Vec<CapturedRequest>>>,
  130. ) -> io::Result<()> {
  131. let (method, path, headers, raw_body) = read_http_request(&mut socket).await?;
  132. let request: MessageRequest = serde_json::from_str(&raw_body)
  133. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?;
  134. let scenario = detect_scenario(&request)
  135. .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing parity scenario"))?;
  136. requests.lock().await.push(CapturedRequest {
  137. method,
  138. path,
  139. headers,
  140. scenario: scenario.name().to_string(),
  141. stream: request.stream,
  142. raw_body,
  143. });
  144. let response = build_http_response(&request, scenario);
  145. socket.write_all(response.as_bytes()).await?;
  146. Ok(())
  147. }
  148. async fn read_http_request(
  149. socket: &mut tokio::net::TcpStream,
  150. ) -> io::Result<(String, String, HashMap<String, String>, String)> {
  151. let mut buffer = Vec::new();
  152. let mut header_end = None;
  153. loop {
  154. let mut chunk = [0_u8; 1024];
  155. let read = socket.read(&mut chunk).await?;
  156. if read == 0 {
  157. break;
  158. }
  159. buffer.extend_from_slice(&chunk[..read]);
  160. if let Some(position) = find_header_end(&buffer) {
  161. header_end = Some(position);
  162. break;
  163. }
  164. }
  165. let header_end = header_end
  166. .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "missing http headers"))?;
  167. let (header_bytes, remaining) = buffer.split_at(header_end);
  168. let header_text = String::from_utf8(header_bytes.to_vec())
  169. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?;
  170. let mut lines = header_text.split("\r\n");
  171. let request_line = lines
  172. .next()
  173. .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing request line"))?;
  174. let mut request_parts = request_line.split_whitespace();
  175. let method = request_parts
  176. .next()
  177. .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing method"))?
  178. .to_string();
  179. let path = request_parts
  180. .next()
  181. .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing path"))?
  182. .to_string();
  183. let mut headers = HashMap::new();
  184. let mut content_length = 0_usize;
  185. for line in lines {
  186. if line.is_empty() {
  187. continue;
  188. }
  189. let (name, value) = line.split_once(':').ok_or_else(|| {
  190. io::Error::new(io::ErrorKind::InvalidData, "malformed http header line")
  191. })?;
  192. let value = value.trim().to_string();
  193. if name.eq_ignore_ascii_case("content-length") {
  194. content_length = value.parse().map_err(|error| {
  195. io::Error::new(
  196. io::ErrorKind::InvalidData,
  197. format!("invalid content-length: {error}"),
  198. )
  199. })?;
  200. }
  201. headers.insert(name.to_ascii_lowercase(), value);
  202. }
  203. let mut body = remaining[4..].to_vec();
  204. while body.len() < content_length {
  205. let mut chunk = vec![0_u8; content_length - body.len()];
  206. let read = socket.read(&mut chunk).await?;
  207. if read == 0 {
  208. break;
  209. }
  210. body.extend_from_slice(&chunk[..read]);
  211. }
  212. let body = String::from_utf8(body)
  213. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?;
  214. Ok((method, path, headers, body))
  215. }
  216. fn find_header_end(bytes: &[u8]) -> Option<usize> {
  217. bytes.windows(4).position(|window| window == b"\r\n\r\n")
  218. }
  219. fn detect_scenario(request: &MessageRequest) -> Option<Scenario> {
  220. request.messages.iter().rev().find_map(|message| {
  221. message.content.iter().rev().find_map(|block| match block {
  222. InputContentBlock::Text { text } => text
  223. .split_whitespace()
  224. .find_map(|token| token.strip_prefix(SCENARIO_PREFIX))
  225. .and_then(Scenario::parse),
  226. _ => None,
  227. })
  228. })
  229. }
  230. fn latest_tool_result(request: &MessageRequest) -> Option<(String, bool)> {
  231. request.messages.iter().rev().find_map(|message| {
  232. message.content.iter().rev().find_map(|block| match block {
  233. InputContentBlock::ToolResult {
  234. content, is_error, ..
  235. } => Some((flatten_tool_result_content(content), *is_error)),
  236. _ => None,
  237. })
  238. })
  239. }
  240. fn tool_results_by_name(request: &MessageRequest) -> HashMap<String, (String, bool)> {
  241. let mut tool_names_by_id = HashMap::new();
  242. for message in &request.messages {
  243. for block in &message.content {
  244. if let InputContentBlock::ToolUse { id, name, .. } = block {
  245. tool_names_by_id.insert(id.clone(), name.clone());
  246. }
  247. }
  248. }
  249. let mut results = HashMap::new();
  250. for message in request.messages.iter().rev() {
  251. for block in message.content.iter().rev() {
  252. if let InputContentBlock::ToolResult {
  253. tool_use_id,
  254. content,
  255. is_error,
  256. } = block
  257. {
  258. let tool_name = tool_names_by_id
  259. .get(tool_use_id)
  260. .cloned()
  261. .unwrap_or_else(|| tool_use_id.clone());
  262. results
  263. .entry(tool_name)
  264. .or_insert_with(|| (flatten_tool_result_content(content), *is_error));
  265. }
  266. }
  267. }
  268. results
  269. }
  270. fn flatten_tool_result_content(content: &[api::ToolResultContentBlock]) -> String {
  271. content
  272. .iter()
  273. .map(|block| match block {
  274. api::ToolResultContentBlock::Text { text } => text.clone(),
  275. api::ToolResultContentBlock::Json { value } => value.to_string(),
  276. })
  277. .collect::<Vec<_>>()
  278. .join("\n")
  279. }
  280. #[allow(clippy::too_many_lines)]
  281. fn build_http_response(request: &MessageRequest, scenario: Scenario) -> String {
  282. let response = if request.stream {
  283. let body = build_stream_body(request, scenario);
  284. return http_response(
  285. "200 OK",
  286. "text/event-stream",
  287. &body,
  288. &[("x-request-id", request_id_for(scenario))],
  289. );
  290. } else {
  291. build_message_response(request, scenario)
  292. };
  293. http_response(
  294. "200 OK",
  295. "application/json",
  296. &serde_json::to_string(&response).expect("message response should serialize"),
  297. &[("request-id", request_id_for(scenario))],
  298. )
  299. }
  300. #[allow(clippy::too_many_lines)]
  301. fn build_stream_body(request: &MessageRequest, scenario: Scenario) -> String {
  302. match scenario {
  303. Scenario::StreamingText => streaming_text_sse(),
  304. Scenario::ReadFileRoundtrip => match latest_tool_result(request) {
  305. Some((tool_output, _)) => final_text_sse(&format!(
  306. "read_file roundtrip complete: {}",
  307. extract_read_content(&tool_output)
  308. )),
  309. None => tool_use_sse(
  310. "toolu_read_fixture",
  311. "read_file",
  312. &[r#"{"path":"fixture.txt"}"#],
  313. ),
  314. },
  315. Scenario::GrepChunkAssembly => match latest_tool_result(request) {
  316. Some((tool_output, _)) => final_text_sse(&format!(
  317. "grep_search matched {} occurrences",
  318. extract_num_matches(&tool_output)
  319. )),
  320. None => tool_use_sse(
  321. "toolu_grep_fixture",
  322. "grep_search",
  323. &[
  324. "{\"pattern\":\"par",
  325. "ity\",\"path\":\"fixture.txt\"",
  326. ",\"output_mode\":\"count\"}",
  327. ],
  328. ),
  329. },
  330. Scenario::WriteFileAllowed => match latest_tool_result(request) {
  331. Some((tool_output, _)) => final_text_sse(&format!(
  332. "write_file succeeded: {}",
  333. extract_file_path(&tool_output)
  334. )),
  335. None => tool_use_sse(
  336. "toolu_write_allowed",
  337. "write_file",
  338. &[r#"{"path":"generated/output.txt","content":"created by mock service\n"}"#],
  339. ),
  340. },
  341. Scenario::WriteFileDenied => match latest_tool_result(request) {
  342. Some((tool_output, _)) => {
  343. final_text_sse(&format!("write_file denied as expected: {tool_output}"))
  344. }
  345. None => tool_use_sse(
  346. "toolu_write_denied",
  347. "write_file",
  348. &[r#"{"path":"generated/denied.txt","content":"should not exist\n"}"#],
  349. ),
  350. },
  351. Scenario::MultiToolTurnRoundtrip => {
  352. let tool_results = tool_results_by_name(request);
  353. match (
  354. tool_results.get("read_file"),
  355. tool_results.get("grep_search"),
  356. ) {
  357. (Some((read_output, _)), Some((grep_output, _))) => final_text_sse(&format!(
  358. "multi-tool roundtrip complete: {} / {} occurrences",
  359. extract_read_content(read_output),
  360. extract_num_matches(grep_output)
  361. )),
  362. _ => tool_uses_sse(&[
  363. ToolUseSse {
  364. tool_id: "toolu_multi_read",
  365. tool_name: "read_file",
  366. partial_json_chunks: &[r#"{"path":"fixture.txt"}"#],
  367. },
  368. ToolUseSse {
  369. tool_id: "toolu_multi_grep",
  370. tool_name: "grep_search",
  371. partial_json_chunks: &[
  372. "{\"pattern\":\"par",
  373. "ity\",\"path\":\"fixture.txt\"",
  374. ",\"output_mode\":\"count\"}",
  375. ],
  376. },
  377. ]),
  378. }
  379. }
  380. Scenario::BashStdoutRoundtrip => match latest_tool_result(request) {
  381. Some((tool_output, _)) => final_text_sse(&format!(
  382. "bash completed: {}",
  383. extract_bash_stdout(&tool_output)
  384. )),
  385. None => tool_use_sse(
  386. "toolu_bash_stdout",
  387. "bash",
  388. &[r#"{"command":"printf 'alpha from bash'","timeout":1000}"#],
  389. ),
  390. },
  391. Scenario::BashPermissionPromptApproved => match latest_tool_result(request) {
  392. Some((tool_output, is_error)) => {
  393. if is_error {
  394. final_text_sse(&format!("bash approval unexpectedly failed: {tool_output}"))
  395. } else {
  396. final_text_sse(&format!(
  397. "bash approved and executed: {}",
  398. extract_bash_stdout(&tool_output)
  399. ))
  400. }
  401. }
  402. None => tool_use_sse(
  403. "toolu_bash_prompt_allow",
  404. "bash",
  405. &[r#"{"command":"printf 'approved via prompt'","timeout":1000}"#],
  406. ),
  407. },
  408. Scenario::BashPermissionPromptDenied => match latest_tool_result(request) {
  409. Some((tool_output, _)) => {
  410. final_text_sse(&format!("bash denied as expected: {tool_output}"))
  411. }
  412. None => tool_use_sse(
  413. "toolu_bash_prompt_deny",
  414. "bash",
  415. &[r#"{"command":"printf 'should not run'","timeout":1000}"#],
  416. ),
  417. },
  418. Scenario::PluginToolRoundtrip => match latest_tool_result(request) {
  419. Some((tool_output, _)) => final_text_sse(&format!(
  420. "plugin tool completed: {}",
  421. extract_plugin_message(&tool_output)
  422. )),
  423. None => tool_use_sse(
  424. "toolu_plugin_echo",
  425. "plugin_echo",
  426. &[r#"{"message":"hello from plugin parity"}"#],
  427. ),
  428. },
  429. Scenario::AutoCompactTriggered => {
  430. final_text_sse_with_usage("auto compact parity complete.", 50_000, 200)
  431. }
  432. Scenario::TokenCostReporting => {
  433. final_text_sse_with_usage("token cost reporting parity complete.", 1_000, 500)
  434. }
  435. }
  436. }
  437. #[allow(clippy::too_many_lines)]
  438. fn build_message_response(request: &MessageRequest, scenario: Scenario) -> MessageResponse {
  439. match scenario {
  440. Scenario::StreamingText => text_message_response(
  441. "msg_streaming_text",
  442. "Mock streaming says hello from the parity harness.",
  443. ),
  444. Scenario::ReadFileRoundtrip => match latest_tool_result(request) {
  445. Some((tool_output, _)) => text_message_response(
  446. "msg_read_file_final",
  447. &format!(
  448. "read_file roundtrip complete: {}",
  449. extract_read_content(&tool_output)
  450. ),
  451. ),
  452. None => tool_message_response(
  453. "msg_read_file_tool",
  454. "toolu_read_fixture",
  455. "read_file",
  456. json!({"path": "fixture.txt"}),
  457. ),
  458. },
  459. Scenario::GrepChunkAssembly => match latest_tool_result(request) {
  460. Some((tool_output, _)) => text_message_response(
  461. "msg_grep_final",
  462. &format!(
  463. "grep_search matched {} occurrences",
  464. extract_num_matches(&tool_output)
  465. ),
  466. ),
  467. None => tool_message_response(
  468. "msg_grep_tool",
  469. "toolu_grep_fixture",
  470. "grep_search",
  471. json!({"pattern": "parity", "path": "fixture.txt", "output_mode": "count"}),
  472. ),
  473. },
  474. Scenario::WriteFileAllowed => match latest_tool_result(request) {
  475. Some((tool_output, _)) => text_message_response(
  476. "msg_write_allowed_final",
  477. &format!("write_file succeeded: {}", extract_file_path(&tool_output)),
  478. ),
  479. None => tool_message_response(
  480. "msg_write_allowed_tool",
  481. "toolu_write_allowed",
  482. "write_file",
  483. json!({"path": "generated/output.txt", "content": "created by mock service\n"}),
  484. ),
  485. },
  486. Scenario::WriteFileDenied => match latest_tool_result(request) {
  487. Some((tool_output, _)) => text_message_response(
  488. "msg_write_denied_final",
  489. &format!("write_file denied as expected: {tool_output}"),
  490. ),
  491. None => tool_message_response(
  492. "msg_write_denied_tool",
  493. "toolu_write_denied",
  494. "write_file",
  495. json!({"path": "generated/denied.txt", "content": "should not exist\n"}),
  496. ),
  497. },
  498. Scenario::MultiToolTurnRoundtrip => {
  499. let tool_results = tool_results_by_name(request);
  500. match (
  501. tool_results.get("read_file"),
  502. tool_results.get("grep_search"),
  503. ) {
  504. (Some((read_output, _)), Some((grep_output, _))) => text_message_response(
  505. "msg_multi_tool_final",
  506. &format!(
  507. "multi-tool roundtrip complete: {} / {} occurrences",
  508. extract_read_content(read_output),
  509. extract_num_matches(grep_output)
  510. ),
  511. ),
  512. _ => tool_message_response_many(
  513. "msg_multi_tool_start",
  514. &[
  515. ToolUseMessage {
  516. tool_id: "toolu_multi_read",
  517. tool_name: "read_file",
  518. input: json!({"path": "fixture.txt"}),
  519. },
  520. ToolUseMessage {
  521. tool_id: "toolu_multi_grep",
  522. tool_name: "grep_search",
  523. input: json!({"pattern": "parity", "path": "fixture.txt", "output_mode": "count"}),
  524. },
  525. ],
  526. ),
  527. }
  528. }
  529. Scenario::BashStdoutRoundtrip => match latest_tool_result(request) {
  530. Some((tool_output, _)) => text_message_response(
  531. "msg_bash_stdout_final",
  532. &format!("bash completed: {}", extract_bash_stdout(&tool_output)),
  533. ),
  534. None => tool_message_response(
  535. "msg_bash_stdout_tool",
  536. "toolu_bash_stdout",
  537. "bash",
  538. json!({"command": "printf 'alpha from bash'", "timeout": 1000}),
  539. ),
  540. },
  541. Scenario::BashPermissionPromptApproved => match latest_tool_result(request) {
  542. Some((tool_output, is_error)) => {
  543. if is_error {
  544. text_message_response(
  545. "msg_bash_prompt_allow_error",
  546. &format!("bash approval unexpectedly failed: {tool_output}"),
  547. )
  548. } else {
  549. text_message_response(
  550. "msg_bash_prompt_allow_final",
  551. &format!(
  552. "bash approved and executed: {}",
  553. extract_bash_stdout(&tool_output)
  554. ),
  555. )
  556. }
  557. }
  558. None => tool_message_response(
  559. "msg_bash_prompt_allow_tool",
  560. "toolu_bash_prompt_allow",
  561. "bash",
  562. json!({"command": "printf 'approved via prompt'", "timeout": 1000}),
  563. ),
  564. },
  565. Scenario::BashPermissionPromptDenied => match latest_tool_result(request) {
  566. Some((tool_output, _)) => text_message_response(
  567. "msg_bash_prompt_deny_final",
  568. &format!("bash denied as expected: {tool_output}"),
  569. ),
  570. None => tool_message_response(
  571. "msg_bash_prompt_deny_tool",
  572. "toolu_bash_prompt_deny",
  573. "bash",
  574. json!({"command": "printf 'should not run'", "timeout": 1000}),
  575. ),
  576. },
  577. Scenario::PluginToolRoundtrip => match latest_tool_result(request) {
  578. Some((tool_output, _)) => text_message_response(
  579. "msg_plugin_tool_final",
  580. &format!(
  581. "plugin tool completed: {}",
  582. extract_plugin_message(&tool_output)
  583. ),
  584. ),
  585. None => tool_message_response(
  586. "msg_plugin_tool_start",
  587. "toolu_plugin_echo",
  588. "plugin_echo",
  589. json!({"message": "hello from plugin parity"}),
  590. ),
  591. },
  592. Scenario::AutoCompactTriggered => text_message_response_with_usage(
  593. "msg_auto_compact_triggered",
  594. "auto compact parity complete.",
  595. 50_000,
  596. 200,
  597. ),
  598. Scenario::TokenCostReporting => text_message_response_with_usage(
  599. "msg_token_cost_reporting",
  600. "token cost reporting parity complete.",
  601. 1_000,
  602. 500,
  603. ),
  604. }
  605. }
  606. fn request_id_for(scenario: Scenario) -> &'static str {
  607. match scenario {
  608. Scenario::StreamingText => "req_streaming_text",
  609. Scenario::ReadFileRoundtrip => "req_read_file_roundtrip",
  610. Scenario::GrepChunkAssembly => "req_grep_chunk_assembly",
  611. Scenario::WriteFileAllowed => "req_write_file_allowed",
  612. Scenario::WriteFileDenied => "req_write_file_denied",
  613. Scenario::MultiToolTurnRoundtrip => "req_multi_tool_turn_roundtrip",
  614. Scenario::BashStdoutRoundtrip => "req_bash_stdout_roundtrip",
  615. Scenario::BashPermissionPromptApproved => "req_bash_permission_prompt_approved",
  616. Scenario::BashPermissionPromptDenied => "req_bash_permission_prompt_denied",
  617. Scenario::PluginToolRoundtrip => "req_plugin_tool_roundtrip",
  618. Scenario::AutoCompactTriggered => "req_auto_compact_triggered",
  619. Scenario::TokenCostReporting => "req_token_cost_reporting",
  620. }
  621. }
  622. fn http_response(status: &str, content_type: &str, body: &str, headers: &[(&str, &str)]) -> String {
  623. let mut extra_headers = String::new();
  624. for (name, value) in headers {
  625. use std::fmt::Write as _;
  626. write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed");
  627. }
  628. format!(
  629. "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
  630. body.len()
  631. )
  632. }
  633. fn text_message_response(id: &str, text: &str) -> MessageResponse {
  634. MessageResponse {
  635. id: id.to_string(),
  636. kind: "message".to_string(),
  637. role: "assistant".to_string(),
  638. content: vec![OutputContentBlock::Text {
  639. text: text.to_string(),
  640. }],
  641. model: DEFAULT_MODEL.to_string(),
  642. stop_reason: Some("end_turn".to_string()),
  643. stop_sequence: None,
  644. usage: Usage {
  645. input_tokens: 10,
  646. cache_creation_input_tokens: 0,
  647. cache_read_input_tokens: 0,
  648. output_tokens: 6,
  649. },
  650. request_id: None,
  651. }
  652. }
  653. fn text_message_response_with_usage(
  654. id: &str,
  655. text: &str,
  656. input_tokens: u32,
  657. output_tokens: u32,
  658. ) -> MessageResponse {
  659. MessageResponse {
  660. id: id.to_string(),
  661. kind: "message".to_string(),
  662. role: "assistant".to_string(),
  663. content: vec![OutputContentBlock::Text {
  664. text: text.to_string(),
  665. }],
  666. model: DEFAULT_MODEL.to_string(),
  667. stop_reason: Some("end_turn".to_string()),
  668. stop_sequence: None,
  669. usage: Usage {
  670. input_tokens,
  671. cache_creation_input_tokens: 0,
  672. cache_read_input_tokens: 0,
  673. output_tokens,
  674. },
  675. request_id: None,
  676. }
  677. }
  678. fn tool_message_response(
  679. id: &str,
  680. tool_id: &str,
  681. tool_name: &str,
  682. input: Value,
  683. ) -> MessageResponse {
  684. tool_message_response_many(
  685. id,
  686. &[ToolUseMessage {
  687. tool_id,
  688. tool_name,
  689. input,
  690. }],
  691. )
  692. }
  693. struct ToolUseMessage<'a> {
  694. tool_id: &'a str,
  695. tool_name: &'a str,
  696. input: Value,
  697. }
  698. fn tool_message_response_many(id: &str, tool_uses: &[ToolUseMessage<'_>]) -> MessageResponse {
  699. MessageResponse {
  700. id: id.to_string(),
  701. kind: "message".to_string(),
  702. role: "assistant".to_string(),
  703. content: tool_uses
  704. .iter()
  705. .map(|tool_use| OutputContentBlock::ToolUse {
  706. id: tool_use.tool_id.to_string(),
  707. name: tool_use.tool_name.to_string(),
  708. input: tool_use.input.clone(),
  709. })
  710. .collect(),
  711. model: DEFAULT_MODEL.to_string(),
  712. stop_reason: Some("tool_use".to_string()),
  713. stop_sequence: None,
  714. usage: Usage {
  715. input_tokens: 10,
  716. cache_creation_input_tokens: 0,
  717. cache_read_input_tokens: 0,
  718. output_tokens: 3,
  719. },
  720. request_id: None,
  721. }
  722. }
  723. fn streaming_text_sse() -> String {
  724. let mut body = String::new();
  725. append_sse(
  726. &mut body,
  727. "message_start",
  728. json!({
  729. "type": "message_start",
  730. "message": {
  731. "id": "msg_streaming_text",
  732. "type": "message",
  733. "role": "assistant",
  734. "content": [],
  735. "model": DEFAULT_MODEL,
  736. "stop_reason": null,
  737. "stop_sequence": null,
  738. "usage": usage_json(11, 0)
  739. }
  740. }),
  741. );
  742. append_sse(
  743. &mut body,
  744. "content_block_start",
  745. json!({
  746. "type": "content_block_start",
  747. "index": 0,
  748. "content_block": {"type": "text", "text": ""}
  749. }),
  750. );
  751. append_sse(
  752. &mut body,
  753. "content_block_delta",
  754. json!({
  755. "type": "content_block_delta",
  756. "index": 0,
  757. "delta": {"type": "text_delta", "text": "Mock streaming "}
  758. }),
  759. );
  760. append_sse(
  761. &mut body,
  762. "content_block_delta",
  763. json!({
  764. "type": "content_block_delta",
  765. "index": 0,
  766. "delta": {"type": "text_delta", "text": "says hello from the parity harness."}
  767. }),
  768. );
  769. append_sse(
  770. &mut body,
  771. "content_block_stop",
  772. json!({
  773. "type": "content_block_stop",
  774. "index": 0
  775. }),
  776. );
  777. append_sse(
  778. &mut body,
  779. "message_delta",
  780. json!({
  781. "type": "message_delta",
  782. "delta": {"stop_reason": "end_turn", "stop_sequence": null},
  783. "usage": usage_json(11, 8)
  784. }),
  785. );
  786. append_sse(&mut body, "message_stop", json!({"type": "message_stop"}));
  787. body
  788. }
  789. fn tool_use_sse(tool_id: &str, tool_name: &str, partial_json_chunks: &[&str]) -> String {
  790. tool_uses_sse(&[ToolUseSse {
  791. tool_id,
  792. tool_name,
  793. partial_json_chunks,
  794. }])
  795. }
  796. struct ToolUseSse<'a> {
  797. tool_id: &'a str,
  798. tool_name: &'a str,
  799. partial_json_chunks: &'a [&'a str],
  800. }
  801. fn tool_uses_sse(tool_uses: &[ToolUseSse<'_>]) -> String {
  802. let mut body = String::new();
  803. let message_id = tool_uses.first().map_or_else(
  804. || "msg_tool_use".to_string(),
  805. |tool_use| format!("msg_{}", tool_use.tool_id),
  806. );
  807. append_sse(
  808. &mut body,
  809. "message_start",
  810. json!({
  811. "type": "message_start",
  812. "message": {
  813. "id": message_id,
  814. "type": "message",
  815. "role": "assistant",
  816. "content": [],
  817. "model": DEFAULT_MODEL,
  818. "stop_reason": null,
  819. "stop_sequence": null,
  820. "usage": usage_json(12, 0)
  821. }
  822. }),
  823. );
  824. for (index, tool_use) in tool_uses.iter().enumerate() {
  825. append_sse(
  826. &mut body,
  827. "content_block_start",
  828. json!({
  829. "type": "content_block_start",
  830. "index": index,
  831. "content_block": {
  832. "type": "tool_use",
  833. "id": tool_use.tool_id,
  834. "name": tool_use.tool_name,
  835. "input": {}
  836. }
  837. }),
  838. );
  839. for chunk in tool_use.partial_json_chunks {
  840. append_sse(
  841. &mut body,
  842. "content_block_delta",
  843. json!({
  844. "type": "content_block_delta",
  845. "index": index,
  846. "delta": {"type": "input_json_delta", "partial_json": chunk}
  847. }),
  848. );
  849. }
  850. append_sse(
  851. &mut body,
  852. "content_block_stop",
  853. json!({
  854. "type": "content_block_stop",
  855. "index": index
  856. }),
  857. );
  858. }
  859. append_sse(
  860. &mut body,
  861. "message_delta",
  862. json!({
  863. "type": "message_delta",
  864. "delta": {"stop_reason": "tool_use", "stop_sequence": null},
  865. "usage": usage_json(12, 4)
  866. }),
  867. );
  868. append_sse(&mut body, "message_stop", json!({"type": "message_stop"}));
  869. body
  870. }
  871. fn final_text_sse(text: &str) -> String {
  872. let mut body = String::new();
  873. append_sse(
  874. &mut body,
  875. "message_start",
  876. json!({
  877. "type": "message_start",
  878. "message": {
  879. "id": unique_message_id(),
  880. "type": "message",
  881. "role": "assistant",
  882. "content": [],
  883. "model": DEFAULT_MODEL,
  884. "stop_reason": null,
  885. "stop_sequence": null,
  886. "usage": usage_json(14, 0)
  887. }
  888. }),
  889. );
  890. append_sse(
  891. &mut body,
  892. "content_block_start",
  893. json!({
  894. "type": "content_block_start",
  895. "index": 0,
  896. "content_block": {"type": "text", "text": ""}
  897. }),
  898. );
  899. append_sse(
  900. &mut body,
  901. "content_block_delta",
  902. json!({
  903. "type": "content_block_delta",
  904. "index": 0,
  905. "delta": {"type": "text_delta", "text": text}
  906. }),
  907. );
  908. append_sse(
  909. &mut body,
  910. "content_block_stop",
  911. json!({
  912. "type": "content_block_stop",
  913. "index": 0
  914. }),
  915. );
  916. append_sse(
  917. &mut body,
  918. "message_delta",
  919. json!({
  920. "type": "message_delta",
  921. "delta": {"stop_reason": "end_turn", "stop_sequence": null},
  922. "usage": usage_json(14, 7)
  923. }),
  924. );
  925. append_sse(&mut body, "message_stop", json!({"type": "message_stop"}));
  926. body
  927. }
  928. fn final_text_sse_with_usage(text: &str, input_tokens: u32, output_tokens: u32) -> String {
  929. let mut body = String::new();
  930. append_sse(
  931. &mut body,
  932. "message_start",
  933. json!({
  934. "type": "message_start",
  935. "message": {
  936. "id": unique_message_id(),
  937. "type": "message",
  938. "role": "assistant",
  939. "content": [],
  940. "model": DEFAULT_MODEL,
  941. "stop_reason": null,
  942. "stop_sequence": null,
  943. "usage": {
  944. "input_tokens": input_tokens,
  945. "cache_creation_input_tokens": 0,
  946. "cache_read_input_tokens": 0,
  947. "output_tokens": 0
  948. }
  949. }
  950. }),
  951. );
  952. append_sse(
  953. &mut body,
  954. "content_block_start",
  955. json!({
  956. "type": "content_block_start",
  957. "index": 0,
  958. "content_block": {"type": "text", "text": ""}
  959. }),
  960. );
  961. append_sse(
  962. &mut body,
  963. "content_block_delta",
  964. json!({
  965. "type": "content_block_delta",
  966. "index": 0,
  967. "delta": {"type": "text_delta", "text": text}
  968. }),
  969. );
  970. append_sse(
  971. &mut body,
  972. "content_block_stop",
  973. json!({
  974. "type": "content_block_stop",
  975. "index": 0
  976. }),
  977. );
  978. append_sse(
  979. &mut body,
  980. "message_delta",
  981. json!({
  982. "type": "message_delta",
  983. "delta": {"stop_reason": "end_turn", "stop_sequence": null},
  984. "usage": {
  985. "input_tokens": input_tokens,
  986. "cache_creation_input_tokens": 0,
  987. "cache_read_input_tokens": 0,
  988. "output_tokens": output_tokens
  989. }
  990. }),
  991. );
  992. append_sse(&mut body, "message_stop", json!({"type": "message_stop"}));
  993. body
  994. }
  995. #[allow(clippy::needless_pass_by_value)]
  996. fn append_sse(buffer: &mut String, event: &str, payload: Value) {
  997. use std::fmt::Write as _;
  998. writeln!(buffer, "event: {event}").expect("event write should succeed");
  999. writeln!(buffer, "data: {payload}").expect("payload write should succeed");
  1000. buffer.push('\n');
  1001. }
  1002. fn usage_json(input_tokens: u32, output_tokens: u32) -> Value {
  1003. json!({
  1004. "input_tokens": input_tokens,
  1005. "cache_creation_input_tokens": 0,
  1006. "cache_read_input_tokens": 0,
  1007. "output_tokens": output_tokens
  1008. })
  1009. }
  1010. fn unique_message_id() -> String {
  1011. let nanos = SystemTime::now()
  1012. .duration_since(UNIX_EPOCH)
  1013. .expect("clock should be after epoch")
  1014. .as_nanos();
  1015. format!("msg_{nanos}")
  1016. }
  1017. fn extract_read_content(tool_output: &str) -> String {
  1018. serde_json::from_str::<Value>(tool_output)
  1019. .ok()
  1020. .and_then(|value| {
  1021. value
  1022. .get("file")
  1023. .and_then(|file| file.get("content"))
  1024. .and_then(Value::as_str)
  1025. .map(ToOwned::to_owned)
  1026. })
  1027. .unwrap_or_else(|| tool_output.trim().to_string())
  1028. }
  1029. #[allow(clippy::cast_possible_truncation)]
  1030. fn extract_num_matches(tool_output: &str) -> usize {
  1031. serde_json::from_str::<Value>(tool_output)
  1032. .ok()
  1033. .and_then(|value| value.get("numMatches").and_then(Value::as_u64))
  1034. .unwrap_or(0) as usize
  1035. }
  1036. fn extract_file_path(tool_output: &str) -> String {
  1037. serde_json::from_str::<Value>(tool_output)
  1038. .ok()
  1039. .and_then(|value| {
  1040. value
  1041. .get("filePath")
  1042. .and_then(Value::as_str)
  1043. .map(ToOwned::to_owned)
  1044. })
  1045. .unwrap_or_else(|| tool_output.trim().to_string())
  1046. }
  1047. fn extract_bash_stdout(tool_output: &str) -> String {
  1048. serde_json::from_str::<Value>(tool_output)
  1049. .ok()
  1050. .and_then(|value| {
  1051. value
  1052. .get("stdout")
  1053. .and_then(Value::as_str)
  1054. .map(ToOwned::to_owned)
  1055. })
  1056. .unwrap_or_else(|| tool_output.trim().to_string())
  1057. }
  1058. fn extract_plugin_message(tool_output: &str) -> String {
  1059. serde_json::from_str::<Value>(tool_output)
  1060. .ok()
  1061. .and_then(|value| {
  1062. value
  1063. .get("input")
  1064. .and_then(|input| input.get("message"))
  1065. .and_then(Value::as_str)
  1066. .map(ToOwned::to_owned)
  1067. })
  1068. .unwrap_or_else(|| tool_output.trim().to_string())
  1069. }