diff --git a/Cargo.lock b/Cargo.lock index d8b36ae..e1446de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -58,6 +58,15 @@ dependencies = [ "subtle", ] +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "anstream" version = "0.6.15" @@ -958,6 +967,15 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "matchit" version = "0.7.3" @@ -1384,6 +1402,50 @@ dependencies = [ "bitflags", ] +[[package]] +name = "regex" +version = "1.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata 0.4.7", + "regex-syntax 0.8.4", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", +] + +[[package]] +name = "regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.4", +] + +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + +[[package]] +name = "regex-syntax" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" + [[package]] name = "rfc6979" version = "0.4.0" @@ -1467,6 +1529,7 @@ dependencies = [ "axum", "clap", "futures", + "http", "hyper", "hyper-util", "russh", @@ -1474,6 +1537,7 @@ dependencies = [ "tokio-stream", "tokio-util", "tower 0.5.0", + "tower-service", "tracing", "tracing-subscriber", ] @@ -2021,10 +2085,14 @@ version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ + "matchers", "nu-ansi-term", + "once_cell", + "regex", "sharded-slab", "smallvec", "thread_local", + "tracing", "tracing-core", "tracing-log", ] diff --git a/Cargo.toml b/Cargo.toml index 404be24..dfd06f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ async-trait = "0.1" axum = "0.7.5" clap = { version = "4.5.17", features = ["derive"] } futures = "0.3.30" +http = "1.1.0" hyper = { version = "1", features = ["full"] } hyper-util = { version = "0.1", features = ["full"] } russh = "0.45" @@ -21,5 +22,6 @@ tokio = { version = "1", features = ["full"] } tokio-stream = { version = "0.1.15", features = ["net", "sync"] } tokio-util = "0.7.11" tower = "0.5.0" +tower-service = "0.3.3" tracing = "0.1" -tracing-subscriber = { version = "0.3.18" } +tracing-subscriber = { version = "0.3.18", features = ["fmt", "env-filter", "std"] } diff --git a/src/main.rs b/src/main.rs index c58ac84..56a5553 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,29 +7,23 @@ use std::{ use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; -use axum::{ - extract::State, - routing::{get, RouterIntoService}, - Router, -}; +use axum::{extract::State, routing::get, Router}; use clap::Parser; -use futures::future::poll_fn; -use hyper::{body::Incoming, service::service_fn}; +use hyper::service::service_fn; use hyper_util::{ rt::{TokioExecutor, TokioIo}, server::conn::auto::Builder, }; use russh::{ - client::{self, Config, Handle, Msg, Session}, - keys::{ - decode_secret_key, - key::{self, KeyPair}, - }, + client::{self, Config, Handle, KeyboardInteractiveAuthResponse, Msg, Session}, + keys::{decode_secret_key, key}, Channel, ChannelMsg, Disconnect, }; -use tokio::{fs, time::sleep}; +use tokio::io::AsyncWriteExt; +use tokio::{fs, io::stdout, time::sleep}; use tower::Service; use tracing::{debug, debug_span, error, info, trace, warn}; +use tracing_subscriber::{fmt, prelude::*, EnvFilter}; /* Entrypoint */ @@ -47,7 +41,7 @@ struct ClapArgs { /// Identity file containing private key #[arg(short, long)] - identity_file: PathBuf, + identity_file: Option<PathBuf>, /// Remote hostname to bind to #[arg(short, long, default_value_t = String::from("localhost"))] @@ -60,21 +54,32 @@ struct ClapArgs { #[tokio::main] async fn main() -> Result<()> { - let subscriber = tracing_subscriber::FmtSubscriber::new(); - tracing::subscriber::set_global_default(subscriber)?; + tracing_subscriber::registry() + .with(fmt::layer()) + .with(EnvFilter::from_default_env()) + .init(); trace!("Tracing is up!"); let args = ClapArgs::parse(); - let secret_key = fs::read_to_string(args.identity_file) - .await - .with_context(|| "Failed to open secret key")?; - let secret_key = decode_secret_key(&secret_key, None).with_context(|| "Invalid secret key")?; + let secret_key = match args.identity_file { + None => None, + Some(file) => { + let secret_key = fs::read_to_string(file) + .await + .with_context(|| "Failed to open secret key")?; + Some(decode_secret_key(&secret_key, None).with_context(|| "Invalid secret key")?) + } + }; let config = Arc::new(client::Config { ..Default::default() }); - let mut session = - TcpForwardSession::connect(&args.host, args.port, config, Arc::new(secret_key)) - .await - .with_context(|| "Initial connection failed")?; + let mut session = TcpForwardSession::connect( + &args.host, + args.port, + config, + secret_key.map(|key| Arc::new(key)), + ) + .await + .with_context(|| "Initial connection failed")?; loop { match session .start_forwarding(&args.remote_host, args.remote_port) @@ -139,7 +144,7 @@ async fn hello(State(state): State<AppState>) -> String { /// User-implemented session type as a helper for interfacing with the SSH protocol. struct TcpForwardSession { config: Arc<Config>, - secret_key: Arc<KeyPair>, + secret_key: Option<Arc<key::KeyPair>>, session: Handle<Client>, } @@ -150,7 +155,7 @@ impl TcpForwardSession { host: &str, port: u16, config: Arc<Config>, - secret_key: Arc<KeyPair>, + secret_key: Option<Arc<key::KeyPair>>, ) -> Result<Self> { let span = debug_span!("TcpForwardSession.connect"); let _enter = span; @@ -159,12 +164,57 @@ impl TcpForwardSession { let mut session = client::connect(Arc::clone(&config), (host, port), client) .await .with_context(|| "Unable to connect to remote host.")?; - if !session - .authenticate_publickey("root", Arc::clone(&secret_key)) - .await - .with_context(|| "Authentication error.")? - { - return Err(anyhow!("Authentication failed.")); + let authentication_result = match secret_key.as_ref() { + None => None, + Some(secret_key) => { + if session + .authenticate_publickey("root", Arc::clone(&secret_key)) + .await + .with_context(|| "Error while authenticating with public key.")? + { + debug!("Public key authentication succeeded!"); + Some(Ok(())) + } else { + Some(Err(anyhow!("Public key authentication failed."))) + } + } + }; + if matches!(authentication_result, None | Some(Err(_))) { + if authentication_result.is_some() { + debug!( + "Public key authentication failed; trying keyboard interactive authentication..." + ); + } + match session + .authenticate_keyboard_interactive_start("russh-axum-tcpip-forward", None) + .await + .with_context(|| "Error while authenticating with keyboard interactive session.")? + { + KeyboardInteractiveAuthResponse::Success => { + debug!("Keyboard interactive authentication succeeded!"); + } + KeyboardInteractiveAuthResponse::Failure => match authentication_result { + None => return Err(anyhow!("Keyboard interactive authentication failed.")), + Some(Err(result)) => { + debug!("Keyboard interactive authentication failed; propagating public key authentication error..."); + return Err(result); + } + _ => unreachable!(), + }, + response => match authentication_result { + None => { + return Err(anyhow!( + "Unhandled keyboard interactive authentication event {:?}", + response + )) + } + Some(Err(result)) => { + debug!("Keyboard interactive authentication failed; propagating public key authentication error..."); + return Err(result); + } + _ => unreachable!(), + }, + } } Ok(Self { config, @@ -175,7 +225,7 @@ impl TcpForwardSession { /// Sends a port forwarding request and opens a session to receive miscellaneous data. /// The function yields when the session is broken (for example, if the connection was lost). - async fn start_forwarding(&mut self, remote_host: &str, remote_port: u16) -> Result<()> { + async fn start_forwarding(&mut self, remote_host: &str, remote_port: u16) -> Result<u32> { let span = debug_span!("TcpForwardSession.start"); let _enter = span; self.session @@ -187,19 +237,29 @@ impl TcpForwardSession { .channel_open_session() .await .with_context(|| "channel_open_session error.")?; + debug!("Created open session channel."); + let mut stdout = stdout(); + let mut code = 0; loop { let Some(msg) = channel.wait().await else { return Err(anyhow!("Unexpected end of channel.")); }; + trace!("Got a message!"); match msg { - ChannelMsg::Data { data } => { - print!("{}", String::from_utf8_lossy(&data)); + ChannelMsg::Data { ref data } => { + stdout.write_all(data).await?; + stdout.flush().await?; } ChannelMsg::Close => break, + ChannelMsg::ExitStatus { exit_status } => { + debug!("Exited with code {exit_status}"); + channel.eof().await?; + code = exit_status; + } msg => return Err(anyhow!("Unknown message type {:?}.", msg)), } } - Ok(()) + Ok(code) } /// Attempts to reconnect to the SSH server. @@ -276,6 +336,8 @@ impl client::Handler for Client { /// To make Axum behave with streaming, we must turn it into a Tower service first. /// And to handle the SSH channel as a stream, we will use a utility method from Tokio that turns our /// AsyncRead/Write stream into a `hyper` IO object. + /// + /// See also: [axum/examples/serve-with-hyper](https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs) async fn server_channel_open_forwarded_tcpip( &mut self, channel: Channel<Msg>, @@ -296,25 +358,13 @@ impl client::Handler for Client { "New connection!" ); // Get our router from the lazy static. - let mut router: RouterIntoService<Incoming> = - <Router as Clone>::clone(&*ROUTER).into_service::<Incoming>(); - poll_fn(|cx| router.poll_ready(cx)).await.unwrap(); - let service = service_fn(move |req| { - // Cloning our service for each call is required, given that service_fn expects Fn instead of FnMut. - // This should be fine performance-wise, since RouterIntoService is a thin wrapper around Router, - // which itself is a thin wrapper around Arc<RouterInner<_>>. - let mut router = router.clone(); - async move { router.call(req).await } - }); - let socket = TokioIo::new(channel.into_stream()); + let router = &*ROUTER; + let service = service_fn(move |req| router.clone().call(req)); let server = Builder::new(TokioExecutor::new()); - // I'm not really sure why tokio::spawn is necessary here, but it doesn't work otherwise. - // My guess is that we block on TcpForwardSession.start_forwarding_with(). - // We use `serve_connection_with_upgrades` to allow upgrading to WebSocket - which will still run through - // our SSH tunnel for every message! + // tokio::spawn is required to let us reply over the data channel. tokio::spawn(async move { server - .serve_connection_with_upgrades(socket, service) + .serve_connection_with_upgrades(TokioIo::new(channel.into_stream()), service) .await .unwrap(); });