|
@@ -0,0 +1,335 @@
|
|
|
|
|
+//! In-memory task registry for sub-agent task lifecycle management.
|
|
|
|
|
+//!
|
|
|
|
|
+//! Provides create, get, list, stop, update, and output operations
|
|
|
|
|
+//! matching the upstream TaskCreate/TaskGet/TaskList/TaskStop/TaskUpdate/TaskOutput
|
|
|
|
|
+//! tool surface.
|
|
|
|
|
+
|
|
|
|
|
+use std::collections::HashMap;
|
|
|
|
|
+use std::sync::{Arc, Mutex};
|
|
|
|
|
+use std::time::{SystemTime, UNIX_EPOCH};
|
|
|
|
|
+
|
|
|
|
|
+use serde::{Deserialize, Serialize};
|
|
|
|
|
+
|
|
|
|
|
+/// Current status of a managed task.
|
|
|
|
|
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
|
|
|
+#[serde(rename_all = "snake_case")]
|
|
|
|
|
+pub enum TaskStatus {
|
|
|
|
|
+ Created,
|
|
|
|
|
+ Running,
|
|
|
|
|
+ Completed,
|
|
|
|
|
+ Failed,
|
|
|
|
|
+ Stopped,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+impl std::fmt::Display for TaskStatus {
|
|
|
|
|
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
|
|
|
+ match self {
|
|
|
|
|
+ Self::Created => write!(f, "created"),
|
|
|
|
|
+ Self::Running => write!(f, "running"),
|
|
|
|
|
+ Self::Completed => write!(f, "completed"),
|
|
|
|
|
+ Self::Failed => write!(f, "failed"),
|
|
|
|
|
+ Self::Stopped => write!(f, "stopped"),
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+/// A single managed task entry.
|
|
|
|
|
+#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
+pub struct Task {
|
|
|
|
|
+ pub task_id: String,
|
|
|
|
|
+ pub prompt: String,
|
|
|
|
|
+ pub description: Option<String>,
|
|
|
|
|
+ pub status: TaskStatus,
|
|
|
|
|
+ pub created_at: u64,
|
|
|
|
|
+ pub updated_at: u64,
|
|
|
|
|
+ pub messages: Vec<TaskMessage>,
|
|
|
|
|
+ pub output: String,
|
|
|
|
|
+ pub team_id: Option<String>,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+/// A message exchanged with a running task.
|
|
|
|
|
+#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
|
|
+pub struct TaskMessage {
|
|
|
|
|
+ pub role: String,
|
|
|
|
|
+ pub content: String,
|
|
|
|
|
+ pub timestamp: u64,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+/// Thread-safe task registry.
|
|
|
|
|
+#[derive(Debug, Clone, Default)]
|
|
|
|
|
+pub struct TaskRegistry {
|
|
|
|
|
+ inner: Arc<Mutex<RegistryInner>>,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+#[derive(Debug, Default)]
|
|
|
|
|
+struct RegistryInner {
|
|
|
|
|
+ tasks: HashMap<String, Task>,
|
|
|
|
|
+ counter: u64,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+fn now_secs() -> u64 {
|
|
|
|
|
+ SystemTime::now()
|
|
|
|
|
+ .duration_since(UNIX_EPOCH)
|
|
|
|
|
+ .unwrap_or_default()
|
|
|
|
|
+ .as_secs()
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+impl TaskRegistry {
|
|
|
|
|
+ /// Create a new empty registry.
|
|
|
|
|
+ #[must_use]
|
|
|
|
|
+ pub fn new() -> Self {
|
|
|
|
|
+ Self::default()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Create a new task and return its ID.
|
|
|
|
|
+ pub fn create(&self, prompt: &str, description: Option<&str>) -> Task {
|
|
|
|
|
+ let mut inner = self.inner.lock().expect("registry lock poisoned");
|
|
|
|
|
+ inner.counter += 1;
|
|
|
|
|
+ let ts = now_secs();
|
|
|
|
|
+ let task_id = format!("task_{:08x}_{}", ts, inner.counter);
|
|
|
|
|
+ let task = Task {
|
|
|
|
|
+ task_id: task_id.clone(),
|
|
|
|
|
+ prompt: prompt.to_owned(),
|
|
|
|
|
+ description: description.map(str::to_owned),
|
|
|
|
|
+ status: TaskStatus::Created,
|
|
|
|
|
+ created_at: ts,
|
|
|
|
|
+ updated_at: ts,
|
|
|
|
|
+ messages: Vec::new(),
|
|
|
|
|
+ output: String::new(),
|
|
|
|
|
+ team_id: None,
|
|
|
|
|
+ };
|
|
|
|
|
+ inner.tasks.insert(task_id, task.clone());
|
|
|
|
|
+ task
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Look up a task by ID.
|
|
|
|
|
+ pub fn get(&self, task_id: &str) -> Option<Task> {
|
|
|
|
|
+ let inner = self.inner.lock().expect("registry lock poisoned");
|
|
|
|
|
+ inner.tasks.get(task_id).cloned()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// List all tasks, optionally filtered by status.
|
|
|
|
|
+ pub fn list(&self, status_filter: Option<TaskStatus>) -> Vec<Task> {
|
|
|
|
|
+ let inner = self.inner.lock().expect("registry lock poisoned");
|
|
|
|
|
+ inner
|
|
|
|
|
+ .tasks
|
|
|
|
|
+ .values()
|
|
|
|
|
+ .filter(|t| status_filter.map_or(true, |s| t.status == s))
|
|
|
|
|
+ .cloned()
|
|
|
|
|
+ .collect()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Mark a task as stopped.
|
|
|
|
|
+ pub fn stop(&self, task_id: &str) -> Result<Task, String> {
|
|
|
|
|
+ let mut inner = self.inner.lock().expect("registry lock poisoned");
|
|
|
|
|
+ let task = inner
|
|
|
|
|
+ .tasks
|
|
|
|
|
+ .get_mut(task_id)
|
|
|
|
|
+ .ok_or_else(|| format!("task not found: {task_id}"))?;
|
|
|
|
|
+
|
|
|
|
|
+ match task.status {
|
|
|
|
|
+ TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Stopped => {
|
|
|
|
|
+ return Err(format!(
|
|
|
|
|
+ "task {task_id} is already in terminal state: {}",
|
|
|
|
|
+ task.status
|
|
|
|
|
+ ));
|
|
|
|
|
+ }
|
|
|
|
|
+ _ => {}
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ task.status = TaskStatus::Stopped;
|
|
|
|
|
+ task.updated_at = now_secs();
|
|
|
|
|
+ Ok(task.clone())
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Send a message to a task, updating its state.
|
|
|
|
|
+ pub fn update(&self, task_id: &str, message: &str) -> Result<Task, String> {
|
|
|
|
|
+ let mut inner = self.inner.lock().expect("registry lock poisoned");
|
|
|
|
|
+ let task = inner
|
|
|
|
|
+ .tasks
|
|
|
|
|
+ .get_mut(task_id)
|
|
|
|
|
+ .ok_or_else(|| format!("task not found: {task_id}"))?;
|
|
|
|
|
+
|
|
|
|
|
+ task.messages.push(TaskMessage {
|
|
|
|
|
+ role: String::from("user"),
|
|
|
|
|
+ content: message.to_owned(),
|
|
|
|
|
+ timestamp: now_secs(),
|
|
|
|
|
+ });
|
|
|
|
|
+ task.updated_at = now_secs();
|
|
|
|
|
+ Ok(task.clone())
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Get the accumulated output of a task.
|
|
|
|
|
+ pub fn output(&self, task_id: &str) -> Result<String, String> {
|
|
|
|
|
+ let inner = self.inner.lock().expect("registry lock poisoned");
|
|
|
|
|
+ let task = inner
|
|
|
|
|
+ .tasks
|
|
|
|
|
+ .get(task_id)
|
|
|
|
|
+ .ok_or_else(|| format!("task not found: {task_id}"))?;
|
|
|
|
|
+ Ok(task.output.clone())
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Append output to a task (used by the task executor).
|
|
|
|
|
+ pub fn append_output(&self, task_id: &str, output: &str) -> Result<(), String> {
|
|
|
|
|
+ let mut inner = self.inner.lock().expect("registry lock poisoned");
|
|
|
|
|
+ let task = inner
|
|
|
|
|
+ .tasks
|
|
|
|
|
+ .get_mut(task_id)
|
|
|
|
|
+ .ok_or_else(|| format!("task not found: {task_id}"))?;
|
|
|
|
|
+ task.output.push_str(output);
|
|
|
|
|
+ task.updated_at = now_secs();
|
|
|
|
|
+ Ok(())
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Transition a task to a new status.
|
|
|
|
|
+ pub fn set_status(&self, task_id: &str, status: TaskStatus) -> Result<(), String> {
|
|
|
|
|
+ let mut inner = self.inner.lock().expect("registry lock poisoned");
|
|
|
|
|
+ let task = inner
|
|
|
|
|
+ .tasks
|
|
|
|
|
+ .get_mut(task_id)
|
|
|
|
|
+ .ok_or_else(|| format!("task not found: {task_id}"))?;
|
|
|
|
|
+ task.status = status;
|
|
|
|
|
+ task.updated_at = now_secs();
|
|
|
|
|
+ Ok(())
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Assign a task to a team.
|
|
|
|
|
+ pub fn assign_team(&self, task_id: &str, team_id: &str) -> Result<(), String> {
|
|
|
|
|
+ let mut inner = self.inner.lock().expect("registry lock poisoned");
|
|
|
|
|
+ let task = inner
|
|
|
|
|
+ .tasks
|
|
|
|
|
+ .get_mut(task_id)
|
|
|
|
|
+ .ok_or_else(|| format!("task not found: {task_id}"))?;
|
|
|
|
|
+ task.team_id = Some(team_id.to_owned());
|
|
|
|
|
+ task.updated_at = now_secs();
|
|
|
|
|
+ Ok(())
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Remove a task from the registry.
|
|
|
|
|
+ pub fn remove(&self, task_id: &str) -> Option<Task> {
|
|
|
|
|
+ let mut inner = self.inner.lock().expect("registry lock poisoned");
|
|
|
|
|
+ inner.tasks.remove(task_id)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Number of tasks in the registry.
|
|
|
|
|
+ #[must_use]
|
|
|
|
|
+ pub fn len(&self) -> usize {
|
|
|
|
|
+ let inner = self.inner.lock().expect("registry lock poisoned");
|
|
|
|
|
+ inner.tasks.len()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Whether the registry has no tasks.
|
|
|
|
|
+ #[must_use]
|
|
|
|
|
+ pub fn is_empty(&self) -> bool {
|
|
|
|
|
+ self.len() == 0
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+#[cfg(test)]
|
|
|
|
|
+mod tests {
|
|
|
|
|
+ use super::*;
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn creates_and_retrieves_tasks() {
|
|
|
|
|
+ let registry = TaskRegistry::new();
|
|
|
|
|
+ let task = registry.create("Do something", Some("A test task"));
|
|
|
|
|
+ assert_eq!(task.status, TaskStatus::Created);
|
|
|
|
|
+ assert_eq!(task.prompt, "Do something");
|
|
|
|
|
+ assert_eq!(task.description.as_deref(), Some("A test task"));
|
|
|
|
|
+
|
|
|
|
|
+ let fetched = registry.get(&task.task_id).expect("task should exist");
|
|
|
|
|
+ assert_eq!(fetched.task_id, task.task_id);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn lists_tasks_with_optional_filter() {
|
|
|
|
|
+ let registry = TaskRegistry::new();
|
|
|
|
|
+ registry.create("Task A", None);
|
|
|
|
|
+ let task_b = registry.create("Task B", None);
|
|
|
|
|
+ registry
|
|
|
|
|
+ .set_status(&task_b.task_id, TaskStatus::Running)
|
|
|
|
|
+ .expect("set status should succeed");
|
|
|
|
|
+
|
|
|
|
|
+ let all = registry.list(None);
|
|
|
|
|
+ assert_eq!(all.len(), 2);
|
|
|
|
|
+
|
|
|
|
|
+ let running = registry.list(Some(TaskStatus::Running));
|
|
|
|
|
+ assert_eq!(running.len(), 1);
|
|
|
|
|
+ assert_eq!(running[0].task_id, task_b.task_id);
|
|
|
|
|
+
|
|
|
|
|
+ let created = registry.list(Some(TaskStatus::Created));
|
|
|
|
|
+ assert_eq!(created.len(), 1);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn stops_running_task() {
|
|
|
|
|
+ let registry = TaskRegistry::new();
|
|
|
|
|
+ let task = registry.create("Stoppable", None);
|
|
|
|
|
+ registry
|
|
|
|
|
+ .set_status(&task.task_id, TaskStatus::Running)
|
|
|
|
|
+ .unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ let stopped = registry.stop(&task.task_id).expect("stop should succeed");
|
|
|
|
|
+ assert_eq!(stopped.status, TaskStatus::Stopped);
|
|
|
|
|
+
|
|
|
|
|
+ // Stopping again should fail
|
|
|
|
|
+ let result = registry.stop(&task.task_id);
|
|
|
|
|
+ assert!(result.is_err());
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn updates_task_with_messages() {
|
|
|
|
|
+ let registry = TaskRegistry::new();
|
|
|
|
|
+ let task = registry.create("Messageable", None);
|
|
|
|
|
+ let updated = registry
|
|
|
|
|
+ .update(&task.task_id, "Here's more context")
|
|
|
|
|
+ .expect("update should succeed");
|
|
|
|
|
+ assert_eq!(updated.messages.len(), 1);
|
|
|
|
|
+ assert_eq!(updated.messages[0].content, "Here's more context");
|
|
|
|
|
+ assert_eq!(updated.messages[0].role, "user");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn appends_and_retrieves_output() {
|
|
|
|
|
+ let registry = TaskRegistry::new();
|
|
|
|
|
+ let task = registry.create("Output task", None);
|
|
|
|
|
+ registry
|
|
|
|
|
+ .append_output(&task.task_id, "line 1\n")
|
|
|
|
|
+ .expect("append should succeed");
|
|
|
|
|
+ registry
|
|
|
|
|
+ .append_output(&task.task_id, "line 2\n")
|
|
|
|
|
+ .expect("append should succeed");
|
|
|
|
|
+
|
|
|
|
|
+ let output = registry.output(&task.task_id).expect("output should exist");
|
|
|
|
|
+ assert_eq!(output, "line 1\nline 2\n");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn assigns_team_and_removes_task() {
|
|
|
|
|
+ let registry = TaskRegistry::new();
|
|
|
|
|
+ let task = registry.create("Team task", None);
|
|
|
|
|
+ registry
|
|
|
|
|
+ .assign_team(&task.task_id, "team_abc")
|
|
|
|
|
+ .expect("assign should succeed");
|
|
|
|
|
+
|
|
|
|
|
+ let fetched = registry.get(&task.task_id).unwrap();
|
|
|
|
|
+ assert_eq!(fetched.team_id.as_deref(), Some("team_abc"));
|
|
|
|
|
+
|
|
|
|
|
+ let removed = registry.remove(&task.task_id);
|
|
|
|
|
+ assert!(removed.is_some());
|
|
|
|
|
+ assert!(registry.get(&task.task_id).is_none());
|
|
|
|
|
+ assert!(registry.is_empty());
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn rejects_operations_on_missing_task() {
|
|
|
|
|
+ let registry = TaskRegistry::new();
|
|
|
|
|
+ assert!(registry.stop("nonexistent").is_err());
|
|
|
|
|
+ assert!(registry.update("nonexistent", "msg").is_err());
|
|
|
|
|
+ assert!(registry.output("nonexistent").is_err());
|
|
|
|
|
+ assert!(registry.append_output("nonexistent", "data").is_err());
|
|
|
|
|
+ assert!(registry
|
|
|
|
|
+ .set_status("nonexistent", TaskStatus::Running)
|
|
|
|
|
+ .is_err());
|
|
|
|
|
+ }
|
|
|
|
|
+}
|