uni_xervo/provider/
cohere.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6    ModelProvider, ProviderCapabilities, ProviderHealth, RerankerModel, ScoredDoc, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13/// Remote provider that calls the [Cohere API](https://docs.cohere.com/reference/about)
14/// for embedding, text generation (chat), and reranking.
15///
16/// Requires the `CO_API_KEY` environment variable (or a custom env var name
17/// via the `api_key_env` option).
18pub struct RemoteCohereProvider {
19    base: RemoteProviderBase,
20}
21
22impl Default for RemoteCohereProvider {
23    fn default() -> Self {
24        Self {
25            base: RemoteProviderBase::new(),
26        }
27    }
28}
29
30impl RemoteCohereProvider {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    #[cfg(test)]
36    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37        self.base.insert_test_breaker(key, age);
38    }
39
40    #[cfg(test)]
41    fn breaker_count(&self) -> usize {
42        self.base.breaker_count()
43    }
44
45    #[cfg(test)]
46    fn force_cleanup_now_for_test(&self) {
47        self.base.force_cleanup_now_for_test();
48    }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteCohereProvider {
53    fn provider_id(&self) -> &'static str {
54        "remote/cohere"
55    }
56
57    fn capabilities(&self) -> ProviderCapabilities {
58        ProviderCapabilities {
59            supported_tasks: vec![ModelTask::Embed, ModelTask::Generate, ModelTask::Rerank],
60        }
61    }
62
63    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64        let cb = self.base.circuit_breaker_for(spec);
65        let api_key = resolve_api_key(&spec.options, "api_key_env", "CO_API_KEY")?;
66
67        let input_type = spec
68            .options
69            .get("input_type")
70            .and_then(|v| v.as_str())
71            .unwrap_or("search_document")
72            .to_string();
73
74        match spec.task {
75            ModelTask::Embed => {
76                let model = CohereEmbeddingModel {
77                    client: self.base.client.clone(),
78                    cb: cb.clone(),
79                    model_id: spec.model_id.clone(),
80                    api_key,
81                    input_type,
82                };
83                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
84                Ok(Arc::new(handle) as LoadedModelHandle)
85            }
86            ModelTask::Generate => {
87                let model = CohereGeneratorModel {
88                    client: self.base.client.clone(),
89                    cb,
90                    model_id: spec.model_id.clone(),
91                    api_key,
92                };
93                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
94                Ok(Arc::new(handle) as LoadedModelHandle)
95            }
96            ModelTask::Rerank => {
97                let model = CohereRerankerModel {
98                    client: self.base.client.clone(),
99                    cb,
100                    model_id: spec.model_id.clone(),
101                    api_key,
102                };
103                let handle: Arc<dyn RerankerModel> = Arc::new(model);
104                Ok(Arc::new(handle) as LoadedModelHandle)
105            }
106        }
107    }
108
109    async fn health(&self) -> ProviderHealth {
110        ProviderHealth::Healthy
111    }
112}
113
114struct CohereEmbeddingModel {
115    client: Client,
116    cb: crate::reliability::CircuitBreakerWrapper,
117    model_id: String,
118    api_key: String,
119    input_type: String,
120}
121
122#[async_trait]
123impl EmbeddingModel for CohereEmbeddingModel {
124    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
125        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
126
127        self.cb
128            .call(move || async move {
129                let response = self
130                    .client
131                    .post("https://api.cohere.com/v2/embed")
132                    .header("Authorization", format!("Bearer {}", self.api_key))
133                    .json(&json!({
134                        "texts": texts,
135                        "model": self.model_id,
136                        "input_type": self.input_type,
137                        "embedding_types": ["float"]
138                    }))
139                    .send()
140                    .await
141                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
142
143                let body: serde_json::Value = check_http_status("Cohere", response)?
144                    .json()
145                    .await
146                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
147
148                let float_embeddings = body
149                    .get("embeddings")
150                    .and_then(|e| e.get("float"))
151                    .and_then(|f| f.as_array())
152                    .ok_or_else(|| {
153                        RuntimeError::ApiError(
154                            "Invalid Cohere embedding response format".to_string(),
155                        )
156                    })?;
157
158                let mut result = Vec::new();
159                for embedding in float_embeddings {
160                    if let Some(values) = embedding.as_array() {
161                        let vec: Vec<f32> = values
162                            .iter()
163                            .filter_map(|v| v.as_f64().map(|f| f as f32))
164                            .collect();
165                        result.push(vec);
166                    }
167                }
168                Ok(result)
169            })
170            .await
171    }
172
173    fn dimensions(&self) -> u32 {
174        match self.model_id.as_str() {
175            "embed-english-light-v3.0" | "embed-multilingual-light-v3.0" => 384,
176            _ => 1024,
177        }
178    }
179
180    fn model_id(&self) -> &str {
181        &self.model_id
182    }
183}
184
185struct CohereGeneratorModel {
186    client: Client,
187    cb: crate::reliability::CircuitBreakerWrapper,
188    model_id: String,
189    api_key: String,
190}
191
192#[async_trait]
193impl GeneratorModel for CohereGeneratorModel {
194    async fn generate(
195        &self,
196        messages: &[String],
197        options: GenerationOptions,
198    ) -> Result<GenerationResult> {
199        let messages: Vec<serde_json::Value> = messages
200            .iter()
201            .enumerate()
202            .map(|(i, content)| {
203                let role = if i % 2 == 0 { "user" } else { "assistant" };
204                json!({ "role": role, "content": content })
205            })
206            .collect();
207
208        self.cb
209            .call(move || async move {
210                let mut body = json!({
211                    "model": self.model_id,
212                    "messages": messages,
213                });
214
215                if let Some(max_tokens) = options.max_tokens {
216                    body["max_tokens"] = json!(max_tokens);
217                }
218                if let Some(temperature) = options.temperature {
219                    body["temperature"] = json!(temperature);
220                }
221                if let Some(top_p) = options.top_p {
222                    body["p"] = json!(top_p);
223                }
224
225                let response = self
226                    .client
227                    .post("https://api.cohere.com/v2/chat")
228                    .header("Authorization", format!("Bearer {}", self.api_key))
229                    .json(&body)
230                    .send()
231                    .await
232                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
233
234                let body: serde_json::Value = check_http_status("Cohere", response)?
235                    .json()
236                    .await
237                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
238
239                let text = body
240                    .get("message")
241                    .and_then(|m| m.get("content"))
242                    .and_then(|c| c.as_array())
243                    .and_then(|arr| arr.first())
244                    .and_then(|item| item.get("text"))
245                    .and_then(|t| t.as_str())
246                    .unwrap_or("")
247                    .to_string();
248
249                let usage = body.get("usage").map(|u| {
250                    let input = u
251                        .get("tokens")
252                        .and_then(|t| t.get("input_tokens"))
253                        .and_then(|v| v.as_u64())
254                        .unwrap_or(0);
255                    let output = u
256                        .get("tokens")
257                        .and_then(|t| t.get("output_tokens"))
258                        .and_then(|v| v.as_u64())
259                        .unwrap_or(0);
260                    TokenUsage {
261                        prompt_tokens: input as usize,
262                        completion_tokens: output as usize,
263                        total_tokens: (input + output) as usize,
264                    }
265                });
266
267                Ok(GenerationResult { text, usage })
268            })
269            .await
270    }
271}
272
273struct CohereRerankerModel {
274    client: Client,
275    cb: crate::reliability::CircuitBreakerWrapper,
276    model_id: String,
277    api_key: String,
278}
279
280#[async_trait]
281impl RerankerModel for CohereRerankerModel {
282    async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>> {
283        let query = query.to_string();
284        let docs: Vec<String> = docs.iter().map(|s| s.to_string()).collect();
285
286        self.cb
287            .call(move || async move {
288                let response = self
289                    .client
290                    .post("https://api.cohere.com/v2/rerank")
291                    .header("Authorization", format!("Bearer {}", self.api_key))
292                    .json(&json!({
293                        "query": query,
294                        "documents": docs,
295                        "model": self.model_id,
296                    }))
297                    .send()
298                    .await
299                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
300
301                let body: serde_json::Value = check_http_status("Cohere", response)?
302                    .json()
303                    .await
304                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
305
306                let results_json =
307                    body.get("results")
308                        .and_then(|r| r.as_array())
309                        .ok_or_else(|| {
310                            RuntimeError::ApiError("Invalid rerank response format".to_string())
311                        })?;
312
313                let mut results = Vec::new();
314                for item in results_json {
315                    let index = item.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
316                    let score = item
317                        .get("relevance_score")
318                        .and_then(|s| s.as_f64())
319                        .unwrap_or(0.0) as f32;
320                    results.push(ScoredDoc {
321                        index,
322                        score,
323                        text: None,
324                    });
325                }
326                Ok(results)
327            })
328            .await
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use crate::api::ModelRuntimeKey;
336    use crate::provider::remote_common::RemoteProviderBase;
337    use crate::traits::ModelProvider;
338    use std::time::Duration;
339
340    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
341
342    fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
343        ModelAliasSpec {
344            alias: alias.to_string(),
345            task,
346            provider_id: "remote/cohere".to_string(),
347            model_id: model_id.to_string(),
348            revision: None,
349            warmup: crate::api::WarmupPolicy::Lazy,
350            required: false,
351            timeout: None,
352            load_timeout: None,
353            retry: None,
354            options: serde_json::Value::Null,
355        }
356    }
357
358    #[tokio::test]
359    async fn breaker_reused_for_same_runtime_key() {
360        let _lock = ENV_LOCK.lock().await;
361        unsafe { std::env::set_var("CO_API_KEY", "test-key") };
362
363        let provider = RemoteCohereProvider::new();
364        let s1 = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
365        let s2 = spec("embed/b", ModelTask::Embed, "embed-english-v3.0");
366
367        let _ = provider.load(&s1).await.unwrap();
368        let _ = provider.load(&s2).await.unwrap();
369
370        assert_eq!(provider.breaker_count(), 1);
371
372        unsafe { std::env::remove_var("CO_API_KEY") };
373    }
374
375    #[tokio::test]
376    async fn breaker_isolated_by_task_and_model() {
377        let _lock = ENV_LOCK.lock().await;
378        unsafe { std::env::set_var("CO_API_KEY", "test-key") };
379
380        let provider = RemoteCohereProvider::new();
381        let embed = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
382        let gen_spec = spec("chat/a", ModelTask::Generate, "command-r-plus");
383        let rerank = spec("rerank/a", ModelTask::Rerank, "rerank-english-v3.0");
384
385        let _ = provider.load(&embed).await.unwrap();
386        let _ = provider.load(&gen_spec).await.unwrap();
387        let _ = provider.load(&rerank).await.unwrap();
388
389        assert_eq!(provider.breaker_count(), 3);
390
391        unsafe { std::env::remove_var("CO_API_KEY") };
392    }
393
394    #[tokio::test]
395    async fn breaker_cleanup_evicts_stale_entries() {
396        let _lock = ENV_LOCK.lock().await;
397        unsafe { std::env::set_var("CO_API_KEY", "test-key") };
398
399        let provider = RemoteCohereProvider::new();
400        let stale = spec("embed/stale", ModelTask::Embed, "embed-english-v3.0");
401        let fresh = spec("chat/fresh", ModelTask::Generate, "command-r-plus");
402        provider.insert_test_breaker(
403            ModelRuntimeKey::new(&stale),
404            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
405        );
406        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
407        assert_eq!(provider.breaker_count(), 2);
408
409        provider.force_cleanup_now_for_test();
410        let _ = provider.load(&fresh).await.unwrap();
411
412        assert_eq!(provider.breaker_count(), 1);
413
414        unsafe { std::env::remove_var("CO_API_KEY") };
415    }
416
417    #[tokio::test]
418    async fn supports_all_three_tasks() {
419        let _lock = ENV_LOCK.lock().await;
420        unsafe { std::env::set_var("CO_API_KEY", "test-key") };
421
422        let provider = RemoteCohereProvider::new();
423
424        let embed = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
425        assert!(provider.load(&embed).await.is_ok());
426
427        let gen_spec = spec("gen/a", ModelTask::Generate, "command-r-plus");
428        assert!(provider.load(&gen_spec).await.is_ok());
429
430        let rerank = spec("rerank/a", ModelTask::Rerank, "rerank-english-v3.0");
431        assert!(provider.load(&rerank).await.is_ok());
432
433        unsafe { std::env::remove_var("CO_API_KEY") };
434    }
435}