uni_xervo/
api.rs

1//! Public API types for configuring models, catalogs, and runtime behavior.
2
3use crate::error::{Result, RuntimeError};
4use serde::{Deserialize, Serialize};
5use std::path::Path;
6
7/// The kind of inference task a model performs.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum ModelTask {
11    /// Produce dense vector embeddings from text.
12    Embed,
13    /// Re-score a set of documents against a query.
14    Rerank,
15    /// Generate text (chat completions, summarization, etc.).
16    Generate,
17}
18
19/// Controls when a model or provider is initialized during runtime startup.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
21#[serde(rename_all = "snake_case")]
22pub enum WarmupPolicy {
23    /// Load immediately during [`ModelRuntime::builder().build()`](crate::runtime::ModelRuntimeBuilder::build).
24    /// Startup blocks until the load completes (or fails).
25    Eager,
26    /// Defer loading until the first inference request. This is the default.
27    #[default]
28    Lazy,
29    /// Spawn loading in a background task at startup. Inference calls that arrive
30    /// before loading finishes will trigger a blocking wait.
31    Background,
32}
33
34impl std::fmt::Display for WarmupPolicy {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            Self::Eager => write!(f, "eager"),
38            Self::Lazy => write!(f, "lazy"),
39            Self::Background => write!(f, "background"),
40        }
41    }
42}
43
44/// Declarative specification that maps a human-readable alias to a concrete
45/// provider and model.
46///
47/// A model catalog is a `Vec<ModelAliasSpec>` — either built programmatically or
48/// parsed from JSON with [`catalog_from_str`] / [`catalog_from_file`].
49///
50/// # Example JSON
51///
52/// ```json
53/// {
54///   "alias": "embed/default",
55///   "task": "embed",
56///   "provider_id": "local/candle",
57///   "model_id": "sentence-transformers/all-MiniLM-L6-v2",
58///   "warmup": "lazy"
59/// }
60/// ```
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub struct ModelAliasSpec {
63    /// Human-readable name used to request this model (e.g. `"embed/default"`).
64    /// Must contain a `/` separator.
65    pub alias: String,
66    /// The inference task this alias targets.
67    pub task: ModelTask,
68    /// Identifier of the provider that will load this model (e.g. `"local/candle"`,
69    /// `"remote/openai"`).
70    pub provider_id: String,
71    /// Model identifier understood by the provider — typically a HuggingFace repo ID
72    /// for local providers or an API model name for remote providers.
73    pub model_id: String,
74    /// Optional HuggingFace revision (branch, tag, or commit hash).
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub revision: Option<String>,
77    /// When this model should be initialized. Defaults to [`WarmupPolicy::Lazy`].
78    #[serde(default)]
79    pub warmup: WarmupPolicy,
80    /// If `true`, a failed eager warmup aborts runtime startup. Defaults to `false`.
81    #[serde(default)]
82    pub required: bool,
83    /// Per-inference timeout in seconds. `None` means no timeout.
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub timeout: Option<u64>,
86    /// Model load timeout in seconds. Defaults to 600 s if unset.
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub load_timeout: Option<u64>,
89    /// Retry configuration for transient inference failures.
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub retry: Option<RetryConfig>,
92    /// Provider-specific options (e.g. `{"isq": "Q4K"}` for mistral.rs,
93    /// `{"api_key_env": "MY_KEY"}` for remote providers). Defaults to `{}`.
94    #[serde(default)]
95    pub options: serde_json::Value,
96}
97
98/// Configuration for exponential-backoff retries on transient inference errors.
99#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
100pub struct RetryConfig {
101    /// Maximum number of attempts (including the initial call).
102    pub max_attempts: u32,
103    /// Base delay in milliseconds; doubled on each subsequent attempt.
104    pub initial_backoff_ms: u64,
105}
106
107impl RetryConfig {
108    /// Compute the backoff duration for the given 1-based `attempt` number.
109    ///
110    /// Uses `initial_backoff_ms * 2^(attempt - 1)` with saturating arithmetic.
111    pub fn get_backoff(&self, attempt: u32) -> std::time::Duration {
112        std::time::Duration::from_millis(
113            self.initial_backoff_ms * 2u64.pow(attempt.saturating_sub(1)),
114        )
115    }
116}
117
118impl Default for RetryConfig {
119    fn default() -> Self {
120        Self {
121            max_attempts: 3,
122            initial_backoff_ms: 100,
123        }
124    }
125}
126
127/// Deduplication key used by the runtime to share a single loaded model instance
128/// across multiple aliases that point to the same provider, model, revision, and
129/// options.
130#[derive(Debug, Clone, PartialEq, Eq, Hash)]
131pub struct ModelRuntimeKey {
132    /// The task type (embed, rerank, generate).
133    pub task: ModelTask,
134    /// Provider that owns this model instance.
135    pub provider_id: String,
136    /// Model identifier within the provider.
137    pub model_id: String,
138    /// Optional HuggingFace revision.
139    pub revision: Option<String>,
140    /// Hash of the provider-specific options JSON. Two specs with semantically
141    /// equivalent options (same keys/values, any object-key order) produce the
142    /// same hash.
143    pub variant_hash: u64,
144}
145
146impl ModelRuntimeKey {
147    /// Derive a runtime key from a [`ModelAliasSpec`], hashing the options JSON
148    /// in a key-order-independent manner.
149    pub fn new(spec: &ModelAliasSpec) -> Self {
150        let mut hasher = std::collections::hash_map::DefaultHasher::new();
151        use std::hash::Hasher;
152
153        // Hash all JSON option shapes with deterministic key ordering.
154        // This avoids collisions for non-object values while preserving
155        // object-order independence for semantically equivalent JSON.
156        hash_json_value(&spec.options, &mut hasher);
157
158        Self {
159            task: spec.task,
160            provider_id: spec.provider_id.clone(),
161            model_id: spec.model_id.clone(),
162            revision: spec.revision.clone(),
163            variant_hash: hasher.finish(),
164        }
165    }
166}
167
168/// Recursively hash a JSON value in a deterministic, key-order-independent way.
169///
170/// Each JSON variant is prefixed with a unique discriminant byte to avoid
171/// collisions between structurally different values (e.g. `null` vs `false`).
172/// Object keys are sorted before hashing so that `{"a":1,"b":2}` and
173/// `{"b":2,"a":1}` produce the same hash.
174fn hash_json_value<H: std::hash::Hasher>(value: &serde_json::Value, hasher: &mut H) {
175    use std::hash::Hash;
176
177    match value {
178        serde_json::Value::Null => {
179            0u8.hash(hasher);
180        }
181        serde_json::Value::Bool(v) => {
182            1u8.hash(hasher);
183            v.hash(hasher);
184        }
185        serde_json::Value::Number(v) => {
186            2u8.hash(hasher);
187            v.to_string().hash(hasher);
188        }
189        serde_json::Value::String(v) => {
190            3u8.hash(hasher);
191            v.hash(hasher);
192        }
193        serde_json::Value::Array(values) => {
194            4u8.hash(hasher);
195            values.len().hash(hasher);
196            for v in values {
197                hash_json_value(v, hasher);
198            }
199        }
200        serde_json::Value::Object(map) => {
201            5u8.hash(hasher);
202            map.len().hash(hasher);
203
204            let mut entries: Vec<_> = map.iter().collect();
205            entries.sort_by_key(|(k, _)| *k);
206            for (k, v) in entries {
207                k.hash(hasher);
208                hash_json_value(v, hasher);
209            }
210        }
211    }
212}
213
214impl ModelAliasSpec {
215    /// Validate invariants: alias must be non-empty and contain a `'/'`, timeouts
216    /// must be non-zero when set.
217    pub fn validate(&self) -> Result<()> {
218        if self.alias.is_empty() {
219            return Err(RuntimeError::Config("Alias cannot be empty".to_string()));
220        }
221        if !self.alias.contains('/') {
222            return Err(RuntimeError::Config(format!(
223                "Alias '{}' must be in 'task/name' format",
224                self.alias
225            )));
226        }
227        if self.timeout == Some(0) {
228            return Err(RuntimeError::Config(
229                "Inference timeout must be greater than 0".to_string(),
230            ));
231        }
232        if self.load_timeout == Some(0) {
233            return Err(RuntimeError::Config(
234                "Load timeout must be greater than 0".to_string(),
235            ));
236        }
237        Ok(())
238    }
239
240    /// Parse a single `ModelAliasSpec` from a JSON value.
241    pub fn from_json(value: serde_json::Value) -> Result<Self> {
242        let spec: Self = serde_json::from_value(value)
243            .map_err(|e| RuntimeError::Config(format!("Invalid ModelAliasSpec JSON: {}", e)))?;
244        spec.validate()?;
245        Ok(spec)
246    }
247
248    /// Parse a single `ModelAliasSpec` from a JSON string.
249    pub fn from_json_str(s: &str) -> Result<Self> {
250        let spec: Self = serde_json::from_str(s)
251            .map_err(|e| RuntimeError::Config(format!("Invalid ModelAliasSpec JSON: {}", e)))?;
252        spec.validate()?;
253        Ok(spec)
254    }
255}
256
257/// Parse a catalog (array) of `ModelAliasSpec` from a JSON string.
258pub fn catalog_from_str(s: &str) -> Result<Vec<ModelAliasSpec>> {
259    let specs: Vec<ModelAliasSpec> = serde_json::from_str(s)
260        .map_err(|e| RuntimeError::Config(format!("Invalid catalog JSON: {}", e)))?;
261    for spec in &specs {
262        spec.validate()?;
263    }
264    Ok(specs)
265}
266
267/// Read and parse a catalog from a JSON file.
268///
269/// The file must contain a JSON array of model alias specs.
270pub fn catalog_from_file(path: impl AsRef<Path>) -> Result<Vec<ModelAliasSpec>> {
271    let path = path.as_ref();
272    let contents = std::fs::read_to_string(path).map_err(|e| {
273        RuntimeError::Config(format!(
274            "Failed to read catalog file '{}': {}",
275            path.display(),
276            e
277        ))
278    })?;
279    catalog_from_str(&contents)
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use serde_json::json;
286
287    const VALID_JSON: &str = r#"{
288        "alias": "embed/default",
289        "task": "embed",
290        "provider_id": "local/candle",
291        "model_id": "sentence-transformers/all-MiniLM-L6-v2"
292    }"#;
293
294    const VALID_CATALOG_JSON: &str = r#"[
295        {
296            "alias": "embed/default",
297            "task": "embed",
298            "provider_id": "local/candle",
299            "model_id": "sentence-transformers/all-MiniLM-L6-v2"
300        },
301        {
302            "alias": "chat/fast",
303            "task": "generate",
304            "provider_id": "local/mistralrs",
305            "model_id": "mistralai/Mistral-7B-v0.1",
306            "warmup": "background",
307            "required": false,
308            "options": { "isq": "Q4K" }
309        }
310    ]"#;
311
312    #[test]
313    fn from_json_str_parses_valid_spec() {
314        let spec = ModelAliasSpec::from_json_str(VALID_JSON).unwrap();
315        assert_eq!(spec.alias, "embed/default");
316        assert_eq!(spec.task, ModelTask::Embed);
317        assert_eq!(spec.provider_id, "local/candle");
318        assert_eq!(spec.warmup, WarmupPolicy::Lazy); // default
319        assert!(!spec.required); // default
320    }
321
322    #[test]
323    fn from_json_value_parses_valid_spec() {
324        let value = json!({
325            "alias": "embed/fast",
326            "task": "embed",
327            "provider_id": "local/fastembed",
328            "model_id": "BAAI/bge-small-en-v1.5",
329            "required": true,
330            "warmup": "eager"
331        });
332        let spec = ModelAliasSpec::from_json(value).unwrap();
333        assert_eq!(spec.alias, "embed/fast");
334        assert_eq!(spec.warmup, WarmupPolicy::Eager);
335        assert!(spec.required);
336    }
337
338    #[test]
339    fn from_json_str_rejects_missing_slash_in_alias() {
340        let json = r#"{"alias":"noSlash","task":"embed","provider_id":"x","model_id":"y"}"#;
341        assert!(ModelAliasSpec::from_json_str(json).is_err());
342    }
343
344    #[test]
345    fn from_json_str_rejects_invalid_json() {
346        assert!(ModelAliasSpec::from_json_str("{not valid}").is_err());
347    }
348
349    #[test]
350    fn catalog_from_str_parses_array() {
351        let specs = catalog_from_str(VALID_CATALOG_JSON).unwrap();
352        assert_eq!(specs.len(), 2);
353        assert_eq!(specs[0].alias, "embed/default");
354        assert_eq!(specs[1].alias, "chat/fast");
355        assert_eq!(specs[1].options["isq"], "Q4K");
356    }
357
358    #[test]
359    fn catalog_from_str_rejects_invalid_spec() {
360        let json = r#"[{"alias":"bad","task":"embed","provider_id":"x","model_id":"y"}]"#;
361        assert!(catalog_from_str(json).is_err()); // alias has no '/'
362    }
363
364    #[test]
365    fn catalog_from_file_reads_and_parses() {
366        let dir = std::env::temp_dir();
367        let path = dir.join("test_catalog.json");
368        std::fs::write(&path, VALID_CATALOG_JSON).unwrap();
369        let specs = catalog_from_file(&path).unwrap();
370        assert_eq!(specs.len(), 2);
371        std::fs::remove_file(&path).unwrap();
372    }
373
374    #[test]
375    fn catalog_from_file_errors_on_missing_file() {
376        assert!(catalog_from_file("/nonexistent/path/catalog.json").is_err());
377    }
378
379    #[test]
380    fn runtime_key_distinguishes_non_object_options() {
381        let mut spec_null = ModelAliasSpec::from_json_str(VALID_JSON).unwrap();
382        spec_null.options = serde_json::Value::Null;
383
384        let mut spec_bool = spec_null.clone();
385        spec_bool.options = json!(true);
386
387        let mut spec_array = spec_null.clone();
388        spec_array.options = json!(["a", 1]);
389
390        let key_null = ModelRuntimeKey::new(&spec_null);
391        let key_bool = ModelRuntimeKey::new(&spec_bool);
392        let key_array = ModelRuntimeKey::new(&spec_array);
393
394        assert_ne!(key_null, key_bool);
395        assert_ne!(key_null, key_array);
396        assert_ne!(key_bool, key_array);
397    }
398
399    #[test]
400    fn runtime_key_nested_option_order_independence() {
401        let mut spec1 = ModelAliasSpec::from_json_str(VALID_JSON).unwrap();
402        spec1.options = json!({
403            "outer": {
404                "b": [3, 2, 1],
405                "a": {"y": 2, "x": 1}
406            }
407        });
408
409        let mut spec2 = ModelAliasSpec::from_json_str(VALID_JSON).unwrap();
410        spec2.options = json!({
411            "outer": {
412                "a": {"x": 1, "y": 2},
413                "b": [3, 2, 1]
414            }
415        });
416
417        let key1 = ModelRuntimeKey::new(&spec1);
418        let key2 = ModelRuntimeKey::new(&spec2);
419        assert_eq!(key1, key2);
420    }
421}