From c1c8cdbfcd715977748348a6a528418893e289bf Mon Sep 17 00:00:00 2001 From: trisua Date: Mon, 28 Apr 2025 21:22:20 -0400 Subject: [PATCH] fix: socket connections/closing you can no longer crash the entire server by connection to the socket 8 times --- Cargo.lock | 55 +++- crates/app/Cargo.toml | 1 + .../src/routes/api/v1/channels/messages.rs | 251 +++++++----------- crates/core/Cargo.toml | 2 +- crates/l10n/Cargo.toml | 2 +- 5 files changed, 147 insertions(+), 164 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d0678f1..9c56781 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -716,6 +716,28 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -735,6 +757,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -3240,6 +3271,7 @@ dependencies = [ "axum-extra", "cf-turnstile", "contrasted", + "crossbeam", "futures-util", "image", "mime_guess", @@ -3524,9 +3556,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148" +checksum = "900f6c86a685850b1bc9f6223b20125115ee3f31e01207d81655bbcc0aea9231" dependencies = [ "serde", "serde_spanned", @@ -3536,26 +3568,33 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.24" +version = "0.22.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +checksum = "10558ed0bd2a1562e630926a2d1f0b98c827da99fabd3fe20920a59642504485" dependencies = [ "indexmap", "serde", "serde_spanned", "toml_datetime", + "toml_write", "winnow", ] +[[package]] +name = "toml_write" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28391a4201ba7eb1984cfeb6862c0b3ea2cfe23332298967c749dddc0d6cd976" + [[package]] name = "totp-rs" version = "5.7.0" @@ -4296,9 +4335,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" [[package]] name = "winnow" -version = "0.7.4" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36" +checksum = "6cb8234a863ea0e8cd7284fcdd4f145233eb00fee02bbdd9861aec44e6477bc5" dependencies = [ "memchr", ] diff --git a/crates/app/Cargo.toml b/crates/app/Cargo.toml index aad52e9..00862f3 100644 --- a/crates/app/Cargo.toml +++ b/crates/app/Cargo.toml @@ -34,3 +34,4 @@ mime_guess = "2.0.5" cf-turnstile = "0.2.0" contrasted = "0.1.2" futures-util = "0.3.31" +crossbeam = { version = "0.8.4", features = ["crossbeam-channel"] } diff --git a/crates/app/src/routes/api/v1/channels/messages.rs b/crates/app/src/routes/api/v1/channels/messages.rs index 058f9a8..9e23935 100644 --- a/crates/app/src/routes/api/v1/channels/messages.rs +++ b/crates/app/src/routes/api/v1/channels/messages.rs @@ -15,8 +15,8 @@ use tetratto_core::{ socket::{SocketMessage, SocketMethod}, ApiReturn, Error, }, + DataManager, }; -use std::{sync::mpsc, time::Duration}; use crate::{get_user_from_token, routes::api::v1::CreateMessage, State}; use serde::Deserialize; use futures_util::{sink::SinkExt, stream::StreamExt}; @@ -31,201 +31,144 @@ pub struct SocketHeaders { pub async fn subscription_handler( ws: WebSocketUpgrade, Extension(data): Extension, - Path(channel_id): Path, + Path(id): Path, ) -> Response { - ws.on_upgrade(move |socket| handle_socket(socket, data, channel_id)) + let data = &(data.read().await); + let data = data.0.clone(); + + ws.on_upgrade(|socket| async move { + tokio::spawn(async move { + handle_socket(socket, data, id).await; + }); + }) } -pub async fn handle_socket(socket: WebSocket, state: State, channel_id: usize) { - let db = &(state.read().await).0; - let db = db.clone(); - +pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: String) { let (mut sink, mut stream) = socket.split(); - let (sender, receiver) = mpsc::channel::(); - // forward messages from mpsc to the sink - let mut forward_task = tokio::spawn(async move { - while let Ok(message) = receiver.recv() { - if message == "Close" { - let _ = sink.close().await; - break; - } - - if sink.send(message.into()).await.is_err() { - break; - } - } - - drop(receiver); - drop(sink); - }); - - // ping - let ping_sender = sender.clone(); - let mut heartbeat_task = tokio::spawn(async move { - let mut heartbeat = tokio::time::interval(Duration::from_secs(10)); - - while ping_sender.send("Ping".to_string()).is_ok() { - heartbeat.tick().await; - } - - drop(ping_sender); - }); - - // ... let mut user: Option = None; let mut con = db.2.clone().get_con().await; // handle incoming messages on socket let dbc = db.clone(); - let recv_sender = sender.clone(); - let mut recv_task = tokio::spawn(async move { - while let Some(Ok(WsMessage::Text(text))) = stream.next().await { - if text == "Pong" { - continue; + if let Some(Ok(WsMessage::Text(text))) = stream.next().await { + let data: SocketMessage = match serde_json::from_str(&text.to_string()) { + Ok(t) => t, + Err(_) => { + let _ = sink.close().await; + return; } + }; - if text == "Close" { - recv_sender.send("Close".to_string()).unwrap(); - break; - } + if data.method != SocketMethod::Headers && user.is_none() { + // we've sent something else before authenticating... that's not right + let _ = sink.close().await; + return; + } - let data: SocketMessage = match serde_json::from_str(&text.to_string()) { - Ok(t) => t, - Err(_) => { - recv_sender.send("Close".to_string()).unwrap(); - break; - } - }; + match data.method { + SocketMethod::Headers => { + let data: SocketHeaders = data.data(); - if data.method != SocketMethod::Headers && user.is_none() { - // we've sent something else before authenticating... that's not right - recv_sender.send("Close".to_string()).unwrap(); - break; - } - - match data.method { - SocketMethod::Headers => { - let data: SocketHeaders = data.data(); - - user = Some( - match dbc - .get_user_by_id(match data.user.parse::() { - Ok(c) => c, - Err(_) => { - recv_sender.send("Close".to_string()).unwrap(); - break; - } - }) - .await - { - Ok(ua) => ua, - Err(_) => { - recv_sender.send("Close".to_string()).unwrap(); - break; - } - }, - ); - - let channel = match dbc - .get_channel_by_id(match data.channel.parse::() { + user = Some( + match dbc + .get_user_by_id(match data.user.parse::() { Ok(c) => c, Err(_) => { - recv_sender.send("Close".to_string()).unwrap(); - break; + let _ = sink.close().await; + return; } }) .await - { - Ok(c) => c, - Err(_) => { - recv_sender.send("Close".to_string()).unwrap(); - break; - } - }; - - let user = user.as_ref().unwrap(); - - let membership = match dbc - .get_membership_by_owner_community(user.id, channel.id) - .await { Ok(ua) => ua, Err(_) => { - recv_sender.send("Close".to_string()).unwrap(); - break; + let _ = sink.close().await; + return; } - }; + }, + ); - if !channel.check_read(user.id, Some(membership.role)) { - recv_sender.send("Close".to_string()).unwrap(); - break; + let channel = match dbc + .get_channel_by_id(match data.channel.parse::() { + Ok(c) => c, + Err(_) => { + let _ = sink.close().await; + return; + } + }) + .await + { + Ok(c) => c, + Err(_) => { + let _ = sink.close().await; + return; } - } - _ => { - recv_sender.send("Close".to_string()).unwrap(); - break; + }; + + let user = user.as_ref().unwrap(); + + let membership = match dbc + .get_membership_by_owner_community(user.id, channel.id) + .await + { + Ok(ua) => ua, + Err(_) => { + let _ = sink.close().await; + return; + } + }; + + if !channel.check_read(user.id, Some(membership.role)) { + let _ = sink.close().await; + return; } } + _ => { + let _ = sink.close().await; + return; + } } + } else { + sink.close().await.unwrap(); + return; + } - drop(recv_sender); + let mut recv_task = tokio::spawn(async move { + while let Some(Ok(WsMessage::Text(text))) = stream.next().await { + if text != "Close" { + continue; + } + + // yes, this is an "unclean" disconnection from the socket... + // i don't care, it works + drop(stream); + break; + } }); - // forward messages from redis to the mpsc - let send_task_sender = sender.clone(); - let mut send_task = tokio::spawn(async move { + let mut redis_task = tokio::spawn(async move { + // forward messages from redis to the mpsc let mut pubsub = con.as_pubsub(); pubsub.subscribe(channel_id).unwrap(); while let Ok(msg) = pubsub.get_message() { // payload is a stringified SocketMessage - if send_task_sender.send(msg.get_payload().unwrap()).is_err() { - break; + if sink + .send(WsMessage::Text(msg.get_payload::().unwrap().into())) + .await + .is_err() + { + return; } } - - drop(send_task_sender); }); - // ... - let close_sender = sender.clone(); tokio::select! { - _ = (&mut heartbeat_task) => { - let _ = close_sender.send("Close".to_string()); - forward_task.abort(); - recv_task.abort(); - send_task.abort(); - } - _ = (&mut forward_task) => { - send_task.abort(); - recv_task.abort(); - heartbeat_task.abort(); - } - _ = (&mut send_task) => { - let _ = close_sender.send("Close".to_string()); - forward_task.abort(); - recv_task.abort(); - heartbeat_task.abort(); - }, - _ = (&mut recv_task) => { - let _ = close_sender.send("Close".to_string()); - send_task.abort(); - forward_task.abort(); - heartbeat_task.abort(); - } - }; + _ = (&mut recv_task) => redis_task.abort(), + _ = (&mut redis_task) => recv_task.abort() + } - // kill - drop(sender); - drop(db); - - send_task.abort(); - recv_task.abort(); - forward_task.abort(); - heartbeat_task.abort(); - - let _ = close_sender.send("Close".to_string()); tracing::info!("socket terminate"); } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index a7e3769..8dcd51a 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -12,7 +12,7 @@ default = ["sqlite", "redis"] [dependencies] pathbufd = "0.1.4" serde = { version = "1.0.219", features = ["derive"] } -toml = "0.8.20" +toml = "0.8.21" tetratto-shared = { path = "../shared" } tetratto-l10n = { path = "../l10n" } serde_json = "1.0.140" diff --git a/crates/l10n/Cargo.toml b/crates/l10n/Cargo.toml index e4c52b4..3d60dd0 100644 --- a/crates/l10n/Cargo.toml +++ b/crates/l10n/Cargo.toml @@ -9,4 +9,4 @@ license.workspace = true [dependencies] pathbufd = "0.1.4" serde = { version = "1.0.219", features = ["derive"] } -toml = "0.8.20" +toml = "0.8.21"