task_registry.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. //! In-memory task registry for sub-agent task lifecycle management.
  2. use std::collections::HashMap;
  3. use std::sync::{Arc, Mutex};
  4. use std::time::{SystemTime, UNIX_EPOCH};
  5. use serde::{Deserialize, Serialize};
  6. use crate::{validate_packet, TaskPacket, TaskPacketValidationError};
  7. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
  8. #[serde(rename_all = "snake_case")]
  9. pub enum TaskStatus {
  10. Created,
  11. Running,
  12. Completed,
  13. Failed,
  14. Stopped,
  15. }
  16. impl std::fmt::Display for TaskStatus {
  17. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  18. match self {
  19. Self::Created => write!(f, "created"),
  20. Self::Running => write!(f, "running"),
  21. Self::Completed => write!(f, "completed"),
  22. Self::Failed => write!(f, "failed"),
  23. Self::Stopped => write!(f, "stopped"),
  24. }
  25. }
  26. }
  27. #[derive(Debug, Clone, Serialize, Deserialize)]
  28. pub struct Task {
  29. pub task_id: String,
  30. pub prompt: String,
  31. pub description: Option<String>,
  32. pub task_packet: Option<TaskPacket>,
  33. pub status: TaskStatus,
  34. pub created_at: u64,
  35. pub updated_at: u64,
  36. pub messages: Vec<TaskMessage>,
  37. pub output: String,
  38. pub team_id: Option<String>,
  39. }
  40. #[derive(Debug, Clone, Serialize, Deserialize)]
  41. pub struct TaskMessage {
  42. pub role: String,
  43. pub content: String,
  44. pub timestamp: u64,
  45. }
  46. #[derive(Debug, Clone, Default)]
  47. pub struct TaskRegistry {
  48. inner: Arc<Mutex<RegistryInner>>,
  49. }
  50. #[derive(Debug, Default)]
  51. struct RegistryInner {
  52. tasks: HashMap<String, Task>,
  53. counter: u64,
  54. }
  55. fn now_secs() -> u64 {
  56. SystemTime::now()
  57. .duration_since(UNIX_EPOCH)
  58. .unwrap_or_default()
  59. .as_secs()
  60. }
  61. impl TaskRegistry {
  62. #[must_use]
  63. pub fn new() -> Self {
  64. Self::default()
  65. }
  66. pub fn create(&self, prompt: &str, description: Option<&str>) -> Task {
  67. self.create_task(
  68. prompt.to_owned(),
  69. description.map(str::to_owned),
  70. None,
  71. )
  72. }
  73. pub fn create_from_packet(
  74. &self,
  75. packet: TaskPacket,
  76. ) -> Result<Task, TaskPacketValidationError> {
  77. let packet = validate_packet(packet)?.into_inner();
  78. Ok(self.create_task(
  79. packet.objective.clone(),
  80. Some(packet.scope.clone()),
  81. Some(packet),
  82. ))
  83. }
  84. fn create_task(
  85. &self,
  86. prompt: String,
  87. description: Option<String>,
  88. task_packet: Option<TaskPacket>,
  89. ) -> Task {
  90. let mut inner = self.inner.lock().expect("registry lock poisoned");
  91. inner.counter += 1;
  92. let ts = now_secs();
  93. let task_id = format!("task_{:08x}_{}", ts, inner.counter);
  94. let task = Task {
  95. task_id: task_id.clone(),
  96. prompt,
  97. description,
  98. task_packet,
  99. status: TaskStatus::Created,
  100. created_at: ts,
  101. updated_at: ts,
  102. messages: Vec::new(),
  103. output: String::new(),
  104. team_id: None,
  105. };
  106. inner.tasks.insert(task_id, task.clone());
  107. task
  108. }
  109. pub fn get(&self, task_id: &str) -> Option<Task> {
  110. let inner = self.inner.lock().expect("registry lock poisoned");
  111. inner.tasks.get(task_id).cloned()
  112. }
  113. pub fn list(&self, status_filter: Option<TaskStatus>) -> Vec<Task> {
  114. let inner = self.inner.lock().expect("registry lock poisoned");
  115. inner
  116. .tasks
  117. .values()
  118. .filter(|t| status_filter.map_or(true, |s| t.status == s))
  119. .cloned()
  120. .collect()
  121. }
  122. pub fn stop(&self, task_id: &str) -> Result<Task, String> {
  123. let mut inner = self.inner.lock().expect("registry lock poisoned");
  124. let task = inner
  125. .tasks
  126. .get_mut(task_id)
  127. .ok_or_else(|| format!("task not found: {task_id}"))?;
  128. match task.status {
  129. TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Stopped => {
  130. return Err(format!(
  131. "task {task_id} is already in terminal state: {}",
  132. task.status
  133. ));
  134. }
  135. _ => {}
  136. }
  137. task.status = TaskStatus::Stopped;
  138. task.updated_at = now_secs();
  139. Ok(task.clone())
  140. }
  141. pub fn update(&self, task_id: &str, message: &str) -> Result<Task, String> {
  142. let mut inner = self.inner.lock().expect("registry lock poisoned");
  143. let task = inner
  144. .tasks
  145. .get_mut(task_id)
  146. .ok_or_else(|| format!("task not found: {task_id}"))?;
  147. task.messages.push(TaskMessage {
  148. role: String::from("user"),
  149. content: message.to_owned(),
  150. timestamp: now_secs(),
  151. });
  152. task.updated_at = now_secs();
  153. Ok(task.clone())
  154. }
  155. pub fn output(&self, task_id: &str) -> Result<String, String> {
  156. let inner = self.inner.lock().expect("registry lock poisoned");
  157. let task = inner
  158. .tasks
  159. .get(task_id)
  160. .ok_or_else(|| format!("task not found: {task_id}"))?;
  161. Ok(task.output.clone())
  162. }
  163. pub fn append_output(&self, task_id: &str, output: &str) -> Result<(), String> {
  164. let mut inner = self.inner.lock().expect("registry lock poisoned");
  165. let task = inner
  166. .tasks
  167. .get_mut(task_id)
  168. .ok_or_else(|| format!("task not found: {task_id}"))?;
  169. task.output.push_str(output);
  170. task.updated_at = now_secs();
  171. Ok(())
  172. }
  173. pub fn set_status(&self, task_id: &str, status: TaskStatus) -> Result<(), String> {
  174. let mut inner = self.inner.lock().expect("registry lock poisoned");
  175. let task = inner
  176. .tasks
  177. .get_mut(task_id)
  178. .ok_or_else(|| format!("task not found: {task_id}"))?;
  179. task.status = status;
  180. task.updated_at = now_secs();
  181. Ok(())
  182. }
  183. pub fn assign_team(&self, task_id: &str, team_id: &str) -> Result<(), String> {
  184. let mut inner = self.inner.lock().expect("registry lock poisoned");
  185. let task = inner
  186. .tasks
  187. .get_mut(task_id)
  188. .ok_or_else(|| format!("task not found: {task_id}"))?;
  189. task.team_id = Some(team_id.to_owned());
  190. task.updated_at = now_secs();
  191. Ok(())
  192. }
  193. pub fn remove(&self, task_id: &str) -> Option<Task> {
  194. let mut inner = self.inner.lock().expect("registry lock poisoned");
  195. inner.tasks.remove(task_id)
  196. }
  197. #[must_use]
  198. pub fn len(&self) -> usize {
  199. let inner = self.inner.lock().expect("registry lock poisoned");
  200. inner.tasks.len()
  201. }
  202. #[must_use]
  203. pub fn is_empty(&self) -> bool {
  204. self.len() == 0
  205. }
  206. }
  207. #[cfg(test)]
  208. mod tests {
  209. use super::*;
  210. #[test]
  211. fn creates_and_retrieves_tasks() {
  212. let registry = TaskRegistry::new();
  213. let task = registry.create("Do something", Some("A test task"));
  214. assert_eq!(task.status, TaskStatus::Created);
  215. assert_eq!(task.prompt, "Do something");
  216. assert_eq!(task.description.as_deref(), Some("A test task"));
  217. assert_eq!(task.task_packet, None);
  218. let fetched = registry.get(&task.task_id).expect("task should exist");
  219. assert_eq!(fetched.task_id, task.task_id);
  220. }
  221. #[test]
  222. fn creates_task_from_packet() {
  223. let registry = TaskRegistry::new();
  224. let packet = TaskPacket {
  225. objective: "Ship task packet support".to_string(),
  226. scope: "runtime/task system".to_string(),
  227. repo: "claw-code-parity".to_string(),
  228. branch_policy: "origin/main only".to_string(),
  229. acceptance_tests: vec!["cargo test --workspace".to_string()],
  230. commit_policy: "single commit".to_string(),
  231. reporting_contract: "print commit sha".to_string(),
  232. escalation_policy: "manual escalation".to_string(),
  233. };
  234. let task = registry
  235. .create_from_packet(packet.clone())
  236. .expect("packet-backed task should be created");
  237. assert_eq!(task.prompt, packet.objective);
  238. assert_eq!(task.description.as_deref(), Some("runtime/task system"));
  239. assert_eq!(task.task_packet, Some(packet.clone()));
  240. let fetched = registry.get(&task.task_id).expect("task should exist");
  241. assert_eq!(fetched.task_packet, Some(packet));
  242. }
  243. #[test]
  244. fn lists_tasks_with_optional_filter() {
  245. let registry = TaskRegistry::new();
  246. registry.create("Task A", None);
  247. let task_b = registry.create("Task B", None);
  248. registry
  249. .set_status(&task_b.task_id, TaskStatus::Running)
  250. .expect("set status should succeed");
  251. let all = registry.list(None);
  252. assert_eq!(all.len(), 2);
  253. let running = registry.list(Some(TaskStatus::Running));
  254. assert_eq!(running.len(), 1);
  255. assert_eq!(running[0].task_id, task_b.task_id);
  256. let created = registry.list(Some(TaskStatus::Created));
  257. assert_eq!(created.len(), 1);
  258. }
  259. #[test]
  260. fn stops_running_task() {
  261. let registry = TaskRegistry::new();
  262. let task = registry.create("Stoppable", None);
  263. registry
  264. .set_status(&task.task_id, TaskStatus::Running)
  265. .unwrap();
  266. let stopped = registry.stop(&task.task_id).expect("stop should succeed");
  267. assert_eq!(stopped.status, TaskStatus::Stopped);
  268. // Stopping again should fail
  269. let result = registry.stop(&task.task_id);
  270. assert!(result.is_err());
  271. }
  272. #[test]
  273. fn updates_task_with_messages() {
  274. let registry = TaskRegistry::new();
  275. let task = registry.create("Messageable", None);
  276. let updated = registry
  277. .update(&task.task_id, "Here's more context")
  278. .expect("update should succeed");
  279. assert_eq!(updated.messages.len(), 1);
  280. assert_eq!(updated.messages[0].content, "Here's more context");
  281. assert_eq!(updated.messages[0].role, "user");
  282. }
  283. #[test]
  284. fn appends_and_retrieves_output() {
  285. let registry = TaskRegistry::new();
  286. let task = registry.create("Output task", None);
  287. registry
  288. .append_output(&task.task_id, "line 1\n")
  289. .expect("append should succeed");
  290. registry
  291. .append_output(&task.task_id, "line 2\n")
  292. .expect("append should succeed");
  293. let output = registry.output(&task.task_id).expect("output should exist");
  294. assert_eq!(output, "line 1\nline 2\n");
  295. }
  296. #[test]
  297. fn assigns_team_and_removes_task() {
  298. let registry = TaskRegistry::new();
  299. let task = registry.create("Team task", None);
  300. registry
  301. .assign_team(&task.task_id, "team_abc")
  302. .expect("assign should succeed");
  303. let fetched = registry.get(&task.task_id).unwrap();
  304. assert_eq!(fetched.team_id.as_deref(), Some("team_abc"));
  305. let removed = registry.remove(&task.task_id);
  306. assert!(removed.is_some());
  307. assert!(registry.get(&task.task_id).is_none());
  308. assert!(registry.is_empty());
  309. }
  310. #[test]
  311. fn rejects_operations_on_missing_task() {
  312. let registry = TaskRegistry::new();
  313. assert!(registry.stop("nonexistent").is_err());
  314. assert!(registry.update("nonexistent", "msg").is_err());
  315. assert!(registry.output("nonexistent").is_err());
  316. assert!(registry.append_output("nonexistent", "data").is_err());
  317. assert!(registry
  318. .set_status("nonexistent", TaskStatus::Running)
  319. .is_err());
  320. }
  321. #[test]
  322. fn task_status_display_all_variants() {
  323. // given
  324. let cases = [
  325. (TaskStatus::Created, "created"),
  326. (TaskStatus::Running, "running"),
  327. (TaskStatus::Completed, "completed"),
  328. (TaskStatus::Failed, "failed"),
  329. (TaskStatus::Stopped, "stopped"),
  330. ];
  331. // when
  332. let rendered: Vec<_> = cases
  333. .into_iter()
  334. .map(|(status, expected)| (status.to_string(), expected))
  335. .collect();
  336. // then
  337. assert_eq!(
  338. rendered,
  339. vec![
  340. ("created".to_string(), "created"),
  341. ("running".to_string(), "running"),
  342. ("completed".to_string(), "completed"),
  343. ("failed".to_string(), "failed"),
  344. ("stopped".to_string(), "stopped"),
  345. ]
  346. );
  347. }
  348. #[test]
  349. fn stop_rejects_completed_task() {
  350. // given
  351. let registry = TaskRegistry::new();
  352. let task = registry.create("done", None);
  353. registry
  354. .set_status(&task.task_id, TaskStatus::Completed)
  355. .expect("set status should succeed");
  356. // when
  357. let result = registry.stop(&task.task_id);
  358. // then
  359. let error = result.expect_err("completed task should be rejected");
  360. assert!(error.contains("already in terminal state"));
  361. assert!(error.contains("completed"));
  362. }
  363. #[test]
  364. fn stop_rejects_failed_task() {
  365. // given
  366. let registry = TaskRegistry::new();
  367. let task = registry.create("failed", None);
  368. registry
  369. .set_status(&task.task_id, TaskStatus::Failed)
  370. .expect("set status should succeed");
  371. // when
  372. let result = registry.stop(&task.task_id);
  373. // then
  374. let error = result.expect_err("failed task should be rejected");
  375. assert!(error.contains("already in terminal state"));
  376. assert!(error.contains("failed"));
  377. }
  378. #[test]
  379. fn stop_succeeds_from_created_state() {
  380. // given
  381. let registry = TaskRegistry::new();
  382. let task = registry.create("created task", None);
  383. // when
  384. let stopped = registry.stop(&task.task_id).expect("stop should succeed");
  385. // then
  386. assert_eq!(stopped.status, TaskStatus::Stopped);
  387. assert!(stopped.updated_at >= task.updated_at);
  388. }
  389. #[test]
  390. fn new_registry_is_empty() {
  391. // given
  392. let registry = TaskRegistry::new();
  393. // when
  394. let all_tasks = registry.list(None);
  395. // then
  396. assert!(registry.is_empty());
  397. assert_eq!(registry.len(), 0);
  398. assert!(all_tasks.is_empty());
  399. }
  400. #[test]
  401. fn create_without_description() {
  402. // given
  403. let registry = TaskRegistry::new();
  404. // when
  405. let task = registry.create("Do the thing", None);
  406. // then
  407. assert!(task.task_id.starts_with("task_"));
  408. assert_eq!(task.description, None);
  409. assert_eq!(task.task_packet, None);
  410. assert!(task.messages.is_empty());
  411. assert!(task.output.is_empty());
  412. assert_eq!(task.team_id, None);
  413. }
  414. #[test]
  415. fn remove_nonexistent_returns_none() {
  416. // given
  417. let registry = TaskRegistry::new();
  418. // when
  419. let removed = registry.remove("missing");
  420. // then
  421. assert!(removed.is_none());
  422. }
  423. #[test]
  424. fn assign_team_rejects_missing_task() {
  425. // given
  426. let registry = TaskRegistry::new();
  427. // when
  428. let result = registry.assign_team("missing", "team_123");
  429. // then
  430. let error = result.expect_err("missing task should be rejected");
  431. assert_eq!(error, "task not found: missing");
  432. }
  433. }