uni_xervo/
reliability.rs

1//! Reliability primitives: circuit breaker, instrumented model wrappers with
2//! timeout and retry support, and metrics emission.
3
4use crate::error::{Result, RuntimeError};
5use crate::traits::{
6    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, RerankerModel, ScoredDoc,
7};
8use async_trait::async_trait;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11
12/// Internal circuit breaker state machine.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14enum State {
15    Closed,
16    Open,
17    HalfOpen,
18}
19
20/// Tunable parameters for the circuit breaker.
21pub struct CircuitBreakerConfig {
22    /// Number of consecutive failures before the breaker opens.
23    pub failure_threshold: u32,
24    /// Seconds to wait in the open state before allowing a probe call.
25    pub open_wait_seconds: u64,
26}
27
28impl Default for CircuitBreakerConfig {
29    fn default() -> Self {
30        Self {
31            failure_threshold: 5,
32            open_wait_seconds: 10,
33        }
34    }
35}
36
37struct Inner {
38    state: State,
39    failures: u32,
40    last_failure: Option<Instant>,
41    config: CircuitBreakerConfig,
42    half_open_probe_in_flight: bool,
43}
44
45/// Thread-safe circuit breaker that tracks failures and short-circuits calls
46/// when a provider is unhealthy.
47///
48/// State transitions: **Closed** -> (failures >= threshold) -> **Open** ->
49/// (wait period elapsed) -> **HalfOpen** -> (probe succeeds) -> **Closed**
50/// (or probe fails -> back to **Open**).
51#[derive(Clone)]
52pub struct CircuitBreakerWrapper {
53    inner: Arc<Mutex<Inner>>,
54}
55
56impl CircuitBreakerWrapper {
57    /// Create a new circuit breaker with the given configuration.
58    pub fn new(config: CircuitBreakerConfig) -> Self {
59        Self {
60            inner: Arc::new(Mutex::new(Inner {
61                state: State::Closed,
62                failures: 0,
63                last_failure: None,
64                config,
65                half_open_probe_in_flight: false,
66            })),
67        }
68    }
69
70    /// Execute `f` through the circuit breaker.
71    ///
72    /// Returns [`RuntimeError::Unavailable`] immediately when the breaker is
73    /// open.  In the half-open state only a single probe call is allowed;
74    /// concurrent callers receive `Unavailable` until the probe completes.
75    pub async fn call<F, Fut, T>(&self, f: F) -> Result<T>
76    where
77        F: FnOnce() -> Fut,
78        Fut: std::future::Future<Output = Result<T>>,
79    {
80        let is_probe_call;
81
82        // 1. Check state
83        {
84            let mut inner = self.inner.lock().unwrap();
85            match inner.state {
86                State::Open => {
87                    if let Some(last) = inner.last_failure {
88                        if last.elapsed() >= Duration::from_secs(inner.config.open_wait_seconds) {
89                            inner.state = State::HalfOpen;
90                        } else {
91                            return Err(RuntimeError::Unavailable);
92                        }
93                    }
94                }
95                State::HalfOpen => {
96                    if inner.half_open_probe_in_flight {
97                        return Err(RuntimeError::Unavailable);
98                    }
99                }
100                State::Closed => {}
101            }
102            is_probe_call = inner.state == State::HalfOpen;
103            if is_probe_call {
104                inner.half_open_probe_in_flight = true;
105            }
106        }
107
108        // 2. Execute
109        let result = f().await;
110
111        // 3. Update state
112        let mut inner = self.inner.lock().unwrap();
113        match result {
114            Ok(val) => {
115                if is_probe_call {
116                    inner.state = State::Closed;
117                    inner.failures = 0;
118                    inner.half_open_probe_in_flight = false;
119                } else if inner.state == State::Closed {
120                    inner.failures = 0;
121                }
122                Ok(val)
123            }
124            Err(e) => {
125                if is_probe_call {
126                    inner.half_open_probe_in_flight = false;
127                }
128                inner.failures += 1;
129                inner.last_failure = Some(Instant::now());
130
131                if is_probe_call
132                    || (inner.state == State::Closed
133                        && inner.failures >= inner.config.failure_threshold)
134                {
135                    inner.state = State::Open;
136                }
137                Err(e)
138            }
139        }
140    }
141}
142
143/// Wrapper around an [`EmbeddingModel`] that adds per-call timeout enforcement,
144/// exponential-backoff retries for transient errors, and metrics emission
145/// (`model_inference.duration_seconds`, `model_inference.total`).
146pub struct InstrumentedEmbeddingModel {
147    pub inner: Arc<dyn EmbeddingModel>,
148    pub alias: String,
149    pub provider_id: String,
150    pub timeout: Option<Duration>,
151    pub retry: Option<crate::api::RetryConfig>,
152}
153
154#[async_trait]
155impl EmbeddingModel for InstrumentedEmbeddingModel {
156    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
157        let start = Instant::now();
158        let mut attempts = 0;
159        let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
160
161        let res = loop {
162            attempts += 1;
163            let fut = self.inner.embed(texts.clone());
164
165            let res = if let Some(timeout) = self.timeout {
166                match tokio::time::timeout(timeout, fut).await {
167                    Ok(r) => r,
168                    Err(_) => Err(RuntimeError::Timeout),
169                }
170            } else {
171                fut.await
172            };
173
174            match res {
175                Ok(val) => break Ok(val),
176                Err(e) if e.is_retryable() && attempts < max_attempts => {
177                    let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
178                    tracing::warn!(
179                        alias = %self.alias,
180                        attempt = attempts,
181                        backoff_ms = backoff.as_millis(),
182                        error = %e,
183                        "Retrying embedding call"
184                    );
185                    tokio::time::sleep(backoff).await;
186                    continue;
187                }
188                Err(e) => break Err(e),
189            }
190        };
191
192        let duration = start.elapsed();
193        let status = if res.is_ok() { "success" } else { "failure" };
194
195        metrics::histogram!(
196            "model_inference.duration_seconds",
197            "alias" => self.alias.clone(),
198            "task" => "embed",
199            "provider" => self.provider_id.clone()
200        )
201        .record(duration.as_secs_f64());
202
203        metrics::counter!(
204            "model_inference.total",
205            "alias" => self.alias.clone(),
206            "task" => "embed",
207            "provider" => self.provider_id.clone(),
208            "status" => status
209        )
210        .increment(1);
211
212        res
213    }
214
215    fn dimensions(&self) -> u32 {
216        self.inner.dimensions()
217    }
218
219    fn model_id(&self) -> &str {
220        self.inner.model_id()
221    }
222
223    async fn warmup(&self) -> Result<()> {
224        self.inner.warmup().await
225    }
226}
227
228/// Wrapper around a [`GeneratorModel`] that adds timeout, retry, and metrics.
229///
230/// See [`InstrumentedEmbeddingModel`] for details on the instrumentation behavior.
231pub struct InstrumentedGeneratorModel {
232    pub inner: Arc<dyn GeneratorModel>,
233    pub alias: String,
234    pub provider_id: String,
235    pub timeout: Option<Duration>,
236    pub retry: Option<crate::api::RetryConfig>,
237}
238
239#[async_trait]
240impl GeneratorModel for InstrumentedGeneratorModel {
241    async fn generate(
242        &self,
243        messages: &[String],
244        options: GenerationOptions,
245    ) -> Result<GenerationResult> {
246        let start = Instant::now();
247        let mut attempts = 0;
248        let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
249
250        let res = loop {
251            attempts += 1;
252            let fut = self.inner.generate(messages, options.clone());
253
254            let res = if let Some(timeout) = self.timeout {
255                match tokio::time::timeout(timeout, fut).await {
256                    Ok(r) => r,
257                    Err(_) => Err(RuntimeError::Timeout),
258                }
259            } else {
260                fut.await
261            };
262
263            match res {
264                Ok(val) => break Ok(val),
265                Err(e) if e.is_retryable() && attempts < max_attempts => {
266                    let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
267                    tracing::warn!(
268                        alias = %self.alias,
269                        attempt = attempts,
270                        backoff_ms = backoff.as_millis(),
271                        error = %e,
272                        "Retrying generation call"
273                    );
274                    tokio::time::sleep(backoff).await;
275                    continue;
276                }
277                Err(e) => break Err(e),
278            }
279        };
280
281        let duration = start.elapsed();
282        let status = if res.is_ok() { "success" } else { "failure" };
283
284        metrics::histogram!(
285            "model_inference.duration_seconds",
286            "alias" => self.alias.clone(),
287            "task" => "generate",
288            "provider" => self.provider_id.clone()
289        )
290        .record(duration.as_secs_f64());
291
292        metrics::counter!(
293            "model_inference.total",
294            "alias" => self.alias.clone(),
295            "task" => "generate",
296            "provider" => self.provider_id.clone(),
297            "status" => status
298        )
299        .increment(1);
300
301        res
302    }
303
304    async fn warmup(&self) -> Result<()> {
305        self.inner.warmup().await
306    }
307}
308
309/// Wrapper around a [`RerankerModel`] that adds timeout, retry, and metrics.
310///
311/// See [`InstrumentedEmbeddingModel`] for details on the instrumentation behavior.
312pub struct InstrumentedRerankerModel {
313    pub inner: Arc<dyn RerankerModel>,
314    pub alias: String,
315    pub provider_id: String,
316    pub timeout: Option<Duration>,
317    pub retry: Option<crate::api::RetryConfig>,
318}
319
320#[async_trait]
321impl RerankerModel for InstrumentedRerankerModel {
322    async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>> {
323        let start = Instant::now();
324        let mut attempts = 0;
325        let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
326
327        let res = loop {
328            attempts += 1;
329            let fut = self.inner.rerank(query, docs);
330
331            let res = if let Some(timeout) = self.timeout {
332                match tokio::time::timeout(timeout, fut).await {
333                    Ok(r) => r,
334                    Err(_) => Err(RuntimeError::Timeout),
335                }
336            } else {
337                fut.await
338            };
339
340            match res {
341                Ok(val) => break Ok(val),
342                Err(e) if e.is_retryable() && attempts < max_attempts => {
343                    let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
344                    tracing::warn!(
345                        alias = %self.alias,
346                        attempt = attempts,
347                        backoff_ms = backoff.as_millis(),
348                        error = %e,
349                        "Retrying rerank call"
350                    );
351                    tokio::time::sleep(backoff).await;
352                    continue;
353                }
354                Err(e) => break Err(e),
355            }
356        };
357
358        let duration = start.elapsed();
359        let status = if res.is_ok() { "success" } else { "failure" };
360
361        metrics::histogram!(
362            "model_inference.duration_seconds",
363            "alias" => self.alias.clone(),
364            "task" => "rerank",
365            "provider" => self.provider_id.clone()
366        )
367        .record(duration.as_secs_f64());
368
369        metrics::counter!(
370            "model_inference.total",
371            "alias" => self.alias.clone(),
372            "task" => "rerank",
373            "provider" => self.provider_id.clone(),
374            "status" => status
375        )
376        .increment(1);
377
378        res
379    }
380
381    async fn warmup(&self) -> Result<()> {
382        self.inner.warmup().await
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use std::sync::atomic::{AtomicU32, Ordering};
390
391    #[tokio::test]
392    async fn test_circuit_breaker_transitions() {
393        let config = CircuitBreakerConfig {
394            failure_threshold: 2,
395            open_wait_seconds: 1,
396        };
397        let cb = CircuitBreakerWrapper::new(config);
398        let counter = Arc::new(AtomicU32::new(0));
399
400        // 1. Success calls - state remains Closed
401        let res = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
402        assert!(res.is_ok());
403
404        // 2. Failures - state transitions to Open
405        let res = cb
406            .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
407            .await;
408        assert!(res.is_err()); // Fail 1
409
410        let res = cb
411            .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
412            .await;
413        assert!(res.is_err()); // Fail 2 -> Open
414
415        // 3. Open state - calls rejected immediately
416        let res = cb
417            .call(|| async {
418                counter.fetch_add(1, Ordering::SeqCst);
419                Ok(())
420            })
421            .await;
422        assert!(res.is_err());
423        assert_eq!(res.err().unwrap().to_string(), "Unavailable");
424        assert_eq!(counter.load(Ordering::SeqCst), 0); // Should not have run
425
426        // 4. Wait for HalfOpen
427        tokio::time::sleep(Duration::from_millis(1100)).await;
428
429        // 5. HalfOpen - allow one call
430        // If it fails, go back to Open
431        let res = cb
432            .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
433            .await;
434        assert!(res.is_err());
435
436        // Should be Open again
437        let res = cb.call(|| async { Ok(()) }).await;
438        assert!(res.is_err());
439        assert_eq!(res.err().unwrap().to_string(), "Unavailable");
440
441        // 6. Wait again for HalfOpen
442        tokio::time::sleep(Duration::from_millis(1100)).await;
443
444        // 7. Success - transition to Closed
445        let res = cb.call(|| async { Ok(()) }).await;
446        assert!(res.is_ok());
447
448        // Should be closed now, next call works
449        let res = cb.call(|| async { Ok(()) }).await;
450        assert!(res.is_ok());
451    }
452
453    #[tokio::test]
454    async fn test_half_open_allows_single_probe() {
455        let config = CircuitBreakerConfig {
456            failure_threshold: 1,
457            open_wait_seconds: 1,
458        };
459        let cb = CircuitBreakerWrapper::new(config);
460
461        // Open breaker.
462        let _ = cb
463            .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
464            .await;
465
466        tokio::time::sleep(Duration::from_millis(1100)).await;
467
468        let started = Arc::new(std::sync::atomic::AtomicU32::new(0));
469        let finished = Arc::new(std::sync::atomic::AtomicU32::new(0));
470
471        let cb_probe = cb.clone();
472        let started_probe = started.clone();
473        let finished_probe = finished.clone();
474        let probe = tokio::spawn(async move {
475            cb_probe
476                .call(|| async move {
477                    started_probe.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
478                    tokio::time::sleep(Duration::from_millis(150)).await;
479                    finished_probe.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
480                    Ok::<_, RuntimeError>(())
481                })
482                .await
483        });
484
485        // Allow the first probe to enter.
486        tokio::time::sleep(Duration::from_millis(20)).await;
487
488        // A concurrent call during half-open probe should fail fast.
489        let second = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
490        assert!(matches!(second, Err(RuntimeError::Unavailable)));
491
492        let probe_result = probe.await.unwrap();
493        assert!(probe_result.is_ok());
494        assert_eq!(started.load(std::sync::atomic::Ordering::SeqCst), 1);
495        assert_eq!(finished.load(std::sync::atomic::Ordering::SeqCst), 1);
496
497        // Closed again.
498        let res = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
499        assert!(res.is_ok());
500    }
501}