1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5 EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6 Message, MessageRole, ModelProvider, ProviderCapabilities, ProviderHealth, RerankerModel,
7 ScoredDoc, TokenUsage,
8};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde_json::json;
12use std::sync::Arc;
13
14pub struct RemoteCohereProvider {
20 base: RemoteProviderBase,
21}
22
23impl Default for RemoteCohereProvider {
24 fn default() -> Self {
25 Self {
26 base: RemoteProviderBase::new(),
27 }
28 }
29}
30
31impl RemoteCohereProvider {
32 pub fn new() -> Self {
33 Self::default()
34 }
35
36 #[cfg(test)]
37 fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
38 self.base.insert_test_breaker(key, age);
39 }
40
41 #[cfg(test)]
42 fn breaker_count(&self) -> usize {
43 self.base.breaker_count()
44 }
45
46 #[cfg(test)]
47 fn force_cleanup_now_for_test(&self) {
48 self.base.force_cleanup_now_for_test();
49 }
50}
51
52#[async_trait]
53impl ModelProvider for RemoteCohereProvider {
54 fn provider_id(&self) -> &'static str {
55 "remote/cohere"
56 }
57
58 fn capabilities(&self) -> ProviderCapabilities {
59 ProviderCapabilities {
60 supported_tasks: vec![ModelTask::Embed, ModelTask::Generate, ModelTask::Rerank],
61 }
62 }
63
64 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
65 let cb = self.base.circuit_breaker_for(spec);
66 let api_key = resolve_api_key(&spec.options, "api_key_env", "CO_API_KEY")?;
67
68 let input_type = spec
69 .options
70 .get("input_type")
71 .and_then(|v| v.as_str())
72 .unwrap_or("search_document")
73 .to_string();
74
75 match spec.task {
76 ModelTask::Embed => {
77 let model = CohereEmbeddingModel {
78 client: self.base.client.clone(),
79 cb: cb.clone(),
80 model_id: spec.model_id.clone(),
81 api_key,
82 input_type,
83 };
84 let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
85 Ok(Arc::new(handle) as LoadedModelHandle)
86 }
87 ModelTask::Generate => {
88 let model = CohereGeneratorModel {
89 client: self.base.client.clone(),
90 cb,
91 model_id: spec.model_id.clone(),
92 api_key,
93 };
94 let handle: Arc<dyn GeneratorModel> = Arc::new(model);
95 Ok(Arc::new(handle) as LoadedModelHandle)
96 }
97 ModelTask::Rerank => {
98 let model = CohereRerankerModel {
99 client: self.base.client.clone(),
100 cb,
101 model_id: spec.model_id.clone(),
102 api_key,
103 };
104 let handle: Arc<dyn RerankerModel> = Arc::new(model);
105 Ok(Arc::new(handle) as LoadedModelHandle)
106 }
107 }
108 }
109
110 async fn health(&self) -> ProviderHealth {
111 ProviderHealth::Healthy
112 }
113}
114
115struct CohereEmbeddingModel {
116 client: Client,
117 cb: crate::reliability::CircuitBreakerWrapper,
118 model_id: String,
119 api_key: String,
120 input_type: String,
121}
122
123#[async_trait]
124impl EmbeddingModel for CohereEmbeddingModel {
125 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
126 let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
127
128 self.cb
129 .call(move || async move {
130 let response = self
131 .client
132 .post("https://api.cohere.com/v2/embed")
133 .header("Authorization", format!("Bearer {}", self.api_key))
134 .json(&json!({
135 "texts": texts,
136 "model": self.model_id,
137 "input_type": self.input_type,
138 "embedding_types": ["float"]
139 }))
140 .send()
141 .await
142 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
143
144 let body: serde_json::Value = check_http_status("Cohere", response)?
145 .json()
146 .await
147 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
148
149 let float_embeddings = body
150 .get("embeddings")
151 .and_then(|e| e.get("float"))
152 .and_then(|f| f.as_array())
153 .ok_or_else(|| {
154 RuntimeError::ApiError(
155 "Invalid Cohere embedding response format".to_string(),
156 )
157 })?;
158
159 let mut result = Vec::new();
160 for embedding in float_embeddings {
161 if let Some(values) = embedding.as_array() {
162 let vec: Vec<f32> = values
163 .iter()
164 .filter_map(|v| v.as_f64().map(|f| f as f32))
165 .collect();
166 result.push(vec);
167 }
168 }
169 Ok(result)
170 })
171 .await
172 }
173
174 fn dimensions(&self) -> u32 {
175 match self.model_id.as_str() {
176 "embed-english-light-v3.0" | "embed-multilingual-light-v3.0" => 384,
177 _ => 1024,
178 }
179 }
180
181 fn model_id(&self) -> &str {
182 &self.model_id
183 }
184}
185
186struct CohereGeneratorModel {
187 client: Client,
188 cb: crate::reliability::CircuitBreakerWrapper,
189 model_id: String,
190 api_key: String,
191}
192
193#[async_trait]
194impl GeneratorModel for CohereGeneratorModel {
195 async fn generate(
196 &self,
197 messages: &[Message],
198 options: GenerationOptions,
199 ) -> Result<GenerationResult> {
200 let messages: Vec<serde_json::Value> = messages
201 .iter()
202 .map(|msg| {
203 let role = match msg.role {
204 MessageRole::System => "system",
205 MessageRole::User => "user",
206 MessageRole::Assistant => "assistant",
207 };
208 json!({ "role": role, "content": msg.text() })
209 })
210 .collect();
211
212 self.cb
213 .call(move || async move {
214 let mut body = json!({
215 "model": self.model_id,
216 "messages": messages,
217 });
218
219 if let Some(max_tokens) = options.max_tokens {
220 body["max_tokens"] = json!(max_tokens);
221 }
222 if let Some(temperature) = options.temperature {
223 body["temperature"] = json!(temperature);
224 }
225 if let Some(top_p) = options.top_p {
226 body["p"] = json!(top_p);
227 }
228
229 let response = self
230 .client
231 .post("https://api.cohere.com/v2/chat")
232 .header("Authorization", format!("Bearer {}", self.api_key))
233 .json(&body)
234 .send()
235 .await
236 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
237
238 let body: serde_json::Value = check_http_status("Cohere", response)?
239 .json()
240 .await
241 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
242
243 let text = body
244 .get("message")
245 .and_then(|m| m.get("content"))
246 .and_then(|c| c.as_array())
247 .and_then(|arr| arr.first())
248 .and_then(|item| item.get("text"))
249 .and_then(|t| t.as_str())
250 .unwrap_or("")
251 .to_string();
252
253 let usage = body.get("usage").map(|u| {
254 let input = u
255 .get("tokens")
256 .and_then(|t| t.get("input_tokens"))
257 .and_then(|v| v.as_u64())
258 .unwrap_or(0);
259 let output = u
260 .get("tokens")
261 .and_then(|t| t.get("output_tokens"))
262 .and_then(|v| v.as_u64())
263 .unwrap_or(0);
264 TokenUsage {
265 prompt_tokens: input as usize,
266 completion_tokens: output as usize,
267 total_tokens: (input + output) as usize,
268 }
269 });
270
271 Ok(GenerationResult {
272 text,
273 usage,
274 images: vec![],
275 audio: None,
276 })
277 })
278 .await
279 }
280}
281
282struct CohereRerankerModel {
283 client: Client,
284 cb: crate::reliability::CircuitBreakerWrapper,
285 model_id: String,
286 api_key: String,
287}
288
289#[async_trait]
290impl RerankerModel for CohereRerankerModel {
291 async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>> {
292 let query = query.to_string();
293 let docs: Vec<String> = docs.iter().map(|s| s.to_string()).collect();
294
295 self.cb
296 .call(move || async move {
297 let response = self
298 .client
299 .post("https://api.cohere.com/v2/rerank")
300 .header("Authorization", format!("Bearer {}", self.api_key))
301 .json(&json!({
302 "query": query,
303 "documents": docs,
304 "model": self.model_id,
305 }))
306 .send()
307 .await
308 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
309
310 let body: serde_json::Value = check_http_status("Cohere", response)?
311 .json()
312 .await
313 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
314
315 let results_json =
316 body.get("results")
317 .and_then(|r| r.as_array())
318 .ok_or_else(|| {
319 RuntimeError::ApiError("Invalid rerank response format".to_string())
320 })?;
321
322 let mut results = Vec::new();
323 for item in results_json {
324 let index = item.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
325 let score = item
326 .get("relevance_score")
327 .and_then(|s| s.as_f64())
328 .unwrap_or(0.0) as f32;
329 results.push(ScoredDoc {
330 index,
331 score,
332 text: None,
333 });
334 }
335 Ok(results)
336 })
337 .await
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use crate::api::ModelRuntimeKey;
345 use crate::provider::remote_common::RemoteProviderBase;
346 use crate::traits::ModelProvider;
347 use std::time::Duration;
348
349 static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
350
351 fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
352 ModelAliasSpec {
353 alias: alias.to_string(),
354 task,
355 provider_id: "remote/cohere".to_string(),
356 model_id: model_id.to_string(),
357 revision: None,
358 warmup: crate::api::WarmupPolicy::Lazy,
359 required: false,
360 timeout: None,
361 load_timeout: None,
362 retry: None,
363 options: serde_json::Value::Null,
364 }
365 }
366
367 #[tokio::test]
368 async fn breaker_reused_for_same_runtime_key() {
369 let _lock = ENV_LOCK.lock().await;
370 unsafe { std::env::set_var("CO_API_KEY", "test-key") };
371
372 let provider = RemoteCohereProvider::new();
373 let s1 = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
374 let s2 = spec("embed/b", ModelTask::Embed, "embed-english-v3.0");
375
376 let _ = provider.load(&s1).await.unwrap();
377 let _ = provider.load(&s2).await.unwrap();
378
379 assert_eq!(provider.breaker_count(), 1);
380
381 unsafe { std::env::remove_var("CO_API_KEY") };
382 }
383
384 #[tokio::test]
385 async fn breaker_isolated_by_task_and_model() {
386 let _lock = ENV_LOCK.lock().await;
387 unsafe { std::env::set_var("CO_API_KEY", "test-key") };
388
389 let provider = RemoteCohereProvider::new();
390 let embed = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
391 let gen_spec = spec("chat/a", ModelTask::Generate, "command-r-plus");
392 let rerank = spec("rerank/a", ModelTask::Rerank, "rerank-english-v3.0");
393
394 let _ = provider.load(&embed).await.unwrap();
395 let _ = provider.load(&gen_spec).await.unwrap();
396 let _ = provider.load(&rerank).await.unwrap();
397
398 assert_eq!(provider.breaker_count(), 3);
399
400 unsafe { std::env::remove_var("CO_API_KEY") };
401 }
402
403 #[tokio::test]
404 async fn breaker_cleanup_evicts_stale_entries() {
405 let _lock = ENV_LOCK.lock().await;
406 unsafe { std::env::set_var("CO_API_KEY", "test-key") };
407
408 let provider = RemoteCohereProvider::new();
409 let stale = spec("embed/stale", ModelTask::Embed, "embed-english-v3.0");
410 let fresh = spec("chat/fresh", ModelTask::Generate, "command-r-plus");
411 provider.insert_test_breaker(
412 ModelRuntimeKey::new(&stale),
413 RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
414 );
415 provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
416 assert_eq!(provider.breaker_count(), 2);
417
418 provider.force_cleanup_now_for_test();
419 let _ = provider.load(&fresh).await.unwrap();
420
421 assert_eq!(provider.breaker_count(), 1);
422
423 unsafe { std::env::remove_var("CO_API_KEY") };
424 }
425
426 #[tokio::test]
427 async fn supports_all_three_tasks() {
428 let _lock = ENV_LOCK.lock().await;
429 unsafe { std::env::set_var("CO_API_KEY", "test-key") };
430
431 let provider = RemoteCohereProvider::new();
432
433 let embed = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
434 assert!(provider.load(&embed).await.is_ok());
435
436 let gen_spec = spec("gen/a", ModelTask::Generate, "command-r-plus");
437 assert!(provider.load(&gen_spec).await.is_ok());
438
439 let rerank = spec("rerank/a", ModelTask::Rerank, "rerank-english-v3.0");
440 assert!(provider.load(&rerank).await.is_ok());
441
442 unsafe { std::env::remove_var("CO_API_KEY") };
443 }
444}