1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::traits::{
4 EmbeddingModel, LoadedModelHandle, ModelProvider, ProviderCapabilities, ProviderHealth,
5};
6use async_trait::async_trait;
7use candle_core::{DType, Device, Module, Tensor};
8use candle_nn::VarBuilder;
9use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
10use candle_transformers::models::gemma::{Config as GemmaConfig, Model as GemmaModel};
11use candle_transformers::models::jina_bert::{
12 BertModel as JinaBertModel, Config as JinaBertConfig,
13};
14use hf_hub::{
15 Repo, RepoType,
16 api::tokio::{Api, ApiBuilder},
17};
18use serde::Deserialize;
19use std::path::PathBuf;
20use std::sync::Arc;
21use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
22use tokio::sync::Mutex;
23
24#[derive(Deserialize, Debug)]
25struct BaseConfig {
26 architectures: Option<Vec<String>>,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq)]
30enum ModelArchitecture {
31 Bert,
32 JinaBert,
33 Gemma,
34}
35
36impl ModelArchitecture {
37 fn from_config(config: &BaseConfig) -> Result<Self> {
38 if let Some(archs) = &config.architectures
39 && let Some(arch) = archs.first()
40 {
41 return match arch.as_str() {
42 "BertModel" | "BertForMaskedLM" => Ok(Self::Bert),
43 "JinaBertModel" | "JinaBertForMaskedLM" => Ok(Self::JinaBert),
44 "GemmaModel" | "GemmaForCausalLM" => Ok(Self::Gemma),
45 _ => Err(RuntimeError::Config(format!(
46 "Unsupported architecture: {}",
47 arch
48 ))),
49 };
50 }
51 Ok(Self::Bert)
53 }
54}
55
56#[derive(Default)]
62pub struct LocalCandleProvider;
63
64impl LocalCandleProvider {
65 pub fn new() -> Self {
66 Self
67 }
68}
69
70#[async_trait]
71impl ModelProvider for LocalCandleProvider {
72 fn provider_id(&self) -> &'static str {
73 "local/candle"
74 }
75
76 fn capabilities(&self) -> ProviderCapabilities {
77 ProviderCapabilities {
78 supported_tasks: vec![ModelTask::Embed],
79 }
80 }
81
82 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
83 if spec.task != ModelTask::Embed {
84 return Err(RuntimeError::CapabilityMismatch(format!(
85 "Candle provider does not support task {:?}",
86 spec.task
87 )));
88 }
89
90 let model_type = CandleTextModel::from_name(&spec.model_id).ok_or_else(|| {
91 RuntimeError::Config(format!("Unsupported Candle model: {}", spec.model_id))
92 })?;
93
94 let cache_dir =
95 crate::cache::resolve_cache_dir("candle", model_type.model_id(), &spec.options);
96
97 tracing::info!(model = ?model_type, "Initializing Candle model");
98 let model = CandleEmbeddingModel::new(model_type, spec.revision.clone(), cache_dir);
99
100 let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
101 Ok(Arc::new(handle) as LoadedModelHandle)
102 }
103
104 async fn health(&self) -> ProviderHealth {
105 ProviderHealth::Healthy
106 }
107
108 async fn warmup(&self) -> Result<()> {
109 tracing::info!("Warming up LocalCandleProvider");
110 let _ = Api::new().map_err(|e| RuntimeError::Load(e.to_string()))?;
112 Ok(())
113 }
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
118pub enum CandleTextModel {
119 #[default]
121 AllMiniLmL6V2,
122 BgeSmallEnV15,
124 BgeBaseEnV15,
126}
127
128impl CandleTextModel {
129 pub fn model_id(&self) -> &'static str {
130 match self {
131 Self::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
132 Self::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
133 Self::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
134 }
135 }
136
137 pub fn dimensions(&self) -> u32 {
138 match self {
139 Self::AllMiniLmL6V2 | Self::BgeSmallEnV15 => 384,
140 Self::BgeBaseEnV15 => 768,
141 }
142 }
143
144 pub fn name(&self) -> &'static str {
145 match self {
146 Self::AllMiniLmL6V2 => "all-MiniLM-L6-v2",
147 Self::BgeSmallEnV15 => "bge-small-en-v1.5",
148 Self::BgeBaseEnV15 => "bge-base-en-v1.5",
149 }
150 }
151
152 pub fn from_name(name: &str) -> Option<Self> {
153 match name.to_lowercase().as_str() {
154 "all-minilm-l6-v2" | "allminilml6v2" | "default" => Some(Self::AllMiniLmL6V2),
155 "bge-small-en-v1.5" | "bgesmallenv15" => Some(Self::BgeSmallEnV15),
156 "bge-base-en-v1.5" | "bgebaseenv15" => Some(Self::BgeBaseEnV15),
157 "sentence-transformers/all-minilm-l6-v2" => Some(Self::AllMiniLmL6V2),
159 "baai/bge-small-en-v1.5" => Some(Self::BgeSmallEnV15),
160 "baai/bge-base-en-v1.5" => Some(Self::BgeBaseEnV15),
161 _ => None,
162 }
163 }
164}
165
166enum InnerModel {
167 Bert(BertModel),
168 JinaBert(JinaBertModel),
169 Gemma(GemmaModel),
170}
171
172struct LoadedModel {
173 model: InnerModel,
174 tokenizer: Tokenizer,
175 device: Device,
176}
177
178pub struct CandleEmbeddingModel {
184 model_type: CandleTextModel,
185 revision: Option<String>,
186 cache_dir: PathBuf,
187 state: Arc<Mutex<Option<LoadedModel>>>,
188}
189
190impl CandleEmbeddingModel {
191 pub fn new(model_type: CandleTextModel, revision: Option<String>, cache_dir: PathBuf) -> Self {
192 Self {
193 model_type,
194 revision,
195 cache_dir,
196 state: Arc::new(Mutex::new(None)),
197 }
198 }
199
200 async fn ensure_loaded(&self) -> Result<()> {
201 let mut state = self.state.lock().await;
202 if state.is_some() {
203 return Ok(());
204 }
205
206 tracing::info!(
207 model = self.model_type.name(),
208 "Loading Candle embedding model"
209 );
210
211 let api = ApiBuilder::new()
212 .with_cache_dir(self.cache_dir.clone())
213 .build()
214 .map_err(|e| RuntimeError::Load(e.to_string()))?;
215 let repo = match &self.revision {
216 Some(rev) => Repo::with_revision(
217 self.model_type.model_id().to_string(),
218 RepoType::Model,
219 rev.clone(),
220 ),
221 None => Repo::model(self.model_type.model_id().to_string()),
222 };
223 let api_repo = api.repo(repo);
224
225 let config_path = api_repo
226 .get("config.json")
227 .await
228 .map_err(|e| RuntimeError::Load(e.to_string()))?;
229
230 let config_contents =
231 std::fs::read_to_string(&config_path).map_err(|e| RuntimeError::Load(e.to_string()))?;
232
233 let base_config: BaseConfig = serde_json::from_str(&config_contents)
234 .map_err(|e| RuntimeError::Load(e.to_string()))?;
235
236 let arch = ModelArchitecture::from_config(&base_config)?;
237 tracing::info!(architecture = ?arch, "Detected model architecture");
238
239 let tokenizer_path = api_repo
240 .get("tokenizer.json")
241 .await
242 .map_err(|e| RuntimeError::Load(e.to_string()))?;
243 let weights_path = api_repo
244 .get("model.safetensors")
245 .await
246 .map_err(|e| RuntimeError::Load(e.to_string()))?;
247
248 let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
249 .map_err(|e| RuntimeError::Load(format!("Failed to load tokenizer: {}", e)))?;
250
251 let padding = PaddingParams {
252 strategy: PaddingStrategy::BatchLongest,
253 ..Default::default()
254 };
255 tokenizer.with_padding(Some(padding));
256
257 tokenizer
259 .with_truncation(Some(TruncationParams {
260 max_length: 512,
261 ..Default::default()
262 }))
263 .map_err(|e| RuntimeError::Load(format!("Failed to set truncation: {}", e)))?;
264
265 let device = Device::Cpu;
266 let vb = unsafe {
267 VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device)
268 .map_err(|e| RuntimeError::Load(e.to_string()))?
269 };
270
271 let model = match arch {
272 ModelArchitecture::Bert => {
273 let config: BertConfig = serde_json::from_str(&config_contents)
274 .map_err(|e| RuntimeError::Load(e.to_string()))?;
275 let model =
276 BertModel::load(vb, &config).map_err(|e| RuntimeError::Load(e.to_string()))?;
277 InnerModel::Bert(model)
278 }
279 ModelArchitecture::JinaBert => {
280 let config: JinaBertConfig = serde_json::from_str(&config_contents)
281 .map_err(|e| RuntimeError::Load(e.to_string()))?;
282 let model = JinaBertModel::new(vb, &config)
283 .map_err(|e| RuntimeError::Load(e.to_string()))?;
284 InnerModel::JinaBert(model)
285 }
286 ModelArchitecture::Gemma => {
287 let config: GemmaConfig = serde_json::from_str(&config_contents)
288 .map_err(|e| RuntimeError::Load(e.to_string()))?;
289 let model = GemmaModel::new(false, &config, vb)
290 .map_err(|e| RuntimeError::Load(e.to_string()))?;
291 InnerModel::Gemma(model)
292 }
293 };
294
295 tracing::info!(
296 model = self.model_type.name(),
297 dimensions = self.model_type.dimensions(),
298 "Candle embedding model loaded"
299 );
300
301 *state = Some(LoadedModel {
302 model,
303 tokenizer,
304 device,
305 });
306
307 Ok(())
308 }
309}
310
311#[async_trait]
312impl EmbeddingModel for CandleEmbeddingModel {
313 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
314 self.ensure_loaded().await?;
315
316 let state_guard = self.state.lock().await;
317 let loaded = state_guard
318 .as_ref()
319 .ok_or_else(|| RuntimeError::Load("Model state missing".to_string()))?;
320
321 if texts.is_empty() {
322 return Ok(vec![]);
323 }
324
325 let encodings = loaded
326 .tokenizer
327 .encode_batch(texts.to_vec(), true)
328 .map_err(|e| RuntimeError::InferenceError(format!("Tokenization failed: {}", e)))?;
329
330 let mut all_input_ids = Vec::new();
331 let mut all_attention_masks = Vec::new();
332 let mut all_token_type_ids = Vec::new();
333
334 for encoding in &encodings {
335 all_input_ids.push(
336 encoding
337 .get_ids()
338 .iter()
339 .map(|&x| x as i64)
340 .collect::<Vec<_>>(),
341 );
342 all_attention_masks.push(
343 encoding
344 .get_attention_mask()
345 .iter()
346 .map(|&x| x as i64)
347 .collect::<Vec<_>>(),
348 );
349 all_token_type_ids.push(
350 encoding
351 .get_type_ids()
352 .iter()
353 .map(|&x| x as i64)
354 .collect::<Vec<_>>(),
355 );
356 }
357
358 let batch_size = texts.len();
359 let seq_len = all_input_ids[0].len();
360
361 let input_ids_flat: Vec<i64> = all_input_ids.into_iter().flatten().collect();
362 let attention_mask_flat: Vec<i64> = all_attention_masks.into_iter().flatten().collect();
363 let token_type_ids_flat: Vec<i64> = all_token_type_ids.into_iter().flatten().collect();
364
365 let input_ids = Tensor::from_vec(input_ids_flat, (batch_size, seq_len), &loaded.device)
366 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
367 let attention_mask =
368 Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &loaded.device)
369 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
370 let token_type_ids =
371 Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &loaded.device)
372 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
373
374 let embeddings = match &loaded.model {
375 InnerModel::Bert(m) => m
376 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
377 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?,
378 InnerModel::JinaBert(m) => m
379 .forward(&input_ids)
380 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?,
381 InnerModel::Gemma(_m) => {
382 let positions = (0..seq_len).map(|i| i as i64).collect::<Vec<_>>();
387 let _positions = Tensor::from_vec(positions, (seq_len,), &loaded.device)
388 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?
389 .broadcast_as((batch_size, seq_len))
390 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
391
392 return Err(RuntimeError::InferenceError(
405 "Gemma embedding not fully implemented (requires hidden state access)"
406 .to_string(),
407 ));
408 }
409 };
410
411 let attention_mask_f32 = attention_mask
413 .to_dtype(DType::F32)
414 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
415 let mask_expanded = attention_mask_f32
416 .unsqueeze(2)
417 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
418 let mask_expanded = mask_expanded
419 .broadcast_as(embeddings.shape())
420 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
421
422 let masked_embeddings = embeddings
423 .mul(&mask_expanded)
424 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
425 let sum_embeddings = masked_embeddings
426 .sum(1)
427 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
428
429 let mask_sum = attention_mask_f32
430 .sum(1)
431 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?
432 .unsqueeze(1)
433 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
434
435 let mask_sum = mask_sum
436 .broadcast_as(sum_embeddings.shape())
437 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
438 let mask_sum = mask_sum
439 .clamp(1e-9, f64::MAX)
440 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
441
442 let mean_embeddings = sum_embeddings
443 .div(&mask_sum)
444 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
445
446 let norm = mean_embeddings
447 .sqr()
448 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?
449 .sum_keepdim(1)
450 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?
451 .sqrt()
452 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?
453 .clamp(1e-12, f64::MAX)
454 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
455
456 let normalized = mean_embeddings
457 .broadcast_div(&norm)
458 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
459
460 let embeddings_vec: Vec<Vec<f32>> = normalized
461 .to_vec2()
462 .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
463
464 Ok(embeddings_vec)
465 }
466
467 fn dimensions(&self) -> u32 {
468 self.model_type.dimensions()
469 }
470
471 fn model_id(&self) -> &str {
472 self.model_type.model_id()
473 }
474
475 async fn warmup(&self) -> Result<()> {
476 self.ensure_loaded().await
477 }
478}