use crate::{Error, Result}; use tokio_tungstenite::{WebSocketStream, MaybeTlsStream}; use tokio::net::TcpStream; use tokio_stream::Stream; use crate::entities::event::Event; use std::pin::Pin; use std::task::Poll; use tokio_tungstenite::tungstenite::Message; use crate::entities::notification::Notification; use crate::entities::status::Status; use futures_util::sink::SinkExt; #[derive(Clone, Debug)] pub enum StreamKind<'a> { User, Public, PublicLocal, Direct, Hashtag(&'a str), HashtagLocal(&'a str), List(&'a str), } impl<'a> StreamKind<'a> { pub(crate) fn get_stream_name(&self) -> &'static str { match self { StreamKind::User => "user", StreamKind::Public => "public", StreamKind::PublicLocal => "public:local", StreamKind::Direct => "direct", StreamKind::Hashtag(_) => "hashtag", StreamKind::HashtagLocal(_) => "hashtag:local", StreamKind::List(_) => "list", } } #[allow(unused)] pub(crate) fn get_url_fragment(&self) -> &'static str { match self { StreamKind::User => "user", StreamKind::Public => "public", StreamKind::PublicLocal => "public/local", StreamKind::Direct => "direct", StreamKind::Hashtag(_) => "hashtag", StreamKind::HashtagLocal(_) => "hashtag/local", StreamKind::List(_) => "list", } } pub(crate) fn get_query_params(&self) -> Vec<(&str, &str)> { match self { StreamKind::User => vec![], StreamKind::Public => vec![], StreamKind::PublicLocal => vec![], StreamKind::Direct => vec![], StreamKind::Hashtag(tag) | StreamKind::HashtagLocal(tag) => vec![("tag", tag)], StreamKind::List(list) => vec![("tag", list)], } } } pub(crate) async fn do_open_streaming(url : &str) -> Result { let mut url: url::Url = reqwest::get(url).await?.url().as_str().parse()?; let new_scheme = match url.scheme() { "http" => "ws", "https" => "wss", x => return Err(Error::Other(format!("Bad URL scheme: {}", x))), }; url.set_scheme(new_scheme) .map_err(|_| Error::Other("Bad URL scheme!".to_string()))?; let (client, _response) = tokio_tungstenite::connect_async(url.as_str()).await?; Ok(EventReader::new(client)) } #[derive(Debug)] /// Iterator that produces events from a mastodon streaming API event stream pub struct EventReader { stream: WebSocketStream>, lines: Vec, } impl EventReader { fn new(stream: WebSocketStream>) -> Self { Self { stream, lines: vec![] } } pub async fn send_ping(&mut self) -> std::result::Result<(), tokio_tungstenite::tungstenite::Error> { trace!("Sending ping"); self.stream.send(Message::Ping("pleroma groups".as_bytes().to_vec())).await } } impl Stream for EventReader { type Item = Event; fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { match Pin::new(&mut self.stream).poll_next(cx) { Poll::Ready(Some(Ok(Message::Text(line)))) => { trace!("WS rx: {}", line); let line = line.trim().to_string(); if line.starts_with(':') || line.is_empty() { trace!("discard as comment"); return Poll::Pending; } self.lines.push(line); if let Ok(event) = self.make_event(&self.lines) { trace!("Parsed event"); self.lines.clear(); Poll::Ready(Some(event)) } else { trace!("Failed to parse"); Poll::Pending } } Poll::Ready(Some(Ok(Message::Ping(_)))) => { trace!("Ping"); Poll::Ready(Some(Event::Heartbeat)) } Poll::Ready(Some(Ok(Message::Pong(_)))) => { trace!("Pong"); Poll::Ready(Some(Event::Heartbeat)) } Poll::Ready(Some(Ok(Message::Binary(_)))) => { warn!("Unexpected binary msg"); Poll::Ready(Some(Event::Heartbeat)) } Poll::Ready(Some(Ok(Message::Close(_)))) => { warn!("Websocket close frame!"); Poll::Ready(None) } Poll::Ready(Some(Err(error))) => { error!("Websocket error: {:?}", error); // Close Poll::Ready(None) } Poll::Ready(None) => { // Stream is closed Poll::Ready(None) } Poll::Pending => { Poll::Pending } } } } impl EventReader { fn make_event(&self, lines: &[String]) -> Result { let event; let data; if let Some(event_line) = lines.iter().find(|line| line.starts_with("event:")) { trace!("plaintext formatted event"); event = event_line[6..].trim().to_string(); data = lines .iter() .find(|line| line.starts_with("data:")) .map(|x| x[5..].trim().to_string()); } else { trace!("JSON formatted event"); use serde::Deserialize; #[derive(Deserialize)] struct Message { pub event: String, pub payload: Option, } let message = serde_json::from_str::(&lines[0])?; event = message.event; data = message.payload; } let event: &str = &event; Ok(match event { "notification" => { let data = data.ok_or_else(|| Error::StreamingFormat)?; let notification = serde_json::from_str::(&data)?; Event::Notification(notification) } "update" => { let data = data.ok_or_else(|| Error::StreamingFormat)?; let status = serde_json::from_str::(&data)?; Event::Update(status) } "delete" => { let data = data.ok_or_else(|| Error::StreamingFormat)?; Event::Delete(data) } "filters_changed" => Event::FiltersChanged, _ => return Err(Error::Other(format!("Unknown event `{}`", event))), }) } }