use std::{collections::HashMap, time::Duration}; use redis::Commands; use axum::{ extract::{ ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}, Path, }, response::{IntoResponse, Response}, Extension, Json, }; use axum_extra::extract::CookieJar; use tetratto_core::{ cache::Cache, model::{ auth::User, channels::Message, socket::{PacketType, SocketMessage, SocketMethod}, ApiReturn, Error, }, DataManager, }; use crate::{get_user_from_token, routes::api::v1::CreateMessage, State}; use serde::Deserialize; use futures_util::{sink::SinkExt, stream::StreamExt}; #[derive(Clone, Deserialize)] pub struct SocketHeaders { pub user: String, pub is_channel: bool, } /// Handle a subscription to the websocket. pub async fn subscription_handler( ws: WebSocketUpgrade, Extension(data): Extension, Path(id): Path, ) -> Response { let data = &(data.read().await); let data = data.0.clone(); ws.on_upgrade(|socket| async move { tokio::spawn(async move { handle_socket(socket, data, id).await; }); }) } pub async fn handle_socket(socket: WebSocket, db: DataManager, community_id: String) { let (mut sink, mut stream) = socket.split(); let mut user: Option = None; let mut headers: Option = None; // handle incoming messages on socket let dbc = db.clone(); if let Some(Ok(WsMessage::Text(text))) = stream.next().await { let data: SocketMessage = match serde_json::from_str(&text.to_string()) { Ok(t) => t, Err(_) => { let _ = sink.close().await; return; } }; if data.method != SocketMethod::Headers && user.is_none() && headers.is_none() { // we've sent something else before authenticating... that's not right let _ = sink.close().await; return; } match data.method { SocketMethod::Headers => { let data: SocketHeaders = data.data(); headers = Some(data.clone()); user = Some( match dbc .get_user_by_id(match data.user.parse::() { Ok(c) => c, Err(_) => { let _ = sink.close().await; return; } }) .await { Ok(ua) => ua, Err(_) => { let _ = sink.close().await; return; } }, ); if data.is_channel { // verify permissions for single channel let channel = match dbc .get_channel_by_id(match community_id.parse::() { Ok(c) => c, Err(_) => { let _ = sink.close().await; return; } }) .await { Ok(c) => c, Err(_) => { let _ = sink.close().await; return; } }; let user = user.as_ref().unwrap(); let membership = match dbc .get_membership_by_owner_community(user.id, channel.id) .await { Ok(ua) => ua, Err(_) => { let _ = sink.close().await; return; } }; if !channel.check_read(user.id, Some(membership.role)) { let _ = sink.close().await; return; } } } _ => { let _ = sink.close().await; return; } } } else { sink.close().await.unwrap(); return; } // get channel permissions let user = user.unwrap(); let headers = headers.unwrap(); let mut channel_read_statuses: HashMap = HashMap::new(); if !headers.is_channel { // check permissions for every channel in community let community_id = match community_id.parse::() { Ok(c) => c, Err(_) => return, }; let membership = match dbc .get_membership_by_owner_community(user.id, community_id) .await { Ok(ua) => ua, Err(_) => { return; } }; for channel in dbc.get_channels_by_community(community_id).await.unwrap() { channel_read_statuses.insert( channel.id, channel.check_read(user.id, Some(membership.role)), ); } } // ... let mut recv_task = tokio::spawn(async move { while let Some(Ok(WsMessage::Text(text))) = stream.next().await { if text != "Close" { continue; } // yes, this is an "unclean" disconnection from the socket... // i don't care, it works drop(stream); break; } }); let dbc = db.clone(); let mut redis_task = tokio::spawn(async move { // forward messages from redis to the mpsc let mut con = dbc.2.get_con().await; let mut pubsub = con.as_pubsub(); pubsub.subscribe(user.id).unwrap(); pubsub.subscribe(community_id.clone()).unwrap(); // listen for pubsub messages while let Ok(msg) = pubsub.get_message() { // payload is a stringified SocketMessage let smsg = msg.get_payload::().unwrap(); let packet: SocketMessage = serde_json::from_str(&smsg).unwrap(); if packet.method == SocketMethod::Forward(PacketType::Ping) { // forward with custom message if sink.send(WsMessage::Text("Ping".into())).await.is_err() { drop(sink); break; } } else if packet.method == SocketMethod::Message { // check perms and then forward let d: (String, Message) = packet.data(); if let Some(cs) = channel_read_statuses.get(&d.1.channel) { if !cs { continue; } } else { if !headers.is_channel { // since we didn't select by just a channel, there HAS to be // an entry for the channel for us to check this message continue; // we don't need to check messages when we're subscribed to // a channel, since that is checked on headers submission when // we subscribe to a channel } } if sink.send(WsMessage::Text(smsg.into())).await.is_err() { drop(sink); break; } } else { // forward to client if sink.send(WsMessage::Text(smsg.into())).await.is_err() { drop(sink); break; } } } }); let db2c = db.2.clone(); let heartbeat_task = tokio::spawn(async move { let mut con = db2c.get_con().await; let mut heartbeat = tokio::time::interval(Duration::from_secs(10)); loop { con.publish::( user.id, serde_json::to_string(&SocketMessage { method: SocketMethod::Forward(PacketType::Ping), data: "Ping".to_string(), }) .unwrap(), ) .unwrap(); heartbeat.tick().await; } }); tokio::select! { _ = (&mut recv_task) => redis_task.abort(), _ = (&mut redis_task) => recv_task.abort() } heartbeat_task.abort(); // kill tracing::info!("socket terminate"); } pub async fn create_request( jar: CookieJar, Extension(data): Extension, Json(req): Json, ) -> impl IntoResponse { let data = &(data.read().await).0; let user = match get_user_from_token!(jar, data) { Some(ua) => ua, None => return Json(Error::NotAllowed.into()), }; match data .create_message(Message::new( match req.channel.parse::() { Ok(c) => c, Err(e) => return Json(Error::MiscError(e.to_string()).into()), }, user.id, req.content, )) .await { Ok(_) => Json(ApiReturn { ok: true, message: "Message created".to_string(), payload: (), }), Err(e) => Json(e.into()), } } pub async fn delete_request( jar: CookieJar, Extension(data): Extension, Path(id): Path, ) -> impl IntoResponse { let data = &(data.read().await).0; let user = match get_user_from_token!(jar, data) { Some(ua) => ua, None => return Json(Error::NotAllowed.into()), }; match data.delete_message(id, user).await { Ok(_) => Json(ApiReturn { ok: true, message: "Message deleted".to_string(), payload: (), }), Err(e) => Json(e.into()), } }