uni_xervo/provider/
azure_openai.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5    EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6    Message, MessageRole, ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13/// Remote provider that calls the [Azure OpenAI Service](https://learn.microsoft.com/en-us/azure/ai-services/openai/)
14/// for embedding and text generation.
15///
16/// Requires the `AZURE_OPENAI_API_KEY` environment variable (or a custom env
17/// var name via the `api_key_env` option) and the `resource_name` option.
18pub struct RemoteAzureOpenAIProvider {
19    base: RemoteProviderBase,
20}
21
22impl Default for RemoteAzureOpenAIProvider {
23    fn default() -> Self {
24        Self {
25            base: RemoteProviderBase::new(),
26        }
27    }
28}
29
30impl RemoteAzureOpenAIProvider {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    #[cfg(test)]
36    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37        self.base.insert_test_breaker(key, age);
38    }
39
40    #[cfg(test)]
41    fn breaker_count(&self) -> usize {
42        self.base.breaker_count()
43    }
44
45    #[cfg(test)]
46    fn force_cleanup_now_for_test(&self) {
47        self.base.force_cleanup_now_for_test();
48    }
49}
50
51/// Resolved Azure OpenAI configuration extracted from a [`ModelAliasSpec`]'s
52/// options and environment variables.
53#[derive(Clone)]
54struct AzureResolvedOptions {
55    api_key: String,
56    resource_name: String,
57    api_version: String,
58}
59
60impl AzureResolvedOptions {
61    fn from_spec(spec: &ModelAliasSpec) -> Result<Self> {
62        let api_key = resolve_api_key(&spec.options, "api_key_env", "AZURE_OPENAI_API_KEY")?;
63
64        let resource_name = spec
65            .options
66            .get("resource_name")
67            .and_then(|v| v.as_str())
68            .ok_or_else(|| {
69                RuntimeError::Config(
70                    "Option 'resource_name' is required for Azure OpenAI provider".to_string(),
71                )
72            })?
73            .to_string();
74
75        let api_version = spec
76            .options
77            .get("api_version")
78            .and_then(|v| v.as_str())
79            .unwrap_or("2024-10-21")
80            .to_string();
81
82        Ok(Self {
83            api_key,
84            resource_name,
85            api_version,
86        })
87    }
88
89    fn embed_url(&self, deployment: &str) -> String {
90        format!(
91            "https://{}.openai.azure.com/openai/deployments/{}/embeddings?api-version={}",
92            self.resource_name, deployment, self.api_version
93        )
94    }
95
96    fn chat_url(&self, deployment: &str) -> String {
97        format!(
98            "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
99            self.resource_name, deployment, self.api_version
100        )
101    }
102}
103
104#[async_trait]
105impl ModelProvider for RemoteAzureOpenAIProvider {
106    fn provider_id(&self) -> &'static str {
107        "remote/azure-openai"
108    }
109
110    fn capabilities(&self) -> ProviderCapabilities {
111        ProviderCapabilities {
112            supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
113        }
114    }
115
116    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
117        let cb = self.base.circuit_breaker_for(spec);
118        let resolved = AzureResolvedOptions::from_spec(spec)?;
119
120        match spec.task {
121            ModelTask::Embed => {
122                let model = AzureOpenAIEmbeddingModel {
123                    client: self.base.client.clone(),
124                    cb: cb.clone(),
125                    deployment: spec.model_id.clone(),
126                    options: resolved,
127                };
128                let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
129                Ok(Arc::new(handle) as LoadedModelHandle)
130            }
131            ModelTask::Generate => {
132                let model = AzureOpenAIGeneratorModel {
133                    client: self.base.client.clone(),
134                    cb,
135                    deployment: spec.model_id.clone(),
136                    options: resolved,
137                };
138                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
139                Ok(Arc::new(handle) as LoadedModelHandle)
140            }
141            _ => Err(RuntimeError::CapabilityMismatch(format!(
142                "Azure OpenAI provider does not support task {:?}",
143                spec.task
144            ))),
145        }
146    }
147
148    async fn health(&self) -> ProviderHealth {
149        ProviderHealth::Healthy
150    }
151}
152
153struct AzureOpenAIEmbeddingModel {
154    client: Client,
155    cb: crate::reliability::CircuitBreakerWrapper,
156    deployment: String,
157    options: AzureResolvedOptions,
158}
159
160#[async_trait]
161impl EmbeddingModel for AzureOpenAIEmbeddingModel {
162    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
163        let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
164
165        self.cb
166            .call(move || async move {
167                let url = self.options.embed_url(&self.deployment);
168
169                let response = self
170                    .client
171                    .post(&url)
172                    .header("api-key", &self.options.api_key)
173                    .json(&json!({
174                        "input": texts
175                    }))
176                    .send()
177                    .await
178                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
179
180                let body: serde_json::Value = check_http_status("Azure OpenAI", response)?
181                    .json()
182                    .await
183                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
184
185                let mut embeddings = Vec::new();
186                if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
187                    for item in data {
188                        if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
189                            let vec: Vec<f32> = embedding
190                                .iter()
191                                .filter_map(|v| v.as_f64().map(|f| f as f32))
192                                .collect();
193                            embeddings.push(vec);
194                        }
195                    }
196                }
197                Ok(embeddings)
198            })
199            .await
200    }
201
202    fn dimensions(&self) -> u32 {
203        // Azure deployments may use various embedding models;
204        // default to 1536 (text-embedding-ada-002 / text-embedding-3-small).
205        1536
206    }
207
208    fn model_id(&self) -> &str {
209        &self.deployment
210    }
211}
212
213struct AzureOpenAIGeneratorModel {
214    client: Client,
215    cb: crate::reliability::CircuitBreakerWrapper,
216    deployment: String,
217    options: AzureResolvedOptions,
218}
219
220#[async_trait]
221impl GeneratorModel for AzureOpenAIGeneratorModel {
222    async fn generate(
223        &self,
224        messages: &[Message],
225        options: GenerationOptions,
226    ) -> Result<GenerationResult> {
227        let messages: Vec<serde_json::Value> = messages
228            .iter()
229            .map(|msg| {
230                let role = match msg.role {
231                    MessageRole::System => "system",
232                    MessageRole::User => "user",
233                    MessageRole::Assistant => "assistant",
234                };
235                json!({ "role": role, "content": msg.text() })
236            })
237            .collect();
238
239        self.cb
240            .call(move || async move {
241                let url = self.options.chat_url(&self.deployment);
242
243                let mut body = json!({
244                    "messages": messages,
245                });
246
247                if let Some(max_tokens) = options.max_tokens {
248                    body["max_tokens"] = json!(max_tokens);
249                }
250                if let Some(temperature) = options.temperature {
251                    body["temperature"] = json!(temperature);
252                }
253                if let Some(top_p) = options.top_p {
254                    body["top_p"] = json!(top_p);
255                }
256
257                let response = self
258                    .client
259                    .post(&url)
260                    .header("api-key", &self.options.api_key)
261                    .json(&body)
262                    .send()
263                    .await
264                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
265
266                let body: serde_json::Value = check_http_status("Azure OpenAI", response)?
267                    .json()
268                    .await
269                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
270
271                let text = body["choices"][0]["message"]["content"]
272                    .as_str()
273                    .unwrap_or("")
274                    .to_string();
275
276                let usage = body.get("usage").map(|u| TokenUsage {
277                    prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as usize,
278                    completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as usize,
279                    total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as usize,
280                });
281
282                Ok(GenerationResult {
283                    text,
284                    usage,
285                    images: vec![],
286                    audio: None,
287                })
288            })
289            .await
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use crate::api::ModelRuntimeKey;
297    use crate::provider::remote_common::RemoteProviderBase;
298    use crate::traits::ModelProvider;
299    use std::time::Duration;
300
301    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
302
303    fn spec_with_opts(
304        alias: &str,
305        task: ModelTask,
306        model_id: &str,
307        options: serde_json::Value,
308    ) -> ModelAliasSpec {
309        ModelAliasSpec {
310            alias: alias.to_string(),
311            task,
312            provider_id: "remote/azure-openai".to_string(),
313            model_id: model_id.to_string(),
314            revision: None,
315            warmup: crate::api::WarmupPolicy::Lazy,
316            required: false,
317            timeout: None,
318            load_timeout: None,
319            retry: None,
320            options,
321        }
322    }
323
324    fn default_opts() -> serde_json::Value {
325        json!({ "resource_name": "my-resource" })
326    }
327
328    #[tokio::test]
329    async fn breaker_reused_for_same_runtime_key() {
330        let _lock = ENV_LOCK.lock().await;
331        unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
332
333        let provider = RemoteAzureOpenAIProvider::new();
334        let s1 = spec_with_opts(
335            "embed/a",
336            ModelTask::Embed,
337            "text-embedding-ada-002",
338            default_opts(),
339        );
340        let s2 = spec_with_opts(
341            "embed/b",
342            ModelTask::Embed,
343            "text-embedding-ada-002",
344            default_opts(),
345        );
346
347        let _ = provider.load(&s1).await.unwrap();
348        let _ = provider.load(&s2).await.unwrap();
349
350        assert_eq!(provider.breaker_count(), 1);
351
352        unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
353    }
354
355    #[tokio::test]
356    async fn breaker_cleanup_evicts_stale_entries() {
357        let _lock = ENV_LOCK.lock().await;
358        unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
359
360        let provider = RemoteAzureOpenAIProvider::new();
361        let stale = spec_with_opts(
362            "embed/stale",
363            ModelTask::Embed,
364            "text-embedding-ada-002",
365            default_opts(),
366        );
367        let fresh = spec_with_opts("chat/fresh", ModelTask::Generate, "gpt-4o", default_opts());
368        provider.insert_test_breaker(
369            ModelRuntimeKey::new(&stale),
370            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
371        );
372        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
373        assert_eq!(provider.breaker_count(), 2);
374
375        provider.force_cleanup_now_for_test();
376        let _ = provider.load(&fresh).await.unwrap();
377
378        assert_eq!(provider.breaker_count(), 1);
379
380        unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
381    }
382
383    #[tokio::test]
384    async fn load_fails_without_resource_name() {
385        let _lock = ENV_LOCK.lock().await;
386        unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
387
388        let provider = RemoteAzureOpenAIProvider::new();
389        let s = spec_with_opts(
390            "embed/a",
391            ModelTask::Embed,
392            "text-embedding-ada-002",
393            serde_json::Value::Null,
394        );
395        let result = provider.load(&s).await;
396        assert!(result.is_err());
397        assert!(result.unwrap_err().to_string().contains("resource_name"));
398
399        unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
400    }
401
402    #[tokio::test]
403    async fn rerank_capability_mismatch() {
404        let _lock = ENV_LOCK.lock().await;
405        unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
406
407        let provider = RemoteAzureOpenAIProvider::new();
408        let s = spec_with_opts(
409            "rerank/a",
410            ModelTask::Rerank,
411            "text-embedding-ada-002",
412            default_opts(),
413        );
414        let result = provider.load(&s).await;
415        assert!(result.is_err());
416        assert!(
417            result
418                .unwrap_err()
419                .to_string()
420                .contains("does not support task")
421        );
422
423        unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
424    }
425
426    #[test]
427    fn azure_url_construction() {
428        let opts = AzureResolvedOptions {
429            api_key: "key".to_string(),
430            resource_name: "my-resource".to_string(),
431            api_version: "2024-10-21".to_string(),
432        };
433
434        assert_eq!(
435            opts.embed_url("text-embedding-ada-002"),
436            "https://my-resource.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-10-21"
437        );
438
439        assert_eq!(
440            opts.chat_url("gpt-4o"),
441            "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-10-21"
442        );
443    }
444}