session.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. use std::collections::BTreeMap;
  2. use std::fmt::{Display, Formatter};
  3. use std::fs;
  4. use std::path::Path;
  5. use crate::json::{JsonError, JsonValue};
  6. use crate::usage::TokenUsage;
  7. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  8. pub enum MessageRole {
  9. System,
  10. User,
  11. Assistant,
  12. Tool,
  13. }
  14. #[derive(Debug, Clone, PartialEq, Eq)]
  15. pub enum ContentBlock {
  16. Text {
  17. text: String,
  18. },
  19. ToolUse {
  20. id: String,
  21. name: String,
  22. input: String,
  23. },
  24. ToolResult {
  25. tool_use_id: String,
  26. tool_name: String,
  27. output: String,
  28. is_error: bool,
  29. },
  30. }
  31. #[derive(Debug, Clone, PartialEq, Eq)]
  32. pub struct ConversationMessage {
  33. pub role: MessageRole,
  34. pub blocks: Vec<ContentBlock>,
  35. pub usage: Option<TokenUsage>,
  36. }
  37. #[derive(Debug, Clone, PartialEq, Eq)]
  38. pub struct SessionMetadata {
  39. pub started_at: String,
  40. pub model: String,
  41. pub message_count: u32,
  42. pub last_prompt: Option<String>,
  43. }
  44. #[derive(Debug, Clone, PartialEq, Eq)]
  45. pub struct Session {
  46. pub version: u32,
  47. pub messages: Vec<ConversationMessage>,
  48. pub metadata: Option<SessionMetadata>,
  49. }
  50. #[derive(Debug)]
  51. pub enum SessionError {
  52. Io(std::io::Error),
  53. Json(JsonError),
  54. Format(String),
  55. }
  56. impl Display for SessionError {
  57. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  58. match self {
  59. Self::Io(error) => write!(f, "{error}"),
  60. Self::Json(error) => write!(f, "{error}"),
  61. Self::Format(error) => write!(f, "{error}"),
  62. }
  63. }
  64. }
  65. impl std::error::Error for SessionError {}
  66. impl From<std::io::Error> for SessionError {
  67. fn from(value: std::io::Error) -> Self {
  68. Self::Io(value)
  69. }
  70. }
  71. impl From<JsonError> for SessionError {
  72. fn from(value: JsonError) -> Self {
  73. Self::Json(value)
  74. }
  75. }
  76. impl Session {
  77. #[must_use]
  78. pub fn new() -> Self {
  79. Self {
  80. version: 1,
  81. messages: Vec::new(),
  82. metadata: None,
  83. }
  84. }
  85. pub fn save_to_path(&self, path: impl AsRef<Path>) -> Result<(), SessionError> {
  86. fs::write(path, self.to_json().render())?;
  87. Ok(())
  88. }
  89. pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, SessionError> {
  90. let contents = fs::read_to_string(path)?;
  91. Self::from_json(&JsonValue::parse(&contents)?)
  92. }
  93. #[must_use]
  94. pub fn to_json(&self) -> JsonValue {
  95. let mut object = BTreeMap::new();
  96. object.insert(
  97. "version".to_string(),
  98. JsonValue::Number(i64::from(self.version)),
  99. );
  100. object.insert(
  101. "messages".to_string(),
  102. JsonValue::Array(
  103. self.messages
  104. .iter()
  105. .map(ConversationMessage::to_json)
  106. .collect(),
  107. ),
  108. );
  109. if let Some(metadata) = &self.metadata {
  110. object.insert("metadata".to_string(), metadata.to_json());
  111. }
  112. JsonValue::Object(object)
  113. }
  114. pub fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
  115. let object = value
  116. .as_object()
  117. .ok_or_else(|| SessionError::Format("session must be an object".to_string()))?;
  118. let version = object
  119. .get("version")
  120. .and_then(JsonValue::as_i64)
  121. .ok_or_else(|| SessionError::Format("missing version".to_string()))?;
  122. let version = u32::try_from(version)
  123. .map_err(|_| SessionError::Format("version out of range".to_string()))?;
  124. let messages = object
  125. .get("messages")
  126. .and_then(JsonValue::as_array)
  127. .ok_or_else(|| SessionError::Format("missing messages".to_string()))?
  128. .iter()
  129. .map(ConversationMessage::from_json)
  130. .collect::<Result<Vec<_>, _>>()?;
  131. let metadata = object
  132. .get("metadata")
  133. .map(SessionMetadata::from_json)
  134. .transpose()?;
  135. Ok(Self {
  136. version,
  137. messages,
  138. metadata,
  139. })
  140. }
  141. }
  142. impl Default for Session {
  143. fn default() -> Self {
  144. Self::new()
  145. }
  146. }
  147. impl SessionMetadata {
  148. #[must_use]
  149. pub fn to_json(&self) -> JsonValue {
  150. let mut object = BTreeMap::new();
  151. object.insert(
  152. "started_at".to_string(),
  153. JsonValue::String(self.started_at.clone()),
  154. );
  155. object.insert("model".to_string(), JsonValue::String(self.model.clone()));
  156. object.insert(
  157. "message_count".to_string(),
  158. JsonValue::Number(i64::from(self.message_count)),
  159. );
  160. if let Some(last_prompt) = &self.last_prompt {
  161. object.insert(
  162. "last_prompt".to_string(),
  163. JsonValue::String(last_prompt.clone()),
  164. );
  165. }
  166. JsonValue::Object(object)
  167. }
  168. fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
  169. let object = value.as_object().ok_or_else(|| {
  170. SessionError::Format("session metadata must be an object".to_string())
  171. })?;
  172. Ok(Self {
  173. started_at: required_string(object, "started_at")?,
  174. model: required_string(object, "model")?,
  175. message_count: required_u32(object, "message_count")?,
  176. last_prompt: optional_string(object, "last_prompt"),
  177. })
  178. }
  179. }
  180. impl ConversationMessage {
  181. #[must_use]
  182. pub fn user_text(text: impl Into<String>) -> Self {
  183. Self {
  184. role: MessageRole::User,
  185. blocks: vec![ContentBlock::Text { text: text.into() }],
  186. usage: None,
  187. }
  188. }
  189. #[must_use]
  190. pub fn assistant(blocks: Vec<ContentBlock>) -> Self {
  191. Self {
  192. role: MessageRole::Assistant,
  193. blocks,
  194. usage: None,
  195. }
  196. }
  197. #[must_use]
  198. pub fn assistant_with_usage(blocks: Vec<ContentBlock>, usage: Option<TokenUsage>) -> Self {
  199. Self {
  200. role: MessageRole::Assistant,
  201. blocks,
  202. usage,
  203. }
  204. }
  205. #[must_use]
  206. pub fn tool_result(
  207. tool_use_id: impl Into<String>,
  208. tool_name: impl Into<String>,
  209. output: impl Into<String>,
  210. is_error: bool,
  211. ) -> Self {
  212. Self {
  213. role: MessageRole::Tool,
  214. blocks: vec![ContentBlock::ToolResult {
  215. tool_use_id: tool_use_id.into(),
  216. tool_name: tool_name.into(),
  217. output: output.into(),
  218. is_error,
  219. }],
  220. usage: None,
  221. }
  222. }
  223. #[must_use]
  224. pub fn to_json(&self) -> JsonValue {
  225. let mut object = BTreeMap::new();
  226. object.insert(
  227. "role".to_string(),
  228. JsonValue::String(
  229. match self.role {
  230. MessageRole::System => "system",
  231. MessageRole::User => "user",
  232. MessageRole::Assistant => "assistant",
  233. MessageRole::Tool => "tool",
  234. }
  235. .to_string(),
  236. ),
  237. );
  238. object.insert(
  239. "blocks".to_string(),
  240. JsonValue::Array(self.blocks.iter().map(ContentBlock::to_json).collect()),
  241. );
  242. if let Some(usage) = self.usage {
  243. object.insert("usage".to_string(), usage_to_json(usage));
  244. }
  245. JsonValue::Object(object)
  246. }
  247. fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
  248. let object = value
  249. .as_object()
  250. .ok_or_else(|| SessionError::Format("message must be an object".to_string()))?;
  251. let role = match object
  252. .get("role")
  253. .and_then(JsonValue::as_str)
  254. .ok_or_else(|| SessionError::Format("missing role".to_string()))?
  255. {
  256. "system" => MessageRole::System,
  257. "user" => MessageRole::User,
  258. "assistant" => MessageRole::Assistant,
  259. "tool" => MessageRole::Tool,
  260. other => {
  261. return Err(SessionError::Format(format!(
  262. "unsupported message role: {other}"
  263. )))
  264. }
  265. };
  266. let blocks = object
  267. .get("blocks")
  268. .and_then(JsonValue::as_array)
  269. .ok_or_else(|| SessionError::Format("missing blocks".to_string()))?
  270. .iter()
  271. .map(ContentBlock::from_json)
  272. .collect::<Result<Vec<_>, _>>()?;
  273. let usage = object.get("usage").map(usage_from_json).transpose()?;
  274. Ok(Self {
  275. role,
  276. blocks,
  277. usage,
  278. })
  279. }
  280. }
  281. impl ContentBlock {
  282. #[must_use]
  283. pub fn to_json(&self) -> JsonValue {
  284. let mut object = BTreeMap::new();
  285. match self {
  286. Self::Text { text } => {
  287. object.insert("type".to_string(), JsonValue::String("text".to_string()));
  288. object.insert("text".to_string(), JsonValue::String(text.clone()));
  289. }
  290. Self::ToolUse { id, name, input } => {
  291. object.insert(
  292. "type".to_string(),
  293. JsonValue::String("tool_use".to_string()),
  294. );
  295. object.insert("id".to_string(), JsonValue::String(id.clone()));
  296. object.insert("name".to_string(), JsonValue::String(name.clone()));
  297. object.insert("input".to_string(), JsonValue::String(input.clone()));
  298. }
  299. Self::ToolResult {
  300. tool_use_id,
  301. tool_name,
  302. output,
  303. is_error,
  304. } => {
  305. object.insert(
  306. "type".to_string(),
  307. JsonValue::String("tool_result".to_string()),
  308. );
  309. object.insert(
  310. "tool_use_id".to_string(),
  311. JsonValue::String(tool_use_id.clone()),
  312. );
  313. object.insert(
  314. "tool_name".to_string(),
  315. JsonValue::String(tool_name.clone()),
  316. );
  317. object.insert("output".to_string(), JsonValue::String(output.clone()));
  318. object.insert("is_error".to_string(), JsonValue::Bool(*is_error));
  319. }
  320. }
  321. JsonValue::Object(object)
  322. }
  323. fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
  324. let object = value
  325. .as_object()
  326. .ok_or_else(|| SessionError::Format("block must be an object".to_string()))?;
  327. match object
  328. .get("type")
  329. .and_then(JsonValue::as_str)
  330. .ok_or_else(|| SessionError::Format("missing block type".to_string()))?
  331. {
  332. "text" => Ok(Self::Text {
  333. text: required_string(object, "text")?,
  334. }),
  335. "tool_use" => Ok(Self::ToolUse {
  336. id: required_string(object, "id")?,
  337. name: required_string(object, "name")?,
  338. input: required_string(object, "input")?,
  339. }),
  340. "tool_result" => Ok(Self::ToolResult {
  341. tool_use_id: required_string(object, "tool_use_id")?,
  342. tool_name: required_string(object, "tool_name")?,
  343. output: required_string(object, "output")?,
  344. is_error: object
  345. .get("is_error")
  346. .and_then(JsonValue::as_bool)
  347. .ok_or_else(|| SessionError::Format("missing is_error".to_string()))?,
  348. }),
  349. other => Err(SessionError::Format(format!(
  350. "unsupported block type: {other}"
  351. ))),
  352. }
  353. }
  354. }
  355. fn usage_to_json(usage: TokenUsage) -> JsonValue {
  356. let mut object = BTreeMap::new();
  357. object.insert(
  358. "input_tokens".to_string(),
  359. JsonValue::Number(i64::from(usage.input_tokens)),
  360. );
  361. object.insert(
  362. "output_tokens".to_string(),
  363. JsonValue::Number(i64::from(usage.output_tokens)),
  364. );
  365. object.insert(
  366. "cache_creation_input_tokens".to_string(),
  367. JsonValue::Number(i64::from(usage.cache_creation_input_tokens)),
  368. );
  369. object.insert(
  370. "cache_read_input_tokens".to_string(),
  371. JsonValue::Number(i64::from(usage.cache_read_input_tokens)),
  372. );
  373. JsonValue::Object(object)
  374. }
  375. fn usage_from_json(value: &JsonValue) -> Result<TokenUsage, SessionError> {
  376. let object = value
  377. .as_object()
  378. .ok_or_else(|| SessionError::Format("usage must be an object".to_string()))?;
  379. Ok(TokenUsage {
  380. input_tokens: required_u32(object, "input_tokens")?,
  381. output_tokens: required_u32(object, "output_tokens")?,
  382. cache_creation_input_tokens: required_u32(object, "cache_creation_input_tokens")?,
  383. cache_read_input_tokens: required_u32(object, "cache_read_input_tokens")?,
  384. })
  385. }
  386. fn required_string(
  387. object: &BTreeMap<String, JsonValue>,
  388. key: &str,
  389. ) -> Result<String, SessionError> {
  390. object
  391. .get(key)
  392. .and_then(JsonValue::as_str)
  393. .map(ToOwned::to_owned)
  394. .ok_or_else(|| SessionError::Format(format!("missing {key}")))
  395. }
  396. fn optional_string(object: &BTreeMap<String, JsonValue>, key: &str) -> Option<String> {
  397. object
  398. .get(key)
  399. .and_then(JsonValue::as_str)
  400. .map(ToOwned::to_owned)
  401. }
  402. fn required_u32(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u32, SessionError> {
  403. let value = object
  404. .get(key)
  405. .and_then(JsonValue::as_i64)
  406. .ok_or_else(|| SessionError::Format(format!("missing {key}")))?;
  407. u32::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range")))
  408. }
  409. #[cfg(test)]
  410. mod tests {
  411. use super::{ContentBlock, ConversationMessage, MessageRole, Session, SessionMetadata};
  412. use crate::json::JsonValue;
  413. use crate::usage::TokenUsage;
  414. use std::fs;
  415. use std::time::{SystemTime, UNIX_EPOCH};
  416. #[test]
  417. fn persists_and_restores_session_json() {
  418. let mut session = Session::new();
  419. session.metadata = Some(SessionMetadata {
  420. started_at: "2026-04-01T00:00:00Z".to_string(),
  421. model: "claude-sonnet".to_string(),
  422. message_count: 3,
  423. last_prompt: Some("hello".to_string()),
  424. });
  425. session
  426. .messages
  427. .push(ConversationMessage::user_text("hello"));
  428. session
  429. .messages
  430. .push(ConversationMessage::assistant_with_usage(
  431. vec![
  432. ContentBlock::Text {
  433. text: "thinking".to_string(),
  434. },
  435. ContentBlock::ToolUse {
  436. id: "tool-1".to_string(),
  437. name: "bash".to_string(),
  438. input: "echo hi".to_string(),
  439. },
  440. ],
  441. Some(TokenUsage {
  442. input_tokens: 10,
  443. output_tokens: 4,
  444. cache_creation_input_tokens: 1,
  445. cache_read_input_tokens: 2,
  446. }),
  447. ));
  448. session.messages.push(ConversationMessage::tool_result(
  449. "tool-1", "bash", "hi", false,
  450. ));
  451. let nanos = SystemTime::now()
  452. .duration_since(UNIX_EPOCH)
  453. .expect("system time should be after epoch")
  454. .as_nanos();
  455. let path = std::env::temp_dir().join(format!("runtime-session-{nanos}.json"));
  456. session.save_to_path(&path).expect("session should save");
  457. let restored = Session::load_from_path(&path).expect("session should load");
  458. fs::remove_file(&path).expect("temp file should be removable");
  459. assert_eq!(restored, session);
  460. assert_eq!(restored.messages[2].role, MessageRole::Tool);
  461. assert_eq!(
  462. restored.messages[1].usage.expect("usage").total_tokens(),
  463. 17
  464. );
  465. assert_eq!(restored.metadata, session.metadata);
  466. }
  467. #[test]
  468. fn loads_legacy_session_without_metadata() {
  469. let legacy = r#"{
  470. "version": 1,
  471. "messages": [
  472. {
  473. "role": "user",
  474. "blocks": [{"type": "text", "text": "hello"}]
  475. }
  476. ]
  477. }"#;
  478. let restored = Session::from_json(&JsonValue::parse(legacy).expect("legacy json"))
  479. .expect("legacy session should parse");
  480. assert_eq!(restored.messages.len(), 1);
  481. assert!(restored.metadata.is_none());
  482. }
  483. }