lib.rs 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003
  1. use std::collections::HashMap;
  2. use std::io;
  3. use std::sync::Arc;
  4. use std::time::{SystemTime, UNIX_EPOCH};
  5. use api::{InputContentBlock, MessageRequest, MessageResponse, OutputContentBlock, Usage};
  6. use serde_json::{json, Value};
  7. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  8. use tokio::net::TcpListener;
  9. use tokio::sync::{oneshot, Mutex};
  10. use tokio::task::JoinHandle;
  11. pub const SCENARIO_PREFIX: &str = "PARITY_SCENARIO:";
  12. pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6";
  13. #[derive(Debug, Clone, PartialEq, Eq)]
  14. pub struct CapturedRequest {
  15. pub method: String,
  16. pub path: String,
  17. pub headers: HashMap<String, String>,
  18. pub scenario: String,
  19. pub stream: bool,
  20. pub raw_body: String,
  21. }
  22. pub struct MockAnthropicService {
  23. base_url: String,
  24. requests: Arc<Mutex<Vec<CapturedRequest>>>,
  25. shutdown: Option<oneshot::Sender<()>>,
  26. join_handle: JoinHandle<()>,
  27. }
  28. impl MockAnthropicService {
  29. pub async fn spawn() -> io::Result<Self> {
  30. Self::spawn_on("127.0.0.1:0").await
  31. }
  32. pub async fn spawn_on(bind_addr: &str) -> io::Result<Self> {
  33. let listener = TcpListener::bind(bind_addr).await?;
  34. let address = listener.local_addr()?;
  35. let requests = Arc::new(Mutex::new(Vec::new()));
  36. let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
  37. let request_state = Arc::clone(&requests);
  38. let join_handle = tokio::spawn(async move {
  39. loop {
  40. tokio::select! {
  41. _ = &mut shutdown_rx => break,
  42. accepted = listener.accept() => {
  43. let Ok((socket, _)) = accepted else {
  44. break;
  45. };
  46. let request_state = Arc::clone(&request_state);
  47. tokio::spawn(async move {
  48. let _ = handle_connection(socket, request_state).await;
  49. });
  50. }
  51. }
  52. }
  53. });
  54. Ok(Self {
  55. base_url: format!("http://{address}"),
  56. requests,
  57. shutdown: Some(shutdown_tx),
  58. join_handle,
  59. })
  60. }
  61. #[must_use]
  62. pub fn base_url(&self) -> String {
  63. self.base_url.clone()
  64. }
  65. pub async fn captured_requests(&self) -> Vec<CapturedRequest> {
  66. self.requests.lock().await.clone()
  67. }
  68. }
  69. impl Drop for MockAnthropicService {
  70. fn drop(&mut self) {
  71. if let Some(shutdown) = self.shutdown.take() {
  72. let _ = shutdown.send(());
  73. }
  74. self.join_handle.abort();
  75. }
  76. }
  77. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  78. enum Scenario {
  79. StreamingText,
  80. ReadFileRoundtrip,
  81. GrepChunkAssembly,
  82. WriteFileAllowed,
  83. WriteFileDenied,
  84. MultiToolTurnRoundtrip,
  85. BashStdoutRoundtrip,
  86. BashPermissionPromptApproved,
  87. BashPermissionPromptDenied,
  88. PluginToolRoundtrip,
  89. }
  90. impl Scenario {
  91. fn parse(value: &str) -> Option<Self> {
  92. match value.trim() {
  93. "streaming_text" => Some(Self::StreamingText),
  94. "read_file_roundtrip" => Some(Self::ReadFileRoundtrip),
  95. "grep_chunk_assembly" => Some(Self::GrepChunkAssembly),
  96. "write_file_allowed" => Some(Self::WriteFileAllowed),
  97. "write_file_denied" => Some(Self::WriteFileDenied),
  98. "multi_tool_turn_roundtrip" => Some(Self::MultiToolTurnRoundtrip),
  99. "bash_stdout_roundtrip" => Some(Self::BashStdoutRoundtrip),
  100. "bash_permission_prompt_approved" => Some(Self::BashPermissionPromptApproved),
  101. "bash_permission_prompt_denied" => Some(Self::BashPermissionPromptDenied),
  102. "plugin_tool_roundtrip" => Some(Self::PluginToolRoundtrip),
  103. _ => None,
  104. }
  105. }
  106. fn name(self) -> &'static str {
  107. match self {
  108. Self::StreamingText => "streaming_text",
  109. Self::ReadFileRoundtrip => "read_file_roundtrip",
  110. Self::GrepChunkAssembly => "grep_chunk_assembly",
  111. Self::WriteFileAllowed => "write_file_allowed",
  112. Self::WriteFileDenied => "write_file_denied",
  113. Self::MultiToolTurnRoundtrip => "multi_tool_turn_roundtrip",
  114. Self::BashStdoutRoundtrip => "bash_stdout_roundtrip",
  115. Self::BashPermissionPromptApproved => "bash_permission_prompt_approved",
  116. Self::BashPermissionPromptDenied => "bash_permission_prompt_denied",
  117. Self::PluginToolRoundtrip => "plugin_tool_roundtrip",
  118. }
  119. }
  120. }
  121. async fn handle_connection(
  122. mut socket: tokio::net::TcpStream,
  123. requests: Arc<Mutex<Vec<CapturedRequest>>>,
  124. ) -> io::Result<()> {
  125. let (method, path, headers, raw_body) = read_http_request(&mut socket).await?;
  126. let request: MessageRequest = serde_json::from_str(&raw_body)
  127. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?;
  128. let scenario = detect_scenario(&request)
  129. .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing parity scenario"))?;
  130. requests.lock().await.push(CapturedRequest {
  131. method,
  132. path,
  133. headers,
  134. scenario: scenario.name().to_string(),
  135. stream: request.stream,
  136. raw_body,
  137. });
  138. let response = build_http_response(&request, scenario);
  139. socket.write_all(response.as_bytes()).await?;
  140. Ok(())
  141. }
  142. async fn read_http_request(
  143. socket: &mut tokio::net::TcpStream,
  144. ) -> io::Result<(String, String, HashMap<String, String>, String)> {
  145. let mut buffer = Vec::new();
  146. let mut header_end = None;
  147. loop {
  148. let mut chunk = [0_u8; 1024];
  149. let read = socket.read(&mut chunk).await?;
  150. if read == 0 {
  151. break;
  152. }
  153. buffer.extend_from_slice(&chunk[..read]);
  154. if let Some(position) = find_header_end(&buffer) {
  155. header_end = Some(position);
  156. break;
  157. }
  158. }
  159. let header_end = header_end
  160. .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "missing http headers"))?;
  161. let (header_bytes, remaining) = buffer.split_at(header_end);
  162. let header_text = String::from_utf8(header_bytes.to_vec())
  163. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?;
  164. let mut lines = header_text.split("\r\n");
  165. let request_line = lines
  166. .next()
  167. .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing request line"))?;
  168. let mut request_parts = request_line.split_whitespace();
  169. let method = request_parts
  170. .next()
  171. .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing method"))?
  172. .to_string();
  173. let path = request_parts
  174. .next()
  175. .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing path"))?
  176. .to_string();
  177. let mut headers = HashMap::new();
  178. let mut content_length = 0_usize;
  179. for line in lines {
  180. if line.is_empty() {
  181. continue;
  182. }
  183. let (name, value) = line.split_once(':').ok_or_else(|| {
  184. io::Error::new(io::ErrorKind::InvalidData, "malformed http header line")
  185. })?;
  186. let value = value.trim().to_string();
  187. if name.eq_ignore_ascii_case("content-length") {
  188. content_length = value.parse().map_err(|error| {
  189. io::Error::new(
  190. io::ErrorKind::InvalidData,
  191. format!("invalid content-length: {error}"),
  192. )
  193. })?;
  194. }
  195. headers.insert(name.to_ascii_lowercase(), value);
  196. }
  197. let mut body = remaining[4..].to_vec();
  198. while body.len() < content_length {
  199. let mut chunk = vec![0_u8; content_length - body.len()];
  200. let read = socket.read(&mut chunk).await?;
  201. if read == 0 {
  202. break;
  203. }
  204. body.extend_from_slice(&chunk[..read]);
  205. }
  206. let body = String::from_utf8(body)
  207. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?;
  208. Ok((method, path, headers, body))
  209. }
  210. fn find_header_end(bytes: &[u8]) -> Option<usize> {
  211. bytes.windows(4).position(|window| window == b"\r\n\r\n")
  212. }
  213. fn detect_scenario(request: &MessageRequest) -> Option<Scenario> {
  214. request.messages.iter().rev().find_map(|message| {
  215. message.content.iter().rev().find_map(|block| match block {
  216. InputContentBlock::Text { text } => text
  217. .split_whitespace()
  218. .find_map(|token| token.strip_prefix(SCENARIO_PREFIX))
  219. .and_then(Scenario::parse),
  220. _ => None,
  221. })
  222. })
  223. }
  224. fn latest_tool_result(request: &MessageRequest) -> Option<(String, bool)> {
  225. request.messages.iter().rev().find_map(|message| {
  226. message.content.iter().rev().find_map(|block| match block {
  227. InputContentBlock::ToolResult {
  228. content, is_error, ..
  229. } => Some((flatten_tool_result_content(content), *is_error)),
  230. _ => None,
  231. })
  232. })
  233. }
  234. fn tool_results_by_name(request: &MessageRequest) -> HashMap<String, (String, bool)> {
  235. let mut tool_names_by_id = HashMap::new();
  236. for message in &request.messages {
  237. for block in &message.content {
  238. if let InputContentBlock::ToolUse { id, name, .. } = block {
  239. tool_names_by_id.insert(id.clone(), name.clone());
  240. }
  241. }
  242. }
  243. let mut results = HashMap::new();
  244. for message in request.messages.iter().rev() {
  245. for block in message.content.iter().rev() {
  246. if let InputContentBlock::ToolResult {
  247. tool_use_id,
  248. content,
  249. is_error,
  250. } = block
  251. {
  252. let tool_name = tool_names_by_id
  253. .get(tool_use_id)
  254. .cloned()
  255. .unwrap_or_else(|| tool_use_id.clone());
  256. results
  257. .entry(tool_name)
  258. .or_insert_with(|| (flatten_tool_result_content(content), *is_error));
  259. }
  260. }
  261. }
  262. results
  263. }
  264. fn flatten_tool_result_content(content: &[api::ToolResultContentBlock]) -> String {
  265. content
  266. .iter()
  267. .map(|block| match block {
  268. api::ToolResultContentBlock::Text { text } => text.clone(),
  269. api::ToolResultContentBlock::Json { value } => value.to_string(),
  270. })
  271. .collect::<Vec<_>>()
  272. .join("\n")
  273. }
  274. #[allow(clippy::too_many_lines)]
  275. fn build_http_response(request: &MessageRequest, scenario: Scenario) -> String {
  276. let response = if request.stream {
  277. let body = build_stream_body(request, scenario);
  278. return http_response(
  279. "200 OK",
  280. "text/event-stream",
  281. &body,
  282. &[("x-request-id", request_id_for(scenario))],
  283. );
  284. } else {
  285. build_message_response(request, scenario)
  286. };
  287. http_response(
  288. "200 OK",
  289. "application/json",
  290. &serde_json::to_string(&response).expect("message response should serialize"),
  291. &[("request-id", request_id_for(scenario))],
  292. )
  293. }
  294. #[allow(clippy::too_many_lines)]
  295. fn build_stream_body(request: &MessageRequest, scenario: Scenario) -> String {
  296. match scenario {
  297. Scenario::StreamingText => streaming_text_sse(),
  298. Scenario::ReadFileRoundtrip => match latest_tool_result(request) {
  299. Some((tool_output, _)) => final_text_sse(&format!(
  300. "read_file roundtrip complete: {}",
  301. extract_read_content(&tool_output)
  302. )),
  303. None => tool_use_sse(
  304. "toolu_read_fixture",
  305. "read_file",
  306. &[r#"{"path":"fixture.txt"}"#],
  307. ),
  308. },
  309. Scenario::GrepChunkAssembly => match latest_tool_result(request) {
  310. Some((tool_output, _)) => final_text_sse(&format!(
  311. "grep_search matched {} occurrences",
  312. extract_num_matches(&tool_output)
  313. )),
  314. None => tool_use_sse(
  315. "toolu_grep_fixture",
  316. "grep_search",
  317. &[
  318. "{\"pattern\":\"par",
  319. "ity\",\"path\":\"fixture.txt\"",
  320. ",\"output_mode\":\"count\"}",
  321. ],
  322. ),
  323. },
  324. Scenario::WriteFileAllowed => match latest_tool_result(request) {
  325. Some((tool_output, _)) => final_text_sse(&format!(
  326. "write_file succeeded: {}",
  327. extract_file_path(&tool_output)
  328. )),
  329. None => tool_use_sse(
  330. "toolu_write_allowed",
  331. "write_file",
  332. &[r#"{"path":"generated/output.txt","content":"created by mock service\n"}"#],
  333. ),
  334. },
  335. Scenario::WriteFileDenied => match latest_tool_result(request) {
  336. Some((tool_output, _)) => {
  337. final_text_sse(&format!("write_file denied as expected: {tool_output}"))
  338. }
  339. None => tool_use_sse(
  340. "toolu_write_denied",
  341. "write_file",
  342. &[r#"{"path":"generated/denied.txt","content":"should not exist\n"}"#],
  343. ),
  344. },
  345. Scenario::MultiToolTurnRoundtrip => {
  346. let tool_results = tool_results_by_name(request);
  347. match (
  348. tool_results.get("read_file"),
  349. tool_results.get("grep_search"),
  350. ) {
  351. (Some((read_output, _)), Some((grep_output, _))) => final_text_sse(&format!(
  352. "multi-tool roundtrip complete: {} / {} occurrences",
  353. extract_read_content(read_output),
  354. extract_num_matches(grep_output)
  355. )),
  356. _ => tool_uses_sse(&[
  357. ToolUseSse {
  358. tool_id: "toolu_multi_read",
  359. tool_name: "read_file",
  360. partial_json_chunks: &[r#"{"path":"fixture.txt"}"#],
  361. },
  362. ToolUseSse {
  363. tool_id: "toolu_multi_grep",
  364. tool_name: "grep_search",
  365. partial_json_chunks: &[
  366. "{\"pattern\":\"par",
  367. "ity\",\"path\":\"fixture.txt\"",
  368. ",\"output_mode\":\"count\"}",
  369. ],
  370. },
  371. ]),
  372. }
  373. }
  374. Scenario::BashStdoutRoundtrip => match latest_tool_result(request) {
  375. Some((tool_output, _)) => final_text_sse(&format!(
  376. "bash completed: {}",
  377. extract_bash_stdout(&tool_output)
  378. )),
  379. None => tool_use_sse(
  380. "toolu_bash_stdout",
  381. "bash",
  382. &[r#"{"command":"printf 'alpha from bash'","timeout":1000}"#],
  383. ),
  384. },
  385. Scenario::BashPermissionPromptApproved => match latest_tool_result(request) {
  386. Some((tool_output, is_error)) => {
  387. if is_error {
  388. final_text_sse(&format!("bash approval unexpectedly failed: {tool_output}"))
  389. } else {
  390. final_text_sse(&format!(
  391. "bash approved and executed: {}",
  392. extract_bash_stdout(&tool_output)
  393. ))
  394. }
  395. }
  396. None => tool_use_sse(
  397. "toolu_bash_prompt_allow",
  398. "bash",
  399. &[r#"{"command":"printf 'approved via prompt'","timeout":1000}"#],
  400. ),
  401. },
  402. Scenario::BashPermissionPromptDenied => match latest_tool_result(request) {
  403. Some((tool_output, _)) => {
  404. final_text_sse(&format!("bash denied as expected: {tool_output}"))
  405. }
  406. None => tool_use_sse(
  407. "toolu_bash_prompt_deny",
  408. "bash",
  409. &[r#"{"command":"printf 'should not run'","timeout":1000}"#],
  410. ),
  411. },
  412. Scenario::PluginToolRoundtrip => match latest_tool_result(request) {
  413. Some((tool_output, _)) => final_text_sse(&format!(
  414. "plugin tool completed: {}",
  415. extract_plugin_message(&tool_output)
  416. )),
  417. None => tool_use_sse(
  418. "toolu_plugin_echo",
  419. "plugin_echo",
  420. &[r#"{"message":"hello from plugin parity"}"#],
  421. ),
  422. },
  423. }
  424. }
  425. #[allow(clippy::too_many_lines)]
  426. fn build_message_response(request: &MessageRequest, scenario: Scenario) -> MessageResponse {
  427. match scenario {
  428. Scenario::StreamingText => text_message_response(
  429. "msg_streaming_text",
  430. "Mock streaming says hello from the parity harness.",
  431. ),
  432. Scenario::ReadFileRoundtrip => match latest_tool_result(request) {
  433. Some((tool_output, _)) => text_message_response(
  434. "msg_read_file_final",
  435. &format!(
  436. "read_file roundtrip complete: {}",
  437. extract_read_content(&tool_output)
  438. ),
  439. ),
  440. None => tool_message_response(
  441. "msg_read_file_tool",
  442. "toolu_read_fixture",
  443. "read_file",
  444. json!({"path": "fixture.txt"}),
  445. ),
  446. },
  447. Scenario::GrepChunkAssembly => match latest_tool_result(request) {
  448. Some((tool_output, _)) => text_message_response(
  449. "msg_grep_final",
  450. &format!(
  451. "grep_search matched {} occurrences",
  452. extract_num_matches(&tool_output)
  453. ),
  454. ),
  455. None => tool_message_response(
  456. "msg_grep_tool",
  457. "toolu_grep_fixture",
  458. "grep_search",
  459. json!({"pattern": "parity", "path": "fixture.txt", "output_mode": "count"}),
  460. ),
  461. },
  462. Scenario::WriteFileAllowed => match latest_tool_result(request) {
  463. Some((tool_output, _)) => text_message_response(
  464. "msg_write_allowed_final",
  465. &format!("write_file succeeded: {}", extract_file_path(&tool_output)),
  466. ),
  467. None => tool_message_response(
  468. "msg_write_allowed_tool",
  469. "toolu_write_allowed",
  470. "write_file",
  471. json!({"path": "generated/output.txt", "content": "created by mock service\n"}),
  472. ),
  473. },
  474. Scenario::WriteFileDenied => match latest_tool_result(request) {
  475. Some((tool_output, _)) => text_message_response(
  476. "msg_write_denied_final",
  477. &format!("write_file denied as expected: {tool_output}"),
  478. ),
  479. None => tool_message_response(
  480. "msg_write_denied_tool",
  481. "toolu_write_denied",
  482. "write_file",
  483. json!({"path": "generated/denied.txt", "content": "should not exist\n"}),
  484. ),
  485. },
  486. Scenario::MultiToolTurnRoundtrip => {
  487. let tool_results = tool_results_by_name(request);
  488. match (
  489. tool_results.get("read_file"),
  490. tool_results.get("grep_search"),
  491. ) {
  492. (Some((read_output, _)), Some((grep_output, _))) => text_message_response(
  493. "msg_multi_tool_final",
  494. &format!(
  495. "multi-tool roundtrip complete: {} / {} occurrences",
  496. extract_read_content(read_output),
  497. extract_num_matches(grep_output)
  498. ),
  499. ),
  500. _ => tool_message_response_many(
  501. "msg_multi_tool_start",
  502. &[
  503. ToolUseMessage {
  504. tool_id: "toolu_multi_read",
  505. tool_name: "read_file",
  506. input: json!({"path": "fixture.txt"}),
  507. },
  508. ToolUseMessage {
  509. tool_id: "toolu_multi_grep",
  510. tool_name: "grep_search",
  511. input: json!({"pattern": "parity", "path": "fixture.txt", "output_mode": "count"}),
  512. },
  513. ],
  514. ),
  515. }
  516. }
  517. Scenario::BashStdoutRoundtrip => match latest_tool_result(request) {
  518. Some((tool_output, _)) => text_message_response(
  519. "msg_bash_stdout_final",
  520. &format!("bash completed: {}", extract_bash_stdout(&tool_output)),
  521. ),
  522. None => tool_message_response(
  523. "msg_bash_stdout_tool",
  524. "toolu_bash_stdout",
  525. "bash",
  526. json!({"command": "printf 'alpha from bash'", "timeout": 1000}),
  527. ),
  528. },
  529. Scenario::BashPermissionPromptApproved => match latest_tool_result(request) {
  530. Some((tool_output, is_error)) => {
  531. if is_error {
  532. text_message_response(
  533. "msg_bash_prompt_allow_error",
  534. &format!("bash approval unexpectedly failed: {tool_output}"),
  535. )
  536. } else {
  537. text_message_response(
  538. "msg_bash_prompt_allow_final",
  539. &format!(
  540. "bash approved and executed: {}",
  541. extract_bash_stdout(&tool_output)
  542. ),
  543. )
  544. }
  545. }
  546. None => tool_message_response(
  547. "msg_bash_prompt_allow_tool",
  548. "toolu_bash_prompt_allow",
  549. "bash",
  550. json!({"command": "printf 'approved via prompt'", "timeout": 1000}),
  551. ),
  552. },
  553. Scenario::BashPermissionPromptDenied => match latest_tool_result(request) {
  554. Some((tool_output, _)) => text_message_response(
  555. "msg_bash_prompt_deny_final",
  556. &format!("bash denied as expected: {tool_output}"),
  557. ),
  558. None => tool_message_response(
  559. "msg_bash_prompt_deny_tool",
  560. "toolu_bash_prompt_deny",
  561. "bash",
  562. json!({"command": "printf 'should not run'", "timeout": 1000}),
  563. ),
  564. },
  565. Scenario::PluginToolRoundtrip => match latest_tool_result(request) {
  566. Some((tool_output, _)) => text_message_response(
  567. "msg_plugin_tool_final",
  568. &format!(
  569. "plugin tool completed: {}",
  570. extract_plugin_message(&tool_output)
  571. ),
  572. ),
  573. None => tool_message_response(
  574. "msg_plugin_tool_start",
  575. "toolu_plugin_echo",
  576. "plugin_echo",
  577. json!({"message": "hello from plugin parity"}),
  578. ),
  579. },
  580. }
  581. }
  582. fn request_id_for(scenario: Scenario) -> &'static str {
  583. match scenario {
  584. Scenario::StreamingText => "req_streaming_text",
  585. Scenario::ReadFileRoundtrip => "req_read_file_roundtrip",
  586. Scenario::GrepChunkAssembly => "req_grep_chunk_assembly",
  587. Scenario::WriteFileAllowed => "req_write_file_allowed",
  588. Scenario::WriteFileDenied => "req_write_file_denied",
  589. Scenario::MultiToolTurnRoundtrip => "req_multi_tool_turn_roundtrip",
  590. Scenario::BashStdoutRoundtrip => "req_bash_stdout_roundtrip",
  591. Scenario::BashPermissionPromptApproved => "req_bash_permission_prompt_approved",
  592. Scenario::BashPermissionPromptDenied => "req_bash_permission_prompt_denied",
  593. Scenario::PluginToolRoundtrip => "req_plugin_tool_roundtrip",
  594. }
  595. }
  596. fn http_response(status: &str, content_type: &str, body: &str, headers: &[(&str, &str)]) -> String {
  597. let mut extra_headers = String::new();
  598. for (name, value) in headers {
  599. use std::fmt::Write as _;
  600. write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed");
  601. }
  602. format!(
  603. "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
  604. body.len()
  605. )
  606. }
  607. fn text_message_response(id: &str, text: &str) -> MessageResponse {
  608. MessageResponse {
  609. id: id.to_string(),
  610. kind: "message".to_string(),
  611. role: "assistant".to_string(),
  612. content: vec![OutputContentBlock::Text {
  613. text: text.to_string(),
  614. }],
  615. model: DEFAULT_MODEL.to_string(),
  616. stop_reason: Some("end_turn".to_string()),
  617. stop_sequence: None,
  618. usage: Usage {
  619. input_tokens: 10,
  620. cache_creation_input_tokens: 0,
  621. cache_read_input_tokens: 0,
  622. output_tokens: 6,
  623. },
  624. request_id: None,
  625. }
  626. }
  627. fn tool_message_response(
  628. id: &str,
  629. tool_id: &str,
  630. tool_name: &str,
  631. input: Value,
  632. ) -> MessageResponse {
  633. tool_message_response_many(
  634. id,
  635. &[ToolUseMessage {
  636. tool_id,
  637. tool_name,
  638. input,
  639. }],
  640. )
  641. }
  642. struct ToolUseMessage<'a> {
  643. tool_id: &'a str,
  644. tool_name: &'a str,
  645. input: Value,
  646. }
  647. fn tool_message_response_many(id: &str, tool_uses: &[ToolUseMessage<'_>]) -> MessageResponse {
  648. MessageResponse {
  649. id: id.to_string(),
  650. kind: "message".to_string(),
  651. role: "assistant".to_string(),
  652. content: tool_uses
  653. .iter()
  654. .map(|tool_use| OutputContentBlock::ToolUse {
  655. id: tool_use.tool_id.to_string(),
  656. name: tool_use.tool_name.to_string(),
  657. input: tool_use.input.clone(),
  658. })
  659. .collect(),
  660. model: DEFAULT_MODEL.to_string(),
  661. stop_reason: Some("tool_use".to_string()),
  662. stop_sequence: None,
  663. usage: Usage {
  664. input_tokens: 10,
  665. cache_creation_input_tokens: 0,
  666. cache_read_input_tokens: 0,
  667. output_tokens: 3,
  668. },
  669. request_id: None,
  670. }
  671. }
  672. fn streaming_text_sse() -> String {
  673. let mut body = String::new();
  674. append_sse(
  675. &mut body,
  676. "message_start",
  677. json!({
  678. "type": "message_start",
  679. "message": {
  680. "id": "msg_streaming_text",
  681. "type": "message",
  682. "role": "assistant",
  683. "content": [],
  684. "model": DEFAULT_MODEL,
  685. "stop_reason": null,
  686. "stop_sequence": null,
  687. "usage": usage_json(11, 0)
  688. }
  689. }),
  690. );
  691. append_sse(
  692. &mut body,
  693. "content_block_start",
  694. json!({
  695. "type": "content_block_start",
  696. "index": 0,
  697. "content_block": {"type": "text", "text": ""}
  698. }),
  699. );
  700. append_sse(
  701. &mut body,
  702. "content_block_delta",
  703. json!({
  704. "type": "content_block_delta",
  705. "index": 0,
  706. "delta": {"type": "text_delta", "text": "Mock streaming "}
  707. }),
  708. );
  709. append_sse(
  710. &mut body,
  711. "content_block_delta",
  712. json!({
  713. "type": "content_block_delta",
  714. "index": 0,
  715. "delta": {"type": "text_delta", "text": "says hello from the parity harness."}
  716. }),
  717. );
  718. append_sse(
  719. &mut body,
  720. "content_block_stop",
  721. json!({
  722. "type": "content_block_stop",
  723. "index": 0
  724. }),
  725. );
  726. append_sse(
  727. &mut body,
  728. "message_delta",
  729. json!({
  730. "type": "message_delta",
  731. "delta": {"stop_reason": "end_turn", "stop_sequence": null},
  732. "usage": usage_json(11, 8)
  733. }),
  734. );
  735. append_sse(&mut body, "message_stop", json!({"type": "message_stop"}));
  736. body
  737. }
  738. fn tool_use_sse(tool_id: &str, tool_name: &str, partial_json_chunks: &[&str]) -> String {
  739. tool_uses_sse(&[ToolUseSse {
  740. tool_id,
  741. tool_name,
  742. partial_json_chunks,
  743. }])
  744. }
  745. struct ToolUseSse<'a> {
  746. tool_id: &'a str,
  747. tool_name: &'a str,
  748. partial_json_chunks: &'a [&'a str],
  749. }
  750. fn tool_uses_sse(tool_uses: &[ToolUseSse<'_>]) -> String {
  751. let mut body = String::new();
  752. let message_id = tool_uses.first().map_or_else(
  753. || "msg_tool_use".to_string(),
  754. |tool_use| format!("msg_{}", tool_use.tool_id),
  755. );
  756. append_sse(
  757. &mut body,
  758. "message_start",
  759. json!({
  760. "type": "message_start",
  761. "message": {
  762. "id": message_id,
  763. "type": "message",
  764. "role": "assistant",
  765. "content": [],
  766. "model": DEFAULT_MODEL,
  767. "stop_reason": null,
  768. "stop_sequence": null,
  769. "usage": usage_json(12, 0)
  770. }
  771. }),
  772. );
  773. for (index, tool_use) in tool_uses.iter().enumerate() {
  774. append_sse(
  775. &mut body,
  776. "content_block_start",
  777. json!({
  778. "type": "content_block_start",
  779. "index": index,
  780. "content_block": {
  781. "type": "tool_use",
  782. "id": tool_use.tool_id,
  783. "name": tool_use.tool_name,
  784. "input": {}
  785. }
  786. }),
  787. );
  788. for chunk in tool_use.partial_json_chunks {
  789. append_sse(
  790. &mut body,
  791. "content_block_delta",
  792. json!({
  793. "type": "content_block_delta",
  794. "index": index,
  795. "delta": {"type": "input_json_delta", "partial_json": chunk}
  796. }),
  797. );
  798. }
  799. append_sse(
  800. &mut body,
  801. "content_block_stop",
  802. json!({
  803. "type": "content_block_stop",
  804. "index": index
  805. }),
  806. );
  807. }
  808. append_sse(
  809. &mut body,
  810. "message_delta",
  811. json!({
  812. "type": "message_delta",
  813. "delta": {"stop_reason": "tool_use", "stop_sequence": null},
  814. "usage": usage_json(12, 4)
  815. }),
  816. );
  817. append_sse(&mut body, "message_stop", json!({"type": "message_stop"}));
  818. body
  819. }
  820. fn final_text_sse(text: &str) -> String {
  821. let mut body = String::new();
  822. append_sse(
  823. &mut body,
  824. "message_start",
  825. json!({
  826. "type": "message_start",
  827. "message": {
  828. "id": unique_message_id(),
  829. "type": "message",
  830. "role": "assistant",
  831. "content": [],
  832. "model": DEFAULT_MODEL,
  833. "stop_reason": null,
  834. "stop_sequence": null,
  835. "usage": usage_json(14, 0)
  836. }
  837. }),
  838. );
  839. append_sse(
  840. &mut body,
  841. "content_block_start",
  842. json!({
  843. "type": "content_block_start",
  844. "index": 0,
  845. "content_block": {"type": "text", "text": ""}
  846. }),
  847. );
  848. append_sse(
  849. &mut body,
  850. "content_block_delta",
  851. json!({
  852. "type": "content_block_delta",
  853. "index": 0,
  854. "delta": {"type": "text_delta", "text": text}
  855. }),
  856. );
  857. append_sse(
  858. &mut body,
  859. "content_block_stop",
  860. json!({
  861. "type": "content_block_stop",
  862. "index": 0
  863. }),
  864. );
  865. append_sse(
  866. &mut body,
  867. "message_delta",
  868. json!({
  869. "type": "message_delta",
  870. "delta": {"stop_reason": "end_turn", "stop_sequence": null},
  871. "usage": usage_json(14, 7)
  872. }),
  873. );
  874. append_sse(&mut body, "message_stop", json!({"type": "message_stop"}));
  875. body
  876. }
  877. #[allow(clippy::needless_pass_by_value)]
  878. fn append_sse(buffer: &mut String, event: &str, payload: Value) {
  879. use std::fmt::Write as _;
  880. writeln!(buffer, "event: {event}").expect("event write should succeed");
  881. writeln!(buffer, "data: {payload}").expect("payload write should succeed");
  882. buffer.push('\n');
  883. }
  884. fn usage_json(input_tokens: u32, output_tokens: u32) -> Value {
  885. json!({
  886. "input_tokens": input_tokens,
  887. "cache_creation_input_tokens": 0,
  888. "cache_read_input_tokens": 0,
  889. "output_tokens": output_tokens
  890. })
  891. }
  892. fn unique_message_id() -> String {
  893. let nanos = SystemTime::now()
  894. .duration_since(UNIX_EPOCH)
  895. .expect("clock should be after epoch")
  896. .as_nanos();
  897. format!("msg_{nanos}")
  898. }
  899. fn extract_read_content(tool_output: &str) -> String {
  900. serde_json::from_str::<Value>(tool_output)
  901. .ok()
  902. .and_then(|value| {
  903. value
  904. .get("file")
  905. .and_then(|file| file.get("content"))
  906. .and_then(Value::as_str)
  907. .map(ToOwned::to_owned)
  908. })
  909. .unwrap_or_else(|| tool_output.trim().to_string())
  910. }
  911. #[allow(clippy::cast_possible_truncation)]
  912. fn extract_num_matches(tool_output: &str) -> usize {
  913. serde_json::from_str::<Value>(tool_output)
  914. .ok()
  915. .and_then(|value| value.get("numMatches").and_then(Value::as_u64))
  916. .unwrap_or(0) as usize
  917. }
  918. fn extract_file_path(tool_output: &str) -> String {
  919. serde_json::from_str::<Value>(tool_output)
  920. .ok()
  921. .and_then(|value| {
  922. value
  923. .get("filePath")
  924. .and_then(Value::as_str)
  925. .map(ToOwned::to_owned)
  926. })
  927. .unwrap_or_else(|| tool_output.trim().to_string())
  928. }
  929. fn extract_bash_stdout(tool_output: &str) -> String {
  930. serde_json::from_str::<Value>(tool_output)
  931. .ok()
  932. .and_then(|value| {
  933. value
  934. .get("stdout")
  935. .and_then(Value::as_str)
  936. .map(ToOwned::to_owned)
  937. })
  938. .unwrap_or_else(|| tool_output.trim().to_string())
  939. }
  940. fn extract_plugin_message(tool_output: &str) -> String {
  941. serde_json::from_str::<Value>(tool_output)
  942. .ok()
  943. .and_then(|value| {
  944. value
  945. .get("input")
  946. .and_then(|input| input.get("message"))
  947. .and_then(Value::as_str)
  948. .map(ToOwned::to_owned)
  949. })
  950. .unwrap_or_else(|| tool_output.trim().to_string())
  951. }