|
|
@@ -392,8 +392,52 @@ pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTok
|
|
|
let Some(token_set) = load_saved_oauth_token()? else {
|
|
|
return Ok(None);
|
|
|
};
|
|
|
+ resolve_saved_oauth_token_set(config, token_set).map(Some)
|
|
|
+}
|
|
|
+
|
|
|
+pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
|
|
|
+where
|
|
|
+ F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
|
|
|
+{
|
|
|
+ if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
|
|
|
+ return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
|
|
+ Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
|
|
|
+ api_key,
|
|
|
+ bearer_token,
|
|
|
+ }),
|
|
|
+ None => Ok(AuthSource::ApiKey(api_key)),
|
|
|
+ };
|
|
|
+ }
|
|
|
+ if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
|
|
+ return Ok(AuthSource::BearerToken(bearer_token));
|
|
|
+ }
|
|
|
+
|
|
|
+ let Some(token_set) = load_saved_oauth_token()? else {
|
|
|
+ return Err(ApiError::MissingApiKey);
|
|
|
+ };
|
|
|
if !oauth_token_is_expired(&token_set) {
|
|
|
- return Ok(Some(token_set));
|
|
|
+ return Ok(AuthSource::BearerToken(token_set.access_token));
|
|
|
+ }
|
|
|
+ if token_set.refresh_token.is_none() {
|
|
|
+ return Err(ApiError::ExpiredOAuthToken);
|
|
|
+ }
|
|
|
+
|
|
|
+ let Some(config) = load_oauth_config()? else {
|
|
|
+ return Err(ApiError::Auth(
|
|
|
+ "saved OAuth token is expired; runtime OAuth config is missing".to_string(),
|
|
|
+ ));
|
|
|
+ };
|
|
|
+ Ok(AuthSource::from(resolve_saved_oauth_token_set(
|
|
|
+ &config, token_set,
|
|
|
+ )?))
|
|
|
+}
|
|
|
+
|
|
|
+fn resolve_saved_oauth_token_set(
|
|
|
+ config: &OAuthConfig,
|
|
|
+ token_set: OAuthTokenSet,
|
|
|
+) -> Result<OAuthTokenSet, ApiError> {
|
|
|
+ if !oauth_token_is_expired(&token_set) {
|
|
|
+ return Ok(token_set);
|
|
|
}
|
|
|
let Some(refresh_token) = token_set.refresh_token.clone() else {
|
|
|
return Err(ApiError::ExpiredOAuthToken);
|
|
|
@@ -403,18 +447,28 @@ pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTok
|
|
|
client
|
|
|
.refresh_oauth_token(
|
|
|
config,
|
|
|
- &OAuthRefreshRequest::from_config(config, refresh_token, Some(token_set.scopes)),
|
|
|
+ &OAuthRefreshRequest::from_config(
|
|
|
+ config,
|
|
|
+ refresh_token,
|
|
|
+ Some(token_set.scopes.clone()),
|
|
|
+ ),
|
|
|
)
|
|
|
.await
|
|
|
})?;
|
|
|
- save_oauth_credentials(&runtime::OAuthTokenSet {
|
|
|
- access_token: refreshed.access_token.clone(),
|
|
|
- refresh_token: refreshed.refresh_token.clone(),
|
|
|
+ let resolved = OAuthTokenSet {
|
|
|
+ access_token: refreshed.access_token,
|
|
|
+ refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
|
|
|
expires_at: refreshed.expires_at,
|
|
|
- scopes: refreshed.scopes.clone(),
|
|
|
+ scopes: refreshed.scopes,
|
|
|
+ };
|
|
|
+ save_oauth_credentials(&runtime::OAuthTokenSet {
|
|
|
+ access_token: resolved.access_token.clone(),
|
|
|
+ refresh_token: resolved.refresh_token.clone(),
|
|
|
+ expires_at: resolved.expires_at,
|
|
|
+ scopes: resolved.scopes.clone(),
|
|
|
})
|
|
|
.map_err(ApiError::from)?;
|
|
|
- Ok(Some(refreshed))
|
|
|
+ Ok(resolved)
|
|
|
}
|
|
|
|
|
|
fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
|
|
|
@@ -571,8 +625,8 @@ mod tests {
|
|
|
use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
|
|
|
|
|
|
use crate::client::{
|
|
|
- now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, AnthropicClient,
|
|
|
- AuthSource, OAuthTokenSet,
|
|
|
+ now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
|
|
|
+ resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
|
|
|
};
|
|
|
use crate::types::{ContentBlockDelta, MessageRequest};
|
|
|
|
|
|
@@ -760,6 +814,95 @@ mod tests {
|
|
|
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
|
|
}
|
|
|
|
|
|
+ #[test]
|
|
|
+ fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
|
|
|
+ let _guard = env_lock();
|
|
|
+ let config_home = temp_config_home();
|
|
|
+ std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
|
|
+ std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
+ std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
+ save_oauth_credentials(&runtime::OAuthTokenSet {
|
|
|
+ access_token: "saved-access-token".to_string(),
|
|
|
+ refresh_token: Some("refresh".to_string()),
|
|
|
+ expires_at: Some(now_unix_timestamp() + 300),
|
|
|
+ scopes: vec!["scope:a".to_string()],
|
|
|
+ })
|
|
|
+ .expect("save oauth credentials");
|
|
|
+
|
|
|
+ let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
|
|
|
+ .expect("startup auth");
|
|
|
+ assert_eq!(auth.bearer_token(), Some("saved-access-token"));
|
|
|
+
|
|
|
+ clear_oauth_credentials().expect("clear credentials");
|
|
|
+ std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
+ std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
|
|
|
+ let _guard = env_lock();
|
|
|
+ let config_home = temp_config_home();
|
|
|
+ std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
|
|
+ std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
+ std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
+ save_oauth_credentials(&runtime::OAuthTokenSet {
|
|
|
+ access_token: "expired-access-token".to_string(),
|
|
|
+ refresh_token: Some("refresh-token".to_string()),
|
|
|
+ expires_at: Some(1),
|
|
|
+ scopes: vec!["scope:a".to_string()],
|
|
|
+ })
|
|
|
+ .expect("save expired oauth credentials");
|
|
|
+
|
|
|
+ let error =
|
|
|
+ resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
|
|
|
+ assert!(
|
|
|
+ matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
|
|
|
+ );
|
|
|
+
|
|
|
+ let stored = runtime::load_oauth_credentials()
|
|
|
+ .expect("load stored credentials")
|
|
|
+ .expect("stored token set");
|
|
|
+ assert_eq!(stored.access_token, "expired-access-token");
|
|
|
+ assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
|
|
+
|
|
|
+ clear_oauth_credentials().expect("clear credentials");
|
|
|
+ std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
+ std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
|
|
|
+ let _guard = env_lock();
|
|
|
+ let config_home = temp_config_home();
|
|
|
+ std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
|
|
+ std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
+ std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
+ save_oauth_credentials(&runtime::OAuthTokenSet {
|
|
|
+ access_token: "expired-access-token".to_string(),
|
|
|
+ refresh_token: Some("refresh-token".to_string()),
|
|
|
+ expires_at: Some(1),
|
|
|
+ scopes: vec!["scope:a".to_string()],
|
|
|
+ })
|
|
|
+ .expect("save expired oauth credentials");
|
|
|
+
|
|
|
+ let token_url = spawn_token_server(
|
|
|
+ "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
|
|
|
+ );
|
|
|
+ let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
|
|
|
+ .expect("resolve refreshed token")
|
|
|
+ .expect("token set present");
|
|
|
+ assert_eq!(resolved.access_token, "refreshed-token");
|
|
|
+ assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
|
|
|
+ let stored = runtime::load_oauth_credentials()
|
|
|
+ .expect("load stored credentials")
|
|
|
+ .expect("stored token set");
|
|
|
+ assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
|
|
+
|
|
|
+ clear_oauth_credentials().expect("clear credentials");
|
|
|
+ std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
+ std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
|
|
+ }
|
|
|
+
|
|
|
#[test]
|
|
|
fn message_request_stream_helper_sets_stream_true() {
|
|
|
let request = MessageRequest {
|