add: user socket streams

add: group messages by author in ui
TODO: group messages by author in ui as they come in from socket
TODO: notifications stream connection
This commit is contained in:
trisua 2025-05-01 16:43:58 -04:00
parent c549fdd274
commit 094dd5fdb5
8 changed files with 198 additions and 40 deletions

View file

@ -1,3 +1,5 @@
use std::time::Duration;
use crate::{
get_user_from_token,
model::{ApiReturn, Error},
@ -8,19 +10,28 @@ use crate::{
State,
};
use axum::{
Extension, Json,
extract::Path,
extract::{
ws::{Message as WsMessage, WebSocket},
Path, WebSocketUpgrade,
},
response::{IntoResponse, Redirect},
Extension, Json,
};
use axum_extra::extract::CookieJar;
use futures_util::{sink::SinkExt, stream::StreamExt};
use tetratto_core::{
cache::Cache,
model::{
auth::{Token, UserSettings},
permissions::FinePermission,
socket::{PacketType, SocketMessage, SocketMethod},
},
DataManager,
};
#[cfg(feature = "redis")]
use redis::Commands;
pub async fn redirect_from_id(
Extension(data): Extension<State>,
Path(id): Path<String>,
@ -410,3 +421,114 @@ pub async fn has_totp_enabled_request(
payload: Some(!user.totp.is_empty()),
})
}
/// Handle a subscription to the websocket.
#[cfg(feature = "redis")]
pub async fn subscription_handler(
jar: CookieJar,
ws: WebSocketUpgrade,
Extension(data): Extension<State>,
Path((user_id, id)): Path<(String, String)>,
) -> impl IntoResponse {
let data = &(data.read().await).0;
let user = match get_user_from_token!(jar, data) {
Some(ua) => ua,
None => return Err("Socket refused"),
};
if user.id.to_string() != user_id {
// TODO: maybe allow moderators to connect anyway
return Err("Socket refused (auth)");
}
let data = data.clone();
Ok(ws.on_upgrade(|socket| async move {
tokio::spawn(async move {
handle_socket(socket, data, user_id, id).await;
});
}))
}
#[cfg(feature = "redis")]
pub async fn handle_socket(socket: WebSocket, db: DataManager, user_id: String, stream_id: String) {
let (mut sink, mut stream) = socket.split();
// get channel permissions
let channel = format!("{user_id}_{stream_id}");
// ...
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(WsMessage::Text(text))) = stream.next().await {
if text != "Close" {
continue;
}
drop(stream);
break;
}
});
let dbc = db.clone();
let channel_c = channel.clone();
let mut redis_task = tokio::spawn(async move {
// forward messages from redis to the socket
let mut con = dbc.2.get_con().await;
let mut pubsub = con.as_pubsub();
pubsub.subscribe(channel_c).unwrap();
// listen for pubsub messages
while let Ok(msg) = pubsub.get_message() {
// payload is a stringified SocketMessage
let smsg = msg.get_payload::<String>().unwrap();
let packet: SocketMessage = serde_json::from_str(&smsg).unwrap();
if packet.method == SocketMethod::Forward(PacketType::Ping) {
// forward with custom message
if sink.send(WsMessage::Text("Ping".into())).await.is_err() {
drop(sink);
break;
}
} else if packet.method == SocketMethod::Message {
if sink.send(WsMessage::Text(smsg.into())).await.is_err() {
drop(sink);
break;
}
} else {
// forward to client
if sink.send(WsMessage::Text(smsg.into())).await.is_err() {
drop(sink);
break;
}
}
}
});
let db2c = db.2.clone();
let channel_c = channel.clone();
let heartbeat_task = tokio::spawn(async move {
let mut con = db2c.get_con().await;
let mut heartbeat = tokio::time::interval(Duration::from_secs(10));
loop {
con.publish::<String, String, ()>(
format!("{channel_c}_heartbeat"),
serde_json::to_string(&SocketMessage {
method: SocketMethod::Forward(PacketType::Ping),
data: "Ping".to_string(),
})
.unwrap(),
)
.unwrap();
heartbeat.tick().await;
}
});
tokio::select! {
_ = (&mut recv_task) => redis_task.abort(),
_ = (&mut redis_task) => recv_task.abort()
}
heartbeat_task.abort(); // kill
tracing::info!("socket terminate");
}