uni_xervo/traits.rs
1//! Core traits that every provider and model implementation must satisfy.
2
3use crate::api::{ModelAliasSpec, ModelTask};
4use crate::error::Result;
5use async_trait::async_trait;
6use std::any::Any;
7
8/// Advertised capabilities of a [`ModelProvider`].
9#[derive(Debug, Clone)]
10pub struct ProviderCapabilities {
11 /// The set of [`ModelTask`] variants this provider can handle.
12 pub supported_tasks: Vec<ModelTask>,
13}
14
15/// Health status reported by a provider.
16#[derive(Debug, Clone)]
17pub enum ProviderHealth {
18 /// The provider is fully operational.
19 Healthy,
20 /// The provider is operational but experiencing partial issues.
21 Degraded(String),
22 /// The provider cannot serve requests.
23 Unhealthy(String),
24}
25
26/// A pluggable backend that knows how to load models for one or more
27/// [`ModelTask`] types.
28///
29/// Providers are registered with [`ModelRuntimeBuilder::register_provider`](crate::runtime::ModelRuntimeBuilder::register_provider)
30/// and are identified by their [`provider_id`](ModelProvider::provider_id)
31/// (e.g. `"local/candle"`, `"remote/openai"`).
32#[async_trait]
33pub trait ModelProvider: Send + Sync {
34 /// Unique identifier for this provider (e.g. `"local/candle"`, `"remote/openai"`).
35 fn provider_id(&self) -> &'static str;
36
37 /// Return the set of tasks this provider supports.
38 fn capabilities(&self) -> ProviderCapabilities;
39
40 /// Load (or connect to) a model described by `spec` and return a type-erased
41 /// handle.
42 ///
43 /// The returned [`LoadedModelHandle`] is expected to contain an
44 /// `Arc<dyn EmbeddingModel>`, `Arc<dyn RerankerModel>`, or
45 /// `Arc<dyn GeneratorModel>` depending on the task.
46 async fn load(&self, spec: &ModelAliasSpec) -> Result<LoadedModelHandle>;
47
48 /// Report the current health of this provider.
49 async fn health(&self) -> ProviderHealth;
50
51 /// Optional one-time warmup hook called during runtime startup.
52 ///
53 /// Use this for provider-wide initialization such as setting up API clients
54 /// or pre-caching shared resources. The default implementation is a no-op.
55 async fn warmup(&self) -> Result<()> {
56 Ok(())
57 }
58}
59
60/// A type-erased, reference-counted handle to a loaded model instance.
61///
62/// Providers wrap their concrete model (e.g. `Arc<dyn EmbeddingModel>`) inside
63/// this `Arc<dyn Any + Send + Sync>` so the runtime can store them uniformly.
64/// The runtime later downcasts the handle back to the expected trait object.
65pub type LoadedModelHandle = std::sync::Arc<dyn Any + Send + Sync>;
66
67/// A model that produces dense vector embeddings from text.
68#[async_trait]
69pub trait EmbeddingModel: Send + Sync + Any {
70 /// Embed a batch of text strings into dense vectors.
71 ///
72 /// Returns one `Vec<f32>` per input text, each with [`dimensions()`](EmbeddingModel::dimensions)
73 /// elements.
74 async fn embed(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>>;
75
76 /// The dimensionality of the embedding vectors produced by this model.
77 fn dimensions(&self) -> u32;
78
79 /// The underlying model identifier (e.g. a HuggingFace repo ID or API model name).
80 fn model_id(&self) -> &str;
81
82 /// Optional warmup hook (e.g. load weights into memory on first access).
83 /// The default is a no-op.
84 async fn warmup(&self) -> Result<()> {
85 Ok(())
86 }
87}
88
89/// A single scored document returned by a [`RerankerModel`].
90#[derive(Debug, Clone)]
91pub struct ScoredDoc {
92 /// Zero-based index into the original `docs` slice passed to
93 /// [`RerankerModel::rerank`].
94 pub index: usize,
95 /// Relevance score assigned by the reranker (higher is more relevant).
96 pub score: f32,
97 /// The document text, if the provider returns it. May be `None`.
98 pub text: Option<String>,
99}
100
101/// A model that re-scores documents against a query for relevance ranking.
102#[async_trait]
103pub trait RerankerModel: Send + Sync {
104 /// Rerank `docs` by relevance to `query`, returning scored results
105 /// (typically sorted by descending score).
106 async fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<ScoredDoc>>;
107
108 /// Optional warmup hook. The default is a no-op.
109 async fn warmup(&self) -> Result<()> {
110 Ok(())
111 }
112}
113
114/// Sampling and length parameters for text generation.
115#[derive(Debug, Clone, Default)]
116pub struct GenerationOptions {
117 /// Maximum number of tokens to generate. Provider default if `None`.
118 pub max_tokens: Option<usize>,
119 /// Sampling temperature (0.0 = greedy, higher = more random).
120 pub temperature: Option<f32>,
121 /// Nucleus sampling threshold.
122 pub top_p: Option<f32>,
123}
124
125/// The output of a text generation call.
126#[derive(Debug, Clone)]
127pub struct GenerationResult {
128 /// The generated text.
129 pub text: String,
130 /// Token usage statistics, if reported by the provider.
131 pub usage: Option<TokenUsage>,
132}
133
134/// Token counts for a generation request.
135#[derive(Debug, Clone)]
136pub struct TokenUsage {
137 /// Number of tokens in the prompt / input.
138 pub prompt_tokens: usize,
139 /// Number of tokens generated.
140 pub completion_tokens: usize,
141 /// Sum of prompt and completion tokens.
142 pub total_tokens: usize,
143}
144
145/// A model that generates text from a conversational message history.
146///
147/// Messages are passed as a flat `&[String]` slice where even-indexed entries
148/// (0, 2, 4, ...) are user turns and odd-indexed entries are assistant turns.
149#[async_trait]
150pub trait GeneratorModel: Send + Sync {
151 /// Generate a response given a conversation history and sampling options.
152 async fn generate(
153 &self,
154 messages: &[String],
155 options: GenerationOptions,
156 ) -> Result<GenerationResult>;
157
158 /// Optional warmup hook. The default is a no-op.
159 async fn warmup(&self) -> Result<()> {
160 Ok(())
161 }
162}