uni_xervo/provider/
openai.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6    Message, MessageRole, ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13/// Remote provider that calls the [OpenAI API](https://platform.openai.com/docs/api-reference)
14/// for embedding (`/v1/embeddings`) and text generation (`/v1/chat/completions`).
15///
16/// Requires the `OPENAI_API_KEY` environment variable (or a custom env var name
17/// via the `api_key_env` option).
18pub struct RemoteOpenAIProvider {
19    base: RemoteProviderBase,
20}
21
22impl Default for RemoteOpenAIProvider {
23    fn default() -> Self {
24        Self {
25            base: RemoteProviderBase::new(),
26        }
27    }
28}
29
30impl RemoteOpenAIProvider {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    #[cfg(test)]
36    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37        self.base.insert_test_breaker(key, age);
38    }
39
40    #[cfg(test)]
41    fn breaker_count(&self) -> usize {
42        self.base.breaker_count()
43    }
44
45    #[cfg(test)]
46    fn force_cleanup_now_for_test(&self) {
47        self.base.force_cleanup_now_for_test();
48    }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteOpenAIProvider {
53    fn provider_id(&self) -> &'static str {
54        "remote/openai"
55    }
56
57    fn capabilities(&self) -> ProviderCapabilities {
58        ProviderCapabilities {
59            supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
60        }
61    }
62
63    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64        let cb = self.base.circuit_breaker_for(spec);
65        let api_key = resolve_api_key(&spec.options, "api_key_env", "OPENAI_API_KEY")?;
66
67        match spec.task {
68            ModelTask::Embed => {
69                let model = OpenAIEmbeddingModel {
70                    client: self.base.client.clone(),
71                    cb: cb.clone(),
72                    model_id: spec.model_id.clone(),
73                    api_key,
74                };
75                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
76                Ok(Arc::new(handle) as LoadedModelHandle)
77            }
78            ModelTask::Generate => {
79                let model = OpenAIGeneratorModel {
80                    client: self.base.client.clone(),
81                    cb,
82                    model_id: spec.model_id.clone(),
83                    api_key,
84                };
85                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
86                Ok(Arc::new(handle) as LoadedModelHandle)
87            }
88            _ => Err(RuntimeError::CapabilityMismatch(format!(
89                "OpenAI provider does not support task {:?}",
90                spec.task
91            ))),
92        }
93    }
94
95    async fn health(&self) -> ProviderHealth {
96        ProviderHealth::Healthy
97    }
98}
99
100/// Embedding model backed by the OpenAI embeddings API.
101pub struct OpenAIEmbeddingModel {
102    client: Client,
103    cb: crate::reliability::CircuitBreakerWrapper,
104    model_id: String,
105    api_key: String,
106}
107
108#[async_trait]
109impl EmbeddingModel for OpenAIEmbeddingModel {
110    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
111        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
112
113        self.cb
114            .call(move || async move {
115                let response = self
116                    .client
117                    .post("https://api.openai.com/v1/embeddings")
118                    .header("Authorization", format!("Bearer {}", self.api_key))
119                    .json(&json!({
120                        "model": self.model_id,
121                        "input": texts
122                    }))
123                    .send()
124                    .await
125                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
126
127                let body: serde_json::Value = check_http_status("OpenAI", response)?
128                    .json()
129                    .await
130                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
131
132                let mut embeddings = Vec::new();
133                if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
134                    for item in data {
135                        if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
136                            let vec: Vec<f32> = embedding
137                                .iter()
138                                .filter_map(|v| v.as_f64().map(|f| f as f32))
139                                .collect();
140                            embeddings.push(vec);
141                        }
142                    }
143                }
144                Ok(embeddings)
145            })
146            .await
147    }
148
149    fn dimensions(&self) -> u32 {
150        match self.model_id.as_str() {
151            "text-embedding-3-large" => 3072,
152            _ => 1536,
153        }
154    }
155
156    fn model_id(&self) -> &str {
157        &self.model_id
158    }
159}
160
161// ---------------------------------------------------------------------------
162// Generator
163// ---------------------------------------------------------------------------
164
165struct OpenAIGeneratorModel {
166    client: Client,
167    cb: crate::reliability::CircuitBreakerWrapper,
168    model_id: String,
169    api_key: String,
170}
171
172#[async_trait]
173impl GeneratorModel for OpenAIGeneratorModel {
174    async fn generate(
175        &self,
176        messages: &[Message],
177        options: GenerationOptions,
178    ) -> Result<GenerationResult> {
179        let messages: Vec<serde_json::Value> = messages
180            .iter()
181            .map(|msg| {
182                let role = match msg.role {
183                    MessageRole::System => "system",
184                    MessageRole::User => "user",
185                    MessageRole::Assistant => "assistant",
186                };
187                json!({ "role": role, "content": msg.text() })
188            })
189            .collect();
190
191        self.cb
192            .call(move || async move {
193                let mut body = json!({
194                    "model": self.model_id,
195                    "messages": messages,
196                });
197
198                if let Some(max_tokens) = options.max_tokens {
199                    body["max_tokens"] = json!(max_tokens);
200                }
201                if let Some(temperature) = options.temperature {
202                    body["temperature"] = json!(temperature);
203                }
204                if let Some(top_p) = options.top_p {
205                    body["top_p"] = json!(top_p);
206                }
207
208                let response = self
209                    .client
210                    .post("https://api.openai.com/v1/chat/completions")
211                    .header("Authorization", format!("Bearer {}", self.api_key))
212                    .json(&body)
213                    .send()
214                    .await
215                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
216
217                let body: serde_json::Value = check_http_status("OpenAI", response)?
218                    .json()
219                    .await
220                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
221
222                let text = body["choices"][0]["message"]["content"]
223                    .as_str()
224                    .unwrap_or("")
225                    .to_string();
226
227                let usage = body.get("usage").map(|u| TokenUsage {
228                    prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as usize,
229                    completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as usize,
230                    total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as usize,
231                });
232
233                Ok(GenerationResult {
234                    text,
235                    usage,
236                    images: vec![],
237                    audio: None,
238                })
239            })
240            .await
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::api::ModelRuntimeKey;
248    use crate::provider::remote_common::RemoteProviderBase;
249    use crate::traits::ModelProvider;
250    use std::time::Duration;
251
252    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
253
254    fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
255        ModelAliasSpec {
256            alias: alias.to_string(),
257            task,
258            provider_id: "remote/openai".to_string(),
259            model_id: model_id.to_string(),
260            revision: None,
261            warmup: crate::api::WarmupPolicy::Lazy,
262            required: false,
263            timeout: None,
264            load_timeout: None,
265            retry: None,
266            options: serde_json::Value::Null,
267        }
268    }
269
270    #[tokio::test]
271    async fn breaker_reused_for_same_runtime_key() {
272        let _lock = ENV_LOCK.lock().await;
273        // SAFETY: protected by ENV_LOCK
274        unsafe { std::env::set_var("OPENAI_API_KEY", "test-key") };
275
276        let provider = RemoteOpenAIProvider::new();
277        let s1 = spec("embed/a", ModelTask::Embed, "text-embedding-3-small");
278        let s2 = spec("embed/b", ModelTask::Embed, "text-embedding-3-small");
279
280        let _ = provider.load(&s1).await.unwrap();
281        let _ = provider.load(&s2).await.unwrap();
282
283        assert_eq!(provider.breaker_count(), 1);
284
285        // SAFETY: protected by ENV_LOCK
286        unsafe { std::env::remove_var("OPENAI_API_KEY") };
287    }
288
289    #[tokio::test]
290    async fn breaker_isolated_by_task_and_model() {
291        let _lock = ENV_LOCK.lock().await;
292        // SAFETY: protected by ENV_LOCK
293        unsafe { std::env::set_var("OPENAI_API_KEY", "test-key") };
294
295        let provider = RemoteOpenAIProvider::new();
296        let embed = spec("embed/a", ModelTask::Embed, "text-embedding-3-small");
297        let gen_spec = spec("chat/a", ModelTask::Generate, "gpt-4o-mini");
298
299        let _ = provider.load(&embed).await.unwrap();
300        let _ = provider.load(&gen_spec).await.unwrap();
301
302        assert_eq!(provider.breaker_count(), 2);
303
304        // SAFETY: protected by ENV_LOCK
305        unsafe { std::env::remove_var("OPENAI_API_KEY") };
306    }
307
308    #[tokio::test]
309    async fn breaker_cleanup_evicts_stale_entries() {
310        let _lock = ENV_LOCK.lock().await;
311        // SAFETY: protected by ENV_LOCK
312        unsafe { std::env::set_var("OPENAI_API_KEY", "test-key") };
313
314        let provider = RemoteOpenAIProvider::new();
315        let stale = spec("embed/stale", ModelTask::Embed, "text-embedding-3-small");
316        let fresh = spec("embed/fresh", ModelTask::Embed, "text-embedding-3-large");
317        provider.insert_test_breaker(
318            ModelRuntimeKey::new(&stale),
319            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
320        );
321        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
322        assert_eq!(provider.breaker_count(), 2);
323
324        provider.force_cleanup_now_for_test();
325        let _ = provider.load(&fresh).await.unwrap();
326
327        assert_eq!(provider.breaker_count(), 1);
328
329        // SAFETY: protected by ENV_LOCK
330        unsafe { std::env::remove_var("OPENAI_API_KEY") };
331    }
332}