1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5 EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6 ModelProvider, ProviderCapabilities, ProviderHealth, RerankerModel, ScoredDoc, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13pub struct RemoteCohereProvider {
19 base: RemoteProviderBase,
20}
21
22impl Default for RemoteCohereProvider {
23 fn default() -> Self {
24 Self {
25 base: RemoteProviderBase::new(),
26 }
27 }
28}
29
30impl RemoteCohereProvider {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 #[cfg(test)]
36 fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37 self.base.insert_test_breaker(key, age);
38 }
39
40 #[cfg(test)]
41 fn breaker_count(&self) -> usize {
42 self.base.breaker_count()
43 }
44
45 #[cfg(test)]
46 fn force_cleanup_now_for_test(&self) {
47 self.base.force_cleanup_now_for_test();
48 }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteCohereProvider {
53 fn provider_id(&self) -> &'static str {
54 "remote/cohere"
55 }
56
57 fn capabilities(&self) -> ProviderCapabilities {
58 ProviderCapabilities {
59 supported_tasks: vec![ModelTask::Embed, ModelTask::Generate, ModelTask::Rerank],
60 }
61 }
62
63 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64 let cb = self.base.circuit_breaker_for(spec);
65 let api_key = resolve_api_key(&spec.options, "api_key_env", "CO_API_KEY")?;
66
67 let input_type = spec
68 .options
69 .get("input_type")
70 .and_then(|v| v.as_str())
71 .unwrap_or("search_document")
72 .to_string();
73
74 match spec.task {
75 ModelTask::Embed => {
76 let model = CohereEmbeddingModel {
77 client: self.base.client.clone(),
78 cb: cb.clone(),
79 model_id: spec.model_id.clone(),
80 api_key,
81 input_type,
82 };
83 let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
84 Ok(Arc::new(handle) as LoadedModelHandle)
85 }
86 ModelTask::Generate => {
87 let model = CohereGeneratorModel {
88 client: self.base.client.clone(),
89 cb,
90 model_id: spec.model_id.clone(),
91 api_key,
92 };
93 let handle: Arc<dyn GeneratorModel> = Arc::new(model);
94 Ok(Arc::new(handle) as LoadedModelHandle)
95 }
96 ModelTask::Rerank => {
97 let model = CohereRerankerModel {
98 client: self.base.client.clone(),
99 cb,
100 model_id: spec.model_id.clone(),
101 api_key,
102 };
103 let handle: Arc<dyn RerankerModel> = Arc::new(model);
104 Ok(Arc::new(handle) as LoadedModelHandle)
105 }
106 }
107 }
108
109 async fn health(&self) -> ProviderHealth {
110 ProviderHealth::Healthy
111 }
112}
113
114struct CohereEmbeddingModel {
115 client: Client,
116 cb: crate::reliability::CircuitBreakerWrapper,
117 model_id: String,
118 api_key: String,
119 input_type: String,
120}
121
122#[async_trait]
123impl EmbeddingModel for CohereEmbeddingModel {
124 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
125 let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
126
127 self.cb
128 .call(move || async move {
129 let response = self
130 .client
131 .post("https://api.cohere.com/v2/embed")
132 .header("Authorization", format!("Bearer {}", self.api_key))
133 .json(&json!({
134 "texts": texts,
135 "model": self.model_id,
136 "input_type": self.input_type,
137 "embedding_types": ["float"]
138 }))
139 .send()
140 .await
141 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
142
143 let body: serde_json::Value = check_http_status("Cohere", response)?
144 .json()
145 .await
146 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
147
148 let float_embeddings = body
149 .get("embeddings")
150 .and_then(|e| e.get("float"))
151 .and_then(|f| f.as_array())
152 .ok_or_else(|| {
153 RuntimeError::ApiError(
154 "Invalid Cohere embedding response format".to_string(),
155 )
156 })?;
157
158 let mut result = Vec::new();
159 for embedding in float_embeddings {
160 if let Some(values) = embedding.as_array() {
161 let vec: Vec<f32> = values
162 .iter()
163 .filter_map(|v| v.as_f64().map(|f| f as f32))
164 .collect();
165 result.push(vec);
166 }
167 }
168 Ok(result)
169 })
170 .await
171 }
172
173 fn dimensions(&self) -> u32 {
174 match self.model_id.as_str() {
175 "embed-english-light-v3.0" | "embed-multilingual-light-v3.0" => 384,
176 _ => 1024,
177 }
178 }
179
180 fn model_id(&self) -> &str {
181 &self.model_id
182 }
183}
184
185struct CohereGeneratorModel {
186 client: Client,
187 cb: crate::reliability::CircuitBreakerWrapper,
188 model_id: String,
189 api_key: String,
190}
191
192#[async_trait]
193impl GeneratorModel for CohereGeneratorModel {
194 async fn generate(
195 &self,
196 messages: &[String],
197 options: GenerationOptions,
198 ) -> Result<GenerationResult> {
199 let messages: Vec<serde_json::Value> = messages
200 .iter()
201 .enumerate()
202 .map(|(i, content)| {
203 let role = if i % 2 == 0 { "user" } else { "assistant" };
204 json!({ "role": role, "content": content })
205 })
206 .collect();
207
208 self.cb
209 .call(move || async move {
210 let mut body = json!({
211 "model": self.model_id,
212 "messages": messages,
213 });
214
215 if let Some(max_tokens) = options.max_tokens {
216 body["max_tokens"] = json!(max_tokens);
217 }
218 if let Some(temperature) = options.temperature {
219 body["temperature"] = json!(temperature);
220 }
221 if let Some(top_p) = options.top_p {
222 body["p"] = json!(top_p);
223 }
224
225 let response = self
226 .client
227 .post("https://api.cohere.com/v2/chat")
228 .header("Authorization", format!("Bearer {}", self.api_key))
229 .json(&body)
230 .send()
231 .await
232 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
233
234 let body: serde_json::Value = check_http_status("Cohere", response)?
235 .json()
236 .await
237 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
238
239 let text = body
240 .get("message")
241 .and_then(|m| m.get("content"))
242 .and_then(|c| c.as_array())
243 .and_then(|arr| arr.first())
244 .and_then(|item| item.get("text"))
245 .and_then(|t| t.as_str())
246 .unwrap_or("")
247 .to_string();
248
249 let usage = body.get("usage").map(|u| {
250 let input = u
251 .get("tokens")
252 .and_then(|t| t.get("input_tokens"))
253 .and_then(|v| v.as_u64())
254 .unwrap_or(0);
255 let output = u
256 .get("tokens")
257 .and_then(|t| t.get("output_tokens"))
258 .and_then(|v| v.as_u64())
259 .unwrap_or(0);
260 TokenUsage {
261 prompt_tokens: input as usize,
262 completion_tokens: output as usize,
263 total_tokens: (input + output) as usize,
264 }
265 });
266
267 Ok(GenerationResult { text, usage })
268 })
269 .await
270 }
271}
272
273struct CohereRerankerModel {
274 client: Client,
275 cb: crate::reliability::CircuitBreakerWrapper,
276 model_id: String,
277 api_key: String,
278}
279
280#[async_trait]
281impl RerankerModel for CohereRerankerModel {
282 async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>> {
283 let query = query.to_string();
284 let docs: Vec<String> = docs.iter().map(|s| s.to_string()).collect();
285
286 self.cb
287 .call(move || async move {
288 let response = self
289 .client
290 .post("https://api.cohere.com/v2/rerank")
291 .header("Authorization", format!("Bearer {}", self.api_key))
292 .json(&json!({
293 "query": query,
294 "documents": docs,
295 "model": self.model_id,
296 }))
297 .send()
298 .await
299 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
300
301 let body: serde_json::Value = check_http_status("Cohere", response)?
302 .json()
303 .await
304 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
305
306 let results_json =
307 body.get("results")
308 .and_then(|r| r.as_array())
309 .ok_or_else(|| {
310 RuntimeError::ApiError("Invalid rerank response format".to_string())
311 })?;
312
313 let mut results = Vec::new();
314 for item in results_json {
315 let index = item.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
316 let score = item
317 .get("relevance_score")
318 .and_then(|s| s.as_f64())
319 .unwrap_or(0.0) as f32;
320 results.push(ScoredDoc {
321 index,
322 score,
323 text: None,
324 });
325 }
326 Ok(results)
327 })
328 .await
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use crate::api::ModelRuntimeKey;
336 use crate::provider::remote_common::RemoteProviderBase;
337 use crate::traits::ModelProvider;
338 use std::time::Duration;
339
340 static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
341
342 fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
343 ModelAliasSpec {
344 alias: alias.to_string(),
345 task,
346 provider_id: "remote/cohere".to_string(),
347 model_id: model_id.to_string(),
348 revision: None,
349 warmup: crate::api::WarmupPolicy::Lazy,
350 required: false,
351 timeout: None,
352 load_timeout: None,
353 retry: None,
354 options: serde_json::Value::Null,
355 }
356 }
357
358 #[tokio::test]
359 async fn breaker_reused_for_same_runtime_key() {
360 let _lock = ENV_LOCK.lock().await;
361 unsafe { std::env::set_var("CO_API_KEY", "test-key") };
362
363 let provider = RemoteCohereProvider::new();
364 let s1 = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
365 let s2 = spec("embed/b", ModelTask::Embed, "embed-english-v3.0");
366
367 let _ = provider.load(&s1).await.unwrap();
368 let _ = provider.load(&s2).await.unwrap();
369
370 assert_eq!(provider.breaker_count(), 1);
371
372 unsafe { std::env::remove_var("CO_API_KEY") };
373 }
374
375 #[tokio::test]
376 async fn breaker_isolated_by_task_and_model() {
377 let _lock = ENV_LOCK.lock().await;
378 unsafe { std::env::set_var("CO_API_KEY", "test-key") };
379
380 let provider = RemoteCohereProvider::new();
381 let embed = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
382 let gen_spec = spec("chat/a", ModelTask::Generate, "command-r-plus");
383 let rerank = spec("rerank/a", ModelTask::Rerank, "rerank-english-v3.0");
384
385 let _ = provider.load(&embed).await.unwrap();
386 let _ = provider.load(&gen_spec).await.unwrap();
387 let _ = provider.load(&rerank).await.unwrap();
388
389 assert_eq!(provider.breaker_count(), 3);
390
391 unsafe { std::env::remove_var("CO_API_KEY") };
392 }
393
394 #[tokio::test]
395 async fn breaker_cleanup_evicts_stale_entries() {
396 let _lock = ENV_LOCK.lock().await;
397 unsafe { std::env::set_var("CO_API_KEY", "test-key") };
398
399 let provider = RemoteCohereProvider::new();
400 let stale = spec("embed/stale", ModelTask::Embed, "embed-english-v3.0");
401 let fresh = spec("chat/fresh", ModelTask::Generate, "command-r-plus");
402 provider.insert_test_breaker(
403 ModelRuntimeKey::new(&stale),
404 RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
405 );
406 provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
407 assert_eq!(provider.breaker_count(), 2);
408
409 provider.force_cleanup_now_for_test();
410 let _ = provider.load(&fresh).await.unwrap();
411
412 assert_eq!(provider.breaker_count(), 1);
413
414 unsafe { std::env::remove_var("CO_API_KEY") };
415 }
416
417 #[tokio::test]
418 async fn supports_all_three_tasks() {
419 let _lock = ENV_LOCK.lock().await;
420 unsafe { std::env::set_var("CO_API_KEY", "test-key") };
421
422 let provider = RemoteCohereProvider::new();
423
424 let embed = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
425 assert!(provider.load(&embed).await.is_ok());
426
427 let gen_spec = spec("gen/a", ModelTask::Generate, "command-r-plus");
428 assert!(provider.load(&gen_spec).await.is_ok());
429
430 let rerank = spec("rerank/a", ModelTask::Rerank, "rerank-english-v3.0");
431 assert!(provider.load(&rerank).await.is_ok());
432
433 unsafe { std::env::remove_var("CO_API_KEY") };
434 }
435}