uni_xervo/provider/
openai.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6    ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13/// Remote provider that calls the [OpenAI API](https://platform.openai.com/docs/api-reference)
14/// for embedding (`/v1/embeddings`) and text generation (`/v1/chat/completions`).
15///
16/// Requires the `OPENAI_API_KEY` environment variable (or a custom env var name
17/// via the `api_key_env` option).
18pub struct RemoteOpenAIProvider {
19    base: RemoteProviderBase,
20}
21
22impl Default for RemoteOpenAIProvider {
23    fn default() -> Self {
24        Self {
25            base: RemoteProviderBase::new(),
26        }
27    }
28}
29
30impl RemoteOpenAIProvider {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    #[cfg(test)]
36    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37        self.base.insert_test_breaker(key, age);
38    }
39
40    #[cfg(test)]
41    fn breaker_count(&self) -> usize {
42        self.base.breaker_count()
43    }
44
45    #[cfg(test)]
46    fn force_cleanup_now_for_test(&self) {
47        self.base.force_cleanup_now_for_test();
48    }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteOpenAIProvider {
53    fn provider_id(&self) -> &'static str {
54        "remote/openai"
55    }
56
57    fn capabilities(&self) -> ProviderCapabilities {
58        ProviderCapabilities {
59            supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
60        }
61    }
62
63    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64        let cb = self.base.circuit_breaker_for(spec);
65        let api_key = resolve_api_key(&spec.options, "api_key_env", "OPENAI_API_KEY")?;
66
67        match spec.task {
68            ModelTask::Embed => {
69                let model = OpenAIEmbeddingModel {
70                    client: self.base.client.clone(),
71                    cb: cb.clone(),
72                    model_id: spec.model_id.clone(),
73                    api_key,
74                };
75                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
76                Ok(Arc::new(handle) as LoadedModelHandle)
77            }
78            ModelTask::Generate => {
79                let model = OpenAIGeneratorModel {
80                    client: self.base.client.clone(),
81                    cb,
82                    model_id: spec.model_id.clone(),
83                    api_key,
84                };
85                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
86                Ok(Arc::new(handle) as LoadedModelHandle)
87            }
88            _ => Err(RuntimeError::CapabilityMismatch(format!(
89                "OpenAI provider does not support task {:?}",
90                spec.task
91            ))),
92        }
93    }
94
95    async fn health(&self) -> ProviderHealth {
96        ProviderHealth::Healthy
97    }
98}
99
100/// Embedding model backed by the OpenAI embeddings API.
101pub struct OpenAIEmbeddingModel {
102    client: Client,
103    cb: crate::reliability::CircuitBreakerWrapper,
104    model_id: String,
105    api_key: String,
106}
107
108#[async_trait]
109impl EmbeddingModel for OpenAIEmbeddingModel {
110    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
111        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
112
113        self.cb
114            .call(move || async move {
115                let response = self
116                    .client
117                    .post("https://api.openai.com/v1/embeddings")
118                    .header("Authorization", format!("Bearer {}", self.api_key))
119                    .json(&json!({
120                        "model": self.model_id,
121                        "input": texts
122                    }))
123                    .send()
124                    .await
125                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
126
127                let body: serde_json::Value = check_http_status("OpenAI", response)?
128                    .json()
129                    .await
130                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
131
132                let mut embeddings = Vec::new();
133                if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
134                    for item in data {
135                        if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
136                            let vec: Vec<f32> = embedding
137                                .iter()
138                                .filter_map(|v| v.as_f64().map(|f| f as f32))
139                                .collect();
140                            embeddings.push(vec);
141                        }
142                    }
143                }
144                Ok(embeddings)
145            })
146            .await
147    }
148
149    fn dimensions(&self) -> u32 {
150        match self.model_id.as_str() {
151            "text-embedding-3-large" => 3072,
152            _ => 1536,
153        }
154    }
155
156    fn model_id(&self) -> &str {
157        &self.model_id
158    }
159}
160
161// ---------------------------------------------------------------------------
162// Generator
163// ---------------------------------------------------------------------------
164
165struct OpenAIGeneratorModel {
166    client: Client,
167    cb: crate::reliability::CircuitBreakerWrapper,
168    model_id: String,
169    api_key: String,
170}
171
172#[async_trait]
173impl GeneratorModel for OpenAIGeneratorModel {
174    async fn generate(
175        &self,
176        messages: &[String],
177        options: GenerationOptions,
178    ) -> Result<GenerationResult> {
179        let messages: Vec<serde_json::Value> = messages
180            .iter()
181            .enumerate()
182            .map(|(i, content)| {
183                let role = if i % 2 == 0 { "user" } else { "assistant" };
184                json!({ "role": role, "content": content })
185            })
186            .collect();
187
188        self.cb
189            .call(move || async move {
190                let mut body = json!({
191                    "model": self.model_id,
192                    "messages": messages,
193                });
194
195                if let Some(max_tokens) = options.max_tokens {
196                    body["max_tokens"] = json!(max_tokens);
197                }
198                if let Some(temperature) = options.temperature {
199                    body["temperature"] = json!(temperature);
200                }
201                if let Some(top_p) = options.top_p {
202                    body["top_p"] = json!(top_p);
203                }
204
205                let response = self
206                    .client
207                    .post("https://api.openai.com/v1/chat/completions")
208                    .header("Authorization", format!("Bearer {}", self.api_key))
209                    .json(&body)
210                    .send()
211                    .await
212                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
213
214                let body: serde_json::Value = check_http_status("OpenAI", response)?
215                    .json()
216                    .await
217                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
218
219                let text = body["choices"][0]["message"]["content"]
220                    .as_str()
221                    .unwrap_or("")
222                    .to_string();
223
224                let usage = body.get("usage").map(|u| TokenUsage {
225                    prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as usize,
226                    completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as usize,
227                    total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as usize,
228                });
229
230                Ok(GenerationResult { text, usage })
231            })
232            .await
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::api::ModelRuntimeKey;
240    use crate::provider::remote_common::RemoteProviderBase;
241    use crate::traits::ModelProvider;
242    use std::time::Duration;
243
244    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
245
246    fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
247        ModelAliasSpec {
248            alias: alias.to_string(),
249            task,
250            provider_id: "remote/openai".to_string(),
251            model_id: model_id.to_string(),
252            revision: None,
253            warmup: crate::api::WarmupPolicy::Lazy,
254            required: false,
255            timeout: None,
256            load_timeout: None,
257            retry: None,
258            options: serde_json::Value::Null,
259        }
260    }
261
262    #[tokio::test]
263    async fn breaker_reused_for_same_runtime_key() {
264        let _lock = ENV_LOCK.lock().await;
265        // SAFETY: protected by ENV_LOCK
266        unsafe { std::env::set_var("OPENAI_API_KEY", "test-key") };
267
268        let provider = RemoteOpenAIProvider::new();
269        let s1 = spec("embed/a", ModelTask::Embed, "text-embedding-3-small");
270        let s2 = spec("embed/b", ModelTask::Embed, "text-embedding-3-small");
271
272        let _ = provider.load(&s1).await.unwrap();
273        let _ = provider.load(&s2).await.unwrap();
274
275        assert_eq!(provider.breaker_count(), 1);
276
277        // SAFETY: protected by ENV_LOCK
278        unsafe { std::env::remove_var("OPENAI_API_KEY") };
279    }
280
281    #[tokio::test]
282    async fn breaker_isolated_by_task_and_model() {
283        let _lock = ENV_LOCK.lock().await;
284        // SAFETY: protected by ENV_LOCK
285        unsafe { std::env::set_var("OPENAI_API_KEY", "test-key") };
286
287        let provider = RemoteOpenAIProvider::new();
288        let embed = spec("embed/a", ModelTask::Embed, "text-embedding-3-small");
289        let gen_spec = spec("chat/a", ModelTask::Generate, "gpt-4o-mini");
290
291        let _ = provider.load(&embed).await.unwrap();
292        let _ = provider.load(&gen_spec).await.unwrap();
293
294        assert_eq!(provider.breaker_count(), 2);
295
296        // SAFETY: protected by ENV_LOCK
297        unsafe { std::env::remove_var("OPENAI_API_KEY") };
298    }
299
300    #[tokio::test]
301    async fn breaker_cleanup_evicts_stale_entries() {
302        let _lock = ENV_LOCK.lock().await;
303        // SAFETY: protected by ENV_LOCK
304        unsafe { std::env::set_var("OPENAI_API_KEY", "test-key") };
305
306        let provider = RemoteOpenAIProvider::new();
307        let stale = spec("embed/stale", ModelTask::Embed, "text-embedding-3-small");
308        let fresh = spec("embed/fresh", ModelTask::Embed, "text-embedding-3-large");
309        provider.insert_test_breaker(
310            ModelRuntimeKey::new(&stale),
311            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
312        );
313        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
314        assert_eq!(provider.breaker_count(), 2);
315
316        provider.force_cleanup_now_for_test();
317        let _ = provider.load(&fresh).await.unwrap();
318
319        assert_eq!(provider.breaker_count(), 1);
320
321        // SAFETY: protected by ENV_LOCK
322        unsafe { std::env::remove_var("OPENAI_API_KEY") };
323    }
324}