Skip to content

Commit 67a5c66

Browse files
EvolveArtakhercha
andauthored
feat: refactor onchain OHLC ws endpoint (#55)
* feat: refacto /onchain/ohlc ws endpoint * feat: check if existing onchain pair * Update utils.rs Co-authored-by: adel <akherchache@pm.me> --------- Co-authored-by: adel <akherchache@pm.me>
1 parent 5073e69 commit 67a5c66

File tree

6 files changed

+215
-69
lines changed

6 files changed

+215
-69
lines changed

pragma-common/src/types.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use chrono::{NaiveDateTime, Timelike};
2-
use serde::Deserialize;
2+
use serde::{Deserialize, Serialize};
33
use utoipa::ToSchema;
44

5-
#[derive(Default, Debug, Deserialize, ToSchema, Clone, Copy)]
5+
#[derive(Default, Debug, Serialize, Deserialize, ToSchema, Clone, Copy)]
66
pub enum AggregationMode {
77
#[serde(rename = "median")]
88
#[default]
@@ -13,7 +13,7 @@ pub enum AggregationMode {
1313
Twap,
1414
}
1515

16-
#[derive(Default, Debug, Deserialize, ToSchema, Clone, Copy)]
16+
#[derive(Default, Debug, Serialize, Deserialize, ToSchema, Clone, Copy)]
1717
pub enum Network {
1818
#[serde(rename = "testnet")]
1919
#[default]
@@ -32,7 +32,7 @@ pub enum DataType {
3232
}
3333

3434
// Supported Aggregation Intervals
35-
#[derive(Default, Debug, Deserialize, ToSchema, Clone, Copy)]
35+
#[derive(Default, Debug, Serialize, Deserialize, ToSchema, Clone, Copy)]
3636
pub enum Interval {
3737
#[serde(rename = "1min")]
3838
#[default]
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,200 @@
11
use std::time::Duration;
22

3-
use axum::extract::{Query, State};
3+
use axum::extract::State;
44
use axum::response::IntoResponse;
5+
use pragma_entities::InfraError;
6+
use serde::{Deserialize, Serialize};
57
use serde_json::json;
68

79
use pragma_common::types::{Interval, Network};
10+
use tokio::time::interval;
811

9-
use crate::handlers::entries::utils::currency_pair_to_pair_id;
10-
use crate::handlers::entries::GetOnchainOHLCParams;
12+
use crate::handlers::entries::utils::is_onchain_existing_pair;
1113
use crate::infra::repositories::entry_repository::OHLCEntry;
1214
use crate::infra::repositories::onchain_repository::get_ohlc;
13-
use crate::utils::PathExtractor;
1415
use crate::AppState;
1516

1617
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
1718

18-
pub const WS_UPDATING_INTERVAL_IN_SECONDS: u64 = 10;
19+
#[derive(Default, Debug, Serialize, Deserialize)]
20+
enum SubscriptionType {
21+
#[serde(rename = "subscribe")]
22+
#[default]
23+
Subscribe,
24+
#[serde(rename = "unsubscribe")]
25+
Unsubscribe,
26+
}
27+
28+
#[derive(Debug, Serialize, Deserialize)]
29+
struct SubscriptionRequest {
30+
msg_type: SubscriptionType,
31+
pair: String,
32+
network: Network,
33+
interval: Interval,
34+
}
35+
36+
#[derive(Debug, Serialize, Deserialize)]
37+
struct SubscriptionAck {
38+
msg_type: SubscriptionType,
39+
pair: String,
40+
network: Network,
41+
interval: Interval,
42+
}
43+
44+
/// Interval in milliseconds that the channel will update the client with the latest prices.
45+
const CHANNEL_UPDATE_INTERVAL_IN_MS: u64 = 500;
1946

2047
#[utoipa::path(
2148
get,
22-
path = "/node/v1/onchain/ws/ohlc/{base}/{quote}",
49+
path = "/node/v1/onchain/ohlc",
2350
responses(
2451
(
2552
status = 200,
26-
description = "Get OHLC data for a pair continuously updated through a ws connection",
27-
body = GetOnchainOHLCResponse
53+
description = "Subscribe to a list of OHLC entries",
54+
body = [SubscribeToEntryResponse]
2855
)
29-
),
30-
params(
31-
("base" = String, Path, description = "Base Asset"),
32-
("quote" = String, Path, description = "Quote Asset"),
33-
("network" = Network, Query, description = "Network"),
34-
("interval" = Interval, Query, description = "Interval of the OHLC data"),
35-
),
56+
)
3657
)]
37-
pub async fn get_onchain_ohlc_ws(
58+
pub async fn subscribe_to_onchain_ohlc(
3859
ws: WebSocketUpgrade,
3960
State(state): State<AppState>,
40-
PathExtractor(pair): PathExtractor<(String, String)>,
41-
Query(params): Query<GetOnchainOHLCParams>,
4261
) -> impl IntoResponse {
43-
let pair_id = currency_pair_to_pair_id(&pair.0, &pair.1);
44-
ws.on_upgrade(move |socket| {
45-
handle_ohlc_ws(socket, state, pair_id, params.network, params.interval)
46-
})
62+
ws.on_upgrade(move |socket| handle_channel(socket, state))
4763
}
4864

49-
async fn handle_ohlc_ws(
50-
mut socket: WebSocket,
51-
state: AppState,
52-
pair_id: String,
53-
network: Network,
54-
interval: Interval,
55-
) {
56-
// Initial OHLC to compute
57-
let mut ohlc_to_compute = 10;
58-
let mut update_interval =
59-
tokio::time::interval(Duration::from_secs(WS_UPDATING_INTERVAL_IN_SECONDS));
65+
/// Handle the WebSocket channel.
66+
async fn handle_channel(mut socket: WebSocket, state: AppState) {
67+
let waiting_duration = Duration::from_millis(CHANNEL_UPDATE_INTERVAL_IN_MS);
68+
let mut update_interval = interval(waiting_duration);
69+
let mut subscribed_pair: Option<String> = None;
70+
let mut network = Network::default();
71+
let mut interval = Interval::default();
6072

73+
let mut ohlc_to_compute = 10;
6174
let mut ohlc_data: Vec<OHLCEntry> = Vec::new();
6275

6376
loop {
64-
update_interval.tick().await;
65-
match get_ohlc(
66-
&mut ohlc_data,
67-
&state.postgres_pool,
68-
network,
69-
pair_id.clone(),
70-
interval,
71-
ohlc_to_compute,
72-
)
73-
.await
74-
{
75-
Ok(()) => {
76-
if socket
77-
.send(Message::Text(serde_json::to_string(&ohlc_data).unwrap()))
78-
.await
79-
.is_err()
80-
{
81-
break;
77+
tokio::select! {
78+
Some(msg) = socket.recv() => {
79+
if let Ok(Message::Text(text)) = msg {
80+
handle_message_received(&mut socket, &state, &mut subscribed_pair, &mut network, &mut interval, text).await;
8281
}
82+
},
83+
_ = update_interval.tick() => {
84+
match send_ohlc_data(&mut socket, &state, &subscribed_pair, &mut ohlc_data, network, interval, ohlc_to_compute).await {
85+
Ok(_) => {
86+
// After the first request, we only get the latest interval
87+
if !ohlc_data.is_empty() {
88+
ohlc_to_compute = 1;
89+
}
90+
},
91+
Err(_) => break
92+
};
8393
}
84-
Err(e) => {
85-
if socket
86-
.send(Message::Text(json!({ "error": e.to_string() }).to_string()))
87-
.await
88-
.is_err()
89-
{
90-
break;
94+
}
95+
}
96+
}
97+
98+
/// Handle the message received from the client.
99+
/// Subscribe or unsubscribe to the pairs requested.
100+
async fn handle_message_received(
101+
socket: &mut WebSocket,
102+
state: &AppState,
103+
subscribed_pair: &mut Option<String>,
104+
network: &mut Network,
105+
interval: &mut Interval,
106+
message: String,
107+
) {
108+
if let Ok(subscription_msg) = serde_json::from_str::<SubscriptionRequest>(&message) {
109+
match subscription_msg.msg_type {
110+
SubscriptionType::Subscribe => {
111+
let pair_exists = is_onchain_existing_pair(
112+
&state.postgres_pool,
113+
&subscription_msg.pair,
114+
subscription_msg.network,
115+
)
116+
.await;
117+
if !pair_exists {
118+
let error_msg = "Pair does not exist in the onchain database.";
119+
send_error_message(socket, error_msg).await;
120+
return;
91121
}
122+
123+
*network = subscription_msg.network;
124+
*subscribed_pair = Some(subscription_msg.pair.clone());
125+
*interval = subscription_msg.interval;
126+
}
127+
SubscriptionType::Unsubscribe => {
128+
*subscribed_pair = None;
92129
}
130+
};
131+
// We send an ack message to the client with the subscribed pairs (so
132+
// the client knows which pairs are successfully subscribed).
133+
if let Ok(ack_message) = serde_json::to_string(&SubscriptionAck {
134+
msg_type: subscription_msg.msg_type,
135+
pair: subscription_msg.pair,
136+
network: subscription_msg.network,
137+
interval: subscription_msg.interval,
138+
}) {
139+
if socket.send(Message::Text(ack_message)).await.is_err() {
140+
let error_msg = "Message received but could not send ack message.";
141+
send_error_message(socket, error_msg).await;
142+
}
143+
} else {
144+
let error_msg = "Could not serialize ack message.";
145+
send_error_message(socket, error_msg).await;
93146
}
94-
// After the first request, we only get the latest interval
95-
if !ohlc_data.is_empty() {
96-
ohlc_to_compute = 1;
147+
} else {
148+
let error_msg = "Invalid message type. Please check the documentation for more info.";
149+
send_error_message(socket, error_msg).await;
150+
}
151+
}
152+
153+
/// Send the current median entries to the client.
154+
async fn send_ohlc_data(
155+
socket: &mut WebSocket,
156+
state: &AppState,
157+
subscribed_pair: &Option<String>,
158+
ohlc_data: &mut Vec<OHLCEntry>,
159+
network: Network,
160+
interval: Interval,
161+
ohlc_to_compute: i64,
162+
) -> Result<(), InfraError> {
163+
if subscribed_pair.is_none() {
164+
return Ok(());
165+
}
166+
167+
let pair_id = subscribed_pair.as_ref().unwrap();
168+
169+
let entries = match get_ohlc(
170+
ohlc_data,
171+
&state.postgres_pool,
172+
network,
173+
pair_id.clone(),
174+
interval,
175+
ohlc_to_compute,
176+
)
177+
.await
178+
{
179+
Ok(()) => ohlc_data,
180+
Err(e) => {
181+
send_error_message(socket, &e.to_string()).await;
182+
return Err(e);
183+
}
184+
};
185+
if let Ok(json_response) = serde_json::to_string(&entries) {
186+
if socket.send(Message::Text(json_response)).await.is_err() {
187+
send_error_message(socket, "Could not send prices.").await;
97188
}
189+
} else {
190+
send_error_message(socket, "Could not serialize prices.").await;
98191
}
192+
Ok(())
193+
}
194+
195+
/// Send an error message to the client.
196+
/// (Does not close the connection)
197+
async fn send_error_message(socket: &mut WebSocket, error: &str) {
198+
let error_msg = json!({ "error": error }).to_string();
199+
socket.send(Message::Text(error_msg)).await.unwrap();
99200
}

pragma-node/src/handlers/entries/utils.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
use bigdecimal::{BigDecimal, ToPrimitive};
22
use chrono::NaiveDateTime;
3+
use deadpool_diesel::postgres::Pool;
4+
use pragma_common::types::Network;
35
use std::collections::HashMap;
46

5-
use crate::infra::repositories::entry_repository::MedianEntry;
7+
use crate::infra::repositories::{
8+
entry_repository::MedianEntry, onchain_repository::get_existing_pairs,
9+
};
610

711
const ONE_YEAR_IN_SECONDS: f64 = 3153600_f64;
812

@@ -68,6 +72,16 @@ pub(crate) fn compute_median_price_and_time(
6872
Some((median_price, latest_time))
6973
}
7074

75+
/// Given a pair and a network, returns if it exists in the
76+
/// onchain database.
77+
pub(crate) async fn is_onchain_existing_pair(pool: &Pool, pair: &String, network: Network) -> bool {
78+
let existings_pairs = get_existing_pairs(pool, network)
79+
.await
80+
.expect("Couldn't get the existing pairs from the database.");
81+
82+
existings_pairs.into_iter().any(|p| p.pair_id == *pair)
83+
}
84+
7185
/// Computes the volatility from a list of entries.
7286
/// The volatility is computed as the annualized standard deviation of the log returns.
7387
/// The log returns are computed as the natural logarithm of the ratio between two consecutive median prices.

pragma-node/src/infra/repositories/onchain_repository.rs

+31
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,37 @@ pub async fn get_last_updated_timestamp(
181181
Ok(most_recent_entry.timestamp.and_utc().timestamp() as u64)
182182
}
183183

184+
#[derive(Queryable, QueryableByName)]
185+
pub struct EntryPairId {
186+
#[diesel(sql_type = VarChar)]
187+
pub pair_id: String,
188+
}
189+
190+
// TODO(0xevolve): Only works for Spot entries
191+
pub async fn get_existing_pairs(
192+
pool: &Pool,
193+
network: Network,
194+
) -> Result<Vec<EntryPairId>, InfraError> {
195+
let raw_sql = format!(
196+
r#"
197+
SELECT DISTINCT
198+
pair_id
199+
FROM
200+
{table_name};
201+
"#,
202+
table_name = get_table_name(network, DataType::SpotEntry)
203+
);
204+
205+
let conn = pool.get().await.map_err(adapt_infra_error)?;
206+
let raw_entries = conn
207+
.interact(move |conn| diesel::sql_query(raw_sql).load::<EntryPairId>(conn))
208+
.await
209+
.map_err(adapt_infra_error)?
210+
.map_err(adapt_infra_error)?;
211+
212+
Ok(raw_entries)
213+
}
214+
184215
#[derive(Queryable, QueryableByName)]
185216
struct RawCheckpoint {
186217
#[diesel(sql_type = VarChar)]

pragma-node/src/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async fn main() {
3636
handlers::entries::get_onchain::get_onchain,
3737
handlers::entries::get_onchain::checkpoints::get_onchain_checkpoints,
3838
handlers::entries::get_onchain::publishers::get_onchain_publishers,
39-
handlers::entries::get_onchain::ohlc::get_onchain_ohlc_ws,
39+
handlers::entries::get_onchain::ohlc::subscribe_to_onchain_ohlc,
4040
),
4141
components(
4242
schemas(pragma_entities::dto::Entry, pragma_entities::EntryError),

pragma-node/src/routes.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use utoipa::OpenApi as OpenApiT;
66
use utoipa_swagger_ui::SwaggerUi;
77

88
use crate::handlers::entries::get_onchain::{
9-
checkpoints::get_onchain_checkpoints, get_onchain, ohlc::get_onchain_ohlc_ws,
9+
checkpoints::get_onchain_checkpoints, get_onchain, ohlc::subscribe_to_onchain_ohlc,
1010
publishers::get_onchain_publishers,
1111
};
1212
use crate::handlers::entries::{create_entries, get_entry, get_ohlc, get_volatility};
@@ -47,7 +47,7 @@ fn onchain_routes(state: AppState) -> Router<AppState> {
4747
.route("/:base/:quote", get(get_onchain))
4848
.route("/checkpoints/:base/:quote", get(get_onchain_checkpoints))
4949
.route("/publishers", get(get_onchain_publishers))
50-
.route("/ws/ohlc/:base/:quote", get(get_onchain_ohlc_ws))
50+
.route("/ws/ohlc", get(subscribe_to_onchain_ohlc))
5151
.with_state(state)
5252
}
5353

0 commit comments

Comments
 (0)