1use std::{
2 fmt::Debug,
3 future::Future,
4 sync::{Arc, Condvar, Mutex, RwLock},
5 thread::JoinHandle as ThreadJoinHandle,
6 time::Duration,
7};
8
9use ::tokio::task::JoinHandle as TokioJoinHandle;
10
11use bftgrid_core::actor::{ActorMsg, ActorRef, Joinable, Task};
12
13pub mod thread;
14pub mod tokio;
15
16#[derive(Debug)]
17struct ThreadJoinable<T> {
18 value: Option<ThreadJoinHandle<T>>,
19}
20
21impl<T> Task for ThreadJoinable<T>
22where
23 T: Debug + Send,
24{
25 fn is_finished(&self) -> bool {
26 if let Some(ref value) = self.value {
27 value.is_finished()
28 } else {
29 true
30 }
31 }
32}
33
34impl Joinable<()> for ThreadJoinable<()> {
35 fn join(&mut self) {
36 let value = self.value.take();
37 if let Some(v) = value {
38 v.join().unwrap()
39 }
40 }
41}
42
43#[derive(Debug)]
44pub struct TokioTask<T> {
45 pub value: ::tokio::task::JoinHandle<T>,
46}
47
48impl<T> Task for TokioTask<T>
49where
50 T: Debug + Send,
51{
52 fn is_finished(&self) -> bool {
53 self.value.is_finished()
54 }
55}
56
57#[derive(Debug)]
68pub struct AsyncRuntime {
69 pub name: Arc<String>,
70 tokio: Arc<RwLock<Option<::tokio::runtime::Runtime>>>,
78}
79
80impl Drop for AsyncRuntime {
81 fn drop(&mut self) {
82 log::debug!("Dropping Tokio runtime '{}'", self.name);
83 if let Some(tokio) = self.tokio.write().unwrap().take() {
85 tokio.shutdown_background();
86 }
87 }
88}
89
90impl AsyncRuntime {
91 pub fn new(name: impl Into<String>, tokio: Option<::tokio::runtime::Runtime>) -> AsyncRuntime {
95 let runtime_name = Arc::new(name.into());
96 AsyncRuntime {
97 name: runtime_name.clone(),
98 tokio: Arc::new(RwLock::new(tokio.or({
99 log::debug!("Creating new Tokio runtime as '{}'", runtime_name);
100 Some(
101 ::tokio::runtime::Builder::new_multi_thread()
102 .enable_all()
103 .build()
104 .unwrap(),
105 )
106 }))),
107 }
108 }
109
110 pub fn block_on_async<R>(&self, f: impl Future<Output = R>) -> R {
114 match ::tokio::runtime::Handle::try_current() {
115 Ok(handle) => {
116 log::debug!(
117 "Tokio runtime '{}' blocking on async inside an async context",
118 self.name,
119 );
120 let _guard = handle.enter();
121 ::tokio::task::block_in_place(|| handle.block_on(f))
122 }
123 _ => {
124 log::debug!(
125 "Tokio runtime '{}' blocking on async outside of an async context",
126 self.name,
127 );
128 self.tokio.read().unwrap().as_ref().unwrap().block_on(f)
129 }
130 }
131 }
132
133 pub fn spawn_async<R>(
134 &self,
135 f: impl Future<Output = R> + Send + 'static,
136 ) -> ::tokio::task::JoinHandle<R>
137 where
138 R: Send + 'static,
139 {
140 match ::tokio::runtime::Handle::try_current() {
141 Ok(handle) => handle.spawn(f),
142 _ => self.tokio.read().unwrap().as_ref().unwrap().spawn(f),
143 }
144 }
145
146 pub fn thread_blocking<R>(&self, f: impl FnOnce() -> R) -> R {
147 match ::tokio::runtime::Handle::try_current() {
148 Ok(handle) => {
149 log::debug!(
150 "Tokio runtime '{}' blocking thread inside an async context",
151 self.name,
152 );
153 let _guard = handle.enter();
154 ::tokio::task::block_in_place(f)
155 }
156 _ => {
157 log::debug!(
158 "Tokio runtime '{}' blocking thread outside of an async context",
159 self.name,
160 );
161 f()
162 }
163 }
164 }
165
166 pub fn spawn_thread_blocking_send<MsgT>(
175 &self,
176 f: impl FnOnce() -> MsgT + Send + 'static,
177 actor_ref: impl ActorRef<MsgT> + 'static,
178 delay: Option<Duration>,
179 ) -> ::tokio::task::JoinHandle<()>
180 where
181 MsgT: ActorMsg + 'static,
182 {
183 match ::tokio::runtime::Handle::try_current() {
184 Ok(handle) => {
185 log::debug!(
186 "Tokio runtime '{}' performing blocking and then send inside an async context",
187 self.name,
188 );
189 let _guard = handle.enter();
190 self.spawn_async_blocking_send(f, actor_ref, delay)
191 }
192 _ => {
193 log::debug!(
194 "Tokio runtime '{}' performing blocking and then send outside of an async context",
195 self.name,
196 );
197 let _guard = self.tokio.read().unwrap().as_ref().unwrap().enter();
198 self.spawn_async_blocking_send(f, actor_ref, delay)
199 }
200 }
201 }
202
203 fn spawn_async_blocking_send<MsgT>(
204 &self,
205 f: impl FnOnce() -> MsgT + Send + 'static,
206 mut actor_ref: impl ActorRef<MsgT> + 'static,
207 delay: Option<Duration>,
208 ) -> ::tokio::task::JoinHandle<()>
209 where
210 MsgT: ActorMsg + 'static,
211 {
212 let actor_system_name = self.name.clone();
213 self.spawn_async(async move {
214 match ::tokio::task::spawn_blocking(f).await {
215 Ok(result) => actor_ref.send(result, delay),
216 Err(_) => log::error!(
217 "Tokio runtime '{}': blocking send task failed",
218 actor_system_name
219 ),
220 };
221 })
222 }
223}
224
225fn notify_close(close_cond: Arc<(Mutex<bool>, Condvar)>) {
226 let (closed_mutex, cvar) = &*close_cond;
227 let mut closed = closed_mutex.lock().unwrap();
228 *closed = true;
229 cvar.notify_all();
230}
231
232fn cleanup_complete_tasks<TaskT>(tasks: &mut Vec<TaskT>) -> &mut Vec<TaskT>
233where
234 TaskT: Task,
235{
236 tasks.retain(|t| !t.is_finished());
237 tasks
238}
239
240fn spawn_async_task<T>(
241 tasks: &mut Vec<TokioTask<T>>,
242 runtime: &AsyncRuntime,
243 future: impl std::future::Future<Output = T> + Send + 'static,
244) where
245 T: Send + std::fmt::Debug + 'static,
246{
247 cleanup_complete_tasks(tasks).push(TokioTask {
248 value: runtime.spawn_async(future),
249 });
250}
251
252fn push_async_task<T>(tasks: &mut Vec<TokioTask<T>>, tokio_join_handle: TokioJoinHandle<T>)
253where
254 T: Send + std::fmt::Debug,
255{
256 cleanup_complete_tasks(tasks).push(TokioTask {
257 value: tokio_join_handle,
258 });
259}
260
261fn join_tasks(async_runtime: &AsyncRuntime, tasks: (Vec<ThreadJoinable<()>>, Vec<TokioTask<()>>)) {
262 let (thread_blocking_tasks, async_tasks) = tasks;
263 log::debug!(
264 "Thread actor system '{}' joining {} thread blocking tasks and {} async tasks",
265 async_runtime.name,
266 thread_blocking_tasks.len(),
267 async_tasks.len(),
268 );
269 for mut t in thread_blocking_tasks {
270 t.join();
271 }
272 for t in async_tasks {
273 async_runtime.block_on_async(async move { t.value.await.unwrap() });
274 }
275}