uni_xervo/
cache.rs

1//! Model and weight cache directory resolution.
2//!
3//! Local providers download model weights to a per-provider, per-model directory.
4//! This module determines where that directory lives based on (in priority order):
5//!
6//! 1. A per-model `cache_dir` option in the spec's JSON options.
7//! 2. The `UNI_CACHE_DIR` environment variable (global root override).
8//! 3. A default `.uni_cache/` directory relative to the working directory.
9
10use serde_json::Value;
11use std::path::PathBuf;
12
13/// Replace `/` with `--` and strip characters that are unsafe in directory names.
14pub fn sanitize_model_name(model_id: &str) -> String {
15    model_id
16        .replace('/', "--")
17        .chars()
18        .filter(|c| c.is_alphanumeric() || matches!(c, '-' | '_' | '.'))
19        .collect()
20}
21
22/// The environment variable used to override the root cache directory.
23pub const CACHE_ROOT_ENV: &str = "UNI_CACHE_DIR";
24
25/// Default root cache directory name (relative to CWD).
26const DEFAULT_CACHE_ROOT: &str = ".uni_cache";
27
28/// Return the cache root directory, respecting the `UNI_CACHE_DIR` env var.
29fn cache_root() -> PathBuf {
30    std::env::var(CACHE_ROOT_ENV)
31        .map(PathBuf::from)
32        .unwrap_or_else(|_| PathBuf::from(DEFAULT_CACHE_ROOT))
33}
34
35/// Resolve the root cache directory for a provider (no model sub-directory).
36///
37/// Used when setting a process-global cache env var (e.g. `HF_HOME` for mistralrs)
38/// before the first model load.
39///
40/// Priority:
41/// 1. `UNI_CACHE_DIR` env var -- resolves to `$UNI_CACHE_DIR/<provider>`
42/// 2. `.uni_cache/<provider>` -- default
43pub fn resolve_provider_cache_root(provider: &str) -> PathBuf {
44    cache_root().join(provider)
45}
46
47/// Resolve the cache directory for a given provider and model.
48///
49/// Priority (highest first):
50/// 1. `options["cache_dir"]` -- per-model override
51/// 2. `UNI_CACHE_DIR` env var -- global root override; resolves to `$UNI_CACHE_DIR/<provider>/<model>`
52/// 3. `.uni_cache/<provider>/<model>` -- default
53pub fn resolve_cache_dir(provider: &str, model_id: &str, options: &Value) -> PathBuf {
54    if let Some(dir) = options.get("cache_dir").and_then(|v| v.as_str()) {
55        return PathBuf::from(dir);
56    }
57    cache_root()
58        .join(provider)
59        .join(sanitize_model_name(model_id))
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65    use serde_json::json;
66
67    // Serialise all tests that read or write UNI_CACHE_DIR to avoid races
68    // between parallel test threads (env vars are process-global).
69    static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
70
71    #[test]
72    fn sanitize_slash_replaced_with_double_dash() {
73        assert_eq!(
74            sanitize_model_name("sentence-transformers/all-MiniLM-L6-v2"),
75            "sentence-transformers--all-MiniLM-L6-v2"
76        );
77    }
78
79    #[test]
80    fn sanitize_strips_unsafe_chars() {
81        assert_eq!(sanitize_model_name("foo:bar@baz"), "foobarbaz");
82    }
83
84    #[test]
85    fn sanitize_keeps_safe_chars() {
86        assert_eq!(
87            sanitize_model_name("BAAI--bge-small-en-v1.5"),
88            "BAAI--bge-small-en-v1.5"
89        );
90    }
91
92    #[test]
93    fn resolve_default_path() {
94        let _lock = ENV_LOCK.lock().unwrap();
95        // SAFETY: protected by ENV_LOCK
96        unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
97        let path = resolve_cache_dir("fastembed", "BAAI/bge-small-en-v1.5", &json!({}));
98        assert_eq!(
99            path,
100            PathBuf::from(".uni_cache/fastembed/BAAI--bge-small-en-v1.5")
101        );
102    }
103
104    #[test]
105    fn resolve_env_var_root() {
106        let _lock = ENV_LOCK.lock().unwrap();
107        // SAFETY: protected by ENV_LOCK
108        unsafe { std::env::set_var(CACHE_ROOT_ENV, "/data/models") };
109        let path = resolve_cache_dir("fastembed", "BAAI/bge-small-en-v1.5", &json!({}));
110        unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
111        assert_eq!(
112            path,
113            PathBuf::from("/data/models/fastembed/BAAI--bge-small-en-v1.5")
114        );
115    }
116
117    #[test]
118    fn resolve_options_cache_dir_takes_priority_over_env() {
119        let _lock = ENV_LOCK.lock().unwrap();
120        // SAFETY: protected by ENV_LOCK
121        unsafe { std::env::set_var(CACHE_ROOT_ENV, "/data/models") };
122        let opts = json!({ "cache_dir": "/tmp/my_cache" });
123        let path = resolve_cache_dir("fastembed", "some-model", &opts);
124        unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
125        assert_eq!(path, PathBuf::from("/tmp/my_cache"));
126    }
127
128    #[test]
129    fn resolve_user_override() {
130        let _lock = ENV_LOCK.lock().unwrap();
131        // SAFETY: protected by ENV_LOCK
132        unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
133        let opts = json!({ "cache_dir": "/tmp/my_cache" });
134        let path = resolve_cache_dir("fastembed", "some-model", &opts);
135        assert_eq!(path, PathBuf::from("/tmp/my_cache"));
136    }
137
138    #[test]
139    fn resolve_candle_path() {
140        let _lock = ENV_LOCK.lock().unwrap();
141        // SAFETY: protected by ENV_LOCK
142        unsafe { std::env::remove_var(CACHE_ROOT_ENV) };
143        let path = resolve_cache_dir(
144            "candle",
145            "sentence-transformers/all-MiniLM-L6-v2",
146            &json!({}),
147        );
148        assert_eq!(
149            path,
150            PathBuf::from(".uni_cache/candle/sentence-transformers--all-MiniLM-L6-v2")
151        );
152    }
153}