1use std::{collections::HashMap, sync::Arc};
8
9use mas_context::LogContext;
10use mas_data_model::{
11 UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
12};
13use mas_iana::oauth::PkceCodeChallengeMethod;
14use mas_oidc_client::error::DiscoveryError;
15use mas_storage::{RepositoryAccess, upstream_oauth2::UpstreamOAuthProviderRepository};
16use oauth2_types::oidc::VerifiedProviderMetadata;
17use tokio::sync::RwLock;
18use url::Url;
19
20pub struct LazyProviderInfos<'a> {
23 cache: &'a MetadataCache,
24 provider: &'a UpstreamOAuthProvider,
25 client: &'a reqwest::Client,
26 loaded_metadata: Option<Arc<VerifiedProviderMetadata>>,
27}
28
29impl<'a> LazyProviderInfos<'a> {
30 pub fn new(
31 cache: &'a MetadataCache,
32 provider: &'a UpstreamOAuthProvider,
33 client: &'a reqwest::Client,
34 ) -> Self {
35 Self {
36 cache,
37 provider,
38 client,
39 loaded_metadata: None,
40 }
41 }
42
43 pub async fn maybe_discover(
46 &mut self,
47 ) -> Result<Option<&VerifiedProviderMetadata>, DiscoveryError> {
48 match self.load().await {
49 Ok(metadata) => Ok(Some(metadata)),
50 Err(DiscoveryError::Disabled) => Ok(None),
51 Err(e) => Err(e),
52 }
53 }
54
55 async fn load(&mut self) -> Result<&VerifiedProviderMetadata, DiscoveryError> {
56 if self.loaded_metadata.is_none() {
57 let verify = match self.provider.discovery_mode {
58 UpstreamOAuthProviderDiscoveryMode::Oidc => true,
59 UpstreamOAuthProviderDiscoveryMode::Insecure => false,
60 UpstreamOAuthProviderDiscoveryMode::Disabled => {
61 return Err(DiscoveryError::Disabled);
62 }
63 };
64
65 let Some(issuer) = &self.provider.issuer else {
66 return Err(DiscoveryError::MissingIssuer);
67 };
68
69 let metadata = self.cache.get(self.client, issuer, verify).await?;
70
71 self.loaded_metadata = Some(metadata);
72 }
73
74 Ok(self.loaded_metadata.as_ref().unwrap())
75 }
76
77 pub async fn jwks_uri(&mut self) -> Result<&Url, DiscoveryError> {
82 if let Some(jwks_uri) = &self.provider.jwks_uri_override {
83 return Ok(jwks_uri);
84 }
85
86 Ok(self.load().await?.jwks_uri())
87 }
88
89 pub async fn authorization_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
94 if let Some(authorization_endpoint) = &self.provider.authorization_endpoint_override {
95 return Ok(authorization_endpoint);
96 }
97
98 Ok(self.load().await?.authorization_endpoint())
99 }
100
101 pub async fn token_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
106 if let Some(token_endpoint) = &self.provider.token_endpoint_override {
107 return Ok(token_endpoint);
108 }
109
110 Ok(self.load().await?.token_endpoint())
111 }
112
113 pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
118 if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
119 return Ok(userinfo_endpoint);
120 }
121
122 Ok(self.load().await?.userinfo_endpoint())
123 }
124
125 pub async fn pkce_methods(
130 &mut self,
131 ) -> Result<Option<Vec<PkceCodeChallengeMethod>>, DiscoveryError> {
132 let methods = match self.provider.pkce_mode {
133 UpstreamOAuthProviderPkceMode::Auto => self
134 .maybe_discover()
135 .await?
136 .and_then(|metadata| metadata.code_challenge_methods_supported.clone()),
137 UpstreamOAuthProviderPkceMode::S256 => Some(vec![PkceCodeChallengeMethod::S256]),
138 UpstreamOAuthProviderPkceMode::Disabled => None,
139 };
140
141 Ok(methods)
142 }
143}
144
145#[allow(clippy::module_name_repetitions)]
151#[derive(Debug, Clone, Default)]
152pub struct MetadataCache {
153 cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
154 insecure_cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
155}
156
157impl MetadataCache {
158 #[must_use]
159 pub fn new() -> Self {
160 Self::default()
161 }
162
163 #[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all)]
169 pub async fn warm_up_and_run<R: RepositoryAccess>(
170 &self,
171 client: &reqwest::Client,
172 interval: std::time::Duration,
173 repository: &mut R,
174 ) -> Result<tokio::task::JoinHandle<()>, R::Error> {
175 let providers = repository.upstream_oauth_provider().all_enabled().await?;
176
177 for provider in providers {
178 let verify = match provider.discovery_mode {
179 UpstreamOAuthProviderDiscoveryMode::Oidc => true,
180 UpstreamOAuthProviderDiscoveryMode::Insecure => false,
181 UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
182 };
183
184 let Some(issuer) = &provider.issuer else {
185 tracing::error!(%provider.id, "Provider doesn't have an issuer set, but discovery is enabled!");
186 continue;
187 };
188
189 if let Err(e) = self.fetch(client, issuer, verify).await {
190 tracing::error!(%issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
191 }
192 }
193
194 let cache = self.clone();
196 let client = client.clone();
197 Ok(tokio::spawn(async move {
198 loop {
199 tokio::time::sleep(interval).await;
201 LogContext::new("metadata-cache-refresh")
202 .run(|| cache.refresh_all(&client))
203 .await;
204 }
205 }))
206 }
207
208 #[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all)]
209 async fn fetch(
210 &self,
211 client: &reqwest::Client,
212 issuer: &str,
213 verify: bool,
214 ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
215 if verify {
216 let metadata = mas_oidc_client::requests::discovery::discover(client, issuer).await?;
217 let metadata = Arc::new(metadata);
218
219 self.cache
220 .write()
221 .await
222 .insert(issuer.to_owned(), metadata.clone());
223
224 Ok(metadata)
225 } else {
226 let metadata =
227 mas_oidc_client::requests::discovery::insecure_discover(client, issuer).await?;
228 let metadata = Arc::new(metadata);
229
230 self.insecure_cache
231 .write()
232 .await
233 .insert(issuer.to_owned(), metadata.clone());
234
235 Ok(metadata)
236 }
237 }
238
239 #[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all)]
241 pub async fn get(
242 &self,
243 client: &reqwest::Client,
244 issuer: &str,
245 verify: bool,
246 ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
247 let cache = if verify {
248 self.cache.read().await
249 } else {
250 self.insecure_cache.read().await
251 };
252
253 if let Some(metadata) = cache.get(issuer) {
254 return Ok(Arc::clone(metadata));
255 }
256 drop(cache);
258
259 let metadata = self.fetch(client, issuer, verify).await?;
260 Ok(metadata)
261 }
262
263 #[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
264 async fn refresh_all(&self, client: &reqwest::Client) {
265 let keys: Vec<String> = {
267 let cache = self.cache.read().await;
268 cache.keys().cloned().collect()
269 };
270
271 for issuer in keys {
272 if let Err(e) = self.fetch(client, &issuer, true).await {
273 tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
274 }
275 }
276
277 let keys: Vec<String> = {
279 let cache = self.insecure_cache.read().await;
280 cache.keys().cloned().collect()
281 };
282
283 for issuer in keys {
284 if let Err(e) = self.fetch(client, &issuer, false).await {
285 tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
286 }
287 }
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 #![allow(clippy::too_many_lines)]
294
295 use mas_data_model::{
299 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderTokenAuthMethod,
300 };
301 use mas_iana::jose::JsonWebSignatureAlg;
302 use mas_storage::{Clock, clock::MockClock};
303 use oauth2_types::scope::{OPENID, Scope};
304 use ulid::Ulid;
305 use wiremock::{
306 Mock, MockServer, ResponseTemplate,
307 matchers::{method, path},
308 };
309
310 use super::*;
311 use crate::test_utils::setup;
312
313 #[tokio::test]
314 async fn test_metadata_cache() {
315 setup();
316 let mock_server = MockServer::start().await;
317 let http_client = mas_http::reqwest_client();
318
319 let cache = MetadataCache::new();
320
321 cache
323 .get(&http_client, &mock_server.uri(), false)
324 .await
325 .unwrap_err();
326
327 let expected_calls = 3;
328 let mut calls = 0;
329 let _mock_guard = Mock::given(method("GET"))
330 .and(path("/.well-known/openid-configuration"))
331 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
332 "issuer": mock_server.uri(),
333 "authorization_endpoint": "https://example.com/authorize",
334 "token_endpoint": "https://example.com/token",
335 "jwks_uri": "https://example.com/jwks",
336 "userinfo_endpoint": "https://example.com/userinfo",
337 "scopes_supported": ["openid"],
338 "response_types_supported": ["code"],
339 "response_modes_supported": ["query", "fragment"],
340 "grant_types_supported": ["authorization_code"],
341 "subject_types_supported": ["public"],
342 "id_token_signing_alg_values_supported": ["RS256"],
343 })))
344 .expect(expected_calls)
345 .mount(&mock_server)
346 .await;
347
348 cache
350 .get(&http_client, &mock_server.uri(), false)
351 .await
352 .unwrap();
353 calls += 1;
354
355 cache
357 .get(&http_client, &mock_server.uri(), false)
358 .await
359 .unwrap();
360 calls += 0;
361
362 cache
364 .get(&http_client, &mock_server.uri(), true)
365 .await
366 .unwrap_err();
367 calls += 1;
368
369 cache.refresh_all(&http_client).await;
371 calls += 1;
372
373 assert_eq!(calls, expected_calls);
374 }
375
376 #[tokio::test]
377 async fn test_lazy_provider_infos() {
378 setup();
379
380 let mock_server = MockServer::start().await;
381 let http_client = mas_http::reqwest_client();
382
383 let expected_calls = 2;
384 let mut calls = 0;
385 let _mock_guard = Mock::given(method("GET"))
386 .and(path("/.well-known/openid-configuration"))
387 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
388 "issuer": mock_server.uri(),
389 "authorization_endpoint": "https://example.com/authorize",
390 "token_endpoint": "https://example.com/token",
391 "jwks_uri": "https://example.com/jwks",
392 "userinfo_endpoint": "https://example.com/userinfo",
393 "scopes_supported": ["openid"],
394 "response_types_supported": ["code"],
395 "response_modes_supported": ["query", "fragment"],
396 "grant_types_supported": ["authorization_code"],
397 "subject_types_supported": ["public"],
398 "id_token_signing_alg_values_supported": ["RS256"],
399 })))
400 .expect(expected_calls)
401 .mount(&mock_server)
402 .await;
403
404 let clock = MockClock::default();
405 let provider = UpstreamOAuthProvider {
406 id: Ulid::nil(),
407 issuer: Some(mock_server.uri()),
408 human_name: Some("Example Ltd.".to_owned()),
409 brand_name: None,
410 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
411 pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
412 fetch_userinfo: false,
413 userinfo_signed_response_alg: None,
414 jwks_uri_override: None,
415 authorization_endpoint_override: None,
416 scope: Scope::from_iter([OPENID]),
417 userinfo_endpoint_override: None,
418 token_endpoint_override: None,
419 client_id: "client_id".to_owned(),
420 encrypted_client_secret: None,
421 token_endpoint_signing_alg: None,
422 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
423 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
424 response_mode: None,
425 created_at: clock.now(),
426 disabled_at: None,
427 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
428 additional_authorization_parameters: Vec::new(),
429 };
430
431 {
433 let cache = MetadataCache::new();
434 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
435 lazy_metadata.maybe_discover().await.unwrap();
436 assert_eq!(
437 lazy_metadata
438 .authorization_endpoint()
439 .await
440 .unwrap()
441 .as_str(),
442 "https://example.com/authorize"
443 );
444 calls += 1;
445 }
446
447 {
449 let provider = UpstreamOAuthProvider {
450 jwks_uri_override: Some("https://example.com/jwks_override".parse().unwrap()),
451 authorization_endpoint_override: Some(
452 "https://example.com/authorize_override".parse().unwrap(),
453 ),
454 token_endpoint_override: Some(
455 "https://example.com/token_override".parse().unwrap(),
456 ),
457 ..provider.clone()
458 };
459 let cache = MetadataCache::new();
460 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
461 assert_eq!(
462 lazy_metadata.jwks_uri().await.unwrap().as_str(),
463 "https://example.com/jwks_override"
464 );
465 assert_eq!(
466 lazy_metadata
467 .authorization_endpoint()
468 .await
469 .unwrap()
470 .as_str(),
471 "https://example.com/authorize_override"
472 );
473 assert_eq!(
474 lazy_metadata.token_endpoint().await.unwrap().as_str(),
475 "https://example.com/token_override"
476 );
477 calls += 0;
479 }
480
481 {
483 let provider = UpstreamOAuthProvider {
484 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
485 ..provider.clone()
486 };
487 let cache = MetadataCache::new();
488 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
489 lazy_metadata.authorization_endpoint().await.unwrap_err();
490 calls += 1;
492 }
493
494 {
496 let provider = UpstreamOAuthProvider {
497 discovery_mode: UpstreamOAuthProviderDiscoveryMode::Disabled,
498 authorization_endpoint_override: Some(
499 Url::parse("https://example.com/authorize_override").unwrap(),
500 ),
501 token_endpoint_override: None,
502 ..provider.clone()
503 };
504 let cache = MetadataCache::new();
505 let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
506 assert!(lazy_metadata.maybe_discover().await.unwrap().is_none());
508 assert_eq!(
509 lazy_metadata
510 .authorization_endpoint()
511 .await
512 .unwrap()
513 .as_str(),
514 "https://example.com/authorize_override"
515 );
516 assert!(matches!(
517 lazy_metadata.token_endpoint().await,
518 Err(DiscoveryError::Disabled),
519 ));
520 calls += 0;
522 }
523
524 assert_eq!(calls, expected_calls);
525 }
526}