Initial commit - still working on MPbN

This commit is contained in:
Bad Manners 2024-09-09 11:33:04 -03:00
commit c2fcc5b210
11 changed files with 4107 additions and 0 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
/target/
/credentials/

2967
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

25
Cargo.toml Normal file
View file

@ -0,0 +1,25 @@
[package]
name = "tank-drone-test"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = "1.0.86"
async-trait = "0.1"
axum = "0.7.5"
axum-macros = "0.4.1"
bitvec = "1.0.1"
clap = { version = "4.5.17", features = ["derive"] }
futures = "0.3.30"
hyper = { version = "1", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] }
maud = { version = "0.26.0", features = ["axum"] }
reqwest = "0.12.7"
russh = "0.45"
termsize = "0.1.9"
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1.15", features = ["net", "sync"] }
tokio-util = "0.7.11"
tower = "0.5.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "env-filter", "std"] }

11
README.md Normal file
View file

@ -0,0 +1,11 @@
# htmx-axum-russh-games
A few silly games I made while I learn about Axum and HTMX.
## checkbox.rs
A poor man's clone of A Million Checkboxes.
## multipaint_by_numbers.rs
A multiplayer Picross/Nonogram, inspired by the project above.

83
src/entrypoint.rs Normal file
View file

@ -0,0 +1,83 @@
use std::{iter, path::PathBuf, sync::Arc, time::Duration};
use anyhow::{Context, Result};
use axum::Router;
use russh::{client, keys::decode_secret_key};
use tokio::{fs, net::TcpListener};
use tracing::{debug, error, info};
use crate::{http::ROUTER, ssh::TcpForwardSession};
/* Local server entrypoint */
/// Spins up a local Axum server for development.
pub async fn local_server_entrypoint(hostname: &str, port: u16) -> Result<()> {
let listener = TcpListener::bind((hostname, port))
.await
.with_context(|| "Failed to bind TCP listener")?;
println!("Listening on http://{}:{}", hostname, port);
axum::serve(
listener,
Router::clone(
ROUTER
.get()
.with_context(|| "Router hasn't been initialized.")?,
),
)
.await
.with_context(|| "Server has closed.")
}
/* SSH entrypoint */
/// Begins remote port forwarding (reverse tunneling) with Russh to serve an Axum application.
pub async fn ssh_entrypoint(
host: &str,
port: u16,
login_name: &str,
identity_file: PathBuf,
remote_host: &str,
remote_port: u16,
request_pty: Option<String>,
) -> Result<()> {
let secret_key = fs::read_to_string(identity_file)
.await
.with_context(|| "Failed to open secret key")?;
let secret_key =
Arc::new(decode_secret_key(&secret_key, None).with_context(|| "Invalid secret key")?);
let config = Arc::new(client::Config {
..Default::default()
});
loop {
let mut reconnect_attempt = 0;
let mut session = TcpForwardSession::connect(
host,
port,
login_name,
Arc::clone(&config),
Arc::clone(&secret_key),
iter::from_fn(move || {
reconnect_attempt += 1;
if reconnect_attempt <= 5 {
Some(Duration::from_secs(2 * reconnect_attempt))
} else {
None
}
}),
)
.await
.with_context(|| "Connection failed.")?;
match session
.start_forwarding(remote_host, remote_port, request_pty.as_deref())
.await
{
Err(e) => error!(error = ?e, "TCP forward session failed."),
_ => info!("Connection closed."),
}
debug!("Attempting graceful disconnect.");
if let Err(e) = session.close().await {
debug!(error = ?e, "Graceful disconnect failed.")
}
debug!("Restarting connection.");
}
}

124
src/http/checkbox.rs Normal file
View file

