Clean up and first try at Serveo support

This commit is contained in:
Bad Manners 2024-09-06 17:18:03 -03:00
parent 9e046c8821
commit 00b362621f
Signed by: badmanners
GPG key ID: 8C88292CCB075609
3 changed files with 173 additions and 53 deletions

68
Cargo.lock generated
View file

@ -58,6 +58,15 @@ dependencies = [
"subtle", "subtle",
] ]
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.15" version = "0.6.15"
@ -958,6 +967,15 @@ version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata 0.1.10",
]
[[package]] [[package]]
name = "matchit" name = "matchit"
version = "0.7.3" version = "0.7.3"
@ -1384,6 +1402,50 @@ dependencies = [
"bitflags", "bitflags",
] ]
[[package]]
name = "regex"
version = "1.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata 0.4.7",
"regex-syntax 0.8.4",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax 0.6.29",
]
[[package]]
name = "regex-automata"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax 0.8.4",
]
[[package]]
name = "regex-syntax"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
[[package]] [[package]]
name = "rfc6979" name = "rfc6979"
version = "0.4.0" version = "0.4.0"
@ -1467,6 +1529,7 @@ dependencies = [
"axum", "axum",
"clap", "clap",
"futures", "futures",
"http",
"hyper", "hyper",
"hyper-util", "hyper-util",
"russh", "russh",
@ -1474,6 +1537,7 @@ dependencies = [
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"tower 0.5.0", "tower 0.5.0",
"tower-service",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
] ]
@ -2021,10 +2085,14 @@ version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [ dependencies = [
"matchers",
"nu-ansi-term", "nu-ansi-term",
"once_cell",
"regex",
"sharded-slab", "sharded-slab",
"smallvec", "smallvec",
"thread_local", "thread_local",
"tracing",
"tracing-core", "tracing-core",
"tracing-log", "tracing-log",
] ]

View file

@ -14,6 +14,7 @@ async-trait = "0.1"
axum = "0.7.5" axum = "0.7.5"
clap = { version = "4.5.17", features = ["derive"] } clap = { version = "4.5.17", features = ["derive"] }
futures = "0.3.30" futures = "0.3.30"
http = "1.1.0"
hyper = { version = "1", features = ["full"] } hyper = { version = "1", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] } hyper-util = { version = "0.1", features = ["full"] }
russh = "0.45" russh = "0.45"
@ -21,5 +22,6 @@ tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1.15", features = ["net", "sync"] } tokio-stream = { version = "0.1.15", features = ["net", "sync"] }
tokio-util = "0.7.11" tokio-util = "0.7.11"
tower = "0.5.0" tower = "0.5.0"
tower-service = "0.3.3"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3.18" } tracing-subscriber = { version = "0.3.18", features = ["fmt", "env-filter", "std"] }

View file

@ -7,29 +7,23 @@ use std::{
use anyhow::{anyhow, Context, Result}; use anyhow::{anyhow, Context, Result};
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{extract::State, routing::get, Router};
extract::State,
routing::{get, RouterIntoService},
Router,
};
use clap::Parser; use clap::Parser;
use futures::future::poll_fn; use hyper::service::service_fn;
use hyper::{body::Incoming, service::service_fn};
use hyper_util::{ use hyper_util::{
rt::{TokioExecutor, TokioIo}, rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder, server::conn::auto::Builder,
}; };
use russh::{ use russh::{
client::{self, Config, Handle, Msg, Session}, client::{self, Config, Handle, KeyboardInteractiveAuthResponse, Msg, Session},
keys::{ keys::{decode_secret_key, key},
decode_secret_key,
key::{self, KeyPair},
},
Channel, ChannelMsg, Disconnect, Channel, ChannelMsg, Disconnect,
}; };
use tokio::{fs, time::sleep}; use tokio::io::AsyncWriteExt;
use tokio::{fs, io::stdout, time::sleep};
use tower::Service; use tower::Service;
use tracing::{debug, debug_span, error, info, trace, warn}; use tracing::{debug, debug_span, error, info, trace, warn};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
/* Entrypoint */ /* Entrypoint */
@ -47,7 +41,7 @@ struct ClapArgs {
/// Identity file containing private key /// Identity file containing private key
#[arg(short, long)] #[arg(short, long)]
identity_file: PathBuf, identity_file: Option<PathBuf>,
/// Remote hostname to bind to /// Remote hostname to bind to
#[arg(short, long, default_value_t = String::from("localhost"))] #[arg(short, long, default_value_t = String::from("localhost"))]
@ -60,21 +54,32 @@ struct ClapArgs {
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
let subscriber = tracing_subscriber::FmtSubscriber::new(); tracing_subscriber::registry()
tracing::subscriber::set_global_default(subscriber)?; .with(fmt::layer())
.with(EnvFilter::from_default_env())
.init();
trace!("Tracing is up!"); trace!("Tracing is up!");
let args = ClapArgs::parse(); let args = ClapArgs::parse();
let secret_key = fs::read_to_string(args.identity_file) let secret_key = match args.identity_file {
.await None => None,
.with_context(|| "Failed to open secret key")?; Some(file) => {
let secret_key = decode_secret_key(&secret_key, None).with_context(|| "Invalid secret key")?; let secret_key = fs::read_to_string(file)
.await
.with_context(|| "Failed to open secret key")?;
Some(decode_secret_key(&secret_key, None).with_context(|| "Invalid secret key")?)
}
};
let config = Arc::new(client::Config { let config = Arc::new(client::Config {
..Default::default() ..Default::default()
}); });
let mut session = let mut session = TcpForwardSession::connect(
TcpForwardSession::connect(&args.host, args.port, config, Arc::new(secret_key)) &args.host,
.await args.port,
.with_context(|| "Initial connection failed")?; config,
secret_key.map(|key| Arc::new(key)),
)
.await
.with_context(|| "Initial connection failed")?;
loop { loop {
match session match session
.start_forwarding(&args.remote_host, args.remote_port) .start_forwarding(&args.remote_host, args.remote_port)
@ -139,7 +144,7 @@ async fn hello(State(state): State<AppState>) -> String {
/// User-implemented session type as a helper for interfacing with the SSH protocol. /// User-implemented session type as a helper for interfacing with the SSH protocol.
struct TcpForwardSession { struct TcpForwardSession {
config: Arc<Config>, config: Arc<Config>,
secret_key: Arc<KeyPair>, secret_key: Option<Arc<key::KeyPair>>,
session: Handle<Client>, session: Handle<Client>,
} }
@ -150,7 +155,7 @@ impl TcpForwardSession {
host: &str, host: &str,
port: u16, port: u16,
config: Arc<Config>, config: Arc<Config>,
secret_key: Arc<KeyPair>, secret_key: Option<Arc<key::KeyPair>>,
) -> Result<Self> { ) -> Result<Self> {
let span = debug_span!("TcpForwardSession.connect"); let span = debug_span!("TcpForwardSession.connect");
let _enter = span; let _enter = span;
@ -159,12 +164,57 @@ impl TcpForwardSession {
let mut session = client::connect(Arc::clone(&config), (host, port), client) let mut session = client::connect(Arc::clone(&config), (host, port), client)
.await .await
.with_context(|| "Unable to connect to remote host.")?; .with_context(|| "Unable to connect to remote host.")?;
if !session let authentication_result = match secret_key.as_ref() {
.authenticate_publickey("root", Arc::clone(&secret_key)) None => None,
.await Some(secret_key) => {
.with_context(|| "Authentication error.")? if session
{ .authenticate_publickey("root", Arc::clone(&secret_key))
return Err(anyhow!("Authentication failed.")); .await
.with_context(|| "Error while authenticating with public key.")?
{
debug!("Public key authentication succeeded!");
Some(Ok(()))
} else {
Some(Err(anyhow!("Public key authentication failed.")))
}
}
};
if matches!(authentication_result, None | Some(Err(_))) {
if authentication_result.is_some() {
debug!(
"Public key authentication failed; trying keyboard interactive authentication..."
);
}
match session
.authenticate_keyboard_interactive_start("russh-axum-tcpip-forward", None)
.await
.with_context(|| "Error while authenticating with keyboard interactive session.")?
{
KeyboardInteractiveAuthResponse::Success => {
debug!("Keyboard interactive authentication succeeded!");
}
KeyboardInteractiveAuthResponse::Failure => match authentication_result {
None => return Err(anyhow!("Keyboard interactive authentication failed.")),
Some(Err(result)) => {
debug!("Keyboard interactive authentication failed; propagating public key authentication error...");
return Err(result);
}
_ => unreachable!(),
},
response => match authentication_result {
None => {
return Err(anyhow!(
"Unhandled keyboard interactive authentication event {:?}",
response
))
}
Some(Err(result)) => {
debug!("Keyboard interactive authentication failed; propagating public key authentication error...");
return Err(result);
}
_ => unreachable!(),
},
}
} }
Ok(Self { Ok(Self {
config, config,
@ -175,7 +225,7 @@ impl TcpForwardSession {
/// Sends a port forwarding request and opens a session to receive miscellaneous data. /// Sends a port forwarding request and opens a session to receive miscellaneous data.
/// The function yields when the session is broken (for example, if the connection was lost). /// The function yields when the session is broken (for example, if the connection was lost).
async fn start_forwarding(&mut self, remote_host: &str, remote_port: u16) -> Result<()> { async fn start_forwarding(&mut self, remote_host: &str, remote_port: u16) -> Result<u32> {
let span = debug_span!("TcpForwardSession.start"); let span = debug_span!("TcpForwardSession.start");
let _enter = span; let _enter = span;
self.session self.session
@ -187,19 +237,29 @@ impl TcpForwardSession {
.channel_open_session() .channel_open_session()
.await .await
.with_context(|| "channel_open_session error.")?; .with_context(|| "channel_open_session error.")?;
debug!("Created open session channel.");
let mut stdout = stdout();
let mut code = 0;
loop { loop {
let Some(msg) = channel.wait().await else { let Some(msg) = channel.wait().await else {
return Err(anyhow!("Unexpected end of channel.")); return Err(anyhow!("Unexpected end of channel."));
}; };
trace!("Got a message!");
match msg { match msg {
ChannelMsg::Data { data } => { ChannelMsg::Data { ref data } => {
print!("{}", String::from_utf8_lossy(&data)); stdout.write_all(data).await?;
stdout.flush().await?;
} }
ChannelMsg::Close => break, ChannelMsg::Close => break,
ChannelMsg::ExitStatus { exit_status } => {
debug!("Exited with code {exit_status}");
channel.eof().await?;
code = exit_status;
}
msg => return Err(anyhow!("Unknown message type {:?}.", msg)), msg => return Err(anyhow!("Unknown message type {:?}.", msg)),
} }
} }
Ok(()) Ok(code)
} }
/// Attempts to reconnect to the SSH server. /// Attempts to reconnect to the SSH server.
@ -276,6 +336,8 @@ impl client::Handler for Client {
/// To make Axum behave with streaming, we must turn it into a Tower service first. /// To make Axum behave with streaming, we must turn it into a Tower service first.
/// And to handle the SSH channel as a stream, we will use a utility method from Tokio that turns our /// And to handle the SSH channel as a stream, we will use a utility method from Tokio that turns our
/// AsyncRead/Write stream into a `hyper` IO object. /// AsyncRead/Write stream into a `hyper` IO object.
///
/// See also: [axum/examples/serve-with-hyper](https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs)
async fn server_channel_open_forwarded_tcpip( async fn server_channel_open_forwarded_tcpip(
&mut self, &mut self,
channel: Channel<Msg>, channel: Channel<Msg>,
@ -296,25 +358,13 @@ impl client::Handler for Client {
"New connection!" "New connection!"
); );
// Get our router from the lazy static. // Get our router from the lazy static.
let mut router: RouterIntoService<Incoming> = let router = &*ROUTER;
<Router as Clone>::clone(&*ROUTER).into_service::<Incoming>(); let service = service_fn(move |req| router.clone().call(req));
poll_fn(|cx| router.poll_ready(cx)).await.unwrap();
let service = service_fn(move |req| {
// Cloning our service for each call is required, given that service_fn expects Fn instead of FnMut.
// This should be fine performance-wise, since RouterIntoService is a thin wrapper around Router,
// which itself is a thin wrapper around Arc<RouterInner<_>>.
let mut router = router.clone();
async move { router.call(req).await }
});
let socket = TokioIo::new(channel.into_stream());
let server = Builder::new(TokioExecutor::new()); let server = Builder::new(TokioExecutor::new());
// I'm not really sure why tokio::spawn is necessary here, but it doesn't work otherwise. // tokio::spawn is required to let us reply over the data channel.
// My guess is that we block on TcpForwardSession.start_forwarding_with().
// We use `serve_connection_with_upgrades` to allow upgrading to WebSocket - which will still run through
// our SSH tunnel for every message!
tokio::spawn(async move { tokio::spawn(async move {
server server
.serve_connection_with_upgrades(socket, service) .serve_connection_with_upgrades(TokioIo::new(channel.into_stream()), service)
.await .await
.unwrap(); .unwrap();
}); });