1use crate::error::{Result, RuntimeError};
5use crate::traits::{
6 EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, Message, RerankerModel,
7 ScoredDoc,
8};
9use async_trait::async_trait;
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15enum State {
16 Closed,
17 Open,
18 HalfOpen,
19}
20
21pub struct CircuitBreakerConfig {
23 pub failure_threshold: u32,
25 pub open_wait_seconds: u64,
27}
28
29impl Default for CircuitBreakerConfig {
30 fn default() -> Self {
31 Self {
32 failure_threshold: 5,
33 open_wait_seconds: 10,
34 }
35 }
36}
37
38struct Inner {
39 state: State,
40 failures: u32,
41 last_failure: Option<Instant>,
42 config: CircuitBreakerConfig,
43 half_open_probe_in_flight: bool,
44}
45
46#[derive(Clone)]
53pub struct CircuitBreakerWrapper {
54 inner: Arc<Mutex<Inner>>,
55}
56
57impl CircuitBreakerWrapper {
58 pub fn new(config: CircuitBreakerConfig) -> Self {
60 Self {
61 inner: Arc::new(Mutex::new(Inner {
62 state: State::Closed,
63 failures: 0,
64 last_failure: None,
65 config,
66 half_open_probe_in_flight: false,
67 })),
68 }
69 }
70
71 pub async fn call<F, Fut, T>(&self, f: F) -> Result<T>
77 where
78 F: FnOnce() -> Fut,
79 Fut: std::future::Future<Output = Result<T>>,
80 {
81 let is_probe_call;
82
83 {
85 let mut inner = self.inner.lock().unwrap();
86 match inner.state {
87 State::Open => {
88 if let Some(last) = inner.last_failure {
89 if last.elapsed() >= Duration::from_secs(inner.config.open_wait_seconds) {
90 inner.state = State::HalfOpen;
91 } else {
92 return Err(RuntimeError::Unavailable);
93 }
94 }
95 }
96 State::HalfOpen => {
97 if inner.half_open_probe_in_flight {
98 return Err(RuntimeError::Unavailable);
99 }
100 }
101 State::Closed => {}
102 }
103 is_probe_call = inner.state == State::HalfOpen;
104 if is_probe_call {
105 inner.half_open_probe_in_flight = true;
106 }
107 }
108
109 let result = f().await;
111
112 let mut inner = self.inner.lock().unwrap();
114 match result {
115 Ok(val) => {
116 if is_probe_call {
117 inner.state = State::Closed;
118 inner.failures = 0;
119 inner.half_open_probe_in_flight = false;
120 } else if inner.state == State::Closed {
121 inner.failures = 0;
122 }
123 Ok(val)
124 }
125 Err(e) => {
126 if is_probe_call {
127 inner.half_open_probe_in_flight = false;
128 }
129 inner.failures += 1;
130 inner.last_failure = Some(Instant::now());
131
132 if is_probe_call
133 || (inner.state == State::Closed
134 && inner.failures >= inner.config.failure_threshold)
135 {
136 inner.state = State::Open;
137 }
138 Err(e)
139 }
140 }
141 }
142}
143
144pub struct InstrumentedEmbeddingModel {
148 pub inner: Arc<dyn EmbeddingModel>,
149 pub alias: String,
150 pub provider_id: String,
151 pub timeout: Option<Duration>,
152 pub retry: Option<crate::api::RetryConfig>,
153}
154
155#[async_trait]
156impl EmbeddingModel for InstrumentedEmbeddingModel {
157 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
158 let start = Instant::now();
159 let mut attempts = 0;
160 let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
161
162 let res = loop {
163 attempts += 1;
164 let fut = self.inner.embed(texts.clone());
165
166 let res = if let Some(timeout) = self.timeout {
167 match tokio::time::timeout(timeout, fut).await {
168 Ok(r) => r,
169 Err(_) => Err(RuntimeError::Timeout),
170 }
171 } else {
172 fut.await
173 };
174
175 match res {
176 Ok(val) => break Ok(val),
177 Err(e) if e.is_retryable() && attempts < max_attempts => {
178 let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
179 tracing::warn!(
180 alias = %self.alias,
181 attempt = attempts,
182 backoff_ms = backoff.as_millis(),
183 error = %e,
184 "Retrying embedding call"
185 );
186 tokio::time::sleep(backoff).await;
187 continue;
188 }
189 Err(e) => break Err(e),
190 }
191 };
192
193 let duration = start.elapsed();
194 let status = if res.is_ok() { "success" } else { "failure" };
195
196 metrics::histogram!(
197 "model_inference.duration_seconds",
198 "alias" => self.alias.clone(),
199 "task" => "embed",
200 "provider" => self.provider_id.clone()
201 )
202 .record(duration.as_secs_f64());
203
204 metrics::counter!(
205 "model_inference.total",
206 "alias" => self.alias.clone(),
207 "task" => "embed",
208 "provider" => self.provider_id.clone(),
209 "status" => status
210 )
211 .increment(1);
212
213 res
214 }
215
216 fn dimensions(&self) -> u32 {
217 self.inner.dimensions()
218 }
219
220 fn model_id(&self) -> &str {
221 self.inner.model_id()
222 }
223
224 async fn warmup(&self) -> Result<()> {
225 self.inner.warmup().await
226 }
227}
228
229pub struct InstrumentedGeneratorModel {
233 pub inner: Arc<dyn GeneratorModel>,
234 pub alias: String,
235 pub provider_id: String,
236 pub timeout: Option<Duration>,
237 pub retry: Option<crate::api::RetryConfig>,
238}
239
240#[async_trait]
241impl GeneratorModel for InstrumentedGeneratorModel {
242 async fn generate(
243 &self,
244 messages: &[Message],
245 options: GenerationOptions,
246 ) -> Result<GenerationResult> {
247 let start = Instant::now();
248 let mut attempts = 0;
249 let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
250
251 let res = loop {
252 attempts += 1;
253 let fut = self.inner.generate(messages, options.clone());
254
255 let res = if let Some(timeout) = self.timeout {
256 match tokio::time::timeout(timeout, fut).await {
257 Ok(r) => r,
258 Err(_) => Err(RuntimeError::Timeout),
259 }
260 } else {
261 fut.await
262 };
263
264 match res {
265 Ok(val) => break Ok(val),
266 Err(e) if e.is_retryable() && attempts < max_attempts => {
267 let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
268 tracing::warn!(
269 alias = %self.alias,
270 attempt = attempts,
271 backoff_ms = backoff.as_millis(),
272 error = %e,
273 "Retrying generation call"
274 );
275 tokio::time::sleep(backoff).await;
276 continue;
277 }
278 Err(e) => break Err(e),
279 }
280 };
281
282 let duration = start.elapsed();
283 let status = if res.is_ok() { "success" } else { "failure" };
284
285 metrics::histogram!(
286 "model_inference.duration_seconds",
287 "alias" => self.alias.clone(),
288 "task" => "generate",
289 "provider" => self.provider_id.clone()
290 )
291 .record(duration.as_secs_f64());
292
293 metrics::counter!(
294 "model_inference.total",
295 "alias" => self.alias.clone(),
296 "task" => "generate",
297 "provider" => self.provider_id.clone(),
298 "status" => status
299 )
300 .increment(1);
301
302 res
303 }
304
305 async fn warmup(&self) -> Result<()> {
306 self.inner.warmup().await
307 }
308}
309
310pub struct InstrumentedRerankerModel {
314 pub inner: Arc<dyn RerankerModel>,
315 pub alias: String,
316 pub provider_id: String,
317 pub timeout: Option<Duration>,
318 pub retry: Option<crate::api::RetryConfig>,
319}
320
321#[async_trait]
322impl RerankerModel for InstrumentedRerankerModel {
323 async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>> {
324 let start = Instant::now();
325 let mut attempts = 0;
326 let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
327
328 let res = loop {
329 attempts += 1;
330 let fut = self.inner.rerank(query, docs);
331
332 let res = if let Some(timeout) = self.timeout {
333 match tokio::time::timeout(timeout, fut).await {
334 Ok(r) => r,
335 Err(_) => Err(RuntimeError::Timeout),
336 }
337 } else {
338 fut.await
339 };
340
341 match res {
342 Ok(val) => break Ok(val),
343 Err(e) if e.is_retryable() && attempts < max_attempts => {
344 let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
345 tracing::warn!(
346 alias = %self.alias,
347 attempt = attempts,
348 backoff_ms = backoff.as_millis(),
349 error = %e,
350 "Retrying rerank call"
351 );
352 tokio::time::sleep(backoff).await;
353 continue;
354 }
355 Err(e) => break Err(e),
356 }
357 };
358
359 let duration = start.elapsed();
360 let status = if res.is_ok() { "success" } else { "failure" };
361
362 metrics::histogram!(
363 "model_inference.duration_seconds",
364 "alias" => self.alias.clone(),
365 "task" => "rerank",
366 "provider" => self.provider_id.clone()
367 )
368 .record(duration.as_secs_f64());
369
370 metrics::counter!(
371 "model_inference.total",
372 "alias" => self.alias.clone(),
373 "task" => "rerank",
374 "provider" => self.provider_id.clone(),
375 "status" => status
376 )
377 .increment(1);
378
379 res
380 }
381
382 async fn warmup(&self) -> Result<()> {
383 self.inner.warmup().await
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use std::sync::atomic::{AtomicU32, Ordering};
391
392 #[tokio::test]
393 async fn test_circuit_breaker_transitions() {
394 let config = CircuitBreakerConfig {
395 failure_threshold: 2,
396 open_wait_seconds: 1,
397 };
398 let cb = CircuitBreakerWrapper::new(config);
399 let counter = Arc::new(AtomicU32::new(0));
400
401 let res = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
403 assert!(res.is_ok());
404
405 let res = cb
407 .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
408 .await;
409 assert!(res.is_err()); let res = cb
412 .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
413 .await;
414 assert!(res.is_err()); let res = cb
418 .call(|| async {
419 counter.fetch_add(1, Ordering::SeqCst);
420 Ok(())
421 })
422 .await;
423 assert!(res.is_err());
424 assert_eq!(res.err().unwrap().to_string(), "Unavailable");
425 assert_eq!(counter.load(Ordering::SeqCst), 0); tokio::time::sleep(Duration::from_millis(1100)).await;
429
430 let res = cb
433 .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
434 .await;
435 assert!(res.is_err());
436
437 let res = cb.call(|| async { Ok(()) }).await;
439 assert!(res.is_err());
440 assert_eq!(res.err().unwrap().to_string(), "Unavailable");
441
442 tokio::time::sleep(Duration::from_millis(1100)).await;
444
445 let res = cb.call(|| async { Ok(()) }).await;
447 assert!(res.is_ok());
448
449 let res = cb.call(|| async { Ok(()) }).await;
451 assert!(res.is_ok());
452 }
453
454 #[tokio::test]
455 async fn test_half_open_allows_single_probe() {
456 let config = CircuitBreakerConfig {
457 failure_threshold: 1,
458 open_wait_seconds: 1,
459 };
460 let cb = CircuitBreakerWrapper::new(config);
461
462 let _ = cb
464 .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
465 .await;
466
467 tokio::time::sleep(Duration::from_millis(1100)).await;
468
469 let started = Arc::new(std::sync::atomic::AtomicU32::new(0));
470 let finished = Arc::new(std::sync::atomic::AtomicU32::new(0));
471
472 let cb_probe = cb.clone();
473 let started_probe = started.clone();
474 let finished_probe = finished.clone();
475 let probe = tokio::spawn(async move {
476 cb_probe
477 .call(|| async move {
478 started_probe.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
479 tokio::time::sleep(Duration::from_millis(150)).await;
480 finished_probe.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
481 Ok::<_, RuntimeError>(())
482 })
483 .await
484 });
485
486 tokio::time::sleep(Duration::from_millis(20)).await;
488
489 let second = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
491 assert!(matches!(second, Err(RuntimeError::Unavailable)));
492
493 let probe_result = probe.await.unwrap();
494 assert!(probe_result.is_ok());
495 assert_eq!(started.load(std::sync::atomic::Ordering::SeqCst), 1);
496 assert_eq!(finished.load(std::sync::atomic::Ordering::SeqCst), 1);
497
498 let res = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
500 assert!(res.is_ok());
501 }
502}