WIP: network check

This commit is contained in:
Albert Armea 2026-01-06 19:18:48 -05:00
parent dc58817aea
commit 266685628e
18 changed files with 1710 additions and 33 deletions

886
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -58,8 +58,14 @@ anyhow = "1.0"
uuid = { version = "1.6", features = ["v4", "serde"] }
bitflags = "2.4"
# HTTP client (for connectivity checks)
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
# Unix-specific
nix = { version = "0.29", features = ["signal", "process", "user", "socket"] }
netlink-sys = "0.8"
netlink-packet-core = "0.7"
netlink-packet-route = "0.21"
# CLI
clap = { version = "4.5", features = ["derive", "env"] }

View file

@ -30,6 +30,16 @@ max_volume = 80 # Maximum volume percentage (0-100)
allow_mute = true # Whether mute toggle is allowed
allow_change = true # Whether volume changes are allowed at all
# Network connectivity settings (optional)
# Used to check if Internet is available before allowing network-dependent entries
[service.network]
# URL to check for global network connectivity (default: Google's connectivity check)
# check_url = "http://connectivitycheck.gstatic.com/generate_204"
# How often to perform periodic connectivity checks, in seconds (default: 60)
# check_interval_seconds = 60
# Timeout for connectivity checks, in seconds (default: 5)
# check_timeout_seconds = 5
# Default warning thresholds
[[service.default_warnings]]
seconds_before = 300
@ -211,6 +221,11 @@ message = "30 seconds! Save NOW!"
[entries.volume]
max_volume = 60 # Limit volume during gaming sessions
# Network requirements for online games
[entries.network]
required = true # Minecraft needs network for authentication and multiplayer
check_url = "http://www.msftconnecttest.com/connecttest.txt" # Use Microsoft's check (Minecraft is owned by Microsoft)
## === Steam games ===
# Steam can be used via Canonical's Steam snap package:
# https://snapcraft.io/steam
@ -244,6 +259,9 @@ end = "20:00"
# No [entries.limits] section - uses service defaults
# Omitting limits entirely uses default_max_run_seconds
[entries.network]
required = true # Steam needs network for authentication
# A Short Hike via Steam
# https://store.steampowered.com/app/1055540/A_Short_Hike/
[[entries]]
@ -267,6 +285,9 @@ days = "weekends"
start = "10:00"
end = "20:00"
[entries.network]
required = true # Steam needs network for authentication
## === Media ===
# Just use `mpv` to play media (for now).
# Files can be local on your system or URLs (YouTube, etc).
@ -314,6 +335,9 @@ max_run_seconds = 0 # Unlimited: sleep/study aid
daily_quota_seconds = 0 # Unlimited
cooldown_seconds = 0 # No cooldown
[entries.network]
required = true # YouTube streaming needs network
# Terminal for debugging only
[[entries]]
id = "terminal"

View file

@ -234,6 +234,7 @@ mod tests {
current_session: None,
entry_count: 5,
entries: vec![],
connectivity: Default::default(),
}),
);

View file

@ -88,6 +88,14 @@ pub enum EventPayload {
event_type: String,
details: serde_json::Value,
},
/// Network connectivity status changed
ConnectivityChanged {
/// Whether global connectivity check now passes
connected: bool,
/// The URL that was checked
check_url: String,
},
}
#[cfg(test)]

View file

@ -121,6 +121,11 @@ pub enum ReasonCode {
Disabled {
reason: Option<String>,
},
/// Network connectivity check failed
NetworkUnavailable {
/// The URL that was checked
check_url: String,
},
}
/// Warning severity level
@ -197,6 +202,20 @@ pub struct ServiceStateSnapshot {
/// Available entries for UI display
#[serde(default)]
pub entries: Vec<EntryView>,
/// Network connectivity status
#[serde(default)]
pub connectivity: ConnectivityStatus,
}
/// Network connectivity status
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ConnectivityStatus {
/// Whether global network connectivity check passed
pub connected: bool,
/// The URL that was checked for global connectivity
pub check_url: Option<String>,
/// When the last check was performed
pub last_check: Option<DateTime<Local>>,
}
/// Role for authorization

View file

