uni_xervo/provider/
mistral.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6    Message, MessageRole, ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13/// Remote provider that calls the [Mistral AI API](https://docs.mistral.ai/api/)
14/// for embedding and text generation (chat completions).
15///
16/// Requires the `MISTRAL_API_KEY` environment variable (or a custom env var
17/// name via the `api_key_env` option).
18pub struct RemoteMistralProvider {
19    base: RemoteProviderBase,
20}
21
22impl Default for RemoteMistralProvider {
23    fn default() -> Self {
24        Self {
25            base: RemoteProviderBase::new(),
26        }
27    }
28}
29
30impl RemoteMistralProvider {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    #[cfg(test)]
36    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37        self.base.insert_test_breaker(key, age);
38    }
39
40    #[cfg(test)]
41    fn breaker_count(&self) -> usize {
42        self.base.breaker_count()
43    }
44
45    #[cfg(test)]
46    fn force_cleanup_now_for_test(&self) {
47        self.base.force_cleanup_now_for_test();
48    }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteMistralProvider {
53    fn provider_id(&self) -> &'static str {
54        "remote/mistral"
55    }
56
57    fn capabilities(&self) -> ProviderCapabilities {
58        ProviderCapabilities {
59            supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
60        }
61    }
62
63    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64        let cb = self.base.circuit_breaker_for(spec);
65        let api_key = resolve_api_key(&spec.options, "api_key_env", "MISTRAL_API_KEY")?;
66
67        match spec.task {
68            ModelTask::Embed => {
69                let model = MistralEmbeddingModel {
70                    client: self.base.client.clone(),
71                    cb: cb.clone(),
72                    model_id: spec.model_id.clone(),
73                    api_key,
74                };
75                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
76                Ok(Arc::new(handle) as LoadedModelHandle)
77            }
78            ModelTask::Generate => {
79                let model = MistralGeneratorModel {
80                    client: self.base.client.clone(),
81                    cb,
82                    model_id: spec.model_id.clone(),
83                    api_key,
84                };
85                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
86                Ok(Arc::new(handle) as LoadedModelHandle)
87            }
88            _ => Err(RuntimeError::CapabilityMismatch(format!(
89                "Mistral provider does not support task {:?}",
90                spec.task
91            ))),
92        }
93    }
94
95    async fn health(&self) -> ProviderHealth {
96        ProviderHealth::Healthy
97    }
98}
99
100struct MistralEmbeddingModel {
101    client: Client,
102    cb: crate::reliability::CircuitBreakerWrapper,
103    model_id: String,
104    api_key: String,
105}
106
107#[async_trait]
108impl EmbeddingModel for MistralEmbeddingModel {
109    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
110        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
111
112        self.cb
113            .call(move || async move {
114                let response = self
115                    .client
116                    .post("https://api.mistral.ai/v1/embeddings")
117                    .header("Authorization", format!("Bearer {}", self.api_key))
118                    .json(&json!({
119                        "model": self.model_id,
120                        "input": texts
121                    }))
122                    .send()
123                    .await
124                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
125
126                let body: serde_json::Value = check_http_status("Mistral", response)?
127                    .json()
128                    .await
129                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
130
131                let mut embeddings = Vec::new();
132                if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
133                    for item in data {
134                        if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
135                            let vec: Vec<f32> = embedding
136                                .iter()
137                                .filter_map(|v| v.as_f64().map(|f| f as f32))
138                                .collect();
139                            embeddings.push(vec);
140                        }
141                    }
142                }
143                Ok(embeddings)
144            })
145            .await
146    }
147
148    fn dimensions(&self) -> u32 {
149        // All current Mistral embedding models use 1024 dimensions.
150        1024
151    }
152
153    fn model_id(&self) -> &str {
154        &self.model_id
155    }
156}
157
158struct MistralGeneratorModel {
159    client: Client,
160    cb: crate::reliability::CircuitBreakerWrapper,
161    model_id: String,
162    api_key: String,
163}
164
165#[async_trait]
166impl GeneratorModel for MistralGeneratorModel {
167    async fn generate(
168        &self,
169        messages: &[Message],
170        options: GenerationOptions,
171    ) -> Result<GenerationResult> {
172        let messages: Vec<serde_json::Value> = messages
173            .iter()
174            .map(|msg| {
175                let role = match msg.role {
176                    MessageRole::System => "system",
177                    MessageRole::User => "user",
178                    MessageRole::Assistant => "assistant",
179                };
180                json!({ "role": role, "content": msg.text() })
181            })
182            .collect();
183
184        self.cb
185            .call(move || async move {
186                let mut body = json!({
187                    "model": self.model_id,
188                    "messages": messages,
189                });
190
191                if let Some(max_tokens) = options.max_tokens {
192                    body["max_tokens"] = json!(max_tokens);
193                }
194                if let Some(temperature) = options.temperature {
195                    body["temperature"] = json!(temperature);
196                }
197                if let Some(top_p) = options.top_p {
198                    body["top_p"] = json!(top_p);
199                }
200
201                let response = self
202                    .client
203                    .post("https://api.mistral.ai/v1/chat/completions")
204                    .header("Authorization", format!("Bearer {}", self.api_key))
205                    .json(&body)
206                    .send()
207                    .await
208                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
209
210                let body: serde_json::Value = check_http_status("Mistral", response)?
211                    .json()
212                    .await
213                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
214
215                let text = body["choices"][0]["message"]["content"]
216                    .as_str()
217                    .unwrap_or("")
218                    .to_string();
219
220                let usage = body.get("usage").map(|u| TokenUsage {
221                    prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as usize,
222                    completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as usize,
223                    total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as usize,
224                });
225
226                Ok(GenerationResult {
227                    text,
228                    usage,
229                    images: vec![],
230                    audio: None,
231                })
232            })
233            .await
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use crate::api::ModelRuntimeKey;
241    use crate::provider::remote_common::RemoteProviderBase;
242    use crate::traits::ModelProvider;
243    use std::time::Duration;
244
245    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
246
247    fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
248        ModelAliasSpec {
249            alias: alias.to_string(),
250            task,
251            provider_id: "remote/mistral".to_string(),
252            model_id: model_id.to_string(),
253            revision: None,
254            warmup: crate::api::WarmupPolicy::Lazy,
255            required: false,
256            timeout: None,
257            load_timeout: None,
258            retry: None,
259            options: serde_json::Value::Null,
260        }
261    }
262
263    #[tokio::test]
264    async fn breaker_reused_for_same_runtime_key() {
265        let _lock = ENV_LOCK.lock().await;
266        unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
267
268        let provider = RemoteMistralProvider::new();
269        let s1 = spec("embed/a", ModelTask::Embed, "mistral-embed");
270        let s2 = spec("embed/b", ModelTask::Embed, "mistral-embed");
271
272        let _ = provider.load(&s1).await.unwrap();
273        let _ = provider.load(&s2).await.unwrap();
274
275        assert_eq!(provider.breaker_count(), 1);
276
277        unsafe { std::env::remove_var("MISTRAL_API_KEY") };
278    }
279
280    #[tokio::test]
281    async fn breaker_isolated_by_task_and_model() {
282        let _lock = ENV_LOCK.lock().await;
283        unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
284
285        let provider = RemoteMistralProvider::new();
286        let embed = spec("embed/a", ModelTask::Embed, "mistral-embed");
287        let gen_spec = spec("chat/a", ModelTask::Generate, "mistral-small-latest");
288
289        let _ = provider.load(&embed).await.unwrap();
290        let _ = provider.load(&gen_spec).await.unwrap();
291
292        assert_eq!(provider.breaker_count(), 2);
293
294        unsafe { std::env::remove_var("MISTRAL_API_KEY") };
295    }
296
297    #[tokio::test]
298    async fn breaker_cleanup_evicts_stale_entries() {
299        let _lock = ENV_LOCK.lock().await;
300        unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
301
302        let provider = RemoteMistralProvider::new();
303        let stale = spec("embed/stale", ModelTask::Embed, "mistral-embed");
304        let fresh = spec("chat/fresh", ModelTask::Generate, "mistral-small-latest");
305        provider.insert_test_breaker(
306            ModelRuntimeKey::new(&stale),
307            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
308        );
309        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
310        assert_eq!(provider.breaker_count(), 2);
311
312        provider.force_cleanup_now_for_test();
313        let _ = provider.load(&fresh).await.unwrap();
314
315        assert_eq!(provider.breaker_count(), 1);
316
317        unsafe { std::env::remove_var("MISTRAL_API_KEY") };
318    }
319
320    #[tokio::test]
321    async fn rerank_capability_mismatch() {
322        let _lock = ENV_LOCK.lock().await;
323        unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
324
325        let provider = RemoteMistralProvider::new();
326        let s = spec("rerank/a", ModelTask::Rerank, "mistral-embed");
327        let result = provider.load(&s).await;
328        assert!(result.is_err());
329        assert!(
330            result
331                .unwrap_err()
332                .to_string()
333                .contains("does not support task")
334        );
335
336        unsafe { std::env::remove_var("MISTRAL_API_KEY") };
337    }
338}