Initial commit
This commit is contained in:
commit
baa7963a77
5 changed files with 2601 additions and 0 deletions
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
||||||
|
/target
|
2213
Cargo.lock
generated
Normal file
2213
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
25
Cargo.toml
Normal file
25
Cargo.toml
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
[package]
|
||||||
|
authors = ["Bad Manners <me@badmanners.xyz>"]
|
||||||
|
description = "Remote port forwarding (reverse tunneling) with Russh to serve an Axum application."
|
||||||
|
name = "russh-axum-tcpip-forward"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
keywords = ["ssh"]
|
||||||
|
license = "MIT"
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0.86"
|
||||||
|
async-trait = "0.1"
|
||||||
|
axum = "0.7.5"
|
||||||
|
clap = { version = "4.5.17", features = ["derive"] }
|
||||||
|
futures = "0.3.30"
|
||||||
|
hyper = { version = "1", features = ["full"] }
|
||||||
|
hyper-util = { version = "0.1", features = ["full"] }
|
||||||
|
russh = "0.45"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
tokio-stream = { version = "0.1.15", features = ["net", "sync"] }
|
||||||
|
tokio-util = "0.7.11"
|
||||||
|
tower = "0.5.0"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3.18" }
|
5
README.md
Normal file
5
README.md
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Russh + Axum + tcpip_forward!
|
||||||
|
|
||||||
|
A Rust project demonstrating how to serve Axum's HTTP server on a remote host's port, using SSH tunneling and streaming to avoid opening a socket on the client.
|
||||||
|
|
||||||
|
Tokio, Tower, hyper, and `async` are responsible for gluing everything together. They are pretty awesome! The hardest part to implement was Axum's half; mainly, figuring out how to accept a streaming socket instead of the default TcpListener.
|
357
src/main.rs
Normal file
357
src/main.rs
Normal file
|
@ -0,0 +1,357 @@
|
||||||
|
use core::str;
|
||||||
|
use std::{
|
||||||
|
iter,
|
||||||
|
path::PathBuf,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Context, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
routing::{get, RouterIntoService},
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use clap::Parser;
|
||||||
|
use futures::future::poll_fn;
|
||||||
|
use hyper::{body::Incoming, service::service_fn};
|
||||||
|
use hyper_util::{
|
||||||
|
rt::{TokioExecutor, TokioIo},
|
||||||
|
server::conn::auto::Builder,
|
||||||
|
};
|
||||||
|
use russh::{
|
||||||
|
client::{self, Config, Handle, Msg, Session},
|
||||||
|
keys::{
|
||||||
|
decode_secret_key,
|
||||||
|
key::{self, KeyPair},
|
||||||
|
},
|
||||||
|
Channel, ChannelMsg, Disconnect,
|
||||||
|
};
|
||||||
|
use tokio::{
|
||||||
|
fs,
|
||||||
|
sync::watch::{self, Receiver, Sender},
|
||||||
|
time::sleep,
|
||||||
|
};
|
||||||
|
use tower::Service;
|
||||||
|
use tracing::{debug, debug_span, error, info, trace, warn};
|
||||||
|
|
||||||
|
/* Entrypoint */
|
||||||
|
|
||||||
|
/// Remote port forwarding (reverse tunneling) with Russh to serve an Axum application.
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(version, about, long_about = None)]
|
||||||
|
struct ClapArgs {
|
||||||
|
/// SSH hostname
|
||||||
|
#[arg(short = 'H', long)]
|
||||||
|
host: String,
|
||||||
|
|
||||||
|
/// SSH port
|
||||||
|
#[arg(short, long, default_value_t = 22)]
|
||||||
|
port: u16,
|
||||||
|
|
||||||
|
/// Identity file containing private key
|
||||||
|
#[arg(short, long)]
|
||||||
|
identity_file: PathBuf,
|
||||||
|
|
||||||
|
/// Remote hostname to bind to
|
||||||
|
#[arg(short, long, default_value_t = String::from("localhost"))]
|
||||||
|
remote_host: String,
|
||||||
|
|
||||||
|
/// Remote port to bind to
|
||||||
|
#[arg(short = 't', long, default_value_t = 80)]
|
||||||
|
remote_port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
let subscriber = tracing_subscriber::FmtSubscriber::new();
|
||||||
|
tracing::subscriber::set_global_default(subscriber)?;
|
||||||
|
trace!("Tracing is up!");
|
||||||
|
let args = ClapArgs::parse();
|
||||||
|
let secret_key = fs::read_to_string(args.identity_file)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to open secret key")?;
|
||||||
|
let secret_key = decode_secret_key(&secret_key, None).with_context(|| "Invalid secret key")?;
|
||||||
|
let config = Arc::new(client::Config {
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
let mut session =
|
||||||
|
TcpForwardSession::connect(&args.host, args.port, config, Arc::new(secret_key))
|
||||||
|
.await
|
||||||
|
.with_context(|| "Initial connection failed")?;
|
||||||
|
let state = Arc::new(Mutex::new(0));
|
||||||
|
let router = router_factory(AppState {
|
||||||
|
data: Arc::clone(&state),
|
||||||
|
});
|
||||||
|
loop {
|
||||||
|
match session
|
||||||
|
.start_forwarding_with(&args.remote_host, args.remote_port, || router.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Err(e) => error!(error = ?e, "TCP forward session failed."),
|
||||||
|
_ => info!("Connection closed."),
|
||||||
|
}
|
||||||
|
debug!("Attempting graceful disconnect.");
|
||||||
|
match session.close().await {
|
||||||
|
Err(e) => debug!(error = ?e, "Graceful disconnect failed."),
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
debug!("Restarting connection.");
|
||||||
|
let mut reconnect_attempt = 0u64;
|
||||||
|
session = session
|
||||||
|
.reconnect_with(
|
||||||
|
&args.host,
|
||||||
|
args.port,
|
||||||
|
iter::from_fn(move || {
|
||||||
|
reconnect_attempt += 1;
|
||||||
|
if reconnect_attempt <= 5 {
|
||||||
|
Some(Duration::from_secs(2 * reconnect_attempt))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Reconnection failed.")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Axum router */
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState {
|
||||||
|
data: Arc<Mutex<usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A function that creates our Axum router. It will only be called once.
|
||||||
|
fn router_factory(state: AppState) -> Router {
|
||||||
|
Router::new().route("/", get(hello)).with_state(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A basic example endpoint that includes shared state.
|
||||||
|
async fn hello(State(state): State<AppState>) -> String {
|
||||||
|
let mut request_id = state.data.lock().unwrap();
|
||||||
|
*request_id += 1;
|
||||||
|
debug!(id = %request_id, "GET /");
|
||||||
|
format!("Hello, request #{}!", request_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Russh session and client */
|
||||||
|
|
||||||
|
/// User-implemented session type as a helper for interfacing with the SSH protocol.
|
||||||
|
struct TcpForwardSession {
|
||||||
|
config: Arc<Config>,
|
||||||
|
secret_key: Arc<KeyPair>,
|
||||||
|
session: Handle<Client>,
|
||||||
|
/// Tokio `watch` channel to transmit our Axum router to all `Client`s once, since they will be the
|
||||||
|
/// ones who need it.
|
||||||
|
router_tx: Option<Sender<Option<Router>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// User-implemented session type as a helper for interfacing with the SSH protocol.
|
||||||
|
impl TcpForwardSession {
|
||||||
|
/// Creates a connection with the SSH server.
|
||||||
|
async fn connect(
|
||||||
|
host: &str,
|
||||||
|
port: u16,
|
||||||
|
config: Arc<Config>,
|
||||||
|
secret_key: Arc<KeyPair>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let span = debug_span!("TcpForwardSession.connect");
|
||||||
|
let _enter = span;
|
||||||
|
debug!("TcpForwardSession connecting...");
|
||||||
|
let (router_fn_tx, router_fn_rx) = watch::channel(None);
|
||||||
|
let client = Client {
|
||||||
|
router_rx: router_fn_rx,
|
||||||
|
};
|
||||||
|
let mut session = client::connect(Arc::clone(&config), (host, port), client)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Unable to connect to remote host.")?;
|
||||||
|
if !session
|
||||||
|
.authenticate_publickey("root", Arc::clone(&secret_key))
|
||||||
|
.await
|
||||||
|
.with_context(|| "Authentication error.")?
|
||||||
|
{
|
||||||
|
return Err(anyhow!("Authentication failed."));
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
config,
|
||||||
|
session,
|
||||||
|
secret_key,
|
||||||
|
router_tx: Some(router_fn_tx),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends a port forwarding request and opens a session to receive miscellaneous data.
|
||||||
|
/// The function yields when the session is broken (for example, if the connection was lost).
|
||||||
|
///
|
||||||
|
/// Here, we also pass a closure to create our Axum server, which will respond to clients connecting to the remote
|
||||||
|
/// port through a separate channel in the SSH connection. This will only be evaluated once.
|
||||||
|
async fn start_forwarding_with(
|
||||||
|
&mut self,
|
||||||
|
remote_host: &str,
|
||||||
|
remote_port: u16,
|
||||||
|
router_factory: impl FnOnce() -> Router,
|
||||||
|
) -> Result<()> {
|
||||||
|
let span = debug_span!("TcpForwardSession.start");
|
||||||
|
let _enter = span;
|
||||||
|
match self
|
||||||
|
.router_tx
|
||||||
|
.take()
|
||||||
|
.map(|router_tx| router_tx.send(Some(router_factory())))
|
||||||
|
.transpose()
|
||||||
|
{
|
||||||
|
Err(_) => return Err(anyhow!("Unable to send router.")),
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
self.session
|
||||||
|
.tcpip_forward(remote_host, remote_port.into())
|
||||||
|
.await
|
||||||
|
.with_context(|| "tcpip_forward error.")?;
|
||||||
|
let mut channel = self
|
||||||
|
.session
|
||||||
|
.channel_open_session()
|
||||||
|
.await
|
||||||
|
.with_context(|| "channel_open_session error.")?;
|
||||||
|
loop {
|
||||||
|
let Some(msg) = channel.wait().await else {
|
||||||
|
return Err(anyhow!("Unexpected end of channel."));
|
||||||
|
};
|
||||||
|
match msg {
|
||||||
|
ChannelMsg::Data { data } => {
|
||||||
|
print!("{}", String::from_utf8_lossy(&data));
|
||||||
|
}
|
||||||
|
ChannelMsg::Close => break,
|
||||||
|
msg => return Err(anyhow!("Unknown message type {:?}.", msg)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Attempts to reconnect to the SSH server.
|
||||||
|
///
|
||||||
|
/// Our reconnection strategy comes from an iterator which yields `Duration`s, which tell us how long to delay
|
||||||
|
/// our next reconnection attempt for. The function will stop attempting to reconnect once the iterator
|
||||||
|
/// stops yielding values.
|
||||||
|
async fn reconnect_with(
|
||||||
|
self,
|
||||||
|
host: &str,
|
||||||
|
port: u16,
|
||||||
|
timer_iterator: impl Iterator<Item = Duration>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let TcpForwardSession {
|
||||||
|
config, secret_key, ..
|
||||||
|
} = self;
|
||||||
|
match TcpForwardSession::connect(host, port, config.clone(), secret_key.clone()).await {
|
||||||
|
Err(err) => {
|
||||||
|
let mut e = err;
|
||||||
|
loop {
|
||||||
|
for (i, duration) in timer_iterator.enumerate() {
|
||||||
|
sleep(duration).await;
|
||||||
|
trace!("Reconnection attempt #{}...", i + 1);
|
||||||
|
e = match TcpForwardSession::connect(
|
||||||
|
&host,
|
||||||
|
port,
|
||||||
|
config.clone(),
|
||||||
|
secret_key.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Err(e) => e,
|
||||||
|
session => {
|
||||||
|
debug!(reconnection_attempts = i + 1, "Succeeded on reconnecting.");
|
||||||
|
return session;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
warn!("Backing off from reconnection attempt.");
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
session => {
|
||||||
|
debug!("Reconnected on first attempt.");
|
||||||
|
session
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close(&mut self) -> Result<()> {
|
||||||
|
self.session
|
||||||
|
.disconnect(Disconnect::ByApplication, "", "English")
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Our SSH client implementing the `Handler` callbacks for the functions we need to use.
|
||||||
|
struct Client {
|
||||||
|
router_rx: Receiver<Option<Router>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl client::Handler for Client {
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
/// Always accept the SSH server's pubkey. Don't do this in production.
|
||||||
|
async fn check_server_key(
|
||||||
|
&mut self,
|
||||||
|
_server_public_key: &key::PublicKey,
|
||||||
|
) -> Result<bool, Self::Error> {
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a new forwarded connection, represented by a specific `Channel`. We will create a clone of our router,
|
||||||
|
/// and forward any messages from this channel with its streaming API.
|
||||||
|
///
|
||||||
|
/// To make Axum behave with streaming, we must turn it into a Tower service first.
|
||||||
|
/// And to handle the SSH channel as a stream, we will use a utility method from Tokio that turns our
|
||||||
|
/// AsyncRead/Write stream into a `hyper` IO object.
|
||||||
|
async fn server_channel_open_forwarded_tcpip(
|
||||||
|
&mut self,
|
||||||
|
channel: Channel<Msg>,
|
||||||
|
connected_address: &str,
|
||||||
|
connected_port: u32,
|
||||||
|
originator_address: &str,
|
||||||
|
originator_port: u32,
|
||||||
|
session: &mut Session,
|
||||||
|
) -> Result<(), Self::Error> {
|
||||||
|
let span = debug_span!("server_channel_open_forwarded_tcpip",);
|
||||||
|
let _enter = span.enter();
|
||||||
|
debug!(
|
||||||
|
sshid = %String::from_utf8_lossy(session.remote_sshid()),
|
||||||
|
connected_address = connected_address,
|
||||||
|
connected_port = connected_port,
|
||||||
|
originator_address = originator_address,
|
||||||
|
originator_port = originator_port,
|
||||||
|
"New connection!"
|
||||||
|
);
|
||||||
|
let mut router: RouterIntoService<Incoming> = match self.router_rx.borrow().as_ref() {
|
||||||
|
None => {
|
||||||
|
return Err(anyhow!("No router has been set yet. Closing channel..."));
|
||||||
|
}
|
||||||
|
Some(ref router) => <Router as Clone>::clone(&router).into_service::<Incoming>(),
|
||||||
|
};
|
||||||
|
poll_fn(|cx| router.poll_ready(cx)).await.unwrap();
|
||||||
|
let service = service_fn(move |req| {
|
||||||
|
// Cloning our service for each call is required, given that service_fn expects Fn instead of FnMut.
|
||||||
|
// This should be fine performance-wise, since RouterIntoService is a thin wrapper around Router,
|
||||||
|
// which itself is a thin wrapper around Arc<RouterInner<_>>.
|
||||||
|
let mut router = router.clone();
|
||||||
|
async move { router.call(req).await }
|
||||||
|
});
|
||||||
|
let socket = TokioIo::new(channel.into_stream());
|
||||||
|
let server = Builder::new(TokioExecutor::new());
|
||||||
|
// I'm not really sure why tokio::spawn is necessary here, but it doesn't work otherwise.
|
||||||
|
// My guess is that we block on TcpForwardSession.start_forwarding_with().
|
||||||
|
// We use `serve_connection_with_upgrades` to allow upgrading to WebSocket - which will still run through
|
||||||
|
// our SSH tunnel for every message!
|
||||||
|
tokio::spawn(async move {
|
||||||
|
server
|
||||||
|
.serve_connection_with_upgrades(socket, service)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue