add: user socket streams
add: group messages by author in ui TODO: group messages by author in ui as they come in from socket TODO: notifications stream connection
This commit is contained in:
parent
c549fdd274
commit
094dd5fdb5
8 changed files with 198 additions and 40 deletions
|
@ -1,3 +1,5 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use crate::{
|
||||
get_user_from_token,
|
||||
model::{ApiReturn, Error},
|
||||
|
@ -8,19 +10,28 @@ use crate::{
|
|||
State,
|
||||
};
|
||||
use axum::{
|
||||
Extension, Json,
|
||||
extract::Path,
|
||||
extract::{
|
||||
ws::{Message as WsMessage, WebSocket},
|
||||
Path, WebSocketUpgrade,
|
||||
},
|
||||
response::{IntoResponse, Redirect},
|
||||
Extension, Json,
|
||||
};
|
||||
use axum_extra::extract::CookieJar;
|
||||
use futures_util::{sink::SinkExt, stream::StreamExt};
|
||||
use tetratto_core::{
|
||||
cache::Cache,
|
||||
model::{
|
||||
auth::{Token, UserSettings},
|
||||
permissions::FinePermission,
|
||||
socket::{PacketType, SocketMessage, SocketMethod},
|
||||
},
|
||||
DataManager,
|
||||
};
|
||||
|
||||
#[cfg(feature = "redis")]
|
||||
use redis::Commands;
|
||||
|
||||
pub async fn redirect_from_id(
|
||||
Extension(data): Extension<State>,
|
||||
Path(id): Path<String>,
|
||||
|
@ -410,3 +421,114 @@ pub async fn has_totp_enabled_request(
|
|||
payload: Some(!user.totp.is_empty()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle a subscription to the websocket.
|
||||
#[cfg(feature = "redis")]
|
||||
pub async fn subscription_handler(
|
||||
jar: CookieJar,
|
||||
ws: WebSocketUpgrade,
|
||||
Extension(data): Extension<State>,
|
||||
Path((user_id, id)): Path<(String, String)>,
|
||||
) -> impl IntoResponse {
|
||||
let data = &(data.read().await).0;
|
||||
let user = match get_user_from_token!(jar, data) {
|
||||
Some(ua) => ua,
|
||||
None => return Err("Socket refused"),
|
||||
};
|
||||
|
||||
if user.id.to_string() != user_id {
|
||||
// TODO: maybe allow moderators to connect anyway
|
||||
return Err("Socket refused (auth)");
|
||||
}
|
||||
|
||||
let data = data.clone();
|
||||
Ok(ws.on_upgrade(|socket| async move {
|
||||
tokio::spawn(async move {
|
||||
handle_socket(socket, data, user_id, id).await;
|
||||
});
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(feature = "redis")]
|
||||
pub async fn handle_socket(socket: WebSocket, db: DataManager, user_id: String, stream_id: String) {
|
||||
let (mut sink, mut stream) = socket.split();
|
||||
|
||||
// get channel permissions
|
||||
let channel = format!("{user_id}_{stream_id}");
|
||||
|
||||
// ...
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(WsMessage::Text(text))) = stream.next().await {
|
||||
if text != "Close" {
|
||||
continue;
|
||||
}
|
||||
|
||||
drop(stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
let dbc = db.clone();
|
||||
let channel_c = channel.clone();
|
||||
let mut redis_task = tokio::spawn(async move {
|
||||
// forward messages from redis to the socket
|
||||
let mut con = dbc.2.get_con().await;
|
||||
let mut pubsub = con.as_pubsub();
|
||||
pubsub.subscribe(channel_c).unwrap();
|
||||
|
||||
// listen for pubsub messages
|
||||
while let Ok(msg) = pubsub.get_message() {
|
||||
// payload is a stringified SocketMessage
|
||||
let smsg = msg.get_payload::<String>().unwrap();
|
||||
let packet: SocketMessage = serde_json::from_str(&smsg).unwrap();
|
||||
|
||||
if packet.method == SocketMethod::Forward(PacketType::Ping) {
|
||||
// forward with custom message
|
||||
if sink.send(WsMessage::Text("Ping".into())).await.is_err() {
|
||||
drop(sink);
|
||||
break;
|
||||
}
|
||||
} else if packet.method == SocketMethod::Message {
|
||||
if sink.send(WsMessage::Text(smsg.into())).await.is_err() {
|
||||
drop(sink);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// forward to client
|
||||
if sink.send(WsMessage::Text(smsg.into())).await.is_err() {
|
||||
drop(sink);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let db2c = db.2.clone();
|
||||
let channel_c = channel.clone();
|
||||
let heartbeat_task = tokio::spawn(async move {
|
||||
let mut con = db2c.get_con().await;
|
||||
let mut heartbeat = tokio::time::interval(Duration::from_secs(10));
|
||||
|
||||
loop {
|
||||
con.publish::<String, String, ()>(
|
||||
format!("{channel_c}_heartbeat"),
|
||||
serde_json::to_string(&SocketMessage {
|
||||
method: SocketMethod::Forward(PacketType::Ping),
|
||||
data: "Ping".to_string(),
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
heartbeat.tick().await;
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = (&mut recv_task) => redis_task.abort(),
|
||||
_ = (&mut redis_task) => recv_task.abort()
|
||||
}
|
||||
|
||||
heartbeat_task.abort(); // kill
|
||||
tracing::info!("socket terminate");
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue