mas_storage_pg/oauth2/
client.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use std::{
8    collections::{BTreeMap, BTreeSet},
9    string::ToString,
10};
11
12use async_trait::async_trait;
13use mas_data_model::{Client, JwksOrJwksUri};
14use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
15use mas_jose::jwk::PublicJsonWebKeySet;
16use mas_storage::{Clock, oauth2::OAuth2ClientRepository};
17use oauth2_types::{oidc::ApplicationType, requests::GrantType};
18use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
19use rand::RngCore;
20use sqlx::PgConnection;
21use tracing::{Instrument, info_span};
22use ulid::Ulid;
23use url::Url;
24use uuid::Uuid;
25
26use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
27
28/// An implementation of [`OAuth2ClientRepository`] for a PostgreSQL connection
29pub struct PgOAuth2ClientRepository<'c> {
30    conn: &'c mut PgConnection,
31}
32
33impl<'c> PgOAuth2ClientRepository<'c> {
34    /// Create a new [`PgOAuth2ClientRepository`] from an active PostgreSQL
35    /// connection
36    pub fn new(conn: &'c mut PgConnection) -> Self {
37        Self { conn }
38    }
39}
40
41#[allow(clippy::struct_excessive_bools)]
42#[derive(Debug)]
43struct OAuth2ClientLookup {
44    oauth2_client_id: Uuid,
45    metadata_digest: Option<String>,
46    encrypted_client_secret: Option<String>,
47    application_type: Option<String>,
48    redirect_uris: Vec<String>,
49    grant_type_authorization_code: bool,
50    grant_type_refresh_token: bool,
51    grant_type_client_credentials: bool,
52    grant_type_device_code: bool,
53    client_name: Option<String>,
54    logo_uri: Option<String>,
55    client_uri: Option<String>,
56    policy_uri: Option<String>,
57    tos_uri: Option<String>,
58    jwks_uri: Option<String>,
59    jwks: Option<serde_json::Value>,
60    id_token_signed_response_alg: Option<String>,
61    userinfo_signed_response_alg: Option<String>,
62    token_endpoint_auth_method: Option<String>,
63    token_endpoint_auth_signing_alg: Option<String>,
64    initiate_login_uri: Option<String>,
65}
66
67impl TryInto<Client> for OAuth2ClientLookup {
68    type Error = DatabaseInconsistencyError;
69
70    #[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
71    fn try_into(self) -> Result<Client, Self::Error> {
72        let id = Ulid::from(self.oauth2_client_id);
73
74        let redirect_uris: Result<Vec<Url>, _> =
75            self.redirect_uris.iter().map(|s| s.parse()).collect();
76        let redirect_uris = redirect_uris.map_err(|e| {
77            DatabaseInconsistencyError::on("oauth2_clients")
78                .column("redirect_uris")
79                .row(id)
80                .source(e)
81        })?;
82
83        let application_type = self
84            .application_type
85            .map(|s| s.parse())
86            .transpose()
87            .map_err(|e| {
88                DatabaseInconsistencyError::on("oauth2_clients")
89                    .column("application_type")
90                    .row(id)
91                    .source(e)
92            })?;
93
94        let mut grant_types = Vec::new();
95        if self.grant_type_authorization_code {
96            grant_types.push(GrantType::AuthorizationCode);
97        }
98        if self.grant_type_refresh_token {
99            grant_types.push(GrantType::RefreshToken);
100        }
101        if self.grant_type_client_credentials {
102            grant_types.push(GrantType::ClientCredentials);
103        }
104        if self.grant_type_device_code {
105            grant_types.push(GrantType::DeviceCode);
106        }
107
108        let logo_uri = self.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
109            DatabaseInconsistencyError::on("oauth2_clients")
110                .column("logo_uri")
111                .row(id)
112                .source(e)
113        })?;
114
115        let client_uri = self
116            .client_uri
117            .map(|s| s.parse())
118            .transpose()
119            .map_err(|e| {
120                DatabaseInconsistencyError::on("oauth2_clients")
121                    .column("client_uri")
122                    .row(id)
123                    .source(e)
124            })?;
125
126        let policy_uri = self
127            .policy_uri
128            .map(|s| s.parse())
129            .transpose()
130            .map_err(|e| {
131                DatabaseInconsistencyError::on("oauth2_clients")
132                    .column("policy_uri")
133                    .row(id)
134                    .source(e)
135            })?;
136
137        let tos_uri = self.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
138            DatabaseInconsistencyError::on("oauth2_clients")
139                .column("tos_uri")
140                .row(id)
141                .source(e)
142        })?;
143
144        let id_token_signed_response_alg = self
145            .id_token_signed_response_alg
146            .map(|s| s.parse())
147            .transpose()
148            .map_err(|e| {
149                DatabaseInconsistencyError::on("oauth2_clients")
150                    .column("id_token_signed_response_alg")
151                    .row(id)
152                    .source(e)
153            })?;
154
155        let userinfo_signed_response_alg = self
156            .userinfo_signed_response_alg
157            .map(|s| s.parse())
158            .transpose()
159            .map_err(|e| {
160                DatabaseInconsistencyError::on("oauth2_clients")
161                    .column("userinfo_signed_response_alg")
162                    .row(id)
163                    .source(e)
164            })?;
165
166        let token_endpoint_auth_method = self
167            .token_endpoint_auth_method
168            .map(|s| s.parse())
169            .transpose()
170            .map_err(|e| {
171                DatabaseInconsistencyError::on("oauth2_clients")
172                    .column("token_endpoint_auth_method")
173                    .row(id)
174                    .source(e)
175            })?;
176
177        let token_endpoint_auth_signing_alg = self
178            .token_endpoint_auth_signing_alg
179            .map(|s| s.parse())
180            .transpose()
181            .map_err(|e| {
182                DatabaseInconsistencyError::on("oauth2_clients")
183                    .column("token_endpoint_auth_signing_alg")
184                    .row(id)
185                    .source(e)
186            })?;
187
188        let initiate_login_uri = self
189            .initiate_login_uri
190            .map(|s| s.parse())
191            .transpose()
192            .map_err(|e| {
193                DatabaseInconsistencyError::on("oauth2_clients")
194                    .column("initiate_login_uri")
195                    .row(id)
196                    .source(e)
197            })?;
198
199        let jwks = match (self.jwks, self.jwks_uri) {
200            (None, None) => None,
201            (Some(jwks), None) => {
202                let jwks = serde_json::from_value(jwks).map_err(|e| {
203                    DatabaseInconsistencyError::on("oauth2_clients")
204                        .column("jwks")
205                        .row(id)
206                        .source(e)
207                })?;
208                Some(JwksOrJwksUri::Jwks(jwks))
209            }
210            (None, Some(jwks_uri)) => {
211                let jwks_uri = jwks_uri.parse().map_err(|e| {
212                    DatabaseInconsistencyError::on("oauth2_clients")
213                        .column("jwks_uri")
214                        .row(id)
215                        .source(e)
216                })?;
217
218                Some(JwksOrJwksUri::JwksUri(jwks_uri))
219            }
220            _ => {
221                return Err(DatabaseInconsistencyError::on("oauth2_clients")
222                    .column("jwks(_uri)")
223                    .row(id));
224            }
225        };
226
227        Ok(Client {
228            id,
229            client_id: id.to_string(),
230            metadata_digest: self.metadata_digest,
231            encrypted_client_secret: self.encrypted_client_secret,
232            application_type,
233            redirect_uris,
234            grant_types,
235            client_name: self.client_name,
236            logo_uri,
237            client_uri,
238            policy_uri,
239            tos_uri,
240            jwks,
241            id_token_signed_response_alg,
242            userinfo_signed_response_alg,
243            token_endpoint_auth_method,
244            token_endpoint_auth_signing_alg,
245            initiate_login_uri,
246        })
247    }
248}
249
250#[async_trait]
251impl OAuth2ClientRepository for PgOAuth2ClientRepository<'_> {
252    type Error = DatabaseError;
253
254    #[tracing::instrument(
255        name = "db.oauth2_client.lookup",
256        skip_all,
257        fields(
258            db.query.text,
259            oauth2_client.id = %id,
260        ),
261        err,
262    )]
263    async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
264        let res = sqlx::query_as!(
265            OAuth2ClientLookup,
266            r#"
267                SELECT oauth2_client_id
268                     , metadata_digest
269                     , encrypted_client_secret
270                     , application_type
271                     , redirect_uris
272                     , grant_type_authorization_code
273                     , grant_type_refresh_token
274                     , grant_type_client_credentials
275                     , grant_type_device_code
276                     , client_name
277                     , logo_uri
278                     , client_uri
279                     , policy_uri
280                     , tos_uri
281                     , jwks_uri
282                     , jwks
283                     , id_token_signed_response_alg
284                     , userinfo_signed_response_alg
285                     , token_endpoint_auth_method
286                     , token_endpoint_auth_signing_alg
287                     , initiate_login_uri
288                FROM oauth2_clients c
289
290                WHERE oauth2_client_id = $1
291            "#,
292            Uuid::from(id),
293        )
294        .traced()
295        .fetch_optional(&mut *self.conn)
296        .await?;
297
298        let Some(res) = res else { return Ok(None) };
299
300        Ok(Some(res.try_into()?))
301    }
302
303    #[tracing::instrument(
304        name = "db.oauth2_client.find_by_metadata_digest",
305        skip_all,
306        fields(
307            db.query.text,
308        ),
309        err,
310    )]
311    async fn find_by_metadata_digest(
312        &mut self,
313        digest: &str,
314    ) -> Result<Option<Client>, Self::Error> {
315        let res = sqlx::query_as!(
316            OAuth2ClientLookup,
317            r#"
318                SELECT oauth2_client_id
319                    , metadata_digest
320                    , encrypted_client_secret
321                    , application_type
322                    , redirect_uris
323                    , grant_type_authorization_code
324                    , grant_type_refresh_token
325                    , grant_type_client_credentials
326                    , grant_type_device_code
327                    , client_name
328                    , logo_uri
329                    , client_uri
330                    , policy_uri
331                    , tos_uri
332                    , jwks_uri
333                    , jwks
334                    , id_token_signed_response_alg
335                    , userinfo_signed_response_alg
336                    , token_endpoint_auth_method
337                    , token_endpoint_auth_signing_alg
338                    , initiate_login_uri
339                FROM oauth2_clients
340                WHERE metadata_digest = $1
341            "#,
342            digest,
343        )
344        .traced()
345        .fetch_optional(&mut *self.conn)
346        .await?;
347
348        let Some(res) = res else { return Ok(None) };
349
350        Ok(Some(res.try_into()?))
351    }
352
353    #[tracing::instrument(
354        name = "db.oauth2_client.load_batch",
355        skip_all,
356        fields(
357            db.query.text,
358        ),
359        err,
360    )]
361    async fn load_batch(
362        &mut self,
363        ids: BTreeSet<Ulid>,
364    ) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
365        let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
366        let res = sqlx::query_as!(
367            OAuth2ClientLookup,
368            r#"
369                SELECT oauth2_client_id
370                     , metadata_digest
371                     , encrypted_client_secret
372                     , application_type
373                     , redirect_uris
374                     , grant_type_authorization_code
375                     , grant_type_refresh_token
376                     , grant_type_client_credentials
377                     , grant_type_device_code
378                     , client_name
379                     , logo_uri
380                     , client_uri
381                     , policy_uri
382                     , tos_uri
383                     , jwks_uri
384                     , jwks
385                     , id_token_signed_response_alg
386                     , userinfo_signed_response_alg
387                     , token_endpoint_auth_method
388                     , token_endpoint_auth_signing_alg
389                     , initiate_login_uri
390                FROM oauth2_clients c
391
392                WHERE oauth2_client_id = ANY($1::uuid[])
393            "#,
394            &ids,
395        )
396        .traced()
397        .fetch_all(&mut *self.conn)
398        .await?;
399
400        res.into_iter()
401            .map(|r| {
402                r.try_into()
403                    .map(|c: Client| (c.id, c))
404                    .map_err(DatabaseError::from)
405            })
406            .collect()
407    }
408
409    #[tracing::instrument(
410        name = "db.oauth2_client.add",
411        skip_all,
412        fields(
413            db.query.text,
414            client.id,
415            client.name = client_name
416        ),
417        err,
418    )]
419    #[allow(clippy::too_many_lines)]
420    async fn add(
421        &mut self,
422        rng: &mut (dyn RngCore + Send),
423        clock: &dyn Clock,
424        redirect_uris: Vec<Url>,
425        metadata_digest: Option<String>,
426        encrypted_client_secret: Option<String>,
427        application_type: Option<ApplicationType>,
428        grant_types: Vec<GrantType>,
429        client_name: Option<String>,
430        logo_uri: Option<Url>,
431        client_uri: Option<Url>,
432        policy_uri: Option<Url>,
433        tos_uri: Option<Url>,
434        jwks_uri: Option<Url>,
435        jwks: Option<PublicJsonWebKeySet>,
436        id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
437        userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
438        token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
439        token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
440        initiate_login_uri: Option<Url>,
441    ) -> Result<Client, Self::Error> {
442        let now = clock.now();
443        let id = Ulid::from_datetime_with_source(now.into(), rng);
444        tracing::Span::current().record("client.id", tracing::field::display(id));
445
446        let jwks_json = jwks
447            .as_ref()
448            .map(serde_json::to_value)
449            .transpose()
450            .map_err(DatabaseError::to_invalid_operation)?;
451
452        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
453
454        sqlx::query!(
455            r#"
456                INSERT INTO oauth2_clients
457                    ( oauth2_client_id
458                    , metadata_digest
459                    , encrypted_client_secret
460                    , application_type
461                    , redirect_uris
462                    , grant_type_authorization_code
463                    , grant_type_refresh_token
464                    , grant_type_client_credentials
465                    , grant_type_device_code
466                    , client_name
467                    , logo_uri
468                    , client_uri
469                    , policy_uri
470                    , tos_uri
471                    , jwks_uri
472                    , jwks
473                    , id_token_signed_response_alg
474                    , userinfo_signed_response_alg
475                    , token_endpoint_auth_method
476                    , token_endpoint_auth_signing_alg
477                    , initiate_login_uri
478                    , is_static
479                    )
480                VALUES
481                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
482                    $14, $15, $16, $17, $18, $19, $20, $21, FALSE)
483            "#,
484            Uuid::from(id),
485            metadata_digest,
486            encrypted_client_secret,
487            application_type.as_ref().map(ToString::to_string),
488            &redirect_uris_array,
489            grant_types.contains(&GrantType::AuthorizationCode),
490            grant_types.contains(&GrantType::RefreshToken),
491            grant_types.contains(&GrantType::ClientCredentials),
492            grant_types.contains(&GrantType::DeviceCode),
493            client_name,
494            logo_uri.as_ref().map(Url::as_str),
495            client_uri.as_ref().map(Url::as_str),
496            policy_uri.as_ref().map(Url::as_str),
497            tos_uri.as_ref().map(Url::as_str),
498            jwks_uri.as_ref().map(Url::as_str),
499            jwks_json,
500            id_token_signed_response_alg
501                .as_ref()
502                .map(ToString::to_string),
503            userinfo_signed_response_alg
504                .as_ref()
505                .map(ToString::to_string),
506            token_endpoint_auth_method.as_ref().map(ToString::to_string),
507            token_endpoint_auth_signing_alg
508                .as_ref()
509                .map(ToString::to_string),
510            initiate_login_uri.as_ref().map(Url::as_str),
511        )
512        .traced()
513        .execute(&mut *self.conn)
514        .await?;
515
516        let jwks = match (jwks, jwks_uri) {
517            (None, None) => None,
518            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
519            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
520            _ => return Err(DatabaseError::invalid_operation()),
521        };
522
523        Ok(Client {
524            id,
525            client_id: id.to_string(),
526            metadata_digest: None,
527            encrypted_client_secret,
528            application_type,
529            redirect_uris,
530            grant_types,
531            client_name,
532            logo_uri,
533            client_uri,
534            policy_uri,
535            tos_uri,
536            jwks,
537            id_token_signed_response_alg,
538            userinfo_signed_response_alg,
539            token_endpoint_auth_method,
540            token_endpoint_auth_signing_alg,
541            initiate_login_uri,
542        })
543    }
544
545    #[tracing::instrument(
546        name = "db.oauth2_client.upsert_static",
547        skip_all,
548        fields(
549            db.query.text,
550            client.id = %client_id,
551        ),
552        err,
553    )]
554    async fn upsert_static(
555        &mut self,
556        client_id: Ulid,
557        client_auth_method: OAuthClientAuthenticationMethod,
558        encrypted_client_secret: Option<String>,
559        jwks: Option<PublicJsonWebKeySet>,
560        jwks_uri: Option<Url>,
561        redirect_uris: Vec<Url>,
562    ) -> Result<Client, Self::Error> {
563        let jwks_json = jwks
564            .as_ref()
565            .map(serde_json::to_value)
566            .transpose()
567            .map_err(DatabaseError::to_invalid_operation)?;
568
569        let client_auth_method = client_auth_method.to_string();
570        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
571
572        sqlx::query!(
573            r#"
574                INSERT INTO oauth2_clients
575                    ( oauth2_client_id
576                    , encrypted_client_secret
577                    , redirect_uris
578                    , grant_type_authorization_code
579                    , grant_type_refresh_token
580                    , grant_type_client_credentials
581                    , grant_type_device_code
582                    , token_endpoint_auth_method
583                    , jwks
584                    , jwks_uri
585                    , is_static
586                    )
587                VALUES
588                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, TRUE)
589                ON CONFLICT (oauth2_client_id)
590                DO
591                    UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
592                             , redirect_uris = EXCLUDED.redirect_uris
593                             , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
594                             , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
595                             , grant_type_client_credentials = EXCLUDED.grant_type_client_credentials
596                             , grant_type_device_code = EXCLUDED.grant_type_device_code
597                             , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
598                             , jwks = EXCLUDED.jwks
599                             , jwks_uri = EXCLUDED.jwks_uri
600                             , is_static = TRUE
601            "#,
602            Uuid::from(client_id),
603            encrypted_client_secret,
604            &redirect_uris_array,
605            true,
606            true,
607            true,
608            true,
609            client_auth_method,
610            jwks_json,
611            jwks_uri.as_ref().map(Url::as_str),
612        )
613        .traced()
614        .execute(&mut *self.conn)
615        .await?;
616
617        let jwks = match (jwks, jwks_uri) {
618            (None, None) => None,
619            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
620            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
621            _ => return Err(DatabaseError::invalid_operation()),
622        };
623
624        Ok(Client {
625            id: client_id,
626            client_id: client_id.to_string(),
627            metadata_digest: None,
628            encrypted_client_secret,
629            application_type: None,
630            redirect_uris,
631            grant_types: vec![
632                GrantType::AuthorizationCode,
633                GrantType::RefreshToken,
634                GrantType::ClientCredentials,
635            ],
636            client_name: None,
637            logo_uri: None,
638            client_uri: None,
639            policy_uri: None,
640            tos_uri: None,
641            jwks,
642            id_token_signed_response_alg: None,
643            userinfo_signed_response_alg: None,
644            token_endpoint_auth_method: None,
645            token_endpoint_auth_signing_alg: None,
646            initiate_login_uri: None,
647        })
648    }
649
650    #[tracing::instrument(
651        name = "db.oauth2_client.all_static",
652        skip_all,
653        fields(
654            db.query.text,
655        ),
656        err,
657    )]
658    async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error> {
659        let res = sqlx::query_as!(
660            OAuth2ClientLookup,
661            r#"
662                SELECT oauth2_client_id
663                     , metadata_digest
664                     , encrypted_client_secret
665                     , application_type
666                     , redirect_uris
667                     , grant_type_authorization_code
668                     , grant_type_refresh_token
669                     , grant_type_client_credentials
670                     , grant_type_device_code
671                     , client_name
672                     , logo_uri
673                     , client_uri
674                     , policy_uri
675                     , tos_uri
676                     , jwks_uri
677                     , jwks
678                     , id_token_signed_response_alg
679                     , userinfo_signed_response_alg
680                     , token_endpoint_auth_method
681                     , token_endpoint_auth_signing_alg
682                     , initiate_login_uri
683                FROM oauth2_clients c
684                WHERE is_static = TRUE
685            "#,
686        )
687        .traced()
688        .fetch_all(&mut *self.conn)
689        .await?;
690
691        res.into_iter()
692            .map(|r| r.try_into().map_err(DatabaseError::from))
693            .collect()
694    }
695
696    #[tracing::instrument(
697        name = "db.oauth2_client.delete_by_id",
698        skip_all,
699        fields(
700            db.query.text,
701            client.id = %id,
702        ),
703        err,
704    )]
705    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
706        // Delete the authorization grants
707        {
708            let span = info_span!(
709                "db.oauth2_client.delete_by_id.authorization_grants",
710                { DB_QUERY_TEXT } = tracing::field::Empty,
711            );
712
713            sqlx::query!(
714                r#"
715                    DELETE FROM oauth2_authorization_grants
716                    WHERE oauth2_client_id = $1
717                "#,
718                Uuid::from(id),
719            )
720            .record(&span)
721            .execute(&mut *self.conn)
722            .instrument(span)
723            .await?;
724        }
725
726        // Delete the user consents
727        {
728            let span = info_span!(
729                "db.oauth2_client.delete_by_id.consents",
730                { DB_QUERY_TEXT } = tracing::field::Empty,
731            );
732
733            sqlx::query!(
734                r#"
735                    DELETE FROM oauth2_consents
736                    WHERE oauth2_client_id = $1
737                "#,
738                Uuid::from(id),
739            )
740            .record(&span)
741            .execute(&mut *self.conn)
742            .instrument(span)
743            .await?;
744        }
745
746        // Delete the OAuth 2 sessions related data
747        {
748            let span = info_span!(
749                "db.oauth2_client.delete_by_id.access_tokens",
750                { DB_QUERY_TEXT } = tracing::field::Empty,
751            );
752
753            sqlx::query!(
754                r#"
755                    DELETE FROM oauth2_access_tokens
756                    WHERE oauth2_session_id IN (
757                        SELECT oauth2_session_id
758                        FROM oauth2_sessions
759                        WHERE oauth2_client_id = $1
760                    )
761                "#,
762                Uuid::from(id),
763            )
764            .record(&span)
765            .execute(&mut *self.conn)
766            .instrument(span)
767            .await?;
768        }
769
770        {
771            let span = info_span!(
772                "db.oauth2_client.delete_by_id.refresh_tokens",
773                { DB_QUERY_TEXT } = tracing::field::Empty,
774            );
775
776            sqlx::query!(
777                r#"
778                    DELETE FROM oauth2_refresh_tokens
779                    WHERE oauth2_session_id IN (
780                        SELECT oauth2_session_id
781                        FROM oauth2_sessions
782                        WHERE oauth2_client_id = $1
783                    )
784                "#,
785                Uuid::from(id),
786            )
787            .record(&span)
788            .execute(&mut *self.conn)
789            .instrument(span)
790            .await?;
791        }
792
793        {
794            let span = info_span!(
795                "db.oauth2_client.delete_by_id.sessions",
796                { DB_QUERY_TEXT } = tracing::field::Empty,
797            );
798
799            sqlx::query!(
800                r#"
801                    DELETE FROM oauth2_sessions
802                    WHERE oauth2_client_id = $1
803                "#,
804                Uuid::from(id),
805            )
806            .record(&span)
807            .execute(&mut *self.conn)
808            .instrument(span)
809            .await?;
810        }
811
812        // Now delete the client itself
813        let res = sqlx::query!(
814            r#"
815                DELETE FROM oauth2_clients
816                WHERE oauth2_client_id = $1
817            "#,
818            Uuid::from(id),
819        )
820        .traced()
821        .execute(&mut *self.conn)
822        .await?;
823
824        DatabaseError::ensure_affected_rows(&res, 1)
825    }
826}