1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5 GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle, Message, MessageRole,
6 ModelProvider, ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13pub struct RemoteAnthropicProvider {
19 base: RemoteProviderBase,
20}
21
22impl Default for RemoteAnthropicProvider {
23 fn default() -> Self {
24 Self {
25 base: RemoteProviderBase::new(),
26 }
27 }
28}
29
30impl RemoteAnthropicProvider {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 #[cfg(test)]
36 fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37 self.base.insert_test_breaker(key, age);
38 }
39
40 #[cfg(test)]
41 fn breaker_count(&self) -> usize {
42 self.base.breaker_count()
43 }
44
45 #[cfg(test)]
46 fn force_cleanup_now_for_test(&self) {
47 self.base.force_cleanup_now_for_test();
48 }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteAnthropicProvider {
53 fn provider_id(&self) -> &'static str {
54 "remote/anthropic"
55 }
56
57 fn capabilities(&self) -> ProviderCapabilities {
58 ProviderCapabilities {
59 supported_tasks: vec![ModelTask::Generate],
60 }
61 }
62
63 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64 let cb = self.base.circuit_breaker_for(spec);
65 let api_key = resolve_api_key(&spec.options, "api_key_env", "ANTHROPIC_API_KEY")?;
66
67 let anthropic_version = spec
68 .options
69 .get("anthropic_version")
70 .and_then(|v| v.as_str())
71 .unwrap_or("2023-06-01")
72 .to_string();
73
74 match spec.task {
75 ModelTask::Generate => {
76 let model = AnthropicGeneratorModel {
77 client: self.base.client.clone(),
78 cb,
79 model_id: spec.model_id.clone(),
80 api_key,
81 anthropic_version,
82 };
83 let handle: Arc<dyn GeneratorModel> = Arc::new(model);
84 Ok(Arc::new(handle) as LoadedModelHandle)
85 }
86 _ => Err(RuntimeError::CapabilityMismatch(format!(
87 "Anthropic provider does not support task {:?}",
88 spec.task
89 ))),
90 }
91 }
92
93 async fn health(&self) -> ProviderHealth {
94 ProviderHealth::Healthy
95 }
96}
97
98struct AnthropicGeneratorModel {
99 client: Client,
100 cb: crate::reliability::CircuitBreakerWrapper,
101 model_id: String,
102 api_key: String,
103 anthropic_version: String,
104}
105
106fn build_anthropic_payload(
107 model_id: &str,
108 messages: &[serde_json::Value],
109 options: &GenerationOptions,
110 system: Option<&str>,
111) -> serde_json::Value {
112 let max_tokens = options.max_tokens.unwrap_or(1024);
113
114 let mut body = json!({
115 "model": model_id,
116 "max_tokens": max_tokens,
117 "messages": messages,
118 });
119
120 if let Some(system_text) = system {
121 body["system"] = json!(system_text);
122 }
123 if let Some(temperature) = options.temperature {
124 body["temperature"] = json!(temperature);
125 }
126 if let Some(top_p) = options.top_p {
127 body["top_p"] = json!(top_p);
128 }
129
130 body
131}
132
133#[async_trait]
134impl GeneratorModel for AnthropicGeneratorModel {
135 async fn generate(
136 &self,
137 messages: &[Message],
138 options: GenerationOptions,
139 ) -> Result<GenerationResult> {
140 let system_parts: Vec<String> = messages
142 .iter()
143 .filter(|m| m.role == MessageRole::System)
144 .map(|m| m.text())
145 .collect();
146 let system_text = if system_parts.is_empty() {
147 None
148 } else {
149 Some(system_parts.join("\n"))
150 };
151
152 let messages: Vec<serde_json::Value> = messages
153 .iter()
154 .filter(|msg| msg.role != MessageRole::System)
155 .map(|msg| {
156 let role = match msg.role {
157 MessageRole::User => "user",
158 MessageRole::Assistant => "assistant",
159 MessageRole::System => unreachable!("system messages filtered above"),
160 };
161 json!({ "role": role, "content": msg.text() })
162 })
163 .collect();
164
165 self.cb
166 .call(move || async move {
167 let body = build_anthropic_payload(
168 &self.model_id,
169 &messages,
170 &options,
171 system_text.as_deref(),
172 );
173
174 let response = self
175 .client
176 .post("https://api.anthropic.com/v1/messages")
177 .header("x-api-key", &self.api_key)
178 .header("anthropic-version", &self.anthropic_version)
179 .header("content-type", "application/json")
180 .json(&body)
181 .send()
182 .await
183 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
184
185 let body: serde_json::Value = check_http_status("Anthropic", response)?
186 .json()
187 .await
188 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
189
190 let text = body
191 .get("content")
192 .and_then(|c| c.as_array())
193 .and_then(|arr| arr.first())
194 .and_then(|item| item.get("text"))
195 .and_then(|t| t.as_str())
196 .unwrap_or("")
197 .to_string();
198
199 let usage = body.get("usage").map(|u| TokenUsage {
200 prompt_tokens: u["input_tokens"].as_u64().unwrap_or(0) as usize,
201 completion_tokens: u["output_tokens"].as_u64().unwrap_or(0) as usize,
202 total_tokens: (u["input_tokens"].as_u64().unwrap_or(0)
203 + u["output_tokens"].as_u64().unwrap_or(0))
204 as usize,
205 });
206
207 Ok(GenerationResult {
208 text,
209 usage,
210 images: vec![],
211 audio: None,
212 })
213 })
214 .await
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use crate::api::ModelRuntimeKey;
222 use crate::provider::remote_common::RemoteProviderBase;
223 use crate::traits::ModelProvider;
224 use std::time::Duration;
225
226 static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
227
228 fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
229 ModelAliasSpec {
230 alias: alias.to_string(),
231 task,
232 provider_id: "remote/anthropic".to_string(),
233 model_id: model_id.to_string(),
234 revision: None,
235 warmup: crate::api::WarmupPolicy::Lazy,
236 required: false,
237 timeout: None,
238 load_timeout: None,
239 retry: None,
240 options: serde_json::Value::Null,
241 }
242 }
243
244 #[tokio::test]
245 async fn breaker_reused_for_same_runtime_key() {
246 let _lock = ENV_LOCK.lock().await;
247 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
248
249 let provider = RemoteAnthropicProvider::new();
250 let s1 = spec("gen/a", ModelTask::Generate, "claude-sonnet-4-5-20250929");
251 let s2 = spec("gen/b", ModelTask::Generate, "claude-sonnet-4-5-20250929");
252
253 let _ = provider.load(&s1).await.unwrap();
254 let _ = provider.load(&s2).await.unwrap();
255
256 assert_eq!(provider.breaker_count(), 1);
257
258 unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
259 }
260
261 #[tokio::test]
262 async fn breaker_cleanup_evicts_stale_entries() {
263 let _lock = ENV_LOCK.lock().await;
264 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
265
266 let provider = RemoteAnthropicProvider::new();
267 let stale = spec(
268 "gen/stale",
269 ModelTask::Generate,
270 "claude-sonnet-4-5-20250929",
271 );
272 let fresh = spec(
273 "gen/fresh",
274 ModelTask::Generate,
275 "claude-haiku-3-5-20241022",
276 );
277 provider.insert_test_breaker(
278 ModelRuntimeKey::new(&stale),
279 RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
280 );
281 provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
282 assert_eq!(provider.breaker_count(), 2);
283
284 provider.force_cleanup_now_for_test();
285 let _ = provider.load(&fresh).await.unwrap();
286
287 assert_eq!(provider.breaker_count(), 1);
288
289 unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
290 }
291
292 #[tokio::test]
293 async fn embed_capability_mismatch() {
294 let _lock = ENV_LOCK.lock().await;
295 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
296
297 let provider = RemoteAnthropicProvider::new();
298 let s = spec("embed/a", ModelTask::Embed, "claude-sonnet-4-5-20250929");
299 let result = provider.load(&s).await;
300 assert!(result.is_err());
301 assert!(
302 result
303 .unwrap_err()
304 .to_string()
305 .contains("does not support task")
306 );
307
308 unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
309 }
310
311 #[tokio::test]
312 async fn rerank_capability_mismatch() {
313 let _lock = ENV_LOCK.lock().await;
314 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
315
316 let provider = RemoteAnthropicProvider::new();
317 let s = spec("rerank/a", ModelTask::Rerank, "claude-sonnet-4-5-20250929");
318 let result = provider.load(&s).await;
319 assert!(result.is_err());
320 assert!(
321 result
322 .unwrap_err()
323 .to_string()
324 .contains("does not support task")
325 );
326
327 unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
328 }
329
330 #[test]
331 fn payload_defaults_max_tokens_to_1024() {
332 let messages = vec![json!({"role": "user", "content": "hello"})];
333 let payload = build_anthropic_payload(
334 "claude-sonnet-4-5-20250929",
335 &messages,
336 &GenerationOptions::default(),
337 None,
338 );
339 assert_eq!(payload["max_tokens"], 1024);
340 }
341
342 #[test]
343 fn payload_uses_explicit_max_tokens() {
344 let messages = vec![json!({"role": "user", "content": "hello"})];
345 let payload = build_anthropic_payload(
346 "claude-sonnet-4-5-20250929",
347 &messages,
348 &GenerationOptions {
349 max_tokens: Some(512),
350 ..Default::default()
351 },
352 None,
353 );
354 assert_eq!(payload["max_tokens"], 512);
355 }
356
357 #[test]
358 fn payload_includes_system_field() {
359 let messages = vec![json!({"role": "user", "content": "hello"})];
360 let payload = build_anthropic_payload(
361 "claude-sonnet-4-5-20250929",
362 &messages,
363 &GenerationOptions::default(),
364 Some("you are helpful"),
365 );
366 assert_eq!(payload["system"], "you are helpful");
367 }
368
369 #[test]
370 fn payload_omits_system_field_when_none() {
371 let messages = vec![json!({"role": "user", "content": "hello"})];
372 let payload = build_anthropic_payload(
373 "claude-sonnet-4-5-20250929",
374 &messages,
375 &GenerationOptions::default(),
376 None,
377 );
378 assert!(payload.get("system").is_none());
379 }
380}