add: connect to socket by community

direct messages/groups connect by channel id, everything else should connect by channel with the "is_channel" header set to true
This commit is contained in:
trisua 2025-04-29 16:53:34 -04:00
parent c1c8cdbfcd
commit 0304461389
20 changed files with 241 additions and 160 deletions

50
Cargo.lock generated
View file

@ -533,9 +533,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.40"
version = "0.4.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c"
checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
dependencies = [
"android-tzdata",
"iana-time-zone",
@ -716,28 +716,6 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "crossbeam"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
@ -757,15 +735,6 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
@ -3271,11 +3240,11 @@ dependencies = [
"axum-extra",
"cf-turnstile",
"contrasted",
"crossbeam",
"futures-util",
"image",
"mime_guess",
"pathbufd",
"redis",
"regex",
"reqwest",
"serde",
@ -3299,7 +3268,6 @@ dependencies = [
"base64",
"bb8-postgres",
"bitflags 2.9.0",
"futures-util",
"md-5",
"pathbufd",
"redis",
@ -3556,9 +3524,9 @@ dependencies = [
[[package]]
name = "toml"
version = "0.8.21"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "900f6c86a685850b1bc9f6223b20125115ee3f31e01207d81655bbcc0aea9231"
checksum = "05ae329d1f08c4d17a59bed7ff5b5a769d062e64a62d34a3261b219e62cd5aae"
dependencies = [
"serde",
"serde_spanned",
@ -3577,9 +3545,9 @@ dependencies = [
[[package]]
name = "toml_edit"
version = "0.22.25"
version = "0.22.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10558ed0bd2a1562e630926a2d1f0b98c827da99fabd3fe20920a59642504485"
checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e"
dependencies = [
"indexmap",
"serde",
@ -3591,9 +3559,9 @@ dependencies = [
[[package]]
name = "toml_write"
version = "0.1.0"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28391a4201ba7eb1984cfeb6862c0b3ea2cfe23332298967c749dddc0d6cd976"
checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076"
[[package]]
name = "totp-rs"

View file

@ -6,7 +6,7 @@ edition = "2024"
[features]
postgres = ["tetratto-core/postgres"]
sqlite = ["tetratto-core/sqlite"]
redis = ["tetratto-core/redis"]
redis = ["tetratto-core/redis", "dep:redis"]
default = ["sqlite", "redis"]
[dependencies]
@ -34,4 +34,5 @@ mime_guess = "2.0.5"
cf-turnstile = "0.2.0"
contrasted = "0.1.2"
futures-util = "0.3.31"
crossbeam = { version = "0.8.4", features = ["crossbeam-channel"] }
redis = { version = "0.30.0", optional = true }

View file

@ -27,6 +27,7 @@
<a
href="/chats/0/0"
class="button quaternary channel_icon {% if selected_community == 0 %}selected{% endif %}"
data-turbo="false"
>
{{ icon "message-circle" }}
</a>
@ -35,6 +36,7 @@
<a
href="/chats/{{ community.id }}/0"
class="button quaternary channel_icon {% if selected_community == community.id %}selected{% endif %}"
data-turbo="false"
>
{{ components::community_avatar(id=community.id,
community=community, size="48px") }}
@ -440,64 +442,76 @@
};
</script>
{% if selected_channel %}
<script>
setTimeout(() => {
<script id="socket_init" data-turbo-permanent="true">
window.SUBSCRIBE_CHANNEL = "{{ selected_community }}" === "0";
globalThis.socket_init = () => {
if (window.socket) {
if (window.socket_id === "{{ selected_channel }}") {
console.log("cannot open; already in session");
return;
} else {
window.socket.send("Close");
window.socket.close();
window.socket = undefined;
console.log("closed lingering");
}
}
for (const anchor of document.querySelectorAll("a")) {
if (anchor.href.includes("{{ selected_channel }}")) {
continue;
}
anchor.addEventListener("click", () => {
window.socket.close();
window.socket = undefined;
console.log("force abandon socket");
});
}
const endpoint = `${window.location.origin.replace("http", "ws")}/api/v1/channels/{{ selected_channel }}/ws`;
if ("{{ selected_community }}" !== "0") {
const endpoint = `${window.location.origin.replace("http", "ws")}/api/v1/_connect/{{ selected_community }}`;
const socket = new WebSocket(endpoint);
window.socket = socket;
window.socket_id = "{{ selected_community }}";
} else {
const endpoint = `${window.location.origin.replace("http", "ws")}/api/v1/_connect/{{ selected_channel }}`;
const socket = new WebSocket(endpoint);
window.socket = socket;
window.socket_id = "{{ selected_channel }}";
}
socket.addEventListener("close", () => {
return socket.send("Close");
});
socket.addEventListener("open", () => {
window.socket.addEventListener("open", () => {
// auth
socket.send(
window.socket.send(
JSON.stringify({
method: "Headers",
data: JSON.stringify({
// SocketHeaders
channel: "{{ selected_channel }}",
user: "{{ user.id }}",
is_channel: window.SUBSCRIBE_CHANNEL,
}),
}),
);
});
};
</script>
socket.addEventListener("message", async (event) => {
{% if selected_channel %}
<script>
setTimeout(() => {
if (!window.SUBSCRIBE_CHANNEL) {
// sub community
if (window.socket_id !== "{{ selected_community }}") {
socket_init();
}
} else {
// sub channel
if (window.socket_id !== "{{ selected_channel }}") {
socket_init();
}
}
}, 50);
setTimeout(() => {
window.socket.addEventListener("message", async (event) => {
if (event.data === "Ping") {
return socket.send("Pong");
}
const msg = JSON.parse(event.data);
const data = JSON.parse(msg.data);
const [channel_id, data] = JSON.parse(msg.data);
if (msg.method === "Message" && window.CURRENT_PAGE === 0) {
if (channel_id !== "{{ selected_channel }}") {
// message not for us... maybe send notification later
// something like /api/v1/messages/{id}/mark_unread
return;
}
if (document.getElementById("stream_body")) {
const element = document.createElement("div");
element.style.display = "contents";

View file

@ -83,7 +83,7 @@ pub async fn proxy_request(
None => return Json(Error::NotAllowed.into()),
};
if let None = user.connections.get(&ConnectionService::LastFm) {
if user.connections.get(&ConnectionService::LastFm).is_none() {
// connection doesn't exist
return Json(Error::GeneralNotFound("connection".to_string()).into());
};

View file

@ -66,11 +66,11 @@ pub async fn me_request(jar: CookieJar, Extension(data): Extension<State>) -> im
None => return Json(Error::NotAllowed.into()),
};
return Json(ApiReturn {
Json(ApiReturn {
ok: true,
message: "User exists".to_string(),
payload: Some(user),
});
})
}
/// Update the settings of the given user.

View file

@ -63,7 +63,7 @@ pub async fn create_group_request(
// check for existing
if members.len() == 1 {
let other_user = members.get(0).unwrap().to_owned();
let other_user = members.first().unwrap().to_owned();
if let Ok(channel) = data.get_channel_by_owner_member(user.id, other_user).await {
return Json(ApiReturn {
ok: true,
@ -80,21 +80,18 @@ pub async fn create_group_request(
Err(e) => return Json(e.into()),
};
if other_user.settings.private_chats {
if data
if other_user.settings.private_chats && data
.get_userfollow_by_initiator_receiver(other_user.id, user.id)
.await
.is_err()
{
.is_err() {
return Json(Error::NotAllowed.into());
}
}
}
// ...
let mut props = Channel::new(0, user.id, 0, req.title);
props.members = members;
let id = props.id.clone();
let id = props.id;
match data.create_channel(props).await {
Ok(_) => Json(ApiReturn {

View file

@ -1,3 +1,5 @@
use std::{collections::HashMap, time::Duration};
use redis::Commands;
use axum::{
extract::{
ws::{Message as WsMessage, WebSocket, WebSocketUpgrade},
@ -12,7 +14,7 @@ use tetratto_core::{
model::{
auth::User,
channels::Message,
socket::{SocketMessage, SocketMethod},
socket::{PacketType, SocketMessage, SocketMethod},
ApiReturn, Error,
},
DataManager,
@ -21,10 +23,10 @@ use crate::{get_user_from_token, routes::api::v1::CreateMessage, State};
use serde::Deserialize;
use futures_util::{sink::SinkExt, stream::StreamExt};
#[derive(Deserialize)]
#[derive(Clone, Deserialize)]
pub struct SocketHeaders {
pub channel: String,
pub user: String,
pub is_channel: bool,
}
/// Handle a subscription to the websocket.
@ -43,11 +45,11 @@ pub async fn subscription_handler(
})
}
pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: String) {
pub async fn handle_socket(socket: WebSocket, db: DataManager, community_id: String) {
let (mut sink, mut stream) = socket.split();
let mut user: Option<User> = None;
let mut con = db.2.clone().get_con().await;
let mut headers: Option<SocketHeaders> = None;
// handle incoming messages on socket
let dbc = db.clone();
@ -60,7 +62,7 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
}
};
if data.method != SocketMethod::Headers && user.is_none() {
if data.method != SocketMethod::Headers && user.is_none() && headers.is_none() {
// we've sent something else before authenticating... that's not right
let _ = sink.close().await;
return;
@ -70,6 +72,7 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
SocketMethod::Headers => {
let data: SocketHeaders = data.data();
headers = Some(data.clone());
user = Some(
match dbc
.get_user_by_id(match data.user.parse::<usize>() {
@ -89,8 +92,10 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
},
);
if data.is_channel {
// verify permissions for single channel
let channel = match dbc
.get_channel_by_id(match data.channel.parse::<usize>() {
.get_channel_by_id(match community_id.parse::<usize>() {
Ok(c) => c,
Err(_) => {
let _ = sink.close().await;
@ -124,6 +129,7 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
return;
}
}
}
_ => {
let _ = sink.close().await;
return;
@ -134,6 +140,37 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
return;
}
// get channel permissions
let user = user.unwrap();
let headers = headers.unwrap();
let mut channel_read_statuses: HashMap<usize, bool> = HashMap::new();
if !headers.is_channel {
// check permissions for every channel in community
let community_id = match community_id.parse::<usize>() {
Ok(c) => c,
Err(_) => return,
};
let membership = match dbc
.get_membership_by_owner_community(user.id, community_id)
.await
{
Ok(ua) => ua,
Err(_) => {
return;
}
};
for channel in dbc.get_channels_by_community(community_id).await.unwrap() {
channel_read_statuses.insert(
channel.id,
channel.check_read(user.id, Some(membership.role)),
);
}
}
// ...
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(WsMessage::Text(text))) = stream.next().await {
if text != "Close" {
@ -147,20 +184,77 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
}
});
let dbc = db.clone();
let mut redis_task = tokio::spawn(async move {
// forward messages from redis to the mpsc
let mut con = dbc.2.get_con().await;
let mut pubsub = con.as_pubsub();
pubsub.subscribe(channel_id).unwrap();
pubsub.subscribe(user.id).unwrap();
pubsub.subscribe(community_id.clone()).unwrap();
// listen for pubsub messages
while let Ok(msg) = pubsub.get_message() {
// payload is a stringified SocketMessage
if sink
.send(WsMessage::Text(msg.get_payload::<String>().unwrap().into()))
.await
.is_err()
{
return;
let smsg = msg.get_payload::<String>().unwrap();
let packet: SocketMessage = serde_json::from_str(&smsg).unwrap();
if packet.method == SocketMethod::Forward(PacketType::Ping) {
// forward with custom message
if sink.send(WsMessage::Text("Ping".into())).await.is_err() {
drop(sink);
break;
}
} else if packet.method == SocketMethod::Message {
// check perms and then forward
let d: (String, Message) = packet.data();
if let Some(cs) = channel_read_statuses.get(&d.1.channel) {
if !cs {
continue;
}
} else {
if !headers.is_channel {
// since we didn't select by just a channel, there HAS to be
// an entry for the channel for us to check this message
continue;
// we don't need to check messages when we're subscribed to
// a channel, since that is checked on headers submission when
// we subscribe to a channel
}
}
if sink.send(WsMessage::Text(smsg.into())).await.is_err() {
drop(sink);
break;
}
} else {
// forward to client
if sink.send(WsMessage::Text(smsg.into())).await.is_err() {
drop(sink);
break;
}
}
}
});
let db2c = db.2.clone();
let heartbeat_task = tokio::spawn(async move {
let mut con = db2c.get_con().await;
let mut heartbeat = tokio::time::interval(Duration::from_secs(10));
loop {
con.publish::<usize, String, ()>(
user.id,
serde_json::to_string(&SocketMessage {
method: SocketMethod::Forward(PacketType::Ping),
data: "Ping".to_string(),
})
.unwrap(),
)
.unwrap();
heartbeat.tick().await;
}
});
@ -169,6 +263,7 @@ pub async fn handle_socket(socket: WebSocket, db: DataManager, channel_id: Strin
_ = (&mut redis_task) => recv_task.abort()
}
heartbeat_task.abort(); // kill
tracing::info!("socket terminate");
}

View file

@ -290,7 +290,7 @@ pub fn routes() -> Router {
)
// messages
.route(
"/channels/{id}/ws",
"/_connect/{id}",
any(channels::messages::subscription_handler),
)
.route("/messages", post(channels::messages::create_request))

View file

@ -211,7 +211,7 @@ pub async fn message_request(
}
};
let message: Message = match serde_json::from_str(&req.data) {
let message: (String, Message) = match serde_json::from_str(&req.data) {
Ok(m) => m,
Err(e) => {
return Err(Html(
@ -220,6 +220,8 @@ pub async fn message_request(
}
};
let message = message.1;
let membership = match data
.0
.get_membership_by_owner_community(user.id, community)

View file

@ -12,7 +12,7 @@ default = ["sqlite", "redis"]
[dependencies]
pathbufd = "0.1.4"
serde = { version = "1.0.219", features = ["derive"] }
toml = "0.8.21"
toml = "0.8.22"
tetratto-shared = { path = "../shared" }
tetratto-l10n = { path = "../l10n" }
serde_json = "1.0.140"
@ -30,4 +30,3 @@ rusqlite = { version = "0.35.0", optional = true }
tokio-postgres = { version = "0.7.13", optional = true }
bb8-postgres = { version = "0.9.0", optional = true }
base64 = "0.22.1"
futures-util = "0.3.31"

View file

@ -151,6 +151,7 @@ impl Default for TurnstileConfig {
}
#[derive(Clone, Serialize, Deserialize, Debug)]
#[derive(Default)]
pub struct ConnectionsConfig {
/// <https://developer.spotify.com/documentation/web-api>
#[serde(default)]
@ -163,15 +164,6 @@ pub struct ConnectionsConfig {
pub last_fm_secret: Option<String>,
}
impl Default for ConnectionsConfig {
fn default() -> Self {
Self {
spotify_client_id: None,
last_fm_key: None,
last_fm_secret: None,
}
}
}
/// Configuration file
#[derive(Clone, Serialize, Deserialize, Debug)]

View file

@ -528,7 +528,7 @@ macro_rules! auto_method {
if !user.permissions.check(FinePermission::$permission) {
return Err(Error::NotAllowed);
} else {
self.create_audit_log_entry(crate::model::moderation::AuditLogEntry::new(
self.create_audit_log_entry($crate::model::moderation::AuditLogEntry::new(
user.id,
format!("invoked `{}` with x value `{x}`", stringify!($name)),
))
@ -607,7 +607,7 @@ macro_rules! auto_method {
if !user.permissions.check(FinePermission::$permission) {
return Err(Error::NotAllowed);
} else {
self.create_audit_log_entry(crate::model::moderation::AuditLogEntry::new(
self.create_audit_log_entry($crate::model::moderation::AuditLogEntry::new(
user.id,
format!("invoked `{}` with x value `{x:?}`", stringify!($name)),
))

View file

@ -49,7 +49,7 @@ impl DataManager {
pub async fn fill_messages(
&self,
messages: Vec<Message>,
ignore_users: &Vec<usize>,
ignore_users: &[usize],
) -> Result<Vec<(Message, User)>> {
let mut out: Vec<(Message, User)> = Vec::new();
@ -158,10 +158,16 @@ impl DataManager {
let mut con = self.2.get_con().await;
if let Err(e) = con.publish::<usize, String, ()>(
data.channel,
if channel.community != 0 {
// broadcast to community ws
channel.community
} else {
// broadcast to channel ws
channel.id
},
serde_json::to_string(&SocketMessage {
method: SocketMethod::Message,
data: serde_json::to_string(&data).unwrap(),
data: serde_json::to_string(&(data.channel.to_string(), data)).unwrap(),
})
.unwrap(),
) {
@ -211,7 +217,13 @@ impl DataManager {
let mut con = self.2.get_con().await;
if let Err(e) = con.publish::<usize, String, ()>(
message.channel,
if channel.community != 0 {
// broadcast to community ws
channel.community
} else {
// broadcast to channel ws
channel.id
},
serde_json::to_string(&SocketMessage {
method: SocketMethod::Delete,
data: serde_json::to_string(&DeleteMessageEvent { id: id.to_string() }).unwrap(),

View file

@ -121,7 +121,7 @@ impl DataManager {
pub async fn fill_posts(
&self,
posts: Vec<Post>,
ignore_users: &Vec<usize>,
ignore_users: &[usize],
) -> Result<Vec<(Post, User, Option<(User, Post)>, Option<(Question, User)>)>> {
let mut out: Vec<(Post, User, Option<(User, Post)>, Option<(Question, User)>)> = Vec::new();
@ -160,7 +160,7 @@ impl DataManager {
&self,
posts: Vec<Post>,
user_id: usize,
ignore_users: &Vec<usize>,
ignore_users: &[usize],
) -> Result<
Vec<(
Post,

View file

@ -49,7 +49,7 @@ impl DataManager {
pub async fn fill_questions(
&self,
questions: Vec<Question>,
ignore_users: &Vec<usize>,
ignore_users: &[usize],
) -> Result<Vec<(Question, User)>> {
let mut out: Vec<(Question, User)> = Vec::new();

View file

@ -371,19 +371,12 @@ pub struct ExternalConnectionInfo {
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Default)]
pub struct ExternalConnectionData {
pub external_urls: HashMap<String, String>,
pub data: HashMap<String, String>,
}
impl Default for ExternalConnectionData {
fn default() -> Self {
Self {
external_urls: HashMap::new(),
data: HashMap::new(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Notification {

View file

@ -93,7 +93,7 @@ impl Message {
created: now,
edited: now,
content,
context: MessageContext::default(),
context: MessageContext,
}
}
}

View file

@ -1,5 +1,11 @@
use serde::{Serialize, Deserialize, de::DeserializeOwned};
#[derive(Serialize, Deserialize, PartialEq, Eq)]
pub enum PacketType {
/// A regular check to ensure the connection is still alive.
Ping,
}
#[derive(Serialize, Deserialize, PartialEq, Eq)]
pub enum SocketMethod {
/// Authentication and channel identification.
@ -8,6 +14,8 @@ pub enum SocketMethod {
Message,
/// A message was deleted in the channel.
Delete,
/// Forward message from server to client. (Redis pubsub to ws)
Forward(PacketType),
}
#[derive(Serialize, Deserialize)]

View file

@ -9,4 +9,4 @@ license.workspace = true
[dependencies]
pathbufd = "0.1.4"
serde = { version = "1.0.219", features = ["derive"] }
toml = "0.8.21"
toml = "0.8.22"

View file

@ -8,7 +8,7 @@ license.workspace = true
[dependencies]
ammonia = "4.1.0"
chrono = "0.4.40"
chrono = "0.4.41"
comrak = "0.38.0"
hex_fmt = "0.3.0"
num-bigint = "0.4.6"