uni_xervo/provider/
vertexai.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{
4    RemoteProviderBase, build_google_generate_payload, check_http_status,
5};
6use crate::traits::{
7    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
8    Message, ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde_json::json;
13use std::sync::Arc;
14
15fn options_map<'a>(
16    provider_id: &str,
17    options: &'a serde_json::Value,
18) -> Result<Option<&'a serde_json::Map<String, serde_json::Value>>> {
19    match options {
20        serde_json::Value::Null => Ok(None),
21        serde_json::Value::Object(map) => Ok(Some(map)),
22        _ => Err(RuntimeError::Config(format!(
23            "Options for provider '{}' must be a JSON object or null",
24            provider_id
25        ))),
26    }
27}
28
29fn option_string(
30    provider_id: &str,
31    map: Option<&serde_json::Map<String, serde_json::Value>>,
32    key: &str,
33) -> Result<Option<String>> {
34    let Some(map) = map else {
35        return Ok(None);
36    };
37    let Some(value) = map.get(key) else {
38        return Ok(None);
39    };
40    let s = value.as_str().ok_or_else(|| {
41        RuntimeError::Config(format!(
42            "Option '{}' for provider '{}' must be a string",
43            key, provider_id
44        ))
45    })?;
46    Ok(Some(s.to_string()))
47}
48
49fn option_u32(
50    provider_id: &str,
51    map: Option<&serde_json::Map<String, serde_json::Value>>,
52    key: &str,
53) -> Result<Option<u32>> {
54    let Some(map) = map else {
55        return Ok(None);
56    };
57    let Some(value) = map.get(key) else {
58        return Ok(None);
59    };
60    let n = value.as_u64().ok_or_else(|| {
61        RuntimeError::Config(format!(
62            "Option '{}' for provider '{}' must be a positive integer",
63            key, provider_id
64        ))
65    })?;
66    if n == 0 {
67        return Err(RuntimeError::Config(format!(
68            "Option '{}' for provider '{}' must be greater than 0",
69            key, provider_id
70        )));
71    }
72    let n_u32 = u32::try_from(n).map_err(|_| {
73        RuntimeError::Config(format!(
74            "Option '{}' for provider '{}' is out of range for u32",
75            key, provider_id
76        ))
77    })?;
78    Ok(Some(n_u32))
79}
80
81/// Resolved and validated Vertex AI configuration extracted from a
82/// [`ModelAliasSpec`]'s options and environment variables.
83#[derive(Clone)]
84struct VertexAiResolvedOptions {
85    token: String,
86    project_id: String,
87    location: String,
88    publisher: String,
89    embedding_dimensions: Option<u32>,
90}
91
92impl VertexAiResolvedOptions {
93    fn from_spec(spec: &ModelAliasSpec) -> Result<Self> {
94        let provider_id = "remote/vertexai";
95        let map = options_map(provider_id, &spec.options)?;
96
97        let token_env = option_string(provider_id, map, "api_token_env")?
98            .unwrap_or_else(|| "VERTEX_AI_TOKEN".to_string());
99        let token = std::env::var(&token_env)
100            .map_err(|_| RuntimeError::Config(format!("{} env var not set", token_env)))?;
101
102        let project_id = if let Some(project_id) = option_string(provider_id, map, "project_id")? {
103            project_id
104        } else {
105            std::env::var("VERTEX_AI_PROJECT").map_err(|_| {
106                RuntimeError::Config(
107                    "project_id option not set and VERTEX_AI_PROJECT env var not set".to_string(),
108                )
109            })?
110        };
111
112        let location =
113            option_string(provider_id, map, "location")?.unwrap_or_else(|| "us-central1".into());
114        let publisher =
115            option_string(provider_id, map, "publisher")?.unwrap_or_else(|| "google".into());
116        let embedding_dimensions = option_u32(provider_id, map, "embedding_dimensions")?;
117
118        Ok(Self {
119            token,
120            project_id,
121            location,
122            publisher,
123            embedding_dimensions,
124        })
125    }
126}
127
128/// Remote provider that calls the [Google Vertex AI](https://cloud.google.com/vertex-ai/docs)
129/// prediction and generation endpoints for embedding and text generation.
130///
131/// Requires the `VERTEX_AI_TOKEN` environment variable (or a custom env var
132/// via `api_token_env`) and either the `project_id` option or the
133/// `VERTEX_AI_PROJECT` env var.
134pub struct RemoteVertexAIProvider {
135    base: RemoteProviderBase,
136}
137
138impl RemoteVertexAIProvider {
139    pub fn new() -> Self {
140        Self::default()
141    }
142
143    #[cfg(test)]
144    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
145        self.base.insert_test_breaker(key, age);
146    }
147
148    #[cfg(test)]
149    fn breaker_count(&self) -> usize {
150        self.base.breaker_count()
151    }
152
153    #[cfg(test)]
154    fn force_cleanup_now_for_test(&self) {
155        self.base.force_cleanup_now_for_test();
156    }
157}
158
159impl Default for RemoteVertexAIProvider {
160    fn default() -> Self {
161        Self {
162            base: RemoteProviderBase::new(),
163        }
164    }
165}
166
167#[async_trait]
168impl ModelProvider for RemoteVertexAIProvider {
169    fn provider_id(&self) -> &'static str {
170        "remote/vertexai"
171    }
172
173    fn capabilities(&self) -> ProviderCapabilities {
174        ProviderCapabilities {
175            supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
176        }
177    }
178
179    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
180        let cb = self.base.circuit_breaker_for(spec);
181        let resolved = VertexAiResolvedOptions::from_spec(spec)?;
182
183        match spec.task {
184            ModelTask::Embed => {
185                let model = VertexAiEmbeddingModel {
186                    client: self.base.client.clone(),
187                    cb: cb.clone(),
188                    model_id: spec.model_id.clone(),
189                    options: resolved.clone(),
190                    dimensions: resolved.embedding_dimensions.unwrap_or(768),
191                };
192                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
193                Ok(Arc::new(handle) as LoadedModelHandle)
194            }
195            ModelTask::Generate => {
196                let model = VertexAiGeneratorModel {
197                    client: self.base.client.clone(),
198                    cb,
199                    model_id: spec.model_id.clone(),
200                    options: resolved,
201                };
202                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
203                Ok(Arc::new(handle) as LoadedModelHandle)
204            }
205            _ => Err(RuntimeError::CapabilityMismatch(format!(
206                "Vertex AI provider does not support task {:?}",
207                spec.task
208            ))),
209        }
210    }
211
212    async fn health(&self) -> ProviderHealth {
213        ProviderHealth::Healthy
214    }
215}
216
217/// Embedding model backed by the Vertex AI prediction API.
218pub struct VertexAiEmbeddingModel {
219    client: Client,
220    cb: crate::reliability::CircuitBreakerWrapper,
221    model_id: String,
222    options: VertexAiResolvedOptions,
223    dimensions: u32,
224}
225
226impl VertexAiEmbeddingModel {
227    fn endpoint_url(&self) -> String {
228        format!(
229            "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/{}/models/{}:predict",
230            self.options.location,
231            self.options.project_id,
232            self.options.location,
233            self.options.publisher,
234            self.model_id
235        )
236    }
237}
238
239#[async_trait]
240impl EmbeddingModel for VertexAiEmbeddingModel {
241    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
242        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
243
244        self.cb
245            .call(move || async move {
246                let instances: Vec<_> = texts.iter().map(|t| json!({ "content": t })).collect();
247                let response = self
248                    .client
249                    .post(self.endpoint_url())
250                    .header("Authorization", format!("Bearer {}", self.options.token))
251                    .json(&json!({ "instances": instances }))
252                    .send()
253                    .await
254                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
255
256                let body: serde_json::Value = check_http_status("Vertex AI", response)?
257                    .json()
258                    .await
259                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
260
261                let predictions = body
262                    .get("predictions")
263                    .and_then(|v| v.as_array())
264                    .ok_or_else(|| {
265                        RuntimeError::ApiError("Invalid response: missing predictions".to_string())
266                    })?;
267
268                let mut result = Vec::new();
269                for item in predictions {
270                    let values_opt = item
271                        .get("embeddings")
272                        .and_then(|e| e.get("values").and_then(|v| v.as_array()))
273                        .or_else(|| {
274                            item.get("embeddings")
275                                .and_then(|e| e.as_array())
276                                .or_else(|| item.get("values").and_then(|v| v.as_array()))
277                        });
278
279                    let values = values_opt.ok_or_else(|| {
280                        RuntimeError::ApiError(
281                            "Invalid embedding format in Vertex AI response".to_string(),
282                        )
283                    })?;
284
285                    let vec: Vec<f32> = values
286                        .iter()
287                        .filter_map(|v| v.as_f64().map(|f| f as f32))
288                        .collect();
289                    result.push(vec);
290                }
291
292                Ok(result)
293            })
294            .await
295    }
296
297    fn dimensions(&self) -> u32 {
298        self.dimensions
299    }
300
301    fn model_id(&self) -> &str {
302        &self.model_id
303    }
304}
305
306/// Text generation model backed by the Vertex AI `generateContent` endpoint.
307pub struct VertexAiGeneratorModel {
308    client: Client,
309    cb: crate::reliability::CircuitBreakerWrapper,
310    model_id: String,
311    options: VertexAiResolvedOptions,
312}
313
314impl VertexAiGeneratorModel {
315    fn endpoint_url(&self) -> String {
316        format!(
317            "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/{}/models/{}:generateContent",
318            self.options.location,
319            self.options.project_id,
320            self.options.location,
321            self.options.publisher,
322            self.model_id
323        )
324    }
325}
326
327#[async_trait]
328impl GeneratorModel for VertexAiGeneratorModel {
329    async fn generate(
330        &self,
331        messages: &[Message],
332        options: GenerationOptions,
333    ) -> Result<GenerationResult> {
334        let messages: Vec<Message> = messages.to_vec();
335
336        self.cb
337            .call(move || async move {
338                let payload = build_google_generate_payload(&messages, &options);
339                let response = self
340                    .client
341                    .post(self.endpoint_url())
342                    .header("Authorization", format!("Bearer {}", self.options.token))
343                    .json(&payload)
344                    .send()
345                    .await
346                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
347
348                let body: serde_json::Value = check_http_status("Vertex AI", response)?
349                    .json()
350                    .await
351                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
352
353                let candidates = body
354                    .get("candidates")
355                    .and_then(|v| v.as_array())
356                    .ok_or_else(|| RuntimeError::ApiError("No candidates returned".to_string()))?;
357
358                let first_candidate = candidates
359                    .first()
360                    .ok_or_else(|| RuntimeError::ApiError("Empty candidates".to_string()))?;
361
362                let content_parts = first_candidate
363                    .get("content")
364                    .and_then(|c| c.get("parts"))
365                    .and_then(|p| p.as_array())
366                    .ok_or_else(|| RuntimeError::ApiError("Invalid content format".to_string()))?;
367
368                let text = content_parts
369                    .first()
370                    .and_then(|p| p.get("text"))
371                    .and_then(|t| t.as_str())
372                    .unwrap_or("")
373                    .to_string();
374
375                let usage = body.get("usageMetadata").map(|u| TokenUsage {
376                    prompt_tokens: u["promptTokenCount"].as_u64().unwrap_or(0) as usize,
377                    completion_tokens: u["candidatesTokenCount"].as_u64().unwrap_or(0) as usize,
378                    total_tokens: u["totalTokenCount"].as_u64().unwrap_or(0) as usize,
379                });
380
381                Ok(GenerationResult {
382                    text,
383                    usage,
384                    images: vec![],
385                    audio: None,
386                })
387            })
388            .await
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use crate::api::ModelRuntimeKey;
396    use crate::provider::remote_common::RemoteProviderBase;
397    use crate::traits::ModelProvider;
398    use std::time::Duration;
399
400    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
401
402    fn spec(
403        alias: &str,
404        task: ModelTask,
405        model_id: &str,
406        options: serde_json::Value,
407    ) -> ModelAliasSpec {
408        ModelAliasSpec {
409            alias: alias.to_string(),
410            task,
411            provider_id: "remote/vertexai".to_string(),
412            model_id: model_id.to_string(),
413            revision: None,
414            warmup: crate::api::WarmupPolicy::Lazy,
415            required: false,
416            timeout: None,
417            load_timeout: None,
418            retry: None,
419            options,
420        }
421    }
422
423    #[tokio::test]
424    async fn breaker_reused_for_same_runtime_key() {
425        let _lock = ENV_LOCK.lock().await;
426        // SAFETY: protected by ENV_LOCK
427        unsafe {
428            std::env::set_var("VERTEX_AI_TOKEN", "test-token");
429            std::env::set_var("VERTEX_AI_PROJECT", "test-project");
430        }
431
432        let provider = RemoteVertexAIProvider::new();
433        let s1 = spec(
434            "embed/a",
435            ModelTask::Embed,
436            "text-embedding-005",
437            serde_json::Value::Null,
438        );
439        let s2 = spec(
440            "embed/b",
441            ModelTask::Embed,
442            "text-embedding-005",
443            serde_json::Value::Null,
444        );
445
446        let _ = provider.load(&s1).await.unwrap();
447        let _ = provider.load(&s2).await.unwrap();
448
449        assert_eq!(provider.breaker_count(), 1);
450
451        // SAFETY: protected by ENV_LOCK
452        unsafe {
453            std::env::remove_var("VERTEX_AI_TOKEN");
454            std::env::remove_var("VERTEX_AI_PROJECT");
455        }
456    }
457
458    #[tokio::test]
459    async fn breaker_cleanup_evicts_stale_entries() {
460        let _lock = ENV_LOCK.lock().await;
461        // SAFETY: protected by ENV_LOCK
462        unsafe {
463            std::env::set_var("VERTEX_AI_TOKEN", "test-token");
464            std::env::set_var("VERTEX_AI_PROJECT", "test-project");
465        }
466
467        let provider = RemoteVertexAIProvider::new();
468        let stale = spec(
469            "embed/stale",
470            ModelTask::Embed,
471            "text-embedding-005",
472            serde_json::Value::Null,
473        );
474        let fresh = spec(
475            "embed/fresh",
476            ModelTask::Embed,
477            "text-embedding-004",
478            serde_json::Value::Null,
479        );
480        provider.insert_test_breaker(
481            ModelRuntimeKey::new(&stale),
482            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
483        );
484        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
485        assert_eq!(provider.breaker_count(), 2);
486
487        provider.force_cleanup_now_for_test();
488        let _ = provider.load(&fresh).await.unwrap();
489        assert_eq!(provider.breaker_count(), 1);
490
491        // SAFETY: protected by ENV_LOCK
492        unsafe {
493            std::env::remove_var("VERTEX_AI_TOKEN");
494            std::env::remove_var("VERTEX_AI_PROJECT");
495        }
496    }
497
498    #[tokio::test]
499    async fn load_fails_when_project_is_missing() {
500        let _lock = ENV_LOCK.lock().await;
501        // SAFETY: protected by ENV_LOCK
502        unsafe {
503            std::env::set_var("VERTEX_AI_TOKEN", "test-token");
504            std::env::remove_var("VERTEX_AI_PROJECT");
505        }
506
507        let provider = RemoteVertexAIProvider::new();
508        let s = spec(
509            "embed/a",
510            ModelTask::Embed,
511            "text-embedding-005",
512            serde_json::Value::Null,
513        );
514        let err = provider.load(&s).await.unwrap_err();
515        assert!(err.to_string().contains("VERTEX_AI_PROJECT"));
516
517        // SAFETY: protected by ENV_LOCK
518        unsafe {
519            std::env::remove_var("VERTEX_AI_TOKEN");
520        }
521    }
522
523    #[test]
524    fn generation_payload_alternates_roles() {
525        use crate::traits::Message;
526        let messages = vec![
527            Message::user("user question"),
528            Message::assistant("assistant answer"),
529            Message::user("user follow-up"),
530        ];
531        let payload = build_google_generate_payload(&messages, &GenerationOptions::default());
532        let contents = payload["contents"].as_array().unwrap();
533
534        assert_eq!(contents[0]["role"], "user");
535        assert_eq!(contents[1]["role"], "model");
536        assert_eq!(contents[2]["role"], "user");
537    }
538
539    #[test]
540    fn generation_payload_includes_generation_options() {
541        use crate::traits::Message;
542        let messages = vec![Message::user("hello")];
543        let payload = build_google_generate_payload(
544            &messages,
545            &GenerationOptions {
546                max_tokens: Some(64),
547                temperature: Some(0.7),
548                top_p: Some(0.9),
549                ..Default::default()
550            },
551        );
552
553        assert_eq!(payload["generationConfig"]["maxOutputTokens"], 64);
554        let temperature = payload["generationConfig"]["temperature"].as_f64().unwrap();
555        let top_p = payload["generationConfig"]["topP"].as_f64().unwrap();
556        assert!((temperature - 0.7).abs() < 1e-6);
557        assert!((top_p - 0.9).abs() < 1e-6);
558    }
559}