Skip to content

Commit f76d16c

Browse files
authored
feat: small ws improvements (#59)
* feat: doc example and timestamp parsing * feat: cors/logging layer + ping msg * Update ohlc.rs * feat: msg handling * fix: clippy * fix: taplo
1 parent c070177 commit f76d16c

File tree

11 files changed

+120
-22
lines changed

11 files changed

+120
-22
lines changed

Cargo.lock

+32
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pragma-node/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ strum = { version = "0.25.0", features = ["derive"] }
3030
thiserror = "1.0.49"
3131
time = "0.3.29"
3232
tokio = { version = "1.0", features = ["sync", "macros", "rt-multi-thread"] }
33+
tower-http = { version = "0.4.0", features = ["fs", "trace", "cors"] }
3334
tracing = "0.1"
3435
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
3536
url = "2.5.0"

pragma-node/src/handlers/entries/get_entry.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub async fn get_entry(
3434
// Construct pair id
3535
let pair_id = currency_pair_to_pair_id(&pair.0, &pair.1);
3636

37-
let now = chrono::Utc::now().timestamp() as u64;
37+
let now = chrono::Utc::now().timestamp();
3838

3939
let timestamp = if let Some(timestamp) = params.timestamp {
4040
timestamp

pragma-node/src/handlers/entries/get_ohlc.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub async fn get_ohlc(
3131
// Construct pair id
3232
let pair_id = currency_pair_to_pair_id(&pair.0, &pair.1);
3333

34-
let now = chrono::Utc::now().timestamp() as u64;
34+
let now = chrono::Utc::now().timestamp();
3535

3636
let timestamp = if let Some(timestamp) = params.timestamp {
3737
timestamp

pragma-node/src/handlers/entries/get_onchain/ohlc.rs

+41-7
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,26 @@ async fn handle_channel(mut socket: WebSocket, state: AppState) {
7373
let mut ohlc_to_compute = 10;
7474
let mut ohlc_data: Vec<OHLCEntry> = Vec::new();
7575

76+
//send a ping (unsupported by some browsers) just to kick things off and get a response
77+
if socket.send(Message::Ping(vec![])).await.is_ok() {
78+
tracing::info!("Pinged ...");
79+
} else {
80+
tracing::info!("Could not send ping !");
81+
// no Error here since the only thing we can do is to close the connection.
82+
// If we can not send messages, there is no way to salvage the statemachine anyway.
83+
return;
84+
}
85+
7686
loop {
7787
tokio::select! {
78-
Some(msg) = socket.recv() => {
79-
if let Ok(Message::Text(text)) = msg {
80-
handle_message_received(&mut socket, &state, &mut subscribed_pair, &mut network, &mut interval, text).await;
81-
}
88+
Some(maybe_msg) = socket.recv() => {
89+
// TODO: remove once we have proper top-level error handling
90+
let msg = if let Ok(msg) = maybe_msg {
91+
msg
92+
} else {
93+
break;
94+
};
95+
handle_message_received(&mut socket, &state, &mut subscribed_pair, &mut network, &mut interval, msg).await;
8296
},
8397
_ = update_interval.tick() => {
8498
match send_ohlc_data(&mut socket, &state, &subscribed_pair, &mut ohlc_data, network, interval, ohlc_to_compute).await {
@@ -103,9 +117,27 @@ async fn handle_message_received(
103117
subscribed_pair: &mut Option<String>,
104118
network: &mut Network,
105119
interval: &mut Interval,
106-
message: String,
120+
message: Message,
107121
) {
108-
if let Ok(subscription_msg) = serde_json::from_str::<SubscriptionRequest>(&message) {
122+
let maybe_client_message = match message {
123+
Message::Close(_) => {
124+
// TODO: Send the close message to gracefully shut down the connection
125+
// Otherwise the client might get an abnormal Websocket closure
126+
// error.
127+
return;
128+
}
129+
Message::Text(text) => serde_json::from_str::<SubscriptionRequest>(&text),
130+
Message::Binary(data) => serde_json::from_slice::<SubscriptionRequest>(&data),
131+
Message::Ping(_) => {
132+
// Axum will send Pong automatically
133+
return;
134+
}
135+
Message::Pong(_) => {
136+
return;
137+
}
138+
};
139+
140+
if let Ok(subscription_msg) = maybe_client_message {
109141
match subscription_msg.msg_type {
110142
SubscriptionType::Subscribe => {
111143
let pair_exists = is_onchain_existing_pair(
@@ -196,5 +228,7 @@ async fn send_ohlc_data(
196228
/// (Does not close the connection)
197229
async fn send_error_message(socket: &mut WebSocket, error: &str) {
198230
let error_msg = json!({ "error": error }).to_string();
199-
socket.send(Message::Text(error_msg)).await.unwrap();
231+
if socket.send(Message::Text(error_msg)).await.is_err() {
232+
tracing::error!("Client already disconnected. Could not send error message.");
233+
}
200234
}

pragma-node/src/handlers/entries/mod.rs

+10-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ pub use get_ohlc::get_ohlc;
1010
pub use get_onchain::get_onchain;
1111
pub use get_volatility::get_volatility;
1212

13-
use crate::infra::repositories::entry_repository::OHLCEntry;
13+
use crate::{
14+
infra::repositories::entry_repository::OHLCEntry,
15+
utils::{doc_examples, UnixTimestamp},
16+
};
1417

1518
pub mod create_entry;
1619
pub mod get_entry;
@@ -134,7 +137,11 @@ pub struct GetOnchainCheckpointsResponse(pub Vec<Checkpoint>);
134137
135138
#[derive(Debug, Deserialize, IntoParams, ToSchema)]
136139
pub struct GetEntryParams {
137-
pub timestamp: Option<u64>,
140+
/// The unix timestamp in seconds. This endpoint will return the first update whose
141+
/// timestamp is <= the provided value.
142+
#[param(value_type = i64)]
143+
#[param(example = doc_examples::timestamp_example)]
144+
pub timestamp: Option<UnixTimestamp>,
138145
pub interval: Option<Interval>,
139146
pub routing: Option<bool>,
140147
pub aggregation: Option<AggregationMode>,
@@ -143,7 +150,7 @@ pub struct GetEntryParams {
143150
impl Default for GetEntryParams {
144151
fn default() -> Self {
145152
Self {
146-
timestamp: Some(chrono::Utc::now().timestamp() as u64),
153+
timestamp: Some(chrono::Utc::now().timestamp()),
147154
interval: Some(Interval::default()),
148155
routing: Some(false),
149156
aggregation: Some(AggregationMode::default()),

pragma-node/src/infra/repositories/entry_repository.rs

+9-9
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ pub async fn routing(
8787
pool: &deadpool_diesel::postgres::Pool,
8888
pair_id: String,
8989
interval: Interval,
90-
timestamp: u64,
90+
timestamp: i64,
9191
is_routing: bool,
9292
agg_mode: AggregationMode,
9393
) -> Result<(MedianEntry, u32), InfraError> {
@@ -164,7 +164,7 @@ async fn find_alternative_pair_price(
164164
base: &str,
165165
quote: &str,
166166
interval: Interval,
167-
timestamp: u64,
167+
timestamp: i64,
168168
agg_mode: AggregationMode,
169169
) -> Result<(MedianEntry, u32), InfraError> {
170170
let conn = pool.get().await.map_err(adapt_infra_error)?;
@@ -213,7 +213,7 @@ async fn get_price_decimals(
213213
pool: &deadpool_diesel::postgres::Pool,
214214
pair_id: String,
215215
interval: Interval,
216-
timestamp: u64,
216+
timestamp: i64,
217217
agg_mode: AggregationMode,
218218
) -> Result<(MedianEntry, u32), InfraError> {
219219
let entry = match agg_mode {
@@ -251,7 +251,7 @@ pub async fn get_twap_price(
251251
pool: &deadpool_diesel::postgres::Pool,
252252
pair_id: String,
253253
interval: Interval,
254-
time: u64,
254+
time: i64,
255255
) -> Result<MedianEntry, InfraError> {
256256
let conn = pool.get().await.map_err(adapt_infra_error)?;
257257

@@ -330,7 +330,7 @@ pub async fn get_twap_price(
330330
}
331331
};
332332

333-
let date_time = DateTime::from_timestamp(time as i64, 0).ok_or(InfraError::InvalidTimeStamp)?;
333+
let date_time = DateTime::from_timestamp(time, 0).ok_or(InfraError::InvalidTimeStamp)?;
334334

335335
let raw_entry = conn
336336
.interact(move |conn| {
@@ -357,7 +357,7 @@ pub async fn get_median_price(
357357
pool: &deadpool_diesel::postgres::Pool,
358358
pair_id: String,
359359
interval: Interval,
360-
time: u64,
360+
time: i64,
361361
) -> Result<MedianEntry, InfraError> {
362362
let conn = pool.get().await.map_err(adapt_infra_error)?;
363363

@@ -436,7 +436,7 @@ pub async fn get_median_price(
436436
}
437437
};
438438

439-
let date_time = DateTime::from_timestamp(time as i64, 0).ok_or(InfraError::InvalidTimeStamp)?;
439+
let date_time = DateTime::from_timestamp(time, 0).ok_or(InfraError::InvalidTimeStamp)?;
440440

441441
let raw_entry = conn
442442
.interact(move |conn| {
@@ -595,7 +595,7 @@ pub async fn get_ohlc(
595595
pool: &deadpool_diesel::postgres::Pool,
596596
pair_id: String,
597597
interval: Interval,
598-
time: u64,
598+
time: i64,
599599
) -> Result<Vec<OHLCEntry>, InfraError> {
600600
let conn = pool.get().await.map_err(adapt_infra_error)?;
601601

@@ -682,7 +682,7 @@ pub async fn get_ohlc(
682682
}
683683
};
684684

685-
let date_time = DateTime::from_timestamp(time as i64, 0).ok_or(InfraError::InvalidTimeStamp)?;
685+
let date_time = DateTime::from_timestamp(time, 0).ok_or(InfraError::InvalidTimeStamp)?;
686686

687687
let raw_entries = conn
688688
.interact(move |conn| {

pragma-node/src/main.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use deadpool_diesel::postgres::Pool;
22
use pragma_entities::connection::{ENV_POSTGRES_DATABASE_URL, ENV_TS_DATABASE_URL};
33
use std::net::SocketAddr;
4+
use tower_http::cors::CorsLayer;
5+
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
46
use utoipa::openapi::security::{ApiKey, ApiKeyValue, SecurityScheme};
57
use utoipa::Modify;
68
use utoipa::OpenApi;
@@ -110,7 +112,15 @@ async fn main() {
110112
postgres_pool,
111113
};
112114

113-
let app = app_router::<ApiDoc>(state.clone()).with_state(state);
115+
let app = app_router::<ApiDoc>(state.clone())
116+
.with_state(state)
117+
// Logging so we can see whats going on
118+
.layer(
119+
TraceLayer::new_for_http()
120+
.make_span_with(DefaultMakeSpan::default().include_headers(true)),
121+
)
122+
// Permissive CORS layer to allow all origins
123+
.layer(CorsLayer::permissive());
114124

115125
let host = config.server_host();
116126
let port = config.server_port();

pragma-node/src/utils/doc_examples.rs

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
use crate::utils::types::UnixTimestamp;
2+
3+
/// Example value for a unix timestamp
4+
pub fn timestamp_example() -> UnixTimestamp {
5+
const STATIC_UNIX_TIMESTAMP: UnixTimestamp = 1717632000; // Thursday, 6 June 2024 00:00:00 GMT
6+
STATIC_UNIX_TIMESTAMP
7+
}

pragma-node/src/utils/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ pub use conversion::{convert_via_quote, format_bigdecimal_price, normalize_to_de
22
pub use custom_extractors::json_extractor::JsonExtractor;
33
pub use custom_extractors::path_extractor::PathExtractor;
44
pub use signing::typed_data::TypedData;
5+
pub use types::UnixTimestamp;
56

67
mod conversion;
78
mod custom_extractors;
9+
pub mod doc_examples;
810
mod signing;
11+
mod types;

pragma-node/src/utils/types.rs

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
/// The number of seconds since the Unix epoch (00:00:00 UTC on 1 Jan 1970). The timestamp is
2+
/// always positive, but represented as a signed integer because that's the standard on Unix
3+
/// systems and allows easy subtraction to compute durations.
4+
pub type UnixTimestamp = i64;

0 commit comments

Comments
 (0)