uni_xervo/
reliability.rs

1//! Reliability primitives: circuit breaker, instrumented model wrappers with
2//! timeout and retry support, and metrics emission.
3
4use crate::error::{Result, RuntimeError};
5use crate::traits::{
6    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, Message, RerankerModel,
7    ScoredDoc,
8};
9use async_trait::async_trait;
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12
13/// Internal circuit breaker state machine.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15enum State {
16    Closed,
17    Open,
18    HalfOpen,
19}
20
21/// Tunable parameters for the circuit breaker.
22pub struct CircuitBreakerConfig {
23    /// Number of consecutive failures before the breaker opens.
24    pub failure_threshold: u32,
25    /// Seconds to wait in the open state before allowing a probe call.
26    pub open_wait_seconds: u64,
27}
28
29impl Default for CircuitBreakerConfig {
30    fn default() -> Self {
31        Self {
32            failure_threshold: 5,
33            open_wait_seconds: 10,
34        }
35    }
36}
37
38struct Inner {
39    state: State,
40    failures: u32,
41    last_failure: Option<Instant>,
42    config: CircuitBreakerConfig,
43    half_open_probe_in_flight: bool,
44}
45
46/// Thread-safe circuit breaker that tracks failures and short-circuits calls
47/// when a provider is unhealthy.
48///
49/// State transitions: **Closed** -> (failures >= threshold) -> **Open** ->
50/// (wait period elapsed) -> **HalfOpen** -> (probe succeeds) -> **Closed**
51/// (or probe fails -> back to **Open**).
52#[derive(Clone)]
53pub struct CircuitBreakerWrapper {
54    inner: Arc<Mutex<Inner>>,
55}
56
57impl CircuitBreakerWrapper {
58    /// Create a new circuit breaker with the given configuration.
59    pub fn new(config: CircuitBreakerConfig) -> Self {
60        Self {
61            inner: Arc::new(Mutex::new(Inner {
62                state: State::Closed,
63                failures: 0,
64                last_failure: None,
65                config,
66                half_open_probe_in_flight: false,
67            })),
68        }
69    }
70
71    /// Execute `f` through the circuit breaker.
72    ///
73    /// Returns [`RuntimeError::Unavailable`] immediately when the breaker is
74    /// open.  In the half-open state only a single probe call is allowed;
75    /// concurrent callers receive `Unavailable` until the probe completes.
76    pub async fn call<F, Fut, T>(&self, f: F) -> Result<T>
77    where
78        F: FnOnce() -> Fut,
79        Fut: std::future::Future<Output = Result<T>>,
80    {
81        let is_probe_call;
82
83        // 1. Check state
84        {
85            let mut inner = self.inner.lock().unwrap();
86            match inner.state {
87                State::Open => {
88                    if let Some(last) = inner.last_failure {
89                        if last.elapsed() >= Duration::from_secs(inner.config.open_wait_seconds) {
90                            inner.state = State::HalfOpen;
91                        } else {
92                            return Err(RuntimeError::Unavailable);
93                        }
94                    }
95                }
96                State::HalfOpen => {
97                    if inner.half_open_probe_in_flight {
98                        return Err(RuntimeError::Unavailable);
99                    }
100                }
101                State::Closed => {}
102            }
103            is_probe_call = inner.state == State::HalfOpen;
104            if is_probe_call {
105                inner.half_open_probe_in_flight = true;
106            }
107        }
108
109        // 2. Execute
110        let result = f().await;
111
112        // 3. Update state
113        let mut inner = self.inner.lock().unwrap();
114        match result {
115            Ok(val) => {
116                if is_probe_call {
117                    inner.state = State::Closed;
118                    inner.failures = 0;
119                    inner.half_open_probe_in_flight = false;
120                } else if inner.state == State::Closed {
121                    inner.failures = 0;
122                }
123                Ok(val)
124            }
125            Err(e) => {
126                if is_probe_call {
127                    inner.half_open_probe_in_flight = false;
128                }
129                inner.failures += 1;
130                inner.last_failure = Some(Instant::now());
131
132                if is_probe_call
133                    || (inner.state == State::Closed
134                        && inner.failures >= inner.config.failure_threshold)
135                {
136                    inner.state = State::Open;
137                }
138                Err(e)
139            }
140        }
141    }
142}
143
144/// Wrapper around an [`EmbeddingModel`] that adds per-call timeout enforcement,
145/// exponential-backoff retries for transient errors, and metrics emission
146/// (`model_inference.duration_seconds`, `model_inference.total`).
147pub struct InstrumentedEmbeddingModel {
148    pub inner: Arc<dyn EmbeddingModel>,
149    pub alias: String,
150    pub provider_id: String,
151    pub timeout: Option<Duration>,
152    pub retry: Option<crate::api::RetryConfig>,
153}
154
155#[async_trait]
156impl EmbeddingModel for InstrumentedEmbeddingModel {
157    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
158        let start = Instant::now();
159        let mut attempts = 0;
160        let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
161
162        let res = loop {
163            attempts += 1;
164            let fut = self.inner.embed(texts.clone());
165
166            let res = if let Some(timeout) = self.timeout {
167                match tokio::time::timeout(timeout, fut).await {
168                    Ok(r) => r,
169                    Err(_) => Err(RuntimeError::Timeout),
170                }
171            } else {
172                fut.await
173            };
174
175            match res {
176                Ok(val) => break Ok(val),
177                Err(e) if e.is_retryable() && attempts < max_attempts => {
178                    let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
179                    tracing::warn!(
180                        alias = %self.alias,
181                        attempt = attempts,
182                        backoff_ms = backoff.as_millis(),
183                        error = %e,
184                        "Retrying embedding call"
185                    );
186                    tokio::time::sleep(backoff).await;
187                    continue;
188                }
189                Err(e) => break Err(e),
190            }
191        };
192
193        let duration = start.elapsed();
194        let status = if res.is_ok() { "success" } else { "failure" };
195
196        metrics::histogram!(
197            "model_inference.duration_seconds",
198            "alias" => self.alias.clone(),
199            "task" => "embed",
200            "provider" => self.provider_id.clone()
201        )
202        .record(duration.as_secs_f64());
203
204        metrics::counter!(
205            "model_inference.total",
206            "alias" => self.alias.clone(),
207            "task" => "embed",
208            "provider" => self.provider_id.clone(),
209            "status" => status
210        )
211        .increment(1);
212
213        res
214    }
215
216    fn dimensions(&self) -> u32 {
217        self.inner.dimensions()
218    }
219
220    fn model_id(&self) -> &str {
221        self.inner.model_id()
222    }
223
224    async fn warmup(&self) -> Result<()> {
225        self.inner.warmup().await
226    }
227}
228
229/// Wrapper around a [`GeneratorModel`] that adds timeout, retry, and metrics.
230///
231/// See [`InstrumentedEmbeddingModel`] for details on the instrumentation behavior.
232pub struct InstrumentedGeneratorModel {
233    pub inner: Arc<dyn GeneratorModel>,
234    pub alias: String,
235    pub provider_id: String,
236    pub timeout: Option<Duration>,
237    pub retry: Option<crate::api::RetryConfig>,
238}
239
240#[async_trait]
241impl GeneratorModel for InstrumentedGeneratorModel {
242    async fn generate(
243        &self,
244        messages: &[Message],
245        options: GenerationOptions,
246    ) -> Result<GenerationResult> {
247        let start = Instant::now();
248        let mut attempts = 0;
249        let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
250
251        let res = loop {
252            attempts += 1;
253            let fut = self.inner.generate(messages, options.clone());
254
255            let res = if let Some(timeout) = self.timeout {
256                match tokio::time::timeout(timeout, fut).await {
257                    Ok(r) => r,
258                    Err(_) => Err(RuntimeError::Timeout),
259                }
260            } else {
261                fut.await
262            };
263
264            match res {
265                Ok(val) => break Ok(val),
266                Err(e) if e.is_retryable() && attempts < max_attempts => {
267                    let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
268                    tracing::warn!(
269                        alias = %self.alias,
270                        attempt = attempts,
271                        backoff_ms = backoff.as_millis(),
272                        error = %e,
273                        "Retrying generation call"
274                    );
275                    tokio::time::sleep(backoff).await;
276                    continue;
277                }
278                Err(e) => break Err(e),
279            }
280        };
281
282        let duration = start.elapsed();
283        let status = if res.is_ok() { "success" } else { "failure" };
284
285        metrics::histogram!(
286            "model_inference.duration_seconds",
287            "alias" => self.alias.clone(),
288            "task" => "generate",
289            "provider" => self.provider_id.clone()
290        )
291        .record(duration.as_secs_f64());
292
293        metrics::counter!(
294            "model_inference.total",
295            "alias" => self.alias.clone(),
296            "task" => "generate",
297            "provider" => self.provider_id.clone(),
298            "status" => status
299        )
300        .increment(1);
301
302        res
303    }
304
305    async fn warmup(&self) -> Result<()> {
306        self.inner.warmup().await
307    }
308}
309
310/// Wrapper around a [`RerankerModel`] that adds timeout, retry, and metrics.
311///
312/// See [`InstrumentedEmbeddingModel`] for details on the instrumentation behavior.
313pub struct InstrumentedRerankerModel {
314    pub inner: Arc<dyn RerankerModel>,
315    pub alias: String,
316    pub provider_id: String,
317    pub timeout: Option<Duration>,
318    pub retry: Option<crate::api::RetryConfig>,
319}
320
321#[async_trait]
322impl RerankerModel for InstrumentedRerankerModel {
323    async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>> {
324        let start = Instant::now();
325        let mut attempts = 0;
326        let max_attempts = self.retry.as_ref().map(|r| r.max_attempts).unwrap_or(1);
327
328        let res = loop {
329            attempts += 1;
330            let fut = self.inner.rerank(query, docs);
331
332            let res = if let Some(timeout) = self.timeout {
333                match tokio::time::timeout(timeout, fut).await {
334                    Ok(r) => r,
335                    Err(_) => Err(RuntimeError::Timeout),
336                }
337            } else {
338                fut.await
339            };
340
341            match res {
342                Ok(val) => break Ok(val),
343                Err(e) if e.is_retryable() && attempts < max_attempts => {
344                    let backoff = self.retry.as_ref().unwrap().get_backoff(attempts);
345                    tracing::warn!(
346                        alias = %self.alias,
347                        attempt = attempts,
348                        backoff_ms = backoff.as_millis(),
349                        error = %e,
350                        "Retrying rerank call"
351                    );
352                    tokio::time::sleep(backoff).await;
353                    continue;
354                }
355                Err(e) => break Err(e),
356            }
357        };
358
359        let duration = start.elapsed();
360        let status = if res.is_ok() { "success" } else { "failure" };
361
362        metrics::histogram!(
363            "model_inference.duration_seconds",
364            "alias" => self.alias.clone(),
365            "task" => "rerank",
366            "provider" => self.provider_id.clone()
367        )
368        .record(duration.as_secs_f64());
369
370        metrics::counter!(
371            "model_inference.total",
372            "alias" => self.alias.clone(),
373            "task" => "rerank",
374            "provider" => self.provider_id.clone(),
375            "status" => status
376        )
377        .increment(1);
378
379        res
380    }
381
382    async fn warmup(&self) -> Result<()> {
383        self.inner.warmup().await
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use std::sync::atomic::{AtomicU32, Ordering};
391
392    #[tokio::test]
393    async fn test_circuit_breaker_transitions() {
394        let config = CircuitBreakerConfig {
395            failure_threshold: 2,
396            open_wait_seconds: 1,
397        };
398        let cb = CircuitBreakerWrapper::new(config);
399        let counter = Arc::new(AtomicU32::new(0));
400
401        // 1. Success calls - state remains Closed
402        let res = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
403        assert!(res.is_ok());
404
405        // 2. Failures - state transitions to Open
406        let res = cb
407            .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
408            .await;
409        assert!(res.is_err()); // Fail 1
410
411        let res = cb
412            .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
413            .await;
414        assert!(res.is_err()); // Fail 2 -> Open
415
416        // 3. Open state - calls rejected immediately
417        let res = cb
418            .call(|| async {
419                counter.fetch_add(1, Ordering::SeqCst);
420                Ok(())
421            })
422            .await;
423        assert!(res.is_err());
424        assert_eq!(res.err().unwrap().to_string(), "Unavailable");
425        assert_eq!(counter.load(Ordering::SeqCst), 0); // Should not have run
426
427        // 4. Wait for HalfOpen
428        tokio::time::sleep(Duration::from_millis(1100)).await;
429
430        // 5. HalfOpen - allow one call
431        // If it fails, go back to Open
432        let res = cb
433            .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
434            .await;
435        assert!(res.is_err());
436
437        // Should be Open again
438        let res = cb.call(|| async { Ok(()) }).await;
439        assert!(res.is_err());
440        assert_eq!(res.err().unwrap().to_string(), "Unavailable");
441
442        // 6. Wait again for HalfOpen
443        tokio::time::sleep(Duration::from_millis(1100)).await;
444
445        // 7. Success - transition to Closed
446        let res = cb.call(|| async { Ok(()) }).await;
447        assert!(res.is_ok());
448
449        // Should be closed now, next call works
450        let res = cb.call(|| async { Ok(()) }).await;
451        assert!(res.is_ok());
452    }
453
454    #[tokio::test]
455    async fn test_half_open_allows_single_probe() {
456        let config = CircuitBreakerConfig {
457            failure_threshold: 1,
458            open_wait_seconds: 1,
459        };
460        let cb = CircuitBreakerWrapper::new(config);
461
462        // Open breaker.
463        let _ = cb
464            .call(|| async { Err::<(), _>(RuntimeError::InferenceError("fail".into())) })
465            .await;
466
467        tokio::time::sleep(Duration::from_millis(1100)).await;
468
469        let started = Arc::new(std::sync::atomic::AtomicU32::new(0));
470        let finished = Arc::new(std::sync::atomic::AtomicU32::new(0));
471
472        let cb_probe = cb.clone();
473        let started_probe = started.clone();
474        let finished_probe = finished.clone();
475        let probe = tokio::spawn(async move {
476            cb_probe
477                .call(|| async move {
478                    started_probe.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
479                    tokio::time::sleep(Duration::from_millis(150)).await;
480                    finished_probe.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
481                    Ok::<_, RuntimeError>(())
482                })
483                .await
484        });
485
486        // Allow the first probe to enter.
487        tokio::time::sleep(Duration::from_millis(20)).await;
488
489        // A concurrent call during half-open probe should fail fast.
490        let second = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
491        assert!(matches!(second, Err(RuntimeError::Unavailable)));
492
493        let probe_result = probe.await.unwrap();
494        assert!(probe_result.is_ok());
495        assert_eq!(started.load(std::sync::atomic::Ordering::SeqCst), 1);
496        assert_eq!(finished.load(std::sync::atomic::Ordering::SeqCst), 1);
497
498        // Closed again.
499        let res = cb.call(|| async { Ok::<_, RuntimeError>(()) }).await;
500        assert!(res.is_ok());
501    }
502}