1use crate::api::{ModelAliasSpec, ModelRuntimeKey};
4use crate::error::{Result, RuntimeError};
5use crate::options_validation::validate_provider_options;
6use crate::reliability::{
7 InstrumentedEmbeddingModel, InstrumentedGeneratorModel, InstrumentedRerankerModel,
8};
9use crate::traits::{
10 EmbeddingModel, GeneratorModel, LoadedModelHandle, ModelProvider, RerankerModel,
11};
12use std::any::Any;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::{Mutex, RwLock};
16
17const DEFAULT_LOAD_TIMEOUT_SECS: u64 = 600;
19
20pub struct ModelRuntime {
32 providers: HashMap<String, Box<dyn ModelProvider>>,
33 registry: Arc<ModelRegistry>,
34 catalog: RwLock<HashMap<String, ModelAliasSpec>>,
35}
36
37#[derive(Default)]
40pub struct ModelRegistry {
41 instances: RwLock<HashMap<ModelRuntimeKey, LoadedModelHandle>>,
42 loader_locks: Mutex<HashMap<ModelRuntimeKey, Arc<Mutex<()>>>>,
44}
45
46impl ModelRuntime {
47 pub fn builder() -> ModelRuntimeBuilder {
50 ModelRuntimeBuilder::default()
51 }
52
53 pub async fn register(&self, spec: ModelAliasSpec) -> Result<()> {
55 spec.validate()?;
56 if !self.providers.contains_key(&spec.provider_id) {
57 return Err(RuntimeError::Config(format!(
58 "Unknown provider '{}' for alias '{}'",
59 spec.provider_id, spec.alias
60 )));
61 }
62 validate_provider_options(&spec.provider_id, spec.task, &spec.options)?;
63 let mut catalog = self.catalog.write().await;
64 if catalog.contains_key(&spec.alias) {
65 return Err(RuntimeError::Config(format!(
66 "Alias '{}' already exists",
67 spec.alias
68 )));
69 }
70 catalog.insert(spec.alias.clone(), spec);
71 Ok(())
72 }
73
74 pub async fn contains_alias(&self, alias: &str) -> bool {
76 let catalog = self.catalog.read().await;
77 catalog.contains_key(alias)
78 }
79
80 async fn lookup_spec(&self, alias: &str) -> Result<ModelAliasSpec> {
82 let catalog = self.catalog.read().await;
83 catalog
84 .get(alias)
85 .cloned()
86 .ok_or_else(|| RuntimeError::Config(format!("Alias '{}' not found", alias)))
87 }
88
89 pub async fn prefetch_all(&self) -> Result<()> {
95 let specs: Vec<ModelAliasSpec> = {
96 let catalog = self.catalog.read().await;
97 catalog.values().cloned().collect()
98 };
99 for spec in specs {
100 tracing::info!(alias = %spec.alias, "Prefetching model");
101 self.resolve_and_load_internal(&spec).await?;
102 }
103 Ok(())
104 }
105
106 pub async fn prefetch(&self, aliases: &[&str]) -> Result<()> {
111 for alias in aliases {
112 let spec = self.lookup_spec(alias).await?;
113 tracing::info!(alias = %alias, "Prefetching model");
114 self.resolve_and_load_internal(&spec).await?;
115 }
116 Ok(())
117 }
118
119 pub async fn embedding(&self, alias: &str) -> Result<Arc<dyn EmbeddingModel>> {
122 let spec = self.lookup_spec(alias).await?;
123 let handle = self.resolve_and_load_internal(&spec).await?;
124 if let Some(model) = handle.downcast_ref::<Arc<dyn EmbeddingModel>>() {
125 let instrumented = InstrumentedEmbeddingModel {
126 inner: model.clone(),
127 alias: alias.to_string(),
128 provider_id: spec.provider_id.clone(),
129 timeout: spec.timeout.map(std::time::Duration::from_secs),
130 retry: spec.retry.clone(),
131 };
132 return Ok(Arc::new(instrumented));
133 }
134
135 Err(RuntimeError::CapabilityMismatch(format!(
136 "Model for alias '{}' does not implement EmbeddingModel",
137 alias
138 )))
139 }
140
141 pub async fn reranker(&self, alias: &str) -> Result<Arc<dyn RerankerModel>> {
144 let spec = self.lookup_spec(alias).await?;
145 let handle = self.resolve_and_load_internal(&spec).await?;
146 if let Some(model) = handle.downcast_ref::<Arc<dyn RerankerModel>>() {
147 let instrumented = InstrumentedRerankerModel {
148 inner: model.clone(),
149 alias: alias.to_string(),
150 provider_id: spec.provider_id.clone(),
151 timeout: spec.timeout.map(std::time::Duration::from_secs),
152 retry: spec.retry.clone(),
153 };
154 return Ok(Arc::new(instrumented));
155 }
156 Err(RuntimeError::CapabilityMismatch(format!(
157 "Model for alias '{}' does not implement RerankerModel",
158 alias
159 )))
160 }
161
162 pub async fn generator(&self, alias: &str) -> Result<Arc<dyn GeneratorModel>> {
165 let spec = self.lookup_spec(alias).await?;
166 let handle = self.resolve_and_load_internal(&spec).await?;
167 if let Some(model) = handle.downcast_ref::<Arc<dyn GeneratorModel>>() {
168 let instrumented = InstrumentedGeneratorModel {
169 inner: model.clone(),
170 alias: alias.to_string(),
171 provider_id: spec.provider_id.clone(),
172 timeout: spec.timeout.map(std::time::Duration::from_secs),
173 retry: spec.retry.clone(),
174 };
175 return Ok(Arc::new(instrumented));
176 }
177 Err(RuntimeError::CapabilityMismatch(format!(
178 "Model for alias '{}' does not implement GeneratorModel",
179 alias
180 )))
181 }
182
183 #[tracing::instrument(skip(self, spec), fields(provider, model))]
184 async fn resolve_and_load_internal(
185 &self,
186 spec: &ModelAliasSpec,
187 ) -> Result<Arc<dyn Any + Send + Sync>> {
188 let key = ModelRuntimeKey::new(spec);
189
190 {
192 let registry = self.registry.instances.read().await;
193 if let Some(handle) = registry.get(&key) {
194 return Ok(handle.clone());
195 }
196 }
197
198 let lock = {
200 let mut locks = self.registry.loader_locks.lock().await;
201 locks
202 .entry(key.clone())
203 .or_insert_with(|| Arc::new(Mutex::new(())))
204 .clone()
205 };
206
207 let _guard = lock.lock().await;
209
210 {
212 let registry = self.registry.instances.read().await;
213 if let Some(handle) = registry.get(&key) {
214 let result = Ok(handle.clone());
215 let mut locks = self.registry.loader_locks.lock().await;
216 locks.remove(&key);
217 return result;
218 }
219 }
220
221 let load_timeout =
222 std::time::Duration::from_secs(spec.load_timeout.unwrap_or(DEFAULT_LOAD_TIMEOUT_SECS));
223
224 let result = match tokio::time::timeout(load_timeout, async {
225 let provider = self.providers.get(&spec.provider_id).ok_or_else(|| {
226 RuntimeError::ProviderNotFound(format!("Provider '{}' not found", spec.provider_id))
227 })?;
228
229 tracing::info!(alias = %spec.alias, provider = %spec.provider_id, "Loading model instance");
230 let start = std::time::Instant::now();
231 let handle_result = provider.load(spec).await;
232 let duration = start.elapsed().as_secs_f64();
233
234 metrics::histogram!("model_load.duration_seconds").record(duration);
235
236 let handle = match handle_result {
237 Ok(h) => {
238 metrics::counter!("model_load.total", "status" => "success").increment(1);
239 h
240 }
241 Err(e) => {
242 metrics::counter!("model_load.total", "status" => "failure").increment(1);
243 tracing::error!(alias = %spec.alias, error = %e, "Model load failed");
244 return Err(e);
245 }
246 };
247
248 if let Some(model) = handle.downcast_ref::<Arc<dyn EmbeddingModel>>() {
250 model.warmup().await?;
251 } else if let Some(model) = handle.downcast_ref::<Arc<dyn RerankerModel>>() {
252 model.warmup().await?;
253 } else if let Some(model) = handle.downcast_ref::<Arc<dyn GeneratorModel>>() {
254 model.warmup().await?;
255 }
256
257 {
258 let mut registry = self.registry.instances.write().await;
259 registry.insert(key.clone(), handle.clone());
260 }
261
262 Ok(handle)
263 })
264 .await
265 {
266 Ok(res) => res,
267 Err(_) => {
268 metrics::counter!("model_load.total", "status" => "failure").increment(1);
269 tracing::error!(
270 alias = %spec.alias,
271 provider = %spec.provider_id,
272 timeout_secs = load_timeout.as_secs(),
273 "Model load timed out"
274 );
275 Err(RuntimeError::Timeout)
276 }
277 };
278
279 {
282 let mut locks = self.registry.loader_locks.lock().await;
283 locks.remove(&key);
284 }
285
286 result
287 }
288}
289
290#[derive(Default)]
305pub struct ModelRuntimeBuilder {
306 providers: HashMap<String, Box<dyn ModelProvider>>,
307 catalog: Vec<ModelAliasSpec>,
308 warmup_policy: crate::api::WarmupPolicy,
309}
310
311impl ModelRuntimeBuilder {
312 pub fn register_provider<P: ModelProvider + 'static>(mut self, provider: P) -> Self {
317 self.providers
318 .insert(provider.provider_id().to_string(), Box::new(provider));
319 self
320 }
321
322 pub fn catalog(mut self, catalog: Vec<ModelAliasSpec>) -> Self {
324 self.catalog = catalog;
325 self
326 }
327
328 pub fn catalog_from_str(mut self, s: &str) -> Result<Self> {
330 self.catalog = crate::api::catalog_from_str(s)?;
331 Ok(self)
332 }
333
334 pub fn catalog_from_file(mut self, path: impl AsRef<std::path::Path>) -> Result<Self> {
336 self.catalog = crate::api::catalog_from_file(path)?;
337 Ok(self)
338 }
339
340 pub fn warmup_policy(mut self, policy: crate::api::WarmupPolicy) -> Self {
343 self.warmup_policy = policy;
344 self
345 }
346
347 pub async fn build(self) -> Result<Arc<ModelRuntime>> {
353 let mut catalog_map = HashMap::new();
354 for spec in self.catalog {
355 spec.validate()?;
356 if !self.providers.contains_key(&spec.provider_id) {
357 return Err(RuntimeError::Config(format!(
358 "Unknown provider '{}' for alias '{}'",
359 spec.provider_id, spec.alias
360 )));
361 }
362 validate_provider_options(&spec.provider_id, spec.task, &spec.options)?;
363 if catalog_map.insert(spec.alias.clone(), spec).is_some() {
364 return Err(RuntimeError::Config(
365 "Duplicate alias in catalog".to_string(),
366 ));
367 }
368 }
369
370 let runtime = Arc::new(ModelRuntime {
371 providers: self.providers,
372 registry: Arc::new(ModelRegistry::default()),
373 catalog: RwLock::new(catalog_map),
374 });
375
376 match self.warmup_policy {
378 crate::api::WarmupPolicy::Eager => {
379 for (id, provider) in &runtime.providers {
380 tracing::info!(provider = %id, "Eagerly warming up provider");
381 provider.warmup().await.map_err(|e| {
382 RuntimeError::Load(format!("Failed to warmup provider {}: {}", id, e))
383 })?;
384 }
385 }
386 crate::api::WarmupPolicy::Background => {
387 for id in runtime.providers.keys() {
388 tracing::info!(provider = %id, "Scheduling background provider warmup");
389 let rt = runtime.clone();
391 let provider_id = id.clone();
392 tokio::spawn(async move {
393 if let Some(provider) = rt.providers.get(&provider_id)
394 && let Err(e) = provider.warmup().await
395 {
396 tracing::error!(provider = %provider_id, error = %e, "Background provider warmup failed");
397 }
398 });
399 }
400 }
401 crate::api::WarmupPolicy::Lazy => {
402 tracing::debug!("Lazy provider warmup (no-op)");
403 }
404 }
405
406 let mut warmup_tasks = Vec::new();
408
409 let specs: Vec<ModelAliasSpec> = {
410 let catalog = runtime.catalog.read().await;
411 catalog.values().cloned().collect()
412 };
413
414 for spec in specs {
415 match spec.warmup {
416 crate::api::WarmupPolicy::Eager => {
417 tracing::info!(alias = %spec.alias, "Eagerly warming up model");
418 if let Err(e) = runtime.resolve_and_load_internal(&spec).await {
419 if spec.required {
420 return Err(e);
421 }
422 tracing::error!(
423 alias = %spec.alias,
424 provider = %spec.provider_id,
425 error = %e,
426 "Optional eager model warmup failed; continuing startup"
427 );
428 }
429 }
430 crate::api::WarmupPolicy::Background => {
431 tracing::info!(alias = %spec.alias, "Scheduling background warmup");
432 let rt = runtime.clone();
433 let spec_clone = spec.clone();
434 warmup_tasks.push(tokio::spawn(async move {
436 if let Err(e) = rt.resolve_and_load_internal(&spec_clone).await {
437 tracing::error!(alias = %spec_clone.alias, error = %e, "Background warmup failed");
438 }
439 }));
440 }
441 crate::api::WarmupPolicy::Lazy => {
442 tracing::debug!(alias = %spec.alias, "Lazy warmup (no-op)");
443 }
444 }
445 }
446
447 Ok(runtime)
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use crate::api::ModelTask;
458 use crate::mock::{MockProvider, make_spec};
459
460 #[tokio::test]
461 async fn loader_lock_entries_cleaned_after_successful_load() {
462 let spec = make_spec("embed/test", ModelTask::Embed, "mock/embed", "test-model");
463 let runtime = ModelRuntime::builder()
464 .register_provider(MockProvider::embed_only())
465 .catalog(vec![spec])
466 .build()
467 .await
468 .unwrap();
469
470 let _ = runtime.embedding("embed/test").await.unwrap();
471
472 let locks = runtime.registry.loader_locks.lock().await;
473 assert!(
474 locks.is_empty(),
475 "loader lock map should be empty after load"
476 );
477 }
478
479 #[tokio::test]
480 async fn loader_lock_entries_cleaned_after_failed_load() {
481 let mut spec = make_spec("embed/test", ModelTask::Embed, "mock/failing", "test-model");
482 spec.warmup = crate::api::WarmupPolicy::Lazy;
483 let runtime = ModelRuntime::builder()
484 .register_provider(MockProvider::failing())
485 .catalog(vec![spec])
486 .build()
487 .await
488 .unwrap();
489
490 let err = runtime.embedding("embed/test").await;
491 assert!(err.is_err());
492
493 let locks = runtime.registry.loader_locks.lock().await;
494 assert!(
495 locks.is_empty(),
496 "loader lock map should be empty after failure"
497 );
498 }
499
500 #[tokio::test]
501 async fn loader_lock_entries_cleaned_after_load_timeout() {
502 let mut spec = make_spec("embed/test", ModelTask::Embed, "mock/embed", "test-model");
503 spec.warmup = crate::api::WarmupPolicy::Lazy;
504 spec.load_timeout = Some(1);
505
506 let runtime = ModelRuntime::builder()
507 .register_provider(MockProvider::embed_only().with_load_delay(2_000))
508 .catalog(vec![spec])
509 .build()
510 .await
511 .unwrap();
512
513 let err = runtime.embedding("embed/test").await;
514 assert!(matches!(err, Err(RuntimeError::Timeout)));
515
516 let locks = runtime.registry.loader_locks.lock().await;
517 assert!(
518 locks.is_empty(),
519 "loader lock map should be empty after load timeout"
520 );
521 }
522}