1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5 EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6 Message, MessageRole, ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13pub struct RemoteAzureOpenAIProvider {
19 base: RemoteProviderBase,
20}
21
22impl Default for RemoteAzureOpenAIProvider {
23 fn default() -> Self {
24 Self {
25 base: RemoteProviderBase::new(),
26 }
27 }
28}
29
30impl RemoteAzureOpenAIProvider {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 #[cfg(test)]
36 fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37 self.base.insert_test_breaker(key, age);
38 }
39
40 #[cfg(test)]
41 fn breaker_count(&self) -> usize {
42 self.base.breaker_count()
43 }
44
45 #[cfg(test)]
46 fn force_cleanup_now_for_test(&self) {
47 self.base.force_cleanup_now_for_test();
48 }
49}
50
51#[derive(Clone)]
54struct AzureResolvedOptions {
55 api_key: String,
56 resource_name: String,
57 api_version: String,
58}
59
60impl AzureResolvedOptions {
61 fn from_spec(spec: &ModelAliasSpec) -> Result<Self> {
62 let api_key = resolve_api_key(&spec.options, "api_key_env", "AZURE_OPENAI_API_KEY")?;
63
64 let resource_name = spec
65 .options
66 .get("resource_name")
67 .and_then(|v| v.as_str())
68 .ok_or_else(|| {
69 RuntimeError::Config(
70 "Option 'resource_name' is required for Azure OpenAI provider".to_string(),
71 )
72 })?
73 .to_string();
74
75 let api_version = spec
76 .options
77 .get("api_version")
78 .and_then(|v| v.as_str())
79 .unwrap_or("2024-10-21")
80 .to_string();
81
82 Ok(Self {
83 api_key,
84 resource_name,
85 api_version,
86 })
87 }
88
89 fn embed_url(&self, deployment: &str) -> String {
90 format!(
91 "https://{}.openai.azure.com/openai/deployments/{}/embeddings?api-version={}",
92 self.resource_name, deployment, self.api_version
93 )
94 }
95
96 fn chat_url(&self, deployment: &str) -> String {
97 format!(
98 "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
99 self.resource_name, deployment, self.api_version
100 )
101 }
102}
103
104#[async_trait]
105impl ModelProvider for RemoteAzureOpenAIProvider {
106 fn provider_id(&self) -> &'static str {
107 "remote/azure-openai"
108 }
109
110 fn capabilities(&self) -> ProviderCapabilities {
111 ProviderCapabilities {
112 supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
113 }
114 }
115
116 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
117 let cb = self.base.circuit_breaker_for(spec);
118 let resolved = AzureResolvedOptions::from_spec(spec)?;
119
120 match spec.task {
121 ModelTask::Embed => {
122 let model = AzureOpenAIEmbeddingModel {
123 client: self.base.client.clone(),
124 cb: cb.clone(),
125 deployment: spec.model_id.clone(),
126 options: resolved,
127 };
128 let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
129 Ok(Arc::new(handle) as LoadedModelHandle)
130 }
131 ModelTask::Generate => {
132 let model = AzureOpenAIGeneratorModel {
133 client: self.base.client.clone(),
134 cb,
135 deployment: spec.model_id.clone(),
136 options: resolved,
137 };
138 let handle: Arc<dyn GeneratorModel> = Arc::new(model);
139 Ok(Arc::new(handle) as LoadedModelHandle)
140 }
141 _ => Err(RuntimeError::CapabilityMismatch(format!(
142 "Azure OpenAI provider does not support task {:?}",
143 spec.task
144 ))),
145 }
146 }
147
148 async fn health(&self) -> ProviderHealth {
149 ProviderHealth::Healthy
150 }
151}
152
153struct AzureOpenAIEmbeddingModel {
154 client: Client,
155 cb: crate::reliability::CircuitBreakerWrapper,
156 deployment: String,
157 options: AzureResolvedOptions,
158}
159
160#[async_trait]
161impl EmbeddingModel for AzureOpenAIEmbeddingModel {
162 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
163 let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
164
165 self.cb
166 .call(move || async move {
167 let url = self.options.embed_url(&self.deployment);
168
169 let response = self
170 .client
171 .post(&url)
172 .header("api-key", &self.options.api_key)
173 .json(&json!({
174 "input": texts
175 }))
176 .send()
177 .await
178 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
179
180 let body: serde_json::Value = check_http_status("Azure OpenAI", response)?
181 .json()
182 .await
183 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
184
185 let mut embeddings = Vec::new();
186 if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
187 for item in data {
188 if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
189 let vec: Vec<f32> = embedding
190 .iter()
191 .filter_map(|v| v.as_f64().map(|f| f as f32))
192 .collect();
193 embeddings.push(vec);
194 }
195 }
196 }
197 Ok(embeddings)
198 })
199 .await
200 }
201
202 fn dimensions(&self) -> u32 {
203 1536
206 }
207
208 fn model_id(&self) -> &str {
209 &self.deployment
210 }
211}
212
213struct AzureOpenAIGeneratorModel {
214 client: Client,
215 cb: crate::reliability::CircuitBreakerWrapper,
216 deployment: String,
217 options: AzureResolvedOptions,
218}
219
220#[async_trait]
221impl GeneratorModel for AzureOpenAIGeneratorModel {
222 async fn generate(
223 &self,
224 messages: &[Message],
225 options: GenerationOptions,
226 ) -> Result<GenerationResult> {
227 let messages: Vec<serde_json::Value> = messages
228 .iter()
229 .map(|msg| {
230 let role = match msg.role {
231 MessageRole::System => "system",
232 MessageRole::User => "user",
233 MessageRole::Assistant => "assistant",
234 };
235 json!({ "role": role, "content": msg.text() })
236 })
237 .collect();
238
239 self.cb
240 .call(move || async move {
241 let url = self.options.chat_url(&self.deployment);
242
243 let mut body = json!({
244 "messages": messages,
245 });
246
247 if let Some(max_tokens) = options.max_tokens {
248 body["max_tokens"] = json!(max_tokens);
249 }
250 if let Some(temperature) = options.temperature {
251 body["temperature"] = json!(temperature);
252 }
253 if let Some(top_p) = options.top_p {
254 body["top_p"] = json!(top_p);
255 }
256
257 let response = self
258 .client
259 .post(&url)
260 .header("api-key", &self.options.api_key)
261 .json(&body)
262 .send()
263 .await
264 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
265
266 let body: serde_json::Value = check_http_status("Azure OpenAI", response)?
267 .json()
268 .await
269 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
270
271 let text = body["choices"][0]["message"]["content"]
272 .as_str()
273 .unwrap_or("")
274 .to_string();
275
276 let usage = body.get("usage").map(|u| TokenUsage {
277 prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as usize,
278 completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as usize,
279 total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as usize,
280 });
281
282 Ok(GenerationResult {
283 text,
284 usage,
285 images: vec![],
286 audio: None,
287 })
288 })
289 .await
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::api::ModelRuntimeKey;
297 use crate::provider::remote_common::RemoteProviderBase;
298 use crate::traits::ModelProvider;
299 use std::time::Duration;
300
301 static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
302
303 fn spec_with_opts(
304 alias: &str,
305 task: ModelTask,
306 model_id: &str,
307 options: serde_json::Value,
308 ) -> ModelAliasSpec {
309 ModelAliasSpec {
310 alias: alias.to_string(),
311 task,
312 provider_id: "remote/azure-openai".to_string(),
313 model_id: model_id.to_string(),
314 revision: None,
315 warmup: crate::api::WarmupPolicy::Lazy,
316 required: false,
317 timeout: None,
318 load_timeout: None,
319 retry: None,
320 options,
321 }
322 }
323
324 fn default_opts() -> serde_json::Value {
325 json!({ "resource_name": "my-resource" })
326 }
327
328 #[tokio::test]
329 async fn breaker_reused_for_same_runtime_key() {
330 let _lock = ENV_LOCK.lock().await;
331 unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
332
333 let provider = RemoteAzureOpenAIProvider::new();
334 let s1 = spec_with_opts(
335 "embed/a",
336 ModelTask::Embed,
337 "text-embedding-ada-002",
338 default_opts(),
339 );
340 let s2 = spec_with_opts(
341 "embed/b",
342 ModelTask::Embed,
343 "text-embedding-ada-002",
344 default_opts(),
345 );
346
347 let _ = provider.load(&s1).await.unwrap();
348 let _ = provider.load(&s2).await.unwrap();
349
350 assert_eq!(provider.breaker_count(), 1);
351
352 unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
353 }
354
355 #[tokio::test]
356 async fn breaker_cleanup_evicts_stale_entries() {
357 let _lock = ENV_LOCK.lock().await;
358 unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
359
360 let provider = RemoteAzureOpenAIProvider::new();
361 let stale = spec_with_opts(
362 "embed/stale",
363 ModelTask::Embed,
364 "text-embedding-ada-002",
365 default_opts(),
366 );
367 let fresh = spec_with_opts("chat/fresh", ModelTask::Generate, "gpt-4o", default_opts());
368 provider.insert_test_breaker(
369 ModelRuntimeKey::new(&stale),
370 RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
371 );
372 provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
373 assert_eq!(provider.breaker_count(), 2);
374
375 provider.force_cleanup_now_for_test();
376 let _ = provider.load(&fresh).await.unwrap();
377
378 assert_eq!(provider.breaker_count(), 1);
379
380 unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
381 }
382
383 #[tokio::test]
384 async fn load_fails_without_resource_name() {
385 let _lock = ENV_LOCK.lock().await;
386 unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
387
388 let provider = RemoteAzureOpenAIProvider::new();
389 let s = spec_with_opts(
390 "embed/a",
391 ModelTask::Embed,
392 "text-embedding-ada-002",
393 serde_json::Value::Null,
394 );
395 let result = provider.load(&s).await;
396 assert!(result.is_err());
397 assert!(result.unwrap_err().to_string().contains("resource_name"));
398
399 unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
400 }
401
402 #[tokio::test]
403 async fn rerank_capability_mismatch() {
404 let _lock = ENV_LOCK.lock().await;
405 unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
406
407 let provider = RemoteAzureOpenAIProvider::new();
408 let s = spec_with_opts(
409 "rerank/a",
410 ModelTask::Rerank,
411 "text-embedding-ada-002",
412 default_opts(),
413 );
414 let result = provider.load(&s).await;
415 assert!(result.is_err());
416 assert!(
417 result
418 .unwrap_err()
419 .to_string()
420 .contains("does not support task")
421 );
422
423 unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
424 }
425
426 #[test]
427 fn azure_url_construction() {
428 let opts = AzureResolvedOptions {
429 api_key: "key".to_string(),
430 resource_name: "my-resource".to_string(),
431 api_version: "2024-10-21".to_string(),
432 };
433
434 assert_eq!(
435 opts.embed_url("text-embedding-ada-002"),
436 "https://my-resource.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-10-21"
437 );
438
439 assert_eq!(
440 opts.chat_url("gpt-4o"),
441 "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-10-21"
442 );
443 }
444}