From 00b362621f306803ab1acda1a0fbc7a66659bb68 Mon Sep 17 00:00:00 2001
From: Bad Manners <me@badmanners.xyz>
Date: Fri, 6 Sep 2024 17:18:03 -0300
Subject: [PATCH] Clean up and first try at Serveo support

---
 Cargo.lock  |  68 +++++++++++++++++++++++
 Cargo.toml  |   4 +-
 src/main.rs | 154 ++++++++++++++++++++++++++++++++++------------------
 3 files changed, 173 insertions(+), 53 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index d8b36ae..e1446de 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -58,6 +58,15 @@ dependencies = [
  "subtle",
 ]
 
+[[package]]
+name = "aho-corasick"
+version = "1.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
+dependencies = [
+ "memchr",
+]
+
 [[package]]
 name = "anstream"
 version = "0.6.15"
@@ -958,6 +967,15 @@ version = "0.4.22"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
 
+[[package]]
+name = "matchers"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
+dependencies = [
+ "regex-automata 0.1.10",
+]
+
 [[package]]
 name = "matchit"
 version = "0.7.3"
@@ -1384,6 +1402,50 @@ dependencies = [
  "bitflags",
 ]
 
+[[package]]
+name = "regex"
+version = "1.10.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
+dependencies = [
+ "aho-corasick",
+ "memchr",
+ "regex-automata 0.4.7",
+ "regex-syntax 0.8.4",
+]
+
+[[package]]
+name = "regex-automata"
+version = "0.1.10"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
+dependencies = [
+ "regex-syntax 0.6.29",
+]
+
+[[package]]
+name = "regex-automata"
+version = "0.4.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
+dependencies = [
+ "aho-corasick",
+ "memchr",
+ "regex-syntax 0.8.4",
+]
+
+[[package]]
+name = "regex-syntax"
+version = "0.6.29"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
+
+[[package]]
+name = "regex-syntax"
+version = "0.8.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
+
 [[package]]
 name = "rfc6979"
 version = "0.4.0"
