uni_xervo/provider/
cohere.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6    Message, MessageRole, ModelProvider, ProviderCapabilities, ProviderHealth, RerankerModel,
7    ScoredDoc, TokenUsage,
8};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde_json::json;
12use std::sync::Arc;
13
14/// Remote provider that calls the [Cohere API](https://docs.cohere.com/reference/about)
15/// for embedding, text generation (chat), and reranking.
16///
17/// Requires the `CO_API_KEY` environment variable (or a custom env var name
18/// via the `api_key_env` option).
19pub struct RemoteCohereProvider {
20    base: RemoteProviderBase,
21}
22
23impl Default for RemoteCohereProvider {
24    fn default() -> Self {
25        Self {
26            base: RemoteProviderBase::new(),
27        }
28    }
29}
30
31impl RemoteCohereProvider {
32    pub fn new() -> Self {
33        Self::default()
34    }
35
36    #[cfg(test)]
37    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
38        self.base.insert_test_breaker(key, age);
39    }
40
41    #[cfg(test)]
42    fn breaker_count(&self) -> usize {
43        self.base.breaker_count()
44    }
45
46    #[cfg(test)]
47    fn force_cleanup_now_for_test(&self) {
48        self.base.force_cleanup_now_for_test();
49    }
50}
51
52#[async_trait]
53impl ModelProvider for RemoteCohereProvider {
54    fn provider_id(&self) -> &'static str {
55        "remote/cohere"
56    }
57
58    fn capabilities(&self) -> ProviderCapabilities {
59        ProviderCapabilities {
60            supported_tasks: vec![ModelTask::Embed, ModelTask::Generate, ModelTask::Rerank],
61        }
62    }
63
64    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
65        let cb = self.base.circuit_breaker_for(spec);
66        let api_key = resolve_api_key(&spec.options, "api_key_env", "CO_API_KEY")?;
67
68        let input_type = spec
69            .options
70            .get("input_type")
71            .and_then(|v| v.as_str())
72            .unwrap_or("search_document")
73            .to_string();
74
75        match spec.task {
76            ModelTask::Embed => {
77                let model = CohereEmbeddingModel {
78                    client: self.base.client.clone(),
79                    cb: cb.clone(),
80                    model_id: spec.model_id.clone(),
81                    api_key,
82                    input_type,
83                };
84                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
85                Ok(Arc::new(handle) as LoadedModelHandle)
86            }
87            ModelTask::Generate => {
88                let model = CohereGeneratorModel {
89                    client: self.base.client.clone(),
90                    cb,
91                    model_id: spec.model_id.clone(),
92                    api_key,
93                };
94                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
95                Ok(Arc::new(handle) as LoadedModelHandle)
96            }
97            ModelTask::Rerank => {
98                let model = CohereRerankerModel {
99                    client: self.base.client.clone(),
100                    cb,
101                    model_id: spec.model_id.clone(),
102                    api_key,
103                };
104                let handle: Arc<dyn RerankerModel> = Arc::new(model);
105                Ok(Arc::new(handle) as LoadedModelHandle)
106            }
107        }
108    }
109
110    async fn health(&self) -> ProviderHealth {
111        ProviderHealth::Healthy
112    }
113}
114
115struct CohereEmbeddingModel {
116    client: Client,
117    cb: crate::reliability::CircuitBreakerWrapper,
118    model_id: String,
119    api_key: String,
120    input_type: String,
121}
122
123#[async_trait]
124impl EmbeddingModel for CohereEmbeddingModel {
125    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
126        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
127
128        self.cb
129            .call(move || async move {
130                let response = self
131                    .client
132                    .post("https://api.cohere.com/v2/embed")
133                    .header("Authorization", format!("Bearer {}", self.api_key))
134                    .json(&json!({
135                        "texts": texts,
136                        "model": self.model_id,
137                        "input_type": self.input_type,
138                        "embedding_types": ["float"]
139                    }))
140                    .send()
141                    .await
142                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
143
144                let body: serde_json::Value = check_http_status("Cohere", response)?
145                    .json()
146                    .await
147                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
148
149                let float_embeddings = body
150                    .get("embeddings")
151                    .and_then(|e| e.get("float"))
152                    .and_then(|f| f.as_array())
153                    .ok_or_else(|| {
154                        RuntimeError::ApiError(
155                            "Invalid Cohere embedding response format".to_string(),
156                        )
157                    })?;
158
159                let mut result = Vec::new();
160                for embedding in float_embeddings {
161                    if let Some(values) = embedding.as_array() {
162                        let vec: Vec<f32> = values
163                            .iter()
164                            .filter_map(|v| v.as_f64().map(|f| f as f32))
165                            .collect();
166                        result.push(vec);
167                    }
168                }
169                Ok(result)
170            })
171            .await
172    }
173
174    fn dimensions(&self) -> u32 {
175        match self.model_id.as_str() {
176            "embed-english-light-v3.0" | "embed-multilingual-light-v3.0" => 384,
177            _ => 1024,
178        }
179    }
180
181    fn model_id(&self) -> &str {
182        &self.model_id
183    }
184}
185
186struct CohereGeneratorModel {
187    client: Client,
188    cb: crate::reliability::CircuitBreakerWrapper,
189    model_id: String,
190    api_key: String,
191}
192
193#[async_trait]
194impl GeneratorModel for CohereGeneratorModel {
195    async fn generate(
196        &self,
197        messages: &[Message],
198        options: GenerationOptions,
199    ) -> Result<GenerationResult> {
200        let messages: Vec<serde_json::Value> = messages
201            .iter()
202            .map(|msg| {
203                let role = match msg.role {
204                    MessageRole::System => "system",
205                    MessageRole::User => "user",
206                    MessageRole::Assistant => "assistant",
207                };
208                json!({ "role": role, "content": msg.text() })
209            })
210            .collect();
211
212        self.cb
213            .call(move || async move {
214                let mut body = json!({
215                    "model": self.model_id,
216                    "messages": messages,
217                });
218
219                if let Some(max_tokens) = options.max_tokens {
220                    body["max_tokens"] = json!(max_tokens);
221                }
222                if let Some(temperature) = options.temperature {
223                    body["temperature"] = json!(temperature);
224                }
225                if let Some(top_p) = options.top_p {
226                    body["p"] = json!(top_p);
227                }
228
229                let response = self
230                    .client
231                    .post("https://api.cohere.com/v2/chat")
232                    .header("Authorization", format!("Bearer {}", self.api_key))
233                    .json(&body)
234                    .send()
235                    .await
236                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
237
238                let body: serde_json::Value = check_http_status("Cohere", response)?
239                    .json()
240                    .await
241                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
242
243                let text = body
244                    .get("message")
245                    .and_then(|m| m.get("content"))
246                    .and_then(|c| c.as_array())
247                    .and_then(|arr| arr.first())
248                    .and_then(|item| item.get("text"))
249                    .and_then(|t| t.as_str())
250                    .unwrap_or("")
251                    .to_string();
252
253                let usage = body.get("usage").map(|u| {
254                    let input = u
255                        .get("tokens")
256                        .and_then(|t| t.get("input_tokens"))
257                        .and_then(|v| v.as_u64())
258                        .unwrap_or(0);
259                    let output = u
260                        .get("tokens")
261                        .and_then(|t| t.get("output_tokens"))
262                        .and_then(|v| v.as_u64())
263                        .unwrap_or(0);
264                    TokenUsage {
265                        prompt_tokens: input as usize,
266                        completion_tokens: output as usize,
267                        total_tokens: (input + output) as usize,
268                    }
269                });
270
271                Ok(GenerationResult {
272                    text,
273                    usage,
274                    images: vec![],
275                    audio: None,
276                })
277            })
278            .await
279    }
280}
281
282struct CohereRerankerModel {
283    client: Client,
284    cb: crate::reliability::CircuitBreakerWrapper,
285    model_id: String,
286    api_key: String,
287}
288
289#[async_trait]
290impl RerankerModel for CohereRerankerModel {
291    async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>> {
292        let query = query.to_string();
293        let docs: Vec<String> = docs.iter().map(|s| s.to_string()).collect();
294
295        self.cb
296            .call(move || async move {
297                let response = self
298                    .client
299                    .post("https://api.cohere.com/v2/rerank")
300                    .header("Authorization", format!("Bearer {}", self.api_key))
301                    .json(&json!({
302                        "query": query,
303                        "documents": docs,
304                        "model": self.model_id,
305                    }))
306                    .send()
307                    .await
308                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
309
310                let body: serde_json::Value = check_http_status("Cohere", response)?
311                    .json()
312                    .await
313                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
314
315                let results_json =
316                    body.get("results")
317                        .and_then(|r| r.as_array())
318                        .ok_or_else(|| {
319                            RuntimeError::ApiError("Invalid rerank response format".to_string())
320                        })?;
321
322                let mut results = Vec::new();
323                for item in results_json {
324                    let index = item.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
325                    let score = item
326                        .get("relevance_score")
327                        .and_then(|s| s.as_f64())
328                        .unwrap_or(0.0) as f32;
329                    results.push(ScoredDoc {
330                        index,
331                        score,
332                        text: None,
333                    });
334                }
335                Ok(results)
336            })
337            .await
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use crate::api::ModelRuntimeKey;
345    use crate::provider::remote_common::RemoteProviderBase;
346    use crate::traits::ModelProvider;
347    use std::time::Duration;
348
349    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
350
351    fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
352        ModelAliasSpec {
353            alias: alias.to_string(),
354            task,
355            provider_id: "remote/cohere".to_string(),
356            model_id: model_id.to_string(),
357            revision: None,
358            warmup: crate::api::WarmupPolicy::Lazy,
359            required: false,
360            timeout: None,
361            load_timeout: None,
362            retry: None,
363            options: serde_json::Value::Null,
364        }
365    }
366
367    #[tokio::test]
368    async fn breaker_reused_for_same_runtime_key() {
369        let _lock = ENV_LOCK.lock().await;
370        unsafe { std::env::set_var("CO_API_KEY", "test-key") };
371
372        let provider = RemoteCohereProvider::new();
373        let s1 = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
374        let s2 = spec("embed/b", ModelTask::Embed, "embed-english-v3.0");
375
376        let _ = provider.load(&s1).await.unwrap();
377        let _ = provider.load(&s2).await.unwrap();
378
379        assert_eq!(provider.breaker_count(), 1);
380
381        unsafe { std::env::remove_var("CO_API_KEY") };
382    }
383
384    #[tokio::test]
385    async fn breaker_isolated_by_task_and_model() {
386        let _lock = ENV_LOCK.lock().await;
387        unsafe { std::env::set_var("CO_API_KEY", "test-key") };
388
389        let provider = RemoteCohereProvider::new();
390        let embed = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
391        let gen_spec = spec("chat/a", ModelTask::Generate, "command-r-plus");
392        let rerank = spec("rerank/a", ModelTask::Rerank, "rerank-english-v3.0");
393
394        let _ = provider.load(&embed).await.unwrap();
395        let _ = provider.load(&gen_spec).await.unwrap();
396        let _ = provider.load(&rerank).await.unwrap();
397
398        assert_eq!(provider.breaker_count(), 3);
399
400        unsafe { std::env::remove_var("CO_API_KEY") };
401    }
402
403    #[tokio::test]
404    async fn breaker_cleanup_evicts_stale_entries() {
405        let _lock = ENV_LOCK.lock().await;
406        unsafe { std::env::set_var("CO_API_KEY", "test-key") };
407
408        let provider = RemoteCohereProvider::new();
409        let stale = spec("embed/stale", ModelTask::Embed, "embed-english-v3.0");
410        let fresh = spec("chat/fresh", ModelTask::Generate, "command-r-plus");
411        provider.insert_test_breaker(
412            ModelRuntimeKey::new(&stale),
413            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
414        );
415        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
416        assert_eq!(provider.breaker_count(), 2);
417
418        provider.force_cleanup_now_for_test();
419        let _ = provider.load(&fresh).await.unwrap();
420
421        assert_eq!(provider.breaker_count(), 1);
422
423        unsafe { std::env::remove_var("CO_API_KEY") };
424    }
425
426    #[tokio::test]
427    async fn supports_all_three_tasks() {
428        let _lock = ENV_LOCK.lock().await;
429        unsafe { std::env::set_var("CO_API_KEY", "test-key") };
430
431        let provider = RemoteCohereProvider::new();
432
433        let embed = spec("embed/a", ModelTask::Embed, "embed-english-v3.0");
434        assert!(provider.load(&embed).await.is_ok());
435
436        let gen_spec = spec("gen/a", ModelTask::Generate, "command-r-plus");
437        assert!(provider.load(&gen_spec).await.is_ok());
438
439        let rerank = spec("rerank/a", ModelTask::Rerank, "rerank-english-v3.0");
440        assert!(provider.load(&rerank).await.is_ok());
441
442        unsafe { std::env::remove_var("CO_API_KEY") };
443    }
444}