@ -1,6 +1,6 @@
//! Validated policy structures
use crate::schema::{RawConfig, RawEntry, RawEntryKind, RawVolumeConfig, RawServiceConfig, RawWarningThreshold};
use crate::schema::{RawConfig, RawEntry, RawEntryKind, RawNetworkConfig, RawEntryNetwork, RawVolumeConfig, RawServiceConfig, RawWarningThreshold};
use crate::validation::{parse_days, parse_time};
use shepherd_api::{EntryKind, WarningSeverity, WarningThreshold};
use shepherd_util::{DaysOfWeek, EntryId, TimeWindow, WallClock, default_data_dir, default_log_dir, socket_path_without_env};
@ -24,6 +24,9 @@ pub struct Policy {
/// Global volume restrictions
pub volume: VolumePolicy,
/// Network connectivity policy
pub network: NetworkPolicy,
}
impl Policy {
@ -50,6 +53,13 @@ impl Policy {
.map(convert_volume_config)
.unwrap_or_default();
let network = raw
.service
.network
.as_ref()
.map(convert_network_config)
.unwrap_or_default();
let entries = raw
.entries
.into_iter()
@ -62,6 +72,7 @@ impl Policy {
default_warnings,
default_max_run,
volume: global_volume,
network,
}
}
@ -130,6 +141,7 @@ pub struct Entry {
pub limits: LimitsPolicy,
pub warnings: Vec<WarningThreshold>,
pub volume: Option<VolumePolicy>,
pub network: NetworkRequirement,
pub disabled: bool,
pub disabled_reason: Option<String>,
}
@ -159,6 +171,11 @@ impl Entry {
.map(|w| w.into_iter().map(convert_warning).collect())
.unwrap_or_else(|| default_warnings.to_vec());
let volume = raw.volume.as_ref().map(convert_volume_config);
let network = raw
.network
.as_ref()
.map(convert_entry_network)
.unwrap_or_default();
Self {
id: EntryId::new(raw.id),
@ -169,6 +186,7 @@ impl Entry {
limits,
warnings,
volume,
network,
disabled: raw.disabled,
disabled_reason: raw.disabled_reason,
}
@ -250,6 +268,52 @@ impl VolumePolicy {
}
}
/// Default connectivity check URL (Google's connectivity check service)
pub const DEFAULT_CHECK_URL: &str = "http://connectivitycheck.gstatic.com/generate_204";
/// Default interval for periodic connectivity checks (60 seconds)
pub const DEFAULT_CHECK_INTERVAL_SECS: u64 = 60;
/// Default timeout for connectivity checks (5 seconds)
pub const DEFAULT_CHECK_TIMEOUT_SECS: u64 = 5;
/// Network connectivity policy
#[derive(Debug, Clone)]
pub struct NetworkPolicy {
/// URL to check for global network connectivity
pub check_url: String,
/// How often to perform periodic connectivity checks
pub check_interval: Duration,
/// Timeout for connectivity checks
pub check_timeout: Duration,
}
impl Default for NetworkPolicy {
fn default() -> Self {
Self {
check_url: DEFAULT_CHECK_URL.to_string(),
check_interval: Duration::from_secs(DEFAULT_CHECK_INTERVAL_SECS),
check_timeout: Duration::from_secs(DEFAULT_CHECK_TIMEOUT_SECS),
}
}
}
/// Network requirements for a specific entry
#[derive(Debug, Clone, Default)]
pub struct NetworkRequirement {
/// Whether this entry requires network connectivity to launch
pub required: bool,
/// Override check URL for this entry (uses global if None)
pub check_url_override: Option<String>,
}
impl NetworkRequirement {
/// Get the check URL to use for this entry, given the global policy
pub fn effective_check_url<'a>(&'a self, global: &'a NetworkPolicy) -> &'a str {
self.check_url_override.as_deref().unwrap_or(&global.check_url)
}
}
// Conversion helpers
fn convert_entry_kind(raw: RawEntryKind) -> EntryKind {
@ -282,6 +346,21 @@ fn convert_volume_config(raw: &RawVolumeConfig) -> VolumePolicy {
}
}
fn convert_network_config(raw: &RawNetworkConfig) -> NetworkPolicy {
NetworkPolicy {
check_url: raw.check_url.clone().unwrap_or_else(|| DEFAULT_CHECK_URL.to_string()),
check_interval: Duration::from_secs(raw.check_interval_seconds.unwrap_or(DEFAULT_CHECK_INTERVAL_SECS)),
check_timeout: Duration::from_secs(raw.check_timeout_seconds.unwrap_or(DEFAULT_CHECK_TIMEOUT_SECS)),
}
}
fn convert_entry_network(raw: &RawEntryNetwork) -> NetworkRequirement {
NetworkRequirement {
required: raw.required,
check_url_override: raw.check_url.clone(),
}
}
fn convert_time_window(raw: crate::schema::RawTimeWindow) -> TimeWindow {
let days_mask = parse_days(&raw.days).unwrap_or(0x7F);
let (start_h, start_m) = parse_time(&raw.start).unwrap_or((0, 0));

View file

@ -48,6 +48,26 @@ pub struct RawServiceConfig {
/// Global volume restrictions
#[serde(default)]
pub volume: Option<RawVolumeConfig>,
/// Network connectivity settings
#[serde(default)]
pub network: Option<RawNetworkConfig>,
}
/// Network connectivity configuration
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct RawNetworkConfig {
/// URL to check for global network connectivity
/// Default: "http://connectivitycheck.gstatic.com/generate_204"
pub check_url: Option<String>,
/// How often to perform periodic connectivity checks (in seconds)
/// Default: 30
pub check_interval_seconds: Option<u64>,
/// Timeout for connectivity checks (in seconds)
/// Default: 5
pub check_timeout_seconds: Option<u64>,
}
/// Raw entry definition
@ -81,6 +101,10 @@ pub struct RawEntry {
#[serde(default)]
pub volume: Option<RawVolumeConfig>,
/// Network requirements for this entry
#[serde(default)]
pub network: Option<RawEntryNetwork>,
/// Explicitly disabled
#[serde(default)]
pub disabled: bool,
@ -89,6 +113,20 @@ pub struct RawEntry {
pub disabled_reason: Option<String>,
}
/// Network requirements for an entry
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct RawEntryNetwork {
/// Whether this entry requires network connectivity to launch
/// If true, the entry will not be available if the network check fails
#[serde(default)]
pub required: bool,
/// Override check URL for this entry
/// If specified, this URL will be checked instead of the global check_url
/// This is useful for entries that need specific services (e.g., Google, Microsoft)
pub check_url: Option<String>,
}
/// Raw entry kind
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]

View file

@ -260,6 +260,7 @@ mod tests {
limits: None,
warnings: None,
volume: None,
network: None,
disabled: false,
disabled_reason: None,
},
@ -277,6 +278,7 @@ mod tests {
limits: None,
warnings: None,
volume: None,
network: None,
disabled: false,
disabled_reason: None,
},

View file

@ -506,6 +506,8 @@ impl CoreEngine {
current_session,
entry_count: self.policy.entries.len(),
entries,
// Connectivity is populated by the daemon, not the core engine
connectivity: Default::default(),
}
}
@ -565,7 +567,7 @@ impl CoreEngine {
#[cfg(test)]
mod tests {
use super::*;
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy};
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy, NetworkRequirement};
use shepherd_api::EntryKind;
use shepherd_store::SqliteStore;
use std::collections::HashMap;
@ -594,12 +596,14 @@ mod tests {
},
warnings: vec![],
volume: None,
network: NetworkRequirement::default(),
disabled: false,
disabled_reason: None,
}],
default_warnings: vec![],
default_max_run: Some(Duration::from_secs(3600)),
volume: Default::default(),
network: Default::default(),
}
}
@ -677,6 +681,7 @@ mod tests {
message_template: Some("1 minute left".into()),
}],
volume: None,
network: NetworkRequirement::default(),
disabled: false,
disabled_reason: None,
}],
@ -684,6 +689,7 @@ mod tests {
default_warnings: vec![],
default_max_run: Some(Duration::from_secs(3600)),
volume: Default::default(),
network: Default::default(),
};
let store = Arc::new(SqliteStore::in_memory().unwrap());
@ -742,6 +748,7 @@ mod tests {
},
warnings: vec![],
volume: None,
network: NetworkRequirement::default(),
disabled: false,
disabled_reason: None,
}],
@ -749,6 +756,7 @@ mod tests {
default_warnings: vec![],
default_max_run: Some(Duration::from_secs(3600)),
volume: Default::default(),
network: Default::default(),
};
let store = Arc::new(SqliteStore::in_memory().unwrap());

View file

@ -17,6 +17,11 @@ nix = { workspace = true }
async-trait = "0.1"
dirs = "5.0"
shell-escape = "0.1"
chrono = { workspace = true }
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
netlink-sys = "0.8"
netlink-packet-core = "0.7"
netlink-packet-route = "0.21"
[dev-dependencies]
tempfile = { workspace = true }

View file

@ -0,0 +1,497 @@
//! Network connectivity monitoring for Linux
//!
//! This module provides:
//! - Periodic connectivity checks to a configurable URL
//! - Network interface change detection via netlink
//! - Per-entry connectivity status tracking
#![allow(dead_code)] // Methods on ConnectivityMonitor may be used for future admin commands
use chrono::{DateTime, Local};
use netlink_packet_core::{NetlinkMessage, NetlinkPayload};
use netlink_packet_route::RouteNetlinkMessage;
use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr};
use reqwest::Client;
use std::collections::HashMap;
use std::os::fd::AsRawFd;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, watch, RwLock};
use tracing::{debug, error, info, warn};
/// Events emitted by the connectivity monitor
#[derive(Debug, Clone)]
pub enum ConnectivityEvent {
/// Global connectivity status changed
StatusChanged {
connected: bool,
check_url: String,
},
/// Network interface changed (may trigger recheck)
InterfaceChanged,
}
/// Configuration for the connectivity monitor
#[derive(Debug, Clone)]
pub struct ConnectivityConfig {
/// URL to check for global network connectivity
pub check_url: String,
/// How often to perform periodic connectivity checks
pub check_interval: Duration,
/// Timeout for connectivity checks
pub check_timeout: Duration,
}
/// Cached connectivity check result
#[derive(Debug, Clone)]
struct CheckResult {
connected: bool,
checked_at: DateTime<Local>,
}
/// Connectivity monitor that tracks network availability
pub struct ConnectivityMonitor {
/// HTTP client for connectivity checks
client: Client,
/// Configuration
config: ConnectivityConfig,
/// Current global connectivity status
global_status: Arc<RwLock<Option<CheckResult>>>,
/// Cached results for specific URLs (entry-specific checks)
url_cache: Arc<RwLock<HashMap<String, CheckResult>>>,
/// Channel for sending events
event_tx: mpsc::Sender<ConnectivityEvent>,
/// Shutdown signal
shutdown_rx: watch::Receiver<bool>,
}
impl ConnectivityMonitor {
/// Create a new connectivity monitor
pub fn new(
config: ConnectivityConfig,
shutdown_rx: watch::Receiver<bool>,
) -> (Self, mpsc::Receiver<ConnectivityEvent>) {
let (event_tx, event_rx) = mpsc::channel(32);
let client = Client::builder()
.timeout(config.check_timeout)
.connect_timeout(config.check_timeout)
.build()
.expect("Failed to create HTTP client");
let monitor = Self {
client,
config,
global_status: Arc::new(RwLock::new(None)),
url_cache: Arc::new(RwLock::new(HashMap::new())),
event_tx,
shutdown_rx,
};
(monitor, event_rx)
}
/// Start the connectivity monitor (runs until shutdown)
pub async fn run(self) {
let check_interval = self.config.check_interval;
let check_url = self.config.check_url.clone();
// Spawn periodic check task
let periodic_handle = {
let client = self.client.clone();
let global_status = self.global_status.clone();
let event_tx = self.event_tx.clone();
let check_url = check_url.clone();
let check_timeout = self.config.check_timeout;
let mut shutdown = self.shutdown_rx.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(check_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
// Do initial check immediately
let connected = check_url_reachable(&client, &check_url, check_timeout).await;
update_global_status(&global_status, &event_tx, &check_url, connected).await;
loop {
tokio::select! {
_ = interval.tick() => {
let connected = check_url_reachable(&client, &check_url, check_timeout).await;
update_global_status(&global_status, &event_tx, &check_url, connected).await;
}
_ = shutdown.changed() => {
if *shutdown.borrow() {
debug!("Periodic check task shutting down");
break;
}
}
}
}
})
};
// Spawn netlink monitor task
let netlink_handle = {
let client = self.client.clone();
let global_status = self.global_status.clone();
let url_cache = self.url_cache.clone();
let event_tx = self.event_tx.clone();
let check_url = check_url.clone();
let check_timeout = self.config.check_timeout;
let mut shutdown = self.shutdown_rx.clone();
tokio::spawn(async move {
if let Err(e) = run_netlink_monitor(
&client,
&global_status,
&url_cache,
&event_tx,
&check_url,
check_timeout,
&mut shutdown,
)
.await
{
warn!(error = %e, "Netlink monitor failed, network change detection unavailable");
}
})
};
// Wait for shutdown
let mut shutdown = self.shutdown_rx.clone();
let _ = shutdown.changed().await;
// Cancel tasks
periodic_handle.abort();
netlink_handle.abort();
info!("Connectivity monitor stopped");
}
/// Get the current global connectivity status
pub async fn is_connected(&self) -> bool {
self.global_status
.read()
.await
.as_ref()
.is_some_and(|r| r.connected)
}
/// Get the last check time
pub async fn last_check_time(&self) -> Option<DateTime<Local>> {
self.global_status.read().await.as_ref().map(|r| r.checked_at)
}
/// Check if a specific URL is reachable (with caching)
/// Used for entry-specific network requirements
pub async fn check_url(&self, url: &str) -> bool {
// Check cache first
{
let cache = self.url_cache.read().await;
if let Some(result) = cache.get(url) {
// Cache valid for half the check interval
let cache_ttl = self.config.check_interval / 2;
let age = shepherd_util::now()
.signed_duration_since(result.checked_at)
.to_std()
.unwrap_or(Duration::MAX);
if age < cache_ttl {
return result.connected;
}
}
}
// Perform check
let connected = check_url_reachable(&self.client, url, self.config.check_timeout).await;
// Update cache
{
let mut cache = self.url_cache.write().await;
cache.insert(
url.to_string(),
CheckResult {
connected,
checked_at: shepherd_util::now(),
},
);
}
connected
}
/// Force an immediate connectivity recheck
pub async fn trigger_recheck(&self) {
let connected =
check_url_reachable(&self.client, &self.config.check_url, self.config.check_timeout)
.await;
update_global_status(
&self.global_status,
&self.event_tx,
&self.config.check_url,
connected,
)
.await;
// Clear URL cache to force rechecks
self.url_cache.write().await.clear();
}
}
/// Check if a URL is reachable
async fn check_url_reachable(client: &Client, url: &str, timeout: Duration) -> bool {
debug!(url = %url, "Checking connectivity");
match client
.get(url)
.timeout(timeout)
.send()
.await
{
Ok(response) => {
let status = response.status();
let connected = status.is_success() || status.as_u16() == 204;
debug!(url = %url, status = %status, connected = connected, "Connectivity check complete");
connected
}
Err(e) => {
debug!(url = %url, error = %e, "Connectivity check failed");
false
}
}
}
/// Update global status and emit event if changed
async fn update_global_status(
global_status: &Arc<RwLock<Option<CheckResult>>>,
event_tx: &mpsc::Sender<ConnectivityEvent>,
check_url: &str,
connected: bool,
) {
let mut status = global_status.write().await;
let previous = status.as_ref().map(|r| r.connected);
*status = Some(CheckResult {
connected,
checked_at: shepherd_util::now(),
});
// Emit event if status changed
if previous != Some(connected) {
info!(
connected = connected,
url = %check_url,
"Global connectivity status changed"
);
let _ = event_tx
.send(ConnectivityEvent::StatusChanged {
connected,
check_url: check_url.to_string(),
})
.await;
}
}
/// Run the netlink monitor to detect network interface changes
async fn run_netlink_monitor(
client: &Client,
global_status: &Arc<RwLock<Option<CheckResult>>>,
url_cache: &Arc<RwLock<HashMap<String, CheckResult>>>,
event_tx: &mpsc::Sender<ConnectivityEvent>,
check_url: &str,
check_timeout: Duration,
shutdown: &mut watch::Receiver<bool>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Create netlink socket for route notifications
let mut socket = Socket::new(NETLINK_ROUTE)?;
// Bind to multicast groups for link and address changes
// RTMGRP_LINK = 1, RTMGRP_IPV4_IFADDR = 0x10, RTMGRP_IPV6_IFADDR = 0x100
let groups = 1 | 0x10 | 0x100;
let addr = SocketAddr::new(0, groups);
socket.bind(&addr)?;
// Set non-blocking for async compatibility
socket.set_non_blocking(true)?;
info!("Netlink monitor started");
let fd = socket.as_raw_fd();
let mut buf = vec![0u8; 4096];
loop {
// Use tokio's async fd for the socket
let async_fd = tokio::io::unix::AsyncFd::new(fd)?;
tokio::select! {
result = async_fd.readable() => {
match result {
Ok(mut guard) => {
// Try to read from socket
match socket.recv(&mut buf, 0) {
Ok(len) if len > 0 => {
// Parse netlink messages
if has_relevant_netlink_event(&buf[..len]) {
debug!("Network interface change detected");
let _ = event_tx.send(ConnectivityEvent::InterfaceChanged).await;
// Clear URL cache
url_cache.write().await.clear();
// Recheck connectivity after a short delay
// (give network time to stabilize)
tokio::time::sleep(Duration::from_millis(500)).await;
let connected = check_url_reachable(client, check_url, check_timeout).await;
update_global_status(global_status, event_tx, check_url, connected).await;
}
guard.clear_ready();
}
Ok(_) => {
guard.clear_ready();
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
guard.clear_ready();
}
Err(e) => {
error!(error = %e, "Netlink recv error");
guard.clear_ready();
}
}
}
Err(e) => {
error!(error = %e, "Async fd error");
}
}
}
_ = shutdown.changed() => {
if *shutdown.borrow() {
debug!("Netlink monitor shutting down");
break;
}
}
}
}
Ok(())
}
/// Check if a netlink message buffer contains relevant network events
fn has_relevant_netlink_event(buf: &[u8]) -> bool {
let mut offset = 0;
while offset < buf.len() {
match NetlinkMessage::<RouteNetlinkMessage>::deserialize(&buf[offset..]) {
Ok(msg) => {
if let NetlinkPayload::InnerMessage(route_msg) = &msg.payload
&& matches!(
route_msg,
// Link up/down events
RouteNetlinkMessage::NewLink(_)
| RouteNetlinkMessage::DelLink(_)
// Address added/removed
| RouteNetlinkMessage::NewAddress(_)
| RouteNetlinkMessage::DelAddress(_)
// Route changes
| RouteNetlinkMessage::NewRoute(_)
| RouteNetlinkMessage::DelRoute(_)
)
{
return true;
}
// Move to next message
let len = msg.header.length as usize;
if len == 0 {
break;
}
offset += len;
}
Err(_) => break,
}
}
false
}
/// Handle for accessing connectivity status from other parts of the service
#[derive(Clone)]
pub struct ConnectivityHandle {
client: Client,
global_status: Arc<RwLock<Option<CheckResult>>>,
url_cache: Arc<RwLock<HashMap<String, CheckResult>>>,
check_timeout: Duration,
cache_ttl: Duration,
global_check_url: String,
}
impl ConnectivityHandle {
/// Create a handle from the monitor
pub fn from_monitor(monitor: &ConnectivityMonitor) -> Self {
Self {
client: monitor.client.clone(),
global_status: monitor.global_status.clone(),
url_cache: monitor.url_cache.clone(),
check_timeout: monitor.config.check_timeout,
cache_ttl: monitor.config.check_interval / 2,
global_check_url: monitor.config.check_url.clone(),
}
}
/// Get the current global connectivity status
pub async fn is_connected(&self) -> bool {
self.global_status
.read()
.await
.as_ref()
.is_some_and(|r| r.connected)
}
/// Get the last check time
pub async fn last_check_time(&self) -> Option<DateTime<Local>> {
self.global_status.read().await.as_ref().map(|r| r.checked_at)
}
/// Get the global check URL
pub fn global_check_url(&self) -> &str {
&self.global_check_url
}
/// Check if a specific URL is reachable (with caching)
pub async fn check_url(&self, url: &str) -> bool {
// If it's the global URL, use global status
if url == self.global_check_url {
return self.is_connected().await;
}
// Check cache first
{
let cache = self.url_cache.read().await;
if let Some(result) = cache.get(url) {
let age = shepherd_util::now()
.signed_duration_since(result.checked_at)
.to_std()
.unwrap_or(Duration::MAX);
if age < self.cache_ttl {
return result.connected;
}
}
}
// Perform check
let connected = check_url_reachable(&self.client, url, self.check_timeout).await;
// Update cache
{
let mut cache = self.url_cache.write().await;
cache.insert(
url.to_string(),
CheckResult {
connected,
checked_at: shepherd_util::now(),
},
);
}
connected
}
}

