1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::provider::remote_common::{RemoteProviderBase, check_http_status, resolve_api_key};
4use crate::traits::{
5 GenerationOptions, GenerationResult, GeneratorModel, LoadedModelHandle, ModelProvider,
6 ProviderCapabilities, ProviderHealth, TokenUsage,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::json;
11use std::sync::Arc;
12
13pub struct RemoteAnthropicProvider {
19 base: RemoteProviderBase,
20}
21
22impl Default for RemoteAnthropicProvider {
23 fn default() -> Self {
24 Self {
25 base: RemoteProviderBase::new(),
26 }
27 }
28}
29
30impl RemoteAnthropicProvider {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 #[cfg(test)]
36 fn insert_test_breaker(&self, key: crate::api::ModelRuntimeKey, age: std::time::Duration) {
37 self.base.insert_test_breaker(key, age);
38 }
39
40 #[cfg(test)]
41 fn breaker_count(&self) -> usize {
42 self.base.breaker_count()
43 }
44
45 #[cfg(test)]
46 fn force_cleanup_now_for_test(&self) {
47 self.base.force_cleanup_now_for_test();
48 }
49}
50
51#[async_trait]
52impl ModelProvider for RemoteAnthropicProvider {
53 fn provider_id(&self) -> &'static str {
54 "remote/anthropic"
55 }
56
57 fn capabilities(&self) -> ProviderCapabilities {
58 ProviderCapabilities {
59 supported_tasks: vec![ModelTask::Generate],
60 }
61 }
62
63 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
64 let cb = self.base.circuit_breaker_for(spec);
65 let api_key = resolve_api_key(&spec.options, "api_key_env", "ANTHROPIC_API_KEY")?;
66
67 let anthropic_version = spec
68 .options
69 .get("anthropic_version")
70 .and_then(|v| v.as_str())
71 .unwrap_or("2023-06-01")
72 .to_string();
73
74 match spec.task {
75 ModelTask::Generate => {
76 let model = AnthropicGeneratorModel {
77 client: self.base.client.clone(),
78 cb,
79 model_id: spec.model_id.clone(),
80 api_key,
81 anthropic_version,
82 };
83 let handle: Arc<dyn GeneratorModel> = Arc::new(model);
84 Ok(Arc::new(handle) as LoadedModelHandle)
85 }
86 _ => Err(RuntimeError::CapabilityMismatch(format!(
87 "Anthropic provider does not support task {:?}",
88 spec.task
89 ))),
90 }
91 }
92
93 async fn health(&self) -> ProviderHealth {
94 ProviderHealth::Healthy
95 }
96}
97
98struct AnthropicGeneratorModel {
99 client: Client,
100 cb: crate::reliability::CircuitBreakerWrapper,
101 model_id: String,
102 api_key: String,
103 anthropic_version: String,
104}
105
106fn build_anthropic_payload(
107 model_id: &str,
108 messages: &[serde_json::Value],
109 options: &GenerationOptions,
110) -> serde_json::Value {
111 let max_tokens = options.max_tokens.unwrap_or(1024);
112
113 let mut body = json!({
114 "model": model_id,
115 "max_tokens": max_tokens,
116 "messages": messages,
117 });
118
119 if let Some(temperature) = options.temperature {
120 body["temperature"] = json!(temperature);
121 }
122 if let Some(top_p) = options.top_p {
123 body["top_p"] = json!(top_p);
124 }
125
126 body
127}
128
129#[async_trait]
130impl GeneratorModel for AnthropicGeneratorModel {
131 async fn generate(
132 &self,
133 messages: &[String],
134 options: GenerationOptions,
135 ) -> Result<GenerationResult> {
136 let messages: Vec<serde_json::Value> = messages
137 .iter()
138 .enumerate()
139 .map(|(i, content)| {
140 let role = if i % 2 == 0 { "user" } else { "assistant" };
141 json!({ "role": role, "content": content })
142 })
143 .collect();
144
145 self.cb
146 .call(move || async move {
147 let body = build_anthropic_payload(&self.model_id, &messages, &options);
148
149 let response = self
150 .client
151 .post("https://api.anthropic.com/v1/messages")
152 .header("x-api-key", &self.api_key)
153 .header("anthropic-version", &self.anthropic_version)
154 .header("content-type", "application/json")
155 .json(&body)
156 .send()
157 .await
158 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
159
160 let body: serde_json::Value = check_http_status("Anthropic", response)?
161 .json()
162 .await
163 .map_err(|e| RuntimeError::ApiError(e.to_string()))?;
164
165 let text = body
166 .get("content")
167 .and_then(|c| c.as_array())
168 .and_then(|arr| arr.first())
169 .and_then(|item| item.get("text"))
170 .and_then(|t| t.as_str())
171 .unwrap_or("")
172 .to_string();
173
174 let usage = body.get("usage").map(|u| TokenUsage {
175 prompt_tokens: u["input_tokens"].as_u64().unwrap_or(0) as usize,
176 completion_tokens: u["output_tokens"].as_u64().unwrap_or(0) as usize,
177 total_tokens: (u["input_tokens"].as_u64().unwrap_or(0)
178 + u["output_tokens"].as_u64().unwrap_or(0))
179 as usize,
180 });
181
182 Ok(GenerationResult { text, usage })
183 })
184 .await
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::api::ModelRuntimeKey;
192 use crate::provider::remote_common::RemoteProviderBase;
193 use crate::traits::ModelProvider;
194 use std::time::Duration;
195
196 static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
197
198 fn spec(alias: &str, task: ModelTask, model_id: &str) -> ModelAliasSpec {
199 ModelAliasSpec {
200 alias: alias.to_string(),
201 task,
202 provider_id: "remote/anthropic".to_string(),
203 model_id: model_id.to_string(),
204 revision: None,
205 warmup: crate::api::WarmupPolicy::Lazy,
206 required: false,
207 timeout: None,
208 load_timeout: None,
209 retry: None,
210 options: serde_json::Value::Null,
211 }
212 }
213
214 #[tokio::test]
215 async fn breaker_reused_for_same_runtime_key() {
216 let _lock = ENV_LOCK.lock().await;
217 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
218
219 let provider = RemoteAnthropicProvider::new();
220 let s1 = spec("gen/a", ModelTask::Generate, "claude-sonnet-4-5-20250929");
221 let s2 = spec("gen/b", ModelTask::Generate, "claude-sonnet-4-5-20250929");
222
223 let _ = provider.load(&s1).await.unwrap();
224 let _ = provider.load(&s2).await.unwrap();
225
226 assert_eq!(provider.breaker_count(), 1);
227
228 unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
229 }
230
231 #[tokio::test]
232 async fn breaker_cleanup_evicts_stale_entries() {
233 let _lock = ENV_LOCK.lock().await;
234 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
235
236 let provider = RemoteAnthropicProvider::new();
237 let stale = spec(
238 "gen/stale",
239 ModelTask::Generate,
240 "claude-sonnet-4-5-20250929",
241 );
242 let fresh = spec(
243 "gen/fresh",
244 ModelTask::Generate,
245 "claude-haiku-3-5-20241022",
246 );
247 provider.insert_test_breaker(
248 ModelRuntimeKey::new(&stale),
249 RemoteProviderBase::BREAKER_TTL + Duration::from_secs(5),
250 );
251 provider.insert_test_breaker(ModelRuntimeKey::new(&fresh), Duration::from_secs(1));
252 assert_eq!(provider.breaker_count(), 2);
253
254 provider.force_cleanup_now_for_test();
255 let _ = provider.load(&fresh).await.unwrap();
256
257 assert_eq!(provider.breaker_count(), 1);
258
259 unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
260 }
261
262 #[tokio::test]
263 async fn embed_capability_mismatch() {
264 let _lock = ENV_LOCK.lock().await;
265 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
266
267 let provider = RemoteAnthropicProvider::new();
268 let s = spec("embed/a", ModelTask::Embed, "claude-sonnet-4-5-20250929");
269 let result = provider.load(&s).await;
270 assert!(result.is_err());
271 assert!(
272 result
273 .unwrap_err()
274 .to_string()
275 .contains("does not support task")
276 );
277
278 unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
279 }
280
281 #[tokio::test]
282 async fn rerank_capability_mismatch() {
283 let _lock = ENV_LOCK.lock().await;
284 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key") };
285
286 let provider = RemoteAnthropicProvider::new();
287 let s = spec("rerank/a", ModelTask::Rerank, "claude-sonnet-4-5-20250929");
288 let result = provider.load(&s).await;
289 assert!(result.is_err());
290 assert!(
291 result
292 .unwrap_err()
293 .to_string()
294 .contains("does not support task")
295 );
296
297 unsafe { std::env::remove_var("ANTHROPIC_API_KEY") };
298 }
299
300 #[test]
301 fn payload_defaults_max_tokens_to_1024() {
302 let messages = vec![json!({"role": "user", "content": "hello"})];
303 let payload = build_anthropic_payload(
304 "claude-sonnet-4-5-20250929",
305 &messages,
306 &GenerationOptions::default(),
307 );
308 assert_eq!(payload["max_tokens"], 1024);
309 }
310
311 #[test]
312 fn payload_uses_explicit_max_tokens() {
313 let messages = vec![json!({"role": "user", "content": "hello"})];
314 let payload = build_anthropic_payload(
315 "claude-sonnet-4-5-20250929",
316 &messages,
317 &GenerationOptions {
318 max_tokens: Some(512),
319 temperature: None,
320 top_p: None,
321 },
322 );
323 assert_eq!(payload["max_tokens"], 512);
324 }
325}