WIP: network check
This commit is contained in:
parent
dc58817aea
commit
266685628e
18 changed files with 1710 additions and 33 deletions
886
Cargo.lock
generated
886
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -58,8 +58,14 @@ anyhow = "1.0"
|
||||||
uuid = { version = "1.6", features = ["v4", "serde"] }
|
uuid = { version = "1.6", features = ["v4", "serde"] }
|
||||||
bitflags = "2.4"
|
bitflags = "2.4"
|
||||||
|
|
||||||
|
# HTTP client (for connectivity checks)
|
||||||
|
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
|
||||||
|
|
||||||
# Unix-specific
|
# Unix-specific
|
||||||
nix = { version = "0.29", features = ["signal", "process", "user", "socket"] }
|
nix = { version = "0.29", features = ["signal", "process", "user", "socket"] }
|
||||||
|
netlink-sys = "0.8"
|
||||||
|
netlink-packet-core = "0.7"
|
||||||
|
netlink-packet-route = "0.21"
|
||||||
|
|
||||||
# CLI
|
# CLI
|
||||||
clap = { version = "4.5", features = ["derive", "env"] }
|
clap = { version = "4.5", features = ["derive", "env"] }
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,16 @@ max_volume = 80 # Maximum volume percentage (0-100)
|
||||||
allow_mute = true # Whether mute toggle is allowed
|
allow_mute = true # Whether mute toggle is allowed
|
||||||
allow_change = true # Whether volume changes are allowed at all
|
allow_change = true # Whether volume changes are allowed at all
|
||||||
|
|
||||||
|
# Network connectivity settings (optional)
|
||||||
|
# Used to check if Internet is available before allowing network-dependent entries
|
||||||
|
[service.network]
|
||||||
|
# URL to check for global network connectivity (default: Google's connectivity check)
|
||||||
|
# check_url = "http://connectivitycheck.gstatic.com/generate_204"
|
||||||
|
# How often to perform periodic connectivity checks, in seconds (default: 60)
|
||||||
|
# check_interval_seconds = 60
|
||||||
|
# Timeout for connectivity checks, in seconds (default: 5)
|
||||||
|
# check_timeout_seconds = 5
|
||||||
|
|
||||||
# Default warning thresholds
|
# Default warning thresholds
|
||||||
[[service.default_warnings]]
|
[[service.default_warnings]]
|
||||||
seconds_before = 300
|
seconds_before = 300
|
||||||
|
|
@ -211,6 +221,11 @@ message = "30 seconds! Save NOW!"
|
||||||
[entries.volume]
|
[entries.volume]
|
||||||
max_volume = 60 # Limit volume during gaming sessions
|
max_volume = 60 # Limit volume during gaming sessions
|
||||||
|
|
||||||
|
# Network requirements for online games
|
||||||
|
[entries.network]
|
||||||
|
required = true # Minecraft needs network for authentication and multiplayer
|
||||||
|
check_url = "http://www.msftconnecttest.com/connecttest.txt" # Use Microsoft's check (Minecraft is owned by Microsoft)
|
||||||
|
|
||||||
## === Steam games ===
|
## === Steam games ===
|
||||||
# Steam can be used via Canonical's Steam snap package:
|
# Steam can be used via Canonical's Steam snap package:
|
||||||
# https://snapcraft.io/steam
|
# https://snapcraft.io/steam
|
||||||
|
|
@ -244,6 +259,9 @@ end = "20:00"
|
||||||
# No [entries.limits] section - uses service defaults
|
# No [entries.limits] section - uses service defaults
|
||||||
# Omitting limits entirely uses default_max_run_seconds
|
# Omitting limits entirely uses default_max_run_seconds
|
||||||
|
|
||||||
|
[entries.network]
|
||||||
|
required = true # Steam needs network for authentication
|
||||||
|
|
||||||
# A Short Hike via Steam
|
# A Short Hike via Steam
|
||||||
# https://store.steampowered.com/app/1055540/A_Short_Hike/
|
# https://store.steampowered.com/app/1055540/A_Short_Hike/
|
||||||
[[entries]]
|
[[entries]]
|
||||||
|
|
@ -267,6 +285,9 @@ days = "weekends"
|
||||||
start = "10:00"
|
start = "10:00"
|
||||||
end = "20:00"
|
end = "20:00"
|
||||||
|
|
||||||
|
[entries.network]
|
||||||
|
required = true # Steam needs network for authentication
|
||||||
|
|
||||||
## === Media ===
|
## === Media ===
|
||||||
# Just use `mpv` to play media (for now).
|
# Just use `mpv` to play media (for now).
|
||||||
# Files can be local on your system or URLs (YouTube, etc).
|
# Files can be local on your system or URLs (YouTube, etc).
|
||||||
|
|
@ -314,6 +335,9 @@ max_run_seconds = 0 # Unlimited: sleep/study aid
|
||||||
daily_quota_seconds = 0 # Unlimited
|
daily_quota_seconds = 0 # Unlimited
|
||||||
cooldown_seconds = 0 # No cooldown
|
cooldown_seconds = 0 # No cooldown
|
||||||
|
|
||||||
|
[entries.network]
|
||||||
|
required = true # YouTube streaming needs network
|
||||||
|
|
||||||
# Terminal for debugging only
|
# Terminal for debugging only
|
||||||
[[entries]]
|
[[entries]]
|
||||||
id = "terminal"
|
id = "terminal"
|
||||||
|
|
|
||||||
|
|
@ -234,6 +234,7 @@ mod tests {
|
||||||
current_session: None,
|
current_session: None,
|
||||||
entry_count: 5,
|
entry_count: 5,
|
||||||
entries: vec![],
|
entries: vec![],
|
||||||
|
connectivity: Default::default(),
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,14 @@ pub enum EventPayload {
|
||||||
event_type: String,
|
event_type: String,
|
||||||
details: serde_json::Value,
|
details: serde_json::Value,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// Network connectivity status changed
|
||||||
|
ConnectivityChanged {
|
||||||
|
/// Whether global connectivity check now passes
|
||||||
|
connected: bool,
|
||||||
|
/// The URL that was checked
|
||||||
|
check_url: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
|
|
@ -121,6 +121,11 @@ pub enum ReasonCode {
|
||||||
Disabled {
|
Disabled {
|
||||||
reason: Option<String>,
|
reason: Option<String>,
|
||||||
},
|
},
|
||||||
|
/// Network connectivity check failed
|
||||||
|
NetworkUnavailable {
|
||||||
|
/// The URL that was checked
|
||||||
|
check_url: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Warning severity level
|
/// Warning severity level
|
||||||
|
|
@ -197,6 +202,20 @@ pub struct ServiceStateSnapshot {
|
||||||
/// Available entries for UI display
|
/// Available entries for UI display
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub entries: Vec<EntryView>,
|
pub entries: Vec<EntryView>,
|
||||||
|
/// Network connectivity status
|
||||||
|
#[serde(default)]
|
||||||
|
pub connectivity: ConnectivityStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Network connectivity status
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
pub struct ConnectivityStatus {
|
||||||
|
/// Whether global network connectivity check passed
|
||||||
|
pub connected: bool,
|
||||||
|
/// The URL that was checked for global connectivity
|
||||||
|
pub check_url: Option<String>,
|
||||||
|
/// When the last check was performed
|
||||||
|
pub last_check: Option<DateTime<Local>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Role for authorization
|
/// Role for authorization
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
//! Validated policy structures
|
//! Validated policy structures
|
||||||
|
|
||||||
use crate::schema::{RawConfig, RawEntry, RawEntryKind, RawVolumeConfig, RawServiceConfig, RawWarningThreshold};
|
use crate::schema::{RawConfig, RawEntry, RawEntryKind, RawNetworkConfig, RawEntryNetwork, RawVolumeConfig, RawServiceConfig, RawWarningThreshold};
|
||||||
use crate::validation::{parse_days, parse_time};
|
use crate::validation::{parse_days, parse_time};
|
||||||
use shepherd_api::{EntryKind, WarningSeverity, WarningThreshold};
|
use shepherd_api::{EntryKind, WarningSeverity, WarningThreshold};
|
||||||
use shepherd_util::{DaysOfWeek, EntryId, TimeWindow, WallClock, default_data_dir, default_log_dir, socket_path_without_env};
|
use shepherd_util::{DaysOfWeek, EntryId, TimeWindow, WallClock, default_data_dir, default_log_dir, socket_path_without_env};
|
||||||
|
|
@ -24,6 +24,9 @@ pub struct Policy {
|
||||||
|
|
||||||
/// Global volume restrictions
|
/// Global volume restrictions
|
||||||
pub volume: VolumePolicy,
|
pub volume: VolumePolicy,
|
||||||
|
|
||||||
|
/// Network connectivity policy
|
||||||
|
pub network: NetworkPolicy,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Policy {
|
impl Policy {
|
||||||
|
|
@ -50,6 +53,13 @@ impl Policy {
|
||||||
.map(convert_volume_config)
|
.map(convert_volume_config)
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let network = raw
|
||||||
|
.service
|
||||||
|
.network
|
||||||
|
.as_ref()
|
||||||
|
.map(convert_network_config)
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
let entries = raw
|
let entries = raw
|
||||||
.entries
|
.entries
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|
@ -62,6 +72,7 @@ impl Policy {
|
||||||
default_warnings,
|
default_warnings,
|
||||||
default_max_run,
|
default_max_run,
|
||||||
volume: global_volume,
|
volume: global_volume,
|
||||||
|
network,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -130,6 +141,7 @@ pub struct Entry {
|
||||||
pub limits: LimitsPolicy,
|
pub limits: LimitsPolicy,
|
||||||
pub warnings: Vec<WarningThreshold>,
|
pub warnings: Vec<WarningThreshold>,
|
||||||
pub volume: Option<VolumePolicy>,
|
pub volume: Option<VolumePolicy>,
|
||||||
|
pub network: NetworkRequirement,
|
||||||
pub disabled: bool,
|
pub disabled: bool,
|
||||||
pub disabled_reason: Option<String>,
|
pub disabled_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
@ -159,6 +171,11 @@ impl Entry {
|
||||||
.map(|w| w.into_iter().map(convert_warning).collect())
|
.map(|w| w.into_iter().map(convert_warning).collect())
|
||||||
.unwrap_or_else(|| default_warnings.to_vec());
|
.unwrap_or_else(|| default_warnings.to_vec());
|
||||||
let volume = raw.volume.as_ref().map(convert_volume_config);
|
let volume = raw.volume.as_ref().map(convert_volume_config);
|
||||||
|
let network = raw
|
||||||
|
.network
|
||||||
|
.as_ref()
|
||||||
|
.map(convert_entry_network)
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
id: EntryId::new(raw.id),
|
id: EntryId::new(raw.id),
|
||||||
|
|
@ -169,6 +186,7 @@ impl Entry {
|
||||||
limits,
|
limits,
|
||||||
warnings,
|
warnings,
|
||||||
volume,
|
volume,
|
||||||
|
network,
|
||||||
disabled: raw.disabled,
|
disabled: raw.disabled,
|
||||||
disabled_reason: raw.disabled_reason,
|
disabled_reason: raw.disabled_reason,
|
||||||
}
|
}
|
||||||
|
|
@ -250,6 +268,52 @@ impl VolumePolicy {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Default connectivity check URL (Google's connectivity check service)
|
||||||
|
pub const DEFAULT_CHECK_URL: &str = "http://connectivitycheck.gstatic.com/generate_204";
|
||||||
|
|
||||||
|
/// Default interval for periodic connectivity checks (60 seconds)
|
||||||
|
pub const DEFAULT_CHECK_INTERVAL_SECS: u64 = 60;
|
||||||
|
|
||||||
|
/// Default timeout for connectivity checks (5 seconds)
|
||||||
|
pub const DEFAULT_CHECK_TIMEOUT_SECS: u64 = 5;
|
||||||
|
|
||||||
|
/// Network connectivity policy
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct NetworkPolicy {
|
||||||
|
/// URL to check for global network connectivity
|
||||||
|
pub check_url: String,
|
||||||
|
/// How often to perform periodic connectivity checks
|
||||||
|
pub check_interval: Duration,
|
||||||
|
/// Timeout for connectivity checks
|
||||||
|
pub check_timeout: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for NetworkPolicy {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
check_url: DEFAULT_CHECK_URL.to_string(),
|
||||||
|
check_interval: Duration::from_secs(DEFAULT_CHECK_INTERVAL_SECS),
|
||||||
|
check_timeout: Duration::from_secs(DEFAULT_CHECK_TIMEOUT_SECS),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Network requirements for a specific entry
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct NetworkRequirement {
|
||||||
|
/// Whether this entry requires network connectivity to launch
|
||||||
|
pub required: bool,
|
||||||
|
/// Override check URL for this entry (uses global if None)
|
||||||
|
pub check_url_override: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NetworkRequirement {
|
||||||
|
/// Get the check URL to use for this entry, given the global policy
|
||||||
|
pub fn effective_check_url<'a>(&'a self, global: &'a NetworkPolicy) -> &'a str {
|
||||||
|
self.check_url_override.as_deref().unwrap_or(&global.check_url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Conversion helpers
|
// Conversion helpers
|
||||||
|
|
||||||
fn convert_entry_kind(raw: RawEntryKind) -> EntryKind {
|
fn convert_entry_kind(raw: RawEntryKind) -> EntryKind {
|
||||||
|
|
@ -282,6 +346,21 @@ fn convert_volume_config(raw: &RawVolumeConfig) -> VolumePolicy {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn convert_network_config(raw: &RawNetworkConfig) -> NetworkPolicy {
|
||||||
|
NetworkPolicy {
|
||||||
|
check_url: raw.check_url.clone().unwrap_or_else(|| DEFAULT_CHECK_URL.to_string()),
|
||||||
|
check_interval: Duration::from_secs(raw.check_interval_seconds.unwrap_or(DEFAULT_CHECK_INTERVAL_SECS)),
|
||||||
|
check_timeout: Duration::from_secs(raw.check_timeout_seconds.unwrap_or(DEFAULT_CHECK_TIMEOUT_SECS)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_entry_network(raw: &RawEntryNetwork) -> NetworkRequirement {
|
||||||
|
NetworkRequirement {
|
||||||
|
required: raw.required,
|
||||||
|
check_url_override: raw.check_url.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn convert_time_window(raw: crate::schema::RawTimeWindow) -> TimeWindow {
|
fn convert_time_window(raw: crate::schema::RawTimeWindow) -> TimeWindow {
|
||||||
let days_mask = parse_days(&raw.days).unwrap_or(0x7F);
|
let days_mask = parse_days(&raw.days).unwrap_or(0x7F);
|
||||||
let (start_h, start_m) = parse_time(&raw.start).unwrap_or((0, 0));
|
let (start_h, start_m) = parse_time(&raw.start).unwrap_or((0, 0));
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,26 @@ pub struct RawServiceConfig {
|
||||||
/// Global volume restrictions
|
/// Global volume restrictions
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub volume: Option<RawVolumeConfig>,
|
pub volume: Option<RawVolumeConfig>,
|
||||||
|
|
||||||
|
/// Network connectivity settings
|
||||||
|
#[serde(default)]
|
||||||
|
pub network: Option<RawNetworkConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Network connectivity configuration
|
||||||
|
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||||
|
pub struct RawNetworkConfig {
|
||||||
|
/// URL to check for global network connectivity
|
||||||
|
/// Default: "http://connectivitycheck.gstatic.com/generate_204"
|
||||||
|
pub check_url: Option<String>,
|
||||||
|
|
||||||
|
/// How often to perform periodic connectivity checks (in seconds)
|
||||||
|
/// Default: 30
|
||||||
|
pub check_interval_seconds: Option<u64>,
|
||||||
|
|
||||||
|
/// Timeout for connectivity checks (in seconds)
|
||||||
|
/// Default: 5
|
||||||
|
pub check_timeout_seconds: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Raw entry definition
|
/// Raw entry definition
|
||||||
|
|
@ -81,6 +101,10 @@ pub struct RawEntry {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub volume: Option<RawVolumeConfig>,
|
pub volume: Option<RawVolumeConfig>,
|
||||||
|
|
||||||
|
/// Network requirements for this entry
|
||||||
|
#[serde(default)]
|
||||||
|
pub network: Option<RawEntryNetwork>,
|
||||||
|
|
||||||
/// Explicitly disabled
|
/// Explicitly disabled
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub disabled: bool,
|
pub disabled: bool,
|
||||||
|
|
@ -89,6 +113,20 @@ pub struct RawEntry {
|
||||||
pub disabled_reason: Option<String>,
|
pub disabled_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Network requirements for an entry
|
||||||
|
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||||
|
pub struct RawEntryNetwork {
|
||||||
|
/// Whether this entry requires network connectivity to launch
|
||||||
|
/// If true, the entry will not be available if the network check fails
|
||||||
|
#[serde(default)]
|
||||||
|
pub required: bool,
|
||||||
|
|
||||||
|
/// Override check URL for this entry
|
||||||
|
/// If specified, this URL will be checked instead of the global check_url
|
||||||
|
/// This is useful for entries that need specific services (e.g., Google, Microsoft)
|
||||||
|
pub check_url: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Raw entry kind
|
/// Raw entry kind
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
|
|
||||||
|
|
@ -260,6 +260,7 @@ mod tests {
|
||||||
limits: None,
|
limits: None,
|
||||||
warnings: None,
|
warnings: None,
|
||||||
volume: None,
|
volume: None,
|
||||||
|
network: None,
|
||||||
disabled: false,
|
disabled: false,
|
||||||
disabled_reason: None,
|
disabled_reason: None,
|
||||||
},
|
},
|
||||||
|
|
@ -277,6 +278,7 @@ mod tests {
|
||||||
limits: None,
|
limits: None,
|
||||||
warnings: None,
|
warnings: None,
|
||||||
volume: None,
|
volume: None,
|
||||||
|
network: None,
|
||||||
disabled: false,
|
disabled: false,
|
||||||
disabled_reason: None,
|
disabled_reason: None,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -506,6 +506,8 @@ impl CoreEngine {
|
||||||
current_session,
|
current_session,
|
||||||
entry_count: self.policy.entries.len(),
|
entry_count: self.policy.entries.len(),
|
||||||
entries,
|
entries,
|
||||||
|
// Connectivity is populated by the daemon, not the core engine
|
||||||
|
connectivity: Default::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -565,7 +567,7 @@ impl CoreEngine {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy};
|
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy, NetworkRequirement};
|
||||||
use shepherd_api::EntryKind;
|
use shepherd_api::EntryKind;
|
||||||
use shepherd_store::SqliteStore;
|
use shepherd_store::SqliteStore;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
@ -594,12 +596,14 @@ mod tests {
|
||||||
},
|
},
|
||||||
warnings: vec![],
|
warnings: vec![],
|
||||||
volume: None,
|
volume: None,
|
||||||
|
network: NetworkRequirement::default(),
|
||||||
disabled: false,
|
disabled: false,
|
||||||
disabled_reason: None,
|
disabled_reason: None,
|
||||||
}],
|
}],
|
||||||
default_warnings: vec![],
|
default_warnings: vec![],
|
||||||
default_max_run: Some(Duration::from_secs(3600)),
|
default_max_run: Some(Duration::from_secs(3600)),
|
||||||
volume: Default::default(),
|
volume: Default::default(),
|
||||||
|
network: Default::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -677,6 +681,7 @@ mod tests {
|
||||||
message_template: Some("1 minute left".into()),
|
message_template: Some("1 minute left".into()),
|
||||||
}],
|
}],
|
||||||
volume: None,
|
volume: None,
|
||||||
|
network: NetworkRequirement::default(),
|
||||||
disabled: false,
|
disabled: false,
|
||||||
disabled_reason: None,
|
disabled_reason: None,
|
||||||
}],
|
}],
|
||||||
|
|
@ -684,6 +689,7 @@ mod tests {
|
||||||
default_warnings: vec![],
|
default_warnings: vec![],
|
||||||
default_max_run: Some(Duration::from_secs(3600)),
|
default_max_run: Some(Duration::from_secs(3600)),
|
||||||
volume: Default::default(),
|
volume: Default::default(),
|
||||||
|
network: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||||
|
|
@ -742,6 +748,7 @@ mod tests {
|
||||||
},
|
},
|
||||||
warnings: vec![],
|
warnings: vec![],
|
||||||
volume: None,
|
volume: None,
|
||||||
|
network: NetworkRequirement::default(),
|
||||||
disabled: false,
|
disabled: false,
|
||||||
disabled_reason: None,
|
disabled_reason: None,
|
||||||
}],
|
}],
|
||||||
|
|
@ -749,6 +756,7 @@ mod tests {
|
||||||
default_warnings: vec![],
|
default_warnings: vec![],
|
||||||
default_max_run: Some(Duration::from_secs(3600)),
|
default_max_run: Some(Duration::from_secs(3600)),
|
||||||
volume: Default::default(),
|
volume: Default::default(),
|
||||||
|
network: Default::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,11 @@ nix = { workspace = true }
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
dirs = "5.0"
|
dirs = "5.0"
|
||||||
shell-escape = "0.1"
|
shell-escape = "0.1"
|
||||||
|
chrono = { workspace = true }
|
||||||
|
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
|
||||||
|
netlink-sys = "0.8"
|
||||||
|
netlink-packet-core = "0.7"
|
||||||
|
netlink-packet-route = "0.21"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
|
|
|
||||||
497
crates/shepherd-host-linux/src/connectivity.rs
Normal file
497
crates/shepherd-host-linux/src/connectivity.rs
Normal file
|
|
@ -0,0 +1,497 @@
|
||||||
|
//! Network connectivity monitoring for Linux
|
||||||
|
//!
|
||||||
|
//! This module provides:
|
||||||
|
//! - Periodic connectivity checks to a configurable URL
|
||||||
|
//! - Network interface change detection via netlink
|
||||||
|
//! - Per-entry connectivity status tracking
|
||||||
|
|
||||||
|
#![allow(dead_code)] // Methods on ConnectivityMonitor may be used for future admin commands
|
||||||
|
|
||||||
|
use chrono::{DateTime, Local};
|
||||||
|
use netlink_packet_core::{NetlinkMessage, NetlinkPayload};
|
||||||
|
use netlink_packet_route::RouteNetlinkMessage;
|
||||||
|
use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr};
|
||||||
|
use reqwest::Client;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::os::fd::AsRawFd;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::sync::{mpsc, watch, RwLock};
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
/// Events emitted by the connectivity monitor
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ConnectivityEvent {
|
||||||
|
/// Global connectivity status changed
|
||||||
|
StatusChanged {
|
||||||
|
connected: bool,
|
||||||
|
check_url: String,
|
||||||
|
},
|
||||||
|
/// Network interface changed (may trigger recheck)
|
||||||
|
InterfaceChanged,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for the connectivity monitor
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ConnectivityConfig {
|
||||||
|
/// URL to check for global network connectivity
|
||||||
|
pub check_url: String,
|
||||||
|
/// How often to perform periodic connectivity checks
|
||||||
|
pub check_interval: Duration,
|
||||||
|
/// Timeout for connectivity checks
|
||||||
|
pub check_timeout: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cached connectivity check result
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct CheckResult {
|
||||||
|
connected: bool,
|
||||||
|
checked_at: DateTime<Local>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connectivity monitor that tracks network availability
|
||||||
|
pub struct ConnectivityMonitor {
|
||||||
|
/// HTTP client for connectivity checks
|
||||||
|
client: Client,
|
||||||
|
/// Configuration
|
||||||
|
config: ConnectivityConfig,
|
||||||
|
/// Current global connectivity status
|
||||||
|
global_status: Arc<RwLock<Option<CheckResult>>>,
|
||||||
|
/// Cached results for specific URLs (entry-specific checks)
|
||||||
|
url_cache: Arc<RwLock<HashMap<String, CheckResult>>>,
|
||||||
|
/// Channel for sending events
|
||||||
|
event_tx: mpsc::Sender<ConnectivityEvent>,
|
||||||
|
/// Shutdown signal
|
||||||
|
shutdown_rx: watch::Receiver<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectivityMonitor {
|
||||||
|
/// Create a new connectivity monitor
|
||||||
|
pub fn new(
|
||||||
|
config: ConnectivityConfig,
|
||||||
|
shutdown_rx: watch::Receiver<bool>,
|
||||||
|
) -> (Self, mpsc::Receiver<ConnectivityEvent>) {
|
||||||
|
let (event_tx, event_rx) = mpsc::channel(32);
|
||||||
|
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(config.check_timeout)
|
||||||
|
.connect_timeout(config.check_timeout)
|
||||||
|
.build()
|
||||||
|
.expect("Failed to create HTTP client");
|
||||||
|
|
||||||
|
let monitor = Self {
|
||||||
|
client,
|
||||||
|
config,
|
||||||
|
global_status: Arc::new(RwLock::new(None)),
|
||||||
|
url_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
event_tx,
|
||||||
|
shutdown_rx,
|
||||||
|
};
|
||||||
|
|
||||||
|
(monitor, event_rx)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the connectivity monitor (runs until shutdown)
|
||||||
|
pub async fn run(self) {
|
||||||
|
let check_interval = self.config.check_interval;
|
||||||
|
let check_url = self.config.check_url.clone();
|
||||||
|
|
||||||
|
// Spawn periodic check task
|
||||||
|
let periodic_handle = {
|
||||||
|
let client = self.client.clone();
|
||||||
|
let global_status = self.global_status.clone();
|
||||||
|
let event_tx = self.event_tx.clone();
|
||||||
|
let check_url = check_url.clone();
|
||||||
|
let check_timeout = self.config.check_timeout;
|
||||||
|
let mut shutdown = self.shutdown_rx.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval = tokio::time::interval(check_interval);
|
||||||
|
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||||
|
|
||||||
|
// Do initial check immediately
|
||||||
|
let connected = check_url_reachable(&client, &check_url, check_timeout).await;
|
||||||
|
update_global_status(&global_status, &event_tx, &check_url, connected).await;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
_ = interval.tick() => {
|
||||||
|
let connected = check_url_reachable(&client, &check_url, check_timeout).await;
|
||||||
|
update_global_status(&global_status, &event_tx, &check_url, connected).await;
|
||||||
|
}
|
||||||
|
_ = shutdown.changed() => {
|
||||||
|
if *shutdown.borrow() {
|
||||||
|
debug!("Periodic check task shutting down");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
// Spawn netlink monitor task
|
||||||
|
let netlink_handle = {
|
||||||
|
let client = self.client.clone();
|
||||||
|
let global_status = self.global_status.clone();
|
||||||
|
let url_cache = self.url_cache.clone();
|
||||||
|
let event_tx = self.event_tx.clone();
|
||||||
|
let check_url = check_url.clone();
|
||||||
|
let check_timeout = self.config.check_timeout;
|
||||||
|
let mut shutdown = self.shutdown_rx.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = run_netlink_monitor(
|
||||||
|
&client,
|
||||||
|
&global_status,
|
||||||
|
&url_cache,
|
||||||
|
&event_tx,
|
||||||
|
&check_url,
|
||||||
|
check_timeout,
|
||||||
|
&mut shutdown,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
warn!(error = %e, "Netlink monitor failed, network change detection unavailable");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wait for shutdown
|
||||||
|
let mut shutdown = self.shutdown_rx.clone();
|
||||||
|
let _ = shutdown.changed().await;
|
||||||
|
|
||||||
|
// Cancel tasks
|
||||||
|
periodic_handle.abort();
|
||||||
|
netlink_handle.abort();
|
||||||
|
|
||||||
|
info!("Connectivity monitor stopped");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current global connectivity status
|
||||||
|
pub async fn is_connected(&self) -> bool {
|
||||||
|
self.global_status
|
||||||
|
.read()
|
||||||
|
.await
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|r| r.connected)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the last check time
|
||||||
|
pub async fn last_check_time(&self) -> Option<DateTime<Local>> {
|
||||||
|
self.global_status.read().await.as_ref().map(|r| r.checked_at)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a specific URL is reachable (with caching)
|
||||||
|
/// Used for entry-specific network requirements
|
||||||
|
pub async fn check_url(&self, url: &str) -> bool {
|
||||||
|
// Check cache first
|
||||||
|
{
|
||||||
|
let cache = self.url_cache.read().await;
|
||||||
|
if let Some(result) = cache.get(url) {
|
||||||
|
// Cache valid for half the check interval
|
||||||
|
let cache_ttl = self.config.check_interval / 2;
|
||||||
|
let age = shepherd_util::now()
|
||||||
|
.signed_duration_since(result.checked_at)
|
||||||
|
.to_std()
|
||||||
|
.unwrap_or(Duration::MAX);
|
||||||
|
if age < cache_ttl {
|
||||||
|
return result.connected;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform check
|
||||||
|
let connected = check_url_reachable(&self.client, url, self.config.check_timeout).await;
|
||||||
|
|
||||||
|
// Update cache
|
||||||
|
{
|
||||||
|
let mut cache = self.url_cache.write().await;
|
||||||
|
cache.insert(
|
||||||
|
url.to_string(),
|
||||||
|
CheckResult {
|
||||||
|
connected,
|
||||||
|
checked_at: shepherd_util::now(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
connected
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Force an immediate connectivity recheck
|
||||||
|
pub async fn trigger_recheck(&self) {
|
||||||
|
let connected =
|
||||||
|
check_url_reachable(&self.client, &self.config.check_url, self.config.check_timeout)
|
||||||
|
.await;
|
||||||
|
update_global_status(
|
||||||
|
&self.global_status,
|
||||||
|
&self.event_tx,
|
||||||
|
&self.config.check_url,
|
||||||
|
connected,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Clear URL cache to force rechecks
|
||||||
|
self.url_cache.write().await.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a URL is reachable
|
||||||
|
async fn check_url_reachable(client: &Client, url: &str, timeout: Duration) -> bool {
|
||||||
|
debug!(url = %url, "Checking connectivity");
|
||||||
|
|
||||||
|
match client
|
||||||
|
.get(url)
|
||||||
|
.timeout(timeout)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(response) => {
|
||||||
|
let status = response.status();
|
||||||
|
let connected = status.is_success() || status.as_u16() == 204;
|
||||||
|
debug!(url = %url, status = %status, connected = connected, "Connectivity check complete");
|
||||||
|
connected
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!(url = %url, error = %e, "Connectivity check failed");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update global status and emit event if changed
|
||||||
|
async fn update_global_status(
|
||||||
|
global_status: &Arc<RwLock<Option<CheckResult>>>,
|
||||||
|
event_tx: &mpsc::Sender<ConnectivityEvent>,
|
||||||
|
check_url: &str,
|
||||||
|
connected: bool,
|
||||||
|
) {
|
||||||
|
let mut status = global_status.write().await;
|
||||||
|
let previous = status.as_ref().map(|r| r.connected);
|
||||||
|
|
||||||
|
*status = Some(CheckResult {
|
||||||
|
connected,
|
||||||
|
checked_at: shepherd_util::now(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Emit event if status changed
|
||||||
|
if previous != Some(connected) {
|
||||||
|
info!(
|
||||||
|
connected = connected,
|
||||||
|
url = %check_url,
|
||||||
|
"Global connectivity status changed"
|
||||||
|
);
|
||||||
|
let _ = event_tx
|
||||||
|
.send(ConnectivityEvent::StatusChanged {
|
||||||
|
connected,
|
||||||
|
check_url: check_url.to_string(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the netlink monitor to detect network interface changes
|
||||||
|
async fn run_netlink_monitor(
|
||||||
|
client: &Client,
|
||||||
|
global_status: &Arc<RwLock<Option<CheckResult>>>,
|
||||||
|
url_cache: &Arc<RwLock<HashMap<String, CheckResult>>>,
|
||||||
|
event_tx: &mpsc::Sender<ConnectivityEvent>,
|
||||||
|
check_url: &str,
|
||||||
|
check_timeout: Duration,
|
||||||
|
shutdown: &mut watch::Receiver<bool>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
// Create netlink socket for route notifications
|
||||||
|
let mut socket = Socket::new(NETLINK_ROUTE)?;
|
||||||
|
|
||||||
|
// Bind to multicast groups for link and address changes
|
||||||
|
// RTMGRP_LINK = 1, RTMGRP_IPV4_IFADDR = 0x10, RTMGRP_IPV6_IFADDR = 0x100
|
||||||
|
let groups = 1 | 0x10 | 0x100;
|
||||||
|
let addr = SocketAddr::new(0, groups);
|
||||||
|
socket.bind(&addr)?;
|
||||||
|
|
||||||
|
// Set non-blocking for async compatibility
|
||||||
|
socket.set_non_blocking(true)?;
|
||||||
|
|
||||||
|
info!("Netlink monitor started");
|
||||||
|
|
||||||
|
let fd = socket.as_raw_fd();
|
||||||
|
let mut buf = vec![0u8; 4096];
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Use tokio's async fd for the socket
|
||||||
|
let async_fd = tokio::io::unix::AsyncFd::new(fd)?;
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
result = async_fd.readable() => {
|
||||||
|
match result {
|
||||||
|
Ok(mut guard) => {
|
||||||
|
// Try to read from socket
|
||||||
|
match socket.recv(&mut buf, 0) {
|
||||||
|
Ok(len) if len > 0 => {
|
||||||
|
// Parse netlink messages
|
||||||
|
if has_relevant_netlink_event(&buf[..len]) {
|
||||||
|
debug!("Network interface change detected");
|
||||||
|
let _ = event_tx.send(ConnectivityEvent::InterfaceChanged).await;
|
||||||
|
|
||||||
|
// Clear URL cache
|
||||||
|
url_cache.write().await.clear();
|
||||||
|
|
||||||
|
// Recheck connectivity after a short delay
|
||||||
|
// (give network time to stabilize)
|
||||||
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||||
|
|
||||||
|
let connected = check_url_reachable(client, check_url, check_timeout).await;
|
||||||
|
update_global_status(global_status, event_tx, check_url, connected).await;
|
||||||
|
}
|
||||||
|
guard.clear_ready();
|
||||||
|
}
|
||||||
|
Ok(_) => {
|
||||||
|
guard.clear_ready();
|
||||||
|
}
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||||||
|
guard.clear_ready();
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(error = %e, "Netlink recv error");
|
||||||
|
guard.clear_ready();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(error = %e, "Async fd error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = shutdown.changed() => {
|
||||||
|
if *shutdown.borrow() {
|
||||||
|
debug!("Netlink monitor shutting down");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a netlink message buffer contains relevant network events
|
||||||
|
fn has_relevant_netlink_event(buf: &[u8]) -> bool {
|
||||||
|
let mut offset = 0;
|
||||||
|
|
||||||
|
while offset < buf.len() {
|
||||||
|
match NetlinkMessage::<RouteNetlinkMessage>::deserialize(&buf[offset..]) {
|
||||||
|
Ok(msg) => {
|
||||||
|
if let NetlinkPayload::InnerMessage(route_msg) = &msg.payload
|
||||||
|
&& matches!(
|
||||||
|
route_msg,
|
||||||
|
// Link up/down events
|
||||||
|
RouteNetlinkMessage::NewLink(_)
|
||||||
|
| RouteNetlinkMessage::DelLink(_)
|
||||||
|
// Address added/removed
|
||||||
|
| RouteNetlinkMessage::NewAddress(_)
|
||||||
|
| RouteNetlinkMessage::DelAddress(_)
|
||||||
|
// Route changes
|
||||||
|
| RouteNetlinkMessage::NewRoute(_)
|
||||||
|
| RouteNetlinkMessage::DelRoute(_)
|
||||||
|
)
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move to next message
|
||||||
|
let len = msg.header.length as usize;
|
||||||
|
if len == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
offset += len;
|
||||||
|
}
|
||||||
|
Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle for accessing connectivity status from other parts of the service
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ConnectivityHandle {
|
||||||
|
client: Client,
|
||||||
|
global_status: Arc<RwLock<Option<CheckResult>>>,
|
||||||
|
url_cache: Arc<RwLock<HashMap<String, CheckResult>>>,
|
||||||
|
check_timeout: Duration,
|
||||||
|
cache_ttl: Duration,
|
||||||
|
global_check_url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnectivityHandle {
|
||||||
|
/// Create a handle from the monitor
|
||||||
|
pub fn from_monitor(monitor: &ConnectivityMonitor) -> Self {
|
||||||
|
Self {
|
||||||
|
client: monitor.client.clone(),
|
||||||
|
global_status: monitor.global_status.clone(),
|
||||||
|
url_cache: monitor.url_cache.clone(),
|
||||||
|
check_timeout: monitor.config.check_timeout,
|
||||||
|
cache_ttl: monitor.config.check_interval / 2,
|
||||||
|
global_check_url: monitor.config.check_url.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current global connectivity status
|
||||||
|
pub async fn is_connected(&self) -> bool {
|
||||||
|
self.global_status
|
||||||
|
.read()
|
||||||
|
.await
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|r| r.connected)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the last check time
|
||||||
|
pub async fn last_check_time(&self) -> Option<DateTime<Local>> {
|
||||||
|
self.global_status.read().await.as_ref().map(|r| r.checked_at)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the global check URL
|
||||||
|
pub fn global_check_url(&self) -> &str {
|
||||||
|
&self.global_check_url
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a specific URL is reachable (with caching)
|
||||||
|
pub async fn check_url(&self, url: &str) -> bool {
|
||||||
|
// If it's the global URL, use global status
|
||||||
|
if url == self.global_check_url {
|
||||||
|
return self.is_connected().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check cache first
|
||||||
|
{
|
||||||
|
let cache = self.url_cache.read().await;
|
||||||
|
if let Some(result) = cache.get(url) {
|
||||||
|
let age = shepherd_util::now()
|
||||||
|
.signed_duration_since(result.checked_at)
|
||||||
|
.to_std()
|
||||||
|
.unwrap_or(Duration::MAX);
|
||||||
|
if age < self.cache_ttl {
|
||||||
|
return result.connected;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform check
|
||||||
|
let connected = check_url_reachable(&self.client, url, self.check_timeout).await;
|
||||||
|
|
||||||
|
// Update cache
|
||||||
|
{
|
||||||
|
let mut cache = self.url_cache.write().await;
|
||||||
|
cache.insert(
|
||||||
|
url.to_string(),
|
||||||
|
CheckResult {
|
||||||
|
connected,
|
||||||
|
checked_at: shepherd_util::now(),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
connected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -6,11 +6,14 @@
|
||||||
//! - Exit observation
|
//! - Exit observation
|
||||||
//! - stdout/stderr capture
|
//! - stdout/stderr capture
|
||||||
//! - Volume control with auto-detection of sound systems
|
//! - Volume control with auto-detection of sound systems
|
||||||
|
//! - Network connectivity monitoring via netlink
|
||||||
|
|
||||||
mod adapter;
|
mod adapter;
|
||||||
|
mod connectivity;
|
||||||
mod process;
|
mod process;
|
||||||
mod volume;
|
mod volume;
|
||||||
|
|
||||||
pub use adapter::*;
|
pub use adapter::*;
|
||||||
|
pub use connectivity::*;
|
||||||
pub use process::*;
|
pub use process::*;
|
||||||
pub use volume::*;
|
pub use volume::*;
|
||||||
|
|
|
||||||
|
|
@ -252,5 +252,6 @@ fn reason_to_message(reason: &ReasonCode) -> &'static str {
|
||||||
ReasonCode::SessionActive { .. } => "Another session is active",
|
ReasonCode::SessionActive { .. } => "Another session is active",
|
||||||
ReasonCode::UnsupportedKind { .. } => "Entry type not supported",
|
ReasonCode::UnsupportedKind { .. } => "Entry type not supported",
|
||||||
ReasonCode::Disabled { .. } => "Entry disabled",
|
ReasonCode::Disabled { .. } => "Entry disabled",
|
||||||
|
ReasonCode::NetworkUnavailable { .. } => "Network connection required",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -117,6 +117,10 @@ impl SharedState {
|
||||||
EventPayload::VolumeChanged { .. } => {
|
EventPayload::VolumeChanged { .. } => {
|
||||||
// Volume events are handled by HUD
|
// Volume events are handled by HUD
|
||||||
}
|
}
|
||||||
|
EventPayload::ConnectivityChanged { .. } => {
|
||||||
|
// Connectivity changes may affect entry availability - request fresh state
|
||||||
|
self.set(LauncherState::Connecting);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ tracing-subscriber = { workspace = true }
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
clap = { version = "4.5", features = ["derive", "env"] }
|
clap = { version = "4.5", features = ["derive", "env"] }
|
||||||
|
nix = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
|
|
|
||||||
|
|
@ -8,17 +8,22 @@
|
||||||
//! - Host adapter (Linux)
|
//! - Host adapter (Linux)
|
||||||
//! - IPC server
|
//! - IPC server
|
||||||
//! - Volume control
|
//! - Volume control
|
||||||
|
//! - Network connectivity monitoring
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use shepherd_api::{
|
use shepherd_api::{
|
||||||
Command, ErrorCode, ErrorInfo, Event, EventPayload, HealthStatus,
|
Command, ConnectivityStatus, ErrorCode, ErrorInfo, Event, EventPayload, HealthStatus,
|
||||||
Response, ResponsePayload, SessionEndReason, StopMode, VolumeInfo, VolumeRestrictions,
|
ReasonCode, Response, ResponsePayload, SessionEndReason, StopMode, VolumeInfo,
|
||||||
|
VolumeRestrictions,
|
||||||
};
|
};
|
||||||
use shepherd_config::{load_config, VolumePolicy};
|
use shepherd_config::{load_config, VolumePolicy};
|
||||||
use shepherd_core::{CoreEngine, CoreEvent, LaunchDecision, StopDecision};
|
use shepherd_core::{CoreEngine, CoreEvent, LaunchDecision, StopDecision};
|
||||||
use shepherd_host_api::{HostAdapter, HostEvent, StopMode as HostStopMode, VolumeController};
|
use shepherd_host_api::{HostAdapter, HostEvent, StopMode as HostStopMode, VolumeController};
|
||||||
use shepherd_host_linux::{LinuxHost, LinuxVolumeController};
|
use shepherd_host_linux::{
|
||||||
|
ConnectivityConfig, ConnectivityEvent, ConnectivityHandle, ConnectivityMonitor, LinuxHost,
|
||||||
|
LinuxVolumeController,
|
||||||
|
};
|
||||||
use shepherd_ipc::{IpcServer, ServerMessage};
|
use shepherd_ipc::{IpcServer, ServerMessage};
|
||||||
use shepherd_store::{AuditEvent, AuditEventType, SqliteStore, Store};
|
use shepherd_store::{AuditEvent, AuditEventType, SqliteStore, Store};
|
||||||
use shepherd_util::{default_config_path, ClientId, MonotonicInstant, RateLimiter};
|
use shepherd_util::{default_config_path, ClientId, MonotonicInstant, RateLimiter};
|
||||||
|
|
@ -26,7 +31,7 @@ use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::signal::unix::{signal, SignalKind};
|
use tokio::signal::unix::{signal, SignalKind};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::{mpsc, watch, Mutex};
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
|
@ -60,10 +65,12 @@ struct Service {
|
||||||
ipc: Arc<IpcServer>,
|
ipc: Arc<IpcServer>,
|
||||||
store: Arc<dyn Store>,
|
store: Arc<dyn Store>,
|
||||||
rate_limiter: RateLimiter,
|
rate_limiter: RateLimiter,
|
||||||
|
connectivity: ConnectivityHandle,
|
||||||
|
shutdown_tx: watch::Sender<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Service {
|
impl Service {
|
||||||
async fn new(args: &Args) -> Result<Self> {
|
async fn new(args: &Args) -> Result<(Self, mpsc::Receiver<ConnectivityEvent>)> {
|
||||||
// Load configuration
|
// Load configuration
|
||||||
let policy = load_config(&args.config)
|
let policy = load_config(&args.config)
|
||||||
.with_context(|| format!("Failed to load config from {:?}", args.config))?;
|
.with_context(|| format!("Failed to load config from {:?}", args.config))?;
|
||||||
|
|
@ -116,6 +123,7 @@ impl Service {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize core engine
|
// Initialize core engine
|
||||||
|
let network_policy = policy.network.clone();
|
||||||
let engine = CoreEngine::new(policy, store.clone(), host.capabilities().clone());
|
let engine = CoreEngine::new(policy, store.clone(), host.capabilities().clone());
|
||||||
|
|
||||||
// Initialize IPC server
|
// Initialize IPC server
|
||||||
|
|
@ -127,17 +135,43 @@ impl Service {
|
||||||
// Rate limiter: 30 requests per second per client
|
// Rate limiter: 30 requests per second per client
|
||||||
let rate_limiter = RateLimiter::new(30, Duration::from_secs(1));
|
let rate_limiter = RateLimiter::new(30, Duration::from_secs(1));
|
||||||
|
|
||||||
Ok(Self {
|
// Initialize connectivity monitor
|
||||||
engine,
|
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||||
host,
|
let connectivity_config = ConnectivityConfig {
|
||||||
volume,
|
check_url: network_policy.check_url,
|
||||||
ipc: Arc::new(ipc),
|
check_interval: network_policy.check_interval,
|
||||||
store,
|
check_timeout: network_policy.check_timeout,
|
||||||
rate_limiter,
|
};
|
||||||
})
|
let (connectivity_monitor, connectivity_events) =
|
||||||
|
ConnectivityMonitor::new(connectivity_config, shutdown_rx);
|
||||||
|
let connectivity = ConnectivityHandle::from_monitor(&connectivity_monitor);
|
||||||
|
|
||||||
|
// Spawn connectivity monitor task
|
||||||
|
tokio::spawn(async move {
|
||||||
|
connectivity_monitor.run().await;
|
||||||
|
});
|
||||||
|
|
||||||
|
info!(
|
||||||
|
check_url = %connectivity.global_check_url(),
|
||||||
|
"Connectivity monitor started"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok((
|
||||||
|
Self {
|
||||||
|
engine,
|
||||||
|
host,
|
||||||
|
volume,
|
||||||
|
ipc: Arc::new(ipc),
|
||||||
|
store,
|
||||||
|
rate_limiter,
|
||||||
|
connectivity,
|
||||||
|
shutdown_tx,
|
||||||
|
},
|
||||||
|
connectivity_events,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run(self) -> Result<()> {
|
async fn run(self, mut connectivity_events: mpsc::Receiver<ConnectivityEvent>) -> Result<()> {
|
||||||
// Start host process monitor
|
// Start host process monitor
|
||||||
let _monitor_handle = self.host.start_monitor();
|
let _monitor_handle = self.host.start_monitor();
|
||||||
|
|
||||||
|
|
@ -155,6 +189,8 @@ impl Service {
|
||||||
let host = self.host.clone();
|
let host = self.host.clone();
|
||||||
let volume = self.volume.clone();
|
let volume = self.volume.clone();
|
||||||
let store = self.store.clone();
|
let store = self.store.clone();
|
||||||
|
let connectivity = self.connectivity.clone();
|
||||||
|
let shutdown_tx = self.shutdown_tx.clone();
|
||||||
|
|
||||||
// Spawn IPC accept task
|
// Spawn IPC accept task
|
||||||
let ipc_accept = ipc_ref.clone();
|
let ipc_accept = ipc_ref.clone();
|
||||||
|
|
@ -218,7 +254,12 @@ impl Service {
|
||||||
|
|
||||||
// IPC messages
|
// IPC messages
|
||||||
Some(msg) = ipc_messages.recv() => {
|
Some(msg) = ipc_messages.recv() => {
|
||||||
Self::handle_ipc_message(&engine, &host, &volume, &ipc_ref, &store, &rate_limiter, msg).await;
|
Self::handle_ipc_message(&engine, &host, &volume, &ipc_ref, &store, &rate_limiter, &connectivity, msg).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connectivity events
|
||||||
|
Some(conn_event) = connectivity_events.recv() => {
|
||||||
|
Self::handle_connectivity_event(&engine, &ipc_ref, &connectivity, conn_event).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -226,6 +267,9 @@ impl Service {
|
||||||
// Graceful shutdown
|
// Graceful shutdown
|
||||||
info!("Shutting down shepherdd");
|
info!("Shutting down shepherdd");
|
||||||
|
|
||||||
|
// Signal connectivity monitor to stop
|
||||||
|
let _ = shutdown_tx.send(true);
|
||||||
|
|
||||||
// Stop all running sessions
|
// Stop all running sessions
|
||||||
{
|
{
|
||||||
let engine = engine.lock().await;
|
let engine = engine.lock().await;
|
||||||
|
|
@ -433,6 +477,45 @@ impl Service {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_connectivity_event(
|
||||||
|
engine: &Arc<Mutex<CoreEngine>>,
|
||||||
|
ipc: &Arc<IpcServer>,
|
||||||
|
connectivity: &ConnectivityHandle,
|
||||||
|
event: ConnectivityEvent,
|
||||||
|
) {
|
||||||
|
match event {
|
||||||
|
ConnectivityEvent::StatusChanged {
|
||||||
|
connected,
|
||||||
|
check_url,
|
||||||
|
} => {
|
||||||
|
info!(connected = connected, url = %check_url, "Connectivity status changed");
|
||||||
|
|
||||||
|
// Broadcast connectivity change event
|
||||||
|
ipc.broadcast_event(Event::new(EventPayload::ConnectivityChanged {
|
||||||
|
connected,
|
||||||
|
check_url,
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Also broadcast state change so clients can update entry availability
|
||||||
|
let state = {
|
||||||
|
let eng = engine.lock().await;
|
||||||
|
let mut state = eng.get_state();
|
||||||
|
state.connectivity = ConnectivityStatus {
|
||||||
|
connected: connectivity.is_connected().await,
|
||||||
|
check_url: Some(connectivity.global_check_url().to_string()),
|
||||||
|
last_check: connectivity.last_check_time().await,
|
||||||
|
};
|
||||||
|
state
|
||||||
|
};
|
||||||
|
ipc.broadcast_event(Event::new(EventPayload::StateChanged(state)));
|
||||||
|
}
|
||||||
|
|
||||||
|
ConnectivityEvent::InterfaceChanged => {
|
||||||
|
debug!("Network interface changed, connectivity recheck in progress");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn handle_ipc_message(
|
async fn handle_ipc_message(
|
||||||
engine: &Arc<Mutex<CoreEngine>>,
|
engine: &Arc<Mutex<CoreEngine>>,
|
||||||
host: &Arc<LinuxHost>,
|
host: &Arc<LinuxHost>,
|
||||||
|
|
@ -440,6 +523,7 @@ impl Service {
|
||||||
ipc: &Arc<IpcServer>,
|
ipc: &Arc<IpcServer>,
|
||||||
store: &Arc<dyn Store>,
|
store: &Arc<dyn Store>,
|
||||||
rate_limiter: &Arc<Mutex<RateLimiter>>,
|
rate_limiter: &Arc<Mutex<RateLimiter>>,
|
||||||
|
connectivity: &ConnectivityHandle,
|
||||||
msg: ServerMessage,
|
msg: ServerMessage,
|
||||||
) {
|
) {
|
||||||
match msg {
|
match msg {
|
||||||
|
|
@ -458,7 +542,7 @@ impl Service {
|
||||||
}
|
}
|
||||||
|
|
||||||
let response =
|
let response =
|
||||||
Self::handle_command(engine, host, volume, ipc, store, &client_id, request.request_id, request.command)
|
Self::handle_command(engine, host, volume, ipc, store, connectivity, &client_id, request.request_id, request.command)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let _ = ipc.send_response(&client_id, response).await;
|
let _ = ipc.send_response(&client_id, response).await;
|
||||||
|
|
@ -504,6 +588,7 @@ impl Service {
|
||||||
volume: &Arc<LinuxVolumeController>,
|
volume: &Arc<LinuxVolumeController>,
|
||||||
ipc: &Arc<IpcServer>,
|
ipc: &Arc<IpcServer>,
|
||||||
store: &Arc<dyn Store>,
|
store: &Arc<dyn Store>,
|
||||||
|
connectivity: &ConnectivityHandle,
|
||||||
client_id: &ClientId,
|
client_id: &ClientId,
|
||||||
request_id: u64,
|
request_id: u64,
|
||||||
command: Command,
|
command: Command,
|
||||||
|
|
@ -513,7 +598,13 @@ impl Service {
|
||||||
|
|
||||||
match command {
|
match command {
|
||||||
Command::GetState => {
|
Command::GetState => {
|
||||||
let state = engine.lock().await.get_state();
|
let mut state = engine.lock().await.get_state();
|
||||||
|
// Add connectivity status
|
||||||
|
state.connectivity = ConnectivityStatus {
|
||||||
|
connected: connectivity.is_connected().await,
|
||||||
|
check_url: Some(connectivity.global_check_url().to_string()),
|
||||||
|
last_check: connectivity.last_check_time().await,
|
||||||
|
};
|
||||||
Response::success(request_id, ResponsePayload::State(state))
|
Response::success(request_id, ResponsePayload::State(state))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -526,6 +617,30 @@ impl Service {
|
||||||
Command::Launch { entry_id } => {
|
Command::Launch { entry_id } => {
|
||||||
let mut eng = engine.lock().await;
|
let mut eng = engine.lock().await;
|
||||||
|
|
||||||
|
// First check if the entry requires network and if it's available
|
||||||
|
if let Some(entry) = eng.policy().get_entry(&entry_id)
|
||||||
|
&& entry.network.required
|
||||||
|
{
|
||||||
|
let check_url = entry.network.effective_check_url(&eng.policy().network);
|
||||||
|
let network_ok = connectivity.check_url(check_url).await;
|
||||||
|
|
||||||
|
if !network_ok {
|
||||||
|
info!(
|
||||||
|
entry_id = %entry_id,
|
||||||
|
check_url = %check_url,
|
||||||
|
"Launch denied: network connectivity check failed"
|
||||||
|
);
|
||||||
|
return Response::success(
|
||||||
|
request_id,
|
||||||
|
ResponsePayload::LaunchDenied {
|
||||||
|
reasons: vec![ReasonCode::NetworkUnavailable {
|
||||||
|
check_url: check_url.to_string(),
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
match eng.request_launch(&entry_id, now) {
|
match eng.request_launch(&entry_id, now) {
|
||||||
LaunchDecision::Approved(plan) => {
|
LaunchDecision::Approved(plan) => {
|
||||||
// Start the session in the engine
|
// Start the session in the engine
|
||||||
|
|
@ -939,6 +1054,6 @@ async fn main() -> Result<()> {
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create and run the service
|
// Create and run the service
|
||||||
let service = Service::new(&args).await?;
|
let (service, connectivity_events) = Service::new(&args).await?;
|
||||||
service.run().await
|
service.run(connectivity_events).await
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
//! These tests verify the end-to-end behavior of shepherdd.
|
//! These tests verify the end-to-end behavior of shepherdd.
|
||||||
|
|
||||||
use shepherd_api::{EntryKind, WarningSeverity, WarningThreshold};
|
use shepherd_api::{EntryKind, WarningSeverity, WarningThreshold};
|
||||||
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy, Policy};
|
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy, NetworkRequirement, Policy};
|
||||||
use shepherd_core::{CoreEngine, CoreEvent, LaunchDecision};
|
use shepherd_core::{CoreEngine, CoreEvent, LaunchDecision};
|
||||||
use shepherd_host_api::{HostCapabilities, MockHost};
|
use shepherd_host_api::{HostCapabilities, MockHost};
|
||||||
use shepherd_store::{SqliteStore, Store};
|
use shepherd_store::{SqliteStore, Store};
|
||||||
|
|
@ -48,6 +48,7 @@ fn make_test_policy() -> Policy {
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
volume: None,
|
volume: None,
|
||||||
|
network: NetworkRequirement::default(),
|
||||||
disabled: false,
|
disabled: false,
|
||||||
disabled_reason: None,
|
disabled_reason: None,
|
||||||
},
|
},
|
||||||
|
|
@ -55,6 +56,7 @@ fn make_test_policy() -> Policy {
|
||||||
default_warnings: vec![],
|
default_warnings: vec![],
|
||||||
default_max_run: Some(Duration::from_secs(3600)),
|
default_max_run: Some(Duration::from_secs(3600)),
|
||||||
volume: Default::default(),
|
volume: Default::default(),
|
||||||
|
network: Default::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue