mas_storage_pg/upstream_oauth2/
link.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
10use mas_storage::{
11    Clock, Page, Pagination,
12    upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
13};
14use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
15use rand::RngCore;
16use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
17use sea_query_binder::SqlxBinder;
18use sqlx::PgConnection;
19use tracing::Instrument;
20use ulid::Ulid;
21use uuid::Uuid;
22
23use crate::{
24    DatabaseError,
25    filter::{Filter, StatementExt},
26    iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
27    pagination::QueryBuilderExt,
28    tracing::ExecuteExt,
29};
30
31/// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL
32/// connection
33pub struct PgUpstreamOAuthLinkRepository<'c> {
34    conn: &'c mut PgConnection,
35}
36
37impl<'c> PgUpstreamOAuthLinkRepository<'c> {
38    /// Create a new [`PgUpstreamOAuthLinkRepository`] from an active PostgreSQL
39    /// connection
40    pub fn new(conn: &'c mut PgConnection) -> Self {
41        Self { conn }
42    }
43}
44
45#[derive(sqlx::FromRow)]
46#[enum_def]
47struct LinkLookup {
48    upstream_oauth_link_id: Uuid,
49    upstream_oauth_provider_id: Uuid,
50    user_id: Option<Uuid>,
51    subject: String,
52    human_account_name: Option<String>,
53    created_at: DateTime<Utc>,
54}
55
56impl From<LinkLookup> for UpstreamOAuthLink {
57    fn from(value: LinkLookup) -> Self {
58        UpstreamOAuthLink {
59            id: Ulid::from(value.upstream_oauth_link_id),
60            provider_id: Ulid::from(value.upstream_oauth_provider_id),
61            user_id: value.user_id.map(Ulid::from),
62            subject: value.subject,
63            human_account_name: value.human_account_name,
64            created_at: value.created_at,
65        }
66    }
67}
68
69impl Filter for UpstreamOAuthLinkFilter<'_> {
70    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
71        sea_query::Condition::all()
72            .add_option(self.user().map(|user| {
73                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId))
74                    .eq(Uuid::from(user.id))
75            }))
76            .add_option(self.provider().map(|provider| {
77                Expr::col((
78                    UpstreamOAuthLinks::Table,
79                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
80                ))
81                .eq(Uuid::from(provider.id))
82            }))
83            .add_option(self.provider_enabled().map(|enabled| {
84                Expr::col((
85                    UpstreamOAuthLinks::Table,
86                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
87                ))
88                .eq(Expr::any(
89                    Query::select()
90                        .expr(Expr::col((
91                            UpstreamOAuthProviders::Table,
92                            UpstreamOAuthProviders::UpstreamOAuthProviderId,
93                        )))
94                        .from(UpstreamOAuthProviders::Table)
95                        .and_where(
96                            Expr::col((
97                                UpstreamOAuthProviders::Table,
98                                UpstreamOAuthProviders::DisabledAt,
99                            ))
100                            .is_null()
101                            .eq(enabled),
102                        )
103                        .take(),
104                ))
105            }))
106            .add_option(self.subject().map(|subject| {
107                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)).eq(subject)
108            }))
109    }
110}
111
112#[async_trait]
113impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
114    type Error = DatabaseError;
115
116    #[tracing::instrument(
117        name = "db.upstream_oauth_link.lookup",
118        skip_all,
119        fields(
120            db.query.text,
121            upstream_oauth_link.id = %id,
122        ),
123        err,
124    )]
125    async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
126        let res = sqlx::query_as!(
127            LinkLookup,
128            r#"
129                SELECT
130                    upstream_oauth_link_id,
131                    upstream_oauth_provider_id,
132                    user_id,
133                    subject,
134                    human_account_name,
135                    created_at
136                FROM upstream_oauth_links
137                WHERE upstream_oauth_link_id = $1
138            "#,
139            Uuid::from(id),
140        )
141        .traced()
142        .fetch_optional(&mut *self.conn)
143        .await?
144        .map(Into::into);
145
146        Ok(res)
147    }
148
149    #[tracing::instrument(
150        name = "db.upstream_oauth_link.find_by_subject",
151        skip_all,
152        fields(
153            db.query.text,
154            upstream_oauth_link.subject = subject,
155            %upstream_oauth_provider.id,
156            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
157            %upstream_oauth_provider.client_id,
158        ),
159        err,
160    )]
161    async fn find_by_subject(
162        &mut self,
163        upstream_oauth_provider: &UpstreamOAuthProvider,
164        subject: &str,
165    ) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
166        let res = sqlx::query_as!(
167            LinkLookup,
168            r#"
169                SELECT
170                    upstream_oauth_link_id,
171                    upstream_oauth_provider_id,
172                    user_id,
173                    subject,
174                    human_account_name,
175                    created_at
176                FROM upstream_oauth_links
177                WHERE upstream_oauth_provider_id = $1
178                  AND subject = $2
179            "#,
180            Uuid::from(upstream_oauth_provider.id),
181            subject,
182        )
183        .traced()
184        .fetch_optional(&mut *self.conn)
185        .await?
186        .map(Into::into);
187
188        Ok(res)
189    }
190
191    #[tracing::instrument(
192        name = "db.upstream_oauth_link.add",
193        skip_all,
194        fields(
195            db.query.text,
196            upstream_oauth_link.id,
197            upstream_oauth_link.subject = subject,
198            upstream_oauth_link.human_account_name = human_account_name,
199            %upstream_oauth_provider.id,
200            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
201            %upstream_oauth_provider.client_id,
202        ),
203        err,
204    )]
205    async fn add(
206        &mut self,
207        rng: &mut (dyn RngCore + Send),
208        clock: &dyn Clock,
209        upstream_oauth_provider: &UpstreamOAuthProvider,
210        subject: String,
211        human_account_name: Option<String>,
212    ) -> Result<UpstreamOAuthLink, Self::Error> {
213        let created_at = clock.now();
214        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
215        tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
216
217        sqlx::query!(
218            r#"
219                INSERT INTO upstream_oauth_links (
220                    upstream_oauth_link_id,
221                    upstream_oauth_provider_id,
222                    user_id,
223                    subject,
224                    human_account_name,
225                    created_at
226                ) VALUES ($1, $2, NULL, $3, $4, $5)
227            "#,
228            Uuid::from(id),
229            Uuid::from(upstream_oauth_provider.id),
230            &subject,
231            human_account_name.as_deref(),
232            created_at,
233        )
234        .traced()
235        .execute(&mut *self.conn)
236        .await?;
237
238        Ok(UpstreamOAuthLink {
239            id,
240            provider_id: upstream_oauth_provider.id,
241            user_id: None,
242            subject,
243            human_account_name,
244            created_at,
245        })
246    }
247
248    #[tracing::instrument(
249        name = "db.upstream_oauth_link.associate_to_user",
250        skip_all,
251        fields(
252            db.query.text,
253            %upstream_oauth_link.id,
254            %upstream_oauth_link.subject,
255            %user.id,
256            %user.username,
257        ),
258        err,
259    )]
260    async fn associate_to_user(
261        &mut self,
262        upstream_oauth_link: &UpstreamOAuthLink,
263        user: &User,
264    ) -> Result<(), Self::Error> {
265        sqlx::query!(
266            r#"
267                UPDATE upstream_oauth_links
268                SET user_id = $1
269                WHERE upstream_oauth_link_id = $2
270            "#,
271            Uuid::from(user.id),
272            Uuid::from(upstream_oauth_link.id),
273        )
274        .traced()
275        .execute(&mut *self.conn)
276        .await?;
277
278        Ok(())
279    }
280
281    #[tracing::instrument(
282        name = "db.upstream_oauth_link.list",
283        skip_all,
284        fields(
285            db.query.text,
286        ),
287        err,
288    )]
289    async fn list(
290        &mut self,
291        filter: UpstreamOAuthLinkFilter<'_>,
292        pagination: Pagination,
293    ) -> Result<Page<UpstreamOAuthLink>, DatabaseError> {
294        let (sql, arguments) = Query::select()
295            .expr_as(
296                Expr::col((
297                    UpstreamOAuthLinks::Table,
298                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
299                )),
300                LinkLookupIden::UpstreamOauthLinkId,
301            )
302            .expr_as(
303                Expr::col((
304                    UpstreamOAuthLinks::Table,
305                    UpstreamOAuthLinks::UpstreamOAuthProviderId,
306                )),
307                LinkLookupIden::UpstreamOauthProviderId,
308            )
309            .expr_as(
310                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)),
311                LinkLookupIden::UserId,
312            )
313            .expr_as(
314                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
315                LinkLookupIden::Subject,
316            )
317            .expr_as(
318                Expr::col((
319                    UpstreamOAuthLinks::Table,
320                    UpstreamOAuthLinks::HumanAccountName,
321                )),
322                LinkLookupIden::HumanAccountName,
323            )
324            .expr_as(
325                Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
326                LinkLookupIden::CreatedAt,
327            )
328            .from(UpstreamOAuthLinks::Table)
329            .apply_filter(filter)
330            .generate_pagination(
331                (
332                    UpstreamOAuthLinks::Table,
333                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
334                ),
335                pagination,
336            )
337            .build_sqlx(PostgresQueryBuilder);
338
339        let edges: Vec<LinkLookup> = sqlx::query_as_with(&sql, arguments)
340            .traced()
341            .fetch_all(&mut *self.conn)
342            .await?;
343
344        let page = pagination.process(edges).map(UpstreamOAuthLink::from);
345
346        Ok(page)
347    }
348
349    #[tracing::instrument(
350        name = "db.upstream_oauth_link.count",
351        skip_all,
352        fields(
353            db.query.text,
354        ),
355        err,
356    )]
357    async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error> {
358        let (sql, arguments) = Query::select()
359            .expr(
360                Expr::col((
361                    UpstreamOAuthLinks::Table,
362                    UpstreamOAuthLinks::UpstreamOAuthLinkId,
363                ))
364                .count(),
365            )
366            .from(UpstreamOAuthLinks::Table)
367            .apply_filter(filter)
368            .build_sqlx(PostgresQueryBuilder);
369
370        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
371            .traced()
372            .fetch_one(&mut *self.conn)
373            .await?;
374
375        count
376            .try_into()
377            .map_err(DatabaseError::to_invalid_operation)
378    }
379
380    #[tracing::instrument(
381        name = "db.upstream_oauth_link.remove",
382        skip_all,
383        fields(
384            db.query.text,
385            upstream_oauth_link.id,
386            upstream_oauth_link.provider_id,
387            %upstream_oauth_link.subject,
388        ),
389        err,
390    )]
391    async fn remove(
392        &mut self,
393        clock: &dyn Clock,
394        upstream_oauth_link: UpstreamOAuthLink,
395    ) -> Result<(), Self::Error> {
396        // Unlink the authorization sessions first, as they have a foreign key
397        // constraint on the links.
398        let span = tracing::info_span!(
399            "db.upstream_oauth_link.remove.unlink",
400            { DB_QUERY_TEXT } = tracing::field::Empty
401        );
402        sqlx::query!(
403            r#"
404                UPDATE upstream_oauth_authorization_sessions SET
405                    upstream_oauth_link_id = NULL,
406                    unlinked_at = $2
407                WHERE upstream_oauth_link_id = $1
408            "#,
409            Uuid::from(upstream_oauth_link.id),
410            clock.now()
411        )
412        .record(&span)
413        .execute(&mut *self.conn)
414        .instrument(span)
415        .await?;
416
417        // Then delete the link itself
418        let span = tracing::info_span!(
419            "db.upstream_oauth_link.remove.delete",
420            { DB_QUERY_TEXT } = tracing::field::Empty
421        );
422        let res = sqlx::query!(
423            r#"
424                DELETE FROM upstream_oauth_links
425                WHERE upstream_oauth_link_id = $1
426            "#,
427            Uuid::from(upstream_oauth_link.id),
428        )
429        .record(&span)
430        .execute(&mut *self.conn)
431        .instrument(span)
432        .await?;
433
434        DatabaseError::ensure_affected_rows(&res, 1)?;
435
436        Ok(())
437    }
438}