1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
11 UpstreamOAuthProvider,
12};
13use mas_storage::{Clock, upstream_oauth2::UpstreamOAuthSessionRepository};
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
20
21pub struct PgUpstreamOAuthSessionRepository<'c> {
24 conn: &'c mut PgConnection,
25}
26
27impl<'c> PgUpstreamOAuthSessionRepository<'c> {
28 pub fn new(conn: &'c mut PgConnection) -> Self {
31 Self { conn }
32 }
33}
34
35struct SessionLookup {
36 upstream_oauth_authorization_session_id: Uuid,
37 upstream_oauth_provider_id: Uuid,
38 upstream_oauth_link_id: Option<Uuid>,
39 state: String,
40 code_challenge_verifier: Option<String>,
41 nonce: String,
42 id_token: Option<String>,
43 userinfo: Option<serde_json::Value>,
44 created_at: DateTime<Utc>,
45 completed_at: Option<DateTime<Utc>>,
46 consumed_at: Option<DateTime<Utc>>,
47 extra_callback_parameters: Option<serde_json::Value>,
48 unlinked_at: Option<DateTime<Utc>>,
49}
50
51impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
52 type Error = DatabaseInconsistencyError;
53
54 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
55 let id = value.upstream_oauth_authorization_session_id.into();
56 let state = match (
57 value.upstream_oauth_link_id,
58 value.id_token,
59 value.extra_callback_parameters,
60 value.userinfo,
61 value.completed_at,
62 value.consumed_at,
63 value.unlinked_at,
64 ) {
65 (None, None, None, None, None, None, None) => {
66 UpstreamOAuthAuthorizationSessionState::Pending
67 }
68 (
69 Some(link_id),
70 id_token,
71 extra_callback_parameters,
72 userinfo,
73 Some(completed_at),
74 None,
75 None,
76 ) => UpstreamOAuthAuthorizationSessionState::Completed {
77 completed_at,
78 link_id: link_id.into(),
79 id_token,
80 extra_callback_parameters,
81 userinfo,
82 },
83 (
84 Some(link_id),
85 id_token,
86 extra_callback_parameters,
87 userinfo,
88 Some(completed_at),
89 Some(consumed_at),
90 None,
91 ) => UpstreamOAuthAuthorizationSessionState::Consumed {
92 completed_at,
93 link_id: link_id.into(),
94 id_token,
95 extra_callback_parameters,
96 userinfo,
97 consumed_at,
98 },
99 (_, id_token, _, _, Some(completed_at), consumed_at, Some(unlinked_at)) => {
100 UpstreamOAuthAuthorizationSessionState::Unlinked {
101 completed_at,
102 id_token,
103 consumed_at,
104 unlinked_at,
105 }
106 }
107 _ => {
108 return Err(DatabaseInconsistencyError::on(
109 "upstream_oauth_authorization_sessions",
110 )
111 .row(id));
112 }
113 };
114
115 Ok(Self {
116 id,
117 provider_id: value.upstream_oauth_provider_id.into(),
118 state_str: value.state,
119 nonce: value.nonce,
120 code_challenge_verifier: value.code_challenge_verifier,
121 created_at: value.created_at,
122 state,
123 })
124 }
125}
126
127#[async_trait]
128impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
129 type Error = DatabaseError;
130
131 #[tracing::instrument(
132 name = "db.upstream_oauth_authorization_session.lookup",
133 skip_all,
134 fields(
135 db.query.text,
136 upstream_oauth_provider.id = %id,
137 ),
138 err,
139 )]
140 async fn lookup(
141 &mut self,
142 id: Ulid,
143 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
144 let res = sqlx::query_as!(
145 SessionLookup,
146 r#"
147 SELECT
148 upstream_oauth_authorization_session_id,
149 upstream_oauth_provider_id,
150 upstream_oauth_link_id,
151 state,
152 code_challenge_verifier,
153 nonce,
154 id_token,
155 extra_callback_parameters,
156 userinfo,
157 created_at,
158 completed_at,
159 consumed_at,
160 unlinked_at
161 FROM upstream_oauth_authorization_sessions
162 WHERE upstream_oauth_authorization_session_id = $1
163 "#,
164 Uuid::from(id),
165 )
166 .traced()
167 .fetch_optional(&mut *self.conn)
168 .await?;
169
170 let Some(res) = res else { return Ok(None) };
171
172 Ok(Some(res.try_into()?))
173 }
174
175 #[tracing::instrument(
176 name = "db.upstream_oauth_authorization_session.add",
177 skip_all,
178 fields(
179 db.query.text,
180 %upstream_oauth_provider.id,
181 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
182 %upstream_oauth_provider.client_id,
183 upstream_oauth_authorization_session.id,
184 ),
185 err,
186 )]
187 async fn add(
188 &mut self,
189 rng: &mut (dyn RngCore + Send),
190 clock: &dyn Clock,
191 upstream_oauth_provider: &UpstreamOAuthProvider,
192 state_str: String,
193 code_challenge_verifier: Option<String>,
194 nonce: String,
195 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
196 let created_at = clock.now();
197 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
198 tracing::Span::current().record(
199 "upstream_oauth_authorization_session.id",
200 tracing::field::display(id),
201 );
202
203 sqlx::query!(
204 r#"
205 INSERT INTO upstream_oauth_authorization_sessions (
206 upstream_oauth_authorization_session_id,
207 upstream_oauth_provider_id,
208 state,
209 code_challenge_verifier,
210 nonce,
211 created_at,
212 completed_at,
213 consumed_at,
214 id_token,
215 userinfo
216 ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
217 "#,
218 Uuid::from(id),
219 Uuid::from(upstream_oauth_provider.id),
220 &state_str,
221 code_challenge_verifier.as_deref(),
222 nonce,
223 created_at,
224 )
225 .traced()
226 .execute(&mut *self.conn)
227 .await?;
228
229 Ok(UpstreamOAuthAuthorizationSession {
230 id,
231 state: UpstreamOAuthAuthorizationSessionState::default(),
232 provider_id: upstream_oauth_provider.id,
233 state_str,
234 code_challenge_verifier,
235 nonce,
236 created_at,
237 })
238 }
239
240 #[tracing::instrument(
241 name = "db.upstream_oauth_authorization_session.complete_with_link",
242 skip_all,
243 fields(
244 db.query.text,
245 %upstream_oauth_authorization_session.id,
246 %upstream_oauth_link.id,
247 ),
248 err,
249 )]
250 async fn complete_with_link(
251 &mut self,
252 clock: &dyn Clock,
253 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
254 upstream_oauth_link: &UpstreamOAuthLink,
255 id_token: Option<String>,
256 extra_callback_parameters: Option<serde_json::Value>,
257 userinfo: Option<serde_json::Value>,
258 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
259 let completed_at = clock.now();
260
261 sqlx::query!(
262 r#"
263 UPDATE upstream_oauth_authorization_sessions
264 SET upstream_oauth_link_id = $1,
265 completed_at = $2,
266 id_token = $3,
267 extra_callback_parameters = $4,
268 userinfo = $5
269 WHERE upstream_oauth_authorization_session_id = $6
270 "#,
271 Uuid::from(upstream_oauth_link.id),
272 completed_at,
273 id_token,
274 extra_callback_parameters,
275 userinfo,
276 Uuid::from(upstream_oauth_authorization_session.id),
277 )
278 .traced()
279 .execute(&mut *self.conn)
280 .await?;
281
282 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
283 .complete(
284 completed_at,
285 upstream_oauth_link,
286 id_token,
287 extra_callback_parameters,
288 userinfo,
289 )
290 .map_err(DatabaseError::to_invalid_operation)?;
291
292 Ok(upstream_oauth_authorization_session)
293 }
294
295 #[tracing::instrument(
297 name = "db.upstream_oauth_authorization_session.consume",
298 skip_all,
299 fields(
300 db.query.text,
301 %upstream_oauth_authorization_session.id,
302 ),
303 err,
304 )]
305 async fn consume(
306 &mut self,
307 clock: &dyn Clock,
308 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
309 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
310 let consumed_at = clock.now();
311 sqlx::query!(
312 r#"
313 UPDATE upstream_oauth_authorization_sessions
314 SET consumed_at = $1
315 WHERE upstream_oauth_authorization_session_id = $2
316 "#,
317 consumed_at,
318 Uuid::from(upstream_oauth_authorization_session.id),
319 )
320 .traced()
321 .execute(&mut *self.conn)
322 .await?;
323
324 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
325 .consume(consumed_at)
326 .map_err(DatabaseError::to_invalid_operation)?;
327
328 Ok(upstream_oauth_authorization_session)
329 }
330}