task_registry.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. //! In-memory task registry for sub-agent task lifecycle management.
  2. use std::collections::HashMap;
  3. use std::sync::{Arc, Mutex};
  4. use std::time::{SystemTime, UNIX_EPOCH};
  5. use serde::{Deserialize, Serialize};
  6. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
  7. #[serde(rename_all = "snake_case")]
  8. pub enum TaskStatus {
  9. Created,
  10. Running,
  11. Completed,
  12. Failed,
  13. Stopped,
  14. }
  15. impl std::fmt::Display for TaskStatus {
  16. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  17. match self {
  18. Self::Created => write!(f, "created"),
  19. Self::Running => write!(f, "running"),
  20. Self::Completed => write!(f, "completed"),
  21. Self::Failed => write!(f, "failed"),
  22. Self::Stopped => write!(f, "stopped"),
  23. }
  24. }
  25. }
  26. #[derive(Debug, Clone, Serialize, Deserialize)]
  27. pub struct Task {
  28. pub task_id: String,
  29. pub prompt: String,
  30. pub description: Option<String>,
  31. pub status: TaskStatus,
  32. pub created_at: u64,
  33. pub updated_at: u64,
  34. pub messages: Vec<TaskMessage>,
  35. pub output: String,
  36. pub team_id: Option<String>,
  37. }
  38. #[derive(Debug, Clone, Serialize, Deserialize)]
  39. pub struct TaskMessage {
  40. pub role: String,
  41. pub content: String,
  42. pub timestamp: u64,
  43. }
  44. #[derive(Debug, Clone, Default)]
  45. pub struct TaskRegistry {
  46. inner: Arc<Mutex<RegistryInner>>,
  47. }
  48. #[derive(Debug, Default)]
  49. struct RegistryInner {
  50. tasks: HashMap<String, Task>,
  51. counter: u64,
  52. }
  53. fn now_secs() -> u64 {
  54. SystemTime::now()
  55. .duration_since(UNIX_EPOCH)
  56. .unwrap_or_default()
  57. .as_secs()
  58. }
  59. impl TaskRegistry {
  60. #[must_use]
  61. pub fn new() -> Self {
  62. Self::default()
  63. }
  64. pub fn create(&self, prompt: &str, description: Option<&str>) -> Task {
  65. let mut inner = self.inner.lock().expect("registry lock poisoned");
  66. inner.counter += 1;
  67. let ts = now_secs();
  68. let task_id = format!("task_{:08x}_{}", ts, inner.counter);
  69. let task = Task {
  70. task_id: task_id.clone(),
  71. prompt: prompt.to_owned(),
  72. description: description.map(str::to_owned),
  73. status: TaskStatus::Created,
  74. created_at: ts,
  75. updated_at: ts,
  76. messages: Vec::new(),
  77. output: String::new(),
  78. team_id: None,
  79. };
  80. inner.tasks.insert(task_id, task.clone());
  81. task
  82. }
  83. pub fn get(&self, task_id: &str) -> Option<Task> {
  84. let inner = self.inner.lock().expect("registry lock poisoned");
  85. inner.tasks.get(task_id).cloned()
  86. }
  87. pub fn list(&self, status_filter: Option<TaskStatus>) -> Vec<Task> {
  88. let inner = self.inner.lock().expect("registry lock poisoned");
  89. inner
  90. .tasks
  91. .values()
  92. .filter(|t| status_filter.map_or(true, |s| t.status == s))
  93. .cloned()
  94. .collect()
  95. }
  96. pub fn stop(&self, task_id: &str) -> Result<Task, String> {
  97. let mut inner = self.inner.lock().expect("registry lock poisoned");
  98. let task = inner
  99. .tasks
  100. .get_mut(task_id)
  101. .ok_or_else(|| format!("task not found: {task_id}"))?;
  102. match task.status {
  103. TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Stopped => {
  104. return Err(format!(
  105. "task {task_id} is already in terminal state: {}",
  106. task.status
  107. ));
  108. }
  109. _ => {}
  110. }
  111. task.status = TaskStatus::Stopped;
  112. task.updated_at = now_secs();
  113. Ok(task.clone())
  114. }
  115. pub fn update(&self, task_id: &str, message: &str) -> Result<Task, String> {
  116. let mut inner = self.inner.lock().expect("registry lock poisoned");
  117. let task = inner
  118. .tasks
  119. .get_mut(task_id)
  120. .ok_or_else(|| format!("task not found: {task_id}"))?;
  121. task.messages.push(TaskMessage {
  122. role: String::from("user"),
  123. content: message.to_owned(),
  124. timestamp: now_secs(),
  125. });
  126. task.updated_at = now_secs();
  127. Ok(task.clone())
  128. }
  129. pub fn output(&self, task_id: &str) -> Result<String, String> {
  130. let inner = self.inner.lock().expect("registry lock poisoned");
  131. let task = inner
  132. .tasks
  133. .get(task_id)
  134. .ok_or_else(|| format!("task not found: {task_id}"))?;
  135. Ok(task.output.clone())
  136. }
  137. pub fn append_output(&self, task_id: &str, output: &str) -> Result<(), String> {
  138. let mut inner = self.inner.lock().expect("registry lock poisoned");
  139. let task = inner
  140. .tasks
  141. .get_mut(task_id)
  142. .ok_or_else(|| format!("task not found: {task_id}"))?;
  143. task.output.push_str(output);
  144. task.updated_at = now_secs();
  145. Ok(())
  146. }
  147. pub fn set_status(&self, task_id: &str, status: TaskStatus) -> Result<(), String> {
  148. let mut inner = self.inner.lock().expect("registry lock poisoned");
  149. let task = inner
  150. .tasks
  151. .get_mut(task_id)
  152. .ok_or_else(|| format!("task not found: {task_id}"))?;
  153. task.status = status;
  154. task.updated_at = now_secs();
  155. Ok(())
  156. }
  157. pub fn assign_team(&self, task_id: &str, team_id: &str) -> Result<(), String> {
  158. let mut inner = self.inner.lock().expect("registry lock poisoned");
  159. let task = inner
  160. .tasks
  161. .get_mut(task_id)
  162. .ok_or_else(|| format!("task not found: {task_id}"))?;
  163. task.team_id = Some(team_id.to_owned());
  164. task.updated_at = now_secs();
  165. Ok(())
  166. }
  167. pub fn remove(&self, task_id: &str) -> Option<Task> {
  168. let mut inner = self.inner.lock().expect("registry lock poisoned");
  169. inner.tasks.remove(task_id)
  170. }
  171. #[must_use]
  172. pub fn len(&self) -> usize {
  173. let inner = self.inner.lock().expect("registry lock poisoned");
  174. inner.tasks.len()
  175. }
  176. #[must_use]
  177. pub fn is_empty(&self) -> bool {
  178. self.len() == 0
  179. }
  180. }
  181. #[cfg(test)]
  182. mod tests {
  183. use super::*;
  184. #[test]
  185. fn creates_and_retrieves_tasks() {
  186. let registry = TaskRegistry::new();
  187. let task = registry.create("Do something", Some("A test task"));
  188. assert_eq!(task.status, TaskStatus::Created);
  189. assert_eq!(task.prompt, "Do something");
  190. assert_eq!(task.description.as_deref(), Some("A test task"));
  191. let fetched = registry.get(&task.task_id).expect("task should exist");
  192. assert_eq!(fetched.task_id, task.task_id);
  193. }
  194. #[test]
  195. fn lists_tasks_with_optional_filter() {
  196. let registry = TaskRegistry::new();
  197. registry.create("Task A", None);
  198. let task_b = registry.create("Task B", None);
  199. registry
  200. .set_status(&task_b.task_id, TaskStatus::Running)
  201. .expect("set status should succeed");
  202. let all = registry.list(None);
  203. assert_eq!(all.len(), 2);
  204. let running = registry.list(Some(TaskStatus::Running));
  205. assert_eq!(running.len(), 1);
  206. assert_eq!(running[0].task_id, task_b.task_id);
  207. let created = registry.list(Some(TaskStatus::Created));
  208. assert_eq!(created.len(), 1);
  209. }
  210. #[test]
  211. fn stops_running_task() {
  212. let registry = TaskRegistry::new();
  213. let task = registry.create("Stoppable", None);
  214. registry
  215. .set_status(&task.task_id, TaskStatus::Running)
  216. .unwrap();
  217. let stopped = registry.stop(&task.task_id).expect("stop should succeed");
  218. assert_eq!(stopped.status, TaskStatus::Stopped);
  219. // Stopping again should fail
  220. let result = registry.stop(&task.task_id);
  221. assert!(result.is_err());
  222. }
  223. #[test]
  224. fn updates_task_with_messages() {
  225. let registry = TaskRegistry::new();
  226. let task = registry.create("Messageable", None);
  227. let updated = registry
  228. .update(&task.task_id, "Here's more context")
  229. .expect("update should succeed");
  230. assert_eq!(updated.messages.len(), 1);
  231. assert_eq!(updated.messages[0].content, "Here's more context");
  232. assert_eq!(updated.messages[0].role, "user");
  233. }
  234. #[test]
  235. fn appends_and_retrieves_output() {
  236. let registry = TaskRegistry::new();
  237. let task = registry.create("Output task", None);
  238. registry
  239. .append_output(&task.task_id, "line 1\n")
  240. .expect("append should succeed");
  241. registry
  242. .append_output(&task.task_id, "line 2\n")
  243. .expect("append should succeed");
  244. let output = registry.output(&task.task_id).expect("output should exist");
  245. assert_eq!(output, "line 1\nline 2\n");
  246. }
  247. #[test]
  248. fn assigns_team_and_removes_task() {
  249. let registry = TaskRegistry::new();
  250. let task = registry.create("Team task", None);
  251. registry
  252. .assign_team(&task.task_id, "team_abc")
  253. .expect("assign should succeed");
  254. let fetched = registry.get(&task.task_id).unwrap();
  255. assert_eq!(fetched.team_id.as_deref(), Some("team_abc"));
  256. let removed = registry.remove(&task.task_id);
  257. assert!(removed.is_some());
  258. assert!(registry.get(&task.task_id).is_none());
  259. assert!(registry.is_empty());
  260. }
  261. #[test]
  262. fn rejects_operations_on_missing_task() {
  263. let registry = TaskRegistry::new();
  264. assert!(registry.stop("nonexistent").is_err());
  265. assert!(registry.update("nonexistent", "msg").is_err());
  266. assert!(registry.output("nonexistent").is_err());
  267. assert!(registry.append_output("nonexistent", "data").is_err());
  268. assert!(registry
  269. .set_status("nonexistent", TaskStatus::Running)
  270. .is_err());
  271. }
  272. #[test]
  273. fn task_status_display_all_variants() {
  274. // given
  275. let cases = [
  276. (TaskStatus::Created, "created"),
  277. (TaskStatus::Running, "running"),
  278. (TaskStatus::Completed, "completed"),
  279. (TaskStatus::Failed, "failed"),
  280. (TaskStatus::Stopped, "stopped"),
  281. ];
  282. // when
  283. let rendered: Vec<_> = cases
  284. .into_iter()
  285. .map(|(status, expected)| (status.to_string(), expected))
  286. .collect();
  287. // then
  288. assert_eq!(
  289. rendered,
  290. vec![
  291. ("created".to_string(), "created"),
  292. ("running".to_string(), "running"),
  293. ("completed".to_string(), "completed"),
  294. ("failed".to_string(), "failed"),
  295. ("stopped".to_string(), "stopped"),
  296. ]
  297. );
  298. }
  299. #[test]
  300. fn stop_rejects_completed_task() {
  301. // given
  302. let registry = TaskRegistry::new();
  303. let task = registry.create("done", None);
  304. registry
  305. .set_status(&task.task_id, TaskStatus::Completed)
  306. .expect("set status should succeed");
  307. // when
  308. let result = registry.stop(&task.task_id);
  309. // then
  310. let error = result.expect_err("completed task should be rejected");
  311. assert!(error.contains("already in terminal state"));
  312. assert!(error.contains("completed"));
  313. }
  314. #[test]
  315. fn stop_rejects_failed_task() {
  316. // given
  317. let registry = TaskRegistry::new();
  318. let task = registry.create("failed", None);
  319. registry
  320. .set_status(&task.task_id, TaskStatus::Failed)
  321. .expect("set status should succeed");
  322. // when
  323. let result = registry.stop(&task.task_id);
  324. // then
  325. let error = result.expect_err("failed task should be rejected");
  326. assert!(error.contains("already in terminal state"));
  327. assert!(error.contains("failed"));
  328. }
  329. #[test]
  330. fn stop_succeeds_from_created_state() {
  331. // given
  332. let registry = TaskRegistry::new();
  333. let task = registry.create("created task", None);
  334. // when
  335. let stopped = registry.stop(&task.task_id).expect("stop should succeed");
  336. // then
  337. assert_eq!(stopped.status, TaskStatus::Stopped);
  338. assert!(stopped.updated_at >= task.updated_at);
  339. }
  340. #[test]
  341. fn new_registry_is_empty() {
  342. // given
  343. let registry = TaskRegistry::new();
  344. // when
  345. let all_tasks = registry.list(None);
  346. // then
  347. assert!(registry.is_empty());
  348. assert_eq!(registry.len(), 0);
  349. assert!(all_tasks.is_empty());
  350. }
  351. #[test]
  352. fn create_without_description() {
  353. // given
  354. let registry = TaskRegistry::new();
  355. // when
  356. let task = registry.create("Do the thing", None);
  357. // then
  358. assert!(task.task_id.starts_with("task_"));
  359. assert_eq!(task.description, None);
  360. assert!(task.messages.is_empty());
  361. assert!(task.output.is_empty());
  362. assert_eq!(task.team_id, None);
  363. }
  364. #[test]
  365. fn remove_nonexistent_returns_none() {
  366. // given
  367. let registry = TaskRegistry::new();
  368. // when
  369. let removed = registry.remove("missing");
  370. // then
  371. assert!(removed.is_none());
  372. }
  373. #[test]
  374. fn assign_team_rejects_missing_task() {
  375. // given
  376. let registry = TaskRegistry::new();
  377. // when
  378. let result = registry.assign_team("missing", "team_123");
  379. // then
  380. let error = result.expect_err("missing task should be rejected");
  381. assert_eq!(error, "task not found: missing");
  382. }
  383. }