1#![warn(missing_debug_implementations, missing_docs, unreachable_pub)]
5
6use crate::filter::AsyncFilter;
7use futures_util::future::Either;
8use pin_project_lite::pin_project;
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use std::{
12 future,
13 pin::Pin,
14 task::{Context, Poll},
15};
16use tracing::error;
17
18mod delay;
19mod latency;
20mod rotating_histogram;
21mod select;
22
23use delay::Delay;
24use latency::Latency;
25use rotating_histogram::RotatingHistogram;
26use select::Select;
27
28type Histo = Arc<Mutex<RotatingHistogram>>;
29type Service<S, P> = select::Select<
30 SelectPolicy<P>,
31 Latency<Histo, S>,
32 Delay<DelayPolicy, AsyncFilter<Latency<Histo, S>, PolicyPredicate<P>>>,
33>;
34
35#[derive(Debug)]
39pub struct Hedge<S, P>(Service<S, P>);
40
41pin_project! {
42 #[derive(Debug)]
46 pub struct Future<S, Request>
47 where
48 S: tower_service::Service<Request>,
49 {
50 #[pin]
51 inner: S::Future,
52 }
53}
54
55pub trait Policy<Request> {
58 fn clone_request(&self, req: &Request) -> Option<Request>;
60
61 fn can_retry(&self, req: &Request) -> bool;
63}
64
65#[doc(hidden)]
68#[derive(Clone, Debug)]
69pub struct PolicyPredicate<P>(P);
70
71#[doc(hidden)]
72#[derive(Debug)]
73pub struct DelayPolicy {
74 histo: Histo,
75 latency_percentile: f32,
76}
77
78#[doc(hidden)]
79#[derive(Debug)]
80pub struct SelectPolicy<P> {
81 policy: P,
82 histo: Histo,
83 min_data_points: u64,
84}
85
86impl<S, P> Hedge<S, P> {
87 pub fn new<Request>(
93 service: S,
94 policy: P,
95 min_data_points: u64,
96 latency_percentile: f32,
97 period: Duration,
98 ) -> Hedge<S, P>
99 where
100 S: tower_service::Service<Request> + Clone,
101 S::Error: Into<crate::BoxError>,
102 P: Policy<Request> + Clone,
103 {
104 assert!(
105 period > Duration::ZERO,
106 "histogram rotation period must be greater than zero"
107 );
108 let histo = Arc::new(Mutex::new(RotatingHistogram::new(period)));
109 Self::new_with_histo(service, policy, min_data_points, latency_percentile, histo)
110 }
111
112 pub fn new_with_mock_latencies<Request>(
115 service: S,
116 policy: P,
117 min_data_points: u64,
118 latency_percentile: f32,
119 period: Duration,
120 latencies_ms: &[u64],
121 ) -> Hedge<S, P>
122 where
123 S: tower_service::Service<Request> + Clone,
124 S::Error: Into<crate::BoxError>,
125 P: Policy<Request> + Clone,
126 {
127 assert!(
128 period > Duration::ZERO,
129 "histogram rotation period must be greater than zero"
130 );
131
132 let histo = Arc::new(Mutex::new(RotatingHistogram::new(period)));
133 {
134 let mut locked = histo.lock().unwrap();
135 for latency in latencies_ms.iter() {
136 locked.read().record(*latency).unwrap();
137 }
138 }
139 Self::new_with_histo(service, policy, min_data_points, latency_percentile, histo)
140 }
141
142 fn new_with_histo<Request>(
143 service: S,
144 policy: P,
145 min_data_points: u64,
146 latency_percentile: f32,
147 histo: Histo,
148 ) -> Hedge<S, P>
149 where
150 S: tower_service::Service<Request> + Clone,
151 S::Error: Into<crate::BoxError>,
152 P: Policy<Request> + Clone,
153 {
154 let recorded_a = Latency::new(histo.clone(), service.clone());
157 let recorded_b = Latency::new(histo.clone(), service);
158
159 let filtered = AsyncFilter::new(recorded_b, PolicyPredicate(policy.clone()));
161
162 let delay_policy = DelayPolicy {
165 histo: histo.clone(),
166 latency_percentile,
167 };
168 let delayed = Delay::new(delay_policy, filtered);
169
170 let select_policy = SelectPolicy {
173 policy,
174 histo,
175 min_data_points,
176 };
177 Hedge(Select::new(select_policy, recorded_a, delayed))
178 }
179}
180
181impl<S, P, Request> tower_service::Service<Request> for Hedge<S, P>
182where
183 S: tower_service::Service<Request> + Clone,
184 S::Error: Into<crate::BoxError>,
185 P: Policy<Request> + Clone,
186{
187 type Response = S::Response;
188 type Error = crate::BoxError;
189 type Future = Future<Service<S, P>, Request>;
190
191 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
192 self.0.poll_ready(cx)
193 }
194
195 fn call(&mut self, request: Request) -> Self::Future {
196 Future {
197 inner: self.0.call(request),
198 }
199 }
200}
201
202impl<S, Request> std::future::Future for Future<S, Request>
203where
204 S: tower_service::Service<Request>,
205 S::Error: Into<crate::BoxError>,
206{
207 type Output = Result<S::Response, crate::BoxError>;
208
209 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
210 self.project().inner.poll(cx).map_err(Into::into)
211 }
212}
213
214const NANOS_PER_MILLI: u32 = 1_000_000;
216const MILLIS_PER_SEC: u64 = 1_000;
217fn millis(duration: Duration) -> u64 {
218 let millis = (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI;
220 duration
221 .as_secs()
222 .saturating_mul(MILLIS_PER_SEC)
223 .saturating_add(u64::from(millis))
224}
225
226impl latency::Record for Histo {
227 fn record(&mut self, latency: Duration) {
228 let mut locked = self.lock().unwrap();
229 locked.write().record(millis(latency)).unwrap_or_else(|e| {
230 error!("Failed to write to hedge histogram: {:?}", e);
231 })
232 }
233}
234
235impl<P, Request> crate::filter::AsyncPredicate<Request> for PolicyPredicate<P>
236where
237 P: Policy<Request>,
238{
239 type Future = Either<
240 future::Ready<Result<Request, crate::BoxError>>,
241 future::Pending<Result<Request, crate::BoxError>>,
242 >;
243 type Request = Request;
244
245 fn check(&mut self, request: Request) -> Self::Future {
246 if self.0.can_retry(&request) {
247 Either::Left(future::ready(Ok(request)))
248 } else {
249 Either::Right(future::pending())
254 }
255 }
256}
257
258impl<Request> delay::Policy<Request> for DelayPolicy {
259 fn delay(&self, _req: &Request) -> Duration {
260 let mut locked = self.histo.lock().unwrap();
261 let millis = locked
262 .read()
263 .value_at_quantile(self.latency_percentile.into());
264 Duration::from_millis(millis)
265 }
266}
267
268impl<P, Request> select::Policy<Request> for SelectPolicy<P>
269where
270 P: Policy<Request>,
271{
272 fn clone_request(&self, req: &Request) -> Option<Request> {
273 self.policy.clone_request(req).filter(|_| {
274 let mut locked = self.histo.lock().unwrap();
275 locked.read().len() >= self.min_data_points
278 })
279 }
280}