mas_storage_pg/upstream_oauth2/
session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
11    UpstreamOAuthProvider,
12};
13use mas_storage::{Clock, upstream_oauth2::UpstreamOAuthSessionRepository};
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
20
21/// An implementation of [`UpstreamOAuthSessionRepository`] for a PostgreSQL
22/// connection
23pub struct PgUpstreamOAuthSessionRepository<'c> {
24    conn: &'c mut PgConnection,
25}
26
27impl<'c> PgUpstreamOAuthSessionRepository<'c> {
28    /// Create a new [`PgUpstreamOAuthSessionRepository`] from an active
29    /// PostgreSQL connection
30    pub fn new(conn: &'c mut PgConnection) -> Self {
31        Self { conn }
32    }
33}
34
35struct SessionLookup {
36    upstream_oauth_authorization_session_id: Uuid,
37    upstream_oauth_provider_id: Uuid,
38    upstream_oauth_link_id: Option<Uuid>,
39    state: String,
40    code_challenge_verifier: Option<String>,
41    nonce: String,
42    id_token: Option<String>,
43    userinfo: Option<serde_json::Value>,
44    created_at: DateTime<Utc>,
45    completed_at: Option<DateTime<Utc>>,
46    consumed_at: Option<DateTime<Utc>>,
47    extra_callback_parameters: Option<serde_json::Value>,
48    unlinked_at: Option<DateTime<Utc>>,
49}
50
51impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
52    type Error = DatabaseInconsistencyError;
53
54    fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
55        let id = value.upstream_oauth_authorization_session_id.into();
56        let state = match (
57            value.upstream_oauth_link_id,
58            value.id_token,
59            value.extra_callback_parameters,
60            value.userinfo,
61            value.completed_at,
62            value.consumed_at,
63            value.unlinked_at,
64        ) {
65            (None, None, None, None, None, None, None) => {
66                UpstreamOAuthAuthorizationSessionState::Pending
67            }
68            (
69                Some(link_id),
70                id_token,
71                extra_callback_parameters,
72                userinfo,
73                Some(completed_at),
74                None,
75                None,
76            ) => UpstreamOAuthAuthorizationSessionState::Completed {
77                completed_at,
78                link_id: link_id.into(),
79                id_token,
80                extra_callback_parameters,
81                userinfo,
82            },
83            (
84                Some(link_id),
85                id_token,
86                extra_callback_parameters,
87                userinfo,
88                Some(completed_at),
89                Some(consumed_at),
90                None,
91            ) => UpstreamOAuthAuthorizationSessionState::Consumed {
92                completed_at,
93                link_id: link_id.into(),
94                id_token,
95                extra_callback_parameters,
96                userinfo,
97                consumed_at,
98            },
99            (_, id_token, _, _, Some(completed_at), consumed_at, Some(unlinked_at)) => {
100                UpstreamOAuthAuthorizationSessionState::Unlinked {
101                    completed_at,
102                    id_token,
103                    consumed_at,
104                    unlinked_at,
105                }
106            }
107            _ => {
108                return Err(DatabaseInconsistencyError::on(
109                    "upstream_oauth_authorization_sessions",
110                )
111                .row(id));
112            }
113        };
114
115        Ok(Self {
116            id,
117            provider_id: value.upstream_oauth_provider_id.into(),
118            state_str: value.state,
119            nonce: value.nonce,
120            code_challenge_verifier: value.code_challenge_verifier,
121            created_at: value.created_at,
122            state,
123        })
124    }
125}
126
127#[async_trait]
128impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
129    type Error = DatabaseError;
130
131    #[tracing::instrument(
132        name = "db.upstream_oauth_authorization_session.lookup",
133        skip_all,
134        fields(
135            db.query.text,
136            upstream_oauth_provider.id = %id,
137        ),
138        err,
139    )]
140    async fn lookup(
141        &mut self,
142        id: Ulid,
143    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
144        let res = sqlx::query_as!(
145            SessionLookup,
146            r#"
147                SELECT
148                    upstream_oauth_authorization_session_id,
149                    upstream_oauth_provider_id,
150                    upstream_oauth_link_id,
151                    state,
152                    code_challenge_verifier,
153                    nonce,
154                    id_token,
155                    extra_callback_parameters,
156                    userinfo,
157                    created_at,
158                    completed_at,
159                    consumed_at,
160                    unlinked_at
161                FROM upstream_oauth_authorization_sessions
162                WHERE upstream_oauth_authorization_session_id = $1
163            "#,
164            Uuid::from(id),
165        )
166        .traced()
167        .fetch_optional(&mut *self.conn)
168        .await?;
169
170        let Some(res) = res else { return Ok(None) };
171
172        Ok(Some(res.try_into()?))
173    }
174
175    #[tracing::instrument(
176        name = "db.upstream_oauth_authorization_session.add",
177        skip_all,
178        fields(
179            db.query.text,
180            %upstream_oauth_provider.id,
181            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
182            %upstream_oauth_provider.client_id,
183            upstream_oauth_authorization_session.id,
184        ),
185        err,
186    )]
187    async fn add(
188        &mut self,
189        rng: &mut (dyn RngCore + Send),
190        clock: &dyn Clock,
191        upstream_oauth_provider: &UpstreamOAuthProvider,
192        state_str: String,
193        code_challenge_verifier: Option<String>,
194        nonce: String,
195    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
196        let created_at = clock.now();
197        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
198        tracing::Span::current().record(
199            "upstream_oauth_authorization_session.id",
200            tracing::field::display(id),
201        );
202
203        sqlx::query!(
204            r#"
205                INSERT INTO upstream_oauth_authorization_sessions (
206                    upstream_oauth_authorization_session_id,
207                    upstream_oauth_provider_id,
208                    state,
209                    code_challenge_verifier,
210                    nonce,
211                    created_at,
212                    completed_at,
213                    consumed_at,
214                    id_token,
215                    userinfo
216                ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
217            "#,
218            Uuid::from(id),
219            Uuid::from(upstream_oauth_provider.id),
220            &state_str,
221            code_challenge_verifier.as_deref(),
222            nonce,
223            created_at,
224        )
225        .traced()
226        .execute(&mut *self.conn)
227        .await?;
228
229        Ok(UpstreamOAuthAuthorizationSession {
230            id,
231            state: UpstreamOAuthAuthorizationSessionState::default(),
232            provider_id: upstream_oauth_provider.id,
233            state_str,
234            code_challenge_verifier,
235            nonce,
236            created_at,
237        })
238    }
239
240    #[tracing::instrument(
241        name = "db.upstream_oauth_authorization_session.complete_with_link",
242        skip_all,
243        fields(
244            db.query.text,
245            %upstream_oauth_authorization_session.id,
246            %upstream_oauth_link.id,
247        ),
248        err,
249    )]
250    async fn complete_with_link(
251        &mut self,
252        clock: &dyn Clock,
253        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
254        upstream_oauth_link: &UpstreamOAuthLink,
255        id_token: Option<String>,
256        extra_callback_parameters: Option<serde_json::Value>,
257        userinfo: Option<serde_json::Value>,
258    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
259        let completed_at = clock.now();
260
261        sqlx::query!(
262            r#"
263                UPDATE upstream_oauth_authorization_sessions
264                SET upstream_oauth_link_id = $1,
265                    completed_at = $2,
266                    id_token = $3,
267                    extra_callback_parameters = $4,
268                    userinfo = $5
269                WHERE upstream_oauth_authorization_session_id = $6
270            "#,
271            Uuid::from(upstream_oauth_link.id),
272            completed_at,
273            id_token,
274            extra_callback_parameters,
275            userinfo,
276            Uuid::from(upstream_oauth_authorization_session.id),
277        )
278        .traced()
279        .execute(&mut *self.conn)
280        .await?;
281
282        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
283            .complete(
284                completed_at,
285                upstream_oauth_link,
286                id_token,
287                extra_callback_parameters,
288                userinfo,
289            )
290            .map_err(DatabaseError::to_invalid_operation)?;
291
292        Ok(upstream_oauth_authorization_session)
293    }
294
295    /// Mark a session as consumed
296    #[tracing::instrument(
297        name = "db.upstream_oauth_authorization_session.consume",
298        skip_all,
299        fields(
300            db.query.text,
301            %upstream_oauth_authorization_session.id,
302        ),
303        err,
304    )]
305    async fn consume(
306        &mut self,
307        clock: &dyn Clock,
308        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
309    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
310        let consumed_at = clock.now();
311        sqlx::query!(
312            r#"
313                UPDATE upstream_oauth_authorization_sessions
314                SET consumed_at = $1
315                WHERE upstream_oauth_authorization_session_id = $2
316            "#,
317            consumed_at,
318            Uuid::from(upstream_oauth_authorization_session.id),
319        )
320        .traced()
321        .execute(&mut *self.conn)
322        .await?;
323
324        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
325            .consume(consumed_at)
326            .map_err(DatabaseError::to_invalid_operation)?;
327
328        Ok(upstream_oauth_authorization_session)
329    }
330}