View file

@ -6,11 +6,14 @@
//! - Exit observation
//! - stdout/stderr capture
//! - Volume control with auto-detection of sound systems
//! - Network connectivity monitoring via netlink
mod adapter;
mod connectivity;
mod process;
mod volume;
pub use adapter::*;
pub use connectivity::*;
pub use process::*;
pub use volume::*;

View file

@ -252,5 +252,6 @@ fn reason_to_message(reason: &ReasonCode) -> &'static str {
ReasonCode::SessionActive { .. } => "Another session is active",
ReasonCode::UnsupportedKind { .. } => "Entry type not supported",
ReasonCode::Disabled { .. } => "Entry disabled",
ReasonCode::NetworkUnavailable { .. } => "Network connection required",
}
}

View file

@ -117,6 +117,10 @@ impl SharedState {
EventPayload::VolumeChanged { .. } => {
// Volume events are handled by HUD
}
EventPayload::ConnectivityChanged { .. } => {
// Connectivity changes may affect entry availability - request fresh state
self.set(LauncherState::Connecting);
}
}
}

View file

@ -27,6 +27,7 @@ tracing-subscriber = { workspace = true }
tokio = { workspace = true }
anyhow = { workspace = true }
clap = { version = "4.5", features = ["derive", "env"] }
nix = { workspace = true }
[dev-dependencies]
tempfile = { workspace = true }