@ -0,0 +1,124 @@
use std::sync::{Arc, Mutex};
use axum::{
extract::{Path, State},
routing::{delete, get, put},
Router,
};
use bitvec::{array::BitArray, order::Lsb0, BitArr};
use hyper::StatusCode;
use maud::{html, Markup, DOCTYPE};
#[derive(Clone)]
struct AppState {
checkboxes: Arc<Mutex<BitArr!(for CHECKBOX_WIDTH*CHECKBOX_HEIGHT, in usize, Lsb0)>>,
}
const CHECKBOX_WIDTH: usize = 20;
const CHECKBOX_HEIGHT: usize = 20;
/// A lazily-created Router, to be used by the SSH client tunnels.
pub fn get_router() -> Router {
Router::new()
.route("/", get(index))
.route("/checkboxes", get(all_checkboxes))
.route("/checkbox/:id", put(mark_checkbox))
.route("/checkbox/:id", delete(unmark_checkbox))
.with_state(AppState {
checkboxes: Arc::new(Mutex::new(BitArray::ZERO)),
})
}
fn style() -> &'static str {
r#"
body {
width: fit-content;
}
ul {
display: grid;
list-style: none;
padding-left: 0;
gap: 2px;
}
li {
width: 20px;
height: 20px;
}
"#
}
fn head() -> Markup {
html! {
(DOCTYPE)
head {
meta charset="utf-8";
title { (CHECKBOX_WIDTH*CHECKBOX_HEIGHT) " Checkboxes" }
script src="https://unpkg.com/htmx.org@2.0.2" integrity="sha384-Y7hw+L/jvKeWIRRkqWYfPcvVxHzVzn5REgzbawhxAuQGwX1XWe70vji+VSeHOThJ" crossorigin="anonymous" {}
style { (style()) }
}
}
}
async fn index() -> Markup {
html! {
(head())
body {
h1 { (CHECKBOX_WIDTH*CHECKBOX_HEIGHT) " Checkboxes" }
ul hx-get="/checkboxes" hx-trigger="load" {}
}
}
}
async fn all_checkboxes(State(state): State<AppState>) -> Markup {
html! {
ul hx-get="/checkboxes" hx-trigger="every 3s" style=(format!("grid-template-columns: repeat({}, minmax(0, 1fr));", CHECKBOX_WIDTH)) {
@for (id, checkbox) in state.checkboxes.lock().unwrap()[..CHECKBOX_WIDTH*CHECKBOX_HEIGHT].iter().by_vals().enumerate() {
li {
@if checkbox {
(checked(id))
} @else {
(unchecked(id))
}
}
}
}
}
}
fn checked(id: usize) -> Markup {
html! {
input id=(format!("cb-{}", id)) type="checkbox" hx-delete=(format!("/checkbox/{}", id)) hx-trigger="click" checked {}
}
}
fn unchecked(id: usize) -> Markup {
html! {
input id=(format!("cb-{}", id)) type="checkbox" hx-put=(format!("/checkbox/{}", id)) hx-trigger="click" {}
}
}
async fn mark_checkbox(
State(state): State<AppState>,
Path(id): Path<usize>,
) -> Result<Markup, StatusCode> {
match state.checkboxes.lock().unwrap().get_mut(id) {
None => Err(StatusCode::NOT_FOUND),
Some(mut checkbox) => {
*checkbox = true;
Ok(checked(id))
}
}
}
async fn unmark_checkbox(
State(state): State<AppState>,
Path(id): Path<usize>,
) -> Result<Markup, StatusCode> {
match state.checkboxes.lock().unwrap().get_mut(id) {
None => Err(StatusCode::NOT_FOUND),
Some(mut checkbox) => {
*checkbox = false;
Ok(unchecked(id))
}
}
}

9
src/http/mod.rs Normal file
View file

@ -0,0 +1,9 @@
use std::sync::OnceLock;
use axum::Router;
pub mod checkbox;
pub mod multipaint_by_numbers;
/// A lazily-created Router, to be used by the SSH client tunnels or directly by the HTTP server.
pub static ROUTER: OnceLock<Router> = OnceLock::new();

View file

@ -0,0 +1,499 @@
use std::{
borrow::BorrowMut,
mem,
ops::{BitXor, DerefMut},
sync::{Arc, Mutex},
time::Duration,
};
use anyhow::{anyhow, Result};
use axum::{
extract::{Path, State},
routing::{delete, get, put},
Router,
};
use bitvec::{bitvec, order::Lsb0, vec::BitVec};
use hyper::StatusCode;
use maud::{html, Markup, PreEscaped, DOCTYPE};
use reqwest::redirect::Policy;
use tokio::time::sleep;
use tracing::{debug, info, warn};
#[derive(Clone)]
struct Puzzle {
id: u32,
title: Option<String>,
copyright: Option<String>,
rows: Vec<Vec<u8>>,
columns: Vec<Vec<u8>>,
solution: BitVec<usize, Lsb0>,
is_solved: bool,
}
#[derive(Copy, Clone, PartialEq)]
enum CheckboxState {
Empty,
Flagged,
Marked,
}
#[derive(Clone)]
struct AppState {
checkboxes: Arc<Mutex<Vec<CheckboxState>>>,
current_puzzle: Arc<Mutex<Puzzle>>,
}
enum GetPuzzleState {
Start,
ReadingRows,
ReadingColumns,
}
async fn get_puzzle() -> Result<Puzzle> {
let client = reqwest::ClientBuilder::new()
.redirect(Policy::none())
.build()
.unwrap();
let redirect_response = client
.post("https://webpbn.com/random.cgi")
.form(&[
("sid", ""),
("go", "1"),
("psize", "1"),
("pcolor", "1"),
("pmulti", "1"),
("pguess", "1"),
("save", "1"),
])
.send()
.await
.unwrap();
let location = redirect_response.headers().get("location").unwrap();
let id = location
.to_str()
.unwrap()
.split_once("id=")
.unwrap()
.1
.split('&')
.next()
.unwrap()
.parse::<u32>()
.unwrap();
debug!(id = id, "Fetching puzzle...");
let client = reqwest::Client::new();
let export_response = client
.post(format!("https://webpbn.com/export.cgi/webpbn{:06}.non", id))
.form(&[
("go", "1"),
("sid", ""),
("id", &id.to_string()),
("xml_clue", "on"),
("xml_soln", "on"),
("fmt", "ss"),
("ss_soln", "on"),
("sg_clue", "on"),
("sg_soln", "on"),
])
.send()
.await
.unwrap()
.text()
.await
.unwrap();
let mut title = None;
let mut copyright = None;
let mut rows = vec![];
let mut columns = vec![];
let mut solution = bitvec![];
let mut state = GetPuzzleState::Start;
for line in export_response.lines() {
match state {
GetPuzzleState::Start => {
if line.starts_with("title") {
let mut iter = line.splitn(3, '"');
iter.next().unwrap();
title = Some(String::from(iter.next().unwrap()));
} else if line.starts_with("copyright") {
let mut iter = line.splitn(3, '"');
iter.next().unwrap();
copyright = Some(String::from(iter.next().unwrap()));
} else if line.starts_with("rows") {
state = GetPuzzleState::ReadingRows;
} else if line.starts_with("columns") {
state = GetPuzzleState::ReadingColumns;
} else if line.starts_with("goal") {
let mut iter = line.splitn(3, '"');
iter.next().unwrap();
solution.extend(iter.next().unwrap().chars().map(|char| char == '1'));
} else if line.starts_with("copyright") {
let mut iter = line.splitn(3, '"');
iter.next().unwrap();
title = Some(String::from(iter.next().unwrap()));
}
}
GetPuzzleState::ReadingRows => {
if line.is_empty() {
state = GetPuzzleState::Start;
} else {
let row = line
.split(',')
.map(|text| str::parse::<u8>(text).unwrap())
.filter(|&value| value > 0)
.collect::<Vec<_>>();
rows.push(row);
}
}
GetPuzzleState::ReadingColumns => {
if line.is_empty() {
state = GetPuzzleState::Start;
} else {
let column = line
.split(',')
.map(|text| str::parse::<u8>(text).unwrap())
.filter(|&value| value > 0)
.collect::<Vec<_>>();
columns.push(column);
}
}
}
}
if rows.len() == 0 || columns.len() == 0 || solution.len() == 0 {
warn!(id = id, "Invalid puzzle.");
Err(anyhow!("Invalid puzzle"))
} else {
info!(id = id, "Valid puzzle.");
Ok(Puzzle {
id,
title,
copyright,
rows,
columns,
solution,
is_solved: false,
})
}
}
/// A lazily-created Router, to be used by the SSH client tunnels.
pub async fn get_router() -> Router {
let first_puzzle = loop {
let puzzle = get_puzzle().await;
if puzzle.is_ok() {
break puzzle.unwrap();
}
};
info!("test");
Router::new()
.route("/", get(index))
.route("/nonogram", get(nonogram))
.route("/flag/:id", put(flag_checkbox))
.route("/flag/:id", delete(unflag_checkbox))
.route("/checkbox/:id", put(mark_checkbox))
.route("/checkbox/:id", delete(unmark_checkbox))
.with_state(AppState {
checkboxes: Arc::new(Mutex::new(vec![
CheckboxState::Empty;
first_puzzle.rows.len()
* first_puzzle.columns.len()
])),
current_puzzle: Arc::new(Mutex::new(first_puzzle)),
})
}
fn style() -> &'static str {
r#"
h2.congratulations {
color: darkgreen;
}
hr {
margin-top: 28px;
margin-bottom: 28px;
}
table {
border-collapse: collapse;
}
tr:nth-child(5n - 3) {
border-top: 1pt solid black;
}
tr th:nth-child(5n - 3), tr td:nth-child(5n - 3) {
border-left: 1pt solid black;
}
th[scope="col"] {
vertical-align: bottom;
}
th[scope="col"] > div {
display: flex;
flex-direction: column;
justify-content: end;
}
th[scope="row"] {
display: flex;
justify-content: end;
column-gap: 6px;
margin-right: 2px;
}
.checkbox {
position: relative;
}
.checkbox.flagged input:not(:checked) {
outline-style: solid;
outline-width: 2px;
outline-color: gray;
}
table.solved .checkbox.marked div {
position: absolute;
inset: 0;
z-index: 10;
background: black;
}
"#
}
fn head() -> Markup {
html! {
(DOCTYPE)
head {
meta charset="utf-8";
title { "Multipaint by Numbers" }
script src="https://unpkg.com/htmx.org@2.0.2" integrity="sha384-Y7hw+L/jvKeWIRRkqWYfPcvVxHzVzn5REgzbawhxAuQGwX1XWe70vji+VSeHOThJ" crossorigin="anonymous" {}
style { (PreEscaped(style())) }
}
}
}
async fn index() -> Markup {
html! {
(head())
body {
h1 { "Multipaint by Numbers" }
div #nonogram hx-get="/nonogram" hx-trigger="load, every 3s" {}
hr {}
p {
"Puzzles from "
a href="https://webpbn.com" target="_blank" {
"Web Paint-by-Number"
}
"."
}
p { "Click to mark/unmark." }
p { "Ctrl+Click to flag. (Then wait for a bit...!)" }
p style=(PreEscaped("opacity: 0")) { "Howdy from Bad Manners!" }
}
}
}
async fn nonogram_oob(state: State<AppState>) -> Markup {
html! {
div #nonogram hx-get="/nonogram" hx-trigger="load, every 3s" {
(nonogram(state).await)
}
}
}
async fn nonogram(State(state): State<AppState>) -> Markup {
let puzzle = state.current_puzzle.lock().unwrap();
let checkboxes = state.checkboxes.lock().unwrap();
let rows = &puzzle.rows;
let columns = &puzzle.columns;
let columns_len = columns.len();
let is_solved = puzzle.is_solved;
html! {
@if is_solved {
h2 class="congratulations" {
"Congratulations!!"
}
}
@if let Some(title) = &puzzle.title {
h3 {
"Puzzle: " (title) " (#" (puzzle.id) ")"
}
}
@if let Some(copyright) = &puzzle.copyright {
em .copyright {
(PreEscaped(copyright))
}
}
hr {}
table .solved[is_solved] {
tbody {
tr {
td {}
@for column in columns {
th scope="col" {
div {
@for value in column.iter() {
div {
(value.to_string())
}
}
}
}
}
}
@for (i, row) in rows.iter().enumerate() {
tr {
th scope="row" {
@for value in row.iter() {
div {
(value.to_string())
}
}
}
@let id_range = i * columns_len..(i + 1) * columns_len;
@let slice = &checkboxes[id_range.clone()];
@for (id, &state) in id_range.zip(slice) {
td {
(checkbox(id, is_solved, state))
}
}
}
}
}
}
}
}
fn checkbox(id: usize, is_solved: bool, state: CheckboxState) -> Markup {
match state {
CheckboxState::Marked => html! {
.checkbox.marked {
input id=(format!("checkbox-{id}")) type="checkbox" disabled[is_solved] checked {}
div hx-delete=(format!("/checkbox/{}", id)) hx-trigger=(format!("click from:#checkbox-{id}")) hx-swap="outerHTML" hx-target="closest .checkbox" {}
}
},
CheckboxState::Flagged if !is_solved => html! {
.checkbox.flagged {
input id=(format!("checkbox-{id}")) type="checkbox" disabled[is_solved] {}
div hx-put=(format!("/checkbox/{}", id)) hx-trigger=(format!("click from:#checkbox-{id}")) hx-swap="outerHTML" hx-target="closest .checkbox" {}
div hx-delete=(format!("/flag/{}", id)) hx-trigger=(format!("click[ctrlKey] from:#checkbox-{id}")) hx-swap="outerHTML" hx-target="closest .checkbox" {}
}
},
_ => html! {
.checkbox.empty {
input id=(format!("checkbox-{id}")) type="checkbox" disabled[is_solved] {}
div hx-put=(format!("/checkbox/{}", id)) hx-trigger=(format!("click from:#checkbox-{id}")) hx-swap="outerHTML" hx-target="closest .checkbox" {}
div hx-put=(format!("/flag/{}", id)) hx-trigger=(format!("click[ctrlKey] from:#checkbox-{id}")) hx-swap="outerHTML" hx-target="closest .checkbox" {}
}
},
}
}
async fn flag_checkbox(
State(state): State<AppState>,
Path(id): Path<usize>,
) -> Result<Markup, StatusCode> {
let puzzle = state.current_puzzle.lock().unwrap();
if puzzle.is_solved {
Ok(checkbox(id, true, CheckboxState::Empty))
} else {
let mut checkboxes = state.checkboxes.lock().unwrap();
match checkboxes.get_mut(id) {
None => Err(StatusCode::NOT_FOUND),
Some(checkbox_state) => {
*checkbox_state = CheckboxState::Flagged;
Ok(checkbox(id, false, CheckboxState::Flagged))
}
}
}
}
async fn unflag_checkbox(
State(state): State<AppState>,
Path(id): Path<usize>,
) -> Result<Markup, StatusCode> {
let puzzle = state.current_puzzle.lock().unwrap();
if puzzle.is_solved {
Ok(checkbox(id, true, CheckboxState::Empty))
} else {
let mut checkboxes = state.checkboxes.lock().unwrap();
match checkboxes.get_mut(id) {
None => Err(StatusCode::NOT_FOUND),
Some(checkbox_state) => {
*checkbox_state = CheckboxState::Flagged;
Ok(checkbox(id, false, CheckboxState::Flagged))
}
}
}
}
async fn mark_checkbox(
State(state): State<AppState>,
Path(id): Path<usize>,
) -> Result<Markup, StatusCode> {
let mut puzzle = state.current_puzzle.lock().unwrap();
if puzzle.is_solved {
Ok(checkbox(id, true, CheckboxState::Empty))
} else {
let mut checkboxes = state.checkboxes.lock().unwrap();
match checkboxes.get_mut(id) {
None => return Err(StatusCode::NOT_FOUND),
Some(checkbox_state) => {
*checkbox_state = CheckboxState::Marked;
}
}
puzzle.is_solved = check_if_solved(&puzzle.solution, &checkboxes, &state);
Ok(checkbox(id, puzzle.is_solved, CheckboxState::Marked))
}
}
async fn unmark_checkbox(
State(state): State<AppState>,
Path(id): Path<usize>,
) -> Result<Markup, StatusCode> {
let mut puzzle = state.current_puzzle.lock().unwrap();
if puzzle.is_solved {
Ok(checkbox(id, true, CheckboxState::Marked))
} else {
let mut checkboxes = state.checkboxes.lock().unwrap();
match checkboxes.get_mut(id) {
None => return Err(StatusCode::NOT_FOUND),
Some(checkbox_state) => {
*checkbox_state = CheckboxState::Empty;
}
}
puzzle.is_solved = check_if_solved(&puzzle.solution, &checkboxes, &state);
Ok(checkbox(id, puzzle.is_solved, CheckboxState::Empty))
}
}
fn check_if_solved(
solution: &BitVec<usize, Lsb0>,
checkboxes: &Vec<CheckboxState>,
state: &AppState,
) -> bool {
let wrong_squares = solution
.clone()
.bitxor(
checkboxes
.iter()
.map(|&state| state == CheckboxState::Marked)
.collect::<BitVec<usize, Lsb0>>(),
)
.count_ones();
let is_solved = wrong_squares == 0;
if is_solved {
let state = state.clone();
let current_puzzle = state.current_puzzle;
let checkboxes = state.checkboxes;
tokio::spawn(async move {
sleep(Duration::from_secs(8)).await;
// Fetch next puzzle
let next_puzzle = loop {
let puzzle = get_puzzle().await;
if puzzle.is_ok() {
break puzzle.unwrap();
}
};
let _ = mem::replace(
checkboxes.lock().unwrap().as_mut(),
vec![CheckboxState::Empty; next_puzzle.rows.len() * next_puzzle.columns.len()],
);
*current_puzzle.lock().unwrap() = next_puzzle;
});
} else {
info!("Have {wrong_squares} wrong squares!");
}
is_solved
}

3
src/lib.rs Normal file
View file

@ -0,0 +1,3 @@
pub mod entrypoint;
pub mod http;
pub mod ssh;

116
src/main.rs Normal file
View file

@ -0,0 +1,116 @@
use std::path::PathBuf;
use anyhow::Result;
use clap::{Parser, Subcommand, ValueEnum};
use tank_drone_test::{
entrypoint::{local_server_entrypoint, ssh_entrypoint},
http::{checkbox, multipaint_by_numbers, ROUTER},
};
use tracing::trace;
use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
#[derive(Debug, Clone, Subcommand)]
enum OperationMode {
/// Run a conventional HTTP server locally.
LocalServer {
/// Hostname to listen to.
#[arg(short = 'H', long, default_value_t = String::from("localhost"))]
hostname: String,
/// Local port to expose our site.
#[arg(short, long, default_value_t = 5023)]
port: u16,
},
/// Expose the HTTP server through SSH remote port forwarding.
Ssh {
/// SSH hostname.
hostname: String,
/// SSH port.
#[arg(short, long, default_value_t = 22)]
port: u16,
/// Identity file containing private key.
#[arg(short, long, default_value_t = String::from(""))]
login_name: String,
/// Identity file containing private key.
#[arg(short, long, value_name = "FILE")]
identity_file: PathBuf,
/// Remote hostname to bind to.
#[arg(short = 'R', long, default_value_t = String::from(""))]
remote_host: String,
/// Remote port to bind to.
#[arg(short = 'P', long, default_value_t = 80)]
remote_port: u16,
/// Request a pseudo-terminal to be allocated with the given command.
#[arg(long)]
request_pty: Option<String>,
},
}
#[derive(Debug, Copy, Clone, ValueEnum)]
enum ActivityRouter {
/// 400 Checkboxes - A barebones clone of One Million Checkboxes.
Checkboxes,
/// Multipaint by Numbers - A multiplayer nonogram/picross.
Multipaint,
}
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct MainEntrypointArgs {
/// Which activity router to serve.
#[arg(value_enum, default_value_t = ActivityRouter::Checkboxes)]
router: ActivityRouter,
/// Which mode to run this application as.
#[command(subcommand)]
mode: OperationMode,
}
#[tokio::main]
async fn main() -> Result<()> {
let _subscriber = tracing_subscriber::registry()
.with(fmt::layer())
.with(EnvFilter::from_default_env())
.init();
trace!("Tracing is up!");
let args = MainEntrypointArgs::parse();
match args.router {
ActivityRouter::Checkboxes => ROUTER.set(checkbox::get_router()).unwrap(),
ActivityRouter::Multipaint => ROUTER
.set(multipaint_by_numbers::get_router().await)
.unwrap(),
}
match args.mode {
OperationMode::LocalServer { hostname, port } => {
local_server_entrypoint(hostname.as_str(), port).await
}
OperationMode::Ssh {
hostname,
port,
login_name,
identity_file,
remote_host,
remote_port,
request_pty,
} => {
ssh_entrypoint(
hostname.as_str(),
port,
login_name.as_str(),
identity_file,
remote_host.as_str(),
remote_port,
request_pty,
)
.await
}
}
}

