uni_xervo/provider/
gemini.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{
4    RemoteProviderBase, build_google_generate_payload, check_http_status, resolve_api_key,
5};
6use crate::traits::{
7    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
8    Message, ModelProvider, ProviderCapabilities, ProviderHealth,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde_json::json;
13use std::sync::Arc;
14
15/// Remote provider that calls the [Google Gemini API](https://ai.google.dev/api)
16/// for embedding (`batchEmbedContents`) and text generation (`generateContent`).
17///
18/// Requires the `GEMINI_API_KEY` environment variable (or a custom env var name
19/// via the `api_key_env` option).
20pub struct RemoteGeminiProvider {
21    base: RemoteProviderBase,
22}
23
24impl RemoteGeminiProvider {
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    #[cfg(test)]
30    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
31        self.base.insert_test_breaker(key, age);
32    }
33
34    #[cfg(test)]
35    fn breaker_count(&self) -> usize {
36        self.base.breaker_count()
37    }
38
39    #[cfg(test)]
40    fn force_cleanup_now_for_test(&self) {
41        self.base.force_cleanup_now_for_test();
42    }
43}
44
45impl Default for RemoteGeminiProvider {
46    fn default() -> Self {
47        Self {
48            base: RemoteProviderBase::new(),
49        }
50    }
51}
52
53#[async_trait]
54impl ModelProvider for RemoteGeminiProvider {
55    fn provider_id(&self) -> &'static str {
56        "remote/gemini"
57    }
58
59    fn capabilities(&self) -> ProviderCapabilities {
60        ProviderCapabilities {
61            supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
62        }
63    }
64
65    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
66        let cb = self.base.circuit_breaker_for(spec);
67        let api_key = resolve_api_key(&spec.options, "api_key_env", "GEMINI_API_KEY")?;
68
69        match spec.task {
70            ModelTask::Embed => {
71                let model = GeminiEmbeddingModel {
72                    client: self.base.client.clone(),
73                    cb: cb.clone(),
74                    model_id: spec.model_id.clone(),
75                    api_key,
76                };
77                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
78                Ok(Arc::new(handle) as LoadedModelHandle)
79            }
80            ModelTask::Generate => {
81                let model = GeminiGeneratorModel {
82                    client: self.base.client.clone(),
83                    cb,
84                    model_id: spec.model_id.clone(),
85                    api_key,
86                };
87                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
88                Ok(Arc::new(handle) as LoadedModelHandle)
89            }
90            _ => Err(RuntimeError::CapabilityMismatch(format!(
91                "Gemini provider does not support task {:?}",
92                spec.task
93            ))),
94        }
95    }
96
97    async fn health(&self) -> ProviderHealth {
98        ProviderHealth::Healthy
99    }
100}
101
102/// Embedding model backed by the Gemini batch embedding API.
103pub struct GeminiEmbeddingModel {
104    client: Client,
105    cb: crate::reliability::CircuitBreakerWrapper,
106    model_id: String,
107    api_key: String,
108}
109
110#[async_trait]
111impl EmbeddingModel for GeminiEmbeddingModel {
112    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
113        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
114
115        self.cb
116            .call(move || async move {
117                let url = format!(
118                    "https://generativelanguage.googleapis.com/v1beta/models/{}:batchEmbedContents?key={}",
119                    self.model_id, self.api_key
120                );
121
122                let requests: Vec<_> = texts
123                    .iter()
124                    .map(|t| {
125                        json!({
126                            "model": format!("models/{}", self.model_id),
127                            "content": { "parts": [{ "text": t }] }
128                        })
129                    })
130                    .collect();
131
132                let response = self
133                    .client
134                    .post(&url)
135                    .json(&json!({ "requests": requests }))
136                    .send()
137                    .await
138                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
139
140                let body: serde_json::Value = check_http_status("Gemini", response)?
141                    .json()
142                    .await
143                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
144
145                let embeddings_json = body
146                    .get("embeddings")
147                    .and_then(|v| v.as_array())
148                    .ok_or_else(|| {
149                        RuntimeError::ApiError("Invalid response format".to_string())
150                    })?;
151
152                let mut result = Vec::new();
153                for item in embeddings_json {
154                    let values = item
155                        .get("values")
156                        .and_then(|v| v.as_array())
157                        .ok_or_else(|| {
158                            RuntimeError::ApiError("Missing values in embedding".to_string())
159                        })?;
160
161                    let vec: Vec<f32> = values
162                        .iter()
163                        .filter_map(|v| v.as_f64().map(|f| f as f32))
164                        .collect();
165                    result.push(vec);
166                }
167                Ok(result)
168            })
169            .await
170    }
171
172    fn dimensions(&self) -> u32 {
173        // All current Gemini embedding models use 768 dimensions.
174        768
175    }
176
177    fn model_id(&self) -> &str {
178        &self.model_id
179    }
180}
181
182/// Text generation model backed by the Gemini `generateContent` API.
183pub struct GeminiGeneratorModel {
184    client: Client,
185    cb: crate::reliability::CircuitBreakerWrapper,
186    model_id: String,
187    api_key: String,
188}
189
190#[async_trait]
191impl GeneratorModel for GeminiGeneratorModel {
192    async fn generate(
193        &self,
194        messages: &[Message],
195        options: GenerationOptions,
196    ) -> Result<GenerationResult> {
197        let messages: Vec<Message> = messages.to_vec();
198
199        self.cb
200            .call(move || async move {
201                let url = format!(
202                    "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
203                    self.model_id, self.api_key
204                );
205
206                let payload = build_google_generate_payload(&messages, &options);
207
208                let response = self
209                    .client
210                    .post(&url)
211                    .json(&payload)
212                    .send()
213                    .await
214                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
215
216                let body: serde_json::Value = check_http_status("Gemini", response)?
217                    .json()
218                    .await
219                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
220
221                let candidates = body
222                    .get("candidates")
223                    .and_then(|v| v.as_array())
224                    .ok_or_else(|| RuntimeError::ApiError("No candidates returned".to_string()))?;
225
226                let first_candidate = candidates
227                    .first()
228                    .ok_or_else(|| RuntimeError::ApiError("Empty candidates".to_string()))?;
229
230                let content_parts = first_candidate
231                    .get("content")
232                    .and_then(|c| c.get("parts"))
233                    .and_then(|p| p.as_array())
234                    .ok_or_else(|| RuntimeError::ApiError("Invalid content format".to_string()))?;
235
236                let text = content_parts
237                    .first()
238                    .and_then(|p| p.get("text"))
239                    .and_then(|t| t.as_str())
240                    .unwrap_or("")
241                    .to_string();
242
243                Ok(GenerationResult {
244                    text,
245                    usage: None,
246                    images: vec![],
247                    audio: None,
248                })
249            })
250            .await
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use crate::api::ModelRuntimeKey;
258    use crate::provider::remote_common::RemoteProviderBase;
259    use crate::traits::ModelProvider;
260    use std::time::Duration;
261
262    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
263
264    fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
265        ModelAliasSpec {
266            alias: alias.to_string(),
267            task,
268            provider_id: "remote/gemini".to_string(),
269            model_id: model_id.to_string(),
270            revision: None,
271            warmup: crate::api::WarmupPolicy::Lazy,
272            required: false,
273            timeout: None,
274            load_timeout: None,
275            retry: None,
276            options: serde_json::Value::Null,
277        }
278    }
279
280    #[tokio::test]
281    async fn breaker_reused_for_same_runtime_key() {
282        let _lock = ENV_LOCK.lock().await;
283        // SAFETY: protected by ENV_LOCK
284        unsafe { std::env::set_var("GEMINI_API_KEY", "test-key") };
285
286        let provider = RemoteGeminiProvider::new();
287        let s1 = spec("embed/a", ModelTask::Embed, "embedding-001");
288        let s2 = spec("embed/b", ModelTask::Embed, "embedding-001");
289
290        let _ = provider.load(&s1).await.unwrap();
291        let _ = provider.load(&s2).await.unwrap();
292
293        assert_eq!(provider.breaker_count(), 1);
294
295        // SAFETY: protected by ENV_LOCK
296        unsafe { std::env::remove_var("GEMINI_API_KEY") };
297    }
298
299    #[tokio::test]
300    async fn breaker_isolated_by_task_and_model() {
301        let _lock = ENV_LOCK.lock().await;
302        // SAFETY: protected by ENV_LOCK
303        unsafe { std::env::set_var("GEMINI_API_KEY", "test-key") };
304
305        let provider = RemoteGeminiProvider::new();
306        let embed = spec("embed/a", ModelTask::Embed, "embedding-001");
307        let gen_spec = spec("chat/a", ModelTask::Generate, "gemini-pro");
308
309        let _ = provider.load(&embed).await.unwrap();
310        let _ = provider.load(&gen_spec).await.unwrap();
311
312        assert_eq!(provider.breaker_count(), 2);
313
314        // SAFETY: protected by ENV_LOCK
315        unsafe { std::env::remove_var("GEMINI_API_KEY") };
316    }
317
318    #[tokio::test]
319    async fn breaker_cleanup_evicts_stale_entries() {
320        let _lock = ENV_LOCK.lock().await;
321        // SAFETY: protected by ENV_LOCK
322        unsafe { std::env::set_var("GEMINI_API_KEY", "test-key") };
323
324        let provider = RemoteGeminiProvider::new();
325        let stale = spec("embed/stale", ModelTask::Embed, "embedding-001");
326        let fresh = spec("embed/fresh", ModelTask::Embed, "embedding-002");
327        provider.insert_test_breaker(
328            ModelRuntimeKey::new(&stale),
329            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
330        );
331        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
332        assert_eq!(provider.breaker_count(), 2);
333
334        provider.force_cleanup_now_for_test();
335        let _ = provider.load(&fresh).await.unwrap();
336
337        assert_eq!(provider.breaker_count(), 1);
338
339        // SAFETY: protected by ENV_LOCK
340        unsafe { std::env::remove_var("GEMINI_API_KEY") };
341    }
342
343    #[test]
344    fn generation_payload_alternates_roles() {
345        use crate::traits::Message;
346        let messages = vec![
347            Message::user("user question"),
348            Message::assistant("assistant answer"),
349            Message::user("user follow-up"),
350        ];
351        let payload = build_google_generate_payload(&messages, &GenerationOptions::default());
352        let contents = payload["contents"].as_array().unwrap();
353
354        assert_eq!(contents[0]["role"], "user");
355        assert_eq!(contents[1]["role"], "model");
356        assert_eq!(contents[2]["role"], "user");
357    }
358
359    #[test]
360    fn generation_payload_includes_generation_options() {
361        use crate::traits::Message;
362        let messages = vec![Message::user("hello")];
363        let payload = build_google_generate_payload(
364            &messages,
365            &GenerationOptions {
366                max_tokens: Some(64),
367                temperature: Some(0.7),
368                top_p: Some(0.9),
369                ..Default::default()
370            },
371        );
372
373        assert_eq!(payload["generationConfig"]["maxOutputTokens"], 64);
374        let temperature = payload["generationConfig"]["temperature"].as_f64().unwrap();
375        let top_p = payload["generationConfig"]["topP"].as_f64().unwrap();
376        assert!((temperature - 0.7).abs() < 1e-6);
377        assert!((top_p - 0.9).abs() < 1e-6);
378    }
379
380    #[test]
381    fn generation_payload_extracts_system_instruction() {
382        use crate::traits::Message;
383        let messages = vec![Message::system("you are helpful"), Message::user("hello")];
384        let payload = build_google_generate_payload(&messages, &GenerationOptions::default());
385
386        // System message should be extracted into system_instruction
387        let si = &payload["system_instruction"];
388        assert_eq!(si["parts"][0]["text"], "you are helpful");
389
390        // Contents should only have the user message
391        let contents = payload["contents"].as_array().unwrap();
392        assert_eq!(contents.len(), 1);
393        assert_eq!(contents[0]["role"], "user");
394    }
395
396    #[test]
397    fn generation_payload_no_system_instruction_without_system_messages() {
398        use crate::traits::Message;
399        let messages = vec![Message::user("hello"), Message::assistant("hi")];
400        let payload = build_google_generate_payload(&messages, &GenerationOptions::default());
401
402        // No system_instruction field should be present
403        assert!(payload.get("system_instruction").is_none());
404
405        let contents = payload["contents"].as_array().unwrap();
406        assert_eq!(contents.len(), 2);
407    }
408}