View file

@ -8,17 +8,22 @@
//! - Host adapter (Linux)
//! - IPC server
//! - Volume control
//! - Network connectivity monitoring
use anyhow::{Context, Result};
use clap::Parser;
use shepherd_api::{
Command, ErrorCode, ErrorInfo, Event, EventPayload, HealthStatus,
Response, ResponsePayload, SessionEndReason, StopMode, VolumeInfo, VolumeRestrictions,
Command, ConnectivityStatus, ErrorCode, ErrorInfo, Event, EventPayload, HealthStatus,
ReasonCode, Response, ResponsePayload, SessionEndReason, StopMode, VolumeInfo,
VolumeRestrictions,
};
use shepherd_config::{load_config, VolumePolicy};
use shepherd_core::{CoreEngine, CoreEvent, LaunchDecision, StopDecision};
use shepherd_host_api::{HostAdapter, HostEvent, StopMode as HostStopMode, VolumeController};
use shepherd_host_linux::{LinuxHost, LinuxVolumeController};
use shepherd_host_linux::{
ConnectivityConfig, ConnectivityEvent, ConnectivityHandle, ConnectivityMonitor, LinuxHost,
LinuxVolumeController,
};
use shepherd_ipc::{IpcServer, ServerMessage};
use shepherd_store::{AuditEvent, AuditEventType, SqliteStore, Store};
use shepherd_util::{default_config_path, ClientId, MonotonicInstant, RateLimiter};
@ -26,7 +31,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::signal::unix::{signal, SignalKind};
use tokio::sync::Mutex;
use tokio::sync::{mpsc, watch, Mutex};
use tracing::{debug, error, info, warn};
use tracing_subscriber::EnvFilter;
@ -60,10 +65,12 @@ struct Service {
ipc: Arc<IpcServer>,
store: Arc<dyn Store>,
rate_limiter: RateLimiter,
connectivity: ConnectivityHandle,
shutdown_tx: watch::Sender<bool>,
}
impl Service {
async fn new(args: &Args) -> Result<Self> {
async fn new(args: &Args) -> Result<(Self, mpsc::Receiver<ConnectivityEvent>)> {
// Load configuration
let policy = load_config(&args.config)
.with_context(|| format!("Failed to load config from {:?}", args.config))?;
@ -116,6 +123,7 @@ impl Service {
}
// Initialize core engine
let network_policy = policy.network.clone();
let engine = CoreEngine::new(policy, store.clone(), host.capabilities().clone());
// Initialize IPC server
@ -127,17 +135,43 @@ impl Service {
// Rate limiter: 30 requests per second per client
let rate_limiter = RateLimiter::new(30, Duration::from_secs(1));
Ok(Self {
engine,
host,
volume,
ipc: Arc::new(ipc),
store,
rate_limiter,
})
// Initialize connectivity monitor
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let connectivity_config = ConnectivityConfig {
check_url: network_policy.check_url,
check_interval: network_policy.check_interval,
check_timeout: network_policy.check_timeout,
};
let (connectivity_monitor, connectivity_events) =
ConnectivityMonitor::new(connectivity_config, shutdown_rx);
let connectivity = ConnectivityHandle::from_monitor(&connectivity_monitor);
// Spawn connectivity monitor task
tokio::spawn(async move {
connectivity_monitor.run().await;
});
info!(
check_url = %connectivity.global_check_url(),
"Connectivity monitor started"
);
Ok((
Self {
engine,
host,
volume,
ipc: Arc::new(ipc),
store,
rate_limiter,
connectivity,
shutdown_tx,
},
connectivity_events,
))
}
async fn run(self) -> Result<()> {
async fn run(self, mut connectivity_events: mpsc::Receiver<ConnectivityEvent>) -> Result<()> {
// Start host process monitor
let _monitor_handle = self.host.start_monitor();
@ -155,6 +189,8 @@ impl Service {
let host = self.host.clone();
let volume = self.volume.clone();
let store = self.store.clone();
let connectivity = self.connectivity.clone();
let shutdown_tx = self.shutdown_tx.clone();
// Spawn IPC accept task
let ipc_accept = ipc_ref.clone();
@ -218,7 +254,12 @@ impl Service {
// IPC messages
Some(msg) = ipc_messages.recv() => {
Self::handle_ipc_message(&engine, &host, &volume, &ipc_ref, &store, &rate_limiter, msg).await;
Self::handle_ipc_message(&engine, &host, &volume, &ipc_ref, &store, &rate_limiter, &connectivity, msg).await;
}
// Connectivity events
Some(conn_event) = connectivity_events.recv() => {
Self::handle_connectivity_event(&engine, &ipc_ref, &connectivity, conn_event).await;
}
}
}
@ -226,6 +267,9 @@ impl Service {
// Graceful shutdown
info!("Shutting down shepherdd");
// Signal connectivity monitor to stop
let _ = shutdown_tx.send(true);
// Stop all running sessions
{
let engine = engine.lock().await;
@ -433,6 +477,45 @@ impl Service {
}
}
async fn handle_connectivity_event(
engine: &Arc<Mutex<CoreEngine>>,
ipc: &Arc<IpcServer>,
connectivity: &ConnectivityHandle,
event: ConnectivityEvent,
) {
match event {
ConnectivityEvent::StatusChanged {
connected,
check_url,
} => {
info!(connected = connected, url = %check_url, "Connectivity status changed");
// Broadcast connectivity change event
ipc.broadcast_event(Event::new(EventPayload::ConnectivityChanged {
connected,
check_url,
}));
// Also broadcast state change so clients can update entry availability
let state = {
let eng = engine.lock().await;
let mut state = eng.get_state();
state.connectivity = ConnectivityStatus {
connected: connectivity.is_connected().await,
check_url: Some(connectivity.global_check_url().to_string()),
last_check: connectivity.last_check_time().await,
};
state
};
ipc.broadcast_event(Event::new(EventPayload::StateChanged(state)));
}
ConnectivityEvent::InterfaceChanged => {
debug!("Network interface changed, connectivity recheck in progress");
}
}
}
async fn handle_ipc_message(
engine: &Arc<Mutex<CoreEngine>>,
host: &Arc<LinuxHost>,
@ -440,6 +523,7 @@ impl Service {
ipc: &Arc<IpcServer>,
store: &Arc<dyn Store>,
rate_limiter: &Arc<Mutex<RateLimiter>>,
connectivity: &ConnectivityHandle,
msg: ServerMessage,
) {
match msg {
@ -458,7 +542,7 @@ impl Service {
}
let response =
Self::handle_command(engine, host, volume, ipc, store, &client_id, request.request_id, request.command)
Self::handle_command(engine, host, volume, ipc, store, connectivity, &client_id, request.request_id, request.command)
.await;
let _ = ipc.send_response(&client_id, response).await;
@ -504,6 +588,7 @@ impl Service {
volume: &Arc<LinuxVolumeController>,
ipc: &Arc<IpcServer>,
store: &Arc<dyn Store>,
connectivity: &ConnectivityHandle,
client_id: &ClientId,
request_id: u64,
command: Command,
@ -513,7 +598,13 @@ impl Service {
match command {
Command::GetState => {
let state = engine.lock().await.get_state();
let mut state = engine.lock().await.get_state();
// Add connectivity status
state.connectivity = ConnectivityStatus {
connected: connectivity.is_connected().await,
check_url: Some(connectivity.global_check_url().to_string()),
last_check: connectivity.last_check_time().await,
};
Response::success(request_id, ResponsePayload::State(state))
}
@ -526,6 +617,30 @@ impl Service {
Command::Launch { entry_id } => {
let mut eng = engine.lock().await;
// First check if the entry requires network and if it's available
if let Some(entry) = eng.policy().get_entry(&entry_id)
&& entry.network.required
{
let check_url = entry.network.effective_check_url(&eng.policy().network);
let network_ok = connectivity.check_url(check_url).await;
if !network_ok {
info!(
entry_id = %entry_id,
check_url = %check_url,
"Launch denied: network connectivity check failed"
);
return Response::success(
request_id,
ResponsePayload::LaunchDenied {
reasons: vec![ReasonCode::NetworkUnavailable {
check_url: check_url.to_string(),
}],
},
);
}
}
match eng.request_launch(&entry_id, now) {
LaunchDecision::Approved(plan) => {
// Start the session in the engine
@ -939,6 +1054,6 @@ async fn main() -> Result<()> {
);
// Create and run the service
let service = Service::new(&args).await?;
service.run().await
let (service, connectivity_events) = Service::new(&args).await?;
service.run(connectivity_events).await
}

View file

@ -3,7 +3,7 @@
//! These tests verify the end-to-end behavior of shepherdd.
use shepherd_api::{EntryKind, WarningSeverity, WarningThreshold};
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy, Policy};
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy, NetworkRequirement, Policy};
use shepherd_core::{CoreEngine, CoreEvent, LaunchDecision};
use shepherd_host_api::{HostCapabilities, MockHost};
use shepherd_store::{SqliteStore, Store};
@ -48,6 +48,7 @@ fn make_test_policy() -> Policy {
},
],
volume: None,
network: NetworkRequirement::default(),
disabled: false,
disabled_reason: None,
},
@ -55,6 +56,7 @@ fn make_test_policy() -> Policy {
default_warnings: vec![],
default_max_run: Some(Duration::from_secs(3600)),
volume: Default::default(),
network: Default::default(),
}
}