268
src/ssh.rs Normal file
View file

@ -0,0 +1,268 @@
use std::{sync::Arc, time::Duration};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use hyper::service::service_fn;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder,
};
use russh::{
client::{self, Config, Handle, Msg, Session},
keys::key::{self, KeyPair},
Channel, ChannelId, ChannelMsg, Disconnect,
};
use tokio::{
io::{stderr, stdout, AsyncWriteExt},
time::sleep,
};
use tower::Service;
use tracing::{debug, debug_span, info, trace};
use crate::http::ROUTER;
/* Russh session and client */
/// User-implemented session type as a helper for interfacing with the SSH protocol.
pub struct TcpForwardSession(Handle<Client>);
/// User-implemented session type as a helper for interfacing with the SSH protocol.
impl TcpForwardSession {
/// Attempts to connect to the SSH server. If authentication fails, it returns an error value immediately.
///
/// Our reconnection strategy comes from an iterator which yields `Duration`s. Each one tells us how long to delay
/// our next reconnection attempt. The function will stop attempting to reconnect once the iterator
/// stops yielding values.
pub async fn connect(
host: &str,
port: u16,
login_name: &str,
config: Arc<Config>,
secret_key: Arc<KeyPair>,
mut timer_iterator: impl Iterator<Item = Duration>,
) -> Result<Self> {
let span = debug_span!("TcpForwardSession.connect");
let _enter = span;
debug!("TcpForwardSession connecting...");
let mut attempts = 0u32;
let session = loop {
attempts += 1;
debug!("Connection retry #{}", attempts);
match client::connect(Arc::clone(&config), (host, port), Client {}).await {
Ok(mut session) => {
if session
.authenticate_publickey(login_name, Arc::clone(&secret_key))
.await
.with_context(|| "Error while authenticating with public key.")?
{
debug!(attempts = attempts, "Public key authentication succeeded!");
break session;
} else {
return Err(anyhow!("Public key authentication failed."));
}
}
Err(err) => {
debug!(err = ?err, "Unable to connect to remote host.");
let Some(duration) = timer_iterator.next() else {
debug!(attempts = attempts, "Failed to recconect.");
return Err(anyhow!("Gave up graceful reconnection."));
};
sleep(duration).await;
}
}
};
Ok(Self(session))
}
/// Sends a port forwarding request and opens a session to receive miscellaneous data.
/// The function yields when the session is broken (for example, if the connection was lost).
pub async fn start_forwarding(
&mut self,
remote_host: &str,
remote_port: u16,
request_pty: Option<&str>,
) -> Result<u32> {
let span = debug_span!("TcpForwardSession.start");
let _enter = span;
let session = &mut self.0;
session
.tcpip_forward(remote_host, remote_port.into())
.await
.with_context(|| "tcpip_forward error.")?;
debug!("Requested tcpip_forward session.");
let mut channel = session
.channel_open_session()
.await
.with_context(|| "channel_open_session error.")?;
debug!("Created open session channel.");
// let mut stdin = stdin();
let mut stdout = stdout();
let mut stderr = stderr();
if let Some(cmd) = request_pty {
let size = termsize::get().unwrap();
channel
.request_pty(
false,
&std::env::var("TERM").unwrap_or("xterm".into()),
size.cols.into(),
size.rows.into(),
0,
0,
&[],
)
.await
.with_context(|| "Unable to request pseudo-terminal.")?;
debug!("Requested pseudo-terminal.");
channel
.exec(true, cmd)
.await
.with_context(|| "Unable to execute command for pseudo-terminal.")?;
};
let code = loop {
let Some(msg) = channel.wait().await else {
return Err(anyhow!("Unexpected end of channel."));
};
trace!("Got a message through initial session!");
match msg {
ChannelMsg::Data { ref data } => {
stdout.write_all(data).await?;
stdout.flush().await?;
}
ChannelMsg::ExtendedData { ref data, ext: 1 } => {
stderr.write_all(data).await?;
stderr.flush().await?;
}
ChannelMsg::Success => (),
ChannelMsg::Close => break 0,
ChannelMsg::ExitStatus { exit_status } => {
debug!("Exited with code {exit_status}");
channel
.eof()
.await
.with_context(|| "Unable to close connection.")?;
break exit_status;
}
msg => return Err(anyhow!("Unknown message type {:?}.", msg)),
}
};
Ok(code)
}
pub async fn close(&mut self) -> Result<()> {
self.0
.disconnect(Disconnect::ByApplication, "", "English")
.await?;
Ok(())
}
}
/// Our SSH client implementing the `Handler` callbacks for the functions we need to use.
struct Client {}
#[async_trait]
impl client::Handler for Client {
type Error = anyhow::Error;
/// Always accept the SSH server's pubkey. Don't do this in production.
#[allow(unused_variables)]
async fn check_server_key(
&mut self,
server_public_key: &key::PublicKey,
) -> Result<bool, Self::Error> {
Ok(true)
}
/// Handle a new forwarded connection, represented by a specific `Channel`. We will create a clone of our router,
/// and forward any messages from this channel with its streaming API.
///
/// To make Axum behave with streaming, we must turn it into a Tower service first.
/// And to handle the SSH channel as a stream, we will use a utility method from Tokio that turns our
/// AsyncRead/Write stream into a `hyper` IO object.
///
/// See also: [axum/examples/serve-with-hyper](https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs)
#[allow(unused_variables)]
async fn server_channel_open_forwarded_tcpip(
&mut self,
channel: Channel<Msg>,
connected_address: &str,
connected_port: u32,
originator_address: &str,
originator_port: u32,
session: &mut Session,
) -> Result<(), Self::Error> {
let span = debug_span!("server_channel_open_forwarded_tcpip");
let _enter = span.enter();
debug!(
sshid = %String::from_utf8_lossy(session.remote_sshid()),
connected_address = connected_address,
connected_port = connected_port,
originator_address = originator_address,
originator_port = originator_port,
"New connection!"
);
let router = ROUTER
.get()
.with_context(|| "Router hasn't been initialized.")?;
let service = service_fn(move |req| router.clone().call(req));
let server = Builder::new(TokioExecutor::new());
// tokio::spawn is required to let us reply over the data channel.
tokio::spawn(async move {
server
.serve_connection_with_upgrades(TokioIo::new(channel.into_stream()), service)
.await
.expect("Invalid request");
});
Ok(())
}
#[allow(unused_variables)]
async fn auth_banner(
&mut self,
banner: &str,
session: &mut Session,
) -> Result<(), Self::Error> {
debug!("Received auth banner.");
let mut stdout = stdout();
stdout.write_all(banner.as_bytes()).await?;
stdout.flush().await?;
Ok(())
}
#[allow(unused_variables)]
async fn exit_status(
&mut self,
channel: ChannelId,
exit_status: u32,
session: &mut Session,
) -> Result<(), Self::Error> {
debug!(channel = ?channel, "exit_status");
if exit_status == 0 {
info!("Remote exited with status {}.", exit_status);
} else {
info!("Remote exited with status {}.", exit_status);
}
Ok(())
}
#[allow(unused_variables)]
async fn channel_open_confirmation(
&mut self,
channel: ChannelId,
max_packet_size: u32,
window_size: u32,
session: &mut Session,
) -> Result<(), Self::Error> {
debug!(channel = ?channel, max_packet_size, window_size, "channel_open_confirmation");
Ok(())
}
#[allow(unused_variables)]
async fn channel_success(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<(), Self::Error> {
debug!(channel = ?channel, "channel_success");
Ok(())
}
}