1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5 EmbeddingModel, LoadedModelHandle, ModelProvider, ProviderCapabilities, ProviderHealth,
6 RerankerModel, ScoredDoc,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13pub struct RemoteVoyageAIProvider {
19 base: RemoteProviderBase,
20}
21
22impl Default for RemoteVoyageAIProvider {
23 fn default() -> Self {
24 Self {
25 base: RemoteProviderBase::new(),
26 }
27 }
28}
29
30impl RemoteVoyageAIProvider {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 #[cfg(test)]
36 fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37 self.base.insert_test_breaker(key, age);
38 }
39
40 #[cfg(test)]
41 fn breaker_count(&self) -> usize {
42 self.base.breaker_count()
43 }
44
45 #[cfg(test)]
46 fn force_cleanup_now_for_test(&self) {
47 self.base.force_cleanup_now_for_test();
48 }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteVoyageAIProvider {
53 fn provider_id(&self) -> &'static str {
54 "remote/voyageai"
55 }
56
57 fn capabilities(&self) -> ProviderCapabilities {
58 ProviderCapabilities {
59 supported_tasks: vec![ModelTask::Embed, ModelTask::Rerank],
60 }
61 }
62
63 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64 let cb = self.base.circuit_breaker_for(spec);
65 let api_key = resolve_api_key(&spec.options, "api_key_env", "VOYAGE_API_KEY")?;
66
67 match spec.task {
68 ModelTask::Embed => {
69 let model = VoyageAIEmbeddingModel {
70 client: self.base.client.clone(),
71 cb: cb.clone(),
72 model_id: spec.model_id.clone(),
73 api_key,
74 };
75 let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
76 Ok(Arc::new(handle) as LoadedModelHandle)
77 }
78 ModelTask::Rerank => {
79 let model = VoyageAIRerankerModel {
80 client: self.base.client.clone(),
81 cb,
82 model_id: spec.model_id.clone(),
83 api_key,
84 };
85 let handle: Arc<dyn RerankerModel> = Arc::new(model);
86 Ok(Arc::new(handle) as LoadedModelHandle)
87 }
88 _ => Err(RuntimeError::CapabilityMismatch(format!(
89 "Voyage AI provider does not support task {:?}",
90 spec.task
91 ))),
92 }
93 }
94
95 async fn health(&self) -> ProviderHealth {
96 ProviderHealth::Healthy
97 }
98}
99
100struct VoyageAIEmbeddingModel {
101 client: Client,
102 cb: crate::reliability::CircuitBreakerWrapper,
103 model_id: String,
104 api_key: String,
105}
106
107#[async_trait]
108impl EmbeddingModel for VoyageAIEmbeddingModel {
109 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
110 let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
111
112 self.cb
113 .call(move || async move {
114 let response = self
115 .client
116 .post("https://api.voyageai.com/v1/embeddings")
117 .header("Authorization", format!("Bearer {}", self.api_key))
118 .json(&json!({
119 "input": texts,
120 "model": self.model_id
121 }))
122 .send()
123 .await
124 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
125
126 let body: serde_json::Value = check_http_status("Voyage AI", response)?
127 .json()
128 .await
129 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
130
131 let mut embeddings = Vec::new();
132 if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
133 for item in data {
134 if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
135 let vec: Vec<f32> = embedding
136 .iter()
137 .filter_map(|v| v.as_f64().map(|f| f as f32))
138 .collect();
139 embeddings.push(vec);
140 }
141 }
142 }
143 Ok(embeddings)
144 })
145 .await
146 }
147
148 fn dimensions(&self) -> u32 {
149 match self.model_id.as_str() {
150 "voyage-large-2" => 1536,
151 _ => 1024,
152 }
153 }
154
155 fn model_id(&self) -> &str {
156 &self.model_id
157 }
158}
159
160struct VoyageAIRerankerModel {
161 client: Client,
162 cb: crate::reliability::CircuitBreakerWrapper,
163 model_id: String,
164 api_key: String,
165}
166
167#[async_trait]
168impl RerankerModel for VoyageAIRerankerModel {
169 async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>> {
170 let query = query.to_string();
171 let docs: Vec<String> = docs.iter().map(|s| s.to_string()).collect();
172
173 self.cb
174 .call(move || async move {
175 let response = self
176 .client
177 .post("https://api.voyageai.com/v1/reranking")
178 .header("Authorization", format!("Bearer {}", self.api_key))
179 .json(&json!({
180 "query": query,
181 "documents": docs,
182 "model": self.model_id,
183 }))
184 .send()
185 .await
186 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
187
188 let body: serde_json::Value = check_http_status("Voyage AI", response)?
189 .json()
190 .await
191 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
192
193 let data = body.get("data").and_then(|d| d.as_array()).ok_or_else(|| {
194 RuntimeError::ApiError("Invalid rerank response format".to_string())
195 })?;
196
197 let mut results = Vec::new();
198 for item in data {
199 let index = item.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
200 let score = item
201 .get("relevance_score")
202 .and_then(|s| s.as_f64())
203 .unwrap_or(0.0) as f32;
204 results.push(ScoredDoc {
205 index,
206 score,
207 text: None,
208 });
209 }
210 Ok(results)
211 })
212 .await
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::api::ModelRuntimeKey;
220 use crate::provider::remote_common::RemoteProviderBase;
221 use crate::traits::ModelProvider;
222 use std::time::Duration;
223
224 static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
225
226 fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
227 ModelAliasSpec {
228 alias: alias.to_string(),
229 task,
230 provider_id: "remote/voyageai".to_string(),
231 model_id: model_id.to_string(),
232 revision: None,
233 warmup: crate::api::WarmupPolicy::Lazy,
234 required: false,
235 timeout: None,
236 load_timeout: None,
237 retry: None,
238 options: serde_json::Value::Null,
239 }
240 }
241
242 #[tokio::test]
243 async fn breaker_reused_for_same_runtime_key() {
244 let _lock = ENV_LOCK.lock().await;
245 unsafe { std::env::set_var("VOYAGE_API_KEY", "test-key") };
246
247 let provider = RemoteVoyageAIProvider::new();
248 let s1 = spec("embed/a", ModelTask::Embed, "voyage-3");
249 let s2 = spec("embed/b", ModelTask::Embed, "voyage-3");
250
251 let _ = provider.load(&s1).await.unwrap();
252 let _ = provider.load(&s2).await.unwrap();
253
254 assert_eq!(provider.breaker_count(), 1);
255
256 unsafe { std::env::remove_var("VOYAGE_API_KEY") };
257 }
258
259 #[tokio::test]
260 async fn breaker_isolated_by_task_and_model() {
261 let _lock = ENV_LOCK.lock().await;
262 unsafe { std::env::set_var("VOYAGE_API_KEY", "test-key") };
263
264 let provider = RemoteVoyageAIProvider::new();
265 let embed = spec("embed/a", ModelTask::Embed, "voyage-3");
266 let rerank = spec("rerank/a", ModelTask::Rerank, "rerank-2");
267
268 let _ = provider.load(&embed).await.unwrap();
269 let _ = provider.load(&rerank).await.unwrap();
270
271 assert_eq!(provider.breaker_count(), 2);
272
273 unsafe { std::env::remove_var("VOYAGE_API_KEY") };
274 }
275
276 #[tokio::test]
277 async fn breaker_cleanup_evicts_stale_entries() {
278 let _lock = ENV_LOCK.lock().await;
279 unsafe { std::env::set_var("VOYAGE_API_KEY", "test-key") };
280
281 let provider = RemoteVoyageAIProvider::new();
282 let stale = spec("embed/stale", ModelTask::Embed, "voyage-3");
283 let fresh = spec("rerank/fresh", ModelTask::Rerank, "rerank-2");
284 provider.insert_test_breaker(
285 ModelRuntimeKey::new(&stale),
286 RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
287 );
288 provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
289 assert_eq!(provider.breaker_count(), 2);
290
291 provider.force_cleanup_now_for_test();
292 let _ = provider.load(&fresh).await.unwrap();
293
294 assert_eq!(provider.breaker_count(), 1);
295
296 unsafe { std::env::remove_var("VOYAGE_API_KEY") };
297 }
298
299 #[tokio::test]
300 async fn generate_capability_mismatch() {
301 let _lock = ENV_LOCK.lock().await;
302 unsafe { std::env::set_var("VOYAGE_API_KEY", "test-key") };
303
304 let provider = RemoteVoyageAIProvider::new();
305 let s = spec("gen/a", ModelTask::Generate, "voyage-3");
306 let result = provider.load(&s).await;
307 assert!(result.is_err());
308 assert!(
309 result
310 .unwrap_err()
311 .to_string()
312 .contains("does not support task")
313 );
314
315 unsafe { std::env::remove_var("VOYAGE_API_KEY") };
316 }
317}