Compare commits
1 commit
main
...
u/aarmea/9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
266685628e |
18 changed files with 1710 additions and 33 deletions
886
Cargo.lock
generated
886
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -58,8 +58,14 @@ anyhow = "1.0"
|
|||
uuid = { version = "1.6", features = ["v4", "serde"] }
|
||||
bitflags = "2.4"
|
||||
|
||||
# HTTP client (for connectivity checks)
|
||||
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
|
||||
|
||||
# Unix-specific
|
||||
nix = { version = "0.29", features = ["signal", "process", "user", "socket"] }
|
||||
netlink-sys = "0.8"
|
||||
netlink-packet-core = "0.7"
|
||||
netlink-packet-route = "0.21"
|
||||
|
||||
# CLI
|
||||
clap = { version = "4.5", features = ["derive", "env"] }
|
||||
|
|
|
|||
|
|
@ -30,6 +30,16 @@ max_volume = 80 # Maximum volume percentage (0-100)
|
|||
allow_mute = true # Whether mute toggle is allowed
|
||||
allow_change = true # Whether volume changes are allowed at all
|
||||
|
||||
# Network connectivity settings (optional)
|
||||
# Used to check if Internet is available before allowing network-dependent entries
|
||||
[service.network]
|
||||
# URL to check for global network connectivity (default: Google's connectivity check)
|
||||
# check_url = "http://connectivitycheck.gstatic.com/generate_204"
|
||||
# How often to perform periodic connectivity checks, in seconds (default: 60)
|
||||
# check_interval_seconds = 60
|
||||
# Timeout for connectivity checks, in seconds (default: 5)
|
||||
# check_timeout_seconds = 5
|
||||
|
||||
# Default warning thresholds
|
||||
[[service.default_warnings]]
|
||||
seconds_before = 300
|
||||
|
|
@ -211,6 +221,11 @@ message = "30 seconds! Save NOW!"
|
|||
[entries.volume]
|
||||
max_volume = 60 # Limit volume during gaming sessions
|
||||
|
||||
# Network requirements for online games
|
||||
[entries.network]
|
||||
required = true # Minecraft needs network for authentication and multiplayer
|
||||
check_url = "http://www.msftconnecttest.com/connecttest.txt" # Use Microsoft's check (Minecraft is owned by Microsoft)
|
||||
|
||||
## === Steam games ===
|
||||
# Steam can be used via Canonical's Steam snap package:
|
||||
# https://snapcraft.io/steam
|
||||
|
|
@ -244,6 +259,9 @@ end = "20:00"
|
|||
# No [entries.limits] section - uses service defaults
|
||||
# Omitting limits entirely uses default_max_run_seconds
|
||||
|
||||
[entries.network]
|
||||
required = true # Steam needs network for authentication
|
||||
|
||||
# A Short Hike via Steam
|
||||
# https://store.steampowered.com/app/1055540/A_Short_Hike/
|
||||
[[entries]]
|
||||
|
|
@ -267,6 +285,9 @@ days = "weekends"
|
|||
start = "10:00"
|
||||
end = "20:00"
|
||||
|
||||
[entries.network]
|
||||
required = true # Steam needs network for authentication
|
||||
|
||||
## === Media ===
|
||||
# Just use `mpv` to play media (for now).
|
||||
# Files can be local on your system or URLs (YouTube, etc).
|
||||
|
|
@ -314,6 +335,9 @@ max_run_seconds = 0 # Unlimited: sleep/study aid
|
|||
daily_quota_seconds = 0 # Unlimited
|
||||
cooldown_seconds = 0 # No cooldown
|
||||
|
||||
[entries.network]
|
||||
required = true # YouTube streaming needs network
|
||||
|
||||
# Terminal for debugging only
|
||||
[[entries]]
|
||||
id = "terminal"
|
||||
|
|
|
|||
|
|
@ -234,6 +234,7 @@ mod tests {
|
|||
current_session: None,
|
||||
entry_count: 5,
|
||||
entries: vec![],
|
||||
connectivity: Default::default(),
|
||||
}),
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -88,6 +88,14 @@ pub enum EventPayload {
|
|||
event_type: String,
|
||||
details: serde_json::Value,
|
||||
},
|
||||
|
||||
/// Network connectivity status changed
|
||||
ConnectivityChanged {
|
||||
/// Whether global connectivity check now passes
|
||||
connected: bool,
|
||||
/// The URL that was checked
|
||||
check_url: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
|
|
@ -121,6 +121,11 @@ pub enum ReasonCode {
|
|||
Disabled {
|
||||
reason: Option<String>,
|
||||
},
|
||||
/// Network connectivity check failed
|
||||
NetworkUnavailable {
|
||||
/// The URL that was checked
|
||||
check_url: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Warning severity level
|
||||
|
|
@ -197,6 +202,20 @@ pub struct ServiceStateSnapshot {
|
|||
/// Available entries for UI display
|
||||
#[serde(default)]
|
||||
pub entries: Vec<EntryView>,
|
||||
/// Network connectivity status
|
||||
#[serde(default)]
|
||||
pub connectivity: ConnectivityStatus,
|
||||
}
|
||||
|
||||
/// Network connectivity status
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ConnectivityStatus {
|
||||
/// Whether global network connectivity check passed
|
||||
pub connected: bool,
|
||||
/// The URL that was checked for global connectivity
|
||||
pub check_url: Option<String>,
|
||||
/// When the last check was performed
|
||||
pub last_check: Option<DateTime<Local>>,
|
||||
}
|
||||
|
||||
/// Role for authorization
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
//! Validated policy structures
|
||||
|
||||
use crate::schema::{RawConfig, RawEntry, RawEntryKind, RawVolumeConfig, RawServiceConfig, RawWarningThreshold};
|
||||
use crate::schema::{RawConfig, RawEntry, RawEntryKind, RawNetworkConfig, RawEntryNetwork, RawVolumeConfig, RawServiceConfig, RawWarningThreshold};
|
||||
use crate::validation::{parse_days, parse_time};
|
||||
use shepherd_api::{EntryKind, WarningSeverity, WarningThreshold};
|
||||
use shepherd_util::{DaysOfWeek, EntryId, TimeWindow, WallClock, default_data_dir, default_log_dir, socket_path_without_env};
|
||||
|
|
@ -24,6 +24,9 @@ pub struct Policy {
|
|||
|
||||
/// Global volume restrictions
|
||||
pub volume: VolumePolicy,
|
||||
|
||||
/// Network connectivity policy
|
||||
pub network: NetworkPolicy,
|
||||
}
|
||||
|
||||
impl Policy {
|
||||
|
|
@ -50,6 +53,13 @@ impl Policy {
|
|||
.map(convert_volume_config)
|
||||
.unwrap_or_default();
|
||||
|
||||
let network = raw
|
||||
.service
|
||||
.network
|
||||
.as_ref()
|
||||
.map(convert_network_config)
|
||||
.unwrap_or_default();
|
||||
|
||||
let entries = raw
|
||||
.entries
|
||||
.into_iter()
|
||||
|
|
@ -62,6 +72,7 @@ impl Policy {
|
|||
default_warnings,
|
||||
default_max_run,
|
||||
volume: global_volume,
|
||||
network,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -130,6 +141,7 @@ pub struct Entry {
|
|||
pub limits: LimitsPolicy,
|
||||
pub warnings: Vec<WarningThreshold>,
|
||||
pub volume: Option<VolumePolicy>,
|
||||
pub network: NetworkRequirement,
|
||||
pub disabled: bool,
|
||||
pub disabled_reason: Option<String>,
|
||||
}
|
||||
|
|
@ -159,6 +171,11 @@ impl Entry {
|
|||
.map(|w| w.into_iter().map(convert_warning).collect())
|
||||
.unwrap_or_else(|| default_warnings.to_vec());
|
||||
let volume = raw.volume.as_ref().map(convert_volume_config);
|
||||
let network = raw
|
||||
.network
|
||||
.as_ref()
|
||||
.map(convert_entry_network)
|
||||
.unwrap_or_default();
|
||||
|
||||
Self {
|
||||
id: EntryId::new(raw.id),
|
||||
|
|
@ -169,6 +186,7 @@ impl Entry {
|
|||
limits,
|
||||
warnings,
|
||||
volume,
|
||||
network,
|
||||
disabled: raw.disabled,
|
||||
disabled_reason: raw.disabled_reason,
|
||||
}
|
||||
|
|
@ -250,6 +268,52 @@ impl VolumePolicy {
|
|||
}
|
||||
}
|
||||
|
||||
/// Default connectivity check URL (Google's connectivity check service)
|
||||
pub const DEFAULT_CHECK_URL: &str = "http://connectivitycheck.gstatic.com/generate_204";
|
||||
|
||||
/// Default interval for periodic connectivity checks (60 seconds)
|
||||
pub const DEFAULT_CHECK_INTERVAL_SECS: u64 = 60;
|
||||
|
||||
/// Default timeout for connectivity checks (5 seconds)
|
||||
pub const DEFAULT_CHECK_TIMEOUT_SECS: u64 = 5;
|
||||
|
||||
/// Network connectivity policy
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NetworkPolicy {
|
||||
/// URL to check for global network connectivity
|
||||
pub check_url: String,
|
||||
/// How often to perform periodic connectivity checks
|
||||
pub check_interval: Duration,
|
||||
/// Timeout for connectivity checks
|
||||
pub check_timeout: Duration,
|
||||
}
|
||||
|
||||
impl Default for NetworkPolicy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
check_url: DEFAULT_CHECK_URL.to_string(),
|
||||
check_interval: Duration::from_secs(DEFAULT_CHECK_INTERVAL_SECS),
|
||||
check_timeout: Duration::from_secs(DEFAULT_CHECK_TIMEOUT_SECS),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Network requirements for a specific entry
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct NetworkRequirement {
|
||||
/// Whether this entry requires network connectivity to launch
|
||||
pub required: bool,
|
||||
/// Override check URL for this entry (uses global if None)
|
||||
pub check_url_override: Option<String>,
|
||||
}
|
||||
|
||||
impl NetworkRequirement {
|
||||
/// Get the check URL to use for this entry, given the global policy
|
||||
pub fn effective_check_url<'a>(&'a self, global: &'a NetworkPolicy) -> &'a str {
|
||||
self.check_url_override.as_deref().unwrap_or(&global.check_url)
|
||||
}
|
||||
}
|
||||
|
||||
// Conversion helpers
|
||||
|
||||
fn convert_entry_kind(raw: RawEntryKind) -> EntryKind {
|
||||
|
|
@ -282,6 +346,21 @@ fn convert_volume_config(raw: &RawVolumeConfig) -> VolumePolicy {
|
|||
}
|
||||
}
|
||||
|
||||
fn convert_network_config(raw: &RawNetworkConfig) -> NetworkPolicy {
|
||||
NetworkPolicy {
|
||||
check_url: raw.check_url.clone().unwrap_or_else(|| DEFAULT_CHECK_URL.to_string()),
|
||||
check_interval: Duration::from_secs(raw.check_interval_seconds.unwrap_or(DEFAULT_CHECK_INTERVAL_SECS)),
|
||||
check_timeout: Duration::from_secs(raw.check_timeout_seconds.unwrap_or(DEFAULT_CHECK_TIMEOUT_SECS)),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_entry_network(raw: &RawEntryNetwork) -> NetworkRequirement {
|
||||
NetworkRequirement {
|
||||
required: raw.required,
|
||||
check_url_override: raw.check_url.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_time_window(raw: crate::schema::RawTimeWindow) -> TimeWindow {
|
||||
let days_mask = parse_days(&raw.days).unwrap_or(0x7F);
|
||||
let (start_h, start_m) = parse_time(&raw.start).unwrap_or((0, 0));
|
||||
|
|
|
|||
|
|
@ -48,6 +48,26 @@ pub struct RawServiceConfig {
|
|||
/// Global volume restrictions
|
||||
#[serde(default)]
|
||||
pub volume: Option<RawVolumeConfig>,
|
||||
|
||||
/// Network connectivity settings
|
||||
#[serde(default)]
|
||||
pub network: Option<RawNetworkConfig>,
|
||||
}
|
||||
|
||||
/// Network connectivity configuration
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct RawNetworkConfig {
|
||||
/// URL to check for global network connectivity
|
||||
/// Default: "http://connectivitycheck.gstatic.com/generate_204"
|
||||
pub check_url: Option<String>,
|
||||
|
||||
/// How often to perform periodic connectivity checks (in seconds)
|
||||
/// Default: 30
|
||||
pub check_interval_seconds: Option<u64>,
|
||||
|
||||
/// Timeout for connectivity checks (in seconds)
|
||||
/// Default: 5
|
||||
pub check_timeout_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
/// Raw entry definition
|
||||
|
|
@ -81,6 +101,10 @@ pub struct RawEntry {
|
|||
#[serde(default)]
|
||||
pub volume: Option<RawVolumeConfig>,
|
||||
|
||||
/// Network requirements for this entry
|
||||
#[serde(default)]
|
||||
pub network: Option<RawEntryNetwork>,
|
||||
|
||||
/// Explicitly disabled
|
||||
#[serde(default)]
|
||||
pub disabled: bool,
|
||||
|
|
@ -89,6 +113,20 @@ pub struct RawEntry {
|
|||
pub disabled_reason: Option<String>,
|
||||
}
|
||||
|
||||
/// Network requirements for an entry
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct RawEntryNetwork {
|
||||
/// Whether this entry requires network connectivity to launch
|
||||
/// If true, the entry will not be available if the network check fails
|
||||
#[serde(default)]
|
||||
pub required: bool,
|
||||
|
||||
/// Override check URL for this entry
|
||||
/// If specified, this URL will be checked instead of the global check_url
|
||||
/// This is useful for entries that need specific services (e.g., Google, Microsoft)
|
||||
pub check_url: Option<String>,
|
||||
}
|
||||
|
||||
/// Raw entry kind
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
|
|
|
|||
|
|
@ -260,6 +260,7 @@ mod tests {
|
|||
limits: None,
|
||||
warnings: None,
|
||||
volume: None,
|
||||
network: None,
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
},
|
||||
|
|
@ -277,6 +278,7 @@ mod tests {
|
|||
limits: None,
|
||||
warnings: None,
|
||||
volume: None,
|
||||
network: None,
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -506,6 +506,8 @@ impl CoreEngine {
|
|||
current_session,
|
||||
entry_count: self.policy.entries.len(),
|
||||
entries,
|
||||
// Connectivity is populated by the daemon, not the core engine
|
||||
connectivity: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -565,7 +567,7 @@ impl CoreEngine {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy};
|
||||
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy, NetworkRequirement};
|
||||
use shepherd_api::EntryKind;
|
||||
use shepherd_store::SqliteStore;
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -594,12 +596,14 @@ mod tests {
|
|||
},
|
||||
warnings: vec![],
|
||||
volume: None,
|
||||
network: NetworkRequirement::default(),
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
}],
|
||||
default_warnings: vec![],
|
||||
default_max_run: Some(Duration::from_secs(3600)),
|
||||
volume: Default::default(),
|
||||
network: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -677,6 +681,7 @@ mod tests {
|
|||
message_template: Some("1 minute left".into()),
|
||||
}],
|
||||
volume: None,
|
||||
network: NetworkRequirement::default(),
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
}],
|
||||
|
|
@ -684,6 +689,7 @@ mod tests {
|
|||
default_warnings: vec![],
|
||||
default_max_run: Some(Duration::from_secs(3600)),
|
||||
volume: Default::default(),
|
||||
network: Default::default(),
|
||||
};
|
||||
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
|
|
@ -742,6 +748,7 @@ mod tests {
|
|||
},
|
||||
warnings: vec![],
|
||||
volume: None,
|
||||
network: NetworkRequirement::default(),
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
}],
|
||||
|
|
@ -749,6 +756,7 @@ mod tests {
|
|||
default_warnings: vec![],
|
||||
default_max_run: Some(Duration::from_secs(3600)),
|
||||
volume: Default::default(),
|
||||
network: Default::default(),
|
||||
};
|
||||
|
||||
let store = Arc::new(SqliteStore::in_memory().unwrap());
|
||||
|
|
|
|||
|
|
@ -17,6 +17,11 @@ nix = { workspace = true }
|
|||
async-trait = "0.1"
|
||||
dirs = "5.0"
|
||||
shell-escape = "0.1"
|
||||
chrono = { workspace = true }
|
||||
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
|
||||
netlink-sys = "0.8"
|
||||
netlink-packet-core = "0.7"
|
||||
netlink-packet-route = "0.21"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
|
|
|
|||
497
crates/shepherd-host-linux/src/connectivity.rs
Normal file
497
crates/shepherd-host-linux/src/connectivity.rs
Normal file
|
|
@ -0,0 +1,497 @@
|
|||
//! Network connectivity monitoring for Linux
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - Periodic connectivity checks to a configurable URL
|
||||
//! - Network interface change detection via netlink
|
||||
//! - Per-entry connectivity status tracking
|
||||
|
||||
#![allow(dead_code)] // Methods on ConnectivityMonitor may be used for future admin commands
|
||||
|
||||
use chrono::{DateTime, Local};
|
||||
use netlink_packet_core::{NetlinkMessage, NetlinkPayload};
|
||||
use netlink_packet_route::RouteNetlinkMessage;
|
||||
use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr};
|
||||
use reqwest::Client;
|
||||
use std::collections::HashMap;
|
||||
use std::os::fd::AsRawFd;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
/// Events emitted by the connectivity monitor
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ConnectivityEvent {
|
||||
/// Global connectivity status changed
|
||||
StatusChanged {
|
||||
connected: bool,
|
||||
check_url: String,
|
||||
},
|
||||
/// Network interface changed (may trigger recheck)
|
||||
InterfaceChanged,
|
||||
}
|
||||
|
||||
/// Configuration for the connectivity monitor
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnectivityConfig {
|
||||
/// URL to check for global network connectivity
|
||||
pub check_url: String,
|
||||
/// How often to perform periodic connectivity checks
|
||||
pub check_interval: Duration,
|
||||
/// Timeout for connectivity checks
|
||||
pub check_timeout: Duration,
|
||||
}
|
||||
|
||||
/// Cached connectivity check result
|
||||
#[derive(Debug, Clone)]
|
||||
struct CheckResult {
|
||||
connected: bool,
|
||||
checked_at: DateTime<Local>,
|
||||
}
|
||||
|
||||
/// Connectivity monitor that tracks network availability
|
||||
pub struct ConnectivityMonitor {
|
||||
/// HTTP client for connectivity checks
|
||||
client: Client,
|
||||
/// Configuration
|
||||
config: ConnectivityConfig,
|
||||
/// Current global connectivity status
|
||||
global_status: Arc<RwLock<Option<CheckResult>>>,
|
||||
/// Cached results for specific URLs (entry-specific checks)
|
||||
url_cache: Arc<RwLock<HashMap<String, CheckResult>>>,
|
||||
/// Channel for sending events
|
||||
event_tx: mpsc::Sender<ConnectivityEvent>,
|
||||
/// Shutdown signal
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl ConnectivityMonitor {
|
||||
/// Create a new connectivity monitor
|
||||
pub fn new(
|
||||
config: ConnectivityConfig,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
) -> (Self, mpsc::Receiver<ConnectivityEvent>) {
|
||||
let (event_tx, event_rx) = mpsc::channel(32);
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(config.check_timeout)
|
||||
.connect_timeout(config.check_timeout)
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
let monitor = Self {
|
||||
client,
|
||||
config,
|
||||
global_status: Arc::new(RwLock::new(None)),
|
||||
url_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
event_tx,
|
||||
shutdown_rx,
|
||||
};
|
||||
|
||||
(monitor, event_rx)
|
||||
}
|
||||
|
||||
/// Start the connectivity monitor (runs until shutdown)
|
||||
pub async fn run(self) {
|
||||
let check_interval = self.config.check_interval;
|
||||
let check_url = self.config.check_url.clone();
|
||||
|
||||
// Spawn periodic check task
|
||||
let periodic_handle = {
|
||||
let client = self.client.clone();
|
||||
let global_status = self.global_status.clone();
|
||||
let event_tx = self.event_tx.clone();
|
||||
let check_url = check_url.clone();
|
||||
let check_timeout = self.config.check_timeout;
|
||||
let mut shutdown = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(check_interval);
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
|
||||
// Do initial check immediately
|
||||
let connected = check_url_reachable(&client, &check_url, check_timeout).await;
|
||||
update_global_status(&global_status, &event_tx, &check_url, connected).await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = interval.tick() => {
|
||||
let connected = check_url_reachable(&client, &check_url, check_timeout).await;
|
||||
update_global_status(&global_status, &event_tx, &check_url, connected).await;
|
||||
}
|
||||
_ = shutdown.changed() => {
|
||||
if *shutdown.borrow() {
|
||||
debug!("Periodic check task shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
// Spawn netlink monitor task
|
||||
let netlink_handle = {
|
||||
let client = self.client.clone();
|
||||
let global_status = self.global_status.clone();
|
||||
let url_cache = self.url_cache.clone();
|
||||
let event_tx = self.event_tx.clone();
|
||||
let check_url = check_url.clone();
|
||||
let check_timeout = self.config.check_timeout;
|
||||
let mut shutdown = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = run_netlink_monitor(
|
||||
&client,
|
||||
&global_status,
|
||||
&url_cache,
|
||||
&event_tx,
|
||||
&check_url,
|
||||
check_timeout,
|
||||
&mut shutdown,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(error = %e, "Netlink monitor failed, network change detection unavailable");
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
// Wait for shutdown
|
||||
let mut shutdown = self.shutdown_rx.clone();
|
||||
let _ = shutdown.changed().await;
|
||||
|
||||
// Cancel tasks
|
||||
periodic_handle.abort();
|
||||
netlink_handle.abort();
|
||||
|
||||
info!("Connectivity monitor stopped");
|
||||
}
|
||||
|
||||
/// Get the current global connectivity status
|
||||
pub async fn is_connected(&self) -> bool {
|
||||
self.global_status
|
||||
.read()
|
||||
.await
|
||||
.as_ref()
|
||||
.is_some_and(|r| r.connected)
|
||||
}
|
||||
|
||||
/// Get the last check time
|
||||
pub async fn last_check_time(&self) -> Option<DateTime<Local>> {
|
||||
self.global_status.read().await.as_ref().map(|r| r.checked_at)
|
||||
}
|
||||
|
||||
/// Check if a specific URL is reachable (with caching)
|
||||
/// Used for entry-specific network requirements
|
||||
pub async fn check_url(&self, url: &str) -> bool {
|
||||
// Check cache first
|
||||
{
|
||||
let cache = self.url_cache.read().await;
|
||||
if let Some(result) = cache.get(url) {
|
||||
// Cache valid for half the check interval
|
||||
let cache_ttl = self.config.check_interval / 2;
|
||||
let age = shepherd_util::now()
|
||||
.signed_duration_since(result.checked_at)
|
||||
.to_std()
|
||||
.unwrap_or(Duration::MAX);
|
||||
if age < cache_ttl {
|
||||
return result.connected;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform check
|
||||
let connected = check_url_reachable(&self.client, url, self.config.check_timeout).await;
|
||||
|
||||
// Update cache
|
||||
{
|
||||
let mut cache = self.url_cache.write().await;
|
||||
cache.insert(
|
||||
url.to_string(),
|
||||
CheckResult {
|
||||
connected,
|
||||
checked_at: shepherd_util::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
connected
|
||||
}
|
||||
|
||||
/// Force an immediate connectivity recheck
|
||||
pub async fn trigger_recheck(&self) {
|
||||
let connected =
|
||||
check_url_reachable(&self.client, &self.config.check_url, self.config.check_timeout)
|
||||
.await;
|
||||
update_global_status(
|
||||
&self.global_status,
|
||||
&self.event_tx,
|
||||
&self.config.check_url,
|
||||
connected,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Clear URL cache to force rechecks
|
||||
self.url_cache.write().await.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a URL is reachable
|
||||
async fn check_url_reachable(client: &Client, url: &str, timeout: Duration) -> bool {
|
||||
debug!(url = %url, "Checking connectivity");
|
||||
|
||||
match client
|
||||
.get(url)
|
||||
.timeout(timeout)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
let status = response.status();
|
||||
let connected = status.is_success() || status.as_u16() == 204;
|
||||
debug!(url = %url, status = %status, connected = connected, "Connectivity check complete");
|
||||
connected
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(url = %url, error = %e, "Connectivity check failed");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Update global status and emit event if changed
|
||||
async fn update_global_status(
|
||||
global_status: &Arc<RwLock<Option<CheckResult>>>,
|
||||
event_tx: &mpsc::Sender<ConnectivityEvent>,
|
||||
check_url: &str,
|
||||
connected: bool,
|
||||
) {
|
||||
let mut status = global_status.write().await;
|
||||
let previous = status.as_ref().map(|r| r.connected);
|
||||
|
||||
*status = Some(CheckResult {
|
||||
connected,
|
||||
checked_at: shepherd_util::now(),
|
||||
});
|
||||
|
||||
// Emit event if status changed
|
||||
if previous != Some(connected) {
|
||||
info!(
|
||||
connected = connected,
|
||||
url = %check_url,
|
||||
"Global connectivity status changed"
|
||||
);
|
||||
let _ = event_tx
|
||||
.send(ConnectivityEvent::StatusChanged {
|
||||
connected,
|
||||
check_url: check_url.to_string(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the netlink monitor to detect network interface changes
|
||||
async fn run_netlink_monitor(
|
||||
client: &Client,
|
||||
global_status: &Arc<RwLock<Option<CheckResult>>>,
|
||||
url_cache: &Arc<RwLock<HashMap<String, CheckResult>>>,
|
||||
event_tx: &mpsc::Sender<ConnectivityEvent>,
|
||||
check_url: &str,
|
||||
check_timeout: Duration,
|
||||
shutdown: &mut watch::Receiver<bool>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Create netlink socket for route notifications
|
||||
let mut socket = Socket::new(NETLINK_ROUTE)?;
|
||||
|
||||
// Bind to multicast groups for link and address changes
|
||||
// RTMGRP_LINK = 1, RTMGRP_IPV4_IFADDR = 0x10, RTMGRP_IPV6_IFADDR = 0x100
|
||||
let groups = 1 | 0x10 | 0x100;
|
||||
let addr = SocketAddr::new(0, groups);
|
||||
socket.bind(&addr)?;
|
||||
|
||||
// Set non-blocking for async compatibility
|
||||
socket.set_non_blocking(true)?;
|
||||
|
||||
info!("Netlink monitor started");
|
||||
|
||||
let fd = socket.as_raw_fd();
|
||||
let mut buf = vec![0u8; 4096];
|
||||
|
||||
loop {
|
||||
// Use tokio's async fd for the socket
|
||||
let async_fd = tokio::io::unix::AsyncFd::new(fd)?;
|
||||
|
||||
tokio::select! {
|
||||
result = async_fd.readable() => {
|
||||
match result {
|
||||
Ok(mut guard) => {
|
||||
// Try to read from socket
|
||||
match socket.recv(&mut buf, 0) {
|
||||
Ok(len) if len > 0 => {
|
||||
// Parse netlink messages
|
||||
if has_relevant_netlink_event(&buf[..len]) {
|
||||
debug!("Network interface change detected");
|
||||
let _ = event_tx.send(ConnectivityEvent::InterfaceChanged).await;
|
||||
|
||||
// Clear URL cache
|
||||
url_cache.write().await.clear();
|
||||
|
||||
// Recheck connectivity after a short delay
|
||||
// (give network time to stabilize)
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
let connected = check_url_reachable(client, check_url, check_timeout).await;
|
||||
update_global_status(global_status, event_tx, check_url, connected).await;
|
||||
}
|
||||
guard.clear_ready();
|
||||
}
|
||||
Ok(_) => {
|
||||
guard.clear_ready();
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||||
guard.clear_ready();
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "Netlink recv error");
|
||||
guard.clear_ready();
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "Async fd error");
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown.changed() => {
|
||||
if *shutdown.borrow() {
|
||||
debug!("Netlink monitor shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a netlink message buffer contains relevant network events
|
||||
fn has_relevant_netlink_event(buf: &[u8]) -> bool {
|
||||
let mut offset = 0;
|
||||
|
||||
while offset < buf.len() {
|
||||
match NetlinkMessage::<RouteNetlinkMessage>::deserialize(&buf[offset..]) {
|
||||
Ok(msg) => {
|
||||
if let NetlinkPayload::InnerMessage(route_msg) = &msg.payload
|
||||
&& matches!(
|
||||
route_msg,
|
||||
// Link up/down events
|
||||
RouteNetlinkMessage::NewLink(_)
|
||||
| RouteNetlinkMessage::DelLink(_)
|
||||
// Address added/removed
|
||||
| RouteNetlinkMessage::NewAddress(_)
|
||||
| RouteNetlinkMessage::DelAddress(_)
|
||||
// Route changes
|
||||
| RouteNetlinkMessage::NewRoute(_)
|
||||
| RouteNetlinkMessage::DelRoute(_)
|
||||
)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Move to next message
|
||||
let len = msg.header.length as usize;
|
||||
if len == 0 {
|
||||
break;
|
||||
}
|
||||
offset += len;
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Handle for accessing connectivity status from other parts of the service
|
||||
#[derive(Clone)]
|
||||
pub struct ConnectivityHandle {
|
||||
client: Client,
|
||||
global_status: Arc<RwLock<Option<CheckResult>>>,
|
||||
url_cache: Arc<RwLock<HashMap<String, CheckResult>>>,
|
||||
check_timeout: Duration,
|
||||
cache_ttl: Duration,
|
||||
global_check_url: String,
|
||||
}
|
||||
|
||||
impl ConnectivityHandle {
|
||||
/// Create a handle from the monitor
|
||||
pub fn from_monitor(monitor: &ConnectivityMonitor) -> Self {
|
||||
Self {
|
||||
client: monitor.client.clone(),
|
||||
global_status: monitor.global_status.clone(),
|
||||
url_cache: monitor.url_cache.clone(),
|
||||
check_timeout: monitor.config.check_timeout,
|
||||
cache_ttl: monitor.config.check_interval / 2,
|
||||
global_check_url: monitor.config.check_url.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current global connectivity status
|
||||
pub async fn is_connected(&self) -> bool {
|
||||
self.global_status
|
||||
.read()
|
||||
.await
|
||||
.as_ref()
|
||||
.is_some_and(|r| r.connected)
|
||||
}
|
||||
|
||||
/// Get the last check time
|
||||
pub async fn last_check_time(&self) -> Option<DateTime<Local>> {
|
||||
self.global_status.read().await.as_ref().map(|r| r.checked_at)
|
||||
}
|
||||
|
||||
/// Get the global check URL
|
||||
pub fn global_check_url(&self) -> &str {
|
||||
&self.global_check_url
|
||||
}
|
||||
|
||||
/// Check if a specific URL is reachable (with caching)
|
||||
pub async fn check_url(&self, url: &str) -> bool {
|
||||
// If it's the global URL, use global status
|
||||
if url == self.global_check_url {
|
||||
return self.is_connected().await;
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
{
|
||||
let cache = self.url_cache.read().await;
|
||||
if let Some(result) = cache.get(url) {
|
||||
let age = shepherd_util::now()
|
||||
.signed_duration_since(result.checked_at)
|
||||
.to_std()
|
||||
.unwrap_or(Duration::MAX);
|
||||
if age < self.cache_ttl {
|
||||
return result.connected;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform check
|
||||
let connected = check_url_reachable(&self.client, url, self.check_timeout).await;
|
||||
|
||||
// Update cache
|
||||
{
|
||||
let mut cache = self.url_cache.write().await;
|
||||
cache.insert(
|
||||
url.to_string(),
|
||||
CheckResult {
|
||||
connected,
|
||||
checked_at: shepherd_util::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
connected
|
||||
}
|
||||
}
|
||||
|
|
@ -6,11 +6,14 @@
|
|||
//! - Exit observation
|
||||
//! - stdout/stderr capture
|
||||
//! - Volume control with auto-detection of sound systems
|
||||
//! - Network connectivity monitoring via netlink
|
||||
|
||||
mod adapter;
|
||||
mod connectivity;
|
||||
mod process;
|
||||
mod volume;
|
||||
|
||||
pub use adapter::*;
|
||||
pub use connectivity::*;
|
||||
pub use process::*;
|
||||
pub use volume::*;
|
||||
|
|
|
|||
|
|
@ -252,5 +252,6 @@ fn reason_to_message(reason: &ReasonCode) -> &'static str {
|
|||
ReasonCode::SessionActive { .. } => "Another session is active",
|
||||
ReasonCode::UnsupportedKind { .. } => "Entry type not supported",
|
||||
ReasonCode::Disabled { .. } => "Entry disabled",
|
||||
ReasonCode::NetworkUnavailable { .. } => "Network connection required",
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -117,6 +117,10 @@ impl SharedState {
|
|||
EventPayload::VolumeChanged { .. } => {
|
||||
// Volume events are handled by HUD
|
||||
}
|
||||
EventPayload::ConnectivityChanged { .. } => {
|
||||
// Connectivity changes may affect entry availability - request fresh state
|
||||
self.set(LauncherState::Connecting);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ tracing-subscriber = { workspace = true }
|
|||
tokio = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
clap = { version = "4.5", features = ["derive", "env"] }
|
||||
nix = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -8,17 +8,22 @@
|
|||
//! - Host adapter (Linux)
|
||||
//! - IPC server
|
||||
//! - Volume control
|
||||
//! - Network connectivity monitoring
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use clap::Parser;
|
||||
use shepherd_api::{
|
||||
Command, ErrorCode, ErrorInfo, Event, EventPayload, HealthStatus,
|
||||
Response, ResponsePayload, SessionEndReason, StopMode, VolumeInfo, VolumeRestrictions,
|
||||
Command, ConnectivityStatus, ErrorCode, ErrorInfo, Event, EventPayload, HealthStatus,
|
||||
ReasonCode, Response, ResponsePayload, SessionEndReason, StopMode, VolumeInfo,
|
||||
VolumeRestrictions,
|
||||
};
|
||||
use shepherd_config::{load_config, VolumePolicy};
|
||||
use shepherd_core::{CoreEngine, CoreEvent, LaunchDecision, StopDecision};
|
||||
use shepherd_host_api::{HostAdapter, HostEvent, StopMode as HostStopMode, VolumeController};
|
||||
use shepherd_host_linux::{LinuxHost, LinuxVolumeController};
|
||||
use shepherd_host_linux::{
|
||||
ConnectivityConfig, ConnectivityEvent, ConnectivityHandle, ConnectivityMonitor, LinuxHost,
|
||||
LinuxVolumeController,
|
||||
};
|
||||
use shepherd_ipc::{IpcServer, ServerMessage};
|
||||
use shepherd_store::{AuditEvent, AuditEventType, SqliteStore, Store};
|
||||
use shepherd_util::{default_config_path, ClientId, MonotonicInstant, RateLimiter};
|
||||
|
|
@ -26,7 +31,7 @@ use std::path::PathBuf;
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::{mpsc, watch, Mutex};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
|
|
@ -60,10 +65,12 @@ struct Service {
|
|||
ipc: Arc<IpcServer>,
|
||||
store: Arc<dyn Store>,
|
||||
rate_limiter: RateLimiter,
|
||||
connectivity: ConnectivityHandle,
|
||||
shutdown_tx: watch::Sender<bool>,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
async fn new(args: &Args) -> Result<Self> {
|
||||
async fn new(args: &Args) -> Result<(Self, mpsc::Receiver<ConnectivityEvent>)> {
|
||||
// Load configuration
|
||||
let policy = load_config(&args.config)
|
||||
.with_context(|| format!("Failed to load config from {:?}", args.config))?;
|
||||
|
|
@ -116,6 +123,7 @@ impl Service {
|
|||
}
|
||||
|
||||
// Initialize core engine
|
||||
let network_policy = policy.network.clone();
|
||||
let engine = CoreEngine::new(policy, store.clone(), host.capabilities().clone());
|
||||
|
||||
// Initialize IPC server
|
||||
|
|
@ -127,17 +135,43 @@ impl Service {
|
|||
// Rate limiter: 30 requests per second per client
|
||||
let rate_limiter = RateLimiter::new(30, Duration::from_secs(1));
|
||||
|
||||
Ok(Self {
|
||||
engine,
|
||||
host,
|
||||
volume,
|
||||
ipc: Arc::new(ipc),
|
||||
store,
|
||||
rate_limiter,
|
||||
})
|
||||
// Initialize connectivity monitor
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let connectivity_config = ConnectivityConfig {
|
||||
check_url: network_policy.check_url,
|
||||
check_interval: network_policy.check_interval,
|
||||
check_timeout: network_policy.check_timeout,
|
||||
};
|
||||
let (connectivity_monitor, connectivity_events) =
|
||||
ConnectivityMonitor::new(connectivity_config, shutdown_rx);
|
||||
let connectivity = ConnectivityHandle::from_monitor(&connectivity_monitor);
|
||||
|
||||
// Spawn connectivity monitor task
|
||||
tokio::spawn(async move {
|
||||
connectivity_monitor.run().await;
|
||||
});
|
||||
|
||||
info!(
|
||||
check_url = %connectivity.global_check_url(),
|
||||
"Connectivity monitor started"
|
||||
);
|
||||
|
||||
Ok((
|
||||
Self {
|
||||
engine,
|
||||
host,
|
||||
volume,
|
||||
ipc: Arc::new(ipc),
|
||||
store,
|
||||
rate_limiter,
|
||||
connectivity,
|
||||
shutdown_tx,
|
||||
},
|
||||
connectivity_events,
|
||||
))
|
||||
}
|
||||
|
||||
async fn run(self) -> Result<()> {
|
||||
async fn run(self, mut connectivity_events: mpsc::Receiver<ConnectivityEvent>) -> Result<()> {
|
||||
// Start host process monitor
|
||||
let _monitor_handle = self.host.start_monitor();
|
||||
|
||||
|
|
@ -155,6 +189,8 @@ impl Service {
|
|||
let host = self.host.clone();
|
||||
let volume = self.volume.clone();
|
||||
let store = self.store.clone();
|
||||
let connectivity = self.connectivity.clone();
|
||||
let shutdown_tx = self.shutdown_tx.clone();
|
||||
|
||||
// Spawn IPC accept task
|
||||
let ipc_accept = ipc_ref.clone();
|
||||
|
|
@ -218,7 +254,12 @@ impl Service {
|
|||
|
||||
// IPC messages
|
||||
Some(msg) = ipc_messages.recv() => {
|
||||
Self::handle_ipc_message(&engine, &host, &volume, &ipc_ref, &store, &rate_limiter, msg).await;
|
||||
Self::handle_ipc_message(&engine, &host, &volume, &ipc_ref, &store, &rate_limiter, &connectivity, msg).await;
|
||||
}
|
||||
|
||||
// Connectivity events
|
||||
Some(conn_event) = connectivity_events.recv() => {
|
||||
Self::handle_connectivity_event(&engine, &ipc_ref, &connectivity, conn_event).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -226,6 +267,9 @@ impl Service {
|
|||
// Graceful shutdown
|
||||
info!("Shutting down shepherdd");
|
||||
|
||||
// Signal connectivity monitor to stop
|
||||
let _ = shutdown_tx.send(true);
|
||||
|
||||
// Stop all running sessions
|
||||
{
|
||||
let engine = engine.lock().await;
|
||||
|
|
@ -433,6 +477,45 @@ impl Service {
|
|||
}
|
||||
}
|
||||
|
||||
async fn handle_connectivity_event(
|
||||
engine: &Arc<Mutex<CoreEngine>>,
|
||||
ipc: &Arc<IpcServer>,
|
||||
connectivity: &ConnectivityHandle,
|
||||
event: ConnectivityEvent,
|
||||
) {
|
||||
match event {
|
||||
ConnectivityEvent::StatusChanged {
|
||||
connected,
|
||||
check_url,
|
||||
} => {
|
||||
info!(connected = connected, url = %check_url, "Connectivity status changed");
|
||||
|
||||
// Broadcast connectivity change event
|
||||
ipc.broadcast_event(Event::new(EventPayload::ConnectivityChanged {
|
||||
connected,
|
||||
check_url,
|
||||
}));
|
||||
|
||||
// Also broadcast state change so clients can update entry availability
|
||||
let state = {
|
||||
let eng = engine.lock().await;
|
||||
let mut state = eng.get_state();
|
||||
state.connectivity = ConnectivityStatus {
|
||||
connected: connectivity.is_connected().await,
|
||||
check_url: Some(connectivity.global_check_url().to_string()),
|
||||
last_check: connectivity.last_check_time().await,
|
||||
};
|
||||
state
|
||||
};
|
||||
ipc.broadcast_event(Event::new(EventPayload::StateChanged(state)));
|
||||
}
|
||||
|
||||
ConnectivityEvent::InterfaceChanged => {
|
||||
debug!("Network interface changed, connectivity recheck in progress");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_ipc_message(
|
||||
engine: &Arc<Mutex<CoreEngine>>,
|
||||
host: &Arc<LinuxHost>,
|
||||
|
|
@ -440,6 +523,7 @@ impl Service {
|
|||
ipc: &Arc<IpcServer>,
|
||||
store: &Arc<dyn Store>,
|
||||
rate_limiter: &Arc<Mutex<RateLimiter>>,
|
||||
connectivity: &ConnectivityHandle,
|
||||
msg: ServerMessage,
|
||||
) {
|
||||
match msg {
|
||||
|
|
@ -458,7 +542,7 @@ impl Service {
|
|||
}
|
||||
|
||||
let response =
|
||||
Self::handle_command(engine, host, volume, ipc, store, &client_id, request.request_id, request.command)
|
||||
Self::handle_command(engine, host, volume, ipc, store, connectivity, &client_id, request.request_id, request.command)
|
||||
.await;
|
||||
|
||||
let _ = ipc.send_response(&client_id, response).await;
|
||||
|
|
@ -504,6 +588,7 @@ impl Service {
|
|||
volume: &Arc<LinuxVolumeController>,
|
||||
ipc: &Arc<IpcServer>,
|
||||
store: &Arc<dyn Store>,
|
||||
connectivity: &ConnectivityHandle,
|
||||
client_id: &ClientId,
|
||||
request_id: u64,
|
||||
command: Command,
|
||||
|
|
@ -513,7 +598,13 @@ impl Service {
|
|||
|
||||
match command {
|
||||
Command::GetState => {
|
||||
let state = engine.lock().await.get_state();
|
||||
let mut state = engine.lock().await.get_state();
|
||||
// Add connectivity status
|
||||
state.connectivity = ConnectivityStatus {
|
||||
connected: connectivity.is_connected().await,
|
||||
check_url: Some(connectivity.global_check_url().to_string()),
|
||||
last_check: connectivity.last_check_time().await,
|
||||
};
|
||||
Response::success(request_id, ResponsePayload::State(state))
|
||||
}
|
||||
|
||||
|
|
@ -526,6 +617,30 @@ impl Service {
|
|||
Command::Launch { entry_id } => {
|
||||
let mut eng = engine.lock().await;
|
||||
|
||||
// First check if the entry requires network and if it's available
|
||||
if let Some(entry) = eng.policy().get_entry(&entry_id)
|
||||
&& entry.network.required
|
||||
{
|
||||
let check_url = entry.network.effective_check_url(&eng.policy().network);
|
||||
let network_ok = connectivity.check_url(check_url).await;
|
||||
|
||||
if !network_ok {
|
||||
info!(
|
||||
entry_id = %entry_id,
|
||||
check_url = %check_url,
|
||||
"Launch denied: network connectivity check failed"
|
||||
);
|
||||
return Response::success(
|
||||
request_id,
|
||||
ResponsePayload::LaunchDenied {
|
||||
reasons: vec![ReasonCode::NetworkUnavailable {
|
||||
check_url: check_url.to_string(),
|
||||
}],
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match eng.request_launch(&entry_id, now) {
|
||||
LaunchDecision::Approved(plan) => {
|
||||
// Start the session in the engine
|
||||
|
|
@ -939,6 +1054,6 @@ async fn main() -> Result<()> {
|
|||
);
|
||||
|
||||
// Create and run the service
|
||||
let service = Service::new(&args).await?;
|
||||
service.run().await
|
||||
let (service, connectivity_events) = Service::new(&args).await?;
|
||||
service.run(connectivity_events).await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
//! These tests verify the end-to-end behavior of shepherdd.
|
||||
|
||||
use shepherd_api::{EntryKind, WarningSeverity, WarningThreshold};
|
||||
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy, Policy};
|
||||
use shepherd_config::{AvailabilityPolicy, Entry, LimitsPolicy, NetworkRequirement, Policy};
|
||||
use shepherd_core::{CoreEngine, CoreEvent, LaunchDecision};
|
||||
use shepherd_host_api::{HostCapabilities, MockHost};
|
||||
use shepherd_store::{SqliteStore, Store};
|
||||
|
|
@ -48,6 +48,7 @@ fn make_test_policy() -> Policy {
|
|||
},
|
||||
],
|
||||
volume: None,
|
||||
network: NetworkRequirement::default(),
|
||||
disabled: false,
|
||||
disabled_reason: None,
|
||||
},
|
||||
|
|
@ -55,6 +56,7 @@ fn make_test_policy() -> Policy {
|
|||
default_warnings: vec![],
|
||||
default_max_run: Some(Duration::from_secs(3600)),
|
||||
volume: Default::default(),
|
||||
network: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue