mas_storage_pg/user/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the user-related
8//! repositories
9
10use async_trait::async_trait;
11use mas_data_model::User;
12use mas_storage::{
13    Clock,
14    user::{UserFilter, UserRepository},
15};
16use rand::RngCore;
17use sea_query::{Expr, PostgresQueryBuilder, Query};
18use sea_query_binder::SqlxBinder;
19use sqlx::PgConnection;
20use ulid::Ulid;
21use uuid::Uuid;
22
23use crate::{
24    DatabaseError,
25    filter::{Filter, StatementExt},
26    iden::Users,
27    pagination::QueryBuilderExt,
28    tracing::ExecuteExt,
29};
30
31mod email;
32mod password;
33mod recovery;
34mod registration;
35mod session;
36mod terms;
37
38#[cfg(test)]
39mod tests;
40
41pub use self::{
42    email::PgUserEmailRepository, password::PgUserPasswordRepository,
43    recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
44    session::PgBrowserSessionRepository, terms::PgUserTermsRepository,
45};
46
47/// An implementation of [`UserRepository`] for a PostgreSQL connection
48pub struct PgUserRepository<'c> {
49    conn: &'c mut PgConnection,
50}
51
52impl<'c> PgUserRepository<'c> {
53    /// Create a new [`PgUserRepository`] from an active PostgreSQL connection
54    pub fn new(conn: &'c mut PgConnection) -> Self {
55        Self { conn }
56    }
57}
58
59mod priv_ {
60    // The enum_def macro generates a public enum, which we don't want, because it
61    // triggers the missing docs warning
62    #![allow(missing_docs)]
63
64    use chrono::{DateTime, Utc};
65    use sea_query::enum_def;
66    use uuid::Uuid;
67
68    #[derive(Debug, Clone, sqlx::FromRow)]
69    #[enum_def]
70    pub(super) struct UserLookup {
71        pub(super) user_id: Uuid,
72        pub(super) username: String,
73        pub(super) created_at: DateTime<Utc>,
74        pub(super) locked_at: Option<DateTime<Utc>>,
75        pub(super) deactivated_at: Option<DateTime<Utc>>,
76        pub(super) can_request_admin: bool,
77    }
78}
79
80use priv_::{UserLookup, UserLookupIden};
81
82impl From<UserLookup> for User {
83    fn from(value: UserLookup) -> Self {
84        let id = value.user_id.into();
85        Self {
86            id,
87            username: value.username,
88            sub: id.to_string(),
89            created_at: value.created_at,
90            locked_at: value.locked_at,
91            deactivated_at: value.deactivated_at,
92            can_request_admin: value.can_request_admin,
93        }
94    }
95}
96
97impl Filter for UserFilter<'_> {
98    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
99        sea_query::Condition::all()
100            .add_option(self.state().map(|state| {
101                match state {
102                    mas_storage::user::UserState::Deactivated => {
103                        Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
104                    }
105                    mas_storage::user::UserState::Locked => {
106                        Expr::col((Users::Table, Users::LockedAt)).is_not_null()
107                    }
108                    mas_storage::user::UserState::Active => {
109                        Expr::col((Users::Table, Users::LockedAt))
110                            .is_null()
111                            .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
112                    }
113                }
114            }))
115            .add_option(self.can_request_admin().map(|can_request_admin| {
116                Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
117            }))
118    }
119}
120
121#[async_trait]
122impl UserRepository for PgUserRepository<'_> {
123    type Error = DatabaseError;
124
125    #[tracing::instrument(
126        name = "db.user.lookup",
127        skip_all,
128        fields(
129            db.query.text,
130            user.id = %id,
131        ),
132        err,
133    )]
134    async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
135        let res = sqlx::query_as!(
136            UserLookup,
137            r#"
138                SELECT user_id
139                     , username
140                     , created_at
141                     , locked_at
142                     , deactivated_at
143                     , can_request_admin
144                FROM users
145                WHERE user_id = $1
146            "#,
147            Uuid::from(id),
148        )
149        .traced()
150        .fetch_optional(&mut *self.conn)
151        .await?;
152
153        let Some(res) = res else { return Ok(None) };
154
155        Ok(Some(res.into()))
156    }
157
158    #[tracing::instrument(
159        name = "db.user.find_by_username",
160        skip_all,
161        fields(
162            db.query.text,
163            user.username = username,
164        ),
165        err,
166    )]
167    async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
168        // We may have multiple users with the same username, but with a different
169        // casing. In this case, we want to return the one which matches the exact
170        // casing
171        let res = sqlx::query_as!(
172            UserLookup,
173            r#"
174                SELECT user_id
175                     , username
176                     , created_at
177                     , locked_at
178                     , deactivated_at
179                     , can_request_admin
180                FROM users
181                WHERE LOWER(username) = LOWER($1)
182            "#,
183            username,
184        )
185        .traced()
186        .fetch_all(&mut *self.conn)
187        .await?;
188
189        match &res[..] {
190            // Happy path: there is only one user matching the username…
191            [user] => Ok(Some(user.clone().into())),
192            // …or none.
193            [] => Ok(None),
194            list => {
195                // If there are multiple users with the same username, we want to
196                // return the one which matches the exact casing
197                if let Some(user) = list.iter().find(|user| user.username == username) {
198                    Ok(Some(user.clone().into()))
199                } else {
200                    // If none match exactly, we prefer to return nothing
201                    Ok(None)
202                }
203            }
204        }
205    }
206
207    #[tracing::instrument(
208        name = "db.user.add",
209        skip_all,
210        fields(
211            db.query.text,
212            user.username = username,
213            user.id,
214        ),
215        err,
216    )]
217    async fn add(
218        &mut self,
219        rng: &mut (dyn RngCore + Send),
220        clock: &dyn Clock,
221        username: String,
222    ) -> Result<User, Self::Error> {
223        let created_at = clock.now();
224        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
225        tracing::Span::current().record("user.id", tracing::field::display(id));
226
227        let res = sqlx::query!(
228            r#"
229                INSERT INTO users (user_id, username, created_at)
230                VALUES ($1, $2, $3)
231                ON CONFLICT (username) DO NOTHING
232            "#,
233            Uuid::from(id),
234            username,
235            created_at,
236        )
237        .traced()
238        .execute(&mut *self.conn)
239        .await?;
240
241        // If the user already exists, want to return an error but not poison the
242        // transaction
243        DatabaseError::ensure_affected_rows(&res, 1)?;
244
245        Ok(User {
246            id,
247            username,
248            sub: id.to_string(),
249            created_at,
250            locked_at: None,
251            deactivated_at: None,
252            can_request_admin: false,
253        })
254    }
255
256    #[tracing::instrument(
257        name = "db.user.exists",
258        skip_all,
259        fields(
260            db.query.text,
261            user.username = username,
262        ),
263        err,
264    )]
265    async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
266        let exists = sqlx::query_scalar!(
267            r#"
268                SELECT EXISTS(
269                    SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
270                ) AS "exists!"
271            "#,
272            username
273        )
274        .traced()
275        .fetch_one(&mut *self.conn)
276        .await?;
277
278        Ok(exists)
279    }
280
281    #[tracing::instrument(
282        name = "db.user.lock",
283        skip_all,
284        fields(
285            db.query.text,
286            %user.id,
287        ),
288        err,
289    )]
290    async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
291        if user.locked_at.is_some() {
292            return Ok(user);
293        }
294
295        let locked_at = clock.now();
296        let res = sqlx::query!(
297            r#"
298                UPDATE users
299                SET locked_at = $1
300                WHERE user_id = $2
301            "#,
302            locked_at,
303            Uuid::from(user.id),
304        )
305        .traced()
306        .execute(&mut *self.conn)
307        .await?;
308
309        DatabaseError::ensure_affected_rows(&res, 1)?;
310
311        user.locked_at = Some(locked_at);
312
313        Ok(user)
314    }
315
316    #[tracing::instrument(
317        name = "db.user.unlock",
318        skip_all,
319        fields(
320            db.query.text,
321            %user.id,
322        ),
323        err,
324    )]
325    async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
326        if user.locked_at.is_none() {
327            return Ok(user);
328        }
329
330        let res = sqlx::query!(
331            r#"
332                UPDATE users
333                SET locked_at = NULL
334                WHERE user_id = $1
335            "#,
336            Uuid::from(user.id),
337        )
338        .traced()
339        .execute(&mut *self.conn)
340        .await?;
341
342        DatabaseError::ensure_affected_rows(&res, 1)?;
343
344        user.locked_at = None;
345
346        Ok(user)
347    }
348
349    #[tracing::instrument(
350        name = "db.user.deactivate",
351        skip_all,
352        fields(
353            db.query.text,
354            %user.id,
355        ),
356        err,
357    )]
358    async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
359        if user.deactivated_at.is_some() {
360            return Ok(user);
361        }
362
363        let deactivated_at = clock.now();
364        let res = sqlx::query!(
365            r#"
366                UPDATE users
367                SET deactivated_at = $2
368                WHERE user_id = $1
369                  AND deactivated_at IS NULL
370            "#,
371            Uuid::from(user.id),
372            deactivated_at,
373        )
374        .traced()
375        .execute(&mut *self.conn)
376        .await?;
377
378        DatabaseError::ensure_affected_rows(&res, 1)?;
379
380        user.deactivated_at = Some(user.created_at);
381
382        Ok(user)
383    }
384
385    #[tracing::instrument(
386        name = "db.user.set_can_request_admin",
387        skip_all,
388        fields(
389            db.query.text,
390            %user.id,
391            user.can_request_admin = can_request_admin,
392        ),
393        err,
394    )]
395    async fn set_can_request_admin(
396        &mut self,
397        mut user: User,
398        can_request_admin: bool,
399    ) -> Result<User, Self::Error> {
400        let res = sqlx::query!(
401            r#"
402                UPDATE users
403                SET can_request_admin = $2
404                WHERE user_id = $1
405            "#,
406            Uuid::from(user.id),
407            can_request_admin,
408        )
409        .traced()
410        .execute(&mut *self.conn)
411        .await?;
412
413        DatabaseError::ensure_affected_rows(&res, 1)?;
414
415        user.can_request_admin = can_request_admin;
416
417        Ok(user)
418    }
419
420    #[tracing::instrument(
421        name = "db.user.list",
422        skip_all,
423        fields(
424            db.query.text,
425        ),
426        err,
427    )]
428    async fn list(
429        &mut self,
430        filter: UserFilter<'_>,
431        pagination: mas_storage::Pagination,
432    ) -> Result<mas_storage::Page<User>, Self::Error> {
433        let (sql, arguments) = Query::select()
434            .expr_as(
435                Expr::col((Users::Table, Users::UserId)),
436                UserLookupIden::UserId,
437            )
438            .expr_as(
439                Expr::col((Users::Table, Users::Username)),
440                UserLookupIden::Username,
441            )
442            .expr_as(
443                Expr::col((Users::Table, Users::CreatedAt)),
444                UserLookupIden::CreatedAt,
445            )
446            .expr_as(
447                Expr::col((Users::Table, Users::LockedAt)),
448                UserLookupIden::LockedAt,
449            )
450            .expr_as(
451                Expr::col((Users::Table, Users::DeactivatedAt)),
452                UserLookupIden::DeactivatedAt,
453            )
454            .expr_as(
455                Expr::col((Users::Table, Users::CanRequestAdmin)),
456                UserLookupIden::CanRequestAdmin,
457            )
458            .from(Users::Table)
459            .apply_filter(filter)
460            .generate_pagination((Users::Table, Users::UserId), pagination)
461            .build_sqlx(PostgresQueryBuilder);
462
463        let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
464            .traced()
465            .fetch_all(&mut *self.conn)
466            .await?;
467
468        let page = pagination.process(edges).map(User::from);
469
470        Ok(page)
471    }
472
473    #[tracing::instrument(
474        name = "db.user.count",
475        skip_all,
476        fields(
477            db.query.text,
478        ),
479        err,
480    )]
481    async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
482        let (sql, arguments) = Query::select()
483            .expr(Expr::col((Users::Table, Users::UserId)).count())
484            .from(Users::Table)
485            .apply_filter(filter)
486            .build_sqlx(PostgresQueryBuilder);
487
488        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
489            .traced()
490            .fetch_one(&mut *self.conn)
491            .await?;
492
493        count
494            .try_into()
495            .map_err(DatabaseError::to_invalid_operation)
496    }
497
498    #[tracing::instrument(
499        name = "db.user.acquire_lock_for_sync",
500        skip_all,
501        fields(
502            db.query.text,
503            user.id = %user.id,
504        ),
505        err,
506    )]
507    async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
508        // XXX: this lock isn't stictly scoped to users, but as we don't use many
509        // postgres advisory locks, it's fine for now. Later on, we could use row-level
510        // locks to make sure we don't get into trouble
511
512        // Convert the user ID to a u128 and grab the lower 64 bits
513        // As this includes 64bit of the random part of the ULID, it should be random
514        // enough to not collide
515        let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
516
517        // Use a PG advisory lock, which will be released when the transaction is
518        // committed or rolled back
519        sqlx::query!(
520            r#"
521                SELECT pg_advisory_xact_lock($1)
522            "#,
523            lock_id,
524        )
525        .traced()
526        .execute(&mut *self.conn)
527        .await?;
528
529        Ok(())
530    }
531}