uni_xervo/provider/
fastembed.rs

1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::traits::{
4    EmbeddingModel, LoadedModelHandle, ModelProvider, ProviderCapabilities, ProviderHealth,
5};
6use anyhow::anyhow;
7use async_trait::async_trait;
8use fastembed::{InitOptions, TextEmbedding};
9use std::path::Path;
10use std::sync::{Arc, Mutex};
11use std::thread;
12use tokio::sync::oneshot;
13
14/// Local embedding provider using [FastEmbed](https://github.com/Anush008/fastembed-rs)
15/// (ONNX Runtime).
16///
17/// Supports a wide range of embedding models. Inference is offloaded to a
18/// dedicated thread with an enlarged stack to accommodate ONNX Runtime's
19/// requirements.
20pub struct LocalFastEmbedProvider;
21
22impl LocalFastEmbedProvider {
23    pub fn new() -> Self {
24        Self
25    }
26}
27
28impl Default for LocalFastEmbedProvider {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34#[async_trait]
35impl ModelProvider for LocalFastEmbedProvider {
36    fn provider_id(&self) -> &'static str {
37        "local/fastembed"
38    }
39
40    fn capabilities(&self) -> ProviderCapabilities {
41        ProviderCapabilities {
42            supported_tasks: vec![ModelTask::Embed],
43        }
44    }
45
46    async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
47        if spec.task != ModelTask::Embed {
48            return Err(RuntimeError::CapabilityMismatch(format!(
49                "FastEmbed provider does not support task {:?}",
50                spec.task
51            )));
52        }
53
54        let model_name = spec.model_id.clone();
55        let cache_dir = crate::cache::resolve_cache_dir("fastembed", &model_name, &spec.options);
56
57        // Offload initialization to a blocking thread because it can refer to onnxruntime which might be heavy
58        // fastembed init might block.
59        let service =
60            tokio::task::spawn_blocking(move || FastEmbedService::new(&model_name, &cache_dir))
61                .await
62                .map_err(|e| RuntimeError::Load(format!("Join error: {}", e)))?
63                .map_err(|e| RuntimeError::Load(e.to_string()))?;
64
65        let handle: Arc<dyn EmbeddingModel> = Arc::new(service);
66        Ok(Arc::new(handle) as LoadedModelHandle)
67    }
68
69    async fn health(&self) -> ProviderHealth {
70        ProviderHealth::Healthy
71    }
72}
73
74/// Stack size for embedding threads.
75const EMBEDDING_THREAD_STACK_SIZE: usize = 8 * 1024 * 1024;
76
77/// Wrapper around a [`TextEmbedding`] instance that implements
78/// [`EmbeddingModel`].
79///
80/// Each inference call spawns a short-lived worker thread with a larger stack
81/// to satisfy ONNX Runtime's stack requirements.
82pub struct FastEmbedService {
83    model: Arc<Mutex<TextEmbedding>>,
84    model_name: String,
85    dimensions: u32,
86}
87
88impl FastEmbedService {
89    pub fn new(model_name: &str, cache_dir: &Path) -> anyhow::Result<Self> {
90        let model_enum = match model_name {
91            "AllMiniLML6V2" | "all-MiniLM-L6-v2" => fastembed::EmbeddingModel::AllMiniLML6V2,
92            "AllMiniLML6V2Q" => fastembed::EmbeddingModel::AllMiniLML6V2Q,
93            "AllMiniLML12V2" => fastembed::EmbeddingModel::AllMiniLML12V2,
94            "AllMiniLML12V2Q" => fastembed::EmbeddingModel::AllMiniLML12V2Q,
95            "AllMpnetBaseV2" | "all-mpnet-base-v2" => fastembed::EmbeddingModel::AllMpnetBaseV2,
96            "BGEBaseENV15" | "bge-base-en-v1.5" => fastembed::EmbeddingModel::BGEBaseENV15,
97            "BGEBaseENV15Q" => fastembed::EmbeddingModel::BGEBaseENV15Q,
98            "BGELargeENV15" | "bge-large-en-v1.5" => fastembed::EmbeddingModel::BGELargeENV15,
99            "BGELargeENV15Q" => fastembed::EmbeddingModel::BGELargeENV15Q,
100            "BGESmallENV15" | "bge-small-en-v1.5" => fastembed::EmbeddingModel::BGESmallENV15,
101            "BGESmallENV15Q" => fastembed::EmbeddingModel::BGESmallENV15Q,
102            "NomicEmbedTextV1" => fastembed::EmbeddingModel::NomicEmbedTextV1,
103            "NomicEmbedTextV15" | "nomic-embed-text-v1.5" => {
104                fastembed::EmbeddingModel::NomicEmbedTextV15
105            }
106            "NomicEmbedTextV15Q" => fastembed::EmbeddingModel::NomicEmbedTextV15Q,
107            "ParaphraseMLMiniLML12V2" => fastembed::EmbeddingModel::ParaphraseMLMiniLML12V2,
108            "ParaphraseMLMiniLML12V2Q" => fastembed::EmbeddingModel::ParaphraseMLMiniLML12V2Q,
109            "ParaphraseMLMpnetBaseV2" => fastembed::EmbeddingModel::ParaphraseMLMpnetBaseV2,
110            "BGESmallZHV15" => fastembed::EmbeddingModel::BGESmallZHV15,
111            "BGELargeZHV15" => fastembed::EmbeddingModel::BGELargeZHV15,
112            "BGEM3" => fastembed::EmbeddingModel::BGEM3,
113            "ModernBertEmbedLarge" => fastembed::EmbeddingModel::ModernBertEmbedLarge,
114            "MultilingualE5Small" | "multilingual-e5-small" => {
115                fastembed::EmbeddingModel::MultilingualE5Small
116            }
117            "MultilingualE5Base" | "multilingual-e5-base" => {
118                fastembed::EmbeddingModel::MultilingualE5Base
119            }
120            "MultilingualE5Large" | "multilingual-e5-large" => {
121                fastembed::EmbeddingModel::MultilingualE5Large
122            }
123            "MxbaiEmbedLargeV1" | "mxbai-embed-large-v1" => {
124                fastembed::EmbeddingModel::MxbaiEmbedLargeV1
125            }
126            _ => {
127                return Err(anyhow!(
128                    "Unsupported FastEmbed model: {}. Please check fastembed docs for supported models.",
129                    model_name
130                ));
131            }
132        };
133
134        let mut options = InitOptions::new(model_enum.clone());
135        options = options.with_cache_dir(cache_dir.to_path_buf());
136
137        let model = TextEmbedding::try_new(options)
138            .map_err(|e| anyhow!("Failed to initialize FastEmbed model: {}", e))?;
139
140        // Determine dimensions
141        let dimensions = match model_enum {
142            fastembed::EmbeddingModel::AllMiniLML6V2
143            | fastembed::EmbeddingModel::AllMiniLML6V2Q
144            | fastembed::EmbeddingModel::AllMiniLML12V2
145            | fastembed::EmbeddingModel::AllMiniLML12V2Q
146            | fastembed::EmbeddingModel::ParaphraseMLMiniLML12V2
147            | fastembed::EmbeddingModel::ParaphraseMLMiniLML12V2Q
148            | fastembed::EmbeddingModel::BGESmallENV15
149            | fastembed::EmbeddingModel::BGESmallENV15Q
150            | fastembed::EmbeddingModel::MultilingualE5Small => 384,
151
152            fastembed::EmbeddingModel::BGESmallZHV15 => 512,
153
154            fastembed::EmbeddingModel::AllMpnetBaseV2
155            | fastembed::EmbeddingModel::ParaphraseMLMpnetBaseV2
156            | fastembed::EmbeddingModel::BGEBaseENV15
157            | fastembed::EmbeddingModel::BGEBaseENV15Q
158            | fastembed::EmbeddingModel::NomicEmbedTextV1
159            | fastembed::EmbeddingModel::NomicEmbedTextV15
160            | fastembed::EmbeddingModel::NomicEmbedTextV15Q
161            | fastembed::EmbeddingModel::MultilingualE5Base => 768,
162
163            fastembed::EmbeddingModel::BGELargeENV15
164            | fastembed::EmbeddingModel::BGELargeENV15Q
165            | fastembed::EmbeddingModel::BGELargeZHV15
166            | fastembed::EmbeddingModel::BGEM3
167            | fastembed::EmbeddingModel::ModernBertEmbedLarge
168            | fastembed::EmbeddingModel::MultilingualE5Large
169            | fastembed::EmbeddingModel::MxbaiEmbedLargeV1 => 1024,
170
171            _ => {
172                // Fallback for new models or quantized variants not explicitly listed
173                // We could log a warning here or return a default.
174                // Assuming 768 is a safe-ish bet for unknown models or 1024 for "Large" ones.
175                // Better approach: Since we can't easily probe without loading, we might just
176                // assume a default and let the user override via config if needed.
177                // But for now, to satisfy exhaustiveness:
178                1024
179            }
180        };
181
182        Ok(Self {
183            model: Arc::new(Mutex::new(model)),
184            model_name: model_name.to_string(),
185            dimensions,
186        })
187    }
188}
189
190#[async_trait]
191impl EmbeddingModel for FastEmbedService {
192    async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
193        let texts_vec: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
194        let model = self.model.clone();
195
196        let (tx, rx) = oneshot::channel();
197
198        // Spawn a dedicated thread with larger stack for ONNX Runtime
199        thread::Builder::new()
200            .name("fastembed-worker".to_string())
201            .stack_size(EMBEDDING_THREAD_STACK_SIZE)
202            .spawn(move || {
203                let result = model
204                    .lock()
205                    .map_err(|_| anyhow!("Failed to lock embedding model"))
206                    .and_then(|mut guard| {
207                        guard
208                            .embed(texts_vec, None)
209                            .map_err(|e| anyhow!("FastEmbed error: {}", e))
210                    });
211                let _ = tx.send(result);
212            })
213            .map_err(|e| {
214                RuntimeError::InferenceError(format!("Failed to spawn embedding thread: {}", e))
215            })?;
216
217        let result = rx
218            .await
219            .map_err(|_| RuntimeError::InferenceError("Embedding thread panicked".to_string()))?;
220
221        result.map_err(|e| RuntimeError::InferenceError(e.to_string()))
222    }
223
224    fn dimensions(&self) -> u32 {
225        self.dimensions
226    }
227
228    fn model_id(&self) -> &str {
229        &self.model_name
230    }
231}