1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{
4 RemoteProviderBase, build_google_generate_payload, check_http_status, resolve_api_key,
5};
6use crate::traits::{
7 EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
8 Message, ModelProvider, ProviderCapabilities, ProviderHealth,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde_json::json;
13use std::sync::Arc;
14
15pub struct RemoteGeminiProvider {
21 base: RemoteProviderBase,
22}
23
24impl RemoteGeminiProvider {
25 pub fn new() -> Self {
26 Self::default()
27 }
28
29 #[cfg(test)]
30 fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
31 self.base.insert_test_breaker(key, age);
32 }
33
34 #[cfg(test)]
35 fn breaker_count(&self) -> usize {
36 self.base.breaker_count()
37 }
38
39 #[cfg(test)]
40 fn force_cleanup_now_for_test(&self) {
41 self.base.force_cleanup_now_for_test();
42 }
43}
44
45impl Default for RemoteGeminiProvider {
46 fn default() -> Self {
47 Self {
48 base: RemoteProviderBase::new(),
49 }
50 }
51}
52
53#[async_trait]
54impl ModelProvider for RemoteGeminiProvider {
55 fn provider_id(&self) -> &'static str {
56 "remote/gemini"
57 }
58
59 fn capabilities(&self) -> ProviderCapabilities {
60 ProviderCapabilities {
61 supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
62 }
63 }
64
65 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
66 let cb = self.base.circuit_breaker_for(spec);
67 let api_key = resolve_api_key(&spec.options, "api_key_env", "GEMINI_API_KEY")?;
68
69 match spec.task {
70 ModelTask::Embed => {
71 let model = GeminiEmbeddingModel {
72 client: self.base.client.clone(),
73 cb: cb.clone(),
74 model_id: spec.model_id.clone(),
75 api_key,
76 };
77 let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
78 Ok(Arc::new(handle) as LoadedModelHandle)
79 }
80 ModelTask::Generate => {
81 let model = GeminiGeneratorModel {
82 client: self.base.client.clone(),
83 cb,
84 model_id: spec.model_id.clone(),
85 api_key,
86 };
87 let handle: Arc<dyn GeneratorModel> = Arc::new(model);
88 Ok(Arc::new(handle) as LoadedModelHandle)
89 }
90 _ => Err(RuntimeError::CapabilityMismatch(format!(
91 "Gemini provider does not support task {:?}",
92 spec.task
93 ))),
94 }
95 }
96
97 async fn health(&self) -> ProviderHealth {
98 ProviderHealth::Healthy
99 }
100}
101
102pub struct GeminiEmbeddingModel {
104 client: Client,
105 cb: crate::reliability::CircuitBreakerWrapper,
106 model_id: String,
107 api_key: String,
108}
109
110#[async_trait]
111impl EmbeddingModel for GeminiEmbeddingModel {
112 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
113 let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
114
115 self.cb
116 .call(move || async move {
117 let url = format!(
118 "https://generativelanguage.googleapis.com/v1beta/models/{}:batchEmbedContents?key={}",
119 self.model_id, self.api_key
120 );
121
122 let requests: Vec<_> = texts
123 .iter()
124 .map(|t| {
125 json!({
126 "model": format!("models/{}", self.model_id),
127 "content": { "parts": [{ "text": t }] }
128 })
129 })
130 .collect();
131
132 let response = self
133 .client
134 .post(&url)
135 .json(&json!({ "requests": requests }))
136 .send()
137 .await
138 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
139
140 let body: serde_json::Value = check_http_status("Gemini", response)?
141 .json()
142 .await
143 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
144
145 let embeddings_json = body
146 .get("embeddings")
147 .and_then(|v| v.as_array())
148 .ok_or_else(|| {
149 RuntimeError::ApiError("Invalid response format".to_string())
150 })?;
151
152 let mut result = Vec::new();
153 for item in embeddings_json {
154 let values = item
155 .get("values")
156 .and_then(|v| v.as_array())
157 .ok_or_else(|| {
158 RuntimeError::ApiError("Missing values in embedding".to_string())
159 })?;
160
161 let vec: Vec<f32> = values
162 .iter()
163 .filter_map(|v| v.as_f64().map(|f| f as f32))
164 .collect();
165 result.push(vec);
166 }
167 Ok(result)
168 })
169 .await
170 }
171
172 fn dimensions(&self) -> u32 {
173 768
175 }
176
177 fn model_id(&self) -> &str {
178 &self.model_id
179 }
180}
181
182pub struct GeminiGeneratorModel {
184 client: Client,
185 cb: crate::reliability::CircuitBreakerWrapper,
186 model_id: String,
187 api_key: String,
188}
189
190#[async_trait]
191impl GeneratorModel for GeminiGeneratorModel {
192 async fn generate(
193 &self,
194 messages: &[Message],
195 options: GenerationOptions,
196 ) -> Result<GenerationResult> {
197 let messages: Vec<Message> = messages.to_vec();
198
199 self.cb
200 .call(move || async move {
201 let url = format!(
202 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
203 self.model_id, self.api_key
204 );
205
206 let payload = build_google_generate_payload(&messages, &options);
207
208 let response = self
209 .client
210 .post(&url)
211 .json(&payload)
212 .send()
213 .await
214 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
215
216 let body: serde_json::Value = check_http_status("Gemini", response)?
217 .json()
218 .await
219 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
220
221 let candidates = body
222 .get("candidates")
223 .and_then(|v| v.as_array())
224 .ok_or_else(|| RuntimeError::ApiError("No candidates returned".to_string()))?;
225
226 let first_candidate = candidates
227 .first()
228 .ok_or_else(|| RuntimeError::ApiError("Empty candidates".to_string()))?;
229
230 let content_parts = first_candidate
231 .get("content")
232 .and_then(|c| c.get("parts"))
233 .and_then(|p| p.as_array())
234 .ok_or_else(|| RuntimeError::ApiError("Invalid content format".to_string()))?;
235
236 let text = content_parts
237 .first()
238 .and_then(|p| p.get("text"))
239 .and_then(|t| t.as_str())
240 .unwrap_or("")
241 .to_string();
242
243 Ok(GenerationResult {
244 text,
245 usage: None,
246 images: vec![],
247 audio: None,
248 })
249 })
250 .await
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use crate::api::ModelRuntimeKey;
258 use crate::provider::remote_common::RemoteProviderBase;
259 use crate::traits::ModelProvider;
260 use std::time::Duration;
261
262 static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
263
264 fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
265 ModelAliasSpec {
266 alias: alias.to_string(),
267 task,
268 provider_id: "remote/gemini".to_string(),
269 model_id: model_id.to_string(),
270 revision: None,
271 warmup: crate::api::WarmupPolicy::Lazy,
272 required: false,
273 timeout: None,
274 load_timeout: None,
275 retry: None,
276 options: serde_json::Value::Null,
277 }
278 }
279
280 #[tokio::test]
281 async fn breaker_reused_for_same_runtime_key() {
282 let _lock = ENV_LOCK.lock().await;
283 unsafe { std::env::set_var("GEMINI_API_KEY", "test-key") };
285
286 let provider = RemoteGeminiProvider::new();
287 let s1 = spec("embed/a", ModelTask::Embed, "embedding-001");
288 let s2 = spec("embed/b", ModelTask::Embed, "embedding-001");
289
290 let _ = provider.load(&s1).await.unwrap();
291 let _ = provider.load(&s2).await.unwrap();
292
293 assert_eq!(provider.breaker_count(), 1);
294
295 unsafe { std::env::remove_var("GEMINI_API_KEY") };
297 }
298
299 #[tokio::test]
300 async fn breaker_isolated_by_task_and_model() {
301 let _lock = ENV_LOCK.lock().await;
302 unsafe { std::env::set_var("GEMINI_API_KEY", "test-key") };
304
305 let provider = RemoteGeminiProvider::new();
306 let embed = spec("embed/a", ModelTask::Embed, "embedding-001");
307 let gen_spec = spec("chat/a", ModelTask::Generate, "gemini-pro");
308
309 let _ = provider.load(&embed).await.unwrap();
310 let _ = provider.load(&gen_spec).await.unwrap();
311
312 assert_eq!(provider.breaker_count(), 2);
313
314 unsafe { std::env::remove_var("GEMINI_API_KEY") };
316 }
317
318 #[tokio::test]
319 async fn breaker_cleanup_evicts_stale_entries() {
320 let _lock = ENV_LOCK.lock().await;
321 unsafe { std::env::set_var("GEMINI_API_KEY", "test-key") };
323
324 let provider = RemoteGeminiProvider::new();
325 let stale = spec("embed/stale", ModelTask::Embed, "embedding-001");
326 let fresh = spec("embed/fresh", ModelTask::Embed, "embedding-002");
327 provider.insert_test_breaker(
328 ModelRuntimeKey::new(&stale),
329 RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
330 );
331 provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
332 assert_eq!(provider.breaker_count(), 2);
333
334 provider.force_cleanup_now_for_test();
335 let _ = provider.load(&fresh).await.unwrap();
336
337 assert_eq!(provider.breaker_count(), 1);
338
339 unsafe { std::env::remove_var("GEMINI_API_KEY") };
341 }
342
343 #[test]
344 fn generation_payload_alternates_roles() {
345 use crate::traits::Message;
346 let messages = vec![
347 Message::user("user question"),
348 Message::assistant("assistant answer"),
349 Message::user("user follow-up"),
350 ];
351 let payload = build_google_generate_payload(&messages, &GenerationOptions::default());
352 let contents = payload["contents"].as_array().unwrap();
353
354 assert_eq!(contents[0]["role"], "user");
355 assert_eq!(contents[1]["role"], "model");
356 assert_eq!(contents[2]["role"], "user");
357 }
358
359 #[test]
360 fn generation_payload_includes_generation_options() {
361 use crate::traits::Message;
362 let messages = vec![Message::user("hello")];
363 let payload = build_google_generate_payload(
364 &messages,
365 &GenerationOptions {
366 max_tokens: Some(64),
367 temperature: Some(0.7),
368 top_p: Some(0.9),
369 ..Default::default()
370 },
371 );
372
373 assert_eq!(payload["generationConfig"]["maxOutputTokens"], 64);
374 let temperature = payload["generationConfig"]["temperature"].as_f64().unwrap();
375 let top_p = payload["generationConfig"]["topP"].as_f64().unwrap();
376 assert!((temperature - 0.7).abs() < 1e-6);
377 assert!((top_p - 0.9).abs() < 1e-6);
378 }
379
380 #[test]
381 fn generation_payload_extracts_system_instruction() {
382 use crate::traits::Message;
383 let messages = vec![Message::system("you are helpful"), Message::user("hello")];
384 let payload = build_google_generate_payload(&messages, &GenerationOptions::default());
385
386 let si = &payload["system_instruction"];
388 assert_eq!(si["parts"][0]["text"], "you are helpful");
389
390 let contents = payload["contents"].as_array().unwrap();
392 assert_eq!(contents.len(), 1);
393 assert_eq!(contents[0]["role"], "user");
394 }
395
396 #[test]
397 fn generation_payload_no_system_instruction_without_system_messages() {
398 use crate::traits::Message;
399 let messages = vec![Message::user("hello"), Message::assistant("hi")];
400 let payload = build_google_generate_payload(&messages, &GenerationOptions::default());
401
402 assert!(payload.get("system_instruction").is_none());
404
405 let contents = payload["contents"].as_array().unwrap();
406 assert_eq!(contents.len(), 2);
407 }
408}