use axum::{ extract::{ ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}, Path, }, response::{IntoResponse, Response}, Extension, Json, }; use axum_extra::extract::CookieJar; use tetratto_core::{ cache::Cache, model::{ auth::User, channels::Message, socket::{SocketMessage, SocketMethod}, ApiReturn, Error, }, DataManager, }; use crate::{get_user_from_token, routes::api::v1::CreateMessage, State}; use serde::Deserialize; use futures_util::{sink::SinkExt, stream::StreamExt}; #[derive(Deserialize)] pub struct SocketHeaders { pub channel: String, pub user: String, } /// Handle a subscription to the websocket. pub async fn subscription_handler( ws: WebSocketUpgrade, Extension(data): Extension, Path(id): Path, ) -> Response { let data = &(data.read().await); let data = data.0.clone(); ws.on_upgrade(|socket| async move { tokio::spawn(async move { handle_socket(socket, data, id).await; }); }) } pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: String) { let (mut sink, mut stream) = socket.split(); let mut user: Option = None; let mut con = db.2.clone().get_con().await; // handle incoming messages on socket let dbc = db.clone(); if let Some(Ok(WsMessage::Text(text))) = stream.next().await { let data: SocketMessage = match serde_json::from_str(&text.to_string()) { Ok(t) => t, Err(_) => { let _ = sink.close().await; return; } }; if data.method != SocketMethod::Headers && user.is_none() { // we've sent something else before authenticating... that's not right let _ = sink.close().await; return; } match data.method { SocketMethod::Headers => { let data: SocketHeaders = data.data(); user = Some( match dbc .get_user_by_id(match data.user.parse::() { Ok(c) => c, Err(_) => { let _ = sink.close().await; return; } }) .await { Ok(ua) => ua, Err(_) => { let _ = sink.close().await; return; } }, ); let channel = match dbc .get_channel_by_id(match data.channel.parse::() { Ok(c) => c, Err(_) => { let _ = sink.close().await; return; } }) .await { Ok(c) => c, Err(_) => { let _ = sink.close().await; return; } }; let user = user.as_ref().unwrap(); let membership = match dbc .get_membership_by_owner_community(user.id, channel.id) .await { Ok(ua) => ua, Err(_) => { let _ = sink.close().await; return; } }; if !channel.check_read(user.id, Some(membership.role)) { let _ = sink.close().await; return; } } _ => { let _ = sink.close().await; return; } } } else { sink.close().await.unwrap(); return; } let mut recv_task = tokio::spawn(async move { while let Some(Ok(WsMessage::Text(text))) = stream.next().await { if text != "Close" { continue; } // yes, this is an "unclean" disconnection from the socket... // i don't care, it works drop(stream); break; } }); let mut redis_task = tokio::spawn(async move { // forward messages from redis to the mpsc let mut pubsub = con.as_pubsub(); pubsub.subscribe(channel_id).unwrap(); while let Ok(msg) = pubsub.get_message() { // payload is a stringified SocketMessage if sink .send(WsMessage::Text(msg.get_payload::().unwrap().into())) .await .is_err() { return; } } }); tokio::select! { _ = (&mut recv_task) => redis_task.abort(), _ = (&mut redis_task) => recv_task.abort() } tracing::info!("socket terminate"); } pub async fn create_request( jar: CookieJar, Extension(data): Extension, Json(req): Json, ) -> impl IntoResponse { let data = &(data.read().await).0; let user = match get_user_from_token!(jar, data) { Some(ua) => ua, None => return Json(Error::NotAllowed.into()), }; match data .create_message(Message::new( match req.channel.parse::() { Ok(c) => c, Err(e) => return Json(Error::MiscError(e.to_string()).into()), }, user.id, req.content, )) .await { Ok(_) => Json(ApiReturn { ok: true, message: "Message created".to_string(), payload: (), }), Err(e) => Json(e.into()), } } pub async fn delete_request( jar: CookieJar, Extension(data): Extension, Path(id): Path, ) -> impl IntoResponse { let data = &(data.read().await).0; let user = match get_user_from_token!(jar, data) { Some(ua) => ua, None => return Json(Error::NotAllowed.into()), }; match data.delete_message(id, user).await { Ok(_) => Json(ApiReturn { ok: true, message: "Message deleted".to_string(), payload: (), }), Err(e) => Json(e.into()), } }