Run rustfmt

This commit is contained in:
Albert Armea 2026-02-08 14:01:49 -05:00
parent c8675be472
commit 3861092d3d
35 changed files with 863 additions and 672 deletions

View file

@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use shepherd_util::{ClientId, EntryId};
use std::time::Duration;
use crate::{ClientRole, StopMode, API_VERSION};
use crate::{API_VERSION, ClientRole, StopMode};
/// Request wrapper with metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -128,7 +128,6 @@ pub enum Command {
GetHealth,
// Volume control commands
/// Get current volume status
GetVolume,
@ -142,7 +141,6 @@ pub enum Command {
SetMute { muted: bool },
// Admin commands
/// Extend the current session (admin only)
ExtendCurrent { by: Duration },

View file

@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use shepherd_util::{EntryId, SessionId};
use std::time::Duration;
use crate::{ServiceStateSnapshot, SessionEndReason, WarningSeverity, API_VERSION};
use crate::{API_VERSION, ServiceStateSnapshot, SessionEndReason, WarningSeverity};
/// Event envelope
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -51,9 +51,7 @@ pub enum EventPayload {
},
/// Session is expiring (termination initiated)
SessionExpiring {
session_id: SessionId,
},
SessionExpiring { session_id: SessionId },
/// Session has ended
SessionEnded {
@ -64,21 +62,13 @@ pub enum EventPayload {
},
/// Policy was reloaded
PolicyReloaded {
entry_count: usize,
},
PolicyReloaded { entry_count: usize },
/// Entry availability changed (for UI updates)
EntryAvailabilityChanged {
entry_id: EntryId,
enabled: bool,
},
EntryAvailabilityChanged { entry_id: EntryId, enabled: bool },
/// Volume status changed
VolumeChanged {
percent: u8,
muted: bool,
},
VolumeChanged { percent: u8, muted: bool },
/// Service is shutting down
Shutdown,
@ -107,7 +97,10 @@ mod tests {
let parsed: Event = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.api_version, API_VERSION);
assert!(matches!(parsed.payload, EventPayload::SessionStarted { .. }));
assert!(matches!(
parsed.payload,
EventPayload::SessionStarted { .. }
));
}
#[test]

View file

@ -125,14 +125,9 @@ pub enum ReasonCode {
next_window_start: Option<DateTime<Local>>,
},
/// Daily quota exhausted
QuotaExhausted {
used: Duration,
quota: Duration,
},
QuotaExhausted { used: Duration, quota: Duration },
/// Cooldown period active
CooldownActive {
available_at: DateTime<Local>,
},
CooldownActive { available_at: DateTime<Local> },
/// Another session is active
SessionActive {
entry_id: EntryId,
@ -140,17 +135,11 @@ pub enum ReasonCode {
remaining: Option<Duration>,
},
/// Host doesn't support this entry kind
UnsupportedKind {
kind: EntryKindTag,
},
UnsupportedKind { kind: EntryKindTag },
/// Entry is explicitly disabled
Disabled {
reason: Option<String>,
},
Disabled { reason: Option<String> },
/// Internet connectivity is required but unavailable
InternetUnavailable {
check: Option<String>,
},
InternetUnavailable { check: Option<String> },
}
/// Warning severity level

View file

@ -29,7 +29,10 @@ fn main() -> ExitCode {
// Check file exists
if !config_path.exists() {
eprintln!("Error: Configuration file not found: {}", config_path.display());
eprintln!(
"Error: Configuration file not found: {}",
config_path.display()
);
return ExitCode::from(1);
}
@ -39,7 +42,10 @@ fn main() -> ExitCode {
println!("✓ Configuration is valid");
println!();
println!("Summary:");
println!(" Config version: {}", shepherd_config::CURRENT_CONFIG_VERSION);
println!(
" Config version: {}",
shepherd_config::CURRENT_CONFIG_VERSION
);
println!(" Entries: {}", policy.entries.len());
// Show entry summary

View file

@ -65,8 +65,9 @@ impl InternetCheckTarget {
let (host, port_opt) = parse_host_port(host_port)?;
let port = match scheme {
InternetCheckScheme::Tcp => port_opt
.ok_or_else(|| "tcp check requires explicit port".to_string())?,
InternetCheckScheme::Tcp => {
port_opt.ok_or_else(|| "tcp check requires explicit port".to_string())?
}
_ => port_opt.unwrap_or_else(|| scheme.default_port()),
};
@ -149,4 +150,3 @@ pub struct EntryInternetPolicy {
pub required: bool,
pub check: Option<InternetCheckTarget>,
}

View file

@ -6,15 +6,15 @@
//! - Time windows, limits, and warnings
//! - Validation with clear error messages
mod internet;
mod policy;
mod schema;
mod validation;
mod internet;
pub use internet::*;
pub use policy::*;
pub use schema::*;
pub use validation::*;
pub use internet::*;
use std::path::Path;
use thiserror::Error;

View file

@ -1,16 +1,19 @@
//! Validated policy structures
use crate::schema::{
RawConfig, RawEntry, RawEntryKind, RawInternetConfig, RawVolumeConfig, RawServiceConfig,
RawWarningThreshold,
};
use crate::internet::{
EntryInternetPolicy, InternetCheckTarget, InternetConfig, DEFAULT_INTERNET_CHECK_INTERVAL,
DEFAULT_INTERNET_CHECK_TIMEOUT,
DEFAULT_INTERNET_CHECK_INTERVAL, DEFAULT_INTERNET_CHECK_TIMEOUT, EntryInternetPolicy,
InternetCheckTarget, InternetConfig,
};
use crate::schema::{
RawConfig, RawEntry, RawEntryKind, RawInternetConfig, RawServiceConfig, RawVolumeConfig,
RawWarningThreshold,
};
use crate::validation::{parse_days, parse_time};
use shepherd_api::{EntryKind, WarningSeverity, WarningThreshold};
use shepherd_util::{DaysOfWeek, EntryId, TimeWindow, WallClock, default_data_dir, default_log_dir, socket_path_without_env};
use shepherd_util::{
DaysOfWeek, EntryId, TimeWindow, WallClock, default_data_dir, default_log_dir,
socket_path_without_env,
};
use std::path::PathBuf;
use std::time::Duration;
@ -94,24 +97,17 @@ pub struct ServiceConfig {
impl ServiceConfig {
fn from_raw(raw: RawServiceConfig) -> Self {
let log_dir = raw
.log_dir
.clone()
.unwrap_or_else(default_log_dir);
let log_dir = raw.log_dir.clone().unwrap_or_else(default_log_dir);
let child_log_dir = raw
.child_log_dir
.unwrap_or_else(|| log_dir.join("sessions"));
let internet = convert_internet_config(raw.internet.as_ref());
Self {
socket_path: raw
.socket_path
.unwrap_or_else(socket_path_without_env),
socket_path: raw.socket_path.unwrap_or_else(socket_path_without_env),
log_dir,
capture_child_output: raw.capture_child_output,
child_log_dir,
data_dir: raw
.data_dir
.unwrap_or_else(default_data_dir),
data_dir: raw.data_dir.unwrap_or_else(default_data_dir),
internet,
}
}
@ -126,7 +122,11 @@ impl Default for ServiceConfig {
log_dir,
data_dir: default_data_dir(),
capture_child_output: false,
internet: InternetConfig::new(None, DEFAULT_INTERNET_CHECK_INTERVAL, DEFAULT_INTERNET_CHECK_TIMEOUT),
internet: InternetConfig::new(
None,
DEFAULT_INTERNET_CHECK_INTERVAL,
DEFAULT_INTERNET_CHECK_TIMEOUT,
),
}
}
}
@ -212,10 +212,7 @@ impl AvailabilityPolicy {
}
/// Get remaining time in current window
pub fn remaining_in_window(
&self,
dt: &chrono::DateTime<chrono::Local>,
) -> Option<Duration> {
pub fn remaining_in_window(&self, dt: &chrono::DateTime<chrono::Local>) -> Option<Duration> {
if self.always {
return None; // No limit from windows
}
@ -269,8 +266,28 @@ impl VolumePolicy {
fn convert_entry_kind(raw: RawEntryKind) -> EntryKind {
match raw {
RawEntryKind::Process { command, args, env, cwd } => EntryKind::Process { command, args, env, cwd },
RawEntryKind::Snap { snap_name, command, args, env } => EntryKind::Snap { snap_name, command, args, env },
RawEntryKind::Process {
command,
args,
env,
cwd,
} => EntryKind::Process {
command,
args,
env,
cwd,
},
RawEntryKind::Snap {
snap_name,
command,
args,
env,
} => EntryKind::Snap {
snap_name,
command,
args,
env,
},
RawEntryKind::Steam { app_id, args, env } => EntryKind::Steam { app_id, args, env },
RawEntryKind::Flatpak { app_id, args, env } => EntryKind::Flatpak { app_id, args, env },
RawEntryKind::Vm { driver, args } => EntryKind::Vm { driver, args },
@ -347,7 +364,10 @@ fn seconds_to_duration_or_unlimited(secs: u64) -> Option<Duration> {
}
}
fn convert_limits(raw: crate::schema::RawLimits, default_max_run: Option<Duration>) -> LimitsPolicy {
fn convert_limits(
raw: crate::schema::RawLimits,
default_max_run: Option<Duration>,
) -> LimitsPolicy {
LimitsPolicy {
max_run: raw
.max_run_seconds

View file

@ -1,7 +1,7 @@
//! Configuration validation
use crate::schema::{RawConfig, RawDays, RawEntry, RawEntryKind, RawTimeWindow};
use crate::internet::InternetCheckTarget;
use crate::schema::{RawConfig, RawDays, RawEntry, RawEntryKind, RawTimeWindow};
use std::collections::HashSet;
use thiserror::Error;
@ -38,7 +38,8 @@ pub fn validate_config(config: &RawConfig) -> Vec<ValidationError> {
// Validate global internet check (if set)
if let Some(internet) = &config.service.internet
&& let Some(check) = &internet.check
&& let Err(e) = InternetCheckTarget::parse(check) {
&& let Err(e) = InternetCheckTarget::parse(check)
{
errors.push(ValidationError::GlobalError(format!(
"Invalid internet check '{}': {}",
check, e
@ -47,13 +48,15 @@ pub fn validate_config(config: &RawConfig) -> Vec<ValidationError> {
if let Some(internet) = &config.service.internet {
if let Some(interval) = internet.interval_seconds
&& interval == 0 {
&& interval == 0
{
errors.push(ValidationError::GlobalError(
"Internet check interval_seconds must be > 0".into(),
));
}
if let Some(timeout) = internet.timeout_ms
&& timeout == 0 {
&& timeout == 0
{
errors.push(ValidationError::GlobalError(
"Internet check timeout_ms must be > 0".into(),
));
@ -156,7 +159,8 @@ fn validate_entry(entry: &RawEntry, config: &RawConfig) -> Vec<ValidationError>
// Only validate warnings if max_run is Some and not 0 (unlimited)
if let (Some(warnings), Some(max_run)) = (&entry.warnings, max_run)
&& max_run > 0 {
&& max_run > 0
{
for warning in warnings {
if warning.seconds_before >= max_run {
errors.push(ValidationError::WarningExceedsMaxRun {
@ -172,7 +176,8 @@ fn validate_entry(entry: &RawEntry, config: &RawConfig) -> Vec<ValidationError>
// Validate internet requirements
if let Some(internet) = &entry.internet {
if let Some(check) = &internet.check
&& let Err(e) = InternetCheckTarget::parse(check) {
&& let Err(e) = InternetCheckTarget::parse(check)
{
errors.push(ValidationError::EntryError {
entry_id: entry.id.clone(),
message: format!("Invalid internet check '{}': {}", check, e),
@ -236,12 +241,8 @@ pub fn parse_time(s: &str) -> Result<(u8, u8), String> {
return Err("Expected HH:MM format".into());
}
let hour: u8 = parts[0]
.parse()
.map_err(|_| "Invalid hour".to_string())?;
let minute: u8 = parts[1]
.parse()
.map_err(|_| "Invalid minute".to_string())?;
let hour: u8 = parts[0].parse().map_err(|_| "Invalid hour".to_string())?;
let minute: u8 = parts[1].parse().map_err(|_| "Invalid minute".to_string())?;
if hour >= 24 {
return Err("Hour must be 0-23".into());
@ -299,12 +300,23 @@ mod tests {
#[test]
fn test_parse_days() {
assert_eq!(parse_days(&RawDays::Preset("weekdays".into())).unwrap(), 0x1F);
assert_eq!(parse_days(&RawDays::Preset("weekends".into())).unwrap(), 0x60);
assert_eq!(
parse_days(&RawDays::Preset("weekdays".into())).unwrap(),
0x1F
);
assert_eq!(
parse_days(&RawDays::Preset("weekends".into())).unwrap(),
0x60
);
assert_eq!(parse_days(&RawDays::Preset("all".into())).unwrap(), 0x7F);
assert_eq!(
parse_days(&RawDays::List(vec!["mon".into(), "wed".into(), "fri".into()])).unwrap(),
parse_days(&RawDays::List(vec![
"mon".into(),
"wed".into(),
"fri".into()
]))
.unwrap(),
0b10101
);
}
@ -355,6 +367,10 @@ mod tests {
};
let errors = validate_config(&config);
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateEntryId(_))));
assert!(
errors
.iter()
.any(|e| matches!(e, ValidationError::DuplicateEntryId(_)))
);
}
}

View file

@ -2,10 +2,9 @@
use chrono::{DateTime, Local};
use shepherd_api::{
ServiceStateSnapshot, EntryView, ReasonCode, SessionEndReason,
WarningSeverity, API_VERSION,
API_VERSION, EntryView, ReasonCode, ServiceStateSnapshot, SessionEndReason, WarningSeverity,
};
use shepherd_config::{Entry, Policy, InternetCheckTarget};
use shepherd_config::{Entry, InternetCheckTarget, Policy};
use shepherd_host_api::{HostCapabilities, HostSessionHandle};
use shepherd_store::{AuditEvent, AuditEventType, Store};
use shepherd_util::{EntryId, MonotonicInstant, SessionId};
@ -44,11 +43,7 @@ pub struct CoreEngine {
impl CoreEngine {
/// Create a new core engine
pub fn new(
policy: Policy,
store: Arc<dyn Store>,
capabilities: HostCapabilities,
) -> Self {
pub fn new(policy: Policy, store: Arc<dyn Store>, capabilities: HostCapabilities) -> Self {
info!(
entry_count = policy.entries.len(),
"Core engine initialized"
@ -79,7 +74,9 @@ impl CoreEngine {
let entry_count = policy.entries.len();
self.policy = policy;
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::PolicyLoaded {
let _ = self
.store
.append_audit(AuditEvent::new(AuditEventType::PolicyLoaded {
entry_count,
}));
@ -137,11 +134,12 @@ impl CoreEngine {
// Check internet requirement
if entry.internet.required {
let check = entry
let check = entry.internet.check.as_ref().or(self
.policy
.service
.internet
.check
.as_ref()
.or(self.policy.service.internet.check.as_ref());
.as_ref());
let available = check
.map(|target| self.internet_available(target))
.unwrap_or(false);
@ -165,16 +163,20 @@ impl CoreEngine {
// Check cooldown
if let Ok(Some(until)) = self.store.get_cooldown_until(&entry.id)
&& until > now {
&& until > now
{
enabled = false;
reasons.push(ReasonCode::CooldownActive { available_at: until });
reasons.push(ReasonCode::CooldownActive {
available_at: until,
});
}
// Check daily quota
if let Some(quota) = entry.limits.daily_quota {
let today = now.date_naive();
if let Ok(used) = self.store.get_usage(&entry.id, today)
&& used >= quota {
&& used >= quota
{
enabled = false;
reasons.push(ReasonCode::QuotaExhausted { used, quota });
}
@ -227,11 +229,7 @@ impl CoreEngine {
}
/// Request to launch an entry
pub fn request_launch(
&self,
entry_id: &EntryId,
now: DateTime<Local>,
) -> LaunchDecision {
pub fn request_launch(&self, entry_id: &EntryId, now: DateTime<Local>) -> LaunchDecision {
// Find entry
let entry = match self.policy.get_entry(entry_id) {
Some(e) => e,
@ -249,7 +247,9 @@ impl CoreEngine {
if !view.enabled {
// Log denial
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::LaunchDenied {
let _ = self
.store
.append_audit(AuditEvent::new(AuditEventType::LaunchDenied {
entry_id: entry_id.clone(),
reasons: view.reasons.iter().map(|r| format!("{:?}", r)).collect(),
}));
@ -302,7 +302,9 @@ impl CoreEngine {
};
// Log to audit
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::SessionStarted {
let _ = self
.store
.append_audit(AuditEvent::new(AuditEventType::SessionStarted {
session_id: session.plan.session_id.clone(),
entry_id: session.plan.entry_id.clone(),
label: session.plan.label.clone(),
@ -384,7 +386,9 @@ impl CoreEngine {
session.mark_warning_issued(threshold);
// Log to audit
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::WarningIssued {
let _ = self
.store
.append_audit(AuditEvent::new(AuditEventType::WarningIssued {
session_id: session.plan.session_id.clone(),
threshold_seconds: threshold,
}));
@ -443,17 +447,22 @@ impl CoreEngine {
// Update usage accounting
let today = now.date_naive();
let _ = self.store.add_usage(&session.plan.entry_id, today, duration);
let _ = self
.store
.add_usage(&session.plan.entry_id, today, duration);
// Set cooldown if configured
if let Some(entry) = self.policy.get_entry(&session.plan.entry_id)
&& let Some(cooldown) = entry.limits.cooldown {
&& let Some(cooldown) = entry.limits.cooldown
{
let until = now + chrono::Duration::from_std(cooldown).unwrap();
let _ = self.store.set_cooldown_until(&session.plan.entry_id, until);
}
// Log to audit
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::SessionEnded {
let _ = self
.store
.append_audit(AuditEvent::new(AuditEventType::SessionEnded {
session_id: session.plan.session_id.clone(),
entry_id: session.plan.entry_id.clone(),
reason: reason.clone(),
@ -492,17 +501,22 @@ impl CoreEngine {
// Update usage accounting
let today = now.date_naive();
let _ = self.store.add_usage(&session.plan.entry_id, today, duration);
let _ = self
.store
.add_usage(&session.plan.entry_id, today, duration);
// Set cooldown if configured
if let Some(entry) = self.policy.get_entry(&session.plan.entry_id)
&& let Some(cooldown) = entry.limits.cooldown {
&& let Some(cooldown) = entry.limits.cooldown
{
let until = now + chrono::Duration::from_std(cooldown).unwrap();
let _ = self.store.set_cooldown_until(&session.plan.entry_id, until);
}
// Log to audit
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::SessionEnded {
let _ = self
.store
.append_audit(AuditEvent::new(AuditEventType::SessionEnded {
session_id: session.plan.session_id.clone(),
entry_id: session.plan.entry_id.clone(),
reason: reason.clone(),
@ -525,9 +539,10 @@ impl CoreEngine {
/// Get current service state snapshot
pub fn get_state(&self) -> ServiceStateSnapshot {
let current_session = self.current_session.as_ref().map(|s| {
s.to_session_info(MonotonicInstant::now())
});
let current_session = self
.current_session
.as_ref()
.map(|s| s.to_session_info(MonotonicInstant::now()));
// Build entry views for the snapshot
let entries = self.list_entries(shepherd_util::now());
@ -577,7 +592,9 @@ impl CoreEngine {
session.deadline = Some(new_deadline);
// Log to audit
let _ = self.store.append_audit(AuditEvent::new(AuditEventType::SessionExtended {
let _ = self
.store
.append_audit(AuditEvent::new(AuditEventType::SessionExtended {
session_id: session.plan.session_id.clone(),
extended_by: by,
new_deadline,
@ -597,8 +614,8 @@ impl CoreEngine {
#[cfg(test)]
mod tests {
use super::*;
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy};
use shepherd_api::EntryKind;
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy};
use shepherd_store::SqliteStore;
use std::collections::HashMap;
@ -736,19 +753,34 @@ mod tests {
// No warnings initially (first tick may emit AvailabilitySetChanged)
let events = engine.tick(now_mono, now);
// Filter to just warning events for this test
let warning_events: Vec<_> = events.iter().filter(|e| matches!(e, CoreEvent::Warning { .. })).collect();
let warning_events: Vec<_> = events
.iter()
.filter(|e| matches!(e, CoreEvent::Warning { .. }))
.collect();
assert!(warning_events.is_empty());
// At 70 seconds (10 seconds past warning threshold), warning should fire
let later = now_mono + Duration::from_secs(70);
let events = engine.tick(later, now);
let warning_events: Vec<_> = events.iter().filter(|e| matches!(e, CoreEvent::Warning { .. })).collect();
let warning_events: Vec<_> = events
.iter()
.filter(|e| matches!(e, CoreEvent::Warning { .. }))
.collect();
assert_eq!(warning_events.len(), 1);
assert!(matches!(warning_events[0], CoreEvent::Warning { threshold_seconds: 60, .. }));
assert!(matches!(
warning_events[0],
CoreEvent::Warning {
threshold_seconds: 60,
..
}
));
// Warning shouldn't fire twice
let events = engine.tick(later, now);
let warning_events: Vec<_> = events.iter().filter(|e| matches!(e, CoreEvent::Warning { .. })).collect();
let warning_events: Vec<_> = events
.iter()
.filter(|e| matches!(e, CoreEvent::Warning { .. }))
.collect();
assert!(warning_events.is_empty());
}
@ -803,7 +835,10 @@ mod tests {
let later = now_mono + Duration::from_secs(61);
let events = engine.tick(later, now);
// Filter to just expiry events for this test
let expiry_events: Vec<_> = events.iter().filter(|e| matches!(e, CoreEvent::ExpireDue { .. })).collect();
let expiry_events: Vec<_> = events
.iter()
.filter(|e| matches!(e, CoreEvent::ExpireDue { .. }))
.collect();
assert_eq!(expiry_events.len(), 1);
assert!(matches!(expiry_events[0], CoreEvent::ExpireDue { .. }));
}

View file

@ -30,9 +30,7 @@ pub enum CoreEvent {
},
/// Session is expiring (termination initiated)
ExpireDue {
session_id: SessionId,
},
ExpireDue { session_id: SessionId },
/// Session has ended
SessionEnded {
@ -43,13 +41,8 @@ pub enum CoreEvent {
},
/// Entry availability changed
EntryAvailabilityChanged {
entry_id: EntryId,
enabled: bool,
},
EntryAvailabilityChanged { entry_id: EntryId, enabled: bool },
/// Policy was reloaded
PolicyReloaded {
entry_count: usize,
},
PolicyReloaded { entry_count: usize },
}

View file

@ -29,8 +29,7 @@ impl SessionPlan {
.iter()
.filter(|w| Duration::from_secs(w.seconds_before) < max_duration)
.map(|w| {
let trigger_after =
max_duration - Duration::from_secs(w.seconds_before);
let trigger_after = max_duration - Duration::from_secs(w.seconds_before);
(w.seconds_before, trigger_after)
})
.collect()
@ -67,11 +66,7 @@ pub struct ActiveSession {
impl ActiveSession {
/// Create a new session from an approved plan
pub fn new(
plan: SessionPlan,
now: DateTime<Local>,
now_mono: MonotonicInstant,
) -> Self {
pub fn new(plan: SessionPlan, now: DateTime<Local>, now_mono: MonotonicInstant) -> Self {
let (deadline, deadline_mono) = match plan.max_duration {
Some(max_dur) => {
let deadline = now + chrono::Duration::from_std(max_dur).unwrap();
@ -101,7 +96,8 @@ impl ActiveSession {
/// Get time remaining using monotonic time. None means unlimited.
pub fn time_remaining(&self, now_mono: MonotonicInstant) -> Option<Duration> {
self.deadline_mono.map(|deadline| deadline.saturating_duration_until(now_mono))
self.deadline_mono
.map(|deadline| deadline.saturating_duration_until(now_mono))
}
/// Check if session is expired (never true for unlimited sessions)
@ -220,7 +216,10 @@ mod tests {
assert_eq!(session.state, SessionState::Launching);
assert!(session.warnings_issued.is_empty());
assert_eq!(session.time_remaining(now_mono), Some(Duration::from_secs(300)));
assert_eq!(
session.time_remaining(now_mono),
Some(Duration::from_secs(300))
);
}
#[test]

View file

@ -18,7 +18,10 @@ pub struct HostSessionHandle {
impl HostSessionHandle {
pub fn new(session_id: SessionId, payload: HostHandlePayload) -> Self {
Self { session_id, payload }
Self {
session_id,
payload,
}
}
pub fn payload(&self) -> &HostHandlePayload {
@ -31,27 +34,16 @@ impl HostSessionHandle {
#[serde(tag = "platform", rename_all = "snake_case")]
pub enum HostHandlePayload {
/// Linux: process group ID
Linux {
pid: u32,
pgid: u32,
},
Linux { pid: u32, pgid: u32 },
/// Windows: job object handle (serialized as name/id)
Windows {
job_name: String,
process_id: u32,
},
Windows { job_name: String, process_id: u32 },
/// macOS: bundle or process identifier
MacOs {
pid: u32,
bundle_id: Option<String>,
},
MacOs { pid: u32, bundle_id: Option<String> },
/// Mock for testing
Mock {
id: u64,
},
Mock { id: u64 },
}
impl HostHandlePayload {
@ -117,7 +109,10 @@ mod tests {
fn handle_serialization() {
let handle = HostSessionHandle::new(
SessionId::new(),
HostHandlePayload::Linux { pid: 1234, pgid: 1234 },
HostHandlePayload::Linux {
pid: 1234,
pgid: 1234,
},
);
let json = serde_json::to_string(&handle).unwrap();

View file

@ -10,8 +10,8 @@ use std::time::Duration;
use tokio::sync::mpsc;
use crate::{
ExitStatus, HostAdapter, HostCapabilities, HostError, HostEvent, HostHandlePayload,
HostResult, HostSessionHandle, SpawnOptions, StopMode,
ExitStatus, HostAdapter, HostCapabilities, HostError, HostEvent, HostHandlePayload, HostResult,
HostSessionHandle, SpawnOptions, StopMode,
};
/// Mock session state for testing
@ -79,7 +79,9 @@ impl MockHost {
if let Some(session) = sessions.values().find(|s| &s.session_id == session_id) {
let handle = HostSessionHandle::new(
session.session_id.clone(),
HostHandlePayload::Mock { id: session.mock_id },
HostHandlePayload::Mock {
id: session.mock_id,
},
);
let _ = self.event_tx.send(HostEvent::Exited { handle, status });
}
@ -122,12 +124,13 @@ impl HostAdapter for MockHost {
exit_delay: *self.auto_exit_delay.lock().unwrap(),
};
self.sessions.lock().unwrap().insert(mock_id, session.clone());
self.sessions
.lock()
.unwrap()
.insert(mock_id, session.clone());
let handle = HostSessionHandle::new(
session_id.clone(),
HostHandlePayload::Mock { id: mock_id },
);
let handle =
HostSessionHandle::new(session_id.clone(), HostHandlePayload::Mock { id: mock_id });
// If auto-exit is configured, spawn a task to send exit event
if let Some(delay) = session.exit_delay {

View file

@ -82,9 +82,7 @@ pub enum HostEvent {
},
/// Window is ready (for UI notification)
WindowReady {
handle: HostSessionHandle,
},
WindowReady { handle: HostSessionHandle },
/// Spawn failed after handle was created
SpawnFailed {
@ -141,6 +139,8 @@ mod tests {
#[test]
fn stop_mode_default() {
let mode = StopMode::default();
assert!(matches!(mode, StopMode::Graceful { timeout } if timeout == Duration::from_secs(5)));
assert!(
matches!(mode, StopMode::Graceful { timeout } if timeout == Duration::from_secs(5))
);
}
}

View file

@ -3,8 +3,8 @@
use async_trait::async_trait;
use shepherd_api::EntryKind;
use shepherd_host_api::{
ExitStatus, HostAdapter, HostCapabilities, HostError, HostEvent, HostHandlePayload,
HostResult, HostSessionHandle, SpawnOptions, StopMode,
ExitStatus, HostAdapter, HostCapabilities, HostError, HostEvent, HostHandlePayload, HostResult,
HostSessionHandle, SpawnOptions, StopMode,
};
use shepherd_util::SessionId;
use std::collections::{HashMap, HashSet};
@ -14,8 +14,8 @@ use tokio::sync::mpsc;
use tracing::{info, warn};
use crate::process::{
find_steam_game_pids, init, kill_by_command, kill_flatpak_cgroup, kill_snap_cgroup,
kill_steam_game_processes, ManagedProcess,
ManagedProcess, find_steam_game_pids, init, kill_by_command, kill_flatpak_cgroup,
kill_snap_cgroup, kill_steam_game_processes,
};
/// Expand `~` at the beginning of a path to the user's home directory
@ -93,14 +93,8 @@ impl LinuxHost {
tokio::time::sleep(Duration::from_millis(100)).await;
let mut exited = Vec::new();
let steam_pids: HashSet<u32> = {
steam_sessions
.lock()
.unwrap()
.keys()
.cloned()
.collect()
};
let steam_pids: HashSet<u32> =
{ steam_sessions.lock().unwrap().keys().cloned().collect() };
{
let mut procs = processes.lock().unwrap();
@ -140,14 +134,8 @@ impl LinuxHost {
}
// Track Steam sessions by Steam App ID instead of process exit
let steam_snapshot: Vec<SteamSession> = {
steam_sessions
.lock()
.unwrap()
.values()
.cloned()
.collect()
};
let steam_snapshot: Vec<SteamSession> =
{ steam_sessions.lock().unwrap().values().cloned().collect() };
let mut ended = Vec::new();
@ -206,22 +194,33 @@ impl HostAdapter for LinuxHost {
) -> HostResult<HostSessionHandle> {
// Extract argv, env, cwd, snap_name, flatpak_app_id, and steam_app_id based on entry kind
let (argv, env, cwd, snap_name, flatpak_app_id, steam_app_id) = match entry_kind {
EntryKind::Process { command, args, env, cwd } => {
EntryKind::Process {
command,
args,
env,
cwd,
} => {
let mut argv = vec![expand_tilde(command)];
argv.extend(expand_args(args));
let expanded_cwd = cwd.as_ref().map(|c| {
std::path::PathBuf::from(expand_tilde(&c.to_string_lossy()))
});
let expanded_cwd = cwd
.as_ref()
.map(|c| std::path::PathBuf::from(expand_tilde(&c.to_string_lossy())));
(argv, env.clone(), expanded_cwd, None, None, None)
}
EntryKind::Snap { snap_name, command, args, env } => {
EntryKind::Snap {
snap_name,
command,
args,
env,
} => {
// For snap apps, we need to use 'snap run <snap_name>' to launch them.
// The command (if specified) is passed as an argument after the snap name,
// followed by any additional args.
let mut argv = vec!["snap".to_string(), "run".to_string(), snap_name.clone()];
// If a custom command is specified (different from snap_name), add it
if let Some(cmd) = command
&& cmd != snap_name {
&& cmd != snap_name
{
argv.push(cmd.clone());
}
argv.extend(expand_args(args));
@ -257,13 +256,19 @@ impl HostAdapter for LinuxHost {
}
(argv, HashMap::new(), None, None, None, None)
}
EntryKind::Media { library_id, args: _ } => {
EntryKind::Media {
library_id,
args: _,
} => {
// For media, we'd typically launch a media player
// This is a placeholder - real implementation would integrate with a player
let argv = vec!["xdg-open".to_string(), expand_tilde(library_id)];
(argv, HashMap::new(), None, None, None, None)
}
EntryKind::Custom { type_name: _, payload: _ } => {
EntryKind::Custom {
type_name: _,
payload: _,
} => {
return Err(HostError::UnsupportedKind);
}
};
@ -301,13 +306,13 @@ impl HostAdapter for LinuxHost {
flatpak_app_id: flatpak_app_id.clone(),
steam_app_id,
};
self.session_info.lock().unwrap().insert(session_id.clone(), session_info_entry);
self.session_info
.lock()
.unwrap()
.insert(session_id.clone(), session_info_entry);
info!(session_id = %session_id, command = %command_name, snap = ?snap_name, flatpak = ?flatpak_app_id, "Tracking session info");
let handle = HostSessionHandle::new(
session_id,
HostHandlePayload::Linux { pid, pgid },
);
let handle = HostSessionHandle::new(session_id, HostHandlePayload::Linux { pid, pgid });
self.processes.lock().unwrap().insert(pid, proc);
@ -354,11 +359,15 @@ impl HostAdapter for LinuxHost {
kill_snap_cgroup(snap, nix::sys::signal::Signal::SIGTERM);
info!(snap = %snap, "Sent SIGTERM via snap cgroup");
} else if let Some(app_id) = info.steam_app_id {
let _ = kill_steam_game_processes(app_id, nix::sys::signal::Signal::SIGTERM);
let _ =
kill_steam_game_processes(app_id, nix::sys::signal::Signal::SIGTERM);
if let Ok(mut map) = self.steam_sessions.lock() {
map.entry(pid).and_modify(|entry| entry.seen_game = true);
}
info!(steam_app_id = app_id, "Sent SIGTERM to Steam game processes");
info!(
steam_app_id = app_id,
"Sent SIGTERM to Steam game processes"
);
} else if let Some(ref app_id) = info.flatpak_app_id {
kill_flatpak_cgroup(app_id, nix::sys::signal::Signal::SIGTERM);
info!(flatpak = %app_id, "Sent SIGTERM via flatpak cgroup");
@ -370,7 +379,10 @@ impl HostAdapter for LinuxHost {
}
// Also send SIGTERM via process handle (skip for Steam sessions)
let is_steam = session_info.as_ref().and_then(|info| info.steam_app_id).is_some();
let is_steam = session_info
.as_ref()
.and_then(|info| info.steam_app_id)
.is_some();
if !is_steam {
let procs = self.processes.lock().unwrap();
if let Some(p) = procs.get(&pid) {
@ -388,13 +400,22 @@ impl HostAdapter for LinuxHost {
kill_snap_cgroup(snap, nix::sys::signal::Signal::SIGKILL);
info!(snap = %snap, "Sent SIGKILL via snap cgroup (timeout)");
} else if let Some(app_id) = info.steam_app_id {
let _ = kill_steam_game_processes(app_id, nix::sys::signal::Signal::SIGKILL);
info!(steam_app_id = app_id, "Sent SIGKILL to Steam game processes (timeout)");
let _ = kill_steam_game_processes(
app_id,
nix::sys::signal::Signal::SIGKILL,
);
info!(
steam_app_id = app_id,
"Sent SIGKILL to Steam game processes (timeout)"
);
} else if let Some(ref app_id) = info.flatpak_app_id {
kill_flatpak_cgroup(app_id, nix::sys::signal::Signal::SIGKILL);
info!(flatpak = %app_id, "Sent SIGKILL via flatpak cgroup (timeout)");
} else {
kill_by_command(&info.command_name, nix::sys::signal::Signal::SIGKILL);
kill_by_command(
&info.command_name,
nix::sys::signal::Signal::SIGKILL,
);
info!(command = %info.command_name, "Sent SIGKILL via command name (timeout)");
}
}
@ -433,11 +454,15 @@ impl HostAdapter for LinuxHost {
kill_snap_cgroup(snap, nix::sys::signal::Signal::SIGKILL);
info!(snap = %snap, "Sent SIGKILL via snap cgroup");
} else if let Some(app_id) = info.steam_app_id {
let _ = kill_steam_game_processes(app_id, nix::sys::signal::Signal::SIGKILL);
let _ =
kill_steam_game_processes(app_id, nix::sys::signal::Signal::SIGKILL);
if let Ok(mut map) = self.steam_sessions.lock() {
map.entry(pid).and_modify(|entry| entry.seen_game = true);
}
info!(steam_app_id = app_id, "Sent SIGKILL to Steam game processes");
info!(
steam_app_id = app_id,
"Sent SIGKILL to Steam game processes"
);
} else if let Some(ref app_id) = info.flatpak_app_id {
kill_flatpak_cgroup(app_id, nix::sys::signal::Signal::SIGKILL);
info!(flatpak = %app_id, "Sent SIGKILL via flatpak cgroup");
@ -448,7 +473,10 @@ impl HostAdapter for LinuxHost {
}
// Also force kill via process handle (skip for Steam sessions)
let is_steam = session_info.as_ref().and_then(|info| info.steam_app_id).is_some();
let is_steam = session_info
.as_ref()
.and_then(|info| info.steam_app_id)
.is_some();
if !is_steam {
let procs = self.processes.lock().unwrap();
if let Some(p) = procs.get(&pid) {

View file

@ -83,7 +83,10 @@ pub fn kill_snap_cgroup(snap_name: &str, _signal: Signal) -> bool {
}
if stopped_any {
info!(snap = snap_name, "Killed snap scope(s) via systemctl SIGKILL");
info!(
snap = snap_name,
"Killed snap scope(s) via systemctl SIGKILL"
);
} else {
debug!(snap = snap_name, "No snap scope found to kill");
}
@ -147,7 +150,10 @@ pub fn kill_flatpak_cgroup(app_id: &str, _signal: Signal) -> bool {
}
if stopped_any {
info!(app_id = app_id, "Killed flatpak scope(s) via systemctl SIGKILL");
info!(
app_id = app_id,
"Killed flatpak scope(s) via systemctl SIGKILL"
);
} else {
debug!(app_id = app_id, "No flatpak scope found to kill");
}
@ -226,11 +232,18 @@ pub fn kill_by_command(command_name: &str, signal: Signal) -> bool {
Ok(output) => {
// pkill returns 0 if processes were found and signaled
if output.status.success() {
info!(command = command_name, signal = signal_name, "Killed processes by command name");
info!(
command = command_name,
signal = signal_name,
"Killed processes by command name"
);
true
} else {
// No processes found is not an error
debug!(command = command_name, "No processes found matching command name");
debug!(
command = command_name,
"No processes found matching command name"
);
false
}
}
@ -276,7 +289,8 @@ impl ManagedProcess {
// Build command: script -q -c "original command" logfile
// -q: quiet mode (no start/done messages)
// -c: command to run
let original_cmd = argv.iter()
let original_cmd = argv
.iter()
.map(|arg| shell_escape::escape(std::borrow::Cow::Borrowed(arg)))
.collect::<Vec<_>>()
.join(" ");
@ -465,23 +479,27 @@ impl ManagedProcess {
// SAFETY: This is safe in the pre-exec context
unsafe {
cmd.pre_exec(|| {
nix::unistd::setsid().map_err(|e| {
std::io::Error::other(e.to_string())
})?;
nix::unistd::setsid().map_err(|e| std::io::Error::other(e.to_string()))?;
Ok(())
});
}
let child = cmd.spawn().map_err(|e| {
HostError::SpawnFailed(format!("Failed to spawn {}: {}", program, e))
})?;
let child = cmd
.spawn()
.map_err(|e| HostError::SpawnFailed(format!("Failed to spawn {}: {}", program, e)))?;
let pid = child.id();
let pgid = pid; // After setsid, pid == pgid
info!(pid = pid, pgid = pgid, program = %program, snap = ?snap_name, "Process spawned");
Ok(Self { child, pid, pgid, command_name, snap_name })
Ok(Self {
child,
pid,
pgid,
command_name,
snap_name,
})
}
/// Get all descendant PIDs of this process using /proc
@ -508,7 +526,8 @@ impl ManagedProcess {
let fields: Vec<&str> = after_comm.split_whitespace().collect();
if fields.len() >= 2
&& let Ok(ppid) = fields[1].parse::<i32>()
&& ppid == parent_pid {
&& ppid == parent_pid
{
descendants.push(pid);
to_check.push(pid);
}

View file

@ -148,7 +148,8 @@ impl LinuxVolumeController {
// Output: "Volume: front-left: 65536 / 100% / -0.00 dB, front-right: ..."
if let Some(percent_str) = stdout.split('/').nth(1)
&& let Ok(percent) = percent_str.trim().trim_end_matches('%').parse::<u8>() {
&& let Ok(percent) = percent_str.trim().trim_end_matches('%').parse::<u8>()
{
status.percent = percent;
}
}
@ -185,7 +186,8 @@ impl LinuxVolumeController {
// Extract percentage: [100%]
if let Some(start) = line.find('[')
&& let Some(end) = line[start..].find('%')
&& let Ok(percent) = line[start + 1..start + end].parse::<u8>() {
&& let Ok(percent) = line[start + 1..start + end].parse::<u8>()
{
status.percent = percent;
}
// Check mute status: [on] or [off]
@ -210,7 +212,11 @@ impl LinuxVolumeController {
/// Set volume via PulseAudio
fn set_volume_pulseaudio(percent: u8) -> VolumeResult<()> {
Command::new("pactl")
.args(["set-sink-volume", "@DEFAULT_SINK@", &format!("{}%", percent)])
.args([
"set-sink-volume",
"@DEFAULT_SINK@",
&format!("{}%", percent),
])
.status()
.map_err(|e| VolumeError::Backend(e.to_string()))?;
Ok(())
@ -323,7 +329,10 @@ impl VolumeController for LinuxVolumeController {
async fn volume_up(&self, step: u8) -> VolumeResult<()> {
let current = self.get_status().await?;
let new_volume = current.percent.saturating_add(step).min(self.capabilities.max_volume);
let new_volume = current
.percent
.saturating_add(step)
.min(self.capabilities.max_volume);
self.set_volume(new_volume).await
}

View file

@ -414,9 +414,9 @@ fn build_hud_content(state: SharedState) -> gtk4::Box {
let remaining = time_remaining_at_warning.saturating_sub(elapsed);
time_display_clone.set_remaining(Some(remaining));
// Use configuration-defined message if present, otherwise show time-based message
let warning_text = message.clone().unwrap_or_else(|| {
format!("Only {} seconds remaining!", remaining)
});
let warning_text = message
.clone()
.unwrap_or_else(|| format!("Only {} seconds remaining!", remaining));
warning_label_clone.set_text(&warning_text);
// Apply severity-based CSS classes

View file

@ -35,14 +35,16 @@ impl BatteryStatus {
// Check for battery
if name_str.starts_with("BAT")
&& let Some((percent, charging)) = read_battery_info(&path) {
&& let Some((percent, charging)) = read_battery_info(&path)
{
status.percent = Some(percent);
status.charging = charging;
}
// Check for AC adapter
if (name_str.starts_with("AC") || name_str.contains("ADP"))
&& let Some(online) = read_ac_status(&path) {
&& let Some(online) = read_ac_status(&path)
{
status.ac_connected = online;
}
}

View file

@ -43,8 +43,7 @@ fn main() -> Result<()> {
// Initialize logging
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(&args.log_level)),
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&args.log_level)),
)
.init();

View file

@ -218,7 +218,8 @@ impl SharedState {
entry_name,
..
} = state
&& sid == session_id {
&& sid == session_id
{
*state = SessionState::Warning {
session_id: session_id.clone(),
entry_id: entry_id.clone(),

View file

@ -60,9 +60,7 @@ pub fn toggle_mute() -> anyhow::Result<()> {
shepherd_api::ResponseResult::Ok(ResponsePayload::VolumeDenied { reason }) => {
Err(anyhow::anyhow!("Volume denied: {}", reason))
}
shepherd_api::ResponseResult::Err(e) => {
Err(anyhow::anyhow!("Error: {}", e.message))
}
shepherd_api::ResponseResult::Err(e) => Err(anyhow::anyhow!("Error: {}", e.message)),
_ => Err(anyhow::anyhow!("Unexpected response")),
}
})
@ -83,9 +81,7 @@ pub fn set_volume(percent: u8) -> anyhow::Result<()> {
shepherd_api::ResponseResult::Ok(ResponsePayload::VolumeDenied { reason }) => {
Err(anyhow::anyhow!("Volume denied: {}", reason))
}
shepherd_api::ResponseResult::Err(e) => {
Err(anyhow::anyhow!("Error: {}", e.message))
}
shepherd_api::ResponseResult::Err(e) => Err(anyhow::anyhow!("Error: {}", e.message)),
_ => Err(anyhow::anyhow!("Unexpected response")),
}
})

View file

@ -8,7 +8,7 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::{UnixListener, UnixStream};
use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
use tokio::sync::{Mutex, RwLock, broadcast, mpsc};
use tracing::{debug, error, info, warn};
use crate::{IpcError, IpcResult};
@ -75,10 +75,9 @@ impl IpcServer {
let listener = UnixListener::bind(&self.socket_path)?;
// Set socket permissions (readable/writable by owner and group)
if let Err(err) = std::fs::set_permissions(
&self.socket_path,
std::fs::Permissions::from_mode(0o660),
) {
if let Err(err) =
std::fs::set_permissions(&self.socket_path, std::fs::Permissions::from_mode(0o660))
{
if err.kind() == std::io::ErrorKind::PermissionDenied {
warn!(
path = %self.socket_path.display(),
@ -190,7 +189,8 @@ impl IpcServer {
match serde_json::from_str::<Request>(line) {
Ok(request) => {
// Check for subscribe command
if matches!(request.command, shepherd_api::Command::SubscribeEvents) {
if matches!(request.command, shepherd_api::Command::SubscribeEvents)
{
let mut clients = clients.write().await;
if let Some(handle) = clients.get_mut(&client_id_clone) {
handle.subscribed = true;
@ -342,7 +342,8 @@ mod tests {
let mut server = IpcServer::new(&socket_path);
if let Err(err) = server.start().await {
if let IpcError::Io(ref io_err) = err
&& io_err.kind() == std::io::ErrorKind::PermissionDenied {
&& io_err.kind() == std::io::ErrorKind::PermissionDenied
{
eprintln!(
"Skipping IPC server start test due to permission error: {}",
io_err

View file

@ -162,11 +162,17 @@ impl ServiceClient {
}
ResponsePayload::Entries(entries) => {
// Only update if we're in idle state
if matches!(self.state.get(), LauncherState::Idle { .. } | LauncherState::Connecting) {
if matches!(
self.state.get(),
LauncherState::Idle { .. } | LauncherState::Connecting
) {
self.state.set(LauncherState::Idle { entries });
}
}
ResponsePayload::LaunchApproved { session_id, deadline } => {
ResponsePayload::LaunchApproved {
session_id,
deadline,
} => {
let now = shepherd_util::now();
// For unlimited sessions (deadline=None), time_remaining is None
let time_remaining = deadline.and_then(|d| {
@ -195,9 +201,7 @@ impl ServiceClient {
Ok(())
}
ResponseResult::Err(e) => {
self.state.set(LauncherState::Error {
message: e.message,
});
self.state.set(LauncherState::Error { message: e.message });
Ok(())
}
}
@ -218,17 +222,23 @@ impl CommandClient {
pub async fn launch(&self, entry_id: &EntryId) -> Result<Response> {
let mut client = IpcClient::connect(&self.socket_path).await?;
client.send(Command::Launch {
client
.send(Command::Launch {
entry_id: entry_id.clone(),
}).await.map_err(Into::into)
})
.await
.map_err(Into::into)
}
#[allow(dead_code)]
pub async fn stop_current(&self) -> Result<Response> {
let mut client = IpcClient::connect(&self.socket_path).await?;
client.send(Command::StopCurrent {
client
.send(Command::StopCurrent {
mode: shepherd_api::StopMode::Graceful,
}).await.map_err(Into::into)
})
.await
.map_err(Into::into)
}
pub async fn get_state(&self) -> Result<Response> {
@ -239,7 +249,10 @@ impl CommandClient {
#[allow(dead_code)]
pub async fn list_entries(&self) -> Result<Response> {
let mut client = IpcClient::connect(&self.socket_path).await?;
client.send(Command::ListEntries { at_time: None }).await.map_err(Into::into)
client
.send(Command::ListEntries { at_time: None })
.await
.map_err(Into::into)
}
}

View file

@ -51,7 +51,8 @@ mod imp {
// Configure flow box
self.flow_box.set_homogeneous(true);
self.flow_box.set_selection_mode(gtk4::SelectionMode::Single);
self.flow_box
.set_selection_mode(gtk4::SelectionMode::Single);
self.flow_box.set_max_children_per_line(6);
self.flow_box.set_min_children_per_line(2);
self.flow_box.set_row_spacing(24);
@ -118,7 +119,8 @@ impl LauncherGrid {
let on_launch = imp.on_launch.clone();
tile.connect_clicked(move |tile| {
if let Some(entry_id) = tile.entry_id()
&& let Some(callback) = on_launch.borrow().as_ref() {
&& let Some(callback) = on_launch.borrow().as_ref()
{
callback(entry_id);
}
});
@ -244,7 +246,6 @@ impl LauncherGrid {
}
}
}
}
impl Default for LauncherGrid {

View file

@ -1,6 +1,6 @@
//! Launcher application state management
use shepherd_api::{ServiceStateSnapshot, EntryView, Event, EventPayload};
use shepherd_api::{EntryView, Event, EventPayload, ServiceStateSnapshot};
use shepherd_util::SessionId;
use std::time::Duration;
use tokio::sync::watch;
@ -18,7 +18,7 @@ pub enum LauncherState {
/// Launch requested, waiting for response
Launching {
#[allow(dead_code)]
entry_id: String
entry_id: String,
},
/// Session is running
SessionActive {
@ -62,7 +62,10 @@ impl SharedState {
tracing::info!(event = ?event.payload, "Received event from shepherdd");
match event.payload {
EventPayload::StateChanged(snapshot) => {
tracing::info!(has_session = snapshot.current_session.is_some(), "Applying state snapshot");
tracing::info!(
has_session = snapshot.current_session.is_some(),
"Applying state snapshot"
);
self.apply_snapshot(snapshot);
}
EventPayload::SessionStarted {
@ -87,7 +90,12 @@ impl SharedState {
time_remaining,
});
}
EventPayload::SessionEnded { session_id, entry_id, reason, .. } => {
EventPayload::SessionEnded {
session_id,
entry_id,
reason,
..
} => {
tracing::info!(session_id = %session_id, entry_id = %entry_id, reason = ?reason, "Session ended event - setting Connecting");
// Will be followed by StateChanged, but set to connecting
// to ensure grid reloads

View file

@ -143,7 +143,11 @@ impl LauncherTile {
}
pub fn entry_id(&self) -> Option<shepherd_util::EntryId> {
self.imp().entry.borrow().as_ref().map(|e| e.entry_id.clone())
self.imp()
.entry
.borrow()
.as_ref()
.map(|e| e.entry_id.clone())
}
}

View file

@ -1,7 +1,7 @@
//! SQLite-based store implementation
use chrono::{DateTime, Local, NaiveDate};
use rusqlite::{params, Connection, OptionalExtension};
use rusqlite::{Connection, OptionalExtension, params};
use shepherd_util::EntryId;
use std::path::Path;
use std::sync::Mutex;
@ -98,9 +98,8 @@ impl Store for SqliteStore {
fn get_recent_audits(&self, limit: usize) -> StoreResult<Vec<AuditEvent>> {
let conn = self.conn.lock().unwrap();
let mut stmt = conn.prepare(
"SELECT id, timestamp, event_json FROM audit_log ORDER BY id DESC LIMIT ?",
)?;
let mut stmt = conn
.prepare("SELECT id, timestamp, event_json FROM audit_log ORDER BY id DESC LIMIT ?")?;
let rows = stmt.query_map([limit], |row| {
let id: i64 = row.get(0)?;
@ -181,11 +180,7 @@ impl Store for SqliteStore {
Ok(result)
}
fn set_cooldown_until(
&self,
entry_id: &EntryId,
until: DateTime<Local>,
) -> StoreResult<()> {
fn set_cooldown_until(&self, entry_id: &EntryId, until: DateTime<Local>) -> StoreResult<()> {
let conn = self.conn.lock().unwrap();
conn.execute(
@ -204,7 +199,10 @@ impl Store for SqliteStore {
fn clear_cooldown(&self, entry_id: &EntryId) -> StoreResult<()> {
let conn = self.conn.lock().unwrap();
conn.execute("DELETE FROM cooldowns WHERE entry_id = ?", [entry_id.as_str()])?;
conn.execute(
"DELETE FROM cooldowns WHERE entry_id = ?",
[entry_id.as_str()],
)?;
Ok(())
}
@ -212,9 +210,11 @@ impl Store for SqliteStore {
let conn = self.conn.lock().unwrap();
let json: Option<String> = conn
.query_row("SELECT snapshot_json FROM snapshot WHERE id = 1", [], |row| {
row.get(0)
})
.query_row(
"SELECT snapshot_json FROM snapshot WHERE id = 1",
[],
|row| row.get(0),
)
.optional()?;
match json {
@ -246,9 +246,7 @@ impl Store for SqliteStore {
fn is_healthy(&self) -> bool {
match self.conn.lock() {
Ok(conn) => {
conn.query_row("SELECT 1", [], |_| Ok(())).is_ok()
}
Ok(conn) => conn.query_row("SELECT 1", [], |_| Ok(())).is_ok(),
Err(_) => {
warn!("Store lock poisoned");
false

View file

@ -30,11 +30,7 @@ pub trait Store: Send + Sync {
fn get_cooldown_until(&self, entry_id: &EntryId) -> StoreResult<Option<DateTime<Local>>>;
/// Set cooldown expiry time for an entry
fn set_cooldown_until(
&self,
entry_id: &EntryId,
until: DateTime<Local>,
) -> StoreResult<()>;
fn set_cooldown_until(&self, entry_id: &EntryId, until: DateTime<Local>) -> StoreResult<()>;
/// Clear cooldown for an entry
fn clear_cooldown(&self, entry_id: &EntryId) -> StoreResult<()>;

View file

@ -40,7 +40,9 @@ pub fn default_socket_path() -> PathBuf {
pub fn socket_path_without_env() -> PathBuf {
// Try XDG_RUNTIME_DIR first (typically /run/user/<uid>)
if let Ok(runtime_dir) = std::env::var("XDG_RUNTIME_DIR") {
return PathBuf::from(runtime_dir).join(APP_DIR).join(SOCKET_FILENAME);
return PathBuf::from(runtime_dir)
.join(APP_DIR)
.join(SOCKET_FILENAME);
}
// Fallback to /tmp with username
@ -109,7 +111,10 @@ pub fn default_log_dir() -> PathBuf {
/// Get the parent directory of the socket (for creating it)
pub fn socket_dir() -> PathBuf {
let socket_path = socket_path_without_env();
socket_path.parent().map(|p| p.to_path_buf()).unwrap_or_else(|| {
socket_path
.parent()
.map(|p| p.to_path_buf())
.unwrap_or_else(|| {
// Should never happen with our paths, but just in case
PathBuf::from("/tmp").join(APP_DIR)
})

View file

@ -42,7 +42,10 @@ impl RateLimiter {
pub fn check(&mut self, client_id: &ClientId) -> bool {
let now = Instant::now();
let bucket = self.clients.entry(client_id.clone()).or_insert(ClientBucket {
let bucket = self
.clients
.entry(client_id.clone())
.or_insert(ClientBucket {
tokens: self.max_tokens,
last_refill: now,
});
@ -72,9 +75,8 @@ impl RateLimiter {
/// Clean up stale client entries
pub fn cleanup(&mut self, stale_after: Duration) {
let now = Instant::now();
self.clients.retain(|_, bucket| {
now.duration_since(bucket.last_refill) < stale_after
});
self.clients
.retain(|_, bucket| now.duration_since(bucket.last_refill) < stale_after);
}
}

View file

@ -37,7 +37,9 @@ fn get_mock_time_offset() -> Option<chrono::Duration> {
{
if let Ok(mock_time_str) = std::env::var(MOCK_TIME_ENV_VAR) {
// Parse the mock time string
if let Ok(naive_dt) = NaiveDateTime::parse_from_str(&mock_time_str, "%Y-%m-%d %H:%M:%S") {
if let Ok(naive_dt) =
NaiveDateTime::parse_from_str(&mock_time_str, "%Y-%m-%d %H:%M:%S")
{
if let Some(mock_dt) = Local.from_local_datetime(&naive_dt).single() {
let real_now = chrono::Local::now();
let offset = mock_dt.signed_duration_since(real_now);
@ -201,9 +203,8 @@ impl DaysOfWeek {
pub const SATURDAY: u8 = 1 << 5;
pub const SUNDAY: u8 = 1 << 6;
pub const WEEKDAYS: DaysOfWeek = DaysOfWeek(
Self::MONDAY | Self::TUESDAY | Self::WEDNESDAY | Self::THURSDAY | Self::FRIDAY,
);
pub const WEEKDAYS: DaysOfWeek =
DaysOfWeek(Self::MONDAY | Self::TUESDAY | Self::WEDNESDAY | Self::THURSDAY | Self::FRIDAY);
pub const WEEKENDS: DaysOfWeek = DaysOfWeek(Self::SATURDAY | Self::SUNDAY);
pub const ALL_DAYS: DaysOfWeek = DaysOfWeek(0x7F);
pub const NONE: DaysOfWeek = DaysOfWeek(0);
@ -534,15 +535,24 @@ mod tests {
// Time within window
let in_window = Local.with_ymd_and_hms(2025, 12, 25, 15, 0, 0).unwrap();
assert!(window.contains(&in_window), "15:00 should be within 14:00-18:00 window");
assert!(
window.contains(&in_window),
"15:00 should be within 14:00-18:00 window"
);
// Time before window
let before_window = Local.with_ymd_and_hms(2025, 12, 25, 10, 0, 0).unwrap();
assert!(!window.contains(&before_window), "10:00 should be before 14:00-18:00 window");
assert!(
!window.contains(&before_window),
"10:00 should be before 14:00-18:00 window"
);
// Time after window
let after_window = Local.with_ymd_and_hms(2025, 12, 25, 20, 0, 0).unwrap();
assert!(!window.contains(&after_window), "20:00 should be after 14:00-18:00 window");
assert!(
!window.contains(&after_window),
"20:00 should be after 14:00-18:00 window"
);
}
#[test]
@ -556,15 +566,24 @@ mod tests {
// Thursday at 3 PM - should be available (weekday, in time window)
let thursday = Local.with_ymd_and_hms(2025, 12, 25, 15, 0, 0).unwrap(); // Christmas 2025 is Thursday
assert!(window.contains(&thursday), "Thursday 15:00 should be in weekday afternoon window");
assert!(
window.contains(&thursday),
"Thursday 15:00 should be in weekday afternoon window"
);
// Saturday at 3 PM - should NOT be available (weekend)
let saturday = Local.with_ymd_and_hms(2025, 12, 27, 15, 0, 0).unwrap();
assert!(!window.contains(&saturday), "Saturday should not be in weekday window");
assert!(
!window.contains(&saturday),
"Saturday should not be in weekday window"
);
// Sunday at 3 PM - should NOT be available (weekend)
let sunday = Local.with_ymd_and_hms(2025, 12, 28, 15, 0, 0).unwrap();
assert!(!window.contains(&sunday), "Sunday should not be in weekday window");
assert!(
!window.contains(&sunday),
"Sunday should not be in weekday window"
);
}
}

View file

@ -26,7 +26,8 @@ impl InternetMonitor {
for entry in &policy.entries {
if entry.internet.required
&& let Some(check) = entry.internet.check.clone()
&& !targets.contains(&check) {
&& !targets.contains(&check)
{
targets.push(check);
}
}

View file

@ -12,20 +12,20 @@
use anyhow::{Context, Result};
use clap::Parser;
use shepherd_api::{
Command, ErrorCode, ErrorInfo, Event, EventPayload, HealthStatus,
Response, ResponsePayload, SessionEndReason, StopMode, VolumeInfo, VolumeRestrictions,
Command, ErrorCode, ErrorInfo, Event, EventPayload, HealthStatus, Response, ResponsePayload,
SessionEndReason, StopMode, VolumeInfo, VolumeRestrictions,
};
use shepherd_config::{load_config, VolumePolicy};
use shepherd_config::{VolumePolicy, load_config};
use shepherd_core::{CoreEngine, CoreEvent, LaunchDecision, StopDecision};
use shepherd_host_api::{HostAdapter, HostEvent, StopMode as HostStopMode, VolumeController};
use shepherd_host_linux::{LinuxHost, LinuxVolumeController};
use shepherd_ipc::{IpcServer, ServerMessage};
use shepherd_store::{AuditEvent, AuditEventType, SqliteStore, Store};
use shepherd_util::{default_config_path, ClientId, MonotonicInstant, RateLimiter};
use shepherd_util::{ClientId, MonotonicInstant, RateLimiter, default_config_path};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::signal::unix::{signal, SignalKind};
use tokio::signal::unix::{SignalKind, signal};
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
use tracing_subscriber::EnvFilter;
@ -180,12 +180,11 @@ impl Service {
});
// Set up signal handlers
let mut sigterm = signal(SignalKind::terminate())
.context("Failed to create SIGTERM handler")?;
let mut sigint = signal(SignalKind::interrupt())
.context("Failed to create SIGINT handler")?;
let mut sighup = signal(SignalKind::hangup())
.context("Failed to create SIGHUP handler")?;
let mut sigterm =
signal(SignalKind::terminate()).context("Failed to create SIGTERM handler")?;
let mut sigint =
signal(SignalKind::interrupt()).context("Failed to create SIGINT handler")?;
let mut sighup = signal(SignalKind::hangup()).context("Failed to create SIGHUP handler")?;
// Main event loop
let tick_interval = Duration::from_millis(100);
@ -246,9 +245,16 @@ impl Service {
let engine = engine.lock().await;
if let Some(session) = engine.current_session() {
info!(session_id = %session.plan.session_id, "Stopping active session");
if let Some(handle) = &session.host_handle && let Err(e) = host.stop(handle, HostStopMode::Graceful {
if let Some(handle) = &session.host_handle
&& let Err(e) = host
.stop(
handle,
HostStopMode::Graceful {
timeout: Duration::from_secs(5),
}).await {
},
)
.await
{
warn!(error = %e, "Failed to stop session gracefully");
}
}
@ -301,9 +307,7 @@ impl Service {
// Get the host handle and stop it
let handle = {
let engine = engine.lock().await;
engine
.current_session()
.and_then(|s| s.host_handle.clone())
engine.current_session().and_then(|s| s.host_handle.clone())
};
if let Some(handle) = handle
@ -405,7 +409,10 @@ impl Service {
engine.notify_session_exited(status.code, now_mono, now)
};
info!(has_event = core_event.is_some(), "notify_session_exited result");
info!(
has_event = core_event.is_some(),
"notify_session_exited result"
);
if let Some(CoreEvent::SessionEnded {
session_id,
@ -472,8 +479,16 @@ impl Service {
}
}
let response =
Self::handle_command(engine, host, volume, ipc, store, &client_id, request.request_id, request.command)
let response = Self::handle_command(
engine,
host,
volume,
ipc,
store,
&client_id,
request.request_id,
request.command,
)
.await;
let _ = ipc.send_response(&client_id, response).await;
@ -487,23 +502,19 @@ impl Service {
"Client connected"
);
let _ = store.append_audit(AuditEvent::new(
AuditEventType::ClientConnected {
let _ = store.append_audit(AuditEvent::new(AuditEventType::ClientConnected {
client_id: client_id.to_string(),
role: format!("{:?}", info.role),
uid: info.uid,
},
));
}));
}
ServerMessage::ClientDisconnected { client_id } => {
debug!(client_id = %client_id, "Client disconnected");
let _ = store.append_audit(AuditEvent::new(
AuditEventType::ClientDisconnected {
let _ = store.append_audit(AuditEvent::new(AuditEventType::ClientDisconnected {
client_id: client_id.to_string(),
},
));
}));
// Clean up rate limiter
let mut limiter = rate_limiter.lock().await;
@ -547,10 +558,7 @@ impl Service {
let event = eng.start_session(plan.clone(), now, now_mono);
// Get the entry kind for spawning
let entry_kind = eng
.policy()
.get_entry(&entry_id)
.map(|e| e.kind.clone());
let entry_kind = eng.policy().get_entry(&entry_id).map(|e| e.kind.clone());
// Build spawn options with log path if capture_child_output is enabled
let spawn_options = if eng.policy().service.capture_child_output {
@ -577,11 +585,7 @@ impl Service {
if let Some(kind) = entry_kind {
match host
.spawn(
plan.session_id.clone(),
&kind,
spawn_options,
)
.spawn(plan.session_id.clone(), &kind, spawn_options)
.await
{
Ok(handle) => {
@ -597,12 +601,14 @@ impl Service {
deadline,
} = event
{
ipc.broadcast_event(Event::new(EventPayload::SessionStarted {
ipc.broadcast_event(Event::new(
EventPayload::SessionStarted {
session_id: session_id.clone(),
entry_id,
label,
deadline,
}));
},
));
Response::success(
request_id,
@ -614,7 +620,10 @@ impl Service {
} else {
Response::error(
request_id,
ErrorInfo::new(ErrorCode::InternalError, "Unexpected event"),
ErrorInfo::new(
ErrorCode::InternalError,
"Unexpected event",
),
)
}
}
@ -628,16 +637,20 @@ impl Service {
duration,
}) = eng.notify_session_exited(Some(-1), now_mono, now)
{
ipc.broadcast_event(Event::new(EventPayload::SessionEnded {
ipc.broadcast_event(Event::new(
EventPayload::SessionEnded {
session_id,
entry_id,
reason,
duration,
}));
},
));
// Broadcast state change so clients return to idle
let state = eng.get_state();
ipc.broadcast_event(Event::new(EventPayload::StateChanged(state)));
ipc.broadcast_event(Event::new(
EventPayload::StateChanged(state),
));
}
Response::error(
@ -666,9 +679,7 @@ impl Service {
let mut eng = engine.lock().await;
// Get handle before stopping in engine
let handle = eng
.current_session()
.and_then(|s| s.host_handle.clone());
let handle = eng.current_session().and_then(|s| s.host_handle.clone());
let reason = match mode {
StopMode::Graceful => SessionEndReason::UserStop,
@ -719,7 +730,8 @@ impl Service {
Command::ReloadConfig => {
// Check permission
if let Some(info) = ipc.get_client_info(client_id).await
&& !info.role.can_reload_config() {
&& !info.role.can_reload_config()
{
return Response::error(
request_id,
ErrorInfo::new(ErrorCode::PermissionDenied, "Admin role required"),
@ -733,14 +745,12 @@ impl Service {
)
}
Command::SubscribeEvents => {
Response::success(
Command::SubscribeEvents => Response::success(
request_id,
ResponsePayload::Subscribed {
client_id: client_id.clone(),
},
)
}
),
Command::UnsubscribeEvents => {
Response::success(request_id, ResponsePayload::Unsubscribed)
@ -761,7 +771,8 @@ impl Service {
Command::ExtendCurrent { by } => {
// Check permission
if let Some(info) = ipc.get_client_info(client_id).await
&& !info.role.can_extend() {
&& !info.role.can_extend()
{
return Response::error(
request_id,
ErrorInfo::new(ErrorCode::PermissionDenied, "Admin role required"),
@ -770,12 +781,18 @@ impl Service {
let mut eng = engine.lock().await;
match eng.extend_current(by, now_mono, now) {
Some(new_deadline) => {
Response::success(request_id, ResponsePayload::Extended { new_deadline: Some(new_deadline) })
}
Some(new_deadline) => Response::success(
request_id,
ResponsePayload::Extended {
new_deadline: Some(new_deadline),
},
),
None => Response::error(
request_id,
ErrorInfo::new(ErrorCode::NoActiveSession, "No active session or session is unlimited"),
ErrorInfo::new(
ErrorCode::NoActiveSession,
"No active session or session is unlimited",
),
),
}
}
@ -917,7 +934,8 @@ impl Service {
// Check if there's an active session with volume restrictions
if let Some(session) = eng.current_session()
&& let Some(entry) = eng.policy().get_entry(&session.plan.entry_id)
&& let Some(ref vol_policy) = entry.volume {
&& let Some(ref vol_policy) = entry.volume
{
return Self::convert_volume_policy(vol_policy);
}
@ -940,18 +958,15 @@ async fn main() -> Result<()> {
let args = Args::parse();
// Initialize logging
let filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(&args.log_level));
let filter =
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&args.log_level));
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(true)
.init();
info!(
version = env!("CARGO_PKG_VERSION"),
"shepherdd starting"
);
info!(version = env!("CARGO_PKG_VERSION"), "shepherdd starting");
// Create and run the service
let service = Service::new(&args).await?;

View file

@ -15,8 +15,7 @@ use std::time::Duration;
fn make_test_policy() -> Policy {
Policy {
service: Default::default(),
entries: vec![
Entry {
entries: vec![Entry {
id: EntryId::new("test-game"),
label: "Test Game".into(),
icon_ref: None,
@ -51,8 +50,7 @@ fn make_test_policy() -> Policy {
disabled: false,
disabled_reason: None,
internet: Default::default(),
},
],
}],
default_warnings: vec![],
default_max_run: Some(Duration::from_secs(3600)),
volume: Default::default(),
@ -91,7 +89,9 @@ fn test_launch_approval() {
let entry_id = EntryId::new("test-game");
let decision = engine.request_launch(&entry_id, shepherd_util::now());
assert!(matches!(decision, LaunchDecision::Approved(plan) if plan.max_duration == Some(Duration::from_secs(10))));
assert!(
matches!(decision, LaunchDecision::Approved(plan) if plan.max_duration == Some(Duration::from_secs(10)))
);
}
#[test]
@ -150,14 +150,26 @@ fn test_warning_emission() {
let at_6s = now + chrono::Duration::seconds(6);
let events = engine.tick(at_6s_mono, at_6s);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], CoreEvent::Warning { threshold_seconds: 5, .. }));
assert!(matches!(
&events[0],
CoreEvent::Warning {
threshold_seconds: 5,
..
}
));
// At 9 seconds (1 second remaining), 2-second warning should fire
let at_9s_mono = now_mono + Duration::from_secs(9);
let at_9s = now + chrono::Duration::seconds(9);
let events = engine.tick(at_9s_mono, at_9s);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], CoreEvent::Warning { threshold_seconds: 2, .. }));
assert!(matches!(
&events[0],
CoreEvent::Warning {
threshold_seconds: 2,
..
}
));
// Warnings shouldn't repeat
let events = engine.tick(at_9s_mono, at_9s);
@ -188,7 +200,9 @@ fn test_session_expiry() {
let events = engine.tick(at_11s_mono, at_11s);
// Should have both remaining warnings + expiry
let has_expiry = events.iter().any(|e| matches!(e, CoreEvent::ExpireDue { .. }));
let has_expiry = events
.iter()
.any(|e| matches!(e, CoreEvent::ExpireDue { .. }));
assert!(has_expiry, "Expected ExpireDue event");
}
@ -291,9 +305,18 @@ fn test_config_parsing() {
let policy = parse_config(config).unwrap();
assert_eq!(policy.entries.len(), 1);
assert_eq!(policy.entries[0].id.as_str(), "scummvm");
assert_eq!(policy.entries[0].limits.max_run, Some(Duration::from_secs(3600)));
assert_eq!(policy.entries[0].limits.daily_quota, Some(Duration::from_secs(7200)));
assert_eq!(policy.entries[0].limits.cooldown, Some(Duration::from_secs(300)));
assert_eq!(
policy.entries[0].limits.max_run,
Some(Duration::from_secs(3600))
);
assert_eq!(
policy.entries[0].limits.daily_quota,
Some(Duration::from_secs(7200))
);
assert_eq!(
policy.entries[0].limits.cooldown,
Some(Duration::from_secs(300))
);
assert_eq!(policy.entries[0].warnings.len(), 1);
}
@ -316,7 +339,11 @@ fn test_session_extension() {
engine.start_session(plan, now, now_mono);
// Get original deadline (should be Some for this test)
let original_deadline = engine.current_session().unwrap().deadline.expect("Expected deadline");
let original_deadline = engine
.current_session()
.unwrap()
.deadline
.expect("Expected deadline");
// Extend by 5 minutes
let new_deadline = engine.extend_current(Duration::from_secs(300), now_mono, now);