mas_storage_pg/oauth2/
authorization_grant.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
11};
12use mas_iana::oauth::PkceCodeChallengeMethod;
13use mas_storage::{Clock, oauth2::OAuth2AuthorizationGrantRepository};
14use oauth2_types::{requests::ResponseMode, scope::Scope};
15use rand::RngCore;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use url::Url;
19use uuid::Uuid;
20
21use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
22
23/// An implementation of [`OAuth2AuthorizationGrantRepository`] for a PostgreSQL
24/// connection
25pub struct PgOAuth2AuthorizationGrantRepository<'c> {
26    conn: &'c mut PgConnection,
27}
28
29impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
30    /// Create a new [`PgOAuth2AuthorizationGrantRepository`] from an active
31    /// PostgreSQL connection
32    pub fn new(conn: &'c mut PgConnection) -> Self {
33        Self { conn }
34    }
35}
36
37#[allow(clippy::struct_excessive_bools)]
38struct GrantLookup {
39    oauth2_authorization_grant_id: Uuid,
40    created_at: DateTime<Utc>,
41    cancelled_at: Option<DateTime<Utc>>,
42    fulfilled_at: Option<DateTime<Utc>>,
43    exchanged_at: Option<DateTime<Utc>>,
44    scope: String,
45    state: Option<String>,
46    nonce: Option<String>,
47    redirect_uri: String,
48    response_mode: String,
49    response_type_code: bool,
50    response_type_id_token: bool,
51    authorization_code: Option<String>,
52    code_challenge: Option<String>,
53    code_challenge_method: Option<String>,
54    login_hint: Option<String>,
55    oauth2_client_id: Uuid,
56    oauth2_session_id: Option<Uuid>,
57}
58
59impl TryFrom<GrantLookup> for AuthorizationGrant {
60    type Error = DatabaseInconsistencyError;
61
62    #[allow(clippy::too_many_lines)]
63    fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
64        let id = value.oauth2_authorization_grant_id.into();
65        let scope: Scope = value.scope.parse().map_err(|e| {
66            DatabaseInconsistencyError::on("oauth2_authorization_grants")
67                .column("scope")
68                .row(id)
69                .source(e)
70        })?;
71
72        let stage = match (
73            value.fulfilled_at,
74            value.exchanged_at,
75            value.cancelled_at,
76            value.oauth2_session_id,
77        ) {
78            (None, None, None, None) => AuthorizationGrantStage::Pending,
79            (Some(fulfilled_at), None, None, Some(session_id)) => {
80                AuthorizationGrantStage::Fulfilled {
81                    session_id: session_id.into(),
82                    fulfilled_at,
83                }
84            }
85            (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
86                AuthorizationGrantStage::Exchanged {
87                    session_id: session_id.into(),
88                    fulfilled_at,
89                    exchanged_at,
90                }
91            }
92            (None, None, Some(cancelled_at), None) => {
93                AuthorizationGrantStage::Cancelled { cancelled_at }
94            }
95            _ => {
96                return Err(
97                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
98                        .column("stage")
99                        .row(id),
100                );
101            }
102        };
103
104        let pkce = match (value.code_challenge, value.code_challenge_method) {
105            (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
106                Some(Pkce {
107                    challenge_method: PkceCodeChallengeMethod::Plain,
108                    challenge,
109                })
110            }
111            (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
112                challenge_method: PkceCodeChallengeMethod::S256,
113                challenge,
114            }),
115            (None, None) => None,
116            _ => {
117                return Err(
118                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
119                        .column("code_challenge_method")
120                        .row(id),
121                );
122            }
123        };
124
125        let code: Option<AuthorizationCode> =
126            match (value.response_type_code, value.authorization_code, pkce) {
127                (false, None, None) => None,
128                (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
129                _ => {
130                    return Err(
131                        DatabaseInconsistencyError::on("oauth2_authorization_grants")
132                            .column("authorization_code")
133                            .row(id),
134                    );
135                }
136            };
137
138        let redirect_uri = value.redirect_uri.parse().map_err(|e| {
139            DatabaseInconsistencyError::on("oauth2_authorization_grants")
140                .column("redirect_uri")
141                .row(id)
142                .source(e)
143        })?;
144
145        let response_mode = value.response_mode.parse().map_err(|e| {
146            DatabaseInconsistencyError::on("oauth2_authorization_grants")
147                .column("response_mode")
148                .row(id)
149                .source(e)
150        })?;
151
152        Ok(AuthorizationGrant {
153            id,
154            stage,
155            client_id: value.oauth2_client_id.into(),
156            code,
157            scope,
158            state: value.state,
159            nonce: value.nonce,
160            response_mode,
161            redirect_uri,
162            created_at: value.created_at,
163            response_type_id_token: value.response_type_id_token,
164            login_hint: value.login_hint,
165        })
166    }
167}
168
169#[async_trait]
170impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
171    type Error = DatabaseError;
172
173    #[tracing::instrument(
174        name = "db.oauth2_authorization_grant.add",
175        skip_all,
176        fields(
177            db.query.text,
178            grant.id,
179            grant.scope = %scope,
180            %client.id,
181        ),
182        err,
183    )]
184    async fn add(
185        &mut self,
186        rng: &mut (dyn RngCore + Send),
187        clock: &dyn Clock,
188        client: &Client,
189        redirect_uri: Url,
190        scope: Scope,
191        code: Option<AuthorizationCode>,
192        state: Option<String>,
193        nonce: Option<String>,
194        response_mode: ResponseMode,
195        response_type_id_token: bool,
196        login_hint: Option<String>,
197    ) -> Result<AuthorizationGrant, Self::Error> {
198        let code_challenge = code
199            .as_ref()
200            .and_then(|c| c.pkce.as_ref())
201            .map(|p| &p.challenge);
202        let code_challenge_method = code
203            .as_ref()
204            .and_then(|c| c.pkce.as_ref())
205            .map(|p| p.challenge_method.to_string());
206        let code_str = code.as_ref().map(|c| &c.code);
207
208        let created_at = clock.now();
209        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
210        tracing::Span::current().record("grant.id", tracing::field::display(id));
211
212        sqlx::query!(
213            r#"
214                INSERT INTO oauth2_authorization_grants (
215                     oauth2_authorization_grant_id,
216                     oauth2_client_id,
217                     redirect_uri,
218                     scope,
219                     state,
220                     nonce,
221                     response_mode,
222                     code_challenge,
223                     code_challenge_method,
224                     response_type_code,
225                     response_type_id_token,
226                     authorization_code,
227                     login_hint,
228                     created_at
229                )
230                VALUES
231                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
232            "#,
233            Uuid::from(id),
234            Uuid::from(client.id),
235            redirect_uri.to_string(),
236            scope.to_string(),
237            state,
238            nonce,
239            response_mode.to_string(),
240            code_challenge,
241            code_challenge_method,
242            code.is_some(),
243            response_type_id_token,
244            code_str,
245            login_hint,
246            created_at,
247        )
248        .traced()
249        .execute(&mut *self.conn)
250        .await?;
251
252        Ok(AuthorizationGrant {
253            id,
254            stage: AuthorizationGrantStage::Pending,
255            code,
256            redirect_uri,
257            client_id: client.id,
258            scope,
259            state,
260            nonce,
261            response_mode,
262            created_at,
263            response_type_id_token,
264            login_hint,
265        })
266    }
267
268    #[tracing::instrument(
269        name = "db.oauth2_authorization_grant.lookup",
270        skip_all,
271        fields(
272            db.query.text,
273            grant.id = %id,
274        ),
275        err,
276    )]
277    async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
278        let res = sqlx::query_as!(
279            GrantLookup,
280            r#"
281                SELECT oauth2_authorization_grant_id
282                     , created_at
283                     , cancelled_at
284                     , fulfilled_at
285                     , exchanged_at
286                     , scope
287                     , state
288                     , redirect_uri
289                     , response_mode
290                     , nonce
291                     , oauth2_client_id
292                     , authorization_code
293                     , response_type_code
294                     , response_type_id_token
295                     , code_challenge
296                     , code_challenge_method
297                     , login_hint
298                     , oauth2_session_id
299                FROM
300                    oauth2_authorization_grants
301
302                WHERE oauth2_authorization_grant_id = $1
303            "#,
304            Uuid::from(id),
305        )
306        .traced()
307        .fetch_optional(&mut *self.conn)
308        .await?;
309
310        let Some(res) = res else { return Ok(None) };
311
312        Ok(Some(res.try_into()?))
313    }
314
315    #[tracing::instrument(
316        name = "db.oauth2_authorization_grant.find_by_code",
317        skip_all,
318        fields(
319            db.query.text,
320        ),
321        err,
322    )]
323    async fn find_by_code(
324        &mut self,
325        code: &str,
326    ) -> Result<Option<AuthorizationGrant>, Self::Error> {
327        let res = sqlx::query_as!(
328            GrantLookup,
329            r#"
330                SELECT oauth2_authorization_grant_id
331                     , created_at
332                     , cancelled_at
333                     , fulfilled_at
334                     , exchanged_at
335                     , scope
336                     , state
337                     , redirect_uri
338                     , response_mode
339                     , nonce
340                     , oauth2_client_id
341                     , authorization_code
342                     , response_type_code
343                     , response_type_id_token
344                     , code_challenge
345                     , code_challenge_method
346                     , login_hint
347                     , oauth2_session_id
348                FROM
349                    oauth2_authorization_grants
350
351                WHERE authorization_code = $1
352            "#,
353            code,
354        )
355        .traced()
356        .fetch_optional(&mut *self.conn)
357        .await?;
358
359        let Some(res) = res else { return Ok(None) };
360
361        Ok(Some(res.try_into()?))
362    }
363
364    #[tracing::instrument(
365        name = "db.oauth2_authorization_grant.fulfill",
366        skip_all,
367        fields(
368            db.query.text,
369            %grant.id,
370            client.id = %grant.client_id,
371            %session.id,
372        ),
373        err,
374    )]
375    async fn fulfill(
376        &mut self,
377        clock: &dyn Clock,
378        session: &Session,
379        grant: AuthorizationGrant,
380    ) -> Result<AuthorizationGrant, Self::Error> {
381        let fulfilled_at = clock.now();
382        let res = sqlx::query!(
383            r#"
384                UPDATE oauth2_authorization_grants
385                SET fulfilled_at = $2
386                  , oauth2_session_id = $3
387                WHERE oauth2_authorization_grant_id = $1
388            "#,
389            Uuid::from(grant.id),
390            fulfilled_at,
391            Uuid::from(session.id),
392        )
393        .traced()
394        .execute(&mut *self.conn)
395        .await?;
396
397        DatabaseError::ensure_affected_rows(&res, 1)?;
398
399        // XXX: check affected rows & new methods
400        let grant = grant
401            .fulfill(fulfilled_at, session)
402            .map_err(DatabaseError::to_invalid_operation)?;
403
404        Ok(grant)
405    }
406
407    #[tracing::instrument(
408        name = "db.oauth2_authorization_grant.exchange",
409        skip_all,
410        fields(
411            db.query.text,
412            %grant.id,
413            client.id = %grant.client_id,
414        ),
415        err,
416    )]
417    async fn exchange(
418        &mut self,
419        clock: &dyn Clock,
420        grant: AuthorizationGrant,
421    ) -> Result<AuthorizationGrant, Self::Error> {
422        let exchanged_at = clock.now();
423        let res = sqlx::query!(
424            r#"
425                UPDATE oauth2_authorization_grants
426                SET exchanged_at = $2
427                WHERE oauth2_authorization_grant_id = $1
428            "#,
429            Uuid::from(grant.id),
430            exchanged_at,
431        )
432        .traced()
433        .execute(&mut *self.conn)
434        .await?;
435
436        DatabaseError::ensure_affected_rows(&res, 1)?;
437
438        let grant = grant
439            .exchange(exchanged_at)
440            .map_err(DatabaseError::to_invalid_operation)?;
441
442        Ok(grant)
443    }
444}