add: connect to socket by community

direct messages/groups connect by channel id, everything else should connect by channel with the "is_channel" header set to true
This commit is contained in:
trisua 2025-04-29 16:53:34 -04:00
parent c1c8cdbfcd
commit 0304461389
20 changed files with 241 additions and 160 deletions

View file

@ -1,3 +1,5 @@
use std::{collections::HashMap, time::Duration};
use redis::Commands;
use axum::{
extract::{
ws::{Message as WsMessage, WebSocket, WebSocketUpgrade},
@ -12,7 +14,7 @@ use tetratto_core::{
model::{
auth::User,
channels::Message,
socket::{SocketMessage, SocketMethod},
socket::{PacketType, SocketMessage, SocketMethod},
ApiReturn, Error,
},
DataManager,
@ -21,10 +23,10 @@ use crate::{get_user_from_token, routes::api::v1::CreateMessage, State};
use serde::Deserialize;
use futures_util::{sink::SinkExt, stream::StreamExt};
#[derive(Deserialize)]
#[derive(Clone, Deserialize)]
pub struct SocketHeaders {
pub channel: String,
pub user: String,
pub is_channel: bool,
}
/// Handle a subscription to the websocket.
@ -43,11 +45,11 @@ pub async fn subscription_handler(
})
}
pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: String) {
pub async fn handle_socket(socket: WebSocket, db: DataManager, community_id: String) {
let (mut sink, mut stream) = socket.split();
let mut user: Option<User> = None;
let mut con = db.2.clone().get_con().await;
let mut headers: Option<SocketHeaders> = None;
// handle incoming messages on socket
let dbc = db.clone();
@ -60,7 +62,7 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
}
};
if data.method != SocketMethod::Headers && user.is_none() {
if data.method != SocketMethod::Headers && user.is_none() && headers.is_none() {
// we've sent something else before authenticating... that's not right
let _ = sink.close().await;
return;
@ -70,6 +72,7 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
SocketMethod::Headers => {
let data: SocketHeaders = data.data();
headers = Some(data.clone());
user = Some(
match dbc
.get_user_by_id(match data.user.parse::<usize>() {
@ -89,39 +92,42 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
},
);
let channel = match dbc
.get_channel_by_id(match data.channel.parse::<usize>() {
if data.is_channel {
// verify permissions for single channel
let channel = match dbc
.get_channel_by_id(match community_id.parse::<usize>() {
Ok(c) => c,
Err(_) => {
let _ = sink.close().await;
return;
}
})
.await
{
Ok(c) => c,
Err(_) => {
let _ = sink.close().await;
return;
}
})
.await
{
Ok(c) => c,
Err(_) => {
};
let user = user.as_ref().unwrap();
let membership = match dbc
.get_membership_by_owner_community(user.id, channel.id)
.await
{
Ok(ua) => ua,
Err(_) => {
let _ = sink.close().await;
return;
}
};
if !channel.check_read(user.id, Some(membership.role)) {
let _ = sink.close().await;
return;
}
};
let user = user.as_ref().unwrap();
let membership = match dbc
.get_membership_by_owner_community(user.id, channel.id)
.await
{
Ok(ua) => ua,
Err(_) => {
let _ = sink.close().await;
return;
}
};
if !channel.check_read(user.id, Some(membership.role)) {
let _ = sink.close().await;
return;
}
}
_ => {
@ -134,6 +140,37 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
return;
}
// get channel permissions
let user = user.unwrap();
let headers = headers.unwrap();
let mut channel_read_statuses: HashMap<usize, bool> = HashMap::new();
if !headers.is_channel {
// check permissions for every channel in community
let community_id = match community_id.parse::<usize>() {
Ok(c) => c,
Err(_) => return,
};
let membership = match dbc
.get_membership_by_owner_community(user.id, community_id)
.await
{
Ok(ua) => ua,
Err(_) => {
return;
}
};
for channel in dbc.get_channels_by_community(community_id).await.unwrap() {
channel_read_statuses.insert(
channel.id,
channel.check_read(user.id, Some(membership.role)),
);
}
}
// ...
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(WsMessage::Text(text))) = stream.next().await {
if text != "Close" {
@ -147,28 +184,86 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
}
});
let dbc = db.clone();
let mut redis_task = tokio::spawn(async move {
// forward messages from redis to the mpsc
let mut con = dbc.2.get_con().await;
let mut pubsub = con.as_pubsub();
pubsub.subscribe(channel_id).unwrap();
pubsub.subscribe(user.id).unwrap();
pubsub.subscribe(community_id.clone()).unwrap();
// listen for pubsub messages
while let Ok(msg) = pubsub.get_message() {
// payload is a stringified SocketMessage
if sink
.send(WsMessage::Text(msg.get_payload::<String>().unwrap().into()))
.await
.is_err()
{
return;
let smsg = msg.get_payload::<String>().unwrap();
let packet: SocketMessage = serde_json::from_str(&smsg).unwrap();
if packet.method == SocketMethod::Forward(PacketType::Ping) {
// forward with custom message
if sink.send(WsMessage::Text("Ping".into())).await.is_err() {
drop(sink);
break;
}
} else if packet.method == SocketMethod::Message {
// check perms and then forward
let d: (String, Message) = packet.data();
if let Some(cs) = channel_read_statuses.get(&d.1.channel) {
if !cs {
continue;
}
} else {
if !headers.is_channel {
// since we didn't select by just a channel, there HAS to be
// an entry for the channel for us to check this message
continue;
// we don't need to check messages when we're subscribed to
// a channel, since that is checked on headers submission when
// we subscribe to a channel
}
}
if sink.send(WsMessage::Text(smsg.into())).await.is_err() {
drop(sink);
break;
}
} else {
// forward to client
if sink.send(WsMessage::Text(smsg.into())).await.is_err() {
drop(sink);
break;
}
}
}
});
let db2c = db.2.clone();
let heartbeat_task = tokio::spawn(async move {
let mut con = db2c.get_con().await;
let mut heartbeat = tokio::time::interval(Duration::from_secs(10));
loop {
con.publish::<usize, String, ()>(
user.id,
serde_json::to_string(&SocketMessage {
method: SocketMethod::Forward(PacketType::Ping),
data: "Ping".to_string(),
})
.unwrap(),
)
.unwrap();
heartbeat.tick().await;
}
});
tokio::select! {
_ = (&mut recv_task) => redis_task.abort(),
_ = (&mut redis_task) => recv_task.abort()
}
heartbeat_task.abort(); // kill
tracing::info!("socket terminate");
}