1use crate::error::{Result, RuntimeError};
4use serde::{Deserialize, Serialize};
5use std::path::Path;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum ModelTask {
11 Embed,
13 Rerank,
15 Generate,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
21#[serde(rename_all = "snake_case")]
22pub enum WarmupPolicy {
23 Eager,
26 #[default]
28 Lazy,
29 Background,
32}
33
34impl std::fmt::Display for WarmupPolicy {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 Self::Eager => write!(f, "eager"),
38 Self::Lazy => write!(f, "lazy"),
39 Self::Background => write!(f, "background"),
40 }
41 }
42}
43
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub struct ModelAliasSpec {
63 pub alias: String,
66 pub task: ModelTask,
68 pub provider_id: String,
71 pub model_id: String,
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub revision: Option<String>,
77 #[serde(default)]
79 pub warmup: WarmupPolicy,
80 #[serde(default)]
82 pub required: bool,
83 #[serde(skip_serializing_if = "Option::is_none")]
85 pub timeout: Option<u64>,
86 #[serde(skip_serializing_if = "Option::is_none")]
88 pub load_timeout: Option<u64>,
89 #[serde(skip_serializing_if = "Option::is_none")]
91 pub retry: Option<RetryConfig>,
92 #[serde(default)]
95 pub options: serde_json::Value,
96}
97
98#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
100pub struct RetryConfig {
101 pub max_attempts: u32,
103 pub initial_backoff_ms: u64,
105}
106
107impl RetryConfig {
108 pub fn get_backoff(&self, attempt: u32) -> std::time::Duration {
112 std::time::Duration::from_millis(
113 self.initial_backoff_ms * 2u64.pow(attempt.saturating_sub(1)),
114 )
115 }
116}
117
118impl Default for RetryConfig {
119 fn default() -> Self {
120 Self {
121 max_attempts: 3,
122 initial_backoff_ms: 100,
123 }
124 }
125}
126
127#[derive(Debug, Clone, PartialEq, Eq, Hash)]
131pub struct ModelRuntimeKey {
132 pub task: ModelTask,
134 pub provider_id: String,
136 pub model_id: String,
138 pub revision: Option<String>,
140 pub variant_hash: u64,
144}
145
146impl ModelRuntimeKey {
147 pub fn new(spec: &ModelAliasSpec) -> Self {
150 let mut hasher = std::collections::hash_map::DefaultHasher::new();
151 use std::hash::Hasher;
152
153 hash_json_value(&spec.options, &mut hasher);
157
158 Self {
159 task: spec.task,
160 provider_id: spec.provider_id.clone(),
161 model_id: spec.model_id.clone(),
162 revision: spec.revision.clone(),
163 variant_hash: hasher.finish(),
164 }
165 }
166}
167
168fn hash_json_value<H: std::hash::Hasher>(value: &serde_json::Value, hasher: &mut H) {
175 use std::hash::Hash;
176
177 match value {
178 serde_json::Value::Null => {
179 0u8.hash(hasher);
180 }
181 serde_json::Value::Bool(v) => {
182 1u8.hash(hasher);
183 v.hash(hasher);
184 }
185 serde_json::Value::Number(v) => {
186 2u8.hash(hasher);
187 v.to_string().hash(hasher);
188 }
189 serde_json::Value::String(v) => {
190 3u8.hash(hasher);
191 v.hash(hasher);
192 }
193 serde_json::Value::Array(values) => {
194 4u8.hash(hasher);
195 values.len().hash(hasher);
196 for v in values {
197 hash_json_value(v, hasher);
198 }
199 }
200 serde_json::Value::Object(map) => {
201 5u8.hash(hasher);
202 map.len().hash(hasher);
203
204 let mut entries: Vec<_> = map.iter().collect();
205 entries.sort_by_key(|(k, _)| *k);
206 for (k, v) in entries {
207 k.hash(hasher);
208 hash_json_value(v, hasher);
209 }
210 }
211 }
212}
213
214impl ModelAliasSpec {
215 pub fn validate(&self) -> Result<()> {
218 if self.alias.is_empty() {
219 return Err(RuntimeError::Config("Alias cannot be empty".to_string()));
220 }
221 if !self.alias.contains('/') {
222 return Err(RuntimeError::Config(format!(
223 "Alias '{}' must be in 'task/name' format",
224 self.alias
225 )));
226 }
227 if self.timeout == Some(0) {
228 return Err(RuntimeError::Config(
229 "Inference timeout must be greater than 0".to_string(),
230 ));
231 }
232 if self.load_timeout == Some(0) {
233 return Err(RuntimeError::Config(
234 "Load timeout must be greater than 0".to_string(),
235 ));
236 }
237 Ok(())
238 }
239
240 pub fn from_json(value: serde_json::Value) -> Result<Self> {
242 let spec: Self = serde_json::from_value(value)
243 .map_err(|e| RuntimeError::Config(format!("Invalid ModelAliasSpec JSON: {}", e)))?;
244 spec.validate()?;
245 Ok(spec)
246 }
247
248 pub fn from_json_str(s: &str) -> Result<Self> {
250 let spec: Self = serde_json::from_str(s)
251 .map_err(|e| RuntimeError::Config(format!("Invalid ModelAliasSpec JSON: {}", e)))?;
252 spec.validate()?;
253 Ok(spec)
254 }
255}
256
257pub fn catalog_from_str(s: &str) -> Result<Vec<ModelAliasSpec>> {
259 let specs: Vec<ModelAliasSpec> = serde_json::from_str(s)
260 .map_err(|e| RuntimeError::Config(format!("Invalid catalog JSON: {}", e)))?;
261 for spec in &specs {
262 spec.validate()?;
263 }
264 Ok(specs)
265}
266
267pub fn catalog_from_file(path: impl AsRef<Path>) -> Result<Vec<ModelAliasSpec>> {
271 let path = path.as_ref();
272 let contents = std::fs::read_to_string(path).map_err(|e| {
273 RuntimeError::Config(format!(
274 "Failed to read catalog file '{}': {}",
275 path.display(),
276 e
277 ))
278 })?;
279 catalog_from_str(&contents)
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use serde_json::json;
286
287 const VALID_JSON: &str = r#"{
288 "alias": "embed/default",
289 "task": "embed",
290 "provider_id": "local/candle",
291 "model_id": "sentence-transformers/all-MiniLM-L6-v2"
292 }"#;
293
294 const VALID_CATALOG_JSON: &str = r#"[
295 {
296 "alias": "embed/default",
297 "task": "embed",
298 "provider_id": "local/candle",
299 "model_id": "sentence-transformers/all-MiniLM-L6-v2"
300 },
301 {
302 "alias": "chat/fast",
303 "task": "generate",
304 "provider_id": "local/mistralrs",
305 "model_id": "mistralai/Mistral-7B-v0.1",
306 "warmup": "background",
307 "required": false,
308 "options": { "isq": "Q4K" }
309 }
310 ]"#;
311
312 #[test]
313 fn from_json_str_parses_valid_spec() {
314 let spec = ModelAliasSpec::from_json_str(VALID_JSON).unwrap();
315 assert_eq!(spec.alias, "embed/default");
316 assert_eq!(spec.task, ModelTask::Embed);
317 assert_eq!(spec.provider_id, "local/candle");
318 assert_eq!(spec.warmup, WarmupPolicy::Lazy); assert!(!spec.required); }
321
322 #[test]
323 fn from_json_value_parses_valid_spec() {
324 let value = json!({
325 "alias": "embed/fast",
326 "task": "embed",
327 "provider_id": "local/fastembed",
328 "model_id": "BAAI/bge-small-en-v1.5",
329 "required": true,
330 "warmup": "eager"
331 });
332 let spec = ModelAliasSpec::from_json(value).unwrap();
333 assert_eq!(spec.alias, "embed/fast");
334 assert_eq!(spec.warmup, WarmupPolicy::Eager);
335 assert!(spec.required);
336 }
337
338 #[test]
339 fn from_json_str_rejects_missing_slash_in_alias() {
340 let json = r#"{"alias":"noSlash","task":"embed","provider_id":"x","model_id":"y"}"#;
341 assert!(ModelAliasSpec::from_json_str(json).is_err());
342 }
343
344 #[test]
345 fn from_json_str_rejects_invalid_json() {
346 assert!(ModelAliasSpec::from_json_str("{not valid}").is_err());
347 }
348
349 #[test]
350 fn catalog_from_str_parses_array() {
351 let specs = catalog_from_str(VALID_CATALOG_JSON).unwrap();
352 assert_eq!(specs.len(), 2);
353 assert_eq!(specs[0].alias, "embed/default");
354 assert_eq!(specs[1].alias, "chat/fast");
355 assert_eq!(specs[1].options["isq"], "Q4K");
356 }
357
358 #[test]
359 fn catalog_from_str_rejects_invalid_spec() {
360 let json = r#"[{"alias":"bad","task":"embed","provider_id":"x","model_id":"y"}]"#;
361 assert!(catalog_from_str(json).is_err()); }
363
364 #[test]
365 fn catalog_from_file_reads_and_parses() {
366 let dir = std::env::temp_dir();
367 let path = dir.join("test_catalog.json");
368 std::fs::write(&path, VALID_CATALOG_JSON).unwrap();
369 let specs = catalog_from_file(&path).unwrap();
370 assert_eq!(specs.len(), 2);
371 std::fs::remove_file(&path).unwrap();
372 }
373
374 #[test]
375 fn catalog_from_file_errors_on_missing_file() {
376 assert!(catalog_from_file("/nonexistent/path/catalog.json").is_err());
377 }
378
379 #[test]
380 fn runtime_key_distinguishes_non_object_options() {
381 let mut spec_null = ModelAliasSpec::from_json_str(VALID_JSON).unwrap();
382 spec_null.options = serde_json::Value::Null;
383
384 let mut spec_bool = spec_null.clone();
385 spec_bool.options = json!(true);
386
387 let mut spec_array = spec_null.clone();
388 spec_array.options = json!(["a", 1]);
389
390 let key_null = ModelRuntimeKey::new(&spec_null);
391 let key_bool = ModelRuntimeKey::new(&spec_bool);
392 let key_array = ModelRuntimeKey::new(&spec_array);
393
394 assert_ne!(key_null, key_bool);
395 assert_ne!(key_null, key_array);
396 assert_ne!(key_bool, key_array);
397 }
398
399 #[test]
400 fn runtime_key_nested_option_order_independence() {
401 let mut spec1 = ModelAliasSpec::from_json_str(VALID_JSON).unwrap();
402 spec1.options = json!({
403 "outer": {
404 "b": [3, 2, 1],
405 "a": {"y": 2, "x": 1}
406 }
407 });
408
409 let mut spec2 = ModelAliasSpec::from_json_str(VALID_JSON).unwrap();
410 spec2.options = json!({
411 "outer": {
412 "a": {"x": 1, "y": 2},
413 "b": [3, 2, 1]
414 }
415 });
416
417 let key1 = ModelRuntimeKey::new(&spec1);
418 let key2 = ModelRuntimeKey::new(&spec2);
419 assert_eq!(key1, key2);
420 }
421}