1use serde_json::Value;
11use std::path::PathBuf;
12
13pub fn sanitize_model_name(model_id: &str) -> String {
15 model_id
16 .replace('/', "--")
17 .chars()
18 .filter(|c| c.is_alphanumeric() || matches!(c, '-' | '_' | '.'))
19 .collect()
20}
21
22pub const CACHE_ROOT_ENV: &str = "UNI_CACHE_DIR";
24
25const DEFAULT_CACHE_ROOT: &str = ".uni_cache";
27
28fn cache_root() -> PathBuf {
30 std::env::var(CACHE_ROOT_ENV)
31 .map(PathBuf::from)
32 .unwrap_or_else(|_| PathBuf::from(DEFAULT_CACHE_ROOT))
33}
34
35pub fn resolve_provider_cache_root(provider: &str) -> PathBuf {
44 cache_root().join(provider)
45}
46
47pub fn resolve_cache_dir(provider: &str, model_id: &str, options: &Value) -> PathBuf {
54 if let Some(dir) = options.get("cache_dir").and_then(|v| v.as_str()) {
55 return PathBuf::from(dir);
56 }
57 cache_root()
58 .join(provider)
59 .join(sanitize_model_name(model_id))
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65 use serde_json::json;
66
67 static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
70
71 #[test]
72 fn sanitize_slash_replaced_with_double_dash() {
73 assert_eq!(
74 sanitize_model_name("sentence-transformers/all-MiniLM-L6-v2"),
75 "sentence-transformers--all-MiniLM-L6-v2"
76 );
77 }
78
79 #[test]
80 fn sanitize_strips_unsafe_chars() {
81 assert_eq!(sanitize_model_name("foo:bar@baz"), "foobarbaz");
82 }
83
84 #[test]
85 fn sanitize_keeps_safe_chars() {
86 assert_eq!(
87 sanitize_model_name("BAAI--bge-small-en-v1.5"),
88 "BAAI--bge-small-en-v1.5"
89 );
90 }
91
92 #[test]
93 fn resolve_default_path() {
94 let _lock = ENV_LOCK.lock().unwrap();
95 unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
97 let path = resolve_cache_dir("fastembed", "BAAI/bge-small-en-v1.5", &json!({}));
98 assert_eq!(
99 path,
100 PathBuf::from(".uni_cache/fastembed/BAAI--bge-small-en-v1.5")
101 );
102 }
103
104 #[test]
105 fn resolve_env_var_root() {
106 let _lock = ENV_LOCK.lock().unwrap();
107 unsafe { std::env::set_var(CACHE_ROOT_ENV, "/data/models") };
109 let path = resolve_cache_dir("fastembed", "BAAI/bge-small-en-v1.5", &json!({}));
110 unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
111 assert_eq!(
112 path,
113 PathBuf::from("/data/models/fastembed/BAAI--bge-small-en-v1.5")
114 );
115 }
116
117 #[test]
118 fn resolve_options_cache_dir_takes_priority_over_env() {
119 let _lock = ENV_LOCK.lock().unwrap();
120 unsafe { std::env::set_var(CACHE_ROOT_ENV, "/data/models") };
122 let opts = json!({ "cache_dir": "/tmp/my_cache" });
123 let path = resolve_cache_dir("fastembed", "some-model", &opts);
124 unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
125 assert_eq!(path, PathBuf::from("/tmp/my_cache"));
126 }
127
128 #[test]
129 fn resolve_user_override() {
130 let _lock = ENV_LOCK.lock().unwrap();
131 unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
133 let opts = json!({ "cache_dir": "/tmp/my_cache" });
134 let path = resolve_cache_dir("fastembed", "some-model", &opts);
135 assert_eq!(path, PathBuf::from("/tmp/my_cache"));
136 }
137
138 #[test]
139 fn resolve_candle_path() {
140 let _lock = ENV_LOCK.lock().unwrap();
141 unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
143 let path = resolve_cache_dir(
144 "candle",
145 "sentence-transformers/all-MiniLM-L6-v2",
146 &json!({}),
147 );
148 assert_eq!(
149 path,
150 PathBuf::from(".uni_cache/candle/sentence-transformers--all-MiniLM-L6-v2")
151 );
152 }
153}