1#![cfg(feature = "advanced-primitives")]
2
3use reqwest::StatusCode;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
13pub struct SemanticClientConfig {
14 pub enabled: bool,
16 #[serde(default, skip_serializing_if = "Option::is_none")]
18 pub endpoint: Option<String>,
19 #[serde(default, skip_serializing_if = "Option::is_none")]
21 pub api_key: Option<String>,
22 #[serde(default, skip_serializing_if = "Option::is_none")]
24 pub domain: Option<String>,
25 pub max_results: usize,
27 pub timeout_ms: u64,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct SemanticQueryRequest {
34 pub query: String,
36 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub domain: Option<String>,
39 #[serde(default, skip_serializing_if = "Option::is_none")]
41 pub session_id: Option<String>,
42 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub task_id: Option<String>,
45 pub source: String,
47 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
49 pub hints: HashMap<String, String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct SemanticQueryHit {
55 pub title: String,
57 pub snippet: String,
59 #[serde(default, skip_serializing_if = "Option::is_none")]
61 pub source: Option<String>,
62 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub score: Option<f32>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct SemanticQueryResult {
70 #[serde(default, skip_serializing_if = "Option::is_none")]
72 pub domain: Option<String>,
73 pub summary: String,
75 #[serde(default, skip_serializing_if = "Vec::is_empty")]
77 pub hits: Vec<SemanticQueryHit>,
78}
79
80#[derive(Debug, thiserror::Error)]
82pub enum SemanticClientError {
83 #[error("semantic client is missing a valid endpoint")]
85 InvalidConfiguration,
86 #[error("semantic request failed: {0}")]
88 Transport(#[from] reqwest::Error),
89 #[error("semantic backend returned status {status}: {body}")]
91 Status {
92 status: StatusCode,
94 body: String,
96 },
97}
98
99#[derive(Debug, Clone)]
101pub struct SemanticClient {
102 config: SemanticClientConfig,
103 client: reqwest::Client,
104}
105
106impl SemanticClient {
107 pub fn new(config: SemanticClientConfig) -> Result<Self, SemanticClientError> {
109 let client = reqwest::Client::builder()
110 .timeout(Duration::from_millis(config.timeout_ms.max(100)))
111 .build()?;
112 Ok(Self { config, client })
113 }
114
115 pub async fn query(
117 &self,
118 request: &SemanticQueryRequest,
119 ) -> Result<Option<SemanticQueryResult>, SemanticClientError> {
120 if !self.config.enabled || request.query.trim().is_empty() {
121 return Ok(None);
122 }
123 let Some(endpoint) = self.config.endpoint.as_deref() else {
124 return Err(SemanticClientError::InvalidConfiguration);
125 };
126
127 let mut builder = self.client.post(endpoint).json(&serde_json::json!({
128 "query": request.query.clone(),
129 "domain": request.domain.as_ref().or(self.config.domain.as_ref()),
130 "session_id": request.session_id.clone(),
131 "task_id": request.task_id.clone(),
132 "source": request.source.clone(),
133 "limit": self.config.max_results.max(1),
134 "hints": request.hints.clone(),
135 }));
136 if let Some(api_key) = self.config.api_key.as_deref() {
137 builder = builder.bearer_auth(api_key);
138 }
139
140 let response = builder.send().await?;
141 let status = response.status();
142 if !status.is_success() {
143 let body = response.text().await.unwrap_or_default();
144 return Err(SemanticClientError::Status { status, body });
145 }
146
147 let body: Value = response.json().await?;
148 Ok(parse_semantic_response(
149 &body,
150 request
151 .domain
152 .clone()
153 .or_else(|| self.config.domain.clone()),
154 self.config.max_results.max(1),
155 ))
156 }
157}
158
159fn parse_semantic_response(
160 body: &Value,
161 fallback_domain: Option<String>,
162 max_results: usize,
163) -> Option<SemanticQueryResult> {
164 let summary = body
165 .get("summary")
166 .and_then(Value::as_str)
167 .or_else(|| body.get("content").and_then(Value::as_str))
168 .or_else(|| body.pointer("/answer/summary").and_then(Value::as_str))
169 .unwrap_or_default()
170 .trim()
171 .to_string();
172
173 let hits = body
174 .get("results")
175 .or_else(|| body.get("hits"))
176 .or_else(|| body.pointer("/data/results"))
177 .and_then(Value::as_array)
178 .map(|values| {
179 values
180 .iter()
181 .take(max_results)
182 .map(|value| SemanticQueryHit {
183 title: value
184 .get("title")
185 .and_then(Value::as_str)
186 .unwrap_or("semantic-result")
187 .trim()
188 .to_string(),
189 snippet: value
190 .get("snippet")
191 .or_else(|| value.get("summary"))
192 .or_else(|| value.get("text"))
193 .and_then(Value::as_str)
194 .unwrap_or_default()
195 .trim()
196 .to_string(),
197 source: value
198 .get("source")
199 .or_else(|| value.get("uri"))
200 .or_else(|| value.get("url"))
201 .and_then(Value::as_str)
202 .map(ToOwned::to_owned),
203 score: value
204 .get("score")
205 .and_then(Value::as_f64)
206 .map(|score| score as f32),
207 })
208 .filter(|hit| !hit.snippet.is_empty() || !hit.title.is_empty())
209 .collect::<Vec<_>>()
210 })
211 .unwrap_or_default();
212
213 if summary.is_empty() && hits.is_empty() {
214 return None;
215 }
216
217 Some(SemanticQueryResult {
218 domain: body
219 .get("domain")
220 .and_then(Value::as_str)
221 .map(ToOwned::to_owned)
222 .or(fallback_domain),
223 summary,
224 hits,
225 })
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn parse_semantic_response_accepts_generic_payloads() {
234 let body = serde_json::json!({
235 "summary": "Cross-check the known protocol constraints before execution.",
236 "results": [
237 {
238 "title": "Protocol reference",
239 "snippet": "Ring transport expects BOS1921 success waveform IDs.",
240 "source": "docs://ring"
241 }
242 ]
243 });
244
245 let result = parse_semantic_response(&body, Some("haptics".to_string()), 3)
246 .expect("result should parse");
247 assert_eq!(result.domain.as_deref(), Some("haptics"));
248 assert_eq!(result.hits.len(), 1);
249 assert!(result.summary.contains("protocol constraints"));
250 }
251
252 #[test]
253 fn parse_semantic_response_returns_none_for_empty_payloads() {
254 let body = serde_json::json!({"results": []});
255 assert!(parse_semantic_response(&body, None, 3).is_none());
256 }
257}