uni_xervo/provider/
fastembed.rs1use crate::api::{ModelAliasSpec, ModelTask};
2use crate::error::{Result, RuntimeError};
3use crate::traits::{
4 EmbeddingModel, LoadedModelHandle, ModelProvider, ProviderCapabilities, ProviderHealth,
5};
6use anyhow::anyhow;
7use async_trait::async_trait;
8use fastembed::{InitOptions, TextEmbedding};
9use std::path::Path;
10use std::sync::{Arc, Mutex};
11use std::thread;
12use tokio::sync::oneshot;
13
14pub struct LocalFastEmbedProvider;
21
22impl LocalFastEmbedProvider {
23 pub fn new() -> Self {
24 Self
25 }
26}
27
28impl Default for LocalFastEmbedProvider {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34#[async_trait]
35impl ModelProvider for LocalFastEmbedProvider {
36 fn provider_id(&self) -> &'static str {
37 "local/fastembed"
38 }
39
40 fn capabilities(&self) -> ProviderCapabilities {
41 ProviderCapabilities {
42 supported_tasks: vec![ModelTask::Embed],
43 }
44 }
45
46 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle> {
47 if spec.task != ModelTask::Embed {
48 return Err(RuntimeError::CapabilityMismatch(format!(
49 "FastEmbed provider does not support task {:?}",
50 spec.task
51 )));
52 }
53
54 let model_name = spec.model_id.clone();
55 let cache_dir = crate::cache::resolve_cache_dir("fastembed", &model_name, &spec.options);
56
57 let service =
60 tokio::task::spawn_blocking(move || FastEmbedService::new(&model_name, &cache_dir))
61 .await
62 .map_err(|e| RuntimeError::Load(format!("Join error: {}", e)))?
63 .map_err(|e| RuntimeError::Load(e.to_string()))?;
64
65 let handle: Arc<dyn EmbeddingModel> = Arc::new(service);
66 Ok(Arc::new(handle) as LoadedModelHandle)
67 }
68
69 async fn health(&self) -> ProviderHealth {
70 ProviderHealth::Healthy
71 }
72}
73
74const EMBEDDING_THREAD_STACK_SIZE: usize = 8 * 1024 * 1024;
76
77pub struct FastEmbedService {
83 model: Arc<Mutex<TextEmbedding>>,
84 model_name: String,
85 dimensions: u32,
86}
87
88impl FastEmbedService {
89 pub fn new(model_name: &str, cache_dir: &Path) -> anyhow::Result<Self> {
90 let model_enum = match model_name {
91 "AllMiniLML6V2" | "all-MiniLM-L6-v2" => fastembed::EmbeddingModel::AllMiniLML6V2,
92 "AllMiniLML6V2Q" => fastembed::EmbeddingModel::AllMiniLML6V2Q,
93 "AllMiniLML12V2" => fastembed::EmbeddingModel::AllMiniLML12V2,
94 "AllMiniLML12V2Q" => fastembed::EmbeddingModel::AllMiniLML12V2Q,
95 "AllMpnetBaseV2" | "all-mpnet-base-v2" => fastembed::EmbeddingModel::AllMpnetBaseV2,
96 "BGEBaseENV15" | "bge-base-en-v1.5" => fastembed::EmbeddingModel::BGEBaseENV15,
97 "BGEBaseENV15Q" => fastembed::EmbeddingModel::BGEBaseENV15Q,
98 "BGELargeENV15" | "bge-large-en-v1.5" => fastembed::EmbeddingModel::BGELargeENV15,
99 "BGELargeENV15Q" => fastembed::EmbeddingModel::BGELargeENV15Q,
100 "BGESmallENV15" | "bge-small-en-v1.5" => fastembed::EmbeddingModel::BGESmallENV15,
101 "BGESmallENV15Q" => fastembed::EmbeddingModel::BGESmallENV15Q,
102 "NomicEmbedTextV1" => fastembed::EmbeddingModel::NomicEmbedTextV1,
103 "NomicEmbedTextV15" | "nomic-embed-text-v1.5" => {
104 fastembed::EmbeddingModel::NomicEmbedTextV15
105 }
106 "NomicEmbedTextV15Q" => fastembed::EmbeddingModel::NomicEmbedTextV15Q,
107 "ParaphraseMLMiniLML12V2" => fastembed::EmbeddingModel::ParaphraseMLMiniLML12V2,
108 "ParaphraseMLMiniLML12V2Q" => fastembed::EmbeddingModel::ParaphraseMLMiniLML12V2Q,
109 "ParaphraseMLMpnetBaseV2" => fastembed::EmbeddingModel::ParaphraseMLMpnetBaseV2,
110 "BGESmallZHV15" => fastembed::EmbeddingModel::BGESmallZHV15,
111 "BGELargeZHV15" => fastembed::EmbeddingModel::BGELargeZHV15,
112 "BGEM3" => fastembed::EmbeddingModel::BGEM3,
113 "ModernBertEmbedLarge" => fastembed::EmbeddingModel::ModernBertEmbedLarge,
114 "MultilingualE5Small" | "multilingual-e5-small" => {
115 fastembed::EmbeddingModel::MultilingualE5Small
116 }
117 "MultilingualE5Base" | "multilingual-e5-base" => {
118 fastembed::EmbeddingModel::MultilingualE5Base
119 }
120 "MultilingualE5Large" | "multilingual-e5-large" => {
121 fastembed::EmbeddingModel::MultilingualE5Large
122 }
123 "MxbaiEmbedLargeV1" | "mxbai-embed-large-v1" => {
124 fastembed::EmbeddingModel::MxbaiEmbedLargeV1
125 }
126 _ => {
127 return Err(anyhow!(
128 "Unsupported FastEmbed model: {}. Please check fastembed docs for supported models.",
129 model_name
130 ));
131 }
132 };
133
134 let mut options = InitOptions::new(model_enum.clone());
135 options = options.with_cache_dir(cache_dir.to_path_buf());
136
137 let model = TextEmbedding::try_new(options)
138 .map_err(|e| anyhow!("Failed to initialize FastEmbed model: {}", e))?;
139
140 let dimensions = match model_enum {
142 fastembed::EmbeddingModel::AllMiniLML6V2
143 | fastembed::EmbeddingModel::AllMiniLML6V2Q
144 | fastembed::EmbeddingModel::AllMiniLML12V2
145 | fastembed::EmbeddingModel::AllMiniLML12V2Q
146 | fastembed::EmbeddingModel::ParaphraseMLMiniLML12V2
147 | fastembed::EmbeddingModel::ParaphraseMLMiniLML12V2Q
148 | fastembed::EmbeddingModel::BGESmallENV15
149 | fastembed::EmbeddingModel::BGESmallENV15Q
150 | fastembed::EmbeddingModel::MultilingualE5Small => 384,
151
152 fastembed::EmbeddingModel::BGESmallZHV15 => 512,
153
154 fastembed::EmbeddingModel::AllMpnetBaseV2
155 | fastembed::EmbeddingModel::ParaphraseMLMpnetBaseV2
156 | fastembed::EmbeddingModel::BGEBaseENV15
157 | fastembed::EmbeddingModel::BGEBaseENV15Q
158 | fastembed::EmbeddingModel::NomicEmbedTextV1
159 | fastembed::EmbeddingModel::NomicEmbedTextV15
160 | fastembed::EmbeddingModel::NomicEmbedTextV15Q
161 | fastembed::EmbeddingModel::MultilingualE5Base => 768,
162
163 fastembed::EmbeddingModel::BGELargeENV15
164 | fastembed::EmbeddingModel::BGELargeENV15Q
165 | fastembed::EmbeddingModel::BGELargeZHV15
166 | fastembed::EmbeddingModel::BGEM3
167 | fastembed::EmbeddingModel::ModernBertEmbedLarge
168 | fastembed::EmbeddingModel::MultilingualE5Large
169 | fastembed::EmbeddingModel::MxbaiEmbedLargeV1 => 1024,
170
171 _ => {
172 1024
179 }
180 };
181
182 Ok(Self {
183 model: Arc::new(Mutex::new(model)),
184 model_name: model_name.to_string(),
185 dimensions,
186 })
187 }
188}
189
190#[async_trait]
191impl EmbeddingModel for FastEmbedService {
192 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
193 let texts_vec: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
194 let model = self.model.clone();
195
196 let (tx, rx) = oneshot::channel();
197
198 thread::Builder::new()
200 .name("fastembed-worker".to_string())
201 .stack_size(EMBEDDING_THREAD_STACK_SIZE)
202 .spawn(move || {
203 let result = model
204 .lock()
205 .map_err(|_| anyhow!("Failed to lock embedding model"))
206 .and_then(|mut guard| {
207 guard
208 .embed(texts_vec, None)
209 .map_err(|e| anyhow!("FastEmbed error: {}", e))
210 });
211 let _ = tx.send(result);
212 })
213 .map_err(|e| {
214 RuntimeError::InferenceError(format!("Failed to spawn embedding thread: {}", e))
215 })?;
216
217 let result = rx
218 .await
219 .map_err(|_| RuntimeError::InferenceError("Embedding thread panicked".to_string()))?;
220
221 result.map_err(|e| RuntimeError::InferenceError(e.to_string()))
222 }
223
224 fn dimensions(&self) -> u32 {
225 self.dimensions
226 }
227
228 fn model_id(&self) -> &str {
229 &self.model_name
230 }
231}