uni_xervo/
runtime.rs

1//! The core runtime that manages providers, catalogs, and loaded model instances.
2
3use crate::api::{ModelAliasSpec, ModelRuntimeKey};
4use crate::error::{Result, RuntimeError};
5use crate::options_validation::validate_provider_options;
6use crate::reliability::{
7    InstrumentedEmbeddingModel, InstrumentedGeneratorModel, InstrumentedRerankerModel,
8};
9use crate::traits::{
10    EmbeddingModel, GeneratorModel, LoadedModelHandle, ModelProvider, RerankerModel,
11};
12use std::any::Any;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::{Mutex, RwLock};
16
17/// Default load timeout applied when [`ModelAliasSpec::load_timeout`] is `None`.
18const DEFAULT_LOAD_TIMEOUT_SECS: u64 = 600;
19
20/// The central runtime that owns registered providers and a catalog of model
21/// aliases.
22///
23/// Obtain an instance via [`ModelRuntime::builder()`] and the
24/// [`ModelRuntimeBuilder`].  Once built, use [`embedding`](Self::embedding),
25/// [`reranker`](Self::reranker), or [`generator`](Self::generator) to obtain
26/// typed, instrumented model handles.
27///
28/// Models are loaded lazily on first access (unless configured for eager or
29/// background warmup) and cached in an internal registry so that subsequent
30/// requests for the same model are served instantly.
31pub struct ModelRuntime {
32    providers: HashMap<String, Box<dyn ModelProvider>>,
33    registry: Arc<ModelRegistry>,
34    catalog: RwLock<HashMap<String, ModelAliasSpec>>,
35}
36
37/// Internal registry that caches loaded model instances and coordinates
38/// concurrent load requests to prevent duplicate work.
39#[derive(Default)]
40pub struct ModelRegistry {
41    instances: RwLock<HashMap<ModelRuntimeKey, LoadedModelHandle>>,
42    /// Per-key mutexes to prevent concurrent loads of the same model.
43    loader_locks: Mutex<HashMap<ModelRuntimeKey, Arc<Mutex<()>>>>,
44}
45
46impl ModelRuntime {
47    /// Create a new [`ModelRuntimeBuilder`] for configuring and constructing a
48    /// runtime.
49    pub fn builder() -> ModelRuntimeBuilder {
50        ModelRuntimeBuilder::default()
51    }
52
53    /// Register a new model alias at runtime.
54    pub async fn register(&self, spec: ModelAliasSpec) -> Result<()> {
55        spec.validate()?;
56        if !self.providers.contains_key(&spec.provider_id) {
57            return Err(RuntimeError::Config(format!(
58                "Unknown provider '{}' for alias '{}'",
59                spec.provider_id, spec.alias
60            )));
61        }
62        validate_provider_options(&spec.provider_id, spec.task, &spec.options)?;
63        let mut catalog = self.catalog.write().await;
64        if catalog.contains_key(&spec.alias) {
65            return Err(RuntimeError::Config(format!(
66                "Alias '{}' already exists",
67                spec.alias
68            )));
69        }
70        catalog.insert(spec.alias.clone(), spec);
71        Ok(())
72    }
73
74    /// Check if an alias exists in the catalog.
75    pub async fn contains_alias(&self, alias: &str) -> bool {
76        let catalog = self.catalog.read().await;
77        catalog.contains_key(alias)
78    }
79
80    /// Look up a spec by alias, returning an error if not found.
81    async fn lookup_spec(&self, alias: &str) -> Result<ModelAliasSpec> {
82        let catalog = self.catalog.read().await;
83        catalog
84            .get(alias)
85            .cloned()
86            .ok_or_else(|| RuntimeError::Config(format!("Alias '{}' not found", alias)))
87    }
88
89    /// Pre-load and cache every model in the catalog.
90    ///
91    /// Models already loaded are skipped. Fails fast on the first error.
92    /// Call this during application startup to avoid cold-start latency on
93    /// first inference.
94    pub async fn prefetch_all(&self) -> Result<()> {
95        let specs: Vec<ModelAliasSpec> = {
96            let catalog = self.catalog.read().await;
97            catalog.values().cloned().collect()
98        };
99        for spec in specs {
100            tracing::info!(alias = %spec.alias, "Prefetching model");
101            self.resolve_and_load_internal(&spec).await?;
102        }
103        Ok(())
104    }
105
106    /// Pre-load and cache specific aliases.
107    ///
108    /// Returns an error immediately if an alias is not found in the catalog
109    /// or if any model fails to load. Models already loaded are skipped.
110    pub async fn prefetch(&self, aliases: &[&str]) -> Result<()> {
111        for alias in aliases {
112            let spec = self.lookup_spec(alias).await?;
113            tracing::info!(alias = %alias, "Prefetching model");
114            self.resolve_and_load_internal(&spec).await?;
115        }
116        Ok(())
117    }
118
119    /// Resolve, load (if necessary), and return an instrumented [`EmbeddingModel`]
120    /// handle for the given alias.
121    pub async fn embedding(&self, alias: &str) -> Result<Arc<dyn EmbeddingModel>> {
122        let spec = self.lookup_spec(alias).await?;
123        let handle = self.resolve_and_load_internal(&spec).await?;
124        if let Some(model) = handle.downcast_ref::<Arc<dyn EmbeddingModel>>() {
125            let instrumented = InstrumentedEmbeddingModel {
126                inner: model.clone(),
127                alias: alias.to_string(),
128                provider_id: spec.provider_id.clone(),
129                timeout: spec.timeout.map(std::time::Duration::from_secs),
130                retry: spec.retry.clone(),
131            };
132            return Ok(Arc::new(instrumented));
133        }
134
135        Err(RuntimeError::CapabilityMismatch(format!(
136            "Model for alias '{}' does not implement EmbeddingModel",
137            alias
138        )))
139    }
140
141    /// Resolve, load (if necessary), and return an instrumented [`RerankerModel`]
142    /// handle for the given alias.
143    pub async fn reranker(&self, alias: &str) -> Result<Arc<dyn RerankerModel>> {
144        let spec = self.lookup_spec(alias).await?;
145        let handle = self.resolve_and_load_internal(&spec).await?;
146        if let Some(model) = handle.downcast_ref::<Arc<dyn RerankerModel>>() {
147            let instrumented = InstrumentedRerankerModel {
148                inner: model.clone(),
149                alias: alias.to_string(),
150                provider_id: spec.provider_id.clone(),
151                timeout: spec.timeout.map(std::time::Duration::from_secs),
152                retry: spec.retry.clone(),
153            };
154            return Ok(Arc::new(instrumented));
155        }
156        Err(RuntimeError::CapabilityMismatch(format!(
157            "Model for alias '{}' does not implement RerankerModel",
158            alias
159        )))
160    }
161
162    /// Resolve, load (if necessary), and return an instrumented [`GeneratorModel`]
163    /// handle for the given alias.
164    pub async fn generator(&self, alias: &str) -> Result<Arc<dyn GeneratorModel>> {
165        let spec = self.lookup_spec(alias).await?;
166        let handle = self.resolve_and_load_internal(&spec).await?;
167        if let Some(model) = handle.downcast_ref::<Arc<dyn GeneratorModel>>() {
168            let instrumented = InstrumentedGeneratorModel {
169                inner: model.clone(),
170                alias: alias.to_string(),
171                provider_id: spec.provider_id.clone(),
172                timeout: spec.timeout.map(std::time::Duration::from_secs),
173                retry: spec.retry.clone(),
174            };
175            return Ok(Arc::new(instrumented));
176        }
177        Err(RuntimeError::CapabilityMismatch(format!(
178            "Model for alias '{}' does not implement GeneratorModel",
179            alias
180        )))
181    }
182
183    #[tracing::instrument(skip(self, spec), fields(provider, model))]
184    async fn resolve_and_load_internal(
185        &self,
186        spec: &ModelAliasSpec,
187    ) -> Result<Arc<dyn Any + Send + Sync>> {
188        let key = ModelRuntimeKey::new(spec);
189
190        // Fast path: already loaded
191        {
192            let registry = self.registry.instances.read().await;
193            if let Some(handle) = registry.get(&key) {
194                return Ok(handle.clone());
195            }
196        }
197
198        // Slow path: coordinate loading
199        let lock = {
200            let mut locks = self.registry.loader_locks.lock().await;
201            locks
202                .entry(key.clone())
203                .or_insert_with(|| Arc::new(Mutex::new(())))
204                .clone()
205        };
206
207        // Acquire loader lock for this key
208        let _guard = lock.lock().await;
209
210        // Double-check after acquiring the loader lock
211        {
212            let registry = self.registry.instances.read().await;
213            if let Some(handle) = registry.get(&key) {
214                let result = Ok(handle.clone());
215                let mut locks = self.registry.loader_locks.lock().await;
216                locks.remove(&key);
217                return result;
218            }
219        }
220
221        let load_timeout =
222            std::time::Duration::from_secs(spec.load_timeout.unwrap_or(DEFAULT_LOAD_TIMEOUT_SECS));
223
224        let result = match tokio::time::timeout(load_timeout, async {
225            let provider = self.providers.get(&spec.provider_id).ok_or_else(|| {
226                RuntimeError::ProviderNotFound(format!("Provider '{}' not found", spec.provider_id))
227            })?;
228
229            tracing::info!(alias = %spec.alias, provider = %spec.provider_id, "Loading model instance");
230            let start = std::time::Instant::now();
231            let handle_result = provider.load(spec).await;
232            let duration = start.elapsed().as_secs_f64();
233
234            metrics::histogram!("model_load.duration_seconds").record(duration);
235
236            let handle = match handle_result {
237                Ok(h) => {
238                    metrics::counter!("model_load.total", "status" => "success").increment(1);
239                    h
240                }
241                Err(e) => {
242                    metrics::counter!("model_load.total", "status" => "failure").increment(1);
243                    tracing::error!(alias = %spec.alias, error = %e, "Model load failed");
244                    return Err(e);
245                }
246            };
247
248            // Model warmup
249            if let Some(model) = handle.downcast_ref::<Arc<dyn EmbeddingModel>>() {
250                model.warmup().await?;
251            } else if let Some(model) = handle.downcast_ref::<Arc<dyn RerankerModel>>() {
252                model.warmup().await?;
253            } else if let Some(model) = handle.downcast_ref::<Arc<dyn GeneratorModel>>() {
254                model.warmup().await?;
255            }
256
257            {
258                let mut registry = self.registry.instances.write().await;
259                registry.insert(key.clone(), handle.clone());
260            }
261
262            Ok(handle)
263        })
264        .await
265        {
266            Ok(res) => res,
267            Err(_) => {
268                metrics::counter!("model_load.total", "status" => "failure").increment(1);
269                tracing::error!(
270                    alias = %spec.alias,
271                    provider = %spec.provider_id,
272                    timeout_secs = load_timeout.as_secs(),
273                    "Model load timed out"
274                );
275                Err(RuntimeError::Timeout)
276            }
277        };
278
279        // Bound loader lock map growth by removing this key once the load path completes.
280        // Existing waiters hold cloned lock Arcs, so this is safe.
281        {
282            let mut locks = self.registry.loader_locks.lock().await;
283            locks.remove(&key);
284        }
285
286        result
287    }
288}
289
290/// Builder for constructing a [`ModelRuntime`] with registered providers,
291/// a model catalog, and a warmup policy.
292///
293/// ```rust,no_run
294/// # use uni_xervo::runtime::ModelRuntime;
295/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
296/// let runtime = ModelRuntime::builder()
297///     // .register_provider(...)
298///     // .catalog(...)
299///     .build()
300///     .await?;
301/// # Ok(())
302/// # }
303/// ```
304#[derive(Default)]
305pub struct ModelRuntimeBuilder {
306    providers: HashMap<String, Box<dyn ModelProvider>>,
307    catalog: Vec<ModelAliasSpec>,
308    warmup_policy: crate::api::WarmupPolicy,
309}
310
311impl ModelRuntimeBuilder {
312    /// Register a provider. The provider's
313    /// [`provider_id`](crate::traits::ModelProvider::provider_id) is used as
314    /// the lookup key; registering a second provider with the same ID
315    /// replaces the first.
316    pub fn register_provider<P: ModelProvider + 'static>(mut self, provider: P) -> Self {
317        self.providers
318            .insert(provider.provider_id().to_string(), Box::new(provider));
319        self
320    }
321
322    /// Set the model catalog from a pre-built vector of specs.
323    pub fn catalog(mut self, catalog: Vec<ModelAliasSpec>) -> Self {
324        self.catalog = catalog;
325        self
326    }
327
328    /// Load catalog from a JSON string (array of model alias specs).
329    pub fn catalog_from_str(mut self, s: &str) -> Result<Self> {
330        self.catalog = crate::api::catalog_from_str(s)?;
331        Ok(self)
332    }
333
334    /// Load catalog from a JSON file (array of model alias specs).
335    pub fn catalog_from_file(mut self, path: impl AsRef<std::path::Path>) -> Result<Self> {
336        self.catalog = crate::api::catalog_from_file(path)?;
337        Ok(self)
338    }
339
340    /// Set the global warmup policy applied to providers during
341    /// [`build`](Self::build).
342    pub fn warmup_policy(mut self, policy: crate::api::WarmupPolicy) -> Self {
343        self.warmup_policy = policy;
344        self
345    }
346
347    /// Validate the catalog, execute the warmup policy, and return the
348    /// constructed [`ModelRuntime`].
349    ///
350    /// Returns an error if any spec references an unknown provider, contains
351    /// invalid options, or if a required eager warmup fails.
352    pub async fn build(self) -> Result<Arc<ModelRuntime>> {
353        let mut catalog_map = HashMap::new();
354        for spec in self.catalog {
355            spec.validate()?;
356            if !self.providers.contains_key(&spec.provider_id) {
357                return Err(RuntimeError::Config(format!(
358                    "Unknown provider '{}' for alias '{}'",
359                    spec.provider_id, spec.alias
360                )));
361            }
362            validate_provider_options(&spec.provider_id, spec.task, &spec.options)?;
363            if catalog_map.insert(spec.alias.clone(), spec).is_some() {
364                return Err(RuntimeError::Config(
365                    "Duplicate alias in catalog".to_string(),
366                ));
367            }
368        }
369
370        let runtime = Arc::new(ModelRuntime {
371            providers: self.providers,
372            registry: Arc::new(ModelRegistry::default()),
373            catalog: RwLock::new(catalog_map),
374        });
375
376        // Provider Warmup Phase
377        match self.warmup_policy {
378            crate::api::WarmupPolicy::Eager => {
379                for (id, provider) in &runtime.providers {
380                    tracing::info!(provider = %id, "Eagerly warming up provider");
381                    provider.warmup().await.map_err(|e| {
382                        RuntimeError::Load(format!("Failed to warmup provider {}: {}", id, e))
383                    })?;
384                }
385            }
386            crate::api::WarmupPolicy::Background => {
387                for id in runtime.providers.keys() {
388                    tracing::info!(provider = %id, "Scheduling background provider warmup");
389                    // We have the Arc<ModelRuntime> already.
390                    let rt = runtime.clone();
391                    let provider_id = id.clone();
392                    tokio::spawn(async move {
393                        if let Some(provider) = rt.providers.get(&provider_id)
394                            && let Err(e) = provider.warmup().await
395                        {
396                            tracing::error!(provider = %provider_id, error = %e, "Background provider warmup failed");
397                        }
398                    });
399                }
400            }
401            crate::api::WarmupPolicy::Lazy => {
402                tracing::debug!("Lazy provider warmup (no-op)");
403            }
404        }
405
406        // Model Warmup Phase
407        let mut warmup_tasks = Vec::new();
408
409        let specs: Vec<ModelAliasSpec> = {
410            let catalog = runtime.catalog.read().await;
411            catalog.values().cloned().collect()
412        };
413
414        for spec in specs {
415            match spec.warmup {
416                crate::api::WarmupPolicy::Eager => {
417                    tracing::info!(alias = %spec.alias, "Eagerly warming up model");
418                    if let Err(e) = runtime.resolve_and_load_internal(&spec).await {
419                        if spec.required {
420                            return Err(e);
421                        }
422                        tracing::error!(
423                            alias = %spec.alias,
424                            provider = %spec.provider_id,
425                            error = %e,
426                            "Optional eager model warmup failed; continuing startup"
427                        );
428                    }
429                }
430                crate::api::WarmupPolicy::Background => {
431                    tracing::info!(alias = %spec.alias, "Scheduling background warmup");
432                    let rt = runtime.clone();
433                    let spec_clone = spec.clone();
434                    // Spawn background task
435                    warmup_tasks.push(tokio::spawn(async move {
436                        if let Err(e) = rt.resolve_and_load_internal(&spec_clone).await {
437                            tracing::error!(alias = %spec_clone.alias, error = %e, "Background warmup failed");
438                        }
439                    }));
440                }
441                crate::api::WarmupPolicy::Lazy => {
442                    tracing::debug!(alias = %spec.alias, "Lazy warmup (no-op)");
443                }
444            }
445        }
446
447        // We don't await background tasks here, they run detached.
448        // Eager tasks are already awaited.
449
450        Ok(runtime)
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use crate::api::ModelTask;
458    use crate::mock::{MockProvider, make_spec};
459
460    #[tokio::test]
461    async fn loader_lock_entries_cleaned_after_successful_load() {
462        let spec = make_spec("embed/test", ModelTask::Embed, "mock/embed", "test-model");
463        let runtime = ModelRuntime::builder()
464            .register_provider(MockProvider::embed_only())
465            .catalog(vec![spec])
466            .build()
467            .await
468            .unwrap();
469
470        let _ = runtime.embedding("embed/test").await.unwrap();
471
472        let locks = runtime.registry.loader_locks.lock().await;
473        assert!(
474            locks.is_empty(),
475            "loader lock map should be empty after load"
476        );
477    }
478
479    #[tokio::test]
480    async fn loader_lock_entries_cleaned_after_failed_load() {
481        let mut spec = make_spec("embed/test", ModelTask::Embed, "mock/failing", "test-model");
482        spec.warmup = crate::api::WarmupPolicy::Lazy;
483        let runtime = ModelRuntime::builder()
484            .register_provider(MockProvider::failing())
485            .catalog(vec![spec])
486            .build()
487            .await
488            .unwrap();
489
490        let err = runtime.embedding("embed/test").await;
491        assert!(err.is_err());
492
493        let locks = runtime.registry.loader_locks.lock().await;
494        assert!(
495            locks.is_empty(),
496            "loader lock map should be empty after failure"
497        );
498    }
499
500    #[tokio::test]
501    async fn loader_lock_entries_cleaned_after_load_timeout() {
502        let mut spec = make_spec("embed/test", ModelTask::Embed, "mock/embed", "test-model");
503        spec.warmup = crate::api::WarmupPolicy::Lazy;
504        spec.load_timeout = Some(1);
505
506        let runtime = ModelRuntime::builder()
507            .register_provider(MockProvider::embed_only().with_load_delay(2_000))
508            .catalog(vec![spec])
509            .build()
510            .await
511            .unwrap();
512
513        let err = runtime.embedding("embed/test").await;
514        assert!(matches!(err, Err(RuntimeError::Timeout)));
515
516        let locks = runtime.registry.loader_locks.lock().await;
517        assert!(
518            locks.is_empty(),
519            "loader lock map should be empty after load timeout"
520        );
521    }
522}