1use gestura_core_foundation::context::ResolvedContext;
7use serde::{Deserialize, Serialize};
8
9use crate::reflection::{quality_signals_for_response, score_reflection_improvement};
10use crate::types::{AgentResponse, ToolCallRecord, ToolResult};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum ReflectionEvalToolOutcome {
15 Success,
16 Error,
17 Skipped,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ReflectionEvalToolResult {
23 pub name: String,
24 pub outcome: ReflectionEvalToolOutcome,
25 pub detail: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ReflectionEvalTurn {
31 pub content: String,
32 pub tool_results: Vec<ReflectionEvalToolResult>,
33 pub truncated: bool,
34 pub iterations: usize,
35}
36
37impl ReflectionEvalTurn {
38 fn to_agent_response(&self) -> AgentResponse {
39 AgentResponse {
40 content: self.content.clone(),
41 thinking: None,
42 tool_calls: self
43 .tool_results
44 .iter()
45 .enumerate()
46 .map(|(index, tool)| ToolCallRecord {
47 id: format!("eval-tool-{index}"),
48 name: tool.name.clone(),
49 arguments: "{}".to_string(),
50 result: match tool.outcome {
51 ReflectionEvalToolOutcome::Success => {
52 ToolResult::Success(tool.detail.clone())
53 }
54 ReflectionEvalToolOutcome::Error => ToolResult::Error(tool.detail.clone()),
55 ReflectionEvalToolOutcome::Skipped => {
56 ToolResult::Skipped(tool.detail.clone())
57 }
58 },
59 duration_ms: 10,
60 })
61 .collect(),
62 usage: None,
63 context_used: ResolvedContext::default(),
64 truncated: self.truncated,
65 iterations: self.iterations,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ReflectionEvalCase {
73 pub name: String,
74 pub request_summary: String,
75 pub initial: ReflectionEvalTurn,
76 pub retry: ReflectionEvalTurn,
77 pub max_iterations: usize,
78 pub expected_min_improvement: f32,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ReflectionEvalReport {
84 pub name: String,
85 pub request_summary: String,
86 pub initial_quality_score: f32,
87 pub retry_quality_score: f32,
88 pub improvement_score: f32,
89 pub passed: bool,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct ReflectionEvalSummary {
95 pub reports: Vec<ReflectionEvalReport>,
96 pub total_cases: usize,
97 pub passed_cases: usize,
98 pub average_improvement_score: f32,
99}
100
101pub fn evaluate_reflection_case(case: &ReflectionEvalCase) -> ReflectionEvalReport {
103 let initial_quality_score =
104 quality_signals_for_response(&case.initial.to_agent_response(), case.max_iterations)
105 .score();
106 let retry_quality_score =
107 quality_signals_for_response(&case.retry.to_agent_response(), case.max_iterations).score();
108 let improvement_score =
109 score_reflection_improvement(initial_quality_score, retry_quality_score);
110
111 ReflectionEvalReport {
112 name: case.name.clone(),
113 request_summary: case.request_summary.clone(),
114 initial_quality_score,
115 retry_quality_score,
116 improvement_score,
117 passed: improvement_score >= case.expected_min_improvement,
118 }
119}
120
121pub fn evaluate_reflection_cases(cases: &[ReflectionEvalCase]) -> ReflectionEvalSummary {
123 let reports: Vec<_> = cases.iter().map(evaluate_reflection_case).collect();
124 let total_cases = reports.len();
125 let passed_cases = reports.iter().filter(|report| report.passed).count();
126 let average_improvement_score = if total_cases == 0 {
127 0.0
128 } else {
129 reports
130 .iter()
131 .map(|report| report.improvement_score)
132 .sum::<f32>()
133 / total_cases as f32
134 };
135
136 ReflectionEvalSummary {
137 reports,
138 total_cases,
139 passed_cases,
140 average_improvement_score,
141 }
142}
143
144pub fn builtin_reflection_eval_cases() -> Vec<ReflectionEvalCase> {
146 vec![
147 ReflectionEvalCase {
148 name: "tool_failure_becomes_actionable".to_string(),
149 request_summary: "User asked for a file that was looked up with the wrong path".to_string(),
150 initial: ReflectionEvalTurn {
151 content: "I'm sorry, I can't access that file right now.".to_string(),
152 tool_results: vec![ReflectionEvalToolResult {
153 name: "file".to_string(),
154 outcome: ReflectionEvalToolOutcome::Error,
155 detail: "config/app.toml does not exist".to_string(),
156 }],
157 truncated: false,
158 iterations: 2,
159 },
160 retry: ReflectionEvalTurn {
161 content: "The file lookup failed because `config/app.toml` does not exist. Please confirm the correct path and I can continue from there.".to_string(),
162 tool_results: vec![ReflectionEvalToolResult {
163 name: "file".to_string(),
164 outcome: ReflectionEvalToolOutcome::Error,
165 detail: "config/app.toml does not exist".to_string(),
166 }],
167 truncated: false,
168 iterations: 1,
169 },
170 max_iterations: 4,
171 expected_min_improvement: 0.20,
172 },
173 ReflectionEvalCase {
174 name: "truncated_answer_is_completed".to_string(),
175 request_summary: "Agent response was cut off before giving the final answer".to_string(),
176 initial: ReflectionEvalTurn {
177 content: "Here are the first two steps, but the rest was truncated".to_string(),
178 tool_results: Vec::new(),
179 truncated: true,
180 iterations: 3,
181 },
182 retry: ReflectionEvalTurn {
183 content: "Here are all four steps, followed by the exact next action and the validation command to run.".to_string(),
184 tool_results: Vec::new(),
185 truncated: false,
186 iterations: 1,
187 },
188 max_iterations: 4,
189 expected_min_improvement: 0.15,
190 },
191 ]
192}
193
194#[cfg(test)]
195mod tests {
196 use super::{
197 builtin_reflection_eval_cases, evaluate_reflection_case, evaluate_reflection_cases,
198 };
199
200 #[test]
201 fn builtin_reflection_eval_cases_pass() {
202 let summary = evaluate_reflection_cases(&builtin_reflection_eval_cases());
203 assert_eq!(summary.total_cases, 2);
204 assert_eq!(summary.passed_cases, 2);
205 assert!(summary.average_improvement_score > 0.15);
206 }
207
208 #[test]
209 fn regression_case_fails_when_retry_does_not_improve() {
210 let mut case = builtin_reflection_eval_cases().remove(0);
211 case.retry.content = case.initial.content.clone();
212 let report = evaluate_reflection_case(&case);
213 assert!(!report.passed);
214 assert_eq!(report.improvement_score, 0.0);
215 }
216}