fix: socket connections/closing
you can no longer crash the entire server by connection to the socket 8 times
This commit is contained in:
parent
3a12c0ee6c
commit
c1c8cdbfcd
5 changed files with 147 additions and 164 deletions
|
@ -15,8 +15,8 @@ use tetratto_core::{
|
|||
socket::{SocketMessage, SocketMethod},
|
||||
ApiReturn, Error,
|
||||
},
|
||||
DataManager,
|
||||
};
|
||||
use std::{sync::mpsc, time::Duration};
|
||||
use crate::{get_user_from_token, routes::api::v1::CreateMessage, State};
|
||||
use serde::Deserialize;
|
||||
use futures_util::{sink::SinkExt, stream::StreamExt};
|
||||
|
@ -31,201 +31,144 @@ pub struct SocketHeaders {
|
|||
pub async fn subscription_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Extension(data): Extension<State>,
|
||||
Path(channel_id): Path<usize>,
|
||||
Path(id): Path<String>,
|
||||
) -> Response {
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, data, channel_id))
|
||||
let data = &(data.read().await);
|
||||
let data = data.0.clone();
|
||||
|
||||
ws.on_upgrade(|socket| async move {
|
||||
tokio::spawn(async move {
|
||||
handle_socket(socket, data, id).await;
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn handle_socket(socket: WebSocket, state: State, channel_id: usize) {
|
||||
let db = &(state.read().await).0;
|
||||
let db = db.clone();
|
||||
|
||||
pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: String) {
|
||||
let (mut sink, mut stream) = socket.split();
|
||||
let (sender, receiver) = mpsc::channel::<String>();
|
||||
|
||||
// forward messages from mpsc to the sink
|
||||
let mut forward_task = tokio::spawn(async move {
|
||||
while let Ok(message) = receiver.recv() {
|
||||
if message == "Close" {
|
||||
let _ = sink.close().await;
|
||||
break;
|
||||
}
|
||||
|
||||
if sink.send(message.into()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
drop(receiver);
|
||||
drop(sink);
|
||||
});
|
||||
|
||||
// ping
|
||||
let ping_sender = sender.clone();
|
||||
let mut heartbeat_task = tokio::spawn(async move {
|
||||
let mut heartbeat = tokio::time::interval(Duration::from_secs(10));
|
||||
|
||||
while ping_sender.send("Ping".to_string()).is_ok() {
|
||||
heartbeat.tick().await;
|
||||
}
|
||||
|
||||
drop(ping_sender);
|
||||
});
|
||||
|
||||
// ...
|
||||
let mut user: Option<User> = None;
|
||||
let mut con = db.2.clone().get_con().await;
|
||||
|
||||
// handle incoming messages on socket
|
||||
let dbc = db.clone();
|
||||
let recv_sender = sender.clone();
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(WsMessage::Text(text))) = stream.next().await {
|
||||
if text == "Pong" {
|
||||
continue;
|
||||
if let Some(Ok(WsMessage::Text(text))) = stream.next().await {
|
||||
let data: SocketMessage = match serde_json::from_str(&text.to_string()) {
|
||||
Ok(t) => t,
|
||||
Err(_) => {
|
||||
let _ = sink.close().await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if text == "Close" {
|
||||
recv_sender.send("Close".to_string()).unwrap();
|
||||
break;
|
||||
}
|
||||
if data.method != SocketMethod::Headers && user.is_none() {
|
||||
// we've sent something else before authenticating... that's not right
|
||||
let _ = sink.close().await;
|
||||
return;
|
||||
}
|
||||
|
||||
let data: SocketMessage = match serde_json::from_str(&text.to_string()) {
|
||||
Ok(t) => t,
|
||||
Err(_) => {
|
||||
recv_sender.send("Close".to_string()).unwrap();
|
||||
break;
|
||||
}
|
||||
};
|
||||
match data.method {
|
||||
SocketMethod::Headers => {
|
||||
let data: SocketHeaders = data.data();
|
||||
|
||||
if data.method != SocketMethod::Headers && user.is_none() {
|
||||
// we've sent something else before authenticating... that's not right
|
||||
recv_sender.send("Close".to_string()).unwrap();
|
||||
break;
|
||||
}
|
||||
|
||||
match data.method {
|
||||
SocketMethod::Headers => {
|
||||
let data: SocketHeaders = data.data();
|
||||
|
||||
user = Some(
|
||||
match dbc
|
||||
.get_user_by_id(match data.user.parse::<usize>() {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
recv_sender.send("Close".to_string()).unwrap();
|
||||
break;
|
||||
}
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(ua) => ua,
|
||||
Err(_) => {
|
||||
recv_sender.send("Close".to_string()).unwrap();
|
||||
break;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let channel = match dbc
|
||||
.get_channel_by_id(match data.channel.parse::<usize>() {
|
||||
user = Some(
|
||||
match dbc
|
||||
.get_user_by_id(match data.user.parse::<usize>() {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
recv_sender.send("Close".to_string()).unwrap();
|
||||
break;
|
||||
let _ = sink.close().await;
|
||||
return;
|
||||
}
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
recv_sender.send("Close".to_string()).unwrap();
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let user = user.as_ref().unwrap();
|
||||
|
||||
let membership = match dbc
|
||||
.get_membership_by_owner_community(user.id, channel.id)
|
||||
.await
|
||||
{
|
||||
Ok(ua) => ua,
|
||||
Err(_) => {
|
||||
recv_sender.send("Close".to_string()).unwrap();
|
||||
break;
|
||||
let _ = sink.close().await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
if !channel.check_read(user.id, Some(membership.role)) {
|
||||
recv_sender.send("Close".to_string()).unwrap();
|
||||
break;
|
||||
let channel = match dbc
|
||||
.get_channel_by_id(match data.channel.parse::<usize>() {
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
let _ = sink.close().await;
|
||||
return;
|
||||
}
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(_) => {
|
||||
let _ = sink.close().await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
recv_sender.send("Close".to_string()).unwrap();
|
||||
break;
|
||||
};
|
||||
|
||||
let user = user.as_ref().unwrap();
|
||||
|
||||
let membership = match dbc
|
||||
.get_membership_by_owner_community(user.id, channel.id)
|
||||
.await
|
||||
{
|
||||
Ok(ua) => ua,
|
||||
Err(_) => {
|
||||
let _ = sink.close().await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !channel.check_read(user.id, Some(membership.role)) {
|
||||
let _ = sink.close().await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let _ = sink.close().await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
sink.close().await.unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
drop(recv_sender);
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(WsMessage::Text(text))) = stream.next().await {
|
||||
if text != "Close" {
|
||||
continue;
|
||||
}
|
||||
|
||||
// yes, this is an "unclean" disconnection from the socket...
|
||||
// i don't care, it works
|
||||
drop(stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
// forward messages from redis to the mpsc
|
||||
let send_task_sender = sender.clone();
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
let mut redis_task = tokio::spawn(async move {
|
||||
// forward messages from redis to the mpsc
|
||||
let mut pubsub = con.as_pubsub();
|
||||
pubsub.subscribe(channel_id).unwrap();
|
||||
|
||||
while let Ok(msg) = pubsub.get_message() {
|
||||
// payload is a stringified SocketMessage
|
||||
if send_task_sender.send(msg.get_payload().unwrap()).is_err() {
|
||||
break;
|
||||
if sink
|
||||
.send(WsMessage::Text(msg.get_payload::<String>().unwrap().into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
drop(send_task_sender);
|
||||
});
|
||||
|
||||
// ...
|
||||
let close_sender = sender.clone();
|
||||
tokio::select! {
|
||||
_ = (&mut heartbeat_task) => {
|
||||
let _ = close_sender.send("Close".to_string());
|
||||
forward_task.abort();
|
||||
recv_task.abort();
|
||||
send_task.abort();
|
||||
}
|
||||
_ = (&mut forward_task) => {
|
||||
send_task.abort();
|
||||
recv_task.abort();
|
||||
heartbeat_task.abort();
|
||||
}
|
||||
_ = (&mut send_task) => {
|
||||
let _ = close_sender.send("Close".to_string());
|
||||
forward_task.abort();
|
||||
recv_task.abort();
|
||||
heartbeat_task.abort();
|
||||
},
|
||||
_ = (&mut recv_task) => {
|
||||
let _ = close_sender.send("Close".to_string());
|
||||
send_task.abort();
|
||||
forward_task.abort();
|
||||
heartbeat_task.abort();
|
||||
}
|
||||
};
|
||||
_ = (&mut recv_task) => redis_task.abort(),
|
||||
_ = (&mut redis_task) => recv_task.abort()
|
||||
}
|
||||
|
||||
// kill
|
||||
drop(sender);
|
||||
drop(db);
|
||||
|
||||
send_task.abort();
|
||||
recv_task.abort();
|
||||
forward_task.abort();
|
||||
heartbeat_task.abort();
|
||||
|
||||
let _ = close_sender.send("Close".to_string());
|
||||
tracing::info!("socket terminate");
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue