echo_grpc_server/
main.rs

1mod conf;
2
3use std::env::args;
4use std::error::Error;
5use std::io::{ErrorKind, Write};
6use tonic::transport::Server;
7
8use anyhow::Result;
9
10use clap::Parser;
11use conf::Config as ServerConfig;
12
13use crate::pb::Tx;
14use std::net::ToSocketAddrs;
15use std::pin::Pin;
16use tokio::sync::mpsc;
17use tokio_stream::wrappers::ReceiverStream;
18use tokio_stream::{Stream, StreamExt};
19use tonic::{Request, Response, Status, Streaming};
20
21pub mod pb {
22    tonic::include_proto!("grpc.echo");
23}
24
25type ResponseStream = Pin<Box<dyn Stream<Item = Result<Tx, Status>> + Send>>;
26
27#[derive(Parser)]
28#[command(author, version, about, long_about = None)]
29struct Cli {
30    #[arg(short, long, value_name = "FILE.toml")]
31    config: std::path::PathBuf,
32}
33
34#[derive(Debug)]
35pub struct EchoServer {}
36
37#[tonic::async_trait]
38impl pb::service_server::Service for EchoServer {
39    type EchoStream = ResponseStream;
40
41    async fn echo(
42        &self,
43        req: Request<Streaming<Tx>>,
44    ) -> std::result::Result<Response<Self::EchoStream>, Status> {
45        let mut in_stream = req.into_inner();
46        let (sender, receiver) = mpsc::channel(128);
47
48        // this spawn here is required if you want to handle connection error.
49        // If we just map `in_stream` and write it back as `out_stream` the `out_stream`
50        // will be drooped when connection error occurs and error will never be propagated
51        // to mapped version of `in_stream`.
52        tokio::spawn(async move {
53            while let Some(result) = in_stream.next().await {
54                match result {
55                    Ok(tx) => sender
56                        .send(Ok(Tx {
57                            tx_id: tx.tx_id,
58                            value: tx.value,
59                        }))
60                        .await
61                        .expect("working rx"),
62                    Err(err) => {
63                        if let Some(io_err) = match_for_io_error(&err) {
64                            if io_err.kind() == ErrorKind::BrokenPipe {
65                                // here you can handle special case when client
66                                // disconnected in unexpected way
67                                log::info!("client disconnected: broken pipe");
68                                break;
69                            }
70                        }
71
72                        match sender.send(Err(err)).await {
73                            Ok(_) => (),
74                            Err(_err) => break, // response was dropped
75                        }
76                    }
77                }
78            }
79            log::info!("stream ended");
80        });
81
82        // echo just write the same data that was received
83        let out_stream = ReceiverStream::new(receiver);
84
85        Ok(Response::new(Box::pin(out_stream) as Self::EchoStream))
86    }
87}
88
89fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
90    let err: &(dyn Error + 'static) = err_status;
91
92    loop {
93        if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
94            return Some(io_err);
95        }
96
97        // h2::Error do not expose std::io::Error with `source()`
98        // https://github.com/hyperium/h2/pull/462
99        if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
100            if let Some(io_err) = h2_err.get_io() {
101                return Some(io_err);
102            }
103        }
104
105        err.source()?;
106    }
107}
108
109#[tokio::main]
110async fn main() -> Result<()> {
111    env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info"))
112        .target(env_logger::Target::Stdout)
113        .format(|buf, record| {
114            let ts = buf.timestamp_micros();
115            writeln!(
116                buf,
117                "[{} {}{}{:#} {:?} {} {}:{}] {}",
118                ts,
119                buf.default_level_style(record.level()),
120                record.level(),
121                buf.default_level_style(record.level()),
122                std::thread::current().id(),
123                record.target(),
124                record.file().unwrap_or("<unknown>"),
125                record.line().unwrap_or(0),
126                record.args()
127            )
128        })
129        .init();
130
131    let cli = Cli::parse();
132
133    let settings = config::Config::builder()
134        .add_source(config::File::from(cli.config))
135        .add_source(config::Environment::with_prefix("BENCH"))
136        .build()?;
137
138    let config = settings.try_deserialize::<ServerConfig>()?;
139
140    let program_name = args().next().unwrap();
141    log::info!(
142        "'{}' starting, configuration loaded: {:?}",
143        program_name,
144        config
145    );
146
147    let server = EchoServer {};
148    let awaiter = Server::builder()
149        .add_service(pb::service_server::ServiceServer::new(server))
150        .serve(config.listen.to_socket_addrs().unwrap().next().unwrap());
151
152    log::info!("'{}' open for business", program_name);
153
154    awaiter.await.unwrap();
155
156    Ok(())
157}