usage.rs 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. use crate::session::Session;
  2. const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
  3. const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
  4. const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
  5. const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
  6. #[derive(Debug, Clone, Copy, PartialEq)]
  7. pub struct ModelPricing {
  8. pub input_cost_per_million: f64,
  9. pub output_cost_per_million: f64,
  10. pub cache_creation_cost_per_million: f64,
  11. pub cache_read_cost_per_million: f64,
  12. }
  13. impl ModelPricing {
  14. #[must_use]
  15. pub const fn default_sonnet_tier() -> Self {
  16. Self {
  17. input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION,
  18. output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION,
  19. cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION,
  20. cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION,
  21. }
  22. }
  23. }
  24. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
  25. pub struct TokenUsage {
  26. pub input_tokens: u32,
  27. pub output_tokens: u32,
  28. pub cache_creation_input_tokens: u32,
  29. pub cache_read_input_tokens: u32,
  30. }
  31. #[derive(Debug, Clone, Copy, PartialEq)]
  32. pub struct UsageCostEstimate {
  33. pub input_cost_usd: f64,
  34. pub output_cost_usd: f64,
  35. pub cache_creation_cost_usd: f64,
  36. pub cache_read_cost_usd: f64,
  37. }
  38. impl UsageCostEstimate {
  39. #[must_use]
  40. pub fn total_cost_usd(self) -> f64 {
  41. self.input_cost_usd
  42. + self.output_cost_usd
  43. + self.cache_creation_cost_usd
  44. + self.cache_read_cost_usd
  45. }
  46. }
  47. #[must_use]
  48. pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
  49. let normalized = model.to_ascii_lowercase();
  50. if normalized.contains("haiku") {
  51. return Some(ModelPricing {
  52. input_cost_per_million: 1.0,
  53. output_cost_per_million: 5.0,
  54. cache_creation_cost_per_million: 1.25,
  55. cache_read_cost_per_million: 0.1,
  56. });
  57. }
  58. if normalized.contains("opus") {
  59. return Some(ModelPricing {
  60. input_cost_per_million: 15.0,
  61. output_cost_per_million: 75.0,
  62. cache_creation_cost_per_million: 18.75,
  63. cache_read_cost_per_million: 1.5,
  64. });
  65. }
  66. if normalized.contains("sonnet") {
  67. return Some(ModelPricing::default_sonnet_tier());
  68. }
  69. None
  70. }
  71. impl TokenUsage {
  72. #[must_use]
  73. pub fn total_tokens(self) -> u32 {
  74. self.input_tokens
  75. + self.output_tokens
  76. + self.cache_creation_input_tokens
  77. + self.cache_read_input_tokens
  78. }
  79. #[must_use]
  80. pub fn estimate_cost_usd(self) -> UsageCostEstimate {
  81. self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier())
  82. }
  83. #[must_use]
  84. pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate {
  85. UsageCostEstimate {
  86. input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
  87. output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million),
  88. cache_creation_cost_usd: cost_for_tokens(
  89. self.cache_creation_input_tokens,
  90. pricing.cache_creation_cost_per_million,
  91. ),
  92. cache_read_cost_usd: cost_for_tokens(
  93. self.cache_read_input_tokens,
  94. pricing.cache_read_cost_per_million,
  95. ),
  96. }
  97. }
  98. #[must_use]
  99. pub fn summary_lines(self, label: &str) -> Vec<String> {
  100. self.summary_lines_for_model(label, None)
  101. }
  102. #[must_use]
  103. pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec<String> {
  104. let pricing = model.and_then(pricing_for_model);
  105. let cost = pricing.map_or_else(
  106. || self.estimate_cost_usd(),
  107. |pricing| self.estimate_cost_usd_with_pricing(pricing),
  108. );
  109. let model_suffix =
  110. model.map_or_else(String::new, |model_name| format!(" model={model_name}"));
  111. let pricing_suffix = if pricing.is_some() {
  112. ""
  113. } else if model.is_some() {
  114. " pricing=estimated-default"
  115. } else {
  116. ""
  117. };
  118. vec![
  119. format!(
  120. "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}",
  121. self.total_tokens(),
  122. self.input_tokens,
  123. self.output_tokens,
  124. self.cache_creation_input_tokens,
  125. self.cache_read_input_tokens,
  126. format_usd(cost.total_cost_usd()),
  127. model_suffix,
  128. pricing_suffix,
  129. ),
  130. format!(
  131. " cost breakdown: input={} output={} cache_write={} cache_read={}",
  132. format_usd(cost.input_cost_usd),
  133. format_usd(cost.output_cost_usd),
  134. format_usd(cost.cache_creation_cost_usd),
  135. format_usd(cost.cache_read_cost_usd),
  136. ),
  137. ]
  138. }
  139. }
  140. fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
  141. f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens
  142. }
  143. #[must_use]
  144. pub fn format_usd(amount: f64) -> String {
  145. format!("${amount:.4}")
  146. }
  147. #[derive(Debug, Clone, Default, PartialEq, Eq)]
  148. pub struct UsageTracker {
  149. latest_turn: TokenUsage,
  150. cumulative: TokenUsage,
  151. turns: u32,
  152. }
  153. impl UsageTracker {
  154. #[must_use]
  155. pub fn new() -> Self {
  156. Self::default()
  157. }
  158. #[must_use]
  159. pub fn from_session(session: &Session) -> Self {
  160. let mut tracker = Self::new();
  161. for message in &session.messages {
  162. if let Some(usage) = message.usage {
  163. tracker.record(usage);
  164. }
  165. }
  166. tracker
  167. }
  168. pub fn record(&mut self, usage: TokenUsage) {
  169. self.latest_turn = usage;
  170. self.cumulative.input_tokens += usage.input_tokens;
  171. self.cumulative.output_tokens += usage.output_tokens;
  172. self.cumulative.cache_creation_input_tokens += usage.cache_creation_input_tokens;
  173. self.cumulative.cache_read_input_tokens += usage.cache_read_input_tokens;
  174. self.turns += 1;
  175. }
  176. #[must_use]
  177. pub fn current_turn_usage(&self) -> TokenUsage {
  178. self.latest_turn
  179. }
  180. #[must_use]
  181. pub fn cumulative_usage(&self) -> TokenUsage {
  182. self.cumulative
  183. }
  184. #[must_use]
  185. pub fn turns(&self) -> u32 {
  186. self.turns
  187. }
  188. }
  189. #[cfg(test)]
  190. mod tests {
  191. use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker};
  192. use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
  193. #[test]
  194. fn tracks_true_cumulative_usage() {
  195. let mut tracker = UsageTracker::new();
  196. tracker.record(TokenUsage {
  197. input_tokens: 10,
  198. output_tokens: 4,
  199. cache_creation_input_tokens: 2,
  200. cache_read_input_tokens: 1,
  201. });
  202. tracker.record(TokenUsage {
  203. input_tokens: 20,
  204. output_tokens: 6,
  205. cache_creation_input_tokens: 3,
  206. cache_read_input_tokens: 2,
  207. });
  208. assert_eq!(tracker.turns(), 2);
  209. assert_eq!(tracker.current_turn_usage().input_tokens, 20);
  210. assert_eq!(tracker.current_turn_usage().output_tokens, 6);
  211. assert_eq!(tracker.cumulative_usage().output_tokens, 10);
  212. assert_eq!(tracker.cumulative_usage().input_tokens, 30);
  213. assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
  214. }
  215. #[test]
  216. fn computes_cost_summary_lines() {
  217. let usage = TokenUsage {
  218. input_tokens: 1_000_000,
  219. output_tokens: 500_000,
  220. cache_creation_input_tokens: 100_000,
  221. cache_read_input_tokens: 200_000,
  222. };
  223. let cost = usage.estimate_cost_usd();
  224. assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
  225. assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
  226. let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-20250514"));
  227. assert!(lines[0].contains("estimated_cost=$54.6750"));
  228. assert!(lines[0].contains("model=claude-sonnet-4-20250514"));
  229. assert!(lines[1].contains("cache_read=$0.3000"));
  230. }
  231. #[test]
  232. fn supports_model_specific_pricing() {
  233. let usage = TokenUsage {
  234. input_tokens: 1_000_000,
  235. output_tokens: 500_000,
  236. cache_creation_input_tokens: 0,
  237. cache_read_input_tokens: 0,
  238. };
  239. let haiku = pricing_for_model("claude-haiku-4-5-20251001").expect("haiku pricing");
  240. let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing");
  241. let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
  242. let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
  243. assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000");
  244. assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000");
  245. }
  246. #[test]
  247. fn marks_unknown_model_pricing_as_fallback() {
  248. let usage = TokenUsage {
  249. input_tokens: 100,
  250. output_tokens: 100,
  251. cache_creation_input_tokens: 0,
  252. cache_read_input_tokens: 0,
  253. };
  254. let lines = usage.summary_lines_for_model("usage", Some("custom-model"));
  255. assert!(lines[0].contains("pricing=estimated-default"));
  256. }
  257. #[test]
  258. fn reconstructs_usage_from_session_messages() {
  259. let mut session = Session::new();
  260. session.messages = vec![ConversationMessage {
  261. role: MessageRole::Assistant,
  262. blocks: vec![ContentBlock::Text {
  263. text: "done".to_string(),
  264. }],
  265. usage: Some(TokenUsage {
  266. input_tokens: 5,
  267. output_tokens: 2,
  268. cache_creation_input_tokens: 1,
  269. cache_read_input_tokens: 0,
  270. }),
  271. }];
  272. let tracker = UsageTracker::from_session(&session);
  273. assert_eq!(tracker.turns(), 1);
  274. assert_eq!(tracker.cumulative_usage().total_tokens(), 8);
  275. }
  276. }