Only use private key auth

This commit is contained in:
Bad Manners 2024-09-08 12:01:14 -03:00
parent 9dc4254647
commit 68a238b31d

View file

@ -8,15 +8,18 @@ use std::{
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use axum::{extract::State, routing::get, Router};
use clap::{Args, Parser};
use clap::Parser;
use hyper::service::service_fn;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder,
};
use russh::{
client::{self, Config, Handle, KeyboardInteractiveAuthResponse, Msg, Session},
keys::{decode_secret_key, key},
client::{self, Config, Handle, Msg, Session},
keys::{
decode_secret_key,
key::{self, KeyPair},
},
Channel, ChannelId, ChannelMsg, Disconnect,
};
use tokio::io::AsyncWriteExt;
@ -38,16 +41,17 @@ struct ClapArgs {
/// SSH hostname
host: String,
/// Identity file containing private key
#[arg(short, long, default_value_t = String::from(""))]
login_name: String,
/// SSH port
#[arg(short, long, default_value_t = 22)]
port: u16,
#[command(flatten)]
auth: Option<Authentication>,
/// Identity file containing private key
#[arg(short, long, default_value_t = String::from(""))]
login_name: String,
/// Identity file containing private key.
#[arg(short, long, value_name = "FILE")]
identity_file: PathBuf,
/// Remote hostname to bind to
#[arg(short = 'R', long, default_value_t = String::from(""))]
@ -62,18 +66,6 @@ struct ClapArgs {
request_pty: Option<String>,
}
#[derive(Args, Debug)]
#[group(required = false, multiple = false)]
struct Authentication {
/// Identity file containing private key.
#[arg(short, long, value_name = "FILE")]
identity_file: Option<PathBuf>,
/// Request keyboard-interactive based SSH authentication.
#[arg(long)]
keyboard_interactive: bool,
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
@ -82,35 +74,18 @@ async fn main() -> Result<()> {
.init();
trace!("Tracing is up!");
let args = ClapArgs::parse();
let session_auth = match args.auth {
None => None,
Some(auth) => {
if auth.keyboard_interactive {
Some(SessionAuth::KeyboardInteractive)
} else if let Some(file) = auth.identity_file {
let secret_key = fs::read_to_string(file)
.await
.with_context(|| "Failed to open secret key")?;
Some(SessionAuth::SecretKey(Arc::new(
decode_secret_key(&secret_key, None).with_context(|| "Invalid secret key")?,
)))
} else {
unreachable!();
}
}
};
let secret_key = fs::read_to_string(args.identity_file)
.await
.with_context(|| "Failed to open secret key")?;
let secret_key =
Arc::new(decode_secret_key(&secret_key, None).with_context(|| "Invalid secret key")?);
let config = Arc::new(client::Config {
..Default::default()
});
let mut session = TcpForwardSession::connect(
&args.host,
args.port,
&args.login_name,
config,
&session_auth,
)
.await
.with_context(|| "Initial connection failed")?;
let mut session =
TcpForwardSession::connect(&args.host, args.port, &args.login_name, config, secret_key)
.await
.with_context(|| "Initial connection failed")?;
loop {
match session
.start_forwarding(
@ -177,18 +152,11 @@ async fn hello(State(state): State<AppState>) -> String {
/* Russh session and client */
/// Private type to decide on the authentication method.
#[derive(Clone)]
enum SessionAuth {
SecretKey(Arc<key::KeyPair>),
KeyboardInteractive,
}
/// User-implemented session type as a helper for interfacing with the SSH protocol.
struct TcpForwardSession {
config: Arc<Config>,
session_auth: Option<SessionAuth>,
session: Handle<Client>,
secret_key: Arc<KeyPair>,
}
/// User-implemented session type as a helper for interfacing with the SSH protocol.
@ -199,7 +167,7 @@ impl TcpForwardSession {
port: u16,
login_name: &str,
config: Arc<Config>,
session_auth: &Option<SessionAuth>,
secret_key: Arc<KeyPair>,
) -> Result<Self> {
let span = debug_span!("TcpForwardSession.connect");
let _enter = span;
@ -208,60 +176,20 @@ impl TcpForwardSession {
let mut session = client::connect(Arc::clone(&config), (host, port), client)
.await
.with_context(|| "Unable to connect to remote host.")?;
let session = match session_auth {
Some(SessionAuth::SecretKey(ref secret_key)) => {
if session
.authenticate_publickey(login_name, Arc::clone(secret_key))
.await
.with_context(|| "Error while authenticating with public key.")?
{
debug!("Public key authentication succeeded!");
Ok(session)
} else {
Err(anyhow!("Public key authentication failed."))
}
}
Some(SessionAuth::KeyboardInteractive) => {
match session
.authenticate_keyboard_interactive_start(login_name, None)
.await
.with_context(|| {
"Error while authenticating with keyboard interactive session."
})? {
KeyboardInteractiveAuthResponse::Success => {
debug!("Keyboard interactive authentication succeeded!");
Ok(session)
}
KeyboardInteractiveAuthResponse::Failure => {
Err(anyhow!("Keyboard interactive authentication failed."))
}
response => Err(anyhow!(
"Unhandled keyboard interactive authentication event {:?}",
response
)),
}
}
None => {
if session
.authenticate_none(login_name)
.await
.with_context(|| "Error while authenticating without credentials.")?
{
debug!("Authentication without credentials succeeded!");
Ok(session)
} else {
Err(anyhow!("Authentication without credentials failed."))
}
}
};
match session {
Ok(session) => Ok(Self {
config,
session,
session_auth: session_auth.clone(),
}),
Err(e) => Err(e),
if session
.authenticate_publickey(login_name, Arc::clone(&secret_key))
.await
.with_context(|| "Error while authenticating with public key.")?
{
debug!("Public key authentication succeeded!");
} else {
return Err(anyhow!("Public key authentication failed."));
}
Ok(Self {
config,
session,
secret_key,
})
}
/// Sends a port forwarding request and opens a session to receive miscellaneous data.
@ -351,12 +279,16 @@ impl TcpForwardSession {
timer_iterator: impl Iterator<Item = Duration>,
) -> Result<Self> {
let TcpForwardSession {
config,
session_auth,
..
config, secret_key, ..
} = self;
match TcpForwardSession::connect(host, port, login_name, config.clone(), &session_auth)
.await
match TcpForwardSession::connect(
host,
port,
login_name,
Arc::clone(&config),
Arc::clone(&secret_key),
)
.await
{
Err(err) => {
let mut e = err;
@ -367,8 +299,8 @@ impl TcpForwardSession {
host,
port,
login_name,
config.clone(),
&session_auth,
Arc::clone(&config),
Arc::clone(&secret_key),
)
.await
{