session.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. use std::collections::BTreeMap;
  2. use std::fmt::{Display, Formatter};
  3. use std::fs;
  4. use std::path::Path;
  5. use crate::json::{JsonError, JsonValue};
  6. use crate::usage::TokenUsage;
  7. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  8. pub enum MessageRole {
  9. System,
  10. User,
  11. Assistant,
  12. Tool,
  13. }
  14. #[derive(Debug, Clone, PartialEq, Eq)]
  15. pub enum ContentBlock {
  16. Text {
  17. text: String,
  18. },
  19. ToolUse {
  20. id: String,
  21. name: String,
  22. input: String,
  23. },
  24. ToolResult {
  25. tool_use_id: String,
  26. tool_name: String,
  27. output: String,
  28. is_error: bool,
  29. },
  30. }
  31. #[derive(Debug, Clone, PartialEq, Eq)]
  32. pub struct ConversationMessage {
  33. pub role: MessageRole,
  34. pub blocks: Vec<ContentBlock>,
  35. pub usage: Option<TokenUsage>,
  36. }
  37. #[derive(Debug, Clone, PartialEq, Eq)]
  38. pub struct Session {
  39. pub version: u32,
  40. pub messages: Vec<ConversationMessage>,
  41. }
  42. #[derive(Debug)]
  43. pub enum SessionError {
  44. Io(std::io::Error),
  45. Json(JsonError),
  46. Format(String),
  47. }
  48. impl Display for SessionError {
  49. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  50. match self {
  51. Self::Io(error) => write!(f, "{error}"),
  52. Self::Json(error) => write!(f, "{error}"),
  53. Self::Format(error) => write!(f, "{error}"),
  54. }
  55. }
  56. }
  57. impl std::error::Error for SessionError {}
  58. impl From<std::io::Error> for SessionError {
  59. fn from(value: std::io::Error) -> Self {
  60. Self::Io(value)
  61. }
  62. }
  63. impl From<JsonError> for SessionError {
  64. fn from(value: JsonError) -> Self {
  65. Self::Json(value)
  66. }
  67. }
  68. impl Session {
  69. #[must_use]
  70. pub fn new() -> Self {
  71. Self {
  72. version: 1,
  73. messages: Vec::new(),
  74. }
  75. }
  76. pub fn save_to_path(&self, path: impl AsRef<Path>) -> Result<(), SessionError> {
  77. fs::write(path, self.to_json().render())?;
  78. Ok(())
  79. }
  80. pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, SessionError> {
  81. let contents = fs::read_to_string(path)?;
  82. Self::from_json(&JsonValue::parse(&contents)?)
  83. }
  84. #[must_use]
  85. pub fn to_json(&self) -> JsonValue {
  86. let mut object = BTreeMap::new();
  87. object.insert(
  88. "version".to_string(),
  89. JsonValue::Number(i64::from(self.version)),
  90. );
  91. object.insert(
  92. "messages".to_string(),
  93. JsonValue::Array(
  94. self.messages
  95. .iter()
  96. .map(ConversationMessage::to_json)
  97. .collect(),
  98. ),
  99. );
  100. JsonValue::Object(object)
  101. }
  102. pub fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
  103. let object = value
  104. .as_object()
  105. .ok_or_else(|| SessionError::Format("session must be an object".to_string()))?;
  106. let version = object
  107. .get("version")
  108. .and_then(JsonValue::as_i64)
  109. .ok_or_else(|| SessionError::Format("missing version".to_string()))?;
  110. let version = u32::try_from(version)
  111. .map_err(|_| SessionError::Format("version out of range".to_string()))?;
  112. let messages = object
  113. .get("messages")
  114. .and_then(JsonValue::as_array)
  115. .ok_or_else(|| SessionError::Format("missing messages".to_string()))?
  116. .iter()
  117. .map(ConversationMessage::from_json)
  118. .collect::<Result<Vec<_>, _>>()?;
  119. Ok(Self { version, messages })
  120. }
  121. }
  122. impl Default for Session {
  123. fn default() -> Self {
  124. Self::new()
  125. }
  126. }
  127. impl ConversationMessage {
  128. #[must_use]
  129. pub fn user_text(text: impl Into<String>) -> Self {
  130. Self {
  131. role: MessageRole::User,
  132. blocks: vec![ContentBlock::Text { text: text.into() }],
  133. usage: None,
  134. }
  135. }
  136. #[must_use]
  137. pub fn assistant(blocks: Vec<ContentBlock>) -> Self {
  138. Self {
  139. role: MessageRole::Assistant,
  140. blocks,
  141. usage: None,
  142. }
  143. }
  144. #[must_use]
  145. pub fn assistant_with_usage(blocks: Vec<ContentBlock>, usage: Option<TokenUsage>) -> Self {
  146. Self {
  147. role: MessageRole::Assistant,
  148. blocks,
  149. usage,
  150. }
  151. }
  152. #[must_use]
  153. pub fn tool_result(
  154. tool_use_id: impl Into<String>,
  155. tool_name: impl Into<String>,
  156. output: impl Into<String>,
  157. is_error: bool,
  158. ) -> Self {
  159. Self {
  160. role: MessageRole::Tool,
  161. blocks: vec![ContentBlock::ToolResult {
  162. tool_use_id: tool_use_id.into(),
  163. tool_name: tool_name.into(),
  164. output: output.into(),
  165. is_error,
  166. }],
  167. usage: None,
  168. }
  169. }
  170. #[must_use]
  171. pub fn to_json(&self) -> JsonValue {
  172. let mut object = BTreeMap::new();
  173. object.insert(
  174. "role".to_string(),
  175. JsonValue::String(
  176. match self.role {
  177. MessageRole::System => "system",
  178. MessageRole::User => "user",
  179. MessageRole::Assistant => "assistant",
  180. MessageRole::Tool => "tool",
  181. }
  182. .to_string(),
  183. ),
  184. );
  185. object.insert(
  186. "blocks".to_string(),
  187. JsonValue::Array(self.blocks.iter().map(ContentBlock::to_json).collect()),
  188. );
  189. if let Some(usage) = self.usage {
  190. object.insert("usage".to_string(), usage_to_json(usage));
  191. }
  192. JsonValue::Object(object)
  193. }
  194. fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
  195. let object = value
  196. .as_object()
  197. .ok_or_else(|| SessionError::Format("message must be an object".to_string()))?;
  198. let role = match object
  199. .get("role")
  200. .and_then(JsonValue::as_str)
  201. .ok_or_else(|| SessionError::Format("missing role".to_string()))?
  202. {
  203. "system" => MessageRole::System,
  204. "user" => MessageRole::User,
  205. "assistant" => MessageRole::Assistant,
  206. "tool" => MessageRole::Tool,
  207. other => {
  208. return Err(SessionError::Format(format!(
  209. "unsupported message role: {other}"
  210. )))
  211. }
  212. };
  213. let blocks = object
  214. .get("blocks")
  215. .and_then(JsonValue::as_array)
  216. .ok_or_else(|| SessionError::Format("missing blocks".to_string()))?
  217. .iter()
  218. .map(ContentBlock::from_json)
  219. .collect::<Result<Vec<_>, _>>()?;
  220. let usage = object.get("usage").map(usage_from_json).transpose()?;
  221. Ok(Self {
  222. role,
  223. blocks,
  224. usage,
  225. })
  226. }
  227. }
  228. impl ContentBlock {
  229. #[must_use]
  230. pub fn to_json(&self) -> JsonValue {
  231. let mut object = BTreeMap::new();
  232. match self {
  233. Self::Text { text } => {
  234. object.insert("type".to_string(), JsonValue::String("text".to_string()));
  235. object.insert("text".to_string(), JsonValue::String(text.clone()));
  236. }
  237. Self::ToolUse { id, name, input } => {
  238. object.insert(
  239. "type".to_string(),
  240. JsonValue::String("tool_use".to_string()),
  241. );
  242. object.insert("id".to_string(), JsonValue::String(id.clone()));
  243. object.insert("name".to_string(), JsonValue::String(name.clone()));
  244. object.insert("input".to_string(), JsonValue::String(input.clone()));
  245. }
  246. Self::ToolResult {
  247. tool_use_id,
  248. tool_name,
  249. output,
  250. is_error,
  251. } => {
  252. object.insert(
  253. "type".to_string(),
  254. JsonValue::String("tool_result".to_string()),
  255. );
  256. object.insert(
  257. "tool_use_id".to_string(),
  258. JsonValue::String(tool_use_id.clone()),
  259. );
  260. object.insert(
  261. "tool_name".to_string(),
  262. JsonValue::String(tool_name.clone()),
  263. );
  264. object.insert("output".to_string(), JsonValue::String(output.clone()));
  265. object.insert("is_error".to_string(), JsonValue::Bool(*is_error));
  266. }
  267. }
  268. JsonValue::Object(object)
  269. }
  270. fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
  271. let object = value
  272. .as_object()
  273. .ok_or_else(|| SessionError::Format("block must be an object".to_string()))?;
  274. match object
  275. .get("type")
  276. .and_then(JsonValue::as_str)
  277. .ok_or_else(|| SessionError::Format("missing block type".to_string()))?
  278. {
  279. "text" => Ok(Self::Text {
  280. text: required_string(object, "text")?,
  281. }),
  282. "tool_use" => Ok(Self::ToolUse {
  283. id: required_string(object, "id")?,
  284. name: required_string(object, "name")?,
  285. input: required_string(object, "input")?,
  286. }),
  287. "tool_result" => Ok(Self::ToolResult {
  288. tool_use_id: required_string(object, "tool_use_id")?,
  289. tool_name: required_string(object, "tool_name")?,
  290. output: required_string(object, "output")?,
  291. is_error: object
  292. .get("is_error")
  293. .and_then(JsonValue::as_bool)
  294. .ok_or_else(|| SessionError::Format("missing is_error".to_string()))?,
  295. }),
  296. other => Err(SessionError::Format(format!(
  297. "unsupported block type: {other}"
  298. ))),
  299. }
  300. }
  301. }
  302. fn usage_to_json(usage: TokenUsage) -> JsonValue {
  303. let mut object = BTreeMap::new();
  304. object.insert(
  305. "input_tokens".to_string(),
  306. JsonValue::Number(i64::from(usage.input_tokens)),
  307. );
  308. object.insert(
  309. "output_tokens".to_string(),
  310. JsonValue::Number(i64::from(usage.output_tokens)),
  311. );
  312. object.insert(
  313. "cache_creation_input_tokens".to_string(),
  314. JsonValue::Number(i64::from(usage.cache_creation_input_tokens)),
  315. );
  316. object.insert(
  317. "cache_read_input_tokens".to_string(),
  318. JsonValue::Number(i64::from(usage.cache_read_input_tokens)),
  319. );
  320. JsonValue::Object(object)
  321. }
  322. fn usage_from_json(value: &JsonValue) -> Result<TokenUsage, SessionError> {
  323. let object = value
  324. .as_object()
  325. .ok_or_else(|| SessionError::Format("usage must be an object".to_string()))?;
  326. Ok(TokenUsage {
  327. input_tokens: required_u32(object, "input_tokens")?,
  328. output_tokens: required_u32(object, "output_tokens")?,
  329. cache_creation_input_tokens: required_u32(object, "cache_creation_input_tokens")?,
  330. cache_read_input_tokens: required_u32(object, "cache_read_input_tokens")?,
  331. })
  332. }
  333. fn required_string(
  334. object: &BTreeMap<String, JsonValue>,
  335. key: &str,
  336. ) -> Result<String, SessionError> {
  337. object
  338. .get(key)
  339. .and_then(JsonValue::as_str)
  340. .map(ToOwned::to_owned)
  341. .ok_or_else(|| SessionError::Format(format!("missing {key}")))
  342. }
  343. fn required_u32(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u32, SessionError> {
  344. let value = object
  345. .get(key)
  346. .and_then(JsonValue::as_i64)
  347. .ok_or_else(|| SessionError::Format(format!("missing {key}")))?;
  348. u32::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range")))
  349. }
  350. #[cfg(test)]
  351. mod tests {
  352. use super::{ContentBlock, ConversationMessage, MessageRole, Session};
  353. use crate::usage::TokenUsage;
  354. use std::fs;
  355. use std::time::{SystemTime, UNIX_EPOCH};
  356. #[test]
  357. fn persists_and_restores_session_json() {
  358. let mut session = Session::new();
  359. session
  360. .messages
  361. .push(ConversationMessage::user_text("hello"));
  362. session
  363. .messages
  364. .push(ConversationMessage::assistant_with_usage(
  365. vec![
  366. ContentBlock::Text {
  367. text: "thinking".to_string(),
  368. },
  369. ContentBlock::ToolUse {
  370. id: "tool-1".to_string(),
  371. name: "bash".to_string(),
  372. input: "echo hi".to_string(),
  373. },
  374. ],
  375. Some(TokenUsage {
  376. input_tokens: 10,
  377. output_tokens: 4,
  378. cache_creation_input_tokens: 1,
  379. cache_read_input_tokens: 2,
  380. }),
  381. ));
  382. session.messages.push(ConversationMessage::tool_result(
  383. "tool-1", "bash", "hi", false,
  384. ));
  385. let nanos = SystemTime::now()
  386. .duration_since(UNIX_EPOCH)
  387. .expect("system time should be after epoch")
  388. .as_nanos();
  389. let path = std::env::temp_dir().join(format!("runtime-session-{nanos}.json"));
  390. session.save_to_path(&path).expect("session should save");
  391. let restored = Session::load_from_path(&path).expect("session should load");
  392. fs::remove_file(&path).expect("temp file should be removable");
  393. assert_eq!(restored, session);
  394. assert_eq!(restored.messages[2].role, MessageRole::Tool);
  395. assert_eq!(
  396. restored.messages[1].usage.expect("usage").total_tokens(),
  397. 17
  398. );
  399. }
  400. }