Lazy router and clippy rules

This commit is contained in:
Bad Manners 2024-09-04 22:24:35 -03:00
parent baa7963a77
commit 9e046c8821
2 changed files with 37 additions and 71 deletions

View file

@ -4,7 +4,7 @@ description = "Remote port forwarding (reverse tunneling) with Russh to serve an
name = "russh-axum-tcpip-forward"
version = "0.1.0"
edition = "2021"
keywords = ["ssh"]
keywords = ["ssh", "http", "async", "demo"]
license = "MIT"
readme = "README.md"

View file

@ -1,8 +1,7 @@
use core::str;
use std::{
iter,
path::PathBuf,
sync::{Arc, Mutex},
sync::{Arc, LazyLock, Mutex},
time::Duration,
};
@ -28,11 +27,7 @@ use russh::{
},
Channel, ChannelMsg, Disconnect,
};
use tokio::{
fs,
sync::watch::{self, Receiver, Sender},
time::sleep,
};
use tokio::{fs, time::sleep};
use tower::Service;
use tracing::{debug, debug_span, error, info, trace, warn};
@ -80,22 +75,17 @@ async fn main() -> Result<()> {
TcpForwardSession::connect(&args.host, args.port, config, Arc::new(secret_key))
.await
.with_context(|| "Initial connection failed")?;
let state = Arc::new(Mutex::new(0));
let router = router_factory(AppState {
data: Arc::clone(&state),
});
loop {
match session
.start_forwarding_with(&args.remote_host, args.remote_port, || router.clone())
.start_forwarding(&args.remote_host, args.remote_port)
.await
{
Err(e) => error!(error = ?e, "TCP forward session failed."),
_ => info!("Connection closed."),
}
debug!("Attempting graceful disconnect.");
match session.close().await {
Err(e) => debug!(error = ?e, "Graceful disconnect failed."),
_ => (),
if let Err(e) = session.close().await {
debug!(error = ?e, "Graceful disconnect failed.")
}
debug!("Restarting connection.");
let mut reconnect_attempt = 0u64;
@ -124,11 +114,18 @@ struct AppState {
data: Arc<Mutex<usize>>,
}
/// A function that creates our Axum router. It will only be called once.
/// A function that creates our Axum router.
fn router_factory(state: AppState) -> Router {
Router::new().route("/", get(hello)).with_state(state)
}
/// A lazily-created Router, to be used by the SSH client tunnels.
static ROUTER: LazyLock<Router> = LazyLock::new(|| {
router_factory(AppState {
data: Arc::new(Mutex::new(0)),
})
});
/// A basic example endpoint that includes shared state.
async fn hello(State(state): State<AppState>) -> String {
let mut request_id = state.data.lock().unwrap();
@ -144,9 +141,6 @@ struct TcpForwardSession {
config: Arc<Config>,
secret_key: Arc<KeyPair>,
session: Handle<Client>,
/// Tokio `watch` channel to transmit our Axum router to all `Client`s once, since they will be the
/// ones who need it.
router_tx: Option<Sender<Option<Router>>>,
}
/// User-implemented session type as a helper for interfacing with the SSH protocol.
@ -161,10 +155,7 @@ impl TcpForwardSession {
let span = debug_span!("TcpForwardSession.connect");
let _enter = span;
debug!("TcpForwardSession connecting...");
let (router_fn_tx, router_fn_rx) = watch::channel(None);
let client = Client {
router_rx: router_fn_rx,
};
let client = Client {};
let mut session = client::connect(Arc::clone(&config), (host, port), client)
.await
.with_context(|| "Unable to connect to remote host.")?;
@ -179,32 +170,14 @@ impl TcpForwardSession {
config,
session,
secret_key,
router_tx: Some(router_fn_tx),
})
}
/// Sends a port forwarding request and opens a session to receive miscellaneous data.
/// The function yields when the session is broken (for example, if the connection was lost).
///
/// Here, we also pass a closure to create our Axum server, which will respond to clients connecting to the remote
/// port through a separate channel in the SSH connection. This will only be evaluated once.
async fn start_forwarding_with(
&mut self,
remote_host: &str,
remote_port: u16,
router_factory: impl FnOnce() -> Router,
) -> Result<()> {
async fn start_forwarding(&mut self, remote_host: &str, remote_port: u16) -> Result<()> {
let span = debug_span!("TcpForwardSession.start");
let _enter = span;
match self
.router_tx
.take()
.map(|router_tx| router_tx.send(Some(router_factory())))
.transpose()
{
Err(_) => return Err(anyhow!("Unable to send router.")),
_ => (),
}
self.session
.tcpip_forward(remote_host, remote_port.into())
.await
@ -246,28 +219,26 @@ impl TcpForwardSession {
match TcpForwardSession::connect(host, port, config.clone(), secret_key.clone()).await {
Err(err) => {
let mut e = err;
loop {
for (i, duration) in timer_iterator.enumerate() {
sleep(duration).await;
trace!("Reconnection attempt #{}...", i + 1);
e = match TcpForwardSession::connect(
&host,
port,
config.clone(),
secret_key.clone(),
)
.await
{
Err(e) => e,
session => {
debug!(reconnection_attempts = i + 1, "Succeeded on reconnecting.");
return session;
}
for (i, duration) in timer_iterator.enumerate() {
sleep(duration).await;
trace!("Reconnection attempt #{}...", i + 1);
e = match TcpForwardSession::connect(
host,
port,
config.clone(),
secret_key.clone(),
)
.await
{
Err(e) => e,
session => {
debug!(reconnection_attempts = i + 1, "Succeeded on reconnecting.");
return session;
}
}
warn!("Backing off from reconnection attempt.");
return Err(e);
}
warn!("Backing off from reconnection attempt.");
Err(e)
}
session => {
debug!("Reconnected on first attempt.");
@ -285,9 +256,7 @@ impl TcpForwardSession {
}
/// Our SSH client implementing the `Handler` callbacks for the functions we need to use.
struct Client {
router_rx: Receiver<Option<Router>>,
}
struct Client {}
#[async_trait]
impl client::Handler for Client {
@ -326,12 +295,9 @@ impl client::Handler for Client {
originator_port = originator_port,
"New connection!"
);
let mut router: RouterIntoService<Incoming> = match self.router_rx.borrow().as_ref() {
None => {
return Err(anyhow!("No router has been set yet. Closing channel..."));
}
Some(ref router) => <Router as Clone>::clone(&router).into_service::<Incoming>(),
};
// Get our router from the lazy static.
let mut router: RouterIntoService<Incoming> =
<Router as Clone>::clone(&*ROUTER).into_service::<Incoming>();
poll_fn(|cx| router.poll_ready(cx)).await.unwrap();
let service = service_fn(move |req| {
// Cloning our service for each call is required, given that service_fn expects Fn instead of FnMut.