oauth.rs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. use std::collections::BTreeMap;
  2. use std::fs::{self, File};
  3. use std::io::{self, Read};
  4. use std::path::PathBuf;
  5. use serde::{Deserialize, Serialize};
  6. use serde_json::{Map, Value};
  7. use sha2::{Digest, Sha256};
  8. use crate::config::OAuthConfig;
  9. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  10. pub struct OAuthTokenSet {
  11. pub access_token: String,
  12. pub refresh_token: Option<String>,
  13. pub expires_at: Option<u64>,
  14. pub scopes: Vec<String>,
  15. }
  16. #[derive(Debug, Clone, PartialEq, Eq)]
  17. pub struct PkceCodePair {
  18. pub verifier: String,
  19. pub challenge: String,
  20. pub challenge_method: PkceChallengeMethod,
  21. }
  22. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  23. pub enum PkceChallengeMethod {
  24. S256,
  25. }
  26. impl PkceChallengeMethod {
  27. #[must_use]
  28. pub const fn as_str(self) -> &'static str {
  29. match self {
  30. Self::S256 => "S256",
  31. }
  32. }
  33. }
  34. #[derive(Debug, Clone, PartialEq, Eq)]
  35. pub struct OAuthAuthorizationRequest {
  36. pub authorize_url: String,
  37. pub client_id: String,
  38. pub redirect_uri: String,
  39. pub scopes: Vec<String>,
  40. pub state: String,
  41. pub code_challenge: String,
  42. pub code_challenge_method: PkceChallengeMethod,
  43. pub extra_params: BTreeMap<String, String>,
  44. }
  45. #[derive(Debug, Clone, PartialEq, Eq)]
  46. pub struct OAuthTokenExchangeRequest {
  47. pub grant_type: &'static str,
  48. pub code: String,
  49. pub redirect_uri: String,
  50. pub client_id: String,
  51. pub code_verifier: String,
  52. pub state: String,
  53. }
  54. #[derive(Debug, Clone, PartialEq, Eq)]
  55. pub struct OAuthRefreshRequest {
  56. pub grant_type: &'static str,
  57. pub refresh_token: String,
  58. pub client_id: String,
  59. pub scopes: Vec<String>,
  60. }
  61. #[derive(Debug, Clone, PartialEq, Eq)]
  62. pub struct OAuthCallbackParams {
  63. pub code: Option<String>,
  64. pub state: Option<String>,
  65. pub error: Option<String>,
  66. pub error_description: Option<String>,
  67. }
  68. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  69. #[serde(rename_all = "camelCase")]
  70. struct StoredOAuthCredentials {
  71. access_token: String,
  72. #[serde(default)]
  73. refresh_token: Option<String>,
  74. #[serde(default)]
  75. expires_at: Option<u64>,
  76. #[serde(default)]
  77. scopes: Vec<String>,
  78. }
  79. impl From<OAuthTokenSet> for StoredOAuthCredentials {
  80. fn from(value: OAuthTokenSet) -> Self {
  81. Self {
  82. access_token: value.access_token,
  83. refresh_token: value.refresh_token,
  84. expires_at: value.expires_at,
  85. scopes: value.scopes,
  86. }
  87. }
  88. }
  89. impl From<StoredOAuthCredentials> for OAuthTokenSet {
  90. fn from(value: StoredOAuthCredentials) -> Self {
  91. Self {
  92. access_token: value.access_token,
  93. refresh_token: value.refresh_token,
  94. expires_at: value.expires_at,
  95. scopes: value.scopes,
  96. }
  97. }
  98. }
  99. impl OAuthAuthorizationRequest {
  100. #[must_use]
  101. pub fn from_config(
  102. config: &OAuthConfig,
  103. redirect_uri: impl Into<String>,
  104. state: impl Into<String>,
  105. pkce: &PkceCodePair,
  106. ) -> Self {
  107. Self {
  108. authorize_url: config.authorize_url.clone(),
  109. client_id: config.client_id.clone(),
  110. redirect_uri: redirect_uri.into(),
  111. scopes: config.scopes.clone(),
  112. state: state.into(),
  113. code_challenge: pkce.challenge.clone(),
  114. code_challenge_method: pkce.challenge_method,
  115. extra_params: BTreeMap::new(),
  116. }
  117. }
  118. #[must_use]
  119. pub fn with_extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
  120. self.extra_params.insert(key.into(), value.into());
  121. self
  122. }
  123. #[must_use]
  124. pub fn build_url(&self) -> String {
  125. let mut params = vec![
  126. ("response_type", "code".to_string()),
  127. ("client_id", self.client_id.clone()),
  128. ("redirect_uri", self.redirect_uri.clone()),
  129. ("scope", self.scopes.join(" ")),
  130. ("state", self.state.clone()),
  131. ("code_challenge", self.code_challenge.clone()),
  132. (
  133. "code_challenge_method",
  134. self.code_challenge_method.as_str().to_string(),
  135. ),
  136. ];
  137. params.extend(
  138. self.extra_params
  139. .iter()
  140. .map(|(key, value)| (key.as_str(), value.clone())),
  141. );
  142. let query = params
  143. .into_iter()
  144. .map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value)))
  145. .collect::<Vec<_>>()
  146. .join("&");
  147. format!(
  148. "{}{}{}",
  149. self.authorize_url,
  150. if self.authorize_url.contains('?') {
  151. '&'
  152. } else {
  153. '?'
  154. },
  155. query
  156. )
  157. }
  158. }
  159. impl OAuthTokenExchangeRequest {
  160. #[must_use]
  161. pub fn from_config(
  162. config: &OAuthConfig,
  163. code: impl Into<String>,
  164. state: impl Into<String>,
  165. verifier: impl Into<String>,
  166. redirect_uri: impl Into<String>,
  167. ) -> Self {
  168. Self {
  169. grant_type: "authorization_code",
  170. code: code.into(),
  171. redirect_uri: redirect_uri.into(),
  172. client_id: config.client_id.clone(),
  173. code_verifier: verifier.into(),
  174. state: state.into(),
  175. }
  176. }
  177. #[must_use]
  178. pub fn form_params(&self) -> BTreeMap<&str, String> {
  179. BTreeMap::from([
  180. ("grant_type", self.grant_type.to_string()),
  181. ("code", self.code.clone()),
  182. ("redirect_uri", self.redirect_uri.clone()),
  183. ("client_id", self.client_id.clone()),
  184. ("code_verifier", self.code_verifier.clone()),
  185. ("state", self.state.clone()),
  186. ])
  187. }
  188. }
  189. impl OAuthRefreshRequest {
  190. #[must_use]
  191. pub fn from_config(
  192. config: &OAuthConfig,
  193. refresh_token: impl Into<String>,
  194. scopes: Option<Vec<String>>,
  195. ) -> Self {
  196. Self {
  197. grant_type: "refresh_token",
  198. refresh_token: refresh_token.into(),
  199. client_id: config.client_id.clone(),
  200. scopes: scopes.unwrap_or_else(|| config.scopes.clone()),
  201. }
  202. }
  203. #[must_use]
  204. pub fn form_params(&self) -> BTreeMap<&str, String> {
  205. BTreeMap::from([
  206. ("grant_type", self.grant_type.to_string()),
  207. ("refresh_token", self.refresh_token.clone()),
  208. ("client_id", self.client_id.clone()),
  209. ("scope", self.scopes.join(" ")),
  210. ])
  211. }
  212. }
  213. pub fn generate_pkce_pair() -> io::Result<PkceCodePair> {
  214. let verifier = generate_random_token(32)?;
  215. Ok(PkceCodePair {
  216. challenge: code_challenge_s256(&verifier),
  217. verifier,
  218. challenge_method: PkceChallengeMethod::S256,
  219. })
  220. }
  221. pub fn generate_state() -> io::Result<String> {
  222. generate_random_token(32)
  223. }
  224. #[must_use]
  225. pub fn code_challenge_s256(verifier: &str) -> String {
  226. let digest = Sha256::digest(verifier.as_bytes());
  227. base64url_encode(&digest)
  228. }
  229. #[must_use]
  230. pub fn loopback_redirect_uri(port: u16) -> String {
  231. format!("http://localhost:{port}/callback")
  232. }
  233. pub fn credentials_path() -> io::Result<PathBuf> {
  234. Ok(credentials_home_dir()?.join("credentials.json"))
  235. }
  236. pub fn load_oauth_credentials() -> io::Result<Option<OAuthTokenSet>> {
  237. let path = credentials_path()?;
  238. let root = read_credentials_root(&path)?;
  239. let Some(oauth) = root.get("oauth") else {
  240. return Ok(None);
  241. };
  242. if oauth.is_null() {
  243. return Ok(None);
  244. }
  245. let stored = serde_json::from_value::<StoredOAuthCredentials>(oauth.clone())
  246. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
  247. Ok(Some(stored.into()))
  248. }
  249. pub fn save_oauth_credentials(token_set: &OAuthTokenSet) -> io::Result<()> {
  250. let path = credentials_path()?;
  251. let mut root = read_credentials_root(&path)?;
  252. root.insert(
  253. "oauth".to_string(),
  254. serde_json::to_value(StoredOAuthCredentials::from(token_set.clone()))
  255. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
  256. );
  257. write_credentials_root(&path, &root)
  258. }
  259. pub fn clear_oauth_credentials() -> io::Result<()> {
  260. let path = credentials_path()?;
  261. let mut root = read_credentials_root(&path)?;
  262. root.remove("oauth");
  263. write_credentials_root(&path, &root)
  264. }
  265. pub fn parse_oauth_callback_request_target(target: &str) -> Result<OAuthCallbackParams, String> {
  266. let (path, query) = target
  267. .split_once('?')
  268. .map_or((target, ""), |(path, query)| (path, query));
  269. if path != "/callback" {
  270. return Err(format!("unexpected callback path: {path}"));
  271. }
  272. parse_oauth_callback_query(query)
  273. }
  274. pub fn parse_oauth_callback_query(query: &str) -> Result<OAuthCallbackParams, String> {
  275. let mut params = BTreeMap::new();
  276. for pair in query.split('&').filter(|pair| !pair.is_empty()) {
  277. let (key, value) = pair
  278. .split_once('=')
  279. .map_or((pair, ""), |(key, value)| (key, value));
  280. params.insert(percent_decode(key)?, percent_decode(value)?);
  281. }
  282. Ok(OAuthCallbackParams {
  283. code: params.get("code").cloned(),
  284. state: params.get("state").cloned(),
  285. error: params.get("error").cloned(),
  286. error_description: params.get("error_description").cloned(),
  287. })
  288. }
  289. fn generate_random_token(bytes: usize) -> io::Result<String> {
  290. let mut buffer = vec![0_u8; bytes];
  291. File::open("/dev/urandom")?.read_exact(&mut buffer)?;
  292. Ok(base64url_encode(&buffer))
  293. }
  294. fn credentials_home_dir() -> io::Result<PathBuf> {
  295. if let Some(path) = std::env::var_os("CLAW_CONFIG_HOME") {
  296. return Ok(PathBuf::from(path));
  297. }
  298. let home = std::env::var_os("HOME")
  299. .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "HOME is not set"))?;
  300. Ok(PathBuf::from(home).join(".claw"))
  301. }
  302. fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
  303. match fs::read_to_string(path) {
  304. Ok(contents) => {
  305. if contents.trim().is_empty() {
  306. return Ok(Map::new());
  307. }
  308. serde_json::from_str::<Value>(&contents)
  309. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?
  310. .as_object()
  311. .cloned()
  312. .ok_or_else(|| {
  313. io::Error::new(
  314. io::ErrorKind::InvalidData,
  315. "credentials file must contain a JSON object",
  316. )
  317. })
  318. }
  319. Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Map::new()),
  320. Err(error) => Err(error),
  321. }
  322. }
  323. fn write_credentials_root(path: &PathBuf, root: &Map<String, Value>) -> io::Result<()> {
  324. if let Some(parent) = path.parent() {
  325. fs::create_dir_all(parent)?;
  326. }
  327. let rendered = serde_json::to_string_pretty(&Value::Object(root.clone()))
  328. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
  329. let temp_path = path.with_extension("json.tmp");
  330. fs::write(&temp_path, format!("{rendered}\n"))?;
  331. fs::rename(temp_path, path)
  332. }
  333. fn base64url_encode(bytes: &[u8]) -> String {
  334. const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
  335. let mut output = String::new();
  336. let mut index = 0;
  337. while index + 3 <= bytes.len() {
  338. let block = (u32::from(bytes[index]) << 16)
  339. | (u32::from(bytes[index + 1]) << 8)
  340. | u32::from(bytes[index + 2]);
  341. output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
  342. output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
  343. output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
  344. output.push(TABLE[(block & 0x3F) as usize] as char);
  345. index += 3;
  346. }
  347. match bytes.len().saturating_sub(index) {
  348. 1 => {
  349. let block = u32::from(bytes[index]) << 16;
  350. output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
  351. output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
  352. }
  353. 2 => {
  354. let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8);
  355. output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
  356. output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
  357. output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
  358. }
  359. _ => {}
  360. }
  361. output
  362. }
  363. fn percent_encode(value: &str) -> String {
  364. let mut encoded = String::new();
  365. for byte in value.bytes() {
  366. match byte {
  367. b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
  368. encoded.push(char::from(byte));
  369. }
  370. _ => {
  371. use std::fmt::Write as _;
  372. let _ = write!(&mut encoded, "%{byte:02X}");
  373. }
  374. }
  375. }
  376. encoded
  377. }
  378. fn percent_decode(value: &str) -> Result<String, String> {
  379. let mut decoded = Vec::with_capacity(value.len());
  380. let bytes = value.as_bytes();
  381. let mut index = 0;
  382. while index < bytes.len() {
  383. match bytes[index] {
  384. b'%' if index + 2 < bytes.len() => {
  385. let hi = decode_hex(bytes[index + 1])?;
  386. let lo = decode_hex(bytes[index + 2])?;
  387. decoded.push((hi << 4) | lo);
  388. index += 3;
  389. }
  390. b'+' => {
  391. decoded.push(b' ');
  392. index += 1;
  393. }
  394. byte => {
  395. decoded.push(byte);
  396. index += 1;
  397. }
  398. }
  399. }
  400. String::from_utf8(decoded).map_err(|error| error.to_string())
  401. }
  402. fn decode_hex(byte: u8) -> Result<u8, String> {
  403. match byte {
  404. b'0'..=b'9' => Ok(byte - b'0'),
  405. b'a'..=b'f' => Ok(byte - b'a' + 10),
  406. b'A'..=b'F' => Ok(byte - b'A' + 10),
  407. _ => Err(format!("invalid percent-encoding byte: {byte}")),
  408. }
  409. }
  410. #[cfg(test)]
  411. mod tests {
  412. use std::time::{SystemTime, UNIX_EPOCH};
  413. use super::{
  414. clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
  415. generate_state, load_oauth_credentials, loopback_redirect_uri, parse_oauth_callback_query,
  416. parse_oauth_callback_request_target, save_oauth_credentials, OAuthAuthorizationRequest,
  417. OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,
  418. };
  419. fn sample_config() -> OAuthConfig {
  420. OAuthConfig {
  421. client_id: "runtime-client".to_string(),
  422. authorize_url: "https://console.test/oauth/authorize".to_string(),
  423. token_url: "https://console.test/oauth/token".to_string(),
  424. callback_port: Some(4545),
  425. manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
  426. scopes: vec!["org:read".to_string(), "user:write".to_string()],
  427. }
  428. }
  429. fn env_lock() -> std::sync::MutexGuard<'static, ()> {
  430. crate::test_env_lock()
  431. }
  432. fn temp_config_home() -> std::path::PathBuf {
  433. std::env::temp_dir().join(format!(
  434. "runtime-oauth-test-{}-{}",
  435. std::process::id(),
  436. SystemTime::now()
  437. .duration_since(UNIX_EPOCH)
  438. .expect("time")
  439. .as_nanos()
  440. ))
  441. }
  442. #[test]
  443. fn s256_challenge_matches_expected_vector() {
  444. assert_eq!(
  445. code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
  446. "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
  447. );
  448. }
  449. #[test]
  450. fn generates_pkce_pair_and_state() {
  451. let pair = generate_pkce_pair().expect("pkce pair");
  452. let state = generate_state().expect("state");
  453. assert!(!pair.verifier.is_empty());
  454. assert!(!pair.challenge.is_empty());
  455. assert!(!state.is_empty());
  456. }
  457. #[test]
  458. fn builds_authorize_url_and_form_requests() {
  459. let config = sample_config();
  460. let pair = generate_pkce_pair().expect("pkce");
  461. let url = OAuthAuthorizationRequest::from_config(
  462. &config,
  463. loopback_redirect_uri(4545),
  464. "state-123",
  465. &pair,
  466. )
  467. .with_extra_param("login_hint", "user@example.com")
  468. .build_url();
  469. assert!(url.starts_with("https://console.test/oauth/authorize?"));
  470. assert!(url.contains("response_type=code"));
  471. assert!(url.contains("client_id=runtime-client"));
  472. assert!(url.contains("scope=org%3Aread%20user%3Awrite"));
  473. assert!(url.contains("login_hint=user%40example.com"));
  474. let exchange = OAuthTokenExchangeRequest::from_config(
  475. &config,
  476. "auth-code",
  477. "state-123",
  478. pair.verifier,
  479. loopback_redirect_uri(4545),
  480. );
  481. assert_eq!(
  482. exchange.form_params().get("grant_type").map(String::as_str),
  483. Some("authorization_code")
  484. );
  485. let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None);
  486. assert_eq!(
  487. refresh.form_params().get("scope").map(String::as_str),
  488. Some("org:read user:write")
  489. );
  490. }
  491. #[test]
  492. fn oauth_credentials_round_trip_and_clear_preserves_other_fields() {
  493. let _guard = env_lock();
  494. let config_home = temp_config_home();
  495. std::env::set_var("CLAW_CONFIG_HOME", &config_home);
  496. let path = credentials_path().expect("credentials path");
  497. std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent");
  498. std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials");
  499. let token_set = OAuthTokenSet {
  500. access_token: "access-token".to_string(),
  501. refresh_token: Some("refresh-token".to_string()),
  502. expires_at: Some(123),
  503. scopes: vec!["scope:a".to_string()],
  504. };
  505. save_oauth_credentials(&token_set).expect("save credentials");
  506. assert_eq!(
  507. load_oauth_credentials().expect("load credentials"),
  508. Some(token_set)
  509. );
  510. let saved = std::fs::read_to_string(&path).expect("read saved file");
  511. assert!(saved.contains("\"other\": \"value\""));
  512. assert!(saved.contains("\"oauth\""));
  513. clear_oauth_credentials().expect("clear credentials");
  514. assert_eq!(load_oauth_credentials().expect("load cleared"), None);
  515. let cleared = std::fs::read_to_string(&path).expect("read cleared file");
  516. assert!(cleared.contains("\"other\": \"value\""));
  517. assert!(!cleared.contains("\"oauth\""));
  518. std::env::remove_var("CLAW_CONFIG_HOME");
  519. std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
  520. }
  521. #[test]
  522. fn parses_callback_query_and_target() {
  523. let params =
  524. parse_oauth_callback_query("code=abc123&state=state-1&error_description=needs%20login")
  525. .expect("parse query");
  526. assert_eq!(params.code.as_deref(), Some("abc123"));
  527. assert_eq!(params.state.as_deref(), Some("state-1"));
  528. assert_eq!(params.error_description.as_deref(), Some("needs login"));
  529. let params = parse_oauth_callback_request_target("/callback?code=abc&state=xyz")
  530. .expect("parse callback target");
  531. assert_eq!(params.code.as_deref(), Some("abc"));
  532. assert_eq!(params.state.as_deref(), Some("xyz"));
  533. assert!(parse_oauth_callback_request_target("/wrong?code=abc").is_err());
  534. }
  535. }