1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 BrowserSession, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
11 UserRegistration,
12};
13use mas_storage::{
14 Clock, Page, Pagination,
15 user::{UserEmailFilter, UserEmailRepository},
16};
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25 DatabaseError,
26 filter::{Filter, StatementExt},
27 iden::UserEmails,
28 pagination::QueryBuilderExt,
29 tracing::ExecuteExt,
30};
31
32pub struct PgUserEmailRepository<'c> {
34 conn: &'c mut PgConnection,
35}
36
37impl<'c> PgUserEmailRepository<'c> {
38 pub fn new(conn: &'c mut PgConnection) -> Self {
41 Self { conn }
42 }
43}
44
45#[derive(Debug, Clone, sqlx::FromRow)]
46#[enum_def]
47struct UserEmailLookup {
48 user_email_id: Uuid,
49 user_id: Uuid,
50 email: String,
51 created_at: DateTime<Utc>,
52}
53
54impl From<UserEmailLookup> for UserEmail {
55 fn from(e: UserEmailLookup) -> UserEmail {
56 UserEmail {
57 id: e.user_email_id.into(),
58 user_id: e.user_id.into(),
59 email: e.email,
60 created_at: e.created_at,
61 }
62 }
63}
64
65struct UserEmailAuthenticationLookup {
66 user_email_authentication_id: Uuid,
67 user_session_id: Option<Uuid>,
68 user_registration_id: Option<Uuid>,
69 email: String,
70 created_at: DateTime<Utc>,
71 completed_at: Option<DateTime<Utc>>,
72}
73
74impl From<UserEmailAuthenticationLookup> for UserEmailAuthentication {
75 fn from(value: UserEmailAuthenticationLookup) -> Self {
76 UserEmailAuthentication {
77 id: value.user_email_authentication_id.into(),
78 user_session_id: value.user_session_id.map(Ulid::from),
79 user_registration_id: value.user_registration_id.map(Ulid::from),
80 email: value.email,
81 created_at: value.created_at,
82 completed_at: value.completed_at,
83 }
84 }
85}
86
87struct UserEmailAuthenticationCodeLookup {
88 user_email_authentication_code_id: Uuid,
89 user_email_authentication_id: Uuid,
90 code: String,
91 created_at: DateTime<Utc>,
92 expires_at: DateTime<Utc>,
93}
94
95impl From<UserEmailAuthenticationCodeLookup> for UserEmailAuthenticationCode {
96 fn from(value: UserEmailAuthenticationCodeLookup) -> Self {
97 UserEmailAuthenticationCode {
98 id: value.user_email_authentication_code_id.into(),
99 user_email_authentication_id: value.user_email_authentication_id.into(),
100 code: value.code,
101 created_at: value.created_at,
102 expires_at: value.expires_at,
103 }
104 }
105}
106
107impl Filter for UserEmailFilter<'_> {
108 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
109 sea_query::Condition::all()
110 .add_option(self.user().map(|user| {
111 Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
112 }))
113 .add_option(
114 self.email()
115 .map(|email| Expr::col((UserEmails::Table, UserEmails::Email)).eq(email)),
116 )
117 }
118}
119
120#[async_trait]
121impl UserEmailRepository for PgUserEmailRepository<'_> {
122 type Error = DatabaseError;
123
124 #[tracing::instrument(
125 name = "db.user_email.lookup",
126 skip_all,
127 fields(
128 db.query.text,
129 user_email.id = %id,
130 ),
131 err,
132 )]
133 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
134 let res = sqlx::query_as!(
135 UserEmailLookup,
136 r#"
137 SELECT user_email_id
138 , user_id
139 , email
140 , created_at
141 FROM user_emails
142
143 WHERE user_email_id = $1
144 "#,
145 Uuid::from(id),
146 )
147 .traced()
148 .fetch_optional(&mut *self.conn)
149 .await?;
150
151 let Some(user_email) = res else {
152 return Ok(None);
153 };
154
155 Ok(Some(user_email.into()))
156 }
157
158 #[tracing::instrument(
159 name = "db.user_email.find",
160 skip_all,
161 fields(
162 db.query.text,
163 %user.id,
164 user_email.email = email,
165 ),
166 err,
167 )]
168 async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
169 let res = sqlx::query_as!(
170 UserEmailLookup,
171 r#"
172 SELECT user_email_id
173 , user_id
174 , email
175 , created_at
176 FROM user_emails
177
178 WHERE user_id = $1 AND email = $2
179 "#,
180 Uuid::from(user.id),
181 email,
182 )
183 .traced()
184 .fetch_optional(&mut *self.conn)
185 .await?;
186
187 let Some(user_email) = res else {
188 return Ok(None);
189 };
190
191 Ok(Some(user_email.into()))
192 }
193
194 #[tracing::instrument(
195 name = "db.user_email.find_by_email",
196 skip_all,
197 fields(
198 db.query.text,
199 user_email.email = email,
200 ),
201 err,
202 )]
203 async fn find_by_email(&mut self, email: &str) -> Result<Option<UserEmail>, Self::Error> {
204 let res = sqlx::query_as!(
205 UserEmailLookup,
206 r#"
207 SELECT user_email_id
208 , user_id
209 , email
210 , created_at
211 FROM user_emails
212 WHERE email = $1
213 "#,
214 email,
215 )
216 .traced()
217 .fetch_all(&mut *self.conn)
218 .await?;
219
220 if res.len() != 1 {
221 return Ok(None);
222 }
223
224 let Some(user_email) = res.into_iter().next() else {
225 return Ok(None);
226 };
227
228 Ok(Some(user_email.into()))
229 }
230
231 #[tracing::instrument(
232 name = "db.user_email.all",
233 skip_all,
234 fields(
235 db.query.text,
236 %user.id,
237 ),
238 err,
239 )]
240 async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
241 let res = sqlx::query_as!(
242 UserEmailLookup,
243 r#"
244 SELECT user_email_id
245 , user_id
246 , email
247 , created_at
248 FROM user_emails
249
250 WHERE user_id = $1
251
252 ORDER BY email ASC
253 "#,
254 Uuid::from(user.id),
255 )
256 .traced()
257 .fetch_all(&mut *self.conn)
258 .await?;
259
260 Ok(res.into_iter().map(Into::into).collect())
261 }
262
263 #[tracing::instrument(
264 name = "db.user_email.list",
265 skip_all,
266 fields(
267 db.query.text,
268 ),
269 err,
270 )]
271 async fn list(
272 &mut self,
273 filter: UserEmailFilter<'_>,
274 pagination: Pagination,
275 ) -> Result<Page<UserEmail>, DatabaseError> {
276 let (sql, arguments) = Query::select()
277 .expr_as(
278 Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
279 UserEmailLookupIden::UserEmailId,
280 )
281 .expr_as(
282 Expr::col((UserEmails::Table, UserEmails::UserId)),
283 UserEmailLookupIden::UserId,
284 )
285 .expr_as(
286 Expr::col((UserEmails::Table, UserEmails::Email)),
287 UserEmailLookupIden::Email,
288 )
289 .expr_as(
290 Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
291 UserEmailLookupIden::CreatedAt,
292 )
293 .from(UserEmails::Table)
294 .apply_filter(filter)
295 .generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination)
296 .build_sqlx(PostgresQueryBuilder);
297
298 let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
299 .traced()
300 .fetch_all(&mut *self.conn)
301 .await?;
302
303 let page = pagination.process(edges).map(UserEmail::from);
304
305 Ok(page)
306 }
307
308 #[tracing::instrument(
309 name = "db.user_email.count",
310 skip_all,
311 fields(
312 db.query.text,
313 ),
314 err,
315 )]
316 async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
317 let (sql, arguments) = Query::select()
318 .expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
319 .from(UserEmails::Table)
320 .apply_filter(filter)
321 .build_sqlx(PostgresQueryBuilder);
322
323 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
324 .traced()
325 .fetch_one(&mut *self.conn)
326 .await?;
327
328 count
329 .try_into()
330 .map_err(DatabaseError::to_invalid_operation)
331 }
332
333 #[tracing::instrument(
334 name = "db.user_email.add",
335 skip_all,
336 fields(
337 db.query.text,
338 %user.id,
339 user_email.id,
340 user_email.email = email,
341 ),
342 err,
343 )]
344 async fn add(
345 &mut self,
346 rng: &mut (dyn RngCore + Send),
347 clock: &dyn Clock,
348 user: &User,
349 email: String,
350 ) -> Result<UserEmail, Self::Error> {
351 let created_at = clock.now();
352 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
353 tracing::Span::current().record("user_email.id", tracing::field::display(id));
354
355 sqlx::query!(
356 r#"
357 INSERT INTO user_emails (user_email_id, user_id, email, created_at)
358 VALUES ($1, $2, $3, $4)
359 "#,
360 Uuid::from(id),
361 Uuid::from(user.id),
362 &email,
363 created_at,
364 )
365 .traced()
366 .execute(&mut *self.conn)
367 .await?;
368
369 Ok(UserEmail {
370 id,
371 user_id: user.id,
372 email,
373 created_at,
374 })
375 }
376
377 #[tracing::instrument(
378 name = "db.user_email.remove",
379 skip_all,
380 fields(
381 db.query.text,
382 user.id = %user_email.user_id,
383 %user_email.id,
384 %user_email.email,
385 ),
386 err,
387 )]
388 async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
389 let res = sqlx::query!(
390 r#"
391 DELETE FROM user_emails
392 WHERE user_email_id = $1
393 "#,
394 Uuid::from(user_email.id),
395 )
396 .traced()
397 .execute(&mut *self.conn)
398 .await?;
399
400 DatabaseError::ensure_affected_rows(&res, 1)?;
401
402 Ok(())
403 }
404
405 #[tracing::instrument(
406 name = "db.user_email.remove_bulk",
407 skip_all,
408 fields(
409 db.query.text,
410 ),
411 err,
412 )]
413 async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
414 let (sql, arguments) = Query::delete()
415 .from_table(UserEmails::Table)
416 .apply_filter(filter)
417 .build_sqlx(PostgresQueryBuilder);
418
419 let res = sqlx::query_with(&sql, arguments)
420 .traced()
421 .execute(&mut *self.conn)
422 .await?;
423
424 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
425 }
426
427 #[tracing::instrument(
428 name = "db.user_email.add_authentication_for_session",
429 skip_all,
430 fields(
431 db.query.text,
432 %session.id,
433 user_email_authentication.id,
434 user_email_authentication.email = email,
435 ),
436 err,
437 )]
438 async fn add_authentication_for_session(
439 &mut self,
440 rng: &mut (dyn RngCore + Send),
441 clock: &dyn Clock,
442 email: String,
443 session: &BrowserSession,
444 ) -> Result<UserEmailAuthentication, Self::Error> {
445 let created_at = clock.now();
446 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
447 tracing::Span::current()
448 .record("user_email_authentication.id", tracing::field::display(id));
449
450 sqlx::query!(
451 r#"
452 INSERT INTO user_email_authentications
453 ( user_email_authentication_id
454 , user_session_id
455 , email
456 , created_at
457 )
458 VALUES ($1, $2, $3, $4)
459 "#,
460 Uuid::from(id),
461 Uuid::from(session.id),
462 &email,
463 created_at,
464 )
465 .traced()
466 .execute(&mut *self.conn)
467 .await?;
468
469 Ok(UserEmailAuthentication {
470 id,
471 user_session_id: Some(session.id),
472 user_registration_id: None,
473 email,
474 created_at,
475 completed_at: None,
476 })
477 }
478
479 #[tracing::instrument(
480 name = "db.user_email.add_authentication_for_registration",
481 skip_all,
482 fields(
483 db.query.text,
484 %user_registration.id,
485 user_email_authentication.id,
486 user_email_authentication.email = email,
487 ),
488 err,
489 )]
490 async fn add_authentication_for_registration(
491 &mut self,
492 rng: &mut (dyn RngCore + Send),
493 clock: &dyn Clock,
494 email: String,
495 user_registration: &UserRegistration,
496 ) -> Result<UserEmailAuthentication, Self::Error> {
497 let created_at = clock.now();
498 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
499 tracing::Span::current()
500 .record("user_email_authentication.id", tracing::field::display(id));
501
502 sqlx::query!(
503 r#"
504 INSERT INTO user_email_authentications
505 ( user_email_authentication_id
506 , user_registration_id
507 , email
508 , created_at
509 )
510 VALUES ($1, $2, $3, $4)
511 "#,
512 Uuid::from(id),
513 Uuid::from(user_registration.id),
514 &email,
515 created_at,
516 )
517 .traced()
518 .execute(&mut *self.conn)
519 .await?;
520
521 Ok(UserEmailAuthentication {
522 id,
523 user_session_id: None,
524 user_registration_id: Some(user_registration.id),
525 email,
526 created_at,
527 completed_at: None,
528 })
529 }
530
531 #[tracing::instrument(
532 name = "db.user_email.add_authentication_code",
533 skip_all,
534 fields(
535 db.query.text,
536 %user_email_authentication.id,
537 %user_email_authentication.email,
538 user_email_authentication_code.id,
539 user_email_authentication_code.code = code,
540 ),
541 err,
542 )]
543 async fn add_authentication_code(
544 &mut self,
545 rng: &mut (dyn RngCore + Send),
546 clock: &dyn Clock,
547 duration: chrono::Duration,
548 user_email_authentication: &UserEmailAuthentication,
549 code: String,
550 ) -> Result<UserEmailAuthenticationCode, Self::Error> {
551 let created_at = clock.now();
552 let expires_at = created_at + duration;
553 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
554 tracing::Span::current().record(
555 "user_email_authentication_code.id",
556 tracing::field::display(id),
557 );
558
559 sqlx::query!(
560 r#"
561 INSERT INTO user_email_authentication_codes
562 ( user_email_authentication_code_id
563 , user_email_authentication_id
564 , code
565 , created_at
566 , expires_at
567 )
568 VALUES ($1, $2, $3, $4, $5)
569 "#,
570 Uuid::from(id),
571 Uuid::from(user_email_authentication.id),
572 &code,
573 created_at,
574 expires_at,
575 )
576 .traced()
577 .execute(&mut *self.conn)
578 .await?;
579
580 Ok(UserEmailAuthenticationCode {
581 id,
582 user_email_authentication_id: user_email_authentication.id,
583 code,
584 created_at,
585 expires_at,
586 })
587 }
588
589 #[tracing::instrument(
590 name = "db.user_email.lookup_authentication",
591 skip_all,
592 fields(
593 db.query.text,
594 user_email_authentication.id = %id,
595 ),
596 err,
597 )]
598 async fn lookup_authentication(
599 &mut self,
600 id: Ulid,
601 ) -> Result<Option<UserEmailAuthentication>, Self::Error> {
602 let res = sqlx::query_as!(
603 UserEmailAuthenticationLookup,
604 r#"
605 SELECT user_email_authentication_id
606 , user_session_id
607 , user_registration_id
608 , email
609 , created_at
610 , completed_at
611 FROM user_email_authentications
612 WHERE user_email_authentication_id = $1
613 "#,
614 Uuid::from(id),
615 )
616 .traced()
617 .fetch_optional(&mut *self.conn)
618 .await?;
619
620 Ok(res.map(UserEmailAuthentication::from))
621 }
622
623 #[tracing::instrument(
624 name = "db.user_email.find_authentication_by_code",
625 skip_all,
626 fields(
627 db.query.text,
628 %authentication.id,
629 user_email_authentication_code.code = code,
630 ),
631 err,
632 )]
633 async fn find_authentication_code(
634 &mut self,
635 authentication: &UserEmailAuthentication,
636 code: &str,
637 ) -> Result<Option<UserEmailAuthenticationCode>, Self::Error> {
638 let res = sqlx::query_as!(
639 UserEmailAuthenticationCodeLookup,
640 r#"
641 SELECT user_email_authentication_code_id
642 , user_email_authentication_id
643 , code
644 , created_at
645 , expires_at
646 FROM user_email_authentication_codes
647 WHERE user_email_authentication_id = $1
648 AND code = $2
649 "#,
650 Uuid::from(authentication.id),
651 code,
652 )
653 .traced()
654 .fetch_optional(&mut *self.conn)
655 .await?;
656
657 Ok(res.map(UserEmailAuthenticationCode::from))
658 }
659
660 #[tracing::instrument(
661 name = "db.user_email.complete_email_authentication",
662 skip_all,
663 fields(
664 db.query.text,
665 %user_email_authentication.id,
666 %user_email_authentication.email,
667 %user_email_authentication_code.id,
668 %user_email_authentication_code.code,
669 ),
670 err,
671 )]
672 async fn complete_authentication(
673 &mut self,
674 clock: &dyn Clock,
675 mut user_email_authentication: UserEmailAuthentication,
676 user_email_authentication_code: &UserEmailAuthenticationCode,
677 ) -> Result<UserEmailAuthentication, Self::Error> {
678 let completed_at = clock.now();
682
683 let res = sqlx::query!(
687 r#"
688 UPDATE user_email_authentications
689 SET completed_at = $2
690 WHERE user_email_authentication_id = $1
691 AND completed_at IS NULL
692 "#,
693 Uuid::from(user_email_authentication.id),
694 completed_at,
695 )
696 .traced()
697 .execute(&mut *self.conn)
698 .await?;
699
700 DatabaseError::ensure_affected_rows(&res, 1)?;
701
702 user_email_authentication.completed_at = Some(completed_at);
703 Ok(user_email_authentication)
704 }
705}