Initial commit

This commit is contained in:
Bad Manners 2024-09-04 20:00:47 -03:00
commit baa7963a77
Signed by: badmanners
GPG key ID: 8C88292CCB075609
5 changed files with 2601 additions and 0 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/target

2213
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

25
Cargo.toml Normal file
View file

@ -0,0 +1,25 @@
[package]
authors = ["Bad Manners <me@badmanners.xyz>"]
description = "Remote port forwarding (reverse tunneling) with Russh to serve an Axum application."
name = "russh-axum-tcpip-forward"
version = "0.1.0"
edition = "2021"
keywords = ["ssh"]
license = "MIT"
readme = "README.md"
[dependencies]
anyhow = "1.0.86"
async-trait = "0.1"
axum = "0.7.5"
clap = { version = "4.5.17", features = ["derive"] }
futures = "0.3.30"
hyper = { version = "1", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] }
russh = "0.45"
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1.15", features = ["net", "sync"] }
tokio-util = "0.7.11"
tower = "0.5.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3.18" }

5
README.md Normal file
View file

@ -0,0 +1,5 @@
# Russh + Axum + tcpip_forward!
A Rust project demonstrating how to serve Axum's HTTP server on a remote host's port, using SSH tunneling and streaming to avoid opening a socket on the client.
Tokio, Tower, hyper, and `async` are responsible for gluing everything together. They are pretty awesome! The hardest part to implement was Axum's half; mainly, figuring out how to accept a streaming socket instead of the default TcpListener.

357
src/main.rs Normal file
View file

