uni_xervo/provider/
anthropic.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5    GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle, Message, MessageRole,
6    ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13/// Remote provider that calls the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages)
14/// for text generation. Does not support embedding or reranking.
15///
16/// Requires the `ANTHROPIC_API_KEY` environment variable (or a custom env var
17/// name via the `api_key_env` option).
18pub struct RemoteAnthropicProvider {
19    base: RemoteProviderBase,
20}
21
22impl Default for RemoteAnthropicProvider {
23    fn default() -> Self {
24        Self {
25            base: RemoteProviderBase::new(),
26        }
27    }
28}
29
30impl RemoteAnthropicProvider {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    #[cfg(test)]
36    fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37        self.base.insert_test_breaker(key, age);
38    }
39
40    #[cfg(test)]
41    fn breaker_count(&self) -> usize {
42        self.base.breaker_count()
43    }
44
45    #[cfg(test)]
46    fn force_cleanup_now_for_test(&self) {
47        self.base.force_cleanup_now_for_test();
48    }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteAnthropicProvider {
53    fn provider_id(&self) -> &'static str {
54        "remote/anthropic"
55    }
56
57    fn capabilities(&self) -> ProviderCapabilities {
58        ProviderCapabilities {
59            supported_tasks: vec![ModelTask::Generate],
60        }
61    }
62
63    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64        let cb = self.base.circuit_breaker_for(spec);
65        let api_key = resolve_api_key(&spec.options, "api_key_env", "ANTHROPIC_API_KEY")?;
66
67        let anthropic_version = spec
68            .options
69            .get("anthropic_version")
70            .and_then(|v| v.as_str())
71            .unwrap_or("2023-06-01")
72            .to_string();
73
74        match spec.task {
75            ModelTask::Generate => {
76                let model = AnthropicGeneratorModel {
77                    client: self.base.client.clone(),
78                    cb,
79                    model_id: spec.model_id.clone(),
80                    api_key,
81                    anthropic_version,
82                };
83                let handle: Arc<dyn GeneratorModel> = Arc::new(model);
84                Ok(Arc::new(handle) as LoadedModelHandle)
85            }
86            _ => Err(RuntimeError::CapabilityMismatch(format!(
87                "Anthropic provider does not support task {:?}",
88                spec.task
89            ))),
90        }
91    }
92
93    async fn health(&self) -> ProviderHealth {
94        ProviderHealth::Healthy
95    }
96}
97
98struct AnthropicGeneratorModel {
99    client: Client,
100    cb: crate::reliability::CircuitBreakerWrapper,
101    model_id: String,
102    api_key: String,
103    anthropic_version: String,
104}
105
106fn build_anthropic_payload(
107    model_id: &str,
108    messages: &[serde_json::Value],
109    options: &GenerationOptions,
110    system: Option<&str>,
111) -> serde_json::Value {
112    let max_tokens = options.max_tokens.unwrap_or(1024);
113
114    let mut body = json!({
115        "model": model_id,
116        "max_tokens": max_tokens,
117        "messages": messages,
118    });
119
120    if let Some(system_text) = system {
121        body["system"] = json!(system_text);
122    }
123    if let Some(temperature) = options.temperature {
124        body["temperature"] = json!(temperature);
125    }
126    if let Some(top_p) = options.top_p {
127        body["top_p"] = json!(top_p);
128    }
129
130    body
131}
132
133#[async_trait]
134impl GeneratorModel for AnthropicGeneratorModel {
135    async fn generate(
136        &self,
137        messages: &[Message],
138        options: GenerationOptions,
139    ) -> Result<GenerationResult> {
140        // Extract system messages into a single combined string
141        let system_parts: Vec<String> = messages
142            .iter()
143            .filter(|m| m.role == MessageRole::System)
144            .map(|m| m.text())
145            .collect();
146        let system_text = if system_parts.is_empty() {
147            None
148        } else {
149            Some(system_parts.join("\n"))
150        };
151
152        let messages: Vec<serde_json::Value> = messages
153            .iter()
154            .filter(|msg| msg.role != MessageRole::System)
155            .map(|msg| {
156                let role = match msg.role {
157                    MessageRole::User => "user",
158                    MessageRole::Assistant => "assistant",
159                    MessageRole::System => unreachable!("system messages filtered above"),
160                };
161                json!({ "role": role, "content": msg.text() })
162            })
163            .collect();
164
165        self.cb
166            .call(move || async move {
167                let body = build_anthropic_payload(
168                    &self.model_id,
169                    &messages,
170                    &options,
171                    system_text.as_deref(),
172                );
173
174                let response = self
175                    .client
176                    .post("https://api.anthropic.com/v1/messages")
177                    .header("x-api-key", &self.api_key)
178                    .header("anthropic-version", &self.anthropic_version)
179                    .header("content-type", "application/json")
180                    .json(&body)
181                    .send()
182                    .await
183                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
184
185                let body: serde_json::Value = check_http_status("Anthropic", response)?
186                    .json()
187                    .await
188                    .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
189
190                let text = body
191                    .get("content")
192                    .and_then(|c| c.as_array())
193                    .and_then(|arr| arr.first())
194                    .and_then(|item| item.get("text"))
195                    .and_then(|t| t.as_str())
196                    .unwrap_or("")
197                    .to_string();
198
199                let usage = body.get("usage").map(|u| TokenUsage {
200                    prompt_tokens: u["input_tokens"].as_u64().unwrap_or(0) as usize,
201                    completion_tokens: u["output_tokens"].as_u64().unwrap_or(0) as usize,
202                    total_tokens: (u["input_tokens"].as_u64().unwrap_or(0)
203                        + u["output_tokens"].as_u64().unwrap_or(0))
204                        as usize,
205                });
206
207                Ok(GenerationResult {
208                    text,
209                    usage,
210                    images: vec![],
211                    audio: None,
212                })
213            })
214            .await
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use crate::api::ModelRuntimeKey;
222    use crate::provider::remote_common::RemoteProviderBase;
223    use crate::traits::ModelProvider;
224    use std::time::Duration;
225
226    static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
227
228    fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
229        ModelAliasSpec {
230            alias: alias.to_string(),
231            task,
232            provider_id: "remote/anthropic".to_string(),
233            model_id: model_id.to_string(),
234            revision: None,
235            warmup: crate::api::WarmupPolicy::Lazy,
236            required: false,
237            timeout: None,
238            load_timeout: None,
239            retry: None,
240            options: serde_json::Value::Null,
241        }
242    }
243
244    #[tokio::test]
245    async fn breaker_reused_for_same_runtime_key() {
246        let _lock = ENV_LOCK.lock().await;
247        unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
248
249        let provider = RemoteAnthropicProvider::new();
250        let s1 = spec("gen/a", ModelTask::Generate, "claude-sonnet-4-5-20250929");
251        let s2 = spec("gen/b", ModelTask::Generate, "claude-sonnet-4-5-20250929");
252
253        let _ = provider.load(&s1).await.unwrap();
254        let _ = provider.load(&s2).await.unwrap();
255
256        assert_eq!(provider.breaker_count(), 1);
257
258        unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
259    }
260
261    #[tokio::test]
262    async fn breaker_cleanup_evicts_stale_entries() {
263        let _lock = ENV_LOCK.lock().await;
264        unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
265
266        let provider = RemoteAnthropicProvider::new();
267        let stale = spec(
268            "gen/stale",
269            ModelTask::Generate,
270            "claude-sonnet-4-5-20250929",
271        );
272        let fresh = spec(
273            "gen/fresh",
274            ModelTask::Generate,
275            "claude-haiku-3-5-20241022",
276        );
277        provider.insert_test_breaker(
278            ModelRuntimeKey::new(&stale),
279            RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
280        );
281        provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
282        assert_eq!(provider.breaker_count(), 2);
283
284        provider.force_cleanup_now_for_test();
285        let _ = provider.load(&fresh).await.unwrap();
286
287        assert_eq!(provider.breaker_count(), 1);
288
289        unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
290    }
291
292    #[tokio::test]
293    async fn embed_capability_mismatch() {
294        let _lock = ENV_LOCK.lock().await;
295        unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
296
297        let provider = RemoteAnthropicProvider::new();
298        let s = spec("embed/a", ModelTask::Embed, "claude-sonnet-4-5-20250929");
299        let result = provider.load(&s).await;
300        assert!(result.is_err());
301        assert!(
302            result
303                .unwrap_err()
304                .to_string()
305                .contains("does not support task")
306        );
307
308        unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
309    }
310
311    #[tokio::test]
312    async fn rerank_capability_mismatch() {
313        let _lock = ENV_LOCK.lock().await;
314        unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
315
316        let provider = RemoteAnthropicProvider::new();
317        let s = spec("rerank/a", ModelTask::Rerank, "claude-sonnet-4-5-20250929");
318        let result = provider.load(&s).await;
319        assert!(result.is_err());
320        assert!(
321            result
322                .unwrap_err()
323                .to_string()
324                .contains("does not support task")
325        );
326
327        unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
328    }
329
330    #[test]
331    fn payload_defaults_max_tokens_to_1024() {
332        let messages = vec![json!({"role": "user", "content": "hello"})];
333        let payload = build_anthropic_payload(
334            "claude-sonnet-4-5-20250929",
335            &messages,
336            &GenerationOptions::default(),
337            None,
338        );
339        assert_eq!(payload["max_tokens"], 1024);
340    }
341
342    #[test]
343    fn payload_uses_explicit_max_tokens() {
344        let messages = vec![json!({"role": "user", "content": "hello"})];
345        let payload = build_anthropic_payload(
346            "claude-sonnet-4-5-20250929",
347            &messages,
348            &GenerationOptions {
349                max_tokens: Some(512),
350                ..Default::default()
351            },
352            None,
353        );
354        assert_eq!(payload["max_tokens"], 512);
355    }
356
357    #[test]
358    fn payload_includes_system_field() {
359        let messages = vec![json!({"role": "user", "content": "hello"})];
360        let payload = build_anthropic_payload(
361            "claude-sonnet-4-5-20250929",
362            &messages,
363            &GenerationOptions::default(),
364            Some("you are helpful"),
365        );
366        assert_eq!(payload["system"], "you are helpful");
367    }
368
369    #[test]
370    fn payload_omits_system_field_when_none() {
371        let messages = vec![json!({"role": "user", "content": "hello"})];
372        let payload = build_anthropic_payload(
373            "claude-sonnet-4-5-20250929",
374            &messages,
375            &GenerationOptions::default(),
376            None,
377        );
378        assert!(payload.get("system").is_none());
379    }
380}