1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5 EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6 Message, MessageRole, ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13pub struct RemoteOpenAIProvider {
19 base: RemoteProviderBase,
20}
21
22impl Default for RemoteOpenAIProvider {
23 fn default() -> Self {
24 Self {
25 base: RemoteProviderBase::new(),
26 }
27 }
28}
29
30impl RemoteOpenAIProvider {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 #[cfg(test)]
36 fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37 self.base.insert_test_breaker(key, age);
38 }
39
40 #[cfg(test)]
41 fn breaker_count(&self) -> usize {
42 self.base.breaker_count()
43 }
44
45 #[cfg(test)]
46 fn force_cleanup_now_for_test(&self) {
47 self.base.force_cleanup_now_for_test();
48 }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteOpenAIProvider {
53 fn provider_id(&self) -> &'static str {
54 "remote/openai"
55 }
56
57 fn capabilities(&self) -> ProviderCapabilities {
58 ProviderCapabilities {
59 supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
60 }
61 }
62
63 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64 let cb = self.base.circuit_breaker_for(spec);
65 let api_key = resolve_api_key(&spec.options, "api_key_env", "OPENAI_API_KEY")?;
66
67 match spec.task {
68 ModelTask::Embed => {
69 let model = OpenAIEmbeddingModel {
70 client: self.base.client.clone(),
71 cb: cb.clone(),
72 model_id: spec.model_id.clone(),
73 api_key,
74 };
75 let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
76 Ok(Arc::new(handle) as LoadedModelHandle)
77 }
78 ModelTask::Generate => {
79 let model = OpenAIGeneratorModel {
80 client: self.base.client.clone(),
81 cb,
82 model_id: spec.model_id.clone(),
83 api_key,
84 };
85 let handle: Arc<dyn GeneratorModel> = Arc::new(model);
86 Ok(Arc::new(handle) as LoadedModelHandle)
87 }
88 _ => Err(RuntimeError::CapabilityMismatch(format!(
89 "OpenAI provider does not support task {:?}",
90 spec.task
91 ))),
92 }
93 }
94
95 async fn health(&self) -> ProviderHealth {
96 ProviderHealth::Healthy
97 }
98}
99
100pub struct OpenAIEmbeddingModel {
102 client: Client,
103 cb: crate::reliability::CircuitBreakerWrapper,
104 model_id: String,
105 api_key: String,
106}
107
108#[async_trait]
109impl EmbeddingModel for OpenAIEmbeddingModel {
110 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
111 let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
112
113 self.cb
114 .call(move || async move {
115 let response = self
116 .client
117 .post("https://api.openai.com/v1/embeddings")
118 .header("Authorization", format!("Bearer {}", self.api_key))
119 .json(&json!({
120 "model": self.model_id,
121 "input": texts
122 }))
123 .send()
124 .await
125 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
126
127 let body: serde_json::Value = check_http_status("OpenAI", response)?
128 .json()
129 .await
130 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
131
132 let mut embeddings = Vec::new();
133 if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
134 for item in data {
135 if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
136 let vec: Vec<f32> = embedding
137 .iter()
138 .filter_map(|v| v.as_f64().map(|f| f as f32))
139 .collect();
140 embeddings.push(vec);
141 }
142 }
143 }
144 Ok(embeddings)
145 })
146 .await
147 }
148
149 fn dimensions(&self) -> u32 {
150 match self.model_id.as_str() {
151 "text-embedding-3-large" => 3072,
152 _ => 1536,
153 }
154 }
155
156 fn model_id(&self) -> &str {
157 &self.model_id
158 }
159}
160
161struct OpenAIGeneratorModel {
166 client: Client,
167 cb: crate::reliability::CircuitBreakerWrapper,
168 model_id: String,
169 api_key: String,
170}
171
172#[async_trait]
173impl GeneratorModel for OpenAIGeneratorModel {
174 async fn generate(
175 &self,
176 messages: &[Message],
177 options: GenerationOptions,
178 ) -> Result<GenerationResult> {
179 let messages: Vec<serde_json::Value> = messages
180 .iter()
181 .map(|msg| {
182 let role = match msg.role {
183 MessageRole::System => "system",
184 MessageRole::User => "user",
185 MessageRole::Assistant => "assistant",
186 };
187 json!({ "role": role, "content": msg.text() })
188 })
189 .collect();
190
191 self.cb
192 .call(move || async move {
193 let mut body = json!({
194 "model": self.model_id,
195 "messages": messages,
196 });
197
198 if let Some(max_tokens) = options.max_tokens {
199 body["max_tokens"] = json!(max_tokens);
200 }
201 if let Some(temperature) = options.temperature {
202 body["temperature"] = json!(temperature);
203 }
204 if let Some(top_p) = options.top_p {
205 body["top_p"] = json!(top_p);
206 }
207
208 let response = self
209 .client
210 .post("https://api.openai.com/v1/chat/completions")
211 .header("Authorization", format!("Bearer {}", self.api_key))
212 .json(&body)
213 .send()
214 .await
215 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
216
217 let body: serde_json::Value = check_http_status("OpenAI", response)?
218 .json()
219 .await
220 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
221
222 let text = body["choices"][0]["message"]["content"]
223 .as_str()
224 .unwrap_or("")
225 .to_string();
226
227 let usage = body.get("usage").map(|u| TokenUsage {
228 prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as usize,
229 completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as usize,
230 total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as usize,
231 });
232
233 Ok(GenerationResult {
234 text,
235 usage,
236 images: vec![],
237 audio: None,
238 })
239 })
240 .await
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::api::ModelRuntimeKey;
248 use crate::provider::remote_common::RemoteProviderBase;
249 use crate::traits::ModelProvider;
250 use std::time::Duration;
251
252 static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
253
254 fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
255 ModelAliasSpec {
256 alias: alias.to_string(),
257 task,
258 provider_id: "remote/openai".to_string(),
259 model_id: model_id.to_string(),
260 revision: None,
261 warmup: crate::api::WarmupPolicy::Lazy,
262 required: false,
263 timeout: None,
264 load_timeout: None,
265 retry: None,
266 options: serde_json::Value::Null,
267 }
268 }
269
270 #[tokio::test]
271 async fn breaker_reused_for_same_runtime_key() {
272 let _lock = ENV_LOCK.lock().await;
273 unsafe { std::env::set_var("OPENAI_API_KEY", "test-key") };
275
276 let provider = RemoteOpenAIProvider::new();
277 let s1 = spec("embed/a", ModelTask::Embed, "text-embedding-3-small");
278 let s2 = spec("embed/b", ModelTask::Embed, "text-embedding-3-small");
279
280 let _ = provider.load(&s1).await.unwrap();
281 let _ = provider.load(&s2).await.unwrap();
282
283 assert_eq!(provider.breaker_count(), 1);
284
285 unsafe { std::env::remove_var("OPENAI_API_KEY") };
287 }
288
289 #[tokio::test]
290 async fn breaker_isolated_by_task_and_model() {
291 let _lock = ENV_LOCK.lock().await;
292 unsafe { std::env::set_var("OPENAI_API_KEY", "test-key") };
294
295 let provider = RemoteOpenAIProvider::new();
296 let embed = spec("embed/a", ModelTask::Embed, "text-embedding-3-small");
297 let gen_spec = spec("chat/a", ModelTask::Generate, "gpt-4o-mini");
298
299 let _ = provider.load(&embed).await.unwrap();
300 let _ = provider.load(&gen_spec).await.unwrap();
301
302 assert_eq!(provider.breaker_count(), 2);
303
304 unsafe { std::env::remove_var("OPENAI_API_KEY") };
306 }
307
308 #[tokio::test]
309 async fn breaker_cleanup_evicts_stale_entries() {
310 let _lock = ENV_LOCK.lock().await;
311 unsafe { std::env::set_var("OPENAI_API_KEY", "test-key") };
313
314 let provider = RemoteOpenAIProvider::new();
315 let stale = spec("embed/stale", ModelTask::Embed, "text-embedding-3-small");
316 let fresh = spec("embed/fresh", ModelTask::Embed, "text-embedding-3-large");
317 provider.insert_test_breaker(
318 ModelRuntimeKey::new(&stale),
319 RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
320 );
321 provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
322 assert_eq!(provider.breaker_count(), 2);
323
324 provider.force_cleanup_now_for_test();
325 let _ = provider.load(&fresh).await.unwrap();
326
327 assert_eq!(provider.breaker_count(), 1);
328
329 unsafe { std::env::remove_var("OPENAI_API_KEY") };
331 }
332}