uni_xervo/provider/
vertexai.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{
4    RemoteProviderBase, build_google_generate_payload, check_http_status,
5};
6use crate::traits::{
7    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
8    ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde_json::json;
13use std::sync::Arc;
14
15fn options_map<'a>(
16    provider_id: &str,
17    options: &'a serde_json::Value,
18) -> Result<Option<&'a serde_json::Map<String, serde_json::Value>>> {
19    match options {
20        serde_json::Value::Null => Ok(None),
21        serde_json::Value::Object(map) => Ok(Some(map)),
22        _ => Err(RuntimeError::Config(format!(
23            "Options for provider '{}' must be a JSON object or null",
24            provider_id
25        ))),
26    }
27}
28
29fn option_string(
30    provider_id: &str,
31    map: Option<&serde_json::Map<String, serde_json::Value>>,
32    key: &str,
33) -> Result<Option<String>> {
34    let Some(map) = map else {
35        return Ok(None);
36    };
37    let Some(value) = map.get(key) else {
38        return Ok(None);
39    };
40    let s = value.as_str().ok_or_else(|| {
41        RuntimeError::Config(format!(
42            "Option '{}' for provider '{}' must be a string",
43            key, provider_id
44        ))
45    })?;
46    Ok(Some(s.to_string()))
47}
48
49fn option_u32(
50    provider_id: &str,
51    map: Option<&serde_json::Map<String, serde_json::Value>>,
52    key: &str,
53) -> Result<Option<u32>> {
54    let Some(map) = map else {
55        return Ok(None);
56    };
57    let Some(value) = map.get(key) else {
58        return Ok(None);
59    };
60    let n = value.as_u64().ok_or_else(|| {
61        RuntimeError::Config(format!(
62            "Option '{}' for provider '{}' must be a positive integer",
63            key, provider_id
64        ))
65    })?;
66    if n == 0 {
67        return Err(RuntimeError::Config(format!(
68            "Option '{}' for provider '{}' must be greater than 0",
69            key, provider_id
70        )));
71    }
72    let n_u32 = u32::try_from(n).map_err(|_| {
73        RuntimeError::Config(format!(
74            "Option '{}' for provider '{}' is out of range for u32",
75            key, provider_id
76        ))
77    })?;
78    Ok(Some(n_u32))
79}
80
81/// Resolved and validated Vertex AI configuration extracted from a
82/// [`ModelAliasSpec`]'s options and environment variables.
83#[derive(Clone)]
84struct VertexAiResolvedOptions {
85    token: String,
86    project_id: String,
87    location: String,
88    publisher: String,
89    embedding_dimensions: Option<u32>,
90}
91
92impl VertexAiResolvedOptions {
93    fn from_spec(spec: &ModelAliasSpec) -> Result<Self> {
94        let provider_id = "remote/vertexai";
95        let map = options_map(provider_id, &spec.options)?;
96
97        let token_env = option_string(provider_id, map, "api_token_env")?
98            .unwrap_or_else(|| "VERTEX_AI_TOKEN".to_string());
99        let token = std::env::var(&token_env)
100            .map_err(|_| RuntimeError::Config(format!("{} env var not set", token_env)))?;
101
102        let project_id = if let Some(project_id) = option_string(provider_id, map, "project_id")? {
103            project_id
104        } else {
105            std::env::var("VERTEX_AI_PROJECT").map_err(|_| {
106                RuntimeError::Config(
107                    "project_id option not set and VERTEX_AI_PROJECT env var not set".to_string(),
108                )
109            })?
110        };
111
112        let location =
113            option_string(provider_id, map, "location")?.unwrap_or_else(|| "us-central1".into());
114        let publisher =
115            option_string(provider_id, map, "publisher")?.unwrap_or_else(|| "google".into());
116        let embedding_dimensions = option_u32(provider_id, map, "embedding_dimensions")?;
117
118        Ok(Self {
119            token,
120            project_id,
121            location,
122            publisher,
123            embedding_dimensions,
124        })
125    }
126}
127
128/// Remote provider that calls the [Google Vertex AI](https://cloud.google.com/vertex-ai/docs)
129/// prediction and generation endpoints for embedding and text generation.
130///
131/// Requires the `VERTEX_AI_TOKEN` environment variable (or a custom env var
132/// via `api_token_env`) and either the `project_id` option or the
133/// `VERTEX_AI_PROJECT` env var.
134pub struct RemoteVertexAIProvider {
135    base: RemoteProviderBase,
136}
137
138impl RemoteVertexAIProvider {
139    pub fn new() -> Self {
140        Self::default()
141    }
142
143    #[cfg(test)]
144    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
145        self.base.insert_test_breaker(key, age);
146    }
147
148    #[cfg(test)]
149    fn breaker_count(&self) -> usize {
150        self.base.breaker_count()
151    }
152
153    #[cfg(test)]
154    fn force_cleanup_now_for_test(&self) {
155        self.base.force_cleanup_now_for_test();
156    }
157}
158
159impl Default for RemoteVertexAIProvider {
160    fn default() -> Self {
161        Self {
162            base: RemoteProviderBase::new(),
163        }
164    }
165}
166
167#[async_trait]
168impl ModelProvider for RemoteVertexAIProvider {
169    fn provider_id(&self) -> &'static str {
170        "remote/vertexai"
171    }
172
173    fn capabilities(&self) -> ProviderCapabilities {
174        ProviderCapabilities {
175            supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
176        }
177    }
178
179    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
180        let cb = self.base.circuit_breaker_for(spec);
181        let resolved = VertexAiResolvedOptions::from_spec(spec)?;
182
183        match spec.task {
184            ModelTask::Embed => {
185                let model = VertexAiEmbeddingModel {
186                    client: self.base.client.clone(),
187                    cb: cb.clone(),
188                    model_id: spec.model_id.clone(),
189                    options: resolved.clone(),
190                    dimensions: resolved.embedding_dimensions.unwrap_or(768),
191                };
192                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
193                Ok(Arc::new(handle) as LoadedModelHandle)
194            }
195            ModelTask::Generate => {
196                let model = VertexAiGeneratorModel {
197                    client: self.base.client.clone(),
198                    cb,
199                    model_id: spec.model_id.clone(),
200                    options: resolved,
201                };
202                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
203                Ok(Arc::new(handle) as LoadedModelHandle)
204            }
205            _ => Err(RuntimeError::CapabilityMismatch(format!(
206                "Vertex AI provider does not support task {:?}",
207                spec.task
208            ))),
209        }
210    }
211
212    async fn health(&self) -> ProviderHealth {
213        ProviderHealth::Healthy
214    }
215}
216
217/// Embedding model backed by the Vertex AI prediction API.
218pub struct VertexAiEmbeddingModel {
219    client: Client,
220    cb: crate::reliability::CircuitBreakerWrapper,
221    model_id: String,
222    options: VertexAiResolvedOptions,
223    dimensions: u32,
224}
225
226impl VertexAiEmbeddingModel {
227    fn endpoint_url(&self) -> String {
228        format!(
229            "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/{}/models/{}:predict",
230            self.options.location,
231            self.options.project_id,
232            self.options.location,
233            self.options.publisher,
234            self.model_id
235        )
236    }
237}
238
239#[async_trait]
240impl EmbeddingModel for VertexAiEmbeddingModel {
241    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
242        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
243
244        self.cb
245            .call(move || async move {
246                let instances: Vec<_> = texts.iter().map(|t| json!({ "content": t })).collect();
247                let response = self
248                    .client
249                    .post(self.endpoint_url())
250                    .header("Authorization", format!("Bearer {}", self.options.token))
251                    .json(&json!({ "instances": instances }))
252                    .send()
253                    .await
254                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
255
256                let body: serde_json::Value = check_http_status("Vertex AI", response)?
257                    .json()
258                    .await
259                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
260
261                let predictions = body
262                    .get("predictions")
263                    .and_then(|v| v.as_array())
264                    .ok_or_else(|| {
265                        RuntimeError::ApiError("Invalid response: missing predictions".to_string())
266                    })?;
267
268                let mut result = Vec::new();
269                for item in predictions {
270                    let values_opt = item
271                        .get("embeddings")
272                        .and_then(|e| e.get("values").and_then(|v| v.as_array()))
273                        .or_else(|| {
274                            item.get("embeddings")
275                                .and_then(|e| e.as_array())
276                                .or_else(|| item.get("values").and_then(|v| v.as_array()))
277                        });
278
279                    let values = values_opt.ok_or_else(|| {
280                        RuntimeError::ApiError(
281                            "Invalid embedding format in Vertex AI response".to_string(),
282                        )
283                    })?;
284
285                    let vec: Vec<f32> = values
286                        .iter()
287                        .filter_map(|v| v.as_f64().map(|f| f as f32))
288                        .collect();
289                    result.push(vec);
290                }
291
292                Ok(result)
293            })
294            .await
295    }
296
297    fn dimensions(&self) -> u32 {
298        self.dimensions
299    }
300
301    fn model_id(&self) -> &str {
302        &self.model_id
303    }
304}
305
306/// Text generation model backed by the Vertex AI `generateContent` endpoint.
307pub struct VertexAiGeneratorModel {
308    client: Client,
309    cb: crate::reliability::CircuitBreakerWrapper,
310    model_id: String,
311    options: VertexAiResolvedOptions,
312}
313
314impl VertexAiGeneratorModel {
315    fn endpoint_url(&self) -> String {
316        format!(
317            "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/{}/models/{}:generateContent",
318            self.options.location,
319            self.options.project_id,
320            self.options.location,
321            self.options.publisher,
322            self.model_id
323        )
324    }
325}
326
327#[async_trait]
328impl GeneratorModel for VertexAiGeneratorModel {
329    async fn generate(
330        &self,
331        messages: &[String],
332        options: GenerationOptions,
333    ) -> Result<GenerationResult> {
334        let messages: Vec<String> = messages.iter().map(|s| s.to_string()).collect();
335
336        self.cb
337            .call(move || async move {
338                let payload = build_google_generate_payload(&messages, &options);
339                let response = self
340                    .client
341                    .post(self.endpoint_url())
342                    .header("Authorization", format!("Bearer {}", self.options.token))
343                    .json(&payload)
344                    .send()
345                    .await
346                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
347
348                let body: serde_json::Value = check_http_status("Vertex AI", response)?
349                    .json()
350                    .await
351                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
352
353                let candidates = body
354                    .get("candidates")
355                    .and_then(|v| v.as_array())
356                    .ok_or_else(|| RuntimeError::ApiError("No candidates returned".to_string()))?;
357
358                let first_candidate = candidates
359                    .first()
360                    .ok_or_else(|| RuntimeError::ApiError("Empty candidates".to_string()))?;
361
362                let content_parts = first_candidate
363                    .get("content")
364                    .and_then(|c| c.get("parts"))
365                    .and_then(|p| p.as_array())
366                    .ok_or_else(|| RuntimeError::ApiError("Invalid content format".to_string()))?;
367
368                let text = content_parts
369                    .first()
370                    .and_then(|p| p.get("text"))
371                    .and_then(|t| t.as_str())
372                    .unwrap_or("")
373                    .to_string();
374
375                let usage = body.get("usageMetadata").map(|u| TokenUsage {
376                    prompt_tokens: u["promptTokenCount"].as_u64().unwrap_or(0) as usize,
377                    completion_tokens: u["candidatesTokenCount"].as_u64().unwrap_or(0) as usize,
378                    total_tokens: u["totalTokenCount"].as_u64().unwrap_or(0) as usize,
379                });
380
381                Ok(GenerationResult { text, usage })
382            })
383            .await
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use crate::api::ModelRuntimeKey;
391    use crate::provider::remote_common::RemoteProviderBase;
392    use crate::traits::ModelProvider;
393    use std::time::Duration;
394
395    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
396
397    fn spec(
398        alias: &str,
399        task: ModelTask,
400        model_id: &str,
401        options: serde_json::Value,
402    ) -> ModelAliasSpec {
403        ModelAliasSpec {
404            alias: alias.to_string(),
405            task,
406            provider_id: "remote/vertexai".to_string(),
407            model_id: model_id.to_string(),
408            revision: None,
409            warmup: crate::api::WarmupPolicy::Lazy,
410            required: false,
411            timeout: None,
412            load_timeout: None,
413            retry: None,
414            options,
415        }
416    }
417
418    #[tokio::test]
419    async fn breaker_reused_for_same_runtime_key() {
420        let _lock = ENV_LOCK.lock().await;
421        // SAFETY: protected by ENV_LOCK
422        unsafe {
423            std::env::set_var("VERTEX_AI_TOKEN", "test-token");
424            std::env::set_var("VERTEX_AI_PROJECT", "test-project");
425        }
426
427        let provider = RemoteVertexAIProvider::new();
428        let s1 = spec(
429            "embed/a",
430            ModelTask::Embed,
431            "text-embedding-005",
432            serde_json::Value::Null,
433        );
434        let s2 = spec(
435            "embed/b",
436            ModelTask::Embed,
437            "text-embedding-005",
438            serde_json::Value::Null,
439        );
440
441        let _ = provider.load(&s1).await.unwrap();
442        let _ = provider.load(&s2).await.unwrap();
443
444        assert_eq!(provider.breaker_count(), 1);
445
446        // SAFETY: protected by ENV_LOCK
447        unsafe {
448            std::env::remove_var("VERTEX_AI_TOKEN");
449            std::env::remove_var("VERTEX_AI_PROJECT");
450        }
451    }
452
453    #[tokio::test]
454    async fn breaker_cleanup_evicts_stale_entries() {
455        let _lock = ENV_LOCK.lock().await;
456        // SAFETY: protected by ENV_LOCK
457        unsafe {
458            std::env::set_var("VERTEX_AI_TOKEN", "test-token");
459            std::env::set_var("VERTEX_AI_PROJECT", "test-project");
460        }
461
462        let provider = RemoteVertexAIProvider::new();
463        let stale = spec(
464            "embed/stale",
465            ModelTask::Embed,
466            "text-embedding-005",
467            serde_json::Value::Null,
468        );
469        let fresh = spec(
470            "embed/fresh",
471            ModelTask::Embed,
472            "text-embedding-004",
473            serde_json::Value::Null,
474        );
475        provider.insert_test_breaker(
476            ModelRuntimeKey::new(&stale),
477            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
478        );
479        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
480        assert_eq!(provider.breaker_count(), 2);
481
482        provider.force_cleanup_now_for_test();
483        let _ = provider.load(&fresh).await.unwrap();
484        assert_eq!(provider.breaker_count(), 1);
485
486        // SAFETY: protected by ENV_LOCK
487        unsafe {
488            std::env::remove_var("VERTEX_AI_TOKEN");
489            std::env::remove_var("VERTEX_AI_PROJECT");
490        }
491    }
492
493    #[tokio::test]
494    async fn load_fails_when_project_is_missing() {
495        let _lock = ENV_LOCK.lock().await;
496        // SAFETY: protected by ENV_LOCK
497        unsafe {
498            std::env::set_var("VERTEX_AI_TOKEN", "test-token");
499            std::env::remove_var("VERTEX_AI_PROJECT");
500        }
501
502        let provider = RemoteVertexAIProvider::new();
503        let s = spec(
504            "embed/a",
505            ModelTask::Embed,
506            "text-embedding-005",
507            serde_json::Value::Null,
508        );
509        let err = provider.load(&s).await.unwrap_err();
510        assert!(err.to_string().contains("VERTEX_AI_PROJECT"));
511
512        // SAFETY: protected by ENV_LOCK
513        unsafe {
514            std::env::remove_var("VERTEX_AI_TOKEN");
515        }
516    }
517
518    #[test]
519    fn generation_payload_alternates_roles() {
520        let messages = vec![
521            "user question".to_string(),
522            "assistant answer".to_string(),
523            "user follow-up".to_string(),
524        ];
525        let payload = build_google_generate_payload(&messages, &GenerationOptions::default());
526        let contents = payload["contents"].as_array().unwrap();
527
528        assert_eq!(contents[0]["role"], "user");
529        assert_eq!(contents[1]["role"], "model");
530        assert_eq!(contents[2]["role"], "user");
531    }
532
533    #[test]
534    fn generation_payload_includes_generation_options() {
535        let messages = vec!["hello".to_string()];
536        let payload = build_google_generate_payload(
537            &messages,
538            &GenerationOptions {
539                max_tokens: Some(64),
540                temperature: Some(0.7),
541                top_p: Some(0.9),
542            },
543        );
544
545        assert_eq!(payload["generationConfig"]["maxOutputTokens"], 64);
546        let temperature = payload["generationConfig"]["temperature"].as_f64().unwrap();
547        let top_p = payload["generationConfig"]["topP"].as_f64().unwrap();
548        assert!((temperature - 0.7).abs() < 1e-6);
549        assert!((top_p - 0.9).abs() < 1e-6);
550    }
551}