mas_storage_pg/user/
email.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    BrowserSession, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
11    UserRegistration,
12};
13use mas_storage::{
14    Clock, Page, Pagination,
15    user::{UserEmailFilter, UserEmailRepository},
16};
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25    DatabaseError,
26    filter::{Filter, StatementExt},
27    iden::UserEmails,
28    pagination::QueryBuilderExt,
29    tracing::ExecuteExt,
30};
31
32/// An implementation of [`UserEmailRepository`] for a PostgreSQL connection
33pub struct PgUserEmailRepository<'c> {
34    conn: &'c mut PgConnection,
35}
36
37impl<'c> PgUserEmailRepository<'c> {
38    /// Create a new [`PgUserEmailRepository`] from an active PostgreSQL
39    /// connection
40    pub fn new(conn: &'c mut PgConnection) -> Self {
41        Self { conn }
42    }
43}
44
45#[derive(Debug, Clone, sqlx::FromRow)]
46#[enum_def]
47struct UserEmailLookup {
48    user_email_id: Uuid,
49    user_id: Uuid,
50    email: String,
51    created_at: DateTime<Utc>,
52}
53
54impl From<UserEmailLookup> for UserEmail {
55    fn from(e: UserEmailLookup) -> UserEmail {
56        UserEmail {
57            id: e.user_email_id.into(),
58            user_id: e.user_id.into(),
59            email: e.email,
60            created_at: e.created_at,
61        }
62    }
63}
64
65struct UserEmailAuthenticationLookup {
66    user_email_authentication_id: Uuid,
67    user_session_id: Option<Uuid>,
68    user_registration_id: Option<Uuid>,
69    email: String,
70    created_at: DateTime<Utc>,
71    completed_at: Option<DateTime<Utc>>,
72}
73
74impl From<UserEmailAuthenticationLookup> for UserEmailAuthentication {
75    fn from(value: UserEmailAuthenticationLookup) -> Self {
76        UserEmailAuthentication {
77            id: value.user_email_authentication_id.into(),
78            user_session_id: value.user_session_id.map(Ulid::from),
79            user_registration_id: value.user_registration_id.map(Ulid::from),
80            email: value.email,
81            created_at: value.created_at,
82            completed_at: value.completed_at,
83        }
84    }
85}
86
87struct UserEmailAuthenticationCodeLookup {
88    user_email_authentication_code_id: Uuid,
89    user_email_authentication_id: Uuid,
90    code: String,
91    created_at: DateTime<Utc>,
92    expires_at: DateTime<Utc>,
93}
94
95impl From<UserEmailAuthenticationCodeLookup> for UserEmailAuthenticationCode {
96    fn from(value: UserEmailAuthenticationCodeLookup) -> Self {
97        UserEmailAuthenticationCode {
98            id: value.user_email_authentication_code_id.into(),
99            user_email_authentication_id: value.user_email_authentication_id.into(),
100            code: value.code,
101            created_at: value.created_at,
102            expires_at: value.expires_at,
103        }
104    }
105}
106
107impl Filter for UserEmailFilter<'_> {
108    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
109        sea_query::Condition::all()
110            .add_option(self.user().map(|user| {
111                Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
112            }))
113            .add_option(
114                self.email()
115                    .map(|email| Expr::col((UserEmails::Table, UserEmails::Email)).eq(email)),
116            )
117    }
118}
119
120#[async_trait]
121impl UserEmailRepository for PgUserEmailRepository<'_> {
122    type Error = DatabaseError;
123
124    #[tracing::instrument(
125        name = "db.user_email.lookup",
126        skip_all,
127        fields(
128            db.query.text,
129            user_email.id = %id,
130        ),
131        err,
132    )]
133    async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
134        let res = sqlx::query_as!(
135            UserEmailLookup,
136            r#"
137                SELECT user_email_id
138                     , user_id
139                     , email
140                     , created_at
141                FROM user_emails
142
143                WHERE user_email_id = $1
144            "#,
145            Uuid::from(id),
146        )
147        .traced()
148        .fetch_optional(&mut *self.conn)
149        .await?;
150
151        let Some(user_email) = res else {
152            return Ok(None);
153        };
154
155        Ok(Some(user_email.into()))
156    }
157
158    #[tracing::instrument(
159        name = "db.user_email.find",
160        skip_all,
161        fields(
162            db.query.text,
163            %user.id,
164            user_email.email = email,
165        ),
166        err,
167    )]
168    async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
169        let res = sqlx::query_as!(
170            UserEmailLookup,
171            r#"
172                SELECT user_email_id
173                     , user_id
174                     , email
175                     , created_at
176                FROM user_emails
177
178                WHERE user_id = $1 AND email = $2
179            "#,
180            Uuid::from(user.id),
181            email,
182        )
183        .traced()
184        .fetch_optional(&mut *self.conn)
185        .await?;
186
187        let Some(user_email) = res else {
188            return Ok(None);
189        };
190
191        Ok(Some(user_email.into()))
192    }
193
194    #[tracing::instrument(
195        name = "db.user_email.find_by_email",
196        skip_all,
197        fields(
198            db.query.text,
199            user_email.email = email,
200        ),
201        err,
202    )]
203    async fn find_by_email(&mut self, email: &str) -> Result<Option<UserEmail>, Self::Error> {
204        let res = sqlx::query_as!(
205            UserEmailLookup,
206            r#"
207                SELECT user_email_id
208                     , user_id
209                     , email
210                     , created_at
211                FROM user_emails
212                WHERE email = $1
213            "#,
214            email,
215        )
216        .traced()
217        .fetch_all(&mut *self.conn)
218        .await?;
219
220        if res.len() != 1 {
221            return Ok(None);
222        }
223
224        let Some(user_email) = res.into_iter().next() else {
225            return Ok(None);
226        };
227
228        Ok(Some(user_email.into()))
229    }
230
231    #[tracing::instrument(
232        name = "db.user_email.all",
233        skip_all,
234        fields(
235            db.query.text,
236            %user.id,
237        ),
238        err,
239    )]
240    async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
241        let res = sqlx::query_as!(
242            UserEmailLookup,
243            r#"
244                SELECT user_email_id
245                     , user_id
246                     , email
247                     , created_at
248                FROM user_emails
249
250                WHERE user_id = $1
251
252                ORDER BY email ASC
253            "#,
254            Uuid::from(user.id),
255        )
256        .traced()
257        .fetch_all(&mut *self.conn)
258        .await?;
259
260        Ok(res.into_iter().map(Into::into).collect())
261    }
262
263    #[tracing::instrument(
264        name = "db.user_email.list",
265        skip_all,
266        fields(
267            db.query.text,
268        ),
269        err,
270    )]
271    async fn list(
272        &mut self,
273        filter: UserEmailFilter<'_>,
274        pagination: Pagination,
275    ) -> Result<Page<UserEmail>, DatabaseError> {
276        let (sql, arguments) = Query::select()
277            .expr_as(
278                Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
279                UserEmailLookupIden::UserEmailId,
280            )
281            .expr_as(
282                Expr::col((UserEmails::Table, UserEmails::UserId)),
283                UserEmailLookupIden::UserId,
284            )
285            .expr_as(
286                Expr::col((UserEmails::Table, UserEmails::Email)),
287                UserEmailLookupIden::Email,
288            )
289            .expr_as(
290                Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
291                UserEmailLookupIden::CreatedAt,
292            )
293            .from(UserEmails::Table)
294            .apply_filter(filter)
295            .generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination)
296            .build_sqlx(PostgresQueryBuilder);
297
298        let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
299            .traced()
300            .fetch_all(&mut *self.conn)
301            .await?;
302
303        let page = pagination.process(edges).map(UserEmail::from);
304
305        Ok(page)
306    }
307
308    #[tracing::instrument(
309        name = "db.user_email.count",
310        skip_all,
311        fields(
312            db.query.text,
313        ),
314        err,
315    )]
316    async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
317        let (sql, arguments) = Query::select()
318            .expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
319            .from(UserEmails::Table)
320            .apply_filter(filter)
321            .build_sqlx(PostgresQueryBuilder);
322
323        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
324            .traced()
325            .fetch_one(&mut *self.conn)
326            .await?;
327
328        count
329            .try_into()
330            .map_err(DatabaseError::to_invalid_operation)
331    }
332
333    #[tracing::instrument(
334        name = "db.user_email.add",
335        skip_all,
336        fields(
337            db.query.text,
338            %user.id,
339            user_email.id,
340            user_email.email = email,
341        ),
342        err,
343    )]
344    async fn add(
345        &mut self,
346        rng: &mut (dyn RngCore + Send),
347        clock: &dyn Clock,
348        user: &User,
349        email: String,
350    ) -> Result<UserEmail, Self::Error> {
351        let created_at = clock.now();
352        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
353        tracing::Span::current().record("user_email.id", tracing::field::display(id));
354
355        sqlx::query!(
356            r#"
357                INSERT INTO user_emails (user_email_id, user_id, email, created_at)
358                VALUES ($1, $2, $3, $4)
359            "#,
360            Uuid::from(id),
361            Uuid::from(user.id),
362            &email,
363            created_at,
364        )
365        .traced()
366        .execute(&mut *self.conn)
367        .await?;
368
369        Ok(UserEmail {
370            id,
371            user_id: user.id,
372            email,
373            created_at,
374        })
375    }
376
377    #[tracing::instrument(
378        name = "db.user_email.remove",
379        skip_all,
380        fields(
381            db.query.text,
382            user.id = %user_email.user_id,
383            %user_email.id,
384            %user_email.email,
385        ),
386        err,
387    )]
388    async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
389        let res = sqlx::query!(
390            r#"
391                DELETE FROM user_emails
392                WHERE user_email_id = $1
393            "#,
394            Uuid::from(user_email.id),
395        )
396        .traced()
397        .execute(&mut *self.conn)
398        .await?;
399
400        DatabaseError::ensure_affected_rows(&res, 1)?;
401
402        Ok(())
403    }
404
405    #[tracing::instrument(
406        name = "db.user_email.remove_bulk",
407        skip_all,
408        fields(
409            db.query.text,
410        ),
411        err,
412    )]
413    async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
414        let (sql, arguments) = Query::delete()
415            .from_table(UserEmails::Table)
416            .apply_filter(filter)
417            .build_sqlx(PostgresQueryBuilder);
418
419        let res = sqlx::query_with(&sql, arguments)
420            .traced()
421            .execute(&mut *self.conn)
422            .await?;
423
424        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
425    }
426
427    #[tracing::instrument(
428        name = "db.user_email.add_authentication_for_session",
429        skip_all,
430        fields(
431            db.query.text,
432            %session.id,
433            user_email_authentication.id,
434            user_email_authentication.email = email,
435        ),
436        err,
437    )]
438    async fn add_authentication_for_session(
439        &mut self,
440        rng: &mut (dyn RngCore + Send),
441        clock: &dyn Clock,
442        email: String,
443        session: &BrowserSession,
444    ) -> Result<UserEmailAuthentication, Self::Error> {
445        let created_at = clock.now();
446        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
447        tracing::Span::current()
448            .record("user_email_authentication.id", tracing::field::display(id));
449
450        sqlx::query!(
451            r#"
452                INSERT INTO user_email_authentications
453                  ( user_email_authentication_id
454                  , user_session_id
455                  , email
456                  , created_at
457                  )
458                VALUES ($1, $2, $3, $4)
459            "#,
460            Uuid::from(id),
461            Uuid::from(session.id),
462            &email,
463            created_at,
464        )
465        .traced()
466        .execute(&mut *self.conn)
467        .await?;
468
469        Ok(UserEmailAuthentication {
470            id,
471            user_session_id: Some(session.id),
472            user_registration_id: None,
473            email,
474            created_at,
475            completed_at: None,
476        })
477    }
478
479    #[tracing::instrument(
480        name = "db.user_email.add_authentication_for_registration",
481        skip_all,
482        fields(
483            db.query.text,
484            %user_registration.id,
485            user_email_authentication.id,
486            user_email_authentication.email = email,
487        ),
488        err,
489    )]
490    async fn add_authentication_for_registration(
491        &mut self,
492        rng: &mut (dyn RngCore + Send),
493        clock: &dyn Clock,
494        email: String,
495        user_registration: &UserRegistration,
496    ) -> Result<UserEmailAuthentication, Self::Error> {
497        let created_at = clock.now();
498        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
499        tracing::Span::current()
500            .record("user_email_authentication.id", tracing::field::display(id));
501
502        sqlx::query!(
503            r#"
504                INSERT INTO user_email_authentications
505                  ( user_email_authentication_id
506                  , user_registration_id
507                  , email
508                  , created_at
509                  )
510                VALUES ($1, $2, $3, $4)
511            "#,
512            Uuid::from(id),
513            Uuid::from(user_registration.id),
514            &email,
515            created_at,
516        )
517        .traced()
518        .execute(&mut *self.conn)
519        .await?;
520
521        Ok(UserEmailAuthentication {
522            id,
523            user_session_id: None,
524            user_registration_id: Some(user_registration.id),
525            email,
526            created_at,
527            completed_at: None,
528        })
529    }
530
531    #[tracing::instrument(
532        name = "db.user_email.add_authentication_code",
533        skip_all,
534        fields(
535            db.query.text,
536            %user_email_authentication.id,
537            %user_email_authentication.email,
538            user_email_authentication_code.id,
539            user_email_authentication_code.code = code,
540        ),
541        err,
542    )]
543    async fn add_authentication_code(
544        &mut self,
545        rng: &mut (dyn RngCore + Send),
546        clock: &dyn Clock,
547        duration: chrono::Duration,
548        user_email_authentication: &UserEmailAuthentication,
549        code: String,
550    ) -> Result<UserEmailAuthenticationCode, Self::Error> {
551        let created_at = clock.now();
552        let expires_at = created_at + duration;
553        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
554        tracing::Span::current().record(
555            "user_email_authentication_code.id",
556            tracing::field::display(id),
557        );
558
559        sqlx::query!(
560            r#"
561                INSERT INTO user_email_authentication_codes
562                  ( user_email_authentication_code_id
563                  , user_email_authentication_id
564                  , code
565                  , created_at
566                  , expires_at
567                  )
568                VALUES ($1, $2, $3, $4, $5)
569            "#,
570            Uuid::from(id),
571            Uuid::from(user_email_authentication.id),
572            &code,
573            created_at,
574            expires_at,
575        )
576        .traced()
577        .execute(&mut *self.conn)
578        .await?;
579
580        Ok(UserEmailAuthenticationCode {
581            id,
582            user_email_authentication_id: user_email_authentication.id,
583            code,
584            created_at,
585            expires_at,
586        })
587    }
588
589    #[tracing::instrument(
590        name = "db.user_email.lookup_authentication",
591        skip_all,
592        fields(
593            db.query.text,
594            user_email_authentication.id = %id,
595        ),
596        err,
597    )]
598    async fn lookup_authentication(
599        &mut self,
600        id: Ulid,
601    ) -> Result<Option<UserEmailAuthentication>, Self::Error> {
602        let res = sqlx::query_as!(
603            UserEmailAuthenticationLookup,
604            r#"
605                SELECT user_email_authentication_id
606                     , user_session_id
607                     , user_registration_id
608                     , email
609                     , created_at
610                     , completed_at
611                FROM user_email_authentications
612                WHERE user_email_authentication_id = $1
613            "#,
614            Uuid::from(id),
615        )
616        .traced()
617        .fetch_optional(&mut *self.conn)
618        .await?;
619
620        Ok(res.map(UserEmailAuthentication::from))
621    }
622
623    #[tracing::instrument(
624        name = "db.user_email.find_authentication_by_code",
625        skip_all,
626        fields(
627            db.query.text,
628            %authentication.id,
629            user_email_authentication_code.code = code,
630        ),
631        err,
632    )]
633    async fn find_authentication_code(
634        &mut self,
635        authentication: &UserEmailAuthentication,
636        code: &str,
637    ) -> Result<Option<UserEmailAuthenticationCode>, Self::Error> {
638        let res = sqlx::query_as!(
639            UserEmailAuthenticationCodeLookup,
640            r#"
641                SELECT user_email_authentication_code_id
642                     , user_email_authentication_id
643                     , code
644                     , created_at
645                     , expires_at
646                FROM user_email_authentication_codes
647                WHERE user_email_authentication_id = $1
648                  AND code = $2
649            "#,
650            Uuid::from(authentication.id),
651            code,
652        )
653        .traced()
654        .fetch_optional(&mut *self.conn)
655        .await?;
656
657        Ok(res.map(UserEmailAuthenticationCode::from))
658    }
659
660    #[tracing::instrument(
661        name = "db.user_email.complete_email_authentication",
662        skip_all,
663        fields(
664            db.query.text,
665            %user_email_authentication.id,
666            %user_email_authentication.email,
667            %user_email_authentication_code.id,
668            %user_email_authentication_code.code,
669        ),
670        err,
671    )]
672    async fn complete_authentication(
673        &mut self,
674        clock: &dyn Clock,
675        mut user_email_authentication: UserEmailAuthentication,
676        user_email_authentication_code: &UserEmailAuthenticationCode,
677    ) -> Result<UserEmailAuthentication, Self::Error> {
678        // We technically don't use the authentication code here (other than
679        // recording it in the span), but this is to make sure the caller has
680        // fetched one before calling this
681        let completed_at = clock.now();
682
683        // We'll assume the caller has checked that completed_at is None, so in case
684        // they haven't, the update will not affect any rows, which will raise
685        // an error
686        let res = sqlx::query!(
687            r#"
688                UPDATE user_email_authentications
689                SET completed_at = $2
690                WHERE user_email_authentication_id = $1
691                  AND completed_at IS NULL
692            "#,
693            Uuid::from(user_email_authentication.id),
694            completed_at,
695        )
696        .traced()
697        .execute(&mut *self.conn)
698        .await?;
699
700        DatabaseError::ensure_affected_rows(&res, 1)?;
701
702        user_email_authentication.completed_at = Some(completed_at);
703        Ok(user_email_authentication)
704    }
705}