(Hopefully) productionized shepherdd
This commit is contained in:
parent
f3e62c43ea
commit
ac2d2abfed
55 changed files with 7418 additions and 1353 deletions
1178
Cargo.lock
generated
1178
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
69
Cargo.toml
69
Cargo.toml
|
|
@ -1,12 +1,63 @@
|
|||
[package]
|
||||
name = "shepherd-launcher"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/shepherd-util",
|
||||
"crates/shepherd-api",
|
||||
"crates/shepherd-host-api",
|
||||
"crates/shepherd-config",
|
||||
"crates/shepherd-store",
|
||||
"crates/shepherd-core",
|
||||
"crates/shepherd-host-linux",
|
||||
"crates/shepherd-ipc",
|
||||
"crates/shepherdd",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
smithay-client-toolkit = "0.19"
|
||||
wayland-client = "0.31"
|
||||
cairo-rs = { version = "0.20", features = ["v1_16"] }
|
||||
chrono = "0.4"
|
||||
[workspace.package]
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
authors = ["Shepherd Contributors"]
|
||||
repository = "https://github.com/shepherd-project/shepherdd"
|
||||
|
||||
[workspace.dependencies]
|
||||
# Internal crates
|
||||
shepherd-util = { path = "crates/shepherd-util" }
|
||||
shepherd-api = { path = "crates/shepherd-api" }
|
||||
shepherd-host-api = { path = "crates/shepherd-host-api" }
|
||||
shepherd-config = { path = "crates/shepherd-config" }
|
||||
shepherd-store = { path = "crates/shepherd-store" }
|
||||
shepherd-core = { path = "crates/shepherd-core" }
|
||||
shepherd-host-linux = { path = "crates/shepherd-host-linux" }
|
||||
shepherd-ipc = { path = "crates/shepherd-ipc" }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
toml = "0.8"
|
||||
|
||||
# Async runtime
|
||||
tokio = { version = "1.35", features = ["full", "signal"] }
|
||||
|
||||
# Database
|
||||
rusqlite = { version = "0.32", features = ["bundled"] }
|
||||
|
||||
# Time
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
|
||||
# Logging
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||
|
||||
# Error handling
|
||||
thiserror = "1.0"
|
||||
anyhow = "1.0"
|
||||
|
||||
# Utilities
|
||||
uuid = { version = "1.6", features = ["v4", "serde"] }
|
||||
bitflags = "2.4"
|
||||
|
||||
# Unix-specific
|
||||
nix = { version = "0.29", features = ["signal", "process", "user", "socket"] }
|
||||
|
||||
# Testing
|
||||
tempfile = "3.9"
|
||||
|
|
|
|||
138
config.example.toml
Normal file
138
config.example.toml
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
# Sample shepherdd configuration
|
||||
# This file defines the policy for what applications can run and when
|
||||
|
||||
config_version = 1
|
||||
|
||||
[daemon]
|
||||
# Uncomment to customize paths
|
||||
# socket_path = "/run/shepherdd/shepherdd.sock"
|
||||
# log_dir = "/var/log/shepherdd"
|
||||
# data_dir = "/var/lib/shepherdd"
|
||||
|
||||
# Default max run duration if not specified per entry (1 hour)
|
||||
default_max_run_seconds = 3600
|
||||
|
||||
# Default warning thresholds
|
||||
[[daemon.default_warnings]]
|
||||
seconds_before = 300
|
||||
severity = "info"
|
||||
message = "5 minutes remaining"
|
||||
|
||||
[[daemon.default_warnings]]
|
||||
seconds_before = 60
|
||||
severity = "warn"
|
||||
message = "1 minute remaining!"
|
||||
|
||||
[[daemon.default_warnings]]
|
||||
seconds_before = 10
|
||||
severity = "critical"
|
||||
message = "10 seconds remaining!"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Entries
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Example: ScummVM - classic adventure games
|
||||
[[entries]]
|
||||
id = "scummvm"
|
||||
label = "ScummVM"
|
||||
icon = "scummvm"
|
||||
|
||||
[entries.kind]
|
||||
type = "process"
|
||||
argv = ["scummvm", "-f"]
|
||||
|
||||
[entries.availability]
|
||||
[[entries.availability.windows]]
|
||||
days = "weekdays"
|
||||
start = "15:00"
|
||||
end = "18:00"
|
||||
|
||||
[[entries.availability.windows]]
|
||||
days = "weekends"
|
||||
start = "10:00"
|
||||
end = "20:00"
|
||||
|
||||
[entries.limits]
|
||||
max_run_seconds = 3600 # 1 hour max
|
||||
daily_quota_seconds = 7200 # 2 hours per day
|
||||
cooldown_seconds = 300 # 5 minute cooldown after each session
|
||||
|
||||
# Example: Minecraft (via Prism Launcher)
|
||||
[[entries]]
|
||||
id = "minecraft"
|
||||
label = "Minecraft"
|
||||
icon = "prismlauncher"
|
||||
|
||||
[entries.kind]
|
||||
type = "process"
|
||||
argv = ["prismlauncher"]
|
||||
|
||||
[entries.availability]
|
||||
always = true # No time restrictions
|
||||
|
||||
[entries.limits]
|
||||
max_run_seconds = 5400 # 90 minutes
|
||||
daily_quota_seconds = 10800 # 3 hours per day
|
||||
|
||||
[[entries.warnings]]
|
||||
seconds_before = 600
|
||||
severity = "info"
|
||||
message = "10 minutes left - start wrapping up!"
|
||||
|
||||
[[entries.warnings]]
|
||||
seconds_before = 120
|
||||
severity = "warn"
|
||||
message = "2 minutes remaining - save your game!"
|
||||
|
||||
[[entries.warnings]]
|
||||
seconds_before = 30
|
||||
severity = "critical"
|
||||
message = "30 seconds! Save NOW!"
|
||||
|
||||
# Example: Educational game - unrestricted
|
||||
[[entries]]
|
||||
id = "tuxmath"
|
||||
label = "Tux Math"
|
||||
icon = "tuxmath"
|
||||
|
||||
[entries.kind]
|
||||
type = "process"
|
||||
argv = ["tuxmath"]
|
||||
|
||||
[entries.availability]
|
||||
always = true
|
||||
|
||||
[entries.limits]
|
||||
max_run_seconds = 7200 # 2 hours
|
||||
|
||||
# Example: Web browser (restricted)
|
||||
[[entries]]
|
||||
id = "browser"
|
||||
label = "Web Browser"
|
||||
icon = "firefox"
|
||||
|
||||
[entries.kind]
|
||||
type = "process"
|
||||
argv = ["firefox", "-P", "kids"]
|
||||
|
||||
[entries.availability]
|
||||
[[entries.availability.windows]]
|
||||
days = ["sat", "sun"]
|
||||
start = "14:00"
|
||||
end = "17:00"
|
||||
|
||||
[entries.limits]
|
||||
max_run_seconds = 1800 # 30 minutes
|
||||
daily_quota_seconds = 3600 # 1 hour per day
|
||||
|
||||
# Example: Disabled entry
|
||||
[[entries]]
|
||||
id = "disabled-game"
|
||||
label = "Game Under Maintenance"
|
||||
disabled = true
|
||||
disabled_reason = "This game is being updated"
|
||||
|
||||
[entries.kind]
|
||||
type = "process"
|
||||
argv = ["/bin/false"]
|
||||
13
crates/shepherd-api/Cargo.toml
Normal file
13
crates/shepherd-api/Cargo.toml
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
[package]
|
||||
name = "shepherd-api"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Protocol types for shepherdd IPC: commands, events, versioning"
|
||||
|
||||
[dependencies]
|
||||
shepherd-util = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
223
crates/shepherd-api/src/commands.rs
Normal file
223
crates/shepherd-api/src/commands.rs
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
//! Command types for the shepherdd protocol
|
||||
|
||||
use chrono::{DateTime, Local};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shepherd_util::{ClientId, EntryId};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::{ClientRole, StopMode, API_VERSION};
|
||||
|
||||
/// Request wrapper with metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Request {
|
||||
/// Request ID for correlation
|
||||
pub request_id: u64,
|
||||
/// API version
|
||||
pub api_version: u32,
|
||||
/// The command
|
||||
pub command: Command,
|
||||
}
|
||||
|
||||
impl Request {
|
||||
pub fn new(request_id: u64, command: Command) -> Self {
|
||||
Self {
|
||||
request_id,
|
||||
api_version: API_VERSION,
|
||||
command,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response wrapper
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Response {
|
||||
/// Corresponding request ID
|
||||
pub request_id: u64,
|
||||
/// API version
|
||||
pub api_version: u32,
|
||||
/// Response payload or error
|
||||
pub result: ResponseResult,
|
||||
}
|
||||
|
||||
impl Response {
|
||||
pub fn success(request_id: u64, payload: ResponsePayload) -> Self {
|
||||
Self {
|
||||
request_id,
|
||||
api_version: API_VERSION,
|
||||
result: ResponseResult::Ok(payload),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn error(request_id: u64, error: ErrorInfo) -> Self {
|
||||
Self {
|
||||
request_id,
|
||||
api_version: API_VERSION,
|
||||
result: ResponseResult::Err(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseResult {
|
||||
Ok(ResponsePayload),
|
||||
Err(ErrorInfo),
|
||||
}
|
||||
|
||||
/// Error information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorInfo {
|
||||
pub code: ErrorCode,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl ErrorInfo {
|
||||
pub fn new(code: ErrorCode, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
code,
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error codes for the protocol
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ErrorCode {
|
||||
InvalidRequest,
|
||||
EntryNotFound,
|
||||
LaunchDenied,
|
||||
NoActiveSession,
|
||||
SessionActive,
|
||||
PermissionDenied,
|
||||
RateLimited,
|
||||
ConfigError,
|
||||
HostError,
|
||||
InternalError,
|
||||
}
|
||||
|
||||
/// All possible commands from clients
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum Command {
|
||||
/// Get current daemon state
|
||||
GetState,
|
||||
|
||||
/// List available entries
|
||||
ListEntries {
|
||||
/// Optional: evaluate at a specific time (for preview)
|
||||
at_time: Option<DateTime<Local>>,
|
||||
},
|
||||
|
||||
/// Request to launch an entry
|
||||
Launch { entry_id: EntryId },
|
||||
|
||||
/// Stop the current session
|
||||
StopCurrent { mode: StopMode },
|
||||
|
||||
/// Reload configuration
|
||||
ReloadConfig,
|
||||
|
||||
/// Subscribe to events (returns immediately, events stream separately)
|
||||
SubscribeEvents,
|
||||
|
||||
/// Unsubscribe from events
|
||||
UnsubscribeEvents,
|
||||
|
||||
/// Get health status
|
||||
GetHealth,
|
||||
|
||||
// Admin commands
|
||||
|
||||
/// Extend the current session (admin only)
|
||||
ExtendCurrent { by: Duration },
|
||||
|
||||
/// Ping for keepalive
|
||||
Ping,
|
||||
}
|
||||
|
||||
/// Response payloads
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponsePayload {
|
||||
State(crate::DaemonStateSnapshot),
|
||||
Entries(Vec<crate::EntryView>),
|
||||
LaunchApproved {
|
||||
session_id: shepherd_util::SessionId,
|
||||
deadline: DateTime<Local>,
|
||||
},
|
||||
LaunchDenied {
|
||||
reasons: Vec<crate::ReasonCode>,
|
||||
},
|
||||
Stopped,
|
||||
ConfigReloaded,
|
||||
Subscribed {
|
||||
client_id: ClientId,
|
||||
},
|
||||
Unsubscribed,
|
||||
Health(crate::HealthStatus),
|
||||
Extended {
|
||||
new_deadline: DateTime<Local>,
|
||||
},
|
||||
Pong,
|
||||
}
|
||||
|
||||
/// Client connection info (set by IPC layer)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClientInfo {
|
||||
pub client_id: ClientId,
|
||||
pub role: ClientRole,
|
||||
/// Unix UID if available
|
||||
pub uid: Option<u32>,
|
||||
/// Process name if available
|
||||
pub process_name: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientInfo {
|
||||
pub fn new(role: ClientRole) -> Self {
|
||||
Self {
|
||||
client_id: ClientId::new(),
|
||||
role,
|
||||
uid: None,
|
||||
process_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_uid(mut self, uid: u32) -> Self {
|
||||
self.uid = Some(uid);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn request_serialization() {
|
||||
let req = Request::new(1, Command::GetState);
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
let parsed: Request = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.request_id, 1);
|
||||
assert!(matches!(parsed.command, Command::GetState));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_serialization() {
|
||||
let resp = Response::success(
|
||||
1,
|
||||
ResponsePayload::State(crate::DaemonStateSnapshot {
|
||||
api_version: API_VERSION,
|
||||
policy_loaded: true,
|
||||
current_session: None,
|
||||
entry_count: 5,
|
||||
}),
|
||||
);
|
||||
|
||||
let json = serde_json::to_string(&resp).unwrap();
|
||||
let parsed: Response = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.request_id, 1);
|
||||
}
|
||||
}
|
||||
105
crates/shepherd-api/src/events.rs
Normal file
105
crates/shepherd-api/src/events.rs
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
//! Event types for shepherdd -> client streaming
|
||||
|
||||
use chrono::{DateTime, Local};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shepherd_util::{EntryId, SessionId};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::{DaemonStateSnapshot, SessionEndReason, WarningSeverity, API_VERSION};
|
||||
|
||||
/// Event envelope
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Event {
|
||||
pub api_version: u32,
|
||||
pub timestamp: DateTime<Local>,
|
||||
pub payload: EventPayload,
|
||||
}
|
||||
|
||||
impl Event {
|
||||
pub fn new(payload: EventPayload) -> Self {
|
||||
Self {
|
||||
api_version: API_VERSION,
|
||||
timestamp: Local::now(),
|
||||
payload,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// All possible events from daemon to clients
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum EventPayload {
|
||||
/// Full state snapshot (sent on subscribe and major changes)
|
||||
StateChanged(DaemonStateSnapshot),
|
||||
|
||||
/// Session has started
|
||||
SessionStarted {
|
||||
session_id: SessionId,
|
||||
entry_id: EntryId,
|
||||
label: String,
|
||||
deadline: DateTime<Local>,
|
||||
},
|
||||
|
||||
/// Warning issued for current session
|
||||
WarningIssued {
|
||||
session_id: SessionId,
|
||||
threshold_seconds: u64,
|
||||
time_remaining: Duration,
|
||||
severity: WarningSeverity,
|
||||
message: Option<String>,
|
||||
},
|
||||
|
||||
/// Session is expiring (termination initiated)
|
||||
SessionExpiring {
|
||||
session_id: SessionId,
|
||||
},
|
||||
|
||||
/// Session has ended
|
||||
SessionEnded {
|
||||
session_id: SessionId,
|
||||
entry_id: EntryId,
|
||||
reason: SessionEndReason,
|
||||
duration: Duration,
|
||||
},
|
||||
|
||||
/// Policy was reloaded
|
||||
PolicyReloaded {
|
||||
entry_count: usize,
|
||||
},
|
||||
|
||||
/// Entry availability changed (for UI updates)
|
||||
EntryAvailabilityChanged {
|
||||
entry_id: EntryId,
|
||||
enabled: bool,
|
||||
},
|
||||
|
||||
/// Daemon is shutting down
|
||||
Shutdown,
|
||||
|
||||
/// Audit event (for admin clients)
|
||||
AuditEntry {
|
||||
event_type: String,
|
||||
details: serde_json::Value,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn event_serialization() {
|
||||
let event = Event::new(EventPayload::SessionStarted {
|
||||
session_id: SessionId::new(),
|
||||
entry_id: EntryId::new("game-1"),
|
||||
label: "Test Game".into(),
|
||||
deadline: Local::now(),
|
||||
});
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
let parsed: Event = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.api_version, API_VERSION);
|
||||
assert!(matches!(parsed.payload, EventPayload::SessionStarted { .. }));
|
||||
}
|
||||
}
|
||||
18
crates/shepherd-api/src/lib.rs
Normal file
18
crates/shepherd-api/src/lib.rs
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
//! Protocol types for shepherdd IPC
|
||||
//!
|
||||
//! This crate defines the stable API between the daemon and clients:
|
||||
//! - Commands (requests from clients)
|
||||
//! - Responses
|
||||
//! - Events (daemon -> clients)
|
||||
//! - Versioning
|
||||
|
||||
mod commands;
|
||||
mod events;
|
||||
mod types;
|
||||
|
||||
pub use commands::*;
|
||||
pub use events::*;
|
||||
pub use types::*;
|
||||
|
||||
/// Current API version
|
||||
pub const API_VERSION: u32 = 1;
|
||||
252
crates/shepherd-api/src/types.rs
Normal file
252
crates/shepherd-api/src/types.rs
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
//! Shared types for the shepherdd API
|
||||
|
||||
use chrono::{DateTime, Local};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shepherd_util::{EntryId, SessionId};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Entry kind tag for capability matching
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum EntryKindTag {
|
||||
Process,
|
||||
Vm,
|
||||
Media,
|
||||
Custom,
|
||||
}
|
||||
|
||||
/// Entry kind with launch details
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum EntryKind {
|
||||
Process {
|
||||
argv: Vec<String>,
|
||||
#[serde(default)]
|
||||
env: HashMap<String, String>,
|
||||
cwd: Option<PathBuf>,
|
||||
},
|
||||
Vm {
|
||||
driver: String,
|
||||
#[serde(default)]
|
||||
args: HashMap<String, serde_json::Value>,
|
||||
},
|
||||
Media {
|
||||
library_id: String,
|
||||
#[serde(default)]
|
||||
args: HashMap<String, serde_json::Value>,
|
||||
},
|
||||
Custom {
|
||||
type_name: String,
|
||||
payload: serde_json::Value,
|
||||
},
|
||||
}
|
||||
|
||||
impl EntryKind {
|
||||
pub fn tag(&self) -> EntryKindTag {
|
||||
match self {
|
||||
EntryKind::Process { .. } => EntryKindTag::Process,
|
||||
EntryKind::Vm { .. } => EntryKindTag::Vm,
|
||||
EntryKind::Media { .. } => EntryKindTag::Media,
|
||||
EntryKind::Custom { .. } => EntryKindTag::Custom,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// View of an entry for UI display
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EntryView {
|
||||
pub entry_id: EntryId,
|
||||
pub label: String,
|
||||
pub icon_ref: Option<String>,
|
||||
pub kind_tag: EntryKindTag,
|
||||
pub enabled: bool,
|
||||
pub reasons: Vec<ReasonCode>,
|
||||
/// If enabled, maximum run duration if started now
|
||||
pub max_run_if_started_now: Option<Duration>,
|
||||
}
|
||||
|
||||
/// Structured reason codes for why an entry is unavailable
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "code", rename_all = "snake_case")]
|
||||
pub enum ReasonCode {
|
||||
/// Outside allowed time window
|
||||
OutsideTimeWindow {
|
||||
/// When the next window opens (if known)
|
||||
next_window_start: Option<DateTime<Local>>,
|
||||
},
|
||||
/// Daily quota exhausted
|
||||
QuotaExhausted {
|
||||
used: Duration,
|
||||
quota: Duration,
|
||||
},
|
||||
/// Cooldown period active
|
||||
CooldownActive {
|
||||
available_at: DateTime<Local>,
|
||||
},
|
||||
/// Another session is active
|
||||
SessionActive {
|
||||
entry_id: EntryId,
|
||||
remaining: Duration,
|
||||
},
|
||||
/// Host doesn't support this entry kind
|
||||
UnsupportedKind {
|
||||
kind: EntryKindTag,
|
||||
},
|
||||
/// Entry is explicitly disabled
|
||||
Disabled {
|
||||
reason: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Warning severity level
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum WarningSeverity {
|
||||
Info,
|
||||
Warn,
|
||||
Critical,
|
||||
}
|
||||
|
||||
/// Warning threshold configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WarningThreshold {
|
||||
/// Seconds before expiry to issue this warning
|
||||
pub seconds_before: u64,
|
||||
pub severity: WarningSeverity,
|
||||
pub message_template: Option<String>,
|
||||
}
|
||||
|
||||
/// Session end reason
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum SessionEndReason {
|
||||
/// Session expired (time limit reached)
|
||||
Expired,
|
||||
/// User requested stop
|
||||
UserStop,
|
||||
/// Admin requested stop
|
||||
AdminStop,
|
||||
/// Process exited on its own
|
||||
ProcessExited { exit_code: Option<i32> },
|
||||
/// Policy change terminated session
|
||||
PolicyStop,
|
||||
/// Daemon shutdown
|
||||
DaemonShutdown,
|
||||
/// Launch failed
|
||||
LaunchFailed { error: String },
|
||||
}
|
||||
|
||||
/// Current session state
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SessionState {
|
||||
Launching,
|
||||
Running,
|
||||
Warned,
|
||||
Expiring,
|
||||
Ended,
|
||||
}
|
||||
|
||||
/// Active session information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionInfo {
|
||||
pub session_id: SessionId,
|
||||
pub entry_id: EntryId,
|
||||
pub label: String,
|
||||
pub state: SessionState,
|
||||
pub started_at: DateTime<Local>,
|
||||
pub deadline: DateTime<Local>,
|
||||
pub time_remaining: Duration,
|
||||
pub warnings_issued: Vec<u64>,
|
||||
}
|
||||
|
||||
/// Full daemon state snapshot
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DaemonStateSnapshot {
|
||||
pub api_version: u32,
|
||||
pub policy_loaded: bool,
|
||||
pub current_session: Option<SessionInfo>,
|
||||
pub entry_count: usize,
|
||||
}
|
||||
|
||||
/// Role for authorization
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ClientRole {
|
||||
/// UI/HUD - can view state, launch entries, stop current
|
||||
Shell,
|
||||
/// Local admin - can also extend, reload config
|
||||
Admin,
|
||||
/// Read-only observer
|
||||
Observer,
|
||||
}
|
||||
|
||||
impl ClientRole {
|
||||
pub fn can_launch(&self) -> bool {
|
||||
matches!(self, ClientRole::Shell | ClientRole::Admin)
|
||||
}
|
||||
|
||||
pub fn can_stop(&self) -> bool {
|
||||
matches!(self, ClientRole::Shell | ClientRole::Admin)
|
||||
}
|
||||
|
||||
pub fn can_extend(&self) -> bool {
|
||||
matches!(self, ClientRole::Admin)
|
||||
}
|
||||
|
||||
pub fn can_reload_config(&self) -> bool {
|
||||
matches!(self, ClientRole::Admin)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stop mode for session termination
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StopMode {
|
||||
/// Try graceful termination first
|
||||
Graceful,
|
||||
/// Force immediate termination
|
||||
Force,
|
||||
}
|
||||
|
||||
/// Health status
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HealthStatus {
|
||||
pub live: bool,
|
||||
pub ready: bool,
|
||||
pub policy_loaded: bool,
|
||||
pub host_adapter_ok: bool,
|
||||
pub store_ok: bool,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn entry_kind_serialization() {
|
||||
let kind = EntryKind::Process {
|
||||
argv: vec!["scummvm".into(), "-f".into()],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&kind).unwrap();
|
||||
let parsed: EntryKind = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(kind, parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reason_code_serialization() {
|
||||
let reason = ReasonCode::QuotaExhausted {
|
||||
used: Duration::from_secs(3600),
|
||||
quota: Duration::from_secs(3600),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&reason).unwrap();
|
||||
assert!(json.contains("quota_exhausted"));
|
||||
}
|
||||
}
|
||||
19
crates/shepherd-config/Cargo.toml
Normal file
19
crates/shepherd-config/Cargo.toml
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
[package]
|
||||
name = "shepherd-config"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Configuration parsing and validation for shepherdd"
|
||||
|
||||
[dependencies]
|
||||
shepherd-util = { workspace = true }
|
||||
shepherd-api = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
100
crates/shepherd-config/src/lib.rs
Normal file
100
crates/shepherd-config/src/lib.rs
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
//! Configuration parsing and validation for shepherdd
|
||||
//!
|
||||
//! Supports TOML configuration with:
|
||||
//! - Versioned schema
|
||||
//! - Entry definitions with availability policies
|
||||
//! - Time windows, limits, and warnings
|
||||
//! - Validation with clear error messages
|
||||
|
||||
mod policy;
|
||||
mod schema;
|
||||
mod validation;
|
||||
|
||||
pub use policy::*;
|
||||
pub use schema::*;
|
||||
pub use validation::*;
|
||||
|
||||
use std::path::Path;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Configuration errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConfigError {
|
||||
#[error("Failed to read config file: {0}")]
|
||||
ReadError(#[from] std::io::Error),
|
||||
|
||||
#[error("Failed to parse TOML: {0}")]
|
||||
ParseError(#[from] toml::de::Error),
|
||||
|
||||
#[error("Validation failed: {errors:?}")]
|
||||
ValidationFailed { errors: Vec<ValidationError> },
|
||||
|
||||
#[error("Unsupported config version: {0}")]
|
||||
UnsupportedVersion(u32),
|
||||
}
|
||||
|
||||
pub type ConfigResult<T> = Result<T, ConfigError>;
|
||||
|
||||
/// Load and validate configuration from a TOML file
|
||||
pub fn load_config(path: impl AsRef<Path>) -> ConfigResult<Policy> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
parse_config(&content)
|
||||
}
|
||||
|
||||
/// Parse and validate configuration from a TOML string
|
||||
pub fn parse_config(content: &str) -> ConfigResult<Policy> {
|
||||
let raw: RawConfig = toml::from_str(content)?;
|
||||
|
||||
// Check version
|
||||
if raw.config_version != CURRENT_CONFIG_VERSION {
|
||||
return Err(ConfigError::UnsupportedVersion(raw.config_version));
|
||||
}
|
||||
|
||||
// Validate
|
||||
let errors = validate_config(&raw);
|
||||
if !errors.is_empty() {
|
||||
return Err(ConfigError::ValidationFailed { errors });
|
||||
}
|
||||
|
||||
// Convert to policy
|
||||
Ok(Policy::from_raw(raw))
|
||||
}
|
||||
|
||||
/// Current supported config version
|
||||
pub const CURRENT_CONFIG_VERSION: u32 = 1;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_minimal_config() {
|
||||
let config = r#"
|
||||
config_version = 1
|
||||
|
||||
[[entries]]
|
||||
id = "test-game"
|
||||
label = "Test Game"
|
||||
kind = { type = "process", argv = ["/usr/bin/game"] }
|
||||
"#;
|
||||
|
||||
let policy = parse_config(config).unwrap();
|
||||
assert_eq!(policy.entries.len(), 1);
|
||||
assert_eq!(policy.entries[0].id.as_str(), "test-game");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reject_wrong_version() {
|
||||
let config = r#"
|
||||
config_version = 99
|
||||
|
||||
[[entries]]
|
||||
id = "test"
|
||||
label = "Test"
|
||||
kind = { type = "process", argv = ["/bin/test"] }
|
||||
"#;
|
||||
|
||||
let result = parse_config(config);
|
||||
assert!(matches!(result, Err(ConfigError::UnsupportedVersion(99))));
|
||||
}
|
||||
}
|
||||
304
crates/shepherd-config/src/policy.rs
Normal file
304
crates/shepherd-config/src/policy.rs
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
//! Validated policy structures
|
||||
|
||||
use crate::schema::{RawConfig, RawDays, RawEntry, RawEntryKind, RawWarningThreshold};
|
||||
use crate::validation::{parse_days, parse_time};
|
||||
use shepherd_api::{EntryKind, WarningSeverity, WarningThreshold};
|
||||
use shepherd_util::{DaysOfWeek, EntryId, TimeWindow, WallClock};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Validated policy ready for use by the core engine
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Policy {
|
||||
/// Daemon configuration
|
||||
pub daemon: DaemonConfig,
|
||||
|
||||
/// Validated entries
|
||||
pub entries: Vec<Entry>,
|
||||
|
||||
/// Default warning thresholds
|
||||
pub default_warnings: Vec<WarningThreshold>,
|
||||
|
||||
/// Default max run duration
|
||||
pub default_max_run: Duration,
|
||||
}
|
||||
|
||||
impl Policy {
|
||||
/// Convert from raw config (after validation)
|
||||
pub fn from_raw(raw: RawConfig) -> Self {
|
||||
let default_warnings = raw
|
||||
.daemon
|
||||
.default_warnings
|
||||
.clone()
|
||||
.map(|w| w.into_iter().map(convert_warning).collect())
|
||||
.unwrap_or_else(default_warning_thresholds);
|
||||
|
||||
let default_max_run = raw
|
||||
.daemon
|
||||
.default_max_run_seconds
|
||||
.map(Duration::from_secs)
|
||||
.unwrap_or(Duration::from_secs(3600)); // 1 hour default
|
||||
|
||||
let entries = raw
|
||||
.entries
|
||||
.into_iter()
|
||||
.map(|e| Entry::from_raw(e, &default_warnings, default_max_run))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
daemon: DaemonConfig::from_raw(raw.daemon),
|
||||
entries,
|
||||
default_warnings,
|
||||
default_max_run,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get entry by ID
|
||||
pub fn get_entry(&self, id: &EntryId) -> Option<&Entry> {
|
||||
self.entries.iter().find(|e| &e.id == id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Daemon configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DaemonConfig {
|
||||
pub socket_path: PathBuf,
|
||||
pub log_dir: PathBuf,
|
||||
pub data_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl DaemonConfig {
|
||||
fn from_raw(raw: crate::schema::RawDaemonConfig) -> Self {
|
||||
Self {
|
||||
socket_path: raw
|
||||
.socket_path
|
||||
.unwrap_or_else(|| PathBuf::from("/run/shepherdd/shepherdd.sock")),
|
||||
log_dir: raw
|
||||
.log_dir
|
||||
.unwrap_or_else(|| PathBuf::from("/var/log/shepherdd")),
|
||||
data_dir: raw
|
||||
.data_dir
|
||||
.unwrap_or_else(|| PathBuf::from("/var/lib/shepherdd")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DaemonConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
socket_path: PathBuf::from("/run/shepherdd/shepherdd.sock"),
|
||||
log_dir: PathBuf::from("/var/log/shepherdd"),
|
||||
data_dir: PathBuf::from("/var/lib/shepherdd"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Validated entry definition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Entry {
|
||||
pub id: EntryId,
|
||||
pub label: String,
|
||||
pub icon_ref: Option<String>,
|
||||
pub kind: EntryKind,
|
||||
pub availability: AvailabilityPolicy,
|
||||
pub limits: LimitsPolicy,
|
||||
pub warnings: Vec<WarningThreshold>,
|
||||
pub disabled: bool,
|
||||
pub disabled_reason: Option<String>,
|
||||
}
|
||||
|
||||
impl Entry {
|
||||
fn from_raw(
|
||||
raw: RawEntry,
|
||||
default_warnings: &[WarningThreshold],
|
||||
default_max_run: Duration,
|
||||
) -> Self {
|
||||
let kind = convert_entry_kind(raw.kind);
|
||||
let availability = raw
|
||||
.availability
|
||||
.map(convert_availability)
|
||||
.unwrap_or_default();
|
||||
let limits = raw
|
||||
.limits
|
||||
.map(|l| convert_limits(l, default_max_run))
|
||||
.unwrap_or_else(|| LimitsPolicy {
|
||||
max_run: default_max_run,
|
||||
daily_quota: None,
|
||||
cooldown: None,
|
||||
});
|
||||
let warnings = raw
|
||||
.warnings
|
||||
.map(|w| w.into_iter().map(convert_warning).collect())
|
||||
.unwrap_or_else(|| default_warnings.to_vec());
|
||||
|
||||
Self {
|
||||
id: EntryId::new(raw.id),
|
||||
label: raw.label,
|
||||
icon_ref: raw.icon,
|
||||
kind,
|
||||
availability,
|
||||
limits,
|
||||
warnings,
|
||||
disabled: raw.disabled,
|
||||
disabled_reason: raw.disabled_reason,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// When an entry is available
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AvailabilityPolicy {
|
||||
/// Time windows when entry is available
|
||||
pub windows: Vec<TimeWindow>,
|
||||
/// If true, always available (ignores windows)
|
||||
pub always: bool,
|
||||
}
|
||||
|
||||
impl AvailabilityPolicy {
|
||||
/// Check if available at given local time
|
||||
pub fn is_available(&self, dt: &chrono::DateTime<chrono::Local>) -> bool {
|
||||
if self.always {
|
||||
return true;
|
||||
}
|
||||
if self.windows.is_empty() {
|
||||
return true; // No windows = always available
|
||||
}
|
||||
self.windows.iter().any(|w| w.contains(dt))
|
||||
}
|
||||
|
||||
/// Get remaining time in current window
|
||||
pub fn remaining_in_window(
|
||||
&self,
|
||||
dt: &chrono::DateTime<chrono::Local>,
|
||||
) -> Option<Duration> {
|
||||
if self.always {
|
||||
return None; // No limit from windows
|
||||
}
|
||||
self.windows.iter().find_map(|w| w.remaining_duration(dt))
|
||||
}
|
||||
}
|
||||
|
||||
/// Time limits for an entry
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LimitsPolicy {
|
||||
pub max_run: Duration,
|
||||
pub daily_quota: Option<Duration>,
|
||||
pub cooldown: Option<Duration>,
|
||||
}
|
||||
|
||||
// Conversion helpers
|
||||
|
||||
fn convert_entry_kind(raw: RawEntryKind) -> EntryKind {
|
||||
match raw {
|
||||
RawEntryKind::Process { argv, env, cwd } => EntryKind::Process { argv, env, cwd },
|
||||
RawEntryKind::Vm { driver, args } => EntryKind::Vm { driver, args },
|
||||
RawEntryKind::Media { library_id, args } => EntryKind::Media { library_id, args },
|
||||
RawEntryKind::Custom { type_name, payload } => EntryKind::Custom {
|
||||
type_name,
|
||||
payload: payload.unwrap_or(serde_json::Value::Null),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_availability(raw: crate::schema::RawAvailability) -> AvailabilityPolicy {
|
||||
let windows = raw.windows.into_iter().map(convert_time_window).collect();
|
||||
AvailabilityPolicy {
|
||||
windows,
|
||||
always: raw.always,
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_time_window(raw: crate::schema::RawTimeWindow) -> TimeWindow {
|
||||
let days_mask = parse_days(&raw.days).unwrap_or(0x7F);
|
||||
let (start_h, start_m) = parse_time(&raw.start).unwrap_or((0, 0));
|
||||
let (end_h, end_m) = parse_time(&raw.end).unwrap_or((23, 59));
|
||||
|
||||
TimeWindow {
|
||||
days: DaysOfWeek::new(days_mask),
|
||||
start: WallClock::new(start_h, start_m).unwrap(),
|
||||
end: WallClock::new(end_h, end_m).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_limits(raw: crate::schema::RawLimits, default_max_run: Duration) -> LimitsPolicy {
|
||||
LimitsPolicy {
|
||||
max_run: raw
|
||||
.max_run_seconds
|
||||
.map(Duration::from_secs)
|
||||
.unwrap_or(default_max_run),
|
||||
daily_quota: raw.daily_quota_seconds.map(Duration::from_secs),
|
||||
cooldown: raw.cooldown_seconds.map(Duration::from_secs),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_warning(raw: RawWarningThreshold) -> WarningThreshold {
|
||||
let severity = match raw.severity.to_lowercase().as_str() {
|
||||
"info" => WarningSeverity::Info,
|
||||
"critical" => WarningSeverity::Critical,
|
||||
_ => WarningSeverity::Warn,
|
||||
};
|
||||
|
||||
WarningThreshold {
|
||||
seconds_before: raw.seconds_before,
|
||||
severity,
|
||||
message_template: raw.message,
|
||||
}
|
||||
}
|
||||
|
||||
fn default_warning_thresholds() -> Vec<WarningThreshold> {
|
||||
vec![
|
||||
WarningThreshold {
|
||||
seconds_before: 300, // 5 minutes
|
||||
severity: WarningSeverity::Info,
|
||||
message_template: Some("5 minutes remaining".into()),
|
||||
},
|
||||
WarningThreshold {
|
||||
seconds_before: 60, // 1 minute
|
||||
severity: WarningSeverity::Warn,
|
||||
message_template: Some("1 minute remaining".into()),
|
||||
},
|
||||
WarningThreshold {
|
||||
seconds_before: 10,
|
||||
severity: WarningSeverity::Critical,
|
||||
message_template: Some("10 seconds remaining!".into()),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::{Local, TimeZone};
|
||||
|
||||
#[test]
|
||||
fn test_availability_always() {
|
||||
let policy = AvailabilityPolicy {
|
||||
windows: vec![],
|
||||
always: true,
|
||||
};
|
||||
|
||||
let dt = Local::now();
|
||||
assert!(policy.is_available(&dt));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_availability_window() {
|
||||
let policy = AvailabilityPolicy {
|
||||
windows: vec![TimeWindow {
|
||||
days: DaysOfWeek::ALL_DAYS,
|
||||
start: WallClock::new(14, 0).unwrap(),
|
||||
end: WallClock::new(18, 0).unwrap(),
|
||||
}],
|
||||
always: false,
|
||||
};
|
||||
|
||||
// 3 PM should be available
|
||||
let dt = Local.with_ymd_and_hms(2025, 12, 26, 15, 0, 0).unwrap();
|
||||
assert!(policy.is_available(&dt));
|
||||
|
||||
// 10 AM should not be available
|
||||
let dt = Local.with_ymd_and_hms(2025, 12, 26, 10, 0, 0).unwrap();
|
||||
assert!(!policy.is_available(&dt));
|
||||
}
|
||||
}
|
||||
216
crates/shepherd-config/src/schema.rs
Normal file
216
crates/shepherd-config/src/schema.rs
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
//! Raw configuration schema (as parsed from TOML)
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Raw configuration as parsed from TOML
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct RawConfig {
|
||||
/// Config schema version
|
||||
pub config_version: u32,
|
||||
|
||||
/// Global daemon settings
|
||||
#[serde(default)]
|
||||
pub daemon: RawDaemonConfig,
|
||||
|
||||
/// List of allowed entries
|
||||
#[serde(default)]
|
||||
pub entries: Vec<RawEntry>,
|
||||
}
|
||||
|
||||
/// Daemon-level settings
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct RawDaemonConfig {
|
||||
/// IPC socket path (default: /run/shepherdd/shepherdd.sock)
|
||||
pub socket_path: Option<PathBuf>,
|
||||
|
||||
/// Log directory
|
||||
pub log_dir: Option<PathBuf>,
|
||||
|
||||
/// Data directory for store
|
||||
pub data_dir: Option<PathBuf>,
|
||||
|
||||
/// Default warning thresholds (can be overridden per entry)
|
||||
pub default_warnings: Option<Vec<RawWarningThreshold>>,
|
||||
|
||||
/// Default max run duration
|
||||
pub default_max_run_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
/// Raw entry definition
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct RawEntry {
|
||||
/// Unique stable ID
|
||||
pub id: String,
|
||||
|
||||
/// Display label
|
||||
pub label: String,
|
||||
|
||||
/// Icon reference (opaque, interpreted by shell)
|
||||
pub icon: Option<String>,
|
||||
|
||||
/// Entry kind and launch details
|
||||
pub kind: RawEntryKind,
|
||||
|
||||
/// Availability time windows
|
||||
#[serde(default)]
|
||||
pub availability: Option<RawAvailability>,
|
||||
|
||||
/// Time limits
|
||||
#[serde(default)]
|
||||
pub limits: Option<RawLimits>,
|
||||
|
||||
/// Warning configuration
|
||||
#[serde(default)]
|
||||
pub warnings: Option<Vec<RawWarningThreshold>>,
|
||||
|
||||
/// Explicitly disabled
|
||||
#[serde(default)]
|
||||
pub disabled: bool,
|
||||
|
||||
/// Reason for disabling
|
||||
pub disabled_reason: Option<String>,
|
||||
}
|
||||
|
||||
/// Raw entry kind
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum RawEntryKind {
|
||||
Process {
|
||||
argv: Vec<String>,
|
||||
#[serde(default)]
|
||||
env: HashMap<String, String>,
|
||||
cwd: Option<PathBuf>,
|
||||
},
|
||||
Vm {
|
||||
driver: String,
|
||||
#[serde(default)]
|
||||
args: HashMap<String, serde_json::Value>,
|
||||
},
|
||||
Media {
|
||||
library_id: String,
|
||||
#[serde(default)]
|
||||
args: HashMap<String, serde_json::Value>,
|
||||
},
|
||||
Custom {
|
||||
type_name: String,
|
||||
#[serde(default)]
|
||||
payload: Option<serde_json::Value>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Availability configuration
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct RawAvailability {
|
||||
/// Time windows when entry is available
|
||||
#[serde(default)]
|
||||
pub windows: Vec<RawTimeWindow>,
|
||||
|
||||
/// If true, entry is always available (ignores windows)
|
||||
#[serde(default)]
|
||||
pub always: bool,
|
||||
}
|
||||
|
||||
/// Time window
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct RawTimeWindow {
|
||||
/// Days of week: "weekdays", "weekends", "all", or list like ["mon", "tue", "wed"]
|
||||
pub days: RawDays,
|
||||
|
||||
/// Start time (HH:MM format)
|
||||
pub start: String,
|
||||
|
||||
/// End time (HH:MM format)
|
||||
pub end: String,
|
||||
}
|
||||
|
||||
/// Days specification
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum RawDays {
|
||||
Preset(String),
|
||||
List(Vec<String>),
|
||||
}
|
||||
|
||||
/// Time limits
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct RawLimits {
|
||||
/// Maximum run duration in seconds
|
||||
pub max_run_seconds: Option<u64>,
|
||||
|
||||
/// Daily quota in seconds
|
||||
pub daily_quota_seconds: Option<u64>,
|
||||
|
||||
/// Cooldown after session ends, in seconds
|
||||
pub cooldown_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
/// Warning threshold
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct RawWarningThreshold {
|
||||
/// Seconds before expiry
|
||||
pub seconds_before: u64,
|
||||
|
||||
/// Severity: "info", "warn", "critical"
|
||||
#[serde(default = "default_severity")]
|
||||
pub severity: String,
|
||||
|
||||
/// Message template
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
fn default_severity() -> String {
|
||||
"warn".to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_process_entry() {
|
||||
let toml_str = r#"
|
||||
config_version = 1
|
||||
|
||||
[[entries]]
|
||||
id = "scummvm"
|
||||
label = "ScummVM"
|
||||
kind = { type = "process", argv = ["scummvm", "-f"] }
|
||||
|
||||
[entries.limits]
|
||||
max_run_seconds = 3600
|
||||
"#;
|
||||
|
||||
let config: RawConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.entries.len(), 1);
|
||||
assert_eq!(config.entries[0].id, "scummvm");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_time_windows() {
|
||||
let toml_str = r#"
|
||||
config_version = 1
|
||||
|
||||
[[entries]]
|
||||
id = "game"
|
||||
label = "Game"
|
||||
kind = { type = "process", argv = ["/bin/game"] }
|
||||
|
||||
[entries.availability]
|
||||
[[entries.availability.windows]]
|
||||
days = "weekdays"
|
||||
start = "14:00"
|
||||
end = "18:00"
|
||||
|
||||
[[entries.availability.windows]]
|
||||
days = ["sat", "sun"]
|
||||
start = "10:00"
|
||||
end = "20:00"
|
||||
"#;
|
||||
|
||||
let config: RawConfig = toml::from_str(toml_str).unwrap();
|
||||
let avail = config.entries[0].availability.as_ref().unwrap();
|
||||
assert_eq!(avail.windows.len(), 2);
|
||||
}
|
||||
}
|
||||
273
crates/shepherd-config/src/validation.rs
Normal file
273
crates/shepherd-config/src/validation.rs
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
//! Configuration validation
|
||||
|
||||
use crate::schema::{RawConfig, RawDays, RawEntry, RawEntryKind, RawTimeWindow};
|
||||
use std::collections::HashSet;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Validation error
|
||||
#[derive(Debug, Clone, Error)]
|
||||
pub enum ValidationError {
|
||||
#[error("Entry '{entry_id}': {message}")]
|
||||
EntryError { entry_id: String, message: String },
|
||||
|
||||
#[error("Duplicate entry ID: {0}")]
|
||||
DuplicateEntryId(String),
|
||||
|
||||
#[error("Invalid time format '{value}': {message}")]
|
||||
InvalidTimeFormat { value: String, message: String },
|
||||
|
||||
#[error("Invalid day specification: {0}")]
|
||||
InvalidDaySpec(String),
|
||||
|
||||
#[error("Warning threshold {seconds}s >= max_run {max_run}s for entry '{entry_id}'")]
|
||||
WarningExceedsMaxRun {
|
||||
entry_id: String,
|
||||
seconds: u64,
|
||||
max_run: u64,
|
||||
},
|
||||
|
||||
#[error("Global config error: {0}")]
|
||||
GlobalError(String),
|
||||
}
|
||||
|
||||
/// Validate a raw configuration
|
||||
pub fn validate_config(config: &RawConfig) -> Vec<ValidationError> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Check for duplicate entry IDs
|
||||
let mut seen_ids = HashSet::new();
|
||||
for entry in &config.entries {
|
||||
if !seen_ids.insert(&entry.id) {
|
||||
errors.push(ValidationError::DuplicateEntryId(entry.id.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
// Validate each entry
|
||||
for entry in &config.entries {
|
||||
errors.extend(validate_entry(entry, config));
|
||||
}
|
||||
|
||||
errors
|
||||
}
|
||||
|
||||
fn validate_entry(entry: &RawEntry, config: &RawConfig) -> Vec<ValidationError> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Validate kind
|
||||
match &entry.kind {
|
||||
RawEntryKind::Process { argv, .. } => {
|
||||
if argv.is_empty() {
|
||||
errors.push(ValidationError::EntryError {
|
||||
entry_id: entry.id.clone(),
|
||||
message: "argv cannot be empty".into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
RawEntryKind::Vm { driver, .. } => {
|
||||
if driver.is_empty() {
|
||||
errors.push(ValidationError::EntryError {
|
||||
entry_id: entry.id.clone(),
|
||||
message: "VM driver cannot be empty".into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
RawEntryKind::Media { library_id, .. } => {
|
||||
if library_id.is_empty() {
|
||||
errors.push(ValidationError::EntryError {
|
||||
entry_id: entry.id.clone(),
|
||||
message: "library_id cannot be empty".into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
RawEntryKind::Custom { type_name, .. } => {
|
||||
if type_name.is_empty() {
|
||||
errors.push(ValidationError::EntryError {
|
||||
entry_id: entry.id.clone(),
|
||||
message: "type_name cannot be empty".into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate availability windows
|
||||
if let Some(avail) = &entry.availability {
|
||||
for window in &avail.windows {
|
||||
errors.extend(validate_time_window(window, &entry.id));
|
||||
}
|
||||
}
|
||||
|
||||
// Validate warning thresholds vs max_run
|
||||
let max_run = entry
|
||||
.limits
|
||||
.as_ref()
|
||||
.and_then(|l| l.max_run_seconds)
|
||||
.or(config.daemon.default_max_run_seconds);
|
||||
|
||||
if let (Some(warnings), Some(max_run)) = (&entry.warnings, max_run) {
|
||||
for warning in warnings {
|
||||
if warning.seconds_before >= max_run {
|
||||
errors.push(ValidationError::WarningExceedsMaxRun {
|
||||
entry_id: entry.id.clone(),
|
||||
seconds: warning.seconds_before,
|
||||
max_run,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errors
|
||||
}
|
||||
|
||||
fn validate_time_window(window: &RawTimeWindow, entry_id: &str) -> Vec<ValidationError> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Validate days
|
||||
if let Err(e) = parse_days(&window.days) {
|
||||
errors.push(ValidationError::EntryError {
|
||||
entry_id: entry_id.to_string(),
|
||||
message: e,
|
||||
});
|
||||
}
|
||||
|
||||
// Validate start time
|
||||
if let Err(e) = parse_time(&window.start) {
|
||||
errors.push(ValidationError::InvalidTimeFormat {
|
||||
value: window.start.clone(),
|
||||
message: e,
|
||||
});
|
||||
}
|
||||
|
||||
// Validate end time
|
||||
if let Err(e) = parse_time(&window.end) {
|
||||
errors.push(ValidationError::InvalidTimeFormat {
|
||||
value: window.end.clone(),
|
||||
message: e,
|
||||
});
|
||||
}
|
||||
|
||||
errors
|
||||
}
|
||||
|
||||
/// Parse HH:MM time format
|
||||
pub fn parse_time(s: &str) -> Result<(u8, u8), String> {
|
||||
let parts: Vec<&str> = s.split(':').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err("Expected HH:MM format".into());
|
||||
}
|
||||
|
||||
let hour: u8 = parts[0]
|
||||
.parse()
|
||||
.map_err(|_| "Invalid hour".to_string())?;
|
||||
let minute: u8 = parts[1]
|
||||
.parse()
|
||||
.map_err(|_| "Invalid minute".to_string())?;
|
||||
|
||||
if hour >= 24 {
|
||||
return Err("Hour must be 0-23".into());
|
||||
}
|
||||
if minute >= 60 {
|
||||
return Err("Minute must be 0-59".into());
|
||||
}
|
||||
|
||||
Ok((hour, minute))
|
||||
}
|
||||
|
||||
/// Parse days specification
|
||||
pub fn parse_days(days: &RawDays) -> Result<u8, String> {
|
||||
match days {
|
||||
RawDays::Preset(preset) => match preset.to_lowercase().as_str() {
|
||||
"all" | "every" | "daily" => Ok(0x7F),
|
||||
"weekdays" => Ok(0x1F), // Mon-Fri
|
||||
"weekends" => Ok(0x60), // Sat-Sun
|
||||
other => Err(format!("Unknown day preset: {}", other)),
|
||||
},
|
||||
RawDays::List(list) => {
|
||||
let mut mask = 0u8;
|
||||
for day in list {
|
||||
let bit = match day.to_lowercase().as_str() {
|
||||
"mon" | "monday" => 1 << 0,
|
||||
"tue" | "tuesday" => 1 << 1,
|
||||
"wed" | "wednesday" => 1 << 2,
|
||||
"thu" | "thursday" => 1 << 3,
|
||||
"fri" | "friday" => 1 << 4,
|
||||
"sat" | "saturday" => 1 << 5,
|
||||
"sun" | "sunday" => 1 << 6,
|
||||
other => return Err(format!("Unknown day: {}", other)),
|
||||
};
|
||||
mask |= bit;
|
||||
}
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_time() {
|
||||
assert_eq!(parse_time("14:30").unwrap(), (14, 30));
|
||||
assert_eq!(parse_time("00:00").unwrap(), (0, 0));
|
||||
assert_eq!(parse_time("23:59").unwrap(), (23, 59));
|
||||
|
||||
assert!(parse_time("24:00").is_err());
|
||||
assert!(parse_time("12:60").is_err());
|
||||
assert!(parse_time("invalid").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_days() {
|
||||
assert_eq!(parse_days(&RawDays::Preset("weekdays".into())).unwrap(), 0x1F);
|
||||
assert_eq!(parse_days(&RawDays::Preset("weekends".into())).unwrap(), 0x60);
|
||||
assert_eq!(parse_days(&RawDays::Preset("all".into())).unwrap(), 0x7F);
|
||||
|
||||
assert_eq!(
|
||||
parse_days(&RawDays::List(vec!["mon".into(), "wed".into(), "fri".into()])).unwrap(),
|
||||
0b10101
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_id_detection() {
|
||||
let config = RawConfig {
|
||||
config_version: 1,
|
||||
daemon: Default::default(),
|
||||
entries: vec![
|
||||
RawEntry {
|
||||
id: "game".into(),
|
||||
label: "Game 1".into(),
|
||||
icon: None,
|
||||
kind: RawEntryKind::Process {
|
||||
argv: vec!["game1".into()],
|
||||
env: Default::default(),
|
||||
cwd: None,
|
||||
},
|
||||
availability: None,
|
||||
limits: None,
|
||||
warnings: None,
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
},
|
||||
RawEntry {
|
||||
id: "game".into(),
|
||||
label: "Game 2".into(),
|
||||
icon: None,
|
||||
kind: RawEntryKind::Process {
|
||||
argv: vec!["game2".into()],
|
||||
env: Default::default(),
|
||||
cwd: None,
|
||||
},
|
||||
availability: None,
|
||||
limits: None,
|
||||
warnings: None,
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let errors = validate_config(&config);
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateEntryId(_))));
|
||||
}
|
||||
}
|
||||
22
crates/shepherd-core/Cargo.toml
Normal file
22
crates/shepherd-core/Cargo.toml
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "shepherd-core"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Core policy engine and session state machine for shepherdd"
|
||||
|
||||
[dependencies]
|
||||
shepherd-util = { workspace = true }
|
||||
shepherd-api = { workspace = true }
|
||||
shepherd-config = { workspace = true }
|
||||
shepherd-store = { workspace = true }
|
||||
shepherd-host-api = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
707
crates/shepherd-core/src/engine.rs
Normal file
707
crates/shepherd-core/src/engine.rs
Normal file
|
|
@ -0,0 +1,707 @@
|
|||
//! Core policy engine
|
||||
|
||||
use chrono::{DateTime, Local};
|
||||
use shepherd_api::{
|
||||
DaemonStateSnapshot, EntryKindTag, EntryView, ReasonCode, SessionEndReason,
|
||||
WarningSeverity, API_VERSION,
|
||||
};
|
||||
use shepherd_config::{Entry, Policy};
|
||||
use shepherd_host_api::{HostCapabilities, HostSessionHandle};
|
||||
use shepherd_store::{AuditEvent, AuditEventType, Store};
|
||||
use shepherd_util::{EntryId, MonotonicInstant, SessionId};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::{ActiveSession, CoreEvent, SessionPlan, StopResult};
|
||||
|
||||
/// Launch decision from the core engine
|
||||
#[derive(Debug)]
|
||||
pub enum LaunchDecision {
|
||||
Approved(SessionPlan),
|
||||
Denied { reasons: Vec<ReasonCode> },
|
||||
}
|
||||
|
||||
/// Stop decision from the core engine
|
||||
#[derive(Debug)]
|
||||
pub enum StopDecision {
|
||||
Stopped(StopResult),
|
||||
NoActiveSession,
|
||||
}
|
||||
|
||||
/// The core policy engine
|
||||
pub struct CoreEngine {
|
||||
policy: Policy,
|
||||
store: Arc<dyn Store>,
|
||||
capabilities: HostCapabilities,
|
||||
current_session: Option<ActiveSession>,
|
||||
}
|
||||
|
||||
impl CoreEngine {
|
||||
/// Create a new core engine
|
||||
pub fn new(
|
||||
policy: Policy,
|
||||
store: Arc<dyn Store>,
|
||||
capabilities: HostCapabilities,
|
||||
) -> Self {
|
||||
info!(
|
||||
entry_count = policy.entries.len(),
|
||||
"Core engine initialized"
|
||||
);
|
||||
|
||||
// Log policy load
|
||||
let _ = store.append_audit(AuditEvent::new(AuditEventType::PolicyLoaded {
|
||||
entry_count: policy.entries.len(),
|
||||
}));
|
||||
|
||||
Self {
|
||||
policy,
|
||||
store,
|
||||
capabilities,
|
||||
current_session: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current policy
|
||||
pub fn policy(&self) -> &Policy {
|
||||
&self.policy
|
||||
}
|
||||
|
||||
/// Reload policy
|
||||
pub fn reload_policy(&mut self, policy: Policy) -> CoreEvent {
|
||||
let entry_count = policy.entries.len();
|
||||
self.policy = policy;
|
||||
|
||||
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::PolicyLoaded {
|
||||
entry_count,
|
||||
}));
|
||||
|
||||
info!(entry_count, "Policy reloaded");
|
||||
|
||||
CoreEvent::PolicyReloaded { entry_count }
|
||||
}
|
||||
|
||||
/// List all entries with availability status
|
||||
pub fn list_entries(&self, now: DateTime<Local>) -> Vec<EntryView> {
|
||||
self.policy
|
||||
.entries
|
||||
.iter()
|
||||
.map(|entry| self.evaluate_entry(entry, now))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Evaluate a single entry for availability
|
||||
fn evaluate_entry(&self, entry: &Entry, now: DateTime<Local>) -> EntryView {
|
||||
let mut reasons = Vec::new();
|
||||
let mut enabled = true;
|
||||
|
||||
// Check if explicitly disabled
|
||||
if entry.disabled {
|
||||
enabled = false;
|
||||
reasons.push(ReasonCode::Disabled {
|
||||
reason: entry.disabled_reason.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Check host capabilities
|
||||
let kind_tag = entry.kind.tag();
|
||||
if !self.capabilities.supports_kind(kind_tag) {
|
||||
enabled = false;
|
||||
reasons.push(ReasonCode::UnsupportedKind { kind: kind_tag });
|
||||
}
|
||||
|
||||
// Check availability window
|
||||
if !entry.availability.is_available(&now) {
|
||||
enabled = false;
|
||||
reasons.push(ReasonCode::OutsideTimeWindow {
|
||||
next_window_start: None, // TODO: compute next window
|
||||
});
|
||||
}
|
||||
|
||||
// Check if another session is active
|
||||
if let Some(session) = &self.current_session {
|
||||
enabled = false;
|
||||
reasons.push(ReasonCode::SessionActive {
|
||||
entry_id: session.plan.entry_id.clone(),
|
||||
remaining: session.time_remaining(MonotonicInstant::now()),
|
||||
});
|
||||
}
|
||||
|
||||
// Check cooldown
|
||||
if let Ok(Some(until)) = self.store.get_cooldown_until(&entry.id) {
|
||||
if until > now {
|
||||
enabled = false;
|
||||
reasons.push(ReasonCode::CooldownActive { available_at: until });
|
||||
}
|
||||
}
|
||||
|
||||
// Check daily quota
|
||||
if let Some(quota) = entry.limits.daily_quota {
|
||||
let today = now.date_naive();
|
||||
if let Ok(used) = self.store.get_usage(&entry.id, today) {
|
||||
if used >= quota {
|
||||
enabled = false;
|
||||
reasons.push(ReasonCode::QuotaExhausted { used, quota });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate max run if enabled
|
||||
let max_run_if_started_now = if enabled {
|
||||
Some(self.compute_max_duration(entry, now))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
EntryView {
|
||||
entry_id: entry.id.clone(),
|
||||
label: entry.label.clone(),
|
||||
icon_ref: entry.icon_ref.clone(),
|
||||
kind_tag,
|
||||
enabled,
|
||||
reasons,
|
||||
max_run_if_started_now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute maximum duration for an entry if started now
|
||||
fn compute_max_duration(&self, entry: &Entry, now: DateTime<Local>) -> Duration {
|
||||
let mut max = entry.limits.max_run;
|
||||
|
||||
// Limit by time window remaining
|
||||
if let Some(window_remaining) = entry.availability.remaining_in_window(&now) {
|
||||
max = max.min(window_remaining);
|
||||
}
|
||||
|
||||
// Limit by daily quota remaining
|
||||
if let Some(quota) = entry.limits.daily_quota {
|
||||
let today = now.date_naive();
|
||||
if let Ok(used) = self.store.get_usage(&entry.id, today) {
|
||||
let remaining = quota.saturating_sub(used);
|
||||
max = max.min(remaining);
|
||||
}
|
||||
}
|
||||
|
||||
max
|
||||
}
|
||||
|
||||
/// Request to launch an entry
|
||||
pub fn request_launch(
|
||||
&self,
|
||||
entry_id: &EntryId,
|
||||
now: DateTime<Local>,
|
||||
) -> LaunchDecision {
|
||||
// Find entry
|
||||
let entry = match self.policy.get_entry(entry_id) {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
return LaunchDecision::Denied {
|
||||
reasons: vec![ReasonCode::Disabled {
|
||||
reason: Some("Entry not found".into()),
|
||||
}],
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Evaluate availability
|
||||
let view = self.evaluate_entry(entry, now);
|
||||
|
||||
if !view.enabled {
|
||||
// Log denial
|
||||
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::LaunchDenied {
|
||||
entry_id: entry_id.clone(),
|
||||
reasons: view.reasons.iter().map(|r| format!("{:?}", r)).collect(),
|
||||
}));
|
||||
|
||||
return LaunchDecision::Denied {
|
||||
reasons: view.reasons,
|
||||
};
|
||||
}
|
||||
|
||||
// Compute session plan
|
||||
let max_duration = view.max_run_if_started_now.unwrap();
|
||||
let plan = SessionPlan {
|
||||
session_id: SessionId::new(),
|
||||
entry_id: entry_id.clone(),
|
||||
label: entry.label.clone(),
|
||||
max_duration,
|
||||
warnings: entry.warnings.clone(),
|
||||
};
|
||||
|
||||
debug!(
|
||||
entry_id = %entry_id,
|
||||
max_duration_secs = max_duration.as_secs(),
|
||||
"Launch approved"
|
||||
);
|
||||
|
||||
LaunchDecision::Approved(plan)
|
||||
}
|
||||
|
||||
/// Start a session from an approved plan
|
||||
pub fn start_session(
|
||||
&mut self,
|
||||
plan: SessionPlan,
|
||||
now: DateTime<Local>,
|
||||
now_mono: MonotonicInstant,
|
||||
) -> CoreEvent {
|
||||
let session = ActiveSession::new(plan.clone(), now, now_mono);
|
||||
|
||||
let event = CoreEvent::SessionStarted {
|
||||
session_id: session.plan.session_id.clone(),
|
||||
entry_id: session.plan.entry_id.clone(),
|
||||
label: session.plan.label.clone(),
|
||||
deadline: session.deadline,
|
||||
};
|
||||
|
||||
// Log to audit
|
||||
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::SessionStarted {
|
||||
session_id: session.plan.session_id.clone(),
|
||||
entry_id: session.plan.entry_id.clone(),
|
||||
label: session.plan.label.clone(),
|
||||
deadline: session.deadline,
|
||||
}));
|
||||
|
||||
info!(
|
||||
session_id = %session.plan.session_id,
|
||||
entry_id = %session.plan.entry_id,
|
||||
deadline = %session.deadline,
|
||||
"Session started"
|
||||
);
|
||||
|
||||
self.current_session = Some(session);
|
||||
|
||||
event
|
||||
}
|
||||
|
||||
/// Attach host handle to current session
|
||||
pub fn attach_host_handle(&mut self, handle: HostSessionHandle) {
|
||||
if let Some(session) = &mut self.current_session {
|
||||
session.attach_handle(handle);
|
||||
}
|
||||
}
|
||||
|
||||
/// Tick the engine - check for warnings and expiry
|
||||
pub fn tick(&mut self, now_mono: MonotonicInstant) -> Vec<CoreEvent> {
|
||||
let mut events = Vec::new();
|
||||
|
||||
let session = match &mut self.current_session {
|
||||
Some(s) => s,
|
||||
None => return events,
|
||||
};
|
||||
|
||||
// Check for pending warnings
|
||||
for (threshold, remaining) in session.pending_warnings(now_mono) {
|
||||
let severity = session
|
||||
.plan
|
||||
.warnings
|
||||
.iter()
|
||||
.find(|w| w.seconds_before == threshold)
|
||||
.map(|w| w.severity)
|
||||
.unwrap_or(WarningSeverity::Warn);
|
||||
|
||||
let message = session
|
||||
.plan
|
||||
.warnings
|
||||
.iter()
|
||||
.find(|w| w.seconds_before == threshold)
|
||||
.and_then(|w| w.message_template.clone());
|
||||
|
||||
session.mark_warning_issued(threshold);
|
||||
|
||||
// Log to audit
|
||||
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::WarningIssued {
|
||||
session_id: session.plan.session_id.clone(),
|
||||
threshold_seconds: threshold,
|
||||
}));
|
||||
|
||||
info!(
|
||||
session_id = %session.plan.session_id,
|
||||
threshold_seconds = threshold,
|
||||
remaining_secs = remaining.as_secs(),
|
||||
"Warning issued"
|
||||
);
|
||||
|
||||
events.push(CoreEvent::Warning {
|
||||
session_id: session.plan.session_id.clone(),
|
||||
threshold_seconds: threshold,
|
||||
time_remaining: remaining,
|
||||
severity,
|
||||
message,
|
||||
});
|
||||
}
|
||||
|
||||
// Check for expiry
|
||||
if session.is_expired(now_mono)
|
||||
&& session.state != shepherd_api::SessionState::Expiring
|
||||
&& session.state != shepherd_api::SessionState::Ended
|
||||
{
|
||||
session.mark_expiring();
|
||||
|
||||
info!(
|
||||
session_id = %session.plan.session_id,
|
||||
"Session expiring"
|
||||
);
|
||||
|
||||
events.push(CoreEvent::ExpireDue {
|
||||
session_id: session.plan.session_id.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
/// Notify that a session has exited
|
||||
pub fn notify_session_exited(
|
||||
&mut self,
|
||||
exit_code: Option<i32>,
|
||||
now_mono: MonotonicInstant,
|
||||
now: DateTime<Local>,
|
||||
) -> Option<CoreEvent> {
|
||||
let session = self.current_session.take()?;
|
||||
|
||||
let duration = session.duration_so_far(now_mono);
|
||||
let reason = if session.state == shepherd_api::SessionState::Expiring {
|
||||
SessionEndReason::Expired
|
||||
} else {
|
||||
SessionEndReason::ProcessExited { exit_code }
|
||||
};
|
||||
|
||||
// Update usage accounting
|
||||
let today = now.date_naive();
|
||||
let _ = self.store.add_usage(&session.plan.entry_id, today, duration);
|
||||
|
||||
// Set cooldown if configured
|
||||
if let Some(entry) = self.policy.get_entry(&session.plan.entry_id) {
|
||||
if let Some(cooldown) = entry.limits.cooldown {
|
||||
let until = now + chrono::Duration::from_std(cooldown).unwrap();
|
||||
let _ = self.store.set_cooldown_until(&session.plan.entry_id, until);
|
||||
}
|
||||
}
|
||||
|
||||
// Log to audit
|
||||
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::SessionEnded {
|
||||
session_id: session.plan.session_id.clone(),
|
||||
entry_id: session.plan.entry_id.clone(),
|
||||
reason: reason.clone(),
|
||||
duration,
|
||||
}));
|
||||
|
||||
info!(
|
||||
session_id = %session.plan.session_id,
|
||||
entry_id = %session.plan.entry_id,
|
||||
duration_secs = duration.as_secs(),
|
||||
reason = ?reason,
|
||||
"Session ended"
|
||||
);
|
||||
|
||||
Some(CoreEvent::SessionEnded {
|
||||
session_id: session.plan.session_id,
|
||||
entry_id: session.plan.entry_id,
|
||||
reason,
|
||||
duration,
|
||||
})
|
||||
}
|
||||
|
||||
/// Stop the current session
|
||||
pub fn stop_current(
|
||||
&mut self,
|
||||
reason: SessionEndReason,
|
||||
now_mono: MonotonicInstant,
|
||||
now: DateTime<Local>,
|
||||
) -> StopDecision {
|
||||
let session = match self.current_session.take() {
|
||||
Some(s) => s,
|
||||
None => return StopDecision::NoActiveSession,
|
||||
};
|
||||
|
||||
let duration = session.duration_so_far(now_mono);
|
||||
|
||||
// Update usage accounting
|
||||
let today = now.date_naive();
|
||||
let _ = self.store.add_usage(&session.plan.entry_id, today, duration);
|
||||
|
||||
// Set cooldown if configured
|
||||
if let Some(entry) = self.policy.get_entry(&session.plan.entry_id) {
|
||||
if let Some(cooldown) = entry.limits.cooldown {
|
||||
let until = now + chrono::Duration::from_std(cooldown).unwrap();
|
||||
let _ = self.store.set_cooldown_until(&session.plan.entry_id, until);
|
||||
}
|
||||
}
|
||||
|
||||
// Log to audit
|
||||
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::SessionEnded {
|
||||
session_id: session.plan.session_id.clone(),
|
||||
entry_id: session.plan.entry_id.clone(),
|
||||
reason: reason.clone(),
|
||||
duration,
|
||||
}));
|
||||
|
||||
info!(
|
||||
session_id = %session.plan.session_id,
|
||||
reason = ?reason,
|
||||
"Session stopped"
|
||||
);
|
||||
|
||||
StopDecision::Stopped(StopResult {
|
||||
session_id: session.plan.session_id,
|
||||
entry_id: session.plan.entry_id,
|
||||
reason,
|
||||
duration,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get current daemon state snapshot
|
||||
pub fn get_state(&self) -> DaemonStateSnapshot {
|
||||
let current_session = self.current_session.as_ref().map(|s| {
|
||||
s.to_session_info(MonotonicInstant::now())
|
||||
});
|
||||
|
||||
DaemonStateSnapshot {
|
||||
api_version: API_VERSION,
|
||||
policy_loaded: true,
|
||||
current_session,
|
||||
entry_count: self.policy.entries.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current session reference
|
||||
pub fn current_session(&self) -> Option<&ActiveSession> {
|
||||
self.current_session.as_ref()
|
||||
}
|
||||
|
||||
/// Get mutable current session reference
|
||||
pub fn current_session_mut(&mut self) -> Option<&mut ActiveSession> {
|
||||
self.current_session.as_mut()
|
||||
}
|
||||
|
||||
/// Check if a session is active
|
||||
pub fn has_active_session(&self) -> bool {
|
||||
self.current_session.is_some()
|
||||
}
|
||||
|
||||
/// Extend current session (admin action)
|
||||
pub fn extend_current(
|
||||
&mut self,
|
||||
by: Duration,
|
||||
now_mono: MonotonicInstant,
|
||||
now: DateTime<Local>,
|
||||
) -> Option<DateTime<Local>> {
|
||||
let session = self.current_session.as_mut()?;
|
||||
|
||||
session.deadline_mono = session.deadline_mono + by;
|
||||
session.deadline = session.deadline + chrono::Duration::from_std(by).unwrap();
|
||||
|
||||
// Log to audit
|
||||
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::SessionExtended {
|
||||
session_id: session.plan.session_id.clone(),
|
||||
extended_by: by,
|
||||
new_deadline: session.deadline,
|
||||
}));
|
||||
|
||||
info!(
|
||||
session_id = %session.plan.session_id,
|
||||
extended_by_secs = by.as_secs(),
|
||||
new_deadline = %session.deadline,
|
||||
"Session extended"
|
||||
);
|
||||
|
||||
Some(session.deadline)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy};
|
||||
use shepherd_api::EntryKind;
|
||||
use shepherd_store::SqliteStore;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_test_policy() -> Policy {
|
||||
Policy {
|
||||
daemon: Default::default(),
|
||||
entries: vec![Entry {
|
||||
id: EntryId::new("test-game"),
|
||||
label: "Test Game".into(),
|
||||
icon_ref: None,
|
||||
kind: EntryKind::Process {
|
||||
argv: vec!["game".into()],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
},
|
||||
availability: AvailabilityPolicy {
|
||||
windows: vec![],
|
||||
always: true,
|
||||
},
|
||||
limits: LimitsPolicy {
|
||||
max_run: Duration::from_secs(300),
|
||||
daily_quota: None,
|
||||
cooldown: None,
|
||||
},
|
||||
warnings: vec![],
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
}],
|
||||
default_warnings: vec![],
|
||||
default_max_run: Duration::from_secs(3600),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_entries() {
|
||||
let policy = make_test_policy();
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let caps = HostCapabilities::minimal();
|
||||
let engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entries = engine.list_entries(Local::now());
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries[0].enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_launch_approval() {
|
||||
let policy = make_test_policy();
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let caps = HostCapabilities::minimal();
|
||||
let engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entry_id = EntryId::new("test-game");
|
||||
let decision = engine.request_launch(&entry_id, Local::now());
|
||||
|
||||
assert!(matches!(decision, LaunchDecision::Approved(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_blocks_new_launch() {
|
||||
let policy = make_test_policy();
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let caps = HostCapabilities::minimal();
|
||||
let mut engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entry_id = EntryId::new("test-game");
|
||||
let now = Local::now();
|
||||
let now_mono = MonotonicInstant::now();
|
||||
|
||||
// Launch first session
|
||||
if let LaunchDecision::Approved(plan) = engine.request_launch(&entry_id, now) {
|
||||
engine.start_session(plan, now, now_mono);
|
||||
}
|
||||
|
||||
// Try to launch again - should be denied
|
||||
let decision = engine.request_launch(&entry_id, now);
|
||||
assert!(matches!(decision, LaunchDecision::Denied { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tick_warnings() {
|
||||
let policy = Policy {
|
||||
entries: vec![Entry {
|
||||
id: EntryId::new("test"),
|
||||
label: "Test".into(),
|
||||
icon_ref: None,
|
||||
kind: EntryKind::Process {
|
||||
argv: vec!["test".into()],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
},
|
||||
availability: AvailabilityPolicy {
|
||||
windows: vec![],
|
||||
always: true,
|
||||
},
|
||||
limits: LimitsPolicy {
|
||||
max_run: Duration::from_secs(120), // 2 minutes
|
||||
daily_quota: None,
|
||||
cooldown: None,
|
||||
},
|
||||
warnings: vec![shepherd_api::WarningThreshold {
|
||||
seconds_before: 60,
|
||||
severity: WarningSeverity::Warn,
|
||||
message_template: Some("1 minute left".into()),
|
||||
}],
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
}],
|
||||
daemon: Default::default(),
|
||||
default_warnings: vec![],
|
||||
default_max_run: Duration::from_secs(3600),
|
||||
};
|
||||
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let caps = HostCapabilities::minimal();
|
||||
let mut engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entry_id = EntryId::new("test");
|
||||
let now = Local::now();
|
||||
let now_mono = MonotonicInstant::now();
|
||||
|
||||
// Start session
|
||||
if let LaunchDecision::Approved(plan) = engine.request_launch(&entry_id, now) {
|
||||
engine.start_session(plan, now, now_mono);
|
||||
}
|
||||
|
||||
// No warnings initially
|
||||
let events = engine.tick(now_mono);
|
||||
assert!(events.is_empty());
|
||||
|
||||
// At 70 seconds (10 seconds past warning threshold), warning should fire
|
||||
let later = now_mono + Duration::from_secs(70);
|
||||
let events = engine.tick(later);
|
||||
assert_eq!(events.len(), 1);
|
||||
assert!(matches!(events[0], CoreEvent::Warning { threshold_seconds: 60, .. }));
|
||||
|
||||
// Warning shouldn't fire twice
|
||||
let events = engine.tick(later);
|
||||
assert!(events.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_expiry() {
|
||||
let policy = Policy {
|
||||
entries: vec![Entry {
|
||||
id: EntryId::new("test"),
|
||||
label: "Test".into(),
|
||||
icon_ref: None,
|
||||
kind: EntryKind::Process {
|
||||
argv: vec!["test".into()],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
},
|
||||
availability: AvailabilityPolicy {
|
||||
windows: vec![],
|
||||
always: true,
|
||||
},
|
||||
limits: LimitsPolicy {
|
||||
max_run: Duration::from_secs(60),
|
||||
daily_quota: None,
|
||||
cooldown: None,
|
||||
},
|
||||
warnings: vec![],
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
}],
|
||||
daemon: Default::default(),
|
||||
default_warnings: vec![],
|
||||
default_max_run: Duration::from_secs(3600),
|
||||
};
|
||||
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let caps = HostCapabilities::minimal();
|
||||
let mut engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entry_id = EntryId::new("test");
|
||||
let now = Local::now();
|
||||
let now_mono = MonotonicInstant::now();
|
||||
|
||||
// Start session
|
||||
if let LaunchDecision::Approved(plan) = engine.request_launch(&entry_id, now) {
|
||||
engine.start_session(plan, now, now_mono);
|
||||
}
|
||||
|
||||
// At 61 seconds, should be expired
|
||||
let later = now_mono + Duration::from_secs(61);
|
||||
let events = engine.tick(later);
|
||||
assert_eq!(events.len(), 1);
|
||||
assert!(matches!(events[0], CoreEvent::ExpireDue { .. }));
|
||||
}
|
||||
}
|
||||
51
crates/shepherd-core/src/events.rs
Normal file
51
crates/shepherd-core/src/events.rs
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
//! Core events emitted by the engine
|
||||
|
||||
use chrono::{DateTime, Local};
|
||||
use shepherd_api::{SessionEndReason, WarningSeverity};
|
||||
use shepherd_util::{EntryId, SessionId};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Events emitted by the core engine
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CoreEvent {
|
||||
/// Session started successfully
|
||||
SessionStarted {
|
||||
session_id: SessionId,
|
||||
entry_id: EntryId,
|
||||
label: String,
|
||||
deadline: DateTime<Local>,
|
||||
},
|
||||
|
||||
/// Warning threshold reached
|
||||
Warning {
|
||||
session_id: SessionId,
|
||||
threshold_seconds: u64,
|
||||
time_remaining: Duration,
|
||||
severity: WarningSeverity,
|
||||
message: Option<String>,
|
||||
},
|
||||
|
||||
/// Session is expiring (termination initiated)
|
||||
ExpireDue {
|
||||
session_id: SessionId,
|
||||
},
|
||||
|
||||
/// Session has ended
|
||||
SessionEnded {
|
||||
session_id: SessionId,
|
||||
entry_id: EntryId,
|
||||
reason: SessionEndReason,
|
||||
duration: Duration,
|
||||
},
|
||||
|
||||
/// Entry availability changed
|
||||
EntryAvailabilityChanged {
|
||||
entry_id: EntryId,
|
||||
enabled: bool,
|
||||
},
|
||||
|
||||
/// Policy was reloaded
|
||||
PolicyReloaded {
|
||||
entry_count: usize,
|
||||
},
|
||||
}
|
||||
15
crates/shepherd-core/src/lib.rs
Normal file
15
crates/shepherd-core/src/lib.rs
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
//! Core policy engine and session state machine for shepherdd
|
||||
//!
|
||||
//! This crate is the heart of shepherdd, containing:
|
||||
//! - Policy evaluation (what's available, when, for how long)
|
||||
//! - Session state machine (Idle -> Launching -> Running -> Warned -> Expiring -> Ended)
|
||||
//! - Warning and expiry scheduling
|
||||
//! - Time enforcement using monotonic time
|
||||
|
||||
mod engine;
|
||||
mod events;
|
||||
mod session;
|
||||
|
||||
pub use engine::*;
|
||||
pub use events::*;
|
||||
pub use session::*;
|
||||
263
crates/shepherd-core/src/session.rs
Normal file
263
crates/shepherd-core/src/session.rs
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
//! Session state machine
|
||||
|
||||
use chrono::{DateTime, Local};
|
||||
use shepherd_api::{SessionEndReason, SessionState, WarningThreshold};
|
||||
use shepherd_host_api::HostSessionHandle;
|
||||
use shepherd_util::{EntryId, MonotonicInstant, SessionId};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Session plan computed at launch approval
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionPlan {
|
||||
pub session_id: SessionId,
|
||||
pub entry_id: EntryId,
|
||||
pub label: String,
|
||||
pub max_duration: Duration,
|
||||
pub warnings: Vec<WarningThreshold>,
|
||||
}
|
||||
|
||||
impl SessionPlan {
|
||||
/// Compute warning times (as durations after start)
|
||||
pub fn warning_times(&self) -> Vec<(u64, Duration)> {
|
||||
self.warnings
|
||||
.iter()
|
||||
.filter(|w| Duration::from_secs(w.seconds_before) < self.max_duration)
|
||||
.map(|w| {
|
||||
let trigger_after =
|
||||
self.max_duration - Duration::from_secs(w.seconds_before);
|
||||
(w.seconds_before, trigger_after)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Active session tracking
|
||||
#[derive(Debug)]
|
||||
pub struct ActiveSession {
|
||||
/// Session plan
|
||||
pub plan: SessionPlan,
|
||||
|
||||
/// Current state
|
||||
pub state: SessionState,
|
||||
|
||||
/// Wall-clock start time (for display/logging)
|
||||
pub started_at: DateTime<Local>,
|
||||
|
||||
/// Monotonic start time (for enforcement)
|
||||
pub started_at_mono: MonotonicInstant,
|
||||
|
||||
/// Wall-clock deadline (for display)
|
||||
pub deadline: DateTime<Local>,
|
||||
|
||||
/// Monotonic deadline (for enforcement)
|
||||
pub deadline_mono: MonotonicInstant,
|
||||
|
||||
/// Warning thresholds already issued (seconds before expiry)
|
||||
pub warnings_issued: Vec<u64>,
|
||||
|
||||
/// Host session handle (for stopping)
|
||||
pub host_handle: Option<HostSessionHandle>,
|
||||
}
|
||||
|
||||
impl ActiveSession {
|
||||
/// Create a new session from an approved plan
|
||||
pub fn new(
|
||||
plan: SessionPlan,
|
||||
now: DateTime<Local>,
|
||||
now_mono: MonotonicInstant,
|
||||
) -> Self {
|
||||
let deadline = now + chrono::Duration::from_std(plan.max_duration).unwrap();
|
||||
let deadline_mono = now_mono + plan.max_duration;
|
||||
|
||||
Self {
|
||||
plan,
|
||||
state: SessionState::Launching,
|
||||
started_at: now,
|
||||
started_at_mono: now_mono,
|
||||
deadline,
|
||||
deadline_mono,
|
||||
warnings_issued: Vec::new(),
|
||||
host_handle: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attach the host handle once spawn succeeds
|
||||
pub fn attach_handle(&mut self, handle: HostSessionHandle) {
|
||||
self.host_handle = Some(handle);
|
||||
self.state = SessionState::Running;
|
||||
}
|
||||
|
||||
/// Get time remaining using monotonic time
|
||||
pub fn time_remaining(&self, now_mono: MonotonicInstant) -> Duration {
|
||||
self.deadline_mono.saturating_duration_until(now_mono)
|
||||
}
|
||||
|
||||
/// Check if session is expired
|
||||
pub fn is_expired(&self, now_mono: MonotonicInstant) -> bool {
|
||||
now_mono >= self.deadline_mono
|
||||
}
|
||||
|
||||
/// Get pending warnings (not yet issued) that should fire now
|
||||
pub fn pending_warnings(&self, now_mono: MonotonicInstant) -> Vec<(u64, Duration)> {
|
||||
let elapsed = now_mono.duration_since(self.started_at_mono);
|
||||
let remaining = self.time_remaining(now_mono);
|
||||
|
||||
self.plan
|
||||
.warning_times()
|
||||
.into_iter()
|
||||
.filter(|(threshold, trigger_after)| {
|
||||
// Should trigger if elapsed >= trigger_after and not already issued
|
||||
elapsed >= *trigger_after && !self.warnings_issued.contains(threshold)
|
||||
})
|
||||
.map(|(threshold, _)| (threshold, remaining))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Mark a warning as issued
|
||||
pub fn mark_warning_issued(&mut self, threshold: u64) {
|
||||
if !self.warnings_issued.contains(&threshold) {
|
||||
self.warnings_issued.push(threshold);
|
||||
}
|
||||
// Update state to Warned if not already expiring
|
||||
if self.state == SessionState::Running {
|
||||
self.state = SessionState::Warned;
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark session as expiring
|
||||
pub fn mark_expiring(&mut self) {
|
||||
self.state = SessionState::Expiring;
|
||||
}
|
||||
|
||||
/// Mark session as ended
|
||||
pub fn mark_ended(&mut self) {
|
||||
self.state = SessionState::Ended;
|
||||
}
|
||||
|
||||
/// Get session duration so far
|
||||
pub fn duration_so_far(&self, now_mono: MonotonicInstant) -> Duration {
|
||||
now_mono.duration_since(self.started_at_mono)
|
||||
}
|
||||
|
||||
/// Get session info for API
|
||||
pub fn to_session_info(&self, now_mono: MonotonicInstant) -> shepherd_api::SessionInfo {
|
||||
shepherd_api::SessionInfo {
|
||||
session_id: self.plan.session_id.clone(),
|
||||
entry_id: self.plan.entry_id.clone(),
|
||||
label: self.plan.label.clone(),
|
||||
state: self.state,
|
||||
started_at: self.started_at,
|
||||
deadline: self.deadline,
|
||||
time_remaining: self.time_remaining(now_mono),
|
||||
warnings_issued: self.warnings_issued.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of stopping a session
|
||||
#[derive(Debug)]
|
||||
pub struct StopResult {
|
||||
pub session_id: SessionId,
|
||||
pub entry_id: EntryId,
|
||||
pub reason: SessionEndReason,
|
||||
pub duration: Duration,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use shepherd_api::WarningSeverity;
|
||||
|
||||
fn make_test_plan(duration_secs: u64) -> SessionPlan {
|
||||
SessionPlan {
|
||||
session_id: SessionId::new(),
|
||||
entry_id: EntryId::new("test"),
|
||||
label: "Test".into(),
|
||||
max_duration: Duration::from_secs(duration_secs),
|
||||
warnings: vec![
|
||||
WarningThreshold {
|
||||
seconds_before: 60,
|
||||
severity: WarningSeverity::Warn,
|
||||
message_template: None,
|
||||
},
|
||||
WarningThreshold {
|
||||
seconds_before: 10,
|
||||
severity: WarningSeverity::Critical,
|
||||
message_template: None,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_creation() {
|
||||
let plan = make_test_plan(300);
|
||||
let now = Local::now();
|
||||
let now_mono = MonotonicInstant::now();
|
||||
|
||||
let session = ActiveSession::new(plan, now, now_mono);
|
||||
|
||||
assert_eq!(session.state, SessionState::Launching);
|
||||
assert!(session.warnings_issued.is_empty());
|
||||
assert_eq!(session.time_remaining(now_mono), Duration::from_secs(300));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_warning_times() {
|
||||
let plan = make_test_plan(300); // 5 minutes
|
||||
|
||||
let times = plan.warning_times();
|
||||
assert_eq!(times.len(), 2);
|
||||
|
||||
// 60s warning should trigger at 240s (4 min)
|
||||
let w60 = times.iter().find(|(t, _)| *t == 60).unwrap();
|
||||
assert_eq!(w60.1, Duration::from_secs(240));
|
||||
|
||||
// 10s warning should trigger at 290s
|
||||
let w10 = times.iter().find(|(t, _)| *t == 10).unwrap();
|
||||
assert_eq!(w10.1, Duration::from_secs(290));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_warning_not_issued_for_short_session() {
|
||||
// Session shorter than warning threshold
|
||||
let plan = SessionPlan {
|
||||
session_id: SessionId::new(),
|
||||
entry_id: EntryId::new("test"),
|
||||
label: "Test".into(),
|
||||
max_duration: Duration::from_secs(30), // 30 seconds
|
||||
warnings: vec![WarningThreshold {
|
||||
seconds_before: 60, // 60 second warning - longer than session!
|
||||
severity: WarningSeverity::Warn,
|
||||
message_template: None,
|
||||
}],
|
||||
};
|
||||
|
||||
let times = plan.warning_times();
|
||||
assert!(times.is_empty()); // No warnings should be scheduled
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pending_warnings() {
|
||||
let plan = make_test_plan(300);
|
||||
let now = Local::now();
|
||||
let now_mono = MonotonicInstant::now();
|
||||
|
||||
let mut session = ActiveSession::new(plan, now, now_mono);
|
||||
|
||||
// At start, no warnings pending
|
||||
let pending = session.pending_warnings(now_mono);
|
||||
assert!(pending.is_empty());
|
||||
|
||||
// Simulate time passing - at 250s, 60s warning should be pending
|
||||
let later = now_mono + Duration::from_secs(250);
|
||||
let pending = session.pending_warnings(later);
|
||||
assert_eq!(pending.len(), 1);
|
||||
assert_eq!(pending[0].0, 60);
|
||||
|
||||
// Mark it issued
|
||||
session.mark_warning_issued(60);
|
||||
let pending = session.pending_warnings(later);
|
||||
assert!(pending.is_empty());
|
||||
}
|
||||
}
|
||||
17
crates/shepherd-host-api/Cargo.toml
Normal file
17
crates/shepherd-host-api/Cargo.toml
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
[package]
|
||||
name = "shepherd-host-api"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Host adapter trait interfaces for shepherdd platform integration"
|
||||
|
||||
[dependencies]
|
||||
shepherd-util = { workspace = true }
|
||||
shepherd-api = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
async-trait = "0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
serde_json = { workspace = true }
|
||||
108
crates/shepherd-host-api/src/capabilities.rs
Normal file
108
crates/shepherd-host-api/src/capabilities.rs
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
//! Host capabilities model
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shepherd_api::EntryKindTag;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Describes what a host adapter can do
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HostCapabilities {
|
||||
/// Entry kinds this host can spawn
|
||||
pub spawn_kinds_supported: HashSet<EntryKindTag>,
|
||||
|
||||
/// Can forcefully kill processes/sessions
|
||||
pub can_kill_forcefully: bool,
|
||||
|
||||
/// Can attempt graceful stop (e.g., SIGTERM)
|
||||
pub can_graceful_stop: bool,
|
||||
|
||||
/// Can group process trees (process groups, job objects)
|
||||
pub can_group_process_tree: bool,
|
||||
|
||||
/// Can observe process exit
|
||||
pub can_observe_exit: bool,
|
||||
|
||||
/// Can detect when window/app is ready (optional)
|
||||
pub can_observe_window_ready: bool,
|
||||
|
||||
/// Can force an app to foreground (optional)
|
||||
pub can_force_foreground: bool,
|
||||
|
||||
/// Can force fullscreen mode (optional)
|
||||
pub can_force_fullscreen: bool,
|
||||
|
||||
/// Can lock to single app (MDM/kiosk mode, optional)
|
||||
pub can_lock_to_single_app: bool,
|
||||
}
|
||||
|
||||
impl HostCapabilities {
|
||||
/// Create minimal capabilities (process spawn/kill only)
|
||||
pub fn minimal() -> Self {
|
||||
let mut spawn_kinds = HashSet::new();
|
||||
spawn_kinds.insert(EntryKindTag::Process);
|
||||
|
||||
Self {
|
||||
spawn_kinds_supported: spawn_kinds,
|
||||
can_kill_forcefully: true,
|
||||
can_graceful_stop: true,
|
||||
can_group_process_tree: false,
|
||||
can_observe_exit: true,
|
||||
can_observe_window_ready: false,
|
||||
can_force_foreground: false,
|
||||
can_force_fullscreen: false,
|
||||
can_lock_to_single_app: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create capabilities for a full Linux host with Sway
|
||||
pub fn linux_full() -> Self {
|
||||
let mut spawn_kinds = HashSet::new();
|
||||
spawn_kinds.insert(EntryKindTag::Process);
|
||||
spawn_kinds.insert(EntryKindTag::Vm);
|
||||
spawn_kinds.insert(EntryKindTag::Media);
|
||||
|
||||
Self {
|
||||
spawn_kinds_supported: spawn_kinds,
|
||||
can_kill_forcefully: true,
|
||||
can_graceful_stop: true,
|
||||
can_group_process_tree: true,
|
||||
can_observe_exit: true,
|
||||
can_observe_window_ready: true,
|
||||
can_force_foreground: true,
|
||||
can_force_fullscreen: true,
|
||||
can_lock_to_single_app: false, // Would need additional setup
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this host can spawn the given entry kind
|
||||
pub fn supports_kind(&self, kind: EntryKindTag) -> bool {
|
||||
self.spawn_kinds_supported.contains(&kind)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HostCapabilities {
|
||||
fn default() -> Self {
|
||||
Self::minimal()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn minimal_capabilities() {
|
||||
let caps = HostCapabilities::minimal();
|
||||
assert!(caps.supports_kind(EntryKindTag::Process));
|
||||
assert!(!caps.supports_kind(EntryKindTag::Vm));
|
||||
assert!(caps.can_kill_forcefully);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linux_full_capabilities() {
|
||||
let caps = HostCapabilities::linux_full();
|
||||
assert!(caps.supports_kind(EntryKindTag::Process));
|
||||
assert!(caps.supports_kind(EntryKindTag::Vm));
|
||||
assert!(caps.can_group_process_tree);
|
||||
}
|
||||
}
|
||||
135
crates/shepherd-host-api/src/handle.rs
Normal file
135
crates/shepherd-host-api/src/handle.rs
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
//! Session handle abstraction
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shepherd_util::SessionId;
|
||||
|
||||
/// Opaque handle to a running session on the host
|
||||
///
|
||||
/// This contains platform-specific identifiers and is created by the
|
||||
/// host adapter when a session is spawned.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HostSessionHandle {
|
||||
/// Session ID from the core
|
||||
pub session_id: SessionId,
|
||||
|
||||
/// Platform-specific payload (opaque to core)
|
||||
payload: HostHandlePayload,
|
||||
}
|
||||
|
||||
impl HostSessionHandle {
|
||||
pub fn new(session_id: SessionId, payload: HostHandlePayload) -> Self {
|
||||
Self { session_id, payload }
|
||||
}
|
||||
|
||||
pub fn payload(&self) -> &HostHandlePayload {
|
||||
&self.payload
|
||||
}
|
||||
}
|
||||
|
||||
/// Platform-specific handle payload
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "platform", rename_all = "snake_case")]
|
||||
pub enum HostHandlePayload {
|
||||
/// Linux: process group ID
|
||||
Linux {
|
||||
pid: u32,
|
||||
pgid: u32,
|
||||
},
|
||||
|
||||
/// Windows: job object handle (serialized as name/id)
|
||||
Windows {
|
||||
job_name: String,
|
||||
process_id: u32,
|
||||
},
|
||||
|
||||
/// macOS: bundle or process identifier
|
||||
MacOs {
|
||||
pid: u32,
|
||||
bundle_id: Option<String>,
|
||||
},
|
||||
|
||||
/// Mock for testing
|
||||
Mock {
|
||||
id: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl HostHandlePayload {
|
||||
/// Get the process ID if applicable
|
||||
pub fn pid(&self) -> Option<u32> {
|
||||
match self {
|
||||
HostHandlePayload::Linux { pid, .. } => Some(*pid),
|
||||
HostHandlePayload::Windows { process_id, .. } => Some(*process_id),
|
||||
HostHandlePayload::MacOs { pid, .. } => Some(*pid),
|
||||
HostHandlePayload::Mock { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Exit status from a session
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExitStatus {
|
||||
/// Exit code if the process exited normally
|
||||
pub code: Option<i32>,
|
||||
|
||||
/// Whether the process was signaled
|
||||
pub signaled: bool,
|
||||
|
||||
/// Signal number if signaled (Unix)
|
||||
pub signal: Option<i32>,
|
||||
}
|
||||
|
||||
impl ExitStatus {
|
||||
pub fn success() -> Self {
|
||||
Self {
|
||||
code: Some(0),
|
||||
signaled: false,
|
||||
signal: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_code(code: i32) -> Self {
|
||||
Self {
|
||||
code: Some(code),
|
||||
signaled: false,
|
||||
signal: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signaled(signal: i32) -> Self {
|
||||
Self {
|
||||
code: None,
|
||||
signaled: true,
|
||||
signal: Some(signal),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_success(&self) -> bool {
|
||||
self.code == Some(0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn handle_serialization() {
|
||||
let handle = HostSessionHandle::new(
|
||||
SessionId::new(),
|
||||
HostHandlePayload::Linux { pid: 1234, pgid: 1234 },
|
||||
);
|
||||
|
||||
let json = serde_json::to_string(&handle).unwrap();
|
||||
let parsed: HostSessionHandle = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(handle.payload().pid(), parsed.payload().pid());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exit_status() {
|
||||
assert!(ExitStatus::success().is_success());
|
||||
assert!(!ExitStatus::with_code(1).is_success());
|
||||
assert!(!ExitStatus::signaled(9).is_success());
|
||||
}
|
||||
}
|
||||
14
crates/shepherd-host-api/src/lib.rs
Normal file
14
crates/shepherd-host-api/src/lib.rs
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
//! Host adapter trait interfaces for shepherdd
|
||||
//!
|
||||
//! This crate defines the capability-based interface between the daemon core
|
||||
//! and platform-specific implementations. It contains no platform code itself.
|
||||
|
||||
mod capabilities;
|
||||
mod handle;
|
||||
mod mock;
|
||||
mod traits;
|
||||
|
||||
pub use capabilities::*;
|
||||
pub use handle::*;
|
||||
pub use mock::*;
|
||||
pub use traits::*;
|
||||
231
crates/shepherd-host-api/src/mock.rs
Normal file
231
crates/shepherd-host-api/src/mock.rs
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
//! Mock host adapter for testing
|
||||
|
||||
use async_trait::async_trait;
|
||||
use shepherd_api::EntryKind;
|
||||
use shepherd_util::SessionId;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{
|
||||
ExitStatus, HostAdapter, HostCapabilities, HostError, HostEvent, HostHandlePayload,
|
||||
HostResult, HostSessionHandle, SpawnOptions, StopMode,
|
||||
};
|
||||
|
||||
/// Mock session state for testing
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MockSession {
|
||||
pub session_id: SessionId,
|
||||
pub mock_id: u64,
|
||||
pub running: bool,
|
||||
pub exit_delay: Option<Duration>,
|
||||
}
|
||||
|
||||
/// Mock host adapter for unit/integration testing
|
||||
pub struct MockHost {
|
||||
capabilities: HostCapabilities,
|
||||
next_id: AtomicU64,
|
||||
sessions: Arc<Mutex<HashMap<u64, MockSession>>>,
|
||||
event_tx: mpsc::UnboundedSender<HostEvent>,
|
||||
event_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<HostEvent>>>>,
|
||||
|
||||
/// Configure spawn to fail
|
||||
pub fail_spawn: Arc<Mutex<bool>>,
|
||||
|
||||
/// Configure stop to fail
|
||||
pub fail_stop: Arc<Mutex<bool>>,
|
||||
|
||||
/// Auto-exit delay (simulates process exiting on its own)
|
||||
pub auto_exit_delay: Arc<Mutex<Option<Duration>>>,
|
||||
}
|
||||
|
||||
impl MockHost {
|
||||
pub fn new() -> Self {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
|
||||
Self {
|
||||
capabilities: HostCapabilities::minimal(),
|
||||
next_id: AtomicU64::new(1),
|
||||
sessions: Arc::new(Mutex::new(HashMap::new())),
|
||||
event_tx: tx,
|
||||
event_rx: Arc::new(Mutex::new(Some(rx))),
|
||||
fail_spawn: Arc::new(Mutex::new(false)),
|
||||
fail_stop: Arc::new(Mutex::new(false)),
|
||||
auto_exit_delay: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_capabilities(mut self, caps: HostCapabilities) -> Self {
|
||||
self.capabilities = caps;
|
||||
self
|
||||
}
|
||||
|
||||
/// Get list of running sessions
|
||||
pub fn running_sessions(&self) -> Vec<SessionId> {
|
||||
self.sessions
|
||||
.lock()
|
||||
.unwrap()
|
||||
.values()
|
||||
.filter(|s| s.running)
|
||||
.map(|s| s.session_id.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Simulate process exit
|
||||
pub fn simulate_exit(&self, session_id: &SessionId, status: ExitStatus) {
|
||||
let sessions = self.sessions.lock().unwrap();
|
||||
if let Some(session) = sessions.values().find(|s| &s.session_id == session_id) {
|
||||
let handle = HostSessionHandle::new(
|
||||
session.session_id.clone(),
|
||||
HostHandlePayload::Mock { id: session.mock_id },
|
||||
);
|
||||
let _ = self.event_tx.send(HostEvent::Exited { handle, status });
|
||||
}
|
||||
}
|
||||
|
||||
/// Set auto-exit behavior
|
||||
pub fn set_auto_exit(&self, delay: Option<Duration>) {
|
||||
*self.auto_exit_delay.lock().unwrap() = delay;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MockHost {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HostAdapter for MockHost {
|
||||
fn capabilities(&self) -> &HostCapabilities {
|
||||
&self.capabilities
|
||||
}
|
||||
|
||||
async fn spawn(
|
||||
&self,
|
||||
session_id: SessionId,
|
||||
_entry_kind: &EntryKind,
|
||||
_options: SpawnOptions,
|
||||
) -> HostResult<HostSessionHandle> {
|
||||
if *self.fail_spawn.lock().unwrap() {
|
||||
return Err(HostError::SpawnFailed("Mock spawn failure".into()));
|
||||
}
|
||||
|
||||
let mock_id = self.next_id.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
let session = MockSession {
|
||||
session_id: session_id.clone(),
|
||||
mock_id,
|
||||
running: true,
|
||||
exit_delay: *self.auto_exit_delay.lock().unwrap(),
|
||||
};
|
||||
|
||||
self.sessions.lock().unwrap().insert(mock_id, session.clone());
|
||||
|
||||
let handle = HostSessionHandle::new(
|
||||
session_id.clone(),
|
||||
HostHandlePayload::Mock { id: mock_id },
|
||||
);
|
||||
|
||||
// If auto-exit is configured, spawn a task to send exit event
|
||||
if let Some(delay) = session.exit_delay {
|
||||
let tx = self.event_tx.clone();
|
||||
let exit_handle = handle.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(delay).await;
|
||||
let _ = tx.send(HostEvent::Exited {
|
||||
handle: exit_handle,
|
||||
status: ExitStatus::success(),
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
async fn stop(&self, handle: &HostSessionHandle, _mode: StopMode) -> HostResult<()> {
|
||||
if *self.fail_stop.lock().unwrap() {
|
||||
return Err(HostError::StopFailed("Mock stop failure".into()));
|
||||
}
|
||||
|
||||
let mock_id = match handle.payload() {
|
||||
HostHandlePayload::Mock { id } => *id,
|
||||
_ => return Err(HostError::SessionNotFound),
|
||||
};
|
||||
|
||||
let mut sessions = self.sessions.lock().unwrap();
|
||||
if let Some(session) = sessions.get_mut(&mock_id) {
|
||||
session.running = false;
|
||||
let _ = self.event_tx.send(HostEvent::Exited {
|
||||
handle: handle.clone(),
|
||||
status: ExitStatus::signaled(15), // SIGTERM
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
Err(HostError::SessionNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> mpsc::UnboundedReceiver<HostEvent> {
|
||||
self.event_rx
|
||||
.lock()
|
||||
.unwrap()
|
||||
.take()
|
||||
.expect("subscribe() can only be called once")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_spawn_and_stop() {
|
||||
let host = MockHost::new();
|
||||
let _rx = host.subscribe();
|
||||
|
||||
let session_id = SessionId::new();
|
||||
let entry = EntryKind::Process {
|
||||
argv: vec!["test".into()],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
};
|
||||
|
||||
let handle = host
|
||||
.spawn(session_id.clone(), &entry, SpawnOptions::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(host.running_sessions().len(), 1);
|
||||
|
||||
host.stop(&handle, StopMode::Force).await.unwrap();
|
||||
|
||||
// Session marked as not running
|
||||
let sessions = host.sessions.lock().unwrap();
|
||||
let session = sessions.values().next().unwrap();
|
||||
assert!(!session.running);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_spawn_failure() {
|
||||
let host = MockHost::new();
|
||||
let _rx = host.subscribe();
|
||||
*host.fail_spawn.lock().unwrap() = true;
|
||||
|
||||
let session_id = SessionId::new();
|
||||
let entry = EntryKind::Process {
|
||||
argv: vec!["test".into()],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
};
|
||||
|
||||
let result = host
|
||||
.spawn(session_id, &entry, SpawnOptions::default())
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
146
crates/shepherd-host-api/src/traits.rs
Normal file
146
crates/shepherd-host-api/src/traits.rs
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
//! Host adapter traits
|
||||
|
||||
use async_trait::async_trait;
|
||||
use shepherd_api::EntryKind;
|
||||
use shepherd_util::SessionId;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{ExitStatus, HostCapabilities, HostSessionHandle};
|
||||
|
||||
/// Errors from host adapter operations
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HostError {
|
||||
#[error("Spawn failed: {0}")]
|
||||
SpawnFailed(String),
|
||||
|
||||
#[error("Stop failed: {0}")]
|
||||
StopFailed(String),
|
||||
|
||||
#[error("Unsupported entry kind")]
|
||||
UnsupportedKind,
|
||||
|
||||
#[error("Session not found")]
|
||||
SessionNotFound,
|
||||
|
||||
#[error("Permission denied: {0}")]
|
||||
PermissionDenied(String),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
pub type HostResult<T> = Result<T, HostError>;
|
||||
|
||||
/// Stop mode for session termination
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum StopMode {
|
||||
/// Try graceful stop with timeout, then force
|
||||
Graceful { timeout: Duration },
|
||||
/// Force immediate termination
|
||||
Force,
|
||||
}
|
||||
|
||||
impl Default for StopMode {
|
||||
fn default() -> Self {
|
||||
Self::Graceful {
|
||||
timeout: Duration::from_secs(5),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for spawning a session
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SpawnOptions {
|
||||
/// Capture stdout to log file
|
||||
pub capture_stdout: bool,
|
||||
|
||||
/// Capture stderr to log file
|
||||
pub capture_stderr: bool,
|
||||
|
||||
/// Log file path (if capturing)
|
||||
pub log_path: Option<std::path::PathBuf>,
|
||||
|
||||
/// Request fullscreen (if supported)
|
||||
pub fullscreen: bool,
|
||||
|
||||
/// Request foreground focus (if supported)
|
||||
pub foreground: bool,
|
||||
}
|
||||
|
||||
/// Events from the host adapter
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum HostEvent {
|
||||
/// Process/session has exited
|
||||
Exited {
|
||||
handle: HostSessionHandle,
|
||||
status: ExitStatus,
|
||||
},
|
||||
|
||||
/// Window is ready (for UI notification)
|
||||
WindowReady {
|
||||
handle: HostSessionHandle,
|
||||
},
|
||||
|
||||
/// Spawn failed after handle was created
|
||||
SpawnFailed {
|
||||
session_id: SessionId,
|
||||
error: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Host adapter trait - implemented by platform-specific adapters
|
||||
#[async_trait]
|
||||
pub trait HostAdapter: Send + Sync {
|
||||
/// Get the capabilities of this host adapter
|
||||
fn capabilities(&self) -> &HostCapabilities;
|
||||
|
||||
/// Spawn a new session
|
||||
async fn spawn(
|
||||
&self,
|
||||
session_id: SessionId,
|
||||
entry_kind: &EntryKind,
|
||||
options: SpawnOptions,
|
||||
) -> HostResult<HostSessionHandle>;
|
||||
|
||||
/// Stop a running session
|
||||
async fn stop(&self, handle: &HostSessionHandle, mode: StopMode) -> HostResult<()>;
|
||||
|
||||
/// Subscribe to host events
|
||||
fn subscribe(&self) -> mpsc::UnboundedReceiver<HostEvent>;
|
||||
|
||||
/// Optional: set foreground focus (if supported)
|
||||
async fn set_foreground(&self, _handle: &HostSessionHandle) -> HostResult<()> {
|
||||
Err(HostError::Internal("Not supported".into()))
|
||||
}
|
||||
|
||||
/// Optional: set fullscreen mode (if supported)
|
||||
async fn set_fullscreen(&self, _handle: &HostSessionHandle) -> HostResult<()> {
|
||||
Err(HostError::Internal("Not supported".into()))
|
||||
}
|
||||
|
||||
/// Optional: ensure the shell/launcher is visible
|
||||
async fn ensure_shell_visible(&self) -> HostResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Optional: check if the host adapter is healthy
|
||||
fn is_healthy(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn stop_mode_default() {
|
||||
let mode = StopMode::default();
|
||||
assert!(matches!(mode, StopMode::Graceful { timeout } if timeout == Duration::from_secs(5)));
|
||||
}
|
||||
}
|
||||
20
crates/shepherd-host-linux/Cargo.toml
Normal file
20
crates/shepherd-host-linux/Cargo.toml
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
[package]
|
||||
name = "shepherd-host-linux"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Linux host adapter for shepherdd: process groups, spawn/kill, exit observation"
|
||||
|
||||
[dependencies]
|
||||
shepherd-util = { workspace = true }
|
||||
shepherd-api = { workspace = true }
|
||||
shepherd-host-api = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
nix = { workspace = true }
|
||||
async-trait = "0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
284
crates/shepherd-host-linux/src/adapter.rs
Normal file
284
crates/shepherd-host-linux/src/adapter.rs
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
//! Linux host adapter implementation
|
||||
|
||||
use async_trait::async_trait;
|
||||
use shepherd_api::{EntryKind, EntryKindTag};
|
||||
use shepherd_host_api::{
|
||||
ExitStatus, HostAdapter, HostCapabilities, HostError, HostEvent, HostHandlePayload,
|
||||
HostResult, HostSessionHandle, SpawnOptions, StopMode,
|
||||
};
|
||||
use shepherd_util::SessionId;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::ManagedProcess;
|
||||
|
||||
/// Linux host adapter
|
||||
pub struct LinuxHost {
|
||||
capabilities: HostCapabilities,
|
||||
processes: Arc<Mutex<HashMap<u32, ManagedProcess>>>,
|
||||
event_tx: mpsc::UnboundedSender<HostEvent>,
|
||||
event_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<HostEvent>>>>,
|
||||
}
|
||||
|
||||
impl LinuxHost {
|
||||
pub fn new() -> Self {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
|
||||
Self {
|
||||
capabilities: HostCapabilities::linux_full(),
|
||||
processes: Arc::new(Mutex::new(HashMap::new())),
|
||||
event_tx: tx,
|
||||
event_rx: Arc::new(Mutex::new(Some(rx))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the background process monitor
|
||||
pub fn start_monitor(&self) -> tokio::task::JoinHandle<()> {
|
||||
let processes = self.processes.clone();
|
||||
let event_tx = self.event_tx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
let mut exited = Vec::new();
|
||||
|
||||
{
|
||||
let mut procs = processes.lock().unwrap();
|
||||
for (pid, proc) in procs.iter_mut() {
|
||||
match proc.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
exited.push((*pid, proc.pgid, status));
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => {
|
||||
warn!(pid = pid, error = %e, "Error checking process status");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (pid, _, _) in &exited {
|
||||
procs.remove(pid);
|
||||
}
|
||||
}
|
||||
|
||||
for (pid, pgid, status) in exited {
|
||||
debug!(pid = pid, status = ?status, "Process exited");
|
||||
|
||||
// We don't have the session_id here, so we use a placeholder
|
||||
// The daemon should track the mapping
|
||||
let handle = HostSessionHandle::new(
|
||||
SessionId::new(), // This will be matched by PID
|
||||
HostHandlePayload::Linux { pid, pgid },
|
||||
);
|
||||
|
||||
let _ = event_tx.send(HostEvent::Exited { handle, status });
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LinuxHost {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HostAdapter for LinuxHost {
|
||||
fn capabilities(&self) -> &HostCapabilities {
|
||||
&self.capabilities
|
||||
}
|
||||
|
||||
async fn spawn(
|
||||
&self,
|
||||
session_id: SessionId,
|
||||
entry_kind: &EntryKind,
|
||||
options: SpawnOptions,
|
||||
) -> HostResult<HostSessionHandle> {
|
||||
let (argv, env, cwd) = match entry_kind {
|
||||
EntryKind::Process { argv, env, cwd } => (argv.clone(), env.clone(), cwd.clone()),
|
||||
EntryKind::Vm { driver, args } => {
|
||||
// Construct command line from VM driver
|
||||
let mut argv = vec![driver.clone()];
|
||||
for (key, value) in args {
|
||||
argv.push(format!("--{}", key));
|
||||
if let Some(v) = value.as_str() {
|
||||
argv.push(v.to_string());
|
||||
} else {
|
||||
argv.push(value.to_string());
|
||||
}
|
||||
}
|
||||
(argv, HashMap::new(), None)
|
||||
}
|
||||
EntryKind::Media { library_id, args } => {
|
||||
// For media, we'd typically launch a media player
|
||||
// This is a placeholder - real implementation would integrate with a player
|
||||
let mut argv = vec!["xdg-open".to_string(), library_id.clone()];
|
||||
(argv, HashMap::new(), None)
|
||||
}
|
||||
EntryKind::Custom { type_name, payload } => {
|
||||
return Err(HostError::UnsupportedKind);
|
||||
}
|
||||
};
|
||||
|
||||
let proc = ManagedProcess::spawn(
|
||||
&argv,
|
||||
&env,
|
||||
cwd.as_ref(),
|
||||
options.capture_stdout || options.capture_stderr,
|
||||
)?;
|
||||
|
||||
let pid = proc.pid;
|
||||
let pgid = proc.pgid;
|
||||
|
||||
let handle = HostSessionHandle::new(
|
||||
session_id,
|
||||
HostHandlePayload::Linux { pid, pgid },
|
||||
);
|
||||
|
||||
self.processes.lock().unwrap().insert(pid, proc);
|
||||
|
||||
info!(pid = pid, pgid = pgid, "Spawned process");
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
async fn stop(&self, handle: &HostSessionHandle, mode: StopMode) -> HostResult<()> {
|
||||
let (pid, _pgid) = match handle.payload() {
|
||||
HostHandlePayload::Linux { pid, pgid } => (*pid, *pgid),
|
||||
_ => return Err(HostError::SessionNotFound),
|
||||
};
|
||||
|
||||
// Check if process exists
|
||||
{
|
||||
let procs = self.processes.lock().unwrap();
|
||||
if !procs.contains_key(&pid) {
|
||||
return Err(HostError::SessionNotFound);
|
||||
}
|
||||
}
|
||||
|
||||
match mode {
|
||||
StopMode::Graceful { timeout } => {
|
||||
// Send SIGTERM
|
||||
{
|
||||
let procs = self.processes.lock().unwrap();
|
||||
if let Some(p) = procs.get(&pid) {
|
||||
p.terminate()?;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for graceful exit
|
||||
let start = std::time::Instant::now();
|
||||
loop {
|
||||
if start.elapsed() >= timeout {
|
||||
// Force kill after timeout
|
||||
let procs = self.processes.lock().unwrap();
|
||||
if let Some(p) = procs.get(&pid) {
|
||||
p.kill()?;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
{
|
||||
let procs = self.processes.lock().unwrap();
|
||||
if !procs.contains_key(&pid) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
StopMode::Force => {
|
||||
let procs = self.processes.lock().unwrap();
|
||||
if let Some(p) = procs.get(&pid) {
|
||||
p.kill()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> mpsc::UnboundedReceiver<HostEvent> {
|
||||
self.event_rx
|
||||
.lock()
|
||||
.unwrap()
|
||||
.take()
|
||||
.expect("subscribe() can only be called once")
|
||||
}
|
||||
|
||||
fn is_healthy(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_spawn_and_exit() {
|
||||
let host = LinuxHost::new();
|
||||
let _rx = host.subscribe();
|
||||
|
||||
let session_id = SessionId::new();
|
||||
let entry = EntryKind::Process {
|
||||
argv: vec!["true".into()],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
};
|
||||
|
||||
let handle = host
|
||||
.spawn(session_id, &entry, SpawnOptions::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Give it time to exit
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
// Process should have exited
|
||||
match handle.payload() {
|
||||
HostHandlePayload::Linux { pid, .. } => {
|
||||
let procs = host.processes.lock().unwrap();
|
||||
// Process may or may not still be tracked depending on monitor timing
|
||||
}
|
||||
_ => panic!("Expected Linux handle"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_spawn_and_kill() {
|
||||
let host = LinuxHost::new();
|
||||
let _rx = host.subscribe();
|
||||
|
||||
let session_id = SessionId::new();
|
||||
let entry = EntryKind::Process {
|
||||
argv: vec!["sleep".into(), "60".into()],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
};
|
||||
|
||||
let handle = host
|
||||
.spawn(session_id, &entry, SpawnOptions::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||
|
||||
// Kill it
|
||||
host.stop(
|
||||
&handle,
|
||||
StopMode::Graceful {
|
||||
timeout: Duration::from_secs(1),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
13
crates/shepherd-host-linux/src/lib.rs
Normal file
13
crates/shepherd-host-linux/src/lib.rs
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
//! Linux host adapter for shepherdd
|
||||
//!
|
||||
//! Provides:
|
||||
//! - Process spawning with process group isolation
|
||||
//! - Graceful (SIGTERM) and forceful (SIGKILL) termination
|
||||
//! - Exit observation
|
||||
//! - stdout/stderr capture
|
||||
|
||||
mod adapter;
|
||||
mod process;
|
||||
|
||||
pub use adapter::*;
|
||||
pub use process::*;
|
||||
245
crates/shepherd-host-linux/src/process.rs
Normal file
245
crates/shepherd-host-linux/src/process.rs
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
//! Process management utilities
|
||||
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::sys::wait::{waitpid, WaitPidFlag, WaitStatus};
|
||||
use nix::unistd::Pid;
|
||||
use std::collections::HashMap;
|
||||
use std::os::unix::process::CommandExt;
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Child, Command, Stdio};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use shepherd_host_api::{ExitStatus, HostError, HostResult};
|
||||
|
||||
/// Managed child process with process group
|
||||
pub struct ManagedProcess {
|
||||
pub child: Child,
|
||||
pub pid: u32,
|
||||
pub pgid: u32,
|
||||
}
|
||||
|
||||
impl ManagedProcess {
|
||||
/// Spawn a new process in its own process group
|
||||
pub fn spawn(
|
||||
argv: &[String],
|
||||
env: &HashMap<String, String>,
|
||||
cwd: Option<&PathBuf>,
|
||||
capture_output: bool,
|
||||
) -> HostResult<Self> {
|
||||
if argv.is_empty() {
|
||||
return Err(HostError::SpawnFailed("Empty argv".into()));
|
||||
}
|
||||
|
||||
let program = &argv[0];
|
||||
let args = &argv[1..];
|
||||
|
||||
let mut cmd = Command::new(program);
|
||||
cmd.args(args);
|
||||
|
||||
// Set environment
|
||||
cmd.env_clear();
|
||||
// Inherit some basic environment
|
||||
if let Ok(path) = std::env::var("PATH") {
|
||||
cmd.env("PATH", path);
|
||||
}
|
||||
if let Ok(home) = std::env::var("HOME") {
|
||||
cmd.env("HOME", home);
|
||||
}
|
||||
if let Ok(display) = std::env::var("DISPLAY") {
|
||||
cmd.env("DISPLAY", display);
|
||||
}
|
||||
if let Ok(wayland) = std::env::var("WAYLAND_DISPLAY") {
|
||||
cmd.env("WAYLAND_DISPLAY", wayland);
|
||||
}
|
||||
if let Ok(xdg_runtime) = std::env::var("XDG_RUNTIME_DIR") {
|
||||
cmd.env("XDG_RUNTIME_DIR", xdg_runtime);
|
||||
}
|
||||
|
||||
// Add custom environment
|
||||
for (k, v) in env {
|
||||
cmd.env(k, v);
|
||||
}
|
||||
|
||||
// Set working directory
|
||||
if let Some(dir) = cwd {
|
||||
cmd.current_dir(dir);
|
||||
}
|
||||
|
||||
// Configure output capture
|
||||
if capture_output {
|
||||
cmd.stdout(Stdio::piped());
|
||||
cmd.stderr(Stdio::piped());
|
||||
} else {
|
||||
cmd.stdout(Stdio::null());
|
||||
cmd.stderr(Stdio::null());
|
||||
}
|
||||
|
||||
cmd.stdin(Stdio::null());
|
||||
|
||||
// Set up process group - this child becomes its own process group leader
|
||||
// SAFETY: This is safe in the pre-exec context
|
||||
unsafe {
|
||||
cmd.pre_exec(|| {
|
||||
// Create new session (which creates new process group)
|
||||
// This ensures the child is the leader of a new process group
|
||||
nix::unistd::setsid().map_err(|e| {
|
||||
std::io::Error::new(std::io::ErrorKind::Other, e.to_string())
|
||||
})?;
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
let child = cmd.spawn().map_err(|e| {
|
||||
HostError::SpawnFailed(format!("Failed to spawn {}: {}", program, e))
|
||||
})?;
|
||||
|
||||
let pid = child.id();
|
||||
let pgid = pid; // After setsid, pid == pgid
|
||||
|
||||
debug!(pid = pid, pgid = pgid, program = %program, "Process spawned");
|
||||
|
||||
Ok(Self { child, pid, pgid })
|
||||
}
|
||||
|
||||
/// Send SIGTERM to the process group
|
||||
pub fn terminate(&self) -> HostResult<()> {
|
||||
let pgid = Pid::from_raw(-(self.pgid as i32)); // Negative for process group
|
||||
|
||||
match signal::kill(pgid, Signal::SIGTERM) {
|
||||
Ok(()) => {
|
||||
debug!(pgid = self.pgid, "Sent SIGTERM to process group");
|
||||
Ok(())
|
||||
}
|
||||
Err(nix::errno::Errno::ESRCH) => {
|
||||
// Process already gone
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(HostError::StopFailed(format!(
|
||||
"Failed to send SIGTERM: {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send SIGKILL to the process group
|
||||
pub fn kill(&self) -> HostResult<()> {
|
||||
let pgid = Pid::from_raw(-(self.pgid as i32));
|
||||
|
||||
match signal::kill(pgid, Signal::SIGKILL) {
|
||||
Ok(()) => {
|
||||
debug!(pgid = self.pgid, "Sent SIGKILL to process group");
|
||||
Ok(())
|
||||
}
|
||||
Err(nix::errno::Errno::ESRCH) => {
|
||||
// Process already gone
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(HostError::StopFailed(format!(
|
||||
"Failed to send SIGKILL: {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the process has exited (non-blocking)
|
||||
pub fn try_wait(&mut self) -> HostResult<Option<ExitStatus>> {
|
||||
match self.child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
let exit_status = if let Some(code) = status.code() {
|
||||
ExitStatus::with_code(code)
|
||||
} else {
|
||||
// Killed by signal
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::process::ExitStatusExt;
|
||||
if let Some(sig) = status.signal() {
|
||||
ExitStatus::signaled(sig)
|
||||
} else {
|
||||
ExitStatus::with_code(-1)
|
||||
}
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
ExitStatus::with_code(-1)
|
||||
}
|
||||
};
|
||||
Ok(Some(exit_status))
|
||||
}
|
||||
Ok(None) => Ok(None), // Still running
|
||||
Err(e) => Err(HostError::Internal(format!("Wait failed: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for the process to exit (blocking)
|
||||
pub fn wait(&mut self) -> HostResult<ExitStatus> {
|
||||
match self.child.wait() {
|
||||
Ok(status) => {
|
||||
let exit_status = if let Some(code) = status.code() {
|
||||
ExitStatus::with_code(code)
|
||||
} else {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::process::ExitStatusExt;
|
||||
if let Some(sig) = status.signal() {
|
||||
ExitStatus::signaled(sig)
|
||||
} else {
|
||||
ExitStatus::with_code(-1)
|
||||
}
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
ExitStatus::with_code(-1)
|
||||
}
|
||||
};
|
||||
Ok(exit_status)
|
||||
}
|
||||
Err(e) => Err(HostError::Internal(format!("Wait failed: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn spawn_simple_process() {
|
||||
let argv = vec!["true".to_string()];
|
||||
let env = HashMap::new();
|
||||
|
||||
let mut proc = ManagedProcess::spawn(&argv, &env, None, false).unwrap();
|
||||
|
||||
// Wait for it to complete
|
||||
let status = proc.wait().unwrap();
|
||||
assert!(status.is_success());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spawn_with_args() {
|
||||
let argv = vec!["echo".to_string(), "hello".to_string()];
|
||||
let env = HashMap::new();
|
||||
|
||||
let mut proc = ManagedProcess::spawn(&argv, &env, None, false).unwrap();
|
||||
let status = proc.wait().unwrap();
|
||||
assert!(status.is_success());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn terminate_sleeping_process() {
|
||||
let argv = vec!["sleep".to_string(), "60".to_string()];
|
||||
let env = HashMap::new();
|
||||
|
||||
let proc = ManagedProcess::spawn(&argv, &env, None, false).unwrap();
|
||||
|
||||
// Give it a moment to start
|
||||
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||
|
||||
// Terminate it
|
||||
proc.terminate().unwrap();
|
||||
|
||||
// Wait a bit and check
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
|
||||
// Process should be gone or terminating
|
||||
}
|
||||
}
|
||||
19
crates/shepherd-ipc/Cargo.toml
Normal file
19
crates/shepherd-ipc/Cargo.toml
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
[package]
|
||||
name = "shepherd-ipc"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "IPC layer for shepherdd: Unix domain socket server, NDJSON protocol"
|
||||
|
||||
[dependencies]
|
||||
shepherd-util = { workspace = true }
|
||||
shepherd-api = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
nix = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
93
crates/shepherd-ipc/src/client.rs
Normal file
93
crates/shepherd-ipc/src/client.rs
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
//! IPC client implementation
|
||||
|
||||
use shepherd_api::{Command, Event, Request, Response, ResponseResult};
|
||||
use std::path::Path;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::net::UnixStream;
|
||||
|
||||
use crate::{IpcError, IpcResult};
|
||||
|
||||
/// IPC Client for connecting to shepherdd
|
||||
pub struct IpcClient {
|
||||
reader: BufReader<tokio::net::unix::OwnedReadHalf>,
|
||||
writer: tokio::net::unix::OwnedWriteHalf,
|
||||
next_request_id: u64,
|
||||
}
|
||||
|
||||
impl IpcClient {
|
||||
/// Connect to the daemon
|
||||
pub async fn connect(socket_path: impl AsRef<Path>) -> IpcResult<Self> {
|
||||
let stream = UnixStream::connect(socket_path).await?;
|
||||
let (read_half, write_half) = stream.into_split();
|
||||
|
||||
Ok(Self {
|
||||
reader: BufReader::new(read_half),
|
||||
writer: write_half,
|
||||
next_request_id: 1,
|
||||
})
|
||||
}
|
||||
|
||||
/// Send a command and wait for response
|
||||
pub async fn send(&mut self, command: Command) -> IpcResult<Response> {
|
||||
let request_id = self.next_request_id;
|
||||
self.next_request_id += 1;
|
||||
|
||||
let request = Request::new(request_id, command);
|
||||
let mut json = serde_json::to_string(&request)?;
|
||||
json.push('\n');
|
||||
|
||||
self.writer.write_all(json.as_bytes()).await?;
|
||||
|
||||
// Read response
|
||||
let mut line = String::new();
|
||||
let n = self.reader.read_line(&mut line).await?;
|
||||
if n == 0 {
|
||||
return Err(IpcError::ConnectionClosed);
|
||||
}
|
||||
|
||||
let response: Response = serde_json::from_str(line.trim())?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Subscribe to events and consume this client to return an event stream
|
||||
pub async fn subscribe(mut self) -> IpcResult<EventStream> {
|
||||
let response = self.send(Command::SubscribeEvents).await?;
|
||||
|
||||
match response.result {
|
||||
ResponseResult::Ok(_) => {}
|
||||
ResponseResult::Err(e) => {
|
||||
return Err(IpcError::ServerError(e.message));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(EventStream {
|
||||
reader: self.reader,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream of events from the daemon
|
||||
pub struct EventStream {
|
||||
reader: BufReader<tokio::net::unix::OwnedReadHalf>,
|
||||
}
|
||||
|
||||
impl EventStream {
|
||||
/// Wait for the next event
|
||||
pub async fn next(&mut self) -> IpcResult<Event> {
|
||||
let mut line = String::new();
|
||||
let n = self.reader.read_line(&mut line).await?;
|
||||
if n == 0 {
|
||||
return Err(IpcError::ConnectionClosed);
|
||||
}
|
||||
|
||||
let event: Event = serde_json::from_str(line.trim())?;
|
||||
Ok(event)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// Client tests would require a running server
|
||||
// See integration tests
|
||||
}
|
||||
36
crates/shepherd-ipc/src/lib.rs
Normal file
36
crates/shepherd-ipc/src/lib.rs
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
//! IPC layer for shepherdd
|
||||
//!
|
||||
//! Provides:
|
||||
//! - Unix domain socket server
|
||||
//! - NDJSON (newline-delimited JSON) protocol
|
||||
//! - Client connection management
|
||||
//! - Peer UID authentication
|
||||
|
||||
mod client;
|
||||
mod server;
|
||||
|
||||
pub use client::*;
|
||||
pub use server::*;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// IPC errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum IpcError {
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("JSON error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error("Connection closed")]
|
||||
ConnectionClosed,
|
||||
|
||||
#[error("Invalid message: {0}")]
|
||||
InvalidMessage(String),
|
||||
|
||||
#[error("Server error: {0}")]
|
||||
ServerError(String),
|
||||
}
|
||||
|
||||
pub type IpcResult<T> = Result<T, IpcError>;
|
||||
336
crates/shepherd-ipc/src/server.rs
Normal file
336
crates/shepherd-ipc/src/server.rs
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
//! IPC server implementation
|
||||
|
||||
use shepherd_api::{ClientInfo, ClientRole, Event, Request, Response};
|
||||
use shepherd_util::ClientId;
|
||||
use std::collections::HashMap;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{IpcError, IpcResult};
|
||||
|
||||
/// Message from client to server
|
||||
pub enum ServerMessage {
|
||||
Request {
|
||||
client_id: ClientId,
|
||||
request: Request,
|
||||
},
|
||||
ClientConnected {
|
||||
client_id: ClientId,
|
||||
info: ClientInfo,
|
||||
},
|
||||
ClientDisconnected {
|
||||
client_id: ClientId,
|
||||
},
|
||||
}
|
||||
|
||||
/// IPC Server
|
||||
pub struct IpcServer {
|
||||
socket_path: PathBuf,
|
||||
listener: Option<UnixListener>,
|
||||
clients: Arc<RwLock<HashMap<ClientId, ClientHandle>>>,
|
||||
event_tx: broadcast::Sender<Event>,
|
||||
message_tx: mpsc::UnboundedSender<ServerMessage>,
|
||||
message_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<ServerMessage>>>>,
|
||||
}
|
||||
|
||||
struct ClientHandle {
|
||||
info: ClientInfo,
|
||||
response_tx: mpsc::UnboundedSender<String>,
|
||||
subscribed: bool,
|
||||
}
|
||||
|
||||
impl IpcServer {
|
||||
/// Create a new IPC server
|
||||
pub fn new(socket_path: impl AsRef<Path>) -> Self {
|
||||
let (event_tx, _) = broadcast::channel(100);
|
||||
let (message_tx, message_rx) = mpsc::unbounded_channel();
|
||||
|
||||
Self {
|
||||
socket_path: socket_path.as_ref().to_path_buf(),
|
||||
listener: None,
|
||||
clients: Arc::new(RwLock::new(HashMap::new())),
|
||||
event_tx,
|
||||
message_tx,
|
||||
message_rx: Arc::new(Mutex::new(Some(message_rx))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start listening
|
||||
pub async fn start(&mut self) -> IpcResult<()> {
|
||||
// Remove existing socket if present
|
||||
if self.socket_path.exists() {
|
||||
std::fs::remove_file(&self.socket_path)?;
|
||||
}
|
||||
|
||||
// Create parent directory if needed
|
||||
if let Some(parent) = self.socket_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let listener = UnixListener::bind(&self.socket_path)?;
|
||||
|
||||
// Set socket permissions (readable/writable by owner and group)
|
||||
std::fs::set_permissions(&self.socket_path, std::fs::Permissions::from_mode(0o660))?;
|
||||
|
||||
info!(path = %self.socket_path.display(), "IPC server listening");
|
||||
|
||||
self.listener = Some(listener);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get receiver for server messages
|
||||
pub async fn take_message_receiver(&self) -> Option<mpsc::UnboundedReceiver<ServerMessage>> {
|
||||
self.message_rx.lock().await.take()
|
||||
}
|
||||
|
||||
/// Accept connections in a loop
|
||||
pub async fn run(&self) -> IpcResult<()> {
|
||||
let listener = self
|
||||
.listener
|
||||
.as_ref()
|
||||
.ok_or_else(|| IpcError::ServerError("Server not started".into()))?;
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, _)) => {
|
||||
let client_id = ClientId::new();
|
||||
|
||||
// Get peer credentials
|
||||
let uid = get_peer_uid(&stream);
|
||||
|
||||
// Determine role based on UID
|
||||
let role = match uid {
|
||||
Some(0) => ClientRole::Admin, // root
|
||||
Some(u) if u == nix::unistd::getuid().as_raw() => ClientRole::Admin,
|
||||
_ => ClientRole::Shell,
|
||||
};
|
||||
|
||||
let info = ClientInfo::new(role);
|
||||
let info = if let Some(u) = uid {
|
||||
info.with_uid(u)
|
||||
} else {
|
||||
info
|
||||
};
|
||||
|
||||
info!(client_id = %client_id, uid = ?uid, role = ?role, "Client connected");
|
||||
|
||||
self.handle_client(stream, client_id, info).await;
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "Failed to accept connection");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_client(&self, stream: UnixStream, client_id: ClientId, info: ClientInfo) {
|
||||
let (read_half, write_half) = stream.into_split();
|
||||
let (response_tx, mut response_rx) = mpsc::unbounded_channel::<String>();
|
||||
|
||||
// Register client
|
||||
{
|
||||
let mut clients = self.clients.write().await;
|
||||
clients.insert(
|
||||
client_id.clone(),
|
||||
ClientHandle {
|
||||
info: info.clone(),
|
||||
response_tx: response_tx.clone(),
|
||||
subscribed: false,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Notify of connection
|
||||
let _ = self.message_tx.send(ServerMessage::ClientConnected {
|
||||
client_id: client_id.clone(),
|
||||
info: info.clone(),
|
||||
});
|
||||
|
||||
let clients = self.clients.clone();
|
||||
let message_tx = self.message_tx.clone();
|
||||
let event_tx = self.event_tx.clone();
|
||||
let client_id_clone = client_id.clone();
|
||||
|
||||
// Spawn reader task
|
||||
let reader_handle = tokio::spawn(async move {
|
||||
let mut reader = BufReader::new(read_half);
|
||||
let mut line = String::new();
|
||||
|
||||
loop {
|
||||
line.clear();
|
||||
match reader.read_line(&mut line).await {
|
||||
Ok(0) => {
|
||||
debug!(client_id = %client_id_clone, "Client disconnected (EOF)");
|
||||
break;
|
||||
}
|
||||
Ok(_) => {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match serde_json::from_str::<Request>(line) {
|
||||
Ok(request) => {
|
||||
// Check for subscribe command
|
||||
if matches!(request.command, shepherd_api::Command::SubscribeEvents) {
|
||||
let mut clients = clients.write().await;
|
||||
if let Some(handle) = clients.get_mut(&client_id_clone) {
|
||||
handle.subscribed = true;
|
||||
}
|
||||
}
|
||||
|
||||
let _ = message_tx.send(ServerMessage::Request {
|
||||
client_id: client_id_clone.clone(),
|
||||
request,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
client_id = %client_id_clone,
|
||||
error = %e,
|
||||
"Invalid request"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(client_id = %client_id_clone, error = %e, "Read error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn writer task
|
||||
let mut event_rx = event_tx.subscribe();
|
||||
let clients_writer = self.clients.clone();
|
||||
let client_id_writer = client_id.clone();
|
||||
let message_tx_writer = self.message_tx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut writer = write_half;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Handle responses
|
||||
Some(response) = response_rx.recv() => {
|
||||
let mut msg = response;
|
||||
msg.push('\n');
|
||||
if let Err(e) = writer.write_all(msg.as_bytes()).await {
|
||||
debug!(client_id = %client_id_writer, error = %e, "Write error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle events (for subscribed clients)
|
||||
Ok(event) = event_rx.recv() => {
|
||||
let is_subscribed = {
|
||||
let clients = clients_writer.read().await;
|
||||
clients.get(&client_id_writer).map(|h| h.subscribed).unwrap_or(false)
|
||||
};
|
||||
|
||||
if is_subscribed {
|
||||
if let Ok(json) = serde_json::to_string(&event) {
|
||||
let mut msg = json;
|
||||
msg.push('\n');
|
||||
if let Err(e) = writer.write_all(msg.as_bytes()).await {
|
||||
debug!(client_id = %client_id_writer, error = %e, "Event write error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Notify of disconnection
|
||||
let _ = message_tx_writer.send(ServerMessage::ClientDisconnected {
|
||||
client_id: client_id_writer.clone(),
|
||||
});
|
||||
|
||||
// Remove client
|
||||
let mut clients = clients_writer.write().await;
|
||||
clients.remove(&client_id_writer);
|
||||
});
|
||||
}
|
||||
|
||||
/// Send a response to a specific client
|
||||
pub async fn send_response(&self, client_id: &ClientId, response: Response) -> IpcResult<()> {
|
||||
let json = serde_json::to_string(&response)?;
|
||||
|
||||
let clients = self.clients.read().await;
|
||||
if let Some(handle) = clients.get(client_id) {
|
||||
handle
|
||||
.response_tx
|
||||
.send(json)
|
||||
.map_err(|_| IpcError::ConnectionClosed)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Broadcast an event to all subscribed clients
|
||||
pub fn broadcast_event(&self, event: Event) {
|
||||
let _ = self.event_tx.send(event);
|
||||
}
|
||||
|
||||
/// Get client info
|
||||
pub async fn get_client_info(&self, client_id: &ClientId) -> Option<ClientInfo> {
|
||||
let clients = self.clients.read().await;
|
||||
clients.get(client_id).map(|h| h.info.clone())
|
||||
}
|
||||
|
||||
/// Get connected client count
|
||||
pub async fn client_count(&self) -> usize {
|
||||
self.clients.read().await.len()
|
||||
}
|
||||
|
||||
/// Shutdown the server
|
||||
pub fn shutdown(&self) {
|
||||
if self.socket_path.exists() {
|
||||
let _ = std::fs::remove_file(&self.socket_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for IpcServer {
|
||||
fn drop(&mut self) {
|
||||
self.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
/// Get peer UID from Unix socket
|
||||
fn get_peer_uid(stream: &UnixStream) -> Option<u32> {
|
||||
use std::os::unix::io::AsFd;
|
||||
|
||||
// Get the borrowed file descriptor from the stream
|
||||
let fd = stream.as_fd();
|
||||
|
||||
match nix::sys::socket::getsockopt(&fd, nix::sys::socket::sockopt::PeerCredentials) {
|
||||
Ok(cred) => Some(cred.uid()),
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_server_start() {
|
||||
let dir = tempdir().unwrap();
|
||||
let socket_path = dir.path().join("test.sock");
|
||||
|
||||
let mut server = IpcServer::new(&socket_path);
|
||||
server.start().await.unwrap();
|
||||
|
||||
assert!(socket_path.exists());
|
||||
}
|
||||
}
|
||||
20
crates/shepherd-store/Cargo.toml
Normal file
20
crates/shepherd-store/Cargo.toml
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
[package]
|
||||
name = "shepherd-store"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Persistence layer for shepherdd: audit log, usage accounting, cooldowns"
|
||||
|
||||
[dependencies]
|
||||
shepherd-util = { workspace = true }
|
||||
shepherd-api = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
rusqlite = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
92
crates/shepherd-store/src/audit.rs
Normal file
92
crates/shepherd-store/src/audit.rs
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
//! Audit event types
|
||||
|
||||
use chrono::{DateTime, Local};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use shepherd_api::SessionEndReason;
|
||||
use shepherd_util::{EntryId, SessionId};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Types of audit events
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum AuditEventType {
|
||||
/// Daemon started
|
||||
DaemonStarted,
|
||||
|
||||
/// Daemon stopped
|
||||
DaemonStopped,
|
||||
|
||||
/// Policy loaded/reloaded
|
||||
PolicyLoaded { entry_count: usize },
|
||||
|
||||
/// Session started
|
||||
SessionStarted {
|
||||
session_id: SessionId,
|
||||
entry_id: EntryId,
|
||||
label: String,
|
||||
deadline: DateTime<Local>,
|
||||
},
|
||||
|
||||
/// Warning issued
|
||||
WarningIssued {
|
||||
session_id: SessionId,
|
||||
threshold_seconds: u64,
|
||||
},
|
||||
|
||||
/// Session ended
|
||||
SessionEnded {
|
||||
session_id: SessionId,
|
||||
entry_id: EntryId,
|
||||
reason: SessionEndReason,
|
||||
duration: Duration,
|
||||
},
|
||||
|
||||
/// Launch denied
|
||||
LaunchDenied {
|
||||
entry_id: EntryId,
|
||||
reasons: Vec<String>,
|
||||
},
|
||||
|
||||
/// Session extended (admin action)
|
||||
SessionExtended {
|
||||
session_id: SessionId,
|
||||
extended_by: Duration,
|
||||
new_deadline: DateTime<Local>,
|
||||
},
|
||||
|
||||
/// Config reload requested
|
||||
ConfigReloaded { success: bool },
|
||||
|
||||
/// Client connected
|
||||
ClientConnected {
|
||||
client_id: String,
|
||||
role: String,
|
||||
uid: Option<u32>,
|
||||
},
|
||||
|
||||
/// Client disconnected
|
||||
ClientDisconnected { client_id: String },
|
||||
}
|
||||
|
||||
/// Full audit event with metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AuditEvent {
|
||||
/// Unique event ID
|
||||
pub id: i64,
|
||||
|
||||
/// Event timestamp
|
||||
pub timestamp: DateTime<Local>,
|
||||
|
||||
/// Event type and details
|
||||
pub event: AuditEventType,
|
||||
}
|
||||
|
||||
impl AuditEvent {
|
||||
pub fn new(event: AuditEventType) -> Self {
|
||||
Self {
|
||||
id: 0, // Will be set by store
|
||||
timestamp: Local::now(),
|
||||
event,
|
||||
}
|
||||
}
|
||||
}
|
||||
47
crates/shepherd-store/src/lib.rs
Normal file
47
crates/shepherd-store/src/lib.rs
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
//! Persistence layer for shepherdd
|
||||
//!
|
||||
//! Provides:
|
||||
//! - Audit log (append-only)
|
||||
//! - Usage accounting (per entry/day)
|
||||
//! - Cooldown tracking
|
||||
//! - State snapshot for recovery
|
||||
|
||||
mod audit;
|
||||
mod sqlite;
|
||||
mod traits;
|
||||
|
||||
pub use audit::*;
|
||||
pub use sqlite::*;
|
||||
pub use traits::*;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Store errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum StoreError {
|
||||
#[error("Database error: {0}")]
|
||||
Database(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(String),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
}
|
||||
|
||||
impl From<rusqlite::Error> for StoreError {
|
||||
fn from(e: rusqlite::Error) -> Self {
|
||||
StoreError::Database(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for StoreError {
|
||||
fn from(e: serde_json::Error) -> Self {
|
||||
StoreError::Serialization(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub type StoreResult<T> = Result<T, StoreError>;
|
||||
346
crates/shepherd-store/src/sqlite.rs
Normal file
346
crates/shepherd-store/src/sqlite.rs
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
//! SQLite-based store implementation
|
||||
|
||||
use chrono::{DateTime, Local, NaiveDate, TimeZone};
|
||||
use rusqlite::{params, Connection, OptionalExtension};
|
||||
use shepherd_util::EntryId;
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::{AuditEvent, SessionSnapshot, StateSnapshot, Store, StoreError, StoreResult};
|
||||
|
||||
/// SQLite-based store
|
||||
pub struct SqliteStore {
|
||||
conn: Mutex<Connection>,
|
||||
}
|
||||
|
||||
impl SqliteStore {
|
||||
/// Open or create a store at the given path
|
||||
pub fn open(path: impl AsRef<Path>) -> StoreResult<Self> {
|
||||
let conn = Connection::open(path)?;
|
||||
let store = Self {
|
||||
conn: Mutex::new(conn),
|
||||
};
|
||||
store.init_schema()?;
|
||||
Ok(store)
|
||||
}
|
||||
|
||||
/// Create an in-memory store (for testing)
|
||||
pub fn in_memory() -> StoreResult<Self> {
|
||||
let conn = Connection::open_in_memory()?;
|
||||
let store = Self {
|
||||
conn: Mutex::new(conn),
|
||||
};
|
||||
store.init_schema()?;
|
||||
Ok(store)
|
||||
}
|
||||
|
||||
fn init_schema(&self) -> StoreResult<()> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
|
||||
conn.execute_batch(
|
||||
r#"
|
||||
-- Audit log (append-only)
|
||||
CREATE TABLE IF NOT EXISTS audit_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
event_json TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- Usage accounting
|
||||
CREATE TABLE IF NOT EXISTS usage (
|
||||
entry_id TEXT NOT NULL,
|
||||
day TEXT NOT NULL,
|
||||
duration_secs INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (entry_id, day)
|
||||
);
|
||||
|
||||
-- Cooldowns
|
||||
CREATE TABLE IF NOT EXISTS cooldowns (
|
||||
entry_id TEXT PRIMARY KEY,
|
||||
until TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- State snapshot (single row)
|
||||
CREATE TABLE IF NOT EXISTS snapshot (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||
snapshot_json TEXT NOT NULL
|
||||
);
|
||||
|
||||
-- Indexes
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_day ON usage(day);
|
||||
"#,
|
||||
)?;
|
||||
|
||||
debug!("Store schema initialized");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Store for SqliteStore {
|
||||
fn append_audit(&self, mut event: AuditEvent) -> StoreResult<()> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let event_json = serde_json::to_string(&event.event)?;
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO audit_log (timestamp, event_json) VALUES (?, ?)",
|
||||
params![event.timestamp.to_rfc3339(), event_json],
|
||||
)?;
|
||||
|
||||
event.id = conn.last_insert_rowid();
|
||||
debug!(event_id = event.id, "Audit event appended");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_recent_audits(&self, limit: usize) -> StoreResult<Vec<AuditEvent>> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, timestamp, event_json FROM audit_log ORDER BY id DESC LIMIT ?",
|
||||
)?;
|
||||
|
||||
let rows = stmt.query_map([limit], |row| {
|
||||
let id: i64 = row.get(0)?;
|
||||
let timestamp_str: String = row.get(1)?;
|
||||
let event_json: String = row.get(2)?;
|
||||
Ok((id, timestamp_str, event_json))
|
||||
})?;
|
||||
|
||||
let mut events = Vec::new();
|
||||
for row in rows {
|
||||
let (id, timestamp_str, event_json) = row?;
|
||||
let timestamp = DateTime::parse_from_rfc3339(×tamp_str)
|
||||
.map(|dt| dt.with_timezone(&Local))
|
||||
.unwrap_or_else(|_| Local::now());
|
||||
let event: crate::AuditEventType = serde_json::from_str(&event_json)?;
|
||||
|
||||
events.push(AuditEvent {
|
||||
id,
|
||||
timestamp,
|
||||
event,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
fn get_usage(&self, entry_id: &EntryId, day: NaiveDate) -> StoreResult<Duration> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let day_str = day.format("%Y-%m-%d").to_string();
|
||||
|
||||
let secs: Option<i64> = conn
|
||||
.query_row(
|
||||
"SELECT duration_secs FROM usage WHERE entry_id = ? AND day = ?",
|
||||
params![entry_id.as_str(), day_str],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.optional()?;
|
||||
|
||||
Ok(Duration::from_secs(secs.unwrap_or(0) as u64))
|
||||
}
|
||||
|
||||
fn add_usage(&self, entry_id: &EntryId, day: NaiveDate, duration: Duration) -> StoreResult<()> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let day_str = day.format("%Y-%m-%d").to_string();
|
||||
let secs = duration.as_secs() as i64;
|
||||
|
||||
conn.execute(
|
||||
r#"
|
||||
INSERT INTO usage (entry_id, day, duration_secs)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(entry_id, day)
|
||||
DO UPDATE SET duration_secs = duration_secs + excluded.duration_secs
|
||||
"#,
|
||||
params![entry_id.as_str(), day_str, secs],
|
||||
)?;
|
||||
|
||||
debug!(entry_id = %entry_id, day = %day_str, added_secs = secs, "Usage added");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_cooldown_until(&self, entry_id: &EntryId) -> StoreResult<Option<DateTime<Local>>> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
|
||||
let until_str: Option<String> = conn
|
||||
.query_row(
|
||||
"SELECT until FROM cooldowns WHERE entry_id = ?",
|
||||
[entry_id.as_str()],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.optional()?;
|
||||
|
||||
let result = until_str.and_then(|s| {
|
||||
DateTime::parse_from_rfc3339(&s)
|
||||
.map(|dt| dt.with_timezone(&Local))
|
||||
.ok()
|
||||
});
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn set_cooldown_until(
|
||||
&self,
|
||||
entry_id: &EntryId,
|
||||
until: DateTime<Local>,
|
||||
) -> StoreResult<()> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
|
||||
conn.execute(
|
||||
r#"
|
||||
INSERT INTO cooldowns (entry_id, until)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(entry_id)
|
||||
DO UPDATE SET until = excluded.until
|
||||
"#,
|
||||
params![entry_id.as_str(), until.to_rfc3339()],
|
||||
)?;
|
||||
|
||||
debug!(entry_id = %entry_id, until = %until, "Cooldown set");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clear_cooldown(&self, entry_id: &EntryId) -> StoreResult<()> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
conn.execute("DELETE FROM cooldowns WHERE entry_id = ?", [entry_id.as_str()])?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_snapshot(&self) -> StoreResult<Option<StateSnapshot>> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
|
||||
let json: Option<String> = conn
|
||||
.query_row("SELECT snapshot_json FROM snapshot WHERE id = 1", [], |row| {
|
||||
row.get(0)
|
||||
})
|
||||
.optional()?;
|
||||
|
||||
match json {
|
||||
Some(s) => {
|
||||
let snapshot: StateSnapshot = serde_json::from_str(&s)?;
|
||||
Ok(Some(snapshot))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn save_snapshot(&self, snapshot: &StateSnapshot) -> StoreResult<()> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let json = serde_json::to_string(snapshot)?;
|
||||
|
||||
conn.execute(
|
||||
r#"
|
||||
INSERT INTO snapshot (id, snapshot_json)
|
||||
VALUES (1, ?)
|
||||
ON CONFLICT(id)
|
||||
DO UPDATE SET snapshot_json = excluded.snapshot_json
|
||||
"#,
|
||||
[json],
|
||||
)?;
|
||||
|
||||
debug!("Snapshot saved");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_healthy(&self) -> bool {
|
||||
match self.conn.lock() {
|
||||
Ok(conn) => {
|
||||
conn.query_row("SELECT 1", [], |_| Ok(())).is_ok()
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Store lock poisoned");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::AuditEventType;
|
||||
|
||||
#[test]
|
||||
fn test_in_memory_store() {
|
||||
let store = SqliteStore::in_memory().unwrap();
|
||||
assert!(store.is_healthy());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_audit_log() {
|
||||
let store = SqliteStore::in_memory().unwrap();
|
||||
|
||||
let event = AuditEvent::new(AuditEventType::DaemonStarted);
|
||||
store.append_audit(event).unwrap();
|
||||
|
||||
let events = store.get_recent_audits(10).unwrap();
|
||||
assert_eq!(events.len(), 1);
|
||||
assert!(matches!(events[0].event, AuditEventType::DaemonStarted));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_usage_accounting() {
|
||||
let store = SqliteStore::in_memory().unwrap();
|
||||
let entry_id = EntryId::new("game-1");
|
||||
let today = Local::now().date_naive();
|
||||
|
||||
// Initially zero
|
||||
let usage = store.get_usage(&entry_id, today).unwrap();
|
||||
assert_eq!(usage, Duration::ZERO);
|
||||
|
||||
// Add some usage
|
||||
store
|
||||
.add_usage(&entry_id, today, Duration::from_secs(300))
|
||||
.unwrap();
|
||||
let usage = store.get_usage(&entry_id, today).unwrap();
|
||||
assert_eq!(usage, Duration::from_secs(300));
|
||||
|
||||
// Add more usage
|
||||
store
|
||||
.add_usage(&entry_id, today, Duration::from_secs(200))
|
||||
.unwrap();
|
||||
let usage = store.get_usage(&entry_id, today).unwrap();
|
||||
assert_eq!(usage, Duration::from_secs(500));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cooldowns() {
|
||||
let store = SqliteStore::in_memory().unwrap();
|
||||
let entry_id = EntryId::new("game-1");
|
||||
|
||||
// No cooldown initially
|
||||
assert!(store.get_cooldown_until(&entry_id).unwrap().is_none());
|
||||
|
||||
// Set cooldown
|
||||
let until = Local::now() + chrono::Duration::hours(1);
|
||||
store.set_cooldown_until(&entry_id, until).unwrap();
|
||||
|
||||
let stored = store.get_cooldown_until(&entry_id).unwrap().unwrap();
|
||||
assert!((stored - until).num_seconds().abs() < 1);
|
||||
|
||||
// Clear cooldown
|
||||
store.clear_cooldown(&entry_id).unwrap();
|
||||
assert!(store.get_cooldown_until(&entry_id).unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_snapshot() {
|
||||
let store = SqliteStore::in_memory().unwrap();
|
||||
|
||||
// No snapshot initially
|
||||
assert!(store.load_snapshot().unwrap().is_none());
|
||||
|
||||
// Save snapshot
|
||||
let snapshot = StateSnapshot {
|
||||
timestamp: Local::now(),
|
||||
active_session: None,
|
||||
};
|
||||
store.save_snapshot(&snapshot).unwrap();
|
||||
|
||||
// Load it back
|
||||
let loaded = store.load_snapshot().unwrap().unwrap();
|
||||
assert!(loaded.active_session.is_none());
|
||||
}
|
||||
}
|
||||
74
crates/shepherd-store/src/traits.rs
Normal file
74
crates/shepherd-store/src/traits.rs
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
//! Store trait definitions
|
||||
|
||||
use chrono::{DateTime, Local, NaiveDate};
|
||||
use shepherd_util::{EntryId, SessionId};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::{AuditEvent, StoreResult};
|
||||
|
||||
/// Main store trait
|
||||
pub trait Store: Send + Sync {
|
||||
// Audit log
|
||||
|
||||
/// Append an audit event
|
||||
fn append_audit(&self, event: AuditEvent) -> StoreResult<()>;
|
||||
|
||||
/// Get recent audit events
|
||||
fn get_recent_audits(&self, limit: usize) -> StoreResult<Vec<AuditEvent>>;
|
||||
|
||||
// Usage accounting
|
||||
|
||||
/// Get total usage for an entry on a specific day
|
||||
fn get_usage(&self, entry_id: &EntryId, day: NaiveDate) -> StoreResult<Duration>;
|
||||
|
||||
/// Add usage for an entry on a specific day
|
||||
fn add_usage(&self, entry_id: &EntryId, day: NaiveDate, duration: Duration) -> StoreResult<()>;
|
||||
|
||||
// Cooldown tracking
|
||||
|
||||
/// Get cooldown expiry time for an entry
|
||||
fn get_cooldown_until(&self, entry_id: &EntryId) -> StoreResult<Option<DateTime<Local>>>;
|
||||
|
||||
/// Set cooldown expiry time for an entry
|
||||
fn set_cooldown_until(
|
||||
&self,
|
||||
entry_id: &EntryId,
|
||||
until: DateTime<Local>,
|
||||
) -> StoreResult<()>;
|
||||
|
||||
/// Clear cooldown for an entry
|
||||
fn clear_cooldown(&self, entry_id: &EntryId) -> StoreResult<()>;
|
||||
|
||||
// State snapshot
|
||||
|
||||
/// Load last saved snapshot
|
||||
fn load_snapshot(&self) -> StoreResult<Option<StateSnapshot>>;
|
||||
|
||||
/// Save state snapshot
|
||||
fn save_snapshot(&self, snapshot: &StateSnapshot) -> StoreResult<()>;
|
||||
|
||||
// Health
|
||||
|
||||
/// Check if store is healthy
|
||||
fn is_healthy(&self) -> bool;
|
||||
}
|
||||
|
||||
/// State snapshot for crash recovery
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct StateSnapshot {
|
||||
/// Timestamp of snapshot
|
||||
pub timestamp: DateTime<Local>,
|
||||
|
||||
/// Active session info (if any)
|
||||
pub active_session: Option<SessionSnapshot>,
|
||||
}
|
||||
|
||||
/// Snapshot of an active session
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct SessionSnapshot {
|
||||
pub session_id: SessionId,
|
||||
pub entry_id: EntryId,
|
||||
pub started_at: DateTime<Local>,
|
||||
pub deadline: DateTime<Local>,
|
||||
pub warnings_issued: Vec<u64>,
|
||||
}
|
||||
17
crates/shepherd-util/Cargo.toml
Normal file
17
crates/shepherd-util/Cargo.toml
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
[package]
|
||||
name = "shepherd-util"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Shared utilities for shepherdd: time, IDs, error types, rate limiting"
|
||||
|
||||
[dependencies]
|
||||
serde = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
74
crates/shepherd-util/src/error.rs
Normal file
74
crates/shepherd-util/src/error.rs
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
//! Error types for shepherdd
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::EntryId;
|
||||
|
||||
/// Core error type for shepherdd operations
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ShepherdError {
|
||||
#[error("Entry not found: {0}")]
|
||||
EntryNotFound(EntryId),
|
||||
|
||||
#[error("No active session")]
|
||||
NoActiveSession,
|
||||
|
||||
#[error("Session already active")]
|
||||
SessionAlreadyActive,
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigError(String),
|
||||
|
||||
#[error("Validation error: {0}")]
|
||||
ValidationError(String),
|
||||
|
||||
#[error("Store error: {0}")]
|
||||
StoreError(String),
|
||||
|
||||
#[error("Host error: {0}")]
|
||||
HostError(String),
|
||||
|
||||
#[error("IPC error: {0}")]
|
||||
IpcError(String),
|
||||
|
||||
#[error("Permission denied: {0}")]
|
||||
PermissionDenied(String),
|
||||
|
||||
#[error("Rate limited")]
|
||||
RateLimited,
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl ShepherdError {
|
||||
pub fn config(msg: impl Into<String>) -> Self {
|
||||
Self::ConfigError(msg.into())
|
||||
}
|
||||
|
||||
pub fn validation(msg: impl Into<String>) -> Self {
|
||||
Self::ValidationError(msg.into())
|
||||
}
|
||||
|
||||
pub fn store(msg: impl Into<String>) -> Self {
|
||||
Self::StoreError(msg.into())
|
||||
}
|
||||
|
||||
pub fn host(msg: impl Into<String>) -> Self {
|
||||
Self::HostError(msg.into())
|
||||
}
|
||||
|
||||
pub fn ipc(msg: impl Into<String>) -> Self {
|
||||
Self::IpcError(msg.into())
|
||||
}
|
||||
|
||||
pub fn permission(msg: impl Into<String>) -> Self {
|
||||
Self::PermissionDenied(msg.into())
|
||||
}
|
||||
|
||||
pub fn internal(msg: impl Into<String>) -> Self {
|
||||
Self::Internal(msg.into())
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ShepherdError>;
|
||||
128
crates/shepherd-util/src/ids.rs
Normal file
128
crates/shepherd-util/src/ids.rs
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
//! Strongly-typed identifiers for shepherdd
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Unique identifier for an entry in the policy whitelist
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct EntryId(String);
|
||||
|
||||
impl EntryId {
|
||||
pub fn new(id: impl Into<String>) -> Self {
|
||||
Self(id.into())
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for EntryId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for EntryId {
|
||||
fn from(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for EntryId {
|
||||
fn from(s: &str) -> Self {
|
||||
Self(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier for a running session
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct SessionId(Uuid);
|
||||
|
||||
impl SessionId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4())
|
||||
}
|
||||
|
||||
pub fn from_uuid(uuid: Uuid) -> Self {
|
||||
Self(uuid)
|
||||
}
|
||||
|
||||
pub fn as_uuid(&self) -> &Uuid {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SessionId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SessionId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier for a connected IPC client
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct ClientId(Uuid);
|
||||
|
||||
impl ClientId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4())
|
||||
}
|
||||
|
||||
pub fn from_uuid(uuid: Uuid) -> Self {
|
||||
Self(uuid)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ClientId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn entry_id_equality() {
|
||||
let id1 = EntryId::new("game-1");
|
||||
let id2 = EntryId::new("game-1");
|
||||
let id3 = EntryId::new("game-2");
|
||||
|
||||
assert_eq!(id1, id2);
|
||||
assert_ne!(id1, id3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_id_uniqueness() {
|
||||
let s1 = SessionId::new();
|
||||
let s2 = SessionId::new();
|
||||
assert_ne!(s1, s2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ids_serialize_deserialize() {
|
||||
let entry_id = EntryId::new("test-entry");
|
||||
let json = serde_json::to_string(&entry_id).unwrap();
|
||||
let parsed: EntryId = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(entry_id, parsed);
|
||||
|
||||
let session_id = SessionId::new();
|
||||
let json = serde_json::to_string(&session_id).unwrap();
|
||||
let parsed: SessionId = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(session_id, parsed);
|
||||
}
|
||||
}
|
||||
17
crates/shepherd-util/src/lib.rs
Normal file
17
crates/shepherd-util/src/lib.rs
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
//! Shared utilities for shepherdd
|
||||
//!
|
||||
//! This crate provides:
|
||||
//! - ID types (EntryId, SessionId, ClientId)
|
||||
//! - Time utilities (monotonic time, duration helpers)
|
||||
//! - Error types
|
||||
//! - Rate limiting helpers
|
||||
|
||||
mod error;
|
||||
mod ids;
|
||||
mod rate_limit;
|
||||
mod time;
|
||||
|
||||
pub use error::*;
|
||||
pub use ids::*;
|
||||
pub use rate_limit::*;
|
||||
pub use time::*;
|
||||
112
crates/shepherd-util/src/rate_limit.rs
Normal file
112
crates/shepherd-util/src/rate_limit.rs
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
//! Rate limiting utilities
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::ClientId;
|
||||
|
||||
/// Simple token-bucket rate limiter
|
||||
#[derive(Debug)]
|
||||
pub struct RateLimiter {
|
||||
/// Maximum tokens (requests) per bucket
|
||||
max_tokens: u32,
|
||||
/// How often tokens are replenished
|
||||
refill_interval: Duration,
|
||||
/// Per-client state
|
||||
clients: HashMap<ClientId, ClientBucket>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ClientBucket {
|
||||
tokens: u32,
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
/// Create a new rate limiter
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max_requests` - Maximum requests allowed per interval
|
||||
/// * `interval` - Time interval for the limit
|
||||
pub fn new(max_requests: u32, interval: Duration) -> Self {
|
||||
Self {
|
||||
max_tokens: max_requests,
|
||||
refill_interval: interval,
|
||||
clients: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a request should be allowed for the given client
|
||||
///
|
||||
/// Returns `true` if allowed, `false` if rate limited
|
||||
pub fn check(&mut self, client_id: &ClientId) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
let bucket = self.clients.entry(client_id.clone()).or_insert(ClientBucket {
|
||||
tokens: self.max_tokens,
|
||||
last_refill: now,
|
||||
});
|
||||
|
||||
// Refill tokens if interval has passed
|
||||
let elapsed = now.duration_since(bucket.last_refill);
|
||||
if elapsed >= self.refill_interval {
|
||||
let intervals = (elapsed.as_millis() / self.refill_interval.as_millis()) as u32;
|
||||
bucket.tokens = (bucket.tokens + intervals * self.max_tokens).min(self.max_tokens);
|
||||
bucket.last_refill = now;
|
||||
}
|
||||
|
||||
// Try to consume a token
|
||||
if bucket.tokens > 0 {
|
||||
bucket.tokens -= 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a client's rate limit state
|
||||
pub fn remove_client(&mut self, client_id: &ClientId) {
|
||||
self.clients.remove(client_id);
|
||||
}
|
||||
|
||||
/// Clean up stale client entries
|
||||
pub fn cleanup(&mut self, stale_after: Duration) {
|
||||
let now = Instant::now();
|
||||
self.clients.retain(|_, bucket| {
|
||||
now.duration_since(bucket.last_refill) < stale_after
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiter_allows_within_limit() {
|
||||
let mut limiter = RateLimiter::new(5, Duration::from_secs(1));
|
||||
let client = ClientId::new();
|
||||
|
||||
for _ in 0..5 {
|
||||
assert!(limiter.check(&client));
|
||||
}
|
||||
|
||||
// 6th request should be denied
|
||||
assert!(!limiter.check(&client));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiter_different_clients() {
|
||||
let mut limiter = RateLimiter::new(2, Duration::from_secs(1));
|
||||
let client1 = ClientId::new();
|
||||
let client2 = ClientId::new();
|
||||
|
||||
assert!(limiter.check(&client1));
|
||||
assert!(limiter.check(&client1));
|
||||
assert!(!limiter.check(&client1));
|
||||
|
||||
// Client 2 should have its own bucket
|
||||
assert!(limiter.check(&client2));
|
||||
assert!(limiter.check(&client2));
|
||||
}
|
||||
}
|
||||
301
crates/shepherd-util/src/time.rs
Normal file
301
crates/shepherd-util/src/time.rs
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
//! Time utilities for shepherdd
|
||||
//!
|
||||
//! Provides both monotonic time (for countdown enforcement) and
|
||||
//! wall-clock time (for availability windows).
|
||||
|
||||
use chrono::{DateTime, Datelike, Local, NaiveTime, Timelike, Weekday};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Represents a point in monotonic time for countdown enforcement.
|
||||
/// This is immune to wall-clock changes.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct MonotonicInstant(Instant);
|
||||
|
||||
impl MonotonicInstant {
|
||||
pub fn now() -> Self {
|
||||
Self(Instant::now())
|
||||
}
|
||||
|
||||
pub fn elapsed(&self) -> Duration {
|
||||
self.0.elapsed()
|
||||
}
|
||||
|
||||
pub fn duration_since(&self, earlier: MonotonicInstant) -> Duration {
|
||||
self.0.duration_since(earlier.0)
|
||||
}
|
||||
|
||||
pub fn checked_add(&self, duration: Duration) -> Option<MonotonicInstant> {
|
||||
self.0.checked_add(duration).map(MonotonicInstant)
|
||||
}
|
||||
|
||||
/// Returns duration until `self`, or zero if `self` is in the past
|
||||
pub fn saturating_duration_until(&self, from: MonotonicInstant) -> Duration {
|
||||
if self.0 > from.0 {
|
||||
self.0.duration_since(from.0)
|
||||
} else {
|
||||
Duration::ZERO
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Add<Duration> for MonotonicInstant {
|
||||
type Output = MonotonicInstant;
|
||||
|
||||
fn add(self, rhs: Duration) -> Self::Output {
|
||||
MonotonicInstant(self.0 + rhs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Wall-clock time for availability windows
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct WallClock {
|
||||
pub hour: u8,
|
||||
pub minute: u8,
|
||||
}
|
||||
|
||||
impl WallClock {
|
||||
pub fn new(hour: u8, minute: u8) -> Option<Self> {
|
||||
if hour < 24 && minute < 60 {
|
||||
Some(Self { hour, minute })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_naive_time(self) -> NaiveTime {
|
||||
NaiveTime::from_hms_opt(self.hour as u32, self.minute as u32, 0).unwrap()
|
||||
}
|
||||
|
||||
pub fn from_naive_time(time: NaiveTime) -> Self {
|
||||
Self {
|
||||
hour: time.hour() as u8,
|
||||
minute: time.minute() as u8,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns seconds since midnight
|
||||
pub fn as_seconds_from_midnight(&self) -> u32 {
|
||||
(self.hour as u32) * 3600 + (self.minute as u32) * 60
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for WallClock {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for WallClock {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.as_seconds_from_midnight()
|
||||
.cmp(&other.as_seconds_from_midnight())
|
||||
}
|
||||
}
|
||||
|
||||
/// Days of the week mask
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
pub struct DaysOfWeek(u8);
|
||||
|
||||
impl DaysOfWeek {
|
||||
pub const MONDAY: u8 = 1 << 0;
|
||||
pub const TUESDAY: u8 = 1 << 1;
|
||||
pub const WEDNESDAY: u8 = 1 << 2;
|
||||
pub const THURSDAY: u8 = 1 << 3;
|
||||
pub const FRIDAY: u8 = 1 << 4;
|
||||
pub const SATURDAY: u8 = 1 << 5;
|
||||
pub const SUNDAY: u8 = 1 << 6;
|
||||
|
||||
pub const WEEKDAYS: DaysOfWeek = DaysOfWeek(
|
||||
Self::MONDAY | Self::TUESDAY | Self::WEDNESDAY | Self::THURSDAY | Self::FRIDAY,
|
||||
);
|
||||
pub const WEEKENDS: DaysOfWeek = DaysOfWeek(Self::SATURDAY | Self::SUNDAY);
|
||||
pub const ALL_DAYS: DaysOfWeek = DaysOfWeek(0x7F);
|
||||
pub const NONE: DaysOfWeek = DaysOfWeek(0);
|
||||
|
||||
pub fn new(mask: u8) -> Self {
|
||||
Self(mask & 0x7F)
|
||||
}
|
||||
|
||||
pub fn contains(&self, weekday: Weekday) -> bool {
|
||||
let bit = match weekday {
|
||||
Weekday::Mon => Self::MONDAY,
|
||||
Weekday::Tue => Self::TUESDAY,
|
||||
Weekday::Wed => Self::WEDNESDAY,
|
||||
Weekday::Thu => Self::THURSDAY,
|
||||
Weekday::Fri => Self::FRIDAY,
|
||||
Weekday::Sat => Self::SATURDAY,
|
||||
Weekday::Sun => Self::SUNDAY,
|
||||
};
|
||||
(self.0 & bit) != 0
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::BitOr for DaysOfWeek {
|
||||
type Output = Self;
|
||||
|
||||
fn bitor(self, rhs: Self) -> Self::Output {
|
||||
Self(self.0 | rhs.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// A time window during which an entry is available
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct TimeWindow {
|
||||
pub days: DaysOfWeek,
|
||||
pub start: WallClock,
|
||||
pub end: WallClock,
|
||||
}
|
||||
|
||||
impl TimeWindow {
|
||||
pub fn new(days: DaysOfWeek, start: WallClock, end: WallClock) -> Self {
|
||||
Self { days, start, end }
|
||||
}
|
||||
|
||||
/// Check if the given local datetime falls within this window
|
||||
pub fn contains(&self, dt: &DateTime<Local>) -> bool {
|
||||
let weekday = dt.weekday();
|
||||
if !self.days.contains(weekday) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let time = WallClock::from_naive_time(dt.time());
|
||||
|
||||
// Handle windows that don't cross midnight
|
||||
if self.start <= self.end {
|
||||
time >= self.start && time < self.end
|
||||
} else {
|
||||
// Window crosses midnight (e.g., 22:00 - 02:00)
|
||||
time >= self.start || time < self.end
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate duration remaining in this window from the given time
|
||||
pub fn remaining_duration(&self, dt: &DateTime<Local>) -> Option<Duration> {
|
||||
if !self.contains(dt) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let now_time = WallClock::from_naive_time(dt.time());
|
||||
let now_secs = now_time.as_seconds_from_midnight();
|
||||
let end_secs = self.end.as_seconds_from_midnight();
|
||||
|
||||
let remaining_secs = if self.start <= self.end {
|
||||
// Normal window
|
||||
end_secs.saturating_sub(now_secs)
|
||||
} else {
|
||||
// Cross-midnight window
|
||||
if now_secs >= self.start.as_seconds_from_midnight() {
|
||||
// We're in the evening portion, count until midnight then add morning
|
||||
(86400 - now_secs) + end_secs
|
||||
} else {
|
||||
// We're in the morning portion
|
||||
end_secs.saturating_sub(now_secs)
|
||||
}
|
||||
};
|
||||
|
||||
Some(Duration::from_secs(remaining_secs as u64))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to format durations in human-readable form
|
||||
pub fn format_duration(d: Duration) -> String {
|
||||
let total_secs = d.as_secs();
|
||||
let hours = total_secs / 3600;
|
||||
let minutes = (total_secs % 3600) / 60;
|
||||
let seconds = total_secs % 60;
|
||||
|
||||
if hours > 0 {
|
||||
format!("{}h {}m {}s", hours, minutes, seconds)
|
||||
} else if minutes > 0 {
|
||||
format!("{}m {}s", minutes, seconds)
|
||||
} else {
|
||||
format!("{}s", seconds)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::TimeZone;
|
||||
|
||||
#[test]
|
||||
fn test_wall_clock_ordering() {
|
||||
let morning = WallClock::new(8, 0).unwrap();
|
||||
let noon = WallClock::new(12, 0).unwrap();
|
||||
let evening = WallClock::new(18, 30).unwrap();
|
||||
|
||||
assert!(morning < noon);
|
||||
assert!(noon < evening);
|
||||
assert!(morning < evening);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_days_of_week() {
|
||||
let weekdays = DaysOfWeek::WEEKDAYS;
|
||||
assert!(weekdays.contains(Weekday::Mon));
|
||||
assert!(weekdays.contains(Weekday::Fri));
|
||||
assert!(!weekdays.contains(Weekday::Sat));
|
||||
assert!(!weekdays.contains(Weekday::Sun));
|
||||
|
||||
let weekends = DaysOfWeek::WEEKENDS;
|
||||
assert!(!weekends.contains(Weekday::Mon));
|
||||
assert!(weekends.contains(Weekday::Sat));
|
||||
assert!(weekends.contains(Weekday::Sun));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_time_window_contains() {
|
||||
let window = TimeWindow::new(
|
||||
DaysOfWeek::WEEKDAYS,
|
||||
WallClock::new(14, 0).unwrap(), // 2 PM
|
||||
WallClock::new(18, 0).unwrap(), // 6 PM
|
||||
);
|
||||
|
||||
// Monday at 3 PM - should be in window
|
||||
let dt = Local.with_ymd_and_hms(2025, 12, 29, 15, 0, 0).unwrap(); // Monday
|
||||
assert!(window.contains(&dt));
|
||||
|
||||
// Monday at 10 AM - outside window
|
||||
let dt = Local.with_ymd_and_hms(2025, 12, 29, 10, 0, 0).unwrap();
|
||||
assert!(!window.contains(&dt));
|
||||
|
||||
// Saturday at 3 PM - wrong day
|
||||
let dt = Local.with_ymd_and_hms(2025, 12, 27, 15, 0, 0).unwrap();
|
||||
assert!(!window.contains(&dt));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_time_window_remaining() {
|
||||
let window = TimeWindow::new(
|
||||
DaysOfWeek::ALL_DAYS,
|
||||
WallClock::new(14, 0).unwrap(),
|
||||
WallClock::new(18, 0).unwrap(),
|
||||
);
|
||||
|
||||
let dt = Local.with_ymd_and_hms(2025, 12, 26, 15, 0, 0).unwrap(); // 3 PM
|
||||
let remaining = window.remaining_duration(&dt).unwrap();
|
||||
assert_eq!(remaining, Duration::from_secs(3 * 3600)); // 3 hours
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_duration() {
|
||||
assert_eq!(format_duration(Duration::from_secs(30)), "30s");
|
||||
assert_eq!(format_duration(Duration::from_secs(90)), "1m 30s");
|
||||
assert_eq!(format_duration(Duration::from_secs(3661)), "1h 1m 1s");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_monotonic_instant() {
|
||||
let t1 = MonotonicInstant::now();
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
let t2 = MonotonicInstant::now();
|
||||
|
||||
assert!(t2 > t1);
|
||||
assert!(t2.duration_since(t1) >= Duration::from_millis(10));
|
||||
}
|
||||
}
|
||||
32
crates/shepherdd/Cargo.toml
Normal file
32
crates/shepherdd/Cargo.toml
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
[package]
|
||||
name = "shepherdd"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "The shepherdd daemon: policy enforcement for child-focused computing"
|
||||
|
||||
[[bin]]
|
||||
name = "shepherdd"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
shepherd-util = { workspace = true }
|
||||
shepherd-api = { workspace = true }
|
||||
shepherd-host-api = { workspace = true }
|
||||
shepherd-config = { workspace = true }
|
||||
shepherd-store = { workspace = true }
|
||||
shepherd-core = { workspace = true }
|
||||
shepherd-host-linux = { workspace = true }
|
||||
shepherd-ipc = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
659
crates/shepherdd/src/main.rs
Normal file
659
crates/shepherdd/src/main.rs
Normal file
|
|
@ -0,0 +1,659 @@
|
|||
//! shepherdd - The shepherd daemon
|
||||
//!
|
||||
//! This is the main entry point for the shepherdd service.
|
||||
//! It wires together all the components:
|
||||
//! - Configuration loading
|
||||
//! - Store initialization
|
||||
//! - Core engine
|
||||
//! - Host adapter (Linux)
|
||||
//! - IPC server
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::Local;
|
||||
use clap::Parser;
|
||||
use shepherd_api::{
|
||||
Command, DaemonStateSnapshot, ErrorCode, ErrorInfo, Event, EventPayload, HealthStatus,
|
||||
Response, ResponsePayload, SessionEndReason, StopMode, API_VERSION,
|
||||
};
|
||||
use shepherd_config::{load_config, Policy};
|
||||
use shepherd_core::{CoreEngine, CoreEvent, LaunchDecision, StopDecision};
|
||||
use shepherd_host_api::{HostAdapter, HostEvent, StopMode as HostStopMode};
|
||||
use shepherd_host_linux::LinuxHost;
|
||||
use shepherd_ipc::{IpcServer, ServerMessage};
|
||||
use shepherd_store::{AuditEvent, AuditEventType, SqliteStore, Store};
|
||||
use shepherd_util::{ClientId, EntryId, MonotonicInstant, RateLimiter};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, error, info, warn, Level};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
/// shepherdd - Policy enforcement daemon for child-focused computing
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "shepherdd")]
|
||||
#[command(about = "Policy enforcement daemon for child-focused computing", long_about = None)]
|
||||
struct Args {
|
||||
/// Configuration file path
|
||||
#[arg(short, long, default_value = "/etc/shepherdd/config.toml")]
|
||||
config: PathBuf,
|
||||
|
||||
/// Socket path override
|
||||
#[arg(short, long)]
|
||||
socket: Option<PathBuf>,
|
||||
|
||||
/// Data directory override
|
||||
#[arg(short, long)]
|
||||
data_dir: Option<PathBuf>,
|
||||
|
||||
/// Log level
|
||||
#[arg(short, long, default_value = "info")]
|
||||
log_level: String,
|
||||
}
|
||||
|
||||
/// Main daemon state
|
||||
struct Daemon {
|
||||
engine: CoreEngine,
|
||||
host: Arc<LinuxHost>,
|
||||
ipc: Arc<IpcServer>,
|
||||
store: Arc<dyn Store>,
|
||||
rate_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl Daemon {
|
||||
async fn new(args: &Args) -> Result<Self> {
|
||||
// Load configuration
|
||||
let policy = load_config(&args.config)
|
||||
.with_context(|| format!("Failed to load config from {:?}", args.config))?;
|
||||
|
||||
info!(
|
||||
config_path = %args.config.display(),
|
||||
entry_count = policy.entries.len(),
|
||||
"Configuration loaded"
|
||||
);
|
||||
|
||||
// Determine paths
|
||||
let socket_path = args
|
||||
.socket
|
||||
.clone()
|
||||
.unwrap_or_else(|| policy.daemon.socket_path.clone());
|
||||
|
||||
let data_dir = args
|
||||
.data_dir
|
||||
.clone()
|
||||
.unwrap_or_else(|| policy.daemon.data_dir.clone());
|
||||
|
||||
// Create data directory
|
||||
std::fs::create_dir_all(&data_dir)
|
||||
.with_context(|| format!("Failed to create data directory {:?}", data_dir))?;
|
||||
|
||||
// Initialize store
|
||||
let db_path = data_dir.join("shepherdd.db");
|
||||
let store: Arc<dyn Store> = Arc::new(
|
||||
SqliteStore::open(&db_path)
|
||||
.with_context(|| format!("Failed to open database {:?}", db_path))?,
|
||||
);
|
||||
|
||||
info!(db_path = %db_path.display(), "Store initialized");
|
||||
|
||||
// Log daemon start
|
||||
store.append_audit(AuditEvent::new(AuditEventType::DaemonStarted))?;
|
||||
|
||||
// Initialize host adapter
|
||||
let host = Arc::new(LinuxHost::new());
|
||||
|
||||
// Initialize core engine
|
||||
let engine = CoreEngine::new(policy, store.clone(), host.capabilities().clone());
|
||||
|
||||
// Initialize IPC server
|
||||
let mut ipc = IpcServer::new(&socket_path);
|
||||
ipc.start().await?;
|
||||
|
||||
info!(socket_path = %socket_path.display(), "IPC server started");
|
||||
|
||||
// Rate limiter: 30 requests per second per client
|
||||
let rate_limiter = RateLimiter::new(30, Duration::from_secs(1));
|
||||
|
||||
Ok(Self {
|
||||
engine,
|
||||
host,
|
||||
ipc: Arc::new(ipc),
|
||||
store,
|
||||
rate_limiter,
|
||||
})
|
||||
}
|
||||
|
||||
async fn run(self) -> Result<()> {
|
||||
// Start host process monitor
|
||||
let _monitor_handle = self.host.start_monitor();
|
||||
|
||||
// Get channels
|
||||
let mut host_events = self.host.subscribe();
|
||||
let ipc_ref = self.ipc.clone();
|
||||
let mut ipc_messages = ipc_ref
|
||||
.take_message_receiver()
|
||||
.await
|
||||
.expect("Message receiver should be available");
|
||||
|
||||
// Wrap mutable state
|
||||
let engine = Arc::new(Mutex::new(self.engine));
|
||||
let rate_limiter = Arc::new(Mutex::new(self.rate_limiter));
|
||||
let host = self.host.clone();
|
||||
let store = self.store.clone();
|
||||
|
||||
// Spawn IPC accept task
|
||||
let ipc_accept = ipc_ref.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = ipc_accept.run().await {
|
||||
error!(error = %e, "IPC server error");
|
||||
}
|
||||
});
|
||||
|
||||
// Main event loop
|
||||
let tick_interval = Duration::from_millis(100);
|
||||
let mut tick_timer = tokio::time::interval(tick_interval);
|
||||
|
||||
info!("Daemon running");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Tick timer - check warnings and expiry
|
||||
_ = tick_timer.tick() => {
|
||||
let now_mono = MonotonicInstant::now();
|
||||
let now = Local::now();
|
||||
|
||||
let events = {
|
||||
let mut engine = engine.lock().await;
|
||||
engine.tick(now_mono)
|
||||
};
|
||||
|
||||
for event in events {
|
||||
Self::handle_core_event(&engine, &host, &ipc_ref, event, now_mono, now).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Host events (process exit)
|
||||
Some(host_event) = host_events.recv() => {
|
||||
Self::handle_host_event(&engine, &ipc_ref, host_event).await;
|
||||
}
|
||||
|
||||
// IPC messages
|
||||
Some(msg) = ipc_messages.recv() => {
|
||||
Self::handle_ipc_message(&engine, &host, &ipc_ref, &store, &rate_limiter, msg).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_core_event(
|
||||
engine: &Arc<Mutex<CoreEngine>>,
|
||||
host: &Arc<LinuxHost>,
|
||||
ipc: &Arc<IpcServer>,
|
||||
event: CoreEvent,
|
||||
_now_mono: MonotonicInstant,
|
||||
_now: chrono::DateTime<Local>,
|
||||
) {
|
||||
match &event {
|
||||
CoreEvent::Warning {
|
||||
session_id,
|
||||
threshold_seconds,
|
||||
time_remaining,
|
||||
severity,
|
||||
message,
|
||||
} => {
|
||||
info!(
|
||||
session_id = %session_id,
|
||||
threshold = threshold_seconds,
|
||||
remaining = ?time_remaining,
|
||||
"Warning issued"
|
||||
);
|
||||
|
||||
ipc.broadcast_event(Event::new(EventPayload::WarningIssued {
|
||||
session_id: session_id.clone(),
|
||||
threshold_seconds: *threshold_seconds,
|
||||
time_remaining: *time_remaining,
|
||||
severity: *severity,
|
||||
message: message.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
CoreEvent::ExpireDue { session_id } => {
|
||||
info!(session_id = %session_id, "Session expired, stopping");
|
||||
|
||||
// Get the host handle and stop it
|
||||
let handle = {
|
||||
let engine = engine.lock().await;
|
||||
engine
|
||||
.current_session()
|
||||
.and_then(|s| s.host_handle.clone())
|
||||
};
|
||||
|
||||
if let Some(handle) = handle {
|
||||
if let Err(e) = host
|
||||
.stop(
|
||||
&handle,
|
||||
HostStopMode::Graceful {
|
||||
timeout: Duration::from_secs(5),
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(error = %e, "Failed to stop session gracefully, forcing");
|
||||
let _ = host.stop(&handle, HostStopMode::Force).await;
|
||||
}
|
||||
}
|
||||
|
||||
ipc.broadcast_event(Event::new(EventPayload::SessionExpiring {
|
||||
session_id: session_id.clone(),
|
||||
}));
|
||||
}
|
||||
|
||||
CoreEvent::SessionStarted {
|
||||
session_id,
|
||||
entry_id,
|
||||
label,
|
||||
deadline,
|
||||
} => {
|
||||
ipc.broadcast_event(Event::new(EventPayload::SessionStarted {
|
||||
session_id: session_id.clone(),
|
||||
entry_id: entry_id.clone(),
|
||||
label: label.clone(),
|
||||
deadline: *deadline,
|
||||
}));
|
||||
}
|
||||
|
||||
CoreEvent::SessionEnded {
|
||||
session_id,
|
||||
entry_id,
|
||||
reason,
|
||||
duration,
|
||||
} => {
|
||||
ipc.broadcast_event(Event::new(EventPayload::SessionEnded {
|
||||
session_id: session_id.clone(),
|
||||
entry_id: entry_id.clone(),
|
||||
reason: reason.clone(),
|
||||
duration: *duration,
|
||||
}));
|
||||
|
||||
// Broadcast state change
|
||||
let state = {
|
||||
let engine = engine.lock().await;
|
||||
engine.get_state()
|
||||
};
|
||||
ipc.broadcast_event(Event::new(EventPayload::StateChanged(state)));
|
||||
}
|
||||
|
||||
CoreEvent::PolicyReloaded { entry_count } => {
|
||||
ipc.broadcast_event(Event::new(EventPayload::PolicyReloaded {
|
||||
entry_count: *entry_count,
|
||||
}));
|
||||
}
|
||||
|
||||
CoreEvent::EntryAvailabilityChanged { entry_id, enabled } => {
|
||||
ipc.broadcast_event(Event::new(EventPayload::EntryAvailabilityChanged {
|
||||
entry_id: entry_id.clone(),
|
||||
enabled: *enabled,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_host_event(
|
||||
engine: &Arc<Mutex<CoreEngine>>,
|
||||
ipc: &Arc<IpcServer>,
|
||||
event: HostEvent,
|
||||
) {
|
||||
match event {
|
||||
HostEvent::Exited { handle, status } => {
|
||||
let now_mono = MonotonicInstant::now();
|
||||
let now = Local::now();
|
||||
|
||||
debug!(
|
||||
session_id = %handle.session_id,
|
||||
status = ?status,
|
||||
"Host process exited"
|
||||
);
|
||||
|
||||
let core_event = {
|
||||
let mut engine = engine.lock().await;
|
||||
engine.notify_session_exited(status.code, now_mono, now)
|
||||
};
|
||||
|
||||
if let Some(event) = core_event {
|
||||
if let CoreEvent::SessionEnded {
|
||||
session_id,
|
||||
entry_id,
|
||||
reason,
|
||||
duration,
|
||||
} = event
|
||||
{
|
||||
ipc.broadcast_event(Event::new(EventPayload::SessionEnded {
|
||||
session_id,
|
||||
entry_id,
|
||||
reason,
|
||||
duration,
|
||||
}));
|
||||
|
||||
// Broadcast state change
|
||||
let state = {
|
||||
let engine = engine.lock().await;
|
||||
engine.get_state()
|
||||
};
|
||||
ipc.broadcast_event(Event::new(EventPayload::StateChanged(state)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HostEvent::WindowReady { handle } => {
|
||||
debug!(session_id = %handle.session_id, "Window ready");
|
||||
}
|
||||
|
||||
HostEvent::SpawnFailed { session_id, error } => {
|
||||
error!(session_id = %session_id, error = %error, "Spawn failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_ipc_message(
|
||||
engine: &Arc<Mutex<CoreEngine>>,
|
||||
host: &Arc<LinuxHost>,
|
||||
ipc: &Arc<IpcServer>,
|
||||
store: &Arc<dyn Store>,
|
||||
rate_limiter: &Arc<Mutex<RateLimiter>>,
|
||||
msg: ServerMessage,
|
||||
) {
|
||||
match msg {
|
||||
ServerMessage::Request { client_id, request } => {
|
||||
// Rate limiting
|
||||
{
|
||||
let mut limiter = rate_limiter.lock().await;
|
||||
if !limiter.check(&client_id) {
|
||||
let response = Response::error(
|
||||
request.request_id,
|
||||
ErrorInfo::new(ErrorCode::RateLimited, "Too many requests"),
|
||||
);
|
||||
let _ = ipc.send_response(&client_id, response).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let response =
|
||||
Self::handle_command(engine, host, ipc, store, &client_id, request.request_id, request.command)
|
||||
.await;
|
||||
|
||||
let _ = ipc.send_response(&client_id, response).await;
|
||||
}
|
||||
|
||||
ServerMessage::ClientConnected { client_id, info } => {
|
||||
info!(
|
||||
client_id = %client_id,
|
||||
role = ?info.role,
|
||||
uid = ?info.uid,
|
||||
"Client connected"
|
||||
);
|
||||
|
||||
let _ = store.append_audit(AuditEvent::new(
|
||||
AuditEventType::ClientConnected {
|
||||
client_id: client_id.to_string(),
|
||||
role: format!("{:?}", info.role),
|
||||
uid: info.uid,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
ServerMessage::ClientDisconnected { client_id } => {
|
||||
debug!(client_id = %client_id, "Client disconnected");
|
||||
|
||||
let _ = store.append_audit(AuditEvent::new(
|
||||
AuditEventType::ClientDisconnected {
|
||||
client_id: client_id.to_string(),
|
||||
},
|
||||
));
|
||||
|
||||
// Clean up rate limiter
|
||||
let mut limiter = rate_limiter.lock().await;
|
||||
limiter.remove_client(&client_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_command(
|
||||
engine: &Arc<Mutex<CoreEngine>>,
|
||||
host: &Arc<LinuxHost>,
|
||||
ipc: &Arc<IpcServer>,
|
||||
store: &Arc<dyn Store>,
|
||||
client_id: &ClientId,
|
||||
request_id: u64,
|
||||
command: Command,
|
||||
) -> Response {
|
||||
let now = Local::now();
|
||||
let now_mono = MonotonicInstant::now();
|
||||
|
||||
match command {
|
||||
Command::GetState => {
|
||||
let state = engine.lock().await.get_state();
|
||||
Response::success(request_id, ResponsePayload::State(state))
|
||||
}
|
||||
|
||||
Command::ListEntries { at_time } => {
|
||||
let time = at_time.unwrap_or(now);
|
||||
let entries = engine.lock().await.list_entries(time);
|
||||
Response::success(request_id, ResponsePayload::Entries(entries))
|
||||
}
|
||||
|
||||
Command::Launch { entry_id } => {
|
||||
let mut eng = engine.lock().await;
|
||||
|
||||
match eng.request_launch(&entry_id, now) {
|
||||
LaunchDecision::Approved(plan) => {
|
||||
// Start the session in the engine
|
||||
let event = eng.start_session(plan.clone(), now, now_mono);
|
||||
|
||||
// Get the entry kind for spawning
|
||||
let entry_kind = eng
|
||||
.policy()
|
||||
.get_entry(&entry_id)
|
||||
.map(|e| e.kind.clone());
|
||||
|
||||
drop(eng); // Release lock before spawning
|
||||
|
||||
if let Some(kind) = entry_kind {
|
||||
match host
|
||||
.spawn(
|
||||
plan.session_id.clone(),
|
||||
&kind,
|
||||
shepherd_host_api::SpawnOptions::default(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(handle) => {
|
||||
// Attach handle to session
|
||||
let mut eng = engine.lock().await;
|
||||
eng.attach_host_handle(handle);
|
||||
|
||||
// Broadcast session started
|
||||
if let CoreEvent::SessionStarted {
|
||||
session_id,
|
||||
entry_id,
|
||||
label,
|
||||
deadline,
|
||||
} = event
|
||||
{
|
||||
ipc.broadcast_event(Event::new(EventPayload::SessionStarted {
|
||||
session_id: session_id.clone(),
|
||||
entry_id,
|
||||
label,
|
||||
deadline,
|
||||
}));
|
||||
|
||||
Response::success(
|
||||
request_id,
|
||||
ResponsePayload::LaunchApproved {
|
||||
session_id,
|
||||
deadline,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
Response::error(
|
||||
request_id,
|
||||
ErrorInfo::new(ErrorCode::InternalError, "Unexpected event"),
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Notify session ended with error
|
||||
let mut eng = engine.lock().await;
|
||||
eng.notify_session_exited(Some(-1), now_mono, now);
|
||||
|
||||
Response::error(
|
||||
request_id,
|
||||
ErrorInfo::new(
|
||||
ErrorCode::HostError,
|
||||
format!("Spawn failed: {}", e),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Response::error(
|
||||
request_id,
|
||||
ErrorInfo::new(ErrorCode::EntryNotFound, "Entry not found"),
|
||||
)
|
||||
}
|
||||
}
|
||||
LaunchDecision::Denied { reasons } => {
|
||||
Response::success(request_id, ResponsePayload::LaunchDenied { reasons })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Command::StopCurrent { mode } => {
|
||||
let mut eng = engine.lock().await;
|
||||
|
||||
// Get handle before stopping in engine
|
||||
let handle = eng
|
||||
.current_session()
|
||||
.and_then(|s| s.host_handle.clone());
|
||||
|
||||
let reason = match mode {
|
||||
StopMode::Graceful => SessionEndReason::UserStop,
|
||||
StopMode::Force => SessionEndReason::AdminStop,
|
||||
};
|
||||
|
||||
match eng.stop_current(reason, now_mono, now) {
|
||||
StopDecision::Stopped(_result) => {
|
||||
drop(eng); // Release lock before host operations
|
||||
|
||||
// Stop the actual process
|
||||
if let Some(h) = handle {
|
||||
let host_mode = match mode {
|
||||
StopMode::Graceful => HostStopMode::Graceful {
|
||||
timeout: Duration::from_secs(5),
|
||||
},
|
||||
StopMode::Force => HostStopMode::Force,
|
||||
};
|
||||
let _ = host.stop(&h, host_mode).await;
|
||||
}
|
||||
|
||||
Response::success(request_id, ResponsePayload::Stopped)
|
||||
}
|
||||
StopDecision::NoActiveSession => Response::error(
|
||||
request_id,
|
||||
ErrorInfo::new(ErrorCode::NoActiveSession, "No active session"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
Command::ReloadConfig => {
|
||||
// Check permission
|
||||
if let Some(info) = ipc.get_client_info(client_id).await {
|
||||
if !info.role.can_reload_config() {
|
||||
return Response::error(
|
||||
request_id,
|
||||
ErrorInfo::new(ErrorCode::PermissionDenied, "Admin role required"),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Reload from original config path
|
||||
Response::error(
|
||||
request_id,
|
||||
ErrorInfo::new(ErrorCode::InternalError, "Reload not yet implemented"),
|
||||
)
|
||||
}
|
||||
|
||||
Command::SubscribeEvents => {
|
||||
Response::success(
|
||||
request_id,
|
||||
ResponsePayload::Subscribed {
|
||||
client_id: client_id.clone(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
Command::UnsubscribeEvents => {
|
||||
Response::success(request_id, ResponsePayload::Unsubscribed)
|
||||
}
|
||||
|
||||
Command::GetHealth => {
|
||||
let _eng = engine.lock().await;
|
||||
let health = HealthStatus {
|
||||
live: true,
|
||||
ready: true,
|
||||
policy_loaded: true,
|
||||
host_adapter_ok: host.is_healthy(),
|
||||
store_ok: store.is_healthy(),
|
||||
};
|
||||
Response::success(request_id, ResponsePayload::Health(health))
|
||||
}
|
||||
|
||||
Command::ExtendCurrent { by } => {
|
||||
// Check permission
|
||||
if let Some(info) = ipc.get_client_info(client_id).await {
|
||||
if !info.role.can_extend() {
|
||||
return Response::error(
|
||||
request_id,
|
||||
ErrorInfo::new(ErrorCode::PermissionDenied, "Admin role required"),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let mut eng = engine.lock().await;
|
||||
match eng.extend_current(by, now_mono, now) {
|
||||
Some(new_deadline) => {
|
||||
Response::success(request_id, ResponsePayload::Extended { new_deadline })
|
||||
}
|
||||
None => Response::error(
|
||||
request_id,
|
||||
ErrorInfo::new(ErrorCode::NoActiveSession, "No active session"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
Command::Ping => Response::success(request_id, ResponsePayload::Pong),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Initialize logging
|
||||
let filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new(&args.log_level));
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.with_target(true)
|
||||
.init();
|
||||
|
||||
info!(
|
||||
version = env!("CARGO_PKG_VERSION"),
|
||||
"shepherdd starting"
|
||||
);
|
||||
|
||||
// Create and run daemon
|
||||
let daemon = Daemon::new(&args).await?;
|
||||
daemon.run().await
|
||||
}
|
||||
321
crates/shepherdd/tests/integration.rs
Normal file
321
crates/shepherdd/tests/integration.rs
Normal file
|
|
@ -0,0 +1,321 @@
|
|||
//! Integration tests for shepherdd
|
||||
//!
|
||||
//! These tests verify the end-to-end behavior of the daemon.
|
||||
|
||||
use chrono::Local;
|
||||
use shepherd_api::{EntryKind, WarningSeverity, WarningThreshold};
|
||||
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy, Policy};
|
||||
use shepherd_core::{CoreEngine, CoreEvent, LaunchDecision};
|
||||
use shepherd_host_api::{HostCapabilities, MockHost};
|
||||
use shepherd_store::{SqliteStore, Store};
|
||||
use shepherd_util::{EntryId, MonotonicInstant};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
fn make_test_policy() -> Policy {
|
||||
Policy {
|
||||
daemon: Default::default(),
|
||||
entries: vec![
|
||||
Entry {
|
||||
id: EntryId::new("test-game"),
|
||||
label: "Test Game".into(),
|
||||
icon_ref: None,
|
||||
kind: EntryKind::Process {
|
||||
argv: vec!["sleep".into(), "999".into()],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
},
|
||||
availability: AvailabilityPolicy {
|
||||
windows: vec![],
|
||||
always: true,
|
||||
},
|
||||
limits: LimitsPolicy {
|
||||
max_run: Duration::from_secs(10), // Short for testing
|
||||
daily_quota: None,
|
||||
cooldown: None,
|
||||
},
|
||||
warnings: vec![
|
||||
WarningThreshold {
|
||||
seconds_before: 5,
|
||||
severity: WarningSeverity::Warn,
|
||||
message_template: Some("5 seconds left".into()),
|
||||
},
|
||||
WarningThreshold {
|
||||
seconds_before: 2,
|
||||
severity: WarningSeverity::Critical,
|
||||
message_template: Some("2 seconds left!".into()),
|
||||
},
|
||||
],
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
},
|
||||
],
|
||||
default_warnings: vec![],
|
||||
default_max_run: Duration::from_secs(3600),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_loading() {
|
||||
let policy = make_test_policy();
|
||||
assert_eq!(policy.entries.len(), 1);
|
||||
assert_eq!(policy.entries[0].id.as_str(), "test-game");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entry_listing() {
|
||||
let policy = make_test_policy();
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let caps = HostCapabilities::minimal();
|
||||
let engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entries = engine.list_entries(Local::now());
|
||||
|
||||
assert_eq!(entries.len(), 1);
|
||||
assert!(entries[0].enabled);
|
||||
assert_eq!(entries[0].entry_id.as_str(), "test-game");
|
||||
assert!(entries[0].max_run_if_started_now.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_launch_approval() {
|
||||
let policy = make_test_policy();
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let caps = HostCapabilities::minimal();
|
||||
let engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entry_id = EntryId::new("test-game");
|
||||
let decision = engine.request_launch(&entry_id, Local::now());
|
||||
|
||||
assert!(matches!(decision, LaunchDecision::Approved(plan) if plan.max_duration == Duration::from_secs(10)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_lifecycle() {
|
||||
let policy = make_test_policy();
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let caps = HostCapabilities::minimal();
|
||||
let mut engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entry_id = EntryId::new("test-game");
|
||||
let now = Local::now();
|
||||
let now_mono = MonotonicInstant::now();
|
||||
|
||||
// Launch
|
||||
let plan = match engine.request_launch(&entry_id, now) {
|
||||
LaunchDecision::Approved(p) => p,
|
||||
LaunchDecision::Denied { .. } => panic!("Launch should be approved"),
|
||||
};
|
||||
|
||||
// Start session
|
||||
let event = engine.start_session(plan, now, now_mono);
|
||||
assert!(matches!(event, CoreEvent::SessionStarted { .. }));
|
||||
|
||||
// Verify session is active
|
||||
assert!(engine.has_active_session());
|
||||
|
||||
// Second launch should be denied
|
||||
let decision = engine.request_launch(&entry_id, now);
|
||||
assert!(matches!(decision, LaunchDecision::Denied { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_warning_emission() {
|
||||
let policy = make_test_policy();
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let caps = HostCapabilities::minimal();
|
||||
let mut engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entry_id = EntryId::new("test-game");
|
||||
let now = Local::now();
|
||||
let now_mono = MonotonicInstant::now();
|
||||
|
||||
// Start session
|
||||
let plan = match engine.request_launch(&entry_id, now) {
|
||||
LaunchDecision::Approved(p) => p,
|
||||
_ => panic!(),
|
||||
};
|
||||
engine.start_session(plan, now, now_mono);
|
||||
|
||||
// No warnings at start
|
||||
let events = engine.tick(now_mono);
|
||||
assert!(events.is_empty());
|
||||
|
||||
// At 6 seconds (4 seconds remaining), 5-second warning should fire
|
||||
let at_6s = now_mono + Duration::from_secs(6);
|
||||
let events = engine.tick(at_6s);
|
||||
assert_eq!(events.len(), 1);
|
||||
assert!(matches!(&events[0], CoreEvent::Warning { threshold_seconds: 5, .. }));
|
||||
|
||||
// At 9 seconds (1 second remaining), 2-second warning should fire
|
||||
let at_9s = now_mono + Duration::from_secs(9);
|
||||
let events = engine.tick(at_9s);
|
||||
assert_eq!(events.len(), 1);
|
||||
assert!(matches!(&events[0], CoreEvent::Warning { threshold_seconds: 2, .. }));
|
||||
|
||||
// Warnings shouldn't repeat
|
||||
let events = engine.tick(at_9s);
|
||||
assert!(events.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_expiry() {
|
||||
let policy = make_test_policy();
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let caps = HostCapabilities::minimal();
|
||||
let mut engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entry_id = EntryId::new("test-game");
|
||||
let now = Local::now();
|
||||
let now_mono = MonotonicInstant::now();
|
||||
|
||||
// Start session
|
||||
let plan = match engine.request_launch(&entry_id, now) {
|
||||
LaunchDecision::Approved(p) => p,
|
||||
_ => panic!(),
|
||||
};
|
||||
engine.start_session(plan, now, now_mono);
|
||||
|
||||
// At 11 seconds, session should be expired
|
||||
let at_11s = now_mono + Duration::from_secs(11);
|
||||
let events = engine.tick(at_11s);
|
||||
|
||||
// Should have both remaining warnings + expiry
|
||||
let has_expiry = events.iter().any(|e| matches!(e, CoreEvent::ExpireDue { .. }));
|
||||
assert!(has_expiry, "Expected ExpireDue event");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_usage_accounting() {
|
||||
let policy = make_test_policy();
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let store_check = store.clone();
|
||||
let caps = HostCapabilities::minimal();
|
||||
let mut engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entry_id = EntryId::new("test-game");
|
||||
let now = Local::now();
|
||||
let now_mono = MonotonicInstant::now();
|
||||
|
||||
// Start session
|
||||
let plan = match engine.request_launch(&entry_id, now) {
|
||||
LaunchDecision::Approved(p) => p,
|
||||
_ => panic!(),
|
||||
};
|
||||
engine.start_session(plan, now, now_mono);
|
||||
|
||||
// Simulate 5 seconds passing
|
||||
let later_mono = now_mono + Duration::from_secs(5);
|
||||
let later = now + chrono::Duration::seconds(5);
|
||||
|
||||
// Session exits
|
||||
engine.notify_session_exited(Some(0), later_mono, later);
|
||||
|
||||
// Check usage was recorded
|
||||
let usage = store_check.get_usage(&entry_id, now.date_naive()).unwrap();
|
||||
assert!(usage >= Duration::from_secs(4) && usage <= Duration::from_secs(6));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_host_integration() {
|
||||
use shepherd_host_api::{HostAdapter, SpawnOptions};
|
||||
use shepherd_util::SessionId;
|
||||
|
||||
let host = MockHost::new();
|
||||
let _rx = host.subscribe();
|
||||
|
||||
let session_id = SessionId::new();
|
||||
let entry = EntryKind::Process {
|
||||
argv: vec!["test".into()],
|
||||
env: HashMap::new(),
|
||||
cwd: None,
|
||||
};
|
||||
|
||||
// Spawn
|
||||
let handle = host
|
||||
.spawn(session_id.clone(), &entry, SpawnOptions::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify running
|
||||
assert_eq!(host.running_sessions().len(), 1);
|
||||
|
||||
// Stop
|
||||
host.stop(
|
||||
&handle,
|
||||
shepherd_host_api::StopMode::Graceful {
|
||||
timeout: Duration::from_secs(1),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_parsing() {
|
||||
use shepherd_config::parse_config;
|
||||
|
||||
let config = r#"
|
||||
config_version = 1
|
||||
|
||||
[[entries]]
|
||||
id = "scummvm"
|
||||
label = "ScummVM"
|
||||
kind = { type = "process", argv = ["scummvm", "-f"] }
|
||||
|
||||
[entries.availability]
|
||||
[[entries.availability.windows]]
|
||||
days = "weekdays"
|
||||
start = "14:00"
|
||||
end = "18:00"
|
||||
|
||||
[entries.limits]
|
||||
max_run_seconds = 3600
|
||||
daily_quota_seconds = 7200
|
||||
cooldown_seconds = 300
|
||||
|
||||
[[entries.warnings]]
|
||||
seconds_before = 300
|
||||
severity = "info"
|
||||
message = "5 minutes remaining"
|
||||
"#;
|
||||
|
||||
let policy = parse_config(config).unwrap();
|
||||
assert_eq!(policy.entries.len(), 1);
|
||||
assert_eq!(policy.entries[0].id.as_str(), "scummvm");
|
||||
assert_eq!(policy.entries[0].limits.max_run, Duration::from_secs(3600));
|
||||
assert_eq!(policy.entries[0].limits.daily_quota, Some(Duration::from_secs(7200)));
|
||||
assert_eq!(policy.entries[0].limits.cooldown, Some(Duration::from_secs(300)));
|
||||
assert_eq!(policy.entries[0].warnings.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_extension() {
|
||||
let policy = make_test_policy();
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
let caps = HostCapabilities::minimal();
|
||||
let mut engine = CoreEngine::new(policy, store, caps);
|
||||
|
||||
let entry_id = EntryId::new("test-game");
|
||||
let now = Local::now();
|
||||
let now_mono = MonotonicInstant::now();
|
||||
|
||||
// Start session
|
||||
let plan = match engine.request_launch(&entry_id, now) {
|
||||
LaunchDecision::Approved(p) => p,
|
||||
_ => panic!(),
|
||||
};
|
||||
engine.start_session(plan, now, now_mono);
|
||||
|
||||
// Get original deadline
|
||||
let original_deadline = engine.current_session().unwrap().deadline;
|
||||
|
||||
// Extend by 5 minutes
|
||||
let new_deadline = engine.extend_current(Duration::from_secs(300), now_mono, now);
|
||||
assert!(new_deadline.is_some());
|
||||
|
||||
let new_deadline = new_deadline.unwrap();
|
||||
let extension = new_deadline.signed_duration_since(original_deadline);
|
||||
assert!(extension.num_seconds() >= 299 && extension.num_seconds() <= 301);
|
||||
}
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
# Process Spawning API
|
||||
|
||||
The daemon now supports spawning graphical processes within the current session.
|
||||
|
||||
## API Messages
|
||||
|
||||
### SpawnProcess
|
||||
Spawns a new process with the specified command and arguments.
|
||||
|
||||
```rust
|
||||
use crate::daemon::{IpcClient, IpcMessage, IpcResponse};
|
||||
|
||||
// Spawn a process with arguments
|
||||
let message = IpcMessage::SpawnProcess {
|
||||
command: "firefox".to_string(),
|
||||
args: vec!["--new-window".to_string(), "https://example.com".to_string()],
|
||||
};
|
||||
|
||||
match IpcClient::send_message(&message) {
|
||||
Ok(IpcResponse::ProcessSpawned { success, pid, message }) => {
|
||||
if success {
|
||||
println!("Process spawned with PID: {:?}", pid);
|
||||
} else {
|
||||
eprintln!("Failed to spawn: {}", message);
|
||||
}
|
||||
}
|
||||
Ok(other) => eprintln!("Unexpected response: {:?}", other),
|
||||
Err(e) => eprintln!("IPC error: {}", e),
|
||||
}
|
||||
```
|
||||
|
||||
### LaunchApp (Legacy)
|
||||
Spawns a process from a command string (command and args in one string).
|
||||
|
||||
```rust
|
||||
let message = IpcMessage::LaunchApp {
|
||||
name: "Terminal".to_string(),
|
||||
command: "alacritty".to_string(),
|
||||
};
|
||||
|
||||
match IpcClient::send_message(&message) {
|
||||
Ok(IpcResponse::ProcessSpawned { success, pid, message }) => {
|
||||
println!("Launch result: {} (PID: {:?})", message, pid);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
```
|
||||
|
||||
## Process Management
|
||||
|
||||
### Automatic Cleanup
|
||||
The daemon automatically tracks spawned processes and cleans up when they exit:
|
||||
- Each spawned process is tracked by PID
|
||||
- The daemon periodically checks for finished processes
|
||||
- Exited processes are automatically removed from tracking
|
||||
|
||||
### Status Query
|
||||
Get the number of currently running processes:
|
||||
|
||||
```rust
|
||||
match IpcClient::send_message(&IpcMessage::GetStatus) {
|
||||
Ok(IpcResponse::Status { uptime_secs, apps_running }) => {
|
||||
println!("Daemon uptime: {}s, Processes running: {}",
|
||||
uptime_secs, apps_running);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
```
|
||||
|
||||
## Environment Inheritance
|
||||
|
||||
Spawned processes inherit the daemon's environment, which includes:
|
||||
- `WAYLAND_DISPLAY` - for Wayland session access
|
||||
- `XDG_RUNTIME_DIR` - runtime directory
|
||||
- `DISPLAY` - for X11 fallback (if available)
|
||||
- All other environment variables from the daemon
|
||||
|
||||
This ensures graphical applications can connect to the display server.
|
||||
|
||||
## Examples
|
||||
|
||||
### Spawn a terminal emulator
|
||||
```rust
|
||||
IpcClient::send_message(&IpcMessage::SpawnProcess {
|
||||
command: "alacritty".to_string(),
|
||||
args: vec![],
|
||||
})
|
||||
```
|
||||
|
||||
### Spawn a browser with URL
|
||||
```rust
|
||||
IpcClient::send_message(&IpcMessage::SpawnProcess {
|
||||
command: "firefox".to_string(),
|
||||
args: vec!["https://github.com".to_string()],
|
||||
})
|
||||
```
|
||||
|
||||
### Spawn with working directory (using sh wrapper)
|
||||
```rust
|
||||
IpcClient::send_message(&IpcMessage::SpawnProcess {
|
||||
command: "sh".to_string(),
|
||||
args: vec![
|
||||
"-c".to_string(),
|
||||
"cd /path/to/project && code .".to_string()
|
||||
],
|
||||
})
|
||||
```
|
||||
|
||||
## Response Format
|
||||
|
||||
`ProcessSpawned` response contains:
|
||||
- `success: bool` - Whether the spawn was successful
|
||||
- `pid: Option<u32>` - Process ID if successful, None on failure
|
||||
- `message: String` - Human-readable status message
|
||||
|
||||
## Error Handling
|
||||
|
||||
Common errors:
|
||||
- Command not found: Returns `success: false` with error message
|
||||
- Permission denied: Returns `success: false` with permission error
|
||||
- Invalid arguments: Returns `success: false` with argument error
|
||||
|
||||
Always check the `success` field before assuming the process started.
|
||||
|
|
@ -1,86 +0,0 @@
|
|||
# Daemon and IPC Implementation
|
||||
|
||||
This directory contains the daemon process and IPC (Inter-Process Communication) implementation for shepherd-launcher.
|
||||
|
||||
## Architecture
|
||||
|
||||
The application uses a multi-process architecture:
|
||||
- **Main Process**: Spawns the daemon and runs the UI
|
||||
- **Daemon Process**: Background service that handles application launching and state management
|
||||
- **IPC**: Unix domain sockets for communication between processes
|
||||
|
||||
## Files
|
||||
|
||||
- `mod.rs`: Module exports
|
||||
- `daemon.rs`: Daemon process implementation
|
||||
- `ipc.rs`: IPC protocol, message types, client and server implementations
|
||||
|
||||
## IPC Protocol
|
||||
|
||||
Communication uses JSON-serialized messages over Unix domain sockets.
|
||||
|
||||
### Message Types (UI → Daemon)
|
||||
- `Ping`: Simple health check
|
||||
- `GetStatus`: Request daemon status (uptime, running apps)
|
||||
- `LaunchApp { name, command }`: Request to launch an application
|
||||
- `Shutdown`: Request daemon shutdown
|
||||
|
||||
### Response Types (Daemon → UI)
|
||||
- `Pong`: Response to Ping
|
||||
- `Status { uptime_secs, apps_running }`: Daemon status information
|
||||
- `AppLaunched { success, message }`: Result of app launch request
|
||||
- `ShuttingDown`: Acknowledgment of shutdown request
|
||||
- `Error { message }`: Error response
|
||||
|
||||
## Socket Location
|
||||
|
||||
The IPC socket is created at: `$XDG_RUNTIME_DIR/shepherd-launcher.sock` (typically `/run/user/1000/shepherd-launcher.sock`)
|
||||
|
||||
## Usage Example
|
||||
|
||||
```rust
|
||||
use crate::daemon::{IpcClient, IpcMessage, IpcResponse};
|
||||
|
||||
// Send a ping
|
||||
match IpcClient::send_message(&IpcMessage::Ping) {
|
||||
Ok(IpcResponse::Pong) => println!("Daemon is alive!"),
|
||||
Ok(other) => println!("Unexpected response: {:?}", other),
|
||||
Err(e) => eprintln!("IPC error: {}", e),
|
||||
}
|
||||
|
||||
// Get daemon status
|
||||
match IpcClient::send_message(&IpcMessage::GetStatus) {
|
||||
Ok(IpcResponse::Status { uptime_secs, apps_running }) => {
|
||||
println!("Uptime: {}s, Apps: {}", uptime_secs, apps_running);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Launch an app
|
||||
let msg = IpcMessage::LaunchApp {
|
||||
name: "Firefox".to_string(),
|
||||
command: "firefox".to_string(),
|
||||
};
|
||||
match IpcClient::send_message(&msg) {
|
||||
Ok(IpcResponse::AppLaunched { success, message }) => {
|
||||
println!("Launch {}: {}", if success { "succeeded" } else { "failed" }, message);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
```
|
||||
|
||||
## Current Functionality
|
||||
|
||||
Currently this is a dummy implementation demonstrating the IPC pattern:
|
||||
- The daemon process runs in the background
|
||||
- The UI periodically queries the daemon status (every 5 seconds)
|
||||
- Messages are printed to stdout for debugging
|
||||
- App launching is simulated (doesn't actually launch apps yet)
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Actual application launching logic
|
||||
- App state tracking
|
||||
- Bi-directional notifications (daemon → UI events)
|
||||
- Multiple concurrent IPC connections
|
||||
- Authentication/security
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
use super::ipc::IpcServer;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Start the daemon process
|
||||
pub fn start_daemon() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("[Daemon] Starting shepherd-launcher daemon...");
|
||||
|
||||
let mut ipc_server = IpcServer::new()?;
|
||||
println!("[Daemon] IPC server listening on socket");
|
||||
|
||||
loop {
|
||||
// Handle incoming IPC connections
|
||||
match ipc_server.accept_and_handle() {
|
||||
Ok(should_shutdown) => {
|
||||
if should_shutdown {
|
||||
println!("[Daemon] Shutdown requested, exiting...");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => eprintln!("[Daemon] Error handling client: {}", e),
|
||||
}
|
||||
|
||||
// Sleep briefly to avoid busy-waiting
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
}
|
||||
|
||||
println!("[Daemon] Daemon shut down cleanly");
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -1,207 +0,0 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::os::unix::net::{UnixListener, UnixStream};
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Child, Command};
|
||||
|
||||
/// Messages that can be sent from the UI to the daemon
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum IpcMessage {
|
||||
Ping,
|
||||
GetStatus,
|
||||
LaunchApp { name: String, command: String },
|
||||
SpawnProcess { command: String, args: Vec<String> },
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
/// Responses sent from the daemon to the UI
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum IpcResponse {
|
||||
Pong,
|
||||
Status { uptime_secs: u64, apps_running: usize },
|
||||
AppLaunched { success: bool, message: String },
|
||||
ProcessSpawned { success: bool, pid: Option<u32>, message: String },
|
||||
ShuttingDown,
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
/// Get the IPC socket path
|
||||
pub fn get_socket_path() -> PathBuf {
|
||||
let runtime_dir = std::env::var("XDG_RUNTIME_DIR")
|
||||
.unwrap_or_else(|_| "/tmp".to_string());
|
||||
PathBuf::from(runtime_dir).join("shepherd-launcher.sock")
|
||||
}
|
||||
|
||||
/// Server-side IPC handler for the daemon
|
||||
pub struct IpcServer {
|
||||
listener: UnixListener,
|
||||
start_time: std::time::Instant,
|
||||
processes: HashMap<u32, Child>,
|
||||
}
|
||||
|
||||
impl IpcServer {
|
||||
pub fn new() -> std::io::Result<Self> {
|
||||
let socket_path = get_socket_path();
|
||||
|
||||
// Remove old socket if it exists
|
||||
let _ = std::fs::remove_file(&socket_path);
|
||||
|
||||
let listener = UnixListener::bind(&socket_path)?;
|
||||
listener.set_nonblocking(true)?;
|
||||
|
||||
Ok(Self {
|
||||
listener,
|
||||
start_time: std::time::Instant::now(),
|
||||
processes: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn accept_and_handle(&mut self) -> std::io::Result<bool> {
|
||||
// Clean up finished processes
|
||||
self.cleanup_processes();
|
||||
|
||||
match self.listener.accept() {
|
||||
Ok((stream, _)) => {
|
||||
let should_shutdown = self.handle_client(stream)?;
|
||||
Ok(should_shutdown)
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||||
Ok(false)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_client(&mut self, mut stream: UnixStream) -> std::io::Result<bool> {
|
||||
let mut reader = BufReader::new(stream.try_clone()?);
|
||||
let mut line = String::new();
|
||||
|
||||
reader.read_line(&mut line)?;
|
||||
|
||||
let message: IpcMessage = match serde_json::from_str(&line) {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => {
|
||||
let response = IpcResponse::Error {
|
||||
message: format!("Failed to parse message: {}", e),
|
||||
};
|
||||
let response_json = serde_json::to_string(&response)?;
|
||||
writeln!(stream, "{}", response_json)?;
|
||||
return Ok(false);
|
||||
}
|
||||
};
|
||||
|
||||
let should_shutdown = matches!(message, IpcMessage::Shutdown);
|
||||
let response = self.process_message(message);
|
||||
let response_json = serde_json::to_string(&response)?;
|
||||
writeln!(stream, "{}", response_json)?;
|
||||
|
||||
Ok(should_shutdown)
|
||||
}
|
||||
|
||||
fn process_message(&mut self, message: IpcMessage) -> IpcResponse {
|
||||
match message {
|
||||
IpcMessage::Ping => IpcResponse::Pong,
|
||||
IpcMessage::GetStatus => {
|
||||
let uptime_secs = self.start_time.elapsed().as_secs();
|
||||
IpcResponse::Status {
|
||||
uptime_secs,
|
||||
apps_running: self.processes.len(),
|
||||
}
|
||||
}
|
||||
IpcMessage::LaunchApp { name, command } => {
|
||||
println!("[Daemon] Launching app: {} ({})", name, command);
|
||||
self.spawn_graphical_process(&command, &[])
|
||||
}
|
||||
IpcMessage::SpawnProcess { command, args } => {
|
||||
println!("[Daemon] Spawning process: {} {:?}", command, args);
|
||||
self.spawn_graphical_process(&command, &args)
|
||||
}
|
||||
IpcMessage::Shutdown => IpcResponse::ShuttingDown,
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_graphical_process(&mut self, command: &str, args: &[String]) -> IpcResponse {
|
||||
// Parse command if it contains arguments and args is empty
|
||||
let (cmd, cmd_args) = if args.is_empty() {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
if parts.is_empty() {
|
||||
return IpcResponse::ProcessSpawned {
|
||||
success: false,
|
||||
pid: None,
|
||||
message: "Empty command".to_string(),
|
||||
};
|
||||
}
|
||||
(parts[0], parts[1..].iter().map(|s| s.to_string()).collect())
|
||||
} else {
|
||||
(command, args.to_vec())
|
||||
};
|
||||
|
||||
match Command::new(cmd)
|
||||
.args(&cmd_args)
|
||||
.spawn()
|
||||
{
|
||||
Ok(child) => {
|
||||
let pid = child.id();
|
||||
println!("[Daemon] Successfully spawned process PID: {}", pid);
|
||||
self.processes.insert(pid, child);
|
||||
IpcResponse::ProcessSpawned {
|
||||
success: true,
|
||||
pid: Some(pid),
|
||||
message: format!("Process spawned with PID {}", pid),
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[Daemon] Failed to spawn process '{}': {}", cmd, e);
|
||||
IpcResponse::ProcessSpawned {
|
||||
success: false,
|
||||
pid: None,
|
||||
message: format!("Failed to spawn: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cleanup_processes(&mut self) {
|
||||
// Check for finished processes and remove them
|
||||
let mut finished = Vec::new();
|
||||
for (pid, child) in self.processes.iter_mut() {
|
||||
match child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
println!("[Daemon] Process {} exited with status: {}", pid, status);
|
||||
finished.push(*pid);
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => {
|
||||
eprintln!("[Daemon] Error checking process {}: {}", pid, e);
|
||||
finished.push(*pid);
|
||||
}
|
||||
}
|
||||
}
|
||||
for pid in finished {
|
||||
self.processes.remove(&pid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Client-side IPC handler for the UI
|
||||
pub struct IpcClient;
|
||||
|
||||
impl IpcClient {
|
||||
pub fn send_message(message: &IpcMessage) -> std::io::Result<IpcResponse> {
|
||||
let socket_path = get_socket_path();
|
||||
let mut stream = UnixStream::connect(&socket_path)?;
|
||||
|
||||
let message_json = serde_json::to_string(message)?;
|
||||
writeln!(stream, "{}", message_json)?;
|
||||
|
||||
let mut reader = BufReader::new(stream);
|
||||
let mut response_line = String::new();
|
||||
reader.read_line(&mut response_line)?;
|
||||
|
||||
let response: IpcResponse = serde_json::from_str(&response_line)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
mod daemon;
|
||||
mod ipc;
|
||||
|
||||
pub use daemon::start_daemon;
|
||||
pub use ipc::{IpcClient, IpcMessage, IpcResponse};
|
||||
79
src/main.rs
79
src/main.rs
|
|
@ -1,79 +0,0 @@
|
|||
mod daemon;
|
||||
mod ui;
|
||||
|
||||
use std::env;
|
||||
use std::process::{Command, Stdio};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
// Check if we're running as the daemon
|
||||
if args.len() > 1 && args[1] == "--daemon" {
|
||||
return daemon::start_daemon();
|
||||
}
|
||||
|
||||
// Spawn the daemon process
|
||||
println!("[Main] Spawning daemon process...");
|
||||
let mut daemon_child = Command::new(&args[0])
|
||||
.arg("--daemon")
|
||||
.stdout(Stdio::inherit())
|
||||
.stderr(Stdio::inherit())
|
||||
.spawn()?;
|
||||
|
||||
let daemon_pid = daemon_child.id();
|
||||
println!("[Main] Daemon spawned with PID: {}", daemon_pid);
|
||||
|
||||
// Give the daemon a moment to start up
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
|
||||
// Test the IPC connection
|
||||
println!("[Main] Testing IPC connection...");
|
||||
match daemon::IpcClient::send_message(&daemon::IpcMessage::Ping) {
|
||||
Ok(daemon::IpcResponse::Pong) => println!("[Main] IPC connection successful!"),
|
||||
Ok(response) => println!("[Main] Unexpected response: {:?}", response),
|
||||
Err(e) => println!("[Main] IPC connection failed: {}", e),
|
||||
}
|
||||
|
||||
// Start the UI
|
||||
println!("[Main] Starting UI...");
|
||||
let ui_result = ui::run();
|
||||
|
||||
// UI has exited, shut down the daemon
|
||||
println!("[Main] UI exited, shutting down daemon...");
|
||||
match daemon::IpcClient::send_message(&daemon::IpcMessage::Shutdown) {
|
||||
Ok(daemon::IpcResponse::ShuttingDown) => {
|
||||
println!("[Main] Daemon acknowledged shutdown");
|
||||
}
|
||||
Ok(response) => {
|
||||
println!("[Main] Unexpected shutdown response: {:?}", response);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[Main] Failed to send shutdown to daemon: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for daemon to exit (with timeout)
|
||||
let wait_start = std::time::Instant::now();
|
||||
loop {
|
||||
match daemon_child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
println!("[Main] Daemon exited with status: {}", status);
|
||||
break;
|
||||
}
|
||||
Ok(None) => {
|
||||
if wait_start.elapsed().as_secs() > 5 {
|
||||
eprintln!("[Main] Daemon did not exit in time, killing it");
|
||||
let _ = daemon_child.kill();
|
||||
break;
|
||||
}
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[Main] Error waiting for daemon: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ui_result
|
||||
}
|
||||
224
src/ui/clock.rs
224
src/ui/clock.rs
|
|
@ -1,224 +0,0 @@
|
|||
use chrono::Local;
|
||||
use smithay_client_toolkit::{
|
||||
compositor::{CompositorHandler, CompositorState},
|
||||
delegate_compositor, delegate_layer, delegate_output, delegate_registry, delegate_shm,
|
||||
output::{OutputHandler, OutputState},
|
||||
registry::{ProvidesRegistryState, RegistryState},
|
||||
registry_handlers,
|
||||
shell::{
|
||||
wlr_layer::{LayerShell, LayerShellHandler, LayerSurface, LayerSurfaceConfigure},
|
||||
WaylandSurface,
|
||||
},
|
||||
shm::{slot::SlotPool, Shm, ShmHandler},
|
||||
};
|
||||
use wayland_client::{
|
||||
protocol::{wl_output, wl_shm, wl_surface},
|
||||
Connection, QueueHandle,
|
||||
};
|
||||
|
||||
pub struct ClockApp {
|
||||
pub registry_state: RegistryState,
|
||||
pub output_state: OutputState,
|
||||
pub compositor_state: CompositorState,
|
||||
pub shm_state: Shm,
|
||||
pub layer_shell: LayerShell,
|
||||
|
||||
pub pool: Option<SlotPool>,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
pub layer_surface: Option<LayerSurface>,
|
||||
pub configured: bool,
|
||||
}
|
||||
|
||||
impl ClockApp {
|
||||
pub fn draw(&mut self, _qh: &QueueHandle<Self>) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if let Some(layer_surface) = &self.layer_surface {
|
||||
let width = self.width;
|
||||
let height = self.height;
|
||||
let stride = width as i32 * 4;
|
||||
|
||||
let pool = self.pool.get_or_insert_with(|| {
|
||||
SlotPool::new((width * height * 4) as usize, &self.shm_state).unwrap()
|
||||
});
|
||||
|
||||
let (buffer, canvas) = pool
|
||||
.create_buffer(width as i32, height as i32, stride, wl_shm::Format::Argb8888)
|
||||
.unwrap();
|
||||
|
||||
// Get current time
|
||||
let now = Local::now();
|
||||
let time_str = now.format("%H:%M:%S").to_string();
|
||||
let date_str = now.format("%A, %B %d, %Y").to_string();
|
||||
|
||||
// Draw using cairo
|
||||
// Safety: We ensure the buffer lifetime is valid for the cairo surface
|
||||
unsafe {
|
||||
let surface = cairo::ImageSurface::create_for_data_unsafe(
|
||||
canvas.as_mut_ptr(),
|
||||
cairo::Format::ARgb32,
|
||||
width as i32,
|
||||
height as i32,
|
||||
stride,
|
||||
)?;
|
||||
|
||||
let ctx = cairo::Context::new(&surface)?;
|
||||
|
||||
// Background
|
||||
ctx.set_source_rgb(0.1, 0.1, 0.15);
|
||||
ctx.paint()?;
|
||||
|
||||
// Draw time
|
||||
ctx.set_source_rgb(1.0, 1.0, 1.0);
|
||||
ctx.select_font_face("Sans", cairo::FontSlant::Normal, cairo::FontWeight::Bold);
|
||||
ctx.set_font_size(60.0);
|
||||
|
||||
let time_extents = ctx.text_extents(&time_str)?;
|
||||
let time_x = (width as f64 - time_extents.width()) / 2.0 - time_extents.x_bearing();
|
||||
let time_y = height as f64 / 2.0 - 10.0;
|
||||
ctx.move_to(time_x, time_y);
|
||||
ctx.show_text(&time_str)?;
|
||||
|
||||
// Draw date
|
||||
ctx.set_font_size(20.0);
|
||||
ctx.select_font_face("Sans", cairo::FontSlant::Normal, cairo::FontWeight::Normal);
|
||||
let date_extents = ctx.text_extents(&date_str)?;
|
||||
let date_x = (width as f64 - date_extents.width()) / 2.0 - date_extents.x_bearing();
|
||||
let date_y = height as f64 / 2.0 + 35.0;
|
||||
ctx.move_to(date_x, date_y);
|
||||
ctx.show_text(&date_str)?;
|
||||
}
|
||||
|
||||
layer_surface
|
||||
.wl_surface()
|
||||
.attach(Some(buffer.wl_buffer()), 0, 0);
|
||||
layer_surface.wl_surface().damage_buffer(0, 0, width as i32, height as i32);
|
||||
layer_surface.wl_surface().commit();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl CompositorHandler for ClockApp {
|
||||
fn scale_factor_changed(
|
||||
&mut self,
|
||||
_conn: &Connection,
|
||||
_qh: &QueueHandle<Self>,
|
||||
_surface: &wl_surface::WlSurface,
|
||||
_new_factor: i32,
|
||||
) {
|
||||
}
|
||||
|
||||
fn transform_changed(
|
||||
&mut self,
|
||||
_conn: &Connection,
|
||||
_qh: &QueueHandle<Self>,
|
||||
_surface: &wl_surface::WlSurface,
|
||||
_new_transform: wl_output::Transform,
|
||||
) {
|
||||
}
|
||||
|
||||
fn frame(
|
||||
&mut self,
|
||||
_conn: &Connection,
|
||||
qh: &QueueHandle<Self>,
|
||||
_surface: &wl_surface::WlSurface,
|
||||
_time: u32,
|
||||
) {
|
||||
let _ = self.draw(qh);
|
||||
}
|
||||
|
||||
fn surface_enter(
|
||||
&mut self,
|
||||
_conn: &Connection,
|
||||
_qh: &QueueHandle<Self>,
|
||||
_surface: &wl_surface::WlSurface,
|
||||
_output: &wl_output::WlOutput,
|
||||
) {
|
||||
}
|
||||
|
||||
fn surface_leave(
|
||||
&mut self,
|
||||
_conn: &Connection,
|
||||
_qh: &QueueHandle<Self>,
|
||||
_surface: &wl_surface::WlSurface,
|
||||
_output: &wl_output::WlOutput,
|
||||
) {
|
||||
}
|
||||
}
|
||||
|
||||
impl OutputHandler for ClockApp {
|
||||
fn output_state(&mut self) -> &mut OutputState {
|
||||
&mut self.output_state
|
||||
}
|
||||
|
||||
fn new_output(
|
||||
&mut self,
|
||||
_conn: &Connection,
|
||||
_qh: &QueueHandle<Self>,
|
||||
_output: wl_output::WlOutput,
|
||||
) {
|
||||
}
|
||||
|
||||
fn update_output(
|
||||
&mut self,
|
||||
_conn: &Connection,
|
||||
_qh: &QueueHandle<Self>,
|
||||
_output: wl_output::WlOutput,
|
||||
) {
|
||||
}
|
||||
|
||||
fn output_destroyed(
|
||||
&mut self,
|
||||
_conn: &Connection,
|
||||
_qh: &QueueHandle<Self>,
|
||||
_output: wl_output::WlOutput,
|
||||
) {
|
||||
}
|
||||
}
|
||||
|
||||
impl LayerShellHandler for ClockApp {
|
||||
fn closed(&mut self, _conn: &Connection, _qh: &QueueHandle<Self>, _layer: &LayerSurface) {
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
fn configure(
|
||||
&mut self,
|
||||
_conn: &Connection,
|
||||
qh: &QueueHandle<Self>,
|
||||
_layer: &LayerSurface,
|
||||
configure: LayerSurfaceConfigure,
|
||||
_serial: u32,
|
||||
) {
|
||||
if configure.new_size.0 != 0 {
|
||||
self.width = configure.new_size.0;
|
||||
}
|
||||
if configure.new_size.1 != 0 {
|
||||
self.height = configure.new_size.1;
|
||||
}
|
||||
|
||||
self.configured = true;
|
||||
let _ = self.draw(qh);
|
||||
}
|
||||
}
|
||||
|
||||
impl ShmHandler for ClockApp {
|
||||
fn shm_state(&mut self) -> &mut Shm {
|
||||
&mut self.shm_state
|
||||
}
|
||||
}
|
||||
|
||||
delegate_compositor!(ClockApp);
|
||||
delegate_output!(ClockApp);
|
||||
delegate_shm!(ClockApp);
|
||||
delegate_layer!(ClockApp);
|
||||
|
||||
delegate_registry!(ClockApp);
|
||||
|
||||
impl ProvidesRegistryState for ClockApp {
|
||||
fn registry(&mut self) -> &mut RegistryState {
|
||||
&mut self.registry_state
|
||||
}
|
||||
|
||||
registry_handlers![OutputState];
|
||||
}
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
mod clock;
|
||||
mod ui;
|
||||
|
||||
pub use ui::run;
|
||||
111
src/ui/ui.rs
111
src/ui/ui.rs
|
|
@ -1,111 +0,0 @@
|
|||
use super::clock::ClockApp;
|
||||
use crate::daemon::{IpcClient, IpcMessage, IpcResponse};
|
||||
use smithay_client_toolkit::{
|
||||
compositor::CompositorState,
|
||||
output::OutputState,
|
||||
registry::RegistryState,
|
||||
shell::{
|
||||
wlr_layer::{Anchor, KeyboardInteractivity, Layer, LayerShell},
|
||||
WaylandSurface,
|
||||
},
|
||||
shm::Shm,
|
||||
};
|
||||
use wayland_client::globals::registry_queue_init;
|
||||
use wayland_client::Connection;
|
||||
|
||||
pub fn run() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let conn = Connection::connect_to_env()?;
|
||||
let (globals, mut event_queue) = registry_queue_init(&conn)?;
|
||||
let qh = event_queue.handle();
|
||||
|
||||
let mut app = ClockApp {
|
||||
registry_state: RegistryState::new(&globals),
|
||||
output_state: OutputState::new(&globals, &qh),
|
||||
compositor_state: CompositorState::bind(&globals, &qh)?,
|
||||
shm_state: Shm::bind(&globals, &qh)?,
|
||||
layer_shell: LayerShell::bind(&globals, &qh)?,
|
||||
|
||||
pool: None,
|
||||
width: 400,
|
||||
height: 200,
|
||||
layer_surface: None,
|
||||
configured: false,
|
||||
};
|
||||
|
||||
// Create the layer surface
|
||||
let surface = app.compositor_state.create_surface(&qh);
|
||||
let layer_surface = app.layer_shell.create_layer_surface(
|
||||
&qh,
|
||||
surface,
|
||||
Layer::Top,
|
||||
Some("clock"),
|
||||
None,
|
||||
);
|
||||
|
||||
layer_surface.set_anchor(Anchor::TOP | Anchor::LEFT | Anchor::RIGHT);
|
||||
layer_surface.set_size(app.width, app.height);
|
||||
layer_surface.set_exclusive_zone(app.height as i32);
|
||||
layer_surface.set_keyboard_interactivity(KeyboardInteractivity::None);
|
||||
layer_surface.commit();
|
||||
|
||||
app.layer_surface = Some(layer_surface);
|
||||
|
||||
// Periodically query daemon status via IPC
|
||||
let mut counter = 0;
|
||||
|
||||
// Example: Spawn a test process after 2 seconds
|
||||
let mut test_spawned = false;
|
||||
|
||||
loop {
|
||||
event_queue.blocking_dispatch(&mut app)?;
|
||||
|
||||
if app.configured {
|
||||
app.draw(&qh)?;
|
||||
|
||||
// Example: Spawn a simple graphical process after 2 seconds
|
||||
if counter == 4 && !test_spawned {
|
||||
println!("[UI] Testing process spawn API...");
|
||||
match IpcClient::send_message(&IpcMessage::SpawnProcess {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["Hello from spawned process!".to_string()],
|
||||
}) {
|
||||
Ok(IpcResponse::ProcessSpawned { success, pid, message }) => {
|
||||
if success {
|
||||
println!("[UI] Process spawned successfully! PID: {:?}, Message: {}",
|
||||
pid, message);
|
||||
} else {
|
||||
println!("[UI] Process spawn failed: {}", message);
|
||||
}
|
||||
}
|
||||
Ok(response) => {
|
||||
println!("[UI] Unexpected response: {:?}", response);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[UI] Failed to spawn process: {}", e);
|
||||
}
|
||||
}
|
||||
test_spawned = true;
|
||||
}
|
||||
|
||||
// Every 10 iterations (5 seconds), query the daemon
|
||||
if counter % 10 == 0 {
|
||||
match IpcClient::send_message(&IpcMessage::GetStatus) {
|
||||
Ok(IpcResponse::Status { uptime_secs, apps_running }) => {
|
||||
println!("[UI] Daemon status - Uptime: {}s, Apps running: {}",
|
||||
uptime_secs, apps_running);
|
||||
}
|
||||
Ok(response) => {
|
||||
println!("[UI] Unexpected daemon response: {:?}", response);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[UI] Failed to communicate with daemon: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
counter += 1;
|
||||
|
||||
// Sleep briefly to reduce CPU usage
|
||||
std::thread::sleep(std::time::Duration::from_millis(500));
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue