1use async_trait::async_trait;
11use mas_data_model::User;
12use mas_storage::{
13 Clock,
14 user::{UserFilter, UserRepository},
15};
16use rand::RngCore;
17use sea_query::{Expr, PostgresQueryBuilder, Query};
18use sea_query_binder::SqlxBinder;
19use sqlx::PgConnection;
20use ulid::Ulid;
21use uuid::Uuid;
22
23use crate::{
24 DatabaseError,
25 filter::{Filter, StatementExt},
26 iden::Users,
27 pagination::QueryBuilderExt,
28 tracing::ExecuteExt,
29};
30
31mod email;
32mod password;
33mod recovery;
34mod registration;
35mod session;
36mod terms;
37
38#[cfg(test)]
39mod tests;
40
41pub use self::{
42 email::PgUserEmailRepository, password::PgUserPasswordRepository,
43 recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
44 session::PgBrowserSessionRepository, terms::PgUserTermsRepository,
45};
46
47pub struct PgUserRepository<'c> {
49 conn: &'c mut PgConnection,
50}
51
52impl<'c> PgUserRepository<'c> {
53 pub fn new(conn: &'c mut PgConnection) -> Self {
55 Self { conn }
56 }
57}
58
59mod priv_ {
60 #![allow(missing_docs)]
63
64 use chrono::{DateTime, Utc};
65 use sea_query::enum_def;
66 use uuid::Uuid;
67
68 #[derive(Debug, Clone, sqlx::FromRow)]
69 #[enum_def]
70 pub(super) struct UserLookup {
71 pub(super) user_id: Uuid,
72 pub(super) username: String,
73 pub(super) created_at: DateTime<Utc>,
74 pub(super) locked_at: Option<DateTime<Utc>>,
75 pub(super) deactivated_at: Option<DateTime<Utc>>,
76 pub(super) can_request_admin: bool,
77 }
78}
79
80use priv_::{UserLookup, UserLookupIden};
81
82impl From<UserLookup> for User {
83 fn from(value: UserLookup) -> Self {
84 let id = value.user_id.into();
85 Self {
86 id,
87 username: value.username,
88 sub: id.to_string(),
89 created_at: value.created_at,
90 locked_at: value.locked_at,
91 deactivated_at: value.deactivated_at,
92 can_request_admin: value.can_request_admin,
93 }
94 }
95}
96
97impl Filter for UserFilter<'_> {
98 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
99 sea_query::Condition::all()
100 .add_option(self.state().map(|state| {
101 match state {
102 mas_storage::user::UserState::Deactivated => {
103 Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
104 }
105 mas_storage::user::UserState::Locked => {
106 Expr::col((Users::Table, Users::LockedAt)).is_not_null()
107 }
108 mas_storage::user::UserState::Active => {
109 Expr::col((Users::Table, Users::LockedAt))
110 .is_null()
111 .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
112 }
113 }
114 }))
115 .add_option(self.can_request_admin().map(|can_request_admin| {
116 Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
117 }))
118 }
119}
120
121#[async_trait]
122impl UserRepository for PgUserRepository<'_> {
123 type Error = DatabaseError;
124
125 #[tracing::instrument(
126 name = "db.user.lookup",
127 skip_all,
128 fields(
129 db.query.text,
130 user.id = %id,
131 ),
132 err,
133 )]
134 async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
135 let res = sqlx::query_as!(
136 UserLookup,
137 r#"
138 SELECT user_id
139 , username
140 , created_at
141 , locked_at
142 , deactivated_at
143 , can_request_admin
144 FROM users
145 WHERE user_id = $1
146 "#,
147 Uuid::from(id),
148 )
149 .traced()
150 .fetch_optional(&mut *self.conn)
151 .await?;
152
153 let Some(res) = res else { return Ok(None) };
154
155 Ok(Some(res.into()))
156 }
157
158 #[tracing::instrument(
159 name = "db.user.find_by_username",
160 skip_all,
161 fields(
162 db.query.text,
163 user.username = username,
164 ),
165 err,
166 )]
167 async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
168 let res = sqlx::query_as!(
172 UserLookup,
173 r#"
174 SELECT user_id
175 , username
176 , created_at
177 , locked_at
178 , deactivated_at
179 , can_request_admin
180 FROM users
181 WHERE LOWER(username) = LOWER($1)
182 "#,
183 username,
184 )
185 .traced()
186 .fetch_all(&mut *self.conn)
187 .await?;
188
189 match &res[..] {
190 [user] => Ok(Some(user.clone().into())),
192 [] => Ok(None),
194 list => {
195 if let Some(user) = list.iter().find(|user| user.username == username) {
198 Ok(Some(user.clone().into()))
199 } else {
200 Ok(None)
202 }
203 }
204 }
205 }
206
207 #[tracing::instrument(
208 name = "db.user.add",
209 skip_all,
210 fields(
211 db.query.text,
212 user.username = username,
213 user.id,
214 ),
215 err,
216 )]
217 async fn add(
218 &mut self,
219 rng: &mut (dyn RngCore + Send),
220 clock: &dyn Clock,
221 username: String,
222 ) -> Result<User, Self::Error> {
223 let created_at = clock.now();
224 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
225 tracing::Span::current().record("user.id", tracing::field::display(id));
226
227 let res = sqlx::query!(
228 r#"
229 INSERT INTO users (user_id, username, created_at)
230 VALUES ($1, $2, $3)
231 ON CONFLICT (username) DO NOTHING
232 "#,
233 Uuid::from(id),
234 username,
235 created_at,
236 )
237 .traced()
238 .execute(&mut *self.conn)
239 .await?;
240
241 DatabaseError::ensure_affected_rows(&res, 1)?;
244
245 Ok(User {
246 id,
247 username,
248 sub: id.to_string(),
249 created_at,
250 locked_at: None,
251 deactivated_at: None,
252 can_request_admin: false,
253 })
254 }
255
256 #[tracing::instrument(
257 name = "db.user.exists",
258 skip_all,
259 fields(
260 db.query.text,
261 user.username = username,
262 ),
263 err,
264 )]
265 async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
266 let exists = sqlx::query_scalar!(
267 r#"
268 SELECT EXISTS(
269 SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
270 ) AS "exists!"
271 "#,
272 username
273 )
274 .traced()
275 .fetch_one(&mut *self.conn)
276 .await?;
277
278 Ok(exists)
279 }
280
281 #[tracing::instrument(
282 name = "db.user.lock",
283 skip_all,
284 fields(
285 db.query.text,
286 %user.id,
287 ),
288 err,
289 )]
290 async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
291 if user.locked_at.is_some() {
292 return Ok(user);
293 }
294
295 let locked_at = clock.now();
296 let res = sqlx::query!(
297 r#"
298 UPDATE users
299 SET locked_at = $1
300 WHERE user_id = $2
301 "#,
302 locked_at,
303 Uuid::from(user.id),
304 )
305 .traced()
306 .execute(&mut *self.conn)
307 .await?;
308
309 DatabaseError::ensure_affected_rows(&res, 1)?;
310
311 user.locked_at = Some(locked_at);
312
313 Ok(user)
314 }
315
316 #[tracing::instrument(
317 name = "db.user.unlock",
318 skip_all,
319 fields(
320 db.query.text,
321 %user.id,
322 ),
323 err,
324 )]
325 async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
326 if user.locked_at.is_none() {
327 return Ok(user);
328 }
329
330 let res = sqlx::query!(
331 r#"
332 UPDATE users
333 SET locked_at = NULL
334 WHERE user_id = $1
335 "#,
336 Uuid::from(user.id),
337 )
338 .traced()
339 .execute(&mut *self.conn)
340 .await?;
341
342 DatabaseError::ensure_affected_rows(&res, 1)?;
343
344 user.locked_at = None;
345
346 Ok(user)
347 }
348
349 #[tracing::instrument(
350 name = "db.user.deactivate",
351 skip_all,
352 fields(
353 db.query.text,
354 %user.id,
355 ),
356 err,
357 )]
358 async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
359 if user.deactivated_at.is_some() {
360 return Ok(user);
361 }
362
363 let deactivated_at = clock.now();
364 let res = sqlx::query!(
365 r#"
366 UPDATE users
367 SET deactivated_at = $2
368 WHERE user_id = $1
369 AND deactivated_at IS NULL
370 "#,
371 Uuid::from(user.id),
372 deactivated_at,
373 )
374 .traced()
375 .execute(&mut *self.conn)
376 .await?;
377
378 DatabaseError::ensure_affected_rows(&res, 1)?;
379
380 user.deactivated_at = Some(user.created_at);
381
382 Ok(user)
383 }
384
385 #[tracing::instrument(
386 name = "db.user.set_can_request_admin",
387 skip_all,
388 fields(
389 db.query.text,
390 %user.id,
391 user.can_request_admin = can_request_admin,
392 ),
393 err,
394 )]
395 async fn set_can_request_admin(
396 &mut self,
397 mut user: User,
398 can_request_admin: bool,
399 ) -> Result<User, Self::Error> {
400 let res = sqlx::query!(
401 r#"
402 UPDATE users
403 SET can_request_admin = $2
404 WHERE user_id = $1
405 "#,
406 Uuid::from(user.id),
407 can_request_admin,
408 )
409 .traced()
410 .execute(&mut *self.conn)
411 .await?;
412
413 DatabaseError::ensure_affected_rows(&res, 1)?;
414
415 user.can_request_admin = can_request_admin;
416
417 Ok(user)
418 }
419
420 #[tracing::instrument(
421 name = "db.user.list",
422 skip_all,
423 fields(
424 db.query.text,
425 ),
426 err,
427 )]
428 async fn list(
429 &mut self,
430 filter: UserFilter<'_>,
431 pagination: mas_storage::Pagination,
432 ) -> Result<mas_storage::Page<User>, Self::Error> {
433 let (sql, arguments) = Query::select()
434 .expr_as(
435 Expr::col((Users::Table, Users::UserId)),
436 UserLookupIden::UserId,
437 )
438 .expr_as(
439 Expr::col((Users::Table, Users::Username)),
440 UserLookupIden::Username,
441 )
442 .expr_as(
443 Expr::col((Users::Table, Users::CreatedAt)),
444 UserLookupIden::CreatedAt,
445 )
446 .expr_as(
447 Expr::col((Users::Table, Users::LockedAt)),
448 UserLookupIden::LockedAt,
449 )
450 .expr_as(
451 Expr::col((Users::Table, Users::DeactivatedAt)),
452 UserLookupIden::DeactivatedAt,
453 )
454 .expr_as(
455 Expr::col((Users::Table, Users::CanRequestAdmin)),
456 UserLookupIden::CanRequestAdmin,
457 )
458 .from(Users::Table)
459 .apply_filter(filter)
460 .generate_pagination((Users::Table, Users::UserId), pagination)
461 .build_sqlx(PostgresQueryBuilder);
462
463 let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
464 .traced()
465 .fetch_all(&mut *self.conn)
466 .await?;
467
468 let page = pagination.process(edges).map(User::from);
469
470 Ok(page)
471 }
472
473 #[tracing::instrument(
474 name = "db.user.count",
475 skip_all,
476 fields(
477 db.query.text,
478 ),
479 err,
480 )]
481 async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
482 let (sql, arguments) = Query::select()
483 .expr(Expr::col((Users::Table, Users::UserId)).count())
484 .from(Users::Table)
485 .apply_filter(filter)
486 .build_sqlx(PostgresQueryBuilder);
487
488 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
489 .traced()
490 .fetch_one(&mut *self.conn)
491 .await?;
492
493 count
494 .try_into()
495 .map_err(DatabaseError::to_invalid_operation)
496 }
497
498 #[tracing::instrument(
499 name = "db.user.acquire_lock_for_sync",
500 skip_all,
501 fields(
502 db.query.text,
503 user.id = %user.id,
504 ),
505 err,
506 )]
507 async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
508 let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
516
517 sqlx::query!(
520 r#"
521 SELECT pg_advisory_xact_lock($1)
522 "#,
523 lock_id,
524 )
525 .traced()
526 .execute(&mut *self.conn)
527 .await?;
528
529 Ok(())
530 }
531}