bftgrid_mt/
thread.rs

1use std::{
2    fmt::Debug,
3    future::Future,
4    mem,
5    sync::{
6        Arc, Condvar, Mutex,
7        mpsc::{self, Sender},
8    },
9    thread,
10    time::Duration,
11};
12
13use crate::{
14    AsyncRuntime, ThreadJoinable, TokioTask, cleanup_complete_tasks, join_tasks, notify_close,
15    push_async_task,
16};
17use bftgrid_core::actor::{
18    ActorControl, ActorMsg, ActorRef, ActorSystemHandle, DynMsgHandler, Joinable, Task,
19};
20use tokio::runtime::Runtime;
21use tokio::task::JoinHandle as TokioJoinHandle;
22
23#[derive(Debug)]
24pub struct ThreadActorData<MsgT>
25where
26    MsgT: ActorMsg,
27{
28    tx: Sender<MsgT>,
29    handler_tx: Sender<Arc<Mutex<DynMsgHandler<MsgT>>>>,
30    close_cond: Arc<(Mutex<bool>, Condvar)>,
31    name: Arc<String>,
32}
33
34#[derive(Debug)]
35struct ThreadActor<MsgT>
36where
37    MsgT: ActorMsg,
38{
39    data: ThreadActorData<MsgT>,
40    actor_system_handle: ThreadActorSystemHandle,
41    join_on_drop: bool,
42}
43
44impl<Msg> ThreadActor<Msg>
45where
46    Msg: ActorMsg,
47{
48    fn join(&self) {
49        let (close_mutex, cvar) = &*self.data.close_cond;
50        let mut closed = close_mutex.lock().unwrap();
51        while !*closed {
52            closed = cvar.wait(closed).unwrap();
53        }
54    }
55}
56
57impl<MsgT> Drop for ThreadActor<MsgT>
58where
59    MsgT: ActorMsg,
60{
61    fn drop(&mut self) {
62        if self.join_on_drop {
63            log::debug!("Thread actor '{}' dropping, joining", self.data.name);
64            self.join();
65        } else {
66            log::debug!("Thread actor '{}' dropping, not joining", self.data.name);
67        }
68    }
69}
70
71impl<MsgT> Task for ThreadActor<MsgT>
72where
73    MsgT: ActorMsg,
74{
75    fn is_finished(&self) -> bool {
76        *self.data.close_cond.0.lock().unwrap()
77    }
78}
79
80#[derive(Debug)]
81pub struct ThreadActorRef<MsgT>
82where
83    MsgT: ActorMsg,
84{
85    actor: Arc<ThreadActor<MsgT>>,
86}
87
88impl<MsgT> Clone for ThreadActorRef<MsgT>
89where
90    MsgT: ActorMsg,
91{
92    fn clone(&self) -> Self {
93        ThreadActorRef {
94            actor: self.actor.clone(),
95        }
96    }
97}
98
99impl<MsgT> Task for ThreadActorRef<MsgT>
100where
101    MsgT: ActorMsg,
102{
103    fn is_finished(&self) -> bool {
104        self.actor.is_finished()
105    }
106}
107
108impl<MsgT> Joinable<()> for ThreadActorRef<MsgT>
109where
110    MsgT: ActorMsg,
111{
112    fn join(&mut self) {
113        self.actor.join();
114    }
115}
116
117impl<MsgT> ActorRef<MsgT> for ThreadActorRef<MsgT>
118where
119    MsgT: ActorMsg,
120{
121    fn send(&mut self, message: MsgT, delay: Option<Duration>) {
122        let sender = self.actor.data.tx.clone();
123        if let Some(delay_duration) = delay {
124            self.actor
125                .actor_system_handle
126                .spawn_thread_blocking_task(move || {
127                    log::debug!("Delaying send by {:?}", delay_duration);
128                    thread::sleep(delay_duration);
129                    checked_send(sender, message);
130                });
131        } else {
132            // No need to spawn a thread if no delay is needed, as the sender is non-blocking
133            checked_send(sender, message);
134        }
135    }
136
137    fn set_handler(&mut self, handler: DynMsgHandler<MsgT>) {
138        self.actor
139            .data
140            .handler_tx
141            .send(Arc::new(Mutex::new(handler)))
142            .unwrap();
143    }
144
145    fn spawn_async_send(
146        &mut self,
147        f: impl Future<Output = MsgT> + Send + 'static,
148        delay: Option<Duration>,
149    ) {
150        let mut self_clone = self.clone();
151        let mut actor_system_lock_guard =
152            self.actor.actor_system_handle.actor_system.lock().unwrap();
153        let async_runtime = actor_system_lock_guard.async_runtime.clone();
154        actor_system_lock_guard.push_async_task(async_runtime.spawn_async(async move {
155            self_clone.send(f.await, delay);
156        }));
157    }
158
159    fn spawn_thread_blocking_send(
160        &mut self,
161        f: impl FnOnce() -> MsgT + Send + 'static,
162        delay: Option<Duration>,
163    ) {
164        let self_clone = self.clone();
165        let mut actor_system_lock_guard =
166            self.actor.actor_system_handle.actor_system.lock().unwrap();
167        let async_runtime = actor_system_lock_guard.async_runtime.clone();
168        actor_system_lock_guard
169            .push_async_task(async_runtime.spawn_thread_blocking_send(f, self_clone, delay));
170    }
171}
172
173fn checked_send<MsgT>(sender: Sender<MsgT>, message: MsgT)
174where
175    MsgT: ActorMsg,
176{
177    match sender.send(message) {
178        Ok(_) => {}
179        Err(e) => {
180            log::warn!("Send from thread actor failed: {:?}", e);
181        }
182    }
183}
184
185#[derive(Debug)]
186pub struct ThreadActorSystem {
187    async_runtime: Arc<AsyncRuntime>,
188    thread_blocking_tasks: Vec<ThreadJoinable<()>>,
189    async_tasks: Vec<TokioTask<()>>,
190    join_tasks_on_drop: bool,
191}
192
193impl ThreadActorSystem {
194    fn spawn_thread_blocking_task(&mut self, f: impl FnOnce() + Send + 'static) {
195        cleanup_complete_tasks(&mut self.thread_blocking_tasks).push(ThreadJoinable {
196            value: Some(thread::spawn(f)),
197        });
198    }
199
200    fn extract_tasks(&mut self) -> (Vec<ThreadJoinable<()>>, Vec<TokioTask<()>>) {
201        (
202            mem::take(&mut self.thread_blocking_tasks),
203            mem::take(&mut self.async_tasks),
204        )
205    }
206
207    fn create<MsgT>(
208        &mut self,
209        name: impl Into<String>,
210        node_id: impl Into<String>,
211    ) -> ThreadActorData<MsgT>
212    where
213        MsgT: ActorMsg,
214    {
215        let (tx, rx) = mpsc::channel();
216        let (handler_tx, handler_rx) = mpsc::channel::<Arc<Mutex<DynMsgHandler<MsgT>>>>();
217        let close_cond = Arc::new((Mutex::new(false), Condvar::new()));
218        let close_cond2 = close_cond.clone();
219        let actor_name = Arc::new(name.into());
220        let actor_name_clone = actor_name.clone();
221        let actor_node_id = node_id.into();
222        let actor_system_name = self.async_runtime.name.clone();
223        self.spawn_thread_blocking_task(move || {
224            let mut current_handler = handler_rx.recv().unwrap();
225            log::debug!("Started actor '{}' on node '{}' in thread actor system '{}'", actor_name, actor_node_id, actor_system_name);
226            loop {
227                if let Ok(new_handler) = handler_rx.try_recv() {
228                    log::debug!("Thread actor '{}' on node '{}' in thread actor system '{}': new handler received", actor_name, actor_node_id, actor_system_name);
229                    current_handler = new_handler;
230                }
231                match rx.recv() {
232                    Err(_) => {
233                        log::info!("Thread actor '{}' on node '{}' in thread actor system '{}': shutting down due to message receive channel having being closed", actor_name, actor_node_id, actor_system_name);
234                        notify_close(close_cond2);
235                        return;
236                    }
237                    Ok(m) => {
238                        if let Some(control) = current_handler.lock().unwrap().receive(m) {
239                            match control {
240                                ActorControl::Exit() => {
241                                    log::info!("Thread actor '{}' on node '{}' in thread actor system '{}': closing requested by handler, shutting it down", actor_name, actor_node_id, actor_system_name);
242                                    notify_close(close_cond2);
243                                    return;
244                                }
245                            }
246                        }
247                    }
248                }
249            }
250        });
251        ThreadActorData {
252            tx,
253            handler_tx,
254            close_cond,
255            name: actor_name_clone,
256        }
257    }
258
259    fn push_async_task(&mut self, tokio_join_handle: TokioJoinHandle<()>) {
260        push_async_task(&mut self.async_tasks, tokio_join_handle);
261    }
262}
263
264impl Drop for ThreadActorSystem {
265    fn drop(&mut self) {
266        if self.join_tasks_on_drop {
267            log::debug!(
268                "Thread actor system '{}' dropping, joining tasks",
269                self.async_runtime.name
270            );
271            join_tasks(self.async_runtime.clone().as_ref(), self.extract_tasks());
272        } else {
273            log::debug!(
274                "Thread actor system '{}' dropping, not joining tasks",
275                self.async_runtime.name
276            );
277        }
278    }
279}
280
281impl Task for ThreadActorSystem {
282    fn is_finished(&self) -> bool {
283        self.thread_blocking_tasks.iter().all(|h| h.is_finished())
284    }
285}
286
287#[derive(Clone, Debug)]
288pub struct ThreadActorSystemHandle {
289    actor_system: Arc<Mutex<ThreadActorSystem>>,
290}
291
292impl ThreadActorSystemHandle {
293    /// Owns the passed runtime, using it only if no contextual handle is available;
294    ///  if `None` is passed, it creates a runtime with multi-threaded support,
295    ///  CPU-based thread pool size and all features enabled.
296    pub fn new_actor_system(
297        name: impl Into<String>,
298        tokio: Option<Runtime>,
299        join_tasks_on_drop: bool,
300    ) -> Self {
301        ThreadActorSystemHandle {
302            actor_system: Arc::new(Mutex::new(ThreadActorSystem {
303                async_runtime: Arc::new(AsyncRuntime::new(name, tokio)),
304                thread_blocking_tasks: Default::default(),
305                async_tasks: Default::default(),
306                join_tasks_on_drop,
307            })),
308        }
309    }
310
311    pub fn spawn_thread_blocking_task(&self, f: impl FnOnce() + Send + 'static) {
312        self.actor_system
313            .lock()
314            .unwrap()
315            .spawn_thread_blocking_task(f);
316    }
317}
318
319impl ActorSystemHandle for ThreadActorSystemHandle {
320    type ActorRefT<MsgT>
321        = ThreadActorRef<MsgT>
322    where
323        MsgT: ActorMsg;
324
325    fn create<MsgT>(
326        &self,
327        node_id: impl Into<String>,
328        name: impl Into<String>,
329        join_on_drop: bool,
330    ) -> Self::ActorRefT<MsgT>
331    where
332        MsgT: ActorMsg,
333    {
334        ThreadActorRef {
335            actor: Arc::new(ThreadActor {
336                data: self.actor_system.lock().unwrap().create(name, node_id),
337                actor_system_handle: self.clone(),
338                join_on_drop,
339            }),
340        }
341    }
342}
343
344impl Task for ThreadActorSystemHandle {
345    fn is_finished(&self) -> bool {
346        self.actor_system.lock().unwrap().is_finished()
347    }
348}
349
350impl Joinable<()> for ThreadActorSystemHandle {
351    fn join(&mut self) {
352        let mut actor_system_lock_guard = self.actor_system.lock().unwrap();
353        let async_runtime = actor_system_lock_guard.async_runtime.clone();
354        let tasks = actor_system_lock_guard.extract_tasks();
355        // Drop the lock before joining tasks to avoid deadlocks if they also lock the actor system
356        drop(actor_system_lock_guard);
357        join_tasks(async_runtime.as_ref(), tasks);
358    }
359}