1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5 EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6 Message, MessageRole, ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13pub struct RemoteMistralProvider {
19 base: RemoteProviderBase,
20}
21
22impl Default for RemoteMistralProvider {
23 fn default() -> Self {
24 Self {
25 base: RemoteProviderBase::new(),
26 }
27 }
28}
29
30impl RemoteMistralProvider {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 #[cfg(test)]
36 fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37 self.base.insert_test_breaker(key, age);
38 }
39
40 #[cfg(test)]
41 fn breaker_count(&self) -> usize {
42 self.base.breaker_count()
43 }
44
45 #[cfg(test)]
46 fn force_cleanup_now_for_test(&self) {
47 self.base.force_cleanup_now_for_test();
48 }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteMistralProvider {
53 fn provider_id(&self) -> &'static str {
54 "remote/mistral"
55 }
56
57 fn capabilities(&self) -> ProviderCapabilities {
58 ProviderCapabilities {
59 supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
60 }
61 }
62
63 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64 let cb = self.base.circuit_breaker_for(spec);
65 let api_key = resolve_api_key(&spec.options, "api_key_env", "MISTRAL_API_KEY")?;
66
67 match spec.task {
68 ModelTask::Embed => {
69 let model = MistralEmbeddingModel {
70 client: self.base.client.clone(),
71 cb: cb.clone(),
72 model_id: spec.model_id.clone(),
73 api_key,
74 };
75 let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
76 Ok(Arc::new(handle) as LoadedModelHandle)
77 }
78 ModelTask::Generate => {
79 let model = MistralGeneratorModel {
80 client: self.base.client.clone(),
81 cb,
82 model_id: spec.model_id.clone(),
83 api_key,
84 };
85 let handle: Arc<dyn GeneratorModel> = Arc::new(model);
86 Ok(Arc::new(handle) as LoadedModelHandle)
87 }
88 _ => Err(RuntimeError::CapabilityMismatch(format!(
89 "Mistral provider does not support task {:?}",
90 spec.task
91 ))),
92 }
93 }
94
95 async fn health(&self) -> ProviderHealth {
96 ProviderHealth::Healthy
97 }
98}
99
100struct MistralEmbeddingModel {
101 client: Client,
102 cb: crate::reliability::CircuitBreakerWrapper,
103 model_id: String,
104 api_key: String,
105}
106
107#[async_trait]
108impl EmbeddingModel for MistralEmbeddingModel {
109 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
110 let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
111
112 self.cb
113 .call(move || async move {
114 let response = self
115 .client
116 .post("https://api.mistral.ai/v1/embeddings")
117 .header("Authorization", format!("Bearer {}", self.api_key))
118 .json(&json!({
119 "model": self.model_id,
120 "input": texts
121 }))
122 .send()
123 .await
124 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
125
126 let body: serde_json::Value = check_http_status("Mistral", response)?
127 .json()
128 .await
129 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
130
131 let mut embeddings = Vec::new();
132 if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
133 for item in data {
134 if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
135 let vec: Vec<f32> = embedding
136 .iter()
137 .filter_map(|v| v.as_f64().map(|f| f as f32))
138 .collect();
139 embeddings.push(vec);
140 }
141 }
142 }
143 Ok(embeddings)
144 })
145 .await
146 }
147
148 fn dimensions(&self) -> u32 {
149 1024
151 }
152
153 fn model_id(&self) -> &str {
154 &self.model_id
155 }
156}
157
158struct MistralGeneratorModel {
159 client: Client,
160 cb: crate::reliability::CircuitBreakerWrapper,
161 model_id: String,
162 api_key: String,
163}
164
165#[async_trait]
166impl GeneratorModel for MistralGeneratorModel {
167 async fn generate(
168 &self,
169 messages: &[Message],
170 options: GenerationOptions,
171 ) -> Result<GenerationResult> {
172 let messages: Vec<serde_json::Value> = messages
173 .iter()
174 .map(|msg| {
175 let role = match msg.role {
176 MessageRole::System => "system",
177 MessageRole::User => "user",
178 MessageRole::Assistant => "assistant",
179 };
180 json!({ "role": role, "content": msg.text() })
181 })
182 .collect();
183
184 self.cb
185 .call(move || async move {
186 let mut body = json!({
187 "model": self.model_id,
188 "messages": messages,
189 });
190
191 if let Some(max_tokens) = options.max_tokens {
192 body["max_tokens"] = json!(max_tokens);
193 }
194 if let Some(temperature) = options.temperature {
195 body["temperature"] = json!(temperature);
196 }
197 if let Some(top_p) = options.top_p {
198 body["top_p"] = json!(top_p);
199 }
200
201 let response = self
202 .client
203 .post("https://api.mistral.ai/v1/chat/completions")
204 .header("Authorization", format!("Bearer {}", self.api_key))
205 .json(&body)
206 .send()
207 .await
208 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
209
210 let body: serde_json::Value = check_http_status("Mistral", response)?
211 .json()
212 .await
213 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
214
215 let text = body["choices"][0]["message"]["content"]
216 .as_str()
217 .unwrap_or("")
218 .to_string();
219
220 let usage = body.get("usage").map(|u| TokenUsage {
221 prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as usize,
222 completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as usize,
223 total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as usize,
224 });
225
226 Ok(GenerationResult {
227 text,
228 usage,
229 images: vec![],
230 audio: None,
231 })
232 })
233 .await
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::api::ModelRuntimeKey;
241 use crate::provider::remote_common::RemoteProviderBase;
242 use crate::traits::ModelProvider;
243 use std::time::Duration;
244
245 static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
246
247 fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
248 ModelAliasSpec {
249 alias: alias.to_string(),
250 task,
251 provider_id: "remote/mistral".to_string(),
252 model_id: model_id.to_string(),
253 revision: None,
254 warmup: crate::api::WarmupPolicy::Lazy,
255 required: false,
256 timeout: None,
257 load_timeout: None,
258 retry: None,
259 options: serde_json::Value::Null,
260 }
261 }
262
263 #[tokio::test]
264 async fn breaker_reused_for_same_runtime_key() {
265 let _lock = ENV_LOCK.lock().await;
266 unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
267
268 let provider = RemoteMistralProvider::new();
269 let s1 = spec("embed/a", ModelTask::Embed, "mistral-embed");
270 let s2 = spec("embed/b", ModelTask::Embed, "mistral-embed");
271
272 let _ = provider.load(&s1).await.unwrap();
273 let _ = provider.load(&s2).await.unwrap();
274
275 assert_eq!(provider.breaker_count(), 1);
276
277 unsafe { std::env::remove_var("MISTRAL_API_KEY") };
278 }
279
280 #[tokio::test]
281 async fn breaker_isolated_by_task_and_model() {
282 let _lock = ENV_LOCK.lock().await;
283 unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
284
285 let provider = RemoteMistralProvider::new();
286 let embed = spec("embed/a", ModelTask::Embed, "mistral-embed");
287 let gen_spec = spec("chat/a", ModelTask::Generate, "mistral-small-latest");
288
289 let _ = provider.load(&embed).await.unwrap();
290 let _ = provider.load(&gen_spec).await.unwrap();
291
292 assert_eq!(provider.breaker_count(), 2);
293
294 unsafe { std::env::remove_var("MISTRAL_API_KEY") };
295 }
296
297 #[tokio::test]
298 async fn breaker_cleanup_evicts_stale_entries() {
299 let _lock = ENV_LOCK.lock().await;
300 unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
301
302 let provider = RemoteMistralProvider::new();
303 let stale = spec("embed/stale", ModelTask::Embed, "mistral-embed");
304 let fresh = spec("chat/fresh", ModelTask::Generate, "mistral-small-latest");
305 provider.insert_test_breaker(
306 ModelRuntimeKey::new(&stale),
307 RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
308 );
309 provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
310 assert_eq!(provider.breaker_count(), 2);
311
312 provider.force_cleanup_now_for_test();
313 let _ = provider.load(&fresh).await.unwrap();
314
315 assert_eq!(provider.breaker_count(), 1);
316
317 unsafe { std::env::remove_var("MISTRAL_API_KEY") };
318 }
319
320 #[tokio::test]
321 async fn rerank_capability_mismatch() {
322 let _lock = ENV_LOCK.lock().await;
323 unsafe { std::env::set_var("MISTRAL_API_KEY", "test-key") };
324
325 let provider = RemoteMistralProvider::new();
326 let s = spec("rerank/a", ModelTask::Rerank, "mistral-embed");
327 let result = provider.load(&s).await;
328 assert!(result.is_err());
329 assert!(
330 result
331 .unwrap_err()
332 .to_string()
333 .contains("does not support task")
334 );
335
336 unsafe { std::env::remove_var("MISTRAL_API_KEY") };
337 }
338}