| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- use std::collections::BTreeMap;
- use std::fs::File;
- use std::io::{self, Read};
- use sha2::{Digest, Sha256};
- use crate::config::OAuthConfig;
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct OAuthTokenSet {
- pub access_token: String,
- pub refresh_token: Option<String>,
- pub expires_at: Option<u64>,
- pub scopes: Vec<String>,
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct PkceCodePair {
- pub verifier: String,
- pub challenge: String,
- pub challenge_method: PkceChallengeMethod,
- }
- #[derive(Debug, Clone, Copy, PartialEq, Eq)]
- pub enum PkceChallengeMethod {
- S256,
- }
- impl PkceChallengeMethod {
- #[must_use]
- pub const fn as_str(self) -> &'static str {
- match self {
- Self::S256 => "S256",
- }
- }
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct OAuthAuthorizationRequest {
- pub authorize_url: String,
- pub client_id: String,
- pub redirect_uri: String,
- pub scopes: Vec<String>,
- pub state: String,
- pub code_challenge: String,
- pub code_challenge_method: PkceChallengeMethod,
- pub extra_params: BTreeMap<String, String>,
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct OAuthTokenExchangeRequest {
- pub grant_type: &'static str,
- pub code: String,
- pub redirect_uri: String,
- pub client_id: String,
- pub code_verifier: String,
- pub state: String,
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct OAuthRefreshRequest {
- pub grant_type: &'static str,
- pub refresh_token: String,
- pub client_id: String,
- pub scopes: Vec<String>,
- }
- impl OAuthAuthorizationRequest {
- #[must_use]
- pub fn from_config(
- config: &OAuthConfig,
- redirect_uri: impl Into<String>,
- state: impl Into<String>,
- pkce: &PkceCodePair,
- ) -> Self {
- Self {
- authorize_url: config.authorize_url.clone(),
- client_id: config.client_id.clone(),
- redirect_uri: redirect_uri.into(),
- scopes: config.scopes.clone(),
- state: state.into(),
- code_challenge: pkce.challenge.clone(),
- code_challenge_method: pkce.challenge_method,
- extra_params: BTreeMap::new(),
- }
- }
- #[must_use]
- pub fn with_extra_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
- self.extra_params.insert(key.into(), value.into());
- self
- }
- #[must_use]
- pub fn build_url(&self) -> String {
- let mut params = vec![
- ("response_type", "code".to_string()),
- ("client_id", self.client_id.clone()),
- ("redirect_uri", self.redirect_uri.clone()),
- ("scope", self.scopes.join(" ")),
- ("state", self.state.clone()),
- ("code_challenge", self.code_challenge.clone()),
- (
- "code_challenge_method",
- self.code_challenge_method.as_str().to_string(),
- ),
- ];
- params.extend(
- self.extra_params
- .iter()
- .map(|(key, value)| (key.as_str(), value.clone())),
- );
- let query = params
- .into_iter()
- .map(|(key, value)| format!("{}={}", percent_encode(key), percent_encode(&value)))
- .collect::<Vec<_>>()
- .join("&");
- format!(
- "{}{}{}",
- self.authorize_url,
- if self.authorize_url.contains('?') {
- '&'
- } else {
- '?'
- },
- query
- )
- }
- }
- impl OAuthTokenExchangeRequest {
- #[must_use]
- pub fn from_config(
- config: &OAuthConfig,
- code: impl Into<String>,
- state: impl Into<String>,
- verifier: impl Into<String>,
- redirect_uri: impl Into<String>,
- ) -> Self {
- let _ = config;
- Self {
- grant_type: "authorization_code",
- code: code.into(),
- redirect_uri: redirect_uri.into(),
- client_id: config.client_id.clone(),
- code_verifier: verifier.into(),
- state: state.into(),
- }
- }
- #[must_use]
- pub fn form_params(&self) -> BTreeMap<&str, String> {
- BTreeMap::from([
- ("grant_type", self.grant_type.to_string()),
- ("code", self.code.clone()),
- ("redirect_uri", self.redirect_uri.clone()),
- ("client_id", self.client_id.clone()),
- ("code_verifier", self.code_verifier.clone()),
- ("state", self.state.clone()),
- ])
- }
- }
- impl OAuthRefreshRequest {
- #[must_use]
- pub fn from_config(
- config: &OAuthConfig,
- refresh_token: impl Into<String>,
- scopes: Option<Vec<String>>,
- ) -> Self {
- Self {
- grant_type: "refresh_token",
- refresh_token: refresh_token.into(),
- client_id: config.client_id.clone(),
- scopes: scopes.unwrap_or_else(|| config.scopes.clone()),
- }
- }
- #[must_use]
- pub fn form_params(&self) -> BTreeMap<&str, String> {
- BTreeMap::from([
- ("grant_type", self.grant_type.to_string()),
- ("refresh_token", self.refresh_token.clone()),
- ("client_id", self.client_id.clone()),
- ("scope", self.scopes.join(" ")),
- ])
- }
- }
- pub fn generate_pkce_pair() -> io::Result<PkceCodePair> {
- let verifier = generate_random_token(32)?;
- Ok(PkceCodePair {
- challenge: code_challenge_s256(&verifier),
- verifier,
- challenge_method: PkceChallengeMethod::S256,
- })
- }
- pub fn generate_state() -> io::Result<String> {
- generate_random_token(32)
- }
- #[must_use]
- pub fn code_challenge_s256(verifier: &str) -> String {
- let digest = Sha256::digest(verifier.as_bytes());
- base64url_encode(&digest)
- }
- #[must_use]
- pub fn loopback_redirect_uri(port: u16) -> String {
- format!("http://localhost:{port}/callback")
- }
- fn generate_random_token(bytes: usize) -> io::Result<String> {
- let mut buffer = vec![0_u8; bytes];
- File::open("/dev/urandom")?.read_exact(&mut buffer)?;
- Ok(base64url_encode(&buffer))
- }
- fn base64url_encode(bytes: &[u8]) -> String {
- const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
- let mut output = String::new();
- let mut index = 0;
- while index + 3 <= bytes.len() {
- let block = (u32::from(bytes[index]) << 16)
- | (u32::from(bytes[index + 1]) << 8)
- | u32::from(bytes[index + 2]);
- output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
- output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
- output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
- output.push(TABLE[(block & 0x3F) as usize] as char);
- index += 3;
- }
- match bytes.len().saturating_sub(index) {
- 1 => {
- let block = u32::from(bytes[index]) << 16;
- output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
- output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
- }
- 2 => {
- let block = (u32::from(bytes[index]) << 16) | (u32::from(bytes[index + 1]) << 8);
- output.push(TABLE[((block >> 18) & 0x3F) as usize] as char);
- output.push(TABLE[((block >> 12) & 0x3F) as usize] as char);
- output.push(TABLE[((block >> 6) & 0x3F) as usize] as char);
- }
- _ => {}
- }
- output
- }
- fn percent_encode(value: &str) -> String {
- let mut encoded = String::new();
- for byte in value.bytes() {
- match byte {
- b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
- encoded.push(char::from(byte));
- }
- _ => {
- use std::fmt::Write as _;
- let _ = write!(&mut encoded, "%{byte:02X}");
- }
- }
- }
- encoded
- }
- #[cfg(test)]
- mod tests {
- use super::{
- code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
- OAuthAuthorizationRequest, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest,
- };
- fn sample_config() -> OAuthConfig {
- OAuthConfig {
- client_id: "runtime-client".to_string(),
- authorize_url: "https://console.test/oauth/authorize".to_string(),
- token_url: "https://console.test/oauth/token".to_string(),
- callback_port: Some(4545),
- manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
- scopes: vec!["org:read".to_string(), "user:write".to_string()],
- }
- }
- #[test]
- fn s256_challenge_matches_expected_vector() {
- assert_eq!(
- code_challenge_s256("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
- "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
- );
- }
- #[test]
- fn generates_pkce_pair_and_state() {
- let pair = generate_pkce_pair().expect("pkce pair");
- let state = generate_state().expect("state");
- assert!(!pair.verifier.is_empty());
- assert!(!pair.challenge.is_empty());
- assert!(!state.is_empty());
- }
- #[test]
- fn builds_authorize_url_and_form_requests() {
- let config = sample_config();
- let pair = generate_pkce_pair().expect("pkce");
- let url = OAuthAuthorizationRequest::from_config(
- &config,
- loopback_redirect_uri(4545),
- "state-123",
- &pair,
- )
- .with_extra_param("login_hint", "user@example.com")
- .build_url();
- assert!(url.starts_with("https://console.test/oauth/authorize?"));
- assert!(url.contains("response_type=code"));
- assert!(url.contains("client_id=runtime-client"));
- assert!(url.contains("scope=org%3Aread%20user%3Awrite"));
- assert!(url.contains("login_hint=user%40example.com"));
- let exchange = OAuthTokenExchangeRequest::from_config(
- &config,
- "auth-code",
- "state-123",
- pair.verifier,
- loopback_redirect_uri(4545),
- );
- assert_eq!(
- exchange.form_params().get("grant_type").map(String::as_str),
- Some("authorization_code")
- );
- let refresh = OAuthRefreshRequest::from_config(&config, "refresh-token", None);
- assert_eq!(
- refresh.form_params().get("scope").map(String::as_str),
- Some("org:read user:write")
- );
- }
- }
|