1use crate::error::{Result, RuntimeError};
5use crate::traits::{
6 EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, RerankerModel, ScoredDoc,
7};
8use async_trait::async_trait;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14enum State {
15 Closed,
16 Open,
17 HalfOpen,
18}
19
20pub struct CircuitBreakerConfig {
22 pub failure_threshold: u32,
24 pub open_wait_seconds: u64,
26}
27
28impl Default for CircuitBreakerConfig {
29 fn default() -> Self {
30 Self {
31 failure_threshold: 5,
32 open_wait_seconds: 10,
33 }
34 }
35}
36
37struct Inner {
38 state: State,
39 failures: u32,
40 last_failure: Option<Instant>,
41 config: CircuitBreakerConfig,
42 half_open_probe_in_flight: bool,
43}
44
45#[derive(Clone)]
52pub struct CircuitBreakerWrapper {
53 inner: Arc<Mutex<Inner>>,
54}
55
56impl CircuitBreakerWrapper {
57 pub fn new(config: CircuitBreakerConfig) -> Self {
59 Self {
60 inner: Arc::new(Mutex::new(Inner {
61 state: State::Closed,
62 failures: 0,
63 last_failure: None,
64 config,
65 half_open_probe_in_flight: false,
66 })),
67 }
68 }
69
70 pub async fn call<F, Fut, T>(&self, f: F) -> Result<T>
76 where
77 F: FnOnce() -> Fut,
78 Fut: std::future::Future<Output = Result<T>>,
79 {
80 let is_probe_call;
81
82 {
84 let mut inner = self.inner.lock().unwrap();
85 match inner.state {
86 State::Open => {
87 if let Some(last) = inner.last_failure {
88 if last.elapsed() >= Duration::from_secs(inner.config.open_wait_seconds) {
89 inner.state = State::HalfOpen;
90 } else {
91 return Err(RuntimeError::Unavailable);
92 }
93 }
94 }
95 State::HalfOpen => {
96 if inner.half_open_probe_in_flight {
97 return Err(RuntimeError::Unavailable);
98 }
99 }
100 State::Closed => {}
101 }
102 is_probe_call = inner.state == State::HalfOpen;
103 if is_probe_call {
104 inner.half_open_probe_in_flight = true;
105 }
106 }
107
108 let result = f().await;
110
111 let mut inner = self.inner.lock().unwrap();
113 match result {
114 Ok(val) => {
115 if is_probe_call {
116 inner.state = State::Closed;
117 inner.failures = 0;
118 inner.half_open_probe_in_flight = false;
119 } else if inner.state == State::Closed {
120 inner.failures = 0;
121 }
122 Ok(val)
123 }
124 Err(e) => {
125 if is_probe_call {
126 inner.half_open_probe_in_flight = false;
127 }
128 inner.failures += 1;
129 inner.last_failure = Some(Instant::now());
130
131 if is_probe_call
132 || (inner.state == State::Closed
133 && inner.failures >= inner.config.failure_threshold)
134 {
135 inner.state = State::Open;
136 }
137 Err(e)
138 }
139 }
140 }
141}
142
143pub struct InstrumentedEmbeddingModel {
147 pub inner: Arc<dyn EmbeddingModel>,
148 pub alias: String,
149 pub provider_id: String,
150 pub timeout: Option<Duration>,
151 pub retry: Option<crate::api::RetryConfig>,
152}
153
154#[async_trait]
155impl EmbeddingModel for InstrumentedEmbeddingModel {
156 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
157 let start = Instant::now();
158 let mut attempts = 0;
159 let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
160
161 let res = loop {
162 attempts += 1;
163 let fut = self.inner.embed(texts.clone());
164
165 let res = if let Some(timeout) = self.timeout {
166 match tokio::time::timeout(timeout, fut).await {
167 Ok(r) => r,
168 Err(_) => Err(RuntimeError::Timeout),
169 }
170 } else {
171 fut.await
172 };
173
174 match res {
175 Ok(val) => break Ok(val),
176 Err(e) if e.is_retryable() && attempts < max_attempts => {
177 let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
178 tracing::warn!(
179 alias = %self.alias,
180 attempt = attempts,
181 backoff_ms = backoff.as_millis(),
182 error = %e,
183 "Retrying embedding call"
184 );
185 tokio::time::sleep(backoff).await;
186 continue;
187 }
188 Err(e) => break Err(e),
189 }
190 };
191
192 let duration = start.elapsed();
193 let status = if res.is_ok() { "success" } else { "failure" };
194
195 metrics::histogram!(
196 "model_inference.duration_seconds",
197 "alias" => self.alias.clone(),
198 "task" => "embed",
199 "provider" => self.provider_id.clone()
200 )
201 .record(duration.as_secs_f64());
202
203 metrics::counter!(
204 "model_inference.total",
205 "alias" => self.alias.clone(),
206 "task" => "embed",
207 "provider" => self.provider_id.clone(),
208 "status" => status
209 )
210 .increment(1);
211
212 res
213 }
214
215 fn dimensions(&self) -> u32 {
216 self.inner.dimensions()
217 }
218
219 fn model_id(&self) -> &str {
220 self.inner.model_id()
221 }
222
223 async fn warmup(&self) -> Result<()> {
224 self.inner.warmup().await
225 }
226}
227
228pub struct InstrumentedGeneratorModel {
232 pub inner: Arc<dyn GeneratorModel>,
233 pub alias: String,
234 pub provider_id: String,
235 pub timeout: Option<Duration>,
236 pub retry: Option<crate::api::RetryConfig>,
237}
238
239#[async_trait]
240impl GeneratorModel for InstrumentedGeneratorModel {
241 async fn generate(
242 &self,
243 messages: &[String],
244 options: GenerationOptions,
245 ) -> Result<GenerationResult> {
246 let start = Instant::now();
247 let mut attempts = 0;
248 let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
249
250 let res = loop {
251 attempts += 1;
252 let fut = self.inner.generate(messages, options.clone());
253
254 let res = if let Some(timeout) = self.timeout {
255 match tokio::time::timeout(timeout, fut).await {
256 Ok(r) => r,
257 Err(_) => Err(RuntimeError::Timeout),
258 }
259 } else {
260 fut.await
261 };
262
263 match res {
264 Ok(val) => break Ok(val),
265 Err(e) if e.is_retryable() && attempts < max_attempts => {
266 let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
267 tracing::warn!(
268 alias = %self.alias,
269 attempt = attempts,
270 backoff_ms = backoff.as_millis(),
271 error = %e,
272 "Retrying generation call"
273 );
274 tokio::time::sleep(backoff).await;
275 continue;
276 }
277 Err(e) => break Err(e),
278 }
279 };
280
281 let duration = start.elapsed();
282 let status = if res.is_ok() { "success" } else { "failure" };
283
284 metrics::histogram!(
285 "model_inference.duration_seconds",
286 "alias" => self.alias.clone(),
287 "task" => "generate",
288 "provider" => self.provider_id.clone()
289 )
290 .record(duration.as_secs_f64());
291
292 metrics::counter!(
293 "model_inference.total",
294 "alias" => self.alias.clone(),
295 "task" => "generate",
296 "provider" => self.provider_id.clone(),
297 "status" => status
298 )
299 .increment(1);
300
301 res
302 }
303
304 async fn warmup(&self) -> Result<()> {
305 self.inner.warmup().await
306 }
307}
308
309pub struct InstrumentedRerankerModel {
313 pub inner: Arc<dyn RerankerModel>,
314 pub alias: String,
315 pub provider_id: String,
316 pub timeout: Option<Duration>,
317 pub retry: Option<crate::api::RetryConfig>,
318}
319
320#[async_trait]
321impl RerankerModel for InstrumentedRerankerModel {
322 async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>> {
323 let start = Instant::now();
324 let mut attempts = 0;
325 let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
326
327 let res = loop {
328 attempts += 1;
329 let fut = self.inner.rerank(query, docs);
330
331 let res = if let Some(timeout) = self.timeout {
332 match tokio::time::timeout(timeout, fut).await {
333 Ok(r) => r,
334 Err(_) => Err(RuntimeError::Timeout),
335 }
336 } else {
337 fut.await
338 };
339
340 match res {
341 Ok(val) => break Ok(val),
342 Err(e) if e.is_retryable() && attempts < max_attempts => {
343 let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
344 tracing::warn!(
345 alias = %self.alias,
346 attempt = attempts,
347 backoff_ms = backoff.as_millis(),
348 error = %e,
349 "Retrying rerank call"
350 );
351 tokio::time::sleep(backoff).await;
352 continue;
353 }
354 Err(e) => break Err(e),
355 }
356 };
357
358 let duration = start.elapsed();
359 let status = if res.is_ok() { "success" } else { "failure" };
360
361 metrics::histogram!(
362 "model_inference.duration_seconds",
363 "alias" => self.alias.clone(),
364 "task" => "rerank",
365 "provider" => self.provider_id.clone()
366 )
367 .record(duration.as_secs_f64());
368
369 metrics::counter!(
370 "model_inference.total",
371 "alias" => self.alias.clone(),
372 "task" => "rerank",
373 "provider" => self.provider_id.clone(),
374 "status" => status
375 )
376 .increment(1);
377
378 res
379 }
380
381 async fn warmup(&self) -> Result<()> {
382 self.inner.warmup().await
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use std::sync::atomic::{AtomicU32, Ordering};
390
391 #[tokio::test]
392 async fn test_circuit_breaker_transitions() {
393 let config = CircuitBreakerConfig {
394 failure_threshold: 2,
395 open_wait_seconds: 1,
396 };
397 let cb = CircuitBreakerWrapper::new(config);
398 let counter = Arc::new(AtomicU32::new(0));
399
400 let res = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
402 assert!(res.is_ok());
403
404 let res = cb
406 .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
407 .await;
408 assert!(res.is_err()); let res = cb
411 .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
412 .await;
413 assert!(res.is_err()); let res = cb
417 .call(|| async {
418 counter.fetch_add(1, Ordering::SeqCst);
419 Ok(())
420 })
421 .await;
422 assert!(res.is_err());
423 assert_eq!(res.err().unwrap().to_string(), "Unavailable");
424 assert_eq!(counter.load(Ordering::SeqCst), 0); tokio::time::sleep(Duration::from_millis(1100)).await;
428
429 let res = cb
432 .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
433 .await;
434 assert!(res.is_err());
435
436 let res = cb.call(|| async { Ok(()) }).await;
438 assert!(res.is_err());
439 assert_eq!(res.err().unwrap().to_string(), "Unavailable");
440
441 tokio::time::sleep(Duration::from_millis(1100)).await;
443
444 let res = cb.call(|| async { Ok(()) }).await;
446 assert!(res.is_ok());
447
448 let res = cb.call(|| async { Ok(()) }).await;
450 assert!(res.is_ok());
451 }
452
453 #[tokio::test]
454 async fn test_half_open_allows_single_probe() {
455 let config = CircuitBreakerConfig {
456 failure_threshold: 1,
457 open_wait_seconds: 1,
458 };
459 let cb = CircuitBreakerWrapper::new(config);
460
461 let _ = cb
463 .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
464 .await;
465
466 tokio::time::sleep(Duration::from_millis(1100)).await;
467
468 let started = Arc::new(std::sync::atomic::AtomicU32::new(0));
469 let finished = Arc::new(std::sync::atomic::AtomicU32::new(0));
470
471 let cb_probe = cb.clone();
472 let started_probe = started.clone();
473 let finished_probe = finished.clone();
474 let probe = tokio::spawn(async move {
475 cb_probe
476 .call(|| async move {
477 started_probe.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
478 tokio::time::sleep(Duration::from_millis(150)).await;
479 finished_probe.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
480 Ok::<_, RuntimeError>(())
481 })
482 .await
483 });
484
485 tokio::time::sleep(Duration::from_millis(20)).await;
487
488 let second = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
490 assert!(matches!(second, Err(RuntimeError::Unavailable)));
491
492 let probe_result = probe.await.unwrap();
493 assert!(probe_result.is_ok());
494 assert_eq!(started.load(std::sync::atomic::Ordering::SeqCst), 1);
495 assert_eq!(finished.load(std::sync::atomic::Ordering::SeqCst), 1);
496
497 let res = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
499 assert!(res.is_ok());
500 }
501}