@ -0,0 +1,357 @@
use core::str;
use std::{
iter,
path::PathBuf,
sync::{Arc, Mutex},
time::Duration,
};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use axum::{
extract::State,
routing::{get, RouterIntoService},
Router,
};
use clap::Parser;
use futures::future::poll_fn;
use hyper::{body::Incoming, service::service_fn};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder,
};
use russh::{
client::{self, Config, Handle, Msg, Session},
keys::{
decode_secret_key,
key::{self, KeyPair},
},
Channel, ChannelMsg, Disconnect,
};
use tokio::{
fs,
sync::watch::{self, Receiver, Sender},
time::sleep,
};
use tower::Service;
use tracing::{debug, debug_span, error, info, trace, warn};
/* Entrypoint */
/// Remote port forwarding (reverse tunneling) with Russh to serve an Axum application.
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct ClapArgs {
/// SSH hostname
#[arg(short = 'H', long)]
host: String,
/// SSH port
#[arg(short, long, default_value_t = 22)]
port: u16,
/// Identity file containing private key
#[arg(short, long)]
identity_file: PathBuf,
/// Remote hostname to bind to
#[arg(short, long, default_value_t = String::from("localhost"))]
remote_host: String,
/// Remote port to bind to
#[arg(short = 't', long, default_value_t = 80)]
remote_port: u16,
}
#[tokio::main]
async fn main() -> Result<()> {
let subscriber = tracing_subscriber::FmtSubscriber::new();
tracing::subscriber::set_global_default(subscriber)?;
trace!("Tracing is up!");
let args = ClapArgs::parse();
let secret_key = fs::read_to_string(args.identity_file)
.await
.with_context(|| "Failed to open secret key")?;
let secret_key = decode_secret_key(&secret_key, None).with_context(|| "Invalid secret key")?;
let config = Arc::new(client::Config {
..Default::default()
});
let mut session =
TcpForwardSession::connect(&args.host, args.port, config, Arc::new(secret_key))
.await
.with_context(|| "Initial connection failed")?;
let state = Arc::new(Mutex::new(0));
let router = router_factory(AppState {
data: Arc::clone(&state),
});
loop {
match session
.start_forwarding_with(&args.remote_host, args.remote_port, || router.clone())
.await
{
Err(e) => error!(error = ?e, "TCP forward session failed."),
_ => info!("Connection closed."),
}
debug!("Attempting graceful disconnect.");
match session.close().await {
Err(e) => debug!(error = ?e, "Graceful disconnect failed."),
_ => (),
}
debug!("Restarting connection.");
let mut reconnect_attempt = 0u64;
session = session
.reconnect_with(
&args.host,
args.port,
iter::from_fn(move || {
reconnect_attempt += 1;
if reconnect_attempt <= 5 {
Some(Duration::from_secs(2 * reconnect_attempt))
} else {
None
}
}),
)
.await
.with_context(|| "Reconnection failed.")?;
}
}
/* Axum router */
#[derive(Clone)]
struct AppState {
data: Arc<Mutex<usize>>,
}
/// A function that creates our Axum router. It will only be called once.
fn router_factory(state: AppState) -> Router {
Router::new().route("/", get(hello)).with_state(state)
}
/// A basic example endpoint that includes shared state.
async fn hello(State(state): State<AppState>) -> String {
let mut request_id = state.data.lock().unwrap();
*request_id += 1;
debug!(id = %request_id, "GET /");
format!("Hello, request #{}!", request_id)
}
/* Russh session and client */
/// User-implemented session type as a helper for interfacing with the SSH protocol.
struct TcpForwardSession {
config: Arc<Config>,
secret_key: Arc<KeyPair>,
session: Handle<Client>,
/// Tokio `watch` channel to transmit our Axum router to all `Client`s once, since they will be the
/// ones who need it.
router_tx: Option<Sender<Option<Router>>>,
}
/// User-implemented session type as a helper for interfacing with the SSH protocol.
impl TcpForwardSession {
/// Creates a connection with the SSH server.
async fn connect(
host: &str,
port: u16,
config: Arc<Config>,
secret_key: Arc<KeyPair>,
) -> Result<Self> {
let span = debug_span!("TcpForwardSession.connect");
let _enter = span;
debug!("TcpForwardSession connecting...");
let (router_fn_tx, router_fn_rx) = watch::channel(None);
let client = Client {
router_rx: router_fn_rx,
};
let mut session = client::connect(Arc::clone(&config), (host, port), client)
.await
.with_context(|| "Unable to connect to remote host.")?;
if !session
.authenticate_publickey("root", Arc::clone(&secret_key))
.await
.with_context(|| "Authentication error.")?
{
return Err(anyhow!("Authentication failed."));
}
Ok(Self {
config,
session,
secret_key,
router_tx: Some(router_fn_tx),
})
}
/// Sends a port forwarding request and opens a session to receive miscellaneous data.
/// The function yields when the session is broken (for example, if the connection was lost).
///
/// Here, we also pass a closure to create our Axum server, which will respond to clients connecting to the remote
/// port through a separate channel in the SSH connection. This will only be evaluated once.
async fn start_forwarding_with(
&mut self,
remote_host: &str,
remote_port: u16,
router_factory: impl FnOnce() -> Router,
) -> Result<()> {
let span = debug_span!("TcpForwardSession.start");
let _enter = span;
match self
.router_tx
.take()
.map(|router_tx| router_tx.send(Some(router_factory())))
.transpose()
{
Err(_) => return Err(anyhow!("Unable to send router.")),
_ => (),
}
self.session
.tcpip_forward(remote_host, remote_port.into())
.await
.with_context(|| "tcpip_forward error.")?;
let mut channel = self
.session
.channel_open_session()
.await
.with_context(|| "channel_open_session error.")?;
loop {
let Some(msg) = channel.wait().await else {
return Err(anyhow!("Unexpected end of channel."));
};
match msg {
ChannelMsg::Data { data } => {
print!("{}", String::from_utf8_lossy(&data));
}
ChannelMsg::Close => break,
msg => return Err(anyhow!("Unknown message type {:?}.", msg)),
}
}
Ok(())
}
/// Attempts to reconnect to the SSH server.
///
/// Our reconnection strategy comes from an iterator which yields `Duration`s, which tell us how long to delay
/// our next reconnection attempt for. The function will stop attempting to reconnect once the iterator
/// stops yielding values.
async fn reconnect_with(
self,
host: &str,
port: u16,
timer_iterator: impl Iterator<Item = Duration>,
) -> Result<Self> {
let TcpForwardSession {
config, secret_key, ..
} = self;
match TcpForwardSession::connect(host, port, config.clone(), secret_key.clone()).await {
Err(err) => {
let mut e = err;
loop {
for (i, duration) in timer_iterator.enumerate() {
sleep(duration).await;
trace!("Reconnection attempt #{}...", i + 1);
e = match TcpForwardSession::connect(
&host,
port,
config.clone(),
secret_key.clone(),
)
.await
{
Err(e) => e,
session => {
debug!(reconnection_attempts = i + 1, "Succeeded on reconnecting.");
return session;
}
}
}
warn!("Backing off from reconnection attempt.");
return Err(e);
}
}
session => {
debug!("Reconnected on first attempt.");
session
}
}
}
async fn close(&mut self) -> Result<()> {
self.session
.disconnect(Disconnect::ByApplication, "", "English")
.await?;
Ok(())
}
}
/// Our SSH client implementing the `Handler` callbacks for the functions we need to use.
struct Client {
router_rx: Receiver<Option<Router>>,
}
#[async_trait]
impl client::Handler for Client {
type Error = anyhow::Error;
/// Always accept the SSH server's pubkey. Don't do this in production.
async fn check_server_key(
&mut self,
_server_public_key: &key::PublicKey,
) -> Result<bool, Self::Error> {
Ok(true)
}
/// Handle a new forwarded connection, represented by a specific `Channel`. We will create a clone of our router,
/// and forward any messages from this channel with its streaming API.
///
/// To make Axum behave with streaming, we must turn it into a Tower service first.
/// And to handle the SSH channel as a stream, we will use a utility method from Tokio that turns our
/// AsyncRead/Write stream into a `hyper` IO object.
async fn server_channel_open_forwarded_tcpip(
&mut self,
channel: Channel<Msg>,
connected_address: &str,
connected_port: u32,
originator_address: &str,
originator_port: u32,
session: &mut Session,
) -> Result<(), Self::Error> {
let span = debug_span!("server_channel_open_forwarded_tcpip",);
let _enter = span.enter();
debug!(
sshid = %String::from_utf8_lossy(session.remote_sshid()),
connected_address = connected_address,
connected_port = connected_port,
originator_address = originator_address,
originator_port = originator_port,
"New connection!"
);
let mut router: RouterIntoService<Incoming> = match self.router_rx.borrow().as_ref() {
None => {
return Err(anyhow!("No router has been set yet. Closing channel..."));
}
Some(ref router) => <Router as Clone>::clone(&router).into_service::<Incoming>(),
};
poll_fn(|cx| router.poll_ready(cx)).await.unwrap();
let service = service_fn(move |req| {
// Cloning our service for each call is required, given that service_fn expects Fn instead of FnMut.
// This should be fine performance-wise, since RouterIntoService is a thin wrapper around Router,
// which itself is a thin wrapper around Arc<RouterInner<_>>.
let mut router = router.clone();
async move { router.call(req).await }
});
let socket = TokioIo::new(channel.into_stream());
let server = Builder::new(TokioExecutor::new());
// I'm not really sure why tokio::spawn is necessary here, but it doesn't work otherwise.
// My guess is that we block on TcpForwardSession.start_forwarding_with().
// We use `serve_connection_with_upgrades` to allow upgrading to WebSocket - which will still run through
// our SSH tunnel for every message!
tokio::spawn(async move {
server
.serve_connection_with_upgrades(socket, service)
.await
.unwrap();
});
Ok(())
}
}