usage.rs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. use crate::session::Session;
  2. const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
  3. const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
  4. const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
  5. const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
  6. /// Per-million-token pricing used for cost estimation.
  7. #[derive(Debug, Clone, Copy, PartialEq)]
  8. pub struct ModelPricing {
  9. pub input_cost_per_million: f64,
  10. pub output_cost_per_million: f64,
  11. pub cache_creation_cost_per_million: f64,
  12. pub cache_read_cost_per_million: f64,
  13. }
  14. impl ModelPricing {
  15. #[must_use]
  16. pub const fn default_sonnet_tier() -> Self {
  17. Self {
  18. input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION,
  19. output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION,
  20. cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION,
  21. cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION,
  22. }
  23. }
  24. }
  25. /// Token counters accumulated for a conversation turn or session.
  26. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
  27. pub struct TokenUsage {
  28. pub input_tokens: u32,
  29. pub output_tokens: u32,
  30. pub cache_creation_input_tokens: u32,
  31. pub cache_read_input_tokens: u32,
  32. }
  33. /// Estimated dollar cost derived from a [`TokenUsage`] sample.
  34. #[derive(Debug, Clone, Copy, PartialEq)]
  35. pub struct UsageCostEstimate {
  36. pub input_cost_usd: f64,
  37. pub output_cost_usd: f64,
  38. pub cache_creation_cost_usd: f64,
  39. pub cache_read_cost_usd: f64,
  40. }
  41. impl UsageCostEstimate {
  42. #[must_use]
  43. pub fn total_cost_usd(self) -> f64 {
  44. self.input_cost_usd
  45. + self.output_cost_usd
  46. + self.cache_creation_cost_usd
  47. + self.cache_read_cost_usd
  48. }
  49. }
  50. /// Returns pricing metadata for a known model alias or family.
  51. #[must_use]
  52. pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
  53. let normalized = model.to_ascii_lowercase();
  54. if normalized.contains("haiku") {
  55. return Some(ModelPricing {
  56. input_cost_per_million: 1.0,
  57. output_cost_per_million: 5.0,
  58. cache_creation_cost_per_million: 1.25,
  59. cache_read_cost_per_million: 0.1,
  60. });
  61. }
  62. if normalized.contains("opus") {
  63. return Some(ModelPricing {
  64. input_cost_per_million: 15.0,
  65. output_cost_per_million: 75.0,
  66. cache_creation_cost_per_million: 18.75,
  67. cache_read_cost_per_million: 1.5,
  68. });
  69. }
  70. if normalized.contains("sonnet") {
  71. return Some(ModelPricing::default_sonnet_tier());
  72. }
  73. None
  74. }
  75. impl TokenUsage {
  76. #[must_use]
  77. pub fn total_tokens(self) -> u32 {
  78. self.input_tokens
  79. + self.output_tokens
  80. + self.cache_creation_input_tokens
  81. + self.cache_read_input_tokens
  82. }
  83. #[must_use]
  84. pub fn estimate_cost_usd(self) -> UsageCostEstimate {
  85. self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier())
  86. }
  87. #[must_use]
  88. pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate {
  89. UsageCostEstimate {
  90. input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
  91. output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million),
  92. cache_creation_cost_usd: cost_for_tokens(
  93. self.cache_creation_input_tokens,
  94. pricing.cache_creation_cost_per_million,
  95. ),
  96. cache_read_cost_usd: cost_for_tokens(
  97. self.cache_read_input_tokens,
  98. pricing.cache_read_cost_per_million,
  99. ),
  100. }
  101. }
  102. #[must_use]
  103. pub fn summary_lines(self, label: &str) -> Vec<String> {
  104. self.summary_lines_for_model(label, None)
  105. }
  106. #[must_use]
  107. pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec<String> {
  108. let pricing = model.and_then(pricing_for_model);
  109. let cost = pricing.map_or_else(
  110. || self.estimate_cost_usd(),
  111. |pricing| self.estimate_cost_usd_with_pricing(pricing),
  112. );
  113. let model_suffix =
  114. model.map_or_else(String::new, |model_name| format!(" model={model_name}"));
  115. let pricing_suffix = if pricing.is_some() {
  116. ""
  117. } else if model.is_some() {
  118. " pricing=estimated-default"
  119. } else {
  120. ""
  121. };
  122. vec![
  123. format!(
  124. "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}",
  125. self.total_tokens(),
  126. self.input_tokens,
  127. self.output_tokens,
  128. self.cache_creation_input_tokens,
  129. self.cache_read_input_tokens,
  130. format_usd(cost.total_cost_usd()),
  131. model_suffix,
  132. pricing_suffix,
  133. ),
  134. format!(
  135. " cost breakdown: input={} output={} cache_write={} cache_read={}",
  136. format_usd(cost.input_cost_usd),
  137. format_usd(cost.output_cost_usd),
  138. format_usd(cost.cache_creation_cost_usd),
  139. format_usd(cost.cache_read_cost_usd),
  140. ),
  141. ]
  142. }
  143. }
  144. fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
  145. f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens
  146. }
  147. #[must_use]
  148. /// Formats a dollar-denominated value for CLI display.
  149. pub fn format_usd(amount: f64) -> String {
  150. format!("${amount:.4}")
  151. }
  152. /// Aggregates token usage across a running session.
  153. #[derive(Debug, Clone, Default, PartialEq, Eq)]
  154. pub struct UsageTracker {
  155. latest_turn: TokenUsage,
  156. cumulative: TokenUsage,
  157. turns: u32,
  158. }
  159. impl UsageTracker {
  160. #[must_use]
  161. pub fn new() -> Self {
  162. Self::default()
  163. }
  164. #[must_use]
  165. pub fn from_session(session: &Session) -> Self {
  166. let mut tracker = Self::new();
  167. for message in &session.messages {
  168. if let Some(usage) = message.usage {
  169. tracker.record(usage);
  170. }
  171. }
  172. tracker
  173. }
  174. pub fn record(&mut self, usage: TokenUsage) {
  175. self.latest_turn = usage;
  176. self.cumulative.input_tokens += usage.input_tokens;
  177. self.cumulative.output_tokens += usage.output_tokens;
  178. self.cumulative.cache_creation_input_tokens += usage.cache_creation_input_tokens;
  179. self.cumulative.cache_read_input_tokens += usage.cache_read_input_tokens;
  180. self.turns += 1;
  181. }
  182. #[must_use]
  183. pub fn current_turn_usage(&self) -> TokenUsage {
  184. self.latest_turn
  185. }
  186. #[must_use]
  187. pub fn cumulative_usage(&self) -> TokenUsage {
  188. self.cumulative
  189. }
  190. #[must_use]
  191. pub fn turns(&self) -> u32 {
  192. self.turns
  193. }
  194. }
  195. #[cfg(test)]
  196. mod tests {
  197. use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker};
  198. use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
  199. #[test]
  200. fn tracks_true_cumulative_usage() {
  201. let mut tracker = UsageTracker::new();
  202. tracker.record(TokenUsage {
  203. input_tokens: 10,
  204. output_tokens: 4,
  205. cache_creation_input_tokens: 2,
  206. cache_read_input_tokens: 1,
  207. });
  208. tracker.record(TokenUsage {
  209. input_tokens: 20,
  210. output_tokens: 6,
  211. cache_creation_input_tokens: 3,
  212. cache_read_input_tokens: 2,
  213. });
  214. assert_eq!(tracker.turns(), 2);
  215. assert_eq!(tracker.current_turn_usage().input_tokens, 20);
  216. assert_eq!(tracker.current_turn_usage().output_tokens, 6);
  217. assert_eq!(tracker.cumulative_usage().output_tokens, 10);
  218. assert_eq!(tracker.cumulative_usage().input_tokens, 30);
  219. assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
  220. }
  221. #[test]
  222. fn computes_cost_summary_lines() {
  223. let usage = TokenUsage {
  224. input_tokens: 1_000_000,
  225. output_tokens: 500_000,
  226. cache_creation_input_tokens: 100_000,
  227. cache_read_input_tokens: 200_000,
  228. };
  229. let cost = usage.estimate_cost_usd();
  230. assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
  231. assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
  232. let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-20250514"));
  233. assert!(lines[0].contains("estimated_cost=$54.6750"));
  234. assert!(lines[0].contains("model=claude-sonnet-4-20250514"));
  235. assert!(lines[1].contains("cache_read=$0.3000"));
  236. }
  237. #[test]
  238. fn supports_model_specific_pricing() {
  239. let usage = TokenUsage {
  240. input_tokens: 1_000_000,
  241. output_tokens: 500_000,
  242. cache_creation_input_tokens: 0,
  243. cache_read_input_tokens: 0,
  244. };
  245. let haiku = pricing_for_model("claude-haiku-4-5-20251001").expect("haiku pricing");
  246. let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing");
  247. let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
  248. let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
  249. assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000");
  250. assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000");
  251. }
  252. #[test]
  253. fn marks_unknown_model_pricing_as_fallback() {
  254. let usage = TokenUsage {
  255. input_tokens: 100,
  256. output_tokens: 100,
  257. cache_creation_input_tokens: 0,
  258. cache_read_input_tokens: 0,
  259. };
  260. let lines = usage.summary_lines_for_model("usage", Some("custom-model"));
  261. assert!(lines[0].contains("pricing=estimated-default"));
  262. }
  263. #[test]
  264. fn reconstructs_usage_from_session_messages() {
  265. let mut session = Session::new();
  266. session.messages = vec![ConversationMessage {
  267. role: MessageRole::Assistant,
  268. blocks: vec![ContentBlock::Text {
  269. text: "done".to_string(),
  270. }],
  271. usage: Some(TokenUsage {
  272. input_tokens: 5,
  273. output_tokens: 2,
  274. cache_creation_input_tokens: 1,
  275. cache_read_input_tokens: 0,
  276. }),
  277. }];
  278. let tracker = UsageTracker::from_session(&session);
  279. assert_eq!(tracker.turns(), 1);
  280. assert_eq!(tracker.cumulative_usage().total_tokens(), 8);
  281. }
  282. }