uni_xervo/provider/
candle.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::traits::{
4    EmbeddingModel, LoadedModelHandle, ModelProvider, ProviderCapabilities, ProviderHealth,
5};
6use async_trait::async_trait;
7use candle_core::{DType, Device, Module, Tensor};
8use candle_nn::VarBuilder;
9use candle_transformers::models::bert::{BertModel, Config as BertConfig, DTYPE};
10use candle_transformers::models::gemma::{Config as GemmaConfig, Model as GemmaModel};
11use candle_transformers::models::jina_bert::{
12    BertModel as JinaBertModel, Config as JinaBertConfig,
13};
14use hf_hub::{
15    Repo, RepoType,
16    api::tokio::{Api, ApiBuilder},
17};
18use serde::Deserialize;
19use std::path::PathBuf;
20use std::sync::Arc;
21use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
22use tokio::sync::Mutex;
23
24#[derive(Deserialize, Debug)]
25struct BaseConfig {
26    architectures: Option<Vec<String>>,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq)]
30enum ModelArchitecture {
31    Bert,
32    JinaBert,
33    Gemma,
34}
35
36impl ModelArchitecture {
37    fn from_config(config: &BaseConfig) -> Result<Self> {
38        if let Some(archs) = &config.architectures
39            && let Some(arch) = archs.first()
40        {
41            return match arch.as_str() {
42                "BertModel" | "BertForMaskedLM" => Ok(Self::Bert),
43                "JinaBertModel" | "JinaBertForMaskedLM" => Ok(Self::JinaBert),
44                "GemmaModel" | "GemmaForCausalLM" => Ok(Self::Gemma),
45                _ => Err(RuntimeError::Config(format!(
46                    "Unsupported architecture: {}",
47                    arch
48                ))),
49            };
50        }
51        // Default to Bert if unspecified (legacy behavior)
52        Ok(Self::Bert)
53    }
54}
55
56/// Local embedding provider using the [Candle](https://github.com/huggingface/candle)
57/// ML framework.
58///
59/// Supports Bert, JinaBert, and Gemma architectures with lazy weight loading
60/// from HuggingFace Hub and mean-pooled, L2-normalized embeddings.
61#[derive(Default)]
62pub struct LocalCandleProvider;
63
64impl LocalCandleProvider {
65    pub fn new() -> Self {
66        Self
67    }
68}
69
70#[async_trait]
71impl ModelProvider for LocalCandleProvider {
72    fn provider_id(&self) -> &'static str {
73        "local/candle"
74    }
75
76    fn capabilities(&self) -> ProviderCapabilities {
77        ProviderCapabilities {
78            supported_tasks: vec![ModelTask::Embed],
79        }
80    }
81
82    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
83        if spec.task != ModelTask::Embed {
84            return Err(RuntimeError::CapabilityMismatch(format!(
85                "Candle provider does not support task {:?}",
86                spec.task
87            )));
88        }
89
90        let model_type = CandleTextModel::from_name(&spec.model_id).ok_or_else(|| {
91            RuntimeError::Config(format!("Unsupported Candle model: {}", spec.model_id))
92        })?;
93
94        let cache_dir =
95            crate::cache::resolve_cache_dir("candle", model_type.model_id(), &spec.options);
96
97        tracing::info!(model = ?model_type, "Initializing Candle model");
98        let model = CandleEmbeddingModel::new(model_type, spec.revision.clone(), cache_dir);
99
100        let handle: Arc<dyn EmbeddingModel> = Arc::new(model);
101        Ok(Arc::new(handle) as LoadedModelHandle)
102    }
103
104    async fn health(&self) -> ProviderHealth {
105        ProviderHealth::Healthy
106    }
107
108    async fn warmup(&self) -> Result<()> {
109        tracing::info!("Warming up LocalCandleProvider");
110        // Pre-initialize HF API to warm up network/cache
111        let _ = Api::new().map_err(|e| RuntimeError::Load(e.to_string()))?;
112        Ok(())
113    }
114}
115
116/// Supported text embedding models.
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
118pub enum CandleTextModel {
119    /// all-MiniLM-L6-v2: 384 dims, fastest, English-optimized
120    #[default]
121    AllMiniLmL6V2,
122    /// BGE-small-en-v1.5: 384 dims, high quality English
123    BgeSmallEnV15,
124    /// BGE-base-en-v1.5: 768 dimensions, higher quality English
125    BgeBaseEnV15,
126}
127
128impl CandleTextModel {
129    pub fn model_id(&self) -> &'static str {
130        match self {
131            Self::AllMiniLmL6V2 => "sentence-transformers/all-MiniLM-L6-v2",
132            Self::BgeSmallEnV15 => "BAAI/bge-small-en-v1.5",
133            Self::BgeBaseEnV15 => "BAAI/bge-base-en-v1.5",
134        }
135    }
136
137    pub fn dimensions(&self) -> u32 {
138        match self {
139            Self::AllMiniLmL6V2 | Self::BgeSmallEnV15 => 384,
140            Self::BgeBaseEnV15 => 768,
141        }
142    }
143
144    pub fn name(&self) -> &'static str {
145        match self {
146            Self::AllMiniLmL6V2 => "all-MiniLM-L6-v2",
147            Self::BgeSmallEnV15 => "bge-small-en-v1.5",
148            Self::BgeBaseEnV15 => "bge-base-en-v1.5",
149        }
150    }
151
152    pub fn from_name(name: &str) -> Option<Self> {
153        match name.to_lowercase().as_str() {
154            "all-minilm-l6-v2" | "allminilml6v2" | "default" => Some(Self::AllMiniLmL6V2),
155            "bge-small-en-v1.5" | "bgesmallenv15" => Some(Self::BgeSmallEnV15),
156            "bge-base-en-v1.5" | "bgebaseenv15" => Some(Self::BgeBaseEnV15),
157            // Map known HF IDs to enum
158            "sentence-transformers/all-minilm-l6-v2" => Some(Self::AllMiniLmL6V2),
159            "baai/bge-small-en-v1.5" => Some(Self::BgeSmallEnV15),
160            "baai/bge-base-en-v1.5" => Some(Self::BgeBaseEnV15),
161            _ => None,
162        }
163    }
164}
165
166enum InnerModel {
167    Bert(BertModel),
168    JinaBert(JinaBertModel),
169    Gemma(GemmaModel),
170}
171
172struct LoadedModel {
173    model: InnerModel,
174    tokenizer: Tokenizer,
175    device: Device,
176}
177
178/// A lazily-loaded embedding model backed by Candle.
179///
180/// On first [`embed`](crate::traits::EmbeddingModel::embed) call (or explicit
181/// [`warmup`](crate::traits::EmbeddingModel::warmup)), the model weights and
182/// tokenizer are downloaded from HuggingFace Hub and loaded into memory.
183pub struct CandleEmbeddingModel {
184    model_type: CandleTextModel,
185    revision: Option<String>,
186    cache_dir: PathBuf,
187    state: Arc<Mutex<Option<LoadedModel>>>,
188}
189
190impl CandleEmbeddingModel {
191    pub fn new(model_type: CandleTextModel, revision: Option<String>, cache_dir: PathBuf) -> Self {
192        Self {
193            model_type,
194            revision,
195            cache_dir,
196            state: Arc::new(Mutex::new(None)),
197        }
198    }
199
200    async fn ensure_loaded(&self) -> Result<()> {
201        let mut state = self.state.lock().await;
202        if state.is_some() {
203            return Ok(());
204        }
205
206        tracing::info!(
207            model = self.model_type.name(),
208            "Loading Candle embedding model"
209        );
210
211        let api = ApiBuilder::new()
212            .with_cache_dir(self.cache_dir.clone())
213            .build()
214            .map_err(|e| RuntimeError::Load(e.to_string()))?;
215        let repo = match &self.revision {
216            Some(rev) => Repo::with_revision(
217                self.model_type.model_id().to_string(),
218                RepoType::Model,
219                rev.clone(),
220            ),
221            None => Repo::model(self.model_type.model_id().to_string()),
222        };
223        let api_repo = api.repo(repo);
224
225        let config_path = api_repo
226            .get("config.json")
227            .await
228            .map_err(|e| RuntimeError::Load(e.to_string()))?;
229
230        let config_contents =
231            std::fs::read_to_string(&config_path).map_err(|e| RuntimeError::Load(e.to_string()))?;
232
233        let base_config: BaseConfig = serde_json::from_str(&config_contents)
234            .map_err(|e| RuntimeError::Load(e.to_string()))?;
235
236        let arch = ModelArchitecture::from_config(&base_config)?;
237        tracing::info!(architecture = ?arch, "Detected model architecture");
238
239        let tokenizer_path = api_repo
240            .get("tokenizer.json")
241            .await
242            .map_err(|e| RuntimeError::Load(e.to_string()))?;
243        let weights_path = api_repo
244            .get("model.safetensors")
245            .await
246            .map_err(|e| RuntimeError::Load(e.to_string()))?;
247
248        let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
249            .map_err(|e| RuntimeError::Load(format!("Failed to load tokenizer: {}", e)))?;
250
251        let padding = PaddingParams {
252            strategy: PaddingStrategy::BatchLongest,
253            ..Default::default()
254        };
255        tokenizer.with_padding(Some(padding));
256
257        // Gemma usually handles truncation differently or defaults are fine.
258        tokenizer
259            .with_truncation(Some(TruncationParams {
260                max_length: 512,
261                ..Default::default()
262            }))
263            .map_err(|e| RuntimeError::Load(format!("Failed to set truncation: {}", e)))?;
264
265        let device = Device::Cpu;
266        let vb = unsafe {
267            VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device)
268                .map_err(|e| RuntimeError::Load(e.to_string()))?
269        };
270
271        let model = match arch {
272            ModelArchitecture::Bert => {
273                let config: BertConfig = serde_json::from_str(&config_contents)
274                    .map_err(|e| RuntimeError::Load(e.to_string()))?;
275                let model =
276                    BertModel::load(vb, &config).map_err(|e| RuntimeError::Load(e.to_string()))?;
277                InnerModel::Bert(model)
278            }
279            ModelArchitecture::JinaBert => {
280                let config: JinaBertConfig = serde_json::from_str(&config_contents)
281                    .map_err(|e| RuntimeError::Load(e.to_string()))?;
282                let model = JinaBertModel::new(vb, &config)
283                    .map_err(|e| RuntimeError::Load(e.to_string()))?;
284                InnerModel::JinaBert(model)
285            }
286            ModelArchitecture::Gemma => {
287                let config: GemmaConfig = serde_json::from_str(&config_contents)
288                    .map_err(|e| RuntimeError::Load(e.to_string()))?;
289                let model = GemmaModel::new(false, &config, vb)
290                    .map_err(|e| RuntimeError::Load(e.to_string()))?;
291                InnerModel::Gemma(model)
292            }
293        };
294
295        tracing::info!(
296            model = self.model_type.name(),
297            dimensions = self.model_type.dimensions(),
298            "Candle embedding model loaded"
299        );
300
301        *state = Some(LoadedModel {
302            model,
303            tokenizer,
304            device,
305        });
306
307        Ok(())
308    }
309}
310
311#[async_trait]
312impl EmbeddingModel for CandleEmbeddingModel {
313    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
314        self.ensure_loaded().await?;
315
316        let state_guard = self.state.lock().await;
317        let loaded = state_guard
318            .as_ref()
319            .ok_or_else(|| RuntimeError::Load("Model state missing".to_string()))?;
320
321        if texts.is_empty() {
322            return Ok(vec![]);
323        }
324
325        let encodings = loaded
326            .tokenizer
327            .encode_batch(texts.to_vec(), true)
328            .map_err(|e| RuntimeError::InferenceError(format!("Tokenization failed: {}", e)))?;
329
330        let mut all_input_ids = Vec::new();
331        let mut all_attention_masks = Vec::new();
332        let mut all_token_type_ids = Vec::new();
333
334        for encoding in &encodings {
335            all_input_ids.push(
336                encoding
337                    .get_ids()
338                    .iter()
339                    .map(|&x| x as i64)
340                    .collect::<Vec<_>>(),
341            );
342            all_attention_masks.push(
343                encoding
344                    .get_attention_mask()
345                    .iter()
346                    .map(|&x| x as i64)
347                    .collect::<Vec<_>>(),
348            );
349            all_token_type_ids.push(
350                encoding
351                    .get_type_ids()
352                    .iter()
353                    .map(|&x| x as i64)
354                    .collect::<Vec<_>>(),
355            );
356        }
357
358        let batch_size = texts.len();
359        let seq_len = all_input_ids[0].len();
360
361        let input_ids_flat: Vec<i64> = all_input_ids.into_iter().flatten().collect();
362        let attention_mask_flat: Vec<i64> = all_attention_masks.into_iter().flatten().collect();
363        let token_type_ids_flat: Vec<i64> = all_token_type_ids.into_iter().flatten().collect();
364
365        let input_ids = Tensor::from_vec(input_ids_flat, (batch_size, seq_len), &loaded.device)
366            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
367        let attention_mask =
368            Tensor::from_vec(attention_mask_flat, (batch_size, seq_len), &loaded.device)
369                .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
370        let token_type_ids =
371            Tensor::from_vec(token_type_ids_flat, (batch_size, seq_len), &loaded.device)
372                .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
373
374        let embeddings = match &loaded.model {
375            InnerModel::Bert(m) => m
376                .forward(&input_ids, &token_type_ids, Some(&attention_mask))
377                .map_err(|e| RuntimeError::InferenceError(e.to_string()))?,
378            InnerModel::JinaBert(m) => m
379                .forward(&input_ids)
380                .map_err(|e| RuntimeError::InferenceError(e.to_string()))?,
381            InnerModel::Gemma(_m) => {
382                // Gemma expects (input_ids, input_positions) usually.
383                // We construct simple positions 0..seq_len
384                // Note: This assumes simple batching without specialized attention masks for Gemma
385                // which might be suboptimal but functional for embedding.
386                let positions = (0..seq_len).map(|i| i as i64).collect::<Vec<_>>();
387                let _positions = Tensor::from_vec(positions, (seq_len,), &loaded.device)
388                    .map_err(|e| RuntimeError::InferenceError(e.to_string()))?
389                    .broadcast_as((batch_size, seq_len))
390                    .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
391
392                // Gemma forward returns logits? Or hidden states?
393                // Standard candle-transformers Gemma::forward returns logits.
394                // We usually want hidden states.
395                // If the model struct doesn't expose it, we are stuck for Gemma via this provider
396                // without copying the model code.
397                // For now, let's try calling it. If it returns logits (vocab size), we can't use it for embedding easily
398                // without knowing which layer to take (usually hidden states before head).
399                // However, "Embedding Gemma" might NOT have an LM head?
400                // If it's `GemmaForCausalLM`, it has a head.
401                // If we load it as `GemmaModel`, does it include head?
402                // `candle_transformers::models::gemma::Model` usually includes the head.
403                // We'll return an error for now for Gemma until we resolve this.
404                return Err(RuntimeError::InferenceError(
405                    "Gemma embedding not fully implemented (requires hidden state access)"
406                        .to_string(),
407                ));
408            }
409        };
410
411        // Mean pooling
412        let attention_mask_f32 = attention_mask
413            .to_dtype(DType::F32)
414            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
415        let mask_expanded = attention_mask_f32
416            .unsqueeze(2)
417            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
418        let mask_expanded = mask_expanded
419            .broadcast_as(embeddings.shape())
420            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
421
422        let masked_embeddings = embeddings
423            .mul(&mask_expanded)
424            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
425        let sum_embeddings = masked_embeddings
426            .sum(1)
427            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
428
429        let mask_sum = attention_mask_f32
430            .sum(1)
431            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?
432            .unsqueeze(1)
433            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
434
435        let mask_sum = mask_sum
436            .broadcast_as(sum_embeddings.shape())
437            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
438        let mask_sum = mask_sum
439            .clamp(1e-9, f64::MAX)
440            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
441
442        let mean_embeddings = sum_embeddings
443            .div(&mask_sum)
444            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
445
446        let norm = mean_embeddings
447            .sqr()
448            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?
449            .sum_keepdim(1)
450            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?
451            .sqrt()
452            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?
453            .clamp(1e-12, f64::MAX)
454            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
455
456        let normalized = mean_embeddings
457            .broadcast_div(&norm)
458            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
459
460        let embeddings_vec: Vec<Vec<f32>> = normalized
461            .to_vec2()
462            .map_err(|e| RuntimeError::InferenceError(e.to_string()))?;
463
464        Ok(embeddings_vec)
465    }
466
467    fn dimensions(&self) -> u32 {
468        self.model_type.dimensions()
469    }
470
471    fn model_id(&self) -> &str {
472        self.model_type.model_id()
473    }
474
475    async fn warmup(&self) -> Result<()> {
476        self.ensure_loaded().await
477    }
478}