| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309 |
- use crate::session::Session;
- const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
- const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
- const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
- const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
- #[derive(Debug, Clone, Copy, PartialEq)]
- pub struct ModelPricing {
- pub input_cost_per_million: f64,
- pub output_cost_per_million: f64,
- pub cache_creation_cost_per_million: f64,
- pub cache_read_cost_per_million: f64,
- }
- impl ModelPricing {
- #[must_use]
- pub const fn default_sonnet_tier() -> Self {
- Self {
- input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION,
- output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION,
- cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION,
- cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION,
- }
- }
- }
- #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
- pub struct TokenUsage {
- pub input_tokens: u32,
- pub output_tokens: u32,
- pub cache_creation_input_tokens: u32,
- pub cache_read_input_tokens: u32,
- }
- #[derive(Debug, Clone, Copy, PartialEq)]
- pub struct UsageCostEstimate {
- pub input_cost_usd: f64,
- pub output_cost_usd: f64,
- pub cache_creation_cost_usd: f64,
- pub cache_read_cost_usd: f64,
- }
- impl UsageCostEstimate {
- #[must_use]
- pub fn total_cost_usd(self) -> f64 {
- self.input_cost_usd
- + self.output_cost_usd
- + self.cache_creation_cost_usd
- + self.cache_read_cost_usd
- }
- }
- #[must_use]
- pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
- let normalized = model.to_ascii_lowercase();
- if normalized.contains("haiku") {
- return Some(ModelPricing {
- input_cost_per_million: 1.0,
- output_cost_per_million: 5.0,
- cache_creation_cost_per_million: 1.25,
- cache_read_cost_per_million: 0.1,
- });
- }
- if normalized.contains("opus") {
- return Some(ModelPricing {
- input_cost_per_million: 15.0,
- output_cost_per_million: 75.0,
- cache_creation_cost_per_million: 18.75,
- cache_read_cost_per_million: 1.5,
- });
- }
- if normalized.contains("sonnet") {
- return Some(ModelPricing::default_sonnet_tier());
- }
- None
- }
- impl TokenUsage {
- #[must_use]
- pub fn total_tokens(self) -> u32 {
- self.input_tokens
- + self.output_tokens
- + self.cache_creation_input_tokens
- + self.cache_read_input_tokens
- }
- #[must_use]
- pub fn estimate_cost_usd(self) -> UsageCostEstimate {
- self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier())
- }
- #[must_use]
- pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate {
- UsageCostEstimate {
- input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
- output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million),
- cache_creation_cost_usd: cost_for_tokens(
- self.cache_creation_input_tokens,
- pricing.cache_creation_cost_per_million,
- ),
- cache_read_cost_usd: cost_for_tokens(
- self.cache_read_input_tokens,
- pricing.cache_read_cost_per_million,
- ),
- }
- }
- #[must_use]
- pub fn summary_lines(self, label: &str) -> Vec<String> {
- self.summary_lines_for_model(label, None)
- }
- #[must_use]
- pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec<String> {
- let pricing = model.and_then(pricing_for_model);
- let cost = pricing.map_or_else(
- || self.estimate_cost_usd(),
- |pricing| self.estimate_cost_usd_with_pricing(pricing),
- );
- let model_suffix =
- model.map_or_else(String::new, |model_name| format!(" model={model_name}"));
- let pricing_suffix = if pricing.is_some() {
- ""
- } else if model.is_some() {
- " pricing=estimated-default"
- } else {
- ""
- };
- vec![
- format!(
- "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}",
- self.total_tokens(),
- self.input_tokens,
- self.output_tokens,
- self.cache_creation_input_tokens,
- self.cache_read_input_tokens,
- format_usd(cost.total_cost_usd()),
- model_suffix,
- pricing_suffix,
- ),
- format!(
- " cost breakdown: input={} output={} cache_write={} cache_read={}",
- format_usd(cost.input_cost_usd),
- format_usd(cost.output_cost_usd),
- format_usd(cost.cache_creation_cost_usd),
- format_usd(cost.cache_read_cost_usd),
- ),
- ]
- }
- }
- fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
- f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens
- }
- #[must_use]
- pub fn format_usd(amount: f64) -> String {
- format!("${amount:.4}")
- }
- #[derive(Debug, Clone, Default, PartialEq, Eq)]
- pub struct UsageTracker {
- latest_turn: TokenUsage,
- cumulative: TokenUsage,
- turns: u32,
- }
- impl UsageTracker {
- #[must_use]
- pub fn new() -> Self {
- Self::default()
- }
- #[must_use]
- pub fn from_session(session: &Session) -> Self {
- let mut tracker = Self::new();
- for message in &session.messages {
- if let Some(usage) = message.usage {
- tracker.record(usage);
- }
- }
- tracker
- }
- pub fn record(&mut self, usage: TokenUsage) {
- self.latest_turn = usage;
- self.cumulative.input_tokens += usage.input_tokens;
- self.cumulative.output_tokens += usage.output_tokens;
- self.cumulative.cache_creation_input_tokens += usage.cache_creation_input_tokens;
- self.cumulative.cache_read_input_tokens += usage.cache_read_input_tokens;
- self.turns += 1;
- }
- #[must_use]
- pub fn current_turn_usage(&self) -> TokenUsage {
- self.latest_turn
- }
- #[must_use]
- pub fn cumulative_usage(&self) -> TokenUsage {
- self.cumulative
- }
- #[must_use]
- pub fn turns(&self) -> u32 {
- self.turns
- }
- }
- #[cfg(test)]
- mod tests {
- use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker};
- use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
- #[test]
- fn tracks_true_cumulative_usage() {
- let mut tracker = UsageTracker::new();
- tracker.record(TokenUsage {
- input_tokens: 10,
- output_tokens: 4,
- cache_creation_input_tokens: 2,
- cache_read_input_tokens: 1,
- });
- tracker.record(TokenUsage {
- input_tokens: 20,
- output_tokens: 6,
- cache_creation_input_tokens: 3,
- cache_read_input_tokens: 2,
- });
- assert_eq!(tracker.turns(), 2);
- assert_eq!(tracker.current_turn_usage().input_tokens, 20);
- assert_eq!(tracker.current_turn_usage().output_tokens, 6);
- assert_eq!(tracker.cumulative_usage().output_tokens, 10);
- assert_eq!(tracker.cumulative_usage().input_tokens, 30);
- assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
- }
- #[test]
- fn computes_cost_summary_lines() {
- let usage = TokenUsage {
- input_tokens: 1_000_000,
- output_tokens: 500_000,
- cache_creation_input_tokens: 100_000,
- cache_read_input_tokens: 200_000,
- };
- let cost = usage.estimate_cost_usd();
- assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
- assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
- let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-20250514"));
- assert!(lines[0].contains("estimated_cost=$54.6750"));
- assert!(lines[0].contains("model=claude-sonnet-4-20250514"));
- assert!(lines[1].contains("cache_read=$0.3000"));
- }
- #[test]
- fn supports_model_specific_pricing() {
- let usage = TokenUsage {
- input_tokens: 1_000_000,
- output_tokens: 500_000,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- };
- let haiku = pricing_for_model("claude-haiku-4-5-20251001").expect("haiku pricing");
- let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing");
- let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
- let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
- assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000");
- assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000");
- }
- #[test]
- fn marks_unknown_model_pricing_as_fallback() {
- let usage = TokenUsage {
- input_tokens: 100,
- output_tokens: 100,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- };
- let lines = usage.summary_lines_for_model("usage", Some("custom-model"));
- assert!(lines[0].contains("pricing=estimated-default"));
- }
- #[test]
- fn reconstructs_usage_from_session_messages() {
- let session = Session {
- version: 1,
- messages: vec![ConversationMessage {
- role: MessageRole::Assistant,
- blocks: vec![ContentBlock::Text {
- text: "done".to_string(),
- }],
- usage: Some(TokenUsage {
- input_tokens: 5,
- output_tokens: 2,
- cache_creation_input_tokens: 1,
- cache_read_input_tokens: 0,
- }),
- }],
- };
- let tracker = UsageTracker::from_session(&session);
- assert_eq!(tracker.turns(), 1);
- assert_eq!(tracker.cumulative_usage().total_tokens(), 8);
- }
- }
|