oauth.rs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. use std::collections::BTreeMap;
  2. use std::fs::File;
  3. use std::io::{self, Read};
  4. use sha2::{Digest, Sha256};
  5. use crate::config::OAuthConfig;
  6. #[derive(Debug, Clone, PartialEq, Eq)]
  7. pub struct OAuthTokenSet {
  8. pub access_token: String,
  9. pub refresh_token: Option<String>,
  10. pub expires_at: Option<u64>,
  11. pub scopes: Vec<String>,
  12. }
  13. #[derive(Debug, Clone, PartialEq, Eq)]
  14. pub struct PkceCodePair {
  15. pub verifier: String,
  16. pub challenge: String,
  17. pub challenge_method: PkceChallengeMethod,
  18. }
  19. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  20. pub enum PkceChallengeMethod {
  21. S256,
  22. }
  23. impl PkceChallengeMethod {
  24. #[must_use]
  25. pub const fn as_str(self) -> &'static str {
  26. match self {
  27. Self::S256 => "S256",
  28. }
  29. }
  30. }
  31. #[derive(Debug, Clone, PartialEq, Eq)]
  32. pub struct OAuthAuthorizationRequest {
  33. pub authorize_url: String,
  34. pub client_id: String,
  35. pub redirect_uri: String,
  36. pub scopes: Vec<String>,
  37. pub state: String,
  38. pub code_challenge: String,
  39. pub code_challenge_method: PkceChallengeMethod,
  40. pub extra_params: BTreeMap<String, String>,
  41. }
  42. #[derive(Debug, Clone, PartialEq, Eq)]
  43. pub struct OAuthTokenExchangeRequest {
  44. pub grant_type: &'static str,
  45. pub code: String,
  46. pub redirect_uri: String,
  47. pub client_id: String,
  48. pub code_verifier: String,
  49. pub state: String,
  50. }
  51. #[derive(Debug, Clone, PartialEq, Eq)]
  52. pub struct OAuthRefreshRequest {
  53. pub grant_type: &'static str,
  54. pub refresh_token: String,
  55. pub client_id: String,
  56. pub scopes: Vec<String>,
  57. }
  58. impl OAuthAuthorizationRequest {
  59. #[must_use]
  60. pub fn from_config(
  61. config: &OAuthConfig,
  62. redirect_uri: impl Into<String>,
  63. state: impl Into<String>,
  64. pkce: &PkceCodePair,
  65. ) -> Self {
  66. Self {
  67. authorize_url: config.authorize_url.clone(),
  68. client_id: config.client_id.clone(),
  69. redirect_uri: redirect_uri.into(),
  70. scopes: config.scopes.clone(),
  71. state: state.into(),
  72. code_challenge: pkce.challenge.clone(),
  73. code_challenge_method: pkce.challenge_method,
  74. extra_params: BTreeMap::new(),
  75. }
  76. }
  77. #[must_use]
  78. pub fn with_extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
  79. self.extra_params.insert(key.into(), value.into());
  80. self
  81. }
  82. #[must_use]
  83. pub fn build_url(&self) -> String {
  84. let mut params = vec![
  85. ("response_type", "code".to_string()),
  86. ("client_id", self.client_id.clone()),
  87. ("redirect_uri", self.redirect_uri.clone()),
  88. ("scope", self.scopes.join(" ")),
  89. ("state", self.state.clone()),
  90. ("code_challenge", self.code_challenge.clone()),
  91. (
  92. "code_challenge_method",
  93. self.code_challenge_method.as_str().to_string(),
  94. ),
  95. ];
  96. params.extend(
  97. self.extra_params
  98. .iter()
  99. .map(|(key, value)| (key.as_str(), value.clone())),
  100. );
  101. let query = params
  102. .into_iter()
  103. .map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value)))
  104. .collect::<Vec<_>>()
  105. .join("&");
  106. format!(
  107. "{}{}{}",
  108. self.authorize_url,
  109. if self.authorize_url.contains('?') {
  110. '&'
  111. } else {
  112. '?'
  113. },
  114. query
  115. )
  116. }
  117. }
  118. impl OAuthTokenExchangeRequest {
  119. #[must_use]
  120. pub fn from_config(
  121. config: &OAuthConfig,
  122. code: impl Into<String>,
  123. state: impl Into<String>,
  124. verifier: impl Into<String>,
  125. redirect_uri: impl Into<String>,
  126. ) -> Self {
  127. let _ = config;
  128. Self {
  129. grant_type: "authorization_code",
  130. code: code.into(),
  131. redirect_uri: redirect_uri.into(),
  132. client_id: config.client_id.clone(),
  133. code_verifier: verifier.into(),
  134. state: state.into(),
  135. }
  136. }
  137. #[must_use]
  138. pub fn form_params(&self) -> BTreeMap<&str, String> {
  139. BTreeMap::from([
  140. ("grant_type", self.grant_type.to_string()),
  141. ("code", self.code.clone()),
  142. ("redirect_uri", self.redirect_uri.clone()),
  143. ("client_id", self.client_id.clone()),
  144. ("code_verifier", self.code_verifier.clone()),
  145. ("state", self.state.clone()),
  146. ])
  147. }
  148. }
  149. impl OAuthRefreshRequest {
  150. #[must_use]
  151. pub fn from_config(
  152. config: &OAuthConfig,
  153. refresh_token: impl Into<String>,
  154. scopes: Option<Vec<String>>,
  155. ) -> Self {
  156. Self {
  157. grant_type: "refresh_token",
  158. refresh_token: refresh_token.into(),
  159. client_id: config.client_id.clone(),
  160. scopes: scopes.unwrap_or_else(|| config.scopes.clone()),
  161. }
  162. }
  163. #[must_use]
  164. pub fn form_params(&self) -> BTreeMap<&str, String> {
  165. BTreeMap::from([
  166. ("grant_type", self.grant_type.to_string()),
  167. ("refresh_token", self.refresh_token.clone()),
  168. ("client_id", self.client_id.clone()),
  169. ("scope", self.scopes.join(" ")),
  170. ])
  171. }
  172. }
  173. pub fn generate_pkce_pair() -> io::Result<PkceCodePair> {
  174. let verifier = generate_random_token(32)?;
  175. Ok(PkceCodePair {
  176. challenge: code_challenge_s256(&verifier),
  177. verifier,
  178. challenge_method: PkceChallengeMethod::S256,
  179. })
  180. }
  181. pub fn generate_state() -> io::Result<String> {
  182. generate_random_token(32)
  183. }
  184. #[must_use]
  185. pub fn code_challenge_s256(verifier: &str) -> String {
  186. let digest = Sha256::digest(verifier.as_bytes());
  187. base64url_encode(&digest)
  188. }
  189. #[must_use]
  190. pub fn loopback_redirect_uri(port: u16) -> String {
  191. format!("http://localhost:{port}/callback")
  192. }
  193. fn generate_random_token(bytes: usize) -> io::Result<String> {
  194. let mut buffer = vec![0_u8; bytes];
  195. File::open("/dev/urandom")?.read_exact(&mut buffer)?;
  196. Ok(base64url_encode(&buffer))
  197. }
  198. fn base64url_encode(bytes: &[u8]) -> String {
  199. const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
  200. let mut output = String::new();
  201. let mut index = 0;
  202. while index + 3 <= bytes.len() {
  203. let block = (u32::from(bytes[index]) << 16)
  204. | (u32::from(bytes[index + 1]) << 8)
  205. | u32::from(bytes[index + 2]);
  206. output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
  207. output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
  208. output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
  209. output.push(TABLE[(block & 0x3F) as usize] as char);
  210. index += 3;
  211. }
  212. match bytes.len().saturating_sub(index) {
  213. 1 => {
  214. let block = u32::from(bytes[index]) << 16;
  215. output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
  216. output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
  217. }
  218. 2 => {
  219. let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8);
  220. output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
  221. output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
  222. output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
  223. }
  224. _ => {}
  225. }
  226. output
  227. }
  228. fn percent_encode(value: &str) -> String {
  229. let mut encoded = String::new();
  230. for byte in value.bytes() {
  231. match byte {
  232. b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
  233. encoded.push(char::from(byte));
  234. }
  235. _ => {
  236. use std::fmt::Write as _;
  237. let _ = write!(&mut encoded, "%{byte:02X}");
  238. }
  239. }
  240. }
  241. encoded
  242. }
  243. #[cfg(test)]
  244. mod tests {
  245. use super::{
  246. code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
  247. OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest,
  248. };
  249. fn sample_config() -> OAuthConfig {
  250. OAuthConfig {
  251. client_id: "runtime-client".to_string(),
  252. authorize_url: "https://console.test/oauth/authorize".to_string(),
  253. token_url: "https://console.test/oauth/token".to_string(),
  254. callback_port: Some(4545),
  255. manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
  256. scopes: vec!["org:read".to_string(), "user:write".to_string()],
  257. }
  258. }
  259. #[test]
  260. fn s256_challenge_matches_expected_vector() {
  261. assert_eq!(
  262. code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
  263. "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
  264. );
  265. }
  266. #[test]
  267. fn generates_pkce_pair_and_state() {
  268. let pair = generate_pkce_pair().expect("pkce pair");
  269. let state = generate_state().expect("state");
  270. assert!(!pair.verifier.is_empty());
  271. assert!(!pair.challenge.is_empty());
  272. assert!(!state.is_empty());
  273. }
  274. #[test]
  275. fn builds_authorize_url_and_form_requests() {
  276. let config = sample_config();
  277. let pair = generate_pkce_pair().expect("pkce");
  278. let url = OAuthAuthorizationRequest::from_config(
  279. &config,
  280. loopback_redirect_uri(4545),
  281. "state-123",
  282. &pair,
  283. )
  284. .with_extra_param("login_hint", "user@example.com")
  285. .build_url();
  286. assert!(url.starts_with("https://console.test/oauth/authorize?"));
  287. assert!(url.contains("response_type=code"));
  288. assert!(url.contains("client_id=runtime-client"));
  289. assert!(url.contains("scope=org%3Aread%20user%3Awrite"));
  290. assert!(url.contains("login_hint=user%40example.com"));
  291. let exchange = OAuthTokenExchangeRequest::from_config(
  292. &config,
  293. "auth-code",
  294. "state-123",
  295. pair.verifier,
  296. loopback_redirect_uri(4545),
  297. );
  298. assert_eq!(
  299. exchange.form_params().get("grant_type").map(String::as_str),
  300. Some("authorization_code")
  301. );
  302. let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None);
  303. assert_eq!(
  304. refresh.form_params().get("scope").map(String::as_str),
  305. Some("org:read user:write")
  306. );
  307. }
  308. }