uni_xervo/provider/
azure_openai.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6    ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13/// Remote provider that calls the [Azure OpenAI Service](https://learn.microsoft.com/en-us/azure/ai-services/openai/)
14/// for embedding and text generation.
15///
16/// Requires the `AZURE_OPENAI_API_KEY` environment variable (or a custom env
17/// var name via the `api_key_env` option) and the `resource_name` option.
18pub struct RemoteAzureOpenAIProvider {
19    base: RemoteProviderBase,
20}
21
22impl Default for RemoteAzureOpenAIProvider {
23    fn default() -> Self {
24        Self {
25            base: RemoteProviderBase::new(),
26        }
27    }
28}
29
30impl RemoteAzureOpenAIProvider {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    #[cfg(test)]
36    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37        self.base.insert_test_breaker(key, age);
38    }
39
40    #[cfg(test)]
41    fn breaker_count(&self) -> usize {
42        self.base.breaker_count()
43    }
44
45    #[cfg(test)]
46    fn force_cleanup_now_for_test(&self) {
47        self.base.force_cleanup_now_for_test();
48    }
49}
50
51/// Resolved Azure OpenAI configuration extracted from a [`ModelAliasSpec`]'s
52/// options and environment variables.
53#[derive(Clone)]
54struct AzureResolvedOptions {
55    api_key: String,
56    resource_name: String,
57    api_version: String,
58}
59
60impl AzureResolvedOptions {
61    fn from_spec(spec: &ModelAliasSpec) -> Result<Self> {
62        let api_key = resolve_api_key(&spec.options, "api_key_env", "AZURE_OPENAI_API_KEY")?;
63
64        let resource_name = spec
65            .options
66            .get("resource_name")
67            .and_then(|v| v.as_str())
68            .ok_or_else(|| {
69                RuntimeError::Config(
70                    "Option 'resource_name' is required for Azure OpenAI provider".to_string(),
71                )
72            })?
73            .to_string();
74
75        let api_version = spec
76            .options
77            .get("api_version")
78            .and_then(|v| v.as_str())
79            .unwrap_or("2024-10-21")
80            .to_string();
81
82        Ok(Self {
83            api_key,
84            resource_name,
85            api_version,
86        })
87    }
88
89    fn embed_url(&self, deployment: &str) -> String {
90        format!(
91            "https://{}.openai.azure.com/openai/deployments/{}/embeddings?api-version={}",
92            self.resource_name, deployment, self.api_version
93        )
94    }
95
96    fn chat_url(&self, deployment: &str) -> String {
97        format!(
98            "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
99            self.resource_name, deployment, self.api_version
100        )
101    }
102}
103
104#[async_trait]
105impl ModelProvider for RemoteAzureOpenAIProvider {
106    fn provider_id(&self) -> &'static str {
107        "remote/azure-openai"
108    }
109
110    fn capabilities(&self) -> ProviderCapabilities {
111        ProviderCapabilities {
112            supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
113        }
114    }
115
116    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
117        let cb = self.base.circuit_breaker_for(spec);
118        let resolved = AzureResolvedOptions::from_spec(spec)?;
119
120        match spec.task {
121            ModelTask::Embed => {
122                let model = AzureOpenAIEmbeddingModel {
123                    client: self.base.client.clone(),
124                    cb: cb.clone(),
125                    deployment: spec.model_id.clone(),
126                    options: resolved,
127                };
128                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
129                Ok(Arc::new(handle) as LoadedModelHandle)
130            }
131            ModelTask::Generate => {
132                let model = AzureOpenAIGeneratorModel {
133                    client: self.base.client.clone(),
134                    cb,
135                    deployment: spec.model_id.clone(),
136                    options: resolved,
137                };
138                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
139                Ok(Arc::new(handle) as LoadedModelHandle)
140            }
141            _ => Err(RuntimeError::CapabilityMismatch(format!(
142                "Azure OpenAI provider does not support task {:?}",
143                spec.task
144            ))),
145        }
146    }
147
148    async fn health(&self) -> ProviderHealth {
149        ProviderHealth::Healthy
150    }
151}
152
153struct AzureOpenAIEmbeddingModel {
154    client: Client,
155    cb: crate::reliability::CircuitBreakerWrapper,
156    deployment: String,
157    options: AzureResolvedOptions,
158}
159
160#[async_trait]
161impl EmbeddingModel for AzureOpenAIEmbeddingModel {
162    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
163        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
164
165        self.cb
166            .call(move || async move {
167                let url = self.options.embed_url(&self.deployment);
168
169                let response = self
170                    .client
171                    .post(&url)
172                    .header("api-key", &self.options.api_key)
173                    .json(&json!({
174                        "input": texts
175                    }))
176                    .send()
177                    .await
178                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
179
180                let body: serde_json::Value = check_http_status("Azure OpenAI", response)?
181                    .json()
182                    .await
183                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
184
185                let mut embeddings = Vec::new();
186                if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
187                    for item in data {
188                        if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
189                            let vec: Vec<f32> = embedding
190                                .iter()
191                                .filter_map(|v| v.as_f64().map(|f| f as f32))
192                                .collect();
193                            embeddings.push(vec);
194                        }
195                    }
196                }
197                Ok(embeddings)
198            })
199            .await
200    }
201
202    fn dimensions(&self) -> u32 {
203        // Azure deployments may use various embedding models;
204        // default to 1536 (text-embedding-ada-002 / text-embedding-3-small).
205        1536
206    }
207
208    fn model_id(&self) -> &str {
209        &self.deployment
210    }
211}
212
213struct AzureOpenAIGeneratorModel {
214    client: Client,
215    cb: crate::reliability::CircuitBreakerWrapper,
216    deployment: String,
217    options: AzureResolvedOptions,
218}
219
220#[async_trait]
221impl GeneratorModel for AzureOpenAIGeneratorModel {
222    async fn generate(
223        &self,
224        messages: &[String],
225        options: GenerationOptions,
226    ) -> Result<GenerationResult> {
227        let messages: Vec<serde_json::Value> = messages
228            .iter()
229            .enumerate()
230            .map(|(i, content)| {
231                let role = if i % 2 == 0 { "user" } else { "assistant" };
232                json!({ "role": role, "content": content })
233            })
234            .collect();
235
236        self.cb
237            .call(move || async move {
238                let url = self.options.chat_url(&self.deployment);
239
240                let mut body = json!({
241                    "messages": messages,
242                });
243
244                if let Some(max_tokens) = options.max_tokens {
245                    body["max_tokens"] = json!(max_tokens);
246                }
247                if let Some(temperature) = options.temperature {
248                    body["temperature"] = json!(temperature);
249                }
250                if let Some(top_p) = options.top_p {
251                    body["top_p"] = json!(top_p);
252                }
253
254                let response = self
255                    .client
256                    .post(&url)
257                    .header("api-key", &self.options.api_key)
258                    .json(&body)
259                    .send()
260                    .await
261                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
262
263                let body: serde_json::Value = check_http_status("Azure OpenAI", response)?
264                    .json()
265                    .await
266                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
267
268                let text = body["choices"][0]["message"]["content"]
269                    .as_str()
270                    .unwrap_or("")
271                    .to_string();
272
273                let usage = body.get("usage").map(|u| TokenUsage {
274                    prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as usize,
275                    completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as usize,
276                    total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as usize,
277                });
278
279                Ok(GenerationResult { text, usage })
280            })
281            .await
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use crate::api::ModelRuntimeKey;
289    use crate::provider::remote_common::RemoteProviderBase;
290    use crate::traits::ModelProvider;
291    use std::time::Duration;
292
293    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
294
295    fn spec_with_opts(
296        alias: &str,
297        task: ModelTask,
298        model_id: &str,
299        options: serde_json::Value,
300    ) -> ModelAliasSpec {
301        ModelAliasSpec {
302            alias: alias.to_string(),
303            task,
304            provider_id: "remote/azure-openai".to_string(),
305            model_id: model_id.to_string(),
306            revision: None,
307            warmup: crate::api::WarmupPolicy::Lazy,
308            required: false,
309            timeout: None,
310            load_timeout: None,
311            retry: None,
312            options,
313        }
314    }
315
316    fn default_opts() -> serde_json::Value {
317        json!({ "resource_name": "my-resource" })
318    }
319
320    #[tokio::test]
321    async fn breaker_reused_for_same_runtime_key() {
322        let _lock = ENV_LOCK.lock().await;
323        unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
324
325        let provider = RemoteAzureOpenAIProvider::new();
326        let s1 = spec_with_opts(
327            "embed/a",
328            ModelTask::Embed,
329            "text-embedding-ada-002",
330            default_opts(),
331        );
332        let s2 = spec_with_opts(
333            "embed/b",
334            ModelTask::Embed,
335            "text-embedding-ada-002",
336            default_opts(),
337        );
338
339        let _ = provider.load(&s1).await.unwrap();
340        let _ = provider.load(&s2).await.unwrap();
341
342        assert_eq!(provider.breaker_count(), 1);
343
344        unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
345    }
346
347    #[tokio::test]
348    async fn breaker_cleanup_evicts_stale_entries() {
349        let _lock = ENV_LOCK.lock().await;
350        unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
351
352        let provider = RemoteAzureOpenAIProvider::new();
353        let stale = spec_with_opts(
354            "embed/stale",
355            ModelTask::Embed,
356            "text-embedding-ada-002",
357            default_opts(),
358        );
359        let fresh = spec_with_opts("chat/fresh", ModelTask::Generate, "gpt-4o", default_opts());
360        provider.insert_test_breaker(
361            ModelRuntimeKey::new(&stale),
362            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
363        );
364        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
365        assert_eq!(provider.breaker_count(), 2);
366
367        provider.force_cleanup_now_for_test();
368        let _ = provider.load(&fresh).await.unwrap();
369
370        assert_eq!(provider.breaker_count(), 1);
371
372        unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
373    }
374
375    #[tokio::test]
376    async fn load_fails_without_resource_name() {
377        let _lock = ENV_LOCK.lock().await;
378        unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
379
380        let provider = RemoteAzureOpenAIProvider::new();
381        let s = spec_with_opts(
382            "embed/a",
383            ModelTask::Embed,
384            "text-embedding-ada-002",
385            serde_json::Value::Null,
386        );
387        let result = provider.load(&s).await;
388        assert!(result.is_err());
389        assert!(result.unwrap_err().to_string().contains("resource_name"));
390
391        unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
392    }
393
394    #[tokio::test]
395    async fn rerank_capability_mismatch() {
396        let _lock = ENV_LOCK.lock().await;
397        unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
398
399        let provider = RemoteAzureOpenAIProvider::new();
400        let s = spec_with_opts(
401            "rerank/a",
402            ModelTask::Rerank,
403            "text-embedding-ada-002",
404            default_opts(),
405        );
406        let result = provider.load(&s).await;
407        assert!(result.is_err());
408        assert!(
409            result
410                .unwrap_err()
411                .to_string()
412                .contains("does not support task")
413        );
414
415        unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
416    }
417
418    #[test]
419    fn azure_url_construction() {
420        let opts = AzureResolvedOptions {
421            api_key: "key".to_string(),
422            resource_name: "my-resource".to_string(),
423            api_version: "2024-10-21".to_string(),
424        };
425
426        assert_eq!(
427            opts.embed_url("text-embedding-ada-002"),
428            "https://my-resource.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-10-21"
429        );
430
431        assert_eq!(
432            opts.chat_url("gpt-4o"),
433            "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-10-21"
434        );
435    }
436}