use axum::{ extract::{ ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}, Path, }, response::{IntoResponse, Response}, Extension, Json, }; use axum_extra::extract::CookieJar; use tetratto_core::{ cache::Cache, model::{ auth::User, channels::Message, socket::{SocketMessage, SocketMethod}, ApiReturn, Error, }, }; use std::{sync::mpsc, time::Duration}; use crate::{get_user_from_token, routes::api::v1::CreateMessage, State}; use serde::Deserialize; use futures_util::{sink::SinkExt, stream::StreamExt}; #[derive(Deserialize)] pub struct SocketHeaders { pub channel: String, pub user: String, } /// Handle a subscription to the websocket. pub async fn subscription_handler( ws: WebSocketUpgrade, Extension(data): Extension, Path(channel_id): Path, ) -> Response { ws.on_upgrade(move |socket| handle_socket(socket, data, channel_id)) } pub async fn handle_socket(socket: WebSocket, state: State, channel_id: usize) { let db = &(state.read().await).0; let db = db.clone(); let (mut sink, mut stream) = socket.split(); let (sender, receiver) = mpsc::channel::(); // forward messages from mpsc to the sink let mut forward_task = tokio::spawn(async move { while let Ok(message) = receiver.recv() { if message == "Close" { let _ = sink.close().await; drop(receiver); break; } if sink.send(message.into()).await.is_err() { break; } } }); // ping let ping_sender = sender.clone(); let mut heartbeat_task = tokio::spawn(async move { let mut heartbeat = tokio::time::interval(Duration::from_secs(30)); loop { heartbeat.tick().await; if ping_sender.send("Ping".to_string()).is_err() { // remote has abandoned us break; } } }); // ... let mut user: Option = None; let mut con = db.2.clone().get_con().await; // handle incoming messages on socket let dbc = db.clone(); let recv_sender = sender.clone(); let mut recv_task = tokio::spawn(async move { while let Some(Ok(WsMessage::Text(text))) = stream.next().await { if text == "Pong" { continue; } if text == "Close" { recv_sender.send("Close".to_string()).unwrap(); break; } let data: SocketMessage = match serde_json::from_str(&text.to_string()) { Ok(t) => t, Err(_) => { recv_sender.send("Close".to_string()).unwrap(); break; } }; if data.method != SocketMethod::Headers && user.is_none() { // we've sent something else before authenticating... that's not right recv_sender.send("Close".to_string()).unwrap(); break; } match data.method { SocketMethod::Headers => { let data: SocketHeaders = data.data(); user = Some( match dbc .get_user_by_id(match data.user.parse::() { Ok(c) => c, Err(_) => { recv_sender.send("Close".to_string()).unwrap(); break; } }) .await { Ok(ua) => ua, Err(_) => { recv_sender.send("Close".to_string()).unwrap(); break; } }, ); let channel = match dbc .get_channel_by_id(match data.channel.parse::() { Ok(c) => c, Err(_) => { recv_sender.send("Close".to_string()).unwrap(); break; } }) .await { Ok(c) => c, Err(_) => { recv_sender.send("Close".to_string()).unwrap(); break; } }; let user = user.as_ref().unwrap(); let membership = match dbc .get_membership_by_owner_community(user.id, channel.id) .await { Ok(ua) => ua, Err(_) => { recv_sender.send("Close".to_string()).unwrap(); break; } }; if !channel.check_read(user.id, Some(membership.role)) { recv_sender.send("Close".to_string()).unwrap(); break; } } _ => { recv_sender.send("Close".to_string()).unwrap(); break; } } } }); // forward messages from redis to the mpsc let send_task_sender = sender.clone(); let mut send_task = tokio::spawn(async move { let mut pubsub = con.as_pubsub(); pubsub.subscribe(channel_id).unwrap(); loop { while let Ok(msg) = pubsub.get_message() { // payload is a stringified SocketMessage if send_task_sender.send(msg.get_payload().unwrap()).is_err() { break; } } } }); // ... let close_sender = sender.clone(); tokio::select! { _ = (&mut heartbeat_task) => { let _ = close_sender.send("Close".to_string()); forward_task.abort(); recv_task.abort(); send_task.abort(); } _ = (&mut forward_task) => { send_task.abort(); recv_task.abort(); heartbeat_task.abort(); } _ = (&mut send_task) => { let _ = close_sender.send("Close".to_string()); forward_task.abort(); recv_task.abort(); heartbeat_task.abort(); }, _ = (&mut recv_task) => { let _ = close_sender.send("Close".to_string()); send_task.abort(); forward_task.abort(); heartbeat_task.abort(); }, }; } pub async fn create_request( jar: CookieJar, Extension(data): Extension, Json(req): Json, ) -> impl IntoResponse { let data = &(data.read().await).0; let user = match get_user_from_token!(jar, data) { Some(ua) => ua, None => return Json(Error::NotAllowed.into()), }; match data .create_message(Message::new( match req.channel.parse::() { Ok(c) => c, Err(e) => return Json(Error::MiscError(e.to_string()).into()), }, user.id, req.content, )) .await { Ok(_) => Json(ApiReturn { ok: true, message: "Message created".to_string(), payload: (), }), Err(e) => Json(e.into()), } } pub async fn delete_request( jar: CookieJar, Extension(data): Extension, Path(id): Path, ) -> impl IntoResponse { let data = &(data.read().await).0; let user = match get_user_from_token!(jar, data) { Some(ua) => ua, None => return Json(Error::NotAllowed.into()), }; match data.delete_message(id, user).await { Ok(_) => Json(ApiReturn { ok: true, message: "Message deleted".to_string(), payload: (), }), Err(e) => Json(e.into()), } }