fix: use redis aio

This commit is contained in:
trisua 2025-05-02 22:59:44 -04:00
parent 0df2cd7b56
commit ecde5d3d46
6 changed files with 240 additions and 30 deletions

View file

@ -443,15 +443,15 @@ pub async fn subscription_handler(
let data = data.clone();
Ok(ws.on_upgrade(|socket| async move {
tokio::spawn(async move {
handle_socket(socket, data, user_id, id).await;
});
handle_socket(socket, data, user_id, id).await;
}))
}
#[cfg(feature = "redis")]
pub async fn handle_socket(socket: WebSocket, db: DataManager, user_id: String, stream_id: String) {
let (mut sink, mut stream) = socket.split();
let socket_id = tetratto_shared::hash::salt();
db.2.incr("atto.active_connections:users".to_string()).await;
// get channel permissions
let channel = format!("{user_id}/{stream_id}");
@ -463,21 +463,25 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, user_id: String,
continue;
}
drop(stream);
break;
}
});
let heartbeat_uri = format!("{channel}/{socket_id}");
let dbc = db.clone();
let channel_c = channel.clone();
let heartbeat_c = heartbeat_uri.clone();
let mut redis_task = tokio::spawn(async move {
// forward messages from redis to the socket
let mut con = dbc.2.get_con().await;
let mut pubsub = con.as_pubsub();
pubsub.subscribe(channel_c).unwrap();
let mut pubsub = dbc.2.client.get_async_pubsub().await.unwrap();
pubsub.subscribe(channel_c).await.unwrap();
pubsub.subscribe(heartbeat_c).await.unwrap();
// listen for pubsub messages
while let Ok(msg) = pubsub.get_message() {
let mut pubsub = pubsub.into_on_message();
while let Some(msg) = pubsub.next().await {
// payload is a stringified SocketMessage
let smsg = msg.get_payload::<String>().unwrap();
let packet: SocketMessage = serde_json::from_str(&smsg).unwrap();
@ -485,18 +489,15 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, user_id: String,
if packet.method == SocketMethod::Forward(PacketType::Ping) {
// forward with custom message
if sink.send(WsMessage::Text("Ping".into())).await.is_err() {
drop(sink);
break;
}
} else if packet.method == SocketMethod::Message {
if sink.send(WsMessage::Text(smsg.into())).await.is_err() {
drop(sink);
break;
}
} else {
// forward to client
if sink.send(WsMessage::Text(smsg.into())).await.is_err() {
drop(sink);
break;
}
}
@ -504,14 +505,14 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, user_id: String,
});
let db2c = db.2.clone();
let channel_c = channel.clone();
let heartbeat_c = heartbeat_uri.clone();
let heartbeat_task = tokio::spawn(async move {
let mut con = db2c.get_con().await;
let mut heartbeat = tokio::time::interval(Duration::from_secs(30));
let mut heartbeat = tokio::time::interval(Duration::from_secs(10));
loop {
con.publish::<&str, String, ()>(
&channel_c,
&heartbeat_c,
serde_json::to_string(&SocketMessage {
method: SocketMethod::Forward(PacketType::Ping),
data: "Ping".to_string(),
@ -524,7 +525,6 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, user_id: String,
}
});
db.2.incr("atto.active_connections:users".to_string()).await;
tokio::select! {
_ = (&mut recv_task) => redis_task.abort(),
_ = (&mut redis_task) => recv_task.abort()