1pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::Ulid;
13use opa_wasm::{
14 Runtime,
15 wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use thiserror::Error;
18use tokio::io::{AsyncRead, AsyncReadExt};
19
20pub use self::model::{
21 AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput,
22 EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation,
23};
24
25#[derive(Debug, Error)]
26pub enum LoadError {
27 #[error("failed to read module")]
28 Read(#[from] tokio::io::Error),
29
30 #[error("failed to create WASM engine")]
31 Engine(#[source] anyhow::Error),
32
33 #[error("module compilation task crashed")]
34 CompilationTask(#[from] tokio::task::JoinError),
35
36 #[error("failed to compile WASM module")]
37 Compilation(#[source] anyhow::Error),
38
39 #[error("invalid policy data")]
40 InvalidData(#[source] anyhow::Error),
41
42 #[error("failed to instantiate a test instance")]
43 Instantiate(#[source] InstantiateError),
44}
45
46impl LoadError {
47 #[doc(hidden)]
50 #[must_use]
51 pub fn invalid_data_example() -> Self {
52 Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
53 }
54}
55
56#[derive(Debug, Error)]
57pub enum InstantiateError {
58 #[error("failed to create WASM runtime")]
59 Runtime(#[source] anyhow::Error),
60
61 #[error("missing entrypoint {entrypoint}")]
62 MissingEntrypoint { entrypoint: String },
63
64 #[error("failed to load policy data")]
65 LoadData(#[source] anyhow::Error),
66}
67
68#[derive(Debug, Clone)]
70pub struct Entrypoints {
71 pub register: String,
72 pub client_registration: String,
73 pub authorization_grant: String,
74 pub email: String,
75}
76
77impl Entrypoints {
78 fn all(&self) -> [&str; 4] {
79 [
80 self.register.as_str(),
81 self.client_registration.as_str(),
82 self.authorization_grant.as_str(),
83 self.email.as_str(),
84 ]
85 }
86}
87
88#[derive(Debug)]
89pub struct Data {
90 server_name: String,
91
92 rest: Option<serde_json::Value>,
93}
94
95impl Data {
96 #[must_use]
97 pub fn new(server_name: String) -> Self {
98 Self {
99 server_name,
100 rest: None,
101 }
102 }
103
104 #[must_use]
105 pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
106 self.rest = Some(rest);
107 self
108 }
109
110 fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
111 let base = serde_json::json!({
112 "server_name": self.server_name,
113 });
114
115 if let Some(rest) = &self.rest {
116 merge_data(base, rest.clone())
117 } else {
118 Ok(base)
119 }
120 }
121}
122
123fn value_kind(value: &serde_json::Value) -> &'static str {
124 match value {
125 serde_json::Value::Object(_) => "object",
126 serde_json::Value::Array(_) => "array",
127 serde_json::Value::String(_) => "string",
128 serde_json::Value::Number(_) => "number",
129 serde_json::Value::Bool(_) => "boolean",
130 serde_json::Value::Null => "null",
131 }
132}
133
134fn merge_data(
135 mut left: serde_json::Value,
136 right: serde_json::Value,
137) -> Result<serde_json::Value, anyhow::Error> {
138 merge_data_rec(&mut left, right)?;
139 Ok(left)
140}
141
142fn merge_data_rec(
143 left: &mut serde_json::Value,
144 right: serde_json::Value,
145) -> Result<(), anyhow::Error> {
146 match (left, right) {
147 (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
148 for (key, value) in right {
149 if let Some(left_value) = left.get_mut(&key) {
150 merge_data_rec(left_value, value)?;
151 } else {
152 left.insert(key, value);
153 }
154 }
155 }
156 (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
157 left.extend(right);
158 }
159 (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
161 *left = right;
162 }
163 (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
164 *left = right;
165 }
166 (serde_json::Value::String(left), serde_json::Value::String(right)) => {
167 *left = right;
168 }
169
170 (left, right) if left.is_null() => *left = right,
172
173 (left, right) if right.is_null() => *left = right,
175
176 (left, right) => anyhow::bail!(
177 "Cannot merge a {} into a {}",
178 value_kind(&right),
179 value_kind(left),
180 ),
181 }
182
183 Ok(())
184}
185
186struct DynamicData {
187 version: Option<Ulid>,
188 merged: serde_json::Value,
189}
190
191pub struct PolicyFactory {
192 engine: Engine,
193 module: Module,
194 data: Data,
195 dynamic_data: ArcSwap<DynamicData>,
196 entrypoints: Entrypoints,
197}
198
199impl PolicyFactory {
200 #[tracing::instrument(name = "policy.load", skip(source))]
201 pub async fn load(
202 mut source: impl AsyncRead + std::marker::Unpin,
203 data: Data,
204 entrypoints: Entrypoints,
205 ) -> Result<Self, LoadError> {
206 let mut config = Config::default();
207 config.async_support(true);
208 config.cranelift_opt_level(OptLevel::SpeedAndSize);
209
210 let engine = Engine::new(&config).map_err(LoadError::Engine)?;
211
212 let mut buf = Vec::new();
214 source.read_to_end(&mut buf).await?;
215 let (engine, module) = tokio::task::spawn_blocking(move || {
217 let module = Module::new(&engine, buf)?;
218 anyhow::Ok((engine, module))
219 })
220 .await?
221 .map_err(LoadError::Compilation)?;
222
223 let merged = data.to_value().map_err(LoadError::InvalidData)?;
224 let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
225 version: None,
226 merged,
227 }));
228
229 let factory = Self {
230 engine,
231 module,
232 data,
233 dynamic_data,
234 entrypoints,
235 };
236
237 factory
239 .instantiate()
240 .await
241 .map_err(LoadError::Instantiate)?;
242
243 Ok(factory)
244 }
245
246 pub async fn set_dynamic_data(
259 &self,
260 dynamic_data: mas_data_model::PolicyData,
261 ) -> Result<bool, LoadError> {
262 if self.dynamic_data.load().version == Some(dynamic_data.id) {
265 return Ok(false);
267 }
268
269 let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
270 let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
271
272 self.instantiate_with_data(&merged)
274 .await
275 .map_err(LoadError::Instantiate)?;
276
277 self.dynamic_data.store(Arc::new(DynamicData {
279 version: Some(dynamic_data.id),
280 merged,
281 }));
282
283 Ok(true)
284 }
285
286 #[tracing::instrument(name = "policy.instantiate", skip_all)]
287 pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
288 let data = self.dynamic_data.load();
289 self.instantiate_with_data(&data.merged).await
290 }
291
292 async fn instantiate_with_data(
293 &self,
294 data: &serde_json::Value,
295 ) -> Result<Policy, InstantiateError> {
296 let mut store = Store::new(&self.engine, ());
297 let runtime = Runtime::new(&mut store, &self.module)
298 .await
299 .map_err(InstantiateError::Runtime)?;
300
301 let policy_entrypoints = runtime.entrypoints();
303
304 for e in self.entrypoints.all() {
305 if !policy_entrypoints.contains(e) {
306 return Err(InstantiateError::MissingEntrypoint {
307 entrypoint: e.to_owned(),
308 });
309 }
310 }
311
312 let instance = runtime
313 .with_data(&mut store, data)
314 .await
315 .map_err(InstantiateError::LoadData)?;
316
317 Ok(Policy {
318 store,
319 instance,
320 entrypoints: self.entrypoints.clone(),
321 })
322 }
323}
324
325pub struct Policy {
326 store: Store<()>,
327 instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
328 entrypoints: Entrypoints,
329}
330
331#[derive(Debug, Error)]
332#[error("failed to evaluate policy")]
333pub enum EvaluationError {
334 Serialization(#[from] serde_json::Error),
335 Evaluation(#[from] anyhow::Error),
336}
337
338impl Policy {
339 #[tracing::instrument(
340 name = "policy.evaluate_email",
341 skip_all,
342 fields(
343 %input.email,
344 ),
345 )]
346 pub async fn evaluate_email(
347 &mut self,
348 input: EmailInput<'_>,
349 ) -> Result<EvaluationResult, EvaluationError> {
350 let [res]: [EvaluationResult; 1] = self
351 .instance
352 .evaluate(&mut self.store, &self.entrypoints.email, &input)
353 .await?;
354
355 Ok(res)
356 }
357
358 #[tracing::instrument(
359 name = "policy.evaluate.register",
360 skip_all,
361 fields(
362 ?input.registration_method,
363 input.username = input.username,
364 input.email = input.email,
365 ),
366 )]
367 pub async fn evaluate_register(
368 &mut self,
369 input: RegisterInput<'_>,
370 ) -> Result<EvaluationResult, EvaluationError> {
371 let [res]: [EvaluationResult; 1] = self
372 .instance
373 .evaluate(&mut self.store, &self.entrypoints.register, &input)
374 .await?;
375
376 Ok(res)
377 }
378
379 #[tracing::instrument(skip(self))]
380 pub async fn evaluate_client_registration(
381 &mut self,
382 input: ClientRegistrationInput<'_>,
383 ) -> Result<EvaluationResult, EvaluationError> {
384 let [res]: [EvaluationResult; 1] = self
385 .instance
386 .evaluate(
387 &mut self.store,
388 &self.entrypoints.client_registration,
389 &input,
390 )
391 .await?;
392
393 Ok(res)
394 }
395
396 #[tracing::instrument(
397 name = "policy.evaluate.authorization_grant",
398 skip_all,
399 fields(
400 %input.scope,
401 %input.client.id,
402 ),
403 )]
404 pub async fn evaluate_authorization_grant(
405 &mut self,
406 input: AuthorizationGrantInput<'_>,
407 ) -> Result<EvaluationResult, EvaluationError> {
408 let [res]: [EvaluationResult; 1] = self
409 .instance
410 .evaluate(
411 &mut self.store,
412 &self.entrypoints.authorization_grant,
413 &input,
414 )
415 .await?;
416
417 Ok(res)
418 }
419}
420
421#[cfg(test)]
422mod tests {
423
424 use std::time::SystemTime;
425
426 use super::*;
427
428 #[tokio::test]
429 async fn test_register() {
430 let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({
431 "allowed_domains": ["element.io", "*.element.io"],
432 "banned_domains": ["staging.element.io"],
433 }));
434
435 #[allow(clippy::disallowed_types)]
436 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
437 .join("..")
438 .join("..")
439 .join("policies")
440 .join("policy.wasm");
441
442 let file = tokio::fs::File::open(path).await.unwrap();
443
444 let entrypoints = Entrypoints {
445 register: "register/violation".to_owned(),
446 client_registration: "client_registration/violation".to_owned(),
447 authorization_grant: "authorization_grant/violation".to_owned(),
448 email: "email/violation".to_owned(),
449 };
450
451 let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
452
453 let mut policy = factory.instantiate().await.unwrap();
454
455 let res = policy
456 .evaluate_register(RegisterInput {
457 registration_method: RegistrationMethod::Password,
458 username: "hello",
459 email: Some("hello@example.com"),
460 requester: Requester {
461 ip_address: None,
462 user_agent: None,
463 },
464 })
465 .await
466 .unwrap();
467 assert!(!res.valid());
468
469 let res = policy
470 .evaluate_register(RegisterInput {
471 registration_method: RegistrationMethod::Password,
472 username: "hello",
473 email: Some("hello@foo.element.io"),
474 requester: Requester {
475 ip_address: None,
476 user_agent: None,
477 },
478 })
479 .await
480 .unwrap();
481 assert!(res.valid());
482
483 let res = policy
484 .evaluate_register(RegisterInput {
485 registration_method: RegistrationMethod::Password,
486 username: "hello",
487 email: Some("hello@staging.element.io"),
488 requester: Requester {
489 ip_address: None,
490 user_agent: None,
491 },
492 })
493 .await
494 .unwrap();
495 assert!(!res.valid());
496 }
497
498 #[tokio::test]
499 async fn test_dynamic_data() {
500 let data = Data::new("example.com".to_owned());
501
502 #[allow(clippy::disallowed_types)]
503 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
504 .join("..")
505 .join("..")
506 .join("policies")
507 .join("policy.wasm");
508
509 let file = tokio::fs::File::open(path).await.unwrap();
510
511 let entrypoints = Entrypoints {
512 register: "register/violation".to_owned(),
513 client_registration: "client_registration/violation".to_owned(),
514 authorization_grant: "authorization_grant/violation".to_owned(),
515 email: "email/violation".to_owned(),
516 };
517
518 let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
519
520 let mut policy = factory.instantiate().await.unwrap();
521
522 let res = policy
523 .evaluate_register(RegisterInput {
524 registration_method: RegistrationMethod::Password,
525 username: "hello",
526 email: Some("hello@example.com"),
527 requester: Requester {
528 ip_address: None,
529 user_agent: None,
530 },
531 })
532 .await
533 .unwrap();
534 assert!(res.valid());
535
536 factory
538 .set_dynamic_data(mas_data_model::PolicyData {
539 id: Ulid::nil(),
540 created_at: SystemTime::now().into(),
541 data: serde_json::json!({
542 "emails": {
543 "banned_addresses": {
544 "substrings": ["hello"]
545 }
546 }
547 }),
548 })
549 .await
550 .unwrap();
551 let mut policy = factory.instantiate().await.unwrap();
552 let res = policy
553 .evaluate_register(RegisterInput {
554 registration_method: RegistrationMethod::Password,
555 username: "hello",
556 email: Some("hello@example.com"),
557 requester: Requester {
558 ip_address: None,
559 user_agent: None,
560 },
561 })
562 .await
563 .unwrap();
564 assert!(!res.valid());
565 }
566
567 #[tokio::test]
568 async fn test_big_dynamic_data() {
569 let data = Data::new("example.com".to_owned());
570
571 #[allow(clippy::disallowed_types)]
572 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
573 .join("..")
574 .join("..")
575 .join("policies")
576 .join("policy.wasm");
577
578 let file = tokio::fs::File::open(path).await.unwrap();
579
580 let entrypoints = Entrypoints {
581 register: "register/violation".to_owned(),
582 client_registration: "client_registration/violation".to_owned(),
583 authorization_grant: "authorization_grant/violation".to_owned(),
584 email: "email/violation".to_owned(),
585 };
586
587 let factory = PolicyFactory::load(file, data, entrypoints).await.unwrap();
588
589 let data: Vec<String> = (0..(1024 * 1024 / 8))
592 .map(|i| format!("{:05}", i % 100_000))
593 .collect();
594 let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
595 factory
596 .set_dynamic_data(mas_data_model::PolicyData {
597 id: Ulid::nil(),
598 created_at: SystemTime::now().into(),
599 data: json,
600 })
601 .await
602 .unwrap();
603
604 let mut policy = factory.instantiate().await.unwrap();
607 let res = policy
608 .evaluate_register(RegisterInput {
609 registration_method: RegistrationMethod::Password,
610 username: "hello",
611 email: Some("12345@example.com"),
612 requester: Requester {
613 ip_address: None,
614 user_agent: None,
615 },
616 })
617 .await
618 .unwrap();
619 assert!(!res.valid());
620 }
621
622 #[test]
623 fn test_merge() {
624 use serde_json::json as j;
625
626 let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
628 assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
629
630 let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
632 assert_eq!(res, j!({"hello": "john"}));
633
634 let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
635 assert_eq!(res, j!({"hello": false}));
636
637 let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
638 assert_eq!(res, j!({"hello": 42}));
639
640 merge_data(j!({"hello": "world"}), j!({"hello": 123}))
642 .expect_err("Can't merge different types");
643
644 let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
646 assert_eq!(res, j!({"hello": ["world", "john"]}));
647
648 let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
650 assert_eq!(res, j!({"hello": null}));
651
652 let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
654 assert_eq!(res, j!({"hello": "world"}));
655
656 let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
658 assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
659 }
660}