1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
//! This module is in change of storing closures with the type `Fn(I) -> O`
//! in a static lifetime storage, supporting mixing differents `I` and `O`
//! types. Because we need to merge different closures with different types in
//! the same storage, we use an `u128` as closure identification (composed by
//! the closure function pointer (`u64`) and the pointer to the closure metadata
//! (`u64`).

use std::{
	cell::RefCell,
	collections::HashMap,
	fmt,
	sync::{Arc, Mutex},
};

use super::util::TypeSignature;

/// Identify a call in the call storage
pub type CallId = u64;

struct CallInfo {
	/// Closure identification
	ptr: u128,

	/// Runtime representation of the closure type.
	/// This field is needed to ensure we are getting the correct closure type,
	/// since the type at compiler time is lost in the `u128` representation of
	/// the closure.
	type_signature: TypeSignature,
}

type Registry = HashMap<CallId, Arc<Mutex<CallInfo>>>;

thread_local! {
	static CALLS: RefCell<Registry> = RefCell::new(HashMap::default());
}

#[derive(Debug, PartialEq)]
pub enum Error {
	CallNotFound,
	TypeNotMatch {
		expected: TypeSignature,
		found: TypeSignature,
	},
}

impl fmt::Display for Error {
	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
		match self {
			Error::CallNotFound => write!(f, "Trying to call a function that is not registered"),
			Error::TypeNotMatch { expected, found } => write!(
				f,
				"The function is registered but the type mismatches. Expected {expected}, found: {found}",
			),
		}
	}
}

/// Register a call into the call storage.
/// The registered call can be uniquely identified by the returned `CallId`.
pub fn register_call<F: Fn(I) -> O + 'static, I, O>(f: F) -> CallId {
	// We box the closure in order to store it in a fixed place of memory,
	// and handle it in a more generic way without knowing the specific closure
	// implementation.
	let f = Box::new(f) as Box<dyn Fn(I) -> O>;

	// We're only interested in the memory address of the closure.
	// Box is never dropped after this call.
	let ptr: *const dyn Fn(I) -> O = Box::into_raw(f);

	let call = CallInfo {
		// We need the transmutation to forget about the type of the closure at compile time,
		// and then store closures with different types together.
		// SAFETY: transforming a wide pointer (*const dyn) to an u128 is always safe
		// because the memory representation is the same.
		ptr: unsafe { std::mem::transmute(ptr) },
		// Since we've lost the type representation at compile time, we need to store the type
		// representation at runtime, in order to recover later the correct closure
		type_signature: TypeSignature::new::<I, O>(),
	};

	CALLS.with(|state| {
		let registry = &mut *state.borrow_mut();
		let call_id = registry.len() as u64;
		registry.insert(call_id, Arc::new(Mutex::new(call)));
		call_id
	})
}

/// Execute a call from the call storage identified by a `call_id`.
pub fn execute_call<I, O>(call_id: CallId, input: I) -> Result<O, Error> {
	let expected_type_signature = TypeSignature::new::<I, O>();

	let call = CALLS.with(|state| {
		let registry = &*state.borrow();
		let call = registry.get(&call_id).ok_or(Error::CallNotFound)?;
		Ok(call.clone())
	})?;

	let call = call.lock().unwrap();

	// We need the runtime type check since we lost the type at compile time.
	if expected_type_signature != call.type_signature {
		return Err(Error::TypeNotMatch {
			expected: expected_type_signature,
			found: call.type_signature.clone(),
		});
	}

	// SAFETY:
	// 1. The existence of this closure ptr in consequent calls is ensured
	// thanks to Box::into_raw() at register_call(),
	// which takes the Box ownership without dropping it. So, ptr exists forever.
	// 2. The type of the transmuted call is ensured in runtime by the above type
	// signature check.
	// 3. The pointer is correctly aligned because it was allocated by a Box.
	// 4. The closure is called once at the same time thanks to the mutex.
	let f: &dyn Fn(I) -> O = unsafe {
		#[allow(clippy::useless_transmute)] // Clippy hints something erroneous
		let ptr: *const dyn Fn(I) -> O = std::mem::transmute(call.ptr);
		&*ptr
	};

	Ok(f(input))
}

#[cfg(test)]
mod tests {
	use super::*;

	#[test]
	fn correct_type() {
		let func_1 = |n: u8| -> usize { 23 * n as usize };
		let call_id_1 = register_call(func_1);
		let result = execute_call::<_, usize>(call_id_1, 2u8);

		assert_eq!(result, Ok(46));
	}

	#[test]
	fn different_input_type() {
		let func_1 = |n: u8| -> usize { 23 * n as usize };
		let call_id_1 = register_call(func_1);
		let result = execute_call::<_, usize>(call_id_1, 'a');

		assert_eq!(
			result,
			Err(Error::TypeNotMatch {
				expected: TypeSignature::new::<char, usize>(),
				found: TypeSignature::new::<u8, usize>()
			})
		);
	}

	#[test]
	fn different_output_type() {
		let func_1 = |n: u8| -> usize { 23 * n as usize };
		let call_id_1 = register_call(func_1);
		let result = execute_call::<_, char>(call_id_1, 2u8);

		assert_eq!(
			result,
			Err(Error::TypeNotMatch {
				expected: TypeSignature::new::<u8, char>(),
				found: TypeSignature::new::<u8, usize>()
			})
		);
	}

	#[test]
	fn no_registered() {
		let call_id_1 = 42;

		assert_eq!(
			execute_call::<_, usize>(call_id_1, 2u8),
			Err(Error::CallNotFound)
		);
	}
}