@@ -1467,6 +1529,7 @@ dependencies = [
  "axum",
  "clap",
  "futures",
+ "http",
  "hyper",
  "hyper-util",
  "russh",
@@ -1474,6 +1537,7 @@ dependencies = [
  "tokio-stream",
  "tokio-util",
  "tower 0.5.0",
+ "tower-service",
  "tracing",
  "tracing-subscriber",
 ]
@@ -2021,10 +2085,14 @@ version = "0.3.18"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
 dependencies = [
+ "matchers",
  "nu-ansi-term",
+ "once_cell",
+ "regex",
  "sharded-slab",
  "smallvec",
  "thread_local",
+ "tracing",
  "tracing-core",
  "tracing-log",
 ]
diff --git a/Cargo.toml b/Cargo.toml
index 404be24..dfd06f2 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -14,6 +14,7 @@ async-trait = "0.1"
 axum = "0.7.5"
 clap = { version = "4.5.17", features = ["derive"] }
 futures = "0.3.30"
+http = "1.1.0"
 hyper = { version = "1", features = ["full"] }
 hyper-util = { version = "0.1", features = ["full"] }
 russh = "0.45"
@@ -21,5 +22,6 @@ tokio = { version = "1", features = ["full"] }
 tokio-stream = { version = "0.1.15", features = ["net", "sync"] }
 tokio-util = "0.7.11"
 tower = "0.5.0"
+tower-service = "0.3.3"
 tracing = "0.1"
-tracing-subscriber = { version = "0.3.18" }
+tracing-subscriber = { version = "0.3.18", features = ["fmt", "env-filter", "std"] }
diff --git a/src/main.rs b/src/main.rs
index c58ac84..56a5553 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -7,29 +7,23 @@ use std::{
 
 use anyhow::{anyhow, Context, Result};
 use async_trait::async_trait;
-use axum::{
-    extract::State,
-    routing::{get, RouterIntoService},
-    Router,
-};
+use axum::{extract::State, routing::get, Router};
 use clap::Parser;
-use futures::future::poll_fn;
-use hyper::{body::Incoming, service::service_fn};
+use hyper::service::service_fn;
 use hyper_util::{
     rt::{TokioExecutor, TokioIo},
     server::conn::auto::Builder,
 };
 use russh::{
-    client::{self, Config, Handle, Msg, Session},
-    keys::{
-        decode_secret_key,
-        key::{self, KeyPair},
-    },
+    client::{self, Config, Handle, KeyboardInteractiveAuthResponse, Msg, Session},
+    keys::{decode_secret_key, key},
     Channel, ChannelMsg, Disconnect,
 };
-use tokio::{fs, time::sleep};
+use tokio::io::AsyncWriteExt;
+use tokio::{fs, io::stdout, time::sleep};
 use tower::Service;
 use tracing::{debug, debug_span, error, info, trace, warn};
+use tracing_subscriber::{fmt, prelude::*, EnvFilter};
 
 /* Entrypoint */
 
@@ -47,7 +41,7 @@ struct ClapArgs {
 
     /// Identity file containing private key
     #[arg(short, long)]
-    identity_file: PathBuf,
+    identity_file: Option<PathBuf>,
 
     /// Remote hostname to bind to
     #[arg(short, long, default_value_t = String::from("localhost"))]
@@ -60,21 +54,32 @@ struct ClapArgs {
 
 #[tokio::main]
 async fn main() -> Result<()> {
-    let subscriber = tracing_subscriber::FmtSubscriber::new();
-    tracing::subscriber::set_global_default(subscriber)?;
+    tracing_subscriber::registry()
+        .with(fmt::layer())
+        .with(EnvFilter::from_default_env())
+        .init();
     trace!("Tracing is up!");
     let args = ClapArgs::parse();
-    let secret_key = fs::read_to_string(args.identity_file)
-        .await
-        .with_context(|| "Failed to open secret key")?;
-    let secret_key = decode_secret_key(&secret_key, None).with_context(|| "Invalid secret key")?;
+    let secret_key = match args.identity_file {
+        None => None,
+        Some(file) => {
+            let secret_key = fs::read_to_string(file)
+                .await
+                .with_context(|| "Failed to open secret key")?;
+            Some(decode_secret_key(&secret_key, None).with_context(|| "Invalid secret key")?)
+        }
+    };
     let config = Arc::new(client::Config {
         ..Default::default()
     });
-    let mut session =
-        TcpForwardSession::connect(&args.host, args.port, config, Arc::new(secret_key))
-            .await
-            .with_context(|| "Initial connection failed")?;
+    let mut session = TcpForwardSession::connect(
+        &args.host,
+        args.port,
+        config,
+        secret_key.map(|key| Arc::new(key)),
+    )
+    .await
+    .with_context(|| "Initial connection failed")?;
     loop {
         match session
             .start_forwarding(&args.remote_host, args.remote_port)
@@ -139,7 +144,7 @@ async fn hello(State(state): State<AppState>) -> String {
 /// User-implemented session type as a helper for interfacing with the SSH protocol.
 struct TcpForwardSession {
     config: Arc<Config>,
-    secret_key: Arc<KeyPair>,
+    secret_key: Option<Arc<key::KeyPair>>,
     session: Handle<Client>,
 }
 
@@ -150,7 +155,7 @@ impl TcpForwardSession {
         host: &str,
         port: u16,
         config: Arc<Config>,
-        secret_key: Arc<KeyPair>,
+        secret_key: Option<Arc<key::KeyPair>>,
     ) -> Result<Self> {
         let span = debug_span!("TcpForwardSession.connect");
         let _enter = span;
@@ -159,12 +164,57 @@ impl TcpForwardSession {
         let mut session = client::connect(Arc::clone(&config), (host, port), client)
             .await
             .with_context(|| "Unable to connect to remote host.")?;
-        if !session
-            .authenticate_publickey("root", Arc::clone(&secret_key))
-            .await
-            .with_context(|| "Authentication error.")?
-        {
-            return Err(anyhow!("Authentication failed."));
+        let authentication_result = match secret_key.as_ref() {
+            None => None,
+            Some(secret_key) => {
+                if session
+                    .authenticate_publickey("root", Arc::clone(&secret_key))
+                    .await
+                    .with_context(|| "Error while authenticating with public key.")?
+                {
+                    debug!("Public key authentication succeeded!");
+                    Some(Ok(()))
+                } else {
+                    Some(Err(anyhow!("Public key authentication failed.")))
+                }
+            }
+        };
+        if matches!(authentication_result, None | Some(Err(_))) {
+            if authentication_result.is_some() {
+                debug!(
+                    "Public key authentication failed; trying keyboard interactive authentication..."
+                );
+            }
+            match session
+                .authenticate_keyboard_interactive_start("russh-axum-tcpip-forward", None)
+                .await
+                .with_context(|| "Error while authenticating with keyboard interactive session.")?
+            {
+                KeyboardInteractiveAuthResponse::Success => {
+                    debug!("Keyboard interactive authentication succeeded!");
+                }
+                KeyboardInteractiveAuthResponse::Failure => match authentication_result {
+                    None => return Err(anyhow!("Keyboard interactive authentication failed.")),
+                    Some(Err(result)) => {
+                        debug!("Keyboard interactive authentication failed; propagating public key authentication error...");
+                        return Err(result);
+                    }
+                    _ => unreachable!(),
+                },
+                response => match authentication_result {
+                    None => {
+                        return Err(anyhow!(
+                            "Unhandled keyboard interactive authentication event {:?}",
+                            response
+                        ))
+                    }
+                    Some(Err(result)) => {
+                        debug!("Keyboard interactive authentication failed; propagating public key authentication error...");
+                        return Err(result);
+                    }
+                    _ => unreachable!(),
+                },
+            }
         }
         Ok(Self {
             config,
@@ -175,7 +225,7 @@ impl TcpForwardSession {
 
     /// Sends a port forwarding request and opens a session to receive miscellaneous data.
     /// The function yields when the session is broken (for example, if the connection was lost).
-    async fn start_forwarding(&mut self, remote_host: &str, remote_port: u16) -> Result<()> {
+    async fn start_forwarding(&mut self, remote_host: &str, remote_port: u16) -> Result<u32> {
         let span = debug_span!("TcpForwardSession.start");
         let _enter = span;
         self.session
@@ -187,19 +237,29 @@ impl TcpForwardSession {
             .channel_open_session()
             .await
             .with_context(|| "channel_open_session error.")?;
+        debug!("Created open session channel.");
+        let mut stdout = stdout();
+        let mut code = 0;
         loop {
             let Some(msg) = channel.wait().await else {
                 return Err(anyhow!("Unexpected end of channel."));
             };
+            trace!("Got a message!");
             match msg {
-                ChannelMsg::Data { data } => {
-                    print!("{}", String::from_utf8_lossy(&data));
+                ChannelMsg::Data { ref data } => {
+                    stdout.write_all(data).await?;
+                    stdout.flush().await?;
                 }
                 ChannelMsg::Close => break,
+                ChannelMsg::ExitStatus { exit_status } => {
+                    debug!("Exited with code {exit_status}");
+                    channel.eof().await?;
+                    code = exit_status;
+                }
                 msg => return Err(anyhow!("Unknown message type {:?}.", msg)),
             }
         }
-        Ok(())
+        Ok(code)
     }
 
     /// Attempts to reconnect to the SSH server.
@@ -276,6 +336,8 @@ impl client::Handler for Client {
     /// To make Axum behave with streaming, we must turn it into a Tower service first.
     /// And to handle the SSH channel as a stream, we will use a utility method from Tokio that turns our
     /// AsyncRead/Write stream into a `hyper` IO object.
+    ///
+    /// See also: [axum/examples/serve-with-hyper](https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs)
     async fn server_channel_open_forwarded_tcpip(
         &mut self,
         channel: Channel<Msg>,
@@ -296,25 +358,13 @@ impl client::Handler for Client {
             "New connection!"
         );
         // Get our router from the lazy static.
-        let mut router: RouterIntoService<Incoming> =
-            <Router as Clone>::clone(&*ROUTER).into_service::<Incoming>();
-        poll_fn(|cx| router.poll_ready(cx)).await.unwrap();
-        let service = service_fn(move |req| {
-            // Cloning our service for each call is required, given that service_fn expects Fn instead of FnMut.
-            // This should be fine performance-wise, since RouterIntoService is a thin wrapper around Router,
-            // which itself is a thin wrapper around Arc<RouterInner<_>>.
-            let mut router = router.clone();
-            async move { router.call(req).await }
-        });
-        let socket = TokioIo::new(channel.into_stream());
+        let router = &*ROUTER;
+        let service = service_fn(move |req| router.clone().call(req));
         let server = Builder::new(TokioExecutor::new());
-        // I'm not really sure why tokio::spawn is necessary here, but it doesn't work otherwise.
-        // My guess is that we block on TcpForwardSession.start_forwarding_with().
-        // We use `serve_connection_with_upgrades` to allow upgrading to WebSocket - which will still run through
-        // our SSH tunnel for every message!
+        // tokio::spawn is required to let us reply over the data channel.
         tokio::spawn(async move {
             server
-                .serve_connection_with_upgrades(socket, service)
+                .serve_connection_with_upgrades(TokioIo::new(channel.into_stream()), service)
                 .await
                 .unwrap();
         });