Bläddra i källkod

Merge remote-tracking branch 'origin/rcc/api' into dev/rust

Yeachan-Heo 2 månader sedan
förälder
incheckning
863958b94c
3 ändrade filer med 162 tillägg och 26 borttagningar
  1. 152 9
      rust/crates/api/src/client.rs
  2. 2 2
      rust/crates/api/src/lib.rs
  3. 8 15
      rust/crates/rusty-claude-cli/src/main.rs

+ 152 - 9
rust/crates/api/src/client.rs

@@ -392,8 +392,52 @@ pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTok
     let Some(token_set) = load_saved_oauth_token()? else {
         return Ok(None);
     };
+    resolve_saved_oauth_token_set(config, token_set).map(Some)
+}
+
+pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
+where
+    F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
+{
+    if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
+        return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
+            Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
+                api_key,
+                bearer_token,
+            }),
+            None => Ok(AuthSource::ApiKey(api_key)),
+        };
+    }
+    if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
+        return Ok(AuthSource::BearerToken(bearer_token));
+    }
+
+    let Some(token_set) = load_saved_oauth_token()? else {
+        return Err(ApiError::MissingApiKey);
+    };
     if !oauth_token_is_expired(&token_set) {
-        return Ok(Some(token_set));
+        return Ok(AuthSource::BearerToken(token_set.access_token));
+    }
+    if token_set.refresh_token.is_none() {
+        return Err(ApiError::ExpiredOAuthToken);
+    }
+
+    let Some(config) = load_oauth_config()? else {
+        return Err(ApiError::Auth(
+            "saved OAuth token is expired; runtime OAuth config is missing".to_string(),
+        ));
+    };
+    Ok(AuthSource::from(resolve_saved_oauth_token_set(
+        &config, token_set,
+    )?))
+}
+
+fn resolve_saved_oauth_token_set(
+    config: &OAuthConfig,
+    token_set: OAuthTokenSet,
+) -> Result<OAuthTokenSet, ApiError> {
+    if !oauth_token_is_expired(&token_set) {
+        return Ok(token_set);
     }
     let Some(refresh_token) = token_set.refresh_token.clone() else {
         return Err(ApiError::ExpiredOAuthToken);
@@ -403,18 +447,28 @@ pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTok
         client
             .refresh_oauth_token(
                 config,
-                &OAuthRefreshRequest::from_config(config, refresh_token, Some(token_set.scopes)),
+                &OAuthRefreshRequest::from_config(
+                    config,
+                    refresh_token,
+                    Some(token_set.scopes.clone()),
+                ),
             )
             .await
     })?;
