1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
11};
12use mas_iana::oauth::PkceCodeChallengeMethod;
13use mas_storage::{Clock, oauth2::OAuth2AuthorizationGrantRepository};
14use oauth2_types::{requests::ResponseMode, scope::Scope};
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use url::Url;
19use uuid::Uuid;
20
21use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
22
23pub struct PgOAuth2AuthorizationGrantRepository<'c> {
26 conn: &'c mut PgConnection,
27}
28
29impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
30 pub fn new(conn: &'c mut PgConnection) -> Self {
33 Self { conn }
34 }
35}
36
37#[allow(clippy::struct_excessive_bools)]
38struct GrantLookup {
39 oauth2_authorization_grant_id: Uuid,
40 created_at: DateTime<Utc>,
41 cancelled_at: Option<DateTime<Utc>>,
42 fulfilled_at: Option<DateTime<Utc>>,
43 exchanged_at: Option<DateTime<Utc>>,
44 scope: String,
45 state: Option<String>,
46 nonce: Option<String>,
47 redirect_uri: String,
48 response_mode: String,
49 response_type_code: bool,
50 response_type_id_token: bool,
51 authorization_code: Option<String>,
52 code_challenge: Option<String>,
53 code_challenge_method: Option<String>,
54 login_hint: Option<String>,
55 oauth2_client_id: Uuid,
56 oauth2_session_id: Option<Uuid>,
57}
58
59impl TryFrom<GrantLookup> for AuthorizationGrant {
60 type Error = DatabaseInconsistencyError;
61
62 #[allow(clippy::too_many_lines)]
63 fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
64 let id = value.oauth2_authorization_grant_id.into();
65 let scope: Scope = value.scope.parse().map_err(|e| {
66 DatabaseInconsistencyError::on("oauth2_authorization_grants")
67 .column("scope")
68 .row(id)
69 .source(e)
70 })?;
71
72 let stage = match (
73 value.fulfilled_at,
74 value.exchanged_at,
75 value.cancelled_at,
76 value.oauth2_session_id,
77 ) {
78 (None, None, None, None) => AuthorizationGrantStage::Pending,
79 (Some(fulfilled_at), None, None, Some(session_id)) => {
80 AuthorizationGrantStage::Fulfilled {
81 session_id: session_id.into(),
82 fulfilled_at,
83 }
84 }
85 (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
86 AuthorizationGrantStage::Exchanged {
87 session_id: session_id.into(),
88 fulfilled_at,
89 exchanged_at,
90 }
91 }
92 (None, None, Some(cancelled_at), None) => {
93 AuthorizationGrantStage::Cancelled { cancelled_at }
94 }
95 _ => {
96 return Err(
97 DatabaseInconsistencyError::on("oauth2_authorization_grants")
98 .column("stage")
99 .row(id),
100 );
101 }
102 };
103
104 let pkce = match (value.code_challenge, value.code_challenge_method) {
105 (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
106 Some(Pkce {
107 challenge_method: PkceCodeChallengeMethod::Plain,
108 challenge,
109 })
110 }
111 (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
112 challenge_method: PkceCodeChallengeMethod::S256,
113 challenge,
114 }),
115 (None, None) => None,
116 _ => {
117 return Err(
118 DatabaseInconsistencyError::on("oauth2_authorization_grants")
119 .column("code_challenge_method")
120 .row(id),
121 );
122 }
123 };
124
125 let code: Option<AuthorizationCode> =
126 match (value.response_type_code, value.authorization_code, pkce) {
127 (false, None, None) => None,
128 (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
129 _ => {
130 return Err(
131 DatabaseInconsistencyError::on("oauth2_authorization_grants")
132 .column("authorization_code")
133 .row(id),
134 );
135 }
136 };
137
138 let redirect_uri = value.redirect_uri.parse().map_err(|e| {
139 DatabaseInconsistencyError::on("oauth2_authorization_grants")
140 .column("redirect_uri")
141 .row(id)
142 .source(e)
143 })?;
144
145 let response_mode = value.response_mode.parse().map_err(|e| {
146 DatabaseInconsistencyError::on("oauth2_authorization_grants")
147 .column("response_mode")
148 .row(id)
149 .source(e)
150 })?;
151
152 Ok(AuthorizationGrant {
153 id,
154 stage,
155 client_id: value.oauth2_client_id.into(),
156 code,
157 scope,
158 state: value.state,
159 nonce: value.nonce,
160 response_mode,
161 redirect_uri,
162 created_at: value.created_at,
163 response_type_id_token: value.response_type_id_token,
164 login_hint: value.login_hint,
165 })
166 }
167}
168
169#[async_trait]
170impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
171 type Error = DatabaseError;
172
173 #[tracing::instrument(
174 name = "db.oauth2_authorization_grant.add",
175 skip_all,
176 fields(
177 db.query.text,
178 grant.id,
179 grant.scope = %scope,
180 %client.id,
181 ),
182 err,
183 )]
184 async fn add(
185 &mut self,
186 rng: &mut (dyn RngCore + Send),
187 clock: &dyn Clock,
188 client: &Client,
189 redirect_uri: Url,
190 scope: Scope,
191 code: Option<AuthorizationCode>,
192 state: Option<String>,
193 nonce: Option<String>,
194 response_mode: ResponseMode,
195 response_type_id_token: bool,
196 login_hint: Option<String>,
197 ) -> Result<AuthorizationGrant, Self::Error> {
198 let code_challenge = code
199 .as_ref()
200 .and_then(|c| c.pkce.as_ref())
201 .map(|p| &p.challenge);
202 let code_challenge_method = code
203 .as_ref()
204 .and_then(|c| c.pkce.as_ref())
205 .map(|p| p.challenge_method.to_string());
206 let code_str = code.as_ref().map(|c| &c.code);
207
208 let created_at = clock.now();
209 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
210 tracing::Span::current().record("grant.id", tracing::field::display(id));
211
212 sqlx::query!(
213 r#"
214 INSERT INTO oauth2_authorization_grants (
215 oauth2_authorization_grant_id,
216 oauth2_client_id,
217 redirect_uri,
218 scope,
219 state,
220 nonce,
221 response_mode,
222 code_challenge,
223 code_challenge_method,
224 response_type_code,
225 response_type_id_token,
226 authorization_code,
227 login_hint,
228 created_at
229 )
230 VALUES
231 ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
232 "#,
233 Uuid::from(id),
234 Uuid::from(client.id),
235 redirect_uri.to_string(),
236 scope.to_string(),
237 state,
238 nonce,
239 response_mode.to_string(),
240 code_challenge,
241 code_challenge_method,
242 code.is_some(),
243 response_type_id_token,
244 code_str,
245 login_hint,
246 created_at,
247 )
248 .traced()
249 .execute(&mut *self.conn)
250 .await?;
251
252 Ok(AuthorizationGrant {
253 id,
254 stage: AuthorizationGrantStage::Pending,
255 code,
256 redirect_uri,
257 client_id: client.id,
258 scope,
259 state,
260 nonce,
261 response_mode,
262 created_at,
263 response_type_id_token,
264 login_hint,
265 })
266 }
267
268 #[tracing::instrument(
269 name = "db.oauth2_authorization_grant.lookup",
270 skip_all,
271 fields(
272 db.query.text,
273 grant.id = %id,
274 ),
275 err,
276 )]
277 async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
278 let res = sqlx::query_as!(
279 GrantLookup,
280 r#"
281 SELECT oauth2_authorization_grant_id
282 , created_at
283 , cancelled_at
284 , fulfilled_at
285 , exchanged_at
286 , scope
287 , state
288 , redirect_uri
289 , response_mode
290 , nonce
291 , oauth2_client_id
292 , authorization_code
293 , response_type_code
294 , response_type_id_token
295 , code_challenge
296 , code_challenge_method
297 , login_hint
298 , oauth2_session_id
299 FROM
300 oauth2_authorization_grants
301
302 WHERE oauth2_authorization_grant_id = $1
303 "#,
304 Uuid::from(id),
305 )
306 .traced()
307 .fetch_optional(&mut *self.conn)
308 .await?;
309
310 let Some(res) = res else { return Ok(None) };
311
312 Ok(Some(res.try_into()?))
313 }
314
315 #[tracing::instrument(
316 name = "db.oauth2_authorization_grant.find_by_code",
317 skip_all,
318 fields(
319 db.query.text,
320 ),
321 err,
322 )]
323 async fn find_by_code(
324 &mut self,
325 code: &str,
326 ) -> Result<Option<AuthorizationGrant>, Self::Error> {
327 let res = sqlx::query_as!(
328 GrantLookup,
329 r#"
330 SELECT oauth2_authorization_grant_id
331 , created_at
332 , cancelled_at
333 , fulfilled_at
334 , exchanged_at
335 , scope
336 , state
337 , redirect_uri
338 , response_mode
339 , nonce
340 , oauth2_client_id
341 , authorization_code
342 , response_type_code
343 , response_type_id_token
344 , code_challenge
345 , code_challenge_method
346 , login_hint
347 , oauth2_session_id
348 FROM
349 oauth2_authorization_grants
350
351 WHERE authorization_code = $1
352 "#,
353 code,
354 )
355 .traced()
356 .fetch_optional(&mut *self.conn)
357 .await?;
358
359 let Some(res) = res else { return Ok(None) };
360
361 Ok(Some(res.try_into()?))
362 }
363
364 #[tracing::instrument(
365 name = "db.oauth2_authorization_grant.fulfill",
366 skip_all,
367 fields(
368 db.query.text,
369 %grant.id,
370 client.id = %grant.client_id,
371 %session.id,
372 ),
373 err,
374 )]
375 async fn fulfill(
376 &mut self,
377 clock: &dyn Clock,
378 session: &Session,
379 grant: AuthorizationGrant,
380 ) -> Result<AuthorizationGrant, Self::Error> {
381 let fulfilled_at = clock.now();
382 let res = sqlx::query!(
383 r#"
384 UPDATE oauth2_authorization_grants
385 SET fulfilled_at = $2
386 , oauth2_session_id = $3
387 WHERE oauth2_authorization_grant_id = $1
388 "#,
389 Uuid::from(grant.id),
390 fulfilled_at,
391 Uuid::from(session.id),
392 )
393 .traced()
394 .execute(&mut *self.conn)
395 .await?;
396
397 DatabaseError::ensure_affected_rows(&res, 1)?;
398
399 let grant = grant
401 .fulfill(fulfilled_at, session)
402 .map_err(DatabaseError::to_invalid_operation)?;
403
404 Ok(grant)
405 }
406
407 #[tracing::instrument(
408 name = "db.oauth2_authorization_grant.exchange",
409 skip_all,
410 fields(
411 db.query.text,
412 %grant.id,
413 client.id = %grant.client_id,
414 ),
415 err,
416 )]
417 async fn exchange(
418 &mut self,
419 clock: &dyn Clock,
420 grant: AuthorizationGrant,
421 ) -> Result<AuthorizationGrant, Self::Error> {
422 let exchanged_at = clock.now();
423 let res = sqlx::query!(
424 r#"
425 UPDATE oauth2_authorization_grants
426 SET exchanged_at = $2
427 WHERE oauth2_authorization_grant_id = $1
428 "#,
429 Uuid::from(grant.id),
430 exchanged_at,
431 )
432 .traced()
433 .execute(&mut *self.conn)
434 .await?;
435
436 DatabaseError::ensure_affected_rows(&res, 1)?;
437
438 let grant = grant
439 .exchange(exchanged_at)
440 .map_err(DatabaseError::to_invalid_operation)?;
441
442 Ok(grant)
443 }
444}