fix: socket connections/closing

you can no longer crash the entire server by connection to the socket 8 times
This commit is contained in:
trisua 2025-04-28 21:22:20 -04:00
parent 3a12c0ee6c
commit c1c8cdbfcd
5 changed files with 147 additions and 164 deletions

55
Cargo.lock generated
View file

@ -716,6 +716,28 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "crossbeam"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2"
dependencies = [
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-deque" name = "crossbeam-deque"
version = "0.8.6" version = "0.8.6"
@ -735,6 +757,15 @@ dependencies = [
"crossbeam-utils", "crossbeam-utils",
] ]
[[package]]
name = "crossbeam-queue"
version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115"
dependencies = [
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-utils" name = "crossbeam-utils"
version = "0.8.21" version = "0.8.21"
@ -3240,6 +3271,7 @@ dependencies = [
"axum-extra", "axum-extra",
"cf-turnstile", "cf-turnstile",
"contrasted", "contrasted",
"crossbeam",
"futures-util", "futures-util",
"image", "image",
"mime_guess", "mime_guess",
@ -3524,9 +3556,9 @@ dependencies = [
[[package]] [[package]]
name = "toml" name = "toml"
version = "0.8.20" version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148" checksum = "900f6c86a685850b1bc9f6223b20125115ee3f31e01207d81655bbcc0aea9231"
dependencies = [ dependencies = [
"serde", "serde",
"serde_spanned", "serde_spanned",
@ -3536,26 +3568,33 @@ dependencies = [
[[package]] [[package]]
name = "toml_datetime" name = "toml_datetime"
version = "0.6.8" version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3"
dependencies = [ dependencies = [
"serde", "serde",
] ]
[[package]] [[package]]
name = "toml_edit" name = "toml_edit"
version = "0.22.24" version = "0.22.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" checksum = "10558ed0bd2a1562e630926a2d1f0b98c827da99fabd3fe20920a59642504485"
dependencies = [ dependencies = [
"indexmap", "indexmap",
"serde", "serde",
"serde_spanned", "serde_spanned",
"toml_datetime", "toml_datetime",
"toml_write",
"winnow", "winnow",
] ]
[[package]]
name = "toml_write"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28391a4201ba7eb1984cfeb6862c0b3ea2cfe23332298967c749dddc0d6cd976"
[[package]] [[package]]
name = "totp-rs" name = "totp-rs"
version = "5.7.0" version = "5.7.0"
@ -4296,9 +4335,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
[[package]] [[package]]
name = "winnow" name = "winnow"
version = "0.7.4" version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36" checksum = "6cb8234a863ea0e8cd7284fcdd4f145233eb00fee02bbdd9861aec44e6477bc5"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]

View file

@ -34,3 +34,4 @@ mime_guess = "2.0.5"
cf-turnstile = "0.2.0" cf-turnstile = "0.2.0"
contrasted = "0.1.2" contrasted = "0.1.2"
futures-util = "0.3.31" futures-util = "0.3.31"
crossbeam = { version = "0.8.4", features = ["crossbeam-channel"] }

View file

@ -15,8 +15,8 @@ use tetratto_core::{
socket::{SocketMessage, SocketMethod}, socket::{SocketMessage, SocketMethod},
ApiReturn, Error, ApiReturn, Error,
}, },
DataManager,
}; };
use std::{sync::mpsc, time::Duration};
use crate::{get_user_from_token, routes::api::v1::CreateMessage, State}; use crate::{get_user_from_token, routes::api::v1::CreateMessage, State};
use serde::Deserialize; use serde::Deserialize;
use futures_util::{sink::SinkExt, stream::StreamExt}; use futures_util::{sink::SinkExt, stream::StreamExt};
@ -31,77 +31,39 @@ pub struct SocketHeaders {
pub async fn subscription_handler( pub async fn subscription_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
Extension(data): Extension<State>, Extension(data): Extension<State>,
Path(channel_id): Path<usize>, Path(id): Path<String>,
) -> Response { ) -> Response {
ws.on_upgrade(move |socket| handle_socket(socket, data, channel_id)) let data = &(data.read().await);
let data = data.0.clone();
ws.on_upgrade(|socket| async move {
tokio::spawn(async move {
handle_socket(socket, data, id).await;
});
})
} }
pub async fn handle_socket(socket: WebSocket, state: State, channel_id: usize) { pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: String) {
let db = &(state.read().await).0;
let db = db.clone();
let (mut sink, mut stream) = socket.split(); let (mut sink, mut stream) = socket.split();
let (sender, receiver) = mpsc::channel::<String>();
// forward messages from mpsc to the sink
let mut forward_task = tokio::spawn(async move {
while let Ok(message) = receiver.recv() {
if message == "Close" {
let _ = sink.close().await;
break;
}
if sink.send(message.into()).await.is_err() {
break;
}
}
drop(receiver);
drop(sink);
});
// ping
let ping_sender = sender.clone();
let mut heartbeat_task = tokio::spawn(async move {
let mut heartbeat = tokio::time::interval(Duration::from_secs(10));
while ping_sender.send("Ping".to_string()).is_ok() {
heartbeat.tick().await;
}
drop(ping_sender);
});
// ...
let mut user: Option<User> = None; let mut user: Option<User> = None;
let mut con = db.2.clone().get_con().await; let mut con = db.2.clone().get_con().await;
// handle incoming messages on socket // handle incoming messages on socket
let dbc = db.clone(); let dbc = db.clone();
let recv_sender = sender.clone(); if let Some(Ok(WsMessage::Text(text))) = stream.next().await {
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(WsMessage::Text(text))) = stream.next().await {
if text == "Pong" {
continue;
}
if text == "Close" {
recv_sender.send("Close".to_string()).unwrap();
break;
}
let data: SocketMessage = match serde_json::from_str(&text.to_string()) { let data: SocketMessage = match serde_json::from_str(&text.to_string()) {
Ok(t) => t, Ok(t) => t,
Err(_) => { Err(_) => {
recv_sender.send("Close".to_string()).unwrap(); let _ = sink.close().await;
break; return;
} }
}; };
if data.method != SocketMethod::Headers && user.is_none() { if data.method != SocketMethod::Headers && user.is_none() {
// we've sent something else before authenticating... that's not right // we've sent something else before authenticating... that's not right
recv_sender.send("Close".to_string()).unwrap(); let _ = sink.close().await;
break; return;
} }
match data.method { match data.method {
@ -113,16 +75,16 @@ pub async fn handle_socket(socket: WebSocket, state: State, channel_id: usize) {
.get_user_by_id(match data.user.parse::<usize>() { .get_user_by_id(match data.user.parse::<usize>() {
Ok(c) => c, Ok(c) => c,
Err(_) => { Err(_) => {
recv_sender.send("Close".to_string()).unwrap(); let _ = sink.close().await;
break; return;
} }
}) })
.await .await
{ {
Ok(ua) => ua, Ok(ua) => ua,
Err(_) => { Err(_) => {
recv_sender.send("Close".to_string()).unwrap(); let _ = sink.close().await;
break; return;
} }
}, },
); );
@ -131,16 +93,16 @@ pub async fn handle_socket(socket: WebSocket, state: State, channel_id: usize) {
.get_channel_by_id(match data.channel.parse::<usize>() { .get_channel_by_id(match data.channel.parse::<usize>() {
Ok(c) => c, Ok(c) => c,
Err(_) => { Err(_) => {
recv_sender.send("Close".to_string()).unwrap(); let _ = sink.close().await;
break; return;
} }
}) })
.await .await
{ {
Ok(c) => c, Ok(c) => c,
Err(_) => { Err(_) => {
recv_sender.send("Close".to_string()).unwrap(); let _ = sink.close().await;
break; return;
} }
}; };
@ -152,80 +114,61 @@ pub async fn handle_socket(socket: WebSocket, state: State, channel_id: usize) {
{ {
Ok(ua) => ua, Ok(ua) => ua,
Err(_) => { Err(_) => {
recv_sender.send("Close".to_string()).unwrap(); let _ = sink.close().await;
break; return;
} }
}; };
if !channel.check_read(user.id, Some(membership.role)) { if !channel.check_read(user.id, Some(membership.role)) {
recv_sender.send("Close".to_string()).unwrap(); let _ = sink.close().await;
break; return;
} }
} }
_ => { _ => {
recv_sender.send("Close".to_string()).unwrap(); let _ = sink.close().await;
return;
}
}
} else {
sink.close().await.unwrap();
return;
}
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(WsMessage::Text(text))) = stream.next().await {
if text != "Close" {
continue;
}
// yes, this is an "unclean" disconnection from the socket...
// i don't care, it works
drop(stream);
break; break;
} }
}
}
drop(recv_sender);
}); });
let mut redis_task = tokio::spawn(async move {
// forward messages from redis to the mpsc // forward messages from redis to the mpsc
let send_task_sender = sender.clone();
let mut send_task = tokio::spawn(async move {
let mut pubsub = con.as_pubsub(); let mut pubsub = con.as_pubsub();
pubsub.subscribe(channel_id).unwrap(); pubsub.subscribe(channel_id).unwrap();
while let Ok(msg) = pubsub.get_message() { while let Ok(msg) = pubsub.get_message() {
// payload is a stringified SocketMessage // payload is a stringified SocketMessage
if send_task_sender.send(msg.get_payload().unwrap()).is_err() { if sink
break; .send(WsMessage::Text(msg.get_payload::<String>().unwrap().into()))
.await
.is_err()
{
return;
} }
} }
drop(send_task_sender);
}); });
// ...
let close_sender = sender.clone();
tokio::select! { tokio::select! {
_ = (&mut heartbeat_task) => { _ = (&mut recv_task) => redis_task.abort(),
let _ = close_sender.send("Close".to_string()); _ = (&mut redis_task) => recv_task.abort()
forward_task.abort();
recv_task.abort();
send_task.abort();
} }
_ = (&mut forward_task) => {
send_task.abort();
recv_task.abort();
heartbeat_task.abort();
}
_ = (&mut send_task) => {
let _ = close_sender.send("Close".to_string());
forward_task.abort();
recv_task.abort();
heartbeat_task.abort();
},
_ = (&mut recv_task) => {
let _ = close_sender.send("Close".to_string());
send_task.abort();
forward_task.abort();
heartbeat_task.abort();
}
};
// kill
drop(sender);
drop(db);
send_task.abort();
recv_task.abort();
forward_task.abort();
heartbeat_task.abort();
let _ = close_sender.send("Close".to_string());
tracing::info!("socket terminate"); tracing::info!("socket terminate");
} }

View file

@ -12,7 +12,7 @@ default = ["sqlite", "redis"]
[dependencies] [dependencies]
pathbufd = "0.1.4" pathbufd = "0.1.4"
serde = { version = "1.0.219", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] }
toml = "0.8.20" toml = "0.8.21"
tetratto-shared = { path = "../shared" } tetratto-shared = { path = "../shared" }
tetratto-l10n = { path = "../l10n" } tetratto-l10n = { path = "../l10n" }
serde_json = "1.0.140" serde_json = "1.0.140"

View file

@ -9,4 +9,4 @@ license.workspace = true
[dependencies] [dependencies]
pathbufd = "0.1.4" pathbufd = "0.1.4"
serde = { version = "1.0.219", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] }
toml = "0.8.20" toml = "0.8.21"