-    save_oauth_credentials(&runtime::OAuthTokenSet {
-        access_token: refreshed.access_token.clone(),
-        refresh_token: refreshed.refresh_token.clone(),
+    let resolved = OAuthTokenSet {
+        access_token: refreshed.access_token,
+        refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
         expires_at: refreshed.expires_at,
-        scopes: refreshed.scopes.clone(),
+        scopes: refreshed.scopes,
+    };
+    save_oauth_credentials(&runtime::OAuthTokenSet {
+        access_token: resolved.access_token.clone(),
+        refresh_token: resolved.refresh_token.clone(),
+        expires_at: resolved.expires_at,
+        scopes: resolved.scopes.clone(),
     })
     .map_err(ApiError::from)?;
-    Ok(Some(refreshed))
+    Ok(resolved)
 }
 
 fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
@@ -571,8 +625,8 @@ mod tests {
     use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
 
     use crate::client::{
-        now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, AnthropicClient,
-        AuthSource, OAuthTokenSet,
+        now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
+        resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
     };
     use crate::types::{ContentBlockDelta, MessageRequest};
 
@@ -760,6 +814,95 @@ mod tests {
         std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
     }
 
+    #[test]
+    fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
+        let _guard = env_lock();
+        let config_home = temp_config_home();
+        std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
+        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
+        std::env::remove_var("ANTHROPIC_API_KEY");
+        save_oauth_credentials(&runtime::OAuthTokenSet {
+            access_token: "saved-access-token".to_string(),
+            refresh_token: Some("refresh".to_string()),
+            expires_at: Some(now_unix_timestamp() + 300),
+            scopes: vec!["scope:a".to_string()],
+        })
+        .expect("save oauth credentials");
+
+        let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
+            .expect("startup auth");
+        assert_eq!(auth.bearer_token(), Some("saved-access-token"));
+
+        clear_oauth_credentials().expect("clear credentials");
+        std::env::remove_var("CLAUDE_CONFIG_HOME");
+        std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
+    }
+
+    #[test]
+    fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
+        let _guard = env_lock();
+        let config_home = temp_config_home();
+        std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
+        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
+        std::env::remove_var("ANTHROPIC_API_KEY");
+        save_oauth_credentials(&runtime::OAuthTokenSet {
+            access_token: "expired-access-token".to_string(),
+            refresh_token: Some("refresh-token".to_string()),
+            expires_at: Some(1),
+            scopes: vec!["scope:a".to_string()],
+        })
+        .expect("save expired oauth credentials");
+
+        let error =
+            resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
+        assert!(
+            matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
+        );
+
+        let stored = runtime::load_oauth_credentials()
+            .expect("load stored credentials")
+            .expect("stored token set");
+        assert_eq!(stored.access_token, "expired-access-token");
+        assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
+
+        clear_oauth_credentials().expect("clear credentials");
+        std::env::remove_var("CLAUDE_CONFIG_HOME");
+        std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
+    }
+
+    #[test]
+    fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
+        let _guard = env_lock();
+        let config_home = temp_config_home();
+        std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
+        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
+        std::env::remove_var("ANTHROPIC_API_KEY");
+        save_oauth_credentials(&runtime::OAuthTokenSet {
+            access_token: "expired-access-token".to_string(),
+            refresh_token: Some("refresh-token".to_string()),
+            expires_at: Some(1),
+            scopes: vec!["scope:a".to_string()],
+        })
+        .expect("save expired oauth credentials");
+
+        let token_url = spawn_token_server(
+            "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
+        );
+        let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
+            .expect("resolve refreshed token")
+            .expect("token set present");
+        assert_eq!(resolved.access_token, "refreshed-token");
+        assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
+        let stored = runtime::load_oauth_credentials()
+            .expect("load stored credentials")
+            .expect("stored token set");
+        assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
+
+        clear_oauth_credentials().expect("clear credentials");
+        std::env::remove_var("CLAUDE_CONFIG_HOME");
+        std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
+    }
+
     #[test]
     fn message_request_stream_helper_sets_stream_true() {
         let request = MessageRequest {

+ 2 - 2
rust/crates/api/src/lib.rs

@@ -4,8 +4,8 @@ mod sse;
 mod types;
 
 pub use client::{
-    oauth_token_is_expired, resolve_saved_oauth_token, AnthropicClient, AuthSource, MessageStream,
-    OAuthTokenSet,
+    oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source,
+    AnthropicClient, AuthSource, MessageStream, OAuthTokenSet,
 };
 pub use error::ApiError;
 pub use sse::{parse_frame, SseParser};

+ 8 - 15
rust/crates/rusty-claude-cli/src/main.rs

@@ -11,7 +11,7 @@ use std::process::Command;
 use std::time::{SystemTime, UNIX_EPOCH};
 
 use api::{
-    resolve_saved_oauth_token, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
+    resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
     InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
     StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
 };
@@ -2006,20 +2006,13 @@ impl AnthropicRuntimeClient {
 }
 
 fn resolve_cli_auth_source() -> Result<AuthSource, Box<dyn std::error::Error>> {
-    match AuthSource::from_env() {
-        Ok(auth) => Ok(auth),
-        Err(api::ApiError::MissingApiKey) => {
-            let cwd = env::current_dir()?;
-            let config = ConfigLoader::default_for(&cwd).load()?;
-            if let Some(oauth) = config.oauth() {
-                if let Some(token_set) = resolve_saved_oauth_token(oauth)? {
-                    return Ok(AuthSource::from(token_set));
-                }
-            }
-            Ok(AuthSource::from_env_or_saved()?)
-        }
-        Err(error) => Err(Box::new(error)),
-    }
+    Ok(resolve_startup_auth_source(|| {
+        let cwd = env::current_dir().map_err(api::ApiError::from)?;
+        let config = ConfigLoader::default_for(&cwd).load().map_err(|error| {
+            api::ApiError::Auth(format!("failed to load runtime OAuth config: {error}"))
+        })?;
+        Ok(config.oauth().cloned())
+    })?)
 }
 
 impl ApiClient for AnthropicRuntimeClient {