fix: socket connections/closing
you can no longer crash the entire server by connection to the socket 8 times
This commit is contained in:
parent
3a12c0ee6c
commit
c1c8cdbfcd
5 changed files with 147 additions and 164 deletions
55
Cargo.lock
generated
55
Cargo.lock
generated
|
@ -716,6 +716,28 @@ dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam"
|
||||||
|
version = "0.8.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-channel",
|
||||||
|
"crossbeam-deque",
|
||||||
|
"crossbeam-epoch",
|
||||||
|
"crossbeam-queue",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-channel"
|
||||||
|
version = "0.5.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-deque"
|
name = "crossbeam-deque"
|
||||||
version = "0.8.6"
|
version = "0.8.6"
|
||||||
|
@ -735,6 +757,15 @@ dependencies = [
|
||||||
"crossbeam-utils",
|
"crossbeam-utils",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-queue"
|
||||||
|
version = "0.3.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-utils"
|
name = "crossbeam-utils"
|
||||||
version = "0.8.21"
|
version = "0.8.21"
|
||||||
|
@ -3240,6 +3271,7 @@ dependencies = [
|
||||||
"axum-extra",
|
"axum-extra",
|
||||||
"cf-turnstile",
|
"cf-turnstile",
|
||||||
"contrasted",
|
"contrasted",
|
||||||
|
"crossbeam",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"image",
|
"image",
|
||||||
"mime_guess",
|
"mime_guess",
|
||||||
|
@ -3524,9 +3556,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "toml"
|
name = "toml"
|
||||||
version = "0.8.20"
|
version = "0.8.21"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148"
|
checksum = "900f6c86a685850b1bc9f6223b20125115ee3f31e01207d81655bbcc0aea9231"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
"serde_spanned",
|
"serde_spanned",
|
||||||
|
@ -3536,26 +3568,33 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "toml_datetime"
|
name = "toml_datetime"
|
||||||
version = "0.6.8"
|
version = "0.6.9"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41"
|
checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "toml_edit"
|
name = "toml_edit"
|
||||||
version = "0.22.24"
|
version = "0.22.25"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474"
|
checksum = "10558ed0bd2a1562e630926a2d1f0b98c827da99fabd3fe20920a59642504485"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"indexmap",
|
"indexmap",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_spanned",
|
"serde_spanned",
|
||||||
"toml_datetime",
|
"toml_datetime",
|
||||||
|
"toml_write",
|
||||||
"winnow",
|
"winnow",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toml_write"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "28391a4201ba7eb1984cfeb6862c0b3ea2cfe23332298967c749dddc0d6cd976"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "totp-rs"
|
name = "totp-rs"
|
||||||
version = "5.7.0"
|
version = "5.7.0"
|
||||||
|
@ -4296,9 +4335,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winnow"
|
name = "winnow"
|
||||||
version = "0.7.4"
|
version = "0.7.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36"
|
checksum = "6cb8234a863ea0e8cd7284fcdd4f145233eb00fee02bbdd9861aec44e6477bc5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
|
@ -34,3 +34,4 @@ mime_guess = "2.0.5"
|
||||||
cf-turnstile = "0.2.0"
|
cf-turnstile = "0.2.0"
|
||||||
contrasted = "0.1.2"
|
contrasted = "0.1.2"
|
||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
|
crossbeam = { version = "0.8.4", features = ["crossbeam-channel"] }
|
||||||
|
|
|
@ -15,8 +15,8 @@ use tetratto_core::{
|
||||||
socket::{SocketMessage, SocketMethod},
|
socket::{SocketMessage, SocketMethod},
|
||||||
ApiReturn, Error,
|
ApiReturn, Error,
|
||||||
},
|
},
|
||||||
|
DataManager,
|
||||||
};
|
};
|
||||||
use std::{sync::mpsc, time::Duration};
|
|
||||||
use crate::{get_user_from_token, routes::api::v1::CreateMessage, State};
|
use crate::{get_user_from_token, routes::api::v1::CreateMessage, State};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use futures_util::{sink::SinkExt, stream::StreamExt};
|
use futures_util::{sink::SinkExt, stream::StreamExt};
|
||||||
|
@ -31,77 +31,39 @@ pub struct SocketHeaders {
|
||||||
pub async fn subscription_handler(
|
pub async fn subscription_handler(
|
||||||
ws: WebSocketUpgrade,
|
ws: WebSocketUpgrade,
|
||||||
Extension(data): Extension<State>,
|
Extension(data): Extension<State>,
|
||||||
Path(channel_id): Path<usize>,
|
Path(id): Path<String>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
ws.on_upgrade(move |socket| handle_socket(socket, data, channel_id))
|
let data = &(data.read().await);
|
||||||
|
let data = data.0.clone();
|
||||||
|
|
||||||
|
ws.on_upgrade(|socket| async move {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
handle_socket(socket, data, id).await;
|
||||||
|
});
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_socket(socket: WebSocket, state: State, channel_id: usize) {
|
pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: String) {
|
||||||
let db = &(state.read().await).0;
|
|
||||||
let db = db.clone();
|
|
||||||
|
|
||||||
let (mut sink, mut stream) = socket.split();
|
let (mut sink, mut stream) = socket.split();
|
||||||
let (sender, receiver) = mpsc::channel::<String>();
|
|
||||||
|
|
||||||
// forward messages from mpsc to the sink
|
|
||||||
let mut forward_task = tokio::spawn(async move {
|
|
||||||
while let Ok(message) = receiver.recv() {
|
|
||||||
if message == "Close" {
|
|
||||||
let _ = sink.close().await;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if sink.send(message.into()).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
drop(receiver);
|
|
||||||
drop(sink);
|
|
||||||
});
|
|
||||||
|
|
||||||
// ping
|
|
||||||
let ping_sender = sender.clone();
|
|
||||||
let mut heartbeat_task = tokio::spawn(async move {
|
|
||||||
let mut heartbeat = tokio::time::interval(Duration::from_secs(10));
|
|
||||||
|
|
||||||
while ping_sender.send("Ping".to_string()).is_ok() {
|
|
||||||
heartbeat.tick().await;
|
|
||||||
}
|
|
||||||
|
|
||||||
drop(ping_sender);
|
|
||||||
});
|
|
||||||
|
|
||||||
// ...
|
|
||||||
let mut user: Option<User> = None;
|
let mut user: Option<User> = None;
|
||||||
let mut con = db.2.clone().get_con().await;
|
let mut con = db.2.clone().get_con().await;
|
||||||
|
|
||||||
// handle incoming messages on socket
|
// handle incoming messages on socket
|
||||||
let dbc = db.clone();
|
let dbc = db.clone();
|
||||||
let recv_sender = sender.clone();
|
if let Some(Ok(WsMessage::Text(text))) = stream.next().await {
|
||||||
let mut recv_task = tokio::spawn(async move {
|
|
||||||
while let Some(Ok(WsMessage::Text(text))) = stream.next().await {
|
|
||||||
if text == "Pong" {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if text == "Close" {
|
|
||||||
recv_sender.send("Close".to_string()).unwrap();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let data: SocketMessage = match serde_json::from_str(&text.to_string()) {
|
let data: SocketMessage = match serde_json::from_str(&text.to_string()) {
|
||||||
Ok(t) => t,
|
Ok(t) => t,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
recv_sender.send("Close".to_string()).unwrap();
|
let _ = sink.close().await;
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if data.method != SocketMethod::Headers && user.is_none() {
|
if data.method != SocketMethod::Headers && user.is_none() {
|
||||||
// we've sent something else before authenticating... that's not right
|
// we've sent something else before authenticating... that's not right
|
||||||
recv_sender.send("Close".to_string()).unwrap();
|
let _ = sink.close().await;
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
match data.method {
|
match data.method {
|
||||||
|
@ -113,16 +75,16 @@ pub async fn handle_socket(socket: WebSocket, state: State, channel_id: usize) {
|
||||||
.get_user_by_id(match data.user.parse::<usize>() {
|
.get_user_by_id(match data.user.parse::<usize>() {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
recv_sender.send("Close".to_string()).unwrap();
|
let _ = sink.close().await;
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(ua) => ua,
|
Ok(ua) => ua,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
recv_sender.send("Close".to_string()).unwrap();
|
let _ = sink.close().await;
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
@ -131,16 +93,16 @@ pub async fn handle_socket(socket: WebSocket, state: State, channel_id: usize) {
|
||||||
.get_channel_by_id(match data.channel.parse::<usize>() {
|
.get_channel_by_id(match data.channel.parse::<usize>() {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
recv_sender.send("Close".to_string()).unwrap();
|
let _ = sink.close().await;
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
recv_sender.send("Close".to_string()).unwrap();
|
let _ = sink.close().await;
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -152,80 +114,61 @@ pub async fn handle_socket(socket: WebSocket, state: State, channel_id: usize) {
|
||||||
{
|
{
|
||||||
Ok(ua) => ua,
|
Ok(ua) => ua,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
recv_sender.send("Close".to_string()).unwrap();
|
let _ = sink.close().await;
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if !channel.check_read(user.id, Some(membership.role)) {
|
if !channel.check_read(user.id, Some(membership.role)) {
|
||||||
recv_sender.send("Close".to_string()).unwrap();
|
let _ = sink.close().await;
|
||||||
break;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
recv_sender.send("Close".to_string()).unwrap();
|
let _ = sink.close().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sink.close().await.unwrap();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut recv_task = tokio::spawn(async move {
|
||||||
|
while let Some(Ok(WsMessage::Text(text))) = stream.next().await {
|
||||||
|
if text != "Close" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// yes, this is an "unclean" disconnection from the socket...
|
||||||
|
// i don't care, it works
|
||||||
|
drop(stream);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
drop(recv_sender);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let mut redis_task = tokio::spawn(async move {
|
||||||
// forward messages from redis to the mpsc
|
// forward messages from redis to the mpsc
|
||||||
let send_task_sender = sender.clone();
|
|
||||||
let mut send_task = tokio::spawn(async move {
|
|
||||||
let mut pubsub = con.as_pubsub();
|
let mut pubsub = con.as_pubsub();
|
||||||
pubsub.subscribe(channel_id).unwrap();
|
pubsub.subscribe(channel_id).unwrap();
|
||||||
|
|
||||||
while let Ok(msg) = pubsub.get_message() {
|
while let Ok(msg) = pubsub.get_message() {
|
||||||
// payload is a stringified SocketMessage
|
// payload is a stringified SocketMessage
|
||||||
if send_task_sender.send(msg.get_payload().unwrap()).is_err() {
|
if sink
|
||||||
break;
|
.send(WsMessage::Text(msg.get_payload::<String>().unwrap().into()))
|
||||||
|
.await
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
drop(send_task_sender);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// ...
|
|
||||||
let close_sender = sender.clone();
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = (&mut heartbeat_task) => {
|
_ = (&mut recv_task) => redis_task.abort(),
|
||||||
let _ = close_sender.send("Close".to_string());
|
_ = (&mut redis_task) => recv_task.abort()
|
||||||
forward_task.abort();
|
|
||||||
recv_task.abort();
|
|
||||||
send_task.abort();
|
|
||||||
}
|
}
|
||||||
_ = (&mut forward_task) => {
|
|
||||||
send_task.abort();
|
|
||||||
recv_task.abort();
|
|
||||||
heartbeat_task.abort();
|
|
||||||
}
|
|
||||||
_ = (&mut send_task) => {
|
|
||||||
let _ = close_sender.send("Close".to_string());
|
|
||||||
forward_task.abort();
|
|
||||||
recv_task.abort();
|
|
||||||
heartbeat_task.abort();
|
|
||||||
},
|
|
||||||
_ = (&mut recv_task) => {
|
|
||||||
let _ = close_sender.send("Close".to_string());
|
|
||||||
send_task.abort();
|
|
||||||
forward_task.abort();
|
|
||||||
heartbeat_task.abort();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// kill
|
|
||||||
drop(sender);
|
|
||||||
drop(db);
|
|
||||||
|
|
||||||
send_task.abort();
|
|
||||||
recv_task.abort();
|
|
||||||
forward_task.abort();
|
|
||||||
heartbeat_task.abort();
|
|
||||||
|
|
||||||
let _ = close_sender.send("Close".to_string());
|
|
||||||
tracing::info!("socket terminate");
|
tracing::info!("socket terminate");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ default = ["sqlite", "redis"]
|
||||||
[dependencies]
|
[dependencies]
|
||||||
pathbufd = "0.1.4"
|
pathbufd = "0.1.4"
|
||||||
serde = { version = "1.0.219", features = ["derive"] }
|
serde = { version = "1.0.219", features = ["derive"] }
|
||||||
toml = "0.8.20"
|
toml = "0.8.21"
|
||||||
tetratto-shared = { path = "../shared" }
|
tetratto-shared = { path = "../shared" }
|
||||||
tetratto-l10n = { path = "../l10n" }
|
tetratto-l10n = { path = "../l10n" }
|
||||||
serde_json = "1.0.140"
|
serde_json = "1.0.140"
|
||||||
|
|
|
@ -9,4 +9,4 @@ license.workspace = true
|
||||||
[dependencies]
|
[dependencies]
|
||||||
pathbufd = "0.1.4"
|
pathbufd = "0.1.4"
|
||||||
serde = { version = "1.0.219", features = ["derive"] }
|
serde = { version = "1.0.219", features = ["derive"] }
|
||||||
toml = "0.8.20"
|
toml = "0.8.21"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue