1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5 EmbeddingModel, GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle,
6 ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13pub struct RemoteAzureOpenAIProvider {
19 base: RemoteProviderBase,
20}
21
22impl Default for RemoteAzureOpenAIProvider {
23 fn default() -> Self {
24 Self {
25 base: RemoteProviderBase::new(),
26 }
27 }
28}
29
30impl RemoteAzureOpenAIProvider {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 #[cfg(test)]
36 fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37 self.base.insert_test_breaker(key, age);
38 }
39
40 #[cfg(test)]
41 fn breaker_count(&self) -> usize {
42 self.base.breaker_count()
43 }
44
45 #[cfg(test)]
46 fn force_cleanup_now_for_test(&self) {
47 self.base.force_cleanup_now_for_test();
48 }
49}
50
51#[derive(Clone)]
54struct AzureResolvedOptions {
55 api_key: String,
56 resource_name: String,
57 api_version: String,
58}
59
60impl AzureResolvedOptions {
61 fn from_spec(spec: &ModelAliasSpec) -> Result<Self> {
62 let api_key = resolve_api_key(&spec.options, "api_key_env", "AZURE_OPENAI_API_KEY")?;
63
64 let resource_name = spec
65 .options
66 .get("resource_name")
67 .and_then(|v| v.as_str())
68 .ok_or_else(|| {
69 RuntimeError::Config(
70 "Option 'resource_name' is required for Azure OpenAI provider".to_string(),
71 )
72 })?
73 .to_string();
74
75 let api_version = spec
76 .options
77 .get("api_version")
78 .and_then(|v| v.as_str())
79 .unwrap_or("2024-10-21")
80 .to_string();
81
82 Ok(Self {
83 api_key,
84 resource_name,
85 api_version,
86 })
87 }
88
89 fn embed_url(&self, deployment: &str) -> String {
90 format!(
91 "https://{}.openai.azure.com/openai/deployments/{}/embeddings?api-version={}",
92 self.resource_name, deployment, self.api_version
93 )
94 }
95
96 fn chat_url(&self, deployment: &str) -> String {
97 format!(
98 "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}",
99 self.resource_name, deployment, self.api_version
100 )
101 }
102}
103
104#[async_trait]
105impl ModelProvider for RemoteAzureOpenAIProvider {
106 fn provider_id(&self) -> &'static str {
107 "remote/azure-openai"
108 }
109
110 fn capabilities(&self) -> ProviderCapabilities {
111 ProviderCapabilities {
112 supported_tasks: vec![ModelTask::Embed, ModelTask::Generate],
113 }
114 }
115
116 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
117 let cb = self.base.circuit_breaker_for(spec);
118 let resolved = AzureResolvedOptions::from_spec(spec)?;
119
120 match spec.task {
121 ModelTask::Embed => {
122 let model = AzureOpenAIEmbeddingModel {
123 client: self.base.client.clone(),
124 cb: cb.clone(),
125 deployment: spec.model_id.clone(),
126 options: resolved,
127 };
128 let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
129 Ok(Arc::new(handle) as LoadedModelHandle)
130 }
131 ModelTask::Generate => {
132 let model = AzureOpenAIGeneratorModel {
133 client: self.base.client.clone(),
134 cb,
135 deployment: spec.model_id.clone(),
136 options: resolved,
137 };
138 let handle: Arc<dyn GeneratorModel> = Arc::new(model);
139 Ok(Arc::new(handle) as LoadedModelHandle)
140 }
141 _ => Err(RuntimeError::CapabilityMismatch(format!(
142 "Azure OpenAI provider does not support task {:?}",
143 spec.task
144 ))),
145 }
146 }
147
148 async fn health(&self) -> ProviderHealth {
149 ProviderHealth::Healthy
150 }
151}
152
153struct AzureOpenAIEmbeddingModel {
154 client: Client,
155 cb: crate::reliability::CircuitBreakerWrapper,
156 deployment: String,
157 options: AzureResolvedOptions,
158}
159
160#[async_trait]
161impl EmbeddingModel for AzureOpenAIEmbeddingModel {
162 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
163 let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
164
165 self.cb
166 .call(move || async move {
167 let url = self.options.embed_url(&self.deployment);
168
169 let response = self
170 .client
171 .post(&url)
172 .header("api-key", &self.options.api_key)
173 .json(&json!({
174 "input": texts
175 }))
176 .send()
177 .await
178 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
179
180 let body: serde_json::Value = check_http_status("Azure OpenAI", response)?
181 .json()
182 .await
183 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
184
185 let mut embeddings = Vec::new();
186 if let Some(data) = body.get("data").and_then(|d| d.as_array()) {
187 for item in data {
188 if let Some(embedding) = item.get("embedding").and_then(|e| e.as_array()) {
189 let vec: Vec<f32> = embedding
190 .iter()
191 .filter_map(|v| v.as_f64().map(|f| f as f32))
192 .collect();
193 embeddings.push(vec);
194 }
195 }
196 }
197 Ok(embeddings)
198 })
199 .await
200 }
201
202 fn dimensions(&self) -> u32 {
203 1536
206 }
207
208 fn model_id(&self) -> &str {
209 &self.deployment
210 }
211}
212
213struct AzureOpenAIGeneratorModel {
214 client: Client,
215 cb: crate::reliability::CircuitBreakerWrapper,
216 deployment: String,
217 options: AzureResolvedOptions,
218}
219
220#[async_trait]
221impl GeneratorModel for AzureOpenAIGeneratorModel {
222 async fn generate(
223 &self,
224 messages: &[String],
225 options: GenerationOptions,
226 ) -> Result<GenerationResult> {
227 let messages: Vec<serde_json::Value> = messages
228 .iter()
229 .enumerate()
230 .map(|(i, content)| {
231 let role = if i % 2 == 0 { "user" } else { "assistant" };
232 json!({ "role": role, "content": content })
233 })
234 .collect();
235
236 self.cb
237 .call(move || async move {
238 let url = self.options.chat_url(&self.deployment);
239
240 let mut body = json!({
241 "messages": messages,
242 });
243
244 if let Some(max_tokens) = options.max_tokens {
245 body["max_tokens"] = json!(max_tokens);
246 }
247 if let Some(temperature) = options.temperature {
248 body["temperature"] = json!(temperature);
249 }
250 if let Some(top_p) = options.top_p {
251 body["top_p"] = json!(top_p);
252 }
253
254 let response = self
255 .client
256 .post(&url)
257 .header("api-key", &self.options.api_key)
258 .json(&body)
259 .send()
260 .await
261 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
262
263 let body: serde_json::Value = check_http_status("Azure OpenAI", response)?
264 .json()
265 .await
266 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
267
268 let text = body["choices"][0]["message"]["content"]
269 .as_str()
270 .unwrap_or("")
271 .to_string();
272
273 let usage = body.get("usage").map(|u| TokenUsage {
274 prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as usize,
275 completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as usize,
276 total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as usize,
277 });
278
279 Ok(GenerationResult { text, usage })
280 })
281 .await
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use crate::api::ModelRuntimeKey;
289 use crate::provider::remote_common::RemoteProviderBase;
290 use crate::traits::ModelProvider;
291 use std::time::Duration;
292
293 static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
294
295 fn spec_with_opts(
296 alias: &str,
297 task: ModelTask,
298 model_id: &str,
299 options: serde_json::Value,
300 ) -> ModelAliasSpec {
301 ModelAliasSpec {
302 alias: alias.to_string(),
303 task,
304 provider_id: "remote/azure-openai".to_string(),
305 model_id: model_id.to_string(),
306 revision: None,
307 warmup: crate::api::WarmupPolicy::Lazy,
308 required: false,
309 timeout: None,
310 load_timeout: None,
311 retry: None,
312 options,
313 }
314 }
315
316 fn default_opts() -> serde_json::Value {
317 json!({ "resource_name": "my-resource" })
318 }
319
320 #[tokio::test]
321 async fn breaker_reused_for_same_runtime_key() {
322 let _lock = ENV_LOCK.lock().await;
323 unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
324
325 let provider = RemoteAzureOpenAIProvider::new();
326 let s1 = spec_with_opts(
327 "embed/a",
328 ModelTask::Embed,
329 "text-embedding-ada-002",
330 default_opts(),
331 );
332 let s2 = spec_with_opts(
333 "embed/b",
334 ModelTask::Embed,
335 "text-embedding-ada-002",
336 default_opts(),
337 );
338
339 let _ = provider.load(&s1).await.unwrap();
340 let _ = provider.load(&s2).await.unwrap();
341
342 assert_eq!(provider.breaker_count(), 1);
343
344 unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
345 }
346
347 #[tokio::test]
348 async fn breaker_cleanup_evicts_stale_entries() {
349 let _lock = ENV_LOCK.lock().await;
350 unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
351
352 let provider = RemoteAzureOpenAIProvider::new();
353 let stale = spec_with_opts(
354 "embed/stale",
355 ModelTask::Embed,
356 "text-embedding-ada-002",
357 default_opts(),
358 );
359 let fresh = spec_with_opts("chat/fresh", ModelTask::Generate, "gpt-4o", default_opts());
360 provider.insert_test_breaker(
361 ModelRuntimeKey::new(&stale),
362 RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
363 );
364 provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
365 assert_eq!(provider.breaker_count(), 2);
366
367 provider.force_cleanup_now_for_test();
368 let _ = provider.load(&fresh).await.unwrap();
369
370 assert_eq!(provider.breaker_count(), 1);
371
372 unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
373 }
374
375 #[tokio::test]
376 async fn load_fails_without_resource_name() {
377 let _lock = ENV_LOCK.lock().await;
378 unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
379
380 let provider = RemoteAzureOpenAIProvider::new();
381 let s = spec_with_opts(
382 "embed/a",
383 ModelTask::Embed,
384 "text-embedding-ada-002",
385 serde_json::Value::Null,
386 );
387 let result = provider.load(&s).await;
388 assert!(result.is_err());
389 assert!(result.unwrap_err().to_string().contains("resource_name"));
390
391 unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
392 }
393
394 #[tokio::test]
395 async fn rerank_capability_mismatch() {
396 let _lock = ENV_LOCK.lock().await;
397 unsafe { std::env::set_var("AZURE_OPENAI_API_KEY", "test-key") };
398
399 let provider = RemoteAzureOpenAIProvider::new();
400 let s = spec_with_opts(
401 "rerank/a",
402 ModelTask::Rerank,
403 "text-embedding-ada-002",
404 default_opts(),
405 );
406 let result = provider.load(&s).await;
407 assert!(result.is_err());
408 assert!(
409 result
410 .unwrap_err()
411 .to_string()
412 .contains("does not support task")
413 );
414
415 unsafe { std::env::remove_var("AZURE_OPENAI_API_KEY") };
416 }
417
418 #[test]
419 fn azure_url_construction() {
420 let opts = AzureResolvedOptions {
421 api_key: "key".to_string(),
422 resource_name: "my-resource".to_string(),
423 api_version: "2024-10-21".to_string(),
424 };
425
426 assert_eq!(
427 opts.embed_url("text-embedding-ada-002"),
428 "https://my-resource.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-10-21"
429 );
430
431 assert_eq!(
432 opts.chat_url("gpt-4o"),
433 "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-10-21"
434 );
435 }
436}