sse.rs 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. use crate::error::ApiError;
  2. use crate::types::StreamEvent;
  3. #[derive(Debug, Default)]
  4. pub struct SseParser {
  5. buffer: Vec<u8>,
  6. }
  7. impl SseParser {
  8. #[must_use]
  9. pub fn new() -> Self {
  10. Self::default()
  11. }
  12. pub fn push(&mut self, chunk: &[u8]) -> Result<Vec<StreamEvent>, ApiError> {
  13. self.buffer.extend_from_slice(chunk);
  14. let mut events = Vec::new();
  15. while let Some(frame) = self.next_frame() {
  16. if let Some(event) = parse_frame(&frame)? {
  17. events.push(event);
  18. }
  19. }
  20. Ok(events)
  21. }
  22. pub fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
  23. if self.buffer.is_empty() {
  24. return Ok(Vec::new());
  25. }
  26. let trailing = std::mem::take(&mut self.buffer);
  27. match parse_frame(&String::from_utf8_lossy(&trailing))? {
  28. Some(event) => Ok(vec![event]),
  29. None => Ok(Vec::new()),
  30. }
  31. }
  32. fn next_frame(&mut self) -> Option<String> {
  33. let separator = self
  34. .buffer
  35. .windows(2)
  36. .position(|window| window == b"\n\n")
  37. .map(|position| (position, 2))
  38. .or_else(|| {
  39. self.buffer
  40. .windows(4)
  41. .position(|window| window == b"\r\n\r\n")
  42. .map(|position| (position, 4))
  43. })?;
  44. let (position, separator_len) = separator;
  45. let frame = self
  46. .buffer
  47. .drain(..position + separator_len)
  48. .collect::<Vec<_>>();
  49. let frame_len = frame.len().saturating_sub(separator_len);
  50. Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
  51. }
  52. }
  53. pub fn parse_frame(frame: &str) -> Result<Option<StreamEvent>, ApiError> {
  54. let trimmed = frame.trim();
  55. if trimmed.is_empty() {
  56. return Ok(None);
  57. }
  58. let mut data_lines = Vec::new();
  59. let mut event_name: Option<&str> = None;
  60. for line in trimmed.lines() {
  61. if line.starts_with(':') {
  62. continue;
  63. }
  64. if let Some(name) = line.strip_prefix("event:") {
  65. event_name = Some(name.trim());
  66. continue;
  67. }
  68. if let Some(data) = line.strip_prefix("data:") {
  69. data_lines.push(data.trim_start());
  70. }
  71. }
  72. if matches!(event_name, Some("ping")) {
  73. return Ok(None);
  74. }
  75. if data_lines.is_empty() {
  76. return Ok(None);
  77. }
  78. let payload = data_lines.join("\n");
  79. if payload == "[DONE]" {
  80. return Ok(None);
  81. }
  82. serde_json::from_str::<StreamEvent>(&payload)
  83. .map(Some)
  84. .map_err(ApiError::from)
  85. }
  86. #[cfg(test)]
  87. mod tests {
  88. use super::{parse_frame, SseParser};
  89. use crate::types::{ContentBlockDelta, MessageDelta, OutputContentBlock, StreamEvent, Usage};
  90. #[test]
  91. fn parses_single_frame() {
  92. let frame = concat!(
  93. "event: content_block_start\n",
  94. "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"Hi\"}}\n\n"
  95. );
  96. let event = parse_frame(frame).expect("frame should parse");
  97. assert_eq!(
  98. event,
  99. Some(StreamEvent::ContentBlockStart(
  100. crate::types::ContentBlockStartEvent {
  101. index: 0,
  102. content_block: OutputContentBlock::Text {
  103. text: "Hi".to_string(),
  104. },
  105. },
  106. ))
  107. );
  108. }
  109. #[test]
  110. fn parses_chunked_stream() {
  111. let mut parser = SseParser::new();
  112. let first = b"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hel";
  113. let second = b"lo\"}}\n\n";
  114. assert!(parser
  115. .push(first)
  116. .expect("first chunk should buffer")
  117. .is_empty());
  118. let events = parser.push(second).expect("second chunk should parse");
  119. assert_eq!(
  120. events,
  121. vec![StreamEvent::ContentBlockDelta(
  122. crate::types::ContentBlockDeltaEvent {
  123. index: 0,
  124. delta: ContentBlockDelta::TextDelta {
  125. text: "Hello".to_string(),
  126. },
  127. }
  128. )]
  129. );
  130. }
  131. #[test]
  132. fn ignores_ping_and_done() {
  133. let mut parser = SseParser::new();
  134. let payload = concat!(
  135. ": keepalive\n",
  136. "event: ping\n",
  137. "data: {\"type\":\"ping\"}\n\n",
  138. "event: message_delta\n",
  139. "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}\n\n",
  140. "event: message_stop\n",
  141. "data: {\"type\":\"message_stop\"}\n\n",
  142. "data: [DONE]\n\n"
  143. );
  144. let events = parser
  145. .push(payload.as_bytes())
  146. .expect("parser should succeed");
  147. assert_eq!(
  148. events,
  149. vec![
  150. StreamEvent::MessageDelta(crate::types::MessageDeltaEvent {
  151. delta: MessageDelta {
  152. stop_reason: Some("tool_use".to_string()),
  153. stop_sequence: None,
  154. },
  155. usage: Usage {
  156. input_tokens: 1,
  157. cache_creation_input_tokens: 0,
  158. cache_read_input_tokens: 0,
  159. output_tokens: 2,
  160. },
  161. }),
  162. StreamEvent::MessageStop(crate::types::MessageStopEvent {}),
  163. ]
  164. );
  165. }
  166. #[test]
  167. fn ignores_data_less_event_frames() {
  168. let frame = "event: ping\n\n";
  169. let event = parse_frame(frame).expect("frame without data should be ignored");
  170. assert_eq!(event, None);
  171. }
  172. #[test]
  173. fn parses_split_json_across_data_lines() {
  174. let frame = concat!(
  175. "event: content_block_delta\n",
  176. "data: {\"type\":\"content_block_delta\",\"index\":0,\n",
  177. "data: \"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n"
  178. );
  179. let event = parse_frame(frame).expect("frame should parse");
  180. assert_eq!(
  181. event,
  182. Some(StreamEvent::ContentBlockDelta(
  183. crate::types::ContentBlockDeltaEvent {
  184. index: 0,
  185. delta: ContentBlockDelta::TextDelta {
  186. text: "Hello".to_string(),
  187. },
  188. }
  189. ))
  190. );
  191. }
  192. }