uni_xervo/provider/
mistral.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6    ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13/// Remote provider that calls the [Mistral AI API](https://docs.mistral.ai/api/)
14/// for embedding and text generation (chat completions).
15///
16/// Requires the `MISTRAL_API_KEY` environment variable (or a custom env var
17/// name via the `api_key_env` option).
18pub struct RemoteMistralProvider {
19    base: RemoteProviderBase,
20}
21
22impl Default for RemoteMistralProvider {
23    fn default() -> Self {
24        Self {
25            base: RemoteProviderBase::new(),
26        }
27    }
28}
29
30impl RemoteMistralProvider {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    #[cfg(test)]
36    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37        self.base.insert_test_breaker(key, age);
38    }
39
40    #[cfg(test)]
41    fn breaker_count(&self) -> usize {
42        self.base.breaker_count()
43    }
44
45    #[cfg(test)]
46    fn force_cleanup_now_for_test(&self) {
47        self.base.force_cleanup_now_for_test();
48    }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteMistralProvider {
53    fn provider_id(&self) -> &'static str {
54        "remote/mistral"
55    }
56
57    fn capabilities(&self) -> ProviderCapabilities {
58        ProviderCapabilities {
59            supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
60        }
61    }
62
63    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64        let cb = self.base.circuit_breaker_for(spec);
65        let api_key = resolve_api_key(&spec.options, "api_key_env", "MISTRAL_API_KEY")?;
66
67        match spec.task {
68            ModelTask::Embed => {
69                let model = MistralEmbeddingModel {
70                    client: self.base.client.clone(),
71                    cb: cb.clone(),
72                    model_id: spec.model_id.clone(),
73                    api_key,
74                };
75                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
76                Ok(Arc::new(handle) as LoadedModelHandle)
77            }
78            ModelTask::Generate => {
79                let model = MistralGeneratorModel {
80                    client: self.base.client.clone(),
81                    cb,
82                    model_id: spec.model_id.clone(),
83                    api_key,
84                };
85                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
86                Ok(Arc::new(handle) as LoadedModelHandle)
87            }
88            _ => Err(RuntimeError::CapabilityMismatch(format!(
89                "Mistral provider does not support task {:?}",
90                spec.task
91            ))),
92        }
93    }
94
95    async fn health(&self) -> ProviderHealth {
96        ProviderHealth::Healthy
97    }
98}
99
100struct MistralEmbeddingModel {
101    client: Client,
102    cb: crate::reliability::CircuitBreakerWrapper,
103    model_id: String,
104    api_key: String,
105}
106
107#[async_trait]
108impl EmbeddingModel for MistralEmbeddingModel {
109    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
110        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
111
112        self.cb
113            .call(move || async move {
114                let response = self
115                    .client
116                    .post("https://api.mistral.ai/v1/embeddings")
117                    .header("Authorization", format!("Bearer {}", self.api_key))
118                    .json(&json!({
119                        "model": self.model_id,
120                        "input": texts
121                    }))
122                    .send()
123                    .await
124                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
125
126                let body: serde_json::Value = check_http_status("Mistral", response)?
127                    .json()
128                    .await
129                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
130
131                let mut embeddings = Vec::new();
132                if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
133                    for item in data {
134                        if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
135                            let vec: Vec<f32> = embedding
136                                .iter()
137                                .filter_map(|v| v.as_f64().map(|f| f as f32))
138                                .collect();
139                            embeddings.push(vec);
140                        }
141                    }
142                }
143                Ok(embeddings)
144            })
145            .await
146    }
147
148    fn dimensions(&self) -> u32 {
149        // All current Mistral embedding models use 1024 dimensions.
150        1024
151    }
152
153    fn model_id(&self) -> &str {
154        &self.model_id
155    }
156}
157
158struct MistralGeneratorModel {
159    client: Client,
160    cb: crate::reliability::CircuitBreakerWrapper,
161    model_id: String,
162    api_key: String,
163}
164
165#[async_trait]
166impl GeneratorModel for MistralGeneratorModel {
167    async fn generate(
168        &self,
169        messages: &[String],
170        options: GenerationOptions,
171    ) -> Result<GenerationResult> {
172        let messages: Vec<serde_json::Value> = messages
173            .iter()
174            .enumerate()
175            .map(|(i, content)| {
176                let role = if i % 2 == 0 { "user" } else { "assistant" };
177                json!({ "role": role, "content": content })
178            })
179            .collect();
180
181        self.cb
182            .call(move || async move {
183                let mut body = json!({
184                    "model": self.model_id,
185                    "messages": messages,
186                });
187
188                if let Some(max_tokens) = options.max_tokens {
189                    body["max_tokens"] = json!(max_tokens);
190                }
191                if let Some(temperature) = options.temperature {
192                    body["temperature"] = json!(temperature);
193                }
194                if let Some(top_p) = options.top_p {
195                    body["top_p"] = json!(top_p);
196                }
197
198                let response = self
199                    .client
200                    .post("https://api.mistral.ai/v1/chat/completions")
201                    .header("Authorization", format!("Bearer {}", self.api_key))
202                    .json(&body)
203                    .send()
204                    .await
205                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
206
207                let body: serde_json::Value = check_http_status("Mistral", response)?
208                    .json()
209                    .await
210                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
211
212                let text = body["choices"][0]["message"]["content"]
213                    .as_str()
214                    .unwrap_or("")
215                    .to_string();
216
217                let usage = body.get("usage").map(|u| TokenUsage {
218                    prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as usize,
219                    completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as usize,
220                    total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as usize,
221                });
222
223                Ok(GenerationResult { text, usage })
224            })
225            .await
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::api::ModelRuntimeKey;
233    use crate::provider::remote_common::RemoteProviderBase;
234    use crate::traits::ModelProvider;
235    use std::time::Duration;
236
237    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
238
239    fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
240        ModelAliasSpec {
241            alias: alias.to_string(),
242            task,
243            provider_id: "remote/mistral".to_string(),
244            model_id: model_id.to_string(),
245            revision: None,
246            warmup: crate::api::WarmupPolicy::Lazy,
247            required: false,
248            timeout: None,
249            load_timeout: None,
250            retry: None,
251            options: serde_json::Value::Null,
252        }
253    }
254
255    #[tokio::test]
256    async fn breaker_reused_for_same_runtime_key() {
257        let _lock = ENV_LOCK.lock().await;
258        unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
259
260        let provider = RemoteMistralProvider::new();
261        let s1 = spec("embed/a", ModelTask::Embed, "mistral-embed");
262        let s2 = spec("embed/b", ModelTask::Embed, "mistral-embed");
263
264        let _ = provider.load(&s1).await.unwrap();
265        let _ = provider.load(&s2).await.unwrap();
266
267        assert_eq!(provider.breaker_count(), 1);
268
269        unsafe { std::env::remove_var("MISTRAL_API_KEY") };
270    }
271
272    #[tokio::test]
273    async fn breaker_isolated_by_task_and_model() {
274        let _lock = ENV_LOCK.lock().await;
275        unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
276
277        let provider = RemoteMistralProvider::new();
278        let embed = spec("embed/a", ModelTask::Embed, "mistral-embed");
279        let gen_spec = spec("chat/a", ModelTask::Generate, "mistral-small-latest");
280
281        let _ = provider.load(&embed).await.unwrap();
282        let _ = provider.load(&gen_spec).await.unwrap();
283
284        assert_eq!(provider.breaker_count(), 2);
285
286        unsafe { std::env::remove_var("MISTRAL_API_KEY") };
287    }
288
289    #[tokio::test]
290    async fn breaker_cleanup_evicts_stale_entries() {
291        let _lock = ENV_LOCK.lock().await;
292        unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
293
294        let provider = RemoteMistralProvider::new();
295        let stale = spec("embed/stale", ModelTask::Embed, "mistral-embed");
296        let fresh = spec("chat/fresh", ModelTask::Generate, "mistral-small-latest");
297        provider.insert_test_breaker(
298            ModelRuntimeKey::new(&stale),
299            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
300        );
301        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
302        assert_eq!(provider.breaker_count(), 2);
303
304        provider.force_cleanup_now_for_test();
305        let _ = provider.load(&fresh).await.unwrap();
306
307        assert_eq!(provider.breaker_count(), 1);
308
309        unsafe { std::env::remove_var("MISTRAL_API_KEY") };
310    }
311
312    #[tokio::test]
313    async fn rerank_capability_mismatch() {
314        let _lock = ENV_LOCK.lock().await;
315        unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
316
317        let provider = RemoteMistralProvider::new();
318        let s = spec("rerank/a", ModelTask::Rerank, "mistral-embed");
319        let result = provider.load(&s).await;
320        assert!(result.is_err());
321        assert!(
322            result
323                .unwrap_err()
324                .to_string()
325                .contains("does not support task")
326        );
327
328        unsafe { std::env::remove_var("MISTRAL_API_KEY") };
329    }
330}