uni_xervo/provider/
voyageai.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5    EmbeddingModel, LoadedModelHandle, ModelProvider, ProviderCapabilities, ProviderHealth,
6    RerankerModel, ScoredDoc,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13/// Remote provider that calls the [Voyage AI API](https://docs.voyageai.com/reference/embeddings-api)
14/// for embedding and reranking. Does not support text generation.
15///
16/// Requires the `VOYAGE_API_KEY` environment variable (or a custom env var
17/// name via the `api_key_env` option).
18pub struct RemoteVoyageAIProvider {
19    base: RemoteProviderBase,
20}
21
22impl Default for RemoteVoyageAIProvider {
23    fn default() -> Self {
24        Self {
25            base: RemoteProviderBase::new(),
26        }
27    }
28}
29
30impl RemoteVoyageAIProvider {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    #[cfg(test)]
36    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37        self.base.insert_test_breaker(key, age);
38    }
39
40    #[cfg(test)]
41    fn breaker_count(&self) -> usize {
42        self.base.breaker_count()
43    }
44
45    #[cfg(test)]
46    fn force_cleanup_now_for_test(&self) {
47        self.base.force_cleanup_now_for_test();
48    }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteVoyageAIProvider {
53    fn provider_id(&self) -> &'static str {
54        "remote/voyageai"
55    }
56
57    fn capabilities(&self) -> ProviderCapabilities {
58        ProviderCapabilities {
59            supported_tasks: vec![ModelTask::Embed, ModelTask::Rerank],
60        }
61    }
62
63    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64        let cb = self.base.circuit_breaker_for(spec);
65        let api_key = resolve_api_key(&spec.options, "api_key_env", "VOYAGE_API_KEY")?;
66
67        match spec.task {
68            ModelTask::Embed => {
69                let model = VoyageAIEmbeddingModel {
70                    client: self.base.client.clone(),
71                    cb: cb.clone(),
72                    model_id: spec.model_id.clone(),
73                    api_key,
74                };
75                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
76                Ok(Arc::new(handle) as LoadedModelHandle)
77            }
78            ModelTask::Rerank => {
79                let model = VoyageAIRerankerModel {
80                    client: self.base.client.clone(),
81                    cb,
82                    model_id: spec.model_id.clone(),
83                    api_key,
84                };
85                let handle: Arc<dyn RerankerModel> = Arc::new(model);
86                Ok(Arc::new(handle) as LoadedModelHandle)
87            }
88            _ => Err(RuntimeError::CapabilityMismatch(format!(
89                "Voyage AI provider does not support task {:?}",
90                spec.task
91            ))),
92        }
93    }
94
95    async fn health(&self) -> ProviderHealth {
96        ProviderHealth::Healthy
97    }
98}
99
100struct VoyageAIEmbeddingModel {
101    client: Client,
102    cb: crate::reliability::CircuitBreakerWrapper,
103    model_id: String,
104    api_key: String,
105}
106
107#[async_trait]
108impl EmbeddingModel for VoyageAIEmbeddingModel {
109    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
110        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
111
112        self.cb
113            .call(move || async move {
114                let response = self
115                    .client
116                    .post("https://api.voyageai.com/v1/embeddings")
117                    .header("Authorization", format!("Bearer {}", self.api_key))
118                    .json(&json!({
119                        "input": texts,
120                        "model": self.model_id
121                    }))
122                    .send()
123                    .await
124                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
125
126                let body: serde_json::Value = check_http_status("Voyage AI", response)?
127                    .json()
128                    .await
129                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
130
131                let mut embeddings = Vec::new();
132                if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
133                    for item in data {
134                        if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
135                            let vec: Vec<f32> = embedding
136                                .iter()
137                                .filter_map(|v| v.as_f64().map(|f| f as f32))
138                                .collect();
139                            embeddings.push(vec);
140                        }
141                    }
142                }
143                Ok(embeddings)
144            })
145            .await
146    }
147
148    fn dimensions(&self) -> u32 {
149        match self.model_id.as_str() {
150            "voyage-large-2" => 1536,
151            _ => 1024,
152        }
153    }
154
155    fn model_id(&self) -> &str {
156        &self.model_id
157    }
158}
159
160struct VoyageAIRerankerModel {
161    client: Client,
162    cb: crate::reliability::CircuitBreakerWrapper,
163    model_id: String,
164    api_key: String,
165}
166
167#[async_trait]
168impl RerankerModel for VoyageAIRerankerModel {
169    async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>> {
170        let query = query.to_string();
171        let docs: Vec<String> = docs.iter().map(|s| s.to_string()).collect();
172
173        self.cb
174            .call(move || async move {
175                let response = self
176                    .client
177                    .post("https://api.voyageai.com/v1/reranking")
178                    .header("Authorization", format!("Bearer {}", self.api_key))
179                    .json(&json!({
180                        "query": query,
181                        "documents": docs,
182                        "model": self.model_id,
183                    }))
184                    .send()
185                    .await
186                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
187
188                let body: serde_json::Value = check_http_status("Voyage AI", response)?
189                    .json()
190                    .await
191                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
192
193                let data = body.get("data").and_then(|d| d.as_array()).ok_or_else(|| {
194                    RuntimeError::ApiError("Invalid rerank response format".to_string())
195                })?;
196
197                let mut results = Vec::new();
198                for item in data {
199                    let index = item.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
200                    let score = item
201                        .get("relevance_score")
202                        .and_then(|s| s.as_f64())
203                        .unwrap_or(0.0) as f32;
204                    results.push(ScoredDoc {
205                        index,
206                        score,
207                        text: None,
208                    });
209                }
210                Ok(results)
211            })
212            .await
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::api::ModelRuntimeKey;
220    use crate::provider::remote_common::RemoteProviderBase;
221    use crate::traits::ModelProvider;
222    use std::time::Duration;
223
224    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
225
226    fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
227        ModelAliasSpec {
228            alias: alias.to_string(),
229            task,
230            provider_id: "remote/voyageai".to_string(),
231            model_id: model_id.to_string(),
232            revision: None,
233            warmup: crate::api::WarmupPolicy::Lazy,
234            required: false,
235            timeout: None,
236            load_timeout: None,
237            retry: None,
238            options: serde_json::Value::Null,
239        }
240    }
241
242    #[tokio::test]
243    async fn breaker_reused_for_same_runtime_key() {
244        let _lock = ENV_LOCK.lock().await;
245        unsafe { std::env::set_var("VOYAGE_API_KEY", "test-key") };
246
247        let provider = RemoteVoyageAIProvider::new();
248        let s1 = spec("embed/a", ModelTask::Embed, "voyage-3");
249        let s2 = spec("embed/b", ModelTask::Embed, "voyage-3");
250
251        let _ = provider.load(&s1).await.unwrap();
252        let _ = provider.load(&s2).await.unwrap();
253
254        assert_eq!(provider.breaker_count(), 1);
255
256        unsafe { std::env::remove_var("VOYAGE_API_KEY") };
257    }
258
259    #[tokio::test]
260    async fn breaker_isolated_by_task_and_model() {
261        let _lock = ENV_LOCK.lock().await;
262        unsafe { std::env::set_var("VOYAGE_API_KEY", "test-key") };
263
264        let provider = RemoteVoyageAIProvider::new();
265        let embed = spec("embed/a", ModelTask::Embed, "voyage-3");
266        let rerank = spec("rerank/a", ModelTask::Rerank, "rerank-2");
267
268        let _ = provider.load(&embed).await.unwrap();
269        let _ = provider.load(&rerank).await.unwrap();
270
271        assert_eq!(provider.breaker_count(), 2);
272
273        unsafe { std::env::remove_var("VOYAGE_API_KEY") };
274    }
275
276    #[tokio::test]
277    async fn breaker_cleanup_evicts_stale_entries() {
278        let _lock = ENV_LOCK.lock().await;
279        unsafe { std::env::set_var("VOYAGE_API_KEY", "test-key") };
280
281        let provider = RemoteVoyageAIProvider::new();
282        let stale = spec("embed/stale", ModelTask::Embed, "voyage-3");
283        let fresh = spec("rerank/fresh", ModelTask::Rerank, "rerank-2");
284        provider.insert_test_breaker(
285            ModelRuntimeKey::new(&stale),
286            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
287        );
288        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
289        assert_eq!(provider.breaker_count(), 2);
290
291        provider.force_cleanup_now_for_test();
292        let _ = provider.load(&fresh).await.unwrap();
293
294        assert_eq!(provider.breaker_count(), 1);
295
296        unsafe { std::env::remove_var("VOYAGE_API_KEY") };
297    }
298
299    #[tokio::test]
300    async fn generate_capability_mismatch() {
301        let _lock = ENV_LOCK.lock().await;
302        unsafe { std::env::set_var("VOYAGE_API_KEY", "test-key") };
303
304        let provider = RemoteVoyageAIProvider::new();
305        let s = spec("gen/a", ModelTask::Generate, "voyage-3");
306        let result = provider.load(&s).await;
307        assert!(result.is_err());
308        assert!(
309            result
310                .unwrap_err()
311                .to_string()
312                .contains("does not support task")
313        );
314
315        unsafe { std::env::remove_var("VOYAGE_API_KEY") };
316    }
317}