1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
10use mas_storage::{
11 Clock, Page, Pagination,
12 upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
13};
14use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
15use rand::RngCore;
16use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
17use sea_query_binder::SqlxBinder;
18use sqlx::PgConnection;
19use tracing::Instrument;
20use ulid::Ulid;
21use uuid::Uuid;
22
23use crate::{
24 DatabaseError,
25 filter::{Filter, StatementExt},
26 iden::{UpstreamOAuthLinks, UpstreamOAuthProviders},
27 pagination::QueryBuilderExt,
28 tracing::ExecuteExt,
29};
30
31pub struct PgUpstreamOAuthLinkRepository<'c> {
34 conn: &'c mut PgConnection,
35}
36
37impl<'c> PgUpstreamOAuthLinkRepository<'c> {
38 pub fn new(conn: &'c mut PgConnection) -> Self {
41 Self { conn }
42 }
43}
44
45#[derive(sqlx::FromRow)]
46#[enum_def]
47struct LinkLookup {
48 upstream_oauth_link_id: Uuid,
49 upstream_oauth_provider_id: Uuid,
50 user_id: Option<Uuid>,
51 subject: String,
52 human_account_name: Option<String>,
53 created_at: DateTime<Utc>,
54}
55
56impl From<LinkLookup> for UpstreamOAuthLink {
57 fn from(value: LinkLookup) -> Self {
58 UpstreamOAuthLink {
59 id: Ulid::from(value.upstream_oauth_link_id),
60 provider_id: Ulid::from(value.upstream_oauth_provider_id),
61 user_id: value.user_id.map(Ulid::from),
62 subject: value.subject,
63 human_account_name: value.human_account_name,
64 created_at: value.created_at,
65 }
66 }
67}
68
69impl Filter for UpstreamOAuthLinkFilter<'_> {
70 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
71 sea_query::Condition::all()
72 .add_option(self.user().map(|user| {
73 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId))
74 .eq(Uuid::from(user.id))
75 }))
76 .add_option(self.provider().map(|provider| {
77 Expr::col((
78 UpstreamOAuthLinks::Table,
79 UpstreamOAuthLinks::UpstreamOAuthProviderId,
80 ))
81 .eq(Uuid::from(provider.id))
82 }))
83 .add_option(self.provider_enabled().map(|enabled| {
84 Expr::col((
85 UpstreamOAuthLinks::Table,
86 UpstreamOAuthLinks::UpstreamOAuthProviderId,
87 ))
88 .eq(Expr::any(
89 Query::select()
90 .expr(Expr::col((
91 UpstreamOAuthProviders::Table,
92 UpstreamOAuthProviders::UpstreamOAuthProviderId,
93 )))
94 .from(UpstreamOAuthProviders::Table)
95 .and_where(
96 Expr::col((
97 UpstreamOAuthProviders::Table,
98 UpstreamOAuthProviders::DisabledAt,
99 ))
100 .is_null()
101 .eq(enabled),
102 )
103 .take(),
104 ))
105 }))
106 .add_option(self.subject().map(|subject| {
107 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)).eq(subject)
108 }))
109 }
110}
111
112#[async_trait]
113impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
114 type Error = DatabaseError;
115
116 #[tracing::instrument(
117 name = "db.upstream_oauth_link.lookup",
118 skip_all,
119 fields(
120 db.query.text,
121 upstream_oauth_link.id = %id,
122 ),
123 err,
124 )]
125 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
126 let res = sqlx::query_as!(
127 LinkLookup,
128 r#"
129 SELECT
130 upstream_oauth_link_id,
131 upstream_oauth_provider_id,
132 user_id,
133 subject,
134 human_account_name,
135 created_at
136 FROM upstream_oauth_links
137 WHERE upstream_oauth_link_id = $1
138 "#,
139 Uuid::from(id),
140 )
141 .traced()
142 .fetch_optional(&mut *self.conn)
143 .await?
144 .map(Into::into);
145
146 Ok(res)
147 }
148
149 #[tracing::instrument(
150 name = "db.upstream_oauth_link.find_by_subject",
151 skip_all,
152 fields(
153 db.query.text,
154 upstream_oauth_link.subject = subject,
155 %upstream_oauth_provider.id,
156 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
157 %upstream_oauth_provider.client_id,
158 ),
159 err,
160 )]
161 async fn find_by_subject(
162 &mut self,
163 upstream_oauth_provider: &UpstreamOAuthProvider,
164 subject: &str,
165 ) -> Result<Option<UpstreamOAuthLink>, Self::Error> {
166 let res = sqlx::query_as!(
167 LinkLookup,
168 r#"
169 SELECT
170 upstream_oauth_link_id,
171 upstream_oauth_provider_id,
172 user_id,
173 subject,
174 human_account_name,
175 created_at
176 FROM upstream_oauth_links
177 WHERE upstream_oauth_provider_id = $1
178 AND subject = $2
179 "#,
180 Uuid::from(upstream_oauth_provider.id),
181 subject,
182 )
183 .traced()
184 .fetch_optional(&mut *self.conn)
185 .await?
186 .map(Into::into);
187
188 Ok(res)
189 }
190
191 #[tracing::instrument(
192 name = "db.upstream_oauth_link.add",
193 skip_all,
194 fields(
195 db.query.text,
196 upstream_oauth_link.id,
197 upstream_oauth_link.subject = subject,
198 upstream_oauth_link.human_account_name = human_account_name,
199 %upstream_oauth_provider.id,
200 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
201 %upstream_oauth_provider.client_id,
202 ),
203 err,
204 )]
205 async fn add(
206 &mut self,
207 rng: &mut (dyn RngCore + Send),
208 clock: &dyn Clock,
209 upstream_oauth_provider: &UpstreamOAuthProvider,
210 subject: String,
211 human_account_name: Option<String>,
212 ) -> Result<UpstreamOAuthLink, Self::Error> {
213 let created_at = clock.now();
214 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
215 tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
216
217 sqlx::query!(
218 r#"
219 INSERT INTO upstream_oauth_links (
220 upstream_oauth_link_id,
221 upstream_oauth_provider_id,
222 user_id,
223 subject,
224 human_account_name,
225 created_at
226 ) VALUES ($1, $2, NULL, $3, $4, $5)
227 "#,
228 Uuid::from(id),
229 Uuid::from(upstream_oauth_provider.id),
230 &subject,
231 human_account_name.as_deref(),
232 created_at,
233 )
234 .traced()
235 .execute(&mut *self.conn)
236 .await?;
237
238 Ok(UpstreamOAuthLink {
239 id,
240 provider_id: upstream_oauth_provider.id,
241 user_id: None,
242 subject,
243 human_account_name,
244 created_at,
245 })
246 }
247
248 #[tracing::instrument(
249 name = "db.upstream_oauth_link.associate_to_user",
250 skip_all,
251 fields(
252 db.query.text,
253 %upstream_oauth_link.id,
254 %upstream_oauth_link.subject,
255 %user.id,
256 %user.username,
257 ),
258 err,
259 )]
260 async fn associate_to_user(
261 &mut self,
262 upstream_oauth_link: &UpstreamOAuthLink,
263 user: &User,
264 ) -> Result<(), Self::Error> {
265 sqlx::query!(
266 r#"
267 UPDATE upstream_oauth_links
268 SET user_id = $1
269 WHERE upstream_oauth_link_id = $2
270 "#,
271 Uuid::from(user.id),
272 Uuid::from(upstream_oauth_link.id),
273 )
274 .traced()
275 .execute(&mut *self.conn)
276 .await?;
277
278 Ok(())
279 }
280
281 #[tracing::instrument(
282 name = "db.upstream_oauth_link.list",
283 skip_all,
284 fields(
285 db.query.text,
286 ),
287 err,
288 )]
289 async fn list(
290 &mut self,
291 filter: UpstreamOAuthLinkFilter<'_>,
292 pagination: Pagination,
293 ) -> Result<Page<UpstreamOAuthLink>, DatabaseError> {
294 let (sql, arguments) = Query::select()
295 .expr_as(
296 Expr::col((
297 UpstreamOAuthLinks::Table,
298 UpstreamOAuthLinks::UpstreamOAuthLinkId,
299 )),
300 LinkLookupIden::UpstreamOauthLinkId,
301 )
302 .expr_as(
303 Expr::col((
304 UpstreamOAuthLinks::Table,
305 UpstreamOAuthLinks::UpstreamOAuthProviderId,
306 )),
307 LinkLookupIden::UpstreamOauthProviderId,
308 )
309 .expr_as(
310 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)),
311 LinkLookupIden::UserId,
312 )
313 .expr_as(
314 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
315 LinkLookupIden::Subject,
316 )
317 .expr_as(
318 Expr::col((
319 UpstreamOAuthLinks::Table,
320 UpstreamOAuthLinks::HumanAccountName,
321 )),
322 LinkLookupIden::HumanAccountName,
323 )
324 .expr_as(
325 Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
326 LinkLookupIden::CreatedAt,
327 )
328 .from(UpstreamOAuthLinks::Table)
329 .apply_filter(filter)
330 .generate_pagination(
331 (
332 UpstreamOAuthLinks::Table,
333 UpstreamOAuthLinks::UpstreamOAuthLinkId,
334 ),
335 pagination,
336 )
337 .build_sqlx(PostgresQueryBuilder);
338
339 let edges: Vec<LinkLookup> = sqlx::query_as_with(&sql, arguments)
340 .traced()
341 .fetch_all(&mut *self.conn)
342 .await?;
343
344 let page = pagination.process(edges).map(UpstreamOAuthLink::from);
345
346 Ok(page)
347 }
348
349 #[tracing::instrument(
350 name = "db.upstream_oauth_link.count",
351 skip_all,
352 fields(
353 db.query.text,
354 ),
355 err,
356 )]
357 async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error> {
358 let (sql, arguments) = Query::select()
359 .expr(
360 Expr::col((
361 UpstreamOAuthLinks::Table,
362 UpstreamOAuthLinks::UpstreamOAuthLinkId,
363 ))
364 .count(),
365 )
366 .from(UpstreamOAuthLinks::Table)
367 .apply_filter(filter)
368 .build_sqlx(PostgresQueryBuilder);
369
370 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
371 .traced()
372 .fetch_one(&mut *self.conn)
373 .await?;
374
375 count
376 .try_into()
377 .map_err(DatabaseError::to_invalid_operation)
378 }
379
380 #[tracing::instrument(
381 name = "db.upstream_oauth_link.remove",
382 skip_all,
383 fields(
384 db.query.text,
385 upstream_oauth_link.id,
386 upstream_oauth_link.provider_id,
387 %upstream_oauth_link.subject,
388 ),
389 err,
390 )]
391 async fn remove(
392 &mut self,
393 clock: &dyn Clock,
394 upstream_oauth_link: UpstreamOAuthLink,
395 ) -> Result<(), Self::Error> {
396 let span = tracing::info_span!(
399 "db.upstream_oauth_link.remove.unlink",
400 { DB_QUERY_TEXT } = tracing::field::Empty
401 );
402 sqlx::query!(
403 r#"
404 UPDATE upstream_oauth_authorization_sessions SET
405 upstream_oauth_link_id = NULL,
406 unlinked_at = $2
407 WHERE upstream_oauth_link_id = $1
408 "#,
409 Uuid::from(upstream_oauth_link.id),
410 clock.now()
411 )
412 .record(&span)
413 .execute(&mut *self.conn)
414 .instrument(span)
415 .await?;
416
417 let span = tracing::info_span!(
419 "db.upstream_oauth_link.remove.delete",
420 { DB_QUERY_TEXT } = tracing::field::Empty
421 );
422 let res = sqlx::query!(
423 r#"
424 DELETE FROM upstream_oauth_links
425 WHERE upstream_oauth_link_id = $1
426 "#,
427 Uuid::from(upstream_oauth_link.id),
428 )
429 .record(&span)
430 .execute(&mut *self.conn)
431 .instrument(span)
432 .await?;
433
434 DatabaseError::ensure_affected_rows(&res, 1)?;
435
436 Ok(())
437 }
438}