1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{
4 RemoteProviderBase, build_google_generate_payload, check_http_status,
5};
6use crate::traits::{
7 EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
8 ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde_json::json;
13use std::sync::Arc;
14
15fn options_map<'a>(
16 provider_id: &str,
17 options: &'a serde_json::Value,
18) -> Result<Option<&'a serde_json::Map<String, serde_json::Value>>> {
19 match options {
20 serde_json::Value::Null => Ok(None),
21 serde_json::Value::Object(map) => Ok(Some(map)),
22 _ => Err(RuntimeError::Config(format!(
23 "Options for provider '{}' must be a JSON object or null",
24 provider_id
25 ))),
26 }
27}
28
29fn option_string(
30 provider_id: &str,
31 map: Option<&serde_json::Map<String, serde_json::Value>>,
32 key: &str,
33) -> Result<Option<String>> {
34 let Some(map) = map else {
35 return Ok(None);
36 };
37 let Some(value) = map.get(key) else {
38 return Ok(None);
39 };
40 let s = value.as_str().ok_or_else(|| {
41 RuntimeError::Config(format!(
42 "Option '{}' for provider '{}' must be a string",
43 key, provider_id
44 ))
45 })?;
46 Ok(Some(s.to_string()))
47}
48
49fn option_u32(
50 provider_id: &str,
51 map: Option<&serde_json::Map<String, serde_json::Value>>,
52 key: &str,
53) -> Result<Option<u32>> {
54 let Some(map) = map else {
55 return Ok(None);
56 };
57 let Some(value) = map.get(key) else {
58 return Ok(None);
59 };
60 let n = value.as_u64().ok_or_else(|| {
61 RuntimeError::Config(format!(
62 "Option '{}' for provider '{}' must be a positive integer",
63 key, provider_id
64 ))
65 })?;
66 if n == 0 {
67 return Err(RuntimeError::Config(format!(
68 "Option '{}' for provider '{}' must be greater than 0",
69 key, provider_id
70 )));
71 }
72 let n_u32 = u32::try_from(n).map_err(|_| {
73 RuntimeError::Config(format!(
74 "Option '{}' for provider '{}' is out of range for u32",
75 key, provider_id
76 ))
77 })?;
78 Ok(Some(n_u32))
79}
80
81#[derive(Clone)]
84struct VertexAiResolvedOptions {
85 token: String,
86 project_id: String,
87 location: String,
88 publisher: String,
89 embedding_dimensions: Option<u32>,
90}
91
92impl VertexAiResolvedOptions {
93 fn from_spec(spec: &ModelAliasSpec) -> Result<Self> {
94 let provider_id = "remote/vertexai";
95 let map = options_map(provider_id, &spec.options)?;
96
97 let token_env = option_string(provider_id, map, "api_token_env")?
98 .unwrap_or_else(|| "VERTEX_AI_TOKEN".to_string());
99 let token = std::env::var(&token_env)
100 .map_err(|_| RuntimeError::Config(format!("{} env var not set", token_env)))?;
101
102 let project_id = if let Some(project_id) = option_string(provider_id, map, "project_id")? {
103 project_id
104 } else {
105 std::env::var("VERTEX_AI_PROJECT").map_err(|_| {
106 RuntimeError::Config(
107 "project_id option not set and VERTEX_AI_PROJECT env var not set".to_string(),
108 )
109 })?
110 };
111
112 let location =
113 option_string(provider_id, map, "location")?.unwrap_or_else(|| "us-central1".into());
114 let publisher =
115 option_string(provider_id, map, "publisher")?.unwrap_or_else(|| "google".into());
116 let embedding_dimensions = option_u32(provider_id, map, "embedding_dimensions")?;
117
118 Ok(Self {
119 token,
120 project_id,
121 location,
122 publisher,
123 embedding_dimensions,
124 })
125 }
126}
127
128pub struct RemoteVertexAIProvider {
135 base: RemoteProviderBase,
136}
137
138impl RemoteVertexAIProvider {
139 pub fn new() -> Self {
140 Self::default()
141 }
142
143 #[cfg(test)]
144 fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
145 self.base.insert_test_breaker(key, age);
146 }
147
148 #[cfg(test)]
149 fn breaker_count(&self) -> usize {
150 self.base.breaker_count()
151 }
152
153 #[cfg(test)]
154 fn force_cleanup_now_for_test(&self) {
155 self.base.force_cleanup_now_for_test();
156 }
157}
158
159impl Default for RemoteVertexAIProvider {
160 fn default() -> Self {
161 Self {
162 base: RemoteProviderBase::new(),
163 }
164 }
165}
166
167#[async_trait]
168impl ModelProvider for RemoteVertexAIProvider {
169 fn provider_id(&self) -> &'static str {
170 "remote/vertexai"
171 }
172
173 fn capabilities(&self) -> ProviderCapabilities {
174 ProviderCapabilities {
175 supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
176 }
177 }
178
179 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
180 let cb = self.base.circuit_breaker_for(spec);
181 let resolved = VertexAiResolvedOptions::from_spec(spec)?;
182
183 match spec.task {
184 ModelTask::Embed => {
185 let model = VertexAiEmbeddingModel {
186 client: self.base.client.clone(),
187 cb: cb.clone(),
188 model_id: spec.model_id.clone(),
189 options: resolved.clone(),
190 dimensions: resolved.embedding_dimensions.unwrap_or(768),
191 };
192 let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
193 Ok(Arc::new(handle) as LoadedModelHandle)
194 }
195 ModelTask::Generate => {
196 let model = VertexAiGeneratorModel {
197 client: self.base.client.clone(),
198 cb,
199 model_id: spec.model_id.clone(),
200 options: resolved,
201 };
202 let handle: Arc<dyn GeneratorModel> = Arc::new(model);
203 Ok(Arc::new(handle) as LoadedModelHandle)
204 }
205 _ => Err(RuntimeError::CapabilityMismatch(format!(
206 "Vertex AI provider does not support task {:?}",
207 spec.task
208 ))),
209 }
210 }
211
212 async fn health(&self) -> ProviderHealth {
213 ProviderHealth::Healthy
214 }
215}
216
217pub struct VertexAiEmbeddingModel {
219 client: Client,
220 cb: crate::reliability::CircuitBreakerWrapper,
221 model_id: String,
222 options: VertexAiResolvedOptions,
223 dimensions: u32,
224}
225
226impl VertexAiEmbeddingModel {
227 fn endpoint_url(&self) -> String {
228 format!(
229 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/{}/models/{}:predict",
230 self.options.location,
231 self.options.project_id,
232 self.options.location,
233 self.options.publisher,
234 self.model_id
235 )
236 }
237}
238
239#[async_trait]
240impl EmbeddingModel for VertexAiEmbeddingModel {
241 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
242 let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
243
244 self.cb
245 .call(move || async move {
246 let instances: Vec<_> = texts.iter().map(|t| json!({ "content": t })).collect();
247 let response = self
248 .client
249 .post(self.endpoint_url())
250 .header("Authorization", format!("Bearer {}", self.options.token))
251 .json(&json!({ "instances": instances }))
252 .send()
253 .await
254 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
255
256 let body: serde_json::Value = check_http_status("Vertex AI", response)?
257 .json()
258 .await
259 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
260
261 let predictions = body
262 .get("predictions")
263 .and_then(|v| v.as_array())
264 .ok_or_else(|| {
265 RuntimeError::ApiError("Invalid response: missing predictions".to_string())
266 })?;
267
268 let mut result = Vec::new();
269 for item in predictions {
270 let values_opt = item
271 .get("embeddings")
272 .and_then(|e| e.get("values").and_then(|v| v.as_array()))
273 .or_else(|| {
274 item.get("embeddings")
275 .and_then(|e| e.as_array())
276 .or_else(|| item.get("values").and_then(|v| v.as_array()))
277 });
278
279 let values = values_opt.ok_or_else(|| {
280 RuntimeError::ApiError(
281 "Invalid embedding format in Vertex AI response".to_string(),
282 )
283 })?;
284
285 let vec: Vec<f32> = values
286 .iter()
287 .filter_map(|v| v.as_f64().map(|f| f as f32))
288 .collect();
289 result.push(vec);
290 }
291
292 Ok(result)
293 })
294 .await
295 }
296
297 fn dimensions(&self) -> u32 {
298 self.dimensions
299 }
300
301 fn model_id(&self) -> &str {
302 &self.model_id
303 }
304}
305
306pub struct VertexAiGeneratorModel {
308 client: Client,
309 cb: crate::reliability::CircuitBreakerWrapper,
310 model_id: String,
311 options: VertexAiResolvedOptions,
312}
313
314impl VertexAiGeneratorModel {
315 fn endpoint_url(&self) -> String {
316 format!(
317 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/{}/models/{}:generateContent",
318 self.options.location,
319 self.options.project_id,
320 self.options.location,
321 self.options.publisher,
322 self.model_id
323 )
324 }
325}
326
327#[async_trait]
328impl GeneratorModel for VertexAiGeneratorModel {
329 async fn generate(
330 &self,
331 messages: &[String],
332 options: GenerationOptions,
333 ) -> Result<GenerationResult> {
334 let messages: Vec<String> = messages.iter().map(|s| s.to_string()).collect();
335
336 self.cb
337 .call(move || async move {
338 let payload = build_google_generate_payload(&messages, &options);
339 let response = self
340 .client
341 .post(self.endpoint_url())
342 .header("Authorization", format!("Bearer {}", self.options.token))
343 .json(&payload)
344 .send()
345 .await
346 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
347
348 let body: serde_json::Value = check_http_status("Vertex AI", response)?
349 .json()
350 .await
351 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
352
353 let candidates = body
354 .get("candidates")
355 .and_then(|v| v.as_array())
356 .ok_or_else(|| RuntimeError::ApiError("No candidates returned".to_string()))?;
357
358 let first_candidate = candidates
359 .first()
360 .ok_or_else(|| RuntimeError::ApiError("Empty candidates".to_string()))?;
361
362 let content_parts = first_candidate
363 .get("content")
364 .and_then(|c| c.get("parts"))
365 .and_then(|p| p.as_array())
366 .ok_or_else(|| RuntimeError::ApiError("Invalid content format".to_string()))?;
367
368 let text = content_parts
369 .first()
370 .and_then(|p| p.get("text"))
371 .and_then(|t| t.as_str())
372 .unwrap_or("")
373 .to_string();
374
375 let usage = body.get("usageMetadata").map(|u| TokenUsage {
376 prompt_tokens: u["promptTokenCount"].as_u64().unwrap_or(0) as usize,
377 completion_tokens: u["candidatesTokenCount"].as_u64().unwrap_or(0) as usize,
378 total_tokens: u["totalTokenCount"].as_u64().unwrap_or(0) as usize,
379 });
380
381 Ok(GenerationResult { text, usage })
382 })
383 .await
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::api::ModelRuntimeKey;
391 use crate::provider::remote_common::RemoteProviderBase;
392 use crate::traits::ModelProvider;
393 use std::time::Duration;
394
395 static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
396
397 fn spec(
398 alias: &str,
399 task: ModelTask,
400 model_id: &str,
401 options: serde_json::Value,
402 ) -> ModelAliasSpec {
403 ModelAliasSpec {
404 alias: alias.to_string(),
405 task,
406 provider_id: "remote/vertexai".to_string(),
407 model_id: model_id.to_string(),
408 revision: None,
409 warmup: crate::api::WarmupPolicy::Lazy,
410 required: false,
411 timeout: None,
412 load_timeout: None,
413 retry: None,
414 options,
415 }
416 }
417
418 #[tokio::test]
419 async fn breaker_reused_for_same_runtime_key() {
420 let _lock = ENV_LOCK.lock().await;
421 unsafe {
423 std::env::set_var("VERTEX_AI_TOKEN", "test-token");
424 std::env::set_var("VERTEX_AI_PROJECT", "test-project");
425 }
426
427 let provider = RemoteVertexAIProvider::new();
428 let s1 = spec(
429 "embed/a",
430 ModelTask::Embed,
431 "text-embedding-005",
432 serde_json::Value::Null,
433 );
434 let s2 = spec(
435 "embed/b",
436 ModelTask::Embed,
437 "text-embedding-005",
438 serde_json::Value::Null,
439 );
440
441 let _ = provider.load(&s1).await.unwrap();
442 let _ = provider.load(&s2).await.unwrap();
443
444 assert_eq!(provider.breaker_count(), 1);
445
446 unsafe {
448 std::env::remove_var("VERTEX_AI_TOKEN");
449 std::env::remove_var("VERTEX_AI_PROJECT");
450 }
451 }
452
453 #[tokio::test]
454 async fn breaker_cleanup_evicts_stale_entries() {
455 let _lock = ENV_LOCK.lock().await;
456 unsafe {
458 std::env::set_var("VERTEX_AI_TOKEN", "test-token");
459 std::env::set_var("VERTEX_AI_PROJECT", "test-project");
460 }
461
462 let provider = RemoteVertexAIProvider::new();
463 let stale = spec(
464 "embed/stale",
465 ModelTask::Embed,
466 "text-embedding-005",
467 serde_json::Value::Null,
468 );
469 let fresh = spec(
470 "embed/fresh",
471 ModelTask::Embed,
472 "text-embedding-004",
473 serde_json::Value::Null,
474 );
475 provider.insert_test_breaker(
476 ModelRuntimeKey::new(&stale),
477 RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
478 );
479 provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
480 assert_eq!(provider.breaker_count(), 2);
481
482 provider.force_cleanup_now_for_test();
483 let _ = provider.load(&fresh).await.unwrap();
484 assert_eq!(provider.breaker_count(), 1);
485
486 unsafe {
488 std::env::remove_var("VERTEX_AI_TOKEN");
489 std::env::remove_var("VERTEX_AI_PROJECT");
490 }
491 }
492
493 #[tokio::test]
494 async fn load_fails_when_project_is_missing() {
495 let _lock = ENV_LOCK.lock().await;
496 unsafe {
498 std::env::set_var("VERTEX_AI_TOKEN", "test-token");
499 std::env::remove_var("VERTEX_AI_PROJECT");
500 }
501
502 let provider = RemoteVertexAIProvider::new();
503 let s = spec(
504 "embed/a",
505 ModelTask::Embed,
506 "text-embedding-005",
507 serde_json::Value::Null,
508 );
509 let err = provider.load(&s).await.unwrap_err();
510 assert!(err.to_string().contains("VERTEX_AI_PROJECT"));
511
512 unsafe {
514 std::env::remove_var("VERTEX_AI_TOKEN");
515 }
516 }
517
518 #[test]
519 fn generation_payload_alternates_roles() {
520 let messages = vec![
521 "user question".to_string(),
522 "assistant answer".to_string(),
523 "user follow-up".to_string(),
524 ];
525 let payload = build_google_generate_payload(&messages, &GenerationOptions::default());
526 let contents = payload["contents"].as_array().unwrap();
527
528 assert_eq!(contents[0]["role"], "user");
529 assert_eq!(contents[1]["role"], "model");
530 assert_eq!(contents[2]["role"], "user");
531 }
532
533 #[test]
534 fn generation_payload_includes_generation_options() {
535 let messages = vec!["hello".to_string()];
536 let payload = build_google_generate_payload(
537 &messages,
538 &GenerationOptions {
539 max_tokens: Some(64),
540 temperature: Some(0.7),
541 top_p: Some(0.9),
542 },
543 );
544
545 assert_eq!(payload["generationConfig"]["maxOutputTokens"], 64);
546 let temperature = payload["generationConfig"]["temperature"].as_f64().unwrap();
547 let top_p = payload["generationConfig"]["topP"].as_f64().unwrap();
548 assert!((temperature - 0.7).abs() < 1e-6);
549 assert!((top_p - 0.9).abs() < 1e-6);
550 }
551}