package alert
import (
"crypto/sha256"
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/processctx"
)
const alertDispatchFailuresMetric = "csm_alert_dispatch_failures_total"
// alertDispatchFailures counts individual channel send failures (email,
// webhook, phpanel) so an operator can see when alerts are silently failing
// to deliver instead of the daemon looking healthy while findings never
// reach anyone.
var alertDispatchFailures = metrics.NewCounter(
alertDispatchFailuresMetric,
"Alert deliveries that failed (email/webhook/phpanel). Sustained growth means findings are being detected but not reaching operators -- check SMTP/webhook reachability and credentials.",
)
func init() {
metrics.MustRegister(alertDispatchFailuresMetric, alertDispatchFailures)
}
func addDispatchError(errs *[]error, err error) {
*errs = append(*errs, err)
alertDispatchFailures.Inc()
}
// Severity levels for findings.
type Severity int
const (
Warning Severity = iota
High Severity = iota
Critical Severity = iota
)
func (s Severity) String() string {
switch s {
case Warning:
return "WARNING"
case High:
return "HIGH"
case Critical:
return "CRITICAL"
}
return "UNKNOWN"
}
// Finding represents a single security check result.
type Finding struct {
Severity Severity `json:"severity"`
Check string `json:"check"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
FilePath string `json:"file_path,omitempty"`
ProcessInfo string `json:"process_info,omitempty"` // "pid=N cmd=name uid=N" from fanotify
PID int `json:"pid,omitempty"` // structured PID for auto-response
// PHP-relay structured fields (Stage 1 email_php_relay_abuse). All optional;
// zero values mean "this finding does not carry that dimension".
Path string `json:"path,omitempty"` // path1 trigger label: "header" | "volume" | "volume_account" | "fanout" | "baseline" | "reputation"
MsgIDs []string `json:"msg_ids,omitempty"` // sample of in-flight msgIDs (auto-action acts on the live snapshot, not this list)
ScriptKey string `json:"script_key,omitempty"` // host:path from X-PHP-Script
SourceIP string `json:"source_ip,omitempty"` // IP after "for " in X-PHP-Script
CPUser string `json:"cp_user,omitempty"` // cPanel user from spool -H line 2
// RelayTotal is the trigger count for the PHP-relay path that fired
// (qualifying/volume/fanout/account-window count). RelayBreakdown lists
// the scripts that contributed, with per-script hit counts and a sample
// subject. Both optional; volume_account carries RelayTotal with no
// breakdown (account log-tail path has no trustworthy script key).
RelayTotal int `json:"relay_total,omitempty"`
RelayBreakdown []RelayScriptHit `json:"relay_breakdown,omitempty"`
// Tenant context (added v2.12.0). Optional - populated when the check
// has enough info to attribute the finding to a specific tenant within
// a multi-tenant host. Empty strings render as omitted JSON keys so
// existing webhook consumers see no diff.
TenantID string `json:"tenant_id,omitempty"`
Domain string `json:"domain,omitempty"`
Mailbox string `json:"mailbox,omitempty"`
// Process context (Phase 1 process-ancestry enrichment). Optional.
// Populated by exec/connection live monitors when cache or enricher
// has data. Omitted from JSON when nil so existing webhook consumers
// see no diff.
Process *processctx.ProcessContext `json:"process,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
// RelayScriptHit is one script's contribution to a PHP-relay finding.
type RelayScriptHit struct {
ScriptKey string `json:"script_key"` // "host:/path" from X-PHP-Script
Hits int `json:"hits"` // messages counted in the path window
LastSeen time.Time `json:"last_seen"`
SampleSubject string `json:"sample_subject,omitempty"` // attacker-controlled; render escaped
}
func (f Finding) String() string {
ts := f.Timestamp.Format("2006-01-02 15:04:05")
s := fmt.Sprintf("[%s] %s - %s", f.Severity, f.Check, f.Message)
if f.Details != "" {
s += "\n " + strings.ReplaceAll(f.Details, "\n", "\n ")
}
if f.ProcessInfo != "" {
s += fmt.Sprintf("\n Process: %s", f.ProcessInfo)
}
s += fmt.Sprintf("\n Time: %s", ts)
return s
}
// Key returns a unique key for deduplication.
func (f Finding) Key() string {
if key := f.sourceIPKey(); key != "" {
return key
}
if f.Details == "" {
return fmt.Sprintf("%s:%s", f.Check, f.Message)
}
h := sha256.Sum256([]byte(f.Details))
return fmt.Sprintf("%s:%s:%x", f.Check, f.Message, h[:4])
}
// Fingerprint returns the content hash used by alert-state deduplication.
func (f Finding) Fingerprint() string {
if key := f.sourceIPKey(); key != "" {
h := sha256.Sum256([]byte(key))
return fmt.Sprintf("%x", h[:8])
}
h := sha256.Sum256([]byte(fmt.Sprintf("%s:%s:%s", f.Check, f.Message, f.Details)))
return fmt.Sprintf("%x", h[:8])
}
func (f Finding) sourceIPKey() string {
switch f.Check {
case "admin_panel_bruteforce", "wp_login_bruteforce", "wp_user_enumeration", "xmlrpc_abuse",
"http_request_flood", "http_ua_spoof":
default:
return ""
}
ip := normalizeFindingIP(f.SourceIP)
if ip == "" {
ip = sourceIPFromFindingMessage(f.Message)
}
if ip == "" {
return ""
}
return fmt.Sprintf("%s:ip:%s", f.Check, ip)
}
func normalizeFindingIP(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
if host, _, err := net.SplitHostPort(raw); err == nil {
raw = host
}
raw = strings.Trim(raw, "[]")
ip := net.ParseIP(raw)
if ip == nil {
return ""
}
return ip.String()
}
func sourceIPFromFindingMessage(msg string) string {
for _, sep := range []string{" from ", ": "} {
idx := strings.LastIndex(msg, sep)
if idx < 0 {
continue
}
rest := msg[idx+len(sep):]
fields := strings.Fields(rest)
if len(fields) == 0 {
continue
}
candidate := strings.TrimRight(fields[0], ",:;)([]")
if ip := normalizeFindingIP(candidate); ip != "" {
return ip
}
}
return ""
}
// SplitEmail returns (localpart, domain) from an email address. Returns
// ("", "") when the input doesn't look like an email.
func SplitEmail(addr string) (localpart, domain string) {
at := strings.LastIndexByte(addr, '@')
if at <= 0 || at == len(addr)-1 {
return "", ""
}
return addr[:at], addr[at+1:]
}
// Deduplicate removes findings with the same Check+Message, keeping the first.
func Deduplicate(findings []Finding) []Finding {
seen := make(map[string]bool)
var result []Finding
for _, f := range findings {
key := f.Key()
if !seen[key] {
seen[key] = true
result = append(result, f)
}
}
return result
}
// FormatAlert formats a list of findings into a human-readable alert body.
// Sensitive data (passwords, tokens) is redacted before sending.
func FormatAlert(hostname string, findings []Finding) string {
var b strings.Builder
critCount := 0
highCount := 0
warnCount := 0
for _, f := range findings {
switch f.Severity {
case Critical:
critCount++
case High:
highCount++
case Warning:
warnCount++
}
}
fmt.Fprintf(&b, "SECURITY ALERT - %s\n", hostname)
fmt.Fprintf(&b, "Timestamp: %s\n", time.Now().Format("2006-01-02 15:04:05 MST"))
fmt.Fprintf(&b, "Findings: %d critical, %d high, %d warning\n", critCount, highCount, warnCount)
b.WriteString(strings.Repeat("─", 60) + "\n\n")
for _, sev := range []Severity{Critical, High, Warning} {
for _, f := range findings {
if f.Severity == sev {
b.WriteString(sanitizeFinding(f).String())
b.WriteString("\n\n")
}
}
}
b.WriteString(strings.Repeat("─", 60) + "\n")
b.WriteString("CSM - Continuous Security Monitor\n")
return b.String()
}
// sanitizeFinding redacts sensitive data (passwords, tokens, secrets)
// from finding messages and details before including them in alerts.
func sanitizeFinding(f Finding) Finding {
f.Message = redactSensitive(f.Message)
f.Details = redactSensitive(f.Details)
return f
}
// redactSensitive replaces password values and tokens in text with [REDACTED].
func redactSensitive(s string) string {
if s == "" {
return s
}
// Redact password= values in URLs and POST data.
// Matches: password=X, pass=X, passwd=X (up to next & or space or quote).
//
// The search base advances past each replacement (or past an
// empty-value occurrence) so we never re-match the same prefix
// position on the next iteration. An earlier version of this code
// restarted the search at position 0 after every replacement, which
// re-found the same prefix and re-wrote `[REDACTED]` -> `[REDACTED]`
// forever whenever the replacement was non-empty. That infinite
// loop would hang the daemon's alert dispatch on any log line that
// contained a populated password field.
for _, prefix := range []string{
"password=", "pass=", "passwd=", "new_password=",
"old_password=", "confirmpassword=",
} {
searchFrom := 0
for searchFrom < len(s) {
lower := strings.ToLower(s[searchFrom:])
rel := strings.Index(lower, prefix)
if rel < 0 {
break
}
idx := searchFrom + rel
valStart := idx + len(prefix)
valEnd := valStart
for valEnd < len(s) {
c := s[valEnd]
if c == '&' || c == ' ' || c == '\n' || c == '"' || c == '\'' || c == ',' {
break
}
valEnd++
}
if valEnd > valStart {
s = s[:valStart] + "[REDACTED]" + s[valEnd:]
searchFrom = valStart + len("[REDACTED]")
} else {
// Empty value (e.g. `password=&`): advance past this
// occurrence so a later populated field is still redacted.
searchFrom = valStart
}
}
}
// Redact API token values (long alphanumeric strings after token-like keys)
for _, prefix := range []string{"token_value=", "api_token="} {
lower := strings.ToLower(s)
if idx := strings.Index(lower, prefix); idx >= 0 {
valStart := idx + len(prefix)
valEnd := valStart
for valEnd < len(s) && s[valEnd] != ' ' && s[valEnd] != '\n' && s[valEnd] != '&' {
valEnd++
}
if valEnd > valStart {
s = s[:valStart] + "[REDACTED]" + s[valEnd:]
}
}
}
return s
}
func filterChecks(findings []Finding, disabledChecks []string) []Finding {
if len(findings) == 0 || len(disabledChecks) == 0 {
return findings
}
disabled := make(map[string]bool, len(disabledChecks))
for _, check := range disabledChecks {
check = strings.TrimSpace(check)
if check != "" {
disabled[check] = true
}
}
if len(disabled) == 0 {
return findings
}
filtered := make([]Finding, 0, len(findings))
for _, f := range findings {
if !disabled[f.Check] {
filtered = append(filtered, f)
}
}
return filtered
}
func buildSubject(hostname string, findings []Finding) string {
subject := fmt.Sprintf("[CSM] %s - %d security finding(s)", hostname, len(findings))
for _, f := range findings {
if f.Severity == Critical {
return fmt.Sprintf("[CSM] CRITICAL - %s - %d finding(s)", hostname, len(findings))
}
}
return subject
}
// rateLimitState tracks alerts sent per hour.
type rateLimitState struct {
Hour string `json:"hour"`
Count int `json:"count"`
}
// FindingBus is set by the daemon at startup to the broadcast.Bus that
// passive observers (e.g. SSE subscribers) drain. nil-safe: Dispatch
// only publishes if non-nil. Importing the broadcast package directly
// would create an import cycle (broadcast imports alert), so this is
// declared as an interface satisfied by *broadcast.Bus.
var FindingBus interface {
Publish(Finding)
}
// ReportHook, when set by the daemon at startup, is called once per
// deduplicated finding so the abuse reporter can consider it for submission to
// a central abuse database or collector. It must not block. Declared as a func
// to avoid an import cycle (the reporting package imports alert for the
// Finding type).
//
// Install or clear it with SetReportHook so Dispatch reads a consistent value.
var ReportHook func(Finding)
var reportHookMu sync.RWMutex
// SetReportHook installs or clears the abuse-reporting hook used by Dispatch.
func SetReportHook(h func(Finding)) {
reportHookMu.Lock()
ReportHook = h
reportHookMu.Unlock()
}
func currentReportHook() func(Finding) {
reportHookMu.RLock()
h := ReportHook
reportHookMu.RUnlock()
return h
}
func callReportHook(f Finding) {
h := currentReportHook()
if h == nil {
return
}
defer func() {
if r := recover(); r != nil {
fmt.Fprintln(os.Stderr, "alert: report hook panic")
}
}()
h(f)
}
// CentralHook, when set by the daemon, is called once per deduplicated finding
// so the central-intel consumer can escalate (challenge/block) when the
// finding's IP is in the verified central scored-set. A finding firing on an IP
// is the node's own local signal, so this is the local-corroboration path.
// Must not block. Install or clear with SetCentralHook.
var CentralHook func(Finding)
var centralHookMu sync.RWMutex
// SetCentralHook installs or clears the central-intel hook used by Dispatch.
func SetCentralHook(h func(Finding)) {
centralHookMu.Lock()
CentralHook = h
centralHookMu.Unlock()
}
func currentCentralHook() func(Finding) {
centralHookMu.RLock()
h := CentralHook
centralHookMu.RUnlock()
return h
}
func callCentralHook(f Finding) {
h := currentCentralHook()
if h == nil {
return
}
defer func() {
if r := recover(); r != nil {
fmt.Fprintln(os.Stderr, "alert: central hook panic")
}
}()
h(f)
}
type rateLimitKey struct {
StatePath string
Hour string
}
type rateLimitReservation struct {
key rateLimitKey
active bool
}
var (
rateLimitMu sync.Mutex
rateLimitPending = make(map[rateLimitKey]int)
)
// reserveRateLimit reports whether the per-hour alert budget can absorb
// another send without committing the slot. The in-memory reservation
// prevents concurrent dispatches from all taking the same final slot while
// the outbound channel is still blocked on SMTP or webhook I/O.
func reserveRateLimit(statePath string, maxPerHour int) (*rateLimitReservation, bool) {
rateLimitMu.Lock()
defer rateLimitMu.Unlock()
if maxPerHour <= 0 {
return nil, false
}
currentHour := time.Now().Format("2006-01-02T15")
rlPath := filepath.Join(statePath, "ratelimit.json")
var rl rateLimitState
// #nosec G304 -- filepath.Join(statePath, "ratelimit.json"); statePath from operator config.
data, err := os.ReadFile(rlPath)
if err == nil {
_ = json.Unmarshal(data, &rl)
}
count := 0
if rl.Hour != currentHour {
count = 0
} else {
count = rl.Count
}
key := rateLimitKey{StatePath: statePath, Hour: currentHour}
if count+rateLimitPending[key] >= maxPerHour {
return nil, false
}
rateLimitPending[key]++
return &rateLimitReservation{key: key, active: true}, true
}
func releaseRateLimit(reservation *rateLimitReservation) {
if reservation == nil {
return
}
rateLimitMu.Lock()
defer rateLimitMu.Unlock()
releaseRateLimitLocked(reservation)
}
func releaseRateLimitLocked(reservation *rateLimitReservation) {
if reservation == nil || !reservation.active {
return
}
if pending := rateLimitPending[reservation.key]; pending <= 1 {
delete(rateLimitPending, reservation.key)
} else {
rateLimitPending[reservation.key] = pending - 1
}
reservation.active = false
}
// commitRateLimit records one successful dispatch toward the hourly
// budget. Logs the WriteFile error so a disk-full or perm regression
// surfaces in the daemon log instead of silently letting the counter
// drift from the on-disk record.
func commitRateLimit(statePath string, reservation *rateLimitReservation) {
rateLimitMu.Lock()
defer rateLimitMu.Unlock()
releaseRateLimitLocked(reservation)
currentHour := time.Now().Format("2006-01-02T15")
rlPath := filepath.Join(statePath, "ratelimit.json")
var rl rateLimitState
// #nosec G304 -- filepath.Join under operator-configured statePath.
if data, err := os.ReadFile(rlPath); err == nil {
_ = json.Unmarshal(data, &rl)
}
if rl.Hour != currentHour {
rl = rateLimitState{Hour: currentHour}
}
rl.Count++
newData, err := json.Marshal(rl)
if err != nil {
fmt.Fprintf(os.Stderr, "alert: rate-limit marshal failed: %v\n", err)
return
}
if err := os.WriteFile(rlPath, newData, 0600); err != nil {
fmt.Fprintf(os.Stderr, "alert: rate-limit write failed for %s: %v\n", rlPath, err)
}
}
// checkRateLimit returns true if we can send more alerts this hour.
//
// Deprecated: kept for callers that expect the check-and-increment pattern.
// New code should use reserveRateLimit and commitRateLimit so a failed
// dispatch does not consume the operator's budget.
func checkRateLimit(statePath string, maxPerHour int) bool {
rateLimitMu.Lock()
defer rateLimitMu.Unlock()
rlPath := filepath.Join(statePath, "ratelimit.json")
currentHour := time.Now().Format("2006-01-02T15")
var rl rateLimitState
// #nosec G304 -- filepath.Join(statePath, "ratelimit.json"); statePath from operator config.
data, err := os.ReadFile(rlPath)
if err == nil {
_ = json.Unmarshal(data, &rl)
}
// Reset if new hour
if rl.Hour != currentHour {
rl = rateLimitState{Hour: currentHour, Count: 0}
}
if rl.Count >= maxPerHour {
return false
}
rl.Count++
newData, _ := json.Marshal(rl)
_ = os.WriteFile(rlPath, newData, 0600)
return true
}
func formatDispatchErrors(errs []error) error {
if len(errs) == 0 {
return nil
}
msgs := make([]string, len(errs))
for i, e := range errs {
msgs[i] = e.Error()
}
return fmt.Errorf("alert dispatch errors: %s", strings.Join(msgs, "; "))
}
// Dispatch sends alerts via all configured channels.
func Dispatch(cfg *config.Config, findings []Finding) error {
// Deduplicate
findings = Deduplicate(findings)
// Audit log captures every (deduplicated) finding before
// FilterBlockedAlerts and the rate limiter, so SIEMs see the
// complete picture even when email/webhook are throttled or
// when "this IP is already blocked" suppression hides a finding
// from the operator-facing channels.
emitAudit(cfg, findings)
// Publish to passive observers (e.g. SSE subscribers) immediately after
// auditing, before rate-limit and webhook delivery, so subscribers see
// the complete picture even when operator-facing channels are throttled.
if FindingBus != nil {
for _, f := range findings {
FindingBus.Publish(f)
}
}
// Offer every finding to the abuse reporter (it gates and minimizes
// internally, queueing only confirmed-abuse findings for the drain loop)
// and to the central-intel consumer (it escalates findings whose IP is in
// the verified central scored-set).
for _, f := range findings {
callReportHook(f)
callCentralHook(f)
}
var errs []error
// Phpanel consumes this webhook as a signed data-plane stream. Send the
// full deduplicated stream before operator notification suppression and
// rate limiting, otherwise fleet correlation can miss attacker spread.
phpanelWebhook := cfg.Alerts.Webhook.Enabled && cfg.Alerts.Webhook.Type == "phpanel"
if phpanelWebhook {
for _, f := range findings {
if err := SendPhpanelWebhookFinding(cfg, f); err != nil {
addDispatchError(&errs, fmt.Errorf("phpanel webhook (check=%s): %w", f.Check, err))
}
}
}
// Filter out blocked IP alerts if configured
findings = FilterBlockedAlerts(cfg, findings)
if len(findings) == 0 {
return formatDispatchErrors(errs)
}
emailFindings := []Finding(nil)
if cfg.Alerts.Email.Enabled {
emailFindings = filterChecks(findings, cfg.Alerts.Email.DisabledChecks)
}
webhookFindings := []Finding(nil)
if cfg.Alerts.Webhook.Enabled && !phpanelWebhook {
webhookFindings = findings
}
if len(emailFindings) == 0 && len(webhookFindings) == 0 {
return formatDispatchErrors(errs)
}
// Rate limit check — CRITICAL realtime findings (malware, webshells,
// backdoors) always get through. Only non-critical alerts are rate-limited.
hasCritical := false
for _, f := range findings {
if f.Severity == Critical {
hasCritical = true
break
}
}
var reservation *rateLimitReservation
if !hasCritical {
var ok bool
reservation, ok = reserveRateLimit(cfg.StatePath, cfg.Alerts.MaxPerHour)
if !ok {
fmt.Fprintf(os.Stderr, "Alert rate limit reached (%d/hour), skipping non-critical alert dispatch\n", cfg.Alerts.MaxPerHour)
return formatDispatchErrors(errs)
}
defer releaseRateLimit(reservation)
}
dispatched := false
if len(emailFindings) > 0 {
subject := buildSubject(cfg.Hostname, emailFindings)
body := FormatAlert(cfg.Hostname, emailFindings)
if err := SendEmail(cfg, subject, body); err != nil {
addDispatchError(&errs, fmt.Errorf("email: %w", err))
} else {
dispatched = true
}
}
if len(webhookFindings) > 0 {
subject := buildSubject(cfg.Hostname, webhookFindings)
body := FormatAlert(cfg.Hostname, webhookFindings)
if err := SendWebhook(cfg, subject, body); err != nil {
addDispatchError(&errs, fmt.Errorf("webhook: %w", err))
} else {
dispatched = true
}
}
// Commit the rate-limit slot only after at least one channel
// accepted the message. Without this, a failed send burned the
// budget; the next non-critical alert was then throttled with no
// operator-facing trace.
if dispatched {
commitRateLimit(cfg.StatePath, reservation)
}
return formatDispatchErrors(errs)
}
// SendHeartbeat pings a dead man's switch URL.
func SendHeartbeat(cfg *config.Config) {
if !cfg.Alerts.Heartbeat.Enabled || cfg.Alerts.Heartbeat.URL == "" {
return
}
client := httpClient(10 * time.Second)
resp, err := client.Get(cfg.Alerts.Heartbeat.URL)
if err != nil {
fmt.Fprintf(os.Stderr, "Heartbeat failed: %v\n", err)
return
}
defer closeWebhookResponseBody(resp)
}
package alert
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"runtime/debug"
"strconv"
"sync"
"sync/atomic"
"github.com/pidginhost/csm/internal/config"
)
// Audit-log dispatching is layered on top of the existing email +
// webhook fork in Dispatch(). Sinks live behind a package-level
// manager so the daemon's per-call Dispatch path does not pay the
// cost of opening a JSONL file or dialling syslog on every alert.
//
// The manager keys its sink set by a fingerprint of the relevant
// config sub-block; on hot reload the fingerprint changes and the
// manager closes the old sinks and rebuilds.
var (
auditMu sync.Mutex
auditSinks []AuditSink
auditFingerprint string
)
// emitAudit ships every finding through every configured audit-log
// sink. Called from Dispatch BEFORE rate-limit checks so the audit
// trail is complete even when email/webhook are throttled.
func emitAudit(cfg *config.Config, findings []Finding) {
if cfg == nil {
return
}
ensureAuditSinks(cfg)
auditMu.Lock()
sinks := append([]AuditSink(nil), auditSinks...)
auditMu.Unlock()
for _, f := range findings {
// Observer fan-out runs first so the incident correlator sees
// every finding even when no audit sinks are configured.
notifyFindingObservers(f)
if len(sinks) == 0 {
continue
}
ev := NewAuditEvent(cfg.Hostname, f)
for _, s := range sinks {
if err := s.Emit(ev); err != nil {
fmt.Fprintf(os.Stderr, "[audit-log] %s emit failed: %v\n", s.Name(), err)
}
}
}
}
// ensureAuditSinks (re)builds the active sink set when the relevant
// config sub-block has changed since the last build. On the
// happy-path steady state this is a fingerprint compare and a return.
func ensureAuditSinks(cfg *config.Config) {
fp := auditConfigFingerprint(cfg)
auditMu.Lock()
defer auditMu.Unlock()
if fp == auditFingerprint && auditSinks != nil {
return
}
// Config changed -- shut down the old sinks before building new
// ones so file descriptors / sockets are released cleanly.
for _, s := range auditSinks {
_ = s.Close()
}
auditSinks = nil
if cfg.Alerts.AuditLog.File.Enabled && cfg.Alerts.AuditLog.File.Path != "" {
s, err := NewJSONLSink(cfg.Alerts.AuditLog.File.Path)
if err != nil {
fmt.Fprintf(os.Stderr, "[audit-log] jsonl init failed: %v\n", err)
} else {
auditSinks = append(auditSinks, s)
}
}
if cfg.Alerts.AuditLog.Syslog.Enabled {
s, err := NewSyslogSink(SyslogConfig{
Network: cfg.Alerts.AuditLog.Syslog.Network,
Address: cfg.Alerts.AuditLog.Syslog.Address,
Facility: cfg.Alerts.AuditLog.Syslog.Facility,
Hostname: cfg.Hostname,
TLSCAFile: cfg.Alerts.AuditLog.Syslog.TLSCAFile,
})
if err != nil {
fmt.Fprintf(os.Stderr, "[audit-log] syslog init failed: %v\n", err)
} else {
auditSinks = append(auditSinks, s)
}
}
auditFingerprint = fp
}
// auditConfigFingerprint reduces the audit-log sub-block to a stable
// hash so ensureAuditSinks can detect config changes without a deep
// reflect-based diff. Hostname is included because it appears in
// every emitted event.
func auditConfigFingerprint(cfg *config.Config) string {
h := sha256.New()
_, _ = fmt.Fprintf(h, "host=%s|", cfg.Hostname)
_, _ = fmt.Fprintf(h, "file.enabled=%t|file.path=%s|",
cfg.Alerts.AuditLog.File.Enabled,
cfg.Alerts.AuditLog.File.Path,
)
_, _ = fmt.Fprintf(h, "syslog.enabled=%t|syslog.network=%s|syslog.address=%s|syslog.facility=%s|syslog.tls=%s",
cfg.Alerts.AuditLog.Syslog.Enabled,
cfg.Alerts.AuditLog.Syslog.Network,
cfg.Alerts.AuditLog.Syslog.Address,
cfg.Alerts.AuditLog.Syslog.Facility,
cfg.Alerts.AuditLog.Syslog.TLSCAFile,
)
return hex.EncodeToString(h.Sum(nil))
}
// CloseAuditSinks shuts down every active sink. Called from the
// daemon at shutdown so file descriptors / sockets are released
// before the process exits. Safe to call when no sinks are open.
func CloseAuditSinks() {
auditMu.Lock()
defer auditMu.Unlock()
for _, s := range auditSinks {
_ = s.Close()
}
auditSinks = nil
auditFingerprint = ""
}
// resetAuditSinksForTest is the test-only seam to wipe the package
// state between cases. Production code never needs this -- live
// daemons run a single Dispatch path with a single config object.
func resetAuditSinksForTest() {
CloseAuditSinks()
}
// findingObservers registry. Used by the daemon to feed the incident
// correlator without making the alert package depend on internal/incident.
var (
findingObserversMu sync.RWMutex
findingObservers []findingObserver
findingObserverSeq atomic.Uint64
)
type findingObserver struct {
id uint64
fn func(Finding)
}
// RegisterFindingObserver registers fn to be called for every finding
// dispatched through emitAudit. Returns a cancel func that removes the
// observer. Safe for concurrent use; observer panics are recovered so
// one bad observer cannot stop dispatch.
func RegisterFindingObserver(fn func(Finding)) func() {
id := findingObserverSeq.Add(1)
findingObserversMu.Lock()
findingObservers = append(findingObservers, findingObserver{id: id, fn: fn})
findingObserversMu.Unlock()
return func() {
findingObserversMu.Lock()
defer findingObserversMu.Unlock()
out := findingObservers[:0]
for _, o := range findingObservers {
if o.id != id {
out = append(out, o)
}
}
findingObservers = out
}
}
// notifyFindingObservers fans a finding out to every registered observer.
// Each observer runs in a recover scope so a panic in one cannot stop
// dispatch to the rest, the audit-log sinks, or future ones.
func notifyFindingObservers(f Finding) {
findingObserversMu.RLock()
obs := append([]findingObserver(nil), findingObservers...)
findingObserversMu.RUnlock()
for _, o := range obs {
func(o findingObserver) {
defer func() {
if r := recover(); r != nil {
fmt.Fprintf(os.Stderr,
"alert: finding observer id=%d panic for check=%q: %s\n%s",
o.id, f.Check, formatRecoverValue(r), debug.Stack())
}
}()
o.fn(f)
}(o)
}
}
func formatRecoverValue(v any) (out string) {
defer func() {
if recover() != nil {
out = strconv.Quote(fmt.Sprintf("<unprintable panic value of type %T>", v))
}
}()
return strconv.QuoteToASCII(fmt.Sprint(v))
}
package alert
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"sync"
)
// JSONLSink appends one JSON object per finding to a file. Designed
// for SIEM ingest via standard log shippers (Vector, Filebeat,
// Fluentbit). Writes are mutex-serialised so concurrent emits cannot
// interleave bytes inside a single line; logrotate's copytruncate
// rotation works without daemon restart because the file's offset is
// reset by the truncate, which the open fd then writes past.
type JSONLSink struct {
path string
mu sync.Mutex
f *os.File
}
// NewJSONLSink opens (or creates) the JSONL file. Permissions are
// 0640 -- group-readable so an operator running a log shipper under
// a non-root user in the appropriate group can tail it. The parent
// directory is created with 0750 so packaging (logrotate) sees a
// reasonable default.
func NewJSONLSink(path string) (*JSONLSink, error) {
if path == "" {
return nil, errors.New("jsonl sink: path is empty")
}
if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil {
return nil, fmt.Errorf("jsonl sink: creating dir: %w", err)
}
// #nosec G304 G302 -- G304: path comes from cfg.Alerts.AuditLog.File.Path, which is operator-controlled (root-owned daemon config), not attacker input. G302: 0640 is intentional; SIEM log shippers (Vector, Filebeat, Fluentbit) commonly run as a non-root user that needs group-read access. 0600 would force the shipper to run as root.
f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0640)
if err != nil {
return nil, fmt.Errorf("jsonl sink: opening %s: %w", path, err)
}
return &JSONLSink{path: path, f: f}, nil
}
// Name returns the sink identifier used in error messages.
func (s *JSONLSink) Name() string { return "jsonl" }
// Emit appends one JSON line. The trailing newline is written as part
// of the same Write call so a partial write at EOL boundaries cannot
// leave a half-finished line in the file.
func (s *JSONLSink) Emit(event AuditEvent) error {
line, err := json.Marshal(event)
if err != nil {
return fmt.Errorf("jsonl sink: marshal: %w", err)
}
line = append(line, '\n')
s.mu.Lock()
defer s.mu.Unlock()
if s.f == nil {
return errors.New("jsonl sink: closed")
}
if _, err := s.f.Write(line); err != nil {
return fmt.Errorf("jsonl sink: write: %w", err)
}
return nil
}
// Close releases the file descriptor. Safe to call multiple times.
func (s *JSONLSink) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.f == nil {
return nil
}
err := s.f.Close()
s.f = nil
return err
}
package alert
import (
"crypto/sha256"
"encoding/hex"
"time"
"github.com/pidginhost/csm/internal/processctx"
)
// AuditSchemaVersion is the value emitted in every AuditEvent's "v"
// field. Frozen contract -- downstream JSONL / syslog parsers pin on
// it. Bump only on incompatible schema changes; additive fields stay
// at the same version.
const AuditSchemaVersion = 1
// AuditEvent is the wire-stable shape every audit-log sink emits. JSON
// keys match what downstream SIEMs expect; fields are added at the
// end so older parsers ignore unknown ones.
type AuditEvent struct {
V int `json:"v"`
Timestamp time.Time `json:"ts"`
FindingID string `json:"finding_id"`
Severity string `json:"severity"`
Check string `json:"check"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
FilePath string `json:"file_path,omitempty"`
Hostname string `json:"hostname"`
TenantID string `json:"tenant_id,omitempty"`
Domain string `json:"domain,omitempty"`
Mailbox string `json:"mailbox,omitempty"`
Process *processctx.ProcessContext `json:"process,omitempty"`
}
// AuditSink is what every audit-log destination implements. Emit must
// be safe for concurrent calls; the alert dispatcher fans out to
// multiple sinks per finding.
type AuditSink interface {
// Name identifies the sink for diagnostics (e.g. "jsonl", "syslog").
Name() string
// Emit ships one event. Should return promptly; sinks that need
// long-haul I/O are expected to handle their own buffering.
Emit(event AuditEvent) error
// Close releases any held resources. Safe to call multiple times.
Close() error
}
// NewAuditEvent builds a versioned audit event from a Finding. hostname
// comes from cfg.Hostname (or os.Hostname() fallback); the caller is
// responsible for picking a stable value across emits.
func NewAuditEvent(hostname string, f Finding) AuditEvent {
return AuditEvent{
V: AuditSchemaVersion,
Timestamp: f.Timestamp.UTC(),
FindingID: makeFindingID(f),
Severity: f.Severity.String(),
Check: f.Check,
Message: f.Message,
Details: f.Details,
FilePath: f.FilePath,
Hostname: hostname,
TenantID: f.TenantID,
Domain: f.Domain,
Mailbox: f.Mailbox,
Process: f.Process,
}
}
// makeFindingID hashes the canonical fields of a Finding to a stable
// 16-hex-char ID. Two emits of the same finding (same timestamp + the
// same other fields) produce the same ID, so downstream dedup works
// across re-runs.
//
// The hash inputs use a "|" separator so the byte-for-byte
// concatenation cannot collide via field-boundary ambiguity (e.g. a
// Check name that ends in the same chars another field starts with).
func makeFindingID(f Finding) string {
h := sha256.New()
_, _ = h.Write([]byte(f.Timestamp.UTC().Format(time.RFC3339Nano)))
_, _ = h.Write([]byte("|"))
_, _ = h.Write([]byte(f.Check))
_, _ = h.Write([]byte("|"))
_, _ = h.Write([]byte(f.Severity.String()))
_, _ = h.Write([]byte("|"))
_, _ = h.Write([]byte(f.Message))
_, _ = h.Write([]byte("|"))
_, _ = h.Write([]byte(f.FilePath))
return hex.EncodeToString(h.Sum(nil))[:16]
}
package alert
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"net"
"os"
"strings"
"sync"
"time"
)
// SyslogConfig drives the syslog sink. Network is one of "udp",
// "tcp", "unix", "unixgram", or "tls"; Address is host:port (or a
// filesystem path for the unix variants). Facility names follow the
// classic syslog set; default is "local0".
type SyslogConfig struct {
Network string
Address string
Facility string
Hostname string // typically cfg.Hostname; falls back to os.Hostname()
TLSCAFile string // optional CA cert path for tls; empty = system roots
}
// SyslogSink is an RFC 5424 syslog client. The wire payload is the
// AuditEvent JSON so SIEMs that parse our JSONL file have a single
// schema regardless of transport. Messages bigger than the legacy
// 1024-byte limit are still sent -- modern receivers (rsyslog,
// syslog-ng) accept the full RFC 5424 max of 8192 bytes; if the
// operator's receiver caps shorter, syslog truncation is the
// expected behaviour.
type SyslogSink struct {
cfg SyslogConfig
priority int // pre-computed PRI byte; severity is OR'd in per emit
mu sync.Mutex
conn net.Conn
}
// facilityCodes is the standard syslog facility number set. local0..7
// is the customary range for application-level audit traffic; we
// default to local0 if the operator leaves the field blank.
var facilityCodes = map[string]int{
"kern": 0, "user": 1, "mail": 2, "daemon": 3, "auth": 4,
"syslog": 5, "lpr": 6, "news": 7, "uucp": 8, "cron": 9,
"authpriv": 10, "ftp": 11,
"local0": 16, "local1": 17, "local2": 18, "local3": 19,
"local4": 20, "local5": 21, "local6": 22, "local7": 23,
}
// NewSyslogSink validates the config, dials the destination, and
// returns a ready-to-emit sink. A dial failure here is fatal --
// callers should treat it as "audit syslog is misconfigured" rather
// than retrying silently. Once dialled, transient write errors
// trigger a single redial on the next Emit.
func NewSyslogSink(cfg SyslogConfig) (*SyslogSink, error) {
if cfg.Network == "" || cfg.Address == "" {
return nil, errors.New("syslog sink: network and address required")
}
switch cfg.Network {
case "udp", "tcp", "unix", "unixgram", "tls":
default:
return nil, fmt.Errorf("syslog sink: unknown network %q (want udp|tcp|unix|unixgram|tls)", cfg.Network)
}
facilityName := strings.ToLower(strings.TrimSpace(cfg.Facility))
if facilityName == "" {
facilityName = "local0"
}
facility, ok := facilityCodes[facilityName]
if !ok {
return nil, fmt.Errorf("syslog sink: unknown facility %q", cfg.Facility)
}
if cfg.Hostname == "" {
if h, err := os.Hostname(); err == nil {
cfg.Hostname = h
} else {
cfg.Hostname = "localhost"
}
}
s := &SyslogSink{cfg: cfg, priority: facility * 8}
if err := s.dial(); err != nil {
return nil, err
}
return s, nil
}
// Name identifies the sink in error messages.
func (s *SyslogSink) Name() string { return "syslog" }
// Emit formats the event as RFC 5424 and writes it to the
// destination. Mutex-serialised so concurrent calls do not interleave
// bytes on stream-oriented transports (TCP, TLS).
func (s *SyslogSink) Emit(event AuditEvent) error {
line, err := s.format(event)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.conn == nil {
if redialErr := s.dialLocked(); redialErr != nil {
return redialErr
}
}
if _, err := s.conn.Write(line); err != nil {
_ = s.conn.Close()
s.conn = nil
return fmt.Errorf("syslog sink: write: %w", err)
}
return nil
}
// Close releases the connection.
func (s *SyslogSink) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.conn == nil {
return nil
}
err := s.conn.Close()
s.conn = nil
return err
}
func (s *SyslogSink) dial() error {
s.mu.Lock()
defer s.mu.Unlock()
return s.dialLocked()
}
// dialLocked establishes the connection. Caller must hold s.mu.
func (s *SyslogSink) dialLocked() error {
if s.cfg.Network == "tls" {
tlsCfg, err := s.tlsConfig()
if err != nil {
return err
}
conn, err := tls.Dial("tcp", s.cfg.Address, tlsCfg)
if err != nil {
return fmt.Errorf("syslog sink: tls dial %s: %w", s.cfg.Address, err)
}
s.conn = conn
return nil
}
conn, err := net.DialTimeout(s.cfg.Network, s.cfg.Address, 5*time.Second)
if err != nil {
return fmt.Errorf("syslog sink: %s dial %s: %w", s.cfg.Network, s.cfg.Address, err)
}
s.conn = conn
return nil
}
func (s *SyslogSink) tlsConfig() (*tls.Config, error) {
cfg := &tls.Config{MinVersion: tls.VersionTLS12}
if s.cfg.TLSCAFile == "" {
return cfg, nil
}
// #nosec G304 -- TLSCAFile is operator-supplied via cfg.Alerts.AuditLog.Syslog.TLSCAFile; the operator owns the daemon config. Not attacker-controlled.
pem, err := os.ReadFile(s.cfg.TLSCAFile)
if err != nil {
return nil, fmt.Errorf("syslog sink: reading TLS CA: %w", err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(pem) {
return nil, fmt.Errorf("syslog sink: TLS CA file %s is not a valid PEM bundle", s.cfg.TLSCAFile)
}
cfg.RootCAs = pool
return cfg, nil
}
// format produces an RFC 5424 line. The MSG body is the JSON-encoded
// event so receivers can parse the structured payload directly.
//
// <PRI>1 TIMESTAMP HOSTNAME APP-NAME PROCID MSGID - MSG
//
// PRI = facility * 8 + severity-as-syslog-level. STRUCTURED-DATA is
// "-" (we lift everything into the JSON body to avoid duplicate
// representation).
func (s *SyslogSink) format(event AuditEvent) ([]byte, error) {
body, err := json.Marshal(event)
if err != nil {
return nil, fmt.Errorf("syslog sink: marshal: %w", err)
}
pri := s.priority + severityToSyslogLevel(event.Severity)
ts := event.Timestamp.UTC().Format(time.RFC3339Nano)
msgID := event.Check
if msgID == "" {
msgID = "-"
}
procID := os.Getpid()
line := fmt.Sprintf("<%d>1 %s %s csm %d %s - %s",
pri, ts, s.cfg.Hostname, procID, msgID, body)
// RFC 5424 over UDP / unixgram is one datagram per message; over
// TCP / TLS / unix-stream the receiver expects either octet
// counting ("nnn ") or LF framing. LF is the common rsyslog
// default; emit it for stream transports.
if s.cfg.Network == "tcp" || s.cfg.Network == "tls" || s.cfg.Network == "unix" {
line += "\n"
}
return []byte(line), nil
}
// severityToSyslogLevel maps CSM severity strings onto the standard
// syslog level codes. Critical -> 2 (crit), High -> 3 (err),
// Warning -> 4 (warning), default -> 6 (info).
func severityToSyslogLevel(s string) int {
switch s {
case "CRITICAL":
return 2
case "HIGH":
return 3
case "WARNING":
return 4
}
return 6
}
package alert
import (
"fmt"
"net/smtp"
"strings"
"github.com/pidginhost/csm/internal/config"
)
func SendEmail(cfg *config.Config, subject, body string) error {
to := cfg.Alerts.Email.To
from := cfg.Alerts.Email.From
smtpAddr := cfg.Alerts.Email.SMTP
if len(to) == 0 {
return fmt.Errorf("no email recipients configured")
}
msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nContent-Type: text/plain; charset=UTF-8\r\n\r\n%s",
from,
strings.Join(to, ", "),
subject,
body,
)
host := strings.Split(smtpAddr, ":")[0]
err := smtp.SendMail(smtpAddr, nil, from, to, []byte(msg))
if err != nil {
// Retry without auth for local servers
c, dialErr := smtp.Dial(smtpAddr)
if dialErr != nil {
return fmt.Errorf("smtp dial %s: %w (original: %v)", smtpAddr, dialErr, err)
}
defer func() { _ = c.Close() }()
if err := c.Hello(host); err != nil {
return fmt.Errorf("smtp hello: %w", err)
}
if err := c.Mail(from); err != nil {
return fmt.Errorf("smtp mail from: %w", err)
}
for _, addr := range to {
if err := c.Rcpt(addr); err != nil {
return fmt.Errorf("smtp rcpt %s: %w", addr, err)
}
}
w, err := c.Data()
if err != nil {
return fmt.Errorf("smtp data: %w", err)
}
if _, err := w.Write([]byte(msg)); err != nil {
return fmt.Errorf("smtp write: %w", err)
}
if err := w.Close(); err != nil {
return fmt.Errorf("smtp close: %w", err)
}
return c.Quit()
}
return nil
}
package alert
import (
"encoding/json"
"net"
"os"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
)
// State files drive suppression of reputation and auto-block alerts.
// Parse failures must degrade open so corrupt state never hides
// operator-facing findings; warning logs keep the corruption visible.
// Missing state files stay silent because they are normal before the
// first block.
// FilterBlockedAlerts removes reputation and auto-block alerts for IPs
// that are currently blocked in CSM firewall (either just blocked or previously blocked).
func FilterBlockedAlerts(cfg *config.Config, findings []Finding) []Finding {
if !cfg.Suppressions.SuppressBlockedAlerts {
return findings
}
// Load all currently blocked IPs and queued blocks from state.
blockedIPs, pendingIPs := loadBlockedAlertState(cfg.StatePath)
// Also collect IPs and subnets blocked in this batch.
var blockedSubnets []*net.IPNet
for _, f := range findings {
if f.Check != "auto_block" {
continue
}
parts := strings.Fields(f.Message)
for i, p := range parts {
if p == "AUTO-BLOCK:" && i+1 < len(parts) {
blockedIPs[parts[i+1]] = true
break
}
if p == "AUTO-BLOCK-SUBNET:" && i+1 < len(parts) {
if _, ipnet, err := net.ParseCIDR(parts[i+1]); err == nil {
blockedSubnets = append(blockedSubnets, ipnet)
}
break
}
}
}
// Also suppress alerts for IPs queued for blocking (rate-limited).
// These will be blocked once the rate limit resets - no need to alert.
for ip := range pendingIPs {
blockedIPs[ip] = true
}
if len(blockedIPs) == 0 && len(blockedSubnets) == 0 {
return findings
}
// Filter out alerts for IPs that are handled automatically.
// When suppress_blocked_alerts is on, the operator doesn't want to be
// notified about IPs that are already dealt with - they only want alerts
// that require human action.
var filtered []Finding
for _, f := range findings {
if f.Check == "ip_reputation" || f.Check == "local_threat_score" {
// If auto-blocking is enabled, these IPs are handled automatically.
// Skip the alert - there's nothing for the operator to do.
if cfg.AutoResponse.Enabled && cfg.AutoResponse.BlockIPs {
continue
}
// Check if the IP is already blocked by exact IP match.
isBlocked := false
for ip := range blockedIPs {
if strings.Contains(f.Message, ip) {
isBlocked = true
break
}
}
// Also suppress if the IP falls within a freshly-blocked subnet.
// AUTO-BLOCK-SUBNET: findings from the same batch must silence
// per-IP reputation alerts for addresses inside that /24.
if !isBlocked && len(blockedSubnets) > 0 {
if parsed := extractIPFromFindingMessage(f.Message); parsed != nil {
for _, subnet := range blockedSubnets {
if subnet.Contains(parsed) {
isBlocked = true
break
}
}
}
}
if isBlocked {
continue
}
}
if f.Check == "auto_block" {
continue
}
filtered = append(filtered, f)
}
return filtered
}
// extractIPFromFindingMessage scans a finding message for the first token
// that parses as a valid IP address and returns it. Returns nil if no IP
// is found. Used to match reputation findings against blocked-subnet CIDRs.
func extractIPFromFindingMessage(msg string) net.IP {
for _, field := range strings.Fields(msg) {
// Strip common punctuation that may trail an IP in a message.
candidate := strings.TrimRight(field, ",:;)([]")
if ip := net.ParseIP(candidate); ip != nil {
return ip
}
}
return nil
}
// loadPendingIPs reads IPs queued for blocking from blocked_ips.json.
func loadPendingIPs(statePath string) map[string]bool {
ips := make(map[string]bool)
loadBlockFileEntries(statePath, time.Time{}, nil, ips, blockFilePendingSection)
return ips
}
func loadBlockedIPSource(statePath string, now time.Time, ips map[string]bool) {
// Use injected loader (bbolt-backed) when available.
if BlockedIPsFunc != nil {
for ip, v := range BlockedIPsFunc() {
ips[ip] = v
}
return
}
loadFirewallStateFile(statePath, now, ips)
}
// BlockedIPsFunc is an optional callback that returns currently blocked IPs.
// Set by the daemon (or tests) to provide blocked IPs from bbolt store,
// avoiding a circular import between alert and store packages.
// When nil, loadBlockedIPs falls back to reading flat files.
var BlockedIPsFunc func() map[string]bool
// loadBlockedIPs reads blocked IPs from both the firewall engine state
// and the legacy blocked_ips.json file.
func loadBlockedIPs(statePath string) map[string]bool {
ips := make(map[string]bool)
now := time.Now()
loadBlockedIPSource(statePath, now, ips)
loadBlockFileEntries(statePath, now, ips, nil, blockFileIPsSection)
return ips
}
func loadBlockedAlertState(statePath string) (map[string]bool, map[string]bool) {
ips := make(map[string]bool)
pending := make(map[string]bool)
now := time.Now()
loadBlockedIPSource(statePath, now, ips)
loadBlockFileEntries(statePath, now, ips, pending, blockFileIPsSection|blockFilePendingSection)
return ips, pending
}
func loadFirewallStateFile(statePath string, now time.Time, ips map[string]bool) {
fwPath := filepath.Join(statePath, "firewall", "state.json")
fwData, err := os.ReadFile(fwPath) // #nosec G304 -- filepath.Join under operator-configured statePath.
if err != nil {
return
}
var fwState struct {
Blocked []struct {
IP string `json:"ip"`
ExpiresAt time.Time `json:"expires_at"`
} `json:"blocked"`
}
if err := json.Unmarshal(fwData, &fwState); err != nil {
csmlog.Warn("alert filter: firewall state.json unparseable, suppression degraded",
"path", fwPath, "err", err)
return
}
for _, entry := range fwState.Blocked {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
ips[entry.IP] = true
}
}
}
type blockFile struct {
IPs []blockFileIP
Pending []blockFilePendingIP
}
type blockFileIP struct {
IP string `json:"ip"`
ExpiresAt time.Time `json:"expires_at"`
}
type blockFilePendingIP struct {
IP string `json:"ip"`
}
type blockFileSection uint8
const (
blockFileIPsSection blockFileSection = 1 << iota
blockFilePendingSection
)
func loadBlockFileEntries(statePath string, now time.Time, ips map[string]bool, pending map[string]bool, sections blockFileSection) {
bf, ok := loadBlockFile(statePath, sections)
if !ok {
return
}
for _, entry := range bf.IPs {
if ips != nil && (entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt)) {
ips[entry.IP] = true
}
}
for _, entry := range bf.Pending {
if pending != nil {
pending[entry.IP] = true
}
}
}
func loadBlockFile(statePath string, sections blockFileSection) (blockFile, bool) {
blockedPath := filepath.Join(statePath, "blocked_ips.json")
data, err := os.ReadFile(blockedPath) // #nosec G304 -- filepath.Join under operator-configured statePath.
if err != nil {
return blockFile{}, false
}
var raw struct {
IPs json.RawMessage `json:"ips"`
Pending json.RawMessage `json:"pending"`
}
if err := json.Unmarshal(data, &raw); err != nil {
csmlog.Warn("alert filter: blocked_ips.json unparseable, suppression degraded",
"path", blockedPath, "err", err)
return blockFile{}, false
}
var bf blockFile
if sections&blockFileIPsSection != 0 && len(raw.IPs) > 0 && string(raw.IPs) != "null" {
if err := json.Unmarshal(raw.IPs, &bf.IPs); err != nil {
csmlog.Warn("alert filter: blocked_ips.json blocked entries unparseable, suppression degraded",
"path", blockedPath, "err", err)
}
}
if sections&blockFilePendingSection != 0 && len(raw.Pending) > 0 && string(raw.Pending) != "null" {
if err := json.Unmarshal(raw.Pending, &bf.Pending); err != nil {
csmlog.Warn("alert filter: blocked_ips.json pending entries unparseable, suppression degraded",
"path", blockedPath, "err", err)
}
}
return bf, true
}
package alert
import "github.com/pidginhost/csm/internal/config"
// EmitForTest drives emitAudit with no audit sinks so callers (often in
// other packages' tests) can trigger observer fan-out without setting up
// jsonl/syslog. Lives in a non-test file because Go tests cannot import
// _test.go symbols across packages.
func EmitForTest(f Finding) {
emitAudit(&config.Config{Hostname: "test"}, []Finding{f})
}
package alert
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/pidginhost/csm/internal/config"
)
// webhookTransport is the shared *http.Transport reused across every
// webhook dispatch. Reusing one transport keeps the underlying TCP /
// TLS connections in its keepalive pool so hosts that fire hundreds
// of webhooks per hour avoid a new handshake per alert. Per-call
// timeouts stay configurable via httpClient: each call wraps the
// shared transport in a fresh *http.Client carrying the requested
// timeout. http.DefaultTransport already configures sensible
// defaults; reuse it directly rather than instantiating a separate
// pool that would shadow Go's HTTP/2 / proxy plumbing.
var webhookTransport http.RoundTripper = http.DefaultTransport
const maxWebhookResponseDrainBytes int64 = 512 << 10
const maxWebhookResponseDrainDuration = 250 * time.Millisecond
// httpClient returns a webhook client with the requested timeout
// backed by the shared transport, so the keepalive pool is shared
// across dispatches without losing per-call timeout configurability.
func httpClient(timeout time.Duration) *http.Client {
return &http.Client{Timeout: timeout, Transport: webhookTransport}
}
// SetWebhookTransportForTest lets tests inject a fake RoundTripper.
// Not safe for concurrent calls; tests should set up before parallel
// dispatch and restore after.
func SetWebhookTransportForTest(rt http.RoundTripper) (restore func()) {
prev := webhookTransport
webhookTransport = rt
return func() { webhookTransport = prev }
}
func closeWebhookResponseBody(resp *http.Response) {
if resp == nil || resp.Body == nil {
return
}
if resp.Close || resp.ContentLength == 0 || resp.ContentLength > maxWebhookResponseDrainBytes {
_ = resp.Body.Close()
return
}
done := make(chan struct{})
go func() {
// Read one sentinel byte past the reuse limit so an exactly-at-limit
// response still reaches the underlying EOF and can be pooled.
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, maxWebhookResponseDrainBytes+1))
close(done)
}()
select {
case <-done:
case <-time.After(maxWebhookResponseDrainDuration):
// A slow or streaming response is not worth holding alert dispatch
// open just to preserve a keepalive connection.
}
_ = resp.Body.Close()
}
func SendWebhook(cfg *config.Config, subject, body string) error {
url := cfg.Alerts.Webhook.URL
if url == "" {
return fmt.Errorf("no webhook URL configured")
}
var payload []byte
var err error
switch cfg.Alerts.Webhook.Type {
case "slack":
payload, err = json.Marshal(map[string]string{
"text": fmt.Sprintf("*%s*\n```\n%s\n```", subject, body),
})
case "discord":
payload, err = json.Marshal(map[string]string{
"content": fmt.Sprintf("**%s**\n```\n%s\n```", subject, body),
})
default:
payload, err = json.Marshal(map[string]string{
"subject": subject,
"body": body,
})
}
if err != nil {
return fmt.Errorf("marshaling webhook payload: %w", err)
}
client := httpClient(10 * time.Second)
resp, err := client.Post(url, "application/json", bytes.NewReader(payload))
if err != nil {
return fmt.Errorf("webhook POST: %w", err)
}
defer closeWebhookResponseBody(resp)
if resp.StatusCode >= 400 {
return fmt.Errorf("webhook returned %d", resp.StatusCode)
}
return nil
}
// SendPhpanelWebhookFinding posts a single finding to the configured phpanel
// endpoint, signing the body with HMAC-SHA256 in X-CSM-Signature. Stateless;
// caller is responsible for filtering / batching.
func SendPhpanelWebhookFinding(cfg *config.Config, f Finding) error {
if cfg.Alerts.Webhook.URL == "" {
return fmt.Errorf("phpanel webhook URL not set")
}
secret := phpanelWebhookSecret(cfg)
if secret == "" {
return fmt.Errorf("phpanel webhook HMAC secret not configured")
}
payload := map[string]interface{}{
"hostname": cfg.Hostname,
"timestamp": time.Now().UTC().Format(time.RFC3339),
"finding": f,
}
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshaling phpanel payload: %w", err)
}
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(body)
sig := "sha256=" + hex.EncodeToString(mac.Sum(nil))
req, err := http.NewRequest(http.MethodPost, cfg.Alerts.Webhook.URL, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-CSM-Signature", sig)
req.Header.Set("X-CSM-Hostname", cfg.Hostname)
req.Header.Set("User-Agent", "csm")
client := httpClient(10 * time.Second)
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("phpanel webhook POST: %w", err)
}
defer closeWebhookResponseBody(resp)
if resp.StatusCode >= 400 {
return fmt.Errorf("phpanel webhook HTTP %d", resp.StatusCode)
}
return nil
}
func phpanelWebhookSecret(cfg *config.Config) string {
if cfg.Alerts.Webhook.HMACSecretEnv != "" {
if v := os.Getenv(cfg.Alerts.Webhook.HMACSecretEnv); v != "" {
return v
}
}
return cfg.Alerts.Webhook.HMACSecret
}
// Package atomicio implements atomic file writes used by state-bearing
// callers (firewall engine, autoblock tracker, etc.). The package is a
// dependency leaf: it imports only the standard library so any caller
// can use AtomicWriteJSON without risking an import cycle through the
// existing state / store packages.
package atomicio
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
)
// AtomicWriteJSON marshals v to JSON and writes it to path atomically:
// MarshalIndent, write tmp, fsync tmp, rename. Returns the first error
// encountered. On rename failure the tmp file is removed best-effort.
//
// Reserved for state files that callers re-read on the next startup or
// the next tick - a torn write would leave the daemon with stale or
// corrupt state. Hot-path callers that only need a best-effort cache
// dump should not use this helper.
func AtomicWriteJSON(path string, perm os.FileMode, v any) error {
data, err := json.MarshalIndent(v, "", " ")
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
legacyTmp := path + ".tmp"
if removeErr := os.Remove(legacyTmp); removeErr != nil && !os.IsNotExist(removeErr) {
return fmt.Errorf("remove stale tmp: %w", removeErr)
}
return atomicWrite(path, perm, data)
}
// AtomicWrite writes already-serialized bytes to path atomically with the
// same write-tmp, fsync, rename, dir-fsync sequence as AtomicWriteJSON.
func AtomicWrite(path string, perm os.FileMode, data []byte) error {
return atomicWrite(path, perm, data)
}
func atomicWrite(path string, perm os.FileMode, data []byte) error {
dir := filepath.Dir(path)
// #nosec G304 -- caller owns the destination path; tmp lives in
// the same operator-owned state directory.
f, err := os.CreateTemp(dir, "."+filepath.Base(path)+".*.tmp")
if err != nil {
return fmt.Errorf("open tmp: %w", err)
}
tmp := f.Name()
if err := f.Chmod(perm); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("chmod tmp: %w", err)
}
if n, err := f.Write(data); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("write tmp: %w", err)
} else if n != len(data) {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("write tmp: %w", io.ErrShortWrite)
}
if err := f.Sync(); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("fsync tmp: %w", err)
}
if err := f.Close(); err != nil {
_ = os.Remove(tmp)
return fmt.Errorf("close tmp: %w", err)
}
if err := os.Rename(tmp, path); err != nil {
_ = os.Remove(tmp)
return fmt.Errorf("rename: %w", err)
}
// #nosec G304 -- dir is filepath.Dir of caller-owned path; opened
// read-only solely to fsync the directory after rename so the
// new dentry survives a power-loss.
d, openErr := os.Open(dir)
if openErr != nil {
return fmt.Errorf("open dir: %w", openErr)
}
if err := d.Sync(); err != nil {
_ = d.Close()
return fmt.Errorf("fsync dir: %w", err)
}
if err := d.Close(); err != nil {
return fmt.Errorf("close dir: %w", err)
}
return nil
}
package attackdb
import (
"bufio"
"fmt"
"net"
"os"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// AttackType categorises observed attacks for grouping and scoring.
type AttackType string
const (
AttackBruteForce AttackType = "brute_force"
AttackWAFBlock AttackType = "waf_block"
AttackWebshell AttackType = "webshell"
AttackPhishing AttackType = "phishing"
AttackC2 AttackType = "c2"
AttackRecon AttackType = "recon"
AttackSPAM AttackType = "spam"
AttackCPanelLogin AttackType = "cpanel_login"
AttackFileUpload AttackType = "file_upload"
AttackReputation AttackType = "reputation"
AttackOther AttackType = "other"
)
// checkToAttack maps alert.Finding.Check values to attack types.
var checkToAttack = map[string]AttackType{
// Brute force
"wp_login_bruteforce": AttackBruteForce,
"xmlrpc_abuse": AttackBruteForce,
"ftp_bruteforce": AttackBruteForce,
"ssh_login_unknown_ip": AttackBruteForce,
"ssh_login_realtime": AttackBruteForce,
"webmail_bruteforce": AttackBruteForce,
"api_auth_failure": AttackBruteForce,
"api_auth_failure_realtime": AttackBruteForce,
"ftp_auth_failure_realtime": AttackBruteForce,
"email_auth_failure_realtime": AttackBruteForce,
"credential_stuffing": AttackBruteForce,
"pam_bruteforce": AttackBruteForce,
"smtp_bruteforce": AttackBruteForce,
"smtp_probe_abuse": AttackBruteForce,
"smtp_subnet_spray": AttackBruteForce,
"mail_bruteforce": AttackBruteForce,
"mail_subnet_spray": AttackBruteForce,
"mail_account_compromised": AttackBruteForce,
"admin_panel_bruteforce": AttackBruteForce,
// Webshells and malware
"webshell": AttackWebshell,
"new_webshell_file": AttackWebshell,
"obfuscated_php": AttackWebshell,
"php_dropper": AttackWebshell,
"suspicious_php_content": AttackWebshell,
"new_php_in_languages": AttackWebshell,
"new_php_in_upgrade": AttackWebshell,
"backdoor_binary": AttackWebshell,
"new_executable_in_config": AttackWebshell,
// Phishing
"phishing_page": AttackPhishing,
"phishing_php": AttackPhishing,
"phishing_iframe": AttackPhishing,
"phishing_redirector": AttackPhishing,
"phishing_credential_log": AttackPhishing,
"phishing_kit_archive": AttackPhishing,
"phishing_directory": AttackPhishing,
// C2 and suspicious processes
"fake_kernel_thread": AttackC2,
"suspicious_process": AttackC2,
"php_suspicious_execution": AttackC2,
"user_outbound_connection": AttackC2,
"exfiltration_paste_site": AttackC2,
// Recon
"wp_user_enumeration": AttackRecon,
// SPAM
"mail_per_account": AttackSPAM,
"exim_frozen_realtime": AttackSPAM,
// WAF
"modsec_block": AttackWAFBlock,
"waf_block": AttackWAFBlock,
// cPanel/webmail login
"cpanel_login": AttackCPanelLogin,
"cpanel_login_realtime": AttackCPanelLogin,
"cpanel_multi_ip_login": AttackCPanelLogin,
"webmail_login_realtime": AttackCPanelLogin,
"ftp_login": AttackCPanelLogin,
"ftp_login_realtime": AttackCPanelLogin,
"pam_login": AttackCPanelLogin,
// File upload
"cpanel_file_upload_realtime": AttackFileUpload,
// Reputation - known malicious IPs from threat database
"ip_reputation": AttackReputation,
// NOTE: "local_threat_score" is intentionally excluded - it is a derived
// finding, not a raw attack. Recording it would create a feedback loop
// that inflates EventCount by +1 every 10-minute cycle.
}
// Event is a single observed attack incident.
type Event struct {
Timestamp time.Time `json:"ts"`
IP string `json:"ip"`
AttackType AttackType `json:"type"`
CheckName string `json:"check"`
Severity int `json:"sev"`
Account string `json:"account,omitempty"`
Message string `json:"msg,omitempty"`
}
// IPRecord is the per-IP aggregated intelligence record.
type IPRecord struct {
IP string `json:"ip"`
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
EventCount int `json:"event_count"`
AttackCounts map[AttackType]int `json:"attack_counts"`
Accounts map[string]int `json:"accounts"`
ThreatScore int `json:"threat_score"`
AutoBlocked bool `json:"auto_blocked"`
BruteForceWindowStart time.Time `json:"brute_force_window_start,omitempty"`
BruteForceWindowCount int `json:"brute_force_window_count,omitempty"`
BruteForceSustainedAt time.Time `json:"brute_force_sustained_at,omitempty"`
}
// DB is the in-memory attack database backed by JSON files.
type DB struct {
mu sync.RWMutex
records map[string]*IPRecord
deletedIPs map[string]struct{}
dirtyIPs map[string]struct{}
pendingEvents []Event
dbPath string
dirty bool
stopCh chan struct{}
wg sync.WaitGroup
}
// markDirtyLocked records that ip's record changed and must be persisted on the
// next flush. The caller must hold db.mu. dirty stays as the flush gate so
// Flush keeps its existing "anything to write?" check.
func (db *DB) markDirtyLocked(ip string) {
if db.dirtyIPs == nil {
db.dirtyIPs = make(map[string]struct{})
}
db.dirtyIPs[ip] = struct{}{}
db.dirty = true
}
func (db *DB) markDeletedLocked(ip string) {
if db.deletedIPs == nil {
db.deletedIPs = make(map[string]struct{})
}
db.deletedIPs[ip] = struct{}{}
delete(db.dirtyIPs, ip)
db.dirty = true
}
var (
globalDB *DB
globalMu sync.Mutex
dbInitOnce sync.Once
)
// Init initializes the global attack database.
func Init(statePath string) *DB {
dbInitOnce.Do(func() {
dbPath := statePath + "/attack_db"
_ = os.MkdirAll(dbPath, 0700)
db := &DB{
records: make(map[string]*IPRecord),
deletedIPs: make(map[string]struct{}),
dbPath: dbPath,
stopCh: make(chan struct{}),
}
db.load()
db.pruneExpired()
// Background saver - flush dirty records every 30 seconds
db.wg.Add(1)
go db.backgroundSaver()
globalDB = db
})
return globalDB
}
// SeedFromPermanentBlocklist imports IPs from the threat DB permanent blocklist
// into the attack database. These are IPs that already attacked and were auto-blocked.
// Only imports IPs not already in the attack DB.
func (db *DB) SeedFromPermanentBlocklist(statePath string) int {
path := statePath + "/threat_db/permanent.txt"
// #nosec G304 -- fixed filename under operator-configured statePath.
f, err := os.Open(path)
if err != nil {
return 0
}
defer func() { _ = f.Close() }()
imported := 0
now := time.Now()
scanner := bufio.NewScanner(f)
db.mu.Lock()
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
ip := fields[0]
if net.ParseIP(ip) == nil {
continue
}
if _, exists := db.records[ip]; exists {
continue // already tracked
}
// Extract reason from comment: "1.2.3.4 # reason [date]"
reason := "auto-blocked (historical)"
if idx := strings.Index(line, "# "); idx > 0 {
reason = strings.TrimSpace(line[idx+2:])
}
db.records[ip] = &IPRecord{
IP: ip,
FirstSeen: now,
LastSeen: now,
EventCount: 1,
AttackCounts: map[AttackType]int{AttackOther: 1},
Accounts: make(map[string]int),
AutoBlocked: true,
}
db.records[ip].ThreatScore = ComputeScore(db.records[ip])
db.pendingEvents = append(db.pendingEvents, Event{
Timestamp: now,
IP: ip,
AttackType: AttackOther,
CheckName: "permanent_blocklist_import",
Severity: 2,
Message: truncate(reason, 200),
})
delete(db.deletedIPs, ip)
db.markDirtyLocked(ip)
imported++
}
db.mu.Unlock()
return imported
}
// Global returns the global attack database instance.
func Global() *DB {
globalMu.Lock()
defer globalMu.Unlock()
return globalDB
}
// SetGlobal overrides the global attack database. Mirrors
// store.SetGlobal: production wires globalDB exactly once via Init;
// tests use this to install a pre-seeded DB without touching the
// sync.Once-guarded Init path.
func SetGlobal(db *DB) {
globalMu.Lock()
globalDB = db
globalMu.Unlock()
}
// NewForTest builds a bare in-memory DB pre-populated with the given
// records. No backgroundSaver is started (no goroutines to clean up),
// no disk path is configured. Reserved for unit tests; production
// wiring stays on Init. Records are deep-copied so a later mutation
// to the caller's map (including its nested AttackCounts / Accounts
// maps) cannot bleed into the DB.
func NewForTest(records map[string]*IPRecord) *DB {
db := &DB{
records: make(map[string]*IPRecord, len(records)),
deletedIPs: make(map[string]struct{}),
stopCh: make(chan struct{}),
}
for k, v := range records {
cp := *v
cp.AttackCounts = make(map[AttackType]int, len(v.AttackCounts))
for ak, av := range v.AttackCounts {
cp.AttackCounts[ak] = av
}
cp.Accounts = make(map[string]int, len(v.Accounts))
for ak, av := range v.Accounts {
cp.Accounts[ak] = av
}
db.records[k] = &cp
}
return db
}
// RecordFinding records an attack event from a finding.
// Fire-and-forget: never blocks, never panics.
func (db *DB) RecordFinding(f alert.Finding) {
attackType, ok := checkToAttack[f.Check]
if !ok {
return // not an attack-related check
}
ip := extractFindingIP(f)
if ip == "" {
return
}
account := extractFindingAccount(f)
event := Event{
Timestamp: f.Timestamp,
IP: ip,
AttackType: attackType,
CheckName: f.Check,
Severity: int(f.Severity),
Account: account,
Message: truncate(f.Message, 200),
}
now := f.Timestamp
if now.IsZero() {
now = time.Now()
}
db.mu.Lock()
rec, exists := db.records[ip]
if !exists {
rec = &IPRecord{
IP: ip,
FirstSeen: now,
AttackCounts: make(map[AttackType]int),
Accounts: make(map[string]int),
}
db.records[ip] = rec
}
rec.LastSeen = now
rec.EventCount++
rec.AttackCounts[attackType]++
if tracksSustainedBruteScore(f.Check) {
updateBruteForceWindow(rec, now)
}
if account != "" {
rec.Accounts[account]++
}
rec.ThreatScore = computeScoreAt(rec, now)
db.pendingEvents = append(db.pendingEvents, event)
delete(db.deletedIPs, ip)
db.markDirtyLocked(ip)
db.mu.Unlock()
}
// MarkBlocked sets the auto-blocked flag on an IP record.
func (db *DB) MarkBlocked(ip string) {
db.mu.Lock()
if rec, ok := db.records[ip]; ok {
rec.AutoBlocked = true
rec.ThreatScore = ComputeScore(rec)
delete(db.deletedIPs, ip)
db.markDirtyLocked(ip)
}
db.mu.Unlock()
}
// LookupIP returns the record for an IP, or nil if not tracked.
func (db *DB) LookupIP(ip string) *IPRecord {
db.mu.RLock()
defer db.mu.RUnlock()
rec, ok := db.records[ip]
if !ok {
return nil
}
// Return a copy to avoid races
cp := *rec
cp.AttackCounts = make(map[AttackType]int, len(rec.AttackCounts))
for k, v := range rec.AttackCounts {
cp.AttackCounts[k] = v
}
cp.Accounts = make(map[string]int, len(rec.Accounts))
for k, v := range rec.Accounts {
cp.Accounts[k] = v
}
return &cp
}
// TopAttackers returns the top N IPs by threat score.
func (db *DB) TopAttackers(n int) []*IPRecord {
db.mu.RLock()
defer db.mu.RUnlock()
all := make([]*IPRecord, 0, len(db.records))
for _, rec := range db.records {
cp := *rec
cp.AttackCounts = make(map[AttackType]int, len(rec.AttackCounts))
for k, v := range rec.AttackCounts {
cp.AttackCounts[k] = v
}
cp.Accounts = make(map[string]int, len(rec.Accounts))
for k, v := range rec.Accounts {
cp.Accounts[k] = v
}
all = append(all, &cp)
}
// Sort by threat score descending, then event count
sortRecords(all)
if n > len(all) {
n = len(all)
}
return all[:n]
}
// Flush saves all pending data to disk. Called on daemon shutdown.
func (db *DB) Flush() error {
db.mu.Lock()
events := db.pendingEvents
db.pendingEvents = nil
dirty := db.dirty
db.dirty = false
db.mu.Unlock()
if len(events) > 0 {
db.appendEvents(events)
}
if dirty {
db.saveRecords()
}
return nil
}
// Stop stops the background saver and flushes.
func (db *DB) Stop() {
close(db.stopCh)
db.wg.Wait()
_ = db.Flush()
}
func (db *DB) backgroundSaver() {
defer db.wg.Done()
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-db.stopCh:
return
case <-ticker.C:
_ = db.Flush()
}
}
}
// extractIP pulls an IP address from a finding message.
func extractIP(message string) string {
for _, sep := range []string{" from ", ": ", "accessing server: "} {
if idx := strings.Index(message, sep); idx >= 0 {
rest := message[idx+len(sep):]
fields := strings.Fields(rest)
if len(fields) > 0 {
ip := strings.TrimRight(fields[0], ",:;)([]")
// Strip AbuseIPDB score suffix like "(AbuseIPDB"
if paren := strings.Index(ip, "("); paren > 0 {
ip = ip[:paren]
}
if net.ParseIP(ip) != nil {
return ip
}
}
}
}
return ""
}
func extractFindingIP(f alert.Finding) string {
if ip := normalizeRecordIP(f.SourceIP); ip != "" {
return ip
}
return extractIP(f.Message)
}
func normalizeRecordIP(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
if host, _, err := net.SplitHostPort(raw); err == nil {
raw = host
}
raw = strings.Trim(raw, "[]")
ip := net.ParseIP(raw)
if ip == nil {
return ""
}
return ip.String()
}
func extractFindingAccount(f alert.Finding) string {
mailbox := strings.TrimSpace(f.Mailbox)
domain := strings.TrimSpace(f.Domain)
if mailbox != "" {
if strings.Contains(mailbox, "@") || domain == "" {
return mailbox
}
return mailbox + "@" + strings.ToLower(domain)
}
if tenant := strings.TrimSpace(f.TenantID); tenant != "" {
return tenant
}
return extractAccount(f.Message, f.Details)
}
// extractAccount tries to pull a cPanel account name from message or details.
func extractAccount(message, details string) string {
// Check details first: "Account: username"
for _, text := range []string{details, message} {
if idx := strings.Index(text, "Account: "); idx >= 0 {
rest := text[idx+9:]
fields := strings.Fields(rest)
if len(fields) > 0 {
return fields[0]
}
}
}
// Try /home/username/ pattern
if idx := strings.Index(message, "/home/"); idx >= 0 {
rest := message[idx+6:]
if slash := strings.Index(rest, "/"); slash > 0 {
return rest[:slash]
}
}
return ""
}
func truncate(s string, n int) string {
r := []rune(s)
if len(r) <= n {
return s
}
return string(r[:n])
}
func updateBruteForceWindow(rec *IPRecord, ts time.Time) {
if rec.BruteForceWindowStart.IsZero() ||
ts.Before(rec.BruteForceWindowStart) ||
ts.Sub(rec.BruteForceWindowStart) > sustainedBruteForceWindow {
rec.BruteForceWindowStart = ts
rec.BruteForceWindowCount = 1
return
}
rec.BruteForceWindowCount++
if rec.BruteForceWindowCount >= sustainedBruteForceThreshold &&
(rec.BruteForceSustainedAt.IsZero() || !ts.Before(rec.BruteForceSustainedAt)) {
rec.BruteForceSustainedAt = ts
}
}
func tracksSustainedBruteScore(check string) bool {
return check == "email_auth_failure_realtime"
}
// RemoveIP removes an IP from the attack database entirely.
func (db *DB) RemoveIP(ip string) {
db.mu.Lock()
delete(db.records, ip)
db.markDeletedLocked(ip)
db.mu.Unlock()
}
// PruneExpired removes records older than 90 days.
func (db *DB) PruneExpired() {
db.pruneExpired()
}
func (db *DB) pruneExpired() {
cutoff := time.Now().Add(-90 * 24 * time.Hour)
db.mu.Lock()
for ip, rec := range db.records {
if rec.LastSeen.Before(cutoff) {
delete(db.records, ip)
db.markDeletedLocked(ip)
}
}
db.mu.Unlock()
}
// TotalIPs returns the number of tracked IPs.
func (db *DB) TotalIPs() int {
db.mu.RLock()
defer db.mu.RUnlock()
return len(db.records)
}
// AllRecords returns a deep-copy snapshot of all records.
func (db *DB) AllRecords() []*IPRecord {
db.mu.RLock()
defer db.mu.RUnlock()
result := make([]*IPRecord, 0, len(db.records))
for _, rec := range db.records {
cp := *rec
cp.AttackCounts = make(map[AttackType]int, len(rec.AttackCounts))
for k, v := range rec.AttackCounts {
cp.AttackCounts[k] = v
}
cp.Accounts = make(map[string]int, len(rec.Accounts))
for k, v := range rec.Accounts {
cp.Accounts[k] = v
}
result = append(result, &cp)
}
return result
}
// FormatTopLine returns a summary string for stderr logging.
func (db *DB) FormatTopLine() string {
db.mu.RLock()
defer db.mu.RUnlock()
total := len(db.records)
blocked := 0
for _, r := range db.records {
if r.AutoBlocked {
blocked++
}
}
return fmt.Sprintf("%d IPs tracked, %d auto-blocked", total, blocked)
}
package attackdb
import (
"bufio"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
"github.com/pidginhost/csm/internal/store"
)
const (
recordsFile = "records.json"
eventsFile = "events.jsonl"
maxEventsBytes = 10 * 1024 * 1024 // 10 MB
)
// load reads IP records from the bbolt store (if available) or from
// the flat-file records.json.
func (db *DB) load() {
if sdb := store.Global(); sdb != nil {
storeRecords := sdb.LoadAllIPRecords()
db.mu.Lock()
defer db.mu.Unlock()
if db.records == nil {
db.records = make(map[string]*IPRecord, len(storeRecords))
}
for ip, sr := range storeRecords {
rec := &IPRecord{
IP: sr.IP,
FirstSeen: sr.FirstSeen,
LastSeen: sr.LastSeen,
EventCount: sr.EventCount,
ThreatScore: sr.ThreatScore,
AutoBlocked: sr.AutoBlocked,
BruteForceWindowStart: sr.BruteForceWindowStart,
BruteForceWindowCount: sr.BruteForceWindowCount,
BruteForceSustainedAt: sr.BruteForceSustainedAt,
AttackCounts: make(map[AttackType]int),
Accounts: make(map[string]int),
}
for k, v := range sr.AttackCounts {
rec.AttackCounts[AttackType(k)] = v
}
for k, v := range sr.Accounts {
rec.Accounts[k] = v
}
if normalizeLoadedRecord(rec) {
db.markDirtyLocked(ip)
}
db.records[ip] = rec
}
return
}
// Fallback: flat-file records.json.
path := filepath.Join(db.dbPath, recordsFile)
// #nosec G304 -- filepath.Join under operator-configured db.dbPath.
data, err := os.ReadFile(path)
if err != nil {
return
}
var records map[string]*IPRecord
if err := json.Unmarshal(data, &records); err != nil {
fmt.Fprintf(os.Stderr, "attackdb: error loading %s: %v\n", path, err)
return
}
db.mu.Lock()
for ip, rec := range records {
if normalizeLoadedRecord(rec) {
db.markDirtyLocked(ip)
}
}
db.records = records
db.mu.Unlock()
}
func normalizeLoadedRecord(rec *IPRecord) bool {
changed := false
// Empty maps are an in-memory invariant; bbolt omits them, so nil-to-empty
// alone should not dirty every account-free record on each startup.
if rec.AttackCounts == nil {
rec.AttackCounts = make(map[AttackType]int)
}
if rec.Accounts == nil {
rec.Accounts = make(map[string]int)
}
bruteCount := rec.AttackCounts[AttackBruteForce]
if rec.BruteForceWindowCount > bruteCount {
rec.BruteForceWindowCount = bruteCount
changed = true
}
if rec.BruteForceSustainedAt.IsZero() &&
rec.BruteForceWindowCount >= sustainedBruteForceThreshold {
rec.BruteForceSustainedAt = rec.LastSeen
changed = true
}
score := ComputeScore(rec)
if rec.ThreatScore != score {
rec.ThreatScore = score
changed = true
}
return changed
}
// saveRecords writes records to the bbolt store (if available) or to
// the flat-file records.json.
func (db *DB) saveRecords() {
if sdb := store.Global(); sdb != nil {
// Incremental: only records changed since the last flush are
// re-serialized. On a host tracking tens of thousands of IPs the old
// full rewrite cost seconds of CPU on every 30s flush and on shutdown.
// Snapshot and clear the dirty set under the lock so concurrent
// mutations land in the fresh set for the next flush; a write that
// fails is re-marked dirty below so it retries.
db.mu.Lock()
dirty := db.dirtyIPs
db.dirtyIPs = make(map[string]struct{})
records := make([]store.IPRecord, 0, len(dirty))
for ip := range dirty {
rec, ok := db.records[ip]
if !ok {
continue // removed since marked; deletedIPs carries the removal
}
records = append(records, toStoreIPRecord(rec))
}
var deleted []string
for ip := range db.deletedIPs {
deleted = append(deleted, ip)
}
db.mu.Unlock()
var failed []string
for _, sr := range records {
if err := sdb.SaveIPRecord(sr); err != nil {
fmt.Fprintf(os.Stderr, "attackdb: store save %s: %v\n", sr.IP, err)
failed = append(failed, sr.IP)
}
}
if len(deleted) > 0 {
var removed []string
var failedDeletes []string
for _, ip := range deleted {
if err := sdb.DeleteIPRecord(ip); err != nil {
fmt.Fprintf(os.Stderr, "attackdb: store delete %s: %v\n", ip, err)
failedDeletes = append(failedDeletes, ip)
continue
}
removed = append(removed, ip)
}
if len(removed) > 0 {
db.mu.Lock()
for _, ip := range removed {
delete(db.deletedIPs, ip)
}
db.mu.Unlock()
}
if len(failedDeletes) > 0 {
db.mu.Lock()
db.dirty = true
db.mu.Unlock()
}
}
if len(failed) > 0 {
db.mu.Lock()
for _, ip := range failed {
db.markDirtyLocked(ip)
}
db.mu.Unlock()
}
return
}
// Fallback: flat-file records.json. The whole records map is rewritten
// each flush, so removals are reflected by absence and deletedIPs is
// redundant here -- but it must still be drained or it grows for the
// process lifetime on a host with no bbolt store. Snapshot under the same
// lock as the marshal and swap dirtyIPs before disk I/O, so mutations
// during the write land in a fresh set. Failed writes requeue the snapshot.
db.mu.Lock()
data, err := json.Marshal(db.records)
var drained []string
for ip := range db.deletedIPs {
drained = append(drained, ip)
}
flushedDirty := db.dirtyIPs
db.dirtyIPs = make(map[string]struct{})
db.mu.Unlock()
if err != nil {
fmt.Fprintf(os.Stderr, "attackdb: error marshaling records: %v\n", err)
db.requeueDirty(flushedDirty, len(drained) > 0)
return
}
path := filepath.Join(db.dbPath, recordsFile)
tmpPath := path + ".tmp"
if err := os.WriteFile(tmpPath, data, 0600); err != nil {
fmt.Fprintf(os.Stderr, "attackdb: error writing %s: %v\n", tmpPath, err)
db.requeueDirty(flushedDirty, len(drained) > 0)
return
}
if err := os.Rename(tmpPath, path); err != nil {
fmt.Fprintf(os.Stderr, "attackdb: error renaming %s: %v\n", path, err)
db.requeueDirty(flushedDirty, len(drained) > 0)
return
}
db.mu.Lock()
for _, ip := range drained {
delete(db.deletedIPs, ip)
}
db.mu.Unlock()
}
func (db *DB) requeueDirty(dirty map[string]struct{}, hasPendingDelete bool) {
if len(dirty) == 0 && !hasPendingDelete {
return
}
db.mu.Lock()
if hasPendingDelete {
db.dirty = true
}
for ip := range dirty {
db.markDirtyLocked(ip)
}
db.mu.Unlock()
}
// toStoreIPRecord projects an in-memory record into the store's persistence
// shape, copying the count maps so the store never aliases live maps.
func toStoreIPRecord(rec *IPRecord) store.IPRecord {
sr := store.IPRecord{
IP: rec.IP,
FirstSeen: rec.FirstSeen,
LastSeen: rec.LastSeen,
EventCount: rec.EventCount,
ThreatScore: rec.ThreatScore,
AutoBlocked: rec.AutoBlocked,
BruteForceWindowStart: rec.BruteForceWindowStart,
BruteForceWindowCount: rec.BruteForceWindowCount,
BruteForceSustainedAt: rec.BruteForceSustainedAt,
AttackCounts: make(map[string]int, len(rec.AttackCounts)),
Accounts: make(map[string]int, len(rec.Accounts)),
}
for k, v := range rec.AttackCounts {
sr.AttackCounts[string(k)] = v
}
for k, v := range rec.Accounts {
sr.Accounts[k] = v
}
return sr
}
// appendEvents writes events to the bbolt store (if available) or appends
// to the JSONL file, rotating if needed.
func (db *DB) appendEvents(events []Event) {
if sdb := store.Global(); sdb != nil {
for i, ev := range events {
ts := ev.Timestamp
if ts.IsZero() {
ts = time.Now()
}
se := store.AttackEvent{
Timestamp: ts,
IP: ev.IP,
AttackType: string(ev.AttackType),
CheckName: ev.CheckName,
Severity: ev.Severity,
Account: ev.Account,
Message: ev.Message,
}
if err := sdb.RecordAttackEvent(se, i); err != nil {
fmt.Fprintf(os.Stderr, "attackdb: store event: %v\n", err)
}
}
return
}
// Fallback: flat-file JSONL.
path := filepath.Join(db.dbPath, eventsFile)
// Check file size and rotate if needed
if info, err := os.Stat(path); err == nil && info.Size() > maxEventsBytes {
rotateEventsFile(path)
}
// #nosec G304 -- filepath.Join under operator-configured db.dbPath.
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
fmt.Fprintf(os.Stderr, "attackdb: error opening %s: %v\n", path, err)
return
}
defer func() { _ = f.Close() }()
w := bufio.NewWriter(f)
enc := json.NewEncoder(w)
for _, ev := range events {
_ = enc.Encode(ev)
}
_ = w.Flush()
}
// rotateEventsFile keeps the newest half of the file.
func rotateEventsFile(path string) {
// #nosec G304 -- path is filepath.Join under operator-configured db.dbPath.
data, err := os.ReadFile(path)
if err != nil {
return
}
// Find the midpoint newline
mid := len(data) / 2
for mid < len(data) {
if data[mid] == '\n' {
mid++
break
}
mid++
}
if mid >= len(data) {
return
}
tmpPath := path + ".tmp"
// #nosec G703 -- path is db.path, derived from the operator-configured
// statePath at DB open time (see Open / NewDB).
if err := os.WriteFile(tmpPath, data[mid:], 0600); err != nil {
return
}
_ = os.Rename(tmpPath, path)
}
// QueryEvents reads events for a specific IP from the bbolt store (if
// available) or from the JSONL file. Returns the most recent `limit`
// events in reverse chronological order.
func (db *DB) QueryEvents(ip string, limit int) []Event {
if sdb := store.Global(); sdb != nil {
storeEvents := sdb.QueryAttackEvents(ip, limit)
result := make([]Event, len(storeEvents))
for i, se := range storeEvents {
result[i] = Event{
Timestamp: se.Timestamp,
IP: se.IP,
AttackType: AttackType(se.AttackType),
CheckName: se.CheckName,
Severity: se.Severity,
Account: se.Account,
Message: se.Message,
}
}
return result
}
// Fallback: flat-file JSONL.
path := filepath.Join(db.dbPath, eventsFile)
// #nosec G304 -- filepath.Join under operator-configured db.dbPath.
f, err := os.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var all []Event
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
var ev Event
if err := json.Unmarshal(scanner.Bytes(), &ev); err != nil {
continue
}
if ev.IP == ip {
all = append(all, ev)
}
}
// Return most recent first
if len(all) > limit && limit > 0 {
all = all[len(all)-limit:]
}
// Reverse
for i, j := 0, len(all)-1; i < j; i, j = i+1, j-1 {
all[i], all[j] = all[j], all[i]
}
return all
}
package attackdb
import (
"sort"
"time"
)
const (
sustainedBruteForceThreshold = 50
sustainedBruteForceWindow = 30 * time.Minute
)
// ComputeScore returns a 0-100 local threat score from an IPRecord.
//
// Scoring logic:
// - Volume: min(event_count * 2, 30)
// - Attack type bonuses (non-cumulative per type)
// - Multi-account targeting: +10
// - Auto-blocked floor: 50
// - Hard cap: 100
func ComputeScore(r *IPRecord) int {
return computeScoreAt(r, time.Now())
}
func computeScoreAt(r *IPRecord, now time.Time) int {
score := 0
// Volume component - caps at 30
vol := r.EventCount * 2
if vol > 30 {
vol = 30
}
score += vol
// Attack type bonuses
if r.AttackCounts[AttackC2] > 0 {
score += 35
}
if r.AttackCounts[AttackWebshell] > 0 {
score += 30
}
if r.AttackCounts[AttackPhishing] > 0 {
score += 25
}
if r.AttackCounts[AttackBruteForce] > 0 {
score += 15
}
// The sustained-brute tier is rate-bound and tied to the raw mail-auth
// signal so stale passwords and unrelated brute-force checks cannot become
// block-eligible by slowly accumulating failures over retention.
if hasSustainedBruteForce(r, now) {
score += 30
}
if r.AttackCounts[AttackWAFBlock] > 5 {
score += 10
}
if r.AttackCounts[AttackFileUpload] > 0 {
score += 20
}
// Multi-account targeting
if len(r.Accounts) > 1 {
score += 10
}
// Auto-blocked floor
if r.AutoBlocked && score < 50 {
score = 50
}
// Cap at 100
if score > 100 {
score = 100
}
return score
}
func hasSustainedBruteForce(r *IPRecord, now time.Time) bool {
if r.AttackCounts[AttackBruteForce] < sustainedBruteForceThreshold ||
r.BruteForceSustainedAt.IsZero() {
return false
}
return !r.BruteForceSustainedAt.Before(now.Add(-sustainedBruteForceWindow))
}
// sortRecords sorts by threat score descending, then event count descending.
func sortRecords(recs []*IPRecord) {
sort.Slice(recs, func(i, j int) bool {
if recs[i].ThreatScore != recs[j].ThreatScore {
return recs[i].ThreatScore > recs[j].ThreatScore
}
return recs[i].EventCount > recs[j].EventCount
})
}
package attackdb
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"sync"
"time"
"github.com/pidginhost/csm/internal/store"
)
const statsCacheTTL = 30 * time.Second
// AttackStats contains aggregate statistics for the API and dashboard.
type AttackStats struct {
TotalIPs int `json:"total_ips"`
TotalEvents int `json:"total_events"`
Last24hEvents int `json:"last_24h_events"`
Last7dEvents int `json:"last_7d_events"`
BlockedIPs int `json:"blocked_ips"`
ByType map[AttackType]int `json:"by_type"` // lifetime, aggregated from IPRecord.AttackCounts
ByType24h map[AttackType]int `json:"by_type_24h"` // last 24h, aggregated from the events log
TopAttackers []*IPRecord `json:"top_attackers"`
HourlyBuckets [24]int `json:"hourly_buckets"` // last 24h, index 0 = oldest hour
DailyBuckets [7]int `json:"daily_buckets"` // last 7 days, index 0 = oldest day
}
var (
cachedStats AttackStats
cachedStatsTime time.Time
cachedStatsMu sync.Mutex
)
// Stats returns aggregate statistics, cached for 30 seconds to avoid
// re-scanning the full events.jsonl on every API call.
func (db *DB) Stats() AttackStats {
cachedStatsMu.Lock()
if time.Since(cachedStatsTime) < statsCacheTTL {
s := cachedStats
cachedStatsMu.Unlock()
return s
}
cachedStatsMu.Unlock()
stats := db.computeStats()
cachedStatsMu.Lock()
cachedStats = stats
cachedStatsTime = time.Now()
cachedStatsMu.Unlock()
return stats
}
func (db *DB) computeStats() AttackStats {
now := time.Now()
cutoff24h := now.Add(-24 * time.Hour)
cutoff7d := now.Add(-7 * 24 * time.Hour)
db.mu.RLock()
stats := AttackStats{
TotalIPs: len(db.records),
ByType: make(map[AttackType]int),
ByType24h: make(map[AttackType]int),
}
for _, rec := range db.records {
stats.TotalEvents += rec.EventCount
if rec.AutoBlocked {
stats.BlockedIPs++
}
for atype, count := range rec.AttackCounts {
stats.ByType[atype] += count
}
}
db.mu.RUnlock()
// Compute time-based stats from events log
events := db.readAllEvents()
for _, ev := range events {
if ev.Timestamp.After(cutoff24h) {
stats.Last24hEvents++
// Skip events with no attack type — a malformed or legacy JSONL
// line would otherwise surface as a "" key in the JSON response.
if ev.AttackType != "" {
stats.ByType24h[ev.AttackType]++
}
hoursAgo := int(now.Sub(ev.Timestamp).Hours())
if hoursAgo >= 0 && hoursAgo < 24 {
stats.HourlyBuckets[23-hoursAgo]++
}
}
if ev.Timestamp.After(cutoff7d) {
stats.Last7dEvents++
daysAgo := int(now.Sub(ev.Timestamp).Hours() / 24)
if daysAgo >= 0 && daysAgo < 7 {
stats.DailyBuckets[6-daysAgo]++
}
}
}
stats.TopAttackers = db.TopAttackers(10)
return stats
}
// readAllEvents reads all events for stats computation.
// Uses bbolt store when available, falls back to JSONL file.
func (db *DB) readAllEvents() []Event {
if sdb := store.Global(); sdb != nil {
storeEvents := sdb.ReadAllAttackEvents()
events := make([]Event, 0, len(storeEvents))
for _, se := range storeEvents {
events = append(events, Event{
Timestamp: se.Timestamp,
IP: se.IP,
AttackType: AttackType(se.AttackType),
CheckName: se.CheckName,
Severity: se.Severity,
Account: se.Account,
Message: se.Message,
})
}
return events
}
// Fallback: flat-file events.jsonl
path := filepath.Join(db.dbPath, eventsFile)
// #nosec G304 -- filepath.Join under operator-configured db.dbPath.
f, err := os.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var events []Event
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
var ev Event
if err := json.Unmarshal(scanner.Bytes(), &ev); err != nil {
continue
}
events = append(events, ev)
}
return events
}
package auditd
import (
"os"
"os/exec"
)
const rulesPath = "/etc/audit/rules.d/csm.rules"
const rules = `## Continuous Security Monitor - auditd rules
# Password/auth file changes
-w /etc/shadow -p wa -k csm_shadow_change
-w /etc/passwd -p wa -k csm_passwd_change
-w /etc/group -p wa -k csm_group_change
# SSH config and keys
-w /etc/ssh/sshd_config -p wa -k csm_sshd_change
-w /root/.ssh/authorized_keys -p wa -k csm_root_ssh_keys
# WHM API tokens
-w /var/cpanel/authn/api_tokens_v2/ -p wa -k csm_whm_api_tokens
# Crontab modifications
-w /var/spool/cron/ -p wa -k csm_crontab_change
-w /etc/cron.d/ -p wa -k csm_crond_change
# Password change commands
-w /usr/bin/passwd -p x -k csm_passwd_exec
-w /usr/sbin/chpasswd -p x -k csm_chpasswd_exec
# CSM binary self-protection
-w /opt/csm/csm -p wa -k csm_binary_tamper
-w /etc/csm/csm.yaml -p wa -k csm_config_tamper
-w /opt/csm/csm.yaml -p wa -k csm_config_tamper
# Execution from suspicious locations
-a always,exit -F arch=b64 -S execve -F dir=/tmp -k csm_exec_tmp
-a always,exit -F arch=b64 -S execve -F dir=/dev/shm -k csm_exec_shm
# User account modifications
-w /usr/sbin/useradd -p x -k csm_useradd
-w /usr/sbin/usermod -p x -k csm_usermod
-w /usr/sbin/userdel -p x -k csm_userdel
# AF_ALG socket creation — CVE-2026-31431 "Copy Fail" exploit signature.
# AF_ALG (numeric family 38) is essentially never used by cPanel/PHP
# workloads, so any non-system UID hitting socket(AF_ALG, ...) is suspicious.
# Filter on uid, not auid: service-launched PHP/cPanel workers commonly have
# unset audit login UID while still running as the account user.
# Two rules — b64 covers native 64-bit binaries, b32 closes the i386
# emulation evasion path on x86_64 hosts with 32-bit compat enabled.
-a always,exit -F arch=b64 -S socket -F a0=38 -F uid>=1000 -k csm_af_alg_socket
-a always,exit -F arch=b32 -S socket -F a0=38 -F uid>=1000 -k csm_af_alg_socket
`
func Deploy() error {
// #nosec G306 -- /etc/audit/rules.d/csm.rules is read by the auditd
// tooling (augenrules) on reload; 0640 keeps world-read off.
if err := os.WriteFile(rulesPath, []byte(rules), 0640); err != nil {
return err
}
// Reload auditd rules
return exec.Command("augenrules", "--load").Run()
}
// EnsureDeployed compares the on-disk rules file to the embedded rules
// constant and re-runs Deploy if they differ. Used by the daemon at
// startup so a CSM upgrade that ships new auditd rules does not silently
// remain inactive when the package postinstall did not invoke Deploy.
//
// Returns (redeployed, err): redeployed=true when the file was updated,
// false when it already matched. err is non-nil only when an unexpected
// I/O failure occurred; a missing rules file is treated as "drift" and
// triggers Deploy.
func EnsureDeployed() (bool, error) {
current, err := os.ReadFile(rulesPath)
if err == nil && string(current) == rules {
return false, nil
}
if err != nil && !os.IsNotExist(err) {
return false, err
}
if err := Deploy(); err != nil {
return false, err
}
return true, nil
}
func Remove() {
_ = os.Remove(rulesPath)
_ = exec.Command("augenrules", "--load").Run()
}
// Package bpf provides the shared scaffolding that BPF-backed live monitors
// across the daemon use: a common Backend interface, backend-kind constants
// for operator config, sentinel errors that distinguish "not built" from
// "kernel unsupported", and a per-feature backend metric.
//
// Real BPF code (program loading, ringbuf consumption, capability probing)
// lives behind the linux && bpf build tag in sibling files. Default builds
// compile stubs that report all capabilities as unavailable.
package bpf
import (
"context"
"errors"
"sync"
"github.com/pidginhost/csm/internal/metrics"
)
// Backend is the shape every BPF-backed live monitor implements. The legacy
// fallback for each feature implements the same interface so the coordinator
// hands the daemon a uniform handle.
type Backend interface {
Mode() string
EventCount() uint64
Run(ctx context.Context)
}
// Backend kind constants are the shared internal values used after each
// feature validates its operator-facing config setting. Individual features
// may keep older public values, such as AF_ALG's "auditd", and map them to
// BackendLegacy internally.
const (
BackendAuto = "auto"
BackendBPF = "bpf"
BackendLegacy = "legacy"
BackendNone = "none"
)
// ErrNotBuilt is returned by feature loaders when CSM was built without the
// bpf build tag. The coordinator treats this identically to a kernel that
// lacks the required BPF program type: log it and fall back to legacy.
var ErrNotBuilt = errors.New("BPF support not compiled in (rebuild with -tags bpf)")
// ErrUnsupported is returned when CSM was built with the bpf tag but the
// running kernel does not accept the requested BPF program type. Distinct
// from ErrNotBuilt so operator logs explain whether the fix is "rebuild" or
// "newer kernel".
var ErrUnsupported = errors.New("kernel does not support requested BPF program type")
var (
backendMetric *metrics.GaugeVec
backendMetricOnce sync.Once
activeMu sync.RWMutex
activeKinds map[string]string
)
// MetricFor returns the shared csm_bpf_backend gauge. The feature argument
// is accepted for call-site readability; all features share a single
// GaugeVec distinguished by the label value, not by separate vec instances.
// Registered exactly once across the process. Pair with SetActive instead
// of calling With directly.
func MetricFor(_ string) *metrics.GaugeVec {
backendMetricOnce.Do(func() {
backendMetric = metrics.NewGaugeVec(
"csm_bpf_backend",
"Active backend for each BPF-backed live monitor; 1 for the selected kind, 0 otherwise.",
[]string{"feature", "kind"},
)
metrics.MustRegister("csm_bpf_backend", backendMetric)
})
return backendMetric
}
// SetActive sets the metric series so that exactly one of {bpf, legacy, none}
// is at 1 and the others at 0 for the given feature. Call from the coordinator
// after backend selection. Also remembers the active kind in-process so
// internal/health can render the matching capability string without
// reaching into the metric registry.
func SetActive(feature, active string) {
g := MetricFor(feature)
for _, k := range []string{BackendBPF, BackendLegacy, BackendNone} {
v := 0.0
if k == active {
v = 1.0
}
g.With(feature, k).Set(v)
}
activeMu.Lock()
if activeKinds == nil {
activeKinds = make(map[string]string)
}
activeKinds[feature] = active
activeMu.Unlock()
}
// ActiveKind returns the kind currently selected for the given feature, or
// "" if SetActive was never called for it. internal/health uses this to
// render per-feature capability strings without importing the metrics
// registry.
func ActiveKind(feature string) string {
activeMu.RLock()
defer activeMu.RUnlock()
return activeKinds[feature]
}
package bpf
import "sync"
// Capabilities reports which BPF program types the running kernel can
// actually load and attach. Populated once at daemon startup by Probe.
//
// Each field maps to a kernel feature, not a CSM feature: a single CSM
// feature (e.g. AF_ALG kernel-side blocking) may need multiple capability bits
// (LSMAttach + Ringbuf).
type Capabilities struct {
LSMAttach bool // BPF LSM programs can attach (kernel >= 5.7 with BPF LSM trampoline)
CgroupSock bool // BPF_PROG_TYPE_CGROUP_SOCK_ADDR can attach to cgroup/connect4 (>= 4.10)
Tracepoint bool // BPF_PROG_TYPE_TRACEPOINT can attach to sched/sched_process_exec (>= 4.7)
Ringbuf bool // BPF_MAP_TYPE_RINGBUF available (>= 5.8)
}
// Any reports whether at least one capability is true. Used by callers that
// only need to know "any BPF surface is usable" before deciding on legacy
// vs auto.
func (c Capabilities) Any() bool {
return c.LSMAttach || c.CgroupSock || c.Tracepoint || c.Ringbuf
}
var (
probeOnce sync.Once
probeResult Capabilities
)
// Probe returns the cached BPF capability result for this process. The first
// call performs the privileged load/attach probes; later calls return the same
// value without touching the kernel again.
func Probe() Capabilities {
probeOnce.Do(func() {
probeResult = probeKernel()
})
return probeResult
}
//go:build !(linux && bpf)
package bpf
// probeKernel is the no-tag stub. It returns zero-value Capabilities so that
// all coordinators on this build see "no BPF surface available" and pick the
// legacy backend.
func probeKernel() Capabilities { return Capabilities{} }
package bpf
// dropEventLogStride is how often the drop path logs after the first drop.
// 256 keeps sustained back-pressure visible without flooding the daemon log.
const dropEventLogStride uint64 = 256
func shouldLogDroppedEvent(dropped uint64) bool {
return dropped == 1 || (dropped > 0 && dropped%dropEventLogStride == 0)
}
// Package broadcast provides a one-to-many publish bus for alert.Finding
// events. Subscribers each get a buffered channel; if a subscriber's
// buffer fills, that subscriber drops the message rather than blocking
// the publisher. Used by the SSE event stream and any other in-process
// passive consumer.
//
// This is intentionally separate from the daemon's primary alert pipeline
// (the unbuffered or large-buffered alertCh that feeds Dispatch). The bus
// is a side-channel for observers that should not influence dispatch.
package broadcast
import (
"sync"
"github.com/pidginhost/csm/internal/alert"
)
// defaultMaxSubscribers caps concurrent subscribers so a flood of event-stream
// connections (each a goroutine plus a buffered channel) cannot exhaust the
// daemon's memory. Generous for an operator dashboard.
const defaultMaxSubscribers = 256
// Bus fans out published findings to every subscriber.
type Bus struct {
mu sync.RWMutex
subscribers map[chan alert.Finding]struct{}
buffer int
maxSubs int
closed bool
}
// NewBus constructs a Bus with the given per-subscriber buffer.
// A buffer < 1 falls back to 16.
func NewBus(buffer int) *Bus {
if buffer < 1 {
buffer = 16
}
return &Bus{
subscribers: make(map[chan alert.Finding]struct{}),
buffer: buffer,
maxSubs: defaultMaxSubscribers,
}
}
// SetMaxSubscribers overrides the concurrent-subscriber cap. A value < 1 is
// ignored. Safe to call before the bus is in use.
func (b *Bus) SetMaxSubscribers(n int) {
if n < 1 {
return
}
b.mu.Lock()
b.maxSubs = n
b.mu.Unlock()
}
// TrySubscribe is Subscribe with the concurrent-subscriber cap enforced. It
// returns ok=false when the cap is reached so an untrusted caller (the SSE
// endpoint, reachable with a low-trust read token) cannot open unbounded
// long-lived streams. Use this for externally-driven subscriptions; Subscribe
// remains for trusted in-process consumers.
func (b *Bus) TrySubscribe() (<-chan alert.Finding, bool) {
b.mu.Lock()
defer b.mu.Unlock()
ch := make(chan alert.Finding, b.buffer)
if b.closed {
close(ch)
return ch, true
}
if len(b.subscribers) >= b.maxSubs {
return nil, false
}
b.subscribers[ch] = struct{}{}
return ch, true
}
// Subscribe returns a new buffered channel that receives every published
// finding from this point forward. Caller must Unsubscribe when done to
// release resources.
func (b *Bus) Subscribe() <-chan alert.Finding {
b.mu.Lock()
defer b.mu.Unlock()
ch := make(chan alert.Finding, b.buffer)
if b.closed {
close(ch)
return ch
}
b.subscribers[ch] = struct{}{}
return ch
}
// Unsubscribe removes the channel from the bus and closes it. Safe to
// call with a channel that was never subscribed (noop). The closed
// channel signals the consumer to exit its read loop.
func (b *Bus) Unsubscribe(ch <-chan alert.Finding) {
b.mu.Lock()
defer b.mu.Unlock()
for sub := range b.subscribers {
if (<-chan alert.Finding)(sub) == ch {
delete(b.subscribers, sub)
close(sub)
return
}
}
}
// Publish sends f to every current subscriber. Non-blocking: if a
// subscriber's buffer is full, that delivery is skipped.
func (b *Bus) Publish(f alert.Finding) {
b.mu.RLock()
defer b.mu.RUnlock()
if b.closed {
return
}
for ch := range b.subscribers {
select {
case ch <- f:
default:
// Slow subscriber; drop rather than block.
}
}
}
// Close shuts the bus down and closes every outstanding subscriber channel.
// Idempotent.
func (b *Bus) Close() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
b.closed = true
for ch := range b.subscribers {
close(ch)
delete(b.subscribers, ch)
}
}
package challenge
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
// CaptchaProvider verifies a third-party CAPTCHA token. Implementations
// post the operator's secret + the visitor's response token to the
// provider's siteverify endpoint and return a single bool: did the
// provider accept this submission?
type CaptchaProvider interface {
Name() string
Verify(ctx context.Context, token, remoteIP string) (bool, error)
}
// providerEndpoint is exposed as a package var so tests can point a
// provider at httptest.Server rather than the live siteverify URL.
var providerEndpoint = map[string]string{
"turnstile": "https://challenges.cloudflare.com/turnstile/v0/siteverify",
"hcaptcha": "https://hcaptcha.com/siteverify",
}
// captchaProvider implements both Cloudflare Turnstile and hCaptcha;
// they accept identical request/response shapes (POST form, JSON
// {"success":bool} reply) so a single struct covers both.
type captchaProvider struct {
name string
endpoint string
secret string
client *http.Client
}
// NewCaptchaProvider returns the right provider for the configured
// name. Returns nil + nil when the operator has not enabled CAPTCHA;
// the server treats nil as "feature off".
func NewCaptchaProvider(name, secret string, timeout time.Duration) (CaptchaProvider, error) {
name = strings.ToLower(strings.TrimSpace(name))
if name == "" {
return nil, nil
}
endpoint, ok := providerEndpoint[name]
if !ok {
return nil, fmt.Errorf("unknown captcha provider %q (want turnstile or hcaptcha)", name)
}
if secret == "" {
return nil, fmt.Errorf("captcha provider %q requires secret_key", name)
}
if timeout <= 0 {
timeout = 10 * time.Second
}
return &captchaProvider{
name: name,
endpoint: endpoint,
secret: secret,
client: &http.Client{Timeout: timeout},
}, nil
}
func (p *captchaProvider) Name() string { return p.name }
// Verify posts the operator's secret + the visitor's token to the
// provider. The remoteIP is optional but recommended; both Turnstile
// and hCaptcha accept it for binding the verification to a single
// client. Network errors propagate; a 200 with success=false returns
// (false, nil).
func (p *captchaProvider) Verify(ctx context.Context, token, remoteIP string) (bool, error) {
if token == "" {
return false, errors.New("empty captcha token")
}
form := url.Values{}
form.Set("secret", p.secret)
form.Set("response", token)
if remoteIP != "" {
form.Set("remoteip", remoteIP)
}
// #nosec G704 -- p.endpoint is set from the providerEndpoint package-level map (turnstile / hcaptcha) by NewCaptchaProvider, which rejects unknown names. Not attacker-controlled; SSRF is not possible.
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.endpoint, strings.NewReader(form.Encode()))
if err != nil {
return false, fmt.Errorf("building siteverify request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// #nosec G704 -- same as above: the request URL is hardcoded to a known siteverify endpoint, never operator or attacker input.
resp, err := p.client.Do(req)
if err != nil {
return false, fmt.Errorf("siteverify call: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return false, fmt.Errorf("siteverify status %d", resp.StatusCode)
}
var body struct {
Success bool `json:"success"`
}
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
return false, fmt.Errorf("decoding siteverify response: %w", err)
}
return body.Success, nil
}
package challenge
import (
"context"
"net"
"strings"
"sync"
"time"
)
// crawlerSuffix names a verified-crawler family along with the PTR
// record suffixes that legitimate hosts in that family use.
type crawlerSuffix struct {
name string
domains []string
}
// builtinCrawlers lists the known canonical reverse-DNS suffixes for
// the supported crawler families. Adding a new family means appending
// to this list and documenting the name in csm.yaml's
// challenge.verified_crawlers.providers.
var builtinCrawlers = map[string]crawlerSuffix{
"googlebot": {name: "googlebot", domains: []string{".googlebot.com.", ".google.com."}},
"bingbot": {name: "bingbot", domains: []string{".search.msn.com."}},
}
// Resolver matches the subset of net.Resolver that CrawlerVerifier
// uses, so tests can swap in a fake without spinning up a real DNS
// server.
type Resolver interface {
LookupAddr(ctx context.Context, addr string) (names []string, err error)
LookupHost(ctx context.Context, host string) (addrs []string, err error)
}
// CrawlerVerifier classifies an IP as a verified search crawler iff
// the IP's reverse-DNS PTR matches one of the configured suffixes AND
// the PTR forward-resolves back to the same IP. The verifier caches
// both positive and negative results; positive cache TTL is the
// configured cacheTTL, negative is one-fifth of that to keep a
// transiently-broken resolver from locking out a legitimate crawler.
type CrawlerVerifier struct {
suffixes []crawlerSuffix
resolver Resolver
posTTL time.Duration
negTTL time.Duration
mu sync.Mutex
cache map[string]cacheEntry
maxSize int
}
// crawlerCacheMaxEntries bounds the verifier cache between the daemon's
// 60-second prune ticks. Without a cap a scan from many unique source IPs
// (every miss inserts a negative entry) grows the map to request-rate x 60s,
// an external memory-pressure lever. At the cap an insert first drops expired
// entries, then evicts the soonest-to-expire entry to make room.
const crawlerCacheMaxEntries = 50000
type cacheEntry struct {
verified bool
expires time.Time
}
// NewCrawlerVerifier builds a verifier with the named crawler families
// enabled. Unknown names are ignored (operators may have configured a
// family this binary does not know about; that is harmless).
func NewCrawlerVerifier(providers []string, cacheTTL time.Duration, resolver Resolver) *CrawlerVerifier {
if cacheTTL <= 0 {
cacheTTL = 15 * time.Minute
}
if resolver == nil {
resolver = net.DefaultResolver
}
enabled := make([]crawlerSuffix, 0, len(providers))
for _, name := range providers {
if c, ok := builtinCrawlers[strings.ToLower(strings.TrimSpace(name))]; ok {
enabled = append(enabled, c)
}
}
return &CrawlerVerifier{
suffixes: enabled,
resolver: resolver,
posTTL: cacheTTL,
negTTL: cacheTTL / 5,
cache: make(map[string]cacheEntry),
maxSize: crawlerCacheMaxEntries,
}
}
// Enabled reports whether at least one crawler family is configured;
// the server uses this to skip the verifier entirely (no DNS round
// trip) when the operator has not opted in.
func (v *CrawlerVerifier) Enabled() bool {
return v != nil && len(v.suffixes) > 0
}
// Verified does the reverse-DNS + forward-confirm dance for ip and
// caches the result. Returns true only when the PTR ends in one of the
// allowed suffixes AND a forward lookup of the PTR includes ip in the
// result set.
func (v *CrawlerVerifier) Verified(ctx context.Context, ip string) bool {
if !v.Enabled() {
return false
}
if hit, ok := v.cacheGet(ip); ok {
return hit
}
verified := v.probe(ctx, ip)
v.cachePut(ip, verified)
return verified
}
func (v *CrawlerVerifier) probe(ctx context.Context, ip string) bool {
names, err := v.resolver.LookupAddr(ctx, ip)
if err != nil || len(names) == 0 {
return false
}
for _, name := range names {
// LookupAddr returns FQDNs with a trailing dot. The suffix
// list also has trailing dots so HasSuffix is unambiguous.
lower := strings.ToLower(name)
if !v.suffixMatches(lower) {
continue
}
addrs, err := v.resolver.LookupHost(ctx, strings.TrimSuffix(lower, "."))
if err != nil {
continue
}
for _, addr := range addrs {
if addr == ip {
return true
}
}
}
return false
}
func (v *CrawlerVerifier) suffixMatches(name string) bool {
for _, s := range v.suffixes {
for _, d := range s.domains {
if strings.HasSuffix(name, d) {
return true
}
}
}
return false
}
func (v *CrawlerVerifier) cacheGet(ip string) (bool, bool) {
v.mu.Lock()
defer v.mu.Unlock()
e, ok := v.cache[ip]
if !ok {
return false, false
}
if time.Now().After(e.expires) {
delete(v.cache, ip)
return false, false
}
return e.verified, true
}
func (v *CrawlerVerifier) cachePut(ip string, verified bool) {
ttl := v.negTTL
if verified {
ttl = v.posTTL
}
v.mu.Lock()
defer v.mu.Unlock()
if _, exists := v.cache[ip]; !exists && v.maxSize > 0 && len(v.cache) >= v.maxSize {
v.evictForInsertLocked()
}
v.cache[ip] = cacheEntry{verified: verified, expires: time.Now().Add(ttl)}
}
// evictForInsertLocked frees a slot when the cache is at capacity. It first
// drops every expired entry; if that reclaimed nothing (all entries still
// live), it evicts the single soonest-to-expire entry. Caller holds v.mu.
func (v *CrawlerVerifier) evictForInsertLocked() {
now := time.Now()
freed := false
for ip, e := range v.cache {
if now.After(e.expires) {
delete(v.cache, ip)
freed = true
}
}
if freed {
return
}
var soonestIP string
var soonest time.Time
for ip, e := range v.cache {
if soonestIP == "" || e.expires.Before(soonest) {
soonestIP = ip
soonest = e.expires
}
}
if soonestIP != "" {
delete(v.cache, soonestIP)
}
}
// cleanExpired drops every entry whose TTL has lapsed. Without this,
// a scan from many IPs leaves the cache full of stale entries until
// each individual IP is queried again. Called from Server.CleanExpired
// on the daemon's 60-second ticker. now is passed in so the caller can
// share a single timestamp across multiple cleanup paths.
func (v *CrawlerVerifier) cleanExpired(now time.Time) {
if v == nil {
return
}
v.mu.Lock()
defer v.mu.Unlock()
for ip, e := range v.cache {
if now.After(e.expires) {
delete(v.cache, ip)
}
}
}
package challenge
import (
"bytes"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
)
// DefaultMapPath is the webserver-readable Apache / LSWS RewriteMap.
// It lives under /run rather than state_path because state_path is mode
// 0700 and must stay private to CSM's bbolt database.
const DefaultMapPath = "/run/csm/challenge_ips.txt"
// DefaultNginxMapPath is the webserver-readable Nginx map include.
const DefaultNginxMapPath = "/run/csm/challenge_ips.nginx.map"
// challengeEntry stores the challenge metadata for a single IP.
type challengeEntry struct {
ExpiresAt time.Time
Reason string
NonEscalating bool
}
// ExpiredEntry is returned by ExpiredEntries for escalation.
type ExpiredEntry struct {
IP string
Reason string
}
// IPList manages the set of IPs that should see challenge pages.
// Webserver integrations read its maps to redirect IPs to the challenge server.
type IPList struct {
path string
nginxPath string
nginxReload func() error
ips map[string]challengeEntry
mu sync.Mutex
gate PortGate
}
// NewIPList creates an IP list writer.
func NewIPList(statePath string) *IPList {
return NewIPListWithMapPath(statePath, filepath.Join(statePath, "challenge_ips.txt"))
}
// NewIPListWithMapPath creates an IP list writer with an explicit
// webserver-facing map path.
func NewIPListWithMapPath(statePath, mapPath string) *IPList {
if strings.TrimSpace(mapPath) == "" {
mapPath = filepath.Join(statePath, "challenge_ips.txt")
}
l := &IPList{
path: mapPath,
ips: make(map[string]challengeEntry),
}
_ = l.flush()
return l
}
// SetPortGate attaches a PortGate so every Add/Remove also opens or
// closes the kernel-level allow. Nil is a no-op (callers don't have to
// branch on whether the gate is configured). Safe to call before any
// Add/Remove; not safe to swap a non-nil gate for another at runtime.
func (l *IPList) SetPortGate(g PortGate) {
l.mu.Lock()
defer l.mu.Unlock()
l.gate = g
}
// SetNginxMap attaches a second map writer for Nginx stacks. The
// callback runs only when the rendered include content changes.
func (l *IPList) SetNginxMap(path string, reload func() error) {
if strings.TrimSpace(path) == "" {
path = DefaultNginxMapPath
}
l.mu.Lock()
l.nginxPath = path
l.nginxReload = reload
changed := l.flush()
l.mu.Unlock()
l.reloadNginx(changed, reload)
}
// Add marks an IP for challenge with the given reason.
func (l *IPList) Add(ip string, reason string, duration time.Duration) {
l.add(ip, reason, duration, false)
}
// AddNonEscalating marks an IP for challenge without timeout-to-block escalation.
func (l *IPList) AddNonEscalating(ip string, reason string, duration time.Duration) {
l.add(ip, reason, duration, true)
}
func (l *IPList) add(ip string, reason string, duration time.Duration, nonEscalating bool) {
l.mu.Lock()
l.ips[ip] = challengeEntry{
ExpiresAt: time.Now().Add(duration),
Reason: reason,
NonEscalating: nonEscalating,
}
changed := l.flush()
gate := l.gate
reload := l.nginxReload
l.mu.Unlock()
if gate != nil {
if err := gate.Allow(ip, duration); err != nil {
fmt.Fprintf(os.Stderr, "challenge: port-gate allow %s: %v\n", ip, err)
}
}
l.reloadNginx(changed, reload)
}
// Remove stops challenging an IP (passed or manually removed).
func (l *IPList) Remove(ip string) {
l.mu.Lock()
delete(l.ips, ip)
changed := l.flush()
gate := l.gate
reload := l.nginxReload
l.mu.Unlock()
if gate != nil {
if err := gate.Revoke(ip); err != nil {
fmt.Fprintf(os.Stderr, "challenge: port-gate revoke %s: %v\n", ip, err)
}
}
l.reloadNginx(changed, reload)
}
// Contains returns true if the IP is currently on the challenge list.
func (l *IPList) Contains(ip string) bool {
l.mu.Lock()
defer l.mu.Unlock()
_, ok := l.ips[ip]
return ok
}
// Count returns the number of IPs currently waiting on a challenge.
func (l *IPList) Count() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.ips)
}
// ExpiredEntries removes expired entries and returns those eligible for escalation.
// The caller is expected to hard-block returned IPs.
func (l *IPList) ExpiredEntries() []ExpiredEntry {
l.mu.Lock()
now := time.Now()
var expired []ExpiredEntry
removed := false
for ip, entry := range l.ips {
if now.After(entry.ExpiresAt) {
if !entry.NonEscalating {
expired = append(expired, ExpiredEntry{IP: ip, Reason: entry.Reason})
}
delete(l.ips, ip)
removed = true
}
}
var changed bool
if removed {
changed = l.flush()
}
reload := l.nginxReload
l.mu.Unlock()
l.reloadNginx(changed, reload)
return expired
}
// CleanExpired removes expired entries without returning them.
// Use ExpiredEntries() instead when escalation is needed.
func (l *IPList) CleanExpired() {
_ = l.ExpiredEntries()
}
// flush writes the IP list to disk in each configured webserver format.
// The caller must hold l.mu. It returns true when the Nginx include
// changed and needs a reload.
func (l *IPList) flush() bool {
ips := sortedIPKeys(l.ips)
var sb strings.Builder
sb.WriteString("# CSM Challenge IP list - auto-generated, do not edit\n")
sb.WriteString("# Format: IP challenge (for Apache RewriteMap)\n")
for _, ip := range ips {
fmt.Fprintf(&sb, "%s challenge\n", ip)
}
if err := writeMapFile(l.path, []byte(sb.String())); err != nil {
return false
}
if strings.TrimSpace(l.nginxPath) == "" {
return false
}
var nginx strings.Builder
nginx.WriteString("# CSM Challenge IP list - auto-generated, do not edit\n")
nginx.WriteString("# Format: IP 1; (for Nginx map include)\n")
for _, ip := range ips {
fmt.Fprintf(&nginx, "%s 1;\n", ip)
}
changed, err := writeMapFileIfChanged(l.nginxPath, []byte(nginx.String()))
return err == nil && changed
}
func sortedIPKeys(ips map[string]challengeEntry) []string {
keys := make([]string, 0, len(ips))
for ip := range ips {
keys = append(keys, ip)
}
sort.Strings(keys)
return keys
}
func writeMapFileIfChanged(path string, data []byte) (bool, error) {
// #nosec G304 -- path is the daemon-owned challenge map under /run/csm
// (DefaultNginxMapPath or operator-set via SetNginxMap), never
// attacker-controlled. Read only to diff the rendered map and skip
// rewrite when content is unchanged.
if current, err := os.ReadFile(path); err == nil && bytes.Equal(current, data) {
return false, nil
}
return true, writeMapFile(path, data)
}
func writeMapFile(path string, data []byte) error {
// /run/csm must be world-readable so the webserver user
// (www-data / nobody / lsws) can stat + read the map underneath.
// The directory holds no sensitive data; only CSM-owned IP files
// live inside.
//
// MkdirAll respects the process umask, so on cPanel/CloudLinux
// hosts where csm.service inherits umask 027 the directory ends
// up at 0o750 and the webserver gets EACCES on the RewriteMap.
// Explicit Chmod after creation forces the mode the integration
// requires regardless of umask.
mapDir := filepath.Dir(path)
// #nosec G301 -- world-readable rationale above.
if err := os.MkdirAll(mapDir, 0o755); err != nil {
return err
}
// #nosec G302 -- same world-readable rationale; needed for the
// webserver user to stat into the directory and read the map.
if err := os.Chmod(mapDir, 0o755); err != nil {
return err
}
tmpPath := path + ".tmp"
// #nosec G306 -- webservers read this map file directly. It has to
// be readable by the webserver user. No sensitive data; only a list
// of IPs that must re-solve the PoW challenge.
if err := os.WriteFile(tmpPath, data, 0o644); err != nil {
return err
}
return os.Rename(tmpPath, path)
}
func (l *IPList) reloadNginx(changed bool, reload func() error) {
if !changed || reload == nil {
return
}
if err := reload(); err != nil {
fmt.Fprintf(os.Stderr, "challenge: nginx reload after map update: %v\n", err)
}
}
package challenge
import (
"net"
"strings"
"time"
)
// PortGate locks the challenge listener TCP port to specific source IPs
// via the host firewall. An IP is allowed only while it is on the
// challenge IPList (plus operator infra IPs and loopback). Everything
// else gets dropped at the kernel before the listener sees the SYN, so
// the listener is invisible to port scanners and stays reachable only
// for the visitors the daemon has actually redirected.
//
// Implementations are pluggable so the netlink-backed Linux variant
// can be swapped for a stub on platforms that do not have nftables.
// All methods are safe to call on a nil PortGate (no-op), so callers
// do not need to nil-check at every IPList Add/Remove site.
type PortGate interface {
// Allow opens the gate for the source IP for at most ttl. The
// underlying firewall enforces the TTL via the set's own timeout
// so the entry expires even if Revoke is never called (daemon
// crash, missed expiry). Returns nil on success or when the IP
// cannot be parsed (best-effort; the IPList accepts only validated
// IPs upstream, so a parse miss here is a bug to log, not block).
Allow(ip string, ttl time.Duration) error
// Revoke closes the gate for ip immediately. Safe to call for IPs
// that were never on the gate (no-op).
Revoke(ip string) error
// Close tears down the gate's nftables footprint (chain, sets,
// table). The port reverts to whatever the rest of the host
// firewall would do with it.
Close() error
}
// PortGateConfig wraps the inputs the gate needs to install rules.
type PortGateConfig struct {
ListenAddr string
ListenPort int
InfraCIDRs []string
}
// NewPortGate returns the platform-appropriate gate. On Linux it
// installs a dedicated `csm_chal` inet table; on non-Linux it returns
// nil so callers naturally no-op via the nil PortGate handling on the
// IPList side.
//
// Returns nil + nil when the listen address is loopback because no
// gate is needed (loopback traffic cannot originate from off-host).
// Caller treats nil as "gate not active" and proceeds without it.
func NewPortGate(cfg PortGateConfig) (PortGate, error) {
if isLoopbackListenAddr(cfg.ListenAddr) {
return nil, nil
}
return newPortGate(cfg)
}
// portGateFamily picks which address families the gate should bind to
// based on the listen address. 0.0.0.0 / blank -> v4 only; :: -> dual
// stack; a specific literal IP gates only that family.
type portGateFamily struct {
v4 bool
v6 bool
}
func familyForListenAddr(addr string) portGateFamily {
addr = strings.Trim(strings.TrimSpace(addr), "[]")
switch addr {
case "", "0.0.0.0":
return portGateFamily{v4: true}
case "::":
return portGateFamily{v4: true, v6: true}
}
ip := net.ParseIP(addr)
if ip == nil {
return portGateFamily{v4: true}
}
if ip.To4() != nil {
return portGateFamily{v4: true}
}
return portGateFamily{v6: true}
}
func portGateFamilyAcceptsIP(fam portGateFamily, ip net.IP) bool {
if ip == nil {
return false
}
if ip.To4() != nil {
return fam.v4
}
return fam.v6
}
//go:build linux
package challenge
import (
"fmt"
"net"
"sync"
"time"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"golang.org/x/sys/unix"
)
// linuxPortGate owns the nftables `csm_chal` table. The table is kept
// separate from the firewall package's `csm` table so the two can be
// installed and torn down independently (challenge can run with or
// without csm.firewall enabled).
type linuxPortGate struct {
mu sync.Mutex
conn *nftables.Conn
cfg PortGateConfig
family portGateFamily
table *nftables.Table
setChalIPs *nftables.Set
setChalIPs6 *nftables.Set
setInfra *nftables.Set
setInfra6 *nftables.Set
}
func newPortGate(cfg PortGateConfig) (PortGate, error) {
if cfg.ListenPort <= 0 || cfg.ListenPort > 65535 {
return nil, fmt.Errorf("port-gate: invalid listen port %d", cfg.ListenPort)
}
fam := familyForListenAddr(cfg.ListenAddr)
conn, err := nftables.New()
if err != nil {
return nil, fmt.Errorf("port-gate: nftables open: %w", err)
}
g := &linuxPortGate{conn: conn, cfg: cfg, family: fam}
if err := g.install(); err != nil {
return nil, err
}
return g, nil
}
// install lays down the table, sets, chain, and per-port rules. Any
// pre-existing `csm_chal` table is deleted first so a daemon restart
// always converges on a clean rule shape (and stale rules from a
// crashed previous run do not linger).
func (g *linuxPortGate) install() error {
g.mu.Lock()
defer g.mu.Unlock()
g.dropExistingTableLocked()
g.table = g.conn.AddTable(&nftables.Table{
Family: nftables.TableFamilyINet,
Name: "csm_chal",
})
if g.family.v4 {
g.setChalIPs = &nftables.Set{
Table: g.table,
Name: "chal_ips",
KeyType: nftables.TypeIPAddr,
HasTimeout: true,
}
if err := g.conn.AddSet(g.setChalIPs, nil); err != nil {
return fmt.Errorf("port-gate: add chal_ips: %w", err)
}
g.setInfra = &nftables.Set{
Table: g.table,
Name: "chal_infra",
KeyType: nftables.TypeIPAddr,
Interval: true,
}
if err := g.conn.AddSet(g.setInfra, infraElementsV4(g.cfg.InfraCIDRs)); err != nil {
return fmt.Errorf("port-gate: add chal_infra: %w", err)
}
}
if g.family.v6 {
g.setChalIPs6 = &nftables.Set{
Table: g.table,
Name: "chal_ips6",
KeyType: nftables.TypeIP6Addr,
HasTimeout: true,
}
if err := g.conn.AddSet(g.setChalIPs6, nil); err != nil {
return fmt.Errorf("port-gate: add chal_ips6: %w", err)
}
g.setInfra6 = &nftables.Set{
Table: g.table,
Name: "chal_infra6",
KeyType: nftables.TypeIP6Addr,
Interval: true,
}
if err := g.conn.AddSet(g.setInfra6, infraElementsV6(g.cfg.InfraCIDRs)); err != nil {
return fmt.Errorf("port-gate: add chal_infra6: %w", err)
}
}
prio := nftables.ChainPriority(-200)
policy := nftables.ChainPolicyAccept
chain := g.conn.AddChain(&nftables.Chain{
Name: "challenge_gate",
Table: g.table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: &prio,
Policy: &policy,
})
g.addAcceptRules(chain)
g.addDropRule(chain)
if err := g.conn.Flush(); err != nil {
return fmt.Errorf("port-gate: install flush: %w", err)
}
return nil
}
// dropExistingTableLocked is idempotent: ListTables + DelTable so a
// re-install does not stack rules on top of a stale chain.
func (g *linuxPortGate) dropExistingTableLocked() {
tables, err := g.conn.ListTables()
if err != nil {
return
}
for _, t := range tables {
if t.Family == nftables.TableFamilyINet && t.Name == "csm_chal" {
g.conn.DelTable(t)
_ = g.conn.Flush()
return
}
}
}
func (g *linuxPortGate) addAcceptRules(chain *nftables.Chain) {
port := portU16(g.cfg.ListenPort)
if g.family.v4 {
// loopback bypass
g.addRule(chain, exprsTCPDportFromV4(port, net.IPv4(127, 0, 0, 0).To4(), net.IPv4Mask(255, 0, 0, 0), expr.VerdictAccept))
g.addRule(chain, exprsTCPDportSetMatchV4(port, g.setInfra, expr.VerdictAccept))
g.addRule(chain, exprsTCPDportSetMatchV4(port, g.setChalIPs, expr.VerdictAccept))
}
if g.family.v6 {
loop6 := net.ParseIP("::1").To16()
mask128 := net.CIDRMask(128, 128)
g.addRule(chain, exprsTCPDportFromV6(port, loop6, mask128, expr.VerdictAccept))
g.addRule(chain, exprsTCPDportSetMatchV6(port, g.setInfra6, expr.VerdictAccept))
g.addRule(chain, exprsTCPDportSetMatchV6(port, g.setChalIPs6, expr.VerdictAccept))
}
}
func (g *linuxPortGate) addDropRule(chain *nftables.Chain) {
port := portU16(g.cfg.ListenPort)
// Any packet that reached this rule with dport == challenge port
// did not match an accept above; drop it.
g.conn.AddRule(&nftables.Rule{
Table: g.table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_TCP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(port)},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
func (g *linuxPortGate) addRule(chain *nftables.Chain, exprs []expr.Any) {
g.conn.AddRule(&nftables.Rule{Table: g.table, Chain: chain, Exprs: exprs})
}
func (g *linuxPortGate) Allow(ip string, ttl time.Duration) error {
parsed := net.ParseIP(ip)
if parsed == nil {
return fmt.Errorf("port-gate: invalid ip %q", ip)
}
if ttl <= 0 {
ttl = 30 * time.Minute
}
if !portGateFamilyAcceptsIP(g.family, parsed) {
return nil
}
g.mu.Lock()
defer g.mu.Unlock()
if ip4 := parsed.To4(); ip4 != nil && g.setChalIPs != nil {
if err := g.conn.SetAddElements(g.setChalIPs, []nftables.SetElement{
{Key: ip4, Timeout: ttl},
}); err != nil {
return fmt.Errorf("port-gate: add v4 %s: %w", ip, err)
}
return g.conn.Flush()
}
if g.setChalIPs6 != nil {
if err := g.conn.SetAddElements(g.setChalIPs6, []nftables.SetElement{
{Key: parsed.To16(), Timeout: ttl},
}); err != nil {
return fmt.Errorf("port-gate: add v6 %s: %w", ip, err)
}
return g.conn.Flush()
}
// IP family is not gated (e.g., v6 IP on a v4-only listener); silently
// no-op so the IPList Add path does not surface an unactionable error.
return nil
}
func (g *linuxPortGate) Revoke(ip string) error {
parsed := net.ParseIP(ip)
if parsed == nil {
return fmt.Errorf("port-gate: invalid ip %q", ip)
}
if !portGateFamilyAcceptsIP(g.family, parsed) {
return nil
}
g.mu.Lock()
defer g.mu.Unlock()
if ip4 := parsed.To4(); ip4 != nil && g.setChalIPs != nil {
if err := g.conn.SetDeleteElements(g.setChalIPs, []nftables.SetElement{{Key: ip4}}); err != nil {
return fmt.Errorf("port-gate: del v4 %s: %w", ip, err)
}
return g.conn.Flush()
}
if g.setChalIPs6 != nil {
if err := g.conn.SetDeleteElements(g.setChalIPs6, []nftables.SetElement{{Key: parsed.To16()}}); err != nil {
return fmt.Errorf("port-gate: del v6 %s: %w", ip, err)
}
return g.conn.Flush()
}
return nil
}
func (g *linuxPortGate) Close() error {
g.mu.Lock()
defer g.mu.Unlock()
if g.table == nil {
return nil
}
g.conn.DelTable(g.table)
if err := g.conn.Flush(); err != nil {
return fmt.Errorf("port-gate: close flush: %w", err)
}
g.table = nil
g.setChalIPs = nil
g.setChalIPs6 = nil
g.setInfra = nil
g.setInfra6 = nil
return nil
}
func portU16(p int) uint16 {
if p < 0 || p > 65535 {
return 0
}
// #nosec G115 -- bounds-checked above.
return uint16(p)
}
// infraElementsV4 builds nftables interval-set elements from the
// operator's infra_ips list, keeping only IPv4 entries. The interval
// set wants [start, end) pairs; net.ParseCIDR + binaryutil pack them
// the same way the firewall engine's infra set does.
func infraElementsV4(cidrs []string) []nftables.SetElement {
var out []nftables.SetElement
for _, raw := range cidrs {
ipnet := parseCIDROrIP(raw)
if ipnet == nil {
continue
}
start := ipnet.IP.To4()
if start == nil {
continue
}
end := lastIPv4(ipnet)
out = append(out,
nftables.SetElement{Key: start},
nftables.SetElement{Key: ipv4Inc(end), IntervalEnd: true},
)
}
return out
}
func infraElementsV6(cidrs []string) []nftables.SetElement {
var out []nftables.SetElement
for _, raw := range cidrs {
ipnet := parseCIDROrIP(raw)
if ipnet == nil {
continue
}
if ipnet.IP.To4() != nil {
continue
}
start := ipnet.IP.To16()
end := lastIPv6(ipnet)
out = append(out,
nftables.SetElement{Key: start},
nftables.SetElement{Key: ipv6Inc(end), IntervalEnd: true},
)
}
return out
}
// parseCIDROrIP accepts both "1.2.3.4" and "1.2.3.0/24". A bare IP is
// treated as a /32 (or /128 for v6).
func parseCIDROrIP(raw string) *net.IPNet {
if _, ipnet, err := net.ParseCIDR(raw); err == nil {
return ipnet
}
ip := net.ParseIP(raw)
if ip == nil {
return nil
}
if v4 := ip.To4(); v4 != nil {
return &net.IPNet{IP: v4, Mask: net.CIDRMask(32, 32)}
}
return &net.IPNet{IP: ip.To16(), Mask: net.CIDRMask(128, 128)}
}
func lastIPv4(n *net.IPNet) net.IP {
ip := n.IP.To4()
out := make(net.IP, 4)
for i := 0; i < 4; i++ {
out[i] = ip[i] | ^n.Mask[i]
}
return out
}
func lastIPv6(n *net.IPNet) net.IP {
ip := n.IP.To16()
out := make(net.IP, 16)
for i := 0; i < 16; i++ {
out[i] = ip[i] | ^n.Mask[i]
}
return out
}
func ipv4Inc(ip net.IP) net.IP {
out := make(net.IP, 4)
copy(out, ip.To4())
for i := 3; i >= 0; i-- {
out[i]++
if out[i] != 0 {
return out
}
}
return out
}
func ipv6Inc(ip net.IP) net.IP {
out := make(net.IP, 16)
copy(out, ip.To16())
for i := 15; i >= 0; i-- {
out[i]++
if out[i] != 0 {
return out
}
}
return out
}
// exprsTCPDportFromV4 produces "L4=TCP, src in CIDR, dport=port -> verdict".
// Mask is the 4-byte IPv4 subnet mask.
func exprsTCPDportFromV4(port uint16, network net.IP, mask net.IPMask, verdict expr.VerdictKind) []expr.Any {
return []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_TCP}},
&expr.Meta{Key: expr.MetaKeyNFPROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.NFPROTO_IPV4}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Bitwise{SourceRegister: 1, DestRegister: 1, Len: 4, Mask: mask, Xor: []byte{0, 0, 0, 0}},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: network},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(port)},
&expr.Verdict{Kind: verdict},
}
}
func exprsTCPDportFromV6(port uint16, network net.IP, mask net.IPMask, verdict expr.VerdictKind) []expr.Any {
return []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_TCP}},
&expr.Meta{Key: expr.MetaKeyNFPROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.NFPROTO_IPV6}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 8, Len: 16},
&expr.Bitwise{SourceRegister: 1, DestRegister: 1, Len: 16, Mask: mask, Xor: make([]byte, 16)},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: network},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(port)},
&expr.Verdict{Kind: verdict},
}
}
func exprsTCPDportSetMatchV4(port uint16, set *nftables.Set, verdict expr.VerdictKind) []expr.Any {
return []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_TCP}},
&expr.Meta{Key: expr.MetaKeyNFPROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.NFPROTO_IPV4}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Lookup{SourceRegister: 1, SetName: set.Name, SetID: set.ID},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(port)},
&expr.Verdict{Kind: verdict},
}
}
func exprsTCPDportSetMatchV6(port uint16, set *nftables.Set, verdict expr.VerdictKind) []expr.Any {
return []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_TCP}},
&expr.Meta{Key: expr.MetaKeyNFPROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.NFPROTO_IPV6}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 8, Len: 16},
&expr.Lookup{SourceRegister: 1, SetName: set.Name, SetID: set.ID},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(port)},
&expr.Verdict{Kind: verdict},
}
}
package challenge
import (
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"html"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/config"
)
// IPUnblocker is the interface for temporarily allowing an IP.
type IPUnblocker interface {
TempAllowIP(ip string, reason string, timeout time.Duration) error
}
// Server serves challenge pages to gray-listed IPs.
// When an IP passes the challenge, it gets a temporary allow.
type Server struct {
cfg *config.Config
secret []byte
unblocker IPUnblocker
ipList *IPList
srv *http.Server
trustedProxies map[string]bool
// Track recently verified IPs to prevent replay
verified map[string]time.Time
verifiedMu sync.Mutex
// Optional subsystems wired in by configuration. Any of these may
// be nil; the handlers branch accordingly so a fresh deployment
// without the new blocks behaves exactly as before.
captcha CaptchaProvider
sessionSigner *AdminSessionSigner
crawlers *CrawlerVerifier
// verifySigner signs the csm_verified cookie handed to every visitor
// who passes the PoW/CAPTCHA. It binds the cookie to one IP and an
// expiry so a returning visitor skips the gate for the allow window
// without it being replayable elsewhere. Always constructed; the key
// rotates on restart like sessionSigner.
verifySigner *AdminSessionSigner
// adminFailures tracks failed admin-token submissions per source
// IP for rate-limiting brute-force probes. Sliding window: an IP
// that hits adminMaxFailuresInWindow within adminFailureWindow
// gets locked out (subsequent submissions return 429) until the
// oldest failure ages out.
adminFailures map[string][]time.Time
adminFailuresMu sync.Mutex
}
const (
adminFailureWindow = 5 * time.Minute
adminMaxFailuresInWindow = 5
// verifyCookieTTL is the lifetime of the csm_verified bypass cookie. It
// matches the firewall allow window applied in markVerified.
verifyCookieTTL = 4 * time.Hour
)
// New creates a challenge server.
func New(cfg *config.Config, unblocker IPUnblocker, ipList *IPList) *Server {
secret := []byte(cfg.Challenge.Secret)
if len(secret) == 0 {
secret = make([]byte, 32)
_, _ = rand.Read(secret)
}
trusted := make(map[string]bool)
for _, p := range cfg.Challenge.TrustedProxies {
p = strings.TrimSpace(p)
if p == "" {
continue
}
trusted[canonicalIP(p)] = true
}
s := &Server{
cfg: cfg,
secret: secret,
unblocker: unblocker,
ipList: ipList,
trustedProxies: trusted,
verified: make(map[string]time.Time),
adminFailures: make(map[string][]time.Time),
}
// The verify-cookie signer is always on (independent of the optional
// admin-session feature). TTL matches the markVerified allow window.
if vs, err := NewAdminSessionSigner(verifyCookieTTL); err != nil {
fmt.Fprintf(os.Stderr, "[challenge] verify-cookie signing disabled: %v\n", err)
} else {
s.verifySigner = vs
}
// Optional sub-features. Each is opt-in via its own config block;
// initialization errors degrade to "feature off" with a stderr
// message rather than refusing to start the challenge server.
if name := cfg.Challenge.CaptchaFallback.Provider; name != "" {
p, err := NewCaptchaProvider(name, cfg.Challenge.CaptchaFallback.SecretKey, cfg.Challenge.CaptchaFallback.Timeout)
if err != nil {
fmt.Fprintf(os.Stderr, "[challenge] captcha disabled: %v\n", err)
}
s.captcha = p
}
if cfg.Challenge.VerifiedSession.Enabled {
signer, err := NewAdminSessionSigner(cfg.Challenge.VerifiedSession.TTL)
if err != nil {
fmt.Fprintf(os.Stderr, "[challenge] verified-session disabled: %v\n", err)
}
s.sessionSigner = signer
}
if cfg.Challenge.VerifiedCrawlers.Enabled {
s.crawlers = NewCrawlerVerifier(
cfg.Challenge.VerifiedCrawlers.Providers,
cfg.Challenge.VerifiedCrawlers.CacheTTL,
nil, // net.DefaultResolver
)
}
mux := http.NewServeMux()
mux.HandleFunc("/challenge", s.handleChallenge)
mux.HandleFunc("/challenge/gate", s.handleGate)
mux.HandleFunc("/challenge/verify", s.handleVerify)
mux.HandleFunc("/challenge/captcha-verify", s.handleCaptchaVerify)
mux.HandleFunc("/challenge/admin-token", s.handleAdminToken)
bindAddr := cfg.Challenge.ListenAddr
if bindAddr == "" {
// Defaults are applied during config.Load, but tests construct
// Server directly with an empty config; default to loopback so
// the production safety guarantee (never bind public by default)
// holds even on those paths.
bindAddr = "127.0.0.1"
}
s.srv = &http.Server{
Addr: net.JoinHostPort(bindAddr, strconv.Itoa(cfg.Challenge.ListenPort)),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ReadHeaderTimeout: 5 * time.Second,
}
return s
}
// Start begins serving challenge pages. Explicit challenge TLS makes the
// listener HTTPS. Direct/public listeners can reuse the WebUI TLS pair.
// Loopback listeners stay plain HTTP by default.
//
// Resolution order:
// 1. challenge.tls_cert + challenge.tls_key (explicit per-service)
// 2. webui.tls_cert + webui.tls_key (direct/public binds only)
// 3. plain HTTP (loopback-only default)
func (s *Server) Start() error {
cert, key := s.resolveTLSMaterial()
if cert == "" || key == "" {
if !isLoopbackListenAddr(s.cfg.Challenge.ListenAddr) {
fmt.Fprintf(os.Stderr,
"[%s] WARNING: public challenge listener has no complete TLS cert/key configured; HSTS-pinned domains will fail with ERR_SSL_PROTOCOL_ERROR\n",
time.Now().Format("2006-01-02 15:04:05"))
}
return s.srv.ListenAndServe()
}
return s.srv.ListenAndServeTLS(cert, key)
}
// resolveTLSMaterial picks the cert / key pair the challenge listener
// should present. The WebUI fallback is only safe for direct/public binds;
// loopback listeners stay plain HTTP unless challenge TLS is explicitly
// configured.
func (s *Server) resolveTLSMaterial() (cert, key string) {
if c, k := s.cfg.Challenge.TLSCert, s.cfg.Challenge.TLSKey; c != "" && k != "" {
return c, k
}
if isLoopbackListenAddr(s.cfg.Challenge.ListenAddr) {
return "", ""
}
if c, k := s.cfg.WebUI.TLSCert, s.cfg.WebUI.TLSKey; c != "" && k != "" {
return c, k
}
return "", ""
}
func isLoopbackListenAddr(addr string) bool {
addr = strings.TrimSpace(addr)
if addr == "" || strings.EqualFold(addr, "localhost") {
return true
}
ip := net.ParseIP(addr)
return ip != nil && ip.IsLoopback()
}
// Shutdown gracefully stops the server.
func (s *Server) Shutdown() {
_ = s.srv.Close()
}
func (s *Server) handleChallenge(w http.ResponseWriter, r *http.Request) {
ip := s.extractIP(r)
// Bypass paths run before the PoW page is generated. Each
// short-circuits to the same markVerified flow, so a passing
// visitor never sees the challenge UI even once.
if s.bypassByAdminCookie(r, ip) {
s.markVerified(w, r, ip, "admin session", "")
return
}
if s.bypassByVerifiedCrawler(r.Context(), ip) {
s.markVerified(w, r, ip, "verified crawler", "")
return
}
if s.bypassByVerifyCookie(r, ip) {
s.markVerified(w, r, ip, "verified cookie", "")
return
}
nonce := generateNonce()
difficulty := s.cfg.Challenge.Difficulty
token := s.makeToken(ip, nonce)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set("Cache-Control", "no-store")
captchaBlock := s.captchaNoscriptHTML(token, nonce)
// #nosec G705 -- All interpolated values are constrained to non-HTML
// character sets: ip is validated via net.ParseIP / net.SplitHostPort
// in extractIP (never attacker-controlled string); nonce and token
// are hex.EncodeToString output (0-9, a-f only); difficulty is an int.
// captchaBlock is constructed from configured site keys (operator
// supplied) plus hex token/nonce; it never includes attacker input.
// html.EscapeString on ip is defence-in-depth so static analysers can
// see the request-derived value is rendered safe even if a future
// change to extractIP loosened validation.
fmt.Fprintf(w, challengePageHTML, html.EscapeString(ip), captchaBlock, nonce, token, difficulty, difficulty)
}
func (s *Server) handleGate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet && r.Method != http.MethodHead {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
ip := s.extractIP(r)
if s.bypassByAdminCookie(r, ip) || s.bypassByVerifiedCrawler(r.Context(), ip) || s.bypassByVerifyCookie(r, ip) {
w.WriteHeader(http.StatusNoContent)
return
}
if s.ipList == nil || !s.ipList.Contains(ip) {
w.WriteHeader(http.StatusNoContent)
return
}
w.WriteHeader(http.StatusUnauthorized)
}
// bypassByAdminCookie returns true when the visitor presents a valid
// signed-session cookie issued by THIS daemon for THIS IP. Daemon
// restart rotates the signing key, so old cookies fall back to the
// normal PoW flow automatically.
func (s *Server) bypassByAdminCookie(r *http.Request, ip string) bool {
if s.sessionSigner == nil {
return false
}
cookieName := s.cfg.Challenge.VerifiedSession.CookieName
if cookieName == "" {
cookieName = "csm_admin_session"
}
c, err := r.Cookie(cookieName)
if err != nil || c.Value == "" {
return false
}
return s.sessionSigner.Verify(c.Value, ip) == nil
}
// bypassByVerifiedCrawler resolves the visitor's reverse DNS and
// confirms the forward lookup -- skipping the PoW only for traffic
// from the configured crawler families. A spoofed UA from a residential
// IP fails forward-confirm and falls through to PoW.
func (s *Server) bypassByVerifiedCrawler(ctx context.Context, ip string) bool {
if s.crawlers == nil {
return false
}
return s.crawlers.Verified(ctx, ip)
}
func (s *Server) handleVerify(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
ip := s.extractIP(r)
nonce := r.FormValue("nonce")
token := r.FormValue("token")
solution := r.FormValue("solution")
// Verify HMAC token
expected := s.makeToken(ip, nonce)
if token != expected {
http.Error(w, "Invalid token", http.StatusForbidden)
return
}
// Verify proof-of-work solution
if !verifyPoW(nonce, solution, s.cfg.Challenge.Difficulty) {
http.Error(w, "Invalid solution", http.StatusForbidden)
return
}
// Prevent replay -- one nonce, one verification.
s.verifiedMu.Lock()
if _, seen := s.verified[nonce]; seen {
s.verifiedMu.Unlock()
http.Error(w, "Token already used", http.StatusForbidden)
return
}
s.verified[nonce] = time.Now()
s.verifiedMu.Unlock()
s.markVerified(w, r, ip, "passed challenge", r.FormValue("dest"))
}
// handleCaptchaVerify accepts a provider token, validates it
// server-side, and (on success) puts the visitor through the same
// markVerified flow PoW uses. Available only when the operator has
// configured a CAPTCHA provider.
func (s *Server) handleCaptchaVerify(w http.ResponseWriter, r *http.Request) {
if s.captcha == nil {
http.NotFound(w, r)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
ip := s.extractIP(r)
nonce := r.FormValue("nonce")
token := r.FormValue("token")
captchaToken := r.FormValue("captcha-token")
// Bind the CAPTCHA submission to the page token before asking the
// provider. The nonce is spent only after the provider accepts, so a
// rejected widget token can be retried from the same page.
if expected := s.makeToken(ip, nonce); token != expected {
http.Error(w, "Invalid token", http.StatusForbidden)
return
}
s.verifiedMu.Lock()
_, seen := s.verified[nonce]
s.verifiedMu.Unlock()
if seen {
http.Error(w, "Token already used", http.StatusForbidden)
return
}
ok, err := s.captcha.Verify(r.Context(), captchaToken, ip)
if err != nil || !ok {
http.Error(w, "CAPTCHA verification failed", http.StatusForbidden)
return
}
s.verifiedMu.Lock()
if _, seen := s.verified[nonce]; seen {
s.verifiedMu.Unlock()
http.Error(w, "Token already used", http.StatusForbidden)
return
}
s.verified[nonce] = time.Now()
s.verifiedMu.Unlock()
s.markVerified(w, r, ip, "passed captcha", r.FormValue("dest"))
}
// handleAdminToken issues a signed-session cookie when the operator
// presents the configured admin_secret. Returns 204 with Set-Cookie on
// success, 403 on bad/missing secret, 429 once an IP has burned
// through adminMaxFailuresInWindow failures inside adminFailureWindow.
// The 429 is returned BEFORE the constant-time compare so an attacker
// cannot keep probing once they are throttled.
func (s *Server) handleAdminToken(w http.ResponseWriter, r *http.Request) {
if s.sessionSigner == nil {
http.NotFound(w, r)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
ip := s.extractIP(r)
if s.adminRateLimited(ip) {
http.Error(w, "Too many failed attempts; try again later.", http.StatusTooManyRequests)
return
}
presented := r.FormValue("secret")
if !CompareAdminSecret(s.cfg.Challenge.VerifiedSession.AdminSecret, presented) {
s.recordAdminFailure(ip)
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Successful auth clears the failure log so a legitimate operator
// who fat-fingered the secret a few times can keep using the
// endpoint after they get it right.
s.clearAdminFailures(ip)
cookieName := s.cfg.Challenge.VerifiedSession.CookieName
if cookieName == "" {
cookieName = "csm_admin_session"
}
cookie := &http.Cookie{
Name: cookieName,
Value: s.sessionSigner.Issue(ip),
Path: "/",
MaxAge: int(s.sessionSigner.TTL().Seconds()),
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(w, cookie)
w.WriteHeader(http.StatusNoContent)
}
// adminRateLimited returns true when ip has at least
// adminMaxFailuresInWindow failures in the last adminFailureWindow.
// Also opportunistically prunes aged-out entries so the per-IP slice
// stays bounded.
func (s *Server) adminRateLimited(ip string) bool {
cutoff := time.Now().Add(-adminFailureWindow)
s.adminFailuresMu.Lock()
defer s.adminFailuresMu.Unlock()
pruned := s.adminFailures[ip][:0]
for _, t := range s.adminFailures[ip] {
if t.After(cutoff) {
pruned = append(pruned, t)
}
}
s.adminFailures[ip] = pruned
return len(pruned) >= adminMaxFailuresInWindow
}
func (s *Server) recordAdminFailure(ip string) {
s.adminFailuresMu.Lock()
defer s.adminFailuresMu.Unlock()
s.adminFailures[ip] = append(s.adminFailures[ip], time.Now())
}
func (s *Server) clearAdminFailures(ip string) {
s.adminFailuresMu.Lock()
defer s.adminFailuresMu.Unlock()
delete(s.adminFailures, ip)
}
// markVerified is the shared post-success path used by handleVerify,
// handleCaptchaVerify, and the bypass shortcuts in handleChallenge.
// Centralising the side effects (firewall tempallow, ipList removal,
// verification cookie, redirect render) keeps all four paths in sync;
// the alternative -- copy-pasting four times -- is the easiest way to
// drift the behaviour of one path away from the others over time.
func (s *Server) markVerified(w http.ResponseWriter, r *http.Request, ip, reason, destOverride string) {
allowDuration := 4 * time.Hour
if s.unblocker != nil {
if err := s.unblocker.TempAllowIP(ip, reason, allowDuration); err != nil {
fmt.Fprintf(os.Stderr, "[challenge] failed to allow %s: %v\n", ip, err)
}
}
if s.ipList != nil {
s.ipList.Remove(ip)
}
// Set verification cookie so the visitor skips the gate until the allow
// window expires. Secure is always on: CSM is designed to run behind
// HTTPS and the cookie grants a multi-hour bypass of the PoW gate, so
// leaking it over plaintext is never acceptable.
if s.verifySigner != nil {
http.SetCookie(w, &http.Cookie{
Name: "csm_verified",
Value: s.verifySigner.Issue(ip),
Path: "/",
MaxAge: int(allowDuration.Seconds()),
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
}
dest := sanitizeRedirectDest(destOverride, r.Host)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprintf(w, `<!DOCTYPE html><html><head><meta http-equiv="refresh" content="2;url=%s">
<style>body{font-family:system-ui;display:flex;justify-content:center;align-items:center;min-height:100vh;margin:0;background:#1a2234;color:#c8d3e0}
.ok{text-align:center}.ok h1{color:#2fb344;font-size:3em}p{font-size:1.2em}</style>
</head><body><div class="ok"><h1>✓</h1><p>Verified - redirecting...</p></div></body></html>`, html.EscapeString(dest))
}
// captchaNoscriptHTML renders the CAPTCHA fallback widget. Returns the
// empty string when no provider is configured, in which case the
// challenge page's <noscript> block falls back to the existing
// "JavaScript is required" message.
//
// token and nonce are guaranteed hex by upstream constructors; siteKey
// is operator-supplied so we escape it to keep a typo from breaking
// the page (an attacker would need write access to csm.yaml to inject
// HTML here, which is already game-over, but defence-in-depth is
// cheap).
func (s *Server) captchaNoscriptHTML(token, nonce string) string {
if s.captcha == nil {
return ""
}
siteKey := html.EscapeString(s.cfg.Challenge.CaptchaFallback.SiteKey)
switch s.captcha.Name() {
case "turnstile":
return fmt.Sprintf(`<script src="https://challenges.cloudflare.com/turnstile/v0/api.js" async defer></script>
<form method="POST" action="/challenge/captcha-verify">
<input type="hidden" name="token" value="%s">
<input type="hidden" name="nonce" value="%s">
<div class="cf-turnstile" data-sitekey="%s" data-callback="csmCaptchaCallback"></div>
<input type="hidden" name="captcha-token" id="captchaToken">
</form>
<script>function csmCaptchaCallback(t){document.getElementById('captchaToken').value=t;document.forms[0].submit();}</script>`,
token, nonce, siteKey)
case "hcaptcha":
return fmt.Sprintf(`<script src="https://js.hcaptcha.com/1/api.js" async defer></script>
<form method="POST" action="/challenge/captcha-verify">
<input type="hidden" name="token" value="%s">
<input type="hidden" name="nonce" value="%s">
<div class="h-captcha" data-sitekey="%s" data-callback="csmCaptchaCallback"></div>
<input type="hidden" name="captcha-token" id="captchaToken">
</form>
<script>function csmCaptchaCallback(t){document.getElementById('captchaToken').value=t;document.forms[0].submit();}</script>`,
token, nonce, siteKey)
default:
return ""
}
}
func (s *Server) makeToken(ip, nonce string) string {
mac := hmac.New(sha256.New, s.secret)
mac.Write([]byte(ip + ":" + nonce))
return hex.EncodeToString(mac.Sum(nil))
}
// bypassByVerifyCookie returns true when the visitor presents a valid
// csm_verified cookie issued by this daemon for this IP. The cookie binds
// to one IP and an expiry, so it cannot be replayed from another network or
// after the allow window; a daemon restart rotates the signing key and
// invalidates every outstanding cookie.
func (s *Server) bypassByVerifyCookie(r *http.Request, ip string) bool {
if s.verifySigner == nil {
return false
}
c, err := r.Cookie("csm_verified")
if err != nil || c.Value == "" {
return false
}
return s.verifySigner.Verify(c.Value, ip) == nil
}
// CleanExpired removes old verification records, prunes the
// admin-failure log, and evicts stale crawler-cache entries. Called
// from the daemon's challengeEscalator ticker every 60 seconds; under
// a sustained scan from many source IPs, this is the only thing
// keeping per-IP map entries from accumulating until restart.
func (s *Server) CleanExpired() {
now := time.Now()
s.verifiedMu.Lock()
verifiedCutoff := now.Add(-4 * time.Hour)
for k, t := range s.verified {
if t.Before(verifiedCutoff) {
delete(s.verified, k)
}
}
s.verifiedMu.Unlock()
// Drop admin-failure entries whose latest failure has aged out of
// the rate-limit window. An IP that hammered the endpoint once
// and stopped will otherwise sit in the map forever.
s.adminFailuresMu.Lock()
failureCutoff := now.Add(-adminFailureWindow)
for ip, times := range s.adminFailures {
kept := times[:0]
for _, t := range times {
if t.After(failureCutoff) {
kept = append(kept, t)
}
}
if len(kept) == 0 {
delete(s.adminFailures, ip)
} else {
s.adminFailures[ip] = kept
}
}
s.adminFailuresMu.Unlock()
if s.crawlers != nil {
s.crawlers.cleanExpired(now)
}
}
// extractIP returns the client IP from the request. X-Forwarded-For is only
// trusted when the direct peer is in the configured trusted_proxies list.
// Uses the rightmost XFF entry (the one appended by the trusted proxy),
// not the leftmost (which the client controls). Without trusted proxies,
// RemoteAddr is always used — this prevents attackers from spoofing their IP
// to mint firewall allow rules for arbitrary addresses.
func (s *Server) extractIP(r *http.Request) string {
peerIP, _, _ := net.SplitHostPort(r.RemoteAddr)
peerIP = canonicalIP(peerIP)
if len(s.trustedProxies) > 0 && s.trustedProxies[peerIP] {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
// Use rightmost entry — the one the trusted proxy appended
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if net.ParseIP(ip) != nil {
return canonicalIP(ip)
}
}
}
}
return peerIP
}
// canonicalIP normalizes an IP string so downstream exact-match lookups (the
// firewall's blocked-IP index) agree on one form. A dual-stack listener
// reports an IPv4 peer as "::ffff:1.2.3.4"; net.IP.String collapses that to
// the plain IPv4 form the firewall stores. Non-parseable input is returned
// unchanged.
func canonicalIP(ip string) string {
if parsed := net.ParseIP(ip); parsed != nil {
return parsed.String()
}
return ip
}
func generateNonce() string {
b := make([]byte, 16)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
// sanitizeRedirectDest validates that a redirect destination is a safe same-origin
// relative path or absolute URL matching the request host. Rejects cross-origin
// redirects, javascript: URIs, backslash-based bypasses, and other open-redirect payloads.
// Returns a reconstructed URL from parsed components to prevent any raw-string injection.
func sanitizeRedirectDest(dest, requestHost string) string {
if dest == "" {
return "/"
}
// Parse through url.Parse to normalize and detect scheme/host
parsed, err := url.Parse(dest)
if err != nil {
return "/"
}
// Scheme whitelist — applies even for opaque URLs with empty Host.
// Without this, `javascript:alert(1)` produces an opaque URL with
// Host="" and Scheme="javascript", which would slip past the
// host-equality check below and end up reconstructed as
// `"javascript:"`. The only acceptable schemes are the empty
// string (pure-path relatives) and the two HTTP variants.
scheme := strings.ToLower(parsed.Scheme)
if scheme != "" && scheme != "http" && scheme != "https" {
return "/"
}
// Reject anything with a host component that doesn't match the request host.
// This catches protocol-relative (//evil.com), backslash tricks (/\evil.com
// which some browsers normalize to //evil.com), and explicit cross-origin URLs.
if parsed.Host != "" {
destHost := parsed.Hostname()
reqHost := requestHost
if h, _, err := net.SplitHostPort(requestHost); err == nil {
reqHost = h
}
if destHost != reqHost {
return "/"
}
}
// For relative paths: reject anything that doesn't start with a clean /
if parsed.Host == "" && parsed.Scheme == "" {
if !strings.HasPrefix(parsed.Path, "/") {
return "/"
}
// Reject backslash in path (browser normalization attack)
if strings.ContainsRune(parsed.Path, '\\') {
return "/"
}
}
// Reconstruct from parsed components to prevent raw-string injection
safe := &url.URL{
Scheme: parsed.Scheme,
Host: parsed.Host,
Path: parsed.Path,
RawQuery: parsed.RawQuery,
Fragment: parsed.Fragment,
}
return safe.String()
}
// verifyPoW checks that SHA256(nonce + solution) starts with `difficulty` zero nibbles.
func verifyPoW(nonce, solution string, difficulty int) bool {
h := sha256.Sum256([]byte(nonce + solution))
hexHash := hex.EncodeToString(h[:])
for i := 0; i < difficulty; i++ {
if i >= len(hexHash) || hexHash[i] != '0' {
return false
}
}
return true
}
package challenge
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/binary"
"errors"
"strings"
"time"
)
// AdminSessionSigner mints and verifies the signed cookies that let
// authenticated operators bypass the PoW. The signing key is generated
// on construction; rebuilding the signer (i.e., daemon restart)
// invalidates every previously-issued cookie -- the rotation contract.
type AdminSessionSigner struct {
key []byte
ttl time.Duration
}
// ErrSessionExpired is returned by Verify for cookies whose embedded
// expiry has passed.
var ErrSessionExpired = errors.New("session expired")
// ErrSessionBadSignature is returned by Verify for cookies whose HMAC
// does not match. Includes tampered payloads and cookies signed by a
// previous AdminSessionSigner instance (post-rotation).
var ErrSessionBadSignature = errors.New("session signature invalid")
// ErrSessionIPMismatch is returned when the cookie was issued for a
// different IP than the one presenting it. Stops cookie theft from a
// different network.
var ErrSessionIPMismatch = errors.New("session IP mismatch")
// ErrSessionMalformed wraps decoding errors so a corrupt cookie has a
// distinct sentinel from a tampered one.
var ErrSessionMalformed = errors.New("session payload malformed")
// NewAdminSessionSigner generates a fresh 32-byte signing key. The
// caller must keep the returned pointer for the lifetime of the
// challenge server; never construct a second signer for the same
// server, or already-issued cookies will be invalidated mid-session.
func NewAdminSessionSigner(ttl time.Duration) (*AdminSessionSigner, error) {
if ttl <= 0 {
ttl = 4 * time.Hour
}
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, err
}
return &AdminSessionSigner{key: key, ttl: ttl}, nil
}
// TTL exposes the configured cookie lifetime so the server can set the
// matching Max-Age on the Set-Cookie header.
func (s *AdminSessionSigner) TTL() time.Duration { return s.ttl }
// Issue returns a cookie value of the form "<base64(payload)>.<base64 hmac>".
// The payload binds the cookie to a single IP and a single expiry so a
// stolen cookie does not work elsewhere or after the TTL.
func (s *AdminSessionSigner) Issue(ip string) string {
return s.issueAt(ip, time.Now().Add(s.ttl))
}
// issueAt is the test seam for issuing a cookie with an explicit
// expiry. Lets expiry tests construct already-expired cookies without
// reaching into encodeSessionPayload directly.
func (s *AdminSessionSigner) issueAt(ip string, exp time.Time) string {
payload := encodeSessionPayload(ip, exp)
mac := hmac.New(sha256.New, s.key)
mac.Write(payload)
sig := mac.Sum(nil)
return base64.RawURLEncoding.EncodeToString(payload) + "." + base64.RawURLEncoding.EncodeToString(sig)
}
// Verify checks the HMAC, payload format, expiry, and IP binding. Use
// errors.Is to branch on the failure mode.
func (s *AdminSessionSigner) Verify(cookieValue, ip string) error {
dot := strings.LastIndexByte(cookieValue, '.')
if dot <= 0 || dot == len(cookieValue)-1 {
return ErrSessionMalformed
}
payloadEnc := cookieValue[:dot]
sigEnc := cookieValue[dot+1:]
payload, err := base64.RawURLEncoding.DecodeString(payloadEnc)
if err != nil {
return ErrSessionMalformed
}
sig, err := base64.RawURLEncoding.DecodeString(sigEnc)
if err != nil {
return ErrSessionMalformed
}
mac := hmac.New(sha256.New, s.key)
mac.Write(payload)
expected := mac.Sum(nil)
if subtle.ConstantTimeCompare(sig, expected) != 1 {
return ErrSessionBadSignature
}
cookieIP, exp, err := decodeSessionPayload(payload)
if err != nil {
return ErrSessionMalformed
}
if time.Now().After(exp) {
return ErrSessionExpired
}
if cookieIP != ip {
return ErrSessionIPMismatch
}
return nil
}
// CompareAdminSecret returns true when stored and presented secrets
// match in constant time. Empty stored secret always returns false so a
// misconfigured admin_secret cannot accidentally accept any caller.
func CompareAdminSecret(stored, presented string) bool {
if stored == "" {
return false
}
return subtle.ConstantTimeCompare([]byte(stored), []byte(presented)) == 1
}
// encodeSessionPayload formats: 1 byte version | 8 byte unix expiry BE
// | n byte IP. Variable-length tail keeps it simple; the IP read stops
// at end-of-buffer.
func encodeSessionPayload(ip string, exp time.Time) []byte {
out := make([]byte, 0, 9+len(ip))
out = append(out, 1)
var ts [8]byte
// #nosec G115 -- exp is always future-dated (now + positive TTL via NewAdminSessionSigner / Issue), so Unix() is positive and the int64->uint64 cast cannot overflow.
binary.BigEndian.PutUint64(ts[:], uint64(exp.Unix()))
out = append(out, ts[:]...)
out = append(out, []byte(ip)...)
return out
}
func decodeSessionPayload(p []byte) (ip string, exp time.Time, err error) {
if len(p) < 9 || p[0] != 1 {
return "", time.Time{}, ErrSessionMalformed
}
tsRaw := binary.BigEndian.Uint64(p[1:9])
if tsRaw > uint64(1<<62) {
return "", time.Time{}, ErrSessionMalformed
}
return string(p[9:]), time.Unix(int64(tsRaw), 0), nil
}
package checks
import (
"context"
"fmt"
"os"
"os/user"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/state"
)
const accountScanMaxFilesDefault = 10000
func effectiveAccountScanMaxFiles(cfg *config.Config) int {
if cfg == nil || cfg.Thresholds.AccountScanMaxFiles <= 0 {
return accountScanMaxFilesDefault
}
return cfg.Thresholds.AccountScanMaxFiles
}
// rankPathsByMtimeDesc orders paths most-recent-first and optionally caps
// the result at maxFiles. Lexical glob order plus a downstream check timeout
// would otherwise keep cutting iteration off at the same prefix every cycle,
// hiding indicators on late-alphabet accounts.
//
// Stat failures are tolerated: the path is kept with a zero mtime so it
// sorts to the end, letting the cap chop it first when present.
// Best-effort ranking is the goal -- dropping silently here would
// reintroduce the same hidden-input bug class the helper exists to close.
// Downstream readers handle the missing-file case on their own.
//
// A canceled ctx returns nil; nil ctx is treated as Background.
// maxFiles <= 0 disables the cap; the sort still runs.
func rankPathsByMtimeDesc(ctx context.Context, paths []string, maxFiles int) []string {
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return nil
}
type entry struct {
path string
mtime time.Time
}
ranked := make([]entry, 0, len(paths))
for _, p := range paths {
if err := ctx.Err(); err != nil {
return nil
}
var mt time.Time
if info, err := osFS.Stat(p); err == nil {
mt = info.ModTime()
}
ranked = append(ranked, entry{path: p, mtime: mt})
}
if err := ctx.Err(); err != nil {
return nil
}
sort.Slice(ranked, func(i, j int) bool {
if ranked[i].mtime.Equal(ranked[j].mtime) {
return ranked[i].path < ranked[j].path
}
return ranked[i].mtime.After(ranked[j].mtime)
})
if err := ctx.Err(); err != nil {
return nil
}
var droppedPaths []string
if maxFiles > 0 && len(ranked) > maxFiles {
dropped := len(ranked) - maxFiles
droppedPaths = make([]string, dropped)
for i, e := range ranked[maxFiles:] {
droppedPaths[i] = e.path
}
ranked = ranked[:maxFiles]
}
if len(droppedPaths) > 0 {
recordAccountScanTruncatedPaths(ctx, droppedPaths, maxFiles)
}
out := make([]string, len(ranked))
for i, e := range ranked {
out[i] = e.path
}
return out
}
// RunAccountScan runs all applicable checks scoped to a single cPanel account.
// Returns findings for that account only. Does NOT trigger auto-response actions.
//
// Scope is propagated through ctx via ContextWithAccountScope, so parallel
// scans of different accounts no longer block on a single process-wide
// mutex and never bleed scope into each other.
func RunAccountScan(cfg *config.Config, store *state.Store, account string) []alert.Finding {
// Verify account exists
homeDir := filepath.Join("/home", account)
if _, err := osFS.Stat(homeDir); os.IsNotExist(err) {
return []alert.Finding{{
Severity: alert.Warning,
Check: "account_scan",
Message: fmt.Sprintf("Account '%s' not found (no /home/%s directory)", account, account),
Timestamp: time.Now(),
}}
}
// Account-scoped checks (filesystem + account-specific)
accountChecks := []namedCheck{
{"webshells", CheckWebshells},
{"htaccess", CheckHtaccess},
{"wp_core", CheckWPCore},
{"php_content", CheckPHPContent},
{"phishing", CheckPhishing},
{"filesystem", CheckFilesystem},
{"group_writable_php", CheckGroupWritablePHP},
{"nulled_plugins", CheckNulledPlugins},
{"open_basedir", CheckOpenBasedir},
{"symlink_attacks", CheckSymlinkAttacks},
{"db_content", CheckDatabaseContent},
{"php_config_changes", CheckPHPConfigChanges},
}
// Account-specific checks that need the account name
accountChecks = append(accountChecks,
namedCheck{"ssh_keys_account", makeAccountSSHKeyCheck(account)},
namedCheck{"crontab_account", makeAccountCrontabCheck(account)},
namedCheck{"backdoor_binaries", makeAccountBackdoorCheck(account)},
)
// Run with bounded parallelism - filesystem checks all walk the same
// directory tree, so too many concurrent checks starve each other on
// loaded servers with slow I/O.
scanCtx, truncations := withAccountScanTruncationCollector(context.Background())
scanCtx = ContextWithAccountScope(scanCtx, account)
var mu sync.Mutex
var findings []alert.Finding
var wg sync.WaitGroup
sem := make(chan struct{}, 4) // max 4 concurrent checks
for _, nc := range accountChecks {
wg.Add(1)
c := nc
// Account checks run against user filesystem content (unparsed PHP,
// crafted archives, foreign encodings) so a panic is plausible.
// SafeGo captures it; runAccountScanCheck surfaces it as a
// check_timeout, keeping the scan and the daemon alive.
obs.SafeGo("account-scan-runner", func() {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
results := runAccountScanCheck(scanCtx, c, cfg, store, checkTimeout)
if len(results) > 0 {
mu.Lock()
findings = append(findings, results...)
mu.Unlock()
}
})
}
wg.Wait()
now := time.Now()
findings = append(findings, truncations.findings(now)...)
for i := range findings {
if findings[i].Timestamp.IsZero() {
findings[i].Timestamp = now
}
}
// Filter findings to only include this account's paths
var filtered []alert.Finding
for _, f := range findings {
if accountScanFindingInScope(f, account) {
filtered = append(filtered, f)
}
}
return stampTenantIDIfEmpty(filtered, account)
}
// runAccountScanCheck runs one check under a timeout, recovering any panic.
// The check executes on a SafeGo goroutine, so a panic (recovered there)
// never delivers to done and the timeout branch fires -- the same graceful
// degradation runParallel uses, surfacing a check_timeout instead of
// crashing the daemon that hosts the WebUI-triggered account scan.
func runAccountScanCheck(ctx context.Context, c namedCheck, cfg *config.Config, store *state.Store, timeout time.Duration) []alert.Finding {
cctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
done := make(chan []alert.Finding, 1)
obs.SafeGo("account-scan-exec", func() {
done <- c.fn(cctx, cfg, store)
})
select {
case results := <-done:
return results
case <-cctx.Done():
return []alert.Finding{{
Severity: alert.Warning,
Check: "check_timeout",
Message: fmt.Sprintf("Account scan check '%s' timed out", c.name),
Timestamp: time.Now(),
}}
}
}
func accountScanFindingInScope(f alert.Finding, account string) bool {
if account == "" {
return true
}
if f.FilePath != "" {
fileAccount := accountFromHomePath(f.FilePath)
if fileAccount != "" {
return fileAccount == account
}
if containsHomeReference(f.FilePath) {
return false
}
}
hasHomeRef, hasAccountRef := textHomeScope(f.Message, account)
detailsHasHomeRef, detailsHasAccountRef := textHomeScope(f.Details, account)
if hasAccountRef || detailsHasAccountRef {
return true
}
if hasHomeRef || detailsHasHomeRef {
return false
}
return true
}
func containsHomeReference(path string) bool {
cleaned := filepath.ToSlash(filepath.Clean(path))
if !strings.HasPrefix(cleaned, "/home") {
return false
}
i := len("/home")
for i < len(cleaned) && cleaned[i] >= '0' && cleaned[i] <= '9' {
i++
}
return i == len(cleaned) || cleaned[i] == '/'
}
func textHomeScope(text, account string) (hasHomeRef, hasAccountRef bool) {
for i := 0; i < len(text); {
idx := strings.Index(text[i:], "/home")
if idx < 0 {
return hasHomeRef, hasAccountRef
}
start := i + idx
homeAccount, ok := homeAccountAt(text[start:])
if ok {
hasHomeRef = true
if homeAccount == account {
hasAccountRef = true
}
}
i = start + len("/home")
}
return hasHomeRef, hasAccountRef
}
func homeAccountAt(text string) (string, bool) {
if !strings.HasPrefix(text, "/home") {
return "", false
}
i := len("/home")
for i < len(text) && text[i] >= '0' && text[i] <= '9' {
i++
}
if i == len(text) {
return "", true
}
if text[i] != '/' {
return "", false
}
i++
start := i
for i < len(text) && isHomeAccountByte(text[i]) {
i++
}
if i == start {
return "", true
}
return text[start:i], true
}
func isHomeAccountByte(b byte) bool {
return b >= 'a' && b <= 'z' ||
b >= 'A' && b <= 'Z' ||
b >= '0' && b <= '9' ||
b == '_' || b == '-' || b == '.'
}
// stampTenantIDIfEmpty fills in Finding.TenantID with account when the
// detector emitted the finding without explicit tenant attribution.
// Account-scope detectors otherwise leave TenantID empty and the
// correlator falls back to weaker identities (UID, PID, file hash),
// fragmenting one account's incidents across multiple keys. Findings
// the detector did stamp keep their value; an empty account is a no-op.
func stampTenantIDIfEmpty(findings []alert.Finding, account string) []alert.Finding {
if account == "" {
return findings
}
for i := range findings {
if findings[i].TenantID == "" {
findings[i].TenantID = account
}
}
return findings
}
// GetScanHomeDirs returns the list of home directories to scan.
// When ctx carries an account scope (via ContextWithAccountScope), only
// that account is returned. Otherwise every entry under /home is read.
// Nil ctx is tolerated for legacy callers and treated as host-wide.
func GetScanHomeDirs(ctx context.Context) ([]os.DirEntry, error) {
if account := AccountFromContext(ctx); account != "" {
info, err := osFS.Stat(filepath.Join("/home", account))
if err != nil {
return nil, err
}
return []os.DirEntry{fakeDirEntry{info}}, nil
}
return osFS.ReadDir("/home")
}
// ResolveWebRoots returns the list of directory paths CSM should scan for
// web-facing content (wp-config.php, .htaccess, public_html trees, etc.).
//
// Resolution order:
// 1. If cfg.AccountRoots is set, expand each glob and return the result.
// Explicit config always wins.
// 2. On cPanel hosts (detected via platform.Detect), fall back to
// /home/*/public_html for backward compatibility.
// 3. On non-cPanel hosts with no config, return an empty list. Callers
// should treat this as "no scanning" and skip cleanly.
//
// Each returned path is an absolute directory that exists on disk.
func ResolveWebRoots(cfg *config.Config) []string {
var patterns []string
switch {
case len(cfg.AccountRoots) > 0:
patterns = cfg.AccountRoots
case platform.Detect().IsCPanel():
patterns = []string{"/home/*/public_html"}
default:
return nil
}
var roots []string
seen := make(map[string]struct{})
for _, pattern := range patterns {
matches, err := osFS.Glob(pattern)
if err != nil || len(matches) == 0 {
continue
}
for _, m := range matches {
info, err := osFS.Stat(m)
if err != nil || !info.IsDir() {
continue
}
if _, ok := seen[m]; ok {
continue
}
seen[m] = struct{}{}
roots = append(roots, m)
}
}
return roots
}
// fakeDirEntry wraps os.FileInfo to implement os.DirEntry.
type fakeDirEntry struct {
fi os.FileInfo
}
func (f fakeDirEntry) Name() string { return f.fi.Name() }
func (f fakeDirEntry) IsDir() bool { return f.fi.IsDir() }
func (f fakeDirEntry) Type() os.FileMode { return f.fi.Mode().Type() }
func (f fakeDirEntry) Info() (os.FileInfo, error) { return f.fi, nil }
// makeAccountSSHKeyCheck creates a check for SSH keys of a specific account.
func makeAccountSSHKeyCheck(account string) CheckFunc {
return func(_ context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
keyFile := filepath.Join("/home", account, ".ssh", "authorized_keys")
hash, err := hashFileContent(keyFile)
if err != nil {
return nil
}
key := fmt.Sprintf("_ssh_user_keys:%s", keyFile)
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ssh_keys",
Message: fmt.Sprintf("User authorized_keys modified: %s", keyFile),
})
}
return findings
}
}
// makeAccountCrontabCheck creates a check for a specific account's crontab.
func makeAccountCrontabCheck(account string) CheckFunc {
return func(_ context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
crontabFile := filepath.Join("/var/spool/cron", account)
data, err := osFS.ReadFile(crontabFile)
if err != nil {
return nil
}
content := string(data)
for _, pattern := range MatchCrontabPatternsDeep(content, cfg) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "suspicious_crontab",
Message: fmt.Sprintf("Suspicious pattern in crontab for %s: %s", account, pattern),
Details: fmt.Sprintf("File: /var/spool/cron/%s\nContent:\n%s", account, content),
})
}
return findings
}
}
// makeAccountBackdoorCheck creates a check for backdoor binaries in account's .config.
func makeAccountBackdoorCheck(account string) CheckFunc {
return func(_ context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
backdoorNames := map[string]bool{
"defunct": true, "defunct.dat": true, "gs-netcat": true,
"gs-sftp": true, "gs-mount": true, "gsocket": true,
}
patterns := []string{
filepath.Join("/home", account, ".config", "htop", "*"),
filepath.Join("/home", account, ".config", "*", "*"),
}
for _, pattern := range patterns {
matches, _ := osFS.Glob(pattern)
for _, path := range matches {
if backdoorNames[filepath.Base(path)] {
info, _ := osFS.Stat(path)
var details string
if info != nil {
details = fmt.Sprintf("Size: %d bytes, Mtime: %s", info.Size(), info.ModTime().Format("2006-01-02 15:04:05"))
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "backdoor_binary",
Message: fmt.Sprintf("Backdoor binary found: %s", path),
Details: details,
FilePath: path,
})
}
}
}
return findings
}
}
// LookupUID returns the UID for a system account name, or -1 if not found.
func LookupUID(account string) int {
u, err := user.Lookup(account)
if err != nil {
return -1
}
uid := 0
fmt.Sscanf(u.Uid, "%d", &uid)
return uid
}
package checks
import (
"context"
"fmt"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
)
type accountScanTruncationContextKey struct{}
type accountScanTruncationCollector struct {
mu sync.Mutex
// droppedBy is keyed by account so operators can see which tenant hit
// which cap. The empty-string key keeps dropped paths that are not
// attributable to /home/<account>/.
droppedBy map[string]map[int]int
}
func withAccountScanTruncationCollector(ctx context.Context) (context.Context, *accountScanTruncationCollector) {
if ctx == nil {
ctx = context.Background()
}
collector := &accountScanTruncationCollector{droppedBy: map[string]map[int]int{}}
return context.WithValue(ctx, accountScanTruncationContextKey{}, collector), collector
}
func recordAccountScanTruncated(ctx context.Context, dropped, cap int) {
if dropped <= 0 || cap <= 0 || ctx == nil {
return
}
collector, ok := ctx.Value(accountScanTruncationContextKey{}).(*accountScanTruncationCollector)
if !ok || collector == nil {
return
}
collector.record(AccountFromContext(ctx), dropped, cap)
}
func recordAccountScanTruncatedPaths(ctx context.Context, droppedPaths []string, cap int) {
if len(droppedPaths) == 0 || cap <= 0 || ctx == nil {
return
}
collector, ok := ctx.Value(accountScanTruncationContextKey{}).(*accountScanTruncationCollector)
if !ok || collector == nil {
return
}
for account, dropped := range accountScanTruncationAccounts(ctx, droppedPaths) {
collector.record(account, dropped, cap)
}
}
func accountScanTruncationAccounts(ctx context.Context, paths []string) map[string]int {
if account := AccountFromContext(ctx); account != "" {
return map[string]int{account: len(paths)}
}
counts := make(map[string]int)
for _, path := range paths {
counts[accountFromHomePath(path)]++
}
return counts
}
func accountFromHomePath(path string) string {
cleaned := filepath.ToSlash(filepath.Clean(path))
if !containsHomeReference(cleaned) {
return ""
}
i := len("/home")
for i < len(cleaned) && cleaned[i] >= '0' && cleaned[i] <= '9' {
i++
}
if i == len(cleaned) || cleaned[i] != '/' {
return ""
}
account := cleaned[i+1:]
if slash := strings.IndexByte(account, '/'); slash >= 0 {
account = account[:slash]
}
if account == "." || account == ".." {
return ""
}
return account
}
func (c *accountScanTruncationCollector) record(account string, dropped, cap int) {
c.mu.Lock()
defer c.mu.Unlock()
caps, ok := c.droppedBy[account]
if !ok {
caps = map[int]int{}
c.droppedBy[account] = caps
}
caps[cap] += dropped
}
func (c *accountScanTruncationCollector) findings(now time.Time) []alert.Finding {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.droppedBy) == 0 {
return nil
}
accounts := make([]string, 0, len(c.droppedBy))
for a := range c.droppedBy {
accounts = append(accounts, a)
}
sort.Strings(accounts) // stable finding order across runs
var findings []alert.Finding
for _, account := range accounts {
caps := c.droppedBy[account]
capValues := make([]int, 0, len(caps))
for cap := range caps {
capValues = append(capValues, cap)
}
sort.Ints(capValues)
for _, cap := range capValues {
dropped := caps[cap]
scope := "host scan"
tenantID := ""
if account != "" {
scope = fmt.Sprintf("account %s", account)
tenantID = account
}
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "account_scan_truncated",
TenantID: tenantID,
Message: fmt.Sprintf("Account scan truncated for %s: %d file(s) skipped past cap of %d", scope, dropped, cap),
Details: "Raise thresholds.account_scan_max_files if recent detection coverage matters more than scan duration.",
Timestamp: now,
})
}
}
return findings
}
package checks
import (
"context"
"path/filepath"
)
type accountScopeKey struct{}
// ContextWithAccountScope returns a derived context that restricts
// filesystem-based checks to a single cPanel/Linux account. Callers
// pass the resulting context into every check; helpers like
// GetScanHomeDirs read the scope back out. Empty account is a no-op
// (returns ctx unchanged, equivalent to a full host scan).
func ContextWithAccountScope(ctx context.Context, account string) context.Context {
if ctx == nil {
ctx = context.Background()
}
if account == "" {
return ctx
}
return context.WithValue(ctx, accountScopeKey{}, account)
}
// AccountFromContext returns the account scope previously attached by
// ContextWithAccountScope, or "" when no scope is set.
func AccountFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
v, _ := ctx.Value(accountScopeKey{}).(string)
return v
}
func homeGlob(ctx context.Context, elem ...string) ([]string, error) {
parts := []string{"/home"}
if account := AccountFromContext(ctx); account != "" {
parts = append(parts, account)
} else {
parts = append(parts, "*")
}
parts = append(parts, elem...)
return osFS.Glob(filepath.Join(parts...))
}
package checks
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// afAlgEvent is the parsed view of a single SYSCALL record tagged with the
// csm_af_alg_socket auditd key. Fields are kept as strings because the
// downstream consumer is a human-readable Finding message.
type afAlgEvent struct {
Timestamp string // e.g. "1761826283.452"
Serial string // e.g. "91234"
UID string
AUID string
PID string // process id from the SYSCALL record; needed for live kill reaction
Comm string
Exe string
}
// AFAlgEvent is the package-public view of afAlgEvent used by callers
// outside internal/checks (the daemon's live audit-log listener emits
// findings derived from this shape).
type AFAlgEvent = afAlgEvent
// ParseAFAlgEventLine is the exported alias of parseAFAlgEvent for the
// daemon's live listener. The unexported form stays internal so the
// rest of this package can refer to the type by its short name.
func ParseAFAlgEventLine(line string) (AFAlgEvent, bool) {
return parseAFAlgEvent(line)
}
// after reports whether e is strictly newer than other. Comparison is
// (Timestamp, Serial) lexicographic with numeric semantics.
func (e afAlgEvent) after(other afAlgEvent) bool {
if e.Timestamp != other.Timestamp {
eFloat, _ := strconv.ParseFloat(e.Timestamp, 64)
otherFloat, _ := strconv.ParseFloat(other.Timestamp, 64)
return eFloat > otherFloat
}
// Avoid the local name `os` here — it shadows the stdlib package
// of the same name and would silently break a future edit that
// adds an `os` import to this file.
eSerial, _ := strconv.Atoi(e.Serial)
otherSerial, _ := strconv.Atoi(other.Serial)
return eSerial > otherSerial
}
// parseAFAlgEvent extracts the relevant fields from a single audit log line.
// It returns (event, true) only when the line is a SYSCALL record carrying
// the csm_af_alg_socket key. Anything else returns ok=false.
func parseAFAlgEvent(line string) (afAlgEvent, bool) {
// Require type=SYSCALL: auditd also writes the rule's key into
// CONFIG_CHANGE records on every add_rule / remove_rule (i.e. every
// CSM restart), and those are not exploit signatures.
if !strings.HasPrefix(line, "type=SYSCALL ") {
return afAlgEvent{}, false
}
if !strings.Contains(line, `key="csm_af_alg_socket"`) {
return afAlgEvent{}, false
}
ts, serial, ok := parseAuditMsgID(line)
if !ok {
return afAlgEvent{}, false
}
ev := afAlgEvent{Timestamp: ts, Serial: serial}
ev.UID = extractAuditField(line, "uid")
ev.AUID = extractAuditField(line, "auid")
ev.PID = extractAuditField(line, "pid")
ev.Comm = extractAuditField(line, "comm")
ev.Exe = extractAuditField(line, "exe")
return ev, true
}
// parseAuditMsgID extracts (timestamp, serial) from `msg=audit(TS:SERIAL):`.
func parseAuditMsgID(line string) (string, string, bool) {
const marker = "msg=audit("
i := strings.Index(line, marker)
if i < 0 {
return "", "", false
}
rest := line[i+len(marker):]
end := strings.Index(rest, ")")
if end < 0 {
return "", "", false
}
inside := rest[:end]
colon := strings.Index(inside, ":")
if colon < 0 {
return "", "", false
}
ts := inside[:colon]
serial := inside[colon+1:]
if _, err := strconv.ParseFloat(ts, 64); err != nil {
return "", "", false
}
if _, err := strconv.Atoi(serial); err != nil {
return "", "", false
}
return ts, serial, true
}
// extractAuditField returns the value of `key=...` from an audit log line.
// Quoted values may contain spaces; bare values are whitespace-delimited.
func extractAuditField(line, key string) string {
prefix := key + "="
idx := 0
for {
i := strings.Index(line[idx:], prefix)
if i < 0 {
return ""
}
i += idx
// Require start-of-line or preceding whitespace so "auid=" doesn't
// match when we asked for "uid=".
if i > 0 && line[i-1] != ' ' && line[i-1] != '\t' {
idx = i + 1
continue
}
rest := line[i+len(prefix):]
if strings.HasPrefix(rest, `"`) {
rest = rest[1:]
end := strings.Index(rest, `"`)
if end < 0 {
return ""
}
return rest[:end]
}
end := strings.IndexAny(rest, " \t")
if end < 0 {
return rest
}
return rest[:end]
}
}
const (
afAlgLogPath = "/var/log/audit/audit.log"
afAlgCursorKey = "_af_alg_last_seen"
)
// CheckAFAlgSocketUsage scans the audit log for csm_af_alg_socket events
// and emits one Critical finding per strictly-newer event. The first run
// alerts on every event found — AF_ALG-from-userland is an exploit signature
// for CVE-2026-31431 ("Copy Fail"), not a baseline metric, so silent seeding
// would hide pre-existing compromise. The cursor in state.Store prevents
// duplicates on subsequent sweeps and survives daemon restarts.
//
// Filtering is delegated to grep so we don't load the whole multi-hundred-MB
// audit log into memory each tick (same precedent as getAuditShadowInfo in
// auth.go). RunAllowNonZero is required because grep returns exit 1 on
// "no match" — the healthy default — and that must not surface as an error.
func CheckAFAlgSocketUsage(_ context.Context, _ *config.Config, st *state.Store) []alert.Finding {
// RunAllowNonZero swallows every non-zero exit (see runCmdAllowNonZeroReal
// in helpers.go) — including exit 1 (no match) and exit 2 (audit log not
// installed). The non-error path is therefore the only one we need to
// reason about here: empty output means "nothing to do".
out, err := cmdExec.RunAllowNonZero("grep", "-a", "csm_af_alg_socket", afAlgLogPath)
if err != nil {
return nil
}
if len(out) == 0 {
return nil
}
cursorRaw, hasCursor := st.GetRaw(afAlgCursorKey)
cursor := decodeCursor(cursorRaw)
var findings []alert.Finding
highest := cursor
highestSet := hasCursor
for _, line := range strings.Split(string(out), "\n") {
ev, ok := parseAFAlgEvent(line)
if !ok {
continue
}
if hasCursor && !ev.after(cursor) {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "af_alg_socket_use",
Message: fmt.Sprintf("AF_ALG socket opened by uid=%s exe=%s", ev.UID, ev.Exe),
Details: fmt.Sprintf(
"Audit event: timestamp=%s serial=%s\nauid=%s uid=%s comm=%q exe=%q\n"+
"AF_ALG is essentially never used by cPanel/PHP workloads. This is\n"+
"the kernel-level exploit signature for CVE-2026-31431 (\"Copy Fail\").\n"+
"Investigate this process immediately and consider unloading algif_aead\n"+
"(modprobe -r algif_aead af_alg) and adding a modprobe.d blacklist.",
ev.Timestamp, ev.Serial, ev.AUID, ev.UID, ev.Comm, ev.Exe,
),
})
if !highestSet || ev.after(highest) {
highest = ev
highestSet = true
}
}
if highestSet {
st.SetRaw(afAlgCursorKey, encodeCursor(highest))
}
return findings
}
func encodeCursor(ev afAlgEvent) string { return ev.Timestamp + ":" + ev.Serial }
func decodeCursor(s string) afAlgEvent {
if s == "" {
return afAlgEvent{}
}
colon := strings.Index(s, ":")
if colon < 0 {
return afAlgEvent{}
}
return afAlgEvent{Timestamp: s[:colon], Serial: s[colon+1:]}
}
package checks
import (
"context"
"errors"
"fmt"
"os"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// EnforceAction is the discrete outcome of the pure enforcement decision.
// Each value corresponds to one operational step the impure wrapper takes.
type EnforceAction int
const (
EnforceActionNoop EnforceAction = iota
EnforceActionRestoreMarker
EnforceActionUnloadModules
EnforceActionRestoreAndUnload
)
// decideAFAlgEnforcement is the pure, deterministic core of the enforcement
// check. Given the observed state of the marker file and the kernel module
// table, it returns exactly one action.
//
// Inputs:
// - markerPresent: /etc/modprobe.d/csm-copy-fail-mitigation.conf exists.
// - markerContentValid: that file's contents match the canonical CSM-managed
// content (so a hand-edited version still triggers a rewrite).
// - loaded: algif_aead OR af_alg is currently in /proc/modules.
//
// The "marker absent + modules loaded" combination intentionally returns
// Noop. The operator has not opted in to enforcement (no marker), so we will
// not unilaterally unload kernel modules they may legitimately be using —
// the existing hardening audit and auditd tripwire still surface the gap.
func decideAFAlgEnforcement(markerPresent, markerContentValid, loaded bool) EnforceAction {
if !markerPresent {
return EnforceActionNoop
}
switch {
case markerContentValid && !loaded:
return EnforceActionNoop
case markerContentValid && loaded:
return EnforceActionUnloadModules
case !markerContentValid && !loaded:
return EnforceActionRestoreMarker
default: // !markerContentValid && loaded
return EnforceActionRestoreAndUnload
}
}
// afAlgMarkerPath is the canonical location of the CSM-managed mitigation
// marker. Its presence is the signal that operator-driven enforcement is
// active for this host.
const afAlgMarkerPath = "/etc/modprobe.d/csm-copy-fail-mitigation.conf"
// canonicalAFAlgMarker is the byte-exact content the enforcer writes and
// re-asserts on drift. Hand-written variants (`blacklist algif_aead`, etc.)
// still satisfy the hardening audit, but the enforcer rewrites them to this
// canonical form so the file's content can be trivially validated.
const canonicalAFAlgMarker = `# CSM Copy Fail (CVE-2026-31431) mitigation — managed by CSM.
# Restored automatically by the af_alg_enforce critical-tier check.
# Remove this file (and run ` + "`csm harden --copy-fail`" + ` again) if you
# need to re-enable AF_ALG.
install algif_aead /bin/false
install af_alg /bin/false
`
// EnforceResult describes what enforceAFAlgBlocked observed and did, in a
// shape both the CLI subcommand and the periodic Check function can format
// for the operator without re-deriving the same conclusions.
//
// ModuleUnloaded reports the OBSERVED post-call state, not the syscall
// attempt: it is true only when /proc/modules no longer contains the
// targeted modules after `modprobe -r` ran. Use this field to distinguish
// "unload succeeded" from "unload attempted but module is in use".
type EnforceResult struct {
Action EnforceAction
MarkerPresent bool
MarkerValid bool
ModulesLoaded []string // names of currently-loaded targeted modules at start of call
MarkerWritten bool // wrapper wrote/restored the marker file this call
ModuleUnloaded bool // post-call /proc/modules shows targeted modules gone
Notes []string // operator-readable lines (warnings, stuck-module names)
}
func validateMarkerContent(data []byte) bool {
return string(data) == canonicalAFAlgMarker
}
// loadedTargetedModules returns the subset of {algif_aead, af_alg} currently
// present in /proc/modules. Used both before unload (to decide what to do)
// and after unload (to verify it actually took effect).
func loadedTargetedModules() []string {
var loaded []string
for _, mod := range loadModuleList() {
if mod == "algif_aead" || mod == "af_alg" {
loaded = append(loaded, mod)
}
}
return loaded
}
// enforceAFAlgBlocked inspects the marker file and /proc/modules, calls the
// pure decideAFAlgEnforcement, and applies the resulting action via osFS
// and cmdExec. Errors from osFS.WriteFile or unexpected osFS.Stat failures
// are returned; modprobe outcomes are observed via a post-call /proc/modules
// re-read (RunAllowNonZero swallows the non-zero exit, so the only reliable
// signal that the unload actually took effect is the kernel's module table).
func enforceAFAlgBlocked() (EnforceResult, error) {
res := EnforceResult{}
// Marker presence + content check. ErrNotExist is the expected "advisory
// mode" path; any other Stat failure (e.g. EACCES on a hardened
// /etc/modprobe.d/) is surfaced as an error rather than silently
// classifying the host as advisory mode.
switch _, err := osFS.Stat(afAlgMarkerPath); {
case err == nil:
res.MarkerPresent = true
if data, readErr := osFS.ReadFile(afAlgMarkerPath); readErr == nil {
res.MarkerValid = validateMarkerContent(data)
}
case errors.Is(err, os.ErrNotExist):
// Advisory mode — operator has not opted in.
default:
return res, fmt.Errorf("stat %s: %w", afAlgMarkerPath, err)
}
res.ModulesLoaded = loadedTargetedModules()
res.Action = decideAFAlgEnforcement(res.MarkerPresent, res.MarkerValid, len(res.ModulesLoaded) > 0)
switch res.Action {
case EnforceActionRestoreMarker, EnforceActionRestoreAndUnload:
if err := osFS.WriteFile(afAlgMarkerPath, []byte(canonicalAFAlgMarker), 0o644); err != nil {
return res, err
}
res.MarkerWritten = true
}
switch res.Action {
case EnforceActionUnloadModules, EnforceActionRestoreAndUnload:
_, _ = cmdExec.RunAllowNonZero("modprobe", "-r", "algif_aead", "af_alg")
// Re-read /proc/modules to see whether the unload actually took
// effect. RunAllowNonZero swallows non-zero exits (modprobe returns
// 1 when a module is in use), and the helper's underlying
// .Output() captures stdout only — modprobe writes its "FATAL:
// module in use" message to stderr — so the only reliable signal
// is the post-call kernel state.
stillLoaded := loadedTargetedModules()
if len(stillLoaded) == 0 {
res.ModuleUnloaded = true
} else {
res.Notes = append(res.Notes, fmt.Sprintf(
"modprobe -r attempted but %v still loaded — module is in use; will retry next tick",
stillLoaded,
))
}
}
return res, nil
}
// AFAlgMarkerPath returns the canonical marker file location. Exposed for
// the cmd/csm CLI which prints it to operators; production code in this
// package should reference the unexported constant directly.
func AFAlgMarkerPath() string { return afAlgMarkerPath }
// WriteAFAlgMarker forces the canonical marker content to disk regardless
// of current state. Used by `csm harden --copy-fail` to ensure subsequent
// EnforceAFAlgBlocked() runs see a valid marker even on first install.
func WriteAFAlgMarker() error {
return osFS.WriteFile(afAlgMarkerPath, []byte(canonicalAFAlgMarker), 0o644)
}
// EnforceAFAlgBlocked is the exported alias of enforceAFAlgBlocked for use
// by cmd/csm. The unexported form stays internal to the package so the
// periodic Check (Task 6) can call it without going through the export.
func EnforceAFAlgBlocked() (EnforceResult, error) { return enforceAFAlgBlocked() }
// CheckAFAlgEnforcement is the periodic critical-tier check that enforces
// the AF_ALG mitigation policy. When the operator has opted in (via
// `csm harden --copy-fail`, which writes the marker file), this check
// reverts any drift on each tick. It is a no-op in advisory mode.
//
// Emits a Warning finding (one per tick that took action) so the operator
// has an alert-pipeline record that system state was modified. Steady-state
// ticks emit no findings. Warning is the lowest severity available in the
// alert.Severity enum (Warning < High < Critical, no Info level).
func CheckAFAlgEnforcement(_ context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if cfg != nil && cfg.AutoResponse.DisableEnforceAFAlg {
return nil
}
res, err := enforceAFAlgBlocked()
if err != nil {
return []alert.Finding{{
Severity: alert.Warning,
Check: "af_alg_enforcement_corrected",
Message: "AF_ALG enforcement encountered an error",
Details: fmt.Sprintf("error: %v\nresult: %+v", err, res),
}}
}
if res.Action == EnforceActionNoop {
return nil
}
return []alert.Finding{{
Severity: alert.Warning,
Check: "af_alg_enforcement_corrected",
Message: fmt.Sprintf("AF_ALG enforcement re-applied (%s)", actionName(res.Action)),
Details: fmt.Sprintf(
"Action: %s\nMarker present: %v\nMarker valid: %v\nModules loaded: %v\nMarker written: %v\nModule unload succeeded: %v\nNotes: %v",
actionName(res.Action), res.MarkerPresent, res.MarkerValid, res.ModulesLoaded,
res.MarkerWritten, res.ModuleUnloaded, res.Notes,
),
}}
}
func actionName(a EnforceAction) string {
switch a {
case EnforceActionNoop:
return "Noop"
case EnforceActionRestoreMarker:
return "RestoreMarker"
case EnforceActionUnloadModules:
return "UnloadModules"
case EnforceActionRestoreAndUnload:
return "RestoreAndUnload"
}
return fmt.Sprintf("Unknown(%d)", int(a))
}
package checks
import (
"bytes"
"compress/gzip"
"os"
"strings"
)
// copyFailCVE is the CVE identifier used to match KernelCare/kpatch entries.
const copyFailCVE = "CVE-2026-31431"
// configHasBuiltInAFAlgAEAD reports whether the supplied kernel config text
// declares CRYPTO_USER_API_AEAD as built into the kernel image (=y) rather
// than as a loadable module (=m). On =y kernels, the modprobe blacklist
// mitigation is ineffective because there is no module to block from
// loading — the code is always present.
//
// The function tolerates the "is not set" comment form that kconfig uses
// for explicitly-disabled options. An unset CRYPTO_USER_API_AEAD returns
// false (not built-in; modular or absent).
func configHasBuiltInAFAlgAEAD(configText string) bool {
for _, line := range strings.Split(configText, "\n") {
line = strings.TrimSpace(line)
if line == "CONFIG_CRYPTO_USER_API_AEAD=y" {
return true
}
}
return false
}
// configHasModularAFAlgAEAD reports whether CRYPTO_USER_API_AEAD is set
// to =m (loadable module) in the supplied kernel config. Used by the
// "is this host actually exploitable?" policy decision: a kernel built
// without =y AND without =m has no AF_ALG aead interface at all, so
// Copy Fail is not reachable on it.
func configHasModularAFAlgAEAD(configText string) bool {
for _, line := range strings.Split(configText, "\n") {
line = strings.TrimSpace(line)
if line == "CONFIG_CRYPTO_USER_API_AEAD=m" {
return true
}
}
return false
}
// kcareReportsCopyFailPatched reports whether the supplied `kcarectl
// --patch-info` output advertises a patch covering the Copy Fail CVE.
// The kcarectl format emits per-patch records like:
//
// kpatch-name: rhel8/.../CVE-2026-NNNNN-foo.patch
// kpatch-cve: CVE-2026-NNNNN
// kpatch-cve-url: ...
//
// We match on a substring of the literal CVE id rather than the URL or
// filename so a future format change to either does not silently break
// detection.
func kcareReportsCopyFailPatched(out []byte) bool {
return bytes.Contains(out, []byte(copyFailCVE))
}
// kernelHasBuiltInAFAlgAEAD is the impure wrapper: it reads the kernel
// config from /boot/config-$(uname -r), falling back to /proc/config.gz,
// and asks configHasBuiltInAFAlgAEAD whether the AEAD interface is
// statically linked.
//
// Returns (false, nil) when no config file is readable — the caller
// treats "config unknown" the same as "modular" so an inability to read
// the config never silently downgrades a real protection state.
func kernelHasBuiltInAFAlgAEAD() (bool, error) {
cfg, _ := readKernelConfigText()
if cfg == "" {
return false, nil
}
return configHasBuiltInAFAlgAEAD(cfg), nil
}
// readKernelConfigText returns the running kernel's .config text from
// /boot/config-$(uname -r), falling back to /proc/config.gz (gunzipped).
// Returns ("", false) when neither is readable so the caller can decide
// whether to treat "unknown" as conservatively-vulnerable.
func readKernelConfigText() (string, bool) {
if uname, err := readKernelRelease(); err == nil {
if data, err := osFS.ReadFile("/boot/config-" + uname); err == nil {
return string(data), true
}
}
if data, err := osFS.ReadFile("/proc/config.gz"); err == nil {
zr, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return "", false
}
defer func() { _ = zr.Close() }()
var buf bytes.Buffer
if _, err := buf.ReadFrom(zr); err != nil {
return "", false
}
return buf.String(), true
}
return "", false
}
// readKernelRelease returns the running kernel release string (the same
// value `uname -r` would print). Used to find the matching config-* file
// under /boot.
func readKernelRelease() (string, error) {
data, err := osFS.ReadFile("/proc/sys/kernel/osrelease")
if err != nil {
return "", err
}
return strings.TrimSpace(string(data)), nil
}
// kcareHasCopyFailPatch runs `kcarectl --patch-info` and reports whether
// KernelCare has applied a livepatch covering Copy Fail. Returns false
// (with a nil error) when kcarectl is absent or fails — KernelCare is
// optional and its absence is not an error condition.
func kcareHasCopyFailPatch() bool {
out, err := cmdExec.RunAllowNonZero("kcarectl", "--patch-info")
if err != nil {
return false
}
if len(out) == 0 {
return false
}
return kcareReportsCopyFailPatched(out)
}
// AFAlgKernelState is the assembled view of how the running kernel
// exposes AF_ALG. Used by the hardening audit, by csm harden, and by
// the live-monitor coordinator to decide whether protection is needed.
type AFAlgKernelState struct {
BuiltIn bool // CONFIG_CRYPTO_USER_API_AEAD=y in the running kernel
Modular bool // CONFIG_CRYPTO_USER_API_AEAD=m (loadable module exists)
ConfigReadable bool // /boot/config-$(uname -r) or /proc/config.gz was parseable
LivepatchActive bool // KernelCare/kpatch has applied a CVE-2026-31431 patch
}
// observeAFAlgKernelState assembles a kernel-state snapshot via the impure
// helpers above. The struct fields document precisely what we know vs
// what we couldn't determine — callers can apply policy without
// re-deriving the same probes.
func observeAFAlgKernelState() AFAlgKernelState {
state := AFAlgKernelState{}
if cfg, ok := readKernelConfigText(); ok {
state.ConfigReadable = true
state.BuiltIn = configHasBuiltInAFAlgAEAD(cfg)
state.Modular = configHasModularAFAlgAEAD(cfg)
}
state.LivepatchActive = kcareHasCopyFailPatch()
return state
}
// IsCopyFailExploitable reports whether this kernel is currently
// vulnerable to Copy Fail (CVE-2026-31431). Used by the daemon's
// live-monitor coordinator to skip starting the listener entirely on
// hosts that don't need protection — saving the inotify watch + tick
// loop for hosts that actually face the threat.
//
// Conservative defaults: when the kernel config is unreadable, treat
// the host as exploitable (better to over-monitor than miss). When a
// KernelCare livepatch is in place, treat as patched regardless of the
// underlying config — the syscall path itself is fixed.
func (s AFAlgKernelState) IsCopyFailExploitable() bool {
if s.LivepatchActive {
return false
}
if s.ConfigReadable && !s.BuiltIn && !s.Modular {
// Kernel was definitively built without the AF_ALG aead interface.
// Nothing to exploit, no listener needed.
return false
}
// Either: confirmed-vulnerable (=y or =m without livepatch), or
// unknown (config unreadable). Both go to "exploitable" so we err
// on the side of monitoring.
return true
}
// String renders the kernel state for inclusion in operator-visible
// messages. The format is stable and short enough to embed in a single
// AuditResult.Message line.
func (s AFAlgKernelState) String() string {
switch {
case s.LivepatchActive:
return "kernel patched by KernelCare (livepatch active for " + copyFailCVE + ")"
case s.BuiltIn:
return "AF_ALG is built into the kernel (CONFIG_CRYPTO_USER_API_AEAD=y) and no livepatch is active"
case s.Modular:
return "AF_ALG aead is a loadable module on this kernel"
case s.ConfigReadable:
return "AF_ALG aead is not present in this kernel build"
default:
return "kernel config unreadable; treating as potentially vulnerable"
}
}
// ObserveAFAlgKernelState is the exported alias for cmd/csm and the
// daemon. Production code inside this package uses the unexported form
// directly.
func ObserveAFAlgKernelState() AFAlgKernelState { return observeAFAlgKernelState() }
// EnsureFile is a sentinel return value: callers can use os.IsNotExist
// to test for the "kernel-config file absent" case explicitly.
var ErrKernelConfigUnreadable = os.ErrNotExist
package checks
import (
"fmt"
"os"
"path/filepath"
"strings"
)
// SeccompDropInBaseName is the file name CSM writes inside each unit's
// /etc/systemd/system/<unit>.d/ override directory. Stable so the
// hardening audit can scan for it and the remove path can clean up
// without guessing.
const SeccompDropInBaseName = "csm-copy-fail-seccomp.conf"
// seccompDropInContent is the byte-exact body of every drop-in CSM
// writes. The marker comment lets a curious sysadmin understand who
// manages the file without consulting external docs.
const seccompDropInContent = `# CSM Copy Fail (CVE-2026-31431) seccomp mitigation - managed by CSM.
# Blocks socket(AF_ALG, ...) for processes spawned by this unit.
# Remove this file to disable.
[Service]
RestrictAddressFamilies=~AF_ALG
`
// afAlgSeccompCandidateUnits is the catalog of systemd units that, on
// shared-hosting servers, regularly spawn untrusted user-level code
// and therefore need the AF_ALG block. Units that do not exist on the
// running host are filtered out at apply time.
//
// The list is intentionally inclusive: a unit that does not exist
// adds zero overhead because we filter via systemctl list-unit-files
// before writing anything.
var afAlgSeccompCandidateUnits = []string{
// Web servers
"lshttpd.service", // LiteSpeed (cPanel default)
"httpd.service", // Apache (RHEL family, cPanel EA4 fallback)
"apache2.service", // Apache (Debian/Ubuntu)
"nginx.service", // Nginx
// PHP-FPM, cPanel EA4
"ea-php72-php-fpm.service",
"ea-php73-php-fpm.service",
"ea-php74-php-fpm.service",
"ea-php80-php-fpm.service",
"ea-php81-php-fpm.service",
"ea-php82-php-fpm.service",
"ea-php83-php-fpm.service",
"ea-php84-php-fpm.service",
"cpanel_php_fpm.service",
// PHP-FPM, generic distro versions
"php-fpm.service",
"php7.4-fpm.service",
"php8.0-fpm.service",
"php8.1-fpm.service",
"php8.2-fpm.service",
"php8.3-fpm.service",
"php8.4-fpm.service",
// Cron and mail (drop privileges to user before running content
// filters or scheduled scripts)
"crond.service",
"cron.service",
"exim.service",
"dovecot.service",
}
// SeccompUnitState describes one unit's mitigation status in operator
// terms. Returned by ScanAFAlgSeccompState so both the CLI and the
// hardening audit can render the same view without re-deriving it.
type SeccompUnitState struct {
Unit string // e.g. "lshttpd.service"
Exists bool // unit is registered with systemd on this host
HasFile bool // CSM-managed drop-in is present on disk
}
// ScanAFAlgSeccompState walks the candidate unit list and returns one
// SeccompUnitState per candidate. Units missing from systemd are still
// reported (Exists=false) so the operator can confirm CSM did not
// silently skip something they expected.
func ScanAFAlgSeccompState() []SeccompUnitState {
var out []SeccompUnitState
for _, u := range afAlgSeccompCandidateUnits {
s := SeccompUnitState{Unit: u}
s.Exists = systemdUnitExists(u)
s.HasFile = seccompDropInPresent(u)
out = append(out, s)
}
return out
}
// SeccompCoverageSummary collapses the per-unit scan into the two
// numbers an operator cares about: how many existing units have the
// CSM drop-in, and how many do not.
type SeccompCoverageSummary struct {
Covered []string // existing units with the CSM drop-in
Uncovered []string // existing units without the drop-in
NotInstalled []string // candidate units not registered with systemd
}
// SummarizeAFAlgSeccompCoverage rolls up ScanAFAlgSeccompState into
// the three-way summary above. Used by the hardening audit and the
// CLI status output.
func SummarizeAFAlgSeccompCoverage() SeccompCoverageSummary {
var s SeccompCoverageSummary
for _, u := range ScanAFAlgSeccompState() {
switch {
case !u.Exists:
s.NotInstalled = append(s.NotInstalled, u.Unit)
case u.HasFile:
s.Covered = append(s.Covered, u.Unit)
default:
s.Uncovered = append(s.Uncovered, u.Unit)
}
}
return s
}
// ApplyAFAlgSeccompDropIns writes the canonical drop-in file for every
// candidate unit that exists on this host AND does not already have
// the file. After all writes, runs systemctl daemon-reload and a
// reload-or-restart per touched unit so the seccomp filter takes
// effect immediately.
//
// Returns the list of units that received a new drop-in this call. An
// empty list with a nil error means everything was already covered
// (idempotent re-run).
func ApplyAFAlgSeccompDropIns() ([]string, error) {
var written []string
for _, u := range afAlgSeccompCandidateUnits {
if !systemdUnitExists(u) {
continue
}
if seccompDropInPresent(u) {
continue
}
if err := writeSeccompDropIn(u); err != nil {
return written, fmt.Errorf("write drop-in for %s: %w", u, err)
}
written = append(written, u)
}
if len(written) == 0 {
return nil, nil
}
if _, err := cmdExec.Run("systemctl", "daemon-reload"); err != nil {
return written, fmt.Errorf("systemctl daemon-reload: %w", err)
}
for _, u := range written {
// try-restart performs a full restart (re-exec the master process)
// only if the unit is currently active. A reload (e.g., SIGUSR2 to
// PHP-FPM) is NOT enough: systemd attaches seccomp filters at
// process spawn time, so the existing master has to be replaced
// for RestrictAddressFamilies to take effect on it and its
// workers. try-restart skips units that were intentionally
// stopped, so we don't surprise-start anything.
if _, err := cmdExec.RunAllowNonZero("systemctl", "try-restart", u); err != nil {
return written, fmt.Errorf("systemctl try-restart %s: %w", u, err)
}
}
return written, nil
}
// RemoveAFAlgSeccompDropIns deletes every CSM-managed seccomp drop-in
// found on disk and runs systemctl daemon-reload + reload-or-restart
// per touched unit so the seccomp filter is dropped from running
// processes. Idempotent: a unit without our drop-in is skipped.
//
// Returns the list of units whose drop-in was removed.
func RemoveAFAlgSeccompDropIns() ([]string, error) {
var removed []string
for _, u := range afAlgSeccompCandidateUnits {
if !seccompDropInPresent(u) {
continue
}
if err := osFS.Remove(seccompDropInPath(u)); err != nil && !os.IsNotExist(err) {
return removed, fmt.Errorf("remove drop-in for %s: %w", u, err)
}
// Best-effort: clean up the now-empty .d directory if we created it.
dir := seccompDropInDir(u)
_ = osFS.Remove(dir) // succeeds only if empty; harmless otherwise
removed = append(removed, u)
}
if len(removed) == 0 {
return nil, nil
}
if _, err := cmdExec.Run("systemctl", "daemon-reload"); err != nil {
return removed, fmt.Errorf("systemctl daemon-reload: %w", err)
}
for _, u := range removed {
if !systemdUnitExists(u) {
continue
}
// Same reasoning as Apply: full re-exec is required so the
// seccomp filter is dropped from the running master.
if _, err := cmdExec.RunAllowNonZero("systemctl", "try-restart", u); err != nil {
return removed, fmt.Errorf("systemctl try-restart %s: %w", u, err)
}
}
return removed, nil
}
// seccompDropInDir returns the override directory for the given unit:
// /etc/systemd/system/<unit>.d
func seccompDropInDir(unit string) string {
return filepath.Join("/etc/systemd/system", unit+".d")
}
// seccompDropInPath returns the full path to the CSM-managed drop-in
// file for the given unit.
func seccompDropInPath(unit string) string {
return filepath.Join(seccompDropInDir(unit), SeccompDropInBaseName)
}
// seccompDropInPresent reports whether the CSM-managed drop-in exists
// for the given unit. Content is not inspected here; the file's
// presence at the canonical path is the policy signal.
func seccompDropInPresent(unit string) bool {
_, err := osFS.Stat(seccompDropInPath(unit))
return err == nil
}
// writeSeccompDropIn creates the override directory and writes the
// canonical drop-in content. Errors propagate to the caller.
func writeSeccompDropIn(unit string) error {
dir := seccompDropInDir(unit)
if err := osFS.MkdirAll(dir, 0o755); err != nil {
return err
}
return osFS.WriteFile(seccompDropInPath(unit), []byte(seccompDropInContent), 0o644)
}
// systemdUnitExists asks systemctl whether the given unit is known to
// the running systemd. Returns false on any error, including missing
// systemctl binary, so a non-systemd host (rare on RHEL/Ubuntu) is
// treated as "no units to mitigate."
func systemdUnitExists(unit string) bool {
out, err := cmdExec.RunAllowNonZero(
"systemctl", "list-unit-files", unit, "--no-legend", "--no-pager",
)
if err != nil {
return false
}
// list-unit-files prints a row per match. An empty stdout means
// the unit name is not registered with this systemd instance.
return len(strings.TrimSpace(string(out))) > 0
}
package checks
import (
"context"
"encoding/json"
"fmt"
"net"
"path/filepath"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
"gopkg.in/yaml.v3"
)
func CheckShadowChanges(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
info, err := osFS.Stat("/etc/shadow")
if err != nil {
return nil
}
mtime := info.ModTime()
mtimeKey := "_shadow_mtime"
hashKey := "_shadow_hash"
// Load current shadow entries (user:hash pairs, no sensitive data stored)
currentEntries := parseShadowUsers()
currentHash := hashBytes([]byte(fmt.Sprintf("%v", currentEntries)))
prevMtimeRaw, mtimeExists := store.GetRaw(mtimeKey)
prevHash, hashExists := store.GetRaw(hashKey)
if mtimeExists {
var lastMtime time.Time
if err := json.Unmarshal([]byte(prevMtimeRaw), &lastMtime); err == nil {
if mtime.After(lastMtime) {
// Shadow file was modified - find what changed
var details string
if hashExists && prevHash != currentHash {
changed := diffShadowChanges(store, currentEntries)
if len(changed) > 0 {
details = fmt.Sprintf("Previous: %s\nCurrent: %s\nAccounts changed: %s",
lastMtime.Format("2006-01-02 15:04:05"),
mtime.Format("2006-01-02 15:04:05"),
strings.Join(changed, ", "))
}
}
if details == "" {
details = fmt.Sprintf("Previous: %s\nCurrent: %s",
lastMtime.Format("2006-01-02 15:04:05"),
mtime.Format("2006-01-02 15:04:05"))
}
// Check if within upcp window
sev := alert.Critical
if cfg.Suppressions.UPCPWindowStart != "" {
now := time.Now()
h, m := now.Hour(), now.Minute()
nowMin := h*60 + m
start := parseTimeMin(cfg.Suppressions.UPCPWindowStart)
end := parseTimeMin(cfg.Suppressions.UPCPWindowEnd)
if nowMin >= start && nowMin <= end {
sev = alert.Warning
}
}
// Check auditd for who made the change
auditInfo := getAuditShadowInfo()
if auditInfo != "" {
details += "\n" + auditInfo
}
// Suppress alerts for password changes made by infra IPs
// (admin-initiated password resets via WHM/xml-api)
if isInfraShadowChange(cfg) {
// Still update state, but don't alert
goto storeState
}
// Separate root password change (higher severity)
changed := diffShadowChanges(store, currentEntries)
rootChanged := false
userCount := 0
for _, c := range changed {
if c == "root" {
rootChanged = true
} else {
userCount++
}
}
if rootChanged {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "root_password_change",
Message: "Root password changed",
Details: details,
})
}
// Bulk password changes (5+ accounts at once)
if userCount >= 5 {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "bulk_password_change",
Message: fmt.Sprintf("Bulk password change: %d accounts modified", userCount),
Details: details,
})
} else {
findings = append(findings, alert.Finding{
Severity: sev,
Check: "shadow_change",
Message: "/etc/shadow modified",
Details: details,
})
}
}
}
}
storeState:
// Store current state
mtimeData, _ := json.Marshal(mtime)
store.SetRaw(mtimeKey, string(mtimeData))
store.SetRaw(hashKey, currentHash)
// Store per-user hashes for diff next time
for user, hash := range currentEntries {
store.SetRaw("_shadow_user:"+user, hash)
}
return findings
}
// parseShadowUsers reads /etc/shadow and returns a map of user -> password hash.
// Only stores a hash of the hash, not the actual password hash.
func parseShadowUsers() map[string]string {
data, err := osFS.ReadFile("/etc/shadow")
if err != nil {
return nil
}
entries := make(map[string]string)
for _, line := range strings.Split(string(data), "\n") {
parts := strings.SplitN(line, ":", 3)
if len(parts) < 2 || parts[0] == "" {
continue
}
// Store a hash of the password field, not the field itself
entries[parts[0]] = hashBytes([]byte(parts[1]))
}
return entries
}
// diffShadowChanges compares current entries against stored per-user hashes.
func diffShadowChanges(store *state.Store, current map[string]string) []string {
var changed []string
for user, hash := range current {
prev, exists := store.GetRaw("_shadow_user:" + user)
if exists && prev != hash {
changed = append(changed, user)
} else if !exists {
changed = append(changed, user+" (new)")
}
}
return changed
}
// getAuditShadowInfo checks auditd for recent shadow change events.
func getAuditShadowInfo() string {
out, err := runCmd("grep", "csm_shadow_change", "/var/log/audit/audit.log")
if err != nil || len(out) == 0 {
return ""
}
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
if len(lines) == 0 {
return ""
}
// Get the last event
last := lines[len(lines)-1]
// Extract exe= field
exe := ""
for _, part := range strings.Fields(last) {
if strings.HasPrefix(part, "exe=") {
exe = strings.Trim(strings.TrimPrefix(part, "exe="), "\"")
break
}
}
// Decode hex comm if present
comm := ""
for _, part := range strings.Fields(last) {
if strings.HasPrefix(part, "comm=") {
raw := strings.Trim(strings.TrimPrefix(part, "comm="), "\"")
decoded := decodeHexString(raw)
if decoded != "" {
comm = decoded
} else {
comm = raw
}
break
}
}
if exe != "" || comm != "" {
return fmt.Sprintf("Changed by: %s (command: %s)", exe, comm)
}
return ""
}
// decodeHexString tries to decode a hex-encoded string (auditd encodes some comm fields).
func decodeHexString(s string) string {
if len(s)%2 != 0 || len(s) < 4 {
return ""
}
// Check if it looks like hex (all hex chars)
for _, c := range s {
isHex := (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')
if !isHex {
return ""
}
}
var result []byte
for i := 0; i < len(s); i += 2 {
// #nosec G115 -- hexVal returns 0..15; (h<<4)|h fits in a byte.
b := byte(hexVal(s[i])<<4 | hexVal(s[i+1]))
if b == 0 {
break
}
result = append(result, b)
}
if len(result) == 0 {
return ""
}
return string(result)
}
func CheckUID0Accounts(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
data, err := osFS.ReadFile("/etc/passwd")
if err != nil {
return nil
}
allowedUID0 := map[string]bool{
"root": true, "sync": true, "shutdown": true,
"halt": true, "operator": true,
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Split(line, ":")
if len(fields) < 4 {
continue
}
user := fields[0]
uid := fields[2]
if uid == "0" && !allowedUID0[user] {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "uid0_account",
Message: fmt.Sprintf("Unauthorized UID 0 account: %s", user),
Details: line,
})
}
}
return findings
}
func CheckSSHKeys(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
if ctx == nil {
ctx = context.Background()
}
var findings []alert.Finding
// Check root authorized_keys
rootKeys := "/root/.ssh/authorized_keys"
if hash, err := hashFileContent(rootKeys); err == nil {
key := "_ssh_root_keys_hash"
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "ssh_keys",
Message: "Root authorized_keys modified",
Details: fmt.Sprintf("File: %s", rootKeys),
})
}
store.SetRaw(key, hash)
}
// Check for new authorized_keys in /home. Rank by mtime desc so
// recently-touched accounts are processed first when the check timeout
// cuts iteration short.
homes, _ := osFS.Glob("/home/*/.ssh/authorized_keys")
for _, keyFile := range rankPathsByMtimeDesc(ctx, homes, effectiveAccountScanMaxFiles(cfg)) {
if ctx.Err() != nil {
break
}
hash, err := hashFileContent(keyFile)
if err != nil {
continue
}
key := fmt.Sprintf("_ssh_user_keys:%s", keyFile)
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ssh_keys",
Message: fmt.Sprintf("User authorized_keys modified: %s", keyFile),
})
}
store.SetRaw(key, hash)
}
return findings
}
func CheckAPITokens(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
if ctx == nil {
ctx = context.Background()
}
var findings []alert.Finding
// WHM root API tokens
if f, fire := checkWHMRootAPITokens(store); fire {
findings = append(findings, f)
}
// User API tokens - read directly from disk instead of spawning uapi per user.
// Token files are JSON at /home/<user>/.cpanel/api_tokens/<token_name>.
// Rank account dirs by mtime desc so recently touched accounts are processed
// first when the check timeout cuts iteration short.
tokenDirs, _ := osFS.Glob("/home/*/.cpanel/api_tokens")
for _, tokenDir := range rankPathsByMtimeDesc(ctx, tokenDirs, effectiveAccountScanMaxFiles(cfg)) {
if ctx.Err() != nil {
break
}
user := filepath.Base(filepath.Dir(filepath.Dir(tokenDir)))
tokenFiles, _ := osFS.Glob(filepath.Join(tokenDir, "*"))
for _, tokenFile := range tokenFiles {
if ctx.Err() != nil {
return findings
}
tokenName := filepath.Base(tokenFile)
data, err := osFS.ReadFile(tokenFile)
if err != nil {
continue
}
content := string(data)
// Check for full access with no IP whitelist
hasFullAccess := strings.Contains(content, `"has_full_access":1`) ||
strings.Contains(content, `"has_full_access": 1`)
noWhitelist := strings.Contains(content, `"whitelist_ips":null`) ||
strings.Contains(content, `"whitelist_ips": null`) ||
strings.Contains(content, `"whitelist_ips":[]`) ||
strings.Contains(content, `"whitelist_ips": []`) ||
!strings.Contains(content, "whitelist_ips")
if hasFullAccess && noWhitelist {
known := false
for _, t := range cfg.Suppressions.KnownAPITokens {
if tokenName == t {
known = true
break
}
}
if !known {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "api_tokens",
Message: fmt.Sprintf("User %s has full-access API token '%s' with no IP whitelist", user, tokenName),
Details: fmt.Sprintf("File: %s", tokenFile),
})
}
}
}
}
return findings
}
const (
whmAPITokensStateKey = "_whm_api_tokens_state" // #nosec G101 -- bbolt state-store key name, not a credential
whmAPITokensHashKey = "_whm_api_tokens_hash" // #nosec G101 -- bbolt state-store key name, not a credential
)
func checkWHMRootAPITokens(store *state.Store) (alert.Finding, bool) {
out, err := runCmd("whmapi1", "api_token_list", "--output=json")
if err == nil {
if cur, ok := parseWHMTokenSig(out); ok {
return checkStructuredWHMRootAPITokens(store, cur, nil)
}
}
return checkWHMRootAPITokensLegacyHash(store)
}
func checkWHMRootAPITokensLegacyHash(store *state.Store) (alert.Finding, bool) {
out, err := runCmd("whmapi1", "api_token_list")
if err != nil {
return alert.Finding{}, false
}
if cur, ok := parseWHMTokenSigYAML(out); ok {
return checkStructuredWHMRootAPITokens(store, cur, out)
}
return checkWHMRootAPITokensLegacyHashOutput(store, out)
}
func checkWHMRootAPITokensLegacyHashOnly(store *state.Store) (alert.Finding, bool) {
out, err := runCmd("whmapi1", "api_token_list")
if err != nil {
return alert.Finding{}, false
}
return checkWHMRootAPITokensLegacyHashOutput(store, out)
}
func checkStructuredWHMRootAPITokens(store *state.Store, cur tokenSig, legacyOut []byte) (alert.Finding, bool) {
_, hadStructuredState := store.GetRaw(whmAPITokensStateKey)
_, hadLegacyState := store.GetRaw(whmAPITokensHashKey)
finding, fire := diffWHMTokens(store, cur)
switch {
case !hadStructuredState && hadLegacyState:
// Migrating from the legacy hash: the tokens were already vetted under
// the old scheme, so diff against the legacy hash rather than
// re-flagging the whole set as new.
var legacyFinding alert.Finding
var legacyFire bool
if legacyOut != nil {
legacyFinding, legacyFire = checkWHMRootAPITokensLegacyHashOutput(store, legacyOut)
} else {
legacyFinding, legacyFire = checkWHMRootAPITokensLegacyHashOnly(store)
}
if legacyFire {
finding, fire = legacyFinding, true
}
case !hadStructuredState && !hadLegacyState:
// True first run with no prior state of any kind. A root API token
// created by an attacker before CSM was installed would otherwise be
// baselined as "known" and never alert. Surface the pre-existing
// operator tokens for review.
finding, fire = baselineWHMTokens(cur)
}
store.SetRaw(whmAPITokensStateKey, marshalTokenSig(cur))
if legacyOut != nil {
store.SetRaw(whmAPITokensHashKey, hashBytes(legacyOut))
}
return finding, fire
}
func checkWHMRootAPITokensLegacyHashOutput(store *state.Store, out []byte) (alert.Finding, bool) {
hash := hashBytes(out)
prev, exists := store.GetRaw(whmAPITokensHashKey)
store.SetRaw(whmAPITokensHashKey, hash)
if exists && prev != hash {
return alert.Finding{
Severity: alert.Critical,
Check: "api_tokens",
Message: "WHM root API tokens changed",
Details: "Run 'whmapi1 api_token_list' to review",
}, true
}
return alert.Finding{}, false
}
// baselineWHMTokens reports pre-existing root API tokens on the very first
// scan. Routine cluster-managed trust tokens churn on their own and are
// expected, so a baseline made up only of those stays silent; any
// operator/full-access token present at baseline is surfaced for review so a
// token planted before CSM was installed cannot pass as "known".
func baselineWHMTokens(cur tokenSig) (alert.Finding, bool) {
var operator []string
for name, info := range cur {
if isClusterManagedToken(info) || isClusterManagedTokenAddition(name, info) {
continue
}
operator = append(operator, name)
}
if len(operator) == 0 {
return alert.Finding{}, false
}
sort.Strings(operator)
return alert.Finding{
Severity: alert.High,
Check: "api_tokens",
Message: "Pre-existing WHM root API tokens present at baseline",
Details: "Review with 'whmapi1 api_token_list': " + strings.Join(operator, "; "),
}, true
}
// tokenSig maps each WHM root API token name to the security traits compared
// between scans.
type tokenSig map[string]tokenInfo
type tokenInfo struct {
FullAccess bool `json:"full_access"`
ClusterManaged bool `json:"cluster_managed"`
}
// isClusterManagedToken reports whether cPanel owns the token's lifecycle.
// DNS clustering creates, rotates, and deletes these on its own:
// - reverse_trust_<uuid>: trust granted to a remote WHM peer
// - <host>-trust (e.g. ns2-trust): local end of a trust relationship
//
// Their churn is routine and must not page like an attacker-created token.
func isClusterManagedToken(info tokenInfo) bool {
return info.ClusterManaged && !info.FullAccess
}
func isClusterManagedTokenAddition(name string, info tokenInfo) bool {
return strings.HasPrefix(name, "reverse_trust_") && isClusterManagedToken(info)
}
func clusterManagedFromACLs(name string, acls map[string]bool) bool {
if strings.HasPrefix(name, "reverse_trust_") {
return true
}
return strings.HasSuffix(name, "-trust") && acls["clustering"]
}
// parseWHMTokenSig decodes `whmapi1 api_token_list --output=json` into a
// tokenSig. A token whose value is not a JSON object or whose ACLs cannot be
// read is kept with FullAccess=false, so an unparsable entry never hides a
// token's presence. ok=false means the caller should use the legacy hash path.
func parseWHMTokenSig(out []byte) (tokenSig, bool) {
var env struct {
Data struct {
Tokens map[string]json.RawMessage `json:"tokens"`
} `json:"data"`
}
if err := json.Unmarshal(out, &env); err != nil || env.Data.Tokens == nil {
return nil, false
}
sig := make(tokenSig, len(env.Data.Tokens))
for name, raw := range env.Data.Tokens {
var t struct {
ACLs map[string]json.RawMessage `json:"acls"`
}
_ = json.Unmarshal(raw, &t)
acls := decodeWHMTokenACLs(t.ACLs)
sig[name] = tokenInfo{
FullAccess: acls["all"],
ClusterManaged: clusterManagedFromACLs(name, acls),
}
}
return sig, true
}
func parseWHMTokenSigYAML(out []byte) (tokenSig, bool) {
type yamlToken struct {
ACLs map[string]any `yaml:"acls"`
}
type yamlTokenData struct {
Tokens map[string]yamlToken `yaml:"tokens"`
}
var env struct {
Data yamlTokenData `yaml:"data"`
Result struct {
Data yamlTokenData `yaml:"data"`
} `yaml:"result"`
}
if err := yaml.Unmarshal(out, &env); err != nil {
return nil, false
}
tokens := env.Data.Tokens
if tokens == nil {
tokens = env.Result.Data.Tokens
}
if tokens == nil {
return nil, false
}
sig := make(tokenSig, len(tokens))
for name, raw := range tokens {
acls := decodeWHMTokenYAMLACLs(raw.ACLs)
sig[name] = tokenInfo{
FullAccess: acls["all"],
ClusterManaged: clusterManagedFromACLs(name, acls),
}
}
return sig, true
}
func decodeWHMTokenACLs(raw map[string]json.RawMessage) map[string]bool {
acls := make(map[string]bool, len(raw))
for name, value := range raw {
acls[name] = decodeWHMTokenACL(value)
}
return acls
}
func decodeWHMTokenYAMLACLs(raw map[string]any) map[string]bool {
acls := make(map[string]bool, len(raw))
for name, value := range raw {
acls[name] = decodeWHMTokenYAMLACL(value)
}
return acls
}
func decodeWHMTokenACL(raw json.RawMessage) bool {
switch strings.TrimSpace(string(raw)) {
case "1", "true", `"1"`, `"true"`:
return true
case "0", "false", `"0"`, `"false"`, "null", "":
return false
}
var n json.Number
if err := json.Unmarshal(raw, &n); err == nil {
return n.String() == "1"
}
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return s == "1" || strings.EqualFold(s, "true")
}
var b bool
return json.Unmarshal(raw, &b) == nil && b
}
func decodeWHMTokenYAMLACL(raw any) bool {
switch v := raw.(type) {
case bool:
return v
case int:
return v == 1
case int64:
return v == 1
case uint64:
return v == 1
case float64:
return v == 1
case string:
s := strings.TrimSpace(v)
return s == "1" || strings.EqualFold(s, "true")
default:
return false
}
}
// marshalTokenSig serializes a tokenSig deterministically (encoding/json sorts
// map keys), so an unchanged set always produces an identical stored string.
func marshalTokenSig(sig tokenSig) string {
b, _ := json.Marshal(sig)
return string(b)
}
func unmarshalTokenSig(raw string) (tokenSig, bool) {
var sig tokenSig
if err := json.Unmarshal([]byte(raw), &sig); err == nil {
return sig, true
}
var legacy map[string]bool
if err := json.Unmarshal([]byte(raw), &legacy); err != nil {
return nil, false
}
sig = make(tokenSig, len(legacy))
for name, all := range legacy {
sig[name] = tokenInfo{
FullAccess: all,
ClusterManaged: !all && (strings.HasPrefix(name, "reverse_trust_") || strings.HasSuffix(name, "-trust")),
}
}
return sig, true
}
// diffWHMTokens compares the current token set against the previously stored
// one and returns a single finding when something changed. Severity splits on
// intent:
// - Critical: any token added outside generated reverse_trust churn, any
// non-cluster token removed, or ANY token gaining the full-access "all" ACL.
// - Warning: only generated reverse_trust additions/removals or recorded
// cluster trust removals changed. cPanel does this during normal DNS
// clustering and it must not page.
//
// The first scan after the key is introduced just records a baseline.
func diffWHMTokens(store *state.Store, cur tokenSig) (alert.Finding, bool) {
raw, exists := store.GetRaw(whmAPITokensStateKey)
if !exists {
return alert.Finding{}, false
}
prev, ok := unmarshalTokenSig(raw)
if !ok {
return alert.Finding{}, false
}
var critical, cluster []string
for name, info := range cur {
prevInfo, had := prev[name]
if !had {
if isClusterManagedTokenAddition(name, info) {
cluster = append(cluster, "added "+name)
} else {
critical = append(critical, "added "+name)
}
continue
}
if info.FullAccess && !prevInfo.FullAccess {
critical = append(critical, "escalated "+name+" to full access")
}
}
for name, info := range prev {
if _, still := cur[name]; still {
continue
}
if isClusterManagedToken(info) {
cluster = append(cluster, "removed "+name)
} else {
critical = append(critical, "removed "+name)
}
}
switch {
case len(critical) > 0:
sort.Strings(critical)
return alert.Finding{
Severity: alert.Critical,
Check: "api_tokens",
Message: "WHM root API tokens changed",
Details: "Review with 'whmapi1 api_token_list': " + strings.Join(critical, "; "),
}, true
case len(cluster) > 0:
sort.Strings(cluster)
return alert.Finding{
Severity: alert.Warning,
Check: "api_tokens",
Message: "WHM root cluster trust tokens changed",
Details: "cPanel DNS clustering churn (expected): " + strings.Join(cluster, "; "),
}, true
}
return alert.Finding{}, false
}
// shadowMutatingWHMEndpoints lists WHM JSON-API endpoints whose handlers
// rewrite /etc/shadow as a side effect:
// - suspendacct/unsuspendacct: lock/unlock password field, swap login shell
// - passwd/forcepasswordchange: set or expire user password
// - createacct/removeacct/killacct: add or remove the shadow entry entirely
//
// Hits on these endpoints from infra IPs explain shadow mtime changes
// without involving an attacker; hits from non-infra IPs do not.
var shadowMutatingWHMEndpoints = []string{
"/json-api/suspendacct",
"/json-api/unsuspendacct",
"/json-api/passwd",
"/json-api/forcepasswordchange",
"/json-api/createacct",
"/json-api/removeacct",
"/json-api/killacct",
}
// isInfraShadowChange reports whether every recent log signal that could
// explain a /etc/shadow modification was originated by an infra IP. It
// fuses two sources:
//
// 1. session_log PURGE password_change events (WHM/cPanel sets a new
// password, which goes through the session machinery).
// 2. successful api_tokens_log entries for shadow-mutating WHM JSON-API endpoints
// (suspendacct, passwd, createacct, ...). The cPanel session log does
// NOT record these because they are not session events, so the older
// session-only check fired on every internal `suspendacct` call.
//
// Returns true only if at least one such event was seen AND every event in
// both sources came from an infra IP (or loopback / "internal"). Any
// successful external API call or unparseable source short-circuits to false
// so a stolen token or compromised neighbour does not get a free suppression.
func isInfraShadowChange(cfg *config.Config) bool {
sessFound, sessAllInfra := scanSessionLogShadow(cfg)
tokFound, tokAllInfra := scanAPITokensLogShadow(cfg)
return (sessFound || tokFound) && sessAllInfra && tokAllInfra
}
// scanSessionLogShadow walks the cPanel session log for PURGE password_change
// events and reports whether any were seen and whether every non-loopback,
// non-"internal" source IP belonged to the infra allowlist.
func scanSessionLogShadow(cfg *config.Config) (foundAny, allInfra bool) {
allInfra = true
lines := tailFile("/usr/local/cpanel/logs/session_log", 100)
for i := len(lines) - 1; i >= 0; i-- {
line := lines[i]
if !strings.Contains(line, "PURGE") || !strings.Contains(line, "password_change") {
continue
}
foundAny = true
// Format: [ts] info [xml-api|whostmgr|security] IP PURGE account:token password_change
var ip string
for _, tag := range []string{"[xml-api]", "[whostmgr]", "[security]"} {
if idx := strings.Index(line, tag); idx >= 0 {
rest := strings.TrimSpace(line[idx+len(tag):])
fields := strings.Fields(rest)
if len(fields) > 0 {
ip = fields[0]
}
break
}
}
if ip == "internal" {
continue
}
if !isTrustedShadowSource(ip, cfg) {
allInfra = false
return
}
}
return
}
// scanAPITokensLogShadow walks the WHM api_tokens_log for recent calls to
// JSON-API endpoints that rewrite /etc/shadow. Returns (foundAny, allInfra)
// with the same semantics as scanSessionLogShadow.
func scanAPITokensLogShadow(cfg *config.Config) (foundAny, allInfra bool) {
allInfra = true
lines := tailFile("/usr/local/cpanel/logs/api_tokens_log", 200)
for i := len(lines) - 1; i >= 0; i-- {
line := lines[i]
if !lineHitsShadowEndpoint(line) {
continue
}
if !apiTokensHTTPStatusOK(line) {
continue
}
foundAny = true
ip := extractAPITokensHost(line)
if ip == "internal" {
continue
}
if !isTrustedShadowSource(ip, cfg) {
allInfra = false
return
}
}
return
}
func lineHitsShadowEndpoint(line string) bool {
path := extractAPITokensRequestPath(line)
if path == "" {
return false
}
for _, ep := range shadowMutatingWHMEndpoints {
if path == ep {
return true
}
}
return false
}
func apiTokensHTTPStatusOK(line string) bool {
status := extractAPITokensField(line, "HTTP Status: ['")
return strings.HasPrefix(status, "2")
}
// extractAPITokensHost pulls the source IP out of the api_tokens_log line
// shape used by whostmgrd:
//
// [ts] info [whostmgrd] Host: ['<ip>'] HTTP Status: [...], ...
func extractAPITokensHost(line string) string {
return extractAPITokensField(line, "Host: ['")
}
func extractAPITokensRequestPath(line string) string {
request := extractAPITokensField(line, "Request: ['")
fields := strings.Fields(request)
if len(fields) < 2 {
return ""
}
path := fields[1]
if idx := strings.IndexByte(path, '?'); idx >= 0 {
path = path[:idx]
}
return path
}
func extractAPITokensField(line, marker string) string {
idx := strings.Index(line, marker)
if idx < 0 {
return ""
}
rest := line[idx+len(marker):]
end := strings.Index(rest, "']")
if end < 0 {
return ""
}
return rest[:end]
}
func isTrustedShadowSource(ip string, cfg *config.Config) bool {
parsed := net.ParseIP(ip)
if parsed != nil && parsed.IsLoopback() {
return true
}
return isInfraIP(ip, cfg.InfraIPs)
}
func parseTimeMin(s string) int {
parts := strings.Split(s, ":")
if len(parts) != 2 {
return 0
}
h := 0
m := 0
fmt.Sscanf(parts[0], "%d", &h)
fmt.Sscanf(parts[1], "%d", &m)
return h*60 + m
}
package checks
import (
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/atomicio"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/firewall"
)
const (
defaultBlockExpiry = "24h"
blockStateFile = "blocked_ips.json"
// maxPendingBlocks bounds the rate-limit overflow queue. Under a
// sustained flood the daemon can see more distinct attacker IPs in an
// hour than the block cap allows; without a bound the pending queue
// grows without limit and bloats blocked_ips.json. Dropped IPs are
// re-detected from the same findings on the next scan, so the cap
// loses no durable protection.
maxPendingBlocks = 1000
)
// IPBlocker abstracts the firewall engine for auto-blocking.
// When set, blocks go through nftables firewall engine.
type IPBlocker interface {
BlockIP(ip string, reason string, timeout time.Duration) error
UnblockIP(ip string) error
IsBlocked(ip string) bool
}
// outcomeBlocker is satisfied by engines that report what they actually
// did. When the wired IPBlocker supports it, the auto-block path uses the
// outcome to decide whether to apply local side effects: a dry-run or
// verdict-allowed call must not mutate blocked_ips.json, must not bump the
// hourly counter, and must not emit the operator-facing "AUTO-BLOCK"
// finding (which would falsely claim a real block landed). The plain
// IPBlocker interface stays as a back-compat fallback for tests and any
// legacy implementation.
type outcomeBlocker interface {
BlockIPOutcome(ip, reason string, timeout time.Duration) (firewall.BlockOutcome, error)
}
// liveBlocker is satisfied by engines that can query the live kernel firewall
// state, not just an in-memory cache built from state.json. The tracker
// reconcile loop prefers this because the cache can drift when nft
// auto-expires entries faster than CSM rewrites state.json, or when an
// out-of-band flush dropped entries the cache still claims are live. Falls
// back to IPBlocker.IsBlocked when the live query is unavailable.
type liveBlocker interface {
IsBlockedLive(ip string) (bool, error)
}
type subnetBlocker interface {
BlockSubnet(cidr string, reason string, timeout time.Duration) error
}
// permanentPromoter is satisfied by engines that can upgrade an existing
// temporary block to a permanent one. PermBlock escalation runs in the same
// scan cycle as the temp block that triggered it, so the ordinary block path
// (which skips an already-blocked IP and returns BlockOutcomeNoop) would never
// clear the kernel timeout and the "permanent" block would silently expire.
type permanentPromoter interface {
PromoteToPermanentBlock(ip, reason string) error
}
type subnetBlockStatus interface {
IsSubnetBlocked(cidr string) bool
}
// fwBlockerSlot wraps an IPBlocker so atomic.Pointer can store it. The
// extra struct layer is required because atomic.Pointer needs a
// concrete type and interfaces cannot be stored directly.
type fwBlockerSlot struct{ b IPBlocker }
var fwBlockerHolder atomic.Pointer[fwBlockerSlot]
var blockStateMu sync.Mutex
var autoBlockNow = time.Now
// SetIPBlocker installs the firewall engine for auto-blocking. Safe to
// call concurrently with AutoBlockIPs: each call publishes the new
// blocker atomically and any in-flight scan keeps the snapshot it
// already loaded.
func SetIPBlocker(b IPBlocker) {
fwBlockerHolder.Store(&fwBlockerSlot{b: b})
}
// getIPBlocker returns the current blocker via a single atomic load.
// Callers should capture the result into a local variable and reuse it
// for the duration of one operation so a concurrent SetIPBlocker
// cannot split a single scan across two different engines.
func getIPBlocker() IPBlocker {
slot := fwBlockerHolder.Load()
if slot == nil {
return nil
}
return slot.b
}
type blockedIP struct {
IP string `json:"ip"`
Reason string `json:"reason"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type pendingIP struct {
IP string `json:"ip"`
Reason string `json:"reason"`
}
type blockState struct {
IPs []blockedIP `json:"ips"`
Pending []pendingIP `json:"pending,omitempty"` // IPs waiting for rate-limit reset
BlocksThisHour int `json:"blocks_this_hour"`
HourKey string `json:"hour_key"`
// RateLimitWarnedHour is the HourKey for which the rate-limit warning
// was already emitted. The warning reflects a steady-state condition,
// not a per-IP event, so it fires once per hour window instead of on
// every scan tick -- the per-tick emission flooded the audit log with
// one identical finding every few seconds during a sustained attack.
RateLimitWarnedHour string `json:"rate_limit_warned_hour,omitempty"`
}
// AutoBlockIPs processes findings and blocks attacker IPs via the firewall engine.
// Note: this should be called with ALL findings (not just new ones)
// for reputation-based blocking to work on repeat offenders.
func AutoBlockIPs(cfg *config.Config, findings []alert.Finding) []alert.Finding {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.BlockIPs {
return nil
}
blockStateMu.Lock()
defer blockStateMu.Unlock()
// Snapshot the wired firewall engine ONCE per call. A concurrent
// SetIPBlocker (SIGHUP re-wire, test cleanup) can swap the global
// mid-scan; reading the atomic pointer once and reusing the
// returned value keeps every block decision in this batch routed
// to the same engine. The previous unsynchronized read of the
// global also tripped the race detector.
blocker := getIPBlocker()
var actions []alert.Finding
// Load block state
state := loadBlockState(cfg.StatePath)
// Prune IPs that the firewall engine no longer has blocked.
// The engine handles expiry natively via nftables timeouts -
// we just sync our state to match. Use the live nftables query
// when the engine supports it so the tracker stays in lock-step
// with the kernel; the in-memory cache (IsBlocked) can lag when
// the kernel expires entries before state.json is rewritten.
var stillBlocked []blockedIP
for _, b := range state.IPs {
if blocker != nil {
if !isBlockedLiveOrCached(blocker, b.IP) {
// Engine expired this block - clean up our state
fmt.Fprintf(os.Stderr, "[%s] AUTO-UNBLOCK: %s removed (engine expired)\n", time.Now().Format("2006-01-02 15:04:05"), b.IP)
continue
}
}
stillBlocked = append(stillBlocked, b)
}
state.IPs = stillBlocked
// Check rate limit
currentHour := autoBlockNow().Format("2006-01-02T15")
if state.HourKey != currentHour {
state.HourKey = currentHour
state.BlocksThisHour = 0
}
// Collect IPs to block from findings
ipsToBlock := make(map[string]string) // ip -> reason
// Always blockable findings carry a confirmed attacker IP: thresholded
// brute force, confirmed compromise, C2/reputation, or escalation.
// Raw mailbox auth failures and account-only mail findings feed incident
// grouping and thresholded trackers, but one row is not enough evidence
// for a firewall block.
alwaysBlock := map[string]bool{
"wp_login_bruteforce": true,
"xmlrpc_abuse": true,
"http_request_flood": true,
"http_ua_spoof": true,
"ftp_bruteforce": true,
"smtp_bruteforce": true,
"smtp_probe_abuse": true,
"mail_bruteforce": true,
"mail_account_compromised": true,
"admin_panel_bruteforce": true,
"ssh_login_unknown_ip": true,
"ssh_login_realtime": true,
"c2_connection": true,
"ip_reputation": true,
"local_threat_score": true,
"modsec_block_escalation": true,
"modsec_csm_block_escalation": true,
"email_compromised_account": true,
"email_cloud_relay_abuse": true,
"waf_attack_blocked": true,
}
// Only blockable when block_cpanel_logins is enabled (disabled by default).
// cpanel_login / cpanel_login_realtime are deliberately absent: those
// fire as Warning-level audit on every direct form login from a non-
// infra IP and a single event is not brute-force evidence. Blocking on
// one Warning turns a legitimate customer logging in from a new country
// into a 24h lockout. Thresholded brute checks below stay blockable.
cpanelWebmailChecks := map[string]bool{
"cpanel_multi_ip_login": true,
"cpanel_file_upload_realtime": true,
"api_auth_failure": true,
"api_auth_failure_realtime": true,
"webmail_bruteforce": true,
"webmail_login_realtime": true,
"ftp_login_realtime": true,
"ftp_auth_failure_realtime": true,
}
// Drain pending queue first (IPs from prior rate-limited cycles)
for _, p := range state.Pending {
if !isAlreadyBlocked(state, p.IP) {
ipsToBlock[p.IP] = p.Reason
}
}
state.Pending = nil
// Subnet fast-path: checks that represent a subnet directly.
// Independent of the per-IP rate limit, because a single subnet block
// replaces what would otherwise be hundreds of per-IP blocks.
for _, f := range findings {
if f.Check != "smtp_subnet_spray" && f.Check != "mail_subnet_spray" {
continue
}
cidr := extractCIDRFromFinding(f)
if cidr == "" {
continue
}
if blocker == nil {
fmt.Fprintf(os.Stderr, "auto-block: firewall engine not available, skipping subnet %s\n", cidr)
continue
}
sb, ok := blocker.(subnetBlocker)
if !ok {
fmt.Fprintf(os.Stderr, "auto-block: firewall engine does not support subnet blocking, skipping %s\n", cidr)
continue
}
if isSubnetAlreadyBlocked(blocker, cidr) {
continue
}
reason := fmt.Sprintf("CSM auto-block (subnet): %s", truncate(f.Message, 100))
if err := sb.BlockSubnet(cidr, reason, parseExpiry(cfg.AutoResponse.BlockExpiry)); err != nil {
fmt.Fprintf(os.Stderr, "auto-block: error blocking subnet %s: %v\n", cidr, err)
continue
}
fmt.Fprintf(os.Stderr, "[%s] AUTO-BLOCK-SUBNET: %s blocked\n", time.Now().Format("2006-01-02 15:04:05"), cidr)
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-BLOCK-SUBNET: %s blocked", cidr),
Details: fmt.Sprintf("Reason: %s", f.Message),
Timestamp: time.Now(),
})
}
for _, f := range findings {
isBlockable := alwaysBlock[f.Check]
if !isBlockable && cfg.AutoResponse.BlockCpanelLogins && cpanelWebmailChecks[f.Check] {
isBlockable = true
}
if !isBlockable {
continue
}
ip := extractIPFromFinding(f)
if ip == "" {
continue
}
// Never block infra IPs
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
// Don't re-block already blocked IPs.
if isAlreadyBlocked(state, ip) || (blocker != nil && isBlockedLiveOrCached(blocker, ip)) {
continue
}
// Skip IPs on the challenge list (they'll be challenged, not blocked)
if cl := GetChallengeIPList(); cl != nil && cl.Contains(ip) {
continue
}
ipsToBlock[ip] = f.Message
}
// Block IPs - queue any that can't be blocked due to rate limit
expiry := parseExpiry(cfg.AutoResponse.BlockExpiry)
maxPerHour := cfg.AutoResponse.MaxBlocksPerHour
if maxPerHour <= 0 {
maxPerHour = config.DefaultMaxBlocksPerHour
}
rateLimited := false
droppedPending := 0
for ip, reason := range ipsToBlock {
if state.BlocksThisHour >= maxPerHour {
// Queue for next cycle instead of dropping, bounded so a
// sustained flood cannot grow the queue without limit.
if len(state.Pending) < maxPendingBlocks {
state.Pending = append(state.Pending, pendingIP{IP: ip, Reason: reason})
} else {
droppedPending++
}
rateLimited = true
continue
}
// Block via firewall engine (nftables)
blockReason := fmt.Sprintf("CSM auto-block: %s", truncate(reason, 100))
if blocker == nil {
fmt.Fprintf(os.Stderr, "auto-block: firewall engine not available, skipping %s\n", ip)
continue
}
outcome, err := callBlockIP(blocker, ip, blockReason, expiry)
if err != nil {
fmt.Fprintf(os.Stderr, "auto-block: error blocking %s: %v\n", ip, err)
continue
}
switch outcome {
case firewall.BlockOutcomeLive:
// nft was mutated. Record the real block below.
case firewall.BlockOutcomeDryRun:
// dry-run intercepted: nft was NOT mutated. Do not record a real
// block locally or in the permanent threat DB; emit a Warning
// notice instead so operators see what would have been blocked.
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-BLOCK [dry-run]: %s would be blocked (expires in %s)", ip, expiry),
Details: fmt.Sprintf("Reason: %s", reason),
Timestamp: time.Now(),
})
continue
case firewall.BlockOutcomeAllowed:
// Verdict callback returned "allow": CSM intentionally did not
// block. Stay silent at finding level - the panel already knows
// it downgraded the decision and the engine logged it to stderr.
continue
case firewall.BlockOutcomeNoop:
// Already-blocked, deny-limit, or other guard rejected the call.
// No local state to record.
continue
default:
fmt.Fprintf(os.Stderr, "auto-block: unknown block outcome %q for %s, skipping local state\n", outcome, ip)
continue
}
if blocker.IsBlocked(ip) {
fmt.Fprintf(os.Stderr, "[%s] AUTO-BLOCK: %s blocked (expires in %s)\n", time.Now().Format("2006-01-02 15:04:05"), ip, expiry)
}
state.BlocksThisHour++
// Add to permanent local threat database
if db := GetThreatDB(); db != nil {
db.AddPermanent(ip, reason)
}
state.IPs = append(state.IPs, blockedIP{
IP: ip,
Reason: reason,
BlockedAt: time.Now(),
ExpiresAt: time.Now().Add(expiry),
})
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-BLOCK: %s blocked (expires in %s)", ip, expiry),
Details: fmt.Sprintf("Reason: %s", reason),
Timestamp: time.Now(),
})
// Permanent block escalation: promote to permanent after N temp blocks
if cfg.AutoResponse.PermBlock {
count := cfg.AutoResponse.PermBlockCount
if count < 2 {
count = 4
}
interval := parseExpiry(cfg.AutoResponse.PermBlockInterval)
if interval == 0 {
interval = 24 * time.Hour
}
if checkPermBlockEscalation(cfg.StatePath, ip, count, interval) {
permReason := fmt.Sprintf("PERMBLOCK: %d temp blocks within %s", count, interval)
if promoteToPermanentBlock(blocker, ip, permReason) {
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-PERMBLOCK: %s promoted to permanent block (%d temp blocks)", ip, count),
Timestamp: time.Now(),
})
}
}
}
}
if rateLimited && state.RateLimitWarnedHour != currentHour {
state.RateLimitWarnedHour = currentHour
msg := fmt.Sprintf("Auto-block rate limit reached (%d/hour), %d IPs queued for next cycle", maxPerHour, len(state.Pending))
if droppedPending > 0 {
msg += fmt.Sprintf(", %d dropped (queue full)", droppedPending)
}
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_block",
Message: msg,
Timestamp: time.Now(),
})
}
// Subnet auto-blocking: detect per-family subnet patterns.
if cfg.AutoResponse.NetBlock && blocker != nil {
threshold := cfg.AutoResponse.NetBlockThreshold
if threshold < 2 {
threshold = 3
}
subnetExpiry := parseExpiry(cfg.AutoResponse.BlockExpiry)
// Count blocked IPs per subnet (IPv4 /24, IPv6 /64).
subnetCounts := make(map[string]int)
subnetBlocked := make(map[string]bool)
for _, b := range state.IPs {
cidr := subnetEscalationCIDR(b.IP)
if cidr != "" {
subnetCounts[cidr]++
}
}
for cidr, count := range subnetCounts {
if count >= threshold && !subnetBlocked[cidr] {
if sb, ok := blocker.(subnetBlocker); ok {
if isSubnetAlreadyBlocked(blocker, cidr) {
continue
}
reason := fmt.Sprintf("Auto-netblock: %d IPs from %s", count, cidr)
if err := sb.BlockSubnet(cidr, reason, subnetExpiry); err == nil {
subnetBlocked[cidr] = true
fmt.Fprintf(os.Stderr, "[%s] AUTO-NETBLOCK: %s blocked (%d IPs from same subnet)\n", time.Now().Format("2006-01-02 15:04:05"), cidr, count)
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-NETBLOCK: %s blocked (%d IPs from same subnet)", cidr, count),
Timestamp: time.Now(),
})
}
}
}
}
}
// Save state (expired IPs were already pruned at the top of this function)
saveBlockState(cfg.StatePath, state)
return actions
}
// isBlockedLiveOrCached returns the live nftables status when the
// blocker supports it, otherwise falls back to the cached IsBlocked
// view. The reconcile loop relies on this to prune blocked_ips.json
// entries the kernel has already expired even when state.json has not
// caught up yet. Live lookup errors keep the cached answer so transient
// netlink failures do not erase the local tracker.
func isBlockedLiveOrCached(b IPBlocker, ip string) bool {
if lb, ok := b.(liveBlocker); ok {
blocked, err := lb.IsBlockedLive(ip)
if err == nil {
return blocked
}
}
return b.IsBlocked(ip)
}
// callBlockIP dispatches to the outcome-reporting interface when the
// underlying blocker implements it, otherwise falls back to the legacy
// IPBlocker interface and assumes the call landed live (the behaviour
// every IPBlocker had before BlockIPOutcome existed). This keeps tests
// and any third-party implementations of IPBlocker working unchanged.
func callBlockIP(b IPBlocker, ip, reason string, timeout time.Duration) (firewall.BlockOutcome, error) {
if ob, ok := b.(outcomeBlocker); ok {
return ob.BlockIPOutcome(ip, reason, timeout)
}
if err := b.BlockIP(ip, reason, timeout); err != nil {
return firewall.BlockOutcomeNoop, err
}
return firewall.BlockOutcomeLive, nil
}
// promoteToPermanentBlock upgrades an existing temp block to permanent. The
// real engine implements permanentPromoter and clears the kernel timeout in
// place. Legacy blockers that only implement BlockIP have not marked the IP
// blocked in a way that trips skipExisting, so a fresh zero-timeout block on
// them lands live; that fallback preserves pre-existing behaviour for tests
// and third-party implementations.
func promoteToPermanentBlock(b IPBlocker, ip, reason string) bool {
if pp, ok := b.(permanentPromoter); ok {
if err := pp.PromoteToPermanentBlock(ip, reason); err != nil {
fmt.Fprintf(os.Stderr, "auto-block: permblock promotion of %s failed: %v\n", ip, err)
return false
}
return true
}
outcome, err := callBlockIP(b, ip, reason, 0)
return err == nil && outcome == firewall.BlockOutcomeLive
}
func isSubnetAlreadyBlocked(b IPBlocker, cidr string) bool {
sb, ok := b.(subnetBlockStatus)
return ok && sb.IsSubnetBlocked(cidr)
}
// ExtractIPFromFinding extracts an IP address from a finding.
func ExtractIPFromFinding(f alert.Finding) string {
return extractIPFromFinding(f)
}
func extractIPFromFinding(f alert.Finding) string {
if strings.TrimSpace(f.SourceIP) != "" {
return normalizeBlockIP(f.SourceIP)
}
msg := f.Message
// Fallback for detectors that have not yet adopted the structured SourceIP
// field. Only findings whose Check is auto-block-eligible reach this path,
// and those detectors format their own messages with a CSM-parsed IP at the
// tail. Use LastIndex so the rightmost (CSM-appended) IP wins over any
// log-injected content earlier in the message.
for _, sep := range []string{" from ", ": "} {
if idx := strings.LastIndex(msg, sep); idx >= 0 {
rest := msg[idx+len(sep):]
fields := strings.Fields(rest)
if len(fields) > 0 {
candidate := strings.TrimRight(fields[0], ",:;)([]")
if ip := normalizeBlockIP(candidate); ip != "" {
return ip
}
}
}
}
return ""
}
func normalizeBlockIP(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
if host, _, err := net.SplitHostPort(raw); err == nil {
raw = host
}
raw = strings.Trim(raw, "[]")
ip := net.ParseIP(raw)
if ip == nil || ip.IsLoopback() || ip.IsUnspecified() {
return ""
}
return ip.String()
}
func isAlreadyBlocked(state *blockState, ip string) bool {
for _, b := range state.IPs {
if b.IP == ip {
return true
}
}
return false
}
func parseExpiry(s string) time.Duration {
if s == "" {
s = defaultBlockExpiry
}
d, err := time.ParseDuration(s)
if err != nil {
return 24 * time.Hour
}
return d
}
func loadBlockState(statePath string) *blockState {
state := &blockState{}
path := filepath.Join(statePath, blockStateFile)
data, err := osFS.ReadFile(path)
if err == nil {
if uerr := json.Unmarshal(data, state); uerr != nil {
fmt.Fprintf(os.Stderr, "autoblock: %s is corrupt, ignoring queued blocks: %v\n", path, uerr)
}
}
return state
}
func saveBlockState(statePath string, s *blockState) {
path := filepath.Join(statePath, blockStateFile)
if err := atomicio.AtomicWriteJSON(path, 0o600, s); err != nil {
fmt.Fprintf(os.Stderr, "autoblock: persist %s failed: %v\n", path, err)
}
}
// PendingBlockIPs returns IPs queued for blocking (rate-limited).
// Used by alert.FilterBlockedAlerts to suppress reputation alerts for these IPs.
func PendingBlockIPs(statePath string) map[string]bool {
state := loadBlockState(statePath)
ips := make(map[string]bool, len(state.Pending))
for _, p := range state.Pending {
ips[p.IP] = true
}
return ips
}
// subnetEscalationCIDR returns the canonical CIDR used by the
// auto-netblock escalation path for the given IP. IPv4 collapses to
// /24 (the historical block size); IPv6 collapses to /64 because most
// providers hand out /64 prefixes to end users -- /128 would let
// attackers rotate addresses inside the same /64 and never escalate,
// while a wider prefix would risk taking down legitimate neighbours.
// Returns "" for unparseable input.
func subnetEscalationCIDR(ip string) string {
parsed := net.ParseIP(ip)
if parsed == nil {
return ""
}
if ip4 := parsed.To4(); ip4 != nil {
return fmt.Sprintf("%d.%d.%d.0/24", ip4[0], ip4[1], ip4[2])
}
ip16 := parsed.To16()
if ip16 == nil {
return ""
}
mask := net.CIDRMask(64, 128)
network := ip16.Mask(mask)
return (&net.IPNet{IP: network, Mask: mask}).String()
}
// --- Permanent block escalation (LF_PERMBLOCK) ---
type permBlockTracker struct {
IPs map[string][]time.Time `json:"ips"` // IP -> list of block timestamps
}
// checkPermBlockEscalation records a new block and returns true if the IP
// has been temp-blocked count times within interval.
func checkPermBlockEscalation(statePath, ip string, count int, interval time.Duration) bool {
tracker := loadPermBlockTracker(statePath)
now := time.Now()
cutoff := now.Add(-interval)
// Add current block timestamp
tracker.IPs[ip] = append(tracker.IPs[ip], now)
// Clean old entries for this IP
var recent []time.Time
for _, t := range tracker.IPs[ip] {
if t.After(cutoff) {
recent = append(recent, t)
}
}
tracker.IPs[ip] = recent
// Clean old IPs entirely (haven't been seen in 2x the interval)
for k, times := range tracker.IPs {
if len(times) == 0 {
delete(tracker.IPs, k)
continue
}
latest := times[len(times)-1]
if now.Sub(latest) > interval*2 {
delete(tracker.IPs, k)
}
}
savePermBlockTracker(statePath, tracker)
return len(recent) >= count
}
func loadPermBlockTracker(statePath string) *permBlockTracker {
tracker := &permBlockTracker{IPs: make(map[string][]time.Time)}
path := filepath.Join(statePath, "permblock_tracker.json")
data, err := osFS.ReadFile(path)
if err == nil {
if uerr := json.Unmarshal(data, tracker); uerr != nil {
fmt.Fprintf(os.Stderr, "autoblock: %s is corrupt, ignoring escalation history: %v\n", path, uerr)
}
if tracker.IPs == nil {
tracker.IPs = make(map[string][]time.Time)
}
}
return tracker
}
func savePermBlockTracker(statePath string, tracker *permBlockTracker) {
path := filepath.Join(statePath, "permblock_tracker.json")
if err := atomicio.AtomicWriteJSON(path, 0o600, tracker); err != nil {
fmt.Fprintf(os.Stderr, "autoblock: persist %s failed: %v\n", path, err)
}
}
// extractCIDRFromFinding returns the CIDR appearing in the message after
// the canonical " from " separator. Returns "" if the value does not parse
// as a CIDR.
func extractCIDRFromFinding(f alert.Finding) string {
msg := f.Message
idx := strings.LastIndex(msg, " from ")
if idx < 0 {
return ""
}
rest := msg[idx+len(" from "):]
fields := strings.Fields(rest)
if len(fields) == 0 {
return ""
}
candidate := strings.TrimRight(fields[0], ",:;)([]")
_, ipnet, err := net.ParseCIDR(candidate)
if err != nil {
return ""
}
return ipnet.String()
}
package checks
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"syscall"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
// var (not const) so tests can redirect to t.TempDir().
var quarantineDir = "/opt/csm/quarantine"
// QuarantineMeta stores original file metadata alongside quarantined files.
type QuarantineMeta struct {
OriginalPath string `json:"original_path"`
Owner int `json:"owner_uid"`
Group int `json:"group_gid"`
Mode string `json:"mode"`
Size int64 `json:"size"`
QuarantineAt time.Time `json:"quarantined_at"`
Reason string `json:"reason"`
}
// AutoKillProcesses kills processes that match critical findings.
// Only targets: fake kernel threads, reverse shells, GSocket processes.
// Never kills root system services or cPanel processes.
func AutoKillProcesses(cfg *config.Config, findings []alert.Finding) []alert.Finding {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.KillProcesses {
return nil
}
var actions []alert.Finding
for _, f := range findings {
// Only act on specific high-confidence critical checks
switch f.Check {
case "fake_kernel_thread", "suspicious_process", "php_suspicious_execution":
default:
continue
}
if f.Severity != alert.Critical {
continue
}
// Use structured PID field when available, fall back to text extraction
pid := fmt.Sprintf("%d", f.PID)
if f.PID == 0 {
pid = extractPID(f.Details)
if pid == "" {
continue
}
}
// Safety: verify the process is not root/system
uid := getProcessUID(pid)
if uid == "0" || uid == "" {
continue // never kill root processes automatically
}
// Safety: verify it's not a cPanel/system process
exe := getProcessExe(pid)
if isSafeProcess(exe) {
continue
}
// Kill it
pidInt := f.PID
if pidInt == 0 {
fmt.Sscanf(pid, "%d", &pidInt)
}
if pidInt <= 1 {
continue
}
err := syscall.Kill(pidInt, syscall.SIGKILL)
if err != nil {
continue
}
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-KILL: Process %s killed (was: %s)", pid, f.Check),
Details: fmt.Sprintf("Original finding: %s\nProcess: %s (UID: %s)", f.Message, exe, uid),
})
}
return actions
}
// AutoQuarantineFiles moves malicious files to quarantine directory.
// Preserves original path and metadata in a sidecar .meta file.
func AutoQuarantineFiles(cfg *config.Config, findings []alert.Finding) []alert.Finding {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.QuarantineFiles {
return nil
}
var actions []alert.Finding
for _, f := range findings {
// Only quarantine specific file-based findings
isRealtimeMatch := false
switch f.Check {
case "webshell", "backdoor_binary", "new_webshell_file", "new_executable_in_config",
"obfuscated_php", "php_dropper", "suspicious_php_content",
"new_php_in_languages", "new_php_in_upgrade",
"phishing_page", "phishing_directory",
"htaccess_handler_abuse":
case "signature_match_realtime":
isRealtimeMatch = true
default:
continue
}
if f.Severity != alert.Critical {
continue
}
// Extract file path - prefer structured field, fallback to message parsing
path := f.FilePath
if path == "" {
path = extractFilePath(f.Message) // fallback for legacy findings
}
if path == "" {
continue
}
// Realtime signature matches require additional validation to avoid
// quarantining false positives (e.g. legitimate PHPMailer matching
// "webshell_marijuana", or zip libraries matching hex patterns).
// Only quarantine when the file is genuinely obfuscated malware.
if isRealtimeMatch && !isHighConfidenceRealtimeMatch(f, path, nil) {
continue
}
// Verify file or directory exists and reject symlinks
info, err := osFS.Lstat(path)
if err != nil {
continue
}
if info.Mode()&os.ModeSymlink != 0 {
continue
}
// Realtime high-confidence matches are fully obfuscated malware -
// there is no legitimate code to preserve, skip cleaning and go
// straight to quarantine.
if isRealtimeMatch {
goto quarantine
}
// For WP core/plugin/theme files: clean surgically instead of quarantining.
// This preserves site functionality while removing the injected code.
if !info.IsDir() && ShouldCleanInsteadOfQuarantine(path) {
result := CleanInfectedFile(path)
switch {
case result.Cleaned:
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-CLEAN: %s surgically cleaned", path),
Details: fmt.Sprintf("Backup: %s\n%s", result.BackupPath, strings.Join(result.Removals, "\n")),
Timestamp: time.Now(),
})
continue // successfully cleaned, skip quarantine
case result.Error != "":
// Cleaning failed - fall through to quarantine
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-CLEAN failed for %s, quarantining instead", path),
Details: result.Error,
Timestamp: time.Now(),
})
// Don't continue - fall through to quarantine below
default:
continue // no changes needed
}
}
quarantine:
// Create quarantine directory
_ = os.MkdirAll(quarantineDir, 0700)
// Build quarantine destination preserving directory structure
safeName := strings.ReplaceAll(path, "/", "_")
ts := time.Now().Format("20060102-150405")
qPath := filepath.Join(quarantineDir, fmt.Sprintf("%s_%s", ts, safeName))
// Get file ownership
var uid, gid int
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
uid = int(stat.Uid)
gid = int(stat.Gid)
}
// Handle directory quarantine (e.g., LEVIATHAN/ webshell directories)
if info.IsDir() {
if err := os.Rename(path, qPath); err != nil {
// Cross-device: skip directory move (too complex for auto-response)
continue
}
} else {
// Move file to quarantine via the TOCTOU-safe path: fd open,
// fstat-verify, hardlink-by-fd, unlink. If hardlinking is
// unavailable, copy from the same verified fd.
if err := quarantineFileTOCTOUSafe(path, qPath, info); err != nil {
fmt.Fprintf(os.Stderr, "autoresponse: refused quarantine of %s: %v\n", path, err)
continue
}
}
// Write metadata sidecar
meta := QuarantineMeta{
OriginalPath: path,
Owner: uid,
Group: gid,
Mode: info.Mode().String(),
Size: info.Size(),
QuarantineAt: time.Now(),
Reason: f.Message,
}
metaData, _ := json.MarshalIndent(meta, "", " ")
_ = os.WriteFile(qPath+".meta", metaData, 0600)
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-QUARANTINE: %s moved to quarantine", path),
Details: fmt.Sprintf("Quarantined to: %s\nOriginal finding: %s", qPath, f.Message),
})
}
return actions
}
// AutoFixPermissions sets world/group-writable PHP files to 0644.
// Returns the auto-response action findings and the keys of original findings
// that were successfully fixed (so the caller can dismiss them from the UI).
func AutoFixPermissions(cfg *config.Config, findings []alert.Finding) (actions []alert.Finding, fixedKeys []string) {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.EnforcePermissions {
return nil, nil
}
for _, f := range findings {
switch f.Check {
case "world_writable_php", "group_writable_php":
default:
continue
}
path := extractFilePath(f.Message)
if path == "" {
continue
}
path, info, err := resolveExistingFixPath(path, fixPermissionsAllowedRoots)
if err != nil || info.IsDir() {
continue
}
oldMode := info.Mode().Perm()
// #nosec G302 -- same as fixPermissions: restoring canonical web-content
// mode on a user file flagged for dangerous (e.g. world-writable) perms.
if err := os.Chmod(path, 0644); err != nil {
continue
}
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-FIX: %s permissions set to 644 (was %o)", path, oldMode),
Timestamp: time.Now(),
})
fixedKeys = append(fixedKeys, f.Check+":"+f.Message)
}
return actions, fixedKeys
}
// AutoFixWPCron disables WP-Cron and installs a per-user system cron for every
// perf_wp_cron finding. Returns the auto-response action findings and the keys
// of the originals so the caller can dismiss them. Gated behind an explicit
// opt-in because it edits customer wp-config.php and crontabs.
func AutoFixWPCron(cfg *config.Config, findings []alert.Finding) (actions []alert.Finding, fixedKeys []string) {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.FixWPCron {
return nil, nil
}
opts := WPCronFixOptions{
IntervalMinutes: cfg.Performance.WPCronFix.IntervalMinutes,
PHPBin: cfg.Performance.WPCronFix.PHPBin,
}
allowedRoots := ResolveWebRoots(cfg)
for _, f := range findings {
if f.Check != "perf_wp_cron" {
continue
}
path := extractWPConfigPath(f.Details)
if path == "" {
continue
}
res := FixDisableWPCronInRoots(path, allowedRoots, opts)
if !res.Success {
continue
}
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-FIX: %s", res.Description),
Timestamp: time.Now(),
})
fixedKeys = append(fixedKeys, f.Key())
}
return actions, fixedKeys
}
// extractWPConfigPath pulls the wp-config.php path out of a perf_wp_cron
// finding's Details, formatted as "File: <path> - add define(...)".
func extractWPConfigPath(details string) string {
const prefix = "File: "
idx := strings.Index(details, prefix)
if idx < 0 {
return ""
}
rest := details[idx+len(prefix):]
if j := strings.Index(rest, " - "); j >= 0 {
rest = rest[:j]
}
return strings.TrimSpace(rest)
}
func extractPID(details string) string {
// Look for "PID: 12345" pattern. Stop at the first whitespace, comma,
// or newline so a trailing word ("PID: 42 exe=/bin/ls") doesn't get
// returned as part of the PID string.
idx := strings.Index(details, "PID: ")
if idx < 0 {
return ""
}
rest := details[idx+5:]
for i, c := range rest {
if c == ' ' || c == ',' || c == '\n' || c == '\t' {
return strings.TrimSpace(rest[:i])
}
}
return strings.TrimSpace(rest)
}
func extractFilePath(message string) string {
// Look for /home/... or /tmp/... paths in the message. Order matters:
// longer/more-specific prefixes (/var/tmp/, /dev/shm/) must come BEFORE
// shorter ones (/tmp/) — otherwise "/tmp/" would match inside "/var/tmp/"
// and we'd silently misclassify the path.
for _, prefix := range []string{"/var/tmp/", "/dev/shm/", "/home/", "/tmp/"} {
if idx := strings.Index(message, prefix); idx >= 0 {
rest := message[idx:]
// Path ends at space, comma, or end of string
endIdx := len(rest)
for i, c := range rest {
if c == ' ' || c == ',' || c == '\n' {
endIdx = i
break
}
}
return rest[:endIdx]
}
}
return ""
}
func getProcessUID(pid string) string {
data, err := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
if err != nil {
return ""
}
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
return fields[0]
}
}
}
return ""
}
func getProcessExe(pid string) string {
exe, err := osFS.Readlink(filepath.Join("/proc", pid, "exe"))
if err != nil {
return ""
}
return exe
}
func isSafeProcess(exe string) bool {
safePrefixes := []string{
"/usr/local/cpanel/",
"/usr/sbin/",
"/usr/bin/",
"/usr/libexec/",
"/opt/cpanel/",
"/opt/cloudlinux/",
"/opt/imunify360/",
}
for _, prefix := range safePrefixes {
if strings.HasPrefix(exe, prefix) {
return true
}
}
return false
}
// isHighConfidenceRealtimeMatch validates whether a realtime signature match
// is truly malicious and safe to auto-quarantine. Prevents false positives
// on legitimate libraries (PHPMailer, zip) and theme code.
//
// The data parameter should be the file content already read by the caller
// (fanotify fd or scanner) to avoid TOCTOU re-reads. Pass nil to read from path.
//
// Criteria:
// 1. Category must be "dropper" or "webshell"
// 2. File must be >= 512 bytes (entropy unreliable below this)
// 3. Content must show obfuscation indicators:
// Shannon entropy >= 5.5 OR hex density > 20% plus an execution signal.
// This applies to BOTH dropper and webshell categories to avoid
// false-positive quarantine of legitimate plugins that happen to
// match a dropper rule (e.g. curl_exec + eval on distant lines).
func isHighConfidenceRealtimeMatch(f alert.Finding, path string, data []byte) bool {
cat := extractCategory(f.Details)
switch cat {
case "dropper", "webshell":
default:
return false
}
if data == nil {
var err error
data, err = osFS.ReadFile(path)
if err != nil {
return false
}
}
if len(data) < 512 {
return false
}
// Both dropper and webshell categories go through the same entropy/encoding
// checks to avoid false positives. Normal PHP with heavy class constants
// (binary literals, many named constants) lands in the 4.5-5.3 range --
// measured: WPML wpml_zip.php = 5.25, Breakdance google-fonts.php = 4.90.
// The 5.5 floor leaves that headroom. Obfuscated packers land at 5.8+.
// Files that use long hex-string payloads (LEVIATHAN signature) are
// caught by the hex-density arm instead, which stays at 20%.
content := string(data)
// High Shannon entropy is a strong standalone signal: packed/encrypted
// payloads land at 5.8+, while ordinary library code (even with a handful
// of binary constants) stays below 5.5 -- measured WPML wpml_zip.php =
// 5.25, Breakdance google-fonts.php = 4.90.
if shannonEntropy(content) >= 5.5 {
return true
}
// High hex density alone is NOT enough: a ZIP/PDF library's magic-byte
// constant tables ("\x50\x4b\x03\x04" ...) saturate that metric while being
// inert data. Require a structural obfuscated-execution signal too. This
// replaces a hardcoded library-path allowlist (vendor/, node_modules/,
// named plugin slugs) that an attacker could defeat by planting a webshell
// under any "trusted" directory -- the file is now judged by content, so a
// hex-encoded packer still quarantines wherever it hides and a benign
// data-heavy library file is spared on any path.
if hexEncodingDensity(content) > 0.20 {
return hasObfuscatedExecutionSignal(content)
}
return false
}
var (
reVariableFunctionCall = regexp.MustCompile(`\$[A-Za-z_]\w*\s*\(`)
reHexEscapedStringConcat = regexp.MustCompile(`(?i)"(?:\\x[0-9a-f]{2})+"\s*\.\s*"(?:\\x[0-9a-f]{2})+"`)
)
// hasObfuscatedExecutionSignal reports whether content carries a structural
// sign of obfuscated code execution, distinguishing a packed webshell from
// inert binary data such as a ZIP library's magic-byte constants. Any signal
// is sufficient:
// - LEVIATHAN-style control-flow obfuscation (goto spaghetti).
// - Function names built from concatenated hex escapes ("\x65"."\x76"... to
// dodge literal-name detection).
// - A variable bound to a decoder/exec primitive and later invoked.
// - A decoder (base64/gz/rot13/openssl/hex2bin) paired with an executor
// (eval/assert/create_function, a literal dangerous callback, or a
// request-scoped variable-function call).
func hasObfuscatedExecutionSignal(content string) bool {
code := stripPHPCommentsFromCode(content)
codeNoStrings := strings.ToLower(stripPHPStringsFromCode(code))
if countOccurrences(codeNoStrings, "goto ") > 10 {
return true
}
// Function-name obfuscation: many double-quoted hex string literals joined
// by the concatenation operator. Standalone hex constant tables (no concat)
// are inert data and do not match.
if countHexEscapedStringConcats(code) > 10 {
return true
}
if detectVarFuncDangerousAssignment(code) {
return true
}
if !containsDirectPHPFunctionCall(codeNoStrings, []string{
"base64_decode", "gzinflate", "gzuncompress", "gzdecode",
"str_rot13", "openssl_decrypt", "hex2bin", "convert_uudecode",
}) {
return false
}
if containsDirectPHPFunctionCall(codeNoStrings, []string{
"eval", "assert", "create_function",
}) {
return true
}
if hasLiteralCallbackExecutor(code) {
return true
}
return hasRequestScopedVariableFunctionCall(codeNoStrings)
}
func countHexEscapedStringConcats(code string) int {
return len(reHexEscapedStringConcat.FindAllStringIndex(code, -1))
}
func containsDirectPHPFunctionCall(codeNoStrings string, names []string) bool {
for _, name := range names {
if containsStandaloneFunc(codeNoStrings, name+"(") {
return true
}
}
return false
}
func hasRequestScopedVariableFunctionCall(codeNoStrings string) bool {
for _, line := range strings.Split(codeNoStrings, "\n") {
if containsRequestSuperglobal(line) && reVariableFunctionCall.MatchString(line) {
return true
}
}
return false
}
func hasLiteralCallbackExecutor(code string) bool {
for i := 0; i < len(code); i++ {
if isPHPQuote(code[i]) {
i = skipPHPString(code, i)
continue
}
nameStart := i
if code[i] == '\\' {
if i+1 >= len(code) || !isPHPIdentifierStart(code[i+1]) || !canStartGlobalPHPFunction(code, i) {
continue
}
nameStart = i + 1
} else if !isPHPIdentifierStart(code[i]) || !canStartPHPFunctionName(code, i) {
continue
}
nameEnd := nameStart + 1
for nameEnd < len(code) && isPHPIdentifierPart(code[nameEnd]) {
nameEnd++
}
name := strings.ToLower(code[nameStart:nameEnd])
if _, ok := callbackFirstArgFuncs[name]; !ok {
i = nameEnd - 1
continue
}
openParen := skipPHPWhitespace(code, nameEnd)
if openParen >= len(code) || code[openParen] != '(' {
i = nameEnd - 1
continue
}
firstArg := skipPHPWhitespace(code, openParen+1)
if firstArg >= len(code) || !isPHPQuote(code[firstArg]) {
i = nameEnd - 1
continue
}
callbackName, _, ok := readPHPFunctionString(code, firstArg)
if !ok {
i = nameEnd - 1
continue
}
if _, dangerous := callbackExecNames[callbackName]; dangerous {
return true
}
i = nameEnd - 1
}
return false
}
// hexEncodingDensity returns the fraction of a string's bytes that are part
// of PHP hex escape sequences (\xNN). LEVIATHAN AES-encrypted webshells
// encode their payload as long hex strings - the \x prefix repeats so
// frequently that Shannon entropy drops to ~3.5 (below normal PHP), but
// the hex density reaches 40-60%.
func hexEncodingDensity(s string) float64 {
if len(s) == 0 {
return 0
}
hexBytes := 0
for i := 0; i < len(s)-3; i++ {
if s[i] == '\\' && s[i+1] == 'x' &&
isHexDigit(s[i+2]) && isHexDigit(s[i+3]) {
hexBytes += 4
i += 3 // skip past this sequence
}
}
return float64(hexBytes) / float64(len(s))
}
func isHexDigit(b byte) bool {
return (b >= '0' && b <= '9') || (b >= 'a' && b <= 'f') || (b >= 'A' && b <= 'F')
}
// InlineQuarantineGated applies the operator's quarantine policy before
// InlineQuarantine moves anything. The realtime fanotify path detects malware
// continuously, but moving a file is a customer-impacting auto-response
// action: it must honor the same master switch and quarantine opt-in as the
// batch AutoQuarantineFiles dispatcher, never act on detection alone. An
// operator in monitor mode (auto-response off, or quarantine_files off) gets
// the alert without having files moved out from under them.
func InlineQuarantineGated(cfg *config.Config, f alert.Finding, path string, data []byte) (string, bool) {
if cfg == nil || !cfg.AutoResponse.Enabled || !cfg.AutoResponse.QuarantineFiles {
return "", false
}
return InlineQuarantine(f, path, data)
}
// InlineQuarantine moves a file to quarantine immediately if it passes the
// high-confidence validation gates. Called from fanotify's analyzeFile to
// quarantine malware without waiting for the 5-second batch dispatcher.
// The data parameter is the file content already read by the caller (avoids
// TOCTOU re-read). Pass nil to read from path.
// Returns the quarantine path and true if the file was quarantined.
func InlineQuarantine(f alert.Finding, path string, data []byte) (string, bool) {
if !isHighConfidenceRealtimeMatch(f, path, data) {
return "", false
}
// Lstat (not Stat) so a symlink is seen as a symlink and rejected. Stat
// follows the link, letting an attacker who swaps the file for a symlink
// between detection and the move trick CSM into relocating the target.
info, err := osFS.Lstat(path)
if err != nil {
return "", false
}
if info.Mode()&os.ModeSymlink != 0 {
return "", false
}
_ = os.MkdirAll(quarantineDir, 0700)
safeName := strings.ReplaceAll(path, "/", "_")
ts := time.Now().Format("20060102-150405")
qPath := filepath.Join(quarantineDir, fmt.Sprintf("%s_%s", ts, safeName))
if info.IsDir() {
if err := os.Rename(path, qPath); err != nil {
return "", false
}
} else {
// Move the file through the TOCTOU-safe path (fd open with O_NOFOLLOW,
// fstat-verify the inode, hardlink-by-fd, unlink) just like the batch
// AutoQuarantineFiles dispatcher, so a late inode/symlink swap fails
// closed instead of relocating an attacker-chosen file.
if err := quarantineFileTOCTOUSafe(path, qPath, info); err != nil {
fmt.Fprintf(os.Stderr, "autoresponse: refused inline quarantine of %s: %v\n", path, err)
return "", false
}
}
// Write metadata sidecar
var uid, gid int
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
uid = int(stat.Uid)
gid = int(stat.Gid)
}
meta := QuarantineMeta{
OriginalPath: path,
Owner: uid,
Group: gid,
Mode: info.Mode().String(),
Size: info.Size(),
QuarantineAt: time.Now(),
Reason: "Inline quarantine: high-confidence realtime signature match",
}
metaData, _ := json.MarshalIndent(meta, "", " ")
_ = os.WriteFile(qPath+".meta", metaData, 0600)
return qPath, true
}
// extractCategory parses "Category: <value>" from a finding's Details field.
func extractCategory(details string) string {
for _, line := range strings.Split(details, "\n") {
if strings.HasPrefix(line, "Category: ") {
return strings.TrimPrefix(line, "Category: ")
}
}
return ""
}
// AutoCleanHtaccess runs the hardened .htaccess cleaner against
// every finding emitted by the new detector registry, gated by
// AutoResponse.CleanHtaccess. Skipped when the daemon's auto-response
// pipeline is disabled overall.
//
// Unlike AutoQuarantineFiles, this routes around the
// quarantine/clean fork (.htaccess files are infrastructure -- moving
// them to /opt/csm/quarantine breaks the site). Each invocation
// backs up the original to /opt/csm/quarantine/pre_clean/<ts>_*
// inside CleanHtaccessFile before atomic-replacing.
func AutoCleanHtaccess(cfg *config.Config, findings []alert.Finding) []alert.Finding {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.CleanHtaccess {
return nil
}
var actions []alert.Finding
seen := make(map[string]struct{})
for _, f := range findings {
if !isHtaccessHardenedFinding(f.Check) {
continue
}
path := f.FilePath
if path == "" {
path = extractFilePath(f.Message)
}
if path == "" {
continue
}
// One Clean per file per autoresponse pass: multiple
// detector findings on the same file converge on a single
// cleaning call (CleanHtaccessFile re-runs every detector).
if _, ok := seen[path]; ok {
continue
}
seen[path] = struct{}{}
result := CleanHtaccessFile(path)
if result.Success {
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-CLEAN: %s hardened directives removed", path),
Details: result.Description,
Timestamp: time.Now(),
})
} else if result.Error != "" && !strings.Contains(result.Error, "no malicious directives") {
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-CLEAN failed: %s", path),
Details: result.Error,
Timestamp: time.Now(),
})
}
}
return actions
}
func isHtaccessHardenedFinding(check string) bool {
switch check {
case "htaccess_auto_prepend", "htaccess_errordocument_hijack",
"htaccess_filesmatch_shield", "htaccess_header_injection",
"htaccess_php_in_uploads", "htaccess_spam_redirect",
"htaccess_user_agent_cloak":
return true
}
return false
}
package checks
import (
"fmt"
"net"
"sync/atomic"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
type asnLookupFunc func(ip string) (asn uint, org string)
type asnLookupHolder struct {
fn asnLookupFunc
}
// asnLookup resolves an IP to its autonomous system number and organization
// via the GeoLite2-ASN database. The daemon injects it at startup/reload;
// nil (no ASN database, or unit tests that do not exercise the live path)
// disables bad-ASN classification on the outbound connection scan.
var asnLookup atomic.Pointer[asnLookupHolder]
// SetASNLookup wires the GeoLite2-ASN resolver used by the outbound
// connection scan. Passing nil clears it.
func SetASNLookup(fn func(ip string) (asn uint, org string)) {
if fn == nil {
asnLookup.Store(nil)
return
}
asnLookup.Store(&asnLookupHolder{fn: fn})
}
// CurrentASNLookup returns the wired ASN resolver, or nil when none is set.
// Both the polling connection scan and the live BPF connection evaluator
// use it so bad-ASN classification behaves identically on either path.
func CurrentASNLookup() func(ip string) (asn uint, org string) {
h := asnLookup.Load()
if h == nil {
return nil
}
return h.fn
}
// EvaluateBadASNOutbound classifies one outbound connection's destination by
// autonomous system and returns a finding when the ASN is bad. It is a pure
// function: the caller supplies the destination IP and the ASN/org already
// resolved from the GeoLite2-ASN database, so the classifier has no IO and
// is the third leg of the host-takeover chain correlator.
//
// Classification:
// - blocked_asns always flags (e.g. known bulletproof hosters);
// - when allowed_asns is non-empty, any ASN outside it flags (allowlist
// mode for hosts whose legitimate egress is confined to a few providers).
//
// An ASN of 0 (no AS found for the IP) is skipped: classifying it would flag
// every destination missing from the ASN database. Private, loopback,
// link-local, and unspecified destinations are skipped because ASN lookup is
// meaningless for them.
func EvaluateBadASNOutbound(cfg *config.Config, dstIP net.IP, asn uint, asOrg string) (alert.Finding, bool) {
if cfg == nil || !cfg.Detection.BadASNOutbound.Enabled {
return alert.Finding{}, false
}
if dstIP == nil || dstIP.IsLoopback() || dstIP.IsUnspecified() ||
dstIP.IsPrivate() || dstIP.IsLinkLocalUnicast() || dstIP.IsLinkLocalMulticast() {
return alert.Finding{}, false
}
if asn == 0 {
return alert.Finding{}, false
}
if !asnIsBad(cfg, asn) {
return alert.Finding{}, false
}
org := asOrg
if org == "" {
org = "unknown organization"
}
dst := dstIP.String()
if dstIP.To4() == nil {
dst = "[" + dst + "]"
}
return alert.Finding{
Severity: alert.High,
Check: "bad_asn_outbound",
Message: fmt.Sprintf("Outbound connection to bad ASN %d (%s): %s", asn, org, dst),
Details: fmt.Sprintf("Destination: %s\nASN: %d (%s)\n"+
"Combined with a new uid-0 account or a planted suid binary this escalates to a host takeover.",
dst, asn, org),
SourceIP: dstIP.String(),
}, true
}
// asnIsBad applies the blocklist-then-allowlist policy to a single ASN.
func asnIsBad(cfg *config.Config, asn uint) bool {
for _, b := range cfg.Detection.BadASNOutbound.BlockedASNs {
if b == asn {
return true
}
}
allowed := cfg.Detection.BadASNOutbound.AllowedASNs
if len(allowed) == 0 {
return false
}
for _, a := range allowed {
if a == asn {
return false
}
}
return true
}
package checks
import (
"context"
"fmt"
"net"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/state"
)
// domlogDiscoveryDropped counts per-vhost log paths the discovery helper
// dropped silently (broken symlink, Stat failure). Same family of
// hidden-input bug as the lex-order issue scanDomlogs already fixed:
// without telemetry, operators have no way to notice when discovery
// loses a sizable fraction of the vhosts they expected to scan.
var (
domlogDiscoveryDropped *metrics.CounterVec
domlogDiscoveryDroppedOnce sync.Once
)
func observeDomlogDrop(reason string) {
domlogDiscoveryDroppedOnce.Do(func() {
domlogDiscoveryDropped = metrics.NewCounterVec(
"csm_checks_domlog_discovery_dropped_total",
"Per-vhost access-log paths the WP brute-force domlog discovery helper dropped before scanning. Labels: reason (evalsymlinks_error|stat_error). Steady growth means a chunk of vhosts is being silently skipped each cycle -- usually a broken symlink farm or a permissions regression on the log directory. Stale-mtime drops are intentional filtering, not counted here.",
[]string{"reason"},
)
metrics.MustRegister("csm_checks_domlog_discovery_dropped_total", domlogDiscoveryDropped)
})
domlogDiscoveryDropped.With(reason).Inc()
}
const (
wpLoginThreshold = 20 // attempts per IP across all logs
xmlrpcThreshold = 30
ftpFailThreshold = 10
webmailThreshold = 10
apiFailThreshold = 10
// domlogTailLines is the built-in default for how many trailing lines
// to read from each domlog. Operators can override via
// cfg.Thresholds.DomlogTailLines. 500 covers ~10 minutes of traffic
// on a busy site.
domlogTailLines = 500
// domlogMaxAge skips domlogs not modified recently (inactive sites).
domlogMaxAge = 30 * time.Minute
// domlogMaxFiles caps the number of domlogs scanned per cycle
// to prevent unbounded I/O on servers with thousands of domains.
domlogMaxFiles = 500
)
// CheckWPBruteForce detects brute force attacks against wp-login.php and
// xmlrpc.php by scanning access logs. Always scans per-domain domlogs
// because on LiteSpeed+cPanel, virtual host traffic only appears there.
// The central access log is scanned as a supplement.
//
// Aggregates per-IP counts across ALL domains -- catches attackers who
// distribute requests across many sites to stay under per-site thresholds.
func CheckWPBruteForce(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
window := cfg.Thresholds.BruteForceWindow
if window <= 0 {
window = 5000
}
stats := newDomlogStats()
// 1. Per-domain domlogs -- primary source on LiteSpeed.
// Glob both SSL and non-SSL logs: attackers may use HTTP.
scanned := scanDomlogsStats(ctx, cfg, stats)
// 2. Central access log -- supplement for non-vhost traffic.
// On LiteSpeed this mostly has WHM/server-level requests.
// On Apache it duplicates domlog data; minor double-counting is
// acceptable since thresholds are high enough.
for _, p := range platform.Detect().AccessLogPaths {
lines := tailFile(p, window)
if len(lines) == 0 {
continue
}
for _, line := range lines {
rec, ok := parseAccessLogRecord(line)
if !ok {
continue
}
stats.scan(rec, cfg, currentBotClassifier(cfg))
}
break
}
findings := stats.emit(cfg)
// Replace the generic legacy Details with the actual scanned-file count.
for i := range findings {
if findings[i].Details == "Aggregated across per-vhost access logs" {
findings[i].Details = "Aggregated across " + itoa(scanned) + " per-vhost access logs"
}
}
return findings
}
// knownCentralAccessLogPaths lists central web-server log paths that may
// appear in broad per-vhost glob patterns. CheckWPBruteForce tails these on
// its own pass, so per-vhost discovery filters them to avoid counting the
// same lines twice.
var knownCentralAccessLogPaths = []string{
"/var/log/apache2/access.log",
"/var/log/apache2/access_log",
"/var/log/httpd/access.log",
"/var/log/httpd/access_log",
"/var/log/nginx/access.log",
"/usr/local/lsws/logs/access.log",
}
// discoverFreshDomlogs returns per-vhost access-log paths ready to tail.
// It globs platform.DomlogGlobs, dedupes by resolved-symlink real path,
// excludes the well-known central logs (so they are not double-counted),
// drops files untouched in the last maxAge, ranks survivors
// most-recent-first, and caps the result at maxFiles.
//
// Mtime-desc + cap is the fairness invariant: lexical glob order plus a
// hard cap would systematically hide brute force on late-alphabet
// domains. maxFiles <= 0 falls back to the built-in domlogMaxFiles
// default; maxAge <= 0 falls back to the built-in domlogMaxAge default.
// A canceled ctx returns nil and stops before any further work.
//
// Shared by scanDomlogs and scanDomlogsStats so the discovery semantics
// stay locked together; each caller layers its own per-line aggregator
// on top.
func discoverFreshDomlogs(ctx context.Context, maxFiles int, maxAge time.Duration) []string {
if maxFiles <= 0 {
maxFiles = domlogMaxFiles
}
if maxAge <= 0 {
maxAge = domlogMaxAge
}
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return nil
}
platformInfo := platform.Detect()
globs := platformInfo.DomlogGlobs
centralLogs := centralAccessLogSet(platformInfo.AccessLogPaths)
var domlogs []string
for _, pattern := range globs {
if err := ctx.Err(); err != nil {
return nil
}
matches, _ := osFS.Glob(pattern)
domlogs = append(domlogs, matches...)
}
type domlogEntry struct {
path string
mtime time.Time
}
fresh := make([]domlogEntry, 0, len(domlogs))
seen := make(map[string]bool)
cutoff := time.Now().Add(-maxAge)
for _, dl := range domlogs {
if err := ctx.Err(); err != nil {
return nil
}
// Resolve symlinks first -- cPanel symlinks SSL and non-SSL
// logs to the same backing file; dedupe on the real path.
real, err := filepath.EvalSymlinks(dl)
if err != nil {
observeDomlogDrop("evalsymlinks_error")
continue
}
if seen[real] || centralLogs[real] {
continue
}
seen[real] = true
// Inactive sites add no signal; filter before the cap so they
// cannot crowd active sites out of the budget.
info, err := osFS.Stat(real)
if err != nil {
observeDomlogDrop("stat_error")
continue
}
if info.ModTime().Before(cutoff) {
continue
}
fresh = append(fresh, domlogEntry{path: real, mtime: info.ModTime()})
}
if err := ctx.Err(); err != nil {
return nil
}
sort.Slice(fresh, func(i, j int) bool {
if fresh[i].mtime.Equal(fresh[j].mtime) {
return fresh[i].path < fresh[j].path
}
return fresh[i].mtime.After(fresh[j].mtime)
})
if len(fresh) > maxFiles {
fresh = fresh[:maxFiles]
}
out := make([]string, len(fresh))
for i, e := range fresh {
out[i] = e.path
}
return out
}
func centralAccessLogSet(configured []string) map[string]bool {
out := make(map[string]bool, len(knownCentralAccessLogPaths)+len(configured))
for _, p := range knownCentralAccessLogPaths {
addCentralAccessLog(out, p)
}
// Multiple AccessLogPaths are fallback candidates; CheckWPBruteForce
// tails only the first one with data, so excluding every candidate here
// would drop later logs that were never scanned centrally.
if len(configured) == 1 {
addCentralAccessLog(out, configured[0])
}
return out
}
func addCentralAccessLog(out map[string]bool, path string) {
if path == "" {
return
}
out[path] = true
if real, err := filepath.EvalSymlinks(path); err == nil {
out[real] = true
}
}
// effectiveDomlogTailLines returns the operator-configured
// thresholds.domlog_tail_lines value or the built-in default when unset.
func effectiveDomlogTailLines(cfg *config.Config) int {
if cfg == nil || cfg.Thresholds.DomlogTailLines <= 0 {
return domlogTailLines
}
return cfg.Thresholds.DomlogTailLines
}
// effectiveDomlogMaxAge returns the operator-configured
// thresholds.domlog_max_age_min as a Duration, or the built-in default
// when unset.
func effectiveDomlogMaxAge(cfg *config.Config) time.Duration {
if cfg == nil || cfg.Thresholds.DomlogMaxAgeMin <= 0 {
return domlogMaxAge
}
return time.Duration(cfg.Thresholds.DomlogMaxAgeMin) * time.Minute
}
// tailDomlogsInto tails each discovered path and feeds every parsed
// access-log record into stats. Returns the number of files actually
// tailed (loop exits early if ctx is cancelled mid-pass).
//
// Single tail-and-aggregate loop shared by every domlog scanner so the
// per-file ctx gate, the parse-or-skip behaviour, and the scanned
// counter cannot drift between callers.
func tailDomlogsInto(ctx context.Context, paths []string, cfg *config.Config, stats *domlogStats, classifier botClassifier, tailLines int) int {
scanned := 0
for _, p := range paths {
if ctx != nil {
if err := ctx.Err(); err != nil {
break
}
}
domain := domainFromDomlogPath(p)
for _, line := range tailFile(p, tailLines) {
rec, ok := parseAccessLogRecord(line)
if !ok {
continue
}
rec.Domain = domain
stats.scan(rec, cfg, classifier)
}
scanned++
}
return scanned
}
// domainFromDomlogPath derives the vhost from a per-domain domlog file
// path. Returns "" for paths that do not look like a domain log so the
// central access log and odd filenames do not pollute the per-IP vhost set.
func domainFromDomlogPath(p string) string {
base := filepath.Base(p)
if domain, ok := pleskDomlogDomain(p, base); ok {
return cleanDomlogDomain(domain)
}
if domain, ok := trimDomlogSuffix(base); ok {
return cleanDomlogDomain(domain)
}
return cleanDomlogDomain(base)
}
func pleskDomlogDomain(p, base string) (string, bool) {
switch strings.ToLower(strings.TrimSpace(base)) {
case "access_log", "access_ssl_log", "proxy_access_ssl_log":
return filepath.Base(filepath.Dir(filepath.Dir(p))), true
default:
return "", false
}
}
func trimDomlogSuffix(base string) (string, bool) {
trimmed := strings.TrimSpace(base)
low := strings.ToLower(trimmed)
for _, suffix := range []string{
".access.log",
"-access.log",
"_access.log",
"-access_log",
"_access_log",
"-ssl_log",
"_log",
".log",
} {
if strings.HasSuffix(low, suffix) {
return trimmed[:len(trimmed)-len(suffix)], true
}
}
return trimmed, false
}
func cleanDomlogDomain(domain string) string {
domain = strings.ToLower(strings.TrimSpace(domain))
if domain == "" || len(domain) > 253 || !strings.Contains(domain, ".") {
return ""
}
if strings.HasPrefix(domain, ".") || strings.HasSuffix(domain, ".") ||
strings.Contains(domain, "..") {
return ""
}
if net.ParseIP(domain) != nil {
return ""
}
labels := strings.Split(domain, ".")
for _, label := range labels {
if label == "" || len(label) > 63 ||
strings.HasPrefix(label, "-") || strings.HasSuffix(label, "-") {
return ""
}
for _, c := range label {
if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' {
continue
}
return ""
}
}
return domain
}
// scanDomlogsStats discovers per-vhost logs honouring the operator's
// thresholds and feeds each parsed record into stats. Production entry
// point used by CheckWPBruteForce. Returns the number of files actually
// tailed.
func scanDomlogsStats(ctx context.Context, cfg *config.Config, stats *domlogStats) int {
if cfg == nil {
cfg = &config.Config{}
}
paths := discoverFreshDomlogs(ctx, cfg.Thresholds.DomlogMaxFiles, effectiveDomlogMaxAge(cfg))
return tailDomlogsInto(ctx, paths, cfg, stats, currentBotClassifier(cfg), effectiveDomlogTailLines(cfg))
}
// scanDomlogs is the legacy infra-IPs-only entry kept for test fixtures
// that drive the brute-force counters directly. Production code calls
// scanDomlogsStats. Both share discoverFreshDomlogs + tailDomlogsInto so
// path selection and per-file ctx semantics cannot diverge.
//
// maxFiles <= 0 falls back to the built-in domlogMaxFiles default.
func scanDomlogs(ctx context.Context, infraIPs []string, maxFiles int, wpLogin, xmlrpc, userEnum map[string]int) int {
cfg := &config.Config{InfraIPs: infraIPs}
stats := newDomlogStats()
stats.wpLogin = wpLogin
stats.xmlrpc = xmlrpc
stats.userEnum = userEnum
paths := discoverFreshDomlogs(ctx, maxFiles, 0)
return tailDomlogsInto(ctx, paths, cfg, stats, nopBotClassifier{}, domlogTailLines)
}
// countBruteForce parses Combined Log Format lines and increments per-IP
// counters via the shared domlogStats aggregator. Kept as a thin shim
// for tests that feed lines directly (no file discovery / tail step).
func countBruteForce(lines []string, infraIPs []string, wpLogin, xmlrpc, userEnum map[string]int) {
cfg := &config.Config{InfraIPs: infraIPs}
stats := newDomlogStats()
stats.wpLogin = wpLogin
stats.xmlrpc = xmlrpc
stats.userEnum = userEnum
for _, line := range lines {
rec, ok := parseAccessLogRecord(line)
if !ok {
continue
}
stats.scan(rec, cfg, nopBotClassifier{})
}
}
// syslogMessagesTailLinesDefault is the built-in fallback for how many
// trailing lines of /var/log/messages CheckFTPLogins tails per cycle.
// Operator override: cfg.Thresholds.SyslogMessagesTailLines.
const syslogMessagesTailLinesDefault = 200
// CheckFTPLogins parses /var/log/messages or pure-ftpd log for FTP brute force.
func CheckFTPLogins(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
tailLines := syslogMessagesTailLinesDefault
if cfg != nil && cfg.Thresholds.SyslogMessagesTailLines > 0 {
tailLines = cfg.Thresholds.SyslogMessagesTailLines
}
lines := tailFile("/var/log/messages", tailLines)
if len(lines) == 0 {
return nil
}
failedFTP := make(map[string]int)
for _, line := range lines {
// pure-ftpd logs: "pure-ftpd: ... [WARNING] Authentication failed for user"
if !strings.Contains(line, "pure-ftpd") {
continue
}
if strings.Contains(line, "Authentication failed") || strings.Contains(line, "auth failed") {
// Extract IP
ip := extractIPFromLog(line)
if ip != "" && !isInfraIP(ip, cfg.InfraIPs) {
failedFTP[ip]++
}
}
// Successful FTP login from non-infra
if strings.Contains(line, "is now logged in") {
ip := extractIPFromLog(line)
if ip != "" && !isInfraIP(ip, cfg.InfraIPs) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ftp_login",
Message: fmt.Sprintf("FTP login from non-infra IP: %s", ip),
Details: truncate(line, 200),
})
}
}
}
for ip, count := range failedFTP {
if count >= ftpFailThreshold {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ftp_bruteforce",
Message: fmt.Sprintf("FTP brute force from %s: %d failed attempts", ip, count),
})
}
}
return findings
}
// CheckWebmailLogins parses cPanel access log for webmail logins from non-infra IPs.
func CheckWebmailLogins(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if cfg.Suppressions.SuppressWebmail {
return nil
}
// Webmail ports and this log are cPanel-specific; on other panels the
// path does not exist and the check would silently return nothing.
if !platform.Detect().IsCPanel() {
return nil
}
var findings []alert.Finding
lines := tailFile("/usr/local/cpanel/logs/access_log", 300)
loginAttempts := make(map[string]int)
for _, line := range lines {
// Webmail ports: 2095 (HTTP), 2096 (HTTPS)
if !strings.Contains(line, "2095") && !strings.Contains(line, "2096") {
continue
}
fields := strings.Fields(line)
if len(fields) < 1 {
continue
}
ip := fields[0]
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
// Count login attempts per IP
if strings.Contains(line, "POST") && (strings.Contains(line, "login") || strings.Contains(line, "auth")) {
loginAttempts[ip]++
}
}
for ip, count := range loginAttempts {
if count >= webmailThreshold {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "webmail_bruteforce",
Message: fmt.Sprintf("Webmail brute force from %s: %d attempts", ip, count),
})
}
}
return findings
}
// CheckAPIAuthFailures parses cPanel access log for failed API authentication.
func CheckAPIAuthFailures(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
// The cPanel/WHM API access log is cPanel-specific.
if !platform.Detect().IsCPanel() {
return nil
}
var findings []alert.Finding
lines := tailFile("/usr/local/cpanel/logs/access_log", 300)
failedAPI := make(map[string]int)
for _, line := range lines {
// Look for 401/403 responses on API endpoints
if !strings.Contains(line, "\" 401 ") && !strings.Contains(line, "\" 403 ") {
continue
}
// Only API endpoints
if !strings.Contains(line, "json-api") && !strings.Contains(line, "/execute/") &&
!strings.Contains(line, "cpsess") {
continue
}
fields := strings.Fields(line)
if len(fields) < 1 {
continue
}
ip := fields[0]
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
failedAPI[ip]++
}
for ip, count := range failedAPI {
if count >= apiFailThreshold {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "api_auth_failure",
Message: fmt.Sprintf("cPanel API auth failures from %s: %d attempts", ip, count),
Details: "Possible API token brute force or unauthorized API access",
})
}
}
return findings
}
// extractIPFromLog tries to extract an IP address from a log line.
func extractIPFromLog(line string) string {
// Look for IP pattern in common positions
fields := strings.Fields(line)
for _, f := range fields {
// Simple IP detection: starts with digit, contains dots
if len(f) >= 7 && f[0] >= '0' && f[0] <= '9' && strings.Count(f, ".") == 3 {
// Strip trailing punctuation
f = strings.TrimRight(f, ",:;)([]")
return f
}
}
return ""
}
package checks
import (
"fmt"
"os"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
// ChallengeIPList abstracts the challenge IP list for routing.
type ChallengeIPList interface {
Add(ip string, reason string, duration time.Duration)
Contains(ip string) bool
}
var challengeIPList ChallengeIPList
// SetChallengeIPList sets the challenge IP list for routing.
func SetChallengeIPList(list ChallengeIPList) {
challengeIPList = list
}
// GetChallengeIPList returns the current challenge IP list (for AutoBlockIPs skip check).
func GetChallengeIPList() ChallengeIPList {
return challengeIPList
}
// hardBlockChecks are exact check names that must NEVER be routed to challenge.
var hardBlockChecks = map[string]bool{
"signature_match_realtime": true,
"yara_match_realtime": true,
"webshell": true,
"backdoor_binary": true,
"cross_account_malware": true,
"c2_connection": true,
"backdoor_port": true,
"backdoor_port_outbound": true,
"exfiltration_paste_site": true,
"htaccess_injection": true,
"htaccess_handler_abuse": true,
"db_siteurl_hijack": true,
"db_options_injection": true,
"db_post_injection": true,
"db_rogue_admin": true,
"db_spam_injection": true,
"phishing_page": true,
"phishing_iframe": true,
"phishing_php": true,
"phishing_redirector": true,
"phishing_credential_log": true,
"phishing_kit_archive": true,
"phishing_directory": true,
"php_shield_webshell": true,
"php_shield_block": true,
"php_shield_eval": true,
"suspicious_crontab": true,
"suspicious_process": true,
"fake_kernel_thread": true,
"php_suspicious_execution": true,
"suspicious_file": true,
"password_hijack_confirmed": true,
"symlink_attack": true,
"shadow_change": true,
"root_password_change": true,
"coordinated_attack": true,
"database_dump": true,
"kernel_module": true,
"uid0_account": true,
"suid_binary": true,
"rpm_integrity": true,
"email_spam_outbreak": true,
"modsec_csm_block_escalation": true,
"api_auth_failure_realtime": true, // cPanel API brute force — challenge is useless, hard-block
"ftp_auth_failure_realtime": true, // FTP brute force — can't challenge non-HTTP
"credential_stuffing": true, // PAM breadth signal — can't challenge non-HTTP
"pam_bruteforce": true, // PAM brute force — can't challenge non-HTTP
"smtp_bruteforce": true, // SMTP brute force — can't challenge non-HTTP
"smtp_probe_abuse": true, // SMTP probe abuse (connect-rate) cannot challenge non-HTTP
"smtp_subnet_spray": true, // SMTP subnet spray — can't challenge non-HTTP
"mail_bruteforce": true, // Mail brute force — can't challenge non-HTTP
"mail_subnet_spray": true, // Mail subnet spray — can't challenge non-HTTP
"mail_account_compromised": true, // Mail account compromise — instant block, zero-FP signal
"admin_panel_bruteforce": true, // Admin panel brute force — tight path set makes FP near-impossible
"waf_attack_blocked": true, // WAF already blocked repeated attacks; keep auto-block path direct
}
// hardBlockPrefixes match any check name starting with these strings.
var hardBlockPrefixes = []string{
"outgoing_mail_",
"spam_",
"modsec_",
"email_auth_failure", // email brute force - SMTP/IMAP, can't challenge via HTTP
"email_compromised", // confirmed compromised email account
"email_credential", // credential leak
}
// challengeableChecks lists checks whose findings contain attacker IPs and
// are appropriate for challenge routing. Closed allowlist. Two rules for
// inclusion:
//
// 1. The IP carrying the finding must be a CLIENT IP making an HTTPS/HTTP
// request that a browser could see. Background tasks (DNS recursion,
// SSH/FTP from CLI clients, internal auth daemons) have no browser to
// present a CAPTCHA to; routing them produces guaranteed
// challenge-timeout hard-blocks, not gated access.
//
// 2. The finding must indicate an ATTACK signal, not an audit-trail
// event. A single successful login is normal customer traffic;
// repeated failed logins (brute force) is attack signal.
//
// Removed from this list (do not reintroduce without revisiting the two
// rules above):
//
// - cpanel_login / cpanel_login_realtime: post-auth audit events; the
// user is already inside cPanel and never makes a fresh connection
// the gate could catch.
// - cpanel_file_upload / cpanel_file_upload_realtime: same; post-auth.
// - cpanel_multi_ip_login / whm_password_change: multi-vector audit.
// - ftp_login / ftp_login_realtime / ssh_login_realtime /
// ssh_login_unknown_ip: no browser at the other end of FTP or SSH.
// - webmail_login_realtime: same as cpanel_login_realtime; post-auth.
// - dns_connection / user_outbound_connection: recursive resolvers and
// egress targets have no client browser.
// - api_auth_failure: API clients, not browsers.
// - brute_force: legacy bucket; superseded by per-protocol entries.
var challengeableChecks = map[string]bool{
// Pre-auth brute force on browser-facing endpoints. Attacker hits a
// public login page repeatedly; the next request from the same IP
// gets routed to the challenge.
"wp_login_bruteforce": true,
"xmlrpc_abuse": true,
"wp_user_enumeration": true,
"webmail_bruteforce": true,
// Reputation / scoring on the HTTP path. The IP is suspect across
// many checks; before hard-blocking, give a browser one verifier.
"ip_reputation": true,
"local_threat_score": true,
}
func isChallengeableCheck(check string) bool {
return challengeableChecks[check]
}
// isHardBlockCheck returns true if the check should be hard-blocked (never challenged).
func isHardBlockCheck(check string) bool {
if hardBlockChecks[check] {
return true
}
for _, prefix := range hardBlockPrefixes {
if strings.HasPrefix(check, prefix) {
return true
}
}
return false
}
const challengeDuration = 30 * time.Minute
// ChallengeRouteIPs processes findings and routes eligible IPs to the challenge
// list instead of hard-blocking them. Must be called BEFORE AutoBlockIPs so
// that challenged IPs are on the list when AutoBlockIPs checks Contains().
func ChallengeRouteIPs(cfg *config.Config, findings []alert.Finding) []alert.Finding {
if !cfg.Challenge.Enabled || challengeIPList == nil {
return nil
}
var actions []alert.Finding
routed := make(map[string]bool)
for _, f := range findings {
if isHardBlockCheck(f.Check) {
continue
}
// Only route checks that are known to contain attacker IPs.
// This is an allowlist — new IP-bearing checks must be added to
// challengeableChecks. Defaulting to skip prevents version numbers,
// sizes, and other numeric finding fields from being blocked as IPs.
if !isChallengeableCheck(f.Check) {
continue
}
ip := extractIPFromFinding(f)
if ip == "" || routed[ip] {
continue
}
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
if challengeIPList.Contains(ip) {
continue
}
challengeIPList.Add(ip, f.Message, challengeDuration)
routed[ip] = true
fmt.Fprintf(os.Stderr, "[%s] CHALLENGE: %s routed to challenge (check: %s)\n",
time.Now().Format("2006-01-02 15:04:05"), ip, f.Check)
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "challenge_route",
Message: fmt.Sprintf("CHALLENGE: %s sent to PoW challenge (expires in %s)", ip, challengeDuration),
Details: fmt.Sprintf("Reason: %s", f.Message),
Timestamp: time.Now(),
})
}
return actions
}
package checks
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"syscall"
"time"
"unicode"
"golang.org/x/sys/unix"
)
// CleanResult describes the outcome of a cleaning attempt.
type CleanResult struct {
Path string
Cleaned bool
BackupPath string
Removals []string // descriptions of what was removed
Error string
}
var (
// Leading class is [\s\x00\x0b] rather than \s: an injection can pad
// the line start with NUL or vertical-tab bytes that Go's \s does not
// cover. Detection still flags the file (analyzePHPContent is not
// anchored); widening the class here lets the surgical cleaner strip
// the injection instead of falling back to whole-file quarantine.
cleanRegexpMaliciousInclude = regexp.MustCompile(
`(?i)^[\s\x00\x0b]*@include\s*\(\s*(?:` +
`['"](?:/tmp/|/dev/shm/|/var/tmp/)` +
`|base64_decode\s*\(` +
`|str_rot13\s*\(` +
`|gzinflate\s*\(` +
`)`)
cleanRegexpVarInclude = regexp.MustCompile(`(?i)^[\s\x00\x0b]*@include\s*\(\s*\$[a-zA-Z_]+\s*\)`)
cleanRegexpCloseOpen = regexp.MustCompile(`(?i)\?>\s*<\?php`)
cleanRegexpInlineEval = regexp.MustCompile(
`(?i)^[\s\x00\x0b]*(?:@?)eval\s*\(\s*(?:base64_decode|gzinflate|gzuncompress|str_rot13)\s*\(`)
cleanRegexpMultiB64 = regexp.MustCompile(`(?i)(?:base64_decode\s*\(\s*){2,}`)
cleanRegexpChainedB64 = regexp.MustCompile(`(?i)\$\w+\s*=\s*base64_decode\s*\(\s*base64_decode`)
cleanRegexpChrChain = regexp.MustCompile(`(?i)\bchr\s*\(\s*\d+\s*\)(?:\s*\.?\s*\bchr\s*\(\s*\d+\s*\)){4,}`)
cleanRegexpPackHex = regexp.MustCompile(`(?i)pack\s*\(\s*["']H\*["']\s*,`)
cleanRegexpHexVar = regexp.MustCompile(`(?:"\x5c\x78[0-9a-fA-F]{2}){3,}|(?:\\x[0-9a-fA-F]{2}){3,}`)
)
// cleanMaxFileSize bounds how large a file surgical cleaning will read
// into memory. The detector that routes a file here (analyzePHPContent)
// only inspects bounded head and tail windows, so an attacker can match a
// signature inside those windows and pad the rest to many gigabytes. Reading
// that whole file with io.ReadAll plus the strings.Split and regex passes
// below would OOM the root daemon. Above this ceiling we refuse, and the
// caller falls back to quarantine-by-rename, which never reads content.
// Legitimate plugin/theme PHP files are far smaller than this. Var, not
// const, so tests can lower it. 8 MiB.
var cleanMaxFileSize int64 = 8 << 20
// CleanInfectedFile attempts to surgically remove malicious code from a PHP file
// while preserving the legitimate content. Always creates a backup first.
//
// Cleaning strategies (tried in order):
// 1. @include injection - remove @include lines pointing to /tmp, eval, base64, or via variables
// 2. Prepend injection - remove malicious code blocks at start of file (entropy-validated)
// 3. Append injection - remove malicious code after closing ?> or end of PSR-12 file
// 4. Inline eval injection - remove eval(base64_decode(...)) single-line injections
func CleanInfectedFile(path string) CleanResult {
result := CleanResult{Path: path}
target, err := openCleanTarget(path)
if err != nil {
result.Error = fmt.Sprintf("cannot read file: %v", err)
return result
}
defer target.Close()
if sz := target.Info.Size(); sz > cleanMaxFileSize {
result.Error = fmt.Sprintf("file too large to clean (%d bytes > %d), quarantining instead", sz, cleanMaxFileSize)
return result
}
data, err := io.ReadAll(target.File)
if err != nil {
result.Error = fmt.Sprintf("cannot read file: %v", err)
return result
}
// Create backup before any modification
backupDir := filepath.Join(quarantineDir, "pre_clean")
_ = os.MkdirAll(backupDir, 0700)
ts := time.Now().Format("20060102-150405")
safeName := strings.ReplaceAll(path, "/", "_")
backupPath := filepath.Join(backupDir, fmt.Sprintf("%s_%s", ts, safeName))
if err := os.WriteFile(backupPath, data, 0600); err != nil {
result.Error = fmt.Sprintf("cannot create backup: %v", err)
return result
}
result.BackupPath = backupPath
// Metadata sidecar derived from the same fd we read, so a directory
// race after open cannot change the metadata we record.
meta := map[string]interface{}{
"original_path": path,
"owner_uid": target.UID,
"group_gid": target.GID,
"mode": target.Info.Mode().String(),
"size": target.Info.Size(),
"quarantined_at": time.Now(),
"reason": "Pre-clean backup (surgical cleaning)",
}
metaData, _ := json.MarshalIndent(meta, "", " ")
_ = os.WriteFile(backupPath+".meta", metaData, 0600)
content := string(data)
originalLen := len(content)
var removals []string
// Strategy 1: Remove @include injections (including variable-based)
content, removed := removeIncludeInjections(content)
removals = append(removals, removed...)
// Strategy 2: Remove prepend injections (with entropy validation)
content, removed = removePrependInjection(content)
removals = append(removals, removed...)
// Strategy 3: Remove append injections (handles files with and without closing ?>)
content, removed = removeAppendInjection(content)
removals = append(removals, removed...)
// Strategy 4: Remove inline eval(base64_decode(...)) injections
content, removed = removeInlineEvalInjections(content)
removals = append(removals, removed...)
// Strategy 5: Remove multi-layer base64 decode chains
content, removed = removeMultiLayerBase64(content)
removals = append(removals, removed...)
// Strategy 6: Remove chr()/pack() constructed code
content, removed = removeChrPackInjections(content)
removals = append(removals, removed...)
// Strategy 7: Remove hex-encoded variable injections
content, removed = removeHexVarInjections(content)
removals = append(removals, removed...)
// If nothing was removed, file couldn't be cleaned
if len(removals) == 0 || len(content) == originalLen {
result.Error = "no known injection patterns found - file may need manual review"
return result
}
if err := writeCleanedFileAtomic(target, []byte(content)); err != nil {
result.Error = fmt.Sprintf("cannot write cleaned file: %v", err)
return result
}
result.Cleaned = true
result.Removals = removals
return result
}
type cleanTarget struct {
Path string
DirFD int
Name string
File *os.File
Info os.FileInfo
UID int
GID int
OwnerKnown bool
}
func (t *cleanTarget) Close() {
if t.File != nil {
_ = t.File.Close()
}
if t.DirFD >= 0 {
_ = unix.Close(t.DirFD)
}
}
func openCleanTarget(path string) (*cleanTarget, error) {
dir, name := filepath.Split(path)
if name == "" || name == "." || name == ".." {
return nil, fmt.Errorf("invalid target path %q", path)
}
if dir == "" {
dir = "."
}
dir = filepath.Clean(dir)
parentInfo, err := os.Lstat(dir)
if err != nil {
return nil, fmt.Errorf("stat parent directory: %w", err)
}
if parentInfo.Mode()&os.ModeSymlink != 0 {
return nil, fmt.Errorf("refusing symlinked parent directory")
}
// Pin the immediate parent. A swap of that directory to a symlink
// after detection must not redirect either the read or the writeback.
dirFD, err := unix.Open(dir, unix.O_RDONLY|unix.O_DIRECTORY|unix.O_NOFOLLOW|unix.O_CLOEXEC, 0)
if err != nil {
return nil, fmt.Errorf("open parent directory: %w", err)
}
closeDir := true
defer func() {
if closeDir {
_ = unix.Close(dirFD)
}
}()
if pinErr := verifyCleanParentStillPinned(dirFD, parentInfo); pinErr != nil {
return nil, pinErr
}
fd, err := unix.Openat(dirFD, name, unix.O_RDONLY|unix.O_NOFOLLOW|unix.O_CLOEXEC, 0)
if err != nil {
return nil, err
}
// #nosec G115 -- unix.Openat returned a non-negative fd because err is nil.
file := os.NewFile(uintptr(fd), path)
closeFile := true
defer func() {
if closeFile {
_ = file.Close()
}
}()
info, err := file.Stat()
if err != nil {
return nil, fmt.Errorf("stat file: %w", err)
}
if !info.Mode().IsRegular() {
return nil, fmt.Errorf("refusing non-regular file (mode=%v)", info.Mode())
}
target := &cleanTarget{
Path: path,
DirFD: dirFD,
Name: name,
File: file,
Info: info,
}
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
target.UID = int(stat.Uid)
target.GID = int(stat.Gid)
target.OwnerKnown = true
}
closeDir = false
closeFile = false
return target, nil
}
func verifyCleanParentStillPinned(dirFD int, want os.FileInfo) error {
var got unix.Stat_t
if err := unix.Fstat(dirFD, &got); err != nil {
return fmt.Errorf("stat opened parent directory: %w", err)
}
if !sameUnixStatIdentity(want, got) {
return fmt.Errorf("parent directory changed during cleaning")
}
return nil
}
func sameUnixStatIdentity(info os.FileInfo, stat unix.Stat_t) bool {
if info == nil {
return false
}
want, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return false
}
return uint64(want.Dev) == uint64(stat.Dev) && uint64(want.Ino) == uint64(stat.Ino)
}
// writeCleanedFileAtomic stages cleaned content through a hidden sibling
// name under the pinned parent directory and renames it over the original
// only after the path still resolves to the inode we read.
func writeCleanedFileAtomic(target *cleanTarget, content []byte) error {
tmp, tmpName, err := createCleanTempFile(target.DirFD)
if err != nil {
return err
}
removeTmp := true
defer func() {
if removeTmp {
_ = unix.Unlinkat(target.DirFD, tmpName, 0)
}
_ = tmp.Close()
}()
if _, err := tmp.Write(content); err != nil {
return err
}
if target.OwnerKnown {
if err := tmp.Chown(target.UID, target.GID); err != nil {
return err
}
}
if err := tmp.Chmod(cleanReplacementMode(target.Info.Mode())); err != nil {
return err
}
if err := tmp.Sync(); err != nil {
return err
}
if err := verifyCleanTargetUnchanged(target); err != nil {
return err
}
if err := unix.Renameat(target.DirFD, tmpName, target.DirFD, target.Name); err != nil {
return err
}
removeTmp = false
return nil
}
func createCleanTempFile(dirFD int) (*os.File, string, error) {
for i := 0; i < 100; i++ {
var raw [8]byte
if _, err := rand.Read(raw[:]); err != nil {
return nil, "", fmt.Errorf("random temp name: %w", err)
}
name := ".csm-clean-" + hex.EncodeToString(raw[:])
fd, err := unix.Openat(dirFD, name, unix.O_WRONLY|unix.O_CREAT|unix.O_EXCL|unix.O_NOFOLLOW|unix.O_CLOEXEC, 0o600)
if err == unix.EEXIST {
continue
}
if err != nil {
return nil, "", err
}
// #nosec G115 -- unix.Openat returned a non-negative fd because err is nil.
return os.NewFile(uintptr(fd), name), name, nil
}
return nil, "", fmt.Errorf("could not allocate temp file name")
}
func verifyCleanTargetUnchanged(target *cleanTarget) error {
fd, err := unix.Openat(target.DirFD, target.Name, unix.O_RDONLY|unix.O_NOFOLLOW|unix.O_CLOEXEC, 0)
if err != nil {
return fmt.Errorf("open target before rename: %w", err)
}
// #nosec G115 -- unix.Openat returned a non-negative fd because err is nil.
file := os.NewFile(uintptr(fd), target.Path)
defer func() { _ = file.Close() }()
info, err := file.Stat()
if err != nil {
return fmt.Errorf("stat target before rename: %w", err)
}
if !info.Mode().IsRegular() {
return fmt.Errorf("refusing non-regular target before rename (mode=%v)", info.Mode())
}
if !sameFileIdentity(info, target.Info) || !sameCleanContentShape(info, target.Info) {
return fmt.Errorf("file changed during cleaning")
}
return nil
}
func sameCleanContentShape(a, b os.FileInfo) bool {
if a == nil || b == nil {
return false
}
return a.Size() == b.Size() && a.ModTime().Equal(b.ModTime())
}
func cleanReplacementMode(mode os.FileMode) os.FileMode {
return mode & (os.ModePerm | os.ModeSetuid | os.ModeSetgid | os.ModeSticky)
}
// ShouldCleanInsteadOfQuarantine returns true if the file should be cleaned
// (surgical removal) instead of quarantined (full removal).
// WP core files and plugin files are better cleaned - removing them breaks the site.
// Unknown standalone files (droppers, webshells) should be quarantined.
func ShouldCleanInsteadOfQuarantine(path string) bool {
// WP core files - always clean, never quarantine
if strings.Contains(path, "/wp-includes/") || strings.Contains(path, "/wp-admin/") {
return true
}
// Plugin/theme main files - clean to preserve functionality
if strings.Contains(path, "/wp-content/plugins/") || strings.Contains(path, "/wp-content/themes/") {
// But not if the file itself is the malware (h4x0r.php inside a theme)
name := strings.ToLower(filepath.Base(path))
if isWebshellName(name) {
return false // quarantine this - it's a standalone webshell
}
return true
}
// Everything else - quarantine
return false
}
// removeIncludeInjections removes @include lines that load malicious files.
// Catches:
// - @include("/tmp/...") - literal paths to temp dirs
// - @include(base64_decode("...")) - encoded includes
// - @include($var) where $var is built from obfuscated strings nearby
func removeIncludeInjections(content string) (string, []string) {
var removals []string
lines := strings.Split(content, "\n")
var clean []string
for i, line := range lines {
if cleanRegexpMaliciousInclude.MatchString(line) {
removals = append(removals, fmt.Sprintf("removed @include injection: %s", trimCleanRemovalLine(line)))
continue
}
// Variable-based @include - check surrounding context for obfuscation
if cleanRegexpVarInclude.MatchString(line) {
context := getLineContext(lines, i, 3)
contextLower := strings.ToLower(context)
isObfuscated := strings.Contains(contextLower, "base64_decode") ||
strings.Contains(contextLower, "str_rot13") ||
strings.Contains(contextLower, "chr(") ||
strings.Contains(contextLower, `"\x`) ||
strings.Count(contextLower, ". ") > 5 // heavy string concatenation
if isObfuscated {
removals = append(removals, fmt.Sprintf("removed obfuscated @include: %s", trimCleanRemovalLine(line)))
continue
}
}
clean = append(clean, line)
}
return strings.Join(clean, "\n"), removals
}
func trimCleanRemovalLine(line string) string {
return strings.TrimFunc(line, func(r rune) bool {
return r == '\x00' || unicode.IsSpace(r)
})
}
// removePrependInjection removes malicious PHP code injected before the
// legitimate file content. Uses entropy analysis to verify the prefix
// is actually obfuscated (not legitimate minified code).
func removePrependInjection(content string) (string, []string) {
var removals []string
trimmed := strings.TrimSpace(content)
// PHP open tags are case-insensitive (<?php, <?PHP, <?Php all execute), so
// match without lowercasing the whole file.
if len(trimmed) < 5 || !strings.EqualFold(trimmed[:5], "<?php") {
return content, nil
}
// Find if there's a malicious block at the start followed by ?><?php
loc := cleanRegexpCloseOpen.FindStringIndex(content)
if loc == nil {
return content, nil
}
prefix := content[:loc[0]]
prefixLower := strings.ToLower(prefix)
// Check if the prefix contains malicious patterns
hasMaliciousPatterns := strings.Contains(prefixLower, "eval(") ||
strings.Contains(prefixLower, "base64_decode") ||
strings.Contains(prefixLower, "gzinflate") ||
strings.Contains(prefixLower, "str_rot13") ||
strings.Contains(prefixLower, "@include")
if !hasMaliciousPatterns {
return content, nil
}
// Additional safety: verify the prefix has high entropy (obfuscated code)
// or contains long encoded strings. This prevents false positives on
// legitimate minified PHP that happens to use ?><?php patterns.
entropy := shannonEntropy(prefix)
hasLongStrings := containsLongEncodedString(prefix, 100)
if entropy < 4.5 && !hasLongStrings {
// Low entropy and no long encoded strings - likely legitimate code
return content, nil
}
// Remove everything before the second <?php
cleaned := "<?php" + content[loc[1]:]
removals = append(removals, fmt.Sprintf("removed %d-byte prepend injection (entropy: %.2f)", loc[1], entropy))
return cleaned, removals
}
// removeAppendInjection removes malicious code appended after the end of
// legitimate PHP content. Handles both files with closing ?> and files
// without (PSR-12 style).
func removeAppendInjection(content string) (string, []string) {
var removals []string
// Case 1: File has closing ?> with malicious code after it
lastClose := strings.LastIndex(content, "?>")
if lastClose >= 0 {
after := content[lastClose+2:]
afterTrimmed := strings.TrimSpace(after)
if afterTrimmed != "" {
afterLower := strings.ToLower(afterTrimmed)
isMalicious := strings.Contains(afterLower, "eval(") ||
strings.Contains(afterLower, "base64_decode") ||
strings.Contains(afterLower, "gzinflate") ||
strings.Contains(afterLower, "system(") ||
strings.Contains(afterLower, "exec(") ||
strings.Contains(afterLower, "@include") ||
strings.Contains(afterLower, "<?php") // second PHP block appended
if isMalicious {
cleaned := content[:lastClose+2] + "\n"
removals = append(removals, fmt.Sprintf("removed %d-byte append injection (after ?>)", len(after)))
return cleaned, removals
}
}
}
// Case 2: PSR-12 style file (no closing ?>) - check if there's a malicious
// block appended at the very end, separated by multiple newlines
lines := strings.Split(content, "\n")
if len(lines) < 5 {
return content, nil
}
// Check last 10 lines for injected code block
startCheck := len(lines) - 10
if startCheck < 0 {
startCheck = 0
}
blankLineIdx := -1
for i := startCheck; i < len(lines); i++ {
if strings.TrimSpace(lines[i]) == "" && blankLineIdx < 0 {
blankLineIdx = i
}
}
if blankLineIdx >= 0 {
tailBlock := strings.Join(lines[blankLineIdx:], "\n")
tailLower := strings.ToLower(tailBlock)
if strings.Contains(tailLower, "eval(") && strings.Contains(tailLower, "base64_decode") {
cleaned := strings.Join(lines[:blankLineIdx], "\n") + "\n"
removals = append(removals, fmt.Sprintf("removed %d-byte PSR-12 append injection", len(tailBlock)))
return cleaned, removals
}
}
return content, nil
}
// removeInlineEvalInjections removes single-line eval(base64_decode("..."));
// injections that are inserted as standalone lines in PHP files.
func removeInlineEvalInjections(content string) (string, []string) {
var removals []string
lines := strings.Split(content, "\n")
var clean []string
for _, line := range lines {
trimmedLine := strings.TrimSpace(line)
// Strip /* ... */ block comments and trailing // / # comments
// before matching. Attackers wedge comments between the keyword
// and the open paren ("@eval/*x*/(base64_decode(...))") to slip
// past a strict regex; the cleaner has to see the line the way
// the PHP tokenizer does, not byte-for-byte.
normalized := stripPHPComments(trimmedLine)
if cleanRegexpInlineEval.MatchString(normalized) {
// Length gate stays on the ORIGINAL line so a short legitimate
// eval() does not get sucked into the removal path.
if len(trimmedLine) > 50 {
removals = append(removals, fmt.Sprintf("removed inline eval injection (%d chars)", len(trimmedLine)))
continue
}
}
clean = append(clean, line)
}
return strings.Join(clean, "\n"), removals
}
// stripPHPComments removes PHP comments while leaving quoted strings
// intact, so comment-looking payload data does not change the code the
// cleaner evaluates.
func stripPHPComments(line string) string {
return strings.TrimSpace(stripPHPCommentsFromCode(line))
}
// --- Helper functions ---
// shannonEntropy calculates the Shannon entropy of a string.
// Obfuscated/encoded code typically has entropy > 5.0.
// Normal PHP code typically has entropy 4.0-4.5.
func shannonEntropy(s string) float64 {
if len(s) == 0 {
return 0
}
freq := make(map[byte]float64)
for i := 0; i < len(s); i++ {
freq[s[i]]++
}
length := float64(len(s))
entropy := 0.0
for _, count := range freq {
p := count / length
if p > 0 {
entropy -= p * math.Log2(p)
}
}
return entropy
}
// containsLongEncodedString checks if the text contains a long base64-like
// string (alphanumeric + /+ without spaces).
func containsLongEncodedString(s string, minLength int) bool {
if minLength <= 0 {
return true
}
run := 0
for i := 0; i < len(s); i++ {
if isEncodedStringByte(s[i]) {
run++
if run >= minLength {
return true
}
continue
}
run = 0
}
return false
}
func isEncodedStringByte(b byte) bool {
return b >= 'A' && b <= 'Z' ||
b >= 'a' && b <= 'z' ||
b >= '0' && b <= '9' ||
b == '+' || b == '/' || b == '='
}
// getLineContext returns N lines before and after the given line index.
func getLineContext(lines []string, idx, window int) string {
start := idx - window
if start < 0 {
start = 0
}
end := idx + window + 1
if end > len(lines) {
end = len(lines)
}
return strings.Join(lines[start:end], "\n")
}
// FormatCleanResult returns a human-readable summary of a clean operation.
func FormatCleanResult(r CleanResult) string {
if r.Error != "" {
return fmt.Sprintf("FAILED to clean %s: %s", r.Path, r.Error)
}
if !r.Cleaned {
return fmt.Sprintf("No changes made to %s", r.Path)
}
var b strings.Builder
fmt.Fprintf(&b, "CLEANED %s\n", r.Path)
fmt.Fprintf(&b, " Backup: %s\n", r.BackupPath)
for _, removal := range r.Removals {
fmt.Fprintf(&b, " - %s\n", removal)
}
return b.String()
}
// --- Strategy 5: Multi-layer base64 decode chains ---
// Catches: eval(base64_decode(base64_decode("...")))
// Catches: $x=base64_decode("...");$y=base64_decode($x);eval($y);
func removeMultiLayerBase64(content string) (string, []string) {
var removals []string
lines := strings.Split(content, "\n")
var clean []string
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if len(trimmed) < 80 {
clean = append(clean, line)
continue
}
if cleanRegexpMultiB64.MatchString(trimmed) || cleanRegexpChainedB64.MatchString(trimmed) {
removals = append(removals, fmt.Sprintf("removed multi-layer base64 chain (%d chars)", len(trimmed)))
continue
}
clean = append(clean, line)
}
return strings.Join(clean, "\n"), removals
}
// --- Strategy 6: chr()/pack() constructed code ---
// Catches: eval(chr(115).chr(121).chr(115)...);
// Catches: $f = pack("H*", "73797374656d"); $f($_POST['cmd']);
func removeChrPackInjections(content string) (string, []string) {
var removals []string
lines := strings.Split(content, "\n")
// Find chr-chain spans across the full content first so multi-line
// chains (chr(115)\n.chr(121)\n...) are recognized as one chain
// instead of slipping past the per-line 5+ count.
dropLine := make(map[int]bool)
offsets := make([]int, len(lines))
off := 0
for i, line := range lines {
offsets[i] = off
off += len(line) + 1 // +1 for the newline strings.Split consumed
}
for _, span := range cleanRegexpChrChain.FindAllStringIndex(content, -1) {
startLine, endLine := chrChainStatementLineRange(lines, offsets, span)
for i := startLine; i <= endLine; i++ {
dropLine[i] = true
}
}
var clean []string
for i, line := range lines {
if dropLine[i] {
removals = append(removals, fmt.Sprintf("removed chr() chain injection (line %d, %d chars)", i+1, len(strings.TrimSpace(line))))
continue
}
trimmed := strings.TrimSpace(line)
if cleanRegexpPackHex.MatchString(trimmed) {
lower := strings.ToLower(trimmed)
// Only remove if combined with execution
if strings.Contains(lower, "eval") || strings.Contains(lower, "$_") ||
strings.Contains(lower, "system") || strings.Contains(lower, "exec") {
removals = append(removals, fmt.Sprintf("removed pack() code construction (line %d)", i+1))
continue
}
}
clean = append(clean, line)
}
return strings.Join(clean, "\n"), removals
}
func chrChainStatementLineRange(lines []string, offsets []int, span []int) (int, int) {
startLine := lineIndexForOffset(offsets, span[0])
endLine := lineIndexForOffset(offsets, span[1]-1)
for startLine > 0 && chrChainPrefixContinues(lines[startLine-1]) {
startLine--
}
for endLine+1 < len(lines) && !strings.Contains(lines[endLine], ";") && chrChainSuffixContinues(lines[endLine+1]) {
endLine++
if strings.Contains(lines[endLine], ";") {
break
}
}
return startLine, endLine
}
func chrChainPrefixContinues(line string) bool {
trimmed := strings.TrimSpace(line)
if trimmed == "" || trimmed == "<?php" {
return false
}
return strings.HasSuffix(trimmed, "=") ||
strings.HasSuffix(trimmed, ".") ||
strings.HasSuffix(trimmed, "(") ||
strings.HasSuffix(trimmed, ",") ||
strings.HasSuffix(trimmed, "[")
}
func chrChainSuffixContinues(line string) bool {
trimmed := strings.TrimSpace(line)
if trimmed == "" {
return true
}
return strings.HasPrefix(trimmed, ".") ||
strings.HasPrefix(trimmed, ")") ||
strings.HasPrefix(trimmed, ",") ||
strings.HasPrefix(trimmed, ";")
}
// lineIndexForOffset returns the index of the line containing the byte
// at offset within the original content.
func lineIndexForOffset(offsets []int, byteOffset int) int {
if len(offsets) == 0 || byteOffset < 0 {
return 0
}
idx := sort.Search(len(offsets), func(i int) bool {
return offsets[i] > byteOffset
}) - 1
if idx < 0 {
return 0
}
if idx >= len(offsets) {
return len(offsets) - 1
}
return idx
}
// --- Strategy 7: Hex-encoded variable injections ---
// Catches: $GLOBALS["\x61\x64\x6d\x69\x6e"] = eval(...)
// Catches: ${"\x47\x4c\x4f\x42\x41\x4c\x53"}[...] = ...
func removeHexVarInjections(content string) (string, []string) {
var removals []string
lines := strings.Split(content, "\n")
var clean []string
for i, line := range lines {
trimmed := strings.TrimSpace(line)
if len(trimmed) < 30 {
clean = append(clean, line)
continue
}
if cleanRegexpHexVar.MatchString(trimmed) {
lower := strings.ToLower(trimmed)
// Only remove if combined with dangerous operations
if strings.Contains(lower, "eval") || strings.Contains(lower, "system(") ||
strings.Contains(lower, "exec(") || strings.Contains(lower, "base64_decode") ||
strings.Contains(lower, "assert(") || strings.Contains(lower, "$_post") ||
strings.Contains(lower, "$_request") || strings.Contains(lower, "$_get") {
removals = append(removals, fmt.Sprintf("removed hex-encoded variable injection (line %d, %d chars)", i+1, len(trimmed)))
continue
}
}
clean = append(clean, line)
}
return strings.Join(clean, "\n"), removals
}
package checks
import (
"crypto/sha256"
"encoding/hex"
"io"
"sync"
)
// CMSHashCache stores SHA256 hashes of verified-clean CMS core files.
// After wp core verify-checksums confirms an installation is clean,
// all its core files are hashed and cached. The real-time scanner
// checks this cache before reporting signature matches - if a file's
// hash is in the cache, it's a known-clean CMS file and signature
// matches on it are false positives.
type CMSHashCache struct {
mu sync.RWMutex
hashes map[string]bool // SHA256 hex → true
}
var (
globalCache *CMSHashCache
globalCacheOnce sync.Once
)
// GlobalCMSCache returns the singleton cache, creating it on first call.
func GlobalCMSCache() *CMSHashCache {
globalCacheOnce.Do(func() {
globalCache = &CMSHashCache{
hashes: make(map[string]bool),
}
})
return globalCache
}
// Add inserts a file hash into the cache.
func (c *CMSHashCache) Add(hash string) {
c.mu.Lock()
c.hashes[hash] = true
c.mu.Unlock()
}
// Contains checks if a file hash is in the cache.
func (c *CMSHashCache) Contains(hash string) bool {
c.mu.RLock()
ok := c.hashes[hash]
c.mu.RUnlock()
return ok
}
// Size returns the number of cached hashes.
func (c *CMSHashCache) Size() int {
c.mu.RLock()
n := len(c.hashes)
c.mu.RUnlock()
return n
}
// Clear removes all cached hashes (used before rebuilding).
func (c *CMSHashCache) Clear() {
c.mu.Lock()
c.hashes = make(map[string]bool)
c.mu.Unlock()
}
// HashFile computes the SHA256 hash of a file. Returns empty string on error.
func HashFile(path string) string {
f, err := osFS.Open(path)
if err != nil {
return ""
}
defer func() { _ = f.Close() }()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return ""
}
return hex.EncodeToString(h.Sum(nil))
}
// IsVerifiedCMSFile checks if a file at the given path matches a
// known-clean CMS core file by comparing its SHA256 hash against the cache.
//
// The cache is keyed by SHA256 hash alone (not path+hash) - this is correct:
// - SHA256 preimage resistance makes it computationally infeasible for an
// attacker to craft a file that produces the same hash as a legitimate
// WP core file. Birthday attacks do not apply here because the attacker
// must hit a specific pre-existing hash, not merely find any collision.
// - If file content matches a known WP core file byte-for-byte, it IS that
// file regardless of where it is located on disk. The path is irrelevant
// to whether the content is clean.
func IsVerifiedCMSFile(path string) bool {
cache := GlobalCMSCache()
if cache.Size() == 0 {
return false
}
hash := HashFile(path)
if hash == "" {
return false
}
return cache.Contains(hash)
}
package checks
import (
"context"
"encoding/hex"
"fmt"
"net"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/state"
)
// safeRemotePorts and safeUsers are package-level so both legacy polling and
// the BPF live coordinator share one source of truth.
var safeRemotePorts = map[uint16]bool{
53: true, 80: true, 443: true, 25: true, 587: true, 465: true,
993: true, 995: true, 110: true, 143: true,
}
var safeUsers = map[string]bool{
"imunify360-webshield": true,
"named": true,
"mysql": true,
"memcached": true,
"icinga": true,
"dovecot": true,
"mailman": true,
}
var serverLocalPorts = map[uint16]bool{
21: true, 25: true, 26: true, 53: true, 80: true, 110: true,
143: true, 443: true, 465: true, 587: true, 993: true, 995: true,
2082: true, 2083: true, 2086: true, 2087: true, 2095: true, 2096: true,
3306: true, 4190: true,
52223: true, 52224: true, 52227: true, 52228: true,
52229: true, 52230: true, 52231: true, 52232: true,
}
// EvaluateConnection returns a populated alert.Finding and true when the
// connection should be reported, or a zero finding and false when it should
// be ignored. Pure function: no IO, no clock. Used by the BPF live backend
// (per-event) and the polling backend (per row of /proc/net/tcp[6]).
func EvaluateConnection(
cfg *config.Config,
uid uint32,
dstIP net.IP,
dstPort uint16,
localPort uint16,
proto string,
user string,
) (alert.Finding, bool) {
if uid == 0 {
return alert.Finding{}, false
}
if dstIP == nil || dstIP.IsLoopback() || dstIP.IsUnspecified() {
return alert.Finding{}, false
}
if serverLocalPorts[localPort] {
return alert.Finding{}, false
}
if safeRemotePorts[dstPort] {
return alert.Finding{}, false
}
if isInfraIP(dstIP.String(), cfg.InfraIPs) {
return alert.Finding{}, false
}
if safeUsers[user] {
return alert.Finding{}, false
}
dst := dstIP.String()
if dstIP.To4() == nil {
dst = "[" + dst + "]"
}
return alert.Finding{
Severity: alert.High,
Check: "user_outbound_connection",
Message: fmt.Sprintf("Non-root user connecting to unusual destination: %s:%d", dst, dstPort),
Details: fmt.Sprintf("UID: %d (%s), Local port: %d, Proto: %s", uid, user, localPort, proto),
}, true
}
// CheckOutboundUserConnections looks for non-root user processes making
// outbound connections to IPs that aren't infra or well-known services.
// Catches compromised accounts phoning home.
func CheckOutboundUserConnections(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
if data, err := osFS.ReadFile("/proc/net/tcp"); err == nil {
findings = append(findings, scanProcNetTCP(cfg, data, false)...)
} else {
// Preserve historical behaviour: if /proc/net/tcp is unreadable,
// return nil rather than continuing to tcp6.
return nil
}
if tcp6Data, err := osFS.ReadFile("/proc/net/tcp6"); err == nil {
findings = append(findings, scanProcNetTCP(cfg, tcp6Data, true)...)
}
return findings
}
// scanProcNetTCP parses one /proc/net/tcp[6] dump and returns findings for
// every ESTABLISHED row that EvaluateConnection flags.
//
// A first pass collects local sockets in LISTEN state so that an ESTABLISHED
// row whose local address and port have a listener is recognised as the
// accept side of an inbound connection (e.g. pure-ftpd PASV data channels,
// user-owned daemons listening on high ports) and not an outbound connect().
// Wildcard listeners match every local address for that port.
func scanProcNetTCP(cfg *config.Config, data []byte, ipv6 bool) []alert.Finding {
var findings []alert.Finding
proto := "tcp"
if ipv6 {
proto = "tcp6"
}
lines := strings.Split(string(data), "\n")
listeners := collectListenSockets(lines, ipv6)
directSMTPEnabled := DirectSMTPEgressBackendEnabled(cfg, "legacy")
var mta platform.MTAIdents
if directSMTPEnabled {
// Resolve MTA identities once per scan; legacy poller has no per-PID context,
// so the EvaluateDirectSMTPEgress UID/user gate carries the load.
mta = platform.LocalMTAIdentities(platform.Detect())
}
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) < 8 || fields[0] == "sl" {
continue
}
// State 01 = ESTABLISHED
if fields[3] != "01" {
continue
}
uidStr := fields[7]
uidU64, err := strconv.ParseUint(uidStr, 10, 32)
if err != nil {
continue
}
uidU32 := uint32(uidU64)
var (
localIP net.IP
dstIP net.IP
dstPort int
localPort int
)
if ipv6 {
localIP, localPort = parseHex6Addr(fields[1])
dstIP, dstPort = parseHex6Addr(fields[2])
} else {
localAddr, parsedLocalPort := parseHexAddr(fields[1])
localIP = net.ParseIP(localAddr)
localPort = parsedLocalPort
remoteIP, remotePort := parseHexAddr(fields[2])
dstIP = net.ParseIP(remoteIP)
dstPort = remotePort
}
if localIP == nil || dstIP == nil || localPort <= 0 || dstPort <= 0 {
continue
}
if listeners.has(localIP, localPort) {
continue
}
// Bad-ASN egress is classified for every UID, root included: a
// post-exploit root process exfiltrating to a bad ASN is the
// host-takeover signal. The live BPF tracker drops root events for
// flood control, so this periodic scan is where root egress is seen.
// The non-root detectors below intentionally skip root.
if lookup := CurrentASNLookup(); lookup != nil && cfg.Detection.BadASNOutbound.Enabled {
asn, org := lookup(dstIP.String())
if f, ok := EvaluateBadASNOutbound(cfg, dstIP, asn, org); ok {
f.Timestamp = time.Now()
findings = append(findings, f)
}
}
if uidU64 == 0 {
continue
}
user := LookupUser(uidU32)
// #nosec G115 -- ports parsed from /proc/net/tcp[6] are bounded by uint16.
if f, ok := EvaluateConnection(cfg, uidU32, dstIP, uint16(dstPort), uint16(localPort), proto, user); ok {
f.Timestamp = time.Now()
findings = append(findings, f)
}
if directSMTPEnabled {
// #nosec G115 -- ports parsed from /proc/net/tcp[6] are bounded by uint16.
if f, ok := EvaluateDirectSMTPEgress(cfg, DirectSMTPEgressInput{
UID: uidU32,
User: user,
DstIP: dstIP,
DstPort: uint16(dstPort),
MTA: mta,
}); ok {
f.Timestamp = time.Now()
findings = append(findings, f)
}
}
}
return findings
}
type listenSocket struct {
address string
port int
}
type listenSocketSet struct {
wildcardPorts map[int]bool
sockets map[listenSocket]bool
}
func (s listenSocketSet) has(ip net.IP, port int) bool {
if port <= 0 {
return false
}
if s.wildcardPorts[port] {
return true
}
if ip == nil {
return false
}
return s.sockets[listenSocket{address: normalizeListenIP(ip), port: port}]
}
// collectListenSockets scans /proc/net/tcp[6] rows for state 0A (LISTEN) and
// returns the set of local sockets a process is bound to.
func collectListenSockets(lines []string, ipv6 bool) listenSocketSet {
listeners := listenSocketSet{
wildcardPorts: make(map[int]bool),
sockets: make(map[listenSocket]bool),
}
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) < 8 || fields[0] == "sl" {
continue
}
if fields[3] != "0A" {
continue
}
var (
localIP net.IP
localPort int
)
if ipv6 {
localIP, localPort = parseHex6Addr(fields[1])
} else {
localAddr, parsedLocalPort := parseHexAddr(fields[1])
localIP = net.ParseIP(localAddr)
localPort = parsedLocalPort
}
if localIP == nil || localPort <= 0 {
continue
}
if localIP.IsUnspecified() {
listeners.wildcardPorts[localPort] = true
continue
}
listeners.sockets[listenSocket{address: normalizeListenIP(localIP), port: localPort}] = true
}
return listeners
}
func normalizeListenIP(ip net.IP) string {
if v4 := ip.To4(); v4 != nil {
return net.IP(v4).String()
}
return ip.String()
}
// parseHex6Addr parses an IPv6 address:port from /proc/net/tcp6 format.
// IPv6 addresses are 32 hex chars (128 bits) in little-endian 4-byte groups.
func parseHex6Addr(s string) (net.IP, int) {
parts := strings.SplitN(s, ":", 2)
if len(parts) != 2 {
return nil, 0
}
hexIP := parts[0]
hexPort := parts[1]
if len(hexIP) != 32 {
return nil, 0
}
port, ok := parseProcNetHexPort(hexPort)
if !ok {
return nil, 0
}
// Parse as 4 little-endian 32-bit words
ip := make(net.IP, 16)
for i := 0; i < 4; i++ {
word := hexIP[i*8 : (i+1)*8]
b, _ := hex.DecodeString(word)
if len(b) != 4 {
return nil, 0
}
// Reverse bytes within each 32-bit word (little-endian to big-endian)
ip[i*4+0] = b[3]
ip[i*4+1] = b[2]
ip[i*4+2] = b[1]
ip[i*4+3] = b[0]
}
return ip, port
}
// CheckSSHDConfig monitors sshd_config for dangerous changes.
func CheckSSHDConfig(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
hash, err := hashFileContent(sshdConfigPath)
if err != nil {
return nil
}
current := currentSSHDSettings()
hashKey := "_sshd_config_hash"
passKey := "_sshd_passwordauthentication"
rootKey := "_sshd_permitrootlogin"
prevHash, exists := store.GetRaw(hashKey)
prevPass, _ := store.GetRaw(passKey)
prevRoot, _ := store.GetRaw(rootKey)
if exists && prevHash != hash {
// Only alert when the effective setting changed into a dangerous value.
// This avoids false positives from commented defaults or Match blocks.
if current.PasswordAuthentication == "yes" && prevPass != "yes" {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "sshd_config_change",
Message: "PasswordAuthentication changed to 'yes' in sshd_config",
Details: "This allows password-based SSH login - high risk if passwords are compromised",
})
}
if current.PermitRootLogin == "yes" && prevRoot != "yes" {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "sshd_config_change",
Message: "PermitRootLogin changed to 'yes' in sshd_config",
})
}
// Generic change alert if no specific dangerous setting found
if len(findings) == 0 {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "sshd_config_change",
Message: "sshd_config modified",
})
}
}
store.SetRaw(hashKey, hash)
store.SetRaw(passKey, current.PasswordAuthentication)
store.SetRaw(rootKey, current.PermitRootLogin)
return findings
}
// CheckNulledPlugins scans WordPress plugin directories for signs of
// nulled/pirated plugins: missing licenses, known crack patterns, GPL
// bypass code, and plugins not found on wordpress.org.
func CheckNulledPlugins(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Known crack/null signatures in PHP files
crackSignatures := []string{
"nulled by", "cracked by", "gpl-club", "gpldl.com",
"developer license", "remove license check",
"license_key_bypass", "activation_bypass",
"@remove_license", "null_license",
}
homeDirs, _ := GetScanHomeDirs(ctx)
for _, homeEntry := range homeDirs {
if !homeEntry.IsDir() {
continue
}
pluginsDir := filepath.Join("/home", homeEntry.Name(), "public_html", "wp-content", "plugins")
plugins, err := osFS.ReadDir(pluginsDir)
if err != nil {
continue
}
for _, plugin := range plugins {
if !plugin.IsDir() {
continue
}
pluginDir := filepath.Join(pluginsDir, plugin.Name())
// Check main plugin PHP file for crack signatures
mainFiles, _ := osFS.Glob(filepath.Join(pluginDir, "*.php"))
for _, mainFile := range mainFiles {
// Only read the first 10KB of each file
data := readFileHead(mainFile, 10*1024)
if data == nil {
continue
}
contentLower := strings.ToLower(string(data))
for _, sig := range crackSignatures {
if strings.Contains(contentLower, sig) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "nulled_plugin",
Message: fmt.Sprintf("Possible nulled plugin: %s/%s", homeEntry.Name(), plugin.Name()),
Details: fmt.Sprintf("File: %s\nSignature: %s", mainFile, sig),
})
break
}
}
}
}
}
return findings
}
// readFileHead reads the first N bytes of a file.
func readFileHead(path string, maxBytes int) []byte {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
buf := make([]byte, maxBytes)
n, _ := f.Read(buf)
if n == 0 {
return nil
}
return buf[:n]
}
package checks
import (
"fmt"
"sort"
"strings"
"github.com/pidginhost/csm/internal/alert"
)
// Checks that indicate active security events (not static config issues).
// Only these are used for cross-account correlation.
var securityEventChecks = map[string]bool{
"fake_kernel_thread": true,
"suspicious_process": true,
"php_suspicious_execution": true,
"backdoor_binary": true,
"webshell": true,
"new_webshell_file": true,
"new_executable_in_config": true,
"new_php_in_uploads": true,
"new_php_in_languages": true,
"new_php_in_upgrade": true,
"obfuscated_php": true,
"php_dropper": true,
"webshell_realtime": true,
"php_in_uploads_realtime": true,
"php_in_sensitive_dir_realtime": true,
"executable_in_config_realtime": true,
"obfuscated_php_realtime": true,
"webshell_content_realtime": true,
"c2_connection": true,
"cpanel_file_upload_realtime": true,
"shadow_change": true,
"root_password_change": true,
}
// CorrelateFindings analyzes findings for cross-account attack patterns.
// Only considers active security events, not static config issues
// (like WAF status, open_basedir, world-writable files).
func CorrelateFindings(findings []alert.Finding) []alert.Finding {
var extra []alert.Finding
// Count security event findings per account
accountCriticals := make(map[string]int)
for _, f := range findings {
if f.Severity != alert.Critical {
continue
}
if !securityEventChecks[f.Check] {
continue
}
account := extractAccountFromFinding(f)
if account != "" {
accountCriticals[account]++
}
}
// If 3+ accounts have critical security events, it's a coordinated attack
affectedAccounts := 0
var accountNames []string
for account, count := range accountCriticals {
if count > 0 {
affectedAccounts++
accountNames = append(accountNames, account)
}
}
if affectedAccounts >= 3 {
sort.Strings(accountNames)
extra = append(extra, alert.Finding{
Severity: alert.Critical,
Check: "coordinated_attack",
Message: fmt.Sprintf("Possible coordinated attack: %d accounts have critical security events", affectedAccounts),
Details: fmt.Sprintf("Affected accounts: %s", strings.Join(accountNames, ", ")),
})
}
// Check for same malware type across accounts
malwareByCheck := make(map[string][]string)
for _, f := range findings {
if f.Check == "new_executable_in_config" || f.Check == "backdoor_binary" ||
f.Check == "webshell" || f.Check == "new_webshell_file" {
account := extractAccountFromFinding(f)
if account != "" {
malwareByCheck[f.Check] = append(malwareByCheck[f.Check], account)
}
}
}
for check, accounts := range malwareByCheck {
unique := uniqueStrings(accounts)
if len(unique) >= 2 {
sort.Strings(unique)
extra = append(extra, alert.Finding{
Severity: alert.Critical,
Check: "cross_account_malware",
Message: fmt.Sprintf("Same malware type (%s) found in %d accounts", check, len(unique)),
Details: fmt.Sprintf("Accounts: %s", strings.Join(unique, ", ")),
})
}
}
return extra
}
func extractAccountFromFinding(f alert.Finding) string {
for _, s := range []string{f.Message, f.Details} {
if idx := strings.Index(s, "/home/"); idx >= 0 {
rest := s[idx+6:]
if slashIdx := strings.Index(rest, "/"); slashIdx > 0 {
return rest[:slashIdx]
}
}
}
return ""
}
func uniqueStrings(input []string) []string {
seen := make(map[string]bool)
var result []string
for _, s := range input {
if !seen[s] {
seen[s] = true
result = append(result, s)
}
}
return result
}
package checks
import (
"context"
"fmt"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const (
sessionLogPath = "/usr/local/cpanel/logs/session_log"
sessionLogTailLines = 1000
defaultMultiIPThreshold = 3
defaultMultiIPWindowMin = 60
)
// CheckCpanelLogins parses the cPanel session log for suspicious login activity:
// - cPanel (cpaneld) logins from non-infra IPs
// - Same account logged in from multiple distinct IPs (credential compromise indicator)
// - Password change purge events (attacker or auto-response password resets)
func CheckCpanelLogins(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
lines := tailFile(sessionLogPath, sessionLogTailLines)
if len(lines) == 0 {
return nil
}
// Track logins per account for multi-IP correlation
accountIPs := make(map[string]map[string]bool)
var passwordChanges []string
// Determine cutoff - only alert on events within the scan window
cutoff := time.Now().Add(-time.Duration(multiIPWindowMin(cfg)) * time.Minute)
for _, line := range lines {
// Session log format:
// [2026-03-25 07:19:47 +0200] info [cpaneld] 203.0.113.133 NEW user:token address=IP,...
// [2026-03-25 07:38:18 +0200] info [security] internal PURGE user:token password_change
// Parse timestamp
ts := parseSessionTimestamp(line)
if ts.IsZero() || ts.Before(cutoff) {
continue
}
// Detect cPanel logins from non-infra IPs
// Skip API/portal sessions (create_user_session) - only alert on direct form login
if strings.Contains(line, "[cpaneld]") && strings.Contains(line, " NEW ") {
if cfg.Suppressions.SuppressCpanelLogin {
// Still track IPs for multi-IP correlation even when suppressed
} else if strings.Contains(line, "method=create_user_session") ||
strings.Contains(line, "method=create_session") ||
strings.Contains(line, "create_user_session") {
continue
}
ip, account := parseCpanelLogin(line)
if ip == "" || account == "" {
continue
}
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" || ip == "internal" {
continue
}
// Always track for multi-IP correlation (even when login alerts suppressed)
if accountIPs[account] == nil {
accountIPs[account] = make(map[string]bool)
}
accountIPs[account][ip] = true
// WARNING severity - logins are audit trail, not paging-level.
// Multi-IP correlation stays CRITICAL via its own check below.
if !cfg.Suppressions.SuppressCpanelLogin {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "cpanel_login",
Message: fmt.Sprintf("cPanel direct login from non-infra IP: %s (account: %s)", ip, account),
Details: truncateString(line, 300),
})
}
}
// Detect password change purge events
if strings.Contains(line, "PURGE") && strings.Contains(line, "password_change") {
account := parsePurgeAccount(line)
if account != "" {
passwordChanges = append(passwordChanges, account)
}
}
}
// Multi-IP correlation: same account from 3+ distinct non-infra IPs
threshold := multiIPThreshold(cfg)
for account, ips := range accountIPs {
if len(ips) >= threshold {
ipList := make([]string, 0, len(ips))
for ip := range ips {
ipList = append(ipList, ip)
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "cpanel_multi_ip_login",
Message: fmt.Sprintf("Account '%s' logged in from %d distinct IPs (credential compromise likely)", account, len(ips)),
Details: fmt.Sprintf("IPs: %s\nThreshold: %d IPs within %d minutes", strings.Join(ipList, ", "), threshold, multiIPWindowMin(cfg)),
})
}
}
// Password change events - deduplicate by account
seen := make(map[string]bool)
for _, account := range passwordChanges {
if seen[account] {
continue
}
seen[account] = true
// Check if triggered by security module (Imunify auto-response) vs user action
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "cpanel_password_purge",
Message: fmt.Sprintf("cPanel sessions purged via password change for account: %s", account),
Details: "This may indicate an automated security response or attacker-initiated password change",
})
}
return findings
}
// CheckCpanelFileManager parses the cPanel access log for file management
// operations from non-infra IPs (file uploads, edits via cPanel File Manager).
func CheckCpanelFileManager(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
lines := tailFile("/usr/local/cpanel/logs/access_log", 300)
// Only match actual write actions - not read-only calls like get_homedir.
// Skip 401/403 responses - the server rejected the request, no write occurred.
// Match against request URI only, not the full line (referer can contain "upload").
filemanWriteActions := []string{
"fileman/save_file_content",
"fileman/upload_files",
"fileman/save_file",
"fileman/paste",
"fileman/rename",
"fileman/delete",
}
for _, line := range lines {
// Only check cPanel (port 2083) entries
if !strings.Contains(line, "2083") {
continue
}
// Skip rejected requests - no write occurred
if strings.Contains(line, "\" 401 ") || strings.Contains(line, "\" 403 ") {
continue
}
fields := strings.Fields(line)
if len(fields) < 1 {
continue
}
ip := fields[0]
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
// Extract request URI (between first pair of quotes) to avoid
// matching "upload" in referer URLs like upload-ajax.html
requestURI := extractRequestURIChecks(line)
for _, action := range filemanWriteActions {
if strings.Contains(strings.ToLower(requestURI), strings.ToLower(action)) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "cpanel_file_upload",
Message: fmt.Sprintf("cPanel File Manager write operation from non-infra IP: %s", ip),
Details: truncateString(line, 300),
})
break
}
}
}
return findings
}
// parseSessionTimestamp extracts the timestamp from a session log line.
// Format: [2026-03-25 07:19:47 +0200]
func parseSessionTimestamp(line string) time.Time {
start := strings.Index(line, "[")
end := strings.Index(line, "]")
if start < 0 || end < 0 || end <= start+1 {
return time.Time{}
}
tsStr := line[start+1 : end]
// Try common cPanel session log formats
for _, layout := range []string{
"2006-01-02 15:04:05 -0700",
"2006-01-02 15:04:05 +0000",
} {
if t, err := time.Parse(layout, tsStr); err == nil {
return t
}
}
return time.Time{}
}
// parseCpanelLogin extracts IP and account from a session NEW line.
// Format: [timestamp] info [cpaneld] 203.0.113.133 NEW user:token address=IP,...
func parseCpanelLogin(line string) (ip, account string) {
// Find IP after [cpaneld]
idx := strings.Index(line, "[cpaneld]")
if idx < 0 {
return "", ""
}
rest := strings.TrimSpace(line[idx+len("[cpaneld]"):])
fields := strings.Fields(rest)
if len(fields) < 3 {
return "", ""
}
ip = fields[0]
// Find account from "NEW user:token" or "NEW user:token address=..."
for i, f := range fields {
if f == "NEW" && i+1 < len(fields) {
userToken := fields[i+1]
parts := strings.SplitN(userToken, ":", 2)
if len(parts) >= 1 {
account = parts[0]
}
break
}
}
return ip, account
}
// parsePurgeAccount extracts the account name from a PURGE password_change line.
// Format: [timestamp] info [security] internal PURGE user:token password_change
func parsePurgeAccount(line string) string {
idx := strings.Index(line, "PURGE")
if idx < 0 {
return ""
}
rest := strings.TrimSpace(line[idx+len("PURGE"):])
fields := strings.Fields(rest)
if len(fields) < 1 {
return ""
}
parts := strings.SplitN(fields[0], ":", 2)
if len(parts) >= 1 {
return parts[0]
}
return ""
}
func multiIPThreshold(cfg *config.Config) int {
if cfg.Thresholds.MultiIPLoginThreshold > 0 {
return cfg.Thresholds.MultiIPLoginThreshold
}
return defaultMultiIPThreshold
}
func multiIPWindowMin(cfg *config.Config) int {
if cfg.Thresholds.MultiIPLoginWindowMin > 0 {
return cfg.Thresholds.MultiIPLoginWindowMin
}
return defaultMultiIPWindowMin
}
// extractRequestURIChecks extracts the request line from an access log entry.
// Format: ... "METHOD /path HTTP/1.1" ... → returns "METHOD /path HTTP/1.1"
func extractRequestURIChecks(line string) string {
start := strings.Index(line, "\"")
if start < 0 {
return ""
}
end := strings.Index(line[start+1:], "\"")
if end < 0 {
return ""
}
return line[start+1 : start+1+end]
}
package checks
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"path/filepath"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// credentialReuseMinAccounts is the threshold for emitting the
// password-reuse finding: the same WordPress admin password hash present
// on this many or more distinct accounts. Two is the common shared-hosting
// pattern (an agency reusing one admin password across client sites) where
// a single credential disclosure compromises every site at once.
const credentialReuseMinAccounts = 2
// CheckCredentialReuse flags WordPress administrator accounts that share
// an identical password hash across two or more distinct hosting accounts.
// Password hashes are salted on modern WordPress installs, so this is an
// exact at-rest hash reuse signal, not a weak-password detector.
//
// Privacy: the raw password hash is never stored, logged, or emitted. Only
// a truncated one-way fingerprint is used to group identical hashes, and
// findings report the affected accounts and a count -- not the hash.
func CheckCredentialReuse(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
wpConfigs, _ := osFS.Glob("/home/*/public_html/wp-config.php")
// fingerprint -> set of distinct accounts carrying that admin hash.
byFingerprint := map[string]map[string]struct{}{}
for _, wpConfig := range wpConfigs {
if ctx.Err() != nil {
return nil
}
account := extractUser(filepath.Dir(wpConfig))
if account == "" {
continue
}
creds := parseWPConfig(wpConfig)
if creds.dbName == "" {
continue
}
prefix, ok := resolveTablePrefix(creds)
if !ok {
continue
}
creds.tablePrefix = prefix
for _, fp := range adminPasswordFingerprintsForSite(creds, prefix) {
if fp == "" {
continue
}
if byFingerprint[fp] == nil {
byFingerprint[fp] = map[string]struct{}{}
}
byFingerprint[fp][account] = struct{}{}
}
}
return buildCredentialReuseFindings(byFingerprint, credentialReuseMinAccounts)
}
// adminPasswordFingerprintsForSite returns fingerprints for the admin
// user_pass hashes currently stored on the WordPress site. Uses root MySQL
// because wp-config passwords drift on cPanel hosts (same rationale as
// adminEmailsForSite).
func adminPasswordFingerprintsForSite(creds wpDBCreds, prefix string) []string {
query := fmt.Sprintf(
"SELECT DISTINCT u.user_pass FROM `%susers` u "+
"JOIN `%susermeta` um ON u.ID = um.user_id "+
"WHERE um.meta_key = '%scapabilities' AND um.meta_value LIKE '%%administrator%%'",
prefix, prefix, prefix,
)
rows := runMySQLQueryRoot(creds.dbName, query)
var out []string
for _, row := range rows {
fp := credentialHashFingerprint(strings.TrimSpace(row))
if fp != "" {
out = append(out, fp)
}
}
return out
}
// credentialHashFingerprint maps a raw password hash to a short,
// non-reversible grouping key. Two identical hashes map to the same
// fingerprint without returning the raw hash. Empty input yields "".
func credentialHashFingerprint(rawHash string) string {
if rawHash == "" {
return ""
}
sum := sha256.Sum256([]byte(rawHash))
return "fp:" + hex.EncodeToString(sum[:])[:16]
}
// buildCredentialReuseFindings emits one Warning per fingerprint shared by
// at least minAccounts distinct accounts. The finding never includes the
// hash or fingerprint-as-secret -- only the affected account list and a
// count, so an operator can rotate the shared credential.
func buildCredentialReuseFindings(byFingerprint map[string]map[string]struct{}, minAccounts int) []alert.Finding {
if minAccounts < 2 {
minAccounts = 2
}
type credentialReuseGroup struct {
accounts []string
}
groups := make([]credentialReuseGroup, 0, len(byFingerprint))
for _, accountSet := range byFingerprint {
if len(accountSet) < minAccounts {
continue
}
accounts := make([]string, 0, len(accountSet))
for a := range accountSet {
accounts = append(accounts, a)
}
sort.Strings(accounts)
groups = append(groups, credentialReuseGroup{accounts: accounts})
}
sort.Slice(groups, func(i, j int) bool {
return strings.Join(groups[i].accounts, "\x00") < strings.Join(groups[j].accounts, "\x00")
})
var out []alert.Finding
for _, group := range groups {
accounts := group.accounts
out = append(out, alert.Finding{
Severity: alert.Warning,
Check: "credential_reuse",
Message: fmt.Sprintf("Identical WordPress admin password hash reused across %d accounts: %s",
len(accounts), strings.Join(accounts, ", ")),
Details: fmt.Sprintf("Accounts sharing one admin password hash: %s\n"+
"Rotate the shared credential: a single disclosure compromises every listed site.",
strings.Join(accounts, ", ")),
Timestamp: time.Now(),
})
}
return out
}
package checks
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"syscall"
"time"
)
// fixCrontabAllowedRoots limits suspicious_crontab remediation to the cron
// spool. Declared as a var so tests can redirect it under t.TempDir()
// without touching the real /var/spool/cron.
var fixCrontabAllowedRoots = []string{"/var/spool/cron"}
// fixSuspiciousCrontab copies a user crontab matching known-bad persistence
// markers into quarantine, writes a restore-ready metadata sidecar, and then
// truncates the live file to empty. Truncation (not deletion) keeps cron(8)
// from re-reading stale content and preserves the caller's ability to
// inspect file perms while the malware is gone.
func fixSuspiciousCrontab(path string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path from finding"}
}
path, info, err := resolveExistingFixPath(path, fixCrontabAllowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
data, err := osFS.ReadFile(path)
if err != nil {
return RemediationResult{Error: fmt.Sprintf("cannot read: %v", err)}
}
_ = os.MkdirAll(quarantineDir, 0700)
ts := time.Now().Format("20060102-150405")
user := filepath.Base(path)
qPath := filepath.Join(quarantineDir, fmt.Sprintf("%s_crontab_%s", ts, user))
if err := os.WriteFile(qPath, data, 0600); err != nil {
return RemediationResult{Error: fmt.Sprintf("cannot write quarantine: %v", err)}
}
var uid, gid int
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
uid = int(stat.Uid)
gid = int(stat.Gid)
}
meta := map[string]interface{}{
"original_path": path,
"owner_uid": uid,
"group_gid": gid,
"mode": info.Mode().String(),
"size": info.Size(),
"quarantine_at": time.Now(),
"reason": "suspicious_crontab remediation",
}
metaData, _ := json.MarshalIndent(meta, "", " ")
if err := os.WriteFile(qPath+".meta", metaData, 0600); err != nil {
fmt.Fprintf(os.Stderr, "remediate: error writing crontab quarantine metadata %s: %v\n", qPath+".meta", err)
}
// Truncate live crontab. 0600 is the mode cron(8) enforces for user
// spool files; any other mode makes cron skip the file with a warning.
// #nosec G306 -- cron(8) rejects world-readable user crontabs, so 0600
// is the only safe mode for /var/spool/cron/<user>.
if err := os.WriteFile(path, []byte{}, 0600); err != nil {
return RemediationResult{Error: fmt.Sprintf("cannot truncate crontab: %v", err)}
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("quarantined crontab %s -> %s and truncated", path, qPath),
Description: fmt.Sprintf("Truncated %d-byte crontab; copy saved to quarantine", len(data)),
}
}
package checks
import (
"context"
"encoding/base64"
"fmt"
"path/filepath"
"regexp"
"strings"
"sync"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/state"
)
// crontabBase64Truncated counts base64 candidates that hit
// crontabBase64BlobMaxBytes. Sustained growth means an attacker is
// padding outer blobs to push the real payload past the decode window;
// raise the cap or split the scanner.
var (
crontabBase64Truncated *metrics.Counter
crontabBase64TruncatedOnce sync.Once
)
func observeCrontabBase64Truncation() {
crontabBase64TruncatedOnce.Do(func() {
crontabBase64Truncated = metrics.NewCounter(
"csm_checks_crontab_base64_truncated_total",
"Crontab base64 candidates that exceeded the per-blob decode cap before decoded-content pattern matching ran. Sustained growth means the scanner inspected only the leading decoded window of large encoded cron content.",
)
metrics.MustRegister("csm_checks_crontab_base64_truncated_total", crontabBase64Truncated)
})
crontabBase64Truncated.Inc()
}
func CheckCrontabs(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
if ctx == nil {
ctx = context.Background()
}
var findings []alert.Finding
// Rank account crontabs by mtime desc so recently-touched users
// process first when the check timeout cuts iteration short. Keep
// root outside the account cap; it is a system baseline, not an
// account-scoped path.
if ctx.Err() != nil {
return findings
}
crontabs, _ := osFS.Glob("/var/spool/cron/*")
var rootCrontabs []string
accountCrontabs := make([]string, 0, len(crontabs))
for _, path := range crontabs {
if filepath.Base(path) == "root" {
rootCrontabs = append(rootCrontabs, path)
continue
}
accountCrontabs = append(accountCrontabs, path)
}
rankedRootCrontabs := rankPathsByMtimeDesc(ctx, rootCrontabs, 0)
if ctx.Err() != nil {
return findings
}
for _, path := range rankedRootCrontabs {
if ctx.Err() != nil {
return findings
}
hash, err := hashFileContent(path)
if err != nil {
continue
}
if ctx.Err() != nil {
return findings
}
key := "_crontab_root_hash"
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "crontab_change",
Message: "Root crontab modified",
Details: "Review with: crontab -l",
})
}
store.SetRaw(key, hash)
}
rankedCrontabs := rankPathsByMtimeDesc(ctx, accountCrontabs, effectiveAccountScanMaxFiles(cfg))
if ctx.Err() != nil {
return findings
}
for _, path := range rankedCrontabs {
if ctx.Err() != nil {
return findings
}
user := filepath.Base(path)
data, err := osFS.ReadFile(path)
if err != nil {
continue
}
if ctx.Err() != nil {
return findings
}
content := string(data)
for _, pattern := range MatchCrontabPatternsDeep(content, cfg) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "suspicious_crontab",
Message: fmt.Sprintf("Suspicious pattern in crontab for user %s: %s", user, pattern),
Details: fmt.Sprintf("File: %s\nContent:\n%s", path, truncate(content, 500)),
})
}
}
// Check /etc/cron.d for new files. This is a system directory, so
// account_scan_max_files must not hide older cron.d baselines.
if ctx.Err() != nil {
return findings
}
cronDFiles, _ := osFS.Glob("/etc/cron.d/*")
rankedCronDFiles := rankPathsByMtimeDesc(ctx, cronDFiles, 0)
if ctx.Err() != nil {
return findings
}
for _, path := range rankedCronDFiles {
if ctx.Err() != nil {
return findings
}
hash, err := hashFileContent(path)
if err != nil {
continue
}
if ctx.Err() != nil {
return findings
}
key := fmt.Sprintf("_crond:%s", filepath.Base(path))
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "crond_change",
Message: fmt.Sprintf("Cron.d file modified: %s", path),
})
}
store.SetRaw(key, hash)
}
return findings
}
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// crontabSuspiciousPatterns is the shared allowlist of case-insensitive
// substrings that mark a crontab line as likely malicious. Single source of
// truth for CheckCrontabs (system scan) and makeAccountCrontabCheck
// (per-account scan) so the two cannot drift apart.
var crontabSuspiciousPatterns = []string{
"defunct-kernel",
"SEED PRNG",
"base64_decode",
"base64 -d|bash",
"base64 -d | bash",
"base64 --decode|bash",
"eval(",
"/dev/tcp/",
"gsocket",
"gs-netcat",
"reverse",
"bash -i",
"/bin/sh -i",
"nc -e",
"ncat -e",
"python -c",
"perl -e",
}
// matchCrontabPatterns returns patterns from crontabSuspiciousPatterns that
// appear as case-insensitive substrings of content, preserving list order.
func matchCrontabPatterns(content string) []string {
lower := strings.ToLower(content)
var matched []string
for _, pattern := range crontabSuspiciousPatterns {
if strings.Contains(lower, strings.ToLower(pattern)) {
matched = append(matched, pattern)
}
}
return matched
}
// crontabBase64BlobMaxBytesDefault is the built-in fallback cap for a
// single base64 candidate before decoding. 16384 encoded bytes
// (~12 KiB decoded) comfortably fits any realistic gsocket /
// `base64 -d|bash` payload while bounding work on adversarial input.
// Operator override: cfg.Thresholds.CrontabBase64BlobMaxBytes.
//
// Must stay a multiple of 4 -- standard base64 needs aligned input or
// DecodeString errors and the candidate is silently skipped. The
// validator rejects non-aligned operator values.
const crontabBase64BlobMaxBytesDefault = 16384
// effectiveCrontabBase64BlobMaxBytes returns the operator-configured cap
// or the built-in default when unset. The validator enforces multiple-of-4
// alignment so this returns a safe value without further checks.
func effectiveCrontabBase64BlobMaxBytes(cfg *config.Config) int {
if cfg == nil || cfg.Thresholds.CrontabBase64BlobMaxBytes <= 0 {
return crontabBase64BlobMaxBytesDefault
}
return cfg.Thresholds.CrontabBase64BlobMaxBytes
}
// crontabBase64BlobMaxCount caps the number of base64 candidates examined
// per crontab. A realistic gsocket cron entry has one outer blob; we
// allow enough headroom for a handful without doing unbounded work.
const crontabBase64BlobMaxCount = 16
// crontabBase64BlobRE matches contiguous standard-alphabet base64 of
// length >= 40 (with optional padding). The 40-char floor avoids matching
// short config IDs and noise like Wordfence cookie names.
var crontabBase64BlobRE = regexp.MustCompile(`[A-Za-z0-9+/]{40,}={0,2}`)
// MatchCrontabPatternsDeep is matchCrontabPatterns plus a single base64
// decode pass: it pulls out base64 candidates from content and re-runs
// pattern matching on the decoded bytes. Catches attackers who wrap the
// `base64 -d|bash` pipe chain in an outer base64 layer so the literal
// markers never appear in the cron file as written. Single decode depth;
// no recursion. cfg nil uses the built-in defaults; pass the live
// operator config to honour `thresholds.crontab_base64_blob_max_bytes`.
func MatchCrontabPatternsDeep(content string, cfg *config.Config) []string {
maxBytes := effectiveCrontabBase64BlobMaxBytes(cfg)
matched := matchCrontabPatterns(content)
seen := make(map[string]bool, len(matched))
for _, m := range matched {
seen[m] = true
}
candidates := crontabBase64BlobRE.FindAllString(content, crontabBase64BlobMaxCount)
for _, blob := range candidates {
if len(blob) > maxBytes {
observeCrontabBase64Truncation()
blob = blob[:maxBytes]
}
decoded, err := base64.StdEncoding.DecodeString(blob)
if err != nil {
continue
}
for _, m := range matchCrontabPatterns(string(decoded)) {
if !seen[m] {
matched = append(matched, m)
seen[m] = true
}
}
}
return matched
}
package checks
import (
"context"
"fmt"
"path/filepath"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
)
// adminEmailRetention bounds how far back an admin observation stays
// relevant for overlap detection. A contractor email seen six months
// ago on one account and never since is not actionable signal -- the
// access likely lapsed. The window is a deliberate balance against the
// alternative of evicting on every scan, which would lose overlaps when
// scans run asynchronously across customer accounts.
const adminEmailRetention = 90 * 24 * time.Hour
// adminEmailDefaultMinAccounts is the default threshold for emitting
// the cross-account overlap finding. Matches the most common
// compromise pattern on shared hosting: a contractor administering two
// or more customer cPanels.
const adminEmailDefaultMinAccounts = 2
// CheckAdminEmailOverlap records every WordPress administrator email
// encountered during an account scan into a server-wide bbolt bucket,
// then emits a Warning finding for each email whose owner list now
// spans the configured minimum number of distinct accounts. The
// detection surface is shared-hosting credential leakage: a single
// compromised contractor account is one credential disclosure away
// from administrator access on every site they touch.
//
// The check is silent when the bbolt store is unavailable (early
// daemon startup, test harness without state injection) -- it can't
// observe overlap without persistence between scans, and falling
// silent is better than a misleading partial result.
func CheckAdminEmailOverlap(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
db := store.Global()
if db == nil {
return nil
}
now := time.Now()
wpConfigs, _ := osFS.Glob("/home/*/public_html/wp-config.php")
for _, wpConfig := range wpConfigs {
if ctx.Err() != nil {
return nil
}
account := extractUser(filepath.Dir(wpConfig))
creds := parseWPConfig(wpConfig)
if creds.dbName == "" {
continue
}
prefix, ok := resolveTablePrefix(creds)
if !ok {
continue
}
creds.tablePrefix = prefix
for _, email := range adminEmailsForSite(creds, prefix) {
_ = db.RecordAdminEmail(email, account, creds.dbName, now)
}
}
min := adminEmailDefaultMinAccounts
if cfg != nil && cfg.Detection.AdminOverlapMinAccounts > 0 {
min = cfg.Detection.AdminOverlapMinAccounts
}
overlaps, err := db.OverlappingAdminEmails(min, adminEmailRetention)
if err != nil || len(overlaps) == 0 {
return nil
}
overlaps = filterTrustedAdminOverlaps(overlaps, cfg)
return buildAdminOverlapFindings(overlaps)
}
// adminEmailsForSite returns the lowercase admin emails currently
// configured on the WordPress site. Uses the existing root-MySQL
// helper so it works on cPanel hosts where wp-config passwords drift.
func adminEmailsForSite(creds wpDBCreds, prefix string) []string {
query := fmt.Sprintf(
"SELECT DISTINCT LOWER(u.user_email) FROM `%susers` u "+
"JOIN `%susermeta` um ON u.ID = um.user_id "+
"WHERE um.meta_key = '%scapabilities' AND um.meta_value LIKE '%%administrator%%'",
prefix, prefix, prefix,
)
rows := runMySQLQueryRoot(creds.dbName, query)
var out []string
for _, row := range rows {
row = strings.TrimSpace(row)
if row != "" {
out = append(out, row)
}
}
return out
}
// buildAdminOverlapFindings collapses each overlap entry into a single
// Warning finding. Account lists are sorted for deterministic message
// content so the dedup layer downstream treats two identical overlaps
// emitted across scans as the same finding.
func buildAdminOverlapFindings(overlaps map[string][]store.AdminEmailEntry) []alert.Finding {
emails := make([]string, 0, len(overlaps))
for email := range overlaps {
emails = append(emails, email)
}
sort.Strings(emails)
out := make([]alert.Finding, 0, len(emails))
for _, email := range emails {
owners := overlaps[email]
accountSet := make(map[string]struct{}, len(owners))
for _, o := range owners {
accountSet[o.Account] = struct{}{}
}
accounts := make([]string, 0, len(accountSet))
for a := range accountSet {
accounts = append(accounts, a)
}
sort.Strings(accounts)
details := strings.Builder{}
fmt.Fprintf(&details, "Email: %s\nAccounts: %s\n", email, strings.Join(accounts, ", "))
for _, o := range owners {
fmt.Fprintf(&details, "- %s (schema %s, last seen %s)\n", o.Account, o.Schema, o.LastSeen.Format(time.RFC3339))
}
out = append(out, alert.Finding{
Severity: alert.Warning,
Check: "admin_cross_account_overlap",
Message: fmt.Sprintf("Admin email %s appears on %d accounts: %s", email, len(accounts), strings.Join(accounts, ", ")),
Details: details.String(),
Timestamp: time.Now(),
})
}
return out
}
func filterTrustedAdminOverlaps(overlaps map[string][]store.AdminEmailEntry, cfg *config.Config) map[string][]store.AdminEmailEntry {
if cfg == nil || (len(cfg.Detection.AdminOverlapTrustedEmails) == 0 && len(cfg.Detection.AdminOverlapTrustedDomains) == 0) {
return overlaps
}
out := make(map[string][]store.AdminEmailEntry, len(overlaps))
for email, owners := range overlaps {
if trustedAdminOverlapEmail(email, cfg) {
continue
}
out[email] = owners
}
return out
}
func trustedAdminOverlapEmail(email string, cfg *config.Config) bool {
email = strings.ToLower(strings.TrimSpace(email))
if email == "" {
return false
}
for _, trusted := range cfg.Detection.AdminOverlapTrustedEmails {
if email == strings.ToLower(strings.TrimSpace(trusted)) {
return true
}
}
domain := adminEmailDomain(email)
if domain == "" {
return false
}
for _, trusted := range cfg.Detection.AdminOverlapTrustedDomains {
if domain == strings.ToLower(strings.TrimSpace(trusted)) {
return true
}
}
return false
}
func adminEmailDomain(email string) string {
at := strings.LastIndexByte(email, '@')
if at < 0 || at == len(email)-1 {
return ""
}
return email[at+1:]
}
package checks
import (
"fmt"
"net"
"regexp"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
// AutoRespondDBMalware processes database injection findings and takes
// automated action: blocks attacker IPs extracted from WordPress session
// tokens, revokes compromised user sessions, and cleans confirmed
// malicious content from wp_options or stored database objects.
//
// Only acts on high-confidence findings:
// - db_options_injection with confirmed malicious external script URLs
// - db_siteurl_hijack (siteurl/home pointing to malicious content)
// - db_malicious_trigger/event/procedure/function with structured metadata
//
// Does NOT act on:
// - db_spam_injection (spam posts — needs manual review)
// - db_post_injection (script in posts — too many FPs from page builders)
// - db_options_injection without confirmed malicious URLs
func AutoRespondDBMalware(cfg *config.Config, findings []alert.Finding) []alert.Finding {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.CleanDatabase {
return nil
}
var actions []alert.Finding
for _, f := range findings {
switch f.Check {
case "db_options_injection":
acts := handleMaliciousOption(cfg, f)
actions = append(actions, acts...)
case "db_siteurl_hijack":
acts := handleSiteurlHijack(cfg, f)
actions = append(actions, acts...)
case "db_malicious_trigger", "db_malicious_event",
"db_malicious_procedure", "db_malicious_function":
acts := handleMaliciousDBObject(f)
actions = append(actions, acts...)
}
}
return actions
}
// dbDropObjectFn is the seam through which handleMaliciousDBObject performs
// the backup-then-DROP. Overridden in tests so the routing and action
// emission can be exercised without a live MySQL server.
var dbDropObjectFn = DBDropObject
// handleMaliciousDBObject auto-cleans a confirmed malicious stored database
// object (trigger/event/procedure/function). Detection always fires; the
// DROP only runs when the operator has enabled auto_response.clean_database
// (checked by the caller). The object kind comes from the check name
// (db_malicious_<kind>); account/schema/name come from the finding details.
// DBDropObject records a SHOW CREATE backup in bbolt before dropping, so the
// action is reversible.
func handleMaliciousDBObject(f alert.Finding) []alert.Finding {
kind := maliciousDBObjectKind(f.Check)
if kind == "" {
return nil
}
account, schema, detailKind, name := parseDBObjectFindingDetails(f.Details)
if account == "" || schema == "" || detailKind == "" || name == "" {
return nil
}
if detailKind != kind {
return nil
}
res := dbDropObjectFn(account, schema, kind, name, false)
if !res.Success {
return []alert.Finding{{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-DB-CLEAN failed to drop %s %s.%s: %s", kind, schema, name, res.Message),
Timestamp: time.Now(),
}}
}
return []alert.Finding{{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-DB-CLEAN: Dropped malicious %s %s.%s (backup retained for restore)", kind, schema, name),
Timestamp: time.Now(),
}}
}
func maliciousDBObjectKind(check string) string {
const prefix = "db_malicious_"
if !strings.HasPrefix(check, prefix) {
return ""
}
kind := strings.TrimPrefix(check, prefix)
if !IsDBObjectKind(kind) {
return ""
}
return kind
}
// parseDBObjectFindingDetails extracts the structured header fields a
// db_malicious_<kind> finding carries in its Details block. The SQL body is
// attacker-controlled and may contain lines that look like metadata, so parsing
// stops at Body and keeps the first value for each header key.
func parseDBObjectFindingDetails(details string) (account, schema, kind, name string) {
for _, line := range strings.Split(details, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "Body:") {
break
}
switch {
case strings.HasPrefix(line, "Account: "):
if account == "" {
account = strings.TrimSpace(strings.TrimPrefix(line, "Account: "))
}
case strings.HasPrefix(line, "Schema: "):
if schema == "" {
schema = strings.TrimSpace(strings.TrimPrefix(line, "Schema: "))
}
case strings.HasPrefix(line, "Kind: "):
if kind == "" {
kind = strings.TrimSpace(strings.TrimPrefix(line, "Kind: "))
}
case strings.HasPrefix(line, "Name: "):
if name == "" {
name = strings.TrimSpace(strings.TrimPrefix(line, "Name: "))
}
}
}
return
}
// handleMaliciousOption checks if a db_options_injection finding contains
// a confirmed malicious external script URL, and if so:
// 1. Extracts attacker IPs from WP sessions and emits block findings
// 2. Revokes sessions for users with non-infra, non-private IPs only
// 3. Backs up and cleans the malicious content from the option
func handleMaliciousOption(cfg *config.Config, f alert.Finding) []alert.Finding {
var actions []alert.Finding
dbName, optionName := parseDBFindingDetails(f.Details)
if dbName == "" || optionName == "" {
return nil
}
// Validate option name — must be a plausible WP option name.
if !isValidOptionName(optionName) {
return nil
}
// Never act on CSM backup options — they preserve original malicious
// content for recovery. Acting on them causes cascading backup loops.
if strings.HasPrefix(optionName, "csm_backup_") {
return nil
}
creds := findCredsForDB(dbName)
if creds.dbName == "" {
return nil
}
prefix := creds.tablePrefix
if prefix == "" {
prefix = "wp_"
}
// Re-read the FULL option value from the database — the finding's
// Details field only has a truncated 200-char preview.
fullValue := readOptionValue(creds, prefix, optionName)
if fullValue == "" {
return nil
}
// Only act on options with confirmed malicious external script URLs.
maliciousURL := extractMaliciousScriptURL(fullValue)
if maliciousURL == "" {
return nil
}
// 1. Extract and block attacker IPs from active WP sessions.
suspiciousIPs := extractSuspiciousSessionIPs(creds, prefix, cfg.InfraIPs)
for _, ip := range suspiciousIPs {
// Emit as auto_block check so AutoBlockIPs processes it.
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-BLOCK: %s (active WP session on compromised site, DB: %s)", ip, dbName),
Timestamp: time.Now(),
})
}
// 2. Revoke sessions only for users with suspicious IPs.
// This preserves the site admin's session if they're on an infra IP.
revoked := revokeCompromisedSessions(creds, prefix, cfg.InfraIPs)
if revoked > 0 {
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-DB-CLEAN: Revoked %d compromised WordPress sessions (DB: %s)", revoked, dbName),
Timestamp: time.Now(),
})
}
// 3. Back up the original value, then clean the malicious content.
cleaned := backupAndCleanOption(creds, prefix, optionName, fullValue, maliciousURL)
if cleaned {
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-DB-CLEAN: Removed malicious script from wp_options '%s' (DB: %s, URL: %s)", optionName, dbName, maliciousURL),
Timestamp: time.Now(),
})
}
return actions
}
// handleSiteurlHijack handles siteurl/home hijacking by revoking sessions
// and blocking attacker IPs. Does NOT modify siteurl/home values.
func handleSiteurlHijack(cfg *config.Config, f alert.Finding) []alert.Finding {
var actions []alert.Finding
dbName, _ := parseDBFindingDetails(f.Details)
if dbName == "" {
return nil
}
creds := findCredsForDB(dbName)
if creds.dbName == "" {
return nil
}
prefix := creds.tablePrefix
if prefix == "" {
prefix = "wp_"
}
suspiciousIPs := extractSuspiciousSessionIPs(creds, prefix, cfg.InfraIPs)
for _, ip := range suspiciousIPs {
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-BLOCK: %s (active session on hijacked site, DB: %s)", ip, dbName),
Timestamp: time.Now(),
})
}
revoked := revokeCompromisedSessions(creds, prefix, cfg.InfraIPs)
if revoked > 0 {
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-DB-CLEAN: Revoked %d sessions on hijacked site (DB: %s)", revoked, dbName),
Timestamp: time.Now(),
})
}
return actions
}
// --- URL analysis ---
// scriptSrcRe matches <script src="..."> or <script src=...> patterns.
// Accepts https://, http://, and protocol-relative // URLs, since real
// attackers use all three forms to load external payloads.
var scriptSrcRe = regexp.MustCompile(`(?i)<script[^>]+src\s*=\s*["']?((?:https?:)?//[^"'\s>]+)`)
// knownSafeDomains are legitimate services that embed scripts in wp_options.
var knownSafeDomains = []string{
"googletagmanager.com",
"google-analytics.com",
"googleapis.com",
"gstatic.com",
"google.com",
"facebook.net",
"facebook.com",
"fbcdn.net",
"connect.facebook.net",
"chimpstatic.com",
"mailchimp.com",
"hotjar.com",
"clarity.ms",
"cloudflare.com",
"cdnjs.cloudflare.com",
"jquery.com",
"jsdelivr.net",
"unpkg.com",
"wp.com",
"wordpress.com",
"gravatar.com",
"tawk.to",
"crisp.chat",
"tidio.co",
"intercom.io",
"zendesk.com",
"hubspot.com",
"hubspot.net",
"hs-scripts.com",
"hs-analytics.net",
"hsforms.com",
"mautic.net",
"pinterest.com",
"twitter.com",
"linkedin.com",
"addthis.com",
"sharethis.com",
"recaptcha.net",
"stripe.com",
"paypal.com",
"brevo-mail.com",
}
// extractMaliciousScriptURL finds a <script src="..."> URL in the content
// that is classified as an attacker script by isAttackerScriptURL.
//
// The classification uses an attack-indicator model (see url_reputation.go):
// a URL flags only when it shows attacker-characteristic markers (raw IP
// host, abused TLD, plaintext HTTP, known-bad exfil host, or no valid
// TLD). The previous allowlist-only model produced HIGH-severity findings
// for legitimate third-party widgets (OneTrust, Issuu, regional video
// embeds, regional tax-form widgets) whose domains were not on the
// allowlist; the attack-indicator model eliminates those false positives
// while still catching the injection patterns attackers actually use.
//
// knownSafeDomains is retained as a fast-path optimisation and operator-
// pre-approved list — see isAttackerScriptURL for the composition order.
func extractMaliciousScriptURL(content string) string {
matches := scriptSrcRe.FindAllStringSubmatch(content, -1)
for _, match := range matches {
if len(match) < 2 {
continue
}
url := match[1]
if isAttackerScriptURL(url) {
return url
}
}
return ""
}
// isSafeScriptDomain checks if a script URL is from a known safe domain.
// Handles https://host, http://host, //host (protocol-relative), and
// host-with-port forms.
func isSafeScriptDomain(url string) bool {
urlLower := strings.ToLower(url)
urlLower = strings.TrimPrefix(urlLower, "https://")
urlLower = strings.TrimPrefix(urlLower, "http://")
urlLower = strings.TrimPrefix(urlLower, "//")
host := urlLower
if idx := strings.IndexByte(host, '/'); idx >= 0 {
host = host[:idx]
}
if idx := strings.IndexByte(host, ':'); idx >= 0 {
host = host[:idx]
}
for _, safe := range knownSafeDomains {
if host == safe || strings.HasSuffix(host, "."+safe) {
return true
}
}
return false
}
// --- Validation ---
// validOptionNameRe allows alphanumeric, underscores, hyphens, colons, and dots.
// Rejects anything that could be SQL injection.
var validOptionNameRe = regexp.MustCompile(`^[a-zA-Z0-9_\-:.]+$`)
// isValidOptionName validates that an option name is safe for SQL interpolation.
func isValidOptionName(name string) bool {
return len(name) > 0 && len(name) <= 191 && validOptionNameRe.MatchString(name)
}
// --- DB helpers ---
// parseDBFindingDetails extracts the database name and option name from
// a finding's Details field.
func parseDBFindingDetails(details string) (dbName, optionName string) {
for _, line := range strings.Split(details, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "Database: ") {
dbName = strings.TrimPrefix(line, "Database: ")
}
if strings.HasPrefix(line, "Option: ") {
optionName = strings.TrimPrefix(line, "Option: ")
}
}
return
}
// findCredsForDB finds wp-config.php credentials that match a database name.
// Skips wp-configs whose $table_prefix fails the safety check -- those
// values come straight from a cPanel-user-writable file and end up in
// root-credentialled SQL via handleMaliciousOption / handleSiteurlHijack.
func findCredsForDB(dbName string) wpDBCreds {
wpConfigs, _ := osFS.Glob("/home/*/public_html/wp-config.php")
addonConfigs, _ := osFS.Glob("/home/*/*/wp-config.php")
wpConfigs = append(wpConfigs, addonConfigs...)
for _, path := range wpConfigs {
creds := parseWPConfig(path)
if creds.dbName != dbName {
continue
}
prefix, ok := resolveTablePrefix(creds)
if !ok {
continue
}
creds.tablePrefix = prefix
return creds
}
return wpDBCreds{}
}
// readOptionValue reads the full value of a wp_option from the database.
func readOptionValue(creds wpDBCreds, prefix, optionName string) string {
if !isValidOptionName(optionName) {
return ""
}
query := fmt.Sprintf(
"SELECT option_value FROM %soptions WHERE option_name='%s' LIMIT 1",
prefix, escapeSQLString(optionName))
lines := runMySQLQuery(creds, query)
if len(lines) == 0 {
return ""
}
return lines[0]
}
// extractSuspiciousSessionIPs reads WP session tokens and returns IPs that
// are NOT infra IPs, not private, and not loopback.
func extractSuspiciousSessionIPs(creds wpDBCreds, prefix string, infraIPs []string) []string {
query := fmt.Sprintf(
"SELECT meta_value FROM %susermeta WHERE meta_key='session_tokens' AND meta_value != ''",
prefix)
lines := runMySQLQuery(creds, query)
seen := make(map[string]bool)
var ips []string
ipRe := regexp.MustCompile(`"ip";s:\d+:"([^"]+)"`)
for _, line := range lines {
matches := ipRe.FindAllStringSubmatch(line, -1)
for _, m := range matches {
if len(m) < 2 {
continue
}
ip := m[1]
parsed := net.ParseIP(ip)
if parsed == nil || parsed.IsLoopback() || parsed.IsPrivate() {
continue
}
if isInfraIP(ip, infraIPs) {
continue
}
if !seen[ip] {
seen[ip] = true
ips = append(ips, ip)
}
}
}
return ips
}
// revokeCompromisedSessions clears session_tokens only for WP users whose
// sessions contain non-infra, non-private IPs. Returns count of users revoked.
func revokeCompromisedSessions(creds wpDBCreds, prefix string, infraIPs []string) int {
// Get user IDs with active sessions.
query := fmt.Sprintf(
"SELECT user_id, meta_value FROM %susermeta WHERE meta_key='session_tokens' AND meta_value != ''",
prefix)
lines := runMySQLQuery(creds, query)
ipRe := regexp.MustCompile(`"ip";s:\d+:"([^"]+)"`)
revoked := 0
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) != 2 {
continue
}
userID := strings.TrimSpace(parts[0])
sessionData := parts[1]
// Check if this user has any suspicious (non-infra, non-private) IPs.
hasSuspicious := false
matches := ipRe.FindAllStringSubmatch(sessionData, -1)
for _, m := range matches {
if len(m) < 2 {
continue
}
ip := m[1]
parsed := net.ParseIP(ip)
if parsed == nil || parsed.IsLoopback() || parsed.IsPrivate() {
continue
}
if isInfraIP(ip, infraIPs) {
continue
}
hasSuspicious = true
break
}
if hasSuspicious {
revokeQuery := fmt.Sprintf(
"UPDATE %susermeta SET meta_value='' WHERE user_id=%s AND meta_key='session_tokens'",
prefix, escapeSQLString(userID))
runMySQLQuery(creds, revokeQuery)
revoked++
}
}
return revoked
}
// backupAndCleanOption saves the original value to a backup option, then
// removes malicious script injections from the option value.
func backupAndCleanOption(creds wpDBCreds, prefix, optionName, originalValue, maliciousURL string) bool {
if maliciousURL == "" {
return false
}
cleaned := removeMaliciousScripts(originalValue)
// Never claim a clean (nor write back) unless the confirmed attacker
// script is actually gone. This keeps removal locked to detection: if
// a script form is flagged but the remover cannot strip it, report the
// finding but never persist a value that still carries a live payload.
// Plain text references to the same URL are inert option data and must
// not block a valid script cleanup.
if extractMaliciousScriptURL(cleaned) != "" {
return false
}
if cleaned == originalValue {
return false
}
// Save original value as a backup option (csm_backup_<name>_<timestamp>).
backupName := fmt.Sprintf("csm_backup_%s_%d", optionName, time.Now().Unix())
if len(backupName) > 191 {
backupName = backupName[:191]
}
backupQuery := fmt.Sprintf(
"INSERT INTO %soptions (option_name, option_value, autoload) VALUES ('%s', '%s', 'no')",
prefix, escapeSQLString(backupName), escapeSQLString(originalValue))
runMySQLQuery(creds, backupQuery)
// Write the cleaned value.
updateQuery := fmt.Sprintf(
"UPDATE %soptions SET option_value='%s' WHERE option_name='%s'",
prefix, escapeSQLString(cleaned), escapeSQLString(optionName))
runMySQLQuery(creds, updateQuery)
return true
}
// --- Script removal ---
// maliciousScriptRe matches the style-break injection pattern:
// </style><script src=...></script><style>
var maliciousScriptRe = regexp.MustCompile(
`(?i)</style>\s*<script[^>]*src\s*=\s*[^>]+>\s*</script>\s*<style>`)
// simpleScriptRe matches standalone <script src="..."></script> tags.
// The src grammar mirrors scriptSrcRe (https://, http://, and
// protocol-relative //) so removal stays paired with detection: a URL
// form the detector flags must be one the remover can strip.
var simpleScriptRe = regexp.MustCompile(
`(?i)<script[^>]*src\s*=\s*["']?(?:https?:)?//[^"'\s>]+["']?[^>]*>\s*</script>`)
// removeMaliciousScripts strips malicious <script> injections from content,
// preserving scripts that are not classified as attacker scripts.
//
// Uses the same isAttackerScriptURL predicate as extractMaliciousScriptURL
// so detection and removal stay semantically paired. If the detector
// would not flag a given URL as malicious, the remover must not strip
// it — otherwise an operator running DBCleanOption on an option that
// contains a real injection alongside a legitimate third-party embed
// (OneTrust, Issuu, regional widget) would silently lose the legitimate
// embed along with the attacker's script.
func removeMaliciousScripts(content string) string {
// First pass: remove style-break patterns only when the embedded URL has
// attacker indicators. The wrapper is suspicious, but the cleaner must
// not remove a legitimate embed just because it appears next to a real
// attacker script in the same option.
content = maliciousScriptRe.ReplaceAllStringFunc(content, func(match string) string {
urls := scriptSrcRe.FindStringSubmatch(match)
if len(urls) >= 2 && isAttackerScriptURL(urls[1]) {
return ""
}
return match
})
// Second pass: remove standalone script tags only when the URL
// shows attacker indicators.
content = simpleScriptRe.ReplaceAllStringFunc(content, func(match string) string {
urls := scriptSrcRe.FindStringSubmatch(match)
if len(urls) >= 2 && isAttackerScriptURL(urls[1]) {
return ""
}
return match
})
return strings.TrimSpace(content)
}
// escapeSQLString escapes special characters for MySQL string interpolation.
func escapeSQLString(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `'`, `\'`)
s = strings.ReplaceAll(s, "\x00", `\0`)
s = strings.ReplaceAll(s, "\n", `\n`)
s = strings.ReplaceAll(s, "\r", `\r`)
s = strings.ReplaceAll(s, "\x1a", `\Z`)
return s
}
package checks
import (
"context"
"fmt"
"regexp"
"strings"
"time"
"github.com/pidginhost/csm/internal/mysqlclient"
)
// spamPostIDRe guards the ID list interpolated into root DELETE statements.
var spamPostIDRe = regexp.MustCompile(`^[0-9]+$`)
// DBCleanResult describes the outcome of a database cleanup operation.
type DBCleanResult struct {
Account string
Database string
Action string // "clean-option", "revoke-user", "delete-spam"
Success bool
Message string
Details []string // individual actions taken
BackupNames []string // names of backup options created
}
// DBCleanOption removes malicious script injections from a wp_option value.
// Creates a backup option before modifying. Returns a result describing
// what was done. If preview is true, reports what would be done without
// modifying the database.
func DBCleanOption(account, optionName string, preview bool) DBCleanResult {
result := DBCleanResult{
Account: account,
Action: "clean-option",
}
if !isValidOptionName(optionName) {
result.Message = fmt.Sprintf("Invalid option name: %q", optionName)
return result
}
creds, prefix := findCredsForAccount(account)
if creds.dbName == "" {
result.Message = fmt.Sprintf("No WordPress database found for account %q", account)
return result
}
result.Database = creds.dbName
// Read current value.
value := readOptionValue(creds, prefix, optionName)
if value == "" {
result.Message = fmt.Sprintf("Option %q not found or empty in %s", optionName, creds.dbName)
return result
}
// Check for malicious content.
maliciousURL := extractMaliciousScriptURL(value)
if maliciousURL == "" {
result.Message = fmt.Sprintf("No malicious external script found in %q", optionName)
result.Details = append(result.Details, "Option exists but contains no confirmed malicious URLs")
return result
}
cleaned := removeMaliciousScripts(value)
if cleaned == value {
result.Message = "Content unchanged after cleaning"
return result
}
if extractMaliciousScriptURL(cleaned) != "" {
result.Message = "Failed to remove all malicious scripts"
return result
}
result.Details = append(result.Details, fmt.Sprintf("Malicious URL: %s", maliciousURL))
result.Details = append(result.Details, fmt.Sprintf("Original length: %d, Cleaned length: %d", len(value), len(cleaned)))
if preview {
result.Message = fmt.Sprintf("PREVIEW: Would clean malicious script from %q", optionName)
result.Success = true
return result
}
// Backup and clean.
if backupAndCleanOption(creds, prefix, optionName, value, maliciousURL) {
backupName := fmt.Sprintf("csm_backup_%s_%d", optionName, time.Now().Unix())
if len(backupName) > 191 {
backupName = backupName[:191]
}
result.BackupNames = append(result.BackupNames, backupName)
result.Details = append(result.Details, fmt.Sprintf("Backup saved as: %s", backupName))
result.Message = fmt.Sprintf("Cleaned malicious script from %q", optionName)
result.Success = true
} else {
result.Message = "Failed to clean option"
}
return result
}
// DBRevokeUser revokes WordPress sessions for a specific user and optionally
// demotes them to subscriber role. If preview is true, reports what would be
// done without modifying the database.
func DBRevokeUser(account string, userID int, demote, preview bool) DBCleanResult {
result := DBCleanResult{
Account: account,
Action: "revoke-user",
}
creds, prefix := findCredsForAccount(account)
if creds.dbName == "" {
result.Message = fmt.Sprintf("No WordPress database found for account %q", account)
return result
}
result.Database = creds.dbName
// Verify user exists.
query := fmt.Sprintf(
"SELECT user_login, user_email FROM %susers WHERE ID=%d LIMIT 1",
prefix, userID)
lines := runMySQLQueryRoot(creds.dbName, query)
if len(lines) == 0 {
result.Message = fmt.Sprintf("User ID %d not found in %s", userID, creds.dbName)
return result
}
parts := strings.SplitN(lines[0], "\t", 2)
login := parts[0]
email := ""
if len(parts) > 1 {
email = parts[1]
}
result.Details = append(result.Details, fmt.Sprintf("User: %s (email: %s)", login, email))
// Check current sessions.
sessQuery := fmt.Sprintf(
"SELECT LEFT(meta_value, 200) FROM %susermeta WHERE user_id=%d AND meta_key='session_tokens'",
prefix, userID)
sessLines := runMySQLQueryRoot(creds.dbName, sessQuery)
sessionCount := 0
if len(sessLines) > 0 && sessLines[0] != "" {
sessionCount = strings.Count(sessLines[0], `"expiration"`)
}
result.Details = append(result.Details, fmt.Sprintf("Active sessions: %d", sessionCount))
if preview {
msg := fmt.Sprintf("PREVIEW: Would revoke %d sessions for user %s (ID %d)", sessionCount, login, userID)
if demote {
msg += " and demote to subscriber"
}
result.Message = msg
result.Success = true
return result
}
// Revoke sessions.
revokeQuery := fmt.Sprintf(
"UPDATE %susermeta SET meta_value='' WHERE user_id=%d AND meta_key='session_tokens'",
prefix, userID)
runMySQLQueryRoot(creds.dbName, revokeQuery)
result.Details = append(result.Details, "Sessions revoked")
// Demote to subscriber.
if demote {
// Read current capabilities to find the meta_key (varies by prefix).
capQuery := fmt.Sprintf(
"SELECT meta_key FROM %susermeta WHERE user_id=%d AND meta_key LIKE '%%capabilities'",
prefix, userID)
capLines := runMySQLQueryRoot(creds.dbName, capQuery)
if len(capLines) > 0 {
capKey := capLines[0]
demoteQuery := fmt.Sprintf(
"UPDATE %susermeta SET meta_value='a:1:{s:10:\"subscriber\";b:1;}' WHERE user_id=%d AND meta_key='%s'",
prefix, userID, escapeSQLString(capKey))
runMySQLQueryRoot(creds.dbName, demoteQuery)
result.Details = append(result.Details, "Demoted to subscriber role")
}
}
result.Message = fmt.Sprintf("Revoked sessions for user %s (ID %d)", login, userID)
result.Success = true
return result
}
// DBDeleteSpam deletes published posts matching spam patterns from a WordPress
// database. Only deletes posts of type 'post' with status 'publish' to avoid
// touching pages, attachments, or plugin data. If preview is true, reports
// counts without deleting.
func DBDeleteSpam(account string, preview bool) DBCleanResult {
result := DBCleanResult{
Account: account,
Action: "delete-spam",
}
creds, prefix := findCredsForAccount(account)
if creds.dbName == "" {
result.Message = fmt.Sprintf("No WordPress database found for account %q", account)
return result
}
result.Database = creds.dbName
// Count spam posts by pattern.
patterns := []struct {
keyword string
sqlLike string
}{
{"casino", "%casino-%"},
{"betting", "%betting%"},
{"cialis", "%cialis%"},
{"viagra", "%viagra%"},
{"pharma", "%pharma%"},
{"buy-cheap", "%buy-cheap-%"},
{"crack-serial", "%crack-serial%"},
{"free-download", "%free-download%"},
}
totalCount := 0
for _, p := range patterns {
countQuery := "SELECT COUNT(*) FROM " + prefix + "posts WHERE post_type='post' AND post_status='publish' AND (post_content LIKE '" + p.sqlLike + "' OR post_title LIKE '" + p.sqlLike + "')"
lines := runMySQLQueryRoot(creds.dbName, countQuery)
if len(lines) > 0 {
var count int
if _, err := fmt.Sscanf(lines[0], "%d", &count); err == nil && count > 0 {
result.Details = append(result.Details, fmt.Sprintf("%s: %d posts", p.keyword, count))
totalCount += count
}
}
}
if totalCount == 0 {
result.Message = "No spam posts found"
result.Success = true
return result
}
if preview {
result.Message = fmt.Sprintf("PREVIEW: Would delete up to %d spam posts", totalCount)
result.Success = true
return result
}
// Delete spam posts (and their revisions/meta).
deleted := 0
for _, p := range patterns {
// Get IDs of matching posts.
idQuery := "SELECT ID FROM " + prefix + "posts WHERE post_type='post' AND post_status='publish' AND (post_content LIKE '" + p.sqlLike + "' OR post_title LIKE '" + p.sqlLike + "')"
idLines := runMySQLQueryRoot(creds.dbName, idQuery)
if len(idLines) == 0 {
continue
}
// Delete in batches of 100.
for i := 0; i < len(idLines); i += 100 {
end := i + 100
if end > len(idLines) {
end = len(idLines)
}
batch := idLines[i:end]
// Validate IDs are numeric.
var validIDs []string
for _, id := range batch {
id = strings.TrimSpace(id)
if spamPostIDRe.MatchString(id) {
validIDs = append(validIDs, id)
}
}
if len(validIDs) == 0 {
continue
}
idList := strings.Join(validIDs, ",")
// Delete postmeta for these posts.
runMySQLQueryRoot(creds.dbName, fmt.Sprintf(
"DELETE FROM %spostmeta WHERE post_id IN (%s)", prefix, idList))
// Delete revisions.
runMySQLQueryRoot(creds.dbName, fmt.Sprintf(
"DELETE FROM %sposts WHERE post_parent IN (%s) AND post_type='revision'",
prefix, idList))
// Delete the posts themselves.
runMySQLQueryRoot(creds.dbName, fmt.Sprintf(
"DELETE FROM %sposts WHERE ID IN (%s) AND post_type='post' AND post_status='publish'",
prefix, idList))
deleted += len(validIDs)
}
}
result.Message = fmt.Sprintf("Deleted %d spam posts and their metadata", deleted)
result.Success = true
return result
}
// FormatDBCleanResult formats a DBCleanResult for terminal output.
func FormatDBCleanResult(r DBCleanResult) string {
var sb strings.Builder
status := "FAILED"
if r.Success {
status = "OK"
}
fmt.Fprintf(&sb, "[%s] %s — %s\n", status, r.Action, r.Message)
if r.Database != "" {
fmt.Fprintf(&sb, " Database: %s\n", r.Database)
}
for _, d := range r.Details {
fmt.Fprintf(&sb, " %s\n", d)
}
return sb.String()
}
// --- helpers ---
// findCredsForAccount finds WP database credentials for a cPanel account.
// Returns root-authenticated credentials that use /root/.my.cnf instead of
// wp-config.php passwords (which are often stale on cPanel servers).
func findCredsForAccount(account string) (wpDBCreds, string) {
patterns := []string{
fmt.Sprintf("/home/%s/public_html/wp-config.php", account),
}
addonConfigs, _ := osFS.Glob(fmt.Sprintf("/home/%s/*/wp-config.php", account))
patterns = append(patterns, addonConfigs...)
for _, path := range patterns {
creds := parseWPConfig(path)
if creds.dbName == "" {
continue
}
prefix, ok := resolveTablePrefix(creds)
if !ok {
continue
}
creds.tablePrefix = prefix
// Use root auth — CSM runs as root with /root/.my.cnf.
// wp-config.php passwords are unreliable (cPanel password
// rotations don't always update the file).
creds.dbUser = ""
creds.dbPass = ""
creds.dbHost = "localhost"
return creds, prefix
}
return wpDBCreds{}, ""
}
// resolveTablePrefix returns the safe `wp_options`-style prefix for the
// parsed wp-config. Empty input defaults to "wp_". Anything outside
// [A-Za-z0-9_]+ returns ("", false) so callers refuse to interpolate
// attacker-controlled $table_prefix into root-credentialled MySQL.
func resolveTablePrefix(creds wpDBCreds) (string, bool) {
prefix := creds.tablePrefix
if prefix == "" {
prefix = "wp_"
}
if !validTablePrefix.MatchString(prefix) {
return "", false
}
return prefix, true
}
// runMySQLExecRoot runs a non-SELECT MySQL statement under root
// credentials and returns the underlying exec error verbatim. Used
// by the persistence-mechanism cleaner where success is signalled by
// a zero exit + empty stdout, which runMySQLQueryRoot misclassifies
// as "no output, treat as failure".
func runMySQLExecRoot(dbName, stmt string) error {
_, err := mysqlclient.RootExecSchema(context.Background(), dbName, stmt)
return err
}
// runMySQLQueryRoot runs a MySQL query using root credentials from
// /root/.my.cnf (no explicit user/password args).
func runMySQLQueryRoot(dbName, query string) []string {
rows, err := mysqlclient.RootQuerySchema(context.Background(), dbName, query)
if err != nil || len(rows) == 0 {
return nil
}
var lines []string
for _, line := range rows {
line = strings.TrimSpace(line)
if line != "" {
lines = append(lines, line)
}
}
return lines
}
package checks
import (
"context"
"fmt"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// DB persistence-mechanism scanner.
//
// Vanilla CMS installs (WordPress, Joomla, Drupal, Magento, OpenCart)
// ship zero triggers / events / stored procedures / stored functions.
// Any presence is operator-review territory at minimum, and a body
// matching known-malware patterns is critical -- attacker persistence
// often re-injects on the next INSERT after a file-level cleanup, so
// detection here closes a real gap.
//
// The scanner reuses the existing parseWPConfig + runMySQLQuery
// infrastructure from dbscan.go. When the multi-CMS adapter layer
// lands later, this file's helpers become reusable across CMSes by
// swapping the credential discovery, not the queries.
//
// Per spec: detection only, no auto-drop. Operators drop manually
// via `csm db-clean drop-object`.
// dbPersistenceMalwarePatterns supplements dbMalwarePatterns from
// dbscan.go with MySQL-specific persistence-attack signals. Lowercase
// for case-insensitive matching against SQL bodies.
var dbPersistenceMalwarePatterns = []string{
"sys_exec",
"lib_mysqludf_sys",
"into outfile",
"into dumpfile",
"load_file(",
"load data infile",
}
// magicTokenRegex extracts high-entropy activation tokens from trigger
// bodies that gate privileged actions on `display_name LIKE
// '%<token>%'`. Common display-name filters such as "%administrator%"
// must not escalate a merely unexpected trigger into Critical.
var magicTokenRegex = regexp.MustCompile(`(?i)display_name\s+like\s+['"]%([A-Za-z0-9_-]{10,32})%['"]`)
// validTablePrefix matches the character class WordPress accepts for
// $table_prefix (alphanumerics and underscore). Untrusted prefixes
// from a malformed wp-config.php fail this check, which keeps
// scanMagicTokenUsers from concatenating attacker-controlled data into
// its SQL literal.
var validTablePrefix = regexp.MustCompile(`^[A-Za-z0-9_]+$`)
// dbPersistenceMalwareRegexes catches multi-token shapes that no
// substring set can match cleanly: role-escalation writes and
// password-hash exfiltration reads. Pre-compiled at package init time
// -- a regex parse error here is a build-time bug, not a runtime one.
//
// Patterns intentionally case-insensitive ((?i) prefix) and tolerant of
// whitespace / line breaks across MySQL trigger bodies. The role-write
// pattern requires the literal string "administrator" inside the
// serialized capabilities payload -- promotion to subscriber/customer
// is the legitimate WP-signup shape and must not match.
var dbPersistenceMalwareRegexes = []*regexp.Regexp{
// Role escalation: UPDATE on *_usermeta writing administrator caps.
// The (?s) flag lets `.` match newlines so multi-line trigger
// bodies with the UPDATE split across lines still hit.
regexp.MustCompile(`(?is)update\s+` + "`?" + `\w*usermeta` + "`?" + `\s+set\s+meta_value\s*=.*?(?:s:13:["\x60]administrator["\x60]|["\x60]administrator["\x60])`),
// Password-hash exfil read: SELECT user_pass FROM <users-like>
// table. Real WP code goes through wp_check_password() in PHP, never
// raw SELECT user_pass from SQL.
regexp.MustCompile(`(?i)select\s+user_pass\s+from\s+` + "`?" + `\w*users`),
}
// dbObjectKind names the four MySQL object types this scanner
// inspects. Used in finding categories and CLI subcommands.
type dbObjectKind string
const (
dbObjectTrigger dbObjectKind = "trigger"
dbObjectEvent dbObjectKind = "event"
dbObjectProcedure dbObjectKind = "procedure"
dbObjectFunction dbObjectKind = "function"
)
// dbObjectAllKinds lists all valid kinds for the CLI's type validator.
var dbObjectAllKinds = []dbObjectKind{
dbObjectTrigger, dbObjectEvent, dbObjectProcedure, dbObjectFunction,
}
// dbObjectFinding describes one detector hit before it is converted
// to an alert.Finding -- carrying the structured fields the CLI
// drop-object subcommand needs to look up the same row.
type dbObjectFinding struct {
Account string
Schema string
Kind dbObjectKind
Name string
Body string
IsMalw bool // true: malware pattern hit; false: unexpected presence
}
// CheckDatabaseObjects scans every WordPress installation's database
// for triggers, events, procedures, and functions. Critical findings
// fire when the body matches a known-malware pattern; Warning
// findings fire when an object exists at all (vanilla CMSes ship
// none). The Detection.DBObjectScanning kill-switch silences both
// emit paths without disabling the manual drop-object CLI.
func CheckDatabaseObjects(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !dbObjectScanningEnabled(cfg) {
return nil
}
if ctx == nil {
ctx = context.Background()
}
var findings []alert.Finding
wpConfigs, _ := osFS.Glob("/home/*/public_html/wp-config.php")
if len(wpConfigs) == 0 {
return nil
}
allowlist := dbObjectAllowlistMap(cfg)
// Rank by mtime desc so recently touched WP installs are processed
// first when the check timeout cuts iteration short.
for _, wpConfig := range rankPathsByMtimeDesc(ctx, wpConfigs, effectiveAccountScanMaxFiles(cfg)) {
if ctx.Err() != nil {
return findings
}
account := extractUser(filepath.Dir(wpConfig))
creds := parseWPConfig(wpConfig)
if creds.dbName == "" || creds.dbUser == "" {
continue
}
hits := scanDBObjects(account, creds)
for _, h := range hits {
if !h.IsMalw && allowlist[allowlistKey(h)] {
continue
}
findings = append(findings, h.toFinding())
}
// Retro-scan: when a trigger gates a privileged action on a
// secret token in display_name, find users whose display_name
// still carries the token. Zero matches is itself useful for
// the incident report ("no evidence the backdoor fired").
for _, h := range hits {
if h.Kind != dbObjectTrigger {
continue
}
tokens := extractMagicTokens(h.Body)
if len(tokens) == 0 {
continue
}
findings = append(findings, scanMagicTokenUsers(account, creds.dbName, creds.tablePrefix, tokens)...)
}
}
return findings
}
// scanDBObjects runs the three INFORMATION_SCHEMA queries and
// classifies every row. Pure function over the cmdExec injector --
// tests provide canned MySQL CLI output and assert on the structured
// findings without touching a real database.
//
// Connections use root credentials via /root/.my.cnf. WP-config
// passwords are unreliable on cPanel hosts (password rotations
// don't update the file), so a WP-creds path here would silently
// miss persistence objects on the very platform we care most about.
// The existing db-clean code (db_clean.go: findCredsForAccount)
// hits the same constraint and reaches the same conclusion.
func scanDBObjects(account string, creds wpDBCreds) []dbObjectFinding {
if creds.dbName == "" {
return nil
}
schema := creds.dbName
schemaLit := mysqlSchemaLiteral(schema)
var hits []dbObjectFinding
// TRIGGERS
for _, row := range runMySQLQueryRoot(schema, fmt.Sprintf(
`SELECT TRIGGER_NAME, ACTION_STATEMENT FROM INFORMATION_SCHEMA.TRIGGERS WHERE TRIGGER_SCHEMA = %s`,
schemaLit)) {
name, body := splitTabRow(row)
if name == "" {
continue
}
hits = append(hits, classifyDBObject(account, schema, dbObjectTrigger, name, body))
}
// EVENTS
for _, row := range runMySQLQueryRoot(schema, fmt.Sprintf(
`SELECT EVENT_NAME, EVENT_DEFINITION FROM INFORMATION_SCHEMA.EVENTS WHERE EVENT_SCHEMA = %s`,
schemaLit)) {
name, body := splitTabRow(row)
if name == "" {
continue
}
hits = append(hits, classifyDBObject(account, schema, dbObjectEvent, name, body))
}
// ROUTINES (procedures + functions)
for _, row := range runMySQLQueryRoot(schema, fmt.Sprintf(
`SELECT ROUTINE_NAME, ROUTINE_TYPE, ROUTINE_DEFINITION FROM INFORMATION_SCHEMA.ROUTINES WHERE ROUTINE_SCHEMA = %s`,
schemaLit)) {
name, rtype, body := splitTabRow3(row)
if name == "" {
continue
}
kind := dbObjectProcedure
if strings.EqualFold(rtype, "FUNCTION") {
kind = dbObjectFunction
}
hits = append(hits, classifyDBObject(account, schema, kind, name, body))
}
return hits
}
// classifyDBObject decides whether a row matches the malware
// patterns (Critical) or merely exists (Warning).
func classifyDBObject(account, schema string, kind dbObjectKind, name, body string) dbObjectFinding {
return dbObjectFinding{
Account: account,
Schema: schema,
Kind: kind,
Name: name,
Body: body,
IsMalw: bodyHasMalwarePattern(body),
}
}
// bodyHasMalwarePattern returns true when the SQL body matches any of
// the three classifier tiers:
//
// 1. dbMalwarePatterns / dbPersistenceMalwarePatterns: substring tokens
// for OS-exec UDFs and file-IO sinks (sys_exec, INTO OUTFILE, etc.).
// 2. extractMagicTokens: high-entropy display_name activation gates.
// 3. dbPersistenceMalwareRegexes: multi-token shapes for role-escalation
// writes and password-hash exfiltration reads.
//
// Substring matching stays case-insensitive via ToLower; the regex tier
// keeps its own `(?i)` flags so its semantics travel with the pattern.
func bodyHasMalwarePattern(body string) bool {
lower := strings.ToLower(body)
for _, p := range dbMalwarePatterns {
if strings.Contains(lower, strings.ToLower(p.pattern)) {
return true
}
}
for _, p := range dbPersistenceMalwarePatterns {
if strings.Contains(lower, p) {
return true
}
}
if len(extractMagicTokens(body)) > 0 {
return true
}
for _, re := range dbPersistenceMalwareRegexes {
if re.MatchString(body) {
return true
}
}
return false
}
// toFinding renders the structured hit into the alert.Finding shape
// the rest of the pipeline expects. Finding category encodes both
// kind and severity tier so operators can suppress per attack type.
func (h dbObjectFinding) toFinding() alert.Finding {
check := fmt.Sprintf("db_unexpected_%s", h.Kind)
severity := alert.Warning
intro := "Unexpected"
if h.IsMalw {
check = fmt.Sprintf("db_malicious_%s", h.Kind)
severity = alert.Critical
intro = "Malicious"
}
excerpt := h.Body
if len(excerpt) > 240 {
excerpt = excerpt[:240] + "..."
}
return alert.Finding{
Severity: severity,
Check: check,
Message: fmt.Sprintf("%s %s %s in %s.%s", intro, h.Kind, h.Name, h.Account, h.Schema),
Details: fmt.Sprintf("Account: %s\nSchema: %s\nKind: %s\nName: %s\nBody: %s", h.Account, h.Schema, h.Kind, h.Name, excerpt),
Timestamp: time.Now(),
}
}
// allowlistKey shapes the suppression key per spec:
// `<account>:<schema>:<type>:<name>`. Used for the Warning tier
// only -- Critical malware-pattern hits always fire.
func allowlistKey(h dbObjectFinding) string {
return fmt.Sprintf("%s:%s:%s:%s", h.Account, h.Schema, h.Kind, h.Name)
}
func dbObjectAllowlistMap(cfg *config.Config) map[string]bool {
out := map[string]bool{}
if cfg == nil {
return out
}
for _, e := range cfg.Detection.DBObjectAllowlist {
out[strings.TrimSpace(e)] = true
}
return out
}
// splitTabRow returns the first two tab-separated fields from a
// MySQL `-B -N` row. Empty strings if the row has fewer than two
// fields.
func splitTabRow(row string) (string, string) {
parts := strings.SplitN(row, "\t", 2)
if len(parts) < 2 {
return "", ""
}
return parts[0], parts[1]
}
// splitTabRow3 returns the first three tab-separated fields. Used
// for ROUTINES which carries (name, type, body).
func splitTabRow3(row string) (string, string, string) {
parts := strings.SplitN(row, "\t", 3)
if len(parts) < 3 {
return "", "", ""
}
return parts[0], parts[1], parts[2]
}
// mysqlSchemaLiteral wraps the schema name as a single-quoted
// string literal with backslash escaping. The DB name comes from
// wp-config.php (operator-controlled) so the risk is low, but the
// string-literal route is consistent with how the existing dbscan
// queries handle string args.
func mysqlSchemaLiteral(name string) string {
escaped := strings.ReplaceAll(name, `\`, `\\`)
escaped = strings.ReplaceAll(escaped, `'`, `\'`)
return "'" + escaped + "'"
}
// dbObjectScanningEnabled resolves the tri-state cfg flag: nil and
// missing-config both mean default-on; an explicit *false means off.
func dbObjectScanningEnabled(cfg *config.Config) bool {
if cfg == nil {
return true
}
if cfg.Detection.DBObjectScanning == nil {
return true
}
return *cfg.Detection.DBObjectScanning
}
// IsDBObjectKind reports whether s is one of the four valid kinds.
// Used by the CLI subcommand to validate user input before opening
// a connection.
func IsDBObjectKind(s string) bool {
for _, k := range dbObjectAllKinds {
if string(k) == s {
return true
}
}
return false
}
// extractMagicTokens returns the secret activation tokens referenced in
// a trigger body's `display_name LIKE '%<token>%'` clauses. The body
// classifier uses the same helper, and the user retro scan reuses the
// returned tokens to search the *_users table for matches. Tokens are
// deduplicated to keep query count bounded when a trigger references
// the same token across multiple branches.
//
// Returns nil for benign bodies so callers can skip MySQL entirely.
func extractMagicTokens(body string) []string {
matches := magicTokenRegex.FindAllStringSubmatch(body, -1)
if len(matches) == 0 {
return nil
}
seen := make(map[string]struct{}, len(matches))
var out []string
for _, m := range matches {
if len(m) < 2 {
continue
}
tok := m[1]
if !validMagicToken(tok) {
continue
}
if _, ok := seen[tok]; ok {
continue
}
seen[tok] = struct{}{}
out = append(out, tok)
}
return out
}
func validMagicToken(tok string) bool {
if len(tok) < 10 || len(tok) > 32 {
return false
}
hasUpper, hasLower, hasDigit := false, false, false
for _, r := range tok {
switch {
case r >= 'A' && r <= 'Z':
hasUpper = true
case r >= 'a' && r <= 'z':
hasLower = true
case r >= '0' && r <= '9':
hasDigit = true
case r == '_' || r == '-':
default:
return false
}
}
return hasUpper && hasLower && hasDigit
}
// scanMagicTokenUsers searches the WordPress users table for accounts
// whose display_name carries a backdoor activation token. A match is
// forensic evidence that the trigger fired against that user -- they
// may still be administrator, or the attacker may have demoted them
// after promotion. Either way the user requires manual review and is
// surfaced as Critical.
//
// The function is conservative about query construction. Tokens are
// guaranteed to be high-entropy [A-Za-z0-9_-]{10,32} strings by
// extractMagicTokens, and the table prefix is validated against
// [A-Za-z0-9_]+ before concatenation. Anything outside those character
// classes causes the scan to skip the query entirely rather than emit a
// half-built SQL statement against an untrusted prefix.
func scanMagicTokenUsers(account, schema, tablePrefix string, tokens []string) []alert.Finding {
if len(tokens) == 0 || tablePrefix == "" || !validTablePrefix.MatchString(tablePrefix) {
return nil
}
var findings []alert.Finding
for _, tok := range tokens {
if !validMagicToken(tok) {
continue
}
query := fmt.Sprintf(
"SELECT ID, user_login, user_email, display_name FROM `%susers` WHERE display_name LIKE '%%%s%%'",
tablePrefix, tok,
)
rows := runMySQLQueryRoot(schema, query)
for _, row := range rows {
parts := strings.SplitN(row, "\t", 4)
if len(parts) < 4 {
continue
}
userID, userLogin, userEmail, displayName := parts[0], parts[1], parts[2], parts[3]
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "db_magic_token_user",
Message: fmt.Sprintf("User %s (ID %s) carries backdoor activation token in %s.%susers", userLogin, userID, account, tablePrefix),
Details: fmt.Sprintf("Account: %s\nSchema: %s\nToken: %s\nUser ID: %s\nUser login: %s\nUser email: %s\nDisplay name: %s", account, schema, tok, userID, userLogin, userEmail, displayName),
Timestamp: time.Now(),
})
}
}
return findings
}
package checks
import (
"errors"
"fmt"
"regexp"
"strings"
"time"
"github.com/pidginhost/csm/internal/store"
)
// reAccountName matches the cPanel-username shape we accept from
// operator CLI input. Constrained on purpose: anything outside this
// charset would either fail later validation (QuoteIdent on schema)
// or escape /home via the path interpolation in findAccountSchemas
// when an unwary glob expanded a `*`.
var reAccountName = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_-]{0,31}$`)
// errInvalidAccountName flags an account string that fails the
// allowed-charset check. Surfaced through the CLI so the operator
// sees a clear error before any filesystem or SQL lookup.
var errInvalidAccountName = errors.New("invalid account name (want [a-zA-Z][a-zA-Z0-9_-]{0,31})")
// DBDropObject drops a single trigger / event / stored procedure /
// stored function from the operator-supplied account+schema, after:
//
// 1. Validating the kind ("trigger" | "event" | "procedure" | "function").
// 2. Validating that <schema> is one of the databases this account
// hosts. The account is taken from /home/<account>/* wp-config.php
// files; an attacker who can pass an arbitrary <schema> here gets
// no further than DROP'ping their own database.
// 3. QuoteIdent on both <schema> and <name>, so identifier strings
// never participate in SQL string concatenation.
// 4. SHOW CREATE the object and persist the result to the
// db_object_backups bbolt bucket as the backup -- replaying the
// CREATE SQL restores the object byte-for-byte.
// 5. DROP the object.
//
// preview=true short-circuits before step 4: the function reports
// what it would do (kind, schema, name, captured CREATE SQL) without
// touching the database.
//
// Per spec: detection is always-on; drop is operator-driven.
func DBDropObject(account, schema, kind, name string, preview bool) DBCleanResult {
result := DBCleanResult{
Account: account,
Action: "drop-object",
}
if !reAccountName.MatchString(account) {
result.Message = fmt.Sprintf("%v: %q", errInvalidAccountName, account)
return result
}
if !IsDBObjectKind(kind) {
result.Message = fmt.Sprintf("Invalid object kind %q (want trigger|event|procedure|function)", kind)
return result
}
quotedSchema, err := QuoteIdent(schema)
if err != nil {
result.Message = fmt.Sprintf("Invalid schema name: %v", err)
return result
}
quotedName, err := QuoteIdent(name)
if err != nil {
result.Message = fmt.Sprintf("Invalid object name: %v", err)
return result
}
knownSchemas := findAccountSchemas(account)
if !containsString(knownSchemas, schema) {
result.Message = fmt.Sprintf("Schema %q is not one of the databases discovered for account %q (known: %v)",
schema, account, knownSchemas)
return result
}
result.Database = schema
// SHOW CREATE captures the backup. Different MySQL grammars per
// kind: TRIGGER and EVENT use the schema-qualified name in
// `<schema>.<name>` form; PROCEDURE and FUNCTION accept the same
// shape under modern MySQL. Use the unified form for consistency.
showCreateSQL := fmt.Sprintf("SHOW CREATE %s %s.%s",
strings.ToUpper(kind), quotedSchema, quotedName)
createOutput := runMySQLQueryRoot(schema, showCreateSQL)
if len(createOutput) == 0 {
result.Message = fmt.Sprintf("SHOW CREATE returned no rows for %s %s.%s -- object missing or permission denied",
kind, schema, name)
return result
}
createSQL := strings.Join(createOutput, "\n")
if preview {
result.Message = fmt.Sprintf("PREVIEW: would drop %s %s.%s", kind, schema, name)
result.Details = []string{
fmt.Sprintf("Captured CREATE SQL (%d bytes)", len(createSQL)),
"No backup written and no DROP executed in preview mode.",
}
result.Success = true
return result
}
// Persist backup BEFORE the drop so a SQL failure on DROP still
// leaves the operator with a record of what existed.
sdb := store.Global()
if sdb == nil {
result.Message = "bbolt store not available; refusing to drop without a recorded backup"
return result
}
if err := sdb.PutDBObjectBackup(store.DBObjectBackup{
Account: account,
Schema: schema,
Kind: kind,
Name: name,
CreateSQL: createSQL,
DroppedAt: time.Now().UTC(),
DroppedBy: "csm-cli",
}); err != nil {
result.Message = fmt.Sprintf("recording backup failed (refusing to drop): %v", err)
return result
}
dropSQL := fmt.Sprintf("DROP %s IF EXISTS %s.%s",
strings.ToUpper(kind), quotedSchema, quotedName)
// runMySQLExecRoot reports the mysql client's exec error
// directly. The previous use of runMySQLQueryRoot misread a
// zero-exit + empty-stdout (the success signature for DROP) as
// failure.
if err := runMySQLExecRoot(schema, dropSQL); err != nil {
result.Message = fmt.Sprintf("DROP %s %s.%s failed: %v", kind, schema, name, err)
return result
}
result.Details = []string{
fmt.Sprintf("Dropped %s %s.%s", kind, schema, name),
fmt.Sprintf("Backup recorded in bbolt: %d bytes", len(createSQL)),
}
result.Message = fmt.Sprintf("Dropped %s %s.%s (backup retained)", kind, schema, name)
result.Success = true
return result
}
// findAccountSchemas returns every distinct database name discovered
// across the account's wp-config.php files. Multiple WordPress
// installations under the same account commonly reuse one database
// but can use several; the CLI relies on this list to validate
// operator input before opening any connection.
func findAccountSchemas(account string) []string {
patterns := []string{
fmt.Sprintf("/home/%s/public_html/wp-config.php", account),
}
addonConfigs, _ := osFS.Glob(fmt.Sprintf("/home/%s/*/wp-config.php", account))
patterns = append(patterns, addonConfigs...)
seen := map[string]struct{}{}
var out []string
for _, path := range patterns {
// parseWPConfig handles missing files silently, so the bare
// non-glob first-entry path is harmless when the account has
// no public_html/wp-config.php.
creds := parseWPConfig(path)
if creds.dbName == "" {
continue
}
if _, ok := seen[creds.dbName]; ok {
continue
}
seen[creds.dbName] = struct{}{}
out = append(out, creds.dbName)
}
return out
}
// containsString reports whether haystack contains needle. Local
// because the package's other helper of the same name lives in a
// _test.go file (waf_test.go) and is not visible to production
// builds.
func containsString(haystack []string, needle string) bool {
for _, h := range haystack {
if h == needle {
return true
}
}
return false
}
// RestoreDBObjectBackup re-executes the captured CREATE SQL for a
// previously-dropped MySQL trigger / event / procedure / function.
// Looks up the row in the db_object_backups bbolt bucket by exact
// key; the caller (typically the web UI's cleanup-history page)
// supplies the key it got from the listing endpoint.
//
// Per spec the operation is operator-driven: there is no auto-
// restore. The webui handler enforces the same.
func RestoreDBObjectBackup(backupKey string) DBCleanResult {
result := DBCleanResult{Action: "restore-object"}
sdb := store.Global()
if sdb == nil {
result.Message = "bbolt store not available"
return result
}
rec, ok, err := sdb.GetDBObjectBackupByKey(backupKey)
if err != nil {
result.Message = fmt.Sprintf("looking up backup: %v", err)
return result
}
if !ok {
result.Message = "backup not found (may have been pruned)"
return result
}
result.Account = rec.Account
result.Database = rec.Schema
if rec.CreateSQL == "" {
result.Message = "backup record has no CREATE SQL"
return result
}
if err := runMySQLExecRoot(rec.Schema, rec.CreateSQL); err != nil {
result.Message = fmt.Sprintf("re-executing CREATE failed: %v", err)
return result
}
result.Details = []string{
fmt.Sprintf("Restored %s %s.%s", rec.Kind, rec.Schema, rec.Name),
fmt.Sprintf("Original drop: %s by %s", rec.DroppedAt.Format(time.RFC3339), rec.DroppedBy),
}
if err := sdb.MarkDBObjectBackupRestored(backupKey, time.Now().UTC()); err != nil {
result.Details = append(result.Details, fmt.Sprintf("Restore state was not recorded: %v", err))
}
result.Message = fmt.Sprintf("Restored %s %s.%s from backup", rec.Kind, rec.Schema, rec.Name)
result.Success = true
return result
}
package checks
import (
"bufio"
"context"
"fmt"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/mysqlclient"
"github.com/pidginhost/csm/internal/state"
)
// Malicious patterns in WordPress database content.
//
// requiresExternalScript: when true, a matching row is only reported if
// its content also contains a <script src=...> pointing at a domain NOT
// on the known-safe list. This filters out the legitimate analytics and
// widget embeds that site owners place in page content (Google Tag
// Manager, Google merchant badge, HubSpot, Mailchimp, etc.) without
// weakening detection of attacker-injected external loaders.
var dbMalwarePatterns = []struct {
pattern string
severity alert.Severity
desc string
requiresExternalScript bool
}{
// The script-tag entry catches BOTH inline <script> blocks and
// <script src=...> loaders as a fast LIKE pre-filter; the Go post-
// filter (hasMaliciousExternalScript) verifies the presence of a
// non-safe-domain external src before raising a finding. Inline
// obfuscation without an external src is caught by the subsequent
// code-pattern entries below.
{"<script", alert.High, "injected <script> tag with non-safe external src", true},
{"eval(", alert.High, "eval() in database content", false},
{"base64_decode", alert.High, "base64_decode in database content", false},
{"document.write(", alert.High, "document.write injection", false},
{"String.fromCharCode", alert.High, "JavaScript obfuscation (fromCharCode)", false},
{".workers.dev", alert.Critical, "Cloudflare Workers exfiltration URL", false},
{"gist.githubusercontent.com", alert.Critical, "GitHub Gist payload URL", false},
{"pastebin.com/raw", alert.Critical, "Pastebin payload URL", false},
}
// CheckDatabaseContent scans WordPress databases for injected malware,
// spam content, siteurl hijacking, and rogue admin accounts.
func CheckDatabaseContent(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
wpConfigs, _ := homeGlob(ctx, "public_html", "wp-config.php")
if len(wpConfigs) == 0 {
return nil
}
for _, wpConfig := range wpConfigs {
user := extractUser(filepath.Dir(wpConfig))
creds := parseWPConfig(wpConfig)
if creds.dbName == "" || creds.dbUser == "" {
continue
}
prefix, ok := resolveTablePrefix(creds)
if !ok {
continue
}
creds.tablePrefix = prefix
// Always scan the main-site (or single-site) tables. In
// multisite, blog ID 1 keeps the unprefixed names; in a
// single-site install these are the only tables.
findings = append(findings, checkWPOptions(user, creds, prefix)...)
findings = append(findings, checkWPPosts(user, creds, prefix)...)
// wp_users / wp_usermeta are network-wide in multisite, so
// the user-table scan runs once regardless of the layout.
findings = append(findings, checkWPUsers(user, creds, prefix)...)
// Multisite: enumerate active secondary blog IDs and scan
// each one's wp_<N>_options / wp_<N>_posts. Spam, archived,
// and deleted blogs are excluded -- their content is
// already operator-suppressed at the WP level, and most
// hosts have stale ones we'd otherwise alert on
// indefinitely.
if creds.multisite {
findings = append(findings, scanMultisiteSecondaryBlogs(user, creds, prefix)...)
}
}
return findings
}
// scanMultisiteSecondaryBlogs queries wp_blogs for active blog IDs
// other than 1 and runs the standard options + posts scan against
// each. The user-table scan does NOT iterate -- WP shares
// wp_users / wp_usermeta across the entire network by default; a
// site-specific user table only exists on configurations that
// override that, which we ignore here for v1.
//
// blog_id=1 is excluded because its tables are unprefixed and were
// already scanned by the caller.
func scanMultisiteSecondaryBlogs(user string, creds wpDBCreds, prefix string) []alert.Finding {
query := fmt.Sprintf(
"SELECT blog_id FROM %sblogs WHERE archived = 0 AND deleted = 0 AND spam = 0 AND blog_id != 1",
prefix,
)
rows := runMySQLQuery(creds, query)
var findings []alert.Finding
for _, row := range rows {
blogID := strings.TrimSpace(row)
if blogID == "" || blogID == "1" {
continue
}
// Guard against any garbage in the row -- only digits.
if !isAllDigits(blogID) {
continue
}
sitePrefix := fmt.Sprintf("%s%s_", prefix, blogID)
findings = append(findings, checkWPOptions(user, creds, sitePrefix)...)
findings = append(findings, checkWPPosts(user, creds, sitePrefix)...)
}
return findings
}
func isAllDigits(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if r < '0' || r > '9' {
return false
}
}
return true
}
type wpDBCreds struct {
dbName string
dbUser string
dbPass string
dbHost string
tablePrefix string
// multisite is set when wp-config.php declares
// `define('MULTISITE', true)`. In multisite, the main blog
// (ID 1) keeps the unprefixed table names and secondary blogs
// live under `wp_<N>_options` / `wp_<N>_posts`. CheckDatabaseContent
// scans both layouts when this is set; a single-site install
// (multisite=false) skips the wp_blogs lookup and per-site
// iteration entirely.
multisite bool
}
// parseWPConfig extracts database credentials from wp-config.php.
func parseWPConfig(path string) wpDBCreds {
f, err := osFS.Open(path)
if err != nil {
return wpDBCreds{}
}
defer func() { _ = f.Close() }()
var creds wpDBCreds
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
// Match: define( 'DB_NAME', 'value' );
if val := extractDefine(line, "DB_NAME"); val != "" {
creds.dbName = val
}
if val := extractDefine(line, "DB_USER"); val != "" {
creds.dbUser = val
}
if val := extractDefine(line, "DB_PASSWORD"); val != "" {
creds.dbPass = val
}
if val := extractDefine(line, "DB_HOST"); val != "" {
creds.dbHost = val
}
// Match: $table_prefix = 'wp_';
if strings.Contains(line, "$table_prefix") {
if val := extractPHPString(line); val != "" {
creds.tablePrefix = val
}
}
// Match: define( 'MULTISITE', true );
if extractDefineBool(line, "MULTISITE") {
creds.multisite = true
}
}
if creds.dbHost == "" {
creds.dbHost = "localhost"
}
return creds
}
// extractDefine extracts the value from: define( 'KEY', 'value' );
func extractDefine(line, key string) string {
if !strings.Contains(line, key) {
return ""
}
// Skip comments
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "//") || strings.HasPrefix(trimmed, "#") || strings.HasPrefix(trimmed, "/*") {
return ""
}
// After the literal key, step past the first comma so
// extractPHPString picks up the VALUE's opening quote rather than
// the KEY's trailing closing quote. Without this, on input
// define( 'DB_NAME', 'wordpress_db' );
// extractPHPString would see `', 'wordpress_db' );` and return
// `, ` — the substring between the closing quote of 'DB_NAME' and
// the opening quote of 'wordpress_db'. Every real WordPress
// install's wp-config.php triggered this, which silently broke
// the entire WP database scan check.
rest := line[strings.Index(line, key)+len(key):]
if commaIdx := strings.Index(rest, ","); commaIdx >= 0 {
rest = rest[commaIdx+1:]
}
return extractPHPString(rest)
}
// extractDefineBool returns true when line is a non-comment
// define('<key>', true) -- i.e., a bare boolean value rather than a
// quoted string. Used for `MULTISITE` and any future bool defines
// CSM cares about. Whitespace is permissive, case-insensitive on
// the literal `true`, trailing PHP/inline comments tolerated.
//
// The key must appear inside its enclosing PHP quotes (single or
// double). This avoids a substring-match collision: a WordPress
// wp-config.php commonly carries `define('WP_ALLOW_MULTISITE',
// true)` to enable the admin network creator on single-site
// installs; matching MULTISITE as a bare substring would falsely
// detect those as multisite hosts.
//
// Operators using anything other than the canonical `true` literal
// (e.g., `!false`, `1`, `defined('FOO')`) won't get multisite
// scanning. That's preferable to running an arbitrary PHP expression
// evaluator over wp-config.php.
func extractDefineBool(line, key string) bool {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "//") || strings.HasPrefix(trimmed, "#") || strings.HasPrefix(trimmed, "/*") {
return false
}
if !strings.Contains(trimmed, "define") {
return false
}
// Find the key's quoted form and seek past the closing quote.
// Two acceptable openings: 'KEY' and "KEY". The key never
// appears unquoted inside a define() literal in valid PHP.
var keyEnd int
switch {
case strings.Contains(trimmed, "'"+key+"'"):
keyEnd = strings.Index(trimmed, "'"+key+"'") + len(key) + 2
case strings.Contains(trimmed, `"`+key+`"`):
keyEnd = strings.Index(trimmed, `"`+key+`"`) + len(key) + 2
default:
return false
}
rest := trimmed[keyEnd:]
commaIdx := strings.Index(rest, ",")
if commaIdx < 0 {
return false
}
value := rest[commaIdx+1:]
// The value runs until the closing paren; everything after it
// is the statement terminator and any trailing comment.
if closeIdx := strings.Index(value, ")"); closeIdx >= 0 {
value = value[:closeIdx]
}
return strings.EqualFold(strings.TrimSpace(value), "true")
}
// extractPHPString extracts the first quoted string value from a line.
func extractPHPString(s string) string {
// Find opening quote
for _, quote := range []byte{'\'', '"'} {
start := strings.IndexByte(s, quote)
if start < 0 {
continue
}
rest := s[start+1:]
end := strings.IndexByte(rest, quote)
if end < 0 {
continue
}
return rest[:end]
}
return ""
}
// runMySQLQuery executes a MySQL query via the in-process database/sql
// driver and returns each row tab-joined, matching the legacy
// `mysql -N -B -e <query>` output shape so existing tab-split callers
// keep working unchanged. Returns nil on any open / query / scan
// error (the legacy implementation swallowed errors the same way).
func runMySQLQuery(creds wpDBCreds, query string) []string {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
rows, err := mysqlclient.PerAccountQuery(ctx, mysqlclient.Creds{
User: creds.dbUser,
Password: creds.dbPass,
Host: creds.dbHost,
DBName: creds.dbName,
}, query)
if err != nil {
return nil
}
out := make([]string, 0, len(rows))
for _, line := range rows {
line = strings.TrimSpace(line)
if line != "" {
out = append(out, line)
}
}
if len(out) == 0 {
return nil
}
return out
}
// checkWPOptions checks for siteurl/home hijacking and injected JavaScript.
func checkWPOptions(user string, creds wpDBCreds, prefix string) []alert.Finding {
var findings []alert.Finding
// Check siteurl and home for hijacking
query := fmt.Sprintf(
"SELECT option_name, option_value FROM %soptions WHERE option_name IN ('siteurl', 'home', 'admin_email') LIMIT 10",
prefix)
lines := runMySQLQuery(creds, query)
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) != 2 {
continue
}
optName := parts[0]
optValue := strings.ToLower(parts[1])
// Check if siteurl/home points to a different domain
if (optName == "siteurl" || optName == "home") &&
(strings.Contains(optValue, "eval(") || strings.Contains(optValue, "<script")) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "db_siteurl_hijack",
Message: fmt.Sprintf("WordPress %s contains malicious code (account: %s)", optName, user),
Details: fmt.Sprintf("Database: %s\n%s = %s", creds.dbName, optName, truncateDB(parts[1], 200)),
})
}
}
// Path 1: External script URLs in any option — only flag non-safe domains.
query = fmt.Sprintf(
"SELECT option_name, option_value FROM %soptions WHERE option_value LIKE '%%<script%%src=%%' LIMIT 20",
prefix)
lines = runMySQLQuery(creds, query)
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) != 2 {
continue
}
optName := parts[0]
optValue := parts[1]
// Skip CSM backup options — they preserve the original malicious
// content for recovery and should not be re-detected/re-cleaned.
if strings.HasPrefix(optName, "csm_backup_") {
continue
}
maliciousURL := extractMaliciousScriptURL(optValue)
if maliciousURL == "" {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "db_options_injection",
Message: fmt.Sprintf("Malicious script injection in wp_options '%s' (account: %s)", optName, user),
Details: fmt.Sprintf("Database: %s\nOption: %s\nMalicious URL: %s\nContent preview: %s", creds.dbName, optName, maliciousURL, truncateDB(optValue, 200)),
})
}
// Path 2: Inline script/code injection in core WP options that should
// NEVER contain JavaScript (siteurl, home, blogname, blogdescription).
coreOpts := "siteurl', 'home', 'blogname', 'blogdescription', 'admin_email"
codePatterns := "<script"
query = fmt.Sprintf(
"SELECT option_name, LEFT(option_value, 500) FROM %soptions WHERE option_name IN ('%s') AND option_value LIKE '%%%s%%'",
prefix, coreOpts, codePatterns)
lines = runMySQLQuery(creds, query)
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) != 2 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "db_options_injection",
Message: fmt.Sprintf("Malicious content in core wp_option '%s' (account: %s)", parts[0], user),
Details: fmt.Sprintf("Database: %s\nOption: %s\nContent preview: %s", creds.dbName, parts[0], truncateDB(parts[1], 200)),
})
}
return findings
}
// checkWPPosts checks post content for injected scripts and malware.
//
// Two classes of false positive are suppressed compared to a naive LIKE-
// based scan:
//
// - post_types used for plugin-managed storage (form submissions,
// revisions, templates, minified bundles) are excluded via the
// shared nonScannablePostTypes denylist. See dbscan_filters.go for
// the rationale and the full list.
//
// - Patterns that match too broadly at the SQL layer (the bare
// <script substring, and bare-word spam keywords like "cialis")
// are post-filtered in Go against word-boundary regexes and the
// known-safe-domain list. Legitimate analytics embeds and
// substring coincidences ("specialist" containing "cialis") no
// longer produce findings.
//
// The denylist is defense-in-depth: custom post_types created by a
// theme or plugin remain in scope, so attackers cannot evade by
// inventing a new post_type value.
func checkWPPosts(user string, creds wpDBCreds, prefix string) []alert.Finding {
var findings []alert.Finding
postTypeExcl := nonScannablePostTypesSQLList()
for _, mp := range dbMalwarePatterns {
// Select ID and content so we can post-filter in Go for the
// patterns that require it. ID comes first so that if the
// MySQL client wraps long content across lines we can still
// join reliably on the first tab.
query := fmt.Sprintf(
"SELECT ID, post_content FROM %sposts WHERE post_status='publish' AND post_type NOT IN (%s) AND (post_content LIKE '%%%s%%' OR post_content_filtered LIKE '%%%s%%') LIMIT 20",
prefix, postTypeExcl, mp.pattern, mp.pattern)
lines := runMySQLQuery(creds, query)
if len(lines) == 0 {
continue
}
var confirmedIDs []string
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) < 2 {
continue
}
postID := parts[0]
content := parts[1]
if mp.requiresExternalScript && !hasMaliciousExternalScriptInPost(content) {
// All scripts in this post are inline, or point at a
// host that shows no structural attack marker. We use
// the post-specific predicate (which ignores plaintext
// HTTP) because legacy author embeds from the pre-TLS
// era are legitimate content, not injection.
continue
}
confirmedIDs = append(confirmedIDs, postID)
if len(confirmedIDs) >= 5 {
break
}
}
if len(confirmedIDs) == 0 {
continue
}
findings = append(findings, alert.Finding{
Severity: mp.severity,
Check: "db_post_injection",
Message: fmt.Sprintf("WordPress posts contain %s (account: %s, %d posts)", mp.desc, user, len(confirmedIDs)),
Details: fmt.Sprintf("Database: %s\nAffected post IDs: %s\nPattern: %s",
creds.dbName, strings.Join(confirmedIDs, ", "), mp.pattern),
})
}
// Spam keyword scan. Three-layer filter:
//
// 1. SQL LIKE as a fast server-side pre-filter (reduces rows).
// 2. Word-boundary regex in countCloakedSpamMatches (rejects
// substring false positives like "specialist" / "cialis").
// 3. SEO-context requirement in contentHasSpamContext: a keyword
// hit only counts when accompanied by CSS cloaking, an
// injection fingerprint, or an external anchor whose URL
// path contains the keyword. Bare prose mentions (industry
// verticals, advisor bios, product catalogs listing a
// pharmaceutical supply chain) do not fire.
//
// The context requirement catches the real attack pattern — hidden
// off-screen div with external commercial link — while leaving
// legitimate content silent. See spam_context.go for the full
// signal catalog.
for _, sp := range dbSpamPatterns {
query := fmt.Sprintf(
"SELECT ID, post_content FROM %sposts WHERE post_status='publish' AND post_type NOT IN (%s) AND post_content LIKE '%s' LIMIT 200",
prefix, postTypeExcl, sp.likeFragment)
lines := runMySQLQuery(creds, query)
if len(lines) == 0 {
continue
}
contents := make([]string, 0, len(lines))
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) < 2 {
continue
}
contents = append(contents, parts[1])
}
n := countCloakedSpamMatches(sp, contents)
if n == 0 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "db_spam_injection",
Message: fmt.Sprintf("WordPress posts contain cloaked spam keyword '%s' (%d posts, account: %s)", sp.keyword, n, user),
Details: fmt.Sprintf("Database: %s", creds.dbName),
})
}
return findings
}
// checkWPUsers checks for rogue admin accounts created recently.
func checkWPUsers(user string, creds wpDBCreds, prefix string) []alert.Finding {
var findings []alert.Finding
// Find admin users created in the last 7 days
query := fmt.Sprintf(
"SELECT u.ID, u.user_login, u.user_email, u.user_registered FROM %susers u "+
"INNER JOIN %susermeta m ON u.ID = m.user_id "+
"WHERE m.meta_key = '%scapabilities' AND m.meta_value LIKE '%%administrator%%' "+
"AND u.user_registered >= DATE_SUB(NOW(), INTERVAL 7 DAY) "+
"LIMIT 10",
prefix, prefix, prefix)
lines := runMySQLQuery(creds, query)
for _, line := range lines {
parts := strings.SplitN(line, "\t", 4)
if len(parts) < 3 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "db_rogue_admin",
Message: fmt.Sprintf("New WordPress admin account created in last 7 days: %s (account: %s)", parts[1], user),
Details: fmt.Sprintf("Database: %s\nUser ID: %s\nLogin: %s\nEmail: %s\nRegistered: %s",
creds.dbName, parts[0], parts[1], parts[2], safeGet(parts, 3)),
})
}
// Check for admin users with suspicious email patterns
query = fmt.Sprintf(
"SELECT u.user_login, u.user_email FROM %susers u "+
"INNER JOIN %susermeta m ON u.ID = m.user_id "+
"WHERE m.meta_key = '%scapabilities' AND m.meta_value LIKE '%%administrator%%' "+
"LIMIT 50",
prefix, prefix, prefix)
lines = runMySQLQuery(creds, query)
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) != 2 {
continue
}
email := strings.ToLower(parts[1])
// Flag suspicious admin emails (disposable/temporary email domains)
suspiciousDomains := []string{
"tempmail", "guerrillamail", "mailinator", "throwaway",
"yopmail", "sharklasers", "trashmail", "maildrop",
}
for _, sd := range suspiciousDomains {
if strings.Contains(email, sd) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "db_suspicious_admin_email",
Message: fmt.Sprintf("WordPress admin '%s' has disposable email (account: %s)", parts[0], user),
Details: fmt.Sprintf("Database: %s\nEmail: %s", creds.dbName, email),
})
break
}
}
}
return findings
}
func safeGet(parts []string, idx int) string {
if idx < len(parts) {
return parts[idx]
}
return ""
}
func truncateDB(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// CleanDatabaseSpam removes known spam/malware patterns from WordPress database content.
// Targets wp_posts and wp_options tables. Returns findings for each cleaned row.
func CleanDatabaseSpam(account string) []alert.Finding {
var findings []alert.Finding
wpConfigs, _ := osFS.Glob(filepath.Join("/home", account, "*/wp-config.php"))
wpConfigs2, _ := osFS.Glob(filepath.Join("/home", account, "public_html/wp-config.php"))
wpConfigs = append(wpConfigs, wpConfigs2...)
for _, wpConfig := range wpConfigs {
creds := parseWPConfig(wpConfig)
if creds.dbName == "" {
continue
}
prefix, ok := resolveTablePrefix(creds)
if !ok {
continue
}
creds.tablePrefix = prefix
// Clean spam from wp_posts
spamPatterns := []struct {
pattern string
desc string
}{
{"<script>", "injected script tag"},
{"eval(", "eval() in post content"},
{"base64_decode(", "base64_decode in post content"},
{"document.write(", "document.write injection"},
}
for _, sp := range spamPatterns {
// Count affected rows first
countQuery := fmt.Sprintf(
"SELECT COUNT(*) FROM %sposts WHERE post_content LIKE '%%%s%%'",
prefix, sp.pattern)
countLines := runMySQLQuery(creds, countQuery)
if len(countLines) == 0 || countLines[0] == "0" {
continue
}
// Clean: remove the malicious pattern from post_content
cleanQuery := fmt.Sprintf(
"UPDATE %sposts SET post_content = REPLACE(post_content, '%s', '') WHERE post_content LIKE '%%%s%%'",
prefix, sp.pattern, sp.pattern)
runMySQLQuery(creds, cleanQuery)
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "db_spam_cleaned",
Message: fmt.Sprintf("Cleaned %s from %s posts in %s (account: %s)", sp.desc, countLines[0], creds.dbName, account),
Timestamp: time.Now(),
})
}
// Scan for spam keywords in wp_posts. Uses the same word-boundary
// regex + post_type denylist + SEO-context requirement as
// checkWPPosts so an operator-initiated cleanup surfaces the
// same set of findings the periodic scan does.
postTypeExcl := nonScannablePostTypesSQLList()
for _, sp := range dbSpamPatterns {
query := fmt.Sprintf(
"SELECT ID, post_content FROM %sposts WHERE post_status='publish' AND post_type NOT IN (%s) AND post_content LIKE '%s' LIMIT 200",
prefix, postTypeExcl, sp.likeFragment)
lines := runMySQLQuery(creds, query)
if len(lines) == 0 {
continue
}
contents := make([]string, 0, len(lines))
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) < 2 {
continue
}
contents = append(contents, parts[1])
}
n := countCloakedSpamMatches(sp, contents)
if n == 0 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "db_spam_found",
Message: fmt.Sprintf("Found spam keyword '%s' in %d published posts in %s (account: %s) - manual review recommended", sp.keyword, n, creds.dbName, account),
})
}
}
return findings
}
package checks
import (
"context"
"fmt"
"path/filepath"
"regexp"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// Drupal database content scanner.
//
// v1 covers Drupal 8 and later (the unified-table layout: config /
// node_field_data / users_field_data). Drupal 7 uses a different
// schema (variable / node / users) and reached EOL in January 2025;
// scanning it lands as a follow-up if any operator reports D7 sites
// still in production.
//
// Discovery: glob /home/*/public_html/sites/default/settings.php
// and confirm core/lib/Drupal.php exists in the site root --
// that file is the canonical D8+ marker (D7 has bootstrap.inc /
// modules/ but no core/ directory).
//
// Credentials: parsed by regex over the canonical $databases
// array literal. Drupal allows both array() and short [] syntax;
// the regex accepts either.
//
// Scanned tables (all unprefixed -- D8+ does not use a per-site
// table prefix in standard installs):
//
// config name + data; data is a
// serialized PHP array carrying
// site name, slogan, theme, etc.
// Common hijack target.
// node_revision__body entity_id + body_value; the
// actual article body. Scanned for
// any pattern in dbMalwarePatterns
// with the same external-script
// post-filter the WP and Joomla
// scanners use.
// users_field_data user table; joined with
// user__roles on entity_id = uid. Rows where
// roles_target_id = 'administrator'
// surface as drupal_admin_injection.
//
// Three new finding categories with CMS-explicit names:
// drupal_settings_injection, drupal_content_injection,
// drupal_admin_injection.
// drupalAdminRoleID is the canonical role identifier for Drupal
// site administrators in vanilla D8+. Operators on hardened
// installs may have renumbered or renamed; v1 narrows to this.
const drupalAdminRoleID = "administrator"
// drupalSettingsRe pulls credentials out of the $databases array
// literal. Each field is matched independently rather than
// trying to parse the array structure -- attackers occasionally
// reorder keys, and a key-only regex ignores layout differences.
var (
drupalDBNameRe = regexp.MustCompile(`'database'\s*=>\s*['"]([^'"]+)['"]`)
drupalDBUserRe = regexp.MustCompile(`'username'\s*=>\s*['"]([^'"]+)['"]`)
drupalDBPassRe = regexp.MustCompile(`'password'\s*=>\s*['"]([^'"]+)['"]`)
drupalDBHostRe = regexp.MustCompile(`'host'\s*=>\s*['"]([^'"]+)['"]`)
)
// drupalCreds carries the parsed connection details. Mirrors the
// jConfigCreds shape so existing helpers (runMySQLQuery,
// asWPDBCreds) work uniformly.
type drupalCreds struct {
dbName string
dbUser string
dbPass string
dbHost string
path string
}
func (c drupalCreds) asWPDBCreds() wpDBCreds {
return wpDBCreds{
dbName: c.dbName,
dbUser: c.dbUser,
dbPass: c.dbPass,
dbHost: c.dbHost,
}
}
// CheckDrupalContent discovers Drupal 8+ sites and scans the three
// canonical attacker-touched tables. Mirrors CheckJoomlaContent
// without sharing code -- the credential layout and table set are
// distinct enough that a generic dispatcher would be more
// abstraction than a 4-CMS pipeline calls for.
func CheckDrupalContent(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if ctx == nil {
ctx = context.Background()
}
var findings []alert.Finding
settings, _ := osFS.Glob("/home/*/public_html/sites/default/settings.php")
if len(settings) == 0 {
return nil
}
// Rank by mtime desc so recently touched Drupal sites are processed
// first when the check timeout cuts iteration short.
for _, path := range rankPathsByMtimeDesc(ctx, settings, effectiveAccountScanMaxFiles(cfg)) {
if ctx.Err() != nil {
return findings
}
// public_html is three dirs up from sites/default/settings.php.
publicHTML := filepath.Dir(filepath.Dir(filepath.Dir(path)))
if !looksLikeDrupal8Plus(publicHTML) {
continue
}
// /home/<account> is one level above public_html.
account := extractUser(filepath.Dir(publicHTML))
creds := parseDrupalSettings(path)
if creds.dbName == "" || creds.dbUser == "" {
continue
}
findings = append(findings, scanDrupalConfig(account, creds)...)
findings = append(findings, scanDrupalContent(account, creds)...)
findings = append(findings, scanDrupalAdmins(account, creds)...)
}
return findings
}
// looksLikeDrupal8Plus checks for the core/lib/Drupal.php marker
// that distinguishes D8+ from D7. Stat (not Open) so we don't
// pull file content into memory just to check existence.
func looksLikeDrupal8Plus(publicHTML string) bool {
marker := filepath.Join(publicHTML, "core", "lib", "Drupal.php")
_, err := osFS.Stat(marker)
return err == nil
}
// parseDrupalSettings reads settings.php and returns the database
// credentials from the $databases['default']['default'] entry. If
// settings.php uses split-DB or per-environment overrides, only
// the first 'default' connection is reported -- the rest are
// followed by the same regex on subsequent calls.
func parseDrupalSettings(path string) drupalCreds {
creds := drupalCreds{path: path}
// #nosec G304 -- path resolved via osFS.Glob over /home/*/public_html; not attacker-controlled.
data, err := osFS.ReadFile(path)
if err != nil {
return creds
}
body := string(data)
if m := drupalDBNameRe.FindStringSubmatch(body); m != nil {
creds.dbName = m[1]
}
if m := drupalDBUserRe.FindStringSubmatch(body); m != nil {
creds.dbUser = m[1]
}
if m := drupalDBPassRe.FindStringSubmatch(body); m != nil {
creds.dbPass = m[1]
}
if m := drupalDBHostRe.FindStringSubmatch(body); m != nil {
creds.dbHost = m[1]
}
if creds.dbHost == "" {
creds.dbHost = "localhost"
}
return creds
}
// scanDrupalConfig pulls rows from the config table whose data
// blob matches any malware pattern, then refines via
// classifyMalwareRow (strict / config-storage variant).
func scanDrupalConfig(account string, creds drupalCreds) []alert.Finding {
query := fmt.Sprintf(
"SELECT name, data FROM config WHERE %s",
paramsLikeClause("data"))
rows := runMySQLQuery(creds.asWPDBCreds(), query)
var findings []alert.Finding
for _, row := range rows {
name, body := splitTabRow(row)
if name == "" {
continue
}
sev, desc, ok := classifyMalwareRow(body, false)
if !ok {
continue
}
findings = append(findings, alert.Finding{
Severity: sev,
Check: "drupal_settings_injection",
Message: fmt.Sprintf("Drupal config injection on %s: %s (%s)", account, name, desc),
Details: fmt.Sprintf("Account: %s\nConfig name: %s\nMatch: %s", account, name, desc),
})
}
return findings
}
// scanDrupalContent walks node_revision__body for malware-pattern
// matches in article bodies. The looser post-filter
// (hasMaliciousExternalScriptInPost) applies because article
// content is author-written and may carry pre-TLS-era embeds.
func scanDrupalContent(account string, creds drupalCreds) []alert.Finding {
query := fmt.Sprintf(
"SELECT entity_id, body_value FROM node_revision__body WHERE %s",
paramsLikeClause("body_value"))
rows := runMySQLQuery(creds.asWPDBCreds(), query)
var findings []alert.Finding
for _, row := range rows {
entityID, body := splitTabRow(row)
if entityID == "" {
continue
}
sev, desc, ok := classifyMalwareRow(body, true)
if !ok {
continue
}
findings = append(findings, alert.Finding{
Severity: sev,
Check: "drupal_content_injection",
Message: fmt.Sprintf("Drupal article content injection on %s: node %s (%s)", account, entityID, desc),
Details: fmt.Sprintf("Account: %s\nNode entity_id: %s\nMatch: %s", account, entityID, desc),
})
}
return findings
}
// scanDrupalAdmins joins users_field_data with user__roles to
// surface every account in the administrator role. Same Warning
// severity / per-row emission as the Joomla equivalent --
// legitimate site admin shows up here too, so this is operator
// review territory rather than auto-actionable.
//
// users_field_data is multilingual in D8+: a single uid can appear
// once per language code on translated sites. The default_langcode
// = 1 filter keeps each admin to exactly one finding regardless of
// how many translations the site has.
func scanDrupalAdmins(account string, creds drupalCreds) []alert.Finding {
query := fmt.Sprintf(
"SELECT u.uid, u.name, u.mail FROM users_field_data u JOIN user__roles r ON u.uid = r.entity_id WHERE r.roles_target_id = '%s' AND u.default_langcode = 1",
drupalAdminRoleID)
rows := runMySQLQuery(creds.asWPDBCreds(), query)
if len(rows) == 0 {
return nil
}
var findings []alert.Finding
for _, row := range rows {
fields := strings.Split(row, "\t")
if len(fields) < 1 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "drupal_admin_injection",
Message: fmt.Sprintf("Drupal administrator account on %s: %s", account, fields[0]),
Details: fmt.Sprintf("Account: %s\nRow: %s\nReview: confirm this is the legitimate site administrator.", account, row),
})
}
return findings
}
package checks
import (
"regexp"
"strings"
)
// This file contains pure-function helpers used by the database content
// scanner (checkWPPosts in dbscan.go). Keeping them pure (no MySQL, no
// filesystem) makes them deterministically testable and independently
// reusable.
//
// Two classes of false positive were historically observed on real
// production traffic:
//
// 1. db_post_injection fired on every post containing a script tag,
// including site-owner-added analytics and widget embeds (Google Tag
// Manager, Google merchant rating badge, etc.).
//
// 2. db_spam_injection used substring LIKE matching, so "specialist"
// triggered on "cialis", "pharmaceutical" triggered on "pharma",
// "casino resort" triggered on "casino", etc. It also scanned all
// post_types including Contact Form 7 / WPForms / Jetpack stored
// submissions, which routinely contain spambot form fills the site
// owner never displays.
//
// The helpers below encode the decisions needed to eliminate those FPs
// without opening detection holes: word-boundary keyword matching,
// post_type filtering against a denylist (not an allowlist, so attackers
// cannot hide a post by renaming post_type to one we didn't anticipate),
// and safe-domain filtering for external script-tag sources.
// nonScannablePostTypes are post_type values that legitimately store
// non-site-content data (form submissions, revisions, templates, feeds).
// These are excluded from malware and spam scans because their content
// is operator-invisible storage, not material rendered to site visitors.
//
// This is a DENYLIST, not an allowlist. A custom post_type created by a
// theme or plugin (for example WooCommerce `product`, events, portfolios)
// is still scanned. Adding a new value here is safe; an attacker cannot
// hide a post by choosing a new post_type, because we default to
// scanning anything not on this list.
var nonScannablePostTypes = []string{
// WordPress internals / templates / navigation
"revision",
"customize_changeset",
"oembed_cache",
"nav_menu_item",
"wp_template",
"wp_template_part",
"wp_global_styles",
"wp_navigation",
// Minification plugins (store compiled bundles that legitimately
// contain JavaScript and obfuscated character sequences).
"wphb_minify_group",
// Form builders (store plugin configuration and visitor submissions.
// Contact-form spam landing here is noise, not site compromise.)
"wpforms",
"wpforms_entries",
"wpforms-log",
"wpcf7_contact_form",
"flamingo_inbound",
"flamingo_outbound",
"cf7_message",
"feedback",
"jetpack_feedback",
}
// isScannablePostType returns true if the given post_type should be
// included in malware and spam scans. The decision mirrors the SQL
// `post_type NOT IN (...)` clause used in checkWPPosts, so Go-side
// callers and test assertions stay consistent with the live SQL.
func isScannablePostType(postType string) bool {
for _, t := range nonScannablePostTypes {
if t == postType {
return false
}
}
return true
}
// nonScannablePostTypesSQLList returns the denylist as a comma-separated
// SQL literal list (for example "'revision','wp_template',...") suitable
// for use inside a `post_type NOT IN (...)` clause. The values are
// hardcoded and contain only [a-z_-] characters, so SQL injection is not
// a risk here; nonetheless the function escapes defensively.
func nonScannablePostTypesSQLList() string {
parts := make([]string, 0, len(nonScannablePostTypes))
for _, t := range nonScannablePostTypes {
// Hardcoded list contains only [a-z_-], but defensive escape in
// case future maintainers add a value with a quote or backslash.
escaped := strings.ReplaceAll(t, `\`, `\\`)
escaped = strings.ReplaceAll(escaped, `'`, `\'`)
parts = append(parts, "'"+escaped+"'")
}
return strings.Join(parts, ",")
}
// dbSpamPattern is a single spam-keyword detector. The LIKE fragment
// is a MySQL server-side pre-filter that quickly narrows the set of
// candidate posts; the Go regex then applies a strict word-boundary
// test so that "specialist" does not match "cialis", "pharmacy" does
// not match "pharma", and so on.
//
// Patterns that end with a non-word character (dash) already have an
// implicit right boundary from that character and only need a left
// word boundary. Pure-word patterns need boundaries on both sides.
type dbSpamPattern struct {
keyword string // human-readable keyword used in finding messages
regex *regexp.Regexp // applied Go-side to candidate rows
likeFragment string // SQL LIKE fragment, always bracketed with '%'
}
// dbSpamPatterns enumerates the keywords we flag as SEO/pharma/gambling
// spam in WordPress post content. Each entry pairs a fast SQL LIKE with
// a strict Go-side word-boundary regex.
//
// The regexes are case-insensitive to catch CIALIS / Cialis / cialis.
// The LIKE fragments are lowercase because MySQL LIKE is case-insensitive
// under the default _ci collation used by cPanel MariaDB.
var dbSpamPatterns = []dbSpamPattern{
{"viagra", regexp.MustCompile(`(?i)\bviagra\b`), "%viagra%"},
{"cialis", regexp.MustCompile(`(?i)\bcialis\b`), "%cialis%"},
{"pharma", regexp.MustCompile(`(?i)\bpharma\b`), "%pharma%"},
{"betting", regexp.MustCompile(`(?i)\bbetting\b`), "%betting%"},
// Dashed variants: the trailing dash is itself a non-word char and
// serves as the right boundary. Only a left word-boundary is needed.
{"casino-", regexp.MustCompile(`(?i)\bcasino-`), "%casino-%"},
{"buy-cheap-", regexp.MustCompile(`(?i)\bbuy-cheap-`), "%buy-cheap-%"},
{"free-download", regexp.MustCompile(`(?i)\bfree-download`), "%free-download%"},
{"crack-serial", regexp.MustCompile(`(?i)\bcrack-serial`), "%crack-serial%"},
}
// countSpamMatches returns the number of candidate rows whose content
// matches pattern.regex. The caller is responsible for passing only
// rows that were already narrowed by the pattern.likeFragment SQL
// pre-filter; this function applies the strict word-boundary test.
func countSpamMatches(pattern dbSpamPattern, contents []string) int {
n := 0
for _, c := range contents {
if pattern.regex.MatchString(c) {
n++
}
}
return n
}
// hasMaliciousExternalScript reports whether the content contains a
// script-tag with a src attribute pointing at a domain NOT on the known-
// safe list (see knownSafeDomains in db_autoresponse.go).
//
// This predicate uses the STRICT classifier (isAttackerScriptURL): it
// flags raw-IP hosts, abused TLDs, known-bad exfil hosts, empty hosts,
// AND plaintext HTTP external scripts. It is the right predicate for
// wp_options and similar configuration storage where fresh-today
// content is expected; see hasMaliciousExternalScriptInPost for the
// looser post_content predicate.
//
// Inline script blocks without a src attribute are not classified by
// this function; those are covered by the separate code-pattern entries
// in dbMalwarePatterns which catch common inline obfuscation techniques.
//
// Rationale: a bare script-tag match was the primary source of false
// positives on real traffic. Legitimate analytics embeds (Google Tag
// Manager, Google Analytics, Google merchant rating badge, Mailchimp,
// HubSpot, etc.) install both an external loader tag AND an inline
// initialization block. Flagging the inline block alone produced many
// HIGH severity noise findings on customer sites. Requiring a
// non-safe-domain external src reduces this to zero FPs in practice
// while still catching attackers who inject a tag pointing at an
// untrusted domain.
func hasMaliciousExternalScript(content string) bool {
return extractMaliciousScriptURL(content) != ""
}
// hasMaliciousExternalScriptInPost is the post_content variant of
// hasMaliciousExternalScript. It applies the same regex-based script-tag
// extraction but classifies each src URL with isAttackerScriptURLInPost,
// which drops the plaintext-HTTP indicator.
//
// Rationale: post_content carries author-written text that can contain
// pre-TLS-era embeds (e.g. a 2013-era video embed on a Romanian video
// site). Those embeds use http:// and, under the strict classifier,
// would flag on scheme alone even though the post has not been modified
// in a decade. Attackers in 2026 almost never land on plaintext-HTTP
// mainstream-TLD URLs — they use raw IPs, abused TLDs, or cheap exfil
// hosts — so dropping the HTTP signal for this context eliminates the
// legacy-embed false positives without giving up meaningful detection.
func hasMaliciousExternalScriptInPost(content string) bool {
matches := scriptSrcRe.FindAllStringSubmatch(content, -1)
for _, match := range matches {
if len(match) < 2 {
continue
}
if isAttackerScriptURLInPost(match[1]) {
return true
}
}
return false
}
package checks
import (
"context"
"fmt"
"path/filepath"
"regexp"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// Joomla database content scanner.
//
// Discovery: glob /home/*/public_html/configuration.php and
// verify the file contains `class JConfig` -- the canonical marker
// for a Joomla site, distinguishing it from PHP files that happen
// to share the configuration.php filename. Credentials are read
// via regex over public-property assignments (`public $host = ...;`)
// rather than PHP eval; the parser ignores anything outside the
// JConfig class body.
//
// Scanned tables (all prefixed; the prefix is operator-controlled
// via configuration.php's $dbprefix and defaults to `jos_` /
// `<random>_` on fresh installs):
//
// <prefix>extensions params blob -- live_site, sitename,
// offline_message are common hijack targets
// <prefix>content article body for malware patterns
// <prefix>users user table; joined with
// <prefix>user_usergroup_map to find rogue Super Users
// (group_id = 8 in vanilla Joomla)
//
// Three new finding categories. CMS-explicit names so operators
// running mixed-CMS hosts can suppress per-CMS:
//
// joomla_extensions_injection (Critical) -- malware pattern in
// an extension's params
// joomla_content_injection (Critical) -- malware pattern in
// an article body
// joomla_admin_injection (Critical) -- rogue Super User
// account
// jConfigCredsPattern parses a `public $foo = 'value';` line
// (single OR double quotes). Anchored to the start of the line so
// arbitrary text inside string literals further along can't be
// misread as a credential.
var jConfigCredsPattern = regexp.MustCompile(`^\s*public\s+\$(\w+)\s*=\s*['"]([^'"]*)['"]\s*;`)
// joomlaSuperUserGroupID is the canonical group id for "Super Users"
// in vanilla Joomla 3+. Operators on hardened installs may have
// renumbered; the spec narrows to 8 for v1.
const joomlaSuperUserGroupID = 8
// jConfigCreds carries the credentials extracted from a Joomla
// configuration.php. Mirrors the wpDBCreds shape so the existing
// runMySQLQuery / mysqlSchemaLiteral helpers can be reused, but
// kept distinct because Joomla configuration.php and WordPress
// wp-config.php are not interchangeable.
type jConfigCreds struct {
dbName string
dbUser string
dbPass string
dbHost string
dbPrefix string
path string
}
// asWPDBCreds returns the equivalent wpDBCreds for runMySQLQuery
// reuse. The `multisite` field is irrelevant here; the existing
// mysql client wrapper does not look at it.
func (c jConfigCreds) asWPDBCreds() wpDBCreds {
return wpDBCreds{
dbName: c.dbName,
dbUser: c.dbUser,
dbPass: c.dbPass,
dbHost: c.dbHost,
tablePrefix: c.dbPrefix,
}
}
// CheckJoomlaContent scans every Joomla installation under
// /home/*/public_html for malware-pattern matches in the three
// canonical attacker-touched tables. Mirrors the structure of
// CheckDatabaseContent without sharing code -- the credentials and
// table layout differ enough that a generic dispatcher is more
// abstraction than this point in the codebase needs.
func CheckJoomlaContent(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if ctx == nil {
ctx = context.Background()
}
var findings []alert.Finding
configs, _ := osFS.Glob("/home/*/public_html/configuration.php")
if len(configs) == 0 {
return nil
}
// Rank by mtime desc so recently touched Joomla installs are processed
// first when the check timeout cuts iteration short.
for _, path := range rankPathsByMtimeDesc(ctx, configs, effectiveAccountScanMaxFiles(cfg)) {
if ctx.Err() != nil {
return findings
}
if !looksLikeJoomlaConfig(path) {
continue
}
account := extractUser(filepath.Dir(path))
creds := parseJConfig(path)
if creds.dbName == "" || creds.dbUser == "" {
continue
}
prefix := creds.dbPrefix
if prefix == "" {
prefix = "jos_"
}
findings = append(findings, scanJoomlaExtensions(account, creds, prefix)...)
findings = append(findings, scanJoomlaContent(account, creds, prefix)...)
findings = append(findings, scanJoomlaSuperUsers(account, creds, prefix)...)
}
return findings
}
// looksLikeJoomlaConfig reads the file looking for the `class JConfig`
// marker. The file must be small enough that a full read is cheap
// (vanilla Joomla configuration.php is ~3 KB; a hostile multi-MB
// file would be unusual but we cap implicitly via Open + ReadFile).
func looksLikeJoomlaConfig(path string) bool {
// #nosec G304 -- path resolved via osFS.Glob over /home/*/public_html; not attacker-controlled.
data, err := osFS.ReadFile(path)
if err != nil {
return false
}
// A defaced configuration.php that still has the class
// declaration but injected malicious public properties is
// exactly what we WANT to scan. Marker check is intentionally
// loose: any occurrence of `class JConfig` (case-insensitive on
// the keyword `class`).
lower := strings.ToLower(string(data))
return strings.Contains(lower, "class jconfig")
}
// parseJConfig reads configuration.php and pulls credentials out of
// the public-property assignments. Lines outside the class body
// (PHP comments, namespaced statements, etc.) are tolerated
// silently because the regex is line-anchored to "public $foo = ...".
func parseJConfig(path string) jConfigCreds {
creds := jConfigCreds{path: path}
// #nosec G304 -- same Glob-resolved path as above.
data, err := osFS.ReadFile(path)
if err != nil {
return creds
}
for _, line := range strings.Split(string(data), "\n") {
m := jConfigCredsPattern.FindStringSubmatch(line)
if m == nil {
continue
}
switch strings.ToLower(m[1]) {
case "host":
creds.dbHost = m[2]
case "user":
creds.dbUser = m[2]
case "password":
creds.dbPass = m[2]
case "db":
creds.dbName = m[2]
case "dbprefix":
creds.dbPrefix = m[2]
}
}
if creds.dbHost == "" {
creds.dbHost = "localhost"
}
return creds
}
// scanJoomlaExtensions queries the extensions table for params
// blobs that match the malware-pattern pre-filter, then applies a
// Go-side post-filter to drop rows whose only match was <script>
// LIKE noise (legitimate analytics embeds in extension params).
//
// Two-phase classifier:
//
// 1. SQL pre-filter via LIKE keeps the result set bounded -- a
// vanilla Joomla #__extensions table has ~50 rows, but a
// plugin-heavy install can exceed 200, and we don't want to
// pull every params blob into the daemon.
//
// 2. Go classifyMalwareRow re-checks each pattern individually
// against the full body, applying the same requiresExternalScript
// filter the WP scanner uses for wp_options. Strict predicate
// here (hasMaliciousExternalScript) because params is config
// storage.
func scanJoomlaExtensions(account string, creds jConfigCreds, prefix string) []alert.Finding {
query := fmt.Sprintf(
"SELECT name, params FROM %sextensions WHERE %s",
prefix, paramsLikeClause("params"))
rows := runMySQLQuery(creds.asWPDBCreds(), query)
var findings []alert.Finding
for _, row := range rows {
name, body := splitTabRow(row)
if name == "" {
continue
}
sev, desc, ok := classifyMalwareRow(body, false)
if !ok {
continue
}
findings = append(findings, alert.Finding{
Severity: sev,
Check: "joomla_extensions_injection",
Message: fmt.Sprintf("Joomla extension params injection on %s: %s (%s)", account, name, desc),
Details: fmt.Sprintf("Account: %s\nExtension: %s\nMatch: %s", account, name, desc),
})
}
return findings
}
// scanJoomlaContent queries article bodies (introtext) for malware
// patterns. Same two-phase classifier as scanJoomlaExtensions but
// uses the looser post-filter (hasMaliciousExternalScriptInPost)
// because articles are author-written and may carry pre-TLS-era
// embeds the strict predicate would flag on scheme alone.
//
// fulltext_ is not scanned in v1: it's almost never populated on
// modern Joomla installs (the read-more split is a layout choice
// most templates don't bother with), and adding it doubles the
// query cost for marginal coverage. Follow-up if operators see
// missed detections.
func scanJoomlaContent(account string, creds jConfigCreds, prefix string) []alert.Finding {
query := fmt.Sprintf(
"SELECT id, title, introtext FROM %scontent WHERE %s",
prefix, paramsLikeClause("introtext"))
rows := runMySQLQuery(creds.asWPDBCreds(), query)
var findings []alert.Finding
for _, row := range rows {
fields := strings.SplitN(row, "\t", 3)
if len(fields) < 3 {
continue
}
id, title, body := fields[0], fields[1], fields[2]
sev, desc, ok := classifyMalwareRow(body, true)
if !ok {
continue
}
findings = append(findings, alert.Finding{
Severity: sev,
Check: "joomla_content_injection",
Message: fmt.Sprintf("Joomla article content injection on %s: id=%s title=%q (%s)", account, id, title, desc),
Details: fmt.Sprintf("Account: %s\nArticle ID: %s\nTitle: %s\nMatch: %s", account, id, title, desc),
})
}
return findings
}
// classifyMalwareRow walks dbMalwarePatterns against body and
// returns the strongest pattern match that survives the
// requiresExternalScript filter. Returns ok=false when nothing
// genuine matched -- the caller skips that row entirely.
//
// inPostContext switches between the strict
// hasMaliciousExternalScript (for config-storage rows like Joomla
// extension params or Drupal config) and the looser
// hasMaliciousExternalScriptInPost (for author-written article
// content). Mirrors how the WP scanner picks its predicate per
// table. Shared by the Joomla and Drupal scanners; the function
// lives in dbscan_joomla.go for historical reasons (added with
// the Joomla scanner; renamed when Drupal also needed it).
func classifyMalwareRow(body string, inPostContext bool) (alert.Severity, string, bool) {
if body == "" {
return 0, "", false
}
lower := strings.ToLower(body)
var bestSev alert.Severity
var bestDesc string
matched := false
for _, p := range dbMalwarePatterns {
if !strings.Contains(lower, strings.ToLower(p.pattern)) {
continue
}
if p.requiresExternalScript {
ok := hasMaliciousExternalScript(body)
if inPostContext {
ok = hasMaliciousExternalScriptInPost(body)
}
if !ok {
continue
}
}
if !matched || p.severity > bestSev {
bestSev = p.severity
bestDesc = p.desc
}
matched = true
}
return bestSev, bestDesc, matched
}
// scanJoomlaSuperUsers detects rogue accounts in the Super Users
// group (group_id = 8 by default). The two-table join is necessary
// because Joomla stores group membership separately from the user
// row; a single rogue admin shows up only when the join fires.
func scanJoomlaSuperUsers(account string, creds jConfigCreds, prefix string) []alert.Finding {
query := fmt.Sprintf(
"SELECT u.id, u.username, u.email FROM %susers u JOIN %suser_usergroup_map m ON u.id = m.user_id WHERE m.group_id = %d",
prefix, prefix, joomlaSuperUserGroupID)
rows := runMySQLQuery(creds.asWPDBCreds(), query)
if len(rows) == 0 {
return nil
}
// Operator-review territory: the legitimate site admin shows up
// here too. We emit a Warning per row so operators can confirm.
// A separate Critical detector for accounts created in the last
// hour is a follow-up; v1 emits visibility only.
var findings []alert.Finding
for _, row := range rows {
fields := strings.Split(row, "\t")
if len(fields) < 1 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "joomla_admin_injection",
Message: fmt.Sprintf("Joomla Super User account on %s: %s", account, fields[0]),
Details: fmt.Sprintf("Account: %s\nRow: %s\nReview: confirm this is the legitimate site administrator.", account, row),
})
}
return findings
}
// paramsLikeClause builds an OR'd LIKE clause over the supplied
// columns and the existing dbMalwarePatterns list. The patterns
// are escaped for MySQL string literal syntax (single quotes
// doubled, backslashes doubled).
//
// We use plain LIKE rather than full-text search so the query
// runs against any MySQL configuration -- some shared-hosting
// instances have ngram FTS off, and we don't want to depend on
// it for a security scan.
func paramsLikeClause(columns ...string) string {
if len(columns) == 0 {
return "1=0"
}
var clauses []string
for _, col := range columns {
for _, p := range dbMalwarePatterns {
lit := mysqlEscapeForLike(p.pattern)
clauses = append(clauses, fmt.Sprintf("%s LIKE '%%%s%%'", col, lit))
}
}
return strings.Join(clauses, " OR ")
}
// mysqlEscapeForLike escapes a literal for use inside a single-quoted
// MySQL LIKE pattern. Only `'` and `\` need escaping; LIKE's `%` and
// `_` are intentionally left alone because the malware-pattern list
// uses literal substrings and never SQL wildcards.
func mysqlEscapeForLike(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `'`, `\'`)
return s
}
package checks
import (
"context"
"encoding/xml"
"fmt"
"path/filepath"
"regexp"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// Magento database content scanner.
//
// Single file covering both major versions: M1 (Magento 1.x, EOL'd
// June 2020 but still found on legacy hosts) and M2 (Magento 2.x).
// The two versions share table names and content shape but disagree
// on configuration file format -- M1 stores credentials in
// app/etc/local.xml as XML, M2 stores them in app/etc/env.php as
// PHP arrays.
//
// Discovery is mutually exclusive: a host either has app/etc/env.php
// (M2) or app/etc/local.xml (M1), never both. We probe in that order
// because M2 is the actively maintained version and we want fresh
// installs to be picked up first.
//
// Scanned tables (identical between M1 and M2):
//
// core_config_data (path, value) -- settings.
// web/unsecure/base_url is the
// canonical hijack target;
// attackers redirect the storefront
// by overwriting it.
// catalog_product_entity_text product description text;
// spam-injection vector for SEO
// cms_block + cms_page CMS-managed content; same
// pattern as the product text scan
// admin_user backend administrator accounts.
// One Warning per row -- legitimate
// admin shows up too.
//
// Three new finding categories with CMS-explicit names:
// magento_settings_injection, magento_content_injection,
// magento_admin_injection.
// magentoCreds carries the parsed connection details. Mirrors
// jConfigCreds / drupalCreds; the version field tells the scanner
// which discovery path produced the creds (useful for messages).
type magentoCreds struct {
dbName string
dbUser string
dbPass string
dbHost string
dbPrefix string
version string // "M1" | "M2"
path string
}
func (c magentoCreds) asWPDBCreds() wpDBCreds {
return wpDBCreds{
dbName: c.dbName,
dbUser: c.dbUser,
dbPass: c.dbPass,
dbHost: c.dbHost,
tablePrefix: c.dbPrefix,
}
}
// magentoM1XMLRoot is the minimum struct surface encoding/xml needs
// to extract the connection block out of a Magento 1.x local.xml.
// CDATA wrapping is transparent to the decoder; both the bare and
// CDATA-wrapped forms produce the same string value.
type magentoM1XMLRoot struct {
XMLName xml.Name `xml:"config"`
Connection magentoM1XMLConnBlock `xml:"global>resources>default_setup>connection"`
Resources magentoM1XMLResources `xml:"global>resources>db"`
}
type magentoM1XMLConnBlock struct {
Host string `xml:"host"`
Username string `xml:"username"`
Password string `xml:"password"`
DBName string `xml:"dbname"`
}
type magentoM1XMLResources struct {
TablePrefix string `xml:"table_prefix"`
}
// M2 env.php is a PHP file returning an array; we extract by regex
// rather than wiring a PHP parser. The patterns match the canonical
// nested-array layout that vendor/magento installers produce; hand-
// rolled env.php files with reordered keys still parse because
// each pattern matches independently.
var (
magentoM2HostRe = regexp.MustCompile(`['"]host['"]\s*=>\s*['"]([^'"]+)['"]`)
magentoM2UserRe = regexp.MustCompile(`['"]username['"]\s*=>\s*['"]([^'"]+)['"]`)
magentoM2PassRe = regexp.MustCompile(`['"]password['"]\s*=>\s*['"]([^'"]+)['"]`)
magentoM2DBRe = regexp.MustCompile(`['"]dbname['"]\s*=>\s*['"]([^'"]+)['"]`)
magentoM2PrefixRe = regexp.MustCompile(`['"]table_prefix['"]\s*=>\s*['"]([^'"]*)['"]`)
)
// CheckMagentoContent discovers Magento installs (M1 + M2) and
// scans the four canonical tables. Mirrors CheckJoomlaContent and
// CheckDrupalContent without sharing code -- the version-branching
// is local to this scanner.
//
// Accounts that produced creds via the M2 (env.php) path are
// tracked in seenAccounts so the M1 fallback doesn't re-scan a
// host that's already been processed -- including the common case
// where M2 found zero malware findings (a clean install). Without
// this, a half-migrated host with both env.php and stale local.xml
// would scan the database twice with different credential sets.
func CheckMagentoContent(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if ctx == nil {
ctx = context.Background()
}
var findings []alert.Finding
seenAccounts := map[string]bool{}
// M2 discovery first (active version). Rank by mtime desc so recently
// touched installs are processed first when the check timeout cuts
// iteration short.
m2Files, _ := osFS.Glob("/home/*/public_html/app/etc/env.php")
for _, path := range rankPathsByMtimeDesc(ctx, m2Files, effectiveAccountScanMaxFiles(cfg)) {
if ctx.Err() != nil {
return findings
}
account := magentoAccountFromPath(path)
creds := parseMagentoM2(path)
if creds.dbName == "" {
continue
}
seenAccounts[account] = true
findings = append(findings, scanMagentoAll(account, creds)...)
}
// M1 fallback for hosts where env.php is absent or unparseable.
m1Files, _ := osFS.Glob("/home/*/public_html/app/etc/local.xml")
for _, path := range rankPathsByMtimeDesc(ctx, m1Files, effectiveAccountScanMaxFiles(cfg)) {
if ctx.Err() != nil {
return findings
}
account := magentoAccountFromPath(path)
if seenAccounts[account] {
continue
}
creds := parseMagentoM1(path)
if creds.dbName == "" {
continue
}
findings = append(findings, scanMagentoAll(account, creds)...)
}
return findings
}
// magentoAccountFromPath strips the conventional cPanel prefix
// (/home/<account>/public_html/app/etc/...) down to the account
// component.
func magentoAccountFromPath(path string) string {
// /home/<account>/public_html/app/etc/<file> -- four Dirs up.
cur := path
for i := 0; i < 4; i++ {
cur = filepath.Dir(cur)
}
return extractUser(cur)
}
// parseMagentoM1 reads local.xml and extracts the connection block.
// Returns zero-valued creds on any error -- a malformed XML file
// silently skips the install rather than crashing the deep tier.
func parseMagentoM1(path string) magentoCreds {
creds := magentoCreds{path: path, version: "M1"}
// #nosec G304 -- path resolved via osFS.Glob over /home/*/public_html/app/etc/; not attacker-controlled.
data, err := osFS.ReadFile(path)
if err != nil {
return creds
}
var root magentoM1XMLRoot
if err := xml.Unmarshal(data, &root); err != nil {
return creds
}
creds.dbHost = strings.TrimSpace(root.Connection.Host)
creds.dbUser = strings.TrimSpace(root.Connection.Username)
creds.dbPass = strings.TrimSpace(root.Connection.Password)
creds.dbName = strings.TrimSpace(root.Connection.DBName)
creds.dbPrefix = strings.TrimSpace(root.Resources.TablePrefix)
if creds.dbHost == "" {
creds.dbHost = "localhost"
}
return creds
}
// parseMagentoM2 reads env.php and pulls credentials out via the
// field-level regexes. Unlike Drupal we have a stable nested-array
// layout to match against (the one Magento Setup writes), but to
// stay robust against operator-edited env.php files we match each
// key independently.
func parseMagentoM2(path string) magentoCreds {
creds := magentoCreds{path: path, version: "M2"}
// #nosec G304 -- same Glob-resolved path as parseMagentoM1.
data, err := osFS.ReadFile(path)
if err != nil {
return creds
}
body := string(data)
if m := magentoM2HostRe.FindStringSubmatch(body); m != nil {
creds.dbHost = m[1]
}
if m := magentoM2UserRe.FindStringSubmatch(body); m != nil {
creds.dbUser = m[1]
}
if m := magentoM2PassRe.FindStringSubmatch(body); m != nil {
creds.dbPass = m[1]
}
if m := magentoM2DBRe.FindStringSubmatch(body); m != nil {
creds.dbName = m[1]
}
if m := magentoM2PrefixRe.FindStringSubmatch(body); m != nil {
creds.dbPrefix = m[1]
}
if creds.dbHost == "" {
creds.dbHost = "localhost"
}
return creds
}
// scanMagentoAll runs the four scan paths against one Magento
// install. Helper exists so M1 and M2 dispatch through the same
// post-creds code path.
func scanMagentoAll(account string, creds magentoCreds) []alert.Finding {
var findings []alert.Finding
findings = append(findings, scanMagentoSettings(account, creds)...)
findings = append(findings, scanMagentoContent(account, creds, "catalog_product_entity_text", "value")...)
findings = append(findings, scanMagentoContent(account, creds, "cms_block", "content")...)
findings = append(findings, scanMagentoContent(account, creds, "cms_page", "content")...)
findings = append(findings, scanMagentoAdmins(account, creds)...)
return findings
}
// scanMagentoSettings looks for malware patterns in core_config_data
// values. The path column carries dotted-namespace identifiers
// (web/unsecure/base_url, design/header/welcome, etc.) so we keep
// it in the finding details for triage.
func scanMagentoSettings(account string, creds magentoCreds) []alert.Finding {
query := fmt.Sprintf(
"SELECT path, value FROM %score_config_data WHERE %s",
creds.dbPrefix, paramsLikeClause("value"))
rows := runMySQLQuery(creds.asWPDBCreds(), query)
var findings []alert.Finding
for _, row := range rows {
cfgPath, body := splitTabRow(row)
if cfgPath == "" {
continue
}
sev, desc, ok := classifyMalwareRow(body, false)
if !ok {
continue
}
findings = append(findings, alert.Finding{
Severity: sev,
Check: "magento_settings_injection",
Message: fmt.Sprintf("Magento %s settings injection on %s: %s (%s)", creds.version, account, cfgPath, desc),
Details: fmt.Sprintf("Account: %s\nConfig path: %s\nMatch: %s", account, cfgPath, desc),
})
}
return findings
}
// scanMagentoContent walks one CMS table (catalog_product_entity_text,
// cms_block, cms_page) for malware patterns. The looser
// hasMaliciousExternalScriptInPost predicate applies because the
// tables carry author-written content.
func scanMagentoContent(account string, creds magentoCreds, table, valueCol string) []alert.Finding {
idCol := "row_id"
switch table {
case "catalog_product_entity_text":
idCol = "entity_id"
case "cms_block":
idCol = "block_id"
case "cms_page":
idCol = "page_id"
}
query := fmt.Sprintf(
"SELECT %s, %s FROM %s%s WHERE %s",
idCol, valueCol, creds.dbPrefix, table, paramsLikeClause(valueCol))
rows := runMySQLQuery(creds.asWPDBCreds(), query)
var findings []alert.Finding
for _, row := range rows {
id, body := splitTabRow(row)
if id == "" {
continue
}
sev, desc, ok := classifyMalwareRow(body, true)
if !ok {
continue
}
findings = append(findings, alert.Finding{
Severity: sev,
Check: "magento_content_injection",
Message: fmt.Sprintf("Magento %s content injection on %s: %s id=%s (%s)", creds.version, account, table, id, desc),
Details: fmt.Sprintf("Account: %s\nTable: %s\nRow id: %s\nMatch: %s", account, table, id, desc),
})
}
return findings
}
// scanMagentoAdmins enumerates the admin_user table. Rows include
// the legitimate site admin -- one Warning per row, operator
// review territory.
func scanMagentoAdmins(account string, creds magentoCreds) []alert.Finding {
query := fmt.Sprintf(
"SELECT user_id, username, email FROM %sadmin_user",
creds.dbPrefix)
rows := runMySQLQuery(creds.asWPDBCreds(), query)
if len(rows) == 0 {
return nil
}
var findings []alert.Finding
for _, row := range rows {
fields := strings.Split(row, "\t")
if len(fields) < 1 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "magento_admin_injection",
Message: fmt.Sprintf("Magento %s admin account on %s: user_id=%s", creds.version, account, fields[0]),
Details: fmt.Sprintf("Account: %s\nRow: %s\nReview: confirm this is the legitimate site administrator.", account, row),
})
}
return findings
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// OpenCart database content scanner.
//
// Discovery: glob /home/*/public_html/config.php and confirm both
// it AND /home/*/public_html/admin/config.php contain
// `define('DB_DRIVER'`. The admin-side config.php pair is the
// canonical OpenCart marker -- plain PHP sites carry a config.php
// in the document root that's nothing to do with OpenCart.
//
// Credentials use PHP define() constants with OC-specific names:
//
// DB_HOSTNAME DB_USERNAME DB_PASSWORD DB_DATABASE DB_PREFIX
//
// Reuses the existing extractDefine helper from dbscan.go (the WP
// scanner already understands this shape). DB_PREFIX defaults to
// "oc_" on vanilla installs.
//
// Scanned tables (all prefixed):
//
// <prefix>setting k/v pairs; values are JSON
// blobs. config_url / config_ssl
// are the canonical hijack
// targets for storefront redirect.
// <prefix>product_description product description text
// <prefix>information_description CMS-managed information pages
// <prefix>user admin/staff accounts.
// Customer accounts live in the
// oc_customer table, not here;
// every oc_user row is admin-shaped.
//
// Three new finding categories:
// opencart_settings_injection, opencart_content_injection,
// opencart_admin_injection.
type opencartCreds struct {
dbName string
dbUser string
dbPass string
dbHost string
dbPrefix string
path string
}
func (c opencartCreds) asWPDBCreds() wpDBCreds {
return wpDBCreds{
dbName: c.dbName,
dbUser: c.dbUser,
dbPass: c.dbPass,
dbHost: c.dbHost,
tablePrefix: c.dbPrefix,
}
}
// CheckOpenCartContent discovers OpenCart installs and scans the
// four canonical attacker-touched tables. Mirrors the other CMS
// scanners; the discovery and credentials parsing are the only
// OC-specific bits.
func CheckOpenCartContent(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if ctx == nil {
ctx = context.Background()
}
var findings []alert.Finding
configs, _ := osFS.Glob("/home/*/public_html/config.php")
if len(configs) == 0 {
return nil
}
// Rank by mtime desc so recently touched OpenCart installs are processed
// first when the check timeout cuts iteration short.
for _, path := range rankPathsByMtimeDesc(ctx, configs, effectiveAccountScanMaxFiles(cfg)) {
if ctx.Err() != nil {
return findings
}
if !looksLikeOpenCart(path) {
continue
}
account := extractUser(filepath.Dir(path))
creds := parseOpenCartConfig(path)
if creds.dbName == "" {
continue
}
prefix := creds.dbPrefix
if prefix == "" {
prefix = "oc_"
}
creds.dbPrefix = prefix
findings = append(findings, scanOpenCartSettings(account, creds)...)
findings = append(findings, scanOpenCartContentTable(account, creds, "product_description", "description")...)
findings = append(findings, scanOpenCartContentTable(account, creds, "information_description", "description")...)
findings = append(findings, scanOpenCartAdmins(account, creds)...)
}
return findings
}
// looksLikeOpenCart confirms both config.php files exist and both
// reference DB_DRIVER. The admin-side file is what distinguishes
// OpenCart from arbitrary PHP sites that happen to ship a
// config.php at the document root.
func looksLikeOpenCart(rootConfig string) bool {
if !configContainsDBDriver(rootConfig) {
return false
}
publicHTML := filepath.Dir(rootConfig)
adminConfig := filepath.Join(publicHTML, "admin", "config.php")
return configContainsDBDriver(adminConfig)
}
func configContainsDBDriver(path string) bool {
// #nosec G304 -- path resolved via osFS.Glob over /home/*/public_html or its admin/ subdir; not attacker-controlled.
data, err := osFS.ReadFile(path)
if err != nil {
return false
}
return strings.Contains(string(data), "DB_DRIVER")
}
// parseOpenCartConfig extracts the DB_* defines from a config.php.
// Reuses the WP scanner's extractDefine helper -- the OC defines
// have the same `define('KEY', 'value')` shape WP uses, and the
// helper already strips comments and walks past the key's closing
// quote correctly.
func parseOpenCartConfig(path string) opencartCreds {
creds := opencartCreds{path: path}
// #nosec G304 -- same Glob-resolved path.
data, err := osFS.ReadFile(path)
if err != nil {
return creds
}
for _, line := range strings.Split(string(data), "\n") {
if v := extractDefine(line, "DB_HOSTNAME"); v != "" {
creds.dbHost = v
}
if v := extractDefine(line, "DB_USERNAME"); v != "" {
creds.dbUser = v
}
if v := extractDefine(line, "DB_PASSWORD"); v != "" {
creds.dbPass = v
}
if v := extractDefine(line, "DB_DATABASE"); v != "" {
creds.dbName = v
}
if v := extractDefine(line, "DB_PREFIX"); v != "" {
creds.dbPrefix = v
}
}
if creds.dbHost == "" {
creds.dbHost = "localhost"
}
return creds
}
// scanOpenCartSettings walks oc_setting k/v rows. The value column
// is a JSON-serialized blob; same external-script post-filter as
// the other CMS settings scanners (strict variant -- this is
// config storage, not author-written content).
func scanOpenCartSettings(account string, creds opencartCreds) []alert.Finding {
query := fmt.Sprintf(
"SELECT `key`, value FROM %ssetting WHERE %s",
creds.dbPrefix, paramsLikeClause("value"))
rows := runMySQLQuery(creds.asWPDBCreds(), query)
var findings []alert.Finding
for _, row := range rows {
key, body := splitTabRow(row)
if key == "" {
continue
}
sev, desc, ok := classifyMalwareRow(body, false)
if !ok {
continue
}
findings = append(findings, alert.Finding{
Severity: sev,
Check: "opencart_settings_injection",
Message: fmt.Sprintf("OpenCart settings injection on %s: %s (%s)", account, key, desc),
Details: fmt.Sprintf("Account: %s\nSetting key: %s\nMatch: %s", account, key, desc),
})
}
return findings
}
// scanOpenCartContentTable walks one of the description tables
// (product_description, information_description). Both have an id
// column and a description column; the id-column name varies but
// the schema is consistent enough that we accept it as a parameter.
//
// Looser post-filter (hasMaliciousExternalScriptInPost) because
// these tables carry author-written content.
//
// Both description tables carry one row per language per product
// or page. Without filtering, a multilingual storefront emits N
// findings per malware-injected row (one per installed language).
// language_id = 1 is English / the vanilla default; non-English
// monolingual sites and genuine multilingual coverage need a
// follow-up that reads config_language_id from oc_setting first.
func scanOpenCartContentTable(account string, creds opencartCreds, table, valueCol string) []alert.Finding {
idCol := "product_id"
if table == "information_description" {
idCol = "information_id"
}
query := fmt.Sprintf(
"SELECT %s, %s FROM %s%s WHERE language_id = 1 AND %s",
idCol, valueCol, creds.dbPrefix, table, paramsLikeClause(valueCol))
rows := runMySQLQuery(creds.asWPDBCreds(), query)
var findings []alert.Finding
for _, row := range rows {
id, body := splitTabRow(row)
if id == "" {
continue
}
sev, desc, ok := classifyMalwareRow(body, true)
if !ok {
continue
}
findings = append(findings, alert.Finding{
Severity: sev,
Check: "opencart_content_injection",
Message: fmt.Sprintf("OpenCart content injection on %s: %s id=%s (%s)", account, table, id, desc),
Details: fmt.Sprintf("Account: %s\nTable: %s\nRow id: %s\nMatch: %s", account, table, id, desc),
})
}
return findings
}
// scanOpenCartAdmins enumerates the oc_user table (admins/staff,
// not customers -- customers live in oc_customer). Same Warning
// per row as the other CMS adapters.
func scanOpenCartAdmins(account string, creds opencartCreds) []alert.Finding {
query := fmt.Sprintf(
"SELECT user_id, username, email FROM %suser",
creds.dbPrefix)
rows := runMySQLQuery(creds.asWPDBCreds(), query)
if len(rows) == 0 {
return nil
}
var findings []alert.Finding
for _, row := range rows {
fields := strings.Split(row, "\t")
if len(fields) < 1 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "opencart_admin_injection",
Message: fmt.Sprintf("OpenCart admin account on %s: user_id=%s", account, fields[0]),
Details: fmt.Sprintf("Account: %s\nRow: %s\nReview: confirm this is the legitimate site administrator.", account, row),
})
}
return findings
}
package checks
import (
"fmt"
"net"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/processctx"
)
// DirectSMTPEgressInput is the input to the pure evaluator. The caller
// (BPF connection consumer or legacy poller) builds it from the live
// event and passes the platform-resolved MTA allowlist as MTA.
//
// Process is optional; when present the resulting finding includes the
// full process-ancestry tree. UID/User/PID/Comm/Exe are the live event
// fields used in finding details and account attribution.
type DirectSMTPEgressInput struct {
UID uint32
User string
PID uint32
Comm string
Exe string
DstIP net.IP
DstPort uint16
MTA platform.MTAIdents
Process *processctx.ProcessContext
// Domain is an optional rDNS-resolved name for DstIP. When set, it
// is included in the finding details. Populating it is the caller's
// responsibility (off-path enrichment lands in Task 6).
Domain string
}
// EvaluateDirectSMTPEgress returns a populated finding when the input
// represents a non-MTA local process opening an outbound SMTP connection.
// Pure function: no IO, no clock. Detector-disabled config returns
// (zero, false) without inspecting the input.
func EvaluateDirectSMTPEgress(cfg *config.Config, in DirectSMTPEgressInput) (alert.Finding, bool) {
if cfg == nil || !cfg.Detection.DirectSMTPEgress.Enabled || directSMTPEgressBackend(cfg) == "none" {
return alert.Finding{}, false
}
if in.UID == 0 {
return alert.Finding{}, false
}
if in.DstIP == nil || in.DstIP.IsLoopback() || in.DstIP.IsUnspecified() {
return alert.Finding{}, false
}
if !portInList(in.DstPort, cfg.Detection.DirectSMTPEgress.Ports) {
return alert.Finding{}, false
}
if isInfraIP(in.DstIP.String(), cfg.InfraIPs) {
return alert.Finding{}, false
}
if in.MTA.IsMTAUser(in.User) {
return alert.Finding{}, false
}
dst := in.DstIP.String()
if in.DstIP.To4() == nil {
dst = "[" + dst + "]"
}
details := fmt.Sprintf("UID: %d (%s), Process: %s, PID: %d, Destination: %s:%d",
in.UID, in.User, in.Comm, in.PID, dst, in.DstPort)
if in.Domain != "" {
details += ", Domain: " + in.Domain
}
return alert.Finding{
Severity: alert.High,
Check: "direct_smtp_egress",
Message: fmt.Sprintf("Non-MTA process opened outbound SMTP connection to %s:%d", dst, in.DstPort),
Details: details,
TenantID: directSMTPTenant(in),
Process: in.Process,
}, true
}
func DirectSMTPEgressBackendEnabled(cfg *config.Config, backend string) bool {
if cfg == nil || !cfg.Detection.DirectSMTPEgress.Enabled {
return false
}
choice := directSMTPEgressBackend(cfg)
switch choice {
case "auto":
return true
case "bpf", "legacy":
return backend == choice
default:
return false
}
}
func directSMTPEgressBackend(cfg *config.Config) string {
backend := strings.ToLower(strings.TrimSpace(cfg.Detection.DirectSMTPEgress.Backend))
if backend == "" {
return "auto"
}
return backend
}
func directSMTPTenant(in DirectSMTPEgressInput) string {
if in.Process != nil && in.Process.Account != "" {
return in.Process.Account
}
return in.User
}
func portInList(p uint16, list []int) bool {
for _, q := range list {
if q <= 0 || q > 65535 {
continue
}
if q == int(p) {
return true
}
}
return false
}
package checks
import (
"sync/atomic"
"github.com/pidginhost/csm/internal/metrics"
)
var directSMTPEgressFindingsTotal atomic.Uint64
// RegisterDirectSMTPEgressMetrics binds the per-finding counter to reg.
// Production callers should pass metrics.Default(); tests pass
// metrics.NewRegistry() to keep registration isolated.
func RegisterDirectSMTPEgressMetrics(reg *metrics.Registry) {
reg.RegisterCounterFunc(
"csm_direct_smtp_egress_findings_total",
"Direct SMTP egress findings emitted by the connection consumer.",
func() float64 { return float64(directSMTPEgressFindingsTotal.Load()) },
)
}
// BumpDirectSMTPEgressFindings increments the per-finding counter.
// Called by the connection consumer when EvaluateDirectSMTPEgress
// returns a finding.
func BumpDirectSMTPEgressFindings() {
directSMTPEgressFindingsTotal.Add(1)
}
// resetDirectSMTPEgressMetricsForTest is a test seam.
func resetDirectSMTPEgressMetricsForTest() {
directSMTPEgressFindingsTotal.Store(0)
}
package checks
import (
"bufio"
"context"
"fmt"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// DNS server processes that legitimately connect to many resolvers
// (e.g. BIND doing recursive resolution on a cPanel server).
var dnsServerUsers = map[string]bool{
"named": true, // BIND
"unbound": true, // Unbound
"pdns": true, // PowerDNS
}
// CheckDNSConnections looks for established connections to port 53 on
// DNS servers that are NOT in /etc/resolv.conf. This catches DNS
// tunneling, GSocket relay discovery, and malware using hardcoded resolvers.
// Connections owned by known DNS server processes (e.g. named) are skipped.
func CheckDNSConnections(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Parse configured resolvers
resolvers := parseResolvers()
if len(resolvers) == 0 {
return nil
}
// Also allow infra IPs and localhost
allowed := make(map[string]bool)
allowed["127.0.0.1"] = true
allowed["0.0.0.0"] = true
for _, r := range resolvers {
allowed[r] = true
}
// Build a set of UIDs belonging to DNS server processes (named, unbound,
// etc.) so we can skip their connections without reading /etc/passwd
// on every loop iteration.
dnsServerUIDs := resolveDNSServerUIDs()
// Parse /proc/net/tcp for connections to port 53
data, err := osFS.ReadFile("/proc/net/tcp")
if err != nil {
return nil
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) < 8 || fields[0] == "sl" {
continue
}
// State 01 = ESTABLISHED
if fields[3] != "01" {
continue
}
remoteIP, remotePort := parseHexAddr(fields[2])
if remotePort != 53 {
continue
}
if allowed[remoteIP] {
continue
}
// Check if it's an infra IP
if isInfraIP(remoteIP, cfg.InfraIPs) {
continue
}
// Skip connections owned by DNS server processes (e.g. named
// doing recursive resolution talks to many different servers)
if dnsServerUIDs[fields[7]] {
continue
}
_, localPort := parseHexAddr(fields[1])
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "dns_connection",
Message: fmt.Sprintf("DNS connection to non-configured resolver: %s", remoteIP),
Details: fmt.Sprintf("Local port: %d, Remote: %s:53\nConfigured resolvers: %s", localPort, remoteIP, strings.Join(resolvers, ", ")),
})
}
return findings
}
// resolveDNSServerUIDs returns a set of UIDs that belong to known DNS
// server users (named, unbound, pdns) by reading /etc/passwd once.
func resolveDNSServerUIDs() map[string]bool {
uids := make(map[string]bool)
data, err := osFS.ReadFile("/etc/passwd")
if err != nil {
return uids
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Split(line, ":")
if len(fields) >= 3 && dnsServerUsers[fields[0]] {
uids[fields[2]] = true
}
}
return uids
}
func parseResolvers() []string {
f, err := osFS.Open("/etc/resolv.conf")
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var resolvers []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "nameserver ") {
ip := strings.TrimSpace(strings.TrimPrefix(line, "nameserver"))
if ip != "" {
resolvers = append(resolvers, ip)
}
}
}
return resolvers
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// maxBulkDNSChanges is the threshold above which zone changes are considered
// a cPanel bulk operation (AutoSSL, serial bump, DNSSEC rotation) and suppressed.
// Only 1-5 zone changes at once are reported - likely targeted modifications.
const maxBulkDNSChanges = 5
// CheckDNSZoneChanges monitors named zone files for modifications.
// Suppresses bulk changes (>5 zones at once = cPanel maintenance).
// Only alerts on targeted changes (1-5 zones) which may indicate tampering.
func CheckDNSZoneChanges(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
// cPanel stores zone files in /var/named/
zoneDir := "/var/named"
zones, err := osFS.ReadDir(zoneDir)
if err != nil {
return nil
}
// First pass: count how many zones changed
var changedZones []string
for _, zone := range zones {
if zone.IsDir() {
continue
}
name := zone.Name()
if !strings.HasSuffix(name, ".db") {
continue
}
fullPath := filepath.Join(zoneDir, name)
hash, err := hashFileContent(fullPath)
if err != nil {
continue
}
key := "_dns_zone:" + name
prev, exists := store.GetRaw(key)
store.SetRaw(key, hash)
if exists && prev != hash {
changedZones = append(changedZones, name)
}
}
// If many zones changed at once, it's cPanel maintenance - suppress
if len(changedZones) > maxBulkDNSChanges {
return nil
}
// Only alert on targeted changes (1-5 zones)
var findings []alert.Finding
for _, name := range changedZones {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "dns_zone_change",
Message: fmt.Sprintf("DNS zone file modified: %s", name),
Details: fmt.Sprintf("File: %s\nThis could indicate DNS hijacking or unauthorized domain changes", filepath.Join(zoneDir, name)),
})
}
return findings
}
// CheckSSLCertIssuance monitors AutoSSL logs for new certificate issuance.
// Attackers may issue certificates for phishing domains using compromised accounts.
func CheckSSLCertIssuance(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
// Check AutoSSL log
logPath := "/var/cpanel/logs/autossl"
entries, err := osFS.ReadDir(logPath)
if err != nil {
return nil
}
// Track the count of log files as a simple change indicator
currentCount := len(entries)
key := "_ssl_autossl_count"
prev, exists := store.GetRaw(key)
store.SetRaw(key, fmt.Sprintf("%d", currentCount))
if !exists {
return nil
}
prevCount := 0
fmt.Sscanf(prev, "%d", &prevCount)
if currentCount > prevCount {
// New AutoSSL activity - check the latest log
var latestLog string
var latestTime int64
for _, entry := range entries {
if entry.IsDir() {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
if info.ModTime().Unix() > latestTime {
latestTime = info.ModTime().Unix()
latestLog = filepath.Join(logPath, entry.Name())
}
}
if latestLog != "" {
// Read the tail of the latest log for certificate issuance
lines := tailFile(latestLog, 50)
for _, line := range lines {
lineLower := strings.ToLower(line)
if strings.Contains(lineLower, "installed") || strings.Contains(lineLower, "issued") {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "ssl_cert_issued",
Message: "New SSL certificate issued via AutoSSL",
Details: truncate(line, 300),
})
}
}
}
}
return findings
}
package checks
import (
"bufio"
"context"
"crypto/sha1" // #nosec G505 -- SHA1 is the digest format required by the Have I Been Pwned range API (https://haveibeenpwned.com/API/v3#PwnedPasswords). We send the first 5 chars of the digest and compare remaining chars against the returned list — HIBP does not offer a stronger-hash endpoint.
"crypto/sha256"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"unicode"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
)
// doveadmSemaphore limits concurrent doveadm processes.
var doveadmSemaphore = make(chan struct{}, 3)
// hibpClient is used for HIBP API requests.
var hibpClient = &http.Client{Timeout: 10 * time.Second}
// hibpEndpoint is the base URL queried for HIBP password-range lookups.
// Declared as a var (not const) so tests can swap in an httptest server.
// Production callers must not modify this.
var hibpEndpoint = "https://api.pwnedpasswords.com/range/"
// currentYear returns the current year. Called each audit cycle so
// long-running daemons don't use a stale year after Jan 1.
func currentYear() int { return time.Now().Year() }
// weakPasswordCache caches the bundled wordlist (loaded once).
var (
weakPasswordOnce sync.Once
weakPasswords []string
)
// parseShadowLine parses a Dovecot shadow line "mailbox:{scheme}hash".
// Returns empty strings if the line is malformed.
func parseShadowLine(line string) (mailbox, hash string) {
idx := strings.IndexByte(line, ':')
if idx <= 0 || idx >= len(line)-1 {
return "", ""
}
return line[:idx], line[idx+1:]
}
// isLockedHash returns true if the hash indicates a locked/disabled account.
func isLockedHash(hash string) bool {
if hash == "" {
return true
}
return hash[0] == '!' || hash[0] == '*'
}
// generateCandidates creates password candidates from username and domain.
// All candidates are >= 6 characters. No duplicates.
func generateCandidates(username, domain string) []string {
seen := make(map[string]bool)
var candidates []string
add := func(s string) {
if len(s) >= 6 && !seen[s] {
seen[s] = true
candidates = append(candidates, s)
}
}
domainLabel := domain
if idx := strings.IndexByte(domain, '.'); idx > 0 {
domainLabel = domain[:idx]
}
bases := []string{username, domainLabel}
for _, base := range bases {
add(base)
// Capitalize first letter variant
upper := capitalizeFirst(base)
add(upper)
// Year variants: current year +/- 2
year := currentYear()
for y := year - 2; y <= year+2; y++ {
ys := strconv.Itoa(y)
add(base + ys)
add(upper + ys)
}
// Two-digit suffix variants: 00-99
for n := 0; n <= 99; n++ {
suffix := fmt.Sprintf("%02d", n)
add(base + suffix)
add(upper + suffix)
}
}
return candidates
}
// capitalizeFirst returns the string with its first rune upper-cased.
func capitalizeFirst(s string) string {
if len(s) == 0 {
return s
}
runes := []rune(s)
runes[0] = unicode.ToUpper(runes[0])
return string(runes)
}
// verifyDoveadm checks a candidate password against a stored hash.
// Returns true if the password matches.
func verifyDoveadm(hash, candidate string) bool {
return verifyDoveadmContext(context.Background(), hash, candidate)
}
func verifyDoveadmContext(ctx context.Context, hash, candidate string) bool {
if ctx == nil {
ctx = context.Background()
}
select {
case doveadmSemaphore <- struct{}{}:
case <-ctx.Done():
return false
}
defer func() { <-doveadmSemaphore }()
if ctx.Err() != nil {
return false
}
// Routed through cmdExec so tests can mock doveadm without a real install.
_, err := cmdExec.RunContext(ctx, "doveadm", "pw", "-t", hash, "-p", candidate)
return err == nil
}
// hashFingerprint returns a SHA256 hex fingerprint of a password hash
// (used for change detection -- re-audit only when hash changes).
func hashFingerprint(hash string) string {
h := sha256.Sum256([]byte(hash))
return fmt.Sprintf("%x", h[:])
}
// parseHIBPCount searches a HIBP range response body for a hash suffix
// and returns the breach count. Returns 0 if not found.
func parseHIBPCount(body, suffix string) int {
upperSuffix := strings.ToUpper(suffix)
for _, line := range strings.Split(body, "\n") {
line = strings.TrimSpace(line)
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
if strings.ToUpper(strings.TrimSpace(parts[0])) == upperSuffix {
count, err := strconv.Atoi(strings.TrimSpace(parts[1]))
if err != nil {
return 0
}
return count
}
}
return 0
}
// checkHIBP queries the HIBP Pwned Passwords API for a plaintext password.
// Returns the breach count (0 if not found or on error).
func checkHIBP(plaintext string) int {
return checkHIBPWithContext(context.Background(), plaintext)
}
func checkHIBPWithContext(ctx context.Context, plaintext string) int {
if ctx == nil {
ctx = context.Background()
}
if ctx.Err() != nil {
return 0
}
// #nosec G401 -- SHA1 is mandated by the HIBP Pwned Passwords range API; see import comment.
h := sha1.Sum([]byte(plaintext))
hex := fmt.Sprintf("%X", h[:])
prefix := hex[:5]
suffix := hex[5:]
req, err := http.NewRequestWithContext(ctx, http.MethodGet, hibpEndpoint+prefix, nil)
if err != nil {
return 0
}
resp, err := hibpClient.Do(req)
if err != nil {
return 0
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0
}
return parseHIBPCount(string(body), suffix)
}
type shadowFile struct {
path string
account string
domain string
}
type mailboxEntry struct {
account string
domain string
mailbox string
hash string
}
// discoverShadowFiles finds all Dovecot shadow files under /home/*/etc/*/shadow.
// Results are ranked by mtime desc so recently touched mailbox password files
// are inspected first when the check timeout cuts work short. maxFiles caps
// iteration; 0 disables the cap.
func discoverShadowFiles(ctx context.Context, maxFiles int) []shadowFile {
matches, _ := osFS.Glob("/home/*/etc/*/shadow")
ranked := rankPathsByMtimeDesc(ctx, matches, maxFiles)
results := make([]shadowFile, 0, len(ranked))
for _, m := range ranked {
parts := strings.Split(m, "/")
// /home/{account}/etc/{domain}/shadow
if len(parts) >= 5 {
results = append(results, shadowFile{
path: m,
account: parts[2],
domain: parts[4],
})
}
}
return results
}
// readShadowFile reads all mailbox entries from a Dovecot shadow file.
func readShadowFile(sf shadowFile) []mailboxEntry {
f, err := osFS.Open(sf.path)
if err != nil {
return nil
}
defer f.Close()
var entries []mailboxEntry
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
mailbox, hash := parseShadowLine(line)
if mailbox == "" || hash == "" {
continue
}
if isLockedHash(hash) {
continue
}
entries = append(entries, mailboxEntry{
account: sf.account,
domain: sf.domain,
mailbox: mailbox,
hash: hash,
})
}
return entries
}
// loadWeakPasswords reads the bundled wordlist once and caches it.
func loadWeakPasswords() []string {
weakPasswordOnce.Do(func() {
f, err := osFS.Open("/opt/csm/configs/weak_passwords.txt")
if err != nil {
return
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
word := strings.TrimSpace(scanner.Text())
if len(word) < 6 || strings.HasPrefix(word, "#") {
continue
}
weakPasswords = append(weakPasswords, word)
}
})
return weakPasswords
}
// checkWordlist tests a hash against the bundled weak passwords list.
// Returns the matched password or empty string.
func checkWordlist(hash string) string {
return checkWordlistContext(context.Background(), hash)
}
func checkWordlistContext(ctx context.Context, hash string) string {
if ctx == nil {
ctx = context.Background()
}
for _, word := range loadWeakPasswords() {
if ctx.Err() != nil {
return ""
}
if verifyDoveadmContext(ctx, hash, word) {
return word
}
}
return ""
}
// CheckEmailPasswords audits Dovecot email account passwords for weak/predictable
// patterns. Uses internal throttle: skips if last refresh was less than
// password_check_interval_min ago.
func CheckEmailPasswords(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if ctx == nil {
ctx = context.Background()
}
db := store.Global()
if db == nil {
return nil
}
// Internal throttle -- same pattern as CheckOutdatedPlugins
if !ForceAll {
lastRefresh := db.GetEmailPWLastRefresh()
interval := time.Duration(cfg.EmailProtection.PasswordCheckIntervalMin) * time.Minute
if time.Since(lastRefresh) < interval {
return nil
}
}
shadowFiles := discoverShadowFiles(ctx, effectiveAccountScanMaxFiles(cfg))
if len(shadowFiles) == 0 {
return nil
}
// Collect all mailbox entries
var allEntries []mailboxEntry
for _, sf := range shadowFiles {
if ctx.Err() != nil {
return nil
}
allEntries = append(allEntries, readShadowFile(sf)...)
}
if ctx.Err() != nil {
return nil
}
if len(allEntries) == 0 {
_ = db.SetEmailPWLastRefresh(time.Now())
return nil
}
var mu sync.Mutex
var findings []alert.Finding
var wg sync.WaitGroup
// Process each mailbox concurrently (bounded by semaphore)
sem := make(chan struct{}, 5)
for _, entry := range allEntries {
if ctx.Err() != nil {
break
}
wg.Add(1)
go func(e mailboxEntry) {
defer wg.Done()
select {
case sem <- struct{}{}:
case <-ctx.Done():
return
}
defer func() { <-sem }()
if ctx.Err() != nil {
return
}
fullMailbox := e.mailbox + "@" + e.domain
storeKey := fmt.Sprintf("email:pwaudit:%s:%s", e.account, fullMailbox)
// Skip if hash hasn't changed since last audit
fp := hashFingerprint(e.hash)
if stored := db.GetMetaString(storeKey); stored == fp {
return
}
// Layer 1: Heuristic candidates
candidates := generateCandidates(e.mailbox, e.domain)
var matched string
var matchType string
for _, c := range candidates {
if ctx.Err() != nil {
return
}
if verifyDoveadmContext(ctx, e.hash, c) {
matched = c
matchType = "heuristic"
break
}
}
// Layer 2: Common wordlist (skip if layer 1 matched)
if matched == "" {
if ctx.Err() != nil {
return
}
if w := checkWordlistContext(ctx, e.hash); w != "" {
matched = w
matchType = "wordlist"
}
}
if matched != "" {
details := fmt.Sprintf("Account: %s\nMailbox: %s\nMatch type: %s\nMatched password pattern: %q",
e.account, fullMailbox, matchType, matched)
// Layer 3: HIBP enrichment (only for confirmed matches)
breachCount := checkHIBPWithContext(ctx, matched)
if breachCount > 0 {
details += fmt.Sprintf("\nHIBP: password found in %d data breaches", breachCount)
}
mu.Lock()
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_weak_password",
Message: fmt.Sprintf("Weak email password for %s (account: %s)", fullMailbox, e.account),
Details: details,
Domain: e.domain,
Mailbox: fullMailbox,
})
mu.Unlock()
}
// Record fingerprint so we don't re-audit until hash changes
if ctx.Err() != nil {
return
}
_ = db.SetMetaString(storeKey, fp)
}(entry)
}
wg.Wait()
if ctx.Err() != nil {
return findings
}
_ = db.SetEmailPWLastRefresh(time.Now())
return findings
}
package checks
import (
"bytes"
"context"
"fmt"
"path/filepath"
"regexp"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/emailspool"
"github.com/pidginhost/csm/internal/state"
)
const emailBodySampleSize = 4096 // Read first 4KB of email body
// Suspicious mailer headers that indicate mass mailing scripts.
var suspiciousMailers = []string{
"phpmailer", "swiftmailer", "mass mailer", "bulk mailer",
"leaf phpmailer", "phpmail", "mail.php",
}
// Known safe mailers that should not be flagged.
var safeMailers = []string{
"wordpress", "woocommerce", "roundcube", "squirrelmail",
"thunderbird", "outlook", "apple mail", "cpanel",
"postfix", "exim", "dovecot",
}
// Phishing URL patterns in email body.
var emailPhishPatterns = []string{
".workers.dev",
"//bit.ly/", "//tinyurl.com/", "//is.gd/", "//rb.gy/",
"//t.co/",
"/redir?url=", "/redirect?url=", "link?url=",
"effi.redir",
}
// Phishing language in email body.
var emailPhishLanguage = []string{
"verify your account",
"confirm your identity",
"unusual activity",
"your account will be",
"suspended unless",
"click here to verify",
"update your payment",
"confirm your email address",
"security alert",
"unauthorized access",
}
// Brand impersonation in email body (when sender doesn't match).
var emailBrandNames = []string{
"paypal", "microsoft", "apple", "google", "amazon",
"netflix", "facebook", "instagram", "bank of",
"wells fargo", "chase", "citibank",
}
// CheckOutboundEmailContent samples outbound email content from Exim spool
// for phishing URLs, credential harvesting language, suspicious mailers,
// and Reply-To mismatches.
func CheckOutboundEmailContent(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Read recent outbound messages from exim_mainlog
lines := tailFile("/var/log/exim_mainlog", 200)
if len(lines) == 0 {
return nil
}
// Extract message IDs for outbound emails
msgIDRegex := regexp.MustCompile(`^\S+\s+(\S+)\s+<=\s+(\S+)`)
scanned := make(map[string]bool) // avoid scanning same message twice
for _, line := range lines {
matches := msgIDRegex.FindStringSubmatch(line)
if len(matches) < 3 {
continue
}
msgID := matches[1]
sender := matches[2]
if sender == "<>" || sender == "" {
continue // bounce message
}
if scanned[msgID] {
continue
}
scanned[msgID] = true
// Scan the message
result := scanEximMessage(msgID, sender, cfg)
if result != nil {
findings = append(findings, *result)
}
}
return findings
}
// scanEximMessage reads an Exim spool message and checks for suspicious content.
func scanEximMessage(msgID, sender string, cfg *config.Config) *alert.Finding {
// Exim spool paths
// Headers: /var/spool/exim/input/{msgID}-H
// Body: /var/spool/exim/input/{msgID}-D
spoolDirs := []string{
"/var/spool/exim/input",
"/var/spool/exim4/input",
}
var headerPath, bodyPath string
for _, dir := range spoolDirs {
h := filepath.Join(dir, msgID+"-H")
b := filepath.Join(dir, msgID+"-D")
if _, err := osFS.Stat(h); err == nil {
headerPath = h
bodyPath = b
break
}
}
if headerPath == "" {
return nil // message already delivered/removed from spool
}
var indicators []string
// Read and parse headers via the emailspool Exim -H parser. We go through
// osFS.ReadFile + bytes.NewReader rather than emailspool.ParseHeaders(path)
// so check tests can inject mock spool contents through the existing
// osFS seam.
headerData, err := osFS.ReadFile(headerPath)
if err != nil {
return nil
}
parsed, err := emailspool.ParseHeadersReader(bytes.NewReader(headerData))
if err != nil {
// Malformed or truncated -H file: nothing to check, skip silently as
// the previous loose parse would have done.
return nil
}
// Lower-cased raw bytes are still required for the base64-text/html
// combination heuristic, which inspects MIME framing the emailspool
// Headers struct does not surface.
headersLower := strings.ToLower(string(headerData))
// Check 1: Reply-To mismatch
if parsed.From != "" && parsed.ReplyTo != "" {
fromDomain := emailspool.ExtractDomain(parsed.From)
replyDomain := emailspool.ExtractDomain(parsed.ReplyTo)
if fromDomain != "" && replyDomain != "" && fromDomain != replyDomain {
indicators = append(indicators, fmt.Sprintf("Reply-To mismatch: From=%s, Reply-To=%s", fromDomain, replyDomain))
}
}
// Check 2: Suspicious X-Mailer (fall back to User-Agent, which the
// emailspool parser surfaces alongside the other RFC 5322 fields).
mailer := parsed.XMailer
if mailer == "" {
mailer = parsed.UserAgent
}
if mailer != "" {
mailerLower := strings.ToLower(mailer)
isSafe := false
for _, safe := range safeMailers {
if strings.Contains(mailerLower, safe) {
isSafe = true
break
}
}
if !isSafe {
for _, suspicious := range suspiciousMailers {
if strings.Contains(mailerLower, suspicious) {
indicators = append(indicators, fmt.Sprintf("suspicious mailer: %s", strings.TrimSpace(mailer)))
break
}
}
}
}
// Check 3: Spoofed display name (brand name in From: but sender is not that brand)
if parsed.From != "" {
fromLower := strings.ToLower(parsed.From)
senderDomain := emailspool.ExtractDomain(sender)
for _, brand := range emailBrandNames {
if strings.Contains(fromLower, brand) && !strings.Contains(strings.ToLower(senderDomain), brand) {
indicators = append(indicators, fmt.Sprintf("spoofed brand in From: '%s' (actual sender: %s)", strings.TrimSpace(parsed.From), sender))
break
}
}
}
// Read and analyze body (sample first 4KB)
bodyData, _ := osFS.ReadFile(bodyPath)
if len(bodyData) > emailBodySampleSize {
bodyData = bodyData[:emailBodySampleSize]
}
if len(bodyData) > 0 {
bodyLower := strings.ToLower(string(bodyData))
// Check 4: Phishing URLs in body
for _, pattern := range emailPhishPatterns {
if strings.Contains(bodyLower, pattern) {
indicators = append(indicators, fmt.Sprintf("phishing URL pattern: %s", pattern))
break
}
}
// Check 5: Credential harvesting language
harvestCount := 0
for _, phrase := range emailPhishLanguage {
if strings.Contains(bodyLower, phrase) {
harvestCount++
}
}
if harvestCount >= 2 {
indicators = append(indicators, fmt.Sprintf("credential harvesting language (%d phrases)", harvestCount))
}
// Check 6: Base64-encoded HTML body (used to bypass filters)
if strings.Contains(headersLower, "content-transfer-encoding: base64") &&
strings.Contains(headersLower, "text/html") {
// Check if decoded content has phishing patterns
indicators = append(indicators, "base64-encoded HTML body (potential filter bypass)")
}
}
if len(indicators) == 0 {
return nil
}
severity := alert.High
if len(indicators) >= 3 {
severity = alert.Critical
}
_, senderDomain := alert.SplitEmail(sender)
return &alert.Finding{
Severity: severity,
Check: "email_phishing_content",
Message: fmt.Sprintf("Suspicious outbound email from %s (message: %s)", sender, msgID),
Details: fmt.Sprintf("Indicators:\n- %s", strings.Join(indicators, "\n- ")),
Domain: senderDomain,
Mailbox: sender,
}
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strconv"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// uidStringToUser converts a string-form uid (as read from /proc/<pid>/status)
// to a username via the cached LookupUser. Falls back to the raw uid string
// when it is not parseable as a uint32.
func uidStringToUser(uid string) string {
u64, err := strconv.ParseUint(uid, 10, 32)
if err != nil {
return uid
}
return LookupUser(uint32(u64))
}
// CheckDatabaseDumps detects mysqldump/pg_dump processes running under
// non-root users - potential data exfiltration.
func CheckDatabaseDumps(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
dumpTools := []string{"mysqldump", "pg_dump", "mongodump"}
procs, _ := osFS.Glob("/proc/[0-9]*/cmdline")
for _, cmdPath := range procs {
pid := filepath.Base(filepath.Dir(cmdPath))
// Read UID
statusData, _ := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
var uid string
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
uid = fields[0]
}
}
}
// Skip root - root may run legitimate backups
if uid == "0" || uid == "" {
continue
}
cmdline, err := osFS.ReadFile(cmdPath)
if err != nil {
continue
}
cmdStr := strings.ReplaceAll(string(cmdline), "\x00", " ")
for _, tool := range dumpTools {
if strings.Contains(cmdStr, tool) {
user := uidStringToUser(uid)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "database_dump",
Message: fmt.Sprintf("Database dump by non-root user: %s (%s)", user, tool),
Details: fmt.Sprintf("PID: %s, UID: %s, cmdline: %s", pid, uid, strings.TrimSpace(cmdStr)),
})
break
}
}
}
return findings
}
// CheckOutboundPasteSites detects connections to known paste/exfiltration sites.
func CheckOutboundPasteSites(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Check running processes for connections to paste sites
pasteSites := []string{
"pastebin.com", "hastebin.com", "ghostbin.co",
"paste.ee", "dpaste.org", "gist.githubusercontent.com",
"raw.githubusercontent.com", "transfer.sh",
"file.io", "0x0.st", "ix.io",
}
procs, _ := osFS.Glob("/proc/[0-9]*/cmdline")
for _, cmdPath := range procs {
pid := filepath.Base(filepath.Dir(cmdPath))
statusData, _ := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
var uid string
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
uid = fields[0]
}
}
}
if uid == "0" {
continue
}
cmdline, err := osFS.ReadFile(cmdPath)
if err != nil {
continue
}
cmdStr := strings.ToLower(strings.ReplaceAll(string(cmdline), "\x00", " "))
// Check if process is connecting to paste sites
for _, site := range pasteSites {
if strings.Contains(cmdStr, site) {
user := uidStringToUser(uid)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "exfiltration_paste_site",
Message: fmt.Sprintf("Process connecting to paste/exfiltration site: %s (user: %s)", site, user),
Details: fmt.Sprintf("PID: %s, cmdline: %s", pid, strings.TrimSpace(cmdStr)),
})
break
}
}
}
return findings
}
package checks
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"sync/atomic"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// fileIndexScanCount tracks the number of CheckFileIndex invocations.
// Every 6th scan forces a full directory rescan, bypassing the mtime cache
// to catch writes that don't update parent directory mtime (e.g. hard links).
var fileIndexScanCount int32
// suspiciousExtensions are file extensions that should never appear in web roots.
var suspiciousExtensions = map[string]bool{
".phtml": true, ".pht": true, ".php5": true,
".haxor": true, ".cgix": true,
}
// dirMtimeCache maps directory paths to their last-known mtime (unix seconds).
// Directories with unchanged mtime are skipped during scanning.
type dirMtimeCache map[string]int64
func loadDirCache(stateDir string) dirMtimeCache {
cache := make(dirMtimeCache)
data, err := osFS.ReadFile(filepath.Join(stateDir, "dircache.json"))
if err == nil {
_ = json.Unmarshal(data, &cache)
}
return cache
}
func saveDirCache(stateDir string, cache dirMtimeCache) {
data, _ := json.Marshal(cache)
tmpPath := filepath.Join(stateDir, "dircache.json.tmp")
_ = os.WriteFile(tmpPath, data, 0600)
_ = os.Rename(tmpPath, filepath.Join(stateDir, "dircache.json"))
}
// dirChanged returns true if the directory mtime has changed since last scan.
// Updates the cache with the new mtime. If forceFullScan is true, always
// returns true to force a ReadDir regardless of mtime (catches writes that
// bypass parent mtime updates, e.g. hard links or mount tricks).
func dirChanged(dir string, cache dirMtimeCache, forceFullScan bool) bool {
info, err := osFS.Stat(dir)
if err != nil {
return true // can't stat, scan it to be safe
}
mtime := info.ModTime().Unix()
prev, exists := cache[dir]
cache[dir] = mtime
if forceFullScan {
return true
}
if !exists {
return true // first time seeing this dir
}
return mtime != prev
}
// CheckFileIndex builds an index of suspicious files using pure Go directory
// reads, diffs against the previous index, and alerts on new files.
// Uses directory mtime caching: unchanged dirs carry forward previous entries
// without calling ReadDir, while changed dirs are re-scanned.
func CheckFileIndex(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
if ctx == nil {
ctx = context.Background()
}
scanNum := atomic.AddInt32(&fileIndexScanCount, 1)
forceFullScan := scanNum%6 == 0 // full rescan every 6th cycle
indexDir := cfg.StatePath
currentPath := filepath.Join(indexDir, "fileindex.current")
previousPath := filepath.Join(indexDir, "fileindex.previous")
// Load caches
dirCache := loadDirCache(indexDir)
// Build a set of previous entries grouped by directory ancestry, so
// unchanged dirs can carry forward their whole subtree without ReadDir.
previousEntries := loadIndex(previousPath)
prevByDir := groupEntriesByUploadDir(previousEntries)
// Build current index
currentEntries := buildFileIndex(ctx, dirCache, prevByDir, forceFullScan)
// A cancelled scan produced a partial index. Do not write or promote it:
// the partial set or its mtimes would make the next scan compare against
// stale cache state instead of the last complete baseline.
if ctx.Err() != nil {
return nil
}
// Save updated dir cache
saveDirCache(indexDir, dirCache)
// Write current index (atomic)
writeIndex(currentPath, currentEntries)
// First run - save baseline
if _, err := osFS.Stat(previousPath); os.IsNotExist(err) {
copyFile(currentPath, previousPath)
return nil
}
// Validate index
if len(previousEntries) > 10 && len(currentEntries) == 0 {
fmt.Fprintf(os.Stderr, "file_index: current index empty but previous had %d entries, skipping diff\n", len(previousEntries))
return nil
}
if len(previousEntries) > 0 && len(currentEntries) < len(previousEntries)/2 {
fmt.Fprintf(os.Stderr, "file_index: current index (%d) < half of previous (%d), skipping diff\n", len(currentEntries), len(previousEntries))
return nil
}
// Diff
prevSet := make(map[string]bool, len(previousEntries))
for _, e := range previousEntries {
prevSet[e] = true
}
var newFiles []string
for _, e := range currentEntries {
if !prevSet[e] {
newFiles = append(newFiles, e)
}
}
// Analyze new files - lazy stat only for alerting
for _, path := range newFiles {
name := filepath.Base(path)
nameLower := strings.ToLower(name)
suppressed := false
for _, ignore := range cfg.Suppressions.IgnorePaths {
if matchGlob(path, ignore) {
suppressed = true
break
}
}
if suppressed {
continue
}
severity := alert.Severity(-1)
check := ""
message := ""
if strings.Contains(path, "/wp-content/uploads/") && phpPathExecutes(path, nameLower) {
// Content decides, never the path or name. A negative
// severity is a content-verified inert stub (e.g. the
// WordPress "silence is golden" index.php, or BackWPup's
// "<?php //<json>" working files) and is suppressed; any
// real code surfaces, malicious or merely present.
if sev, ck, msg := classifyUploadPHP(path); sev >= 0 {
severity = sev
check = ck
message = msg
} else {
continue
}
}
// PHP files in wp-content/languages and wp-content/upgrade: content-first.
// Path-only Critical buried real alerts under location noise (WPML
// translation queues, WP auto-update staging). See classifySensitiveDirPHP.
if sev, ck, msg := classifySensitiveDirPHP(path, name); sev >= 0 {
severity = sev
check = ck
message = msg
}
if strings.Contains(path, "/.config/") {
severity = alert.Critical
check = "new_executable_in_config"
message = fmt.Sprintf("New executable in .config: %s", path)
}
if isWebshellName(nameLower) {
severity = alert.Critical
check = "new_webshell_file"
message = fmt.Sprintf("New file with webshell name: %s", path)
}
if isExecutablePHPName(nameLower) && isSuspiciousPHPName(nameLower) {
if severity < 0 {
severity = alert.High
check = "new_suspicious_php"
message = fmt.Sprintf("New suspicious PHP file: %s", path)
}
}
if severity >= 0 {
details := ""
if info, err := osFS.Stat(path); err == nil {
details = fmt.Sprintf("Size: %d, Mtime: %s", info.Size(), info.ModTime().Format("2006-01-02 15:04:05"))
}
findings = append(findings, alert.Finding{
Severity: severity,
Check: check,
Message: message,
Details: details,
FilePath: path,
})
}
}
copyFile(currentPath, previousPath)
return findings
}
// classifySensitiveDirPHP returns (severity, check, message) for a PHP file
// in /wp-content/languages/ or /wp-content/upgrade/. Returns a negative
// severity when the path is not in a sensitive dir.
//
// Every PHP file in a sensitive dir is content-analysed -- there is no
// filename allowlist, so an attacker cannot hide a backdoor by naming it like
// a translation or index file. A real indicator keeps Critical severity and
// the content-based check name (obfuscated_php / suspicious_php_content, both
// already wired into autoresponse, remediate, correlation, and attackdb). A
// clean file is demoted to Warning with check new_php_in_sensitive_dir_clean,
// which is intentionally NOT in any of those maps -- a clean file is a
// visibility signal, not an attack. Mirrors the realtime path at fanotify.go.
func classifySensitiveDirPHP(path, name string) (alert.Severity, string, string) {
nameLower := strings.ToLower(name)
if !phpPathExecutes(path, nameLower) {
return -1, "", ""
}
isLanguages := strings.Contains(path, "/wp-content/languages/")
isUpgrade := strings.Contains(path, "/wp-content/upgrade/")
if !isLanguages && !isUpgrade {
return -1, "", ""
}
locLabel := "wp-content/languages"
if isUpgrade {
locLabel = "wp-content/upgrade"
}
result := analyzePHPContent(path)
if result.severity >= 0 {
return result.severity, result.check, fmt.Sprintf("%s: %s", result.message, path)
}
// Fail closed: an unreadable body (attacker racing the scanner with rm or
// chmod 000) must not be demoted to a clean Warning. Mirrors classifyUploadPHP.
if !result.readOK {
return alert.High, "new_php_in_sensitive_dir",
fmt.Sprintf("New unreadable PHP file in %s: %s", locLabel, path)
}
if result.empty {
return alert.High, "new_php_in_sensitive_dir",
fmt.Sprintf("New empty PHP file in %s: %s", locLabel, path)
}
// Content-verified inert stub (e.g. the "silence is golden" index.php) is
// suppressed; any real code surfaces as a non-actionable visibility Warning.
if IsBenignPHPStub(path) {
return -1, "", ""
}
return alert.Warning, "new_php_in_sensitive_dir_clean",
fmt.Sprintf("New PHP file in %s (content clean): %s", locLabel, path)
}
// groupEntriesByUploadDir groups index entries by each ancestor directory.
// Used to carry forward the whole cached subtree when an unchanged directory
// is skipped before walking into its children.
func groupEntriesByUploadDir(entries []string) map[string][]string {
grouped := make(map[string][]string)
for _, path := range entries {
for dir := filepath.Dir(path); dir != "." && dir != string(filepath.Separator); dir = filepath.Dir(dir) {
grouped[dir] = append(grouped[dir], path)
}
}
return grouped
}
// buildFileIndex scans targeted directory subtrees using ReadDir.
// Skips directories whose mtime hasn't changed - carries forward
// their entries from the previous index instead.
// If forceFullScan is true, all directories are re-scanned regardless of mtime.
func buildFileIndex(ctx context.Context, dirCache dirMtimeCache, prevByDir map[string][]string, forceFullScan bool) []string {
var entries []string
if ctx == nil {
ctx = context.Background()
}
if ctx.Err() != nil {
return nil
}
homeDirs, err := GetScanHomeDirs(ctx)
if err != nil {
return nil
}
for _, homeEntry := range homeDirs {
// A cancelled scan must stop walking and let the caller discard the
// partial index rather than promote it as the new baseline.
if ctx.Err() != nil {
return entries
}
if !homeEntry.IsDir() {
continue
}
user := homeEntry.Name()
homeDir := filepath.Join("/home", user)
// Scan wp-content/uploads for PHP files
uploadDirs := []string{
filepath.Join(homeDir, "public_html", "wp-content", "uploads"),
}
// Scan directories that shouldn't normally contain user PHP:
// languages (translation files only), upgrade (temp dir), mu-plugins
sensitiveWPDirs := []string{
filepath.Join(homeDir, "public_html", "wp-content", "languages"),
filepath.Join(homeDir, "public_html", "wp-content", "upgrade"),
filepath.Join(homeDir, "public_html", "wp-content", "mu-plugins"),
}
subDirs, _ := osFS.ReadDir(homeDir)
for _, sd := range subDirs {
if ctx.Err() != nil {
return entries
}
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
uploadsPath := filepath.Join(homeDir, sd.Name(), "wp-content", "uploads")
if info, err := osFS.Stat(uploadsPath); err == nil && info.IsDir() {
uploadDirs = append(uploadDirs, uploadsPath)
}
// Also track sensitive dirs for addon domains
for _, subDir := range []string{"languages", "upgrade", "mu-plugins"} {
if ctx.Err() != nil {
return entries
}
sensitiveDir := filepath.Join(homeDir, sd.Name(), "wp-content", subDir)
if info, err := osFS.Stat(sensitiveDir); err == nil && info.IsDir() {
sensitiveWPDirs = append(sensitiveWPDirs, sensitiveDir)
}
}
}
}
for _, uploadsDir := range uploadDirs {
if ctx.Err() != nil {
return entries
}
scanDirForPHPContext(ctx, uploadsDir, 6, dirCache, prevByDir, forceFullScan, phpHandlerOverlay{}, &entries)
}
// Scan sensitive WP directories for any PHP files
for _, sensitiveDir := range sensitiveWPDirs {
if ctx.Err() != nil {
return entries
}
scanDirForPHPContext(ctx, sensitiveDir, 4, dirCache, prevByDir, forceFullScan, phpHandlerOverlay{}, &entries)
}
// Scan .config for executables
configDir := filepath.Join(homeDir, ".config")
if ctx.Err() != nil {
return entries
}
scanDirForExecutablesContext(ctx, configDir, 3, dirCache, prevByDir, forceFullScan, &entries)
}
if AccountFromContext(ctx) == "" {
// Scan tmp dirs
for _, tmpDir := range []string{"/tmp", "/dev/shm", "/var/tmp"} {
if ctx.Err() != nil {
return entries
}
scanDirForSuspiciousExtContext(ctx, tmpDir, 2, dirCache, prevByDir, forceFullScan, &entries)
}
}
sort.Strings(entries)
return entries
}
// scanDirForPHP recursively reads directories for PHP-executable files.
// If directory mtime is unchanged, carries forward previous entries.
func scanDirForPHP(dir string, maxDepth int, cache dirMtimeCache, prev map[string][]string, forceFullScan bool, overlay phpHandlerOverlay, entries *[]string) {
scanDirForPHPContext(context.Background(), dir, maxDepth, cache, prev, forceFullScan, overlay, entries)
}
func scanDirForPHPContext(ctx context.Context, dir string, maxDepth int, cache dirMtimeCache, prev map[string][]string, forceFullScan bool, overlay phpHandlerOverlay, entries *[]string) {
if maxDepth <= 0 || ctx.Err() != nil {
return
}
if htaccess, err := osFS.ReadFile(filepath.Join(dir, ".htaccess")); err == nil {
overlay = overlay.mergeHtaccess(htaccess)
}
if ctx.Err() != nil {
return
}
changed := dirChanged(dir, cache, forceFullScan || overlay.active())
if ctx.Err() != nil {
return
}
if !changed {
// Carry forward previous entries for this directory
*entries = append(*entries, prev[dir]...)
return
}
dirEntries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range dirEntries {
if ctx.Err() != nil {
return
}
name := entry.Name()
fullPath := filepath.Join(dir, name)
if entry.IsDir() {
scanDirForPHPContext(ctx, fullPath, maxDepth-1, cache, prev, forceFullScan, overlay, entries)
continue
}
nameLower := strings.ToLower(name)
// index.php is indexed too: a webshell named index.php must not hide
// behind the WordPress silence-stub convention. The inert stub itself
// is suppressed later by content analysis. All PHP-executable
// extensions are indexed, not just .php, so a .phtml/.php7 backdoor
// cannot dodge the index by extension. The two predicates overlap
// (the PHP family is also "suspicious"), so index each file once.
if overlay.executes(nameLower) || suspiciousExtensions[filepath.Ext(nameLower)] {
*entries = append(*entries, fullPath)
}
}
}
// scanDirForExecutables reads .config dirs for executable files.
func scanDirForExecutables(dir string, maxDepth int, cache dirMtimeCache, prev map[string][]string, forceFullScan bool, entries *[]string) {
scanDirForExecutablesContext(context.Background(), dir, maxDepth, cache, prev, forceFullScan, entries)
}
func scanDirForExecutablesContext(ctx context.Context, dir string, maxDepth int, cache dirMtimeCache, prev map[string][]string, forceFullScan bool, entries *[]string) {
if maxDepth <= 0 || ctx.Err() != nil {
return
}
changed := dirChanged(dir, cache, forceFullScan)
if ctx.Err() != nil {
return
}
if !changed {
*entries = append(*entries, prev[dir]...)
return
}
dirEntries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range dirEntries {
if ctx.Err() != nil {
return
}
fullPath := filepath.Join(dir, entry.Name())
if entry.IsDir() {
scanDirForExecutablesContext(ctx, fullPath, maxDepth-1, cache, prev, forceFullScan, entries)
continue
}
info, err := entry.Info()
if err != nil {
continue
}
if info.Mode()&0111 != 0 {
*entries = append(*entries, fullPath)
}
}
}
// scanDirForSuspiciousExt reads tmp dirs for files with suspicious extensions.
func scanDirForSuspiciousExt(dir string, maxDepth int, cache dirMtimeCache, prev map[string][]string, forceFullScan bool, entries *[]string) {
scanDirForSuspiciousExtContext(context.Background(), dir, maxDepth, cache, prev, forceFullScan, entries)
}
func scanDirForSuspiciousExtContext(ctx context.Context, dir string, maxDepth int, cache dirMtimeCache, prev map[string][]string, forceFullScan bool, entries *[]string) {
if maxDepth <= 0 || ctx.Err() != nil {
return
}
changed := dirChanged(dir, cache, forceFullScan)
if ctx.Err() != nil {
return
}
if !changed {
*entries = append(*entries, prev[dir]...)
return
}
dirEntries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range dirEntries {
if ctx.Err() != nil {
return
}
name := entry.Name()
fullPath := filepath.Join(dir, name)
if entry.IsDir() {
scanDirForSuspiciousExtContext(ctx, fullPath, maxDepth-1, cache, prev, forceFullScan, entries)
continue
}
ext := filepath.Ext(strings.ToLower(name))
if suspiciousExtensions[ext] {
*entries = append(*entries, fullPath)
}
}
}
func isWebshellName(name string) bool {
webshells := map[string]bool{
"h4x0r.php": true, "c99.php": true, "r57.php": true,
"wso.php": true, "alfa.php": true, "b374k.php": true,
"shell.php": true, "cmd.php": true, "backdoor.php": true,
"webshell.php": true, "hack.php": true, "0x.php": true,
"up.php": true, "uploader.php": true, "filemanager.php": true,
}
return webshells[name]
}
func isSuspiciousPHPName(name string) bool {
suspicious := []string{
"shell", "cmd", "exec", "hack", "backdoor", "upload",
"exploit", "reverse", "connect", "proxy", "tunnel",
"0x", "x0", "eval", "assert", "passthru",
}
for _, s := range suspicious {
if strings.Contains(name, s) {
return true
}
}
nameNoExt := strings.TrimSuffix(name, ".php")
if len(nameNoExt) <= 5 && strings.ContainsAny(nameNoExt, "0123456789") {
return true
}
return false
}
func writeIndex(path string, entries []string) {
tmpPath := path + ".tmp"
// #nosec G304 -- path is filepath.Join under operator-configured StatePath.
f, err := os.Create(tmpPath)
if err != nil {
return
}
w := bufio.NewWriter(f)
for _, e := range entries {
_, _ = w.WriteString(e + "\n")
}
_ = w.Flush()
_ = f.Close()
_ = os.Rename(tmpPath, path)
}
func loadIndex(path string) []string {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var entries []string
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
entries = append(entries, line)
}
}
return entries
}
func copyFile(src, dst string) {
data, err := osFS.ReadFile(src)
if err != nil {
return
}
_ = os.WriteFile(dst, data, 0600)
}
package checks
import (
"fmt"
"github.com/pidginhost/csm/internal/alert"
)
// classifyUploadPHP decides severity, check name, and message for a fresh PHP
// file under wp-content/uploads using its CONTENT, never its path or name.
// Uploads should hold media, not PHP, so any new PHP is at least a visibility
// signal; the body decides whether it is an attack.
//
// A negative severity means "suppress" (a content-verified inert stub). It
// mirrors classifySensitiveDirPHP so the two anomalous-PHP-location detectors
// behave identically. Path/name allowlists are intentionally absent: skipping
// a file because it sits under /cache/ or is named index.php is exactly how an
// attacker hides a webshell in a "safe" location.
//
// Unreadable or zero-byte bodies fail closed at High: an attacker who races
// the scanner with `rm` or chmod 000 must not earn a demote. A content-clean
// real-code file surfaces as a non-actionable Warning under a check name that
// is intentionally absent from the correlation and auto-response maps -- a
// clean file is visibility, not an attack.
func classifyUploadPHP(path string) (alert.Severity, string, string) {
r := analyzePHPContent(path)
if r.severity >= 0 {
return r.severity, r.check, fmt.Sprintf("%s: %s", r.message, path)
}
if !r.readOK {
return alert.High, "new_php_in_uploads", fmt.Sprintf("New unreadable PHP file in uploads: %s", path)
}
if r.empty {
return alert.High, "new_php_in_uploads", fmt.Sprintf("New empty PHP file in uploads: %s", path)
}
if IsBenignPHPStub(path) {
return -1, "", ""
}
return alert.Warning, "new_php_in_uploads_clean", fmt.Sprintf("New PHP file in uploads (content clean): %s", path)
}
package checks
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckFilesystem uses globs and targeted ReadDir to check for backdoors,
// hidden files, and SUID binaries. No `find` command needed.
func CheckFilesystem(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if ctx == nil {
ctx = context.Background()
}
var findings []alert.Finding
// GSocket / backdoor binaries in .config dirs - glob (instant).
// Rank by mtime desc so recently-touched accounts process first
// when the check timeout cuts iteration short.
backdoorNames := map[string]bool{
"defunct": true, "defunct.dat": true, "gs-netcat": true,
"gs-sftp": true, "gs-mount": true, "gsocket": true,
}
configGlobs := [][]string{
{".config", "htop", "*"},
{".config", "*", "*"},
}
// The htop glob is a subset of the wider .config glob. Deduplicate
// before ranking so the per-account cap applies to this scanner once.
configCandidates := make([]string, 0)
seenConfigCandidate := make(map[string]struct{})
for _, pattern := range configGlobs {
if ctx.Err() != nil {
return findings
}
matches, _ := homeGlob(ctx, pattern...)
for _, path := range matches {
if ctx.Err() != nil {
return findings
}
if backdoorNames[filepath.Base(path)] {
if _, seen := seenConfigCandidate[path]; seen {
continue
}
seenConfigCandidate[path] = struct{}{}
configCandidates = append(configCandidates, path)
}
}
}
rankedConfigCandidates := rankPathsByMtimeDesc(ctx, configCandidates, effectiveAccountScanMaxFiles(cfg))
if ctx.Err() != nil {
return findings
}
for _, path := range rankedConfigCandidates {
if ctx.Err() != nil {
return findings
}
info, _ := osFS.Stat(path)
var details string
if info != nil {
details = fmt.Sprintf("Size: %d bytes, Mtime: %s", info.Size(), info.ModTime().Format("2006-01-02 15:04:05"))
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "backdoor_binary",
Message: fmt.Sprintf("Backdoor binary found: %s", path),
Details: details,
FilePath: path,
})
}
if AccountFromContext(ctx) == "" {
// Hidden files in /tmp, /dev/shm, /var/tmp - glob (instant)
safeHiddenPrefixes := []string{
".s.PGSQL", ".font-unix", ".ICE-unix", ".X11-unix",
".XIM-unix", ".crontab.", ".Test-unix",
}
for _, pattern := range []string{"/tmp/.*", "/dev/shm/.*", "/var/tmp/.*"} {
if ctx.Err() != nil {
return findings
}
matches, _ := osFS.Glob(pattern)
candidates := make([]string, 0, len(matches))
for _, match := range matches {
if ctx.Err() != nil {
return findings
}
base := filepath.Base(match)
safe := false
for _, prefix := range safeHiddenPrefixes {
if strings.HasPrefix(base, prefix) {
safe = true
break
}
}
if safe {
continue
}
candidates = append(candidates, match)
}
// These are global temp locations, not account paths; do not let
// account_scan_max_files hide older suspicious files here.
ranked := rankPathsByMtimeDesc(ctx, candidates, 0)
if ctx.Err() != nil {
return findings
}
for _, match := range ranked {
if ctx.Err() != nil {
return findings
}
info, err := osFS.Stat(match)
if err != nil || info.IsDir() {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "suspicious_file",
Message: fmt.Sprintf("Suspicious hidden file: %s", match),
Details: fmt.Sprintf("Size: %d, Mtime: %s", info.Size(), info.ModTime()),
FilePath: match,
})
}
}
// SUID binaries in tmp dirs - ReadDir + stat (small dirs, fast)
for _, dir := range []string{"/tmp", "/var/tmp", "/dev/shm"} {
if ctx.Err() != nil {
return findings
}
scanForSUID(ctx, dir, 3, &findings)
}
}
// SUID in /home - shallow scan only
if ctx.Err() != nil {
return findings
}
homeDirs, _ := GetScanHomeDirs(ctx)
for _, entry := range homeDirs {
if ctx.Err() != nil {
return findings
}
if !entry.IsDir() {
continue
}
scanForSUID(ctx, filepath.Join("/home", entry.Name()), 3, &findings)
}
return findings
}
// scanForSUID checks for SUID binaries using ReadDir.
func scanForSUID(ctx context.Context, dir string, maxDepth int, findings *[]alert.Finding) {
if ctx.Err() != nil {
return
}
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
if ctx.Err() != nil {
return
}
fullPath := filepath.Join(dir, entry.Name())
if entry.IsDir() {
// Skip virtfs and known large dirs
if entry.Name() == "virtfs" || entry.Name() == "mail" || entry.Name() == "public_html" {
continue
}
scanForSUID(ctx, fullPath, maxDepth-1, findings)
continue
}
info, err := entry.Info()
if err != nil {
continue
}
if info.Mode()&os.ModeSetuid != 0 {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "suid_binary",
Message: fmt.Sprintf("SUID binary in unusual location: %s", fullPath),
Details: fmt.Sprintf("Mode: %s, Size: %d", info.Mode(), info.Size()),
FilePath: fullPath,
})
}
}
}
// CheckWebshells uses pure Go ReadDir to scan for known webshell files
// and directories. No `find` command needed.
func CheckWebshells(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
webshellNames := map[string]bool{
"h4x0r.php": true, "c99.php": true, "r57.php": true,
"wso.php": true, "alfa.php": true, "b374k.php": true,
"mini.php": true, "adminer.php": true,
}
webshellDirs := map[string]bool{
"LEVIATHAN": true, "haxorcgiapi": true,
}
// Scan each user's public_html and addon domains
homeDirs, _ := GetScanHomeDirs(ctx)
for _, homeEntry := range homeDirs {
if ctx.Err() != nil {
return findings
}
if !homeEntry.IsDir() {
continue
}
homeDir := filepath.Join("/home", homeEntry.Name())
// Get all potential document roots
docRoots := []string{filepath.Join(homeDir, "public_html")}
subDirs, _ := osFS.ReadDir(homeDir)
for _, sd := range subDirs {
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
docRoots = append(docRoots, filepath.Join(homeDir, sd.Name()))
}
}
for _, docRoot := range docRoots {
scanForWebshells(ctx, docRoot, 8, webshellNames, webshellDirs, cfg, &findings)
if ctx.Err() != nil {
return findings
}
}
}
return findings
}
// scanForWebshells recursively reads directories looking for known webshell
// files and directories. Uses ReadDir (getdents) - no stat unless matched.
func scanForWebshells(ctx context.Context, dir string, maxDepth int, names map[string]bool, dirs map[string]bool, cfg *config.Config, findings *[]alert.Finding) {
if ctx.Err() != nil {
return
}
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
if ctx.Err() != nil {
return
}
name := entry.Name()
fullPath := filepath.Join(dir, name)
// Check suppressed paths
suppressed := false
for _, ignore := range cfg.Suppressions.IgnorePaths {
if matchGlob(fullPath, ignore) {
suppressed = true
break
}
}
if suppressed {
continue
}
if entry.IsDir() {
if dirs[name] {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "webshell",
Message: fmt.Sprintf("Webshell directory found: %s", fullPath),
FilePath: fullPath,
})
}
scanForWebshells(ctx, fullPath, maxDepth-1, names, dirs, cfg, findings)
continue
}
nameLower := strings.ToLower(name)
if names[nameLower] {
info, _ := osFS.Stat(fullPath)
var details string
if info != nil {
details = fmt.Sprintf("Size: %d, Mtime: %s", info.Size(), info.ModTime())
}
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "webshell",
Message: fmt.Sprintf("Known webshell found: %s", fullPath),
Details: details,
FilePath: fullPath,
})
}
// .haxor extension
if strings.HasSuffix(nameLower, ".haxor") || strings.HasSuffix(nameLower, ".cgix") {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "webshell",
Message: fmt.Sprintf("Suspicious CGI file: %s", fullPath),
FilePath: fullPath,
})
}
// File permission anomalies - only check PHP-executable files to keep it fast
if isExecutablePHPName(nameLower) {
info, err := entry.Info()
if err == nil {
mode := info.Mode()
// World-writable PHP
if mode&0002 != 0 {
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "world_writable_php",
Message: fmt.Sprintf("World-writable PHP file: %s", fullPath),
Details: fmt.Sprintf("Mode: %s", mode),
FilePath: fullPath,
})
}
// Note: executable PHP check removed - most PHP files on cPanel
// have +x due to suPHP/lsapi, making this too noisy.
}
}
}
}
// matchGlob reports whether path is covered by an operator suppression pattern.
//
// Matching is tried in order:
// 1. filepath.Match against the basename ("*.php", "*.log") and the full path.
// 2. For a leading-any-depth glob pattern, a substring match of the
// wildcard-stripped residue -- but ONLY when that residue still contains a
// path separator with literal content (e.g. "*/node_modules/*" ->
// "/node_modules/"). This preserves the "directory anywhere in the path, at
// any depth" intent without broadening anchored full-path globs like
// "/tmp/safe/*" into recursive subtree suppressions.
// 3. For a pattern with no wildcards, a literal substring match, so an operator
// can suppress a directory ("/uploads/") or a filename ("adminer.php").
//
// The separator requirement in step 2 is the fix for an over-suppression
// footgun: the previous code stripped every "*" and substring-matched the
// remainder, so "*.php" became the bare token ".php" and silenced every file
// whose path merely contained ".php" -- turning a narrow pattern into a
// whole-subtree allowlist an attacker could hide a webshell in.
func matchGlob(path, pattern string) bool {
if pattern == "" {
return false
}
if strings.ContainsAny(pattern, "*?[") {
if matched, _ := filepath.Match(pattern, filepath.Base(path)); matched {
return true
}
if matched, _ := filepath.Match(pattern, path); matched {
return true
}
if strings.ContainsAny(pattern, "?[") || !hasLeadingAnyDepthGlob(pattern) {
return false
}
residue := strings.ReplaceAll(pattern, "*", "")
if strings.Contains(residue, "/") && strings.Trim(residue, "/") != "" {
return strings.Contains(path, residue)
}
return false
}
return strings.Contains(path, pattern)
}
func hasLeadingAnyDepthGlob(pattern string) bool {
firstSlash := strings.Index(pattern, "/")
if firstSlash <= 0 {
return false
}
return strings.Trim(pattern[:firstSlash], "*") == ""
}
package checks
import (
"context"
"fmt"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
func CheckFirewall(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
if !cfg.Firewall.Enabled {
// Firewall not managed by CSM - skip nftables checks
return findings
}
// Verify the CSM nftables table exists and has expected components.
// Routed through cmdExec so tests can mock the nft response without
// requiring a real nftables stack.
out, err := cmdExec.RunAllowNonZero("nft", "list", "table", "inet", "csm")
if err != nil {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "firewall",
Message: "CSM firewall table not found in nftables - rules may not be active",
Timestamp: time.Now(),
})
return findings
}
output := string(out)
required := []string{"chain input", "chain output", "set blocked_ips", "set allowed_ips", "set infra_ips"}
for _, component := range required {
if !strings.Contains(output, component) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "firewall",
Message: fmt.Sprintf("Firewall missing expected component: %s", component),
Timestamp: time.Now(),
})
}
}
// Hash the rule structure (excluding dynamic set elements which change on every block/unblock).
// Only detects if someone manually added/removed chains or rules.
var stableLines []byte
for _, line := range strings.Split(string(out), "\n") {
trimmed := strings.TrimSpace(line)
// Skip set element lines (start with IPs or "elements = {" or "expires")
if strings.HasPrefix(trimmed, "elements") || strings.Contains(trimmed, "expires") {
continue
}
// Skip empty element continuation lines (just IPs)
if len(trimmed) > 0 && (trimmed[0] >= '0' && trimmed[0] <= '9') {
continue
}
stableLines = append(stableLines, line...)
stableLines = append(stableLines, '\n')
}
hash := hashBytes(stableLines)
prev, exists := store.GetRaw("_nftables_rules_hash")
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "firewall",
Message: "nftables ruleset modified outside of CSM",
Timestamp: time.Now(),
})
}
store.SetRaw("_nftables_rules_hash", hash)
// Check for dangerous ports in config
findings = append(findings, checkDangerousPorts(cfg)...)
return findings
}
func checkDangerousPorts(cfg *config.Config) []alert.Finding {
var findings []alert.Finding
dangerousPorts := make(map[int]bool)
for _, p := range cfg.BackdoorPorts {
dangerousPorts[p] = true
}
restricted := make(map[int]bool)
for _, p := range cfg.Firewall.RestrictedTCP {
restricted[p] = true
}
for _, port := range cfg.Firewall.TCPIn {
if dangerousPorts[port] {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "firewall_ports",
Message: fmt.Sprintf("Known backdoor port %d is open in firewall TCP_IN", port),
Timestamp: time.Now(),
})
}
if restricted[port] {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "firewall_ports",
Message: fmt.Sprintf("Restricted port %d found in public TCP_IN - should be infra-only", port),
Timestamp: time.Now(),
})
}
}
return findings
}
package checks
import (
"bufio"
"context"
"crypto/sha256"
"fmt"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
)
// parseValiasLine parses a valiases line "local_part: destination".
// Returns empty strings for comments, blank lines, or malformed lines.
func parseValiasLine(line string) (localPart, dest string) {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
return "", ""
}
idx := strings.IndexByte(line, ':')
if idx < 0 {
return "", ""
}
localPart = strings.TrimSpace(line[:idx])
dest = strings.TrimSpace(line[idx+1:])
return localPart, dest
}
// isPipeForwarder returns true if the destination is a pipe forwarder,
// excluding known-safe cPanel built-in pipes (autoresponder, BoxTrapper).
func isPipeForwarder(dest string) bool {
if !strings.HasPrefix(dest, "|") {
return false
}
safe := []string{
"/usr/local/cpanel/bin/autorespond",
"/usr/local/cpanel/bin/boxtrapper",
"/usr/local/cpanel/bin/mailman",
}
for _, s := range safe {
if strings.Contains(dest, s) {
return false
}
}
return true
}
// isDevNullForwarder returns true if the destination is /dev/null.
func isDevNullForwarder(dest string) bool {
return dest == "/dev/null"
}
// isExternalDest returns true if the destination is an email address
// with a domain not in the local domains set.
func isExternalDest(dest string, localDomains map[string]bool) bool {
atIdx := strings.LastIndexByte(dest, '@')
if atIdx < 0 || atIdx >= len(dest)-1 {
return false
}
domain := strings.ToLower(dest[atIdx+1:])
return !localDomains[domain]
}
// parseVfilterExternalDests extracts external email destinations from vfilter content.
// Looks for `to "dest@domain"` directives.
func parseVfilterExternalDests(content string, localDomains map[string]bool) []string {
var external []string
lines := strings.Split(content, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
// Look for: to "email@domain"
if !strings.HasPrefix(line, "to ") && !strings.HasPrefix(line, "to\t") {
continue
}
// Extract the quoted destination
quoteStart := strings.IndexByte(line, '"')
if quoteStart < 0 {
continue
}
rest := line[quoteStart+1:]
quoteEnd := strings.IndexByte(rest, '"')
if quoteEnd < 0 {
continue
}
dest := rest[:quoteEnd]
if isExternalDest(dest, localDomains) {
external = append(external, dest)
}
}
return external
}
// parseLocalDomainsContent parses the content of /etc/localdomains or /etc/virtualdomains.
func parseLocalDomainsContent(content string) map[string]bool {
domains := make(map[string]bool)
lines := strings.Split(content, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// virtualdomains format: "domain: user" - take domain part
if idx := strings.IndexByte(line, ':'); idx > 0 {
line = strings.TrimSpace(line[:idx])
}
domains[strings.ToLower(line)] = true
}
return domains
}
// loadLocalDomains reads /etc/localdomains and /etc/virtualdomains.
func loadLocalDomains() map[string]bool {
domains := make(map[string]bool)
for _, path := range []string{"/etc/localdomains", "/etc/virtualdomains"} {
data, err := osFS.ReadFile(path)
if err != nil {
continue
}
for k, v := range parseLocalDomainsContent(string(data)) {
domains[k] = v
}
}
return domains
}
// isKnownForwarder checks if a forwarder rule matches the known forwarders suppression list.
func isKnownForwarder(localPart, domain, dest string, knownForwarders []string) bool {
entry := fmt.Sprintf("%s@%s: %s", localPart, domain, dest)
for _, known := range knownForwarders {
if strings.EqualFold(strings.TrimSpace(known), entry) {
return true
}
}
return false
}
// fileContentHash returns the SHA256 hex hash of a file's content.
func fileContentHash(path string) (string, error) {
data, err := osFS.ReadFile(path)
if err != nil {
return "", err
}
h := sha256.Sum256(data)
return fmt.Sprintf("%x", h[:]), nil
}
// CheckForwarders audits all valiases and vfilters files for dangerous forwarder
// patterns. Uses internal throttle: skips if last refresh was less than
// password_check_interval_min ago (reuses the same interval).
func CheckForwarders(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if ctx == nil {
ctx = context.Background()
}
db := store.Global()
if db == nil {
return nil
}
// Internal throttle (24h) - reuse PasswordCheckIntervalMin
if !ForceAll {
lastRefreshStr := db.GetMetaString("email:fwd_last_refresh")
if lastRefreshStr != "" {
if lastRefresh, err := time.Parse(time.RFC3339, lastRefreshStr); err == nil {
interval := time.Duration(cfg.EmailProtection.PasswordCheckIntervalMin) * time.Minute
if time.Since(lastRefresh) < interval {
return nil
}
}
}
}
if ctx.Err() != nil {
return nil
}
localDomains := loadLocalDomains()
var findings []alert.Finding
// Audit valiases. Rank by mtime desc so recently-changed mail
// domains process first when the check timeout cuts iteration short.
if ctx.Err() != nil {
return findings
}
maxFiles := effectiveAccountScanMaxFiles(cfg)
baselineComplete := true
valiasFiles, err := osFS.Glob("/etc/valiases/*")
if err != nil {
baselineComplete = false
}
baselineComplete = baselineComplete && scanCoversAllFiles(valiasFiles, maxFiles)
rankedValiasFiles := rankPathsByMtimeDesc(ctx, valiasFiles, maxFiles)
if ctx.Err() != nil {
return findings
}
for _, path := range rankedValiasFiles {
if ctx.Err() != nil {
return findings
}
domain := filepath.Base(path)
entries, ok := auditValiasFileWithStatus(path, domain, localDomains, cfg)
if !ok {
baselineComplete = false
}
findings = append(findings, entries...)
}
// Audit vfilters
if ctx.Err() != nil {
return findings
}
vfilterFiles, err := osFS.Glob("/etc/vfilters/*")
if err != nil {
baselineComplete = false
}
baselineComplete = baselineComplete && scanCoversAllFiles(vfilterFiles, maxFiles)
rankedVfilterFiles := rankPathsByMtimeDesc(ctx, vfilterFiles, maxFiles)
if ctx.Err() != nil {
return findings
}
for _, path := range rankedVfilterFiles {
if ctx.Err() != nil {
return findings
}
domain := filepath.Base(path)
entries, ok := auditVfilterFileWithStatus(path, domain, localDomains, cfg)
if !ok {
baselineComplete = false
}
findings = append(findings, entries...)
}
if ctx.Err() != nil {
return findings
}
baselineExists := db.GetMetaString("email:fwd_last_refresh") != ""
if baselineComplete || baselineExists {
_ = db.SetMetaString("email:fwd_last_refresh", time.Now().Format(time.RFC3339))
}
return findings
}
// forwarderFileIsNew reports whether a forwarder/filter file should be
// treated as newly added. A stored hash that differs means the file changed.
// No stored hash means one of two things: before the first complete audit
// (no baseline marker) it is pre-existing install backlog and stays quiet;
// after the baseline it genuinely appeared post-audit -- the classic BEC
// drop the first-sight suppression used to silence forever, because the
// next scan saw an unchanged hash.
func forwarderFileIsNew(db *store.DB, baselineKey, hashKey, currentHash string) bool {
old, found := db.GetForwarderHash(hashKey)
if found {
return old != currentHash
}
return db.GetMetaString(baselineKey) != ""
}
func scanCoversAllFiles(paths []string, maxFiles int) bool {
return maxFiles <= 0 || len(paths) <= maxFiles
}
// auditValiasFile parses a valiases file and returns findings for dangerous entries.
func auditValiasFile(path, domain string, localDomains map[string]bool, cfg *config.Config) []alert.Finding {
findings, _ := auditValiasFileWithStatus(path, domain, localDomains, cfg)
return findings
}
func auditValiasFileWithStatus(path, domain string, localDomains map[string]bool, cfg *config.Config) ([]alert.Finding, bool) {
f, err := osFS.Open(path)
if err != nil {
return nil, false
}
defer f.Close()
db := store.Global()
var findings []alert.Finding
complete := true
isNew := false
hashKey := "valiases:" + domain
var currentHash string
if db != nil {
var hashErr error
currentHash, hashErr = fileContentHash(path)
if hashErr == nil {
isNew = forwarderFileIsNew(db, "email:fwd_last_refresh", hashKey, currentHash)
} else {
complete = false
}
}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
localPart, dest := parseValiasLine(scanner.Text())
if localPart == "" || dest == "" {
continue
}
// Check each destination (may be comma-separated)
dests := strings.Split(dest, ",")
for _, d := range dests {
d = strings.TrimSpace(d)
if d == "" {
continue
}
// Suppression check
if isKnownForwarder(localPart, domain, d, cfg.EmailProtection.KnownForwarders) {
continue
}
newContext := ""
if isNew {
newContext = " (newly added)"
}
if isPipeForwarder(d) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_pipe_forwarder",
Message: fmt.Sprintf("Pipe forwarder detected: %s@%s -> %s%s", localPart, domain, d, newContext),
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: %s\nFile: %s\nPipe forwarders execute arbitrary commands on incoming mail.", domain, localPart, d, path),
})
continue
}
if isDevNullForwarder(d) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_suspicious_forwarder",
Message: fmt.Sprintf("Mail blackhole: %s@%s -> /dev/null%s", localPart, domain, newContext),
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: /dev/null\nFile: %s\nAll mail to this address is silently discarded.", domain, localPart, path),
})
continue
}
if isExternalDest(d, localDomains) && isNew {
severity := alert.High
msg := fmt.Sprintf("External forwarder: %s@%s -> %s%s", localPart, domain, d, newContext)
if localPart == "*" {
msg = fmt.Sprintf("Wildcard catch-all to external: *@%s -> %s%s", domain, d, newContext)
}
findings = append(findings, alert.Finding{
Severity: severity,
Check: "email_suspicious_forwarder",
Message: msg,
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: %s\nFile: %s", domain, localPart, d, path),
})
}
}
}
if err := scanner.Err(); err != nil {
complete = false
}
if complete && db != nil {
if err := db.SetForwarderHash(hashKey, currentHash); err != nil {
complete = false
}
}
return findings, complete
}
// auditVfilterFile parses a vfilters file and returns findings for external destinations.
func auditVfilterFile(path, domain string, localDomains map[string]bool, cfg *config.Config) []alert.Finding {
findings, _ := auditVfilterFileWithStatus(path, domain, localDomains, cfg)
return findings
}
func auditVfilterFileWithStatus(path, domain string, localDomains map[string]bool, cfg *config.Config) ([]alert.Finding, bool) {
data, err := osFS.ReadFile(path)
if err != nil {
return nil, false
}
db := store.Global()
content := string(data)
complete := true
// Same newness logic as valiases above.
isNew := false
if db != nil {
currentHash := fmt.Sprintf("%x", sha256.Sum256(data))
isNew = forwarderFileIsNew(db, "email:fwd_last_refresh", "vfilters:"+domain, currentHash)
if err := db.SetForwarderHash("vfilters:"+domain, currentHash); err != nil {
complete = false
}
}
externalDests := parseVfilterExternalDests(content, localDomains)
var findings []alert.Finding
for _, dest := range externalDests {
// Suppression check - use "*" as localPart for vfilter entries
if isKnownForwarder("*", domain, dest, cfg.EmailProtection.KnownForwarders) {
continue
}
// Only alert when the vfilter file actually changed; existing forwarders
// are normal customer configuration, not an attack indicator.
if !isNew {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_suspicious_forwarder",
Message: fmt.Sprintf("External destination in vfilter: %s -> %s (newly added)", domain, dest),
Details: fmt.Sprintf("Domain: %s\nDestination: %s\nFile: %s\nA mail filter rule forwards messages to an external address.", domain, dest, path),
})
}
return findings, complete
}
package checks
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckOpenBasedir verifies that each cPanel account has proper PHP
// isolation via CageFS and/or open_basedir.
// Flags accounts where CageFS is disabled AND open_basedir is not set.
func CheckOpenBasedir(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Check global CageFS mode
cageFSMode := getCageFSMode()
// Get list of disabled CageFS users (if mode is "Enable All", check exceptions)
disabledUsers := getCageFSDisabledUsers(cageFSMode)
users := cPanelUsersForOpenBasedirScan(ctx)
if cageFSMode == "unknown" && len(disabledUsers) == 0 {
for _, user := range users {
disabledUsers[user] = true
}
}
// For users without CageFS, check if open_basedir is set
for _, user := range users {
// Skip if CageFS is active for this user
if !disabledUsers[user] {
continue
}
// CageFS is disabled for this user - check open_basedir
if !hasOpenBasedir(user) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "open_basedir",
Message: fmt.Sprintf("Account %s has no PHP isolation: CageFS disabled and no open_basedir", user),
Details: "This account's PHP scripts can read any file on the server including other accounts' wp-config.php and /etc/shadow",
})
}
}
return findings
}
func cPanelUsersForOpenBasedirScan(ctx context.Context) []string {
if account := AccountFromContext(ctx); account != "" {
if _, err := osFS.Stat(filepath.Join("/var/cpanel/users", account)); err != nil {
return nil
}
return []string{account}
}
userDirs, _ := osFS.ReadDir("/var/cpanel/users")
users := make([]string, 0, len(userDirs))
for _, userEntry := range userDirs {
users = append(users, userEntry.Name())
}
return users
}
func getCageFSMode() string {
// CageFS mode file
data, err := osFS.ReadFile("/etc/cagefs/cagefs.mp")
if err != nil {
return "unknown"
}
content := strings.TrimSpace(string(data))
if content != "" {
return "enabled"
}
return "unknown"
}
func getCageFSDisabledUsers(mode string) map[string]bool {
disabled := make(map[string]bool)
// Check cagefsctl --list-disabled
out, _ := runCmd("cagefsctl", "--list-disabled")
if out != nil {
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
user := strings.TrimSpace(line)
if user != "" {
disabled[user] = true
}
}
}
// If mode is unknown (no CageFS), all users are "disabled"
if mode == "unknown" && len(disabled) == 0 {
userDirs, _ := osFS.ReadDir("/var/cpanel/users")
for _, u := range userDirs {
disabled[u.Name()] = true
}
}
return disabled
}
func hasOpenBasedir(user string) bool {
// Check .user.ini in public_html
userIni := filepath.Join("/home", user, "public_html", ".user.ini")
if data, err := osFS.ReadFile(userIni); err == nil {
if strings.Contains(strings.ToLower(string(data)), "open_basedir") {
return true
}
}
// Check .htaccess for php_value open_basedir
htaccess := filepath.Join("/home", user, "public_html", ".htaccess")
if data, err := osFS.ReadFile(htaccess); err == nil {
if strings.Contains(strings.ToLower(string(data)), "open_basedir") {
return true
}
}
// Check per-user PHP config set via cPanel MultiPHP
phpConfDirs, _ := osFS.Glob("/opt/cpanel/ea-php*/root/etc/php.d/")
for _, confDir := range phpConfDirs {
userConf := filepath.Join(confDir, "local.ini")
if data, err := osFS.ReadFile(userConf); err == nil {
if strings.Contains(string(data), "open_basedir") {
return true
}
}
}
return false
}
// CheckSymlinkAttacks detects symbolic links inside user public_html
// directories that point outside the account's own directory.
// This is a classic shared hosting attack to read other users' files.
func CheckSymlinkAttacks(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
homeDirs, _ := GetScanHomeDirs(ctx)
for _, homeEntry := range homeDirs {
if !homeEntry.IsDir() {
continue
}
user := homeEntry.Name()
homeDir := filepath.Join("/home", user)
docRoot := filepath.Join(homeDir, "public_html")
scanForMaliciousSymlinks(docRoot, user, homeDir, 4, &findings)
}
return findings
}
func scanForMaliciousSymlinks(dir, user, homeDir string, maxDepth int, findings *[]alert.Finding) {
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
fullPath := filepath.Join(dir, entry.Name())
// Check if it's a symlink
if entry.Type()&os.ModeSymlink == 0 {
if entry.IsDir() {
scanForMaliciousSymlinks(fullPath, user, homeDir, maxDepth-1, findings)
}
continue
}
// It's a symlink - read the target
target, err := osFS.Readlink(fullPath)
if err != nil {
continue
}
// Resolve relative targets
if !filepath.IsAbs(target) {
target = filepath.Join(filepath.Dir(fullPath), target)
}
target = filepath.Clean(target)
// Check if target is outside the user's home
if isSymlinkSafe(target, user, homeDir) {
continue
}
// Check if target points to another user's home
if strings.HasPrefix(target, "/home/") {
parts := strings.SplitN(target[6:], "/", 2)
if len(parts) > 0 && parts[0] != user {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "symlink_attack",
Message: fmt.Sprintf("Symlink to another user's directory: %s -> %s", fullPath, target),
Details: fmt.Sprintf("User: %s, target user: %s\nThis could be used to read other accounts' files", user, parts[0]),
})
continue
}
}
// Check if target points to sensitive system files
sensitiveTargets := []string{"/etc/shadow", "/etc/passwd", "/root/"}
for _, sens := range sensitiveTargets {
if strings.HasPrefix(target, sens) {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "symlink_attack",
Message: fmt.Sprintf("Symlink to sensitive system file: %s -> %s", fullPath, target),
Details: fmt.Sprintf("User: %s", user),
})
break
}
}
}
}
func isSymlinkSafe(target, user, homeDir string) bool {
// Inside own home directory
if strings.HasPrefix(target, homeDir+"/") || target == homeDir {
return true
}
// Standard cPanel-created symlinks
safeTargets := []string{
"/etc/apache2/logs/",
"/usr/local/apache/logs/",
"/var/cpanel/",
"/opt/cpanel/",
"/usr/",
"/var/lib/mysql/",
"/var/run/",
}
for _, safe := range safeTargets {
if strings.HasPrefix(target, safe) {
return true
}
}
return false
}
package checks
import (
"bufio"
"context"
"encoding/hex"
"fmt"
"net"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/store"
)
// auditCmdTimeout is the per-subprocess timeout for audit checks.
// Audit checks are fast config reads, not heavy scans.
const auditCmdTimeout = 10 * time.Second
// RunHardeningAudit runs all hardening checks and returns a report.
// Pure function — reads system state only, no store access.
func RunHardeningAudit(cfg *config.Config) *store.AuditReport {
serverType := detectServerType()
var results []store.AuditResult
results = append(results, auditSSH()...)
results = append(results, auditPHP(serverType)...)
results = append(results, auditWebServer(serverType)...)
results = append(results, auditMail()...)
if serverType != "bare" {
results = append(results, auditCPanel(serverType)...)
}
results = append(results, auditOS()...)
results = append(results, auditFirewall()...)
score := 0
for _, r := range results {
if r.Status == "pass" {
score++
}
}
return &store.AuditReport{
Timestamp: time.Now(),
ServerType: serverType,
Results: results,
Score: score,
Total: len(results),
}
}
func detectServerType() string {
info := platform.Detect()
if info.IsCPanel() {
if info.OS == platform.OSCloudLinux {
return "cloudlinux"
}
return "cpanel"
}
return "bare"
}
// auditRunCmd executes a command with the audit-specific timeout via the
// cmdExec injector so tests can mock systemctl/cagefsctl/etc. without
// invoking real binaries on the host.
func auditRunCmd(name string, args ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), auditCmdTimeout)
defer cancel()
out, err := cmdExec.RunContext(ctx, name, args...)
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("command timed out: %s", name)
}
return out, err
}
// --- SSH checks ---
// sshdDefaults are the OpenSSH compiled defaults for settings we audit.
var sshdDefaults = map[string]string{
"port": "22",
"protocol": "2",
"passwordauthentication": "yes",
"permitrootlogin": "prohibit-password",
"maxauthtries": "6",
"x11forwarding": "no",
"usedns": "no",
}
var sshdConfigPath = "/etc/ssh/sshd_config"
type sshdSettings struct {
PasswordAuthentication string
PermitRootLogin string
X11Forwarding string
}
// parseSSHDConfig reads sshd_config + Include drop-ins with first-match-wins.
// Match blocks are skipped entirely (audit evaluates global config only).
func parseSSHDConfig() map[string]string {
effective := make(map[string]string)
parseSSHDFile(sshdConfigPath, effective)
return effective
}
func parseSSHDFile(path string, effective map[string]string) {
f, err := osFS.Open(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
inMatch := false
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Detect Match blocks — a Match block continues until the next
// Match keyword or EOF, regardless of indentation (per sshd_config(5)).
lower := strings.ToLower(line)
if strings.HasPrefix(lower, "match ") {
inMatch = true
continue
}
if inMatch {
// Only another Match line (handled above) or EOF ends the block.
// Everything else inside is a Match-scoped directive — skip it.
continue
}
// Handle Include directives
if strings.HasPrefix(lower, "include ") {
pattern := strings.TrimSpace(line[8:])
if !filepath.IsAbs(pattern) {
pattern = filepath.Join(filepath.Dir(path), pattern)
}
matches, _ := osFS.Glob(pattern)
for _, m := range matches {
parseSSHDFile(m, effective)
}
continue
}
// Parse keyword=value or keyword value
parts := strings.SplitN(line, " ", 2)
if len(parts) < 2 {
parts = strings.SplitN(line, "=", 2)
}
if len(parts) < 2 {
continue
}
keyword := strings.ToLower(strings.TrimSpace(parts[0]))
value := strings.TrimSpace(parts[1])
// First-match-wins: only record the first occurrence
if _, exists := effective[keyword]; !exists {
effective[keyword] = value
}
}
}
func sshdEffective(parsed map[string]string, key string) string {
if v, ok := parsed[key]; ok {
return strings.ToLower(v)
}
return sshdDefaults[key]
}
func currentSSHDSettings() sshdSettings {
parsed := parseSSHDConfig()
return sshdSettings{
PasswordAuthentication: sshdEffective(parsed, "passwordauthentication"),
PermitRootLogin: sshdEffective(parsed, "permitrootlogin"),
X11Forwarding: sshdEffective(parsed, "x11forwarding"),
}
}
func auditSSH() []store.AuditResult {
parsed := parseSSHDConfig()
var results []store.AuditResult
// ssh_port
port := sshdEffective(parsed, "port")
if port == "22" {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_port", Title: "SSH Port",
Status: "warn", Message: "SSH is running on default port 22",
Fix: "Change to a non-standard port in /etc/ssh/sshd_config to reduce automated scan noise. Update firewall rules before changing.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_port", Title: "SSH Port",
Status: "pass", Message: fmt.Sprintf("SSH on non-standard port %s", port),
})
}
// ssh_protocol
proto := sshdEffective(parsed, "protocol")
if strings.Contains(proto, "1") {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_protocol", Title: "SSH Protocol",
Status: "fail", Message: "SSHv1 protocol is enabled",
Fix: "Set 'Protocol 2' in /etc/ssh/sshd_config. SSHv1 has known cryptographic weaknesses.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_protocol", Title: "SSH Protocol",
Status: "pass", Message: "SSHv1 disabled",
})
}
// ssh_password_auth
passAuth := sshdEffective(parsed, "passwordauthentication")
if passAuth != "no" {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_password_auth", Title: "SSH PasswordAuthentication",
Status: "fail", Message: "Password authentication is enabled",
Fix: "Set 'PasswordAuthentication no' in /etc/ssh/sshd_config and use SSH key authentication only.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_password_auth", Title: "SSH PasswordAuthentication",
Status: "pass", Message: "Password authentication disabled",
})
}
// ssh_root_login
rootLogin := sshdEffective(parsed, "permitrootlogin")
if rootLogin == "yes" {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_root_login", Title: "SSH PermitRootLogin",
Status: "fail", Message: "Direct root login is permitted",
Fix: "Set 'PermitRootLogin no' or 'PermitRootLogin prohibit-password' in /etc/ssh/sshd_config.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_root_login", Title: "SSH PermitRootLogin",
Status: "pass", Message: fmt.Sprintf("PermitRootLogin set to %s", rootLogin),
})
}
// ssh_max_auth_tries
maxTries := sshdEffective(parsed, "maxauthtries")
n, _ := strconv.Atoi(maxTries)
if n == 0 {
n = 6 // default
}
if n > 4 {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_max_auth_tries", Title: "SSH MaxAuthTries",
Status: "warn", Message: fmt.Sprintf("MaxAuthTries is %d (recommended: 4 or less)", n),
Fix: "Set 'MaxAuthTries 4' in /etc/ssh/sshd_config to limit brute-force attempts per connection.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_max_auth_tries", Title: "SSH MaxAuthTries",
Status: "pass", Message: fmt.Sprintf("MaxAuthTries set to %d", n),
})
}
// ssh_x11_forwarding
x11 := sshdEffective(parsed, "x11forwarding")
if x11 != "no" {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_x11_forwarding", Title: "SSH X11Forwarding",
Status: "warn", Message: "X11 forwarding is enabled",
Fix: "Set 'X11Forwarding no' in /etc/ssh/sshd_config unless X11 forwarding is actively needed.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_x11_forwarding", Title: "SSH X11Forwarding",
Status: "pass", Message: "X11 forwarding disabled",
})
}
// ssh_use_dns
useDNS := sshdEffective(parsed, "usedns")
if useDNS != "no" {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_use_dns", Title: "SSH UseDNS",
Status: "warn", Message: "UseDNS is enabled",
Fix: "Set 'UseDNS no' in /etc/ssh/sshd_config. Otherwise lfd may not track SSH login failures by IP.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_use_dns", Title: "SSH UseDNS",
Status: "pass", Message: "UseDNS disabled",
})
}
return results
}
// --- OS hardening checks ---
func auditOS() []store.AuditResult {
var results []store.AuditResult
// /tmp and /var/tmp permissions
for _, dir := range []struct {
path, id, title string
}{
{"/tmp", "os_tmp_permissions", "/tmp Permissions"},
{"/var/tmp", "os_var_tmp_permissions", "/var/tmp Permissions"},
} {
info, err := osFS.Stat(dir.path)
if err != nil {
results = append(results, store.AuditResult{
Category: "os", Name: dir.id, Title: dir.title,
Status: "warn", Message: fmt.Sprintf("Cannot stat %s: %v", dir.path, err),
})
continue
}
// Use the raw Unix mode bits from syscall to get the traditional
// permission representation (sticky=01000, setuid=04000, etc.).
// Go's os.ModeSticky uses high bits that don't map to Unix octal,
// so os.FileMode math produces wrong values for comparison.
// Only check the lower 12 bits (sticky + rwx) — ignore setuid/setgid
// which CloudLinux/CageFS may set on virtmp mounts.
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
results = append(results, store.AuditResult{
Category: "os", Name: dir.id, Title: dir.title,
Status: "warn", Message: fmt.Sprintf("Cannot read ownership of %s", dir.path),
})
continue
}
mode := stat.Mode & 01777 // sticky + rwxrwxrwx, ignore setuid/setgid
if mode != 01777 || stat.Uid != 0 || stat.Gid != 0 {
results = append(results, store.AuditResult{
Category: "os", Name: dir.id, Title: dir.title,
Status: "fail",
Message: fmt.Sprintf("%s has mode %04o uid=%d gid=%d (expected 1777 root:root)", dir.path, mode, stat.Uid, stat.Gid),
Fix: fmt.Sprintf("chmod 1777 %s && chown root:root %s", dir.path, dir.path),
})
} else {
results = append(results, store.AuditResult{
Category: "os", Name: dir.id, Title: dir.title,
Status: "pass", Message: fmt.Sprintf("%s is 1777 root:root", dir.path),
})
}
}
// /etc/shadow permissions
// Accept 0000, 0600 (RHEL/CentOS default), and 0640 (Debian default).
// All three restrict access to root only. 0600 is the standard on
// CentOS/CloudLinux — changing it can break passwd/chage.
if info, err := osFS.Stat("/etc/shadow"); err == nil {
perm := info.Mode().Perm()
if perm == 0 || perm == 0o600 || perm == 0o640 {
results = append(results, store.AuditResult{
Category: "os", Name: "os_shadow_permissions", Title: "/etc/shadow Permissions",
Status: "pass", Message: fmt.Sprintf("/etc/shadow has mode %04o", perm),
})
} else {
results = append(results, store.AuditResult{
Category: "os", Name: "os_shadow_permissions", Title: "/etc/shadow Permissions",
Status: "fail", Message: fmt.Sprintf("/etc/shadow has mode %04o (expected 0000, 0600, or 0640)", perm),
Fix: "chmod 0600 /etc/shadow",
})
}
} else {
results = append(results, store.AuditResult{
Category: "os", Name: "os_shadow_permissions", Title: "/etc/shadow Permissions",
Status: "warn", Message: "Cannot stat /etc/shadow",
})
}
// Swap
if data, err := osFS.ReadFile("/proc/swaps"); err == nil {
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) > 1 {
results = append(results, store.AuditResult{
Category: "os", Name: "os_swap", Title: "Swap Configured",
Status: "pass", Message: fmt.Sprintf("%d swap device(s) active", len(lines)-1),
})
} else {
results = append(results, store.AuditResult{
Category: "os", Name: "os_swap", Title: "Swap Configured",
Status: "warn", Message: "No swap configured",
Fix: "Configure swap space to prevent OOM kills: fallocate -l 2G /swapfile && chmod 600 /swapfile && mkswap /swapfile && swapon /swapfile",
})
}
}
// Distro EOL
results = append(results, checkDistroEOL()...)
// nobody crontab
if info, err := osFS.Stat("/var/spool/cron/nobody"); err != nil {
// absent is fine
results = append(results, store.AuditResult{
Category: "os", Name: "os_nobody_cron", Title: "Nobody Crontab",
Status: "pass", Message: "No crontab for nobody user",
})
} else if info.Size() == 0 {
results = append(results, store.AuditResult{
Category: "os", Name: "os_nobody_cron", Title: "Nobody Crontab",
Status: "pass", Message: "Nobody crontab is empty",
})
} else {
results = append(results, store.AuditResult{
Category: "os", Name: "os_nobody_cron", Title: "Nobody Crontab",
Status: "fail", Message: "nobody user has a crontab with content",
Fix: "Review and remove: crontab -u nobody -r",
})
}
// Unnecessary services
results = append(results, checkUnnecessaryServices()...)
// CVE-2026-31431 "Copy Fail" — algif_aead is the AF_ALG submodule the
// exploit chains through. Blacklisting it neutralises the attack on
// unpatched kernels.
results = append(results, auditAlgifAEAD())
// Sysctl checks (table-driven)
sysctlChecks := []struct {
id, title, path, expected string
}{
{"os_sysctl_syncookies", "TCP SYN Cookies", "/proc/sys/net/ipv4/tcp_syncookies", "1"},
{"os_sysctl_aslr", "Address Space Layout Randomization", "/proc/sys/kernel/randomize_va_space", "2"},
{"os_sysctl_rp_filter", "Reverse Path Filtering", "/proc/sys/net/ipv4/conf/all/rp_filter", "1"},
{"os_sysctl_icmp_broadcast", "ICMP Broadcast Ignore", "/proc/sys/net/ipv4/icmp_echo_ignore_broadcasts", "1"},
{"os_sysctl_symlinks", "Protected Symlinks", "/proc/sys/fs/protected_symlinks", "1"},
{"os_sysctl_hardlinks", "Protected Hardlinks", "/proc/sys/fs/protected_hardlinks", "1"},
}
for _, sc := range sysctlChecks {
data, err := osFS.ReadFile(sc.path)
if err != nil {
results = append(results, store.AuditResult{
Category: "os", Name: sc.id, Title: sc.title,
Status: "warn", Message: fmt.Sprintf("Cannot read %s", sc.path),
})
continue
}
val := strings.TrimSpace(string(data))
// Convert /proc/sys path to sysctl dotted notation for fix command
sysctlKey := strings.TrimPrefix(sc.path, "/proc/sys/")
sysctlKey = strings.ReplaceAll(sysctlKey, "/", ".")
if val == sc.expected {
results = append(results, store.AuditResult{
Category: "os", Name: sc.id, Title: sc.title,
Status: "pass", Message: fmt.Sprintf("%s = %s", sysctlKey, val),
})
} else {
results = append(results, store.AuditResult{
Category: "os", Name: sc.id, Title: sc.title,
Status: "fail", Message: fmt.Sprintf("%s = %s (expected %s)", sysctlKey, val, sc.expected),
Fix: fmt.Sprintf("sysctl -w %s=%s && echo '%s = %s' >> /etc/sysctl.d/99-csm-hardening.conf", sysctlKey, sc.expected, sysctlKey, sc.expected),
})
}
}
return results
}
// algifAEADBlacklisted reports whether any of the supplied modprobe.d files
// contain a non-comment directive that prevents algif_aead from loading.
// Either of the following blocks is sufficient:
//
// blacklist algif_aead
// install algif_aead /bin/false
// blacklist af_alg (parent — algif_aead depends on it)
// install af_alg /bin/false (parent — same reason)
//
// The dependency relationship matters for hand-hardened images: a sysadmin
// who blocked only the parent `af_alg` has correctly mitigated Copy Fail
// without needing to also block the AEAD submodule. Reporting "no blacklist
// exists" on such hosts would be a false-fail alert.
//
// `install <module> /sbin/modprobe --ignore-install <module>` is the
// idiomatic re-load form and explicitly does NOT block the module — we
// detect that by skipping any install replacement whose first token's
// basename is "modprobe".
func algifAEADBlacklisted(confs map[string]string) bool {
for _, body := range confs {
for _, line := range strings.Split(body, "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
if fields[1] != "algif_aead" && fields[1] != "af_alg" {
continue
}
switch fields[0] {
case "blacklist":
return true
case "install":
if len(fields) < 3 {
// Malformed (no replacement command). Don't claim a
// pass on a half-written directive.
continue
}
// Match on the basename of the first token, not a substring
// anywhere in the line. That correctly classifies
// /sbin/modprobe --ignore-install <module> → re-load
// /bin/false → block
// /usr/local/bin/my-modprobe-wrapper → block (wrapper, not modprobe itself)
// Substring matching would have lumped the wrapper case in
// with the re-load case, producing a false-fail alert.
if filepath.Base(fields[2]) == "modprobe" {
continue
}
return true
}
}
}
return false
}
// algifAEAD identifiers shared by the pure helper and the impure wrapper —
// hoisted to package scope so a future edit cannot drift the two copies out
// of sync (which would silently mismatch Name fields between the pass/fail
// path and the warn path).
const (
algifAEADAuditID = "os_algif_aead_blocked"
algifAEADAuditTitle = "AF_ALG (algif_aead) Blocked — CVE-2026-31431"
)
// evaluateAlgifAEAD is the pure, testable core of the algif_aead hardening
// check. `loaded` reports whether algif_aead currently shows up in
// /proc/modules; `confs` is a map of modprobe.d file path → contents.
func evaluateAlgifAEAD(loaded bool, confs map[string]string) store.AuditResult {
blocked := algifAEADBlacklisted(confs)
switch {
case !loaded && blocked:
msg := "algif_aead is blacklisted and not loaded"
if content, ok := confs[afAlgMarkerPath]; ok && validateMarkerContent([]byte(content)) {
msg = "algif_aead is blocked by CSM-managed enforcement (csm harden --copy-fail)"
}
return store.AuditResult{
Category: "os", Name: algifAEADAuditID, Title: algifAEADAuditTitle,
Status: "pass", Message: msg,
}
case loaded:
return store.AuditResult{
Category: "os", Name: algifAEADAuditID, Title: algifAEADAuditTitle,
Status: "fail",
Message: "algif_aead is currently loaded — Copy Fail (CVE-2026-31431) exploitable",
Fix: "echo 'install algif_aead /bin/false' > /etc/modprobe.d/csm-disable-algif.conf && modprobe -r algif_aead af_alg",
}
default:
return store.AuditResult{
Category: "os", Name: algifAEADAuditID, Title: algifAEADAuditTitle,
Status: "fail",
Message: "algif_aead is not loaded but no modprobe.d blacklist exists — module can be loaded on demand",
Fix: "echo 'install algif_aead /bin/false' > /etc/modprobe.d/csm-disable-algif.conf",
}
}
}
// auditAlgifAEAD is the impure wrapper: it reads the running kernel's
// build configuration, KernelCare/livepatch state, /proc/modules, and
// /etc/modprobe.d/*.conf via osFS/cmdExec, then produces an AuditResult.
//
// Decision order:
// 1. KernelCare has applied a Copy Fail livepatch -> pass.
// 2. Kernel has CONFIG_CRYPTO_USER_API_AEAD=y -> fail with a
// truthful message; the modprobe blacklist is ineffective on this
// kernel because the AEAD code is statically linked.
// 3. Otherwise, fall through to the modprobe-state evaluator (the
// existing logic for hosts where AF_ALG is a loadable module).
//
// If any modprobe.d file is unreadable, return a "warn" AuditResult
// naming the offending file rather than silently misreporting.
func auditAlgifAEAD() store.AuditResult {
kernelState := observeAFAlgKernelState()
if kernelState.LivepatchActive {
return store.AuditResult{
Category: "os", Name: algifAEADAuditID, Title: algifAEADAuditTitle,
Status: "pass",
Message: kernelState.String(),
}
}
if kernelState.BuiltIn {
// On built-in kernels the modprobe blacklist is ineffective. Two
// interim options exist: KernelCare livepatch (handled above) or
// per-service seccomp drop-ins. Recognize the seccomp coverage
// here so an operator who ran `csm harden --copy-fail-seccomp`
// gets a truthful pass.
seccomp := SummarizeAFAlgSeccompCoverage()
if len(seccomp.Covered) > 0 && len(seccomp.Uncovered) == 0 {
return store.AuditResult{
Category: "os", Name: algifAEADAuditID, Title: algifAEADAuditTitle,
Status: "pass",
Message: fmt.Sprintf(
"kernel built-in but Copy Fail blocked by seccomp drop-ins on %d units (%s)",
len(seccomp.Covered), strings.Join(seccomp.Covered, ", "),
),
}
}
fixMsg := "Apply KernelCare/kpatch when the CVE-2026-31431 patch ships (kcarectl --update); " +
"or run `csm harden --copy-fail-seccomp` to apply per-service seccomp drop-ins now. " +
"The modprobe blacklist file is harmless but does not protect this kernel."
messageDetail := "modprobe blacklist is ineffective on this kernel and Copy Fail (CVE-2026-31431) is exploitable"
if len(seccomp.Covered) > 0 {
messageDetail = fmt.Sprintf(
"seccomp drop-ins present on %d units but %d candidate units still uncovered (%s)",
len(seccomp.Covered), len(seccomp.Uncovered), strings.Join(seccomp.Uncovered, ", "),
)
}
return store.AuditResult{
Category: "os", Name: algifAEADAuditID, Title: algifAEADAuditTitle,
Status: "fail",
Message: "AF_ALG aead is built into the kernel (CONFIG_CRYPTO_USER_API_AEAD=y); " +
messageDetail,
Fix: fixMsg,
}
}
loaded := false
for _, mod := range loadModuleList() {
if mod == "algif_aead" {
loaded = true
break
}
}
confs := make(map[string]string)
matches, err := osFS.Glob("/etc/modprobe.d/*.conf")
if err == nil {
for _, p := range matches {
data, err := osFS.ReadFile(p)
if err != nil {
return store.AuditResult{
Category: "os", Name: algifAEADAuditID, Title: algifAEADAuditTitle,
Status: "warn",
Message: fmt.Sprintf("Cannot read %s: %v — blacklist state undetermined", p, err),
}
}
confs[p] = string(data)
}
}
return evaluateAlgifAEAD(loaded, confs)
}
// distroEOLPolicy encodes the oldest supported major version per known OS.
// Anything below the minimum is considered EOL by this check.
var distroEOLPolicy = map[platform.OSFamily]int{
platform.OSAlma: 8,
platform.OSRocky: 8,
platform.OSRHEL: 8,
platform.OSCloudLinux: 7,
platform.OSUbuntu: 20, // 20.04 is the oldest non-EOL LTS
platform.OSDebian: 11, // Debian 11 "bullseye"
}
func checkDistroEOL() []store.AuditResult {
return evaluateDistroEOL(platform.Detect(), readOSReleasePretty())
}
// evaluateDistroEOL is the pure, testable core of checkDistroEOL. It returns
// an AuditResult based purely on the supplied platform info and
// PRETTY_NAME string (either may be empty).
func evaluateDistroEOL(info platform.Info, prettyName string) []store.AuditResult {
if prettyName == "" && info.OSVersion != "" {
prettyName = fmt.Sprintf("%s %s", info.OS, info.OSVersion)
}
if info.OS == platform.OSUnknown || info.OSVersion == "" {
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "warn", Message: "Unable to determine distribution version",
}}
}
if info.OS == platform.OSCentOS {
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "fail",
Message: fmt.Sprintf("%s — CentOS is end-of-life", prettyName),
Fix: "Migrate to a supported replacement such as AlmaLinux, Rocky Linux, or RHEL. CentOS no longer receives security patches.",
}}
}
// Extract the major version. Ubuntu/Debian use "24.04" / "12", RHEL
// family uses "8.6" / "10", etc. Taking the integer prefix handles both.
majorStr, _, _ := strings.Cut(info.OSVersion, ".")
major, err := strconv.Atoi(majorStr)
if err != nil {
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "warn", Message: fmt.Sprintf("%s — unable to parse version %q", prettyName, info.OSVersion),
}}
}
minVersion, known := distroEOLPolicy[info.OS]
if !known {
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "warn", Message: fmt.Sprintf("%s — no EOL policy configured for this distro", prettyName),
}}
}
if major < minVersion {
fix := "Upgrade to a supported release. EOL distributions receive no security patches."
if info.IsRHELFamily() {
fix = fmt.Sprintf("Upgrade to %s %d+ or newer. EOL distributions receive no security patches.", info.OS, minVersion)
}
if info.IsDebianFamily() {
fix = fmt.Sprintf("Upgrade to %s %d+ or newer LTS. EOL distributions receive no security patches.", info.OS, minVersion)
}
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "fail",
Message: fmt.Sprintf("%s — major version %d is EOL", prettyName, major),
Fix: fix,
}}
}
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "pass", Message: prettyName,
}}
}
// readOSReleasePretty returns the PRETTY_NAME from /etc/os-release or "".
func readOSReleasePretty() string {
data, err := osFS.ReadFile("/etc/os-release")
if err != nil {
return ""
}
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "PRETTY_NAME=") {
return strings.Trim(strings.TrimPrefix(line, "PRETTY_NAME="), `"'`)
}
}
return ""
}
func checkUnnecessaryServices() []store.AuditResult {
badServices := []string{
"avahi-daemon", "bluetooth", "cups", "cupsd", "gdm",
"ModemManager", "packagekit", "rpcbind", "wpa_supplicant", "firewalld",
}
out, err := auditRunCmd("systemctl", "list-unit-files", "--state=enabled", "--no-pager", "--no-legend")
if err != nil {
return []store.AuditResult{{
Category: "os", Name: "os_services", Title: "Unnecessary Services",
Status: "warn", Message: "Cannot query systemd unit files",
}}
}
var found []string
lines := strings.Split(string(out), "\n")
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) == 0 {
continue
}
unit := strings.TrimSuffix(fields[0], ".service")
for _, bad := range badServices {
if unit == bad {
found = append(found, bad)
}
}
}
if len(found) == 0 {
return []store.AuditResult{{
Category: "os", Name: "os_services", Title: "Unnecessary Services",
Status: "pass", Message: "No unnecessary services enabled",
}}
}
return []store.AuditResult{{
Category: "os", Name: "os_services", Title: "Unnecessary Services",
Status: "warn",
Message: fmt.Sprintf("Unnecessary services enabled: %s", strings.Join(found, ", ")),
Fix: fmt.Sprintf("systemctl disable --now %s", strings.Join(found, " ")),
}}
}
// --- Firewall checks ---
func auditFirewall() []store.AuditResult {
var results []store.AuditResult
// Gather nft and iptables state
nftOut, nftErr := auditRunCmd("nft", "list", "ruleset")
nftRules := string(nftOut)
hasNft := nftErr == nil && strings.TrimSpace(nftRules) != ""
iptOut, iptErr := auditRunCmd("iptables", "-L", "INPUT", "-n")
iptRules := string(iptOut)
hasIpt := iptErr == nil && strings.TrimSpace(iptRules) != ""
// fw_active
if hasNft || hasIpt {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_active", Title: "Firewall Active",
Status: "pass", Message: "Firewall has active rules",
})
} else {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_active", Title: "Firewall Active",
Status: "fail", Message: "No active firewall rules detected",
Fix: "Install and configure nftables or iptables with a default-deny policy.",
})
}
// fw_default_policy
defaultDeny := false
if hasNft {
lower := strings.ToLower(nftRules)
if strings.Contains(lower, "policy drop") || strings.Contains(lower, "policy reject") {
defaultDeny = true
}
}
if !defaultDeny && hasIpt {
for _, line := range strings.Split(iptRules, "\n") {
if strings.HasPrefix(line, "Chain INPUT") {
upper := strings.ToUpper(line)
if strings.Contains(upper, "POLICY DROP") || strings.Contains(upper, "POLICY REJECT") {
defaultDeny = true
}
break
}
}
}
if defaultDeny {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_default_policy", Title: "Default INPUT Policy",
Status: "pass", Message: "INPUT chain has default-deny policy",
})
} else {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_default_policy", Title: "Default INPUT Policy",
Status: "fail", Message: "INPUT chain does not have a DROP/REJECT policy",
Fix: "Set the default INPUT policy to DROP: iptables -P INPUT DROP (or nft equivalent).",
})
}
// fw_mysql_exposed
results = append(results, checkMySQLExposed(hasNft, nftRules, hasIpt, iptRules)...)
// fw_telnet
if isPortListening(23) {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_telnet", Title: "Telnet Service",
Status: "fail", Message: "Something is listening on port 23 (telnet)",
Fix: "Disable telnet: systemctl disable --now telnet.socket xinetd; use SSH instead.",
})
} else {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_telnet", Title: "Telnet Service",
Status: "pass", Message: "Nothing listening on port 23",
})
}
// fw_ipv6
results = append(results, checkIPv6Firewall()...)
return results
}
// getListeningAddr reads /proc/net/tcp for a port in LISTEN state (0A)
// and returns the hex-encoded local IP, or "" if not found.
func getListeningAddr(port int) string {
hexPort := fmt.Sprintf("%04X", port)
for _, path := range []string{"/proc/net/tcp", "/proc/net/tcp6"} {
data, err := osFS.ReadFile(path)
if err != nil {
continue
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) < 4 {
continue
}
// fields[1] = local_address (hex_ip:hex_port), fields[3] = state
if fields[3] != "0A" { // 0A = LISTEN
continue
}
parts := strings.SplitN(fields[1], ":", 2)
if len(parts) != 2 {
continue
}
if parts[1] == hexPort {
return parts[0]
}
}
}
return ""
}
// hexToIPv4 converts a /proc/net/tcp hex IP (little-endian 32-bit) to dotted notation.
func hexToIPv4(h string) string {
if len(h) != 8 {
return h
}
b, err := hex.DecodeString(h)
if err != nil || len(b) != 4 {
return h
}
// /proc/net/tcp stores IPs in little-endian byte order
return fmt.Sprintf("%d.%d.%d.%d", b[3], b[2], b[1], b[0])
}
// isPrivateOrLoopback returns true for loopback, RFC1918, and RFC4193 addresses.
func isPrivateOrLoopback(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
if ip.IsLoopback() {
return true
}
// Check private ranges
privateRanges := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"fc00::/7",
}
for _, cidr := range privateRanges {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if network.Contains(ip) {
return true
}
}
return false
}
// isPortListening checks /proc/net/tcp and /proc/net/tcp6 for a port in LISTEN state.
func isPortListening(port int) bool {
hexPort := fmt.Sprintf("%04X", port)
for _, path := range []string{"/proc/net/tcp", "/proc/net/tcp6"} {
data, err := osFS.ReadFile(path)
if err != nil {
continue
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) < 4 {
continue
}
if fields[3] != "0A" {
continue
}
parts := strings.SplitN(fields[1], ":", 2)
if len(parts) == 2 && parts[1] == hexPort {
return true
}
}
}
return false
}
func checkMySQLExposed(hasNft bool, nftRules string, hasIpt bool, iptRules string) []store.AuditResult {
hexAddr := getListeningAddr(3306)
if hexAddr == "" {
return []store.AuditResult{{
Category: "firewall", Name: "fw_mysql_exposed", Title: "MySQL Exposure",
Status: "pass", Message: "MySQL is not listening on any port",
}}
}
// Convert hex addr to IP and check if private/loopback
var ip string
if len(hexAddr) == 8 {
ip = hexToIPv4(hexAddr)
} else {
// IPv6: 32 hex chars, little-endian 4-byte groups
if b, err := hex.DecodeString(hexAddr); err == nil {
ipBytes := make(net.IP, len(b))
// Reverse each 4-byte group for /proc/net/tcp6 little-endian encoding
for i := 0; i+4 <= len(b); i += 4 {
ipBytes[i] = b[i+3]
ipBytes[i+1] = b[i+2]
ipBytes[i+2] = b[i+1]
ipBytes[i+3] = b[i]
}
ip = ipBytes.String()
}
}
// All zeros = wildcard bind
allZero := true
for _, c := range hexAddr {
if c != '0' {
allZero = false
break
}
}
if !allZero && ip != "" && isPrivateOrLoopback(ip) {
return []store.AuditResult{{
Category: "firewall", Name: "fw_mysql_exposed", Title: "MySQL Exposure",
Status: "pass", Message: fmt.Sprintf("MySQL bound to private/loopback address %s", ip),
}}
}
// Wildcard or public bind — check if firewall blocks 3306
fwBlocks3306 := false
if hasNft && !strings.Contains(nftRules, "3306") {
// If nft has rules but doesn't mention 3306 and has default deny, it's blocked
lower := strings.ToLower(nftRules)
if strings.Contains(lower, "policy drop") || strings.Contains(lower, "policy reject") {
fwBlocks3306 = true
}
}
if !fwBlocks3306 && hasIpt && !strings.Contains(iptRules, "3306") {
for _, line := range strings.Split(iptRules, "\n") {
if strings.HasPrefix(line, "Chain INPUT") {
upper := strings.ToUpper(line)
if strings.Contains(upper, "POLICY DROP") || strings.Contains(upper, "POLICY REJECT") {
fwBlocks3306 = true
}
break
}
}
}
bindDesc := "wildcard (0.0.0.0)"
if !allZero && ip != "" {
bindDesc = ip
}
if fwBlocks3306 {
return []store.AuditResult{{
Category: "firewall", Name: "fw_mysql_exposed", Title: "MySQL Exposure",
Status: "warn",
Message: fmt.Sprintf("MySQL bound to %s but firewall blocks port 3306", bindDesc),
Fix: "Bind MySQL to 127.0.0.1 in /etc/my.cnf: bind-address = 127.0.0.1",
}}
}
return []store.AuditResult{{
Category: "firewall", Name: "fw_mysql_exposed", Title: "MySQL Exposure",
Status: "fail",
Message: fmt.Sprintf("MySQL bound to %s and port 3306 appears accessible", bindDesc),
Fix: "Bind MySQL to 127.0.0.1 in /etc/my.cnf and/or block port 3306 in firewall.",
}}
}
func checkIPv6Firewall() []store.AuditResult {
// Check if any non-loopback, non-link-local IPv6 addresses exist
data, err := osFS.ReadFile("/proc/net/if_inet6")
if err != nil {
return []store.AuditResult{{
Category: "firewall", Name: "fw_ipv6", Title: "IPv6 Firewall",
Status: "pass", Message: "IPv6 not active (cannot read /proc/net/if_inet6)",
}}
}
hasIPv6 := false
for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") {
fields := strings.Fields(line)
if len(fields) < 6 {
continue
}
addr := fields[0]
iface := fields[5]
// Skip loopback
if iface == "lo" {
continue
}
// Skip link-local (fe80::/10)
if strings.HasPrefix(strings.ToLower(addr), "fe80") {
continue
}
hasIPv6 = true
break
}
if !hasIPv6 {
return []store.AuditResult{{
Category: "firewall", Name: "fw_ipv6", Title: "IPv6 Firewall",
Status: "pass", Message: "No non-link-local IPv6 addresses found",
}}
}
// Check nftables for inet/ip6 family chain with input hook and default deny
nftChains, err := auditRunCmd("nft", "list", "chains")
if err == nil {
chainsStr := strings.ToLower(string(nftChains))
// Look for chains in inet or ip6 family with filter hook input
// nft list chains output looks like:
// table inet filter {
// chain input {
// type filter hook input priority filter; policy drop;
// }
// }
var currentFamily string
for _, line := range strings.Split(chainsStr, "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "table ") {
parts := strings.Fields(trimmed)
if len(parts) >= 3 {
currentFamily = parts[1]
}
}
if (currentFamily == "inet" || currentFamily == "ip6") &&
strings.Contains(trimmed, "hook input") {
if strings.Contains(trimmed, "policy drop") || strings.Contains(trimmed, "policy reject") {
return []store.AuditResult{{
Category: "firewall", Name: "fw_ipv6", Title: "IPv6 Firewall",
Status: "pass", Message: fmt.Sprintf("IPv6 active; nftables %s family has default-deny input chain", currentFamily),
}}
}
}
}
}
// Fallback: check ip6tables
ip6Out, err := auditRunCmd("ip6tables", "-L", "INPUT", "-n")
if err == nil {
for _, line := range strings.Split(string(ip6Out), "\n") {
if strings.HasPrefix(line, "Chain INPUT") {
upper := strings.ToUpper(line)
if strings.Contains(upper, "POLICY DROP") || strings.Contains(upper, "POLICY REJECT") {
return []store.AuditResult{{
Category: "firewall", Name: "fw_ipv6", Title: "IPv6 Firewall",
Status: "pass", Message: "IPv6 active; ip6tables INPUT chain has default-deny policy",
}}
}
break
}
}
}
return []store.AuditResult{{
Category: "firewall", Name: "fw_ipv6", Title: "IPv6 Firewall",
Status: "fail", Message: "IPv6 is active but no default-deny input policy found",
Fix: "Configure ip6tables or nftables inet family with a default DROP policy for INPUT.",
}}
}
// --- cPanel/WHM and CloudLinux checks ---
func auditCPanel(serverType string) []store.AuditResult {
var results []store.AuditResult
cpConf := parseCpanelConfig("/var/cpanel/cpanel.config")
// Table-driven boolean checks on cpanel.config.
// fix is the human-readable remediation shown in the UI.
type cpCheck struct {
id, title, key, wantVal string
invert bool // true = fail when value matches wantVal
fix string
}
checks := []cpCheck{
{"cp_ssl_only", "Always Redirect to SSL", "alwaysredirecttossl", "1", false,
"In WHM > Tweak Settings > Redirection, set 'Always redirect to SSL/TLS' to On."},
{"cp_boxtrapper", "BoxTrapper Disabled", "skipboxtrapper", "1", false,
"In WHM > Tweak Settings > Mail, set 'Enable BoxTrapper spam trap' to Off."},
{"cp_password_reset", "Password Reset Disabled", "resetpass", "1", true,
"In WHM > Tweak Settings > System, set 'Reset Password for cPanel accounts' to Off."},
{"cp_password_reset_sub", "Subaccount Password Reset Disabled", "resetpass_sub", "1", true,
"In WHM > Tweak Settings > System, set 'Reset Password for Subaccounts' to Off."},
{"cp_email_passwords", "Email Passwords Disabled", "emailpasswords", "1", true,
"In WHM > Tweak Settings > Security, set 'Send passwords when creating a new account' to Off."},
{"cp_cookie_validation", "Cookie IP Validation", "cookieipvalidation", "strict", false,
"In WHM > Tweak Settings > Security, set 'Cookie IP validation' to strict."},
{"cp_remote_domains", "Remote Domains Disabled", "allowremotedomains", "1", true,
"In WHM > Tweak Settings > Domains, set 'Allow Remote Domains' to Off."},
{"cp_core_dumps", "Core Dumps Disabled", "coredump", "1", true,
"In WHM > Tweak Settings > Security, set 'Generate core dumps' to Off."},
{"cp_nobodyspam", "Nobody Spam Prevention", "nobodyspam", "1", false,
"In WHM > Tweak Settings > Mail, set 'Prevent nobody from sending mail' to On."},
}
for _, c := range checks {
val := cpConf[c.key]
var pass bool
if c.invert {
pass = val != c.wantVal
} else {
pass = val == c.wantVal
}
if pass {
results = append(results, store.AuditResult{
Category: "cpanel", Name: c.id, Title: c.title,
Status: "pass", Message: fmt.Sprintf("%s = %s", c.key, val),
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: c.id, Title: c.title,
Status: "fail", Message: fmt.Sprintf("%s = %s", c.key, val),
Fix: c.fix,
})
}
}
// cp_max_emails_hour
maxEmail := cpConf["maxemailsperhour"]
if maxEmail != "" && maxEmail != "0" {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_max_emails_hour", Title: "Max Emails Per Hour",
Status: "pass", Message: fmt.Sprintf("maxemailsperhour = %s", maxEmail),
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_max_emails_hour", Title: "Max Emails Per Hour",
Status: "fail", Message: "maxemailsperhour is not set or is 0",
Fix: "In WHM > Tweak Settings, set 'Max emails per hour per domain' to a reasonable limit (e.g., 200).",
})
}
// cp_compilers: check /usr/bin/cc permissions
if info, err := osFS.Stat("/usr/bin/cc"); err == nil {
perm := info.Mode().Perm()
if perm <= 0o750 {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_compilers", Title: "Compiler Access Restricted",
Status: "pass", Message: fmt.Sprintf("/usr/bin/cc has mode %04o", perm),
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_compilers", Title: "Compiler Access Restricted",
Status: "fail", Message: fmt.Sprintf("/usr/bin/cc has mode %04o (should be <= 0750)", perm),
Fix: "WHM > Security Center > Compiler Access, or: chmod 750 /usr/bin/cc",
})
}
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_compilers", Title: "Compiler Access Restricted",
Status: "pass", Message: "No compiler found at /usr/bin/cc",
})
}
// cp_ftp_anonymous: parse /etc/pure-ftpd.conf
if data, err := osFS.ReadFile("/etc/pure-ftpd.conf"); err == nil {
noAnon := false
for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "#") {
continue
}
if strings.HasPrefix(trimmed, "NoAnonymous") {
parts := strings.Fields(trimmed)
if len(parts) >= 2 && strings.EqualFold(parts[1], "yes") {
noAnon = true
}
}
}
if noAnon {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_ftp_anonymous", Title: "Anonymous FTP Disabled",
Status: "pass", Message: "NoAnonymous is enabled in pure-ftpd",
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_ftp_anonymous", Title: "Anonymous FTP Disabled",
Status: "fail", Message: "Anonymous FTP may be enabled",
Fix: "Set 'NoAnonymous yes' in /etc/pure-ftpd.conf and restart pure-ftpd.",
})
}
}
// cp_updates: parse /etc/cpupdate.conf
if data, err := osFS.ReadFile("/etc/cpupdate.conf"); err == nil {
updatesDaily := false
for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "#") {
continue
}
if strings.HasPrefix(strings.ToUpper(trimmed), "UPDATES=") {
val := strings.TrimPrefix(trimmed, trimmed[:strings.Index(trimmed, "=")+1])
if strings.EqualFold(strings.TrimSpace(val), "daily") {
updatesDaily = true
}
}
}
if updatesDaily {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_updates", Title: "cPanel Auto-Updates",
Status: "pass", Message: "UPDATES=daily in cpupdate.conf",
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_updates", Title: "cPanel Auto-Updates",
Status: "warn", Message: "cPanel auto-updates not set to daily",
Fix: "Set UPDATES=daily in /etc/cpupdate.conf or WHM > Update Preferences.",
})
}
}
// CloudLinux-specific checks
if serverType == "cloudlinux" {
results = append(results, auditCloudLinux()...)
}
return results
}
func parseCpanelConfig(path string) map[string]string {
conf := make(map[string]string)
data, err := osFS.ReadFile(path)
if err != nil {
return conf
}
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if idx := strings.Index(line, "="); idx > 0 {
key := strings.TrimSpace(line[:idx])
val := strings.TrimSpace(line[idx+1:])
conf[key] = val
}
}
return conf
}
func auditCloudLinux() []store.AuditResult {
var results []store.AuditResult
// cl_cagefs
out, err := auditRunCmd("cagefsctl", "--cagefs-status")
switch {
case err != nil:
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_cagefs", Title: "CageFS Enabled",
Status: "warn", Message: "Cannot check CageFS status",
})
case strings.Contains(string(out), "Enabled"):
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_cagefs", Title: "CageFS Enabled",
Status: "pass", Message: "CageFS is enabled",
})
default:
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_cagefs", Title: "CageFS Enabled",
Status: "fail", Message: "CageFS is not enabled",
Fix: "Enable CageFS: cagefsctl --enable-all",
})
}
// cl_symlink_protection
if data, err := osFS.ReadFile("/proc/sys/fs/enforce_symlinksifowner"); err == nil {
val := strings.TrimSpace(string(data))
n, _ := strconv.Atoi(val)
if n >= 1 {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_symlink_protection", Title: "CloudLinux Symlink Protection",
Status: "pass", Message: fmt.Sprintf("enforce_symlinksifowner = %s", val),
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_symlink_protection", Title: "CloudLinux Symlink Protection",
Status: "fail", Message: fmt.Sprintf("enforce_symlinksifowner = %s (expected >= 1)", val),
Fix: "sysctl -w fs.enforce_symlinksifowner=1",
})
}
}
// cl_proc_virtualization
if data, err := osFS.ReadFile("/proc/sys/fs/proc_can_see_other_uid"); err == nil {
val := strings.TrimSpace(string(data))
if val == "0" {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_proc_virtualization", Title: "CloudLinux /proc Virtualization",
Status: "pass", Message: "proc_can_see_other_uid = 0",
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_proc_virtualization", Title: "CloudLinux /proc Virtualization",
Status: "fail", Message: fmt.Sprintf("proc_can_see_other_uid = %s (expected 0)", val),
Fix: "sysctl -w fs.proc_can_see_other_uid=0",
})
}
}
return results
}
// --- PHP checks ---
func auditPHP(serverType string) []store.AuditResult {
var results []store.AuditResult
type phpInstall struct {
version string // e.g. "8.1"
shortID string // e.g. "81"
iniPath string
fpmDir string // for FPM pool override merging
}
var installs []phpInstall
// cPanel EA4 PHP installs
eaInis, _ := osFS.Glob("/opt/cpanel/ea-php*/root/etc/php.ini")
for _, ini := range eaInis {
// Extract version from path: /opt/cpanel/ea-php81/root/etc/php.ini -> "81"
dir := filepath.Dir(filepath.Dir(filepath.Dir(ini))) // /opt/cpanel/ea-php81/root -> /opt/cpanel/ea-php81
base := filepath.Base(dir) // ea-php81
shortID := strings.TrimPrefix(base, "ea-php")
if len(shortID) >= 2 {
ver := shortID[:len(shortID)-1] + "." + shortID[len(shortID)-1:]
fpmDir := filepath.Join(dir, "root", "etc", "php-fpm.d")
installs = append(installs, phpInstall{
version: ver,
shortID: shortID,
iniPath: ini,
fpmDir: fpmDir,
})
}
}
// CloudLinux alt-php installs (skip Imunify360's internal PHP builds)
if serverType == "cloudlinux" {
altInis, _ := osFS.Glob("/opt/alt/php*/etc/php.ini")
for _, ini := range altInis {
dir := filepath.Dir(filepath.Dir(ini)) // /opt/alt/php81
base := filepath.Base(dir) // php81
if strings.Contains(base, "-") {
continue // skip php74-imunify, php81-hardened, etc.
}
shortID := strings.TrimPrefix(base, "php")
if len(shortID) >= 2 {
ver := shortID[:len(shortID)-1] + "." + shortID[len(shortID)-1:]
installs = append(installs, phpInstall{
version: ver,
shortID: shortID,
iniPath: ini,
})
}
}
}
// Bare server fallback
if len(installs) == 0 {
out, err := auditRunCmd("php", "-i")
if err == nil {
for _, line := range strings.Split(string(out), "\n") {
if strings.HasPrefix(line, "Loaded Configuration File") {
parts := strings.SplitN(line, "=>", 2)
if len(parts) == 2 {
iniPath := strings.TrimSpace(parts[1])
if iniPath != "(none)" && iniPath != "" {
installs = append(installs, phpInstall{
version: "unknown",
shortID: "system",
iniPath: iniPath,
})
}
}
}
}
}
// Try to get version for bare
if len(installs) > 0 && installs[0].version == "unknown" {
vout, verr := auditRunCmd("php", "-v")
if verr == nil {
first := strings.SplitN(string(vout), "\n", 2)[0]
// "PHP 8.2.15 (cli) ..."
fields := strings.Fields(first)
if len(fields) >= 2 {
verParts := strings.SplitN(fields[1], ".", 3)
if len(verParts) >= 2 {
installs[0].version = verParts[0] + "." + verParts[1]
installs[0].shortID = verParts[0] + verParts[1]
}
}
}
}
}
for _, inst := range installs {
data, err := osFS.ReadFile(inst.iniPath)
if err != nil {
continue
}
ini := parsePHPIni(string(data))
// Merge FPM pool overrides if available
if inst.fpmDir != "" {
poolConfs, _ := osFS.Glob(filepath.Join(inst.fpmDir, "*.conf"))
for _, pc := range poolConfs {
pdata, perr := osFS.ReadFile(pc)
if perr != nil {
continue
}
for _, line := range strings.Split(string(pdata), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, ";") {
continue
}
// php_admin_value[key] = val or php_value[key] = val
for _, prefix := range []string{"php_admin_value[", "php_value["} {
if strings.HasPrefix(line, prefix) {
rest := strings.TrimPrefix(line, prefix)
if idx := strings.Index(rest, "]"); idx > 0 {
key := rest[:idx]
valPart := rest[idx+1:]
if eqIdx := strings.Index(valPart, "="); eqIdx >= 0 {
val := strings.TrimSpace(valPart[eqIdx+1:])
ini[key] = val
}
}
}
}
}
}
}
suffix := inst.shortID
// php_version check
major, minor := parsePHPVersion(inst.version)
if major > 0 {
if major < 8 || (major == 8 && minor < 1) {
results = append(results, store.AuditResult{
Category: "php", Name: "php_version_" + suffix, Title: fmt.Sprintf("PHP %s Version", inst.version),
Status: "fail", Message: fmt.Sprintf("PHP %s is end-of-life", inst.version),
Fix: fmt.Sprintf("Upgrade PHP %s to 8.1 or later. EOL versions receive no security patches.", inst.version),
})
} else {
results = append(results, store.AuditResult{
Category: "php", Name: "php_version_" + suffix, Title: fmt.Sprintf("PHP %s Version", inst.version),
Status: "pass", Message: fmt.Sprintf("PHP %s is supported", inst.version),
})
}
}
// php_disable_functions
df := strings.TrimSpace(ini["disable_functions"])
if df == "" || strings.EqualFold(df, "none") {
results = append(results, store.AuditResult{
Category: "php", Name: "php_disable_functions_" + suffix, Title: fmt.Sprintf("PHP %s disable_functions", inst.version),
Status: "fail", Message: "disable_functions is empty",
Fix: fmt.Sprintf("Set disable_functions in %s to include dangerous functions like exec, system, passthru, shell_exec, popen, proc_open.", inst.iniPath),
})
} else {
results = append(results, store.AuditResult{
Category: "php", Name: "php_disable_functions_" + suffix, Title: fmt.Sprintf("PHP %s disable_functions", inst.version),
Status: "pass", Message: "disable_functions is configured",
})
}
// php_expose
expose := strings.TrimSpace(strings.ToLower(ini["expose_php"]))
if expose == "off" || expose == "0" {
results = append(results, store.AuditResult{
Category: "php", Name: "php_expose_" + suffix, Title: fmt.Sprintf("PHP %s expose_php", inst.version),
Status: "pass", Message: "expose_php is off",
})
} else {
results = append(results, store.AuditResult{
Category: "php", Name: "php_expose_" + suffix, Title: fmt.Sprintf("PHP %s expose_php", inst.version),
Status: "warn", Message: "expose_php is on — PHP version disclosed in headers",
Fix: fmt.Sprintf("Set expose_php = Off in %s", inst.iniPath),
})
}
// php_allow_url_fopen
auf := strings.TrimSpace(strings.ToLower(ini["allow_url_fopen"]))
if auf == "off" || auf == "0" {
results = append(results, store.AuditResult{
Category: "php", Name: "php_allow_url_fopen_" + suffix, Title: fmt.Sprintf("PHP %s allow_url_fopen", inst.version),
Status: "pass", Message: "allow_url_fopen is off",
})
} else {
results = append(results, store.AuditResult{
Category: "php", Name: "php_allow_url_fopen_" + suffix, Title: fmt.Sprintf("PHP %s allow_url_fopen", inst.version),
Status: "warn", Message: "allow_url_fopen is on — remote file inclusion risk",
Fix: fmt.Sprintf("Set allow_url_fopen = Off in %s", inst.iniPath),
})
}
// php_enable_dl
edl := strings.TrimSpace(strings.ToLower(ini["enable_dl"]))
if edl == "on" || edl == "1" {
results = append(results, store.AuditResult{
Category: "php", Name: "php_enable_dl_" + suffix, Title: fmt.Sprintf("PHP %s enable_dl", inst.version),
Status: "fail", Message: "enable_dl is on — allows loading arbitrary shared objects",
Fix: fmt.Sprintf("Set enable_dl = Off in %s", inst.iniPath),
})
} else {
results = append(results, store.AuditResult{
Category: "php", Name: "php_enable_dl_" + suffix, Title: fmt.Sprintf("PHP %s enable_dl", inst.version),
Status: "pass", Message: "enable_dl is off",
})
}
}
return results
}
func parsePHPIni(content string) map[string]string {
ini := make(map[string]string)
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, ";") || strings.HasPrefix(line, "[") {
continue
}
if idx := strings.Index(line, "="); idx > 0 {
key := strings.TrimSpace(line[:idx])
val := strings.TrimSpace(line[idx+1:])
ini[key] = val
}
}
return ini
}
func parsePHPVersion(ver string) (int, int) {
parts := strings.SplitN(ver, ".", 3)
if len(parts) < 2 {
return 0, 0
}
major, _ := strconv.Atoi(parts[0])
minor, _ := strconv.Atoi(parts[1])
return major, minor
}
// --- Web server checks ---
func auditWebServer(serverType string) []store.AuditResult {
var results []store.AuditResult
// Find main config file
var configPath string
configPaths := []string{
"/etc/apache2/conf/httpd.conf", // cPanel EA4
"/usr/local/lsws/conf/httpd_config.xml", // LiteSpeed
"/etc/httpd/conf/httpd.conf", // RHEL/CentOS bare
"/etc/apache2/apache2.conf", // Debian/Ubuntu bare
}
for _, p := range configPaths {
if _, err := osFS.Stat(p); err == nil {
configPath = p
break
}
}
if configPath != "" {
configData, err := osFS.ReadFile(configPath)
if err == nil {
configContent := string(configData)
// Table-driven directive checks
type directiveCheck struct {
id, title, directive string
goodValues []string
}
dirChecks := []directiveCheck{
{"web_server_tokens", "ServerTokens", "ServerTokens", []string{"prod", "productonly"}},
{"web_server_signature", "ServerSignature", "ServerSignature", []string{"off"}},
{"web_trace_enable", "TraceEnable", "TraceEnable", []string{"off"}},
{"web_file_etag", "FileETag", "FileETag", []string{"none"}},
}
for _, dc := range dirChecks {
found := false
var foundVal string
for _, line := range strings.Split(configContent, "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "#") {
continue
}
if strings.HasPrefix(strings.ToLower(trimmed), strings.ToLower(dc.directive)+" ") {
parts := strings.Fields(trimmed)
if len(parts) >= 2 {
foundVal = parts[1]
found = true
}
}
}
if !found {
results = append(results, store.AuditResult{
Category: "webserver", Name: dc.id, Title: dc.title,
Status: "warn", Message: fmt.Sprintf("%s not set in %s", dc.directive, configPath),
Fix: fmt.Sprintf("Add '%s %s' to %s", dc.directive, dc.goodValues[0], configPath),
})
continue
}
isGood := false
for _, gv := range dc.goodValues {
if strings.EqualFold(foundVal, gv) {
isGood = true
break
}
}
if isGood {
results = append(results, store.AuditResult{
Category: "webserver", Name: dc.id, Title: dc.title,
Status: "pass", Message: fmt.Sprintf("%s = %s", dc.directive, foundVal),
})
} else {
results = append(results, store.AuditResult{
Category: "webserver", Name: dc.id, Title: dc.title,
Status: "fail", Message: fmt.Sprintf("%s = %s", dc.directive, foundVal),
Fix: fmt.Sprintf("Set '%s %s' in %s", dc.directive, dc.goodValues[0], configPath),
})
}
}
// Directory listing check
hasIndexes := false
for _, line := range strings.Split(configContent, "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "#") {
continue
}
lower := strings.ToLower(trimmed)
if strings.Contains(lower, "options") && strings.Contains(lower, "indexes") && !strings.Contains(lower, "-indexes") {
hasIndexes = true
break
}
}
if hasIndexes {
results = append(results, store.AuditResult{
Category: "webserver", Name: "web_directory_listing", Title: "Directory Listing",
Status: "warn", Message: "Global config enables directory listing (Options Indexes)",
Fix: "Replace 'Indexes' with '-Indexes' in Options directives.",
})
} else {
results = append(results, store.AuditResult{
Category: "webserver", Name: "web_directory_listing", Title: "Directory Listing",
Status: "pass", Message: "No global directory listing enabled",
})
}
}
}
// TLS version checks: probe with openssl
for _, tc := range []struct {
id, title, flag, version string
}{
{"web_tls_version", "Legacy TLS Disabled", "-tls1", "TLSv1.0"},
{"web_tls11_version", "TLS 1.1 Disabled", "-tls1_1", "TLSv1.1"},
} {
out, err := auditRunCmd("openssl", "s_client", "-connect", "localhost:443", tc.flag)
output := string(out)
// If the handshake succeeds, output contains "SSL-Session:" without ":error:" on the same handshake
succeeded := err == nil && strings.Contains(output, "SSL-Session:") && !strings.Contains(output, ":error:")
if succeeded {
results = append(results, store.AuditResult{
Category: "webserver", Name: tc.id, Title: tc.title,
Status: "fail", Message: fmt.Sprintf("Server accepts %s connections", tc.version),
Fix: fmt.Sprintf("Disable %s in your web server's SSL configuration. Minimum should be TLSv1.2.", tc.version),
})
} else {
results = append(results, store.AuditResult{
Category: "webserver", Name: tc.id, Title: tc.title,
Status: "pass", Message: fmt.Sprintf("%s is rejected", tc.version),
})
}
}
return results
}
// --- Mail checks ---
func auditMail() []store.AuditResult {
var results []store.AuditResult
// mail_root_forwarder
if info, err := osFS.Stat("/root/.forward"); err != nil {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_root_forwarder", Title: "Root Mail Forwarder",
Status: "warn", Message: "/root/.forward does not exist — root mail may go unread",
Fix: "Create /root/.forward with an email address to receive root's mail.",
})
} else if info.Size() == 0 {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_root_forwarder", Title: "Root Mail Forwarder",
Status: "warn", Message: "/root/.forward is empty — root mail may go unread",
Fix: "Add an email address to /root/.forward to receive root's mail.",
})
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_root_forwarder", Title: "Root Mail Forwarder",
Status: "pass", Message: "Root mail forwarding is configured",
})
}
// Get exim config for multiple checks
eximOut, eximErr := auditRunCmd("exim", "-bP")
eximConfig := ""
if eximErr == nil {
eximConfig = string(eximOut)
}
// mail_exim_logging
if eximConfig != "" {
lower := strings.ToLower(eximConfig)
if strings.Contains(lower, "+arguments") || strings.Contains(lower, "+all") {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_exim_logging", Title: "Exim Argument Logging",
Status: "pass", Message: "Exim logs include +arguments",
})
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_exim_logging", Title: "Exim Argument Logging",
Status: "warn", Message: "Exim log_selector does not include +arguments",
Fix: "Add '+arguments' to log_selector in exim configuration for better forensics.",
})
}
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_exim_logging", Title: "Exim Argument Logging",
Status: "warn", Message: "Cannot query exim configuration",
})
}
// mail_exim_tls: check for SSLv2 in tls_require_ciphers
// +no_sslv2 in openssl_options means SSLv2 is DISABLED (good).
// Only flag if SSLv2 is referenced WITHOUT +no_ prefix.
if eximConfig != "" {
lower := strings.ToLower(eximConfig)
hasSslv2 := strings.Contains(lower, "sslv2")
isDisabled := strings.Contains(lower, "+no_sslv2") || strings.Contains(lower, "no_sslv2")
if hasSslv2 && !isDisabled {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_exim_tls", Title: "Exim TLS Ciphers",
Status: "fail", Message: "Exim allows SSLv2 connections",
Fix: "Add '+no_sslv2' to openssl_options in exim configuration.",
})
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_exim_tls", Title: "Exim TLS Ciphers",
Status: "pass", Message: "SSLv2 is disabled in exim TLS configuration",
})
}
}
// mail_secure_auth: check /etc/exim.conf.localopts
if data, err := osFS.ReadFile("/etc/exim.conf.localopts"); err == nil {
content := string(data)
if strings.Contains(content, "require_secure_auth=0") {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_secure_auth", Title: "Exim Secure Authentication",
Status: "fail", Message: "require_secure_auth is disabled in /etc/exim.conf.localopts",
Fix: "Remove or set require_secure_auth=1 in /etc/exim.conf.localopts.",
})
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_secure_auth", Title: "Exim Secure Authentication",
Status: "pass", Message: "Secure authentication is not disabled",
})
}
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_secure_auth", Title: "Exim Secure Authentication",
Status: "pass", Message: "No local exim overrides file found (default is secure)",
})
}
// mail_dovecot_tls: check ssl_min_protocol
// Use 'doveconf -a' for the effective config — cPanel manages Dovecot
// settings outside the standard config files, so file parsing misses
// them. Routed through cmdExec so tests can mock the doveconf output.
dovecotTLS := false
if out, err := cmdExec.Run("doveconf", "-a"); err == nil {
for _, line := range strings.Split(string(out), "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "ssl_min_protocol") {
val := strings.TrimSpace(strings.TrimPrefix(trimmed, "ssl_min_protocol"))
val = strings.TrimLeft(val, "= ")
if strings.Contains(val, "TLSv1.2") || strings.Contains(val, "TLSv1.3") {
dovecotTLS = true
}
}
}
} else {
// Fallback: try config files
for _, path := range []string{"/etc/dovecot/conf.d/10-ssl.conf", "/etc/dovecot/dovecot.conf"} {
data, readErr := osFS.ReadFile(path)
if readErr != nil {
continue
}
for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "#") {
continue
}
if strings.HasPrefix(trimmed, "ssl_min_protocol") {
val := strings.TrimSpace(strings.TrimPrefix(trimmed, "ssl_min_protocol"))
val = strings.TrimLeft(val, "= ")
if strings.Contains(val, "TLSv1.2") || strings.Contains(val, "TLSv1.3") {
dovecotTLS = true
}
}
}
}
}
if dovecotTLS {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_dovecot_tls", Title: "Dovecot TLS Minimum",
Status: "pass", Message: "Dovecot ssl_min_protocol is TLSv1.2 or higher",
})
} else {
if _, err := osFS.Stat("/etc/dovecot/dovecot.conf"); err != nil {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_dovecot_tls", Title: "Dovecot TLS Minimum",
Status: "warn", Message: "Dovecot configuration not found",
})
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_dovecot_tls", Title: "Dovecot TLS Minimum",
Status: "fail", Message: "Dovecot ssl_min_protocol not set to TLSv1.2 or higher",
Fix: "Set 'ssl_min_protocol = TLSv1.2' in /etc/dovecot/conf.d/10-ssl.conf.",
})
}
}
return results
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/bpf"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/state"
)
// CheckHealth verifies that CSM's dependencies are working.
// Reports on missing external commands, broken auditd, etc.
func CheckHealth(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
info := platform.Detect()
// Required commands depend on the platform. On plain Linux hosts we
// don't need Exim/cPanel-specific tooling.
requiredCmds := platformRequiredCommands(info)
for _, cmd := range requiredCmds {
if _, err := cmdExec.LookPath(cmd); err != nil {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "csm_health",
Message: fmt.Sprintf("Required command not found: %s", cmd),
Details: "Some checks will be skipped",
})
}
}
// Optional commands. Only complain about cPanel tools on cPanel hosts.
optionalCmds := map[string]string{
"wp": "WordPress core integrity check will be skipped",
}
if info.IsCPanel() {
optionalCmds["whmapi1"] = "WHM API token check will be skipped"
}
for cmd, impact := range optionalCmds {
if _, err := cmdExec.LookPath(cmd); err != nil {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "csm_health",
Message: fmt.Sprintf("Optional command not found: %s", cmd),
Details: impact,
})
}
}
// Check auditd is running and has CSM rules
out, _ := runCmd("auditctl", "-l")
if out != nil {
rules := string(out)
if !strings.Contains(rules, "csm_shadow_change") {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "csm_health",
Message: "auditd CSM rules not loaded",
Details: "Run 'csm install' to deploy auditd rules, then 'service auditd restart'",
})
}
}
if cfg != nil && cfg.BPFEnforcement.Enabled && cfg.BPFEnforcement.DirectSMTPEgress {
switch active := bpf.ActiveKind("connection_tracker"); active {
case bpf.BackendLegacy, bpf.BackendNone:
message := "BPF enforcement enabled but connection tracker is running on legacy backend"
if active == bpf.BackendNone {
message = "BPF enforcement enabled but connection tracker has no active backend"
}
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "csm_health",
Message: message,
Details: "bpf_enforcement.direct_smtp_egress requires the connection tracker BPF backend. Check kernel version, LSM availability, or CAP_BPF.",
})
}
}
// Check state directory is writable
stateDir := "/var/lib/csm/state"
if cfg != nil && cfg.StatePath != "" {
stateDir = cfg.StatePath
}
testFile := filepath.Join(stateDir, ".health_check")
if err := osFS.WriteFile(testFile, []byte("ok"), 0600); err != nil {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "csm_health",
Message: fmt.Sprintf("State directory not writable: %s", stateDir),
Details: err.Error(),
})
} else {
_ = osFS.Remove(testFile)
}
return findings
}
// platformRequiredCommands returns the external commands CSM needs on the
// detected platform. On plain Linux hosts Exim is not required.
func platformRequiredCommands(info platform.Info) []string {
cmds := []string{"find", "auditctl"}
if info.IsCPanel() {
cmds = append(cmds, "exim")
}
return cmds
}
package checks
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
)
const cmdTimeout = 2 * time.Minute
var systemCommandSearchDirs = []string{
"/usr/local/sbin",
"/usr/sbin",
"/sbin",
"/usr/local/bin",
"/usr/bin",
"/bin",
}
func hashFileContent(path string) (string, error) {
data, err := osFS.ReadFile(path)
if err != nil {
return "", err
}
h := sha256.Sum256(data)
return fmt.Sprintf("%x", h[:]), nil
}
func hashBytes(data []byte) string {
h := sha256.Sum256(data)
return fmt.Sprintf("%x", h[:])
}
// runCmd delegates to the package-level cmdExec provider.
// Check functions call runCmd; tests swap cmdExec via SetCmdRunner.
func runCmd(name string, args ...string) ([]byte, error) {
return cmdExec.Run(name, args...)
}
func runCmdAllowNonZero(name string, args ...string) ([]byte, error) {
return cmdExec.RunAllowNonZero(name, args...)
}
func runCmdCombinedContext(parent context.Context, name string, args ...string) ([]byte, error) {
return cmdExec.RunContext(parent, name, args...)
}
func lookupSystemCommand(name string) (string, error) {
if strings.ContainsRune(name, os.PathSeparator) {
return exec.LookPath(name)
}
path, err := exec.LookPath(name)
if err == nil {
return path, nil
}
for _, dir := range systemCommandSearchDirs {
candidate := filepath.Join(dir, name)
info, statErr := os.Stat(candidate)
if statErr == nil && !info.IsDir() && info.Mode()&0111 != 0 {
return candidate, nil
}
}
return "", err
}
func resolveSystemCommand(name string) string {
path, err := lookupSystemCommand(name)
if err != nil {
return name
}
return path
}
// ---------------------------------------------------------------------------
// Real implementations — used by realCmd in provider.go
//
// Every caller of these helpers passes a constant system command name
// (nft, rpm, wp-cli, doveadm, systemctl, etc.) with arguments built from
// either config, filesystem state, or CSM-generated data. Nothing here
// reaches out to HTTP request bodies or webui form inputs. The gosec
// G204 suppressions below are in that trust model.
// ---------------------------------------------------------------------------
func runCmdReal(name string, args ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), cmdTimeout)
defer cancel()
// #nosec G204 -- see package-level trust note above.
out, err := exec.CommandContext(ctx, resolveSystemCommand(name), args...).Output()
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "Command timed out: %s %v\n", name, args)
return nil, nil
}
return out, err
}
func runCmdAllowNonZeroReal(name string, args ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), cmdTimeout)
defer cancel()
// #nosec G204 -- see package-level trust note above.
out, err := exec.CommandContext(ctx, resolveSystemCommand(name), args...).Output()
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "Command timed out: %s %v\n", name, args)
return nil, nil
}
var exitErr *exec.ExitError
if err != nil && errors.As(err, &exitErr) {
return out, nil
}
return out, err
}
func runCmdCombinedContextReal(parent context.Context, name string, args ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(parent, cmdTimeout)
defer cancel()
// #nosec G204 -- see package-level trust note above.
out, err := exec.CommandContext(ctx, resolveSystemCommand(name), args...).CombinedOutput()
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "Command timed out: %s %v\n", name, args)
return nil, nil
}
if parent.Err() != nil {
return nil, parent.Err()
}
return out, err
}
// runCmdStdoutContextReal runs a command with a per-call timeout and returns
// stdout only. Stderr is discarded so chatter from the child process (PHP
// warnings, MySQL deprecation notices, wp-cli plugin backtraces, ...) cannot
// poison parsers that expect JSON/URL bytes on stdout. On timeout the caller
// receives context.DeadlineExceeded rather than a silent (nil, nil), so it
// can distinguish a hung command from a legitimately empty result.
func runCmdStdoutContextReal(parent context.Context, name string, args ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(parent, cmdTimeout)
defer cancel()
// #nosec G204 -- see package-level trust note above.
out, err := exec.CommandContext(ctx, resolveSystemCommand(name), args...).Output()
if ctx.Err() == context.DeadlineExceeded {
return nil, context.DeadlineExceeded
}
if parent.Err() != nil {
return nil, parent.Err()
}
return out, err
}
func runCmdWithEnvReal(name string, args []string, extraEnv ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), cmdTimeout)
defer cancel()
// #nosec G204 -- see package-level trust note above.
cmd := exec.CommandContext(ctx, resolveSystemCommand(name), args...)
cmd.Env = append(os.Environ(), extraEnv...)
out, err := cmd.Output()
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "Command timed out: %s\n", name)
return nil, nil
}
return out, err
}
package checks
import (
"encoding/json"
"fmt"
"net/url"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// .htaccess hardened detection / cleaning.
//
// Detection emits one of seven specific finding names so operators
// can suppress, route, or auto-respond per attack pattern instead of
// relying on the generic htaccess_injection / htaccess_handler_abuse
// categories. Cleaning is gated by AutoResponse.CleanHtaccess and is
// always backed up under /opt/csm/quarantine/pre_clean/.
//
// Each detector returns matches as byte ranges into the file
// content; the cleaner merges all ranges (deduplicating overlaps),
// removes them, and writes the result atomically. If post-clean
// content is identical to pre-clean (no detector matched anything
// new), no write happens and no backup is created.
// htaccessBackupDirRoot is the parent directory under which
// CleanHtaccessFile writes <ts>_<sanitized-path> backups. Exposed
// as a package var so tests can redirect it to a t.TempDir().
var htaccessBackupDirRoot = "/opt/csm/quarantine/pre_clean"
// htaccessByteRange is a half-open [start, end) byte slice into the
// file content. Cleaning removes the bytes; the line ending after
// `end` is included when `end` falls just before a `\n` so we do
// not leave a blank line behind.
type htaccessByteRange struct {
Start int
End int
}
// htaccessMatch is one finding-worthy hit returned by a detector.
type htaccessMatch struct {
Range htaccessByteRange
Excerpt string // the offending line(s), trimmed for the finding details
}
// htaccessDetector pairs a finding category with a function that
// scans the content and reports every range the detector wants
// removed. Adding an 8th pattern is one entry in this slice.
type htaccessDetector struct {
Name string
Severity alert.Severity
Detect func(content []byte, path string) []htaccessMatch
}
// htaccessSpamTLDs lists TLDs commonly abused for spam-redirect
// .htaccess injections. Operators on legitimate hosts in these TLDs
// can suppress the finding by file path.
var htaccessSpamTLDs = []string{
".xyz", ".tk", ".ml", ".ga", ".cf", ".gq", ".click",
".country", ".loan", ".work", ".top",
}
// htaccessNonScriptDirHints names directory components where PHP
// execution is rarely legitimate. .htaccess files inside one of
// these get the htaccess_php_in_uploads finding when they map
// non-PHP extensions to a PHP handler.
//
// /tmp/ is intentionally NOT in this list even though attackers
// drop payloads there: a real-world .htaccess inside /tmp/ would
// only be reached if the webserver served /tmp/, which is rare
// outside misconfigurations -- and including it caused
// false-positive matches against Linux t.TempDir() paths under
// /tmp/ in the test suite. The auto_prepend detector covers the
// /tmp/ payload-target angle separately.
var htaccessNonScriptDirHints = []string{
"/uploads/", "/images/", "/cache/",
"/wp-content/uploads/", "/wp-content/cache/",
"/files/", "/media/",
}
// htaccessSuspiciousAutoPrependPaths are filesystem locations that
// auto_prepend_file should never reference. Anything in /tmp/,
// /dev/shm/, /var/tmp/ or pointing at an image extension is
// always-malicious.
var htaccessSuspiciousAutoPrependPaths = []string{
"/tmp/", "/dev/shm/", "/var/tmp/",
}
var htaccessImageExtensions = []string{".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".ico"}
// htaccessTrackingHeaders is a small allowlist of header *names*
// known to be used in injection campaigns. Scoped intentionally;
// false positives on legitimate analytics/CDN headers are worse
// than missed detections here.
var htaccessTrackingHeaders = []string{
"X-Track-", "X-Affiliate-", "X-Promo-", "X-Click-ID",
}
var (
rePHPHandlerMap = regexp.MustCompile(`(?im)^\s*(AddHandler|SetHandler|ForceType)\s+\S*php\S*\s+([^\n]+)$`)
// Match both forms because mod_php and some LSAPI builds honor either
// directive in .htaccess.
reAutoPrepend = regexp.MustCompile(`(?im)^\s*php(?:_admin)?_value\s+auto_prepend_file\s+(\S+)`)
reUACloakCond = regexp.MustCompile(`(?im)^\s*RewriteCond\s+%\{HTTP_USER_AGENT\}\s+([^\n]+)`)
reSpamRedirect = regexp.MustCompile(`(?im)^\s*RewriteRule\s+\S+\s+(https?://[^\s\[]+)`)
reFilesMatchOpen = regexp.MustCompile(`(?im)^\s*<FilesMatch\s+["']?[^"'>]*\\\.(php|phtml|ph[2-7])[^"'>]*["']?\s*>`)
reFilesMatchClose = regexp.MustCompile(`(?im)^\s*</FilesMatch>`)
reHeaderSetAdd = regexp.MustCompile(`(?im)^\s*Header\s+(set|add)\s+([A-Za-z0-9_-]+)`)
reErrorDocument = regexp.MustCompile(`(?im)^\s*ErrorDocument\s+\d+\s+(https?://[^\s]+)`)
// crawlerUARegex matches the UA strings frequently used in cloak
// conditions: search-engine bots and social-share scrapers. Used
// as a positive filter on the htaccess_user_agent_cloak finding;
// matching one of these names is what makes a UA-keyed redirect
// suspicious.
crawlerUARegex = regexp.MustCompile(`(?i)(googlebot|bingbot|baiduspider|yandex|facebookexternalhit|slurp|duckduckbot)`)
)
// htaccessDetectors is the registry. Order matters only for finding
// emission: detectors run in slice order so the first detector to
// see a line wins.
var htaccessDetectors = []htaccessDetector{
{
Name: "htaccess_php_in_uploads",
Severity: alert.Critical,
Detect: detectPHPInUploads,
},
{
Name: "htaccess_auto_prepend",
Severity: alert.Critical,
Detect: detectAutoPrepend,
},
{
Name: "htaccess_user_agent_cloak",
Severity: alert.High,
Detect: detectUserAgentCloak,
},
{
Name: "htaccess_spam_redirect",
Severity: alert.High,
Detect: detectSpamRedirect,
},
{
Name: "htaccess_filesmatch_shield",
Severity: alert.Critical,
Detect: detectFilesMatchShield,
},
{
Name: "htaccess_header_injection",
Severity: alert.High,
Detect: detectHeaderInjection,
},
{
Name: "htaccess_errordocument_hijack",
Severity: alert.High,
Detect: detectErrorDocumentHijack,
},
}
// AuditHtaccessFile runs every registered detector against the file
// at path. Returns the alert findings (one per detector hit) and
// the merged byte ranges that the cleaner would remove. The two
// outputs travel together so cleaning never disagrees with what
// the operator was alerted about.
func AuditHtaccessFile(path string) ([]alert.Finding, []htaccessByteRange) {
if filepath.Base(path) != ".htaccess" {
return nil, nil
}
// #nosec G304 -- path resolved by the caller via ResolveWebRoots / scanHtaccess; the operator's file tree, not attacker input.
content, err := os.ReadFile(path)
if err != nil {
return nil, nil
}
var findings []alert.Finding
var ranges []htaccessByteRange
for _, d := range htaccessDetectors {
matches := d.Detect(content, path)
for _, m := range matches {
findings = append(findings, alert.Finding{
Severity: d.Severity,
Check: d.Name,
Message: fmt.Sprintf("%s in %s", d.Name, path),
Details: fmt.Sprintf("File: %s\nMatch: %s", path, m.Excerpt),
FilePath: path,
Timestamp: time.Now(),
})
ranges = append(ranges, m.Range)
}
}
return findings, mergeRanges(ranges)
}
// CleanHtaccessFile audits the file, computes the removal range
// set, backs up the original, and writes the trimmed content.
// Returns success=false with no Action when no detector matched
// (i.e., nothing to clean).
//
// Caller is responsible for gating on cfg.AutoResponse.CleanHtaccess
// before invoking; this function will clean unconditionally if
// detectors find anything.
func CleanHtaccessFile(path string) RemediationResult {
if filepath.Base(path) != ".htaccess" {
return RemediationResult{Error: "automated .htaccess remediation only applies to .htaccess files"}
}
resolved, _, err := resolveExistingFixPath(path, fixHtaccessAllowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
// #nosec G304 -- resolved path verified inside the allowed roots.
original, err := os.ReadFile(resolved)
if err != nil {
return RemediationResult{Error: fmt.Sprintf("cannot read: %v", err)}
}
_, ranges := AuditHtaccessFile(resolved)
if len(ranges) == 0 {
return RemediationResult{Error: "no malicious directives found to remove"}
}
cleaned := applyRangeRemoval(original, ranges)
if len(cleaned) == len(original) {
return RemediationResult{Error: "no bytes removed (range computation produced empty diff)"}
}
backupDir := htaccessBackupDirRoot
if err = os.MkdirAll(backupDir, 0750); err != nil {
return RemediationResult{Error: fmt.Sprintf("creating backup dir: %v", err)}
}
stamp := time.Now().UTC().Format("20060102T150405Z")
backupPath := filepath.Join(backupDir, fmt.Sprintf("%s_%s", stamp, sanitizePathForBackup(resolved)))
// #nosec G306 G703 -- 0640 matches the rest of pre_clean/. backupPath is filepath.Join(backupDir, <ts>_<sanitizePathForBackup>) where sanitizePathForBackup strips every / and .. so the result cannot escape backupDir; resolved itself was validated by resolveExistingFixPath (fixHtaccessAllowedRoots).
if err = os.WriteFile(backupPath, original, 0640); err != nil {
return RemediationResult{Error: fmt.Sprintf("writing backup: %v", err)}
}
// .meta written as JSON in the same shape as autoresponse.go's
// QuarantineMeta so the existing /api/v1/quarantine listing and
// /api/v1/quarantine-restore handlers pick up htaccess pre_clean
// backups without a parallel codepath. The early implementation
// used a plain key=value sidecar; nothing in the pipeline read
// that, which made htaccess backups invisible in the UI.
metaPath := backupPath + ".meta"
metaJSON, err := json.Marshal(QuarantineMeta{
OriginalPath: resolved,
Size: int64(len(original)),
QuarantineAt: time.Now().UTC(),
Reason: fmt.Sprintf("htaccess clean: %d ranges removed (%d -> %d bytes)", len(ranges), len(original), len(cleaned)),
})
if err != nil {
return RemediationResult{Error: fmt.Sprintf("encoding backup meta: %v", err)}
}
// #nosec G306 -- sidecar meta; 0640 matches the backup file mode.
if err := os.WriteFile(metaPath, metaJSON, 0640); err != nil {
return RemediationResult{Error: fmt.Sprintf("writing backup meta: %v", err)}
}
tmp := resolved + ".csm-clean.tmp"
// #nosec G306 G703 -- 0644 is what the webserver expects for static content. tmp = resolved + ".csm-clean.tmp"; resolved was validated by resolveExistingFixPath against fixHtaccessAllowedRoots so path traversal is impossible.
if err := os.WriteFile(tmp, cleaned, 0644); err != nil {
return RemediationResult{Error: fmt.Sprintf("writing cleaned tmp: %v", err)}
}
if err := os.Rename(tmp, resolved); err != nil {
_ = os.Remove(tmp)
return RemediationResult{Error: fmt.Sprintf("atomic rename: %v", err)}
}
bytesRemoved := len(original) - len(cleaned)
return RemediationResult{
Success: true,
Action: fmt.Sprintf("removed %d malicious byte(s) from %s", bytesRemoved, resolved),
Description: fmt.Sprintf("Cleaned .htaccess: %d ranges, %d bytes removed (backup: %s)", len(ranges), bytesRemoved, backupPath),
}
}
// mergeRanges normalises the input slice: sort by start, then merge
// overlapping or adjacent ranges so cleaning produces deterministic
// output regardless of detector order.
func mergeRanges(in []htaccessByteRange) []htaccessByteRange {
if len(in) == 0 {
return nil
}
cp := make([]htaccessByteRange, len(in))
copy(cp, in)
sort.Slice(cp, func(i, j int) bool { return cp[i].Start < cp[j].Start })
out := []htaccessByteRange{cp[0]}
for _, r := range cp[1:] {
last := &out[len(out)-1]
if r.Start <= last.End {
if r.End > last.End {
last.End = r.End
}
continue
}
out = append(out, r)
}
return out
}
// applyRangeRemoval slices `content` minus every range in `ranges`
// (which mergeRanges has already sorted/merged). Each removal also
// includes the trailing newline if `end` lands on one, so we do not
// leave a blank line behind.
func applyRangeRemoval(content []byte, ranges []htaccessByteRange) []byte {
out := make([]byte, 0, len(content))
cursor := 0
for _, r := range ranges {
if r.Start > cursor {
out = append(out, content[cursor:r.Start]...)
}
end := r.End
if end < len(content) && content[end] == '\n' {
end++
}
cursor = end
}
if cursor < len(content) {
out = append(out, content[cursor:]...)
}
return out
}
type htaccessPhysicalByteLine struct {
text string
start int
end int
}
type htaccessLogicalByteLine struct {
text string
span htaccessByteRange
}
func splitHtaccessPhysicalByteLines(content []byte) []htaccessPhysicalByteLine {
var lines []htaccessPhysicalByteLine
for start := 0; start <= len(content); {
if start == len(content) {
lines = append(lines, htaccessPhysicalByteLine{
text: "",
start: start,
end: start,
})
break
}
end := start
for end < len(content) && content[end] != '\n' {
end++
}
lines = append(lines, htaccessPhysicalByteLine{
text: string(content[start:end]),
start: start,
end: end,
})
if end == len(content) {
break
}
start = end + 1
}
return lines
}
func htaccessLogicalByteLines(content []byte) []htaccessLogicalByteLine {
physical := splitHtaccessPhysicalByteLines(content)
var out []htaccessLogicalByteLine
for i := 0; i < len(physical); {
start := physical[i].start
var end int
var sb strings.Builder
for {
body, continues := htaccessContinuationBody(physical[i].text, i < len(physical)-1)
end = physical[i].end
if continues {
sb.WriteString(body)
i++
continue
}
sb.WriteString(body)
break
}
out = append(out, htaccessLogicalByteLine{
text: sb.String(),
span: htaccessByteRange{Start: start, End: end},
})
i++
}
return out
}
func matchesFromLogicalLineRegex(content []byte, re *regexp.Regexp) []htaccessMatch {
var out []htaccessMatch
for _, logical := range htaccessLogicalByteLines(content) {
if !re.MatchString(logical.text) {
continue
}
out = append(out, htaccessMatch{
Range: logical.span,
Excerpt: trimExcerpt(content, logical.span.Start, logical.span.End),
})
}
return out
}
func sanitizePathForBackup(p string) string {
r := strings.NewReplacer("/", "_", "\\", "_", " ", "_", ":", "_")
return strings.TrimPrefix(r.Replace(p), "_")
}
// detectPHPInUploads flags AddHandler/SetHandler/ForceType lines
// that map to PHP when the .htaccess lives inside a directory where
// PHP execution is rarely legitimate.
func detectPHPInUploads(content []byte, path string) []htaccessMatch {
if !pathInNonScriptDir(path) {
return nil
}
return matchesFromLogicalLineRegex(content, rePHPHandlerMap)
}
func pathInNonScriptDir(path string) bool {
lower := strings.ToLower(path)
for _, dir := range htaccessNonScriptDirHints {
if strings.Contains(lower, dir) {
return true
}
}
return false
}
// detectAutoPrepend flags PHP auto_prepend_file directives that
// point at filesystem locations known not to host legitimate prelude
// scripts.
func detectAutoPrepend(content []byte, _ string) []htaccessMatch {
idxs := reAutoPrepend.FindAllSubmatchIndex(content, -1)
var out []htaccessMatch
for _, idx := range idxs {
if len(idx) < 4 {
continue
}
target := string(content[idx[2]:idx[3]])
if !autoPrependTargetSuspicious(target) {
continue
}
out = append(out, htaccessMatch{
Range: lineRange(content, idx[0], idx[1]),
Excerpt: trimExcerpt(content, idx[0], idx[1]),
})
}
return out
}
func autoPrependTargetSuspicious(target string) bool {
target = strings.Trim(strings.TrimSpace(target), `"'`)
lower := strings.ToLower(target)
for _, p := range htaccessSuspiciousAutoPrependPaths {
if strings.HasPrefix(lower, p) {
return true
}
}
for _, ext := range htaccessImageExtensions {
if strings.HasSuffix(lower, ext) {
return true
}
}
return false
}
// reUARewriteRuleAfter captures the substitution and flag list of a
// RewriteRule directive. Used to inspect the rule paired with a
// preceding UA cond so the detector can distinguish defensive blocks
// (forbid / no-op / sinkhole) from cloaks (rewrite to a different
// file or external URL).
var reUARewriteRuleAfter = regexp.MustCompile(`(?im)^\s*RewriteRule\s+\S+\s+(\S+)(?:\s+\[([^\]]+)\])?\s*$`)
// uaCloakDefensiveFlags lists RewriteRule flags that, when set on the
// rule paired with a UA cond, indicate defensive blocking rather than
// content cloaking. Apache combines these flags with a comma so each
// flag is matched as a substring of the bracketed flag list.
var uaCloakDefensiveFlags = []string{"f", "g"}
// uaCloakBlocklistThreshold is the number of OR-list entries in the
// UA cond's regex that converts the cond from "potential cloak" to
// "operator-installed bot blocklist". A cond with 4+ alternation
// entries is overwhelmingly a defensive block (the canonical
// SoftAculous / Apache Bad Bots list ships ~20+ entries).
const uaCloakBlocklistThreshold = 4
// uaCloakAlternationCount counts the top-level "|" alternation
// branches in a UA cond regex pattern, ignoring "|" characters inside
// nested parentheses. Used to identify long bot blocklists.
func uaCloakAlternationCount(pattern string) int {
depth := 0
count := 1
prevEscape := false
for i := 0; i < len(pattern); i++ {
c := pattern[i]
if prevEscape {
prevEscape = false
continue
}
switch c {
case '\\':
prevEscape = true
case '(':
depth++
case ')':
if depth > 0 {
depth--
}
case '|':
if depth <= 1 {
count++
}
}
}
return count
}
// uaCloakPairedRuleIsDefensive scans forward from condEnd for the
// RewriteRule directive that Apache will pair with the cond. Returns
// true when that rule is a no-op ("-" substitution), a forbid
// ([F]/[G] flag), or absent. Both shapes are defensive, not cloaking.
//
// Apache's chaining model: a chain of RewriteCond lines applies to
// the FIRST RewriteRule that follows them in the file. We walk
// line-by-line and stop at the first RewriteRule. A blank line, a
// non-RewriteCond directive, or end-of-file means there is no paired
// rule - the cond is dead text and not actively cloaking anything.
func uaCloakPairedRuleIsDefensive(content []byte, condEnd int) bool {
rest := content[condEnd:]
// Find the next RewriteRule line.
idx := reUARewriteRuleAfter.FindSubmatchIndex(rest)
if idx == nil {
return true
}
substitution := string(rest[idx[2]:idx[3]])
var flags string
if idx[4] != -1 {
flags = strings.ToLower(string(rest[idx[4]:idx[5]]))
}
if substitution == "-" {
return true
}
for _, f := range uaCloakDefensiveFlags {
// Match flag as a comma-bounded token: "F" matches "[F]",
// "[F,L]", "[L,F]"; does NOT match "[NC]" or "[QSA]".
for _, token := range strings.Split(flags, ",") {
if strings.TrimSpace(token) == f {
return true
}
}
}
return false
}
// detectUserAgentCloak flags RewriteCond %{HTTP_USER_AGENT}
// directives that match a known crawler UA AND are part of an active
// content-cloaking rule. Four suppression gates filter legitimate
// shapes before emitting the High alert:
//
// 1. Negated cond ("RewriteCond %{HTTP_USER_AGENT} !..."): the rule
// applies only when the UA is NOT this crawler. Cache plugins
// (WP Fastest Cache, WP Super Cache) ship long negated lists to
// exclude social-share scrapers from the cached-content rewrite.
//
// 2. Long alternation (>= uaCloakBlocklistThreshold OR-branches):
// operator-installed defensive blocklists ship many bot UAs in
// a single OR chain paired with a [F] forbid or sinkhole rewrite.
// Cloakers use one or two crawler names.
//
// 3. Long multi-line chain: the canonical Apache "Bad Bots" snippet
// puts each scraper UA on its own RewriteCond line, terminated
// by one RewriteRule. The chain length is the blocklist signal;
// per-cond alternation is always 1.
//
// 4. Paired RewriteRule is defensive ("-" substitution, [F]/[G]
// flag, or absent): the cond is part of a forbid / env-var-set
// block, not a content swap.
//
// All gates fail-closed: any uncertainty (no paired rule found,
// parse failure) keeps the original alert firing.
func detectUserAgentCloak(content []byte, _ string) []htaccessMatch {
idxs := reUACloakCond.FindAllSubmatchIndex(content, -1)
chainSize := uaCloakChainSizes(content, idxs)
var out []htaccessMatch
for i, idx := range idxs {
if len(idx) < 4 {
continue
}
uaPattern := string(content[idx[2]:idx[3]])
if !crawlerUARegex.MatchString(uaPattern) {
continue
}
// Gate 1: negated cond. Strip leading whitespace and look
// for "!" before the rest of the pattern.
if strings.HasPrefix(strings.TrimSpace(uaPattern), "!") {
continue
}
// Gate 2: long alternation list = bot blocklist.
if uaCloakAlternationCount(uaPattern) >= uaCloakBlocklistThreshold {
continue
}
// Gate 3: long multi-line chain = bot blocklist.
if chainSize[i] >= uaCloakBlocklistThreshold {
continue
}
// Gate 4: paired RewriteRule is defensive.
if uaCloakPairedRuleIsDefensive(content, idx[1]) {
continue
}
out = append(out, htaccessMatch{
Range: lineRange(content, idx[0], idx[1]),
Excerpt: trimExcerpt(content, idx[0], idx[1]),
})
}
return out
}
// uaCloakChainSizes returns, for each UA cond match, the number of
// UA conds in its chain. Two UA conds belong to the same chain iff
// they sit on adjacent lines with nothing but optional indentation
// between them. A blank line, a RewriteRule, or any other directive
// breaks the chain - that mirrors how Apache groups conds onto the
// next RewriteRule.
func uaCloakChainSizes(content []byte, idxs [][]int) []int {
sizes := make([]int, len(idxs))
if len(idxs) == 0 {
return sizes
}
runStart := 0
for i := 1; i < len(idxs); i++ {
if uaCondsAreAdjacent(content, idxs[i-1][1], idxs[i][0]) {
continue
}
runLen := i - runStart
for j := runStart; j < i; j++ {
sizes[j] = runLen
}
runStart = i
}
runLen := len(idxs) - runStart
for j := runStart; j < len(idxs); j++ {
sizes[j] = runLen
}
return sizes
}
// uaCondsAreAdjacent reports whether the gap between two UA cond
// matches is a single line break plus optional indentation. Any
// non-whitespace byte (a different directive sneaking in), or a
// second newline (blank line), means the chain ends.
func uaCondsAreAdjacent(content []byte, prevEnd, nextStart int) bool {
if prevEnd > nextStart || nextStart > len(content) {
return false
}
seenNewline := false
for _, b := range content[prevEnd:nextStart] {
switch b {
case '\n':
if seenNewline {
return false
}
seenNewline = true
case ' ', '\t', '\r':
// allowed
default:
return false
}
}
return seenNewline
}
// detectSpamRedirect flags RewriteRule directives whose target host
// is on a known spam TLD. Operator-supplied legitimate hosts in
// these TLDs need a per-path suppression.
func detectSpamRedirect(content []byte, _ string) []htaccessMatch {
idxs := reSpamRedirect.FindAllSubmatchIndex(content, -1)
var out []htaccessMatch
for _, idx := range idxs {
if len(idx) < 4 {
continue
}
target := string(content[idx[2]:idx[3]])
host := extractHost(target)
if !hostOnSpamTLD(host) {
continue
}
out = append(out, htaccessMatch{
Range: lineRange(content, idx[0], idx[1]),
Excerpt: trimExcerpt(content, idx[0], idx[1]),
})
}
return out
}
func extractHost(rawURL string) string {
u, err := url.Parse(rawURL)
if err != nil {
return ""
}
return strings.ToLower(u.Hostname())
}
func hostOnSpamTLD(host string) bool {
for _, tld := range htaccessSpamTLDs {
if strings.HasSuffix(host, tld) {
return true
}
}
return false
}
// reFilesMatchPattern captures the regex pattern inside the FilesMatch
// quotes. The capture group is everything between the optional quote
// characters, which is the Apache-side regex applied to filenames.
var reFilesMatchPattern = regexp.MustCompile(`(?im)^\s*<FilesMatch\s+["']?([^"'>]+?)["']?\s*>`)
// reFilesMatchExtensionTail strips the canonical PHP extension suffix
// from a FilesMatch pattern so the remaining literal can be examined.
// Anchors and end-of-string markers are left to the caller.
var reFilesMatchExtensionTail = regexp.MustCompile(`(?i)\\\.(?:php|phtml|ph[2-7])\$?\)?$`)
// filesMatchPatternIsTargeted reports whether the FilesMatch regex
// names at least one specific PHP filename rather than granting access
// to every .php file in the directory. Stock plugins ship targeted
// patterns ("wpc\.php$", "ps_facetedsearch-.+\.php$",
// "(webp-on-demand\.php|webp-realizer\.php)$"); the malicious shape
// is a bare wildcard ("\.php$", ".*\.php$", "[^/]+\.php$").
//
// The check is character-class based: any literal alphanumeric, dash,
// or underscore in the pattern (after stripping the trailing
// "\.php$" / "\.phtml$" extension) means the pattern names something
// specific. A pattern composed only of regex meta-characters (".",
// "*", "^", "$", "[", "]", "(", ")", "|", "+", "?", "\\") is treated
// as a wildcard and continues to the wildcard-context check.
func filesMatchPatternIsTargeted(pattern string) bool {
stripped := reFilesMatchExtensionTail.ReplaceAllString(pattern, "")
for _, c := range stripped {
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-' {
return true
}
}
return false
}
// htaccessParentPHPFileCount counts ".php" files (and other handler
// extensions FilesMatch covers) sitting alongside the .htaccess at
// path. Used to differentiate a plugin directory full of legitimate
// PHP dispatchers from a freshly-attacker-written upload directory
// containing one or zero PHP files.
//
// Errors (parent missing, permission denied, race) return 0 - the
// caller treats that as "not enough sibling PHP" and keeps firing.
func htaccessParentPHPFileCount(htaccessPath string) int {
parent := filepath.Dir(htaccessPath)
entries, err := os.ReadDir(parent)
if err != nil {
return 0
}
n := 0
for _, e := range entries {
if e.IsDir() {
continue
}
ext := strings.ToLower(filepath.Ext(e.Name()))
switch ext {
case ".php", ".phtml", ".ph2", ".ph3", ".ph4", ".ph5", ".ph6", ".ph7":
n++
}
}
return n
}
// filesMatchShieldSiblingThreshold is the number of sibling PHP files
// that converts a bare-wildcard FilesMatch shield from "fire" to
// "treat as legitimate plugin allowlist". A directory with 3+ stock
// PHP dispatchers existing alongside the shield is overwhelmingly
// likely to be a legitimate webapp module (KCFinder ships ~5,
// PrestaShop modules ship ~10+, vendor dispatchers ship a handful).
// The attacker drop pattern is .htaccess + one or zero PHP files.
const filesMatchShieldSiblingThreshold = 3
// detectFilesMatchShield finds <FilesMatch ...\.php(tml)?> blocks
// that grant Allow from all or Require all granted -- the canonical
// "let everyone execute everything we just dropped" pattern. The
// returned range covers the full block, opening tag through closing
// tag inclusive.
//
// Two suppression gates run before emitting the finding:
//
// 1. Targeted pattern: if the FilesMatch regex names a specific
// filename ("wpc\.php$", named allowlist, or prefix pattern), it
// is a legitimate plugin allowlist and is skipped.
//
// 2. Bare wildcard with sibling PHP context: if the FilesMatch is a
// bare wildcard but the .htaccess parent directory contains
// multiple sibling PHP dispatchers, the shield is protecting an
// existing legitimate plugin layout, not a freshly-dropped
// dropper. Threshold is filesMatchShieldSiblingThreshold.
//
// Both gates fail-open: any uncertainty (parse failure, IO error)
// keeps the original Critical alert firing. The sibling-PHP gate
// requires the htaccess path so the detector now reads it from the
// caller (passed as the second argument to all htaccessDetector.Detect
// implementations).
func detectFilesMatchShield(content []byte, path string) []htaccessMatch {
openIdxs := reFilesMatchOpen.FindAllIndex(content, -1)
closeIdxs := reFilesMatchClose.FindAllIndex(content, -1)
if len(openIdxs) == 0 || len(closeIdxs) == 0 {
return nil
}
patternIdxs := reFilesMatchPattern.FindAllSubmatchIndex(content, -1)
var out []htaccessMatch
for _, open := range openIdxs {
// pair this opening tag with the next closing tag after it
var paired []int
for _, c := range closeIdxs {
if c[0] >= open[1] {
paired = c
break
}
}
if paired == nil {
continue
}
body := content[open[1]:paired[0]]
bodyLower := strings.ToLower(string(body))
if !strings.Contains(bodyLower, "allow from all") && !strings.Contains(bodyLower, "require all granted") {
continue
}
// Look up the FilesMatch pattern that opened at this position
// so we can apply the targeted-vs-wildcard discriminator.
var openPattern string
for _, pIdx := range patternIdxs {
if len(pIdx) < 4 {
continue
}
if pIdx[0] == open[0] {
openPattern = string(content[pIdx[2]:pIdx[3]])
break
}
}
if openPattern != "" && filesMatchPatternIsTargeted(openPattern) {
continue
}
// Bare wildcard: check sibling PHP count. A directory with
// multiple stock PHP dispatchers is a legitimate plugin layout.
if path != "" && htaccessParentPHPFileCount(path) >= filesMatchShieldSiblingThreshold {
continue
}
out = append(out, htaccessMatch{
Range: blockRange(content, open[0], paired[1]),
Excerpt: trimExcerpt(content, open[0], paired[1]),
})
}
return out
}
// detectHeaderInjection flags Header set / Header add directives
// whose name is on the small tracking-header allowlist. Generic
// CSP / HSTS / X-Frame-Options headers do not match.
func detectHeaderInjection(content []byte, _ string) []htaccessMatch {
idxs := reHeaderSetAdd.FindAllSubmatchIndex(content, -1)
var out []htaccessMatch
for _, idx := range idxs {
if len(idx) < 6 {
continue
}
name := string(content[idx[4]:idx[5]])
if !headerNameSuspicious(name) {
continue
}
out = append(out, htaccessMatch{
Range: lineRange(content, idx[0], idx[1]),
Excerpt: trimExcerpt(content, idx[0], idx[1]),
})
}
return out
}
func headerNameSuspicious(name string) bool {
lower := strings.ToLower(name)
for _, h := range htaccessTrackingHeaders {
if strings.HasPrefix(lower, strings.ToLower(h)) {
return true
}
}
return false
}
// errorDocumentHostShareThreshold bounds how short a host-vs-path
// substring match can be while still treating the redirect as
// "same-brand". Three characters is the floor: anything shorter would
// match incidental segments ("us" inside "user", "co" inside "co.uk")
// and let an attacker tunnel through with a name like "co.evil.com".
const errorDocumentHostShareThreshold = 4
// errorDocumentHostIsSameBrand reports whether the URL host's
// "registrable label" (the leftmost segment of the public suffix +
// 1) shares an alphanumeric stem of >= errorDocumentHostShareThreshold
// chars with any path component of the .htaccess file. This catches
// the dominant legitimate shape: a custom 404 redirect to the site's
// own homepage on the same brand domain.
//
// Examples:
//
// /home/flores/public_html/.htaccess + https://floresgrup.ro
// -> account "flores" is a substring of label "floresgrup" -> same-brand
//
// /home/shop/example-shop.com/.htaccess + https://www.example-shop.com/404
// -> domain dir "example-shop.com" contains label "example-shop" -> same-brand
//
// /home/victim/public_html/.htaccess + https://attacker.com/landing
// -> "attacker" shares no >=4-char stem with any path component -> different-brand
func errorDocumentHostIsSameBrand(htaccessPath, urlHost string) bool {
label := registrableLabel(urlHost)
if len(label) < errorDocumentHostShareThreshold {
return false
}
labelLower := strings.ToLower(label)
for _, component := range strings.Split(htaccessPath, string(filepath.Separator)) {
if component == "" || component == "public_html" || component == "home" {
continue
}
comp := strings.ToLower(component)
if longestCommonAlnumRun(comp, labelLower) >= errorDocumentHostShareThreshold {
return true
}
}
return false
}
// registrableLabel extracts the leftmost segment of the public
// suffix + 1: for "www.example-shop.com" returns "example-shop", for
// "floresgrup.ro" returns "floresgrup". This is heuristic - we treat
// the last dot-separated segment as the TLD - but it is robust to
// the common cases (single-segment TLD, two-segment country code TLD
// like "co.uk" handled by stripping known double-segment suffixes).
//
// Returns "" for inputs that look like an IPv4 dotted quad: numeric
// targets are flagged separately and never qualify as same-brand.
func registrableLabel(host string) string {
host = strings.ToLower(strings.TrimSpace(host))
if host == "" {
return ""
}
// IPv4 dotted quad: each segment must be 1-3 digits.
if isIPv4(host) {
return ""
}
// Strip a leading "www." for canonical comparison.
host = strings.TrimPrefix(host, "www.")
parts := strings.Split(host, ".")
if len(parts) < 2 {
return host
}
// Two-segment public suffix heuristic: "co.uk", "co.za",
// "com.au" etc. If the second-to-last segment is one of these
// short country-code prefixes, the registrable label is the
// THIRD-from-last segment.
if len(parts) >= 3 {
penultimate := parts[len(parts)-2]
twoSegmentCC := map[string]bool{
"co": true, "com": true, "net": true, "org": true,
"ac": true, "gov": true, "edu": true,
}
if twoSegmentCC[penultimate] {
return parts[len(parts)-3]
}
}
return parts[len(parts)-2]
}
// isIPv4 reports whether s parses as a dotted-quad IPv4 address.
// IPv6 / hex / mixed forms are caught separately by IP-target
// signalling above the same-brand check.
func isIPv4(s string) bool {
parts := strings.Split(s, ".")
if len(parts) != 4 {
return false
}
for _, p := range parts {
if p == "" || len(p) > 3 {
return false
}
for _, c := range p {
if c < '0' || c > '9' {
return false
}
}
}
return true
}
// longestCommonAlnumRun returns the length of the longest common
// substring of a and b that consists entirely of alphanumeric
// characters. Symbols (".", "-", "_") segment the run so an
// "example-shop" path component can match a "example-shop.com" URL
// host on the "example" stem without crediting the dash.
func longestCommonAlnumRun(a, b string) int {
best := 0
// Walk a, accumulate alnum tokens, check each token against b.
tokens := splitAlnumTokens(a)
bLower := b
for _, tok := range tokens {
if len(tok) < best {
continue
}
// Find tok inside b: contiguous alnum substring match.
if strings.Contains(bLower, tok) {
if len(tok) > best {
best = len(tok)
}
}
}
return best
}
// splitAlnumTokens splits s on any non-alphanumeric character, returning
// the maximal alphanumeric runs.
func splitAlnumTokens(s string) []string {
var out []string
start := -1
for i := 0; i < len(s); i++ {
c := s[i]
isAlnum := (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')
if isAlnum && start == -1 {
start = i
} else if !isAlnum && start != -1 {
out = append(out, s[start:i])
start = -1
}
}
if start != -1 {
out = append(out, s[start:])
}
return out
}
// detectErrorDocumentHijack flags ErrorDocument directives whose
// target is an external http(s) URL pointing at a host that does NOT
// share a brand stem with the .htaccess file's path. Same-brand
// redirects (custom 404 -> site homepage) are extremely common and
// must not be flagged. Spam-TLD targets and IP-address targets are
// always flagged regardless of any brand match.
func detectErrorDocumentHijack(content []byte, path string) []htaccessMatch {
idxs := reErrorDocument.FindAllSubmatchIndex(content, -1)
var out []htaccessMatch
for _, idx := range idxs {
if len(idx) < 4 {
continue
}
target := string(content[idx[2]:idx[3]])
host := extractHost(target)
// Spam TLDs and IP-address targets always fire. Both are
// signals of compromise even when the path-share heuristic
// would otherwise consider the host same-brand.
if hostOnSpamTLD(host) || isIPv4(host) {
out = append(out, htaccessMatch{
Range: lineRange(content, idx[0], idx[1]),
Excerpt: trimExcerpt(content, idx[0], idx[1]),
})
continue
}
// Same-brand check only when we have a path to compare
// against. Detector is sometimes called with an empty path
// (unit tests of the regex alone); fail-closed.
if path != "" && errorDocumentHostIsSameBrand(path, host) {
continue
}
out = append(out, htaccessMatch{
Range: lineRange(content, idx[0], idx[1]),
Excerpt: trimExcerpt(content, idx[0], idx[1]),
})
}
return out
}
// lineRange computes a range covering the full line(s) that contain
// [start, end]. Line boundaries are at '\n'; the returned end is
// the position of the trailing '\n' (exclusive) or len(content) if
// the match is at EOF without a trailing newline.
func lineRange(content []byte, start, end int) htaccessByteRange {
if start < 0 {
start = 0
}
if end > len(content) {
end = len(content)
}
for start > 0 && content[start-1] != '\n' {
start--
}
for end < len(content) && content[end] != '\n' {
end++
}
return htaccessByteRange{Start: start, End: end}
}
// blockRange covers the whole block including the opening and
// closing directive lines. Same line-snapping rules as lineRange.
func blockRange(content []byte, start, end int) htaccessByteRange {
r := lineRange(content, start, end)
return r
}
// trimExcerpt returns up to ~200 bytes of the matched content,
// trimmed and with newlines collapsed, for finding details.
func trimExcerpt(content []byte, start, end int) string {
if start < 0 {
start = 0
}
if end > len(content) {
end = len(content)
}
s := strings.TrimSpace(string(content[start:end]))
s = strings.ReplaceAll(s, "\n", " | ")
if len(s) > 200 {
s = s[:200] + "..."
}
return s
}
// HTTP abuse detection.
//
// This file holds the access-log line parser, the per-scan aggregator
// struct (domlogStats), the UA classifier, the bot-classifier interface,
// and the static allowlist classifier that consults embedded bot IP ranges.
// The rDNS verifying classifier arrives in Task 5.
package checks
import (
"fmt"
"net"
"sort"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/threatintel"
)
// accessLogRecord is the parsed shape of one access-log line. Combined
// Log Format plus the cPanel final-vhost extension:
//
// IP - - [time] "METHOD URI PROTO" status bytes "referer" "ua" "vhost"
//
// The parser tolerates either the 9-field plain CLF or the 10-field
// cPanel variant. Bad lines return ok=false; callers must skip them.
type accessLogRecord struct {
RemoteIP string
Time time.Time
Method string
URI string
Status int
UserAgent string
XFF string // optional; only trusted when RemoteIP is a trusted proxy
Domain string // vhost the line came from (per-domain domlog); empty for the central log
}
// uaKind is the User-Agent classification produced by classifyUA and
// consumed by domlogStats.scan and the http_ua_spoof emit logic.
type uaKind int
const (
uaKindBrowser uaKind = iota
uaKindClaimedBot
uaKindClaimedBotNegative
uaKindKnownScanner
uaKindWPSpoofPingback
uaKindScriptingLang
uaKindHeadless
uaKindEmpty
)
// botClassifier decides whether the source IP is a known verified bot
// the detector should NOT count or flag. Returns true to skip. The
// real implementation lives in internal/threatintel; tests use the
// nopBotClassifier below.
type botClassifier interface {
IsVerifiedBot(ip string, ua string) bool
}
type nopBotClassifier struct{}
func (nopBotClassifier) IsVerifiedBot(string, string) bool { return false }
// httpSample is one representative request kept per IP for forensic
// display in the finding Details field. First-seen wins; subsequent
// requests increment counters only.
type httpSample struct {
Method string
URI string
UA string
}
// domlogStats is the per-scan aggregator. One instance per CheckWPBruteForce
// invocation. Every counter is map[ip] -> int so emit can produce findings
// per source IP without a second pass.
type domlogStats struct {
wpLogin map[string]int
xmlrpc map[string]int
userEnum map[string]int
httpReqs map[string]int
uaCat map[string]map[uaKind]int
samples map[string]httpSample
// domains tracks the set of distinct vhosts each IP touched, so the
// per-IP aggregate findings can report cross-site spread (one IP
// scanning many vhosts on a shared host). Empty-domain records (the
// central access log) do not contribute.
domains map[string]map[string]struct{}
// abuseDomains tracks the in-window vhosts where each source made the
// request shape for a specific HTTP-abuse finding. The distributed
// rollup uses this instead of all in-window touches so an IP that was
// abusive elsewhere cannot make a normal hit count against a vhost.
abuseDomains map[string]map[string]map[string]struct{}
scanTime time.Time
}
func newDomlogStats() *domlogStats {
return newDomlogStatsAt(time.Now())
}
func newDomlogStatsAt(t time.Time) *domlogStats {
return &domlogStats{
wpLogin: make(map[string]int),
xmlrpc: make(map[string]int),
userEnum: make(map[string]int),
httpReqs: make(map[string]int),
uaCat: make(map[string]map[uaKind]int),
samples: make(map[string]httpSample),
domains: make(map[string]map[string]struct{}),
abuseDomains: make(map[string]map[string]map[string]struct{}),
scanTime: t,
}
}
// scan updates counters for one parsed record. cfg is allowed to be nil
// at Task 1 (parity tests pass nil); later tasks read thresholds from
// it. bot is consulted before any count so a verified Googlebot does
// not contribute to either legacy or new metrics.
func (s *domlogStats) scan(rec accessLogRecord, cfg *config.Config, bot botClassifier) {
ip := normalizeHTTPClientIP(clientIPForRecord(rec, cfg))
if ip == "" {
return
}
if cfg != nil && isInfraIP(ip, cfg.InfraIPs) {
return
}
if bot != nil && bot.IsVerifiedBot(ip, rec.UserAgent) {
return
}
wpLoginHit := rec.Method == "POST" && strings.Contains(rec.URI, "wp-login.php")
xmlrpcHit := rec.Method == "POST" && strings.Contains(rec.URI, "xmlrpc.php")
userEnumHit := false
if rec.Method == "POST" {
if wpLoginHit {
s.wpLogin[ip]++
}
if xmlrpcHit {
s.xmlrpc[ip]++
}
}
if strings.Contains(rec.URI, "?author=") {
s.userEnum[ip]++
userEnumHit = true
} else if strings.Contains(rec.URI, "/wp-json/wp/v2/users") &&
!strings.Contains(rec.URI, "/users/me") {
s.userEnum[ip]++
userEnumHit = true
}
// Rate and UA counters only fire for requests that fall inside the
// flood window. Malformed timestamp lines still feed the legacy POST
// counters above but not rate or UA findings.
if withinHTTPFloodWindow(rec.Time, cfg, s.scanTime) {
if rec.Domain != "" {
set := s.domains[ip]
if set == nil {
set = make(map[string]struct{})
s.domains[ip] = set
}
set[rec.Domain] = struct{}{}
if wpLoginHit {
s.recordAbuseDomain("wp_login_bruteforce", ip, rec.Domain)
}
if xmlrpcHit {
s.recordAbuseDomain("xmlrpc_abuse", ip, rec.Domain)
}
if userEnumHit {
s.recordAbuseDomain("wp_user_enumeration", ip, rec.Domain)
}
s.recordAbuseDomain("http_request_flood", ip, rec.Domain)
}
if _, ok := s.samples[ip]; !ok {
s.samples[ip] = httpSample{Method: rec.Method, URI: rec.URI, UA: rec.UserAgent}
}
s.httpReqs[ip]++
kind := classifyUA(rec.UserAgent, rec.Method)
if kind == uaKindClaimedBot {
// Static allowlist hits returned early through IsVerifiedBot
// above. For IPs outside the static range, check whether the
// async rDNS verifier has confirmed a negative result. Only
// promote to uaKindClaimedBotNegative when the cache has a
// definitive negative; otherwise fail open (treat as browser)
// so a new-to-us legitimate bot IP does not fire ua_spoof on
// the first scan after it appears.
if cv, ok := bot.(interface {
ConfirmedNegative(ip, ua string) bool
}); ok && cv.ConfirmedNegative(ip, rec.UserAgent) {
kind = uaKindClaimedBotNegative
} else {
kind = uaKindBrowser
}
}
if _, ok := s.uaCat[ip]; !ok {
s.uaCat[ip] = make(map[uaKind]int)
}
s.uaCat[ip][kind]++
if rec.Domain != "" && kind != uaKindBrowser {
s.recordAbuseDomain("http_ua_spoof", ip, rec.Domain)
}
}
}
func (s *domlogStats) recordAbuseDomain(check, ip, domain string) {
byIP := s.abuseDomains[check]
if byIP == nil {
byIP = make(map[string]map[string]struct{})
s.abuseDomains[check] = byIP
}
set := byIP[ip]
if set == nil {
set = make(map[string]struct{})
byIP[ip] = set
}
set[domain] = struct{}{}
}
// emitLegacy returns the three pre-existing finding kinds. Kept
// separate from the new emit() (Tasks 3/4) so the parity test can
// assert "no new findings yet".
func (s *domlogStats) emitLegacy(_ *config.Config) []alert.Finding {
var out []alert.Finding
for ip, count := range s.wpLogin {
if count >= wpLoginThreshold {
out = append(out, alert.Finding{
Severity: alert.Critical,
Check: "wp_login_bruteforce",
SourceIP: ip,
Message: formatLegacyMessage("WordPress login brute force", ip, count, "attempts"),
Details: "Aggregated across per-vhost access logs",
})
}
}
for ip, count := range s.xmlrpc {
if count >= xmlrpcThreshold {
out = append(out, alert.Finding{
Severity: alert.Critical,
Check: "xmlrpc_abuse",
SourceIP: ip,
Message: formatLegacyMessage("XML-RPC abuse", ip, count, "requests"),
Details: "Aggregated across per-vhost access logs",
})
}
}
for ip, count := range s.userEnum {
if count >= 5 {
out = append(out, alert.Finding{
Severity: alert.High,
Check: "wp_user_enumeration",
SourceIP: ip,
Message: formatLegacyMessage("WordPress user enumeration", ip, count, "requests"),
Details: "Requests to /wp-json/wp/v2/users or ?author=",
})
}
}
return out
}
// emit produces all finding kinds from a single populated domlogStats.
// Legacy three kinds come first (so existing callers still get them when
// running through emit), then http_request_flood, then http_ua_spoof.
func (s *domlogStats) emit(cfg *config.Config) []alert.Finding {
out := s.emitLegacy(cfg)
if cfg != nil && cfg.Thresholds.HTTPFloodThreshold > 0 {
for ip, count := range s.httpReqs {
if count < cfg.Thresholds.HTTPFloodThreshold {
continue
}
sample := s.samples[ip]
out = append(out, alert.Finding{
Severity: alert.High,
Check: "http_request_flood",
SourceIP: ip,
Domain: s.singleDomain(ip),
Message: "HTTP request flood from " + ip + ": " + itoa(count) + " requests" + s.vhostSuffix(ip),
Details: "Sample: " + sample.Method + " " + sample.URI + " UA=" + truncate(sample.UA, 120),
})
}
}
if cfg != nil {
threshold := cfg.Thresholds.HTTPUASpoofThreshold
if threshold <= 0 {
threshold = 30
}
for ip, byKind := range s.uaCat {
if hits, ok := byKind[uaKindClaimedBotNegative]; ok && hits > 0 {
out = append(out, s.makeUASpoofFinding(ip,
"claimed search-engine bot failed rDNS verification",
s.samples[ip], hits))
continue
}
if hits, ok := byKind[uaKindKnownScanner]; ok && hits > 0 {
out = append(out, s.makeUASpoofFinding(ip, "known scanner UA",
s.samples[ip], hits))
continue
}
if hits := byKind[uaKindWPSpoofPingback]; hits >= threshold {
out = append(out, s.makeUASpoofFinding(ip,
"WordPress/<ver> UA on GET (pingback spoof)",
s.samples[ip], hits))
continue
}
if cfg.Thresholds.HTTPUAScriptingEnabled {
if hits := byKind[uaKindScriptingLang]; hits >= threshold {
out = append(out, s.makeUASpoofFinding(ip,
"scripting-language UA (curl/python/etc.)",
s.samples[ip], hits))
continue
}
}
if cfg.Thresholds.HTTPUAHeadlessEnabled {
if hits := byKind[uaKindHeadless]; hits >= threshold {
out = append(out, s.makeUASpoofFinding(ip,
"headless browser UA", s.samples[ip], hits))
continue
}
}
if cfg.Thresholds.HTTPUAEmptyEnabled {
if hits := byKind[uaKindEmpty]; hits >= threshold {
out = append(out, s.makeUASpoofFinding(ip,
"empty/dash User-Agent", s.samples[ip], hits))
continue
}
}
}
}
// Distributed attack: many distinct abusive IPs hitting one vhost.
// Built from the per-IP findings already emitted above, so only IPs
// that crossed an abuse threshold count -- a popular site's normal
// visitor spread never trips it.
out = append(out, s.emitDistributedFlood(cfg, out)...)
return out
}
// httpAbuseChecks are the per-IP finding kinds that mark a source IP as
// abusive for the distributed-attack rollup.
var httpAbuseChecks = map[string]struct{}{
"wp_login_bruteforce": {},
"xmlrpc_abuse": {},
"wp_user_enumeration": {},
"http_request_flood": {},
"http_ua_spoof": {},
}
// emitDistributedFlood rolls the per-IP HTTP-abuse findings up per vhost:
// when at least HTTPDistributedMinIPs distinct abusive IPs hit one vhost
// in this window, it emits a single http_distributed_flood finding for
// that vhost. The per-IP findings still stand; this adds the
// many-IPs-one-target view (botnet / distributed brute-force) that the
// per-IP and per-source-IP-spray paths cannot see. Disabled when the
// threshold is <= 0.
func (s *domlogStats) emitDistributedFlood(cfg *config.Config, prior []alert.Finding) []alert.Finding {
if cfg == nil {
return nil
}
minIPs := cfg.Thresholds.HTTPDistributedMinIPs
if minIPs <= 0 {
return nil
}
domainIPs := map[string]map[string]struct{}{}
for _, f := range prior {
if _, ok := httpAbuseChecks[f.Check]; !ok || f.SourceIP == "" {
continue
}
for dom := range s.abuseDomains[f.Check][f.SourceIP] {
if domainIPs[dom] == nil {
domainIPs[dom] = map[string]struct{}{}
}
domainIPs[dom][f.SourceIP] = struct{}{}
}
}
domains := make([]string, 0, len(domainIPs))
for dom := range domainIPs {
domains = append(domains, dom)
}
sort.Strings(domains)
var out []alert.Finding
for _, dom := range domains {
n := len(domainIPs[dom])
if n < minIPs {
continue
}
out = append(out, alert.Finding{
Severity: alert.High,
Check: "http_distributed_flood",
Domain: dom,
Message: fmt.Sprintf("Distributed HTTP attack on %s: %d distinct abusive source IPs", dom, n),
Details: fmt.Sprintf("%d source IPs each tripped an HTTP-abuse threshold against %s in this window. "+
"Likely a botnet or distributed brute-force; consider edge rate-limiting or a challenge.", n, dom),
Timestamp: time.Now(),
})
}
return out
}
func (s *domlogStats) makeUASpoofFinding(ip, reason string, sample httpSample, hits int) alert.Finding {
return alert.Finding{
Severity: alert.Critical,
Check: "http_ua_spoof",
SourceIP: ip,
Domain: s.singleDomain(ip),
Message: "User-Agent spoof from " + ip + ": " + reason + s.vhostSuffix(ip),
Details: "Hits: " + itoa(hits) + ", Sample: " + sample.Method + " " +
sample.URI + " UA=" + truncate(sample.UA, 200),
}
}
// singleDomain returns the one vhost ip touched, or "" when zero or more
// than one (the per-IP aggregate spans several vhosts, so no single domain
// attributes the finding). vhostSuffix renders the cross-site spread for
// the operator-facing message.
func (s *domlogStats) singleDomain(ip string) string {
set := s.domains[ip]
if len(set) != 1 {
return ""
}
for d := range set {
return d
}
return ""
}
func (s *domlogStats) vhostSuffix(ip string) string {
if n := len(s.domains[ip]); n > 1 {
return " across " + itoa(n) + " vhosts"
}
return ""
}
// withinHTTPFloodWindow reports whether a log timestamp falls inside the
// configured flood rate window relative to the scan start time. Timestamps
// in the future (up to one clock-skew minute) are accepted. Zero timestamps
// from malformed lines return false so they do not contribute to rate counts.
func withinHTTPFloodWindow(ts time.Time, cfg *config.Config, now time.Time) bool {
if ts.IsZero() || cfg == nil {
return false
}
windowMin := cfg.Thresholds.HTTPFloodWindowMin
if windowMin <= 0 {
windowMin = 5
}
cutoff := now.Add(-time.Duration(windowMin) * time.Minute)
return !ts.Before(cutoff) && !ts.After(now.Add(time.Minute))
}
func formatLegacyMessage(kind, ip string, n int, unit string) string {
return kind + " from " + ip + ": " + itoa(n) + " " + unit
}
func itoa(n int) string {
if n == 0 {
return "0"
}
var buf [20]byte
pos := len(buf)
neg := n < 0
if neg {
n = -n
}
for n > 0 {
pos--
buf[pos] = byte('0' + n%10)
n /= 10
}
if neg {
pos--
buf[pos] = '-'
}
return string(buf[pos:])
}
// parseAccessLogRecord parses one Combined Log Format line into an
// accessLogRecord. It does NOT use strings.Fields because quoted
// request/referer/user-agent fields can contain spaces.
//
// Format:
//
// <ip> <ident> <user> [<time>] "<method> <uri> <proto>" <status> <bytes> "<ref>" "<ua>" ["<vhost>"]
//
// Returns ok=false for any line that cannot be parsed. Never panics.
func parseAccessLogRecord(line string) (accessLogRecord, bool) {
const maxUALen = 512
const maxURILen = 4096
var rec accessLogRecord
// IP is everything up to the first space.
sp := strings.IndexByte(line, ' ')
if sp <= 0 {
return rec, false
}
rec.RemoteIP = line[:sp]
rest := line[sp+1:]
// Skip ident, user (two single-token fields). Loose: we just need to
// land at the [time] bracket.
br := strings.IndexByte(rest, '[')
if br < 0 {
return rec, false
}
rest = rest[br+1:]
closeBr := strings.IndexByte(rest, ']')
if closeBr < 0 {
return rec, false
}
timeStr := rest[:closeBr]
rest = rest[closeBr+1:]
// time format: 02/Jan/2006:15:04:05 -0700
t, err := time.Parse("02/Jan/2006:15:04:05 -0700", timeStr)
if err == nil {
rec.Time = t
}
// Request quoted field.
q1 := strings.IndexByte(rest, '"')
if q1 < 0 {
return rec, false
}
rest = rest[q1+1:]
q2 := strings.IndexByte(rest, '"')
if q2 < 0 {
return rec, false
}
request := rest[:q2]
rest = rest[q2+1:]
parts := strings.SplitN(request, " ", 3)
if len(parts) >= 1 {
rec.Method = parts[0]
}
if len(parts) >= 2 {
uri := parts[1]
if len(uri) > maxURILen {
uri = uri[:maxURILen]
}
rec.URI = uri
}
// status (skip leading spaces)
rest = strings.TrimLeft(rest, " ")
end := strings.IndexByte(rest, ' ')
if end > 0 {
rec.Status = atoiSafe(rest[:end])
rest = rest[end+1:]
}
// bytes field -- skip leading spaces then advance past the token
rest = strings.TrimLeft(rest, " ")
end = strings.IndexByte(rest, ' ')
if end > 0 {
rest = rest[end+1:]
} else {
// no more fields after bytes
return rec, true
}
// referer quoted field (skipped).
q1 = strings.IndexByte(rest, '"')
if q1 < 0 {
return rec, false
}
rest = rest[q1+1:]
q2 = strings.IndexByte(rest, '"')
if q2 < 0 {
return rec, false
}
rest = rest[q2+1:]
// UA quoted field.
q1 = strings.IndexByte(rest, '"')
if q1 < 0 {
return rec, true // no UA present is fine
}
rest = rest[q1+1:]
q2 = strings.IndexByte(rest, '"')
if q2 < 0 {
return rec, false
}
ua := rest[:q2]
if len(ua) > maxUALen {
ua = ua[:maxUALen]
}
rec.UserAgent = ua
rest = rest[q2+1:]
// Optional quoted extensions. cPanel may append a quoted vhost after
// UA. Custom proxy formats may append an X-Forwarded-For value. Only
// retain a quoted extension that parses as an IP list; clientIPForRecord
// still ignores it unless RemoteIP is a configured trusted proxy.
for {
q1 = strings.IndexByte(rest, '"')
if q1 < 0 {
break
}
rest = rest[q1+1:]
q2 = strings.IndexByte(rest, '"')
if q2 < 0 {
return rec, false
}
extra := rest[:q2]
if looksLikeXFF(extra) {
rec.XFF = extra
}
rest = rest[q2+1:]
}
return rec, true
}
func atoiSafe(s string) int {
n := 0
for i := 0; i < len(s); i++ {
c := s[i]
if c < '0' || c > '9' {
break
}
n = n*10 + int(c-'0')
}
return n
}
func clientIPForRecord(rec accessLogRecord, cfg *config.Config) string {
if cfg == nil || len(cfg.WebServer.TrustedProxies) == 0 || rec.XFF == "" {
return rec.RemoteIP
}
if !isTrustedProxy(rec.RemoteIP, cfg.WebServer.TrustedProxies) {
return rec.RemoteIP
}
// A trusted direct proxy appends the peer it observed to the end of
// X-Forwarded-For. Use that entry only; earlier entries can come from
// the client.
parts := strings.Split(rec.XFF, ",")
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if net.ParseIP(ip) == nil {
continue
}
return ip
}
return rec.RemoteIP
}
func normalizeHTTPClientIP(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
if host, _, err := net.SplitHostPort(raw); err == nil {
raw = host
}
raw = strings.Trim(raw, "[]")
ip := net.ParseIP(raw)
if ip == nil || ip.IsLoopback() || ip.IsUnspecified() {
return ""
}
return ip.String()
}
// isTrustedProxy returns true when addr matches any entry in proxies (exact
// IP or CIDR). Entries that fail to parse are skipped.
func isTrustedProxy(addr string, proxies []string) bool {
addr = strings.TrimSpace(addr)
parsed := net.ParseIP(addr)
if parsed == nil {
return false
}
for _, entry := range proxies {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
if _, cidr, err := net.ParseCIDR(entry); err == nil {
if cidr.Contains(parsed) {
return true
}
continue
}
if ip := net.ParseIP(entry); ip != nil && ip.Equal(parsed) {
return true
}
}
return false
}
func looksLikeXFF(raw string) bool {
for _, part := range strings.Split(raw, ",") {
if net.ParseIP(strings.TrimSpace(part)) != nil {
return true
}
}
return false
}
// classifyUA maps a User-Agent string to a uaKind. Matching is performed
// on a lower-cased copy of the UA because scanner and impersonation tools
// routinely vary case. Precedence order follows spec section 6: scanner
// signatures win over claimed-bot, claimed-bot wins over headless, etc.
func classifyUA(ua, method string) uaKind {
const maxUALen = 512
if len(ua) > maxUALen {
ua = ua[:maxUALen]
}
if ua == "" || ua == "-" {
return uaKindEmpty
}
low := strings.ToLower(ua)
for _, s := range knownScannerSubstrings {
if strings.Contains(low, s) {
return uaKindKnownScanner
}
}
// WordPress pingback UA on a GET request is illegal: legitimate
// pingback clients always POST. A GET with this UA is a content
// scraper or probe spoofing the pingback agent.
if method == "GET" && strings.HasPrefix(low, "wordpress/") {
return uaKindWPSpoofPingback
}
for _, s := range claimedBotSubstrings {
if strings.Contains(low, s) {
return uaKindClaimedBot
}
}
for _, s := range headlessSubstrings {
if strings.Contains(low, s) {
return uaKindHeadless
}
}
for _, s := range scriptingSubstrings {
if strings.Contains(low, s) {
return uaKindScriptingLang
}
}
return uaKindBrowser
}
var (
knownScannerSubstrings = []string{
"nikto", "sqlmap", "acunetix", "nmap ", "masscan", "wpscan",
"nuclei", "dirbuster", "gobuster", "feroxbuster",
}
claimedBotSubstrings = []string{
"googlebot", "bingbot", "applebot", "duckduckbot", "yandexbot",
"baiduspider", "facebookexternalhit", "twitterbot",
// Appendix A bots: no published static IP range; rDNS-verified.
"amazonbot", "gptbot", "chatgpt-user", "claudebot", "claude-searchbot",
"perplexitybot", "meta-externalagent", "meta-webindexer", "bravebot",
}
headlessSubstrings = []string{
"headlesschrome", "phantomjs", "puppeteer", "playwright",
}
scriptingSubstrings = []string{
"python-requests/", "curl/", "go-http-client/", "java/", "wget/",
"libwww-perl/", "node-fetch/",
}
)
// staticAllowlistClassifier consults the embedded static IP ranges only.
// If the source IP falls inside a published bot range for the claimed bot
// identity, the request is treated as verified and excluded from all
// counters. rDNS verification for IPs outside the static ranges is
// handled by verifyingClassifier.
type staticAllowlistClassifier struct{}
func (staticAllowlistClassifier) IsVerifiedBot(ipStr, ua string) bool {
bot := threatintel.ClaimedBotFromUA(ua)
if bot == "" {
return false
}
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
return threatintel.DefaultRanges().IPInBot(ip, bot)
}
// verifyingClassifier consults the static allowlist first, then the
// rDNS verify cache. Cache misses enqueue an async job and return
// false (treat as unverified for this scan cycle).
type verifyingClassifier struct {
async *threatintel.AsyncBotVerifier
cacheGet func(net.IP, string) (bool, bool)
}
func newVerifyingClassifier(async *threatintel.AsyncBotVerifier,
cacheGet func(net.IP, string) (bool, bool)) verifyingClassifier {
return verifyingClassifier{async: async, cacheGet: cacheGet}
}
func (c verifyingClassifier) IsVerifiedBot(ipStr, ua string) bool {
bot := threatintel.ClaimedBotFromUA(ua)
if bot == "" {
return false
}
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// Static range is the fast positive path.
if threatintel.DefaultRanges().IPInBot(ip, bot) {
return true
}
// Cache lookup: valid positive -> verified, valid negative -> false
// (ConfirmedNegative handles it), no entry -> enqueue and fail open.
if c.cacheGet != nil {
if verified, valid := c.cacheGet(ip, bot); valid {
return verified
}
}
// Enqueue async verification; treat as unverified for this scan.
if c.async != nil {
c.async.Enqueue(ip, bot)
}
return false
}
// ConfirmedNegative reports whether the rDNS cache has a definitive
// negative result for this IP+UA pair. Called from scan() to decide
// whether to promote uaKindClaimedBot to uaKindClaimedBotNegative.
func (c verifyingClassifier) ConfirmedNegative(ipStr, ua string) bool {
bot := threatintel.ClaimedBotFromUA(ua)
if bot == "" {
return false
}
ip := net.ParseIP(ipStr)
if ip == nil || c.cacheGet == nil {
return false
}
verified, valid := c.cacheGet(ip, bot)
return valid && !verified
}
var (
globalBotVerifier *threatintel.AsyncBotVerifier
globalBotGet func(net.IP, string) (bool, bool)
botMu sync.RWMutex
)
// SetBotVerifier installs the daemon-lifetime async verifier and cache
// reader. Called from daemon.go after the store and goroutine are ready.
func SetBotVerifier(v *threatintel.AsyncBotVerifier, get func(net.IP, string) (bool, bool)) {
botMu.Lock()
defer botMu.Unlock()
globalBotVerifier = v
globalBotGet = get
}
// currentBotClassifier returns the appropriate botClassifier based on
// config. When bot_verify_enabled is false, falls back to the
// static-only classifier so DNS calls are never made.
func currentBotClassifier(cfg *config.Config) botClassifier {
if cfg == nil || !cfg.BotVerifyEnabled() {
return staticAllowlistClassifier{}
}
botMu.RLock()
defer botMu.RUnlock()
return newVerifyingClassifier(globalBotVerifier, globalBotGet)
}
package checks
import (
"errors"
"fmt"
)
// MySQL identifier limit per the manual is 64 bytes. Anything beyond
// is silently truncated by the server, but for the cleaner we want
// the failure to be loud at the validation step.
const mysqlIdentMaxLen = 64
// errEmptyIdent and errInvalidIdent are surfaced to the operator so
// the CLI prints a clear "the name you typed is not a valid MySQL
// identifier" instead of a SQL syntax error from the server.
var (
errEmptyIdent = errors.New("identifier is empty")
errInvalidIdent = errors.New("identifier contains characters outside [A-Za-z0-9_$]")
errLongIdent = fmt.Errorf("identifier exceeds %d bytes (MySQL limit)", mysqlIdentMaxLen)
)
// QuoteIdent returns a backtick-quoted MySQL identifier, or an error
// if the input is empty, longer than 64 bytes, or contains characters
// outside the safe class. Used at every site where an attacker-
// controlled object name (trigger / event / routine / schema) would
// otherwise reach a SQL string concatenation.
//
// The safe class is intentionally narrow: standard MySQL allows more
// (digits-only names, dotted names, $-prefixed) but the cleaner only
// needs to handle CMS-shaped identifiers and operator-typed schema
// names. Rejecting anything weirder is cheaper than reasoning about
// edge cases in dynamic SQL.
func QuoteIdent(name string) (string, error) {
if name == "" {
return "", errEmptyIdent
}
if len(name) > mysqlIdentMaxLen {
return "", errLongIdent
}
for _, r := range name {
switch {
case r >= 'A' && r <= 'Z':
case r >= 'a' && r <= 'z':
case r >= '0' && r <= '9':
case r == '_' || r == '$':
default:
return "", fmt.Errorf("%w: %q", errInvalidIdent, name)
}
}
return "`" + name + "`", nil
}
package checks
import (
"context"
"fmt"
"math"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/attackdb"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const localThreatScoreFindingLimit = 50
// CheckLocalThreatScore generates findings for IPs that have accumulated
// a high local threat score but have not yet been blocked.
// Runs every 10 minutes as part of TierCritical.
func CheckLocalThreatScore(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
adb := attackdb.Global()
if adb == nil {
return nil
}
alreadyBlocked := loadAllBlockedIPs(cfg.StatePath)
var findings []alert.Finding
for _, rec := range adb.TopAttackers(math.MaxInt) {
if ctx.Err() != nil {
return findings
}
if alreadyBlocked[rec.IP] {
continue
}
score := attackdb.ComputeScore(rec)
if score >= 70 {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "local_threat_score",
Message: fmt.Sprintf("High local threat score: %s (score %d/100, %d attacks)", rec.IP, score, rec.EventCount),
Details: fmt.Sprintf("Attack types: %v\nAccounts targeted: %d\nFirst seen: %s\nLast seen: %s", rec.AttackCounts, len(rec.Accounts), rec.FirstSeen.Format("2006-01-02 15:04"), rec.LastSeen.Format("2006-01-02 15:04")),
Timestamp: time.Now(),
SourceIP: rec.IP,
})
if len(findings) >= localThreatScoreFindingLimit {
break
}
}
}
return findings
}
package checks
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
func CheckMailQueue(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
out, err := runCmd("exim", "-bpc")
if err != nil || out == nil {
return nil
}
count, err := strconv.Atoi(strings.TrimSpace(string(out)))
if err != nil {
return nil
}
if count >= cfg.Thresholds.MailQueueCrit {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "mail_queue",
Message: fmt.Sprintf("Exim mail queue critical: %d messages", count),
Details: "Possible spam outbreak from compromised account",
})
} else if count >= cfg.Thresholds.MailQueueWarn {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "mail_queue",
Message: fmt.Sprintf("Exim mail queue elevated: %d messages", count),
})
}
return findings
}
package checks
import (
"context"
"crypto/sha256"
"fmt"
"path/filepath"
"sort"
"strings"
"time"
"unicode"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
)
// Mail-filter exfiltration detector.
//
// cPanel stores per-mailbox Exim filters at
// /home/<user>/etc/<domain>/<localpart>/filter and domain-wide defaults at
// /etc/vfilters/<domain>. A compromised webmail account is commonly weaponised
// by writing a filter that copies every inbound message to an external dropbox
// while keeping a local copy, so the victim never notices the interception
// (business email compromise). This check parses those Exim filters and scores
// the deliver/save actions for that stealth pattern.
//
// Unlike CheckForwarders (valiases redirects), the stealth combination here is
// inherently malicious, so it is reported even when the filter predates CSM --
// newness gating only applies to plain external forwards that are frequently
// legitimate customer configuration.
// filterAction is a single Exim filter action (deliver/save/pipe/finish/...).
type filterAction struct {
verb string
arg string
unseen bool
}
// filterRule is one branch of an Exim filter: the condition that guards it and
// the actions it performs. The unconditional top level is represented as a rule
// with an empty condition.
type filterRule struct {
condition string
matchesAll bool
actions []filterAction
}
// filterMailbox identifies the mailbox a filter file belongs to. localPart is
// "*" for a domain-wide /etc/vfilters file.
type filterMailbox struct {
localPart string
domain string
}
func (m filterMailbox) String() string {
return m.localPart + "@" + m.domain
}
// filterFinding is the scorer's intermediate result before it is turned into an
// alert.Finding (which needs file path and newness context).
type filterFinding struct {
severity alert.Severity
check string
kind string // "exfil" | "forwarder" | "pipe" | "blackhole"
dest string // external destination, when applicable, for correlation
reason string
onlyIfNew bool
}
// safePipeCommands are cPanel built-in pipe targets that are not attacker code.
var safePipeCommands = []string{
"/usr/local/cpanel/bin/autorespond",
"/usr/local/cpanel/bin/boxtrapper",
"/usr/local/cpanel/bin/mailman",
}
// ---------------------------------------------------------------------------
// Parser
// ---------------------------------------------------------------------------
type eximToken struct {
text string
str bool
}
// tokenizeExim splits Exim filter source into tokens. Quoted strings (with \"
// and \\ escapes) become a single string token with the quotes removed;
// parentheses are standalone tokens; everything else is a bareword. Comments
// (# to end of line, outside strings) are dropped.
func tokenizeExim(s string) []eximToken {
var toks []eximToken
runes := []rune(s)
i := 0
for i < len(runes) {
c := runes[i]
switch c {
case ' ', '\t', '\n', '\r':
i++
case '#':
for i < len(runes) && runes[i] != '\n' {
i++
}
case '(', ')':
toks = append(toks, eximToken{text: string(c)})
i++
case '"':
i++
var b strings.Builder
for i < len(runes) && runes[i] != '"' {
if runes[i] == '\\' && i+1 < len(runes) {
i++
}
b.WriteRune(runes[i])
i++
}
if i < len(runes) {
i++ // closing quote
}
toks = append(toks, eximToken{text: b.String(), str: true})
default:
start := i
for i < len(runes) {
r := runes[i]
if r == ' ' || r == '\t' || r == '\n' || r == '\r' || r == '(' || r == ')' || r == '"' || r == '#' {
break
}
i++
}
toks = append(toks, eximToken{text: string(runes[start:i])})
}
}
return toks
}
// renderCondition reconstructs a condition string from its tokens so the
// match-all heuristic can run against text that matches the source form
// (string tokens are re-quoted).
func renderCondition(toks []eximToken) string {
var parts []string
for _, t := range toks {
if t.str {
parts = append(parts, `"`+t.text+`"`)
} else {
parts = append(parts, t.text)
}
}
return strings.Join(parts, " ")
}
var actionVerbs = map[string]bool{
"deliver": true,
"save": true,
"pipe": true,
"finish": true,
"mail": true,
"vacation": true,
}
var actionArgs = map[string]bool{
"deliver": true,
"save": true,
"pipe": true,
"mail": true,
"vacation": true,
}
var controlWords = map[string]bool{
"if": true,
"elif": true,
"else": true,
"endif": true,
"then": true,
"unseen": true,
}
type filterRuleNode struct {
rule filterRule
parent *filterRuleNode
}
// parseEximFilter parses Exim filter source into a flat list of rules, one per
// if/elif/else branch plus one for any unconditional top-level actions. Nested
// branches include ancestor actions so split deliver/save patterns still score
// as one executed branch.
func parseEximFilter(content string) []filterRule {
toks := tokenizeExim(content)
top := &filterRuleNode{rule: filterRule{matchesAll: true}}
stack := []*filterRuleNode{top}
rules := []*filterRuleNode{top}
pendingUnseen := false
i := 0
for i < len(toks) {
t := toks[i]
if t.str {
i++
continue
}
kw := strings.ToLower(t.text)
switch kw {
case "if", "elif":
if kw == "elif" && len(stack) > 1 {
stack = stack[:len(stack)-1]
}
parent := stack[len(stack)-1]
i++
condStart := i
for i < len(toks) && !tokenIs(toks[i], "then") {
i++
}
cond := renderCondition(toks[condStart:i])
if i < len(toks) {
i++ // consume "then"
}
r := &filterRuleNode{
rule: filterRule{condition: cond, matchesAll: conditionMatchesAll(cond)},
parent: parent,
}
rules = append(rules, r)
stack = append(stack, r)
pendingUnseen = false
case "else":
if len(stack) > 1 {
stack = stack[:len(stack)-1]
}
parent := stack[len(stack)-1]
i++
r := &filterRuleNode{
rule: filterRule{condition: "else"},
parent: parent,
}
rules = append(rules, r)
stack = append(stack, r)
pendingUnseen = false
case "endif":
if len(stack) > 1 {
stack = stack[:len(stack)-1]
}
i++
pendingUnseen = false
case "unseen":
pendingUnseen = true
i++
default:
if !actionVerbs[kw] {
i++
continue
}
verb := kw
i++
arg := ""
if actionArgs[verb] && i < len(toks) {
if toks[i].str || isBareActionArg(toks[i]) {
arg = toks[i].text
i++
}
}
cur := stack[len(stack)-1]
cur.rule.actions = append(cur.rule.actions, filterAction{verb: verb, arg: arg, unseen: pendingUnseen})
pendingUnseen = false
}
}
out := make([]filterRule, 0, len(rules))
for _, r := range rules {
if len(r.rule.actions) > 0 {
out = append(out, flattenRuleNode(r))
}
}
return out
}
func tokenIs(t eximToken, word string) bool {
return !t.str && strings.EqualFold(t.text, word)
}
func isBareActionArg(t eximToken) bool {
if t.str {
return true
}
lower := strings.ToLower(t.text)
return !controlWords[lower] && !actionVerbs[lower]
}
func flattenRuleNode(node *filterRuleNode) filterRule {
var chain []*filterRuleNode
for n := node; n != nil; n = n.parent {
chain = append(chain, n)
}
out := filterRule{matchesAll: true}
var conditions []string
for i := len(chain) - 1; i >= 0; i-- {
r := chain[i].rule
if r.condition != "" {
conditions = append(conditions, r.condition)
}
if !r.matchesAll {
out.matchesAll = false
}
out.actions = append(out.actions, r.actions...)
}
out.condition = strings.Join(conditions, " && ")
return out
}
// conditionMatchesAll reports whether a filter condition fires on effectively
// all mail: an unconditional rule, or one that only tests that an address or
// header comparison is true for every normal email address.
func conditionMatchesAll(cond string) bool {
c := strings.TrimSpace(cond)
if c == "" {
return true
}
return tokenExpressionMatchesAll(tokenizeExim(c))
}
func tokenExpressionMatchesAll(toks []eximToken) bool {
toks = trimOuterParens(toks)
if len(toks) == 0 {
return false
}
for _, term := range splitTopLevel(toks, "or") {
if tokenConjunctionMatchesAll(term) {
return true
}
}
return false
}
func tokenConjunctionMatchesAll(toks []eximToken) bool {
toks = trimOuterParens(toks)
if len(toks) == 0 {
return false
}
parts := splitTopLevel(toks, "and")
for _, part := range parts {
if !tokenTermMatchesAll(part) {
return false
}
}
return len(parts) > 0
}
func tokenTermMatchesAll(toks []eximToken) bool {
toks = trimOuterParens(toks)
if len(toks) == 0 {
return false
}
if parts := splitTopLevel(toks, "or"); len(parts) > 1 {
return tokenExpressionMatchesAll(toks)
}
if parts := splitTopLevel(toks, "and"); len(parts) > 1 {
return tokenConjunctionMatchesAll(toks)
}
hasMatchAllComparison := false
for i := 0; i < len(toks); i++ {
if toks[i].str {
continue
}
if strings.EqualFold(toks[i].text, "not") {
return false
}
}
for i := 0; i+1 < len(toks); i++ {
if toks[i].str {
continue
}
op := strings.ToLower(toks[i].text)
if !isAddressComparisonOperator(op) {
continue
}
if !comparisonMatchesAllAddress(op, toks[i+1].text) || !comparisonHasAddressOperand(toks, i) {
return false
}
hasMatchAllComparison = true
}
return hasMatchAllComparison
}
func splitTopLevel(toks []eximToken, word string) [][]eximToken {
var parts [][]eximToken
start := 0
depth := 0
for i, t := range toks {
if t.str {
continue
}
switch t.text {
case "(":
depth++
case ")":
if depth > 0 {
depth--
}
default:
if depth == 0 && strings.EqualFold(t.text, word) {
if start < i {
parts = append(parts, toks[start:i])
}
start = i + 1
}
}
}
if start < len(toks) {
parts = append(parts, toks[start:])
}
return parts
}
func trimOuterParens(toks []eximToken) []eximToken {
for len(toks) >= 2 && tokenIs(toks[0], "(") && tokenIs(toks[len(toks)-1], ")") && outerParensEncloseAll(toks) {
toks = toks[1 : len(toks)-1]
}
return toks
}
func outerParensEncloseAll(toks []eximToken) bool {
depth := 0
for i, t := range toks {
if t.str {
continue
}
switch t.text {
case "(":
depth++
case ")":
depth--
if depth == 0 && i != len(toks)-1 {
return false
}
if depth < 0 {
return false
}
}
}
return depth == 0
}
func isAddressComparisonOperator(op string) bool {
switch op {
case "contains", "matches", "is":
return true
}
return false
}
func comparisonMatchesAllAddress(op, value string) bool {
v := strings.ToLower(strings.TrimSpace(value))
switch op {
case "contains":
return v == "@"
case "matches":
switch v {
case "@", ".*@.*", ".+@.+", "^.*@.*$", "^.+@.+$":
return true
}
case "is":
return v == "*@*"
}
return false
}
func comparisonHasAddressOperand(toks []eximToken, opIndex int) bool {
for i := opIndex - 1; i >= 0; i-- {
if toks[i].str {
continue
}
word := strings.ToLower(toks[i].text)
switch word {
case "and", "or", "then", "else":
return false
case "(", ")":
return false
case "not":
continue
}
if tokenLooksAddressOperand(word) {
return true
}
}
return false
}
func tokenLooksAddressOperand(token string) bool {
addressOperands := []string{
"$thisaddress",
"foranyaddress",
"$sender_address",
"$return_path",
"$header_from",
"$h_from",
"$header_to",
"$h_to",
"$header_cc",
"$h_cc",
"$header_bcc",
"$h_bcc",
"$header_reply-to",
"$h_reply-to",
"$header_sender",
"$h_sender",
"$header_return-path",
"$h_return-path",
}
for _, operand := range addressOperands {
if strings.Contains(token, operand) {
return true
}
}
return false
}
// ---------------------------------------------------------------------------
// Scorer
// ---------------------------------------------------------------------------
// destIsExternal reports whether an Exim deliver destination leaves the local
// mail system. Exim variables ($domain etc.) and same-domain/local-domain
// addresses are not external.
func destIsExternal(dest string, mb filterMailbox, localDomains map[string]bool) bool {
_, dom, ok := splitDeliverDest(dest)
if !ok {
return false
}
if deliverDomainIsLocal(dom, mb, localDomains) {
return false
}
return true
}
// destIsLocalSelf reports whether a deliver destination routes back into the
// local mail system (a self re-delivery that keeps a copy for the victim).
func destIsLocalSelf(dest string, mb filterMailbox, localDomains map[string]bool) bool {
_, dom, ok := splitDeliverDest(dest)
if !ok {
return false
}
return deliverDomainIsLocal(dom, mb, localDomains)
}
func splitDeliverDest(dest string) (string, string, bool) {
clean := strings.Trim(strings.TrimSpace(dest), `"`)
at := strings.LastIndexByte(clean, '@')
if at < 0 || at == len(clean)-1 {
return "", "", false
}
local := strings.Trim(strings.TrimSpace(clean[:at]), `"`)
dom := strings.ToLower(strings.Trim(strings.TrimSpace(clean[at+1:]), `"`))
if local == "" || dom == "" {
return "", "", false
}
return local, dom, true
}
func deliverDomainIsLocal(dom string, mb filterMailbox, localDomains map[string]bool) bool {
d := strings.ToLower(strings.TrimSpace(dom))
if d == "$domain" || d == "${domain}" {
return true
}
if d == strings.ToLower(mb.domain) {
return true
}
return localDomains[d]
}
func isSafePipe(cmd string) bool {
first := firstPipeCommandWord(cmd)
for _, s := range safePipeCommands {
if first == s {
return true
}
}
return false
}
func firstPipeCommandWord(cmd string) string {
s := strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(cmd), "|"))
var b strings.Builder
var quote rune
for _, r := range s {
if quote != 0 {
if r == quote {
quote = 0
continue
}
b.WriteRune(r)
continue
}
if r == '\'' || r == '"' {
quote = r
continue
}
if unicode.IsSpace(r) {
break
}
b.WriteRune(r)
}
return b.String()
}
// scoreFilterRules evaluates one mailbox's parsed filter rules and returns the
// dangerous patterns found. Suppression entries in known (format
// "local@domain: dest") drop matching plain external destinations.
func scoreFilterRules(rules []filterRule, mb filterMailbox, localDomains map[string]bool, known []string) []filterFinding {
var out []filterFinding
seen := map[string]bool{}
add := func(f filterFinding) {
key := f.kind + "|" + f.dest
if seen[key] {
return
}
if f.kind == "forwarder" && seen["exfil|"+f.dest] {
return
}
if f.kind == "exfil" && seen["forwarder|"+f.dest] {
delete(seen, "forwarder|"+f.dest)
for i := range out {
if out[i].kind == "forwarder" && out[i].dest == f.dest {
out = append(out[:i], out[i+1:]...)
break
}
}
}
seen[key] = true
out = append(out, f)
}
type externalDelivery struct {
dest string
unseen bool
}
for _, r := range rules {
var external []externalDelivery
hasLocalCopy := false
hasDevNull := false
for _, a := range r.actions {
switch a.verb {
case "deliver":
switch {
case destIsExternal(a.arg, mb, localDomains):
external = append(external, externalDelivery{dest: a.arg, unseen: a.unseen})
case destIsLocalSelf(a.arg, mb, localDomains):
hasLocalCopy = true
}
case "save":
if strings.TrimSpace(a.arg) == "/dev/null" {
hasDevNull = true
} else {
hasLocalCopy = true
}
case "pipe":
if !isSafePipe(a.arg) {
add(filterFinding{
severity: alert.Critical,
check: "email_filter_pipe",
kind: "pipe",
dest: a.arg,
reason: fmt.Sprintf("filter pipes mail to a command: %s", a.arg),
})
}
}
}
if len(external) > 0 {
for _, delivery := range external {
stealth := hasLocalCopy || hasDevNull || r.matchesAll || delivery.unseen
if !stealth && isKnownForwarder(mb.localPart, mb.domain, delivery.dest, known) {
continue
}
if stealth {
add(filterFinding{
severity: alert.Critical,
check: "email_filter_exfil",
kind: "exfil",
dest: delivery.dest,
reason: stealthReason(hasLocalCopy, hasDevNull, r.matchesAll, delivery.unseen),
})
} else {
add(filterFinding{
severity: alert.High,
check: "email_filter_forwarder",
kind: "forwarder",
dest: delivery.dest,
reason: "filter forwards mail to an external address",
onlyIfNew: true,
})
}
}
continue
}
if hasDevNull && r.matchesAll {
add(filterFinding{
severity: alert.High,
check: "email_filter_blackhole",
kind: "blackhole",
reason: "filter discards all mail to /dev/null",
})
}
}
return out
}
func stealthReason(localCopy, devNull, matchAll, unseen bool) string {
switch {
case devNull:
if matchAll {
return "filter forwards all mail externally and discards the local copy to hide it"
}
return "filter forwards mail externally and discards the local copy to hide it"
case localCopy || unseen:
if !matchAll {
return "filter sends matching mail to an external address while keeping a local copy (stealth interception)"
}
return "filter copies every message to an external address while keeping a local copy (stealth interception)"
case matchAll:
return "filter forwards all mail to an external address"
}
return "filter forwards mail to an external address"
}
// ---------------------------------------------------------------------------
// Check
// ---------------------------------------------------------------------------
// mailboxFromFilterPath derives the mailbox from a filter file path. Per-mailbox
// filters live at /home/<user>/etc/<domain>/<localpart>/filter; domain-wide
// filters at /etc/vfilters/<domain>.
func mailboxFromFilterPath(path string) filterMailbox {
if dir := filepath.Dir(path); dir == "/etc/vfilters" {
return filterMailbox{localPart: "*", domain: filepath.Base(path)}
}
parts := strings.Split(path, "/")
for i := 0; i+2 < len(parts); i++ {
if parts[i] == "etc" && parts[i+1] != "" && parts[i+2] != "" {
return filterMailbox{localPart: parts[i+2], domain: parts[i+1]}
}
}
return filterMailbox{localPart: "*", domain: filepath.Base(filepath.Dir(path))}
}
// CheckMailFilters scans per-mailbox and domain-wide Exim filters for BEC-style
// exfiltration rules. Throttled to PasswordCheckIntervalMin, like the forwarder
// audit it complements.
func CheckMailFilters(ctx context.Context, cfg *config.Config, st *state.Store) []alert.Finding {
if ctx == nil {
ctx = context.Background()
}
db := store.Global()
if db == nil {
if !shouldReportMailFilterStoreUnavailable(st, cfg) {
return nil
}
// Without the store the whole check is inoperative (hashes and the
// throttle live there). Say so instead of looking like a clean host.
return []alert.Finding{{
Severity: alert.Warning,
Check: "email_mail_filters",
Message: "Mail filter audit skipped: state store unavailable",
Timestamp: time.Now(),
}}
}
if !ForceAll {
if last := db.GetMetaString("email:mailfilter_last_refresh"); last != "" {
if ts, err := time.Parse(time.RFC3339, last); err == nil {
interval := time.Duration(cfg.EmailProtection.PasswordCheckIntervalMin) * time.Minute
if time.Since(ts) < interval {
return nil
}
}
}
}
if ctx.Err() != nil {
return nil
}
localDomains := loadLocalDomains()
var files []string
baselineComplete := true
if perMailbox, err := homeGlob(ctx, "etc", "*", "*", "filter"); err == nil {
files = append(files, perMailbox...)
} else {
baselineComplete = false
}
if AccountFromContext(ctx) == "" {
if vfilters, err := osFS.Glob("/etc/vfilters/*"); err == nil {
files = append(files, vfilters...)
} else {
baselineComplete = false
}
}
maxFiles := effectiveAccountScanMaxFiles(cfg)
baselineComplete = baselineComplete && scanCoversAllFiles(files, maxFiles)
ranked := rankPathsByMtimeDesc(ctx, files, maxFiles)
if ctx.Err() != nil {
return nil
}
var collected []mailFilterPending
for _, path := range ranked {
if ctx.Err() != nil {
return findingsFromPending(collected)
}
data, err := osFS.ReadFile(path)
if err != nil {
baselineComplete = false
continue
}
currentHash := sha256Hex(data)
isNew := forwarderFileIsNew(db, "email:mailfilter_last_refresh", "mailfilter:"+path, currentHash)
if err := db.SetForwarderHash("mailfilter:"+path, currentHash); err != nil {
baselineComplete = false
}
mb := mailboxFromFilterPath(path)
rules := parseEximFilter(string(data))
for _, ff := range scoreFilterRules(rules, mb, localDomains, cfg.EmailProtection.KnownForwarders) {
if ff.onlyIfNew && !isNew {
continue
}
collected = append(collected, mailFilterPending{
finding: alert.Finding{
Severity: ff.severity,
Check: ff.check,
Message: fmt.Sprintf("%s: %s", mb.String(), ff.reason),
Details: mailFilterDetails(mb, path, ff),
FilePath: path,
Domain: mb.domain,
Mailbox: mailboxField(mb),
},
dest: ff.dest,
mailbox: mb.String(),
})
}
}
annotateCrossAccount(collected)
if ctx.Err() != nil {
return nil
}
// Only a full scan establishes the baseline / refreshes the throttle:
// an account-scoped scan hashes one account's files, and marking it
// complete would make the next full scan treat every other account's
// existing filters as newly created.
baselineExists := db.GetMetaString("email:mailfilter_last_refresh") != ""
if AccountFromContext(ctx) == "" && (baselineComplete || baselineExists) {
_ = db.SetMetaString("email:mailfilter_last_refresh", time.Now().Format(time.RFC3339))
}
return findingsFromPending(collected)
}
func shouldReportMailFilterStoreUnavailable(st *state.Store, cfg *config.Config) bool {
if st == nil {
return true
}
return st.ShouldRunThrottled("email_mail_filters_store_unavailable", mailFilterAuditIntervalMin(cfg))
}
func mailFilterAuditIntervalMin(cfg *config.Config) int {
if cfg == nil || cfg.EmailProtection.PasswordCheckIntervalMin <= 0 {
return 1440
}
return cfg.EmailProtection.PasswordCheckIntervalMin
}
func mailFilterDetails(mb filterMailbox, path string, ff filterFinding) string {
var b strings.Builder
fmt.Fprintf(&b, "Mailbox: %s\nDomain: %s\nFile: %s\n", mb.String(), mb.domain, path)
if ff.dest != "" {
fmt.Fprintf(&b, "Destination: %s\n", ff.dest)
}
b.WriteString(ff.reason)
return b.String()
}
func mailboxField(mb filterMailbox) string {
if mb.localPart == "*" {
return ""
}
return mb.String()
}
func sha256Hex(data []byte) string {
h := sha256.Sum256(data)
return fmt.Sprintf("%x", h[:])
}
// mailFilterPending is an in-flight finding plus the fields needed for the
// cross-account correlation pass before findings are emitted.
type mailFilterPending struct {
finding alert.Finding
dest string
mailbox string
}
// annotateCrossAccount marks exfil findings whose external destination appears
// across two or more distinct mailboxes -- a strong campaign signal.
func annotateCrossAccount(collected []mailFilterPending) {
byDest := map[string]map[string]bool{}
for _, p := range collected {
if p.dest == "" {
continue
}
if byDest[p.dest] == nil {
byDest[p.dest] = map[string]bool{}
}
byDest[p.dest][p.mailbox] = true
}
for i := range collected {
dest := collected[i].dest
boxes := byDest[dest]
if len(boxes) < 2 {
continue
}
others := make([]string, 0, len(boxes))
for b := range boxes {
if b != collected[i].mailbox {
others = append(others, b)
}
}
sort.Strings(others)
collected[i].finding.Severity = alert.Critical
collected[i].finding.Details += fmt.Sprintf(
"\nCross-account: the same destination %s is used by %d mailboxes (also %s). This indicates a coordinated campaign.",
dest, len(boxes), strings.Join(others, ", "))
}
}
func findingsFromPending(collected []mailFilterPending) []alert.Finding {
out := make([]alert.Finding, 0, len(collected))
for _, p := range collected {
out = append(out, p.finding)
}
return out
}
package checks
import (
"context"
"fmt"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const perAccountMailThreshold = 100 // emails per recent log window
// mailLogTailLinesDefault is the built-in fallback for how many trailing
// lines of /var/log/exim_mainlog CheckMailPerAccount tails per cycle.
// Operator override: cfg.Thresholds.MailLogTailLines.
const mailLogTailLinesDefault = 500
// CheckMailPerAccount reads the tail of exim_mainlog and counts outbound
// emails per cPanel account. Alerts if a single account exceeds the threshold.
func CheckMailPerAccount(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
tailLines := mailLogTailLinesDefault
if cfg != nil && cfg.Thresholds.MailLogTailLines > 0 {
tailLines = cfg.Thresholds.MailLogTailLines
}
lines := tailFile("/var/log/exim_mainlog", tailLines)
// Count emails per sender domain/user
// Exim log format: "... <= user@domain.com ..." for outgoing
counts := make(map[string]int)
for _, line := range lines {
// Look for outgoing messages (<=)
idx := strings.Index(line, " <= ")
if idx < 0 {
continue
}
// Extract sender address
rest := line[idx+4:]
fields := strings.Fields(rest)
if len(fields) < 1 {
continue
}
sender := fields[0]
// Extract the domain part
atIdx := strings.LastIndex(sender, "@")
if atIdx < 0 {
continue
}
domain := sender[atIdx+1:]
// Skip system/bounce messages
if domain == "" || sender == "<>" || strings.HasPrefix(sender, "cPanel") {
continue
}
counts[domain]++
}
// Alert on accounts exceeding threshold
for domain, count := range counts {
if count >= perAccountMailThreshold {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "mail_per_account",
Message: fmt.Sprintf("High email volume from %s: %d messages in recent log", domain, count),
Details: "Possible spam outbreak or compromised email account",
Domain: domain,
})
}
}
return findings
}
package checks
import (
"context"
"fmt"
"net"
"strings"
"sync"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
func CheckOutboundConnections(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Parse /proc/net/tcp for established connections
// Format: sl local_address rem_address st ...
// local_address = IP:port we are listening/connecting from
// rem_address = IP:port of the remote end
data, err := osFS.ReadFile("/proc/net/tcp")
if err != nil {
return nil
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) < 4 {
continue
}
if fields[0] == "sl" {
continue
}
// State 01 = ESTABLISHED
if fields[3] != "01" {
continue
}
localAddr := fields[1]
remoteAddr := fields[2]
_, localPort := parseHexAddr(localAddr)
remoteIP, remotePort := parseHexAddr(remoteAddr)
if remoteIP == "" || remoteIP == "127.0.0.1" || remoteIP == "0.0.0.0" {
continue
}
// Check remote IP against C2 blocklist
for _, blocked := range cfg.C2Blocklist {
if remoteIP == blocked {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "c2_connection",
Message: fmt.Sprintf("Connection to known C2 IP: %s:%d", remoteIP, remotePort),
Details: fmt.Sprintf("Local port: %d", localPort),
SourceIP: remoteIP,
})
}
}
// Check if OUR LOCAL port is a backdoor port (we're listening on it)
// This catches backdoor listeners, not clients connecting from high ports
for _, bp := range cfg.BackdoorPorts {
if localPort == bp {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "backdoor_port",
Message: fmt.Sprintf("Listening on known backdoor port %d, connected from %s:%d", localPort, remoteIP, remotePort),
SourceIP: remoteIP,
})
}
}
// Also check if we're connecting OUT to a backdoor port on a remote host
// (e.g. reverse shell calling back to attacker's listener)
// Skip if our local port is a known service (the remote port is just
// the client's ephemeral port, not a backdoor listener)
knownServicePorts := map[int]bool{
21: true, 25: true, 26: true, 53: true, 80: true, 110: true,
143: true, 443: true, 465: true, 587: true, 993: true, 995: true,
2082: true, 2083: true, 2086: true, 2087: true, 2095: true, 2096: true,
3306: true, 4190: true,
}
if knownServicePorts[localPort] {
continue
}
for _, bp := range cfg.BackdoorPorts {
if remotePort == bp {
if isInfraIP(remoteIP, cfg.InfraIPs) {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "backdoor_port_outbound",
Message: fmt.Sprintf("Outbound connection to backdoor port: %s:%d", remoteIP, remotePort),
Details: fmt.Sprintf("Local port: %d", localPort),
SourceIP: remoteIP,
})
}
}
}
return findings
}
func isInfraIP(ip string, infraNets []string) bool {
parsed := net.ParseIP(ip)
if parsed == nil {
return false
}
for _, entry := range infraNets {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
// Try CIDR first (e.g. "10.0.0.0/8")
_, network, err := net.ParseCIDR(entry)
if err == nil {
if network.Contains(parsed) {
return true
}
continue
}
// Fall back to plain IP match (e.g. "1.2.3.4")
if entryIP := net.ParseIP(entry); entryIP != nil && entryIP.Equal(parsed) {
return true
}
}
// Also check Cloudflare IPs - these must never be blocked/challenged
// because blocking a CF edge IP blocks thousands of legitimate users.
// The detection/alert still fires; only the nftables action is skipped.
if isCloudflareIP(parsed) {
return true
}
return false
}
var (
cfNets []*net.IPNet
cfNetsMu sync.RWMutex
)
// SetCloudflareNets updates the cached Cloudflare IP ranges.
// Called by the daemon after fetching CF IPs.
func SetCloudflareNets(cidrs []string) {
var nets []*net.IPNet
for _, cidr := range cidrs {
_, network, err := net.ParseCIDR(cidr)
if err == nil {
nets = append(nets, network)
}
}
cfNetsMu.Lock()
cfNets = nets
cfNetsMu.Unlock()
}
func isCloudflareIP(ip net.IP) bool {
cfNetsMu.RLock()
defer cfNetsMu.RUnlock()
for _, network := range cfNets {
if network.Contains(ip) {
return true
}
}
return false
}
func parseHexAddr(hexAddr string) (string, int) {
parts := strings.Split(hexAddr, ":")
if len(parts) != 2 {
return "", 0
}
hexIP := parts[0]
hexPort := parts[1]
if len(hexIP) != 8 {
return "", 0
}
// Parse little-endian hex IP
var octets [4]byte
for i := 0; i < 4; i++ {
val := hexToByte(hexIP[6-2*i : 8-2*i])
octets[i] = val
}
ip := net.IPv4(octets[0], octets[1], octets[2], octets[3]).String()
port, ok := parseProcNetHexPort(hexPort)
if !ok {
return ip, 0
}
return ip, port
}
func parseProcNetHexPort(hexPort string) (int, bool) {
if len(hexPort) != 4 {
return 0, false
}
port := 0
for i := 0; i < len(hexPort); i++ {
c := hexPort[i]
if !isHexDigit(c) {
return 0, false
}
port = port*16 + hexVal(c)
}
return port, true
}
func hexToByte(s string) byte {
if len(s) != 2 {
return 0
}
// #nosec G115 -- hexVal returns 0..15; (h<<4)|h fits in a byte (0..255).
return byte(hexVal(s[0])<<4 | hexVal(s[1]))
}
func hexVal(c byte) int {
switch {
case c >= '0' && c <= '9':
return int(c - '0')
case c >= 'a' && c <= 'f':
return int(c-'a') + 10
case c >= 'A' && c <= 'F':
return int(c-'A') + 10
}
return 0
}
package checks
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"syscall"
)
// fixPerfAllowedRoots scopes performance remediations to per-account web
// content. Same shape as the other fix*AllowedRoots so tests can swap in
// a t.TempDir().
var fixPerfAllowedRoots = []string{"/home"}
// FixErrorLogBloat truncates an account-owned error_log file in place.
// Truncating preserves the inode and file ownership so any PHP process
// holding the descriptor keeps appending to the same file without an
// open/reopen race; this is also the safest action because nothing in
// the host needs the historical lines to keep serving traffic.
func FixErrorLogBloat(path string) RemediationResult {
return FixErrorLogBloatInRoots(path, fixPerfAllowedRoots)
}
// FixErrorLogBloatInRoots is FixErrorLogBloat with caller-supplied roots.
// The Web UI uses this to include configured account_roots while tests can
// keep writes under t.TempDir().
func FixErrorLogBloatInRoots(path string, allowedRoots []string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path from finding"}
}
resolved, info, err := resolveExistingFixPath(path, allowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
if info.IsDir() {
return RemediationResult{Error: "refusing to truncate a directory"}
}
if filepath.Base(resolved) != "error_log" {
return RemediationResult{Error: fmt.Sprintf("refusing to truncate non error_log file: %s", resolved)}
}
oldSize := info.Size()
if err := truncateFilePreservingIdentity(resolved, info); err != nil {
return RemediationResult{Error: fmt.Sprintf("truncate failed: %v", err)}
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("truncate %s", resolved),
Description: fmt.Sprintf("Emptied error_log (was %s)", humanBytes(oldSize)),
}
}
// FixDisplayErrorsOn rewrites an INI / .htaccess / .user.ini file so the
// display_errors directive is set to Off. The original line is preserved
// commented out for operator review; an override line is appended at the
// end of the file so the last-write-wins semantics of every supported
// config format land on Off regardless of earlier statements.
//
// Only .user.ini, php.ini, and .htaccess are accepted. wp-config.php and
// other PHP source files require code-level edits this routine does not
// attempt. The caller (web UI) should not advertise the fix for those.
func FixDisplayErrorsOn(path string) RemediationResult {
return FixDisplayErrorsOnInRoots(path, fixPerfAllowedRoots)
}
// FixDisplayErrorsOnInRoots is FixDisplayErrorsOn with caller-supplied
// roots. The Web UI uses this to honor account_roots outside /home.
func FixDisplayErrorsOnInRoots(path string, allowedRoots []string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path from finding"}
}
resolved, info, err := resolveExistingFixPath(path, allowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
if info.IsDir() {
return RemediationResult{Error: "refusing to edit a directory"}
}
base := filepath.Base(resolved)
var (
isHtaccess bool
supported bool
)
switch {
case base == ".user.ini" || base == "php.ini" || strings.HasSuffix(base, ".ini"):
supported = true
case base == ".htaccess":
supported = true
isHtaccess = true
}
if !supported {
return RemediationResult{Error: fmt.Sprintf("automated display_errors fix only supports .user.ini, php.ini, and .htaccess (got %s)", base)}
}
// #nosec G304 -- path was validated by resolveExistingFixPath against
// the supplied remediation roots; symlinks already rejected.
data, err := readFilePreservingIdentity(resolved, info)
if err != nil {
return RemediationResult{Error: fmt.Sprintf("read failed: %v", err)}
}
rewritten, changedLines := commentDisplayErrorsLines(data)
if changedLines == 0 {
return RemediationResult{Error: "no display_errors directive found in file"}
}
if isHtaccess {
rewritten = appendHtaccessOverride(rewritten)
} else {
rewritten = appendIniOverride(rewritten)
}
// Preserve ownership + mode. Write atomically via a sibling temp file +
// rename so a partial write does not leave the operator with a broken
// config.
if err := writeFilePreservingOwner(resolved, rewritten, info); err != nil {
return RemediationResult{Error: err.Error()}
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("disable display_errors in %s", resolved),
Description: fmt.Sprintf(
"Commented %d display_errors line(s) and appended an Off override at end of file",
changedLines,
),
}
}
// commentDisplayErrorsLines walks the file line-by-line, comments out any
// non-comment line whose directive is display_errors (matching the
// detector's logic). Returns the rewritten bytes and the count of lines
// changed.
func commentDisplayErrorsLines(data []byte) ([]byte, int) {
var out bytes.Buffer
changed := 0
for _, raw := range bytes.SplitAfter(data, []byte("\n")) {
if len(raw) == 0 {
continue
}
line := raw
newline := []byte(nil)
if bytes.HasSuffix(raw, []byte("\n")) {
line = raw[:len(raw)-1]
newline = []byte("\n")
}
lineText := string(line)
trimmed := strings.TrimSpace(lineText)
if trimmed == "" || strings.HasPrefix(trimmed, "#") || strings.HasPrefix(trimmed, ";") {
out.Write(raw)
continue
}
if !strings.Contains(strings.ToLower(trimmed), "display_errors") {
out.Write(raw)
continue
}
out.WriteString("# csm: disabled by remediation -- ")
out.WriteString(lineText)
out.Write(newline)
changed++
}
return out.Bytes(), changed
}
func appendIniOverride(data []byte) []byte {
override := "display_errors = Off"
return appendOverrideLine(data, override, "; csm: appended by remediation")
}
func appendHtaccessOverride(data []byte) []byte {
override := "php_flag display_errors Off"
return appendOverrideLine(data, override, "# csm: appended by remediation")
}
func appendOverrideLine(data []byte, directive, marker string) []byte {
var out bytes.Buffer
out.Write(data)
if len(data) > 0 && data[len(data)-1] != '\n' {
out.WriteByte('\n')
}
out.WriteString(marker)
out.WriteByte('\n')
out.WriteString(directive)
out.WriteByte('\n')
return out.Bytes()
}
// These helpers re-check the target inode after the initial path validation
// so a file swap between validation and mutation fails closed.
func truncateFilePreservingIdentity(path string, expected os.FileInfo) error {
// #nosec G304 -- path was validated by resolveExistingFixPath against
// the per-account remediation roots; the open uses O_NOFOLLOW and the
// subsequent sameFileIdentity check fails closed on inode swap.
f, err := os.OpenFile(path, os.O_WRONLY|syscall.O_NOFOLLOW, 0)
if err != nil {
return err
}
defer func() { _ = f.Close() }()
info, err := f.Stat()
if err != nil {
return err
}
if !sameFileIdentity(info, expected) {
return fmt.Errorf("file changed during remediation")
}
return f.Truncate(0)
}
func readFilePreservingIdentity(path string, expected os.FileInfo) ([]byte, error) {
// #nosec G304 -- path was validated by resolveExistingFixPath against
// the per-account remediation roots; the open uses O_NOFOLLOW and the
// subsequent sameFileIdentity check fails closed on inode swap.
f, err := os.OpenFile(path, os.O_RDONLY|syscall.O_NOFOLLOW, 0)
if err != nil {
return nil, err
}
defer func() { _ = f.Close() }()
info, err := f.Stat()
if err != nil {
return nil, err
}
if !sameFileIdentity(info, expected) {
return nil, fmt.Errorf("file changed during remediation")
}
return io.ReadAll(f)
}
func writeFilePreservingOwner(path string, data []byte, original os.FileInfo) error {
dir := filepath.Dir(path)
tmp, createErr := os.CreateTemp(dir, ".csm-perf-fix-*")
if createErr != nil {
return fmt.Errorf("create temp: %v", createErr)
}
tmpPath := tmp.Name()
cleanup := func() { _ = os.Remove(tmpPath) }
if _, werr := tmp.Write(data); werr != nil {
_ = tmp.Close()
cleanup()
return fmt.Errorf("write temp: %v", werr)
}
if cerr := tmp.Close(); cerr != nil {
cleanup()
return fmt.Errorf("close temp: %v", cerr)
}
if merr := os.Chmod(tmpPath, original.Mode().Perm()); merr != nil {
cleanup()
return fmt.Errorf("chmod temp: %v", merr)
}
info, statErr := os.Lstat(path)
if statErr != nil {
cleanup()
return fmt.Errorf("stat original: %v", statErr)
}
if info.Mode()&os.ModeSymlink != 0 {
cleanup()
return fmt.Errorf("original became a symlink during remediation")
}
if !sameFileIdentity(info, original) {
cleanup()
return fmt.Errorf("file changed during remediation")
}
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
if chownErr := os.Chown(tmpPath, int(stat.Uid), int(stat.Gid)); chownErr != nil {
cleanup()
return fmt.Errorf("chown temp: %v", chownErr)
}
}
if renameErr := os.Rename(tmpPath, path); renameErr != nil {
cleanup()
return fmt.Errorf("rename: %v", renameErr)
}
return nil
}
func sameFileIdentity(a, b os.FileInfo) bool {
if a == nil || b == nil {
return false
}
if os.SameFile(a, b) {
return true
}
as, aok := a.Sys().(*syscall.Stat_t)
bs, bok := b.Sys().(*syscall.Stat_t)
return aok && bok && as.Dev == bs.Dev && as.Ino == bs.Ino
}
package checks
import (
"bufio"
"context"
"encoding/json"
"fmt"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/mysqlclient"
"github.com/pidginhost/csm/internal/redisinfo"
"github.com/pidginhost/csm/internal/state"
)
// perfEnabled returns false only if Performance.Enabled is explicitly set to false.
// nil (unset) is treated as enabled.
func perfEnabled(cfg *config.Config) bool {
if cfg.Performance.Enabled == nil {
return true
}
return *cfg.Performance.Enabled
}
// cpuCoresOnce guards the cached CPU core count.
var (
cpuCoresOnce sync.Once
cpuCoresCache int
)
// getCPUCores reads /proc/cpuinfo and counts "processor\t" lines.
// The result is cached after the first call. Returns 1 on error.
func getCPUCores() int {
cpuCoresOnce.Do(func() {
f, err := osFS.Open("/proc/cpuinfo")
if err != nil {
cpuCoresCache = 1
return
}
defer func() { _ = f.Close() }()
count := 0
scanner := bufio.NewScanner(f)
for scanner.Scan() {
if strings.HasPrefix(scanner.Text(), "processor\t") {
count++
}
}
if count == 0 {
count = 1
}
cpuCoresCache = count
})
return cpuCoresCache
}
// parseLoadAvg reads /proc/loadavg and returns the first three load average
// values (1m, 5m, 15m).
func parseLoadAvg() ([3]float64, error) {
var result [3]float64
data, err := osFS.ReadFile("/proc/loadavg")
if err != nil {
return result, fmt.Errorf("reading /proc/loadavg: %w", err)
}
fields := strings.Fields(string(data))
if len(fields) < 3 {
return result, fmt.Errorf("unexpected /proc/loadavg format: %q", string(data))
}
for i := 0; i < 3; i++ {
v, err := strconv.ParseFloat(fields[i], 64)
if err != nil {
return result, fmt.Errorf("parsing load avg field %d: %w", i, err)
}
result[i] = v
}
return result, nil
}
// parseMemInfo reads /proc/meminfo and returns total memory, available memory,
// swap total, and swap free - all in kilobytes.
func parseMemInfo() (total, available, swapTotal, swapFree uint64) {
f, err := osFS.Open("/proc/meminfo")
if err != nil {
return
}
defer func() { _ = f.Close() }()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
key := fields[0]
val, err := strconv.ParseUint(fields[1], 10, 64)
if err != nil {
continue
}
switch key {
case "MemTotal:":
total = val
case "MemAvailable:":
available = val
case "SwapTotal:":
swapTotal = val
case "SwapFree:":
swapFree = val
}
}
return
}
const (
maxInt64Value = int64(1<<63 - 1)
maxUint64Value = ^uint64(0)
)
// uint64ToInt64Clamped narrows a uint64 to int64 for byte display.
func uint64ToInt64Clamped(v uint64) int64 {
if v > uint64(maxInt64Value) {
return maxInt64Value
}
return int64(v)
}
func redisLargeDatasetThresholdBytes(gb int) uint64 {
if gb <= 0 {
return 0
}
const gbBytes uint64 = 1024 * 1024 * 1024
thresholdGB := uint64(gb)
if thresholdGB > maxUint64Value/gbBytes {
return maxUint64Value
}
return thresholdGB * gbBytes
}
// humanBytes formats a byte count as a human-readable string.
// Thresholds: >=1G → "1.0G", >=1M → "1M", >=1K → "1K", else "0B".
func humanBytes(b int64) string {
const (
KB = 1024
MB = 1024 * KB
GB = 1024 * MB
)
switch {
case b >= GB:
return fmt.Sprintf("%.1fG", float64(b)/float64(GB))
case b >= MB:
return fmt.Sprintf("%dM", b/MB)
case b >= KB:
return fmt.Sprintf("%dK", b/KB)
default:
return "0B"
}
}
func kibToDisplayBytes(kib uint64) int64 {
const maxKiBForInt64Bytes = uint64(maxInt64Value / 1024)
if kib > maxKiBForInt64Bytes {
return maxInt64Value
}
return int64(kib) * 1024
}
// CheckLoadAverage compares load averages against per-core thresholds
// from config. The 1-minute load drives the Critical / High findings;
// when 1-minute is below the High threshold we additionally check the
// 5- and 15-minute averages for sustained pressure (>= 0.7 * High
// threshold on both) and emit a Warning. The sustained variant catches
// "constant 22%-of-cores busy for 15 minutes" which is invisible to a
// 1-minute spike check but is what operators actually want to see.
func CheckLoadAverage(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
loads, err := parseLoadAvg()
if err != nil {
return nil
}
cores := getCPUCores()
load1 := loads[0]
critThreshold := float64(cores) * cfg.Performance.LoadCriticalMultiplier
highThreshold := float64(cores) * cfg.Performance.LoadHighMultiplier
switch {
case load1 > critThreshold:
return []alert.Finding{{
Severity: alert.Critical,
Check: "perf_load",
Message: "High load average exceeds critical threshold",
Details: fmt.Sprintf("Load: %.1f/%.1f/%.1f, Cores: %d, Threshold: %.1f",
loads[0], loads[1], loads[2], cores, critThreshold),
Timestamp: time.Now(),
}}
case load1 > highThreshold:
return []alert.Finding{{
Severity: alert.High,
Check: "perf_load",
Message: "High load average exceeds high threshold",
Details: fmt.Sprintf("Load: %.1f/%.1f/%.1f, Cores: %d, Threshold: %.1f",
loads[0], loads[1], loads[2], cores, highThreshold),
Timestamp: time.Now(),
}}
}
// Sustained pressure: 1-minute is calm but 5- and 15-minute
// averages are both above 70% of the High threshold. This is the
// "load 9 on 40 cores for 15 minutes" shape -- below the spike
// threshold but a real operator concern.
sustainedThreshold := highThreshold * 0.7
if loads[1] > sustainedThreshold && loads[2] > sustainedThreshold {
return []alert.Finding{{
Severity: alert.Warning,
Check: "perf_load",
Message: "Sustained load (5m + 15m) above 70% of high threshold",
Details: fmt.Sprintf("Load: %.1f/%.1f/%.1f, Cores: %d, Sustained threshold: %.1f",
loads[0], loads[1], loads[2], cores, sustainedThreshold),
Timestamp: time.Now(),
}}
}
return nil
}
// CheckPHPProcessLoad scans /proc for lsphp processes, groups them by user,
// and fires Critical if total exceeds cores*multiplier, High per user if
// individual count exceeds threshold.
func CheckPHPProcessLoad(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
cores := getCPUCores()
cmdlinePaths, _ := osFS.Glob("/proc/[0-9]*/cmdline")
// user → list of cmdline samples
userProcs := make(map[string][]string)
total := 0
for _, cmdPath := range cmdlinePaths {
pid := filepath.Base(filepath.Dir(cmdPath))
data, err := osFS.ReadFile(cmdPath)
if err != nil {
continue
}
cmdStr := strings.ReplaceAll(string(data), "\x00", " ")
cmdStr = strings.TrimSpace(cmdStr)
if !strings.Contains(cmdStr, "lsphp") {
continue
}
// Read UID from status
statusData, _ := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
var uid string
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
uid = fields[0]
}
break
}
}
if uid == "" {
uid = "unknown"
}
username := uidStringToUser(uid)
userProcs[username] = append(userProcs[username], cmdStr)
total++
}
if total == 0 {
return nil
}
var findings []alert.Finding
// Critical: total lsphp count exceeds cores * multiplier
critTotalThreshold := cores * cfg.Performance.PHPProcessCriticalTotalMult
if total > critTotalThreshold {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "perf_php_processes",
Message: "Total lsphp process count exceeds critical threshold",
Details: fmt.Sprintf("Count: %d, Threshold: %d (cores: %d × %d)", total, critTotalThreshold, cores, cfg.Performance.PHPProcessCriticalTotalMult),
Timestamp: time.Now(),
})
}
// High: per-user count exceeds threshold
for username, procs := range userProcs {
if len(procs) > cfg.Performance.PHPProcessWarnPerUser {
// Collect up to 3 sample cmdlines
samples := procs
if len(samples) > 3 {
samples = samples[:3]
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_php_processes",
Message: fmt.Sprintf("Excessive lsphp processes for user %s", username),
Details: fmt.Sprintf("Count: %d, Threshold: %d, Sample cmdlines: %s", len(procs), cfg.Performance.PHPProcessWarnPerUser, strings.Join(samples, " | ")),
Timestamp: time.Now(),
})
}
}
return findings
}
// CheckSwapAndOOM checks for OOM killer events in dmesg and elevated swap
// usage from /proc/meminfo. Reports Critical for OOM, High for swap > 50%.
func CheckSwapAndOOM(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
var findings []alert.Finding
// Check dmesg for OOM events
// Prefer ISO timestamps so we can filter to the last hour.
// Fall back to -T (human-readable) on older kernels that don't support --time-format.
dmesgOut, isoErr := runCmd("dmesg", "--time-format", "iso", "--level=err")
useISO := isoErr == nil && dmesgOut != nil
if !useISO {
dmesgOut, _ = runCmd("dmesg", "--level=err", "-T")
}
if dmesgOut != nil {
cutoff := time.Now().Add(-1 * time.Hour)
for _, line := range strings.Split(string(dmesgOut), "\n") {
lower := strings.ToLower(line)
if !strings.Contains(lower, "out of memory") && !strings.Contains(lower, "oom_reaper") {
continue
}
// Both the ISO and the -T fallback are filtered to the last
// hour. A line whose timestamp cannot be parsed is skipped, not
// reported: an OOM event we cannot date is exactly the stale
// finding that previously fired a Critical on every scan.
when, ok := parseDmesgOOMTime(line, useISO)
if !ok || when.Before(cutoff) {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "perf_memory",
Message: "OOM killer invoked in the last hour",
Details: strings.TrimSpace(line),
Timestamp: time.Now(),
})
break // one finding is enough
}
}
// Check swap usage
_, _, swapTotal, swapFree := parseMemInfo()
if swapTotal > 0 {
swapUsed := swapTotal - swapFree
usagePct := float64(swapUsed) / float64(swapTotal) * 100
if usagePct > 50 {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_memory",
Message: "High swap usage",
Details: fmt.Sprintf("Swap used: %s / %s (%.0f%%)", humanBytes(kibToDisplayBytes(swapUsed)), humanBytes(kibToDisplayBytes(swapTotal)), usagePct),
Timestamp: time.Now(),
})
}
}
return findings
}
// parseDmesgOOMTime extracts the event time from a dmesg line. ISO lines
// (--time-format iso) carry an absolute timestamp with a timezone offset as
// the first field. The -T fallback carries a bracketed ctime in local time
// ("[Mon Jan _2 15:04:05 2006] ..."). Returns ok=false when no timestamp can
// be parsed, so the caller drops the line rather than reporting an undatable
// (and therefore possibly stale) OOM event.
func parseDmesgOOMTime(line string, useISO bool) (time.Time, bool) {
if useISO {
// 2006-01-02T15:04:05,000000+0300 -- comma decimal, first field.
ts := strings.Replace(strings.SplitN(line, " ", 2)[0], ",", ".", 1)
for _, layout := range []string{"2006-01-02T15:04:05.000000-0700", "2006-01-02T15:04:05.000000-07:00"} {
if parsed, err := time.Parse(layout, ts); err == nil {
return parsed, true
}
}
return time.Time{}, false
}
open := strings.IndexByte(line, '[')
closeIdx := strings.IndexByte(line, ']')
if open != 0 || closeIdx <= open {
return time.Time{}, false
}
// dmesg -T prints local wall-clock time with no zone, so parse in Local.
parsed, err := time.ParseInLocation("Mon Jan _2 15:04:05 2006", strings.TrimSpace(line[open+1:closeIdx]), time.Local)
if err != nil {
return time.Time{}, false
}
return parsed, true
}
// CheckPHPHandler detects PHP CGI handler usage on LiteSpeed servers.
// On LiteSpeed, CGI is significantly slower than LSAPI; this check fires
// a Critical finding for each PHP version using the CGI handler.
func CheckPHPHandler(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
// Only relevant on LiteSpeed
if _, err := osFS.Stat("/usr/local/lsws/bin/litespeed"); err != nil {
return nil
}
var cgiVersions []string
// Try whmapi1 first
out, err := runCmd("whmapi1", "php_get_handlers", "--output=json")
if err == nil && len(out) > 0 {
// Parse JSON: look for handler entries with type "cgi"
var result struct {
Data struct {
Handlers []struct {
Version string `json:"version"`
Handler string `json:"handler"`
Type string `json:"type"`
} `json:"handlers"`
} `json:"data"`
}
if jsonErr := json.Unmarshal(out, &result); jsonErr == nil {
for _, h := range result.Data.Handlers {
t := strings.ToLower(h.Handler + " " + h.Type)
if strings.Contains(t, "cgi") && !strings.Contains(t, "lsapi") && !strings.Contains(t, "fpm") {
cgiVersions = append(cgiVersions, h.Version)
}
}
}
} else {
// Fallback: read /etc/cpanel/ea4/ea4.conf
data, readErr := osFS.ReadFile("/etc/cpanel/ea4/ea4.conf")
if readErr == nil {
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
// Lines like: ea-php74.handler = cgi
if !strings.Contains(line, ".handler") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
val := strings.TrimSpace(parts[1])
if val == "cgi" {
versionPart := strings.TrimSpace(parts[0])
cgiVersions = append(cgiVersions, versionPart)
}
}
}
}
if len(cgiVersions) == 0 {
return nil
}
return []alert.Finding{{
Severity: alert.Critical,
Check: "perf_php_handler",
Message: "PHP handler set to CGI instead of LSAPI on LiteSpeed",
Details: fmt.Sprintf("Affected PHP versions: %s", strings.Join(cgiVersions, ", ")),
Timestamp: time.Now(),
}}
}
// CheckMySQLConfig inspects MySQL global variables and runtime status for
// performance-impacting misconfigurations. Each issue emits its own finding
// with a stable message so deduplication works correctly.
func CheckMySQLConfig(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
var findings []alert.Finding
// --- Global variables ---
varRows, err := mysqlclient.RootQuery(ctx,
"SHOW GLOBAL VARIABLES WHERE Variable_name IN ('join_buffer_size','wait_timeout','interactive_timeout','max_user_connections','slow_query_log')")
varOut := []byte(strings.Join(varRows, "\n"))
if err == nil && len(varOut) > 0 {
joinBufThresholdBytes := int64(cfg.Performance.MySQLJoinBufferMaxMB) * 1024 * 1024
waitTimeoutMax := cfg.Performance.MySQLWaitTimeoutMax
for _, line := range strings.Split(string(varOut), "\n") {
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
name := fields[0]
val := fields[1]
switch name {
case "join_buffer_size":
v, convErr := strconv.ParseInt(val, 10, 64)
if convErr == nil && v > joinBufThresholdBytes {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "perf_mysql_config",
Message: "MySQL join_buffer_size exceeds safe maximum",
Details: fmt.Sprintf("Current: %s, Max: %s", humanBytes(v), humanBytes(joinBufThresholdBytes)),
Timestamp: time.Now(),
})
}
case "wait_timeout":
v, convErr := strconv.Atoi(val)
if convErr == nil && v > waitTimeoutMax {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_mysql_config",
Message: "MySQL wait_timeout is too high",
Details: fmt.Sprintf("Current: %ds, Max: %ds", v, waitTimeoutMax),
Timestamp: time.Now(),
})
}
case "interactive_timeout":
v, convErr := strconv.Atoi(val)
if convErr == nil && v > waitTimeoutMax {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_mysql_config",
Message: "MySQL interactive_timeout is too high",
Details: fmt.Sprintf("Current: %ds, Max: %ds", v, waitTimeoutMax),
Timestamp: time.Now(),
})
}
case "max_user_connections":
if val == "0" {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_mysql_config",
Message: "MySQL max_user_connections is unlimited",
Details: fmt.Sprintf("Current: 0 (unlimited), Recommended: %d", cfg.Performance.MySQLMaxConnectionsPerUser),
Timestamp: time.Now(),
})
}
case "slow_query_log":
if strings.ToUpper(val) == "OFF" {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_mysql_config",
Message: "MySQL slow query log is disabled",
Details: "Set slow_query_log=ON to help diagnose performance issues",
Timestamp: time.Now(),
})
}
}
}
}
// --- InnoDB buffer pool hit ratio + temporary disk tables ---
statusRows, err := mysqlclient.RootQuery(ctx,
"SHOW GLOBAL STATUS WHERE Variable_name IN ('Innodb_buffer_pool_read_requests','Innodb_buffer_pool_reads','Created_tmp_disk_tables','Created_tmp_tables')")
statusOut := []byte(strings.Join(statusRows, "\n"))
if err == nil && len(statusOut) > 0 {
var readRequests, reads, tmpDiskTables, tmpTables int64
for _, line := range strings.Split(string(statusOut), "\n") {
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
v, convErr := strconv.ParseInt(fields[1], 10, 64)
if convErr != nil {
continue
}
switch fields[0] {
case "Innodb_buffer_pool_read_requests":
readRequests = v
case "Innodb_buffer_pool_reads":
reads = v
case "Created_tmp_disk_tables":
tmpDiskTables = v
case "Created_tmp_tables":
tmpTables = v
}
}
if tmpTables > 0 && tmpDiskTables > 0 {
diskRatio := float64(tmpDiskTables) / float64(tmpTables) * 100
if diskRatio > 25.0 {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_mysql_config",
Message: "MySQL creating excessive temporary tables on disk",
Details: fmt.Sprintf("Disk ratio: %.1f%% (%d disk tables / %d total tables)", diskRatio, tmpDiskTables, tmpTables),
Timestamp: time.Now(),
})
}
}
if readRequests > 0 {
hitRatio := float64(readRequests-reads) / float64(readRequests) * 100
if hitRatio < 95.0 {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_mysql_config",
Message: "InnoDB buffer pool hit ratio is low",
Details: fmt.Sprintf("Hit ratio: %.1f%% (threshold: 95%%), disk reads: %d", hitRatio, reads),
Timestamp: time.Now(),
})
}
}
}
// --- Per-user connection counts ---
plRows, err := mysqlclient.RootQuery(ctx, "SHOW PROCESSLIST")
if err == nil && len(plRows) > 0 {
userCounts := make(map[string]int)
for _, line := range plRows {
fields := strings.Fields(line)
// SHOW PROCESSLIST columns: Id, User, Host, db, Command, Time, State, Info
if len(fields) < 2 {
continue
}
user := fields[1]
if user == "" || user == "User" {
continue
}
userCounts[user]++
}
maxConn := cfg.Performance.MySQLMaxConnectionsPerUser
for dbUser, count := range userCounts {
if count > maxConn {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_mysql_config",
Message: fmt.Sprintf("MySQL user %s holding excessive connections", dbUser),
Details: fmt.Sprintf("Connections: %d, Threshold: %d", count, maxConn),
Timestamp: time.Now(),
})
}
}
}
return findings
}
// CheckRedisConfig inspects a local Redis instance for performance-impacting
// misconfigurations: unset maxmemory, noeviction policy, non-expiring keys,
// and an overly aggressive bgsave schedule for the dataset size.
func CheckRedisConfig(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
// Skip on hosts without a local redis: peek at the redisinfo
// client by trying a fast INFO server. Any error short-circuits
// the whole check (matches the historical behaviour where a
// missing redis-cli binary made every redis check a no-op).
if _, _, err := redisinfo.MemoryUsage(ctx); err != nil {
return nil
}
var findings []alert.Finding
// --- maxmemory ---
if maxMem, err := redisinfo.ConfigGet(ctx, "maxmemory"); err == nil && maxMem == "0" {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "perf_redis_config",
Message: "Redis maxmemory is not set",
Details: "maxmemory=0 means Redis will use all available system memory without bound",
Timestamp: time.Now(),
})
}
// --- maxmemory-policy ---
policy, policyErr := redisinfo.ConfigGet(ctx, "maxmemory-policy")
policyLower := ""
if policyErr == nil {
policyLower = strings.ToLower(strings.TrimSpace(policy))
}
if policyLower == "noeviction" {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_redis_config",
Message: "Redis maxmemory-policy is noeviction",
Details: "noeviction causes Redis to return errors when memory is full instead of evicting keys",
Timestamp: time.Now(),
})
}
// --- Non-expiring keys ratio via keyspace ---
// A high non-expiring ratio only breaks eviction under volatile-* policies
// (which evict keys carrying a TTL) or noeviction. Under allkeys-* Redis
// evicts any key, so non-expiring keys are reclaimable and the ratio is
// benign.
if stats, err := redisinfo.KeyspaceStats(ctx); err == nil && stats.TotalKeys > 0 && !strings.HasPrefix(policyLower, "allkeys-") {
nonExpiring := stats.TotalKeys - stats.TotalExpires
ratio := float64(nonExpiring) / float64(stats.TotalKeys) * 100
if ratio > 95.0 {
details := fmt.Sprintf("Non-expiring: %d / %d total keys (%.1f%%); %s",
nonExpiring, stats.TotalKeys, ratio, redisNonExpiringPolicyDetail(policyLower))
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_redis_config",
Message: "Redis has excessive non-expiring keys",
Details: details,
Timestamp: time.Now(),
})
}
}
// --- bgsave interval vs dataset size ---
saveSpec, _ := redisinfo.ConfigGet(ctx, "save")
usedBytes, _, _ := redisinfo.MemoryUsage(ctx)
largeDatasetBytes := redisLargeDatasetThresholdBytes(cfg.Performance.RedisLargeDatasetGB)
bgsaveMinInterval := cfg.Performance.RedisBgsaveMinInterval
if usedBytes > largeDatasetBytes && saveSpec != "" {
// `CONFIG GET save` returns the spec as a single space-separated
// string of alternating "<seconds> <changes>" pairs, e.g.
// "900 1 300 10 60 10000". Walk the seconds tokens (every
// other field) and flag any below the configured floor.
fields := strings.Fields(saveSpec)
aggressiveSave := false
for i := 0; i < len(fields); i += 2 {
seconds, convErr := strconv.Atoi(fields[i])
if convErr == nil && seconds < bgsaveMinInterval {
aggressiveSave = true
break
}
}
if aggressiveSave {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_redis_config",
Message: "Redis bgsave interval too aggressive for dataset size",
Details: fmt.Sprintf(
"Used memory: %s, Threshold: %s, Minimum safe bgsave interval: %ds",
humanBytes(uint64ToInt64Clamped(usedBytes)),
humanBytes(uint64ToInt64Clamped(largeDatasetBytes)),
bgsaveMinInterval,
),
Timestamp: time.Now(),
})
}
}
// --- used_memory vs maxmemory headroom ---
// The maxmemory==0 branch above flags the unset case. When maxmemory
// IS set, used/max ratio is the operator-meaningful signal: at 80%
// the eviction policy is about to start churning hot keys; at 90%
// noeviction-policy instances start returning OOM errors.
_, maxBytes, _ := redisinfo.MemoryUsage(ctx)
if maxBytes > 0 && usedBytes > 0 {
pct := float64(usedBytes) / float64(maxBytes) * 100
switch {
case pct >= 90:
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_redis_config",
Message: "Redis used memory >= 90% of maxmemory",
Details: fmt.Sprintf(
"Used: %s / Max: %s (%.1f%%)",
humanBytes(uint64ToInt64Clamped(usedBytes)),
humanBytes(uint64ToInt64Clamped(maxBytes)),
pct,
),
Timestamp: time.Now(),
})
case pct >= 80:
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_redis_config",
Message: "Redis used memory >= 80% of maxmemory",
Details: fmt.Sprintf(
"Used: %s / Max: %s (%.1f%%)",
humanBytes(uint64ToInt64Clamped(usedBytes)),
humanBytes(uint64ToInt64Clamped(maxBytes)),
pct,
),
Timestamp: time.Now(),
})
}
}
return findings
}
func redisNonExpiringPolicyDetail(policyLower string) string {
switch {
case policyLower == "":
return "maxmemory-policy is unavailable, so non-expiring keys may be unsafe under memory pressure"
case policyLower == "noeviction":
return "maxmemory-policy noeviction does not evict keys under memory pressure"
case strings.HasPrefix(policyLower, "volatile-"):
return fmt.Sprintf("maxmemory-policy %q only evicts keys with a TTL under memory pressure", policyLower)
default:
return fmt.Sprintf("maxmemory-policy %q may leave non-expiring keys unreclaimable under memory pressure", policyLower)
}
}
// ---------------------------------------------------------------------------
// Performance check helpers (WP-specific)
// ---------------------------------------------------------------------------
// safeIdentifier returns true if s matches ^[a-zA-Z0-9_]+$ (non-empty).
// Used to reject values with shell metacharacters before use in commands/SQL.
var safeIdentRe = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
func safeIdentifier(s string) bool {
return s != "" && safeIdentRe.MatchString(s)
}
// extractPHPDefine extracts the value argument from a PHP define() line:
//
// define('KEY', 'value'); or define("KEY", "value");
//
// It is distinct from extractDefine (dbscan.go) which requires a key parameter.
// Returns the empty string if no value can be extracted.
func extractPHPDefine(line string) string {
// Trim whitespace and trailing semicolons/comments.
line = strings.TrimSpace(line)
// Find the opening parenthesis.
parenIdx := strings.Index(line, "(")
if parenIdx < 0 {
return ""
}
inner := line[parenIdx+1:]
// Strip closing paren and anything after.
if closeIdx := strings.LastIndex(inner, ")"); closeIdx >= 0 {
inner = inner[:closeIdx]
}
// inner is now like: 'KEY', 'value' or "KEY", "value"
// Split on the first comma, ignoring the key part.
commaIdx := strings.Index(inner, ",")
if commaIdx < 0 {
return ""
}
valuePart := strings.TrimSpace(inner[commaIdx+1:])
if valuePart == "" {
return ""
}
// Strip surrounding quotes (single or double) when present.
q := valuePart[0]
if q == '\'' || q == '"' {
if len(valuePart) < 2 {
return ""
}
end := strings.LastIndexByte(valuePart, q)
if end <= 0 {
return ""
}
return valuePart[1:end]
}
// Unquoted literal (boolean/number constant). Strip a trailing ); or
// whitespace and return the bare token. Examples wp-config.php uses:
// define('DISABLE_WP_CRON', true);
// define('WP_DEBUG', false);
// define('WP_MEMORY_LIMIT', 256);
for i, c := range valuePart {
if c == ' ' || c == '\t' || c == ';' || c == ')' || c == ',' {
return strings.TrimSpace(valuePart[:i])
}
}
return strings.TrimSpace(valuePart)
}
// ---------------------------------------------------------------------------
// Subdirs to skip in recursive helpers.
// ---------------------------------------------------------------------------
var skipDirs = map[string]bool{
"wp-admin": true,
"wp-content": true,
"wp-includes": true,
"cache": true,
"node_modules": true,
"vendor": true,
}
// ---------------------------------------------------------------------------
// CheckErrorLogBloat
// ---------------------------------------------------------------------------
// scanErrorLogs recursively walks dir up to maxDepth looking for error_log
// files larger than threshold bytes. Results are appended to *findings (capped
// at 20).
func scanErrorLogs(dir string, thresholdBytes int64, depth int, findings *[]alert.Finding) {
if depth < 0 || len(*findings) >= 20 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, e := range entries {
if len(*findings) >= 20 {
return
}
name := e.Name()
fullPath := filepath.Join(dir, name)
if e.IsDir() {
if skipDirs[name] {
continue
}
scanErrorLogs(fullPath, thresholdBytes, depth-1, findings)
continue
}
if name != "error_log" {
continue
}
info, statErr := e.Info()
if statErr != nil {
continue
}
if info.Size() > thresholdBytes {
*findings = append(*findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_error_logs",
Message: fmt.Sprintf("Bloated error_log: %s", fullPath),
Details: fmt.Sprintf("Size: %s", humanBytes(info.Size())),
Timestamp: time.Now(),
})
}
}
}
// CheckErrorLogBloat walks configured web roots (default /home/*/public_html
// on cPanel) looking for error_log files that exceed the configured size
// threshold. The runner enforces a 60-minute throttle via checkThrottleMin.
func CheckErrorLogBloat(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
thresholdBytes := int64(cfg.Performance.ErrorLogWarnSizeMB) * 1024 * 1024
homeDirs := ResolveWebRoots(cfg)
var findings []alert.Finding
for _, dir := range homeDirs {
scanErrorLogs(dir, thresholdBytes, 3, &findings)
if len(findings) >= 20 {
break
}
}
return findings
}
// ---------------------------------------------------------------------------
// CheckWPConfig
// ---------------------------------------------------------------------------
// parseMemoryLimit converts a PHP memory_limit string (e.g. "256M", "1G")
// to megabytes. Returns 0 if the value cannot be parsed.
func parseMemoryLimit(s string) int {
s = strings.TrimSpace(strings.ToUpper(s))
if s == "" || s == "-1" {
return 0
}
suffix := s[len(s)-1]
numStr := s
mult := 1
switch suffix {
case 'K':
numStr = s[:len(s)-1]
v, err := strconv.Atoi(numStr)
if err != nil {
return 0
}
return v / 1024
case 'M':
numStr = s[:len(s)-1]
mult = 1
case 'G':
numStr = s[:len(s)-1]
mult = 1024
}
v, err := strconv.Atoi(numStr)
if err != nil {
return 0
}
return v * mult
}
// scanWPConfigs recursively searches dir (max depth) for wp-config.php files
// and checks WP_MEMORY_LIMIT and co-located config files for issues.
func scanWPConfigs(dir, account string, cfg *config.Config, depth int, findings *[]alert.Finding) {
if depth < 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, e := range entries {
name := e.Name()
fullPath := filepath.Join(dir, name)
if e.IsDir() {
if skipDirs[name] {
continue
}
scanWPConfigs(fullPath, account, cfg, depth-1, findings)
continue
}
if name != "wp-config.php" {
continue
}
// --- WP_MEMORY_LIMIT ---
wpData, readErr := osFS.ReadFile(fullPath)
if readErr == nil {
for _, line := range strings.Split(string(wpData), "\n") {
if strings.Contains(line, "WP_MEMORY_LIMIT") {
val := extractPHPDefine(strings.TrimSpace(line))
if mb := parseMemoryLimit(val); mb > cfg.Performance.WPMemoryLimitMaxMB {
*findings = append(*findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_wp_config",
Message: fmt.Sprintf("Excessive WP_MEMORY_LIMIT for %s", account),
Details: fmt.Sprintf("File: %s, Value: %s", fullPath, val),
Timestamp: time.Now(),
})
}
break
}
}
}
// --- Co-located PHP config files ---
wpDir := filepath.Dir(fullPath)
for _, cfgFile := range []string{".htaccess", "php.ini", ".user.ini"} {
cfgPath := filepath.Join(wpDir, cfgFile)
data, readErr2 := osFS.ReadFile(cfgPath)
if readErr2 != nil {
continue
}
// cPanel MultiPHP INI Editor writes .user.ini with a fixed
// header and owns the file's content. Values inside a
// cPanel-managed .user.ini (max_execution_time=0 for a
// backup importer, display_errors=On for a staging account)
// reflect operator choices made through the cPanel UI and
// are not attacker actions. Suppress findings for this
// file in that case — operators do not need alerts for
// their own configuration. The suppression is scoped
// strictly to .user.ini: the same signature in php.ini or
// .htaccess is not authoritative (cPanel does not write
// those files) and the scanner treats it normally.
if cfgFile == ".user.ini" && isCpanelManagedUserIni(data) {
continue
}
for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") || strings.HasPrefix(trimmed, ";") {
continue
}
lc := strings.ToLower(trimmed)
switch {
case strings.Contains(lc, "max_execution_time"):
// max_execution_time = 0 (or php_value max_execution_time 0)
parts := strings.FieldsFunc(trimmed, func(r rune) bool { return r == '=' || r == ' ' || r == '\t' })
if len(parts) >= 2 && parts[len(parts)-1] == "0" {
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "perf_wp_config",
Message: fmt.Sprintf("Unlimited max_execution_time for %s", account),
Details: fmt.Sprintf("File: %s, Value: 0", cfgPath),
Timestamp: time.Now(),
})
}
case strings.Contains(lc, "display_errors"):
parts := strings.FieldsFunc(trimmed, func(r rune) bool { return r == '=' || r == ' ' || r == '\t' })
if len(parts) >= 2 && strings.ToLower(parts[len(parts)-1]) == "on" {
*findings = append(*findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_wp_config",
Message: fmt.Sprintf("display_errors enabled in production for %s", account),
Details: fmt.Sprintf("File: %s, Value: On", cfgPath),
Timestamp: time.Now(),
})
}
}
}
}
}
}
// CheckWPConfig scans /home/*/public_html (max depth 2) for wp-config.php
// files and reports excessive WP_MEMORY_LIMIT values, unlimited
// max_execution_time, and display_errors enabled in production.
// The runner enforces a 60-minute throttle via checkThrottleMin.
func CheckWPConfig(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
homeDirs := ResolveWebRoots(cfg)
var findings []alert.Finding
for _, dir := range homeDirs {
scanWPConfigs(dir, accountFromPath(dir), cfg, 2, &findings)
}
return findings
}
// accountFromPath extracts a best-effort account name from a web root path.
// On cPanel (/home/USER/public_html) it returns USER. On other layouts it
// returns the parent directory name, or the final path component if there
// is no parent. Used for reporting only — never for authorization.
func accountFromPath(dir string) string {
parts := strings.Split(dir, string(filepath.Separator))
// cPanel shape: /home/<account>/public_html
for i, p := range parts {
if p == "home" && i+1 < len(parts) {
return parts[i+1]
}
}
// Generic shape: /var/www/<site>, /srv/http/<site>, etc.
if len(parts) >= 2 && parts[len(parts)-1] != "" {
return parts[len(parts)-2]
}
return filepath.Base(dir)
}
// ---------------------------------------------------------------------------
// CheckWPTransientBloat
// ---------------------------------------------------------------------------
// findWPTransients recursively searches dir for wp-config.php files and
// queries the WordPress database for bloated transients.
func findWPTransients(dir string, cfg *config.Config, warnBytes, critBytes int64, depth int, findings *[]alert.Finding) {
if depth < 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, e := range entries {
name := e.Name()
fullPath := filepath.Join(dir, name)
if e.IsDir() {
if skipDirs[name] {
continue
}
findWPTransients(fullPath, cfg, warnBytes, critBytes, depth-1, findings)
continue
}
if name != "wp-config.php" {
continue
}
info := parseWPConfig(fullPath)
if info.dbName == "" || info.dbUser == "" {
continue
}
// Apply default table prefix when not set.
if info.tablePrefix == "" {
info.tablePrefix = "wp_"
}
// Security: validate identifiers before use in SQL.
if !safeIdentifier(info.dbName) || !safeIdentifier(info.dbUser) || !safeIdentifier(info.tablePrefix) {
continue
}
query := fmt.Sprintf(
"SELECT option_name, LENGTH(option_value) as size FROM %soptions WHERE option_name LIKE '_transient_%%' AND LENGTH(option_value) > %d ORDER BY size DESC LIMIT 5",
info.tablePrefix,
warnBytes,
)
rows, runErr := mysqlclient.PerAccountQuery(context.Background(), mysqlclient.Creds{
User: info.dbUser,
Password: info.dbPass,
Host: info.dbHost,
DBName: info.dbName,
}, query)
if runErr != nil || len(rows) == 0 {
continue
}
for _, line := range rows {
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
optionName := fields[0]
sizeBytes, convErr := strconv.ParseInt(fields[1], 10, 64)
if convErr != nil {
continue
}
var sev alert.Severity
switch {
case sizeBytes > critBytes:
sev = alert.High
case sizeBytes > warnBytes:
sev = alert.Warning
default:
continue
}
*findings = append(*findings, alert.Finding{
Severity: sev,
Check: "perf_wp_transients",
Message: fmt.Sprintf("Bloated transient %s in %s", optionName, info.dbName),
Details: fmt.Sprintf("Size: %s", humanBytes(sizeBytes)),
Timestamp: time.Now(),
})
}
}
}
// CheckWPTransientBloat scans configured web roots (default /home/*/public_html
// on cPanel) for WordPress installs and queries each database for oversized
// transients. DB credentials are read from wp-config.php; the password is
// passed via MYSQL_PWD environment variable (never on the command line).
// The runner enforces a 60-minute throttle via checkThrottleMin.
func CheckWPTransientBloat(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
warnBytes := int64(cfg.Performance.WPTransientWarnMB) * 1024 * 1024
critBytes := int64(cfg.Performance.WPTransientCriticalMB) * 1024 * 1024
homeDirs := ResolveWebRoots(cfg)
var findings []alert.Finding
for _, dir := range homeDirs {
findWPTransients(dir, cfg, warnBytes, critBytes, 2, &findings)
}
return findings
}
// ---------------------------------------------------------------------------
// CheckWPCron
// ---------------------------------------------------------------------------
// scanWPCron recursively searches dir for wp-config.php files and checks
// whether DISABLE_WP_CRON is defined and set to true.
func scanWPCron(dir, account string, depth int, findings *[]alert.Finding) {
if depth < 0 || len(*findings) >= 30 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, e := range entries {
if len(*findings) >= 30 {
return
}
name := e.Name()
fullPath := filepath.Join(dir, name)
if e.IsDir() {
if skipDirs[name] {
continue
}
scanWPCron(fullPath, account, depth-1, findings)
continue
}
if name != "wp-config.php" {
continue
}
data, readErr := osFS.ReadFile(fullPath)
if readErr != nil {
continue
}
if !wpCronHasActiveDisableDefine(data) {
*findings = append(*findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_wp_cron",
Message: fmt.Sprintf("WP-Cron not disabled for %s", account),
Details: fmt.Sprintf(
"File: %s - add define('DISABLE_WP_CRON', true); and use a real cron job instead",
fullPath,
),
Timestamp: time.Now(),
})
}
}
}
// CheckWPCron scans configured web roots (default /home/*/public_html on
// cPanel) for WordPress installs that have not disabled the built-in
// WP-Cron mechanism. Running WP-Cron via HTTP is a common cause of high
// load on busy sites.
// The runner enforces a 60-minute throttle via checkThrottleMin.
func CheckWPCron(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
homeDirs := ResolveWebRoots(cfg)
var findings []alert.Finding
for _, dir := range homeDirs {
scanWPCron(dir, accountFromPath(dir), 2, &findings)
if len(findings) >= 30 {
break
}
}
return findings
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strings"
"unicode"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const phishingReadSize = 16384 // Read first 16KB - phishing pages are self-contained
// ---------------------------------------------------------------------------
// Brand impersonation patterns
// ---------------------------------------------------------------------------
var phishingBrands = []struct {
name string
titlePatterns []string
bodyPatterns []string
}{
{
name: "Microsoft/SharePoint",
titlePatterns: []string{"sharepoint", "onedrive", "microsoft 365", "outlook web", "office 365", "ms online"},
bodyPatterns: []string{"sharepoint", "onedrive", "secured by microsoft", "microsoft corporation"},
},
{
name: "Google",
titlePatterns: []string{"google drive", "google docs", "google sign", "gmail", "google workspace"},
bodyPatterns: []string{"google drive", "google docs", "accounts.google", "secured by google"},
},
{
name: "Dropbox",
titlePatterns: []string{"dropbox", "shared file", "shared folder"},
bodyPatterns: []string{"dropbox", "dropbox.com", "secured by dropbox"},
},
{
name: "DocuSign",
titlePatterns: []string{"docusign", "document signing", "e-signature"},
bodyPatterns: []string{"docusign", "please review and sign", "e-signature"},
},
{
name: "Adobe",
titlePatterns: []string{"adobe sign", "adobe document", "adobe acrobat"},
bodyPatterns: []string{"adobe sign", "adobe document cloud", "secured by adobe"},
},
{
name: "WeTransfer",
titlePatterns: []string{"wetransfer", "file transfer"},
bodyPatterns: []string{"wetransfer", "download your files"},
},
{
name: "Apple/iCloud",
titlePatterns: []string{"icloud", "apple id", "find my"},
bodyPatterns: []string{"icloud.com", "apple id", "secured by apple"},
},
{
name: "PayPal",
titlePatterns: []string{"paypal", "pay pal"},
bodyPatterns: []string{"paypal.com", "secured by paypal"},
},
{
name: "Webmail/Roundcube",
titlePatterns: []string{"roundcube", "horde", "webmail login", "webmail ::", "squirrelmail", "zimbra"},
bodyPatterns: []string{"roundcube webmail", "horde login", "zimbra web client", "webmail login"},
// Note: bare "squirrelmail"/"roundcube" removed from body - sites legitimately
// link to their server's webmail (e.g. href="squirrelmail/index.php").
},
{
name: "cPanel/WHM",
titlePatterns: []string{"cpanel", "whm login", "webhost manager"},
bodyPatterns: []string{"cpanel login", "whm login", "webhost manager"},
},
{
name: "Banking/Financial",
titlePatterns: []string{"online banking", "bank login", "secure banking", "account login"},
bodyPatterns: []string{"online banking", "bank account", "transaction verification"},
},
{
name: "Generic Login",
titlePatterns: []string{"secure access", "verify your", "confirm your identity", "account verification", "email verification", "sign in"},
bodyPatterns: []string{"verify your identity", "confirm your account", "unusual activity"},
},
}
// ---------------------------------------------------------------------------
// Content-based indicators
// ---------------------------------------------------------------------------
// Credential harvesting patterns in page body.
var harvestIndicators = []string{
"window.location.href",
"window.location.replace",
"window.location =",
"document.location.href",
"form.submit()",
".workers.dev",
"confirm access",
"verify your email",
"confirm your email",
"verify identity",
"continue to document",
"access confirmed, redirecting",
"secured by microsoft",
"secured by google",
"secured by apple",
"256-bit encrypted",
"256‑bit encrypted",
// fetch/XHR exfiltration - silent credential POST without redirect
"fetch(",
"xmlhttprequest",
"$.ajax(",
"$.post(",
"navigator.sendbeacon(",
}
// Redirect/exfiltration URL patterns.
var exfilPatterns = []string{
".workers.dev",
"//t.co/",
"/redir?",
"/redirect?",
"effi.redir",
"link?url=",
"goto_url=",
"//bit.ly/",
"//tinyurl.com/",
"//rb.gy/",
"//is.gd/",
"/servlet/effi.redir",
}
// Fake trust badge patterns - security claims in pages not on the brand's domain.
var trustBadgePatterns = []string{
"secured by microsoft",
"secured by google",
"secured by apple",
"secured by dropbox",
"secured by adobe",
"verified by microsoft",
"protected by microsoft",
"256-bit encrypted",
"256‑bit encrypted",
"ssl secured",
"bank-level encryption",
"enterprise security",
}
// Urgency language used to pressure victims.
var urgencyPatterns = []string{
"expires in",
"temporary hold",
"limited time",
"unusual activity detected",
"suspicious activity",
"your account will be",
"verify within",
"action required",
"immediate action",
"account suspended",
"access will be revoked",
}
// Embedded asset indicators - phishing kits embed logos to avoid external loading.
var embeddedAssetPatterns = []string{
"data:image/png;base64,",
"data:image/svg+xml;base64,",
"data:image/jpeg;base64,",
}
// ---------------------------------------------------------------------------
// Main check entry point
// ---------------------------------------------------------------------------
// CheckPhishing scans HTML files in user document roots for phishing pages.
// Uses three detection layers:
// 1. Content analysis - brand impersonation + credential harvesting patterns
// 2. Structural analysis - self-contained HTML with embedded assets
// 3. Directory anomaly - lone HTML files in otherwise empty directories
func CheckPhishing(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
homeDirs, err := GetScanHomeDirs(ctx)
if err != nil {
return nil
}
for _, homeEntry := range homeDirs {
if ctx.Err() != nil {
return findings
}
if !homeEntry.IsDir() {
continue
}
user := homeEntry.Name()
if strings.HasPrefix(user, ".") || user == "virtfs" {
continue
}
homeDir := filepath.Join("/home", user)
docRoots := []string{filepath.Join(homeDir, "public_html")}
subDirs, _ := osFS.ReadDir(homeDir)
for _, sd := range subDirs {
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
docRoots = append(docRoots, filepath.Join(homeDir, sd.Name()))
}
}
for _, docRoot := range docRoots {
scanForPhishing(ctx, docRoot, 3, user, cfg, &findings)
if ctx.Err() != nil {
return findings
}
}
}
return findings
}
// ---------------------------------------------------------------------------
// Directory scanner
// ---------------------------------------------------------------------------
func scanForPhishing(ctx context.Context, dir string, maxDepth int, user string, cfg *config.Config, findings *[]alert.Finding) {
if ctx.Err() != nil {
return
}
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
if ctx.Err() != nil {
return
}
name := entry.Name()
fullPath := filepath.Join(dir, name)
suppressed := false
for _, ignore := range cfg.Suppressions.IgnorePaths {
if matchGlob(fullPath, ignore) {
suppressed = true
break
}
}
if suppressed {
continue
}
if entry.IsDir() {
if isKnownSafeDir(name) {
continue
}
// --- Directory anomaly detection ---
dirResult := analyzeDirectoryStructure(fullPath, user)
if dirResult != nil {
*findings = append(*findings, *dirResult)
}
scanForPhishing(ctx, fullPath, maxDepth-1, user, cfg, findings)
continue
}
nameLower := strings.ToLower(name)
info, err := entry.Info()
if err != nil {
continue
}
size := info.Size()
// --- HTML/HTM phishing pages ---
if strings.HasSuffix(nameLower, ".html") || strings.HasSuffix(nameLower, ".htm") {
// Standard phishing page check (3KB-100KB)
if size >= 3000 && size <= 100000 {
result := analyzeHTMLForPhishing(fullPath)
if result != nil {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "phishing_page",
Message: fmt.Sprintf("Phishing page detected (%s impersonation): %s", result.brand, fullPath),
Details: fmt.Sprintf("Account: %s\nBrand: %s\nScore: %d/10\nIndicators:\n- %s\nSize: %d bytes",
user, result.brand, result.score, strings.Join(result.indicators, "\n- "), size),
FilePath: fullPath,
})
}
}
// --- iframe phishing (tiny HTML files that embed external phishing) ---
if size > 0 && size < 3000 {
if result := checkIframePhishing(fullPath); result != "" {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "phishing_iframe",
Message: fmt.Sprintf("Iframe phishing page detected: %s", fullPath),
Details: fmt.Sprintf("Account: %s\n%s", user, result),
FilePath: fullPath,
})
}
}
continue
}
// --- PHP phishing pages and open redirectors ---
if isExecutablePHPName(nameLower) {
// Skip known CMS files
if isKnownCMSFile(nameLower) {
continue
}
// PHP phishing (3KB-100KB) - same brand/content analysis as HTML
if size >= 3000 && size <= 100000 {
result := analyzePHPForPhishing(fullPath)
if result != nil {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "phishing_php",
Message: fmt.Sprintf("PHP phishing page detected (%s): %s", result.brand, fullPath),
Details: fmt.Sprintf("Account: %s\nBrand: %s\nScore: %d/10\nIndicators:\n- %s\nSize: %d bytes",
user, result.brand, result.score, strings.Join(result.indicators, "\n- "), size),
FilePath: fullPath,
})
}
}
// PHP open redirector (tiny PHP files under 1KB)
if size > 0 && size < 1024 {
if result := checkPHPRedirector(fullPath); result != "" {
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "phishing_redirector",
Message: fmt.Sprintf("PHP open redirector detected: %s", fullPath),
Details: fmt.Sprintf("Account: %s\n%s", user, result),
FilePath: fullPath,
})
}
}
continue
}
// --- Credential log files ---
if isCredentialLogName(nameLower) && size > 0 && size < 10*1024*1024 {
if result := checkCredentialLog(fullPath); result != "" {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "phishing_credential_log",
Message: fmt.Sprintf("Harvested credential log file detected: %s", fullPath),
Details: fmt.Sprintf("Account: %s\n%s", user, result),
FilePath: fullPath,
})
}
continue
}
// --- Phishing kit ZIP archives ---
if strings.HasSuffix(nameLower, ".zip") && size > 1000 && size < 50*1024*1024 {
if isPhishingKitZip(nameLower) {
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "phishing_kit_archive",
Message: fmt.Sprintf("Suspected phishing kit archive: %s", fullPath),
Details: fmt.Sprintf("Account: %s\nFilename: %s\nSize: %d bytes",
user, name, size),
FilePath: fullPath,
})
}
}
}
}
// ---------------------------------------------------------------------------
// Layer 1: Content analysis
// ---------------------------------------------------------------------------
type phishingResult struct {
brand string
score int
indicators []string
}
func analyzeHTMLForPhishing(path string) *phishingResult {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
buf := make([]byte, phishingReadSize)
n, _ := f.Read(buf)
if n == 0 {
return nil
}
content := string(buf[:n])
contentLower := strings.ToLower(content)
// Must contain a form or input
hasForm := strings.Contains(contentLower, "<form") ||
strings.Contains(contentLower, "<input")
if !hasForm {
return nil
}
// Must contain email/password input
hasCredentialInput := hasHTMLCredentialInput(contentLower)
if !hasCredentialInput {
return nil
}
var indicators []string
score := 0
// --- Brand impersonation ---
brandMatch := ""
titleContent := extractTitle(contentLower)
for _, brand := range phishingBrands {
titleHit := false
bodyHit := false
for _, tp := range brand.titlePatterns {
if strings.Contains(titleContent, tp) {
titleHit = true
indicators = append(indicators, fmt.Sprintf("title impersonates '%s'", tp))
score += 3
break
}
}
for _, bp := range brand.bodyPatterns {
if strings.Contains(contentLower, bp) {
bodyHit = true
if !titleHit {
indicators = append(indicators, fmt.Sprintf("body impersonates '%s'", bp))
score += 2
}
break
}
}
if titleHit || bodyHit {
brandMatch = brand.name
break
}
}
// --- Credential harvesting / redirect indicators ---
for _, pattern := range harvestIndicators {
if strings.Contains(contentLower, pattern) {
score++
indicators = append(indicators, fmt.Sprintf("harvest: '%s'", pattern))
}
}
// --- Exfiltration URL patterns ---
for _, pattern := range exfilPatterns {
if strings.Contains(contentLower, pattern) {
score += 2
indicators = append(indicators, fmt.Sprintf("exfiltration: '%s'", pattern))
}
}
// --- Fake trust badges ---
for _, pattern := range trustBadgePatterns {
if strings.Contains(contentLower, pattern) {
score++
indicators = append(indicators, fmt.Sprintf("fake trust badge: '%s'", pattern))
}
}
// --- Urgency language ---
urgencyCount := 0
for _, pattern := range urgencyPatterns {
if strings.Contains(contentLower, pattern) {
urgencyCount++
}
}
if urgencyCount > 0 {
score++
indicators = append(indicators, fmt.Sprintf("urgency language (%d patterns)", urgencyCount))
}
// --- Embedded Base64 assets (logos embedded to avoid external loading) ---
embeddedCount := 0
for _, pattern := range embeddedAssetPatterns {
embeddedCount += countOccurrences(contentLower, pattern)
}
if embeddedCount > 0 {
score++
indicators = append(indicators, fmt.Sprintf("embedded base64 assets (%d)", embeddedCount))
}
// --- Form action pointing to external domain ---
if hasExternalFormAction(content) {
score += 2
indicators = append(indicators, "form action points to external domain")
}
// --- Self-contained page (all CSS inline, no external stylesheets except CDN) ---
if isSelfContainedHTML(contentLower) {
score++
indicators = append(indicators, "self-contained HTML (all styles inline)")
}
// --- Person name as filename ---
baseName := strings.TrimSuffix(strings.TrimSuffix(filepath.Base(path), ".html"), ".htm")
if looksLikePersonName(baseName) {
score += 2
indicators = append(indicators, fmt.Sprintf("filename looks like person name: '%s'", baseName))
}
// --- Decision ---
// With brand match: need score >= 4 (brand gives 2-3 + at least 1 other signal)
// Without brand match: need score >= 6 (multiple strong signals)
if brandMatch != "" && score >= 4 {
return &phishingResult{brand: brandMatch, score: score, indicators: indicators}
}
if brandMatch == "" && score >= 6 {
return &phishingResult{brand: "Unknown", score: score, indicators: indicators}
}
return nil
}
// ---------------------------------------------------------------------------
// Layer 2: Structural analysis helpers
// ---------------------------------------------------------------------------
// extractTitle pulls the <title> content from HTML.
func extractTitle(contentLower string) string {
start := strings.Index(contentLower, "<title>")
end := strings.Index(contentLower, "</title>")
if start >= 0 && end > start+7 {
return contentLower[start+7 : end]
}
return ""
}
// hasExternalFormAction checks if a <form> action points to a different domain.
func hasExternalFormAction(content string) bool {
for _, url := range htmlAttrValues(strings.ToLower(content), "form", "action", false) {
// External if it starts with http:// or https:// (not relative)
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
return true
}
}
return false
}
func hasHTMLCredentialInput(contentLower string) bool {
for _, value := range htmlAttrValues(contentLower, "input", "type", true) {
switch strings.TrimSpace(value) {
case "email", "password":
return true
}
}
for _, value := range htmlAttrValues(contentLower, "input", "name", true) {
switch strings.TrimSpace(value) {
case "email", "pass", "password", "login":
return true
}
}
for _, value := range htmlAttrValues(contentLower, "input", "placeholder", true) {
value = strings.TrimSpace(value)
if strings.HasPrefix(value, "email") ||
strings.HasPrefix(value, "you@") ||
strings.HasPrefix(value, "your email") {
return true
}
}
return strings.Contains(contentLower, "work or school email") ||
strings.Contains(contentLower, "corporate email")
}
func htmlAttrValues(contentLower, tagName, attrName string, allowUnquoted bool) []string {
var values []string
needle := "<" + tagName
for offset := 0; offset < len(contentLower); {
idx := strings.Index(contentLower[offset:], needle)
if idx < 0 {
break
}
start := offset + idx
afterName := start + len(needle)
if afterName < len(contentLower) && !isTagBoundary(contentLower[afterName]) {
offset = afterName
continue
}
end := findTagEnd(contentLower, afterName)
if end < 0 {
break
}
if value, ok := tagAttrValue(contentLower[afterName:end], attrName, allowUnquoted); ok {
values = append(values, value)
}
offset = end + 1
}
return values
}
func isTagBoundary(c byte) bool {
return c == '>' || c == '/' || unicode.IsSpace(rune(c))
}
func findTagEnd(content string, start int) int {
var quote byte
for i := start; i < len(content); i++ {
c := content[i]
if quote != 0 {
if c == quote {
quote = 0
}
continue
}
if c == '"' || c == '\'' {
quote = c
continue
}
if c == '>' {
return i
}
}
return -1
}
func tagAttrValue(attrs, attrName string, allowUnquoted bool) (string, bool) {
for i := 0; i < len(attrs); {
for i < len(attrs) && (unicode.IsSpace(rune(attrs[i])) || attrs[i] == '/') {
i++
}
nameStart := i
for i < len(attrs) && attrs[i] != '=' && attrs[i] != '>' &&
attrs[i] != '/' && !unicode.IsSpace(rune(attrs[i])) {
i++
}
if nameStart == i {
i++
continue
}
name := attrs[nameStart:i]
for i < len(attrs) && unicode.IsSpace(rune(attrs[i])) {
i++
}
if i >= len(attrs) || attrs[i] != '=' {
continue
}
i++
for i < len(attrs) && unicode.IsSpace(rune(attrs[i])) {
i++
}
if i >= len(attrs) {
return "", false
}
value := ""
if attrs[i] == '"' || attrs[i] == '\'' {
quote := attrs[i]
i++
valueStart := i
for i < len(attrs) && attrs[i] != quote {
i++
}
value = attrs[valueStart:i]
if i < len(attrs) {
i++
}
} else {
valueStart := i
for i < len(attrs) && attrs[i] != '>' && !unicode.IsSpace(rune(attrs[i])) {
i++
}
if !allowUnquoted {
continue
}
value = attrs[valueStart:i]
}
if name == attrName {
return strings.TrimSpace(value), true
}
}
return "", false
}
// isSelfContainedHTML checks if a page has all its CSS inline (embedded <style> tags)
// with no or minimal external stylesheet references - typical of phishing kits.
func isSelfContainedHTML(contentLower string) bool {
hasInlineStyle := strings.Contains(contentLower, "<style")
externalCSS := countOccurrences(contentLower, "rel=\"stylesheet\"") +
countOccurrences(contentLower, "rel='stylesheet'")
// Allow 1 external CSS (e.g., Font Awesome CDN) - phishing kits often use one
return hasInlineStyle && externalCSS <= 1
}
// looksLikePersonName checks if a filename looks like a person name (CamelCase
// with 2+ capitalized words, e.g., "PalmerHamilton", "MarilynEsguerra").
func looksLikePersonName(name string) bool {
if len(name) < 6 {
return false
}
// Count uppercase transitions (start of words in CamelCase)
upperCount := 0
for i, c := range name {
if unicode.IsUpper(c) {
if i == 0 || unicode.IsLower(rune(name[i-1])) {
upperCount++
}
}
}
// 2+ capitalized words, all letters, no common web words
if upperCount < 2 {
return false
}
allLetters := true
for _, c := range name {
if !unicode.IsLetter(c) {
allLetters = false
break
}
}
if !allLetters {
return false
}
// Exclude common web filenames
nameLower := strings.ToLower(name)
webNames := []string{"index", "default", "portal", "login", "home", "main",
"readme", "changelog", "license", "manifest", "service"}
for _, w := range webNames {
if nameLower == w {
return false
}
}
return true
}
// ---------------------------------------------------------------------------
// Layer 3: Directory anomaly detection
// ---------------------------------------------------------------------------
// analyzeDirectoryStructure checks if a directory looks like a phishing drop:
// - Contains only 1-3 HTML files and nothing else significant
// - Directory name looks like a business/organization name
// - No CMS markers (wp-config, index.php, etc.)
func analyzeDirectoryStructure(dir string, user string) *alert.Finding {
entries, err := osFS.ReadDir(dir)
if err != nil {
return nil
}
var htmlFiles []string
otherFiles := 0
totalFiles := 0
for _, entry := range entries {
if entry.IsDir() {
return nil // Has subdirectories - likely not a simple phishing drop
}
name := entry.Name()
if strings.HasPrefix(name, ".") {
continue // Skip dotfiles (.htaccess etc.)
}
totalFiles++
nameLower := strings.ToLower(name)
if strings.HasSuffix(nameLower, ".html") || strings.HasSuffix(nameLower, ".htm") {
htmlFiles = append(htmlFiles, name)
} else {
otherFiles++
}
}
// Must have exactly 1-3 HTML files and at most 1 other file
if len(htmlFiles) == 0 || len(htmlFiles) > 3 || otherFiles > 1 {
return nil
}
// Directory name should look like a business/organization (CamelCase or multi-word)
dirName := filepath.Base(dir)
if !looksLikeBusinessName(dirName) {
return nil
}
// Verify at least one HTML file has credential inputs (quick check)
hasPhishingContent := false
for _, htmlFile := range htmlFiles {
fullPath := filepath.Join(dir, htmlFile)
if quickPhishingCheck(fullPath) {
hasPhishingContent = true
break
}
}
if !hasPhishingContent {
return nil
}
// Build indicators
indicators := []string{
fmt.Sprintf("directory '%s' contains only %d HTML file(s)", dirName, len(htmlFiles)),
fmt.Sprintf("directory name resembles business/organization: '%s'", dirName),
}
for _, h := range htmlFiles {
baseName := strings.TrimSuffix(strings.TrimSuffix(h, ".html"), ".htm")
if looksLikePersonName(baseName) {
indicators = append(indicators, fmt.Sprintf("HTML filename looks like person name: '%s'", baseName))
}
}
return &alert.Finding{
Severity: alert.High,
Check: "phishing_directory",
Message: fmt.Sprintf("Suspected phishing directory (lone HTML in business-named folder): %s", dir),
Details: fmt.Sprintf("Account: %s\nDirectory: %s\nHTML files: %s\nIndicators:\n- %s",
user, dirName, strings.Join(htmlFiles, ", "), strings.Join(indicators, "\n- ")),
FilePath: dir,
}
}
// looksLikeBusinessName checks if a directory name looks like a business or
// organization name rather than a standard web directory.
func looksLikeBusinessName(name string) bool {
if len(name) < 5 {
return false
}
nameLower := strings.ToLower(name)
// Skip names that start with tech/dev terms - these are tutorial
// or test directories, not business names (e.g. "php-email-form",
// "PHP-Login", "JavaScript Login")
techPrefixes := []string{
"php", "javascript", "js-", "css", "html", "python",
"java", "node", "react", "vue", "angular", "jquery",
"bootstrap", "wordpress", "wp-", "laravel",
}
for _, prefix := range techPrefixes {
if strings.HasPrefix(nameLower, prefix) {
return false
}
}
// Skip standard web directories
standardDirs := []string{
"images", "img", "css", "js", "fonts", "assets", "static",
"media", "uploads", "downloads", "files", "docs", "data",
"api", "admin", "config", "templates", "scripts", "lib",
"src", "dist", "build", "public", "private", "backup",
"old", "new", "test", "dev", "staging", "demo",
}
for _, sd := range standardDirs {
if nameLower == sd {
return false
}
}
// CamelCase detection (e.g., WashingtonGolf, XRFScientificAmericasInc)
upperTransitions := 0
for i, c := range name {
if unicode.IsUpper(c) && i > 0 && unicode.IsLower(rune(name[i-1])) {
upperTransitions++
}
}
if upperTransitions >= 1 && unicode.IsUpper(rune(name[0])) {
return true
}
// Multi-word with separators (e.g., federated-lighting, northwest_crawlspace)
if strings.ContainsAny(name, "-_") {
parts := strings.FieldsFunc(name, func(r rune) bool { return r == '-' || r == '_' })
if len(parts) >= 2 {
return true
}
}
// Long lowercase name that doesn't match standard dirs (e.g., "healthcornerpediattrics")
allLower := true
for _, c := range name {
if !unicode.IsLetter(c) {
allLower = false
break
}
}
if allLower && len(name) >= 12 {
return true
}
return false
}
// quickPhishingCheck does a fast read of an HTML file and confirms phishing
// shape: a credential-collection form AND at least one phishing-kit signal
// in the page body itself. Used by the directory-anomaly heuristic to
// avoid flagging benign HTML that happens to ship an <input> tag.
//
// A bare "<form> + email/password keyword" gate matches developer demo
// pages, JavaScript login tutorials, contact forms, and password-reset
// stubs. Phishing kits add at least one of:
// - a form action that posts to an external host (exfiltration target),
// - a fully self-contained inline-styled HTML body (kits ship one file),
// - brand impersonation in the title or visible body (Office/PayPal/etc).
//
// Requiring credential intake plus one of those signals keeps real
// phishing kits in scope while letting tutorials and trivial forms drop
// out without consulting any path-name allowlist.
func quickPhishingCheck(path string) bool {
f, err := osFS.Open(path)
if err != nil {
return false
}
defer func() { _ = f.Close() }()
buf := make([]byte, 4096) // first 4KB is enough for the head, form attrs, brand strings
n, _ := f.Read(buf)
if n == 0 {
return false
}
content := string(buf[:n])
contentLower := strings.ToLower(content)
if !hasHTMLCredentialInput(contentLower) {
return false
}
if hasExternalFormAction(content) {
return true
}
if isSelfContainedHTML(contentLower) {
return true
}
titleContent := extractTitle(contentLower)
for _, brand := range phishingBrands {
for _, tp := range brand.titlePatterns {
if titleContent != "" && strings.Contains(titleContent, tp) {
return true
}
}
for _, bp := range brand.bodyPatterns {
if strings.Contains(contentLower, bp) {
return true
}
}
}
return false
}
// ---------------------------------------------------------------------------
// Layer 4: PHP phishing pages
// ---------------------------------------------------------------------------
// analyzePHPForPhishing reads a PHP file and checks for embedded HTML with
// brand impersonation. PHP phishing kits often have PHP code at the top
// (credential handling, emailing) and HTML output below.
func analyzePHPForPhishing(path string) *phishingResult {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
buf := make([]byte, phishingReadSize)
n, _ := f.Read(buf)
if n == 0 {
return nil
}
content := string(buf[:n])
contentLower := strings.ToLower(content)
// PHP phishing indicators: credential handling code.
// Only truly specific patterns belong here - generic functions like
// mail() and fwrite() are handled separately with context checks below.
phpPhishingPatterns := []string{
"$_post['email']",
"$_post['password']",
"$_post['pass']",
"$_post[\"email\"]",
"$_post[\"password\"]",
"$_post[\"pass\"]",
"$_request['email']",
"$_request['password']",
}
var indicators []string
score := 0
// Check for PHP credential handling
phpCredHandling := false
for _, pattern := range phpPhishingPatterns {
if strings.Contains(contentLower, pattern) {
phpCredHandling = true
indicators = append(indicators, fmt.Sprintf("PHP credential handling: '%s'", pattern))
score += 2
break
}
}
// Must have either PHP credential handling OR HTML form output
hasForm := strings.Contains(contentLower, "<form") || strings.Contains(contentLower, "<input")
hasCredentialInput := hasHTMLCredentialInput(contentLower)
if !phpCredHandling && (!hasForm || !hasCredentialInput) {
return nil
}
// Check brand impersonation (same as HTML check)
brandMatch := ""
titleContent := extractTitle(contentLower)
for _, brand := range phishingBrands {
for _, tp := range brand.titlePatterns {
if strings.Contains(titleContent, tp) {
brandMatch = brand.name
indicators = append(indicators, fmt.Sprintf("title impersonates '%s'", tp))
score += 3
break
}
}
if brandMatch != "" {
break
}
for _, bp := range brand.bodyPatterns {
if strings.Contains(contentLower, bp) {
brandMatch = brand.name
indicators = append(indicators, fmt.Sprintf("body impersonates '%s'", bp))
score += 2
break
}
}
if brandMatch != "" {
break
}
}
// Check harvest/exfil patterns
for _, pattern := range harvestIndicators {
if strings.Contains(contentLower, pattern) {
score++
indicators = append(indicators, fmt.Sprintf("harvest: '%s'", pattern))
}
}
for _, pattern := range exfilPatterns {
if strings.Contains(contentLower, pattern) {
score += 2
indicators = append(indicators, fmt.Sprintf("exfiltration: '%s'", pattern))
}
}
// PHP-specific exfil: emailing or writing harvested credentials.
// These only fire when the file also reads from $_POST/$_REQUEST,
// because a phishing kit must capture form data before exfiltrating it.
// Without this gate, any PHP file using fwrite()+config keywords triggers.
hasPostData := strings.Contains(contentLower, "$_post") || strings.Contains(contentLower, "$_request")
if hasPostData && strings.Contains(contentLower, "mail(") &&
(strings.Contains(contentLower, "password") || strings.Contains(contentLower, "email")) {
score += 2
indicators = append(indicators, "PHP mail() with credential data")
}
if hasPostData &&
(strings.Contains(contentLower, "fwrite(") || strings.Contains(contentLower, "file_put_contents(")) {
if strings.Contains(contentLower, "password") || strings.Contains(contentLower, "email") ||
strings.Contains(contentLower, "result") || strings.Contains(contentLower, "log") {
score += 2
indicators = append(indicators, "PHP writes credential data to file")
}
}
// Require brand impersonation to flag as phishing.
// PHP files with $_POST['email'] + mail() are normal (contact forms, CMS user
// admin, gallery software). Without brand impersonation, these are almost
// always legitimate applications.
if brandMatch == "" {
return nil
}
if score >= 4 {
return &phishingResult{brand: brandMatch, score: score, indicators: indicators}
}
return nil
}
// isKnownCMSFile returns true for PHP files that are standard CMS files
// and should not be scanned for phishing (too many false positives).
func isKnownCMSFile(nameLower string) bool {
cmsFiles := map[string]bool{
"index.php": true, "wp-config.php": true, "wp-login.php": true,
"wp-cron.php": true, "wp-settings.php": true, "wp-load.php": true,
"wp-blog-header.php": true, "wp-links-opml.php": true,
"xmlrpc.php": true, "wp-signup.php": true, "wp-activate.php": true,
"wp-trackback.php": true, "wp-comments-post.php": true,
"wp-mail.php": true, "configuration.php": true,
"config.php": true, "settings.php": true,
}
return cmsFiles[nameLower]
}
// ---------------------------------------------------------------------------
// Layer 5: PHP open redirectors
// ---------------------------------------------------------------------------
// checkPHPRedirector reads a small PHP file and checks if it's an open
// redirector - a file that redirects the visitor to a URL from a parameter.
func checkPHPRedirector(path string) string {
f, err := osFS.Open(path)
if err != nil {
return ""
}
defer func() { _ = f.Close() }()
buf := make([]byte, 1024)
n, _ := f.Read(buf)
if n == 0 {
return ""
}
content := strings.ToLower(string(buf[:n]))
hasHeader := strings.Contains(content, "header(") &&
(strings.Contains(content, "location:") || strings.Contains(content, "location :"))
if !hasHeader {
return ""
}
// Pattern 1: user-controlled redirect target - the URL in header() must
// come from user input. Just having $_GET anywhere + header() is too broad;
// normal form handlers use $_POST for data then header() for redirect.
// Only flag when the redirect URL itself is parameterized.
userControlledRedirect := false
redirectPatterns := []string{
"$_get['url']", "$_get[\"url\"]",
"$_get['redirect']", "$_get[\"redirect\"]",
"$_get['r']", "$_get[\"r\"]",
"$_get['return']", "$_get[\"return\"]",
"$_get['next']", "$_get[\"next\"]",
"$_get['goto']", "$_get[\"goto\"]",
"$_get['link']", "$_get[\"link\"]",
"$_request['url']", "$_request[\"url\"]",
"$_request['redirect']", "$_request[\"redirect\"]",
"header(\"location: \".$_get", "header(\"location: \".$_request",
"header('location: '.$_get", "header('location: '.$_request",
"header(\"location:\".$_get", "header('location:'.$_get",
}
for _, p := range redirectPatterns {
if strings.Contains(content, p) {
userControlledRedirect = true
break
}
}
if userControlledRedirect {
return "PHP open redirector: header(Location) with user-supplied URL"
}
// Pattern 2: Hardcoded redirect to suspicious domain
for _, pattern := range exfilPatterns {
if strings.Contains(content, pattern) {
return fmt.Sprintf("PHP redirect to suspicious destination matching '%s'", pattern)
}
}
return ""
}
// ---------------------------------------------------------------------------
// Layer 6: Credential log files
// ---------------------------------------------------------------------------
// isCredentialLogName checks if a filename matches patterns used by phishing
// kits to store harvested credentials.
func isCredentialLogName(nameLower string) bool {
// Exact names commonly used by phishing kits
exactNames := map[string]bool{
"results.txt": true, "result.txt": true, "log.txt": true,
"logs.txt": true, "emails.txt": true, "data.txt": true,
"passwords.txt": true, "creds.txt": true, "credentials.txt": true,
"victims.txt": true, "output.txt": true, "harvested.txt": true,
"results.log": true, "emails.log": true, "data.log": true,
"results.csv": true, "emails.csv": true, "data.csv": true,
"results.html": true,
}
if exactNames[nameLower] {
return true
}
// Pattern: contains "result", "victim", "harvested", "credential" in name
suspiciousWords := []string{"result", "victim", "harvest", "credential", "creds", "stolen"}
for _, word := range suspiciousWords {
if strings.Contains(nameLower, word) {
return true
}
}
return false
}
// checkCredentialLog reads a text file and checks if it contains harvested
// credentials (email:password pairs, one per line).
func checkCredentialLog(path string) string {
f, err := osFS.Open(path)
if err != nil {
return ""
}
defer func() { _ = f.Close() }()
// Read first 4KB
buf := make([]byte, 4096)
n, _ := f.Read(buf)
if n == 0 {
return ""
}
content := string(buf[:n])
lines := strings.Split(content, "\n")
credentialLines := 0
emailCount := 0
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Pattern: email:password or email|password or email,password
if strings.Contains(line, "@") {
emailCount++
// Check for delimiter after email
for _, delim := range []string{":", "|", "\t", ","} {
parts := strings.SplitN(line, delim, 3)
if len(parts) >= 2 {
part0 := strings.TrimSpace(parts[0])
part1 := strings.TrimSpace(parts[1])
if strings.Contains(part0, "@") && len(part1) > 0 && !strings.Contains(part1, " ") {
credentialLines++
break
}
}
}
}
}
// 3+ lines that look like email:password pairs = credential log
if credentialLines >= 3 {
return fmt.Sprintf("File contains %d credential-like lines (email:password format) out of %d email lines",
credentialLines, emailCount)
}
// High density of emails alone (10+) in a non-.csv file is suspicious
if emailCount >= 10 && !strings.HasSuffix(strings.ToLower(path), ".csv") {
return fmt.Sprintf("File contains %d email addresses - possible harvested email list", emailCount)
}
return ""
}
// ---------------------------------------------------------------------------
// Layer 7: Iframe phishing
// ---------------------------------------------------------------------------
// checkIframePhishing checks small HTML files for iframe-based phishing -
// a minimal HTML page that just loads an external phishing page in a full-screen iframe.
func checkIframePhishing(path string) string {
f, err := osFS.Open(path)
if err != nil {
return ""
}
defer func() { _ = f.Close() }()
buf := make([]byte, 3000)
n, _ := f.Read(buf)
if n == 0 {
return ""
}
contentLower := strings.ToLower(string(buf[:n]))
// Must contain an iframe
if !strings.Contains(contentLower, "<iframe") {
return ""
}
// Check if iframe src points to an external URL
idx := strings.Index(contentLower, "<iframe")
if idx < 0 {
return ""
}
rest := contentLower[idx:]
endTag := strings.Index(rest, ">")
if endTag < 0 {
return ""
}
iframeTag := rest[:endTag+1]
// Extract src
srcIdx := strings.Index(iframeTag, "src=")
if srcIdx < 0 {
return ""
}
srcRest := iframeTag[srcIdx+4:]
if len(srcRest) == 0 {
return ""
}
quote := srcRest[0]
if quote != '"' && quote != '\'' {
return ""
}
srcEnd := strings.IndexByte(srcRest[1:], quote)
if srcEnd < 0 {
return ""
}
src := srcRest[1 : srcEnd+1]
// Must be external (http:// or https://)
if !strings.HasPrefix(src, "http://") && !strings.HasPrefix(src, "https://") {
return ""
}
// Check if iframe is fullscreen (width/height 100% or style covers viewport)
isFullscreen := strings.Contains(iframeTag, "100%") ||
strings.Contains(contentLower, "width:100%") ||
strings.Contains(contentLower, "width: 100%") ||
strings.Contains(contentLower, "position:fixed") ||
strings.Contains(contentLower, "position: fixed")
if isFullscreen {
return fmt.Sprintf("Full-screen iframe loading external URL: %s", src)
}
// Even non-fullscreen, check if URL matches known phishing/exfil patterns
for _, pattern := range exfilPatterns {
if strings.Contains(src, pattern) {
return fmt.Sprintf("Iframe loading suspicious external URL matching '%s': %s", pattern, src)
}
}
return ""
}
// ---------------------------------------------------------------------------
// Layer 8: Phishing kit ZIP archives
// ---------------------------------------------------------------------------
// isPhishingKitZip checks if a ZIP filename matches common phishing kit names.
// Requires 2+ keyword matches to reduce false positives (e.g. "CssCheckboxKit"
// matched "kit" alone, but legitimate UI kits, CSS kits, etc. are common).
func isPhishingKitZip(nameLower string) bool {
// High-confidence single-match keywords (brand impersonation in filename)
singleMatch := []string{
"office365", "office 365", "sharepoint", "onedrive",
"microsoft", "outlook", "gmail",
"dropbox", "docusign", "wetransfer",
"paypal", "icloud", "netflix",
"facebook", "instagram", "linkedin",
"roundcube", "cpanel",
"phish", "scam",
}
for _, kw := range singleMatch {
if strings.Contains(nameLower, kw) {
return true
}
}
// Lower-confidence keywords - require 2+ matches to flag.
// Words like "login", "verify", "secure", "google", "apple", "bank"
// appear in legitimate archives too.
multiMatch := []string{
"login", "verify", "secure", "bank",
"google", "apple", "adobe", "webmail",
}
matches := 0
for _, kw := range multiMatch {
if strings.Contains(nameLower, kw) {
matches++
}
}
return matches >= 2
}
// ---------------------------------------------------------------------------
// Safe directory list
// ---------------------------------------------------------------------------
func isKnownSafeDir(name string) bool {
safeDirs := map[string]bool{
"wp-admin": true, "wp-includes": true, "wp-content": true,
"node_modules": true, "vendor": true, ".git": true,
"cgi-bin": true, ".well-known": true, "mail": true,
"cache": true, "tmp": true, "logs": true,
}
return safeDirs[name]
}
package checks
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strings"
"sync/atomic"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// nestedEvalDecodeRe matches the PHP token sequence
// `<eval|assert> [ws] ( [ws] [@] [\] <ident> [ws] (`, with DOTALL so the
// source can have line breaks inside the whitespace gaps. Attackers wedge
// comments or common call modifiers between the sink and decoder; comments and
// strings are stripped before matching so only executable token structure is
// evaluated.
var nestedEvalDecodeRe = regexp.MustCompile(`(?is)\b(eval|assert)\s*\(\s*@?\s*\\?\s*(\w+)\s*\(`)
// reEvalVarCallee matches eval wrapping a variable function call,
// e.g. `eval($f(...))`. The literal-callee form above cannot capture a
// `$var` callee, yet eval'ing the result of a dynamic function call is a
// near-certain dropper signal in user web directories.
var reEvalVarCallee = regexp.MustCompile(`(?is)\beval\s*\(\s*@?\s*\$\w+\s*\(`)
// evalExecWrapInner lists code-construction primitives that, when wrapped
// directly by eval, indicate dynamic code execution rather than the
// decoder/decompressor chain nestedEvalDecodeRe already covers.
var evalExecWrapInner = map[string]struct{}{
"create_function": {},
"call_user_func": {},
"call_user_func_array": {},
}
var callbackFirstArgFuncs = map[string]struct{}{
"array_map": {},
"call_user_func": {},
"call_user_func_array": {},
"register_shutdown_function": {},
"register_tick_function": {},
}
// callbackExecNames are names that execute code when used as a callback. They
// are RCE regardless of where the call's arguments come from, so they flag
// unconditionally.
var callbackExecNames = map[string]struct{}{
"assert": {},
"call_user_func": {},
"create_function": {},
"eval": {},
"exec": {},
"passthru": {},
"popen": {},
"proc_open": {},
"shell_exec": {},
"system": {},
}
// callbackDecoderNames are decode/decompress primitives. As a callback they
// only transform data (array_map('base64_decode', $data) returns decoded
// bytes, it executes nothing), so legitimate plugins use them constantly. They
// only signal a dropper when the same call is fed request input; the
// decode-then-eval shape is covered separately by the eval-chain detectors.
var callbackDecoderNames = map[string]struct{}{
"base64_decode": {},
"gzinflate": {},
"gzuncompress": {},
"str_rot13": {},
}
// reVarVarCall matches a variable-variable or dynamic-expression function
// invocation, e.g. `$$h(...)` or `${$x}(...)`. On its own this shows up in
// some dispatcher code, so it is only treated as an indicator when the same
// line also carries a request superglobal (the RCE shape).
var reVarVarCall = regexp.MustCompile(`(?:\$\$\w+|\$\{[^}]{1,64}\})\s*\(`)
// includeDangerWrappers are stream wrappers / remote schemes that, as an
// include/require target, mean remote-file inclusion or php://input code
// execution. Matched on the comment-stripped (strings preserved) source.
var includeDangerWrappers = []string{"data://", "php://", "phar://", "http://", "https://", "ftp://"}
var includeKeywords = []string{"include_once", "require_once", "include", "require"}
var (
pregReplaceCallName = map[string]struct{}{"preg_replace": {}}
codeEvalPrimitiveCallNames = map[string]struct{}{"assert": {}, "create_function": {}}
callUserFuncCallNames = map[string]struct{}{"call_user_func": {}, "call_user_func_array": {}}
)
// phpContentReadSize bounds the head window read from a PHP file for content
// analysis. Most real PHP is far smaller; the window is large enough that an
// attacker cannot cheaply hide a payload by prepending benign padding.
const phpContentReadSize = 1 << 20 // 1 MiB head window
// phpContentTailSize is also scanned for files larger than the head window so
// a payload appended after >1 MiB of padding is still seen. The two windows
// are scanned together; bounded total memory is head+tail.
const phpContentTailSize = 64 << 10 // 64 KiB tail window
// readPHPContentWindows returns the bytes to analyse for a PHP file: the
// first phpContentReadSize bytes, plus the last phpContentTailSize bytes when
// the file is larger than the head window. tailOffset is -1 when tail is empty.
// readOK is false only on a real read error.
func readPHPContentWindows(f io.ReaderAt, size int64) (head, tail []byte, tailOffset int64, readOK bool) {
tailOffset = -1
headLen := int64(phpContentReadSize)
if size >= 0 && size < headLen {
headLen = size
}
headBuf := make([]byte, headLen)
n, err := f.ReadAt(headBuf, 0)
if err != nil && !errors.Is(err, io.EOF) {
return nil, nil, -1, false
}
head = headBuf[:n]
if size > int64(phpContentReadSize) {
tailLen := int64(phpContentTailSize)
off := size - tailLen
if off < int64(phpContentReadSize) {
// Overlap into the head window is fine; it only re-scans bytes.
off = int64(phpContentReadSize)
tailLen = size - off
}
if tailLen > 0 {
tailBuf := make([]byte, tailLen)
tn, terr := f.ReadAt(tailBuf, off)
if terr != nil && !errors.Is(terr, io.EOF) {
return nil, nil, -1, false
}
tail = tailBuf[:tn]
tailOffset = off
}
}
return head, tail, tailOffset, true
}
func combinePHPContentWindows(head, tail []byte) []byte {
if len(tail) == 0 {
return head
}
capacity := len(head) + len(tail)
raw := make([]byte, 0, capacity)
raw = append(raw, head...)
raw = append(raw, tail...)
return raw
}
func phpCodeOnlyWindows(head, tail []byte) string {
return phpCodeOnly(string(combinePHPContentWindows(head, tail)))
}
func hasBacktickSuperglobal(code string) bool {
for i := 0; i < len(code); i++ {
if isPHPQuote(code[i]) {
i = skipPHPString(code, i)
continue
}
if code[i] != '`' {
continue
}
end := i + 1
for end < len(code) {
if code[end] == '\\' && end+1 < len(code) {
end += 2
continue
}
if code[end] == '`' {
break
}
end++
}
if end >= len(code) {
break
}
if containsRequestSuperglobal(code[i+1 : end]) {
return true
}
i = end
}
return false
}
func hasCallbackExecName(code string) bool {
for i := 0; i < len(code); i++ {
if isPHPQuote(code[i]) {
i = skipPHPString(code, i)
continue
}
nameStart := i
if code[i] == '\\' {
if i+1 >= len(code) || !isPHPIdentifierStart(code[i+1]) || !canStartGlobalPHPFunction(code, i) {
continue
}
nameStart = i + 1
} else if !isPHPIdentifierStart(code[i]) || !canStartPHPFunctionName(code, i) {
continue
}
nameEnd := nameStart + 1
for nameEnd < len(code) && isPHPIdentifierPart(code[nameEnd]) {
nameEnd++
}
name := strings.ToLower(code[nameStart:nameEnd])
if _, ok := callbackFirstArgFuncs[name]; !ok {
i = nameEnd - 1
continue
}
openParen := skipPHPWhitespace(code, nameEnd)
if openParen >= len(code) || code[openParen] != '(' {
i = nameEnd - 1
continue
}
firstArg := skipPHPWhitespace(code, openParen+1)
if firstArg >= len(code) || !isPHPQuote(code[firstArg]) {
i = nameEnd - 1
continue
}
callbackName, _, ok := readPHPFunctionString(code, firstArg)
if !ok {
i = nameEnd - 1
continue
}
if _, dangerous := callbackExecNames[callbackName]; dangerous {
return true
}
if _, decoder := callbackDecoderNames[callbackName]; decoder {
closeParen := matchingParen(code, openParen)
if containsRequestSuperglobalExpression(code[openParen:closeParen]) {
return true
}
}
i = nameEnd - 1
}
return false
}
// matchingParen returns the index of the close paren matching the open paren
// at openParen, or len(code) if unbalanced. Quoted strings are skipped so
// parens inside string literals do not throw off the depth count.
func matchingParen(code string, openParen int) int {
depth := 0
for i := openParen; i < len(code); i++ {
if isPHPQuote(code[i]) {
i = skipPHPString(code, i)
continue
}
switch code[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 {
return i
}
}
}
return len(code)
}
// hasDangerousInclude reports an include/require whose target is request
// input or a remote/stream wrapper -- the LFI/RFI and php://input code-exec
// shapes. Scans the include expression, not the whole line, so unrelated
// request reads or URLs after a static include do not trip the detector.
func hasDangerousInclude(code string) bool {
searchFrom := 0
for {
exprStart, exprEnd, ok := nextIncludeExpression(code, searchFrom)
if !ok {
break
}
expr := code[exprStart:exprEnd]
if containsIncludeTargetExpression(expr) {
return true
}
exprLower := strings.ToLower(expr)
for _, w := range includeDangerWrappers {
if strings.Contains(exprLower, w) {
return true
}
}
searchFrom = exprEnd
if searchFrom >= len(code) {
break
}
}
return false
}
// hasCodeEvalPrimitiveWithRequest reports assert()/create_function() fed
// request input on the same line. assert() evaluates a string argument as PHP
// (pre-8.0) and create_function() evals its body argument; both are RCE when
// driven by attacker input.
func hasCodeEvalPrimitiveWithRequest(code string) bool {
for _, line := range strings.Split(code, "\n") {
searchFrom := 0
for {
_, openParen, closeParen, ok := nextStandalonePHPCall(line, searchFrom, codeEvalPrimitiveCallNames)
if !ok {
break
}
exprEnd := closeParen
if exprEnd < len(line) {
exprEnd++
}
if containsRequestSuperglobal(line[openParen:exprEnd]) {
return true
}
searchFrom = exprEnd
if searchFrom >= len(line) {
break
}
}
}
return false
}
// hasPregReplaceEvalWithRequest reports a preg_replace() call whose pattern
// carries the /e modifier and whose evaluated replacement or subject reads
// request input. The /e modifier (removed in PHP 7.0) evaluates the replacement
// as PHP after backreferences from the subject are interpolated, so pattern-only
// request input is not enough to call this a code-execution sink. Bare /e with
// static arguments is the legacy WordPress serialize-fix / autolink idiom --
// real /e usage, but not a dropper -- so it is gated on request-input
// correlation, the same way the shell, include, and assert detectors here gate
// their sinks.
// Walks the source skipping string literals so a documentation example in a
// string does not trip, then inspects the literal first argument of each call.
func hasPregReplaceEvalWithRequest(code string) bool {
searchFrom := 0
for {
callStart, openParen, closeParen, ok := nextStandalonePHPCall(code, searchFrom, pregReplaceCallName)
if !ok {
break
}
args := phpCallArguments(code, openParen+1, closeParen)
if len(args) < 3 || !pregPatternArgumentHasEvalModifier(args[0]) {
searchFrom = nextSearchOffset(closeParen, len(code))
continue
}
tainted := requestTaintedVariablesBefore(code, callStart)
if pregReplacementReadsRequest(args[1], tainted) || phpExpressionReadsRequest(args[2], tainted) {
return true
}
searchFrom = nextSearchOffset(closeParen, len(code))
}
return false
}
func pregPatternArgumentHasEvalModifier(expr string) bool {
expr = strings.TrimSpace(expr)
if expr == "" || !isPHPQuote(expr[0]) {
return false
}
end := skipPHPString(expr, 0)
// skipPHPString returns the closing-quote index, or the last index for an
// unterminated literal. Guard the slice: a quote in the final position
// leaves no string body to inspect.
if end <= 0 {
return false
}
return pregPatternHasEvalModifier(expr[1:end])
}
func pregReplacementReadsRequest(expr string, tainted map[string]struct{}) bool {
if phpExpressionReadsRequest(expr, tainted) {
return true
}
body, _, ok := phpStringLiteralExpression(expr)
if !ok {
return false
}
return phpExpressionReadsRequest(body, tainted)
}
func phpExpressionReadsRequest(expr string, tainted map[string]struct{}) bool {
return containsRequestSuperglobalExpression(expr) || phpExpressionReferencesTaintedVariable(expr, tainted)
}
func phpExpressionReferencesTaintedVariable(expr string, tainted map[string]struct{}) bool {
if len(tainted) == 0 {
return false
}
for i := 0; i < len(expr); i++ {
if isPHPQuote(expr[i]) {
i = skipPHPString(expr, i)
continue
}
if expr[i] != '$' {
continue
}
variable, next, ok := readPHPVariableName(expr, i)
if !ok {
continue
}
if _, found := tainted[variable]; found {
return true
}
i = next - 1
}
return false
}
func requestTaintedVariablesBefore(code string, limit int) map[string]struct{} {
if limit > len(code) {
limit = len(code)
}
if limit <= 0 {
return nil
}
scan := code[:limit]
taintStack := []map[string]struct{}{{}}
functionBraceStack := []bool{}
pendingFunctionScope := false
for i := 0; i < len(scan); i++ {
if isPHPQuote(scan[i]) {
i = skipPHPString(scan, i)
continue
}
if end, ok := phpKeywordAt(scan, i, "function"); ok {
pendingFunctionScope = true
i = end - 1
continue
}
switch scan[i] {
case '{':
functionBraceStack = append(functionBraceStack, pendingFunctionScope)
if pendingFunctionScope {
taintStack = append(taintStack, map[string]struct{}{})
}
pendingFunctionScope = false
continue
case '}':
if len(functionBraceStack) > 0 {
last := len(functionBraceStack) - 1
if functionBraceStack[last] && len(taintStack) > 1 {
taintStack = taintStack[:len(taintStack)-1]
}
functionBraceStack = functionBraceStack[:last]
}
pendingFunctionScope = false
continue
case ';':
pendingFunctionScope = false
}
if scan[i] != '$' {
continue
}
variable, next, ok := readPHPVariableName(scan, i)
if !ok || isRequestSuperglobalVariable(variable) {
continue
}
j := skipPHPWhitespace(scan, next)
opLen, directAssign, appendAssign, ok := phpAssignmentOperator(scan, j)
if !ok {
i = next - 1
continue
}
exprStart := skipPHPWhitespace(scan, j+opLen)
exprEnd := phpExpressionEnd(scan, exprStart)
expr := scan[exprStart:exprEnd]
tainted := taintStack[len(taintStack)-1]
exprTainted := phpExpressionReadsRequest(expr, tainted)
switch {
case directAssign:
if exprTainted {
tainted[variable] = struct{}{}
} else {
delete(tainted, variable)
}
case appendAssign:
if exprTainted {
tainted[variable] = struct{}{}
}
default:
if exprTainted {
tainted[variable] = struct{}{}
} else {
delete(tainted, variable)
}
}
i = exprEnd - 1
}
return taintStack[len(taintStack)-1]
}
func phpKeywordAt(code string, start int, keyword string) (int, bool) {
end := start + len(keyword)
if end > len(code) || !strings.EqualFold(code[start:end], keyword) {
return 0, false
}
if start > 0 {
prev := code[start-1]
if isPHPIdentifierPart(prev) || prev == '$' || prev == '>' || prev == ':' || prev == '\\' {
return 0, false
}
}
if end < len(code) && isPHPIdentifierPart(code[end]) {
return 0, false
}
return end, true
}
func isRequestSuperglobalVariable(variable string) bool {
switch strings.ToLower(variable) {
case "_request", "_post", "_get", "_cookie", "_server":
return true
default:
return false
}
}
func phpStringLiteralExpression(expr string) (string, byte, bool) {
expr = strings.TrimSpace(expr)
if expr == "" || !isPHPQuote(expr[0]) {
return "", 0, false
}
end := skipPHPString(expr, 0)
if end <= 0 || end >= len(expr) || expr[end] != expr[0] {
return "", 0, false
}
if skipPHPWhitespace(expr, end+1) != len(expr) {
return "", 0, false
}
return expr[1:end], expr[0], true
}
func phpCallArguments(code string, start, end int) []string {
if start < 0 {
start = 0
}
if end > len(code) {
end = len(code)
}
if start > end {
return nil
}
var args []string
argStart := skipPHPWhitespace(code, start)
depth := 0
for i := start; i < end; i++ {
if isPHPQuote(code[i]) {
i = skipPHPString(code, i)
continue
}
switch code[i] {
case '(', '[', '{':
depth++
case ')', ']', '}':
if depth > 0 {
depth--
}
case ',':
if depth == 0 {
args = append(args, strings.TrimSpace(code[argStart:i]))
argStart = skipPHPWhitespace(code, i+1)
}
}
}
if tail := strings.TrimSpace(code[argStart:end]); tail != "" || len(args) > 0 {
args = append(args, tail)
}
return args
}
// pregPatternHasEvalModifier returns true when the PCRE pattern string carries
// an "e" modifier after its closing delimiter.
func pregPatternHasEvalModifier(pat string) bool {
if len(pat) < 2 {
return false
}
open := pat[0]
// PHP forbids alphanumeric, backslash, and whitespace delimiters.
if isPHPIdentifierPart(open) || open == '\\' || open == ' ' {
return false
}
closeDelim := open
switch open {
case '(':
closeDelim = ')'
case '[':
closeDelim = ']'
case '{':
closeDelim = '}'
case '<':
closeDelim = '>'
}
idx := strings.LastIndexByte(pat, closeDelim)
if idx <= 0 {
return false
}
for _, c := range pat[idx+1:] {
if (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') {
return false
}
if c == 'e' {
return true
}
}
return false
}
func nextIncludeExpression(line string, searchFrom int) (int, int, bool) {
for i := searchFrom; i < len(line); i++ {
if isPHPQuote(line[i]) {
i = skipPHPString(line, i)
continue
}
keywordEnd, ok := includeKeywordEnd(line, i)
if !ok {
continue
}
exprStart := skipPHPWhitespace(line, keywordEnd)
return exprStart, phpExpressionEnd(line, exprStart), true
}
return 0, 0, false
}
func includeKeywordEnd(line string, start int) (int, bool) {
if !canStartIncludeKeyword(line, start) {
return 0, false
}
for _, keyword := range includeKeywords {
end := start + len(keyword)
if end > len(line) || !strings.EqualFold(line[start:end], keyword) {
continue
}
if end < len(line) && isPHPIdentifierPart(line[end]) {
continue
}
return end, true
}
return 0, false
}
func canStartIncludeKeyword(line string, start int) bool {
if start == 0 {
return true
}
prev := line[start-1]
return !isPHPIdentifierPart(prev) && prev != '$' && prev != '>' && prev != ':' && prev != '\\'
}
func phpExpressionEnd(code string, start int) int {
depth := 0
for i := start; i < len(code); i++ {
if isPHPQuote(code[i]) {
i = skipPHPString(code, i)
continue
}
switch code[i] {
case '(', '[', '{':
depth++
case ')', ']', '}':
if depth == 0 {
return i
}
depth--
case ';':
if depth == 0 {
return i
}
case ',':
if depth == 0 {
return i
}
}
}
return len(code)
}
// phpCodeOnly blanks the inline-HTML regions of a PHP source so only the code
// inside <?php ... ?> (and <?= ... ?>) spans is analysed for execution sinks.
// Inline HTML is literal output and cannot execute PHP, so scanning it as code
// only yields false positives: apostrophes in prose ("don't", "you're") desync
// the string scanner, and href URLs, JS backtick template literals, and English
// words like "include"/"require" in markup then read as PHP execution sinks.
//
// HTML bytes become spaces (newlines kept so line-oriented detectors keep their
// line structure); a closing "?>" becomes "; " so it still bounds the preceding
// statement and consecutive <?php?> blocks do not run their expressions
// together. The "?>" scan skips PHP strings, heredoc/nowdoc bodies, and block
// comments so a "?>" inside them does not end PHP mode and blank real code. A
// file with no PHP open tag yields all blanks -- it executes nothing.
func phpCodeOnly(src string) string {
var b strings.Builder
b.Grow(len(src))
n := len(src)
i := 0
for i < n {
// HTML mode: blank up to the next "<?" open tag.
htmlStart := i
for i < n && (src[i] != '<' || i+1 >= n || src[i+1] != '?') {
i++
}
blankInlineHTML(&b, src[htmlStart:i])
if i >= n {
break
}
// Blank the opening tag: "<?php" (needs trailing whitespace/EOF), "<?=",
// or a bare "<?" short tag.
i += 2
b.WriteString(" ")
if i+3 <= n && strings.EqualFold(src[i:i+3], "php") && (i+3 == n || isPHPSpace(src[i+3])) {
b.WriteString(" ")
i += 3
} else if i < n && src[i] == '=' {
b.WriteByte(' ')
i++
}
// PHP mode: copy verbatim until a top-level "?>".
i = copyPHPModeRegion(&b, src, i)
}
return b.String()
}
// copyPHPModeRegion copies src[start:] into b verbatim until a top-level "?>"
// (which it replaces with "; ") or EOF, and returns the resume index. Strings,
// heredoc/nowdoc bodies, and block comments are copied whole so a "?>" inside
// them is not mistaken for a closing tag. A "?>" inside a // or # line comment
// does end PHP mode, matching PHP's own tokeniser.
func copyPHPModeRegion(b *strings.Builder, src string, start int) int {
n := len(src)
i := start
for i < n {
if label, bodyStart, ok := phpHeredocOpen(src, i); ok {
end := phpHeredocEnd(src, bodyStart, label)
b.WriteString(src[i:end])
i = end
continue
}
if isPHPQuote(src[i]) {
i = copyPHPString(b, src, i) + 1
continue
}
if src[i] == '/' && i+1 < n && src[i+1] == '*' {
b.WriteString("/*")
i += 2
for i < n {
if src[i] == '*' && i+1 < n && src[i+1] == '/' {
b.WriteString("*/")
i += 2
break
}
b.WriteByte(src[i])
i++
}
continue
}
if (src[i] == '/' && i+1 < n && src[i+1] == '/') || src[i] == '#' {
for i < n && src[i] != '\n' {
if src[i] == '?' && i+1 < n && src[i+1] == '>' {
break
}
b.WriteByte(src[i])
i++
}
if i+1 < n && src[i] == '?' && src[i+1] == '>' {
b.WriteString("; ")
return i + 2
}
continue
}
if src[i] == '?' && i+1 < n && src[i+1] == '>' {
b.WriteString("; ")
return i + 2
}
b.WriteByte(src[i])
i++
}
return i
}
// blankInlineHTML writes s as spaces, preserving newlines so line-based
// detectors keep their line boundaries across blanked template text.
func blankInlineHTML(b *strings.Builder, s string) {
for i := 0; i < len(s); i++ {
if s[i] == '\n' || s[i] == '\r' {
b.WriteByte(s[i])
} else {
b.WriteByte(' ')
}
}
}
func nextStandalonePHPCall(code string, searchFrom int, names map[string]struct{}) (int, int, int, bool) {
for i := searchFrom; i < len(code); i++ {
if isPHPQuote(code[i]) {
i = skipPHPString(code, i)
continue
}
nameStart := i
if code[i] == '\\' {
if i+1 >= len(code) || !isPHPIdentifierStart(code[i+1]) || !canStartGlobalPHPFunction(code, i) {
continue
}
nameStart = i + 1
} else if !isPHPIdentifierStart(code[i]) || !canStartPHPFunctionName(code, i) {
continue
}
nameEnd := nameStart + 1
for nameEnd < len(code) && isPHPIdentifierPart(code[nameEnd]) {
nameEnd++
}
if _, ok := names[strings.ToLower(code[nameStart:nameEnd])]; !ok {
i = nameEnd - 1
continue
}
openParen := skipPHPWhitespace(code, nameEnd)
if openParen >= len(code) || code[openParen] != '(' {
i = nameEnd - 1
continue
}
return i, openParen, matchingParen(code, openParen), true
}
return 0, 0, 0, false
}
func nextSearchOffset(pos, codeLen int) int {
if pos >= codeLen {
return codeLen
}
return pos + 1
}
var requestSuperglobalNames = []string{"$_request", "$_post", "$_get", "$_cookie", "$_server"}
// includeTargetSuperglobalNames omits $_SERVER path keys: including a path
// built from the server document root or script filename is the standard
// WordPress bootstrap idiom, not an LFI/RFI primitive. Header-derived
// $_SERVER keys are handled separately below.
var includeTargetSuperglobalNames = []string{"$_request", "$_post", "$_get", "$_cookie"}
func containsRequestSuperglobal(code string) bool {
return containsAnySuperglobal(code, requestSuperglobalNames)
}
func containsAnySuperglobal(code string, names []string) bool {
code = strings.ToLower(code)
for _, requestVar := range names {
searchFrom := 0
for {
pos := strings.Index(code[searchFrom:], requestVar)
if pos < 0 {
break
}
end := searchFrom + pos + len(requestVar)
if end >= len(code) || !isPHPIdentifierPart(code[end]) {
return true
}
searchFrom = end
}
}
return false
}
// PHP single-quoted strings do not interpolate variables, but double-quoted
// strings do. The decoder callback gate needs that distinction so a literal
// '$_POST' data value does not look like request input.
func containsRequestSuperglobalExpression(code string) bool {
return containsSuperglobalExpression(code, requestSuperglobalNames)
}
// containsIncludeTargetExpression reports whether an include/require target
// expression reads an attacker-controlled superglobal. Non-header $_SERVER
// path keys are left alone (see includeTargetSuperglobalNames).
func containsIncludeTargetExpression(code string) bool {
return containsSuperglobalExpression(code, includeTargetSuperglobalNames) ||
containsServerHeaderIncludeExpression(code)
}
func containsSuperglobalExpression(code string, names []string) bool {
start := 0
for i := 0; i < len(code); i++ {
if !isPHPQuote(code[i]) {
continue
}
if containsAnySuperglobal(code[start:i], names) {
return true
}
end := skipPHPString(code, i)
if code[i] == '"' && containsAnySuperglobal(code[i:end+1], names) {
return true
}
i = end
start = end + 1
}
return containsAnySuperglobal(code[start:], names)
}
func containsServerHeaderIncludeExpression(code string) bool {
for i := 0; i < len(code); i++ {
if isPHPQuote(code[i]) {
end := skipPHPString(code, i)
if code[i] == '"' && end > i && containsServerHeaderReference(code[i+1:end]) {
return true
}
i = end
continue
}
dangerous, next, ok := serverIncludeReferenceAt(code, i)
if !ok {
continue
}
if dangerous {
return true
}
i = next - 1
}
return false
}
func containsServerHeaderReference(code string) bool {
for i := 0; i < len(code); i++ {
dangerous, next, ok := serverIncludeReferenceAt(code, i)
if !ok {
continue
}
if dangerous {
return true
}
i = next - 1
}
return false
}
func serverIncludeReferenceAt(code string, start int) (bool, int, bool) {
const serverName = "$_server"
end := start + len(serverName)
if end > len(code) || !strings.EqualFold(code[start:end], serverName) {
return false, start, false
}
if end < len(code) && isPHPIdentifierPart(code[end]) {
return false, start, false
}
bracket := skipPHPWhitespace(code, end)
if bracket >= len(code) || code[bracket] != '[' {
return true, end, true
}
keyStart := skipPHPWhitespace(code, bracket+1)
if keyStart >= len(code) {
return true, len(code), true
}
if isPHPQuote(code[keyStart]) {
keyEnd := skipPHPString(code, keyStart)
key := phpStringLiteralValue(code, keyStart, keyEnd)
next := skipPHPWhitespace(code, keyEnd+1)
if next < len(code) && code[next] == ']' {
return isAttackerControlledServerKey(key), next + 1, true
}
return true, next, true
}
if isPHPIdentifierStart(code[keyStart]) {
keyEnd := keyStart + 1
for keyEnd < len(code) && isPHPIdentifierPart(code[keyEnd]) {
keyEnd++
}
key := code[keyStart:keyEnd]
next := skipPHPWhitespace(code, keyEnd)
if next < len(code) && code[next] == ']' {
return isAttackerControlledServerKey(key), next + 1, true
}
return true, next, true
}
return true, keyStart + 1, true
}
func phpStringLiteralValue(code string, start, end int) string {
var b strings.Builder
quote := code[start]
for i := start + 1; i < end; i++ {
if code[i] != '\\' || i+1 >= end {
b.WriteByte(code[i])
continue
}
i++
esc := code[i]
if quote == '\'' {
if esc == '\'' || esc == '\\' {
b.WriteByte(esc)
} else {
b.WriteByte('\\')
b.WriteByte(esc)
}
continue
}
switch esc {
case 'x':
if i+1 >= end || !isHexDigit(code[i+1]) {
b.WriteByte('\\')
b.WriteByte(esc)
continue
}
value := hexVal(code[i+1])
i++
if i+1 < end && isHexDigit(code[i+1]) {
value = value*16 + hexVal(code[i+1])
i++
}
// #nosec G115 -- PHP hex string escapes are at most one byte.
b.WriteByte(byte(value))
case 'u':
if i+2 >= end || code[i+1] != '{' || !isHexDigit(code[i+2]) {
b.WriteByte('\\')
b.WriteByte(esc)
continue
}
value := 0
j := i + 2
for ; j < end && isHexDigit(code[j]); j++ {
value = value*16 + hexVal(code[j])
}
if j >= end || code[j] != '}' {
b.WriteByte('\\')
b.WriteByte(esc)
continue
}
if value > 0x10ffff {
b.WriteByte('\\')
b.WriteByte(esc)
continue
}
// #nosec G115 -- value is capped at the largest valid Unicode code point.
b.WriteRune(rune(value))
i = j
case '0', '1', '2', '3', '4', '5', '6', '7':
value := int(esc - '0')
digits := 1
for i+1 < end && digits < 3 && code[i+1] >= '0' && code[i+1] <= '7' {
value = value*8 + int(code[i+1]-'0')
i++
digits++
}
// #nosec G115 -- PHP octal string escapes are byte escapes.
b.WriteByte(byte(value))
case 'n':
b.WriteByte('\n')
case 'r':
b.WriteByte('\r')
case 't':
b.WriteByte('\t')
case 'v':
b.WriteByte('\v')
case 'e':
b.WriteByte(0x1b)
case 'f':
b.WriteByte('\f')
case '\\', '$', '"':
b.WriteByte(esc)
default:
b.WriteByte('\\')
b.WriteByte(esc)
}
}
return b.String()
}
func isAttackerControlledServerKey(key string) bool {
key = strings.ToUpper(strings.TrimSpace(key))
if strings.HasPrefix(key, "HTTP_") {
return true
}
switch key {
case "CONTENT_LENGTH", "CONTENT_TYPE", "PHP_AUTH_DIGEST", "PHP_AUTH_PW", "PHP_AUTH_USER":
return true
default:
return false
}
}
func canStartGlobalPHPFunction(code string, slash int) bool {
if slash == 0 {
return true
}
prev := code[slash-1]
return !isPHPIdentifierPart(prev) && prev != '$' && prev != '>' && prev != ':' && prev != '\\'
}
func canStartPHPFunctionName(code string, start int) bool {
if start == 0 {
return true
}
prev := code[start-1]
return !isPHPIdentifierPart(prev) && prev != '$' && prev != '>' && prev != ':' && prev != '\\'
}
// CheckPHPContent scans new/suspicious PHP files for obfuscation patterns,
// remote payload fetching, and eval chains. This is designed to catch droppers
// like the LEVIATHAN attack's file.php and files.php that use goto spaghetti,
// hex-encoded strings, and call_user_func with string-built function names.
//
// This check scans PHP files in directories that shouldn't normally contain
// user-authored PHP: wp-content/languages, wp-content/upgrade, wp-content/mu-plugins,
// and also checks any PHP files flagged by the file index as new.
// phpFileStamp is the cheap content-version key for a scanned PHP file. A file
// whose mtime and size both match the previous cycle is treated as unchanged.
type phpFileStamp struct {
Mtime int64 `json:"m"`
Size int64 `json:"s"`
}
// phpContentCache maps a file path to the stamp it carried when last confirmed
// clean. Only clean files are stored, so a present, matching entry means
// "unchanged and previously produced no finding."
type phpContentCache map[string]phpFileStamp
func loadPHPContentCache(stateDir string) phpContentCache {
cache := phpContentCache{}
if stateDir == "" {
return cache
}
data, err := osFS.ReadFile(filepath.Join(stateDir, "phpcontentcache.json"))
if err == nil {
_ = json.Unmarshal(data, &cache)
}
return cache
}
func savePHPContentCache(stateDir string, cache phpContentCache) {
if stateDir == "" {
return
}
data, _ := json.Marshal(cache)
tmpPath := filepath.Join(stateDir, "phpcontentcache.json.tmp")
_ = os.WriteFile(tmpPath, data, 0600)
_ = os.Rename(tmpPath, filepath.Join(stateDir, "phpcontentcache.json"))
}
// phpContentHostScanCount drives a periodic forced full rescan that bypasses
// the content cache, mirroring the file-index cadence. The cache keys on
// mtime+size alone, so a content swap that preserves both (an attacker resetting
// mtime after editing in place) would be skipped until the file changed again.
// The forced rescan bounds that window; realtime fanotify covers it in between.
var phpContentHostScanCount int32
// phpContentAccountScanCount keeps account-scoped scans from consuming the
// host-wide full-rescan cadence. Account scans are subsets and never save the
// shared cache, so mixing the counters can make the host-wide backstop miss its
// intended cycle.
var phpContentAccountScanCount int32
func phpContentForceFull(ctx context.Context) bool {
if AccountFromContext(ctx) != "" {
return atomic.AddInt32(&phpContentAccountScanCount, 1)%6 == 0
}
return atomic.AddInt32(&phpContentHostScanCount, 1)%6 == 0
}
// phpContentScan carries the per-cycle cache state through the recursive walk.
// prev is the previous cycle's clean-file stamps (read-only); next is rebuilt
// from the files seen this cycle, which prunes deleted files automatically.
type phpContentScan struct {
cfg *config.Config
prev phpContentCache
next phpContentCache
forceFull bool
}
func newPHPContentScan(cfg *config.Config, prev phpContentCache, forceFull bool) *phpContentScan {
if prev == nil {
prev = phpContentCache{}
}
return &phpContentScan{cfg: cfg, prev: prev, next: phpContentCache{}, forceFull: forceFull}
}
func CheckPHPContent(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
homeDirs, err := GetScanHomeDirs(ctx)
if err != nil {
return nil
}
scan := newPHPContentScan(cfg, loadPHPContentCache(cfg.StatePath), phpContentForceFull(ctx))
for _, homeEntry := range homeDirs {
if ctx.Err() != nil {
return findings
}
if !homeEntry.IsDir() {
continue
}
homeDir := filepath.Join("/home", homeEntry.Name())
// Get all potential document roots
docRoots := []string{filepath.Join(homeDir, "public_html")}
subDirs, _ := osFS.ReadDir(homeDir)
for _, sd := range subDirs {
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
docRoots = append(docRoots, filepath.Join(homeDir, sd.Name()))
}
}
for _, docRoot := range docRoots {
// Scan directories that shouldn't contain user PHP
suspiciousDirs := []string{
filepath.Join(docRoot, "wp-content", "languages"),
filepath.Join(docRoot, "wp-content", "upgrade"),
filepath.Join(docRoot, "wp-content", "mu-plugins"),
filepath.Join(docRoot, "wp-content", "plugins"),
filepath.Join(docRoot, "wp-content", "themes"),
}
for _, dir := range suspiciousDirs {
scan.scanDir(ctx, dir, 4, phpHandlerOverlay{}, &findings)
if ctx.Err() != nil {
return findings
}
}
}
}
// Persist only after a complete host-wide scan. A run cut short by ctx
// timeout leaves scan.next missing the unscanned files; persisting it would
// drop their cached-clean state and force a needless re-read next cycle
// (still safe, just slower). An account-scoped run (account_scan) only walks
// one account, so its scan.next is a subset of the shared cache and must not
// overwrite the host-wide stamps.
if ctx.Err() == nil && AccountFromContext(ctx) == "" {
savePHPContentCache(cfg.StatePath, scan.next)
}
return findings
}
// scanDirForObfuscatedPHP scans dir without the content cache: every PHP file
// is read and analysed. Used where caching does not apply (no prior cycle to
// compare against).
func scanDirForObfuscatedPHP(ctx context.Context, dir string, maxDepth int, cfg *config.Config, findings *[]alert.Finding) {
newPHPContentScan(cfg, nil, true).scanDir(ctx, dir, maxDepth, phpHandlerOverlay{}, findings)
}
// scanDir recursively scans dir for PHP files with malicious content patterns.
// A file that was clean last cycle and is unchanged (same mtime+size) skips the
// read+parse, unless this is a forced full rescan. Files that produce a finding
// are never cached, so they re-surface on every cycle for the alert pipeline.
func (s *phpContentScan) scanDir(ctx context.Context, dir string, maxDepth int, overlay phpHandlerOverlay, findings *[]alert.Finding) {
if ctx.Err() != nil {
return
}
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
// Layer this directory's .htaccess PHP handler remappings onto the set
// inherited from parent directories. An attacker who maps a non-PHP
// extension to the PHP interpreter (the LEVIATHAN .htaccess trick) or who
// SetHandlers the whole directory must not be able to hide executable PHP
// behind a name the stock handler would not run. Read once per directory.
if htaccess, herr := osFS.ReadFile(filepath.Join(dir, ".htaccess")); herr == nil {
overlay = overlay.mergeHtaccess(htaccess)
}
for _, entry := range entries {
if ctx.Err() != nil {
return
}
name := entry.Name()
fullPath := filepath.Join(dir, name)
// Check suppressed paths
suppressed := false
for _, ignore := range s.cfg.Suppressions.IgnorePaths {
if matchGlob(fullPath, ignore) {
suppressed = true
break
}
}
if suppressed {
continue
}
if entry.IsDir() {
s.scanDir(ctx, fullPath, maxDepth-1, overlay, findings)
continue
}
nameLower := strings.ToLower(name)
if !overlay.executes(nameLower) {
continue
}
info, statErr := osFS.Stat(fullPath)
var stamp phpFileStamp
canCache := statErr == nil
if canCache {
stamp = phpFileStamp{Mtime: info.ModTime().Unix(), Size: info.Size()}
// Cache hit: file was clean last cycle and has not changed. Skip
// the read+parse and carry the stamp forward only if the file is
// still readable. chmod does not update mtime or size, so a stale
// clean cache entry must not mask a file we can no longer inspect.
if !s.forceFull {
if prev, ok := s.prev[fullPath]; ok && prev == stamp {
if phpContentReadable(fullPath) {
s.next[fullPath] = stamp
continue
}
}
}
}
// Every .php file is content-analysed. No filename/path allowlist:
// clean files produce no finding, so there is no benefit to skipping
// them, and any skip is a place an attacker can hide a backdoor.
result := analyzePHPContent(fullPath)
if result.severity >= 0 {
details := result.details
if info != nil {
details += fmt.Sprintf("\nSize: %d, Mtime: %s", info.Size(), info.ModTime().Format("2006-01-02 15:04:05"))
}
*findings = append(*findings, alert.Finding{
Severity: result.severity,
Check: result.check,
Message: fmt.Sprintf("%s: %s", result.message, fullPath),
Details: details,
FilePath: fullPath,
})
continue
}
// Cache only files we read successfully and confirmed clean. An
// unreadable file might become readable later, so it must not be
// recorded as clean.
if canCache && result.readOK {
s.next[fullPath] = stamp
}
}
}
func phpContentReadable(path string) bool {
f, err := osFS.Open(path)
if err != nil {
return false
}
_ = f.Close()
return true
}
type phpAnalysisResult struct {
severity alert.Severity
check string
message string
details string
indicators []string
readOK bool
empty bool
}
// analyzePHPContent reads a PHP file's head window (and tail window for large
// files) and checks for obfuscation and malicious patterns.
func analyzePHPContent(path string) phpAnalysisResult {
f, err := osFS.Open(path)
if err != nil {
return phpAnalysisResult{severity: -1}
}
defer func() { _ = f.Close() }()
var size int64 = -1
if info, statErr := f.Stat(); statErr == nil {
size = info.Size()
}
head, tail, tailOffset, readOK := readPHPContentWindows(f, size)
if !readOK {
return phpAnalysisResult{severity: -1, readOK: false}
}
if len(head) == 0 && len(tail) == 0 {
return phpAnalysisResult{severity: -1, readOK: true, empty: true}
}
if len(tail) > 0 && tailOffset > int64(len(head)) {
// The skipped bytes may close a quote or comment that starts in the
// head. Score windows independently so head parser state cannot hide
// executable tail code we did read.
return mergePHPAnalysisResults(
analyzePHPCode(path, phpCodeOnly(string(head)), readOK),
analyzePHPCode(path, phpCodeOnly("<?php\n"+string(tail)), readOK),
)
}
return analyzePHPCode(path, phpCodeOnlyWindows(head, tail), readOK)
}
func mergePHPAnalysisResults(results ...phpAnalysisResult) phpAnalysisResult {
readOK := true
var indicators []string
for _, result := range results {
readOK = readOK && result.readOK
if result.severity >= 0 {
indicators = append(indicators, result.indicators...)
}
}
return phpAnalysisFromIndicators(indicators, readOK)
}
func analyzePHPCode(path, content string, readOK bool) phpAnalysisResult {
contentLower := strings.ToLower(content)
var indicators []string
// --- Critical: Remote payload fetching ---
// Paste sites are always suspicious in PHP files.
// GitHub raw URLs are common in legitimate plugin update checkers,
// so only count them as indicators when they appear on the same line
// as a dangerous PHP function call.
pasteHosts := []string{
"pastebin.com/raw",
"paste.ee/r/",
"ghostbin.co/paste/",
"hastebin.com/raw/",
}
for _, host := range pasteHosts {
if strings.Contains(contentLower, host) {
indicators = append(indicators, fmt.Sprintf("remote payload URL: %s", host))
}
}
githubHosts := []string{"gist.githubusercontent.com", "raw.githubusercontent.com"}
dangerousCalls := []string{"file_put_contents(", "fwrite(", "shell_", "passthru(", "popen("}
for _, host := range githubHosts {
if !strings.Contains(contentLower, host) {
continue
}
// Same-line = strong signal (critical)
sameLine := false
for _, line := range strings.Split(contentLower, "\n") {
if !strings.Contains(line, host) {
continue
}
for _, fn := range dangerousCalls {
if strings.Contains(line, fn) {
indicators = append(indicators, fmt.Sprintf("remote payload URL with dangerous call: %s", host))
sameLine = true
break
}
}
if sameLine {
break
}
}
// Co-presence (different lines, same 32 KB window) was previously
// emitted as a weaker indicator. It generated standing FPs on
// legit plugins that fetch upstream resources from github mirrors
// (wp-statistics GeoLite2 updates, unyson font fetcher, polylang
// language packs). Same-line is the strong signal kept above; the
// co-presence path is removed entirely.
}
// --- Critical: eval() chains with decoding ---
// Only flag when eval directly wraps a decoder (structural nesting),
// not when they merely co-exist in the same file (which causes false
// positives on legitimate plugins that use eval for templates and
// base64_decode for unrelated data processing).
decoders := []string{
"base64_decode", "gzinflate", "gzuncompress", "str_rot13",
"rawurldecode", "gzdecode", "bzdecompress",
}
hasDecoder := false
hasNestedEvalDecode := false
for _, d := range decoders {
if strings.Contains(contentLower, d) {
hasDecoder = true
}
}
// Check for structural nesting: eval(base64_decode(...)), eval(gzinflate(...)), etc.
// PHP tolerates inline comments and arbitrary whitespace (including line
// breaks) between the keyword and its open paren, so a naive line-by-line
// `eval(` substring scan misses `eval /*x*/ ( base64_decode(...))` and
// `eval // bypass\n( base64_decode(...))`. Strip PHP comments and
// strings first, then match the structural pattern across whitespace,
// and require the inner callee to be one of the known decoders /
// decompressors.
commentStripped := stripPHPCommentsFromCode(content)
codeLower := strings.ToLower(stripPHPStringsFromCode(commentStripped))
for _, m := range nestedEvalDecodeRe.FindAllStringSubmatch(codeLower, -1) {
if len(m) < 3 {
continue
}
inner := m[2]
for _, d := range decoders {
if inner == d {
hasNestedEvalDecode = true
break
}
}
if hasNestedEvalDecode {
break
}
}
if hasNestedEvalDecode {
indicators = append(indicators, "eval() directly wrapping encoding/compression function")
}
// eval wrapping dynamic code construction the decoder loop above
// ignores: a variable callee (eval($f(...))) or a code-building
// primitive (eval(create_function(...)), eval(call_user_func(...))).
// These never appear in legitimate user-directory PHP; a single hit is
// surfaced as a High signal (the >=2 gate still governs quarantine).
hasEvalExecWrap := reEvalVarCallee.MatchString(codeLower)
if !hasEvalExecWrap {
for _, m := range nestedEvalDecodeRe.FindAllStringSubmatch(codeLower, -1) {
if len(m) < 3 {
continue
}
if m[1] != "eval" {
continue
}
if _, ok := evalExecWrapInner[m[2]]; ok {
hasEvalExecWrap = true
break
}
}
}
if hasEvalExecWrap {
indicators = append(indicators, "eval() wrapping a dynamic code-execution primitive")
}
// Backtick shell execution with request input -- `...$_GET...`.
// Match only executable backtick spans, not quoted examples.
if hasBacktickSuperglobal(commentStripped) {
indicators = append(indicators, "backtick shell execution with request input")
}
// Callback-position exec: an exec/decoder function name passed as a
// string callback (array_map("system", ...), register_shutdown_function(
// "passthru", ...)). Runs on the comment-stripped source so the literal
// callback name is preserved.
if hasCallbackExecName(commentStripped) {
indicators = append(indicators, "exec/decoder function name passed as a callback")
}
// Variable-variable / dynamic-expression function call co-located with
// request input on the same line -- $$h($_GET[...]) -- a dynamic-dispatch
// RCE shape. The same-line request-var gate keeps benign dispatcher code
// (which uses $$var without attacker input) from tripping.
for _, line := range strings.Split(codeLower, "\n") {
if reVarVarCall.MatchString(line) && lineContainsRequestVar(line) {
indicators = append(indicators, "variable-variable function call with request input")
break
}
}
// preg_replace() with the /e modifier evaluates its replacement as PHP.
// Removed in PHP 7.0; flag it when request input reaches the evaluated
// replacement or its subject backreferences.
if hasPregReplaceEvalWithRequest(commentStripped) {
indicators = append(indicators, "preg_replace with /e modifier (code execution)")
}
// include/require of request input or a remote/stream wrapper -- LFI,
// RFI, and php://input code execution.
if hasDangerousInclude(commentStripped) {
indicators = append(indicators, "include/require of request input or remote/data wrapper")
}
// assert()/create_function() driven by request input -- both evaluate a
// string argument as PHP.
if hasCodeEvalPrimitiveWithRequest(commentStripped) {
indicators = append(indicators, "code-eval primitive (assert/create_function) with request input")
}
// --- Critical: call_user_func with string-built function names ---
// LEVIATHAN droppers build the target function name on the call itself:
// call_user_func("\x63"."\x75"."\x72"."\x6c", $payload) == call_user_func("curl", ...)
// File-wide hex/concat counts are unsafe here: WPML bundles PHPZip
// (inc/wpml_zip.php) which declares 20+ ZIP-format signature constants
// as hex literals ("\x50\x4b\x03\x04" etc.) and makes a single benign
// call_user_func(self::$temp) call to invoke a temp-file factory.
// Match the obfuscation on the callable target argument itself.
if strings.Contains(contentLower, "call_user_func") {
foundCallUserFuncObfuscation := false
for _, line := range strings.Split(content, "\n") {
if !strings.Contains(strings.ToLower(line), "call_user_func") {
continue
}
for _, targetArg := range phpCallUserFuncTargetArgs(line) {
// PHP 7+ accepts both "\xNN" hex and "\u{NN}" unicode-codepoint
// escapes inside double-quoted strings. Treat them as
// equivalent obfuscation forms so an attacker cannot bypass
// the detector by swapping syntax.
lineHex := countOccurrences(targetArg, `"\x`) + countOccurrences(targetArg, `"\u{`)
lineConcat := countOccurrences(targetArg, `" . "`) + countOccurrences(targetArg, `"."`)
// Typical shortest obfuscated name is 3-4 bytes ("exec", "curl",
// "eval"); require >=3 escapes AND >=2 concatenations on the
// call target argument.
if lineHex >= 3 && lineConcat >= 2 {
indicators = append(indicators, "call_user_func with obfuscated function names")
foundCallUserFuncObfuscation = true
break
}
}
if foundCallUserFuncObfuscation {
break
}
}
}
// --- High: Goto obfuscation (LEVIATHAN signature) ---
gotoCount := countOccurrences(contentLower, "goto ")
if gotoCount > 10 {
indicators = append(indicators, fmt.Sprintf("excessive goto statements (%d found - obfuscation pattern)", gotoCount))
}
// --- High: Hex-encoded string construction ---
// Only flag hex strings when accompanied by concatenation - real obfuscation
// builds function names like "\x63" . "\x75" . "\x72" . "\x6c" (= "curl").
// Standalone hex arrays (Wordfence IPv6 subnet masks, binary data) are benign.
hexStringCount := countOccurrences(content, `"\x`)
dotConcatCount := countOccurrences(content, `" . "`)
if hexStringCount > 20 && dotConcatCount > 10 {
indicators = append(indicators, fmt.Sprintf("heavy hex-encoded strings with concatenation (%d hex, %d concat - obfuscation pattern)", hexStringCount, dotConcatCount))
}
// The standalone "concat>30 alone" branch was removed: WordPress
// themes and page builders concatenate literal CSS/HTML tokens
// dozens of times in dynamic style/markup builders (sydney theme,
// elementor, beaver builder), producing FPs on every install. Real
// function-name obfuscation always pairs concat with hex escapes
// and is still caught by the combined branch above.
// --- Critical: variable-function indirection that resolves to a
// decoder. Attackers slip past the literal "eval(base64_decode("
// detector by binding the dangerous name to a variable on one line
// and calling it on another:
// $d = "base64_decode";
// $r = "eval";
// $r($d("AAAA"));
// The heuristic looks for an assignment $var = "decoder_or_exec"
// followed by a $var( invocation in the same file. Hits must
// reference at least one decoder OR one shell-exec primitive,
// since plain variable function calls show up in legitimate
// metaprogramming.
if detectVarFuncDangerousAssignment(content) {
indicators = append(indicators, "variable function name resolves to decoder or exec primitive")
}
// --- High: Variable function calls with obfuscated names ---
// call_user_func + decoder alone is too broad - Elementor, WooCommerce, and
// dozens of plugins use call_user_func_array with base64_decode legitimately.
// Only flag when combined with hex-built callable target obfuscation.
if strings.Contains(contentLower, "call_user_func") && hasDecoder {
if hasCallUserFuncHexNameBuild(content) {
indicators = append(indicators, "variable function call with decoder and obfuscation")
}
}
// --- High: Shell execution functions combined with request input ---
// Uses containsStandaloneFunc to avoid substring false positives
// (e.g. "WP_Filesystem(" matching "exec(", "preg_match(" matching "exec(")
shellFuncs := []string{"system(", "passthru(", "exec(", "shell_exec(", "popen(", "proc_open(", "pcntl_exec("}
// Two-tier detection:
// Same line = CRITICAL signal (auto-quarantine eligible)
// Co-presence = HIGH signal (alert only, not quarantined alone)
// This prevents bypass by splitting across lines while avoiding
// false-positive quarantine of legitimate plugins.
hasShellFunc := false
sameLineShellRequest := false
for _, sf := range shellFuncs {
if containsStandaloneFunc(contentLower, sf) {
hasShellFunc = true
break
}
}
hasRequestVar := containsRequestSuperglobal(contentLower)
// Same-line is a strong signal on its own: "$ret = system($_POST['cmd']);"
// almost never occurs in legitimate code. Co-presence is weaker: elFinder
// and other media-processing libraries legitimately call exec() for
// ImageMagick and also consume $_POST for AJAX routing, placing both
// tokens in the same 32 KB window. The co-presence finding is therefore
// only emitted as CORROBORATION after all the stronger indicators have
// been collected -- see the deferred append further below.
coPresenceCandidate := false
if hasShellFunc && hasRequestVar {
for _, line := range strings.Split(contentLower, "\n") {
lineHasShell := false
for _, sf := range shellFuncs {
if containsStandaloneFunc(line, sf) {
lineHasShell = true
break
}
}
if !lineHasShell {
continue
}
if containsRequestSuperglobal(line) {
sameLineShellRequest = true
break
}
}
if sameLineShellRequest {
indicators = append(indicators, "shell function with request input on same line")
} else if !IsVerifiedCMSFile(path) {
coPresenceCandidate = true
}
}
// --- High: base64 encoding/decoding with execution on same line ---
if strings.Contains(contentLower, "base64_decode") && strings.Contains(contentLower, "base64_encode") {
for _, line := range strings.Split(contentLower, "\n") {
hasBoth := strings.Contains(line, "base64_decode") && strings.Contains(line, "base64_encode")
hasExec := false
for _, sf := range shellFuncs {
if containsStandaloneFunc(line, sf) {
hasExec = true
break
}
}
if hasBoth && hasExec {
indicators = append(indicators, "base64 encode+decode with execution on same line (command relay)")
break
}
}
}
// Deferred corroboration: a lone co-presence is not enough. If a
// stronger indicator was produced above, the co-presence is appended
// both as extra context for the operator and to nudge the severity
// into the >=2 Critical band for obfuscated droppers.
if coPresenceCandidate && len(indicators) > 0 {
indicators = append(indicators, "shell function co-present with request input")
}
// --- Determine severity based on indicators ---
return phpAnalysisFromIndicators(indicators, readOK)
}
func phpAnalysisFromIndicators(indicators []string, readOK bool) phpAnalysisResult {
indicators = dedupePHPIndicators(indicators)
if len(indicators) == 0 {
return phpAnalysisResult{severity: -1, readOK: readOK}
}
// Auto-quarantine in autoresponse.AutoQuarantineFiles acts only on
// Critical findings. A single heuristic indicator has false-positive
// classes severe enough to rm live production files (WPML's PHPZip
// tripped the former "call_user_func with obfuscated" bypass on hex
// constants that build ZIP magic bytes; legitimate plugins embed
// pastebin URLs in support docstrings and release notes). Require
// two converging indicators before the severity crosses the
// destructive-action threshold; single hits surface as High and stay
// in the operator queue.
if len(indicators) >= 2 {
return phpAnalysisResult{
severity: alert.Critical,
check: "obfuscated_php",
message: "Obfuscated/malicious PHP detected",
details: fmt.Sprintf("Indicators found:\n- %s", strings.Join(indicators, "\n- ")),
readOK: readOK,
indicators: indicators,
}
}
return phpAnalysisResult{
severity: alert.High,
check: "suspicious_php_content",
message: "Suspicious PHP content detected",
details: fmt.Sprintf("Indicators found:\n- %s", strings.Join(indicators, "\n- ")),
readOK: readOK,
indicators: indicators,
}
}
func dedupePHPIndicators(indicators []string) []string {
seen := make(map[string]struct{}, len(indicators))
out := indicators[:0]
for _, indicator := range indicators {
if _, ok := seen[indicator]; ok {
continue
}
seen[indicator] = struct{}{}
out = append(out, indicator)
}
return out
}
func countOccurrences(s, substr string) int {
count := 0
offset := 0
for {
idx := strings.Index(s[offset:], substr)
if idx < 0 {
break
}
count++
offset += idx + len(substr)
}
return count
}
type hexNameAssignment struct {
obfuscated bool
hexEscapes int
concatOps int
pos int
}
func hasCallUserFuncHexNameBuild(content string) bool {
code := stripPHPCommentsFromCode(content)
assignments := findHexNameAssignments(code)
searchFrom := 0
for {
callStart, openParen, closeParen, ok := nextStandalonePHPCall(code, searchFrom, callUserFuncCallNames)
if !ok {
return false
}
arg, argOK := firstPHPCallArgument(code, openParen+1)
if argOK && phpExprHasHexNameBuild(arg) {
return true
}
if argOK {
if variable, varOK := singlePHPVariableExpr(arg); varOK && hexNameBuildAt(assignments[variable], callStart) {
return true
}
}
searchFrom = nextSearchOffset(closeParen, len(code))
}
}
func findHexNameAssignments(code string) map[string][]hexNameAssignment {
assignments := map[string][]hexNameAssignment{}
for i := 0; i < len(code); i++ {
if isPHPQuote(code[i]) {
i = skipPHPString(code, i)
continue
}
if code[i] != '$' {
continue
}
variable, next, ok := readPHPVariableName(code, i)
if !ok {
continue
}
j := skipPHPWhitespace(code, next)
opLen, directAssign, appendAssign, ok := phpAssignmentOperator(code, j)
if !ok {
i = next - 1
continue
}
exprStart := skipPHPWhitespace(code, j+opLen)
exprEnd := phpExpressionEnd(code, exprStart)
hexEscapes := 0
concatOps := 0
if directAssign || appendAssign {
hexEscapes = countPHPStringHexEscapes(code[exprStart:exprEnd])
concatOps = countPHPConcatOperators(code[exprStart:exprEnd])
}
if appendAssign {
prev := lastHexNameAssignment(assignments[variable])
hexEscapes += prev.hexEscapes
concatOps += prev.concatOps + 1
}
assignments[variable] = append(assignments[variable], hexNameAssignment{
obfuscated: phpExprCountsHaveHexNameBuild(hexEscapes, concatOps),
hexEscapes: hexEscapes,
concatOps: concatOps,
pos: exprEnd,
})
i = next - 1
}
return assignments
}
func phpAssignmentOperator(code string, pos int) (int, bool, bool, bool) {
if pos >= len(code) {
return 0, false, false, false
}
if code[pos] == '=' && (pos+1 >= len(code) || (code[pos+1] != '=' && code[pos+1] != '>')) {
return 1, true, false, true
}
if pos+1 < len(code) && code[pos] == '.' && code[pos+1] == '=' {
return 2, false, true, true
}
if pos+1 < len(code) && strings.ContainsRune("+-*/%&|^", rune(code[pos])) && code[pos+1] == '=' {
return 2, false, false, true
}
if pos+2 < len(code) {
op := code[pos : pos+3]
if op == "??=" || op == "<<=" || op == ">>=" {
return 3, false, false, true
}
}
return 0, false, false, false
}
func lastHexNameAssignment(assignments []hexNameAssignment) hexNameAssignment {
if len(assignments) == 0 {
return hexNameAssignment{}
}
return assignments[len(assignments)-1]
}
func hexNameBuildAt(assignments []hexNameAssignment, callPos int) bool {
var last hexNameAssignment
found := false
for _, assignment := range assignments {
if assignment.pos > callPos {
break
}
last = assignment
found = true
}
return found && last.obfuscated
}
func singlePHPVariableExpr(expr string) (string, bool) {
expr = strings.TrimSpace(expr)
if expr == "" || expr[0] != '$' {
return "", false
}
variable, next, ok := readPHPVariableName(expr, 0)
if !ok || skipPHPWhitespace(expr, next) != len(expr) {
return "", false
}
return variable, true
}
func phpExprHasHexNameBuild(expr string) bool {
return phpExprCountsHaveHexNameBuild(countPHPStringHexEscapes(expr), countPHPConcatOperators(expr))
}
func phpExprCountsHaveHexNameBuild(hexEscapes, concatOps int) bool {
return hexEscapes >= 3 && concatOps >= 2
}
func countPHPStringHexEscapes(expr string) int {
count := 0
for i := 0; i < len(expr); i++ {
if expr[i] != '"' {
if isPHPQuote(expr[i]) {
i = skipPHPString(expr, i)
}
continue
}
end := skipPHPString(expr, i)
if end > i {
literal := strings.ToLower(expr[i+1 : end])
count += countOccurrences(literal, `\x`)
count += countOccurrences(literal, `\u{`)
}
i = end
}
return count
}
func countPHPConcatOperators(expr string) int {
count := 0
for i := 0; i < len(expr); i++ {
if isPHPQuote(expr[i]) {
i = skipPHPString(expr, i)
continue
}
if expr[i] == '.' {
count++
}
}
return count
}
func phpCallUserFuncTargetArgs(line string) []string {
lower := strings.ToLower(line)
var args []string
for _, name := range []string{"call_user_func_array", "call_user_func"} {
searchFrom := 0
for {
pos := strings.Index(lower[searchFrom:], name)
if pos < 0 {
break
}
pos += searchFrom
next := pos + len(name)
searchFrom = next
if !isPHPFuncNameBoundary(line, pos, next) {
continue
}
i := skipPHPSpaceString(line, next)
if i >= len(line) || line[i] != '(' {
continue
}
if arg, ok := firstPHPCallArgument(line, i+1); ok {
args = append(args, arg)
}
}
}
return args
}
func isPHPFuncNameBoundary(s string, start, end int) bool {
if start > 0 {
prev := s[start-1]
if isIdentCont(prev) || prev == '>' || prev == ':' {
return false
}
}
return end >= len(s) || !isIdentCont(s[end])
}
func skipPHPSpaceString(s string, i int) int {
for i < len(s) && isPHPSpace(s[i]) {
i++
}
return i
}
func firstPHPCallArgument(s string, start int) (string, bool) {
start = skipPHPSpaceString(s, start)
i := start
depth := 0
var quote byte
escaped := false
for i < len(s) {
c := s[i]
if quote != 0 {
if escaped {
escaped = false
i++
continue
}
if c == '\\' {
escaped = true
i++
continue
}
if c == quote {
quote = 0
}
i++
continue
}
switch c {
case '"', '\'':
quote = c
case '(', '[', '{':
depth++
case ')':
if depth == 0 {
return s[start:i], true
}
depth--
case ',':
if depth == 0 {
return s[start:i], true
}
}
i++
}
return "", false
}
// containsStandaloneFunc reports whether content contains an occurrence of
// funcCall (e.g. "exec(") that is a real call to the named PHP function
// rather than something that shares the same suffix.
//
// Four shapes must be rejected:
//
// - embedded identifiers: "doubleval(" must not match "eval("; the
// preceding character is a letter/digit/underscore;
// - method invocations: "$this->DB->exec(" must not match "exec(" even
// though the preceding ">" is non-alphanumeric;
// - static invocations: "Foo::exec(" must not match for the same reason;
// - function declarations: "function exec(" names a local function of
// the same name and must not be counted as a call site.
//
// The earlier implementation only guarded against the first case and was
// the source of false positives on elFinder volume drivers that call
// "$this->DB->exec(...)" (SQLite) alongside $_SERVER references on the
// same line.
func containsStandaloneFunc(content, funcCall string) bool {
idx := 0
for {
pos := strings.Index(content[idx:], funcCall)
if pos < 0 {
return false
}
absPos := idx + pos
nextIdx := absPos + len(funcCall)
advance := func() bool {
if nextIdx >= len(content) {
return false
}
idx = nextIdx
return true
}
if absPos == 0 {
return true
}
prev := content[absPos-1]
isAlnum := (prev >= 'a' && prev <= 'z') || (prev >= 'A' && prev <= 'Z') ||
(prev >= '0' && prev <= '9') || prev == '_'
if isAlnum {
if !advance() {
return false
}
continue
}
if absPos >= 2 {
op := content[absPos-2 : absPos]
if op == "->" || op == "::" {
if !advance() {
return false
}
continue
}
} else if absPos == 1 && (prev == '>' || prev == ':') {
// Degenerate position: only one preceding byte, and it is
// the tail char of a possible method ("->") or static
// ("::") operator. We cannot confirm the second char
// because there is no second char. The conservative choice
// is to skip, so a truncated "->exec(" or "::exec(" at the
// very start of a buffer does not get flagged as a real
// shell-function call.
if !advance() {
return false
}
continue
}
const decl = "function "
if absPos >= len(decl) && content[absPos-len(decl):absPos] == decl {
if !advance() {
return false
}
continue
}
return true
}
}
func containsAny(strs []string, substrs ...string) bool {
for _, s := range strs {
for _, sub := range substrs {
if strings.Contains(s, sub) {
return true
}
}
}
return false
}
// benignPHPStubMaxScan caps how many bytes of a candidate stub the
// recogniser will read. Stub files in the wild (BackWPup folder.php
// caches at ~160 KB, WP "silence is golden" index.php at ~30 B, plugin
// 404-stub headers under 1 KB) fit comfortably under this bound;
// anything larger must surface for normal alerting rather than be
// accepted on faith.
const benignPHPStubMaxScan = 4 * 1024 * 1024
// IsBenignPHPStub reports whether the reachable code region of a PHP
// file consists only of whitespace and comments, or terminates with a
// no-argument die / exit / __halt_compiler before any other statement.
// Files matching either shape cannot execute attacker-controlled code
// via a web request: PHP either runs to EOF emitting nothing, or hits
// the terminator and stops with the remaining bytes unreachable.
//
// The recogniser is content-shape only -- it does not look at the path,
// filename, parent directory, or whether a plugin is installed. An
// attacker cannot bypass it by naming a payload to mimic a known-plugin
// file because the gate fails the moment any executable statement
// appears before a terminator. Conversely a legitimate plugin that
// writes a stub-shaped working file (BackWPup writes
// "<?php //<json>" for job state and "<?php\n//path1\n//path2..." for
// folder caches) is recognised regardless of where it puts the file.
//
// Other detectors -- signature scans, YARA, suspicious filename, the
// webshell name list -- still run on the file in their own pipelines.
// Only the path-only "anomalous PHP location" warning is suppressed
// for files that this recogniser accepts.
func IsBenignPHPStub(path string) bool {
f, err := osFS.Open(path)
if err != nil {
return false
}
defer func() { _ = f.Close() }()
buf := make([]byte, benignPHPStubMaxScan)
n, _ := f.Read(buf)
if n == 0 {
return false
}
info, err := f.Stat()
complete := err == nil && info.Size() <= int64(n)
return IsBenignPHPStubBytesComplete(buf[:n], complete)
}
// IsBenignPHPStubBytes is the buffer-only variant. The realtime fanotify
// path uses it on the bytes it already read from the file descriptor;
// IsBenignPHPStub provides the path-based entry point for the polled
// fileindex scan. Both rely on the same parser so realtime and scheduled
// scans agree on which files are stubs.
//
// The parser tokenises the leading region of the buffer:
//
// - Optional UTF-8 BOM and whitespace, then the literal "<?php" opener.
// The short-echo opener "<?=" is rejected because it emits output.
// A "<?phpfoo" run-together opener is rejected because PHP requires
// whitespace (or EOF) after the tag.
// - Repeatedly accept whitespace, line comments ("//..." or "#..." up to
// newline or "?>"), and balanced block comments ("/* ... */"). A "/*"
// without a matching "*/" inside the scanned window is rejected -- we
// cannot prove the rest of the file is comment.
// - Accept the no-argument forms of die, exit, and __halt_compiler as
// terminators. Once seen, the rest of the buffer is treated as
// unreachable.
// - Reject any closing "?>" tag (would allow HTML escape and a later
// "<?php" re-entry that this gate does not analyse).
// - Reject any other identifier (return, if, system, eval, function,
// class, ...) and any stray punctuation ("$", "(", "=", ";", ...).
// Those are statements we cannot prove benign.
// - If the loop reaches EOF in a complete buffer having only seen
// whitespace and comments, accept: PHP outputs nothing and executes
// nothing.
func IsBenignPHPStubBytes(buf []byte) bool {
return IsBenignPHPStubBytesComplete(buf, true)
}
// IsBenignPHPStubBytesComplete is like IsBenignPHPStubBytes, but complete
// tells the parser whether buf contains the entire file. Comment-only stubs
// require a complete buffer; no-argument terminators do not, because bytes
// after them are unreachable to PHP.
func IsBenignPHPStubBytesComplete(buf []byte, complete bool) bool {
if len(buf) >= 3 && buf[0] == 0xEF && buf[1] == 0xBB && buf[2] == 0xBF {
buf = buf[3:]
}
i := 0
for i < len(buf) && isPHPSpace(buf[i]) {
i++
}
const opener = "<?php"
if !bytes.HasPrefix(buf[i:], []byte(opener)) {
return false
}
i += len(opener)
if i < len(buf) && !isPHPSpace(buf[i]) {
return false
}
for i < len(buf) {
c := buf[i]
if isPHPSpace(c) {
i++
continue
}
if c == '#' {
i = skipPHPLineComment(buf, i)
continue
}
if c == '/' && i+1 < len(buf) && buf[i+1] == '/' {
i = skipPHPLineComment(buf, i)
continue
}
if c == '/' && i+1 < len(buf) && buf[i+1] == '*' {
i += 2
end := bytes.Index(buf[i:], []byte("*/"))
if end < 0 {
return false
}
i += end + 2
continue
}
if c == '?' && i+1 < len(buf) && buf[i+1] == '>' {
return false
}
if isIdentStart(c) {
start := i
for i < len(buf) && isIdentCont(buf[i]) {
i++
}
word := strings.ToLower(string(buf[start:i]))
return isNoArgPHPTerminator(buf, i, word, complete)
}
return false
}
return complete
}
func isPHPSpace(c byte) bool {
return c == ' ' || c == '\t' || c == '\n' || c == '\r' || c == '\v' || c == '\f'
}
func isIdentStart(c byte) bool {
return c == '_' || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
}
func isIdentCont(c byte) bool {
return isIdentStart(c) || (c >= '0' && c <= '9')
}
func skipPHPSpace(buf []byte, i int) int {
for i < len(buf) && isPHPSpace(buf[i]) {
i++
}
return i
}
func skipPHPLineComment(buf []byte, i int) int {
for i < len(buf) {
if buf[i] == '\n' {
return i
}
if buf[i] == '?' && i+1 < len(buf) && buf[i+1] == '>' {
return i
}
i++
}
return i
}
func isNoArgPHPTerminator(buf []byte, i int, word string, complete bool) bool {
if word != "die" && word != "exit" && word != "__halt_compiler" {
return false
}
i = skipPHPSpace(buf, i)
if word == "__halt_compiler" {
next, ok := consumeEmptyPHPParens(buf, i)
if !ok {
return false
}
return phpTerminatorStatementEnds(buf, next, complete)
}
if i >= len(buf) {
return complete
}
if buf[i] == ';' {
return true
}
if buf[i] == '?' && i+1 < len(buf) && buf[i+1] == '>' {
return true
}
if buf[i] != '(' {
return false
}
next, ok := consumeEmptyPHPParens(buf, i)
if !ok {
return false
}
return phpTerminatorStatementEnds(buf, next, complete)
}
func consumeEmptyPHPParens(buf []byte, i int) (int, bool) {
if i >= len(buf) || buf[i] != '(' {
return i, false
}
i = skipPHPSpace(buf, i+1)
if i >= len(buf) || buf[i] != ')' {
return i, false
}
return skipPHPSpace(buf, i+1), true
}
func phpTerminatorStatementEnds(buf []byte, i int, complete bool) bool {
i = skipPHPSpace(buf, i)
if i >= len(buf) {
return complete
}
if buf[i] == ';' {
return true
}
return buf[i] == '?' && i+1 < len(buf) && buf[i+1] == '>'
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckPHPConfigChanges monitors .user.ini and .htaccess for PHP configuration
// changes that weaken security (disabling disable_functions, enabling dangerous functions).
// This runs as a deep check. The fanotify watcher also catches .user.ini writes in real-time.
func CheckPHPConfigChanges(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
homeDirs, _ := GetScanHomeDirs(ctx)
for _, homeEntry := range homeDirs {
if !homeEntry.IsDir() {
continue
}
user := homeEntry.Name()
// Check .user.ini in public_html and addon domains
iniPaths := []string{
filepath.Join("/home", user, "public_html", ".user.ini"),
}
subDirs, _ := osFS.ReadDir(filepath.Join("/home", user))
for _, sd := range subDirs {
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
iniPath := filepath.Join("/home", user, sd.Name(), ".user.ini")
if _, err := osFS.Stat(iniPath); err == nil {
iniPaths = append(iniPaths, iniPath)
}
}
}
for _, iniPath := range iniPaths {
// Hash-based change detection
hash, err := hashFileContent(iniPath)
if err != nil {
continue
}
key := "_phpini:" + iniPath
prev, exists := store.GetRaw(key)
store.SetRaw(key, hash)
if !exists || prev == hash {
continue
}
// File changed - analyze content for dangerous settings
data, err := osFS.ReadFile(iniPath)
if err != nil {
continue
}
content := strings.ToLower(string(data))
// Check for dangerous PHP settings
dangerous := analyzePHPINI(content)
if len(dangerous) > 0 {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "php_config_change",
Message: fmt.Sprintf("Dangerous PHP config change: %s (user: %s)", iniPath, user),
Details: fmt.Sprintf("Dangerous settings:\n- %s", strings.Join(dangerous, "\n- ")),
})
}
}
}
return findings
}
func analyzePHPINI(content string) []string {
var dangerous []string
// disable_functions being cleared or reduced
if strings.Contains(content, "disable_functions") {
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, ";") || strings.HasPrefix(line, "#") {
continue
}
if strings.HasPrefix(line, "disable_functions") {
// Check if it's being set to empty or very short
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
val := strings.TrimSpace(parts[1])
if val == "" || val == "\"\"" || val == "''" || val == "none" {
dangerous = append(dangerous, "disable_functions cleared (all PHP functions enabled)")
}
}
}
}
}
// Dangerous functions being enabled
dangerousFuncs := []string{
"exec", "system", "passthru", "shell_exec",
"popen", "proc_open", "pcntl_exec",
}
if strings.Contains(content, "disable_functions") {
// Check if dangerous functions are NOT in the disable list
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "disable_functions") {
for _, fn := range dangerousFuncs {
if !strings.Contains(line, fn) {
// This dangerous function is not disabled
dangerous = append(dangerous, fmt.Sprintf("%s not in disable_functions", fn))
}
}
break
}
}
}
// allow_url_fopen / allow_url_include being enabled
if strings.Contains(content, "allow_url_include") {
for _, line := range strings.Split(content, "\n") {
if strings.Contains(line, "allow_url_include") && (strings.Contains(line, "on") || strings.Contains(line, "1")) {
dangerous = append(dangerous, "allow_url_include enabled (remote code inclusion)")
}
}
}
// open_basedir being removed
if strings.Contains(content, "open_basedir") {
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "open_basedir") {
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
val := strings.TrimSpace(parts[1])
if val == "" || val == "/" || val == "\"\"" {
dangerous = append(dangerous, "open_basedir cleared or set to / (no restriction)")
}
}
}
}
}
return dangerous
}
package checks
import (
"path/filepath"
"strings"
)
// executablePHPExtensions are the extensions a stock PHP-capable web server
// (Apache mod_php / PHP-FPM via EasyApache4, LiteSpeed LSAPI, Nginx + php-fpm)
// routes to the PHP interpreter by default. Any file with one of these names
// can execute PHP, so a content scan that skipped them would let a webshell
// hide behind a non-".php" name. ".phps" is deliberately excluded: the stock
// handler renders it as highlighted source, it does not execute. Lowercase,
// leading dot.
var executablePHPExtensions = []string{
".php", ".php2", ".php3", ".php4", ".php5", ".php6", ".php7", ".php8",
".phtml", ".pht",
}
// IsExecutablePHPName reports whether a (lowercased) filename has an extension
// that a stock PHP handler executes. Shared by the realtime fanotify path and
// the periodic content scanners so the two never drift apart. It is a coarse,
// default-deny gate for content analysis only; per-directory .htaccess handler
// remappings are layered on top via phpHandlerOverlay.
func IsExecutablePHPName(nameLower string) bool {
return isExecutablePHPName(nameLower)
}
func isExecutablePHPName(nameLower string) bool {
for _, ext := range executablePHPExtensions {
if strings.HasSuffix(nameLower, ext) {
return true
}
}
return false
}
// phpHandlerOverlay carries the extra PHP execution mappings discovered from
// .htaccess files while walking a directory tree. Apache merges a parent
// directory's directives into its children, so the overlay accumulates down
// the recursion: a mapping declared in a parent applies to every descendant.
type phpHandlerOverlay struct {
// exts holds extra ".ext" entries (lowercase, leading dot) that a local
// AddHandler/AddType maps to a PHP handler, e.g. "AddHandler
// application/x-httpd-php .inc".
exts map[string]struct{}
// names holds exact lowercase basenames matched by a <Files> container.
names map[string]struct{}
// scanAll is set when a SetHandler/ForceType routes the PHP interpreter
// for the whole directory with no extension filter. Every file in the
// subtree then executes as PHP and must be content-analysed.
scanAll bool
}
func (o phpHandlerOverlay) active() bool {
return o.scanAll || len(o.exts) > 0 || len(o.names) > 0
}
// executes reports whether a file named nameLower (lowercased) runs as PHP
// under this overlay, either by a stock extension, a directory-wide handler,
// or an .htaccess-mapped extension.
func (o phpHandlerOverlay) executes(nameLower string) bool {
if o.scanAll {
return true
}
if isExecutablePHPName(nameLower) {
return true
}
if _, ok := o.names[nameLower]; ok {
return true
}
for ext := range o.exts {
if strings.HasSuffix(nameLower, ext) {
return true
}
}
return false
}
// mergeHtaccess returns a new overlay combining the receiver (inherited from
// the parent directory) with any PHP handler directives found in the .htaccess
// at dirHtaccessContent. The receiver is never mutated, so sibling directories
// do not see each other's mappings.
func (o phpHandlerOverlay) mergeHtaccess(content []byte) phpHandlerOverlay {
if len(content) == 0 {
return o
}
parsed := parsePHPHandlerDirectives(content)
if !parsed.active() {
return o
}
merged := phpHandlerOverlay{scanAll: o.scanAll || parsed.scanAll}
if len(o.exts) > 0 || len(parsed.exts) > 0 {
merged.exts = make(map[string]struct{}, len(o.exts)+len(parsed.exts))
for e := range o.exts {
merged.exts[e] = struct{}{}
}
for e := range parsed.exts {
merged.exts[e] = struct{}{}
}
}
if len(o.names) > 0 || len(parsed.names) > 0 {
merged.names = make(map[string]struct{}, len(o.names)+len(parsed.names))
for name := range o.names {
merged.names[name] = struct{}{}
}
for name := range parsed.names {
merged.names[name] = struct{}{}
}
}
return merged
}
// parsePHPHandlerDirectives extracts extension-to-PHP mappings from .htaccess
// content. It recognises:
//
// AddHandler <php-handler> .ext [.ext...]
// AddType <php-mime> .ext [.ext...]
// SetHandler <php-handler> (no extension -> whole directory)
// ForceType <php-mime> (no extension -> whole directory)
//
// A directive counts as PHP when its handler/MIME token names PHP directly or
// routes through proxy-fcgi. x-httpd-php-source is excluded because it renders
// highlighted source instead of executing. Matching is case-insensitive.
func parsePHPHandlerDirectives(content []byte) phpHandlerOverlay {
var overlay phpHandlerOverlay
var contexts []phpHandlerOverlay
for _, logical := range joinHtaccessContinuations(strings.Split(string(content), "\n")) {
line := strings.TrimSpace(logical.text)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if ctx, ok := openPHPHandlerContext(line); ok {
contexts = append(contexts, ctx)
continue
}
if closesPHPHandlerContext(line) {
if len(contexts) > 0 {
contexts = contexts[:len(contexts)-1]
}
continue
}
fields := apacheDirectiveFields(line)
if len(fields) < 2 {
continue
}
directive := strings.ToLower(fields[0])
handler := fields[1]
switch directive {
case "addhandler", "addtype":
if !handlerIsPHP(handler) {
continue
}
// Remaining fields are extensions.
addExtensions(&overlay, normalizedExts(fields[2:]))
if len(fields) == 2 {
mergeContext(&overlay, contexts)
}
case "sethandler", "forcetype":
if !handlerIsPHP(handler) {
continue
}
exts := normalizedExts(fields[2:])
if len(exts) > 0 {
addExtensions(&overlay, exts)
continue
}
if mergeContext(&overlay, contexts) {
continue
}
if len(contexts) > 0 {
continue
}
// No extension or file filter: the handler applies to every
// file in this directory.
overlay.scanAll = true
}
}
return overlay
}
func handlerIsPHP(token string) bool {
token = strings.ToLower(strings.Trim(strings.TrimSpace(token), `"'`))
// The source viewer renders highlighted source instead of executing.
if strings.Contains(token, "php-source") {
return false
}
if strings.Contains(token, "php") {
return true
}
// cPanel/PHP-FPM .htaccess wiring can use a custom socket alias whose
// path does not include the literal "php". A proxy-fcgi handler still
// routes matching files to an executable backend, so remapped extensions
// must be treated as PHP-executed for scanning and .htaccess alerts.
return strings.HasPrefix(token, "proxy:") && strings.Contains(token, "fcgi://")
}
// normalizeExt turns an .htaccess extension token into a lowercase
// leading-dot extension, rejecting anything that is not a plain extension
// token (e.g. a stray flag or MIME fragment).
func normalizeExt(token string) string {
t := strings.ToLower(strings.Trim(strings.TrimSpace(token), `"'`))
t = strings.TrimPrefix(t, ".")
if t == "" {
return ""
}
for _, r := range t {
alnum := (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9')
if !alnum && r != '_' && r != '-' {
return ""
}
}
return "." + t
}
// htaccessLogicalLine is one Apache directive after joining physical
// continuation lines. text is the joined directive (continuation backslashes
// removed); lines are the original physical lines it spans, so a per-line
// rewrite can drop or keep them together. start is the 0-based index of the
// first physical line.
type htaccessLogicalLine struct {
text string
lines []string
start int
}
// joinHtaccessContinuations groups physical .htaccess lines into logical
// directives, honoring Apache's trailing-backslash line continuation: a line
// ending in "\" is joined with the next. Without this, a directive split as
// "AddHandler ...php \" + ".jpg" reads as two harmless physical lines and
// every per-line scanner misses the remap.
func joinHtaccessContinuations(physical []string) []htaccessLogicalLine {
var out []htaccessLogicalLine
for i := 0; i < len(physical); {
start := i
var sb strings.Builder
var span []string
for {
cur := physical[i]
span = append(span, cur)
body, continues := htaccessContinuationBody(cur, i < len(physical)-1)
if continues {
sb.WriteString(body)
i++
continue
}
sb.WriteString(body)
break
}
out = append(out, htaccessLogicalLine{text: sb.String(), lines: span, start: start})
i++
}
return out
}
func htaccessContinuationBody(line string, hasNext bool) (string, bool) {
// On CRLF input (strings.Split keeps the trailing "\r") the continuation
// backslash sits before the carriage return, so strip it before the suffix
// test and from the joined text.
body := strings.TrimRight(line, "\r")
// A trailing backslash continues only when a next line exists; a backslash on
// the final line is left literal, matching Apache.
if !hasNext || !strings.HasSuffix(body, `\`) {
return body, false
}
if htaccessContainerLineKeepsTrailingBackslash(body) {
return body, false
}
return body[:len(body)-1], true
}
func htaccessContainerLineKeepsTrailingBackslash(body string) bool {
candidate := strings.TrimSpace(strings.TrimSuffix(body, `\`))
if !strings.Contains(candidate, ">") {
return false
}
if _, ok := openPHPHandlerContext(candidate); ok {
return true
}
return closesPHPHandlerContext(candidate)
}
func apacheDirectiveFields(line string) []string {
fields := strings.Fields(line)
for i, field := range fields {
if strings.HasPrefix(field, "#") {
return fields[:i]
}
}
return fields
}
func normalizedExts(tokens []string) []string {
var exts []string
for _, token := range tokens {
if ext := normalizeExt(token); ext != "" {
exts = append(exts, ext)
}
}
return exts
}
func addExtensions(overlay *phpHandlerOverlay, exts []string) {
if len(exts) == 0 {
return
}
for _, ext := range exts {
if isExecutablePHPName("x" + ext) {
continue
}
if overlay.exts == nil {
overlay.exts = make(map[string]struct{}, len(exts))
}
overlay.exts[ext] = struct{}{}
}
}
func mergeContext(overlay *phpHandlerOverlay, contexts []phpHandlerOverlay) bool {
merged := false
for _, ctx := range contexts {
if len(ctx.exts) > 0 {
if overlay.exts == nil {
overlay.exts = make(map[string]struct{}, len(ctx.exts))
}
for ext := range ctx.exts {
overlay.exts[ext] = struct{}{}
merged = true
}
}
if len(ctx.names) > 0 {
if overlay.names == nil {
overlay.names = make(map[string]struct{}, len(ctx.names))
}
for name := range ctx.names {
overlay.names[name] = struct{}{}
merged = true
}
}
}
return merged
}
func openPHPHandlerContext(line string) (phpHandlerOverlay, bool) {
lower := strings.ToLower(line)
switch {
case strings.HasPrefix(lower, "<filesmatch"):
pattern := apacheContainerArgument(line)
return overlayForFilesMatch(pattern), true
case strings.HasPrefix(lower, "<files"):
name := apacheContainerArgument(line)
name = strings.ToLower(strings.TrimSpace(name))
if name == "" || strings.Contains(name, "/") {
return phpHandlerOverlay{}, true
}
if strings.ContainsAny(name, "*?[") {
overlay := phpHandlerOverlay{}
addExtensions(&overlay, []string{filepath.Ext(name)})
return overlay, true
}
overlay := phpHandlerOverlay{names: map[string]struct{}{name: {}}}
return overlay, true
default:
return phpHandlerOverlay{}, false
}
}
func closesPHPHandlerContext(line string) bool {
lower := strings.ToLower(line)
return strings.HasPrefix(lower, "</filesmatch") || strings.HasPrefix(lower, "</files")
}
func apacheContainerArgument(line string) string {
start := strings.IndexByte(line, ' ')
end := strings.LastIndexByte(line, '>')
if start < 0 || end <= start {
return ""
}
arg := strings.TrimSpace(line[start:end])
if len(arg) >= 2 {
quote := arg[0]
if (quote == '"' || quote == '\'') && arg[len(arg)-1] == quote {
arg = arg[1 : len(arg)-1]
}
}
return arg
}
func overlayForFilesMatch(pattern string) phpHandlerOverlay {
overlay := phpHandlerOverlay{}
addExtensions(&overlay, extensionsFromFilesMatchPattern(pattern))
return overlay
}
func extensionsFromFilesMatchPattern(pattern string) []string {
pattern = strings.ToLower(pattern)
seen := make(map[string]struct{})
var exts []string
add := func(ext string) {
ext = normalizeExt(ext)
if ext == "" {
return
}
if _, ok := seen[ext]; ok {
return
}
seen[ext] = struct{}{}
exts = append(exts, ext)
}
for i := 0; i+1 < len(pattern); i++ {
if pattern[i] != '\\' || pattern[i+1] != '.' {
continue
}
j := i + 2
if j < len(pattern) && pattern[j] == '(' {
end := strings.IndexByte(pattern[j+1:], ')')
if end >= 0 {
group := pattern[j+1 : j+1+end]
group = strings.TrimPrefix(group, "?:")
for _, part := range strings.Split(group, "|") {
add(part)
}
}
continue
}
start := j
for j < len(pattern) {
c := pattern[j]
alnum := (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9')
if !alnum && c != '_' && c != '-' {
break
}
j++
}
add(pattern[start:j])
i = j
}
return exts
}
func phpPathExecutes(path, nameLower string) bool {
if isExecutablePHPName(nameLower) {
return true
}
overlay := phpHandlerOverlay{}
for _, dir := range htaccessAncestorDirs(path) {
if htaccess, err := osFS.ReadFile(filepath.Join(dir, ".htaccess")); err == nil {
overlay = overlay.mergeHtaccess(htaccess)
}
}
return overlay.executes(nameLower)
}
func htaccessAncestorDirs(path string) []string {
dir := filepath.Clean(filepath.Dir(path))
if dir == "." {
return nil
}
var dirs []string
for {
dirs = append(dirs, dir)
if stopHtaccessAncestorWalk(dir) {
break
}
parent := filepath.Dir(dir)
if parent == dir {
break
}
dir = parent
}
for i, j := 0, len(dirs)-1; i < j; i, j = i+1, j-1 {
dirs[i], dirs[j] = dirs[j], dirs[i]
}
return dirs
}
func stopHtaccessAncestorWalk(dir string) bool {
if !strings.HasPrefix(dir, "/home/") {
return false
}
rest := strings.TrimPrefix(dir, "/home/")
return rest != "" && !strings.Contains(rest, "/")
}
package checks
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
neturl "net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
)
// wpCLIFlags are the extra flags and env CSM adds to every wp-cli invocation.
//
// WP_CLI_PHP_ARGS disables PHP's display_errors/error_reporting for the
// bootstrap so stray Notices/Warnings/Deprecated messages from the site
// don't get emitted (they'd only be a problem on stderr, which we already
// discard, but suppressing them also avoids exit-255 on strict hosts that
// promote warnings to errors).
//
// --skip-plugins and --skip-themes make wp-cli enumerate plugins from the
// filesystem without loading them. That removes the biggest source of log
// noise: one broken plugin (e.g. a PHP Parse error in litespeed-cache on a
// site nobody updated for years) would otherwise crash the whole `wp plugin
// list` call with exit 255, or spew backtraces from plugins that call
// wp_redirect() during admin bootstrap. Skipping loads gives us the list
// plus update_version unchanged.
const wpCLIFlags = `WP_CLI_PHP_ARGS='-d display_errors=0 -d error_reporting=0' wp --skip-plugins --skip-themes `
var wpOrgHTTPClient = &http.Client{Timeout: 10 * time.Second}
type wpOrgResponse struct {
Slug string `json:"slug"`
Version string `json:"version"`
Tested string `json:"tested"`
Error string `json:"error"`
}
// parseWPOrgPluginResponse parses a JSON body from the WordPress.org plugin
// information API into a store.PluginInfo. It returns an error if the response
// contains an error field (e.g. "Plugin not found.") or if the JSON is invalid.
func parseWPOrgPluginResponse(body []byte) (store.PluginInfo, error) {
var resp wpOrgResponse
if err := json.Unmarshal(body, &resp); err != nil {
return store.PluginInfo{}, fmt.Errorf("wporg: invalid JSON: %w", err)
}
if resp.Error != "" {
return store.PluginInfo{}, fmt.Errorf("wporg: %s", resp.Error)
}
return store.PluginInfo{
LatestVersion: resp.Version,
TestedUpTo: resp.Tested,
LastChecked: time.Now().Unix(),
}, nil
}
// fetchWPOrgPluginInfo queries the WordPress.org plugin information API for the
// given slug and returns the parsed PluginInfo.
func fetchWPOrgPluginInfo(ctx context.Context, slug string) (store.PluginInfo, error) {
url := "https://api.wordpress.org/plugins/info/1.2/?action=plugin_information" +
"&request[slug]=" + neturl.QueryEscape(slug) +
"&request[fields][version]=1&request[fields][tested]=1"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return store.PluginInfo{}, fmt.Errorf("wporg: building request: %w", err)
}
resp, err := wpOrgHTTPClient.Do(req)
if err != nil {
return store.PluginInfo{}, fmt.Errorf("wporg: HTTP request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return store.PluginInfo{}, fmt.Errorf("wporg: reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return store.PluginInfo{}, fmt.Errorf("wporg: unexpected status %d", resp.StatusCode)
}
return parseWPOrgPluginResponse(body)
}
// parseVersion splits a dotted version string like "6.4.2" into []int{6, 4, 2}.
// Non-numeric segments are treated as 0.
func parseVersion(v string) []int {
if v == "" {
return nil
}
parts := strings.Split(v, ".")
out := make([]int, len(parts))
for i, p := range parts {
n, err := strconv.Atoi(p)
if err != nil {
n = 0
}
out[i] = n
}
return out
}
// compareVersions returns whether there is a major version gap and
// how many minor versions behind the installed version is.
func compareVersions(installed, available string) (majorGap bool, minorBehind int) {
iv := parseVersion(installed)
av := parseVersion(available)
if len(iv) < 2 || len(av) < 2 {
return false, 0
}
if av[0] > iv[0] {
return true, 0
}
if av[0] == iv[0] && av[1] > iv[1] {
return false, av[1] - iv[1]
}
return false, 0
}
// pluginAlertSeverity returns "critical", "high", "warning", or "" for a version gap.
func pluginAlertSeverity(installed, available string) string {
majorGap, minorBehind := compareVersions(installed, available)
if majorGap {
return "critical"
}
if minorBehind >= 3 {
return "high"
}
// Check if there is any difference at all.
iv := parseVersion(installed)
av := parseVersion(available)
if len(iv) < 2 || len(av) < 2 {
return ""
}
// Compare all parsed components to detect if available is actually newer.
// If installed >= available at every component, the site is up to date
// (or ahead, e.g. custom/premium builds). Only warn when behind.
maxLen := len(iv)
if len(av) > maxLen {
maxLen = len(av)
}
for i := 0; i < maxLen; i++ {
var a, b int
if i < len(iv) {
a = iv[i]
}
if i < len(av) {
b = av[i]
}
if b > a {
return "warning" // available is newer at this component
}
if a > b {
return "" // installed is ahead - not outdated
}
}
return "" // identical
}
const pluginCheckWorkers = 5
// CheckOutdatedPlugins scans all WordPress installations for plugins with
// available updates and emits findings based on severity of the version gap.
// Results are cached in bbolt with a configurable refresh interval (default 24h).
func CheckOutdatedPlugins(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
db := store.Global()
if db == nil {
return nil
}
// Refresh cache if stale.
lastRefresh := db.GetPluginRefreshTime()
interval := time.Duration(cfg.Thresholds.PluginCheckIntervalMin) * time.Minute
if time.Since(lastRefresh) > interval {
refreshPluginCache(ctx, db)
}
return evaluatePluginCache(db)
}
// findAllWPInstalls discovers all wp-config.php files under /home, deduplicating
// and skipping cache/backup/staging/trash paths.
func findAllWPInstalls() []string {
patterns := []string{
"/home/*/public_html/wp-config.php",
"/home/*/public_html/*/wp-config.php",
"/home/*/*/wp-config.php",
}
seen := make(map[string]bool)
var results []string
skipSubstrings := []string{"/cache/", "/backup", "/staging", "/.trash/"}
for _, pattern := range patterns {
matches, _ := osFS.Glob(pattern)
for _, m := range matches {
skip := false
lower := strings.ToLower(m)
for _, sub := range skipSubstrings {
if strings.Contains(lower, sub) {
skip = true
break
}
}
if skip {
continue
}
if !seen[m] {
seen[m] = true
results = append(results, m)
}
}
}
return results
}
// wpCLIPluginEntry mirrors the JSON output of `wp plugin list --format=json`.
type wpCLIPluginEntry struct {
Name string `json:"name"`
Status string `json:"status"`
Version string `json:"version"`
UpdateVersion string `json:"update_version"`
}
// refreshPluginCache discovers all WP installs, runs wp-cli to inventory
// plugins for each site, enriches free plugins via the WordPress.org API,
// and stores everything in bbolt.
func refreshPluginCache(ctx context.Context, db *store.DB) {
wpConfigs := findAllWPInstalls()
if len(wpConfigs) == 0 {
return
}
var mu sync.Mutex
var wg sync.WaitGroup
successCount := 0
var timeoutCount, execFailCount, parseFailCount int
slugsSeen := make(map[string]bool)
discoveredPaths := make(map[string]bool)
jobs := make(chan string, len(wpConfigs))
for i := 0; i < pluginCheckWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for wpConfig := range jobs {
if ctx.Err() != nil {
return
}
wpPath := filepath.Dir(wpConfig)
user := extractUser(wpPath)
domain := extractWPDomain(ctx, wpPath, user)
if ctx.Err() != nil {
return
}
mu.Lock()
discoveredPaths[wpPath] = true
mu.Unlock()
// Run wp plugin list as the site owner on stdout-only so PHP
// notices/warnings on stderr can't corrupt the JSON we parse.
// Use --path instead of shell cd to avoid shell injection via
// crafted directory names on shared hosting.
out, err := cmdExec.RunContextStdout(ctx, "su", "-", user, "-s", "/bin/bash", "-c",
wpCLIFlags+"plugin list --fields=name,status,version,update_version --format=json --path="+shellQuote(wpPath),
)
if err != nil {
if ctx.Err() != nil {
return
}
mu.Lock()
if errors.Is(err, context.DeadlineExceeded) {
timeoutCount++
} else {
execFailCount++
}
mu.Unlock()
continue
}
var entries []wpCLIPluginEntry
if err := json.Unmarshal(out, &entries); err != nil {
mu.Lock()
parseFailCount++
mu.Unlock()
continue
}
sitePlugins := store.SitePlugins{
Account: user,
Domain: domain,
}
for _, e := range entries {
sitePlugins.Plugins = append(sitePlugins.Plugins, store.SitePluginEntry{
Slug: e.Name,
Name: e.Name,
Status: e.Status,
InstalledVersion: e.Version,
UpdateVersion: e.UpdateVersion,
})
mu.Lock()
slugsSeen[e.Name] = true
mu.Unlock()
}
if err := db.SetSitePlugins(wpPath, sitePlugins); err != nil {
fmt.Fprintf(os.Stderr, "plugincheck: store failed for %s: %v\n", wpPath, err)
continue
}
mu.Lock()
successCount++
mu.Unlock()
}
}()
}
for _, wpConfig := range wpConfigs {
if ctx.Err() != nil {
break
}
jobs <- wpConfig
}
close(jobs)
wg.Wait()
if ctx.Err() != nil {
return
}
// Enrich free plugins via WordPress.org API (one lookup per unique slug).
mu.Lock()
slugList := make([]string, 0, len(slugsSeen))
for slug := range slugsSeen {
slugList = append(slugList, slug)
}
mu.Unlock()
for _, slug := range slugList {
if ctx.Err() != nil {
return
}
// Skip if we have a recent cached entry (< 24h).
if cached, ok := db.GetPluginInfo(slug); ok {
if time.Since(time.Unix(cached.LastChecked, 0)) < 24*time.Hour {
continue
}
}
info, err := fetchWPOrgPluginInfo(ctx, slug)
if err != nil {
// Not found on .org = premium/custom plugin, skip silently.
continue
}
_ = db.SetPluginInfo(slug, info)
}
// Prune cache entries for WP installs no longer on disk.
allCached := db.AllSitePlugins()
for path := range allCached {
if !discoveredPaths[path] {
_ = db.DeleteSitePlugins(path)
}
}
// Only mark refresh as complete if the majority of sites refreshed
// successfully. A partial failure (e.g. one wp-cli timeout on a 100-site
// server) should not freeze ALL stale data for 24 hours. But if most
// sites failed (e.g. transient PHP issue), don't mark as fresh - allow
// retry next cycle.
mu.Lock()
sc := successCount
to, exf, pf := timeoutCount, execFailCount, parseFailCount
mu.Unlock()
failCount := len(wpConfigs) - sc
ts := time.Now().Format("2006-01-02 15:04:05")
if sc == 0 {
fmt.Fprintf(os.Stderr, "[%s] plugincheck: refresh failed, 0/%d sites succeeded%s, not updating timestamp\n",
ts, len(wpConfigs), failureBreakdown(to, exf, pf))
return
}
if failCount > sc {
fmt.Fprintf(os.Stderr, "[%s] plugincheck: refresh partial, %d/%d sites failed%s, not updating timestamp\n",
ts, failCount, len(wpConfigs), failureBreakdown(to, exf, pf))
return
}
if failCount > 0 {
fmt.Fprintf(os.Stderr, "[%s] plugincheck: refreshed %d/%d sites%s\n",
ts, sc, len(wpConfigs), failureBreakdown(to, exf, pf))
}
_ = db.SetPluginRefreshTime(time.Now())
}
// failureBreakdown formats " (timeout=N exec_fail=N json_fail=N)" when any
// category is non-zero, or "" otherwise. Keeps the refresh log to one line
// instead of one line per broken site.
func failureBreakdown(timeout, execFail, parseFail int) string {
if timeout == 0 && execFail == 0 && parseFail == 0 {
return ""
}
return fmt.Sprintf(" (timeout=%d exec_fail=%d json_fail=%d)", timeout, execFail, parseFail)
}
// evaluatePluginCache reads the cached plugin inventory and emits one
// aggregated finding per site listing every outdated active plugin. The
// per-site rollup keeps the alert channel under control during a deep
// scan tier on hosts with many sites: the previous one-finding-per-
// outdated-plugin shape produced ~1000 findings on a 200-account host
// and saturated the 500-deep alert channel buffer, dropping real
// signal under "alert channel full, dropping deep finding:
// outdated_plugins".
//
// Aggregation rules:
// - Severity = max of constituents (critical > high > warning).
// - Message = "<count> outdated plugins on <domain> (<account>):
// <severity-label>" - searchable and self-describing.
// - Details lists each plugin slug, installed version, available
// version, and per-plugin severity, one per line, so an operator
// triaging the alert sees the same per-plugin breakdown as before.
func evaluatePluginCache(db *store.DB) []alert.Finding {
var findings []alert.Finding
allSites := db.AllSitePlugins()
for wpPath, site := range allSites {
var (
detailLines []string
worstSeverity alert.Severity
worstSevLabel string
outdatedTotal int
)
// Track whether worstSeverity has been set at all: alert.Severity's
// zero value is Warning, so a strict "newer rank > current rank"
// comparison would never overwrite the initial state on a site
// whose constituents are all Warning, leaving worstSevLabel empty.
worstSet := false
for _, p := range site.Plugins {
if p.Status != "active" && p.Status != "active-network" {
continue
}
// Determine the best available version: prefer wp-cli's update_version,
// fall back to WordPress.org API cache.
available := p.UpdateVersion
if available == "" {
if info, ok := db.GetPluginInfo(p.Slug); ok {
available = info.LatestVersion
}
}
if available == "" {
// No update source - skip silently (custom/private plugins).
continue
}
sev := pluginAlertSeverity(p.InstalledVersion, available)
if sev == "" {
continue
}
var severity alert.Severity
switch sev {
case "critical":
severity = alert.Critical
case "high":
severity = alert.High
default:
severity = alert.Warning
}
outdatedTotal++
detailLines = append(detailLines, fmt.Sprintf("- %s (%s): %s -> %s [%s]",
p.Slug, p.Name, p.InstalledVersion, available, sev))
if !worstSet || severityRank(severity) > severityRank(worstSeverity) {
worstSeverity = severity
worstSevLabel = sev
worstSet = true
}
}
if outdatedTotal == 0 {
continue
}
findings = append(findings, alert.Finding{
Severity: worstSeverity,
Check: "outdated_plugins",
Message: fmt.Sprintf("%d outdated plugin%s on %s (%s): worst severity %s",
outdatedTotal, plural(outdatedTotal), site.Domain, site.Account, worstSevLabel),
Details: fmt.Sprintf("Path: %s\nOutdated plugins (%d):\n%s",
wpPath, outdatedTotal, strings.Join(detailLines, "\n")),
})
}
return findings
}
// severityRank orders severities so an aggregate can pick the worst.
// Higher returned value means more severe.
func severityRank(s alert.Severity) int {
switch s {
case alert.Critical:
return 3
case alert.High:
return 2
case alert.Warning:
return 1
default:
return 0
}
}
func plural(n int) string {
if n == 1 {
return ""
}
return "s"
}
// extractWPDomain runs `wp option get siteurl` to discover the site's domain.
// Falls back to directory name heuristics if wp-cli fails.
func extractWPDomain(ctx context.Context, wpPath, user string) string {
// Stdout-only: some sites print "WARNING: MYSQL_OPT_RECONNECT deprecated"
// or similar on stderr during wp-cli boot. Mixing that into the value
// would produce a poisoned domain like "Warning: ... https://site.com".
out, err := cmdExec.RunContextStdout(ctx, "su", "-", user, "-s", "/bin/bash", "-c",
wpCLIFlags+"option get siteurl --path="+shellQuote(wpPath),
)
if err == nil {
url := strings.TrimSpace(string(out))
if url != "" {
// Strip protocol prefix for display.
url = strings.TrimPrefix(url, "https://")
url = strings.TrimPrefix(url, "http://")
return url
}
}
// Fallback: use directory name after public_html (addon domain)
// or account name (main domain).
parts := strings.Split(wpPath, "/")
for i, p := range parts {
if p == "public_html" && i+1 < len(parts) {
return parts[i+1]
}
}
return user
}
// shellQuote wraps a string in single quotes for safe shell argument passing.
// Any embedded single quotes are escaped as '\” (end quote, literal quote, start quote).
func shellQuote(s string) string {
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
}
package checks
import (
"context"
"fmt"
"math"
"path/filepath"
"strconv"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/processctx"
"github.com/pidginhost/csm/internal/state"
)
func processStatusUID(data []byte) (int, bool) {
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) == 0 {
return 0, false
}
uid, err := strconv.Atoi(fields[0])
return uid, err == nil && uid >= 0
}
}
return 0, false
}
func processIdentityForUID(uid int) (string, string) {
if uid < 0 || uid > math.MaxUint32 {
return "", ""
}
// #nosec G115 -- uid is range-checked against math.MaxUint32 above.
user := LookupUser(uint32(uid))
if uid >= 1000 && user != "" && !strings.HasPrefix(user, "uid:") {
return user, user
}
return user, ""
}
// suspiciousExeNames flags processes whose exe basename contains any of
// these substrings. Shared between the periodic CheckSuspiciousProcesses
// and the live BPF exec backend (which cannot see cmdline patterns and
// relies on exe-name + exe-path matching).
var suspiciousExeNames = []string{"defunct", "gsocket", "gs-netcat", "gs-sftp"}
// suspiciousExePaths flags processes whose exe path contains any of these
// directory prefixes. Shared with the live BPF exec backend.
var suspiciousExePaths = []string{"/tmp/", "/dev/shm/", "/.config/"}
// suspiciousCmdlinePatterns is checked only by the periodic
// CheckSuspiciousProcesses; the BPF exec backend cannot read cmdline at
// the moment of exec.
var suspiciousCmdlinePatterns = []string{
"/bin/sh -i", "/bin/bash -i", "bash -i",
"/dev/tcp/", "semutmerah", "gsocket",
"reverse", "nc -e", "ncat -e",
}
// EvaluateExec returns findings for a single execve event observed by the
// BPF live backend. Inputs are the (UID, PID, comm, exe, parentComm)
// tuple the kernel hook collects. Pure function: no IO. The legacy
// periodic checks (CheckSuspiciousProcesses, CheckFakeKernelThreads) keep
// using cmdline-aware detection that this function cannot replicate.
func EvaluateExec(uid uint32, pid uint32, comm, exe, parentComm string) []alert.Finding {
var out []alert.Finding
pidInt := int(pid)
if uid != 0 && len(comm) >= 2 && comm[0] == '[' && comm[len(comm)-1] == ']' {
out = append(out, alert.Finding{
Severity: alert.Critical,
Check: "fake_kernel_thread",
Message: fmt.Sprintf("Non-root process masquerading as kernel thread: %s", comm),
Details: fmt.Sprintf("PID: %d, UID: %d, exe: %s, parent: %s", pid, uid, exe, parentComm),
PID: pidInt,
})
}
if uid == 0 {
return out
}
exeName := filepath.Base(exe)
exeNameLower := strings.ToLower(exeName)
for _, s := range suspiciousExeNames {
if strings.Contains(exeNameLower, s) {
out = append(out, alert.Finding{
Severity: alert.Critical,
Check: "suspicious_process",
Message: fmt.Sprintf("Suspicious process name: %s", exeName),
Details: fmt.Sprintf("PID: %d, UID: %d, exe: %s, comm: %s, parent: %s", pid, uid, exe, comm, parentComm),
PID: pidInt,
})
break
}
}
for _, p := range suspiciousExePaths {
if strings.Contains(exe, p) {
out = append(out, alert.Finding{
Severity: alert.High,
Check: "suspicious_process",
Message: fmt.Sprintf("Process running from suspicious path: %s", exe),
Details: fmt.Sprintf("PID: %d, UID: %d, comm: %s, parent: %s", pid, uid, comm, parentComm),
PID: pidInt,
})
break
}
}
return out
}
func CheckFakeKernelThreads(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
procs, _ := osFS.Glob("/proc/[0-9]*/status")
for _, statusPath := range procs {
pid := filepath.Base(filepath.Dir(statusPath))
data, err := osFS.ReadFile(statusPath)
if err != nil {
continue
}
var name, uid string
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "Name:\t") {
name = strings.TrimPrefix(line, "Name:\t")
}
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
uid = fields[0]
}
}
}
// Kernel threads run as root (uid 0). Non-root process with
// a name that looks like a kernel thread is suspicious.
if uid == "0" || uid == "" {
continue
}
// Read cmdline - real kernel threads have empty cmdline
cmdline, _ := osFS.ReadFile(filepath.Join("/proc", pid, "cmdline"))
cmdStr := strings.TrimRight(strings.ReplaceAll(string(cmdline), "\x00", " "), " ")
// Check if the process name contains brackets (faking kernel thread)
// or if cmdline starts with [
if strings.HasPrefix(cmdStr, "[") || strings.HasPrefix(name, "[") {
// This is a non-root process masquerading as a kernel thread
exe, _ := osFS.Readlink(filepath.Join("/proc", pid, "exe"))
uidInt, _ := strconv.Atoi(uid)
pidInt, _ := strconv.Atoi(pid)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "fake_kernel_thread",
Message: fmt.Sprintf("Non-root process masquerading as kernel thread: [%s]", name),
Details: fmt.Sprintf("PID: %s, UID: %d, exe: %s, cmdline: %s", pid, uidInt, exe, cmdStr),
PID: pidInt,
})
}
}
return findings
}
func CheckSuspiciousProcesses(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
suspiciousNames := suspiciousExeNames
suspiciousCmdline := suspiciousCmdlinePatterns
suspiciousPaths := suspiciousExePaths
procs, _ := osFS.Glob("/proc/[0-9]*/exe")
for _, exePath := range procs {
pid := filepath.Base(filepath.Dir(exePath))
pidInt, _ := strconv.Atoi(pid)
statusData, _ := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
var uid string
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
uid = fields[0]
}
}
}
if uid == "0" {
continue // Skip root processes for this check
}
exe, _ := osFS.Readlink(exePath)
cmdline, _ := osFS.ReadFile(filepath.Join("/proc", pid, "cmdline"))
cmdStr := strings.TrimRight(strings.ReplaceAll(string(cmdline), "\x00", " "), " ")
// Check executable name
exeName := filepath.Base(exe)
for _, s := range suspiciousNames {
if strings.Contains(strings.ToLower(exeName), s) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "suspicious_process",
Message: fmt.Sprintf("Suspicious process name: %s", exeName),
Details: fmt.Sprintf("PID: %s, UID: %s, exe: %s, cmdline: %s", pid, uid, exe, cmdStr),
PID: pidInt,
})
}
}
// Check cmdline for suspicious patterns
cmdLower := strings.ToLower(cmdStr)
for _, s := range suspiciousCmdline {
if strings.Contains(cmdLower, strings.ToLower(s)) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "suspicious_process",
Message: fmt.Sprintf("Suspicious cmdline pattern: %s", s),
Details: fmt.Sprintf("PID: %s, UID: %s, exe: %s, cmdline: %s", pid, uid, exe, cmdStr),
PID: pidInt,
})
break
}
}
// Check executable path
for _, s := range suspiciousPaths {
if strings.Contains(exe, s) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "suspicious_process",
Message: fmt.Sprintf("Process running from suspicious path: %s", exe),
Details: fmt.Sprintf("PID: %s, UID: %s, cmdline: %s", pid, uid, cmdStr),
PID: pidInt,
})
break
}
}
}
return findings
}
// CheckPHPProcesses inspects running lsphp processes to detect active
// webshell execution. Only reads /proc cmdline - zero disk I/O.
func CheckPHPProcesses(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
suspiciousPHPPaths := []string{
"/tmp/",
"/dev/shm/",
"/wp-content/uploads/",
"/.config/",
}
procs, _ := osFS.Glob("/proc/[0-9]*/cmdline")
for _, cmdPath := range procs {
pid := filepath.Base(filepath.Dir(cmdPath))
pidInt, _ := strconv.Atoi(pid)
cmdline, err := osFS.ReadFile(cmdPath)
if err != nil {
continue
}
cmdStr := strings.ReplaceAll(string(cmdline), "\x00", " ")
// Only check lsphp processes
if !strings.Contains(cmdStr, "lsphp") {
continue
}
for _, sus := range suspiciousPHPPaths {
if strings.Contains(cmdStr, sus) {
statusData, _ := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
uid, uidOK := processStatusUID(statusData)
uidText := ""
var proc *processctx.ProcessContext
if uidOK {
uidText = strconv.Itoa(uid)
userName, account := processIdentityForUID(uid)
proc = &processctx.ProcessContext{
PID: pidInt,
UID: uid,
User: userName,
Account: account,
}
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "php_suspicious_execution",
Message: fmt.Sprintf("PHP executing from suspicious path: %s", sus),
Details: fmt.Sprintf("PID: %s, UID: %s, cmdline: %s", pid, uidText, strings.TrimSpace(cmdStr)),
PID: pidInt,
Process: proc,
})
break
}
}
}
return findings
}
package checks
import (
"context"
"os"
"path/filepath"
)
// ---------------------------------------------------------------------------
// OS abstraction — filesystem read operations
// ---------------------------------------------------------------------------
// OS abstracts filesystem operations (read and write) used by check functions.
// Production code uses realOS{}; tests swap in a mockOS via SetOS().
type OS interface {
ReadFile(name string) ([]byte, error)
ReadDir(name string) ([]os.DirEntry, error)
Stat(name string) (os.FileInfo, error)
Lstat(name string) (os.FileInfo, error)
Readlink(name string) (string, error)
Open(name string) (*os.File, error)
WriteFile(name string, data []byte, perm os.FileMode) error
MkdirAll(path string, perm os.FileMode) error
Remove(name string) error
Glob(pattern string) ([]string, error)
}
type realOS struct{}
// #nosec G304 -- filesystem abstraction; check functions pass trusted paths.
func (realOS) ReadFile(name string) ([]byte, error) { return os.ReadFile(name) }
func (realOS) ReadDir(name string) ([]os.DirEntry, error) { return os.ReadDir(name) }
func (realOS) Stat(name string) (os.FileInfo, error) { return os.Stat(name) }
func (realOS) Lstat(name string) (os.FileInfo, error) { return os.Lstat(name) }
func (realOS) Readlink(name string) (string, error) { return os.Readlink(name) }
// #nosec G304 -- filesystem abstraction; check functions pass trusted paths.
func (realOS) Open(name string) (*os.File, error) { return os.Open(name) }
// #nosec G306 -- callers pass explicit perm; intent is operator-readable mode.
func (realOS) WriteFile(name string, data []byte, perm os.FileMode) error {
return os.WriteFile(name, data, perm)
}
func (realOS) MkdirAll(path string, perm os.FileMode) error { return os.MkdirAll(path, perm) }
func (realOS) Remove(name string) error { return os.Remove(name) }
func (realOS) Glob(pattern string) ([]string, error) { return filepath.Glob(pattern) }
// osFS is the package-level filesystem provider. All check functions use
// this instead of calling os.ReadFile / os.ReadDir / etc. directly.
var osFS OS = realOS{}
// SetOS replaces the filesystem provider. Used by tests to inject mocks.
func SetOS(o OS) { osFS = o }
// ---------------------------------------------------------------------------
// CmdRunner abstraction — external command execution
// ---------------------------------------------------------------------------
// CmdRunner abstracts external command execution used by check functions.
// Production code uses realCmd{}; tests swap in a mockCmdRunner via SetCmdRunner().
//
// RunContext returns stdout+stderr merged (CombinedOutput) and is fine for
// tools that only write to stdout. RunContextStdout returns stdout only and
// should be used when the command prints structured output (JSON, a URL, ...)
// on stdout and chatter (warnings, PHP notices, MySQL deprecations) on stderr
// -- mixing them there would corrupt the parse. RunContextStdout also surfaces
// context.DeadlineExceeded on timeout so callers can distinguish "no output"
// from "empty output".
type CmdRunner interface {
Run(name string, args ...string) ([]byte, error)
RunAllowNonZero(name string, args ...string) ([]byte, error)
RunContext(parent context.Context, name string, args ...string) ([]byte, error)
RunContextStdout(parent context.Context, name string, args ...string) ([]byte, error)
RunWithEnv(name string, args []string, extraEnv ...string) ([]byte, error)
LookPath(file string) (string, error)
}
type realCmd struct{}
func (realCmd) Run(name string, args ...string) ([]byte, error) {
return runCmdReal(name, args...)
}
func (realCmd) RunAllowNonZero(name string, args ...string) ([]byte, error) {
return runCmdAllowNonZeroReal(name, args...)
}
func (realCmd) RunContext(parent context.Context, name string, args ...string) ([]byte, error) {
return runCmdCombinedContextReal(parent, name, args...)
}
func (realCmd) RunContextStdout(parent context.Context, name string, args ...string) ([]byte, error) {
return runCmdStdoutContextReal(parent, name, args...)
}
func (realCmd) RunWithEnv(name string, args []string, extraEnv ...string) ([]byte, error) {
return runCmdWithEnvReal(name, args, extraEnv...)
}
func (realCmd) LookPath(file string) (string, error) {
return lookupSystemCommand(file)
}
// cmdExec is the package-level command runner. All check functions use
// this instead of calling runCmd / exec.Command directly.
var cmdExec CmdRunner = realCmd{}
// SetCmdRunner replaces the command runner. Used by tests to inject mocks.
func SetCmdRunner(r CmdRunner) { cmdExec = r }
//go:build linux
package checks
import (
"fmt"
"io"
"os"
"syscall"
"golang.org/x/sys/unix"
)
// var (not const) so Linux tests can force EXDEV without depending on
// the host's filesystem layout.
var quarantineLinkByFD = linkQuarantineFileByFD
// quarantineFileTOCTOUSafe moves a single regular file into quarantine in
// a way that defends against the classic detect-then-quarantine race: an
// attacker who controls the directory can swap a legitimate file in
// between Lstat and Rename, tricking CSM into moving the wrong file out
// of the user's home. The defence:
//
// 1. Open the path with O_RDONLY|O_NOFOLLOW. Symlinks are refused at
// the kernel level; the fd is bound to the inode that existed at
// open time.
// 2. Fstat the fd and verify it still matches the inode we detected
// earlier (sameFileIdentity). A late swap loses here.
// 3. linkat(/proc/self/fd/N, qPath, AT_SYMLINK_FOLLOW) creates a
// hardlink to the inode we opened, by file descriptor, not by
// path. If hardlinking is unavailable, copy from the same open
// fd instead of reopening by path.
// 4. Unlink the source path only if it still resolves to the inode
// we quarantined. If an attacker swapped in a replacement after
// step 2, leave that replacement alone.
//
// Returns nil on success. Errors describe what failed; callers should
// not retry blindly because a failure usually means the file moved.
func quarantineFileTOCTOUSafe(path, qPath string, originalInfo os.FileInfo) error {
if originalInfo == nil {
return fmt.Errorf("quarantine: missing original stat")
}
if originalInfo.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("quarantine: refused symlink at %s", path)
}
// O_NOFOLLOW makes open() fail with ELOOP if path resolved to a
// symlink in the final component; combined with the earlier Lstat
// rejection above, this closes the symlink-swap variant.
// #nosec G304 -- path is the quarantine subject; O_NOFOLLOW plus
// fd identity verification below fail closed on symlink and inode swaps.
fd, err := os.OpenFile(path, os.O_RDONLY|syscall.O_NOFOLLOW, 0)
if err != nil {
return fmt.Errorf("quarantine: open %s: %w", path, err)
}
defer fd.Close()
// Re-stat the open fd and confirm the inode matches what we
// detected. The race window between Lstat and OpenFile is narrow
// but real - an attacker who hits it gets caught here.
cur, err := fd.Stat()
if err != nil {
return fmt.Errorf("quarantine: fstat %s: %w", path, err)
}
if !sameFileIdentity(cur, originalInfo) {
return fmt.Errorf("quarantine: file at %s changed between detection and quarantine (TOCTOU)", path)
}
// Defence against inode reuse: on busy tmpfs / ext4 mounts the kernel
// can hand out the freed inode to whatever the attacker wrote next.
// A matching inode is necessary but not sufficient — also require the
// content shape (size + mtime) to match what the detector recorded.
if !sameContentShape(cur, originalInfo) {
return fmt.Errorf("quarantine: file at %s changed between detection and quarantine (TOCTOU, inode reused)", path)
}
// Refuse to quarantine a non-regular file (block, char, socket,
// FIFO). The detector only flags regular files, so a non-regular
// shape at this point means someone is trying to move CSM at a
// device node or pipe.
if !cur.Mode().IsRegular() {
return fmt.Errorf("quarantine: refusing non-regular file at %s (mode=%v)", path, cur.Mode())
}
if err := quarantineLinkByFD(fd, qPath); err != nil {
if err := copyQuarantineFileByFD(fd, qPath); err != nil {
return fmt.Errorf("quarantine: copy %s -> %s: %w", path, qPath, err)
}
}
if err := removeQuarantinedSource(path, qPath, cur); err != nil {
return err
}
return nil
}
// sameContentShape verifies that two stats describe a file with the same
// size and modification time. Used as a defence-in-depth check after
// sameFileIdentity passes, because inode reuse on tmpfs / ext4 lets an
// attacker recreate a file under the same path with a fresh ino that
// happens to match the freed slot.
func sameContentShape(a, b os.FileInfo) bool {
if a == nil || b == nil {
return false
}
if a.Size() != b.Size() {
return false
}
return a.ModTime().Equal(b.ModTime())
}
func linkQuarantineFileByFD(fd *os.File, qPath string) error {
procLink := fmt.Sprintf("/proc/self/fd/%d", fd.Fd())
return unix.Linkat(unix.AT_FDCWD, procLink, unix.AT_FDCWD, qPath, unix.AT_SYMLINK_FOLLOW)
}
func copyQuarantineFileByFD(src *os.File, qPath string) error {
if _, err := src.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("seek source: %w", err)
}
// #nosec G304 G306 -- qPath is generated under the quarantine
// directory; 0600 keeps cross-device quarantine copies private.
dst, err := os.OpenFile(qPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600)
if err != nil {
return err
}
removeCopy := true
defer func() {
if removeCopy {
_ = os.Remove(qPath)
}
}()
if _, err := io.Copy(dst, src); err != nil {
_ = dst.Close()
return err
}
if err := dst.Close(); err != nil {
return err
}
removeCopy = false
return nil
}
func removeQuarantinedSource(path, qPath string, original os.FileInfo) error {
info, err := os.Lstat(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
_ = os.Remove(qPath)
return fmt.Errorf("quarantine: stat source before unlink %s: %w", path, err)
}
if info.Mode()&os.ModeSymlink != 0 || !sameFileIdentity(info, original) {
return nil
}
if !sameContentShape(info, original) {
_ = os.Remove(qPath)
return fmt.Errorf("quarantine: source changed before unlink %s", path)
}
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
_ = os.Remove(qPath)
return fmt.Errorf("quarantine: unlink source %s: %w", path, err)
}
return nil
}
package checks
import (
"container/list"
"net"
"sync"
"time"
)
// RDNSCacheConfig is the config block for NewRDNSCache. Resolve is the
// function used to perform the actual reverse lookup; production
// callers wrap net.LookupAddr. ResolveDeadline bounds each lookup;
// 0 disables the deadline. MaxSize caps the number of cached entries
// to keep memory bounded on hosts that see a wide spread of remote
// IPs (BPF SMTP-egress is the motivating case); the oldest entry by
// cachedAt is evicted before a new one is inserted past the cap.
// 0 falls back to rdnsCacheDefaultMaxSize.
type RDNSCacheConfig struct {
TTL time.Duration
Resolve func(ip net.IP) (string, error)
ResolveDeadline time.Duration
MaxSize int
// MaxConcurrent caps the number of deadline-bound resolve goroutines in
// flight at once. A goroutine blocked on a wedged resolver cannot be
// cancelled in Go, so without a cap a burst of distinct IPs under deadline
// saturation spawns one abandonable goroutine per IP. 0 falls back to
// rdnsCacheDefaultMaxConcurrent.
MaxConcurrent int
}
const (
rdnsCacheDefaultMaxSize = 10000
rdnsCacheDefaultMaxConcurrent = 64
)
// RDNSCache is a small TTL cache around reverse DNS lookups. Cached
// negative results (resolver error / NXDOMAIN) are kept until TTL too,
// so the detector does not hammer a slow resolver on a known-bad IP.
// Entries are capped at maxSize; the oldest-by-cachedAt entry is
// evicted on insert once the cap is reached.
type RDNSCache struct {
mu sync.Mutex
ttl time.Duration
deadln time.Duration
maxSize int
resolve func(ip net.IP) (string, error)
now func() time.Time
order *list.List
entries map[string]*list.Element
sem chan struct{}
}
type rdnsEntry struct {
key string
host string
cachedAt time.Time
}
// NewRDNSCache returns a ready cache.
func NewRDNSCache(cfg RDNSCacheConfig) *RDNSCache {
maxSize := cfg.MaxSize
if maxSize <= 0 {
maxSize = rdnsCacheDefaultMaxSize
}
maxConcurrent := cfg.MaxConcurrent
if maxConcurrent <= 0 {
maxConcurrent = rdnsCacheDefaultMaxConcurrent
}
return &RDNSCache{
ttl: cfg.TTL,
deadln: cfg.ResolveDeadline,
maxSize: maxSize,
resolve: cfg.Resolve,
now: time.Now,
order: list.New(),
entries: map[string]*list.Element{},
sem: make(chan struct{}, maxConcurrent),
}
}
// evictOldestLocked drops the oldest cached entry. Caller holds c.mu.
func (c *RDNSCache) evictOldestLocked() {
el := c.order.Front()
if el == nil {
return
}
entry := el.Value.(*rdnsEntry)
delete(c.entries, entry.key)
c.order.Remove(el)
}
// Lookup returns the cached hostname for ip, or "" on miss/error/deadline.
// Lookup blocks the caller for at most cfg.ResolveDeadline; cache hits
// return immediately.
func (c *RDNSCache) Lookup(ip net.IP) string {
if ip == nil {
return ""
}
key := ip.String()
c.mu.Lock()
now := c.now()
if el, ok := c.entries[key]; ok {
e := el.Value.(*rdnsEntry)
if now.Sub(e.cachedAt) <= c.ttl {
c.mu.Unlock()
return e.host
}
}
c.mu.Unlock()
host := c.runWithDeadline(ip)
c.mu.Lock()
now = c.now()
if el, present := c.entries[key]; present {
e := el.Value.(*rdnsEntry)
e.host = host
e.cachedAt = now
c.order.MoveToBack(el)
c.mu.Unlock()
return host
}
if c.order.Len() >= c.maxSize {
c.evictOldestLocked()
}
c.entries[key] = c.order.PushBack(&rdnsEntry{key: key, host: host, cachedAt: now})
c.mu.Unlock()
return host
}
func (c *RDNSCache) runWithDeadline(ip net.IP) string {
if c.deadln <= 0 {
host, err := c.resolve(ip)
if err != nil {
return ""
}
return host
}
// Cap in-flight resolve goroutines. A goroutine blocked on a wedged
// resolver keeps its slot until the syscall finally returns, so under
// deadline saturation further lookups fail fast (return "" like a
// deadline miss) instead of spawning more abandonable goroutines.
select {
case c.sem <- struct{}{}:
default:
return ""
}
type result struct {
host string
err error
}
ch := make(chan result, 1)
go func() {
defer func() { <-c.sem }()
host, err := c.resolve(ip)
ch <- result{host, err}
}()
timer := time.NewTimer(c.deadln)
defer timer.Stop()
select {
case r := <-ch:
if r.err != nil {
return ""
}
return r.host
case <-timer.C:
return ""
}
}
package checks
import "sort"
// CheckInfo describes a single named check emitted as an alert.Finding.Check.
// Category groups related checks for display in the settings UI. Internal is
// true for checks that exist for plumbing (self-tests, plumbing findings) and
// should not appear in user-facing dropdowns like alerts.email.disabled_checks.
type CheckInfo struct {
Name string
Category string
Internal bool
}
// Category labels are the groupings shown in the multi-select UI. Keep the
// order below in sync with checkCategoryOrder so categories render in a sane
// order rather than alphabetically (Auth first, Internal last).
const (
CategoryAuth = "Authentication & Login"
CategoryBruteForce = "Brute Force"
CategoryMalware = "Malware & Webshells"
CategoryWeb = "Web & Application"
CategoryDatabase = "Database Content"
CategoryEmail = "Email & Phishing"
CategoryPerformance = "Performance"
CategoryNetwork = "Network & Firewall"
CategorySystem = "System Integrity"
CategoryWAF = "WAF & ModSecurity"
CategoryCorrelation = "Correlation & Health"
CategoryInternal = "Internal"
)
var checkCategoryOrder = []string{
CategoryAuth,
CategoryBruteForce,
CategoryMalware,
CategoryWeb,
CategoryDatabase,
CategoryEmail,
CategoryPerformance,
CategoryNetwork,
CategorySystem,
CategoryWAF,
CategoryCorrelation,
CategoryInternal,
}
// checkRegistry is the authoritative list of every Check string the daemon
// may emit. Adding a new alert.Finding Check name anywhere in internal/checks,
// internal/daemon, or internal/webui without also adding it here will fail
// TestCheckRegistryCoversProductionCode.
var checkRegistry = []CheckInfo{
// --- Authentication & Login ------------------------------------------
{Name: "admin_panel_bruteforce", Category: CategoryAuth},
{Name: "api_auth_failure", Category: CategoryAuth},
{Name: "api_auth_failure_realtime", Category: CategoryAuth},
{Name: "api_tokens", Category: CategoryAuth},
{Name: "bulk_password_change", Category: CategoryAuth},
{Name: "cpanel_file_upload", Category: CategoryAuth},
{Name: "cpanel_file_upload_realtime", Category: CategoryAuth},
{Name: "cpanel_login", Category: CategoryAuth},
{Name: "cpanel_login_realtime", Category: CategoryAuth},
{Name: "cpanel_multi_ip_login", Category: CategoryAuth},
{Name: "cpanel_password_purge", Category: CategoryAuth},
{Name: "cpanel_password_purge_realtime", Category: CategoryAuth},
{Name: "ftp_auth_failure_realtime", Category: CategoryAuth},
{Name: "ftp_bruteforce", Category: CategoryAuth},
{Name: "ftp_login", Category: CategoryAuth},
{Name: "ftp_login_realtime", Category: CategoryAuth},
{Name: "credential_stuffing", Category: CategoryAuth},
{Name: "pam_bruteforce", Category: CategoryAuth},
{Name: "pam_login", Category: CategoryAuth},
{Name: "password_hijack_confirmed", Category: CategoryAuth},
{Name: "root_password_change", Category: CategoryAuth},
{Name: "shadow_change", Category: CategoryAuth},
{Name: "ssh_keys", Category: CategoryAuth},
{Name: "ssh_login_realtime", Category: CategoryAuth},
{Name: "ssh_login_unknown_ip", Category: CategoryAuth},
{Name: "sshd_config_change", Category: CategoryAuth},
{Name: "uid0_account", Category: CategoryAuth},
{Name: "webmail_bruteforce", Category: CategoryAuth},
{Name: "webmail_login_realtime", Category: CategoryAuth},
{Name: "whm_account_action", Category: CategoryAuth},
{Name: "whm_login_realtime", Category: CategoryAuth},
{Name: "whm_password_change", Category: CategoryAuth},
{Name: "whm_password_change_noninfra", Category: CategoryAuth},
{Name: "whm_unauth_scripts_realtime", Category: CategoryAuth},
// --- Brute Force -----------------------------------------------------
{Name: "http_request_flood", Category: CategoryBruteForce},
{Name: "http_ua_spoof", Category: CategoryBruteForce},
{Name: "http_distributed_flood", Category: CategoryBruteForce},
{Name: "mail_account_compromised", Category: CategoryBruteForce},
{Name: "mail_account_spray", Category: CategoryBruteForce},
{Name: "mail_bruteforce", Category: CategoryBruteForce},
{Name: "mail_subnet_spray", Category: CategoryBruteForce},
{Name: "smtp_account_spray", Category: CategoryBruteForce},
{Name: "smtp_bruteforce", Category: CategoryBruteForce},
{Name: "smtp_probe_abuse", Category: CategoryBruteForce},
{Name: "smtp_subnet_spray", Category: CategoryBruteForce},
{Name: "wp_login_bruteforce", Category: CategoryBruteForce},
{Name: "wp_user_enumeration", Category: CategoryBruteForce},
{Name: "xmlrpc_abuse", Category: CategoryBruteForce},
// --- Malware & Webshells --------------------------------------------
{Name: "backdoor_binary", Category: CategoryMalware},
{Name: "cgi_backdoor_realtime", Category: CategoryMalware},
{Name: "cgi_suspicious_location_realtime", Category: CategoryMalware},
{Name: "cross_account_malware", Category: CategoryMalware},
{Name: "executable_in_config_realtime", Category: CategoryMalware},
{Name: "executable_in_tmp_realtime", Category: CategoryMalware},
{Name: "fake_kernel_thread", Category: CategoryMalware},
{Name: "group_writable_php", Category: CategoryMalware},
{Name: "new_executable_in_config", Category: CategoryMalware},
{Name: "new_php_in_sensitive_dir", Category: CategoryMalware},
{Name: "new_php_in_sensitive_dir_clean", Category: CategoryMalware},
{Name: "new_php_in_uploads", Category: CategoryMalware},
{Name: "new_php_in_uploads_clean", Category: CategoryMalware},
{Name: "new_suspicious_php", Category: CategoryMalware},
{Name: "new_webshell_file", Category: CategoryMalware},
{Name: "nulled_plugin", Category: CategoryMalware},
{Name: "obfuscated_php", Category: CategoryMalware},
{Name: "obfuscated_php_realtime", Category: CategoryMalware},
{Name: "php_dropper_realtime", Category: CategoryMalware},
{Name: "php_in_sensitive_dir_realtime", Category: CategoryMalware},
{Name: "php_in_uploads_realtime", Category: CategoryMalware},
{Name: "php_shield_block", Category: CategoryMalware},
{Name: "php_shield_eval", Category: CategoryMalware},
{Name: "php_shield_webshell", Category: CategoryMalware},
{Name: "php_suspicious_execution", Category: CategoryMalware},
{Name: "signature_match_realtime", Category: CategoryMalware},
{Name: "suid_binary", Category: CategoryMalware},
{Name: "suspicious_file", Category: CategoryMalware},
{Name: "suspicious_php_content", Category: CategoryMalware},
{Name: "suspicious_process", Category: CategoryMalware},
{Name: "webshell", Category: CategoryMalware},
{Name: "webshell_content_realtime", Category: CategoryMalware},
{Name: "webshell_realtime", Category: CategoryMalware},
{Name: "world_writable_php", Category: CategoryMalware},
{Name: "yara_match_realtime", Category: CategoryMalware},
{Name: "yara_worker_crashed", Category: CategoryMalware},
// --- Web & Application ----------------------------------------------
{Name: "htaccess_auto_prepend", Category: CategoryWeb},
{Name: "htaccess_errordocument_hijack", Category: CategoryWeb},
{Name: "htaccess_filesmatch_shield", Category: CategoryWeb},
{Name: "htaccess_handler_abuse", Category: CategoryWeb},
{Name: "htaccess_header_injection", Category: CategoryWeb},
{Name: "htaccess_injection", Category: CategoryWeb},
{Name: "htaccess_injection_realtime", Category: CategoryWeb},
{Name: "htaccess_php_in_uploads", Category: CategoryWeb},
{Name: "htaccess_spam_redirect", Category: CategoryWeb},
{Name: "htaccess_user_agent_cloak", Category: CategoryWeb},
{Name: "open_basedir", Category: CategoryWeb},
{Name: "outdated_plugins", Category: CategoryWeb},
{Name: "php_config_change", Category: CategoryWeb},
{Name: "php_config_realtime", Category: CategoryWeb},
{Name: "symlink_attack", Category: CategoryWeb},
{Name: "wp_core_integrity", Category: CategoryWeb},
// --- Database Content -----------------------------------------------
{Name: "database_dump", Category: CategoryDatabase},
{Name: "db_malicious_event", Category: CategoryDatabase},
{Name: "db_malicious_function", Category: CategoryDatabase},
{Name: "db_malicious_procedure", Category: CategoryDatabase},
{Name: "admin_cross_account_overlap", Category: CategoryDatabase},
{Name: "credential_reuse", Category: CategoryDatabase},
{Name: "supply_chain_vuln", Category: CategoryWeb},
{Name: "db_magic_token_user", Category: CategoryDatabase},
{Name: "db_malicious_trigger", Category: CategoryDatabase},
{Name: "db_options_injection", Category: CategoryDatabase},
{Name: "db_post_injection", Category: CategoryDatabase},
{Name: "db_unexpected_event", Category: CategoryDatabase},
{Name: "db_unexpected_function", Category: CategoryDatabase},
{Name: "db_unexpected_procedure", Category: CategoryDatabase},
{Name: "db_unexpected_trigger", Category: CategoryDatabase},
{Name: "drupal_admin_injection", Category: CategoryDatabase},
{Name: "drupal_content_injection", Category: CategoryDatabase},
{Name: "drupal_settings_injection", Category: CategoryDatabase},
{Name: "joomla_admin_injection", Category: CategoryDatabase},
{Name: "joomla_content_injection", Category: CategoryDatabase},
{Name: "joomla_extensions_injection", Category: CategoryDatabase},
{Name: "magento_admin_injection", Category: CategoryDatabase},
{Name: "magento_content_injection", Category: CategoryDatabase},
{Name: "magento_settings_injection", Category: CategoryDatabase},
{Name: "opencart_admin_injection", Category: CategoryDatabase},
{Name: "opencart_content_injection", Category: CategoryDatabase},
{Name: "opencart_settings_injection", Category: CategoryDatabase},
{Name: "db_rogue_admin", Category: CategoryDatabase},
{Name: "db_siteurl_hijack", Category: CategoryDatabase},
{Name: "db_spam_cleaned", Category: CategoryDatabase},
{Name: "db_spam_found", Category: CategoryDatabase},
{Name: "db_spam_injection", Category: CategoryDatabase},
{Name: "db_suspicious_admin_email", Category: CategoryDatabase},
// --- Email & Phishing -----------------------------------------------
{Name: "credential_log_realtime", Category: CategoryEmail},
{Name: "email_auth_failure_realtime", Category: CategoryEmail},
{Name: "email_cloud_relay_abuse", Category: CategoryEmail},
{Name: "email_av_degraded", Category: CategoryEmail},
{Name: "email_av_parse_error", Category: CategoryEmail},
{Name: "email_av_quarantine_error", Category: CategoryEmail},
{Name: "email_av_scan_error", Category: CategoryEmail},
{Name: "email_av_timeout", Category: CategoryEmail},
{Name: "email_compromised_account", Category: CategoryEmail},
{Name: "email_credential_leak", Category: CategoryEmail},
{Name: "email_dkim_failure", Category: CategoryEmail},
{Name: "email_filter_blackhole", Category: CategoryEmail},
{Name: "email_filter_exfil", Category: CategoryEmail},
{Name: "email_filter_forwarder", Category: CategoryEmail},
{Name: "email_filter_pipe", Category: CategoryEmail},
{Name: "email_mail_filters", Category: CategoryEmail, Internal: true},
{Name: "email_malware", Category: CategoryEmail},
{Name: "email_phishing_content", Category: CategoryEmail},
{Name: "email_php_relay_abuse", Category: CategoryEmail},
{Name: "email_php_relay_account_volume_capped", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_action_dry_run", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_action_failed", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_action_skipped", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_cpanel_limit_unreadable", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_disabled", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_inotify_overflow", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_inotify_overflow_recovered", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_msgindex_persist_failed", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_no_exim", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_overflow_scan_truncated", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_path2b_disabled", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_policies_reload", Category: CategoryEmail},
{Name: "email_php_relay_rate_limit_hit", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_sweep_failed", Category: CategoryEmail, Internal: true},
{Name: "email_php_relay_watcher_failed", Category: CategoryEmail, Internal: true},
{Name: "email_defer_fail_governor", Category: CategoryEmail},
{Name: "email_pipe_forwarder", Category: CategoryEmail},
{Name: "email_rate_critical", Category: CategoryEmail},
{Name: "email_rate_warning", Category: CategoryEmail},
{Name: "email_spam_outbreak", Category: CategoryEmail},
{Name: "email_spf_rejection", Category: CategoryEmail},
{Name: "email_suspicious_forwarder", Category: CategoryEmail},
{Name: "email_suspicious_geo", Category: CategoryEmail},
{Name: "email_weak_password", Category: CategoryEmail},
{Name: "exim_frozen_realtime", Category: CategoryEmail},
{Name: "mail_per_account", Category: CategoryEmail},
{Name: "mail_queue", Category: CategoryEmail},
{Name: "phishing_credential_log", Category: CategoryEmail},
{Name: "phishing_directory", Category: CategoryEmail},
{Name: "phishing_iframe", Category: CategoryEmail},
{Name: "phishing_kit_archive", Category: CategoryEmail},
{Name: "phishing_kit_realtime", Category: CategoryEmail},
{Name: "phishing_page", Category: CategoryEmail},
{Name: "phishing_php", Category: CategoryEmail},
{Name: "phishing_realtime", Category: CategoryEmail},
{Name: "phishing_redirector", Category: CategoryEmail},
// --- Performance -----------------------------------------------------
{Name: "perf_error_logs", Category: CategoryPerformance},
{Name: "perf_load", Category: CategoryPerformance},
{Name: "perf_memory", Category: CategoryPerformance},
{Name: "perf_mysql_config", Category: CategoryPerformance},
{Name: "perf_php_handler", Category: CategoryPerformance},
{Name: "perf_php_processes", Category: CategoryPerformance},
{Name: "perf_redis_config", Category: CategoryPerformance},
{Name: "perf_wp_config", Category: CategoryPerformance},
{Name: "perf_wp_cron", Category: CategoryPerformance},
{Name: "perf_wp_transients", Category: CategoryPerformance},
// --- Network & Firewall ---------------------------------------------
{Name: "backdoor_port", Category: CategoryNetwork},
{Name: "backdoor_port_outbound", Category: CategoryNetwork},
{Name: "c2_connection", Category: CategoryNetwork},
{Name: "direct_smtp_egress", Category: CategoryNetwork},
{Name: "dns_connection", Category: CategoryNetwork},
{Name: "dns_zone_change", Category: CategoryNetwork},
{Name: "exfiltration_paste_site", Category: CategoryNetwork},
{Name: "firewall", Category: CategoryNetwork},
{Name: "firewall_ports", Category: CategoryNetwork},
{Name: "bad_asn_outbound", Category: CategoryNetwork},
{Name: "infra_ips_unresolvable", Category: CategoryNetwork},
{Name: "ip_reputation", Category: CategoryNetwork},
{Name: "ssl_cert_issued", Category: CategoryNetwork},
{Name: "user_outbound_connection", Category: CategoryNetwork},
// --- System Integrity ------------------------------------------------
{Name: "af_alg_enforcement_corrected", Category: CategorySystem},
{Name: "af_alg_socket_use", Category: CategorySystem},
{Name: "account_scan_truncated", Category: CategorySystem},
{Name: "bpf_unavailable", Category: CategorySystem},
{Name: "crond_change", Category: CategorySystem},
{Name: "crontab_change", Category: CategorySystem},
{Name: "dpkg_integrity", Category: CategorySystem},
{Name: "kernel_module", Category: CategorySystem},
{Name: "mysql_superuser", Category: CategorySystem},
{Name: "rpm_integrity", Category: CategorySystem},
{Name: "sensitive_file_modified", Category: CategorySystem},
{Name: "signature_update_rescan_queued", Category: CategorySystem},
{Name: "suspicious_crontab", Category: CategorySystem},
// --- WAF & ModSecurity ----------------------------------------------
{Name: "modsec_block_escalation", Category: CategoryWAF},
{Name: "modsec_block_realtime", Category: CategoryWAF},
{Name: "modsec_csm_block_escalation", Category: CategoryWAF},
{Name: "modsec_warning_realtime", Category: CategoryWAF},
{Name: "waf_attack_blocked", Category: CategoryWAF},
{Name: "waf_bypass", Category: CategoryWAF},
{Name: "waf_detection_only", Category: CategoryWAF},
{Name: "waf_rules", Category: CategoryWAF},
{Name: "waf_rules_stale", Category: CategoryWAF},
{Name: "waf_status", Category: CategoryWAF},
// --- Correlation & Health -------------------------------------------
{Name: "account_scan", Category: CategoryCorrelation},
{Name: "auto_block", Category: CategoryCorrelation},
{Name: "auto_response", Category: CategoryCorrelation},
{Name: "challenge_route", Category: CategoryCorrelation},
{Name: "check_timeout", Category: CategoryCorrelation},
{Name: "config_reload_error", Category: CategoryCorrelation},
{Name: "config_reload_restart_required", Category: CategoryCorrelation},
{Name: "coordinated_attack", Category: CategoryCorrelation},
{Name: "csm_health", Category: CategoryCorrelation},
{Name: "fanotify_overflow", Category: CategoryCorrelation},
{Name: "integrity", Category: CategoryCorrelation},
{Name: "local_threat_score", Category: CategoryCorrelation},
{Name: "mail_log_source_unavailable", Category: CategoryCorrelation},
// --- Internal (not shown in user-facing dropdowns) -------------------
{Name: "test_alert", Category: CategoryInternal, Internal: true},
}
// AllCheckNames returns every registered Check name, sorted alphabetically.
// Includes internal names; callers that render user-facing UI should use
// PublicCheckInfos instead.
func AllCheckNames() []string {
out := make([]string, 0, len(checkRegistry))
for _, c := range checkRegistry {
out = append(out, c.Name)
}
sort.Strings(out)
return out
}
// PublicCheckInfos returns all non-Internal checks grouped by category in
// the canonical category order (see checkCategoryOrder). Within a category
// names are sorted alphabetically. This is the list the settings UI shows
// for alerts.email.disabled_checks.
func PublicCheckInfos() []CheckInfo {
byCategory := make(map[string][]CheckInfo, len(checkCategoryOrder))
for _, c := range checkRegistry {
if c.Internal {
continue
}
byCategory[c.Category] = append(byCategory[c.Category], c)
}
var out []CheckInfo
for _, cat := range checkCategoryOrder {
items := byCategory[cat]
sort.Slice(items, func(i, j int) bool { return items[i].Name < items[j].Name })
out = append(out, items...)
}
return out
}
// LookupCheck returns the registry entry for name, if any.
func LookupCheck(name string) (CheckInfo, bool) {
for _, c := range checkRegistry {
if c.Name == name {
return c, true
}
}
return CheckInfo{}, false
}
package checks
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"syscall"
"time"
)
// eximMsgIDRegex validates Exim message ID format. Exim 4.96 and older use
// 6-6-2 ids; Exim 4.97 and newer use 6-11-4 ids.
var eximMsgIDRegex = regexp.MustCompile(`^[0-9A-Za-z]{6}-(?:[0-9A-Za-z]{6}-[0-9A-Za-z]{2}|[0-9A-Za-z]{11}-[0-9A-Za-z]{4})$`)
// Allowed roots for each fix action. Declared as vars (not consts) so tests
// can redirect remediation under t.TempDir() without writing to real /home,
// /tmp, or /var/spool. Production must not mutate these at runtime.
var (
fixPermissionsAllowedRoots = []string{"/home"}
fixQuarantineAllowedRoots = []string{"/home", "/tmp", "/dev/shm", "/var/tmp"}
fixHtaccessAllowedRoots = []string{"/home"}
eximSpoolDirs = []string{"/var/spool/exim/input", "/var/spool/exim4/input"}
)
// RemediationResult describes the outcome of a fix action.
type RemediationResult struct {
Success bool `json:"success"`
Action string `json:"action"` // human-readable description of what was done
Description string `json:"description"` // what fix was applied
Error string `json:"error,omitempty"`
}
// FixDescription returns a human-readable description of what the fix will do
// for a given check type and file path. Returns empty string if no fix is available.
func FixDescription(checkType, message string, filePath ...string) string {
path := selectFindingPath(message, filePath...)
switch checkType {
case "world_writable_php", "group_writable_php":
if path != "" {
return fmt.Sprintf("Set permissions to 644 on %s", path)
}
case "webshell", "new_webshell_file", "obfuscated_php", "php_dropper",
"suspicious_php_content", "new_php_in_languages", "new_php_in_upgrade",
"phishing_page", "phishing_directory":
if path != "" {
return fmt.Sprintf("Quarantine %s to /opt/csm/quarantine/", path)
}
case "backdoor_binary", "new_executable_in_config":
if path != "" {
return fmt.Sprintf("Kill process and quarantine %s", path)
}
case "suspicious_crontab":
if path != "" {
return fmt.Sprintf("Quarantine and truncate crontab %s", path)
}
return "Quarantine and truncate crontab"
case "htaccess_injection", "htaccess_handler_abuse",
"htaccess_auto_prepend", "htaccess_errordocument_hijack",
"htaccess_filesmatch_shield", "htaccess_header_injection",
"htaccess_php_in_uploads", "htaccess_spam_redirect",
"htaccess_user_agent_cloak":
if path != "" {
return fmt.Sprintf("Remove malicious directives from %s", path)
}
case "email_phishing_content":
msgID := extractEximMsgID(message)
if msgID != "" {
return fmt.Sprintf("Quarantine Exim spool message %s", msgID)
}
}
return ""
}
// HasFix returns true if the check type has a known automated fix.
func HasFix(checkType string) bool {
fixableChecks := map[string]bool{
"world_writable_php": true,
"group_writable_php": true,
"webshell": true,
"new_webshell_file": true,
"obfuscated_php": true,
"php_dropper": true,
"suspicious_php_content": true,
"new_php_in_languages": true,
"new_php_in_upgrade": true,
"phishing_page": true,
"phishing_directory": true,
"backdoor_binary": true,
"new_executable_in_config": true,
"htaccess_injection": true,
"htaccess_handler_abuse": true,
// Per-pattern findings from the hardened detectors. Each routes
// through CleanHtaccessFile, which runs the full registry and
// removes every detector's matched ranges atomically.
"htaccess_auto_prepend": true,
"htaccess_errordocument_hijack": true,
"htaccess_filesmatch_shield": true,
"htaccess_header_injection": true,
"htaccess_php_in_uploads": true,
"htaccess_spam_redirect": true,
"htaccess_user_agent_cloak": true,
"email_phishing_content": true,
"suspicious_crontab": true,
}
return fixableChecks[checkType]
}
// ApplyFix executes the remediation action for a finding.
func ApplyFix(checkType, message, details string, filePath ...string) RemediationResult {
path := selectFindingPath(message, filePath...)
switch checkType {
case "world_writable_php", "group_writable_php":
return fixPermissions(path)
case "webshell", "new_webshell_file", "obfuscated_php", "php_dropper",
"suspicious_php_content", "new_php_in_languages", "new_php_in_upgrade",
"phishing_page", "phishing_directory":
return fixQuarantine(path)
case "backdoor_binary", "new_executable_in_config":
return fixKillAndQuarantine(path, details)
case "htaccess_injection", "htaccess_handler_abuse":
return fixHtaccess(path, message)
case "htaccess_auto_prepend", "htaccess_errordocument_hijack",
"htaccess_filesmatch_shield", "htaccess_header_injection",
"htaccess_php_in_uploads", "htaccess_spam_redirect",
"htaccess_user_agent_cloak":
// Per-pattern findings emit alongside the generic
// htaccess_injection / htaccess_handler_abuse categories from
// the existing detector. Both routes converge on byte-range
// cleaning here -- CleanHtaccessFile re-runs the full
// detector registry, so a single click cleans every malicious
// directive the audit found.
return CleanHtaccessFile(path)
case "email_phishing_content":
return fixQuarantineSpoolMessage(message)
case "suspicious_crontab":
return fixSuspiciousCrontab(path)
default:
return RemediationResult{Error: fmt.Sprintf("no automated fix available for check type '%s'", checkType)}
}
}
// fixPermissions sets file permissions to 0644.
func fixPermissions(path string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path from finding"}
}
path, info, err := resolveExistingFixPath(path, fixPermissionsAllowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
oldMode := info.Mode().Perm()
// #nosec G302 -- Intentional: this is the remediation that sets the
// canonical "safe web content" mode on a user file after we flagged
// the file as having dangerous perms (e.g. 0777). 0644 is what the
// webserver needs to serve static content as the file owner.
if err := os.Chmod(path, 0644); err != nil {
return RemediationResult{Error: fmt.Sprintf("chmod failed: %v", err)}
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("chmod 644 %s", path),
Description: fmt.Sprintf("Changed permissions from %o to 644", oldMode),
}
}
// fixQuarantine moves a file or directory to quarantine.
func fixQuarantine(path string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path from finding"}
}
path, info, err := resolveExistingFixPath(path, fixQuarantineAllowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
_ = os.MkdirAll(quarantineDir, 0700)
safeName := strings.ReplaceAll(path, "/", "_")
ts := time.Now().Format("20060102-150405")
qPath := filepath.Join(quarantineDir, fmt.Sprintf("%s_%s", ts, safeName))
// Directories use the standard rename (they're rare in quarantine
// remediation and harder to TOCTOU-swap atomically). Regular files
// go through the fd-based safe path which closes the detect-then-
// rename race window.
if info.IsDir() {
if err := os.Rename(path, qPath); err != nil {
return RemediationResult{Error: fmt.Sprintf("cannot quarantine directory: %v", err)}
}
} else if err := quarantineFileTOCTOUSafe(path, qPath, info); err != nil {
return RemediationResult{Error: err.Error()}
}
// Write metadata sidecar for restore
var uid, gid int
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
uid = int(stat.Uid)
gid = int(stat.Gid)
}
meta := map[string]interface{}{
"original_path": path,
"owner_uid": uid,
"group_gid": gid,
"mode": info.Mode().String(),
"size": info.Size(),
"quarantine_at": time.Now(),
"reason": "Fixed via CSM Web UI",
}
metaData, _ := json.MarshalIndent(meta, "", " ")
if err := os.WriteFile(qPath+".meta", metaData, 0600); err != nil {
fmt.Fprintf(os.Stderr, "remediate: error writing quarantine metadata %s: %v\n", qPath+".meta", err)
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("quarantined %s → %s", path, qPath),
Description: fmt.Sprintf("Moved to quarantine: %s", qPath),
}
}
// fixKillAndQuarantine kills any process using the file, then quarantines it.
func fixKillAndQuarantine(path, details string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path from finding"}
}
// Try to extract and kill PID from details
pid := extractPID(details)
if pid != "" {
pidInt := 0
fmt.Sscanf(pid, "%d", &pidInt)
if pidInt > 1 {
uid := getProcessUID(pid)
if uid != "0" && uid != "" { // never kill root
_ = syscall.Kill(pidInt, syscall.SIGKILL)
}
}
}
// Then quarantine
result := fixQuarantine(path)
if result.Success && pid != "" {
result.Action = fmt.Sprintf("killed PID %s and %s", pid, result.Action)
result.Description = "Process killed and file quarantined"
}
return result
}
// fixHtaccess removes malicious directives from an .htaccess file while
// preserving comments and known-safe directives (e.g., Wordfence, LiteSpeed).
func fixHtaccess(path, message string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path"}
}
if filepath.Base(path) != ".htaccess" {
return RemediationResult{Error: "automated .htaccess remediation only applies to .htaccess files"}
}
path, _, err := resolveExistingFixPath(path, fixHtaccessAllowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
data, err := osFS.ReadFile(path)
if err != nil {
return RemediationResult{Error: fmt.Sprintf("cannot read: %v", err)}
}
dangerous := []string{"auto_prepend_file", "auto_append_file", "eval(", "base64_decode",
"gzinflate", "str_rot13", "addhandler", "sethandler"}
safe := []string{
"wordfence-waf.php", "litespeed", "advanced-headers.php", "rsssl",
"application/x-httpd-php", "application/x-httpd-ea-php", "application/x-httpd-alt-php",
"-execcgi", "sethandler none", "sethandler default-handler",
"text/html", "text/css", "text/javascript", "application/javascript",
"image/", "font/", ".woff", ".woff2", ".ttf", ".eot", ".svg",
"wordfence",
}
var cleaned []string
removed := 0
var phpHandlerContexts []phpHandlerOverlay
// Iterate logical directives so a malicious mapping split across an Apache
// line continuation is removed as a unit (every physical line it spans).
for _, logical := range joinHtaccessContinuations(strings.Split(string(data), "\n")) {
trimmed := strings.TrimSpace(logical.text)
lineLower := strings.ToLower(trimmed)
if strings.HasPrefix(trimmed, "#") {
cleaned = append(cleaned, logical.lines...)
continue
}
if ctx, ok := openPHPHandlerContext(trimmed); ok {
phpHandlerContexts = append(phpHandlerContexts, ctx)
cleaned = append(cleaned, logical.lines...)
continue
}
if closesPHPHandlerContext(trimmed) {
if len(phpHandlerContexts) > 0 {
phpHandlerContexts = phpHandlerContexts[:len(phpHandlerContexts)-1]
}
cleaned = append(cleaned, logical.lines...)
continue
}
isDangerous := false
if phpHandlerRemapsNonPHPInContext(lineLower, phpHandlerContexts) {
isDangerous = true
}
for _, d := range dangerous {
if strings.Contains(lineLower, d) {
isSafe := false
for _, s := range safe {
if strings.Contains(lineLower, s) {
isSafe = true
break
}
}
if !isSafe {
isDangerous = true
break
}
}
}
if isDangerous {
removed++
} else {
cleaned = append(cleaned, logical.lines...)
}
}
if removed == 0 {
return RemediationResult{Error: "no malicious directives found to remove"}
}
// #nosec G306 -- .htaccess rewritten for a user's public_html; 0644 is
// the mode the webserver expects for static content.
if err := os.WriteFile(path, []byte(strings.Join(cleaned, "\n")), 0644); err != nil {
return RemediationResult{Error: fmt.Sprintf("write failed: %v", err)}
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("removed %d malicious directive(s) from %s", removed, path),
Description: fmt.Sprintf("Cleaned .htaccess: removed %d line(s)", removed),
}
}
// extractFilePathFromMessage extracts a file path from a finding message.
// Handles patterns like "World-writable PHP file: /path/to/file"
// and "Webshell found: /path/to/file"
func extractFilePathFromMessage(message string) string {
// Look for /home/ or /tmp/ paths
for _, prefix := range []string{"/home/", "/tmp/", "/dev/shm/", "/var/tmp/"} {
idx := strings.Index(message, prefix)
if idx < 0 {
continue
}
rest := message[idx:]
// Path ends at space, comma, newline, or end
endIdx := len(rest)
for i, c := range rest {
if c == ' ' || c == ',' || c == '\n' || c == ')' {
endIdx = i
break
}
}
return rest[:endIdx]
}
return ""
}
func selectFindingPath(message string, filePath ...string) string {
if len(filePath) > 0 {
if path := strings.TrimSpace(filePath[0]); path != "" {
return path
}
}
return extractFilePathFromMessage(message)
}
func resolveExistingFixPath(path string, allowedRoots []string) (string, os.FileInfo, error) {
cleanPath, err := sanitizeFixPath(path, allowedRoots)
if err != nil {
return "", nil, err
}
info, err := osFS.Lstat(cleanPath)
if err != nil {
return "", nil, fmt.Errorf("file not found: %v", err)
}
if info.Mode()&os.ModeSymlink != 0 {
return "", nil, fmt.Errorf("symlinked paths are not eligible for automated remediation: %s", cleanPath)
}
resolved, err := filepath.EvalSymlinks(cleanPath)
if err != nil {
return "", nil, fmt.Errorf("cannot resolve path: %v", err)
}
resolved, err = sanitizeFixPath(resolved, allowedRoots)
if err != nil {
return "", nil, err
}
if accountRoot := homeAccountRoot(cleanPath); accountRoot != "" && !isPathWithinOrEqual(resolved, accountRoot) {
return "", nil, fmt.Errorf("resolved path escapes account boundary: %s", resolved)
}
resolvedInfo, err := osFS.Lstat(resolved)
if err != nil {
return "", nil, fmt.Errorf("file not found: %v", err)
}
if resolvedInfo.Mode()&os.ModeSymlink != 0 {
return "", nil, fmt.Errorf("symlinked paths are not eligible for automated remediation: %s", resolved)
}
return resolved, resolvedInfo, nil
}
func sanitizeFixPath(path string, allowedRoots []string) (string, error) {
path = filepath.Clean(strings.TrimSpace(path))
if path == "" {
return "", fmt.Errorf("file path is required")
}
if !filepath.IsAbs(path) {
return "", fmt.Errorf("file path must be absolute")
}
for _, root := range allowedRoots {
if isPathWithinOrEqual(path, root) {
return path, nil
}
}
return "", fmt.Errorf("file path is outside the allowed remediation roots: %s", path)
}
func isPathWithinOrEqual(path, base string) bool {
cleanPath := filepath.Clean(path)
cleanBase := filepath.Clean(base)
return cleanPath == cleanBase || strings.HasPrefix(cleanPath, cleanBase+string(filepath.Separator))
}
func homeAccountRoot(path string) string {
clean := filepath.Clean(path)
if !strings.HasPrefix(clean, "/home/") {
return ""
}
parts := strings.Split(clean, string(filepath.Separator))
if len(parts) < 4 {
return ""
}
return filepath.Join("/home", parts[2])
}
// extractEximMsgID extracts an Exim message ID from a finding message.
// Matches the pattern "(message: XXXXXX-XXXXXX-XX)" used by emailscan.go.
func extractEximMsgID(message string) string {
prefix := "(message: "
idx := strings.Index(message, prefix)
if idx < 0 {
return ""
}
rest := message[idx+len(prefix):]
end := strings.Index(rest, ")")
if end < 0 {
return ""
}
return strings.TrimSpace(rest[:end])
}
// fixQuarantineSpoolMessage moves Exim spool files (-H header and -D body)
// for a message ID into quarantine.
func fixQuarantineSpoolMessage(message string) RemediationResult {
msgID := extractEximMsgID(message)
if msgID == "" {
return RemediationResult{Error: "could not extract Exim message ID from finding"}
}
// Validate Exim message ID format to prevent path traversal
if !eximMsgIDRegex.MatchString(msgID) {
return RemediationResult{Error: fmt.Sprintf("invalid Exim message ID format: %s", msgID)}
}
var spoolDir string
for _, dir := range eximSpoolDirs {
if _, err := osFS.Stat(filepath.Join(dir, msgID+"-H")); err == nil {
spoolDir = dir
break
}
}
if spoolDir == "" {
return RemediationResult{Error: fmt.Sprintf("spool message %s not found (already delivered or removed)", msgID)}
}
_ = os.MkdirAll(quarantineDir, 0700)
ts := time.Now().Format("20060102-150405")
moved := 0
for _, suffix := range []string{"-H", "-D"} {
src := filepath.Join(spoolDir, msgID+suffix)
if _, err := osFS.Stat(src); err != nil {
continue
}
dst := filepath.Join(quarantineDir, fmt.Sprintf("%s_exim_%s%s", ts, msgID, suffix))
if err := os.Rename(src, dst); err != nil {
// Cross-device fallback
data, readErr := osFS.ReadFile(src)
if readErr != nil {
return RemediationResult{Error: fmt.Sprintf("cannot read %s: %v", src, readErr)}
}
if writeErr := os.WriteFile(dst, data, 0600); writeErr != nil {
return RemediationResult{Error: fmt.Sprintf("cannot write quarantine: %v", writeErr)}
}
os.Remove(src)
}
moved++
}
if moved == 0 {
return RemediationResult{Error: fmt.Sprintf("no spool files found for message %s", msgID)}
}
// Write metadata sidecar
meta := map[string]interface{}{
"message_id": msgID,
"spool_dir": spoolDir,
"quarantine_at": time.Now(),
"reason": "Phishing email quarantined via CSM Web UI",
}
metaData, _ := json.MarshalIndent(meta, "", " ")
metaPath := filepath.Join(quarantineDir, fmt.Sprintf("%s_exim_%s.meta", ts, msgID))
if err := os.WriteFile(metaPath, metaData, 0600); err != nil {
fmt.Fprintf(os.Stderr, "remediate: error writing spool quarantine metadata %s: %v\n", metaPath, err)
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("quarantined spool message %s (%d files)", msgID, moved),
Description: fmt.Sprintf("Exim spool files moved to quarantine for message %s", msgID),
}
}
package checks
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
"github.com/pidginhost/csm/internal/threatintel"
)
const (
reputationCacheFile = "reputation_cache.json"
reputationEximMainlog = "/var/log/exim_mainlog"
reputationWHMAccessLog = "/usr/local/cpanel/logs/access_log"
cacheExpiry = 6 * time.Hour
errorCacheExpiry = 1 * time.Hour // cache transient API errors to avoid retrying same IP
abuseConfidenceThreshold = 50
maxQueriesPerCycle = 5 // max AbuseIPDB API calls per 10-min cycle (~720/day, fits free tier)
maxCacheEntries = 5000 // cap cache size
)
// maxDailyAbuseQueries is the store-backed daily circuit-breaker below
// the 1000/day free-tier ceiling. The 100-slot cushion below 1000 leaves
// room for API-side accounting differences and fallback paths that cannot
// share the store counter. Declared as a var (not const) so tests can lower it
// without burning seconds on 900 bbolt transactions. Production callers
// must not modify this.
var maxDailyAbuseQueries = 900
// abuseIPDBEndpoint is the URL queried for IP reputation. Declared as a
// var (not const) so tests can point it at an httptest server. Production
// callers must not modify this.
var abuseIPDBEndpoint = "https://api.abuseipdb.com/api/v2/check"
// abuseIPDBClient is the HTTP client used for AbuseIPDB queries. Declared
// at package scope so tests can swap in a mock client (e.g., one whose
// transport routes all traffic to an httptest server).
var abuseIPDBClient = &http.Client{Timeout: 10 * time.Second}
type reputationCache struct {
Entries map[string]*reputationEntry `json:"entries"`
}
type reputationEntry struct {
Score int `json:"score"`
Category string `json:"category"`
CheckedAt time.Time `json:"checked_at"`
}
// CheckIPReputation looks up non-infra IPs against threat intelligence.
// Four-tier approach:
// 1. Skip if already blocked
// 2. Check local threat DB (permanent blocklist + free feeds)
// 3. Check AbuseIPDB cache
// 4. Query AbuseIPDB for truly unknown IPs (max 5/cycle, ~720/day)
func CheckIPReputation(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
supplementalAgg := newSupplementalThreatAggregator(cfg)
ips := collectRecentIPs(cfg)
if len(ips) == 0 {
return nil
}
alreadyBlocked := loadAllBlockedIPs(cfg.StatePath)
threatDB := GetThreatDB()
cache := loadReputationCache(cfg.StatePath)
client := abuseIPDBClient
sdb := store.Global()
now := time.Now()
utcDay := now.UTC().Format("2006-01-02")
quotaExhausted := !abuseQuotaReady(sdb, now)
// Two-pass design so the slow AbuseIPDB HTTP queries can run in
// parallel:
//
// Pass 1 (serial): walk every IP and resolve via tier 1/2/3 plus
// the supplemental aggregator. Collect the IPs that genuinely
// need a tier-4 HTTP lookup into pendingQueries.
//
// Pass 2 (parallel, up to maxQueriesPerCycle workers): fan out
// queryAbuseIPDB and collect results.
//
// Pass 3 (serial): apply results back into the cache and emit
// findings.
//
// Pre-cache, all five HTTP queries ran in a serial loop, so a
// cycle paid ~5x worst-case AbuseIPDB latency. A busy production
// host saw ip_reputation averaging ~3.6 s per run because of
// this; the fan-out brings that down to ~max(single-call latency).
type pendingQuery struct {
ip string
source string
}
var pendingQueries []pendingQuery
for ip, source := range ips {
// Tier 1: Skip if already blocked
if alreadyBlocked[ip] {
continue
}
// Tier 2: Check local threat DB
if threatDB != nil {
if dbSource, found := threatDB.Lookup(ip); found {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "ip_reputation",
Message: fmt.Sprintf("Known malicious IP accessing server: %s (source: %s)", ip, dbSource),
Details: fmt.Sprintf("Detected via: %s\nMatched in local threat intelligence database", source),
Timestamp: time.Now(),
})
continue
}
}
// Tier 3: Check AbuseIPDB cache. Treat entries with CheckedAt in
// the future (legacy data written by a prior buggy error-caching
// formula) as expired so they get re-queried or aged out.
if entry, ok := cache.Entries[ip]; ok {
age := time.Since(entry.CheckedAt)
if age >= 0 && age < cacheExpiry {
if entry.Score >= abuseConfidenceThreshold {
appendReputationFinding(&findings, ip, source, "AbuseIPDB", entry.Score, entry.Category)
} else if score, src, ok := supplementalThreatScore(ctx, supplementalAgg, ip); ok && score >= abuseConfidenceThreshold {
appendReputationFinding(&findings, ip, source, src, score, strings.ToLower(src)+" history")
}
continue
}
}
// Tier 4 candidate; defer the HTTP call to pass 2 unless the
// quota / config gates already preclude querying.
if cfg.Reputation.AbuseIPDBKey == "" || quotaExhausted || len(pendingQueries) >= maxQueriesPerCycle {
if score, src, ok := supplementalThreatScore(ctx, supplementalAgg, ip); ok && score >= abuseConfidenceThreshold {
appendReputationFinding(&findings, ip, source, src, score, strings.ToLower(src)+" history")
}
continue
}
pendingQueries = append(pendingQueries, pendingQuery{ip: ip, source: source})
}
// Pass 2: reserve daily quota slots up front and fan out the HTTP
// calls. The pre-reservation matches the prior "count the attempt
// before the call so a crash or network hang still consumes a slot"
// guarantee, while keeping near-cap cycles from spending more slots
// than the store can reserve.
type queryResult struct {
score int
category string
err error
}
if len(pendingQueries) > 0 {
if sdb != nil {
reserved := sdb.ReserveAbuseQuerySlots(utcDay, len(pendingQueries), maxDailyAbuseQueries)
if reserved < len(pendingQueries) {
for _, q := range pendingQueries[reserved:] {
if supplemental, src, ok := supplementalThreatScore(ctx, supplementalAgg, q.ip); ok && supplemental >= abuseConfidenceThreshold {
appendReputationFinding(&findings, q.ip, q.source, src, supplemental, strings.ToLower(src)+" history")
}
}
pendingQueries = pendingQueries[:reserved]
}
}
}
results := make(map[string]queryResult, len(pendingQueries))
if len(pendingQueries) > 0 {
var mu sync.Mutex
var wg sync.WaitGroup
workers := len(pendingQueries)
if workers > maxQueriesPerCycle {
workers = maxQueriesPerCycle
}
jobs := make(chan pendingQuery, len(pendingQueries))
for _, q := range pendingQueries {
jobs <- q
}
close(jobs)
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for q := range jobs {
score, category, err := queryAbuseIPDB(client, q.ip, cfg.Reputation.AbuseIPDBKey)
mu.Lock()
results[q.ip] = queryResult{score: score, category: category, err: err}
mu.Unlock()
}
}()
}
wg.Wait()
}
// Pass 3: apply each tier-4 result back into cache + findings.
// Serial so cache writes and quota-exhaustion handling stay
// consistent regardless of which worker observed which HTTP error.
for _, q := range pendingQueries {
res, ok := results[q.ip]
if !ok {
continue
}
if res.err != nil {
if strings.Contains(res.err.Error(), "429") || strings.Contains(res.err.Error(), "402") {
resetAt := nextUTCMidnight(time.Now())
fmt.Fprintf(os.Stderr, "abuseipdb: quota exhausted (%v), pausing lookups until %s\n",
res.err, resetAt.Format(time.RFC3339))
// Persisted backoff is the load-bearing signal for the
// next cycle's classifier; the cycle-local quotaExhausted
// flag has no further reader past this loop.
if sdb != nil {
_ = sdb.SetAbuseQuotaExhaustedUntil(resetAt)
}
if supplemental, src, ok := supplementalThreatScore(ctx, supplementalAgg, q.ip); ok && supplemental >= abuseConfidenceThreshold {
appendReputationFinding(&findings, q.ip, q.source, src, supplemental, strings.ToLower(src)+" history")
}
continue
}
cache.Entries[q.ip] = &reputationEntry{
Score: -1,
Category: fmt.Sprintf("error: %v", res.err),
// CheckedAt is shifted into the past so time.Since returns
// ~(cacheExpiry-errorCacheExpiry) immediately; the Tier-3
// freshness check then flips false after a further
// errorCacheExpiry, giving a real ~1h TTL on error entries.
CheckedAt: time.Now().Add(-(cacheExpiry - errorCacheExpiry)),
}
if supplemental, src, ok := supplementalThreatScore(ctx, supplementalAgg, q.ip); ok && supplemental >= abuseConfidenceThreshold {
appendReputationFinding(&findings, q.ip, q.source, src, supplemental, strings.ToLower(src)+" history")
}
continue
}
cache.Entries[q.ip] = &reputationEntry{
Score: res.score,
Category: res.category,
CheckedAt: time.Now(),
}
score := res.score
category := res.category
provider := "AbuseIPDB"
if supplemental, src, ok := supplementalThreatScore(ctx, supplementalAgg, q.ip); ok && supplemental > score {
score = supplemental
category = strings.ToLower(src) + " history"
provider = src
}
if score >= abuseConfidenceThreshold {
appendReputationFinding(&findings, q.ip, q.source, provider, score, category)
}
}
// Clean and cap cache
cleanCache(cache)
saveReputationCache(cfg.StatePath, cache)
return findings
}
func newSupplementalThreatAggregator(cfg *config.Config) *threatintel.Aggregator {
if !cfg.Reputation.Upstream.Enabled {
threatintel.ClearUpstreamMetricsSource()
}
if !cfg.Reputation.Rspamd.Enabled && !cfg.Reputation.Upstream.Enabled {
return nil
}
agg := threatintel.NewAggregator()
if cfg.Reputation.Rspamd.Enabled {
agg.Register(threatintel.NewRspamdSource(
cfg.Reputation.Rspamd.URL,
cfg.Reputation.Rspamd.Token,
cfg.Reputation.Rspamd.TokenEnv,
))
}
if cfg.Reputation.Upstream.Enabled {
upstream := threatintel.NewUpstreamSource(threatintel.UpstreamConfig{
URL: cfg.Reputation.Upstream.URL,
Token: cfg.Reputation.Upstream.Token,
TokenEnv: cfg.Reputation.Upstream.TokenEnv,
CacheTTL: time.Duration(cfg.Reputation.Upstream.CacheTTLMin) * time.Minute,
Timeout: time.Duration(cfg.Reputation.Upstream.TimeoutSec) * time.Second,
})
threatintel.RegisterUpstreamMetrics(metrics.Default(), upstream)
agg.Register(upstream)
}
return agg
}
// supplementalThreatScore queries the aggregator for ip and returns the
// aggregated score, the name of the highest-scoring individual source
// (capitalised for operator-facing messages), and whether a usable score
// was found. Returns ("", 0, false) when agg is nil or no source scored.
func supplementalThreatScore(ctx context.Context, agg *threatintel.Aggregator, ip string) (int, string, bool) {
if agg == nil {
return 0, "", false
}
res, err := agg.Score(ctx, ip)
if err != nil || res.AggregatedScore == 0 {
return 0, "", false
}
// Identify the source with the highest individual score so callers can
// label findings accurately (e.g. "Rspamd" vs "Upstream").
dominant := "supplemental"
max := 0
for name, s := range res.Sources {
if s > max {
max = s
dominant = name
}
}
return res.AggregatedScore, capitalizeProvider(dominant), true
}
// capitalizeProvider title-cases known source names for operator-facing
// messages ("rspamd" -> "Rspamd", "upstream" -> "Upstream").
func capitalizeProvider(name string) string {
if len(name) == 0 {
return name
}
return strings.ToUpper(name[:1]) + name[1:]
}
func appendReputationFinding(findings *[]alert.Finding, ip, detectedVia, provider string, score int, category string) {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "ip_reputation",
Message: fmt.Sprintf("Known malicious IP accessing server: %s (%s score: %d/100)", ip, provider, score),
Details: fmt.Sprintf("Detected via: %s\nCategory: %s\nThis IP is reported in threat intelligence databases", detectedVia, category),
Timestamp: time.Now(),
})
}
// nextUTCMidnight returns 00:00 UTC on the day after now — the point at
// which AbuseIPDB's daily quota resets.
func nextUTCMidnight(now time.Time) time.Time {
u := now.UTC()
return time.Date(u.Year(), u.Month(), u.Day()+1, 0, 0, 0, 0, time.UTC)
}
// abuseQuotaReady reports whether we may call AbuseIPDB right now. It
// combines the persisted backoff (set when the API returns 429/402) with
// the daily query counter (stops before we approach the free-tier cap).
// Returns true when no bbolt store is available (fallback mode).
func abuseQuotaReady(sdb *store.DB, now time.Time) bool {
if sdb == nil {
return true
}
if until := sdb.AbuseQuotaExhaustedUntil(); !until.IsZero() && now.Before(until) {
return false
}
if sdb.AbuseQueryCount(now.UTC().Format("2006-01-02")) >= maxDailyAbuseQueries {
return false
}
return true
}
// collectRecentIPs gathers non-infra IPs from multiple log sources.
// Returns map of IP → source description (e.g. "SSH login", "Dovecot IMAP auth failure").
func collectRecentIPs(cfg *config.Config) map[string]string {
ips := make(map[string]string)
info := platform.Detect()
// SSH logins. Path differs by OS family (secure vs auth.log); the old
// hardcoded /var/log/secure made this loop dead on Debian/Ubuntu.
for _, line := range tailFile(info.AuthLogPath(), 50) {
if !strings.Contains(line, "Accepted") {
continue
}
if ip := extractIPAfterKeyword(line, "from"); ip != "" {
addIfNotInfra(ips, ip, "SSH login", cfg)
}
}
// Web server access logs, platform-detected (Apache/Nginx/LiteSpeed).
for _, path := range info.AccessLogPaths {
if path == "" || isWHMAccessLog(path) {
continue
}
lines := tailFile(path, 100)
if len(lines) == 0 {
continue
}
for _, line := range lines {
if ip := firstField(line); ip != "" {
addIfNotInfra(ips, ip, "HTTP request", cfg)
}
}
}
// Dovecot - IMAP/POP3 auth failures.
if mailLog := reputationMailLogPath(cfg, info); mailLog != "" {
for _, line := range tailFile(mailLog, 50) {
if strings.Contains(line, "auth failed") || strings.Contains(line, "Aborted login") {
if ip := extractIPAfterKeyword(line, "rip="); ip != "" {
addIfNotInfra(ips, ip, "Dovecot IMAP/POP3 auth failure", cfg)
}
}
}
}
if info.IsCPanel() {
// cPanel/WHM access log.
for _, line := range tailFile(reputationWHMAccessLog, 100) {
if ip := firstField(line); ip != "" {
addIfNotInfra(ips, ip, "cPanel/WHM access", cfg)
}
}
}
if shouldCollectEximMainlog(info) {
for _, line := range tailFile(reputationEximMainlog, 50) {
if strings.Contains(line, "authenticator failed") || strings.Contains(line, "rejected RCPT") {
if ip := extractBracketedIP(line); ip != "" {
addIfNotInfra(ips, ip, "SMTP auth failure", cfg)
}
}
}
}
return ips
}
func reputationMailLogPath(cfg *config.Config, info platform.Info) string {
if cfg == nil {
return info.MailLogPath()
}
if cfg.MailLogs.Source == "journal" {
return ""
}
if cfg.MailLogs.File != "" {
return cfg.MailLogs.File
}
return info.MailLogPath()
}
func isWHMAccessLog(path string) bool {
return filepath.Clean(path) == reputationWHMAccessLog
}
func shouldCollectEximMainlog(info platform.Info) bool {
if info.IsCPanel() {
return true
}
if _, err := osFS.Stat(reputationEximMainlog); err != nil {
return !os.IsNotExist(err)
}
return true
}
func addIfNotInfra(ips map[string]string, ip, source string, cfg *config.Config) {
if ip == "127.0.0.1" || ip == "::1" || ip == "" {
return
}
if isInfraIP(ip, cfg.InfraIPs) {
return
}
if _, exists := ips[ip]; !exists {
ips[ip] = source
}
}
func firstField(line string) string {
fields := strings.Fields(line)
if len(fields) == 0 {
return ""
}
ip := fields[0]
// Validate it looks like an IP (v4 or v6)
if strings.Count(ip, ".") == 3 || strings.Contains(ip, ":") {
return ip
}
return ""
}
func extractIPAfterKeyword(line, keyword string) string {
idx := strings.Index(line, keyword)
if idx < 0 {
return ""
}
rest := line[idx+len(keyword):]
rest = strings.TrimLeft(rest, " =")
fields := strings.Fields(rest)
if len(fields) == 0 {
return ""
}
ip := strings.TrimRight(fields[0], ",:;)([]")
if strings.Count(ip, ".") == 3 || strings.Contains(ip, ":") {
return ip
}
return ""
}
func extractBracketedIP(line string) string {
// Extract IP from [1.2.3.4] format common in exim logs
start := strings.Index(line, "[")
if start < 0 {
return ""
}
end := strings.Index(line[start:], "]")
if end < 0 {
return ""
}
ip := line[start+1 : start+end]
if strings.Count(ip, ".") == 3 || strings.Contains(ip, ":") {
return ip
}
return ""
}
// loadAllBlockedIPs returns all IPs currently blocked in CSM.
// It reads firewall state through osFS so tests can inject a filesystem,
// then merges the legacy blocked_ips.json file.
func loadAllBlockedIPs(statePath string) map[string]bool {
blocked := make(map[string]bool)
// Read the authoritative firewall engine state. The engine persists
// every block to firewall/state.json; the parallel bbolt fw:blocked
// bucket is written only at migration, so reading it would return a
// frozen snapshot that misses live blocks.
fwPath := filepath.Join(statePath, "firewall", "state.json")
if fwData, err := osFS.ReadFile(fwPath); err == nil {
var fwState struct {
Blocked []struct {
IP string `json:"ip"`
ExpiresAt time.Time `json:"expires_at"`
} `json:"blocked"`
}
if uerr := json.Unmarshal(fwData, &fwState); uerr != nil {
fmt.Fprintf(os.Stderr, "reputation: %s is corrupt, alert suppression degraded: %v\n", fwPath, uerr)
} else {
now := time.Now()
for _, entry := range fwState.Blocked {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
blocked[entry.IP] = true
}
}
}
}
// Also read from blocked_ips.json (legacy CSM auto-block)
type blockedEntry struct {
IP string `json:"ip"`
ExpiresAt time.Time `json:"expires_at"`
}
type blockFile struct {
IPs []blockedEntry `json:"ips"`
Pending []struct {
IP string `json:"ip"`
} `json:"pending,omitempty"`
}
legacyPath := filepath.Join(statePath, "blocked_ips.json")
data, err := osFS.ReadFile(legacyPath)
if err == nil {
var bf blockFile
if uerr := json.Unmarshal(data, &bf); uerr != nil {
fmt.Fprintf(os.Stderr, "reputation: %s is corrupt, alert suppression degraded: %v\n", legacyPath, uerr)
} else {
now := time.Now()
for _, entry := range bf.IPs {
if now.Before(entry.ExpiresAt) {
blocked[entry.IP] = true
}
}
for _, entry := range bf.Pending {
blocked[entry.IP] = true
}
}
}
return blocked
}
type abuseIPDBResponse struct {
Data struct {
AbuseConfidenceScore int `json:"abuseConfidenceScore"`
UsageType string `json:"usageType"`
ISP string `json:"isp"`
TotalReports int `json:"totalReports"`
} `json:"data"`
Errors []struct {
Detail string `json:"detail"`
Status int `json:"status"`
} `json:"errors"`
}
// queryAbuseIPDB returns (score, category, error).
// Returns specific errors for rate limiting (429) and quota exhaustion (402).
func queryAbuseIPDB(client *http.Client, ip, apiKey string) (int, string, error) {
req, err := http.NewRequest("GET", abuseIPDBEndpoint+"?ipAddress="+ip+"&maxAgeInDays=90", nil)
if err != nil {
return 0, "", err
}
req.Header.Set("Key", apiKey)
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return 0, "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode == 429 {
return 0, "", fmt.Errorf("429 rate limited")
}
if resp.StatusCode == 402 {
return 0, "", fmt.Errorf("402 quota exceeded")
}
if resp.StatusCode != 200 {
return 0, "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
var result abuseIPDBResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return 0, "", err
}
if len(result.Errors) > 0 {
return 0, "", fmt.Errorf("API error: %s", result.Errors[0].Detail)
}
category := result.Data.UsageType
if result.Data.ISP != "" {
category += " (" + result.Data.ISP + ")"
}
if result.Data.TotalReports > 0 {
category += fmt.Sprintf(", %d reports", result.Data.TotalReports)
}
return result.Data.AbuseConfidenceScore, category, nil
}
// cleanCache removes expired entries and caps at maxCacheEntries.
func cleanCache(cache *reputationCache) {
// When using bbolt store, delegate cleanup to store methods.
if sdb := store.Global(); sdb != nil {
sdb.CleanExpiredReputation(cacheExpiry)
sdb.EnforceReputationCap(maxCacheEntries)
return
}
// Fallback: in-memory cache cleanup.
now := time.Now()
// Remove expired entries - use same expiry as cache freshness check
for ip, entry := range cache.Entries {
if now.Sub(entry.CheckedAt) > cacheExpiry {
delete(cache.Entries, ip)
}
}
// Cap at max entries - remove oldest if over limit
if len(cache.Entries) > maxCacheEntries {
type aged struct {
ip string
age time.Duration
}
entries := make([]aged, 0, len(cache.Entries))
for ip, entry := range cache.Entries {
entries = append(entries, aged{ip, now.Sub(entry.CheckedAt)})
}
// Sort by age descending (oldest first)
sort.Slice(entries, func(i, j int) bool {
return entries[i].age > entries[j].age
})
// Remove oldest until under limit
for i := 0; i < len(entries)-maxCacheEntries; i++ {
delete(cache.Entries, entries[i].ip)
}
}
}
func loadReputationCache(statePath string) *reputationCache {
cache := &reputationCache{Entries: make(map[string]*reputationEntry)}
// Try bbolt store first - after migration the flat file is renamed to .bak.
if sdb := store.Global(); sdb != nil {
for ip, entry := range sdb.AllReputation() {
cache.Entries[ip] = &reputationEntry{
Score: entry.Score,
Category: entry.Category,
CheckedAt: entry.CheckedAt,
}
}
return cache
}
// Fallback: flat-file JSON (pre-migration).
data, err := osFS.ReadFile(filepath.Join(statePath, reputationCacheFile))
if err == nil {
_ = json.Unmarshal(data, cache)
if cache.Entries == nil {
cache.Entries = make(map[string]*reputationEntry)
}
}
return cache
}
func saveReputationCache(statePath string, cache *reputationCache) {
if sdb := store.Global(); sdb != nil {
for ip, entry := range cache.Entries {
_ = sdb.SetReputation(ip, store.ReputationEntry{
Score: entry.Score,
Category: entry.Category,
CheckedAt: entry.CheckedAt,
})
}
return
}
// Fallback: flat-file JSON.
data, _ := json.MarshalIndent(cache, "", " ")
tmpPath := filepath.Join(statePath, reputationCacheFile+".tmp")
_ = os.WriteFile(tmpPath, data, 0600)
_ = os.Rename(tmpPath, filepath.Join(statePath, reputationCacheFile))
}
package checks
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/state"
)
// autoResponseActions counts every auto-response action fired, by
// action class. Registered lazily on first observation.
var (
autoResponseActions *metrics.CounterVec
autoResponseActionsOnce sync.Once
)
func observeAutoResponse(action string, n int) {
if n <= 0 {
return
}
autoResponseActionsOnce.Do(func() {
autoResponseActions = metrics.NewCounterVec(
"csm_auto_response_actions_total",
"Auto-response actions fired, by action. Labels: action (kill|quarantine|block). Incremented once per finding the auto-response subsystem produced in each tier run; a batch of four IPs blocked in one cycle contributes 4 to action=block.",
[]string{"action"},
)
metrics.MustRegister("csm_auto_response_actions_total", autoResponseActions)
})
autoResponseActions.With(action).Add(float64(n))
}
// checkDuration is the per-check latency histogram for /metrics.
// Labelled by check name and tier so scrapers can spot a single check
// regressing without scanning logs. Buckets span the observed range
// from ~millisecond process-list passes up to the heavy-check timeout
// ceiling.
var (
checkDuration *metrics.HistogramVec
checkDurationOnce sync.Once
)
var checkDurationBuckets = []float64{0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30, 60, 120, 180, 300, 600, 900}
func observeCheckDuration(name, tier string, d time.Duration) {
checkDurationOnce.Do(func() {
checkDuration = metrics.NewHistogramVec(
"csm_check_duration_seconds",
"Wall-clock time for each security check to complete. Label `name` is a check runner name; label `tier` is critical|deep|all. Use p95 across name to spot a single check regressing, and sum across name to track per-cycle pressure.",
[]string{"name", "tier"},
checkDurationBuckets,
)
metrics.MustRegister("csm_check_duration_seconds", checkDuration)
})
checkDuration.With(name, tier).Observe(d.Seconds())
}
// CheckFunc is the signature for all check functions.
// The context is cancelled when the check times out so goroutines can exit.
type CheckFunc func(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding
// splitDisabledChecks partitions checks by cfg.DisabledChecks. Finding names
// are the public vocabulary used by the settings UI and docs; runner names
// remain accepted for existing operator configs.
func splitDisabledChecks(cfg *config.Config, checks []namedCheck) (enabled, disabled []namedCheck) {
if cfg == nil || len(cfg.DisabledChecks) == 0 {
return checks, nil
}
disabledSet := make(map[string]struct{}, len(cfg.DisabledChecks))
knownRunners := make(map[string]struct{}, len(checks))
for _, nc := range checks {
knownRunners[nc.name] = struct{}{}
}
for _, name := range cfg.DisabledChecks {
name = strings.TrimSpace(name)
if name == "" {
continue
}
if _, ok := knownRunners[name]; ok {
disabledSet[name] = struct{}{}
continue
}
for _, runner := range runnerNamesForFinding(name) {
if _, ok := knownRunners[runner]; ok {
disabledSet[runner] = struct{}{}
}
}
}
if len(disabledSet) == 0 {
return checks, nil
}
enabled = make([]namedCheck, 0, len(checks))
disabledChecks := make([]namedCheck, 0, len(disabledSet))
for _, nc := range checks {
if _, skip := disabledSet[nc.name]; skip {
disabledChecks = append(disabledChecks, nc)
continue
}
enabled = append(enabled, nc)
}
return enabled, disabledChecks
}
func runnerNamesForFinding(finding string) []string {
return findingNameToRunnerNames[finding]
}
// DisabledCheckNames returns the sorted public finding-name vocabulary accepted
// by top-level disabled_checks for scheduled check execution. Runner IDs are
// also accepted by splitDisabledChecks for existing configs, but are not
// exposed in the UI.
func DisabledCheckNames() []string {
out := make([]string, 0, len(findingNameToRunnerNames))
for finding := range findingNameToRunnerNames {
info, ok := LookupCheck(finding)
if !ok || info.Internal {
continue
}
out = append(out, finding)
}
sort.Strings(out)
return out
}
var findingNameToRunnerNames = buildFindingNameToRunnerNames()
func buildFindingNameToRunnerNames() map[string][]string {
out := map[string][]string{}
for runner, findings := range runnerFindingNames {
for _, finding := range findings {
out[finding] = append(out[finding], runner)
}
}
return out
}
var runnerFindingNames = map[string][]string{
"admin_overlap": {"admin_cross_account_overlap"},
"credential_reuse": {"credential_reuse"},
"supply_chain": {"supply_chain_vuln"},
"af_alg_enforcement": {"af_alg_enforcement_corrected"},
"af_alg_socket_use": {"af_alg_socket_use"},
"api_auth_failures": {"api_auth_failure"},
"api_tokens": {"api_tokens"},
"cpanel_filemanager": {"cpanel_file_upload"},
"cpanel_logins": {"cpanel_login", "cpanel_multi_ip_login", "cpanel_password_purge"},
"crontabs": {"crond_change", "crontab_change", "suspicious_crontab"},
"database_dumps": {"database_dump"},
"db_content": {"db_options_injection", "db_post_injection", "db_rogue_admin", "db_siteurl_hijack", "db_spam_cleaned", "db_spam_found", "db_spam_injection", "db_suspicious_admin_email"},
"db_content_drupal": {"drupal_admin_injection", "drupal_content_injection", "drupal_settings_injection"},
"db_content_joomla": {"joomla_admin_injection", "joomla_content_injection", "joomla_extensions_injection"},
"db_content_magento": {"magento_admin_injection", "magento_content_injection", "magento_settings_injection"},
"db_content_opencart": {"opencart_admin_injection", "opencart_content_injection", "opencart_settings_injection"},
"db_objects": {"db_magic_token_user", "db_malicious_event", "db_malicious_function", "db_malicious_procedure", "db_malicious_trigger", "db_unexpected_event", "db_unexpected_function", "db_unexpected_procedure", "db_unexpected_trigger"},
"dns_connections": {"dns_connection"},
"dns_zones": {"dns_zone_change"},
"email_content": {"email_phishing_content"},
"email_forwarder_audit": {"email_pipe_forwarder", "email_suspicious_forwarder"},
"email_mail_filters": {"email_filter_blackhole", "email_filter_exfil", "email_filter_forwarder", "email_filter_pipe"},
"email_weak_password": {"email_weak_password"},
"exfiltration_paste": {"exfiltration_paste_site"},
"fake_kernel_threads": {"fake_kernel_thread"},
"file_index": {"new_executable_in_config", "new_php_in_sensitive_dir", "new_php_in_sensitive_dir_clean", "new_php_in_uploads", "new_php_in_uploads_clean", "new_suspicious_php", "new_webshell_file", "obfuscated_php", "suspicious_php_content"},
"filesystem": {"backdoor_binary", "suid_binary", "suspicious_file"},
"firewall": {"firewall", "firewall_ports"},
"ftp_logins": {"ftp_bruteforce", "ftp_login"},
"group_writable_php": {"group_writable_php"},
"health": {"csm_health"},
"htaccess": {"htaccess_handler_abuse", "htaccess_injection"},
"ip_reputation": {"ip_reputation"},
"kernel_modules": {"kernel_module"},
"local_threat_score": {"local_threat_score"},
"mail_per_account": {"mail_per_account"},
"mail_queue": {"mail_queue"},
"modsec_audit": {"waf_attack_blocked"},
"mysql_users": {"mysql_superuser"},
"nulled_plugins": {"nulled_plugin"},
"open_basedir": {"open_basedir"},
"outbound_connections": {"backdoor_port", "backdoor_port_outbound", "c2_connection"},
"outdated_plugins": {"outdated_plugins"},
"perf_error_logs": {"perf_error_logs"},
"perf_load": {"perf_load"},
"perf_memory": {"perf_memory"},
"perf_mysql_config": {"perf_mysql_config"},
"perf_php_handler": {"perf_php_handler"},
"perf_php_processes": {"perf_php_processes"},
"perf_redis_config": {"perf_redis_config"},
"perf_wp_config": {"perf_wp_config"},
"perf_wp_cron": {"perf_wp_cron"},
"perf_wp_transients": {"perf_wp_transients"},
"phishing": {"phishing_credential_log", "phishing_directory", "phishing_iframe", "phishing_kit_archive", "phishing_page", "phishing_php", "phishing_redirector"},
"php_config_changes": {"php_config_change"},
"php_content": {"obfuscated_php", "suspicious_php_content"},
"php_processes": {"php_suspicious_execution"},
"rpm_integrity": {"dpkg_integrity", "rpm_integrity"},
"shadow_changes": {"bulk_password_change", "root_password_change", "shadow_change"},
"ssh_keys": {"ssh_keys"},
"ssh_logins": {"ssh_login_unknown_ip"},
"sshd_config": {"sshd_config_change"},
"ssl_certs": {"ssl_cert_issued"},
"suspicious_processes": {"suspicious_process"},
"symlink_attacks": {"symlink_attack"},
"uid0_accounts": {"uid0_account"},
"user_outbound": {"user_outbound_connection"},
"waf_status": {"waf_bypass", "waf_detection_only", "waf_rules", "waf_rules_stale", "waf_status"},
"webmail_logins": {"webmail_bruteforce"},
"webshells": {"webshell", "world_writable_php"},
"whm_access": {"whm_account_action", "whm_password_change"},
"wp_bruteforce": {"wp_login_bruteforce", "wp_user_enumeration", "xmlrpc_abuse", "http_request_flood", "http_ua_spoof", "http_distributed_flood"},
"wp_core": {"wp_core_integrity"},
}
// namedCheck pairs a check function with its name for timeout reporting.
type namedCheck struct {
name string
fn CheckFunc
}
// ForceAll forces all checks to run regardless of throttle (used by baseline).
var ForceAll bool
// Tier identifies which set of checks to run.
type Tier string
const (
TierCritical Tier = "critical" // Fast checks - processes, auth, network (~5 seconds)
TierDeep Tier = "deep" // Filesystem scans - webshells, htaccess, WP core (~90 seconds)
TierAll Tier = "all" // Both tiers
)
const checkTimeout = 5 * time.Minute
// heavyCheckTimeout applies to filesystem walks that traverse every WP
// install on the host. On busy shared servers (300+ WP installs, tens of
// thousands of plugin/theme PHP files) these legitimately run longer than
// the default 5-minute budget, so they get a wider window to avoid
// noisy check_timeout warnings while leaving fast checks aggressive.
const heavyCheckTimeout = 15 * time.Minute
// heavyChecks names the deep-tier checks that walk every account's
// document roots. Keep this list short and explicit; only checks that
// observably blow past 5 minutes on production hosts belong here.
var heavyChecks = map[string]bool{
"filesystem": true,
"webshells": true,
"htaccess": true,
"php_content": true,
"file_index": true,
"phishing": true,
}
// timeoutFor returns the per-check execution budget. Heavy filesystem
// scans get heavyCheckTimeout, everything else gets checkTimeout.
// Indirected through timeoutForFunc so tests can shrink budgets
// without mutating the const.
func timeoutFor(name string) time.Duration {
return timeoutForFunc(name)
}
var timeoutForFunc = func(name string) time.Duration {
if heavyChecks[name] {
return heavyCheckTimeout
}
return checkTimeout
}
func criticalChecks() []namedCheck {
return []namedCheck{
{"fake_kernel_threads", CheckFakeKernelThreads},
{"suspicious_processes", CheckSuspiciousProcesses},
{"php_processes", CheckPHPProcesses},
{"shadow_changes", CheckShadowChanges},
{"uid0_accounts", CheckUID0Accounts},
{"ssh_keys", CheckSSHKeys},
{"sshd_config", CheckSSHDConfig},
{"ssh_logins", CheckSSHLogins},
{"api_tokens", CheckAPITokens},
{"crontabs", CheckCrontabs},
{"outbound_connections", CheckOutboundConnections},
{"user_outbound", CheckOutboundUserConnections},
{"dns_connections", CheckDNSConnections},
{"whm_access", CheckWHMAccess},
{"cpanel_logins", CheckCpanelLogins},
{"cpanel_filemanager", CheckCpanelFileManager},
{"firewall", CheckFirewall},
{"mail_queue", CheckMailQueue},
{"mail_per_account", CheckMailPerAccount},
{"kernel_modules", CheckKernelModules},
{"af_alg_socket_use", CheckAFAlgSocketUsage},
{"af_alg_enforcement", CheckAFAlgEnforcement},
{"mysql_users", CheckMySQLUsers},
{"database_dumps", CheckDatabaseDumps},
{"exfiltration_paste", CheckOutboundPasteSites},
{"wp_bruteforce", CheckWPBruteForce},
{"ftp_logins", CheckFTPLogins},
{"webmail_logins", CheckWebmailLogins},
{"api_auth_failures", CheckAPIAuthFailures},
{"ip_reputation", CheckIPReputation},
{"local_threat_score", CheckLocalThreatScore},
{"modsec_audit", CheckModSecAuditLog},
{"health", CheckHealth},
{"perf_load", CheckLoadAverage},
{"perf_php_processes", CheckPHPProcessLoad},
{"perf_memory", CheckSwapAndOOM},
}
}
func deepChecks() []namedCheck {
return []namedCheck{
{"filesystem", CheckFilesystem},
{"webshells", CheckWebshells},
{"htaccess", CheckHtaccess},
{"wp_core", CheckWPCore},
{"file_index", CheckFileIndex},
{"php_content", CheckPHPContent},
{"phishing", CheckPhishing},
{"nulled_plugins", CheckNulledPlugins},
{"rpm_integrity", CheckRPMIntegrity},
{"group_writable_php", CheckGroupWritablePHP},
{"open_basedir", CheckOpenBasedir},
{"symlink_attacks", CheckSymlinkAttacks},
{"php_config_changes", CheckPHPConfigChanges},
{"dns_zones", CheckDNSZoneChanges},
{"ssl_certs", CheckSSLCertIssuance},
{"waf_status", CheckWAFStatus},
{"db_content", CheckDatabaseContent},
{"db_content_drupal", CheckDrupalContent},
{"db_content_joomla", CheckJoomlaContent},
{"db_content_magento", CheckMagentoContent},
{"db_content_opencart", CheckOpenCartContent},
{"db_objects", CheckDatabaseObjects},
{"admin_overlap", CheckAdminEmailOverlap},
{"credential_reuse", CheckCredentialReuse},
{"email_content", CheckOutboundEmailContent},
{"outdated_plugins", CheckOutdatedPlugins},
{"supply_chain", CheckSupplyChain},
{"email_weak_password", CheckEmailPasswords},
{"email_forwarder_audit", CheckForwarders},
{"email_mail_filters", CheckMailFilters},
{"perf_php_handler", CheckPHPHandler},
{"perf_mysql_config", CheckMySQLConfig},
{"perf_redis_config", CheckRedisConfig},
{"perf_error_logs", CheckErrorLogBloat},
{"perf_wp_config", CheckWPConfig},
{"perf_wp_transients", CheckWPTransientBloat},
{"perf_wp_cron", CheckWPCron},
}
}
func reducedDeepChecks() []namedCheck {
return []namedCheck{
{"wp_core", CheckWPCore},
{"nulled_plugins", CheckNulledPlugins},
{"rpm_integrity", CheckRPMIntegrity},
{"group_writable_php", CheckGroupWritablePHP},
{"open_basedir", CheckOpenBasedir},
{"symlink_attacks", CheckSymlinkAttacks},
{"dns_zones", CheckDNSZoneChanges},
{"ssl_certs", CheckSSLCertIssuance},
{"waf_status", CheckWAFStatus},
{"db_content", CheckDatabaseContent},
{"db_content_drupal", CheckDrupalContent},
{"db_content_joomla", CheckJoomlaContent},
{"db_content_magento", CheckMagentoContent},
{"db_content_opencart", CheckOpenCartContent},
{"db_objects", CheckDatabaseObjects},
{"admin_overlap", CheckAdminEmailOverlap},
{"credential_reuse", CheckCredentialReuse},
{"email_content", CheckOutboundEmailContent},
{"outdated_plugins", CheckOutdatedPlugins},
{"supply_chain", CheckSupplyChain},
{"email_weak_password", CheckEmailPasswords},
{"email_forwarder_audit", CheckForwarders},
{"email_mail_filters", CheckMailFilters},
{"perf_php_handler", CheckPHPHandler},
{"perf_mysql_config", CheckMySQLConfig},
{"perf_redis_config", CheckRedisConfig},
{"perf_error_logs", CheckErrorLogBloat},
{"perf_wp_config", CheckWPConfig},
{"perf_wp_transients", CheckWPTransientBloat},
{"perf_wp_cron", CheckWPCron},
}
}
// PerfCheckNamesForTier returns the perf_* check names registered in the given tier.
// Used by the daemon to perform an atomic purge-and-merge when storing findings.
func PerfCheckNamesForTier(tier Tier) []string {
var names []string
for _, nc := range checksForTier(tier) {
if strings.HasPrefix(nc.name, "perf_") {
names = append(names, nc.name)
}
}
return names
}
// checkThrottleMin maps a check name to its minimum interval in minutes
// between executions. The runner consults this BEFORE invoking the check
// function. Throttled checks that get skipped in a given cycle are NOT
// added to the per-scan purge list, so their previously-emitted findings
// stay in the latest set instead of being wiped every cycle. Without this
// gating in the runner, a deep scan that ran while the throttle window was
// still open would purge stale findings and merge nothing, hiding real
// issues until the next non-throttled cycle (or daemon restart).
var checkThrottleMin = map[string]int{
"perf_php_handler": 60,
"perf_mysql_config": 60,
"perf_redis_config": 60,
"perf_error_logs": 60,
"perf_wp_config": 60,
"perf_wp_transients": 60,
"perf_wp_cron": 60,
}
// LatestPurgeCheckNamesForTier returns every emitted finding name owned by a
// tier. The daemon uses this to replace a tier's current scan output without
// retaining stale findings from prior runs.
func LatestPurgeCheckNamesForTier(tier Tier) []string {
return latestPurgeCheckNamesForChecks(checksForTier(tier))
}
// LatestPurgeCheckNamesForReducedDeep returns the emitted finding names owned
// by the reduced deep set used while fanotify covers filesystem events.
func LatestPurgeCheckNamesForReducedDeep() []string {
return latestPurgeCheckNamesForChecks(reducedDeepChecks())
}
var latestVolatileCheckNames = []string{
"account_scan_truncated",
"auto_block",
"auto_response",
}
var latestDerivedCheckNames = []string{
"coordinated_attack",
"cross_account_malware",
}
// StoreLatestScanFindings replaces the latest findings owned by a scan, then
// rebuilds derived correlation findings from the merged current set. One-shot
// auto-response actions stay in history and alerts, not the active findings
// view.
func StoreLatestScanFindings(st *state.Store, purgeChecks []string, findings []alert.Finding) {
if st == nil {
return
}
if len(purgeChecks) == 0 && len(findings) == 0 {
return
}
st.PurgeAndMergeFindings(latestPurgeWithVolatile(purgeChecks), latestPersistentFindings(findings))
now := time.Now()
derived := CorrelateFindings(st.LatestFindings())
for i := range derived {
if derived[i].Timestamp.IsZero() {
derived[i].Timestamp = now
}
}
st.PurgeAndMergeFindings(latestDerivedCheckNames, derived)
}
func latestPurgeWithVolatile(purgeChecks []string) []string {
out := make([]string, 0, len(purgeChecks)+len(latestVolatileCheckNames))
out = append(out, purgeChecks...)
out = append(out, latestVolatileCheckNames...)
return out
}
func latestPersistentFindings(findings []alert.Finding) []alert.Finding {
out := make([]alert.Finding, 0, len(findings))
for _, f := range findings {
if isLatestVolatileFinding(f.Check) || isLatestDerivedFinding(f.Check) {
continue
}
out = append(out, f)
}
return out
}
func isLatestVolatileFinding(check string) bool {
for _, name := range latestVolatileCheckNames {
if check == name {
return true
}
}
return false
}
func isLatestDerivedFinding(check string) bool {
for _, name := range latestDerivedCheckNames {
if check == name {
return true
}
}
return false
}
func checksForTier(tier Tier) []namedCheck {
switch tier {
case TierCritical:
return criticalChecks()
case TierDeep:
return deepChecks()
case TierAll:
return append(criticalChecks(), deepChecks()...)
default:
return nil
}
}
func latestPurgeCheckNamesForChecks(toScan []namedCheck) []string {
seen := make(map[string]struct{})
for _, nc := range toScan {
seen[nc.name] = struct{}{}
for _, name := range runnerFindingNames[nc.name] {
seen[name] = struct{}{}
}
}
names := make([]string, 0, len(seen))
for name := range seen {
names = append(names, name)
}
sort.Strings(names)
return names
}
// RunTier runs only the specified tier of checks. The second return value
// is the per-scan purge name list (emitted finding aliases owned by the
// checks that actually executed this cycle); pass it to
// StoreLatestScanFindings so throttled-out checks keep their prior
// findings.
//
// Passes the requested dry-run state into runParallel via a scoped
// parameter rather than the previous package-level toggle, so concurrent
// periodic scanners no longer race with a manual `csm check` invocation.
func RunTier(cfg *config.Config, store *state.Store, tier Tier) ([]alert.Finding, []string) {
return RunTierWithContext(context.Background(), cfg, store, tier)
}
// RunTierWithContext is RunTier with a caller-owned parent context. Daemon
// periodic scans pass their shutdown context here so an interrupted scan does
// not stall process exit.
func RunTierWithContext(ctx context.Context, cfg *config.Config, store *state.Store, tier Tier) ([]alert.Finding, []string) {
return runParallelWithContext(ctx, cfg, store, checksForTier(tier), string(tier), false)
}
// RunTierDryRun is the dry-run variant of RunTier: auto-response actions
// are skipped. Used by `csm check*` socket commands and the legacy CLI.
func RunTierDryRun(cfg *config.Config, store *state.Store, tier Tier) ([]alert.Finding, []string) {
return runParallelWithContext(context.Background(), cfg, store, checksForTier(tier), string(tier), true)
}
// RunReducedDeep runs only the deep checks that fanotify can't replace.
// Used by the daemon when fanotify is active.
//
// Skipped (fanotify handles these in real-time):
//
// filesystem, webshells, htaccess, file_index, php_content,
// phishing, php_config_changes
//
// The second return value is the per-scan purge name list scoped to the
// checks that actually executed this cycle.
func RunReducedDeep(cfg *config.Config, store *state.Store) ([]alert.Finding, []string) {
return RunReducedDeepWithContext(context.Background(), cfg, store)
}
// RunReducedDeepWithContext is RunReducedDeep with a caller-owned parent
// context for daemon shutdown cancellation.
func RunReducedDeepWithContext(ctx context.Context, cfg *config.Config, store *state.Store) ([]alert.Finding, []string) {
return runParallelWithContext(ctx, cfg, store, reducedDeepChecks(), string(TierDeep), false)
}
// RunAll runs critical checks always. Deep checks run if throttle allows or
// ForceAll is set. The second return value is the per-scan purge name list
// scoped to the checks that actually executed this cycle.
func RunAll(cfg *config.Config, store *state.Store) ([]alert.Finding, []string) {
return runAll(cfg, store, false)
}
// RunAllDryRun is the dry-run variant of RunAll for `csm baseline`.
func RunAllDryRun(cfg *config.Config, store *state.Store) ([]alert.Finding, []string) {
return runAll(cfg, store, true)
}
func runAll(cfg *config.Config, store *state.Store, dryRun bool) ([]alert.Finding, []string) {
toRun := criticalChecks()
if ForceAll || store.ShouldRunThrottled("deep_scan", cfg.Thresholds.DeepScanIntervalMin) {
toRun = append(toRun, deepChecks()...)
}
return runParallelWithContext(context.Background(), cfg, store, toRun, string(TierAll), dryRun)
}
// runParallel executes the supplied checks concurrently. It returns the
// emitted findings plus the per-scan purge name list. Throttled checks whose
// window has not elapsed stay out of the purge list so the previous cycle's
// findings persist. Disabled checks do not run, but their names stay in the
// purge list so disabling a check clears any findings it previously owned.
func runParallel(cfg *config.Config, store *state.Store, checks []namedCheck, tier string, dryRun bool) ([]alert.Finding, []string) {
return runParallelWithContext(context.Background(), cfg, store, checks, tier, dryRun)
}
func runParallelWithContext(parent context.Context, cfg *config.Config, store *state.Store, checks []namedCheck, tier string, dryRun bool) ([]alert.Finding, []string) {
if parent == nil {
parent = context.Background()
}
enabledChecks, disabledChecks := splitDisabledChecks(cfg, checks)
scanCtx, truncations := withAccountScanTruncationCollector(parent)
var mu sync.Mutex
var findings []alert.Finding
var wg sync.WaitGroup
ranChecks := make([]namedCheck, 0, len(enabledChecks))
purgeChecks := make([]namedCheck, 0, len(enabledChecks)+len(disabledChecks))
purgeChecks = append(purgeChecks, disabledChecks...)
for _, nc := range enabledChecks {
if min, ok := checkThrottleMin[nc.name]; ok && store != nil && !store.ShouldRunThrottled(nc.name, min) {
continue
}
ranChecks = append(ranChecks, nc)
purgeChecks = append(purgeChecks, nc)
}
// Limit concurrent checks to avoid saturating CPU (keeps WebUI responsive)
sem := make(chan struct{}, 5)
for _, nc := range ranChecks {
wg.Add(1)
c := nc
// Check functions run against user filesystem content (unparsed
// PHP, crafted archives, foreign encodings) so a panic here is
// plausible. SafeGo captures it and surfaces as a check_timeout
// finding on the outer select, keeping the scan and the daemon
// alive.
obs.SafeGo("check-runner", func() {
defer wg.Done()
select {
case sem <- struct{}{}:
case <-scanCtx.Done():
return
}
defer func() { <-sem }()
if scanCtx.Err() != nil {
return
}
// Run with cancellable context so timed-out checks stop
budget := timeoutFor(c.name)
ctx, cancel := context.WithTimeout(scanCtx, budget)
done := make(chan []alert.Finding, 1)
start := time.Now()
obs.SafeGo("check-exec", func() {
done <- c.fn(ctx, cfg, store)
})
select {
case results := <-done:
cancel()
observeCheckDuration(c.name, tier, time.Since(start))
if len(results) > 0 {
mu.Lock()
findings = append(findings, results...)
mu.Unlock()
}
case <-ctx.Done():
cancel()
observeCheckDuration(c.name, tier, time.Since(start))
// Distinguish a real per-check timeout from a daemon shutdown.
// On shutdown the parent scan context is cancelled, so abort
// quietly rather than flooding the scan with bogus check_timeout
// findings for work that was deliberately interrupted.
if scanCtx.Err() != nil {
return
}
mu.Lock()
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "check_timeout",
Message: fmt.Sprintf("Check '%s' timed out after %s", c.name, budget),
Timestamp: time.Now(),
})
mu.Unlock()
}
})
}
wg.Wait()
if scanCtx.Err() != nil {
// The scan did not complete, so keep the caller from replacing the
// last completed scan state with partial findings or an empty purge.
return nil, nil
}
now := time.Now()
findings = append(findings, truncations.findings(now)...)
for i := range findings {
if findings[i].Timestamp.IsZero() {
findings[i].Timestamp = now
}
}
// Cross-account correlation
extra := CorrelateFindings(findings)
for i := range extra {
if extra[i].Timestamp.IsZero() {
extra[i].Timestamp = now
}
}
findings = append(findings, extra...)
// Auto-response: skip when the caller requested a dry run
// (check/baseline commands).
if !dryRun {
killActions := AutoKillProcesses(cfg, findings)
for i := range killActions {
if killActions[i].Timestamp.IsZero() {
killActions[i].Timestamp = now
}
}
observeAutoResponse("kill", len(killActions))
findings = append(findings, killActions...)
quarantineActions := AutoQuarantineFiles(cfg, findings)
for i := range quarantineActions {
if quarantineActions[i].Timestamp.IsZero() {
quarantineActions[i].Timestamp = now
}
}
observeAutoResponse("quarantine", len(quarantineActions))
findings = append(findings, quarantineActions...)
htaccessActions := AutoCleanHtaccess(cfg, findings)
for i := range htaccessActions {
if htaccessActions[i].Timestamp.IsZero() {
htaccessActions[i].Timestamp = now
}
}
observeAutoResponse("htaccess_clean", len(htaccessActions))
findings = append(findings, htaccessActions...)
blockActions := AutoBlockIPs(cfg, findings)
for i := range blockActions {
if blockActions[i].Timestamp.IsZero() {
blockActions[i].Timestamp = now
}
}
observeAutoResponse("block", len(blockActions))
findings = append(findings, blockActions...)
}
return findings, latestPurgeCheckNamesForChecks(purgeChecks)
}
package checks
import (
"crypto/sha256"
"encoding/hex"
"sync"
"time"
)
// selfWriteTTL bounds how long a CSM-performed write to a sensitive path
// suppresses the sensitive-file detectors. Short enough that an independent
// tamper layered on later is still caught by the next scan.
const selfWriteTTL = 15 * time.Minute
var (
selfWriteMu sync.Mutex
selfWrites = map[string]selfWriteRecord{}
selfWriteNow = time.Now // overridable in tests
)
type selfWriteRecord struct {
hash string
expires time.Time
}
// RecordSelfWrite registers that CSM remediation just wrote content to a
// sensitive watched file. The sensitive-file detectors suppress a finding only
// when the file still holds exactly this content within the TTL, so a malicious
// change layered on top (different hash) is still reported -- this is not a path
// allowlist.
func RecordSelfWrite(path string, content []byte) {
sum := sha256.Sum256(content)
selfWriteMu.Lock()
defer selfWriteMu.Unlock()
now := selfWriteNow()
pruneExpiredSelfWritesLocked(now)
selfWrites[path] = selfWriteRecord{
hash: hex.EncodeToString(sum[:]),
expires: now.Add(selfWriteTTL),
}
}
func forgetSelfWrites(paths ...string) {
selfWriteMu.Lock()
defer selfWriteMu.Unlock()
for _, path := range paths {
delete(selfWrites, path)
}
}
// isExpectedSelfWrite reports whether content at path is byte-identical to a
// CSM self-write recorded within the TTL. Expired entries are pruned and treated
// as not-expected.
func isExpectedSelfWrite(path string, content []byte) bool {
selfWriteMu.Lock()
defer selfWriteMu.Unlock()
now := selfWriteNow()
pruneExpiredSelfWritesLocked(now)
rec, ok := selfWrites[path]
if !ok {
return false
}
sum := sha256.Sum256(content)
return hex.EncodeToString(sum[:]) == rec.hash
}
func pruneExpiredSelfWritesLocked(now time.Time) {
for path, rec := range selfWrites {
if now.After(rec.expires) {
delete(selfWrites, path)
}
}
}
package checks
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// sensitiveWatchset is the static set of system-configuration paths CSM
// raises a finding on when any of them is opened for write. The set is
// intentionally narrow and not operator-configurable: an attacker who
// learns that a path is excluded gets a free landing pad.
//
// Glob entries expand at runtime; non-glob entries appear once.
var sensitiveWatchset = []string{
"/etc/shadow",
"/etc/gshadow",
"/etc/passwd",
"/etc/group",
"/etc/sudoers",
"/etc/sudoers.d/*",
"/etc/ssh/sshd_config",
"/etc/ssh/sshd_config.d/*",
"/etc/cron.d/*",
"/etc/cron.hourly/*",
"/etc/cron.daily/*",
"/etc/cron.weekly/*",
"/etc/cron.monthly/*",
"/var/spool/cron/*",
}
const sensitiveFileBaselineKey = "_sensitive_file_hash:__baseline_complete"
// ExpandWatchset returns the absolute paths in the watchset, with globs
// expanded against the given filesystem root. Non-existent paths drop
// silently; the next refresh picks them up once they are created. root
// is "/" in production and a t.TempDir in tests.
func ExpandWatchset(root string) []string {
var out []string
for _, pat := range sensitiveWatchset {
full := filepath.Join(root, pat)
if strings.ContainsAny(pat, "*?[") {
matches, _ := filepath.Glob(full)
out = append(out, matches...)
continue
}
out = append(out, full)
}
return out
}
// classifySensitive returns a stable kind label for a watchset path so
// findings can vary their severity and message.
func classifySensitive(path string) string {
switch filepath.Base(path) {
case "shadow", "gshadow", "passwd", "group":
return "auth"
case "sshd_config":
return "sshd"
case "sudoers":
return "sudo"
}
dir := filepath.Dir(path)
if strings.Contains(dir, "/cron") || strings.Contains(dir, "/spool/cron") {
return "cron"
}
if strings.Contains(dir, "/sudoers.d") {
return "sudo"
}
if strings.Contains(dir, "/sshd_config.d") {
return "sshd"
}
return ""
}
// EvaluateSensitiveFileWrite returns a populated alert.Finding and true when
// the BPF live backend observed a write to a watchset path. Pure: no IO
// beyond the provenance probe used to demote vendor package activity.
// Returns false for paths classifySensitive does not recognise -- the BPF
// program already filters via its dev+inode map, but this guards against
// stale map entries pointing at unrelated files.
func EvaluateSensitiveFileWrite(path string, uid, pid uint32, comm string) (alert.Finding, bool) {
kind := classifySensitive(path)
if kind == "" {
return alert.Finding{}, false
}
// Suppress writes CSM itself just performed (e.g. installing a per-user
// wp-cron). Content-bound: a tamper layered on top changes the hash and
// is still reported.
if content, err := osFS.ReadFile(path); err == nil && isExpectedSelfWrite(path, content) {
return alert.Finding{}, false
}
sev := alert.High
if uid != 0 {
sev = alert.Critical
}
f := alert.Finding{
Severity: sev,
Check: "sensitive_file_modified",
Message: fmt.Sprintf("Write to sensitive system file: %s (uid=%d)", path, uid),
Details: fmt.Sprintf("Class: %s, PID: %d, Comm: %s, User: %s", kind, pid, comm, LookupUser(uid)),
FilePath: path,
}
return rescoreSensitive(f, kind, nil, pid, time.Now()), true
}
// EvaluateSensitiveFileAppearance returns a finding when a new file appears
// inside a sensitive glob directory between live watchset refreshes. Reads
// the file content (when accessible) so the cron-content heuristic can veto
// a package-window demote on obvious persistence payloads.
func EvaluateSensitiveFileAppearance(path string) (alert.Finding, bool) {
kind := classifySensitive(path)
if kind == "" {
return alert.Finding{}, false
}
// Read content up front so a CSM self-write (e.g. an installed wp-cron)
// can be matched and suppressed. The cron-content heuristic below reuses
// the same bytes.
var content []byte
if data, err := osFS.ReadFile(path); err == nil {
content = data
if isExpectedSelfWrite(path, content) {
return alert.Finding{}, false
}
}
now := time.Now()
f := alert.Finding{
Severity: alert.High,
Check: "sensitive_file_modified",
Message: fmt.Sprintf("New sensitive system file appeared: %s", path),
Details: fmt.Sprintf("Class: %s", kind),
FilePath: path,
Timestamp: now,
}
var scoreContent []byte
if kind == "cron" {
scoreContent = content
}
return rescoreSensitive(f, kind, scoreContent, 0, now), true
}
// CheckSensitiveFiles is the periodic safety-net that runs when the BPF
// live monitor is unavailable or disabled. It content-hashes every watchset
// path and emits a finding when a hash differs from the previous run. The
// first run records baselines without emitting findings.
//
// CheckShadowChanges in auth.go does richer per-user diff and infra-IP
// suppression for /etc/shadow specifically; this catch-all complements
// that for sshd_config, sudoers, cron drop-ins, etc. Both run in parallel;
// audit-log dedup handles the (rare) overlap.
func CheckSensitiveFiles(_ context.Context, _ *config.Config, store *state.Store) []alert.Finding {
if store == nil {
return nil
}
var findings []alert.Finding
_, baselineComplete := store.GetRaw(sensitiveFileBaselineKey)
for _, path := range ExpandWatchset("/") {
data, err := osFS.ReadFile(path)
if err != nil {
continue
}
sum := sha256.Sum256(data)
hashHex := hex.EncodeToString(sum[:])
key := "_sensitive_file_hash:" + path
prev, ok := store.GetRaw(key)
if !ok {
store.SetRaw(key, hashHex)
if baselineComplete {
if f, emit := EvaluateSensitiveFileAppearance(path); emit {
findings = append(findings, f)
}
}
continue
}
if prev == hashHex {
continue
}
store.SetRaw(key, hashHex)
// A content change CSM itself made (e.g. installing a wp-cron) updates
// the stored baseline above but raises no finding.
if isExpectedSelfWrite(path, data) {
continue
}
kind := classifySensitive(path)
var contentForScore []byte
if kind == "cron" {
contentForScore = data
}
hashChange := alert.Finding{
Severity: alert.High,
Check: "sensitive_file_modified",
Message: fmt.Sprintf("Periodic check: content hash changed for %s", path),
Details: fmt.Sprintf("Previous: %s, Current: %s", prev, hashHex),
FilePath: path,
Timestamp: time.Now(),
}
findings = append(findings, rescoreSensitive(hashChange, kind, contentForScore, 0, time.Now()))
}
if !baselineComplete {
store.SetRaw(sensitiveFileBaselineKey, "1")
}
return findings
}
package checks
import (
"bytes"
"fmt"
"os"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// pkgManagerLogs is the ordered set of package-manager log files whose
// recent mtime acts as evidence of a legitimate root-driven file system
// change. RPM-family hosts touch dnf.rpm.log / yum.log; Debian-family
// hosts touch dpkg.log; minimal installs add history.log for unattended-
// upgrades. The variable is package-private (no operator override) so an
// attacker who learns CSM is here cannot point the daemon at an empty
// path -- the trade-off is that mtime spoofing requires root, which is
// already game-over for this detector class.
var pkgManagerLogs = []string{
"/var/log/dnf.rpm.log",
"/var/log/dnf.log",
"/var/log/yum.log",
"/var/log/dpkg.log",
"/var/log/apt/history.log",
}
// AncestryProbe reports whether the process tree rooted at pid contains a
// package-manager process. Nil on hosts without BPF process context, in
// which case ancestry checks no-op and the pkg-window + cron-content layers
// still apply. The BPF daemon wires this from processctx at startup.
var AncestryProbe func(pid uint32) bool
// pkgManagerWindow returns true when any pkgManagerLogs file was modified
// within window. Reads file mtime only; does not parse log contents.
func pkgManagerWindow(now time.Time, window time.Duration) bool {
cutoff := now.Add(-window)
for _, p := range pkgManagerLogs {
fi, err := os.Stat(p)
if err != nil {
continue
}
if fi.ModTime().After(cutoff) {
return true
}
}
return false
}
// cronDangerTokens are byte fragments whose presence in a cron drop-in is
// inconsistent with vendor-shipped automation and consistent with
// post-exploitation persistence. The list is intentionally narrow: false
// positives here cancel the demote, which is the safe failure mode.
var cronDangerTokens = [][]byte{
[]byte("| sh"),
[]byte("|sh "),
[]byte("| bash"),
[]byte("|bash "),
[]byte("; sh "),
[]byte(";sh "),
[]byte("; bash"),
[]byte(";bash "),
[]byte("base64 -d"),
[]byte("base64 --decode"),
[]byte("base64_decode"),
[]byte("eval("),
[]byte("eval $("),
[]byte("eval \""),
[]byte("/tmp/"),
[]byte("/var/tmp/"),
[]byte("/dev/shm/"),
[]byte("python -c"),
[]byte("python3 -c"),
[]byte("perl -e"),
[]byte("ruby -e"),
[]byte("nc -e"),
[]byte("ncat -e"),
[]byte("bash -i"),
[]byte("\\x"),
[]byte("curl "),
[]byte("wget "),
}
// cronHasDangerTokens returns true if any cronDangerTokens byte fragment
// appears in content. Case-sensitive; cron content is shell, so case
// matters (PATH lookups, builtins). The curl/wget tokens are broad on
// purpose: a vendor cron that needs to fetch is rare enough that flagging
// is the right default.
func cronHasDangerTokens(content []byte) bool {
for _, tok := range cronDangerTokens {
if bytes.Contains(content, tok) {
return true
}
}
return false
}
// rescoreSensitive returns f with severity adjusted per provenance signals:
// - package-manager activity inside pkgWindowDefault demotes High to Warning
// - AncestryProbe(pid) returning true demotes High to Warning
// - cron class with cronHasDangerTokens(content) vetoes any demote
//
// content and pid are optional (nil / 0). class is "" for non-classified
// findings. now is injected for deterministic testing.
func rescoreSensitive(f alert.Finding, class string, content []byte, pid uint32, now time.Time) alert.Finding {
if f.Severity != alert.High {
return f
}
veto := class == "cron" && len(content) > 0 && cronHasDangerTokens(content)
if veto {
return f
}
var reason string
if pkgManagerWindow(now, pkgWindowDefault) {
reason = "package manager active within window"
} else if pid != 0 && AncestryProbe != nil && AncestryProbe(pid) {
reason = "ancestor is package manager"
}
if reason == "" {
return f
}
f.Severity = alert.Warning
if f.Details == "" {
f.Details = fmt.Sprintf("Demoted: %s", reason)
} else {
f.Details = fmt.Sprintf("%s [demoted: %s]", f.Details, reason)
}
return f
}
// pkgWindowDefault is the slack we give for a legitimate root-driven file
// system change after a package transaction. dnf scriptlets observed up to
// a few seconds between the rpm log entry and post-install file drops; 2
// minutes covers slower scriptlets without inviting a multi-minute window
// for an attacker who happened to time a transaction.
const pkgWindowDefault = 2 * time.Minute
package checks
import (
"os"
"path/filepath"
"strings"
)
// Recognisers used by the realtime fanotify restore/probe dedup logic.
// The scheduled file index deliberately does not call these as scan skips:
// upload PHP is indexed first and then classified by content.
// LooksLikeCpanelRestoreStaging recognises files inside cPanel's
// pkgacct/restorepkg staging tree. cPanel extracts the user backup
// as root into /home/cpanelpkgrestore.TMP.work.<id>/ for inspection,
// then re-extracts it under the user identity into /home/<account>/.
// Both extractions raise events; the user-context one carries the
// real signal, so the staging-side alert is a duplicate.
//
// The recogniser requires the marker to sit directly under /home (the
// only place cPanel ever creates it) plus a non-empty alphanumeric id
// of >=2 chars. A user account at /home/<user>/ cannot create
// siblings of itself, so this gate cannot be spoofed by a non-root
// attacker.
func LooksLikeCpanelRestoreStaging(path string) bool {
const homeRoot = "/home"
const marker = "/cpanelpkgrestore.TMP.work."
idx := strings.Index(path, marker)
if idx < 0 {
return false
}
if idx != len(homeRoot) {
return false
}
if !strings.HasPrefix(path, homeRoot) {
return false
}
rest := path[idx+len(marker):]
if rest == "" {
return false
}
end := strings.IndexByte(rest, '/')
var token string
if end < 0 {
token = rest
} else {
token = rest[:end]
}
if len(token) < 2 {
return false
}
for i := 0; i < len(token); i++ {
c := token[i]
switch {
case c >= '0' && c <= '9':
case c >= 'a' && c <= 'z':
case c >= 'A' && c <= 'Z':
default:
return false
}
}
return true
}
// LooksLikeWPOptimizeProbeByPath recognises WP-Optimize's per-server
// probe files using path structure alone. WP-Optimize writes tiny
// <?php files to /wp-content/uploads/wpo/.../test.php to test whether
// the host honours certain Apache/Nginx directives.
//
// Path-only gates (no content read):
//
// 1. Path lies under /wp-content/uploads/wpo/.
// 2. The basename is exactly "test.php" (the literal filename
// WP-Optimize uses for these probes; an attacker dropping
// /uploads/wpo/webshell.php fails this gate and continues to
// the standard alert).
// 3. The wp-optimize plugin directory is actually present in this
// site's wp-content/plugins/ tree (filesystem stat).
//
// The realtime path additionally applies a content shape gate before
// suppressing the duplicate alert, so the path predicate here stays
// narrow to the literal probe filename and installed plugin directory.
func LooksLikeWPOptimizeProbeByPath(path string) bool {
const marker = "/wp-content/uploads/wpo/"
if !strings.Contains(path, marker) {
return false
}
if filepath.Base(path) != "test.php" {
return false
}
uploadsIdx := strings.Index(path, "/wp-content/uploads/")
if uploadsIdx < 0 {
return false
}
pluginDir := path[:uploadsIdx] + "/wp-content/plugins/wp-optimize"
st, err := os.Stat(pluginDir)
if err != nil || !st.IsDir() {
return false
}
return true
}
package checks
import (
"regexp"
"strings"
)
// SEO-spam context analysis for WordPress post content.
//
// Word-boundary keyword matching (see countSpamMatches) eliminated the
// substring false positives where "specialist" triggered "cialis" and
// "pharmaceutical" triggered "pharma". It did not eliminate a second
// class of false positive: legitimate prose mentions of an industry or
// product category ("our advisor covers consumer goods, energy,
// pharma" or "Industria alimentara si Pharma"). Word-boundary matching
// cannot distinguish the prose mention from the cloaked black-hat SEO
// link the lalimanro attack injected on the same site.
//
// This file classifies a keyword HIT as SPAM only when the surrounding
// HTML shows an attacker signal: CSS cloaking (off-screen absolute
// positioning, display:none, visibility:hidden, text-indent, micro
// height, font-size:0), an injection fingerprint (short hex HTML
// comment bracketing content), or an external anchor whose URL path
// itself contains the keyword ("/buy/pharma/" style commercial paths).
//
// Bare keyword mentions with none of those signals do not fire, so
// legitimate industry-vertical prose is silent.
//
// The proximity window is bounded to ±spamContextWindow bytes around
// the keyword match. Cloaking at the top of a long post unrelated to
// a keyword mention at the bottom does not spuriously associate.
// spamContextWindow bounds how far (in bytes) around a keyword match
// the analyzer looks for cloaking signals. 400 covers a typical
// cloaked-div attack (small <div> with style + <a> + keyword) without
// reaching into unrelated content.
const spamContextWindow = 400
// cssCloakPatterns are regexes that each, when matched, identify a
// CSS property value indicative of content cloaking. The list is
// conservative: each entry corresponds to a technique widely used in
// real SEO spam campaigns and rarely to nothing else at production
// scale. `(?i)` makes matching case-insensitive; `\s*` around colons
// and values tolerates the whitespace variants attackers use to evade
// naive string scanners ("display : none" vs "display:none").
var cssCloakPatterns = []*regexp.Regexp{
// display:none and visibility:hidden — classic hide.
regexp.MustCompile(`(?i)\bdisplay\s*:\s*none\b`),
regexp.MustCompile(`(?i)\bvisibility\s*:\s*hidden\b`),
// text-indent with a negative value — pushes text off-screen.
regexp.MustCompile(`(?i)\btext-indent\s*:\s*-\s*\d+`),
// Micro height with a positive integer of 0 or 1, paired with
// overflow:hidden to suppress contents. We require the pair
// because height:1 alone is occasionally legitimate.
regexp.MustCompile(`(?i)\bheight\s*:\s*[01](\s*px)?\b[^"'}]*\boverflow\s*:\s*hidden\b`),
regexp.MustCompile(`(?i)\boverflow\s*:\s*hidden\b[^"'}]*\bheight\s*:\s*[01](\s*px)?\b`),
// font-size:0 — classic invisible-text technique.
regexp.MustCompile(`(?i)\bfont-size\s*:\s*0\b`),
// position:absolute paired with a negative coordinate in the same
// style attribute (non-greedy [^"'}]* stays inside one attribute).
// Legitimate uses of position:absolute (menus, tooltips) have
// non-negative coordinates; the pairing is the signal.
regexp.MustCompile(`(?i)\bposition\s*:\s*absolute\b[^"'}]*\b(left|top|right|bottom)\s*:\s*-\s*\d+`),
regexp.MustCompile(`(?i)\b(left|top|right|bottom)\s*:\s*-\s*\d+[^"'}]*\bposition\s*:\s*absolute\b`),
// Off-screen via large negative margin on a block element.
regexp.MustCompile(`(?i)\bmargin(-left|-top)?\s*:\s*-\s*\d{4,}`),
}
// injectionFingerprintRe matches the short hex HTML comment (4-8 hex
// chars) attackers use to tag their own injections across a campaign.
// The length bracket is deliberate: shorter matches are too ambiguous,
// longer matches overlap with legitimate WordPress UUIDs.
var injectionFingerprintRe = regexp.MustCompile(`<!--\s*[a-f0-9]{4,8}\s*-->`)
// anchorHrefRe extracts the href attribute value from each anchor tag
// in a fragment. Double- and single-quoted values are both accepted.
var anchorHrefRe = regexp.MustCompile(`(?i)<a\b[^>]+href\s*=\s*["']([^"']+)["']`)
// externalSchemeRe recognises absolute or protocol-relative URLs. A
// relative URL ("/services/pharma/") is same-origin navigation and not
// an external spam link.
var externalSchemeRe = regexp.MustCompile(`^(?i)(https?:)?//`)
// contentHasSpamContext reports whether any occurrence of the keyword
// in content is accompanied by an SEO-spam signal within
// spamContextWindow bytes. Returning false is a direct statement that
// every keyword hit is a bare mention without cloaking context — the
// caller should suppress the finding in that case.
func contentHasSpamContext(content string, pattern dbSpamPattern) bool {
matches := pattern.regex.FindAllStringIndex(content, -1)
for _, m := range matches {
if hitHasSpamContext(content, m[0], m[1], pattern.keyword) {
return true
}
}
return false
}
// countCloakedSpamMatches returns the number of rows in contents whose
// text contains the spam keyword AND shows an accompanying SEO-spam
// context signal. It is the aggregator used by checkWPPosts to decide
// whether to emit a db_spam_injection finding: bare prose mentions of
// a keyword do not count; only cloaked/SEO-style injections do.
//
// Each qualifying row is counted exactly once regardless of how many
// keyword hits it contains — the finding is per-post, not per-hit.
func countCloakedSpamMatches(pattern dbSpamPattern, contents []string) int {
n := 0
for _, c := range contents {
if contentHasSpamContext(c, pattern) {
n++
}
}
return n
}
// hitHasSpamContext examines the ±spamContextWindow byte region around
// a single keyword hit for cloaking, injection-fingerprint, or
// external-spam-link signals. Extracted as a helper so tests can target
// one hit at a time.
func hitHasSpamContext(content string, start, end int, keyword string) bool {
ws := start - spamContextWindow
if ws < 0 {
ws = 0
}
we := end + spamContextWindow
if we > len(content) {
we = len(content)
}
window := content[ws:we]
if windowHasCSSCloaking(window) {
return true
}
if injectionFingerprintRe.MatchString(window) {
return true
}
if windowHasExternalSpamAnchor(window, keyword) {
return true
}
return false
}
// positionAbsoluteRe and negativeCoordRe are evaluated together in
// windowHasCSSCloaking for the "off-screen absolute positioning"
// signal. Keeping them as two independent regexes (rather than one
// paired regex with [^"'}]* between them) closes an evasion where the
// attacker splits the cloak across two CSS rules — e.g.
// `<style>.a{position:absolute}.b{left:-9999px}</style>` — which the
// paired form stops matching at the first `}`. Both signals must
// appear somewhere in the proximity window; the window itself is
// bounded (see spamContextWindow) so the association is not unbounded.
var (
positionAbsoluteRe = regexp.MustCompile(`(?i)\bposition\s*:\s*absolute\b`)
negativeCoordRe = regexp.MustCompile(`(?i)\b(left|top|right|bottom|margin-left|margin-top)\s*:\s*-\s*\d{2,}`)
)
// windowHasCSSCloaking returns true if the window contains any CSS
// declaration from cssCloakPatterns, OR if it contains both
// position:absolute and a negative coordinate somewhere in the window
// (independent-signal form, for rule-split evasions).
func windowHasCSSCloaking(window string) bool {
for _, re := range cssCloakPatterns {
if re.MatchString(window) {
return true
}
}
if positionAbsoluteRe.MatchString(window) && negativeCoordRe.MatchString(window) {
return true
}
return false
}
// windowHasExternalSpamAnchor returns true if the window contains an
// <a href> whose destination is an external URL (absolute or
// protocol-relative) AND whose URL path contains the spam keyword as
// a bounded segment. Same-origin relative links ("/services/pharma/")
// are skipped — they are internal navigation, not spam.
func windowHasExternalSpamAnchor(window, keyword string) bool {
anchors := anchorHrefRe.FindAllStringSubmatch(window, -1)
if len(anchors) == 0 {
return false
}
keywordLower := strings.ToLower(keyword)
for _, a := range anchors {
href := a[1]
if !externalSchemeRe.MatchString(href) {
continue
}
if !hrefPathContainsKeyword(href, keywordLower) {
continue
}
return true
}
return false
}
// hrefPathContainsKeyword checks whether the URL path portion of href
// contains the keyword bounded by non-alphanumeric characters. The
// boundary guarantees "/buy/pharma/cheap" matches "pharma" but
// "/pharmaceutical/" does not.
func hrefPathContainsKeyword(href, keywordLower string) bool {
// Strip scheme://host prefix. externalSchemeRe already confirmed
// the URL starts with //host or scheme://host.
rest := href
if i := strings.Index(rest, "://"); i >= 0 {
rest = rest[i+3:]
} else {
rest = strings.TrimPrefix(rest, "//")
}
// Everything after the first '/' is the path+query.
var path string
if i := strings.IndexByte(rest, '/'); i >= 0 {
path = rest[i:]
} else {
path = ""
}
lower := strings.ToLower(path)
idx := strings.Index(lower, keywordLower)
if idx < 0 {
return false
}
// Check bounded: preceding char (if any) and following char (if
// any) must not be alphanumeric. This disambiguates
// "/buy/pharma/" (bounded by '/') from "/pharmacist/" (bounded by
// 'c' after "pharma").
if idx > 0 {
c := lower[idx-1]
if isURLWordChar(c) {
return false
}
}
after := idx + len(keywordLower)
if after < len(lower) {
c := lower[after]
if isURLWordChar(c) {
return false
}
}
return true
}
// isURLWordChar reports whether a byte is an ASCII alphanumeric used
// for URL-path word-boundary analysis. Underscore is treated as a word
// character by convention; hyphen is NOT (so "buy-cheap-pharma" still
// matches "pharma" bounded by the trailing hyphen/slash).
func isURLWordChar(b byte) bool {
switch {
case b >= 'a' && b <= 'z':
return true
case b >= 'A' && b <= 'Z':
return true
case b >= '0' && b <= '9':
return true
case b == '_':
return true
}
return false
}
package checks
import (
"context"
"encoding/json"
"fmt"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// Supply-chain dependency scanning.
//
// This is the scanner half of the supply-chain check: it parses
// composer.lock / package-lock.json dependency trees under customer
// document roots and matches the resolved versions against a local
// advisory database. The advisory database itself is operational data,
// not shipped in the binary -- an operator or a sync job writes
// <state>/advisories/supply-chain.json (format documented in
// docs/supply-chain-advisories.md). With no advisory file present the
// check is dormant: it parses nothing it cannot match and emits nothing.
// This mirrors the YARA-forge mirror posture (machinery in CSM, signed
// data delivered out of band).
// supplyChainAdvisoryRelPath is where CSM looks for the advisory DB,
// relative to the configured state directory.
const supplyChainAdvisoryRelPath = "advisories/supply-chain.json"
// supplyChainPkg is one resolved dependency from a lockfile.
type supplyChainPkg struct {
Ecosystem string // "composer" | "npm"
Name string
Version string
}
// supplyChainAdvisory is the OSV-subset advisory shape CSM matches
// against. A version is vulnerable when it falls inside any range:
// version >= introduced AND (fixed == "" OR version < fixed).
type supplyChainAdvisory struct {
Ecosystem string `json:"ecosystem"`
Package string `json:"package"`
Ranges []supplyChainAdvisoryRange `json:"ranges"`
ID string `json:"id"`
Severity string `json:"severity"`
Summary string `json:"summary"`
}
type supplyChainAdvisoryRange struct {
Introduced string `json:"introduced"`
Fixed string `json:"fixed"`
}
type supplyChainAdvisoryFile struct {
Advisories []supplyChainAdvisory `json:"advisories"`
}
// CheckSupplyChain scans customer dependency lockfiles for versions with
// known advisories. Dormant unless an advisory database is present at
// <state>/advisories/supply-chain.json.
func CheckSupplyChain(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if cfg == nil {
return nil
}
advisories := loadSupplyChainAdvisories(cfg.StatePath)
if len(advisories) == 0 {
return nil // no data -> nothing to match against
}
index := indexAdvisories(advisories)
var findings []alert.Finding
for _, lock := range discoverLockfiles(ctx) {
if ctx != nil && ctx.Err() != nil {
return findings
}
data, err := osFS.ReadFile(lock.path)
if err != nil {
continue
}
pkgs := lock.parse(data)
for _, p := range pkgs {
for _, adv := range index[advisoryKey(p.Ecosystem, p.Name)] {
if versionVulnerable(p.Version, adv.Ranges) {
findings = append(findings, supplyChainFinding(lock.account, lock.path, p, adv))
}
}
}
}
return findings
}
func loadSupplyChainAdvisories(statePath string) []supplyChainAdvisory {
if statePath == "" {
return nil
}
data, err := osFS.ReadFile(filepath.Join(statePath, supplyChainAdvisoryRelPath))
if err != nil {
return nil
}
var f supplyChainAdvisoryFile
if json.Unmarshal(data, &f) != nil {
return nil
}
return f.Advisories
}
func advisoryKey(ecosystem, pkg string) string {
return strings.ToLower(ecosystem) + "\x00" + strings.ToLower(pkg)
}
func indexAdvisories(advisories []supplyChainAdvisory) map[string][]supplyChainAdvisory {
out := map[string][]supplyChainAdvisory{}
for _, a := range advisories {
if a.Package == "" || a.Ecosystem == "" {
continue
}
k := advisoryKey(a.Ecosystem, a.Package)
out[k] = append(out[k], a)
}
return out
}
type lockfile struct {
path string
account string
parse func([]byte) []supplyChainPkg
}
type packageLockDependency struct {
Version string `json:"version"`
Dependencies map[string]packageLockDependency `json:"dependencies"`
}
// discoverLockfiles globs composer.lock and package-lock.json at the
// common project depths under customer home directories. Bounded by the
// glob shape (no recursive walk) so a deep node_modules tree cannot turn
// the scan into an unbounded crawl.
func discoverLockfiles(ctx context.Context) []lockfile {
patterns := []struct {
glob string
parse func([]byte) []supplyChainPkg
}{
{"/home/*/public_html/composer.lock", parseComposerLock},
{"/home/*/composer.lock", parseComposerLock},
{"/home/*/public_html/package-lock.json", parsePackageLock},
{"/home/*/package-lock.json", parsePackageLock},
}
var out []lockfile
seen := map[string]struct{}{}
for _, p := range patterns {
if ctx != nil && ctx.Err() != nil {
return out
}
matches, _ := osFS.Glob(p.glob)
for _, m := range matches {
if _, dup := seen[m]; dup {
continue
}
seen[m] = struct{}{}
out = append(out, lockfile{path: m, account: extractUser(m), parse: p.parse})
}
}
return out
}
func parseComposerLock(data []byte) []supplyChainPkg {
var doc struct {
Packages []struct{ Name, Version string } `json:"packages"`
PackagesDev []struct{ Name, Version string } `json:"packages-dev"`
}
if json.Unmarshal(data, &doc) != nil {
return nil
}
var out []supplyChainPkg
for _, set := range [][]struct{ Name, Version string }{doc.Packages, doc.PackagesDev} {
for _, p := range set {
if p.Name == "" || p.Version == "" {
continue
}
out = append(out, supplyChainPkg{Ecosystem: "composer", Name: p.Name, Version: p.Version})
}
}
return out
}
func parsePackageLock(data []byte) []supplyChainPkg {
var doc struct {
Packages map[string]struct {
Version string `json:"version"`
} `json:"packages"`
Dependencies map[string]packageLockDependency `json:"dependencies"`
}
if json.Unmarshal(data, &doc) != nil {
return nil
}
var out []supplyChainPkg
// npm v2/v3: keyed by "node_modules/<name>" (the root "" entry is the
// project itself and has no node_modules prefix).
paths := make([]string, 0, len(doc.Packages))
for path := range doc.Packages {
paths = append(paths, path)
}
sort.Strings(paths)
for _, path := range paths {
v := doc.Packages[path]
name := npmNameFromPackagesKey(path)
if name == "" || v.Version == "" {
continue
}
out = append(out, supplyChainPkg{Ecosystem: "npm", Name: name, Version: v.Version})
}
// npm v1: dependency tree rooted at the top-level dependencies map.
if len(doc.Packages) == 0 {
return appendPackageLockV1Dependencies(out, doc.Dependencies)
}
return out
}
func appendPackageLockV1Dependencies(out []supplyChainPkg, deps map[string]packageLockDependency) []supplyChainPkg {
stack := []map[string]packageLockDependency{deps}
for len(stack) > 0 {
cur := stack[len(stack)-1]
stack = stack[:len(stack)-1]
names := make([]string, 0, len(cur))
for name := range cur {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
dep := cur[name]
if name != "" && dep.Version != "" {
out = append(out, supplyChainPkg{Ecosystem: "npm", Name: name, Version: dep.Version})
}
if len(dep.Dependencies) > 0 {
stack = append(stack, dep.Dependencies)
}
}
}
return out
}
// npmNameFromPackagesKey extracts the package name from a v2/v3
// package-lock "packages" key. The key is the path "node_modules/<name>"
// (or nested "node_modules/a/node_modules/b"); the name is whatever
// follows the last "node_modules/". The root project key "" yields "".
func npmNameFromPackagesKey(key string) string {
if key == "" {
return ""
}
parts := strings.Split(key, "/")
for i := len(parts) - 2; i >= 0; i-- {
if parts[i] != "node_modules" {
continue
}
if parts[i+1] == "" {
return ""
}
if strings.HasPrefix(parts[i+1], "@") {
if i+2 >= len(parts) || parts[i+2] == "" {
return ""
}
return parts[i+1] + "/" + parts[i+2]
}
return parts[i+1]
}
return ""
}
// versionVulnerable reports whether version falls inside any advisory
// range: version >= introduced AND (fixed == "" OR version < fixed).
// An empty/"0" introduced means "from the beginning".
func versionVulnerable(version string, ranges []supplyChainAdvisoryRange) bool {
for _, r := range ranges {
introOK := r.Introduced == "" || r.Introduced == "0" || semverCompare(version, r.Introduced) >= 0
fixedOK := r.Fixed == "" || semverCompare(version, r.Fixed) < 0
if introOK && fixedOK {
return true
}
}
return false
}
// semverCompare compares two dotted versions numerically segment by
// segment, tolerating a leading "v" and ignoring any pre-release/build
// suffix after the first "-" or "+". Returns -1, 0, or 1. Non-numeric
// segments compare as 0 so a malformed version never panics.
func semverCompare(a, b string) int {
as := semverSegments(a)
bs := semverSegments(b)
n := len(as)
if len(bs) > n {
n = len(bs)
}
for i := 0; i < n; i++ {
var av, bv int
if i < len(as) {
av = as[i]
}
if i < len(bs) {
bv = bs[i]
}
if av < bv {
return -1
}
if av > bv {
return 1
}
}
return 0
}
func semverSegments(v string) []int {
v = strings.TrimSpace(v)
v = strings.TrimPrefix(v, "v")
v = strings.TrimPrefix(v, "V")
if i := strings.IndexAny(v, "-+"); i >= 0 {
v = v[:i]
}
parts := strings.Split(v, ".")
out := make([]int, 0, len(parts))
for _, p := range parts {
n, err := strconv.Atoi(strings.TrimSpace(p))
if err != nil {
n = 0
}
out = append(out, n)
}
return out
}
func supplyChainFinding(account, lockPath string, p supplyChainPkg, adv supplyChainAdvisory) alert.Finding {
sev := alert.High
switch strings.ToLower(adv.Severity) {
case "critical":
sev = alert.Critical
case "low", "medium", "moderate", "":
sev = alert.Warning
}
id := adv.ID
if id == "" {
id = "advisory"
}
fixed := "no fixed version published"
for _, r := range adv.Ranges {
if r.Fixed != "" {
fixed = "fixed in " + r.Fixed
break
}
}
return alert.Finding{
Severity: sev,
Check: "supply_chain_vuln",
Message: fmt.Sprintf("Vulnerable %s dependency %s %s (%s) on account %s",
p.Ecosystem, p.Name, p.Version, id, account),
Details: fmt.Sprintf("Account: %s\nLockfile: %s\nEcosystem: %s\nPackage: %s\nVersion: %s\nAdvisory: %s\nSeverity: %s\nFix: %s\n%s",
account, lockPath, p.Ecosystem, p.Name, p.Version, id, adv.Severity, fixed, adv.Summary),
FilePath: lockPath,
Timestamp: time.Now(),
}
}
package checks
import (
"bufio"
"context"
"fmt"
"io"
"sort"
"strings"
"syscall"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/mysqlclient"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/state"
)
// CheckKernelModules compares loaded kernel modules against baseline.
// All modules present at baseline time are considered known.
// Only modules loaded AFTER baseline trigger alerts.
func CheckKernelModules(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
modules := loadModuleList()
if len(modules) == 0 {
return nil
}
// Check if baseline exists for kernel modules
_, baselineExists := store.GetRaw("_kmod_baseline_set")
if !baselineExists {
// First run - store all current modules as baseline
for _, mod := range modules {
store.SetRaw("_kmod:"+mod, "baseline")
}
store.SetRaw("_kmod_baseline_set", "true")
return nil
}
// Check for modules not seen at baseline
for _, mod := range modules {
key := "_kmod:" + mod
_, known := store.GetRaw(key)
if !known {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "kernel_module",
Message: fmt.Sprintf("New kernel module loaded after baseline: %s", mod),
Details: "This module was not present when CSM baseline was set. Verify it is legitimate.",
})
// Store it so we don't re-alert
store.SetRaw(key, "new")
}
}
return findings
}
func loadModuleList() []string {
f, err := osFS.Open("/proc/modules")
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var modules []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if len(fields) >= 1 {
modules = append(modules, fields[0])
}
}
return modules
}
// CheckRPMIntegrity verifies critical system binaries haven't been modified.
// Only checks a small set of security-critical packages. Dispatches to
// rpm -V on RHEL-family systems and debsums/dpkg --verify on Debian family.
func CheckRPMIntegrity(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
info := platform.Detect()
switch {
case info.IsRHELFamily():
return checkRPMPackageIntegrity(rpmCriticalPackages)
case info.IsDebianFamily():
return checkDebianPackageIntegrity(debianCriticalPackages)
}
return nil
}
var rpmCriticalPackages = []string{
"openssh-server",
"shadow-utils",
"sudo",
"coreutils",
"util-linux",
"passwd",
}
var debianCriticalPackages = []string{
"openssh-server",
"passwd",
"sudo",
"coreutils",
"util-linux",
"login",
}
func checkRPMPackageIntegrity(packages []string) []alert.Finding {
var findings []alert.Finding
for _, pkg := range packages {
// rpm -V exits non-zero when it finds problems; treat that as
// "findings present" rather than command failure.
out, err := runCmdAllowNonZero("rpm", "-V", pkg)
if err != nil || out == nil {
continue
}
output := strings.TrimSpace(string(out))
if output == "" {
continue
}
// Parse rpm -V output: each line starts with flags
// S=size, 5=md5, T=mtime, etc. We care about S, 5, and M (mode)
for _, line := range strings.Split(output, "\n") {
if len(line) < 9 {
continue
}
flags := line[:9]
file := strings.TrimSpace(line[9:])
// Skip config files (c) and documentation (d)
if strings.Contains(line, " c ") || strings.Contains(line, " d ") {
continue
}
// Check for size (S) or checksum (5) changes. Config and doc files
// were already skipped above; report any tampered executable or
// shared library regardless of directory.
if strings.Contains(flags, "S") || strings.Contains(flags, "5") {
if looksExecutableOrLibrary(file) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "rpm_integrity",
Message: fmt.Sprintf("Modified system binary or library: %s (package: %s)", file, pkg),
Details: fmt.Sprintf("RPM verification flags: %s", flags),
})
}
}
}
}
return findings
}
// checkDebianPackageIntegrity verifies Debian/Ubuntu packages using debsums
// (from the debsums package) when available, falling back to dpkg --verify
// (built into dpkg and always present).
func checkDebianPackageIntegrity(packages []string) []alert.Finding {
// Prefer debsums: it's the Debian equivalent of `rpm -V` and reports
// changed files vs. the md5sum shipped by the package.
if _, err := cmdExec.LookPath("debsums"); err == nil {
return checkDebsums(packages)
}
return checkDpkgVerify(packages)
}
func checkDebsums(packages []string) []alert.Finding {
var findings []alert.Finding
for _, pkg := range packages {
// debsums -c exits 2 when it finds mismatches; treat as findings.
out, err := runCmdAllowNonZero("debsums", "-c", pkg)
if err != nil || out == nil {
continue
}
for _, file := range strings.Split(strings.TrimSpace(string(out)), "\n") {
file = strings.TrimSpace(file)
if file == "" {
continue
}
if !looksExecutableOrLibrary(file) {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "dpkg_integrity",
Message: fmt.Sprintf("Modified system binary or library: %s (package: %s)", file, pkg),
Details: "debsums reported md5 mismatch against the package manifest.",
})
}
}
return findings
}
func checkDpkgVerify(packages []string) []alert.Finding {
var findings []alert.Finding
for _, pkg := range packages {
// dpkg --verify exits 1 when it finds mismatches; treat as findings.
out, err := runCmdAllowNonZero("dpkg", "--verify", pkg)
if err != nil || out == nil {
continue
}
// dpkg --verify prints lines like:
// ??5?????? /usr/bin/passwd
// where position 2 == '5' means md5 mismatch.
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
if len(line) < 10 {
continue
}
flags := line[:9]
file := strings.TrimSpace(line[9:])
// Skip config files (marked with 'c' after flags).
if strings.Contains(line, " c ") {
continue
}
if !strings.Contains(flags, "5") && !strings.Contains(flags, "S") {
continue
}
if !looksExecutableOrLibrary(file) {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "dpkg_integrity",
Message: fmt.Sprintf("Modified system binary or library: %s (package: %s)", file, pkg),
Details: fmt.Sprintf("dpkg --verify flags: %s", flags),
})
}
}
return findings
}
// looksExecutableOrLibrary reports whether the installed file at path is an
// executable or a shared library. A package-integrity mismatch on one of these
// is the threat we care about (a trojaned binary or .so), so it is reported
// wherever it lives. Judging by file type instead of a directory allowlist
// means an attacker cannot dodge the check by tampering a packaged binary that
// sits outside /usr/bin (e.g. under /usr/lib64, /usr/local, or /opt), while
// changed package-manager state files (manifests, caches, databases) -- which
// are not executable -- do not generate noise.
func looksExecutableOrLibrary(path string) bool {
info, err := osFS.Stat(path)
if err != nil || !info.Mode().IsRegular() {
return false
}
if info.Mode()&0o111 != 0 {
return true
}
// Shared libraries are commonly mode 0644; identify them by ELF magic so a
// trojaned .so is reported regardless of its permission bits.
f, err := osFS.Open(path)
if err != nil {
return false
}
defer func() { _ = f.Close() }()
var magic [4]byte
n, _ := io.ReadFull(f, magic[:])
return n == 4 && string(magic[:]) == "\x7fELF"
}
// CheckMySQLUsers queries for MySQL users with elevated privileges
// that aren't standard cPanel-managed users.
func CheckMySQLUsers(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
rows, err := mysqlclient.RootQuery(ctx,
"SELECT user, host FROM mysql.user WHERE Super_priv='Y' AND user NOT IN ('root','mysql.session','mysql.sys','mysql.infoschema','debian-sys-maint')")
if err != nil {
return nil
}
sort.Strings(rows)
output := strings.TrimSpace(strings.Join(rows, "\n"))
out := []byte(output)
// Track known MySQL superusers
hash := hashBytes(out)
key := "_mysql_super_users"
prev, exists := store.GetRaw(key)
switch {
case !exists:
if output == "" {
break
}
// First run establishes the baseline. The query already excludes the
// standard cPanel/system superusers, so every account here is a
// non-standard privileged account. Baselining them silently would let
// a rogue superuser planted before CSM was installed (the
// CSM-installed-after-breach case) pass as "known" forever. Surface the
// pre-existing set for operator review instead.
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "mysql_superuser",
Message: "Non-standard MySQL superuser accounts present at baseline",
Details: fmt.Sprintf("Review these privileged accounts:\n%s", output),
})
case prev != hash:
details := "Current superusers: none"
if output != "" {
details = fmt.Sprintf("Current superusers:\n%s", output)
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "mysql_superuser",
Message: "MySQL superuser accounts changed",
Details: details,
})
}
store.SetRaw(key, hash)
return findings
}
// CheckGroupWritablePHP scans for PHP files that are group-writable
// where the group is the web server (nobody/www-data). This allows
// webshells to persist by the web server modifying PHP files.
func CheckGroupWritablePHP(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Get web server group GIDs
webGroupGIDs := getWebServerGIDs()
if len(webGroupGIDs) == 0 {
return nil
}
homeDirs, _ := GetScanHomeDirs(ctx)
for _, homeEntry := range homeDirs {
if !homeEntry.IsDir() {
continue
}
docRoot := fmt.Sprintf("/home/%s/public_html", homeEntry.Name())
scanGroupWritablePHP(docRoot, 4, webGroupGIDs, &findings)
}
return findings
}
func scanGroupWritablePHP(dir string, maxDepth int, webGIDs map[uint32]bool, findings *[]alert.Finding) {
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
name := entry.Name()
fullPath := dir + "/" + name
if entry.IsDir() {
// Skip known large/safe dirs
if name == "cache" || name == "node_modules" || name == "vendor" {
continue
}
scanGroupWritablePHP(fullPath, maxDepth-1, webGIDs, findings)
continue
}
if !isExecutablePHPName(strings.ToLower(name)) {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
// Check group-write bit
if info.Mode()&0020 == 0 {
continue
}
// Check if group is web server
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
continue
}
if webGIDs[stat.Gid] {
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "group_writable_php",
Message: fmt.Sprintf("Web-server group-writable PHP: %s", fullPath),
Details: fmt.Sprintf("Mode: %s, GID: %d", info.Mode(), stat.Gid),
})
}
}
}
func getWebServerGIDs() map[uint32]bool {
gids := make(map[uint32]bool)
data, err := osFS.ReadFile("/etc/group")
if err != nil {
return gids
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Split(line, ":")
if len(fields) < 3 {
continue
}
name := fields[0]
if name == "nobody" || name == "www-data" || name == "apache" || name == "www" {
gid := uint32(0)
fmt.Sscanf(fields[2], "%d", &gid)
gids[gid] = true
}
}
return gids
}
package checks
import (
"bufio"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/store"
)
// ThreatDB is a local IP reputation database built from:
// 1. CSM's own block history (permanent)
// 2. Public threat intelligence feeds (updated daily)
// 3. AbuseIPDB as fallback for unknown IPs
type ThreatDB struct {
mu sync.RWMutex
badIPs map[string]string // ip -> source/reason
badNets []*net.IPNet // CIDR ranges from feeds
whitelist map[string]bool // IPs to never flag
whitelistMeta map[string]*whitelistEntry // expiry metadata
lastUpdate time.Time
dbPath string
// Stats for WebUI
PermanentCount int
FeedIPCount int
FeedNetCount int
LastFeedUpdate time.Time
LastUpdated time.Time // tracks when feeds were last successfully loaded
}
var (
globalThreatDB *ThreatDB
threatDBOnce sync.Once
)
// Minimum expected entries per feed - alerts if feed returns less (corrupted/down)
var feedMinEntries = map[string]int{
"spamhaus-drop": 50,
"spamhaus-edrop": 10,
"blocklist-de": 1000,
"cins-army": 5000,
}
// Free public threat intelligence feeds
var threatFeeds = []struct {
name string
url string
}{
{"spamhaus-drop", "https://www.spamhaus.org/drop/drop.txt"},
{"spamhaus-edrop", "https://www.spamhaus.org/drop/edrop.txt"},
{"blocklist-de", "https://lists.blocklist.de/lists/all.txt"},
{"cins-army", "https://cinsscore.com/list/ci-badguys.txt"},
}
// InitThreatDB initializes the global threat database.
func InitThreatDB(statePath string, whitelistIPs []string) *ThreatDB {
threatDBOnce.Do(func() {
wl := make(map[string]bool)
for _, ip := range whitelistIPs {
wl[ip] = true
}
db := &ThreatDB{
badIPs: make(map[string]string),
whitelist: wl,
dbPath: filepath.Join(statePath, "threat_db"),
}
_ = os.MkdirAll(db.dbPath, 0700)
db.loadPermanentBlocklist()
db.loadPersistedWhitelist()
db.loadFeedCache()
globalThreatDB = db
})
return globalThreatDB
}
// GetThreatDB returns the global threat database.
func GetThreatDB() *ThreatDB {
return globalThreatDB
}
// SetGlobalThreatDBForTest installs a freshly-constructed threat DB rooted at
// statePath, bypassing the once-guard, and returns a function that restores the
// previous global. For tests only: lets a test exercise the threat-DB path
// without permanently polluting the global for order-dependent tests.
func SetGlobalThreatDBForTest(statePath string) func() {
prev := globalThreatDB
db := &ThreatDB{
badIPs: make(map[string]string),
whitelist: make(map[string]bool),
dbPath: filepath.Join(statePath, "threat_db"),
}
_ = os.MkdirAll(db.dbPath, 0700)
globalThreatDB = db
return func() { globalThreatDB = prev }
}
// Lookup checks if an IP is in the local threat database.
// Returns (source, true) if found, ("", false) if unknown.
// Whitelisted IPs always return false.
func (db *ThreatDB) Lookup(ip string) (string, bool) {
db.mu.RLock()
defer db.mu.RUnlock()
// Never flag whitelisted IPs
if db.whitelist[ip] {
return "", false
}
// Check exact IP match
if source, ok := db.badIPs[ip]; ok {
return source, true
}
// Check CIDR ranges (supports both IPv4 and IPv6)
parsed := net.ParseIP(ip)
if parsed != nil {
for _, cidr := range db.badNets {
if cidr.Contains(parsed) {
return "threat-feed-cidr", true
}
}
}
return "", false
}
// AddPermanent adds an IP to the permanent local blocklist.
// Called when CSM auto-blocks an IP - persists across restarts.
func (db *ThreatDB) AddPermanent(ip, reason string) {
db.mu.Lock()
_, exists := db.badIPs[ip]
db.badIPs[ip] = reason
db.mu.Unlock()
// Only persist if this is a new IP (dedup)
if exists {
return
}
if sdb := store.Global(); sdb != nil {
_ = sdb.AddPermanentBlock(ip, reason)
return
}
// Fallback: flat-file permanent.txt.
f, err := os.OpenFile(filepath.Join(db.dbPath, "permanent.txt"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return
}
defer func() { _ = f.Close() }()
fmt.Fprintf(f, "%s # %s [%s]\n", ip, reason, time.Now().Format("2006-01-02"))
}
// RemovePermanent removes an IP from the permanent blocklist and in-memory DB.
func (db *ThreatDB) RemovePermanent(ip string) {
db.mu.Lock()
delete(db.badIPs, ip)
db.mu.Unlock()
if sdb := store.Global(); sdb != nil {
_ = sdb.RemovePermanentBlock(ip)
return
}
// Fallback: rewrite permanent.txt without this IP.
path := filepath.Join(db.dbPath, "permanent.txt")
data, err := osFS.ReadFile(path)
if err != nil {
return
}
var kept []string
for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
kept = append(kept, line)
continue
}
fields := strings.Fields(trimmed)
if len(fields) > 0 && fields[0] == ip {
continue // skip this IP
}
kept = append(kept, line)
}
tmpPath := path + ".tmp"
_ = os.WriteFile(tmpPath, []byte(strings.Join(kept, "\n")+"\n"), 0600)
_ = os.Rename(tmpPath, path)
}
// whitelistEntry tracks an IP with optional expiry.
type whitelistEntry struct {
ExpiresAt time.Time // zero = permanent
}
// AddWhitelist adds an IP to the permanent whitelist.
func (db *ThreatDB) AddWhitelist(ip string) {
db.addWhitelistEntry(ip, time.Time{})
}
// TempWhitelist adds an IP to the whitelist with a TTL.
func (db *ThreatDB) TempWhitelist(ip string, ttl time.Duration) {
db.addWhitelistEntry(ip, time.Now().Add(ttl))
}
func (db *ThreatDB) addWhitelistEntry(ip string, expiresAt time.Time) {
db.mu.Lock()
db.whitelist[ip] = true
if db.whitelistMeta == nil {
db.whitelistMeta = make(map[string]*whitelistEntry)
}
db.whitelistMeta[ip] = &whitelistEntry{ExpiresAt: expiresAt}
delete(db.badIPs, ip)
db.mu.Unlock()
if sdb := store.Global(); sdb != nil {
permanent := expiresAt.IsZero()
_ = sdb.AddWhitelistEntry(ip, expiresAt, permanent)
return
}
db.saveWhitelistFile()
}
// RemoveWhitelist removes an IP from the whitelist.
func (db *ThreatDB) RemoveWhitelist(ip string) {
db.mu.Lock()
delete(db.whitelist, ip)
delete(db.whitelistMeta, ip)
db.mu.Unlock()
if sdb := store.Global(); sdb != nil {
_ = sdb.RemoveWhitelistEntry(ip)
return
}
db.saveWhitelistFile()
}
// PruneExpiredWhitelist removes expired temporary whitelist entries.
// Called periodically from the daemon heartbeat.
func (db *ThreatDB) PruneExpiredWhitelist() int {
now := time.Now()
pruned := 0
db.mu.Lock()
for ip, entry := range db.whitelistMeta {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
delete(db.whitelist, ip)
delete(db.whitelistMeta, ip)
pruned++
}
}
db.mu.Unlock()
if pruned > 0 {
if sdb := store.Global(); sdb != nil {
sdb.PruneExpiredWhitelist()
} else {
db.saveWhitelistFile()
}
fmt.Fprintf(os.Stderr, "[%s] Pruned %d expired whitelist entries\n",
time.Now().Format("2006-01-02 15:04:05"), pruned)
}
return pruned
}
// WhitelistInfo returns all whitelisted IPs with their expiry info.
type WhitelistIP struct {
IP string `json:"ip"`
ExpiresAt *time.Time `json:"expires_at,omitempty"` // nil = permanent
Permanent bool `json:"permanent"`
}
func (db *ThreatDB) WhitelistedIPs() []WhitelistIP {
db.mu.RLock()
defer db.mu.RUnlock()
var ips []string
for ip := range db.whitelist {
ips = append(ips, ip)
}
sort.Strings(ips)
result := make([]WhitelistIP, len(ips))
for i, ip := range ips {
entry := db.whitelistMeta[ip]
w := WhitelistIP{IP: ip, Permanent: true}
if entry != nil && !entry.ExpiresAt.IsZero() {
t := entry.ExpiresAt
w.ExpiresAt = &t
w.Permanent = false
}
result[i] = w
}
return result
}
func (db *ThreatDB) saveWhitelistFile() {
path := filepath.Join(db.dbPath, "whitelist.txt")
db.mu.RLock()
var lines []string
for ip := range db.whitelist {
entry := db.whitelistMeta[ip]
if entry != nil && !entry.ExpiresAt.IsZero() {
lines = append(lines, fmt.Sprintf("%s expires=%s", ip, entry.ExpiresAt.Format(time.RFC3339)))
} else {
lines = append(lines, fmt.Sprintf("%s permanent", ip))
}
}
db.mu.RUnlock()
sort.Strings(lines)
tmpPath := path + ".tmp"
_ = os.WriteFile(tmpPath, []byte(strings.Join(lines, "\n")+"\n"), 0600)
_ = os.Rename(tmpPath, path)
}
// loadPersistedWhitelist loads IPs from the bbolt store (if available)
// or from the flat-file whitelist.txt.
func (db *ThreatDB) loadPersistedWhitelist() {
if db.whitelistMeta == nil {
db.whitelistMeta = make(map[string]*whitelistEntry)
}
if sdb := store.Global(); sdb != nil {
entries := sdb.ListWhitelist()
now := time.Now()
for _, e := range entries {
// Skip expired entries
if !e.Permanent && !e.ExpiresAt.IsZero() && now.After(e.ExpiresAt) {
continue
}
db.whitelist[e.IP] = true
db.whitelistMeta[e.IP] = &whitelistEntry{ExpiresAt: e.ExpiresAt}
delete(db.badIPs, e.IP)
}
return
}
// Fallback: flat-file whitelist.txt.
path := filepath.Join(db.dbPath, "whitelist.txt")
f, err := osFS.Open(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
now := time.Now()
needsRewrite := false
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
if len(fields) == 0 {
continue
}
ip := fields[0]
if net.ParseIP(ip) == nil {
continue
}
entry := &whitelistEntry{}
// Parse "expires=2026-03-28T19:00:00Z" if present
for _, f := range fields[1:] {
if strings.HasPrefix(f, "expires=") {
if t, err := time.Parse(time.RFC3339, f[8:]); err == nil {
entry.ExpiresAt = t
}
}
}
// Skip expired entries
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
needsRewrite = true
continue
}
db.whitelist[ip] = true
db.whitelistMeta[ip] = entry
delete(db.badIPs, ip)
}
if needsRewrite {
// Synchronous: the load path runs once at startup, so the cost
// is negligible, and a fire-and-forget goroutine would race the
// daemon's shutdown (potentially leaving a `.tmp` file behind or
// writing a half-serialized whitelist.txt if the process is
// killed before the rewrite lands).
db.saveWhitelistFile()
}
}
// Count returns the total number of entries in the database.
func (db *ThreatDB) Count() int {
db.mu.RLock()
defer db.mu.RUnlock()
return len(db.badIPs) + len(db.badNets)
}
// Stats returns statistics for the WebUI dashboard.
func (db *ThreatDB) Stats() map[string]interface{} {
db.mu.RLock()
defer db.mu.RUnlock()
return map[string]interface{}{
"permanent_ips": db.PermanentCount,
"feed_ips": db.FeedIPCount,
"feed_cidrs": db.FeedNetCount,
"total": len(db.badIPs) + len(db.badNets),
"whitelist": len(db.whitelist),
"last_update": db.LastFeedUpdate.Format(time.RFC3339),
}
}
// FeedsStale returns true if threat feeds have not been updated in over 7 days.
func (db *ThreatDB) FeedsStale() bool {
db.mu.RLock()
defer db.mu.RUnlock()
if db.LastUpdated.IsZero() {
return db.lastUpdate.IsZero() || time.Since(db.lastUpdate) > 7*24*time.Hour
}
return time.Since(db.LastUpdated) > 7*24*time.Hour
}
// UpdateFeeds downloads fresh threat intelligence feeds.
// Downloads outside the lock, then swaps data under lock to avoid blocking lookups.
func (db *ThreatDB) UpdateFeeds() error {
db.mu.RLock()
lastUpdate := db.lastUpdate
db.mu.RUnlock()
// Only update once per day
if time.Since(lastUpdate) < 20*time.Hour {
return nil
}
client := &http.Client{Timeout: 30 * time.Second}
// Download all feeds OUTSIDE the lock
newIPs := make(map[string]string)
var newNets []*net.IPNet
totalIPs := 0
totalNets := 0
for _, feed := range threatFeeds {
ips, nets, err := downloadFeed(client, feed.url, feed.name)
if err != nil {
fmt.Fprintf(os.Stderr, "threatdb: error downloading %s: %v\n", feed.name, err)
continue
}
// Validate feed - reject partial downloads to avoid losing good data
minExpected := feedMinEntries[feed.name]
if minExpected > 0 && len(ips)+len(nets) < minExpected {
fmt.Fprintf(os.Stderr, "threatdb: WARNING %s returned only %d entries (expected >%d), keeping cached version\n",
feed.name, len(ips)+len(nets), minExpected)
continue // keep previous cached data for this feed
}
for _, ip := range ips {
newIPs[ip] = feed.name
}
newNets = append(newNets, nets...)
totalIPs += len(ips)
totalNets += len(nets)
// Cache to disk
cachePath := filepath.Join(db.dbPath, feed.name+".txt")
saveLines(cachePath, ips)
}
// Swap data UNDER the lock - fast operation
db.mu.Lock()
// Clear old feed data but keep permanent entries
for ip, source := range db.badIPs {
if source == "permanent-blocklist" {
continue
}
// Check if it's a feed entry (not permanent)
isPermanent := source == "permanent-blocklist"
if !isPermanent {
delete(db.badIPs, ip)
}
}
// Add new feed data
for ip, source := range newIPs {
if _, isPermanent := db.badIPs[ip]; !isPermanent {
db.badIPs[ip] = source
}
}
// Replace CIDR ranges entirely (fixes accumulation bug)
db.badNets = newNets
now := time.Now()
db.lastUpdate = now
db.FeedIPCount = totalIPs
db.FeedNetCount = totalNets
db.LastFeedUpdate = now
db.LastUpdated = now
db.mu.Unlock()
// Save timestamp
_ = os.WriteFile(filepath.Join(db.dbPath, "last_update"),
[]byte(db.lastUpdate.Format(time.RFC3339)), 0600)
fmt.Fprintf(os.Stderr, "threatdb: updated %d IPs + %d CIDR ranges from %d feeds\n",
totalIPs, totalNets, len(threatFeeds))
return nil
}
// loadPermanentBlocklist loads the permanent blocklist from bbolt store
// (if available) or from the flat-file permanent.txt.
func (db *ThreatDB) loadPermanentBlocklist() {
if sdb := store.Global(); sdb != nil {
blocks := sdb.AllPermanentBlocks()
for _, b := range blocks {
db.badIPs[b.IP] = "permanent-blocklist"
}
db.PermanentCount = len(blocks)
return
}
// Fallback: flat-file permanent.txt.
path := filepath.Join(db.dbPath, "permanent.txt")
f, err := osFS.Open(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
seen := make(map[string]bool)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
ip := strings.Fields(line)[0]
// Support both IPv4 and IPv6
if net.ParseIP(ip) != nil && !seen[ip] {
db.badIPs[ip] = "permanent-blocklist"
seen[ip] = true
}
}
db.PermanentCount = len(seen)
// Compact the file if it has duplicates (rewrite with unique entries)
if db.PermanentCount > 0 {
compactPermanentFile(path, seen)
}
}
// compactPermanentFile rewrites the permanent blocklist with unique entries only.
func compactPermanentFile(path string, uniqueIPs map[string]bool) {
// Read all lines to preserve comments/reasons
data, err := osFS.ReadFile(path)
if err != nil {
return
}
lines := strings.Split(string(data), "\n")
if len(lines) <= len(uniqueIPs)+5 {
return // not worth compacting - minimal duplicates
}
// Rewrite with deduplication
seen := make(map[string]bool)
var unique []string
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
unique = append(unique, line)
continue
}
ip := strings.Fields(trimmed)[0]
if !seen[ip] {
seen[ip] = true
unique = append(unique, line)
}
}
tmpPath := path + ".tmp"
_ = os.WriteFile(tmpPath, []byte(strings.Join(unique, "\n")+"\n"), 0600)
_ = os.Rename(tmpPath, path)
}
// loadFeedCache loads cached feed data from disk.
func (db *ThreatDB) loadFeedCache() {
data, err := osFS.ReadFile(filepath.Join(db.dbPath, "last_update"))
if err == nil {
if t, err := time.Parse(time.RFC3339, strings.TrimSpace(string(data))); err == nil {
db.lastUpdate = t
db.LastFeedUpdate = t
db.LastUpdated = t
}
}
for _, feed := range threatFeeds {
cachePath := filepath.Join(db.dbPath, feed.name+".txt")
lines := loadLines(cachePath)
for _, ip := range lines {
db.badIPs[ip] = feed.name
}
db.FeedIPCount += len(lines)
}
// Warn on startup if feeds are stale
if db.LastUpdated.IsZero() && db.FeedIPCount == 0 {
fmt.Fprintf(os.Stderr, "threatdb: WARNING no threat feed data loaded, feeds have never been fetched\n")
} else if !db.LastUpdated.IsZero() && time.Since(db.LastUpdated) > 7*24*time.Hour {
fmt.Fprintf(os.Stderr, "threatdb: WARNING threat feeds are stale (last updated %s, %d days ago)\n",
db.LastUpdated.Format("2006-01-02"), int(time.Since(db.LastUpdated).Hours()/24))
}
}
func downloadFeed(client *http.Client, url, name string) ([]string, []*net.IPNet, error) {
resp, err := client.Get(url)
if err != nil {
return nil, nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
return nil, nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
limited := io.LimitReader(resp.Body, 10*1024*1024)
scanner := bufio.NewScanner(limited)
var ips []string
var nets []*net.IPNet
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") {
continue
}
if idx := strings.IndexAny(line, ";#"); idx > 0 {
line = strings.TrimSpace(line[:idx])
}
if strings.Contains(line, "/") {
_, cidr, err := net.ParseCIDR(line)
if err == nil {
nets = append(nets, cidr)
}
continue
}
ip := strings.Fields(line)[0]
if net.ParseIP(ip) != nil {
ips = append(ips, ip)
}
}
return ips, nets, nil
}
func saveLines(path string, lines []string) {
sort.Strings(lines) // sorted for diffing
// #nosec G304 -- path is filepath.Join under operator-configured statePath.
f, err := os.Create(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
w := bufio.NewWriter(f)
for _, line := range lines {
fmt.Fprintln(w, line)
}
_ = w.Flush()
}
func loadLines(path string) []string {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var lines []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line != "" && !strings.HasPrefix(line, "#") {
lines = append(lines, line)
}
}
return lines
}
package checks
import (
"fmt"
"os"
"strconv"
"strings"
"sync"
)
// uidCache caches uid -> username from /etc/passwd. The first Lookup of an
// unknown uid reads and parses the file; subsequent lookups return from the
// in-memory map. Process-lifetime: callers that need fresh data after a
// useradd should call Refresh().
type uidCache struct {
path string
mu sync.RWMutex
m map[uint32]string
}
var defaultUIDCache = newUIDCache("/etc/passwd")
func newUIDCache(path string) *uidCache {
return &uidCache{path: path, m: map[uint32]string{}}
}
// LookupUser returns the username for uid, or "uid:<n>" if not resolvable.
// Safe for concurrent use; the underlying cache is shared across the daemon.
func LookupUser(uid uint32) string { return defaultUIDCache.Lookup(uid) }
// swapDefaultUIDCacheForTest replaces defaultUIDCache with a cache pointed at
// path and returns a function that restores the original. Test-only helper:
// existing tests stub /etc/passwd via osFS, but the cache reads the real file
// directly (so the daemon never burns syscalls per-event). This shim lets the
// tests stage a fixture file and have LookupUser read from it.
func swapDefaultUIDCacheForTest(path string) func() {
prev := defaultUIDCache
defaultUIDCache = newUIDCache(path)
return func() { defaultUIDCache = prev }
}
func (c *uidCache) Lookup(uid uint32) string {
c.mu.RLock()
if name, ok := c.m[uid]; ok {
c.mu.RUnlock()
return name
}
c.mu.RUnlock()
c.mu.Lock()
defer c.mu.Unlock()
if name, ok := c.m[uid]; ok {
return name
}
c.parseLocked()
if name, ok := c.m[uid]; ok {
return name
}
miss := fmt.Sprintf("uid:%d", uid)
c.m[uid] = miss
return miss
}
// Refresh drops the cache. The next Lookup re-reads /etc/passwd.
func (c *uidCache) Refresh() {
c.mu.Lock()
c.m = map[uint32]string{}
c.mu.Unlock()
}
// parseLocked replaces the cache contents with a fresh scan. Caller holds the
// write lock.
func (c *uidCache) parseLocked() {
data, err := os.ReadFile(c.path)
if err != nil {
return
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.SplitN(line, ":", 4)
if len(fields) < 3 {
continue
}
uid64, err := strconv.ParseUint(fields[2], 10, 32)
if err != nil {
continue
}
c.m[uint32(uid64)] = fields[0]
}
}
package checks
import (
"net"
"net/url"
"strings"
)
// URL reputation — attack-indicator based classifier for external
// <script src="..."> URLs embedded in WordPress content.
//
// PHILOSOPHY
//
// Earlier versions of this package classified script sources against a
// hardcoded allowlist of "known safe" domains (Google Tag Manager,
// Cloudflare CDN, HubSpot, Stripe, etc.). In practice the allowlist is
// unmaintainable: every new widget service (OneTrust, Issuu, regional
// video embeds, tax-form widgets, etc.) adds an entry and operators
// still see HIGH-severity findings for legitimate third-party embeds.
//
// This file takes the inverse approach: rather than asking "is this
// domain on my list of safe services?", it asks "does this URL show
// attacker-characteristic markers?". A <script src> fires a finding
// only when at least one attack indicator is present:
//
// - the host is a raw IP address (attackers dodge domain reputation);
// - the host TLD is on a well-known abused-TLD list (.tk, .ml, .ga,
// .cf, .gq free-abuse; .top, .icu, .click, .pw and similar cheap
// gTLDs per Spamhaus recurring bad-TLD reports);
// - the host is on the existing short known-bad-exfil list
// (Cloudflare Workers free tier, Pastebin raw, GitHub Gist raw —
// legitimate content rarely loads from these hosts);
// - the scheme is plaintext HTTP (an external JS loader without TLS
// is a plaintext MITM target regardless of the destination);
// - the host is empty, contains no dot, or otherwise fails basic
// FQDN validation.
//
// The knownSafeDomains list is retained as a FAST-PATH tie-breaker: if
// the host matches a well-known service we return "not malicious"
// immediately, skipping further analysis. It is an optimization, not
// the primary filter. Unknown hosts on unremarkable TLDs (e.g.
// onetrust.com, issuu.com, trilulilu.ro, formular230.ro) pass because
// they have zero attack indicators — no allowlist growth needed.
//
// TRADE-OFF
//
// An attacker who hosts payload on a compromised mainstream domain
// (.com/.org/.net HTTPS, normal-looking path) is not caught. This gap
// existed under the prior allowlist too — it is a fundamental limit of
// URL-only classification. Closing it requires threat-intelligence
// correlation or content-based JS analysis, both of which are out of
// scope here. The defensive value is in raising the bar for casual
// injection, not in defeating sophisticated adversaries.
// abusedTLDs are top-level domains whose registrations are cheap,
// unverified, and overwhelmingly abused for phishing, malware, and SEO
// spam per recurring Spamhaus and KnowBe4 reports.
//
// The entry bar is high: a TLD is included only if (a) it has shown up
// in top-10 abuse rankings across multiple reporting years, AND
// (b) legitimate business usage is rare. Mixed-use TLDs with
// significant legitimate traffic (.xyz, .online, .site, .live, .space)
// are intentionally excluded to keep false-positive rates low.
//
// Entries are stored without the leading dot; comparison strips the
// leading dot from the observed TLD before lookup.
var abusedTLDs = map[string]bool{
// Former Freenom TLDs — free registration, no verification.
// Near-100% abuse rate; Freenom itself was shut down in 2023 but
// the TLDs remain in DNS.
"tk": true,
"ml": true,
"ga": true,
"cf": true,
"gq": true,
// Cheap new gTLDs consistently in the Spamhaus top-abused list.
"top": true,
"icu": true,
"click": true,
"pw": true,
"loan": true,
"work": true,
"download": true,
// Spamhaus badlist recurring entries — legitimate usage is
// essentially nonexistent at meaningful volume.
"kim": true,
"gdn": true,
"stream": true,
"bid": true,
"racing": true,
"win": true,
"party": true,
"science": true,
"trade": true,
}
// knownBadExfilHosts are hosts where legitimate WordPress content
// essentially never loads JavaScript from, but attackers routinely do.
// The list is intentionally short — see knownSafeDomains for the
// inverse fast-path.
var knownBadExfilHosts = []string{
// Cloudflare Workers free-tier subdomain (e.g. x.workers.dev).
// Legitimate apps host on custom domains; free .workers.dev is a
// common payload drop.
".workers.dev",
// Pastebin and GitHub Gist raw endpoints — legitimate sites do not
// load JS from these paths.
"pastebin.com",
"gist.githubusercontent.com",
// bit.ly and similar URL shorteners in a <script src> are a very
// strong attack signal — no legitimate embed uses a shortener for
// a JS asset.
"bit.ly",
"tinyurl.com",
"is.gd",
"cutt.ly",
}
// scriptSrcStrongReason classifies a <script src> URL by structural
// attack indicators that are context-independent: a raw IP host, an
// abused TLD, a known-bad exfil host, or an empty/unparseable/no-TLD
// host. These markers are rare-to-nonexistent in legitimate content of
// any age and remain valid signals whether the script appears in a
// freshly-written wp_options value or in decade-old post_content.
//
// Callers that operate on storage which is expected to hold current
// configuration (wp_options) should prefer scriptSrcMaliciousReason,
// which layers the plaintext-HTTP indicator on top. The HTTP signal
// catches attacker convenience ("don't bother with TLS") in fresh
// configuration but produces false positives on legacy author content
// where pre-TLS embeds are normal.
//
// The function does not consult knownSafeDomains — callers should do
// that first as a fast path. Separating the two concerns keeps this
// function single-purpose and trivially testable.
func scriptSrcStrongReason(rawURL string) (bool, string) {
normalised := rawURL
if strings.HasPrefix(normalised, "//") {
normalised = "https:" + normalised
}
u, err := url.Parse(normalised)
if err != nil || u == nil {
return true, "unparseable URL"
}
host := strings.ToLower(u.Hostname())
if host == "" {
return true, "empty host"
}
if ip := net.ParseIP(host); ip != nil {
return true, "raw IP address host"
}
for _, bad := range knownBadExfilHosts {
if host == strings.TrimPrefix(bad, ".") {
return true, "known-bad exfil host: " + bad
}
if strings.HasSuffix(host, bad) {
return true, "known-bad exfil host: " + bad
}
}
lastDot := strings.LastIndexByte(host, '.')
if lastDot < 0 || lastDot == len(host)-1 {
return true, "host without valid TLD"
}
tld := host[lastDot+1:]
if abusedTLDs[tld] {
return true, "abused TLD: ." + tld
}
return false, ""
}
// scriptSrcMaliciousReason classifies a <script src> URL with the full
// indicator set: everything scriptSrcStrongReason flags, plus plaintext
// HTTP. This is the classifier for wp_options and similar configuration
// storage, where a plaintext HTTP external script loader is a signal on
// its own (a site's analytics configuration should be HTTPS in 2026).
//
// The function does not consult knownSafeDomains — callers should do
// that first as a fast path.
func scriptSrcMaliciousReason(rawURL string) (bool, string) {
// Structural markers take precedence: a raw-IP host over HTTP should
// report the IP, not the scheme, because the scheme is merely the
// delivery method while the IP is the identity of the attacker
// infrastructure.
if bad, reason := scriptSrcStrongReason(rawURL); bad {
return true, reason
}
// Plaintext HTTP for an external script is a strong indicator in
// configuration storage. Protocol-relative URLs (//host/path)
// inherit the page's scheme and are not flagged here.
normalised := rawURL
if strings.HasPrefix(normalised, "//") {
normalised = "https:" + normalised
}
u, err := url.Parse(normalised)
if err != nil || u == nil {
// scriptSrcStrongReason would have caught this; defensive only.
return true, "unparseable URL"
}
if strings.EqualFold(u.Scheme, "http") && !strings.HasPrefix(rawURL, "//") {
return true, "plaintext HTTP external script"
}
return false, ""
}
// isAttackerScriptURL is the caller-facing predicate for contexts where
// a plaintext-HTTP external script is a signal on its own (wp_options
// and similar configuration storage). It combines the known-safe fast
// path with the strict attack-indicator classifier.
//
// The order matters: the fast path is checked first because it lets us
// short-circuit common legitimate widgets (Google Tag Manager,
// Cloudflare CDN, Stripe) without parsing the URL. Only unknown hosts
// are subjected to the attack-indicator analysis.
func isAttackerScriptURL(rawURL string) bool {
if isSafeScriptDomain(rawURL) {
return false
}
bad, _ := scriptSrcMaliciousReason(rawURL)
return bad
}
// isAttackerScriptURLInPost is the caller-facing predicate for
// post_content classification. It uses scriptSrcStrongReason, which
// omits the plaintext-HTTP indicator: legacy author embeds from the
// pre-TLS era are legitimate content, not injection, and must not
// produce db_post_injection findings.
//
// Fresh attacker injections still flag because they almost always point
// at structural markers (raw IP hosts, abused TLDs, cheap exfil hosts)
// rather than at an unremarkable mainstream-TLD host — and an attacker
// who did somehow land on a plaintext-HTTP mainstream host URL would
// still be caught by other checks (obfuscated_php_realtime scanning the
// attacker's dropper, remote payload URLs in the served page, etc.).
func isAttackerScriptURLInPost(rawURL string) bool {
if isSafeScriptDomain(rawURL) {
return false
}
bad, _ := scriptSrcStrongReason(rawURL)
return bad
}
package checks
import "bytes"
// .user.ini cPanel-managed signature detection.
//
// cPanel's MultiPHP INI Editor writes .user.ini files with a fixed
// four-line header that begins:
//
// ; cPanel-generated php ini directives, do not edit
// ; Manual editing of this file may result in unexpected behavior.
// ; To make changes to this file, use the cPanel MultiPHP INI Editor ...
// ; For more information, read our documentation ...
//
// When this header is present the values in the file reflect operator
// choices made through cPanel's UI (max_execution_time=0 for a backup
// importer, display_errors=On for a staging account, etc.). These are
// not attacker signals. The severity of findings on values in a
// cPanel-managed file should be reduced to informational.
//
// When the header is absent we cannot tell whether the site owner
// hand-edited the file or an attacker planted it. In that case we
// preserve the original severity.
//
// Detection rule (minimal and precise):
//
// The signature must appear on the FIRST non-blank line of the file.
//
// Reasoning: cPanel always writes the header at the top of the file
// and rewrites the whole file on every edit through the UI. An
// attacker appending content below keeps the header position intact.
// An attacker prepending content above pushes the header down; the
// file is then no longer a regular cPanel-managed file and the
// attacker's values take precedence anyway — we must NOT accept this
// as "managed" because doing so would let attackers suppress findings
// by inserting one of their own lines and then the real cPanel header.
// cpanelUserIniSignature is the exact string cPanel writes on the first
// comment line of a managed .user.ini. Case-sensitive: alternate
// capitalizations are rejected to avoid accepting forgeries.
const cpanelUserIniSignature = "cPanel-generated php ini directives"
// cpanelUserIniMaxLeadingBlanks caps how many blank lines may precede
// the signature. cPanel itself writes the signature as the very first
// line, so any tolerance here is purely for admins who may have
// round-tripped the file through a text editor that adds a trailing
// newline or through an FTP client that accumulates CRLFs. A small cap
// also closes a forgery route: without a bound, an attacker who wants
// to suppress severity on their injected values could prepend a large
// run of blank lines followed by the genuine cPanel header string, and
// a first-non-blank-line-only check would classify the file as
// managed despite the attacker owning all the content above the
// header.
const cpanelUserIniMaxLeadingBlanks = 5
// isCpanelManagedUserIni reports whether data begins with the cPanel
// MultiPHP-managed .user.ini header. Empty/whitespace-only input
// returns false. A small number (cpanelUserIniMaxLeadingBlanks) of
// leading blank lines is tolerated; beyond that, the file is not
// considered cPanel-managed even if the signature appears later.
func isCpanelManagedUserIni(data []byte) bool {
if len(data) == 0 {
return false
}
blanks := 0
start := 0
check := func(line []byte) (matched bool, terminate bool) {
line = bytes.TrimRight(line, "\r")
line = bytes.TrimSpace(line)
if len(line) == 0 {
blanks++
if blanks > cpanelUserIniMaxLeadingBlanks {
return false, true
}
return false, false
}
return bytes.Contains(line, []byte(cpanelUserIniSignature)), true
}
for i := 0; i < len(data); i++ {
if data[i] != '\n' {
continue
}
matched, terminate := check(data[start:i])
if terminate {
return matched
}
start = i + 1
}
if start < len(data) {
matched, _ := check(data[start:])
return matched
}
return false
}
package checks
import (
"strings"
)
type indirectFuncKind uint8
const (
indirectDecoder indirectFuncKind = 1 << iota
indirectEval
indirectShell
)
var indirectFuncKinds = map[string]indirectFuncKind{
"base64_decode": indirectDecoder,
"gzinflate": indirectDecoder,
"gzuncompress": indirectDecoder,
"gzdecode": indirectDecoder,
"bzdecompress": indirectDecoder,
"str_rot13": indirectDecoder,
"rawurldecode": indirectDecoder,
"eval": indirectEval,
"assert": indirectEval,
"create_function": indirectEval,
"system": indirectShell,
"passthru": indirectShell,
"exec": indirectShell,
"shell_exec": indirectShell,
"popen": indirectShell,
"proc_open": indirectShell,
"pcntl_exec": indirectShell,
}
type indirectAssignment struct {
kind indirectFuncKind
pos int
}
type indirectCall struct {
variable string
kind indirectFuncKind
pos int
lineStart int
lineEnd int
line string
}
// detectVarFuncDangerousAssignment returns true when content contains a
// `$var = "dangerous_name"` assignment followed by a corresponding
// `$var(` invocation that forms an execution sink. Decoder-only
// callbacks are not enough; direct base64_decode() calls are common in
// legitimate plugins, and indirect callbacks need the same restraint.
func detectVarFuncDangerousAssignment(content string) bool {
code := stripPHPCommentsFromCode(content)
assignments := findIndirectAssignments(code)
if len(assignments) == 0 {
return false
}
calls := findIndirectCalls(code, assignments)
if len(calls) == 0 {
return false
}
for _, call := range calls {
switch {
case call.kind&indirectShell != 0:
if lineContainsRequestVar(call.line) {
return true
}
case call.kind&indirectEval != 0:
if lineContainsRequestVar(call.line) ||
lineContainsDirectDecoderCall(call.line) ||
lineContainsIndirectCallKind(calls, call.lineStart, call.lineEnd, indirectDecoder) {
return true
}
case call.kind&indirectDecoder != 0:
if lineContainsDirectEvalCall(call.line) ||
lineContainsIndirectCallKind(calls, call.lineStart, call.lineEnd, indirectEval) {
return true
}
}
}
return false
}
func findIndirectAssignments(code string) map[string][]indirectAssignment {
assignments := map[string][]indirectAssignment{}
for i := 0; i < len(code); i++ {
if isPHPQuote(code[i]) {
i = skipPHPString(code, i)
continue
}
if code[i] != '$' {
continue
}
variable, next, ok := readPHPVariableName(code, i)
if !ok {
continue
}
j := skipPHPWhitespace(code, next)
if j >= len(code) || code[j] != '=' || (j+1 < len(code) && (code[j+1] == '=' || code[j+1] == '>')) {
i = next - 1
continue
}
assignment := indirectAssignment{pos: j + 1}
valueStart := skipPHPWhitespace(code, j+1)
if valueStart < len(code) && isPHPQuote(code[valueStart]) {
value, valueEnd, valueOK := readPHPFunctionString(code, valueStart)
assignment.pos = valueEnd
if valueOK {
assignment.kind = indirectFuncKinds[value]
}
}
assignments[variable] = append(assignments[variable], assignment)
i = next - 1
}
return assignments
}
func findIndirectCalls(code string, assignments map[string][]indirectAssignment) []indirectCall {
var calls []indirectCall
for i := 0; i < len(code); i++ {
if isPHPQuote(code[i]) {
i = skipPHPString(code, i)
continue
}
if code[i] != '$' {
continue
}
variable, next, ok := readPHPVariableName(code, i)
if !ok {
continue
}
j := skipPHPWhitespace(code, next)
if j >= len(code) || code[j] != '(' {
i = next - 1
continue
}
kind, ok := indirectKindAt(assignments[variable], i)
if !ok {
i = next - 1
continue
}
lineStart, lineEnd := phpLineBounds(code, i)
calls = append(calls, indirectCall{
variable: variable,
kind: kind,
pos: i,
lineStart: lineStart,
lineEnd: lineEnd,
line: code[lineStart:lineEnd],
})
i = next - 1
}
return calls
}
func indirectKindAt(assignments []indirectAssignment, callPos int) (indirectFuncKind, bool) {
var last indirectFuncKind
found := false
for _, assignment := range assignments {
if assignment.pos > callPos {
break
}
last = assignment.kind
found = true
}
return last, found && last != 0
}
func lineContainsDirectDecoderCall(line string) bool {
codeLine := strings.ToLower(stripPHPStringsFromCode(line))
for name, kind := range indirectFuncKinds {
if kind&indirectDecoder != 0 && containsStandaloneFunc(codeLine, name+"(") {
return true
}
}
return false
}
func lineContainsDirectEvalCall(line string) bool {
codeLine := strings.ToLower(stripPHPStringsFromCode(line))
for name, kind := range indirectFuncKinds {
if kind&indirectEval != 0 && containsStandaloneFunc(codeLine, name+"(") {
return true
}
}
return false
}
func lineContainsRequestVar(line string) bool {
return containsRequestSuperglobal(stripPHPStringsFromCode(line))
}
func lineContainsIndirectCallKind(calls []indirectCall, lineStart, lineEnd int, kind indirectFuncKind) bool {
for _, call := range calls {
if call.lineStart == lineStart && call.lineEnd == lineEnd && call.kind&kind != 0 {
return true
}
}
return false
}
func stripPHPCommentsFromCode(code string) string {
var b strings.Builder
b.Grow(len(code))
for i := 0; i < len(code); i++ {
// Heredoc/nowdoc bodies are string literals, not code. Copy them
// verbatim so a '#', '//', or '/*' inside the body is not mistaken
// for a comment (which would corrupt the surrounding code).
if label, bodyStart, ok := phpHeredocOpen(code, i); ok {
end := phpHeredocEnd(code, bodyStart, label)
b.WriteString(code[i:end])
i = end - 1
continue
}
if isPHPQuote(code[i]) {
i = copyPHPString(&b, code, i)
continue
}
if code[i] == '/' && i+1 < len(code) && code[i+1] == '*' {
b.WriteString(" ")
i += 2
for i < len(code) {
if code[i] == '*' && i+1 < len(code) && code[i+1] == '/' {
b.WriteString(" ")
i++
break
}
writeCommentReplacementByte(&b, code[i])
i++
}
continue
}
if code[i] == '/' && i+1 < len(code) && code[i+1] == '/' {
b.WriteString(" ")
i += 2
for i < len(code) {
if code[i] == '\n' || code[i] == '\r' {
b.WriteByte(code[i])
break
}
b.WriteByte(' ')
i++
}
continue
}
if code[i] == '#' {
b.WriteByte(' ')
i++
for i < len(code) {
if code[i] == '\n' || code[i] == '\r' {
b.WriteByte(code[i])
break
}
b.WriteByte(' ')
i++
}
continue
}
b.WriteByte(code[i])
}
return b.String()
}
func stripPHPStringsFromCode(code string) string {
var b strings.Builder
b.Grow(len(code))
for i := 0; i < len(code); i++ {
// Blank heredoc/nowdoc bodies (and their opener/closing label) the same
// way single/double-quoted strings are blanked, so their contents are
// not analysed as code and a quote inside the body cannot desync the
// scanner and swallow real code that follows the heredoc.
if label, bodyStart, ok := phpHeredocOpen(code, i); ok {
end := phpHeredocEnd(code, bodyStart, label)
for k := i; k < end; k++ {
if code[k] == '\n' || code[k] == '\r' {
b.WriteByte(code[k])
} else {
b.WriteByte(' ')
}
}
i = end - 1
continue
}
if isPHPQuote(code[i]) {
i = replacePHPString(&b, code, i)
continue
}
b.WriteByte(code[i])
}
return b.String()
}
// phpHeredocOpen reports whether code[i:] opens a heredoc or nowdoc. On success
// it returns the label and the byte index where the body begins (just past the
// opening line's newline). It recognises `<<<LABEL`, `<<<"LABEL"` (heredoc) and
// `<<<'LABEL'` (nowdoc), with optional spaces/tabs after `<<<`.
func phpHeredocOpen(code string, i int) (label string, bodyStart int, ok bool) {
if i+3 > len(code) || code[i] != '<' || code[i+1] != '<' || code[i+2] != '<' {
return "", 0, false
}
if i > 0 && code[i-1] == '<' {
return "", 0, false
}
j := i + 3
for j < len(code) && (code[j] == ' ' || code[j] == '\t') {
j++
}
var quote byte
if j < len(code) && (code[j] == '\'' || code[j] == '"') {
quote = code[j]
j++
}
if j >= len(code) || !isPHPIdentifierStart(code[j]) {
return "", 0, false
}
start := j
j++
for j < len(code) && isPHPIdentifierPart(code[j]) {
j++
}
label = code[start:j]
if quote != 0 {
if j >= len(code) || code[j] != quote {
return "", 0, false
}
j++
}
// The opening line ends at the next newline; only trailing whitespace and
// an optional CR may sit between the label and that newline.
for j < len(code) && (code[j] == ' ' || code[j] == '\t' || code[j] == '\r') {
j++
}
if j >= len(code) || code[j] != '\n' {
return "", 0, false
}
return label, j + 1, true
}
// phpHeredocEnd returns the byte index just past the closing label of a heredoc
// whose body begins at bodyStart. PHP 7.3+ permits the closing label to be
// indented; the label must appear at the start of a line (after optional
// spaces/tabs) and be followed by a non-identifier byte. An unterminated
// heredoc consumes the rest of the input.
func phpHeredocEnd(code string, bodyStart int, label string) int {
i := bodyStart
for i < len(code) {
lineEnd := i
for lineEnd < len(code) && code[lineEnd] != '\n' {
lineEnd++
}
k := i
for k < lineEnd && (code[k] == ' ' || code[k] == '\t') {
k++
}
if k+len(label) <= lineEnd && code[k:k+len(label)] == label {
after := k + len(label)
if after >= len(code) || !isPHPIdentifierPart(code[after]) {
return after
}
}
if lineEnd >= len(code) {
break
}
i = lineEnd + 1
}
return len(code)
}
func copyPHPString(b *strings.Builder, code string, start int) int {
quote := code[start]
b.WriteByte(code[start])
for i := start + 1; i < len(code); i++ {
b.WriteByte(code[i])
if code[i] == '\\' && i+1 < len(code) {
i++
b.WriteByte(code[i])
continue
}
if code[i] == quote {
return i
}
}
return len(code) - 1
}
func replacePHPString(b *strings.Builder, code string, start int) int {
quote := code[start]
b.WriteByte(' ')
for i := start + 1; i < len(code); i++ {
if code[i] == '\n' || code[i] == '\r' {
b.WriteByte(code[i])
} else {
b.WriteByte(' ')
}
if code[i] == '\\' && i+1 < len(code) {
i++
if code[i] == '\n' || code[i] == '\r' {
b.WriteByte(code[i])
} else {
b.WriteByte(' ')
}
continue
}
if code[i] == quote {
return i
}
}
return len(code) - 1
}
func writeCommentReplacementByte(b *strings.Builder, c byte) {
if c == '\n' || c == '\r' {
b.WriteByte(c)
return
}
b.WriteByte(' ')
}
func readPHPVariableName(code string, dollar int) (string, int, bool) {
if dollar+1 >= len(code) || !isPHPIdentifierStart(code[dollar+1]) {
return "", dollar + 1, false
}
i := dollar + 2
for i < len(code) && isPHPIdentifierPart(code[i]) {
i++
}
return code[dollar+1 : i], i, true
}
func readPHPFunctionString(code string, start int) (string, int, bool) {
quote := code[start]
var b strings.Builder
for i := start + 1; i < len(code); i++ {
if code[i] == '\\' && i+1 < len(code) {
if code[i+1] == quote || code[i+1] == '\\' {
i++
b.WriteByte(code[i])
continue
}
b.WriteByte(code[i])
continue
}
if code[i] == quote {
value, ok := normalizeIndirectFunctionName(b.String())
return value, i + 1, ok
}
b.WriteByte(code[i])
}
return "", len(code), false
}
func normalizeIndirectFunctionName(value string) (string, bool) {
value = strings.TrimLeft(value, "\\")
if !isPHPIdentifier(value) {
return "", false
}
return strings.ToLower(value), true
}
func skipPHPString(code string, start int) int {
quote := code[start]
for i := start + 1; i < len(code); i++ {
if code[i] == '\\' && i+1 < len(code) {
i++
continue
}
if code[i] == quote {
return i
}
}
return len(code) - 1
}
func skipPHPWhitespace(code string, start int) int {
for start < len(code) {
switch code[start] {
case ' ', '\t', '\n', '\r', '\f', '\v':
start++
default:
return start
}
}
return start
}
func phpLineBounds(code string, pos int) (int, int) {
start := pos
for start > 0 && code[start-1] != '\n' && code[start-1] != '\r' {
start--
}
end := pos
for end < len(code) && code[end] != '\n' && code[end] != '\r' {
end++
}
return start, end
}
func isPHPQuote(c byte) bool {
return c == '\'' || c == '"'
}
func isPHPIdentifier(value string) bool {
if value == "" || !isPHPIdentifierStart(value[0]) {
return false
}
for i := 1; i < len(value); i++ {
if !isPHPIdentifierPart(value[i]) {
return false
}
}
return true
}
func isPHPIdentifierStart(c byte) bool {
return c == '_' || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')
}
func isPHPIdentifierPart(c byte) bool {
return isPHPIdentifierStart(c) || (c >= '0' && c <= '9')
}
package checks
import (
"bufio"
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/modsec"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/state"
)
// wafRulesAssembleRetryDelay is the wait between the first negative
// probe and the re-probe on cPanel+LiteSpeed hosts where cPanel's
// nightly modsec_assemble briefly leaves both `whmapi1
// modsec_get_vendors` and the vendor dir empty while it rewrites the
// tree in place. Observed windows are <10s; 30s gives margin without
// meaningfully delaying the surrounding deep-scan tier. Tests override
// this to keep the suite fast.
var wafRulesAssembleRetryDelay = 30 * time.Second
// CheckWAFStatus verifies that ModSecurity is loaded, the engine is in
// enforcement mode (not DetectionOnly), OWASP/Comodo rules are active,
// and rules are up to date.
func CheckWAFStatus(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
info := platform.Detect()
// If there is no web server at all, WAF concerns don't apply to this host.
if info.WebServer == platform.WSNone {
return findings
}
modsecActive := modsecDetected(info)
if !modsecActive {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "waf_status",
Message: "ModSecurity WAF is not active",
Details: wafInstallHint(info),
})
return findings // no point checking further
}
// --- Engine mode check ---
engineMode := checkEngineMode(info)
if engineMode == "detectiononly" {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "waf_detection_only",
Message: "ModSecurity is in DetectionOnly mode - attacks are logged but NOT blocked",
Details: "SecRuleEngine is set to DetectionOnly. Change to 'On' for enforcement:\nWHM > Security Center > ModSecurity > Edit Global Directive",
})
}
// --- Rule vendor check ---
ruleDirs := modsecRuleDirs(info)
hasRules := probeWAFRules(info, ruleDirs)
// cPanel+LiteSpeed: cPanel's nightly modsec_assemble rewrites the
// vendor tree in place, so for ~6-10s both `whmapi1
// modsec_get_vendors` and the vendor dir return empty. A production
// false positive at 01:10:27 fired 6s after the rewrite. Re-probe
// once after a short delay before alerting; on a host that really
// has no rules, the re-probe is still negative and we alert in the
// same scan, so this doesn't shift detection to the next deep tier.
if !hasRules && info.IsCPanel() && info.WebServer == platform.WSLiteSpeed {
select {
case <-time.After(wafRulesAssembleRetryDelay):
case <-ctx.Done():
return findings
}
hasRules = probeWAFRules(info, ruleDirs)
}
if !hasRules {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "waf_rules",
Message: "ModSecurity has no WAF rules loaded",
Details: wafRulesHint(info),
})
}
// --- Rule age check + auto-update ---
if hasRules {
staleAge := checkRuleAge(ruleDirs)
if staleAge > 0 {
// Attempt auto-update before alerting
updated := false
if info.IsCPanel() {
updated = autoUpdateWAFRules()
}
if updated {
// Re-check age after update
staleAge = checkRuleAge(ruleDirs)
}
if staleAge > 0 {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "waf_rules_stale",
Message: fmt.Sprintf("ModSecurity rules are %d days old - update recommended", staleAge),
Details: wafRulesStaleHint(info),
})
}
}
}
// --- Virtual patch deployment ---
// Only cPanel has the modsec user config dirs we write into.
if info.IsCPanel() {
deployVirtualPatches()
}
// --- Per-account WAF bypass check ---
// whmapi1-only, skip on non-cPanel hosts.
if info.IsCPanel() {
bypassed := checkPerAccountBypass()
for _, domain := range bypassed {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "waf_bypass",
Message: fmt.Sprintf("ModSecurity disabled for domain: %s", domain),
Details: "This domain has ModSecurity bypassed. All web attacks pass through unfiltered.\nCheck: WHM > Security Center > ModSecurity > Domains",
})
}
}
return findings
}
// probeWAFRules checks whether any WAF rule source — cPanel's whmapi1
// vendor list or the on-disk vendor/CRS directories — currently
// reports rules. Used by CheckWAFStatus directly and again on retry
// for the cPanel+LiteSpeed modsec_assemble race.
func probeWAFRules(info platform.Info, ruleDirs []string) bool {
if info.IsCPanel() {
if out, _ := runCmd("whmapi1", "modsec_get_vendors"); out != nil {
outStr := string(out)
if strings.Contains(outStr, "comodo") || strings.Contains(outStr, "owasp") ||
strings.Contains(outStr, "OWASP") || strings.Contains(outStr, "Comodo") {
return true
}
}
}
return hasRuleArtifacts(ruleDirs)
}
// modsecDetected returns true if a ModSecurity module is loaded for the
// detected web server. It first consults the platform layer, then falls
// back to scanning config files.
func modsecDetected(info platform.Info) bool {
// cPanel fast path
if info.IsCPanel() {
if out, _ := runCmd("whmapi1", "modsec_is_installed"); out != nil &&
strings.Contains(string(out), "installed: 1") {
return true
}
}
// Generic file-based probes per web server
for _, conf := range expandPathGlobs(modsecActivationCandidates(info)) {
data, err := osFS.ReadFile(conf)
if err != nil {
continue
}
if modsecEnabledInConfig(info, string(data)) {
return true
}
}
return false
}
// modsecActivationCandidates returns the config files that can enable the
// ModSecurity module for the detected web server.
func modsecActivationCandidates(info platform.Info) []string {
var paths []string
switch info.WebServer {
case platform.WSApache:
if info.ApacheConfigDir != "" {
paths = append(paths,
filepath.Join(info.ApacheConfigDir, "httpd.conf"),
filepath.Join(info.ApacheConfigDir, "apache2.conf"),
filepath.Join(info.ApacheConfigDir, "modsec2.conf"),
filepath.Join(info.ApacheConfigDir, "conf.d", "modsec2.conf"),
filepath.Join(info.ApacheConfigDir, "mods-enabled", "security2.conf"),
filepath.Join(info.ApacheConfigDir, "conf-enabled", "security2.conf"),
filepath.Join(info.ApacheConfigDir, "conf.d", "mod_security.conf"),
filepath.Join(info.ApacheConfigDir, "conf.modules.d", "10-mod_security.conf"),
filepath.Join(info.ApacheConfigDir, "conf.d", "*.conf"),
filepath.Join(info.ApacheConfigDir, "mods-enabled", "*.conf"),
filepath.Join(info.ApacheConfigDir, "conf-enabled", "*.conf"),
)
}
case platform.WSNginx:
if info.NginxConfigDir != "" {
paths = append(paths,
filepath.Join(info.NginxConfigDir, "nginx.conf"),
filepath.Join(info.NginxConfigDir, "conf.d", "*.conf"),
filepath.Join(info.NginxConfigDir, "sites-enabled", "*"),
)
}
case platform.WSLiteSpeed:
paths = append(paths,
"/usr/local/lsws/conf/httpd_config.xml",
)
}
return paths
}
// modsecConfigCandidates returns the set of config files worth scanning
// for ModSecurity directives on the detected web server.
func modsecConfigCandidates(info platform.Info) []string {
paths := append([]string(nil), modsecActivationCandidates(info)...)
switch info.WebServer {
case platform.WSNginx:
if info.NginxConfigDir != "" {
paths = append(paths,
filepath.Join(info.NginxConfigDir, "modules-enabled", "*.conf"),
filepath.Join(info.NginxConfigDir, "modsec", "main.conf"),
)
}
case platform.WSLiteSpeed:
paths = append(paths, "/usr/local/lsws/conf/modsec2.conf")
}
return paths
}
func modsecEnabledInConfig(info platform.Info, contents string) bool {
scanner := bufio.NewScanner(strings.NewReader(contents))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
lineLower := strings.ToLower(line)
switch info.WebServer {
case platform.WSNginx:
if strings.HasPrefix(lineLower, "#") {
continue
}
if strings.HasPrefix(lineLower, "modsecurity on") ||
strings.HasPrefix(lineLower, "modsecurity_rules ") ||
strings.HasPrefix(lineLower, "modsecurity_rules_file ") {
return true
}
case platform.WSLiteSpeed:
if strings.Contains(lineLower, "mod_security") || strings.Contains(lineLower, "modsecurity") {
return true
}
default:
if strings.HasPrefix(lineLower, "#") {
continue
}
if strings.Contains(lineLower, "security2_module") ||
strings.HasPrefix(lineLower, "secruleengine ") ||
strings.Contains(lineLower, "mod_security2") {
return true
}
}
}
return false
}
func expandPathGlobs(paths []string) []string {
var expanded []string
seen := make(map[string]struct{})
for _, candidate := range paths {
matches := []string{candidate}
if strings.ContainsAny(candidate, "*?[") {
if globbed, err := osFS.Glob(candidate); err == nil && len(globbed) > 0 {
matches = globbed
}
}
for _, match := range matches {
if _, ok := seen[match]; ok {
continue
}
seen[match] = struct{}{}
expanded = append(expanded, match)
}
}
return expanded
}
// modsecRuleDirs delegates to the canonical helper in internal/modsec.
// Kept as a package-local thin wrapper because the existing waf check tests
// reference this name directly.
func modsecRuleDirs(info platform.Info) []string {
return modsec.RuleDirs(info)
}
// wafInstallHint returns platform-specific install instructions.
func wafInstallHint(info platform.Info) string {
switch {
case info.IsCPanel():
return "No ModSecurity module detected. Install: WHM > Security Center > ModSecurity"
case info.WebServer == platform.WSNginx && info.IsDebianFamily():
return "No ModSecurity module detected for Nginx.\nInstall: apt install libnginx-mod-http-modsecurity modsecurity-crs"
case info.WebServer == platform.WSApache && info.IsDebianFamily():
return "No ModSecurity module detected for Apache.\nInstall: apt install libapache2-mod-security2 modsecurity-crs && a2enmod security2"
case info.WebServer == platform.WSApache && info.IsRHELFamily():
return "No ModSecurity module detected for Apache.\nInstall (requires EPEL): dnf install -y epel-release && dnf install -y mod_security mod_security_crs && systemctl restart httpd"
case info.WebServer == platform.WSNginx && info.IsRHELFamily():
return "No ModSecurity module detected for Nginx.\nInstall (requires EPEL): dnf install -y epel-release && dnf install -y nginx-mod-http-modsecurity && systemctl restart nginx"
}
return "No ModSecurity module detected. The server has no web application firewall protecting against SQL injection, XSS, and other web attacks."
}
// wafRulesHint returns platform-specific rules-install instructions.
func wafRulesHint(info platform.Info) string {
if info.IsCPanel() {
return "ModSecurity is installed but has no OWASP or Comodo rules. Add rules: WHM > Security Center > ModSecurity Vendors"
}
if info.IsDebianFamily() {
return "ModSecurity is installed but has no rules loaded. Install OWASP CRS: apt install modsecurity-crs"
}
if info.IsRHELFamily() {
return "ModSecurity is installed but has no rules loaded. Install OWASP CRS: dnf install --enablerepo=epel modsecurity-crs"
}
return "ModSecurity is installed but has no rules loaded."
}
// wafRulesStaleHint returns platform-specific advice for updating stale
// ModSecurity vendor rules.
func wafRulesStaleHint(info platform.Info) string {
if info.IsCPanel() {
return "Vendor rules should be updated at least monthly. Check: WHM > Security Center > ModSecurity Vendors > Update"
}
if info.IsDebianFamily() {
return "Vendor rules should be updated at least monthly. Update with: apt update && apt upgrade modsecurity-crs"
}
if info.IsRHELFamily() {
return "Vendor rules should be updated at least monthly. Update with: dnf upgrade modsecurity-crs"
}
return "Vendor rules should be updated at least monthly."
}
// checkEngineMode reads ModSecurity config files to determine the SecRuleEngine setting.
// Returns "on", "detectiononly", "off", or "" if unknown.
func checkEngineMode(info platform.Info) string {
configPaths := modsecConfigCandidates(info)
// Also include the top-level modsecurity.conf installed by distro packages.
configPaths = append(configPaths,
"/etc/modsecurity/modsecurity.conf",
"/etc/nginx/modsec/modsecurity.conf",
)
for _, path := range configPaths {
f, err := osFS.Open(path)
if err != nil {
continue
}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "#") {
continue
}
lineLower := strings.ToLower(line)
if strings.HasPrefix(lineLower, "secruleengine") {
parts := strings.Fields(lineLower)
if len(parts) >= 2 {
_ = f.Close()
return parts[1]
}
}
}
_ = f.Close()
}
return ""
}
// checkRuleAge returns the age in days of the oldest rule file, or 0 if rules are fresh.
// Only alerts if rules are >30 days old.
func checkRuleAge(ruleDirs []string) int {
oldestMtime, found := oldestRuleArtifact(ruleDirs)
if !found {
return 0
}
age := int(time.Since(oldestMtime).Hours() / 24)
if age > 30 {
return age
}
return 0
}
func hasRuleArtifacts(ruleDirs []string) bool {
_, found := oldestRuleArtifact(ruleDirs)
return found
}
func oldestRuleArtifact(ruleDirs []string) (time.Time, bool) {
var oldestMtime time.Time
found := false
for _, dir := range ruleDirs {
entries, err := osFS.ReadDir(dir)
if err != nil {
continue
}
for _, entry := range entries {
if !entry.IsDir() {
// Rule file directly in the rule dir (distro CRS layout,
// e.g. /usr/share/modsecurity-crs/rules/REQUEST-*.conf).
info, err := entry.Info()
if err != nil {
continue
}
if !isRuleArtifact(entry.Name()) {
continue
}
if !found || info.ModTime().Before(oldestMtime) {
oldestMtime = info.ModTime()
found = true
}
continue
}
// Subdirectory: scan one level deeper for vendor-packed rules
// (cPanel layout, e.g. /usr/local/apache/conf/modsec_vendor_configs/OWASP/*.conf).
subDir := dir + "/" + entry.Name()
subEntries, err := osFS.ReadDir(subDir)
if err != nil {
continue
}
for _, subEntry := range subEntries {
if subEntry.IsDir() {
continue
}
if !isRuleArtifact(subEntry.Name()) {
continue
}
info, err := subEntry.Info()
if err != nil {
continue
}
if !found || info.ModTime().Before(oldestMtime) {
oldestMtime = info.ModTime()
found = true
}
}
}
}
return oldestMtime, found
}
// isRuleArtifact reports whether a filename looks like a ModSecurity rule
// or data artifact (.conf, .data, .rules) so unrelated files like README
// or LICENSE don't dominate the oldest-mtime calculation.
func isRuleArtifact(name string) bool {
name = strings.ToLower(name)
return strings.HasSuffix(name, ".conf") ||
strings.HasSuffix(name, ".data") ||
strings.HasSuffix(name, ".rules")
}
// checkPerAccountBypass checks for domains with ModSecurity disabled.
func checkPerAccountBypass() []string {
out, err := runCmd("whmapi1", "modsec_get_rules")
if err != nil || out == nil {
return nil
}
var bypassed []string
outStr := string(out)
// Parse YAML-like output for disabled domains
// The output format varies, but disabled rules/domains show "disabled: 1" or "active: 0"
lines := strings.Split(outStr, "\n")
for i, line := range lines {
lineLower := strings.ToLower(strings.TrimSpace(line))
if strings.Contains(lineLower, "disabled: 1") || strings.Contains(lineLower, "active: 0") {
// Look backward for the domain/config name
for j := i - 1; j >= 0 && j >= i-5; j-- {
prev := strings.TrimSpace(lines[j])
if strings.HasSuffix(prev, ":") && !strings.HasPrefix(prev, "-") {
domain := strings.TrimSuffix(prev, ":")
if strings.Contains(domain, ".") { // looks like a domain
bypassed = append(bypassed, domain)
}
break
}
}
}
}
return bypassed
}
// deployVirtualPatches ensures CSM's custom ModSec rules are installed.
// These provide virtual patches for known WordPress CVEs.
func deployVirtualPatches() {
// Possible modsec user config paths
destPaths := []string{
"/etc/apache2/conf.d/modsec/modsec2.user.conf",
"/usr/local/apache/conf/modsec2.user.conf",
}
srcPath := "/opt/csm/configs/csm_modsec_custom.conf"
srcData, err := osFS.ReadFile(srcPath)
if err != nil {
return // no custom rules to deploy
}
for _, dest := range destPaths {
dir := filepath.Dir(dest)
if _, err := osFS.Stat(dir); os.IsNotExist(err) {
continue
}
// Check if CSM rules are already included
existing, err := osFS.ReadFile(dest)
if err == nil && strings.Contains(string(existing), "CSM Custom ModSecurity Rules") {
// Already deployed - check if rules need updating
if string(existing) == string(srcData) {
return // up to date
}
}
// Deploy: if file exists and has non-CSM content, append. Otherwise write.
if err == nil && len(existing) > 0 && !strings.Contains(string(existing), "CSM Custom ModSecurity Rules") {
// Append to existing user config
// #nosec G302 G304 -- WAF rule file read by Apache/nginx as a different user; dest is fixed list above.
f, err := os.OpenFile(dest, os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
continue
}
_, _ = f.Write([]byte("\n\n"))
_, _ = f.Write(srcData)
_ = f.Close()
} else {
// Write or overwrite
// #nosec G306 -- same reason: webserver-readable WAF rule file.
_ = os.WriteFile(dest, srcData, 0644)
}
fmt.Fprintf(os.Stderr, "[%s] Virtual patches deployed to %s\n",
time.Now().Format("2006-01-02 15:04:05"), dest)
return
}
}
// autoUpdateWAFRules triggers ModSecurity vendor rule updates via whmapi1.
// Returns true if an update was successfully triggered.
func autoUpdateWAFRules() bool {
// Get installed vendors
out, err := runCmd("whmapi1", "modsec_get_vendors")
if err != nil || out == nil {
return false
}
// Parse vendor IDs from output (look for "vendor_id:" lines)
var vendors []string
for _, line := range strings.Split(string(out), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "vendor_id:") || strings.HasPrefix(line, "id:") {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
vid := strings.TrimSpace(parts[1])
if vid != "" {
vendors = append(vendors, vid)
}
}
}
}
if len(vendors) == 0 {
return false
}
// Update each vendor
updated := false
for _, vid := range vendors {
out, err := runCmd("whmapi1", "modsec_update_vendor", fmt.Sprintf("vendor_id=%s", vid))
if err == nil && out != nil && strings.Contains(string(out), "result: 1") {
fmt.Fprintf(os.Stderr, "[%s] WAF auto-update: vendor %s updated successfully\n",
time.Now().Format("2006-01-02 15:04:05"), vid)
updated = true
}
}
return updated
}
// CheckModSecAuditLog parses the ModSecurity audit log for blocked attacks.
// High-volume attackers are reported for potential auto-blocking.
func CheckModSecAuditLog(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
logPaths := platform.Detect().ModSecAuditLogPaths
if len(logPaths) == 0 {
return nil
}
var lines []string
for _, path := range logPaths {
lines = tailFile(path, 200)
if len(lines) > 0 {
break
}
}
if len(lines) == 0 {
return nil
}
// Count blocked attacks per IP
blocked := make(map[string]int)
for _, line := range lines {
if strings.Contains(line, "403") || strings.Contains(line, "Access denied") ||
strings.Contains(line, "MODSEC") || strings.Contains(line, "mod_security") {
ip := extractIPFromLog(line)
if ip != "" && !isInfraIP(ip, cfg.InfraIPs) {
blocked[ip]++
}
}
}
// Alert on high-volume attackers (auto-block integration via check name)
for ip, count := range blocked {
if count >= 20 {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "waf_attack_blocked",
SourceIP: ip,
Message: fmt.Sprintf("WAF blocking high-volume attacker: %s (%d blocked requests)", ip, count),
Details: fmt.Sprintf("IP %s has been blocked %d times by ModSecurity. Consider permanent block via CSM.", ip, count),
})
}
}
return findings
}
package checks
import (
"bufio"
"context"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const wpChecksumWorkers = 5 // concurrent wp core verify-checksums
// htaccessMaxLineBytes bounds a single .htaccess line for the token scanner.
// Legitimate directives are far shorter; a line past this is itself an
// anomaly, so the scanner fails closed (flags the file) rather than silently
// truncating the rest.
const htaccessMaxLineBytes = 1 << 20 // 1 MiB
// CheckHtaccess scans for malicious .htaccess directives using pure Go ReadDir.
func CheckHtaccess(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
suspiciousPatterns := []string{
"auto_prepend_file",
"auto_append_file",
"eval(",
"base64_decode",
"gzinflate",
"str_rot13",
"php_value disable_functions",
"addhandler",
"addtype",
"sethandler",
}
safePatterns := []string{
"wordfence-waf.php",
"litespeed",
"advanced-headers.php",
"rsssl",
// Standard handler directives for PHP/static files are safe
"application/x-httpd-php",
"application/x-httpd-php5",
"application/x-httpd-ea-php",
"application/x-httpd-alt-php",
"text/html",
"text/css",
"text/javascript",
"application/javascript",
"image/",
"font/",
"proxy:unix",
// Security plugins that use handler directives to BLOCK execution
"-execcgi", // Options -ExecCGI disables CGI (Wordfence pattern)
"sethandler none", // Disables all handlers (security measure)
"sethandler default-handler", // Resets to default (security measure)
// Legitimate MIME type additions
"application/font",
"application/vnd",
".woff",
".woff2",
".ttf",
".eot",
".svg",
// Wordfence code execution protection
"wordfence",
}
// Scan each user's document roots
homeDirs, _ := GetScanHomeDirs(ctx)
for _, homeEntry := range homeDirs {
if ctx.Err() != nil {
return findings
}
if !homeEntry.IsDir() {
continue
}
homeDir := filepath.Join("/home", homeEntry.Name())
docRoot := filepath.Join(homeDir, "public_html")
scanHtaccess(ctx, docRoot, 5, suspiciousPatterns, safePatterns, cfg, &findings)
// Also check addon domains
subDirs, _ := osFS.ReadDir(homeDir)
for _, sd := range subDirs {
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
scanHtaccess(ctx, filepath.Join(homeDir, sd.Name()), 5, suspiciousPatterns, safePatterns, cfg, &findings)
}
}
}
return findings
}
func scanHtaccess(ctx context.Context, dir string, maxDepth int, suspicious, safe []string, cfg *config.Config, findings *[]alert.Finding) {
if ctx.Err() != nil {
return
}
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
if ctx.Err() != nil {
return
}
name := entry.Name()
fullPath := filepath.Join(dir, name)
if entry.IsDir() {
scanHtaccess(ctx, fullPath, maxDepth-1, suspicious, safe, cfg, findings)
continue
}
if name != ".htaccess" {
continue
}
// Skip suppressed paths
suppressed := false
for _, ignore := range cfg.Suppressions.IgnorePaths {
if matchGlob(fullPath, ignore) {
suppressed = true
break
}
}
if suppressed {
continue
}
checkHtaccessFile(fullPath, suspicious, safe, findings)
// Run the hardened detector registry alongside the generic
// token scanner so per-pattern findings emit with their own
// names (htaccess_php_in_uploads, htaccess_filesmatch_shield,
// etc.) rather than collapsing into the catch-all categories.
// The two scans can both fire on the same line; downstream
// dedup at alert.Dispatch handles same-key dups.
hardenedFindings, _ := AuditHtaccessFile(fullPath)
*findings = append(*findings, hardenedFindings...)
}
}
// phpExtension reports whether ext (leading dot, lowercase) is a stock
// PHP-executed extension that legitimately maps to a PHP handler.
func phpExtension(ext string) bool {
return isExecutablePHPName("x" + ext)
}
// phpHandlerRemapsNonPHP reports whether a single lowercased .htaccess line is
// an AddHandler/AddType/SetHandler directive routing a non-PHP file extension
// to a PHP execution handler. cPanel MultiPHP legitimately maps PHP extensions
// to a PHP handler (e.g. application/x-httpd-ea-php74___lsphp .php .php7), so we
// flag only when at least one mapped extension is not PHP-family.
func phpHandlerRemapsNonPHP(lineLower string) bool {
return phpHandlerRemapsNonPHPInContext(lineLower, nil)
}
func phpHandlerRemapsNonPHPInContext(lineLower string, contexts []phpHandlerOverlay) bool {
fields := apacheDirectiveFields(lineLower)
if len(fields) < 2 {
return false
}
directive := fields[0]
switch directive {
case "addhandler", "addtype", "sethandler", "forcetype":
default:
return false
}
handler := fields[1]
if !handlerIsPHP(handler) {
return false
}
exts := normalizedExts(fields[2:])
if len(exts) == 0 {
return directiveHandlerContextTargetsNonPHP(directive, contexts)
}
for _, ext := range exts {
if !phpExtension(ext) {
return true
}
}
return false
}
func directiveHandlerContextTargetsNonPHP(directive string, contexts []phpHandlerOverlay) bool {
if directive != "sethandler" && directive != "forcetype" {
return false
}
for _, ctx := range contexts {
for ext := range ctx.exts {
if !phpExtension(ext) {
return true
}
}
for name := range ctx.names {
if !isExecutablePHPName(name) {
return true
}
}
}
return false
}
func checkHtaccessFile(path string, suspicious, safe []string, findings *[]alert.Finding) {
f, err := osFS.Open(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
// Read entire file to check cross-line context (e.g., AddHandler + Options -ExecCGI)
var lines []string
scanner := bufio.NewScanner(f)
// A .htaccess line longer than the default 64 KB token would make
// Scan stop early and silently drop every line after it, letting an
// attacker hide a malicious directive behind one padded line. Raise
// the ceiling, and if a line still exceeds it, fail closed below.
scanner.Buffer(make([]byte, 0, 64*1024), htaccessMaxLineBytes)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
if err := scanner.Err(); err != nil {
// Could not read the whole file (oversized line or I/O error), so
// the cross-line and per-line analysis below is incomplete. Flag
// the file for review rather than reporting a clean partial scan.
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "htaccess_injection",
Message: "Unparseable .htaccess: oversized or unreadable line blocks full analysis",
Details: fmt.Sprintf("File: %s\nError: %v", path, err),
FilePath: path,
})
return
}
// Build full file content for context checks
fullContentLower := strings.ToLower(strings.Join(lines, "\n"))
// If file contains handler directives paired with -ExecCGI, the whole
// block is a security measure (e.g., Wordfence execution protection)
hasExecCGIBlock := strings.Contains(fullContentLower, "-execcgi")
var phpHandlerContexts []phpHandlerOverlay
for _, logical := range joinHtaccessContinuations(lines) {
lineNum := logical.start
trimmed := strings.TrimSpace(logical.text)
lineLower := strings.ToLower(trimmed)
// Skip comments entirely - commented-out directives are not active
if strings.HasPrefix(trimmed, "#") {
continue
}
if ctx, ok := openPHPHandlerContext(trimmed); ok {
phpHandlerContexts = append(phpHandlerContexts, ctx)
continue
}
if closesPHPHandlerContext(trimmed) {
if len(phpHandlerContexts) > 0 {
phpHandlerContexts = phpHandlerContexts[:len(phpHandlerContexts)-1]
}
continue
}
// A PHP execution handler mapped onto a non-PHP extension is the
// handler-remap webshell technique: an uploaded .jpg then runs as
// PHP. The safe-pattern and AddType skips below would otherwise
// suppress it because the handler name itself is a normal PHP
// handler, so this override fires first and unconditionally.
remapsNonPHP := phpHandlerRemapsNonPHP(lineLower)
if !remapsNonPHP && len(phpHandlerContexts) > 0 {
remapsNonPHP = phpHandlerRemapsNonPHPInContext(lineLower, phpHandlerContexts)
}
if remapsNonPHP {
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "htaccess_injection",
Message: "PHP handler mapped to non-PHP extension (handler remap)",
Details: fmt.Sprintf("File: %s (line %d)\nContent: %s", path, lineNum+1, trimmed),
FilePath: path,
})
continue
}
for _, pattern := range suspicious {
if !strings.Contains(lineLower, strings.ToLower(pattern)) {
continue
}
// Check per-line safe patterns
isSafe := false
for _, sp := range safe {
if strings.Contains(lineLower, strings.ToLower(sp)) {
isSafe = true
break
}
}
if isSafe {
continue
}
patternLower := strings.ToLower(pattern)
// For handler directives, apply context-aware checks
if patternLower == "addhandler" || patternLower == "sethandler" {
// Skip if paired with -ExecCGI (Wordfence protection)
if hasExecCGIBlock {
continue
}
// Skip Drupal security handlers
if strings.Contains(lineLower, "drupal_security") {
continue
}
// Skip SetHandler none/default (disabling handlers = security measure)
if strings.Contains(lineLower, "sethandler none") ||
strings.Contains(lineLower, "sethandler default") {
continue
}
// Skip AddHandler for standard CGI extensions only (.cgi, .pl)
if strings.Contains(lineLower, "addhandler") {
// Only flag if mapping non-standard extensions
standardCGI := true
hasNonStandard := false
// Check each extension on the line
for _, ext := range []string{".haxor", ".cgix", ".phtml", ".php3",
".php5", ".suspected", ".bak.php", ".shtml", ".sh"} {
if strings.Contains(lineLower, ext) {
hasNonStandard = true
break
}
}
// If line only has .cgi and/or .pl, it's standard
if !hasNonStandard && standardCGI {
onlyStandard := true
parts := strings.Fields(lineLower)
for _, p := range parts {
if strings.HasPrefix(p, ".") && p != ".cgi" && p != ".pl" && p != ".py" &&
p != ".php" && p != ".jsp" && p != ".asp" {
// Has non-standard extension
onlyStandard = false
break
}
}
if onlyStandard {
continue
}
}
}
}
// Skip AddType for any MIME type (application/*, text/*, x-mapp-*, etc.)
if patternLower == "addtype" {
// AddType is only dangerous if it maps to a PHP/CGI handler
// Standard MIME type declarations are safe
if strings.Contains(lineLower, "application/") ||
strings.Contains(lineLower, "text/") ||
strings.Contains(lineLower, "image/") ||
strings.Contains(lineLower, "font/") ||
strings.Contains(lineLower, "x-mapp-") ||
strings.Contains(lineLower, "audio/") ||
strings.Contains(lineLower, "video/") {
continue
}
}
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "htaccess_injection",
Message: fmt.Sprintf("Suspicious .htaccess directive: %s", pattern),
Details: fmt.Sprintf("File: %s (line %d)\nContent: %s", path, lineNum+1, trimmed),
FilePath: path,
})
}
}
// Special check: AddHandler mapping non-standard extensions WITHOUT -ExecCGI
// (actual attack pattern - e.g., AddHandler cgi-script .haxor)
if !hasExecCGIBlock && strings.Contains(fullContentLower, "addhandler") {
for _, logical := range joinHtaccessContinuations(lines) {
lineNum := logical.start
line := logical.text
lineLower := strings.ToLower(line)
if !strings.Contains(lineLower, "addhandler") {
continue
}
// Flag if it maps unusual extensions like .haxor, .cgix, etc.
dangerousExts := []string{".haxor", ".cgix", ".suspected", ".bak.php"}
for _, ext := range dangerousExts {
if strings.Contains(lineLower, ext) {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "htaccess_handler_abuse",
Message: fmt.Sprintf("Malicious handler mapping for %s extension", ext),
Details: fmt.Sprintf("File: %s (line %d)\nContent: %s", path, lineNum+1, strings.TrimSpace(line)),
FilePath: path,
})
}
}
}
}
}
// CheckWPCore runs wp core verify-checksums for each WordPress installation
// using a bounded worker pool for concurrency.
// Installations that pass verification have their core files cached in
// GlobalCMSCache so the real-time scanner can skip signature matches
// on known-clean CMS files.
func CheckWPCore(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
wpConfigs, _ := homeGlob(ctx, "public_html", "wp-config.php")
if len(wpConfigs) == 0 {
return nil
}
cache := GlobalCMSCache()
var mu sync.Mutex
var findings []alert.Finding
var wg sync.WaitGroup
// Bounded worker pool
jobs := make(chan string, len(wpConfigs))
for i := 0; i < wpChecksumWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for wpConfig := range jobs {
if ctx.Err() != nil {
return
}
wpPath := filepath.Dir(wpConfig)
user := extractUser(wpPath)
out, err := runCmdCombinedContext(ctx, "wp", "core", "verify-checksums",
"--path="+wpPath, "--allow-root")
if ctx.Err() != nil {
return
}
if err == nil {
// Verification passed - cache all core files
cacheWPCoreFiles(cache, wpPath)
continue
}
if out == nil {
continue
}
outStr := string(out)
for _, line := range strings.Split(outStr, "\n") {
if strings.Contains(line, "should not exist") && !strings.Contains(line, "error_log") {
mu.Lock()
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "wp_core_integrity",
Message: fmt.Sprintf("WordPress core integrity failure for %s", user),
Details: fmt.Sprintf("Path: %s\n%s", wpPath, line),
})
mu.Unlock()
}
}
}
}()
}
for _, wpConfig := range wpConfigs {
if ctx.Err() != nil {
break
}
jobs <- wpConfig
}
close(jobs)
wg.Wait()
fmt.Fprintf(os.Stderr, "CMS hash cache: %d verified core files cached\n", cache.Size())
return findings
}
// cacheWPCoreFiles hashes all PHP files in wp-includes/ and wp-admin/
// for a verified-clean WordPress installation and adds them to the cache.
func cacheWPCoreFiles(cache *CMSHashCache, wpPath string) {
coreDirs := []string{
filepath.Join(wpPath, "wp-includes"),
filepath.Join(wpPath, "wp-admin"),
}
// Also cache root-level WP core files
rootFiles := []string{
"wp-config.php", "wp-cron.php", "wp-login.php", "wp-settings.php",
"wp-load.php", "wp-blog-header.php", "wp-links-opml.php",
"wp-mail.php", "wp-signup.php", "wp-activate.php",
"wp-comments-post.php", "wp-trackback.php", "xmlrpc.php",
"index.php",
}
for _, name := range rootFiles {
path := filepath.Join(wpPath, name)
if hash := HashFile(path); hash != "" {
cache.Add(hash)
}
}
for _, dir := range coreDirs {
_ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
name := strings.ToLower(info.Name())
if strings.HasSuffix(name, ".php") || strings.HasSuffix(name, ".js") {
if hash := HashFile(path); hash != "" {
cache.Add(hash)
}
}
return nil
})
}
}
func extractUser(path string) string {
parts := strings.Split(path, "/")
for i, p := range parts {
if p == "home" && i+1 < len(parts) {
return parts[i+1]
}
}
return "unknown"
}
package checks
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckWHMAccess parses the cPanel access log for WHM (port 2087) logins
// and password change API calls from non-infra IPs.
// Only reads the tail of the log - lightweight.
func CheckWHMAccess(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
lines := tailFile("/usr/local/cpanel/logs/access_log", 200)
for _, line := range lines {
if !isWHMAccessLogLine(line) {
continue
}
// Extract IP (first field)
fields := strings.Fields(line)
if len(fields) < 1 {
continue
}
ip := fields[0]
// Skip infra IPs
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
// Check for password change actions
passwordActions := []string{
"passwd", "change_root_password", "chpasswd",
"force_password_change", "resetpass",
}
lineLower := strings.ToLower(line)
for _, action := range passwordActions {
if strings.Contains(lineLower, action) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "whm_password_change",
Message: fmt.Sprintf("WHM password change from non-infra IP: %s", ip),
Details: truncateString(line, 200),
})
break
}
}
// Check for account management from unknown IPs
accountActions := []string{
"createacct", "killacct", "suspendacct", "unsuspendacct",
}
for _, action := range accountActions {
if strings.Contains(lineLower, action) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "whm_account_action",
Message: fmt.Sprintf("WHM account action from non-infra IP: %s", ip),
Details: truncateString(line, 200),
})
break
}
}
}
return findings
}
func isWHMAccessLogLine(line string) bool {
served := lastAccessLogField(line)
switch {
case served == "2087":
return accessLogQuotedFieldCount(line) >= 5
case strings.HasSuffix(served, ":2087"):
return accessLogQuotedFieldCount(line) >= 4
default:
return false
}
}
func lastAccessLogField(line string) string {
line = strings.TrimSpace(line)
if line == "" {
return ""
}
if strings.HasSuffix(line, "\"") {
end := len(line) - 1
start := strings.LastIndex(line[:end], "\"")
if start < 0 {
return ""
}
return line[start+1 : end]
}
fields := strings.Fields(line)
if len(fields) == 0 {
return ""
}
return fields[len(fields)-1]
}
func accessLogQuotedFieldCount(line string) int {
return strings.Count(line, "\"") / 2
}
// CheckSSHLogins parses /var/log/secure for SSH logins from non-infra IPs.
func CheckSSHLogins(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
lines := tailFile("/var/log/secure", 100)
for _, line := range lines {
if !strings.Contains(line, "Accepted") {
continue
}
// Extract IP - format: "Accepted publickey for root from 1.2.3.4 port 12345"
parts := strings.Fields(line)
ipIdx := -1
for i, p := range parts {
if p == "from" && i+1 < len(parts) {
ipIdx = i + 1
break
}
}
if ipIdx < 0 || ipIdx >= len(parts) {
continue
}
ip := parts[ipIdx]
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
// Extract user
user := "unknown"
for i, p := range parts {
if p == "for" && i+1 < len(parts) {
user = parts[i+1]
break
}
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "ssh_login_unknown_ip",
Message: fmt.Sprintf("SSH login from non-infra IP: %s (user: %s)", ip, user),
Details: truncateString(line, 200),
})
}
return findings
}
// tailFile reads the last N lines of a file efficiently.
func tailFile(path string, maxLines int) []string {
if maxLines <= 0 {
return nil
}
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
// Seek to end and read backwards to find last N lines
info, err := f.Stat()
if err != nil {
return nil
}
// For small files, just read all
if info.Size() < 1024*1024 {
return readAllLines(f, maxLines)
}
data, err := readTailWindow(f, info.Size(), maxLines, maxTailWindowBytes)
if err != nil {
return readAllLines(f, maxLines)
}
return readAllLines(bytes.NewReader(data), maxLines)
}
func readTailWindow(f *os.File, size int64, maxLines int, maxBytes int64) ([]byte, error) {
const chunkSize int64 = 256 * 1024
if maxBytes <= 0 {
return nil, nil
}
offset := size
newlines := 0
var totalRead int64
chunks := make([][]byte, 0, 4)
for offset > 0 && newlines <= maxLines && totalRead < maxBytes {
n := chunkSize
if offset < n {
n = offset
}
if remaining := maxBytes - totalRead; remaining < n {
n = remaining
}
offset -= n
chunk := make([]byte, n)
read, err := f.ReadAt(chunk, offset)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}
chunk = chunk[:read]
totalRead += int64(read)
newlines += bytes.Count(chunk, []byte{'\n'})
chunks = append(chunks, chunk)
}
total := 0
for _, chunk := range chunks {
total += len(chunk)
}
data := make([]byte, 0, total)
for i := len(chunks) - 1; i >= 0; i-- {
data = append(data, chunks[i]...)
}
if offset > 0 {
if firstNewline := bytes.IndexByte(data, '\n'); firstNewline >= 0 {
data = data[firstNewline+1:]
} else {
return nil, nil
}
}
return data, nil
}
const (
// maxLogLineBytes is the per-line cap for periodic log tailers.
// Oversized records are skipped after the reader advances past the
// terminator so a crafted long line cannot poison the next record.
maxLogLineBytes = 256 * 1024
// maxTailWindowBytes bounds the backward seek window before line
// parsing starts. Without this, a huge unterminated final record makes
// the tail reader cache the whole file while looking for maxLines.
maxTailWindowBytes int64 = 32 * 1024 * 1024
)
func readAllLines(r io.Reader, maxLines int) []string {
if maxLines <= 0 {
return nil
}
br := bufio.NewReaderSize(r, 64*1024)
var lines []string
for {
line, truncated, err := readBoundedLineLog(br, maxLogLineBytes)
if len(line) > 0 && !truncated {
lines = append(lines, trimLogLineEnding(line))
}
if err != nil {
break
}
}
if len(lines) > maxLines {
return lines[len(lines)-maxLines:]
}
return lines
}
func trimLogLineEnding(line string) string {
line = strings.TrimSuffix(line, "\n")
return strings.TrimSuffix(line, "\r")
}
// readBoundedLineLog reads up to and including the next '\n'. If the
// line exceeds maxBytes the returned data is truncated to maxBytes and
// the reader is advanced past the line's terminating newline so framing
// stays intact. Returns the same error semantics as
// bufio.Reader.ReadString.
func readBoundedLineLog(r *bufio.Reader, maxBytes int) (string, bool, error) {
var b strings.Builder
truncated := false
for {
chunk, err := r.ReadSlice('\n')
if len(chunk) > 0 {
switch {
case truncated:
// drain remainder so the next line is well-framed
case b.Len()+len(chunk) <= maxBytes:
b.Write(chunk)
default:
if room := maxBytes - b.Len(); room > 0 {
b.Write(chunk[:room])
}
truncated = true
}
}
if errors.Is(err, bufio.ErrBufferFull) {
continue
}
return b.String(), truncated, err
}
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
package checks
import (
"bytes"
"fmt"
"os"
"os/user"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"syscall"
)
// WPCronFixOptions carries operator-tunable parameters for the WP-Cron
// remediation. Both the Web UI handler and the daemon auto-response resolve
// these from config before calling the fix, so the remediation core itself
// stays free of config coupling.
type WPCronFixOptions struct {
// IntervalMinutes is how often the installed system cron runs wp-cron.php.
// Clamped to [1,60]; a non-positive value falls back to the 5-minute default.
IntervalMinutes int
// PHPBin is the interpreter the cron line invokes. Empty means "detect":
// LookPath("php") first, then the cPanel default /usr/local/bin/php.
PHPBin string
}
const (
wpCronDefaultIntervalMin = 5
wpCronMaxIntervalMin = 60
// wpCronEditMarker tags the line CSM inserts so the customer can see why
// WP-Cron was disabled and so re-running the fix stays idempotent.
wpCronEditMarker = "// CSM: WP-Cron disabled, served by system cron instead"
// wpCronJobMarker prefixes the managed crontab block for a given docroot.
wpCronJobMarker = "# CSM WP-Cron "
wpCronStopMarker = "stop editing"
)
// wpCronDefineRe matches a define of DISABLE_WP_CRON set to a truthy value,
// matching the detector's view of "already disabled".
var wpCronDefineRe = regexp.MustCompile(`(?i)define\s*\(\s*['"]DISABLE_WP_CRON['"]\s*,\s*['"]?(true|1)['"]?\s*\)`)
// validCPUser guards the username passed to `crontab -u`. cPanel usernames are
// lowercase alnum starting with a letter; rejecting anything else keeps a
// surprising file owner from reaching the crontab argument vector.
var validCPUser = regexp.MustCompile(`^[a-z_][a-z0-9_-]{0,31}$`)
var wpCronHeredocStartRe = regexp.MustCompile(`<<<['"]?([A-Za-z_][A-Za-z0-9_]*)['"]?`)
// Crontab installs are read-modify-write; serialize each account in-process.
var wpCronCrontabLocks sync.Map
// wp-config.php edits are read-modify-write too. Deep and periodic scans can
// overlap, so serialize each config path before resolving, reading, and writing.
var wpCronConfigLocks sync.Map
// wpCronOwnerName resolves the account that owns a wp-config.php. It is a var
// so tests can inject a deterministic owner regardless of who runs `go test`.
var wpCronOwnerName = fileOwnerName
// FixDisableWPCron disables WP-Cron in a wp-config.php and installs a real
// per-user system cron that runs wp-cron.php on a fixed interval. It scopes
// writes to the default per-account roots (/home).
func FixDisableWPCron(path string, opts WPCronFixOptions) RemediationResult {
return FixDisableWPCronInRoots(path, fixPerfAllowedRoots, opts)
}
// FixDisableWPCronInRoots is FixDisableWPCron with caller-supplied roots so the
// Web UI can honor configured account_roots and tests can write under t.TempDir().
func FixDisableWPCronInRoots(path string, allowedRoots []string, opts WPCronFixOptions) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path from finding"}
}
lockPath, err := sanitizeFixPath(path, allowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
lock := wpCronConfigLock(lockPath)
lock.Lock()
defer lock.Unlock()
resolved, info, err := resolveExistingFixPath(lockPath, allowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
if info.IsDir() {
return RemediationResult{Error: "refusing to edit a directory"}
}
if filepath.Base(resolved) != "wp-config.php" {
return RemediationResult{Error: fmt.Sprintf("automated WP-Cron fix only applies to wp-config.php (got %s)", filepath.Base(resolved))}
}
data, err := readFilePreservingIdentity(resolved, info)
if err != nil {
return RemediationResult{Error: fmt.Sprintf("read failed: %v", err)}
}
var actions []string
needsDefine := !wpCronHasActiveDisableDefine(data)
var rewritten []byte
if needsDefine {
var ok bool
rewritten, ok = insertDisableWPCron(data)
if !ok {
return RemediationResult{Error: "could not find a safe insertion point in wp-config.php (no \"stop editing\" marker or wp-settings.php require)"}
}
}
docroot := filepath.Dir(resolved)
owner, err := wpCronOwnerName(info)
if err != nil {
return RemediationResult{Error: fmt.Sprintf("could not resolve account owner of wp-config.php: %v", err)}
}
cronInstalled, err := installUserWPCron(owner, docroot, opts)
if err != nil {
return RemediationResult{Error: fmt.Sprintf("system cron install failed: %v", err)}
}
if needsDefine {
if werr := writeFilePreservingOwner(resolved, rewritten, info); werr != nil {
if cronInstalled {
return RemediationResult{Error: fmt.Sprintf("system cron installed but wp-config.php update failed: %v", werr)}
}
return RemediationResult{Error: werr.Error()}
}
actions = append(actions, "disabled WP-Cron in wp-config.php")
}
if cronInstalled {
actions = append(actions, fmt.Sprintf("installed every-%d-minute system cron for %s", clampInterval(opts.IntervalMinutes), owner))
}
if len(actions) == 0 {
return RemediationResult{
Success: true,
Action: fmt.Sprintf("wp-cron already configured for %s", docroot),
Description: "WP-Cron already disabled and system cron already present; no change needed",
}
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("disable WP-Cron + install system cron for %s", docroot),
Description: strings.Join(actions, "; "),
}
}
// insertDisableWPCron returns wp-config.php bytes with the DISABLE_WP_CRON
// define inserted before the "stop editing" marker, or before the
// wp-settings.php require as a fallback. The second return is false when no
// safe insertion point exists, so the caller can refuse rather than append a
// define into an unfamiliar PHP file.
func insertDisableWPCron(data []byte) ([]byte, bool) {
lines := bytes.Split(data, []byte("\n"))
defineLine := []byte("define( 'DISABLE_WP_CRON', true ); " + wpCronEditMarker)
insertAt := wpCronInsertionLine(lines)
if insertAt < 0 {
return nil, false
}
out := make([][]byte, 0, len(lines)+1)
out = append(out, lines[:insertAt]...)
out = append(out, defineLine)
out = append(out, lines[insertAt:]...)
return bytes.Join(out, []byte("\n")), true
}
func wpCronHasActiveDisableDefine(data []byte) bool {
inBlockComment := false
heredocLabel := ""
for _, line := range strings.Split(string(data), "\n") {
code := wpCronActivePHPCode(line, &inBlockComment, &heredocLabel)
if wpCronDefineRe.MatchString(code) {
return true
}
}
return false
}
func wpCronInsertionLine(lines [][]byte) int {
inBlockComment := false
heredocLabel := ""
fallback := -1
for i, line := range lines {
safeAtLineStart := !inBlockComment && heredocLabel == ""
code := wpCronActivePHPCode(string(line), &inBlockComment, &heredocLabel)
if !safeAtLineStart {
continue
}
if bytes.Contains(bytes.ToLower(line), []byte(wpCronStopMarker)) {
return i
}
if fallback < 0 && strings.Contains(code, "wp-settings.php") {
fallback = i
}
}
return fallback
}
func wpCronActivePHPCode(line string, inBlockComment *bool, heredocLabel *string) string {
if *heredocLabel != "" {
if wpCronEndsHeredoc(line, *heredocLabel) {
*heredocLabel = ""
}
return ""
}
var out strings.Builder
quote := byte(0)
escaped := false
for i := 0; i < len(line); i++ {
if *inBlockComment {
if i+1 < len(line) && line[i] == '*' && line[i+1] == '/' {
*inBlockComment = false
i++
}
continue
}
c := line[i]
if quote != 0 {
out.WriteByte(c)
if escaped {
escaped = false
continue
}
if c == '\\' {
escaped = true
continue
}
if c == quote {
quote = 0
}
continue
}
if c == '\'' || c == '"' {
quote = c
out.WriteByte(c)
continue
}
if i+1 < len(line) && c == '/' && line[i+1] == '*' {
*inBlockComment = true
i++
continue
}
if i+1 < len(line) && c == '/' && line[i+1] == '/' {
break
}
if c == '#' {
break
}
out.WriteByte(c)
}
code := out.String()
if match := wpCronHeredocStartRe.FindStringSubmatch(code); len(match) == 2 {
*heredocLabel = match[1]
}
return code
}
func wpCronEndsHeredoc(line, label string) bool {
trimmed := strings.TrimSpace(line)
return trimmed == label || trimmed == label+";"
}
// installUserWPCron ensures the owner's crontab contains a CSM-managed line
// running wp-cron.php for docroot. It returns false (no error) when the line
// is already present. The crontab is rewritten via a spool file because the
// command runner has no stdin channel; `crontab -u <user> <file>` installs and
// validates it atomically.
func installUserWPCron(owner, docroot string, opts WPCronFixOptions) (bool, error) {
if !validCPUser.MatchString(owner) {
return false, fmt.Errorf("refusing crontab edit for unexpected account name %q", owner)
}
lock := wpCronCrontabLock(owner)
lock.Lock()
defer lock.Unlock()
existing := ""
if out, err := cmdExec.RunAllowNonZero("crontab", "-u", owner, "-l"); err == nil {
existing = string(out)
}
want := wpCronJobLine(docroot, opts)
if crontabHasWPCronJob(existing, docroot) {
return false, nil
}
var buf bytes.Buffer
buf.WriteString(strings.TrimRight(existing, "\n"))
if buf.Len() > 0 {
buf.WriteByte('\n')
}
buf.WriteString(wpCronJobMarker + docroot + "\n")
buf.WriteString(want + "\n")
tmp, err := os.CreateTemp("", "csm-wpcron-*")
if err != nil {
return false, fmt.Errorf("create crontab spool: %v", err)
}
tmpPath := tmp.Name()
defer func() { _ = os.Remove(tmpPath) }()
if _, err := tmp.Write(buf.Bytes()); err != nil {
_ = tmp.Close()
return false, fmt.Errorf("write crontab spool: %v", err)
}
if err := tmp.Close(); err != nil {
return false, fmt.Errorf("close crontab spool: %v", err)
}
expected := append([]byte(nil), buf.Bytes()...)
preRecordCrontabSelfWrite(owner, expected)
if _, err := cmdExec.Run("crontab", "-u", owner, tmpPath); err != nil {
forgetSelfWrites(crontabSpoolPaths(owner)...)
return false, fmt.Errorf("crontab install: %v", err)
}
recordCrontabSelfWrite(owner, expected)
return true, nil
}
func preRecordCrontabSelfWrite(owner string, expected []byte) {
for _, p := range crontabSpoolPaths(owner) {
RecordSelfWrite(p, expected)
}
}
// recordCrontabSelfWrite registers the just-installed crontab with the
// self-write ledger so the sensitive-file detectors do not flag CSM's own
// change. The on-disk spool content (cron may normalize it) is what the
// detectors hash, so record the spool file rather than our staged buffer.
func recordCrontabSelfWrite(owner string, expected []byte) {
paths := crontabSpoolPaths(owner)
recorded := ""
for _, p := range paths {
data, err := osFS.ReadFile(p)
if err != nil {
continue
}
if crontabContentEqual(data, expected) {
RecordSelfWrite(p, data)
recorded = p
}
break
}
for _, p := range paths {
if p != recorded {
forgetSelfWrites(p)
}
}
}
func crontabSpoolPaths(owner string) []string {
return []string{
filepath.Join("/var/spool/cron", owner),
filepath.Join("/var/spool/cron/crontabs", owner),
}
}
func crontabContentEqual(got, want []byte) bool {
got = normalizeCrontabLineEndings(got)
want = normalizeCrontabLineEndings(want)
if bytes.Equal(got, want) {
return true
}
return bytes.HasSuffix(want, []byte("\n")) && bytes.Equal(got, bytes.TrimSuffix(want, []byte("\n")))
}
func normalizeCrontabLineEndings(data []byte) []byte {
return bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
}
// wpCronJobLine builds the crontab entry. CLI php is used (not an HTTP hit) so
// the job does not tie up a web worker, which is the load source the finding
// flags. max_execution_time caps a runaway cron pass.
func wpCronJobLine(docroot string, opts WPCronFixOptions) string {
interval := clampInterval(opts.IntervalMinutes)
php := opts.PHPBin
if php == "" {
php = detectPHPBin()
}
return fmt.Sprintf("*/%d * * * * cd %s && %s -d max_execution_time=300 wp-cron.php >/dev/null 2>&1",
interval, shellQuote(docroot), shellQuote(php))
}
// crontabHasWPCronJob reports whether the crontab already runs wp-cron.php for
// docroot, regardless of interval or php path, so re-running the fix is a no-op.
func crontabHasWPCronJob(crontab, docroot string) bool {
docroot = filepath.Clean(docroot)
for _, line := range strings.Split(crontab, "\n") {
command := crontabCommand(line)
if command == "" {
continue
}
if commandRunsWPCronForDocroot(command, docroot) {
return true
}
}
return false
}
func wpCronCrontabLock(owner string) *sync.Mutex {
lock, _ := wpCronCrontabLocks.LoadOrStore(owner, &sync.Mutex{})
return lock.(*sync.Mutex)
}
func wpCronConfigLock(path string) *sync.Mutex {
lock, _ := wpCronConfigLocks.LoadOrStore(path, &sync.Mutex{})
return lock.(*sync.Mutex)
}
func crontabCommand(line string) string {
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
return ""
}
fields := strings.Fields(trimmed)
if len(fields) == 0 || strings.Contains(fields[0], "=") {
return ""
}
if strings.HasPrefix(fields[0], "@") {
if len(fields) < 2 {
return ""
}
return strings.TrimSpace(trimmed[len(fields[0]):])
}
if len(fields) < 6 {
return ""
}
rest := trimmed
for i := 0; i < 5; i++ {
rest = strings.TrimLeft(rest, " \t")
fieldEnd := strings.IndexAny(rest, " \t")
if fieldEnd < 0 {
return ""
}
rest = rest[fieldEnd:]
}
return strings.TrimSpace(rest)
}
func commandRunsWPCronForDocroot(command, docroot string) bool {
if !strings.Contains(command, "wp-cron.php") {
return false
}
words := shellWords(command)
wpCronPath := filepath.Clean(filepath.Join(docroot, "wp-cron.php"))
for _, word := range words {
if cleanShellPathWord(word) == wpCronPath {
return true
}
}
for i, word := range words {
if word != "cd" {
continue
}
j := i + 1
for j < len(words) && strings.HasPrefix(words[j], "-") {
j++
}
if j >= len(words) || filepath.Clean(words[j]) != docroot {
continue
}
for _, later := range words[j+1:] {
if filepath.Base(cleanShellPathWord(later)) == "wp-cron.php" {
return true
}
}
}
return false
}
func cleanShellPathWord(word string) string {
if i := strings.IndexAny(word, "?#"); i >= 0 {
word = word[:i]
}
return filepath.Clean(word)
}
func shellWords(command string) []string {
var words []string
var current strings.Builder
quote := byte(0)
escaped := false
flush := func() {
if current.Len() == 0 {
return
}
words = append(words, current.String())
current.Reset()
}
for i := 0; i < len(command); i++ {
c := command[i]
if quote != 0 {
if escaped {
current.WriteByte(c)
escaped = false
continue
}
if c == '\\' && quote == '"' {
escaped = true
continue
}
if c == quote {
quote = 0
continue
}
current.WriteByte(c)
continue
}
if escaped {
current.WriteByte(c)
escaped = false
continue
}
switch c {
case '\\':
escaped = true
case '\'', '"':
quote = c
case ' ', '\t', ';', '&', '|', '(', ')', '<', '>':
flush()
default:
current.WriteByte(c)
}
}
if escaped {
current.WriteByte('\\')
}
flush()
return words
}
func clampInterval(minutes int) int {
if minutes <= 0 {
return wpCronDefaultIntervalMin
}
if minutes > wpCronMaxIntervalMin {
return wpCronMaxIntervalMin
}
return minutes
}
func detectPHPBin() string {
if p, err := cmdExec.LookPath("php"); err == nil && p != "" {
return p
}
return "/usr/local/bin/php"
}
// fileOwnerName resolves the username that owns the wp-config.php so the cron
// runs as the account, not root. The owner is the source of truth for which
// account this WordPress install belongs to.
func fileOwnerName(info os.FileInfo) (string, error) {
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return "", fmt.Errorf("unsupported file info")
}
if stat.Uid == 0 {
// A customer wp-config.php should never be root-owned; installing a
// cron that runs wp-cron.php as root would be a privilege smell.
return "", fmt.Errorf("refusing to install a root-owned cron for wp-config.php")
}
uid := strconv.FormatUint(uint64(stat.Uid), 10)
u, err := user.LookupId(uid)
if err != nil {
return "", fmt.Errorf("uid %s: %v", uid, err)
}
return u.Username, nil
}
package config
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"strings"
"syscall"
"gopkg.in/yaml.v3"
)
type confFragment struct {
path string
node *yaml.Node
}
// ValidateConfDir vets an operator-selected conf.d directory before any YAML
// fragments are loaded from it. The returned path is symlink-resolved so later
// reads do not depend on a mutable link name.
func ValidateConfDir(dir string) (string, error) {
if dir == "" {
return "", nil
}
if !filepath.IsAbs(dir) {
return "", fmt.Errorf("conf.d directory must be an absolute path, got %q", dir)
}
resolved, err := filepath.EvalSymlinks(dir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return "", fmt.Errorf("conf.d directory does not exist: %s", dir)
}
return "", fmt.Errorf("conf.d directory symlink resolution: %w", err)
}
info, err := os.Stat(resolved)
if err != nil {
return "", fmt.Errorf("conf.d directory stat: %w", err)
}
if !info.IsDir() {
return "", fmt.Errorf("conf.d directory is not a directory: %s", resolved)
}
if trustErr := validateConfPathTrust("conf.d directory", resolved, info); trustErr != nil {
return "", trustErr
}
return resolved, nil
}
// LoadConfDir reads every *.yaml file in dir in lexicographic order and
// returns each as a parsed yaml.DocumentNode. A missing directory is not
// an error and returns an empty slice; an unreadable file or invalid YAML
// is fatal so operators see misconfigurations at startup.
func LoadConfDir(dir string) ([]*yaml.Node, error) {
frags, err := loadConfDirFragments(dir)
if err != nil {
return nil, err
}
out := make([]*yaml.Node, 0, len(frags))
for _, frag := range frags {
out = append(out, frag.node)
}
return out, nil
}
// ConfDirFragment is one conf.d drop-in fragment's filename and raw bytes,
// exported so the integrity hasher can cover the same fragment set the loader
// merges without duplicating the enumeration rules.
type ConfDirFragment struct {
Name string
Data []byte
}
// ConfDirFragmentDigestInput returns every non-empty trusted conf.d fragment in
// merge order (sorted .yaml/.yml, symlink-resolved, trust-validated) as
// name+content pairs for integrity hashing. An empty dir or no mergeable
// fragments yields nil so a config without conf.d hashes to the empty digest
// and its baseline is unaffected.
func ConfDirFragmentDigestInput(dir string) ([]ConfDirFragment, error) {
files, err := confDirFragmentFiles(dir)
if err != nil {
return nil, err
}
if len(files) == 0 {
return nil, nil
}
out := make([]ConfDirFragment, 0, len(files))
for _, ff := range files {
if _, ok, err := parseConfFragment(ff); err != nil {
return nil, err
} else if !ok {
continue
}
out = append(out, ConfDirFragment{Name: ff.name, Data: ff.data})
}
return out, nil
}
type confFragmentFile struct {
name string
path string
data []byte
}
// confDirFragmentFiles enumerates trusted conf.d fragment files and returns
// their raw bytes in merge order. Shared by loadConfDirFragments and the
// integrity hasher so both observe exactly the same fragment set.
func confDirFragmentFiles(dir string) ([]confFragmentFile, error) {
if dir == "" {
return nil, nil
}
resolved, err := filepath.EvalSymlinks(dir)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, nil
}
return nil, fmt.Errorf("conf.d directory symlink resolution: %w", err)
}
info, err := os.Stat(resolved)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, nil
}
return nil, fmt.Errorf("conf.d directory stat: %w", err)
}
if !info.IsDir() {
return nil, fmt.Errorf("conf.d directory is not a directory: %s", resolved)
}
if trustErr := validateConfPathTrust("conf.d directory", resolved, info); trustErr != nil {
return nil, trustErr
}
dir = resolved
entries, err := os.ReadDir(dir)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, nil
}
return nil, fmt.Errorf("reading %s: %w", dir, err)
}
names := make([]string, 0, len(entries))
for _, e := range entries {
if e.IsDir() {
continue
}
if !strings.HasSuffix(e.Name(), ".yaml") && !strings.HasSuffix(e.Name(), ".yml") {
continue
}
names = append(names, e.Name())
}
sort.Strings(names)
out := make([]confFragmentFile, 0, len(names))
for _, name := range names {
path := filepath.Join(dir, name)
data, err := readTrustedConfFragment(path)
if err != nil {
return nil, err
}
out = append(out, confFragmentFile{name: name, path: path, data: data})
}
return out, nil
}
func loadConfDirFragments(dir string) ([]confFragment, error) {
files, err := confDirFragmentFiles(dir)
if err != nil {
return nil, err
}
out := make([]confFragment, 0, len(files))
for _, ff := range files {
node, ok, err := parseConfFragment(ff)
if err != nil {
return nil, err
}
if !ok {
continue
}
out = append(out, confFragment{path: ff.path, node: node})
}
return out, nil
}
func parseConfFragment(ff confFragmentFile) (*yaml.Node, bool, error) {
var node yaml.Node
if err := yaml.Unmarshal(ff.data, &node); err != nil {
return nil, false, fmt.Errorf("parsing %s: %w", ff.path, err)
}
// Skip empty files (Unmarshal yields a zero-Content document).
if len(node.Content) == 0 {
return nil, false, nil
}
if hasTopLevelKey(&node, "integrity") {
return nil, false, fmt.Errorf("conf.d fragment %s must not set daemon-managed integrity metadata", ff.path)
}
return &node, true, nil
}
func hasTopLevelKey(root *yaml.Node, key string) bool {
cur := root
if cur.Kind == yaml.DocumentNode {
if len(cur.Content) == 0 {
return false
}
cur = cur.Content[0]
}
if cur.Kind != yaml.MappingNode {
return false
}
for i := 0; i+1 < len(cur.Content); i += 2 {
if cur.Content[i].Value == key {
return true
}
}
return false
}
func readTrustedConfFragment(path string) ([]byte, error) {
// #nosec G304 -- path is built from an operator-selected conf.d and a directory entry.
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("reading %s: %w", path, err)
}
defer f.Close()
info, err := f.Stat()
if err != nil {
return nil, fmt.Errorf("stat %s: %w", path, err)
}
if !info.Mode().IsRegular() {
return nil, fmt.Errorf("conf.d fragment is not a regular file: %s", path)
}
if trustErr := validateConfPathTrust("conf.d fragment", path, info); trustErr != nil {
return nil, trustErr
}
data, err := readConfigBytesLimited(f)
if errors.Is(err, errConfigTooLarge) {
return nil, fmt.Errorf("conf.d fragment %s exceeds %d byte cap", path, MaxConfigBytes)
}
if err != nil {
return nil, fmt.Errorf("reading %s: %w", path, err)
}
return data, nil
}
func validateConfPathTrust(kind, path string, info os.FileInfo) error {
if mode := info.Mode().Perm(); mode&0022 != 0 {
return fmt.Errorf("%s %s has unsafe mode %04o (group or world writable); set 0750/0640 or stricter", kind, path, mode)
}
sys, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return nil
}
// #nosec G115 -- Linux uid_t is uint32; os.Geteuid returns the kernel
// effective UID and cannot overflow that type on supported hosts.
selfUID := uint32(os.Geteuid())
if sys.Uid != 0 && sys.Uid != selfUID {
return fmt.Errorf("%s %s owner uid=%d is neither root (0) nor process uid=%d; refusing to load untrusted YAML", kind, path, sys.Uid, selfUID)
}
return nil
}
package config
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"gopkg.in/yaml.v3"
"github.com/pidginhost/csm/internal/atomicio"
"github.com/pidginhost/csm/internal/firewall"
)
// WebUIToken is one entry in WebUI.Tokens. Scope must be "admin" or "read".
type WebUIToken struct {
Name string `yaml:"name"`
Token string `yaml:"token"`
Scope string `yaml:"scope"`
}
// DefaultMaxBlocksPerHour is the safe hourly cap used when the operator
// leaves auto_response.max_blocks_per_hour unset or sets it to 0.
const DefaultMaxBlocksPerHour = 50
// MailLogsConfig controls how postfix/dovecot logs are read.
//
// source: auto - try file first; fall back to journal if file absent.
// source: file - require log file at the platform-default path.
// source: journal - read from systemd-journald (units must be set).
//
// Units is consulted for journal fallback: the daemon matches each
// systemd unit by name, appending ".service" for bare service names.
type MailLogsConfig struct {
Source string `yaml:"source"` // auto | file | journal
File string `yaml:"file,omitempty"` // override platform default
Units []string `yaml:"units,omitempty"` // for journal source
}
type Config struct {
ConfigFile string `yaml:"-"`
ConfigDir string `yaml:"-" hotreload:"restart"` // /etc/csm/conf.d (or operator override); empty means no drop-ins loaded
Hostname string `yaml:"hostname" hotreload:"restart"`
Alerts struct {
Email struct {
Enabled bool `yaml:"enabled"`
To []string `yaml:"to"`
From string `yaml:"from"`
SMTP string `yaml:"smtp"`
DisabledChecks []string `yaml:"disabled_checks"`
} `yaml:"email"`
Webhook struct {
Enabled bool `yaml:"enabled"`
URL string `yaml:"url"`
Type string `yaml:"type"` // slack, discord, generic, phpanel
// HMACSecret is the shared secret used to sign each request when
// Type=="phpanel". Read from this field directly OR via HMACSecretEnv
// (env wins, for secret hygiene).
HMACSecret string `yaml:"hmac_secret,omitempty"`
HMACSecretEnv string `yaml:"hmac_secret_env,omitempty"`
// PerFinding documents the expected phpanel delivery shape. Phpanel
// webhooks always emit one signed POST per finding; other webhook
// types keep the existing digest delivery.
PerFinding bool `yaml:"per_finding,omitempty"`
} `yaml:"webhook"`
Heartbeat struct {
Enabled bool `yaml:"enabled"`
URL string `yaml:"url"`
} `yaml:"heartbeat"`
MaxPerHour int `yaml:"max_per_hour"`
// AuditLog ships every (deduplicated) finding to one or more
// SIEM-friendly destinations. Schema is stable: parsers can
// pin on the v=1 contract. Both sub-blocks default off; the
// alert pipeline behaves identically to before when neither
// is enabled.
AuditLog struct {
File struct {
Enabled bool `yaml:"enabled"`
Path string `yaml:"path"` // default: /var/log/csm/audit.jsonl
} `yaml:"file"`
Syslog struct {
Enabled bool `yaml:"enabled"`
Network string `yaml:"network"` // udp | tcp | unix | unixgram | tls
Address string `yaml:"address"` // host:port or filesystem path
Facility string `yaml:"facility"` // default: local0
TLSCAFile string `yaml:"tls_ca"` // optional CA cert for tls
} `yaml:"syslog"`
} `yaml:"audit_log"`
} `yaml:"alerts" hotreload:"safe"`
Integrity struct {
BinaryHash string `yaml:"binary_hash"`
ConfigHash string `yaml:"config_hash"`
// ConfdHash covers the conf.d drop-in fragments merged on top of the
// main config. Empty when no fragments exist, so configs without
// conf.d stay byte-identical to their pre-existing baseline. Without
// it an attacker with write access to conf.d could override any
// setting (auto_response, block_ips, dry_run) without tripping
// integrity verification.
ConfdHash string `yaml:"confd_hash"`
Immutable bool `yaml:"immutable"`
} `yaml:"integrity"`
Thresholds struct {
MailQueueWarn int `yaml:"mail_queue_warn"`
MailQueueCrit int `yaml:"mail_queue_crit"`
StateExpiryHours int `yaml:"state_expiry_hours"`
DeepScanIntervalMin int `yaml:"deep_scan_interval_min"`
WPCoreCheckIntervalMin int `yaml:"wp_core_check_interval_min"`
WebshellScanIntervalMin int `yaml:"webshell_scan_interval_min"`
FilesystemScanIntervalMin int `yaml:"filesystem_scan_interval_min"`
MultiIPLoginThreshold int `yaml:"multi_ip_login_threshold"`
MultiIPLoginWindowMin int `yaml:"multi_ip_login_window_min"`
// CredStuffingDistinctAccounts is the number of distinct accounts a
// single source IP must fail against inside the multi_ip_login window
// to raise a credential_stuffing finding. This is the breadth signal
// (one source, many accounts) that the count-based pam_bruteforce
// detector does not catch. Default 5 when unset or <=0, matching the
// always-on posture of the sibling multi_ip_login_threshold.
CredStuffingDistinctAccounts int `yaml:"cred_stuffing_distinct_accounts"`
PluginCheckIntervalMin int `yaml:"plugin_check_interval_min"`
BruteForceWindow int `yaml:"brute_force_window"`
// DomlogMaxFiles caps how many per-domain access logs the WP brute
// force check scans per cycle. Sites are ranked by recent mtime so
// the cap chops least-active domains. Default 500. Bump on hosts
// with many active domains so late-alphabet sites are not skipped.
DomlogMaxFiles int `yaml:"domlog_max_files"`
// AccountScanMaxFiles caps how many account and mail-domain paths
// account-scoped scanners iterate per cycle. This covers SSH keys,
// cPanel API tokens, Dovecot shadow files, CMS DB scans, forwarders,
// user crontabs, and account .config backdoor paths. Paths are
// ranked by mtime desc so the cap chops least-active accounts.
// Default 10000 covers typical cPanel hosts with no cap effect.
// Raise on very large multi-tenant hosts where late-mtime accounts
// get skipped.
AccountScanMaxFiles int `yaml:"account_scan_max_files"`
// CrontabBase64BlobMaxBytes caps a single base64 candidate before
// decoding in MatchCrontabPatternsDeep. Default 16384 encoded
// bytes (~12 KiB decoded) comfortably fits any realistic gsocket
// or `base64 -d|bash` payload while bounding work on adversarial
// input. Raise on hosts where csm_checks_crontab_base64_truncated_total
// shows recurring truncation. Must be a multiple of 4: standard
// base64 requires aligned input or the decode fails and the
// candidate is silently skipped.
CrontabBase64BlobMaxBytes int `yaml:"crontab_base64_blob_max_bytes"`
// DomlogTailLines is how many trailing lines the WP brute force
// check reads from each per-domain access log per cycle. Default
// 500 covers roughly 10 minutes of traffic on a busy site. Raise
// on hosts where slow-burn attacks against high-volume domains
// spread across more than 500 lines per scan interval, so the
// per-IP counter window is wide enough to trip a threshold.
DomlogTailLines int `yaml:"domlog_tail_lines"`
// DomlogMaxAgeMin is how many minutes back the WP brute force
// scanner accepts a per-domain access log as "fresh enough to
// scan." Logs whose mtime is older than this are skipped. Default
// 30 keeps the scan focused on currently-active vhosts. Raise on
// low-traffic hosts where a slow-burn dictionary attack against
// a quiet domain still needs to fall inside the window.
DomlogMaxAgeMin int `yaml:"domlog_max_age_min"`
// MailLogTailLines is how many trailing lines CheckMailPerAccount
// reads from /var/log/exim_mainlog per cycle. Default 500. Raise
// on busy mail hosts where a single account's spam burst spreads
// across more than 500 lines per cycle.
MailLogTailLines int `yaml:"mail_log_tail_lines"`
// SyslogMessagesTailLines is how many trailing lines CheckFTPLogins
// reads from /var/log/messages per cycle. Default 200. Raise on
// hosts that share /var/log/messages with noisy services (e.g.
// systemd-resolved chatter) so pure-ftpd failure lines do not
// fall outside the window.
SyslogMessagesTailLines int `yaml:"syslog_messages_tail_lines"`
SMTPBruteForceThreshold int `yaml:"smtp_bruteforce_threshold"`
SMTPBruteForceWindowMin int `yaml:"smtp_bruteforce_window_min"`
SMTPBruteForceSuppressMin int `yaml:"smtp_bruteforce_suppress_min"`
SMTPBruteForceSubnetThresh int `yaml:"smtp_bruteforce_subnet_threshold"`
SMTPAccountSprayThreshold int `yaml:"smtp_account_spray_threshold"`
SMTPBruteForceMaxTracked int `yaml:"smtp_bruteforce_max_tracked"`
// SMTP probe abuse counts raw inbound SMTP connect events per source
// IP (independent of AUTH outcome) so probe-and-disconnect scanners
// that never reach the AUTH stage are still caught. Threshold sized
// well above any legitimate MUA usage. Explicit 0 disables.
SMTPProbeThreshold int `yaml:"smtp_probe_threshold"`
SMTPProbeWindowMin int `yaml:"smtp_probe_window_min"`
SMTPProbeSuppressMin int `yaml:"smtp_probe_suppress_min"`
SMTPProbeMaxTracked int `yaml:"smtp_probe_max_tracked"`
MailBruteForceThreshold int `yaml:"mail_bruteforce_threshold"`
MailBruteForceWindowMin int `yaml:"mail_bruteforce_window_min"`
MailBruteForceSuppressMin int `yaml:"mail_bruteforce_suppress_min"`
MailBruteForceSubnetThresh int `yaml:"mail_bruteforce_subnet_threshold"`
MailAccountSprayThreshold int `yaml:"mail_account_spray_threshold"`
MailBruteForceMaxTracked int `yaml:"mail_bruteforce_max_tracked"`
// MailBruteAccountKey selects how the account is extracted from a
// dovecot/postfix log line for per-account brute-force scoring.
// - builtin:dovecot-user (default) - match `user=<...>`
// - builtin:postfix-sasl - match `sasl_username=<...>`
// - regex:<pattern> - capture group 1 is the account
MailBruteAccountKey string `yaml:"mail_brute_account_key,omitempty"`
// ModSecEscalationHits is the number of ModSecurity denies from a
// single source IP, inside ModSecEscalationWindowMin, that
// promote the IP from "logged" to "escalated" (Critical finding
// + firewall hand-off). Default 3. Lower it on hosts where
// low-and-slow scanners go below the floor for too long.
ModSecEscalationHits int `yaml:"modsec_escalation_hits"`
// ModSecEscalationWindowMin is the sliding-window size for the
// hit counter. Default 10. Bumping it (e.g. to 60-240) catches
// paced attackers that spread denies across hours without
// changing the trip count.
ModSecEscalationWindowMin int `yaml:"modsec_escalation_window_min"`
// HTTPFloodThreshold is the minimum requests per http_flood_window_min
// from one source IP that emits http_request_flood. 0 (default) disables
// the detector. Operators should sample baseline traffic before setting
// a nonzero value.
HTTPFloodThreshold int `yaml:"http_flood_threshold"`
// HTTPFloodWindowMin is the rate window in minutes for HTTPFloodThreshold
// counting. Default 5.
HTTPFloodWindowMin int `yaml:"http_flood_window_min"`
// HTTPUASpoofThreshold is the minimum per-IP per-window count
// of WPSpoofPingback, ScriptingLang, Headless, or Empty UA
// requests that emits http_ua_spoof. KnownScanner and
// cache-confirmed negative ClaimedBot emit on count=1.
// Default 30.
HTTPUASpoofThreshold int `yaml:"http_ua_spoof_threshold"`
// HTTPDistributedMinIPs is the number of distinct source IPs that
// must each trip an HTTP-abuse threshold (wp-login / xmlrpc / user
// enumeration / request flood / UA spoof) against one vhost in a
// scan window before a single http_distributed_flood finding is
// emitted for that vhost. Only already-abusive IPs are counted, so
// a popular site's ordinary visitor spread does not trip it. 0
// disables the detector; the shipped sample sets 10.
HTTPDistributedMinIPs int `yaml:"http_distributed_min_ips"`
// HTTPUAScriptingEnabled opts in to flagging scripting-language
// UA strings (curl, python-requests, wget, etc.) as spoof candidates.
// Off by default: many legitimate API integrations use these.
HTTPUAScriptingEnabled bool `yaml:"http_ua_scripting_enabled"`
// HTTPUAHeadlessEnabled opts in to flagging headless-browser UA
// strings (HeadlessChrome, PhantomJS, Playwright, etc.). Off by
// default: headless browsers are used by legitimate monitoring tools.
HTTPUAHeadlessEnabled bool `yaml:"http_ua_headless_enabled"`
// HTTPUAEmptyEnabled opts in to flagging requests with an empty or
// dash User-Agent. Off by default: some CDN health checks omit UA.
HTTPUAEmptyEnabled bool `yaml:"http_ua_empty_enabled"`
} `yaml:"thresholds" hotreload:"safe"`
InfraIPs []string `yaml:"infra_ips" hotreload:"restart"`
StatePath string `yaml:"state_path" hotreload:"restart"`
// Detection groups operator-facing knobs that gate detection
// scanners. Today this is just the DB persistence-mechanism
// scanner; future detection toggles land here.
Detection struct {
// DBObjectScanning toggles the MySQL persistence scanner
// (triggers/events/procedures/functions). Tri-state *bool
// matching the existing yara_worker_enabled pattern: nil =
// default-on, *true = explicit on, *false = explicit off.
// When off both the Critical (db_malicious_*) and Warning
// (db_unexpected_*) emit paths fall silent; the manual
// `csm db-clean drop-object` CLI keeps working so operators
// can act on objects discovered by other means.
DBObjectScanning *bool `yaml:"db_object_scanning"`
// DBObjectAllowlist suppresses the Warning tier
// (db_unexpected_*) for objects an operator has reviewed and
// accepted. Entries shaped <account>:<schema>:<type>:<name>.
// The Critical tier (db_malicious_*) ignores this list --
// pattern hits always fire.
DBObjectAllowlist []string `yaml:"db_object_allowlist"`
// AdminOverlapMinAccounts is the threshold at which the
// cross-account admin email correlator emits a finding. Default
// 2 matches the most common compromise pattern on shared hosting
// -- a contractor account used across multiple customer cPanels
// is a single credential leak away from compromising every site
// they touch. Operators with deliberately shared internal admin
// emails (e.g. one ops team across many sites) can raise the
// threshold to silence the routine overlap.
AdminOverlapMinAccounts int `yaml:"admin_overlap_min_accounts"`
// AdminOverlapTrustedEmails suppresses cross-account admin
// overlap findings for exact, operator-reviewed email addresses.
AdminOverlapTrustedEmails []string `yaml:"admin_overlap_trusted_emails"`
// AdminOverlapTrustedDomains suppresses cross-account admin
// overlap findings for exact email domains used by trusted
// developer or reseller admin accounts.
AdminOverlapTrustedDomains []string `yaml:"admin_overlap_trusted_domains"`
// RescanOnSignatureUpdate fires a forced full-tree deep
// scan the next time any file under cfg.Signatures.RulesDir
// has its mtime advance. Tri-state *bool: nil = default-on,
// *true = explicit on, *false = explicit off. Off means the
// existing behaviour (deep-tier runs against the fanotify
// short-list when fanotify is active) is unchanged; new
// rules only catch files that change after the update.
RescanOnSignatureUpdate *bool `yaml:"rescan_on_signature_update"`
// AFAlgBackend selects the live AF_ALG (CVE-2026-31431, "Copy
// Fail") detection backend. Empty / "auto" picks BPF LSM if
// the binary was built with -tags bpf and the kernel supports
// it, otherwise the audit-log inotify listener. "bpf" forces
// BPF and disables the audit fallback (no live monitor if BPF
// is unavailable; the periodic critical-tier check still
// runs). "auditd" forces the audit listener even on BPF-
// capable kernels -- a kill switch when a BPF-tagged release
// misbehaves and the operator wants to revert without
// rebuilding. "none" disables the live monitor entirely.
AFAlgBackend string `yaml:"af_alg_backend"`
// ConnectionTrackerBackend selects the live outbound-connection
// tracker. Empty / "auto" tries BPF cgroup/connect4,6 first and
// falls back to the existing /proc/net/tcp polling. "bpf"
// requires BPF (no fallback). "legacy" pins polling. "none"
// disables the live tracker; the periodic check still runs.
ConnectionTrackerBackend string `yaml:"connection_tracker_backend"`
// ConnectionPollInterval is how often the legacy polling backend
// reads /proc/net/tcp(6). Ignored when the BPF backend is active.
// Empty / zero defaults to 30s.
ConnectionPollInterval time.Duration `yaml:"connection_poll_interval"`
// ExecMonitorBackend selects the live process-exec monitor.
// Empty / "auto" tries the sched_process_exec BPF tracepoint and
// falls back to the periodic /proc walk. "bpf" requires BPF
// (no fallback). "legacy" pins polling. "none" disables the live
// monitor; the periodic deep-tier checks still run.
ExecMonitorBackend string `yaml:"exec_monitor_backend"`
// ExecMonitorPollInterval is how often the legacy polling backend
// runs CheckSuspiciousProcesses + CheckFakeKernelThreads. Ignored
// when the BPF backend is active. Empty / zero defaults to 30m.
ExecMonitorPollInterval time.Duration `yaml:"exec_monitor_poll_interval"`
// SensitiveFilesBackend selects the live sensitive-file write
// monitor. Empty / "auto" tries the BPF LSM hook on /etc/shadow
// and friends, falling back to a periodic content-hash check.
// "bpf" requires BPF (no fallback). "legacy" pins polling.
// "none" disables the live monitor; the periodic check still runs.
SensitiveFilesBackend string `yaml:"sensitive_files_backend"`
// SensitiveFilesPollInterval is how often the BPF watchset map
// refreshes (to pick up newly-created files in glob directories
// and handle inode reuse) and how often the legacy polling
// backend runs the content-hash check. Empty / zero defaults to 5m.
SensitiveFilesPollInterval time.Duration `yaml:"sensitive_files_poll_interval"`
// DirectSMTPEgress flags non-MTA local processes opening
// outbound SMTP connections. Phase 3 of the BPF Incident
// Response Roadmap. Detection-only this phase; Phase 4 will
// add the auto-response action gated by DryRun.
DirectSMTPEgress struct {
Enabled bool `yaml:"enabled"`
Backend string `yaml:"backend"` // auto / bpf / legacy / none
// DryRun, when true (or absent for safety), reports findings
// but takes no detector-scoped action. Phase 3 emits findings
// regardless; the knob exists for the Phase 4 action.
DryRun *bool `yaml:"dry_run,omitempty"`
Ports []int `yaml:"ports,omitempty"`
} `yaml:"direct_smtp_egress" hotreload:"safe"`
// BadASNOutbound flags outbound connections whose destination IP
// resolves (via the GeoLite2-ASN database) to a bad or unexpected
// autonomous system. It is the third leg of the host-takeover
// chain correlator (alongside a new uid-0 account and a planted
// suid binary). Requires the GeoLite2-ASN database. Off by default;
// classification needs operator-supplied ASN lists.
BadASNOutbound struct {
Enabled bool `yaml:"enabled"`
// BlockedASNs are autonomous system numbers always treated as
// bad (e.g. known bulletproof hosters).
BlockedASNs []uint `yaml:"blocked_asns"`
// AllowedASNs, when non-empty, switches to allowlist mode: any
// destination ASN outside this set is treated as bad. Use on
// hosts whose legitimate egress is confined to a few providers.
AllowedASNs []uint `yaml:"allowed_asns"`
} `yaml:"bad_asn_outbound" hotreload:"safe"`
} `yaml:"detection" hotreload:"safe"`
Suppressions struct {
UPCPWindowStart string `yaml:"upcp_window_start"`
UPCPWindowEnd string `yaml:"upcp_window_end"`
KnownAPITokens []string `yaml:"known_api_tokens"`
IgnorePaths []string `yaml:"ignore_paths"`
SuppressWebmail bool `yaml:"suppress_webmail_alerts"` // don't alert on webmail logins
SuppressCpanelLogin bool `yaml:"suppress_cpanel_login_alerts"` // don't alert on cPanel direct logins
SuppressBlockedAlerts bool `yaml:"suppress_blocked_alerts"` // don't alert on IPs that were auto-blocked
TrustedCountries []string `yaml:"trusted_countries"` // ISO 3166-1 alpha-2 codes - suppress cPanel login alerts from these countries
} `yaml:"suppressions" hotreload:"safe"`
AutoResponse struct {
Enabled bool `yaml:"enabled"`
KillProcesses bool `yaml:"kill_processes"`
QuarantineFiles bool `yaml:"quarantine_files"`
BlockIPs bool `yaml:"block_ips"`
BlockExpiry string `yaml:"block_expiry"` // e.g. "24h", "12h"
EnforcePermissions bool `yaml:"enforce_permissions"` // auto-chmod 644 world/group-writable PHP files (default false)
FixWPCron bool `yaml:"fix_wp_cron"` // auto-disable WP-Cron + install per-user system cron on perf_wp_cron findings (default false)
BlockCpanelLogins bool `yaml:"block_cpanel_logins"` // block IPs on cPanel/webmail login alerts (default false)
NetBlock bool `yaml:"netblock"` // auto-block IPv4 /24 or IPv6 /64 at threshold
NetBlockThreshold int `yaml:"netblock_threshold"` // IPs from same IPv4 /24 or IPv6 /64 before subnet block (default 3)
// MaxBlocksPerHour caps per-IP auto-blocks per wall-clock hour.
// 0 uses DefaultMaxBlocksPerHour.
MaxBlocksPerHour int `yaml:"max_blocks_per_hour"`
PermBlock bool `yaml:"permblock"` // auto-promote to permanent after N temp blocks
PermBlockCount int `yaml:"permblock_count"` // temp blocks before permanent (default 4)
PermBlockInterval string `yaml:"permblock_interval"` // window for counting temp blocks (default "24h")
CleanDatabase bool `yaml:"clean_database"` // auto-clean malicious DB injections, revoke sessions, block attacker IPs (default false)
CleanHtaccess bool `yaml:"clean_htaccess"` // auto-clean .htaccess directives flagged by the hardened detectors (default false)
DisableEnforceAFAlg bool `yaml:"disable_enforce_af_alg"` // suspend periodic AF_ALG enforcement; marker file + detection remain active (default false = enforce when marker present)
CopyFailKillProcess bool `yaml:"copy_fail_kill_process"` // SIGKILL processes caught opening AF_ALG sockets via the live listener (default false; alert-only)
// DryRun, when true (or absent - safety default), logs the intended
// action but does NOT touch nftables. Mirrors the PHPRelay.DryRun
// pattern: pointer-bool to distinguish "operator explicitly set false"
// from "operator omitted the key". Implicit nil means dry-run on, so
// flipping block_ips: true alone never causes a real block.
DryRun *bool `yaml:"dry_run,omitempty"`
// PHPRelay controls the auto-freeze behaviour that companion
// email PHP-relay detectors emit findings for. Freeze and DryRun
// are *bool so we can distinguish OMITTED from EXPLICIT FALSE in
// YAML. A plain bool zero-value is false, which would let an
// operator write `freeze: true` and (by forgetting `dry_run`)
// get LIVE freezes against their will. Pointer values: nil =
// "not set in YAML"; *true / *false = explicit. Use the
// FreezeEnabled() / DryRunEnabled() accessors on *Config to
// resolve the safe defaults rather than dereferencing directly.
PHPRelay struct {
Freeze *bool `yaml:"freeze"`
DryRun *bool `yaml:"dry_run"`
MaxActionsPerMinute int `yaml:"max_actions_per_minute"`
} `yaml:"php_relay"`
// VerdictCallback lets phpanel observe each block decision before it's
// applied. CSM POSTs the verdict to the configured URL with HMAC-SHA256
// signing (same scheme as the phpanel webhook); the response is
// advisory - phpanel can attach a tenant_id, return "allow" to keep
// the event audit-only, or omit a response entirely
// (CSM proceeds with its default verdict). NOT a per-tenant nftables
// enforcement: that's a separate, larger feature.
//
// Secret resolution happens at call time (the verdict.Client reads
// HMACSecretEnv per call), so operators can rotate via env without
// restarting the daemon.
VerdictCallback struct {
Enabled bool `yaml:"enabled"`
URL string `yaml:"url"`
HMACSecret string `yaml:"hmac_secret,omitempty"`
HMACSecretEnv string `yaml:"hmac_secret_env,omitempty"`
TimeoutSec int `yaml:"timeout_sec"`
// RequireResponseSignature controls whether the panel must sign
// its response body with the same HMAC scheme used on the
// request (X-CSM-Signature header). Default true: when a secret
// is configured, CSM rejects unsigned or forged responses to
// prevent an on-path attacker from downgrading block to allow.
// Set false only during a phpanel rollout that has not yet
// implemented response signing. In that mode, CSM still checks
// nonce or timestamp fields the panel echoes; responses that
// omit both keep the legacy advisory shape working.
RequireResponseSignature *bool `yaml:"require_response_signature,omitempty"`
// AllowUnsigned opts out of the default fail-closed posture and
// permits the verdict callback to fire without an HMAC secret,
// including advisory "allow" responses. Only set true while
// bootstrapping a new panel or during local testing; production
// deployments must keep this false so the daemon refuses to start
// when the secret env var is empty.
AllowUnsigned bool `yaml:"allow_unsigned,omitempty"`
} `yaml:"verdict_callback"`
} `yaml:"auto_response" hotreload:"safe"`
// BPFEnforcement is the optional in-kernel deny path for matched
// outbound connections. Phase 4 of the BPF Incident Response
// Roadmap. Defaults are all-safe: enforcement off, dry-run on.
// Operators flip live denial only after dry-run telemetry review.
BPFEnforcement struct {
Enabled bool `yaml:"enabled"`
// DryRun, when true (or absent for safety), logs intended
// denials but allows the connect. False = real deny.
DryRun *bool `yaml:"dry_run,omitempty"`
DirectSMTPEgress bool `yaml:"direct_smtp_egress"`
// VerdictCallback, when true, asks auto_response.verdict_callback
// for an advisory ALLOW override before recording a USERSPACE
// action (incident close, audit note). The in-kernel hook NEVER
// waits on this; it would add latency to every connect.
VerdictCallback bool `yaml:"verdict_callback"`
} `yaml:"bpf_enforcement" hotreload:"safe"`
Challenge struct {
Enabled bool `yaml:"enabled"` // enable challenge pages instead of hard block for some IPs
ListenAddr string `yaml:"listen_addr"` // bind address for the challenge listener (default: 127.0.0.1)
ListenPort int `yaml:"listen_port"` // port for challenge server (default: 8439)
// ListenAddr defaults to loopback because the production path
// keeps the listener private until an operator deliberately
// exposes it. The webserver integration redirects browsers to
// challenge.public_url, so installed direct mode needs a
// non-loopback listen address plus TLS material via
// challenge.tls_cert / tls_key, or the webui TLS fallback.
Secret string `yaml:"secret"` // HMAC secret for challenge tokens (auto-generated if empty)
Difficulty int `yaml:"difficulty"` // proof-of-work difficulty 0-5 (default: 2)
TrustedProxies []string `yaml:"trusted_proxies"` // IPs allowed to set X-Forwarded-For (empty = trust RemoteAddr only)
// TLSCert / TLSKey activate HTTPS on the challenge listener. Empty
// values keep loopback listeners on plain HTTP. Direct/public
// listeners fall back to webui.tls_cert / webui.tls_key so
// single-cert hosts can opt in without duplicating paths.
TLSCert string `yaml:"tls_cert"`
TLSKey string `yaml:"tls_key"`
// PublicURL is the external URL the webserver redirect target
// points at. Operators put it on an existing TLS-valid host so
// the integration does not need a new DNS record or cert.
// Example: https://server.example.com:8439/challenge with
// listen_addr=0.0.0.0, tls_cert/tls_key set to the host's
// cpanel cert.
// When empty the webserver-integration installer refuses to
// run; per-vhost reverse-proxy is no longer supported because
// LSWS proxy emulation does not honor it at server scope.
PublicURL string `yaml:"public_url"`
// CaptchaFallback shows a third-party CAPTCHA widget when JS is
// disabled. All fields default empty; the feature is off until
// the operator supplies provider + keys.
CaptchaFallback struct {
Provider string `yaml:"provider"` // "turnstile" | "hcaptcha" | "" (off)
SiteKey string `yaml:"site_key"` // public key embedded in the HTML widget
SecretKey string `yaml:"secret_key"` // verified server-side against the provider
Timeout time.Duration `yaml:"timeout"` // HTTP timeout for siteverify (default 10s)
} `yaml:"captcha_fallback"`
// VerifiedSession lets operators mint a signed cookie that
// bypasses the PoW for the cookie's TTL. The signing key is
// generated at daemon startup and rotates on restart.
VerifiedSession struct {
Enabled bool `yaml:"enabled"`
CookieName string `yaml:"cookie_name"` // default: csm_admin_session
TTL time.Duration `yaml:"ttl"` // default: 4h
AdminSecret string `yaml:"admin_secret"` // shared secret POST'd to /challenge/admin-token
} `yaml:"verified_session"`
// VerifiedCrawlers allows-passes traffic from search crawlers
// whose IP forward-confirms a reverse-DNS PTR matching one of
// the configured providers.
VerifiedCrawlers struct {
Enabled bool `yaml:"enabled"`
Providers []string `yaml:"providers"` // names: googlebot | bingbot
CacheTTL time.Duration `yaml:"cache_ttl"` // default: 15m
} `yaml:"verified_crawlers"`
// PortGate locks the challenge listener TCP port to specific
// source IPs via nftables. Enabled implies the daemon owns a
// dedicated `csm_chal` inet table whose chain drops all traffic
// to challenge.listen_port except: loopback, operator infra_ips,
// and IPs the IPList has just flagged. The set entry carries
// the same TTL as the challenge entry, so nftables expires the
// allow even if the daemon dies before calling Revoke. Auto-off
// when listen_addr is loopback (gate has no effect there).
PortGate struct {
Enabled bool `yaml:"enabled"`
} `yaml:"port_gate"`
} `yaml:"challenge" hotreload:"restart"`
PHPShield struct {
Enabled bool `yaml:"enabled"` // watch PHP Shield event log for alerts (default: false)
} `yaml:"php_shield" hotreload:"restart"`
Reputation struct {
AbuseIPDBKey string `yaml:"abuseipdb_key"`
Whitelist []string `yaml:"whitelist"` // IPs to never flag as malicious
// Rspamd queries the local rspamd controller for per-IP reject/junk
// counts. Disabled by default. URL must reach the controller HTTP port
// (default 11334). Token is the controller's admin password (rspamadm
// pw -e), supplied via env when possible. Token resolution happens at
// query time so operators can rotate via env without daemon restart.
Rspamd struct {
Enabled bool `yaml:"enabled"`
URL string `yaml:"url"`
Token string `yaml:"token,omitempty"`
TokenEnv string `yaml:"token_env,omitempty"`
} `yaml:"rspamd"`
// Upstream is an HTTP threat-intel source - typically a panel host that
// caches AbuseIPDB / proprietary scores on behalf of every agent in its
// fleet. Disabled by default. Token resolution happens at query time
// (see internal/threatintel/upstream_source.go) so operators can rotate
// the bearer via TokenEnv without restarting the daemon.
Upstream struct {
Enabled bool `yaml:"enabled"`
URL string `yaml:"url"`
Token string `yaml:"token,omitempty"` // discouraged - prefer TokenEnv
TokenEnv string `yaml:"token_env,omitempty"`
CacheTTLMin int `yaml:"cache_ttl_min"`
TimeoutSec int `yaml:"timeout_sec"`
} `yaml:"upstream"`
// BotVerifyEnabled controls async PTR+forward-A verification of
// claimed search-engine bot IPs. *bool so an explicit false
// survives SIGHUP reload without being overwritten by applyDefaults.
// Default true.
BotVerifyEnabled *bool `yaml:"bot_verify_enabled"`
// Report emits signed, minimized abuse reports for confirmed-abuse
// findings to a central abuse database or a private collector.
// Opt-in; default off. Keys/secrets resolve from *_env at startup.
Report struct {
Enabled bool `yaml:"enabled"`
Classes []string `yaml:"classes"` // bruteforce, php_relay, credential_stuffing, bad_asn_egress
SpoolPath string `yaml:"spool_path"` // bbolt file; default <state_dir>/abuse_reports.db
SpoolMax int `yaml:"spool_max"` // bounded queue size; default 10000
Targets []struct {
Name string `yaml:"name"`
URL string `yaml:"url"`
Transport string `yaml:"transport"` // ed25519 | hmac
NodeID string `yaml:"node_id"`
KeyID string `yaml:"key_id"`
KeyEnv string `yaml:"key_env"` // ed25519 hex private key, or hmac secret
TokenEnv string `yaml:"token_env"` // optional bearer for hmac collectors
} `yaml:"targets"`
} `yaml:"report" hotreload:"restart"`
// Central pulls a signed scored-set from the central abuse database and
// acts on it per Action. Opt-in; default off. Central data never hard
// blocks on its own (see Action / firebreaks).
Central struct {
Enabled bool `yaml:"enabled"`
SetURL string `yaml:"set_url"`
PubkeyEnv string `yaml:"pubkey_env"` // env holding the central Ed25519 public key (hex)
RefreshInterval string `yaml:"refresh_interval"` // e.g. "6h"; default 6h
Action string `yaml:"action"` // off | challenge | block_if_local_corroborated
BlockThreshold int `yaml:"block_threshold"` // node-local score threshold for block (default 80)
} `yaml:"central" hotreload:"restart"`
} `yaml:"reputation" hotreload:"safe"`
Signatures struct {
RulesDir string `yaml:"rules_dir"`
UpdateURL string `yaml:"update_url"`
AutoUpdate bool `yaml:"auto_update"` // auto-download rules daily (default: true if update_url set)
UpdateInterval string `yaml:"update_interval"` // how often to check (default: "24h")
SigningKey string `yaml:"signing_key"` // hex-encoded ed25519 public key for verifying rule updates
YaraForge struct {
Enabled bool `yaml:"enabled"`
Tier string `yaml:"tier"` // "core", "extended", "full" (default: "core")
UpdateInterval string `yaml:"update_interval"` // default: "168h" (weekly)
DownloadURL string `yaml:"download_url"` // signed ZIP URL/template; supports {tier} and {version}
} `yaml:"yara_forge"`
DisabledRules []string `yaml:"disabled_rules"` // YARA rule names to exclude from Forge downloads
// YaraWorkerEnabled is a tri-state: nil means "use system default"
// (default-on, per ROADMAP item 2 follow-up), *true means explicit on,
// *false means explicit off. Callers must nil-check before dereferencing;
// daemon.yaraWorkerOn() is the canonical accessor.
YaraWorkerEnabled *bool `yaml:"yara_worker_enabled"`
} `yaml:"signatures" hotreload:"restart"`
WebUI struct {
Enabled bool `yaml:"enabled"`
Listen string `yaml:"listen"`
AuthToken string `yaml:"auth_token"`
MetricsToken string `yaml:"metrics_token" hotreload:"safe"` // optional Bearer token for /metrics; rotate via SIGHUP without restart
TLSCert string `yaml:"tls_cert"`
TLSKey string `yaml:"tls_key"`
UIDir string `yaml:"ui_dir"` // path to UI files on disk (default: /opt/csm/ui)
// Tokens is the multi-credential model added in v2.12.0. Each entry has
// a stable name (for audit), an opaque secret, and a scope that gates
// which endpoints accept it. Legacy AuthToken is preserved during the
// migration window so callers that read it directly keep working;
// applyDefaults populates Tokens from AuthToken when only the legacy
// field is set.
Tokens []WebUIToken `yaml:"tokens,omitempty"`
} `yaml:"webui" hotreload:"restart"`
EmailAV EmailAVConfig `yaml:"email_av" hotreload:"restart"`
EmailProtection struct {
PasswordCheckIntervalMin int `yaml:"password_check_interval_min"`
HighVolumeSenders []string `yaml:"high_volume_senders"`
RateWarnThreshold int `yaml:"rate_warn_threshold"`
RateCritThreshold int `yaml:"rate_crit_threshold"`
RateWindowMin int `yaml:"rate_window_min"`
KnownForwarders []string `yaml:"known_forwarders"`
// PHPRelay is the operator-tunable knob block for the email
// PHP-relay protection feature (Stage 1). All thresholds default
// to the values set in applyDefaults(); leaving any field at its
// zero value triggers the documented default at load time.
PHPRelay struct {
Enabled bool `yaml:"enabled"`
RateWindowMin int `yaml:"rate_window_min"`
HeaderScoreVolumeMin int `yaml:"header_score_volume_min"`
AbsoluteVolumePerHour int `yaml:"absolute_volume_per_hour"`
AccountVolumePerHour int `yaml:"account_volume_per_hour"`
ReputationFailuresPer24h int `yaml:"reputation_failures_per_24h"`
FanoutDistinctScripts int `yaml:"fanout_distinct_scripts"`
FanoutWindowMin int `yaml:"fanout_window_min"`
BaselineSigma float64 `yaml:"baseline_sigma"`
BaselineObservationDays int `yaml:"baseline_observation_days"`
PoliciesDir string `yaml:"policies_dir"`
} `yaml:"php_relay"`
// CloudRelay scopes opt-out for the email_cloud_relay_abuse
// detector only. Use this when an operator legitimately runs a
// mailer on a public-cloud VM (Google Cloud, AWS, etc.) and the
// realtime/retro detectors keep false-firing on that mailbox.
// AllowUsers matches full mailboxes (case-insensitive). AllowDomains
// matches the domain part of the AUTH user (case-insensitive),
// covering every mailbox under that domain. Either match exits
// the detector before any window state is updated, so an
// allowlisted mailbox cannot prime the counter for another user.
// Leaving both empty preserves prior behavior. The shared
// EmailProtection.HighVolumeSenders list still applies as well.
CloudRelay struct {
AllowUsers []string `yaml:"allow_users"`
AllowDomains []string `yaml:"allow_domains"`
} `yaml:"cloud_relay"`
// ForwardGuard is the opt-in protection that holds spam/backscatter
// forward copies before they relay to an external provider. Default
// off; dry-run first. The MTA enforces; CSM only generates the rule
// and owns the quarantine, so it is never in the live mail path.
ForwardGuard ForwardGuardConfig `yaml:"forward_guard"`
} `yaml:"email_protection" hotreload:"safe"`
Firewall *firewall.FirewallConfig `yaml:"firewall" hotreload:"restart"`
GeoIP struct {
AccountID string `yaml:"account_id"`
LicenseKey string `yaml:"license_key"`
Editions []string `yaml:"editions"`
AutoUpdate *bool `yaml:"auto_update"` // nil = true when credentials set
UpdateInterval string `yaml:"update_interval"` // default "24h"
} `yaml:"geoip" hotreload:"restart"`
ModSecErrorLog string `yaml:"modsec_error_log" hotreload:"restart"`
ModSec struct {
RulesFile string `yaml:"rules_file"` // path to modsec2.user.conf
OverridesFile string `yaml:"overrides_file"` // path to csm-overrides.conf
ReloadCommand string `yaml:"reload_command"` // e.g. "systemctl restart lsws"
} `yaml:"modsec" hotreload:"restart"`
// WebServer overrides the auto-detected web server paths. Every field is
// optional: anything left blank or empty falls back to what
// platform.Detect() returned at startup. Intended for hosts with a
// custom layout (reverse proxy in front of a second daemon, non-standard
// package locations, chroot, etc.).
WebServer struct {
Type string `yaml:"type"` // "apache", "nginx", "litespeed" -- overrides auto-detect
ConfigDir string `yaml:"config_dir"` // e.g. /etc/apache2 or /etc/nginx
AccessLogs []string `yaml:"access_logs"` // candidate access-log paths, tried in order
ErrorLogs []string `yaml:"error_logs"` // candidate error-log paths (used for modsec denies)
ModSecAudits []string `yaml:"modsec_audit_logs"` // candidate ModSecurity audit-log paths
// TrustedProxies is the list of IP addresses or CIDR ranges whose
// X-Forwarded-For header is trusted for client IP extraction. When
// the connecting IP is not in this list, XFF is ignored and
// RemoteIP is used as-is.
TrustedProxies []string `yaml:"trusted_proxies"` // IP/CIDR sources allowed to supply X-Forwarded-For
// DomlogGlobs overrides the auto-detected per-vhost access-log glob
// patterns. When set, the platform default for the detected panel/OS
// is discarded and only these patterns are used. Leave empty to use
// the auto-detected globs.
DomlogGlobs []string `yaml:"domlog_globs"` // per-vhost access-log globs
} `yaml:"web_server" hotreload:"restart"`
// AccountRoots lets operators point the account-scan based checks at
// non-cPanel web root layouts. Each entry is a glob pattern expanded
// at check time. Examples:
//
// account_roots:
// - /var/www/*/public
// - /srv/http/*
// - /home/*/public_html # cPanel default (implicit when unset on cPanel)
//
// When unset, CSM uses the cPanel default of /home/*/public_html on
// cPanel hosts and an empty list on non-cPanel hosts (account-scan
// checks skip entirely). See docs/src/configuration.md for the full
// list of checks that consume this.
AccountRoots []string `yaml:"account_roots" hotreload:"restart"`
Performance struct {
Enabled *bool `yaml:"enabled"`
LoadHighMultiplier float64 `yaml:"load_high_multiplier"`
LoadCriticalMultiplier float64 `yaml:"load_critical_multiplier"`
PHPProcessWarnPerUser int `yaml:"php_process_warn_per_user"`
PHPProcessCriticalTotalMult int `yaml:"php_process_critical_total_multiplier"`
ErrorLogWarnSizeMB int `yaml:"error_log_warn_size_mb"`
MySQLJoinBufferMaxMB int `yaml:"mysql_join_buffer_max_mb"`
MySQLWaitTimeoutMax int `yaml:"mysql_wait_timeout_max"`
MySQLMaxConnectionsPerUser int `yaml:"mysql_max_connections_per_user"`
RedisBgsaveMinInterval int `yaml:"redis_bgsave_min_interval"`
RedisLargeDatasetGB int `yaml:"redis_large_dataset_gb"`
WPMemoryLimitMaxMB int `yaml:"wp_memory_limit_max_mb"`
WPTransientWarnMB int `yaml:"wp_transient_warn_mb"`
WPTransientCriticalMB int `yaml:"wp_transient_critical_mb"`
// WPCronFix tunes the WP-Cron remediation (manual fix from the Web UI
// and the daemon auto-response). Disabling WP-Cron without a real cron
// would stop scheduled tasks, so the fix also installs a per-user system
// cron that runs wp-cron.php on this interval.
WPCronFix struct {
IntervalMinutes int `yaml:"interval_minutes"` // system cron frequency; default 5, clamped to [1,60]
PHPBin string `yaml:"php_bin"` // php interpreter for the cron line; empty => detect
} `yaml:"wp_cron_fix"`
} `yaml:"performance" hotreload:"restart"`
Cloudflare struct {
Enabled bool `yaml:"enabled"`
RefreshHours int `yaml:"refresh_hours"`
} `yaml:"cloudflare" hotreload:"restart"`
C2Blocklist []string `yaml:"c2_blocklist" hotreload:"restart"`
BackdoorPorts []int `yaml:"backdoor_ports" hotreload:"restart"`
// DisabledChecks lists check names that should be skipped entirely by
// the runner (no execution, no finding, no email/webhook/audit). Use
// this when a whole category does not apply to a host (e.g. WAF/web
// checks on DNS-only cPanel servers). Distinct from
// alerts.email.disabled_checks, which only suppresses email but still
// runs the check and emits findings to other sinks.
DisabledChecks []string `yaml:"disabled_checks" hotreload:"safe"`
// Retention bounds bbolt growth. When enabled, a daily sweep prunes
// per-bucket entries older than the configured TTL and an online
// compaction pass shrinks the on-disk file once the fill ratio drops
// below CompactFillRatio (and the file exceeds CompactMinSizeMB).
// All fields are hot-reload:"restart" because the retention goroutine
// captures these on daemon start.
Retention struct {
Enabled bool `yaml:"enabled"` // opt-in
FindingsDays int `yaml:"findings_days"` // default 90
HistoryDays int `yaml:"history_days"` // default 30
ReputationDays int `yaml:"reputation_days"` // default 180
SweepInterval string `yaml:"sweep_interval"` // default "24h"
CompactMinSizeMB int `yaml:"compact_min_size_mb"` // default 128
CompactFillRatio float64 `yaml:"compact_fill_ratio"` // default 0.5
} `yaml:"retention" hotreload:"restart"`
// Sentry ships panics and selected errors to a Sentry server for
// aggregation across hosts. Disabled by default; set enabled=true and
// provide a DSN from the Sentry project. Init is one-shot: changes
// require a daemon restart.
Sentry struct {
Enabled bool `yaml:"enabled"`
DSN string `yaml:"dsn"`
Environment string `yaml:"environment"` // e.g. "production", "staging"
SampleRate float64 `yaml:"sample_rate"` // 0 -> 1.0 (capture all errors)
Debug bool `yaml:"debug"` // SDK debug logs to stderr
} `yaml:"sentry" hotreload:"restart"`
// MailLogs selects the log source for the postfix/dovecot brute-force
// and relay detectors. Changing the source (file vs. journal) requires
// the daemon to re-attach its reader, so the field is tagged restart.
MailLogs MailLogsConfig `yaml:"mail_logs,omitempty" hotreload:"restart"`
// Updates controls the upstream release-availability poll surfaced
// in the Web UI top banner. The daemon never downloads or applies
// updates -- it only tells the operator that a newer version
// exists. Disable wholesale on air-gapped hosts.
Updates struct {
// CheckEnabled is a tri-state. nil means default-on; explicit
// false disables the poll entirely (no outbound HTTP, no
// package-manager probe). Use a pointer so the absence of the
// key in YAML is distinguishable from `check_enabled: false`.
CheckEnabled *bool `yaml:"check_enabled"`
// Interval is parsed by time.ParseDuration. Defaults to 24h;
// clamped to a 1h floor by updatecheck.New.
Interval string `yaml:"interval"`
// GitHubAPIURL overrides the default release endpoint. Tests
// and air-gapped mirrors use this; leave empty in production.
GitHubAPIURL string `yaml:"github_api_url,omitempty"`
// PackageName is the apt/dnf package name to query when the
// GitHub call fails. Defaults to "csm".
PackageName string `yaml:"package_name,omitempty"`
} `yaml:"updates" hotreload:"restart"`
// Incidents groups correlator-side knobs the operator can tune
// without code changes. Hot-reload "restart" because the daemon
// captures these on startup; flipping mid-run would race the
// retention loop.
Incidents struct {
// AutoClose resolves Open / Contained incidents whose UpdatedAt
// exceeds the per-kind idle threshold. Default-on with safe
// thresholds; mailbox takeover and credential spray expire at
// 24h, web-account compromise at 7d, and host-level kinds never
// auto-close. Operators who want to monitor decisions without
// writing back can flip dry_run=true.
AutoClose struct {
// Enabled is tri-state: nil (default) means default-on; an
// explicit false in YAML disables. Pointer so absence in YAML
// is distinguishable from "enabled: false".
Enabled *bool `yaml:"enabled"`
DryRun bool `yaml:"dry_run"`
// ByKind maps incident kind -> idle threshold (parsed by
// time.ParseDuration). Kinds absent from the map are never
// auto-closed. Use a string-keyed map so the operator can
// add custom kinds without recompiling. Empty map falls back
// to safe defaults (mailbox_takeover=24h,
// credential_spray=24h, web_account_compromise=7d).
ByKind map[string]string `yaml:"by_kind,omitempty"`
} `yaml:"auto_close"`
// SpraySuppression collapses one source IP brute-forcing many
// distinct mailboxes/accounts into a single credential_spray
// super-incident. Default-OFF + dry_run=TRUE so the path ships
// dark; counters and audit log show what would have happened.
// Operators flip enabled=true and dry_run=false after watching
// the counters on their own infra.
SpraySuppression struct {
Enabled bool `yaml:"enabled"`
DryRun bool `yaml:"dry_run"`
DistinctMailboxes int `yaml:"distinct_mailboxes"`
SeverityEscalateAt int `yaml:"severity_escalate_at"`
PerCheck []string `yaml:"per_check"`
MaxTrackedIPs int `yaml:"max_tracked_ips"`
// BlockAtSeverity drives the firewall hand-off. Empty (default)
// means detection-only: the super-incident opens and counters
// move, but the IP is not blocked. "high" blocks as soon as the
// detector trips at DistinctMailboxes. "critical" waits for the
// severity escalation (SeverityEscalateAt distinct mailboxes)
// before blocking. Auto_response.dry_run and block_ips still
// gate the actual firewall call.
BlockAtSeverity string `yaml:"block_at_severity"`
} `yaml:"spray_suppression"`
// AutoBlock is the generic incident-driven firewall hand-off.
// Independent of SpraySuppression; applies to any non-spray
// incident kind that carries a remote_ip in its correlation
// key. Default-OFF so the path is dormant until an operator
// opts in. Block requests still respect auto_response.enabled
// and auto_response.block_ips at decision time.
AutoBlock struct {
Enabled bool `yaml:"enabled"`
BlockAtSeverity string `yaml:"block_at_severity"`
Kinds []string `yaml:"kinds"`
} `yaml:"auto_block"`
} `yaml:"incidents" hotreload:"restart"`
}
// UpdatesCheckEnabled reports the YAML-level state for the upstream
// release poll. Defaults to TRUE when omitted (most operators want
// the banner). Set `updates.check_enabled: false` to disable.
func (c *Config) UpdatesCheckEnabled() bool {
return c.Updates.CheckEnabled == nil || *c.Updates.CheckEnabled
}
// UpdatesInterval returns the parsed poll interval. Falls back to
// 24h on parse error or when unset; updatecheck applies the floor.
func (c *Config) UpdatesInterval() time.Duration {
if c.Updates.Interval == "" {
return 24 * time.Hour
}
d, err := time.ParseDuration(c.Updates.Interval)
if err != nil {
return 24 * time.Hour
}
return d
}
// UpdatesPackageName returns the apt/dnf package name to query.
// Defaults to "csm".
func (c *Config) UpdatesPackageName() string {
if c.Updates.PackageName == "" {
return "csm"
}
return c.Updates.PackageName
}
// PHPRelayFreezeEnabled reports whether auto-freeze should run for the
// email PHP-relay detectors. Defaults to false when freeze was not set
// in YAML — the operator must opt in explicitly.
func (cfg *Config) PHPRelayFreezeEnabled() bool {
return cfg.AutoResponse.PHPRelay.Freeze != nil && *cfg.AutoResponse.PHPRelay.Freeze
}
// IncidentsAutoCloseEnabled reports whether the auto-close path should
// run. Defaults to TRUE when the YAML key is absent so a fresh
// installation drains stale incidents without explicit opt-in. An
// explicit `incidents.auto_close.enabled: false` disables.
func (cfg *Config) IncidentsAutoCloseEnabled() bool {
return cfg.Incidents.AutoClose.Enabled == nil || *cfg.Incidents.AutoClose.Enabled
}
// IncidentsAutoCloseThresholds returns the per-kind idle thresholds in
// parsed form. Built from the operator's by_kind YAML map, falling
// back to safe defaults (mailbox_takeover=24h, web_account_compromise=7d,
// credential_spray=24h) when the operator did not supply a map.
// Unparseable durations are skipped silently so a typo in one entry
// does not disable the rest.
func (cfg *Config) IncidentsAutoCloseThresholds() map[string]time.Duration {
out := defaultIncidentAutoCloseThresholds()
for kind, raw := range cfg.Incidents.AutoClose.ByKind {
if raw == "" {
delete(out, kind)
continue
}
d, err := time.ParseDuration(raw)
if err != nil || d <= 0 {
continue
}
out[kind] = d
}
return out
}
func defaultIncidentAutoCloseThresholds() map[string]time.Duration {
return map[string]time.Duration{
"mailbox_takeover": 24 * time.Hour,
"credential_spray": 24 * time.Hour,
"web_account_compromise": 7 * 24 * time.Hour,
}
}
// IncidentsAutoBlockKinds returns the configured kinds set in the shape
// the correlator expects. Empty result means "any non-spray kind".
func (cfg *Config) IncidentsAutoBlockKinds() map[string]bool {
src := cfg.Incidents.AutoBlock.Kinds
out := make(map[string]bool, len(src))
for _, k := range src {
if k == "" {
continue
}
out[k] = true
}
return out
}
// IncidentsSpraySuppressionPerCheck returns the configured per-check map
// in the shape the correlator expects. Falls back to a safe default
// when the operator did not supply a list.
func (cfg *Config) IncidentsSpraySuppressionPerCheck() map[string]bool {
src := cfg.Incidents.SpraySuppression.PerCheck
if len(src) == 0 {
src = []string{
"email_auth_failure_realtime",
"pam_auth_failure",
"ssh_bruteforce",
}
}
out := make(map[string]bool, len(src))
for _, c := range src {
if c == "" {
continue
}
out[c] = true
}
return out
}
// PHPRelayDryRunEnabled reports the YAML-level dry-run state for the
// email PHP-relay auto-freeze. Defaults to TRUE when dry_run was not
// set, which is the safe shipped behaviour: an operator who enables
// freeze without thinking about dry-run gets a dry-run, not a live
// freeze. nil-or-explicit-true => true; explicit-false => false.
func (cfg *Config) PHPRelayDryRunEnabled() bool {
return cfg.AutoResponse.PHPRelay.DryRun == nil || *cfg.AutoResponse.PHPRelay.DryRun
}
// AutoResponseDryRunEnabled mirrors PHPRelayDryRunEnabled: nil-or-true means true.
// When dry_run is absent from YAML the operator gets safe dry-run behaviour;
// explicit false is required to enable live nftables blocking.
func (cfg *Config) AutoResponseDryRunEnabled() bool {
return cfg.AutoResponse.DryRun == nil || *cfg.AutoResponse.DryRun
}
// DirectSMTPEgressDryRunEnabled reports the YAML-level dry-run state
// for the direct SMTP egress detector. Defaults to TRUE when dry_run
// was omitted (safety default). Operators must explicitly set
// `dry_run: false` to flip the detector to active mode.
func (c *Config) DirectSMTPEgressDryRunEnabled() bool {
if c.Detection.DirectSMTPEgress.DryRun == nil {
return true
}
return *c.Detection.DirectSMTPEgress.DryRun
}
// BPFEnforcementDryRunEnabled reports the YAML-level dry-run state for
// BPF cgroup-deny enforcement. Defaults to TRUE when dry_run is omitted
// (safety default). Operators must explicitly set `dry_run: false` to
// flip the in-kernel program to live denial.
func (c *Config) BPFEnforcementDryRunEnabled() bool {
if c.AutoResponseDryRunEnabled() {
return true
}
if c.BPFEnforcement.DirectSMTPEgress && c.DirectSMTPEgressDryRunEnabled() {
return true
}
if c.BPFEnforcement.DryRun == nil {
return true
}
return *c.BPFEnforcement.DryRun
}
// BotVerifyEnabled reports whether async PTR+forward-A verification of
// claimed search-engine bots is on. Default true; explicit false is
// honored so operators on air-gapped networks can disable DNS calls.
func (c *Config) BotVerifyEnabled() bool {
if c.Reputation.BotVerifyEnabled == nil {
return true
}
return *c.Reputation.BotVerifyEnabled
}
type defaultPresence struct {
smtpProbeThreshold bool
forwardGuard forwardGuardPresence
}
// forwardGuardPresence records which forward-guard fields were set explicitly,
// so an operator's literal false survives the safe default-true.
type forwardGuardPresence struct {
dryRun bool
retention bool
bounceBackscatter bool
spamFlagged bool
malware bool
badSenderIP bool
authFail bool
}
// ForwardGuardConfig configures the email forward-guard. Enabled is the master
// switch (default off); DryRun (default true) accounts without holding. Each
// hold signal is individually toggleable; unset signals default on but only
// matter once Enabled.
type ForwardGuardConfig struct {
Enabled bool `yaml:"enabled"`
DryRun bool `yaml:"dry_run"`
HoldSignals ForwardHoldSignals `yaml:"hold_signals"`
SkipForwarders []string `yaml:"skip_forwarders"`
QuarantineRetentionDays int `yaml:"quarantine_retention_days"`
}
// ForwardHoldSignals toggles which layered signals may hold a forward copy.
type ForwardHoldSignals struct {
BounceBackscatter bool `yaml:"bounce_backscatter"`
SpamFlagged bool `yaml:"spam_flagged"`
Malware bool `yaml:"malware"`
BadSenderIP bool `yaml:"bad_sender_ip"`
AuthFail bool `yaml:"auth_fail"`
}
func applyDefaults(cfg *Config, presence defaultPresence) {
// Defaults
if cfg.StatePath == "" {
cfg.StatePath = "/var/lib/csm/state"
}
if cfg.Alerts.AuditLog.File.Enabled && cfg.Alerts.AuditLog.File.Path == "" {
cfg.Alerts.AuditLog.File.Path = "/var/log/csm/audit.jsonl"
}
if cfg.Alerts.AuditLog.Syslog.Enabled {
if cfg.Alerts.AuditLog.Syslog.Network == "" {
cfg.Alerts.AuditLog.Syslog.Network = "udp"
}
if cfg.Alerts.AuditLog.Syslog.Facility == "" {
cfg.Alerts.AuditLog.Syslog.Facility = "local0"
}
}
if cfg.Alerts.Webhook.HMACSecretEnv != "" {
if v := os.Getenv(cfg.Alerts.Webhook.HMACSecretEnv); v != "" {
cfg.Alerts.Webhook.HMACSecret = v
}
}
if cfg.Signatures.RulesDir == "" {
cfg.Signatures.RulesDir = "/opt/csm/rules"
}
if cfg.Signatures.YaraForge.Tier == "" {
cfg.Signatures.YaraForge.Tier = "core"
}
if cfg.Signatures.YaraForge.UpdateInterval == "" {
cfg.Signatures.YaraForge.UpdateInterval = "168h"
}
if cfg.WebUI.Listen == "" {
cfg.WebUI.Listen = "0.0.0.0:9443"
}
if cfg.WebUI.AuthToken != "" && len(cfg.WebUI.Tokens) == 0 {
cfg.WebUI.Tokens = []WebUIToken{{
Name: "legacy-auth-token", Token: cfg.WebUI.AuthToken, Scope: "admin",
}}
}
if cfg.Thresholds.MailQueueWarn == 0 {
cfg.Thresholds.MailQueueWarn = 500
}
if cfg.Thresholds.MailQueueCrit == 0 {
cfg.Thresholds.MailQueueCrit = 2000
}
if cfg.Thresholds.StateExpiryHours == 0 {
cfg.Thresholds.StateExpiryHours = 24
}
if cfg.Thresholds.DeepScanIntervalMin == 0 {
cfg.Thresholds.DeepScanIntervalMin = 60
}
if cfg.Thresholds.WPCoreCheckIntervalMin == 0 {
cfg.Thresholds.WPCoreCheckIntervalMin = 60
}
if cfg.Thresholds.WebshellScanIntervalMin == 0 {
cfg.Thresholds.WebshellScanIntervalMin = 30
}
if cfg.Thresholds.FilesystemScanIntervalMin == 0 {
cfg.Thresholds.FilesystemScanIntervalMin = 30
}
if cfg.Thresholds.PluginCheckIntervalMin == 0 {
cfg.Thresholds.PluginCheckIntervalMin = 1440
}
if cfg.Thresholds.BruteForceWindow == 0 {
cfg.Thresholds.BruteForceWindow = 5000
}
if cfg.Thresholds.DomlogMaxFiles == 0 {
cfg.Thresholds.DomlogMaxFiles = 500
}
if cfg.Thresholds.AccountScanMaxFiles == 0 {
cfg.Thresholds.AccountScanMaxFiles = 10000
}
if cfg.Thresholds.CrontabBase64BlobMaxBytes == 0 {
cfg.Thresholds.CrontabBase64BlobMaxBytes = 16384
}
if cfg.Thresholds.DomlogTailLines == 0 {
cfg.Thresholds.DomlogTailLines = 500
}
if cfg.Thresholds.DomlogMaxAgeMin == 0 {
cfg.Thresholds.DomlogMaxAgeMin = 30
}
if cfg.Thresholds.MailLogTailLines == 0 {
cfg.Thresholds.MailLogTailLines = 500
}
if cfg.Thresholds.SyslogMessagesTailLines == 0 {
cfg.Thresholds.SyslogMessagesTailLines = 200
}
if cfg.Thresholds.CredStuffingDistinctAccounts == 0 {
cfg.Thresholds.CredStuffingDistinctAccounts = 5
}
if cfg.Thresholds.ModSecEscalationHits == 0 {
cfg.Thresholds.ModSecEscalationHits = 3
}
if cfg.Thresholds.ModSecEscalationWindowMin == 0 {
cfg.Thresholds.ModSecEscalationWindowMin = 10
}
if cfg.Thresholds.HTTPFloodWindowMin <= 0 {
cfg.Thresholds.HTTPFloodWindowMin = 5
}
// HTTPFloodThreshold has no nonzero default: 0 means disabled and
// that is the shipped behavior.
if cfg.Thresholds.HTTPUASpoofThreshold <= 0 {
cfg.Thresholds.HTTPUASpoofThreshold = 30
}
if cfg.Thresholds.SMTPBruteForceThreshold == 0 {
cfg.Thresholds.SMTPBruteForceThreshold = 5
}
if cfg.Thresholds.SMTPBruteForceWindowMin == 0 {
cfg.Thresholds.SMTPBruteForceWindowMin = 10
}
if cfg.Thresholds.SMTPBruteForceSuppressMin == 0 {
cfg.Thresholds.SMTPBruteForceSuppressMin = 60
}
if cfg.Thresholds.SMTPBruteForceSubnetThresh == 0 {
cfg.Thresholds.SMTPBruteForceSubnetThresh = 8
}
if cfg.Thresholds.SMTPAccountSprayThreshold == 0 {
cfg.Thresholds.SMTPAccountSprayThreshold = 12
}
if cfg.Thresholds.SMTPBruteForceMaxTracked == 0 {
cfg.Thresholds.SMTPBruteForceMaxTracked = 20000
}
if cfg.Thresholds.SMTPProbeThreshold == 0 && !presence.smtpProbeThreshold {
cfg.Thresholds.SMTPProbeThreshold = 100
}
if cfg.Thresholds.SMTPProbeWindowMin == 0 {
cfg.Thresholds.SMTPProbeWindowMin = 5
}
if cfg.Thresholds.SMTPProbeSuppressMin == 0 {
cfg.Thresholds.SMTPProbeSuppressMin = 60
}
if cfg.Thresholds.SMTPProbeMaxTracked == 0 {
cfg.Thresholds.SMTPProbeMaxTracked = 20000
}
if cfg.Thresholds.MailBruteForceThreshold == 0 {
cfg.Thresholds.MailBruteForceThreshold = 5
}
if cfg.Thresholds.MailBruteForceWindowMin == 0 {
cfg.Thresholds.MailBruteForceWindowMin = 10
}
if cfg.Thresholds.MailBruteForceSuppressMin == 0 {
cfg.Thresholds.MailBruteForceSuppressMin = 60
}
if cfg.Thresholds.MailBruteForceSubnetThresh == 0 {
cfg.Thresholds.MailBruteForceSubnetThresh = 8
}
if cfg.Thresholds.MailAccountSprayThreshold == 0 {
cfg.Thresholds.MailAccountSprayThreshold = 12
}
if cfg.Thresholds.MailBruteForceMaxTracked == 0 {
cfg.Thresholds.MailBruteForceMaxTracked = 20000
}
if cfg.Alerts.MaxPerHour == 0 {
cfg.Alerts.MaxPerHour = 30
}
if cfg.Challenge.ListenAddr == "" {
cfg.Challenge.ListenAddr = "127.0.0.1"
}
if cfg.Challenge.ListenPort == 0 {
cfg.Challenge.ListenPort = 8439
}
if cfg.Challenge.Difficulty == 0 {
cfg.Challenge.Difficulty = 2
}
if cfg.Firewall == nil {
cfg.Firewall = firewall.DefaultConfig()
}
if len(cfg.GeoIP.Editions) == 0 {
cfg.GeoIP.Editions = []string{"GeoLite2-City", "GeoLite2-ASN"}
}
if cfg.GeoIP.UpdateInterval == "" {
cfg.GeoIP.UpdateInterval = "24h"
}
EmailAVDefaults(&cfg.EmailAV)
if cfg.EmailProtection.PasswordCheckIntervalMin == 0 {
cfg.EmailProtection.PasswordCheckIntervalMin = 1440
}
if cfg.EmailProtection.RateWarnThreshold == 0 {
cfg.EmailProtection.RateWarnThreshold = 50
}
if cfg.EmailProtection.RateCritThreshold == 0 {
cfg.EmailProtection.RateCritThreshold = 100
}
if cfg.EmailProtection.RateWindowMin == 0 {
cfg.EmailProtection.RateWindowMin = 10
}
applyForwardGuardDefaults(&cfg.EmailProtection.ForwardGuard, presence.forwardGuard)
// EmailProtection.PHPRelay defaults. Freeze/DryRun are *bool and
// remain nil here -- accessors resolve the safe defaults
// (PHPRelayFreezeEnabled / PHPRelayDryRunEnabled) so we do NOT
// mutate them. AccountVolumePerHour stays at 0 by default to mark
// "auto-derive from cPanel maxemailsperhour" downstream.
if cfg.EmailProtection.PHPRelay.RateWindowMin == 0 {
cfg.EmailProtection.PHPRelay.RateWindowMin = 5
}
if cfg.EmailProtection.PHPRelay.HeaderScoreVolumeMin == 0 {
cfg.EmailProtection.PHPRelay.HeaderScoreVolumeMin = 5
}
if cfg.EmailProtection.PHPRelay.AbsoluteVolumePerHour == 0 {
cfg.EmailProtection.PHPRelay.AbsoluteVolumePerHour = 30
}
if cfg.EmailProtection.PHPRelay.ReputationFailuresPer24h == 0 {
cfg.EmailProtection.PHPRelay.ReputationFailuresPer24h = 3
}
if cfg.EmailProtection.PHPRelay.FanoutDistinctScripts == 0 {
cfg.EmailProtection.PHPRelay.FanoutDistinctScripts = 3
}
if cfg.EmailProtection.PHPRelay.FanoutWindowMin == 0 {
cfg.EmailProtection.PHPRelay.FanoutWindowMin = 5
}
if cfg.EmailProtection.PHPRelay.BaselineSigma == 0 {
cfg.EmailProtection.PHPRelay.BaselineSigma = 3.0
}
if cfg.EmailProtection.PHPRelay.BaselineObservationDays == 0 {
cfg.EmailProtection.PHPRelay.BaselineObservationDays = 7
}
if cfg.EmailProtection.PHPRelay.PoliciesDir == "" {
cfg.EmailProtection.PHPRelay.PoliciesDir = "/opt/csm/policies/php_relay"
}
if cfg.AutoResponse.PHPRelay.MaxActionsPerMinute == 0 {
cfg.AutoResponse.PHPRelay.MaxActionsPerMinute = 60
}
if cfg.AutoResponse.MaxBlocksPerHour == 0 {
cfg.AutoResponse.MaxBlocksPerHour = DefaultMaxBlocksPerHour
}
// Performance defaults.
// Enabled is a tri-state *bool: nil means "use system default (on)", true means
// explicitly enabled, false means explicitly disabled. We do NOT apply a default
// here so that callers can distinguish "operator left it unset" (nil) from
// "operator set it to true" (&true). All callers must nil-check before dereferencing;
// perfEnabled() in checks/performance.go treats nil as true.
if cfg.Performance.LoadHighMultiplier == 0 {
cfg.Performance.LoadHighMultiplier = 1.0
}
if cfg.Performance.LoadCriticalMultiplier == 0 {
cfg.Performance.LoadCriticalMultiplier = 2.0
}
if cfg.Performance.PHPProcessWarnPerUser == 0 {
cfg.Performance.PHPProcessWarnPerUser = 20
}
if cfg.Performance.PHPProcessCriticalTotalMult == 0 {
cfg.Performance.PHPProcessCriticalTotalMult = 5
}
if cfg.Performance.ErrorLogWarnSizeMB == 0 {
cfg.Performance.ErrorLogWarnSizeMB = 50
}
if cfg.Performance.MySQLJoinBufferMaxMB == 0 {
cfg.Performance.MySQLJoinBufferMaxMB = 64
}
if cfg.Performance.MySQLWaitTimeoutMax == 0 {
cfg.Performance.MySQLWaitTimeoutMax = 3600
}
if cfg.Performance.MySQLMaxConnectionsPerUser == 0 {
cfg.Performance.MySQLMaxConnectionsPerUser = 10
}
if cfg.Performance.RedisBgsaveMinInterval == 0 {
cfg.Performance.RedisBgsaveMinInterval = 900
}
if cfg.Performance.RedisLargeDatasetGB == 0 {
cfg.Performance.RedisLargeDatasetGB = 4
}
if cfg.Performance.WPMemoryLimitMaxMB == 0 {
cfg.Performance.WPMemoryLimitMaxMB = 512
}
if cfg.Performance.WPTransientWarnMB == 0 {
cfg.Performance.WPTransientWarnMB = 1
}
if cfg.Performance.WPTransientCriticalMB == 0 {
cfg.Performance.WPTransientCriticalMB = 10
}
if cfg.Performance.WPCronFix.IntervalMinutes == 0 {
cfg.Performance.WPCronFix.IntervalMinutes = 5
}
if cfg.Cloudflare.RefreshHours == 0 {
cfg.Cloudflare.RefreshHours = 6
}
// Retention: defaults apply whether or not the feature is enabled, so
// that flipping `enabled: true` without further tuning gives the
// documented behaviour.
if cfg.Retention.FindingsDays == 0 {
cfg.Retention.FindingsDays = 90
}
if cfg.Retention.HistoryDays == 0 {
cfg.Retention.HistoryDays = 30
}
if cfg.Retention.ReputationDays == 0 {
cfg.Retention.ReputationDays = 180
}
if cfg.Retention.SweepInterval == "" {
cfg.Retention.SweepInterval = "24h"
}
if cfg.Retention.CompactMinSizeMB == 0 {
cfg.Retention.CompactMinSizeMB = 128
}
if cfg.Retention.CompactFillRatio == 0 {
cfg.Retention.CompactFillRatio = 0.5
}
if cfg.MailLogs.Source == "" {
cfg.MailLogs.Source = "auto"
}
if len(cfg.MailLogs.Units) == 0 {
cfg.MailLogs.Units = []string{"postfix", "dovecot"}
}
if cfg.Updates.Interval == "" {
cfg.Updates.Interval = "24h"
}
if cfg.Updates.PackageName == "" {
cfg.Updates.PackageName = "csm"
}
if cfg.Thresholds.MailBruteAccountKey == "" {
cfg.Thresholds.MailBruteAccountKey = "builtin:dovecot-user"
}
if cfg.Reputation.Rspamd.URL == "" {
cfg.Reputation.Rspamd.URL = "http://127.0.0.1:11334"
}
// Token resolution happens at query time (see RspamdSource.Score).
if cfg.Reputation.Upstream.CacheTTLMin == 0 {
cfg.Reputation.Upstream.CacheTTLMin = 15
}
if cfg.Reputation.Upstream.TimeoutSec == 0 {
cfg.Reputation.Upstream.TimeoutSec = 5
}
// Token resolution happens at query time (UpstreamSource.resolveToken).
if cfg.AutoResponse.VerdictCallback.TimeoutSec == 0 {
cfg.AutoResponse.VerdictCallback.TimeoutSec = 2 // tight; the hook is on the block hot path
}
// Secret resolution happens at call time (verdict.Client reads env per call).
// Direct SMTP egress detector defaults. Backend "auto" lets the runtime
// pick BPF where available and fall back to legacy polling. Standard
// submission/relay ports cover the bulk of mass-mail abuse seen in the
// wild; operators can override via YAML to add e.g. 2525.
if cfg.Detection.DirectSMTPEgress.Backend == "" {
cfg.Detection.DirectSMTPEgress.Backend = "auto"
}
if len(cfg.Detection.DirectSMTPEgress.Ports) == 0 {
cfg.Detection.DirectSMTPEgress.Ports = []int{25, 465, 587}
}
}
// MaxConfigBytes caps the YAML config body size LoadBytes will parse.
// Real CSM configs (main + every drop-in fragment) top out near 64 KB
// even with verbose comments; 4 MB is several orders of magnitude
// above legitimate use and far below the size at which a malformed
// or attacker-supplied file would force the YAML parser to allocate
// gigabytes of intermediate state.
const MaxConfigBytes = 4 * 1024 * 1024
var errConfigTooLarge = errors.New("config input exceeds byte cap")
func readConfigBytesLimited(r io.Reader) ([]byte, error) {
data, err := io.ReadAll(io.LimitReader(r, MaxConfigBytes+1))
if err != nil {
return nil, err
}
if int64(len(data)) > MaxConfigBytes {
return nil, errConfigTooLarge
}
return data, nil
}
// LoadBytes decodes a YAML config body and applies all defaults,
// matching Load. ConfigFile is left empty; the caller sets it.
func LoadBytes(data []byte) (*Config, error) {
if len(data) > MaxConfigBytes {
return nil, fmt.Errorf("parsing config: input size %d exceeds %d byte cap", len(data), MaxConfigBytes)
}
presence, err := defaultPresenceFromYAML(data)
if err != nil {
return nil, fmt.Errorf("parsing config: %w", err)
}
cfg := &Config{}
dec := yaml.NewDecoder(bytes.NewReader(data))
dec.KnownFields(true)
if err := dec.Decode(cfg); err != nil && !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("parsing config: %w", err)
}
applyDefaults(cfg, presence)
if err := validateWebUITokens(cfg); err != nil {
return nil, err
}
if err := validateMailLogs(cfg); err != nil {
return nil, err
}
if err := validateMailBruteAccountKey(cfg); err != nil {
return nil, err
}
if err := validateReputation(cfg); err != nil {
return nil, err
}
if err := validateVerdictCallback(cfg); err != nil {
return nil, err
}
if err := validateDirectSMTPEgress(cfg); err != nil {
return nil, err
}
if err := validateForwardGuard(cfg); err != nil {
return nil, err
}
return cfg, nil
}
func defaultPresenceFromYAML(data []byte) (defaultPresence, error) {
var presence defaultPresence
if len(bytes.TrimSpace(data)) == 0 {
return presence, nil
}
var raw struct {
Thresholds map[string]yaml.Node `yaml:"thresholds"`
EmailProtection struct {
ForwardGuard map[string]yaml.Node `yaml:"forward_guard"`
} `yaml:"email_protection"`
}
if err := yaml.Unmarshal(data, &raw); err != nil {
return presence, err
}
_, presence.smtpProbeThreshold = raw.Thresholds["smtp_probe_threshold"]
fg := raw.EmailProtection.ForwardGuard
_, presence.forwardGuard.dryRun = fg["dry_run"]
_, presence.forwardGuard.retention = fg["quarantine_retention_days"]
if node, ok := fg["hold_signals"]; ok {
var sig map[string]yaml.Node
if err := node.Decode(&sig); err != nil {
return presence, err
}
_, presence.forwardGuard.bounceBackscatter = sig["bounce_backscatter"]
_, presence.forwardGuard.spamFlagged = sig["spam_flagged"]
_, presence.forwardGuard.malware = sig["malware"]
_, presence.forwardGuard.badSenderIP = sig["bad_sender_ip"]
_, presence.forwardGuard.authFail = sig["auth_fail"]
}
return presence, nil
}
// applyForwardGuardDefaults fills unset forward-guard fields. dry_run and every
// hold signal default on; an explicit false (tracked via presence) is kept.
func applyForwardGuardDefaults(fg *ForwardGuardConfig, p forwardGuardPresence) {
if !p.dryRun {
fg.DryRun = true
}
if !p.bounceBackscatter {
fg.HoldSignals.BounceBackscatter = true
}
if !p.spamFlagged {
fg.HoldSignals.SpamFlagged = true
}
if !p.malware {
fg.HoldSignals.Malware = true
}
if !p.badSenderIP {
fg.HoldSignals.BadSenderIP = true
}
if !p.authFail {
fg.HoldSignals.AuthFail = true
}
if !p.retention {
fg.QuarantineRetentionDays = 14
}
}
// validateForwardGuard rejects nonsensical or unsafe forward-guard configs. The
// guard only matters when enabled, so an off guard never fails validation.
func validateForwardGuard(cfg *Config) error {
fg := cfg.EmailProtection.ForwardGuard
if !fg.Enabled {
return nil
}
if fg.QuarantineRetentionDays <= 0 {
return fmt.Errorf("email_protection.forward_guard.quarantine_retention_days must be > 0 when enabled")
}
// Enforce mode requires a signal exim can actually evaluate at routing time.
// spam_flagged/malware/auth_fail are accounted in dry-run but not yet
// enforceable, so enforcing with only those on would silently hold nothing.
if !fg.DryRun && !fg.HoldSignals.BounceBackscatter && !fg.HoldSignals.BadSenderIP {
return fmt.Errorf("email_protection.forward_guard: enforce mode (dry_run:false) requires bounce_backscatter or bad_sender_ip enabled; spam_flagged/malware/auth_fail are dry-run only until exim content scanning is enabled")
}
return nil
}
func validateReputation(cfg *Config) error {
up := cfg.Reputation.Upstream
if up.CacheTTLMin != 0 && (up.CacheTTLMin < 1 || up.CacheTTLMin > 1440) {
return fmt.Errorf("reputation.upstream.cache_ttl_min must be between 1 and 1440")
}
if up.TimeoutSec != 0 && (up.TimeoutSec < 1 || up.TimeoutSec > 60) {
return fmt.Errorf("reputation.upstream.timeout_sec must be between 1 and 60")
}
if !up.Enabled {
return nil
}
if strings.TrimSpace(up.URL) == "" {
return fmt.Errorf("reputation.upstream.enabled=true but url is empty")
}
parsed, err := url.Parse(up.URL)
if err != nil {
return fmt.Errorf("reputation.upstream.url: %w", err)
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return fmt.Errorf("reputation.upstream.url must use http or https")
}
if parsed.Host == "" {
return fmt.Errorf("reputation.upstream.url must include host")
}
if parsed.Scheme == "http" && !isLoopbackHost(parsed.Hostname()) {
return fmt.Errorf("reputation.upstream.url must use https for non-loopback hosts (bearer token would otherwise leak in plaintext)")
}
return nil
}
// isLoopbackHost keeps plain HTTP limited to same-host panel deployments.
func isLoopbackHost(host string) bool {
if host == "" {
return false
}
if strings.EqualFold(host, "localhost") {
return true
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return ip.IsLoopback()
}
func validateVerdictCallback(cfg *Config) error {
_, err := validateVerdictCallbackField(cfg)
return err
}
func validateVerdictCallbackField(cfg *Config) (string, error) {
vc := cfg.AutoResponse.VerdictCallback
if vc.TimeoutSec != 0 && (vc.TimeoutSec < 1 || vc.TimeoutSec > 30) {
return "auto_response.verdict_callback.timeout_sec", fmt.Errorf("auto_response.verdict_callback.timeout_sec must be between 1 and 30")
}
if !vc.Enabled {
return "", nil
}
rawURL := strings.TrimSpace(vc.URL)
if rawURL == "" {
return "auto_response.verdict_callback.url", fmt.Errorf("auto_response.verdict_callback.enabled=true but url is empty")
}
parsed, err := url.Parse(rawURL)
if err != nil {
return "auto_response.verdict_callback.url", fmt.Errorf("auto_response.verdict_callback.url: %w", err)
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "auto_response.verdict_callback.url", fmt.Errorf("auto_response.verdict_callback.url must use http or https")
}
if parsed.Host == "" {
return "auto_response.verdict_callback.url", fmt.Errorf("auto_response.verdict_callback.url must include host")
}
if err := validateVerdictCallbackSecret(verdictCallbackForValidation{
HMACSecret: vc.HMACSecret,
HMACSecretEnv: vc.HMACSecretEnv,
AllowUnsigned: vc.AllowUnsigned,
}); err != nil {
return verdictCallbackSecretField(vc.HMACSecretEnv), err
}
return "", nil
}
// validateVerdictCallbackSecret enforces fail-closed posture on the
// outbound HMAC: when the callback is enabled, either hmac_secret or the
// hmac_secret_env-named env var must resolve to a non-empty value, OR
// the operator must explicitly set allow_unsigned: true to acknowledge
// that requests and responses will run without integrity protection.
//
// Without this check a misconfigured deployment (env var typoed, secret
// not yet rotated in) silently emits unsigned POSTs while the daemon
// keeps reporting healthy, and any on-path actor can forge or replay
// block decisions.
func validateVerdictCallbackSecret(vc verdictCallbackForValidation) error {
if vc.AllowUnsigned {
return nil
}
if strings.TrimSpace(vc.HMACSecret) != "" {
return nil
}
if vc.HMACSecretEnv != "" {
if strings.TrimSpace(os.Getenv(vc.HMACSecretEnv)) != "" {
return nil
}
return fmt.Errorf("auto_response.verdict_callback.enabled=true but env var %q is empty or unset; set the secret, or opt in with allow_unsigned: true", vc.HMACSecretEnv)
}
return fmt.Errorf("auto_response.verdict_callback.enabled=true requires hmac_secret or hmac_secret_env (or allow_unsigned: true to acknowledge unsigned requests and responses)")
}
// verdictCallbackForValidation isolates the fields validateVerdictCallbackSecret
// needs without re-spelling the anonymous struct literal in config.go.
type verdictCallbackForValidation struct {
HMACSecret string
HMACSecretEnv string
AllowUnsigned bool
}
func verdictCallbackSecretField(hmacSecretEnv string) string {
if hmacSecretEnv != "" {
return "auto_response.verdict_callback.hmac_secret_env"
}
return "auto_response.verdict_callback.hmac_secret"
}
func validateDirectSMTPEgress(cfg *Config) error {
d := cfg.Detection.DirectSMTPEgress
switch strings.ToLower(strings.TrimSpace(d.Backend)) {
case "", "auto", "bpf", "legacy", "none":
default:
return fmt.Errorf("detection.direct_smtp_egress.backend must be auto, bpf, legacy, or none")
}
for i, p := range d.Ports {
if p < 1 || p > 65535 {
return fmt.Errorf("detection.direct_smtp_egress.ports[%d] must be between 1 and 65535", i)
}
}
return nil
}
func validateBPFEnforcement(cfg *Config) error {
if !cfg.BPFEnforcement.Enabled {
return nil
}
switch strings.ToLower(strings.TrimSpace(cfg.Detection.ConnectionTrackerBackend)) {
case "", "auto", "bpf":
case "legacy", "none":
return fmt.Errorf("bpf_enforcement.enabled=true requires detection.connection_tracker_backend=auto or bpf")
default:
return fmt.Errorf("detection.connection_tracker_backend must be auto, bpf, legacy, or none")
}
gates := 0
if cfg.BPFEnforcement.DirectSMTPEgress {
if !cfg.Detection.DirectSMTPEgress.Enabled {
return fmt.Errorf("bpf_enforcement.direct_smtp_egress requires detection.direct_smtp_egress.enabled=true")
}
switch strings.ToLower(strings.TrimSpace(cfg.Detection.DirectSMTPEgress.Backend)) {
case "", "auto", "bpf":
case "legacy", "none":
return fmt.Errorf("bpf_enforcement.direct_smtp_egress requires detection.direct_smtp_egress.backend=auto or bpf")
default:
return fmt.Errorf("detection.direct_smtp_egress.backend must be auto, bpf, legacy, or none")
}
gates++
}
if gates == 0 {
return fmt.Errorf("bpf_enforcement.enabled=true requires at least one feature gate (direct_smtp_egress)")
}
return nil
}
func validateWebUITokens(cfg *Config) error {
seenNames := make(map[string]struct{}, len(cfg.WebUI.Tokens))
seenTokens := make(map[string]struct{}, len(cfg.WebUI.Tokens))
for i, tok := range cfg.WebUI.Tokens {
name := strings.TrimSpace(tok.Name)
if name == "" {
return fmt.Errorf("webui.tokens[%d]: empty name", i)
}
if tok.Scope != "admin" && tok.Scope != "read" {
return fmt.Errorf("webui.tokens[%d]: unknown scope %q (use admin or read)", i, tok.Scope)
}
if tok.Token == "" {
return fmt.Errorf("webui.tokens[%d]: empty token", i)
}
if _, ok := seenNames[name]; ok {
return fmt.Errorf("webui.tokens[%d]: duplicate name %q", i, tok.Name)
}
seenNames[name] = struct{}{}
if _, ok := seenTokens[tok.Token]; ok {
return fmt.Errorf("webui.tokens[%d]: duplicate token", i)
}
seenTokens[tok.Token] = struct{}{}
}
return nil
}
func validateMailLogsField(cfg *Config) (string, error) {
switch cfg.MailLogs.Source {
case "", "auto", "file", "journal":
default:
return "mail_logs.source", fmt.Errorf("mail_logs.source: must be auto, file, or journal (got %q)", cfg.MailLogs.Source)
}
for i, unit := range cfg.MailLogs.Units {
if strings.TrimSpace(unit) == "" {
return "mail_logs.units", fmt.Errorf("mail_logs.units[%d]: empty unit", i)
}
}
return "", nil
}
func validateMailLogs(cfg *Config) error {
_, err := validateMailLogsField(cfg)
return err
}
func validateMailBruteAccountKeyField(cfg *Config) (string, error) {
key := cfg.Thresholds.MailBruteAccountKey
switch {
case key == "", key == "builtin:dovecot-user", key == "builtin:postfix-sasl":
// ok
case strings.HasPrefix(key, "regex:"):
re, err := regexp.Compile(strings.TrimPrefix(key, "regex:"))
if err != nil {
return "thresholds.mail_brute_account_key", fmt.Errorf("mail_brute_account_key: invalid regex: %w", err)
}
if re.NumSubexp() < 1 {
return "thresholds.mail_brute_account_key", fmt.Errorf("mail_brute_account_key: regex must contain at least one capture group")
}
default:
return "thresholds.mail_brute_account_key", fmt.Errorf("mail_brute_account_key: %q must be builtin:* or regex:*", key)
}
return "", nil
}
func validateMailBruteAccountKey(cfg *Config) error {
_, err := validateMailBruteAccountKeyField(cfg)
return err
}
func Load(path string) (*Config, error) {
// #nosec G304 -- path is operator-supplied config file (CLI flag / env).
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("reading config %s: %w", path, err)
}
defer f.Close()
data, err := readConfigBytesLimited(f)
if errors.Is(err, errConfigTooLarge) {
return nil, fmt.Errorf("config %s exceeds %d byte cap", path, MaxConfigBytes)
}
if err != nil {
return nil, fmt.Errorf("reading config %s: %w", path, err)
}
cfg, err := LoadBytes(data)
if err != nil {
return nil, err
}
cfg.ConfigFile = path
return cfg, nil
}
// LoadWithDir loads the main config file and then merges every YAML fragment
// from confDir on top in lexicographic order. A missing confDir is not an
// error. Unknown fields in fragments are rejected (KnownFields=true).
func LoadWithDir(path, confDir string) (*Config, error) {
// #nosec G304 -- path is operator-supplied (CLI flag).
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("reading config %s: %w", path, err)
}
defer f.Close()
mainData, err := readConfigBytesLimited(f)
if errors.Is(err, errConfigTooLarge) {
return nil, fmt.Errorf("config %s exceeds %d byte cap", path, MaxConfigBytes)
}
if err != nil {
return nil, fmt.Errorf("reading config %s: %w", path, err)
}
var merged yaml.Node
if unmarshalErr := yaml.Unmarshal(mainData, &merged); unmarshalErr != nil {
return nil, fmt.Errorf("parsing %s: %w", path, unmarshalErr)
}
frags, err := loadConfDirFragments(confDir)
if err != nil {
return nil, err
}
for _, frag := range frags {
DeepMergeTracked(&merged, frag.node, func(keyPath, oldVal, newVal string) {
fmt.Fprintf(os.Stderr, "confd: %s overrides %s: %q -> %q\n",
frag.path,
keyPath,
redactConfigScalarForLog(keyPath, oldVal),
redactConfigScalarForLog(keyPath, newVal),
)
})
}
mergedBytes, err := yaml.Marshal(&merged)
if err != nil {
return nil, fmt.Errorf("marshaling merged config: %w", err)
}
cfg, err := LoadBytes(mergedBytes)
if err != nil {
return nil, err
}
cfg.ConfigFile = path
cfg.ConfigDir = confDir
return cfg, nil
}
func Save(cfg *Config) error {
data, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("marshaling config: %w", err)
}
path, err := saveTargetPath(cfg.ConfigFile)
if err != nil {
return err
}
// csm.yaml is the daemon's only config: a truncate-in-place write torn
// by a crash would block the next daemon start with a parse error.
return atomicio.AtomicWrite(path, 0600, data)
}
func saveTargetPath(path string) (string, error) {
info, err := os.Lstat(path)
if err != nil {
if os.IsNotExist(err) {
return path, nil
}
return "", fmt.Errorf("checking config path: %w", err)
}
if info.Mode()&os.ModeSymlink == 0 {
return path, nil
}
// FHS migration leaves the legacy config path as a symlink to the
// main config. Save must update the target instead of replacing the link.
resolved, err := filepath.EvalSymlinks(path)
if err != nil {
return "", fmt.Errorf("resolving config symlink: %w", err)
}
return resolved, nil
}
package config
import "time"
// EmailAVConfig holds email antivirus scanning settings.
type EmailAVConfig struct {
Enabled bool `yaml:"enabled"`
ClamdSocket string `yaml:"clamd_socket"`
ScanTimeout string `yaml:"scan_timeout"`
MaxAttachmentSize int64 `yaml:"max_attachment_size"`
MaxArchiveDepth int `yaml:"max_archive_depth"`
MaxArchiveFiles int `yaml:"max_archive_files"`
MaxExtractionSize int64 `yaml:"max_extraction_size"`
QuarantineInfected bool `yaml:"quarantine_infected"`
ScanConcurrency int `yaml:"scan_concurrency"`
// FailMode controls behavior when scanning cannot complete.
// "open" (default): deliver mail when engines are down, scans time out, or quarantine fails.
// "tempfail": defer delivery (Exim retries later) when scanning cannot complete or infected mail cannot be quarantined.
FailMode string `yaml:"fail_mode"`
}
// ScanTimeoutDuration parses the ScanTimeout string as a time.Duration.
func (c *EmailAVConfig) ScanTimeoutDuration() time.Duration {
if c.ScanTimeout == "" {
return 30 * time.Second
}
d, err := time.ParseDuration(c.ScanTimeout)
if err != nil {
return 30 * time.Second
}
return d
}
// EmailAVDefaults applies default values to an EmailAVConfig.
func EmailAVDefaults(c *EmailAVConfig) {
if c.ClamdSocket == "" {
c.ClamdSocket = "/var/run/clamd.scan/clamd.sock"
}
if c.ScanTimeout == "" {
c.ScanTimeout = "30s"
}
if c.MaxAttachmentSize == 0 {
c.MaxAttachmentSize = 25 * 1024 * 1024 // 25 MB
}
if c.MaxArchiveDepth == 0 {
c.MaxArchiveDepth = 1
}
if c.MaxArchiveFiles == 0 {
c.MaxArchiveFiles = 50
}
if c.MaxExtractionSize == 0 {
c.MaxExtractionSize = 100 * 1024 * 1024 // 100 MB
}
if c.ScanConcurrency == 0 {
c.ScanConcurrency = 4
}
}
package config
import (
"reflect"
"sync/atomic"
)
// Hot-reload policy (ROADMAP item 7).
//
// Each top-level field of Config carries an optional `hotreload`
// struct tag:
//
// - "safe": a SIGHUP reload swaps the field in place; readers
// on the next tick see the new value.
// - "restart": the field cannot be applied without a full daemon
// restart (fanotify watched roots, bbolt path, web UI
// listener). A SIGHUP that touches a restart field
// logs a warning and leaves the prior config live.
// - (none): treated as restart-required. Tagging every safe
// field explicitly is the strict default; this closes
// the door on accidental hot-swaps of a fresh field
// someone adds without considering the safety of live
// mutation.
const (
TagSafe = "safe"
TagRestart = "restart"
)
// active holds the current live config. Readers on hot paths (check
// tick handlers, alert dispatchers, metrics-auth) call Active() to
// pick up the latest snapshot after a SIGHUP. Writers (daemon
// startup, SIGHUP reload) call SetActive.
var active atomic.Pointer[Config]
// Active returns the current live Config pointer. Returns nil if
// SetActive has not been called; callers on hot paths are expected
// to nil-check once per call. Daemon startup calls SetActive before
// any tick runs, so a nil return in production is a bug.
func Active() *Config {
return active.Load()
}
// SetActive installs cfg as the current live config.
func SetActive(cfg *Config) {
active.Store(cfg)
}
// Change describes a single top-level field that differs between an
// old and a new Config.
type Change struct {
// Field is the YAML name (from the `yaml:"..."` struct tag), or
// the Go field name if no yaml tag is set.
Field string
// Tag is the hotreload classification: TagSafe, TagRestart, or
// "" for fields with no explicit tag (treated as TagRestart).
Tag string
}
// ReloadPolicy is the top-level hot-reload manifest exposed to tests and
// operator surfaces that need to explain whether a Settings section can be
// applied live or waits for a restart.
type ReloadPolicy struct {
Field string
Tag string
RestartRequired bool
}
// HotReloadManifest returns every operator-owned top-level Config field with
// its effective reload policy. Fields without an explicit supported tag are
// classified restart-required, matching Diff and RestartRequired.
func HotReloadManifest() []ReloadPolicy {
cfgType := reflect.TypeOf(Config{})
policies := make([]ReloadPolicy, 0, cfgType.NumField())
for i := 0; i < cfgType.NumField(); i++ {
field := cfgType.Field(i)
if !field.IsExported() || isReloadManifestIgnoredRoot(field) {
continue
}
name := yamlFieldName(field)
if name == "" || name == "-" {
continue
}
tag := field.Tag.Get("hotreload")
if tag != TagSafe && tag != TagRestart {
tag = TagRestart
}
policies = append(policies, ReloadPolicy{
Field: name,
Tag: tag,
RestartRequired: tag != TagSafe,
})
}
return policies
}
func isReloadManifestIgnoredRoot(field reflect.StructField) bool {
return field.Name == "ConfigFile" || field.Name == "ConfigDir" || field.Name == "Integrity"
}
// Diff reports which Config fields differ between old and new,
// classified by hotreload tag.
//
// The walk is recursive: if a top-level field is tagged, its tag
// applies to any change inside. If a nested field has its own tag,
// that tag wins over the parent (field-level overrides let a single
// safe field sit inside an otherwise restart-required parent, which
// is how webui.metrics_token can hot-reload even though the rest of
// WebUI needs a restart).
//
// Each Change carries the YAML path from root (e.g. "thresholds" for
// the top-level struct, "webui.metrics_token" for a nested leaf).
// The tag is the nearest tagged ancestor on that path; if nothing on
// the path is tagged, the Change's Tag is "" and the caller should
// treat that as TagRestart.
//
// Granularity rule: if a tagged ancestor classifies the whole
// subtree uniformly (parent tag applies, no nested overrides on
// changed leaves), the Change is reported at the parent level. That
// keeps the common case ("I changed three thresholds") as one
// "thresholds" Change. When a subtree contains a differently-tagged
// leaf, that leaf is reported separately with its own tag, and the
// parent (minus that leaf) is reported with the inherited tag.
func Diff(oldCfg, newCfg *Config) []Change {
if oldCfg == nil || newCfg == nil {
return nil
}
oldV := reflect.ValueOf(*oldCfg)
newV := reflect.ValueOf(*newCfg)
return diffStruct(oldV, newV, "", "")
}
// diffStruct walks two reflect.Values of the same struct type and
// returns Changes for every differing field. parentPath is the
// already-composed YAML dotted path down to this struct (empty at
// the root). parentTag is the effective hotreload tag inherited
// from the nearest tagged ancestor.
func diffStruct(oldV, newV reflect.Value, parentPath, parentTag string) []Change {
var changes []Change
t := oldV.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if !field.IsExported() {
continue
}
// ConfigFile / ConfigDir / Integrity are daemon-managed process
// metadata, not operator policy fields.
if parentPath == "" && isReloadManifestIgnoredRoot(field) {
continue
}
oldField := oldV.Field(i).Interface()
newField := newV.Field(i).Interface()
if reflect.DeepEqual(oldField, newField) {
continue
}
// Effective tag for this field: its own explicit tag wins;
// otherwise inherit from the parent path.
tag := field.Tag.Get("hotreload")
if tag != TagSafe && tag != TagRestart {
tag = parentTag
}
name := yamlFieldName(field)
path := name
if parentPath != "" {
path = parentPath + "." + name
}
// If the field is itself a struct (not a pointer, slice,
// map), recurse so nested overrides can surface separately.
// Pointer-to-struct is treated as a leaf because the
// reflect.DeepEqual already told us the pointer target
// changed; re-walking it would produce duplicate noise.
if field.Type.Kind() == reflect.Struct {
nested := diffStruct(oldV.Field(i), newV.Field(i), path, tag)
// If every nested Change carries the same tag and there
// is no mixed classification, collapse to a single
// Change at this level. Operators rarely need the
// granularity "I changed thresholds.mail_queue_warn";
// the collapse keeps the common case clean.
if collapsed, ok := collapseIfUniform(nested, path, tag); ok {
changes = append(changes, collapsed)
} else {
changes = append(changes, nested...)
}
continue
}
changes = append(changes, Change{Field: path, Tag: tag})
}
return changes
}
// collapseIfUniform returns (Change{Field:path, Tag:parentTag}, true)
// when every nested change inherits parentTag (i.e. nothing nested
// overrode it). Returns (_, false) when the subtree contains a
// differently-tagged leaf, which means the caller must keep the
// granular changes.
func collapseIfUniform(nested []Change, path, parentTag string) (Change, bool) {
if len(nested) == 0 {
return Change{}, false
}
for _, c := range nested {
if c.Tag != parentTag {
return Change{}, false
}
}
return Change{Field: path, Tag: parentTag}, true
}
// RestartRequired returns true if any change in the diff carries a
// TagRestart classification (or no tag, which collapses to restart).
func RestartRequired(changes []Change) bool {
for _, c := range changes {
if c.Tag != TagSafe {
return true
}
}
return false
}
// yamlFieldName returns the yaml tag's primary name if set, else the
// Go field name. Strips any `,omitempty` / `,inline` suffix.
func yamlFieldName(f reflect.StructField) string {
tag := f.Tag.Get("yaml")
if tag == "" || tag == "-" {
return f.Name
}
for i := 0; i < len(tag); i++ {
if tag[i] == ',' {
return tag[:i]
}
}
return tag
}
package config
import "gopkg.in/yaml.v3"
// CollisionFn is invoked when DeepMergeTracked detects a scalar in the
// overlay overwriting a different scalar in the base. Identical-value
// rewrites are not reported because the operator cannot act on them.
// keyPath uses dotted YAML notation rooted at the document
// ("mail_logs.source"); top-level keys have no parent.
type CollisionFn func(keyPath, oldVal, newVal string)
// DeepMerge merges overlay into base in place and returns base.
// Both inputs must be DocumentNodes. Rules:
// - mapping ∩ mapping → key-by-key recurse
// - sequence ∩ sequence → append (base then overlay), with duplicate
// scalar entries removed from all-scalar lists
// - any other combination → overlay replaces base
//
// AliasNodes are treated as opaque scalars: an overlay alias replaces the
// base node; an alias inside base/overlay is not resolved before merging.
func DeepMerge(base, overlay *yaml.Node) *yaml.Node {
return DeepMergeTracked(base, overlay, nil)
}
// DeepMergeTracked is DeepMerge with an optional collision callback. The
// callback fires once per scalar-vs-scalar override across the document
// tree. Callers can pass nil for the previous DeepMerge behaviour.
func DeepMergeTracked(base, overlay *yaml.Node, onCollision CollisionFn) *yaml.Node {
if base == nil || overlay == nil {
return base
}
// An empty yaml.Unmarshal result has Kind==0; treat it as an empty document.
if base.Kind == 0 {
base.Kind = yaml.DocumentNode
}
if overlay.Kind == 0 {
overlay.Kind = yaml.DocumentNode
}
if base.Kind != yaml.DocumentNode || overlay.Kind != yaml.DocumentNode {
return base
}
if len(overlay.Content) == 0 {
return base
}
if len(base.Content) == 0 {
base.Content = overlay.Content
return base
}
mergeNodesAt(base.Content[0], overlay.Content[0], "", onCollision)
return base
}
func mergeNodesAt(b, o *yaml.Node, path string, onCollision CollisionFn) {
switch {
case b.Kind == yaml.MappingNode && o.Kind == yaml.MappingNode:
mergeMapAt(b, o, path, onCollision)
case b.Kind == yaml.SequenceNode && o.Kind == yaml.SequenceNode:
b.Content = dedupScalarSequence(append(b.Content, o.Content...))
default:
if onCollision != nil && b.Kind == yaml.ScalarNode && o.Kind == yaml.ScalarNode && b.Value != o.Value {
onCollision(path, b.Value, o.Value)
}
*b = *o
}
}
// dedupScalarSequence removes duplicate scalar entries (by value+tag),
// keeping the first occurrence and preserving order. It only acts when every
// element is a scalar: lists of maps (e.g. webui.tokens) keep every entry,
// where position and identity matter. Idempotent-by-content security lists
// (infra_ips, c2_blocklist, trusted_countries, disabled_checks) merged from a
// fragment that repeats a main-config entry would otherwise carry duplicates
// into validation and enforcement on every load.
func dedupScalarSequence(content []*yaml.Node) []*yaml.Node {
for _, n := range content {
if n.Kind != yaml.ScalarNode {
return content
}
}
seen := make(map[string]struct{}, len(content))
out := content[:0]
for _, n := range content {
key := n.Tag + "\x00" + n.Value
if _, dup := seen[key]; dup {
continue
}
seen[key] = struct{}{}
out = append(out, n)
}
return out
}
func mergeMapAt(b, o *yaml.Node, parent string, onCollision CollisionFn) {
for i := 0; i+1 < len(o.Content); i += 2 {
key := o.Content[i].Value
val := o.Content[i+1]
childPath := key
if parent != "" {
childPath = parent + "." + key
}
if idx := findKey(b, key); idx >= 0 {
mergeNodesAt(b.Content[idx+1], val, childPath, onCollision)
} else {
b.Content = append(b.Content, o.Content[i], val)
}
}
}
func findKey(m *yaml.Node, key string) int {
for i := 0; i+1 < len(m.Content); i += 2 {
if m.Content[i].Value == key {
return i
}
}
return -1
}
package config
const redactedValue = "***REDACTED***"
var sensitiveScalarPaths = map[string]struct{}{
"alerts.webhook.hmac_secret": {},
"auto_response.verdict_callback.hmac_secret": {},
"challenge.captcha_fallback.secret_key": {},
"challenge.secret": {},
"challenge.verified_session.admin_secret": {},
"geoip.license_key": {},
"integrity.binary_hash": {},
"integrity.config_hash": {},
"integrity.confd_hash": {},
"reputation.abuseipdb_key": {},
"reputation.rspamd.token": {},
"reputation.upstream.token": {},
"sentry.dsn": {},
"webui.auth_token": {},
"webui.metrics_token": {},
}
// Redact returns a copy of the config with sensitive fields replaced.
// Empty fields are left empty (not replaced with the redaction marker).
// The original config is not modified.
func Redact(cfg *Config) *Config {
// Shallow copy the struct
c := *cfg
// Redact secrets (only if non-empty)
if c.WebUI.AuthToken != "" {
c.WebUI.AuthToken = redactedValue
}
if c.WebUI.MetricsToken != "" {
c.WebUI.MetricsToken = redactedValue
}
if len(c.WebUI.Tokens) > 0 {
c.WebUI.Tokens = append([]WebUIToken(nil), c.WebUI.Tokens...)
for i := range c.WebUI.Tokens {
if c.WebUI.Tokens[i].Token != "" {
c.WebUI.Tokens[i].Token = redactedValue
}
}
}
if c.Alerts.Webhook.HMACSecret != "" {
c.Alerts.Webhook.HMACSecret = redactedValue
}
if c.GeoIP.LicenseKey != "" {
c.GeoIP.LicenseKey = redactedValue
}
if c.Reputation.AbuseIPDBKey != "" {
c.Reputation.AbuseIPDBKey = redactedValue
}
if c.Reputation.Rspamd.Token != "" {
c.Reputation.Rspamd.Token = redactedValue
}
if c.Reputation.Upstream.Token != "" {
c.Reputation.Upstream.Token = redactedValue
}
if c.AutoResponse.VerdictCallback.HMACSecret != "" {
c.AutoResponse.VerdictCallback.HMACSecret = redactedValue
}
if c.Challenge.Secret != "" {
c.Challenge.Secret = redactedValue
}
if c.Challenge.CaptchaFallback.SecretKey != "" {
c.Challenge.CaptchaFallback.SecretKey = redactedValue
}
if c.Challenge.VerifiedSession.AdminSecret != "" {
c.Challenge.VerifiedSession.AdminSecret = redactedValue
}
if c.Integrity.BinaryHash != "" {
c.Integrity.BinaryHash = redactedValue
}
if c.Integrity.ConfigHash != "" {
c.Integrity.ConfigHash = redactedValue
}
if c.Integrity.ConfdHash != "" {
c.Integrity.ConfdHash = redactedValue
}
if c.Sentry.DSN != "" {
c.Sentry.DSN = redactedValue
}
// Deep-copy Firewall pointer so we don't share it with the original
if cfg.Firewall != nil {
fw := *cfg.Firewall
c.Firewall = &fw
}
return &c
}
func redactConfigScalarForLog(keyPath, value string) string {
if value == "" {
return value
}
if _, ok := sensitiveScalarPaths[keyPath]; ok {
return redactedValue
}
return value
}
package config
import (
"reflect"
"strings"
"time"
)
// Schema returns a JSON Schema (draft-07-style, partial) describing the
// Config struct via reflection over yaml: tags. Used by phpanel's config
// editor for client-side validation. Not a complete spec implementation -
// covers YAML field names, scalar/container types, and nested objects.
//
// IMPORTANT: This schema is structural only. Imperative validation rules
// enforced by Validate() (e.g., mail_logs.source must be auto/file/journal,
// webui.tokens[].scope must be admin or read) are NOT encoded here.
// Phpanel must still call `csm validate` for the authoritative check.
func Schema() map[string]interface{} {
return reflectStruct(reflect.TypeOf(Config{}))
}
func reflectStruct(t reflect.Type) map[string]interface{} {
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return map[string]interface{}{"type": "object"}
}
props := map[string]interface{}{}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
tag := f.Tag.Get("yaml")
if tag == "" || tag == "-" {
continue
}
name, _ := splitYAMLTag(tag)
if name == "" {
continue
}
props[name] = reflectField(f.Type)
}
return map[string]interface{}{
"type": "object",
"properties": props,
}
}
func reflectField(t reflect.Type) map[string]interface{} {
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
if t == reflect.TypeOf(time.Duration(0)) {
return map[string]interface{}{"type": "string", "format": "duration"}
}
switch t.Kind() {
case reflect.String:
return map[string]interface{}{"type": "string"}
case reflect.Bool:
return map[string]interface{}{"type": "boolean"}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return map[string]interface{}{"type": "integer"}
case reflect.Float32, reflect.Float64:
return map[string]interface{}{"type": "number"}
case reflect.Slice, reflect.Array:
return map[string]interface{}{"type": "array", "items": reflectField(t.Elem())}
case reflect.Map:
return map[string]interface{}{"type": "object", "additionalProperties": reflectField(t.Elem())}
case reflect.Struct:
return reflectStruct(t)
default:
return map[string]interface{}{}
}
}
func splitYAMLTag(tag string) (string, []string) {
parts := strings.Split(tag, ",")
return parts[0], parts[1:]
}
package config
import (
"fmt"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"path/filepath"
"strings"
"time"
)
// ValidationResult represents a single validation finding.
type ValidationResult struct {
Level string // "error", "warn", "ok"
Field string // dotted path matching YAML keys
Message string
}
// String implements the Stringer interface for nice printing.
func (v ValidationResult) String() string {
return fmt.Sprintf("[%s] %s: %s", strings.ToUpper(v.Level), v.Field, v.Message)
}
// Validate checks the config for errors, warnings, and emits OK for valid sections.
func Validate(cfg *Config) []ValidationResult {
var results []ValidationResult
// --- Hostname ---
if cfg.Hostname == "" || cfg.Hostname == "SET_HOSTNAME_HERE" {
results = append(results, ValidationResult{"error", "hostname", "hostname is not set"})
} else {
results = append(results, ValidationResult{"ok", "hostname", cfg.Hostname})
}
// --- Alerts ---
if !cfg.Alerts.Email.Enabled && !cfg.Alerts.Webhook.Enabled {
results = append(results, ValidationResult{"error", "alerts", "no alert method enabled (enable email or webhook)"})
}
// --- Email alerts ---
if cfg.Alerts.Email.Enabled {
if len(cfg.Alerts.Email.To) == 0 {
results = append(results, ValidationResult{"error", "alerts.email.to", "email alerts enabled but no recipients configured"})
} else {
valid := true
for _, to := range cfg.Alerts.Email.To {
if to == "SET_EMAIL_HERE" || !strings.Contains(to, "@") {
results = append(results, ValidationResult{"error", "alerts.email.to", fmt.Sprintf("invalid email recipient: %s", to)})
valid = false
}
}
if valid {
results = append(results, ValidationResult{"ok", "alerts.email.to", strings.Join(cfg.Alerts.Email.To, ", ")})
}
}
if cfg.Alerts.Email.From == "" {
results = append(results, ValidationResult{"error", "alerts.email.from", "email alerts enabled but no from address configured"})
}
if cfg.Alerts.Email.SMTP == "" {
results = append(results, ValidationResult{"error", "alerts.email.smtp", "email alerts enabled but no SMTP server configured"})
} else {
results = append(results, ValidationResult{"ok", "alerts.email.smtp", cfg.Alerts.Email.SMTP})
}
}
// --- Webhook ---
if cfg.Alerts.Webhook.Enabled {
if cfg.Alerts.Webhook.URL == "" {
results = append(results, ValidationResult{"error", "alerts.webhook.url", "webhook alerts enabled but no URL configured"})
} else {
results = append(results, ValidationResult{"ok", "alerts.webhook.url", cfg.Alerts.Webhook.URL})
}
switch cfg.Alerts.Webhook.Type {
case "", "slack", "discord", "generic", "phpanel":
default:
results = append(results, ValidationResult{"error", "alerts.webhook.type", fmt.Sprintf("unknown webhook type %q", cfg.Alerts.Webhook.Type)})
}
if cfg.Alerts.Webhook.Type == "phpanel" {
secret := cfg.Alerts.Webhook.HMACSecret
if cfg.Alerts.Webhook.HMACSecretEnv != "" {
if v := os.Getenv(cfg.Alerts.Webhook.HMACSecretEnv); v != "" {
secret = v
}
}
if secret == "" {
field := "alerts.webhook.hmac_secret"
if cfg.Alerts.Webhook.HMACSecretEnv != "" {
field = "alerts.webhook.hmac_secret_env"
}
results = append(results, ValidationResult{"error", field, "phpanel webhook enabled but no HMAC secret configured"})
}
}
}
// --- Heartbeat ---
if cfg.Alerts.Heartbeat.Enabled {
if cfg.Alerts.Heartbeat.URL == "" {
results = append(results, ValidationResult{"error", "alerts.heartbeat.url", "heartbeat enabled but no URL configured"})
} else {
results = append(results, ValidationResult{"ok", "alerts.heartbeat.url", cfg.Alerts.Heartbeat.URL})
}
}
// --- MaxPerHour ---
if cfg.Alerts.MaxPerHour <= 0 {
results = append(results, ValidationResult{"error", "alerts.max_per_hour", "max_per_hour must be > 0"})
}
// --- WebUI ---
if cfg.WebUI.Enabled {
if err := validateWebUITokens(cfg); err != nil {
results = append(results, ValidationResult{"error", "webui.tokens", err.Error()})
}
tokenCount, adminCount := webUITokenCounts(cfg)
if tokenCount == 0 {
results = append(results, ValidationResult{"error", "webui.tokens", "webui enabled but no auth token configured"})
} else {
results = append(results, ValidationResult{"ok", "webui", fmt.Sprintf("listening on %s", cfg.WebUI.Listen)})
if adminCount == 0 {
results = append(results, ValidationResult{"warn", "webui.tokens", "no admin-scope token configured; browser login and admin API calls are disabled"})
}
}
}
// --- Trusted countries ---
for _, cc := range cfg.Suppressions.TrustedCountries {
if len(cc) != 2 {
results = append(results, ValidationResult{"error", "suppressions.trusted_countries", fmt.Sprintf("invalid country code: %q (expected 2-letter ISO code)", cc)})
}
}
// --- Duration fields (only check if non-empty) ---
if cfg.AutoResponse.BlockExpiry != "" {
if _, err := time.ParseDuration(cfg.AutoResponse.BlockExpiry); err != nil {
results = append(results, ValidationResult{"error", "auto_response.block_expiry", fmt.Sprintf("unparseable duration: %s", cfg.AutoResponse.BlockExpiry)})
}
}
if cfg.Signatures.UpdateInterval != "" {
if _, err := time.ParseDuration(cfg.Signatures.UpdateInterval); err != nil {
results = append(results, ValidationResult{"error", "signatures.update_interval", fmt.Sprintf("unparseable duration: %s", cfg.Signatures.UpdateInterval)})
}
}
if cfg.Signatures.YaraForge.UpdateInterval != "" {
if _, err := time.ParseDuration(cfg.Signatures.YaraForge.UpdateInterval); err != nil {
results = append(results, ValidationResult{"error", "signatures.yara_forge.update_interval", fmt.Sprintf("unparseable duration: %s", cfg.Signatures.YaraForge.UpdateInterval)})
}
}
if cfg.EmailAV.ScanTimeout != "" {
if _, err := time.ParseDuration(cfg.EmailAV.ScanTimeout); err != nil {
results = append(results, ValidationResult{"error", "email_av.scan_timeout", fmt.Sprintf("unparseable duration: %s", cfg.EmailAV.ScanTimeout)})
}
}
if cfg.GeoIP.UpdateInterval != "" {
if _, err := time.ParseDuration(cfg.GeoIP.UpdateInterval); err != nil {
results = append(results, ValidationResult{"error", "geoip.update_interval", fmt.Sprintf("unparseable duration: %s", cfg.GeoIP.UpdateInterval)})
}
}
if cfg.AutoResponse.PermBlockInterval != "" {
if _, err := time.ParseDuration(cfg.AutoResponse.PermBlockInterval); err != nil {
results = append(results, ValidationResult{"error", "auto_response.permblock_interval", fmt.Sprintf("unparseable duration: %s", cfg.AutoResponse.PermBlockInterval)})
}
}
// --- Retention ---
if cfg.Retention.Enabled {
if cfg.Retention.SweepInterval != "" {
if _, err := time.ParseDuration(cfg.Retention.SweepInterval); err != nil {
results = append(results, ValidationResult{"error", "retention.sweep_interval", fmt.Sprintf("unparseable duration: %s", cfg.Retention.SweepInterval)})
}
}
if cfg.Retention.FindingsDays < 0 {
results = append(results, ValidationResult{"error", "retention.findings_days", fmt.Sprintf("findings_days must be >= 0 (0 disables the sweep), got %d", cfg.Retention.FindingsDays)})
}
if cfg.Retention.HistoryDays < 0 {
results = append(results, ValidationResult{"error", "retention.history_days", fmt.Sprintf("history_days must be >= 0, got %d", cfg.Retention.HistoryDays)})
}
if cfg.Retention.ReputationDays < 0 {
results = append(results, ValidationResult{"error", "retention.reputation_days", fmt.Sprintf("reputation_days must be >= 0, got %d", cfg.Retention.ReputationDays)})
}
if cfg.Retention.CompactMinSizeMB < 0 {
results = append(results, ValidationResult{"error", "retention.compact_min_size_mb", fmt.Sprintf("compact_min_size_mb must be >= 0, got %d", cfg.Retention.CompactMinSizeMB)})
}
if cfg.Retention.CompactFillRatio <= 0 || cfg.Retention.CompactFillRatio > 1 {
results = append(results, ValidationResult{"error", "retention.compact_fill_ratio", fmt.Sprintf("compact_fill_ratio must be in (0, 1], got %v", cfg.Retention.CompactFillRatio)})
}
}
// --- Firewall ---
if cfg.Firewall != nil && cfg.Firewall.Enabled {
if cfg.Firewall.ConnRateLimit <= 0 {
results = append(results, ValidationResult{"error", "firewall.conn_rate_limit", "conn_rate_limit must be > 0 when firewall enabled"})
}
if cfg.Firewall.ConnLimit < 0 {
results = append(results, ValidationResult{"error", "firewall.conn_limit", "conn_limit must be >= 0 when firewall enabled (0 = disabled)"})
}
if cfg.Firewall.ConnRateLimit > 0 && cfg.Firewall.ConnLimit >= 0 {
results = append(results, ValidationResult{"ok", "firewall", fmt.Sprintf("enabled, conn_rate_limit=%d, conn_limit=%d", cfg.Firewall.ConnRateLimit, cfg.Firewall.ConnLimit)})
}
}
// --- Challenge ---
if cfg.Challenge.Difficulty < 0 || cfg.Challenge.Difficulty > 5 {
results = append(results, ValidationResult{"error", "challenge.difficulty", fmt.Sprintf("difficulty must be 0-5, got %d", cfg.Challenge.Difficulty)})
}
if cfg.Challenge.ListenPort < 0 || cfg.Challenge.ListenPort > 65535 {
results = append(results, ValidationResult{"error", "challenge.listen_port", fmt.Sprintf("listen_port must be 0-65535, got %d", cfg.Challenge.ListenPort)})
} else if cfg.Challenge.Enabled && cfg.Challenge.ListenPort == 0 {
results = append(results, ValidationResult{"error", "challenge.listen_port", fmt.Sprintf("listen_port must be 1-65535 when challenge.enabled, got %d", cfg.Challenge.ListenPort)})
}
// --- EmailAV ---
if cfg.EmailAV.Enabled && cfg.EmailAV.MaxAttachmentSize <= 0 {
results = append(results, ValidationResult{"error", "email_av.max_attachment_size", "max_attachment_size must be > 0 when email_av enabled"})
}
if cfg.EmailAV.FailMode != "" && cfg.EmailAV.FailMode != "open" && cfg.EmailAV.FailMode != "tempfail" {
results = append(results, ValidationResult{"error", "email_av.fail_mode",
fmt.Sprintf("invalid fail_mode %q: must be \"open\" or \"tempfail\"", cfg.EmailAV.FailMode)})
}
if cfg.Signatures.UpdateURL != "" && cfg.Signatures.SigningKey == "" {
results = append(results, ValidationResult{"error", "signatures.signing_key",
"signing_key is required when signatures.update_url is configured"})
}
if cfg.Signatures.YaraForge.Enabled && cfg.Signatures.SigningKey == "" {
results = append(results, ValidationResult{"error", "signatures.signing_key",
"signing_key is required when signatures.yara_forge.enabled is true"})
}
if cfg.Signatures.YaraForge.Enabled && cfg.Signatures.YaraForge.DownloadURL == "" {
results = append(results, ValidationResult{"error", "signatures.yara_forge.download_url",
"download_url is required because upstream YARA Forge releases do not publish CSM detached signatures"})
}
if cfg.Signatures.YaraForge.DownloadURL != "" {
if err := validateSignatureURL(cfg.Signatures.YaraForge.DownloadURL, true); err != nil {
results = append(results, ValidationResult{"error", "signatures.yara_forge.download_url", err.Error()})
}
}
if cfg.Signatures.UpdateURL != "" {
if err := validateSignatureURL(cfg.Signatures.UpdateURL, false); err != nil {
results = append(results, ValidationResult{"error", "signatures.update_url", err.Error()})
}
}
if err := validateDirectSMTPEgress(cfg); err != nil {
results = append(results, ValidationResult{"error", "detection.direct_smtp_egress", err.Error()})
}
if err := validateBPFEnforcement(cfg); err != nil {
results = append(results, ValidationResult{"error", "bpf_enforcement", err.Error()})
}
// --- EmailProtection ---
if cfg.EmailProtection.RateWarnThreshold > 0 && cfg.EmailProtection.RateWarnThreshold < 10 {
results = append(results, ValidationResult{"warn", "email_protection.rate_warn_threshold", "rate_warn_threshold < 10 may cause excessive alerts"})
}
if cfg.EmailProtection.RateCritThreshold > 0 && cfg.EmailProtection.RateCritThreshold <= cfg.EmailProtection.RateWarnThreshold {
results = append(results, ValidationResult{"error", "email_protection.rate_crit_threshold", "rate_crit_threshold must be > rate_warn_threshold"})
}
if cfg.EmailProtection.RateWindowMin > 0 && (cfg.EmailProtection.RateWindowMin < 5 || cfg.EmailProtection.RateWindowMin > 60) {
results = append(results, ValidationResult{"error", "email_protection.rate_window_min", "rate_window_min must be between 5 and 60"})
}
if cfg.EmailProtection.PasswordCheckIntervalMin > 0 && cfg.EmailProtection.PasswordCheckIntervalMin < 60 {
results = append(results, ValidationResult{"warn", "email_protection.password_check_interval_min", "password_check_interval_min < 60 may cause high CPU from doveadm"})
}
// --- EmailProtection.PHPRelay bounds ---
// Bounds checks fire only when the operator has supplied a value
// (zero means "use the applyDefaults value" or, for AccountVolumePerHour,
// "auto-derive from cPanel maxemailsperhour"). PoliciesDir is NOT
// validated here -- filesystem state probes belong in ValidateDeep.
pr := cfg.EmailProtection.PHPRelay
if pr.RateWindowMin != 0 && (pr.RateWindowMin < 1 || pr.RateWindowMin > 60) {
results = append(results, ValidationResult{"error", "email_protection.php_relay.rate_window_min", fmt.Sprintf("rate_window_min must be between 1 and 60, got %d", pr.RateWindowMin)})
}
if pr.HeaderScoreVolumeMin != 0 && (pr.HeaderScoreVolumeMin < 2 || pr.HeaderScoreVolumeMin > 100) {
results = append(results, ValidationResult{"error", "email_protection.php_relay.header_score_volume_min", fmt.Sprintf("header_score_volume_min must be between 2 and 100, got %d", pr.HeaderScoreVolumeMin)})
}
if pr.AbsoluteVolumePerHour != 0 && (pr.AbsoluteVolumePerHour < 10 || pr.AbsoluteVolumePerHour > 1000) {
results = append(results, ValidationResult{"error", "email_protection.php_relay.absolute_volume_per_hour", fmt.Sprintf("absolute_volume_per_hour must be between 10 and 1000, got %d", pr.AbsoluteVolumePerHour)})
}
// AccountVolumePerHour: 0 is the documented "auto-derive" sentinel;
// only reject explicitly out-of-range positive values.
if pr.AccountVolumePerHour < 0 || pr.AccountVolumePerHour > 5000 {
results = append(results, ValidationResult{"error", "email_protection.php_relay.account_volume_per_hour", fmt.Sprintf("account_volume_per_hour must be between 0 (auto-derive) and 5000, got %d", pr.AccountVolumePerHour)})
}
if pr.ReputationFailuresPer24h != 0 && (pr.ReputationFailuresPer24h < 1 || pr.ReputationFailuresPer24h > 50) {
results = append(results, ValidationResult{"error", "email_protection.php_relay.reputation_failures_per_24h", fmt.Sprintf("reputation_failures_per_24h must be between 1 and 50, got %d", pr.ReputationFailuresPer24h)})
}
if pr.FanoutDistinctScripts != 0 && (pr.FanoutDistinctScripts < 2 || pr.FanoutDistinctScripts > 20) {
results = append(results, ValidationResult{"error", "email_protection.php_relay.fanout_distinct_scripts", fmt.Sprintf("fanout_distinct_scripts must be between 2 and 20, got %d", pr.FanoutDistinctScripts)})
}
if pr.FanoutWindowMin != 0 && (pr.FanoutWindowMin < 1 || pr.FanoutWindowMin > 60) {
results = append(results, ValidationResult{"error", "email_protection.php_relay.fanout_window_min", fmt.Sprintf("fanout_window_min must be between 1 and 60, got %d", pr.FanoutWindowMin)})
}
if pr.BaselineSigma != 0 && (pr.BaselineSigma < 2.0 || pr.BaselineSigma > 6.0) {
results = append(results, ValidationResult{"error", "email_protection.php_relay.baseline_sigma", fmt.Sprintf("baseline_sigma must be between 2.0 and 6.0, got %v", pr.BaselineSigma)})
}
if pr.BaselineObservationDays != 0 && (pr.BaselineObservationDays < 1 || pr.BaselineObservationDays > 30) {
results = append(results, ValidationResult{"error", "email_protection.php_relay.baseline_observation_days", fmt.Sprintf("baseline_observation_days must be between 1 and 30, got %d", pr.BaselineObservationDays)})
}
// --- AutoResponse.PHPRelay bounds ---
if cfg.AutoResponse.PHPRelay.MaxActionsPerMinute != 0 && (cfg.AutoResponse.PHPRelay.MaxActionsPerMinute < 1 || cfg.AutoResponse.PHPRelay.MaxActionsPerMinute > 600) {
results = append(results, ValidationResult{"error", "auto_response.php_relay.max_actions_per_minute", fmt.Sprintf("max_actions_per_minute must be between 1 and 600, got %d", cfg.AutoResponse.PHPRelay.MaxActionsPerMinute)})
}
if cfg.AutoResponse.MaxBlocksPerHour < 0 {
results = append(results, ValidationResult{"error", "auto_response.max_blocks_per_hour", fmt.Sprintf("max_blocks_per_hour must be >= 0 (0 uses default %d), got %d", DefaultMaxBlocksPerHour, cfg.AutoResponse.MaxBlocksPerHour)})
}
// --- SMTP brute-force thresholds ---
t := cfg.Thresholds
if t.DomlogMaxFiles != 0 && (t.DomlogMaxFiles < 1 || t.DomlogMaxFiles > 100000) {
results = append(results, ValidationResult{"error", "thresholds.domlog_max_files", "domlog_max_files must be between 1 and 100000"})
}
if t.AccountScanMaxFiles != 0 && (t.AccountScanMaxFiles < 1 || t.AccountScanMaxFiles > 100000) {
results = append(results, ValidationResult{"error", "thresholds.account_scan_max_files", "account_scan_max_files must be between 1 and 100000"})
}
if t.CrontabBase64BlobMaxBytes != 0 {
if t.CrontabBase64BlobMaxBytes < 1024 || t.CrontabBase64BlobMaxBytes > 1048576 {
results = append(results, ValidationResult{"error", "thresholds.crontab_base64_blob_max_bytes", "crontab_base64_blob_max_bytes must be between 1024 and 1048576"})
} else if t.CrontabBase64BlobMaxBytes%4 != 0 {
results = append(results, ValidationResult{"error", "thresholds.crontab_base64_blob_max_bytes", "crontab_base64_blob_max_bytes must be a multiple of 4 (standard base64 alignment)"})
}
}
if t.DomlogTailLines != 0 && (t.DomlogTailLines < 10 || t.DomlogTailLines > 100000) {
results = append(results, ValidationResult{"error", "thresholds.domlog_tail_lines", "domlog_tail_lines must be between 10 and 100000"})
}
if t.DomlogMaxAgeMin != 0 && (t.DomlogMaxAgeMin < 1 || t.DomlogMaxAgeMin > 1440) {
results = append(results, ValidationResult{"error", "thresholds.domlog_max_age_min", "domlog_max_age_min must be between 1 and 1440"})
}
if t.MailLogTailLines != 0 && (t.MailLogTailLines < 10 || t.MailLogTailLines > 100000) {
results = append(results, ValidationResult{"error", "thresholds.mail_log_tail_lines", "mail_log_tail_lines must be between 10 and 100000"})
}
if t.SyslogMessagesTailLines != 0 && (t.SyslogMessagesTailLines < 10 || t.SyslogMessagesTailLines > 100000) {
results = append(results, ValidationResult{"error", "thresholds.syslog_messages_tail_lines", "syslog_messages_tail_lines must be between 10 and 100000"})
}
if t.CredStuffingDistinctAccounts != 0 && (t.CredStuffingDistinctAccounts < 2 || t.CredStuffingDistinctAccounts > 200) {
results = append(results, ValidationResult{"error", "thresholds.cred_stuffing_distinct_accounts", "cred_stuffing_distinct_accounts must be between 2 and 200"})
}
if t.SMTPBruteForceThreshold != 0 && (t.SMTPBruteForceThreshold < 2 || t.SMTPBruteForceThreshold > 50) {
results = append(results, ValidationResult{"error", "thresholds.smtp_bruteforce_threshold", "smtp_bruteforce_threshold must be between 2 and 50"})
}
if t.SMTPBruteForceWindowMin != 0 && (t.SMTPBruteForceWindowMin < 1 || t.SMTPBruteForceWindowMin > 60) {
results = append(results, ValidationResult{"error", "thresholds.smtp_bruteforce_window_min", "smtp_bruteforce_window_min must be between 1 and 60"})
}
if t.SMTPBruteForceSuppressMin != 0 && (t.SMTPBruteForceSuppressMin < 1 || t.SMTPBruteForceSuppressMin > 1440) {
results = append(results, ValidationResult{"error", "thresholds.smtp_bruteforce_suppress_min", "smtp_bruteforce_suppress_min must be between 1 and 1440"})
}
if t.SMTPBruteForceSubnetThresh != 0 && (t.SMTPBruteForceSubnetThresh < 2 || t.SMTPBruteForceSubnetThresh > 64) {
results = append(results, ValidationResult{"error", "thresholds.smtp_bruteforce_subnet_threshold", "smtp_bruteforce_subnet_threshold must be between 2 and 64"})
}
if t.SMTPAccountSprayThreshold != 0 && (t.SMTPAccountSprayThreshold < 2 || t.SMTPAccountSprayThreshold > 200) {
results = append(results, ValidationResult{"error", "thresholds.smtp_account_spray_threshold", "smtp_account_spray_threshold must be between 2 and 200"})
}
if t.SMTPBruteForceMaxTracked != 0 && (t.SMTPBruteForceMaxTracked < 1000 || t.SMTPBruteForceMaxTracked > 200000) {
results = append(results, ValidationResult{"error", "thresholds.smtp_bruteforce_max_tracked", "smtp_bruteforce_max_tracked must be between 1000 and 200000"})
}
if t.SMTPProbeThreshold != 0 && (t.SMTPProbeThreshold < 10 || t.SMTPProbeThreshold > 10000) {
results = append(results, ValidationResult{"error", "thresholds.smtp_probe_threshold", "smtp_probe_threshold must be between 10 and 10000"})
}
if t.SMTPProbeWindowMin != 0 && (t.SMTPProbeWindowMin < 1 || t.SMTPProbeWindowMin > 60) {
results = append(results, ValidationResult{"error", "thresholds.smtp_probe_window_min", "smtp_probe_window_min must be between 1 and 60"})
}
if t.SMTPProbeSuppressMin != 0 && (t.SMTPProbeSuppressMin < 1 || t.SMTPProbeSuppressMin > 1440) {
results = append(results, ValidationResult{"error", "thresholds.smtp_probe_suppress_min", "smtp_probe_suppress_min must be between 1 and 1440"})
}
if t.SMTPProbeMaxTracked != 0 && (t.SMTPProbeMaxTracked < 1000 || t.SMTPProbeMaxTracked > 200000) {
results = append(results, ValidationResult{"error", "thresholds.smtp_probe_max_tracked", "smtp_probe_max_tracked must be between 1000 and 200000"})
}
if t.MailBruteForceThreshold != 0 && (t.MailBruteForceThreshold < 2 || t.MailBruteForceThreshold > 50) {
results = append(results, ValidationResult{"error", "thresholds.mail_bruteforce_threshold", "mail_bruteforce_threshold must be between 2 and 50"})
}
if t.MailBruteForceWindowMin != 0 && (t.MailBruteForceWindowMin < 1 || t.MailBruteForceWindowMin > 60) {
results = append(results, ValidationResult{"error", "thresholds.mail_bruteforce_window_min", "mail_bruteforce_window_min must be between 1 and 60"})
}
if t.MailBruteForceSuppressMin != 0 && (t.MailBruteForceSuppressMin < 1 || t.MailBruteForceSuppressMin > 1440) {
results = append(results, ValidationResult{"error", "thresholds.mail_bruteforce_suppress_min", "mail_bruteforce_suppress_min must be between 1 and 1440"})
}
if t.MailBruteForceSubnetThresh != 0 && (t.MailBruteForceSubnetThresh < 2 || t.MailBruteForceSubnetThresh > 64) {
results = append(results, ValidationResult{"error", "thresholds.mail_bruteforce_subnet_threshold", "mail_bruteforce_subnet_threshold must be between 2 and 64"})
}
if t.MailAccountSprayThreshold != 0 && (t.MailAccountSprayThreshold < 2 || t.MailAccountSprayThreshold > 200) {
results = append(results, ValidationResult{"error", "thresholds.mail_account_spray_threshold", "mail_account_spray_threshold must be between 2 and 200"})
}
if t.MailBruteForceMaxTracked != 0 && (t.MailBruteForceMaxTracked < 1000 || t.MailBruteForceMaxTracked > 200000) {
results = append(results, ValidationResult{"error", "thresholds.mail_bruteforce_max_tracked", "mail_bruteforce_max_tracked must be between 1000 and 200000"})
}
if field, err := validateMailBruteAccountKeyField(cfg); err != nil {
results = append(results, ValidationResult{"error", field, err.Error()})
}
// --- Mail log source ---
if field, err := validateMailLogsField(cfg); err != nil {
results = append(results, ValidationResult{"error", field, err.Error()})
}
// --- Reputation.Rspamd ---
if cfg.Reputation.Rspamd.Enabled {
secret := cfg.Reputation.Rspamd.Token
if cfg.Reputation.Rspamd.TokenEnv != "" {
if v := os.Getenv(cfg.Reputation.Rspamd.TokenEnv); v != "" {
secret = v
}
}
if secret == "" {
results = append(results, ValidationResult{"warn", "reputation.rspamd.token", "rspamd enabled but no token configured (rspamd controller history may require auth)"})
}
}
// --- Reputation.Upstream ---
if cfg.Reputation.Upstream.Enabled {
secret := cfg.Reputation.Upstream.Token
if cfg.Reputation.Upstream.TokenEnv != "" {
if v := os.Getenv(cfg.Reputation.Upstream.TokenEnv); v != "" {
secret = v
}
}
if secret == "" {
results = append(results, ValidationResult{"warn", "reputation.upstream.token", "upstream enabled but no token configured (panel endpoint may require auth)"})
}
}
// --- AutoResponse.VerdictCallback ---
if field, err := validateVerdictCallbackField(cfg); err != nil {
results = append(results, ValidationResult{"error", field, err.Error()})
}
// --- Warnings ---
results = append(results, validateWarnings(cfg)...)
return results
}
func webUITokenCounts(cfg *Config) (tokens, admins int) {
if len(cfg.WebUI.Tokens) == 0 && cfg.WebUI.AuthToken != "" {
tokens++
admins++
}
for _, tok := range cfg.WebUI.Tokens {
if tok.Token == "" {
continue
}
tokens++
if tok.Scope == "admin" {
admins++
}
}
return tokens, admins
}
// validateWarnings checks for non-fatal configuration issues.
func validateWarnings(cfg *Config) []ValidationResult {
var results []ValidationResult
// GeoIP credentials set but auto_update explicitly false
if cfg.GeoIP.AccountID != "" && cfg.GeoIP.LicenseKey != "" {
if cfg.GeoIP.AutoUpdate != nil && !*cfg.GeoIP.AutoUpdate {
results = append(results, ValidationResult{"warn", "geoip", "GeoIP credentials configured but auto_update is disabled"})
}
}
// Auto-response enabled but no actions
if cfg.AutoResponse.Enabled {
if !cfg.AutoResponse.KillProcesses && !cfg.AutoResponse.QuarantineFiles && !cfg.AutoResponse.BlockIPs {
results = append(results, ValidationResult{"warn", "auto_response", "auto_response enabled but no actions configured (kill/quarantine/block all false)"})
}
}
// block_ips wants to mutate nftables, but the firewall engine that
// would apply those rules is disabled or absent. Without this check
// the daemon happily logs "auto-blocked" actions that never reach
// the kernel, and operators only notice when attackers keep coming
// back.
if cfg.AutoResponse.Enabled && cfg.AutoResponse.BlockIPs {
if cfg.Firewall == nil || !cfg.Firewall.Enabled {
results = append(results, ValidationResult{"warn", "auto_response.block_ips", "auto-response wants to block IPs but firewall is disabled; blocks will be no-ops"})
}
}
// Infra IPs both empty
fwInfra := cfg.Firewall != nil && len(cfg.Firewall.InfraIPs) > 0
topInfra := len(cfg.InfraIPs) > 0
if !topInfra && !fwInfra {
results = append(results, ValidationResult{"warn", "infra_ips", "no infra_ips configured in either top-level or firewall section"})
}
// Firewall enabled but no infra IPs (lockout risk)
if cfg.Firewall != nil && cfg.Firewall.Enabled && !topInfra && !fwInfra {
results = append(results, ValidationResult{"warn", "firewall", "firewall enabled but no infra_ips configured - risk of lockout"})
}
// Netblock threshold too low
if cfg.AutoResponse.NetBlock && cfg.AutoResponse.NetBlockThreshold < 2 {
results = append(results, ValidationResult{"warn", "auto_response.netblock_threshold", fmt.Sprintf("netblock_threshold=%d is very low (< 2), may cause excessive blocking", cfg.AutoResponse.NetBlockThreshold)})
}
// Permblock count too low
if cfg.AutoResponse.PermBlock && cfg.AutoResponse.PermBlockCount < 2 {
results = append(results, ValidationResult{"warn", "auto_response.permblock_count", fmt.Sprintf("permblock_count=%d is very low (< 2), may permanently block too quickly", cfg.AutoResponse.PermBlockCount)})
}
return results
}
// ValidateDeep performs connectivity probes against configured services.
// It does NOT call Validate(); the caller should invoke both separately.
func ValidateDeep(cfg *Config) []ValidationResult {
var results []ValidationResult
// State directory
results = append(results, probeStatePath(cfg.StatePath)...)
// Signature rules directory
if cfg.Signatures.RulesDir != "" {
results = append(results, probeRulesDir(cfg.Signatures.RulesDir)...)
}
// SMTP
if cfg.Alerts.Email.Enabled && cfg.Alerts.Email.SMTP != "" {
results = append(results, probeSMTP(cfg.Alerts.Email.SMTP)...)
}
// ClamAV socket
if cfg.EmailAV.Enabled && cfg.EmailAV.ClamdSocket != "" {
results = append(results, probeClamd(cfg.EmailAV.ClamdSocket)...)
}
// TLS cert/key (only when custom paths set)
if cfg.WebUI.TLSCert != "" {
if _, err := os.Stat(cfg.WebUI.TLSCert); err != nil {
results = append(results, ValidationResult{"error", "webui.tls_cert", fmt.Sprintf("file not found: %s", cfg.WebUI.TLSCert)})
} else {
results = append(results, ValidationResult{"ok", "webui.tls_cert", cfg.WebUI.TLSCert})
}
}
if cfg.WebUI.TLSKey != "" {
if _, err := os.Stat(cfg.WebUI.TLSKey); err != nil {
results = append(results, ValidationResult{"error", "webui.tls_key", fmt.Sprintf("file not found: %s", cfg.WebUI.TLSKey)})
} else {
results = append(results, ValidationResult{"ok", "webui.tls_key", cfg.WebUI.TLSKey})
}
}
// Webhook
if cfg.Alerts.Webhook.Enabled && cfg.Alerts.Webhook.URL != "" {
results = append(results, probeWebhook(cfg.Alerts.Webhook.URL)...)
}
// GeoIP database files
if cfg.GeoIP.AccountID != "" && cfg.GeoIP.LicenseKey != "" && len(cfg.GeoIP.Editions) > 0 {
results = append(results, probeGeoIPDBs(cfg.StatePath, cfg.GeoIP.Editions)...)
}
return results
}
// ValidateDeepSection runs only the deep probes relevant to the named
// section, so a save to section X does not fail on an unrelated probe
// for section Y. Section names match the webui settings schema IDs.
//
// For any section without deep probes, returns nil.
func ValidateDeepSection(cfg *Config, section string) []ValidationResult {
switch section {
case "alerts":
var results []ValidationResult
if cfg.Alerts.Email.Enabled && cfg.Alerts.Email.SMTP != "" {
results = append(results, probeSMTP(cfg.Alerts.Email.SMTP)...)
}
if cfg.Alerts.Webhook.Enabled && cfg.Alerts.Webhook.URL != "" {
results = append(results, probeWebhook(cfg.Alerts.Webhook.URL)...)
}
return results
case "email_av":
if cfg.EmailAV.Enabled && cfg.EmailAV.ClamdSocket != "" {
return probeClamd(cfg.EmailAV.ClamdSocket)
}
case "geoip":
if cfg.GeoIP.AccountID != "" && cfg.GeoIP.LicenseKey != "" && len(cfg.GeoIP.Editions) > 0 {
return probeGeoIPDBs(cfg.StatePath, cfg.GeoIP.Editions)
}
case "challenge":
// probeListenPortAvailable is not yet implemented in this codebase.
// When added, invoke it here: return probeListenPortAvailable(cfg.Challenge.ListenPort).
return nil
}
return nil
}
// probeStatePath checks that the state directory exists and is writable.
func probeStatePath(path string) []ValidationResult {
info, err := os.Stat(path)
if err != nil {
return []ValidationResult{{"error", "state_path", fmt.Sprintf("directory not found: %s", path)}}
}
if !info.IsDir() {
return []ValidationResult{{"error", "state_path", fmt.Sprintf("not a directory: %s", path)}}
}
probe := filepath.Join(path, ".csm-validate-probe")
// #nosec G304 -- filepath.Join under operator-configured statePath.
f, err := os.Create(probe)
if err != nil {
return []ValidationResult{{"error", "state_path", fmt.Sprintf("directory not writable: %s", path)}}
}
f.Close()
os.Remove(probe)
return []ValidationResult{{"ok", "state_path", path}}
}
// probeRulesDir checks that the rules directory exists and contains rule files.
func probeRulesDir(path string) []ValidationResult {
info, err := os.Stat(path)
if err != nil {
return []ValidationResult{{"error", "signatures.rules_dir", fmt.Sprintf("directory not found: %s", path)}}
}
if !info.IsDir() {
return []ValidationResult{{"error", "signatures.rules_dir", fmt.Sprintf("not a directory: %s", path)}}
}
// Check for rule files
for _, pattern := range []string{"*.yaml", "*.yml", "*.yar", "*.yara"} {
matches, _ := filepath.Glob(filepath.Join(path, pattern))
if len(matches) > 0 {
return []ValidationResult{{"ok", "signatures.rules_dir", fmt.Sprintf("%s (%d rule files)", path, len(matches))}}
}
}
return []ValidationResult{{"error", "signatures.rules_dir", fmt.Sprintf("no rule files (.yaml/.yml/.yar/.yara) found in %s", path)}}
}
// probeSMTP attempts a TCP dial to the SMTP server.
func probeSMTP(addr string) []ValidationResult {
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
return []ValidationResult{{"error", "alerts.email.smtp", fmt.Sprintf("cannot connect to %s: %v", addr, err)}}
}
_ = conn.Close()
return []ValidationResult{{"ok", "alerts.email.smtp", fmt.Sprintf("connected to %s", addr)}}
}
// probeClamd attempts to connect to the ClamAV unix socket.
func probeClamd(socket string) []ValidationResult {
conn, err := net.DialTimeout("unix", socket, 3*time.Second)
if err != nil {
return []ValidationResult{{"error", "email_av.clamd_socket", fmt.Sprintf("cannot connect to %s: %v", socket, err)}}
}
_ = conn.Close()
return []ValidationResult{{"ok", "email_av.clamd_socket", fmt.Sprintf("connected to %s", socket)}}
}
// probeWebhook performs an HTTP HEAD request to verify the webhook endpoint is reachable.
// DNS/TCP/TLS failures are errors; HTTP status codes (even 401/403/404/405) mean reachable.
func probeWebhook(url string) []ValidationResult {
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Head(url)
if err != nil {
return []ValidationResult{{"error", "alerts.webhook.url", fmt.Sprintf("cannot reach %s: %v", url, err)}}
}
resp.Body.Close()
return []ValidationResult{{"ok", "alerts.webhook.url", fmt.Sprintf("reachable (HTTP %d)", resp.StatusCode)}}
}
func validateSignatureURL(raw string, allowTemplates bool) error {
candidate := strings.TrimSpace(raw)
if allowTemplates {
candidate = sampleSignatureURLTemplate(candidate)
}
u, err := url.Parse(candidate)
if err != nil {
return fmt.Errorf("parse: %w", err)
}
switch strings.ToLower(u.Scheme) {
case "http", "https":
default:
return fmt.Errorf("signatures URL must be an http or https URL")
}
host := u.Hostname()
if host == "" {
return fmt.Errorf("missing host in %q", raw)
}
return validateSignaturesHost(host)
}
func sampleSignatureURLTemplate(raw string) string {
raw = strings.ReplaceAll(raw, "{tier}", "core")
raw = strings.ReplaceAll(raw, "{version}", "v1.0.0")
return raw
}
// validateSignaturesHost rejects URL hosts that point at loopback,
// link-local, or RFC1918 / ULA ranges. Scoped IPv6 literals are valid URL
// hosts, so use netip instead of net.ParseIP to keep the zone intact.
func validateSignaturesHost(host string) error {
lower := strings.TrimSuffix(strings.ToLower(host), ".")
if lower == "localhost" || lower == "localhost.localdomain" {
return fmt.Errorf("signatures URL host %q is loopback; refuse for production downloads", host)
}
if addr, err := netip.ParseAddr(lower); err == nil {
addr = addr.Unmap()
if addr.IsLoopback() {
return fmt.Errorf("signatures URL host %s is loopback", host)
}
if addr.IsPrivate() {
return fmt.Errorf("signatures URL host %s is RFC1918 / ULA private", host)
}
if addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast() {
return fmt.Errorf("signatures URL host %s is link-local", host)
}
if addr.IsUnspecified() {
return fmt.Errorf("signatures URL host %s is unspecified", host)
}
}
return nil
}
// probeGeoIPDBs checks that expected GeoIP database files exist on disk.
func probeGeoIPDBs(statePath string, editions []string) []ValidationResult {
var results []ValidationResult
allOK := true
for _, edition := range editions {
dbPath := filepath.Join(statePath, "geoip", edition+".mmdb")
if _, err := os.Stat(dbPath); err != nil {
results = append(results, ValidationResult{"error", "geoip", fmt.Sprintf("database not found: %s", dbPath)})
allOK = false
}
}
if allOK {
results = append(results, ValidationResult{"ok", "geoip", fmt.Sprintf("all %d edition databases present", len(editions))})
}
return results
}
package config
import (
"bytes"
"fmt"
"sort"
"strings"
"gopkg.in/yaml.v3"
)
// YAMLChange describes a single scalar, list, or map replacement to apply
// to a YAML document. Path is the dotted YAML key path from the document
// root. Value is the Go value to serialise; nil means YAML null.
type YAMLChange struct {
Path []string
Value interface{}
}
// lineIndex maps 1-based line numbers to byte offsets of their first byte.
// lineIndex[i] is the byte offset of line i+1 (i.e. lineIndex[0] == 0 == start of line 1).
type lineIndex []int
func buildLineIndex(data []byte) lineIndex {
idx := lineIndex{0} // line 1 starts at offset 0
for i, b := range data {
if b == '\n' {
idx = append(idx, i+1) // line starts after the newline
}
}
return idx
}
// offset returns the byte offset for a 1-based line and 1-based column.
func (li lineIndex) offset(line, col int) int {
if line < 1 || line > len(li) {
// line beyond end of file -- treat as EOF
if len(li) == 0 {
return 0
}
return li[len(li)-1]
}
return li[line-1] + col - 1
}
// lineEnd returns the byte offset of the '\n' at the end of the given 1-based line,
// or the length of data if the last line has no trailing newline.
func (li lineIndex) lineEnd(line int, data []byte) int {
if line < len(li) {
// next line starts at li[line]; the '\n' is at li[line]-1
return li[line] - 1
}
// last line
return len(data)
}
// lineStart returns the byte offset of the start of 1-based line.
func (li lineIndex) lineStart(line int) int {
if line < 1 {
return 0
}
if line > len(li) {
return li[len(li)-1]
}
return li[line-1]
}
// maxLine returns the maximum Line value in the node subtree.
func maxLine(n *yaml.Node) int {
m := n.Line
for _, child := range n.Content {
if cl := maxLine(child); cl > m {
m = cl
}
}
return m
}
// findNode walks the yaml.Node document tree and returns the key node and value node
// for the last segment of path, plus the parent mapping node.
// Returns (keyNode, valueNode, parentMapping, error).
func findNode(root *yaml.Node, path []string) (keyN, valN, parentMap *yaml.Node, err error) {
// root should be a DocumentNode; its Content[0] is the real root mapping.
cur := root
if cur.Kind == yaml.DocumentNode {
if len(cur.Content) == 0 {
return nil, nil, nil, fmt.Errorf("empty document")
}
cur = cur.Content[0]
}
for i, seg := range path {
if cur.Kind != yaml.MappingNode {
return nil, nil, nil, fmt.Errorf("path %v: segment %q: expected mapping, got kind %d", path[:i+1], seg, cur.Kind)
}
found := false
for j := 0; j+1 < len(cur.Content); j += 2 {
k := cur.Content[j]
v := cur.Content[j+1]
if k.Value == seg {
if i == len(path)-1 {
return k, v, cur, nil
}
cur = v
found = true
break
}
}
if !found {
// segment not found; return the parent mapping so caller can insert
if i == len(path)-1 {
return nil, nil, cur, nil // key missing, parent is cur
}
return nil, nil, nil, fmt.Errorf("path %v: segment %q not found", path[:i+1], seg)
}
}
return nil, nil, nil, fmt.Errorf("empty path")
}
// renderValueInline renders a scalar value to its YAML inline form.
// nil -> "null", bool -> "true"/"false", numbers via fmt, strings quoted if needed.
func renderValueInline(v interface{}) (string, error) {
if v == nil {
return "null", nil
}
// Use yaml.Marshal on a single-value map to get the marshalled scalar,
// then extract just the value part.
type wrapper struct {
V interface{} `yaml:"v"`
}
b, err := yaml.Marshal(wrapper{V: v})
if err != nil {
return "", err
}
// b looks like "v: VALUE\n"
s := strings.TrimPrefix(string(b), "v: ")
s = strings.TrimSuffix(s, "\n")
if strings.ContainsAny(s, "\n\r") {
return "", fmt.Errorf("value of type %T cannot be rendered inline", v)
}
return s, nil
}
// renderKeyInline renders a mapping key string in a YAML-safe form.
// If the key needs quoting (contains special characters, starts with special
// indicators, etc.) yaml.Marshal will add the necessary quotes.
// Returns an error if the key cannot be represented as a single-line YAML key
// (e.g. the key itself contains literal newlines).
func renderKeyInline(key string) (string, error) {
// Marshal a scalar node tagged !!str -- yaml.v3 will add quotes when the
// bare value would be misinterpreted (e.g. "1", ":", " x").
node := &yaml.Node{
Kind: yaml.ScalarNode,
Value: key,
Tag: "!!str",
}
b, err := yaml.Marshal(node)
if err != nil {
return "", err
}
// yaml.Marshal of a scalar node produces "value\n"
rendered := strings.TrimSuffix(string(b), "\n")
// Block literal (|) and folded (>) scalars span multiple lines and cannot
// serve as a simple inline mapping key (key: value on one line).
if strings.ContainsAny(rendered, "\n\r") || strings.HasPrefix(rendered, "|") || strings.HasPrefix(rendered, ">") {
return "", fmt.Errorf("key %q requires multi-line YAML representation and cannot be used as a simple mapping key", key)
}
return rendered, nil
}
// renderValueBlock renders a sequence or mapping value to a block of lines,
// indented at the given column (1-based). Returns lines like:
//
// " - a\n - b\n"
func renderValueBlock(v interface{}, indent int) (string, error) {
raw, err := yaml.Marshal(v)
if err != nil {
return "", err
}
// raw is like "- a\n- b\n" for a sequence, or "key: val\n" for a mapping.
prefix := strings.Repeat(" ", indent-1)
lines := strings.Split(string(raw), "\n")
var sb strings.Builder
for _, line := range lines {
if line == "" {
continue
}
sb.WriteString(prefix)
sb.WriteString(line)
sb.WriteByte('\n')
}
return sb.String(), nil
}
// isBlockValue returns true when the node needs block rendering (sequence or mapping
// that is not on the same line as its key).
func isBlockValue(keyN, valN *yaml.Node) bool {
if valN.Kind == yaml.ScalarNode {
return false
}
return valN.Line > keyN.Line
}
// splice replaces data[start:end] with replacement.
func splice(data []byte, start, end int, replacement []byte) []byte {
var buf bytes.Buffer
buf.Write(data[:start])
buf.Write(replacement)
buf.Write(data[end:])
return buf.Bytes()
}
// edit holds a resolved splice operation.
type edit struct {
start int
end int
replacement []byte
}
// YAMLEdit applies changes to data and returns the new document bytes.
// For every path that already exists, only the value span is rewritten
// at the same indent; untouched bytes (including all comments and
// whitespace) remain byte-identical. For a path that does not exist,
// a new key:value block is appended to the parent mapping at the parent's
// indent. Applies later edits first so earlier offsets remain valid.
func YAMLEdit(data []byte, changes []YAMLChange) ([]byte, error) {
if len(changes) == 0 {
return data, nil
}
seen := make(map[string]struct{}, len(changes))
for _, ch := range changes {
key := strings.Join(ch.Path, "\x00")
if _, dup := seen[key]; dup {
return nil, fmt.Errorf("yamledit: duplicate path %v", ch.Path)
}
seen[key] = struct{}{}
}
var root yaml.Node
if err := yaml.Unmarshal(data, &root); err != nil {
return nil, fmt.Errorf("yamledit: parse: %w", err)
}
li := buildLineIndex(data)
var edits []edit
for _, ch := range changes {
if len(ch.Path) == 0 {
return nil, fmt.Errorf("yamledit: empty path")
}
keyN, valN, parentMap, err := findNode(&root, ch.Path)
if err != nil {
return nil, fmt.Errorf("yamledit: %w", err)
}
if valN == nil {
// Key does not exist -- insert into parentMap.
var ed edit
ed, err = buildInsertEdit(data, li, parentMap, ch.Path[len(ch.Path)-1], ch.Value)
if err != nil {
return nil, fmt.Errorf("yamledit: insert %v: %w", ch.Path, err)
}
edits = append(edits, ed)
continue
}
// Key exists -- replace value span.
var ed edit
ed, err = buildReplaceEdit(data, li, keyN, valN, ch.Value)
if err != nil {
return nil, fmt.Errorf("yamledit: replace %v: %w", ch.Path, err)
}
edits = append(edits, ed)
}
// Sort by start offset descending so we splice end-to-start.
sort.Slice(edits, func(i, j int) bool {
return edits[i].start > edits[j].start
})
result := data
for _, ed := range edits {
result = splice(result, ed.start, ed.end, ed.replacement)
}
// Validate that the output is still parseable YAML. This catches edge cases
// where unusual input formats (complex key notation, etc.) produce invalid output.
var check yaml.Node
if err := yaml.Unmarshal(result, &check); err != nil {
return nil, fmt.Errorf("yamledit: output is not valid YAML: %w", err)
}
return result, nil
}
// buildReplaceEdit computes the splice for replacing an existing value node.
func buildReplaceEdit(data []byte, li lineIndex, keyN, valN *yaml.Node, value interface{}) (edit, error) {
if isBlockValue(keyN, valN) {
// Block sequence or mapping: value occupies one or more complete lines
// starting at valN.Line. Replace from the start of valN.Line to the
// end of the last line in the subtree.
lastLine := maxLine(valN)
start := li.lineStart(valN.Line)
end := li.lineEnd(lastLine, data)
if end < len(data) && data[end] == '\n' {
end++ // include the trailing newline so we replace whole lines
}
rendered, err := renderValueBlock(value, valN.Column)
if err != nil {
return edit{}, fmt.Errorf("path %v: %w", keyN.Value, err)
}
return edit{start: start, end: end, replacement: []byte(rendered)}, nil
}
// Flow / scalar: try inline first. If the new value is a sequence or
// mapping that cannot fit on one line (e.g. replacing `foo: []` with a
// multi-item list) fall back to block rendering and expand the span to
// cover the whole line, so the result is `foo:\n - a\n - b\n`.
start := li.offset(valN.Line, valN.Column)
end := li.lineEnd(valN.Line, data)
rendered, err := renderValueInline(value)
if err == nil {
return edit{start: start, end: end, replacement: []byte(rendered)}, nil
}
if !needsBlockFallback(value) {
return edit{}, fmt.Errorf("path %v: %w", keyN.Value, err)
}
// Replace from the key's column (beginning of `key:`) to end of the key's
// line, emitting `key:\n<block>`. Using keyN.Column keeps the existing
// indent. A trailing `# comment` on the original line is kept attached
// to the key line so operator annotations survive a multi-select save.
keyStart := li.offset(keyN.Line, keyN.Column)
keyEnd := li.lineEnd(keyN.Line, data)
renderedKey, kerr := renderKeyInline(keyN.Value)
if kerr != nil {
return edit{}, fmt.Errorf("path %v: render key: %w", keyN.Value, kerr)
}
block, berr := renderValueBlock(value, keyN.Column+2)
if berr != nil {
return edit{}, fmt.Errorf("path %v: %w", keyN.Value, berr)
}
trailingComment := extractInlineComment(data[keyStart:keyEnd])
return edit{start: keyStart, end: keyEnd, replacement: []byte(renderedKey + ":" + trailingComment + "\n" + strings.TrimRight(block, "\n"))}, nil
}
// extractInlineComment scans a single YAML line for a trailing `#` comment
// (one preceded by whitespace, outside string/rune literals) and returns it
// with its leading whitespace, e.g. " # keep empty". Returns "" if there
// is no comment. The scanner is deliberately simple — it handles the common
// case of a flow value followed by whitespace + `#`; it does not attempt to
// parse every YAML quoting variant (block scalars etc. do not reach this
// path).
func extractInlineComment(line []byte) string {
inDQ := false
inSQ := false
for i := 0; i < len(line); i++ {
c := line[i]
if c == '"' && !inSQ {
// Count backslashes before this quote to detect escapes.
bs := 0
for j := i - 1; j >= 0 && line[j] == '\\'; j-- {
bs++
}
if bs%2 == 0 {
inDQ = !inDQ
}
continue
}
if c == '\'' && !inDQ {
inSQ = !inSQ
continue
}
if inDQ || inSQ {
continue
}
if c != '#' {
continue
}
// A `#` only starts a comment when preceded by whitespace (or at the
// start of the line). Otherwise it is a legal scalar character.
if i == 0 || line[i-1] == ' ' || line[i-1] == '\t' {
// Include the whitespace run before `#` so output keeps spacing.
start := i
for start > 0 && (line[start-1] == ' ' || line[start-1] == '\t') {
start--
}
return string(line[start:])
}
}
return ""
}
// needsBlockFallback reports whether value is a sequence or mapping type
// that may exceed one line and therefore warrants switching the replacement
// from inline to block rendering.
func needsBlockFallback(value interface{}) bool {
switch value.(type) {
case []string, []interface{}, map[string]interface{}, map[interface{}]interface{}:
return true
}
return false
}
// buildInsertEdit computes the splice for appending a new key to a mapping node.
func buildInsertEdit(data []byte, li lineIndex, parentMap *yaml.Node, key string, value interface{}) (edit, error) {
// Find the insertion point: end of the last content line of parentMap.
// If parentMap has no content (empty mapping), insert after the mapping's own line.
insertLine := parentMap.Line
if len(parentMap.Content) > 0 {
// last child is parentMap.Content[len-1]
last := parentMap.Content[len(parentMap.Content)-1]
insertLine = maxLine(last)
}
insertOff := li.lineEnd(insertLine, data)
needsNewline := false
if insertOff < len(data) && data[insertOff] == '\n' {
insertOff++ // insert after the newline
} else if insertOff > 0 && data[insertOff-1] != '\n' {
// Last line has no trailing newline; we must add one before the new key.
needsNewline = true
}
// Match the indent of existing siblings rather than inferring from the
// parent node's own column, which in a nested mapping does not reflect
// child indentation.
siblingCol := parentMap.Column
if len(parentMap.Content) >= 1 {
siblingCol = parentMap.Content[0].Column
}
prefix := strings.Repeat(" ", siblingCol-1)
renderedKey, err := renderKeyInline(key)
if err != nil {
return edit{}, fmt.Errorf("render key %q: %w", key, err)
}
var rendered string
// Determine if value needs block style.
switch value.(type) {
case []string, []interface{}:
block, err := renderValueBlock(value, siblingCol+2)
if err != nil {
return edit{}, fmt.Errorf("path %v: %w", key, err)
}
rendered = prefix + renderedKey + ":\n" + block
default:
inline, err := renderValueInline(value)
if err != nil {
return edit{}, fmt.Errorf("path %v: %w", key, err)
}
rendered = prefix + renderedKey + ": " + inline + "\n"
}
if needsNewline {
rendered = "\n" + rendered
}
return edit{start: insertOff, end: insertOff, replacement: []byte(rendered)}, nil
}
package daemon
import (
"context"
"crypto/ed25519"
"encoding/hex"
"log"
"os"
"path/filepath"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/reporting"
)
const (
abuseReportSpoolFile = "abuse_reports.db"
abuseReportSpoolDefault = 10000
abuseReportQueueDefault = abuseReportSpoolDefault
abuseReportDrainEvery = time.Minute
)
// startAbuseReporting wires the abuse reporter from config: it sets
// alert.ReportHook so confirmed-abuse findings are gated, minimized, and
// queued for spooling, and returns the spool drain loop to run as a supervised
// goroutine.
// It returns nil when reporting is disabled or misconfigured (logged), leaving
// the alert path untouched.
func (d *Daemon) startAbuseReporting() func() {
alert.SetReportHook(nil)
rc := d.cfg.Reputation.Report
if !rc.Enabled {
return nil
}
targets := buildReportTargets(rc.Targets)
if len(targets) == 0 {
log.Printf("abuse-reporting: enabled but no usable targets configured; reporting stays off")
return nil
}
enabled := classSet(rc.Classes)
if len(enabled) == 0 {
log.Printf("abuse-reporting: enabled but no valid classes configured; reporting stays off")
return nil
}
spoolPath := rc.SpoolPath
if spoolPath == "" {
spoolPath = filepath.Join(d.cfg.StatePath, abuseReportSpoolFile)
}
max := rc.SpoolMax
if max <= 0 {
max = abuseReportSpoolDefault
}
spool, err := reporting.NewSpool(spoolPath, "reports", max)
if err != nil {
log.Printf("abuse-reporting: cannot open spool %s: %v; reporting stays off", spoolPath, err)
return nil
}
reportQueue := make(chan reporting.Report, abuseReportQueueSize(max))
spooler := reporting.NewSpooler(spool, reporting.NewSender(nil, nil), targets, abuseReportDrainEvery)
gate := reporting.Gate{Enabled: enabled}
stopCh := make(chan struct{})
doneCh := make(chan struct{})
d.abuseReportStop = stopCh
d.abuseReportDone = doneCh
var dropped atomic.Uint64
var loggedDropped uint64
logDropped := func() {
n := dropped.Load()
if n != loggedDropped {
log.Printf("abuse-reporting: report queue full; dropped %d report(s)", n)
loggedDropped = n
}
}
alert.SetReportHook(func(f alert.Finding) {
if r, ok := gate.Consider(f); ok {
select {
case reportQueue <- r:
default:
dropped.Add(1)
}
}
})
log.Printf("abuse-reporting: enabled for %d target(s), %d class(es)", len(targets), len(enabled))
return func() {
defer close(doneCh)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
<-stopCh
cancel()
}()
ticker := time.NewTicker(abuseReportDrainEvery)
defer ticker.Stop()
defer func() {
alert.SetReportHook(nil)
_ = spool.Close()
}()
for {
select {
case r := <-reportQueue:
spooler.Enqueue(r)
case <-ticker.C:
logDropped()
spooler.DrainOnce(ctx)
case <-stopCh:
alert.SetReportHook(nil)
logDropped()
drainReportQueue(spooler, reportQueue)
return
}
}
}
}
func (d *Daemon) stopAbuseReporting() {
if d.abuseReportStop == nil {
return
}
close(d.abuseReportStop)
<-d.abuseReportDone
d.abuseReportStop = nil
d.abuseReportDone = nil
}
func abuseReportQueueSize(spoolMax int) int {
if spoolMax > 0 && spoolMax < abuseReportQueueDefault {
return spoolMax
}
return abuseReportQueueDefault
}
func drainReportQueue(spooler *reporting.Spooler, reportQueue <-chan reporting.Report) {
for {
select {
case r := <-reportQueue:
spooler.Enqueue(r)
default:
return
}
}
}
// classSet parses configured class names into the set the gate accepts,
// skipping unknown values with a log line.
func classSet(names []string) map[reporting.Class]bool {
known := map[reporting.Class]bool{
reporting.ClassBruteforce: true,
reporting.ClassPHPRelay: true,
reporting.ClassCredentialStuffing: true,
reporting.ClassBadASNEgress: true,
}
out := make(map[reporting.Class]bool)
for _, n := range names {
c := reporting.Class(n)
if known[c] {
out[c] = true
} else {
log.Printf("abuse-reporting: ignoring unknown report class %q", n)
}
}
return out
}
// reportTargetConfig mirrors the per-target config shape; declared so the
// builder takes a concrete slice type from the anonymous struct in config.
type reportTargetConfig = struct {
Name string `yaml:"name"`
URL string `yaml:"url"`
Transport string `yaml:"transport"`
NodeID string `yaml:"node_id"`
KeyID string `yaml:"key_id"`
KeyEnv string `yaml:"key_env"`
TokenEnv string `yaml:"token_env"`
}
// buildReportTargets resolves configured targets into sender targets, reading
// key material from the environment. Invalid targets are skipped with a log
// line rather than failing startup.
func buildReportTargets(cfgTargets []reportTargetConfig) []reporting.Target {
var targets []reporting.Target
for _, ct := range cfgTargets {
if ct.Name == "" || ct.URL == "" || ct.NodeID == "" || ct.KeyID == "" {
log.Printf("abuse-reporting: skipping target with missing name/url/node_id/key_id")
continue
}
if err := reporting.ValidateTargetURL(ct.URL); err != nil {
log.Printf("abuse-reporting: target %q: URL must be HTTPS or loopback HTTP; skipping", ct.Name)
continue
}
t := reporting.Target{
Name: ct.Name,
URL: ct.URL,
NodeID: ct.NodeID,
KeyID: ct.KeyID,
}
secret := os.Getenv(ct.KeyEnv)
switch reporting.Transport(ct.Transport) {
case reporting.TransportEd25519:
raw, err := hex.DecodeString(secret)
if err != nil || len(raw) != ed25519.PrivateKeySize {
log.Printf("abuse-reporting: target %q: configured key_env must hold a 64-byte hex Ed25519 key; skipping", ct.Name)
continue
}
t.Transport = reporting.TransportEd25519
t.Ed25519Key = ed25519.PrivateKey(raw)
case reporting.TransportHMAC:
if secret == "" {
log.Printf("abuse-reporting: target %q: configured key_env is empty; skipping", ct.Name)
continue
}
t.Transport = reporting.TransportHMAC
t.HMACSecret = []byte(secret)
if ct.TokenEnv != "" {
t.BearerToken = os.Getenv(ct.TokenEnv)
}
default:
log.Printf("abuse-reporting: target %q: unknown transport %q; skipping", ct.Name, ct.Transport)
continue
}
targets = append(targets, t)
}
return targets
}
//go:build linux
package daemon
import (
"bytes"
"context"
"errors"
"fmt"
"os"
"sync/atomic"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/unix"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
)
// auditLogPath is the file the kernel auditd writes events to. Var (not
// const) so tests can redirect it under t.TempDir().
var auditLogPath = "/var/log/audit/audit.log"
// AFAlgAuditListener tails /var/log/audit/audit.log via inotify, parses
// each new line for the csm_af_alg_socket auditd key, and emits a
// Critical alert.Finding (plus optional kill/quarantine reactions)
// within milliseconds of the syscall. Sub-second response on hosts where
// BPF LSM is not available.
//
// The listener seeks to end-of-file at startup — events that pre-date
// the daemon are intentionally not re-alerted; the periodic critical-tier
// CheckAFAlgSocketUsage handles backfill via its (timestamp, serial)
// cursor in state.Store.
type AFAlgAuditListener struct {
alertCh chan<- alert.Finding
cfg *config.Config
path string
inotifyFd int
file *os.File
pos int64
leftover []byte // partial line accumulator across reads
droppedOversize bool // mid-drop of a line that overflowed the cap
eventCount atomic.Uint64 // observed by tests / metrics
}
// afAlgMaxLeftoverBytes caps the partial-line accumulator. A real auditd
// SYSCALL record is well under 1 KiB; 64 KiB leaves generous headroom while
// bounding memory if a record never terminates.
const afAlgMaxLeftoverBytes = 64 * 1024
// Mode reports the live-monitor backend kind. Matches the BPF path's
// "bpf-lsm" return so the coordinator and operator-visible logs use a
// stable, machine-readable label.
func (l *AFAlgAuditListener) Mode() string { return "auditd-tail" }
// EventCount returns the number of csm_af_alg_socket events this
// listener has parsed since startup. Operational metric; not exported
// to Prometheus here to keep the listener self-contained.
func (l *AFAlgAuditListener) EventCount() uint64 { return l.eventCount.Load() }
// NewAFAlgAuditListener opens the audit log, seeks to its current end,
// and registers an inotify watch on the file. The watch fires for
// IN_MODIFY (new bytes appended), IN_MOVE_SELF (logrotate moves the
// file), and IN_DELETE_SELF (file unlinked) so we can re-open the
// rotated/replaced file.
//
// Returns an error if /var/log/audit/audit.log is missing — caller is
// expected to log a warning and either skip live detection (with
// periodic check still active) or retry later.
func NewAFAlgAuditListener(alertCh chan<- alert.Finding, cfg *config.Config) (*AFAlgAuditListener, error) {
l := &AFAlgAuditListener{
alertCh: alertCh,
cfg: cfg,
path: auditLogPath,
}
if err := l.open(); err != nil {
return nil, err
}
return l, nil
}
// open initialises (or re-initialises after rotation) the audit log fd
// and inotify watch. Seeks to end-of-file so we tail forward, never
// re-alerting historical events.
func (l *AFAlgAuditListener) open() error {
// #nosec G304 -- l.path is /var/log/audit/audit.log (or t.TempDir()
// equivalent); not user-controlled.
f, err := os.Open(l.path)
if err != nil {
return fmt.Errorf("open %s: %w", l.path, err)
}
end, err := f.Seek(0, 2) // SEEK_END
if err != nil {
_ = f.Close()
return fmt.Errorf("seek %s: %w", l.path, err)
}
fd, err := unix.InotifyInit1(unix.IN_CLOEXEC | unix.IN_NONBLOCK)
if err != nil {
_ = f.Close()
return fmt.Errorf("inotify_init: %w", err)
}
mask := uint32(unix.IN_MODIFY | unix.IN_MOVE_SELF | unix.IN_DELETE_SELF)
if _, err := unix.InotifyAddWatch(fd, l.path, mask); err != nil {
_ = unix.Close(fd)
_ = f.Close()
return fmt.Errorf("inotify_add_watch %s: %w", l.path, err)
}
// Replace any existing fds (rotation case).
if l.file != nil {
_ = l.file.Close()
}
if l.inotifyFd != 0 {
_ = unix.Close(l.inotifyFd)
}
l.file = f
l.pos = end
l.inotifyFd = fd
l.leftover = nil
l.droppedOversize = false
return nil
}
// Run drains inotify events and tails the audit log until ctx is done.
// Polls the inotify fd every poll interval (matches forwarder_watcher's
// approach — no epoll, easier to reason about).
//
// On rotation (IN_MOVE_SELF / IN_DELETE_SELF) the listener re-opens the
// new audit.log and continues from its end. There is a brief window
// during rotation where events written between the move and the re-open
// can be missed; the periodic critical-tier check covers that gap via
// its persistent cursor.
func (l *AFAlgAuditListener) Run(ctx context.Context) {
defer func() {
if l.inotifyFd != 0 {
_ = unix.Close(l.inotifyFd)
}
if l.file != nil {
_ = l.file.Close()
}
}()
inotifyBuf := make([]byte, 4096)
readBuf := make([]byte, 16*1024)
// 500 ms tick gives sub-second average detection latency. The cost
// is ~2 cheap syscalls/sec (inotify Read returns EAGAIN immediately
// when nothing's queued, file Seek+Read returns EOF cheaply when
// no new bytes). Cheap enough to not bother with epoll/select for
// a v1 — we can refactor to true event-driven if we ever need
// hundred-microsecond latency.
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
// Start by draining anything the kernel may have queued during
// startup.
l.tail(readBuf)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
rotated, ok := l.drainInotify(inotifyBuf)
if !ok {
continue
}
if rotated {
if err := l.open(); err != nil {
csmlog.Warn("af_alg audit listener: re-open failed", "err", err)
continue
}
}
l.tail(readBuf)
}
}
}
// drainInotify reads any queued inotify events. Returns:
//
// ok=false on read errors that aren't EAGAIN/EINTR (the listener
// should skip this tick entirely; the next tick re-tries).
// rotated=true if any IN_MOVE_SELF or IN_DELETE_SELF event was seen.
//
// IN_MODIFY events implicitly trigger a tail() in the caller because we
// always read at the end of each tick on success.
func (l *AFAlgAuditListener) drainInotify(buf []byte) (rotated, ok bool) {
for {
n, err := unix.Read(l.inotifyFd, buf)
if err != nil {
// EAGAIN: no more events queued — the normal terminator.
// EINTR: interrupted by signal, also benign — retry next tick.
if errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EWOULDBLOCK) || errors.Is(err, syscall.EINTR) {
return rotated, true
}
// Anything else (EBADF after a stray Close, EIO, EFAULT)
// signals the inotify fd is no longer healthy. Surface it
// so the operator notices, and tell the caller to skip the
// tail() this tick — we'll re-evaluate next tick when the
// 5s safety-net runs.
csmlog.Warn("af_alg audit listener: inotify read error", "err", err)
return rotated, false
}
if n <= 0 {
return rotated, true
}
offset := 0
for offset+unix.SizeofInotifyEvent <= n {
// #nosec G103 -- inotify packed binary stream;
// reinterpretation is required and the bound check above
// guarantees the read is in-range.
ev := (*unix.InotifyEvent)(unsafe.Pointer(&buf[offset]))
if ev.Mask&(unix.IN_MOVE_SELF|unix.IN_DELETE_SELF) != 0 {
rotated = true
}
offset += unix.SizeofInotifyEvent + int(ev.Len)
}
}
}
// tail reads any new bytes since l.pos, splits on newlines, and feeds
// each complete line to handleLine. Partial trailing bytes (no newline
// yet) are buffered in l.leftover for the next tick.
func (l *AFAlgAuditListener) tail(buf []byte) {
if _, err := l.file.Seek(l.pos, 0); err != nil {
csmlog.Warn("af_alg audit listener: seek failed", "err", err)
return
}
for {
n, err := l.file.Read(buf)
if n > 0 {
l.feed(buf[:n])
l.pos += int64(n)
}
if err != nil {
// io.EOF or EAGAIN: out of data for this tick.
return
}
}
}
// feed appends a chunk of audit-log bytes to the leftover buffer and
// emits a finding for each complete line containing the csm_af_alg_socket
// key. Lines are kept in the leftover until terminated by '\n' so a
// short read at the end of the buffer does not corrupt a multi-byte
// timestamp split across two reads.
func (l *AFAlgAuditListener) feed(chunk []byte) {
l.leftover = append(l.leftover, chunk...)
for {
idx := bytes.IndexByte(l.leftover, '\n')
if idx < 0 {
// No complete line yet. A real audit record fits well under the
// cap; an unterminated buffer past it is garbage (truncated write,
// binary noise, or an attacker-stretched exe= path). Drop it so the
// accumulator cannot grow without bound, and resync at the next
// newline.
if len(l.leftover) > afAlgMaxLeftoverBytes {
l.leftover = nil
l.droppedOversize = true
}
return
}
line := l.leftover[:idx]
l.leftover = l.leftover[idx+1:]
// Skip the remainder of a line whose head was already dropped for
// exceeding the cap: it is a partial record, not a parseable line.
if l.droppedOversize {
l.droppedOversize = false
continue
}
l.handleLine(string(line))
}
}
// handleLine inspects one audit log line. If it carries the
// csm_af_alg_socket key, parse the event and dispatch a finding.
func (l *AFAlgAuditListener) handleLine(line string) {
ev, ok := checks.ParseAFAlgEventLine(line)
if !ok {
return
}
l.eventCount.Add(1)
finding := alert.Finding{
Severity: alert.Critical,
Check: "af_alg_socket_use",
Message: fmt.Sprintf("AF_ALG socket opened by uid=%s exe=%s", ev.UID, ev.Exe),
Timestamp: time.Now(),
Details: fmt.Sprintf(
"Live audit-log detection: timestamp=%s serial=%s\nauid=%s uid=%s comm=%q exe=%q pid=%s\n"+
"AF_ALG is essentially never used by cPanel/PHP workloads. This is\n"+
"the kernel-level exploit signature for CVE-2026-31431 (\"Copy Fail\").\n"+
"This event was caught by the live audit-log listener (sub-second\n"+
"latency); investigate this process immediately.",
ev.Timestamp, ev.Serial, ev.AUID, ev.UID, ev.Comm, ev.Exe, ev.PID,
),
}
// Non-blocking send — the alert dispatcher buffer is sized for bursts;
// dropping a finding under extreme pressure is preferable to blocking
// the listener loop.
select {
case l.alertCh <- finding:
default:
csmlog.Warn("af_alg audit listener: alert channel full; finding dropped", "uid", ev.UID, "exe", ev.Exe)
}
// Optional reactions (kill, quarantine) gated by config; implemented
// in af_alg_react.go so the BPF path can reuse the same logic.
reactToAFAlgEvent(l.cfg, ev)
}
//go:build !(linux && bpf)
package daemon
import (
"context"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/bpf"
"github.com/pidginhost/csm/internal/config"
)
// tryStartBPFLSM is the no-tag fallback. The real implementation lives in
// af_alg_bpf.go behind //go:build linux && bpf.
func tryStartBPFLSM(_ context.Context, _ chan<- alert.Finding, _ *config.Config) (AFAlgLiveMonitor, error) {
return nil, bpf.ErrNotBuilt
}
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64
package af_alg_bpfprog
import (
"bytes"
_ "embed"
"fmt"
"io"
"structs"
"github.com/cilium/ebpf"
)
type AFAlgAfAlgEvent struct {
_ structs.HostLayout
Uid uint32
Pid uint32
Ppid uint32
Comm [16]uint8
ParentComm [16]uint8
Exe [256]uint8
}
// LoadAFAlg returns the embedded CollectionSpec for AFAlg.
func LoadAFAlg() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_AFAlgBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load AFAlg: %w", err)
}
return spec, err
}
// LoadAFAlgObjects loads AFAlg and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *AFAlgObjects
// *AFAlgPrograms
// *AFAlgMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func LoadAFAlgObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := LoadAFAlg()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// AFAlgSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type AFAlgSpecs struct {
AFAlgProgramSpecs
AFAlgMapSpecs
AFAlgVariableSpecs
}
// AFAlgProgramSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type AFAlgProgramSpecs struct {
CsmBlockAfAlg *ebpf.ProgramSpec `ebpf:"csm_block_af_alg"`
}
// AFAlgMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type AFAlgMapSpecs struct {
Events *ebpf.MapSpec `ebpf:"events"`
}
// AFAlgVariableSpecs contains global variables before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type AFAlgVariableSpecs struct {
Unused *ebpf.VariableSpec `ebpf:"unused"`
}
// AFAlgObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to LoadAFAlgObjects or ebpf.CollectionSpec.LoadAndAssign.
type AFAlgObjects struct {
AFAlgPrograms
AFAlgMaps
AFAlgVariables
}
func (o *AFAlgObjects) Close() error {
return _AFAlgClose(
&o.AFAlgPrograms,
&o.AFAlgMaps,
)
}
// AFAlgMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to LoadAFAlgObjects or ebpf.CollectionSpec.LoadAndAssign.
type AFAlgMaps struct {
Events *ebpf.Map `ebpf:"events"`
}
func (m *AFAlgMaps) Close() error {
return _AFAlgClose(
m.Events,
)
}
// AFAlgVariables contains all global variables after they have been loaded into the kernel.
//
// It can be passed to LoadAFAlgObjects or ebpf.CollectionSpec.LoadAndAssign.
type AFAlgVariables struct {
Unused *ebpf.Variable `ebpf:"unused"`
}
// AFAlgPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to LoadAFAlgObjects or ebpf.CollectionSpec.LoadAndAssign.
type AFAlgPrograms struct {
CsmBlockAfAlg *ebpf.Program `ebpf:"csm_block_af_alg"`
}
func (p *AFAlgPrograms) Close() error {
return _AFAlgClose(
p.CsmBlockAfAlg,
)
}
func _AFAlgClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed afalg_x86_bpfel.o
var _AFAlgBytes []byte
package daemon
import (
"context"
"fmt"
"strings"
"sync"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/bpf"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/metrics"
)
// AFAlgLiveMonitor was the local name for the live-monitor interface;
// it is now an alias of the shared bpf.Backend. Existing call sites in
// reactToAFAlgEvent etc. continue to compile.
type AFAlgLiveMonitor = bpf.Backend
// AFAlgBackend* are the operator-facing cfg.Detection.AFAlgBackend values.
// Reuse bpf constants where the public value already matches.
const (
AFAlgBackendAuto = bpf.BackendAuto
AFAlgBackendBPF = bpf.BackendBPF
AFAlgBackendAuditd = "auditd"
AFAlgBackendNone = bpf.BackendNone
)
var (
afAlgBackendMetricOnce sync.Once
afAlgBackendMetric *metrics.GaugeVec
)
func ensureAFAlgBackendMetric() {
afAlgBackendMetricOnce.Do(func() {
afAlgBackendMetric = metrics.NewGaugeVec(
"csm_af_alg_backend",
"Active AF_ALG (Copy Fail) live-monitor backend; 1 for the selected kind, 0 otherwise.",
[]string{"kind"},
)
metrics.MustRegister("csm_af_alg_backend", afAlgBackendMetric)
})
}
func setAFAlgBackendMetric(active string) {
ensureAFAlgBackendMetric()
for _, k := range []string{"bpf-lsm", "auditd-tail", "none"} {
v := 0.0
if k == active {
v = 1.0
}
afAlgBackendMetric.With(k).Set(v)
}
switch active {
case "bpf-lsm":
bpf.SetActive("af_alg", bpf.BackendBPF)
case "auditd-tail":
bpf.SetActive("af_alg", bpf.BackendLegacy)
default:
bpf.SetActive("af_alg", bpf.BackendNone)
}
}
// StartAFAlgLiveMonitor returns the live monitor selected by
// cfg.Detection.AFAlgBackend. "" / "auto" tries BPF LSM first (when
// compiled in and kernel-supported) and falls back to the audit listener.
// "bpf" requires BPF — no audit fallback if BPF is unavailable, useful for
// hosts where the operator deliberately wants the kernel-side block or
// nothing. "auditd" pins the audit listener even on BPF-capable hosts, the
// kill switch when a BPF-tagged release misbehaves. "none" disables the
// live monitor (the periodic critical-tier check still runs). Returns nil
// when no backend ends up active; the metric csm_af_alg_backend{kind=...}
// reflects whichever path was selected.
func StartAFAlgLiveMonitor(alertCh chan<- alert.Finding, cfg *config.Config) AFAlgLiveMonitor {
choice := strings.ToLower(strings.TrimSpace(cfg.Detection.AFAlgBackend))
if choice == "" {
choice = AFAlgBackendAuto
}
switch choice {
case AFAlgBackendAuto, AFAlgBackendBPF, AFAlgBackendAuditd, AFAlgBackendNone:
default:
csmlog.Warn("af_alg live monitor: unknown backend choice, falling back to auto",
"value", choice,
)
choice = AFAlgBackendAuto
}
if choice == AFAlgBackendNone {
csmlog.Info("af_alg live monitor: disabled by config")
setAFAlgBackendMetric("none")
return nil
}
var bpfErr error
if choice == AFAlgBackendAuto || choice == AFAlgBackendBPF {
if mon, err := tryStartBPFLSMFn(context.Background(), alertCh, cfg); err == nil && mon != nil {
csmlog.Info("af_alg live monitor", "backend", "bpf-lsm", "choice", choice)
setAFAlgBackendMetric("bpf-lsm")
return mon
} else if err != nil {
bpfErr = err
csmlog.Info("af_alg live monitor: BPF LSM unavailable",
"state", "bpf-lsm-unsupported",
"reason", err.Error(),
"choice", choice,
)
if choice == AFAlgBackendBPF {
csmlog.Warn("af_alg live monitor: af_alg_backend=bpf but BPF unavailable; no live detection",
"reason", err.Error(),
)
setAFAlgBackendMetric("none")
emitBPFUnavailableFinding(alertCh, "af_alg", choice, "", err)
return nil
}
}
}
listener, err := NewAFAlgAuditListener(alertCh, cfg)
if err != nil {
csmlog.Warn("af_alg live monitor: auditd fallback unavailable", "err", err)
setAFAlgBackendMetric("none")
if bpfErr != nil {
emitBPFUnavailableFinding(alertCh, "af_alg", choice, "", fmt.Errorf("BPF unavailable: %w; audit fallback unavailable: %w", bpfErr, err))
}
return nil
}
csmlog.Info("af_alg live monitor", "backend", "auditd-tail", "choice", choice)
setAFAlgBackendMetric("auditd-tail")
if bpfErr != nil {
emitBPFUnavailableFinding(alertCh, "af_alg", choice, "auditd-tail", bpfErr)
}
return listener
}
// tryStartBPFLSMFn is a package-level indirection so tests can substitute a
// fake for the BPF probe path without needing the bpf build tag or kernel
// privileges. Production code goes through tryStartBPFLSM, which is the
// stub on default builds and the real probe on -tags bpf builds.
var tryStartBPFLSMFn = tryStartBPFLSM
//go:build linux
package daemon
import (
"strconv"
"golang.org/x/sys/unix"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
)
// reactToAFAlgEvent applies opt-in live reactions when an AF_ALG socket
// open is caught by either the audit-log listener or the BPF LSM hook.
// Currently supports a single reaction: SIGKILL the offending process
// (gated by config.AutoResponse.CopyFailKillProcess).
//
// Reactions are intentionally narrow: a critical alert is always emitted
// by the listener itself (this function is for *additional* responses
// beyond alerting). Quarantining the offending exe is a future addition;
// keeping the surface minimal until the kill path has been observed in
// production.
//
// Refuses to act on PID 0/1 to avoid catastrophic mistakes if the parser
// ever returns something unexpected.
func reactToAFAlgEvent(cfg *config.Config, ev checks.AFAlgEvent) {
if cfg == nil || !cfg.AutoResponse.CopyFailKillProcess {
return
}
pid, err := strconv.Atoi(ev.PID)
if err != nil || pid <= 1 {
csmlog.Warn("af_alg react: refusing to kill",
"reason", "implausible pid",
"pid", ev.PID, "exe", ev.Exe, "uid", ev.UID,
)
return
}
if err := unix.Kill(pid, unix.SIGKILL); err != nil {
csmlog.Warn("af_alg react: kill failed",
"pid", pid, "exe", ev.Exe, "uid", ev.UID,
"err", err,
)
return
}
csmlog.Info("af_alg react: killed offending process",
"pid", pid, "exe", ev.Exe, "uid", ev.UID, "comm", ev.Comm,
)
}
package daemon
// pkgManagerComms are the process names CSM treats as evidence that an
// observed sensitive-file write originated from a legitimate root-driven
// package transaction. The list intentionally omits shells (sh, bash) and
// generic utilities (cp, mv) -- attackers reuse those. Matching the package
// manager binary itself anywhere in the parent chain is the discriminator.
var pkgManagerComms = map[string]struct{}{
"dnf": {},
"dnf-3": {},
"microdnf": {},
"yum": {},
"rpm": {},
"dpkg": {},
"apt": {},
"apt-get": {},
"unattended-upgr": {}, // unattended-upgrade is comm-truncated to TASK_COMM_LEN-1.
}
func isPackageManagerComm(comm string) bool {
_, ok := pkgManagerComms[comm]
return ok
}
//go:build !linux || !bpf
package daemon
import "github.com/pidginhost/csm/internal/processctx"
// wireAncestryProbeIfAvailable is a no-op on hosts built without the bpf
// build tag. checks.AncestryProbe stays nil; rescoreSensitive falls back to
// the package-manager log mtime signal alone.
func wireAncestryProbeIfAvailable(*processctx.Cache) {}
package daemon
import "regexp"
// atomicWriteStageRE matches the `.temp.<digits>.<rest>` filename pattern
// emitted by cPanel's fileTransfer service and similar atomic-write
// helpers. The digits are a nanosecond timestamp; <rest> is the original
// basename that will be rename(2)d into place.
var atomicWriteStageRE = regexp.MustCompile(`^\.temp\.\d+\..+`)
// looksLikeAtomicWriteStage reports whether a base filename matches the
// atomic-write staging convention `.temp.<digits>.<name>`. Pass the
// basename, not the full path.
func looksLikeAtomicWriteStage(name string) bool {
return atomicWriteStageRE.MatchString(name)
}
package daemon
// PolicyMapPayload is the wire-shape of struct policy_state in the BPF
// program. Mirrors the C layout exactly: three uint32 fields, no
// alignment surprises (all fields are 4 bytes).
type PolicyMapPayload struct {
Enforce uint32
DryRun uint32
ProtectedPorts uint32
}
// policyMapPayload converts the userspace policy into the wire shape
// the BPF program reads. Pure function; tests drive it without loading
// any actual BPF program.
func policyMapPayload(pol BPFEnforcementPolicy) PolicyMapPayload {
return PolicyMapPayload{
Enforce: pol.Enforce,
DryRun: pol.DryRun,
// #nosec G115 -- bounded by protected_ports BPF map max_entries=16
ProtectedPorts: uint32(len(pol.Ports)),
}
}
package daemon
import (
"sync"
"github.com/pidginhost/csm/internal/metrics"
)
// Decision label values match the BPF DECISION_* numeric codes:
// allow (0), dry_run (1), deny (2). Wire-stable strings; SIEM
// dashboards pin on them.
const (
BPFDecisionAllow = "allow"
BPFDecisionDryRun = "dry_run"
BPFDecisionDeny = "deny"
)
var (
bpfEnfMu sync.Mutex
bpfEnfDecisionsVec *metrics.CounterVec
bpfEnfUIDRefreshTotal *metrics.Counter
bpfEnfUIDRefreshFailures *metrics.Counter
)
// RegisterBPFEnforcementMetrics registers the counters on reg. Production
// callers pass metrics.Default(); tests pass metrics.NewRegistry() to
// keep registration isolated.
func RegisterBPFEnforcementMetrics(reg *metrics.Registry) {
bpfEnfMu.Lock()
defer bpfEnfMu.Unlock()
bpfEnfDecisionsVec = metrics.NewCounterVec(
"csm_bpf_enforcement_decisions_total",
"BPF cgroup-deny decisions by label (allow/dry_run/deny).",
[]string{"decision"},
)
reg.MustRegister("csm_bpf_enforcement_decisions_total", bpfEnfDecisionsVec)
bpfEnfUIDRefreshTotal = metrics.NewCounter(
"csm_bpf_enforcement_uid_map_refresh_total",
"BPF safe-UID map refresh successes.",
)
reg.MustRegister("csm_bpf_enforcement_uid_map_refresh_total", bpfEnfUIDRefreshTotal)
bpfEnfUIDRefreshFailures = metrics.NewCounter(
"csm_bpf_enforcement_uid_map_refresh_failures_total",
"BPF safe-UID map refresh failures.",
)
reg.MustRegister("csm_bpf_enforcement_uid_map_refresh_failures_total", bpfEnfUIDRefreshFailures)
}
// BumpBPFEnforcementDecision advances the per-decision counter. Called
// from the connection consumer when a ConnectionEvent with a decision
// field arrives. Unknown labels are silently ignored (caller bug).
func BumpBPFEnforcementDecision(label string) {
bpfEnfMu.Lock()
cv := bpfEnfDecisionsVec
bpfEnfMu.Unlock()
if cv == nil {
return
}
switch label {
case BPFDecisionAllow, BPFDecisionDryRun, BPFDecisionDeny:
cv.With(label).Inc()
}
}
// BumpUIDRefresh advances the periodic-refresh success counter.
func BumpUIDRefresh() {
bpfEnfMu.Lock()
c := bpfEnfUIDRefreshTotal
bpfEnfMu.Unlock()
if c != nil {
c.Inc()
}
}
// BumpUIDRefreshFailure advances the periodic-refresh failure counter.
func BumpUIDRefreshFailure() {
bpfEnfMu.Lock()
c := bpfEnfUIDRefreshFailures
bpfEnfMu.Unlock()
if c != nil {
c.Inc()
}
}
// resetBPFEnforcementMetricsForTest is a test seam.
func resetBPFEnforcementMetricsForTest() {
bpfEnfMu.Lock()
defer bpfEnfMu.Unlock()
bpfEnfDecisionsVec = nil
bpfEnfUIDRefreshTotal = nil
bpfEnfUIDRefreshFailures = nil
}
package daemon
import (
"bufio"
"os"
"strconv"
"strings"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
)
// BPFEnforcementPolicy is the userspace-derived state to be loaded into
// the BPF policy + protected_ports + safe_uids maps. Pure data; no IO.
type BPFEnforcementPolicy struct {
Enforce uint32
DryRun uint32
Ports []uint16
}
// BuildBPFEnforcementPolicy translates config into the policy struct
// the BPF program consumes. Disabled config returns zero policy
// (Enforce=0); the in-kernel program then short-circuits to allow.
func BuildBPFEnforcementPolicy(cfg *config.Config) BPFEnforcementPolicy {
if cfg == nil || !cfg.BPFEnforcement.Enabled {
return BPFEnforcementPolicy{}
}
if !connectionTrackerAllowsBPF(cfg) {
return BPFEnforcementPolicy{}
}
p := BPFEnforcementPolicy{Enforce: 1}
if cfg.BPFEnforcementDryRunEnabled() {
p.DryRun = 1
}
if cfg.BPFEnforcement.DirectSMTPEgress && checks.DirectSMTPEgressBackendEnabled(cfg, "bpf") {
for _, port := range cfg.Detection.DirectSMTPEgress.Ports {
if port > 0 && port <= 65535 {
p.Ports = append(p.Ports, uint16(port))
}
}
}
if len(p.Ports) == 0 {
return BPFEnforcementPolicy{}
}
return p
}
func connectionTrackerAllowsBPF(cfg *config.Config) bool {
switch strings.ToLower(strings.TrimSpace(cfg.Detection.ConnectionTrackerBackend)) {
case "", "auto", "bpf":
return true
default:
return false
}
}
// safeUIDsFromPasswd returns a map of UIDs that should be exempt from
// in-kernel deny: UID 0 (root), UIDs <1000 (system accounts), and any
// platform-known MTA users that happen to live above 1000 on some
// distros. Hosted account UIDs (>=1000) are NOT in the safe map; their
// connections will be evaluated by the in-kernel deny path when
// enforcement is active.
//
// Caller passes a /etc/passwd path; production wiring uses /etc/passwd.
func safeUIDsFromPasswd(path string) (map[uint32]bool, error) {
f, err := os.Open(path) // #nosec G304 -- caller-controlled; production passes /etc/passwd
if err != nil {
return nil, err
}
defer f.Close()
mta := platform.LocalMTAIdentities(platform.Detect())
out := map[uint32]bool{}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Split(line, ":")
if len(fields) < 4 {
continue
}
user := fields[0]
uid64, err := strconv.ParseUint(fields[2], 10, 32)
if err != nil {
continue
}
uid := uint32(uid64)
// UID 0 always safe.
if uid == 0 {
out[uid] = true
continue
}
// System UIDs (<1000): always safe. Daemons, services,
// distro-managed users.
if uid < 1000 {
out[uid] = true
continue
}
// Hosted UIDs (>=1000): NOT safe by default. Exception:
// platform-known MTA users on distros that put them above 1000.
if mta.IsMTAUser(user) {
out[uid] = true
}
}
if err := scanner.Err(); err != nil {
return out, err
}
return out, nil
}
package daemon
import (
"sync"
"sync/atomic"
"time"
)
// UIDRefresherConfig drives a UIDRefresher. Refresh is the function
// called every tick; production wiring re-reads /etc/passwd and
// repopulates the BPF safe_uids map. Interval bounds the period;
// production uses 5 minutes.
type UIDRefresherConfig struct {
Interval time.Duration
Refresh func() error
}
// UIDRefresherStats is a counter snapshot.
type UIDRefresherStats struct {
Refreshes uint64
Failures uint64
}
// UIDRefresher runs the configured Refresh on a fixed interval. Stop
// is idempotent.
type UIDRefresher struct {
cfg UIDRefresherConfig
stop chan struct{}
wg sync.WaitGroup
started atomic.Bool
stopped atomic.Bool
refreshes atomic.Uint64
failures atomic.Uint64
}
// NewUIDRefresher returns a stopped refresher. Call Start to launch.
func NewUIDRefresher(cfg UIDRefresherConfig) *UIDRefresher {
return &UIDRefresher{cfg: cfg, stop: make(chan struct{})}
}
// Start launches the refresh goroutine. Idempotent.
func (r *UIDRefresher) Start() {
if r.started.Swap(true) {
return
}
r.wg.Add(1)
go r.loop()
}
// Stop signals the goroutine and waits. Safe to call multiple times.
func (r *UIDRefresher) Stop() {
if r.stopped.Swap(true) {
return
}
close(r.stop)
r.wg.Wait()
}
// Stats returns a counter snapshot.
func (r *UIDRefresher) Stats() UIDRefresherStats {
return UIDRefresherStats{
Refreshes: r.refreshes.Load(),
Failures: r.failures.Load(),
}
}
func (r *UIDRefresher) loop() {
defer r.wg.Done()
t := time.NewTicker(r.cfg.Interval)
defer t.Stop()
for {
select {
case <-r.stop:
return
case <-t.C:
if err := r.cfg.Refresh(); err != nil {
r.failures.Add(1)
continue
}
r.refreshes.Add(1)
}
}
}
package daemon
import (
"fmt"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/bpf"
)
// emitBPFUnavailableFinding posts an operator-visible warning when a
// BPF-backed live monitor cannot start the kernel-attached path. fallback is
// the selected non-BPF backend; empty means no live fallback is active.
func emitBPFUnavailableFinding(alertCh chan<- alert.Finding, feature, choice, fallback string, err error) bool {
if alertCh == nil {
return false
}
sev := alert.Warning
if choice == bpf.BackendBPF || fallback == "" {
sev = alert.High
}
status := "no live fallback active"
if fallback != "" {
status = fmt.Sprintf("running on %s fallback", fallback)
}
details := ""
if err != nil {
details = err.Error()
}
f := alert.Finding{
Severity: sev,
Check: "bpf_unavailable",
Message: fmt.Sprintf("BPF backend unavailable for %s (operator choice=%q); %s", feature, choice, status),
Details: details,
Timestamp: time.Now(),
}
select {
case alertCh <- f:
return true
default:
return false
}
}
package daemon
import (
"context"
"crypto/ed25519"
"encoding/hex"
"log"
"net"
"os"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/reporting"
)
const (
centralRefreshDefault = 6 * time.Hour
centralBlockThreshold = 80
centralChallengeTTL = 6 * time.Hour
centralBlockTTL = 24 * time.Hour
centralActionQueue = 1024
)
type centralQueuedAction struct {
decision reporting.Decision
ip string
}
// documentationNets are reserved/non-routable ranges (RFC 5737 documentation,
// RFC 3849 IPv6 documentation, RFC 2544 benchmarking) that must never be acted
// on; they are not routable real attackers.
var documentationNets = mustCIDRs(
"192.0.2.0/24", "198.51.100.0/24", "203.0.113.0/24", "198.18.0.0/15", "2001:db8::/32",
)
// knownCentralAction reports whether s is a recognized central action policy.
func knownCentralAction(s string) bool {
switch reporting.Action(s) {
case reporting.ActionOff, reporting.ActionChallenge, reporting.ActionBlockIfLocalCorroborated:
return true
default:
return false
}
}
func mustCIDRs(cidrs ...string) []*net.IPNet {
out := make([]*net.IPNet, 0, len(cidrs))
for _, c := range cidrs {
if _, n, err := net.ParseCIDR(c); err == nil {
out = append(out, n)
}
}
return out
}
// startCentralConsume wires the central scored-set consumer: it pulls and
// verifies the signed set on an interval and installs alert.CentralHook so a
// finding whose IP is in the set is escalated per the configured action. It
// returns the refresh loop, or nil when disabled/misconfigured.
func (d *Daemon) startCentralConsume() func() {
alert.SetCentralHook(nil)
cc := d.cfg.Reputation.Central
if !cc.Enabled {
return nil
}
if cc.SetURL == "" {
log.Printf("central-intel: enabled but set_url is empty; consumer stays off")
return nil
}
pubHex := os.Getenv(cc.PubkeyEnv)
if raw, err := hex.DecodeString(pubHex); err != nil || len(raw) != ed25519.PublicKeySize {
log.Printf("central-intel: %s must hold a 64-hex-char Ed25519 public key; consumer stays off", cc.PubkeyEnv)
return nil
}
policy := reporting.ParseAction(cc.Action)
if cc.Action != "" && !knownCentralAction(cc.Action) {
log.Printf("central-intel: unrecognized action %q, defaulting to challenge", cc.Action)
}
threshold := cc.BlockThreshold
if threshold <= 0 {
threshold = centralBlockThreshold
}
interval := centralRefreshDefault
if cc.RefreshInterval != "" {
if d2, err := time.ParseDuration(cc.RefreshInterval); err == nil && d2 > 0 {
interval = d2
}
}
store := reporting.NewCentralStore(reporting.NewPuller(nil, cc.SetURL, pubHex))
firebreak := d.centralFirebreak()
actions := make(chan centralQueuedAction, centralActionQueue)
var droppedActions atomic.Uint64
alert.SetCentralHook(func(f alert.Finding) {
a, ok := d.planCentralAction(store, policy, threshold, firebreak, f)
if !ok {
return
}
select {
case actions <- a:
default:
droppedActions.Add(1)
}
})
log.Printf("central-intel: enabled (action=%s, threshold=%d, refresh=%s)", policy, threshold, interval)
return func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
logDropped := func() {
if n := droppedActions.Swap(0); n > 0 {
log.Printf("central-intel: action queue full; dropped %d action(s)", n)
}
}
go func() {
<-d.stopCh
cancel()
}()
defer alert.SetCentralHook(nil)
// Initial pull so the set is usable before the first interval.
if err := store.Refresh(ctx); err != nil {
log.Printf("central-intel: initial pull failed: %v", err)
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
logDropped()
return
default:
}
select {
case <-d.stopCh:
logDropped()
return
case a := <-actions:
d.performCentralAction(a)
case <-ticker.C:
logDropped()
if err := store.Refresh(ctx); err != nil {
log.Printf("central-intel: refresh failed: %v", err)
}
}
}
}
}
// applyCentral escalates a finding's IP when it appears in the central set. A
// finding firing on the IP is the node's local corroboration. Firebreaks and
// the action policy gate what happens; central data never blocks on its own.
func (d *Daemon) applyCentral(store *reporting.CentralStore, action reporting.Action, threshold int, firebreak func(string) bool, f alert.Finding) {
a, ok := d.planCentralAction(store, action, threshold, firebreak, f)
if !ok {
return
}
d.performCentralAction(a)
}
func (d *Daemon) planCentralAction(store *reporting.CentralStore, action reporting.Action, threshold int, firebreak func(string) bool, f alert.Finding) (centralQueuedAction, bool) {
ip := f.SourceIP
if ip == "" {
return centralQueuedAction{}, false
}
entry, found := store.Lookup(ip)
dec := reporting.Decide(reporting.DecisionInput{
Found: found,
Score: entry.Score,
Protected: firebreak(ip),
LocallyCorroborated: true, // a finding fired on this IP
}, action, threshold)
if dec == reporting.DecisionIgnore {
return centralQueuedAction{}, false
}
return centralQueuedAction{decision: dec, ip: ip}, true
}
func (d *Daemon) performCentralAction(a centralQueuedAction) {
switch a.decision {
case reporting.DecisionChallenge:
if d.ipList != nil {
d.ipList.AddNonEscalating(a.ip, "central-intel", centralChallengeTTL)
}
case reporting.DecisionBlock:
if d.fwEngine != nil {
if err := d.fwEngine.BlockIP(a.ip, "central-intel (locally corroborated)", centralBlockTTL); err != nil {
log.Printf("central-intel: block %s failed: %v", a.ip, err)
}
}
}
}
// centralFirebreak returns a predicate that reports whether an IP must never be
// acted on from central data: loopback/unspecified/private, documentation
// ranges, or an operator infra_ips entry.
func (d *Daemon) centralFirebreak() func(string) bool {
infraEntries := d.cfg.InfraIPs
if d.cfg.Firewall != nil {
infraEntries = mergeInfraIPs(d.cfg.InfraIPs, d.cfg.Firewall.InfraIPs)
}
var infra []*net.IPNet
for _, raw := range infraEntries {
if _, n, err := net.ParseCIDR(raw); err == nil {
infra = append(infra, n)
continue
}
if ip := net.ParseIP(raw); ip != nil {
bits := 32
if ip.To4() == nil {
bits = 128
}
infra = append(infra, &net.IPNet{IP: ip, Mask: net.CIDRMask(bits, bits)})
}
}
return func(s string) bool {
ip := net.ParseIP(s)
if ip == nil {
return true // unparseable: never act
}
if ip.IsLoopback() || ip.IsUnspecified() || ip.IsPrivate() || ip.IsLinkLocalUnicast() {
return true
}
for _, n := range documentationNets {
if n.Contains(ip) {
return true
}
}
for _, n := range infra {
if n.Contains(ip) {
return true
}
}
return false
}
}
package daemon
import (
"fmt"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/obs"
)
// --- Cloud-relay compromise detection -----------------------------------
//
// Detects the pattern where a mailbox's SMTP AUTH credentials are being
// abused from a rented botnet of cloud VMs. Characteristic signature:
// multiple authenticated sends in a short window, from several distinct
// cloud-provider IPs (Google Cloud, AWS, Azure, etc.), for the SAME
// mailbox user.
//
// A normal user on a residential/ISP IP never matches (cloud-PTR check).
// A legitimate self-hosted script on a single VPS won't match either
// (requires ≥2 distinct source IPs within the window). Credential abuse
// from a rotating fleet trips the detector within minutes of the first
// distinct IPs showing up.
//
// Action: one Critical finding per user per window. Customer-impacting
// response actions follow the configured auto-response and dry-run settings.
// The finding message embeds a source IP so autoblock can evaluate it.
// cloudProviderPTRSuffixes is an intentionally-conservative list of
// hostname suffixes that strongly indicate a cloud-VM source. Adding to
// this list increases detection coverage; removing reduces it. Keep
// entries specific enough to avoid catching ISP-transit ASNs.
var cloudProviderPTRSuffixes = []string{
// Google Cloud Platform
".googleusercontent.com",
".gce.internal",
// AWS EC2 — public PTRs only. `.compute.internal` is intentionally
// excluded: it is an AWS VPC-internal PTR but also appears on
// corporate VPN and self-hosted lab networks that have nothing to
// do with AWS, so matching it would risk suspending mailboxes
// that just happen to have an internal-looking reverse DNS.
".compute.amazonaws.com",
".compute-1.amazonaws.com",
// Microsoft Azure
".cloudapp.net",
".cloudapp.azure.com",
// Oracle Cloud
".oraclecloud.com",
".oraclevcn.com",
// DigitalOcean
".digitalocean.com",
".digitaloceanspaces.com",
// Linode / Akamai Cloud
".members.linode.com",
".linodeusercontent.com",
// Vultr
".vultr.com",
".vultrusercontent.com",
// Hetzner
".hetzner.com",
".your-server.de",
// OVH / OVHcloud
".ovh.net",
".ovhcloud.com",
".ovh.ca",
".ovh.us",
// Contabo
".contabo.net",
".contabo.host",
".contaboserver.net",
}
// isCloudRelayAllowed reports whether the AUTH user is opted out of the
// email_cloud_relay_abuse detector via the operator-managed allowlists.
// users matches whole mailboxes; domains matches the domain part. Both
// comparisons are case-insensitive. An empty user is never considered
// allowed (defense against malformed log lines).
func isCloudRelayAllowed(user string, users, domains []string) bool {
if user == "" {
return false
}
for _, u := range users {
if strings.EqualFold(user, u) {
return true
}
}
if len(domains) == 0 {
return false
}
at := strings.LastIndexByte(user, '@')
if at < 0 || at >= len(user)-1 {
return false
}
dom := user[at+1:]
for _, d := range domains {
if strings.EqualFold(dom, d) {
return true
}
}
return false
}
// isCloudProviderPTR reports whether the given PTR hostname belongs to a
// recognized public-cloud provider. Case-insensitive suffix match.
func isCloudProviderPTR(ptr string) bool {
if ptr == "" {
return false
}
p := strings.ToLower(ptr)
for _, suffix := range cloudProviderPTRSuffixes {
if strings.HasSuffix(p, suffix) {
return true
}
}
return false
}
// extractEximHostname parses the H=<hostname> field from an exim log line.
// The field often looks like "H=hostname.example (helo.string) [IP]:port"
// — we want the PTR-derived hostname before the HELO-in-parens.
func extractEximHostname(line string) string {
idx := strings.Index(line, " H=")
if idx < 0 {
return ""
}
rest := line[idx+3:]
// Terminate at first space, tab, or opening paren (HELO string).
end := len(rest)
for i, r := range rest {
if r == ' ' || r == '\t' || r == '(' {
end = i
break
}
}
return strings.TrimSpace(rest[:end])
}
// cloudRelayWindow tracks authenticated sends from cloud IPs for one user.
// Bounded so memory can't grow unbounded from a misbehaving log stream.
type cloudRelayWindow struct {
mu sync.Mutex
events []cloudRelayEvent
firedAt time.Time // last Critical emission; dedup guard
lastEvent time.Time // last append — used to garbage-collect idle entries
}
type cloudRelayEvent struct {
at time.Time
ip string
ptr string
}
// cloudRelayWindows tracks per-user cloud-relay activity.
var cloudRelayWindows sync.Map // map[string]*cloudRelayWindow
func lockCloudRelayWindowForUpdate(user string, now time.Time) *cloudRelayWindow {
for {
val, _ := cloudRelayWindows.LoadOrStore(user, &cloudRelayWindow{lastEvent: now})
w, ok := val.(*cloudRelayWindow)
if !ok {
cloudRelayWindows.Delete(user)
continue
}
w.mu.Lock()
// Re-check while holding w.mu so the eviction sweep cannot delete
// a stale entry and orphan a parser update made through that pointer.
current, ok := cloudRelayWindows.Load(user)
currentWindow, currentOK := current.(*cloudRelayWindow)
if ok && currentOK && currentWindow == w {
return w
}
w.mu.Unlock()
}
}
// Detection thresholds. Two OR-combined signals within the same 60-min
// sliding window:
//
// A. Multi-IP burst: ≥ cloudRelayMinEvents sends from ≥ cloudRelayMinDistinctIP
// distinct cloud IPs. Catches rented-fleet abuse rotating IPs per-send
// (the typical credential-stuffing spam pattern).
//
// B. Volume burst: ≥ cloudRelayHighVolumeEvents sends regardless of
// distinct-IP count. Catches paced attacks that deliberately use one
// cloud IP per day to evade signal A. Threshold sits well above any
// legitimate SaaS integration seen on production (SmartBill ~2/hr,
// Nylas ~2/hr, WP transactional ≤3/hr).
//
// Tuning rationale: these values were chosen from the Apr 2026 incident
// analysis. A single-mailbox user with a legit single-VPS cron averaging
// ≤14 mails/hr stays silent; anything above that is either a compromised
// relay or a SaaS integration that should be added to
// `email_protection.high_volume_senders`.
const (
cloudRelayWindow_ = 60 * time.Minute
cloudRelayMinEvents = 3
cloudRelayMinDistinctIP = 2
cloudRelayHighVolumeEvents = 15
cloudRelayDedupCooldown = 60 * time.Minute
cloudRelayMaxEvents = 256 // per-user cap; prevents unbounded growth
)
// parseCloudRelayFinding evaluates an exim acceptance line for the
// cloud-relay compromise pattern. Called from parseEximLogLine. Returns
// zero or one finding. Never auto-suspends on its own — emits a finding
// whose Message embeds the source IP; the existing autoblock + suspend
// pipeline picks it up by check name.
func parseCloudRelayFinding(line string, cfg *config.Config) []alert.Finding {
// Only care about authenticated outbound acceptance lines.
if !strings.Contains(line, " <= ") || !strings.Contains(line, "A=dovecot_") {
return nil
}
user := extractAuthUser(line)
if user == "" {
return nil
}
if isHighVolumeSender(user, cfg.EmailProtection.HighVolumeSenders) {
return nil
}
if isCloudRelayAllowed(user, cfg.EmailProtection.CloudRelay.AllowUsers, cfg.EmailProtection.CloudRelay.AllowDomains) {
return nil
}
ptr := extractEximHostname(line)
if !isCloudProviderPTR(ptr) {
return nil
}
ip := extractBracketedIP(line)
if ip == "" {
// Without an IP we can't dedup distinct sources; bail silently
// so we don't count half-parsed records toward the threshold.
return nil
}
now := time.Now()
w := lockCloudRelayWindowForUpdate(user, now)
defer w.mu.Unlock()
// Prune anything older than the window.
cutoff := now.Add(-cloudRelayWindow_)
kept := w.events[:0]
for _, e := range w.events {
if e.at.After(cutoff) {
kept = append(kept, e)
}
}
w.events = kept
// Append this event (with cap).
if len(w.events) < cloudRelayMaxEvents {
w.events = append(w.events, cloudRelayEvent{at: now, ip: ip, ptr: ptr})
}
w.lastEvent = now
// Already fired recently for this user? Dedup.
if !w.firedAt.IsZero() && now.Sub(w.firedAt) < cloudRelayDedupCooldown {
return nil
}
// Evaluate thresholds.
distinctIPs := make(map[string]struct{}, len(w.events))
for _, e := range w.events {
distinctIPs[e.ip] = struct{}{}
}
multiIPBurst := len(w.events) >= cloudRelayMinEvents && len(distinctIPs) >= cloudRelayMinDistinctIP
volumeBurst := len(w.events) >= cloudRelayHighVolumeEvents
if !multiIPBurst && !volumeBurst {
return nil
}
w.firedAt = now
// Build an IP list for the details (newest first, deduped).
seen := make(map[string]struct{}, len(w.events))
ips := make([]string, 0, len(distinctIPs))
for i := len(w.events) - 1; i >= 0; i-- {
e := w.events[i]
if _, dup := seen[e.ip]; dup {
continue
}
seen[e.ip] = struct{}{}
ips = append(ips, e.ip)
}
// The block source is carried in the structured SourceIP field below; the
// IP is included in the message only for operator-facing readability.
message := fmt.Sprintf(
"Email account %s sent %d authenticated messages from %d cloud-provider IPs in %d minutes - credentials compromised - from %s",
user, len(w.events), len(distinctIPs), int(cloudRelayWindow_.Minutes()), ips[0],
)
details := fmt.Sprintf(
"Authenticated SMTP submissions for %s in the last %d minutes:\n"+
" total sends: %d\n"+
" distinct source IPs: %d\n"+
" most recent PTR: %s\n"+
" recent IPs: %s\n\n"+
"Legitimate users do not send mail from rented cloud VMs. "+
"This pattern is characteristic of credential abuse by a bulk "+
"phishing operator. Outgoing mail hold and source-IP blocking "+
"follow the configured auto-response and dry-run settings.",
user,
int(cloudRelayWindow_.Minutes()),
len(w.events),
len(distinctIPs),
ptr,
strings.Join(truncateIPList(ips, 8), ", "),
)
mailbox, domain, tenant := splitMailAccount(user)
return []alert.Finding{{
Severity: alert.Critical,
Check: "email_cloud_relay_abuse",
Message: message,
Details: truncateDaemon(details, 800),
SourceIP: ips[0],
Mailbox: mailbox,
Domain: domain,
TenantID: tenant,
}}
}
func handleCloudRelayCredentialAbuse(cfg *config.Config, authUser string) {
if domain := extractDomainFromEmail(authUser); domain != "" {
maybeHoldOutgoingMail(cfg, authUser)
// This is correlation state, not an auto-response action; keep it
// active even when the mail hold is disabled or dry-run gated.
RecordCompromisedDomain(domain)
}
}
func truncateIPList(ips []string, n int) []string {
if len(ips) <= n {
return ips
}
return ips[:n]
}
// cloudRelayEvictWindow is how long a per-user cloudRelayWindow can stay
// idle before it is evicted. Picked at 2x cloudRelayWindow_ so a user who
// just barely cleared the threshold does not lose their entry before the
// dedup cooldown can suppress a repeat.
const cloudRelayEvictWindow = 2 * cloudRelayWindow_
// StartCloudRelayEviction periodically prunes per-user cloud-relay
// windows that have not seen activity in cloudRelayEvictWindow. Without
// this sweep the cloudRelayWindows sync.Map grows linearly with every
// distinct authenticated sender ever seen, including users deleted by
// the operator. Pairs with the firedAt dedup guard so repeat alerts on
// long-lived attackers still pass.
func StartCloudRelayEviction(stopCh <-chan struct{}) {
obs.Go("cloud-relay-eviction", func() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case now := <-ticker.C:
evictCloudRelayWindows(now)
}
}
})
}
// evictCloudRelayWindows deletes per-user entries whose lastEvent is
// older than cloudRelayEvictWindow. Safe to call from tests.
func evictCloudRelayWindows(now time.Time) {
cutoff := now.Add(-cloudRelayEvictWindow)
cloudRelayWindows.Range(func(key, val any) bool {
w, ok := val.(*cloudRelayWindow)
if !ok {
cloudRelayWindows.Delete(key)
return true
}
w.mu.Lock()
if w.lastEvent.Before(cutoff) {
cloudRelayWindows.CompareAndDelete(key, val)
}
w.mu.Unlock()
return true
})
}
package daemon
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/store"
)
// --- Retrospective scan for cloud-relay credential abuse -----------------
//
// The realtime watcher in cloud_relay.go only catches traffic arriving
// AFTER CSM starts. A credential-abuse spam run can go for weeks before
// an operator notices (a real incident saw 230 outbound sends from one
// compromised account over 20 days before being flagged). This scanner replays
// the last N hours of exim_mainlog through the same rule, so on CSM
// startup any in-progress or recent compromise is surfaced immediately.
//
// Runs once at daemon startup. Not part of the tiered-check registry
// because it is purely an event-log replay; the realtime watcher owns
// live state thereafter.
// cloudRelayScanPathDefault is where cPanel exim writes its mainlog.
const cloudRelayScanPathDefault = "/var/log/exim_mainlog"
// Memory caps for the retro scan. A compromised account or a crafted log
// can otherwise grow byUser without bound. The thresholds are set well
// above the volume detector (cloudRelayHighVolumeEvents=15 in a 60-min
// window) so legitimate detection is unaffected; the caps only kick in
// for pathological volumes that would balloon memory.
const (
cloudRelayScanMaxEventsPerUser = 5000
cloudRelayScanMaxUsers = 10000
)
// CloudRelayScanPath is the log file path scanned at startup. Exported
// via var (not const) for tests.
var CloudRelayScanPath = cloudRelayScanPathDefault
// ScanEximHistoryForCloudRelay replays the tail of exim_mainlog for the
// last `lookback` duration and returns a finding per mailbox that
// exceeds the cloud-relay thresholds. Safe to call from goroutines.
//
// The scanner respects EmailProtection.HighVolumeSenders and the
// detector-scoped EmailProtection.CloudRelay.AllowUsers / .AllowDomains
// allowlists, mirrors the realtime detector's thresholds exactly, and
// uses a per-user persistent marker in the global store to avoid
// re-emitting the same finding on successive restarts.
func ScanEximHistoryForCloudRelay(cfg *config.Config, logPath string, now time.Time, lookback time.Duration) []alert.Finding {
if logPath == "" {
logPath = CloudRelayScanPath
}
// #nosec G304 -- logPath is operator-configured / hardcoded to cPanel default.
f, err := os.Open(logPath)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
since := now.Add(-lookback)
byUser := make(map[string]*cloudRelayScanAccumulator)
// Use bufio.Reader rather than bufio.Scanner: an exim line can
// occasionally exceed whatever fixed Scanner buffer we set (e.g.
// a spam run with a huge Base64 subject). Scanner returns an
// ErrTooLong which aborts the whole loop — missing every later
// compromise event. Reader.ReadString lets us skip oversized
// lines and keep going.
reader := bufio.NewReaderSize(f, 256*1024)
for {
line, rerr := reader.ReadString('\n')
if len(line) > 0 {
if line[len(line)-1] == '\n' {
line = line[:len(line)-1]
}
processCloudRelayScanLine(line, cfg, since, byUser)
}
if rerr == nil {
continue
}
if errors.Is(rerr, io.EOF) {
break
}
if errors.Is(rerr, bufio.ErrBufferFull) {
// Line longer than 256 KB — drain it and move on.
// Real exim acceptance lines are well under 10 KB;
// anything longer is almost certainly a pathological
// subject we can't usefully parse anyway.
if drainErr := drainUntilNewline(reader); drainErr != nil {
break
}
continue
}
// Any other I/O error: stop cleanly, don't panic.
break
}
users := make([]string, 0, len(byUser))
for u := range byUser {
users = append(users, u)
}
sort.Strings(users) // stable finding order
var findings []alert.Finding
for _, user := range users {
acc := byUser[user]
if acc == nil || !acc.reportable {
continue
}
maxSends, maxDistinctIPs, fireAt, peakPTR := acc.bestSends, acc.bestDistinctIPs, acc.bestAt, acc.bestPTR
multiIP := maxSends >= cloudRelayMinEvents && maxDistinctIPs >= cloudRelayMinDistinctIP
volume := maxSends >= cloudRelayHighVolumeEvents
if !multiIP && !volume {
continue
}
// Persistent dedup: skip if we've already fired for this user
// and no new event has landed since then.
latestEvent := acc.latestEvent
if alreadyReportedRetro(user, latestEvent) {
continue
}
ips := acc.recentIPs
if len(ips) == 0 {
continue
}
recentIP := ips[0]
msg := fmt.Sprintf(
"RETRO: account %s sent %d authenticated messages from %d cloud-provider IPs (peak 60-min burst) in the last %d hours - credentials compromised - from %s",
user, maxSends, maxDistinctIPs, int(lookback.Hours()), recentIP,
)
details := fmt.Sprintf(
"Retrospective exim_mainlog scan at %s found a cloud-relay pattern:\n"+
" user: %s\n"+
" total cloud-PTR sends (%dh): %d\n"+
" peak 60-min window: %d sends / %d distinct IPs ending at %s\n"+
" peak PTR: %s\n"+
" distinct source IPs observed: %s\n\n"+
"Outgoing mail hold and source-IP blocking follow the configured "+
"auto-response and dry-run settings. Older IPs are left as context "+
"because rented-fleet addresses tend to be recycled outside a 2-hour window.",
now.Format("2006-01-02 15:04:05"),
user,
int(lookback.Hours()),
acc.total,
maxSends, maxDistinctIPs, fireAt.Format("2006-01-02 15:04:05"),
peakPTR,
strings.Join(ips, ", "),
)
mailbox, domain, tenant := splitMailAccount(user)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_cloud_relay_abuse",
Message: msg,
Details: truncateDaemon(details, 900),
Timestamp: now,
SourceIP: recentIP,
Mailbox: mailbox,
Domain: domain,
TenantID: tenant,
})
markReportedRetro(user, latestEvent)
}
return findings
}
// processCloudRelayScanLine parses a single exim log line and, if it is
// an authenticated cloud-PTR acceptance within the lookback window,
// records it under the AUTH user in `byUser`.
func processCloudRelayScanLine(line string, cfg *config.Config, since time.Time, byUser map[string]*cloudRelayScanAccumulator) {
if !strings.Contains(line, " <= ") || !strings.Contains(line, "A=dovecot_") {
return
}
ts, ok := parseEximTimestamp(line)
if !ok || ts.Before(since) {
return
}
user := extractAuthUser(line)
if user == "" || isHighVolumeSender(user, cfg.EmailProtection.HighVolumeSenders) {
return
}
if isCloudRelayAllowed(user, cfg.EmailProtection.CloudRelay.AllowUsers, cfg.EmailProtection.CloudRelay.AllowDomains) {
return
}
ptr := extractEximHostname(line)
if !isCloudProviderPTR(ptr) {
return
}
ip := extractBracketedIP(line)
if ip == "" {
return
}
acc, exists := byUser[user]
if !exists {
if len(byUser) >= cloudRelayScanMaxUsers {
pruneCloudRelayScanUsers(byUser, ts)
}
if len(byUser) >= cloudRelayScanMaxUsers {
evictOldestCloudRelayScanUser(byUser)
}
if len(byUser) >= cloudRelayScanMaxUsers {
return
}
acc = newCloudRelayScanAccumulator()
byUser[user] = acc
}
acc.record(cloudRelayScanEvent{at: ts, ip: ip, ptr: ptr})
}
// drainUntilNewline reads from reader and discards bytes until a newline
// is consumed or EOF is hit. Returns io.EOF if the reader is exhausted.
func drainUntilNewline(reader *bufio.Reader) error {
for {
_, err := reader.ReadSlice('\n')
if err == nil {
return nil
}
if errors.Is(err, bufio.ErrBufferFull) {
// Still inside the oversized line — keep draining.
continue
}
return err
}
}
// cloudRelayScanEvent is a single timestamped cloud-PTR AUTH send
// replayed from the log.
type cloudRelayScanEvent struct {
at time.Time
ip string
ptr string
}
type cloudRelayScanAccumulator struct {
events []cloudRelayScanEvent
ipCounts map[string]int
recentIPs []string
total int
latestEvent time.Time
bestSends int
bestDistinctIPs int
bestAt time.Time
bestPTR string
reportable bool
}
func newCloudRelayScanAccumulator() *cloudRelayScanAccumulator {
return &cloudRelayScanAccumulator{
ipCounts: make(map[string]int),
}
}
func (acc *cloudRelayScanAccumulator) record(event cloudRelayScanEvent) {
acc.total++
if acc.latestEvent.IsZero() || event.at.After(acc.latestEvent) {
acc.latestEvent = event.at
}
acc.rememberRecentIP(event.ip)
cutoff := event.at.Add(-cloudRelayWindow_)
drop := 0
for drop < len(acc.events) && acc.events[drop].at.Before(cutoff) {
acc.removeWindowIP(acc.events[drop].ip)
drop++
}
if drop > 0 {
clear(acc.events[:drop])
acc.events = acc.events[drop:]
}
if len(acc.events) >= cloudRelayScanMaxEventsPerUser {
acc.removeWindowIP(acc.events[0].ip)
var zero cloudRelayScanEvent
acc.events[0] = zero
acc.events = acc.events[1:]
}
acc.events = append(acc.events, event)
acc.ipCounts[event.ip]++
sends := len(acc.events)
distinctIPs := len(acc.ipCounts)
if sends > acc.bestSends || (sends == acc.bestSends && distinctIPs > acc.bestDistinctIPs) {
acc.bestSends = sends
acc.bestDistinctIPs = distinctIPs
acc.bestAt = event.at
acc.bestPTR = event.ptr
}
if acc.bestSends >= cloudRelayHighVolumeEvents ||
(acc.bestSends >= cloudRelayMinEvents && acc.bestDistinctIPs >= cloudRelayMinDistinctIP) {
acc.reportable = true
}
}
func (acc *cloudRelayScanAccumulator) removeWindowIP(ip string) {
count := acc.ipCounts[ip]
if count <= 1 {
delete(acc.ipCounts, ip)
return
}
acc.ipCounts[ip] = count - 1
}
func (acc *cloudRelayScanAccumulator) rememberRecentIP(ip string) {
if ip == "" {
return
}
for i, existing := range acc.recentIPs {
if existing != ip {
continue
}
copy(acc.recentIPs[1:i+1], acc.recentIPs[:i])
acc.recentIPs[0] = ip
return
}
acc.recentIPs = append(acc.recentIPs, "")
copy(acc.recentIPs[1:], acc.recentIPs[:len(acc.recentIPs)-1])
acc.recentIPs[0] = ip
if len(acc.recentIPs) > 10 {
acc.recentIPs = acc.recentIPs[:10]
}
}
func pruneCloudRelayScanUsers(byUser map[string]*cloudRelayScanAccumulator, now time.Time) {
cutoff := now.Add(-cloudRelayWindow_)
for user, acc := range byUser {
if acc == nil || (!acc.reportable && acc.latestEvent.Before(cutoff)) {
delete(byUser, user)
}
}
}
func evictOldestCloudRelayScanUser(byUser map[string]*cloudRelayScanAccumulator) {
var oldestUser string
var oldestSeen time.Time
for user, acc := range byUser {
if acc == nil {
delete(byUser, user)
return
}
if acc.reportable {
continue
}
if oldestUser == "" || acc.latestEvent.Before(oldestSeen) {
oldestUser = user
oldestSeen = acc.latestEvent
}
}
if oldestUser != "" {
delete(byUser, oldestUser)
}
}
// maxCloudRelayBurst finds the strongest 60-min window in a sorted event
// list. Returns (sends, distinctIPs, peakEnd, peakPTR), where peakEnd is
// the timestamp of the LAST event in the best window so operators see
// when the burst peaked, not when it started.
func maxCloudRelayBurst(events []cloudRelayScanEvent) (int, int, time.Time, string) {
if len(events) == 0 {
return 0, 0, time.Time{}, ""
}
bestSends, bestDistinct := 0, 0
var bestAt time.Time
var bestPTR string
left := 0
ipCounts := make(map[string]int)
for right, event := range events {
ipCounts[event.ip]++
for event.at.Sub(events[left].at) > cloudRelayWindow_ {
leftIP := events[left].ip
if ipCounts[leftIP] <= 1 {
delete(ipCounts, leftIP)
} else {
ipCounts[leftIP]--
}
left++
}
sends := right - left + 1
distinct := len(ipCounts)
// "Best" = highest send count; tie-break by distinct IPs.
if sends > bestSends || (sends == bestSends && distinct > bestDistinct) {
bestSends = sends
bestDistinct = distinct
bestAt = event.at
bestPTR = event.ptr
}
}
return bestSends, bestDistinct, bestAt, bestPTR
}
// parseEximTimestamp extracts the "YYYY-MM-DD HH:MM:SS" timestamp prefix
// from an exim log line. Returns false on any parse failure.
func parseEximTimestamp(line string) (time.Time, bool) {
if len(line) < 19 {
return time.Time{}, false
}
t, err := time.ParseInLocation("2006-01-02 15:04:05", line[:19], time.Local)
if err != nil {
return time.Time{}, false
}
return t, true
}
// alreadyReportedRetro returns true when the latest event for this user
// is older than or equal to the persisted marker (meaning: nothing new
// since we last alerted).
func alreadyReportedRetro(user string, latestEvent time.Time) bool {
db := store.Global()
if db == nil {
return false
}
raw := db.GetMetaString("cloudrelay_retro:" + user)
if raw == "" {
return false
}
prev, err := time.Parse(time.RFC3339, raw)
if err != nil {
return false
}
return !latestEvent.After(prev)
}
// extractSenderFromCloudRelayMessage pulls the sender mailbox out of a
// finding message emitted by ScanEximHistoryForCloudRelay. Returns ""
// when the message is not from this check (defensive — never panics on
// unexpected input).
func extractSenderFromCloudRelayMessage(msg string) string {
const marker = "account "
idx := strings.Index(msg, marker)
if idx < 0 {
return ""
}
rest := msg[idx+len(marker):]
sp := strings.IndexByte(rest, ' ')
if sp <= 0 {
return ""
}
candidate := rest[:sp]
if !strings.Contains(candidate, "@") {
return ""
}
return candidate
}
func markReportedRetro(user string, latestEvent time.Time) {
db := store.Global()
if db == nil {
return
}
_ = db.SetMetaString("cloudrelay_retro:"+user, latestEvent.Format(time.RFC3339))
}
//go:build !(linux && bpf)
package daemon
import (
"context"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/bpf"
"github.com/pidginhost/csm/internal/config"
)
// connectionBPF is the no-tag placeholder for the BPF cgroup/connect backend.
// The real type with map handles, links, and the ringbuf reader lives in
// connection_bpf.go behind //go:build linux && bpf. On any other build, the
// coordinator never reaches this stub: startConnectionBPF returns
// bpf.ErrNotBuilt before a value is constructed.
type connectionBPF struct{}
func (c *connectionBPF) Mode() string { return "bpf" }
func (c *connectionBPF) EventCount() uint64 { return 0 }
func (c *connectionBPF) Run(_ context.Context) {}
func startConnectionBPF(_ context.Context, _ chan<- alert.Finding, _ *config.Config) (*connectionBPF, error) {
return nil, bpf.ErrNotBuilt
}
package daemon
import (
"context"
"fmt"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/verdict"
)
func applyBPFEnforcementVerdict(ctx context.Context, cfg *config.Config, ev ConnectionEvent, f *alert.Finding) {
if cfg == nil || f == nil || !cfg.BPFEnforcement.VerdictCallback || !cfg.AutoResponse.VerdictCallback.Enabled {
return
}
if ev.Decision != 1 && ev.Decision != 2 {
return
}
vcCfg := cfg.AutoResponse.VerdictCallback
vc := verdict.New(verdict.Config{
URL: vcCfg.URL,
HMACSecret: vcCfg.HMACSecret,
HMACSecretEnv: vcCfg.HMACSecretEnv,
RequireResponseSignature: vcCfg.RequireResponseSignature,
AllowUnsigned: vcCfg.AllowUnsigned,
Timeout: time.Duration(vcCfg.TimeoutSec) * time.Second,
})
resp, err := vc.Ask(ctx, verdict.Request{
IP: ev.DstIP.String(),
Reason: fmt.Sprintf("bpf_enforcement:%s:%d", f.Check, ev.DstPort),
Severity: f.Severity.String(),
Source: "bpf_enforcement",
})
if err != nil {
csmlog.Warn("bpf enforcement verdict callback failed", "err", err, "dst", ev.DstIP.String())
return
}
if resp.TenantID != "" && f.TenantID == "" {
f.TenantID = resp.TenantID
}
if resp.Verdict != "" {
appendFindingDetail(f, "Verdict callback: "+resp.Verdict)
}
if resp.TenantID != "" {
appendFindingDetail(f, "Verdict tenant: "+resp.TenantID)
}
if resp.Note != "" {
appendFindingDetail(f, "Verdict note: "+resp.Note)
}
}
func appendFindingDetail(f *alert.Finding, detail string) {
if detail == "" {
return
}
if f.Details == "" {
f.Details = detail
return
}
f.Details += ", " + detail
}
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64
package connection_bpfprog
import (
"bytes"
_ "embed"
"fmt"
"io"
"structs"
"github.com/cilium/ebpf"
)
type ConnectionConnEvent struct {
_ structs.HostLayout
Uid uint32
Pid uint32
Family uint32
DstPort uint32
DstIp4 uint32
DstIp6 [16]uint8
Comm [16]uint8
Decision uint32
}
type ConnectionPolicyState struct {
_ structs.HostLayout
Enforce uint32
DryRun uint32
ProtectedPorts uint32
}
// LoadConnection returns the embedded CollectionSpec for Connection.
func LoadConnection() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_ConnectionBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load Connection: %w", err)
}
return spec, err
}
// LoadConnectionObjects loads Connection and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *ConnectionObjects
// *ConnectionPrograms
// *ConnectionMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func LoadConnectionObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := LoadConnection()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// ConnectionSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type ConnectionSpecs struct {
ConnectionProgramSpecs
ConnectionMapSpecs
ConnectionVariableSpecs
}
// ConnectionProgramSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type ConnectionProgramSpecs struct {
CsmConnect4 *ebpf.ProgramSpec `ebpf:"csm_connect4"`
CsmConnect6 *ebpf.ProgramSpec `ebpf:"csm_connect6"`
}
// ConnectionMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type ConnectionMapSpecs struct {
Events *ebpf.MapSpec `ebpf:"events"`
Policy *ebpf.MapSpec `ebpf:"policy"`
ProtectedPorts *ebpf.MapSpec `ebpf:"protected_ports"`
SafeUids *ebpf.MapSpec `ebpf:"safe_uids"`
}
// ConnectionVariableSpecs contains global variables before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type ConnectionVariableSpecs struct {
Unused *ebpf.VariableSpec `ebpf:"unused"`
UnusedPolicy *ebpf.VariableSpec `ebpf:"unused_policy"`
}
// ConnectionObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to LoadConnectionObjects or ebpf.CollectionSpec.LoadAndAssign.
type ConnectionObjects struct {
ConnectionPrograms
ConnectionMaps
ConnectionVariables
}
func (o *ConnectionObjects) Close() error {
return _ConnectionClose(
&o.ConnectionPrograms,
&o.ConnectionMaps,
)
}
// ConnectionMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to LoadConnectionObjects or ebpf.CollectionSpec.LoadAndAssign.
type ConnectionMaps struct {
Events *ebpf.Map `ebpf:"events"`
Policy *ebpf.Map `ebpf:"policy"`
ProtectedPorts *ebpf.Map `ebpf:"protected_ports"`
SafeUids *ebpf.Map `ebpf:"safe_uids"`
}
func (m *ConnectionMaps) Close() error {
return _ConnectionClose(
m.Events,
m.Policy,
m.ProtectedPorts,
m.SafeUids,
)
}
// ConnectionVariables contains all global variables after they have been loaded into the kernel.
//
// It can be passed to LoadConnectionObjects or ebpf.CollectionSpec.LoadAndAssign.
type ConnectionVariables struct {
Unused *ebpf.Variable `ebpf:"unused"`
UnusedPolicy *ebpf.Variable `ebpf:"unused_policy"`
}
// ConnectionPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to LoadConnectionObjects or ebpf.CollectionSpec.LoadAndAssign.
type ConnectionPrograms struct {
CsmConnect4 *ebpf.Program `ebpf:"csm_connect4"`
CsmConnect6 *ebpf.Program `ebpf:"csm_connect6"`
}
func (p *ConnectionPrograms) Close() error {
return _ConnectionClose(
p.CsmConnect4,
p.CsmConnect6,
)
}
func _ConnectionClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed connection_x86_bpfel.o
var _ConnectionBytes []byte
package daemon
import (
"net"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
)
var (
directSMTPRDNSOnce sync.Once
directSMTPRDNSCache *checks.RDNSCache
)
// rdnsCache is the daemon-wide rDNS cache used by direct SMTP egress
// detection. TTL 30 min, per-lookup deadline 1 second. Resolver wraps
// net.LookupAddr; negative results cached so a slow upstream does not
// stall the connection consumer.
func rdnsCache() *checks.RDNSCache {
directSMTPRDNSOnce.Do(func() {
directSMTPRDNSCache = checks.NewRDNSCache(checks.RDNSCacheConfig{
TTL: 30 * time.Minute,
ResolveDeadline: time.Second,
Resolve: func(ip net.IP) (string, error) {
names, err := net.LookupAddr(ip.String())
if err != nil || len(names) == 0 {
return "", err
}
return strings.TrimSuffix(names[0], "."), nil
},
})
})
return directSMTPRDNSCache
}
// evaluateConnectionEvent runs every per-event detector and returns the
// findings that should be emitted. Pure-ish: no IO and no alertCh
// access. Caller is responsible for attaching process context (which
// MAY do IO via the enricher) and shipping to alertCh.
//
// The function exists in a non-build-tagged file so unit tests can
// drive a synthetic ConnectionEvent through the same policy logic the
// live BPF Run loop uses, without requiring the linux+bpf build tag.
func evaluateConnectionEvent(cfg *config.Config, mta platform.MTAIdents, ev ConnectionEvent, user string) []alert.Finding {
switch ev.Decision {
case 0:
BumpBPFEnforcementDecision(BPFDecisionAllow)
case 1:
BumpBPFEnforcementDecision(BPFDecisionDryRun)
case 2:
BumpBPFEnforcementDecision(BPFDecisionDeny)
}
// Phase 4 note: bpf_enforcement.verdict_callback is applied by the
// BPF Run loop after this evaluator returns. The in-kernel hook
// NEVER waits on HTTP; cgroup/connect is synchronous and a remote
// callback would add latency to every connect.
now := time.Now()
// Phase 3 note: DryRun knobs are not consulted here. Detection runs
// regardless. The knobs gate the Phase 4 auto-response action that
// has not landed yet.
var out []alert.Finding
if checks.DirectSMTPEgressBackendEnabled(cfg, "bpf") {
// Direct SMTP egress (Phase 3). Distinct Check value; the inbound
// smtp_probe meters never see this traffic.
if f, ok := checks.EvaluateDirectSMTPEgress(cfg, checks.DirectSMTPEgressInput{
UID: ev.UID,
User: user,
PID: ev.PID,
Comm: ev.Comm,
DstIP: ev.DstIP,
DstPort: ev.DstPort,
MTA: mta,
}); ok {
if domain := rdnsCache().Lookup(ev.DstIP); domain != "" {
f.Details += ", Domain: " + domain
}
f.Timestamp = now
out = append(out, f)
checks.BumpDirectSMTPEgressFindings()
}
}
// Pre-existing user_outbound_connection detector. SMTP destinations
// are filtered out by checks.safeRemotePorts inside this evaluator,
// so it does not double-fire for a 25/465/587 connect.
if f, ok := checks.EvaluateConnection(cfg, ev.UID, ev.DstIP, ev.DstPort, 0, protoFromFamily(ev.Family), user); ok {
f.Timestamp = now
out = append(out, f)
}
// Bad-ASN egress (host-takeover chain leg). The BPF program emits
// non-root connects only; root egress is covered by the periodic
// /proc/net scan so root-heavy hosts do not flood the ringbuf.
if ev.UID != 0 && cfg.Detection.BadASNOutbound.Enabled {
if lookup := checks.CurrentASNLookup(); lookup != nil {
asn, org := lookup(ev.DstIP.String())
if f, ok := checks.EvaluateBadASNOutbound(cfg, ev.DstIP, asn, org); ok {
f.Timestamp = now
out = append(out, f)
}
}
}
return out
}
// protoFromFamily maps a sockaddr family int to a string label used in
// finding details. Lives here (not in connection_bpf.go) so the
// evaluator helper compiles on darwin without the bpf tag.
func protoFromFamily(f uint32) string {
if f == 10 {
return "tcp6"
}
return "tcp"
}
package daemon
import (
"encoding/binary"
"errors"
"net"
)
// ConnectionEvent is the userspace shape of a struct conn_event emitted by
// the cgroup/connect BPF program. Field layout matches connection.bpf.c
// byte for byte: scalars are little-endian (host order on amd64/arm64),
// dst_ip4 is network-order, dst_ip6 is the raw 16-byte address.
type ConnectionEvent struct {
UID uint32
PID uint32
Family uint32 // AF_INET=2, AF_INET6=10
DstPort uint16 // host order; BPF program calls bpf_ntohs
DstIP net.IP // resolved from dst_ip4 (v4) or dst_ip6 (v6) per Family
Comm string // null-terminated, up to 16 bytes
Decision uint32 // Phase 4: DECISION_* code (0=allow, 1=dry_run_deny, 2=deny)
}
const connectionEventSize = 4 + 4 + 4 + 4 + 4 + 16 + 16 + 4
func decodeConnectionEvent(b []byte) (ConnectionEvent, error) {
if len(b) < connectionEventSize {
return ConnectionEvent{}, errors.New("connection event short buffer")
}
// The BPF program stores dst_port as __u32 for alignment but calls
// bpf_ntohs() before writing, which guarantees the value fits in 16 bits.
// The narrowing is safe by construction.
dstPort := binary.LittleEndian.Uint32(b[12:16]) & 0xffff
ev := ConnectionEvent{
UID: binary.LittleEndian.Uint32(b[0:4]),
PID: binary.LittleEndian.Uint32(b[4:8]),
Family: binary.LittleEndian.Uint32(b[8:12]),
DstPort: uint16(dstPort), // #nosec G115 -- masked to low 16 bits above
}
switch ev.Family {
case 2: // AF_INET
ipv4 := make(net.IP, 4)
binary.BigEndian.PutUint32(ipv4, binary.BigEndian.Uint32(b[16:20]))
ev.DstIP = ipv4
case 10: // AF_INET6
v6 := make(net.IP, 16)
copy(v6, b[20:36])
ev.DstIP = v6
default:
return ConnectionEvent{}, errors.New("unknown family")
}
ev.Comm = nullTerm(b[36 : 36+16])
ev.Decision = binary.LittleEndian.Uint32(b[52:56])
return ev, nil
}
func indexNull(b []byte) int {
for i, c := range b {
if c == 0 {
return i
}
}
return -1
}
// nullTerm returns the prefix of b up to (but not including) the first NUL,
// or the whole slice if there is none. Helper for fixed-size character
// fields the BPF programs emit (comm[16], filename[256], exe[256]).
func nullTerm(b []byte) string {
if i := indexNull(b); i >= 0 {
return string(b[:i])
}
return string(b)
}
package daemon
import (
"context"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
)
// connectionPoller is the userspace fallback. It runs CheckOutboundUserConnections
// on a fixed interval and forwards any findings to the alert channel. Used
// when the BPF backend is unavailable (no bpf tag, kernel rejects program,
// or operator pinned legacy via Detection.ConnectionTrackerBackend).
type connectionPoller struct {
cfg *config.Config
alertCh chan<- alert.Finding
count atomic.Uint64
}
func newConnectionPoller(cfg *config.Config, alertCh chan<- alert.Finding) *connectionPoller {
return &connectionPoller{cfg: cfg, alertCh: alertCh}
}
func (p *connectionPoller) Mode() string { return "legacy" }
func (p *connectionPoller) EventCount() uint64 { return p.count.Load() }
func (p *connectionPoller) Run(ctx context.Context) {
interval := pollerInterval(p.cfg)
t := time.NewTicker(interval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
findings := checks.CheckOutboundUserConnections(ctx, activeConnectionCfg(p.cfg), nil)
for _, f := range findings {
p.count.Add(1)
select {
case p.alertCh <- f:
default:
csmlog.Warn("connection legacy: alert channel full, dropping finding")
}
}
}
}
}
// pollerInterval returns the configured polling interval, falling back to a
// 30-second default when Detection.ConnectionPollInterval is unset.
func pollerInterval(cfg *config.Config) time.Duration {
if d := cfg.Detection.ConnectionPollInterval; d > 0 {
return d
}
return 30 * time.Second
}
package daemon
import (
"context"
"errors"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/bpf"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/processctx"
)
// StartConnectionTracker selects the active connection-tracker backend based
// on cfg.Detection.ConnectionTrackerBackend and host capability:
//
// "auto" (default) -- try BPF, fall back to legacy polling.
// "bpf" -- require BPF; return nil if unavailable (no fallback).
// "legacy" -- pin legacy polling.
// "none" -- disable the live tracker (the periodic check still runs).
//
// Unknown values fall back to "auto" with a warning. The metric
// csm_bpf_backend{feature="connection_tracker", kind="..."} reflects the
// chosen path.
func StartConnectionTracker(alertCh chan<- alert.Finding, cfg *config.Config) bpf.Backend {
choice := strings.ToLower(strings.TrimSpace(cfg.Detection.ConnectionTrackerBackend))
if choice == "" {
choice = bpf.BackendAuto
}
switch choice {
case bpf.BackendAuto, bpf.BackendBPF, bpf.BackendLegacy, bpf.BackendNone:
default:
csmlog.Warn("connection_tracker: unknown backend choice, using auto", "value", choice)
choice = bpf.BackendAuto
}
if choice == bpf.BackendNone {
csmlog.Info("connection_tracker: disabled by config")
bpf.SetActive("connection_tracker", bpf.BackendNone)
return nil
}
var bpfErr error
if choice == bpf.BackendAuto || choice == bpf.BackendBPF {
if b, err := tryStartConnectionBPFFn(context.Background(), alertCh, cfg); err == nil && b != nil {
csmlog.Info("connection_tracker", "backend", "bpf", "choice", choice)
bpf.SetActive("connection_tracker", bpf.BackendBPF)
return b
} else if err != nil {
bpfErr = err
level := "bpf-unsupported"
if errors.Is(err, bpf.ErrNotBuilt) {
level = "bpf-not-built"
}
csmlog.Info("connection_tracker: BPF unavailable", "state", level, "reason", err.Error(), "choice", choice)
if choice == bpf.BackendBPF {
csmlog.Warn("connection_tracker: backend=bpf but BPF unavailable; no live tracker", "reason", err.Error())
bpf.SetActive("connection_tracker", bpf.BackendNone)
emitBPFUnavailableFinding(alertCh, "connection_tracker", choice, "", err)
return nil
}
}
}
poller := newConnectionPoller(cfg, alertCh)
csmlog.Info("connection_tracker", "backend", "legacy", "choice", choice)
bpf.SetActive("connection_tracker", bpf.BackendLegacy)
if bpfErr != nil {
emitBPFUnavailableFinding(alertCh, "connection_tracker", choice, bpf.BackendLegacy, bpfErr)
}
return poller
}
func activeConnectionCfg(startup *config.Config) *config.Config {
if cfg := config.Active(); cfg != nil {
return cfg
}
return startup
}
// tryStartConnectionBPFFn is the package-level indirection so tests can
// substitute a fake without the bpf build tag.
var tryStartConnectionBPFFn = tryStartConnectionBPF
func tryStartConnectionBPF(ctx context.Context, ch chan<- alert.Finding, cfg *config.Config) (bpf.Backend, error) {
b, err := startConnectionBPF(ctx, ch, cfg)
if err != nil {
return nil, err
}
return b, nil
}
// attachProcessCtxToFinding sets f.Process from the cache when present, or
// enqueues a /proc enrichment so the next finding for the same PID benefits.
// Cache miss is the common case for short-lived processes; the finding is
// emitted with whatever context already exists (often none).
func attachProcessCtxToFinding(cache *processctx.Cache, enr *processctx.Enricher, f *alert.Finding, ev ConnectionEvent) {
if ev.PID == 0 {
return
}
req := processctxRequestFromConnection(ev)
if pc, needsEnrichment := cache.MaterializeVerifiedSnapshot(req); pc != nil {
f.Process = pc
if pc.Account != "" && (f.Check == "direct_smtp_egress" || f.TenantID == "") {
f.TenantID = pc.Account
}
if needsEnrichment {
enr.Enqueue(req)
}
return
}
enr.Enqueue(req)
}
func processctxRequestFromConnection(ev ConnectionEvent) processctx.EnrichRequest {
pid := int(ev.PID)
return processctx.EnrichRequest{
PID: pid,
UID: int(ev.UID),
UIDKnown: true,
Comm: ev.Comm,
StartedAt: processCtxStartedAt(pid),
}
}
package daemon
import (
"encoding/json"
"fmt"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/control"
"github.com/pidginhost/csm/internal/integrity"
"github.com/pidginhost/csm/internal/store"
)
// handleBaseline clears existing state and captures the current host as
// the new known-good reference. Mirrors the old `csm baseline` flow but
// runs inside the daemon so no external lock coordination is needed.
//
// Concurrency: a sync.Mutex on the daemon serialises baselines against
// each other. The baseline sweep still uses checks.ForceAll to bypass
// throttles; dry-run state is scoped to RunAllDryRun.
func (c *ControlListener) handleBaseline(argsRaw json.RawMessage) (any, error) {
var args control.BaselineArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
c.d.baselineMu.Lock()
defer c.d.baselineMu.Unlock()
histCount := 0
if sdb := store.Global(); sdb != nil {
histCount = sdb.HistoryCount()
}
if histCount > 0 && !args.Confirm {
return control.BaselineResult{
HistoryCleared: histCount,
NeedsConfirm: true,
}, nil
}
// Force-all bypasses throttles for the baseline sweep. Dry-run threads
// through RunAllDryRun so a concurrent periodic scanner running in live
// mode is never silenced by this caller.
prevForceAll := checks.ForceAll
checks.ForceAll = true
defer func() { checks.ForceAll = prevForceAll }()
cfg := c.d.currentCfg()
findings, _ := checks.RunAllDryRun(cfg, c.d.store)
c.d.store.SetBaseline(findings)
binaryHash, err := integrity.HashFile(c.d.binaryPath)
if err != nil {
return nil, fmt.Errorf("hashing binary: %w", err)
}
configHash, confdHash, err := integrity.SignConfigFilePreserving(cfg.ConfigFile, cfg.ConfigDir, binaryHash)
if err != nil {
return nil, fmt.Errorf("saving integrity: %w", err)
}
cfg.Integrity.BinaryHash = binaryHash
cfg.Integrity.ConfigHash = configHash
cfg.Integrity.ConfdHash = confdHash
return control.BaselineResult{
Findings: len(findings),
HistoryCleared: histCount,
BinaryHash: binaryHash,
ConfigHash: cfg.Integrity.ConfigHash,
}, nil
}
package daemon
import (
"bytes"
"encoding/json"
"fmt"
"net"
"os"
"os/exec"
"path/filepath"
"time"
"github.com/pidginhost/csm/internal/control"
"github.com/pidginhost/csm/internal/firewall"
"github.com/pidginhost/csm/internal/obs"
)
// Subnet / batch / meta firewall handlers: operations that either span
// multiple IPs (deny-file/allow-file, subnet ops) or reshape the whole
// ruleset (flush/restart/apply-confirmed/confirm). `restart` and
// `apply-confirmed` require a live fwEngine — a dead engine means
// "systemctl restart csm" rather than rebuild-from-handler.
func (c *ControlListener) handleFirewallDenySubnet(argsRaw json.RawMessage) (any, error) {
var args control.FirewallSubnetArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if _, _, err := net.ParseCIDR(args.CIDR); err != nil {
return nil, fmt.Errorf("invalid cidr: %q", args.CIDR)
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
reason := args.Reason
if reason == "" {
reason = "Blocked via CLI"
}
if err := c.d.fwEngine.BlockSubnet(args.CIDR, reason, 0); err != nil {
return nil, fmt.Errorf("block subnet %s: %w", args.CIDR, err)
}
return control.FirewallAckResult{
Message: fmt.Sprintf("Blocked subnet %s - %s", args.CIDR, reason),
}, nil
}
func (c *ControlListener) handleFirewallRemoveSubnet(argsRaw json.RawMessage) (any, error) {
var args control.FirewallSubnetArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if _, _, err := net.ParseCIDR(args.CIDR); err != nil {
return nil, fmt.Errorf("invalid cidr: %q", args.CIDR)
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
if err := c.d.fwEngine.UnblockSubnet(args.CIDR); err != nil {
return nil, fmt.Errorf("remove subnet %s: %w", args.CIDR, err)
}
return control.FirewallAckResult{
Message: fmt.Sprintf("Removed subnet block %s", args.CIDR),
}, nil
}
func (c *ControlListener) handleFirewallDenyFile(argsRaw json.RawMessage) (any, error) {
var args control.FirewallFileArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if len(args.IPs) == 0 {
return nil, fmt.Errorf("no ips in batch")
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
reason := args.Reason
if reason == "" {
reason = "Bulk block via CLI"
}
blocked, failed, skipped := 0, 0, 0
for _, ip := range args.IPs {
if net.ParseIP(ip) == nil {
skipped++
continue
}
// Operator-initiated batch: bypass auto_response.dry_run gate.
if err := c.d.fwEngine.BlockIPForce(ip, reason, 0); err != nil {
failed++
continue
}
blocked++
}
msg := fmt.Sprintf("Blocked %d, skipped %d invalid", blocked, skipped)
if failed > 0 {
msg = fmt.Sprintf("Blocked %d, failed %d, skipped %d invalid", blocked, failed, skipped)
}
return control.FirewallAckResult{Message: msg}, nil
}
func (c *ControlListener) handleFirewallAllowFile(argsRaw json.RawMessage) (any, error) {
var args control.FirewallFileArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if len(args.IPs) == 0 {
return nil, fmt.Errorf("no ips in batch")
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
reason := args.Reason
if reason == "" {
reason = "Bulk allow via CLI"
}
allowed, failed, skipped := 0, 0, 0
for _, ip := range args.IPs {
if net.ParseIP(ip) == nil {
skipped++
continue
}
if err := c.d.fwEngine.AllowIP(ip, reason); err != nil {
failed++
continue
}
allowed++
}
msg := fmt.Sprintf("Allowed %d, skipped %d invalid", allowed, skipped)
if failed > 0 {
msg = fmt.Sprintf("Allowed %d, failed %d, skipped %d invalid", allowed, failed, skipped)
}
return control.FirewallAckResult{Message: msg}, nil
}
func (c *ControlListener) handleFirewallFlush(_ json.RawMessage) (any, error) {
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
cfg := c.d.currentCfg()
before, _ := firewall.LoadState(cfg.StatePath)
count := len(before.Blocked)
if err := c.d.fwEngine.FlushBlocked(); err != nil {
return nil, fmt.Errorf("flushing blocked: %w", err)
}
return control.FirewallAckResult{
Message: fmt.Sprintf("Flushed %d blocked IPs", count),
}, nil
}
func (c *ControlListener) handleFirewallRestart(_ json.RawMessage) (any, error) {
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall engine not running; restart the csm daemon")
}
if err := c.d.fwEngine.Apply(); err != nil {
return nil, fmt.Errorf("applying ruleset: %w", err)
}
state, _ := firewall.LoadState(c.d.currentCfg().StatePath)
return control.FirewallAckResult{
Message: fmt.Sprintf("Firewall restarted. %d blocked, %d allowed IPs restored.", len(state.Blocked), len(state.Allowed)),
}, nil
}
func (c *ControlListener) handleFirewallApplyConfirmed(argsRaw json.RawMessage) (any, error) {
var args control.FirewallApplyConfirmedArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
minutes := args.Minutes
if minutes <= 0 || minutes > 60 {
minutes = 3
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall engine not running; restart the csm daemon")
}
cfg := c.d.currentCfg()
confirmFile, rollbackFile, legacyRollbackFile := firewallRollbackFiles(cfg.StatePath)
if err := os.MkdirAll(filepath.Dir(rollbackFile), 0700); err != nil {
return nil, fmt.Errorf("creating firewall rollback dir: %w", err)
}
if err := removeFirewallRollbackFiles(rollbackFile, legacyRollbackFile); err != nil {
return nil, err
}
if err := writeFirewallRollbackFile(rollbackFile); err != nil {
return nil, err
}
if err := c.d.fwEngine.Apply(); err != nil {
_ = removeFileIfExists(rollbackFile)
return nil, fmt.Errorf("applying ruleset: %w", err)
}
deadline := time.Now().Add(time.Duration(minutes) * time.Minute)
if err := os.WriteFile(confirmFile, []byte(deadline.Format(time.RFC3339)), 0600); err != nil {
if restoreErr := applyFirewallRollbackFile(rollbackFile); restoreErr != nil {
_ = removeFileIfExists(confirmFile)
return nil, fmt.Errorf("writing confirm marker: %w; rollback restore failed: %v", err, restoreErr)
}
_ = removeFirewallRollbackFiles(confirmFile, rollbackFile)
return nil, fmt.Errorf("writing confirm marker: %w; previous ruleset restored", err)
}
// Rollback goroutine lives in the daemon (long-lived, so it
// survives CLI exit -- an improvement over the old CLI version).
obs.SafeGo("fw-apply-confirmed-rollback", func() {
time.Sleep(time.Duration(minutes) * time.Minute)
if err := restoreFirewallRollback(confirmFile, rollbackFile); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Firewall rollback failed: %v\n", ts(), err)
}
})
state, _ := firewall.LoadState(cfg.StatePath)
return control.FirewallAckResult{
Message: fmt.Sprintf("Firewall applied with %d-minute rollback timer. %d blocked, %d allowed. Run `csm firewall confirm` to keep.", minutes, len(state.Blocked), len(state.Allowed)),
}, nil
}
func (c *ControlListener) handleFirewallConfirm(_ json.RawMessage) (any, error) {
cfg := c.d.currentCfg()
confirmFile, rollbackFile, legacyRollbackFile := firewallRollbackFiles(cfg.StatePath)
if _, err := os.Stat(confirmFile); err != nil {
if !os.IsNotExist(err) {
return nil, fmt.Errorf("checking confirm marker: %w", err)
}
if cleanupErr := removeFirewallRollbackFiles(rollbackFile, legacyRollbackFile); cleanupErr != nil {
return nil, cleanupErr
}
return control.FirewallAckResult{
Message: "No pending confirmation. Firewall is already confirmed.",
}, nil
}
if err := removeFirewallRollbackFiles(confirmFile, rollbackFile, legacyRollbackFile); err != nil {
return nil, err
}
return control.FirewallAckResult{
Message: "Firewall confirmed. Rollback timer cancelled.",
}, nil
}
func firewallRollbackFiles(statePath string) (confirmFile, rollbackFile, legacyRollbackFile string) {
firewallDir := filepath.Join(statePath, "firewall")
return filepath.Join(firewallDir, "confirm_pending"),
filepath.Join(firewallDir, "rollback.nft"),
filepath.Join(firewallDir, "rollback.sh")
}
func writeFirewallRollbackFile(rollbackFile string) error {
// #nosec G204 -- "nft list ruleset" is literal.
nftDump, err := exec.Command("nft", "list", "ruleset").Output()
if err != nil {
return fmt.Errorf("capturing rollback ruleset: %w", err)
}
// The dump is nft syntax, so store it as data consumed by nft -f.
// An empty live ruleset still needs a rollback file; flush ruleset
// restores that state.
payload := make([]byte, 0, len("flush ruleset\n")+len(nftDump)+1)
payload = append(payload, "flush ruleset\n"...)
payload = append(payload, nftDump...)
if len(nftDump) > 0 && nftDump[len(nftDump)-1] != '\n' {
payload = append(payload, '\n')
}
// #nosec G306 -- root-only state dir; this is data, not an executable.
if err := os.WriteFile(rollbackFile, payload, 0600); err != nil {
_ = removeFileIfExists(rollbackFile)
return fmt.Errorf("writing rollback ruleset: %w", err)
}
return nil
}
func restoreFirewallRollback(confirmFile, rollbackFile string) error {
if _, err := os.Stat(confirmFile); err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("checking confirm marker: %w", err)
}
if err := applyFirewallRollbackFile(rollbackFile); err != nil {
return err
}
return removeFirewallRollbackFiles(confirmFile, rollbackFile, legacyRollbackFileFor(rollbackFile))
}
func applyFirewallRollbackFile(rollbackFile string) error {
if _, err := os.Stat(rollbackFile); err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("rollback ruleset missing")
}
return fmt.Errorf("checking rollback ruleset: %w", err)
}
// #nosec G204 -- nft is hardcoded; rollbackFile is a CSM-written path.
out, err := exec.Command("nft", "-f", rollbackFile).CombinedOutput()
if err != nil {
out = bytes.TrimSpace(out)
if len(out) > 0 {
return fmt.Errorf("restoring rollback ruleset: %w: %s", err, out)
}
return fmt.Errorf("restoring rollback ruleset: %w", err)
}
return nil
}
func removeFirewallRollbackFiles(paths ...string) error {
for _, path := range paths {
if err := removeFileIfExists(path); err != nil {
return err
}
}
return nil
}
func legacyRollbackFileFor(rollbackFile string) string {
return filepath.Join(filepath.Dir(rollbackFile), "rollback.sh")
}
func removeFileIfExists(path string) error {
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("removing %s: %w", filepath.Base(path), err)
}
return nil
}
package daemon
import (
"encoding/json"
"fmt"
"net"
"time"
"github.com/pidginhost/csm/internal/control"
)
// Single-IP firewall mutation handlers. Each validates args, guards on
// c.d.fwEngine != nil, calls the matching engine method, and returns a
// FirewallAckResult with a human-readable message the CLI prints verbatim.
func (c *ControlListener) handleFirewallBlock(argsRaw json.RawMessage) (any, error) {
var args control.FirewallIPArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if net.ParseIP(args.IP) == nil {
return nil, fmt.Errorf("invalid ip: %q", args.IP)
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
reason := args.Reason
if reason == "" {
reason = "Blocked via CLI"
}
// Operator-initiated: bypass auto_response.dry_run gate.
if err := c.d.fwEngine.BlockIPForce(args.IP, reason, 0); err != nil {
return nil, fmt.Errorf("block %s: %w", args.IP, err)
}
return control.FirewallAckResult{Message: fmt.Sprintf("Blocked %s - %s", args.IP, reason)}, nil
}
func (c *ControlListener) handleFirewallUnblock(argsRaw json.RawMessage) (any, error) {
var args control.FirewallIPArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if net.ParseIP(args.IP) == nil {
return nil, fmt.Errorf("invalid ip: %q", args.IP)
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
if err := c.d.fwEngine.UnblockIP(args.IP); err != nil {
return nil, fmt.Errorf("unblock %s: %w", args.IP, err)
}
return control.FirewallAckResult{Message: fmt.Sprintf("Unblocked %s", args.IP)}, nil
}
func (c *ControlListener) handleFirewallAllow(argsRaw json.RawMessage) (any, error) {
var args control.FirewallIPArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if net.ParseIP(args.IP) == nil {
return nil, fmt.Errorf("invalid ip: %q", args.IP)
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
reason := args.Reason
if reason == "" {
reason = "Allowed via CLI"
}
if err := c.d.fwEngine.AllowIP(args.IP, reason); err != nil {
return nil, fmt.Errorf("allow %s: %w", args.IP, err)
}
msg := fmt.Sprintf("Allowed %s - %s", args.IP, reason)
if cidr, covered := c.d.fwEngine.BlockedSubnetCovering(args.IP); covered {
msg += fmt.Sprintf(" (WARNING: still dropped by blocked subnet %s; unblock the subnet for this allow to take effect)", cidr)
}
return control.FirewallAckResult{Message: msg}, nil
}
func (c *ControlListener) handleFirewallRemoveAllow(argsRaw json.RawMessage) (any, error) {
var args control.FirewallIPArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if net.ParseIP(args.IP) == nil {
return nil, fmt.Errorf("invalid ip: %q", args.IP)
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
if err := c.d.fwEngine.RemoveAllowIP(args.IP); err != nil {
return nil, fmt.Errorf("remove-allow %s: %w", args.IP, err)
}
return control.FirewallAckResult{Message: fmt.Sprintf("Removed %s from allow list", args.IP)}, nil
}
func (c *ControlListener) handleFirewallAllowPort(argsRaw json.RawMessage) (any, error) {
var args control.FirewallPortArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if net.ParseIP(args.IP) == nil {
return nil, fmt.Errorf("invalid ip: %q", args.IP)
}
if args.Port <= 0 || args.Port > 65535 {
return nil, fmt.Errorf("invalid port: %d", args.Port)
}
proto := args.Proto
if proto == "" {
proto = "tcp"
}
if proto != "tcp" && proto != "udp" {
return nil, fmt.Errorf("invalid proto: %q (want tcp or udp)", args.Proto)
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
reason := args.Reason
if reason == "" {
reason = "Port-allowed via CLI"
}
if err := c.d.fwEngine.AllowIPPort(args.IP, args.Port, proto, reason); err != nil {
return nil, fmt.Errorf("allow-port %s %s:%d: %w", args.IP, proto, args.Port, err)
}
return control.FirewallAckResult{
Message: fmt.Sprintf("Allowed %s on %s:%d - %s", args.IP, proto, args.Port, reason),
}, nil
}
func (c *ControlListener) handleFirewallRemovePort(argsRaw json.RawMessage) (any, error) {
var args control.FirewallPortArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if net.ParseIP(args.IP) == nil {
return nil, fmt.Errorf("invalid ip: %q", args.IP)
}
if args.Port <= 0 || args.Port > 65535 {
return nil, fmt.Errorf("invalid port: %d", args.Port)
}
proto := args.Proto
if proto == "" {
proto = "tcp"
}
if proto != "tcp" && proto != "udp" {
return nil, fmt.Errorf("invalid proto: %q (want tcp or udp)", args.Proto)
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
if err := c.d.fwEngine.RemoveAllowIPPort(args.IP, args.Port, proto); err != nil {
return nil, fmt.Errorf("remove-port %s %s:%d: %w", args.IP, proto, args.Port, err)
}
return control.FirewallAckResult{
Message: fmt.Sprintf("Removed port-allow for %s on %s:%d", args.IP, proto, args.Port),
}, nil
}
func (c *ControlListener) handleFirewallTempBan(argsRaw json.RawMessage) (any, error) {
var args control.FirewallIPArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
// Parse timeout FIRST so callers get a duration-parse error before
// the engine-nil check (the unit test depends on this ordering).
if args.Timeout == "" {
return nil, fmt.Errorf("tempban requires timeout")
}
timeout, err := time.ParseDuration(args.Timeout)
if err != nil {
return nil, fmt.Errorf("parsing duration %q: %w", args.Timeout, err)
}
if net.ParseIP(args.IP) == nil {
return nil, fmt.Errorf("invalid ip: %q", args.IP)
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
reason := args.Reason
if reason == "" {
reason = "Temp-banned via CLI"
}
// Operator-initiated: bypass auto_response.dry_run gate.
if err := c.d.fwEngine.BlockIPForce(args.IP, reason, timeout); err != nil {
return nil, fmt.Errorf("tempban %s: %w", args.IP, err)
}
return control.FirewallAckResult{
Message: fmt.Sprintf("Temp-banned %s for %s - %s", args.IP, timeout, reason),
}, nil
}
func (c *ControlListener) handleFirewallTempAllow(argsRaw json.RawMessage) (any, error) {
var args control.FirewallIPArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if args.Timeout == "" {
return nil, fmt.Errorf("tempallow requires timeout")
}
timeout, err := time.ParseDuration(args.Timeout)
if err != nil {
return nil, fmt.Errorf("parsing duration %q: %w", args.Timeout, err)
}
if net.ParseIP(args.IP) == nil {
return nil, fmt.Errorf("invalid ip: %q", args.IP)
}
if c.d.fwEngine == nil {
return nil, fmt.Errorf("firewall disabled in csm.yaml")
}
reason := args.Reason
if reason == "" {
reason = "Temp-allowed via CLI"
}
if err := c.d.fwEngine.TempAllowIP(args.IP, reason, timeout); err != nil {
return nil, fmt.Errorf("tempallow %s: %w", args.IP, err)
}
return control.FirewallAckResult{
Message: fmt.Sprintf("Temp-allowed %s for %s - %s", args.IP, timeout, reason),
}, nil
}
package daemon
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/pidginhost/csm/internal/control"
"github.com/pidginhost/csm/internal/firewall"
)
// Read-only firewall handlers: no state mutation, just surface what
// firewall.LoadState already has.
// fmtPortsSlice converts a slice of int ports into a slice of strings,
// one entry per port. Mirrors the atomic rendering the wire schema
// expects — the CLI joins them with commas for display. Empty input
// returns nil so the JSON wire form is `null` / omitted rather than an
// empty array carrying a placeholder.
func fmtPortsSlice(ports []int) []string {
if len(ports) == 0 {
return nil
}
out := make([]string, len(ports))
for i, p := range ports {
out[i] = strconv.Itoa(p)
}
return out
}
func (c *ControlListener) handleFirewallStatus(_ json.RawMessage) (any, error) {
cfg := c.d.currentCfg()
// LoadState tolerates a missing file (returns empty state).
state, err := firewall.LoadState(cfg.StatePath)
if err != nil {
return nil, fmt.Errorf("loading firewall state: %w", err)
}
// cfg.Firewall is *firewall.FirewallConfig — nil when the operator
// omitted the block entirely. Return a zero-value Enabled=false
// result so the CLI prints "Status: DISABLED" rather than erroring.
if cfg.Firewall == nil {
return control.FirewallStatusResult{
Enabled: false,
BlockedCount: len(state.Blocked),
BlockedNetCount: len(state.BlockedNet),
AllowedCount: len(state.Allowed),
}, nil
}
fwCfg := cfg.Firewall
result := control.FirewallStatusResult{
Enabled: fwCfg.Enabled,
TCPIn: fmtPortsSlice(fwCfg.TCPIn),
TCPOut: fmtPortsSlice(fwCfg.TCPOut),
UDPIn: fmtPortsSlice(fwCfg.UDPIn),
UDPOut: fmtPortsSlice(fwCfg.UDPOut),
Restricted: fmtPortsSlice(fwCfg.RestrictedTCP),
PassiveFTPStart: fwCfg.PassiveFTPStart,
PassiveFTPEnd: fwCfg.PassiveFTPEnd,
InfraIPCount: len(fwCfg.InfraIPs),
BlockedCount: len(state.Blocked),
BlockedNetCount: len(state.BlockedNet),
AllowedCount: len(state.Allowed),
SYNFlood: fwCfg.SYNFloodProtection,
ConnRateLimit: fwCfg.ConnRateLimit,
LogDropped: fwCfg.LogDropped,
LogRate: fwCfg.LogRate,
}
// Recent blocked: last 10, newest-first, matching cmd/csm/firewall.go:fwStatus.
// Time values emitted as RFC3339 so the wire schema is not tied to
// Go's time.Time encoding; the CLI renders "N ago" on receive.
shown := 0
for i := len(state.Blocked) - 1; i >= 0 && shown < 10; i-- {
b := state.Blocked[i]
entry := control.FirewallBlockedEntry{
IP: b.IP,
Reason: b.Reason,
BlockedAt: b.BlockedAt.UTC().Format(time.RFC3339),
}
if !b.ExpiresAt.IsZero() {
entry.ExpiresAt = b.ExpiresAt.UTC().Format(time.RFC3339)
}
result.RecentBlocked = append(result.RecentBlocked, entry)
shown++
}
return result, nil
}
func (c *ControlListener) handleFirewallPorts(_ json.RawMessage) (any, error) {
cfg := c.d.currentCfg()
var lines []string
if cfg.Firewall == nil {
return control.FirewallListResult{Lines: lines}, nil
}
fwCfg := cfg.Firewall
lines = append(lines, "TCP Inbound (public):")
lines = append(lines, " "+joinPorts(fwCfg.TCPIn))
lines = append(lines, "")
if len(fwCfg.RestrictedTCP) > 0 {
lines = append(lines, "TCP Restricted (infra only):")
lines = append(lines, " "+joinPorts(fwCfg.RestrictedTCP))
lines = append(lines, "")
}
lines = append(lines, "TCP Outbound:")
lines = append(lines, " "+joinPorts(fwCfg.TCPOut))
lines = append(lines, "")
lines = append(lines, "UDP Inbound:")
lines = append(lines, " "+joinPorts(fwCfg.UDPIn))
lines = append(lines, "")
lines = append(lines, "UDP Outbound:")
lines = append(lines, " "+joinPorts(fwCfg.UDPOut))
lines = append(lines, "")
if fwCfg.PassiveFTPStart > 0 {
lines = append(lines, "Passive FTP:")
lines = append(lines, fmt.Sprintf(" %d-%d", fwCfg.PassiveFTPStart, fwCfg.PassiveFTPEnd))
}
return control.FirewallListResult{Lines: lines}, nil
}
// joinPorts returns a comma-separated rendering of ports matching
// cmd/csm/firewall.go:fmtPortsWrap's behaviour for the ports handler.
// Wrapping stays client-side; the wire payload keeps the full CSV so
// the CLI can format for whatever terminal width it runs in.
func joinPorts(ports []int) string {
if len(ports) == 0 {
return "(none)"
}
strs := make([]string, len(ports))
for i, p := range ports {
strs[i] = strconv.Itoa(p)
}
return strings.Join(strs, ", ")
}
func (c *ControlListener) handleFirewallGrep(argsRaw json.RawMessage) (any, error) {
var args control.FirewallGrepArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
cfg := c.d.currentCfg()
// Empty pattern matches nothing — the CLI used to require a
// positional arg and exit with usage; mirror that by returning an
// empty result rather than dumping everything.
if args.Pattern == "" {
return control.FirewallListResult{}, nil
}
pattern := strings.ToLower(args.Pattern)
state, err := firewall.LoadState(cfg.StatePath)
if err != nil {
return nil, fmt.Errorf("loading firewall state: %w", err)
}
var lines []string
now := time.Now()
for _, b := range state.Blocked {
if strings.Contains(strings.ToLower(b.IP), pattern) ||
strings.Contains(strings.ToLower(b.Reason), pattern) {
ago := now.Sub(b.BlockedAt).Truncate(time.Minute)
expires := "permanent"
if !b.ExpiresAt.IsZero() {
remaining := b.ExpiresAt.Sub(now).Truncate(time.Minute)
expires = fmt.Sprintf("%s left", remaining)
}
lines = append(lines, fmt.Sprintf("BLOCKED %-18s (%s ago, %s) %s",
b.IP, ago, expires, b.Reason))
}
}
for _, a := range state.Allowed {
if strings.Contains(strings.ToLower(a.IP), pattern) ||
strings.Contains(strings.ToLower(a.Reason), pattern) {
port := ""
if a.Port > 0 {
port = fmt.Sprintf(" port:%d", a.Port)
}
lines = append(lines, fmt.Sprintf("ALLOWED %-18s%s %s",
a.IP, port, a.Reason))
}
}
for _, s := range state.BlockedNet {
if strings.Contains(strings.ToLower(s.CIDR), pattern) ||
strings.Contains(strings.ToLower(s.Reason), pattern) {
ago := now.Sub(s.BlockedAt).Truncate(time.Minute)
lines = append(lines, fmt.Sprintf("SUBNET %-18s (%s ago) %s",
s.CIDR, ago, s.Reason))
}
}
if cfg.Firewall != nil {
for _, ip := range cfg.Firewall.InfraIPs {
if strings.Contains(strings.ToLower(ip), pattern) {
lines = append(lines, fmt.Sprintf("INFRA %s", ip))
}
}
}
return control.FirewallListResult{Lines: lines}, nil
}
func (c *ControlListener) handleFirewallAudit(argsRaw json.RawMessage) (any, error) {
var args control.FirewallAuditArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
limit := args.Limit
if limit <= 0 {
limit = 50
}
cfg := c.d.currentCfg()
entries := firewall.ReadAuditLog(cfg.StatePath, limit)
lines := make([]string, 0, len(entries))
for _, e := range entries {
ts := e.Timestamp.Format("2006-01-02 15:04:05")
dur := ""
if e.Duration != "" {
dur = fmt.Sprintf(" (%s)", e.Duration)
}
reason := ""
if e.Reason != "" {
reason = fmt.Sprintf(" %s", e.Reason)
}
lines = append(lines, fmt.Sprintf("%s %-13s %-18s%s%s",
ts, e.Action, e.IP, dur, reason))
}
return control.FirewallListResult{Lines: lines}, nil
}
package daemon
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/pidginhost/csm/internal/control"
"github.com/pidginhost/csm/internal/firewall/rollback"
)
func (c *ControlListener) handleFirewallRollbackStatus(_ json.RawMessage) (any, error) {
mgr := rollback.Global()
if mgr == nil {
return control.FirewallRollbackStatus{}, nil
}
st := mgr.Status()
out := control.FirewallRollbackStatus{
Pending: st.Pending,
AppliedBy: st.AppliedBy,
PrevHash: st.PrevHash,
NewHash: st.NewHash,
SecondsRemaining: st.SecondsRemaining,
}
if !st.AppliedAt.IsZero() {
out.AppliedAtRFC3339 = st.AppliedAt.Format(time.RFC3339)
}
if !st.ExpiresAt.IsZero() {
out.ExpiresAtRFC3339 = st.ExpiresAt.Format(time.RFC3339)
}
return out, nil
}
func (c *ControlListener) handleFirewallRollbackConfirm(_ json.RawMessage) (any, error) {
mgr := rollback.Global()
if mgr == nil {
return control.FirewallAckResult{Message: "rollback manager not initialised"}, nil
}
if !mgr.Status().Pending {
return control.FirewallAckResult{Message: "no pending rollback"}, nil
}
if err := mgr.Confirm(); err != nil {
return nil, fmt.Errorf("confirm rollback: %w", err)
}
return control.FirewallAckResult{Message: "rollback confirmed; pending change is now permanent"}, nil
}
func (c *ControlListener) handleFirewallRollbackRevert(_ json.RawMessage) (any, error) {
mgr := rollback.Global()
if mgr == nil {
return control.FirewallAckResult{Message: "rollback manager not initialised"}, nil
}
if !mgr.Status().Pending {
return control.FirewallAckResult{Message: "no pending rollback"}, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := mgr.Revert(ctx); err != nil {
return nil, fmt.Errorf("revert rollback: %w", err)
}
return control.FirewallAckResult{Message: "rollback reverted; previous config restored, daemon restart issued"}, nil
}
package daemon
import (
"encoding/json"
"fmt"
"os"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/control"
"github.com/pidginhost/csm/internal/health"
"github.com/pidginhost/csm/internal/integrity"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/store"
)
// dispatch parses a raw request line, routes to the right handler, and
// wraps the handler's (result, error) pair in the response envelope.
// Unknown commands fail cleanly rather than crash the listener.
func (c *ControlListener) dispatch(line []byte) control.Response {
var req control.Request
if err := json.Unmarshal(line, &req); err != nil {
return control.Response{OK: false, Error: fmt.Sprintf("bad request: %v", err)}
}
var (
result any
err error
)
switch req.Cmd {
case control.CmdTierRun:
result, err = c.handleTierRun(req.Args)
case control.CmdStatus:
result, err = c.handleStatus(req.Args)
case control.CmdHistoryRead:
result, err = c.handleHistoryRead(req.Args)
case control.CmdRulesReload:
result, err = c.handleRulesReload(req.Args)
case control.CmdGeoIPReload:
result, err = c.handleGeoIPReload(req.Args)
case control.CmdBaseline:
result, err = c.handleBaseline(req.Args)
case control.CmdFirewallStatus:
result, err = c.handleFirewallStatus(req.Args)
case control.CmdFirewallPorts:
result, err = c.handleFirewallPorts(req.Args)
case control.CmdFirewallGrep:
result, err = c.handleFirewallGrep(req.Args)
case control.CmdFirewallAudit:
result, err = c.handleFirewallAudit(req.Args)
case control.CmdFirewallBlock:
result, err = c.handleFirewallBlock(req.Args)
case control.CmdFirewallUnblock:
result, err = c.handleFirewallUnblock(req.Args)
case control.CmdFirewallAllow:
result, err = c.handleFirewallAllow(req.Args)
case control.CmdFirewallRemoveAllow:
result, err = c.handleFirewallRemoveAllow(req.Args)
case control.CmdFirewallAllowPort:
result, err = c.handleFirewallAllowPort(req.Args)
case control.CmdFirewallRemovePort:
result, err = c.handleFirewallRemovePort(req.Args)
case control.CmdFirewallTempBan:
result, err = c.handleFirewallTempBan(req.Args)
case control.CmdFirewallTempAllow:
result, err = c.handleFirewallTempAllow(req.Args)
case control.CmdFirewallDenySubnet:
result, err = c.handleFirewallDenySubnet(req.Args)
case control.CmdFirewallRemoveSubnet:
result, err = c.handleFirewallRemoveSubnet(req.Args)
case control.CmdFirewallDenyFile:
result, err = c.handleFirewallDenyFile(req.Args)
case control.CmdFirewallAllowFile:
result, err = c.handleFirewallAllowFile(req.Args)
case control.CmdFirewallFlush:
result, err = c.handleFirewallFlush(req.Args)
case control.CmdFirewallRestart:
result, err = c.handleFirewallRestart(req.Args)
case control.CmdFirewallApplyConfirmed:
result, err = c.handleFirewallApplyConfirmed(req.Args)
case control.CmdFirewallConfirm:
result, err = c.handleFirewallConfirm(req.Args)
case control.CmdFirewallRollbackStatus:
result, err = c.handleFirewallRollbackStatus(req.Args)
case control.CmdFirewallRollbackOK:
result, err = c.handleFirewallRollbackConfirm(req.Args)
case control.CmdFirewallRollbackRevert:
result, err = c.handleFirewallRollbackRevert(req.Args)
case control.CmdStoreExport:
result, err = c.handleStoreExport(req.Args)
case control.CmdHistorySince:
result, err = c.handleHistorySince(req.Args)
case control.CmdPHPRelayStatus:
result, err = c.handlePHPRelayStatus(req.Args)
case control.CmdPHPRelayIgnoreScript:
result, err = c.handlePHPRelayIgnoreScript(req.Args)
case control.CmdPHPRelayUnignore:
result, err = c.handlePHPRelayUnignore(req.Args)
case control.CmdPHPRelayIgnoreList:
result, err = c.handlePHPRelayIgnoreList(req.Args)
case control.CmdPHPRelayDryRun:
result, err = c.handlePHPRelayDryRun(req.Args)
case control.CmdPHPRelayThaw:
result, err = c.handlePHPRelayThaw(req.Args)
case control.CmdIncidentsList:
result, err = c.handleIncidentsList(req.Args)
case control.CmdIncidentsShow:
result, err = c.handleIncidentsShow(req.Args)
case control.CmdIncidentsStatus:
result, err = c.handleIncidentsStatus(req.Args)
case control.CmdIncidentsBulkStatus:
result, err = c.handleIncidentsBulkStatus(req.Args)
default:
return control.Response{OK: false, Error: fmt.Sprintf("unknown command: %q", req.Cmd)}
}
if err != nil {
return control.Response{OK: false, Error: err.Error()}
}
payload, mErr := json.Marshal(result)
if mErr != nil {
return control.Response{OK: false, Error: "result marshal: " + mErr.Error()}
}
return control.Response{OK: true, Result: payload}
}
// parseTier maps the wire string onto the checks.Tier constants. Kept
// local to the listener so the protocol package does not depend on the
// checks package.
func parseTier(s string) (checks.Tier, error) {
switch s {
case "critical":
return checks.TierCritical, nil
case "deep":
return checks.TierDeep, nil
case "all", "":
return checks.TierAll, nil
}
return "", fmt.Errorf("unknown tier: %q", s)
}
// handleTierRun runs a tier synchronously and reports the result. The
// flow mirrors Daemon.runPeriodicChecks: integrity verify, RunTier,
// purge-and-merge, then hand findings to the alert pipeline. When
// Alerts=false the handler absorbs the old `csm check*` behaviour:
// skip auto-response for this run, append the raw findings to history,
// and return them in the response body so the CLI can render them
// verbatim.
func (c *ControlListener) handleTierRun(argsRaw json.RawMessage) (any, error) {
var args control.TierRunArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
tier, err := parseTier(args.Tier)
if err != nil {
return nil, err
}
// dryRun ties together the three Alerts=false side effects:
// auto-response suppression, the post-run history append, and the
// FindingList in the response. Hoisting it makes the invariant
// "these three happen together" visually obvious.
dryRun := !args.Alerts
if vErr := c.verifyTierRunIntegrity(); vErr != nil {
// Integrity failures are escalated through the normal alert
// pipeline so the on-call path sees them regardless of who
// kicked the tier run. The client also gets an error so the
// systemd timer unit fails loudly.
if args.Alerts {
select {
case c.d.alertCh <- alert.Finding{
Severity: alert.Critical,
Check: "integrity",
Message: fmt.Sprintf("BINARY/CONFIG TAMPER DETECTED: %v", vErr),
Timestamp: time.Now(),
}:
default:
atomic.AddInt64(&c.d.droppedAlerts, 1)
}
}
return nil, fmt.Errorf("integrity verify failed: %w", vErr)
}
// Dry-run threads through RunTierDryRun, so a concurrent periodic
// scanner running in live mode never sees this caller's dry-run state.
start := time.Now()
var (
findings []alert.Finding
purgeChecks []string
)
cfg := c.d.currentCfg()
if dryRun {
findings, purgeChecks = checks.RunTierDryRun(cfg, c.d.store, tier)
} else {
findings, purgeChecks = checks.RunTier(cfg, c.d.store, tier)
}
c.recordTierRunFindings(cfg, findings, purgeChecks, !dryRun, args.Alerts)
// Dry-run history + FindingList: the live path writes history via
// Daemon.runPeriodicChecks when the internal scanners fire; the
// pre-phase-2 `csm check*` wrote it via store.AppendHistory in
// cmd/csm/main.go. Preserve that quirk on the socket path.
if dryRun {
c.d.store.AppendHistory(findings)
}
newCount := len(c.d.store.FilterNew(findings))
result := control.TierRunResult{
Findings: len(findings),
NewFindings: newCount,
ElapsedMs: time.Since(start).Milliseconds(),
}
if dryRun {
if findings == nil {
result.FindingList = []alert.Finding{}
} else {
result.FindingList = findings
}
}
return result, nil
}
// recordTierRunFindings persists a control-socket tier run's findings. Auto-fix
// is gated on a live run so a dry run never edits a customer's wp-config.php,
// and the alert push is gated separately on whether the caller asked for alerts.
func (c *ControlListener) recordTierRunFindings(cfg *config.Config, findings []alert.Finding, purgeChecks []string, live, alerts bool) {
checks.StoreLatestScanFindings(c.d.store, purgeChecks, findings)
if live {
c.d.applyWPCronAutoFix(cfg, findings)
}
if alerts {
c.d.enqueueScanAlerts(findings, "control")
}
}
func (c *ControlListener) verifyTierRunIntegrity() error {
return integrity.Verify(c.d.binaryPath, c.d.currentCfg())
}
// handleStatus reports what `csm status` historically printed from
// disk, sourced from the live daemon instead of re-opening the store.
func (c *ControlListener) handleStatus(_ json.RawMessage) (any, error) {
latest := c.d.store.LatestFindings()
latestTime := c.d.store.LatestScanTime()
var latestStr string
if !latestTime.IsZero() {
latestStr = latestTime.UTC().Format(time.RFC3339)
}
var historyCount int
if sdb := store.Global(); sdb != nil {
historyCount = sdb.HistoryCount()
}
uptime := int64(0)
if !c.d.startTime.IsZero() {
uptime = int64(time.Since(c.d.startTime).Seconds())
}
result := control.StatusResult{
Version: c.d.version,
UptimeSec: uptime,
LatestScanTime: latestStr,
LatestFindings: len(latest),
HistoryCount: historyCount,
DroppedAlerts: c.d.DroppedAlerts(),
}
snap := health.Build(c.d, c.d.version, health.Capabilities())
result.Snapshot = &snap
return result, nil
}
// handleHistoryRead paginates bbolt history. Clamps Limit so a buggy
// client cannot ask for everything at once; 1000 is well above the
// dashboard page size.
func (c *ControlListener) handleHistoryRead(argsRaw json.RawMessage) (any, error) {
var args control.HistoryReadArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if args.Limit <= 0 || args.Limit > 1000 {
args.Limit = 100
}
if args.Offset < 0 {
args.Offset = 0
}
findings, total := c.d.store.ReadHistory(args.Limit, args.Offset)
return control.HistoryReadResult{Findings: findings, Total: total}, nil
}
// handleRulesReload replaces `kill -HUP $(pidof csm)`. Returns after
// the reload completes so the client can confirm it happened.
func (c *ControlListener) handleRulesReload(_ json.RawMessage) (any, error) {
c.d.reloadSignatures()
return map[string]string{"status": "reloaded"}, nil
}
// handleGeoIPReload is the GeoIP equivalent of rules.reload.
func (c *ControlListener) handleGeoIPReload(_ json.RawMessage) (any, error) {
c.d.publishGeoIP()
return map[string]string{"status": "reloaded"}, nil
}
// handleHistorySince streams every history-bucket finding newer than
// the supplied cutoff. Used by `csm export --since` for SIEM backfill;
// the daemon-side bbolt cursor seek is materially faster than
// pagination + client-side filtering for hosts with large histories.
func (c *ControlListener) handleHistorySince(argsRaw json.RawMessage) (any, error) {
var args control.HistorySinceArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if args.Since == "" {
return nil, fmt.Errorf("since is required (RFC 3339)")
}
since, err := time.Parse(time.RFC3339, args.Since)
if err != nil {
return nil, fmt.Errorf("parsing since: %w", err)
}
sdb := store.Global()
if sdb == nil {
return nil, fmt.Errorf("bbolt store not available")
}
findings := sdb.ReadHistorySince(since)
// ReadHistorySince returns newest-first; reverse for chronological
// output so JSONL consumers see the same order they would from a
// live tail.
for i, j := 0, len(findings)-1; i < j; i, j = i+1, j-1 {
findings[i], findings[j] = findings[j], findings[i]
}
return control.HistorySinceResult{Findings: findings}, nil
}
// handleStoreExport writes a tar+zstd backup containing the live bbolt
// snapshot, the state directory, and the signature-rules cache. The
// daemon is the single source of truth for paths; the CLI only supplies
// where to write the archive. Import deliberately does NOT route through
// the socket -- it requires a stopped daemon.
func (c *ControlListener) handleStoreExport(argsRaw json.RawMessage) (any, error) {
var args control.StoreExportArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if args.DstPath == "" {
return nil, fmt.Errorf("dst_path is required")
}
sdb := store.Global()
if sdb == nil {
return nil, fmt.Errorf("bbolt store not available")
}
cfg := c.d.currentCfg()
hostname, _ := os.Hostname()
pi := platform.Detect()
res, err := sdb.Export(store.ExportOptions{
StatePath: cfg.StatePath,
RulesPath: cfg.Signatures.RulesDir,
DstPath: args.DstPath,
Manifest: store.Manifest{
CSMVersion: c.d.version,
SourceHostname: hostname,
SourcePlatform: map[string]string{
"os": string(pi.OS),
"os_version": pi.OSVersion,
"panel": string(pi.Panel),
"webserver": string(pi.WebServer),
},
},
})
if err != nil {
return nil, err
}
return control.StoreExportResult{
Path: res.Path,
Bytes: res.Bytes,
ArchiveSHA256: res.ArchiveSHA256,
BboltSHA256: res.BboltSHA256,
}, nil
}
package daemon
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/pidginhost/csm/internal/control"
"github.com/pidginhost/csm/internal/incident"
)
const (
defaultIncidentListLimit = 100
maxIncidentListLimit = 1000
defaultIncidentBulkLimit = 100
maxIncidentBulkLimit = 1000
)
// handleIncidentsList returns a bounded, newest-first incident page.
func (c *ControlListener) handleIncidentsList(argsRaw json.RawMessage) (any, error) {
var args control.IncidentListArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, err
}
}
statuses, statusLabel, err := incidentListStatusFilter(args.Status)
if err != nil {
return nil, err
}
offset := args.Offset
if offset < 0 {
offset = 0
}
limit := args.Limit
if args.All {
limit = 0
} else {
if limit <= 0 {
limit = defaultIncidentListLimit
}
if limit > maxIncidentListLimit {
limit = maxIncidentListLimit
}
}
co := IncidentCorrelator()
items, total := co.SnapshotPageStatuses(statuses, offset, limit)
return control.IncidentListResult{
Items: items,
Total: total,
Offset: offset,
Limit: limit,
Status: statusLabel,
}, nil
}
func incidentListStatusFilter(status string) ([]incident.Status, string, error) {
switch strings.ToLower(strings.TrimSpace(status)) {
case "", "all":
return nil, "all", nil
case "active":
return []incident.Status{incident.StatusOpen, incident.StatusContained}, "active", nil
case string(incident.StatusOpen):
return []incident.Status{incident.StatusOpen}, string(incident.StatusOpen), nil
case string(incident.StatusContained):
return []incident.Status{incident.StatusContained}, string(incident.StatusContained), nil
case string(incident.StatusResolved):
return []incident.Status{incident.StatusResolved}, string(incident.StatusResolved), nil
case string(incident.StatusDismissed):
return []incident.Status{incident.StatusDismissed}, string(incident.StatusDismissed), nil
default:
return nil, "", fmt.Errorf("unknown status: %q", status)
}
}
func incidentBulkStatusFilter(status string) ([]incident.Status, string, error) {
switch strings.ToLower(strings.TrimSpace(status)) {
case "", "active":
return []incident.Status{incident.StatusOpen, incident.StatusContained}, "active", nil
case string(incident.StatusOpen):
return []incident.Status{incident.StatusOpen}, string(incident.StatusOpen), nil
case string(incident.StatusContained):
return []incident.Status{incident.StatusContained}, string(incident.StatusContained), nil
default:
return nil, "", fmt.Errorf("bulk status source must be active, open, or contained")
}
}
// handleIncidentsShow returns one incident by id; ErrIncidentNotFound on miss.
func (c *ControlListener) handleIncidentsShow(argsRaw json.RawMessage) (any, error) {
var args control.IncidentShowArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, err
}
}
co := IncidentCorrelator()
inc, ok := co.Get(args.ID)
if !ok {
return nil, incident.ErrIncidentNotFound
}
return inc, nil
}
// handleIncidentsStatus transitions an incident's status. Returns
// {"ok": true} on success; ErrIncidentNotFound or validation error on
// failure.
func (c *ControlListener) handleIncidentsStatus(argsRaw json.RawMessage) (any, error) {
var args control.IncidentStatusArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, err
}
}
co := IncidentCorrelator()
if err := co.SetStatus(args.ID, incident.Status(args.Status), args.Details); err != nil {
return nil, err
}
return map[string]bool{"ok": true}, nil
}
func (c *ControlListener) handleIncidentsBulkStatus(argsRaw json.RawMessage) (any, error) {
var args control.IncidentBulkStatusArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, err
}
}
statuses, statusLabel, err := incidentBulkStatusFilter(args.Status)
if err != nil {
return nil, err
}
to := incident.Status(strings.ToLower(strings.TrimSpace(args.To)))
if to == "" {
to = incident.StatusResolved
}
if to != incident.StatusResolved && to != incident.StatusDismissed {
return nil, fmt.Errorf("bulk status target must be resolved or dismissed")
}
if args.OlderThanSeconds < 0 {
return nil, fmt.Errorf("older-than must be positive")
}
olderThan := time.Duration(args.OlderThanSeconds) * time.Second
if olderThan <= 0 && args.LastSeenBefore.IsZero() {
return nil, fmt.Errorf("bulk status requires --older-than or --last-seen-before")
}
limit := args.Limit
if limit <= 0 {
limit = defaultIncidentBulkLimit
}
if limit > maxIncidentBulkLimit {
limit = maxIncidentBulkLimit
}
if args.Apply && !args.Confirm {
return nil, fmt.Errorf("bulk status apply requires confirmation")
}
dryRun := !args.Apply
co := IncidentCorrelator()
res, err := co.BulkSetStatus(incident.BulkStatusFilter{
FromStatuses: statuses,
To: to,
OlderThan: olderThan,
LastSeenBefore: args.LastSeenBefore,
Kind: incident.Kind(strings.TrimSpace(args.Kind)),
Domain: strings.TrimSpace(args.Domain),
Account: strings.TrimSpace(args.Account),
Mailbox: strings.TrimSpace(args.Mailbox),
Limit: limit,
DryRun: dryRun,
Details: strings.TrimSpace(args.Details),
})
if err != nil {
return nil, err
}
return control.IncidentBulkStatusResult{
DryRun: dryRun,
Matched: res.Matched,
Updated: res.Updated,
Limit: limit,
Status: statusLabel,
To: string(to),
OlderThanSeconds: args.OlderThanSeconds,
LastSeenBefore: args.LastSeenBefore,
Items: res.Items,
}, nil
}
package daemon
import (
"bufio"
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"time"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/obs"
)
// controlSocketPath is the Unix socket the daemon binds for
// CLI-to-daemon IPC. A var (not const) so tests can redirect under
// t.TempDir(); production default matches internal/control.DefaultSocketPath.
var controlSocketPath = "/var/run/csm/control.sock"
// controlRequestTimeout caps how long a single client request can block
// the listener. Handlers that legitimately take longer (tier.run on a
// large server) run on the accepting goroutine, so the timeout applies
// to reading the request line and writing the response, not to the
// handler body itself.
const controlRequestTimeout = 2 * time.Second
// ControlListener serves the local command-line client over a Unix
// socket. One request and one response per connection, line-framed JSON.
// The daemon keeps exclusive ownership of the bbolt store; this listener
// is the only reason any CLI command needs to reach into daemon state.
type ControlListener struct {
d *Daemon
listener net.Listener
phprelay *PHPRelayController // wired by Phase O2; may be nil in tests / pre-wiring
}
// NewControlListener creates the socket, enforces 0600 perms, and
// returns a listener that the daemon wires into its goroutine pool.
func NewControlListener(d *Daemon) (*ControlListener, error) {
socketDir := filepath.Dir(controlSocketPath)
if err := os.MkdirAll(socketDir, 0750); err != nil {
return nil, fmt.Errorf("creating socket dir: %w", err)
}
// Stale socket from a previous crash would make Listen fail with
// EADDRINUSE. Remove before binding; the file is process-owned.
_ = os.Remove(controlSocketPath)
ln, err := net.Listen("unix", controlSocketPath)
if err != nil {
return nil, fmt.Errorf("listening on %s: %w", controlSocketPath, err)
}
// 0600 root-only: the CLI client also runs as root (fanotify,
// nftables, cpanel APIs all require it), so no group is needed.
if err := os.Chmod(controlSocketPath, 0600); err != nil {
_ = ln.Close()
return nil, fmt.Errorf("chmod socket: %w", err)
}
return &ControlListener{d: d, listener: ln}, nil
}
// Run accepts connections until stopCh closes. Each connection is
// handled on its own goroutine so a slow request never stalls the
// accept loop.
func (c *ControlListener) Run(stopCh <-chan struct{}) {
for {
conn, err := c.listener.Accept()
if err != nil {
select {
case <-stopCh:
return
default:
csmlog.Warn("control listener accept error", "err", err)
time.Sleep(100 * time.Millisecond)
continue
}
}
// SO_PEERCRED defence-in-depth: socket perms are already 0600
// (root-only), but if a future install pattern relaxes that --
// or a confused-deputy mount makes the socket reachable from a
// non-root namespace -- the kernel-supplied credentials force
// us to refuse any non-root caller before reading a byte of
// payload. The socket perms remain the primary defence; this
// is an extra rejection layer that the kernel cannot lie about.
if err := verifyControlPeer(conn); err != nil {
csmlog.Warn("control listener rejecting peer", "err", err)
_ = conn.Close()
continue
}
obs.SafeGo("control-conn", func() { c.handleConnection(conn) })
}
}
// verifyControlPeer is defined per-OS:
// - linux: reads SO_PEERCRED and refuses non-root callers
// - other: no-op (the listener is Linux-only in production; macOS dev
// builds still need the package to compile)
// Stop closes the listener and removes the socket file. Safe to call
// after Run has already returned.
func (c *ControlListener) Stop() {
_ = c.listener.Close()
_ = os.Remove(controlSocketPath)
}
// handleConnection reads one request, dispatches it, writes one
// response, and closes. The short timeout applies to I/O only; the
// handler itself can take as long as the underlying work requires.
func (c *ControlListener) handleConnection(conn net.Conn) {
defer func() { _ = conn.Close() }()
// Read deadline. Writes use a separate deadline set after the
// handler returns so slow scans don't count against the reader.
_ = conn.SetReadDeadline(time.Now().Add(controlRequestTimeout))
scanner := bufio.NewScanner(conn)
// Requests and responses are single-line JSON; on a compromised
// server a dry-run tier scan can legitimately return a findings
// list that exceeds the old 1 MiB cap (think: many WordPress
// installs, each with multiple infected files). The socket is
// root-only 0600 so the original DoS guard no longer applies —
// cap the buffer at 16 MiB so large finding lists round-trip.
scanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024)
if !scanner.Scan() {
return
}
line := scanner.Bytes()
resp := c.dispatch(line)
payload, err := json.Marshal(resp)
if err != nil {
// Marshalling a Response should not fail; if it does, fall back
// to a minimal error response the client can still parse.
payload = []byte(`{"ok":false,"error":"internal: response marshal failed"}`)
}
payload = append(payload, '\n')
_ = conn.SetWriteDeadline(time.Now().Add(controlRequestTimeout))
_, _ = conn.Write(payload)
}
//go:build linux
package daemon
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
var controlPeerRequiredUID uint32
// verifyControlPeer reads SO_PEERCRED and refuses any caller whose
// effective uid is not root. Returns nil when the peer is acceptable.
func verifyControlPeer(conn net.Conn) error {
uc, ok := conn.(*net.UnixConn)
if !ok || uc == nil {
return fmt.Errorf("peer credentials: unsupported connection %T", conn)
}
raw, err := uc.SyscallConn()
if err != nil {
return fmt.Errorf("peer raw conn: %w", err)
}
var ucred *unix.Ucred
var optErr error
if err := raw.Control(func(fd uintptr) {
// #nosec G115 -- POSIX fd fits in int on Linux.
ucred, optErr = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
}); err != nil {
return fmt.Errorf("peer credentials: %w", err)
}
if optErr != nil {
return fmt.Errorf("peer credentials: %w", optErr)
}
if ucred == nil {
return fmt.Errorf("peer credentials: empty result")
}
if ucred.Uid != controlPeerRequiredUID {
return fmt.Errorf("peer uid=%d, want %d", ucred.Uid, controlPeerRequiredUID)
}
return nil
}
package daemon
import (
"net"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
)
// infraHostnames partitions the cfg.InfraIPs operator list into the
// subset that is hostnames (not literal IPs or CIDRs). Hostnames get
// DNS-refreshed into the engine's infra-block guard so operators can
// list panel hosts by name and have them stay protected as the
// underlying address rotates.
func infraHostnames(entries []string) []string {
out := make([]string, 0, len(entries))
for _, e := range entries {
s := strings.TrimSpace(e)
if s == "" {
continue
}
if _, _, err := net.ParseCIDR(s); err == nil {
continue
}
if ip := net.ParseIP(s); ip != nil {
continue
}
out = append(out, s)
}
return out
}
// containsString returns true when s appears in haystack. Linear scan
// because the call site loops short DynDNS host lists where a map
// would not pay for itself.
func containsString(haystack []string, s string) bool {
for _, h := range haystack {
if h == s {
return true
}
}
return false
}
// expandWithCorrelation runs cross-account correlation over a dispatch
// batch and appends any synthesized findings that are not already present.
// The scan runner may have already produced the same synthetic findings, so
// this helper must be idempotent to avoid double-alerting the first batch.
func expandWithCorrelation(findings []alert.Finding, now time.Time) []alert.Finding {
seen := make(map[string]struct{})
for i := range findings {
if !isCorrelationFinding(findings[i].Check) {
continue
}
if findings[i].Timestamp.IsZero() {
findings[i].Timestamp = now
}
seen[findings[i].Key()] = struct{}{}
}
extra := checks.CorrelateFindings(findings)
for i := range extra {
if extra[i].Timestamp.IsZero() {
extra[i].Timestamp = now
}
key := extra[i].Key()
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
findings = append(findings, extra[i])
}
return findings
}
func isCorrelationFinding(check string) bool {
switch check {
case "coordinated_attack", "cross_account_malware":
return true
default:
return false
}
}
package daemon
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/emailspool"
"gopkg.in/yaml.v3"
)
// userDomainsResolver resolves a cPanel user to the lowercased, IDN-normalised
// set of domains the account owns. TTL-based cache; safe for concurrent use.
type userDomainsResolver struct {
root string
ttl time.Duration
mu sync.Mutex
cache map[string]userDomainsCacheEntry
}
type userDomainsCacheEntry struct {
domains map[string]struct{}
fetched time.Time
err error
}
// newUserDomainsResolver returns a resolver reading from /var/cpanel/userdata/.
// Wired by daemon startup in O2; kept now so that future call sites compile.
//
//nolint:unused // consumed by daemon wiring (Task O2)
func newUserDomainsResolver() *userDomainsResolver {
return newUserDomainsResolverWithRoot("/var/cpanel/userdata", 5*time.Minute)
}
func newUserDomainsResolverWithRoot(root string, ttl time.Duration) *userDomainsResolver {
return &userDomainsResolver{
root: root,
ttl: ttl,
cache: make(map[string]userDomainsCacheEntry),
}
}
// Domains returns the cPanel user's authorised domain set. Returns the
// (possibly cached) error if the user's userdata is unreadable; callers
// must treat an error result as "skip the From-mismatch signal" rather than
// falsely amplifying.
func (r *userDomainsResolver) Domains(user string) (map[string]struct{}, error) {
if user == "" {
return nil, errors.New("empty user")
}
r.mu.Lock()
if e, ok := r.cache[user]; ok && time.Since(e.fetched) < r.ttl {
r.mu.Unlock()
return e.domains, e.err
}
r.mu.Unlock()
set, err := r.read(user)
r.mu.Lock()
r.cache[user] = userDomainsCacheEntry{domains: set, fetched: time.Now(), err: err}
r.mu.Unlock()
return set, err
}
// Invalidate removes the cached entry for user. Callers wire this to
// inotify on /var/cpanel/userdata/<user>/.
func (r *userDomainsResolver) Invalidate(user string) {
r.mu.Lock()
delete(r.cache, user)
r.mu.Unlock()
}
func (r *userDomainsResolver) read(user string) (map[string]struct{}, error) {
path := filepath.Join(r.root, user, "main")
// #nosec G304 -- r.root is fixed at /var/cpanel/userdata/, user is a
// cPanel-managed account name validated by the resolver's caller; the
// resulting path is constrained to the cpanel userdata tree.
data, err := os.ReadFile(path)
if err != nil {
return map[string]struct{}{}, fmt.Errorf("read %s: %w", path, err)
}
var raw struct {
MainDomain string `yaml:"main_domain"`
AddonDomains map[string]string `yaml:"addon_domains"`
ParkedDomains []string `yaml:"parked_domains"`
SubDomains []string `yaml:"sub_domains"`
}
if err := yaml.Unmarshal(data, &raw); err != nil {
return map[string]struct{}{}, fmt.Errorf("parse %s: %w", path, err)
}
set := make(map[string]struct{}, 8)
add := func(d string) {
d = strings.TrimSpace(d)
if d == "" {
return
}
// ExtractDomain handles IDN normalisation and lowercasing for free.
// It expects an addr-style input but happily round-trips bare hosts.
norm := emailspool.ExtractDomain("anyone@" + d)
if norm == "" {
norm = strings.ToLower(d)
}
set[norm] = struct{}{}
}
add(raw.MainDomain)
for k := range raw.AddonDomains {
add(k)
}
for _, d := range raw.ParkedDomains {
add(d)
}
for _, d := range raw.SubDomains {
add(d)
}
return set, nil
}
// IsAuthorisedFromDomain reports whether fromDomain is one of the user's
// domains, accounting for subdomain inclusion (a sub.example.com From is
// authorised if example.com is in the set, but the reverse is NOT true).
func IsAuthorisedFromDomain(fromDomain string, authSet map[string]struct{}) bool {
if fromDomain == "" || len(authSet) == 0 {
return false
}
for base := range authSet {
if emailspool.IsSubdomainOrEqual(fromDomain, base) {
return true
}
}
return false
}
package daemon
import (
"sort"
"time"
)
// credentialStuffingDetector tracks, per source IP, the set of distinct
// accounts hit by auth failures inside a sliding window and flags an IP that
// targets many distinct accounts. This is the breadth signal of credential
// stuffing / password spraying -- one source trying many accounts, often with
// only one or two attempts each -- which the count-based pam_bruteforce
// detector (depth: many failures, any account) does not capture. CSM's auth
// sources never expose the attempted password, so the detector keys on the
// distinct-account behavioral signature rather than a password fingerprint.
//
// Concurrency: callers serialize access (the PAM listener holds its mutex
// across Record), so the detector takes no lock of its own.
type credentialStuffingDetector struct {
distinctAccounts int
window time.Duration
now func() time.Time
perIP map[string]*credStuffState
// maxTrackedIPs bounds the live map under sustained source-IP churn.
// Once at the cap the oldest-by-lastSeen entry is evicted before insert.
maxTrackedIPs int
}
type credStuffState struct {
accounts map[string]time.Time
lastSeen time.Time
// fired is set once the IP crosses the distinct-account threshold so a
// single active window yields one finding, not one per additional account.
fired bool
}
// newCredentialStuffingDetector builds a detector. A distinctAccounts
// threshold below 2 has no breadth meaning and is clamped to 2.
func newCredentialStuffingDetector(distinctAccounts int, window time.Duration, now func() time.Time) *credentialStuffingDetector {
if distinctAccounts < 2 {
distinctAccounts = 2
}
if now == nil {
now = time.Now
}
return &credentialStuffingDetector{
distinctAccounts: distinctAccounts,
window: window,
now: now,
perIP: make(map[string]*credStuffState),
maxTrackedIPs: 10000,
}
}
// Record ingests one auth failure for (ip, account). It returns the sorted
// distinct-account list and true exactly once -- when the IP first crosses
// the distinct-account threshold inside the window. Empty ip or account is
// ignored. A failure whose IP has been silent longer than the window resets
// that IP's distinct set so a fresh campaign does not inherit cold counts.
func (d *credentialStuffingDetector) Record(ip, account string) ([]string, bool) {
if d == nil || ip == "" || account == "" {
return nil, false
}
now := d.now()
state, ok := d.perIP[ip]
if ok {
d.pruneAccountWindow(state, now)
}
if ok && len(state.accounts) == 0 {
state = nil
delete(d.perIP, ip)
}
if state == nil {
if d.maxTrackedIPs > 0 && len(d.perIP) >= d.maxTrackedIPs {
d.evictOldest()
}
state = &credStuffState{accounts: make(map[string]time.Time)}
d.perIP[ip] = state
}
state.accounts[account] = now
state.lastSeen = now
if len(state.accounts) < d.distinctAccounts {
state.fired = false
return nil, false
}
if state.fired {
return nil, false
}
state.fired = true
accounts := make([]string, 0, len(state.accounts))
for a := range state.accounts {
accounts = append(accounts, a)
}
sort.Strings(accounts)
return accounts, true
}
// PruneStale drops per-IP entries whose lastSeen is older than the window.
// Called from the PAM listener cleanup loop so the detector does not grow
// without bound between window resets. Returns the number pruned.
func (d *credentialStuffingDetector) PruneStale(now time.Time) int {
if d == nil {
return 0
}
pruned := 0
for ip, state := range d.perIP {
d.pruneAccountWindow(state, now)
if len(state.accounts) == 0 {
delete(d.perIP, ip)
pruned++
continue
}
if len(state.accounts) < d.distinctAccounts {
state.fired = false
}
}
return pruned
}
// Clear removes all tracked breadth state for ip after a successful login.
func (d *credentialStuffingDetector) Clear(ip string) {
if d == nil || ip == "" {
return
}
delete(d.perIP, ip)
}
// Configure applies live threshold/window changes. Callers hold their
// external mutex, matching Record's synchronization contract.
func (d *credentialStuffingDetector) Configure(distinctAccounts int, window time.Duration, now time.Time) {
if d == nil {
return
}
if distinctAccounts < 2 {
distinctAccounts = 2
}
d.distinctAccounts = distinctAccounts
d.window = window
for ip, state := range d.perIP {
d.pruneAccountWindow(state, now)
if len(state.accounts) == 0 {
delete(d.perIP, ip)
continue
}
if len(state.accounts) < d.distinctAccounts {
state.fired = false
}
}
}
func (d *credentialStuffingDetector) pruneAccountWindow(state *credStuffState, now time.Time) {
if state == nil {
return
}
for account, seen := range state.accounts {
if now.Sub(seen) > d.window {
delete(state.accounts, account)
}
}
}
// evictOldest removes the entry with the smallest lastSeen so a fresh insert
// stays within maxTrackedIPs. Linear scan, bounded by the cap.
func (d *credentialStuffingDetector) evictOldest() {
var oldestIP string
var oldestAt time.Time
first := true
for ip, state := range d.perIP {
if first || state.lastSeen.Before(oldestAt) {
oldestIP, oldestAt, first = ip, state.lastSeen, false
}
}
if oldestIP != "" {
delete(d.perIP, oldestIP)
}
}
package daemon
import (
"context"
"fmt"
"net"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/attackdb"
"github.com/pidginhost/csm/internal/auditd"
"github.com/pidginhost/csm/internal/broadcast"
"github.com/pidginhost/csm/internal/challenge"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/emailav"
"github.com/pidginhost/csm/internal/emailspool"
"github.com/pidginhost/csm/internal/firewall"
"github.com/pidginhost/csm/internal/firewall/rollback"
"github.com/pidginhost/csm/internal/geoip"
"github.com/pidginhost/csm/internal/health"
"github.com/pidginhost/csm/internal/integrity"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/maillog"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/modsec"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/sdnotify"
"github.com/pidginhost/csm/internal/signatures"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
"github.com/pidginhost/csm/internal/threatintel"
"github.com/pidginhost/csm/internal/updatecheck"
"github.com/pidginhost/csm/internal/verdict"
"github.com/pidginhost/csm/internal/webui"
"github.com/pidginhost/csm/internal/yara"
"github.com/pidginhost/csm/internal/yaraworker"
)
const eximMainlogPath = "/var/log/exim_mainlog"
// Daemon is the main persistent monitoring process.
type Daemon struct {
cfg *config.Config
store *state.Store
lock *state.LockFile
binaryPath string
logWatchers []*LogWatcher
logWatchersMu sync.Mutex
fileMonitor *FileMonitor
hijackDetector *PasswordHijackDetector
pamListener *PAMListener
controlListener *ControlListener
spoolWatcher *SpoolWatcher
spoolWatcherMu sync.Mutex
forwarderWatcher *ForwarderWatcher
emailQuarantine *emailav.Quarantine
webServer *webui.Server
challengeServer *challenge.Server
ipList *challenge.IPList
challengeGate challenge.PortGate
fwEngine *firewall.Engine
baselineMu sync.Mutex // serialises CmdBaseline handler runs
geoipDB *geoip.DB
geoipMu sync.Mutex // protects geoipDB for publishGeoIP
version string
alertCh chan alert.Finding
droppedAlerts int64 // atomic counter for alert channel backpressure drops
stopCh chan struct{}
scanCtx context.Context
scanCancel context.CancelFunc // cancels in-flight periodic scans on shutdown
abuseReportStop chan struct{}
abuseReportDone chan struct{}
wg sync.WaitGroup
smtpAuthTracker *smtpAuthTracker
smtpProbeTracker *smtpProbeTracker
mailAuthTracker *mailAuthTracker
startTime time.Time
// yaraSup is the supervised YARA-X worker, wired up when
// yaraWorkerOn(cfg) is true (the default; see ROADMAP item 2).
// Nil when the in-process scanner is in use
// (cfg.Signatures.YaraWorkerEnabled explicitly set to false).
yaraSup *yaraworker.Supervisor
yaraCrashMu sync.Mutex
yaraLastCrashAlert time.Time
// forceFullRescan is armed by the signature watcher
// (sig_watch.go) when any tracked rule file's mtime advances.
// The deep-tier scheduler reads + clears the flag at the start
// of each tick; when set, the tick bypasses the fanotify
// short-list and runs the full account tree against the new
// ruleset.
forceFullRescan atomic.Bool
// policies holds the email PHP-relay pattern policies
// (suspicious/safe x-mailer classes, HTTP proxy ranges) loaded
// from EmailProtection.PHPRelay.PoliciesDir. Initialised in O2
// (daemon wiring); stays nil until then. The SIGHUP path
// nil-guards the Reload call so this commit is a no-op at
// runtime until O2 lands.
policies *emailspool.Policies
// PHP-relay components wired by startPHPRelay (Linux only). The fields are
// declared cross-platform but stay nil on non-cPanel or non-Linux hosts.
autoFreezer *autoFreezer
phpRelayShutdown []func() // ordered shutdown hooks
// watcherStatus tracks which top-level watchers have successfully attached.
// Keys are short stable names ("fanotify", "audit", "spool", "modsec",
// "afalg"). Values flip from false-to-true when the watcher's setup
// function completes without error. Used by /api/v1/status and the
// sd_notify gate.
//
// watcherChangedAt records the wall-clock time of the most recent state
// transition for the same key. Driven from MarkWatcher; consumed by the
// /api/v1/components endpoint so operators can see how long a watcher
// has been in its current state.
watcherMu sync.RWMutex
watcherStatus map[string]bool
watcherChangedAt map[string]time.Time
// watcherUpstream maps a watcher name to a probe function that reports
// whether the upstream feeding the watcher is still active. The probe
// runs at most once per /api/v1/components scrape; results compose with
// WatcherStatuses to surface "deaf" (attached but no upstream traffic)
// distinct from "idle" (attached and quiet) in the dashboard.
watcherUpstream map[string]UpstreamProbe
// findingBus fans out dispatched findings to passive observers like
// the SSE event stream. Initialized in Run(); closed on shutdown.
findingBus *broadcast.Bus
// updateChecker polls upstream for new CSM releases. Wired in Run()
// when updates.check_enabled is true (default). Nil when disabled
// or before Run starts; UpdateInfo() handles that.
updateChecker *updatecheck.Checker
// lastAutomationActionCache memoises the newest automation-emitted
// finding so /api/v1/status does not run a 100-row history cursor on
// every poll. Invalidated after lastAutomationActionTTL elapses.
automationActionMu sync.Mutex
automationActionCache *health.AutomationAction
automationActionCached time.Time
}
const lastAutomationActionTTL = 5 * time.Second
// New creates a new daemon instance.
func New(cfg *config.Config, store *state.Store, lock *state.LockFile, binaryPath string) *Daemon {
d := &Daemon{
cfg: cfg,
store: store,
lock: lock,
binaryPath: binaryPath,
alertCh: make(chan alert.Finding, 500),
stopCh: make(chan struct{}),
}
d.smtpAuthTracker = newSMTPAuthTracker(
cfg.Thresholds.SMTPBruteForceThreshold,
cfg.Thresholds.SMTPBruteForceSubnetThresh,
cfg.Thresholds.SMTPAccountSprayThreshold,
time.Duration(cfg.Thresholds.SMTPBruteForceWindowMin)*time.Minute,
time.Duration(cfg.Thresholds.SMTPBruteForceSuppressMin)*time.Minute,
cfg.Thresholds.SMTPBruteForceMaxTracked,
time.Now,
)
d.smtpProbeTracker = newSMTPProbeTracker(
cfg.Thresholds.SMTPProbeThreshold,
time.Duration(cfg.Thresholds.SMTPProbeWindowMin)*time.Minute,
time.Duration(cfg.Thresholds.SMTPProbeSuppressMin)*time.Minute,
cfg.Thresholds.SMTPProbeMaxTracked,
time.Now,
smtpProbeBlockExpiryString,
)
d.mailAuthTracker = newMailAuthTracker(
cfg.Thresholds.MailBruteForceThreshold,
cfg.Thresholds.MailBruteForceSubnetThresh,
cfg.Thresholds.MailAccountSprayThreshold,
time.Duration(cfg.Thresholds.MailBruteForceWindowMin)*time.Minute,
time.Duration(cfg.Thresholds.MailBruteForceSuppressMin)*time.Minute,
cfg.Thresholds.MailBruteForceMaxTracked,
time.Now,
)
return d
}
// SetVersion sets the application version for display in the web UI.
func (d *Daemon) SetVersion(v string) {
d.version = v
}
// MarkWatcher records the attachment state of a named watcher.
// Call from each watcher's startup path: true on success, false on failure.
// The first record and any subsequent state transition stamps
// watcherChangedAt so the components view can show "since".
func (d *Daemon) MarkWatcher(name string, attached bool) {
d.watcherMu.Lock()
defer d.watcherMu.Unlock()
if d.watcherStatus == nil {
d.watcherStatus = make(map[string]bool)
}
if d.watcherChangedAt == nil {
d.watcherChangedAt = make(map[string]time.Time)
}
prev, existed := d.watcherStatus[name]
if !existed || prev != attached {
d.watcherChangedAt[name] = time.Now()
}
d.watcherStatus[name] = attached
}
// WatcherStatuses returns a snapshot of every recorded watcher.
func (d *Daemon) WatcherStatuses() map[string]bool {
d.watcherMu.RLock()
defer d.watcherMu.RUnlock()
out := make(map[string]bool, len(d.watcherStatus))
for k, v := range d.watcherStatus {
out[k] = v
}
return out
}
// WatcherChangedAt returns the wall-clock time at which each watcher last
// transitioned state. Watchers without a recorded change return the zero
// value.
func (d *Daemon) WatcherChangedAt() map[string]time.Time {
d.watcherMu.RLock()
defer d.watcherMu.RUnlock()
out := make(map[string]time.Time, len(d.watcherChangedAt))
for k, v := range d.watcherChangedAt {
out[k] = v
}
return out
}
// UpstreamProbe reports whether a watcher's upstream input source is
// still feeding it. Return Fresh=false when the watcher is attached but
// can no longer hear from its source (PAM module not installed, log file
// rotated and never reappeared, fanotify marks lost, etc.) so the
// dashboard can surface "deaf" instead of conflating it with "idle".
type UpstreamProbe func() health.UpstreamResult
// RegisterUpstreamProbe wires a probe for a named watcher. Safe to call
// from any watcher's startup path; repeated calls overwrite the previous
// probe. Probes run on the request thread of /api/v1/components, so they
// must be cheap (single stat / atomic load, not a syscall storm).
func (d *Daemon) RegisterUpstreamProbe(name string, probe UpstreamProbe) {
d.watcherMu.Lock()
defer d.watcherMu.Unlock()
if d.watcherUpstream == nil {
d.watcherUpstream = make(map[string]UpstreamProbe)
}
d.watcherUpstream[name] = probe
}
// WatcherUpstream returns a snapshot of every probed watcher's upstream
// state. Watchers without a registered probe are absent from the map; the
// components API treats absence as "no probe wired, do not surface a
// deaf verdict for this watcher".
func (d *Daemon) WatcherUpstream() map[string]health.UpstreamResult {
d.watcherMu.RLock()
probes := make(map[string]UpstreamProbe, len(d.watcherUpstream))
for k, v := range d.watcherUpstream {
probes[k] = v
}
d.watcherMu.RUnlock()
out := make(map[string]health.UpstreamResult, len(probes))
for name, probe := range probes {
if probe == nil {
continue
}
out[name] = probe()
}
return out
}
// buildInfoOnce guards process-wide registration of the build_info
// gauge so repeated daemon construction in tests does not panic.
var buildInfoOnce sync.Once
// storeSizeOnce guards the csm_store_size_bytes gauge hook so tests
// that create multiple daemons share a single registration.
var storeSizeOnce sync.Once
// registerBuildInfo exposes build metadata on /metrics in the
// conventional Prometheus shape: a gauge fixed at 1, with the
// interesting fields as labels so scrapers can join on them.
func (d *Daemon) registerBuildInfo() {
buildInfoOnce.Do(func() {
g := metrics.NewGaugeVec(
"csm_build_info",
"CSM build metadata. Value is always 1; read version from the label.",
[]string{"version"},
)
version := d.version
if version == "" {
version = "unknown"
}
g.With(version).Set(1)
metrics.MustRegister("csm_build_info", g)
})
}
// registerStoreSizeMetric exposes the bbolt on-disk size as a gauge
// that stats the file at scrape time. No caching: the expected scrape
// interval is 15+ seconds and stat is cheap.
func (d *Daemon) registerStoreSizeMetric() {
storeSizeOnce.Do(func() {
metrics.RegisterGaugeFunc(
"csm_store_size_bytes",
"On-disk size of the bbolt state database in bytes.",
func() float64 {
db := store.Global()
if db == nil {
return 0
}
info, err := os.Stat(db.Path())
if err != nil {
return 0
}
return float64(info.Size())
},
)
})
}
// firewallMetricsOnce guards /metrics registration of the firewall +
// blocked-IP gauges so repeated daemon starts in a test binary are
// idempotent.
var firewallMetricsOnce sync.Once
var firewallMetricsMu sync.RWMutex
var firewallMetricsEngine *firewall.Engine
func setFirewallMetricsEngine(engine *firewall.Engine) {
firewallMetricsMu.Lock()
defer firewallMetricsMu.Unlock()
firewallMetricsEngine = engine
}
func firewallMetricsRuleCounts() firewall.RuleCounts {
firewallMetricsMu.RLock()
engine := firewallMetricsEngine
firewallMetricsMu.RUnlock()
if engine == nil {
return firewall.RuleCounts{}
}
return engine.RuleCounts()
}
func (d *Daemon) setFirewallEngine(engine *firewall.Engine) {
d.fwEngine = engine
setFirewallMetricsEngine(engine)
}
// registerFirewallMetrics exposes the count of blocked IPs and the
// total number of firewall rules (IPs + allowed + subnets + port
// allow entries). Both gauges read the firewall engine at scrape time:
// the engine state file is authoritative, while the parallel bbolt
// fw:* buckets are written only at migration, so reading the store
// would freeze the gauge at the migration-time snapshot.
func (d *Daemon) registerFirewallMetrics() {
setFirewallMetricsEngine(d.fwEngine)
firewallMetricsOnce.Do(func() {
metrics.RegisterGaugeFunc(
"csm_blocked_ips_total",
"Number of IPs currently on the firewall block list (excluding expired temp bans).",
func() float64 {
return float64(firewallMetricsRuleCounts().Blocked)
},
)
metrics.RegisterGaugeFunc(
"csm_firewall_rules_total",
"Total firewall rules across all categories (blocked IPs, allowed IPs, blocked subnets, port-specific allows).",
func() float64 {
return float64(firewallMetricsRuleCounts().Total())
},
)
})
}
// Run starts the daemon and blocks until stopped.
func (d *Daemon) Run() error {
d.startTime = time.Now()
if d.store != nil {
d.store.EnsureBaseline(d.startTime)
}
// Initialize structured logging from environment (CSM_LOG_FORMAT,
// CSM_LOG_LEVEL). The default text handler preserves the legacy
// "[YYYY-MM-DD HH:MM:SS] msg" format so operators mixing csmlog
// with legacy fmt.Fprintf call sites see a uniform log stream.
// CSM_LOG_FORMAT=json switches to structured JSON for log shipping.
csmlog.Init()
csmlog.Info("CSM daemon starting")
// Periodic scans use this context so shutdown can abort an in-flight tier
// instead of blocking the worker drain for the full check budget.
scanCtx, scanCancel := context.WithCancel(context.Background())
d.scanCtx = scanCtx
d.scanCancel = scanCancel
defer scanCancel()
// Wire the active config to the incident-singleton's auto-close
// loop BEFORE the singleton is constructed, so the loop reads the
// operator-supplied thresholds on its first sweep. The closure
// captures the Daemon to pick up reloaded configs without restart.
SetIncidentConfigSource(func() *config.Config { return d.currentCfg() })
// Initialize the findings broadcast bus so passive observers (SSE, etc.)
// can subscribe before any findings are dispatched.
d.findingBus = broadcast.NewBus(64)
alert.FindingBus = d.findingBus
// Install config-supplied platform overrides BEFORE the first Detect()
// call so every check sees the merged view. Must happen before any
// other code calls platform.Detect() in this daemon run.
//
// WebServer is a *platform.WebServer pointer — take address only when
// the operator actually supplied a non-empty type, otherwise leave
// the auto-detected value alone.
var wsOverride *platform.WebServer
if t := d.cfg.WebServer.Type; t != "" {
ws := platform.WebServer(t)
wsOverride = &ws
}
platform.SetOverrides(platform.Overrides{
WebServer: wsOverride,
ApacheConfigDir: d.cfg.WebServer.ConfigDir,
AccessLogPaths: d.cfg.WebServer.AccessLogs,
ErrorLogPaths: d.cfg.WebServer.ErrorLogs,
ModSecAuditLogPaths: d.cfg.WebServer.ModSecAudits,
DomlogGlobs: d.cfg.WebServer.DomlogGlobs,
})
// Log detected platform as a structured record. In text mode this
// comes out as "[ts] platform detected os=X panel=Y ..."; in
// JSON mode as {"msg":"platform detected","os":"X","panel":"Y",...}.
pi := platform.Detect()
csmlog.Info("platform detected",
"os", orUnknown(string(pi.OS)),
"os_version", orUnknown(pi.OSVersion),
"panel", orNone(string(pi.Panel)),
"webserver", orNone(string(pi.WebServer)),
)
// Build the ModSec rule-action registry before any log watcher starts
// so the LiteSpeed classifier can tell pass-action vendor rules apart
// from real denies on the very first parsed line.
d.initModSecRegistry()
// Wire the firewall tentative-apply manager. Recovery has to run
// before integrity.Verify because a pending rollback whose deadline
// passed while the daemon was down restores the previous csm.yaml
// (and its integrity hash) to disk; verifying first would fail
// against the still-on-disk new config the operator never confirmed.
if sdb := store.Global(); sdb != nil {
mgr := rollback.NewManager(sdb, d.cfg.ConfigFile, rollback.SystemctlRestart, time.Now)
rollback.SetGlobal(mgr)
recoveryCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
reverted, rerr := mgr.RecoverOnStartup(recoveryCtx)
cancel()
if rerr != nil {
return fmt.Errorf("firewall rollback recovery failed: %w", rerr)
}
if reverted {
csmlog.Warn("firewall rollback expired during downtime; previous config restored, restart issued")
// systemctl restart csm.service from the manager will tear
// us down momentarily; bail out cleanly so we do not race
// the watchers we are about to start against the restart.
return nil
}
}
// Verify integrity on startup
if err := integrity.Verify(d.binaryPath, d.cfg); err != nil {
tamper := alert.Finding{
Severity: alert.Critical,
Check: "integrity",
Message: fmt.Sprintf("BINARY/CONFIG TAMPER DETECTED: %v", err),
Timestamp: time.Now(),
}
_ = alert.Dispatch(d.cfg, []alert.Finding{tamper})
return fmt.Errorf("integrity check failed: %w", err)
}
// Publish the verified config as the process-wide live pointer.
// Hot paths (check ticks, alert dispatch, etc.) call
// config.Active() to pick up the current snapshot so a SIGHUP
// reload is visible on the next call without restart.
publishActiveConfig(d.cfg, "startup")
// Install the mail-brute account-key extractor selected by config.
// Validation in config.Load() already rejected invalid specs, so the
// error path here is defense-in-depth only.
if err := installAccountExtractorFromConfig(d.cfg); err != nil {
return err
}
// Self-heal the auditd rules file. Package upgrades sometimes ship
// a new csm binary without re-running auditd.Deploy() (postinstall
// hooks differ across apt/dnf and across operator deploy automation),
// which leaves new rules — including detection layers like
// csm_af_alg_socket — silently inactive on the upgraded host. The
// startup compare-and-redeploy here closes that gap. Errors are
// non-fatal: if auditd is absent or augenrules fails, the rest of
// CSM still runs.
if redeployed, err := auditd.EnsureDeployed(); err != nil {
csmlog.Warn("auditd rules ensure failed", "err", err)
} else if redeployed {
csmlog.Info("auditd rules redeployed (drift from embedded constant)")
}
// Deploy WHM plugin and configs if cPanel is present
deployConfigs()
// Initialize signature scanners and threat DB (fast, no I/O scan)
d.registerBuildInfo()
d.registerStoreSizeMetric()
d.registerFirewallMetrics()
checks.RegisterDirectSMTPEgressMetrics(metrics.Default())
RegisterBPFEnforcementMetrics(metrics.Default())
if err := d.initYaraBackend(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] YARA backend init: %v\n", ts(), err)
}
checks.InitThreatDB(d.cfg.StatePath, d.cfg.Reputation.Whitelist)
if db := checks.GetThreatDB(); db != nil {
fmt.Fprintf(os.Stderr, "[%s] Threat DB initialized (%d entries)\n", ts(), db.Count())
}
if adb := attackdb.Init(d.cfg.StatePath); adb != nil {
// Seed from permanent blocklist on first run (when attack DB is empty)
if adb.TotalIPs() == 0 {
if n := adb.SeedFromPermanentBlocklist(d.cfg.StatePath); n > 0 {
fmt.Fprintf(os.Stderr, "[%s] Attack DB seeded %d IPs from permanent blocklist\n", ts(), n)
}
}
fmt.Fprintf(os.Stderr, "[%s] Attack DB initialized (%s)\n", ts(), adb.FormatTopLine())
}
// Start firewall engine if enabled
d.startFirewall()
// Construct the incident correlator after the firewall starts so the
// credential_spray block hand-off is installed before the singleton
// captures its config. This still happens before any finding is
// dispatched.
_ = IncidentCorrelator()
// Start challenge server if enabled (gray listing)
d.startChallengeServer()
// Start challenge escalation ticker
if d.ipList != nil {
d.wg.Add(1)
obs.Go("challenge-escalator", d.challengeEscalator)
}
// Create password hijack detector
d.hijackDetector = NewPasswordHijackDetector(d.cfg, d.alertCh, d.stopCh)
// Start inotify log watchers
d.startLogWatchers()
// Start PAM listener for real-time brute-force detection
d.startPAMListener()
// Start control socket listener for the thin-client CLI. The
// daemon is the sole bbolt owner; CLI commands that previously
// raced for the lock now route through this socket.
d.startControlListener()
// Start fanotify file monitor (real-time detection starts immediately)
d.startFileMonitor()
// Start email AV spool watcher (separate fanotify for Exim spool).
// Spool and forwarder watchers are cPanel-only; they watch paths
// (/var/spool/exim, /etc/valiases) that only exist on cPanel hosts.
if platform.Detect().IsCPanel() {
d.startSpoolWatcher()
d.startForwarderWatcher()
}
// Wire the update checker before the Web UI starts so the
// /api/v1/status handler always sees a non-nil checker (the
// goroutine that polls upstream still warms up for 5 minutes
// before the first poll, but UpdateInfo() returns zero values
// without racing on the field assignment).
d.startUpdateChecker()
// Start Web UI server - available immediately, before initial scan
d.startWebUI()
// Wire email quarantine to web server (after both start)
d.syncEmailAVWebState()
// Initialize GeoIP databases (after webServer so SetGeoIPDB can attach)
d.initGeoIP()
// Signal systemd we're up as soon as the real-time watchers and the
// control surfaces are attached. The initial baseline scan, the
// kernel-state probes, and the BPF tracker wiring below all run
// inline but no longer block `systemctl is-active` / `systemctl
// restart`. The watchdog notifier has to start in the same step so
// systemd's WatchdogSec doesn't trip while the baseline scan is
// still running on a large host.
d.wg.Add(1)
obs.Go("watchdog-notifier", d.watchdogNotifier)
if sent, err := sdnotify.Ready(); err != nil {
fmt.Fprintf(os.Stderr, "sd_notify READY failed: %v\n", err)
} else if sent {
fmt.Fprintf(os.Stderr, "sd_notify: daemon ready\n")
}
_, _ = sdnotify.Status(fmt.Sprintf("watchers attached: %d", countAttachedWatchers(d.WatcherStatuses())))
// Reconcile the opt-in email forward-guard to the current config (installs
// the exim rule when enabled+enforcing, removes it otherwise), then keep its
// bad-IP lookup fresh. Both no-op off cPanel and fail open on error.
d.reconcileForwardGuard()
d.wg.Add(1)
obs.Go("forward-guard-refresh", d.forwardGuardRefresher)
// Snapshot the live config ONCE for the entire initial-scan tick
// (detection + auto-response). Earlier code called d.currentCfg()
// for RunTier and then again later for the auto-response batch;
// a SIGHUP landing between the two reads split the same tick
// across old detection policy and new response policy.
initialCfg := d.currentCfg()
// Run initial scan synchronously (before dispatcher starts)
fmt.Fprintf(os.Stderr, "[%s] Running initial baseline scan...\n", ts())
initialFindings, initialPurge := checks.RunTierWithContext(d.scanContext(), initialCfg, d.store, checks.TierCritical)
// Seed the attack database with initial scan findings
if adb := attackdb.Global(); adb != nil {
for _, f := range initialFindings {
adb.RecordFinding(f)
}
}
d.store.AppendHistory(initialFindings)
newFindings := d.store.FilterNew(initialFindings)
suppressions := d.store.LoadSuppressions()
initialAutoResponseFindings := initialFindings
if len(suppressions) > 0 {
initialAutoResponseFindings = filterUnsuppressedFindings(d.store, initialFindings, suppressions)
newFindings = filterUnsuppressedFindings(d.store, newFindings, suppressions)
}
// Permission auto-fix runs on ALL findings (not just new) because
// it's safe/idempotent and should fix baseline findings too.
permActions, permFixedKeys := checks.AutoFixPermissions(initialCfg, initialAutoResponseFindings)
// Challenge routing runs on ALL findings unconditionally when enabled.
challengeActions := checks.ChallengeRouteIPs(initialCfg, initialAutoResponseFindings)
// Other auto-response only on new findings
if len(newFindings) > 0 {
killActions := checks.AutoKillProcesses(initialCfg, newFindings)
quarantineActions := checks.AutoQuarantineFiles(initialCfg, newFindings)
blockActions := checks.AutoBlockIPs(initialCfg, initialAutoResponseFindings)
newFindings = append(newFindings, killActions...)
newFindings = append(newFindings, quarantineActions...)
newFindings = append(newFindings, permActions...)
newFindings = append(newFindings, challengeActions...)
newFindings = append(newFindings, blockActions...)
// Cross-account correlation runs on the initial batch too, not
// just on subsequent ticks. Otherwise three account compromises
// landing in the first scan slip past with no synthetic alert.
newFindings = expandWithCorrelation(newFindings, time.Now())
co := IncidentCorrelator()
for _, f := range newFindings {
_, _, _ = co.OnFinding(f)
}
_ = alert.Dispatch(initialCfg, newFindings)
}
// Remove auto-fixed findings before storing to UI
if len(permFixedKeys) > 0 {
fixedSet := make(map[string]bool, len(permFixedKeys))
for _, k := range permFixedKeys {
fixedSet[k] = true
}
var filtered []alert.Finding
for _, f := range initialFindings {
key := f.Check + ":" + f.Message
if !fixedSet[key] {
filtered = append(filtered, f)
}
}
initialFindings = filtered
}
d.store.Update(initialFindings)
d.store.MarkAlerted(newFindings)
// Merge initial scan findings into the existing set. Previous deep scan
// results (outdated_plugins, wp_core, etc.) persist across restarts until
// the next deep scan replaces them. ClearLatestFindings is NOT called
// here - it would wipe deep scan findings that haven't re-run yet.
checks.StoreLatestScanFindings(d.store, initialPurge, initialFindings)
csmlog.Info("initial scan complete", "findings", len(initialFindings), "new", len(newFindings))
// NOW start the alert dispatcher - no more race with initial scan
d.wg.Add(1)
obs.Go("alert-dispatcher", d.alertDispatcher)
// Retrospective cloud-relay scan: replay the last 24h of exim_mainlog
// through the compromise-detection rule so any in-progress credential
// abuse is surfaced within seconds of daemon start, not after the
// realtime watcher sees a new line. Gated on cPanel because
// exim_mainlog is cPanel-specific; safe to run in a goroutine so
// startup isn't delayed by parsing a large log.
if platform.Detect().IsCPanel() {
d.wg.Add(1)
obs.Go("cloud-relay-retro-scan", func() {
defer d.wg.Done()
cfg := d.currentCfg()
retro := ScanEximHistoryForCloudRelay(cfg, "", time.Now(), 24*time.Hour)
for _, f := range retro {
// Enqueue the finding FIRST; only after it is
// accepted by the dispatcher do we trigger the
// account-suspend side-effect. This prevents a
// silent mailbox suspension if the daemon begins
// shutting down between these two operations.
select {
case d.alertCh <- f:
case <-d.stopCh:
return
}
sender := extractSenderFromCloudRelayMessage(f.Message)
if sender == "" {
continue
}
handleCloudRelayCredentialAbuse(cfg, sender)
}
})
}
// Start periodic scanners
d.wg.Add(1)
obs.Go("critical-scanner", d.criticalScanner)
d.wg.Add(1)
obs.Go("deep-scanner", d.deepScanner)
// Async bot-rDNS verifier: runs PTR+forward-A verification for
// claimed search-engine bot IPs that are not in a static range.
// Gated on reputation.bot_verify_enabled (default true) and on a
// non-nil store so the result can be persisted.
if d.cfg.BotVerifyEnabled() {
if db := store.Global(); db != nil {
if dropped, err := db.EnsureBotVerifyLogicVersion(threatintel.LogicVersion); err != nil {
csmlog.Warn("bot-verify cache version check failed", "err", err)
} else if dropped {
csmlog.Info("bot-verify cache dropped after logic upgrade",
"logic_version", threatintel.LogicVersion)
}
bv := threatintel.NewAsyncBotVerifier(db.PutBotVerify)
d.wg.Add(1)
obs.Go("bot-verify", func() {
defer d.wg.Done()
bv.Run(d.stopCh)
})
checks.SetBotVerifier(bv, db.GetBotVerify)
}
}
// Live AF_ALG listener (Copy Fail / CVE-2026-31431) — only started
// when the kernel is actually exploitable. Hosts with a KernelCare
// livepatch covering CVE-2026-31431, OR built without the AF_ALG
// aead interface entirely, skip the listener: there's nothing to
// detect, and the inotify watch + 500ms tick would just burn cycles.
// The hardening audit + periodic critical-tier check stay active
// either way, so re-introduction of the vulnerability (e.g., a
// kernel rollback) is still surfaced via the slower path.
kstate := checks.ObserveAFAlgKernelState()
switch {
case !kstate.IsCopyFailExploitable():
csmlog.Info("af_alg live listener: skipped",
"reason", "kernel not exploitable",
"state", kstate.String(),
)
default:
if mon := StartAFAlgLiveMonitor(d.alertCh, d.cfg); mon == nil {
csmlog.Warn("af_alg live listener: not started",
"reason", "no backend available",
"state", kstate.String(),
)
d.MarkWatcher("afalg", false)
} else {
csmlog.Info("af_alg live listener: started",
"backend", mon.Mode(),
"state", kstate.String(),
)
d.wg.Add(1)
obs.Go("af-alg-listener", func() {
defer d.wg.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { <-d.stopCh; cancel() }()
mon.Run(ctx)
})
d.MarkWatcher("afalg", true)
}
}
d.startPHPRelay()
if mon := StartConnectionTracker(d.alertCh, d.cfg); mon != nil {
csmlog.Info("connection_tracker: started", "backend", mon.Mode())
d.wg.Add(1)
obs.Go("connection-tracker", func() {
defer d.wg.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { <-d.stopCh; cancel() }()
mon.Run(ctx)
})
}
if mon := StartExecMonitor(d.alertCh, d.cfg); mon != nil {
csmlog.Info("exec_monitor: started", "backend", mon.Mode())
d.wg.Add(1)
obs.Go("exec-monitor", func() {
defer d.wg.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { <-d.stopCh; cancel() }()
mon.Run(ctx)
})
}
if mon := StartSensitiveFileMonitor(d.alertCh, d.cfg, d.store); mon != nil {
csmlog.Info("sensitive_files: started", "backend", mon.Mode())
d.wg.Add(1)
obs.Go("sensitive-files", func() {
defer d.wg.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { <-d.stopCh; cancel() }()
mon.Run(ctx)
})
}
// Start automatic signature updates
d.wg.Add(1)
obs.Go("signature-updater", d.signatureUpdater)
// Start signature mtime watcher: arms forceFullRescan when any
// rule file's mtime advances. Disabled wholesale via
// detection.rescan_on_signature_update: false.
d.wg.Add(1)
obs.Go("signature-watcher", d.signatureWatcher)
d.wg.Add(1)
obs.Go("geoip-updater", d.geoipUpdater)
// Start heartbeat
d.wg.Add(1)
obs.Go("heartbeat", d.heartbeat)
// Start abuse reporting (opt-in). startAbuseReporting installs the alert
// hook and returns the spool drain loop, or nil when disabled. The reporter
// stops after the final shutdown alert flush so late findings can still be
// queued before the bbolt spool closes.
if reportLoop := d.startAbuseReporting(); reportLoop != nil {
obs.Go("abuse-reporter", reportLoop)
}
// Start the central scored-set consumer (opt-in). Maintains the verified
// set and escalates findings whose IP is listed; nil when disabled.
if centralLoop := d.startCentralConsume(); centralLoop != nil {
d.wg.Add(1)
obs.Go("central-intel", func() {
defer d.wg.Done()
centralLoop()
})
}
// Start the retention sweep only when opted in. Compaction is not
// run from this goroutine; see internal/daemon/retention.go for why.
if d.cfg != nil && d.cfg.Retention.Enabled {
d.wg.Add(1)
obs.Go("retention-scanner", d.retentionScanner)
}
// Refresh the sd_notify status line now that AF_ALG, BPF trackers,
// and the periodic scanners have all reported in. The initial
// READY=1 fired earlier so systemctl restart didn't have to wait
// on the baseline scan; this is the operator-visible summary.
_, _ = sdnotify.Status(fmt.Sprintf("watchers attached: %d", countAttachedWatchers(d.WatcherStatuses())))
csmlog.Info("CSM daemon running")
// Wait for signals
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT, syscall.SIGHUP)
for sig := range sigCh {
if sig == syscall.SIGHUP {
fmt.Fprintf(os.Stderr, "[%s] SIGHUP received - reloading config and rules\n", ts())
d.reloadConfig()
d.reloadSignatures()
if d.policies != nil {
if err := d.policies.Reload(d.cfg.EmailProtection.PHPRelay.PoliciesDir); err != nil {
// Previous valid version stays in effect; surface the failure
// via the existing alert pipeline so operators see partial reload.
d.emitReloadFinding(alert.Warning, "email_php_relay_policies_reload",
fmt.Sprintf("policies/email reload encountered errors: %v", err))
}
}
d.publishGeoIP()
if d.fwEngine != nil {
if err := d.fwEngine.Apply(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Firewall reload error: %v\n", ts(), err)
} else {
fmt.Fprintf(os.Stderr, "[%s] Firewall rules reloaded\n", ts())
}
}
continue
}
break // SIGTERM or SIGINT
}
csmlog.Info("shutting down")
shutdownStart := time.Now()
close(d.stopCh)
// Abort any in-flight periodic scan so d.wg.Wait below is not held for a
// whole tier. Scanners observe d.stopCh between cycles; this cancels the
// scan currently executing inside RunTier.
if d.scanCancel != nil {
d.scanCancel()
}
// Log watchers own their files and close them when their Run loop exits on
// d.stopCh; d.wg.Wait below blocks until that happens. Calling Stop here
// would race the still-running Run on w.file.
if d.webServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = d.webServer.Shutdown(ctx)
cancel()
}
if d.challengeServer != nil {
d.challengeServer.Shutdown()
}
if d.challengeGate != nil {
if err := d.challengeGate.Close(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] challenge port-gate close: %v\n", ts(), err)
}
d.challengeGate = nil
}
if d.fileMonitor != nil {
d.fileMonitor.Stop()
}
if sw := d.getSpoolWatcher(); sw != nil {
sw.Stop()
}
if d.pamListener != nil {
d.pamListener.Stop()
}
if d.controlListener != nil {
d.controlListener.Stop()
}
d.stopYaraBackend()
csmlog.Info("watchers signalled", "elapsed_ms", time.Since(shutdownStart).Milliseconds())
d.wg.Wait()
csmlog.Info("workers drained", "elapsed_ms", time.Since(shutdownStart).Milliseconds())
// Some producers can finish a tick after alertDispatcher observes stopCh.
// Drain again once tracked workers are gone and before state is closed.
d.flushPendingAlertsOnShutdown()
d.stopAbuseReporting()
if d.findingBus != nil {
d.findingBus.Close()
alert.FindingBus = nil
}
for i := len(d.phpRelayShutdown) - 1; i >= 0; i-- {
d.phpRelayShutdown[i]()
}
d.phpRelayShutdown = nil
if adb := attackdb.Global(); adb != nil {
adb.Stop()
}
// Stop the incident auto-close and retention goroutines before closing the
// store so neither writes to an already-closed bbolt database.
StopIncidentBackgroundLoops()
if err := d.store.Close(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] error closing state store: %v\n", ts(), err)
}
d.lock.Release()
csmlog.Info("daemon stopped", "elapsed_ms", time.Since(shutdownStart).Milliseconds())
fmt.Fprintf(os.Stderr, "[%s] CSM daemon stopped\n", ts())
return nil
}
// DroppedAlerts returns the total number of alerts dropped due to
// channel backpressure since the daemon started.
func (d *Daemon) DroppedAlerts() int64 {
return atomic.LoadInt64(&d.droppedAlerts)
}
func (d *Daemon) scanContext() context.Context {
if d.scanCtx != nil {
return d.scanCtx
}
return context.Background()
}
// FindingBus returns the per-daemon broadcast.Bus used by passive
// observers like the SSE event stream. Returns nil before Run starts.
func (d *Daemon) FindingBus() *broadcast.Bus {
return d.findingBus
}
// alertDispatcher batches and dispatches alerts.
func (d *Daemon) alertDispatcher() {
defer d.wg.Done()
// Batch alerts: collect for 5 seconds, then dispatch
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
var batch []alert.Finding
for {
select {
case <-d.stopCh:
batch = d.drainAlertChannel(batch)
d.persistPendingFindingsOnShutdown(batch)
return
case f := <-d.alertCh:
batch = append(batch, f)
case <-ticker.C:
if len(batch) > 0 {
d.dispatchBatch(batch)
batch = nil
}
}
}
}
func (d *Daemon) drainAlertChannel(batch []alert.Finding) []alert.Finding {
for {
select {
case f, ok := <-d.alertCh:
if !ok {
return batch
}
batch = append(batch, f)
default:
return batch
}
}
}
func (d *Daemon) flushPendingAlertsOnShutdown() {
d.persistPendingFindingsOnShutdown(d.drainAlertChannel(nil))
}
// persistPendingFindingsOnShutdown records findings still queued at shutdown to
// the history log for forensics, then returns. It deliberately does NOT run the
// auto-response pipeline (nftables blocks, permission fixes, kill/quarantine,
// DB cleanup) or network alert dispatch: that work blocked the service stop for
// tens of seconds -- up to twice, once here and once in the dispatcher's stop
// branch -- while systemd waited. It is also redundant, because the next
// startup baseline scan re-detects and re-acts on the same conditions. Writing
// history only also leaves each finding re-alertable (it is not marked sent via
// store.Update), so the restart's dispatch is not suppressed.
func (d *Daemon) persistPendingFindingsOnShutdown(batch []alert.Finding) {
if len(batch) == 0 {
return
}
d.store.AppendHistory(batch)
}
func (d *Daemon) dispatchBatch(findings []alert.Finding) {
// Snapshot the live config once at the top of the batch. Every
// cfg.X read below picks up the last-reloaded value (ROADMAP
// item 7); taking one snapshot avoids the weirder case of a
// SIGHUP landing mid-batch and splitting some auto-response
// actions between old and new policy.
cfg := d.currentCfg()
findings = alert.Deduplicate(findings)
suppressions := d.store.LoadSuppressions()
autoResponseFindings := findings
if len(suppressions) > 0 {
autoResponseFindings = filterUnsuppressedFindings(d.store, findings, suppressions)
}
// Record ALL findings in attack database (before filtering -
// repeated attacks from the same IP must still be counted even if
// the alert is suppressed by FilterNew).
if adb := attackdb.Global(); adb != nil {
for _, f := range autoResponseFindings {
adb.RecordFinding(f)
}
}
// Auto-block and permission fix run on ALL findings (not just new ones).
// These must execute BEFORE FilterNew because repeat offender IPs and
// recurring permission issues need to be fixed even if the alert was
// already sent in a previous cycle.
// Challenge routing runs FIRST - claims eligible IPs before hard-blocking.
challengeActions := checks.ChallengeRouteIPs(cfg, autoResponseFindings)
blockActions := checks.AutoBlockIPs(cfg, autoResponseFindings)
permActions, permFixedKeys := checks.AutoFixPermissions(cfg, autoResponseFindings)
// Mark auto-blocked IPs in attack database
if adb := attackdb.Global(); adb != nil {
for _, f := range blockActions {
if ip := checks.ExtractIPFromFinding(f); ip != "" {
adb.MarkBlocked(ip)
}
}
}
// Dismiss auto-fixed findings from the Findings page
for _, key := range permFixedKeys {
d.store.DismissLatestFinding(key)
}
// Filter through state - only new findings get alerted and logged
newFindings := d.store.FilterNew(findings)
// Filter out suppressed findings - prevents email/webhook alerts for
// paths the admin has explicitly suppressed (e.g. false positives).
// Suppressions are stored in state/suppressions.json, not in rule files.
if len(suppressions) > 0 {
newFindings = filterUnsuppressedFindings(d.store, newFindings, suppressions)
}
// Append auto-response actions to new findings for alerting
newFindings = append(newFindings, blockActions...)
newFindings = append(newFindings, challengeActions...)
newFindings = append(newFindings, permActions...)
// PHP-relay AutoFreeze: emit any new findings produced by post-emit
// freeze decisions back into the dispatched batch so operators see
// the action outcome alongside the original finding. Nil-guard for
// non-cPanel / non-linux hosts where wiring is skipped.
if d.autoFreezer != nil {
if freezeFindings := d.autoFreezer.Apply(autoResponseFindings); len(freezeFindings) > 0 {
newFindings = append(newFindings, freezeFindings...)
}
}
if len(newFindings) == 0 {
d.store.Update(findings)
return
}
// Log to history
d.store.AppendHistory(newFindings)
// Kill, quarantine, and DB cleanup only run on NEW findings
killActions := checks.AutoKillProcesses(cfg, newFindings)
quarantineActions := checks.AutoQuarantineFiles(cfg, newFindings)
dbCleanActions := checks.AutoRespondDBMalware(cfg, newFindings)
newFindings = append(newFindings, killActions...)
newFindings = append(newFindings, quarantineActions...)
newFindings = append(newFindings, dbCleanActions...)
// Correlation
newFindings = expandWithCorrelation(newFindings, time.Now())
co := IncidentCorrelator()
for _, f := range newFindings {
_, _, _ = co.OnFinding(f)
}
// Broadcast findings (no-op; dashboard uses polling)
if d.webServer != nil {
d.webServer.Broadcast(newFindings)
}
// Dispatch via email/webhook - filter out findings that are
// informational or fully automated (no human action needed).
// These are all visible in the web UI for forensics.
var alertable []alert.Finding
for _, f := range newFindings {
switch f.Check {
case "modsec_block_realtime", "modsec_warning_realtime", "modsec_block_escalation", "modsec_csm_block_escalation":
continue // ModSecurity: fully automated, visible on /modsec
case "outdated_plugins":
continue // informational, visible on findings page
case "email_dkim_failure", "email_spf_rejection":
continue // operational email auth issues - visible on findings page
case "email_auth_failure_realtime", "pam_bruteforce", "exim_frozen_realtime":
continue // failed logins and frozen bounces - informational, no action needed
}
alertable = append(alertable, f)
}
if err := alert.Dispatch(cfg, alertable); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Alert dispatch error: %v\n", ts(), err)
}
d.store.Update(findings)
d.store.MarkAlerted(newFindings)
}
// autoFixWPCron lets daemon wiring tests avoid real wp-config.php and crontab
// edits; the checks package covers those side effects directly.
var autoFixWPCron = checks.AutoFixWPCron
// processScanFindings handles the output of a deep or periodic scan: it persists
// the findings to the latest-findings surface, runs the auto-responses that act
// on warning-severity perf findings, then forwards the remaining findings to the
// alert dispatcher. Warning-severity perf findings stay off the alert channel so
// they never page an operator; that is exactly why the WP-Cron auto-fix runs
// here and not in dispatchBatch, which only ever sees what the channel carries.
func (d *Daemon) processScanFindings(cfg *config.Config, findings []alert.Finding, purgeChecks []string, label string) {
checks.StoreLatestScanFindings(d.store, purgeChecks, findings)
d.applyWPCronAutoFix(cfg, findings)
d.enqueueScanAlerts(findings, label)
}
// enqueueScanAlerts forwards a scan's findings to the alert dispatcher.
// Warning-severity perf findings stay off the channel so they never page an
// operator; that is also why the WP-Cron auto-fix runs against the scan
// findings directly rather than in dispatchBatch, which only sees the channel.
func (d *Daemon) enqueueScanAlerts(findings []alert.Finding, label string) {
for _, f := range findings {
if strings.HasPrefix(f.Check, "perf_") && f.Severity == alert.Warning {
continue
}
select {
case d.alertCh <- f:
default:
atomic.AddInt64(&d.droppedAlerts, 1)
fmt.Fprintf(os.Stderr, "[%s] alert channel full, dropping %s finding: %s\n", ts(), label, f.Check)
}
}
}
// applyWPCronAutoFix disables WP-Cron and installs a per-user system cron for
// every perf_wp_cron finding, then clears the fixed findings from the
// latest-findings surface and records the actions in history. Findings the
// operator has suppressed are left untouched, so a suppression also stops the
// automated edit of that account's wp-config.php.
func (d *Daemon) applyWPCronAutoFix(cfg *config.Config, findings []alert.Finding) {
if d.store != nil {
if suppressions := d.store.LoadSuppressions(); len(suppressions) > 0 {
findings = filterUnsuppressedFindings(d.store, findings, suppressions)
}
}
actions, fixedKeys := autoFixWPCron(cfg, findings)
for _, key := range fixedKeys {
d.store.DismissLatestFinding(key)
}
if len(actions) > 0 {
d.store.AppendHistory(actions)
}
}
func filterUnsuppressedFindings(store *state.Store, findings []alert.Finding, suppressions []state.SuppressionRule) []alert.Finding {
if len(suppressions) == 0 {
return findings
}
var filtered []alert.Finding
for _, f := range findings {
if !store.IsSuppressed(f, suppressions) {
filtered = append(filtered, f)
}
}
return filtered
}
// criticalScanner runs critical checks every 10 minutes.
func (d *Daemon) criticalScanner() {
defer d.wg.Done()
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
d.runPeriodicChecks(checks.TierCritical)
}
}
}
// deepScanner runs deep checks at the configured interval (default 60 min).
// If fanotify is active, runs only the checks it can't replace (reduced set).
// If fanotify is NOT active (fallback mode), runs the full deep tier for timer-mode parity.
func (d *Daemon) deepScanner() {
defer d.wg.Done()
// Re-read the interval on each iteration so a SIGHUP that changes
// thresholds.deep_scan_interval_min takes effect on the next scan.
// A ticker captured at startup can't be re-sized cleanly without
// a reset path; time.After recomputes.
for {
interval := time.Duration(d.currentCfg().Thresholds.DeepScanIntervalMin) * time.Minute
if interval <= 0 {
// Defensive: an operator who zeroes the threshold would
// otherwise get a tight spin loop. 60 minutes matches the
// default from config.Load.
interval = 60 * time.Minute
}
select {
case <-d.stopCh:
return
case <-time.After(interval):
// Update threat intelligence feeds (once per day)
if db := checks.GetThreatDB(); db != nil {
_ = db.UpdateFeeds()
}
// Prune expired attack DB records (90-day retention)
if adb := attackdb.Global(); adb != nil {
adb.PruneExpired()
}
// If fanotify is active, only run checks it can't replace.
// If fanotify is NOT active, run the full deep tier.
//
// One exception: forceFullRescan is armed by the
// signature watcher when any rule file's mtime advances.
// In that case we bypass the fanotify short-list so the
// new ruleset gets a full sweep against existing files;
// without this, only files that change AFTER the rule
// update would catch the new patterns.
cfg := d.currentCfg()
rescan := d.forceFullRescan.CompareAndSwap(true, false)
var findings []alert.Finding
var purgeChecks []string
switch {
case rescan:
findings, purgeChecks = checks.RunTierWithContext(d.scanContext(), cfg, d.store, checks.TierDeep)
observeSignatureRescan()
case d.fileMonitor != nil:
findings, purgeChecks = checks.RunReducedDeepWithContext(d.scanContext(), cfg, d.store)
default:
findings, purgeChecks = checks.RunTierWithContext(d.scanContext(), cfg, d.store, checks.TierDeep)
}
d.processScanFindings(cfg, findings, purgeChecks, "deep")
}
}
}
func (d *Daemon) runPeriodicChecks(tier checks.Tier) {
// Snapshot the live config ONCE for the whole tick. Calling
// d.currentCfg() twice (once for integrity, once for RunTier)
// lets a SIGHUP land between the two reads and split the tick
// between old-policy integrity verification and new-policy
// detection. Matches the snapshot pattern in dispatchBatch.
cfg := d.currentCfg()
// Verify integrity against the snapshot. A SIGHUP reload re-signs
// integrity.config_hash on disk and updates config.Active; using
// d.cfg (the startup snapshot) here would fire a Critical tamper
// alert on every tick after a successful reload because the stored
// hash in d.cfg is stale. If a reload completes while Verify is
// hashing, retry once against the latest live config to avoid a
// false tamper alert from a stale snapshot.
var err error
cfg, err = d.verifyPeriodicIntegritySnapshot(cfg)
if err != nil {
select {
case d.alertCh <- alert.Finding{
Severity: alert.Critical,
Check: "integrity",
Message: fmt.Sprintf("BINARY/CONFIG TAMPER DETECTED: %v", err),
Timestamp: time.Now(),
}:
default:
atomic.AddInt64(&d.droppedAlerts, 1)
fmt.Fprintf(os.Stderr, "[%s] alert channel full, dropping integrity finding\n", ts())
}
return
}
// Age out stale dry-run-block records so the status surface
// reflects recent activity instead of months-old entries left
// over from a previous dry-run window. Keeping a 7-day rolling
// window matches the operator workflow of reviewing a week of
// would-have-been-blocks before flipping to live.
if sdb := store.Global(); sdb != nil {
sdb.PurgeDryRunBlocksOlderThan(time.Now().Add(-7 * 24 * time.Hour))
}
findings, purgeChecks := checks.RunTierWithContext(d.scanContext(), cfg, d.store, tier)
d.processScanFindings(cfg, findings, purgeChecks, "periodic")
}
func (d *Daemon) verifyPeriodicIntegritySnapshot(cfg *config.Config) (*config.Config, error) {
if err := integrity.Verify(d.binaryPath, cfg); err != nil {
latest := d.currentCfg()
if latest != nil && latest != cfg {
retryErr := integrity.Verify(d.binaryPath, latest)
if retryErr == nil {
return latest, nil
}
return latest, retryErr
}
return cfg, err
}
return cfg, nil
}
// heartbeat sends periodic pings to dead man's switch.
func (d *Daemon) heartbeat() {
defer d.wg.Done()
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
alert.SendHeartbeat(d.currentCfg())
d.hijackDetector.Cleanup()
// Clean expired temporary allows
if d.fwEngine != nil {
d.fwEngine.CleanExpiredAllows()
d.fwEngine.CleanExpiredSubnets()
}
// Clean expired temporary whitelist entries
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.PruneExpiredWhitelist()
}
}
}
}
// startPHPRelay implements the platform gate from Stage 1 spec section
// 9 (O1). Emits a Warning if the host is not cPanel; otherwise locates
// the exim binary for AutoFreeze and (in O2) wires the spool watcher,
// pipeline, and Flow E ticker.
func (d *Daemon) startPHPRelay() {
info := platform.Detect()
if !info.IsCPanel() {
select {
case d.alertCh <- alert.Finding{
Severity: alert.Warning,
Check: "email_php_relay_disabled",
Message: "php_relay disabled: not a cPanel host",
Timestamp: time.Now(),
}:
default:
}
return
}
if !d.cfg.EmailProtection.PHPRelay.Enabled {
return
}
if path, err := exec.LookPath("exim"); err == nil {
eximBinary = path
} else {
select {
case d.alertCh <- alert.Finding{
Severity: alert.Warning,
Check: "email_php_relay_no_exim",
Message: "php_relay auto-action disabled: exim binary not in PATH",
Timestamp: time.Now(),
}:
default:
}
}
// Bridge to the linux-only wiring (Phase O2). On non-linux GOOS
// the stub in php_relay_wiring_other.go is a no-op.
startPHPRelayLinux(d)
}
func (d *Daemon) startLogWatchers() {
// Session log handler wrapper - feeds events to both the alert handler and hijack detector
sessionHandler := func(line string, cfg *config.Config) []alert.Finding {
// Feed to hijack detector (tracks password changes + correlates with logins)
ParseSessionLineForHijack(line, d.hijackDetector)
// Regular session log handling
return parseSessionLogLine(line, cfg)
}
hostInfo := platform.Detect()
type logFile struct {
name string
path string
handler func(string, *config.Config) []alert.Finding
}
var logFiles []logFile
// Generic Linux auth log. RHEL-family uses /var/log/secure, Debian
// family uses /var/log/auth.log. Only register the log appropriate
// for the detected OS so we don't spam "not found, retrying" forever.
if hostInfo.IsDebianFamily() {
logFiles = append(logFiles, logFile{"", "/var/log/auth.log", parseSecureLogLine})
} else {
logFiles = append(logFiles, logFile{"", "/var/log/secure", parseSecureLogLine})
}
// eximHandler wraps parseEximLogLine (unchanged) and augments the result
// with smtpAuthTracker findings for dovecot authenticator failures and
// smtpProbeTracker findings for raw connect-rate abuse (scanners that
// probe-and-disconnect without ever reaching AUTH).
eximHandler := func(line string, cfg *config.Config) []alert.Finding {
findings := parseEximLogLine(line, cfg)
// Connect-rate signal fires before any AUTH attempt.
if probeIP := parseEximSMTPConnectIP(line); probeIP != "" {
if parsed := net.ParseIP(probeIP); parsed != nil {
if v4 := parsed.To4(); v4 != nil {
probeIP = v4.String()
}
}
if !isInfraIPDaemon(probeIP, cfg.InfraIPs) && !isPrivateOrLoopback(probeIP) {
if d.smtpProbeTracker != nil {
findings = append(findings, d.smtpProbeTracker.Record(probeIP)...)
}
}
}
if strings.Contains(line, "authenticator failed") && strings.Contains(line, "dovecot") {
ip := extractBracketedIP(line)
account := extractSetID(line)
// Canonicalize IPv4-mapped IPv6 (::ffff:a.b.c.d) to plain IPv4 so the
// tracker doesn't double-count the same attacker as two IPs.
if ip != "" {
if parsed := net.ParseIP(ip); parsed != nil {
if v4 := parsed.To4(); v4 != nil {
ip = v4.String()
}
}
}
if ip != "" && !isInfraIPDaemon(ip, cfg.InfraIPs) && !isPrivateOrLoopback(ip) {
if d.smtpAuthTracker != nil {
findings = append(findings, d.smtpAuthTracker.Record(ip, account)...)
}
}
}
return findings
}
// mailHandler composes parseDovecotLogLine (preserving email_suspicious_geo)
// with mailAuthTracker augmentation for IMAP/POP3/ManageSieve brute-force,
// subnet spray, account spray, and compromise detection.
mailHandler := func(line string, cfg *config.Config) []alert.Finding {
findings := parseDovecotLogLine(line, cfg)
if !isMailAuthLine(line) {
return findings
}
ip, account, success := extractMailLoginEvent(line)
if ip == "" {
return findings
}
if parsed := net.ParseIP(ip); parsed != nil {
if v4 := parsed.To4(); v4 != nil {
ip = v4.String()
}
}
if isInfraIPDaemon(ip, cfg.InfraIPs) || isPrivateOrLoopback(ip) {
return findings
}
if d.mailAuthTracker == nil {
return findings
}
if success {
findings = append(findings, d.mailAuthTracker.RecordSuccess(ip, account)...)
} else {
findings = append(findings, d.mailAuthTracker.Record(ip, account)...)
}
return findings
}
// cPanel-specific logs only watch these on cPanel hosts. On plain
// Ubuntu/AlmaLinux they do not exist and the old code spammed
// "not found, will retry every 60s" forever.
if hostInfo.IsCPanel() {
logFiles = append(logFiles,
logFile{"", "/usr/local/cpanel/logs/session_log", sessionHandler},
logFile{"", "/usr/local/cpanel/logs/access_log", parseAccessLogLineEnhanced},
logFile{"", "/var/log/messages", parseFTPLogLine},
)
}
if shouldWatchEximMainlog(hostInfo, os.Stat) {
logFiles = append(logFiles, logFile{"", eximMainlogPath, eximHandler})
}
// Mail-log reader: factory selects file vs journal based on cfg.MailLogs.
// Replaces the old cPanel-only /var/log/maillog registration; now works
// on all platforms using the platform-default path or journal fallback.
{
mailReader, mlErr := maillog.New(d.cfg.MailLogs, hostInfo.MailLogPath())
if mlErr != nil {
csmlog.Warn("mail log reader disabled", "err", mlErr)
d.MarkWatcher("maillog", false)
} else {
// A file-backed reader can go dark if its log path disappears
// mid-run (syslog->journald migration). Surface that instead of
// silently tailing a dead fd: mark the watcher unhealthy and
// emit a finding so the operator knows mail detection degraded.
if fr, ok := mailReader.(*maillog.FileReader); ok {
fr.SetOnGone(d.handleMailLogSourceGone)
fr.SetOnRestored(d.handleMailLogSourceRestored)
}
ctx, cancel := context.WithCancel(context.Background())
go func() { <-d.stopCh; cancel() }()
mailLines, mlErr := mailReader.Run(ctx)
if mlErr != nil {
cancel()
csmlog.Warn("mail log reader failed to start", "err", mlErr)
d.MarkWatcher("maillog", false)
} else {
d.MarkWatcher("maillog", true)
d.wg.Add(1)
obs.Go("maillog-consumer", func() {
defer d.wg.Done()
for line := range mailLines {
if !d.dispatchMailLogLine(line, mailHandler) {
return
}
}
})
}
}
}
// Only watch PHP Shield events if enabled in config
if d.cfg.PHPShield.Enabled {
logFiles = append(logFiles, logFile{"", phpEventsLogPath, parsePHPShieldLogLine})
}
// ModSecurity error log - auto-discover path based on detected web server.
if modsecPath := discoverModSecLogPath(d.cfg); modsecPath != "" {
logFiles = append(logFiles, logFile{"modsec", modsecPath, parseModSecLogLineDeduped})
} else if hostInfo.WebServer != platform.WSNone {
// Only bother with the retry loop if a web server is actually
// present. Headless hosts don't need this.
fmt.Fprintf(os.Stderr, "[%s] ModSecurity error log not found (checked %v), will retry every 60s\n", ts(), hostInfo.ErrorLogPaths)
d.MarkWatcher("modsec", false)
d.wg.Add(1)
obs.Go("logwatch-modsec-retry", func() {
defer d.wg.Done()
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
path := discoverModSecLogPath(d.cfg)
if path == "" {
continue
}
w, err := NewLogWatcher(path, d.cfg, parseModSecLogLineDeduped, d.alertCh)
if err != nil {
continue
}
d.logWatchersMu.Lock()
d.logWatchers = append(d.logWatchers, w)
d.logWatchersMu.Unlock()
d.wg.Add(1)
obs.Go("logwatch-modsec", func() {
defer d.wg.Done()
w.Run(d.stopCh)
})
csmlog.Info("watching log (appeared after retry)", "path", path)
d.MarkWatcher("modsec", true)
return
}
}
})
}
// Real-time access log watcher for wp-login/xmlrpc brute force detection.
// Auto-discover path from platform info (Apache/Nginx/cPanel aware).
if accessLogPath := discoverAccessLogPath(); accessLogPath != "" {
logFiles = append(logFiles, logFile{"", accessLogPath, parseAccessLogBruteForce})
} else if hostInfo.WebServer != platform.WSNone && len(hostInfo.AccessLogPaths) > 0 {
csmlog.Warn("access log not found, will retry every 60s", "candidates", fmt.Sprintf("%v", hostInfo.AccessLogPaths))
d.wg.Add(1)
accessPath := hostInfo.AccessLogPaths[0]
obs.Go("logwatch-access-retry", func() { d.retryLogWatcher(accessPath, parseAccessLogBruteForce) })
}
// Start background eviction for modsec dedup/escalation state
StartModSecEviction(d.stopCh, func() *config.Config { return d.currentCfg() })
// Start background eviction for access log brute force state
StartAccessLogEviction(d.stopCh)
// Start background eviction for email rate limiting state
StartEmailRateEviction(d.stopCh)
// Start background eviction for cloud-relay per-user windows so the
// sync.Map does not grow linearly with every distinct authenticated
// sender ever seen.
StartCloudRelayEviction(d.stopCh)
// Start background purge for SMTP brute-force tracker
d.wg.Add(1)
obs.Go("smtp-tracker-purge", func() {
defer d.wg.Done()
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
tick := 0
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
if d.smtpAuthTracker != nil {
d.smtpAuthTracker.Purge()
}
if d.smtpProbeTracker != nil {
d.smtpProbeTracker.Purge()
}
// Diagnostic: surface whether the SMTP/mail brute-force
// trackers are actually seeing auth failures and emitting
// blockable findings. A nonzero record_calls with zero
// findings_emitted over a sustained attack means the
// threshold path, not the wiring, is the gap to chase.
tick++
if tick%10 == 0 {
if d.smtpAuthTracker != nil {
sc, se := d.smtpAuthTracker.Stats()
csmlog.Info("smtp brute tracker stats",
"record_calls", sc, "findings_emitted", se, "tracked", d.smtpAuthTracker.Size())
}
if d.mailAuthTracker != nil {
mc, me := d.mailAuthTracker.Stats()
csmlog.Info("mail brute tracker stats",
"record_calls", mc, "findings_emitted", me, "tracked", d.mailAuthTracker.Size())
}
}
}
}
})
// Start background purge for mail (IMAP/POP3) brute-force tracker
d.wg.Add(1)
obs.Go("mail-tracker-purge", func() {
defer d.wg.Done()
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
if d.mailAuthTracker != nil {
d.mailAuthTracker.Purge()
}
}
}
})
for _, lf := range logFiles {
w, err := NewLogWatcher(lf.path, d.cfg, lf.handler, d.alertCh)
if err != nil {
if os.IsNotExist(err) {
// File doesn't exist yet - retry periodically until it appears
if lf.name != "" {
d.MarkWatcher(lf.name, false)
}
d.wg.Add(1)
path, handler, name := lf.path, lf.handler, lf.name
obs.Go("logwatch-retry", func() { d.retryLogWatcherNamed(path, handler, name) })
} else {
fmt.Fprintf(os.Stderr, "[%s] Warning: could not watch %s: %v\n", ts(), lf.path, err)
if lf.name != "" {
d.MarkWatcher(lf.name, false)
}
}
continue
}
d.logWatchers = append(d.logWatchers, w)
d.wg.Add(1)
watcher := w
obs.Go("logwatch", func() {
defer d.wg.Done()
watcher.Run(d.stopCh)
})
csmlog.Info("watching log", "path", lf.path)
if lf.name != "" {
d.MarkWatcher(lf.name, true)
}
}
}
func (d *Daemon) handleMailLogSourceGone(err error) {
d.MarkWatcher("maillog", false)
finding := alert.Finding{
Severity: alert.Warning,
Check: "mail_log_source_unavailable",
Message: fmt.Sprintf("Mail log source unavailable: %v; brute-force and rate detection degraded until it returns or the daemon restarts", err),
Timestamp: time.Now(),
}
select {
case d.alertCh <- finding:
case <-d.stopCh:
default:
atomic.AddInt64(&d.droppedAlerts, 1)
fmt.Fprintf(os.Stderr, "[%s] alert channel full, dropping maillog source finding\n", ts())
}
}
func (d *Daemon) handleMailLogSourceRestored() {
d.MarkWatcher("maillog", true)
}
func (d *Daemon) dispatchMailLogLine(line maillog.Line, handler LogLineHandler) bool {
findings := handler(line.Message, d.currentCfg())
for _, f := range findings {
select {
case d.alertCh <- f:
case <-d.stopCh:
return false
}
}
return true
}
// retryLogWatcher polls for a missing log file every 60 seconds.
// When the file appears, it starts a watcher and returns.
func (d *Daemon) retryLogWatcher(path string, handler LogLineHandler) {
d.retryLogWatcherNamed(path, handler, "")
}
func shouldWatchEximMainlog(hostInfo platform.Info, stat func(string) (os.FileInfo, error)) bool {
if hostInfo.IsCPanel() {
return true
}
if stat == nil {
stat = os.Stat
}
if _, err := stat(eximMainlogPath); err != nil {
return !os.IsNotExist(err)
}
return true
}
func (d *Daemon) retryLogWatcherNamed(path string, handler LogLineHandler, name string) {
defer d.wg.Done()
csmlog.Warn("log not found, will retry every 60s", "path", path)
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
w, err := NewLogWatcher(path, d.cfg, handler, d.alertCh)
if err != nil {
continue // still missing, keep retrying
}
d.logWatchersMu.Lock()
d.logWatchers = append(d.logWatchers, w)
d.logWatchersMu.Unlock()
d.wg.Add(1)
watcher := w
obs.Go("logwatch-late", func() {
defer d.wg.Done()
watcher.Run(d.stopCh)
})
csmlog.Info("watching log (appeared after retry)", "path", path)
if name != "" {
d.MarkWatcher(name, true)
}
return
}
}
}
func (d *Daemon) startWebUI() {
if !d.cfg.WebUI.Enabled {
return
}
srv, err := webui.New(d.cfg, d.store)
if err != nil {
csmlog.Error("webui init error", "err", err)
return
}
// Set version and signature count for status API
srv.SetVersion(d.version)
if scanner := signatures.Global(); scanner != nil {
srv.SetSigCount(scanner.RuleCount())
}
d.webServer = srv
srv.SetHealthProvider(d)
srv.SetFindingBus(d.findingBus)
srv.SetIncidentCorrelator(IncidentCorrelator())
d.logWatchersMu.Lock()
numWatchers := len(d.logWatchers)
d.logWatchersMu.Unlock()
srv.SetHealthInfo(d.fileMonitor != nil, numWatchers)
if d.fwEngine != nil {
srv.SetIPBlocker(d.fwEngine)
}
d.wg.Add(1)
obs.Go("webui", func() {
defer d.wg.Done()
if err := srv.Start(); err != nil {
csmlog.Error("webui server error", "err", err)
}
})
}
func (d *Daemon) startPAMListener() {
pl, err := NewPAMListener(d.cfg, d.alertCh)
if err != nil {
csmlog.Warn("PAM listener not available", "err", err)
d.MarkWatcher("pamlistener", false)
return
}
d.pamListener = pl
d.MarkWatcher("pamlistener", true)
d.RegisterUpstreamProbe("pamlistener", pl.UpstreamResult)
d.wg.Add(1)
obs.Go("pam-listener", func() {
defer d.wg.Done()
pl.Run(d.stopCh)
})
csmlog.Info("PAM listener active", "socket", pamSocketPath)
}
func (d *Daemon) startControlListener() {
cl, err := NewControlListener(d)
if err != nil {
// The daemon can still function without the socket — periodic
// scans and webui keep running — but the CLI will hard-error
// because the socket is the expected path. Log loudly.
csmlog.Error("control listener not available", "err", err)
return
}
d.controlListener = cl
d.wg.Add(1)
obs.Go("control-listener", func() {
defer d.wg.Done()
cl.Run(d.stopCh)
})
csmlog.Info("control listener active", "socket", controlSocketPath)
}
func (d *Daemon) startFileMonitor() {
fm, err := NewFileMonitor(d.cfg, d.alertCh)
if err != nil {
csmlog.Warn("fanotify not available, falling back to periodic deep scan", "err", err)
d.MarkWatcher("fanotify", false)
return
}
fm.registerMetrics()
d.fileMonitor = fm
d.wg.Add(1)
obs.Go("fanotify", func() {
defer d.wg.Done()
fm.Run(d.stopCh)
})
csmlog.Info("fanotify file monitor active", "paths", "/home, /tmp, /dev/shm")
d.MarkWatcher("fanotify", true)
}
func (d *Daemon) startSpoolWatcher() {
if !d.cfg.EmailAV.Enabled {
return
}
// Create ClamAV scanner
clamScanner := emailav.NewClamdScanner(d.cfg.EmailAV.ClamdSocket)
// YARA-X scanner over whichever backend initYaraBackend installed.
// Active() transparently resolves to the in-process *yara.Scanner
// or to the out-of-process supervisor depending on
// signatures.yara_worker_enabled; severity metadata now travels on
// matches so either backend yields the same verdict shape.
yaraScanner := emailav.NewYaraXScanner(yara.Active())
// Create orchestrator with both engines
scanners := []emailav.Scanner{clamScanner, yaraScanner}
orch := emailav.NewOrchestrator(scanners, d.cfg.EmailAV.ScanTimeoutDuration())
// Create quarantine
quar := emailav.NewQuarantine("/opt/csm/quarantine/email")
d.emailQuarantine = quar
// Create and start spool watcher
sw, err := NewSpoolWatcher(d.cfg, d.alertCh, orch, quar)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Email AV spool watcher not available: %v\n", ts(), err)
d.MarkWatcher("email_av_spool", false)
return
}
d.setSpoolWatcher(sw)
d.MarkWatcher("email_av_spool", true)
d.wg.Add(1)
obs.Go("spool-watcher", func() {
defer d.wg.Done()
d.runSpoolWatcherLoop(sw, orch, quar)
})
// Start quarantine cleanup goroutine
d.wg.Add(1)
obs.Go("email-quarantine-cleanup", d.emailQuarantineCleanup)
fmt.Fprintf(os.Stderr, "[%s] Email AV spool watcher active\n", ts())
}
// superviseWatcherRun runs run() to completion. If daemonStop closes while
// run() is still blocked, stop() is invoked so run() can return. The helper
// goroutine is reaped when run() returns on its own. This guarantees the live
// watcher instance is stopped on shutdown even after a crash-restart swapped a
// fresh instance in, which the external shutdown path (it only stops the
// instance registered via setSpoolWatcher) can miss, hanging wg.Wait forever.
func superviseWatcherRun(daemonStop <-chan struct{}, run, stop func()) {
done := make(chan struct{})
helperDone := make(chan struct{})
go func() {
defer close(helperDone)
select {
case <-daemonStop:
stop()
case <-done:
}
}()
run()
close(done)
<-helperDone
}
type spoolWatcherRuntime interface {
Run()
Stop()
}
func (d *Daemon) runSpoolWatcherLoop(sw *SpoolWatcher, orch *emailav.Orchestrator, quar *emailav.Quarantine) {
d.runSpoolWatcherLoopWithFactory(sw, 2*time.Second, func() (spoolWatcherRuntime, error) {
next, err := NewSpoolWatcher(d.cfg, d.alertCh, orch, quar)
if err != nil {
return nil, err
}
d.setSpoolWatcher(next)
return next, nil
})
}
func (d *Daemon) runSpoolWatcherLoopWithFactory(current spoolWatcherRuntime, restartDelay time.Duration, newWatcher func() (spoolWatcherRuntime, error)) {
for {
superviseWatcherRun(d.stopCh, current.Run, current.Stop)
select {
case <-d.stopCh:
return
default:
}
fmt.Fprintf(os.Stderr, "[%s] Email AV spool watcher stopped unexpectedly; restarting in %s\n", ts(), restartDelay)
for {
select {
case <-d.stopCh:
return
case <-time.After(restartDelay):
}
next, err := newWatcher()
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Email AV spool watcher restart failed: %v\n", ts(), err)
continue
}
current = next
break
}
}
}
func (d *Daemon) setSpoolWatcher(sw *SpoolWatcher) {
d.spoolWatcherMu.Lock()
d.spoolWatcher = sw
d.spoolWatcherMu.Unlock()
d.syncEmailAVWebState()
}
func (d *Daemon) getSpoolWatcher() *SpoolWatcher {
d.spoolWatcherMu.Lock()
defer d.spoolWatcherMu.Unlock()
return d.spoolWatcher
}
// startForwarderWatcher starts the inotify watcher for /etc/valiases/.
func (d *Daemon) startForwarderWatcher() {
fw, err := NewForwarderWatcher(d.alertCh, d.cfg.EmailProtection.KnownForwarders)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Warning: forwarder watcher not started: %v\n", ts(), err)
d.MarkWatcher("forwarder", false)
return
}
d.forwarderWatcher = fw
d.wg.Add(1)
obs.Go("forwarder-watcher", func() {
defer d.wg.Done()
fw.Run(d.stopCh)
})
csmlog.Info("watching log (inotify forwarder watcher)", "path", "/etc/valiases/")
d.MarkWatcher("forwarder", true)
}
func (d *Daemon) syncEmailAVWebState() {
if d.webServer == nil || d.emailQuarantine == nil {
return
}
d.webServer.SetEmailQuarantine(d.emailQuarantine)
if sw := d.getSpoolWatcher(); sw != nil {
if sw.PermissionMode() {
d.webServer.SetEmailAVWatcherMode("permission")
} else {
d.webServer.SetEmailAVWatcherMode("notification")
}
}
}
// emailQuarantineCleanup periodically removes expired quarantined email messages.
func (d *Daemon) emailQuarantineCleanup() {
defer d.wg.Done()
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
if d.emailQuarantine != nil {
cleaned, err := d.emailQuarantine.CleanExpired(30 * 24 * time.Hour)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Email quarantine cleanup error: %v\n", ts(), err)
} else if cleaned > 0 {
fmt.Fprintf(os.Stderr, "[%s] Email quarantine cleanup: removed %d expired messages\n", ts(), cleaned)
}
}
}
}
}
func (d *Daemon) startChallengeServer() {
if !d.cfg.Challenge.Enabled {
return
}
if d.fwEngine == nil {
fmt.Fprintf(os.Stderr, "[%s] Challenge server requires firewall to be enabled (for escalation and allow). Skipping.\n", ts())
return
}
unblocker := challenge.IPUnblocker(d.fwEngine)
d.ipList = challenge.NewIPListWithMapPath(d.cfg.StatePath, challenge.DefaultMapPath)
if platform.Detect().WebServer == platform.WSNginx {
d.ipList.SetNginxMap(challenge.DefaultNginxMapPath, d.reloadChallengeNginxMap)
}
d.attachChallengePortGate()
checks.SetChallengeIPList(d.ipList)
srv := challenge.New(d.cfg, unblocker, d.ipList)
d.challengeServer = srv
d.wg.Add(1)
obs.Go("challenge-server", func() {
defer d.wg.Done()
csmlog.Info("challenge server active", "port", d.cfg.Challenge.ListenPort)
if err := srv.Start(); err != nil && err.Error() != "http: Server closed" {
csmlog.Error("challenge server error", "err", err)
}
})
}
// attachChallengePortGate installs the nftables port-gate for the
// challenge listener when the operator opts in. The gate is silently
// absent when the listener is loopback-only (no off-host traffic can
// reach it anyway) or on non-Linux builds.
func (d *Daemon) attachChallengePortGate() {
if !d.cfg.Challenge.PortGate.Enabled {
return
}
gate, err := challenge.NewPortGate(challenge.PortGateConfig{
ListenAddr: d.cfg.Challenge.ListenAddr,
ListenPort: d.cfg.Challenge.ListenPort,
InfraCIDRs: d.cfg.InfraIPs,
})
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] challenge port-gate install failed: %v (listener stays publicly reachable)\n", ts(), err)
return
}
if gate == nil {
csmlog.Info("challenge port-gate skipped (loopback listener or non-Linux build)",
"listen_addr", d.cfg.Challenge.ListenAddr)
return
}
d.challengeGate = gate
d.ipList.SetPortGate(gate)
csmlog.Info("challenge port-gate active", "port", d.cfg.Challenge.ListenPort)
}
func (d *Daemon) reloadChallengeNginxMap() error {
// #nosec G204 -- static binary and arguments; no operator input is passed.
out, err := exec.Command("nginx", "-s", "reload").CombinedOutput()
if err != nil {
return fmt.Errorf("nginx -s reload: %w: %s", err, strings.TrimSpace(string(out)))
}
return nil
}
func (d *Daemon) challengeEscalator() {
defer d.wg.Done()
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
// Re-read the block expiry each tick so a SIGHUP that
// changes auto_response.block_expiry takes effect on the
// next escalation without requiring a restart.
expiry := parseBlockExpiry(d.currentCfg().AutoResponse.BlockExpiry)
if d.challengeServer != nil {
d.challengeServer.CleanExpired()
}
expired := d.ipList.ExpiredEntries()
for _, e := range expired {
if d.fwEngine == nil {
continue
}
reason := fmt.Sprintf("CSM challenge-timeout: %s", truncateStr(e.Reason, 100))
if err := d.fwEngine.BlockIP(e.IP, reason, expiry); err != nil {
fmt.Fprintf(os.Stderr, "[%s] challenge-escalate: error blocking %s: %v\n", ts(), e.IP, err)
continue
}
fmt.Fprintf(os.Stderr, "[%s] CHALLENGE-ESCALATE: %s timed out, hard-blocked\n", ts(), e.IP)
}
}
}
}
func parseBlockExpiry(s string) time.Duration {
if s == "" {
return 24 * time.Hour
}
d, err := time.ParseDuration(s)
if err != nil {
return 24 * time.Hour
}
return d
}
func truncateStr(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max]
}
func (d *Daemon) initGeoIP() {
dbDir := filepath.Join(d.cfg.StatePath, "geoip")
db := geoip.Open(dbDir)
if db != nil {
d.geoipDB = db
setGeoIPDB(db) // make available to log watcher handlers for country filtering
if d.webServer != nil {
d.webServer.SetGeoIPDB(db)
}
}
}
// publishGeoIP reloads existing GeoIP databases or creates a new DB
// if databases were downloaded for the first time.
// Mutex-protected: safe to call from geoipUpdater goroutine and SIGHUP handler concurrently.
func (d *Daemon) publishGeoIP() {
d.geoipMu.Lock()
defer d.geoipMu.Unlock()
if d.geoipDB != nil {
if err := d.geoipDB.Reload(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] GeoIP reload error: %v\n", ts(), err)
} else {
fmt.Fprintf(os.Stderr, "[%s] GeoIP databases reloaded\n", ts())
}
return
}
// First-time: no DB existed at startup, try to open freshly downloaded files
dbDir := filepath.Join(d.cfg.StatePath, "geoip")
db := geoip.OpenFresh(dbDir)
if db != nil {
d.geoipDB = db
setGeoIPDB(db)
if d.webServer != nil {
d.webServer.SetGeoIPDB(db)
}
fmt.Fprintf(os.Stderr, "[%s] GeoIP databases loaded for the first time\n", ts())
}
}
// geoipUpdater periodically downloads updated GeoLite2 databases.
func (d *Daemon) geoipUpdater() {
defer d.wg.Done()
// Skip if no credentials configured
if d.cfg.GeoIP.AccountID == "" || d.cfg.GeoIP.LicenseKey == "" {
return
}
// Skip if auto_update is explicitly false
if d.cfg.GeoIP.AutoUpdate != nil && !*d.cfg.GeoIP.AutoUpdate {
return
}
interval := 24 * time.Hour
if d.cfg.GeoIP.UpdateInterval != "" {
if parsed, err := time.ParseDuration(d.cfg.GeoIP.UpdateInterval); err == nil && parsed >= time.Hour {
interval = parsed
}
}
// Wait 5 minutes before first update attempt (let the daemon stabilize)
select {
case <-d.stopCh:
return
case <-time.After(5 * time.Minute):
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
d.doGeoIPUpdate()
select {
case <-d.stopCh:
return
case <-ticker.C:
}
}
}
func (d *Daemon) doGeoIPUpdate() {
results := geoip.Update(
filepath.Join(d.cfg.StatePath, "geoip"),
d.cfg.GeoIP.AccountID,
d.cfg.GeoIP.LicenseKey,
d.cfg.GeoIP.Editions,
)
if results == nil {
return
}
anyUpdated := false
for _, r := range results {
switch r.Status {
case "updated":
fmt.Fprintf(os.Stderr, "[%s] GeoIP auto-update: %s updated\n", ts(), r.Edition)
anyUpdated = true
case "up_to_date":
// silent
case "error":
fmt.Fprintf(os.Stderr, "[%s] GeoIP auto-update: %s error: %v\n", ts(), r.Edition, r.Err)
}
}
if anyUpdated {
d.publishGeoIP()
}
}
func (d *Daemon) startFirewall() {
if !d.cfg.Firewall.Enabled {
return
}
// Merge top-level infra IPs into firewall's list. Top-level controls
// alert suppression (tight: only admin IPs), firewall may include
// additional CIDRs (e.g. server's own range) that need port access
// but should still be tracked for security alerts.
//
// Use a shallow copy rather than mutating d.cfg.Firewall in place.
// Mutating the live config poisons config.Diff during a SIGHUP
// reload: reload loads a fresh Config whose firewall.infra_ips has
// NOT been merged, and reflect.DeepEqual then reports the firewall
// subtree as changed even when nothing in csm.yaml was edited. The
// reload is classified restart_required and every operator reload
// turns into a spurious warning.
mergedFirewall := *d.cfg.Firewall
mergedFirewall.InfraIPs = mergeInfraIPs(d.cfg.InfraIPs, d.cfg.Firewall.InfraIPs)
ensureChallengePortGateFirewallAccess(d.cfg, &mergedFirewall)
engine, err := firewall.NewEngine(&mergedFirewall, d.cfg.StatePath)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Firewall engine init error: %v\n", ts(), err)
return
}
// Wire dry-run + verdict callbacks BEFORE Apply() and before the
// engine is exposed via d.fwEngine / checks.SetIPBlocker. The
// auto_response.dry_run safety default is "on": if any code path
// reaches engine.BlockIP while these callbacks are still nil, the
// engine treats dry-run as off and the block lands live, defeating
// the operator's stated intent. Wiring before exposure removes the
// boot-time race window entirely.
engine.SetDryRunRecorder(func(ip, reason string, timeout time.Duration) {
if db := store.Global(); db != nil {
db.RecordDryRunBlock(ip, reason, timeout)
}
})
engine.SetDryRunEnabledFunc(d.autoResponseDryRunEnabled)
engine.SetVerdictAsker(d.askVerdictCallback)
if err := engine.Apply(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Firewall apply error: %v\n", ts(), err)
return
}
// Apply does not consult the verdict callback. Install the shutdown
// context only after a successful firewall setup so a failed init
// does not leave behind a stopCh waiter.
verdictCtx, cancelVerdict := context.WithCancel(context.Background())
go func() {
<-d.stopCh
cancelVerdict()
}()
engine.SetShutdownContext(verdictCtx)
d.setFirewallEngine(engine)
// Set firewall engine for auto-blocking
checks.SetIPBlocker(engine)
// Wire the incident firewall hand-off through BlockIPOutcome so the
// correlator can distinguish live nftables mutation from dry-run,
// verdict-allow, and other no-op outcomes.
SetIncidentSprayBlocker(func(ip, reason string, timeout time.Duration) (bool, error) {
outcome, err := engine.BlockIPOutcome(ip, reason, timeout)
return outcome == firewall.BlockOutcomeLive, err
})
fwState, _ := firewall.LoadState(d.cfg.StatePath)
csmlog.Info("firewall active",
"blocked_ips", len(fwState.Blocked),
"allowed_ips", len(fwState.Allowed),
)
// Start Dynamic DNS resolver if configured. The same resolver
// loop also services hostnames listed under infra_ips so they get
// DNS-refreshed into the engine's infra-block guard; otherwise the
// hostname entries would only protect operators whose IPs never
// move, which defeats the point of listing them by name.
infraHosts := infraHostnames(mergedFirewall.InfraIPs)
dynHosts := append([]string{}, d.cfg.Firewall.DynDNSHosts...)
for _, h := range infraHosts {
if !containsString(dynHosts, h) {
dynHosts = append(dynHosts, h)
}
}
if len(dynHosts) > 0 {
resolver := firewall.NewDynDNSResolver(dynHosts, engine)
resolver.SetInfraEngine(engine)
for _, h := range infraHosts {
resolver.RegisterInfraHost(h)
}
resolver.SetFindingSink(func(host string) {
select {
case d.alertCh <- dynDNSUnresolvableFinding(host):
default:
atomic.AddInt64(&d.droppedAlerts, 1)
fmt.Fprintf(os.Stderr, "[%s] alert channel full, dropping dyndns guard finding: %s\n", ts(), host)
}
})
d.wg.Add(1)
obs.Go("dyndns-resolver", func() {
defer d.wg.Done()
resolver.Run(d.stopCh)
})
csmlog.Info("DynDNS resolver active", "hosts", len(dynHosts), "infra_hosts", len(infraHosts))
}
// Start Cloudflare IP whitelist refresh if configured
if d.cfg.Cloudflare.Enabled {
d.wg.Add(1)
obs.Go("cloudflare-refresh", d.cloudflareRefreshLoop)
csmlog.Info("cloudflare IP whitelist enabled", "refresh_hours", d.cfg.Cloudflare.RefreshHours)
}
}
func ensureChallengePortGateFirewallAccess(cfg *config.Config, fw *firewall.FirewallConfig) {
if cfg == nil || fw == nil {
return
}
if !cfg.Challenge.Enabled || !cfg.Challenge.PortGate.Enabled {
return
}
if cfg.Challenge.ListenPort <= 0 || challengeListenAddrIsLoopback(cfg.Challenge.ListenAddr) {
return
}
fw.TCPIn = appendUniquePort(fw.TCPIn, cfg.Challenge.ListenPort)
fw.RestrictedTCP = removePort(fw.RestrictedTCP, cfg.Challenge.ListenPort)
}
func challengeListenAddrIsLoopback(addr string) bool {
addr = strings.TrimSpace(addr)
if addr == "" {
return true
}
host := addr
if h, _, err := net.SplitHostPort(addr); err == nil {
host = h
}
host = strings.Trim(host, "[]")
if host == "" {
return false
}
if strings.EqualFold(host, "localhost") {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
func appendUniquePort(ports []int, port int) []int {
for _, p := range ports {
if p == port {
return ports
}
}
out := append([]int(nil), ports...)
return append(out, port)
}
func removePort(ports []int, port int) []int {
out := make([]int, 0, len(ports))
for _, p := range ports {
if p != port {
out = append(out, p)
}
}
return out
}
func (d *Daemon) autoResponseDryRunEnabled() bool {
return d.activeOrStartupCfg().AutoResponseDryRunEnabled()
}
func (d *Daemon) askVerdictCallback(ctx context.Context, ip, reason string) (string, string, string, error) {
cfg := d.activeOrStartupCfg()
if cfg == nil || !cfg.AutoResponse.VerdictCallback.Enabled {
return "", "", "", nil
}
vcCfg := cfg.AutoResponse.VerdictCallback
vc := verdict.New(verdict.Config{
URL: vcCfg.URL,
HMACSecret: vcCfg.HMACSecret,
HMACSecretEnv: vcCfg.HMACSecretEnv,
RequireResponseSignature: vcCfg.RequireResponseSignature,
AllowUnsigned: vcCfg.AllowUnsigned,
Timeout: time.Duration(vcCfg.TimeoutSec) * time.Second,
})
resp, err := vc.Ask(ctx, verdict.Request{
IP: ip,
Reason: reason,
Severity: "auto",
Source: "auto_response",
})
if err != nil {
return "", "", "", err
}
return resp.Verdict, resp.TenantID, resp.Note, nil
}
func dynDNSUnresolvableFinding(host string) alert.Finding {
return alert.Finding{
Check: "infra_ips_unresolvable",
Severity: alert.Warning,
Message: fmt.Sprintf("dynamic firewall host %s has not resolved within grace period", host),
Details: "Verify DNS for the host or remove it from infra_ips or firewall.dyndns_hosts. While unresolvable, the previous resolved IP remains protected and a rotated IP will not be protected.",
Timestamp: time.Now(),
}
}
// cloudflareRefreshLoop fetches Cloudflare IPs and updates the firewall sets periodically.
func (d *Daemon) cloudflareRefreshLoop() {
defer d.wg.Done()
interval := time.Duration(d.cfg.Cloudflare.RefreshHours) * time.Hour
// Fetch immediately on startup
d.refreshCloudflareIPs()
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
d.refreshCloudflareIPs()
}
}
}
func (d *Daemon) refreshCloudflareIPs() {
ipv4, ipv6, err := firewall.FetchCloudflareIPs()
if err != nil {
csmlog.Error("cloudflare IP fetch error", "err", err)
return
}
if d.fwEngine != nil {
if err := d.fwEngine.UpdateCloudflareSet(ipv4, ipv6); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Cloudflare set update error: %v\n", ts(), err)
return
}
}
firewall.SaveCFState(d.cfg.StatePath, ipv4, ipv6, time.Now())
// Update the checks package so AutoBlockIPs/ChallengeRouteIPs skip CF IPs.
// Blocking a CF edge IP would block thousands of legitimate users.
allCF := make([]string, 0, len(ipv4)+len(ipv6))
allCF = append(allCF, ipv4...)
allCF = append(allCF, ipv6...)
checks.SetCloudflareNets(allCF)
}
// signatureUpdater periodically downloads new rules and reloads scanners.
func (d *Daemon) signatureUpdater() {
defer d.wg.Done()
yamlEnabled := d.cfg.Signatures.UpdateURL != ""
forgeEnabled := d.cfg.Signatures.YaraForge.Enabled && yara.Available()
if !yamlEnabled && !forgeEnabled {
return
}
select {
case <-d.stopCh:
return
case <-time.After(5 * time.Minute):
}
yamlInterval := 24 * time.Hour
if d.cfg.Signatures.UpdateInterval != "" {
if parsed, err := time.ParseDuration(d.cfg.Signatures.UpdateInterval); err == nil && parsed >= time.Hour {
yamlInterval = parsed
}
}
forgeInterval := 168 * time.Hour
if d.cfg.Signatures.YaraForge.UpdateInterval != "" {
if parsed, err := time.ParseDuration(d.cfg.Signatures.YaraForge.UpdateInterval); err == nil && parsed >= time.Hour {
forgeInterval = parsed
}
}
tickInterval := yamlInterval
if forgeEnabled && forgeInterval < tickInterval {
tickInterval = forgeInterval
}
if !yamlEnabled {
tickInterval = forgeInterval
}
var lastYAML, lastForge time.Time
ticker := time.NewTicker(tickInterval)
defer ticker.Stop()
for {
now := time.Now()
if yamlEnabled && now.Sub(lastYAML) >= yamlInterval {
d.doSignatureUpdate()
lastYAML = now
}
if forgeEnabled && now.Sub(lastForge) >= forgeInterval {
d.doForgeUpdate()
lastForge = now
}
select {
case <-d.stopCh:
return
case <-ticker.C:
}
}
}
func (d *Daemon) doSignatureUpdate() {
count, err := signatures.Update(d.cfg.Signatures.RulesDir, d.cfg.Signatures.UpdateURL, d.cfg.Signatures.SigningKey)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Signature auto-update failed: %v\n", ts(), err)
return
}
fmt.Fprintf(os.Stderr, "[%s] Signature auto-update: %d rules downloaded\n", ts(), count)
d.reloadSignatures()
}
func (d *Daemon) doForgeUpdate() {
// yara.Active() resolves to the in-process scanner or the worker
// supervisor depending on signatures.yara_worker_enabled; both
// satisfy the Reload/RuleCount calls this routine makes, so the
// forge update path is backend-agnostic.
yaraScanner := yara.Active()
if yaraScanner == nil {
// No YARA backend active (build without yara tag or no rules
// dir). Skip Forge update - rules can't be loaded anyway.
return
}
db := store.Global()
currentVersion := ""
if db != nil {
currentVersion = db.GetMetaString("forge_version_" + d.cfg.Signatures.YaraForge.Tier)
}
newVersion, count, err := signatures.ForgeUpdateFromURL(
d.cfg.Signatures.RulesDir,
d.cfg.Signatures.YaraForge.Tier,
currentVersion,
d.cfg.Signatures.SigningKey,
d.cfg.Signatures.YaraForge.DownloadURL,
d.cfg.Signatures.DisabledRules,
)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] YARA Forge update failed: %v\n", ts(), err)
return
}
if count == 0 {
return
}
fmt.Fprintf(os.Stderr, "[%s] YARA Forge update: %d rules (version %s)\n", ts(), count, newVersion)
// Record rule count before reload to detect conflicts with existing .yar files.
prevCount := yaraScanner.RuleCount()
if err := yaraScanner.Reload(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] YARA rule reload after Forge update error: %v\n", ts(), err)
return // don't store version - retry next cycle
}
newCount := yaraScanner.RuleCount()
// If rule count dropped, the Forge file likely conflicts with existing rules.
// Roll back: remove the Forge file, reload again, don't store version.
if newCount < prevCount {
forgeFile := filepath.Join(d.cfg.Signatures.RulesDir, fmt.Sprintf("yara-forge-%s.yar", d.cfg.Signatures.YaraForge.Tier))
fmt.Fprintf(os.Stderr, "[%s] YARA Forge rollback: rule count dropped %d -> %d (conflict with existing rules), removing %s\n",
ts(), prevCount, newCount, forgeFile)
_ = os.Remove(forgeFile)
_ = yaraScanner.Reload()
return // don't store version
}
fmt.Fprintf(os.Stderr, "[%s] Reloaded %d YARA rules after Forge update\n", ts(), newCount)
if db != nil {
_ = db.SetMetaString("forge_version_"+d.cfg.Signatures.YaraForge.Tier, newVersion)
}
}
func (d *Daemon) reloadSignatures() {
if scanner := signatures.Global(); scanner != nil {
if err := scanner.Reload(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] YAML rule reload error: %v\n", ts(), err)
} else {
fmt.Fprintf(os.Stderr, "[%s] Reloaded %d YAML rules (version %d)\n", ts(), scanner.RuleCount(), scanner.Version())
if d.webServer != nil {
d.webServer.SetSigCount(scanner.RuleCount())
}
}
}
if yaraScanner := yara.Active(); yaraScanner != nil {
if err := yaraScanner.Reload(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] YARA rule reload error: %v\n", ts(), err)
} else {
fmt.Fprintf(os.Stderr, "[%s] Reloaded %d YARA rule file(s)\n", ts(), yaraScanner.RuleCount())
}
}
}
// deployConfigs writes embedded config files to their system locations on startup.
// Ensures WHM plugin CGI and ModSec rules stay current after binary upgrades.
//
// Every file written here is a system integration point consumed by a
// different process (WHM, Apache, nginx); the permissions intentionally
// allow the right external reader. Gosec G301/G306 warnings on this
// function are suppressed inline with the specific integration target.
func deployConfigs() {
// WHM plugin CGI - embedded in binary
if _, err := os.Stat("/usr/local/cpanel"); err == nil {
dst := "/usr/local/cpanel/whostmgr/docroot/cgi/addon_csm.cgi"
// #nosec G306 -- WHM CGI endpoint; 0755 is required so cPanel's
// webserver can execute it.
if err := os.WriteFile(dst, embeddedWHMCGI, 0755); err == nil {
csmlog.Info("WHM plugin CGI deployed", "path", dst)
}
// Write the AppConfig file, then register it with WHM.
// Writing the file alone does NOT make the plugin appear in the
// sidebar — WHM's AppConfig system maintains a registration
// database that is updated via `register_appconfig`. Skipping
// that step was a long-standing bug; the plugin file existed on
// disk but never showed up in the menu.
// #nosec G301 -- cPanel standard /var/cpanel/apps directory.
_ = os.MkdirAll("/var/cpanel/apps", 0755)
confPath := "/var/cpanel/apps/csm.conf"
// #nosec G306 -- WHM AppConfig; read by cPanel tooling, 0644 is convention.
if err := os.WriteFile(confPath, embeddedWHMConf, 0644); err != nil {
csmlog.Error("WHM AppConfig write failed", "path", confPath, "err", err)
} else if err := registerWHMPlugin(confPath); err != nil {
// Non-fatal: the conf is on disk, register_appconfig failure is
// logged so operators can fix it manually. Most common failure
// is register_appconfig not being in PATH (old cPanel versions).
csmlog.Warn("WHM plugin registration failed", "err", err)
} else {
csmlog.Info("WHM plugin registered with AppConfig")
}
}
// Deploy script (self-updating)
// #nosec G306 -- Shell script executed by operators and by the CSM
// upgrade path; needs to be executable, not private.
_ = os.WriteFile("/opt/csm/deploy.sh", embeddedDeployScript, 0755)
// ModSecurity virtual patches
for _, dst := range []string{
"/etc/apache2/conf.d/modsec/modsec2.user.conf",
"/usr/local/apache/conf/modsec2.user.conf",
} {
if _, err := os.Stat(filepath.Dir(dst)); err == nil {
// #nosec G306 -- Apache reads this ModSecurity config; webserver
// runs as a different user.
_ = os.WriteFile(dst, embeddedModSec, 0644)
overridesFile := filepath.Join(filepath.Dir(dst), "modsec2.csm-overrides.conf")
modsec.EnsureOverridesInclude(dst, overridesFile)
break
}
}
}
// registerWHMPlugin runs cPanel's register_appconfig helper to add the CSM
// plugin to the WHM sidebar. WHM maintains a cached registration database
// separate from the /var/cpanel/apps/ conf files; without running this
// helper, the plugin file exists on disk but the menu never shows it.
//
// Idempotent: re-running against an already-registered plugin just updates
// the entry. Non-fatal on failure — deployment continues and the operator
// can rerun manually.
func registerWHMPlugin(confPath string) error {
bin := "/usr/local/cpanel/bin/register_appconfig"
if _, err := os.Stat(bin); err != nil {
return fmt.Errorf("register_appconfig not found at %s: %w", bin, err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// #nosec G204 -- bin is the fixed cPanel path validated by os.Stat above;
// confPath was just written by deployConfigs from an embedded constant.
cmd := exec.CommandContext(ctx, bin, confPath)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("%s %s: %w: %s", bin, confPath, err, strings.TrimSpace(string(out)))
}
return nil
}
// watchdogNotifier sends systemd watchdog keepalives on its own ticker.
// Runs at half the WatchdogSec interval so there's always margin.
// Completely independent of scan goroutines — never blocks.
func (d *Daemon) watchdogNotifier() {
defer d.wg.Done()
usecStr := os.Getenv("WATCHDOG_USEC")
if usecStr == "" {
return // watchdog not configured
}
addr := os.Getenv("NOTIFY_SOCKET")
if addr == "" {
return
}
usec, err := strconv.ParseInt(usecStr, 10, 64)
if err != nil || usec <= 0 {
return
}
// Notify at half the watchdog interval for safety margin
interval := time.Duration(usec) * time.Microsecond / 2
if interval < 10*time.Second {
interval = 10 * time.Second
}
csmlog.Info("systemd watchdog active", "interval", interval.String())
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
if _, err := sdnotify.Watchdog(); err != nil {
csmlog.Warn("systemd watchdog notify failed", "err", err)
}
}
}
}
func sdNotify(addr, msg string) {
conn, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
if err != nil {
return
}
defer func() { _ = syscall.Close(conn) }()
sa := &syscall.SockaddrUnix{Name: addr}
_ = syscall.Sendmsg(conn, []byte(msg), nil, sa, 0)
}
func ts() string {
return time.Now().Format("2006-01-02 15:04:05")
}
func orUnknown(v string) string {
if v == "" {
return "unknown"
}
return v
}
func orNone(v string) string {
if v == "" {
return "none"
}
return v
}
// countAttachedWatchers returns how many watchers are currently attached
// (value == true). Used for the systemd one-line status string.
func countAttachedWatchers(statuses map[string]bool) int {
n := 0
for _, attached := range statuses {
if attached {
n++
}
}
return n
}
//go:build !(linux && bpf)
package daemon
import (
"context"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/bpf"
"github.com/pidginhost/csm/internal/config"
)
// execBPF is the no-tag placeholder for the BPF exec-monitor backend. The
// real type with a tracepoint link and ringbuf reader lives in exec_bpf.go
// behind //go:build linux && bpf.
type execBPF struct{}
func (e *execBPF) Mode() string { return "bpf" }
func (e *execBPF) EventCount() uint64 { return 0 }
func (e *execBPF) Run(_ context.Context) {}
func startExecBPF(_ context.Context, _ chan<- alert.Finding, _ *config.Config) (*execBPF, error) {
return nil, bpf.ErrNotBuilt
}
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64
package exec_bpfprog
import (
"bytes"
_ "embed"
"fmt"
"io"
"structs"
"github.com/cilium/ebpf"
)
type ExecExecEvent struct {
_ structs.HostLayout
Uid uint32
Pid uint32
Ppid uint32
Comm [16]uint8
ParentComm [16]uint8
Filename [256]uint8
}
// LoadExec returns the embedded CollectionSpec for Exec.
func LoadExec() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_ExecBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load Exec: %w", err)
}
return spec, err
}
// LoadExecObjects loads Exec and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *ExecObjects
// *ExecPrograms
// *ExecMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func LoadExecObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := LoadExec()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// ExecSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type ExecSpecs struct {
ExecProgramSpecs
ExecMapSpecs
ExecVariableSpecs
}
// ExecProgramSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type ExecProgramSpecs struct {
CsmOnExec *ebpf.ProgramSpec `ebpf:"csm_on_exec"`
}
// ExecMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type ExecMapSpecs struct {
Events *ebpf.MapSpec `ebpf:"events"`
}
// ExecVariableSpecs contains global variables before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type ExecVariableSpecs struct {
Unused *ebpf.VariableSpec `ebpf:"unused"`
}
// ExecObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to LoadExecObjects or ebpf.CollectionSpec.LoadAndAssign.
type ExecObjects struct {
ExecPrograms
ExecMaps
ExecVariables
}
func (o *ExecObjects) Close() error {
return _ExecClose(
&o.ExecPrograms,
&o.ExecMaps,
)
}
// ExecMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to LoadExecObjects or ebpf.CollectionSpec.LoadAndAssign.
type ExecMaps struct {
Events *ebpf.Map `ebpf:"events"`
}
func (m *ExecMaps) Close() error {
return _ExecClose(
m.Events,
)
}
// ExecVariables contains all global variables after they have been loaded into the kernel.
//
// It can be passed to LoadExecObjects or ebpf.CollectionSpec.LoadAndAssign.
type ExecVariables struct {
Unused *ebpf.Variable `ebpf:"unused"`
}
// ExecPrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to LoadExecObjects or ebpf.CollectionSpec.LoadAndAssign.
type ExecPrograms struct {
CsmOnExec *ebpf.Program `ebpf:"csm_on_exec"`
}
func (p *ExecPrograms) Close() error {
return _ExecClose(
p.CsmOnExec,
)
}
func _ExecClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed exec_x86_bpfel.o
var _ExecBytes []byte
package daemon
import (
"encoding/binary"
"errors"
)
// ExecEvent matches struct exec_event in exec.bpf.c byte for byte:
// scalars are little-endian (host order on amd64/arm64), comm/parent_comm
// are 16-byte null-padded, filename is 256-byte null-padded.
type ExecEvent struct {
UID uint32
PID uint32
PPID uint32
Comm string
ParentComm string
Filename string
}
const execEventSize = 4 + 4 + 4 + 16 + 16 + 256
func decodeExecEvent(b []byte) (ExecEvent, error) {
if len(b) < execEventSize {
return ExecEvent{}, errors.New("exec event short buffer")
}
ev := ExecEvent{
UID: binary.LittleEndian.Uint32(b[0:4]),
PID: binary.LittleEndian.Uint32(b[4:8]),
PPID: binary.LittleEndian.Uint32(b[8:12]),
}
ev.Comm = nullTerm(b[12:28])
ev.ParentComm = nullTerm(b[28:44])
ev.Filename = nullTerm(b[44 : 44+256])
return ev, nil
}
package daemon
import (
"context"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/state"
)
// execPoller wraps the periodic CheckSuspiciousProcesses + CheckFakeKernelThreads
// in a goroutine. Used when the BPF backend is unavailable or operator-disabled.
// Detection latency for already-running processes equals the poll interval (default
// 30 minutes, matching the existing deep-tier cadence).
type execPoller struct {
cfg *config.Config
alertCh chan<- alert.Finding
count atomic.Uint64
}
func newExecPoller(cfg *config.Config, alertCh chan<- alert.Finding) *execPoller {
return &execPoller{cfg: cfg, alertCh: alertCh}
}
func (p *execPoller) Mode() string { return "legacy" }
func (p *execPoller) EventCount() uint64 { return p.count.Load() }
func (p *execPoller) Run(ctx context.Context) {
interval := execPollerInterval(p.cfg)
t := time.NewTicker(interval)
defer t.Stop()
emit := func(fs []alert.Finding) {
for _, f := range fs {
p.count.Add(1)
select {
case p.alertCh <- f:
default:
csmlog.Warn("exec legacy: alert channel full, dropping finding")
}
}
}
for {
select {
case <-ctx.Done():
return
case <-t.C:
emit(checks.CheckSuspiciousProcesses(ctx, p.cfg, (*state.Store)(nil)))
emit(checks.CheckFakeKernelThreads(ctx, p.cfg, (*state.Store)(nil)))
}
}
}
func execPollerInterval(cfg *config.Config) time.Duration {
if d := cfg.Detection.ExecMonitorPollInterval; d > 0 {
return d
}
return 30 * time.Minute
}
package daemon
import (
"context"
"errors"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/bpf"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/processctx"
)
// StartExecMonitor selects the active exec-monitor backend based on
// cfg.Detection.ExecMonitorBackend and host capability:
//
// "auto" (default) -- try BPF, fall back to legacy polling.
// "bpf" -- require BPF; return nil if unavailable (no fallback).
// "legacy" -- pin legacy polling.
// "none" -- disable the live monitor (periodic checks still run).
//
// Unknown values fall back to "auto" with a warning. The metric
// csm_bpf_backend{feature="exec_monitor", kind="..."} reflects the
// chosen path.
func StartExecMonitor(alertCh chan<- alert.Finding, cfg *config.Config) bpf.Backend {
choice := strings.ToLower(strings.TrimSpace(cfg.Detection.ExecMonitorBackend))
if choice == "" {
choice = bpf.BackendAuto
}
switch choice {
case bpf.BackendAuto, bpf.BackendBPF, bpf.BackendLegacy, bpf.BackendNone:
default:
csmlog.Warn("exec_monitor: unknown backend choice, using auto", "value", choice)
choice = bpf.BackendAuto
}
if choice == bpf.BackendNone {
csmlog.Info("exec_monitor: disabled by config")
bpf.SetActive("exec_monitor", bpf.BackendNone)
return nil
}
var bpfErr error
if choice == bpf.BackendAuto || choice == bpf.BackendBPF {
if b, err := tryStartExecBPFFn(context.Background(), alertCh, cfg); err == nil && b != nil {
csmlog.Info("exec_monitor", "backend", "bpf", "choice", choice)
bpf.SetActive("exec_monitor", bpf.BackendBPF)
return b
} else if err != nil {
bpfErr = err
level := "bpf-unsupported"
if errors.Is(err, bpf.ErrNotBuilt) {
level = "bpf-not-built"
}
csmlog.Info("exec_monitor: BPF unavailable", "state", level, "reason", err.Error(), "choice", choice)
if choice == bpf.BackendBPF {
csmlog.Warn("exec_monitor: backend=bpf but BPF unavailable; no live monitor", "reason", err.Error())
bpf.SetActive("exec_monitor", bpf.BackendNone)
emitBPFUnavailableFinding(alertCh, "exec_monitor", choice, "", err)
return nil
}
}
}
poller := newExecPoller(cfg, alertCh)
csmlog.Info("exec_monitor", "backend", "legacy", "choice", choice)
bpf.SetActive("exec_monitor", bpf.BackendLegacy)
if bpfErr != nil {
emitBPFUnavailableFinding(alertCh, "exec_monitor", choice, bpf.BackendLegacy, bpfErr)
}
return poller
}
// tryStartExecBPFFn is the package-level indirection so tests can substitute
// a fake without the bpf build tag.
var tryStartExecBPFFn = tryStartExecBPF
func tryStartExecBPF(ctx context.Context, ch chan<- alert.Finding, cfg *config.Config) (bpf.Backend, error) {
b, err := startExecBPF(ctx, ch, cfg)
if err != nil {
return nil, err
}
return b, nil
}
// populateProcessCtxFromExec writes the BPF exec event into the process
// context cache without doing identity work in the event loop. Zero PID is
// ignored (synthetic boot-time noise).
func populateProcessCtxFromExec(cache *processctx.Cache, ev ExecEvent, startedAt time.Time) {
if ev.PID == 0 {
return
}
cache.PutFromExecStartedAt(int(ev.PID), int(ev.PPID), int(ev.UID), ev.Comm, ev.Filename, startedAt)
}
func attachProcessCtxToExecFinding(cache *processctx.Cache, f *alert.Finding, ev ExecEvent) {
req := processctx.EnrichRequest{
PID: int(ev.PID),
UID: int(ev.UID),
UIDKnown: true,
Comm: ev.Comm,
StartedAt: processCtxStartedAt(int(ev.PID)),
}
if pc, _ := cache.MaterializeVerifiedSnapshot(req); pc != nil {
f.Process = pc
}
}
// processctxRequestFromExec builds the EnrichRequest snapshot used by the BPF
// exec consumer. Lives here (no build tag) so unit tests on darwin can verify
// the field mapping without a Linux+bpf build.
func processctxRequestFromExec(ev ExecEvent) processctx.EnrichRequest {
pid := int(ev.PID)
return processctx.EnrichRequest{
PID: pid,
UID: int(ev.UID),
UIDKnown: true,
Comm: ev.Comm,
StartedAt: processCtxStartedAt(pid),
}
}
//go:build linux
package daemon
import (
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"unsafe"
"golang.org/x/sys/unix"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/signatures"
"github.com/pidginhost/csm/internal/wpcheck"
"github.com/pidginhost/csm/internal/yara"
)
// fanotify constants (not all in Go stdlib)
const (
FAN_MARK_ADD = 0x00000001
FAN_MARK_MOUNT = 0x00000010
FAN_CLOSE_WRITE = 0x00000008
FAN_CREATE = 0x00000100
FAN_CLASS_NOTIF = 0x00000000
FAN_CLOEXEC = 0x00000001
FAN_NONBLOCK = 0x00000002
)
// fanotifyEventMetadata is the header for each fanotify event.
type fanotifyEventMetadata struct {
EventLen uint32
Vers uint8
Reserved uint8
MetadataLen uint16
Mask uint64
Fd int32
Pid int32
}
const metadataSize = int(unsafe.Sizeof(fanotifyEventMetadata{}))
// M1 - webshells map at package level (avoid per-call allocation)
var knownWebshells = map[string]bool{
"h4x0r.php": true, "c99.php": true, "r57.php": true,
"wso.php": true, "alfa.php": true, "b374k.php": true,
"shell.php": true, "cmd.php": true, "backdoor.php": true,
"webshell.php": true,
}
// M3 - plugin stat cache with TTL
type pluginCacheEntry struct {
exists bool
ts time.Time
}
var pluginStatCache sync.Map // key: pluginDir string → value: pluginCacheEntry
const pluginCacheTTL = 5 * time.Minute
// cronSpoolWatchDir is the directory the realtime crontab watcher marks.
// Declared as a var (not const) so tests can redirect it under t.TempDir()
// without touching the real /var/spool/cron.
var cronSpoolWatchDir = "/var/spool/cron"
// alertDedupTTL is the cooldown period for duplicate alerts on the same
// check+filepath combination. Prevents alert storms from rapid writes.
const alertDedupTTL = 30 * time.Second
// FileMonitor watches mount points for file creation/modification using fanotify.
type FileMonitor struct {
fd int
cfg *config.Config
alertCh chan<- alert.Finding
analyzerCh chan fileEvent
// M7 - separate counters for dropped events and alerts
droppedEvents int64
droppedAlerts int64
// C4 - pipe for epoll stop signaling
pipeFds [2]int // [0]=read, [1]=write
pipeClosed int32 // atomic flag: 1 = pipe fds closed by drainAndClose
// C2 - sync.Once for safe Stop
stopOnce sync.Once
drainOnce sync.Once
stopCh chan struct{} // internal stop channel
wg sync.WaitGroup
// Per-path alert deduplication: "check:filepath" → last alert time
alertDedup sync.Map
// WordPress core checksum verifier - skips detection on unmodified WP core files
wpCache *wpcheck.Cache
// Drop-recovery reconcile: directories that had fanotify events dropped
// because the analyzer queue was full. The overflow reporter walks this
// set once a minute and scans any interesting file modified within
// reconcileWindow so bulk filesystem operations (unzip, backup restore)
// do not blind detection to actual threats landing in the storm.
reconcileMu sync.Mutex
reconcileDirs map[string]time.Time
// reconcileSig is a buffered cap-1 channel that lets sendEvent's drop
// branch nudge overflowReporter to run reconcileDrops out of cycle
// when sustained drops cross eagerReconcileDropThreshold. The cap-1
// shape collapses multiple triggers in the same window into one and
// keeps sendEvent non-blocking on the event-loop hot path.
reconcileSig chan struct{}
// metricsOnce guards one-time registration of the fanotify-scoped
// Prometheus metrics. Each FileMonitor registers its own hooks when
// it first starts; subsequent calls are a no-op.
metricsOnce sync.Once
}
const (
// reconcileDirCap bounds the dirty-region tracker fed by sendEvent's
// drop branch. A 2026-04-28 cpanel package restore overflowed the
// previous 64-entry cap inside seconds (every wp-content subdir was
// a distinct parent), evicting older dirs before reconcileDrops ran.
// 1024 entries fits typical cpanel restore bursts comfortably while
// staying tiny in memory (each entry is a string-pointer + time, so
// the whole map peaks under ~100 KiB even at full cap).
reconcileDirCap = 1024
// reconcileWindow scopes which files reconcileDrops will rescan: only
// files whose mtime is within this window of "now". Sized just over
// the minute tick so a drop near the start of a tick is still picked
// up by the reconcile that runs at tick end.
reconcileWindow = 70 * time.Second
// analyzerChBufferSize sizes the channel feeding the analyzer pool.
// A cpanel package restore in production observed ~4189 events in a
// few seconds; a 16 KiB buffer absorbs that burst plus headroom
// without ever overflowing. Memory cost is bounded (fileEvent is a
// path string + fd + pid, ~40 bytes each, so <1 MiB at full buffer).
analyzerChBufferSize = 16384
// eagerReconcileDropThreshold triggers an out-of-cycle reconcile
// when sustained drops cross this count within a single minute tick.
// Without this, drops happening just after a tick wait the full
// interval before reconcileDrops walks them - long enough for the
// reconcileWindow to expire on the earliest dropped files.
eagerReconcileDropThreshold = 500
)
// Package-level Prometheus metrics for fanotify. Instantiated once per
// process; one FileMonitor per daemon instance reuses them.
var (
fanotifyDroppedTotal *metrics.Counter
fanotifyReconcileDur *metrics.Histogram
contentScanTruncated *metrics.CounterVec
)
// registerFanotifyMetrics is called once per FileMonitor via
// fm.metricsOnce. Safe to call multiple times at the FileMonitor
// layer; the package-level sync.Once guards the actual registrations.
var fanotifyMetricsInit sync.Once
func (fm *FileMonitor) registerMetrics() {
fm.metricsOnce.Do(func() {
fanotifyMetricsInit.Do(func() {
fanotifyDroppedTotal = metrics.NewCounter(
"csm_fanotify_events_dropped_total",
"Fanotify events dropped because the analyzer queue was full. Sustained growth indicates an event storm (bulk unzip, backup restore) or an attack producing more file activity than the scanner can analyse; the reconcile pass still rescans affected directories, so dropped events do not vanish from detection, they arrive delayed.",
)
metrics.MustRegister("csm_fanotify_events_dropped_total", fanotifyDroppedTotal)
fanotifyReconcileDur = metrics.NewHistogram(
"csm_fanotify_reconcile_latency_seconds",
"How long the post-overflow reconcile pass takes to walk drop-affected directories and rescan recent files. Buckets sized for the observed range; alert if p99 crosses tens of seconds (reconcile is stealing CPU from real-time analysis).",
[]float64{0.01, 0.05, 0.1, 0.5, 1, 5, 10, 30, 60},
)
metrics.MustRegister("csm_fanotify_reconcile_latency_seconds", fanotifyReconcileDur)
metrics.RegisterGaugeFunc(
"csm_fanotify_queue_depth",
"Current number of queued fanotify events waiting for the analyzer pool. Capacity is 4000; queue approaching that cap means drops are imminent.",
func() float64 {
if fm == nil || fm.analyzerCh == nil {
return 0
}
return float64(len(fm.analyzerCh))
},
)
contentScanTruncated = metrics.NewCounterVec(
"csm_realtime_content_scan_truncated_total",
"Real-time fanotify content checks whose file was larger than the main read window, so the full-rule pass saw only the leading window. Labels: check (phpcontent_inline, phpcontent_uploads, php_check, crontab, htaccess, user_ini, html_phishing, cgi_backdoor).",
[]string{"check"},
)
metrics.MustRegister("csm_realtime_content_scan_truncated_total", contentScanTruncated)
})
})
}
// recordReadTruncation increments csm_realtime_content_scan_truncated_total
// when the file behind fd is larger than maxBytes. Cheap fstat per scan.
// No-op if the counter has not been registered (test setups that skip
// registerMetrics).
func recordReadTruncation(fd int, maxBytes int, check string) {
if contentScanTruncated == nil {
return
}
var st unix.Stat_t
if err := unix.Fstat(fd, &st); err != nil {
return
}
if st.Size > int64(maxBytes) {
contentScanTruncated.With(check).Inc()
}
}
type fileEvent struct {
path string
fd int
pid int32
}
func (fm *FileMonitor) currentCfg() *config.Config {
if cfg := config.Active(); cfg != nil {
return cfg
}
if fm == nil {
return nil
}
return fm.cfg
}
// NewFileMonitor creates a fanotify-based file monitor.
// Returns error if the kernel doesn't support the required features.
func NewFileMonitor(cfg *config.Config, alertCh chan<- alert.Finding) (*FileMonitor, error) {
// H1 - use golang.org/x/sys/unix for fanotify_init
fd, err := unix.FanotifyInit(FAN_CLASS_NOTIF|FAN_CLOEXEC|FAN_NONBLOCK, unix.O_RDONLY)
if err != nil {
return nil, fmt.Errorf("fanotify_init: %w (kernel may not support fanotify)", err)
}
// Mark mount points; M2 - track successful mounts
mountPaths := []string{"/home", "/tmp", "/dev/shm", "/var/tmp"}
mountOK := 0
for _, path := range mountPaths {
// H1 - use golang.org/x/sys/unix for fanotify_mark
err = unix.FanotifyMark(fd, FAN_MARK_ADD|FAN_MARK_MOUNT, FAN_CLOSE_WRITE|FAN_CREATE, -1, path)
if err != nil {
// Try without FAN_CREATE (older kernels)
err = unix.FanotifyMark(fd, FAN_MARK_ADD|FAN_MARK_MOUNT, FAN_CLOSE_WRITE, -1, path)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Warning: cannot watch %s: %v\n", ts(), path, err)
continue
}
}
mountOK++
}
// M2 - error on zero successful mounts
if mountOK == 0 {
_ = unix.Close(fd)
return nil, fmt.Errorf("no mount points could be watched (tried %v)", mountPaths)
}
// Directory-scoped watch on /var/spool/cron so any user crontab write
// reaches analyzeFile in real time. Best-effort: cron may live under a
// different path on non-cPanel hosts (the platform layer normalises),
// and we'd rather lose the realtime crontab signal than fail daemon
// startup. The polled CheckCrontabs run still covers this case via
// the next scheduled scan. Mask matches spoolwatch.go (the proven
// production pattern for directory-scoped marks): FAN_CLOSE_WRITE
// alone catches both new and modified crontabs, since the close
// after O_CREAT|O_WRONLY|... fires the close-write event. FAN_CREATE
// is omitted because it has stricter kernel requirements with
// directory-scoped (non-MOUNT) marks and adds no coverage here.
if _, statErr := os.Stat(cronSpoolWatchDir); statErr == nil {
if err := unix.FanotifyMark(fd, FAN_MARK_ADD,
FAN_CLOSE_WRITE|FAN_EVENT_ON_CHILD, -1, cronSpoolWatchDir); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Warning: cannot watch %s: %v\n", ts(), cronSpoolWatchDir, err)
}
}
// C4 - create pipe for epoll stop signaling
var pipeFds [2]int
if err := unix.Pipe2(pipeFds[:], unix.O_NONBLOCK|unix.O_CLOEXEC); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("pipe2: %w", err)
}
fm := &FileMonitor{
fd: fd,
cfg: cfg,
alertCh: alertCh,
analyzerCh: make(chan fileEvent, analyzerChBufferSize),
pipeFds: pipeFds,
stopCh: make(chan struct{}),
reconcileDirs: make(map[string]time.Time),
reconcileSig: make(chan struct{}, 1),
}
fm.wpCache = wpcheck.NewCache(cfg.StatePath)
fm.wpCache.SetStopCh(fm.stopCh)
return fm, nil
}
// Run starts the file monitor event loop and analyzer workers.
func (fm *FileMonitor) Run(stopCh <-chan struct{}) {
// H7 - configurable workers: min 4, max 16, based on NumCPU
numWorkers := runtime.NumCPU()
if numWorkers < 4 {
numWorkers = 4
}
if numWorkers > 16 {
numWorkers = 16
}
for i := 0; i < numWorkers; i++ {
fm.wg.Add(1)
obs.Go("fanotify-analyzer", fm.analyzerWorker)
}
// Start overflow reporter
fm.wg.Add(1)
obs.Go("fanotify-overflow", fm.overflowReporter)
// C4 - create epoll instance, watch fanotify fd + pipe read end
epfd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] epoll_create1 failed: %v, falling back to poll loop\n", ts(), err)
fm.runPollFallback(stopCh)
return
}
defer func() { _ = unix.Close(epfd) }()
// Add fanotify fd to epoll
if err := unix.EpollCtl(epfd, unix.EPOLL_CTL_ADD, fm.fd, &unix.EpollEvent{
Events: unix.EPOLLIN,
// #nosec G115 -- POSIX fd fits in int32 (rlimit caps fds at ~1024).
Fd: int32(fm.fd),
}); err != nil {
fmt.Fprintf(os.Stderr, "[%s] epoll_ctl(fanotify): %v\n", ts(), err)
fm.runPollFallback(stopCh)
return
}
// Add pipe read end to epoll (for stop signaling)
if err := unix.EpollCtl(epfd, unix.EPOLL_CTL_ADD, fm.pipeFds[0], &unix.EpollEvent{
Events: unix.EPOLLIN,
// #nosec G115 -- POSIX fd fits in int32.
Fd: int32(fm.pipeFds[0]),
}); err != nil {
fmt.Fprintf(os.Stderr, "[%s] epoll_ctl(pipe): %v\n", ts(), err)
fm.runPollFallback(stopCh)
return
}
// Forward external stopCh to our internal mechanism
obs.SafeGo("fanotify-stop-forward", func() {
select {
case <-stopCh:
fm.Stop()
case <-fm.stopCh:
}
})
buf := make([]byte, 4096*24) // Large buffer for event batches
events := make([]unix.EpollEvent, 4)
for {
n, err := unix.EpollWait(epfd, events, 500) // 500ms timeout
if err != nil {
if err == unix.EINTR {
continue
}
// Check if we've been stopped
select {
case <-fm.stopCh:
fm.drainAndClose()
return
default:
}
fmt.Fprintf(os.Stderr, "[%s] epoll_wait error: %v\n", ts(), err)
time.Sleep(1 * time.Second)
continue
}
// Check for stop first
select {
case <-fm.stopCh:
fm.drainAndClose()
return
default:
}
for i := 0; i < n; i++ {
// #nosec G115 -- POSIX fd fits in int32; comparing against epoll event fd.
if events[i].Fd == int32(fm.pipeFds[0]) {
// Stop signal received via pipe
fm.drainAndClose()
return
}
// #nosec G115 -- POSIX fd fits in int32.
if events[i].Fd == int32(fm.fd) {
// fanotify events ready — single read per epoll wake
nr, readErr := unix.Read(fm.fd, buf)
if readErr != nil {
if readErr != unix.EAGAIN && readErr != unix.EINTR {
fmt.Fprintf(os.Stderr, "[%s] fanotify read error: %v\n", ts(), readErr)
}
} else if nr >= metadataSize {
fm.processEvents(buf[:nr])
}
}
}
}
}
// runPollFallback is used when epoll setup fails; falls back to sleep-based polling.
func (fm *FileMonitor) runPollFallback(stopCh <-chan struct{}) {
// Forward external stopCh to our internal mechanism
obs.SafeGo("fanotify-stop-forward", func() {
select {
case <-stopCh:
fm.Stop()
case <-fm.stopCh:
}
})
buf := make([]byte, 4096*24)
for {
select {
case <-fm.stopCh:
fm.drainAndClose()
return
default:
}
n, err := unix.Read(fm.fd, buf)
if err != nil {
if err == unix.EAGAIN || err == unix.EINTR {
time.Sleep(100 * time.Millisecond)
continue
}
fmt.Fprintf(os.Stderr, "[%s] fanotify read error: %v\n", ts(), err)
time.Sleep(1 * time.Second)
continue
}
if n < metadataSize {
continue
}
fm.processEvents(buf[:n])
}
}
// processEvents parses a buffer of fanotify event metadata and dispatches each event.
func (fm *FileMonitor) processEvents(buf []byte) {
offset := 0
for offset+metadataSize <= len(buf) {
// #nosec G103 -- fanotify delivers a packed binary stream on the
// fd; we must reinterpret the byte buffer as the kernel struct.
// The metadataSize bounds check above guarantees we have enough
// bytes for the struct.
event := (*fanotifyEventMetadata)(unsafe.Pointer(&buf[offset]))
if event.EventLen < uint32(metadataSize) {
break
}
if event.Fd >= 0 {
fm.handleEvent(int(event.Fd), event.Pid)
}
offset += int(event.EventLen)
}
}
// drainAndClose drains the analyzerCh and waits for workers to finish.
// C1 - ensures no fd leak on shutdown.
func (fm *FileMonitor) drainAndClose() {
fm.drainOnce.Do(func() {
close(fm.analyzerCh)
fm.wg.Wait()
// Mark pipe as closed before actually closing, so Stop() won't
// write to an already-closed fd (H2 fix).
atomic.StoreInt32(&fm.pipeClosed, 1)
_ = unix.Close(fm.pipeFds[0])
_ = unix.Close(fm.pipeFds[1])
})
}
// Stop signals the monitor to shut down.
// C2 - sync.Once ensures safe concurrent calls; does not close analyzerCh directly.
func (fm *FileMonitor) Stop() {
fm.stopOnce.Do(func() {
close(fm.stopCh)
// Wake epoll so Run() exits and calls drainAndClose.
// Only write if pipe hasn't been closed by drainAndClose yet.
if atomic.LoadInt32(&fm.pipeClosed) == 0 {
_, _ = unix.Write(fm.pipeFds[1], []byte{0})
}
// Close fanotify fd - causes any pending Read/EpollWait to return
_ = unix.Close(fm.fd)
})
}
func (fm *FileMonitor) handleEvent(fd int, pid int32) {
// Get the file path from the fd via /proc/self/fd/N
path, err := os.Readlink(fmt.Sprintf("/proc/self/fd/%d", fd))
if err != nil {
_ = unix.Close(fd)
return
}
// M5 - skip directory events
if strings.HasSuffix(path, "/") {
_ = unix.Close(fd)
return
}
// Fast filter - decide if this file is interesting based on path only
if !fm.isInteresting(path) {
_ = unix.Close(fd)
return
}
// Send to analyzer pool (with backpressure)
select {
case fm.analyzerCh <- fileEvent{path: path, fd: fd, pid: pid}:
default:
// Queue full - drop event, count, and record the parent dir so the
// reconcile pass in overflowReporter can rescan it. Without this
// every file in a bulk burst past buffer capacity is invisible to
// detection forever.
n := atomic.AddInt64(&fm.droppedEvents, 1)
if fanotifyDroppedTotal != nil {
fanotifyDroppedTotal.Inc()
}
if n%100 == 0 {
fmt.Fprintf(os.Stderr, "[%s] fanotify: %d events dropped (analyzer queue full)\n", ts(), n)
}
fm.recordDroppedDir(path)
fm.maybeTriggerEagerReconcile(n)
_ = unix.Close(fd)
}
}
// maybeTriggerEagerReconcile nudges overflowReporter to run reconcileDrops
// immediately when sustained drops cross eagerReconcileDropThreshold within
// a single minute window. Delegates to the free function so the trigger
// logic stays testable from a cross-platform test file.
func (fm *FileMonitor) maybeTriggerEagerReconcile(droppedSoFar int64) {
signalEagerReconcile(fm.reconcileSig, droppedSoFar, eagerReconcileDropThreshold)
}
// recordDroppedDir registers a directory whose file had its fanotify event
// dropped, capped at reconcileDirCap entries (oldest evicted).
func (fm *FileMonitor) recordDroppedDir(path string) {
dir := filepath.Dir(path)
fm.reconcileMu.Lock()
defer fm.reconcileMu.Unlock()
if fm.reconcileDirs == nil {
fm.reconcileDirs = make(map[string]time.Time)
}
fm.reconcileDirs[dir] = time.Now()
if len(fm.reconcileDirs) <= reconcileDirCap {
return
}
var oldestKey string
var oldestTime time.Time
first := true
for k, t := range fm.reconcileDirs {
if first || t.Before(oldestTime) {
oldestKey, oldestTime, first = k, t, false
}
}
delete(fm.reconcileDirs, oldestKey)
}
// reconcileDrops walks every directory with a recent dropped event and
// analyses any interesting file modified within reconcileWindow. Converts
// lost events into delayed events rather than invisible ones. Called from
// overflowReporter after the minute-granularity overflow alert.
func (fm *FileMonitor) reconcileDrops() {
fm.reconcileMu.Lock()
dirs := fm.reconcileDirs
fm.reconcileDirs = make(map[string]time.Time)
fm.reconcileMu.Unlock()
if len(dirs) == 0 {
return
}
if fanotifyReconcileDur != nil {
start := time.Now()
defer func() {
fanotifyReconcileDur.Observe(time.Since(start).Seconds())
}()
}
cutoff := time.Now().Add(-reconcileWindow)
for dir := range dirs {
entries, err := os.ReadDir(dir)
if err != nil {
continue
}
for _, e := range entries {
if e.IsDir() {
continue
}
info, err := e.Info()
if err != nil {
continue
}
if info.ModTime().Before(cutoff) {
continue
}
fullPath := filepath.Join(dir, e.Name())
if !fm.isInteresting(fullPath) {
continue
}
// Open+analyse+close wrapped so a panic in analyzeFile does not
// leak the fd; otherwise `defer f.Close()` in a loop body would
// defer until reconcileDrops returns, accumulating fds across
// every entry in every tracked dir.
func() {
// #nosec G304 -- fullPath is a directory entry under a dir
// the kernel already notified us about; reconcile owns
// reopening because the original fanotify fd is gone.
f, err := os.Open(fullPath)
if err != nil {
return
}
defer func() { _ = f.Close() }()
// #nosec G115 -- POSIX fd fits in int32 (rlimit caps fds at ~1024).
fm.analyzeFile(fileEvent{path: fullPath, fd: int(f.Fd())})
}()
}
}
}
// isInteresting is the fast filter - zero I/O, pure string matching.
func (fm *FileMonitor) isInteresting(path string) bool {
// Atomic-write staging files. cPanel's fileTransfer and any restore
// tool using write-then-rename stages content as
// `.temp.<nanoseconds>.<name>.<ext>` before rename(2) to the final
// path. CSM's fanotify mask is CLOSE_WRITE + CREATE only; it does not
// subscribe to FAN_MOVED_TO, so the post-rename file is never
// rescanned in real time. Scanning the transient staging path
// produces a false-positive storm on legitimate WordPress restores
// because the staged content IS genuine WP core. The periodic deep
// scan catches any file that lingers at a staging name (attacker
// hiding under `.temp.` would leave a permanent .temp.* file on disk
// for the next hourly deep pass to pick up).
if looksLikeAtomicWriteStage(filepath.Base(path)) {
return false
}
lower := strings.ToLower(path)
// PHP files
if isPHPExtension(filepath.Base(lower)) {
return true
}
// Webshell extensions
if strings.HasSuffix(lower, ".haxor") || strings.HasSuffix(lower, ".cgix") {
return true
}
// CGI scripts in web-accessible directories — detect Perl/Python/Bash backdoors
if strings.HasPrefix(path, "/home/") {
if strings.HasSuffix(lower, ".pl") || strings.HasSuffix(lower, ".cgi") ||
strings.HasSuffix(lower, ".py") || strings.HasSuffix(lower, ".sh") ||
strings.HasSuffix(lower, ".rb") {
return true
}
}
// .htaccess and .user.ini files
if strings.HasSuffix(lower, ".htaccess") || strings.HasSuffix(lower, ".user.ini") {
return true
}
// HTML files in /home (phishing pages)
if strings.HasPrefix(path, "/home/") &&
(strings.HasSuffix(lower, ".html") || strings.HasSuffix(lower, ".htm")) {
return true
}
// Credential log files - known phishing harvest filenames
base := filepath.Base(lower)
if credentialLogNames[base] {
return true
}
// ZIP archives in /home (phishing kit uploads)
if strings.HasPrefix(path, "/home/") && strings.HasSuffix(lower, ".zip") {
return true
}
// Anything in .config directories
if strings.Contains(path, "/.config/") {
return true
}
// User crontabs surfaced via the directory-scoped fanotify mark in
// NewFileMonitor. Each write to /var/spool/cron/<user> dispatches to
// checkCrontab in real time.
if strings.HasPrefix(path, cronSpoolWatchDir+"/") {
return true
}
// Executables in /tmp or /dev/shm
if strings.HasPrefix(path, "/tmp/") || strings.HasPrefix(path, "/dev/shm/") || strings.HasPrefix(path, "/var/tmp/") {
return true
}
// PHP in sensitive directories that should never contain PHP
if (strings.Contains(path, "/.ssh/") || strings.Contains(path, "/.cpanel/") ||
strings.Contains(path, "/mail/") || strings.Contains(path, "/.gnupg/") ||
strings.Contains(path, "/.cagefs/")) && isPHPExtension(strings.ToLower(filepath.Base(path))) {
return true
}
return false
}
// credentialLogNames are filenames commonly used by phishing kits to store
// harvested credentials. Checked in isInteresting() for real-time detection.
var credentialLogNames = map[string]bool{
"results.txt": true, "result.txt": true, "log.txt": true,
"logs.txt": true, "emails.txt": true, "data.txt": true,
"passwords.txt": true, "creds.txt": true, "credentials.txt": true,
"victims.txt": true, "output.txt": true, "harvested.txt": true,
"results.log": true, "emails.log": true, "data.log": true,
"results.csv": true, "emails.csv": true, "data.csv": true,
}
// analyzerWorker processes file events from the bounded channel.
// C1 - on channel close, drains remaining events and closes their fds.
func (fm *FileMonitor) analyzerWorker() {
defer fm.wg.Done()
for event := range fm.analyzerCh {
fm.analyzeFile(event)
_ = unix.Close(event.fd)
}
}
// readFromFd reads up to maxBytes from a file descriptor at position 0.
// C3 - avoids TOCTOU by reading from the original fanotify event fd.
// readFromFd reads up to maxBytes from the fanotify event fd using pread
// at offset 0. Uses unix.Pread directly to avoid os.NewFile's GC finalizer
// which would close the fd out-of-band, racing with the worker's explicit close.
func readFromFd(fd int, maxBytes int) []byte {
buf := make([]byte, maxBytes)
n, err := unix.Pread(fd, buf, 0)
if n <= 0 || (err != nil && n == 0) {
return nil
}
return buf[:n]
}
func isBenignPHPStubData(fd int, data []byte) bool {
if len(data) == 0 {
return false
}
complete := false
var st unix.Stat_t
if err := unix.Fstat(fd, &st); err == nil {
complete = st.Size <= int64(len(data))
}
return checks.IsBenignPHPStubBytesComplete(data, complete)
}
// readTailFromFd reads the last maxBytes of a file via its fd using pread.
// Returns nil if the file is smaller than maxBytes (head scan already covers it).
func readTailFromFd(fd int, maxBytes int) []byte {
var stat unix.Stat_t
if err := unix.Fstat(fd, &stat); err != nil {
return nil
}
size := stat.Size
if size <= int64(maxBytes) {
return nil // head read already covers the entire file
}
offset := size - int64(maxBytes)
buf := make([]byte, maxBytes)
n, err := unix.Pread(fd, buf, offset)
if n <= 0 || (err != nil && n == 0) {
return nil
}
return buf[:n]
}
// resolveProcessInfo reads /proc/<pid>/comm and /proc/<pid>/status
// to build a "pid=N cmd=name uid=N" string for alert enrichment.
// Returns empty string on any error (process may have exited).
func resolveProcessInfo(pid int32) string {
if pid <= 0 {
return ""
}
procDir := fmt.Sprintf("/proc/%d", pid)
// Read process name
// #nosec G304 -- /proc/<pid>/comm; kernel pseudo-FS, pid is int32 from fanotify event.
comm, err := os.ReadFile(procDir + "/comm")
if err != nil {
return ""
}
name := strings.TrimSpace(string(comm))
info := fmt.Sprintf("pid=%d cmd=%s", pid, name)
// Read UID from status to map to cPanel username
// #nosec G304 -- /proc/<pid>/status; kernel pseudo-FS.
statusData, err := os.ReadFile(procDir + "/status")
if err != nil {
return info
}
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
info += fmt.Sprintf(" uid=%s", fields[1])
}
break
}
}
return info
}
func (fm *FileMonitor) analyzeFile(event fileEvent) {
path := event.path
name := filepath.Base(path)
nameLower := strings.ToLower(name)
// Resolve process info from PID (best-effort - process may have exited)
procInfo := resolveProcessInfo(event.pid)
// H2 - suppression path matching using filepath.Match. Read the live config
// (config.Active via currentCfg) so a SIGHUP change to suppressions.ignore_paths
// takes effect without a restart, matching the rest of this analyzer.
for _, ignore := range fm.currentCfg().Suppressions.IgnorePaths {
if matchSuppression(ignore, path) {
return
}
}
// Skip verified WordPress core files - checksum matches official WP.org checksums.
// Content is read from the event fd (not path) to preserve TOCTOU safety.
if fm.wpCache != nil && fm.wpCache.IsVerifiedCoreFile(event.fd, path) {
return
}
// Verified WordPress plugin files: hash matches the plugin's official
// wordpress.org ZIP for its declared version. Stops signature/YARA FPs
// on stock plugin code (Wordfence, Contact Form 7, etc.). Cache miss
// triggers a background fetch; misses fall through to rule evaluation.
if fm.wpCache != nil && fm.wpCache.IsVerifiedPluginFile(event.fd, path) {
return
}
// User crontab written under /var/spool/cron/<user>. Scan content
// from the event fd via the shared deep matcher and emit Critical on
// any hit. The polled CheckCrontabs run still tracks root crontab
// hash drift, so we skip root here to avoid duplicate signal.
if strings.HasPrefix(path, cronSpoolWatchDir+"/") {
fm.checkCrontab(event.fd, path, procInfo)
return
}
// Location-based severity escalation: PHP in dirs that should NEVER have PHP
if isPHPExtension(nameLower) {
for _, sensitive := range []string{"/.ssh/", "/.cpanel/", "/mail/", "/.gnupg/", "/.cagefs/"} {
if strings.Contains(path, sensitive) {
fm.sendAlertWithPath(alert.Critical, "php_in_sensitive_dir_realtime",
fmt.Sprintf("PHP file in critical directory: %s", path),
fmt.Sprintf("PHP should never exist in %s - likely webshell or backdoor", sensitive), path, procInfo)
return
}
}
}
// Known webshell filenames (M1 - package-level var). Filename alone is
// too weak: WordPress core ships wp-includes/Text/Diff/Engine/shell.php
// (the Pear Text_Diff library using shell_exec to call Unix `diff`).
// Confirm with content: the file must also exhibit webshell markers
// (request superglobal flowing into a dangerous function, or an
// eval/assert wrapping a base64/gzinflate decoder).
if knownWebshells[nameLower] {
recordReadTruncation(event.fd, 65536, "phpcontent_inline")
if data := readFromFd(event.fd, 65536); looksLikePHPWebshell(data) {
fm.sendAlertWithPath(alert.Critical, "webshell_realtime",
fmt.Sprintf("Webshell file created: %s", path), "", path, procInfo)
return
}
}
// Webshell extensions
if strings.HasSuffix(nameLower, ".haxor") || strings.HasSuffix(nameLower, ".cgix") {
fm.sendAlertWithPath(alert.Critical, "webshell_realtime",
fmt.Sprintf("Suspicious CGI file created: %s", path), "", path, procInfo)
return
}
// .htaccess modification - check for injection (C3 - read from fd).
// Checked before the /tmp early-return so a malicious .htaccess anywhere
// (including /tmp) is still analyzed for dangerous directives.
if nameLower == ".htaccess" {
fm.checkHtaccess(event.fd, path, procInfo)
return
}
// .user.ini modification - check for dangerous PHP settings (C3 - read from fd).
// Also checked before /tmp so malicious .user.ini is detected anywhere.
if nameLower == ".user.ini" {
fm.checkUserINI(event.fd, path, procInfo)
return
}
// Executables in .config - checked before the /tmp block so a miner
// dropped at /tmp/.config/* is flagged as executable_in_config_realtime
// (more specific) rather than executable_in_tmp_realtime.
// Uses unix.Fstat on the event fd (not os.Stat by path) for TOCTOU
// safety: an attacker cannot chmod -x or swap the file after the event.
if strings.Contains(path, "/.config/") {
var cfgStat unix.Stat_t
if err := unix.Fstat(event.fd, &cfgStat); err == nil {
isDir := cfgStat.Mode&unix.S_IFMT == unix.S_IFDIR
if !isDir && cfgStat.Mode&0111 != 0 {
fm.sendAlertWithPath(alert.Critical, "executable_in_config_realtime",
fmt.Sprintf("Executable created in .config: %s", path),
fmt.Sprintf("Size: %d", cfgStat.Size), path, procInfo)
}
}
return
}
// Executables in /tmp or /dev/shm - detect dropped malware/miners
// Uses unix.Fstat on event fd for TOCTOU safety (attacker can't chmod -x after event)
if strings.HasPrefix(path, "/tmp/") || strings.HasPrefix(path, "/dev/shm/") || strings.HasPrefix(path, "/var/tmp/") {
var tmpStat unix.Stat_t
if err := unix.Fstat(event.fd, &tmpStat); err == nil {
isDir := tmpStat.Mode&unix.S_IFMT == unix.S_IFDIR
isExec := tmpStat.Mode&0111 != 0
if !isDir && isExec {
// Skip known root-owned work directories:
// - cPanel: SpamAssassin compiles .so regex modules, UPCP stages scripts
// - dracut: rebuilds initramfs after kernel updates, copies system binaries
// Non-root files in these paths are still suspicious.
isCpanelWork := strings.Contains(path, "/cpanel.TMP.work.") || strings.Contains(path, "/cPanel-")
isDracutWork := strings.Contains(path, "/dracut.")
if (isCpanelWork || isDracutWork) && tmpStat.Uid == 0 {
// Root-owned executable in system work dir - legitimate, skip
} else {
fm.sendAlertWithPath(alert.Critical, "executable_in_tmp_realtime",
fmt.Sprintf("Executable created in %s: %s", filepath.Dir(path), path),
fmt.Sprintf("Size: %d, Mode: %04o, UID: %d", tmpStat.Size, tmpStat.Mode&0777, tmpStat.Uid), path, procInfo)
}
}
}
// Fall through to PHP checks below for .php files in /tmp
if !isPHPExtension(nameLower) {
return
}
}
// PHP in uploads directories.
// Any PHP file here is anomalous: /wp-content/uploads/ is meant for
// media, not code. Plugin-update temp dirs are recognised structurally
// via looksLikePluginUpdate only after content checks have had first
// refusal, so a decoy update directory cannot downgrade a webshell.
// Operators suppress legitimate caching daemons through the path-scoped
// suppressions_api, not an implicit substring allowlist in the daemon.
if strings.Contains(path, "/wp-content/uploads/") && isPHPExtension(nameLower) {
// Content-aware severity: PHP in uploads is anomalous but not
// always malicious (TinyMCE smile_fonts/charmap.php is glyph
// data shipped by WP's bundled editor). Emit Critical for
// direct webshell markers, otherwise run the broader PHP
// content/signature/YARA path before downgrading clean PHP to
// a Warning.
recordReadTruncation(event.fd, 65536, "phpcontent_uploads")
data := readFromFd(event.fd, 65536)
if looksLikePHPWebshell(data) {
fm.sendAlertWithPath(alert.Critical, "php_in_uploads_realtime",
fmt.Sprintf("PHP file created in uploads: %s", path),
"Webshell markers in content (request superglobal -> dangerous function, or eval/assert + decoder chain)",
path, procInfo)
} else {
if fm.checkPHPContent(event.fd, path, procInfo) {
return
}
// Content-shape gate: file whose reachable code is
// whitespace+comments, or that terminates with
// die/exit/__halt_compiler before any statement,
// cannot execute attacker-controlled code via web
// request. BackWPup writes its working-job and
// folder-cache state files this way. The earlier
// signature/YARA pass and the path-only warning
// below are the layers that fire on real droppers;
// a structurally inert stub adds no signal.
if isBenignPHPStubData(event.fd, data) {
return
}
if looksLikePluginUpdate(path) {
// Verified plugin update - emit one low-severity alert per temp directory.
uploadsIdx := strings.Index(path, "/wp-content/uploads/")
afterUploads := path[uploadsIdx+len("/wp-content/uploads/"):]
tempDir := afterUploads
if slashIdx := strings.Index(afterUploads, "/"); slashIdx > 0 {
tempDir = afterUploads[:slashIdx]
}
updateDir := path[:uploadsIdx] + "/wp-content/uploads/" + tempDir
fm.sendAlertWithPath(alert.Warning, "php_in_uploads_realtime",
fmt.Sprintf("Plugin update in uploads: %s", updateDir),
"Verified: matching plugin exists in plugins/", updateDir, procInfo)
return
}
// Suppress the path-only "anomalous location" warning
// when the file is structurally a duplicate (cPanel
// restore staging) or a known plugin probe shape that
// never carries executable input. Signature/YARA scans
// already ran above, so any real malicious content is
// reported through its own pipeline.
if looksLikeCpanelRestoreStaging(path) {
return
}
if looksLikeWPOptimizeProbe(path, data) {
return
}
fm.sendAlertWithPath(alert.Warning, "php_in_uploads_realtime",
fmt.Sprintf("PHP file in uploads (no webshell markers): %s", path),
"Anomalous location for PHP, but content is clean",
path, procInfo)
}
return
}
// PHP in languages/upgrade directories.
// Path-only Critical buried real alerts under location noise (WPML
// translation queues, WP auto-update staging). Run content analysis
// on every file -- a real rule fires Critical, clean real code gets a
// Warning, and inert stubs stay quiet. No filename allowlist: an attacker
// must not be able to hide a backdoor by naming it like a translation or
// index file.
if (strings.Contains(path, "/wp-content/languages/") || strings.Contains(path, "/wp-content/upgrade/")) &&
isPHPExtension(nameLower) {
if !fm.checkPHPContent(event.fd, path, procInfo) {
data := readFromFd(event.fd, 65536)
if isBenignPHPStubData(event.fd, data) {
return
}
fm.sendAlertWithPath(alert.Warning, "php_in_sensitive_dir_realtime",
fmt.Sprintf("PHP file created in sensitive WP directory (content clean): %s", path), "", path, procInfo)
}
return
}
// PHP content analysis (C3 - read from fd; M4 - 32KB scan size).
// .htaccess, .user.ini, and .config executable checks are handled
// earlier in this function (before the /tmp early-return) so specific
// file types take precedence over the /tmp generic block.
if isPHPExtension(nameLower) {
fm.checkPHPContent(event.fd, path, procInfo)
return
}
// HTML phishing page detection (uses event fd for content, unix.Fstat for size)
if strings.HasSuffix(nameLower, ".html") || strings.HasSuffix(nameLower, ".htm") {
fm.checkHTMLPhishing(event.fd, path, procInfo)
return
}
// Credential log files (content read from the event fd)
if credentialLogNames[nameLower] {
fm.checkCredentialLog(event.fd, path, procInfo)
return
}
// Phishing kit ZIP archives (path-based)
if strings.HasSuffix(nameLower, ".zip") {
fm.checkPhishingZip(path, nameLower, procInfo)
return
}
// CGI scripts in web-accessible directories (Perl, Python, Bash, Ruby)
// Detect backdoor toolkits like LEVIATHAN that use non-PHP scripts.
if strings.HasPrefix(path, "/home/") && isCGIExtension(nameLower) {
fm.checkCGIBackdoor(event.fd, path, procInfo)
return
}
}
// Structural exclusions for checkHtaccess. Both anchor to the actual
// directive or regex context, not to loose substrings that an attacker
// can paste anywhere on the line.
var (
// Legit auto_(prepend|append)_file directive targets: known product
// files shipped by security plugins. Match is anchored to the
// directive argument, so a trailing "# litespeed" comment cannot
// forge safety.
htaccessAutoPrependSafeTarget = regexp.MustCompile(
`(?i)auto_(?:prepend|append)_file\s*=?\s*['"]?(?:[^\s'"]*/)?` +
`(?:wordfence-waf|sucuri|advanced-headers)\.php(?:['"]|\s|$)`,
)
// Apache mod_rewrite directives. base64_decode / eval( appearing
// inside a RewriteCond or RewriteRule is a pattern in an attack-query
// blocklist (e.g. Really Simple SSL hardening), not PHP code.
htaccessRewriteDirective = regexp.MustCompile(
`(?i)^\s*Rewrite(?:Cond|Rule)\s`,
)
)
// checkCrontab scans a freshly-written /var/spool/cron/<user> file for the
// known persistence-marker patterns (literal + base64-decoded). Reads from
// the event fd, not the path, so an attacker swapping the file post-event
// cannot redirect us. Root crontab drift is tracked separately via
// hash-baseline by the polled CheckCrontabs.
func (fm *FileMonitor) checkCrontab(fd int, path, procInfo string) {
user := filepath.Base(path)
if user == "" || user == "root" || user == filepath.Base(cronSpoolWatchDir) {
return
}
recordReadTruncation(fd, 65536, "crontab")
data := readFromFd(fd, 65536)
if data == nil {
return
}
matched := checks.MatchCrontabPatternsDeep(string(data), fm.currentCfg())
if len(matched) == 0 {
return
}
fm.sendAlertWithPath(alert.Critical, "suspicious_crontab",
fmt.Sprintf("Suspicious crontab written for user %s: %v", user, matched),
fmt.Sprintf("File: %s\nPatterns matched: %v", path, matched),
path, procInfo)
}
// checkHtaccess reads .htaccess content from the event fd and checks for injection.
// C3 - reads from fd, not path.
func (fm *FileMonitor) checkHtaccess(fd int, path, procInfo string) {
recordReadTruncation(fd, 16384, "htaccess")
data := readFromFd(fd, 16384)
if data == nil {
return
}
for _, rawLine := range strings.Split(string(data), "\n") {
line := strings.TrimSpace(rawLine)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
lower := strings.ToLower(line)
// auto_prepend_file / auto_append_file: suspicious unless the
// directive target matches a known-legit security plugin file.
if strings.Contains(lower, "auto_prepend_file") || strings.Contains(lower, "auto_append_file") {
if htaccessAutoPrependSafeTarget.MatchString(line) {
continue
}
fm.sendAlertWithPath(alert.High, "htaccess_injection_realtime",
fmt.Sprintf("Suspicious .htaccess modification: %s", path),
"auto_prepend_file/auto_append_file target not recognised", path, procInfo)
return
}
// eval( / base64_decode outside a RewriteCond/RewriteRule is a
// tamper signal: .htaccess is not a PHP execution context, so
// the only legit appearance of these tokens is as regex patterns
// inside mod_rewrite attack-blocklists.
if strings.Contains(lower, "eval(") || strings.Contains(lower, "base64_decode") {
if htaccessRewriteDirective.MatchString(line) {
continue
}
fm.sendAlertWithPath(alert.High, "htaccess_injection_realtime",
fmt.Sprintf("Suspicious .htaccess modification: %s", path),
"PHP function reference outside RewriteCond/RewriteRule", path, procInfo)
return
}
}
// Run signature/YARA scanning on .htaccess content
fm.runSignatureScan(data, path, ".htaccess", procInfo)
}
// checkUserINI reads .user.ini content from the event fd and checks for dangerous PHP settings.
// C3 - reads from fd, not path. H6 - proper allow_url_include parsing.
func (fm *FileMonitor) checkUserINI(fd int, path, procInfo string) {
recordReadTruncation(fd, 4096, "user_ini")
data := readFromFd(fd, 4096)
if data == nil {
return
}
content := strings.ToLower(string(data))
dangerous := []struct {
pattern string
desc string
}{
{"allow_url_include", "allow_url_include (remote code inclusion)"},
{"disable_functions", "disable_functions modified"},
}
for _, d := range dangerous {
if !strings.Contains(content, d.pattern) {
continue
}
if d.pattern == "disable_functions" {
for _, line := range strings.Split(content, "\n") {
if strings.HasPrefix(strings.TrimSpace(line), "disable_functions") {
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
val := strings.TrimSpace(parts[1])
if val == "" || val == "\"\"" || val == "none" {
fm.sendAlertWithPath(alert.Critical, "php_config_realtime",
fmt.Sprintf("PHP disable_functions cleared: %s", path),
"All dangerous PHP functions enabled - shell execution possible", path, procInfo)
return
}
}
}
}
}
// H6 - parse the specific line value instead of checking for "on" anywhere
if d.pattern == "allow_url_include" {
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "allow_url_include") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
val := strings.TrimSpace(strings.ToLower(parts[1]))
if val == "on" || val == "1" || val == "\"on\"" || val == "'on'" {
fm.sendAlertWithPath(alert.Critical, "php_config_realtime",
fmt.Sprintf("PHP allow_url_include enabled: %s", path),
"Remote PHP file inclusion is now possible", path, procInfo)
return
}
}
}
}
}
// Run signature/YARA scanning on .user.ini content
fm.runSignatureScan(data, path, ".ini", procInfo)
}
// checkPHPContent reads PHP content from the event fd and checks for malicious patterns.
// C3 - reads from fd, not path. M4 - 32KB scan size.
func (fm *FileMonitor) checkPHPContent(fd int, path, procInfo string) bool {
recordReadTruncation(fd, 32768, "php_check")
data := readFromFd(fd, 32768)
if data == nil {
return false
}
content := strings.ToLower(string(data))
if looksLikePHPWebshell(data) {
fm.sendAlertWithPath(alert.Critical, "webshell_content_realtime",
fmt.Sprintf("Webshell pattern detected: %s", path),
"Request input reaches a dangerous PHP execution primitive", path, procInfo)
return true
}
// Remote payload fetching — paste sites are always suspicious.
// GitHub raw URLs only flag when combined with a dangerous call on
// the same line (legitimate plugins use GitHub for update checks).
pasteURLs := []string{"pastebin.com/raw", "paste.ee/r/", "ghostbin.co/paste/", "hastebin.com/raw/"}
for _, p := range pasteURLs {
if strings.Contains(content, p) {
fm.sendAlertWithPath(alert.Critical, "php_dropper_realtime",
fmt.Sprintf("PHP dropper with paste site URL: %s", path),
fmt.Sprintf("Fetches from: %s", p), path, procInfo)
return true
}
}
githubURLs := []string{"gist.githubusercontent.com", "raw.githubusercontent.com"}
dangerousFns := []string{"file_put_contents(", "fwrite(", "shell_", "passthru(", "popen("}
for _, gh := range githubURLs {
if !strings.Contains(content, gh) {
continue
}
for _, line := range strings.Split(content, "\n") {
if !strings.Contains(line, gh) {
continue
}
for _, fn := range dangerousFns {
if strings.Contains(line, fn) {
fm.sendAlertWithPath(alert.Critical, "php_dropper_realtime",
fmt.Sprintf("PHP dropper fetching from GitHub with dangerous call: %s", path),
fmt.Sprintf("URL: %s, Function: %s", gh, fn), path, procInfo)
return true
}
}
}
}
// eval + decoder combo — require same-line nesting to avoid FPs on
// legitimate plugins that use these functions in unrelated contexts.
evalStr := "eval(" // search target for PHP eval function calls
assertStr := "assert(" // search target for PHP assert function calls
decoders := []string{"base64_decode", "gzinflate", "gzuncompress", "str_rot13", "gzdecode"}
for _, line := range strings.Split(content, "\n") {
lineHasEval := strings.Contains(line, evalStr) || strings.Contains(line, assertStr)
if !lineHasEval {
continue
}
for _, dec := range decoders {
if strings.Contains(line, dec) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Obfuscated PHP detected: %s", path),
fmt.Sprintf("PHP code execution with %s on same line", dec), path, procInfo)
return true
}
}
}
// Fragmented base64 evasion: $a="base"; $b="64_decode"; $c=$a.$b;
if strings.Contains(content, "\"base\"") || strings.Contains(content, "'base'") {
if strings.Contains(content, "64_dec") && strings.Contains(content, evalStr) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Fragmented base64_decode evasion detected: %s", path),
"base64_decode function name split across string variables", path, procInfo)
return true
}
}
// Massive variable concatenation payload ($z .= "xxxx"; repeated thousands of times)
concatCount := strings.Count(content, ".= \"")
if concatCount > 50 && strings.Contains(content, evalStr) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Concatenation payload detected: %s (%d concat ops)", path, concatCount),
"Variable built from hundreds of string concatenations then executed", path, procInfo)
return true
}
// Shell execution with request input
// Uses containsFunc to avoid substring false positives
// (e.g. "WP_Filesystem(" matching "exec(", "preg_match(" matching "exec(")
shellFuncs := []string{"system(", "passthru(", "exec(", "shell_exec(", "popen("}
requestVars := []string{"$_request", "$_post", "$_get", "$_cookie", "$_server"}
hasShell := false
hasInput := false
for _, sf := range shellFuncs {
if containsFunc(content, sf) {
hasShell = true
}
}
for _, rv := range requestVars {
if strings.Contains(content, rv) {
hasInput = true
}
}
if hasShell && hasInput {
// Require shell function + request variable on the SAME line.
// Same-line narrowing is the actual detection: admin panels with
// both tokens in unrelated contexts stay quiet because they never
// co-occur on one line. A file-wide allowlist (e.g. "skip when
// 'wp_filesystem' appears anywhere") would be forgeable by any
// webshell that pastes the token into a comment.
for _, line := range strings.Split(content, "\n") {
lineHasShell := false
lineHasInput := false
for _, sf := range shellFuncs {
if containsFunc(line, sf) {
lineHasShell = true
break
}
}
for _, rv := range requestVars {
if strings.Contains(line, rv) {
lineHasInput = true
break
}
}
if lineHasShell && lineHasInput {
fm.sendAlertWithPath(alert.Critical, "webshell_content_realtime",
fmt.Sprintf("Webshell pattern detected: %s", path),
fmt.Sprintf("Shell execution with request input on same line: %s", strings.TrimSpace(line)), path, procInfo)
return true
}
}
}
// Tail scan: for large files, also check the last 32KB.
// Attackers append payloads (eval+base64) at the end of legitimate PHP files,
// beyond the head scan window. Only do the cheap heuristic checks, not full
// signature scanning (which would be too slow on every large PHP file).
if tailData := readTailFromFd(fd, 32768); tailData != nil {
tail := strings.ToLower(string(tailData))
// Check for eval+decoder on same line in tail
for _, line := range strings.Split(tail, "\n") {
lineHasEval := strings.Contains(line, evalStr) || strings.Contains(line, assertStr)
if !lineHasEval {
continue
}
for _, dec := range decoders {
if strings.Contains(line, dec) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Obfuscated PHP appended to file tail: %s", path),
fmt.Sprintf("PHP code execution with %s found at end of file", dec), path, procInfo)
return true
}
}
}
// Fragmented base64 in tail
if strings.Contains(tail, "\"base\"") || strings.Contains(tail, "'base'") {
if strings.Contains(tail, "64_dec") && strings.Contains(tail, evalStr) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Fragmented base64_decode evasion in file tail: %s", path),
"Payload appended at end of legitimate PHP file", path, procInfo)
return true
}
}
// Concat payload with eval in tail
tailConcatCount := strings.Count(tail, ".= \"")
if tailConcatCount > 50 && strings.Contains(tail, evalStr) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Concatenation payload in file tail: %s (%d concat ops)", path, tailConcatCount),
"Payload appended at end of legitimate PHP file", path, procInfo)
return true
}
}
// Skip signature/YARA scanning for verified CMS core files.
// The wp_core periodic check validates files against official checksums;
// if a file's hash matches a known-clean core file, signature matches
// on it are false positives (e.g. $_POST in wp-includes, mail() in
// PHPMailer, fsockopen() in POP3.php).
if checks.IsVerifiedCMSFile(path) {
return false
}
// External signature + YARA scanning
return fm.runSignatureScan(data, path, filepath.Ext(path), procInfo)
}
// checkHTMLPhishing reads an HTML file and checks for phishing indicators:
// brand impersonation + credential input + redirect/exfiltration.
// Uses event fd for content read and unix.Fstat for size (TOCTOU-safe).
func (fm *FileMonitor) checkHTMLPhishing(fd int, path, procInfo string) {
// Only check files in web-accessible directories.
//
// No path-allowlist below this point: the content gates (credential
// inputs + brand impersonation + exfil/trust-badge) reject legitimate
// framework HTML on their own. A previous allowlist for /wp-admin/,
// /wp-content/themes/, /wp-content/plugins/, /node_modules/, /vendor/,
// /.well-known/ let an attacker who compromised any of those dirs drop
// a credential-harvesting page with full suppression.
if !strings.Contains(path, "/public_html/") {
return
}
var stat unix.Stat_t
if err := unix.Fstat(fd, &stat); err != nil {
return
}
size := stat.Size
if size < 500 || size > 500000 {
return
}
recordReadTruncation(fd, 16384, "html_phishing")
data := readFromFd(fd, 16384)
if data == nil {
return
}
content := strings.ToLower(string(data))
// Must have a form with credential inputs
if !strings.Contains(content, "<form") && !strings.Contains(content, "<input") {
return
}
hasCredInput := strings.Contains(content, "type=\"email\"") ||
strings.Contains(content, "type=\"password\"") ||
strings.Contains(content, "type='email'") ||
strings.Contains(content, "type='password'") ||
strings.Contains(content, "name=\"email\"") ||
strings.Contains(content, "name=\"password\"") ||
strings.Contains(content, "placeholder=\"you@")
if !hasCredInput {
return
}
// Check for brand impersonation
brands := []struct {
name string
patterns []string
}{
{"Microsoft/SharePoint", []string{"sharepoint", "onedrive", "microsoft 365", "outlook web", "office 365"}},
{"Google", []string{"google drive", "google docs", "accounts.google", "gmail"}},
{"Dropbox", []string{"dropbox"}},
{"DocuSign", []string{"docusign"}},
{"Adobe", []string{"adobe sign", "adobe document"}},
{"Apple/iCloud", []string{"icloud", "apple id"}},
{"Webmail", []string{"roundcube", "horde", "webmail login", "zimbra"}},
{"Generic", []string{"secure access", "verify your", "confirm your identity", "account verification"}},
}
brandMatch := ""
for _, b := range brands {
for _, p := range b.patterns {
if strings.Contains(content, p) {
brandMatch = b.name
break
}
}
if brandMatch != "" {
break
}
}
if brandMatch == "" {
return
}
// Check for redirect/exfiltration patterns
exfilPatterns := []string{
"window.location.href", "window.location.replace", "window.location =",
".workers.dev", "fetch(", "xmlhttprequest",
}
hasExfil := false
for _, p := range exfilPatterns {
if strings.Contains(content, p) {
hasExfil = true
break
}
}
// Also check for trust badges (strong phishing signal)
hasTrustBadge := strings.Contains(content, "secured by microsoft") ||
strings.Contains(content, "secured by google") ||
strings.Contains(content, "256-bit encrypted") ||
strings.Contains(content, "256‑bit encrypted")
if hasExfil || hasTrustBadge {
fm.sendAlertWithPath(alert.Critical, "phishing_realtime",
fmt.Sprintf("Phishing page created (%s impersonation): %s", brandMatch, path),
fmt.Sprintf("Size: %d bytes", size), path, procInfo)
return
}
// Run signature/YARA scanning on HTML content not caught by phishing heuristics
fm.runSignatureScan(data, path, ".html", procInfo)
}
// checkCredentialLog reads a text file and checks if it contains harvested
// email:password pairs - output from an active phishing kit. The path is used
// only for the suppression/location checks; the content is read from the
// fanotify event fd (not re-opened by path) so an attacker cannot swap the
// file between the event and the read.
func (fm *FileMonitor) checkCredentialLog(fd int, path, procInfo string) {
if !strings.Contains(path, "/public_html/") {
return
}
// Exclude known config file paths - these legitimately contain email-like patterns.
if strings.HasPrefix(path, "/etc/") {
return
}
for _, suffix := range []string{".conf", ".cfg", ".ini", ".yaml", ".yml"} {
if strings.HasSuffix(path, suffix) {
return
}
}
data := readFromFd(fd, 4096)
if data == nil {
return
}
content := string(data)
lines := strings.Split(content, "\n")
credLines := 0
emailCount := 0
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
if strings.Contains(line, "@") {
emailCount++
for _, delim := range []string{":", "|", "\t", ","} {
parts := strings.SplitN(line, delim, 3)
if len(parts) >= 2 {
p0 := strings.TrimSpace(parts[0])
p1 := strings.TrimSpace(parts[1])
if strings.Contains(p0, "@") && len(p1) > 0 && !strings.Contains(p1, " ") {
credLines++
break
}
}
}
}
}
if credLines >= 5 {
fm.sendAlertWithPath(alert.Critical, "credential_log_realtime",
fmt.Sprintf("Harvested credential log detected: %s", path),
fmt.Sprintf("%d credential lines (email:password format) found", credLines), path, procInfo)
} else if emailCount >= 10 {
fm.sendAlertWithPath(alert.High, "credential_log_realtime",
fmt.Sprintf("Possible harvested email list: %s", path),
fmt.Sprintf("%d email addresses found in %s", emailCount, filepath.Base(path)), path, procInfo)
}
}
// checkPhishingZip checks if a newly created ZIP file matches known
// phishing-kit archive name patterns. The signal is the COMBINATION of a
// brand name and a phishing-suggestive token in the same filename --
// "office365-login.zip", "paypal-verify.zip", "microsoft-secure.zip".
// Plain plugin distribution backups (google-site-kit.zip, mailchimp.zip)
// have a brand without an action verb and don't fire.
func (fm *FileMonitor) checkPhishingZip(path, nameLower, procInfo string) {
if !strings.Contains(path, "/public_html/") {
return
}
// Brand impersonation targets: filenames mimicking a service users log in to.
brands := []string{
"office365", "office 365", "sharepoint", "onedrive",
"microsoft", "outlook", "google", "gmail",
"dropbox", "docusign", "adobe", "wetransfer",
"paypal", "apple", "icloud", "netflix",
"facebook", "instagram", "linkedin",
"webmail", "roundcube", "cpanel",
}
// Phishing-suggestive verbs/nouns. These must co-occur with a brand
// for the rule to fire. "kit" is intentionally NOT here -- many
// official WordPress plugin slugs end in -kit (google-site-kit,
// mailchimp-for-wp-kit) and were the dominant FP source.
phishingIndicators := []string{
"login", "signin", "sign-in", "sign_in",
"verify", "verification",
"secure", "security",
"phish", "scam",
"bank", "account",
"capture", "harvest", "steal",
}
var matchedBrand string
for _, b := range brands {
if strings.Contains(nameLower, b) {
matchedBrand = b
break
}
}
if matchedBrand == "" {
return
}
var matchedIndicator string
for _, p := range phishingIndicators {
if strings.Contains(nameLower, p) {
matchedIndicator = p
break
}
}
if matchedIndicator == "" {
return
}
fm.sendAlertWithPath(alert.High, "phishing_kit_realtime",
fmt.Sprintf("Suspected phishing kit archive uploaded: %s", path),
fmt.Sprintf("Filename combines brand '%s' with phishing indicator '%s'", matchedBrand, matchedIndicator),
path, procInfo)
}
// runSignatureScan runs YAML and YARA signature scanning on file content.
// Returns true if a match was found and an alert was sent.
// Non-critical YAML matches use directory-level dedup to avoid alert floods
// when a plugin directory has many files matching the same rule.
// Critical matches (backdoors, webshells) always alert per-file.
func (fm *FileMonitor) runSignatureScan(data []byte, path, ext, procInfo string) bool {
if scanner := signatures.Global(); scanner != nil {
matches := scanner.ScanContent(data, ext)
if len(matches) > 0 {
m := matches[0]
sev := alert.High
if m.Severity == "critical" {
sev = alert.Critical
}
// Non-critical: dedup by rule+directory so 30 files in the same
// plugin matching the same rule produce one alert, not 30.
// Critical matches always alert per-file (real path for quarantine).
if sev != alert.Critical {
dirKey := m.RuleName + ":" + filepath.Dir(path)
if !fm.shouldAlert("signature_match_realtime", dirKey) {
return true // suppressed by dedup, but still counts as "matched"
}
}
details := fmt.Sprintf("Category: %s\nDescription: %s\nMatched: %s",
m.Category, m.Description, strings.Join(m.Matched, ", "))
fm.sendAlertWithPath(sev, "signature_match_realtime",
fmt.Sprintf("Signature match [%s]: %s", m.RuleName, path),
details, path, procInfo)
// Inline quarantine: move high-confidence malware to quarantine
// immediately instead of waiting for the 5-second batch dispatcher.
// Uses the same 3-gate validation as AutoQuarantineFiles (category +
// library exclusion + entropy >= 5.5) to prevent false positives,
// and the same auto-response policy gate (enabled + quarantine_files)
// so the realtime path never moves files the batch path would not.
if sev == alert.Critical {
finding := alert.Finding{
Severity: sev,
Check: "signature_match_realtime",
Details: details,
FilePath: path,
}
if qPath, ok := checks.InlineQuarantineGated(fm.currentCfg(), finding, path, data); ok {
fm.sendAlert(alert.Critical, "auto_response",
fmt.Sprintf("AUTO-QUARANTINE (inline): %s moved to quarantine", path),
fmt.Sprintf("Quarantined to: %s\nRule: %s", qPath, m.RuleName))
}
}
return true
}
}
if yaraScanner := yara.Active(); yaraScanner != nil {
matches := yaraScanner.ScanBytes(data)
if len(matches) > 0 {
fm.sendAlertWithPath(alert.Critical, "yara_match_realtime",
fmt.Sprintf("YARA rule match [%s]: %s", matches[0].RuleName, path),
fmt.Sprintf("Matched %d YARA rule(s)", len(matches)), path, procInfo)
return true
}
}
return false
}
// M7 - sendAlert uses droppedAlerts counter, separate from droppedEvents.
// No dedup - only used for system-level alerts (overflow reporting) that are
// already ticker-gated. File-related alerts should use sendAlertWithPath.
func (fm *FileMonitor) sendAlert(severity alert.Severity, check, message, details string) {
finding := alert.Finding{
Severity: severity,
Check: check,
Message: message,
Details: details,
Timestamp: time.Now(),
}
select {
case fm.alertCh <- finding:
default:
atomic.AddInt64(&fm.droppedAlerts, 1)
}
}
// sendAlertWithPath is like sendAlert but also sets the FilePath and
// ProcessInfo fields for structured propagation to auto-response.
// Applies per-path deduplication to prevent alert storms from rapid writes.
func (fm *FileMonitor) sendAlertWithPath(severity alert.Severity, check, message, details, filePath, processInfo string) {
if !fm.shouldAlert(check, filePath) {
return
}
finding := alert.Finding{
Severity: severity,
Check: check,
Message: message,
Details: details,
FilePath: filePath,
ProcessInfo: processInfo,
Timestamp: time.Now(),
}
select {
case fm.alertCh <- finding:
default:
atomic.AddInt64(&fm.droppedAlerts, 1)
}
}
// shouldAlert returns true if this check+path combination hasn't been alerted
// recently. Prevents duplicate alerts from rapid writes to the same file.
// Uses LoadOrStore for atomic initial insertion to avoid TOCTOU races
// between concurrent analyzer workers.
func (fm *FileMonitor) shouldAlert(check, filePath string) bool {
if filePath == "" {
return true // no path = no dedup possible
}
key := check + ":" + filePath
now := time.Now()
if v, loaded := fm.alertDedup.LoadOrStore(key, now); loaded {
if now.Sub(v.(time.Time)) < alertDedupTTL {
return false
}
fm.alertDedup.Store(key, now) // refresh TTL on expiry
}
return true
}
// M7 - overflowReporter reports dropped events and alerts separately.
//
// Two timers feed this loop:
// - 1-minute ticker: emits the periodic fanotify_overflow alert,
// resets drop counters, runs reconcileDrops, and evicts stale alert
// dedup entries.
// - reconcileSig: out-of-cycle reconcile triggered by sendEvent when
// sustained drops cross eagerReconcileDropThreshold within the
// current tick. Closes the latency gap between a drop and its
// reconcile read so the file's mtime is still inside reconcileWindow.
func (fm *FileMonitor) overflowReporter() {
defer fm.wg.Done()
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-fm.stopCh:
return
case <-fm.reconcileSig:
// Eager reconcile: do not reset counters, do not emit the
// minute-tick alert. Just walk the tracked dirs and surface
// any interesting file inside reconcileWindow. The minute
// tick will still fire its alert + drain the counters when
// it arrives.
fm.reconcileDrops()
case <-ticker.C:
droppedEv := atomic.SwapInt64(&fm.droppedEvents, 0)
droppedAl := atomic.SwapInt64(&fm.droppedAlerts, 0)
if droppedEv > 0 {
fm.sendAlert(alert.Warning, "fanotify_overflow",
fmt.Sprintf("fanotify event queue overflowed: %d events dropped in last minute", droppedEv),
"Possible event storm (backup, bulk update) or high-volume attack")
// Recover coverage: scan files in directories that saw drops
// so a threat landing during the storm is still detected.
fm.reconcileDrops()
}
if droppedAl > 0 {
fmt.Fprintf(os.Stderr, "[%s] alert channel full: %d alerts dropped in last minute\n", ts(), droppedAl)
}
// Evict stale dedup entries every minute
now := time.Now()
fm.alertDedup.Range(func(key, value any) bool {
if now.Sub(value.(time.Time)) > alertDedupTTL {
fm.alertDedup.Delete(key)
}
return true
})
evictStalePluginStatCache(now)
}
}
}
// evictStalePluginStatCache bounds the package-level plugin update stat cache.
// The compare-delete keeps the minute sweep from removing a fresh stat result
// stored by an analyzer worker after Range observed an older entry.
func evictStalePluginStatCache(now time.Time) {
pluginCutoff := 2 * pluginCacheTTL
pluginStatCache.Range(func(key, value any) bool {
entry, ok := value.(pluginCacheEntry)
if !ok {
pluginStatCache.Delete(key)
return true
}
if now.Sub(entry.ts) > pluginCutoff {
pluginStatCache.CompareAndDelete(key, entry)
}
return true
})
}
// isPHPExtension returns true for all PHP file extensions that can execute code.
// containsFunc checks if content contains a function call that isn't part of
// a longer identifier. Prevents "WP_Filesystem(" matching "exec(" or
// "preg_match(" matching "exec(". Checks the character before the match
// is not a letter, digit, or underscore.
func containsFunc(content, funcCall string) bool {
idx := 0
for {
pos := strings.Index(content[idx:], funcCall)
if pos < 0 {
return false
}
absPos := idx + pos
if absPos == 0 {
return true
}
prev := content[absPos-1]
if (prev < 'a' || prev > 'z') && (prev < 'A' || prev > 'Z') &&
(prev < '0' || prev > '9') && prev != '_' {
return true
}
idx = absPos + len(funcCall)
if idx >= len(content) {
return false
}
}
}
func isPHPExtension(nameLower string) bool {
// Single source of truth shared with the periodic content scanners so the
// realtime and batch paths never drift on which extensions execute PHP.
return checks.IsExecutablePHPName(nameLower)
}
func isCGIExtension(nameLower string) bool {
return strings.HasSuffix(nameLower, ".pl") ||
strings.HasSuffix(nameLower, ".cgi") ||
strings.HasSuffix(nameLower, ".py") ||
strings.HasSuffix(nameLower, ".sh") ||
strings.HasSuffix(nameLower, ".rb")
}
// checkCGIBackdoor reads a CGI script and checks for backdoor patterns.
// Detects Perl/Python/Bash backdoors like the LEVIATHAN toolkit.
func (fm *FileMonitor) checkCGIBackdoor(fd int, path, procInfo string) {
recordReadTruncation(fd, 32768, "cgi_backdoor")
data := readFromFd(fd, 32768)
if data == nil {
return
}
content := strings.ToLower(string(data))
// Backdoor indicators in CGI scripts
indicators := 0
var matched []string
// Indicators weighted by suspicion level. Generic patterns like
// "request_method" and "cmd" removed — they match every CGI script.
shellPatterns := []struct {
pattern string
desc string
}{
{"system(", "system() call"},
{"os.popen", "os.popen() call"},
{"`$", "backtick execution with variable"},
{"content_length", "reads POST body length"},
{"base64_decode", "base64 decoding"},
{"$_post", "PHP POST input"},
{"$_get", "PHP GET input"},
{"param(", "CGI parameter read"},
{"qs.parse", "query string parsing"},
}
for _, sp := range shellPatterns {
if strings.Contains(content, sp.pattern) {
indicators++
matched = append(matched, sp.desc)
}
}
// 4+ indicators = likely backdoor
if indicators >= 4 {
fm.sendAlertWithPath(alert.Critical, "cgi_backdoor_realtime",
fmt.Sprintf("CGI backdoor detected: %s", path),
fmt.Sprintf("Indicators (%d): %s", indicators, strings.Join(matched, ", ")), path, procInfo)
return
}
// CGI scripts in unusual locations (images, css, js directories)
if strings.Contains(path, "/img/") || strings.Contains(path, "/images/") ||
strings.Contains(path, "/css/") || strings.Contains(path, "/js/") ||
strings.Contains(path, "/fonts/") || strings.Contains(path, "/icons/") {
fm.sendAlertWithPath(alert.High, "cgi_suspicious_location_realtime",
fmt.Sprintf("CGI script in non-CGI directory: %s", path),
"Scripts should not exist in image/css/js directories", path, procInfo)
return
}
// Run signature scan on the content
fm.runSignatureScan(data, path, filepath.Ext(path), procInfo)
}
// matchSuppression checks if a file path matches a suppression glob pattern.
// Supports patterns like "*/cache/*", "*/vendor/*", "*.log".
func matchSuppression(pattern, path string) bool {
if pattern == "" {
return false
}
// Direct match against full path
if m, _ := filepath.Match(pattern, path); m {
return true
}
// Match against basename (e.g. "*.log")
if m, _ := filepath.Match(pattern, filepath.Base(path)); m {
return true
}
if !strings.ContainsAny(pattern, "*?[") {
return strings.Contains(path, pattern)
}
if strings.ContainsAny(pattern, "?[") || !hasLeadingAnyDepthSuppressionGlob(pattern) {
return false
}
residue := strings.ReplaceAll(pattern, "*", "")
if strings.Contains(residue, "/") && strings.Trim(residue, "/") != "" {
return strings.Contains(path, residue)
}
return false
}
func hasLeadingAnyDepthSuppressionGlob(pattern string) bool {
firstSlash := strings.Index(pattern, "/")
if firstSlash <= 0 {
return false
}
return strings.Trim(pattern[:firstSlash], "*") == ""
}
// readHead opens a file by path and reads the first maxBytes.
// Kept for path-based checks (HTML phishing, credential logs, ZIP checks)
// that need os.Stat for file size anyway.
func readHead(path string, maxBytes int) []byte {
if maxBytes <= 0 {
return nil
}
// #nosec G304 -- readHead scans files surfaced by fanotify/scanner;
// reading user files for signature analysis is the daemon's purpose.
f, err := os.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
// ReadAll over a LimitReader, not a single f.Read into a pre-sized
// buffer: a short read would hand only a prefix to the detectors and
// silently miss content deeper in the file.
buf, err := io.ReadAll(io.LimitReader(f, int64(maxBytes)))
if err != nil || len(buf) == 0 {
return nil
}
return buf
}
// looksLikePluginUpdate checks if a PHP file in uploads looks like a plugin
// update temp directory (e.g., elementor_t0q9y). Returns true if it matches
// the pattern of a known plugin extracting an update.
// M3 - uses sync.Map cache with 5-minute TTL for plugin directory stat results.
func looksLikePluginUpdate(path string) bool {
// WordPress plugin updates extract to /uploads/{pluginname}_{random}/
// Detect by extracting the directory name under uploads/ and checking
// if a matching plugin exists in wp-content/plugins/.
// No hardcoded whitelist - works for all 60,000+ WP plugins.
uploadsIdx := strings.Index(path, "/wp-content/uploads/")
if uploadsIdx < 0 {
return false
}
wpRoot := path[:uploadsIdx]
afterUploads := path[uploadsIdx+len("/wp-content/uploads/"):]
// Extract the first directory component: "header-footer_7ocsd"
slashIdx := strings.Index(afterUploads, "/")
if slashIdx < 0 {
return false
}
dirName := afterUploads[:slashIdx]
// Strip the random suffix (e.g. "_7ocsd") - WordPress appends _XXXXX
// The plugin name is everything before the last underscore-followed-by-random
pluginName := dirName
if lastUnderscore := strings.LastIndex(dirName, "_"); lastUnderscore > 0 {
suffix := dirName[lastUnderscore+1:]
// Random suffixes are short alphanumeric strings (5-8 chars)
if len(suffix) >= 4 && len(suffix) <= 10 {
pluginName = dirName[:lastUnderscore]
}
}
// Check if a matching plugin directory exists in plugins/
pluginDir := wpRoot + "/wp-content/plugins/" + pluginName
// M3 - check cache first
if cached, ok := pluginStatCache.Load(pluginDir); ok {
entry := cached.(pluginCacheEntry)
if time.Since(entry.ts) < pluginCacheTTL {
return entry.exists
}
}
_, err := os.Stat(pluginDir)
exists := err == nil
pluginStatCache.Store(pluginDir, pluginCacheEntry{
exists: exists,
ts: time.Now(),
})
return exists
}
package daemon
import (
"time"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/mailfwd/adapter"
"github.com/pidginhost/csm/internal/mailfwd/guard"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/store"
)
// forwardGuardBadIPScore is the reputation score at/above which a sender IP is
// treated as bad for the forward-guard's bad_sender_ip signal.
const forwardGuardBadIPScore = 50
// forwardGuardRefreshInterval is how often the bad-IP lookup file is refreshed
// from the reputation DB. exim reads the lsearch file per lookup, so this only
// rewrites a file -- no exim rebuild or reload.
const forwardGuardRefreshInterval = 15 * time.Minute
// forwardGuardReconciler builds the reconciler for the current host. The guard
// is only active on cPanel/exim; elsewhere Reconcile/RefreshBadIPs are no-ops.
func (d *Daemon) forwardGuardReconciler() guard.Reconciler {
return guard.Reconciler{
Guard: adapter.NewEximAdapter(),
Active: platform.Detect().IsCPanel(),
BadIPs: d.forwardGuardBadIPs,
}
}
// forwardGuardBadIPs returns sender IPs the reputation DB scores as bad. Empty
// when the store is unavailable -- the guard then simply holds nothing on the
// bad-IP signal (the null-sender signal is unaffected).
func (d *Daemon) forwardGuardBadIPs() []string {
db := store.Global()
if db == nil {
return nil
}
var ips []string
for ip, e := range db.AllReputation() {
if e.Score >= forwardGuardBadIPScore {
ips = append(ips, ip)
}
}
return ips
}
// reconcileForwardGuard installs or removes the exim forward-guard to match the
// current config. Errors are logged, never fatal: a guard failure must not take
// the daemon down or block mail (fail-open).
func (d *Daemon) reconcileForwardGuard() {
fg := d.currentCfg().EmailProtection.ForwardGuard
if err := d.forwardGuardReconciler().Reconcile(fg); err != nil {
csmlog.Error("forward-guard reconcile failed", "err", err)
}
}
// forwardGuardRefresher periodically refreshes the bad-IP lookup file while the
// guard is enforcing.
func (d *Daemon) forwardGuardRefresher() {
defer d.wg.Done()
ticker := time.NewTicker(forwardGuardRefreshInterval)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
fg := d.currentCfg().EmailProtection.ForwardGuard
if err := d.forwardGuardReconciler().RefreshBadIPs(fg); err != nil {
csmlog.Error("forward-guard bad-IP refresh failed", "err", err)
}
}
}
}
package daemon
import (
"bufio"
"fmt"
"os"
"strings"
"github.com/pidginhost/csm/internal/alert"
)
// parseValiasFileForFindings parses a valiases file and returns findings.
// Used by both the realtime watcher and tests.
func parseValiasFileForFindings(path, domain string, localDomains map[string]bool, knownForwarders []string) []alert.Finding {
return parseValiasFileForFindingsFiltered(path, domain, localDomains, knownForwarders, true)
}
func parseValiasFileForFindingsFiltered(path, domain string, localDomains map[string]bool, knownForwarders []string, includeExternal bool) []alert.Finding {
// #nosec G304 -- path from cPanel valiases directory walk; operator-scoped.
f, err := os.Open(path)
if err != nil {
return nil
}
defer f.Close()
var findings []alert.Finding
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
idx := strings.IndexByte(line, ':')
if idx < 0 {
continue
}
localPart := strings.TrimSpace(line[:idx])
dest := strings.TrimSpace(line[idx+1:])
if localPart == "" || dest == "" {
continue
}
dests := strings.Split(dest, ",")
for _, d := range dests {
d = strings.TrimSpace(d)
if d == "" {
continue
}
// Suppression check
if isKnownForwarderWatcher(localPart, domain, d, knownForwarders) {
continue
}
if strings.HasPrefix(d, "|") {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_pipe_forwarder",
Message: fmt.Sprintf("Pipe forwarder detected: %s@%s -> %s", localPart, domain, d),
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: %s\nFile: %s", domain, localPart, d, path),
})
continue
}
if d == "/dev/null" {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_suspicious_forwarder",
Message: fmt.Sprintf("Mail blackhole: %s@%s -> /dev/null", localPart, domain),
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: /dev/null\nFile: %s", domain, localPart, path),
})
continue
}
// Check if external
atIdx := strings.LastIndexByte(d, '@')
if includeExternal && atIdx >= 0 && atIdx < len(d)-1 {
destDomain := strings.ToLower(d[atIdx+1:])
if !localDomains[destDomain] {
msg := fmt.Sprintf("External forwarder: %s@%s -> %s", localPart, domain, d)
if localPart == "*" {
msg = fmt.Sprintf("Wildcard catch-all to external: *@%s -> %s", domain, d)
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_suspicious_forwarder",
Message: msg,
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: %s\nFile: %s", domain, localPart, d, path),
})
}
}
}
}
return findings
}
// isKnownForwarderWatcher checks if a forwarder matches the suppression list.
// Separate from checks package to avoid import cycles.
func isKnownForwarderWatcher(localPart, domain, dest string, knownForwarders []string) bool {
entry := fmt.Sprintf("%s@%s: %s", localPart, domain, dest)
for _, known := range knownForwarders {
if strings.EqualFold(strings.TrimSpace(known), entry) {
return true
}
}
return false
}
//go:build linux
package daemon
import (
"crypto/sha256"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"unsafe"
"golang.org/x/sys/unix"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/store"
)
// valiasesDir is the inotify watch root. It is a var (not const) so
// tests can redirect it under t.TempDir() without touching the real
// /etc/valiases directory; mirrors cronSpoolWatchDir in fanotify.go.
var valiasesDir = "/etc/valiases"
// ForwarderWatcher watches /etc/valiases/ for changes using inotify.
type ForwarderWatcher struct {
alertCh chan<- alert.Finding
knownForwarders []string
inotifyFd int
}
// NewForwarderWatcher creates a watcher for the valiases directory.
func NewForwarderWatcher(alertCh chan<- alert.Finding, knownForwarders []string) (*ForwarderWatcher, error) {
fd, err := unix.InotifyInit1(unix.IN_CLOEXEC | unix.IN_NONBLOCK)
if err != nil {
return nil, fmt.Errorf("inotify_init1: %w", err)
}
_, err = unix.InotifyAddWatch(fd, valiasesDir, unix.IN_CLOSE_WRITE)
if err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("inotify_add_watch(%s): %w", valiasesDir, err)
}
return &ForwarderWatcher{
alertCh: alertCh,
knownForwarders: knownForwarders,
inotifyFd: fd,
}, nil
}
// Run starts the watch loop. Blocks until stopCh is closed.
func (fw *ForwarderWatcher) Run(stopCh <-chan struct{}) {
buf := make([]byte, 4096)
// Use a polling approach since inotify fd + stopCh coordination
// requires either epoll or periodic polling. Keep it simple.
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-stopCh:
_ = unix.Close(fw.inotifyFd)
return
case <-ticker.C:
fw.readEvents(buf)
}
}
}
func (fw *ForwarderWatcher) readEvents(buf []byte) {
for {
n, err := unix.Read(fw.inotifyFd, buf)
if err != nil || n <= 0 {
return // EAGAIN or error - no more events
}
offset := 0
for offset < n {
if offset+unix.SizeofInotifyEvent > n {
break
}
// #nosec G103 -- inotify returns a packed binary stream;
// reinterpretation is required and bounded by the SizeofInotifyEvent check above.
event := (*unix.InotifyEvent)(unsafe.Pointer(&buf[offset]))
nameLen := int(event.Len)
if nameLen > 0 && offset+unix.SizeofInotifyEvent+nameLen <= n {
nameBytes := buf[offset+unix.SizeofInotifyEvent : offset+unix.SizeofInotifyEvent+nameLen]
// Trim null bytes
name := strings.TrimRight(string(nameBytes), "\x00")
if name != "" && !strings.HasPrefix(name, ".") {
fw.handleFileChange(name)
}
}
offset += unix.SizeofInotifyEvent + nameLen
}
}
}
func (fw *ForwarderWatcher) handleFileChange(domain string) {
path := filepath.Join(valiasesDir, domain)
// 2026-04-27: suppress alerts on the first observation of a valiases file.
// Account transfers via WHM rsync write the entire file fresh; alerting
// on every pre-existing forwarder buries operators in noise. Hash the
// file: baseline external destinations on first sight, alert on external
// destinations only when the hash changes, and keep pipe/dev-null checks
// active because they are dangerous even on a first observation. Mirrors
// auditValiasFile's behaviour in internal/checks/forwarder.go.
db := store.Global()
includeExternal := true
if db != nil {
// #nosec G304 -- path is filepath.Join(valiasesDir, domain) where the
// inotify event already restricted domain to a single path component.
data, err := os.ReadFile(path)
if err == nil {
currentHash := fmt.Sprintf("%x", sha256.Sum256(data))
oldHash, found := db.GetForwarderHash("valiases:" + domain)
_ = db.SetForwarderHash("valiases:"+domain, currentHash)
switch {
case !found:
includeExternal = false
case oldHash == currentHash:
return
default:
includeExternal = true
}
}
}
// Load local domains for external detection
localDomains := loadLocalDomainsForWatcher()
findings := parseValiasFileForFindingsFiltered(path, domain, localDomains, fw.knownForwarders, includeExternal)
for _, f := range findings {
f.Timestamp = time.Now()
f.Details += "\n(detected in realtime via inotify)"
select {
case fw.alertCh <- f:
default:
fmt.Fprintf(os.Stderr, "[%s] Warning: alert channel full, dropping forwarder finding for %s\n",
time.Now().Format("2006-01-02 15:04:05"), domain)
}
}
}
// loadLocalDomainsForWatcher reads local domain files. Separate from the checks
// package version to avoid import cycles.
func loadLocalDomainsForWatcher() map[string]bool {
domains := make(map[string]bool)
for _, path := range []string{"/etc/localdomains", "/etc/virtualdomains"} {
// #nosec G304 -- path iterates a literal slice of cPanel system files.
data, err := os.ReadFile(path)
if err != nil {
continue
}
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if idx := strings.IndexByte(line, ':'); idx > 0 {
line = strings.TrimSpace(line[:idx])
}
domains[strings.ToLower(line)] = true
}
}
return domains
}
package daemon
import (
"strings"
"sync/atomic"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/geoip"
)
// daemonGeoIPDB is a package-level pointer to the GeoIP database, set once
// during daemon init and read by log watcher handlers for country filtering.
var daemonGeoIPDB atomic.Pointer[geoip.DB]
// setGeoIPDB stores the GeoIP database for daemon-wide use and wires the
// ASN resolver the bad_asn_outbound detector uses. Clearing the DB (nil)
// disables ASN classification.
func setGeoIPDB(db *geoip.DB) {
daemonGeoIPDB.Store(db)
if db == nil {
checks.SetASNLookup(nil)
return
}
checks.SetASNLookup(func(ip string) (uint, string) {
info := db.Lookup(ip)
return info.ASN, info.ASOrg
})
}
// getGeoIPDB returns the daemon's GeoIP database, or nil.
func getGeoIPDB() *geoip.DB {
return daemonGeoIPDB.Load()
}
// isTrustedCountry checks if an IP's country is in the trusted list.
// Returns false if GeoIP is unavailable or country can't be resolved.
func isTrustedCountry(ip string, trustedCountries []string) bool {
if len(trustedCountries) == 0 {
return false
}
db := getGeoIPDB()
if db == nil {
return false
}
info := db.Lookup(ip)
if info.Country == "" {
return false
}
for _, tc := range trustedCountries {
if strings.EqualFold(info.Country, tc) {
return true
}
}
return false
}
package daemon
import (
"fmt"
"net"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
// Enhanced access_log handler - catches File Manager, API failures,
// webmail logins, wp-login brute force, and xmlrpc abuse.
func parseAccessLogLineEnhanced(line string, cfg *config.Config) []alert.Finding {
var findings []alert.Finding
ip, _, _, ok := accessLogIPMethodPath(line)
if !ok {
return nil
}
if isInfraIPDaemon(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
return nil
}
lineLower := strings.ToLower(line)
// File Manager write operations (port 2083)
// Only match actual write actions - not read-only calls like get_homedir.
// Skip 401/403 responses - the server rejected the request, no write occurred.
// Match against the request URI only (between first pair of quotes), not the
// full line which includes the referer URL that can contain "upload" in paths.
if strings.Contains(line, "2083") && !strings.Contains(line, "\" 401 ") && !strings.Contains(line, "\" 403 ") {
requestURI := extractRequestURI(lineLower)
filemanWriteActions := []string{
"fileman/save_file", "fileman/upload_files",
"fileman/paste", "fileman/rename", "fileman/delete",
}
for _, action := range filemanWriteActions {
if strings.Contains(requestURI, action) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "cpanel_file_upload_realtime",
Message: fmt.Sprintf("cPanel File Manager write from non-infra IP: %s", ip),
Details: truncateDaemon(line, 300),
SourceIP: ip,
})
break
}
}
}
// API authentication failures (401/403)
// Suppress 401s that are stale-session artifacts from a recent password change.
// When a user changes their password, in-flight browser AJAX requests (notification
// polls, etc.) will 401 against the now-invalidated session - that's expected, not
// an attack. Real API abuse won't correlate with a recent purge for the same account.
if strings.Contains(line, "\" 401 ") || strings.Contains(line, "\" 403 ") {
if strings.Contains(lineLower, "json-api") || strings.Contains(lineLower, "/execute/") {
if !purgeTracker.isPostPurge401(ip) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "api_auth_failure_realtime",
Message: fmt.Sprintf("cPanel API auth failure from %s", ip),
Details: truncateDaemon(line, 300),
SourceIP: ip,
})
}
}
}
// Webmail login attempts (port 2095/2096)
if !cfg.Suppressions.SuppressWebmail {
if strings.Contains(line, "2095") || strings.Contains(line, "2096") {
if strings.Contains(lineLower, "post") {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "webmail_login_realtime",
Message: fmt.Sprintf("Webmail login attempt from non-infra IP: %s", ip),
Details: truncateDaemon(line, 200),
SourceIP: ip,
})
}
}
}
// WHM login attempts (port 2086/2087). CVE-2026-41940 step 1 creates the
// preauth session via a POST to the WHM login endpoint; the CRLF
// injection lands when cpsrvd writes that session file. Surfacing every
// non-infra POST gives ops a brute-force/recon signal even on patched
// hosts. Suppressed under the cPanel-login suppression flag because WHM
// is the admin face of cPanel and shares the same noise profile.
isWHMPort := isWHMLogVhost(line)
if isWHMPort && !cfg.Suppressions.SuppressCpanelLogin {
if strings.Contains(lineLower, "post /login/?login_only=1") {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "whm_login_realtime",
Message: fmt.Sprintf("WHM login attempt from non-infra IP: %s", ip),
Details: truncateDaemon(line, 200),
SourceIP: ip,
})
}
}
// CVE-2026-41940 step 4 fingerprint: a tokenless request to a
// token-required WHM path triggers do_token_denied(), which the watchTowr
// PoC abuses to promote a CRLF-injected session record into the JSON
// cache. Legitimate WHM clients always prefix /scripts*/* with the
// /cpsessXXXXXX/ security token, so the bare prefix on a WHM port is a
// hard signature, not a heuristic. Matches both /scripts/ and /scripts2/
// because the do_token_denied() trigger is path-agnostic - the watchTowr
// PoC happens to use listaccts but any token-required endpoint works.
// Fires regardless of suppression - this is an attack IOC, not a login.
if isWHMPort && isUnauthWHMScriptsRequest(extractRequestURI(line)) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "whm_unauth_scripts_realtime",
Message: fmt.Sprintf("Tokenless WHM scripts request from %s (CVE-2026-41940 IOC)", ip),
Details: truncateDaemon(line, 300),
SourceIP: ip,
})
}
return findings
}
// isWHMLogVhost reports whether the access-log line was served by WHM (port
// 2086 plain or 2087 SSL). cPanel's combined log format ends every line with
// the served vhost as the final double-quoted field (e.g. "host:2087"); we
// anchor on the suffix of that field to avoid matching a port-like substring
// inside a referer URL or user-agent.
func isWHMLogVhost(line string) bool {
vhost := lastQuotedField(line)
return strings.HasSuffix(vhost, ":2087") || strings.HasSuffix(vhost, ":2086")
}
// lastQuotedField returns the content of the final double-quoted field on the
// line, or "" if there isn't a closed pair. cPanel's log writer always emits
// the served vhost as that final field.
func lastQuotedField(line string) string {
end := strings.LastIndex(line, "\"")
if end <= 0 {
return ""
}
start := strings.LastIndex(line[:end], "\"")
if start < 0 {
return ""
}
return line[start+1 : end]
}
// isUnauthWHMScriptsRequest returns true when the request URI targets a path
// under /scripts/ or /scripts2/ without a /cpsessXXXXXX/ security-token
// prefix - the literal step-4 fingerprint of CVE-2026-41940. Query strings
// are stripped before comparison.
func isUnauthWHMScriptsRequest(requestURI string) bool {
parts := strings.SplitN(requestURI, " ", 3)
if len(parts) < 2 {
return false
}
path := parts[1]
if q := strings.Index(path, "?"); q >= 0 {
path = path[:q]
}
if strings.Contains(path, "/cpsess") {
return false
}
return strings.HasPrefix(path, "/scripts/") || strings.HasPrefix(path, "/scripts2/")
}
// parseFTPLogLine handles FTP log entries from /var/log/messages.
func parseFTPLogLine(line string, cfg *config.Config) []alert.Finding {
var findings []alert.Finding
if !strings.Contains(line, "pure-ftpd") {
return nil
}
// Extract the client address. Pure-ftpd's standard syslog format
// prefixes the client as (user@addr), where addr is either an IP
// (DontResolve=yes) or a reverse-resolved hostname (cPanel's
// default with DontResolve=no). We try the pure-ftpd prefix first
// and fall back to the generic "whitespace field starting with a
// digit" scanner. If the pure-ftpd prefix contains a hostname
// rather than an IP, no finding is emitted — we can't hold an
// attacker accountable by hostname, and reverse-DNS lookups in the
// log hot path are not acceptable.
ip := extractPureFTPDClientIP(line)
if ip == "" {
ip = extractIPFromLogDaemon(line)
}
if ip == "" || isInfraIPDaemon(ip, cfg.InfraIPs) {
return nil
}
// Failed authentication
if strings.Contains(line, "Authentication failed") || strings.Contains(line, "auth failed") {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ftp_auth_failure_realtime",
Message: fmt.Sprintf("FTP authentication failed from %s", ip),
Details: truncateDaemon(line, 200),
SourceIP: ip,
})
}
// Successful login from non-infra
if strings.Contains(line, "is now logged in") {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ftp_login_realtime",
Message: fmt.Sprintf("FTP login from non-infra IP: %s", ip),
Details: truncateDaemon(line, 200),
SourceIP: ip,
})
}
return findings
}
// extractPureFTPDClientIP parses the "(user@addr)" prefix that pure-ftpd
// prepends to every log message and returns `addr` only if it parses as
// an IP. Returns empty if the log line contains no prefix at all, the
// prefix is malformed, or addr is a reverse-resolved hostname (in which
// case the caller should not emit a finding since we can't block a
// hostname at the firewall).
func extractPureFTPDClientIP(line string) string {
open := strings.Index(line, "(")
if open < 0 {
return ""
}
rest := line[open+1:]
close := strings.Index(rest, ")")
if close < 0 {
return ""
}
inner := rest[:close]
at := strings.IndexByte(inner, '@')
if at < 0 {
return ""
}
addr := inner[at+1:]
if net.ParseIP(addr) == nil {
return "" // hostname, not an IP — nothing we can block
}
return addr
}
// extractRequestURI extracts the request URI from an access log line.
// Format: ... "METHOD /path HTTP/1.1" ... → returns "/path"
// Returns the content between the first pair of quotes (the request line).
func extractRequestURI(line string) string {
start := strings.Index(line, "\"")
if start < 0 {
return ""
}
end := strings.Index(line[start+1:], "\"")
if end < 0 {
return ""
}
return line[start+1 : start+1+end]
}
func extractIPFromLogDaemon(line string) string {
fields := strings.Fields(line)
for _, f := range fields {
if len(f) >= 7 && f[0] >= '0' && f[0] <= '9' && strings.Count(f, ".") == 3 {
return strings.TrimRight(f, ",:;)([]")
}
}
return ""
}
package daemon
import (
"fmt"
"os"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/platform"
)
// Real-time access log handler for detecting wp-login.php brute force and
// xmlrpc.php abuse. Watches the LiteSpeed/Apache Combined Log Format access
// log and tracks per-IP POST counts using a sliding time window.
//
// Emits the same check names as the periodic CheckWPBruteForce
// (wp_login_bruteforce, xmlrpc_abuse) so the existing auto-block pipeline
// handles them automatically.
const (
// Sliding window for counting requests per IP.
accessLogWindow = 5 * time.Minute
// Thresholds within the window. Lower than periodic checks because
// we're watching in real-time and want fast response.
accessLogWPLoginThreshold = 10
accessLogXMLRPCThreshold = 15
// Eviction: how often to prune expired trackers.
accessLogEvictInterval = 5 * time.Minute
// Cooldown after an IP is flagged: don't re-alert for this long.
// Prevents alert spam while auto-block processes the finding.
accessLogBlockCooldown = 30 * time.Minute
)
// accessLogTracker tracks POST timestamps per endpoint for a single IP.
type accessLogTracker struct {
mu sync.Mutex
wpLoginTimes []time.Time
xmlrpcTimes []time.Time
adminPanelTimes []time.Time
wpLoginAlerted bool
xmlrpcAlerted bool
adminPanelAlerted bool
lastSeen time.Time
generation uint64
evicting bool
}
// accessLogTrackers holds per-IP state. sync.Map for concurrent handler access.
var accessLogTrackers sync.Map // key: IP string → value: *accessLogTracker
// accessLogTrackerCount approximates the live entry count in
// accessLogTrackers. sync.Map exposes no Len(); maintaining a side
// counter is the canonical workaround. Used to trigger eager
// eviction during a DDoS burst, where the 5-min timer alone would
// let the map grow into the hundreds of thousands of unique IPs
// before the next prune.
var (
accessLogTrackerCount atomic.Int64
accessLogEagerEvictTrip = make(chan struct{}, 1)
)
// accessLogEvictSoftCap is the live-entry threshold above which the
// hot path nudges the eviction goroutine to run sooner than the
// 5-min ticker. Picked to be well below typical memory limits even
// at the worst per-tracker size (~256 bytes) so a 100k entry
// burst stays under 32 MB.
const (
accessLogEvictSoftCap int64 = 50000
accessLogEvictTargetPercent int64 = 95
)
type accessLogEvictionCandidate struct {
key string
tracker *accessLogTracker
lastSeen time.Time
generation uint64
}
// discoverAccessLogPath returns the first access log path that exists,
// consulting the platform detector for OS/web-server specific candidates.
func discoverAccessLogPath() string {
info := platform.Detect()
for _, p := range info.AccessLogPaths {
if _, err := os.Stat(p); err == nil {
return p
}
}
return ""
}
// parseAccessLogBruteForce is the LogLineHandler for the Apache/LiteSpeed
// Combined Log Format access log. It parses each line, tracks per-IP POST
// counts to wp-login.php and xmlrpc.php, and emits findings when thresholds
// are crossed.
func parseAccessLogBruteForce(line string, cfg *config.Config) []alert.Finding {
// Fast reject: only care about POST requests to known attack targets.
if !strings.Contains(line, "POST") {
return nil
}
ip, method, path, ok := accessLogIPMethodPath(line)
if !ok {
return nil
}
// Skip infra IPs and loopback.
if ip == "127.0.0.1" || ip == "::1" || isInfraIPDaemon(ip, cfg.InfraIPs) {
return nil
}
if method != "POST" {
return nil
}
isWPLogin := strings.Contains(path, "wp-login.php")
isXMLRPC := strings.Contains(path, "xmlrpc.php")
isAdminPanel := isAdminPanelPath(path)
if !isWPLogin && !isXMLRPC && !isAdminPanel {
return nil
}
now := time.Now()
tracker := loadAccessLogTracker(ip, now)
defer tracker.mu.Unlock()
tracker.lastSeen = now
tracker.generation++
cutoff := now.Add(-accessLogWindow)
var results []alert.Finding
// Once a per-tier alert has fired, skip pruneAndAppend until cooldown
// clears the `alerted` flag. The slice would otherwise grow on every
// event during a sustained burst (potentially tens of thousands of
// entries for a 5-min window at 100 rps), wasting CPU on prune passes
// whose result is never consumed: the `alerted` flag already prevents
// re-alerts, and the eviction loop trims the slice on its own schedule.
//
// Safety: evictAccessLogState resets `alerted` once `lastSeen` is older
// than `cooldownCutoff` (30 min of silence by default). By that point
// the same eviction call has also pruned the slice to empty (window is
// 5 min, so any remaining timestamp is far past cutoff), so the next
// matching event correctly starts a fresh count from 1.
if isWPLogin && !tracker.wpLoginAlerted {
tracker.wpLoginTimes = pruneAndAppend(tracker.wpLoginTimes, cutoff, now)
if len(tracker.wpLoginTimes) >= accessLogWPLoginThreshold {
tracker.wpLoginAlerted = true
results = append(results, alert.Finding{
Severity: alert.Critical,
Check: "wp_login_bruteforce",
Message: fmt.Sprintf("WordPress login brute force from %s: %d POSTs in %v (real-time)", ip, len(tracker.wpLoginTimes), accessLogWindow),
Details: "Real-time detection: high rate of POST requests to wp-login.php",
Timestamp: now,
SourceIP: ip,
})
}
}
if isXMLRPC && !tracker.xmlrpcAlerted {
tracker.xmlrpcTimes = pruneAndAppend(tracker.xmlrpcTimes, cutoff, now)
if len(tracker.xmlrpcTimes) >= accessLogXMLRPCThreshold {
tracker.xmlrpcAlerted = true
results = append(results, alert.Finding{
Severity: alert.Critical,
Check: "xmlrpc_abuse",
Message: fmt.Sprintf("XML-RPC abuse from %s: %d POSTs in %v (real-time)", ip, len(tracker.xmlrpcTimes), accessLogWindow),
Details: "Real-time detection: high rate of POST requests to xmlrpc.php (brute force or amplification)",
Timestamp: now,
SourceIP: ip,
})
}
}
if isAdminPanel && !tracker.adminPanelAlerted {
tracker.adminPanelTimes = pruneAndAppend(tracker.adminPanelTimes, cutoff, now)
if len(tracker.adminPanelTimes) >= accessLogWPLoginThreshold {
tracker.adminPanelAlerted = true
results = append(results, alert.Finding{
Severity: alert.Critical,
Check: "admin_panel_bruteforce",
Message: fmt.Sprintf("Admin panel brute force from %s: %d POSTs in %v (real-time)", ip, len(tracker.adminPanelTimes), accessLogWindow),
Details: "Real-time detection: high rate of POST requests to common admin panel login paths (phpMyAdmin / Joomla)",
Timestamp: now,
SourceIP: ip,
})
}
}
return results
}
func loadAccessLogTracker(ip string, now time.Time) *accessLogTracker {
for {
val, loaded := accessLogTrackers.LoadOrStore(ip, &accessLogTracker{lastSeen: now})
tracker := val.(*accessLogTracker)
if !loaded && accessLogTrackerCount.Add(1) > accessLogEvictSoftCap {
signalAccessLogEagerEviction()
}
tracker.mu.Lock()
if !tracker.evicting {
return tracker
}
tracker.mu.Unlock()
if accessLogTrackers.CompareAndDelete(ip, tracker) {
decrementAccessLogTrackerCount()
}
}
}
func signalAccessLogEagerEviction() {
select {
case accessLogEagerEvictTrip <- struct{}{}:
default:
}
}
// pruneAndAppend removes entries older than cutoff and appends now.
func pruneAndAppend(times []time.Time, cutoff, now time.Time) []time.Time {
recent := times[:0]
for _, t := range times {
if !t.Before(cutoff) {
recent = append(recent, t)
}
}
return append(recent, now)
}
// StartAccessLogEviction starts a background goroutine that prunes expired
// tracker entries to prevent unbounded memory growth. Same pattern as
// StartModSecEviction.
func StartAccessLogEviction(stopCh <-chan struct{}) {
obs.Go("accesslog-eviction", func() {
ticker := time.NewTicker(accessLogEvictInterval)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case now := <-ticker.C:
evictAccessLogState(now)
case <-accessLogEagerEvictTrip:
// Soft-cap signal from the hot path. Run an
// immediate eviction so a DDoS burst of unique
// IPs cannot grow the tracker map past memory
// budget before the next 5-min tick.
evictAccessLogState(time.Now())
}
}
})
}
func evictAccessLogState(now time.Time) {
evictAccessLogStateWithCap(now, accessLogEvictSoftCap)
}
func evictAccessLogStateWithCap(now time.Time, cap int64) {
cutoff := now.Add(-accessLogWindow)
cooldownCutoff := now.Add(-accessLogBlockCooldown)
candidates := make([]accessLogEvictionCandidate, 0)
accessLogTrackers.Range(func(key, value any) bool {
ip := key.(string)
tracker := value.(*accessLogTracker)
tracker.mu.Lock()
if tracker.evicting {
tracker.mu.Unlock()
return true
}
// Prune old timestamps.
tracker.wpLoginTimes = pruneSlice(tracker.wpLoginTimes, cutoff)
tracker.xmlrpcTimes = pruneSlice(tracker.xmlrpcTimes, cutoff)
tracker.adminPanelTimes = pruneSlice(tracker.adminPanelTimes, cutoff)
// Reset alerted flags after cooldown so the IP can be re-detected
// if it comes back after the block expires.
if tracker.wpLoginAlerted && tracker.lastSeen.Before(cooldownCutoff) {
tracker.wpLoginAlerted = false
}
if tracker.xmlrpcAlerted && tracker.lastSeen.Before(cooldownCutoff) {
tracker.xmlrpcAlerted = false
}
if tracker.adminPanelAlerted && tracker.lastSeen.Before(cooldownCutoff) {
tracker.adminPanelAlerted = false
}
empty := len(tracker.wpLoginTimes) == 0 && len(tracker.xmlrpcTimes) == 0 &&
len(tracker.adminPanelTimes) == 0 && tracker.lastSeen.Before(cooldownCutoff)
if empty {
deleteAccessLogTrackerLocked(ip, tracker)
} else {
candidates = append(candidates, accessLogEvictionCandidate{
key: ip,
tracker: tracker,
lastSeen: tracker.lastSeen,
generation: tracker.generation,
})
}
tracker.mu.Unlock()
return true
})
enforceAccessLogTrackerCap(candidates, cap)
}
func enforceAccessLogTrackerCap(candidates []accessLogEvictionCandidate, cap int64) {
if cap <= 0 || accessLogTrackerCount.Load() <= cap {
return
}
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].lastSeen.Before(candidates[j].lastSeen)
})
target := cap * accessLogEvictTargetPercent / 100
for _, candidate := range candidates {
if accessLogTrackerCount.Load() <= target {
return
}
tracker := candidate.tracker
tracker.mu.Lock()
if !tracker.evicting &&
tracker.generation == candidate.generation &&
tracker.lastSeen.Equal(candidate.lastSeen) {
deleteAccessLogTrackerLocked(candidate.key, tracker)
}
tracker.mu.Unlock()
}
}
func deleteAccessLogTrackerLocked(key string, tracker *accessLogTracker) bool {
tracker.evicting = true
if accessLogTrackers.CompareAndDelete(key, tracker) {
decrementAccessLogTrackerCount()
return true
}
return false
}
func decrementAccessLogTrackerCount() {
for {
current := accessLogTrackerCount.Load()
if current <= 0 {
return
}
if accessLogTrackerCount.CompareAndSwap(current, current-1) {
return
}
}
}
// accessLogIPMethodPath extracts client IP, request method, and request path
// from an Apache/LiteSpeed Combined Log Format line without allocating a
// string slice. Hot path: each domlog line that survives the "POST" prefilter
// hits this. strings.Fields allocates len(fields)+1 strings per call; this
// scanner only returns sub-strings that share the input's backing array.
func accessLogIPMethodPath(line string) (ip, method, path string, ok bool) {
var fields [7]string
n := len(line)
i := 0
for f := 0; f < 7; f++ {
for i < n && isAccessLogSpace(line[i]) {
i++
}
if i >= n {
return "", "", "", false
}
start := i
for i < n && !isAccessLogSpace(line[i]) {
i++
}
fields[f] = line[start:i]
}
method = fields[5]
for len(method) > 0 && method[0] == '"' {
method = method[1:]
}
for len(method) > 0 && method[len(method)-1] == '"' {
method = method[:len(method)-1]
}
return fields[0], method, fields[6], true
}
func isAccessLogSpace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\v' || b == '\f' || b == '\r'
}
// isAdminPanelPath returns true for high-confidence non-WP admin panel login
// paths suitable for hard-block auto-response. Drupal /user/login, Tomcat
// /manager/html, generic /admin/login.php, /mysql/ are intentionally EXCLUDED
// because they're either too generic (FP risk on shared hosting) or use a
// different attack shape (Basic auth vs. POST forms). See spec Component 5
// for the full rationale.
func isAdminPanelPath(path string) bool {
return strings.Contains(path, "/phpmyadmin/index.php") ||
strings.Contains(path, "/pma/index.php") ||
strings.Contains(path, "/phpMyAdmin/index.php") ||
strings.Contains(path, "/administrator/index.php")
}
func pruneSlice(times []time.Time, cutoff time.Time) []time.Time {
recent := times[:0]
for _, t := range times {
if !t.Before(cutoff) {
recent = append(recent, t)
}
}
return recent
}
package daemon
import (
"fmt"
"net"
"os"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/store"
)
const (
geoHistoryMaxAge = 30 * 24 * 60 * 60 // 30 days in seconds
geoMinLoginCount = 5 // minimum logins before alerting on new country
geoAlertCooldownH = 24 // hours between alerts per account
)
// parseDovecotLogLine handles Dovecot login lines from /var/log/maillog.
// It tracks per-mailbox login countries and alerts on new-country logins.
func parseDovecotLogLine(line string, cfg *config.Config) []alert.Finding {
// Only process successful login lines
if !strings.Contains(line, "Login: user=<") {
return nil
}
if !strings.Contains(line, "dovecot:") {
return nil
}
user, ip := parseDovecotLoginFields(line)
if user == "" || ip == "" {
return nil
}
// Skip private/loopback IPs
if isPrivateOrLoopback(ip) {
return nil
}
// Skip infra IPs
if isInfraIPDaemon(ip, cfg.InfraIPs) {
return nil
}
// GeoIP lookup
db := getGeoIPDB()
if db == nil {
return nil
}
info := db.Lookup(ip)
country := info.Country
if country == "" {
return nil
}
// Skip trusted countries
for _, tc := range cfg.Suppressions.TrustedCountries {
if strings.EqualFold(country, tc) {
return nil
}
}
// Load history from bbolt
boltDB := store.Global()
if boltDB == nil {
return nil
}
now := time.Now().Unix()
history, _ := boltDB.GetGeoHistory(user)
if history.Countries == nil {
history.Countries = make(map[string]int64)
}
// Prune old country entries
history.Countries = pruneOldCountries(history.Countries, now, geoHistoryMaxAge)
// Increment login count
history.LoginCount++
// Check if this is a new country
_, countryKnown := history.Countries[country]
isNewCountry := !countryKnown && history.LoginCount >= geoMinLoginCount
// Update country timestamp
history.Countries[country] = now
// Persist updated history
if err := boltDB.SetGeoHistory(user, history); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Warning: failed to save geo history for %s: %v\n",
time.Now().Format("2006-01-02 15:04:05"), user, err)
}
if !isNewCountry {
return nil
}
// Rate limit: max 1 alert per account per 24h
alertKey := "email:geo_alert:" + user
lastAlertStr := boltDB.GetMetaString(alertKey)
if lastAlertStr != "" {
if lastAlert, err := time.Parse(time.RFC3339, lastAlertStr); err == nil {
if time.Since(lastAlert) < time.Duration(geoAlertCooldownH)*time.Hour {
return nil
}
}
}
// Record alert time
_ = boltDB.SetMetaString(alertKey, time.Now().Format(time.RFC3339))
// Build "previously seen" country list
var knownCountries []string
for c := range history.Countries {
if c != country {
knownCountries = append(knownCountries, c)
}
}
previousList := "none"
if len(knownCountries) > 0 {
previousList = strings.Join(knownCountries, ", ")
}
countryName := info.CountryName
if countryName == "" {
countryName = country
}
mailbox, domain, tenant := splitMailAccount(user)
return []alert.Finding{{
Severity: alert.High,
Check: "email_suspicious_geo",
Message: fmt.Sprintf("Suspicious email login for %s from %s (%s) - previously seen: %s",
user, countryName, ip, previousList),
Details: fmt.Sprintf("Country: %s (%s)\nIP: %s\nLogin count: %d\nPreviously seen countries: %s",
country, countryName, ip, history.LoginCount, previousList),
SourceIP: ip,
Mailbox: mailbox,
Domain: domain,
TenantID: tenant,
}}
}
// parseDovecotLoginFields extracts user and remote IP from a Dovecot login line.
// Expected format: "... Login: user=<user@domain>, ... rip=1.2.3.4, ..."
func parseDovecotLoginFields(line string) (user, ip string) {
// Extract user from user=<...>
userIdx := strings.Index(line, "user=<")
if userIdx < 0 {
return "", ""
}
rest := line[userIdx+6:]
endIdx := strings.IndexByte(rest, '>')
if endIdx < 0 {
return "", ""
}
user = rest[:endIdx]
// Extract remote IP from rip=...
ripIdx := strings.Index(line, "rip=")
if ripIdx < 0 {
return "", ""
}
rest = line[ripIdx+4:]
endIdx = strings.IndexAny(rest, ", \t\n")
if endIdx < 0 {
ip = rest
} else {
ip = rest[:endIdx]
}
if user == "" || ip == "" {
return "", ""
}
return user, ip
}
// isPrivateOrLoopback returns true if the IP is loopback or RFC1918 private.
func isPrivateOrLoopback(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return true // invalid = skip
}
if ip.IsLoopback() {
return true
}
// RFC1918 checks
private10 := net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)}
private172 := net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)}
private192 := net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}
return private10.Contains(ip) || private172.Contains(ip) || private192.Contains(ip)
}
// pruneOldCountries removes country entries older than maxAge seconds.
func pruneOldCountries(countries map[string]int64, now, maxAge int64) map[string]int64 {
pruned := make(map[string]int64, len(countries))
cutoff := now - maxAge
for c, ts := range countries {
if ts >= cutoff {
pruned[c] = ts
}
}
return pruned
}
package daemon
import (
"fmt"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/modsec"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/store"
)
// modsecIPCounter tracks deny timestamps for a single IP.
type modsecIPCounter struct {
mu sync.Mutex
times []time.Time
escalated bool // set once escalation fires - prevents repeated findings per window
}
var (
modsecDedup sync.Map // key: "IP:ruleID" → value: time.Time
modsecBlockCount sync.Map // key: IP → value: *modsecIPCounter
)
const (
modsecDedupTTL = 60 * time.Second
modsecEvictInterval = 10 * time.Minute
modsecDefaultEscalationWin = 10 * time.Minute
modsecDefaultEscalationHits = 3
)
// modsecEscalationParams returns the operator-tuned (hits, window) pair,
// falling back to the shipped defaults when either is unset or
// non-positive. nil cfg returns the defaults so test wiring without a
// config still behaves predictably.
func modsecEscalationParams(cfg *config.Config) (int, time.Duration) {
hits := modsecDefaultEscalationHits
win := modsecDefaultEscalationWin
if cfg == nil {
return hits, win
}
if cfg.Thresholds.ModSecEscalationHits > 0 {
hits = cfg.Thresholds.ModSecEscalationHits
}
if cfg.Thresholds.ModSecEscalationWindowMin > 0 {
win = time.Duration(cfg.Thresholds.ModSecEscalationWindowMin) * time.Minute
}
return hits, win
}
// parseModSecLogLine parses a ModSecurity log line from Apache or LiteSpeed
// error logs and returns findings for blocked requests or warnings.
func parseModSecLogLine(line string, cfg *config.Config) []alert.Finding {
// Fast reject: not a ModSecurity line.
if !strings.Contains(line, "ModSecurity:") && !strings.Contains(line, "[MODSEC]") {
return nil
}
isLiteSpeed := strings.Contains(line, "[MODSEC]")
var ip, ruleID, msg, hostname, uri string
if isLiteSpeed {
ip = extractLiteSpeedIP(line)
ruleID = extractModSecField(line, `[id "`, `"]`)
msg = extractModSecField(line, `[msg "`, `"]`)
hostname = extractModSecField(line, `[hostname "`, `"]`)
uri = extractModSecField(line, `[uri "`, `"]`)
} else {
// Apache format: [client IP] or [client IP:port]
raw := extractModSecField(line, "[client ", "]")
// Strip port if present (Apache 2.4 uses "IP:port").
if idx := strings.LastIndex(raw, ":"); idx > 0 {
// Make sure it's not an IPv6 address (contains multiple colons).
if strings.Count(raw, ":") == 1 {
raw = raw[:idx]
}
}
ip = raw
ruleID = extractModSecField(line, `[id "`, `"]`)
msg = extractModSecField(line, `[msg "`, `"]`)
hostname = extractModSecField(line, `[hostname "`, `"]`)
uri = extractModSecField(line, `[uri "`, `"]`)
}
// Skip infra and loopback IPs - consistent with other realtime handlers
// (handlers.go:22, autoblock.go:149). Prevents noisy findings and
// false escalation from proxied or locally forwarded Apache traffic.
if ip != "" && (isInfraIPDaemon(ip, cfg.InfraIPs) || ip == "127.0.0.1" || ip == "::1") {
return nil
}
// Determine check name.
//
// Apache mod_security writes the action verbatim into the message, so
// "Access denied" is a reliable block signal. LiteSpeed's mod_security
// front-end writes every match as "triggered!" with no action context,
// regardless of whether the rule's declared action denied the request
// or merely incremented a counter. Without further context every match
// would be counted as a deny, escalating to a 24-hour auto-block after
// three pass-action triggers from the same IP. Consult the rule-action
// registry built at daemon start: pass/log/allow rules produce a
// warning, deny/drop/block produce a block, and an unknown rule ID
// keeps the legacy default-to-block behaviour for safety.
check := "modsec_warning_realtime"
if strings.Contains(line, "Access denied") {
check = "modsec_block_realtime"
} else if isLiteSpeed && strings.Contains(line, "triggered!") {
check = classifyLiteSpeedTrigger(ruleID)
}
// Determine severity from rule ID.
// Individual blocks are informational - ModSecurity already denied the
// request. Only the escalation finding (3+ from same IP) is CRITICAL
// because it triggers auto-block at the firewall level.
severity := alert.Warning
if ruleNum, err := strconv.Atoi(ruleID); err == nil {
switch {
case ruleNum >= 900000 && ruleNum <= 900999:
severity = alert.High // CSM custom rules - attack blocked, informational
case ruleNum >= 910000:
severity = alert.High // OWASP CRS
}
}
// Build message.
message := fmt.Sprintf("ModSecurity rule %s", ruleID)
if check == "modsec_block_realtime" {
message = fmt.Sprintf("ModSecurity blocked request: rule %s", ruleID)
}
if ip != "" {
message += fmt.Sprintf(" from %s", ip)
}
if hostname != "" {
message += fmt.Sprintf(" on %s", hostname)
}
if uri != "" {
message += fmt.Sprintf(" uri=%s", uri)
}
if msg != "" {
message += fmt.Sprintf(" - %s", msg)
}
// Store structured details so the web UI can extract fields consistently
// regardless of whether the source was Apache or LiteSpeed format.
details := fmt.Sprintf("Rule: %s\nMessage: %s\nHostname: %s\nURI: %s\nRaw: %s",
ruleID, msg, hostname, uri, truncateDaemon(line, 300))
return []alert.Finding{{
Severity: severity,
Check: check,
Message: message,
Details: details,
SourceIP: ip,
Domain: domainOrEmpty(hostname),
}}
}
// domainOrEmpty returns hostname unless it parses as a bare IP address
// (v4 or v6, with or without surrounding brackets). Vhosts served on a
// raw IP would otherwise key the incident bucket on the IP literal,
// causing two unrelated victim sites that happen to be reachable over
// their public IPs to merge into a single bucket.
func domainOrEmpty(hostname string) string {
if hostname == "" {
return ""
}
probe := strings.TrimPrefix(hostname, "[")
probe = strings.TrimSuffix(probe, "]")
if net.ParseIP(probe) != nil {
return ""
}
return hostname
}
// classifyLiteSpeedTrigger decides whether a LiteSpeed mod_security
// "triggered!" line represents a real deny (block_realtime) or merely an
// informational pass-action match (warning_realtime), based on the rule's
// declared action in the rule-action registry. Unknown rules default to
// block to preserve coverage when the registry has not been populated yet
// (very early daemon startup, or hosts without parseable rule files).
func classifyLiteSpeedTrigger(ruleID string) string {
num, err := strconv.Atoi(ruleID)
if err != nil {
return "modsec_block_realtime"
}
reg := modsec.Global()
if reg == nil {
return "modsec_block_realtime"
}
action, known := reg.Action(num)
if !known {
return "modsec_block_realtime"
}
if modsec.IsBlockingAction(action) {
return "modsec_block_realtime"
}
return "modsec_warning_realtime"
}
// extractModSecField extracts the value between start and end delimiters.
// Returns empty string if delimiters are not found.
func extractModSecField(line, start, end string) string {
idx := strings.Index(line, start)
if idx < 0 {
return ""
}
rest := line[idx+len(start):]
endIdx := strings.Index(rest, end)
if endIdx < 0 {
return ""
}
return rest[:endIdx]
}
// extractLiteSpeedIP extracts the client IP from a LiteSpeed log line.
// Format: [IP:PORT-CONN#VHOST] e.g. [122.9.114.57:41920-13#APVH_*_server.example.com]
func extractLiteSpeedIP(line string) string {
// Find the field that looks like [IP:PORT-CONN#VHOST]
// It appears as a bracketed field containing # and a port separator.
start := 0
for {
openBracket := strings.Index(line[start:], "[")
if openBracket < 0 {
return ""
}
openBracket += start
closeBracket := strings.Index(line[openBracket:], "]")
if closeBracket < 0 {
return ""
}
closeBracket += openBracket
field := line[openBracket+1 : closeBracket]
// LiteSpeed connection field has # (for VHOST) and contains IP:PORT-CONN#
if strings.Contains(field, "#") && strings.Contains(field, ":") && strings.Contains(field, "-") {
// Extract IP part: everything before the first ':'
colonIdx := strings.Index(field, ":")
if colonIdx > 0 {
ip := field[:colonIdx]
// Validate it looks like an IP (has dots).
if strings.Count(ip, ".") == 3 {
return ip
}
}
}
start = closeBracket + 1
}
}
// parseModSecLogLineDeduped wraps parseModSecLogLine with dedup and block
// threshold escalation. It is the handler registered with the log watcher.
//
// Order of operations (critical for correctness):
// 1. Parse the raw line.
// 2. ALWAYS increment the block escalation counter (even if dedup will suppress).
// 3. Then check dedup - suppress the base finding if a duplicate, but still
// return any escalation finding from step 2.
func parseModSecLogLineDeduped(line string, cfg *config.Config) []alert.Finding {
raw := parseModSecLogLine(line, cfg)
if len(raw) == 0 {
return nil
}
f := raw[0]
now := time.Now()
var results []alert.Finding
// --- Step 1: block escalation (before dedup) ---
// Extract IP and rule ID directly from the raw log line - NOT from the
// finding message, which could be manipulated via log injection.
ip := extractModSecField(line, "[client ", "]")
if ip == "" {
ip = extractLiteSpeedIP(line)
}
// Strip port from Apache 2.4 format (IP:port)
if strings.Count(ip, ":") == 1 {
if idx := strings.LastIndex(ip, ":"); idx > 0 {
ip = ip[:idx]
}
}
ruleID := extractModSecField(line, `[id "`, `"]`)
ruleNum, _ := strconv.Atoi(ruleID)
isBlock := f.Check == "modsec_block_realtime"
isCSM := isBlock && ruleNum >= 900000 && ruleNum <= 900999
// Record hit for per-rule stats (24h hourly buckets)
if ruleNum >= 900000 && ruleNum <= 900999 {
if sdb := store.Global(); sdb != nil {
sdb.IncrModSecRuleHit(ruleNum, now)
}
}
// Check if this rule is excluded from auto-block escalation.
// Configurable via the Rules page in the web UI.
noEscalate := false
if db := store.Global(); db != nil {
noEscalate = db.GetModSecNoEscalateRules()[ruleNum]
}
if isBlock && ip != "" && ruleID != "" && !noEscalate {
hits, win := modsecEscalationParams(cfg)
if recordModSecDeny(ip, now, hits, win) {
check := "modsec_block_escalation"
label := "ModSecurity"
if isCSM {
check = "modsec_csm_block_escalation"
label = "CSM rule"
}
results = append(results, alert.Finding{
Severity: alert.Critical,
Check: check,
Message: fmt.Sprintf("%s escalation: %d+ denies from %s within %v", label, hits, ip, win),
Details: truncateDaemon(line, 400),
SourceIP: ip,
})
}
}
// --- Step 2: Dedup ---
dedupKey := ip + ":" + ruleID
if prev, loaded := modsecDedup.Load(dedupKey); loaded {
if now.Sub(prev.(time.Time)) < modsecDedupTTL {
// Suppress the base finding but still return any escalation.
if len(results) > 0 {
return results
}
return nil
}
}
modsecDedup.Store(dedupKey, now)
results = append(results, f)
return results
}
// recordModSecDeny records a deny event for the given IP and returns true
// if the escalation threshold has been reached (>= hits within window).
// hits and window are operator-configurable knobs; recordModSecDeny does
// no defaulting -- callers pull defaults via modsecEscalationParams so
// every code path uses the same fallback rule.
func recordModSecDeny(ip string, now time.Time, hits int, window time.Duration) bool {
val, _ := modsecBlockCount.LoadOrStore(ip, &modsecIPCounter{})
ctr := val.(*modsecIPCounter)
ctr.mu.Lock()
defer ctr.mu.Unlock()
// Prune entries older than the escalation window.
cutoff := now.Add(-window)
recent := ctr.times[:0]
for _, t := range ctr.times {
if !t.Before(cutoff) {
recent = append(recent, t)
}
}
recent = append(recent, now)
ctr.times = recent
if len(recent) >= hits && !ctr.escalated {
ctr.escalated = true
return true
}
return false
}
// StartModSecEviction starts a background goroutine that prunes expired dedup
// and counter entries every modsecEvictInterval to prevent unbounded memory
// growth. It returns when stopCh is closed. cfgFn supplies the live
// thresholds at each tick so SIGHUP edits to the escalation window take
// effect without restarting the evictor.
func StartModSecEviction(stopCh <-chan struct{}, cfgFn func() *config.Config) {
if cfgFn == nil {
cfgFn = func() *config.Config { return nil }
}
obs.Go("modsec-eviction", func() {
ticker := time.NewTicker(modsecEvictInterval)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case now := <-ticker.C:
hits, win := modsecEscalationParams(cfgFn())
evictModSecState(now, hits, win)
}
}
})
}
// discoverModSecLogPath returns the path to the web server error log that
// CSM should tail for ModSecurity denies. Config override wins, then the
// first candidate from platform detection that actually exists.
func discoverModSecLogPath(cfg *config.Config) string {
if cfg.ModSecErrorLog != "" {
return cfg.ModSecErrorLog
}
return firstExistingPath(platform.Detect().ErrorLogPaths)
}
// firstExistingPath returns the first path in the list that exists on disk,
// or "" if none do. Pure function so tests can exercise it directly.
func firstExistingPath(candidates []string) string {
for _, p := range candidates {
if _, err := os.Stat(p); err == nil {
return p
}
}
return ""
}
// evictModSecState prunes expired entries from modsecDedup and modsecBlockCount.
// hits and window mirror the live thresholds so the cooldown reset matches
// what recordModSecDeny would compute on the next event.
func evictModSecState(now time.Time, hits int, window time.Duration) {
// Prune dedup entries older than modsecDedupTTL.
modsecDedup.Range(func(key, value any) bool {
if now.Sub(value.(time.Time)) >= modsecDedupTTL {
modsecDedup.Delete(key)
}
return true
})
// Prune counter entries.
cutoff := now.Add(-window)
modsecBlockCount.Range(func(key, value any) bool {
ctr := value.(*modsecIPCounter)
ctr.mu.Lock()
recent := ctr.times[:0]
for _, t := range ctr.times {
if !t.Before(cutoff) {
recent = append(recent, t)
}
}
ctr.times = recent
empty := len(recent) == 0
if len(recent) < hits {
ctr.escalated = false // reset cooldown when counter drops below threshold
}
ctr.mu.Unlock()
if empty {
modsecBlockCount.Delete(key)
}
return true
})
}
package daemon
import (
"context"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/bpf"
"github.com/pidginhost/csm/internal/firewall/rollback"
"github.com/pidginhost/csm/internal/health"
"github.com/pidginhost/csm/internal/integrity"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/store"
"github.com/pidginhost/csm/internal/updatecheck"
)
// Hostname implements health.Provider.
func (d *Daemon) Hostname() string {
cfg := d.currentCfg()
if cfg == nil {
return ""
}
return cfg.Hostname
}
// StartedAt implements health.Provider.
func (d *Daemon) StartedAt() time.Time {
return d.startTime
}
// LatestScan implements health.Provider.
func (d *Daemon) LatestScan() time.Time {
if d.store == nil {
return time.Time{}
}
return d.store.LatestScanTime()
}
// BaselineAt implements health.Provider. Returns the persisted first-start
// timestamp recorded by EnsureBaseline on the daemon's first successful
// boot against this state directory. Reinstalls and upgrades preserve
// the original value.
func (d *Daemon) BaselineAt() time.Time {
if d.store == nil {
return time.Time{}
}
return d.store.BaselineAt()
}
// StoreHealthy implements health.Provider.
func (d *Daemon) StoreHealthy() bool {
s := store.Global()
if s == nil {
return false
}
return s.IsHealthy()
}
// StoreSizeMB implements health.Provider.
func (d *Daemon) StoreSizeMB() float64 {
s := store.Global()
if s == nil {
return 0
}
return float64(s.SizeBytes()) / (1024 * 1024)
}
// SeverityCounts implements health.Provider.
// Buckets the latest findings by severity name.
func (d *Daemon) SeverityCounts() map[string]int {
out := map[string]int{"critical": 0, "high": 0, "warning": 0}
if d.store == nil {
return out
}
for _, f := range d.store.LatestFindings() {
switch f.Severity {
case alert.Critical:
out["critical"]++
case alert.High:
out["high"]++
default:
out["warning"]++
}
}
return out
}
// BlocklistSize implements health.Provider.
//
// The firewall engine is the authoritative source: production code never
// writes the parallel bbolt `fw:blocked` bucket the previous implementation
// read, so /api/v1/status reported a stale count (cluster6 showed 25
// against 909 in the real engine state). Engine.BlockedCount() reads the
// same state file Status() and `csm firewall status` use, with expired
// entries pruned.
func (d *Daemon) BlocklistSize() int {
if d.fwEngine == nil {
return 0
}
return d.fwEngine.BlockedCount()
}
// IncidentsOpen implements health.Provider. Returns the count of
// open + contained incidents in the correlator. Falls back to 0 if
// the correlator has not been constructed yet (e.g. very early
// startup or shutdown).
func (d *Daemon) IncidentsOpen() int {
if incidentCorrelator == nil {
return 0
}
return incidentCorrelator.OpenCount()
}
// BPFEnforcementActive implements health.Provider. Reports the
// configured enforcement state. Phase 4 of the BPF Incident Response
// Roadmap. Reads the live config via Daemon.currentCfg(); falls back
// to false if config is nil (very early startup).
func (d *Daemon) BPFEnforcementActive() bool {
cfg := d.currentCfg()
if cfg == nil {
return false
}
return cfg.BPFEnforcement.Enabled &&
cfg.BPFEnforcement.DirectSMTPEgress &&
bpf.ActiveKind("connection_tracker") == bpf.BackendBPF
}
// HistoryCount implements health.Provider.
func (d *Daemon) HistoryCount() int {
s := store.Global()
if s == nil {
return 0
}
return s.HistoryCount()
}
// ConfigHash implements health.Provider.
func (d *Daemon) ConfigHash() string {
cfg := d.currentCfg()
if cfg == nil {
return ""
}
return cfg.Integrity.ConfigHash
}
// BinaryHash implements health.Provider.
// Computes the hash on first call via the known binary path; returns empty on error.
func (d *Daemon) BinaryHash() string {
if d.binaryPath == "" {
return ""
}
h, err := integrity.HashFile(d.binaryPath)
if err != nil {
return ""
}
return h
}
// DryRunBlocksCount implements health.Provider.
// Returns the number of firewall blocks that were intercepted by dry_run.
func (d *Daemon) DryRunBlocksCount() int {
s := store.Global()
if s == nil {
return 0
}
return s.DryRunBlocksCount()
}
// AutomationStatus implements health.Provider.
func (d *Daemon) AutomationStatus() health.AutomationStatus {
cfg := d.currentCfg()
out := health.AutomationStatus{
DryRunBlocks: d.DryRunBlocksCount(),
LastAction: d.lastAutomationAction(),
}
if cfg != nil {
out.AutoResponseEnabled = cfg.AutoResponse.Enabled
out.AutoResponseBlockIPs = cfg.AutoResponse.BlockIPs
out.AutoResponseDryRun = cfg.AutoResponseDryRunEnabled()
out.ChallengeEnabled = cfg.Challenge.Enabled
out.ChallengePortGateEnabled = cfg.Challenge.PortGate.Enabled
}
if d.ipList != nil {
out.ChallengePending = d.ipList.Count()
}
out.ChallengePortGateActive = d.challengeGate != nil
if mgr := rollback.Global(); mgr != nil {
st := mgr.Status()
out.FirewallRollbackPending = st.Pending
out.FirewallRollbackSecondsRemain = st.SecondsRemaining
}
return out
}
func (d *Daemon) lastAutomationAction() *health.AutomationAction {
if d.store == nil {
return nil
}
d.automationActionMu.Lock()
defer d.automationActionMu.Unlock()
if !d.automationActionCached.IsZero() && time.Since(d.automationActionCached) < lastAutomationActionTTL {
return d.automationActionCache
}
d.automationActionCache = d.computeLastAutomationAction()
d.automationActionCached = time.Now()
return d.automationActionCache
}
func (d *Daemon) computeLastAutomationAction() *health.AutomationAction {
var (
best alert.Finding
ok bool
)
consider := func(findings []alert.Finding) {
for _, f := range findings {
if !isAutomationActionCheck(f.Check) {
continue
}
if !ok || f.Timestamp.After(best.Timestamp) {
best = f
ok = true
}
}
}
consider(d.store.LatestFindings())
if history, _ := d.store.ReadHistory(100, 0); len(history) > 0 {
consider(history)
}
if !ok {
return nil
}
return &health.AutomationAction{
Check: best.Check,
Message: best.Message,
Timestamp: best.Timestamp,
}
}
func isAutomationActionCheck(check string) bool {
switch check {
case "auto_block", "auto_response", "challenge_route":
return true
}
return strings.HasPrefix(check, "email_php_relay_action_")
}
// UpdateInfo implements health.Provider. Returns the latest cached
// release-check result, or zero value if the checker was disabled
// or has not completed its first poll yet.
func (d *Daemon) UpdateInfo() health.UpdateInfo {
if d.updateChecker == nil {
return health.UpdateInfo{}
}
info := d.updateChecker.Latest()
return health.UpdateInfo{
LatestVersion: info.LatestVersion,
Available: info.Available,
Source: info.Source,
CheckedAt: info.CheckedAt,
Err: info.Err,
}
}
// startUpdateChecker wires the updatecheck.Checker. No-op when
// updates.check_enabled is false (operator opt-out, e.g. air-gapped
// deployments) or when the running binary is "dev" (still useful but
// the banner will always show).
func (d *Daemon) startUpdateChecker() {
cfg := d.cfg
if cfg == nil || !cfg.UpdatesCheckEnabled() {
return
}
interval := cfg.UpdatesInterval()
pkgProbe := selectPackageProbe(cfg.UpdatesPackageName())
d.updateChecker = updatecheck.New(updatecheck.Options{
CurrentVersion: d.version,
Interval: interval,
GitHubAPIURL: cfg.Updates.GitHubAPIURL,
PackageProbe: pkgProbe,
LogErr: func(source string, err error) {
csmlog.Debug("update check probe failed", "source", source, "err", err)
},
})
d.wg.Add(1)
obs.Go("update-checker", func() {
defer d.wg.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() { <-d.stopCh; cancel() }()
d.updateChecker.Run(ctx)
})
}
// selectPackageProbe returns an apt or dnf probe based on the
// detected OS family, or nil when the host runs neither (binary
// installs, source builds, etc.).
func selectPackageProbe(packageName string) updatecheck.PackageProbe {
info := platform.Detect()
switch {
case info.IsDebianFamily():
return updatecheck.AptProbe(packageName)
case info.IsRHELFamily():
return updatecheck.DnfProbe(packageName)
default:
return nil
}
}
// compile-time check: Daemon satisfies health.Provider.
var _ health.Provider = (*Daemon)(nil)
package daemon
import (
"fmt"
"os"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/integrity"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/store"
)
// Reload outcome labels for csm_config_reloads_total. Keep these in
// sync with docs/src/metrics.md.
const (
reloadResultSuccess = "success"
reloadResultError = "error"
reloadResultRestartRequired = "restart_required"
reloadResultNoop = "noop"
)
var (
reloadMetric *metrics.CounterVec
reloadMetricOnce sync.Once
)
// recordReloadResult bumps csm_config_reloads_total by one for the
// given outcome label, registering the metric on first use.
func recordReloadResult(result string) {
reloadMetricOnce.Do(func() {
reloadMetric = metrics.NewCounterVec(
"csm_config_reloads_total",
"SIGHUP config reload attempts, by outcome. result=success when safe fields were swapped in place; restart_required when the edit touched a field that needs a full restart (live config unchanged); error on YAML parse, validation, or re-sign failure (live config unchanged); noop when the edit was semantically identical to the running config.",
[]string{"result"},
)
metrics.MustRegister("csm_config_reloads_total", reloadMetric)
})
reloadMetric.With(result).Inc()
}
// reloadConfig re-reads the on-disk csm.yaml plus configured drop-ins,
// validates it, diffs against the current live config, and - if every
// change is marked safe for live reload - installs the new config via
// config.SetActive. Only the main config file is re-signed; drop-ins are
// never written back into csm.yaml.
//
// Failure modes all leave the live config untouched:
//
// - YAML parse error: Critical `config_reload_error` finding.
// - Validation error: Critical `config_reload_error` finding.
// - Restart-required fields changed: Warning
// `config_reload_restart_required` finding, listing the offending
// field names.
// - Re-signing failure: Critical `config_reload_error` finding.
//
// ROADMAP item 7.
func (d *Daemon) reloadConfig() {
oldCfg := d.activeOrStartupCfg()
cfgPath := oldCfg.ConfigFile
fmt.Fprintf(os.Stderr, "[%s] SIGHUP: reloading config from %s\n", ts(), cfgPath)
newCfg, err := config.LoadWithDir(cfgPath, oldCfg.ConfigDir)
if err != nil {
recordReloadResult(reloadResultError)
d.emitReloadFinding(alert.Critical, "config_reload_error",
fmt.Sprintf("SIGHUP reload: parse failed (%v); keeping old config", err))
return
}
for _, r := range config.Validate(newCfg) {
if r.Level == "error" {
recordReloadResult(reloadResultError)
d.emitReloadFinding(alert.Critical, "config_reload_error",
fmt.Sprintf("SIGHUP reload: validation error on %q: %s; keeping old config",
r.Field, r.Message))
return
}
}
changes := config.Diff(oldCfg, newCfg)
if len(changes) == 0 {
recordReloadResult(reloadResultNoop)
fmt.Fprintf(os.Stderr, "[%s] SIGHUP: no config changes detected\n", ts())
return
}
if config.RestartRequired(changes) {
var offenders []string
for _, c := range changes {
if c.Tag != config.TagSafe {
offenders = append(offenders, c.Field)
}
}
// The edit passed Validate, so the file on disk is
// loadable. Re-sign integrity.config_hash to match the
// edited content -- otherwise the next daemon restart
// (systemd, manual, crash recovery) trips the startup
// integrity check and crash-loops. Update the live cfg's
// stored ConfigHash in lock-step so the periodic
// integrity.Verify(currentCfg()) does not see a disk /
// memory divergence and fire spurious tamper alerts.
//
// Any error re-signing degrades to "live config unchanged,
// file unchanged, operator must rehash manually"; we still
// emit the warning so they know to act.
if err := d.signAndSaveReloadedConfig(oldCfg, newCfg); err == nil {
resynced := *oldCfg
resynced.Integrity = newCfg.Integrity
resynced.ConfigFile = cfgPath
resynced.ConfigDir = oldCfg.ConfigDir
publishActiveConfig(&resynced, "SIGHUP")
} else {
fmt.Fprintf(os.Stderr, "[%s] config_reload_restart_required: re-sign failed (%v); file and live hash remain mismatched until operator runs `csm rehash`\n",
ts(), err)
}
recordReloadResult(reloadResultRestartRequired)
d.emitReloadFinding(alert.Warning, "config_reload_restart_required",
fmt.Sprintf("SIGHUP reload: restart-required fields changed: %v; live config unchanged, main config re-signed if needed for next restart",
offenders))
return
}
if err := d.signAndSaveReloadedConfig(oldCfg, newCfg); err != nil {
recordReloadResult(reloadResultError)
d.emitReloadFinding(alert.Critical, "config_reload_error",
fmt.Sprintf("SIGHUP reload: re-signing config failed: %v; live config unchanged", err))
return
}
newCfg.ConfigFile = cfgPath
newCfg.ConfigDir = oldCfg.ConfigDir
if err := installAccountExtractorFromConfig(newCfg); err != nil {
recordReloadResult(reloadResultError)
d.emitReloadFinding(alert.Critical, "config_reload_error",
fmt.Sprintf("SIGHUP reload: account extractor update failed: %v; live config unchanged", err))
return
}
publishActiveConfig(newCfg, "SIGHUP")
recordReloadResult(reloadResultSuccess)
var names []string
for _, c := range changes {
names = append(names, c.Field)
}
fmt.Fprintf(os.Stderr, "[%s] SIGHUP: config reloaded; safe fields updated: %v\n", ts(), names)
// A forward-guard config change is a safe (hot-reload) field; re-reconcile
// so enabling/disabling or toggling enforce takes effect without a restart.
d.reconcileForwardGuard()
}
// activeOrStartupCfg returns the current live config, falling back
// to d.cfg (the startup snapshot) if SetActive has not yet been
// called. Reload paths use this so the first reload diffs against
// the startup config, and every subsequent reload diffs against
// whatever the last successful reload installed.
//
// Also used by the tier-run hot paths (via currentCfg below) so a
// SIGHUP-driven threshold change reaches the next tick without a
// restart.
func (d *Daemon) activeOrStartupCfg() *config.Config {
if c := config.Active(); c != nil {
return c
}
return d.cfg
}
// currentCfg is the per-tick config accessor for hot paths. See
// ROADMAP item 7 for the threshold-tuning motivation.
func (d *Daemon) currentCfg() *config.Config {
return d.activeOrStartupCfg()
}
func publishActiveConfig(cfg *config.Config, source string) {
config.SetActive(cfg)
purgeDryRunBlocksIfAutoResponseLive(cfg, source)
}
func purgeDryRunBlocksIfAutoResponseLive(cfg *config.Config, source string) {
if cfg == nil || cfg.AutoResponseDryRunEnabled() {
return
}
if sdb := store.Global(); sdb != nil {
if removed := sdb.PurgeAllDryRunBlocks(); removed > 0 {
fmt.Fprintf(os.Stderr, "[%s] %s: purged %d dry_run_blocks records (auto-response live)\n", ts(), source, removed)
}
}
}
// signAndSaveReloadedConfig re-computes integrity.config_hash for the main
// config file (and integrity.confd_hash for the conf.d fragments) when either
// changed across a SIGHUP. Drop-in fragment CONTENT is intentionally not
// written back into csm.yaml; only their digest is folded into confd_hash so
// a later Verify still passes. The binary hash is preserved from the prior
// live config; a SIGHUP reload cannot upgrade the binary, so it must not drift.
func (d *Daemon) signAndSaveReloadedConfig(oldCfg, newCfg *config.Config) error {
currentHash, err := integrity.HashConfigStable(oldCfg.ConfigFile)
if err != nil {
return err
}
currentConfd, err := integrity.HashConfDir(oldCfg.ConfigDir)
if err != nil {
return err
}
if currentHash == oldCfg.Integrity.ConfigHash && currentConfd == oldCfg.Integrity.ConfdHash {
newCfg.Integrity = oldCfg.Integrity
return nil
}
configHash, confdHash, err := integrity.SignConfigFilePreserving(oldCfg.ConfigFile, oldCfg.ConfigDir, oldCfg.Integrity.BinaryHash)
if err != nil {
return err
}
newCfg.Integrity.BinaryHash = oldCfg.Integrity.BinaryHash
newCfg.Integrity.ConfigHash = configHash
newCfg.Integrity.ConfdHash = confdHash
return nil
}
// emitReloadFinding logs to stderr and pushes a Finding into the
// daemon's alert channel. Non-blocking on channel saturation; the
// daemon's existing drop counter tracks those.
func (d *Daemon) emitReloadFinding(sev alert.Severity, check, msg string) {
fmt.Fprintf(os.Stderr, "[%s] %s: %s\n", ts(), check, msg)
finding := alert.Finding{
Severity: sev,
Check: check,
Message: msg,
Timestamp: time.Now(),
}
select {
case d.alertCh <- finding:
default:
}
}
package daemon
import (
"net"
"sync"
"time"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/incident"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/store"
"github.com/pidginhost/csm/internal/threatintel"
)
var (
incidentOnce sync.Once
incidentCorrelator *incident.Correlator
incidentRegistry = metrics.Default
incidentRetentionCancel func()
incidentAutoCloseCancel func()
// incidentSprayBlocker is the firewall hand-off the daemon wires in
// before IncidentCorrelator() is first invoked. The bool reports
// whether nftables was actually mutated; dry-run, already-blocked,
// and verdict-allow outcomes return false so incident timelines do
// not record a live block that never landed.
// nil means "no blocker wired" (early startup or unit tests); the
// singleton then skips wiring OnSprayBlock and the spray detector
// stays detection-only even with BlockAtSeverity set.
incidentSprayBlocker func(ip, reason string, timeout time.Duration) (bool, error)
)
// SetIncidentSprayBlocker installs the firewall-side hand-off used by the
// incident auto-block paths. Call once after the firewall engine is built
// and before the first IncidentCorrelator() call.
// Passing nil clears the binding.
func SetIncidentSprayBlocker(fn func(ip, reason string, timeout time.Duration) (bool, error)) {
incidentSprayBlocker = fn
}
// incidentAutoCloseInterval is how often the daemon scans Open / Contained
// incidents for staleness. One hour is fast enough that 24h-idle incidents
// close within ~25h worst-case, but slow enough that the per-kind walk
// stays cheap on hosts with thousands of incidents.
const incidentAutoCloseInterval = 1 * time.Hour
// incidentAutoCloseWarmup delays the first sweep just long enough for the
// startup finding burst to settle. Incidents are restored synchronously
// before the loop starts, so a long warm-up only parks a stale backlog as
// "open" after every restart (observed on a frequently-upgraded prod host:
// thousands of >24h incidents sitting open for the full warm-up). Short.
const incidentAutoCloseWarmup = 2 * time.Minute
// incidentAutoCloseDrainDelay is the cadence used when a sweep hit its
// per-sweep cap, so a large backlog drains over a few quick passes instead
// of waiting a full interval between each capped sweep.
const incidentAutoCloseDrainDelay = 30 * time.Second
// incidentAutoCloseMaxPerSweep bounds live closes per sweep so a big
// post-restart backlog does not hold the correlator lock or burst
// thousands of bbolt persists in one tick.
const incidentAutoCloseMaxPerSweep = 1000
// incidentRetentionPeriod is how long resolved/dismissed incidents are
// kept before compaction prunes them. Named constant per project
// convention; config exposure deferred until operators ask.
const incidentRetentionPeriod = 30 * 24 * time.Hour
// incidentOpenThreshold is the number of correlated findings required
// before a non-Critical finding opens an incident. Two means an
// isolated probe (a single dictionary-attack guess, one modsec hit
// from a wandering scanner) is treated as a finding only and never
// promoted to an incident on its own; the next correlated event
// inside the merge window does the promotion. Declared as var so
// tests that exercise the correlator wiring (where one finding is
// expected to land in one incident immediately) can pin it to 1
// via resetIncidentForTest. Production code never mutates this.
var incidentOpenThreshold = 2
// IncidentCorrelator returns the daemon-wide incident correlator.
// On first call: builds the correlator, restores prior state from
// the bbolt store (when available), and registers metrics. Safe for
// concurrent callers.
func IncidentCorrelator() *incident.Correlator {
incidentOnce.Do(func() {
db := store.Global()
var persist func(incident.Incident)
if db != nil {
persist = func(inc incident.Incident) {
if err := db.SaveIncident(inc); err != nil {
// The in-memory correlator has already advanced, so
// failed writes mean the next restore may replay stale
// incident state unless operators repair the store.
csmlog.Warn("incident persist failed",
"id", inc.ID, "kind", string(inc.Kind),
"status", string(inc.Status), "err", err)
}
}
}
// Resolve spray-suppression knobs from the active config. nil
// config (early test wiring) leaves the detector disabled.
var spray incident.SpraySuppressionConfig
var autoBlock incident.IncidentAutoBlockConfig
var whitelisted func(string) bool
var onSprayBlock func(ip, reason string) bool
var onIncidentBlock func(ip, reason string) bool
if cfg := globalCfgForIncidents(); cfg != nil {
spray = incident.SpraySuppressionConfig{
Enabled: cfg.Incidents.SpraySuppression.Enabled,
DryRun: cfg.Incidents.SpraySuppression.DryRun,
DistinctMailboxes: cfg.Incidents.SpraySuppression.DistinctMailboxes,
SeverityEscalateAt: cfg.Incidents.SpraySuppression.SeverityEscalateAt,
PerCheck: cfg.IncidentsSpraySuppressionPerCheck(),
MaxTrackedIPs: cfg.Incidents.SpraySuppression.MaxTrackedIPs,
BlockAtSeverity: cfg.Incidents.SpraySuppression.BlockAtSeverity,
}
// Only wire the firewall hand-off when block-on-spray is
// configured and the daemon has a blocker installed. The live
// auto_response gate is checked at decision time so SIGHUP
// changes to enabled/block_ips take effect without rebuilding
// the singleton.
if spray.BlockAtSeverity != "" && incidentSprayBlocker != nil {
blocker := incidentSprayBlocker
onSprayBlock = func(ip, reason string) bool {
liveCfg := globalCfgForIncidents()
if liveCfg == nil || !liveCfg.AutoResponse.Enabled || !liveCfg.AutoResponse.BlockIPs {
return false
}
timeout, perr := time.ParseDuration(liveCfg.AutoResponse.BlockExpiry)
if perr != nil || timeout <= 0 {
timeout = 24 * time.Hour
}
live, err := blocker(ip, "CSM credential_spray: "+reason, timeout)
if err != nil {
csmlog.Warn("credential_spray block failed", "ip", ip, "err", err)
return false
}
return live
}
}
// Generic incident-driven auto-block. Reuses the same firewall
// blocker as the spray path; the reason prefix differs so audit
// log rows distinguish which detector triggered the block.
kindsRaw := cfg.IncidentsAutoBlockKinds()
kinds := make(map[incident.Kind]bool, len(kindsRaw))
for k := range kindsRaw {
kinds[incident.Kind(k)] = true
}
autoBlock = incident.IncidentAutoBlockConfig{
Enabled: cfg.Incidents.AutoBlock.Enabled,
BlockAtSeverity: cfg.Incidents.AutoBlock.BlockAtSeverity,
Kinds: kinds,
}
if autoBlock.Enabled && autoBlock.BlockAtSeverity != "" && incidentSprayBlocker != nil {
blocker := incidentSprayBlocker
onIncidentBlock = func(ip, reason string) bool {
liveCfg := globalCfgForIncidents()
if liveCfg == nil || !liveCfg.AutoResponse.Enabled || !liveCfg.AutoResponse.BlockIPs {
return false
}
timeout, perr := time.ParseDuration(liveCfg.AutoResponse.BlockExpiry)
if perr != nil || timeout <= 0 {
timeout = 24 * time.Hour
}
live, err := blocker(ip, "CSM incident: "+reason, timeout)
if err != nil {
csmlog.Warn("incident auto-block failed", "ip", ip, "err", err)
return false
}
return live
}
}
// Whitelist accessor: check the static reputation.whitelist
// list, then the bbolt-backed live whitelist operators add at
// runtime. db nil-check inside the closure so store resolution
// stays current across daemon restarts.
staticAllow := make(map[string]bool, len(cfg.Reputation.Whitelist))
for _, ip := range cfg.Reputation.Whitelist {
if ip != "" {
staticAllow[ip] = true
}
}
whitelisted = func(ip string) bool {
if ip == "" {
return false
}
if staticAllow[ip] {
return true
}
if d := store.Global(); d != nil && d.IsWhitelisted(ip) {
return true
}
// Backstop: a verified-crawler IP from a published range
// (Googlebot/Bingbot/Applebot) should never anchor a
// correlated incident. CDN edge ranges are intentionally
// excluded -- see threatintel.IPInAnyBot.
return threatintel.DefaultRanges().IPInAnyBot(net.ParseIP(ip))
}
}
incidentCorrelator = incident.NewCorrelator(incident.CorrelatorConfig{
Persist: persist,
OpenThreshold: incidentOpenThreshold,
SpraySuppression: spray,
AutoBlock: autoBlock,
IsWhitelisted: whitelisted,
CanSprayBlock: func() bool {
cfg := globalCfgForIncidents()
return cfg != nil && cfg.AutoResponse.Enabled && cfg.AutoResponse.BlockIPs
},
CanIncidentBlock: func() bool {
cfg := globalCfgForIncidents()
return cfg != nil && cfg.AutoResponse.Enabled && cfg.AutoResponse.BlockIPs
},
OnSprayBlock: onSprayBlock,
OnIncidentBlock: onIncidentBlock,
})
if db != nil {
list, err := db.ListIncidents()
if err != nil {
csmlog.Warn("incident restore failed", "err", err)
} else {
incidentCorrelator.Restore(list)
}
}
incident.RegisterMetrics(incidentRegistry(), incidentCorrelator)
incidentRetentionCancel = startIncidentRetentionLoop(incidentCorrelator)
// Auto-close runs on its own hourly ticker so the daily retention
// loop is not coupled to the close cadence; a 24h-idle incident
// closes within at most ~25h.
if cfg := globalCfgForIncidents(); cfg != nil {
incidentAutoCloseCancel = startIncidentAutoCloseLoop(incidentCorrelator, cfg)
}
})
return incidentCorrelator
}
// globalCfgForIncidents is overridden in tests to plug a synthetic config
// without touching package-level state. Production wiring sets this to a
// closure over the daemon's loaded config; until that wiring lands the
// auto-close loop simply does not start (no panic). The retention loop
// keeps running unchanged.
var globalCfgForIncidents = func() *config.Config { return nil }
// SetIncidentConfigSource wires the daemon-loaded config so the
// incident singleton can resolve auto-close thresholds at construction.
// Called once from cmd/csm/serve before IncidentCorrelator() is first
// invoked. Subsequent calls overwrite the source so reload paths can
// rebind without restarting the singleton.
func SetIncidentConfigSource(get func() *config.Config) {
if get == nil {
globalCfgForIncidents = func() *config.Config { return nil }
return
}
globalCfgForIncidents = get
}
// startIncidentAutoCloseLoop launches the per-kind idle scan that
// auto-resolves stale incidents. Returns a cancel func. Logs every run
// at info when work was done; silent when nothing closed.
func startIncidentAutoCloseLoop(c *incident.Correlator, cfg *config.Config) func() {
stop := make(chan struct{})
stopped := make(chan struct{})
go func() {
defer close(stopped)
// Single resettable timer: first sweep after a short warm-up, then
// the normal interval -- unless a sweep reports a remaining backlog,
// in which case the next sweep fires on the fast drain cadence so a
// post-restart backlog clears in minutes rather than over hours.
timer := time.NewTimer(incidentAutoCloseWarmup)
defer timer.Stop()
for {
select {
case <-stop:
return
case <-timer.C:
more := runIncidentAutoClose(c, cfg)
next := incidentAutoCloseInterval
if more {
next = incidentAutoCloseDrainDelay
}
timer.Reset(next)
}
}
}()
// The cancel waits for the goroutine to actually exit, not just signal it,
// so the daemon's shutdown sequence can guarantee no sweep is mid-bbolt-
// write when it closes the store immediately after cancelling.
return func() {
close(stop)
<-stopped
}
}
// runIncidentAutoClose is one tick of the auto-close loop. Gated on
// the operator's config and on the per-kind threshold map. dry_run=true
// only increments counters; live mode flips Status -> resolved and
// records "auto:stale" attribution. Returns more=true when the per-sweep
// cap left stale incidents unclosed, so the caller schedules a prompt
// follow-up sweep instead of waiting the full interval.
func runIncidentAutoClose(c *incident.Correlator, cfg *config.Config) (more bool) {
// The safety cap runs on every tick regardless of the operator's
// auto-close toggle or per-kind thresholds. It is a hard backstop against
// unbounded growth of Open/Contained incidents (in memory and bbolt) on a
// host under sustained attack when auto-close is off or a kind is omitted.
capMore := runIncidentSafetyCap(c)
if cfg == nil || !cfg.IncidentsAutoCloseEnabled() {
return capMore
}
rawThresholds := cfg.IncidentsAutoCloseThresholds()
if len(rawThresholds) == 0 {
return capMore
}
thresholds := make(map[incident.Kind]time.Duration, len(rawThresholds))
for k, v := range rawThresholds {
thresholds[incident.Kind(k)] = v
}
dryRun := cfg.Incidents.AutoClose.DryRun
closed, dryRunCount, scanned, more := c.CloseStaleLimited(time.Now(), thresholds, dryRun, incidentAutoCloseMaxPerSweep)
if closed > 0 || dryRunCount > 0 {
csmlog.Info("incident auto-close",
"closed", closed,
"dry_run_decisions", dryRunCount,
"scanned", scanned,
"dry_run", dryRun,
"backlog_remaining", more,
)
}
return more || capMore
}
// incidentSafetyMaxAge is the hard age cap: any Open/Contained incident idle
// longer than this is force-closed regardless of auto-close config.
const incidentSafetyMaxAge = 30 * 24 * time.Hour
// incidentSafetyMaxActive bounds how many Open/Contained incidents are held in
// memory at once; the oldest over this are force-closed.
const incidentSafetyMaxActive = 50000
// runIncidentSafetyCap force-closes incidents past the age cap and trims the
// active set back under the size ceiling. Always runs, independent of the
// operator's auto-close settings. Returns more=true if either sweep left a
// backlog so the loop schedules a prompt follow-up.
func runIncidentSafetyCap(c *incident.Correlator) (more bool) {
now := time.Now()
byAge, ageMore := c.CloseStaleByAge(now, incidentSafetyMaxAge, incidentAutoCloseMaxPerSweep)
byCap, capMore := c.EnforceActiveCap(now, incidentSafetyMaxActive, incidentAutoCloseMaxPerSweep)
if byAge > 0 || byCap > 0 {
csmlog.Info("incident safety cap",
"closed_by_age", byAge,
"closed_by_active_cap", byCap,
"backlog_remaining", ageMore || capMore,
)
}
return ageMore || capMore
}
// startIncidentRetentionLoop runs a daily compaction sweep against the
// store. Started from IncidentCorrelator() once after the singleton is
// constructed. Logs errors but never panics. Returns the cancel func.
// The first sweep fires after one hour so the daemon has time to settle
// before touching the store under retention rules.
func startIncidentRetentionLoop(c *incident.Correlator) func() {
stop := make(chan struct{})
stopped := make(chan struct{})
go func() {
defer close(stopped)
t := time.NewTicker(24 * time.Hour)
defer t.Stop()
first := time.NewTimer(time.Hour)
defer first.Stop()
for {
select {
case <-stop:
return
case <-first.C:
runIncidentCompaction(c)
case <-t.C:
runIncidentCompaction(c)
}
}
}()
// Wait for the goroutine to exit, so a compaction in flight finishes
// before the daemon closes the store.
return func() {
close(stop)
<-stopped
}
}
// runIncidentCompaction prunes resolved/dismissed incidents older than
// the retention window and bumps the compacted_total counter so the
// metric reflects actual store work. Errors are logged, never fatal.
func runIncidentCompaction(c *incident.Correlator) {
db := store.Global()
if db == nil {
return
}
now := time.Now()
pruned, err := db.CompactIncidents(now, incidentRetentionPeriod)
if err != nil {
csmlog.Warn("incident retention compaction failed", "err", err)
return
}
_ = c.PruneClosedOlderThan(now, incidentRetentionPeriod)
_ = c.PruneStalePending(now)
_ = c.PruneStaleSpray(now)
if pruned > 0 {
c.IncrementCompactedTotal(pruned)
csmlog.Info("incident retention compaction", "pruned", pruned)
}
}
// StopIncidentBackgroundLoops cancels the incident auto-close and retention
// goroutines. The daemon calls it during shutdown, before closing the store,
// so neither loop performs a bbolt write against an already-closed database.
// Safe to call when the singleton was never constructed (both cancels nil).
func StopIncidentBackgroundLoops() {
if incidentRetentionCancel != nil {
incidentRetentionCancel()
incidentRetentionCancel = nil
}
if incidentAutoCloseCancel != nil {
incidentAutoCloseCancel()
incidentAutoCloseCancel = nil
}
}
// resetIncidentForTest is a test seam. Stops any retention worker, zeros
// the singleton, and pins the registry to a private one so subsequent
// IncidentCorrelator() calls do not collide on metrics.Default.
func resetIncidentForTest() {
resetIncidentForTestWithThreshold(1)
}
// resetIncidentForTestWithThreshold is the same seam but lets a test pin
// the open threshold to a specific value. Used by the wiring test that
// proves the production default (2) is honored end-to-end through the
// IncidentCorrelator() singleton constructor.
func resetIncidentForTestWithThreshold(threshold int) {
if incidentRetentionCancel != nil {
incidentRetentionCancel()
incidentRetentionCancel = nil
}
if incidentAutoCloseCancel != nil {
incidentAutoCloseCancel()
incidentAutoCloseCancel = nil
}
incidentCorrelator = nil
incidentOnce = sync.Once{}
incidentRegistry = metrics.NewRegistry
globalCfgForIncidents = func() *config.Config { return nil }
incidentSprayBlocker = nil
// Most tests assert that one finding lands in one incident; the
// production threshold of 2 would defer creation to the second
// correlated event and break those wiring assertions. Pin to the
// caller-supplied value; production callers never invoke this seam.
incidentOpenThreshold = threshold
}
package daemon
import (
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// mailIPEntry tracks failed-auth timestamps and suppression state for one IP.
type mailIPEntry struct {
times []time.Time
suppressed time.Time
lastSeen time.Time
}
// mailSubnetEntry tracks unique attacker IPs within a /24.
type mailSubnetEntry struct {
ips map[string]time.Time
suppressed time.Time
lastSeen time.Time
}
// mailAccountEntry tracks unique attacker IPs per mailbox, plus a separate
// suppression clock for compromise findings emitted by RecordSuccess.
type mailAccountEntry struct {
ips map[string]time.Time
suppressed time.Time
compromiseSuppressed time.Time
lastSeen time.Time
}
// mailAuthTracker aggregates dovecot IMAP/POP3/ManageSieve auth events into
// four detection signals: per-IP brute force, per-/24 password spray,
// per-mailbox account spray, and per-account compromise (success after
// recent failures).
//
// Thread-safe; Record/RecordSuccess may be called concurrently from multiple
// log readers.
type mailAuthTracker struct {
mu sync.Mutex
perIPThreshold int
subnetThreshold int
accountSprayThreshold int
window time.Duration
suppression time.Duration
maxTracked int
now func() time.Time
ips map[string]*mailIPEntry
subnets map[string]*mailSubnetEntry
accounts map[string]*mailAccountEntry
// Diagnostic counters (guarded by mu): cumulative Record invocations and
// findings emitted, logged periodically by the daemon to pin whether the
// non-cPanel dovecot brute-force path sees traffic and escalates.
recordCalls int64
findingsEmitted int64
}
// newMailAuthTracker constructs a tracker. `now` is injected so tests can
// use deterministic clocks; pass `time.Now` in production.
func newMailAuthTracker(
perIPThreshold int,
subnetThreshold int,
accountSprayThreshold int,
window time.Duration,
suppression time.Duration,
maxTracked int,
now func() time.Time,
) *mailAuthTracker {
if now == nil {
now = time.Now
}
return &mailAuthTracker{
perIPThreshold: perIPThreshold,
subnetThreshold: subnetThreshold,
accountSprayThreshold: accountSprayThreshold,
window: window,
suppression: suppression,
maxTracked: maxTracked,
now: now,
ips: make(map[string]*mailIPEntry),
subnets: make(map[string]*mailSubnetEntry),
accounts: make(map[string]*mailAccountEntry),
}
}
// Size returns the total number of tracked entities (IPs + subnets + accounts).
func (t *mailAuthTracker) Size() int {
t.mu.Lock()
defer t.mu.Unlock()
return len(t.ips) + len(t.subnets) + len(t.accounts)
}
// Record processes one dovecot IMAP/POP3/ManageSieve auth-failure observation.
// Returns zero or more findings that callers should append.
//
// ip MUST be non-private, non-loopback, and non-infra — callers enforce this
// before invoking Record.
func (t *mailAuthTracker) Record(ip, account string) []alert.Finding {
if ip == "" {
return nil
}
t.mu.Lock()
defer t.mu.Unlock()
t.recordCalls++
now := t.now()
cutoff := now.Add(-t.window)
var findings []alert.Finding
// --- Per-IP tracker ---
e, ok := t.ips[ip]
if !ok {
e = &mailIPEntry{}
t.ips[ip] = e
}
e.times = pruneTimes(e.times, cutoff)
e.times = append(e.times, now)
e.lastSeen = now
if len(e.times) >= t.perIPThreshold && !now.Before(e.suppressed) {
e.suppressed = now.Add(t.suppression)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "mail_bruteforce",
Message: fmt.Sprintf("Mail auth brute force from %s: %d failed auths in %v",
ip, len(e.times), t.window),
Details: "Real-time detection of dovecot imap/pop3/managesieve auth failures",
Timestamp: now,
SourceIP: ip,
})
}
// --- Per-/24 subnet tracker (IPv4 only) ---
if prefix := extractPrefix24Daemon(ip); prefix != "" {
s, ok := t.subnets[prefix]
if !ok {
s = &mailSubnetEntry{ips: make(map[string]time.Time)}
t.subnets[prefix] = s
}
for ipKey, ts := range s.ips {
if ts.Before(cutoff) {
delete(s.ips, ipKey)
}
}
s.ips[ip] = now
s.lastSeen = now
if len(s.ips) >= t.subnetThreshold && !now.Before(s.suppressed) {
s.suppressed = now.Add(t.suppression)
cidr := prefix + ".0/24"
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "mail_subnet_spray",
Message: fmt.Sprintf("Mail password spray from %s.0/24: %d unique IPs in %v",
prefix, len(s.ips), t.window),
Details: "Real-time detection of mail auth failures from many IPs in one /24",
Timestamp: now,
SourceIP: cidr,
})
}
}
// --- Per-account spray tracker ---
if account != "" {
a, ok := t.accounts[account]
if !ok {
a = &mailAccountEntry{ips: make(map[string]time.Time)}
t.accounts[account] = a
}
for ipKey, ts := range a.ips {
if ts.Before(cutoff) {
delete(a.ips, ipKey)
}
}
a.ips[ip] = now
a.lastSeen = now
if len(a.ips) >= t.accountSprayThreshold && !now.Before(a.suppressed) {
a.suppressed = now.Add(t.suppression)
_, acctDomain := alert.SplitEmail(account)
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "mail_account_spray",
Message: fmt.Sprintf("Mail password spray targeting %s: %d unique IPs in %v",
account, len(a.ips), t.window),
Details: "Distributed login attempts across many IPs against one mailbox (visibility only — no auto-block).",
Timestamp: now,
SourceIP: ip,
Domain: acctDomain,
Mailbox: account,
})
}
}
t.enforceMaxTracked()
t.findingsEmitted += int64(len(findings))
return findings
}
// Stats returns cumulative Record invocations and findings emitted since
// startup. Used by the daemon's periodic diagnostic log.
func (t *mailAuthTracker) Stats() (calls, emits int64) {
t.mu.Lock()
defer t.mu.Unlock()
return t.recordCalls, t.findingsEmitted
}
// RecordSuccess processes a successful mail login. Emits mail_account_compromised
// when the successful IP has recent failed auths for the same account — a
// zero-FP compromise signal: the attacker literally failed N times from that IP
// for that mailbox, then guessed the password.
//
// ip and account MUST both be non-empty. Caller filters infra/private/loopback
// IPs before invoking.
func (t *mailAuthTracker) RecordSuccess(ip, account string) []alert.Finding {
if ip == "" || account == "" {
return nil
}
t.mu.Lock()
defer t.mu.Unlock()
a, ok := t.accounts[account]
if !ok {
return nil
}
if _, failedRecently := a.ips[ip]; !failedRecently {
return nil
}
now := t.now()
if now.Before(a.compromiseSuppressed) {
return nil
}
a.compromiseSuppressed = now.Add(t.suppression)
_, compDomain := alert.SplitEmail(account)
return []alert.Finding{{
Severity: alert.Critical,
Check: "mail_account_compromised",
Message: fmt.Sprintf("Mail account compromise: successful login for %s from %s after recent auth failures",
account, ip),
Details: "Attacker succeeded after one or more failed attempts from the same IP for this mailbox. Rotate password and revoke sessions.",
Timestamp: now,
SourceIP: ip,
Domain: compDomain,
Mailbox: account,
}}
}
// Purge removes stale entries older than (window + suppression).
// Called from a background goroutine every minute.
func (t *mailAuthTracker) Purge() {
t.mu.Lock()
defer t.mu.Unlock()
now := t.now()
activityCutoff := now.Add(-(t.window + t.suppression))
windowCutoff := now.Add(-t.window)
for k, e := range t.ips {
e.times = pruneTimes(e.times, windowCutoff)
if len(e.times) == 0 && !e.lastSeen.After(activityCutoff) {
delete(t.ips, k)
}
}
for k, s := range t.subnets {
for ip, ts := range s.ips {
if ts.Before(windowCutoff) {
delete(s.ips, ip)
}
}
if len(s.ips) == 0 && !s.lastSeen.After(activityCutoff) {
delete(t.subnets, k)
}
}
for k, a := range t.accounts {
for ip, ts := range a.ips {
if ts.Before(windowCutoff) {
delete(a.ips, ip)
}
}
if len(a.ips) == 0 && !a.lastSeen.After(activityCutoff) {
delete(t.accounts, k)
}
}
}
// enforceMaxTracked evicts the least-recently-seen entries until the IP count
// is <= 95% of maxTracked. Batch target avoids re-sorting on every subsequent
// insert. Caller must hold t.mu.
func (t *mailAuthTracker) enforceMaxTracked() {
total := len(t.ips) + len(t.subnets) + len(t.accounts)
if total <= t.maxTracked {
return
}
// Evict to 95% of cap so subsequent inserts don't re-trigger the sort.
target := t.maxTracked * 95 / 100
type victim struct {
kind string // "ip" | "subnet" | "account"
key string
seen time.Time
}
victims := make([]victim, 0, total)
for k, v := range t.ips {
victims = append(victims, victim{"ip", k, v.lastSeen})
}
for k, v := range t.subnets {
victims = append(victims, victim{"subnet", k, v.lastSeen})
}
for k, v := range t.accounts {
victims = append(victims, victim{"account", k, v.lastSeen})
}
sort.Slice(victims, func(i, j int) bool { return victims[i].seen.Before(victims[j].seen) })
for i := 0; i < len(victims); i++ {
if len(t.ips)+len(t.subnets)+len(t.accounts) <= target {
break
}
v := victims[i]
switch v.kind {
case "ip":
delete(t.ips, v.key)
case "subnet":
delete(t.subnets, v.key)
case "account":
delete(t.accounts, v.key)
}
}
}
// isMailAuthLine returns true for dovecot imap/pop3/managesieve login events.
func isMailAuthLine(line string) bool {
if !strings.Contains(line, "dovecot:") {
return false
}
return strings.Contains(line, "imap-login:") ||
strings.Contains(line, "pop3-login:") ||
strings.Contains(line, "managesieve-login:")
}
// extractMailLoginEvent parses a dovecot login line and returns
// (ip, account, success). Returns empty strings and false on parse failure.
//
// Real dovecot wire format (validated against production logs):
//
// Success: "imap-login: Logged in: user=<alice@x.ro>, method=PLAIN, rip=..."
// Failure: "imap-login: Login aborted: ... (auth failed, N attempts ...): user=<...>, method=..., rip=..."
//
// The success marker is "Logged in" (NOT "Login:" — an earlier version of
// this parser used the wrong marker and silently skipped every successful
// login, which in turn broke RecordSuccess-based compromise detection).
// The failure marker is "(auth failed" with the opening paren, which
// distinguishes real auth failures from Login-aborted reasons like
// "no auth attempts" or TLS handshake errors.
func extractMailLoginEvent(line string) (ip, account string, success bool) {
switch {
case strings.Contains(line, "-login: Logged in"):
success = true
case strings.Contains(line, "(auth failed"):
success = false
default:
return "", "", false
}
// Extract account key via the configured (or default) extractor.
account = currentAccountExtractor().Extract(line)
// Extract rip=... field. Delimited by comma or whitespace.
if i := strings.Index(line, "rip="); i >= 0 {
rest := line[i+len("rip="):]
end := strings.IndexAny(rest, ", \n")
if end < 0 {
end = len(rest)
}
ip = rest[:end]
}
return ip, account, success
}
package daemon
import (
"fmt"
"regexp"
"strings"
"sync/atomic"
"github.com/pidginhost/csm/internal/config"
)
// AccountExtractor pulls the account/mailbox identifier out of a mail
// server log line. Used by mailbrute for per-account scoring. Selected
// by cfg.Thresholds.MailBruteAccountKey at daemon startup and SIGHUP
// reload.
type AccountExtractor struct {
mode string
re *regexp.Regexp
}
// NewAccountExtractor parses the spec string from cfg.Thresholds.MailBruteAccountKey.
// Empty spec defaults to "builtin:dovecot-user" (matches the legacy behavior).
func NewAccountExtractor(spec string) (*AccountExtractor, error) {
switch {
case spec == "" || spec == "builtin:dovecot-user":
return &AccountExtractor{mode: "dovecot-user"}, nil
case spec == "builtin:postfix-sasl":
return &AccountExtractor{mode: "postfix-sasl"}, nil
case strings.HasPrefix(spec, "regex:"):
re, err := regexp.Compile(strings.TrimPrefix(spec, "regex:"))
if err != nil {
return nil, fmt.Errorf("invalid regex: %w", err)
}
if re.NumSubexp() < 1 {
return nil, fmt.Errorf("regex must contain at least one capture group")
}
return &AccountExtractor{mode: "regex", re: re}, nil
default:
return nil, fmt.Errorf("unknown extractor spec: %s", spec)
}
}
// Extract returns the account/mailbox key, or "" when no match.
func (e *AccountExtractor) Extract(line string) string {
switch e.mode {
case "dovecot-user":
return extractAngleBracket(line, "user=")
case "postfix-sasl":
return extractEqualsValue(line, "sasl_username=")
case "regex":
m := e.re.FindStringSubmatch(line)
if len(m) >= 2 {
return m[1]
}
}
return ""
}
// extractAngleBracket matches `key<value>` with balanced angle brackets.
func extractAngleBracket(line, key string) string {
idx := strings.Index(line, key+"<")
if idx < 0 {
return ""
}
end := strings.IndexByte(line[idx+len(key)+1:], '>')
if end < 0 {
return ""
}
return line[idx+len(key)+1 : idx+len(key)+1+end]
}
// extractEqualsValue matches `key=<value>` where value is delimited by
// whitespace or comma (postfix log format).
func extractEqualsValue(line, key string) string {
idx := strings.Index(line, key)
if idx < 0 {
return ""
}
rest := line[idx+len(key):]
end := strings.IndexAny(rest, " ,\t\n")
if end < 0 {
return rest
}
return rest[:end]
}
// defaultAccountExtractor is the package-level singleton set at daemon
// startup and safe config reloads.
var defaultAccountExtractor atomic.Pointer[AccountExtractor]
func installAccountExtractorFromConfig(cfg *config.Config) error {
ex, err := NewAccountExtractor(cfg.Thresholds.MailBruteAccountKey)
if err != nil {
return fmt.Errorf("invalid mail_brute_account_key: %w", err)
}
SetAccountExtractor(ex)
return nil
}
// SetAccountExtractor installs the configured extractor; called from
// Daemon.Run() and safe reload after applyDefaults has set the spec.
func SetAccountExtractor(ex *AccountExtractor) {
defaultAccountExtractor.Store(ex)
}
// currentAccountExtractor returns the installed extractor, or lazily
// initializes the default so test code that doesn't call SetAccountExtractor
// gets the legacy dovecot-user behavior.
func currentAccountExtractor() *AccountExtractor {
if ex := defaultAccountExtractor.Load(); ex != nil {
return ex
}
ex, _ := NewAccountExtractor("")
defaultAccountExtractor.Store(ex)
return ex
}
package daemon
import (
"time"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/modsec"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/platform"
)
// modsecRegistryRefresh controls how often the rule-action registry is
// rebuilt from disk. ModSec rule files change rarely (vendor pack updates,
// cPanel modsec_assemble nightly run), so a coarse interval keeps the cost
// negligible while still picking up operator edits within minutes.
const modsecRegistryRefresh = 5 * time.Minute
// initModSecRegistry builds the rule-action registry once at startup and
// installs it as the package-level singleton. The registry tells the
// LiteSpeed log-line classifier which "triggered!" matches actually denied
// the request and which were pass-action informational rules. Without this,
// pass-action vendor rules (Comodo CWAF id 210710, 214930, ...) would be
// counted as denies, falsely escalating to a 24-hour auto-block of any IP
// that hits them three times in ten minutes.
//
// The build is failure-soft: missing rule directories yield an empty
// registry, which the classifier interprets the same way as the legacy
// behaviour (default to block on unknown rule). New installs and hosts
// without ModSecurity therefore see no behaviour change.
func (d *Daemon) initModSecRegistry() {
d.refreshModSecRegistry()
d.wg.Add(1)
obs.Go("modsec-registry-refresh", d.modsecRegistryRefreshLoop)
}
func (d *Daemon) refreshModSecRegistry() {
dirs := modsec.RuleDirs(platform.Detect())
reg, err := modsec.BuildRegistry(dirs)
if err != nil {
csmlog.Warn("modsec rule-action registry build had errors", "err", err, "rules_loaded", reg.Len())
}
modsec.SetGlobal(reg)
csmlog.Info("modsec rule-action registry loaded", "rules", reg.Len(), "dirs", len(dirs))
}
func (d *Daemon) modsecRegistryRefreshLoop() {
defer d.wg.Done()
ticker := time.NewTicker(modsecRegistryRefresh)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
d.refreshModSecRegistry()
}
}
}
package daemon
import (
"bufio"
"fmt"
"net"
"os"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/health"
"github.com/pidginhost/csm/internal/obs"
)
const pamSocketPath = "/var/run/csm/pam.sock"
const (
defaultPAMFailureThreshold = 5
defaultPAMFailureWindowMin = 10
defaultCredStuffingDistinctAccounts = 5
)
// PAMListener listens on a Unix socket for authentication events from the
// pam_csm.so PAM module. Tracks failures per IP and triggers CSM auto-blocking.
type PAMListener struct {
cfg *config.Config
alertCh chan<- alert.Finding
listener net.Listener
mu sync.Mutex
failures map[string]*pamFailureTracker
// stopCh is set by Run so emit can abort a send when the daemon is
// shutting down. Per-connection goroutines are not tracked by a
// WaitGroup, so without this escape a goroutine blocked on the
// undrained alert channel would leak past shutdown. Nil for
// hand-constructed listeners in tests, where emit keeps blocking-send
// semantics.
stopCh <-chan struct{}
// useActiveConfig is true for the daemon-owned listener, so SIGHUP
// threshold changes apply without rebuilding the socket. Unit tests that
// assemble a listener by hand keep using their local cfg.
useActiveConfig bool
// stuffing flags one source IP failing against many distinct accounts
// (credential stuffing / password spraying breadth) -- complementary to
// the per-IP failure-count brute-force trigger below. Guarded by mu.
stuffing *credentialStuffingDetector
// startedAt anchors the upstream-probe verdict so the dashboard does
// not flag a freshly-started daemon as "deaf" before the PAM module
// has had a chance to emit anything. Compared against lastPeerNanos
// to decide whether silence is informative.
startedAt time.Time
// lastPeerNanos is an atomic UnixNano timestamp updated on every
// inbound connection (regardless of payload validity). Used by the
// upstream probe to detect a missing PAM module hook.
lastPeerNanos atomic.Int64
}
type pamFailureTracker struct {
count int
firstSeen time.Time
lastSeen time.Time
users map[string]bool
services map[string]bool
blocked bool
}
// NewPAMListener creates a Unix socket listener for PAM events.
func NewPAMListener(cfg *config.Config, alertCh chan<- alert.Finding) (*PAMListener, error) {
// Ensure socket directory exists
if err := os.MkdirAll("/var/run/csm", 0750); err != nil {
return nil, fmt.Errorf("creating socket dir: %w", err)
}
// Remove stale socket
os.Remove(pamSocketPath)
listener, err := net.Listen("unix", pamSocketPath)
if err != nil {
return nil, fmt.Errorf("listening on %s: %w", pamSocketPath, err)
}
// The PAM module runs inside privileged auth stacks, so the socket can stay
// root-only instead of accepting arbitrary local writers.
_ = os.Chmod(pamSocketPath, 0600)
_, window, distinct := pamThresholds(cfg)
return &PAMListener{
cfg: cfg,
alertCh: alertCh,
listener: listener,
failures: make(map[string]*pamFailureTracker),
startedAt: time.Now(),
useActiveConfig: true,
stuffing: newCredentialStuffingDetector(distinct, window, nil),
}, nil
}
// UpstreamProbe returns an UpstreamResult describing whether the PAM
// module hook is feeding the socket. The probe is cheap (single atomic
// load + clock read) so it is safe to wire into the components API.
//
// Fresh verdict:
// - At least one inbound connection within pamUpstreamFreshWindow.
// - Or the daemon has been up for less than pamUpstreamGracePeriod
// (so a freshly-started daemon is not flagged before the first
// real auth happens).
//
// LastActivity is the most recent connection time, or the daemon start
// time when no connection has arrived. Reason explains a !Fresh verdict
// to operators.
func (p *PAMListener) UpstreamResult() health.UpstreamResult {
last := time.Unix(0, p.lastPeerNanos.Load())
if p.lastPeerNanos.Load() == 0 && time.Since(p.startedAt) < pamUpstreamGracePeriod {
return health.UpstreamResult{Fresh: true, LastActivity: p.startedAt}
}
if p.lastPeerNanos.Load() != 0 && time.Since(last) < pamUpstreamFreshWindow {
return health.UpstreamResult{Fresh: true, LastActivity: last}
}
reason := "no PAM module hook feeding the socket; install pam_csm.so and add `session optional pam_csm.so` to the relevant /etc/pam.d/ files"
activity := last
if p.lastPeerNanos.Load() == 0 {
activity = p.startedAt
}
return health.UpstreamResult{Fresh: false, LastActivity: activity, Reason: reason}
}
const (
// pamUpstreamFreshWindow is how recently a peer connection must have
// arrived before the upstream is considered alive. Sized to comfortably
// span the longest realistic gap between auth events on a host that
// has the PAM module installed.
pamUpstreamFreshWindow = 24 * time.Hour
// pamUpstreamGracePeriod is the post-start window during which a
// silent socket is not yet flagged as deaf.
pamUpstreamGracePeriod = 15 * time.Minute
)
// Run accepts connections and processes PAM events.
func (p *PAMListener) Run(stopCh <-chan struct{}) {
p.stopCh = stopCh
// Start cleanup goroutine to expire old failure records
obs.Go("pam-cleanup", func() { p.cleanupLoop(stopCh) })
// Accept connections
obs.Go("pam-accept", func() {
for {
conn, err := p.listener.Accept()
if err != nil {
select {
case <-stopCh:
return
default:
fmt.Fprintf(os.Stderr, "[%s] PAM listener accept error: %v\n", ts(), err)
time.Sleep(100 * time.Millisecond)
continue
}
}
obs.SafeGo("pam-conn", func() { p.handleConnection(conn) })
}
})
<-stopCh
}
// Stop closes the listener and removes the socket file.
func (p *PAMListener) Stop() {
_ = p.listener.Close()
os.Remove(pamSocketPath)
}
func (p *PAMListener) handleConnection(conn net.Conn) {
defer func() { _ = conn.Close() }()
// Record the connection moment regardless of peer trust so the
// upstream probe can distinguish "PAM hook not installed" (no
// connections at all) from "PAM hook present but rejected as
// untrusted" (connections happen but never deliver payload).
p.lastPeerNanos.Store(time.Now().UnixNano())
if !isTrustedPAMPeer(conn) {
return
}
_ = conn.SetDeadline(time.Now().Add(1 * time.Second))
scanner := bufio.NewScanner(conn)
for scanner.Scan() {
line := scanner.Text()
p.processEvent(line)
}
}
// processEvent handles a single PAM event line.
// Format: FAIL ip=1.2.3.4 user=root service=sshd
//
// OK ip=1.2.3.4 user=root service=sshd
func (p *PAMListener) processEvent(line string) {
parts := strings.SplitN(strings.TrimSpace(line), " ", 2)
if len(parts) < 2 {
return
}
eventType := parts[0]
kvPart := parts[1]
var ip, user, service string
for _, kv := range strings.Fields(kvPart) {
switch {
case strings.HasPrefix(kv, "ip="):
ip = kv[3:]
case strings.HasPrefix(kv, "user="):
user = kv[5:]
case strings.HasPrefix(kv, "service="):
service = kv[8:]
}
}
if ip == "" || ip == "-" || ip == "127.0.0.1" {
return
}
// Skip infra IPs
if isInfraIP(ip, p.cfg.InfraIPs) {
return
}
switch eventType {
case "FAIL":
p.emit(p.recordFailure(ip, user, service))
case "OK":
p.clearFailures(ip)
// Successful login from non-infra IP - informational alert
p.emit([]alert.Finding{{
Severity: alert.High,
Check: "pam_login",
Message: fmt.Sprintf("Login success from non-infra IP: %s (user: %s, service: %s)", ip, user, service),
Timestamp: time.Now(),
SourceIP: ip,
}})
}
}
// emit forwards findings to the alert channel. It must be called WITHOUT
// p.mu held: a stalled alert consumer blocks the send, and holding the lock
// across that send would wedge recordFailure, clearFailures, and the cleanup
// loop, letting failure trackers grow without bound.
func (p *PAMListener) emit(findings []alert.Finding) {
for _, f := range findings {
select {
case p.alertCh <- f:
case <-p.stopCh:
// Shutting down and the dispatcher has stopped draining;
// drop the remaining findings rather than leak this
// goroutine. A nil stopCh (hand-constructed listener)
// makes this case unselectable, preserving blocking send.
return
}
}
}
// recordFailure updates per-IP failure state and returns any findings the
// update produced. Findings are returned rather than sent so the caller can
// emit them after releasing p.mu (see emit).
func (p *PAMListener) recordFailure(ip, user, service string) []alert.Finding {
p.mu.Lock()
defer p.mu.Unlock()
var findings []alert.Finding
cfg := p.currentCfg()
threshold, window, distinct := pamThresholds(cfg)
now := time.Now()
tracker, exists := p.failures[ip]
if !exists {
tracker = &pamFailureTracker{
firstSeen: now,
users: make(map[string]bool),
services: make(map[string]bool),
}
p.failures[ip] = tracker
}
tracker.count++
tracker.lastSeen = now
tracker.users[user] = true
tracker.services[service] = true
// Credential-stuffing breadth signal: one source IP failing against many
// distinct accounts. Independent of the per-IP failure-count brute-force
// trigger below, so a low-and-slow campaign that stays under the count
// threshold per account is still caught. Fires once per window.
p.ensureCredentialStuffingDetectorLocked(distinct, window, now)
if accounts, fire := p.stuffing.Record(ip, user); fire {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "credential_stuffing",
Message: fmt.Sprintf("Credential stuffing: %s failed logins against %d distinct accounts", ip, len(accounts)),
Details: fmt.Sprintf("Accounts targeted: %s\nService(s): %s",
strings.Join(accounts, ", "), strings.Join(sortedBoolKeys(tracker.services), ", ")),
Timestamp: now,
SourceIP: ip,
})
}
// Only block if within the time window
if now.Sub(tracker.firstSeen) > window {
// Window expired - reset tracker
tracker.count = 1
tracker.firstSeen = now
tracker.users = map[string]bool{user: true}
tracker.services = map[string]bool{service: true}
tracker.blocked = false
return findings
}
if tracker.count >= threshold && !tracker.blocked {
tracker.blocked = true
// Build user/service lists for details
var users, services []string
for u := range tracker.users {
users = append(users, u)
}
for s := range tracker.services {
services = append(services, s)
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "pam_bruteforce",
Message: fmt.Sprintf("PAM brute-force detected: %s (%d failures in %ds)", ip, tracker.count, int(now.Sub(tracker.firstSeen).Seconds())),
Details: fmt.Sprintf("Users targeted: %s\nServices: %s",
strings.Join(users, ", "), strings.Join(services, ", ")),
Timestamp: now,
SourceIP: ip,
})
}
return findings
}
func (p *PAMListener) clearFailures(ip string) {
p.mu.Lock()
defer p.mu.Unlock()
delete(p.failures, ip)
p.stuffing.Clear(ip)
}
// cleanupLoop removes expired failure trackers every minute.
func (p *PAMListener) cleanupLoop(stopCh <-chan struct{}) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case <-ticker.C:
p.mu.Lock()
cutoff := time.Now().Add(-30 * time.Minute)
for ip, tracker := range p.failures {
if tracker.lastSeen.Before(cutoff) {
delete(p.failures, ip)
}
}
if p.stuffing != nil {
now := time.Now()
_, window, distinct := pamThresholds(p.currentCfg())
p.stuffing.Configure(distinct, window, now)
p.stuffing.PruneStale(now)
}
p.mu.Unlock()
}
}
}
func (p *PAMListener) currentCfg() *config.Config {
if p.useActiveConfig {
if cfg := config.Active(); cfg != nil {
return cfg
}
}
return p.cfg
}
func pamThresholds(cfg *config.Config) (threshold int, window time.Duration, distinct int) {
threshold = defaultPAMFailureThreshold
windowMin := defaultPAMFailureWindowMin
distinct = defaultCredStuffingDistinctAccounts
if cfg != nil {
if cfg.Thresholds.MultiIPLoginThreshold > 0 {
threshold = cfg.Thresholds.MultiIPLoginThreshold
}
if cfg.Thresholds.MultiIPLoginWindowMin > 0 {
windowMin = cfg.Thresholds.MultiIPLoginWindowMin
}
if cfg.Thresholds.CredStuffingDistinctAccounts > 0 {
distinct = cfg.Thresholds.CredStuffingDistinctAccounts
}
}
return threshold, time.Duration(windowMin) * time.Minute, distinct
}
func (p *PAMListener) ensureCredentialStuffingDetectorLocked(distinct int, window time.Duration, now time.Time) {
if p.stuffing == nil {
p.stuffing = newCredentialStuffingDetector(distinct, window, nil)
return
}
p.stuffing.Configure(distinct, window, now)
}
func sortedBoolKeys(m map[string]bool) []string {
out := make([]string, 0, len(m))
for k := range m {
out = append(out, k)
}
sort.Strings(out)
return out
}
// isInfraIP checks if an IP is in the configured infra IP ranges.
// Duplicated here to avoid import cycle with checks package.
func isInfraIP(ip string, infraNets []string) bool {
parsed := net.ParseIP(ip)
if parsed == nil {
return false
}
for _, cidr := range infraNets {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if network.Contains(parsed) {
return true
}
}
return false
}
//go:build linux
package daemon
import (
"net"
"golang.org/x/sys/unix"
)
func isTrustedPAMPeer(conn net.Conn) bool {
unixConn, ok := conn.(*net.UnixConn)
if !ok {
return false
}
rawConn, err := unixConn.SyscallConn()
if err != nil {
return false
}
trusted := false
controlErr := rawConn.Control(func(fd uintptr) {
// #nosec G115 -- socket fd from net.Conn.SyscallConn; POSIX fd fits in int.
cred, err := unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
if err == nil && cred != nil && cred.Uid == 0 {
trusted = true
}
})
return controlErr == nil && trusted
}
package daemon
import (
"fmt"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
// PasswordHijackDetector tracks WHM password changes from non-infra IPs
// and correlates them with subsequent cPanel logins to detect the attack
// pattern: attacker changes password via WHM → immediately logs in.
//
// Legitimate flow (excluded):
//
// Portal (infra IP) changes password via xml-api → user logs in
//
// Attack flow (detected):
//
// Attacker (non-infra IP) changes password via whostmgr → logs in within 60s
type PasswordHijackDetector struct {
mu sync.Mutex
recentChanges map[string]*passwordChange // account -> change info
cfg *config.Config
alertCh chan<- alert.Finding
stopCh <-chan struct{} // closed on daemon shutdown; lets sends escape
}
type passwordChange struct {
account string
ip string
timestamp time.Time
}
const hijackWindow = 120 * time.Second // time window to correlate password change + login
// NewPasswordHijackDetector creates a new detector.
func NewPasswordHijackDetector(cfg *config.Config, alertCh chan<- alert.Finding, stopCh <-chan struct{}) *PasswordHijackDetector {
return &PasswordHijackDetector{
recentChanges: make(map[string]*passwordChange),
cfg: cfg,
alertCh: alertCh,
stopCh: stopCh,
}
}
// emit delivers a finding without ever blocking the (synchronous) session-log
// watcher goroutine forever. It applies backpressure while the daemon is
// running, but gives up the moment shutdown is signaled -- after the alert
// dispatcher stops draining, a plain send would otherwise wedge d.wg.Wait().
// Callers must NOT hold d.mu while calling this.
func (d *PasswordHijackDetector) emit(f alert.Finding) {
select {
case d.alertCh <- f:
case <-d.stopCh:
}
}
// HandlePasswordChange records a WHM password change from a non-infra IP.
func (d *PasswordHijackDetector) HandlePasswordChange(account, ip string) {
if isInfraIPDaemon(ip, d.cfg.InfraIPs) || ip == "127.0.0.1" || ip == "internal" {
return // legitimate - portal or admin action
}
d.mu.Lock()
d.recentChanges[account] = &passwordChange{
account: account,
ip: ip,
timestamp: time.Now(),
}
d.mu.Unlock()
// Alert on the password change itself - non-infra WHM password change is
// always suspicious. Emit outside the lock so a saturated alert channel
// can never wedge the mutex.
d.emit(alert.Finding{
Severity: alert.Critical,
Check: "whm_password_change_noninfra",
Message: fmt.Sprintf("WHM password change from non-infra IP: %s (account: %s)", ip, account),
Details: "Password was changed via WHM from an IP outside your infrastructure. This is a strong indicator of account takeover.",
Timestamp: time.Now(),
SourceIP: ip,
TenantID: account,
})
}
// HandleLogin checks if a cPanel login matches a recent non-infra password change.
func (d *PasswordHijackDetector) HandleLogin(account, loginIP string) {
if isInfraIPDaemon(loginIP, d.cfg.InfraIPs) {
return
}
d.mu.Lock()
change, exists := d.recentChanges[account]
if exists {
delete(d.recentChanges, account)
}
d.mu.Unlock()
if !exists {
return
}
// Check if within the hijack window
if time.Since(change.timestamp) > hijackWindow {
return
}
// CONFIRMED ATTACK: password changed from non-infra IP, login within 120s
d.emit(alert.Finding{
Severity: alert.Critical,
Check: "password_hijack_confirmed",
Message: fmt.Sprintf("CONFIRMED ACCOUNT HIJACK: %s - password changed from %s, login from %s within %ds", account, change.ip, loginIP, int(time.Since(change.timestamp).Seconds())),
Details: fmt.Sprintf("Attack pattern: WHM password change from non-infra IP followed by immediate cPanel login.\nPassword change IP: %s\nLogin IP: %s\nTime between: %ds\n\nBoth IPs should be permanently blocked.", change.ip, loginIP, int(time.Since(change.timestamp).Seconds())),
Timestamp: time.Now(),
SourceIP: loginIP,
TenantID: account,
})
}
// Cleanup removes expired entries.
func (d *PasswordHijackDetector) Cleanup() {
d.mu.Lock()
defer d.mu.Unlock()
for account, change := range d.recentChanges {
if time.Since(change.timestamp) > hijackWindow*2 {
delete(d.recentChanges, account)
}
}
}
// ParseSessionLineForHijack extracts password change and login events
// from session log lines and feeds them to the detector.
func ParseSessionLineForHijack(line string, detector *PasswordHijackDetector) {
// WHM password change: [timestamp] info [whostmgr] IP PURGE account:token password_change
if strings.Contains(line, "[whostmgr]") && strings.Contains(line, "PURGE") && strings.Contains(line, "password_change") {
ip, account := parseWHMPurge(line)
if ip != "" && account != "" {
detector.HandlePasswordChange(account, ip)
}
}
// cPanel login: [timestamp] info [cpaneld] IP NEW account:token ...
if strings.Contains(line, "[cpaneld]") && strings.Contains(line, " NEW ") {
// Skip API sessions
if strings.Contains(line, "method=create_user_session") {
return
}
ip, account := parseCpanelSessionLogin(line)
if ip != "" && account != "" {
detector.HandleLogin(account, ip)
}
}
}
func parseWHMPurge(line string) (ip, account string) {
// Format: [timestamp] info [whostmgr] 198.51.100.50 PURGE account:token password_change
idx := strings.Index(line, "[whostmgr]")
if idx < 0 {
return "", ""
}
rest := strings.TrimSpace(line[idx+len("[whostmgr]"):])
fields := strings.Fields(rest)
if len(fields) < 3 {
return "", ""
}
ip = fields[0]
// Find account from PURGE account:token
for i, f := range fields {
if f == "PURGE" && i+1 < len(fields) {
parts := strings.SplitN(fields[i+1], ":", 2)
if len(parts) >= 1 {
account = parts[0]
}
break
}
}
return ip, account
}
package daemon
import (
"fmt"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
const phpEventsLogPath = "/var/log/csm-php-shield/events.log"
// parsePHPShieldLogLine wraps parsePHPShieldLine for the log watcher handler signature.
func parsePHPShieldLogLine(line string, _ *config.Config) []alert.Finding { //nolint:unparam
f := parsePHPShieldLine(line)
if f == nil {
return nil
}
return []alert.Finding{*f}
}
// parsePHPShieldLine parses a line from the PHP shield event log and returns
// a finding if it represents a security event.
//
// Format: [2026-03-25 10:00:00] EVENT_TYPE ip=X script=Y uri=Z ua=A details=B
func parsePHPShieldLine(line string) *alert.Finding {
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "[") {
return nil
}
// Extract event type (first word after the timestamp bracket)
closeBracket := strings.Index(line, "]")
if closeBracket < 0 || closeBracket+2 >= len(line) {
return nil
}
rest := strings.TrimSpace(line[closeBracket+1:])
fields := strings.SplitN(rest, " ", 2)
if len(fields) < 1 {
return nil
}
eventType := fields[0]
// Extract key=value pairs
var ip, script, details string
if len(fields) > 1 {
kvPart := fields[1]
for _, kv := range splitKV(kvPart) {
switch kv[0] {
case "ip":
ip = kv[1]
case "script":
script = kv[1]
case "details":
details = kv[1]
}
}
}
switch eventType {
case "BLOCK_PATH":
return &alert.Finding{
Severity: alert.Critical,
Check: "php_shield_block",
Message: fmt.Sprintf("PHP Shield blocked execution from dangerous path: %s", script),
Details: fmt.Sprintf("IP: %s\n%s", ip, details),
}
case "WEBSHELL_PARAM":
return &alert.Finding{
Severity: alert.Critical,
Check: "php_shield_webshell",
Message: fmt.Sprintf("PHP Shield detected webshell command parameter: %s", script),
Details: fmt.Sprintf("IP: %s\n%s", ip, details),
}
case "EVAL_FATAL":
return &alert.Finding{
Severity: alert.High,
Check: "php_shield_eval",
Message: fmt.Sprintf("PHP Shield detected eval() chain failure: %s", script),
Details: details,
}
}
return nil
}
// splitKV splits "key1=val1 key2=val2" respecting values with spaces.
func splitKV(s string) [][2]string {
var result [][2]string
keys := []string{"ip=", "script=", "uri=", "ua=", "details="}
for i, key := range keys {
idx := strings.Index(s, key)
if idx < 0 {
continue
}
valStart := idx + len(key)
// Value ends at the next key or end of string
valEnd := len(s)
for _, nextKey := range keys[i+1:] {
nextIdx := strings.Index(s[valStart:], " "+nextKey)
if nextIdx >= 0 {
valEnd = valStart + nextIdx
break
}
}
val := strings.TrimSpace(s[valStart:valEnd])
keyName := strings.TrimSuffix(key, "=")
result = append(result, [2]string{keyName, val})
}
return result
}
package daemon
import (
"bufio"
"bytes"
"encoding/gob"
"fmt"
"os"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/emailspool"
"github.com/pidginhost/csm/internal/store"
)
var phpRelayEvaluatorRef atomic.Pointer[evaluator]
// SetPHPRelayEvaluator is called by daemon wiring once everything is up.
// nil disables php_relay path dispatch from parseEximLogLine.
func SetPHPRelayEvaluator(e *evaluator) {
phpRelayEvaluatorRef.Store(e)
}
// PHPRelayEvaluator returns the registered evaluator, or nil if none.
func PHPRelayEvaluator() *evaluator {
return phpRelayEvaluatorRef.Load()
}
// scriptKey = host(X-PHP-Script) + ":" + path(X-PHP-Script)
type scriptKey = string
// scriptEvent is one accepted outbound mail from a PHP script. The booleans
// are computed at acceptance time (Flow A); evaluatePaths counts them in
// the sliding window.
type scriptEvent struct {
At time.Time
MsgID string
Subject string
FromMismatch bool
AdditionalSignal bool // Reply-To external mismatch OR X-Mailer suspicious-not-safe
SourceIP string
}
// rejectionEvent records a remote-MTA policy-block rejection for Path 3.
// Stage 1 defines the type for completeness; Stage 2 wires it.
//
//nolint:unused // wired by Path 3 in Stage 2
type rejectionEvent struct {
At time.Time
MsgID string
MTACode string
Snippet string
}
const (
phpRelayMaxEventsPerScript = 256
phpRelayMaxRejectionsPerScript = 64
phpRelayMaxActiveMsgsPerScript = 4096
phpRelayScriptIdleHorizon = 25 * time.Hour //nolint:unused // consumed by Flow E in Task O2
)
// scriptState tracks one script's recent activity. All fields read or
// written through the embedded mutex.
type scriptState struct {
mu sync.Mutex
events []scriptEvent
rejections []rejectionEvent //nolint:unused // wired by Path 3 in Stage 2
firedAt map[string]time.Time
activeMsgs map[string]time.Time // msgID -> acceptedAt
activeMsgsCapped bool
lastEvent time.Time
maxEvents int
maxRejections int //nolint:unused // wired by Path 3 in Stage 2
maxActiveMsgs int
}
func newScriptState() *scriptState {
return &scriptState{
firedAt: make(map[string]time.Time, 4),
activeMsgs: make(map[string]time.Time, 64),
maxEvents: phpRelayMaxEventsPerScript,
maxRejections: phpRelayMaxRejectionsPerScript,
maxActiveMsgs: phpRelayMaxActiveMsgsPerScript,
}
}
func (s *scriptState) append(e scriptEvent) {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.events) >= s.maxEvents {
s.events = s.events[1:]
}
s.events = append(s.events, e)
if e.At.After(s.lastEvent) {
s.lastEvent = e.At
}
}
// qualifyingCount returns the number of events whose At is at or after
// since AND for which match returns true.
func (s *scriptState) qualifyingCount(since time.Time, match func(scriptEvent) bool) int {
s.mu.Lock()
defer s.mu.Unlock()
n := 0
for _, e := range s.events {
if e.At.Before(since) {
continue
}
if match(e) {
n++
}
}
return n
}
// volumeCount returns events on or after since regardless of signal flags.
//
//nolint:unused // consumed by Path 2 evaluator in Task F2
func (s *scriptState) volumeCount(since time.Time) int {
return s.qualifyingCount(since, func(scriptEvent) bool { return true })
}
func (s *scriptState) relayHit(k scriptKey, since time.Time, match func(scriptEvent) bool) (alert.RelayScriptHit, bool) {
s.mu.Lock()
defer s.mu.Unlock()
hit := alert.RelayScriptHit{ScriptKey: string(k)}
var sampleAt time.Time
for _, e := range s.events {
if e.At.Before(since) || !match(e) {
continue
}
hit.Hits++
if e.At.After(hit.LastSeen) {
hit.LastSeen = e.At
}
if e.Subject != "" && (sampleAt.IsZero() || e.At.After(sampleAt)) {
hit.SampleSubject = truncateDaemon(e.Subject, phpRelayBreakdownSubjectMax)
sampleAt = e.At
}
}
if hit.Hits == 0 {
return alert.RelayScriptHit{}, false
}
return hit, true
}
func (s *scriptState) recordActive(msgID string, at time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.activeMsgs[msgID]; !exists && len(s.activeMsgs) >= s.maxActiveMsgs {
// Drop oldest.
var oldestID string
var oldest time.Time
first := true
for id, t := range s.activeMsgs {
if first || t.Before(oldest) {
oldestID = id
oldest = t
first = false
}
}
if oldestID != "" {
delete(s.activeMsgs, oldestID)
}
s.activeMsgsCapped = true
}
s.activeMsgs[msgID] = at
}
func (s *scriptState) removeActive(msgID string) {
s.mu.Lock()
delete(s.activeMsgs, msgID)
s.mu.Unlock()
}
// snapshotActiveMsgs returns a copy of activeMsgs keys and the capped flag.
// The returned slice is independent of internal state; callers may mutate
// it without affecting the scriptState.
func (s *scriptState) snapshotActiveMsgs() ([]string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]string, 0, len(s.activeMsgs))
for id := range s.activeMsgs {
out = append(out, id)
}
return out, s.activeMsgsCapped
}
// pruneActiveMsgsOlderThan drops activeMsgs entries older than cutoff.
// Called by Flow E's GC to bound the lifetime of unreaped ids.
func (s *scriptState) pruneActiveMsgsOlderThan(cutoff time.Time) int {
s.mu.Lock()
defer s.mu.Unlock()
n := 0
for id, at := range s.activeMsgs {
if at.Before(cutoff) {
delete(s.activeMsgs, id)
n++
}
}
return n
}
// shouldFire returns true and updates firedAt if the cooldown for path has
// elapsed since the last fire.
//
//nolint:unused // consumed by Path 1/2/3 evaluators in Tasks F1/F2/G1
func (s *scriptState) shouldFire(path string, now time.Time, cooldown time.Duration) bool {
s.mu.Lock()
defer s.mu.Unlock()
if last, ok := s.firedAt[path]; ok && now.Sub(last) < cooldown {
return false
}
s.firedAt[path] = now
return true
}
// perScriptWindow keeps a scriptState per scriptKey.
type perScriptWindow struct {
states sync.Map // map[scriptKey]*scriptState
}
func newPerScriptWindow() *perScriptWindow { return &perScriptWindow{} }
func (w *perScriptWindow) getOrCreate(k scriptKey) *scriptState {
if v, ok := w.states.Load(k); ok {
return v.(*scriptState)
}
fresh := newScriptState()
actual, _ := w.states.LoadOrStore(k, fresh)
return actual.(*scriptState)
}
// SweepIdle drops scriptState entries whose lastEvent is before cutoff.
// Returns the number of entries dropped. Called by Flow E.
//
//nolint:unused // consumed by Flow E GC in Task O2
func (w *perScriptWindow) SweepIdle(cutoff time.Time) int {
n := 0
w.states.Range(func(k, v any) bool {
s := v.(*scriptState)
s.mu.Lock()
idle := s.lastEvent.Before(cutoff)
s.mu.Unlock()
if idle {
w.states.Delete(k)
n++
}
return true
})
return n
}
// PruneActiveMsgs iterates retained scriptStates and prunes activeMsgs
// entries older than cutoff. Used by Flow E so still-active scripts don't
// accumulate ghost activeMsgs whose corresponding messages have left the
// queue without a "Completed" log line being parsed. Returns the total
// number of activeMsgs entries removed across all scripts.
func (w *perScriptWindow) PruneActiveMsgs(cutoff time.Time) int {
n := 0
w.states.Range(func(_, v any) bool {
n += v.(*scriptState).pruneActiveMsgsOlderThan(cutoff)
return true
})
return n
}
// Snapshot returns the current per-script states (for csm phprelay status).
//
//nolint:unused // consumed by csm phprelay status command in Task M1
func (w *perScriptWindow) Snapshot() map[scriptKey]*scriptState {
out := make(map[scriptKey]*scriptState)
w.states.Range(func(k, v any) bool {
out[k.(scriptKey)] = v.(*scriptState)
return true
})
return out
}
type ipState struct {
mu sync.Mutex
scripts map[scriptKey]*ipScriptState
lastEvent time.Time
}
type ipScriptState struct {
lastSeen time.Time
sampleAt time.Time
sampleSubject string
}
type perIPWindow struct {
states sync.Map // map[string]*ipState
capPerIP int
}
func newPerIPWindow(capPerIP int) *perIPWindow {
if capPerIP <= 0 {
capPerIP = 64
}
return &perIPWindow{capPerIP: capPerIP}
}
func (w *perIPWindow) append(ip string, k scriptKey, at time.Time, subject ...string) {
if ip == "" {
return
}
v, _ := w.states.LoadOrStore(ip, &ipState{scripts: make(map[scriptKey]*ipScriptState, 8)})
s := v.(*ipState)
s.mu.Lock()
defer s.mu.Unlock()
if _, exists := s.scripts[k]; !exists && len(s.scripts) >= w.capPerIP {
var oldestK scriptKey
var oldest time.Time
first := true
for kk, ss := range s.scripts {
if first || ss.lastSeen.Before(oldest) {
oldestK = kk
oldest = ss.lastSeen
first = false
}
}
delete(s.scripts, oldestK)
}
ss := s.scripts[k]
if ss == nil {
ss = &ipScriptState{}
s.scripts[k] = ss
}
sampleSubject := ""
if len(subject) > 0 {
sampleSubject = truncateDaemon(subject[0], phpRelayBreakdownSubjectMax)
}
if at.After(ss.lastSeen) {
ss.lastSeen = at
}
if sampleSubject != "" && (ss.sampleAt.IsZero() || at.After(ss.sampleAt)) {
ss.sampleAt = at
ss.sampleSubject = sampleSubject
}
if at.After(s.lastEvent) {
s.lastEvent = at
}
}
func (w *perIPWindow) distinctScriptsSince(ip string, since time.Time) int {
v, ok := w.states.Load(ip)
if !ok {
return 0
}
s := v.(*ipState)
s.mu.Lock()
defer s.mu.Unlock()
n := 0
for _, ss := range s.scripts {
if !ss.lastSeen.Before(since) {
n++
}
}
return n
}
func (w *perIPWindow) relaySamplesSince(ip string, since time.Time) []alert.RelayScriptHit {
v, ok := w.states.Load(ip)
if !ok {
return nil
}
s := v.(*ipState)
s.mu.Lock()
defer s.mu.Unlock()
out := make([]alert.RelayScriptHit, 0, len(s.scripts))
for k, ss := range s.scripts {
if ss.lastSeen.Before(since) {
continue
}
sampleSubject := ""
if !ss.sampleAt.Before(since) {
sampleSubject = ss.sampleSubject
}
out = append(out, alert.RelayScriptHit{
ScriptKey: string(k),
Hits: 1,
LastSeen: ss.lastSeen,
SampleSubject: sampleSubject,
})
}
return out
}
func (w *perIPWindow) SweepIdle(cutoff time.Time) int {
n := 0
w.states.Range(func(k, v any) bool {
s := v.(*ipState)
s.mu.Lock()
idle := s.lastEvent.Before(cutoff)
s.mu.Unlock()
if idle {
w.states.Delete(k)
n++
}
return true
})
return n
}
type accountState struct {
mu sync.Mutex
events []time.Time
firedAt time.Time
lastEvent time.Time
maxEvents int
}
type perAccountWindow struct {
states sync.Map
cap int
}
func newPerAccountWindow(capPerAccount int) *perAccountWindow {
if capPerAccount <= 0 {
capPerAccount = 5000
}
return &perAccountWindow{cap: capPerAccount}
}
func (w *perAccountWindow) append(user string, at time.Time) {
if user == "" {
return
}
v, _ := w.states.LoadOrStore(user, &accountState{maxEvents: w.cap})
s := v.(*accountState)
s.mu.Lock()
defer s.mu.Unlock()
if len(s.events) >= s.maxEvents {
s.events = s.events[1:]
}
s.events = append(s.events, at)
if at.After(s.lastEvent) {
s.lastEvent = at
}
}
func (w *perAccountWindow) volumeSince(user string, since time.Time) int {
v, ok := w.states.Load(user)
if !ok {
return 0
}
s := v.(*accountState)
s.mu.Lock()
defer s.mu.Unlock()
n := 0
for _, t := range s.events {
if !t.Before(since) {
n++
}
}
return n
}
func (w *perAccountWindow) shouldFire(user string, now time.Time, cooldown time.Duration) bool {
v, ok := w.states.Load(user)
if !ok {
return false
}
s := v.(*accountState)
s.mu.Lock()
defer s.mu.Unlock()
if !s.firedAt.IsZero() && now.Sub(s.firedAt) < cooldown {
return false
}
s.firedAt = now
return true
}
func (w *perAccountWindow) SweepIdle(cutoff time.Time) int {
n := 0
w.states.Range(func(k, v any) bool {
s := v.(*accountState)
s.mu.Lock()
idle := s.lastEvent.Before(cutoff)
s.mu.Unlock()
if idle {
w.states.Delete(k)
n++
}
return true
})
return n
}
// signals is the per-event boolean fingerprint Flow A appends to
// scriptState. The numeric scriptKey lookup happens against the same
// emailspool helpers used elsewhere so subdomain handling is consistent.
type signals struct {
ScriptKey scriptKey
SourceIP string
FromMismatch bool
AdditionalSignal bool
XMailer string
}
// computeSignals resolves the per-event flags for an accepted message.
// authDomains is the cPanel user's authorised domain set (empty on
// resolver error -- caller treats that as "skip From-mismatch contribution"
// by passing isAuthDomainsKnown=false).
func computeSignals(h emailspool.Headers, authDomains map[string]struct{}, pol *emailspool.Policies) signals {
sk, sourceIP := parseXPHPScript(h.XPHPScript)
s := signals{
ScriptKey: sk,
SourceIP: sourceIP,
XMailer: h.XMailer,
}
if len(authDomains) > 0 {
fromDomain := emailspool.ExtractDomain(h.From)
if fromDomain != "" && !IsAuthorisedFromDomain(fromDomain, authDomains) {
s.FromMismatch = true
}
}
// Reply-To external mismatch contribution.
var replyToDomainMismatch bool
if h.ReplyTo != "" && h.From != "" {
rd := emailspool.ExtractDomain(h.ReplyTo)
fd := emailspool.ExtractDomain(h.From)
if rd != "" && fd != "" && rd != fd {
replyToDomainMismatch = true
}
}
// X-Mailer suspicious contribution.
var mailerSuspicious bool
if pol != nil {
if pol.MailerSuspicious(h.XMailer) && !pol.MailerSafe(h.XMailer) {
mailerSuspicious = true
}
}
s.AdditionalSignal = replyToDomainMismatch || mailerSuspicious
return s
}
// parseXPHPScript splits an X-PHP-Script header value into (scriptKey, sourceIP).
// Format: "<host>/<path> for <ip>". Returns ("", "") on parse failure.
func parseXPHPScript(v string) (scriptKey, string) {
v = strings.TrimSpace(v)
if v == "" {
return "", ""
}
forIdx := strings.LastIndex(v, " for ")
var url, ip string
if forIdx > 0 {
url = strings.TrimSpace(v[:forIdx])
ip = strings.TrimSpace(v[forIdx+5:])
} else {
url = v
}
// Strip any query string.
if q := strings.IndexByte(url, '?'); q > 0 {
url = url[:q]
}
slash := strings.IndexByte(url, '/')
if slash < 0 {
// Bare host with no path.
return scriptKey(url + ":/"), ip
}
host := url[:slash]
path := url[slash:]
return scriptKey(host + ":" + path), ip
}
const (
phpRelayPathCooldown = 30 * time.Minute
phpRelayBreakdownSubjectMax = 160
)
// evaluator combines windows + config + alerter in one object so the
// detector code path is callable from inotify watcher, retro scan, and
// startup spool walker without rebuilding the dependency graph each time.
type evaluator struct {
scripts *perScriptWindow
ips *perIPWindow
accounts *perAccountWindow
cfg *config.Config
metrics *phpRelayMetrics // optional; nil in unit tests
policies *emailspool.Policies
effectiveAccountLimit int
}
func newEvaluator(s *perScriptWindow, i *perIPWindow, a *perAccountWindow, cfg *config.Config, m *phpRelayMetrics) *evaluator {
return &evaluator{scripts: s, ips: i, accounts: a, cfg: cfg, metrics: m}
}
// SetPolicies is called by daemon wiring once the policies file has loaded.
func (e *evaluator) SetPolicies(p *emailspool.Policies) { e.policies = p }
// evaluatePaths inspects the script's window state (and IP window) and
// returns the set of findings that fire at this moment. Cooldowns prevent
// duplicate emissions per (script, path).
func (e *evaluator) evaluatePaths(k scriptKey, sourceIP, cpuser string, now time.Time) []alert.Finding {
if !e.cfg.EmailProtection.PHPRelay.Enabled {
return nil
}
var findings []alert.Finding
s := e.scripts.getOrCreate(k)
// Path 1: sustained qualifying events.
win := time.Duration(e.cfg.EmailProtection.PHPRelay.RateWindowMin) * time.Minute
qualifying := s.qualifyingCount(now.Add(-win), func(ev scriptEvent) bool {
return ev.FromMismatch && ev.AdditionalSignal
})
if qualifying >= e.cfg.EmailProtection.PHPRelay.HeaderScoreVolumeMin {
if s.shouldFire("header", now, phpRelayPathCooldown) {
f := e.makeFinding(k, "header", sourceIP, cpuser, s, fmtHeaderMessage(qualifying, win), now)
f.RelayTotal = qualifying
f.RelayBreakdown = e.scriptRelayBreakdown(k, now.Add(-win), func(ev scriptEvent) bool {
return ev.FromMismatch && ev.AdditionalSignal
})
if e.metrics != nil {
e.metrics.Findings.With("header").Inc()
}
findings = append(findings, f)
}
}
// Path 2: absolute volume per script in the last 60 min.
absVol := s.volumeCount(now.Add(-60 * time.Minute))
if absVol >= e.cfg.EmailProtection.PHPRelay.AbsoluteVolumePerHour {
if s.shouldFire("volume", now, phpRelayPathCooldown) {
f := e.makeFinding(k, "volume", sourceIP, cpuser, s,
fmt.Sprintf("Path 2: %d outbound mails from one script in last 60 min", absVol), now)
f.RelayTotal = absVol
f.RelayBreakdown = e.scriptRelayBreakdown(k, now.Add(-60*time.Minute), func(scriptEvent) bool {
return true
})
if e.metrics != nil {
e.metrics.Findings.With("volume").Inc()
}
findings = append(findings, f)
}
}
// Path 4: HTTP-IP fanout. Skipped silently for proxy IPs.
if sourceIP != "" {
if e.policies == nil || !e.policies.IsProxyIP(sourceIP) {
fwin := time.Duration(e.cfg.EmailProtection.PHPRelay.FanoutWindowMin) * time.Minute
distinct := e.ips.distinctScriptsSince(sourceIP, now.Add(-fwin))
if distinct >= e.cfg.EmailProtection.PHPRelay.FanoutDistinctScripts {
if s.shouldFire("fanout", now, phpRelayPathCooldown) {
f := e.makeFinding(k, "fanout", sourceIP, cpuser, s,
fmt.Sprintf("Path 4: HTTP source IP %s triggered %d distinct scripts in last %s", sourceIP, distinct, fwin), now)
f.RelayTotal = distinct
f.RelayBreakdown = e.fanoutRelayBreakdown(sourceIP, now.Add(-fwin))
if e.metrics != nil {
e.metrics.Findings.With("fanout").Inc()
}
findings = append(findings, f)
}
}
}
}
return findings
}
// makeFinding builds a Critical Finding for the given path.
func (e *evaluator) makeFinding(k scriptKey, path, sourceIP, cpuser string, s *scriptState, message string, now time.Time) alert.Finding {
msgIDs, _ := s.snapshotActiveMsgs()
// Cap the sample shown in the finding so the alert payload stays bounded;
// AutoFreezePHPRelayQueue takes its own complete snapshot.
if len(msgIDs) > 10 {
msgIDs = msgIDs[:10]
}
return alert.Finding{
Severity: alert.Critical,
Check: "email_php_relay_abuse",
Path: path,
Message: message,
ScriptKey: string(k),
SourceIP: sourceIP,
CPUser: cpuser,
MsgIDs: msgIDs,
Timestamp: now,
}
}
func (e *evaluator) scriptRelayBreakdown(k scriptKey, since time.Time, match func(scriptEvent) bool) []alert.RelayScriptHit {
hit, ok := e.scripts.getOrCreate(k).relayHit(k, since, match)
if !ok {
return nil
}
return []alert.RelayScriptHit{hit}
}
func (e *evaluator) fanoutRelayBreakdown(sourceIP string, since time.Time) []alert.RelayScriptHit {
if e.ips == nil || sourceIP == "" {
return nil
}
samples := e.ips.relaySamplesSince(sourceIP, since)
if len(samples) == 0 {
return nil
}
out := make([]alert.RelayScriptHit, 0, len(samples))
for _, sample := range samples {
hit := sample
if e.scripts != nil {
if v, ok := e.scripts.states.Load(scriptKey(sample.ScriptKey)); ok {
if counted, ok := v.(*scriptState).relayHit(scriptKey(sample.ScriptKey), since, func(ev scriptEvent) bool {
return ev.SourceIP == sourceIP
}); ok {
hit = counted
}
}
}
out = append(out, hit)
}
sortRelayScriptHits(out)
return out
}
func sortRelayScriptHits(out []alert.RelayScriptHit) {
sort.Slice(out, func(i, j int) bool {
if out[i].Hits != out[j].Hits {
return out[i].Hits > out[j].Hits
}
if !out[i].LastSeen.Equal(out[j].LastSeen) {
return out[i].LastSeen.After(out[j].LastSeen)
}
return out[i].ScriptKey < out[j].ScriptKey
})
}
func fmtHeaderMessage(qualifying int, win time.Duration) string {
return fmt.Sprintf("Path 1: %d qualifying outbound mails (From-mismatch AND suspicious header) in last %s", qualifying, win)
}
type cpanelLimitStatus int
const (
cpanelLimitOK cpanelLimitStatus = iota
cpanelLimitMissing
cpanelLimitUnparsable
cpanelLimitDisabled
)
// readCpanelHourlyLimit returns (parsed-value, status). The key in
// /var/cpanel/cpanel.config is `maxemailsperhour` (no underscores, matches
// internal/checks/hardening_audit.go usage).
//
// OK -> integer > 0; the cap is in force.
// Missing -> file or key absent; caller assumes default 100 + Warning.
// Unparsable -> key present but not a number; caller assumes default 100 + Warning.
// Disabled -> key present and == 0; cpanel hourly limit explicitly off.
func readCpanelHourlyLimit(path string) (int, cpanelLimitStatus) {
// #nosec G304 -- path is the cPanel config path (default /var/cpanel/cpanel.config); operator-controlled, root-owned.
f, err := os.Open(path)
if err != nil {
return 0, cpanelLimitMissing
}
defer f.Close()
sc := bufio.NewScanner(f)
for sc.Scan() {
line := sc.Text()
eq := strings.IndexByte(line, '=')
if eq < 0 {
continue
}
if strings.TrimSpace(line[:eq]) != "maxemailsperhour" {
continue
}
val := strings.TrimSpace(line[eq+1:])
n, err := strconv.Atoi(val)
if err != nil {
return 0, cpanelLimitUnparsable
}
if n == 0 {
return 0, cpanelLimitDisabled
}
return n, cpanelLimitOK
}
return 0, cpanelLimitMissing
}
// deriveEffectiveAccountLimit implements spec section 6.1's three-step
// derivation. Returns (effective, enabled, cappedFromOperator).
// - cpanelLimit/status come from readCpanelHourlyLimit.
// - missing/unparsable callers should also emit a Warning startup finding.
// - returned enabled=false means Path 2b should not run this session.
func deriveEffectiveAccountLimit(cfg *config.Config, cpanelLimit int, status cpanelLimitStatus) (effective int, enabled bool, capped bool) {
op := cfg.EmailProtection.PHPRelay.AccountVolumePerHour
// Step 1: classify the cPanel limit.
var assumed int
var known bool
switch status {
case cpanelLimitOK:
assumed = cpanelLimit
known = true
case cpanelLimitMissing, cpanelLimitUnparsable:
// Caller emits Warning; we use the cPanel default 100.
assumed = 100
known = true
case cpanelLimitDisabled:
known = false
}
// Step 2: derive effective.
if known {
cap := assumed * 95 / 100
if cap < 1 {
cap = 1
}
if op == 0 {
target := assumed * 60 / 100
if target < 20 {
target = 20
}
if target > 60 {
target = 60
}
effective = target
if effective > cap {
effective = cap
}
} else {
effective = op
if effective > cap {
effective = cap
capped = true
}
}
if effective <= 0 {
return 0, false, false
}
return effective, true, capped
}
// Cpanel limit explicitly disabled.
if op > 0 {
return op, true, false
}
return 0, false, false
}
const (
phpRelayAccountWindowDur = 60 * time.Minute
phpRelayAccountFireCooldown = 30 * time.Minute
)
// SetEffectiveAccountLimit is called by daemon wiring after derivation.
// Tests may also call it directly. A non-positive value disables Path 2b.
func (e *evaluator) SetEffectiveAccountLimit(n int) {
e.effectiveAccountLimit = n
}
// parsePHPRelayAccountVolume processes one outbound `<= ` exim_mainlog line.
// Returns zero or one finding (per cooldown).
func (e *evaluator) parsePHPRelayAccountVolume(line string, now time.Time) []alert.Finding {
if e.effectiveAccountLimit <= 0 || e.accounts == nil {
return nil
}
if !strings.Contains(line, " <= ") {
return nil
}
if !strings.Contains(line, " B=redirect_resolver") {
return nil
}
user := extractUField(line)
if user == "" {
return nil
}
e.accounts.append(user, now)
volume := e.accounts.volumeSince(user, now.Add(-phpRelayAccountWindowDur))
if volume < e.effectiveAccountLimit {
return nil
}
if !e.accounts.shouldFire(user, now, phpRelayAccountFireCooldown) {
return nil
}
if e.metrics != nil {
e.metrics.Findings.With("volume_account").Inc()
}
return []alert.Finding{{
Severity: alert.Critical,
Check: "email_php_relay_abuse",
Path: "volume_account",
Message: fmt.Sprintf("Path 2b: account %s sent >= %d outbound mails in last hour", user, e.effectiveAccountLimit),
CPUser: user,
RelayTotal: volume,
Timestamp: now,
}}
}
// extractUField returns the cpuser from "U=<name>" in an exim log line.
// Returns "" if absent.
func extractUField(line string) string {
idx := strings.Index(line, " U=")
if idx < 0 {
return ""
}
rest := line[idx+3:]
end := len(rest)
for i := 0; i < len(rest); i++ {
c := rest[i]
if c == ' ' || c == '\t' {
end = i
break
}
}
return rest[:end]
}
// ignoreEntry records an operator-issued ignore on a script. A zero
// ExpiresAt means "never expires"; otherwise Has/List/SweepExpired drop
// the entry once now > ExpiresAt.
type ignoreEntry struct {
ScriptKey string
AddedAt time.Time
ExpiresAt time.Time
AddedBy string
Reason string
}
// ignoreList is the in-memory operator allowlist for php_relay scripts.
// L2 (bbolt persistence) wraps this with --persist semantics; Flow E
// (O2) calls SweepExpired periodically.
type ignoreList struct {
mu sync.Mutex
entries map[string]ignoreEntry
db *store.DB
}
func newIgnoreList() *ignoreList {
return &ignoreList{entries: make(map[string]ignoreEntry)}
}
func (il *ignoreList) Add(k scriptKey, expiresAt time.Time, by, reason string) {
il.mu.Lock()
defer il.mu.Unlock()
il.entries[string(k)] = ignoreEntry{
ScriptKey: string(k),
AddedAt: time.Now(),
ExpiresAt: expiresAt,
AddedBy: by,
Reason: reason,
}
}
func (il *ignoreList) Remove(k scriptKey) {
il.mu.Lock()
delete(il.entries, string(k))
il.mu.Unlock()
}
func (il *ignoreList) Has(k scriptKey) bool {
il.mu.Lock()
defer il.mu.Unlock()
e, ok := il.entries[string(k)]
if !ok {
return false
}
if !e.ExpiresAt.IsZero() && time.Now().After(e.ExpiresAt) {
delete(il.entries, string(k))
return false
}
return true
}
func (il *ignoreList) List() []ignoreEntry {
il.mu.Lock()
defer il.mu.Unlock()
out := make([]ignoreEntry, 0, len(il.entries))
now := time.Now()
for k, e := range il.entries {
if !e.ExpiresAt.IsZero() && now.After(e.ExpiresAt) {
delete(il.entries, k)
continue
}
out = append(out, e)
}
return out
}
// SweepExpired drops expired entries. Called by Flow E ticker.
//
//nolint:unused // wired in O2 Flow E ticker
func (il *ignoreList) SweepExpired(now time.Time) int {
il.mu.Lock()
defer il.mu.Unlock()
n := 0
for k, e := range il.entries {
if !e.ExpiresAt.IsZero() && now.After(e.ExpiresAt) {
delete(il.entries, k)
n++
}
}
return n
}
const ignoreBucket = "phprelay:ignore"
func (il *ignoreList) SetStore(db *store.DB) { il.db = db }
func (il *ignoreList) AddPersist(k scriptKey, expiresAt time.Time, by, reason string) error {
il.Add(k, expiresAt, by, reason)
if il.db == nil {
return nil
}
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(ignoreEntry{
ScriptKey: string(k), AddedAt: time.Now(),
ExpiresAt: expiresAt, AddedBy: by, Reason: reason,
}); err != nil {
return err
}
return il.db.PHPRelayPut(ignoreBucket, string(k), buf.Bytes())
}
//nolint:unused // wired in M2/O2
func (il *ignoreList) RemovePersist(k scriptKey) error {
il.Remove(k)
if il.db == nil {
return nil
}
return il.db.PHPRelayDelete(ignoreBucket, string(k))
}
// Restore re-populates the in-memory list from bbolt at daemon start.
// Corrupt rows are skipped silently; expired rows are skipped (the bbolt
// row stays put until SweepBolt prunes it on the next Flow E tick).
func (il *ignoreList) Restore() error {
if il.db == nil {
return nil
}
rows, err := il.db.PHPRelayList(ignoreBucket)
if err != nil {
return err
}
now := time.Now()
for _, raw := range rows {
var e ignoreEntry
if err := gob.NewDecoder(bytes.NewReader(raw)).Decode(&e); err != nil {
continue
}
if !e.ExpiresAt.IsZero() && now.After(e.ExpiresAt) {
continue
}
il.Add(scriptKey(e.ScriptKey), e.ExpiresAt, e.AddedBy, e.Reason)
}
return nil
}
// SweepBolt drops expired bbolt entries on the Flow E ticker. Corrupt
// rows are also dropped so the bucket stays healthy.
//
//nolint:unused // wired in M2/O2
func (il *ignoreList) SweepBolt(now time.Time) (int, error) {
if il.db == nil {
return 0, nil
}
return il.db.PHPRelaySweep(ignoreBucket, func(_, value []byte) bool {
var e ignoreEntry
if err := gob.NewDecoder(bytes.NewReader(value)).Decode(&e); err != nil {
return true // drop corrupt rows
}
return !e.ExpiresAt.IsZero() && now.After(e.ExpiresAt)
})
}
package daemon
import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/emailspool"
)
// msgIDPattern guards exim -Mf invocations against header-injected garbage
// that slipped past parseHeaders. Exim msgIDs are <= 23 chars in practice
// but we accept up to 32 to allow for future format changes; the lower
// bound of 16 rules out any short string an attacker could try to slip in.
var msgIDPattern = regexp.MustCompile(`^[A-Za-z0-9-]{16,32}$`)
// eximBinary is resolved at module init via exec.LookPath. Empty means
// auto-action is permanently disabled (a Warning finding is emitted at
// startup -- see Phase O).
//
//nolint:unused // populated in K5 by AutoFreezePHPRelayQueue init via exec.LookPath
var eximBinary string
// actionRateLimiter is a sliding-window counter of exim -M* invocations.
// Per spec: at most maxPerMinute actions in any rolling 60s window.
type actionRateLimiter struct {
mu sync.Mutex
maxPerMin int
bucket int
refilledAt time.Time
now func() time.Time
}
func newActionRateLimiter(maxPerMin int) *actionRateLimiter {
return &actionRateLimiter{
maxPerMin: maxPerMin,
bucket: maxPerMin,
now: time.Now,
}
}
// consumeN returns true if n tokens were available and consumed.
func (rl *actionRateLimiter) consumeN(n int) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := rl.now()
if rl.refilledAt.IsZero() || now.Sub(rl.refilledAt) >= time.Minute {
rl.bucket = rl.maxPerMin
rl.refilledAt = now
}
if rl.bucket < n {
return false
}
rl.bucket -= n
return true
}
// freezeErrIsAlreadyGone matches the Exim stderr fragments emitted when
// the message has already left the queue between snapshot and freeze.
// Those are not action failures -- they are normal queue churn.
func freezeErrIsAlreadyGone(stderr string) bool {
s := strings.ToLower(stderr)
return strings.Contains(s, "message not found") ||
strings.Contains(s, "spool file not found") ||
strings.Contains(s, "no such message")
}
// spoolScanMatchingScript walks every -H file under spoolRoot, parses
// headers, and returns msgIDs whose X-PHP-Script host:path matches
// scriptKey. Used by AutoFreezePHPRelayQueue when activeMsgs was capped
// or when a late reputation finding has no in-memory activeMsgs left.
//
// Handles BOTH spool layouts:
//
// 1. Split (cPanel default + Exim's split_spool_directory=true): each
// msgID-H lives under spoolRoot/<hash-char>/. We must descend one
// level into each subdir.
// 2. Unsplit (some self-hosted Exim builds, smaller cPanel installs
// where the operator has disabled split_spool_directory): -H files
// live directly in spoolRoot.
//
// The spec section 5.8 explicitly requires both layouts. We probe each
// entry: if it's a regular -H file at the root, scan it; if it's a
// directory, descend. No probing of /etc/exim or spool config -- the
// filesystem layout is the source of truth.
func spoolScanMatchingScript(spoolRoot string, k scriptKey) []string {
var out []string
// #nosec G304 -- spoolRoot is operator-configured / hardcoded to cPanel default.
entries, err := os.ReadDir(spoolRoot)
if err != nil {
return nil
}
inspect := func(full string, name string) {
if !strings.HasSuffix(name, "-H") {
return
}
h, err := emailspool.ParseHeaders(full)
if err != nil || h.XPHPScript == "" {
return
}
sk, _ := parseXPHPScript(h.XPHPScript)
if sk != k {
return
}
id := strings.TrimSuffix(name, "-H")
if msgIDPattern.MatchString(id) {
out = append(out, id)
}
}
for _, e := range entries {
full := filepath.Join(spoolRoot, e.Name())
if e.IsDir() {
// Split layout: descend one level.
// #nosec G304 -- spoolRoot is operator-configured / hardcoded to cPanel default.
files, err := os.ReadDir(full)
if err != nil {
continue
}
for _, f := range files {
inspect(filepath.Join(full, f.Name()), f.Name())
}
continue
}
// Unsplit layout: -H files at the root of spoolRoot.
inspect(full, e.Name())
}
return out
}
// runner abstracts exec.CommandContext so tests can inject a stub.
type runner interface {
Run(ctx context.Context, bin string, args []string) (stderr string, err error)
}
type defaultRunner struct{}
func (defaultRunner) Run(ctx context.Context, bin string, args []string) (string, error) {
// #nosec G204 -- bin is the operator-configured exim binary path
// (autoFreezer.eximBin, default /usr/sbin/exim); args are exim flags +
// validated msg-IDs from the spool. Not attacker-controlled.
cmd := exec.CommandContext(ctx, bin, args...)
var sb strings.Builder
cmd.Stderr = &sb
err := cmd.Run()
return sb.String(), err
}
// auditEntry is the per-action record written by the auditor. K6 will add a
// JSONL serialiser; K5 only constructs the in-memory shape.
type auditEntry struct {
Ts time.Time
MsgID string
ScriptKey string
Path string
DryRun bool
Exit int
Stderr string
Action string // "freeze" | "thaw" | "freeze_dry_run"
}
type auditor interface {
Write(e auditEntry)
}
// autoFreezer holds the wiring needed to translate findings into exim -Mf
// invocations. Constructed once at daemon start; Apply is invoked per
// post-emit AutoResponse pass.
//
// dryRunFn returns the EFFECTIVE dry-run state. The CLI's runtime
// override + bbolt override + csm.yaml fallback are resolved by the
// PHPRelayController.effectiveDryRun and threaded in via dryRunFn so
// `csm phprelay dry-run on|off|reset` actually changes freeze
// behaviour. autoFreezer never reads cfg.AutoResponse.PHPRelay.DryRun
// directly.
//
//nolint:unused // wired in O2 by daemon controller
type autoFreezer struct {
scripts *perScriptWindow
cfg *config.Config
spoolRoot string
eximBin string
runner runner
auditor auditor
rateLim *actionRateLimiter
metrics *phpRelayMetrics
dryRunFn func() bool
}
//nolint:unused // wired in O2 by daemon controller
func newAutoFreezer(scripts *perScriptWindow, cfg *config.Config, spoolRoot, eximBin string, r runner, a auditor, m *phpRelayMetrics, dryRunFn func() bool) *autoFreezer {
if r == nil {
r = defaultRunner{}
}
rl := newActionRateLimiter(cfg.AutoResponse.PHPRelay.MaxActionsPerMinute)
if cfg.AutoResponse.PHPRelay.MaxActionsPerMinute <= 0 {
rl = newActionRateLimiter(60)
}
if dryRunFn == nil {
// Defensive default: if no resolver wired, fall back to the safe
// YAML-level dry-run state (PHPRelayDryRunEnabled defaults to TRUE).
dryRunFn = cfg.PHPRelayDryRunEnabled
}
return &autoFreezer{
scripts: scripts, cfg: cfg, spoolRoot: spoolRoot, eximBin: eximBin,
runner: r, auditor: a, rateLim: rl, metrics: m, dryRunFn: dryRunFn,
}
}
// Apply iterates findings, snapshots each script's activeMsgs, optionally
// extends with a spool-scan fallback, and freezes via exim -Mf. Returns
// any new findings produced (Warning/Critical for action outcomes). Pure
// from the perspective of finding emission -- caller forwards them to the
// alert pipeline.
//
//nolint:unused // wired in O2 by daemon controller
func (a *autoFreezer) Apply(findings []alert.Finding) []alert.Finding {
var emitted []alert.Finding
if !a.cfg.AutoResponse.Enabled || !a.cfg.PHPRelayFreezeEnabled() {
return nil
}
if a.eximBin == "" {
return nil
}
dryRun := a.dryRunFn()
for _, f := range findings {
if f.Check != "email_php_relay_abuse" {
continue
}
if !canActOnPath(f.Path) {
emitted = append(emitted, alert.Finding{
Severity: alert.Warning,
Check: "email_php_relay_action_skipped",
Path: f.Path,
Message: fmt.Sprintf("AutoFreeze skipped: path %q has no scriptKey", f.Path),
Timestamp: time.Now(),
})
continue
}
s := a.scripts.getOrCreate(scriptKey(f.ScriptKey))
ids, capped := s.snapshotActiveMsgs()
if capped || (len(ids) == 0 && f.Path == "reputation") {
if a.metrics != nil {
if capped {
a.metrics.SpoolScanFallbacks.With("capped").Inc()
} else {
a.metrics.SpoolScanFallbacks.With("reputation").Inc()
}
}
extra := spoolScanMatchingScript(a.spoolRoot, scriptKey(f.ScriptKey))
ids = unionStrings(ids, extra)
}
if len(ids) == 0 {
continue
}
if dryRun {
for _, id := range ids {
a.auditor.Write(auditEntry{
Ts: time.Now(), MsgID: id, ScriptKey: f.ScriptKey,
Path: f.Path, DryRun: true, Action: "freeze_dry_run",
})
}
emitted = append(emitted, alert.Finding{
Severity: alert.Warning,
Check: "email_php_relay_action_dry_run",
Path: f.Path,
Message: fmt.Sprintf("AutoFreeze dry-run: would freeze %d msgs from %s", len(ids), f.ScriptKey),
ScriptKey: f.ScriptKey,
Timestamp: time.Now(),
})
continue
}
if !a.rateLim.consumeN(len(ids)) {
emitted = append(emitted, alert.Finding{
Severity: alert.Warning,
Check: "email_php_relay_rate_limit_hit",
Path: f.Path,
Message: fmt.Sprintf("AutoFreeze rate limit prevented %d freezes for %s", len(ids), f.ScriptKey),
ScriptKey: f.ScriptKey,
Timestamp: time.Now(),
})
continue
}
var failed []string
for _, id := range ids {
if !msgIDPattern.MatchString(id) {
continue
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
stderr, err := a.runner.Run(ctx, a.eximBin, []string{"-Mf", id})
cancel()
entry := auditEntry{
Ts: time.Now(), MsgID: id, ScriptKey: f.ScriptKey,
Path: f.Path, DryRun: false, Stderr: stderr, Action: "freeze",
}
if err != nil {
if freezeErrIsAlreadyGone(stderr) {
if a.metrics != nil {
a.metrics.ActionGone.Inc()
}
a.auditor.Write(entry)
continue
}
entry.Exit = 1
failed = append(failed, id)
if a.metrics != nil {
a.metrics.Actions.With("freeze", "fail").Inc()
}
}
a.auditor.Write(entry)
// On successful freeze, drop the id from activeMsgs so we
// don't re-freeze it on the next finding emission.
if err == nil {
if a.metrics != nil {
a.metrics.Actions.With("freeze", "ok").Inc()
}
s.removeActive(id)
}
}
if len(failed) > 0 {
emitted = append(emitted, alert.Finding{
Severity: alert.Critical,
Check: "email_php_relay_action_failed",
Path: f.Path,
Message: fmt.Sprintf("exim -Mf failed for %d msgs from %s", len(failed), f.ScriptKey),
ScriptKey: f.ScriptKey,
MsgIDs: failed,
Timestamp: time.Now(),
})
}
}
return emitted
}
// canActOnPath reports whether AutoFreeze can act on a finding's Path.
// volume_account fires per-cpuser without scriptKey; baseline / reputation /
// header / volume / fanout all carry scriptKey.
//
//nolint:unused // wired in O2 by daemon controller
func canActOnPath(p string) bool {
switch p {
case "header", "volume", "fanout", "baseline", "reputation":
return true
}
return false
}
//nolint:unused // wired in O2 by daemon controller
func unionStrings(a, b []string) []string {
seen := make(map[string]struct{}, len(a)+len(b))
out := make([]string, 0, len(a)+len(b))
for _, s := range append(a, b...) {
if _, ok := seen[s]; ok {
continue
}
seen[s] = struct{}{}
out = append(out, s)
}
return out
}
type structuredAuditor struct {
mu sync.Mutex
w io.Writer
}
func newStructuredAuditor(w io.Writer) *structuredAuditor { return &structuredAuditor{w: w} }
func (a *structuredAuditor) Write(e auditEntry) {
payload := struct {
Ts time.Time `json:"ts"`
MsgID string `json:"msg_id"`
ScriptKey string `json:"script_key"`
Path string `json:"path"`
Action string `json:"action"`
DryRun bool `json:"dry_run"`
Exit int `json:"exit"`
Stderr string `json:"stderr,omitempty"`
}{
Ts: e.Ts.UTC(), MsgID: e.MsgID, ScriptKey: e.ScriptKey,
Path: e.Path, Action: e.Action, DryRun: e.DryRun,
Exit: e.Exit, Stderr: e.Stderr,
}
line, err := json.Marshal(payload)
if err != nil {
return
}
a.mu.Lock()
defer a.mu.Unlock()
_, _ = a.w.Write(line)
_, _ = a.w.Write([]byte("\n"))
}
package daemon
import (
"io"
"os"
"sync"
"github.com/pidginhost/csm/internal/config"
)
// eximAuditWriterAt returns the writer used by the structured JSONL auditor.
// When auto-freeze is disabled, it returns a lazy writer so daemon startup
// does not create a 0-byte orphan log file, while manual thaw commands can
// still create the file on their first real audit entry. When auto-freeze is
// enabled, the file is opened at startup to preserve the existing live-action
// failure visibility.
func eximAuditWriterAt(cfg *config.Config, path string) io.Writer {
if cfg == nil || !cfg.PHPRelayFreezeEnabled() {
return &lazyEximAuditWriter{path: path}
}
return openEximAuditWriterAt(path)
}
type lazyEximAuditWriter struct {
mu sync.Mutex
path string
w io.Writer
}
func (w *lazyEximAuditWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
if w.w == nil {
w.w = openEximAuditWriterAt(w.path)
}
return w.w.Write(p)
}
func openEximAuditWriterAt(path string) io.Writer {
// #nosec G304 G302 -- G304: path is the compile-time constant phpRelayAuditPath in production; tests pass t.TempDir-derived paths. G302: 0640 is intentional; SIEM log shippers (Vector, Filebeat, Fluentbit) commonly run as a non-root user that needs group-read access. 0600 would force the shipper to run as root. Same rationale as internal/alert/audit_jsonl.go.
if f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0640); err == nil {
return f
}
return os.Stderr
}
package daemon
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/pidginhost/csm/internal/control"
"github.com/pidginhost/csm/internal/store"
)
// PHPRelayController aggregates the phprelay-state references the
// control-socket handlers need. One instance per running daemon;
// constructed in daemon.New (Phase O2) and assigned to ControlListener.phprelay.
//
// The spool-pipeline reference (linux-only spoolPipeline) is intentionally
// not held here so this file stays cross-platform. Phase O2 can attach
// the pipeline through a separate linux-gated wiring file if needed.
type PHPRelayController struct {
eng *evaluator
msgIndex *msgIDIndex
ignores *ignoreList
actionDryRun *runtimeBool
db *store.DB
runner runner
eximBin string
auditor auditor
enabled bool
platform string
}
// runtimeBool is the in-memory dry-run override; effective-value precedence
// resolves CLI > bbolt > csm.yaml at read time.
type runtimeBool struct {
mu sync.Mutex
set bool
value bool
}
func (r *runtimeBool) Set(v bool) { r.mu.Lock(); r.set = true; r.value = v; r.mu.Unlock() }
func (r *runtimeBool) Reset() { r.mu.Lock(); r.set = false; r.mu.Unlock() }
func (r *runtimeBool) Get() (value, set bool) {
r.mu.Lock()
defer r.mu.Unlock()
return r.value, r.set
}
// effectiveDryRun resolves precedence: runtime > bbolt > csm.yaml.
// Returns (effective, source) where source identifies the winning input.
func (c *PHPRelayController) effectiveDryRun() (bool, string) {
if v, set := c.actionDryRun.Get(); set {
return v, "runtime"
}
if c.db != nil {
if v, ok, err := readDryRunOverride(c.db); err == nil && ok {
return v, "bbolt"
}
}
if c.eng != nil && c.eng.cfg != nil {
return c.eng.cfg.PHPRelayDryRunEnabled(), "csm.yaml"
}
return true, "default"
}
// Status returns a snapshot of detector state for `csm phprelay status`.
func (c *PHPRelayController) Status(_ context.Context, _ control.PHPRelayStatusRequest) (control.PHPRelayStatusResponse, error) {
resp := control.PHPRelayStatusResponse{
Enabled: c.enabled,
Platform: c.platform,
EffectiveAccountLimit: c.eng.effectiveAccountLimit,
IgnoresActive: len(c.ignores.List()),
RecentFindings: map[string]int{}, // populated by metrics in Phase N
}
eff, _ := c.effectiveDryRun()
resp.DryRun = eff
if c.eng.scripts != nil {
resp.ScriptsTracked = len(c.eng.scripts.Snapshot())
}
if c.msgIndex != nil {
resp.MsgIDIndexSize = c.msgIndex.Len()
}
return resp, nil
}
// handlePHPRelayStatus is the dispatcher-side adapter that bridges the
// json.RawMessage args to the typed Status method.
func (c *ControlListener) handlePHPRelayStatus(argsRaw json.RawMessage) (any, error) {
if c.phprelay == nil {
return nil, fmt.Errorf("phprelay controller not wired (Phase O2)")
}
var req control.PHPRelayStatusRequest
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &req); err != nil {
return nil, fmt.Errorf("bad args: %w", err)
}
}
return c.phprelay.Status(context.Background(), req)
}
func (c *PHPRelayController) IgnoreScript(_ context.Context, req control.PHPRelayIgnoreScriptRequest) (control.PHPRelayIgnoreScriptResponse, error) {
if req.ScriptKey == "" {
return control.PHPRelayIgnoreScriptResponse{}, errors.New("script_key required")
}
hours := req.ForHours
if hours == 0 {
hours = 24 * 7
}
expires := time.Now().Add(time.Duration(hours) * time.Hour)
by := req.AddedBy
if by == "" {
by = "operator"
}
if req.Persist {
if err := c.ignores.AddPersist(scriptKey(req.ScriptKey), expires, by, req.Reason); err != nil {
return control.PHPRelayIgnoreScriptResponse{}, err
}
} else {
c.ignores.Add(scriptKey(req.ScriptKey), expires, by, req.Reason)
}
return control.PHPRelayIgnoreScriptResponse{ExpiresAt: expires}, nil
}
func (c *PHPRelayController) Unignore(_ context.Context, req control.PHPRelayUnignoreRequest) (struct{}, error) {
if req.ScriptKey == "" {
return struct{}{}, errors.New("script_key required")
}
if req.Persist {
if err := c.ignores.RemovePersist(scriptKey(req.ScriptKey)); err != nil {
return struct{}{}, err
}
} else {
c.ignores.Remove(scriptKey(req.ScriptKey))
}
return struct{}{}, nil
}
func (c *PHPRelayController) IgnoreList(_ context.Context, _ struct{}) (control.PHPRelayIgnoreListResponse, error) {
raw := c.ignores.List()
out := make([]control.PHPRelayIgnoreEntry, 0, len(raw))
for _, e := range raw {
out = append(out, control.PHPRelayIgnoreEntry{
ScriptKey: e.ScriptKey, ExpiresAt: e.ExpiresAt,
AddedBy: e.AddedBy, Reason: e.Reason,
})
}
return control.PHPRelayIgnoreListResponse{Entries: out}, nil
}
func (c *ControlListener) handlePHPRelayIgnoreScript(argsRaw json.RawMessage) (any, error) {
if c.phprelay == nil {
return nil, fmt.Errorf("phprelay controller not wired (Phase O2)")
}
var req control.PHPRelayIgnoreScriptRequest
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &req); err != nil {
return nil, fmt.Errorf("bad args: %w", err)
}
}
return c.phprelay.IgnoreScript(context.Background(), req)
}
func (c *ControlListener) handlePHPRelayUnignore(argsRaw json.RawMessage) (any, error) {
if c.phprelay == nil {
return nil, fmt.Errorf("phprelay controller not wired (Phase O2)")
}
var req control.PHPRelayUnignoreRequest
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &req); err != nil {
return nil, fmt.Errorf("bad args: %w", err)
}
}
return c.phprelay.Unignore(context.Background(), req)
}
func (c *ControlListener) handlePHPRelayIgnoreList(argsRaw json.RawMessage) (any, error) {
if c.phprelay == nil {
return nil, fmt.Errorf("phprelay controller not wired (Phase O2)")
}
_ = argsRaw // no args expected
return c.phprelay.IgnoreList(context.Background(), struct{}{})
}
func (c *PHPRelayController) DryRun(_ context.Context, req control.PHPRelayDryRunRequest) (control.PHPRelayDryRunResponse, error) {
switch req.Mode {
case "on":
c.actionDryRun.Set(true)
if req.Persist {
if err := writeDryRunOverride(c.db, true, "operator"); err != nil {
return control.PHPRelayDryRunResponse{}, err
}
}
case "off":
c.actionDryRun.Set(false)
if req.Persist {
if err := writeDryRunOverride(c.db, false, "operator"); err != nil {
return control.PHPRelayDryRunResponse{}, err
}
}
case "reset":
c.actionDryRun.Reset()
if req.Persist {
if err := deleteDryRunOverride(c.db); err != nil {
return control.PHPRelayDryRunResponse{}, err
}
}
default:
return control.PHPRelayDryRunResponse{}, errors.New("mode must be on|off|reset")
}
eff, src := c.effectiveDryRun()
return control.PHPRelayDryRunResponse{Effective: eff, Source: src}, nil
}
// DryRunFn returns a closure that evaluates the precedence chain on
// every call. Daemon wiring passes this to newAutoFreezer so that
// `csm phprelay dry-run` actually changes freeze behaviour without
// rebuilding the freezer.
func (c *PHPRelayController) DryRunFn() func() bool {
return func() bool {
v, _ := c.effectiveDryRun()
return v
}
}
func (c *ControlListener) handlePHPRelayDryRun(argsRaw json.RawMessage) (any, error) {
if c.phprelay == nil {
return nil, fmt.Errorf("phprelay controller not wired (Phase O2)")
}
var req control.PHPRelayDryRunRequest
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &req); err != nil {
return nil, fmt.Errorf("bad args: %w", err)
}
}
return c.phprelay.DryRun(context.Background(), req)
}
// Thaw runs `exim -Mt <msg_id>` to release a frozen message back to the
// queue. msgIDPattern validation guards against header-injected garbage
// even though only operators can hit this endpoint. The audit entry is
// written for both success and failure so an operator can later prove
// what was thawed.
//
// req.By is accepted on the wire for forward compatibility (future
// auditEntry.By field) but is not used by the M4 handler.
func (c *PHPRelayController) Thaw(ctx context.Context, req control.PHPRelayThawRequest) (control.PHPRelayThawResponse, error) {
if !msgIDPattern.MatchString(req.MsgID) {
return control.PHPRelayThawResponse{}, errors.New("invalid msg_id")
}
sub, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
stderr, err := c.runner.Run(sub, c.eximBin, []string{"-Mt", req.MsgID})
c.auditor.Write(auditEntry{
Ts: time.Now(), MsgID: req.MsgID, Action: "thaw",
Stderr: stderr,
})
if err != nil {
return control.PHPRelayThawResponse{Stderr: stderr}, err
}
return control.PHPRelayThawResponse{Stderr: stderr}, nil
}
func (c *ControlListener) handlePHPRelayThaw(argsRaw json.RawMessage) (any, error) {
if c.phprelay == nil {
return nil, fmt.Errorf("phprelay controller not wired (Phase O2)")
}
var req control.PHPRelayThawRequest
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &req); err != nil {
return nil, fmt.Errorf("bad args: %w", err)
}
}
return c.phprelay.Thaw(context.Background(), req)
}
package daemon
import (
"sync"
"github.com/pidginhost/csm/internal/metrics"
)
// phpRelayMetrics holds every series the module emits via the local
// internal/metrics OpenMetrics implementation.
//
// Defined as a struct of pointers so callers can pass nil and skip
// observation. All increments at call sites are guarded by
// `if e.metrics != nil` checks.
type phpRelayMetrics struct {
Findings *metrics.CounterVec // labels: path
Actions *metrics.CounterVec // labels: action, result
PathSkipped *metrics.CounterVec // labels: path, reason
WindowsActive *metrics.GaugeVec // labels: kind (script/ip/account)
MsgIDIndexSize *metrics.GaugeVec // labels: layer (memory/bbolt)
MsgindexPersistDropped *metrics.Counter
MsgindexPersistErrors *metrics.Counter
InotifyOverflows *metrics.Counter
SpoolReadErrors *metrics.Counter
UserdataErrors *metrics.Counter
ActiveMsgsCapped *metrics.Counter
SpoolScanFallbacks *metrics.CounterVec // labels: reason
ActionGone *metrics.Counter
}
var (
phpRelayMetricsOnce sync.Once
phpRelayMetricsInstance *phpRelayMetrics
)
// newPHPRelayMetrics constructs (and registers via the package default
// registry) the singleton metric set. Subsequent calls return the same
// instance -- sync.Once protects against the duplicate-name panic from
// metrics.MustRegister.
func newPHPRelayMetrics() *phpRelayMetrics {
phpRelayMetricsOnce.Do(func() {
m := &phpRelayMetrics{
Findings: metrics.NewCounterVec("csm_php_relay_findings_total", "Findings emitted by php_relay paths.", []string{"path"}),
Actions: metrics.NewCounterVec("csm_php_relay_actions_total", "AutoFreeze actions attempted.", []string{"action", "result"}),
PathSkipped: metrics.NewCounterVec("csm_php_relay_path_skipped_total", "Path evaluation skipped.", []string{"path", "reason"}),
WindowsActive: metrics.NewGaugeVec("csm_php_relay_windows_active", "Active windows per kind.", []string{"kind"}),
MsgIDIndexSize: metrics.NewGaugeVec("csm_php_relay_msgid_index_size", "msgIDIndex size by storage layer.", []string{"layer"}),
MsgindexPersistDropped: metrics.NewCounter("csm_php_relay_msgindex_persist_dropped_total", "Persist queue overflow drops."),
MsgindexPersistErrors: metrics.NewCounter("csm_php_relay_msgindex_persist_errors_total", "bbolt commit failures."),
InotifyOverflows: metrics.NewCounter("csm_php_relay_inotify_overflows_total", "IN_Q_OVERFLOW events."),
SpoolReadErrors: metrics.NewCounter("csm_php_relay_spool_read_errors_total", "Spool -H read errors."),
UserdataErrors: metrics.NewCounter("csm_php_relay_userdata_errors_total", "cpanelUserDomains read errors."),
ActiveMsgsCapped: metrics.NewCounter("csm_php_relay_active_msgs_capped_total", "scriptState.activeMsgs cap-hit events."),
SpoolScanFallbacks: metrics.NewCounterVec("csm_php_relay_spool_scan_fallbacks_total", "AutoFreeze spool-scan fallback invocations.", []string{"reason"}),
ActionGone: metrics.NewCounter("csm_php_relay_action_gone_total", "Messages already absent at exim -Mf time."),
}
metrics.MustRegister("csm_php_relay_findings_total", m.Findings)
metrics.MustRegister("csm_php_relay_actions_total", m.Actions)
metrics.MustRegister("csm_php_relay_path_skipped_total", m.PathSkipped)
metrics.MustRegister("csm_php_relay_windows_active", m.WindowsActive)
metrics.MustRegister("csm_php_relay_msgid_index_size", m.MsgIDIndexSize)
metrics.MustRegister("csm_php_relay_msgindex_persist_dropped_total", m.MsgindexPersistDropped)
metrics.MustRegister("csm_php_relay_msgindex_persist_errors_total", m.MsgindexPersistErrors)
metrics.MustRegister("csm_php_relay_inotify_overflows_total", m.InotifyOverflows)
metrics.MustRegister("csm_php_relay_spool_read_errors_total", m.SpoolReadErrors)
metrics.MustRegister("csm_php_relay_userdata_errors_total", m.UserdataErrors)
metrics.MustRegister("csm_php_relay_active_msgs_capped_total", m.ActiveMsgsCapped)
metrics.MustRegister("csm_php_relay_spool_scan_fallbacks_total", m.SpoolScanFallbacks)
metrics.MustRegister("csm_php_relay_action_gone_total", m.ActionGone)
phpRelayMetricsInstance = m
})
return phpRelayMetricsInstance
}
package daemon
import (
"sync"
"time"
)
// indexEntry maps a message ID to the per-message attribution recorded at
// acceptance. Used by Path 3 (Stage 2) to map delivery-failure log lines
// back to the originating script. Public field names because gob-encoded
// for bbolt persistence in Task C3.
type indexEntry struct {
ScriptKey string
HeaderScore int
SourceIP string
CPUser string
At time.Time
}
// msgIDIndex stores indexEntry per msgID, served from memory.
// Bounded by maxEntries; overflow drops the oldest entry by acceptance time.
// Persistence to bbolt is handled by msgIndexPersister (Task C3).
type msgIDIndex struct {
mu sync.Mutex
entries map[string]indexEntry
maxEntries int
persister *msgIndexPersister // nil in unit tests; real in production
}
func newMsgIDIndex(persister *msgIndexPersister, maxEntries int) *msgIDIndex {
if maxEntries <= 0 {
maxEntries = 200_000
}
return &msgIDIndex{
entries: make(map[string]indexEntry, 4096),
maxEntries: maxEntries,
persister: persister,
}
}
// Put records an entry. If the in-memory map exceeds maxEntries, the
// oldest entry by At is evicted from memory. The persister (if non-nil)
// receives the put asynchronously; persistence failure does not affect
// in-memory correctness.
func (i *msgIDIndex) Put(msgID string, e indexEntry) {
i.mu.Lock()
if _, ok := i.entries[msgID]; !ok && len(i.entries) >= i.maxEntries {
i.evictOldestLocked()
}
i.entries[msgID] = e
i.mu.Unlock()
if i.persister != nil {
i.persister.Enqueue(msgID, e)
}
}
// Get returns the entry and whether it was present.
func (i *msgIDIndex) Get(msgID string) (indexEntry, bool) {
i.mu.Lock()
e, ok := i.entries[msgID]
i.mu.Unlock()
return e, ok
}
// Has reports whether msgID is present in memory.
func (i *msgIDIndex) Has(msgID string) bool {
i.mu.Lock()
_, ok := i.entries[msgID]
i.mu.Unlock()
return ok
}
// Len returns the number of entries currently in memory.
func (i *msgIDIndex) Len() int {
i.mu.Lock()
defer i.mu.Unlock()
return len(i.entries)
}
// SweepMemory drops entries whose At is at or before cutoff.
// Called by Flow E's 1-min ticker with cutoff = now - 4h.
func (i *msgIDIndex) SweepMemory(cutoff time.Time) int {
i.mu.Lock()
defer i.mu.Unlock()
n := 0
for id, e := range i.entries {
if !e.At.After(cutoff) {
delete(i.entries, id)
n++
}
}
return n
}
func (i *msgIDIndex) evictOldestLocked() {
var oldestID string
var oldestAt time.Time
first := true
for id, e := range i.entries {
if first || e.At.Before(oldestAt) {
oldestID = id
oldestAt = e.At
first = false
}
}
if oldestID != "" {
delete(i.entries, oldestID)
}
}
package daemon
import (
"bytes"
"encoding/gob"
"fmt"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/store"
)
const msgIndexBucket = "phprelay:msgindex"
// All bbolt access goes through `*store.DB`'s phprelay helpers. The
// underlying `*bolt.DB` is unexported by design (internal/store/db.go);
// daemon code never imports go.etcd.io/bbolt directly.
// msgIndexPersister persists msgIDIndex entries to bbolt off the hot path.
// Public methods are safe for concurrent callers. Persistence failure does
// not affect the in-memory index; it only degrades restart-recovery and
// emits a Critical via the alerter callback.
//
// SetErrorCallback must be invoked BEFORE Start to establish a
// happens-before relationship between the writer of `onError` and the
// goroutine that reads it; concurrent callers after Start are not safe.
type msgIndexPersister struct {
db *store.DB
queue chan persistOp
flushEvery time.Duration
batchSize int
stopCh chan struct{}
doneCh chan struct{}
flushReqCh chan chan struct{}
droppedTotal uint64
errorsTotal uint64
onError func(alert.Finding)
metrics *phpRelayMetrics
}
type persistOp struct {
msgID string
entry indexEntry
}
func newMsgIndexPersister(db *store.DB, queueSize int, flushEvery time.Duration) *msgIndexPersister {
if queueSize <= 0 {
queueSize = 4096
}
if flushEvery <= 0 {
flushEvery = 100 * time.Millisecond
}
return &msgIndexPersister{
db: db,
queue: make(chan persistOp, queueSize),
flushEvery: flushEvery,
batchSize: 256,
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
flushReqCh: make(chan chan struct{}),
onError: func(alert.Finding) {},
}
}
// SetErrorCallback wires Critical findings emission for bbolt failures.
// Optional -- nil disables emission (used by tests). Must be called before
// Start; concurrent invocation after Start is not safe.
func (p *msgIndexPersister) SetErrorCallback(fn func(alert.Finding)) {
if fn != nil {
p.onError = fn
}
}
// SetMetrics wires the phpRelayMetrics sink. Optional -- nil disables
// observation (used by tests). Must be called before Start; concurrent
// invocation after Start is not safe.
func (p *msgIndexPersister) SetMetrics(m *phpRelayMetrics) {
p.metrics = m
}
func (p *msgIndexPersister) Start() {
go p.run()
}
func (p *msgIndexPersister) Stop() {
close(p.stopCh)
<-p.doneCh
}
// Enqueue is non-blocking. Returns immediately; the op is dropped if the
// queue is full (in which case DroppedCount increments).
func (p *msgIndexPersister) Enqueue(msgID string, e indexEntry) {
select {
case p.queue <- persistOp{msgID: msgID, entry: e}:
default:
atomic.AddUint64(&p.droppedTotal, 1)
if p.metrics != nil {
p.metrics.MsgindexPersistDropped.Inc()
}
}
}
// Flush blocks until the persister has drained whatever was already
// enqueued at call time. For tests and shutdown.
func (p *msgIndexPersister) Flush() {
done := make(chan struct{})
select {
case p.flushReqCh <- done:
<-done
case <-p.stopCh:
}
}
func (p *msgIndexPersister) DroppedCount() uint64 {
return atomic.LoadUint64(&p.droppedTotal)
}
func (p *msgIndexPersister) ErrorCount() uint64 {
return atomic.LoadUint64(&p.errorsTotal)
}
// Lookup reads an entry from bbolt by msgID.
func (p *msgIndexPersister) Lookup(msgID string) (indexEntry, bool, error) {
raw, ok, err := p.db.PHPRelayGet(msgIndexBucket, msgID)
if err != nil || !ok {
return indexEntry{}, false, err
}
var e indexEntry
if err := gob.NewDecoder(bytes.NewReader(raw)).Decode(&e); err != nil {
return indexEntry{}, false, fmt.Errorf("decode %s: %w", msgID, err)
}
return e, true, nil
}
// SweepBolt deletes phprelay:msgindex entries whose At <= cutoff.
// Returns the number of entries removed. Called by Flow E's 1-min ticker.
// Corrupt rows (decode failure) are also dropped to keep the bucket
// healthy.
func (p *msgIndexPersister) SweepBolt(cutoff time.Time) (int, error) {
return p.db.PHPRelaySweep(msgIndexBucket, func(_, value []byte) bool {
var e indexEntry
if err := gob.NewDecoder(bytes.NewReader(value)).Decode(&e); err != nil {
return true // drop corrupt rows
}
return !e.At.After(cutoff)
})
}
func (p *msgIndexPersister) run() {
defer close(p.doneCh)
ticker := time.NewTicker(p.flushEvery)
defer ticker.Stop()
var pending []persistOp
for {
select {
case <-p.stopCh:
p.commitBatch(pending)
return
case op := <-p.queue:
pending = append(pending, op)
if len(pending) >= p.batchSize {
p.commitBatch(pending)
pending = pending[:0]
}
case <-ticker.C:
if len(pending) > 0 {
p.commitBatch(pending)
pending = pending[:0]
}
case done := <-p.flushReqCh:
// Drain queue before flushing.
drain := true
for drain {
select {
case op := <-p.queue:
pending = append(pending, op)
default:
drain = false
}
}
p.commitBatch(pending)
pending = pending[:0]
close(done)
}
}
}
func (p *msgIndexPersister) commitBatch(ops []persistOp) {
if len(ops) == 0 {
return
}
kvs := make([]store.PHPRelayKV, 0, len(ops))
var buf bytes.Buffer
for _, op := range ops {
buf.Reset()
if err := gob.NewEncoder(&buf).Encode(&op.entry); err != nil {
// Encoding failure is a code bug, not a transient I/O issue.
// Skip the offending op and continue with the rest of the batch.
atomic.AddUint64(&p.errorsTotal, 1)
if p.metrics != nil {
p.metrics.MsgindexPersistErrors.Inc()
}
p.onError(alert.Finding{
Severity: alert.Critical,
Check: "email_php_relay_msgindex_persist_failed",
Message: fmt.Sprintf("encode %s: %v", op.msgID, err),
})
continue
}
kvs = append(kvs, store.PHPRelayKV{
Key: []byte(op.msgID),
Value: append([]byte(nil), buf.Bytes()...),
})
}
if err := p.db.PHPRelayPutBatch(msgIndexBucket, kvs); err != nil {
atomic.AddUint64(&p.errorsTotal, 1)
if p.metrics != nil {
p.metrics.MsgindexPersistErrors.Inc()
}
p.onError(alert.Finding{
Severity: alert.Critical,
Check: "email_php_relay_msgindex_persist_failed",
Message: fmt.Sprintf("phprelay:msgindex commit failed (%d ops): %v", len(ops), err),
})
}
}
package daemon
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"os"
"time"
"github.com/pidginhost/csm/internal/alert"
)
const phpRelayHistoryMaxLineBytes = 1024 * 1024
// ScanEximHistoryForPHPRelayAccountVolume replays exim_mainlog through the
// Path 2b parser, used at daemon startup to populate perAccountWindow with
// recent outbound activity. Each accepted finding is delivered via emit.
//
// Reads the file lazily and skips a single oversized line instead of
// abandoning later entries. Caller passes `now` to keep retro replays
// deterministic for tests; production passes time.Now() once and the parser
// uses it for window math.
//
// ctx scopes the scan to daemon lifetime; nil leaves the scan unbounded for
// direct helper callers.
func ScanEximHistoryForPHPRelayAccountVolume(ctx context.Context, path string, eng *evaluator, now time.Time, emit func(alert.Finding)) {
// #nosec G304 -- path is operator-configured / hardcoded to cPanel default.
f, err := os.Open(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
reader := bufio.NewReaderSize(f, 64*1024)
var line []byte
oversized := false
for {
if phpRelayScanContextDone(ctx) {
return
}
part, rerr := reader.ReadSlice('\n')
if len(part) > 0 && !oversized {
if len(line)+len(part) > phpRelayHistoryMaxLineBytes {
oversized = true
line = nil
} else {
line = append(line, part...)
}
}
switch {
case rerr == nil:
if !oversized {
emitPHPRelayHistoryLine(line, eng, now, emit)
}
line = nil
oversized = false
case errors.Is(rerr, bufio.ErrBufferFull):
continue
case errors.Is(rerr, io.EOF):
if len(line) > 0 && !oversized {
emitPHPRelayHistoryLine(line, eng, now, emit)
}
return
default:
return
}
}
}
func phpRelayScanContextDone(ctx context.Context) bool {
if ctx == nil {
return false
}
select {
case <-ctx.Done():
return true
default:
return false
}
}
func emitPHPRelayHistoryLine(line []byte, eng *evaluator, now time.Time, emit func(alert.Finding)) {
line = bytes.TrimSuffix(line, []byte("\n"))
line = bytes.TrimSuffix(line, []byte("\r"))
for _, ev := range eng.parsePHPRelayAccountVolume(string(line), now) {
emit(ev)
}
}
package daemon
import (
"bytes"
"encoding/gob"
"errors"
"time"
"github.com/pidginhost/csm/internal/store"
)
const (
settingsBucket = "phprelay:settings"
dryRunOverrideKey = "dry_run_override"
)
type dryRunOverrideRow struct {
Value bool
UpdatedAt time.Time
UpdatedBy string
}
func writeDryRunOverride(db *store.DB, val bool, by string) error {
if db == nil {
return errors.New("db nil")
}
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(dryRunOverrideRow{Value: val, UpdatedAt: time.Now(), UpdatedBy: by}); err != nil {
return err
}
return db.PHPRelayPut(settingsBucket, dryRunOverrideKey, buf.Bytes())
}
func deleteDryRunOverride(db *store.DB) error {
if db == nil {
return nil
}
return db.PHPRelayDelete(settingsBucket, dryRunOverrideKey)
}
func readDryRunOverride(db *store.DB) (bool, bool, error) {
if db == nil {
return false, false, nil
}
raw, ok, err := db.PHPRelayGet(settingsBucket, dryRunOverrideKey)
if err != nil || !ok {
return false, false, err
}
var row dryRunOverrideRow
if err := gob.NewDecoder(bytes.NewReader(raw)).Decode(&row); err != nil {
return false, false, err
}
return row.Value, true, nil
}
//go:build linux
package daemon
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/unix"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/emailspool"
)
// spoolWatcher watches /var/spool/exim/input for new -H files. cPanel hashes
// msgIDs into 64+ subdirs; the watcher enumerates them at start and watches
// IN_CREATE on the parent so subdirs that appear later are also picked up.
//
// On every IN_CLOSE_WRITE / IN_MOVED_TO whose name ends in "-H", the
// supplied callback is invoked synchronously with the absolute path. The
// callback must not block long; spawn worker goroutines if needed.
type spoolWatcher struct {
root string
onFile func(path string)
fd int
parentW int
mu sync.Mutex
subDirs map[int]string // watch descriptor -> path
overflowCount uint64
onOverflow func() // invoked from Run() the moment IN_Q_OVERFLOW arrives
metrics *phpRelayMetrics
}
// SetOverflowHandler wires the recovery scan + Critical finding emission
// into the watcher. Caller passes a closure that calls runRecoveryScan
// against the spool root and emits findings via the daemon alerter.
func (w *spoolWatcher) SetOverflowHandler(fn func()) { w.onOverflow = fn }
// SetMetrics wires the phpRelayMetrics sink. Optional -- nil disables
// observation (used by tests). Must be called before Run; concurrent
// invocation after Run is not safe.
func (w *spoolWatcher) SetMetrics(m *phpRelayMetrics) { w.metrics = m }
func newSpoolWatcher(root string, onFile func(path string)) (*spoolWatcher, error) {
fd, err := unix.InotifyInit1(unix.IN_CLOEXEC | unix.IN_NONBLOCK)
if err != nil {
return nil, fmt.Errorf("inotify_init1: %w", err)
}
parentMask := uint32(unix.IN_CREATE | unix.IN_MOVED_TO)
parentW, err := unix.InotifyAddWatch(fd, root, parentMask)
if err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("inotify_add_watch %s: %w", root, err)
}
w := &spoolWatcher{
root: root,
onFile: onFile,
fd: fd,
parentW: parentW,
subDirs: make(map[int]string),
}
// Enumerate existing subdirs.
entries, err := os.ReadDir(root)
if err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("readdir %s: %w", root, err)
}
for _, e := range entries {
if !e.IsDir() {
continue
}
if err := w.addSubdir(filepath.Join(root, e.Name())); err != nil {
// Non-fatal -- continue with what we have.
continue
}
}
return w, nil
}
func (w *spoolWatcher) addSubdir(path string) error {
mask := uint32(unix.IN_CLOSE_WRITE | unix.IN_MOVED_TO)
wd, err := unix.InotifyAddWatch(w.fd, path, mask)
if err != nil {
return err
}
w.mu.Lock()
w.subDirs[wd] = path
w.mu.Unlock()
return nil
}
func (w *spoolWatcher) Close() error {
if w.fd != 0 {
return unix.Close(w.fd)
}
return nil
}
// Run drains inotify events until ctx is cancelled.
func (w *spoolWatcher) Run(ctx context.Context) {
defer func() { _ = w.Close() }()
buf := make([]byte, 16*1024)
for {
select {
case <-ctx.Done():
return
default:
}
n, err := syscall.Read(w.fd, buf)
if err != nil {
if errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR) {
// Briefly yield via select so cancellation is responsive.
select {
case <-ctx.Done():
return
default:
// Use a small ppoll-equivalent: read again after the kernel buffers.
var fdset unix.FdSet
// #nosec G115 -- w.fd%64 yields 0..63; conversion to uint is lossless.
fdset.Bits[w.fd/64] |= 1 << uint(w.fd%64)
ts := unix.Timespec{Sec: 0, Nsec: 100 * 1e6}
_, _ = unix.Pselect(w.fd+1, &fdset, nil, nil, &ts, nil)
continue
}
}
// Treat other errors as fatal; supervisor will restart us.
return
}
offset := 0
for offset+unix.SizeofInotifyEvent <= n {
// #nosec G103 -- bounds-checked above; standard inotify decode pattern.
ev := (*unix.InotifyEvent)(unsafe.Pointer(&buf[offset]))
nameBytes := buf[offset+unix.SizeofInotifyEvent : offset+unix.SizeofInotifyEvent+int(ev.Len)]
name := strings.TrimRight(string(nameBytes), "\x00")
offset += unix.SizeofInotifyEvent + int(ev.Len)
if ev.Mask&unix.IN_Q_OVERFLOW != 0 {
w.overflowCount++
if w.metrics != nil {
w.metrics.InotifyOverflows.Inc()
}
if w.onOverflow != nil {
w.onOverflow()
}
continue
}
if int(ev.Wd) == w.parentW {
if ev.Mask&(unix.IN_CREATE|unix.IN_MOVED_TO) != 0 && name != "" {
full := filepath.Join(w.root, name)
if fi, err := os.Stat(full); err == nil && fi.IsDir() {
_ = w.addSubdir(full)
}
}
continue
}
w.mu.Lock()
dir, ok := w.subDirs[int(ev.Wd)]
w.mu.Unlock()
if !ok || name == "" {
continue
}
if !strings.HasSuffix(name, "-H") {
continue
}
w.onFile(filepath.Join(dir, name))
}
}
}
// OverflowCount returns the number of IN_Q_OVERFLOW events observed.
// Used by the daemon to drive recovery scans (Task I3).
//
//nolint:unused // consumed by daemon wiring (Task O2)
func (w *spoolWatcher) OverflowCount() uint64 {
return w.overflowCount
}
// spoolPipeline ties together: parse headers -> compute signals -> update
// windows -> evaluate paths -> emit findings via alerter callback.
type spoolPipeline struct {
eng *evaluator
domains *userDomainsResolver
policies *emailspool.Policies
msgIndex *msgIDIndex
ignores *ignoreList
alerter func(alert.Finding)
rebuilding atomic.Bool
}
func newSpoolPipeline(eng *evaluator, domains *userDomainsResolver, pol *emailspool.Policies, idx *msgIDIndex, ignores *ignoreList, alerter func(alert.Finding)) *spoolPipeline {
eng.SetPolicies(pol)
return &spoolPipeline{
eng: eng, domains: domains, policies: pol, msgIndex: idx, ignores: ignores, alerter: alerter,
}
}
// SetRebuilding gates finding emission during the startup spool-walker
// rebuild pass. When true: state is updated, findings are NOT emitted.
func (p *spoolPipeline) SetRebuilding(v bool) { p.rebuilding.Store(v) }
// OnFile is the inotify callback. Parses, signals, updates state, evaluates.
func (p *spoolPipeline) OnFile(path string) {
h, err := emailspool.ParseHeaders(path)
if err != nil {
if p.eng != nil && p.eng.metrics != nil {
p.eng.metrics.SpoolReadErrors.Inc()
}
return
}
if h.XPHPScript == "" {
return
}
msgID := msgIDFromPath(path)
if msgID == "" {
return
}
if p.msgIndex != nil && p.msgIndex.Has(msgID) {
return // queue-runner re-write dedup
}
auth, _ := p.domains.Domains(h.EnvelopeUser)
sig := computeSignals(h, auth, p.policies)
if sig.ScriptKey == "" {
return
}
if p.ignores != nil && p.ignores.Has(sig.ScriptKey) {
return
}
now := time.Now()
if p.msgIndex != nil {
p.msgIndex.Put(msgID, indexEntry{
ScriptKey: string(sig.ScriptKey),
SourceIP: sig.SourceIP,
CPUser: h.EnvelopeUser,
At: now,
})
}
state := p.eng.scripts.getOrCreate(sig.ScriptKey)
state.append(scriptEvent{
At: now,
MsgID: msgID,
Subject: truncateDaemon(h.Subject, phpRelayBreakdownSubjectMax),
FromMismatch: sig.FromMismatch,
AdditionalSignal: sig.AdditionalSignal,
SourceIP: sig.SourceIP,
})
state.recordActive(msgID, now)
if p.policies == nil || !p.policies.IsProxyIP(sig.SourceIP) {
p.eng.ips.append(sig.SourceIP, sig.ScriptKey, now, h.Subject)
}
if p.rebuilding.Load() {
return
}
findings := p.eng.evaluatePaths(sig.ScriptKey, sig.SourceIP, h.EnvelopeUser, now)
for _, f := range findings {
p.alerter(f)
}
}
// msgIDFromPath returns the msgID portion of a /path/<msgID>-H file.
func msgIDFromPath(path string) string {
base := filepath.Base(path)
if !strings.HasSuffix(base, "-H") {
return ""
}
return strings.TrimSuffix(base, "-H")
}
// runRecoveryScan walks every -H file under spoolRoot/*/, sorts by mtime
// (oldest first), invokes onFile up to maxFiles. Returns the number scanned
// and whether the cap was hit.
//
//nolint:unused // consumed by daemon wiring (Task O2)
func runRecoveryScan(spoolRoot string, maxFiles int, onFile func(string)) (int, bool) {
type entry struct {
path string
mod time.Time
}
var entries []entry
subs, err := os.ReadDir(spoolRoot)
if err != nil {
return 0, false
}
for _, sub := range subs {
if !sub.IsDir() {
continue
}
subPath := filepath.Join(spoolRoot, sub.Name())
files, err := os.ReadDir(subPath)
if err != nil {
continue
}
for _, f := range files {
if !strings.HasSuffix(f.Name(), "-H") {
continue
}
full := filepath.Join(subPath, f.Name())
fi, err := os.Stat(full)
if err != nil {
continue
}
entries = append(entries, entry{path: full, mod: fi.ModTime()})
}
}
sort.Slice(entries, func(i, j int) bool { return entries[i].mod.Before(entries[j].mod) })
truncated := false
if len(entries) > maxFiles {
entries = entries[:maxFiles]
truncated = true
}
for _, e := range entries {
onFile(e.path)
}
return len(entries), truncated
}
// runStartupSpoolWalker walks every currently-queued -H file through the
// pipeline in REBUILD mode, then performs one re-evaluation pass over the
// reconstructed scriptStates. Findings are emitted ONLY in the re-evaluation
// pass, so the rebuild itself never produces duplicate findings for the
// same in-queue mail.
func runStartupSpoolWalker(spoolRoot string, p *spoolPipeline) {
p.SetRebuilding(true)
subs, err := os.ReadDir(spoolRoot)
if err == nil {
for _, sub := range subs {
if !sub.IsDir() {
continue
}
subPath := filepath.Join(spoolRoot, sub.Name())
files, err := os.ReadDir(subPath)
if err != nil {
continue
}
for _, f := range files {
if !strings.HasSuffix(f.Name(), "-H") {
continue
}
p.OnFile(filepath.Join(subPath, f.Name()))
}
}
}
p.SetRebuilding(false)
// Re-evaluation pass.
snap := p.eng.scripts.Snapshot()
now := time.Now()
for k, s := range snap {
// We don't have per-script source IP in the snapshot; pass empty
// sourceIP. Path 4 (HTTP-IP fanout) is keyed off perIPWindow which
// was already populated in OnFile, so an empty SourceIP here just
// means the per-script Path 4 finding doesn't carry an IP -- the
// window itself still triggers correctly via direct OnFile calls
// during normal operation.
cpuser := ""
// Best-effort cpuser: read it from any active msgID's index entry.
if p.msgIndex != nil {
if ids, _ := s.snapshotActiveMsgs(); len(ids) > 0 {
if e, ok := p.msgIndex.Get(ids[0]); ok {
cpuser = e.CPUser
}
}
}
for _, f := range p.eng.evaluatePaths(k, "", cpuser, now) {
p.alerter(f)
}
}
}
// spoolSupervisor wraps a goroutine that may panic. After maxRestarts
// consecutive panics, it stops trying and invokes OnFailed (used to emit
// a Critical finding email_php_relay_watcher_failed).
type spoolSupervisor struct {
fn func(ctx context.Context)
maxRestarts int
OnFailed func()
}
//nolint:unused // consumed by daemon wiring (Task O2)
func newSpoolSupervisor(fn func(ctx context.Context), maxRestarts int) *spoolSupervisor {
return &spoolSupervisor{fn: fn, maxRestarts: maxRestarts, OnFailed: func() {}}
}
func (s *spoolSupervisor) Run(ctx context.Context) {
backoff := 100 * time.Millisecond
for attempt := 0; attempt <= s.maxRestarts; attempt++ {
select {
case <-ctx.Done():
return
default:
}
func() {
defer func() {
if r := recover(); r != nil {
_ = r // Panic recovered; loop will sleep + retry.
}
}()
s.fn(ctx)
}()
if ctx.Err() != nil {
return
}
if attempt == s.maxRestarts {
s.OnFailed()
return
}
select {
case <-ctx.Done():
return
case <-time.After(backoff):
}
if backoff < 5*time.Second {
backoff *= 2
}
}
}
//go:build linux
package daemon
import (
"context"
"fmt"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/emailspool"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/store"
)
// phpRelayAuditPath is the production audit file. Linux-only because the
// writer is only attached from startPHPRelayLinux; tests pass t.TempDir
// paths directly to eximAuditWriterAt.
const phpRelayAuditPath = "/var/log/csm/php_relay_audit.jsonl"
// startPHPRelayLinux completes the PHP-relay wiring after the platform
// gate (in startPHPRelay) has confirmed cPanel + located exim. It is
// split out from daemon.go so the heavy linux-only types stay in this
// file; the cross-platform stub in php_relay_wiring_other.go keeps the
// darwin build clean.
//
// The whole pipeline is built up exactly once at daemon start; SIGHUP
// only reloads policies (handled in daemon.go where d.policies is
// already wired). Callers must not invoke this twice.
func startPHPRelayLinux(d *Daemon) {
// 1. cPanel hourly limit + Path 2b derivation.
limit, status := readCpanelHourlyLimit("/var/cpanel/cpanel.config")
switch status {
case cpanelLimitMissing, cpanelLimitUnparsable:
emitPHPRelayFinding(d, alert.Warning, "email_php_relay_cpanel_limit_unreadable",
"cpanel.config maxemailsperhour unreadable; assuming 100")
}
// 2. Policies (suspicious mailers, HTTP-proxy ranges). LoadPolicies
// returns a usable Policies even on partial-failure so we always
// have something to install on the daemon for SIGHUP reloads.
pol, _ := emailspool.LoadPolicies(d.cfg.EmailProtection.PHPRelay.PoliciesDir)
d.policies = pol
// 3. Window state (per-script, per-IP, per-account).
psw := newPerScriptWindow()
pip := newPerIPWindow(64)
pacct := newPerAccountWindow(5000)
// 4. Evaluator. Wires the cPanel-derived effective account limit so
// Path 2b activates as soon as the first message arrives.
prMetrics := newPHPRelayMetrics()
eng := newEvaluator(psw, pip, pacct, d.cfg, prMetrics)
eng.SetPolicies(pol)
eff, enabled, capped := deriveEffectiveAccountLimit(d.cfg, limit, status)
if !enabled {
emitPHPRelayFinding(d, alert.Warning, "email_php_relay_path2b_disabled",
"Path 2b disabled: cPanel limit off and no operator override")
}
if capped {
emitPHPRelayFinding(d, alert.Warning, "email_php_relay_account_volume_capped",
"operator AccountVolumePerHour capped to 95% of cPanel hourly limit")
}
eng.SetEffectiveAccountLimit(eff)
SetPHPRelayEvaluator(eng)
// 5. msgIDIndex + persister. Bbolt access is via store.Global() --
// the daemon does not hold a *store.DB directly; the global handle
// is the same singleton the rest of the codebase uses (sigWatcher,
// retention, etc.).
bdb := store.Global()
persister := newMsgIndexPersister(bdb, 4096, 100*time.Millisecond)
persister.SetErrorCallback(func(f alert.Finding) {
select {
case d.alertCh <- f:
default:
}
})
persister.SetMetrics(prMetrics)
persister.Start()
d.phpRelayShutdown = append(d.phpRelayShutdown, persister.Stop)
idx := newMsgIDIndex(persister, 200_000)
// 6. ignoreList with bbolt-backed restore.
ignores := newIgnoreList()
ignores.SetStore(bdb)
_ = ignores.Restore()
// 7. cpanel user domains resolver (Path 4 helper).
domains := newUserDomainsResolver()
// 8. Controller (constructed before the freezer so DryRunFn can
// thread the runtime/bbolt/yaml precedence into freeze decisions).
runner := defaultRunner{}
auditor := newStructuredAuditor(eximAuditWriterAt(d.cfg, phpRelayAuditPath))
controller := &PHPRelayController{
eng: eng,
msgIndex: idx,
ignores: ignores,
actionDryRun: &runtimeBool{},
db: bdb,
runner: runner,
eximBin: eximBinary,
auditor: auditor,
enabled: true,
platform: "cpanel",
}
if d.controlListener != nil {
d.controlListener.phprelay = controller
}
// 9. Spool pipeline (Flow A) + autoFreezer (post-emit hook).
pipeline := newSpoolPipeline(eng, domains, pol, idx, ignores, func(f alert.Finding) {
select {
case d.alertCh <- f:
default:
}
})
freezer := newAutoFreezer(psw, d.cfg, "/var/spool/exim/input", eximBinary,
runner, auditor, prMetrics, controller.DryRunFn())
d.autoFreezer = freezer
// 10. Startup walker BEFORE the watcher to rebuild script state for
// messages already on the spool when the daemon starts.
runStartupSpoolWalker("/var/spool/exim/input", pipeline)
watcherFn := func(ctx context.Context) {
w, err := newSpoolWatcher("/var/spool/exim/input", pipeline.OnFile)
if err != nil {
d.MarkWatcher("phprelay", false)
emitPHPRelayFinding(d, alert.Critical, "email_php_relay_watcher_failed", err.Error())
return
}
d.MarkWatcher("phprelay", true)
w.SetMetrics(prMetrics)
w.SetOverflowHandler(func() {
emitPHPRelayFinding(d, alert.Critical, "email_php_relay_inotify_overflow",
"inotify queue overflow; running bounded recovery scan")
const phpRelayOverflowScanMax = 1000
n, truncated := runRecoveryScan("/var/spool/exim/input", phpRelayOverflowScanMax, pipeline.OnFile)
if truncated {
emitPHPRelayFinding(d, alert.Critical, "email_php_relay_overflow_scan_truncated",
fmt.Sprintf("overflow recovery capped at %d files; older messages skipped (Path 2b backstops)", phpRelayOverflowScanMax))
}
emitPHPRelayFinding(d, alert.Warning, "email_php_relay_inotify_overflow_recovered",
fmt.Sprintf("recovery scan processed %d -H files", n))
})
w.Run(ctx)
}
sup := newSpoolSupervisor(watcherFn, 5)
sup.OnFailed = func() {
d.MarkWatcher("phprelay", false)
emitPHPRelayFinding(d, alert.Critical, "email_php_relay_watcher_failed", "supervisor exhausted restarts")
}
ctx := stopChContext(d)
d.wg.Add(1)
obs.Go("php-relay-supervisor", func() {
defer d.wg.Done()
sup.Run(ctx)
})
// 11. Retrospective Path 2b scan over exim_mainlog so account-volume
// alerts fire on the first hour boundary even after a daemon
// restart that lost in-memory state. Threaded through ctx so a
// large mainlog on a busy host cannot outlive shutdown.
d.wg.Add(1)
obs.Go("php-relay-history-scan", func() {
defer d.wg.Done()
ScanEximHistoryForPHPRelayAccountVolume(ctx, "/var/log/exim_mainlog", eng, time.Now(), func(f alert.Finding) {
select {
case d.alertCh <- f:
default:
}
})
})
// 12. Flow E maintenance ticker.
d.wg.Add(1)
obs.Go("php-relay-flow-e", func() {
defer d.wg.Done()
runPHPRelayFlowE(d, ctx, psw, pip, pacct, idx, persister, ignores, prMetrics)
})
}
// runPHPRelayFlowE drives Phase E maintenance (TTL sweeps + metric
// gauges). Single source of truth for php_relay TTLs; runs until ctx is
// cancelled (which happens when d.stopCh closes).
func runPHPRelayFlowE(
d *Daemon,
ctx context.Context,
psw *perScriptWindow,
pip *perIPWindow,
pacct *perAccountWindow,
idx *msgIDIndex,
persister *msgIndexPersister,
ignores *ignoreList,
m *phpRelayMetrics,
) {
minTicker := time.NewTicker(1 * time.Minute)
fiveMinTicker := time.NewTicker(5 * time.Minute)
defer minTicker.Stop()
defer fiveMinTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-minTicker.C:
now := time.Now()
_ = idx.SweepMemory(now.Add(-4 * time.Hour))
if m != nil {
m.MsgIDIndexSize.With("memory").Set(float64(idx.Len()))
}
if _, err := persister.SweepBolt(now.Add(-25 * time.Hour)); err != nil {
emitPHPRelayFinding(d, alert.Warning, "email_php_relay_sweep_failed", err.Error())
}
if _, err := ignores.SweepBolt(now); err != nil {
emitPHPRelayFinding(d, alert.Warning, "email_php_relay_sweep_failed", err.Error())
}
ignores.SweepExpired(now)
case <-fiveMinTicker.C:
now := time.Now()
cutoff25h := now.Add(-25 * time.Hour)
psw.PruneActiveMsgs(cutoff25h)
psw.SweepIdle(cutoff25h)
pip.SweepIdle(now.Add(-1 * time.Hour))
pacct.SweepIdle(now.Add(-24 * time.Hour))
}
}
}
// emitPHPRelayFinding sends a finding through the daemon alert pipeline,
// dropping silently if the channel buffer is full (matches the existing
// startup-time alert pattern in daemon.go).
func emitPHPRelayFinding(d *Daemon, sev alert.Severity, check, msg string) {
select {
case d.alertCh <- alert.Finding{
Severity: sev,
Check: check,
Message: msg,
Timestamp: time.Now(),
}:
default:
}
}
// stopChContext bridges the daemon's stopCh (chan struct{}) to a
// context.Context for components that take a ctx (spool watcher, Flow E
// ticker). The returned context is cancelled when stopCh closes.
func stopChContext(d *Daemon) context.Context {
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-d.stopCh
cancel()
}()
return ctx
}
package daemon
import (
"math"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/processctx"
)
const (
processCtxCacheCap = 16384
processCtxCacheTTL = 30 * time.Minute
processCtxEnrichWorkers = 2
processCtxEnrichQueueCap = 1024
processCtxProcRoot = "/proc"
processCtxProcReadDeadline = 10 * time.Millisecond
)
var (
processCtxOnce sync.Once
processCtxCache *processctx.Cache
processCtxEnr *processctx.Enricher
processCtxRegistry = metrics.Default
)
var processCtxReadStartedAt = defaultProcessCtxReadStartedAt
func defaultProcessCtxReadStartedAt(pid int) (time.Time, bool) {
return processctx.NewProcReader(processCtxProcRoot, processCtxProcReadDeadline).ReadStartedAt(pid)
}
func processCtxStartedAt(pid int) time.Time {
t, ok := processCtxReadStartedAt(pid)
if !ok {
return time.Time{}
}
return t
}
// ProcessCtx returns the daemon-wide process-context cache and enricher,
// constructing them on first call and registering metrics on the default
// registry. Safe for concurrent callers.
func ProcessCtx() (*processctx.Cache, *processctx.Enricher) {
processCtxOnce.Do(func() {
processCtxCache = processctx.NewCache(processCtxCacheCap, processCtxCacheTTL)
reader := processctx.NewProcReader(processCtxProcRoot, processCtxProcReadDeadline)
processCtxEnr = processctx.NewEnricher(processCtxCache, reader, processctx.EnricherConfig{
Workers: processCtxEnrichWorkers,
QueueCap: processCtxEnrichQueueCap,
Resolver: daemonProcessIdentityResolver{},
})
processctx.RegisterMetrics(processCtxRegistry(), processCtxCache, processCtxEnr)
processCtxEnr.Start()
wireAncestryProbeIfAvailable(processCtxCache)
})
return processCtxCache, processCtxEnr
}
type daemonProcessIdentityResolver struct{}
func (daemonProcessIdentityResolver) Resolve(uid int) (string, string) {
if uid < 0 || uid > math.MaxUint32 {
return "", ""
}
user := checks.LookupUser(uint32(uid))
account := resolveLocalAccountForUID(uid, user)
return user, account
}
func resolveLocalAccountForUID(uid int, user string) string {
// First phase: for normal hosted account UIDs, the username is the account
// on cPanel and plain Linux fallback hosts. Later phases can replace this
// helper with a platform-backed account enumerator without changing the
// processctx package.
if uid >= 1000 && user != "" && !strings.HasPrefix(user, "uid:") {
return user
}
return ""
}
// resetProcessCtxForTest is a test seam. Callers in tests must run with
// t.Setenv or similar isolation; production code never invokes this.
func resetProcessCtxForTest() {
if processCtxEnr != nil {
processCtxEnr.Stop()
}
processCtxOnce = sync.Once{}
processCtxCache = nil
processCtxEnr = nil
processCtxRegistry = metrics.NewRegistry
processCtxReadStartedAt = defaultProcessCtxReadStartedAt
}
package daemon
import (
"sync"
"time"
)
const purgeSuppressionWindow = 60 * time.Second
// purgeTracker correlates password purge events with subsequent
// stale-session 401 errors to suppress false-positive alerts.
//
// When a cPanel user changes their password, all existing sessions are
// invalidated. Any in-flight browser requests (AJAX polls, notifications,
// etc.) will return 401 - these are expected side effects, not attacks.
//
// Flow: login (NEW) records IP→account, PURGE records account→time,
// 401 handler checks IP→account→purgeTime to decide suppression.
var purgeTracker = &purgeState{
purges: make(map[string]time.Time),
sessions: make(map[string]string),
}
type purgeState struct {
mu sync.Mutex
purges map[string]time.Time // account → last purge time
sessions map[string]string // IP → last known account
}
// recordLogin tracks which account an IP most recently logged into.
func (ps *purgeState) recordLogin(ip, account string) {
ps.mu.Lock()
defer ps.mu.Unlock()
ps.sessions[ip] = account
ps.cleanupLocked()
}
// recordPurge records a password purge event for an account.
func (ps *purgeState) recordPurge(account string) {
ps.mu.Lock()
defer ps.mu.Unlock()
ps.purges[account] = time.Now()
}
// isPostPurge401 returns true if the IP's 401 is likely a stale session
// artifact from a recent password change (within the suppression window).
func (ps *purgeState) isPostPurge401(ip string) bool {
ps.mu.Lock()
defer ps.mu.Unlock()
account, ok := ps.sessions[ip]
if !ok {
return false
}
purgeTime, ok := ps.purges[account]
if !ok {
return false
}
return time.Since(purgeTime) < purgeSuppressionWindow
}
// cleanupLocked removes stale entries. Caller must hold ps.mu.
func (ps *purgeState) cleanupLocked() {
cutoff := time.Now().Add(-2 * purgeSuppressionWindow)
for k, t := range ps.purges {
if t.Before(cutoff) {
delete(ps.purges, k)
}
}
// Cap sessions map to prevent unbounded growth (keep most recent 500)
if len(ps.sessions) > 500 {
ps.sessions = make(map[string]string)
}
}
package daemon
import (
"bytes"
"github.com/pidginhost/csm/internal/checks"
)
// signalEagerReconcile fires a non-blocking notification on sig the first
// time count reaches threshold. Cross-platform helper extracted so the
// trigger can be unit-tested from a non-linux test file.
//
// - sig is a buffered cap-1 channel. The send is non-blocking (default
// branch) so a stalled receiver never wedges the caller.
// - The trigger fires only on the exact threshold (not >=). A long
// burst above threshold within the same window must not refire
// after the signal has been drained; the next window's first count
// reaching threshold rearms it once the receiver has reset counters.
// - A nil sig is a no-op (some unit tests construct partial structs
// that omit it).
func signalEagerReconcile(sig chan struct{}, count, threshold int64) {
if sig == nil {
return
}
if count != threshold {
return
}
select {
case sig <- struct{}{}:
default:
}
}
// Recognisers that suppress the lowest-tier "anomalous PHP location"
// warning for two specific shapes, without skipping content scanning:
//
// 1. Files inside cPanel's pkgacct/restorepkg staging tree. cPanel
// extracts the user backup as root into /home/cpanelpkgrestore.TMP.
// work.<id>/ for inspection, then re-extracts it under the user
// identity into /home/<account>/. Both extractions raise fanotify
// events; the user-context one carries the real signal, so the
// staging-side warning is a duplicate. The signature/YARA scanners
// still run on the staging file - only the path-only warning is
// dropped.
//
// 2. WP-Optimize probe files at wp-content/uploads/wpo/*. The plugin
// writes tiny <?php files to test whether the host honours certain
// Apache/Nginx directives (Server-Signature, mod_headers, mod_rewrite).
// They contain no input handling and no execution primitives; the
// anomalous-location warning is noise on every site running this
// plugin. As above, the signature/YARA scanners still run.
// looksLikeCpanelRestoreStaging delegates to the shared recogniser in
// internal/checks/sitedetect.go. The deep-scan path uses the same
// helper so realtime and scheduled scans agree on which files are
// duplicates of the user-context extraction.
func looksLikeCpanelRestoreStaging(path string) bool {
return checks.LooksLikeCpanelRestoreStaging(path)
}
// wpOptimizeProbeMaxSize bounds the size of files the recogniser will
// accept. WP-Optimize probes are header()/echo one-liners; anything
// larger fails the shape gate and falls through to the standard
// anomalous-location warning.
const wpOptimizeProbeMaxSize = 512
// wpOptimizeProbeDangerous is the deny list of byte sequences that, if
// present, disqualify a file from being treated as a WP-Optimize probe.
// Probes never use PHP superglobals or execution primitives; an attacker
// payload that does (the only realistic way to abuse a 512-byte file in
// /uploads/wpo/) trips this gate and continues to the standard alert.
//
// Tokens are matched case-insensitively against the file body. They are
// kept separate from the signature scanner above this recogniser so the
// gate stays valid even if a future signature update changes coverage.
var wpOptimizeProbeDangerous = [][]byte{
[]byte("$_"), // any PHP superglobal: $_POST, $_GET, $_REQUEST, $_COOKIE, $_SERVER...
[]byte("ev" + "al"), // split to keep the source-tree security hook happy
[]byte("ass" + "ert"),
[]byte("include"),
[]byte("require"),
[]byte("sys" + "tem"),
[]byte("p" + "assthru"),
[]byte("sh" + "ell_exec"),
[]byte("po" + "pen"),
[]byte("proc_open"),
[]byte("e" + "xec"), // plain exec() and any *exec* variant
[]byte("base64"), // any encoder/decoder pair
[]byte("phpinfo"), // information disclosure
[]byte("create_func"), // create_function deprecated lambda primitive
[]byte("file_get"), // file_get_contents (file disclosure)
[]byte("file_put"), // file_put_contents (write primitive)
[]byte("fwrite"),
[]byte("readfile"),
[]byte("`"), // backtick command substitution
}
// looksLikeWPOptimizeProbe is the realtime, content-aware check.
// It applies the shared path-only gate from internal/checks/sitedetect.go
// (path under /uploads/wpo/, basename test.php, plugin installed) and
// then adds two content-shape gates the deep-scan path cannot apply:
//
// - File body fits in wpOptimizeProbeMaxSize bytes.
// - File body contains none of wpOptimizeProbeDangerous.
//
// All gates together prevent a webshell hidden under /uploads/wpo/test.php
// from silencing the realtime warning: any payload large or interesting
// enough to be useful trips one of the content gates. The
// signature/YARA scanners run before this recogniser, so any existing
// rule still fires on its own pipeline regardless of suppression here.
func looksLikeWPOptimizeProbe(path string, content []byte) bool {
if !checks.LooksLikeWPOptimizeProbeByPath(path) {
return false
}
if len(content) > wpOptimizeProbeMaxSize {
return false
}
lower := bytes.ToLower(content)
for _, danger := range wpOptimizeProbeDangerous {
if bytes.Contains(lower, danger) {
return false
}
}
return true
}
package daemon
import (
"sync"
"time"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/store"
)
// RetentionResult reports how many entries each sweep removed in a single
// RunRetentionOnce invocation.
type RetentionResult struct {
History int
AttackEvents int
Reputation int
Errors []error
}
// Deleted returns the total number of entries deleted across all sweeps.
func (r RetentionResult) Deleted() int {
return r.History + r.AttackEvents + r.Reputation
}
// bucketMap describes which config knob drives which bucket sweep. Keeping
// this explicit rather than inferring it from field names keeps the mapping
// documented next to the orchestrator.
//
// Retention bucket mapping:
// - HistoryDays → `history` bucket (the finding archive every scan appends to)
// - FindingsDays → `attacks:events` bucket (per-attack event trail feeding scoring)
// - ReputationDays → `reputation` bucket (AbuseIPDB / local lookup cache, keyed by IP)
//
// Blocked IPs are deliberately NOT on a TTL: `fw:blocked` is pruned when
// an operator or auto-response unblocks an IP, and temp-ban expiry is
// already handled by LoadFirewallState.
// RunRetentionOnce performs one sweep cycle over the retention-managed
// buckets. Safe to call when retention is disabled or when inputs are
// nil — it no-ops and returns a zero result.
//
// Setting a bucket's *Days value to zero means "don't sweep this bucket
// on this cycle"; negative values are treated as zero (validation catches
// them at config load).
func RunRetentionOnce(db *store.DB, cfg *config.Config, now time.Time) RetentionResult {
var result RetentionResult
if db == nil || cfg == nil || !cfg.Retention.Enabled {
return result
}
if cfg.Retention.HistoryDays > 0 {
cutoff := now.Add(-time.Duration(cfg.Retention.HistoryDays) * 24 * time.Hour)
n, err := db.SweepHistoryOlderThan(cutoff)
if err != nil {
result.Errors = append(result.Errors, err)
}
result.History = n
}
if cfg.Retention.FindingsDays > 0 {
cutoff := now.Add(-time.Duration(cfg.Retention.FindingsDays) * 24 * time.Hour)
n, err := db.SweepAttackEventsOlderThan(cutoff)
if err != nil {
result.Errors = append(result.Errors, err)
}
result.AttackEvents = n
}
if cfg.Retention.ReputationDays > 0 {
cutoff := now.Add(-time.Duration(cfg.Retention.ReputationDays) * 24 * time.Hour)
n, err := db.SweepReputationOlderThan(cutoff)
if err != nil {
result.Errors = append(result.Errors, err)
}
result.Reputation = n
}
return result
}
// retentionSweepDurationOnce guards /metrics registration of the
// retention-cycle counter so repeated daemon starts in a test binary are
// idempotent.
var retentionSweepDurationOnce sync.Once
var retentionSweepCounter *metrics.Counter
var retentionDeletedCounter *metrics.Counter
func registerRetentionMetrics() {
retentionSweepDurationOnce.Do(func() {
retentionSweepCounter = metrics.NewCounter(
"csm_retention_sweeps_total",
"Number of retention sweep cycles completed since daemon start.",
)
metrics.MustRegister("csm_retention_sweeps_total", retentionSweepCounter)
retentionDeletedCounter = metrics.NewCounter(
"csm_retention_deleted_total",
"Number of bucket entries deleted by the retention sweep.",
)
metrics.MustRegister("csm_retention_deleted_total", retentionDeletedCounter)
})
}
// retentionScanner is the daemon goroutine that drives RunRetentionOnce
// on the configured SweepInterval. Started from Run() only when
// cfg.Retention.Enabled is true; absent that, the sweep is dormant and
// no timer fires.
//
// Compaction is NOT triggered from here: reclaiming space safely requires
// the daemon to close and reopen the bbolt handle under coordinated
// exclusive access, which is the job of `csm store compact` with the
// daemon stopped. This goroutine instead emits an info log when the file
// crosses CompactMinSizeMB so operators know a compact is due.
func (d *Daemon) retentionScanner() {
defer d.wg.Done()
registerRetentionMetrics()
// First sweep happens after a short settle period so a restart storm
// does not hammer bbolt. Subsequent sweeps use the full interval.
settle := 5 * time.Minute
timer := time.NewTimer(settle)
defer timer.Stop()
for {
select {
case <-d.stopCh:
return
case <-timer.C:
d.runRetentionTick()
timer.Reset(d.retentionInterval())
}
}
}
// retentionInterval parses Retention.SweepInterval from the live config,
// falling back to 24h if the duration is malformed or non-positive. The
// live config is re-read each tick so SIGHUP can adjust cadence without
// restart (the retention struct itself stays hotreload:"restart", but the
// ticker can pick up edits on the next cycle).
func (d *Daemon) retentionInterval() time.Duration {
cfg := d.currentCfg()
if cfg == nil {
return 24 * time.Hour
}
ival, err := time.ParseDuration(cfg.Retention.SweepInterval)
if err != nil || ival <= 0 {
return 24 * time.Hour
}
return ival
}
// runRetentionTick runs one sweep + size check and emits metrics/logs.
func (d *Daemon) runRetentionTick() {
cfg := d.currentCfg()
db := store.Global()
result := RunRetentionOnce(db, cfg, time.Now())
retentionSweepCounter.Inc()
if n := result.Deleted(); n > 0 {
retentionDeletedCounter.Add(float64(n))
csmlog.Info("retention sweep",
"deleted_total", n,
"history", result.History,
"attacks_events", result.AttackEvents,
"reputation", result.Reputation,
)
}
for _, err := range result.Errors {
csmlog.Warn("retention sweep bucket error", "err", err)
}
// Size check: surface a human-readable hint when the file has grown
// past the configured floor so operators can plan a maintenance
// window. Actual compaction is operator-driven via `csm store compact`.
if cfg != nil && cfg.Retention.Enabled && cfg.Retention.CompactMinSizeMB > 0 && db != nil {
size, err := db.Size()
if err == nil {
minBytes := int64(cfg.Retention.CompactMinSizeMB) * 1024 * 1024
if size >= minBytes {
csmlog.Info("retention: compaction recommended; run `csm store compact` during a maintenance window",
"size_bytes", size,
"min_bytes", minBytes,
)
}
}
}
}
//go:build !(linux && bpf)
package daemon
import (
"context"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/bpf"
"github.com/pidginhost/csm/internal/config"
)
// sensitiveFileBPF is the no-tag placeholder for the BPF sensitive-file
// monitor. The real type with BPF map handles, link, and the ringbuf reader
// lives in sensitive_file_bpf.go behind //go:build linux && bpf.
type sensitiveFileBPF struct{}
func (s *sensitiveFileBPF) Mode() string { return "bpf" }
func (s *sensitiveFileBPF) EventCount() uint64 { return 0 }
func (s *sensitiveFileBPF) Run(_ context.Context) {}
func startSensitiveFileBPF(_ context.Context, _ chan<- alert.Finding, _ *config.Config) (*sensitiveFileBPF, error) {
return nil, bpf.ErrNotBuilt
}
// Code generated by bpf2go; DO NOT EDIT.
//go:build 386 || amd64
package sensitive_file_bpfprog
import (
"bytes"
_ "embed"
"fmt"
"io"
"structs"
"github.com/cilium/ebpf"
)
type SensitiveFileFileid struct {
_ structs.HostLayout
Dev uint64
Ino uint64
}
type SensitiveFileSensitiveEvent struct {
_ structs.HostLayout
Uid uint32
Pid uint32
Mask uint32
_ [4]byte
Dev uint64
Ino uint64
Comm [16]uint8
}
// LoadSensitiveFile returns the embedded CollectionSpec for SensitiveFile.
func LoadSensitiveFile() (*ebpf.CollectionSpec, error) {
reader := bytes.NewReader(_SensitiveFileBytes)
spec, err := ebpf.LoadCollectionSpecFromReader(reader)
if err != nil {
return nil, fmt.Errorf("can't load SensitiveFile: %w", err)
}
return spec, err
}
// LoadSensitiveFileObjects loads SensitiveFile and converts it into a struct.
//
// The following types are suitable as obj argument:
//
// *SensitiveFileObjects
// *SensitiveFilePrograms
// *SensitiveFileMaps
//
// See ebpf.CollectionSpec.LoadAndAssign documentation for details.
func LoadSensitiveFileObjects(obj interface{}, opts *ebpf.CollectionOptions) error {
spec, err := LoadSensitiveFile()
if err != nil {
return err
}
return spec.LoadAndAssign(obj, opts)
}
// SensitiveFileSpecs contains maps and programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type SensitiveFileSpecs struct {
SensitiveFileProgramSpecs
SensitiveFileMapSpecs
SensitiveFileVariableSpecs
}
// SensitiveFileProgramSpecs contains programs before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type SensitiveFileProgramSpecs struct {
CsmFilePerm *ebpf.ProgramSpec `ebpf:"csm_file_perm"`
}
// SensitiveFileMapSpecs contains maps before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type SensitiveFileMapSpecs struct {
Events *ebpf.MapSpec `ebpf:"events"`
Watched *ebpf.MapSpec `ebpf:"watched"`
}
// SensitiveFileVariableSpecs contains global variables before they are loaded into the kernel.
//
// It can be passed ebpf.CollectionSpec.Assign.
type SensitiveFileVariableSpecs struct {
Unused *ebpf.VariableSpec `ebpf:"unused"`
}
// SensitiveFileObjects contains all objects after they have been loaded into the kernel.
//
// It can be passed to LoadSensitiveFileObjects or ebpf.CollectionSpec.LoadAndAssign.
type SensitiveFileObjects struct {
SensitiveFilePrograms
SensitiveFileMaps
SensitiveFileVariables
}
func (o *SensitiveFileObjects) Close() error {
return _SensitiveFileClose(
&o.SensitiveFilePrograms,
&o.SensitiveFileMaps,
)
}
// SensitiveFileMaps contains all maps after they have been loaded into the kernel.
//
// It can be passed to LoadSensitiveFileObjects or ebpf.CollectionSpec.LoadAndAssign.
type SensitiveFileMaps struct {
Events *ebpf.Map `ebpf:"events"`
Watched *ebpf.Map `ebpf:"watched"`
}
func (m *SensitiveFileMaps) Close() error {
return _SensitiveFileClose(
m.Events,
m.Watched,
)
}
// SensitiveFileVariables contains all global variables after they have been loaded into the kernel.
//
// It can be passed to LoadSensitiveFileObjects or ebpf.CollectionSpec.LoadAndAssign.
type SensitiveFileVariables struct {
Unused *ebpf.Variable `ebpf:"unused"`
}
// SensitiveFilePrograms contains all programs after they have been loaded into the kernel.
//
// It can be passed to LoadSensitiveFileObjects or ebpf.CollectionSpec.LoadAndAssign.
type SensitiveFilePrograms struct {
CsmFilePerm *ebpf.Program `ebpf:"csm_file_perm"`
}
func (p *SensitiveFilePrograms) Close() error {
return _SensitiveFileClose(
p.CsmFilePerm,
)
}
func _SensitiveFileClose(closers ...io.Closer) error {
for _, closer := range closers {
if err := closer.Close(); err != nil {
return err
}
}
return nil
}
// Do not access this directly.
//
//go:embed sensitivefile_x86_bpfel.o
var _SensitiveFileBytes []byte
package daemon
import (
"encoding/binary"
"errors"
)
// SensitiveFileEvent matches struct sensitive_event in sensitive_file.bpf.c
// byte for byte. Userspace looks up (Dev, Ino) in the in-memory mirror of
// the BPF watchset map to recover the path string at finding time.
type SensitiveFileEvent struct {
UID uint32
PID uint32
Mask uint32
Dev uint64
Ino uint64
Comm string
}
const sensitiveFileEventSize = 4 + 4 + 4 + 4 + 8 + 8 + 16
func decodeSensitiveFileEvent(b []byte) (SensitiveFileEvent, error) {
if len(b) < sensitiveFileEventSize {
return SensitiveFileEvent{}, errors.New("sensitive file event short buffer")
}
ev := SensitiveFileEvent{
UID: binary.LittleEndian.Uint32(b[0:4]),
PID: binary.LittleEndian.Uint32(b[4:8]),
Mask: binary.LittleEndian.Uint32(b[8:12]),
Dev: binary.LittleEndian.Uint64(b[16:24]),
Ino: binary.LittleEndian.Uint64(b[24:32]),
}
ev.Comm = nullTerm(b[32 : 32+16])
return ev, nil
}
package daemon
import (
"context"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/state"
)
// sensitiveFilePoller wraps checks.CheckSensitiveFiles in a goroutine.
// Used when the BPF backend is unavailable or operator-disabled. Detection
// latency equals the poll interval (default 5 minutes).
type sensitiveFilePoller struct {
cfg *config.Config
store *state.Store
alertCh chan<- alert.Finding
count atomic.Uint64
}
func newSensitiveFilePoller(cfg *config.Config, store *state.Store, alertCh chan<- alert.Finding) *sensitiveFilePoller {
return &sensitiveFilePoller{cfg: cfg, store: store, alertCh: alertCh}
}
func (p *sensitiveFilePoller) Mode() string { return "legacy" }
func (p *sensitiveFilePoller) EventCount() uint64 { return p.count.Load() }
func (p *sensitiveFilePoller) Run(ctx context.Context) {
interval := sensitiveFilePollerInterval(p.cfg)
t := time.NewTicker(interval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
for _, f := range checks.CheckSensitiveFiles(ctx, p.cfg, p.store) {
p.count.Add(1)
select {
case p.alertCh <- f:
default:
csmlog.Warn("sensitive_file legacy: alert channel full, dropping finding")
}
}
}
}
}
func sensitiveFilePollerInterval(cfg *config.Config) time.Duration {
if d := cfg.Detection.SensitiveFilesPollInterval; d > 0 {
return d
}
return 5 * time.Minute
}
package daemon
import (
"context"
"errors"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/bpf"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/state"
)
// StartSensitiveFileMonitor selects the active sensitive-file write monitor
// based on cfg.Detection.SensitiveFilesBackend and host capability.
//
// "auto" (default) -- try BPF, fall back to legacy hash-comparison polling.
// "bpf" -- require BPF; return nil if unavailable (no fallback).
// "legacy" -- pin legacy polling.
// "none" -- disable the live monitor (the periodic check still runs).
//
// Unknown values fall back to "auto" with a warning. The metric
// csm_bpf_backend{feature="sensitive_files", kind="..."} reflects the
// chosen path.
func StartSensitiveFileMonitor(alertCh chan<- alert.Finding, cfg *config.Config, store *state.Store) bpf.Backend {
choice := strings.ToLower(strings.TrimSpace(cfg.Detection.SensitiveFilesBackend))
if choice == "" {
choice = bpf.BackendAuto
}
switch choice {
case bpf.BackendAuto, bpf.BackendBPF, bpf.BackendLegacy, bpf.BackendNone:
default:
csmlog.Warn("sensitive_files: unknown backend choice, using auto", "value", choice)
choice = bpf.BackendAuto
}
if choice == bpf.BackendNone {
csmlog.Info("sensitive_files: disabled by config")
bpf.SetActive("sensitive_files", bpf.BackendNone)
return nil
}
var bpfErr error
if choice == bpf.BackendAuto || choice == bpf.BackendBPF {
if b, err := tryStartSensitiveFileBPFFn(context.Background(), alertCh, cfg); err == nil && b != nil {
csmlog.Info("sensitive_files", "backend", "bpf", "choice", choice)
bpf.SetActive("sensitive_files", bpf.BackendBPF)
return b
} else if err != nil {
bpfErr = err
level := "bpf-unsupported"
if errors.Is(err, bpf.ErrNotBuilt) {
level = "bpf-not-built"
}
csmlog.Info("sensitive_files: BPF unavailable", "state", level, "reason", err.Error(), "choice", choice)
if choice == bpf.BackendBPF {
csmlog.Warn("sensitive_files: backend=bpf but BPF unavailable; no live monitor", "reason", err.Error())
bpf.SetActive("sensitive_files", bpf.BackendNone)
emitBPFUnavailableFinding(alertCh, "sensitive_files", choice, "", err)
return nil
}
}
}
poller := newSensitiveFilePoller(cfg, store, alertCh)
csmlog.Info("sensitive_files", "backend", "legacy", "choice", choice)
bpf.SetActive("sensitive_files", bpf.BackendLegacy)
if bpfErr != nil {
emitBPFUnavailableFinding(alertCh, "sensitive_files", choice, bpf.BackendLegacy, bpfErr)
}
return poller
}
var tryStartSensitiveFileBPFFn = tryStartSensitiveFileBPF
func tryStartSensitiveFileBPF(ctx context.Context, ch chan<- alert.Finding, cfg *config.Config) (bpf.Backend, error) {
b, err := startSensitiveFileBPF(ctx, ch, cfg)
if err != nil {
return nil, err
}
return b, nil
}
package daemon
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/store"
)
// Signature-update-driven retroactive rescan.
//
// CSM's signature rules update independently of the deep-tier
// scanner. Without this watcher, a fresh ruleset only catches files
// that change AFTER the update -- existing files that newly match
// stay silent until the next time they happen to be touched. Real
// attacks aren't that polite.
//
// The watcher polls cfg.Signatures.RulesDir every sigWatchInterval,
// stat()s every *.yaml / *.yml / *.yar / *.yara file, and sets the
// daemon's forceFullRescan flag whenever any tracked file's mtime
// advances. The next deep-tier tick reads + clears the flag and runs
// the full account tree instead of the fanotify short-list.
//
// The mtime map is persisted in bbolt (sig_watch bucket) so a daemon
// restart does not look like "all files are new" and trigger a
// phantom rescan on first tick.
const sigWatchInterval = 60 * time.Second
var sigWatchExtensions = []string{".yaml", ".yml", ".yar", ".yara"}
var (
sigRescansTotalOnce sync.Once
sigRescansTotal *metrics.Counter
)
// observeSignatureRescan increments the operator-facing counter the
// first time the watcher arms a rescan in any process lifetime, and
// every subsequent time. Called from the deep-tier path AFTER a full
// retro-sweep completes, so the counter measures completed sweeps,
// not queued ones.
func observeSignatureRescan() {
sigRescansTotalOnce.Do(func() {
sigRescansTotal = metrics.NewCounter(
"csm_signature_rescans_total",
"Signature-update-driven full deep-tier rescans completed. Incremented when the deep-tier scheduler picks up the forceFullRescan flag set by the signature watcher and finishes a sweep against the new ruleset.",
)
metrics.MustRegister("csm_signature_rescans_total", sigRescansTotal)
})
sigRescansTotal.Inc()
}
// sigWatcher carries the watcher's loop state. The daemon owns one
// instance; the goroutine in (*Daemon).signatureWatcher drives it.
//
// cfg and store are re-resolved per tick, not captured at
// construction. Originally we cached cfg.Signatures.RulesDir and
// store.Global() into struct fields and discovered two ways that
// could go wrong: a hot-reload that changed the rules dir would
// silently keep walking the old path, and a daemon ordering quirk
// where store.Global() is nil at goroutine spawn would leave the
// watcher persistence-blind for the rest of its lifetime. Live
// resolution closes both.
type sigWatcher struct {
cfgFunc func() *config.Config
storeFunc func() *store.DB
rescanFlag *atomic.Bool
alertCh chan<- alert.Finding
interval time.Duration
// Initialised on first tick from store.GetSignatureMtimes(); the
// in-memory map is the authoritative working copy for the loop.
lastMtimes map[string]time.Time
}
// newSigWatcher constructs a watcher with production defaults.
// cfgFunc and storeFunc are called per tick so config hot-reloads
// and lazy bbolt initialisation are picked up automatically.
// Callers can override the interval after construction for tests.
func newSigWatcher(cfgFunc func() *config.Config, storeFunc func() *store.DB, flag *atomic.Bool, alertCh chan<- alert.Finding) *sigWatcher {
return &sigWatcher{
cfgFunc: cfgFunc,
storeFunc: storeFunc,
rescanFlag: flag,
alertCh: alertCh,
interval: sigWatchInterval,
}
}
// loadInitial pulls the persisted mtime map into memory. Called once
// when w.lastMtimes is nil. A read error here is non-fatal -- the
// watcher operates with an empty map and the next tick re-persists,
// so the cost of a transient bbolt error is at most one phantom
// rescan.
func (w *sigWatcher) loadInitial(sdb *store.DB) {
if sdb == nil {
w.lastMtimes = map[string]time.Time{}
return
}
got, err := sdb.GetSignatureMtimes()
if err != nil {
csmlog.Warn("sig_watch: loading persisted mtimes", "err", err)
w.lastMtimes = map[string]time.Time{}
return
}
w.lastMtimes = got
}
// tick performs one walk of the rules dir and arms the rescan flag
// when any tracked file's mtime advances. Removed files drop out of
// the persisted map without triggering a rescan -- the spec calls
// out only mtime-advance as a trigger.
func (w *sigWatcher) tick() {
cfg := w.cfgFunc()
if !sigWatchEnabled(cfg) {
return
}
rulesDir := cfg.Signatures.RulesDir
if rulesDir == "" {
return
}
sdb := w.storeFunc()
// Defer first-time persistence load until we have a non-nil
// store. A nil store on the first tick (race against bbolt
// open) means we operate purely in-memory; once bbolt is up,
// the next tick triggers loadInitial as if for the first time
// because lastMtimes is still nil.
if w.lastMtimes == nil && sdb != nil {
w.loadInitial(sdb)
}
if w.lastMtimes == nil {
w.lastMtimes = map[string]time.Time{}
}
current := walkRulesDir(rulesDir)
var changed []sigWatchChange
for path, mtime := range current {
old, ok := w.lastMtimes[path]
switch {
case !ok:
// New file. The spec treats first-observation as a
// non-event so a fresh `update-rules` install does not
// cause a rescan when the daemon also starts cold.
// Track the mtime forward without arming.
case !mtime.Equal(old):
changed = append(changed, sigWatchChange{Path: path, Old: old, New: mtime})
}
}
w.lastMtimes = current
if sdb != nil {
if err := sdb.PutSignatureMtimes(current); err != nil {
csmlog.Warn("sig_watch: persisting mtimes", "err", err)
}
}
if len(changed) == 0 {
return
}
w.rescanFlag.Store(true)
for _, c := range changed {
select {
case w.alertCh <- alert.Finding{
Severity: alert.Warning,
Check: "signature_update_rescan_queued",
Message: fmt.Sprintf("Signature update detected, full deep rescan queued: %s", filepath.Base(c.Path)),
Details: fmt.Sprintf("File: %s\nOld mtime: %s\nNew mtime: %s", c.Path, c.Old.UTC().Format(time.RFC3339), c.New.UTC().Format(time.RFC3339)),
FilePath: c.Path,
Timestamp: time.Now(),
}:
default:
// Alert channel is full; the rescan flag is already set
// so the operator-visible "what happened" record is the
// less critical loss here.
}
}
}
// walkRulesDir returns mtimes for every signature file under dir.
// Sub-directories are walked too -- the YARA Forge updater puts
// files under tier-named subfolders. Errors during walk (missing
// dir, EACCES on a sub-tree) are swallowed; we want one bad path
// not to crash the watcher or stop sibling traversal.
func walkRulesDir(dir string) map[string]time.Time {
out := map[string]time.Time{}
_ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
// Missing dir or EACCES on a sub-tree -- ignore so the
// watcher does not crash. We deliberately swallow err
// instead of returning it; filepath.SkipDir is the
// idiomatic alternative but we want to keep walking
// siblings, not the descendants of one bad path.
return filepath.SkipDir
}
if info.IsDir() {
return nil
}
ext := strings.ToLower(filepath.Ext(path))
if !sigWatchExtMatches(ext) {
return nil
}
out[path] = info.ModTime()
return nil
})
return out
}
// sigWatchExtMatches returns true when ext is one of the file
// extensions the watcher tracks. Lower-case input expected.
func sigWatchExtMatches(ext string) bool {
for _, want := range sigWatchExtensions {
if ext == want {
return true
}
}
return false
}
// sigWatchEnabled resolves the tri-state cfg flag. Same shape as
// dbObjectScanningEnabled in the checks package: nil = on, *true =
// on, *false = off.
func sigWatchEnabled(cfg *config.Config) bool {
if cfg == nil {
return true
}
if cfg.Detection.RescanOnSignatureUpdate == nil {
return true
}
return *cfg.Detection.RescanOnSignatureUpdate
}
// sigWatchChange records one mtime advance for the alert detail
// message.
type sigWatchChange struct {
Path string
Old time.Time
New time.Time
}
// signatureWatcher is the daemon's signature-watch goroutine. Runs
// until d.stopCh is closed; ticks every sigWatchInterval, sets
// d.forceFullRescan when any tracked rule file's mtime advances.
//
// Cfg and store are accessed via getter closures (not captured
// values) so a hot-reload of signatures.rules_dir takes effect on
// the next tick and a late-initialised bbolt is picked up
// automatically.
func (d *Daemon) signatureWatcher() {
defer d.wg.Done()
w := newSigWatcher(
func() *config.Config { return d.currentCfg() },
store.Global,
&d.forceFullRescan,
d.alertCh,
)
ticker := time.NewTicker(w.interval)
defer ticker.Stop()
// Initial tick on start so the watcher converges quickly when
// the daemon comes up shortly after an `update-rules` invocation.
w.tick()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
w.tick()
}
}
}
package daemon
import (
"fmt"
"sort"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// smtpIPEntry tracks failed-auth timestamps and suppression state for one IP.
type smtpIPEntry struct {
times []time.Time
suppressed time.Time
lastSeen time.Time
}
// smtpSubnetEntry tracks unique attacker IPs within a /24.
type smtpSubnetEntry struct {
ips map[string]time.Time // ip -> firstSeen in window
suppressed time.Time
lastSeen time.Time
}
// smtpAccountEntry tracks unique attacker IPs per mailbox.
type smtpAccountEntry struct {
ips map[string]time.Time
suppressed time.Time
lastSeen time.Time
}
// smtpAuthTracker aggregates dovecot auth-failure events into three
// detection signals: per-IP brute force, per-/24 password spray, and
// per-mailbox account spray.
//
// Thread-safe; Record may be called concurrently from multiple log readers.
type smtpAuthTracker struct {
mu sync.Mutex
perIPThreshold int
subnetThreshold int
accountSprayThreshold int
window time.Duration
suppression time.Duration
maxTracked int
now func() time.Time
ips map[string]*smtpIPEntry
subnets map[string]*smtpSubnetEntry
accounts map[string]*smtpAccountEntry
// Diagnostic counters (guarded by mu): cumulative Record invocations and
// findings emitted. The daemon logs these so a "zero smtp_bruteforce in
// production despite thousands of auth failures" can be pinned to either
// "Record never called" or "called but never crosses threshold".
recordCalls int64
findingsEmitted int64
}
// newSMTPAuthTracker constructs a tracker. `now` is injected so tests can
// use deterministic clocks; pass `time.Now` in production.
func newSMTPAuthTracker(
perIPThreshold int,
subnetThreshold int,
accountSprayThreshold int,
window time.Duration,
suppression time.Duration,
maxTracked int,
now func() time.Time,
) *smtpAuthTracker {
if now == nil {
now = time.Now
}
return &smtpAuthTracker{
perIPThreshold: perIPThreshold,
subnetThreshold: subnetThreshold,
accountSprayThreshold: accountSprayThreshold,
window: window,
suppression: suppression,
maxTracked: maxTracked,
now: now,
ips: make(map[string]*smtpIPEntry),
subnets: make(map[string]*smtpSubnetEntry),
accounts: make(map[string]*smtpAccountEntry),
}
}
// Size returns the total number of tracked entities (IPs + subnets + accounts).
func (t *smtpAuthTracker) Size() int {
t.mu.Lock()
defer t.mu.Unlock()
return len(t.ips) + len(t.subnets) + len(t.accounts)
}
// Record processes one dovecot auth-failure observation. Returns zero or more
// findings that callers should append to their finding slice.
//
// ip MUST be non-private, non-loopback, and non-infra — callers enforce this
// before invoking Record.
func (t *smtpAuthTracker) Record(ip, account string) []alert.Finding {
if ip == "" {
return nil
}
t.mu.Lock()
defer t.mu.Unlock()
t.recordCalls++
now := t.now()
cutoff := now.Add(-t.window)
var findings []alert.Finding
// --- Per-IP tracker ---
e, ok := t.ips[ip]
if !ok {
e = &smtpIPEntry{}
t.ips[ip] = e
}
e.times = pruneTimes(e.times, cutoff)
e.times = append(e.times, now)
e.lastSeen = now
if len(e.times) >= t.perIPThreshold && !now.Before(e.suppressed) {
e.suppressed = now.Add(t.suppression)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "smtp_bruteforce",
Message: fmt.Sprintf("SMTP brute force from %s: %d failed auths in %v",
ip, len(e.times), t.window),
Details: "Real-time detection of dovecot_login auth failures",
Timestamp: now,
SourceIP: ip,
})
}
// --- Per-/24 subnet tracker (IPv4 only) ---
if prefix := extractPrefix24Daemon(ip); prefix != "" {
s, ok := t.subnets[prefix]
if !ok {
s = &smtpSubnetEntry{ips: make(map[string]time.Time)}
t.subnets[prefix] = s
}
pruneSubnetIPs(s, cutoff)
s.ips[ip] = now
s.lastSeen = now
if len(s.ips) >= t.subnetThreshold && !now.Before(s.suppressed) {
s.suppressed = now.Add(t.suppression)
cidr := prefix + ".0/24"
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "smtp_subnet_spray",
Message: fmt.Sprintf("SMTP password spray from %s.0/24: %d unique IPs in %v",
prefix, len(s.ips), t.window),
Details: "Real-time detection of dovecot_login auth failures from many IPs in one /24",
Timestamp: now,
SourceIP: cidr,
})
}
}
// --- Per-account spray tracker ---
if account != "" {
a, ok := t.accounts[account]
if !ok {
a = &smtpAccountEntry{ips: make(map[string]time.Time)}
t.accounts[account] = a
}
pruneAccountIPs(a, cutoff)
a.ips[ip] = now
a.lastSeen = now
if len(a.ips) >= t.accountSprayThreshold && !now.Before(a.suppressed) {
a.suppressed = now.Add(t.suppression)
_, acctDomain := alert.SplitEmail(account)
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "smtp_account_spray",
Message: fmt.Sprintf("SMTP password spray targeting %s: %d unique IPs in %v",
account, len(a.ips), t.window),
Details: "Distributed login attempts across many IPs against one mailbox (visibility only — no auto-block).",
Timestamp: now,
SourceIP: ip,
Domain: acctDomain,
Mailbox: account,
})
}
}
t.enforceMaxTracked()
t.findingsEmitted += int64(len(findings))
return findings
}
// Stats returns cumulative Record invocations and findings emitted since
// startup. Used by the daemon's periodic diagnostic log.
func (t *smtpAuthTracker) Stats() (calls, emits int64) {
t.mu.Lock()
defer t.mu.Unlock()
return t.recordCalls, t.findingsEmitted
}
// pruneTimes drops timestamps older than cutoff. Reuses the backing array.
func pruneTimes(times []time.Time, cutoff time.Time) []time.Time {
recent := times[:0]
for _, ts := range times {
if !ts.Before(cutoff) {
recent = append(recent, ts)
}
}
return recent
}
// extractPrefix24Daemon returns the first three octets of an IPv4 address as
// "a.b.c", or "" if the input isn't an IPv4 address in dotted-quad form.
func extractPrefix24Daemon(ip string) string {
parts := 0
end := 0
for i := 0; i < len(ip); i++ {
if ip[i] == '.' {
parts++
if parts == 3 {
end = i
break
}
}
}
if parts != 3 {
return ""
}
// Reject IPv6 mapped or containing colons.
for i := 0; i < end; i++ {
if ip[i] == ':' {
return ""
}
}
return ip[:end]
}
// pruneSubnetIPs drops per-/24 IP entries whose last-seen is older than cutoff.
func pruneSubnetIPs(s *smtpSubnetEntry, cutoff time.Time) {
for ip, ts := range s.ips {
if ts.Before(cutoff) {
delete(s.ips, ip)
}
}
}
// pruneAccountIPs drops per-account IP entries whose last-seen is older than cutoff.
func pruneAccountIPs(a *smtpAccountEntry, cutoff time.Time) {
for ip, ts := range a.ips {
if ts.Before(cutoff) {
delete(a.ips, ip)
}
}
}
// Purge removes entries with no recent activity (older than window + suppression).
// Called from a background goroutine every minute.
func (t *smtpAuthTracker) Purge() {
t.mu.Lock()
defer t.mu.Unlock()
now := t.now()
activityCutoff := now.Add(-(t.window + t.suppression))
for k, e := range t.ips {
e.times = pruneTimes(e.times, now.Add(-t.window))
if len(e.times) == 0 && !e.lastSeen.After(activityCutoff) {
delete(t.ips, k)
}
}
for k, s := range t.subnets {
pruneSubnetIPs(s, now.Add(-t.window))
if len(s.ips) == 0 && !s.lastSeen.After(activityCutoff) {
delete(t.subnets, k)
}
}
for k, a := range t.accounts {
pruneAccountIPs(a, now.Add(-t.window))
if len(a.ips) == 0 && !a.lastSeen.After(activityCutoff) {
delete(t.accounts, k)
}
}
}
// enforceMaxTracked evicts the least-recently-seen entries until the total
// number of tracked entities (IPs + subnets + accounts) is <= maxTracked.
// Caller must hold t.mu.
func (t *smtpAuthTracker) enforceMaxTracked() {
total := len(t.ips) + len(t.subnets) + len(t.accounts)
if total <= t.maxTracked {
return
}
type victim struct {
kind string // "ip" | "subnet" | "account"
key string
seen time.Time
}
victims := make([]victim, 0, total)
for k, v := range t.ips {
victims = append(victims, victim{"ip", k, v.lastSeen})
}
for k, v := range t.subnets {
victims = append(victims, victim{"subnet", k, v.lastSeen})
}
for k, v := range t.accounts {
victims = append(victims, victim{"account", k, v.lastSeen})
}
sort.Slice(victims, func(i, j int) bool { return victims[i].seen.Before(victims[j].seen) })
// Evict to 95% of cap so subsequent inserts don't re-trigger the sort.
target := t.maxTracked * 95 / 100
for i := 0; i < len(victims); i++ {
if len(t.ips)+len(t.subnets)+len(t.accounts) <= target {
break
}
v := victims[i]
switch v.kind {
case "ip":
delete(t.ips, v.key)
case "subnet":
delete(t.subnets, v.key)
case "account":
delete(t.accounts, v.key)
}
}
}
package daemon
import (
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
// smtpProbeBlockExpiryString returns the configured block expiry string when
// auto-response will actually block the source IP for an `smtp_probe_abuse`
// finding (auto_response.enabled AND block_ips true AND dry_run false), or ""
// otherwise.
// The returned string is what the operator put in csm.yaml ("24h", "12h"...)
// so the alert text matches the value they configured rather than Go's
// canonical Duration formatting.
func smtpProbeBlockExpiryString() string {
cfg := config.Active()
if cfg == nil {
return ""
}
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.BlockIPs || cfg.AutoResponseDryRunEnabled() {
return ""
}
if cfg.AutoResponse.BlockExpiry == "" {
return "24h"
}
return cfg.AutoResponse.BlockExpiry
}
// smtpProbeEntry records connect timestamps and suppression for one IP.
type smtpProbeEntry struct {
times []time.Time
suppressed time.Time
lastSeen time.Time
}
// smtpProbeTracker counts raw SMTP connect events per source IP and emits an
// `smtp_probe_abuse` finding when an IP exceeds the threshold inside the
// rolling window.
//
// This is the connection-rate complement to smtpAuthTracker: scanners that
// probe-and-disconnect (no AUTH attempt) never trigger the auth tracker, so
// they need their own signal. The thresholds are deliberately set well above
// any legitimate MUA usage; Thunderbird/iPhone bursts of 10-15 parallel
// sessions per send fall comfortably under, scanner storms with hundreds of
// connect/min are caught.
type smtpProbeTracker struct {
mu sync.Mutex
threshold int
window time.Duration
suppression time.Duration
maxTracked int
now func() time.Time
// expiryStrFn returns the operator-visible block expiry (e.g. "24h") when
// live auto-blocking is enabled, or "" when no auto-block will run. Read
// at finding time so a SIGHUP reload of auto_response.* is reflected in
// the next emitted finding's Details.
expiryStrFn func() string
ips map[string]*smtpProbeEntry
}
func newSMTPProbeTracker(threshold int, window, suppression time.Duration, maxTracked int, now func() time.Time, expiryStrFn func() string) *smtpProbeTracker {
if now == nil {
now = time.Now
}
return &smtpProbeTracker{
threshold: threshold,
window: window,
suppression: suppression,
maxTracked: maxTracked,
now: now,
expiryStrFn: expiryStrFn,
ips: make(map[string]*smtpProbeEntry),
}
}
// Size returns the number of tracked source IPs.
func (t *smtpProbeTracker) Size() int {
t.mu.Lock()
defer t.mu.Unlock()
return len(t.ips)
}
// Record observes one SMTP connect event. ip MUST be non-private,
// non-loopback, and non-infra. Callers enforce this before invoking Record.
// Returns zero or one finding (no per-call multiplication).
func (t *smtpProbeTracker) Record(ip string) []alert.Finding {
if ip == "" || t.threshold <= 0 {
return nil
}
t.mu.Lock()
defer t.mu.Unlock()
now := t.now()
cutoff := now.Add(-t.window)
e, ok := t.ips[ip]
if !ok {
e = &smtpProbeEntry{}
t.ips[ip] = e
}
e.times = pruneTimes(e.times, cutoff)
e.times = append(e.times, now)
e.lastSeen = now
var findings []alert.Finding
if len(e.times) >= t.threshold && !now.Before(e.suppressed) {
e.suppressed = now.Add(t.suppression)
// The Details message is computed here, before AutoBlockIPs runs in
// dispatchBatch. We can only report the *intent* (scheduled for
// auto-block) - the actual outcome (blocked / rate-limited / already
// blocked / challenged) is published by the companion `auto_block`
// finding emitted by checks.AutoBlockIPs in the same batch.
details := "Sustained SMTP connect rate above the configured threshold. Likely scanner / dictionary probe;"
if t.expiryStrFn != nil {
if exp := t.expiryStrFn(); exp != "" {
details += fmt.Sprintf(" scheduled for auto-block (%s).", exp)
} else {
details += " consider manual block."
}
} else {
details += " consider manual block."
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "smtp_probe_abuse",
Message: fmt.Sprintf("SMTP probe abuse from %s: %d connections in %v",
ip, len(e.times), t.window),
Details: details,
Timestamp: now,
SourceIP: ip,
})
}
t.enforceMaxTracked()
return findings
}
// Purge removes IPs with no activity since (window + suppression) ago.
// Called periodically to prevent unbounded growth.
func (t *smtpProbeTracker) Purge() {
t.mu.Lock()
defer t.mu.Unlock()
now := t.now()
activityCutoff := now.Add(-(t.window + t.suppression))
for k, e := range t.ips {
e.times = pruneTimes(e.times, now.Add(-t.window))
if len(e.times) == 0 && !e.lastSeen.After(activityCutoff) {
delete(t.ips, k)
}
}
}
// enforceMaxTracked evicts the least-recently-seen IPs to keep memory bounded.
// Caller must hold t.mu.
func (t *smtpProbeTracker) enforceMaxTracked() {
if t.maxTracked <= 0 || len(t.ips) <= t.maxTracked {
return
}
type victim struct {
key string
seen time.Time
}
victims := make([]victim, 0, len(t.ips))
for k, e := range t.ips {
victims = append(victims, victim{k, e.lastSeen})
}
sort.Slice(victims, func(i, j int) bool { return victims[i].seen.Before(victims[j].seen) })
target := t.maxTracked * 95 / 100
for i := 0; i < len(victims) && len(t.ips) > target; i++ {
delete(t.ips, victims[i].key)
}
}
// parseEximSMTPConnectIP extracts the connecting source IP from an exim
// mainlog "SMTP connection from ..." line. Returns "" when the line is not
// a connect event.
//
// Exim formats vary:
//
// SMTP connection from [1.2.3.4]:65417 (TCP/IP connection count = 7)
// SMTP connection from (helo.example.com) [1.2.3.4]:43018 lost D=5s
// SMTP connection from ([helo-as-ip]) [1.2.3.4]:38294 lost D=15s
// SMTP connection from ([192.168.0.94]) [1.2.3.4]:64547 D=5s closed by QUIT
//
// The connecting peer is always the LAST `[ip]:port` token before flags.
func parseEximSMTPConnectIP(line string) string {
const marker = "SMTP connection from "
idx := strings.Index(line, marker)
if idx < 0 {
return ""
}
rest := line[idx+len(marker):]
// Walk all `[...]` tokens; remember the last one whose `:port` follows.
var last string
for {
open := strings.Index(rest, "[")
if open < 0 {
break
}
close := strings.Index(rest[open:], "]")
if close < 0 {
break
}
candidate := rest[open+1 : open+close]
afterClose := rest[open+close+1:]
if strings.HasPrefix(afterClose, ":") {
// Confirm digits follow the colon; that is the source port.
tail := afterClose[1:]
if len(tail) > 0 && tail[0] >= '0' && tail[0] <= '9' {
last = candidate
}
}
rest = afterClose
}
return last
}
//go:build linux
package daemon
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/unix"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/emailav"
emime "github.com/pidginhost/csm/internal/mime"
"github.com/pidginhost/csm/internal/obs"
)
// fanotify constants for permission events (not in Go stdlib).
const (
FAN_CLASS_CONTENT = 0x00000004
FAN_OPEN_PERM = 0x00010000
FAN_ALLOW = 0x01
FAN_DENY = 0x02
FAN_EVENT_ON_CHILD = 0x08000000
)
// fanotifyResponse is the struct written back to the fanotify fd
// to allow or deny a permission event.
type fanotifyResponse struct {
Fd int32
Response uint32
}
const responseSize = int(unsafe.Sizeof(fanotifyResponse{}))
// SpoolWatcher monitors Exim spool directories for new messages using a
// dedicated fanotify instance with permission events (FAN_OPEN_PERM).
// It is completely separate from the FileMonitor.
type SpoolWatcher struct {
fd int
cfg *config.Config
alertCh chan<- alert.Finding
orchestrator *emailav.Orchestrator
quarantine *emailav.Quarantine
permissionMode bool // true if using FAN_OPEN_PERM, false if fallback to FAN_CLOSE_WRITE
// emailAVTempDir is the staging directory CreateTemp uses for
// extracted attachments. Established once at watcher construction
// (0700, daemon-owned) so an unprivileged local uid cannot race
// the scanner via /tmp symlink swaps. Empty is only for hand-built
// test watchers; production construction requires state_path.
emailAVTempDir string
scanCh chan spoolEvent
pipeFds [2]int
stopOnce sync.Once
drainOnce sync.Once
stopCh chan struct{}
wg sync.WaitGroup
pipeClosed int32 // atomic
fdClosed int32 // atomic - guards sw.fd against double-close
degradedMu sync.Mutex
lastDegradedAt time.Time
}
type spoolEvent struct {
path string
fd int // fanotify event fd (for permission response)
pid int32
needResp bool // true if permission event requiring response
}
// NewSpoolWatcher creates a dedicated fanotify instance for Exim spool scanning.
// Attempts FAN_CLASS_CONTENT with FAN_OPEN_PERM first; falls back to
// FAN_CLASS_NOTIF with FAN_CLOSE_WRITE if permission events are unavailable.
func NewSpoolWatcher(cfg *config.Config, alertCh chan<- alert.Finding, orch *emailav.Orchestrator, quar *emailav.Quarantine) (*SpoolWatcher, error) {
emailAVTempDir, err := resolveEmailAVTempDir(cfg)
if err != nil {
return nil, err
}
sw := &SpoolWatcher{
cfg: cfg,
alertCh: alertCh,
orchestrator: orch,
quarantine: quar,
emailAVTempDir: emailAVTempDir,
scanCh: make(chan spoolEvent, 256),
stopCh: make(chan struct{}),
}
// Try permission-capable class first
fd, err := unix.FanotifyInit(FAN_CLASS_CONTENT|FAN_CLOEXEC|FAN_NONBLOCK, unix.O_RDONLY)
if err == nil {
sw.fd = fd
sw.permissionMode = true
fmt.Fprintf(os.Stderr, "[%s] spool watcher: permission events enabled (FAN_OPEN_PERM)\n", ts())
} else {
// Fallback to notification-only
fd, err = unix.FanotifyInit(FAN_CLASS_NOTIF|FAN_CLOEXEC|FAN_NONBLOCK, unix.O_RDONLY)
if err != nil {
return nil, fmt.Errorf("fanotify_init: %w (neither permission nor notification mode available)", err)
}
sw.fd = fd
sw.permissionMode = false
fmt.Fprintf(os.Stderr, "[%s] spool watcher: WARNING - permission events unavailable, using notification mode (small delivery race window possible)\n", ts())
if sw.cfg.EmailAV.FailMode == "tempfail" {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: WARNING - fail_mode=tempfail requested but cannot be honoured without FAN_OPEN_PERM; operating fail-open\n", ts())
}
}
// Mark spool directories
spoolDirs := []string{"/var/spool/exim/input", "/var/spool/exim4/input"}
var eventMask uint64
if sw.permissionMode {
eventMask = FAN_OPEN_PERM | FAN_EVENT_ON_CHILD
} else {
eventMask = FAN_CLOSE_WRITE | FAN_EVENT_ON_CHILD
}
marked := 0
for _, dir := range spoolDirs {
if _, err := os.Stat(dir); err != nil {
continue
}
// Use FAN_MARK_ADD (not FAN_MARK_MOUNT) to scope to the directory
err := unix.FanotifyMark(sw.fd, FAN_MARK_ADD, eventMask, -1, dir)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: cannot watch %s: %v\n", ts(), dir, err)
continue
}
marked++
fmt.Fprintf(os.Stderr, "[%s] spool watcher: watching %s\n", ts(), dir)
}
if marked == 0 {
_ = unix.Close(sw.fd)
return nil, fmt.Errorf("no Exim spool directories found to watch")
}
// Create pipe for stop signaling
if err := unix.Pipe2(sw.pipeFds[:], unix.O_NONBLOCK|unix.O_CLOEXEC); err != nil {
_ = unix.Close(sw.fd)
return nil, fmt.Errorf("creating pipe: %w", err)
}
return sw, nil
}
// Run starts the event loop and scanner workers. Blocks until Stop() is called.
func (sw *SpoolWatcher) Run() {
// Event loop. Set up epoll before starting any workers so a setup
// failure has nothing to unwind beyond drainAndClose: workers started
// first would park on scanCh forever and hang daemon shutdown.
epfd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: epoll_create: %v\n", ts(), err)
sw.drainAndClose()
return
}
defer func() { _ = unix.Close(epfd) }()
// #nosec G115 -- POSIX fd fits in int32 (rlimit ~1024). Same for all fd→int32 in this file.
if err := unix.EpollCtl(epfd, unix.EPOLL_CTL_ADD, sw.fd, &unix.EpollEvent{Events: unix.EPOLLIN, Fd: int32(sw.fd)}); err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: epoll_ctl(fanotify fd): %v\n", ts(), err)
sw.drainAndClose()
return
}
// #nosec G115 -- POSIX fd fits in int32.
if err := unix.EpollCtl(epfd, unix.EPOLL_CTL_ADD, sw.pipeFds[0], &unix.EpollEvent{Events: unix.EPOLLIN, Fd: int32(sw.pipeFds[0])}); err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: epoll_ctl(pipe fd): %v\n", ts(), err)
sw.drainAndClose()
return
}
// Start scanner workers
concurrency := sw.cfg.EmailAV.ScanConcurrency
if concurrency < 1 {
concurrency = 4
}
for i := 0; i < concurrency; i++ {
sw.wg.Add(1)
obs.Go("spool-scanner", sw.scanWorker)
}
events := make([]unix.EpollEvent, 16)
buf := make([]byte, 4096)
for {
select {
case <-sw.stopCh:
sw.drainAndClose()
return
default:
}
n, err := unix.EpollWait(epfd, events, 500)
if err != nil {
if err == unix.EINTR {
continue
}
select {
case <-sw.stopCh:
sw.drainAndClose()
return
default:
continue
}
}
for i := 0; i < n; i++ {
// #nosec G115 -- POSIX fd fits in int32.
if events[i].Fd == int32(sw.pipeFds[0]) {
sw.drainAndClose()
return
}
// #nosec G115 -- POSIX fd fits in int32.
if events[i].Fd == int32(sw.fd) {
sw.readEvents(buf)
}
}
}
}
func (sw *SpoolWatcher) readEvents(buf []byte) {
for {
n, err := unix.Read(sw.fd, buf)
if err != nil || n < metadataSize {
return
}
offset := 0
for offset+metadataSize <= n {
// #nosec G103 -- fanotify delivers a packed binary stream;
// reinterpretation is required and bounded by metadataSize above.
meta := (*fanotifyEventMetadata)(unsafe.Pointer(&buf[offset]))
if meta.Fd < 0 {
offset += int(meta.EventLen)
continue
}
path, err := os.Readlink(fmt.Sprintf("/proc/self/fd/%d", meta.Fd))
if err != nil {
// Must respond even on error paths (permission mode)
if sw.permissionMode {
sw.writeResponse(meta.Fd, FAN_ALLOW)
}
_ = unix.Close(int(meta.Fd))
offset += int(meta.EventLen)
continue
}
// Only process *-D files (message body)
if !strings.HasSuffix(path, "-D") {
if sw.permissionMode {
sw.writeResponse(meta.Fd, FAN_ALLOW)
}
_ = unix.Close(int(meta.Fd))
offset += int(meta.EventLen)
continue
}
// Send to scan workers - blocks if pool is full.
// This is intentional: backpressure on Exim's delivery runner
// is the correct behavior per the spec. Exim is designed to
// handle delivery delays; unscanned delivery is not acceptable.
evt := spoolEvent{
path: path,
fd: int(meta.Fd),
pid: meta.Pid,
needResp: sw.permissionMode,
}
select {
case sw.scanCh <- evt:
// Worker will handle response and fd close
case <-sw.stopCh:
// Shutting down - allow and close
if sw.permissionMode {
sw.writeResponse(meta.Fd, FAN_ALLOW)
}
_ = unix.Close(int(meta.Fd))
}
offset += int(meta.EventLen)
}
}
}
// scanWorker processes spool events: MIME parse, scan, quarantine/allow.
func (sw *SpoolWatcher) scanWorker() {
defer sw.wg.Done()
for {
select {
case <-sw.stopCh:
return
case evt, ok := <-sw.scanCh:
if !ok {
return
}
sw.handleSpoolEvent(evt)
}
}
}
func (sw *SpoolWatcher) handleSpoolEvent(evt spoolEvent) {
// CRITICAL: deferred FAN_ALLOW - every code path must allow by default.
// Only overridden to FAN_DENY when policy requires deferral or quarantine.
response := uint32(FAN_ALLOW)
defer func() {
if evt.needResp {
// #nosec G115 -- evt.fd is a POSIX fd; fits in int32.
sw.writeResponse(int32(evt.fd), response)
}
_ = unix.Close(evt.fd)
}()
// Derive message ID: strip -D suffix and directory
base := filepath.Base(evt.path)
msgID := strings.TrimSuffix(base, "-D")
spoolDir := filepath.Dir(evt.path)
headerPath := filepath.Join(spoolDir, msgID+"-H")
bodyPath := evt.path
// MIME parse - fail-open on error
limits := emime.Limits{
MaxAttachmentSize: sw.cfg.EmailAV.MaxAttachmentSize,
MaxArchiveDepth: sw.cfg.EmailAV.MaxArchiveDepth,
MaxArchiveFiles: sw.cfg.EmailAV.MaxArchiveFiles,
MaxExtractionSize: sw.cfg.EmailAV.MaxExtractionSize,
TempDir: sw.emailAVTempDir,
}
tempfail := sw.cfg.EmailAV.FailMode == "tempfail" && evt.needResp
extraction, err := emime.ParseSpoolMessage(headerPath, bodyPath, limits)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: MIME parse error for %s: %v\n", ts(), msgID, err)
sw.emitFinding("email_av_parse_error", alert.Warning, fmt.Sprintf("MIME parse failed for message %s: %v", msgID, err))
if tempfail {
response = FAN_DENY // tempfail: Exim retries later
}
return
}
// Clean up temp files when done
defer func() {
for _, p := range extraction.Parts {
os.Remove(p.TempPath)
}
}()
if extraction.Partial {
partialResult := &emailav.ScanResult{PartialExtraction: true}
if shouldTempfailEmailDelivery(tempfail, partialResult, nil) {
response = FAN_DENY
sw.emitDegradedWarning(fmt.Sprintf("Incomplete email attachment extraction for message %s (%s) - delivery deferred (tempfail mode)", msgID, partialExtractionReason(extraction)))
return
}
sw.emitDegradedWarning(fmt.Sprintf("Incomplete email attachment extraction for message %s (%s) - delivery allowed (fail-open mode)", msgID, partialExtractionReason(extraction)))
}
if len(extraction.Parts) == 0 {
return // No attachments to scan - allow
}
// Scan
result := sw.orchestrator.ScanParts(msgID, extraction.Parts, extraction.Partial)
// Emit degraded/timeout findings for operator visibility
if result.AllEnginesDown {
if shouldTempfailEmailDelivery(tempfail, result, nil) {
response = FAN_DENY // tempfail: defer delivery until engines recover
sw.emitDegradedWarning(fmt.Sprintf("All AV engines unavailable - message %s deferred (tempfail mode)", msgID))
return
}
sw.emitDegradedWarning(fmt.Sprintf("All AV engines unavailable - message %s delivered unscanned", msgID))
}
if len(result.TimedOutEngines) > 0 {
sw.emitFinding("email_av_timeout", alert.Warning,
fmt.Sprintf("Scan timeout for message %s on engine(s): %s", msgID, strings.Join(result.TimedOutEngines, ", ")))
if shouldTempfailEmailDelivery(tempfail, result, nil) {
sw.emitDegradedWarning(fmt.Sprintf("Incomplete AV scan - message %s deferred after engine timeout", msgID))
response = FAN_DENY
return
}
}
if len(result.ErroredEngines) > 0 {
sw.emitFinding("email_av_scan_error", alert.Warning,
fmt.Sprintf("Scan error for message %s on engine(s): %s", msgID, strings.Join(result.ErroredEngines, ", ")))
if shouldTempfailEmailDelivery(tempfail, result, nil) {
sw.emitDegradedWarning(fmt.Sprintf("Incomplete AV scan - message %s deferred after engine error", msgID))
response = FAN_DENY
return
}
}
if !result.Infected {
return // Clean - allow
}
// Infected - attempt quarantine
env := emailav.QuarantineEnvelope{
From: extraction.From,
To: extraction.To,
Subject: extraction.Subject,
Direction: extraction.Direction,
}
if sw.cfg.EmailAV.QuarantineInfected {
if err := sw.quarantine.QuarantineMessage(msgID, spoolDir, result, env); err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: quarantine failed for %s: %v\n", ts(), msgID, err)
sw.emitFinding("email_av_quarantine_error", alert.Warning,
fmt.Sprintf("Quarantine failed for infected message %s: %v", msgID, err))
if shouldTempfailEmailDelivery(tempfail, result, err) {
sw.emitDegradedWarning(fmt.Sprintf("Quarantine failed for infected message %s - delivery deferred (tempfail mode)", msgID))
response = FAN_DENY
}
} else {
// Quarantine succeeded - deny the open so Exim can't deliver
response = FAN_DENY
}
}
// Emit alert finding
sigNames := make([]string, len(result.Findings))
for i, f := range result.Findings {
sigNames[i] = fmt.Sprintf("%s(%s)", f.Signature, f.Engine)
}
msg := fmt.Sprintf("Malware detected in %s email from %s to %s: %s [subject: %s]",
extraction.Direction, extraction.From, strings.Join(extraction.To, ","),
strings.Join(sigNames, ", "), extraction.Subject)
sw.emitFinding("email_malware", alert.Critical, msg)
}
func partialExtractionReason(extraction *emime.ExtractionResult) string {
if extraction.PartialReason != "" {
return extraction.PartialReason
}
return "partial extraction"
}
func (sw *SpoolWatcher) writeResponse(fd int32, response uint32) {
resp := fanotifyResponse{Fd: fd, Response: response}
// #nosec G103 -- serializing the fanotify response struct for the
// kernel write; unsafe cast to a byte slice of the exact struct size.
respBytes := (*[responseSize]byte)(unsafe.Pointer(&resp))[:]
_, err := unix.Write(sw.fd, respBytes)
if err != nil {
// The kernel holds blocked processes until a response is written or
// the fanotify fd is closed. A failed write means the fd is broken -
// close it to release ALL pending permission events (fail-open),
// then signal the event loop to exit so the daemon can restart us.
fmt.Fprintf(os.Stderr, "[%s] spool watcher: FATAL - fanotify response write failed: %v - closing fd to release pending events\n", ts(), err)
sw.closeFd()
sw.Stop()
}
}
// closeFd closes the fanotify fd exactly once, even if called from multiple paths.
func (sw *SpoolWatcher) closeFd() {
if atomic.CompareAndSwapInt32(&sw.fdClosed, 0, 1) {
_ = unix.Close(sw.fd)
}
}
func (sw *SpoolWatcher) emitFinding(check string, severity alert.Severity, message string) {
select {
case sw.alertCh <- alert.Finding{
Severity: severity,
Check: check,
Message: message,
}:
default:
// Alert channel full - drop
}
}
// emitDegradedWarning emits an email_av_degraded finding, rate-limited to
// once per minute to avoid flooding the alert channel when clamd is down.
func (sw *SpoolWatcher) emitDegradedWarning(message string) {
sw.degradedMu.Lock()
if time.Since(sw.lastDegradedAt) < time.Minute {
sw.degradedMu.Unlock()
return
}
sw.lastDegradedAt = time.Now()
sw.degradedMu.Unlock()
sw.emitFinding("email_av_degraded", alert.Warning, message)
}
func (sw *SpoolWatcher) drainAndClose() {
sw.drainOnce.Do(func() {
close(sw.scanCh)
sw.wg.Wait()
sw.closeFd()
if atomic.CompareAndSwapInt32(&sw.pipeClosed, 0, 1) {
_ = unix.Close(sw.pipeFds[0])
_ = unix.Close(sw.pipeFds[1])
}
})
}
// PermissionMode returns true if using FAN_OPEN_PERM, false if FAN_CLOSE_WRITE fallback.
func (sw *SpoolWatcher) PermissionMode() bool {
return sw.permissionMode
}
// resolveEmailAVTempDir returns the private directory CreateTemp should use
// for extracted email parts.
func resolveEmailAVTempDir(cfg *config.Config) (string, error) {
if cfg == nil {
return "", fmt.Errorf("email AV temp dir: nil config")
}
if cfg.StatePath == "" {
return "", fmt.Errorf("email AV temp dir: state_path is empty")
}
dir := filepath.Join(cfg.StatePath, "emailav-tmp")
if err := os.MkdirAll(dir, 0o700); err != nil {
return "", fmt.Errorf("creating email AV temp dir %s: %w", dir, err)
}
if err := secureEmailAVTempDir(dir); err != nil {
return "", err
}
return dir, nil
}
func secureEmailAVTempDir(dir string) error {
info, err := os.Lstat(dir)
if err != nil {
return fmt.Errorf("checking email AV temp dir %s: %w", dir, err)
}
if info.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("email AV temp dir %s is a symlink", dir)
}
if !info.IsDir() {
return fmt.Errorf("email AV temp dir %s is not a directory", dir)
}
uid, gid, ok := fileOwner(info)
if ok {
euid := os.Geteuid()
if uid != euid {
if euid != 0 {
return fmt.Errorf("email AV temp dir %s is owned by uid %d, want uid %d", dir, uid, euid)
}
if chownErr := os.Chown(dir, euid, os.Getegid()); chownErr != nil {
return fmt.Errorf("owning email AV temp dir %s: %w", dir, chownErr)
}
} else if gid != os.Getegid() && euid == 0 {
if chownErr := os.Chown(dir, euid, os.Getegid()); chownErr != nil {
return fmt.Errorf("owning email AV temp dir %s: %w", dir, chownErr)
}
}
}
// #nosec G302 -- 0700 is the intended directory mode: daemon-only
// execute+read+write so unprivileged uids cannot enumerate or race
// staged email attachments. gosec's <=0600 rule does not distinguish
// directories from regular files.
if chmodErr := os.Chmod(dir, 0o700); chmodErr != nil {
return fmt.Errorf("chmod email AV temp dir %s: %w", dir, chmodErr)
}
info, err = os.Lstat(dir)
if err != nil {
return fmt.Errorf("checking email AV temp dir %s after chmod: %w", dir, err)
}
if info.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("email AV temp dir %s became a symlink", dir)
}
if !info.IsDir() {
return fmt.Errorf("email AV temp dir %s is not a directory", dir)
}
if info.Mode().Perm() != 0o700 {
return fmt.Errorf("email AV temp dir %s mode is %o, want 700", dir, info.Mode().Perm())
}
if uid, _, ok := fileOwner(info); ok && uid != os.Geteuid() {
return fmt.Errorf("email AV temp dir %s is owned by uid %d, want uid %d", dir, uid, os.Geteuid())
}
return nil
}
func fileOwner(info os.FileInfo) (uid, gid int, ok bool) {
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return 0, 0, false
}
return int(stat.Uid), int(stat.Gid), true
}
// Stop signals the event loop to exit.
func (sw *SpoolWatcher) Stop() {
sw.stopOnce.Do(func() {
close(sw.stopCh)
if atomic.LoadInt32(&sw.pipeClosed) == 0 {
_, _ = unix.Write(sw.pipeFds[1], []byte{0})
}
sw.closeFd()
})
}
package daemon
import "github.com/pidginhost/csm/internal/emailav"
func shouldTempfailEmailDelivery(tempfail bool, result *emailav.ScanResult, quarantineErr error) bool {
if !tempfail {
return false
}
if quarantineErr != nil {
return true
}
if result == nil {
return false
}
return result.PartialExtraction ||
result.AllEnginesDown ||
len(result.TimedOutEngines) > 0 ||
len(result.ErroredEngines) > 0
}
package daemon
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"strings"
"sync"
"syscall"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/store"
)
const (
recentOutgoingMailHoldWindow = 2 * time.Hour
logWatcherMaxLineBytes = 256 * 1024
logWatcherOffsetMarkerBytes = 256
)
// LogLineHandler parses a log line and returns findings (if any).
type LogLineHandler func(line string, cfg *config.Config) []alert.Finding
// LogWatcher tails a log file using inotify and processes new lines.
type LogWatcher struct {
path string
cfg *config.Config
handler LogLineHandler
alertCh chan<- alert.Finding
file *os.File
offset int64
fileID logFileID
marker []byte
closeOnce sync.Once
}
type logFileID struct {
dev uint64
ino uint64
known bool
}
func fileID(info os.FileInfo) logFileID {
if info == nil {
return logFileID{}
}
if st, ok := info.Sys().(*syscall.Stat_t); ok {
return logFileID{
dev: uint64(st.Dev), // #nosec G115 -- device IDs are non-negative on supported Unix hosts
ino: uint64(st.Ino), // #nosec G115 -- inode numbers are non-negative
known: st.Dev != 0 || st.Ino != 0,
}
}
return logFileID{}
}
func (id logFileID) same(other logFileID) bool {
return id.known && other.known && id.dev == other.dev && id.ino == other.ino
}
func readOffsetMarker(f *os.File, offset int64) ([]byte, bool, error) {
if f == nil || offset <= 0 {
return nil, true, nil
}
n := int64(logWatcherOffsetMarkerBytes)
if offset < n {
n = offset
}
buf := make([]byte, n)
read, err := f.ReadAt(buf, offset-n)
if err != nil && !errors.Is(err, io.EOF) {
return nil, false, err
}
return buf[:read], read == len(buf), nil
}
// NewLogWatcher creates a watcher for a log file.
func NewLogWatcher(path string, cfg *config.Config, handler LogLineHandler, alertCh chan<- alert.Finding) (*LogWatcher, error) {
// #nosec G304 -- path is operator-configured log path from csm.yaml.
f, err := os.Open(path)
if err != nil {
return nil, err
}
// Seek to end - only process new lines
offset, err := f.Seek(0, io.SeekEnd)
if err != nil {
_ = f.Close()
return nil, err
}
info, err := f.Stat()
if err != nil {
_ = f.Close()
return nil, err
}
marker, markerOK, err := readOffsetMarker(f, offset)
if err != nil {
_ = f.Close()
return nil, fmt.Errorf("read offset marker for %s: %w", path, err)
}
if !markerOK {
// The file rotated or shrank between Seek and ReadAt. Start without a
// marker rather than fail: a constructor error would disable this
// watcher until daemon restart, while readNewLines treats the saved
// position as untrusted and reads from the beginning on the next tick.
marker = nil
}
return &LogWatcher{
path: path,
cfg: cfg,
handler: handler,
alertCh: alertCh,
file: f,
offset: offset,
fileID: fileID(info),
marker: marker,
}, nil
}
// currentCfg returns the live daemon config so SIGHUP changes to thresholds,
// infra_ips, trusted_countries, and suppression settings reach the log-line
// handlers without a restart. Falls back to the startup snapshot before the
// first hot-reload publishes an active config.
func (w *LogWatcher) currentCfg() *config.Config {
if cfg := config.Active(); cfg != nil {
return cfg
}
return w.cfg
}
// Run starts watching the log file. Uses polling (every 2 seconds) instead of
// inotify to avoid complexity with log rotation. Simple, reliable, low overhead.
func (w *LogWatcher) Run(stopCh <-chan struct{}) {
// Run owns the file for its lifetime. Closing it here, rather than from a
// separate shutdown goroutine, keeps w.file single-threaded: a concurrent
// Stop() close used to race readNewLines/reopen, and the freed fd could be
// reused by another goroutine mid-Stat/Read.
defer w.closeFile()
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
// Also reopen the file every 5 minutes to handle log rotation
reopenTicker := time.NewTicker(5 * time.Minute)
defer reopenTicker.Stop()
for {
select {
case <-stopCh:
return
case <-reopenTicker.C:
w.reopen()
case <-ticker.C:
w.readNewLines()
}
}
}
// Stop closes the watcher's file. Safe to call when Run was never started
// (unit tests open a watcher just to drive readNewLines directly). When Run is
// active it closes the file itself on stopCh, so the daemon must not call Stop
// concurrently with a running Run.
func (w *LogWatcher) Stop() {
w.closeFile()
}
func (w *LogWatcher) closeFile() {
w.closeOnce.Do(func() {
if w.file != nil {
_ = w.file.Close()
}
})
}
func (w *LogWatcher) readNewLines() {
if w.file == nil {
w.reopen()
if w.file == nil {
return
}
}
info, err := w.file.Stat()
if err != nil {
w.reopen()
return
}
// File was truncated or rotated (smaller than our offset)
if info.Size() < w.offset {
w.reopen()
return
}
if !w.offsetMarkerMatches(w.file) {
w.offset = 0
w.marker = nil
}
// No new data
if info.Size() == w.offset {
return
}
// Seek to where we left off
_, err = w.file.Seek(w.offset, io.SeekStart)
if err != nil {
return
}
reader := bufio.NewReaderSize(w.file, 64*1024)
for {
rawLine, truncated, readErr := readBoundedWatcherLine(reader, logWatcherMaxLineBytes)
if truncated {
fmt.Fprintf(os.Stderr, "[%s] Warning: skipped oversized log line from %s at %d bytes\n", ts(), w.path, logWatcherMaxLineBytes)
}
if len(rawLine) > 0 && !truncated {
line := trimWatcherLineEnding(rawLine)
if line == "" {
if readErr != nil {
break
}
continue
}
findings := w.handler(line, w.currentCfg())
for _, f := range findings {
if f.Timestamp.IsZero() {
f.Timestamp = time.Now()
}
select {
case w.alertCh <- f:
default:
// Channel full - drop (backpressure)
fmt.Fprintf(os.Stderr, "[%s] Warning: alert channel full, dropping finding from %s\n", ts(), w.path)
}
}
}
if readErr != nil {
break
}
}
// Update offset
newOffset, err := w.file.Seek(0, io.SeekCurrent)
if err == nil {
w.offset = newOffset
w.refreshOffsetMarker()
}
}
func readBoundedWatcherLine(r *bufio.Reader, maxBytes int) (string, bool, error) {
var b strings.Builder
truncated := false
for {
chunk, err := r.ReadSlice('\n')
if len(chunk) > 0 {
switch {
case truncated:
case b.Len()+len(chunk) <= maxBytes:
b.Write(chunk)
default:
if room := maxBytes - b.Len(); room > 0 {
b.Write(chunk[:room])
}
truncated = true
}
}
if errors.Is(err, bufio.ErrBufferFull) {
continue
}
return b.String(), truncated, err
}
}
func trimWatcherLineEnding(line string) string {
line = strings.TrimSuffix(line, "\n")
return strings.TrimSuffix(line, "\r")
}
func (w *LogWatcher) reopen() {
if w.file != nil {
_ = w.file.Close()
// Drop the closed handle so a failed open below doesn't leave a dead
// fd behind for the next readNewLines to Stat in a loop.
w.file = nil
}
f, err := os.Open(w.path)
if err != nil {
return
}
info, err := f.Stat()
if err != nil {
_ = f.Close()
return
}
w.file = f
id := fileID(info)
switch {
case w.fileID.known && id.known && !w.fileID.same(id):
// Rotated by rename+create: new file, read from the start regardless
// of its size.
w.offset = 0
case info.Size() < w.offset:
// Truncated in place (copytruncate rotation).
w.offset = 0
case !w.offsetMarkerMatches(f):
// The saved offset now points into different content. This catches a
// truncate-and-regrow between polling ticks and cheap inode reuse.
w.offset = 0
}
if w.offset == 0 {
w.marker = nil
} else {
w.refreshOffsetMarker()
}
// Same file, size >= offset, matching marker: keep w.offset so lines
// written since the last read tick are not skipped. readNewLines seeks
// before every read.
w.fileID = id
}
func (w *LogWatcher) offsetMarkerMatches(f *os.File) bool {
if w.offset == 0 {
return true
}
if len(w.marker) == 0 {
return false
}
marker, ok, err := readOffsetMarker(f, w.offset)
return err == nil && ok && bytes.Equal(marker, w.marker)
}
func (w *LogWatcher) refreshOffsetMarker() {
marker, ok, err := readOffsetMarker(w.file, w.offset)
if err != nil || !ok {
w.marker = nil
return
}
w.marker = marker
}
// --- Log line handlers ---
func parseSessionLogLine(line string, cfg *config.Config) []alert.Finding {
var findings []alert.Finding
// cPanel login from non-infra IP - only alert on direct form login,
// not API-created sessions (from portal create_user_session)
if strings.Contains(line, "[cpaneld]") && strings.Contains(line, " NEW ") {
// Track IP→account for purge correlation (before any filtering)
if loginIP, loginAccount := parseCpanelSessionLogin(line); loginIP != "" && loginAccount != "" {
purgeTracker.recordLogin(loginIP, loginAccount)
}
switch {
case cfg.Suppressions.SuppressCpanelLogin:
// Skip all cPanel login alerts
case strings.Contains(line, "method=create_user_session") ||
strings.Contains(line, "method=create_session") ||
strings.Contains(line, "create_user_session"):
// Portal-created session - no alert
default:
ip, account := parseCpanelSessionLogin(line)
if ip != "" && account != "" && !isInfraIPDaemon(ip, cfg.InfraIPs) &&
!isTrustedCountry(ip, cfg.Suppressions.TrustedCountries) {
// WARNING severity - logins are useful for audit trail but
// not paging-level. Multi-IP correlation and brute-force
// stay at CRITICAL/HIGH via their own checks.
method := "unknown"
if strings.Contains(line, "method=handle_form_login") {
method = "direct form login"
} else if idx := strings.Index(line, "method="); idx >= 0 {
rest := line[idx+7:]
if comma := strings.IndexAny(rest, ",\n "); comma > 0 {
method = rest[:comma]
}
}
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "cpanel_login_realtime",
Message: fmt.Sprintf("cPanel direct login from non-infra IP: %s (account: %s, method: %s)", ip, account, method),
Details: truncateDaemon(line, 300),
SourceIP: ip,
TenantID: account,
})
}
}
}
// Password purge
if strings.Contains(line, "PURGE") && strings.Contains(line, "password_change") {
account := parsePurgeDaemon(line)
if account != "" {
purgeTracker.recordPurge(account)
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "cpanel_password_purge_realtime",
Message: fmt.Sprintf("cPanel password purge for: %s", account),
TenantID: account,
})
}
}
return findings
}
func parseSecureLogLine(line string, cfg *config.Config) []alert.Finding {
var findings []alert.Finding
if !strings.Contains(line, "Accepted") {
return nil
}
// Extract IP
parts := strings.Fields(line)
for i, p := range parts {
if p == "from" && i+1 < len(parts) {
ip := parts[i+1]
if isInfraIPDaemon(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
return nil
}
user := "unknown"
for j, q := range parts {
if q == "for" && j+1 < len(parts) {
user = parts[j+1]
break
}
}
tenant := user
if tenant == "unknown" {
tenant = ""
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "ssh_login_realtime",
Message: fmt.Sprintf("SSH login from non-infra IP: %s (user: %s)", ip, user),
Details: truncateDaemon(line, 200),
SourceIP: ip,
TenantID: tenant,
})
break
}
}
return findings
}
func parseEximLogLine(line string, cfg *config.Config) []alert.Finding {
var findings []alert.Finding
// 1. Frozen bounces - spam indicator
if strings.Contains(line, "frozen") {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "exim_frozen_realtime",
Message: "Exim frozen message detected",
Details: truncateDaemon(line, 200),
})
}
// 2. Outgoing mail hold - account is held by cPanel.
// Format: "Sender office@example.com has an outgoing mail hold"
// or: "Domain example.org has an outgoing mail hold"
//
// Exim emits this rejection from the enforce_mail_permissions router on
// EVERY queued-message retry while the hold is active, so re-applying the
// hold here creates a feedback loop: an operator who clears a
// false-positive hold (e.g. caused by external transit defers like the
// 2026-05-11 Microsoft edge outage) sees CSM re-set the hold within
// seconds because old queued messages keep retrying. cPanel's
// TailWatch::Eximstats is the authoritative source for setting
// the hold. CSM records the hold so later retry-limit noise from
// the held domain is not promoted to a fresh spam outbreak.
if strings.Contains(line, "outgoing mail hold") {
sender := extractMailHoldSender(line)
if sender == "" {
sender = extractEximSender(line)
}
domain := extractDomainFromEmail(sender)
if domain == "" {
domain = sender // may already be a bare domain
}
if domain != "" {
recordRecentOutgoingMailHold(domain)
RecordCompromisedDomain(domain)
}
// Alert only once per domain per hour
dedupKey := "email_hold:" + domain
if db := store.Global(); db != nil {
lastAlert := db.GetMetaString(dedupKey)
if lastAlert != "" && !isDedupExpired(lastAlert, 1*time.Hour) {
// Already alerted for this domain recently - skip finding
} else {
_ = db.SetMetaString(dedupKey, time.Now().Format(time.RFC3339))
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_compromised_account",
Message: fmt.Sprintf("Account %s is on cPanel outgoing mail hold", sender),
Details: truncateDaemon(line, 300),
Mailbox: mailboxOnly(sender),
Domain: domain,
})
}
}
}
// 3. Max defers/failures exceeded.
//
// cPanel's TailWatch::Eximstats has already throttled the domain by the
// time exim emits this line from enforce_mail_permissions, so the line is
// not independent evidence of an outbound spam blast: the same governor
// trips on inbound junk, full mailboxes, and forwarder bounces. Escalate
// to a compromise (CRITICAL + auto-hold) only when CSM's own
// authenticated-send rate window corroborates a real outbound blast for
// the domain. Otherwise report a deliverability event and leave the hold
// to cPanel, so an operator who clears a false-positive hold is not
// immediately re-held.
if strings.Contains(line, "max defers and failures per hour") {
domain := extractEximDomain(line)
if recentOutgoingMailHold(domain) {
return findings
}
if domainHasOutboundBlast(domain, cfg) {
held := maybeHoldOutgoingMail(cfg, domain)
if held {
recordRecentOutgoingMailHold(domain)
}
message := fmt.Sprintf("Spam outbreak: %s exceeded max defers/failures with high outbound volume", domain)
if held {
message += " - outgoing mail auto-suspended"
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_spam_outbreak",
Message: message,
Details: truncateDaemon(line, 300),
Domain: domain,
})
if domain != "" {
RecordCompromisedDomain(domain)
}
} else {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_defer_fail_governor",
Message: fmt.Sprintf("%s hit the cPanel defer/fail governor; no outbound spam volume observed", domain),
Details: truncateDaemon(line, 300),
Domain: domain,
})
}
}
// 4. SMTP credentials leaked in subject - compromised account
// Pattern: T="...host:port,user@domain,PASSWORD..." in the subject field
if strings.Contains(line, " <= ") && strings.Contains(line, "T=\"") {
subject := extractEximSubject(line)
subjectLower := strings.ToLower(subject)
// Detect credential patterns: host:port,user,password or
// SMTP credentials in subject (common in credential stuffing attacks)
if (strings.Contains(subject, ":587,") || strings.Contains(subject, ":465,") ||
strings.Contains(subject, ":25,")) &&
strings.Contains(subject, "@") {
sender := extractEximSender(line)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_credential_leak",
Message: fmt.Sprintf("SMTP credentials leaked in email subject from %s", sender),
Details: fmt.Sprintf("The email subject contains what appears to be SMTP credentials (host:port,user,password). This account is likely compromised by a bulk mail service.\nSubject: %s", truncateDaemon(subject, 100)),
Mailbox: mailboxOnly(sender),
Domain: extractDomainFromEmail(sender),
})
}
// Also detect common spam subject patterns
if strings.Contains(subjectLower, "password") && strings.Contains(subjectLower, "smtp") {
sender := extractEximSender(line)
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_credential_leak",
Message: fmt.Sprintf("Suspicious email subject with SMTP/password keywords from %s", sender),
Details: truncateDaemon(line, 300),
Mailbox: mailboxOnly(sender),
Domain: extractDomainFromEmail(sender),
})
}
}
// 5. Authentication from known bulk mail services
if strings.Contains(line, " <= ") && strings.Contains(line, "A=dovecot_") {
knownSpamServices := []string{
"truelist.io", "sendinblue.com", "mailspree.co",
"bulkmailer.", "massmailsoftware.", "sendblaster.",
}
lineLower := strings.ToLower(line)
for _, service := range knownSpamServices {
if strings.Contains(lineLower, service) {
sender := extractEximSender(line)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_compromised_account",
Message: fmt.Sprintf("Compromised email account %s authenticated from bulk mail service %s", sender, service),
Details: truncateDaemon(line, 300),
Mailbox: mailboxOnly(sender),
Domain: extractDomainFromEmail(sender),
})
break
}
}
}
// 6. Dovecot auth failure - brute force indicator
// Format: "dovecot_login authenticator failed for H=(hostname) [IP]:port: 535 ... (set_id=user@domain)"
if strings.Contains(line, "authenticator failed") && strings.Contains(line, "dovecot") {
ip := extractBracketedIP(line)
account := extractSetID(line)
msg := "Email authentication failure"
if account != "" {
msg += " for " + account
}
if ip != "" {
msg += " from " + ip
}
// cPanel-local mailboxes log set_id as a bare local part with no
// "@domain"; treating it as a Mailbox would leave the structured
// field empty (mailboxOnly drops bare names) and force the
// correlator to fall back to SourceIP, splitting one targeted
// account across many attacker IPs. Route the bare form to
// TenantID so the incident groups by account.
mailbox, domain, tenant := splitMailAccount(account)
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_auth_failure_realtime",
Message: msg,
Details: truncateDaemon(line, 300),
SourceIP: ip,
Mailbox: mailbox,
Domain: domain,
TenantID: tenant,
})
}
// 7. DKIM signing failures
if dkimDomain := parseDKIMFailureDomain(line); dkimDomain != "" {
dedupKey := "dkim_fail:" + dkimDomain
if db := store.Global(); db != nil {
lastAlert := db.GetMetaString(dedupKey)
if lastAlert == "" || isDedupExpired(lastAlert, 24*time.Hour) {
_ = db.SetMetaString(dedupKey, time.Now().Format(time.RFC3339))
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "email_dkim_failure",
Message: fmt.Sprintf("DKIM signing failed for %s - check key file and DNS TXT record", dkimDomain),
Details: truncateDaemon(line, 300),
Timestamp: time.Now(),
Domain: dkimDomain,
})
}
}
}
// 8. SPF/DMARC outbound rejections
if spfDomain, spfReason := parseSPFDMARCRejection(line); spfDomain != "" {
dedupKey := "spf_reject:" + spfDomain
if db := store.Global(); db != nil {
lastAlert := db.GetMetaString(dedupKey)
if lastAlert == "" || isDedupExpired(lastAlert, 24*time.Hour) {
_ = db.SetMetaString(dedupKey, time.Now().Format(time.RFC3339))
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_spf_rejection",
Message: fmt.Sprintf("Outbound mail from %s rejected due to SPF/DMARC failure", spfDomain),
Details: fmt.Sprintf("Reason: %s\n%s", spfReason, truncateDaemon(line, 200)),
Timestamp: time.Now(),
Domain: spfDomain,
})
}
}
}
// 9. Outbound rate limiting for authenticated users
if strings.Contains(line, " <= ") && strings.Contains(line, "A=dovecot_") {
authUser := extractAuthUser(line)
if authUser != "" {
rateFindings := checkEmailRate(authUser, cfg)
findings = append(findings, rateFindings...)
}
}
// 10. Cloud-relay credential abuse (multiple authenticated sends from
// distinct cloud-provider IPs for the same mailbox).
// Use the AUTH identity (A=dovecot_*:<user>), not the envelope-from:
// the envelope-from can be forged by the attacker, while the AUTH
// identity is the credential actually being abused and the one we
// must lock out.
for _, f := range parseCloudRelayFinding(line, cfg) {
handleCloudRelayCredentialAbuse(cfg, extractAuthUser(line))
findings = append(findings, f)
}
if eng := PHPRelayEvaluator(); eng != nil {
findings = append(findings, eng.parsePHPRelayAccountVolume(line, time.Now())...)
}
return findings
}
// extractEximSender extracts the sender address from an exim log line.
// Format: "... <= sender@domain.com H=..."
func extractEximSender(line string) string {
idx := strings.Index(line, " <= ")
if idx < 0 {
return ""
}
rest := line[idx+4:]
fields := strings.Fields(rest)
if len(fields) > 0 {
return fields[0]
}
return ""
}
// extractEximDomain extracts a domain from an exim log line mentioning
// "Domain X has exceeded".
func extractEximDomain(line string) string {
idx := strings.Index(line, "Domain ")
if idx < 0 {
return ""
}
rest := line[idx+7:]
if sp := strings.IndexByte(rest, ' '); sp > 0 {
return rest[:sp]
}
return rest
}
// extractEximSubject extracts the subject from T="..." in an exim log line.
func extractEximSubject(line string) string {
idx := strings.Index(line, "T=\"")
if idx < 0 {
return ""
}
rest := line[idx+3:]
end := strings.Index(rest, "\"")
if end < 0 {
return rest
}
return rest[:end]
}
// --- Helpers (avoid import cycle with checks package) ---
func parseCpanelSessionLogin(line string) (ip, account string) {
idx := strings.Index(line, "[cpaneld]")
if idx < 0 {
return "", ""
}
rest := strings.TrimSpace(line[idx+len("[cpaneld]"):])
fields := strings.Fields(rest)
if len(fields) < 3 {
return "", ""
}
ip = fields[0]
for i, f := range fields {
if f == "NEW" && i+1 < len(fields) {
parts := strings.SplitN(fields[i+1], ":", 2)
if len(parts) >= 1 {
account = parts[0]
}
break
}
}
return ip, account
}
func parsePurgeDaemon(line string) string {
idx := strings.Index(line, "PURGE")
if idx < 0 {
return ""
}
rest := strings.TrimSpace(line[idx+len("PURGE"):])
fields := strings.Fields(rest)
if len(fields) < 1 {
return ""
}
parts := strings.SplitN(fields[0], ":", 2)
if len(parts) >= 1 {
return parts[0]
}
return ""
}
func isInfraIPDaemon(ip string, infraNets []string) bool {
parsed := net.ParseIP(ip)
if parsed == nil {
return false
}
for _, cidr := range infraNets {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
// Try as plain IP
if ip == cidr {
return true
}
continue
}
if network.Contains(parsed) {
return true
}
}
return false
}
// mergeInfraIPs combines top-level infra IPs with firewall-specific ones,
// deduplicating entries. This allows the firewall to include additional CIDRs
// (e.g. server's own range) that need port access but shouldn't suppress alerts.
func mergeInfraIPs(topLevel, fwSpecific []string) []string {
// Bound the map size hint so a malformed config cannot push the sum
// near max int; the lists are operator-supplied infra/firewall CIDRs
// and tens of entries is the realistic ceiling.
const infraIPHintCap = 1 << 16
hint := len(topLevel) + len(fwSpecific)
if hint < 0 || hint > infraIPHintCap {
hint = infraIPHintCap
}
seen := make(map[string]bool, hint)
var merged []string
for _, ip := range topLevel {
if !seen[ip] {
seen[ip] = true
merged = append(merged, ip)
}
}
for _, ip := range fwSpecific {
if !seen[ip] {
seen[ip] = true
merged = append(merged, ip)
}
}
return merged
}
// outgoingMailHoldUsersPath is the cPanel file listing users currently
// under OUTGOING_MAIL_HOLD. Read-only; cPanel/WHM owns mutation. var
// (not const) so tests can point it at a fixture.
var outgoingMailHoldUsersPath = "/etc/outgoing_mail_hold_users"
// whmapi1HoldExec invokes `whmapi1 hold_outgoing_email user=<user>`.
// Declared as var so tests can replace it without spawning whmapi1.
var whmapi1HoldExec = func(user string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// #nosec G204 -- whmapi1 is the cPanel API binary; user is a cPanel
// account name resolved from /etc/userdomains (cpanel-managed).
return exec.CommandContext(ctx, "whmapi1", "hold_outgoing_email", "user="+user).CombinedOutput()
}
// userOnOutgoingMailHold reports whether the cPanel user already appears
// in /etc/outgoing_mail_hold_users. Used to short-circuit redundant
// whmapi1 calls when exim re-emits "exceeded max defers/failures" every
// retry hour while the hold is already active.
func userOnOutgoingMailHold(user string) bool {
if user == "" {
return false
}
f, err := os.Open(outgoingMailHoldUsersPath)
if err != nil {
return false
}
defer func() { _ = f.Close() }()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
if strings.TrimSpace(scanner.Text()) == user {
return true
}
}
return false
}
// maybeHoldOutgoingMail applies an outgoing-mail hold only when auto-response
// is enabled and not in dry-run. Holding a customer's outbound mail is a
// customer-impacting action, so it honours the same master switch and dry-run
// safety default as IP blocking and quarantine; an operator evaluating CSM in
// monitor mode must never have mail held out from under them. It returns true
// only when a hold was actually applied (or already active), so callers can
// keep their hold-dedup bookkeeping accurate.
func maybeHoldOutgoingMail(cfg *config.Config, domainOrEmail string) bool {
if cfg == nil || !cfg.AutoResponse.Enabled || cfg.AutoResponseDryRunEnabled() {
fmt.Fprintf(os.Stderr, "[%s] auto-suspend: would hold outgoing mail for %s (auto_response disabled or dry-run)\n",
time.Now().Format("2006-01-02 15:04:05"), domainOrEmail)
return false
}
return autoSuspendOutgoingMail(domainOrEmail)
}
// autoSuspendOutgoingMail calls whmapi1 to hold outgoing mail for the cPanel
// account that owns the given domain or email address. It returns true when
// the hold is applied or already active. Declared as var so tests can swap in
// a recorder without spawning whmapi1.
var autoSuspendOutgoingMail = autoSuspendOutgoingMailReal
func autoSuspendOutgoingMailReal(domainOrEmail string) bool {
if domainOrEmail == "" {
return false
}
// Extract domain from email if needed
domain := domainOrEmail
if atIdx := strings.LastIndexByte(domain, '@'); atIdx >= 0 {
domain = domain[atIdx+1:]
}
// Look up cPanel username for this domain
user := lookupCPanelUser(domain)
if user == "" {
fmt.Fprintf(os.Stderr, "[%s] auto-suspend: could not find cPanel user for domain %s\n",
time.Now().Format("2006-01-02 15:04:05"), domain)
return false
}
// Skip if cPanel already lists this user as held. Re-issuing the
// hold has no operational effect, but on a sustained exim retry
// loop (queued bounces keep the defer/fail ratio above threshold
// hour after hour) the redundant whmapi1 calls produce a stream
// of "AUTO-SUSPEND" log lines that look like a fresh incident.
if userOnOutgoingMailHold(user) {
return true
}
out, err := whmapi1HoldExec(user)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] auto-suspend: whmapi1 hold_outgoing_email failed for %s: %v\n%s\n",
time.Now().Format("2006-01-02 15:04:05"), user, err, string(out))
return false
}
fmt.Fprintf(os.Stderr, "[%s] AUTO-SUSPEND: outgoing mail held for cPanel user %s (domain: %s)\n",
time.Now().Format("2006-01-02 15:04:05"), user, domain)
return true
}
// userdomainsPath is the cPanel domain→user map file. var (not const)
// so tests can point it at a fixture under t.TempDir(). Production must
// not mutate at runtime.
var userdomainsPath = "/etc/userdomains"
// lookupCPanelUser finds the cPanel username that owns a domain.
// Reads userdomainsPath which maps "domain: user" per line.
func lookupCPanelUser(domain string) string {
f, err := os.Open(userdomainsPath)
if err != nil {
return ""
}
defer func() { _ = f.Close() }()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
d := strings.TrimSpace(parts[0])
u := strings.TrimSpace(parts[1])
if strings.EqualFold(d, domain) {
return u
}
}
return ""
}
// extractMailHoldSender extracts the account/domain from outgoing mail hold messages.
//
// Two formats:
//
// "Sender user@domain has an outgoing mail hold" -> "user@domain"
// "Domain example.com has an outgoing mail hold" -> "example.com"
func extractMailHoldSender(line string) string {
// Try "Sender user@domain" first
if idx := strings.Index(line, "Sender "); idx >= 0 {
rest := line[idx+7:]
if sp := strings.IndexByte(rest, ' '); sp > 0 {
return rest[:sp]
}
return rest
}
// Try "Domain example.com" format
if idx := strings.Index(line, "Domain "); idx >= 0 {
rest := line[idx+7:]
if sp := strings.IndexByte(rest, ' '); sp > 0 {
return rest[:sp]
}
return rest
}
return ""
}
func recordRecentOutgoingMailHold(domain string) {
if domain == "" {
return
}
db := store.Global()
if db == nil {
return
}
_ = db.SetMetaString("email_hold_seen:"+domain, time.Now().Format(time.RFC3339))
}
func recentOutgoingMailHold(domain string) bool {
if domain == "" {
return false
}
db := store.Global()
if db == nil {
return false
}
stored := db.GetMetaString("email_hold_seen:" + domain)
return stored != "" && !isDedupExpired(stored, recentOutgoingMailHoldWindow)
}
// extractBracketedIP extracts an IP from [IP]:port or [IP] format in exim logs.
func extractBracketedIP(line string) string {
// Find the LAST [IP] in the line (the client IP, not the hostname)
lastBracket := strings.LastIndex(line, "[")
if lastBracket < 0 {
return ""
}
rest := line[lastBracket+1:]
end := strings.IndexByte(rest, ']')
if end < 0 {
return ""
}
ip := rest[:end]
if len(ip) >= 7 && (ip[0] >= '0' && ip[0] <= '9' || ip[0] == ':') {
return ip
}
return ""
}
// extractSetID extracts the account from "(set_id=user@domain)" or "(set_id=user)" in exim logs.
func extractSetID(line string) string {
const prefix = "set_id="
idx := strings.Index(line, prefix)
if idx < 0 {
return ""
}
rest := line[idx+len(prefix):]
end := strings.IndexAny(rest, ")\n ")
if end < 0 {
return rest
}
return rest[:end]
}
func truncateDaemon(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// parseDKIMFailureDomain extracts domain from "DKIM: signing failed for {domain}"
func parseDKIMFailureDomain(line string) string {
const prefix = "DKIM: signing failed for "
idx := strings.Index(line, prefix)
if idx < 0 {
return ""
}
rest := line[idx+len(prefix):]
end := strings.IndexAny(rest, ": \t\n")
if end < 0 {
return rest
}
return rest[:end]
}
// parseSPFDMARCRejection extracts SENDER domain and rejection reason from
// exim ** permanent failure lines. Sender comes from <envelope_sender>.
func parseSPFDMARCRejection(line string) (senderDomain, reason string) {
starIdx := strings.Index(line, " ** ")
if starIdx < 0 {
return "", ""
}
// Extract envelope sender from <sender@domain> - search AFTER the **
// marker to avoid matching earlier <> fields (e.g. H=<hostname>).
rest := line[starIdx:]
ltIdx := strings.Index(rest, "<")
gtIdx := strings.Index(rest, ">")
if ltIdx < 0 || gtIdx < 0 || gtIdx <= ltIdx+1 {
return "", ""
}
sender := rest[ltIdx+1 : gtIdx]
atIdx := strings.LastIndexByte(sender, '@')
if atIdx < 0 || atIdx >= len(sender)-1 {
return "", ""
}
domain := sender[atIdx+1:]
// Extract rejection reason after last " : "
colonIdx := strings.LastIndex(line, " : ")
if colonIdx < 0 {
return "", ""
}
reason = strings.TrimSpace(line[colonIdx+3:])
if !isSPFDMARCRelated(reason) {
return "", ""
}
if len(reason) > 200 {
reason = reason[:200]
}
return domain, reason
}
// isSPFDMARCRelated checks if a rejection reason is SPF/DMARC related.
// Generic 5.7.1 alone is NOT sufficient - requires explicit auth keywords.
func isSPFDMARCRelated(reason string) bool {
if reason == "" {
return false
}
lower := strings.ToLower(reason)
for _, kw := range []string{"spf", "dmarc", "dkim"} {
if strings.Contains(lower, kw) {
return true
}
}
for _, code := range []string{"5.7.23", "5.7.25", "5.7.26"} {
if strings.Contains(lower, code) {
return true
}
}
if strings.Contains(lower, "5.7.1") {
if strings.Contains(lower, "authentication") || strings.Contains(lower, "ptr record") ||
strings.Contains(lower, "sender policy") || strings.Contains(lower, "alignment") {
return true
}
}
return false
}
// isDedupExpired checks if a stored RFC3339 timestamp is older than the given duration.
func isDedupExpired(stored string, window time.Duration) bool {
t, err := time.Parse(time.RFC3339, stored)
if err != nil {
return true
}
return time.Since(t) > window
}
// --- Outbound email rate limiting ---
// rateWindow tracks send timestamps for a single authenticated user.
type rateWindow struct {
mu sync.Mutex
times []time.Time
alerted string // last threshold level alerted ("warn" or "crit") - prevents repeated alerts per window
}
// add appends a timestamp to the window.
func (rw *rateWindow) add(t time.Time) {
rw.times = append(rw.times, t)
}
// countInWindow returns the number of timestamps within the window duration.
// Caller must hold rw.mu.
func (rw *rateWindow) countInWindow(now time.Time, window time.Duration) int {
cutoff := now.Add(-window)
count := 0
for _, t := range rw.times {
if t.After(cutoff) {
count++
}
}
return count
}
// prune removes timestamps older than the window duration and resets the
// alerted flag when the count drops below thresholds. Caller must hold rw.mu.
func (rw *rateWindow) prune(now time.Time, window time.Duration) {
cutoff := now.Add(-window)
kept := rw.times[:0]
for _, t := range rw.times {
if t.After(cutoff) {
kept = append(kept, t)
}
}
rw.times = kept
}
// emailRateWindows tracks per-user send rate windows.
var emailRateWindows sync.Map // map[string]*rateWindow
// extractAuthUser extracts the authenticated user from an exim <= line.
// Looks for A=dovecot_login:{user} or A=dovecot_plain:{user}.
// Returns empty string if not found or line is not an acceptance line.
func extractAuthUser(line string) string {
if !strings.Contains(line, " <= ") {
return ""
}
// Look for A=dovecot_login: or A=dovecot_plain:
for _, prefix := range []string{"A=dovecot_login:", "A=dovecot_plain:"} {
idx := strings.Index(line, prefix)
if idx < 0 {
continue
}
rest := line[idx+len(prefix):]
end := strings.IndexAny(rest, " \t\n")
if end < 0 {
return rest
}
return rest[:end]
}
return ""
}
// isHighVolumeSender checks if a user is in the high-volume senders allowlist.
func isHighVolumeSender(user string, allowlist []string) bool {
for _, allowed := range allowlist {
if strings.EqualFold(user, allowed) {
return true
}
}
return false
}
// extractDomainFromEmail returns the domain part of an email address.
func extractDomainFromEmail(email string) string {
idx := strings.LastIndexByte(email, '@')
if idx < 0 || idx >= len(email)-1 {
return ""
}
return email[idx+1:]
}
// mailboxOnly returns the input only when it looks like a full mailbox
// (contains '@'); otherwise returns "". Used by realtime emit sites that
// receive either "user@domain" or a bare domain — the bare domain belongs
// in the Domain field, not Mailbox, so the correlator does not collapse
// distinct mailboxes onto a domain key.
func mailboxOnly(s string) string {
if strings.IndexByte(s, '@') < 0 {
return ""
}
return s
}
// splitMailAccount classifies an authenticated mail account string into
// the three correlation fields. A full mailbox ("user@domain") routes to
// Mailbox + Domain. A bare local part (cPanel-style, no '@') routes to
// TenantID so the incident correlator groups by account, not by attacker
// SourceIP. An empty input returns three empty strings.
func splitMailAccount(account string) (mailbox, domain, tenant string) {
if account == "" {
return "", "", ""
}
if strings.IndexByte(account, '@') < 0 {
return "", "", account
}
return account, extractDomainFromEmail(account), ""
}
// hasRecentCompromisedFinding checks if there's a recent email_compromised_account
// or email_spam_outbreak finding for the given domain (suppresses rate alerts).
func hasRecentCompromisedFinding(domain string) bool {
emailRateSuppressed.mu.Lock()
defer emailRateSuppressed.mu.Unlock()
if ts, ok := emailRateSuppressed.domains[domain]; ok {
if time.Since(ts) < time.Hour {
return true
}
delete(emailRateSuppressed.domains, domain)
}
return false
}
// emailRateSuppressed tracks domains with recent compromised/spam findings.
var emailRateSuppressed = struct {
mu sync.Mutex
domains map[string]time.Time
}{domains: make(map[string]time.Time)}
// RecordCompromisedDomain marks a domain as having a recent compromised finding.
// Called from parseEximLogLine when email_compromised_account or email_spam_outbreak fires.
func RecordCompromisedDomain(domain string) {
emailRateSuppressed.mu.Lock()
defer emailRateSuppressed.mu.Unlock()
emailRateSuppressed.domains[domain] = time.Now()
}
// domainHasOutboundBlast reports whether authenticated senders under the given
// domain have produced enough outbound volume within the rate window to
// corroborate an actual spam outbreak. A cPanel defer/fail governor trip alone
// is not such evidence. Returns false when rate thresholds are unconfigured, so
// an operator who has not tuned the rate window never auto-holds on a bare
// governor line.
func domainHasOutboundBlast(domain string, cfg *config.Config) bool {
if domain == "" || cfg == nil {
return false
}
threshold := cfg.EmailProtection.RateWarnThreshold
windowDur := time.Duration(cfg.EmailProtection.RateWindowMin) * time.Minute
if threshold <= 0 || windowDur <= 0 {
return false
}
now := time.Now()
total := 0
emailRateWindows.Range(func(key, val any) bool {
user, ok := key.(string)
if !ok || !strings.EqualFold(extractDomainFromEmail(user), domain) {
return true
}
rw, ok := val.(*rateWindow)
if !ok {
return true
}
rw.mu.Lock()
total += rw.countInWindow(now, windowDur)
rw.mu.Unlock()
return total < threshold // stop iterating once corroborated
})
return total >= threshold
}
// checkEmailRate processes an outbound email for rate limiting.
// Returns findings if thresholds are exceeded.
func checkEmailRate(user string, cfg *config.Config) []alert.Finding {
// Guard: skip if thresholds are zero (misconfigured or disabled)
if cfg.EmailProtection.RateWarnThreshold <= 0 || cfg.EmailProtection.RateCritThreshold <= 0 {
return nil
}
if isHighVolumeSender(user, cfg.EmailProtection.HighVolumeSenders) {
return nil
}
// Load or create rate window for this user
val, _ := emailRateWindows.LoadOrStore(user, &rateWindow{})
rw := val.(*rateWindow)
now := time.Now()
windowDur := time.Duration(cfg.EmailProtection.RateWindowMin) * time.Minute
rw.mu.Lock()
defer rw.mu.Unlock()
// Check domain suppression BEFORE adding to window - prevents
// phantom rate inflation for suppressed domains.
domain := extractDomainFromEmail(user)
if domain != "" && hasRecentCompromisedFinding(domain) {
return nil
}
rw.add(now)
count := rw.countInWindow(now, windowDur)
// Reset alerted state when count drops below warn threshold -
// allows re-alerting on the next burst after the window slides.
if count < cfg.EmailProtection.RateWarnThreshold {
rw.alerted = ""
}
var findings []alert.Finding
mailbox, domain, tenant := splitMailAccount(user)
if count >= cfg.EmailProtection.RateCritThreshold {
if rw.alerted != "crit" {
rw.alerted = "crit"
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_rate_critical",
Message: fmt.Sprintf("Email rate CRITICAL: %s sent %d messages in %d minutes (threshold: %d)", user, count, cfg.EmailProtection.RateWindowMin, cfg.EmailProtection.RateCritThreshold),
Details: fmt.Sprintf("User: %s\nMessages in window: %d\nWindow: %d minutes\nThreshold: %d", user, count, cfg.EmailProtection.RateWindowMin, cfg.EmailProtection.RateCritThreshold),
Mailbox: mailbox,
Domain: domain,
TenantID: tenant,
})
}
} else if count >= cfg.EmailProtection.RateWarnThreshold {
if rw.alerted != "warn" && rw.alerted != "crit" {
rw.alerted = "warn"
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_rate_warning",
Message: fmt.Sprintf("Email rate WARNING: %s sent %d messages in %d minutes (threshold: %d)", user, count, cfg.EmailProtection.RateWindowMin, cfg.EmailProtection.RateWarnThreshold),
Details: fmt.Sprintf("User: %s\nMessages in window: %d\nWindow: %d minutes\nThreshold: %d", user, count, cfg.EmailProtection.RateWindowMin, cfg.EmailProtection.RateWarnThreshold),
Mailbox: mailbox,
Domain: domain,
TenantID: tenant,
})
}
}
return findings
}
// StartEmailRateEviction starts a background goroutine that prunes expired
// rate windows every 10 minutes. Same pattern as StartModSecEviction.
func StartEmailRateEviction(stopCh <-chan struct{}) {
obs.Go("email-rate-eviction", func() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case now := <-ticker.C:
evictEmailRateWindows(now)
}
}
})
}
// evictEmailRateWindows prunes all per-user rate windows and deletes empty entries.
func evictEmailRateWindows(now time.Time) {
// Use a generous 60-minute eviction window to avoid premature deletion.
// The actual rate window is checked during rate evaluation.
evictWindow := 60 * time.Minute
emailRateWindows.Range(func(key, val any) bool {
rw := val.(*rateWindow)
rw.mu.Lock()
rw.prune(now, evictWindow)
empty := len(rw.times) == 0
if empty {
rw.alerted = ""
}
rw.mu.Unlock()
if empty {
emailRateWindows.Delete(key)
}
return true
})
// Also prune the suppressed domains map
emailRateSuppressed.mu.Lock()
for domain, ts := range emailRateSuppressed.domains {
if time.Since(ts) > time.Hour {
delete(emailRateSuppressed.domains, domain)
}
}
emailRateSuppressed.mu.Unlock()
}
package daemon
import "regexp"
// looksLikePHPWebshell returns true when PHP file content exhibits the
// canonical realtime-detectable webshell shapes:
// 1. A request superglobal ($_GET / $_POST / $_REQUEST / $_COOKIE /
// php://input) flowing into a code-execution primitive
// (eval / assert / system / passthru / exec / shell_exec / proc_open
// / popen / create_function), with optional decoder layers
// (base64_decode / gzinflate / str_rot13).
// 2. eval/assert wrapping a decoder of an arbitrary base64 / gzinflate
// blob (the obfuscated-payload primitive).
//
// Returns false on legitimate code that uses dangerous functions in
// non-attack contexts (Pear Text_Diff/Engine/shell.php's shell_exec call
// to Unix `diff`, TinyMCE charmap.php's static glyph data array).
func looksLikePHPWebshell(data []byte) bool {
if len(data) == 0 {
return false
}
// No inner byte cap. The fanotify callers own the read window and
// record when that window is hit. A duplicate cap here would silently
// shorten any caller that legitimately passed a larger buffer, hiding
// truncation from that caller without adding protection beyond the
// upstream reader limit.
for _, re := range webshellContentRegexes {
if re.Match(data) {
return true
}
}
return requestVariableFlowsToDangerousFunction(data)
}
// webshellContentRegexes are the realtime-detection-grade content patterns
// for looksLikePHPWebshell. Compiled once at package init.
var webshellContentRegexes = []*regexp.Regexp{
// Request superglobal directly piped into a code-execution primitive
// in the same expression (with optional decoder layers).
regexp.MustCompile(`(?i)\b(?:eval|assert|system|passthru|exec|shell_exec|proc_open|popen|create_function)\s*\(\s*(?:gzinflate\s*\(\s*|str_rot13\s*\(\s*|base64_decode\s*\(\s*|@\s*)*\s*\$_(?:GET|POST|REQUEST|COOKIE|FILES|SERVER)\b`),
// php://input piped into eval/assert/system in the same expression.
regexp.MustCompile(`(?i)\b(?:eval|assert|system|passthru|exec|shell_exec|proc_open|popen|create_function)\s*\(\s*(?:gzinflate\s*\(\s*|str_rot13\s*\(\s*|base64_decode\s*\(\s*|@\s*)*\s*file_get_contents\s*\(\s*['"]php://input`),
// eval/assert wrapping a base64/gzinflate/str_rot13 decoder of a
// long literal blob (obfuscated-payload primitive).
regexp.MustCompile(`(?i)\b(?:eval|assert)\s*\(\s*(?:gzinflate\s*\(\s*|str_rot13\s*\(\s*)?base64_decode\s*\(\s*['"][A-Za-z0-9+/=]{40,}`),
}
var (
requestAssignmentRegex = regexp.MustCompile(`(?is)\$([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(?:@?\s*)?(?:\$_(?:GET|POST|REQUEST|COOKIE|FILES|SERVER)\b(?:\s*\[[^\]]{0,200}\])?|file_get_contents\s*\(\s*['"]php://input['"]\s*\))[^;]{0,200};`)
dangerousVariableCallRegex = regexp.MustCompile(`(?is)\b(?:eval|assert|system|passthru|exec|shell_exec|proc_open|popen|create_function)\s*\(\s*(?:@?\s*)?\$([A-Za-z_][A-Za-z0-9_]*)\b`)
)
func requestVariableFlowsToDangerousFunction(data []byte) bool {
matches := requestAssignmentRegex.FindAllSubmatchIndex(data, -1)
for _, match := range matches {
if len(match) < 4 {
continue
}
varName := string(data[match[2]:match[3]])
windowEnd := match[1] + 800
if windowEnd > len(data) {
windowEnd = len(data)
}
for _, call := range dangerousVariableCallRegex.FindAllSubmatch(data[match[1]:windowEnd], -1) {
if len(call) == 2 && string(call[1]) == varName {
return true
}
}
}
return false
}
package daemon
import (
"context"
"fmt"
"os"
"sync"
"syscall"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/yara"
"github.com/pidginhost/csm/internal/yaraworker"
)
// yaraMetricsOnce guards registration of the yara-worker restart
// counter hook so a baseline re-run or a second daemon instance in the
// same test binary does not panic with "duplicate registration".
var yaraMetricsOnce sync.Once
// yaraWorkerOn reports whether the daemon should run YARA-X in a
// supervised child process. The field is a *bool tri-state: nil means
// "use system default" (true, per ROADMAP item 2 follow-up), *true is
// explicit opt-in, *false is explicit opt-out.
//
// A nil cfg falls back to false so a pathological caller does not
// accidentally spin up a worker; production never passes nil.
func yaraWorkerOn(cfg *config.Config) bool {
if cfg == nil {
return false
}
if cfg.Signatures.YaraWorkerEnabled == nil {
return true
}
return *cfg.Signatures.YaraWorkerEnabled
}
// initYaraBackend wires up either the out-of-process YARA-X supervisor
// (default, per ROADMAP item 2 follow-up) or the in-process scanner
// (when config.Signatures.YaraWorkerEnabled is explicitly *false).
// Both paths register themselves as the yara package's active backend
// so existing callers work unchanged.
//
// In worker mode, yara.Init is deliberately NOT called: the rule
// compile happens inside the child process, not the daemon. This is
// the point of the feature (ROADMAP item 2) — a cgo crash while
// compiling or scanning stays contained to the child. Matches carry
// string-valued rule metadata (see yara.Match.Meta / yaraipc.Match.Meta)
// so the emailav YARA-X adapter works identically under both backends
// — severity no longer needs the in-process *yara_x.Rules object.
func (d *Daemon) initYaraBackend() error {
if !yaraWorkerOn(d.cfg) {
if yaraScanner := yara.Init(d.cfg.Signatures.RulesDir); yaraScanner != nil {
fmt.Fprintf(os.Stderr, "[%s] YARA-X scanner active: %d rule file(s)\n", ts(), yaraScanner.RuleCount())
}
return nil
}
sup, err := yaraworker.NewSupervisor(yaraworker.SupervisorConfig{
BinaryPath: d.binaryPath,
SocketPath: yaraworker.DefaultSocketPath(),
RulesDir: d.cfg.Signatures.RulesDir,
StartTimeout: 10 * time.Second,
MinRestartInterval: time.Second,
MaxRestartInterval: 60 * time.Second,
StableDuration: 30 * time.Second,
ClientTimeout: 30 * time.Second,
OnRestart: d.onYaraWorkerRestart,
Logf: func(format string, args ...any) {
fmt.Fprintf(os.Stderr, "[%s] yara-worker: "+format+"\n", append([]any{ts()}, args...)...)
},
})
if err != nil {
return fmt.Errorf("creating yara-worker supervisor: %w", err)
}
if err := sup.Start(context.Background()); err != nil {
return fmt.Errorf("starting yara-worker: %w", err)
}
d.yaraSup = sup
yara.SetActive(sup)
// Expose the supervisor's cumulative restart count to Prometheus.
// Registered once per process; subsequent calls re-point nothing
// (the closure captures `sup`, and a second daemon.Run in the
// same process would need to arrange for the metric to follow).
yaraMetricsOnce.Do(func() {
metrics.RegisterCounterFunc(
"csm_yara_worker_restarts_total",
"Number of times the YARA-X worker subprocess has been restarted by its supervisor.",
func() float64 { return float64(sup.RestartCount()) },
)
})
fmt.Fprintf(os.Stderr,
"[%s] YARA-X worker active: %d rule(s) compiled in child process (pid=%d)\n",
ts(), sup.RuleCount(), sup.ChildPID())
return nil
}
// stopYaraBackend is called during daemon shutdown. Safe to call when
// worker mode is off.
func (d *Daemon) stopYaraBackend() {
if d.yaraSup == nil {
return
}
if err := d.yaraSup.Stop(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] yara-worker stop: %v\n", ts(), err)
}
yara.SetActive(nil)
}
// onYaraWorkerRestart is called once per unplanned worker exit. Emits
// a Critical finding the first time, then one every minute after to
// avoid spamming alerts while a broken rule package is in place.
func (d *Daemon) onYaraWorkerRestart(exitCode int, sig syscall.Signal, ranFor time.Duration) {
now := time.Now()
d.yaraCrashMu.Lock()
last := d.yaraLastCrashAlert
d.yaraLastCrashAlert = now
d.yaraCrashMu.Unlock()
fmt.Fprintf(os.Stderr, "[%s] yara-worker exited code=%d signal=%v ran=%s\n",
ts(), exitCode, sig, ranFor.Round(time.Millisecond))
if !last.IsZero() && now.Sub(last) < time.Minute {
return
}
finding := alert.Finding{
Severity: alert.Critical,
Check: "yara_worker_crashed",
Message: fmt.Sprintf("YARA-X worker crashed (exit=%d signal=%v after %s); supervisor restarted it, real-time scanning recovered.", exitCode, sig, ranFor.Round(time.Millisecond)),
}
select {
case d.alertCh <- finding:
default:
// Channel saturated; the daemon's general drop-counter path
// already tracks this.
}
}
package emailav
import (
"encoding/binary"
"fmt"
"io"
"net"
"os"
"strings"
"time"
)
// ClamdScanner scans files via the clamd Unix socket using INSTREAM.
type ClamdScanner struct {
socketPath string
}
// NewClamdScanner creates a ClamdScanner that connects to clamd at the given socket path.
func NewClamdScanner(socketPath string) *ClamdScanner {
return &ClamdScanner{socketPath: socketPath}
}
func (s *ClamdScanner) Name() string { return "clamav" }
// Available checks if clamd is reachable by attempting a connection.
func (s *ClamdScanner) Available() bool {
conn, err := net.DialTimeout("unix", s.socketPath, 2*time.Second)
if err != nil {
return false
}
_ = conn.Close()
return true
}
// Scan sends the file at path to clamd via INSTREAM and returns the verdict.
func (s *ClamdScanner) Scan(path string) (Verdict, error) {
// #nosec G304 -- path is mail queue file path from mail scanner walk.
f, err := os.Open(path)
if err != nil {
return Verdict{}, fmt.Errorf("opening file: %w", err)
}
defer f.Close()
conn, err := net.DialTimeout("unix", s.socketPath, 5*time.Second)
if err != nil {
return Verdict{}, fmt.Errorf("connecting to clamd: %w", err)
}
defer func() { _ = conn.Close() }()
err = conn.SetDeadline(time.Now().Add(30 * time.Second))
if err != nil {
return Verdict{}, fmt.Errorf("setting deadline: %w", err)
}
// Send INSTREAM command
_, err = conn.Write([]byte("nINSTREAM\n"))
if err != nil {
return Verdict{}, fmt.Errorf("sending INSTREAM: %w", err)
}
// Stream file content in chunks
buf := make([]byte, 8192)
lenBuf := make([]byte, 4)
for {
n, readErr := f.Read(buf)
if n > 0 {
// #nosec G115 -- n is bounded by len(buf)=8192; fits in uint32.
binary.BigEndian.PutUint32(lenBuf, uint32(n))
_, err = conn.Write(lenBuf)
if err != nil {
return Verdict{}, fmt.Errorf("sending chunk length: %w", err)
}
_, err = conn.Write(buf[:n])
if err != nil {
return Verdict{}, fmt.Errorf("sending chunk data: %w", err)
}
}
if readErr == io.EOF {
break
}
if readErr != nil {
return Verdict{}, fmt.Errorf("reading file: %w", readErr)
}
}
// Send terminator (4 zero bytes)
binary.BigEndian.PutUint32(lenBuf, 0)
_, err = conn.Write(lenBuf)
if err != nil {
return Verdict{}, fmt.Errorf("sending terminator: %w", err)
}
// Read response
resp := make([]byte, 4096)
n, err := conn.Read(resp)
if err != nil && err != io.EOF {
return Verdict{}, fmt.Errorf("reading response: %w", err)
}
return parseClamdResponse(string(resp[:n]))
}
// parseClamdResponse parses a clamd INSTREAM response line.
// "stream: OK\n" → clean
// "stream: Win.Trojan.Agent-123 FOUND\n" → infected
func parseClamdResponse(resp string) (Verdict, error) {
resp = strings.TrimSpace(resp)
if strings.HasSuffix(resp, "OK") {
return Verdict{Infected: false}, nil
}
if strings.HasSuffix(resp, "FOUND") {
// Extract signature: "stream: <sig> FOUND"
resp = strings.TrimPrefix(resp, "stream: ")
sig := strings.TrimSuffix(resp, " FOUND")
return Verdict{
Infected: true,
Signature: sig,
Severity: "critical",
}, nil
}
return Verdict{}, fmt.Errorf("unexpected clamd response: %q", resp)
}
package emailav
import (
"context"
"fmt"
"os"
"sync"
"time"
emime "github.com/pidginhost/csm/internal/mime"
"github.com/pidginhost/csm/internal/obs"
)
// Orchestrator runs multiple scanners in parallel against extracted email parts.
type Orchestrator struct {
scanners []Scanner
scanTimeout time.Duration
}
// NewOrchestrator creates an orchestrator with the given scanners and per-scan timeout.
func NewOrchestrator(scanners []Scanner, scanTimeout time.Duration) *Orchestrator {
return &Orchestrator{
scanners: scanners,
scanTimeout: scanTimeout,
}
}
// ScanParts scans all extracted parts with all available engines.
// Fail-open: unavailable engines, timeouts, and errors are recorded but do not
// mark the message as infected.
func (o *Orchestrator) ScanParts(messageID string, parts []emime.ExtractedPart, partial bool) *ScanResult {
result := &ScanResult{
MessageID: messageID,
ScannedAt: time.Now(),
PartialExtraction: partial,
}
// Determine which engines are available
var available []Scanner
for _, s := range o.scanners {
if s.Available() {
available = append(available, s)
result.EnginesUsed = append(result.EnginesUsed, s.Name())
} else {
result.FailedEngines = append(result.FailedEngines, s.Name())
fmt.Fprintf(os.Stderr, "[emailav] engine %s unavailable\n", s.Name())
}
}
if len(available) == 0 {
// fail-open: no engines available - rate-limit the warning
result.AllEnginesDown = true
return result
}
// Scan each part with all available engines
for _, part := range parts {
findings, timedOut, errored := o.scanPart(part, available)
result.Findings = append(result.Findings, findings...)
result.TimedOutEngines = append(result.TimedOutEngines, timedOut...)
result.ErroredEngines = append(result.ErroredEngines, errored...)
}
result.Infected = len(result.Findings) > 0
return result
}
// scanPart scans a single part with all available engines concurrently.
// Returns findings and lists of engine names that timed out or errored.
func (o *Orchestrator) scanPart(part emime.ExtractedPart, scanners []Scanner) ([]Finding, []string, []string) {
type scanResult struct {
engine string
verdict Verdict
err error
timedOut bool
}
ctx, cancel := context.WithTimeout(context.Background(), o.scanTimeout)
defer cancel()
results := make(chan scanResult, len(scanners))
var wg sync.WaitGroup
for _, s := range scanners {
wg.Add(1)
scanner := s
obs.SafeGo("emailav-scan", func() {
defer wg.Done()
done := make(chan scanResult, 1)
obs.SafeGo("emailav-engine", func() {
v, err := scanner.Scan(part.TempPath)
done <- scanResult{engine: scanner.Name(), verdict: v, err: err}
})
select {
case r := <-done:
results <- r
case <-ctx.Done():
results <- scanResult{engine: scanner.Name(), err: fmt.Errorf("scan timeout"), timedOut: true}
}
})
}
// Close results channel when all scans complete
obs.SafeGo("emailav-drain", func() {
wg.Wait()
close(results)
})
var findings []Finding
var timedOut []string
var errored []string
for r := range results {
if r.err != nil {
fmt.Fprintf(os.Stderr, "[emailav] %s scan error on %s: %v\n", r.engine, part.Filename, r.err)
if r.timedOut {
timedOut = append(timedOut, r.engine)
} else {
errored = append(errored, r.engine)
}
continue // fail-open
}
if r.verdict.Infected {
f := Finding{
Filename: part.Filename,
Engine: r.engine,
Signature: r.verdict.Signature,
Severity: r.verdict.Severity,
}
if part.Nested {
f.Filename = part.ArchiveName + "/" + part.Filename
}
findings = append(findings, f)
}
}
return findings, timedOut, errored
}
package emailav
import (
"encoding/json"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"syscall"
"time"
)
type movedFile struct {
src string
dst string
}
// vars so tests can force EXDEV and source swaps without depending on
// the host filesystem layout.
var (
moveFileRename = os.Rename
moveFileAfterCrossDeviceCopy = func(string, string) error { return nil }
)
// QuarantineEnvelope holds the email envelope info for quarantine metadata.
type QuarantineEnvelope struct {
From string
To []string
Subject string
Direction string
}
// QuarantineMetadata is the JSON sidecar written with quarantined messages.
type QuarantineMetadata struct {
MessageID string `json:"message_id"`
Direction string `json:"direction"`
From string `json:"from"`
To []string `json:"to"`
Subject string `json:"subject"`
QuarantinedAt time.Time `json:"quarantined_at"`
OriginalSpoolDir string `json:"original_spool_dir"`
Findings []Finding `json:"findings"`
PartialScan bool `json:"partial_scan"`
EnginesUsed []string `json:"engines_used"`
}
// Quarantine manages the per-message email quarantine directory.
type Quarantine struct {
baseDir string // e.g. /opt/csm/quarantine/email
allowedSpoolDirs []string
}
// NewQuarantine creates a quarantine manager for the given base directory.
func NewQuarantine(baseDir string) *Quarantine {
return &Quarantine{
baseDir: baseDir,
allowedSpoolDirs: []string{"/var/spool/exim/input", "/var/spool/exim4/input"},
}
}
// QuarantineMessage moves spool files into a per-message quarantine directory
// and writes metadata.json.
func (q *Quarantine) QuarantineMessage(msgID, spoolDir string, result *ScanResult, env QuarantineEnvelope) error {
msgID, err := cleanQuarantineMessageID(msgID)
if err != nil {
return err
}
msgDir := filepath.Join(q.baseDir, msgID)
meta := QuarantineMetadata{
MessageID: msgID,
Direction: env.Direction,
From: env.From,
To: env.To,
Subject: env.Subject,
QuarantinedAt: time.Now(),
OriginalSpoolDir: spoolDir,
Findings: result.Findings,
PartialScan: result.PartialExtraction,
EnginesUsed: result.EnginesUsed,
}
metaData, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return fmt.Errorf("marshaling metadata: %w", err)
}
if err := os.MkdirAll(msgDir, 0700); err != nil {
return fmt.Errorf("creating quarantine dir: %w", err)
}
var moved []movedFile
for _, suffix := range []string{"-H", "-D"} {
src := filepath.Join(spoolDir, msgID+suffix)
dst := filepath.Join(msgDir, msgID+suffix)
if err := moveFile(src, dst); err != nil {
if !os.IsNotExist(err) {
rollbackErr := rollbackMovedFiles(moved)
_ = os.RemoveAll(msgDir)
if rollbackErr != nil {
return fmt.Errorf("moving spool file %s: %w (rollback failed: %v)", suffix, err, rollbackErr)
}
return fmt.Errorf("moving spool file %s: %w", suffix, err)
}
continue
}
moved = append(moved, movedFile{src: src, dst: dst})
}
if len(moved) == 0 {
os.Remove(msgDir)
return fmt.Errorf("no spool files found for %s", msgID)
}
metaPath := filepath.Join(msgDir, "metadata.json")
if err := os.WriteFile(metaPath, metaData, 0600); err != nil {
rollbackErr := rollbackMovedFiles(moved)
_ = os.RemoveAll(msgDir)
if rollbackErr != nil {
return fmt.Errorf("writing metadata: %w (rollback failed: %v)", err, rollbackErr)
}
return fmt.Errorf("writing metadata: %w", err)
}
return nil
}
// ListMessages returns all quarantined email messages.
func (q *Quarantine) ListMessages() ([]QuarantineMetadata, error) {
entries, err := os.ReadDir(q.baseDir)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, fmt.Errorf("reading quarantine dir: %w", err)
}
var msgs []QuarantineMetadata
for _, entry := range entries {
if !entry.IsDir() {
continue
}
meta, err := q.readMetadata(entry.Name())
if err != nil {
continue
}
msgs = append(msgs, *meta)
}
return msgs, nil
}
// GetMessage returns the metadata for a single quarantined message.
func (q *Quarantine) GetMessage(msgID string) (*QuarantineMetadata, error) {
msgID, err := cleanQuarantineMessageID(msgID)
if err != nil {
return nil, err
}
return q.readMetadata(msgID)
}
// ReleaseMessage moves spool files back to the original spool directory
// and removes the quarantine directory.
func (q *Quarantine) ReleaseMessage(msgID string) error {
msgID, err := cleanQuarantineMessageID(msgID)
if err != nil {
return err
}
meta, err := q.readMetadata(msgID)
if err != nil {
return fmt.Errorf("reading metadata: %w", err)
}
spoolDir, err := q.validateReleaseSpoolDir(meta.OriginalSpoolDir)
if err != nil {
return err
}
msgDir := filepath.Join(q.baseDir, msgID)
for _, suffix := range []string{"-H", "-D"} {
src := filepath.Join(msgDir, msgID+suffix)
dst := filepath.Join(spoolDir, msgID+suffix)
if err := moveFile(src, dst); err != nil {
// If source doesn't exist, skip (partial quarantine)
if os.IsNotExist(err) {
continue
}
return fmt.Errorf("moving %s back to spool: %w", suffix, err)
}
}
return os.RemoveAll(msgDir)
}
// DeleteMessage permanently removes a quarantined message. Unlike
// ReleaseMessage it does not require metadata: deleting an orphaned or
// partially-written entry is a legitimate cleanup operation.
func (q *Quarantine) DeleteMessage(msgID string) error {
msgID, err := cleanQuarantineMessageID(msgID)
if err != nil {
return err
}
msgDir := filepath.Join(q.baseDir, msgID)
return os.RemoveAll(msgDir)
}
// CleanExpired removes quarantine directories older than maxAge.
// Returns the number of directories cleaned.
func (q *Quarantine) CleanExpired(maxAge time.Duration) (int, error) {
entries, err := os.ReadDir(q.baseDir)
if err != nil {
if os.IsNotExist(err) {
return 0, nil
}
return 0, err
}
cleaned := 0
cutoff := time.Now().Add(-maxAge)
for _, entry := range entries {
if !entry.IsDir() {
continue
}
meta, err := q.readMetadata(entry.Name())
if err != nil {
continue
}
if meta.QuarantinedAt.Before(cutoff) {
os.RemoveAll(filepath.Join(q.baseDir, entry.Name()))
cleaned++
}
}
return cleaned, nil
}
func (q *Quarantine) readMetadata(msgID string) (*QuarantineMetadata, error) {
msgID, err := cleanQuarantineMessageID(msgID)
if err != nil {
return nil, err
}
metaPath := filepath.Join(q.baseDir, msgID, "metadata.json")
// #nosec G304 -- msgID is restricted to a single path segment under baseDir.
data, err := os.ReadFile(metaPath)
if err != nil {
return nil, err
}
var meta QuarantineMetadata
if err := json.Unmarshal(data, &meta); err != nil {
return nil, err
}
return &meta, nil
}
func cleanQuarantineMessageID(msgID string) (string, error) {
if msgID == "" {
return "", fmt.Errorf("message id is required")
}
if msgID == "." || msgID == ".." || strings.ContainsAny(msgID, `/\`) {
return "", fmt.Errorf("invalid message id")
}
if filepath.Base(msgID) != msgID {
return "", fmt.Errorf("invalid message id")
}
return msgID, nil
}
// moveFile renames src to dst, falling back to a fd-bound copy when
// rename fails (cross-device or other EXDEV-style errors). Callers
// construct dst by filepath.Join under the quarantine base dir
// (config-owned) plus a validated single-segment identifier.
//
// The cross-device path opens src with O_NOFOLLOW, copies content
// from the open fd into a freshly-created dst, and unlinks src by the
// path only after verifying the path still names the same inode. An
// attacker who swaps src for a symlink or replaces the file
// mid-copy is rejected.
func moveFile(src, dst string) error {
if err := moveFileRename(src, dst); err == nil {
return nil
} else if !errors.Is(err, syscall.EXDEV) {
// Non-EXDEV rename errors are not a cross-device condition;
// surface them directly so the caller does not silently
// fall back to copy semantics for permission or path errors.
return err
}
// #nosec G304 -- src is mail queue path from scanner walk;
// O_NOFOLLOW plus the identity check below catch symlink swaps.
fd, err := os.OpenFile(src, os.O_RDONLY|syscall.O_NOFOLLOW, 0)
if err != nil {
return fmt.Errorf("opening cross-device source: %w", err)
}
defer fd.Close()
srcInfo, err := fd.Stat()
if err != nil {
return fmt.Errorf("stat cross-device source: %w", err)
}
if !srcInfo.Mode().IsRegular() {
return fmt.Errorf("refusing cross-device copy of non-regular %s", src)
}
// #nosec G304 G306 -- dst is constructed under the quarantine baseDir;
// 0600 keeps the copy private.
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o600)
if err != nil {
return fmt.Errorf("creating cross-device destination: %w", err)
}
removeDst := true
defer func() {
if removeDst {
_ = os.Remove(dst)
}
}()
if _, copyErr := io.Copy(dstFile, fd); copyErr != nil {
_ = dstFile.Close()
return fmt.Errorf("cross-device copy body: %w", copyErr)
}
if closeErr := dstFile.Close(); closeErr != nil {
return fmt.Errorf("cross-device close: %w", closeErr)
}
if hookErr := moveFileAfterCrossDeviceCopy(src, dst); hookErr != nil {
return fmt.Errorf("cross-device post-copy check: %w", hookErr)
}
pathInfo, err := os.Lstat(src)
if err != nil {
return fmt.Errorf("stat cross-device source path: %w", err)
}
if !sameUnixInode(srcInfo, pathInfo) {
return fmt.Errorf("cross-device source %s changed during copy", src)
}
if err := os.Remove(src); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("removing cross-device source: %w", err)
}
removeDst = false
return nil
}
// sameUnixInode compares two FileInfos via underlying syscall.Stat_t so
// the cross-device move can verify the source path still names the
// same inode it opened. Falls back to false if either Sys() does not
// expose Stat_t (e.g. non-Linux builds).
func sameUnixInode(a, b os.FileInfo) bool {
if a == nil || b == nil {
return false
}
if os.SameFile(a, b) {
return true
}
as, aok := a.Sys().(*syscall.Stat_t)
bs, bok := b.Sys().(*syscall.Stat_t)
return aok && bok && as.Dev == bs.Dev && as.Ino == bs.Ino
}
func rollbackMovedFiles(moved []movedFile) error {
for i := len(moved) - 1; i >= 0; i-- {
if err := moveFile(moved[i].dst, moved[i].src); err != nil {
return err
}
}
return nil
}
func (q *Quarantine) validateReleaseSpoolDir(spoolDir string) (string, error) {
cleanDir := filepath.Clean(spoolDir)
if cleanDir == "" || !filepath.IsAbs(cleanDir) {
return "", fmt.Errorf("invalid original spool directory")
}
// Fail closed when the supplied path will not resolve. The previous
// behaviour fell back to the unresolved literal, which let a path
// that does not currently exist on disk still match an allowed-list
// entry that also does not exist (e.g. on a host where only one of
// the default exim spool defaults is installed). A release that
// targets a path we cannot resolve cannot have its identity
// confirmed, so it must not proceed.
resolvedDir, err := filepath.EvalSymlinks(cleanDir)
if err != nil {
return "", fmt.Errorf("resolving original spool directory %q: %w", cleanDir, err)
}
for _, allowed := range q.allowedSpoolDirs {
cleanAllowed := filepath.Clean(allowed)
// Allowed entries are operator-config trusted defaults; the
// shipped list intentionally covers both exim and exim4, so a
// host that only has one of them must silently skip the
// non-installed entry rather than falling back to the literal
// (which would silently widen the trust boundary). The input
// path was already required to resolve above, so a missing
// allowed entry can never alias the resolved input.
resolvedAllowed, err := filepath.EvalSymlinks(cleanAllowed)
if err != nil {
continue
}
if resolvedDir == resolvedAllowed {
return resolvedDir, nil
}
}
return "", fmt.Errorf("original spool directory is not trusted: %s", cleanDir)
}
//go:build !yara
package emailav
import "github.com/pidginhost/csm/internal/yara"
// YaraXScanner is a no-op stub when YARA-X is not compiled in. The
// constructor still accepts a yara.Backend so callers can pass
// yara.Active() uniformly across build tags.
type YaraXScanner struct{}
// NewYaraXScanner returns a scanner that is never available.
func NewYaraXScanner(_ yara.Backend) *YaraXScanner {
return &YaraXScanner{}
}
func (s *YaraXScanner) Name() string { return "yara-x" }
func (s *YaraXScanner) Available() bool { return false }
func (s *YaraXScanner) Scan(_ string) (Verdict, error) {
return Verdict{}, nil
}
package emailspool
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"strconv"
"strings"
"golang.org/x/net/idna"
)
// ExtractDomain returns the lowercased, IDN-normalised domain portion of an
// RFC 5322 address or display-name form. Returns "" on parse failure.
//
// Quoted local parts ("a@b"@example.com) are handled by treating the address
// as the substring after the LAST unquoted '@'.
func ExtractDomain(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return ""
}
if i := strings.LastIndex(s, "<"); i >= 0 {
if j := strings.LastIndex(s, ">"); j > i {
s = s[i+1 : j]
}
}
at := lastUnquotedAt(s)
if at < 0 {
return ""
}
domain := strings.TrimSpace(s[at+1:])
domain = strings.ToLower(domain)
if ascii, err := idna.ToASCII(domain); err == nil {
domain = ascii
}
return domain
}
// lastUnquotedAt returns the index of the rightmost '@' character that is not
// inside a double-quoted segment, or -1 if none.
func lastUnquotedAt(s string) int {
inQuote := false
last := -1
for i := 0; i < len(s); i++ {
c := s[i]
switch c {
case '\\':
i++ // skip the next byte (quoted-pair)
case '"':
inQuote = !inQuote
case '@':
if !inQuote {
last = i
}
}
}
return last
}
// IsSubdomainOrEqual reports whether candidate is base or a subdomain of base.
// Both inputs are case-insensitive. Empty inputs return false.
func IsSubdomainOrEqual(candidate, base string) bool {
if candidate == "" || base == "" {
return false
}
c := strings.ToLower(strings.TrimSuffix(candidate, "."))
b := strings.ToLower(strings.TrimSuffix(base, "."))
if c == b {
return true
}
return strings.HasSuffix(c, "."+b)
}
// MaxSpoolHeaderBytes bounds how much of an Exim -H file we read.
// 32 KiB covers rich messages with DKIM, ARC, and folded MIME headers
// without unbounded memory.
const MaxSpoolHeaderBytes = 32 * 1024
// Headers is the parsed envelope + interesting RFC 5322 fields from a
// cPanel-Exim spool -H file. EnvelopeUser comes from the file's line 2
// (the local UID under which Exim accepted the message); the RFC 5322
// fields come from the message's mail headers section. Empty string means
// "header absent" -- callers should not synthesise defaults from absence.
type Headers struct {
EnvelopeUser string
EnvelopeUID int
From string
ReplyTo string
Subject string
XPHPScript string
XMailer string
UserAgent string
MessageID string
}
// ParseHeaders reads the given Exim -H file and returns a Headers.
// Returns the parse error if the file cannot be opened or is structurally
// invalid; otherwise missing individual headers leave the corresponding
// Headers field empty.
func ParseHeaders(path string) (Headers, error) {
// #nosec G304 -- path is supplied by the daemon's Exim spool
// watcher and resolved from /var/spool/exim/input/, an
// operator-trusted directory enumerated by the spool walker.
f, err := os.Open(path)
if err != nil {
return Headers{}, err
}
defer f.Close()
h, err := ParseHeadersReader(f)
if err != nil {
return Headers{}, fmt.Errorf("parse %s: %w", path, err)
}
return h, nil
}
// ParseHeadersReader is the io.Reader form of ParseHeaders for callers that
// already have the spool bytes in memory or behind a custom seam (e.g. the
// checks package's osFS abstraction). It applies the same Exim -H parsing
// rules as ParseHeaders -- envelope preamble, blank-line separator,
// "NNNX " prefixed RFC 5322 headers -- and is bounded by MaxSpoolHeaderBytes
// per token; oversize input returns bufio.ErrTooLong.
func ParseHeadersReader(r io.Reader) (Headers, error) {
var h Headers
// Per-line memory is bounded by the scanner's max buffer
// (MaxSpoolHeaderBytes); a token larger than that returns
// bufio.ErrTooLong. We deliberately do NOT wrap r in an io.LimitReader:
// when LimitReader returns EOF mid-token, bufio.Scanner emits the
// partial token without error and oversize spool files are silently
// truncated. That hides the failure from operators.
sc := bufio.NewScanner(r)
sc.Buffer(make([]byte, 0, 8192), MaxSpoolHeaderBytes)
// Line 1: msgID-H (we ignore the value; presence is enough)
if !sc.Scan() {
return Headers{}, errors.New("empty spool file")
}
// Line 2: "<user> <uid> <gid>"
if !sc.Scan() {
return Headers{}, errors.New("missing envelope user line")
}
fields := splitFields(sc.Text())
if len(fields) < 2 {
return Headers{}, fmt.Errorf("malformed envelope user line: %q", sc.Text())
}
h.EnvelopeUser = fields[0]
if uid, err := strconv.Atoi(fields[1]); err == nil {
h.EnvelopeUID = uid
}
// Skip remaining envelope metadata until the blank line that separates
// it from the RFC 5322 header section. Exim's -H format places mail
// headers AFTER a blank line that follows the recipient list; recipients
// are preceded by a numeric count line.
inHeaders := false
for sc.Scan() {
line := sc.Text()
if !inHeaders {
if line == "" {
inHeaders = true
}
continue
}
// RFC 5322 header section. Each header line in Exim's -H format
// starts with "<count><flag> <name>: <value>" where count is 3-4
// digits and flag is a single ASCII letter ('T', 'F', 'R', etc.).
// Folded continuations start with whitespace.
name, value := parseEximHeaderLine(line)
switch name {
case "From":
h.From = value
case "Reply-To":
h.ReplyTo = value
case "Subject":
h.Subject = value
case "X-PHP-Script":
h.XPHPScript = value
case "X-Mailer":
h.XMailer = value
case "User-Agent":
h.UserAgent = value
case "Message-ID":
h.MessageID = value
}
}
if err := sc.Err(); err != nil {
return Headers{}, fmt.Errorf("scan spool: %w", err)
}
if !inHeaders {
return Headers{}, errors.New("missing header section separator")
}
return h, nil
}
// parseEximHeaderLine returns (header-name, header-value) for an Exim -H
// header line of the form "NNNF Header-Name: value", or ("", "") if the
// line cannot be parsed (folded continuations, garbage). The 4-byte prefix
// is 3 digits + flag byte (a single ASCII letter such as 'T', 'F', 'R', or
// a space when the header has no flag) followed by a space separator; it
// is stripped before splitting on the colon.
func parseEximHeaderLine(line string) (string, string) {
if len(line) < 5 {
return "", ""
}
// Position 3 is space for unflagged Exim headers (e.g. X-PHP-Script
// in real cPanel-Exim spool output) and a flag letter (T/F/R/I/...)
// for flagged headers. Both shapes appear; the parser captures both.
if !isDigit(line[0]) || !isDigit(line[1]) || !isDigit(line[2]) {
return "", ""
}
if !isLetter(line[3]) && line[3] != ' ' || line[4] != ' ' {
return "", ""
}
rest := line[5:]
colon := indexByte(rest, ':')
if colon < 0 {
return "", ""
}
name := rest[:colon]
value := ""
if colon+1 < len(rest) {
value = trimLeadingSpace(rest[colon+1:])
}
return name, value
}
func isDigit(b byte) bool { return b >= '0' && b <= '9' }
func isLetter(b byte) bool { return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') }
func indexByte(s string, b byte) int {
for i := 0; i < len(s); i++ {
if s[i] == b {
return i
}
}
return -1
}
func trimLeadingSpace(s string) string {
i := 0
for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
i++
}
return s[i:]
}
func splitFields(s string) []string {
var out []string
start := -1
for i := 0; i < len(s); i++ {
if s[i] == ' ' || s[i] == '\t' {
if start >= 0 {
out = append(out, s[start:i])
start = -1
}
} else if start < 0 {
start = i
}
}
if start >= 0 {
out = append(out, s[start:])
}
return out
}
package emailspool
import (
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"gopkg.in/yaml.v3"
)
// Policies is the loaded contents of policies/email/*.yaml. Used by Stage 1
// (mailer classes, http_proxy_ranges) and Stage 2 (policy_blocks). Each
// data file has its own schema; failing to parse one file does not abort
// the whole load -- each category degrades independently.
type Policies struct {
mu sync.RWMutex
suspiciousMail []string
safeMail []string
proxyNets []*net.IPNet
// selfIPs are the host's own interface addresses. PHP-relay Path 4
// (HTTP-IP fanout) treats them as proxy-equivalent so WordPress cron
// and any other local loopback-to-public traffic does not page on
// "one IP triggered N scripts". Populated by RefreshSelfIPs.
selfIPs []net.IP
}
// hostIPsFunc returns the set of non-loopback host IPs to treat as self.
// Package-level so tests can inject deterministic addresses.
var hostIPsFunc = enumerateHostIPs
// enumerateHostIPs reads every IPv4/IPv6 address bound to a non-loopback
// interface. Errors from net.InterfaceAddrs (extremely rare; only on syscall
// failure) drop us to an empty list; loopback fallback in IsProxyIP still
// covers 127/8 and ::1.
func enumerateHostIPs() []net.IP {
addrs, err := net.InterfaceAddrs()
if err != nil {
return nil
}
out := make([]net.IP, 0, len(addrs))
for _, a := range addrs {
var ip net.IP
switch v := a.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() || ip.IsLinkLocalUnicast() {
continue
}
out = append(out, ip)
}
return out
}
// RefreshSelfIPs re-enumerates host addresses and replaces the cached set.
// Safe to call concurrently with IsProxyIP. LoadPolicies / Reload call this
// automatically so SIGHUP picks up cPanel alias IP additions without code
// changes to callers; exposed for tests and the daemon Flow E ticker.
func (p *Policies) RefreshSelfIPs() {
ips := hostIPsFunc()
p.mu.Lock()
p.selfIPs = ips
p.mu.Unlock()
}
type mailerClassesYAML struct {
Version int `yaml:"version"`
Suspicious []string `yaml:"suspicious"`
Safe []string `yaml:"safe"`
}
//nolint:unused // consumed by E2/E3
type httpProxyRangesYAML struct {
Version int `yaml:"version"`
CIDRs []string `yaml:"cidrs"`
}
// LoadPolicies reads all known policy files in dir. Missing files are
// treated as "no entries"; corrupt files return an error per file but do
// not abort the load. The returned Policies is safe for concurrent use.
func LoadPolicies(dir string) (*Policies, error) {
p := &Policies{}
if err := p.load(dir); err != nil {
return p, err
}
return p, nil
}
func (p *Policies) load(dir string) error {
p.mu.Lock()
// Refresh self IPs alongside file-backed policies so SIGHUP picks up
// any cPanel alias IP additions that landed since startup.
p.selfIPs = hostIPsFunc()
defer p.mu.Unlock()
var firstErr error
setErr := func(err error) {
if firstErr == nil {
firstErr = err
}
}
// mailer_classes.yaml
// #nosec G304 -- dir is the operator-supplied policy directory; filename is a fixed literal under it.
if data, err := os.ReadFile(filepath.Join(dir, "mailer_classes.yaml")); err == nil {
var raw mailerClassesYAML
if uerr := yaml.Unmarshal(data, &raw); uerr != nil {
setErr(fmt.Errorf("parse mailer_classes.yaml: %w", uerr))
} else {
p.suspiciousMail = lowerList(raw.Suspicious)
p.safeMail = lowerList(raw.Safe)
}
} else if !os.IsNotExist(err) {
setErr(fmt.Errorf("read mailer_classes.yaml: %w", err))
}
// http_proxy_ranges.yaml
// #nosec G304 -- dir is the operator-supplied policy directory; filename is a fixed literal under it.
if data, err := os.ReadFile(filepath.Join(dir, "http_proxy_ranges.yaml")); err == nil {
var raw httpProxyRangesYAML
if uerr := yaml.Unmarshal(data, &raw); uerr != nil {
setErr(fmt.Errorf("parse http_proxy_ranges.yaml: %w", uerr))
} else {
nets := make([]*net.IPNet, 0, len(raw.CIDRs))
for _, c := range raw.CIDRs {
_, n, perr := net.ParseCIDR(strings.TrimSpace(c))
if perr != nil {
setErr(fmt.Errorf("invalid CIDR %q in http_proxy_ranges.yaml: %w", c, perr))
continue
}
nets = append(nets, n)
}
p.proxyNets = nets
}
} else if !os.IsNotExist(err) {
setErr(fmt.Errorf("read http_proxy_ranges.yaml: %w", err))
}
return firstErr
}
// Reload refreshes from dir while keeping previous values for any category
// whose new file is corrupt (existing reload-error contract). Used by SIGHUP.
//
//nolint:unused // consumed by E3
func (p *Policies) Reload(dir string) error {
return p.load(dir)
}
// MailerSuspicious reports whether x-mailer header matches any suspicious
// substring. Substrings, not exact match.
func (p *Policies) MailerSuspicious(xMailer string) bool {
if xMailer == "" {
return false
}
p.mu.RLock()
defer p.mu.RUnlock()
low := strings.ToLower(xMailer)
for _, s := range p.suspiciousMail {
if strings.Contains(low, s) {
return true
}
}
return false
}
// MailerSafe reports whether x-mailer matches any safe substring.
func (p *Policies) MailerSafe(xMailer string) bool {
if xMailer == "" {
return false
}
p.mu.RLock()
defer p.mu.RUnlock()
low := strings.ToLower(xMailer)
for _, s := range p.safeMail {
if strings.Contains(low, s) {
return true
}
}
return false
}
// IsProxyIP reports whether ip falls within any configured CDN/proxy CIDR,
// is one of the host's own interface addresses, or is a loopback address.
// Used by Path 4 to skip fanout counting for IPs that are CDN front IPs,
// the local host (WordPress cron, panel-internal callbacks), or 127/::1.
func (p *Policies) IsProxyIP(ip string) bool {
if ip == "" {
return false
}
parsed := net.ParseIP(ip)
if parsed == nil {
return false
}
if parsed.IsLoopback() {
return true
}
p.mu.RLock()
defer p.mu.RUnlock()
for _, n := range p.proxyNets {
if n.Contains(parsed) {
return true
}
}
for _, self := range p.selfIPs {
if self.Equal(parsed) {
return true
}
}
return false
}
func lowerList(in []string) []string {
out := make([]string, 0, len(in))
for _, s := range in {
s = strings.ToLower(strings.TrimSpace(s))
if s != "" {
out = append(out, s)
}
}
return out
}
package firewall
import (
"bufio"
"encoding/json"
"log"
"os"
"path/filepath"
"time"
)
const maxAuditFileSize = 10 * 1024 * 1024 // 10 MB
// AuditEntry records a firewall modification for compliance and forensics.
type AuditEntry struct {
Timestamp time.Time `json:"timestamp"`
Action string `json:"action"` // block, unblock, allow, remove_allow, flush, apply
IP string `json:"ip,omitempty"`
Reason string `json:"reason,omitempty"`
Source string `json:"source,omitempty"`
Duration string `json:"duration,omitempty"`
}
// AppendAudit writes an audit entry to the JSONL audit log.
// Rotates the log when it exceeds 10 MB.
func AppendAudit(statePath, action, ip, reason, source string, duration time.Duration) {
if source == "" {
source = InferProvenance(action, reason)
}
entry := AuditEntry{
Timestamp: time.Now(),
Action: action,
IP: ip,
Reason: reason,
Source: source,
}
if duration > 0 {
entry.Duration = duration.String()
}
path := filepath.Join(statePath, "audit.jsonl")
data, err := json.Marshal(entry)
if err != nil {
return
}
data = append(data, '\n')
// Rotate if file exceeds max size
if info, statErr := os.Stat(path); statErr == nil && info.Size() > maxAuditFileSize {
_ = os.Rename(path, path+".1")
}
// #nosec G304 -- path is filepath.Join under operator-configured statePath.
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
// Without this log, perm/disk-full/inode-exhaust modes drop the
// audit entry with no operator-visible signal -- the Write/Close
// branches below already log, so a silent Open path was the only
// remaining hole in the audit pipeline.
log.Printf("firewall: audit open failed for %s: %v", path, err)
return
}
if _, writeErr := f.Write(data); writeErr != nil {
_ = f.Close()
log.Printf("firewall: audit write failed for %s: %v", path, writeErr)
return
}
// Close error on a writable file is the disk-full / fsync signal --
// without it, a dropped audit entry leaves no record anywhere.
if closeErr := f.Close(); closeErr != nil {
log.Printf("firewall: audit close failed for %s: %v", path, closeErr)
}
}
// ReadAuditLog returns the last N audit entries from the log.
func ReadAuditLog(statePath string, limit int) []AuditEntry {
path := filepath.Join(statePath, "firewall", "audit.jsonl")
// #nosec G304 -- filepath.Join under operator-configured statePath.
f, err := os.Open(path)
if err != nil {
return nil
}
defer f.Close()
var all []AuditEntry
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var entry AuditEntry
if json.Unmarshal(scanner.Bytes(), &entry) == nil {
all = append(all, entry)
}
}
if limit > 0 && len(all) > limit {
all = all[len(all)-limit:]
}
return all
}
//go:build linux
package firewall
import (
"fmt"
"net"
"os"
"github.com/google/nftables"
)
// UpdateCloudflareSet flushes and repopulates the Cloudflare nftables sets.
func (e *Engine) UpdateCloudflareSet(ipv4, ipv6 []string) error {
e.mu.Lock()
defer e.mu.Unlock()
// Both sets are created together by createSets, but ConnectExisting
// loads them independently, so one can be present while the other is
// nil. Require both -- FlushSet/SetAddElements panic on a nil set.
if e.setCFWhitelist == nil || e.setCFWhitelist6 == nil {
return fmt.Errorf("cf_whitelist sets not initialized")
}
// Flush existing entries
e.conn.FlushSet(e.setCFWhitelist)
e.conn.FlushSet(e.setCFWhitelist6)
// Populate IPv4 CIDRs
var elems4 []nftables.SetElement
for _, cidr := range ipv4 {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
start := network.IP.To4()
end := lastIPInRange(network)
if start == nil || end == nil {
continue
}
elems4 = appendIntervalSetElements(elems4, start, end)
}
if len(elems4) > 0 {
if err := e.conn.SetAddElements(e.setCFWhitelist, elems4); err != nil {
return fmt.Errorf("adding CF IPv4 elements: %w", err)
}
}
// Populate IPv6 CIDRs
var elems6 []nftables.SetElement
for _, cidr := range ipv6 {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
start := network.IP.To16()
end := lastIPInRange(network)
if start == nil || end == nil {
continue
}
elems6 = appendIntervalSetElements(elems6, start, end)
}
if len(elems6) > 0 {
if err := e.conn.SetAddElements(e.setCFWhitelist6, elems6); err != nil {
return fmt.Errorf("adding CF IPv6 elements: %w", err)
}
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing CF whitelist: %w", err)
}
fmt.Fprintf(os.Stderr, "firewall: cloudflare whitelist updated: %d IPv4, %d IPv6 CIDRs\n",
len(ipv4), len(ipv6))
return nil
}
// CloudflareIPs returns the currently configured Cloudflare CIDRs from the cached state.
func (e *Engine) CloudflareIPs() (ipv4, ipv6 []string) {
return LoadCFState(e.statePath)
}
package firewall
import (
"bufio"
"fmt"
"net"
"net/http"
"os"
"strings"
"time"
)
const (
cfIPv4URL = "https://www.cloudflare.com/ips-v4"
cfIPv6URL = "https://www.cloudflare.com/ips-v6"
)
// FetchCloudflareIPs downloads the current Cloudflare IP ranges.
func FetchCloudflareIPs() (ipv4, ipv6 []string, err error) {
client := &http.Client{Timeout: 30 * time.Second}
ipv4, err = fetchCIDRList(client, cfIPv4URL)
if err != nil {
return nil, nil, fmt.Errorf("fetching CF IPv4: %w", err)
}
ipv6, err = fetchCIDRList(client, cfIPv6URL)
if err != nil {
return nil, nil, fmt.Errorf("fetching CF IPv6: %w", err)
}
return ipv4, ipv6, nil
}
// fetchCIDRList fetches a URL and parses one CIDR per line.
func fetchCIDRList(client *http.Client, url string) ([]string, error) {
resp, err := client.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, url)
}
return parseCloudflareResponse(bufio.NewScanner(resp.Body)), nil
}
// parseCloudflareResponse parses lines from a scanner, returning valid CIDRs.
// Skips blank lines, comments, and invalid entries.
func parseCloudflareResponse(scanner *bufio.Scanner) []string {
var cidrs []string
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
_, _, err := net.ParseCIDR(line)
if err != nil {
continue
}
cidrs = append(cidrs, line)
}
return cidrs
}
// SaveCFState persists the Cloudflare CIDRs for status display.
func SaveCFState(statePath string, ipv4, ipv6 []string, refreshed time.Time) {
path := statePath
if !strings.HasSuffix(path, "/firewall") {
path += "/firewall"
}
_ = os.MkdirAll(path, 0700)
var sb strings.Builder
fmt.Fprintf(&sb, "# refreshed: %s\n", refreshed.Format(time.RFC3339))
sb.WriteString("# ipv4\n")
for _, cidr := range ipv4 {
sb.WriteString(cidr)
sb.WriteByte('\n')
}
sb.WriteString("# ipv6\n")
for _, cidr := range ipv6 {
sb.WriteString(cidr)
sb.WriteByte('\n')
}
file := path + "/cf_whitelist.txt"
_ = os.WriteFile(file, []byte(sb.String()), 0600)
}
// LoadCFState reads the cached Cloudflare CIDRs.
func LoadCFState(statePath string) (ipv4, ipv6 []string) {
path := statePath
if !strings.HasSuffix(path, "/firewall") {
path += "/firewall"
}
file := path + "/cf_whitelist.txt"
// #nosec G304 -- fixed filename under operator-configured statePath.
f, err := os.Open(file)
if err != nil {
return nil, nil
}
defer f.Close()
scanner := bufio.NewScanner(f)
section := "ipv4"
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "# ipv4" {
section = "ipv4"
continue
}
if line == "# ipv6" {
section = "ipv6"
continue
}
if line == "" || strings.HasPrefix(line, "#") {
continue
}
switch section {
case "ipv4":
ipv4 = append(ipv4, line)
case "ipv6":
ipv6 = append(ipv6, line)
}
}
return ipv4, ipv6
}
// LoadCFRefreshTime reads the last CF refresh time from state.
func LoadCFRefreshTime(statePath string) time.Time {
path := statePath
if !strings.HasSuffix(path, "/firewall") {
path += "/firewall"
}
file := path + "/cf_whitelist.txt"
// #nosec G304 -- fixed filename under operator-configured statePath.
f, err := os.Open(file)
if err != nil {
return time.Time{}
}
defer f.Close()
scanner := bufio.NewScanner(f)
if scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "# refreshed: ") {
ts := strings.TrimPrefix(line, "# refreshed: ")
if t, err := time.Parse(time.RFC3339, ts); err == nil {
return t
}
}
}
return time.Time{}
}
package firewall
// FirewallConfig defines the nftables firewall configuration.
type FirewallConfig struct {
Enabled bool `yaml:"enabled"`
// Open ports (IPv4)
TCPIn []int `yaml:"tcp_in"`
TCPOut []int `yaml:"tcp_out"`
UDPIn []int `yaml:"udp_in"`
UDPOut []int `yaml:"udp_out"`
// IPv6 - enable dual-stack filtering
IPv6 bool `yaml:"ipv6"`
TCP6In []int `yaml:"tcp6_in"` // if empty, uses tcp_in
TCP6Out []int `yaml:"tcp6_out"` // if empty, uses tcp_out
UDP6In []int `yaml:"udp6_in"` // if empty, uses udp_in
UDP6Out []int `yaml:"udp6_out"` // if empty, uses udp_out
// Ports restricted to infra IPs only
RestrictedTCP []int `yaml:"restricted_tcp"`
// Passive FTP range
PassiveFTPStart int `yaml:"passive_ftp_start"`
PassiveFTPEnd int `yaml:"passive_ftp_end"`
// Infra IPs (CIDR notation)
InfraIPs []string `yaml:"infra_ips"`
// Rate limiting (per-source-IP via nftables meters)
ConnRateLimit int `yaml:"conn_rate_limit"` // new connections per minute per IP
SYNFloodProtection bool `yaml:"syn_flood_protection"`
ConnLimit int `yaml:"conn_limit"` // max concurrent connections per IP (0 = disabled)
// Per-port flood protection - per-source rate limit per port and IP family.
PortFlood []PortFloodRule `yaml:"port_flood"`
// UDP flood protection - per-source rate limit on UDP packets
UDPFlood bool `yaml:"udp_flood"`
UDPFloodRate int `yaml:"udp_flood_rate"` // packets per second
UDPFloodBurst int `yaml:"udp_flood_burst"` // burst allowance
// Country blocking
CountryBlock []string `yaml:"country_block"` // ISO country codes
CountryDBPath string `yaml:"country_db_path"`
// Ports to drop silently without logging (reduces log noise from scanners)
DropNoLog []int `yaml:"drop_nolog"`
// Max blocked IPs (prevents memory exhaustion, 0 = unlimited)
DenyIPLimit int `yaml:"deny_ip_limit"`
DenyTempIPLimit int `yaml:"deny_temp_ip_limit"`
// Outbound SMTP restriction - block outgoing mail except from allowed users
SMTPBlock bool `yaml:"smtp_block"`
SMTPAllowUsers []string `yaml:"smtp_allow_users"` // usernames allowed to send
SMTPPorts []int `yaml:"smtp_ports"`
// Dynamic DNS - resolve hostnames to IPs, update allowed set periodically
DynDNSHosts []string `yaml:"dyndns_hosts"`
// Logging
LogDropped bool `yaml:"log_dropped"`
LogRate int `yaml:"log_rate"` // log entries per minute
}
// PortFloodRule defines per-port connection rate limiting.
type PortFloodRule struct {
Port int `yaml:"port"`
Proto string `yaml:"proto"` // "tcp" or "udp"
Hits int `yaml:"hits"` // max new connections
Seconds int `yaml:"seconds"` // time window in seconds
}
// DefaultConfig returns a sensible default firewall configuration
// matching a typical cPanel server.
func DefaultConfig() *FirewallConfig {
return &FirewallConfig{
Enabled: false,
// SSH (22) is intentionally absent. Many cPanel hosts move sshd to
// 2087 or other alt ports; operators who keep sshd on 22 must add
// it explicitly. TCP 853 enables DNS-over-TLS; UDP 853 enables
// DNS-over-QUIC.
TCPIn: []int{
20, 21, 25, 26, 53, 80, 110, 143, 443, 465, 587,
853, 993, 995, 2077, 2078, 2079, 2080, 2082, 2083,
2091, 2095, 2096,
},
TCPOut: []int{
20, 21, 25, 26, 37, 43, 53, 80, 110, 113, 443,
465, 587, 853, 873, 993, 995, 2082, 2083, 2086, 2087,
2089, 2195, 2325, 2703,
},
UDPIn: []int{53, 443, 853},
// 6277/24441 are DCC/Pyzor network checks used by SpamAssassin.
// Without them outbound spam-scoring queries silently fail.
UDPOut: []int{53, 113, 123, 443, 853, 873, 6277, 24441},
RestrictedTCP: []int{2086, 2087, 2325, 9443},
PassiveFTPStart: 49152,
PassiveFTPEnd: 65534,
// 200 new connections per minute per source IP. Sized to tolerate
// shared CGNAT egress (residential ISPs / mobile carriers) where
// hundreds of subscribers share a single public address; the
// original 30/min default produced spurious drops on such ranges.
ConnRateLimit: 200,
SYNFloodProtection: true,
// 400 concurrent connections per source IP. Sized for power users
// (multi-tab webmail + IMAP IDLE on multiple devices + Thunderbird
// parallel send + HTTPS browsing) plus headroom for shared CGNAT
// egress IPs. Operators with very heavy IDLE-style workloads can
// raise it further.
ConnLimit: 400,
// 600 hits / 300 s = 120 new connections per minute per source IP.
// Sized to tolerate normal MUA bursts (Thunderbird/iPhone/Outlook each
// open 5-15 parallel sessions when sending one email or syncing IMAP
// after suspend) while still catching true single-IP floods. Detection
// of low-and-slow scanners belongs to userspace, not this rule.
PortFlood: []PortFloodRule{
{Port: 25, Proto: "tcp", Hits: 600, Seconds: 300},
{Port: 465, Proto: "tcp", Hits: 600, Seconds: 300},
{Port: 587, Proto: "tcp", Hits: 600, Seconds: 300},
},
UDPFlood: true,
UDPFloodRate: 100,
UDPFloodBurst: 500,
DropNoLog: []int{23, 67, 68, 111, 113, 135, 136, 137, 138, 139, 445, 500, 513, 520},
DenyIPLimit: 3000,
DenyTempIPLimit: 500,
SMTPBlock: false,
SMTPPorts: []int{25, 465, 587},
LogDropped: true,
LogRate: 5,
}
}
package firewall
import (
"context"
"fmt"
"net"
"os"
"sort"
"sync"
"time"
)
// hostHealth tracks per-host resolution state for the re-resolve guard.
type hostHealth struct {
lastSuccess time.Time
firstFailure time.Time
findingEmitted bool
}
// DynDNSResolver periodically resolves hostnames and updates the firewall allowed set.
type DynDNSResolver struct {
mu sync.Mutex
hosts []string
resolved map[string][]string // hostname -> all resolved IPs
// infraHosts is the subset of hosts that also feed the engine's
// infra-IP guard. When a host appears here, every successful
// resolution additionally calls engine.UpdateInfraResolved so the
// hostname stays blockable-refusable even when the IP behind it
// rotates. Hosts not in this map only feed the allowed-IPs set.
infraHosts map[string]struct{}
engine interface {
AllowIP(ip string, reason string) error
RemoveAllowIPBySource(ip string, source string) error
}
// infraEngine receives infra-mode updates. Optional - separate
// interface so the AllowIP/RemoveAllowIPBySource consumers don't
// have to grow when an operator does not declare any infra hosts.
infraEngine interface {
UpdateInfraResolved(host string, ips []string)
DropInfraResolved(host string)
}
// lookupFn is the context-bound DNS lookup function. Defaults to
// net.DefaultResolver.LookupHost so callers can bound a stuck
// resolver by cancelling the parent context. Tests may replace this
// with a stub.
lookupFn func(ctx context.Context, host string) ([]string, error)
// Guard fields for the DNS re-resolve guard (Task 3).
muGuard sync.RWMutex
hostHealth map[string]*hostHealth
gracePeriod time.Duration
unresolvable map[string]struct{}
findingSink func(name string)
}
// NewDynDNSResolver creates a resolver for the given hostnames.
func NewDynDNSResolver(hosts []string, engine interface {
AllowIP(ip string, reason string) error
RemoveAllowIPBySource(ip string, source string) error
}) *DynDNSResolver {
r := &DynDNSResolver{
hosts: hosts,
resolved: make(map[string][]string),
engine: engine,
hostHealth: make(map[string]*hostHealth),
unresolvable: make(map[string]struct{}),
gracePeriod: 10 * time.Minute,
}
r.lookupFn = net.DefaultResolver.LookupHost
return r
}
// AddHost appends a hostname to the resolver's host list.
// It is safe to call concurrently. Used in tests to add hosts after construction.
func (d *DynDNSResolver) AddHost(host string) {
d.mu.Lock()
defer d.mu.Unlock()
d.hosts = append(d.hosts, host)
}
// RegisterInfraHost marks host as an infra hostname. Every subsequent
// successful resolution will, in addition to AllowIP, call
// engine.UpdateInfraResolved so the resolved IPs feed the infra-block
// guard. Idempotent. Wire an infra engine via SetInfraEngine before
// the resolver's first tick to make this effective.
func (d *DynDNSResolver) RegisterInfraHost(host string) {
d.mu.Lock()
defer d.mu.Unlock()
if d.infraHosts == nil {
d.infraHosts = make(map[string]struct{})
}
d.infraHosts[host] = struct{}{}
}
// SetInfraEngine wires the engine that receives infra-mode resolution
// updates. Setting it to nil disables infra routing without affecting
// the regular allowed-IPs path.
func (d *DynDNSResolver) SetInfraEngine(eng interface {
UpdateInfraResolved(host string, ips []string)
DropInfraResolved(host string)
}) {
d.mu.Lock()
defer d.mu.Unlock()
d.infraEngine = eng
}
// markLastSuccess seeds the lastSuccess timestamp for a host to now.
// Used in tests to simulate a prior successful resolution without a real
// DNS lookup.
func (d *DynDNSResolver) markLastSuccess(host string) {
d.muGuard.Lock()
defer d.muGuard.Unlock()
hh := d.hostHealth[host]
if hh == nil {
hh = &hostHealth{}
d.hostHealth[host] = hh
}
hh.lastSuccess = time.Now()
}
// SetFindingSink installs the callback invoked when a host has been
// unresolvable for longer than gracePeriod. Called from the daemon at
// startup, after the alert pipeline is wired.
func (d *DynDNSResolver) SetFindingSink(sink func(host string)) {
d.muGuard.Lock()
defer d.muGuard.Unlock()
d.findingSink = sink
}
// UnresolvableHosts lists infra_ips hostnames currently failing to resolve
// beyond the grace period.
func (d *DynDNSResolver) UnresolvableHosts() []string {
d.muGuard.RLock()
defer d.muGuard.RUnlock()
out := make([]string, 0, len(d.unresolvable))
for h := range d.unresolvable {
out = append(out, h)
}
sort.Strings(out)
return out
}
// Run starts the periodic resolver. Blocks until stopCh is closed.
func (d *DynDNSResolver) Run(stopCh <-chan struct{}) {
parent, cancelParent := context.WithCancel(context.Background())
defer cancelParent()
done := make(chan struct{})
go func() {
select {
case <-stopCh:
cancelParent()
case <-done:
}
}()
defer close(done)
select {
case <-stopCh:
cancelParent()
return
default:
}
d.runTick(parent)
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-parent.Done():
return
case <-ticker.C:
d.runTick(parent)
}
}
}
// runTick is the periodic Run helper. It bounds a single resolution
// cycle so a stuck DNS server cannot hold the ticker beyond its budget
// and stack late ticks. The per-tick budget is purposely larger than
// the default 30-second per-resolver timeout but smaller than the
// 5-minute ticker period.
func (d *DynDNSResolver) runTick(parent context.Context) {
ctx, cancel := context.WithTimeout(parent, dyndnsTickBudget)
defer cancel()
d.tickOnce(ctx)
}
const dyndnsTickBudget = 60 * time.Second
// tickOnce performs one resolution cycle over all registered hosts.
// The periodic Run loop calls this; tests can call it directly. The
// context bounds the whole cycle so a single stuck DNS server cannot
// hold the tick for the resolver's implicit timeout (~30s) and pile
// late ticks on top of the configured 5-minute period.
func (d *DynDNSResolver) tickOnce(ctx context.Context) {
d.mu.Lock()
hosts := make([]string, len(d.hosts))
copy(hosts, d.hosts)
d.mu.Unlock()
for _, host := range hosts {
if ctx.Err() != nil {
return
}
d.resolveHost(ctx, host)
}
}
// resolveAll is retained for backward compatibility with existing tests.
func (d *DynDNSResolver) resolveAll() {
d.tickOnce(context.Background())
}
func (d *DynDNSResolver) resolveHost(ctx context.Context, host string) {
newIPs, err := d.lookupFn(ctx, host)
if err != nil || len(newIPs) == 0 {
if ctx.Err() == context.Canceled {
return
}
fmt.Fprintf(os.Stderr, "dyndns: failed to resolve %s: %v\n", host, err)
d.updateGuardFailure(host)
return
}
sort.Strings(newIPs)
d.mu.Lock()
oldIPs := d.resolved[host]
d.mu.Unlock()
oldSet := make(map[string]bool)
for _, ip := range oldIPs {
oldSet[ip] = true
}
newSet := make(map[string]bool)
var infraIPs []string
for _, ip := range newIPs {
newSet[ip] = true
if parsed := net.ParseIP(ip); parsed != nil {
infraIPs = append(infraIPs, parsed.String())
}
}
// Remove IPs no longer in DNS (only remove the dyndns source entry)
for _, ip := range oldIPs {
if !newSet[ip] {
_ = d.engine.RemoveAllowIPBySource(ip, SourceDynDNS)
fmt.Fprintf(os.Stderr, "dyndns: %s removed %s (no longer resolves)\n", host, ip)
}
}
// Add new IPs
reason := fmt.Sprintf("dyndns: %s", host)
var successIPs []string
for _, ip := range newIPs {
if oldSet[ip] {
successIPs = append(successIPs, ip) // already allowed
continue
}
if err := d.engine.AllowIP(ip, reason); err != nil {
fmt.Fprintf(os.Stderr, "dyndns: error allowing %s (%s): %v\n", ip, host, err)
continue
}
successIPs = append(successIPs, ip)
fmt.Fprintf(os.Stderr, "dyndns: %s resolved to %s (added)\n", host, ip)
}
// Only update resolved map with successfully allowed IPs
d.mu.Lock()
d.resolved[host] = successIPs
isInfra := false
if _, ok := d.infraHosts[host]; ok {
isInfra = true
}
infraEngine := d.infraEngine
d.mu.Unlock()
// Infra mode feeds the block guard from DNS itself, not from the
// allow-list mutation result. A transient nftables write failure
// should not make a resolved management hostname blockable.
if isInfra && infraEngine != nil {
infraEngine.UpdateInfraResolved(host, infraIPs)
}
// Successful resolution: clear any guard state.
d.updateGuardSuccess(host)
}
// updateGuardSuccess marks a host as successfully resolved in the guard state.
func (d *DynDNSResolver) updateGuardSuccess(host string) {
d.muGuard.Lock()
hh := d.hostHealth[host]
if hh == nil {
hh = &hostHealth{}
d.hostHealth[host] = hh
}
hh.lastSuccess = time.Now()
hh.firstFailure = time.Time{}
if _, was := d.unresolvable[host]; was {
delete(d.unresolvable, host)
fmt.Fprintf(os.Stderr, "dyndns: %s recovered (resolution succeeded)\n", host)
}
hh.findingEmitted = false
d.muGuard.Unlock()
}
// updateGuardFailure updates guard state after a failed resolution for a host.
// If the host has been unresolvable beyond gracePeriod and no finding has been
// emitted yet, it marks the host unresolvable and invokes the finding sink.
func (d *DynDNSResolver) updateGuardFailure(host string) {
d.muGuard.Lock()
hh := d.hostHealth[host]
if hh == nil {
hh = &hostHealth{}
d.hostHealth[host] = hh
}
now := time.Now()
if hh.firstFailure.IsZero() {
hh.firstFailure = now
}
since := hh.lastSuccess
if since.IsZero() {
since = hh.firstFailure
}
var sinkToCall func(string)
if now.Sub(since) > d.gracePeriod &&
!hh.findingEmitted {
d.unresolvable[host] = struct{}{}
hh.findingEmitted = true
sinkToCall = d.findingSink
}
d.muGuard.Unlock()
// Invoke sink outside the lock to prevent deadlock if sink calls back in.
if sinkToCall != nil {
sinkToCall(host)
}
}
//go:build linux
package firewall
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/pidginhost/csm/internal/atomicio"
)
// Engine manages the nftables firewall ruleset.
// Manages the nftables ruleset via netlink.
type Engine struct {
mu sync.Mutex
conn *nftables.Conn
cfg *FirewallConfig
// dryRunRecorder is called by BlockIP when auto_response.dry_run is
// active. Set by SetDryRunRecorder after construction so the firewall
// package does not import internal/store (which would be a cycle).
dryRunRecorder func(ip, reason string, timeout time.Duration)
// dryRunEnabled reports whether auto_response.dry_run is active.
// Set by the daemon so this package does not import internal/config.
dryRunEnabled func() bool
// verdictAsker, when set, is consulted after local block validation and
// before the dry-run gate. The callback returns (verdict, tenantID, note,
// error). Verdict "allow" short-circuits the block; "block" or empty
// proceeds with the default flow.
// Errors are fail-open: the daemon proceeds with the default block and
// logs the failure. Daemon owns the underlying verdict.Client; this
// package stays free of the internal/verdict import.
verdictAsker func(ctx context.Context, ip, reason string) (verdict, tenantID, note string, err error)
// shutdownCtx, when set, scopes the lifetime of any in-flight verdict
// callback to daemon shutdown. Without it, BlockIPOutcome used
// context.Background() and a wedged panel callback kept the
// auto-block caller waiting for the full http.Client.Timeout during
// graceful restart. Nil falls back to context.Background() so unit
// tests that build the Engine literal without a daemon keep working.
shutdownCtx context.Context
table *nftables.Table
chainIn *nftables.Chain
chainOut *nftables.Chain
setBlocked *nftables.Set
setBlockedNet *nftables.Set
setAllowed *nftables.Set
setInfra *nftables.Set
setCountry *nftables.Set
// Cloudflare IP whitelist sets (interval for CIDR matching)
setCFWhitelist *nftables.Set // IPv4
setCFWhitelist6 *nftables.Set // IPv6
// IPv6 sets (nil if IPv6 disabled)
setBlocked6 *nftables.Set
setBlockedNet6 *nftables.Set
setAllowed6 *nftables.Set
setInfra6 *nftables.Set
setCountry6 *nftables.Set
// Meters for per-IP rate limiting
meterSYN *nftables.Set
meterConn *nftables.Set
meterUDP *nftables.Set
meterConnlim *nftables.Set
meterPortFlood4 map[string]*nftables.Set
meterPortFlood6 map[string]*nftables.Set
statePath string
// Cached parsed firewall state. Population is lazy: the first call
// to loadStateFile under e.mu reads + parses state.json once, then
// subsequent calls return a deep copy from the in-memory cache as
// long as the on-disk metadata key is unchanged.
//
// Before this cache existed, every single mutator (BlockIP,
// AllowIP, saveBlockedEntry, etc.) reloaded and re-parsed the full
// 325 KiB state.json from disk, and every IsBlocked / IsAllowed
// call did a linear scan over the parsed slices. On a busy
// production host that meant ~72 state.json opens per second
// steady state, which showed up as the dominant CPU hot spot in
// roadmap audit 7.1.
//
// All four fields are written only while e.mu is held. The index
// maps are rebuilt every time stateCache is repopulated so
// O(1) lookups stay coherent with the cached slices. blockedIPIndex
// stores the slice position for each canonical IP.
stateCache *FirewallState
stateCacheKey stateFileCacheKey
blockedIPIndex map[string]int
allowedIPIndex map[string]struct{}
blockedCIDRIndex map[string]struct{}
liveBlockLookup func(set *nftables.Set, key []byte) (bool, error)
// liveBlockCounts, when non-nil, returns the live (perm, temp)
// element counts across blocked v4 + v6 sets. Tests inject this
// to avoid spinning up a real nft connection. Nil falls back to
// GetSetElements queries.
liveBlockCounts func() (perm, temp int, err error)
// infraResolved maps a hostname declared under cfg.InfraIPs to its
// last successfully-resolved set of IPs. blockIPTarget refuses to
// block any of these so a transient DNS pause cannot let an
// attacker block CSM's own panel hostname into a lockout. Mutated
// under e.mu by UpdateInfraResolved / DropInfraResolved from the
// DynDNS resolver.
infraResolved map[string]map[string]struct{}
// localAddrs caches the host's own non-loopback interface addresses.
// The block guard refuses to block any of these regardless of
// cfg.InfraIPs so a misconfigured config or a stray scan that loops
// back to the daemon cannot brick the host. Refreshed lazily under
// e.mu when localAddrsExpiresAt has elapsed.
localAddrs map[string]struct{}
localAddrsExpiresAt time.Time
// localAddrsLookup, when non-nil, replaces net.InterfaceAddrs() for
// tests. Returning an error leaves the cache untouched.
localAddrsLookup func() ([]string, error)
}
type stateFileCacheKey struct {
modTime time.Time
changeTime time.Time
size int64
dev uint64
ino uint64
}
func stateFileCacheKeyFromInfo(info os.FileInfo) stateFileCacheKey {
key := stateFileCacheKey{
modTime: info.ModTime(),
size: info.Size(),
}
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
key.changeTime = time.Unix(stat.Ctim.Sec, stat.Ctim.Nsec)
key.dev = uint64(stat.Dev)
key.ino = uint64(stat.Ino)
}
return key
}
func (k stateFileCacheKey) matches(info os.FileInfo) bool {
other := stateFileCacheKeyFromInfo(info)
return k.size == other.size &&
k.dev == other.dev &&
k.ino == other.ino &&
k.modTime.Equal(other.modTime) &&
k.changeTime.Equal(other.changeTime)
}
// BlockedEntry represents a blocked IP with metadata.
type BlockedEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"` // zero = permanent
}
// AllowedEntry represents an allowed IP with metadata.
type AllowedEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
Port int `json:"port,omitempty"` // 0 = all ports
ExpiresAt time.Time `json:"expires_at,omitempty"` // zero = permanent
}
// SubnetEntry represents a blocked CIDR range.
type SubnetEntry struct {
CIDR string `json:"cidr"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
}
// PortAllowEntry represents a port-specific IP allow (e.g. tcp|in|d=PORT|s=IP).
type PortAllowEntry struct {
IP string `json:"ip"`
Port int `json:"port"`
Proto string `json:"proto"` // "tcp" or "udp"
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
}
// FirewallState is persisted to disk for restore on restart.
type FirewallState struct {
Blocked []BlockedEntry `json:"blocked"`
BlockedNet []SubnetEntry `json:"blocked_nets"`
Allowed []AllowedEntry `json:"allowed"`
PortAllowed []PortAllowEntry `json:"port_allowed"`
}
// portU16 converts an operator-configured int port to the uint16 nftables
// expects. Returns 0 for out-of-range values; 0 is an unroutable TCP/UDP port
// so a misconfigured rule fails closed (no traffic matches) rather than
// wrapping silently to a valid-but-wrong port.
func portU16(p int) uint16 {
if p < 0 || p > 65535 {
return 0
}
// #nosec G115 -- bounded above.
return uint16(p)
}
// NewEngine creates a new nftables firewall engine.
func NewEngine(cfg *FirewallConfig, statePath string) (*Engine, error) {
conn, err := nftables.New()
if err != nil {
return nil, fmt.Errorf("nftables connection: %w", err)
}
e := &Engine{
conn: conn,
cfg: cfg,
statePath: filepath.Join(statePath, "firewall"),
}
_ = os.MkdirAll(e.statePath, 0700)
return e, nil
}
// ConnectExisting connects to an already-running CSM firewall.
// Used by CLI commands to modify the live ruleset without reapplying all rules.
func ConnectExisting(cfg *FirewallConfig, statePath string) (*Engine, error) {
conn, err := nftables.New()
if err != nil {
return nil, fmt.Errorf("nftables connection: %w", err)
}
// Find existing CSM table
tables, err := conn.ListTables()
if err != nil {
return nil, fmt.Errorf("listing tables: %w", err)
}
var table *nftables.Table
for _, t := range tables {
if t.Name == "csm" && t.Family == nftables.TableFamilyINet {
table = t
break
}
}
if table == nil {
return nil, fmt.Errorf("CSM firewall not running (table 'csm' not found) - run 'csm firewall restart' first")
}
setBlocked, err := conn.GetSetByName(table, "blocked_ips")
if err != nil {
return nil, fmt.Errorf("blocked_ips set not found: %w", err)
}
setBlockedNet, err := conn.GetSetByName(table, "blocked_nets")
if err != nil {
return nil, fmt.Errorf("blocked_nets set not found: %w", err)
}
setAllowed, err := conn.GetSetByName(table, "allowed_ips")
if err != nil {
return nil, fmt.Errorf("allowed_ips set not found: %w", err)
}
setInfra, err := conn.GetSetByName(table, "infra_ips")
if err != nil {
return nil, fmt.Errorf("infra_ips set not found: %w", err)
}
e := &Engine{
conn: conn,
cfg: cfg,
table: table,
setBlocked: setBlocked,
setBlockedNet: setBlockedNet,
setAllowed: setAllowed,
setInfra: setInfra,
statePath: filepath.Join(statePath, "firewall"),
}
// Try to find Cloudflare whitelist sets (optional)
if s, err := conn.GetSetByName(table, "cf_whitelist"); err == nil {
e.setCFWhitelist = s
}
if s, err := conn.GetSetByName(table, "cf_whitelist6"); err == nil {
e.setCFWhitelist6 = s
}
// Try to find IPv6 sets (optional - may not exist if IPv6 disabled)
if s, err := conn.GetSetByName(table, "blocked_ips6"); err == nil {
e.setBlocked6 = s
}
if s, err := conn.GetSetByName(table, "blocked_nets6"); err == nil {
e.setBlockedNet6 = s
}
if s, err := conn.GetSetByName(table, "allowed_ips6"); err == nil {
e.setAllowed6 = s
}
if s, err := conn.GetSetByName(table, "infra_ips6"); err == nil {
e.setInfra6 = s
}
return e, nil
}
// SetDryRunRecorder installs a callback that is invoked by BlockIP whenever
// auto_response.dry_run is active. The daemon calls this after construction
// to wire in store.RecordDryRunBlock without creating an import cycle between
// internal/firewall and internal/store.
func (e *Engine) SetDryRunRecorder(fn func(ip, reason string, timeout time.Duration)) {
e.mu.Lock()
e.dryRunRecorder = fn
e.mu.Unlock()
}
// SetDryRunEnabledFunc installs the callback BlockIP uses to decide whether
// auto_response.dry_run should intercept an automatic block. Nil means live.
func (e *Engine) SetDryRunEnabledFunc(fn func() bool) {
e.mu.Lock()
e.dryRunEnabled = fn
e.mu.Unlock()
}
// SetVerdictAsker installs the verdict callback the daemon constructs at
// startup. Nil disables the verdict callback (the gate skips entirely).
func (e *Engine) SetVerdictAsker(fn func(ctx context.Context, ip, reason string) (string, string, string, error)) {
e.mu.Lock()
e.verdictAsker = fn
e.mu.Unlock()
}
func (e *Engine) verdictAskerFn() func(ctx context.Context, ip, reason string) (string, string, string, error) {
e.mu.Lock()
fn := e.verdictAsker
e.mu.Unlock()
return fn
}
// SetShutdownContext installs a context whose cancellation aborts any
// in-flight verdict callback. The daemon ties this to its stopCh so a
// graceful shutdown does not have to wait for an unresponsive panel
// callback to return.
func (e *Engine) SetShutdownContext(ctx context.Context) {
e.mu.Lock()
e.shutdownCtx = ctx
e.mu.Unlock()
}
func (e *Engine) verdictContext() context.Context {
e.mu.Lock()
ctx := e.shutdownCtx
e.mu.Unlock()
if ctx == nil {
return context.Background()
}
return ctx
}
// Apply builds and atomically applies the complete nftables ruleset.
// All operations (delete old table + create new table/rules +
// populate persisted block/allow entries) are batched into a single
// netlink transaction. If the flush fails, the kernel keeps whatever
// ruleset was running before - the server is never left without a
// firewall. Equally important: the new ruleset never appears with
// EMPTY blocked sets between the table-swap and the persisted-state
// load; an attacker IP from state.json is blocked from the moment
// the new table becomes the live one.
func (e *Engine) Apply() error {
e.mu.Lock()
defer e.mu.Unlock()
// Compute the elements to seed each set from persisted state
// BEFORE touching nft. Pure computation; if state.json is
// missing or malformed the slices stay empty and the new table
// still applies (no firewall regression on a fresh install).
initial := e.computeInitialBlockStateLocked()
// Check if existing CSM table needs replacing.
// If so, include the delete in the same atomic batch as the new table.
tables, _ := e.conn.ListTables()
for _, t := range tables {
if t.Name == "csm" && t.Family == nftables.TableFamilyINet {
e.conn.DelTable(t)
break
}
}
// Create table - all operations below are batched, nothing is sent until Flush()
e.table = e.conn.AddTable(&nftables.Table{
Name: "csm",
Family: nftables.TableFamilyINet,
})
// Create IP sets
if err := e.createSets(); err != nil {
return fmt.Errorf("creating sets: %w", err)
}
// Create chains and rules
if err := e.createInputChain(); err != nil {
return fmt.Errorf("creating input chain: %w", err)
}
if err := e.createOutputChain(); err != nil {
return fmt.Errorf("creating output chain: %w", err)
}
// Queue initial set elements from persisted state into the same
// netlink batch as the table+set+chain creation above. Without
// this, Apply previously Flushed an empty-set ruleset, then a
// separate loadState() Flush populated the sets - leaving a
// brief window where the new table existed without the
// persisted blocks.
if err := e.queueInitialBlockStateLocked(initial); err != nil {
return fmt.Errorf("queueing initial firewall state: %w", err)
}
// Apply atomically - if this fails, nftables keeps whatever was running before
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("applying ruleset: %w", err)
}
return nil
}
// createSets creates the nftables named sets for IP management.
func (e *Engine) createSets() error {
// Blocked IPs set (with per-element timeout)
e.setBlocked = &nftables.Set{
Table: e.table,
Name: "blocked_ips",
KeyType: nftables.TypeIPAddr,
HasTimeout: true,
Timeout: 24 * time.Hour,
}
if err := e.conn.AddSet(e.setBlocked, nil); err != nil {
return fmt.Errorf("blocked set: %w", err)
}
// Blocked subnets set (interval for CIDR ranges, permanent)
e.setBlockedNet = &nftables.Set{
Table: e.table,
Name: "blocked_nets",
KeyType: nftables.TypeIPAddr,
Interval: true,
}
if err := e.conn.AddSet(e.setBlockedNet, nil); err != nil {
return fmt.Errorf("blocked nets set: %w", err)
}
// Allowed IPs set
e.setAllowed = &nftables.Set{
Table: e.table,
Name: "allowed_ips",
KeyType: nftables.TypeIPAddr,
}
if err := e.conn.AddSet(e.setAllowed, nil); err != nil {
return fmt.Errorf("allowed set: %w", err)
}
// Infra IPs set (interval for CIDR support) - split IPv4 and IPv6
e.setInfra = &nftables.Set{
Table: e.table,
Name: "infra_ips",
KeyType: nftables.TypeIPAddr,
Interval: true,
}
var infraElements []nftables.SetElement
var infra6Elements []nftables.SetElement
for _, cidr := range e.cfg.InfraIPs {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
ip := net.ParseIP(cidr)
if ip == nil {
continue
}
if ip4 := ip.To4(); ip4 != nil {
infraElements = appendIntervalSetElements(infraElements, ip4, ip4)
} else if e.cfg.IPv6 {
ip16 := ip.To16()
infra6Elements = appendIntervalSetElements(infra6Elements, ip16, ip16)
}
continue
}
if network.IP.To4() != nil {
start := network.IP.To4()
end := lastIPInRange(network)
if start != nil && end != nil {
infraElements = appendIntervalSetElements(infraElements, start, end)
}
} else if e.cfg.IPv6 {
start := network.IP.To16()
end := lastIPInRange(network)
if start != nil && end != nil {
infra6Elements = appendIntervalSetElements(infra6Elements, start, end)
}
}
}
if err := e.conn.AddSet(e.setInfra, infraElements); err != nil {
return fmt.Errorf("infra set: %w", err)
}
// Cloudflare IP whitelist sets (interval for CIDR ranges, accept on 80/443 only)
e.setCFWhitelist = &nftables.Set{
Table: e.table,
Name: "cf_whitelist",
KeyType: nftables.TypeIPAddr,
Interval: true,
}
if err := e.conn.AddSet(e.setCFWhitelist, nil); err != nil {
return fmt.Errorf("cf_whitelist set: %w", err)
}
e.setCFWhitelist6 = &nftables.Set{
Table: e.table,
Name: "cf_whitelist6",
KeyType: nftables.TypeIP6Addr,
Interval: true,
}
if err := e.conn.AddSet(e.setCFWhitelist6, nil); err != nil {
return fmt.Errorf("cf_whitelist6 set: %w", err)
}
// Country-blocked IPs set (interval for CIDR ranges)
if len(e.cfg.CountryBlock) > 0 && e.cfg.CountryDBPath != "" {
e.setCountry = &nftables.Set{
Table: e.table,
Name: "country_blocked",
KeyType: nftables.TypeIPAddr,
Interval: true,
}
var countryElements []nftables.SetElement
for _, code := range e.cfg.CountryBlock {
countryElements = append(countryElements, loadCountryCIDRs(e.cfg.CountryDBPath, code)...)
}
if err := e.conn.AddSet(e.setCountry, countryElements); err != nil {
fmt.Fprintf(os.Stderr, "firewall: warning creating country set: %v\n", err)
e.setCountry = nil
} else if len(countryElements) > 0 {
fmt.Fprintf(os.Stderr, "firewall: loaded %d country block ranges for %v\n",
len(countryElements)/2, e.cfg.CountryBlock)
}
}
// IPv6 sets
if e.cfg.IPv6 {
e.setBlocked6 = &nftables.Set{
Table: e.table, Name: "blocked_ips6",
KeyType: nftables.TypeIP6Addr, HasTimeout: true, Timeout: 24 * time.Hour,
}
if err := e.conn.AddSet(e.setBlocked6, nil); err != nil {
return fmt.Errorf("blocked6 set: %w", err)
}
e.setBlockedNet6 = &nftables.Set{
Table: e.table, Name: "blocked_nets6",
KeyType: nftables.TypeIP6Addr, Interval: true,
}
if err := e.conn.AddSet(e.setBlockedNet6, nil); err != nil {
return fmt.Errorf("blocked_nets6 set: %w", err)
}
e.setAllowed6 = &nftables.Set{
Table: e.table, Name: "allowed_ips6",
KeyType: nftables.TypeIP6Addr,
}
if err := e.conn.AddSet(e.setAllowed6, nil); err != nil {
return fmt.Errorf("allowed6 set: %w", err)
}
e.setInfra6 = &nftables.Set{
Table: e.table, Name: "infra_ips6",
KeyType: nftables.TypeIP6Addr, Interval: true,
}
if err := e.conn.AddSet(e.setInfra6, infra6Elements); err != nil {
return fmt.Errorf("infra6 set: %w", err)
}
// IPv6 country-block set, mirroring the IPv4 set above. Without it an
// attacker on an IPv6 address from a blocked country was never dropped.
if len(e.cfg.CountryBlock) > 0 && e.cfg.CountryDBPath != "" {
e.setCountry6 = &nftables.Set{
Table: e.table, Name: "country_blocked6",
KeyType: nftables.TypeIP6Addr, Interval: true,
}
var country6Elements []nftables.SetElement
for _, code := range e.cfg.CountryBlock {
country6Elements = append(country6Elements, loadCountryCIDRs6(e.cfg.CountryDBPath, code)...)
}
if err := e.conn.AddSet(e.setCountry6, country6Elements); err != nil {
fmt.Fprintf(os.Stderr, "firewall: warning creating IPv6 country set: %v\n", err)
e.setCountry6 = nil
} else if len(country6Elements) > 0 {
fmt.Fprintf(os.Stderr, "firewall: loaded %d IPv6 country block ranges for %v\n",
len(country6Elements)/2, e.cfg.CountryBlock)
}
}
}
// Meter sets for per-IP rate limiting (dynamic sets)
if e.cfg.SYNFloodProtection {
e.meterSYN = &nftables.Set{
Table: e.table, Name: "meter_syn", KeyType: nftables.TypeIPAddr,
Dynamic: true, HasTimeout: true, Timeout: time.Minute,
}
_ = e.conn.AddSet(e.meterSYN, nil)
}
if e.cfg.ConnRateLimit > 0 {
e.meterConn = &nftables.Set{
Table: e.table, Name: "meter_conn", KeyType: nftables.TypeIPAddr,
Dynamic: true, HasTimeout: true, Timeout: time.Minute,
}
_ = e.conn.AddSet(e.meterConn, nil)
}
if e.cfg.UDPFlood && e.cfg.UDPFloodRate > 0 {
e.meterUDP = &nftables.Set{
Table: e.table, Name: "meter_udp", KeyType: nftables.TypeIPAddr,
Dynamic: true, HasTimeout: true, Timeout: time.Minute,
}
_ = e.conn.AddSet(e.meterUDP, nil)
}
if e.cfg.ConnLimit > 0 {
e.meterConnlim = &nftables.Set{
Table: e.table, Name: "meter_connlimit", KeyType: nftables.TypeIPAddr,
Dynamic: true,
}
_ = e.conn.AddSet(e.meterConnlim, nil)
}
if portFloodNeedsMeter(e.cfg.PortFlood) {
e.meterPortFlood4 = make(map[string]*nftables.Set)
if e.cfg.IPv6 {
e.meterPortFlood6 = make(map[string]*nftables.Set)
}
for _, pf := range e.cfg.PortFlood {
if !usablePortFloodRule(pf) {
continue
}
name4 := portFloodMeterName(pf, portFloodIPv4)
if _, ok := e.meterPortFlood4[name4]; !ok {
set := &nftables.Set{
Table: e.table, Name: name4, KeyType: nftables.TypeIPAddr,
Dynamic: true, HasTimeout: true, Timeout: time.Minute,
}
_ = e.conn.AddSet(set, nil)
e.meterPortFlood4[name4] = set
}
if e.cfg.IPv6 {
name6 := portFloodMeterName(pf, portFloodIPv6)
if _, ok := e.meterPortFlood6[name6]; !ok {
set := &nftables.Set{
Table: e.table, Name: name6, KeyType: nftables.TypeIP6Addr,
Dynamic: true, HasTimeout: true, Timeout: time.Minute,
}
_ = e.conn.AddSet(set, nil)
e.meterPortFlood6[name6] = set
}
}
}
}
return nil
}
// portFloodNeedsMeter reports whether any port_flood rule has a usable rate,
// so the meter set is only created when at least one rule will reference it.
func portFloodNeedsMeter(rules []PortFloodRule) bool {
for _, pf := range rules {
if usablePortFloodRule(pf) {
return true
}
}
return false
}
func usablePortFloodRule(pf PortFloodRule) bool {
return pf.Hits > 0 && pf.Seconds > 0 && pf.Port > 0
}
// createInputChain builds the input filter chain with proper rule ordering.
func (e *Engine) createInputChain() error {
policy := nftables.ChainPolicyDrop
e.chainIn = e.conn.AddChain(&nftables.Chain{
Name: "input",
Table: e.table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
Policy: &policy,
})
// Rule 1: Allow loopback
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte("lo\x00"),
},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
// Rule 2: Drop INVALID conntrack state (malformed packets)
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1, DestRegister: 1, Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitINVALID),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(0)},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
// Rule 3: Allow infra IPs FIRST - infra must NEVER be blocked, even accidentally
e.addSetMatchRule(e.setInfra, expr.VerdictAccept)
e.addSetMatchRuleV6(e.setInfra6, expr.VerdictAccept)
// Rule 4: Cloudflare IP whitelist - accept on TCP 80/443 only.
// CF IPs can still be blocked on other ports (unlike infra).
e.addCFWhitelistRule(e.setCFWhitelist, false)
e.addCFWhitelistRule(e.setCFWhitelist6, true)
// Rule 5: Drop blocked IPs before established/related so active
// keep-alive connections do not bypass a new block.
e.addSetMatchRule(e.setBlocked, expr.VerdictDrop)
e.addSetMatchRuleV6(e.setBlocked6, expr.VerdictDrop)
// Rule 6: Drop blocked subnets (interval set for CIDR ranges)
e.addSetMatchRule(e.setBlockedNet, expr.VerdictDrop)
e.addSetMatchRuleV6(e.setBlockedNet6, expr.VerdictDrop)
// Rule 7: Allow established/related connections after block checks.
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
// Rule 8: Allow explicitly allowed IPs
e.addSetMatchRule(e.setAllowed, expr.VerdictAccept)
e.addSetMatchRuleV6(e.setAllowed6, expr.VerdictAccept)
// Rule 9: Port-specific allows (IP+port, e.g. MySQL access for specific IPs)
state := e.loadStateFile()
for _, pa := range state.PortAllowed {
parsed := net.ParseIP(pa.IP)
if parsed == nil {
continue
}
proto := byte(6) // TCP
if pa.Proto == "udp" {
proto = 17
}
if ip4 := parsed.To4(); ip4 != nil {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ip4},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(pa.Port))},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
} else if e.cfg.IPv6 {
ip16 := parsed.To16()
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
&expr.Meta{Key: expr.MetaKeyNFPROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{10}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 8, Len: 16},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ip16},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(pa.Port))},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
}
// ICMPv4 echo-request (type 8)
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{1}}, // ICMPv4
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{8}}, // echo-request
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
// ICMPv6 echo-request (type 128) + neighbor discovery (types 133-137, required for IPv6)
if e.cfg.IPv6 {
for _, icmp6Type := range []byte{128, 133, 134, 135, 136, 137} {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{58}}, // ICMPv6
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{icmp6Type}},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
}
// Per-IP SYN flood protection via meter
if e.cfg.SYNFloodProtection && e.meterSYN != nil {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}}, // TCP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 13, Len: 1},
&expr.Bitwise{SourceRegister: 1, DestRegister: 1, Len: 1, Mask: []byte{0x12}, Xor: []byte{0x00}},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{0x02}}, // SYN only
// Load source IP for per-IP metering
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Dynset{
SrcRegKey: 1,
SetName: e.meterSYN.Name,
SetID: e.meterSYN.ID,
Operation: 1, // NFT_DYNSET_OP_UPDATE
Exprs: []expr.Any{
&expr.Limit{Type: expr.LimitTypePkts, Rate: 25, Unit: expr.LimitTimeSecond, Burst: 100, Over: true},
},
},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
// Per-IP new connection rate limit via meter
if e.cfg.ConnRateLimit > 0 && e.meterConn != nil {
// #nosec G115 -- ConnRateLimit is an operator-configured int (typical 10–1000);
// /2 is non-negative and well below uint32 max.
burst := uint32(e.cfg.ConnRateLimit / 2)
if burst < 5 {
burst = 5
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1, DestRegister: 1, Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(0)},
// Load source IP for per-IP metering
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Dynset{
SrcRegKey: 1,
SetName: e.meterConn.Name,
SetID: e.meterConn.ID,
Operation: 1,
Exprs: []expr.Any{
&expr.Limit{Type: expr.LimitTypePkts, Rate: uint64(e.cfg.ConnRateLimit), Unit: expr.LimitTimeMinute, Burst: burst, Over: true},
},
},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
// Per-IP concurrent connection limit (CONNLIMIT)
if e.cfg.ConnLimit > 0 && e.meterConnlim != nil {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1, DestRegister: 1, Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(0)},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Dynset{
SrcRegKey: 1,
SetName: e.meterConnlim.Name,
SetID: e.meterConnlim.ID,
Operation: 1,
Exprs: []expr.Any{
// #nosec G115 -- ConnLimit is operator-configured non-negative int; fits in uint32.
&expr.Connlimit{Count: uint32(e.cfg.ConnLimit), Flags: 1}, // 1 = over
},
},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
// Rule 9: Country block - drop traffic from blocked countries
if e.setCountry != nil {
e.addSetMatchRule(e.setCountry, expr.VerdictDrop)
}
if e.setCountry6 != nil {
e.addSetMatchRuleV6(e.setCountry6, expr.VerdictDrop)
}
// Per-port flood protection - rate-limit new connections per source IP.
// Each rule has separate IPv4 and IPv6 meters so ports and families do not
// consume each other's token buckets.
for _, pf := range e.cfg.PortFlood {
items := []struct {
family portFloodIPFamily
meter *nftables.Set
}{
{family: portFloodIPv4, meter: e.meterPortFlood4[portFloodMeterName(pf, portFloodIPv4)]},
}
if e.cfg.IPv6 {
items = append(items, struct {
family portFloodIPFamily
meter *nftables.Set
}{family: portFloodIPv6, meter: e.meterPortFlood6[portFloodMeterName(pf, portFloodIPv6)]})
}
for _, item := range items {
exprs := buildPortFloodExprs(pf, item.meter, item.family)
if exprs == nil {
continue
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: exprs,
})
}
}
// Per-IP UDP flood protection via meter
if e.cfg.UDPFlood && e.cfg.UDPFloodRate > 0 && e.meterUDP != nil {
// #nosec G115 -- UDPFloodBurst is operator-configured non-negative int.
burst := uint32(e.cfg.UDPFloodBurst)
if burst < 10 {
burst = 10
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{17}}, // UDP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Dynset{
SrcRegKey: 1,
SetName: e.meterUDP.Name,
SetID: e.meterUDP.ID,
Operation: 1,
Exprs: []expr.Any{
&expr.Limit{Type: expr.LimitTypePkts, Rate: uint64(e.cfg.UDPFloodRate), Unit: expr.LimitTimeSecond, Burst: burst, Over: true},
},
},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
// Build restricted port set - these are only reachable via infra IPs (rule 4)
restricted := make(map[int]bool)
for _, p := range e.cfg.RestrictedTCP {
restricted[p] = true
}
// Open TCP ports (public) - restricted ports excluded
for _, port := range e.cfg.TCPIn {
if restricted[port] {
continue
}
e.addPortAcceptRule(port, true)
}
// Open UDP ports (public)
for _, port := range e.cfg.UDPIn {
e.addPortAcceptRule(port, false)
}
// Passive FTP range
if e.cfg.PassiveFTPStart > 0 && e.cfg.PassiveFTPEnd > 0 {
e.addPortRangeAcceptRule(e.cfg.PassiveFTPStart, e.cfg.PassiveFTPEnd, true)
}
// Silent drop for commonly-scanned ports (no logging)
for _, port := range e.cfg.DropNoLog {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}}, // TCP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{17}}, // UDP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
// Rate-limited log for remaining dropped packets
if e.cfg.LogDropped {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Limit{
Type: expr.LimitTypePkts,
Rate: uint64(max(e.cfg.LogRate, 1)),
Unit: expr.LimitTimeMinute,
Burst: 5,
},
&expr.Log{Key: 1, Data: []byte("CSM-DROP: ")},
},
})
}
// Default policy is DROP - anything not matched above is dropped
return nil
}
// createOutputChain builds the output filter chain.
// Restricts outbound to configured ports only (prevents C2 on non-standard ports).
func (e *Engine) createOutputChain() error {
if len(e.cfg.TCPOut) == 0 && len(e.cfg.UDPOut) == 0 {
// No outbound restrictions configured - accept all
policy := nftables.ChainPolicyAccept
e.chainOut = e.conn.AddChain(&nftables.Chain{
Name: "output",
Table: e.table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityFilter,
Policy: &policy,
})
return nil
}
// Outbound filtering enabled
policy := nftables.ChainPolicyDrop
e.chainOut = e.conn.AddChain(&nftables.Chain{
Name: "output",
Table: e.table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityFilter,
Policy: &policy,
})
// Allow established/related outbound
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(0)},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
// Allow loopback outbound
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte("lo\x00")},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
// SMTP block - restrict outbound mail to allowed users only.
// resolveSMTPAllowedUIDs unconditionally includes root and mailnull
// (exim's queue runner); without mailnull on the allow list queued
// mail is silently dropped while CSM still reports healthy.
smtpBlocked := make(map[int]bool)
if e.cfg.SMTPBlock && len(e.cfg.SMTPPorts) > 0 {
allowedUIDs := resolveSMTPAllowedUIDs(e.cfg.SMTPAllowUsers)
for _, port := range e.cfg.SMTPPorts {
smtpBlocked[port] = true
// Accept from each allowed UID
for _, uid := range allowedUIDs {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Meta{Key: expr.MetaKeySKUID, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(uid)},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
// Drop SMTP from everyone else
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
}
// Allow configured outbound TCP ports (skip SMTP-blocked ports - handled above)
for _, port := range e.cfg.TCPOut {
if smtpBlocked[port] {
continue
}
e.addOutboundPortRule(port, true)
}
// Allow configured outbound UDP ports
for _, port := range e.cfg.UDPOut {
e.addOutboundPortRule(port, false)
}
// Allow only safe ICMP outbound (echo-reply + echo-request, block dest-unreachable)
// Blocking ICMP type 3 (dest-unreachable) prevents leaking closed port info to scanners
for _, icmpType := range []byte{0, 8} { // 0=echo-reply, 8=echo-request
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{1}}, // ICMP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{icmpType}},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
// ICMPv6 outbound - allow echo-reply (129) + echo-request (128) + ND (133-137)
if e.cfg.IPv6 {
for _, icmp6Type := range []byte{128, 129, 133, 134, 135, 136, 137} {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{58}}, // ICMPv6
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{icmp6Type}},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
}
// REJECT outbound TCP with RST (faster failure than silent DROP)
// UDP still silently drops via chain policy.
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}},
&expr.Reject{Type: 1}, // NFT_REJECT_TCP_RST
},
})
return nil
}
// --- Helper methods ---
// resolveIPSet returns the appropriate set and key bytes for an IP address.
// Falls back to IPv6 set if the IP is not IPv4.
func (e *Engine) resolveIPSet(ip string, set4, set6 *nftables.Set) (*nftables.Set, []byte, error) {
parsed := net.ParseIP(ip)
if parsed == nil {
return nil, nil, fmt.Errorf("invalid IP: %s", ip)
}
if ip4 := parsed.To4(); ip4 != nil {
return set4, ip4, nil
}
if set6 == nil {
return nil, nil, fmt.Errorf("IPv6 not enabled in firewall config: %s", ip)
}
return set6, parsed.To16(), nil
}
func canonicalFirewallIP(ip string) (string, error) {
parsed := net.ParseIP(ip)
if parsed == nil {
return "", fmt.Errorf("invalid IP: %s", ip)
}
return parsed.String(), nil
}
func canonicalIPKey(ip string) (string, bool) {
parsed := net.ParseIP(ip)
if parsed == nil {
return "", false
}
return parsed.String(), true
}
func stateIPKey(ip string) string {
if key, ok := canonicalIPKey(ip); ok {
return key
}
return ip
}
func sameIPString(a, b string) bool {
if a == b {
return true
}
aKey, aOK := canonicalIPKey(a)
if !aOK {
return false
}
bKey, bOK := canonicalIPKey(b)
return bOK && aKey == bKey
}
func preferLongerLivedBlock(current, candidate BlockedEntry) BlockedEntry {
if current.ExpiresAt.IsZero() {
return current
}
if candidate.ExpiresAt.IsZero() || candidate.ExpiresAt.After(current.ExpiresAt) {
return candidate
}
return current
}
func preferLongerLivedAllow(current, candidate AllowedEntry) AllowedEntry {
if current.ExpiresAt.IsZero() {
return current
}
if candidate.ExpiresAt.IsZero() || candidate.ExpiresAt.After(current.ExpiresAt) {
return candidate
}
return current
}
func normalizeFirewallStateIPs(state *FirewallState) {
blockedPos := make(map[string]int, len(state.Blocked))
blocked := state.Blocked[:0]
for _, entry := range state.Blocked {
key, ok := canonicalIPKey(entry.IP)
if !ok {
blocked = append(blocked, entry)
continue
}
entry.IP = key
if i, exists := blockedPos[key]; exists {
blocked[i] = preferLongerLivedBlock(blocked[i], entry)
continue
}
blockedPos[key] = len(blocked)
blocked = append(blocked, entry)
}
state.Blocked = blocked
type allowedKey struct {
ip string
source string
}
allowedPos := make(map[allowedKey]int, len(state.Allowed))
allowed := state.Allowed[:0]
for _, entry := range state.Allowed {
key, ok := canonicalIPKey(entry.IP)
if !ok {
allowed = append(allowed, entry)
continue
}
entry.IP = key
dedupKey := allowedKey{ip: key, source: entry.Source}
if i, exists := allowedPos[dedupKey]; exists {
allowed[i] = preferLongerLivedAllow(allowed[i], entry)
continue
}
allowedPos[dedupKey] = len(allowed)
allowed = append(allowed, entry)
}
state.Allowed = allowed
type portAllowKey struct {
ip string
port int
proto string
}
portAllowedPos := make(map[portAllowKey]struct{}, len(state.PortAllowed))
portAllowed := state.PortAllowed[:0]
for _, entry := range state.PortAllowed {
key, ok := canonicalIPKey(entry.IP)
if !ok {
portAllowed = append(portAllowed, entry)
continue
}
entry.IP = key
dedupKey := portAllowKey{ip: key, port: entry.Port, proto: entry.Proto}
if _, exists := portAllowedPos[dedupKey]; exists {
continue
}
portAllowedPos[dedupKey] = struct{}{}
portAllowed = append(portAllowed, entry)
}
state.PortAllowed = portAllowed
}
// addSetMatchRule adds an IPv4 source-IP set match rule on the input chain.
func (e *Engine) addSetMatchRule(set *nftables.Set, verdict expr.VerdictKind) {
exprs := buildSetMatchRuleExprs(set, verdict, 2, 12, 4)
if exprs == nil {
return
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: exprs,
})
}
// addSetMatchRuleV6 adds an IPv6 source-IP set match rule on the input chain.
func (e *Engine) addSetMatchRuleV6(set *nftables.Set, verdict expr.VerdictKind) {
exprs := buildSetMatchRuleExprs(set, verdict, 10, 8, 16)
if exprs == nil {
return
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: exprs,
})
}
func buildSetMatchRuleExprs(set *nftables.Set, verdict expr.VerdictKind, nfproto byte, sourceOffset, sourceLen uint32) []expr.Any {
if set == nil {
return nil
}
return []expr.Any{
&expr.Meta{Key: expr.MetaKeyNFPROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{nfproto}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: sourceOffset, Len: sourceLen},
&expr.Lookup{SourceRegister: 1, SetName: set.Name, SetID: set.ID},
&expr.Verdict{Kind: verdict},
}
}
// addCFWhitelistRule adds an accept rule for Cloudflare IPs on TCP ports 80 and 443.
// Equivalent to: ip saddr @cf_whitelist tcp dport {80, 443} accept
func (e *Engine) addCFWhitelistRule(set *nftables.Set, ipv6 bool) {
if set == nil {
return
}
for _, port := range []uint16{80, 443} {
var exprs []expr.Any
if ipv6 {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyNFPROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{10}}, // NFPROTO_IPV6
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 8, Len: 16},
)
} else {
exprs = append(exprs,
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
)
}
exprs = append(exprs,
&expr.Lookup{SourceRegister: 1, SetName: set.Name, SetID: set.ID},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}}, // TCP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(port)},
&expr.Verdict{Kind: expr.VerdictAccept},
)
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: exprs,
})
}
}
func (e *Engine) addPortAcceptRule(port int, tcp bool) {
proto := byte(6) // TCP
if !tcp {
proto = 17 // UDP
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
func (e *Engine) addPortRangeAcceptRule(startPort, endPort int, tcp bool) {
proto := byte(6)
if !tcp {
proto = 17
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
// Load dest port once, check range
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpGte, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(startPort))},
&expr.Cmp{Op: expr.CmpOpLte, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(endPort))},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
func (e *Engine) addOutboundPortRule(port int, tcp bool) {
proto := byte(6)
if !tcp {
proto = 17
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
// --- Public API ---
// BlockIP adds an IP to the blocked set with optional timeout.
// timeout 0 = permanent block.
//
// Thin wrapper over BlockIPOutcome that discards the outcome. Existing
// callers that only need success/error semantics keep working; auto-
// response callers should use BlockIPOutcome so they can suppress local
// side effects (state mutation, AUTO-BLOCK alert) when the kernel was
// not actually touched.
func (e *Engine) BlockIP(ip string, reason string, timeout time.Duration) error {
_, err := e.BlockIPOutcome(ip, reason, timeout)
return err
}
// BlockIPOutcome is the AUTO-RESPONSE entry point. It performs the same
// guards, verdict-callback consultation, and dry-run gating as BlockIP,
// but additionally reports which path was taken via BlockOutcome so the
// caller can decide whether to record local state. See the BlockOutcome
// godoc for the meaning of each return value.
//
// Operator-initiated commands (csm firewall block, Web UI manual block) must
// call BlockIPForce instead, which skips the dry-run gate unconditionally.
func (e *Engine) BlockIPOutcome(ip string, reason string, timeout time.Duration) (BlockOutcome, error) {
canonical, err := canonicalFirewallIP(ip)
if err != nil {
return BlockOutcomeNoop, err
}
ip = canonical
// Local safety checks always run before consulting the external callback.
// The callback can downgrade a block decision, but it cannot bypass
// malformed-IP, IPv6-disabled, infra-IP, or block-limit guards.
alreadyBlocked, err := e.validateBlockIP(ip, timeout, true)
if err != nil {
return BlockOutcomeNoop, err
}
if alreadyBlocked {
return BlockOutcomeNoop, nil
}
// Verdict gate: consult the panel after local validation and before the
// dry-run gate so that an "allow" verdict short-circuits everything.
// Fail-open: errors proceed with the default block. Nil callback skips
// the gate entirely.
if asker := e.verdictAskerFn(); asker != nil {
v, tenant, note, err := asker(e.verdictContext(), ip, reason)
switch {
case err != nil:
fmt.Fprintf(os.Stderr, "[%s] verdict callback failed for %s: %v - proceeding with default block\n",
time.Now().Format("2006-01-02 15:04:05"), ip, err)
case v == "allow":
fmt.Fprintf(os.Stderr, "[%s] verdict callback returned allow for %s (tenant=%q note=%q) - not blocking\n",
time.Now().Format("2006-01-02 15:04:05"), ip, tenant, note)
return BlockOutcomeAllowed, nil
case tenant != "" || note != "":
fmt.Fprintf(os.Stderr, "[%s] verdict callback returned block for %s (tenant=%q note=%q) - proceeding with default block\n",
time.Now().Format("2006-01-02 15:04:05"), ip, tenant, note)
}
// "block" / empty / error -> proceed with default flow.
}
// Dry-run gate: the daemon callback reads the current daemon config at
// call time so a SIGHUP takes effect without a daemon restart. Nil
// callback means live.
if e.autoResponseDryRunEnabled() {
fmt.Fprintf(os.Stderr, "[%s] auto_response dry_run: would have blocked %s (%s)\n",
time.Now().Format("2006-01-02 15:04:05"), ip, reason)
e.recordDryRunBlock(ip, reason, timeout)
return BlockOutcomeDryRun, nil
}
if err := e.blockIPLocked(ip, reason, timeout, true); err != nil {
return BlockOutcomeNoop, err
}
return BlockOutcomeLive, nil
}
func (e *Engine) autoResponseDryRunEnabled() bool {
e.mu.Lock()
fn := e.dryRunEnabled
e.mu.Unlock()
return fn != nil && fn()
}
// BlockIPForce adds an IP to the blocked set unconditionally, bypassing the
// auto_response.dry_run gate. Use this for operator-initiated commands (CLI,
// Web UI manual block) where the operator has explicitly decided to block.
func (e *Engine) BlockIPForce(ip string, reason string, timeout time.Duration) error {
return e.blockIPLocked(ip, reason, timeout, false)
}
// PromoteToPermanentBlock upgrades an existing temporary block on ip to a
// permanent one: it clears the kernel timeout by deleting the timed element
// and re-adding it without a timeout, and zeroes ExpiresAt in state. The
// ordinary block path cannot do this during PermBlock escalation because it
// skips an already-blocked IP, so the kernel timeout would otherwise expire
// the block the operator wanted made permanent. Returns an error if the IP is
// not currently blocked (nothing to promote).
func (e *Engine) PromoteToPermanentBlock(ip, reason string) error {
canonical, err := canonicalFirewallIP(ip)
if err != nil {
return err
}
ip = canonical
e.mu.Lock()
defer e.mu.Unlock()
targetSet, key, err := e.resolveIPSet(ip, e.setBlocked, e.setBlocked6)
if err != nil {
return err
}
priorState := e.loadStateFile()
entry := BlockedEntry{IP: ip, Reason: reason, Source: SourceSystem, BlockedAt: time.Now()}
found := false
wasTemporary := false
for _, b := range priorState.Blocked {
if sameIPString(b.IP, ip) {
found = true
wasTemporary = !b.ExpiresAt.IsZero()
entry.BlockedAt = b.BlockedAt
if b.Source != "" {
entry.Source = b.Source
}
break
}
}
if !found {
return fmt.Errorf("cannot promote %s: not currently blocked", ip)
}
if e.cfg != nil && wasTemporary && e.cfg.DenyIPLimit > 0 {
perm, _, ok := e.livePermTempCountsLocked(priorState)
if !ok {
perm = countPermanentBlockedEntries(priorState)
}
if perm >= e.cfg.DenyIPLimit {
return fmt.Errorf("permanent deny limit reached (%d)", e.cfg.DenyIPLimit)
}
}
// ExpiresAt left zero: permanent.
nextState := copyFirewallState(priorState)
upsertBlockedEntryInState(&nextState, entry)
if err := e.saveState(&nextState); err != nil {
return fmt.Errorf("persisting permanent promotion for %s: %w", ip, err)
}
// Delete the timed element and re-add it without a timeout in one
// transaction, so the address is never unblocked in between.
if err := e.conn.SetDeleteElements(targetSet, []nftables.SetElement{{Key: key}}); err != nil {
if restoreErr := e.restoreBlockStateAfterFailureLocked(priorState, ip); restoreErr != nil {
return fmt.Errorf("promoting %s: delete timed element: %w (state restore failed: %v)", ip, err, restoreErr)
}
return fmt.Errorf("promoting %s: delete timed element: %w", ip, err)
}
if err := e.conn.SetAddElements(targetSet, []nftables.SetElement{{Key: key}}); err != nil {
if restoreErr := e.restoreBlockStateAfterFailureLocked(priorState, ip); restoreErr != nil {
return fmt.Errorf("promoting %s: re-add permanent element: %w (state restore failed: %v)", ip, err, restoreErr)
}
return fmt.Errorf("promoting %s: re-add permanent element: %w", ip, err)
}
if err := e.conn.Flush(); err != nil {
if restoreErr := e.restoreBlockStateAfterFailureLocked(priorState, ip); restoreErr != nil {
return fmt.Errorf("promoting %s: flush: %w (state restore failed: %v)", ip, err, restoreErr)
}
return fmt.Errorf("promoting %s: flush: %w", ip, err)
}
AppendAudit(e.statePath, "permblock", ip, reason, entry.Source, 0)
return nil
}
// recordDryRunBlock persists a dry-run record through the daemon-installed
// recorder so operators can review the count via /api/v1/status.
// No-op when no recorder is installed.
func (e *Engine) recordDryRunBlock(ip, reason string, timeout time.Duration) {
e.mu.Lock()
recorder := e.dryRunRecorder
e.mu.Unlock()
if recorder != nil {
recorder(ip, reason, timeout)
}
}
// blockIPLocked is the real implementation called by both BlockIP and BlockIPForce.
func (e *Engine) blockIPLocked(ip string, reason string, timeout time.Duration, skipExisting bool) error {
canonical, err := canonicalFirewallIP(ip)
if err != nil {
return err
}
ip = canonical
e.mu.Lock()
defer e.mu.Unlock()
targetSet, key, alreadyBlocked, evictTempIP, err := e.blockIPTarget(ip, timeout, skipExisting)
if err != nil {
return err
}
if alreadyBlocked {
return nil
}
// Persist to state BEFORE adding the kernel element. state.json is the
// seed source on the next Apply, so a crash after this point but before the
// kernel add leaves a durable record that Apply re-applies. The previous
// ordering (kernel first, then state) left a window where a process kill
// produced a permanent kernel block with no state row: it never expired
// (timeout 0) yet a later state-seeded Apply would silently drop it.
// Zero ExpiresAt means permanent.
entry := BlockedEntry{
IP: ip,
Reason: reason,
Source: InferProvenance("block", reason),
BlockedAt: time.Now(),
}
if timeout > 0 {
entry.ExpiresAt = time.Now().Add(timeout)
}
var evictSet *nftables.Set
var evictKey []byte
if evictTempIP != "" {
evictSet, evictKey, err = e.resolveIPSet(evictTempIP, e.setBlocked, e.setBlocked6)
if err != nil {
return fmt.Errorf("resolving temp block eviction target %s: %w", evictTempIP, err)
}
}
priorState := e.loadStateFile()
nextState := copyFirewallState(priorState)
if evictTempIP != "" {
removeBlockedIPFromState(&nextState, evictTempIP)
}
upsertBlockedEntryInState(&nextState, entry)
if err := e.saveState(&nextState); err != nil {
return fmt.Errorf("persisting block for %s: %w", ip, err)
}
elem := []nftables.SetElement{{Key: key, Timeout: timeout}}
if err := e.conn.SetAddElements(targetSet, elem); err != nil {
if restoreErr := e.restoreBlockStateAfterFailureLocked(priorState, ip); restoreErr != nil {
return fmt.Errorf("adding to blocked set: %w (state restore failed: %v)", err, restoreErr)
}
return fmt.Errorf("adding to blocked set: %w", err)
}
if evictTempIP != "" {
if err := e.conn.SetDeleteElements(evictSet, []nftables.SetElement{{Key: evictKey}}); err != nil {
if restoreErr := e.restoreBlockStateAfterFailureLocked(priorState, ip); restoreErr != nil {
return fmt.Errorf("evicting temp block %s: %w (state restore failed: %v)", evictTempIP, err, restoreErr)
}
return fmt.Errorf("evicting temp block %s: %w", evictTempIP, err)
}
}
if err := e.conn.Flush(); err != nil {
if restoreErr := e.restoreBlockStateAfterFailureLocked(priorState, ip); restoreErr != nil {
return fmt.Errorf("flushing: %w (state restore failed: %v)", err, restoreErr)
}
return fmt.Errorf("flushing: %w", err)
}
if evictTempIP != "" {
AppendAudit(e.statePath, "evict_temp", evictTempIP, "temp deny limit reached; evicted soonest-expiring entry", SourceSystem, 0)
}
AppendAudit(e.statePath, "block", ip, reason, entry.Source, timeout)
return nil
}
func (e *Engine) validateBlockIP(ip string, timeout time.Duration, skipExisting bool) (bool, error) {
canonical, err := canonicalFirewallIP(ip)
if err != nil {
return false, err
}
ip = canonical
e.mu.Lock()
defer e.mu.Unlock()
_, _, alreadyBlocked, _, err := e.blockIPTarget(ip, timeout, skipExisting)
return alreadyBlocked, err
}
func (e *Engine) blockIPTarget(ip string, timeout time.Duration, skipExisting bool) (*nftables.Set, []byte, bool, string, error) {
parsed := net.ParseIP(ip)
if parsed == nil {
return nil, nil, false, "", fmt.Errorf("invalid IP: %s", ip)
}
ip = parsed.String()
// SAFETY: never block infra IPs - prevents admin lockout.
// Runs before resolveIPSet so that an IPv6-disabled config (set6 == nil)
// cannot bypass the infra guard when the caller passes the canonical
// IPv6 form of a listed infra address.
for _, cidr := range e.cfg.InfraIPs {
_, network, cidrErr := net.ParseCIDR(cidr)
if cidrErr != nil {
if infraIP := net.ParseIP(cidr); infraIP != nil && infraIP.String() == parsed.String() {
return nil, nil, false, "", fmt.Errorf("refusing to block infra IP: %s", ip)
}
continue
}
if network.Contains(parsed) {
return nil, nil, false, "", fmt.Errorf("refusing to block infra IP: %s (in %s)", ip, cidr)
}
}
// Hostnames in cfg.InfraIPs are resolved by the DynDNS loop and
// pushed in via UpdateInfraResolved. Without this check a hostname
// listed as infra would only be honoured when the operator also
// pinned the IP, so a moving panel IP would silently drop out of
// the lockout guard.
if host, ok := e.infraIPResolvedHostLocked(ip); ok {
return nil, nil, false, "", fmt.Errorf("refusing to block infra IP: %s (resolved from %s)", ip, host)
}
// Daemon's own interface addresses are always off-limits. Without
// this guard a stray request from the host to itself (cron, panel
// callback, internal probe) could trigger an auto-block that
// firewalls every customer hosted on the same IP.
if e.isLocalAddrLocked(ip) {
return nil, nil, false, "", fmt.Errorf("refusing to block local host IP: %s (own interface address)", ip)
}
targetSet, key, err := e.resolveIPSet(ip, e.setBlocked, e.setBlocked6)
if err != nil {
return nil, nil, false, "", err
}
st := e.loadStateFile()
cachedBlockMissingLive := false
if skipExisting && firewallStateHasBlocked(st, ip) {
liveBlocked, liveErr := e.isBlockedLiveLocked(ip)
if liveErr != nil || liveBlocked {
// Treat a probe error as "still blocked" so we never demote a
// cached block on transient netlink trouble. Returning nil
// here is intentional and the conservative posture.
return targetSet, key, true, "", nil //nolint:nilerr // intentional fail-safe on netlink probe error
}
cachedBlockMissingLive = true
}
// Enforce deny IP limits. Prefer counts from the live nft set
// so an entry the kernel already expired no longer counts
// against the cap. Fall back to the cached state.json count
// (existing behaviour) if the live query is unavailable.
if e.cfg.DenyIPLimit > 0 || e.cfg.DenyTempIPLimit > 0 {
perm, temp, ok := e.livePermTempCountsLocked(st)
if !ok {
excludeIP := ""
if cachedBlockMissingLive {
excludeIP = ip
}
perm, temp = blockedStatePermTempCounts(st, excludeIP)
}
if timeout == 0 && e.cfg.DenyIPLimit > 0 && perm >= e.cfg.DenyIPLimit {
return nil, nil, false, "", fmt.Errorf("permanent deny limit reached (%d)", e.cfg.DenyIPLimit)
}
if timeout > 0 && e.cfg.DenyTempIPLimit > 0 && temp >= e.cfg.DenyTempIPLimit {
// Validation must stay read-only because verdict and dry-run gates run
// after it. The live block path applies this eviction target.
victim, ok := soonestExpiringTempIP(st, ip)
if !ok {
return nil, nil, false, "", fmt.Errorf("temporary deny limit reached (%d) and no temp entry to evict", e.cfg.DenyTempIPLimit)
}
if _, _, err := e.resolveIPSet(victim, e.setBlocked, e.setBlocked6); err != nil {
return nil, nil, false, "", fmt.Errorf("temporary deny limit reached (%d) and eviction target %s is unusable: %w", e.cfg.DenyTempIPLimit, victim, err)
}
return targetSet, key, false, victim, nil
}
}
return targetSet, key, false, "", nil
}
// soonestExpiringTempIP returns the IP of the temporary block closest to
// expiry, skipping permanent blocks and excludeIP. Pure helper so the
// eviction policy is unit-testable without nftables.
func soonestExpiringTempIP(st FirewallState, excludeIP string) (string, bool) {
var best string
var bestExp time.Time
found := false
for _, b := range st.Blocked {
if sameIPString(b.IP, excludeIP) || b.ExpiresAt.IsZero() {
continue
}
if !found || b.ExpiresAt.Before(bestExp) {
best, bestExp, found = b.IP, b.ExpiresAt, true
}
}
return best, found
}
func firewallStateHasBlocked(state FirewallState, ip string) bool {
for _, entry := range state.Blocked {
if sameIPString(entry.IP, ip) {
return true
}
}
return false
}
func countPermanentBlockedEntries(state FirewallState) int {
perm, _ := blockedStatePermTempCounts(state, "")
return perm
}
func blockedStatePermTempCounts(state FirewallState, excludeIP string) (perm, temp int) {
seen := make(map[string]bool, len(state.Blocked))
for _, entry := range state.Blocked {
key, ok := canonicalIPKey(entry.IP)
if !ok {
continue
}
if excludeIP != "" && sameIPString(key, excludeIP) {
continue
}
isTemp := !entry.ExpiresAt.IsZero()
if priorTemp, exists := seen[key]; exists {
if priorTemp && !isTemp {
seen[key] = false
}
continue
}
seen[key] = isTemp
}
for _, isTemp := range seen {
if isTemp {
temp++
} else {
perm++
}
}
return perm, temp
}
// UnblockIP removes an IP from the blocked set and state.
func (e *Engine) UnblockIP(ip string) error {
canonical, err := canonicalFirewallIP(ip)
if err != nil {
return err
}
ip = canonical
e.mu.Lock()
defer e.mu.Unlock()
targetSet, key, err := e.resolveIPSet(ip, e.setBlocked, e.setBlocked6)
if err != nil {
return err
}
// Remove from state BEFORE the kernel delete: a crash between the two
// must converge to the operator's intent (unblocked) on the next Apply,
// not silently re-block the IP from a stale state row. On kernel failure
// the prior state is restored so the entry stays visible for a retry.
priorState := e.loadStateFile()
nextState := copyFirewallState(priorState)
removeBlockedIPFromState(&nextState, ip)
if err := e.saveState(&nextState); err != nil {
return fmt.Errorf("persisting unblock for %s: %w", ip, err)
}
if err := e.conn.SetDeleteElements(targetSet, []nftables.SetElement{{Key: key}}); err != nil {
if restoreErr := e.saveState(&priorState); restoreErr != nil {
return fmt.Errorf("removing from blocked set: %w (state restore failed: %v)", err, restoreErr)
}
return fmt.Errorf("removing from blocked set: %w", err)
}
if err := e.conn.Flush(); err != nil {
if isNftNotFound(err) {
// The delete target already disappeared from nft, so the
// persisted unblock is the only remaining state to keep.
AppendAudit(e.statePath, "unblock", ip, "", "", 0)
return nil
}
if restoreErr := e.saveState(&priorState); restoreErr != nil {
return fmt.Errorf("flushing: %w (state restore failed: %v)", err, restoreErr)
}
return fmt.Errorf("flushing: %w", err)
}
AppendAudit(e.statePath, "unblock", ip, "", "", 0)
return nil
}
// IsBlocked returns true if the IP is currently in the engine's blocked state.
// Uses the persisted state file (which is cleaned of expired entries on load).
//
// The lookup is O(1) via the blockedIPIndex map populated from the
// cached state. Linear scans over the parsed slice are gone -- on
// hosts with hundreds of persisted blocks the scan was the dominant
// cost of every connection-handler IsBlocked check.
func (e *Engine) IsBlocked(ip string) bool {
e.mu.Lock()
defer e.mu.Unlock()
e.ensureStateCacheLocked()
if _, ok := e.blockedIPIndex[ip]; ok {
return true
}
// The index is keyed by canonical form (blockIPLocked stores
// net.ParseIP(ip).String()). A caller passing a non-canonical form -- an
// IPv4-mapped IPv6 like "::ffff:1.2.3.4" from a dual-stack listener --
// would otherwise miss the block. Retry with the canonical form.
if parsed := net.ParseIP(ip); parsed != nil {
if canon := parsed.String(); canon != ip {
_, ok := e.blockedIPIndex[canon]
return ok
}
}
return false
}
// IsBlockedLive queries the live nftables set, not the in-memory cache built
// from state.json. The cache can drift from the kernel when nft auto-expires
// entries faster than CSM rewrites state.json, or when an out-of-band flush
// happens. Reconcile loops should consult this method so the local tracker
// shrinks in lock-step with the kernel; per-packet hot paths should stay on
// IsBlocked since this issues a netlink RTT.
//
// Malformed IPs are reported as absent. Netlink and engine-initialization
// failures are returned so callers can keep their cached answer instead of
// deleting local state on a transient lookup failure.
func (e *Engine) IsBlockedLive(ip string) (bool, error) {
e.mu.Lock()
defer e.mu.Unlock()
return e.isBlockedLiveLocked(ip)
}
// livePermTempCountsLocked returns the count of permanent and temporary
// entries across the blocked v4 + v6 nft sets. Used by blockIPTarget so
// the deny limits trip against the kernel's actual state instead of
// stale state.json entries that the kernel already expired. Live keys
// that still exist in CSM state are classified from state because nft
// timeout attributes can reflect inherited/default set behaviour rather
// than the operator's block intent. Out-of-state live keys fall back to
// the kernel expiration attributes.
//
// Must be called with e.mu held; blockIPTarget already holds the lock
// at the only call site.
func (e *Engine) livePermTempCountsLocked(state FirewallState) (perm, temp int, ok bool) {
if e.liveBlockCounts != nil {
p, t, err := e.liveBlockCounts()
if err != nil {
return 0, 0, false
}
return p, t, true
}
if e.conn == nil {
return 0, 0, false
}
stateTempByIP := blockedStateTempByIP(state)
gotAny := false
for _, set := range []*nftables.Set{e.setBlocked, e.setBlocked6} {
if set == nil {
continue
}
elements, err := e.conn.GetSetElements(set)
if err != nil {
return 0, 0, false
}
gotAny = true
p, t := countLiveBlockElements(elements, stateTempByIP)
perm += p
temp += t
}
if !gotAny {
return 0, 0, false
}
return perm, temp, true
}
func blockedStateTempByIP(state FirewallState) map[string]bool {
byIP := make(map[string]bool, len(state.Blocked)*2)
for _, entry := range state.Blocked {
if entry.IP == "" {
continue
}
temp := !entry.ExpiresAt.IsZero()
byIP[entry.IP] = temp
if parsed := net.ParseIP(entry.IP); parsed != nil {
byIP[parsed.String()] = temp
}
}
return byIP
}
func countLiveBlockElements(elements []nftables.SetElement, stateTempByIP map[string]bool) (perm, temp int) {
for _, el := range elements {
if ip, ok := setElementIPString(el.Key); ok {
if stateTemp, found := stateTempByIP[ip]; found {
if stateTemp {
temp++
} else {
perm++
}
continue
}
}
if el.Timeout > 0 || el.Expires > 0 {
temp++
} else {
perm++
}
}
return perm, temp
}
func setElementIPString(key []byte) (string, bool) {
switch len(key) {
case net.IPv4len:
return net.IP(key).String(), true
case net.IPv6len:
return net.IP(key).String(), true
default:
return "", false
}
}
func (e *Engine) isBlockedLiveLocked(ip string) (bool, error) {
parsed := net.ParseIP(ip)
if parsed == nil {
return false, nil
}
var (
set *nftables.Set
key []byte
)
if ip4 := parsed.To4(); ip4 != nil {
set = e.setBlocked
key = ip4
} else {
set = e.setBlocked6
key = parsed.To16()
}
if set == nil {
return false, fmt.Errorf("blocked set unavailable for %s", ip)
}
if e.liveBlockLookup != nil {
return e.liveBlockLookup(set, key)
}
if e.conn == nil {
return false, fmt.Errorf("nftables connection unavailable")
}
elements, err := e.conn.GetSetElements(set)
if err != nil {
return false, fmt.Errorf("listing blocked set: %w", err)
}
for _, el := range elements {
if bytes.Equal(el.Key, key) {
return true, nil
}
}
return false, nil
}
// AllowIP adds an IP to the allowed set and persists it.
// If the IP is currently blocked, the block is removed first.
func (e *Engine) AllowIP(ip string, reason string) error {
return e.allowIP(ip, reason, 0, "allow")
}
// TempAllowIP adds a temporary allow with expiry. Uses the same allowed set
// but tracks expiry in state - CleanExpiredAllows removes them periodically.
func (e *Engine) TempAllowIP(ip string, reason string, timeout time.Duration) error {
return e.allowIP(ip, reason, timeout, "temp_allow")
}
func (e *Engine) allowIP(ip string, reason string, timeout time.Duration, action string) error {
canonical, err := canonicalFirewallIP(ip)
if err != nil {
return err
}
ip = canonical
e.mu.Lock()
defer e.mu.Unlock()
blockedSet, blockedKey, _ := e.resolveIPSet(ip, e.setBlocked, e.setBlocked6)
allowedSet, allowedKey, err := e.resolveIPSet(ip, e.setAllowed, e.setAllowed6)
if err != nil {
return err
}
if allowedSet == nil {
return fmt.Errorf("allowed set unavailable for %s", ip)
}
entry := AllowedEntry{IP: ip, Reason: reason, Source: InferProvenance(action, reason)}
if action == "temp_allow" && timeout > 0 {
entry.ExpiresAt = time.Now().Add(timeout)
}
priorState := e.loadStateFile()
nextState := copyFirewallState(priorState)
removeBlockedIPFromState(&nextState, ip)
upsertAllowedEntryInState(&nextState, entry)
if err := e.saveState(&nextState); err != nil {
return fmt.Errorf("persisting %s for %s: %w", action, ip, err)
}
if blockedSet != nil {
if err := e.conn.SetDeleteElements(blockedSet, []nftables.SetElement{{Key: blockedKey}}); err != nil {
if restoreErr := e.saveState(&priorState); restoreErr != nil {
return fmt.Errorf("removing from blocked set: %w (state restore failed: %v)", err, restoreErr)
}
logNftSetOpErr(action+" remove from blocked", ip, err)
return fmt.Errorf("removing from blocked set: %w", err)
}
}
if err := e.conn.SetAddElements(allowedSet, []nftables.SetElement{{Key: allowedKey}}); err != nil {
if restoreErr := e.saveState(&priorState); restoreErr != nil {
return fmt.Errorf("adding to allowed set: %w (state restore failed: %v)", err, restoreErr)
}
return fmt.Errorf("adding to allowed set: %w", err)
}
if err := e.conn.Flush(); err != nil {
retryErr := e.retryAllowAfterBenignFlushError(blockedSet, blockedKey, allowedSet, allowedKey, err)
if retryErr == nil {
AppendAudit(e.statePath, action, ip, reason, entry.Source, timeout)
return nil
}
if restoreErr := e.saveState(&priorState); restoreErr != nil {
return fmt.Errorf("flushing: %w (state restore failed: %v)", err, restoreErr)
}
if isNftNotFound(err) {
return fmt.Errorf("flushing: %w (retry failed: %v)", err, retryErr)
}
return fmt.Errorf("flushing: %w", err)
}
AppendAudit(e.statePath, action, ip, reason, entry.Source, timeout)
return nil
}
func (e *Engine) retryAllowAfterBenignFlushError(blockedSet *nftables.Set, blockedKey []byte, allowedSet *nftables.Set, allowedKey []byte, flushErr error) error {
if !isNftNotFound(flushErr) {
return flushErr
}
if blockedSet != nil {
if err := e.conn.SetDeleteElements(blockedSet, []nftables.SetElement{{Key: blockedKey}}); err != nil {
return fmt.Errorf("retry removing from blocked set: %w", err)
}
if err := e.conn.Flush(); err != nil && !isNftNotFound(err) {
return fmt.Errorf("retry flushing blocked delete: %w", err)
}
}
if err := e.conn.SetAddElements(allowedSet, []nftables.SetElement{{Key: allowedKey}}); err != nil {
return fmt.Errorf("retry adding to allowed set: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("retry flushing allowed add: %w", err)
}
return nil
}
// CleanExpiredAllows removes expired temporary allows from the set and state.
// An IP is only removed from nftables if no non-expired entries remain for it.
// Called periodically by the daemon.
func (e *Engine) CleanExpiredAllows() int {
e.mu.Lock()
defer e.mu.Unlock()
state, ok := e.loadStateFileRawLocked()
if !ok {
return 0
}
now := time.Now()
var active []AllowedEntry
expiredIPs := make(map[string]bool)
var expired []AllowedEntry
removed := 0
for _, entry := range state.Allowed {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
expiredIPs[stateIPKey(entry.IP)] = true
expired = append(expired, entry)
removed++
} else {
active = append(active, entry)
}
}
if removed > 0 {
// Only remove from nftables if no active entries remain for the IP
activeIPs := make(map[string]bool)
for _, entry := range active {
activeIPs[stateIPKey(entry.IP)] = true
}
queueFailedIPs := make(map[string]bool)
queuedDeletes := false
for ip := range expiredIPs {
if !activeIPs[ip] {
if set, key, err := e.resolveIPSet(ip, e.setAllowed, e.setAllowed6); err == nil {
if err := e.conn.SetDeleteElements(set, []nftables.SetElement{{Key: key}}); err != nil {
logNftSetOpErr("CleanExpiredAllows remove", ip, err)
queueFailedIPs[ip] = true
} else {
queuedDeletes = true
}
}
}
}
if queuedDeletes {
if err := e.conn.Flush(); err != nil {
// The netlink batch is atomic: a failed flush applied nothing,
// so keep every expired row in state and retry next tick.
fmt.Fprintf(os.Stderr, "firewall: nft flush after expired-allow cleanup failed: %v\n", err)
return 0
}
}
// Drop state rows whose kernel element was removed (or had none to
// remove). Rows whose queue op failed stay in state so the next tick
// retries the kernel delete instead of wedging forever: the previous
// all-or-nothing handling kept already-flushed deletes in state, and
// re-deleting their missing elements failed every following tick.
dropped := make([]AllowedEntry, 0, len(expired))
for _, entry := range expired {
if queueFailedIPs[stateIPKey(entry.IP)] {
active = append(active, entry)
continue
}
dropped = append(dropped, entry)
}
state.Allowed = active
_ = e.saveState(&state)
for _, entry := range dropped {
AppendAudit(e.statePath, "temp_allow_expired", entry.IP, "", SourceSystem, 0)
}
return len(dropped)
}
return removed
}
// CleanExpiredSubnets removes expired temporary subnet blocks from nftables and state.
func (e *Engine) CleanExpiredSubnets() int {
e.mu.Lock()
defer e.mu.Unlock()
state, ok := e.loadStateFileRawLocked()
if !ok {
return 0
}
now := time.Now()
var active []SubnetEntry
var expired []SubnetEntry
removed := 0
queueFailedCIDRs := make(map[string]bool)
queuedDeletes := false
for _, entry := range state.BlockedNet {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
if _, network, err := net.ParseCIDR(entry.CIDR); err == nil {
if set, start, end := e.resolveSubnetSet(network); set != nil {
if elements := intervalSetElements(start, end); len(elements) > 0 {
if err := e.conn.SetDeleteElements(set, elements); err != nil {
logNftSetOpErr("CleanExpiredSubnets remove", entry.CIDR, err)
queueFailedCIDRs[entry.CIDR] = true
} else {
queuedDeletes = true
}
}
}
}
expired = append(expired, entry)
removed++
continue
}
active = append(active, entry)
}
if removed > 0 {
if queuedDeletes {
if err := e.conn.Flush(); err != nil {
// The netlink batch is atomic: a failed flush applied nothing,
// so keep every expired row in state and retry next tick.
fmt.Fprintf(os.Stderr, "firewall: nft flush after expired-subnet cleanup failed: %v\n", err)
return 0
}
}
// Keep queue-failed rows in state for a retry next tick; drop the
// rest. See CleanExpiredAllows for the wedge this avoids.
dropped := make([]SubnetEntry, 0, len(expired))
for _, entry := range expired {
if queueFailedCIDRs[entry.CIDR] {
active = append(active, entry)
continue
}
dropped = append(dropped, entry)
}
state.BlockedNet = active
_ = e.saveState(&state)
for _, entry := range dropped {
AppendAudit(e.statePath, "temp_subnet_expired", entry.CIDR, "", SourceSystem, 0)
}
return len(dropped)
}
return removed
}
// RemoveAllowIP removes an IP from the allowed set and state.
func (e *Engine) RemoveAllowIP(ip string) error {
canonical, err := canonicalFirewallIP(ip)
if err != nil {
return err
}
ip = canonical
e.mu.Lock()
defer e.mu.Unlock()
targetSet, key, err := e.resolveIPSet(ip, e.setAllowed, e.setAllowed6)
if err != nil {
return err
}
if err := e.conn.SetDeleteElements(targetSet, []nftables.SetElement{{Key: key}}); err != nil {
return fmt.Errorf("removing from allowed set: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
e.removeAllowedState(ip)
AppendAudit(e.statePath, "remove_allow", ip, "", "", 0)
return nil
}
// RemoveAllowIPBySource removes only allow entries from a specific source.
// The IP is only removed from the nftables set if no other sources remain.
func (e *Engine) RemoveAllowIPBySource(ip, source string) error {
canonical, err := canonicalFirewallIP(ip)
if err != nil {
return err
}
ip = canonical
e.mu.Lock()
defer e.mu.Unlock()
ipGone := e.removeAllowedStateBySource(ip, source)
if ipGone {
targetSet, key, err := e.resolveIPSet(ip, e.setAllowed, e.setAllowed6)
if err != nil {
return err
}
if err := e.conn.SetDeleteElements(targetSet, []nftables.SetElement{{Key: key}}); err != nil {
return fmt.Errorf("removing from allowed set: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
}
AppendAudit(e.statePath, "remove_allow", ip, "source: "+source, source, 0)
return nil
}
// AllowIPPort adds a port-specific IP allow. The rule is persisted to state
// and applied on the next Apply(). For immediate effect, call Apply() after.
func (e *Engine) AllowIPPort(ip string, port int, proto string, reason string) error {
canonical, err := canonicalFirewallIP(ip)
if err != nil {
return err
}
ip = canonical
if port < 1 || port > 65535 {
return fmt.Errorf("invalid port: %d", port)
}
if proto != "tcp" && proto != "udp" {
proto = "tcp"
}
e.mu.Lock()
defer e.mu.Unlock()
st := e.loadStateFile()
// Deduplicate
for _, existing := range st.PortAllowed {
if sameIPString(existing.IP, ip) && existing.Port == port && existing.Proto == proto {
return nil // already exists
}
}
st.PortAllowed = append(st.PortAllowed, PortAllowEntry{
IP: ip, Port: port, Proto: proto, Reason: reason, Source: InferProvenance("allow_port", reason),
})
_ = e.saveState(&st)
AppendAudit(e.statePath, "allow_port", fmt.Sprintf("%s:%d/%s", ip, port, proto), reason, InferProvenance("allow_port", reason), 0)
return nil
}
// RemoveAllowIPPort removes a port-specific IP allow from state.
func (e *Engine) RemoveAllowIPPort(ip string, port int, proto string) error {
canonical, err := canonicalFirewallIP(ip)
if err != nil {
return err
}
ip = canonical
e.mu.Lock()
defer e.mu.Unlock()
st := e.loadStateFile()
var remaining []PortAllowEntry
found := false
for _, entry := range st.PortAllowed {
if sameIPString(entry.IP, ip) && entry.Port == port && entry.Proto == proto {
found = true
continue
}
remaining = append(remaining, entry)
}
if !found {
return fmt.Errorf("port allow not found: %s:%d/%s", ip, port, proto)
}
st.PortAllowed = remaining
_ = e.saveState(&st)
AppendAudit(e.statePath, "remove_port_allow", fmt.Sprintf("%s:%d/%s", ip, port, proto), "", "", 0)
return nil
}
// FlushBlocked removes all IPs from the blocked set and clears persisted state.
func (e *Engine) FlushBlocked() error {
e.mu.Lock()
defer e.mu.Unlock()
e.conn.FlushSet(e.setBlocked)
if e.setBlocked6 != nil {
e.conn.FlushSet(e.setBlocked6)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing blocked set: %w", err)
}
state := e.loadStateFile()
count := len(state.Blocked)
state.Blocked = nil
_ = e.saveState(&state)
AppendAudit(e.statePath, "flush", "", fmt.Sprintf("cleared %d entries", count), SourceSystem, 0)
return nil
}
// subnetSafetyGuardLocked refuses a CIDR block that would firewall traffic the
// daemon must keep reachable: infra IPs or ranges, DNS-resolved infra hosts,
// local interface addresses, full-IP allows, port-specific allows, or the
// default route. Subnet blocks cover many addresses, and the output chain has
// no infra carve-out, so an unsafe subnet can lock out operators or kill the
// daemon's own egress.
// Must be called with e.mu held.
func (e *Engine) subnetSafetyGuardLocked(network *net.IPNet) error {
if ones, _ := network.Mask.Size(); ones == 0 {
return fmt.Errorf("refusing to block default route: %s", network.String())
}
for _, raw := range e.cfg.InfraIPs {
if _, infraNet, cidrErr := net.ParseCIDR(raw); cidrErr == nil {
if network.Contains(infraNet.IP) || infraNet.Contains(network.IP) {
return fmt.Errorf("refusing to block subnet %s: overlaps infra range %s", network.String(), raw)
}
continue
}
if infraIP := net.ParseIP(raw); infraIP != nil && network.Contains(infraIP) {
return fmt.Errorf("refusing to block subnet %s: contains infra IP %s", network.String(), raw)
}
}
for host, set := range e.infraResolved {
for key := range set {
if ip := net.ParseIP(key); ip != nil && network.Contains(ip) {
return fmt.Errorf("refusing to block subnet %s: contains infra IP %s (resolved from %s)", network.String(), key, host)
}
}
}
e.refreshLocalAddrsLocked()
for key := range e.localAddrs {
if ip := net.ParseIP(key); ip != nil && network.Contains(ip) {
return fmt.Errorf("refusing to block subnet %s: contains local host IP %s", network.String(), key)
}
}
state := e.loadStateFile()
for _, entry := range state.Allowed {
if ip := net.ParseIP(entry.IP); ip != nil && network.Contains(ip) {
return fmt.Errorf("refusing to block subnet %s: contains allowed IP %s", network.String(), entry.IP)
}
}
for _, entry := range state.PortAllowed {
if ip := net.ParseIP(entry.IP); ip != nil && network.Contains(ip) {
return fmt.Errorf("refusing to block subnet %s: contains port-allowed IP %s", network.String(), entry.IP)
}
}
return nil
}
// BlockSubnet adds a CIDR range to the blocked subnets set (IPv4 or IPv6).
// timeout 0 = permanent block.
func (e *Engine) BlockSubnet(cidr string, reason string, timeout time.Duration) error {
e.mu.Lock()
defer e.mu.Unlock()
_, network, err := net.ParseCIDR(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR: %s", cidr)
}
if err := e.subnetSafetyGuardLocked(network); err != nil {
return err
}
if e.isSubnetBlockedStateLocked(network.String()) {
return nil
}
targetSet, start, end := e.resolveSubnetSet(network)
if targetSet == nil {
return fmt.Errorf("no matching set for %s (IPv6 disabled?)", cidr)
}
elements := intervalSetElements(start, end)
if len(elements) == 0 {
return fmt.Errorf("CIDR has no safe interval end: %s", network.String())
}
entry := SubnetEntry{
CIDR: network.String(),
Reason: reason,
Source: InferProvenance("block_subnet", reason),
BlockedAt: time.Now(),
}
if timeout > 0 {
entry.ExpiresAt = time.Now().Add(timeout)
}
// Persist to state BEFORE the kernel add, mirroring blockIPLocked: state
// seeds the next Apply, so a crash between the kernel add and a later
// state write would otherwise leave a kernel block that silently
// disappears on restart. On kernel failure the prior state is restored.
priorState := e.loadStateFile()
nextState := copyFirewallState(priorState)
addSubnetEntryIfMissingInState(&nextState, entry)
if err := e.saveState(&nextState); err != nil {
return fmt.Errorf("persisting subnet block for %s: %w", network.String(), err)
}
if err := e.conn.SetAddElements(targetSet, elements); err != nil {
if restoreErr := e.saveState(&priorState); restoreErr != nil {
return fmt.Errorf("adding to blocked_nets: %w (state restore failed: %v)", err, restoreErr)
}
return fmt.Errorf("adding to blocked_nets: %w", err)
}
if err := e.conn.Flush(); err != nil {
if restoreErr := e.saveState(&priorState); restoreErr != nil {
return fmt.Errorf("flushing: %w (state restore failed: %v)", err, restoreErr)
}
return fmt.Errorf("flushing: %w", err)
}
AppendAudit(e.statePath, "block_subnet", network.String(), reason, entry.Source, timeout)
return nil
}
// IsSubnetBlocked returns true if the CIDR is present in the persisted subnet block state.
func (e *Engine) IsSubnetBlocked(cidr string) bool {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
return false
}
e.mu.Lock()
defer e.mu.Unlock()
return e.isSubnetBlockedStateLocked(network.String())
}
// BlockedSubnetCovering reports the blocked CIDR (if any) that contains ip.
// The input-chain drops blocked_nets before the allowed_ips accept, so an
// allow on an IP inside a blocked subnet has no effect: the subnet drop still
// fires. Callers surface this so an operator is not told an IP is reachable
// when a subnet rule still blocks it. The subnet block stays authoritative by
// design (see subnetSafetyGuardLocked); this only reports, it does not unblock.
func (e *Engine) BlockedSubnetCovering(ip string) (string, bool) {
e.mu.Lock()
defer e.mu.Unlock()
return subnetCovering(e.loadStateFile().BlockedNet, ip)
}
// UnblockSubnet removes a CIDR range from the blocked subnets set (IPv4 or IPv6).
func (e *Engine) UnblockSubnet(cidr string) error {
e.mu.Lock()
defer e.mu.Unlock()
_, network, err := net.ParseCIDR(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR: %s", cidr)
}
targetSet, start, end := e.resolveSubnetSet(network)
if targetSet == nil {
return fmt.Errorf("no matching set for %s (IPv6 disabled?)", cidr)
}
elements := intervalSetElements(start, end)
if len(elements) == 0 {
e.removeSubnetState(network.String())
AppendAudit(e.statePath, "unblock_subnet", network.String(), "", "", 0)
return nil
}
if err := e.conn.SetDeleteElements(targetSet, elements); err != nil {
return fmt.Errorf("removing from blocked_nets: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
e.removeSubnetState(network.String())
AppendAudit(e.statePath, "unblock_subnet", network.String(), "", "", 0)
return nil
}
// resolveSubnetSet returns the correct blocked_nets set and start/end keys
// for a CIDR. Returns (nil, nil, nil) when IPv6 is disabled for a v6 CIDR,
// or when lastIPInRange cannot produce an interval end (malformed
// net.IPNet whose IP is neither 4 nor 16 bytes) -- callers already check
// for a nil set, so this also short-circuits the degenerate case where
// nextIP(nil) would feed an empty Key to the kernel.
func (e *Engine) resolveSubnetSet(network *net.IPNet) (*nftables.Set, net.IP, net.IP) {
end := lastIPInRange(network)
if end == nil {
return nil, nil, nil
}
if start := network.IP.To4(); start != nil {
return e.setBlockedNet, start, end
}
if e.setBlockedNet6 != nil {
return e.setBlockedNet6, network.IP.To16(), end
}
return nil, nil, nil
}
func intervalSetElements(start, end net.IP) []nftables.SetElement {
if start == nil || end == nil {
return nil
}
return appendIntervalSetElements(nil, start, end)
}
func appendIntervalSetElements(dst []nftables.SetElement, start, end net.IP) []nftables.SetElement {
if start == nil || end == nil {
return dst
}
endMarker, ok := nextIPSafe(end)
if !ok {
// The interval end marker is exclusive. An all-ones end has no
// successor, so encoding it would either wrap or widen the range.
return dst
}
return append(dst,
nftables.SetElement{Key: start},
nftables.SetElement{Key: endMarker, IntervalEnd: true},
)
}
// UpdateInfraResolved records the IP set last resolved for an infra
// hostname. Replaces any previous entry for that host so the resolver's
// per-tick refresh leaves no stale ghost IPs. Pass an empty ips slice
// to remove the host entirely (e.g. when DNS stopped resolving).
func (e *Engine) UpdateInfraResolved(host string, ips []string) {
e.mu.Lock()
defer e.mu.Unlock()
if host == "" {
return
}
if e.infraResolved == nil {
e.infraResolved = make(map[string]map[string]struct{})
}
if len(ips) == 0 {
delete(e.infraResolved, host)
return
}
set := make(map[string]struct{}, len(ips))
for _, ip := range ips {
// Normalise via net.ParseIP so canonical form is stored
// (collapses IPv6 forms and drops malformed values).
if parsed := net.ParseIP(ip); parsed != nil {
set[parsed.String()] = struct{}{}
}
}
if len(set) == 0 {
delete(e.infraResolved, host)
return
}
e.infraResolved[host] = set
}
// DropInfraResolved clears all resolved IPs for a host. Equivalent to
// UpdateInfraResolved(host, nil); separate name surfaces operator
// intent at call sites that purposefully retire a hostname.
func (e *Engine) DropInfraResolved(host string) {
e.UpdateInfraResolved(host, nil)
}
// infraIPResolvedHostLocked reports whether ip matches any IP recorded
// for any tracked infra hostname. Must be called with e.mu held; the
// existing blockIPTarget path already does so. The lookup normalizes ip
// to the same canonical form the storage path applies (net.ParseIP
// collapses IPv6 and rewrites ::ffff:1.2.3.4 to 1.2.3.4), so a caller
// passing the IPv4-mapped or uncanonical form still hits the guard.
func (e *Engine) infraIPResolvedHostLocked(ip string) (string, bool) {
if e.infraResolved == nil {
return "", false
}
parsed := net.ParseIP(ip)
if parsed == nil {
return "", false
}
key := parsed.String()
for host, set := range e.infraResolved {
if _, ok := set[key]; ok {
return host, true
}
}
return "", false
}
// localAddrsCacheTTL bounds how stale the host-own-IP set can get before
// the next block call rebuilds it. The trade-off: a newly assigned local
// address could be auto-blocked for up to this window if a flagged source
// happens to share that address. Local-address changes (operator running
// `ip addr add`) are rare, so 60s keeps both the FP window and the steady
// per-block syscall cost negligible. Refreshing every miss instead would
// pin the cache to permanently-fresh under any scan storm.
const localAddrsCacheTTL = 60 * time.Second
// refreshLocalAddrsLocked rebuilds the cache of host-own interface
// addresses when the TTL has expired. Must be called with e.mu held.
// Failure leaves the previous cache in place so a transient netlink
// hiccup cannot demote the lockout guard.
func (e *Engine) refreshLocalAddrsLocked() {
if e.localAddrs != nil && !e.localAddrsExpiresAt.IsZero() && time.Now().Before(e.localAddrsExpiresAt) {
return
}
var ips []string
if e.localAddrsLookup != nil {
got, err := e.localAddrsLookup()
if err != nil {
return
}
ips = got
} else {
addrs, err := net.InterfaceAddrs()
if err != nil {
return
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
continue
}
ips = append(ips, ipnet.IP.String())
}
}
set := make(map[string]struct{}, len(ips))
for _, raw := range ips {
key, ok := localAddrGuardKey(raw)
if !ok {
continue
}
set[key] = struct{}{}
}
e.localAddrs = set
e.localAddrsExpiresAt = time.Now().Add(localAddrsCacheTTL)
}
func localAddrGuardKey(raw string) (string, bool) {
parsed := net.ParseIP(raw)
if parsed == nil {
return "", false
}
if parsed.IsLoopback() || parsed.IsLinkLocalUnicast() || parsed.IsLinkLocalMulticast() {
return "", false
}
return parsed.String(), true
}
// isLocalAddrLocked reports whether ip is one of the daemon's own host
// addresses. Must be called with e.mu held.
func (e *Engine) isLocalAddrLocked(ip string) bool {
parsed := net.ParseIP(ip)
if parsed == nil {
return false
}
e.refreshLocalAddrsLocked()
if len(e.localAddrs) == 0 {
return false
}
_, ok := e.localAddrs[parsed.String()]
return ok
}
// BlockedCount returns the number of live blocked IP entries the engine
// is enforcing. Sourced from the same state file Status() uses, so
// `/api/v1/status` and `csm firewall status` agree on the number. Expired
// entries are pruned by loadStateFile before being counted.
func (e *Engine) BlockedCount() int {
e.mu.Lock()
defer e.mu.Unlock()
s := e.loadStateFile()
return countBlockedRules(s.Blocked, e.cfg != nil && e.cfg.IPv6)
}
// RuleCounts returns the cardinality of every firewall rule category from
// the engine state file with expired temp bans pruned. Callers needing a
// live count (e.g. Prometheus gauges) must use this rather than the bbolt
// store, which holds only the migration-time snapshot.
func (e *Engine) RuleCounts() RuleCounts {
e.mu.Lock()
defer e.mu.Unlock()
s := e.loadStateFile()
return countRuleEntries(s, e.cfg != nil && e.cfg.IPv6)
}
// Status returns current firewall statistics.
//
// Takes e.mu so the cached state can be read coherently. Before the
// cache existed loadStateFile was lock-free because every call did its
// own ReadFile + Unmarshal; now that loadStateFile mutates the shared
// cache + index, the lock is required.
func (e *Engine) Status() map[string]interface{} {
e.mu.Lock()
defer e.mu.Unlock()
state := e.loadStateFile()
return map[string]interface{}{
"enabled": e.cfg.Enabled,
"tcp_in": e.cfg.TCPIn,
"tcp_out": e.cfg.TCPOut,
"udp_in": e.cfg.UDPIn,
"udp_out": e.cfg.UDPOut,
"infra_ips": e.cfg.InfraIPs,
"blocked": len(state.Blocked),
"allowed": len(state.Allowed),
"log_dropped": e.cfg.LogDropped,
}
}
// --- State persistence ---
// initialBlockState is the pre-computed pool of nft set elements to
// seed a freshly-built csm table from persisted state. Populated by
// computeInitialBlockStateLocked and consumed by
// queueInitialBlockStateLocked.
type initialBlockState struct {
blocked4, blocked6 []nftables.SetElement
allowed4, allowed6 []nftables.SetElement
blockedNet4, blockedNet6 []nftables.SetElement
}
// computeInitialBlockStateLocked reads state.json and returns the
// nft elements needed to repopulate the blocked / allowed / blocked-
// net sets. Pure computation; does not touch nft. Safe to call from
// Apply before any AddTable / AddSet so the result can be queued
// into the same atomic netlink batch.
func (e *Engine) computeInitialBlockStateLocked() initialBlockState {
state := e.loadStateFile()
now := time.Now()
var ibs initialBlockState
restoredBlocked := make(map[string]bool)
for _, entry := range state.Blocked {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
continue
}
key, ok := canonicalIPKey(entry.IP)
if !ok {
continue
}
if restoredBlocked[key] {
continue
}
parsed := net.ParseIP(entry.IP)
timeout := time.Duration(0)
if !entry.ExpiresAt.IsZero() {
timeout = time.Until(entry.ExpiresAt)
}
if ip4 := parsed.To4(); ip4 != nil {
ibs.blocked4 = append(ibs.blocked4, nftables.SetElement{Key: ip4, Timeout: timeout})
} else if e.cfg.IPv6 {
ibs.blocked6 = append(ibs.blocked6, nftables.SetElement{Key: parsed.To16(), Timeout: timeout})
}
restoredBlocked[key] = true
}
restoredAllowed := make(map[string]bool)
for _, entry := range state.Allowed {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
continue
}
key, ok := canonicalIPKey(entry.IP)
if !ok {
continue
}
if restoredAllowed[key] {
continue
}
parsed := net.ParseIP(entry.IP)
if ip4 := parsed.To4(); ip4 != nil {
ibs.allowed4 = append(ibs.allowed4, nftables.SetElement{Key: ip4})
} else if e.cfg.IPv6 {
ibs.allowed6 = append(ibs.allowed6, nftables.SetElement{Key: parsed.To16()})
}
restoredAllowed[key] = true
}
for _, entry := range state.BlockedNet {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
continue
}
_, network, err := net.ParseCIDR(entry.CIDR)
if err != nil {
continue
}
end := lastIPInRange(network)
if end == nil {
continue
}
if start := network.IP.To4(); start != nil {
ibs.blockedNet4 = appendIntervalSetElements(ibs.blockedNet4, start, end)
} else if e.cfg.IPv6 {
ibs.blockedNet6 = appendIntervalSetElements(ibs.blockedNet6, network.IP.To16(), end)
}
}
return ibs
}
// queueInitialBlockStateLocked queues the previously-computed
// elements into the still-pending Apply netlink batch. Apply Flushes
// the whole batch as one transaction. No Flush here.
func (e *Engine) queueInitialBlockStateLocked(ibs initialBlockState) error {
if err := e.addElementsChunked(e.setBlocked, ibs.blocked4); err != nil {
return err
}
if e.setBlocked6 != nil {
if err := e.addElementsChunked(e.setBlocked6, ibs.blocked6); err != nil {
return err
}
}
if err := e.addElementsChunked(e.setAllowed, ibs.allowed4); err != nil {
return err
}
if e.setAllowed6 != nil {
if err := e.addElementsChunked(e.setAllowed6, ibs.allowed6); err != nil {
return err
}
}
if err := e.addElementsChunked(e.setBlockedNet, ibs.blockedNet4); err != nil {
return err
}
if e.setBlockedNet6 != nil {
if err := e.addElementsChunked(e.setBlockedNet6, ibs.blockedNet6); err != nil {
return err
}
}
return nil
}
// addElementsChunked issues SetAddElements in fixed-size chunks. A single
// SetAddElements call encodes all elements into one netlink message whose
// size scales linearly with len(elems); the kernel's netlink socket rmem
// (typically 208 KB, tunable via net.core.rmem_max) caps how big that
// message can be before the receive path refuses it with ENOBUFS. At
// ~28 bytes per element worst-case a 1000-element chunk is ~28 KB, well
// under the default rmem and comfortably below any realistic rmem_max.
// The batch size must stay even so interval sets (blocked_net, where each
// CIDR expands to a consecutive {start, IntervalEnd} pair) never split a
// pair across chunks.
func (e *Engine) addElementsChunked(s *nftables.Set, elems []nftables.SetElement) error {
const chunk = 1000
for i := 0; i < len(elems); i += chunk {
end := i + chunk
if end > len(elems) {
end = len(elems)
}
if err := e.conn.SetAddElements(s, elems[i:end]); err != nil {
op := fmt.Sprintf("add elements to set %q chunk %d-%d", s.Name, i, end)
logNftSetOpErr(op, "initial restore", err)
return fmt.Errorf("adding initial elements to set %q chunk %d-%d: %w", s.Name, i, end, err)
}
}
return nil
}
// logNftSetOpErr keeps operator-visible nft error logs grep-friendly.
func logNftSetOpErr(op, target string, err error) {
fmt.Fprintf(os.Stderr, "firewall: nft %s for %s failed: %v\n", op, target, err)
}
// loadStateFile returns a deep copy of the cached firewall state with
// expired entries pruned. The on-disk state.json is re-read only when
// the file metadata key differs from the cached key (or the cache is empty).
//
// All callers must hold e.mu. The returned value is safe to mutate
// without affecting the cache; mutators write back via saveState which
// rebuilds the cache from the passed-in struct.
//
// Hot read paths (IsBlocked, IsSubnetBlocked, IsAllowed) bypass this
// allocation by consulting the index maps directly.
func (e *Engine) loadStateFile() FirewallState {
e.ensureStateCacheLocked()
if e.stateCache == nil {
return FirewallState{}
}
s := e.stateCache
return FirewallState{
Blocked: append([]BlockedEntry(nil), s.Blocked...),
BlockedNet: append([]SubnetEntry(nil), s.BlockedNet...),
Allowed: append([]AllowedEntry(nil), s.Allowed...),
PortAllowed: append([]PortAllowEntry(nil), s.PortAllowed...),
}
}
// loadStateFileRawLocked reads state.json without pruning expired entries.
// Expiry cleanup needs the stale rows so it can remove matching nftables
// elements before writing the active state back.
func (e *Engine) loadStateFileRawLocked() (FirewallState, bool) {
stateFile := filepath.Join(e.statePath, "state.json")
data, err := os.ReadFile(stateFile) // #nosec G304 -- filepath.Join under operator-configured statePath.
if err != nil {
if os.IsNotExist(err) {
return FirewallState{}, true
}
return FirewallState{}, false
}
var state FirewallState
if err := json.Unmarshal(data, &state); err != nil {
return FirewallState{}, false
}
return state, true
}
// ensureStateCacheLocked populates or refreshes e.stateCache. Cheap on
// cache-hit (one stat). On cache-miss it does the full ReadFile +
// json.Unmarshal that the pre-cache implementation did on every call.
//
// Expired entries are pruned in-place after load so IsBlocked and the
// other index-backed lookups never report a stale block. Pruning
// updates the index maps via rebuildIndexLocked.
func (e *Engine) ensureStateCacheLocked() {
stateFile := filepath.Join(e.statePath, "state.json")
info, statErr := os.Stat(stateFile)
if statErr == nil && e.stateCache != nil && e.stateCacheKey.matches(info) {
e.applyExpiryLocked()
return
}
var fresh FirewallState
switch {
case statErr == nil:
// #nosec G304 -- filepath.Join under operator-configured statePath.
data, readErr := os.ReadFile(stateFile)
if readErr != nil {
e.keepPriorStateCacheLocked()
return
}
if err := json.Unmarshal(data, &fresh); err != nil {
e.keepPriorStateCacheLocked()
return
}
e.stateCacheKey = stateFileCacheKeyFromInfo(info)
case os.IsNotExist(statErr):
e.stateCacheKey = stateFileCacheKey{}
case e.stateCache != nil:
// Transient stat error (permission, EIO). Keep the prior
// cache rather than dropping all blocks.
e.applyExpiryLocked()
return
}
normalizeFirewallStateIPs(&fresh)
e.stateCache = &fresh
e.applyExpiryLocked()
e.rebuildIndexLocked()
}
func (e *Engine) keepPriorStateCacheLocked() {
if e.stateCache != nil {
e.applyExpiryLocked()
return
}
e.stateCache = &FirewallState{}
e.stateCacheKey = stateFileCacheKey{}
e.rebuildIndexLocked()
}
// applyExpiryLocked prunes expired entries from the cached state in
// place. Returns whether anything changed; the index maps are rebuilt
// when something did.
func (e *Engine) applyExpiryLocked() {
if e.stateCache == nil {
return
}
now := time.Now()
s := e.stateCache
changed := false
if pruned, dropped := pruneBlocked(s.Blocked, now); dropped {
s.Blocked = pruned
changed = true
}
if pruned, dropped := pruneBlockedNet(s.BlockedNet, now); dropped {
s.BlockedNet = pruned
changed = true
}
if pruned, dropped := pruneAllowed(s.Allowed, now); dropped {
s.Allowed = pruned
changed = true
}
if changed {
e.rebuildIndexLocked()
}
}
// rebuildIndexLocked refreshes the three lookup maps from the cached
// state. Must be called any time e.stateCache is mutated.
func (e *Engine) rebuildIndexLocked() {
if e.stateCache == nil {
e.blockedIPIndex = nil
e.allowedIPIndex = nil
e.blockedCIDRIndex = nil
return
}
s := e.stateCache
blocked := make(map[string]int, len(s.Blocked))
for i, entry := range s.Blocked {
blocked[entry.IP] = i
if key, ok := canonicalIPKey(entry.IP); ok {
blocked[key] = i
}
}
e.blockedIPIndex = blocked
allowed := make(map[string]struct{}, len(s.Allowed))
for _, entry := range s.Allowed {
allowed[entry.IP] = struct{}{}
if key, ok := canonicalIPKey(entry.IP); ok {
allowed[key] = struct{}{}
}
}
e.allowedIPIndex = allowed
subnets := make(map[string]struct{}, len(s.BlockedNet))
for _, entry := range s.BlockedNet {
subnets[entry.CIDR] = struct{}{}
}
e.blockedCIDRIndex = subnets
}
// pruneBlocked returns the list with expired blocked-IP entries removed
// and a flag indicating whether anything was dropped. Returns the
// original slice when no entries expired so we avoid pointless
// allocations on the steady-state hot path.
func pruneBlocked(in []BlockedEntry, now time.Time) ([]BlockedEntry, bool) {
expired := 0
for _, entry := range in {
if !entry.ExpiresAt.IsZero() && !entry.ExpiresAt.After(now) {
expired++
}
}
if expired == 0 {
return in, false
}
out := make([]BlockedEntry, 0, len(in)-expired)
for _, entry := range in {
if !entry.ExpiresAt.IsZero() && !entry.ExpiresAt.After(now) {
continue
}
out = append(out, entry)
}
return out, true
}
func pruneBlockedNet(in []SubnetEntry, now time.Time) ([]SubnetEntry, bool) {
expired := 0
for _, entry := range in {
if !entry.ExpiresAt.IsZero() && !entry.ExpiresAt.After(now) {
expired++
}
}
if expired == 0 {
return in, false
}
out := make([]SubnetEntry, 0, len(in)-expired)
for _, entry := range in {
if !entry.ExpiresAt.IsZero() && !entry.ExpiresAt.After(now) {
continue
}
out = append(out, entry)
}
return out, true
}
func pruneAllowed(in []AllowedEntry, now time.Time) ([]AllowedEntry, bool) {
expired := 0
for _, entry := range in {
if !entry.ExpiresAt.IsZero() && !entry.ExpiresAt.After(now) {
expired++
}
}
if expired == 0 {
return in, false
}
out := make([]AllowedEntry, 0, len(in)-expired)
for _, entry := range in {
if !entry.ExpiresAt.IsZero() && !entry.ExpiresAt.After(now) {
continue
}
out = append(out, entry)
}
return out, true
}
var writeFirewallStateJSON = atomicio.AtomicWriteJSON
// saveState writes the firewall state to disk atomically (write to .tmp,
// rename into place) and rebuilds the in-memory cache to reflect the
// just-written snapshot. Callers must hold e.mu.
//
// The cache rebuild deep-copies the input slices so a caller that
// keeps mutating the local FirewallState after saveState returns cannot
// corrupt the cache.
func (e *Engine) saveState(s *FirewallState) error {
state := copyFirewallState(*s)
normalizeFirewallStateIPs(&state)
path := filepath.Join(e.statePath, "state.json")
if err := writeFirewallStateJSON(path, 0o600, &state); err != nil {
if firewallStateFileMatches(path, 0o600, &state) {
fmt.Fprintf(os.Stderr, "firewall: state.json committed with persistence warning: %v\n", err)
e.setStateCacheLocked(path, &state)
return nil
}
fmt.Fprintf(os.Stderr, "firewall: persist state.json failed: %v\n", err)
e.clearStateCacheLocked()
return err
}
e.setStateCacheLocked(path, &state)
return nil
}
func firewallStateFileMatches(path string, perm os.FileMode, s *FirewallState) bool {
info, err := os.Stat(path)
if err != nil || info.Mode().Perm() != perm {
return false
}
want, err := json.MarshalIndent(s, "", " ")
if err != nil {
return false
}
// #nosec G304 -- path is the engine-owned state file path.
got, err := os.ReadFile(path)
if err != nil {
return false
}
return bytes.Equal(got, want)
}
func (e *Engine) setStateCacheLocked(path string, s *FirewallState) {
var cacheKey stateFileCacheKey
if info, statErr := os.Stat(path); statErr == nil {
cacheKey = stateFileCacheKeyFromInfo(info)
}
state := copyFirewallState(*s)
normalizeFirewallStateIPs(&state)
e.stateCache = &FirewallState{
Blocked: append([]BlockedEntry(nil), state.Blocked...),
BlockedNet: append([]SubnetEntry(nil), state.BlockedNet...),
Allowed: append([]AllowedEntry(nil), state.Allowed...),
PortAllowed: append([]PortAllowEntry(nil), state.PortAllowed...),
}
e.stateCacheKey = cacheKey
e.rebuildIndexLocked()
}
func (e *Engine) clearStateCacheLocked() {
e.stateCache = nil
e.stateCacheKey = stateFileCacheKey{}
e.rebuildIndexLocked()
}
func copyFirewallState(s FirewallState) FirewallState {
return FirewallState{
Blocked: append([]BlockedEntry(nil), s.Blocked...),
BlockedNet: append([]SubnetEntry(nil), s.BlockedNet...),
Allowed: append([]AllowedEntry(nil), s.Allowed...),
PortAllowed: append([]PortAllowEntry(nil), s.PortAllowed...),
}
}
func upsertBlockedEntryInState(state *FirewallState, entry BlockedEntry) {
if key, ok := canonicalIPKey(entry.IP); ok {
entry.IP = key
}
out := state.Blocked[:0]
written := false
for _, existing := range state.Blocked {
if sameIPString(existing.IP, entry.IP) {
if !written {
out = append(out, entry)
written = true
}
continue
}
out = append(out, existing)
}
if !written {
out = append(out, entry)
}
state.Blocked = out
}
func removeBlockedIPFromState(state *FirewallState, ip string) {
remaining := state.Blocked[:0]
for _, entry := range state.Blocked {
if sameIPString(entry.IP, ip) {
continue
}
remaining = append(remaining, entry)
}
state.Blocked = remaining
}
func upsertAllowedEntryInState(state *FirewallState, entry AllowedEntry) {
if key, ok := canonicalIPKey(entry.IP); ok {
entry.IP = key
}
out := state.Allowed[:0]
written := false
for _, existing := range state.Allowed {
if sameIPString(existing.IP, entry.IP) && existing.Source == entry.Source {
if !written {
out = append(out, entry)
written = true
}
continue
}
out = append(out, existing)
}
if !written {
out = append(out, entry)
}
state.Allowed = out
}
func addSubnetEntryIfMissingInState(state *FirewallState, entry SubnetEntry) bool {
for _, existing := range state.BlockedNet {
if existing.CIDR == entry.CIDR {
return false
}
}
state.BlockedNet = append(state.BlockedNet, entry)
return true
}
func isNftNotFound(err error) bool {
return errors.Is(err, syscall.ENOENT)
}
func (e *Engine) restoreBlockStateAfterFailureLocked(state FirewallState, ip string) error {
if err := e.saveState(&state); err != nil {
fmt.Fprintf(os.Stderr, "firewall: restore state after failed block for %s failed: %v\n", ip, err)
return err
}
return nil
}
func (e *Engine) saveBlockedEntry(entry BlockedEntry) error {
if entry.Source == "" {
entry.Source = InferProvenance("block", entry.Reason)
}
if key, ok := canonicalIPKey(entry.IP); ok {
entry.IP = key
}
state := e.loadStateFile()
upsertBlockedEntryInState(&state, entry)
return e.saveState(&state)
}
func (e *Engine) removeBlockedState(ip string) {
state := e.loadStateFile()
var remaining []BlockedEntry
for _, entry := range state.Blocked {
if !sameIPString(entry.IP, ip) {
remaining = append(remaining, entry)
}
}
state.Blocked = remaining
_ = e.saveState(&state)
}
func (e *Engine) saveAllowedEntry(entry AllowedEntry) {
if entry.Source == "" {
entry.Source = InferProvenance("allow", entry.Reason)
}
if key, ok := canonicalIPKey(entry.IP); ok {
entry.IP = key
}
state := e.loadStateFile()
upsertAllowedEntryInState(&state, entry)
_ = e.saveState(&state)
}
func (e *Engine) removeAllowedState(ip string) {
state := e.loadStateFile()
var remaining []AllowedEntry
for _, entry := range state.Allowed {
if !sameIPString(entry.IP, ip) {
remaining = append(remaining, entry)
}
}
state.Allowed = remaining
_ = e.saveState(&state)
}
// removeAllowedStateBySource removes only entries matching ip+source.
// Returns true if no entries remain for that IP (caller should remove from nftables).
// Returns false if the IP was not in state at all (no action needed).
func (e *Engine) removeAllowedStateBySource(ip, source string) bool {
state := e.loadStateFile()
var remaining []AllowedEntry
found := false
ipStillPresent := false
for _, entry := range state.Allowed {
if sameIPString(entry.IP, ip) && entry.Source == source {
found = true
continue
}
remaining = append(remaining, entry)
if sameIPString(entry.IP, ip) {
ipStillPresent = true
}
}
if !found {
return false // IP+source not in state, nothing to do
}
state.Allowed = remaining
_ = e.saveState(&state)
return !ipStillPresent
}
func (e *Engine) saveSubnetEntry(entry SubnetEntry) {
if entry.Source == "" {
entry.Source = InferProvenance("block_subnet", entry.Reason)
}
state := e.loadStateFile()
if !addSubnetEntryIfMissingInState(&state, entry) {
return
}
_ = e.saveState(&state)
}
func (e *Engine) isSubnetBlockedStateLocked(cidr string) bool {
e.ensureStateCacheLocked()
_, ok := e.blockedCIDRIndex[cidr]
return ok
}
func (e *Engine) removeSubnetState(cidr string) {
state := e.loadStateFile()
var remaining []SubnetEntry
for _, entry := range state.BlockedNet {
if entry.CIDR != cidr {
remaining = append(remaining, entry)
}
}
state.BlockedNet = remaining
_ = e.saveState(&state)
}
// IP helpers (nextIP, lastIPInRange, fileExistsFirewall) moved to ip_helpers.go (no build tag).
// loadCountryCIDRs reads CIDR ranges from a country file.
// Expected format: one CIDR per line in {dbPath}/{CODE}.cidr
func loadCountryCIDRs(dbPath, countryCode string) []nftables.SetElement {
file := filepath.Join(dbPath, strings.ToUpper(countryCode)+".cidr")
// #nosec G304 -- filepath.Join under operator-configured GeoIP dbPath.
data, err := os.ReadFile(file)
if err != nil {
return nil
}
var elements []nftables.SetElement
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
_, network, err := net.ParseCIDR(line)
if err != nil {
continue
}
start := network.IP.To4()
end := lastIPInRange(network)
if start != nil && end != nil {
elements = appendIntervalSetElements(elements, start, end)
}
}
return elements
}
// loadCountryCIDRs6 reads IPv6 CIDR ranges from {dbPath}/{CODE}.cidr6 and
// builds interval set elements. v4 CIDRs in the file (if any) are skipped so
// only IPv6 ranges reach the IPv6 set.
func loadCountryCIDRs6(dbPath, countryCode string) []nftables.SetElement {
file := filepath.Join(dbPath, strings.ToUpper(countryCode)+".cidr6")
// #nosec G304 -- filepath.Join under operator-configured GeoIP dbPath.
data, err := os.ReadFile(file)
if err != nil {
return nil
}
var elements []nftables.SetElement
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
_, network, err := net.ParseCIDR(line)
if err != nil {
continue
}
if network.IP.To4() != nil {
continue // not an IPv6 range
}
start := network.IP.To16()
end := lastIPInRange(network)
if start != nil && end != nil {
elements = appendIntervalSetElements(elements, start, end.To16())
}
}
return elements
}
package firewall
import (
"bufio"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const (
geoIPBaseURL = "https://raw.githubusercontent.com/herrbischoff/country-ip-blocks/master/ipv4/"
geoIPBaseURLv6 = "https://raw.githubusercontent.com/herrbischoff/country-ip-blocks/master/ipv6/"
)
// UpdateGeoIPDB downloads country CIDR lists from a public source.
// Creates one file per country code per family: {dbPath}/{CC}.cidr (IPv4)
// and {dbPath}/{CC}.cidr6 (IPv6). IPv6 is best-effort so a country with no v6
// allocation does not fail the update. The return value is the number of CIDR
// files refreshed.
func UpdateGeoIPDB(dbPath string, countryCodes []string) (int, error) {
client := &http.Client{Timeout: 30 * time.Second}
return updateGeoIPDBWithClient(dbPath, countryCodes, client)
}
func updateGeoIPDBWithClient(dbPath string, countryCodes []string, client *http.Client) (int, error) {
if err := os.MkdirAll(dbPath, 0700); err != nil {
return 0, fmt.Errorf("creating geoip directory: %w", err)
}
updated := 0
for _, code := range countryCodes {
code = strings.ToLower(strings.TrimSpace(code))
if len(code) != 2 {
continue
}
cc := strings.ToUpper(code)
if downloadCIDRFile(client, geoIPBaseURL+code+".cidr", filepath.Join(dbPath, cc+".cidr")) {
updated++
}
if downloadCIDRFile(client, geoIPBaseURLv6+code+".cidr", filepath.Join(dbPath, cc+".cidr6")) {
updated++
}
}
return updated, nil
}
// downloadCIDRFile fetches url into outPath atomically. Returns false (and
// logs) on any HTTP, write, or too-small-payload condition so the caller can
// treat each family independently.
func downloadCIDRFile(client *http.Client, url, outPath string) bool {
resp, err := client.Get(url)
if err != nil {
fmt.Fprintf(os.Stderr, "geoip: error downloading %s: %v\n", url, err)
return false
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
fmt.Fprintf(os.Stderr, "geoip: %s returned HTTP %d\n", url, resp.StatusCode)
return false
}
tmpPath := outPath + ".tmp"
// #nosec G304 -- filepath.Join under operator-configured dbPath; code from fixed list.
f, err := os.Create(tmpPath)
if err != nil {
fmt.Fprintf(os.Stderr, "geoip: error creating %s: %v\n", tmpPath, err)
return false
}
n, copyErr := io.Copy(f, resp.Body)
closeErr := f.Close()
if copyErr != nil {
_ = os.Remove(tmpPath)
fmt.Fprintf(os.Stderr, "geoip: error writing %s: %v\n", tmpPath, copyErr)
return false
}
if closeErr != nil {
_ = os.Remove(tmpPath)
fmt.Fprintf(os.Stderr, "geoip: error closing %s: %v\n", tmpPath, closeErr)
return false
}
if n < 10 {
_ = os.Remove(tmpPath)
fmt.Fprintf(os.Stderr, "geoip: %s too small (%d bytes), skipping\n", url, n)
return false
}
if err := os.Rename(tmpPath, outPath); err != nil {
_ = os.Remove(tmpPath)
fmt.Fprintf(os.Stderr, "geoip: error installing %s: %v\n", outPath, err)
return false
}
fmt.Fprintf(os.Stderr, "geoip: updated %s (%d bytes)\n", outPath, n)
return true
}
// LookupIP finds which country CIDR files contain the given IP.
// Returns matching country codes.
func LookupIP(dbPath string, ip string) []string {
parsed := net.ParseIP(ip)
if parsed == nil {
return nil
}
// IPv4 (incl. v4-mapped) matches against .cidr files; IPv6 against .cidr6.
var suffix string
var needle net.IP
if ip4 := parsed.To4(); ip4 != nil {
suffix = ".cidr"
needle = ip4
} else {
suffix = ".cidr6"
needle = parsed.To16()
}
if needle == nil {
return nil
}
entries, err := os.ReadDir(dbPath)
if err != nil {
return nil
}
var matches []string
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), suffix) {
continue
}
code := strings.TrimSuffix(entry.Name(), suffix)
if containsIP(filepath.Join(dbPath, entry.Name()), needle) {
matches = append(matches, code)
}
}
return matches
}
func containsIP(cidrFile string, ip net.IP) bool {
// #nosec G304 -- cidrFile is filepath.Join under operator-configured dbPath.
f, err := os.Open(cidrFile)
if err != nil {
return false
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
_, network, err := net.ParseCIDR(line)
if err != nil {
continue
}
if network.Contains(ip) {
return true
}
}
return false
}
package firewall
import (
"net"
"os"
)
// subnetCovering returns the first CIDR in entries that contains ip, if any.
// Pure helper for BlockedSubnetCovering so the containment logic is testable
// without a kernel-attached engine.
func subnetCovering(entries []SubnetEntry, ip string) (string, bool) {
parsed := net.ParseIP(ip)
if parsed == nil {
return "", false
}
for _, entry := range entries {
_, network, err := net.ParseCIDR(entry.CIDR)
if err != nil {
continue
}
if network.Contains(parsed) {
return entry.CIDR, true
}
}
return "", false
}
// nextIP returns the IP address immediately following the given IP.
// When ip is the all-ones address for its family, nextIP clamps to ip
// instead of wrapping to all-zeros. Callers that construct nftables
// interval ranges should use nextIPSafe so saturated ends can be
// skipped rather than encoded as a wrapped interval.
func nextIP(ip net.IP) net.IP {
next, _ := nextIPSafe(ip)
return next
}
// nextIPSafe returns the successor of ip plus whether the successor
// exists in the same address family.
func nextIPSafe(ip net.IP) (net.IP, bool) {
next := canonicalIPBytes(ip)
if next == nil {
return nil, false
}
for i := len(next) - 1; i >= 0; i-- {
if next[i] != 0xff {
next[i]++
return next, true
}
next[i] = 0
}
return canonicalIPBytes(ip), false
}
func canonicalIPBytes(ip net.IP) net.IP {
if ip4 := ip.To4(); ip4 != nil {
out := make(net.IP, net.IPv4len)
copy(out, ip4)
return out
}
if ip16 := ip.To16(); ip16 != nil {
out := make(net.IP, net.IPv6len)
copy(out, ip16)
return out
}
return nil
}
// lastIPInRange returns the last IP address in a CIDR range.
func lastIPInRange(network *net.IPNet) net.IP {
ip := network.IP.To4()
if ip == nil {
ip = network.IP.To16()
}
if ip == nil {
return nil
}
mask := network.Mask
last := make(net.IP, len(ip))
for i := range ip {
if i < len(mask) {
last[i] = ip[i] | ^mask[i]
} else {
last[i] = ip[i]
}
}
return last
}
// fileExistsFirewall checks if a file exists.
func fileExistsFirewall(path string) bool {
_, err := os.Stat(path)
return err == nil
}
//go:build linux
package firewall
import (
"fmt"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
)
type portFloodIPFamily struct {
name string
nfproto byte
sourceOffset uint32
sourceLen uint32
}
var (
portFloodIPv4 = portFloodIPFamily{name: "v4", nfproto: 2, sourceOffset: 12, sourceLen: 4}
portFloodIPv6 = portFloodIPFamily{name: "v6", nfproto: 10, sourceOffset: 8, sourceLen: 16}
)
// buildPortFloodExprs returns the nftables expressions for one per-port
// flood-protection rule. The rule rate-limits new TCP/UDP connections to
// pf.Port per source address by updating a dynamic meter set. The caller
// supplies a family-specific meter, so IPv4 and IPv6 do not share buckets.
//
// Returning nil signals the caller to skip rule creation (zero rate, missing
// meter, or zero-port).
func buildPortFloodExprs(pf PortFloodRule, meter *nftables.Set, family portFloodIPFamily) []expr.Any {
if meter == nil || pf.Hits <= 0 || pf.Seconds <= 0 || pf.Port <= 0 {
return nil
}
proto := byte(6) // TCP
if pf.Proto == "udp" {
proto = 17
}
// hits/seconds -> packets per minute (multiply first to keep precision).
ratePerMin := uint64(pf.Hits) * 60 / uint64(pf.Seconds)
if ratePerMin < 1 {
ratePerMin = 1
}
burst := uint32(ratePerMin / 4)
if burst < 2 {
burst = 2
}
return []expr.Any{
// Restrict to the family that matches the meter key type.
&expr.Meta{Key: expr.MetaKeyNFPROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{family.nfproto}},
// Match new connections only.
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1, DestRegister: 1, Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(0)},
// L4 protocol filter (TCP or UDP).
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
// Destination port.
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(pf.Port))},
// Load source address into reg1; this is the meter key.
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: family.sourceOffset, Len: family.sourceLen},
// Update meter entry for this source IP, evaluating its own token bucket.
&expr.Dynset{
SrcRegKey: 1,
SetName: meter.Name,
SetID: meter.ID,
Operation: 1, // NFT_DYNSET_OP_UPDATE
Exprs: []expr.Any{
&expr.Limit{Type: expr.LimitTypePkts, Rate: ratePerMin, Unit: expr.LimitTimeMinute, Burst: burst, Over: true},
},
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
}
func portFloodMeterName(pf PortFloodRule, family portFloodIPFamily) string {
return fmt.Sprintf("meter_pf_%s_%d_%s", portFloodProto(pf), pf.Port, family.name)
}
func portFloodProto(pf PortFloodRule) string {
if pf.Proto == "udp" {
return "udp"
}
return "tcp"
}
package firewall
import "strings"
const (
SourceUnknown = "unknown"
SourceWebUI = "web_ui"
SourceCLI = "cli"
SourceAutoResponse = "auto_response"
SourceChallenge = "challenge"
SourceWhitelist = "whitelist"
SourceDynDNS = "dyndns"
SourceSystem = "system"
)
// InferProvenance classifies a firewall entry source from structured action/reason text.
// This keeps provenance logic centralized instead of spreading fragile string checks
// throughout the web UI and firewall call sites.
func InferProvenance(action, reason string) string {
action = strings.ToLower(strings.TrimSpace(action))
reason = strings.ToLower(strings.TrimSpace(reason))
switch {
case strings.Contains(reason, "dyndns:"):
return SourceDynDNS
case strings.Contains(reason, "passed challenge"),
strings.Contains(reason, "challenge-timeout"),
strings.Contains(reason, "challenge timeout"):
return SourceChallenge
case strings.Contains(reason, "temp whitelist"),
strings.Contains(reason, "whitelist"),
strings.Contains(reason, "bulk whitelist"),
strings.Contains(reason, "customer ip"):
return SourceWhitelist
case strings.Contains(reason, "auto-block"),
strings.Contains(reason, "permbblock"),
strings.Contains(reason, "permblock"),
strings.Contains(reason, "auto-netblock"):
return SourceAutoResponse
case strings.Contains(reason, "via cli"):
return SourceCLI
case strings.Contains(reason, "via csm web ui"),
strings.Contains(reason, "via ui"),
strings.Contains(reason, "allowed from firewall lookup"),
strings.Contains(reason, "manual block"):
return SourceWebUI
case action == "temp_allow_expired":
return SourceSystem
case action == "flush":
return SourceSystem
default:
return SourceUnknown
}
}
// Package rollback implements the firewall settings tentative-apply
// workflow: a save with a deadline that auto-reverts unless the operator
// confirms before the timer expires. The manager survives daemon restarts
// (state is persisted in bbolt) so that the apply itself can take down the
// daemon for a config reload without losing the rollback intent.
package rollback
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"os/exec"
"sync"
"time"
"github.com/pidginhost/csm/internal/integrity"
"github.com/pidginhost/csm/internal/store"
)
// MinTimeout and MaxTimeout bound the operator-supplied window. The lower
// bound exists so a misclick cannot leave the operator one second to react;
// the upper bound caps how long a botched apply can sit on disk before
// auto-recovery kicks in.
const (
MinTimeout = 1 * time.Minute
MaxTimeout = 30 * time.Minute
DefaultTimeout = 5 * time.Minute
)
// Restarter performs a daemon restart. Production wires this to a
// systemctl exec; tests substitute a fake.
type Restarter func(ctx context.Context) error
// Status describes the current pending rollback for status APIs and the
// Web UI banner. AppliedAt and ExpiresAt are UTC; SecondsRemaining is a
// derived hint computed at call time.
type Status struct {
Pending bool `json:"pending"`
AppliedAt time.Time `json:"applied_at,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
SecondsRemaining int64 `json:"seconds_remaining,omitempty"`
AppliedBy string `json:"applied_by,omitempty"`
PrevHash string `json:"prev_hash,omitempty"`
NewHash string `json:"new_hash,omitempty"`
}
// Manager owns the active timer and serialises Apply/Confirm/Revert/Recover
// against each other. Storage and restart are injected so the manager can be
// driven in tests without a real bbolt or systemctl.
type Manager struct {
mu sync.Mutex
db *store.DB
configPath string
restart Restarter
now func() time.Time
timer *time.Timer
}
// Process-wide singleton. The daemon installs one Manager at startup; the
// Web UI handlers and the control-socket commands look it up by calling
// Global so they do not need to thread the manager through every layer.
var (
globalMu sync.Mutex
global *Manager
)
// SetGlobal installs the process-wide manager. Safe to call from
// daemon startup; subsequent SetGlobal calls overwrite.
func SetGlobal(m *Manager) {
globalMu.Lock()
global = m
globalMu.Unlock()
}
// Global returns the installed manager or nil when none has been set
// (e.g. inside CLI commands that load config but never start the
// daemon). Callers must nil-check.
func Global() *Manager {
globalMu.Lock()
defer globalMu.Unlock()
return global
}
// NewManager wires a manager. Use SystemctlRestart for the production
// restart path. now defaults to time.Now when nil.
func NewManager(db *store.DB, configPath string, restart Restarter, now func() time.Time) *Manager {
if now == nil {
now = time.Now
}
return &Manager{
db: db,
configPath: configPath,
restart: restart,
now: now,
}
}
// SystemctlRestart issues `systemctl restart csm.service`. The context
// timeout caps how long we wait for systemctl itself; the daemon restart
// it triggers is asynchronous from systemctl's perspective.
func SystemctlRestart(ctx context.Context) error {
// #nosec G204 -- fixed argv, no operator input interpolated.
out, err := exec.CommandContext(ctx, "systemctl", "restart", "csm.service").CombinedOutput()
if err != nil {
return fmt.Errorf("systemctl restart csm: %w (%s)", err, string(out))
}
return nil
}
// HashYAML returns the sha256 hex digest of yaml bytes. Used so the
// Web UI and CLI can show operators the before/after hash without the
// full file contents.
func HashYAML(data []byte) string {
sum := sha256.Sum256(data)
return "sha256:" + hex.EncodeToString(sum[:])
}
// clampTimeout returns a timeout within [MinTimeout, MaxTimeout]; it
// substitutes DefaultTimeout for zero so callers can pass 0 to mean
// "use the default."
func clampTimeout(d time.Duration) time.Duration {
if d == 0 {
return DefaultTimeout
}
if d < MinTimeout {
return MinTimeout
}
if d > MaxTimeout {
return MaxTimeout
}
return d
}
// Apply records prevYAML as the snapshot to restore on expiry, computes
// the expiry deadline, persists the rollback entry, and arms the local
// timer. The caller is responsible for writing newYAML to disk and
// triggering the restart.
//
// applyBy is logged with the rollback record (e.g. token name or "cli")
// so audits can trace the source.
func (m *Manager) Apply(prevYAML, newYAML []byte, timeout time.Duration, applyBy string) (Status, error) {
m.mu.Lock()
defer m.mu.Unlock()
if existing, ok := m.db.GetFirewallRollback(); ok {
return statusFromRecord(existing, m.now()), fmt.Errorf("rollback already pending; confirm or revert first")
}
timeout = clampTimeout(timeout)
now := m.now().UTC()
rb := store.FirewallRollback{
PrevYAML: prevYAML,
PrevHash: HashYAML(prevYAML),
NewHash: HashYAML(newYAML),
AppliedAt: now,
ExpiresAt: now.Add(timeout),
AppliedBy: applyBy,
}
if err := m.db.SaveFirewallRollback(rb); err != nil {
return Status{}, fmt.Errorf("persist rollback: %w", err)
}
m.armTimerLocked(timeout)
return statusFromRecord(rb, m.now()), nil
}
// Confirm drops the pending rollback. The new config stays on disk;
// no daemon restart is required. Idempotent: confirming with no
// pending entry is a no-op.
func (m *Manager) Confirm() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.timer != nil {
m.timer.Stop()
m.timer = nil
}
return m.db.ClearFirewallRollback()
}
// Revert restores the snapshot to disk and triggers a daemon restart.
// Returns an error if there is no pending rollback so callers can
// surface a clean "nothing to revert" message instead of silently
// succeeding.
func (m *Manager) Revert(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
rb, ok := m.db.GetFirewallRollback()
if !ok {
return fmt.Errorf("no pending rollback")
}
return m.applyRevertLocked(ctx, rb)
}
// RecoverOnStartup is called once during daemon startup. If a pending
// rollback exists, the manager either fires the revert immediately
// (timer already expired while the daemon was down) or rearms a
// time.AfterFunc for the remaining window.
//
// The bool return is true when an immediate revert was performed so
// the caller can decide whether to bail out of further startup work
// while the restart it triggers takes effect.
func (m *Manager) RecoverOnStartup(ctx context.Context) (reverted bool, err error) {
m.mu.Lock()
rb, ok := m.db.GetFirewallRollback()
if !ok {
m.mu.Unlock()
return false, nil
}
now := m.now()
if !now.Before(rb.ExpiresAt) {
err := m.applyRevertLocked(ctx, rb)
m.mu.Unlock()
if err != nil {
return false, err
}
return true, nil
}
remaining := rb.ExpiresAt.Sub(now)
m.armTimerLocked(remaining)
m.mu.Unlock()
return false, nil
}
// Status reports the pending rollback for /api/v1/.../rollback and
// the CLI status command. Pending=false means "nothing in flight".
func (m *Manager) Status() Status {
m.mu.Lock()
defer m.mu.Unlock()
rb, ok := m.db.GetFirewallRollback()
if !ok {
return Status{}
}
return statusFromRecord(rb, m.now())
}
func (m *Manager) armTimerLocked(d time.Duration) {
if m.timer != nil {
m.timer.Stop()
}
m.timer = time.AfterFunc(d, func() {
// Build a fresh context so the AfterFunc goroutine has a
// usable deadline; the original ctx from RecoverOnStartup
// would have been cancelled by the time this fires.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := m.timerExpired(ctx); err != nil {
fmt.Fprintf(os.Stderr, "rollback: timer expiry failed: %v\n", err)
}
})
}
func (m *Manager) timerExpired(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
rb, ok := m.db.GetFirewallRollback()
if !ok {
return nil
}
return m.applyRevertLocked(ctx, rb)
}
// applyRevertLocked restores the snapshot bytes to the config path,
// triggers a daemon restart, and clears the pending record only after
// the restart command succeeds. Caller must hold m.mu.
func (m *Manager) applyRevertLocked(ctx context.Context, rb store.FirewallRollback) error {
if m.timer != nil {
m.timer.Stop()
m.timer = nil
}
if len(rb.PrevYAML) == 0 {
return fmt.Errorf("rollback record has empty prev_yaml; cannot restore")
}
if err := integrity.WriteConfigBytesAtomic(m.configPath, rb.PrevYAML); err != nil {
return fmt.Errorf("restore previous config: %w", err)
}
if m.restart != nil {
if err := m.restart(ctx); err != nil {
return fmt.Errorf("trigger restart after revert: %w", err)
}
}
if err := m.db.ClearFirewallRollback(); err != nil {
return fmt.Errorf("clear rollback after revert: %w", err)
}
return nil
}
func statusFromRecord(rb store.FirewallRollback, now time.Time) Status {
remaining := int64(rb.ExpiresAt.Sub(now).Seconds())
if remaining < 0 {
remaining = 0
}
return Status{
Pending: true,
AppliedAt: rb.AppliedAt,
ExpiresAt: rb.ExpiresAt,
SecondsRemaining: remaining,
AppliedBy: rb.AppliedBy,
PrevHash: rb.PrevHash,
NewHash: rb.NewHash,
}
}
package firewall
// RuleCounts holds firewall rule cardinalities sourced from the engine
// state file, which is the authoritative store. The parallel bbolt
// fw:* buckets are written only during migration, so anything counting
// live rules must read the engine, not the store. Expired temp bans are
// excluded.
type RuleCounts struct {
Blocked int
Allowed int
Subnets int
PortAllowed int
}
// Total returns the sum across all rule categories.
func (c RuleCounts) Total() int {
return c.Blocked + c.Allowed + c.Subnets + c.PortAllowed
}
//go:build linux
package firewall
import "net"
func countRuleEntries(state FirewallState, ipv6Enabled bool) RuleCounts {
return RuleCounts{
Blocked: countBlockedRules(state.Blocked, ipv6Enabled),
Allowed: countAllowedRules(state.Allowed, ipv6Enabled),
Subnets: countSubnetRules(state.BlockedNet, ipv6Enabled),
PortAllowed: countPortAllowRules(state.PortAllowed, ipv6Enabled),
}
}
func countBlockedRules(entries []BlockedEntry, ipv6Enabled bool) int {
seen := make(map[string]struct{}, len(entries))
for _, entry := range entries {
key, ok := ruleIPKey(entry.IP, ipv6Enabled)
if !ok {
continue
}
seen[key] = struct{}{}
}
return len(seen)
}
func countAllowedRules(entries []AllowedEntry, ipv6Enabled bool) int {
seen := make(map[string]struct{}, len(entries))
for _, entry := range entries {
key, ok := ruleIPKey(entry.IP, ipv6Enabled)
if !ok {
continue
}
seen[key] = struct{}{}
}
return len(seen)
}
func countSubnetRules(entries []SubnetEntry, ipv6Enabled bool) int {
seen := make(map[string]struct{}, len(entries))
for _, entry := range entries {
_, network, err := net.ParseCIDR(entry.CIDR)
if err != nil {
continue
}
if network.IP.To4() == nil && !ipv6Enabled {
continue
}
seen[network.String()] = struct{}{}
}
return len(seen)
}
func countPortAllowRules(entries []PortAllowEntry, ipv6Enabled bool) int {
count := 0
for _, entry := range entries {
if _, ok := ruleIPKey(entry.IP, ipv6Enabled); !ok {
continue
}
count++
}
return count
}
func ruleIPKey(raw string, ipv6Enabled bool) (string, bool) {
parsed := net.ParseIP(raw)
if parsed == nil {
return "", false
}
if ip4 := parsed.To4(); ip4 != nil {
return "4:" + net.IP(ip4).String(), true
}
if !ipv6Enabled {
return "", false
}
ip16 := parsed.To16()
if ip16 == nil {
return "", false
}
return "6:" + net.IP(ip16).String(), true
}
package firewall
import (
"fmt"
"os"
"os/user"
"strconv"
)
// smtpAllowlistLookupUser is the user-database accessor that
// resolveSMTPAllowedUIDs goes through. Production binds it to
// user.Lookup; tests can swap a fixture so unit tests do not depend on
// /etc/passwd of the host they run on.
var smtpAllowlistLookupUser = user.Lookup
// resolveSMTPAllowedUIDs returns the deduplicated list of UIDs that are
// allowed to open outbound SMTP connections when smtp_block is on.
//
// Two UIDs are unconditional:
// - 0 (root): operator commands, csm itself, system tools.
// - mailnull: exim's queue-runner runs as this user on cPanel; if it
// is dropped, queued mail never leaves the host even though cPanel
// thinks it released the hold. Silently breaking outbound mail is
// the worst possible failure mode for a security firewall, so we
// allow it unconditionally rather than rely on the operator
// remembering to list it under smtp_allow_users.
//
// `allowUsers` is the operator-supplied set; each entry is resolved
// through smtpAllowlistLookupUser. Unknown or unparseable entries are
// reported to stderr (matching the legacy behavior in createOutputChain)
// and skipped, so a typo in the YAML does not crash the firewall engine.
func resolveSMTPAllowedUIDs(allowUsers []string) []uint32 {
// Cap the size hint so a pathological config (or future caller bug)
// cannot drive a multi-gigabyte allocation; the +2 covers root and
// mailnull which are added unconditionally below.
const smtpAllowHintCap = 1 << 16
hint := len(allowUsers)
if hint > smtpAllowHintCap {
hint = smtpAllowHintCap
}
seen := make(map[uint32]struct{}, hint+2)
out := make([]uint32, 0, hint+2)
add := func(uid uint32) {
if _, ok := seen[uid]; ok {
return
}
seen[uid] = struct{}{}
out = append(out, uid)
}
add(0)
if u, err := smtpAllowlistLookupUser("mailnull"); err == nil {
if uid, parseErr := strconv.ParseUint(u.Uid, 10, 32); parseErr == nil {
add(uint32(uid))
}
}
for _, name := range allowUsers {
u, err := smtpAllowlistLookupUser(name)
if err != nil {
fmt.Fprintf(os.Stderr, "firewall: smtp_allow_users: unknown user %q\n", name)
continue
}
uid, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
fmt.Fprintf(os.Stderr, "firewall: smtp_allow_users: invalid uid for %s: %v\n", name, err)
continue
}
add(uint32(uid))
}
return out
}
package firewall
import (
"encoding/json"
"os"
"path/filepath"
"time"
)
// LoadState reads the authoritative firewall state file directly without requiring
// a running engine. A missing state file is a valid fresh-host state and returns
// an empty FirewallState.
func LoadState(statePath string) (*FirewallState, error) {
stateFile := filepath.Join(statePath, "firewall", "state.json")
// #nosec G304 -- filepath.Join under operator-configured statePath.
data, err := os.ReadFile(stateFile)
if err != nil {
if os.IsNotExist(err) {
return &FirewallState{}, nil
}
return nil, err
}
var state FirewallState
if err := json.Unmarshal(data, &state); err != nil {
return nil, err
}
// Clean expired entries
now := time.Now()
var active []BlockedEntry
for _, entry := range state.Blocked {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
active = append(active, entry)
}
}
state.Blocked = active
var activeNets []SubnetEntry
for _, entry := range state.BlockedNet {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
activeNets = append(activeNets, entry)
}
}
state.BlockedNet = activeNets
var activeAllowed []AllowedEntry
for _, entry := range state.Allowed {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
activeAllowed = append(activeAllowed, entry)
}
}
state.Allowed = activeAllowed
return &state, nil
}
// Package forensic produces evidence archives for incident response.
//
// A snapshot bundles the structured outputs an operator needs to hand a
// customer after a database-layer compromise: full trigger / event /
// routine definitions, the admin user roster, active session metadata,
// and the recent-mtime list under the account's document roots. Wrapped
// in a tar+gzip with a manifest and a SHA256 sidecar.
//
// The snapshot intentionally excludes credentials. Password rotation is
// a separate runbook step; bundling new credentials with the evidence
// archive would conflate two opposing flows (hand-to-customer for
// audit vs hand-to-ops for rotation) and risk credential leakage if
// the archive is later shared casually.
package forensic
import (
"archive/tar"
"bytes"
"compress/gzip"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"time"
)
// SchemaTarget identifies one MySQL schema to include in the snapshot
// along with the WordPress table prefix needed to enumerate the user
// roster correctly. Discovered from wp-config.php in the account's
// document roots by the production hook.
type SchemaTarget struct {
Schema string
TablePrefix string
ConfigPath string
}
// DiscoveryAudit records why a snapshot captured the targets it did.
// It is written into the manifest so operators can validate the
// evidence boundary without inspecting the host again.
type DiscoveryAudit struct {
AccountRoot string
PrivatePathsExcluded bool
PrivateTopPaths []string
SkippedPaths []SkippedPath
}
// SkippedPath records one discovery path that was not safe or useful
// to capture.
type SkippedPath struct {
Path string
Reason string
}
// Sources lets the caller swap each I/O dependency for a test double.
// Production wiring lives in cmd/csm; tests pass deterministic
// closures.
type Sources struct {
DiscoverTargets func(account string) []SchemaTarget
DumpSchema func(schema string) ([]byte, error)
ListAdmins func(schema, tablePrefix string) ([]byte, error)
ListSessions func(schema, tablePrefix string) ([]byte, error)
ListRecentFiles func(accountRoot string, since time.Time) ([]byte, error)
}
// Snapshot is the operator-facing configuration. Account and OutPath
// are required; Sources is required when running outside the default
// production wiring (see cmd/csm for the defaults).
type Snapshot struct {
Account string
OutPath string
Timestamp time.Time
DiscoveryAudit DiscoveryAudit
Sources Sources
}
var accountNamePattern = regexp.MustCompile(`^[A-Za-z0-9_-]{1,32}$`)
var schemaNamePattern = regexp.MustCompile(`^[A-Za-z0-9_@]+$`)
var tablePrefixPattern = regexp.MustCompile(`^[A-Za-z0-9_]+$`)
// AccountNameValid keeps the account string conservative enough to use
// in archive entry names, manifest keys, and shell-free SQL queries.
func AccountNameValid(name string) bool {
return accountNamePattern.MatchString(name)
}
func schemaNameValid(name string) bool {
return schemaNamePattern.MatchString(name)
}
func tablePrefixValid(name string) bool {
return tablePrefixPattern.MatchString(name)
}
// Write builds the archive at s.OutPath and a `<out>.sha256` sidecar,
// returning the archive path and the SHA256 hex digest. Errors abort
// before any partial state is written.
func (s Snapshot) Write() (string, string, error) {
if !AccountNameValid(s.Account) {
return "", "", fmt.Errorf("invalid account name: %q", s.Account)
}
if s.OutPath == "" {
return "", "", errors.New("forensic snapshot: OutPath required")
}
if err := ValidateOutPath(s.Account, s.OutPath); err != nil {
return "", "", err
}
if s.Sources.DiscoverTargets == nil {
return "", "", errors.New("forensic snapshot: Sources.DiscoverTargets required")
}
ts := s.Timestamp
if ts.IsZero() {
ts = time.Now().UTC()
}
targets := s.Sources.DiscoverTargets(s.Account)
sort.Slice(targets, func(i, j int) bool { return targets[i].Schema < targets[j].Schema })
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
tw := tar.NewWriter(gw)
var manifestB strings.Builder
fmt.Fprintf(&manifestB, "account=%s\n", s.Account)
fmt.Fprintf(&manifestB, "timestamp=%s\n", ts.UTC().Format(time.RFC3339))
fmt.Fprintf(&manifestB, "schema_count=%d\n", len(targets))
writeDiscoveryAudit(&manifestB, s.DiscoveryAudit)
var validTargets, invalidTargets int
var dumpOK, dumpErr int
var adminsOK, adminsErr int
var sessionsOK, sessionsErr int
recentStatus := "disabled"
for i, tgt := range targets {
if !schemaNameValid(tgt.Schema) || !tablePrefixValid(tgt.TablePrefix) {
invalidTargets++
fmt.Fprintf(&manifestB, "invalid_target=%d schema=%q table_prefix=%q\n", i, tgt.Schema, tgt.TablePrefix)
name := "schema/invalid-target-" + strconv.Itoa(i) + ".err"
data := []byte(fmt.Sprintf("invalid schema target: schema=%q table_prefix=%q\n", tgt.Schema, tgt.TablePrefix))
if err := writeArchiveEntry(tw, name, data, ts); err != nil {
return "", "", err
}
continue
}
validTargets++
fmt.Fprintf(&manifestB, "schema=%s table_prefix=%s", tgt.Schema, tgt.TablePrefix)
if tgt.ConfigPath != "" {
fmt.Fprintf(&manifestB, " config_path=%q", tgt.ConfigPath)
}
fmt.Fprint(&manifestB, "\n")
// Schema dump.
if s.Sources.DumpSchema != nil {
data, err := s.Sources.DumpSchema(tgt.Schema)
name := "schema/" + tgt.Schema + "-routines.sql"
if err != nil {
dumpErr++
name += ".err"
data = []byte(err.Error() + "\n")
} else {
dumpOK++
}
if werr := writeArchiveEntry(tw, name, data, ts); werr != nil {
return "", "", werr
}
}
// Admin roster.
if s.Sources.ListAdmins != nil {
data, err := s.Sources.ListAdmins(tgt.Schema, tgt.TablePrefix)
name := "schema/" + tgt.Schema + "-admins.tsv"
if err != nil {
adminsErr++
name += ".err"
data = []byte(err.Error() + "\n")
} else {
adminsOK++
}
if werr := writeArchiveEntry(tw, name, data, ts); werr != nil {
return "", "", werr
}
}
// Sessions.
if s.Sources.ListSessions != nil {
data, err := s.Sources.ListSessions(tgt.Schema, tgt.TablePrefix)
name := "schema/" + tgt.Schema + "-sessions.tsv"
if err != nil {
sessionsErr++
name += ".err"
data = []byte(err.Error() + "\n")
} else {
sessionsOK++
}
if werr := writeArchiveEntry(tw, name, data, ts); werr != nil {
return "", "", werr
}
}
}
// Recent files.
if s.Sources.ListRecentFiles != nil {
data, err := s.Sources.ListRecentFiles("/home/"+s.Account, ts.Add(-7*24*time.Hour))
name := "files/recent-mtimes.tsv"
if err != nil {
recentStatus = "error"
name += ".err"
data = []byte(err.Error() + "\n")
} else {
recentStatus = "ok"
}
if werr := writeArchiveEntry(tw, name, data, ts); werr != nil {
return "", "", werr
}
}
fmt.Fprintf(&manifestB, "valid_target_count=%d\n", validTargets)
fmt.Fprintf(&manifestB, "invalid_target_count=%d\n", invalidTargets)
fmt.Fprintf(&manifestB, "dump_success_count=%d\n", dumpOK)
fmt.Fprintf(&manifestB, "dump_error_count=%d\n", dumpErr)
fmt.Fprintf(&manifestB, "admins_success_count=%d\n", adminsOK)
fmt.Fprintf(&manifestB, "admins_error_count=%d\n", adminsErr)
fmt.Fprintf(&manifestB, "sessions_success_count=%d\n", sessionsOK)
fmt.Fprintf(&manifestB, "sessions_error_count=%d\n", sessionsErr)
fmt.Fprintf(&manifestB, "recent_mtimes_status=%s\n", recentStatus)
// Manifest last so the schema list reflects what was actually
// processed.
if err := writeArchiveEntry(tw, "manifest.txt", []byte(manifestB.String()), ts); err != nil {
return "", "", err
}
if err := tw.Close(); err != nil {
return "", "", fmt.Errorf("closing tar: %w", err)
}
if err := gw.Close(); err != nil {
return "", "", fmt.Errorf("closing gzip: %w", err)
}
if err := os.WriteFile(s.OutPath, buf.Bytes(), 0o600); err != nil {
return "", "", fmt.Errorf("writing archive: %w", err)
}
sum := sha256.Sum256(buf.Bytes())
hexSum := hex.EncodeToString(sum[:])
sidecar := s.OutPath + ".sha256"
sidecarBody := fmt.Sprintf("%s %s\n", hexSum, filepath.Base(s.OutPath))
if err := os.WriteFile(sidecar, []byte(sidecarBody), 0o600); err != nil {
return "", "", fmt.Errorf("writing sidecar: %w", err)
}
return s.OutPath, hexSum, nil
}
func writeDiscoveryAudit(b *strings.Builder, audit DiscoveryAudit) {
if audit.AccountRoot != "" {
fmt.Fprintf(b, "discovery_root=%q\n", audit.AccountRoot)
}
if audit.PrivatePathsExcluded {
fmt.Fprintln(b, "private_paths_excluded=true")
}
if len(audit.PrivateTopPaths) > 0 {
paths := append([]string(nil), audit.PrivateTopPaths...)
sort.Strings(paths)
fmt.Fprintf(b, "private_top_excluded=%q\n", strings.Join(paths, ","))
}
for i, skipped := range audit.SkippedPaths {
fmt.Fprintf(b, "skipped_path=%d path=%q reason=%q\n", i, skipped.Path, skipped.Reason)
}
}
// ValidateOutPath rejects destinations that would land inside the
// target account's home directory. Writing forensic evidence somewhere
// the suspect user can read defeats the point.
func ValidateOutPath(account, outPath string) error {
if !AccountNameValid(account) {
return fmt.Errorf("invalid account name: %q", account)
}
abs, err := filepath.Abs(outPath)
if err != nil {
return fmt.Errorf("resolving out path: %w", err)
}
home := filepath.Clean("/home/" + account)
homes := []string{home}
if realHome, err := filepath.EvalSymlinks(home); err == nil {
homes = append(homes, realHome)
}
paths := []string{abs}
if realAbs, err := filepath.EvalSymlinks(abs); err == nil {
paths = append(paths, realAbs)
}
if realParent, err := filepath.EvalSymlinks(filepath.Dir(abs)); err == nil {
paths = append(paths, filepath.Join(realParent, filepath.Base(abs)))
}
for _, h := range homes {
for _, p := range paths {
if pathWithin(h, p) {
return fmt.Errorf("out path must not be inside /home/%s/", account)
}
}
}
return nil
}
func pathWithin(root, path string) bool {
rel, err := filepath.Rel(filepath.Clean(root), filepath.Clean(path))
if err != nil {
return false
}
return rel == "." || (rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && !filepath.IsAbs(rel))
}
// writeArchiveEntry adds a single file to the tar stream with a fixed
// mtime so identical inputs produce byte-identical archives. The mode
// is 0600 because forensic content is operator-only.
func writeArchiveEntry(tw *tar.Writer, name string, data []byte, ts time.Time) error {
if unsafeArchiveEntryName(name) {
return fmt.Errorf("unsafe archive entry name: %q", name)
}
hdr := &tar.Header{
Name: name,
Size: int64(len(data)),
Mode: 0o600,
ModTime: ts.UTC(),
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("tar header %s: %w", name, err)
}
if _, err := tw.Write(data); err != nil {
return fmt.Errorf("tar body %s: %w", name, err)
}
return nil
}
func unsafeArchiveEntryName(name string) bool {
if name == "" || filepath.IsAbs(name) || strings.Contains(name, `\`) {
return true
}
clean := filepath.Clean(name)
return clean != name || clean == "." || clean == ".." || strings.HasPrefix(clean, ".."+string(filepath.Separator))
}
package geoip
// KnownEditions returns the MaxMind database editions CSM supports in the
// settings UI. The first slice is the free GeoLite2 family; the second is
// the paid GeoIP2 family. The lists are curated to what MaxMind actually
// publishes via the geoipupdate protocol; adding an edition here makes it
// selectable in the Settings → GeoIP → Database editions dropdown.
func KnownEditions() (free, commercial []string) {
free = []string{
"GeoLite2-City",
"GeoLite2-Country",
"GeoLite2-ASN",
}
commercial = []string{
"GeoIP2-City",
"GeoIP2-Country",
"GeoIP2-ISP",
"GeoIP2-Domain",
"GeoIP2-Connection-Type",
"GeoIP2-Anonymous-IP",
"GeoIP2-Enterprise",
}
return free, commercial
}
// Package geoip provides IP geolocation via MaxMind GeoLite2 databases
// and on-demand RDAP lookups for detailed ISP/org information.
package geoip
import (
"encoding/json"
"fmt"
"net/http"
"net/netip"
"os"
"path/filepath"
"sync"
"time"
"github.com/oschwald/maxminddb-golang/v2"
)
// Info contains geolocation and network information for an IP.
type Info struct {
IP string `json:"ip"`
Country string `json:"country"` // ISO 3166-1 alpha-2 (e.g. "US")
CountryName string `json:"country_name"` // Full name (e.g. "United States")
City string `json:"city,omitempty"`
ASN uint `json:"asn,omitempty"` // Autonomous System Number
ASOrg string `json:"as_org,omitempty"` // AS Organization (ISP)
Network string `json:"network,omitempty"` // CIDR range
RDAPOrg string `json:"rdap_org,omitempty"` // Detailed org from RDAP (on-demand)
RDAPName string `json:"rdap_name,omitempty"` // Network name from RDAP
RDAPCountry string `json:"rdap_country,omitempty"` // Country from RDAP
}
// DB holds the MaxMind database readers.
type DB struct {
mu sync.RWMutex
cityDB *maxminddb.Reader
asnDB *maxminddb.Reader
dbDir string
rdapMu sync.Mutex
rdapTTL map[string]rdapCacheEntry
}
type rdapCacheEntry struct {
info Info
fetched time.Time
}
// MaxMind GeoLite2 record structures
type cityRecord struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
Names map[string]string `maxminddb:"names"`
} `maxminddb:"country"`
City struct {
Names map[string]string `maxminddb:"names"`
} `maxminddb:"city"`
}
type asnRecord struct {
ASN uint `maxminddb:"autonomous_system_number"`
Org string `maxminddb:"autonomous_system_organization"`
}
// Open loads MaxMind databases from the given directory.
// Expects GeoLite2-City.mmdb and/or GeoLite2-ASN.mmdb.
// Returns nil if no databases found (graceful degradation).
func Open(dbDir string) *DB {
if dbDir == "" {
return nil
}
db := &DB{
dbDir: dbDir,
rdapTTL: make(map[string]rdapCacheEntry),
}
cityPath := filepath.Join(dbDir, "GeoLite2-City.mmdb")
if r, err := maxminddb.Open(cityPath); err == nil {
db.cityDB = r
fmt.Fprintf(os.Stderr, "geoip: loaded %s\n", cityPath)
}
asnPath := filepath.Join(dbDir, "GeoLite2-ASN.mmdb")
if r, err := maxminddb.Open(asnPath); err == nil {
db.asnDB = r
fmt.Fprintf(os.Stderr, "geoip: loaded %s\n", asnPath)
}
if db.cityDB == nil && db.asnDB == nil {
fmt.Fprintf(os.Stderr, "geoip: no databases found in %s (download GeoLite2-City.mmdb and GeoLite2-ASN.mmdb)\n", dbDir)
return nil
}
return db
}
// Close releases database resources.
func (db *DB) Close() {
if db == nil {
return
}
db.mu.Lock()
defer db.mu.Unlock()
if db.cityDB != nil {
_ = db.cityDB.Close()
}
if db.asnDB != nil {
_ = db.asnDB.Close()
}
}
// Reload opens new database readers and swaps them in atomically.
// Opens replacement readers first - if both fail, the old readers stay in place.
// If one succeeds and the other fails, only the successful one is swapped.
func (db *DB) Reload() error {
if db == nil {
return fmt.Errorf("geoip: cannot reload nil DB")
}
cityPath := filepath.Join(db.dbDir, "GeoLite2-City.mmdb")
asnPath := filepath.Join(db.dbDir, "GeoLite2-ASN.mmdb")
// Open replacements before taking lock
newCity, cityErr := maxminddb.Open(cityPath)
newASN, asnErr := maxminddb.Open(asnPath)
if cityErr != nil && asnErr != nil {
return fmt.Errorf("geoip: reload failed - city: %v, asn: %v", cityErr, asnErr)
}
db.mu.Lock()
if newCity != nil {
if db.cityDB != nil {
_ = db.cityDB.Close()
}
db.cityDB = newCity
}
if newASN != nil {
if db.asnDB != nil {
_ = db.asnDB.Close()
}
db.asnDB = newASN
}
db.mu.Unlock()
if cityErr != nil {
fmt.Fprintf(os.Stderr, "geoip: reload warning - city DB failed: %v\n", cityErr)
}
if asnErr != nil {
fmt.Fprintf(os.Stderr, "geoip: reload warning - ASN DB failed: %v\n", asnErr)
}
return nil
}
// OpenFresh creates a new DB from databases on disk.
// Use when no DB existed at startup and databases have since been downloaded.
// Returns nil if no databases found (same behavior as Open).
func OpenFresh(dbDir string) *DB {
return Open(dbDir)
}
// Lookup returns geolocation info for an IP from local MaxMind databases.
// Fast (microseconds), no network calls.
func (db *DB) Lookup(ip string) Info {
info := Info{IP: ip}
if db == nil {
return info
}
addr, err := netip.ParseAddr(ip)
if err != nil {
return info
}
db.mu.RLock()
defer db.mu.RUnlock()
if db.cityDB != nil {
var record cityRecord
result := db.cityDB.Lookup(addr)
if err := result.Decode(&record); err == nil {
info.Country = record.Country.ISOCode
info.CountryName = record.Country.Names["en"]
info.City = record.City.Names["en"]
if prefix := result.Prefix(); prefix.IsValid() {
info.Network = prefix.String()
}
}
}
if db.asnDB != nil {
var record asnRecord
result := db.asnDB.Lookup(addr)
if err := result.Decode(&record); err == nil {
info.ASN = record.ASN
info.ASOrg = record.Org
}
}
return info
}
// LookupWithRDAP returns geolocation info plus on-demand RDAP details.
// The RDAP lookup is cached for 24 hours.
func (db *DB) LookupWithRDAP(ip string) Info {
info := db.Lookup(ip)
// Check RDAP cache
db.rdapMu.Lock()
if cached, ok := db.rdapTTL[ip]; ok && time.Since(cached.fetched) < 24*time.Hour {
db.rdapMu.Unlock()
info.RDAPOrg = cached.info.RDAPOrg
info.RDAPName = cached.info.RDAPName
info.RDAPCountry = cached.info.RDAPCountry
return info
}
db.rdapMu.Unlock()
// Fetch from RDAP
rdapInfo := fetchRDAP(ip)
info.RDAPOrg = rdapInfo.RDAPOrg
info.RDAPName = rdapInfo.RDAPName
info.RDAPCountry = rdapInfo.RDAPCountry
// Cache
db.rdapMu.Lock()
db.rdapTTL[ip] = rdapCacheEntry{info: rdapInfo, fetched: time.Now()}
db.evictRDAPLocked()
db.rdapMu.Unlock()
return info
}
// maxRDAPCacheEntries hard-caps the RDAP cache. Caller holds db.rdapMu.
const maxRDAPCacheEntries = 10000
// evictRDAPLocked bounds the RDAP cache. It first drops expired entries; if the
// map is still over the cap (every entry fresh, e.g. a burst of distinct
// lookups within the 24h TTL), it evicts the oldest entries until at the cap so
// the map cannot grow without bound. Caller holds db.rdapMu.
func (db *DB) evictRDAPLocked() {
if len(db.rdapTTL) <= maxRDAPCacheEntries {
return
}
for k, v := range db.rdapTTL {
if time.Since(v.fetched) > 24*time.Hour {
delete(db.rdapTTL, k)
}
}
for len(db.rdapTTL) > maxRDAPCacheEntries {
var oldestKey string
var oldest time.Time
for k, v := range db.rdapTTL {
if oldestKey == "" || v.fetched.Before(oldest) {
oldestKey = k
oldest = v.fetched
}
}
if oldestKey == "" {
break
}
delete(db.rdapTTL, oldestKey)
}
}
// RDAP lookup - fetches from the appropriate RIR
func fetchRDAP(ip string) Info {
var info Info
url := fmt.Sprintf("https://rdap.org/ip/%s", ip)
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Get(url)
if err != nil {
return info
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return info
}
var rdap struct {
Name string `json:"name"`
Country string `json:"country"`
Handle string `json:"handle"`
Entities []struct {
VCardArray []interface{} `json:"vcardArray"`
Roles []string `json:"roles"`
} `json:"entities"`
}
if err := json.NewDecoder(resp.Body).Decode(&rdap); err != nil {
return info
}
info.RDAPName = rdap.Name
info.RDAPCountry = rdap.Country
// Extract org name from entities
for _, entity := range rdap.Entities {
for _, role := range entity.Roles {
if role == "registrant" || role == "abuse" {
if len(entity.VCardArray) >= 2 {
if props, ok := entity.VCardArray[1].([]interface{}); ok {
for _, prop := range props {
if arr, ok := prop.([]interface{}); ok && len(arr) >= 4 {
if name, ok := arr[0].(string); ok && name == "fn" {
if val, ok := arr[3].(string); ok {
info.RDAPOrg = val
}
}
}
}
}
}
}
}
}
return info
}
package geoip
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/oschwald/maxminddb-golang/v2"
)
const (
maxMindBaseURL = "https://download.maxmind.com/geoip/databases"
maxDownloadSize = 150 * 1024 * 1024 // 150MB
// maxExtractedSize bounds the size of the .mmdb entry we extract from
// the tar.gz. Real GeoLite2 files are ~70MB; 500MiB is generous. The
// check sits on the tar header and prevents io.Copy from writing a
// bomb-compressed entry to disk.
maxExtractedSize = 500 * 1024 * 1024
downloadTimeout = 120 * time.Second
)
// EditionResult reports the outcome of updating a single GeoLite2 edition.
type EditionResult struct {
Edition string // e.g. "GeoLite2-City"
Status string // "updated", "up_to_date", "error"
Err error // nil unless Status == "error"
}
// Update downloads GeoLite2 databases from MaxMind's direct download API.
// Returns one EditionResult per edition. Returns nil if credentials are empty.
func Update(dbDir, accountID, licenseKey string, editions []string) []EditionResult {
if accountID == "" || licenseKey == "" {
return nil
}
if err := os.MkdirAll(dbDir, 0700); err != nil {
result := make([]EditionResult, len(editions))
for i, ed := range editions {
result[i] = EditionResult{Edition: ed, Status: "error", Err: fmt.Errorf("creating directory: %w", err)}
}
return result
}
client := &http.Client{Timeout: downloadTimeout}
results := make([]EditionResult, len(editions))
for i, edition := range editions {
results[i] = updateEdition(client, dbDir, accountID, licenseKey, edition)
}
return results
}
func updateEdition(client *http.Client, dbDir, accountID, licenseKey, edition string) EditionResult {
return updateEditionWithURL(client, dbDir, accountID, licenseKey, edition, maxMindBaseURL)
}
func updateEditionWithURL(client *http.Client, dbDir, accountID, licenseKey, edition, baseURL string) EditionResult {
url := fmt.Sprintf("%s/%s/download?suffix=tar.gz", baseURL, edition)
markerPath := filepath.Join(dbDir, ".last-modified-"+edition)
// Read stored Last-Modified
storedLM := ""
// #nosec G304 -- filepath.Join under operator-configured dbDir.
if data, err := os.ReadFile(markerPath); err == nil {
storedLM = strings.TrimSpace(string(data))
}
// HEAD request to check Last-Modified
headReq, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return EditionResult{Edition: edition, Status: "error", Err: err}
}
headReq.SetBasicAuth(accountID, licenseKey)
headResp, err := client.Do(headReq)
if err != nil {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("HEAD request: %w", err)}
}
headResp.Body.Close()
if headResp.StatusCode == 401 {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("invalid MaxMind credentials")}
}
if headResp.StatusCode == 429 {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("rate limited by MaxMind")}
}
if headResp.StatusCode != 200 {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("HEAD returned HTTP %d", headResp.StatusCode)}
}
remoteLM := headResp.Header.Get("Last-Modified")
if storedLM != "" && remoteLM != "" && storedLM == remoteLM {
return EditionResult{Edition: edition, Status: "up_to_date"}
}
// GET request to download
getReq, err := http.NewRequest("GET", url, nil)
if err != nil {
return EditionResult{Edition: edition, Status: "error", Err: err}
}
getReq.SetBasicAuth(accountID, licenseKey)
getResp, err := client.Do(getReq)
if err != nil {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("download: %w", err)}
}
defer getResp.Body.Close()
if getResp.StatusCode != 200 {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("download returned HTTP %d", getResp.StatusCode)}
}
// Reject oversized responses upfront if Content-Length is known
if getResp.ContentLength > maxDownloadSize {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("download too large: %d bytes (max %d)", getResp.ContentLength, maxDownloadSize)}
}
// LimitReader as safety net for responses without Content-Length
mmdbTmpPath := filepath.Join(dbDir, edition+".mmdb.tmp")
if err := extractMMDB(io.LimitReader(getResp.Body, maxDownloadSize), mmdbTmpPath, edition); err != nil {
os.Remove(mmdbTmpPath)
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("extract: %w", err)}
}
if err := validateMMDB(mmdbTmpPath); err != nil {
os.Remove(mmdbTmpPath)
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("validate: %w", err)}
}
// Atomic install
destPath := filepath.Join(dbDir, edition+".mmdb")
if err := os.Rename(mmdbTmpPath, destPath); err != nil {
os.Remove(mmdbTmpPath)
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("install: %w", err)}
}
// Save Last-Modified marker
if remoteLM != "" {
_ = os.WriteFile(markerPath, []byte(remoteLM), 0600)
}
return EditionResult{Edition: edition, Status: "updated"}
}
func validateMMDB(path string) error {
db, err := maxminddb.Open(path)
if err != nil {
return err
}
return db.Close()
}
// extractMMDB reads a tar.gz stream and extracts the .mmdb file to destPath.
// MaxMind tar.gz archives contain a single directory with the .mmdb inside,
// e.g. GeoLite2-City_20260328/GeoLite2-City.mmdb
func extractMMDB(r io.Reader, destPath, edition string) error {
gz, err := gzip.NewReader(r)
if err != nil {
return fmt.Errorf("gzip: %w", err)
}
defer func() { _ = gz.Close() }()
tr := tar.NewReader(gz)
suffix := edition + ".mmdb"
for {
header, err := tr.Next()
if err == io.EOF {
return fmt.Errorf("no %s found in archive", suffix)
}
if err != nil {
return fmt.Errorf("reading tar: %w", err)
}
if header.Typeflag != tar.TypeReg {
continue
}
if !strings.HasSuffix(header.Name, suffix) {
continue
}
if header.Size > maxExtractedSize {
return fmt.Errorf("archive entry too large: %d bytes", header.Size)
}
// #nosec G304 -- destPath is filepath.Join under operator-configured dbDir.
f, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("creating %s: %w", destPath, err)
}
_, copyErr := io.Copy(f, io.LimitReader(tr, maxExtractedSize))
closeErr := f.Close()
if copyErr != nil {
return fmt.Errorf("writing mmdb: %w", copyErr)
}
if closeErr != nil {
return fmt.Errorf("closing mmdb: %w", closeErr)
}
return nil
}
}
package health
import "time"
// Provider is the contract the daemon (or a stub for tests) implements
// so the snapshot builder doesn't depend on internal/daemon directly.
type Provider interface {
Hostname() string
StartedAt() time.Time
LatestScan() time.Time
BaselineAt() time.Time
WatcherStatuses() map[string]bool
StoreHealthy() bool
StoreSizeMB() float64
SeverityCounts() map[string]int
BlocklistSize() int
IncidentsOpen() int
BPFEnforcementActive() bool
HistoryCount() int
ConfigHash() string
BinaryHash() string
DryRunBlocksCount() int
AutomationStatus() AutomationStatus
UpdateInfo() UpdateInfo
}
// Build assembles a Snapshot from the provider plus the static version
// string and capability list. Safe to call from any goroutine; the
// provider's accessors are expected to be lock-protected internally.
func Build(p Provider, version string, capabilities []string) Snapshot {
started := p.StartedAt()
uptime := int64(0)
if !started.IsZero() {
uptime = int64(time.Since(started).Seconds())
}
caps := append([]string(nil), capabilities...)
return Snapshot{
Version: version,
Hostname: p.Hostname(),
StartedAt: started,
UptimeSec: uptime,
LatestScan: p.LatestScan(),
BaselineAt: p.BaselineAt(),
BlocklistSize: p.BlocklistSize(),
IncidentsOpen: p.IncidentsOpen(),
BPFEnforcementActive: p.BPFEnforcementActive(),
HistoryCount: p.HistoryCount(),
Severities: cloneIntMap(p.SeverityCounts()),
Watchers: cloneBoolMap(p.WatcherStatuses()),
StoreHealthy: p.StoreHealthy(),
StoreSizeMB: p.StoreSizeMB(),
ConfigHash: p.ConfigHash(),
BinaryHash: p.BinaryHash(),
Capabilities: caps,
DryRunBlocks: p.DryRunBlocksCount(),
Automation: p.AutomationStatus(),
Update: p.UpdateInfo(),
}
}
func cloneIntMap(m map[string]int) map[string]int {
out := make(map[string]int, len(m))
for k, v := range m {
out[k] = v
}
return out
}
func cloneBoolMap(m map[string]bool) map[string]bool {
out := make(map[string]bool, len(m))
for k, v := range m {
out[k] = v
}
return out
}
package health
import (
"github.com/pidginhost/csm/internal/bpf"
"github.com/pidginhost/csm/internal/maillog"
)
// Capabilities is the static list of features this build supports. Phpanel
// reads it via /api/v1/capabilities to feature-detect without version
// sniffing. Add a string here when shipping a feature; remove when ripping
// one out. Keep the base order stable; build-tag gated capabilities are
// appended only when this binary actually supports them.
//
// BPF capability strings (`bpf-...`) are appended dynamically based on the
// running kernel's accepted BPF program types. Their presence depends on
// build tag and host kernel and is therefore not stable across deployments.
func Capabilities() []string {
caps := []string{
"confd.dropins.v1", // P1
"profile.phpanel-agent.v1", // P1
"status.json.v1", // P2
"capabilities.v1", // P2
"doctor.v1", // P2
"config.schema.v1", // P2
"sd_notify.ready", // P2
"audit.fields.tenant.v1", // P3
"webhook.phpanel.v1", // P3
"events.sse.v1", // P3
"token.scope.readonly.v1", // P3
"mail.brute.account_key.v1", // P4
"ti.source.rspamd.v1", // P4
"auto_response.dry_run.v1", // P5
"infra_ips.guard.v1", // P5
"store.backup.v1", // P5
"ti.source.upstream.v1", // P6
"verdict.callback.v1", // P7
"systemd.dropin.example.v1", // P7
"incidents.v1",
"bpf_enforcement.available.v1",
"webui.prefs.v1", // operator preferences + saved views
"webui.undo.v1", // bulk-action undo
"mail.filter.exfil.v1", // BEC mail-filter exfiltration detection
"mail.queue.composition.v1",
"mail.forward_guard.v1", // opt-in MTA-native forward-guard (hold spam/backscatter forward copies)
}
if maillog.JournalSupported() {
caps = append(caps, "mail.source.journal.v1")
}
caps = appendBPFCaps(caps)
caps = appendActiveBPFFeatures(caps)
return caps
}
// appendActiveBPFFeatures adds one capability string per BPF-backed live
// monitor that is currently running on the kernel-side path (as opposed to
// its userspace fallback). Phases 2-4 extend this with their own feature
// keys; the test for each phase asserts the expected toggling.
func appendActiveBPFFeatures(out []string) []string {
if bpf.ActiveKind("connection_tracker") == bpf.BackendBPF {
out = append(out, "bpf-connection-tracker")
}
if bpf.ActiveKind("af_alg") == bpf.BackendBPF {
out = append(out, "bpf-af-alg-live")
}
if bpf.ActiveKind("exec_monitor") == bpf.BackendBPF {
out = append(out, "bpf-exec-monitor")
}
if bpf.ActiveKind("sensitive_files") == bpf.BackendBPF {
out = append(out, "bpf-sensitive-files")
}
return out
}
// bpfCapabilities returns the cached probe result. Tests use this to assert
// that capability strings stay in sync with the shared BPF probe.
func bpfCapabilities() bpf.Capabilities { return bpf.Probe() }
// appendBPFCaps adds one capability string per BPF program type the kernel
// accepts. Phases 1-4 add a second helper alongside this one for per-feature
// "live monitor is currently running on BPF" strings.
func appendBPFCaps(out []string) []string {
caps := bpf.Probe()
if caps.LSMAttach {
out = append(out, "bpf-lsm-attach")
}
if caps.CgroupSock {
out = append(out, "bpf-cgroup-sock")
}
if caps.Tracepoint {
out = append(out, "bpf-tracepoint")
}
if caps.Ringbuf {
out = append(out, "bpf-ringbuf")
}
return out
}
package health
import "time"
// Snapshot is the unified machine-readable health view assembled from the
// running daemon (or, on a cold lookup, from on-disk state). It is the
// single source of truth for /api/v1/status, csm status --json, csm doctor,
// and the sd_notify readiness gate.
type Snapshot struct {
Version string `json:"version"`
Hostname string `json:"hostname"`
StartedAt time.Time `json:"started_at"`
UptimeSec int64 `json:"uptime_sec"`
LatestScan time.Time `json:"latest_scan,omitempty"`
BaselineAt time.Time `json:"baseline_at,omitempty"`
BlocklistSize int `json:"blocklist_size"`
IncidentsOpen int `json:"incidents_open"`
BPFEnforcementActive bool `json:"bpf_enforcement_active"`
HistoryCount int `json:"history_count"`
Severities map[string]int `json:"severities"` // "critical","high","warning"
Watchers map[string]bool `json:"watchers"` // name -> attached
StoreHealthy bool `json:"store_healthy"`
StoreSizeMB float64 `json:"store_size_mb"`
ConfigHash string `json:"config_hash,omitempty"`
BinaryHash string `json:"binary_hash,omitempty"`
Capabilities []string `json:"capabilities,omitempty"`
// DryRunBlocks is the count of firewall blocks that were intercepted by
// auto_response.dry_run and logged rather than applied to nftables.
// Cleared whenever auto-response is live; dry-run mode keeps a recent
// rolling window for operator review.
DryRunBlocks int `json:"dry_run_blocks,omitempty"`
// Automation is the operator-facing safety surface for automatic action
// rollout. It groups dry-run state, challenge routing, pending firewall
// rollback, and the last recorded automation action in one stable payload.
Automation AutomationStatus `json:"automation,omitempty"`
// Update reports whether a newer CSM release is available upstream.
// Populated by internal/updatecheck. Zero value means the checker has
// not yet completed a poll (very early startup) or is disabled in
// config.
Update UpdateInfo `json:"update,omitempty"`
}
// AutomationStatus summarizes the live automation safety state. It is
// intentionally compact so status clients can decide whether the host is
// observe-only, actively mutating the firewall, or waiting for operator
// confirmation after a tentative firewall apply.
type AutomationStatus struct {
AutoResponseEnabled bool `json:"auto_response_enabled"`
AutoResponseBlockIPs bool `json:"auto_response_block_ips"`
AutoResponseDryRun bool `json:"auto_response_dry_run"`
DryRunBlocks int `json:"dry_run_blocks"`
ChallengeEnabled bool `json:"challenge_enabled"`
ChallengePortGateEnabled bool `json:"challenge_port_gate_enabled"`
ChallengePortGateActive bool `json:"challenge_port_gate_active"`
ChallengePending int `json:"challenge_pending"`
FirewallRollbackPending bool `json:"firewall_rollback_pending"`
FirewallRollbackSecondsRemain int64 `json:"firewall_rollback_seconds_remaining,omitempty"`
LastAction *AutomationAction `json:"last_action,omitempty"`
}
// AutomationAction is the newest action-like finding CSM recorded.
type AutomationAction struct {
Check string `json:"check"`
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
}
// UpdateInfo mirrors updatecheck.Info for the health snapshot. Kept
// as a separate type so internal/health does not import
// internal/updatecheck and create a cycle.
type UpdateInfo struct {
LatestVersion string `json:"latest_version,omitempty"`
Available bool `json:"available,omitempty"`
Source string `json:"source,omitempty"`
CheckedAt time.Time `json:"checked_at,omitempty"`
Err string `json:"err,omitempty"`
}
// TotalFindings returns the sum across all severity buckets.
func (s Snapshot) TotalFindings() int {
total := 0
for _, v := range s.Severities {
total += v
}
return total
}
// AllWatchersAttached reports whether every registered watcher is attached.
// An empty Watchers map returns false (we never claim ready before probing).
func (s Snapshot) AllWatchersAttached() bool {
if len(s.Watchers) == 0 {
return false
}
for _, attached := range s.Watchers {
if !attached {
return false
}
}
return true
}
// OverallStatus collapses the snapshot into one of: "ok", "degraded", "down".
// - "down" if the snapshot was zero-valued (never assembled)
// - "degraded" if any watcher is detached or the store is unhealthy
// - "ok" otherwise
func (s Snapshot) OverallStatus() string {
if s.StartedAt.IsZero() && len(s.Watchers) == 0 {
return "down"
}
if !s.StoreHealthy || !s.AllWatchersAttached() {
return "degraded"
}
return "ok"
}
package incident
import (
"fmt"
"sort"
"strings"
"time"
)
// BulkStatusFilter selects stale active incidents for a bounded operator
// transition. The caller must supply at least one age guard and a positive
// limit so a broad filter cannot accidentally close every incident.
type BulkStatusFilter struct {
FromStatuses []Status
To Status
OlderThan time.Duration
LastSeenBefore time.Time
Kind Kind
Domain string
Account string
Mailbox string
Limit int
DryRun bool
Details string
Now time.Time
}
// BulkStatusItem is a small preview row for bulk incident status changes.
type BulkStatusItem struct {
ID string `json:"id"`
Kind string `json:"kind"`
Status string `json:"status"`
NewStatus string `json:"new_status"`
Severity string `json:"severity"`
Domain string `json:"domain,omitempty"`
Account string `json:"account,omitempty"`
Mailbox string `json:"mailbox,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastSeenAt time.Time `json:"last_seen_at"`
}
// BulkStatusResult reports how many incidents matched the filter and how
// many were changed. Items is capped by BulkStatusFilter.Limit.
type BulkStatusResult struct {
Matched int
Updated int
Items []BulkStatusItem
}
// BulkSetStatus previews or applies one closing transition to matching
// incidents. Matching and mutation happen under the correlator lock so a
// fresh finding cannot update LastSeen between filter evaluation and close.
func (c *Correlator) BulkSetStatus(filter BulkStatusFilter) (BulkStatusResult, error) {
if filter.To != StatusResolved && filter.To != StatusDismissed {
return BulkStatusResult{}, fmt.Errorf("incident: bulk target status must be resolved or dismissed")
}
if filter.OlderThan <= 0 && filter.LastSeenBefore.IsZero() {
return BulkStatusResult{}, fmt.Errorf("incident: bulk status requires older-than or last-seen-before")
}
if filter.Limit <= 0 {
return BulkStatusResult{}, fmt.Errorf("incident: bulk status requires a positive limit")
}
now := filter.Now
if now.IsZero() {
now = c.now()
}
statusSet := make(map[Status]struct{}, len(filter.FromStatuses))
for _, status := range filter.FromStatuses {
if !validStatus(status) {
return BulkStatusResult{}, fmt.Errorf("incident: invalid status %q", status)
}
statusSet[status] = struct{}{}
}
if len(statusSet) == 0 {
return BulkStatusResult{}, fmt.Errorf("incident: bulk status requires a source status")
}
var persist []queuedPersist
result := BulkStatusResult{Items: make([]BulkStatusItem, 0, filter.Limit)}
c.mu.Lock()
matched := make([]*Incident, 0, len(c.incidents))
for _, inc := range c.incidents {
if bulkStatusMatches(inc, filter, statusSet, now) {
matched = append(matched, inc)
}
}
sort.Slice(matched, func(i, j int) bool {
if !matched[i].UpdatedAt.Equal(matched[j].UpdatedAt) {
return matched[i].UpdatedAt.Before(matched[j].UpdatedAt)
}
return matched[i].ID < matched[j].ID
})
result.Matched = len(matched)
for _, inc := range matched {
if len(result.Items) >= filter.Limit {
break
}
from := inc.Status
result.Items = append(result.Items, bulkStatusItem(inc, filter.To))
if filter.DryRun {
continue
}
inc.Status = filter.To
inc.UpdatedAt = now
inc.ClosedAt = now
inc.ClosedBy = "operator"
inc.Actions = append(inc.Actions, IncidentAction{
Time: now,
Action: "incident_status_changed",
Result: "ok",
Details: string(from) + " -> " + string(filter.To) + ": " + filter.Details,
})
c.counters.statusChangedTotal.Add(1)
c.unbindLocked(inc.ID)
if c.spray != nil {
c.spray.UnbindIncident(inc.ID)
}
if req, ok := c.queuePersistLocked(*inc); ok {
persist = append(persist, req)
}
result.Updated++
}
c.mu.Unlock()
for _, req := range persist {
c.runQueuedPersist(req)
}
return result, nil
}
func bulkStatusMatches(inc *Incident, filter BulkStatusFilter, statusSet map[Status]struct{}, now time.Time) bool {
if _, ok := statusSet[inc.Status]; !ok {
return false
}
if filter.OlderThan > 0 {
cutoff := now.Add(-filter.OlderThan)
if inc.UpdatedAt.After(cutoff) {
return false
}
}
if !filter.LastSeenBefore.IsZero() && inc.UpdatedAt.After(filter.LastSeenBefore) {
return false
}
if filter.Kind != "" && inc.Kind != filter.Kind {
return false
}
if filter.Domain != "" && !strings.EqualFold(inc.Domain, filter.Domain) {
return false
}
if filter.Account != "" && !strings.EqualFold(inc.Account, filter.Account) {
return false
}
if filter.Mailbox != "" && !strings.EqualFold(inc.Mailbox, filter.Mailbox) {
return false
}
return true
}
func bulkStatusItem(inc *Incident, to Status) BulkStatusItem {
return BulkStatusItem{
ID: inc.ID,
Kind: string(inc.Kind),
Status: string(inc.Status),
NewStatus: string(to),
Severity: inc.Severity.String(),
Domain: inc.Domain,
Account: inc.Account,
Mailbox: inc.Mailbox,
CreatedAt: inc.CreatedAt,
LastSeenAt: inc.UpdatedAt,
}
}
package incident
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"net"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// ErrIncidentNotFound is returned when SetStatus or other lookups
// target an unknown incident id.
var ErrIncidentNotFound = errors.New("incident: not found")
// maxIncidentFindings and maxIncidentTimeline cap the per-incident
// fingerprint slice and operator-visible timeline so a long-running
// open incident with sustained low-severity traffic does not grow
// memory and persistence payloads without bound. Eviction keeps the
// first half (incident-opening context an operator needs to root-
// cause the incident) and the most recent half (so the timeline
// reflects current activity). Operators reading the timeline see a
// gap marker via the appended IncidentEvent that the cap fires.
const (
maxIncidentFindings = 5000
maxIncidentTimeline = 500
incidentFingerprintTruncatedMark = "...truncated:"
incidentTimelineTruncatedKind = "truncated"
)
// incidentMergeWindow is the time gap inside which two findings with the
// same correlation key are considered the same incident. Named constant
// per project convention; config exposure deferred until operators ask.
const incidentMergeWindow = 15 * time.Minute
// CorrelatorConfig is reserved for future tunables and the persistence
// hook used by the daemon to write incidents to bbolt.
type CorrelatorConfig struct {
// Persist is invoked after every create/update. Implementations
// must be quick and idempotent. nil means "in-memory only".
Persist func(Incident)
// OpenThreshold is the number of correlated findings required before
// a non-Critical finding opens an incident. Critical-severity
// findings always open immediately so escalations page on first
// hit. Values <= 0 default to 1 (open on first finding) for
// backwards compatibility with callers that expect the original
// behavior; the daemon explicitly configures 2 to suppress
// one-shot scanner noise.
OpenThreshold int
// SpraySuppression turns on the credential-spray super-incident
// path. When zero the detector is not constructed and OnFinding
// follows the legacy per-mailbox correlation path. Default-off.
SpraySuppression SpraySuppressionConfig
// IsWhitelisted is consulted before a source-IP finding can anchor
// incident correlation and by the spray detector to skip IPs the
// operator has marked as known-good (e.g. internal mail relays).
// nil short-circuits to "no IPs whitelisted".
IsWhitelisted func(ip string) bool
// CanSprayBlock is consulted immediately before recording a
// credential_spray block request. nil means "allowed" when
// OnSprayBlock is present. Implementations must be quick and must
// not call back into the correlator.
CanSprayBlock func() bool
// OnSprayBlock is invoked once per IP when the credential_spray
// detector decides the IP should be hard-blocked, based on
// SpraySuppression.BlockAtSeverity. The callback runs after the
// correlator mutex is released so firewall or verdict latency does
// not stall incident ingestion. nil disables the hand-off; the spray
// super-incident still opens and escalates, but no firewall action
// fires. The return value reports whether the firewall actually
// recorded the block: false means dry-run, transient failure, or an
// upstream gate refused, and the audit "credential_spray_block_requested"
// action is not appended in that case so operators cannot mistake a
// declined request for an enforced block.
OnSprayBlock func(ip, reason string) bool
// AutoBlock turns on the generic incident-driven firewall hand-off
// for non-spray kinds. Independent of SpraySuppression; applies when
// an incident has exactly one unambiguous remote IP and its kind is
// allowed by AutoBlock.Kinds (empty = any). Default-zero means the
// path is dormant.
AutoBlock IncidentAutoBlockConfig
// CanIncidentBlock is consulted immediately before recording a
// generic incident block request. nil means "allowed" when
// OnIncidentBlock is present. Lets the daemon recheck
// auto_response.enabled / block_ips at decision time so SIGHUP
// edits take effect without rebuilding the correlator.
CanIncidentBlock func() bool
// OnIncidentBlock fires when the generic auto-block gate trips. The
// callback runs after the correlator mutex is released and returns
// true only when a live block request was accepted. Dry-run,
// disabled, and failed attempts must return false so the correlator
// can retry on the next finding instead of permanently latching the
// incident. nil disables the path even when AutoBlock is configured.
OnIncidentBlock func(ip, reason string) bool
}
// IncidentAutoBlockConfig drives the generic incident-driven firewall
// hand-off independent of credential-spray suppression. Operators turn
// it on once they have validated that incident severity is trustworthy
// (the daemon does not promote a finding to High/Critical without
// either an explicit per-check signal or the correlator's threshold
// gate).
type IncidentAutoBlockConfig struct {
Enabled bool
// BlockAtSeverity is the minimum incident severity that triggers
// a firewall hand-off. "" / "high" / "critical". Comparison is
// case-insensitive. Any other value is ignored so operator typos
// cannot accidentally engage blocking.
BlockAtSeverity string
// Kinds, when non-empty, restricts the auto-block path to the
// listed incident kinds. Empty means "every kind that carries one
// unambiguous remote IP". Credential_spray is implicitly excluded
// since the dedicated spray hand-off owns it.
Kinds map[Kind]bool
}
// IsZero reports whether the config is unset; the correlator treats a
// zero value as "generic auto-block disabled" without touching
// defaults.
func (c IncidentAutoBlockConfig) IsZero() bool {
return !c.Enabled && c.BlockAtSeverity == "" && len(c.Kinds) == 0
}
// counters holds the atomic tallies exposed via RegisterMetrics. Kept
// on the Correlator so a single instance owns its own metric state and
// tests can build isolated correlators without touching globals.
type counters struct {
createdTotal atomic.Uint64
severityChangedTotal atomic.Uint64
statusChangedTotal atomic.Uint64
findingsMergedTotal atomic.Uint64
compactedTotal atomic.Uint64
autoClosedTotal atomic.Uint64
autoCloseDryRunTotal atomic.Uint64
sprayOpenedTotal atomic.Uint64
spraySuppressedTotal atomic.Uint64
sprayDryRunTotal atomic.Uint64
}
// Correlator groups findings into incidents. In-memory state; the
// daemon is responsible for wiring it to a store via CorrelatorConfig.Persist.
type Correlator struct {
mu sync.Mutex
// persistMu protects persistTail. Persist callbacks wait on the
// previously queued write instead of holding this lock, so re-entrant
// callbacks can still take c.mu while later writers are queued.
persistMu sync.Mutex
persistTail chan struct{}
cfg CorrelatorConfig
incidents map[string]*Incident
byKey map[string]string
pending map[string]pendingFinding
pendingSprayBlocks map[string]struct{}
pendingIncidentBlocks map[string]struct{}
openThreshold int
now func() time.Time
counters counters
spray *sprayDetector
}
// pendingFinding is a finding seen for a key that has not yet met the
// open threshold. Stored only on the create path; merge into open
// incidents stays unconditional.
type pendingFinding struct {
finding alert.Finding
at time.Time
}
// NewCorrelator returns a ready Correlator. Nothing to start; this type
// is purely callback-driven.
func NewCorrelator(cfg CorrelatorConfig) *Correlator {
threshold := cfg.OpenThreshold
if threshold < 1 {
threshold = 1
}
c := &Correlator{
cfg: cfg,
persistTail: closedPersistTail(),
incidents: map[string]*Incident{},
byKey: map[string]string{},
pending: map[string]pendingFinding{},
pendingSprayBlocks: map[string]struct{}{},
pendingIncidentBlocks: map[string]struct{}{},
openThreshold: threshold,
now: time.Now,
}
c.spray = newSprayDetector(cfg.SpraySuppression, incidentMergeWindow, func() time.Time { return c.now() }, cfg.IsWhitelisted)
return c
}
func closedPersistTail() chan struct{} {
done := make(chan struct{})
close(done)
return done
}
// OnFinding ingests a Finding. Returns the incident id (if attributable)
// and whether a new incident was created. Unattributable findings yield
// ("", false, nil). Non-Critical findings whose key has fewer than
// OpenThreshold prior findings inside the merge window are stashed in
// the pending map and yield ("", false, nil) too; they will only open
// an incident if the threshold is met inside the window.
func (c *Correlator) OnFinding(f alert.Finding) (string, bool, error) {
key := KeyFor(f)
if key.IsEmpty() {
return "", false, nil
}
if key.Host == "" && f.SourceIP != "" && c.cfg.IsWhitelisted != nil && c.cfg.IsWhitelisted(f.SourceIP) {
return "", false, nil
}
var afterUnlock func()
c.mu.Lock()
defer func() {
c.mu.Unlock()
if afterUnlock != nil {
afterUnlock()
}
}()
keyStr := keyString(key)
now := c.now()
// Credential-spray super-incident path. When one source IP brute-forces
// many distinct mailboxes/accounts inside the merge window, collapse
// the per-mailbox fan-out into a single credential_spray incident
// keyed on RemoteIP. The detector returns sprayDecisionNone for
// non-spray traffic so the legacy correlation continues unchanged.
if c.spray != nil {
decision, hits := c.spray.Decide(f)
switch decision {
case sprayDecisionOpen:
sprayKey := Key{RemoteIP: f.SourceIP}
id := c.createSprayIncidentLocked(sprayKey, f, now, hits)
c.spray.BindIncident(f.SourceIP, id)
c.counters.sprayOpenedTotal.Add(1)
if cb := c.maybeBlockSprayLocked(c.incidents[id], f.SourceIP, hits, now, "spray opened"); cb != nil {
afterUnlock = cb
}
return id, true, nil
case sprayDecisionSuppress:
id := c.spray.IncidentForIP(f.SourceIP)
inc, ok := c.incidents[id]
if ok && incidentStatusActive(inc.Status) {
c.mergeLocked(inc, f, now, true)
if hits >= c.spray.cfg.SeverityEscalateAt && inc.Severity < alert.Critical {
from := inc.Severity
inc.Severity = alert.Critical
c.counters.severityChangedTotal.Add(1)
inc.Actions = append(inc.Actions, IncidentAction{
Time: now,
Action: "incident_severity_changed",
Result: "ok",
Details: from.String() + " -> CRITICAL: spray sustained " + strconv.Itoa(hits) + " mailboxes",
})
c.persistLocked(*inc)
}
// Re-evaluate the block gate on every merged finding. The
// configured BlockAtSeverity may have been armed AFTER the
// incident was opened or escalated, in which case the
// transition-time hook already fired without effect; the
// helper is idempotent via triggerSprayBlockLocked's
// action-presence and in-flight guards so a no-op call is
// harmless.
if cb := c.maybeBlockSprayLocked(inc, f.SourceIP, hits, now, "spray ongoing"); cb != nil {
afterUnlock = cb
}
c.counters.spraySuppressedTotal.Add(1)
return id, false, nil
}
// Bound incident vanished (purged) or is no longer active
// (operator resolved/dismissed). Clear the perIP binding so
// subsequent findings don't keep falling into the same dead
// lookup, then fall through to legacy so the finding still
// produces an incident rather than silently disappearing.
if id != "" {
c.spray.UnbindIncident(id)
}
case sprayDecisionNone:
if c.spray.cfg.DryRun && hits >= c.spray.cfg.DistinctMailboxes {
c.counters.sprayDryRunTotal.Add(1)
}
}
}
if id, ok := c.byKey[keyStr]; ok {
if inc, exists := c.incidents[id]; exists && now.Sub(inc.UpdatedAt) <= incidentMergeWindow {
c.mergeLocked(inc, f, now, true)
delete(c.pending, keyStr)
if cb := c.maybeBlockIncidentLocked(inc, now, "merge"); cb != nil {
afterUnlock = cb
}
return id, false, nil
}
// Stale binding -- the incident is older than the merge window.
// Drop the binding and fall through to create so a fresh incident
// owns the key going forward.
delete(c.byKey, keyStr)
}
// Threshold gate. Non-Critical findings need OpenThreshold sightings
// inside the merge window before opening an incident, except High
// host-integrity findings. Those are not scanner noise and must
// surface on the first sighting like Critical escalations.
if c.openThreshold > 1 && !opensIncidentImmediately(f) {
if pf, ok := c.pending[keyStr]; ok && now.Sub(pf.at) <= incidentMergeWindow {
delete(c.pending, keyStr)
id := c.createIncidentLocked(key, keyStr, pf.finding, pf.at)
inc := c.incidents[id]
c.mergeLocked(inc, f, now, true)
if cb := c.maybeBlockIncidentLocked(inc, now, "threshold promote"); cb != nil {
afterUnlock = cb
}
return id, true, nil
}
c.pending[keyStr] = pendingFinding{finding: f, at: now}
return "", false, nil
}
id := c.createIncidentLocked(key, keyStr, f, now)
delete(c.pending, keyStr)
if cb := c.maybeBlockIncidentLocked(c.incidents[id], now, "incident opened"); cb != nil {
afterUnlock = cb
}
return id, true, nil
}
// createSprayIncidentLocked builds a credential_spray incident keyed on
// the source IP. Caller must hold c.mu. Severity is HIGH at trip and
// escalates to CRITICAL once the merge path observes
// SpraySuppressionConfig.SeverityEscalateAt distinct mailboxes.
func (c *Correlator) createSprayIncidentLocked(key Key, f alert.Finding, now time.Time, hits int) string {
id := newIncidentID()
sev := f.Severity
if sev < alert.High {
sev = alert.High
}
inc := &Incident{
ID: id,
Kind: KindCredentialSpray,
Status: StatusOpen,
Severity: sev,
CorrelationKey: cloneKey(key),
Findings: []string{},
Timeline: []IncidentEvent{},
Actions: []IncidentAction{{
Time: now,
Action: "credential_spray_opened",
Result: "ok",
Details: f.SourceIP + " hit " + strconv.Itoa(hits) + " distinct mailboxes inside window",
}},
CreatedAt: now,
UpdatedAt: now,
}
c.incidents[id] = inc
keyStr := keyString(key)
c.byKey[keyStr] = id
c.counters.createdTotal.Add(1)
c.mergeLocked(inc, f, now, false)
return id
}
// createIncidentLocked builds a new Incident, registers it in the maps,
// and seeds it with the given finding via mergeLocked. Caller must hold
// c.mu. mergeLocked is the single source of truth for Persist
// invocations, avoiding double-fire on create.
func (c *Correlator) createIncidentLocked(key Key, keyStr string, f alert.Finding, now time.Time) string {
id := newIncidentID()
displayMailbox, displayDomain := displayMailboxDomain(f.Mailbox, f.Domain)
if displayMailbox == "" && displayDomain == "" {
displayMailbox, displayDomain = key.Mailbox, key.Domain
}
inc := &Incident{
ID: id,
Kind: ClassifyKind(f),
Status: StatusOpen,
Severity: f.Severity,
Account: key.Account,
Domain: displayDomain,
Mailbox: displayMailbox,
CorrelationKey: cloneKey(key),
Findings: []string{},
Timeline: []IncidentEvent{},
Actions: []IncidentAction{},
CreatedAt: now,
UpdatedAt: now,
}
c.incidents[id] = inc
c.byKey[keyStr] = id
c.counters.createdTotal.Add(1)
c.mergeLocked(inc, f, now, false)
return id
}
// PendingCount returns the number of findings currently held in the
// threshold-gate pending map. Exposed for metrics and tests.
func (c *Correlator) PendingCount() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.pending)
}
// PruneStalePending removes pending findings whose age relative to now
// exceeds the merge window. Returns the number pruned. Called by the
// daemon's retention loop so a host with sustained one-shot scanner
// traffic does not grow the pending map without bound.
func (c *Correlator) PruneStalePending(now time.Time) int {
c.mu.Lock()
defer c.mu.Unlock()
cutoff := now.Add(-incidentMergeWindow)
pruned := 0
for k, pf := range c.pending {
if pf.at.Before(cutoff) {
delete(c.pending, k)
pruned++
}
}
return pruned
}
// PruneStaleSpray clears spray-detector state whose lastSeen is older
// than the merge window. Wired into the same retention sweep as
// PruneStalePending.
func (c *Correlator) PruneStaleSpray(now time.Time) int {
c.mu.Lock()
defer c.mu.Unlock()
return c.spray.PruneStale(now)
}
// SprayTrackedIPs reports the count of source IPs currently held in
// the spray detector. Safe to call when the detector is disabled.
func (c *Correlator) SprayTrackedIPs() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.spray.TrackedIPs()
}
// OpenCount returns the number of incidents in Open or Contained
// status. Used by the csm_incidents_open gauge; computed at scrape
// time so the value never drifts from in-memory state.
func (c *Correlator) OpenCount() int {
c.mu.Lock()
defer c.mu.Unlock()
n := 0
for _, inc := range c.incidents {
if inc.Status == StatusOpen || inc.Status == StatusContained {
n++
}
}
return n
}
// Get returns a snapshot of the incident by id.
func (c *Correlator) Get(id string) (Incident, bool) {
c.mu.Lock()
defer c.mu.Unlock()
inc, ok := c.incidents[id]
if !ok {
return Incident{}, false
}
return cloneIncident(*inc), true
}
// SnapshotPage returns a page of incidents matching status (empty
// string means all statuses), starting at offset, with at most limit
// items. The returned total is the number of records that match the
// filter regardless of the page bounds, so the caller can render an
// accurate "X of Y" header.
//
// limit <= 0 returns the rest of the filtered set after offset. The
// caller (web UI / phpanel) is expected to cap the page at a sane
// ceiling; this primitive only enforces correct slicing. Negative
// offset is clamped to zero.
//
// Items are deep-copied so callers may mutate the returned slice
// without affecting subsequent calls.
func (c *Correlator) SnapshotPage(status Status, offset, limit int) ([]Incident, int) {
if status == "" {
return c.SnapshotPageStatuses(nil, offset, limit)
}
return c.SnapshotPageStatuses([]Status{status}, offset, limit)
}
// SnapshotPageStatuses returns a page of incidents matching any status
// in statuses. An empty status list means all statuses. Sorting and
// slicing happen against internal pointers first; only the returned
// page is deep-copied.
func (c *Correlator) SnapshotPageStatuses(statuses []Status, offset, limit int) ([]Incident, int) {
c.mu.Lock()
defer c.mu.Unlock()
statusSet := make(map[Status]struct{}, len(statuses))
for _, st := range statuses {
if st != "" {
statusSet[st] = struct{}{}
}
}
matched := make([]*Incident, 0, len(c.incidents))
for _, inc := range c.incidents {
if len(statusSet) > 0 {
if _, ok := statusSet[inc.Status]; !ok {
continue
}
}
matched = append(matched, inc)
}
sortIncidentRefs(matched)
total := len(matched)
if offset < 0 {
offset = 0
}
if offset >= total {
return []Incident{}, total
}
end := total
if limit > 0 && offset+limit < end {
end = offset + limit
}
out := make([]Incident, 0, end-offset)
for _, inc := range matched[offset:end] {
out = append(out, cloneIncident(*inc))
}
return out, total
}
func sortIncidentRefs(refs []*Incident) {
sort.Slice(refs, func(i, j int) bool {
return incidentRefLess(refs[i], refs[j])
})
}
func incidentRefLess(a, b *Incident) bool {
if !a.UpdatedAt.Equal(b.UpdatedAt) {
return a.UpdatedAt.After(b.UpdatedAt)
}
return a.ID > b.ID
}
// Snapshot returns every incident sorted by UpdatedAt descending. Safe
// for concurrent callers; produces a deep-copy slice so the API layer
// can serialize it without coordinating with mutators.
func (c *Correlator) Snapshot() []Incident {
c.mu.Lock()
defer c.mu.Unlock()
refs := make([]*Incident, 0, len(c.incidents))
for _, inc := range c.incidents {
refs = append(refs, inc)
}
sortIncidentRefs(refs)
out := make([]Incident, 0, len(refs))
for _, inc := range refs {
out = append(out, cloneIncident(*inc))
}
return out
}
// mergeLocked folds f into inc. merged=true means this is a join into
// an existing incident (bumps findings_merged_total); merged=false means
// the caller already created the incident and is using mergeLocked only
// to seed the first finding -- in that case the create path owns the
// "did a new incident appear" tally.
func (c *Correlator) mergeLocked(inc *Incident, f alert.Finding, now time.Time, merged bool) {
if merged {
c.counters.findingsMergedTotal.Add(1)
}
// Re-classify before appending so timeline-aware compound rules
// see the unchanged history; the new finding is passed in
// explicitly so its Check participates in the compound check.
priorKind := inc.Kind
maybeReclassifyKind(inc, f)
if inc.Kind != priorKind {
inc.Actions = append(inc.Actions, IncidentAction{
Time: now,
Action: "incident_kind_changed",
Result: "ok",
Details: string(priorKind) + " -> " + string(inc.Kind),
})
}
inc.Findings = appendCappedFingerprint(inc.Findings, f.Fingerprint())
ev := IncidentEvent{
Time: f.Timestamp,
Kind: "finding",
Check: f.Check,
Message: f.Message,
}
if f.Process != nil {
ev.PID = f.Process.PID
ev.UID = f.Process.UID
ev.Process = f.Process.Comm
}
if f.FilePath != "" {
ev.Path = f.FilePath
}
if f.SourceIP != "" {
ev.RemoteIP = f.SourceIP
}
inc.Timeline = appendCappedTimeline(inc.Timeline, ev)
if f.Severity > inc.Severity {
from := inc.Severity
inc.Severity = f.Severity
c.counters.severityChangedTotal.Add(1)
inc.Actions = append(inc.Actions, IncidentAction{
Time: now,
Action: "incident_severity_changed",
Result: "ok",
Details: from.String() + " -> " + f.Severity.String(),
})
}
inc.UpdatedAt = now
c.persistLocked(*inc)
}
// appendCappedFingerprint appends fp to fps and trims to
// maxIncidentFindings via first-half + last-half retention when the
// cap is crossed. Keeps the original opening signals plus the most
// recent traffic, dropping the middle.
func appendCappedFingerprint(fps []string, fp string) []string {
fps = append(fps, fp)
if len(fps) <= maxIncidentFindings {
return fps
}
real := make([]string, 0, len(fps))
elided := 0
for _, existing := range fps {
if n, ok := fingerprintTruncationCount(existing); ok {
elided += n
continue
}
real = append(real, existing)
}
if len(real) <= maxIncidentFindings {
return fingerprintsWithTruncationMarker(real, elided)
}
half := maxIncidentFindings / 2
tailLen := maxIncidentFindings - half
elided += len(real) - half - tailLen
head := append([]string(nil), real[:half]...)
tail := append([]string(nil), real[len(real)-tailLen:]...)
gap := []string{formatFingerprintTruncation(elided)}
return append(append(head, gap...), tail...)
}
// appendCappedTimeline behaves the same as appendCappedFingerprint
// for the operator-visible IncidentEvent slice; the truncation marker
// is rendered as a synthetic "truncated" event so the UI can show a
// "X events elided" row.
func appendCappedTimeline(events []IncidentEvent, ev IncidentEvent) []IncidentEvent {
events = append(events, ev)
if len(events) <= maxIncidentTimeline {
return events
}
real := make([]IncidentEvent, 0, len(events))
elided := 0
var markerTime time.Time
for _, existing := range events {
if n, ok := timelineTruncationCount(existing); ok {
elided += n
if markerTime.IsZero() || existing.Time.Before(markerTime) {
markerTime = existing.Time
}
continue
}
real = append(real, existing)
}
if len(real) <= maxIncidentTimeline {
return timelineWithTruncationMarker(real, elided, markerTime)
}
half := maxIncidentTimeline / 2
tailLen := maxIncidentTimeline - half
if markerTime.IsZero() {
markerTime = real[half].Time
}
elided += len(real) - half - tailLen
head := append([]IncidentEvent(nil), real[:half]...)
tail := append([]IncidentEvent(nil), real[len(real)-tailLen:]...)
gap := []IncidentEvent{timelineTruncationMarker(elided, markerTime)}
return append(append(head, gap...), tail...)
}
func fingerprintsWithTruncationMarker(fps []string, elided int) []string {
if elided == 0 {
return fps
}
half := len(fps) / 2
out := make([]string, 0, len(fps)+1)
out = append(out, fps[:half]...)
out = append(out, formatFingerprintTruncation(elided))
out = append(out, fps[half:]...)
return out
}
func fingerprintTruncationCount(fp string) (int, bool) {
if !strings.HasPrefix(fp, incidentFingerprintTruncatedMark) {
return 0, false
}
rest := strings.TrimPrefix(fp, incidentFingerprintTruncatedMark)
fields := strings.Fields(rest)
if len(fields) == 0 {
return 1, true
}
n, err := strconv.Atoi(fields[0])
if err != nil || n < 1 {
return 1, true
}
return n, true
}
func formatFingerprintTruncation(count int) string {
return incidentFingerprintTruncatedMark + strconv.Itoa(count) + " findings elided"
}
func timelineWithTruncationMarker(events []IncidentEvent, elided int, markerTime time.Time) []IncidentEvent {
if elided == 0 {
return events
}
if markerTime.IsZero() && len(events) > 0 {
markerTime = events[len(events)/2].Time
}
half := len(events) / 2
out := make([]IncidentEvent, 0, len(events)+1)
out = append(out, events[:half]...)
out = append(out, timelineTruncationMarker(elided, markerTime))
out = append(out, events[half:]...)
return out
}
func timelineTruncationCount(ev IncidentEvent) (int, bool) {
if ev.Kind != incidentTimelineTruncatedKind {
return 0, false
}
fields := strings.Fields(ev.Message)
if len(fields) == 0 {
return 1, true
}
n, err := strconv.Atoi(fields[0])
if err != nil || n < 1 {
return 1, true
}
return n, true
}
func timelineTruncationMarker(count int, at time.Time) IncidentEvent {
return IncidentEvent{
Time: at,
Kind: incidentTimelineTruncatedKind,
Message: strconv.Itoa(count) + " events elided to cap incident size",
}
}
type queuedPersist struct {
previous <-chan struct{}
done chan struct{}
snap Incident
persist func(Incident)
}
// queuePersistLocked reserves this write's place in mutation order while
// c.mu is still held. The returned callback must run after c.mu is released.
func (c *Correlator) queuePersistLocked(snap Incident) (queuedPersist, bool) {
persist := c.cfg.Persist
if persist == nil {
return queuedPersist{}, false
}
snap = cloneIncident(snap)
done := make(chan struct{})
c.persistMu.Lock()
previous := c.persistTail
c.persistTail = done
c.persistMu.Unlock()
return queuedPersist{
previous: previous,
done: done,
snap: snap,
persist: persist,
}, true
}
func (c *Correlator) runQueuedPersist(req queuedPersist) {
<-req.previous
defer close(req.done)
req.persist(req.snap)
}
// persistLocked invokes the Persist callback while temporarily releasing
// the correlator mutex so a re-entrant Persist that reads Correlator
// state does not deadlock. The caller MUST already hold c.mu; the
// deferred re-Lock keeps the "mu held on return" contract that
// mergeLocked's callers rely on.
func (c *Correlator) persistLocked(snap Incident) {
req, ok := c.queuePersistLocked(snap)
if !ok {
return
}
c.mu.Unlock()
defer c.mu.Lock()
c.runQueuedPersist(req)
}
// SetStatus transitions an incident's status. On Resolved/Dismissed
// the incident is unbound from the active byKey index so future
// findings for the same correlation key start a fresh incident.
// Returns ErrIncidentNotFound if id is unknown.
func (c *Correlator) SetStatus(id string, status Status, details string) error {
if !validStatus(status) {
return fmt.Errorf("incident: invalid status %q", status)
}
c.mu.Lock()
defer c.mu.Unlock()
inc, ok := c.incidents[id]
if !ok {
return ErrIncidentNotFound
}
if inc.Status == status {
return nil
}
now := c.now()
from := inc.Status
inc.Status = status
inc.UpdatedAt = now
inc.Actions = append(inc.Actions, IncidentAction{
Time: now,
Action: "incident_status_changed",
Result: "ok",
Details: string(from) + " -> " + string(status) + ": " + details,
})
c.counters.statusChangedTotal.Add(1)
if status == StatusResolved || status == StatusDismissed {
// Operator close: record provenance so reporting can distinguish
// from CloseStale's "auto:stale" attribution. Only set if the
// caller has not already assigned a closed reason (e.g. CloseStale
// reuses SetStatus internally and presets these fields).
if inc.ClosedAt.IsZero() {
inc.ClosedAt = now
inc.ClosedBy = "operator"
}
c.unbindLocked(id)
if c.spray != nil {
c.spray.UnbindIncident(id)
}
} else {
// Reverting from resolved/dismissed back to open or contained
// clears the close attribution so future closes attribute correctly.
inc.ClosedAt = time.Time{}
inc.ClosedBy = ""
c.bindLocked(inc)
}
c.persistLocked(*inc)
return nil
}
// CloseStale auto-resolves Open / Contained incidents whose UpdatedAt is
// older than the per-kind threshold in `idleThresholds`. Kinds absent
// from the map are never closed (the caller decides which kinds expire).
// dryRun=true counts decisions without mutating state, so an operator
// can validate thresholds before flipping the live switch. Returns
// (closed, dryRun, total-scanned).
//
// Closing unbinds both the incident key and any spray detector state,
// so future findings are evaluated as new activity instead of merging
// into the closed incident.
func (c *Correlator) CloseStale(now time.Time, idleThresholds map[Kind]time.Duration, dryRun bool) (closed, dryRunCount, scanned int) {
closed, dryRunCount, scanned, _ = c.CloseStaleLimited(now, idleThresholds, dryRun, 0)
return closed, dryRunCount, scanned
}
// CloseStaleLimited is CloseStale with a per-call cap on the number of
// incidents it resolves. limit <= 0 means unbounded (the CloseStale
// behaviour). When the cap is hit, more=true signals the caller that
// stale incidents remain so it can schedule a prompt follow-up sweep
// instead of waiting the full auto-close interval. Bounding the work
// keeps a large post-restart backlog from holding c.mu and bursting
// thousands of bbolt persists in a single tick. The cap applies only to
// live closes; a dry-run pass always scans the full set so its counters
// stay accurate.
func (c *Correlator) CloseStaleLimited(now time.Time, idleThresholds map[Kind]time.Duration, dryRun bool, limit int) (closed, dryRunCount, scanned int, more bool) {
if len(idleThresholds) == 0 {
return 0, 0, 0, false
}
var persist []queuedPersist
c.mu.Lock()
for id, inc := range c.incidents {
if inc.Status != StatusOpen && inc.Status != StatusContained {
continue
}
threshold, ok := idleThresholds[inc.Kind]
if !ok || threshold <= 0 {
continue
}
idle := now.Sub(inc.UpdatedAt)
if idle <= threshold {
continue
}
if !dryRun && limit > 0 && closed >= limit {
more = true
break
}
scanned++
if dryRun {
c.counters.autoCloseDryRunTotal.Add(1)
dryRunCount++
continue
}
from := inc.Status
inc.Status = StatusResolved
inc.UpdatedAt = now
inc.ClosedAt = now
inc.ClosedBy = "auto:stale"
inc.Actions = append(inc.Actions, IncidentAction{
Time: now,
Action: "incident_auto_closed",
Result: "ok",
Details: string(from) + " -> resolved: stale " + idle.Truncate(time.Second).String(),
})
c.counters.statusChangedTotal.Add(1)
c.counters.autoClosedTotal.Add(1)
c.unbindLocked(id)
if c.spray != nil {
c.spray.UnbindIncident(id)
}
if req, ok := c.queuePersistLocked(*inc); ok {
persist = append(persist, req)
}
closed++
}
c.mu.Unlock()
for _, req := range persist {
c.runQueuedPersist(req)
}
return closed, dryRunCount, scanned, more
}
// closeIncidentLocked force-resolves an active incident with the given
// attribution. Caller holds c.mu and must queue/run the persist of *inc.
func (c *Correlator) closeIncidentLocked(inc *Incident, id string, now time.Time, by, detail string) {
from := inc.Status
inc.Status = StatusResolved
inc.UpdatedAt = now
inc.ClosedAt = now
inc.ClosedBy = by
inc.Actions = append(inc.Actions, IncidentAction{
Time: now,
Action: "incident_auto_closed",
Result: "ok",
Details: string(from) + " -> resolved: " + detail,
})
c.counters.statusChangedTotal.Add(1)
c.counters.autoClosedTotal.Add(1)
c.unbindLocked(id)
if c.spray != nil {
c.spray.UnbindIncident(id)
}
}
// CloseStaleByAge is a kind-agnostic safety cap. It force-closes any Open or
// Contained incident whose UpdatedAt is older than maxAge, independent of the
// operator's per-kind auto-close thresholds. Without it, disabling auto-close
// (or omitting a kind from the threshold map) lets active incidents accumulate
// without bound in memory and bbolt on a host under sustained attack. limit
// bounds closures per call (0 = unbounded); more reports a remaining backlog.
func (c *Correlator) CloseStaleByAge(now time.Time, maxAge time.Duration, limit int) (closed int, more bool) {
if maxAge <= 0 {
return 0, false
}
var persist []queuedPersist
c.mu.Lock()
for id, inc := range c.incidents {
if inc.Status != StatusOpen && inc.Status != StatusContained {
continue
}
if now.Sub(inc.UpdatedAt) <= maxAge {
continue
}
if limit > 0 && closed >= limit {
more = true
break
}
c.closeIncidentLocked(inc, id, now, "auto:age_cap", "stale age cap "+maxAge.Truncate(time.Second).String())
if req, ok := c.queuePersistLocked(*inc); ok {
persist = append(persist, req)
}
closed++
}
c.mu.Unlock()
for _, req := range persist {
c.runQueuedPersist(req)
}
return closed, more
}
// EnforceActiveCap bounds how many Open/Contained incidents are held in
// memory. When the active count exceeds maxActive it force-closes the oldest
// (by UpdatedAt) until the count is back at maxActive or the per-call limit is
// reached. This protects against a flood of distinct incidents arriving within
// the age-cap window. maxActive <= 0 disables the cap; limit <= 0 is
// unbounded. more reports that incidents remained over the cap after limit.
func (c *Correlator) EnforceActiveCap(now time.Time, maxActive, limit int) (closed int, more bool) {
if maxActive <= 0 {
return 0, false
}
c.mu.Lock()
type activeRef struct {
id string
updated time.Time
}
var active []activeRef
for id, inc := range c.incidents {
if inc.Status == StatusOpen || inc.Status == StatusContained {
active = append(active, activeRef{id: id, updated: inc.UpdatedAt})
}
}
if len(active) <= maxActive {
c.mu.Unlock()
return 0, false
}
sort.Slice(active, func(i, j int) bool { return active[i].updated.Before(active[j].updated) })
overflow := len(active) - maxActive
var persist []queuedPersist
for _, a := range active {
if overflow <= 0 {
break
}
if limit > 0 && closed >= limit {
more = true
break
}
inc := c.incidents[a.id]
if inc == nil || (inc.Status != StatusOpen && inc.Status != StatusContained) {
continue
}
c.closeIncidentLocked(inc, a.id, now, "auto:active_cap", "active incident cap "+strconv.Itoa(maxActive))
if req, ok := c.queuePersistLocked(*inc); ok {
persist = append(persist, req)
}
closed++
overflow--
}
c.mu.Unlock()
for _, req := range persist {
c.runQueuedPersist(req)
}
return closed, more
}
// validStatus reports whether s is one of the four spec-defined values.
// Guards SetStatus against arbitrary strings reaching the persisted
// timeline; control-socket and webui handlers also reject early but
// the correlator owns the type and must not trust callers.
func validStatus(s Status) bool {
switch s {
case StatusOpen, StatusContained, StatusResolved, StatusDismissed:
return true
}
return false
}
func incidentStatusActive(s Status) bool {
return s == StatusOpen || s == StatusContained
}
// IncrementCompactedTotal bumps the compaction counter by n. Called
// from the daemon-side retention scheduler after store.CompactIncidents
// removes records. Negative inputs are ignored so a buggy caller cannot
// underflow the monotonic counter.
func (c *Correlator) IncrementCompactedTotal(n int) {
if n < 0 {
return
}
c.counters.compactedTotal.Add(uint64(n))
}
// PruneClosedOlderThan removes resolved/dismissed incidents older than
// retention from the in-memory map. Store compaction removes the durable
// records; this keeps API/control snapshots from serving stale incidents
// until the next daemon restart.
func (c *Correlator) PruneClosedOlderThan(now time.Time, retention time.Duration) int {
c.mu.Lock()
defer c.mu.Unlock()
cutoff := now.Add(-retention)
pruned := 0
var prunedIDs []string
for id, inc := range c.incidents {
if inc.Status != StatusResolved && inc.Status != StatusDismissed {
continue
}
if !inc.UpdatedAt.Before(cutoff) {
continue
}
delete(c.incidents, id)
c.unbindLocked(id)
if c.spray != nil {
prunedIDs = append(prunedIDs, id)
}
pruned++
}
if c.spray != nil {
c.spray.UnbindIncidents(prunedIDs)
}
return pruned
}
// Restore re-hydrates correlator state from a list previously loaded
// from the store. Open and Contained incidents are bound to the
// byKey index so a finding arriving inside the merge window joins the
// existing incident; Resolved/Dismissed incidents are loaded into the
// id map only (Get still returns them) but do NOT claim their key, so
// future findings for the same key start a fresh incident.
func (c *Correlator) Restore(incidents []Incident) {
c.mu.Lock()
defer c.mu.Unlock()
for i := range incidents {
inc := incidents[i]
c.incidents[inc.ID] = &inc
if inc.Status != StatusOpen && inc.Status != StatusContained {
continue
}
c.bindLocked(&inc)
// credential_spray super-incidents persist in bbolt but the spray
// detector's perIP map is in-memory only. Without this rehydration
// step a daemon restart while an attacker is mid-spray causes the
// detector to re-trip and open a duplicate super-incident even
// though the original is still active.
if c.spray != nil && inc.Kind == KindCredentialSpray && inc.CorrelationKey != nil {
ip := inc.CorrelationKey.RemoteIP
if ip != "" {
c.spray.Rehydrate(ip, inc.ID, inc.UpdatedAt)
}
}
}
}
func (c *Correlator) bindLocked(inc *Incident) {
key, ok := incidentKey(*inc)
if !ok {
return
}
c.byKey[keyString(key)] = inc.ID
}
func (c *Correlator) unbindLocked(id string) {
// Scan-and-delete by value rather than rebuilding the key: incidents
// can be keyed by account, domain, mailbox, process, remote IP, or a
// combination. byKey only holds active incidents so the scan is bounded.
for k, v := range c.byKey {
if v == id {
delete(c.byKey, k)
}
}
}
func incidentKey(inc Incident) (Key, bool) {
if inc.CorrelationKey != nil && !inc.CorrelationKey.IsEmpty() {
return canonicalizeKey(*inc.CorrelationKey), true
}
key := Key{Account: inc.Account, Domain: inc.Domain, Mailbox: inc.Mailbox}
key = canonicalizeKey(key)
if key.IsEmpty() {
return Key{}, false
}
return key, true
}
func cloneIncident(in Incident) Incident {
out := in
out.Findings = append([]string(nil), in.Findings...)
out.Timeline = append([]IncidentEvent(nil), in.Timeline...)
out.Actions = append([]IncidentAction(nil), in.Actions...)
if in.CorrelationKey != nil {
key := *in.CorrelationKey
out.CorrelationKey = &key
}
return out
}
func cloneKey(k Key) *Key {
if k.IsEmpty() {
return nil
}
key := k
return &key
}
// keyString serializes a Key into a stable string for the byKey map.
// All fields selected by KeyFor must be encoded so distinct findings
// (e.g. different PID-only processes or different remote IPs) do not
// collapse to the same bucket and falsely merge.
func keyString(k Key) string {
return fmt.Sprintf("%d:%s|%d:%s|%d:%s|%d:%s|%d|%d|%d:%s",
len(k.Host), k.Host,
len(k.Account), k.Account,
len(k.Mailbox), k.Mailbox,
len(k.Domain), k.Domain,
k.UID,
k.PID,
len(k.RemoteIP), k.RemoteIP,
)
}
func opensIncidentImmediately(f alert.Finding) bool {
if f.Severity >= alert.Critical {
return true
}
return f.Severity >= alert.High && ClassifyKind(f) == KindHostIntegrityRisk
}
func newIncidentID() string {
var buf [6]byte
_, _ = rand.Read(buf[:])
return "inc_" + hex.EncodeToString(buf[:])
}
// maybeBlockSprayLocked is the single decision point for the
// credential_spray firewall hand-off. Unlike a transition-only hook,
// this runs on every spray decision (open, merge, escalate) so an
// operator who arms BlockAtSeverity AFTER an incident has already
// reached the configured severity still gets a block on the next
// matching finding. Idempotency is provided by
// triggerSprayBlockLocked's action-presence and in-flight guards, so
// calling this helper repeatedly against the same incident emits one
// live firewall call at a time.
//
// Returns the callback that the caller must invoke after releasing
// c.mu (matches the existing sprayDecisionOpen contract), or nil when
// no block is owed.
func (c *Correlator) maybeBlockSprayLocked(inc *Incident, ip string, hits int, now time.Time, reason string) func() {
if inc == nil || c.spray == nil || c.cfg.OnSprayBlock == nil {
return nil
}
if !incidentStatusActive(inc.Status) {
return nil
}
if !c.sprayBlockAllowed() {
return nil
}
switch strings.ToLower(c.spray.cfg.BlockAtSeverity) {
case "high":
if inc.Severity < alert.High {
return nil
}
case "critical":
if inc.Severity < alert.Critical {
return nil
}
default:
return nil
}
return c.triggerSprayBlockLocked(inc, ip, hits, now, reason)
}
func (c *Correlator) sprayBlockAllowed() bool {
return c.cfg.CanSprayBlock == nil || c.cfg.CanSprayBlock()
}
// maybeBlockIncidentLocked is the decision point for the generic
// incident-driven firewall hand-off. Runs on every create / merge so
// an operator who arms AutoBlock AFTER an incident has already crossed
// the gate still gets a block on the next finding. Idempotent via the
// action-presence and in-flight guards in triggerIncidentBlockLocked.
//
// Skips credential_spray incidents -- the dedicated spray hand-off owns
// those so we avoid double-firing.
//
// Returns the callback that the caller must invoke after releasing
// c.mu, or nil when no block is owed.
func (c *Correlator) maybeBlockIncidentLocked(inc *Incident, now time.Time, why string) func() {
if inc == nil || c.cfg.OnIncidentBlock == nil || !c.cfg.AutoBlock.Enabled {
return nil
}
if !incidentStatusActive(inc.Status) {
return nil
}
if inc.Kind == KindCredentialSpray {
return nil
}
if c.sprayOwnsIncident(inc) {
return nil
}
if !c.incidentBlockAllowed() {
return nil
}
ip := incidentBlockCandidate(inc)
if ip == "" {
return nil
}
if len(c.cfg.AutoBlock.Kinds) > 0 && !c.cfg.AutoBlock.Kinds[inc.Kind] {
return nil
}
switch strings.ToLower(c.cfg.AutoBlock.BlockAtSeverity) {
case "high":
if inc.Severity < alert.High {
return nil
}
case "critical":
if inc.Severity < alert.Critical {
return nil
}
default:
return nil
}
return c.triggerIncidentBlockLocked(inc, ip, now, why)
}
func (c *Correlator) incidentBlockAllowed() bool {
return c.cfg.CanIncidentBlock == nil || c.cfg.CanIncidentBlock()
}
// triggerIncidentBlockLocked is the per-incident emit point for the
// generic auto-block path. It marks the incident as in-flight, returns
// the deferred callback, then appends "incident_block_requested" only
// when the callback reports a live block request. Dry-run attempts are
// intentionally not latched so a later finding can retry after the
// operator disables dry-run.
func (c *Correlator) triggerIncidentBlockLocked(inc *Incident, ip string, now time.Time, why string) func() {
if inc == nil || c.cfg.OnIncidentBlock == nil {
return nil
}
if hasIncidentAction(inc.Actions, "incident_block_requested") {
return nil
}
if _, ok := c.pendingIncidentBlocks[inc.ID]; ok {
return nil
}
c.pendingIncidentBlocks[inc.ID] = struct{}{}
incidentID := inc.ID
reason := "incident " + string(inc.Kind) + " " + inc.Severity.String() + " (" + why + ")"
onBlock := c.cfg.OnIncidentBlock
return func() {
var live bool
callbackReturned := false
// The in-flight slot must clear even if onBlock panics. Otherwise,
// later findings keep seeing the incident as already in flight and
// skip the auto-block path for the rest of the incident lifetime.
defer func() {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.pendingIncidentBlocks, incidentID)
if !callbackReturned || !live {
return
}
current, ok := c.incidents[incidentID]
if !ok || hasIncidentAction(current.Actions, "incident_block_requested") {
return
}
current.Actions = append(current.Actions, IncidentAction{
Time: now,
Action: "incident_block_requested",
Result: "ok",
Details: ip + " " + reason,
})
c.persistLocked(*current)
}()
live = onBlock(ip, reason)
callbackReturned = true
}
}
func incidentBlockCandidate(inc *Incident) string {
if inc == nil {
return ""
}
if inc.CorrelationKey != nil {
if ip := normalizeIncidentRemoteIP(inc.CorrelationKey.RemoteIP); ip != "" {
return ip
}
}
var candidate string
for _, ev := range inc.Timeline {
if ev.Kind == incidentTimelineTruncatedKind {
return ""
}
ip := normalizeIncidentRemoteIP(ev.RemoteIP)
if ip == "" {
continue
}
if candidate == "" {
candidate = ip
continue
}
if candidate != ip {
return ""
}
}
return candidate
}
func (c *Correlator) sprayOwnsIncident(inc *Incident) bool {
if c == nil || c.spray == nil || inc == nil {
return false
}
owned := false
for _, ev := range inc.Timeline {
if ev.RemoteIP == "" {
continue
}
if !c.spray.cfg.PerCheck[ev.Check] {
return false
}
owned = true
}
return owned
}
func normalizeIncidentRemoteIP(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
if host, _, err := net.SplitHostPort(raw); err == nil {
raw = host
}
raw = strings.Trim(raw, "[]")
ip := net.ParseIP(raw)
if ip == nil || ip.IsLoopback() || ip.IsUnspecified() {
return ""
}
return ip.String()
}
func hasIncidentAction(actions []IncidentAction, action string) bool {
for _, a := range actions {
if a.Action == action {
return true
}
}
return false
}
// triggerSprayBlockLocked invokes the operator-supplied OnSprayBlock
// callback at most once while a spray block is already in flight and
// appends an audit action only when the callback reports the firewall
// actually applied the block. Caller holds c.mu. The returned callback
// must be invoked after unlocking. Idempotent after success; declined
// callbacks can retry on a later finding after the in-flight marker is
// cleared.
func (c *Correlator) triggerSprayBlockLocked(inc *Incident, ip string, hits int, now time.Time, why string) func() {
if inc == nil || c.cfg.OnSprayBlock == nil {
return nil
}
if hasIncidentAction(inc.Actions, "credential_spray_block_requested") {
return nil
}
if _, ok := c.pendingSprayBlocks[inc.ID]; ok {
return nil
}
c.pendingSprayBlocks[inc.ID] = struct{}{}
reason := "credential_spray: " + strconv.Itoa(hits) + " distinct mailboxes (" + why + ")"
onSprayBlock := c.cfg.OnSprayBlock
incidentID := inc.ID
return func() {
var live bool
callbackReturned := false
// Mirror the panic-safety guarantee from triggerIncidentBlockLocked:
// the in-flight slot must clear even if onSprayBlock panics so a
// recurring panic class does not latch the credential_spray
// incident out of the auto-block path forever.
defer func() {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.pendingSprayBlocks, incidentID)
if !callbackReturned || !live {
return
}
current, ok := c.incidents[incidentID]
if !ok || hasIncidentAction(current.Actions, "credential_spray_block_requested") {
return
}
current.Actions = append(current.Actions, IncidentAction{
Time: now,
Action: "credential_spray_block_requested",
Result: "ok",
Details: ip + " " + reason,
})
c.persistLocked(*current)
}()
live = onSprayBlock(ip, reason)
callbackReturned = true
}
}
package incident
import (
"sort"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// IncidentGroupsScanCap is the hard upper bound on matching incidents
// grouped per BuildGroups call. Rows excluded by status/kind filters do
// not consume the cap.
const IncidentGroupsScanCap = 10000
// Group is one row of the grouped incident view: a (kind, source)
// bucket plus rolled-up counters. Source identifies what the bucket is
// keyed on -- IP for credential-spray patterns, account/domain/mailbox
// when the source IP is unknown.
type Group struct {
Key string `json:"key"`
Kind Kind `json:"kind"`
SourceKind string `json:"source_kind"`
Source string `json:"source"`
IncidentCount int `json:"incident_count"`
OpenCount int `json:"open_count"`
ContainedCount int `json:"contained_count"`
ResolvedCount int `json:"resolved_count"`
DismissedCount int `json:"dismissed_count"`
SeverityMax alert.Severity `json:"-"`
SeverityLabel string `json:"severity_max"`
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
SampleIDs []string `json:"sample_ids"`
}
// GroupsResponse is the wire shape for /api/v1/incidents/groups.
type GroupsResponse struct {
Groups []Group `json:"groups"`
TotalGroups int `json:"total_groups"`
ScannedIncidents int `json:"scanned_incidents"`
Truncated bool `json:"truncated"`
}
// GroupFilter narrows what BuildGroups buckets. Empty fields mean "no
// filter on that dimension".
type GroupFilter struct {
// StatusSet, when non-empty, restricts incidents to the listed
// statuses. Pass {StatusOpen, StatusContained} to surface only
// active incidents (the default web UI mode).
StatusSet []Status
// Kind, when non-empty, restricts incidents to a specific kind.
Kind Kind
// Offset is the starting index into the sorted group list. Used by
// the UI to page through groups when TotalGroups exceeds MaxGroups.
Offset int
// MaxGroups caps the returned slice. Zero or negative means "no
// cap"; the handler still applies a sane upper bound.
MaxGroups int
}
// BuildGroups buckets the supplied incidents by (kind, source) and
// returns the rolled-up groups sorted by incident_count desc, then
// severity_max desc, then last_seen desc. SampleIDs holds up to three
// of the most recently updated members of each group so the UI can
// drill in without a follow-up call.
//
// `incidents` may be the full correlator snapshot. The function caps
// its scan at IncidentGroupsScanCap after status/kind filtering; the
// returned `truncated` flag reports whether the cap clipped matching
// incidents. Callers fed by Correlator.Snapshot() get an already-newest-
// first slice; sort stability there means the truncation drops the oldest
// matching entries first, which is what an operator wants.
func BuildGroups(incidents []Incident, filter GroupFilter) GroupsResponse {
statusAllowed := func(Status) bool { return true }
if len(filter.StatusSet) > 0 {
set := make(map[Status]struct{}, len(filter.StatusSet))
for _, s := range filter.StatusSet {
set[s] = struct{}{}
}
statusAllowed = func(s Status) bool {
_, ok := set[s]
return ok
}
}
type sampleEntry struct {
id string
updatedAt time.Time
}
type aggregator struct {
group Group
samples []sampleEntry
statusMap map[Status]int
}
bucketKey := func(kind Kind, sourceKind, source string) string {
return string(kind) + "|" + sourceKind + ":" + source
}
buckets := make(map[string]*aggregator)
scanned := 0
truncated := false
for _, inc := range incidents {
if !statusAllowed(inc.Status) {
continue
}
if filter.Kind != "" && inc.Kind != filter.Kind {
continue
}
if scanned >= IncidentGroupsScanCap {
truncated = true
break
}
scanned++
sourceKind, source := groupSource(inc)
k := bucketKey(inc.Kind, sourceKind, source)
agg, ok := buckets[k]
if !ok {
agg = &aggregator{
group: Group{
Key: k,
Kind: inc.Kind,
SourceKind: sourceKind,
Source: source,
FirstSeen: inc.CreatedAt,
LastSeen: inc.UpdatedAt,
},
statusMap: map[Status]int{},
}
buckets[k] = agg
}
agg.group.IncidentCount++
agg.statusMap[inc.Status]++
if inc.Severity > agg.group.SeverityMax {
agg.group.SeverityMax = inc.Severity
}
if inc.CreatedAt.Before(agg.group.FirstSeen) || agg.group.FirstSeen.IsZero() {
agg.group.FirstSeen = inc.CreatedAt
}
if inc.UpdatedAt.After(agg.group.LastSeen) {
agg.group.LastSeen = inc.UpdatedAt
}
agg.samples = append(agg.samples, sampleEntry{id: inc.ID, updatedAt: inc.UpdatedAt})
}
out := make([]Group, 0, len(buckets))
for _, agg := range buckets {
agg.group.OpenCount = agg.statusMap[StatusOpen]
agg.group.ContainedCount = agg.statusMap[StatusContained]
agg.group.ResolvedCount = agg.statusMap[StatusResolved]
agg.group.DismissedCount = agg.statusMap[StatusDismissed]
agg.group.SeverityLabel = agg.group.SeverityMax.String()
// Top-3 most recently updated members.
sort.SliceStable(agg.samples, func(i, j int) bool {
return agg.samples[i].updatedAt.After(agg.samples[j].updatedAt)
})
n := len(agg.samples)
if n > 3 {
n = 3
}
ids := make([]string, n)
for i := 0; i < n; i++ {
ids[i] = agg.samples[i].id
}
agg.group.SampleIDs = ids
out = append(out, agg.group)
}
sort.SliceStable(out, func(i, j int) bool {
if out[i].IncidentCount != out[j].IncidentCount {
return out[i].IncidentCount > out[j].IncidentCount
}
if out[i].SeverityMax != out[j].SeverityMax {
return out[i].SeverityMax > out[j].SeverityMax
}
return out[i].LastSeen.After(out[j].LastSeen)
})
totalGroups := len(out)
if filter.Offset > 0 {
if filter.Offset >= len(out) {
out = out[:0]
} else {
out = out[filter.Offset:]
}
}
if filter.MaxGroups > 0 && len(out) > filter.MaxGroups {
out = out[:filter.MaxGroups]
}
return GroupsResponse{
Groups: out,
TotalGroups: totalGroups,
ScannedIncidents: scanned,
Truncated: truncated,
}
}
// groupSource derives the (source_kind, source) pair the UI uses to
// label and drill into a group. Cascade order: host > remote_ip >
// account > domain > mailbox > "_unkeyed". The IP path is the most
// useful grouping for credential-spray patterns and the most common
// shape on busy hosts.
func groupSource(inc Incident) (sourceKind, source string) {
if inc.CorrelationKey != nil && inc.CorrelationKey.Host != "" {
return "host", inc.CorrelationKey.Host
}
if inc.CorrelationKey != nil && inc.CorrelationKey.RemoteIP != "" {
return "ip", inc.CorrelationKey.RemoteIP
}
if ip := timelineRemoteIP(inc); ip != "" {
return "ip", ip
}
if inc.Account != "" {
return "account", inc.Account
}
if inc.Domain != "" {
return "domain", inc.Domain
}
if inc.Mailbox != "" {
return "mailbox", inc.Mailbox
}
return "_unkeyed", ""
}
func timelineRemoteIP(inc Incident) string {
counts := make(map[string]int)
for _, ev := range inc.Timeline {
if ev.RemoteIP != "" {
counts[ev.RemoteIP]++
}
}
best := ""
bestCount := 0
for ip, count := range counts {
if count > bestCount || count == bestCount && (best == "" || ip < best) {
best = ip
bestCount = count
}
}
return best
}
package incident
import (
"strings"
"github.com/pidginhost/csm/internal/alert"
)
// Key is the correlation key derived from a Finding. Empty fields mean
// "not provided"; the correlator uses the most specific non-empty fields.
// Host is the synthetic local-host actor for findings whose blast radius
// is the machine itself rather than one tenant, process, or remote IP.
type Key struct {
Host string `json:"host,omitempty"`
Account string `json:"account,omitempty"`
Domain string `json:"domain,omitempty"`
Mailbox string `json:"mailbox,omitempty"`
UID int `json:"uid,omitempty"`
PID int `json:"pid,omitempty"`
RemoteIP string `json:"remote_ip,omitempty"`
}
// IsEmpty reports whether the key has nothing to correlate on. Such
// findings are emitted normally but do not join an incident.
func (k Key) IsEmpty() bool {
return k.Host == "" && k.Account == "" && k.Domain == "" && k.Mailbox == "" && k.UID == 0 && k.PID == 0 && k.RemoteIP == ""
}
// KeyFor extracts a correlation key from a Finding. Host-integrity
// findings are keyed to the local host so unattributed root/system events
// still become incidents. TenantID, Process.Account, CPUser, and a
// /home[N]/<account>/ heuristic provide account attribution. Domain and
// Mailbox come directly from the finding. Process UID/PID and SourceIP are
// fallback identities: they should not split account/domain/mailbox
// incidents.
//
// Mailbox + Domain are canonicalised so emitters that set either the
// full local@domain form or the split (Mailbox=local, Domain=site)
// form land on the same key. Without that, two findings about the
// same mailbox split into two incidents whenever the emitters use
// different conventions.
func KeyFor(f alert.Finding) Key {
if ClassifyKind(f) == KindHostIntegrityRisk {
return Key{Host: "host"}
}
mailbox, domain := canonicalizeMailboxDomain(f.Mailbox, f.Domain)
k := Key{
Account: f.TenantID,
Domain: domain,
Mailbox: mailbox,
}
if f.Process != nil && f.Process.Account != "" && k.Account == "" {
k.Account = f.Process.Account
}
if k.Account == "" && f.CPUser != "" {
k.Account = f.CPUser
}
if k.Account == "" {
k.Account = accountFromHomePath(f.FilePath)
}
if f.Process != nil && !hasStableActor(k) {
if f.Process.UID != 0 {
k.UID = f.Process.UID
}
if k.UID == 0 {
k.PID = f.Process.PID
}
}
if k.Account == "" && k.Domain == "" && k.Mailbox == "" && k.UID == 0 && k.PID == 0 {
k.RemoteIP = f.SourceIP
}
return k
}
func hasStableActor(k Key) bool {
return k.Account != "" || k.Domain != "" || k.Mailbox != ""
}
// canonicalizeMailboxDomain merges Mailbox+Domain into a stable
// (Mailbox, Domain) key pair regardless of which emit convention the
// caller used. Rules:
//
// - If Mailbox already contains "@", treat it as authoritative;
// drop Domain to avoid double-keying on conflicting site.
// - If Mailbox lacks "@" and Domain is set, splice them into the
// full form. Domain is then dropped from the key (it's already
// encoded in Mailbox).
// - Domain-only findings (no Mailbox) keep the domain as the key.
//
// Domain names are case-insensitive, so only the domain component is
// lower-cased. The local part is left intact.
func canonicalizeMailboxDomain(mailbox, domain string) (string, string) {
mailbox = strings.TrimSpace(mailbox)
domain = normalizeDomainForKey(domain)
if mailbox == "" {
return "", domain
}
if local, mailboxDomain := splitMailboxForKey(mailbox); mailboxDomain != "" {
return local + "@" + mailboxDomain, ""
}
if strings.Contains(mailbox, "@") {
return mailbox, ""
}
if domain == "" {
return mailbox, ""
}
return mailbox + "@" + domain, ""
}
func canonicalizeKey(k Key) Key {
k.Mailbox, k.Domain = canonicalizeMailboxDomain(k.Mailbox, k.Domain)
return k
}
func displayMailboxDomain(mailbox, domain string) (string, string) {
mailbox = strings.TrimSpace(mailbox)
domain = strings.TrimSpace(domain)
if mailbox == "" {
return "", domain
}
if local, mailboxDomain := splitMailboxForKey(mailbox); mailboxDomain != "" {
return local + "@" + mailboxDomain, mailboxDomain
}
if domain == "" {
return mailbox, ""
}
if strings.Contains(mailbox, "@") {
return mailbox, domain
}
normalizedDomain := normalizeDomainForKey(domain)
return mailbox + "@" + normalizedDomain, normalizedDomain
}
func splitMailboxForKey(mailbox string) (string, string) {
at := strings.LastIndexByte(mailbox, '@')
if at <= 0 || at == len(mailbox)-1 {
return "", ""
}
return mailbox[:at], normalizeDomainForKey(mailbox[at+1:])
}
func normalizeDomainForKey(domain string) string {
return strings.ToLower(strings.TrimSpace(domain))
}
// accountFromHomePath parses /home[N]/<account>/... paths. Returns the
// account segment or "" if the path does not match the cPanel-style home
// layout. Pure string parsing; does not walk the filesystem.
func accountFromHomePath(p string) string {
if p == "" {
return ""
}
parts := strings.SplitN(p, "/", 4)
if len(parts) < 3 {
return ""
}
if parts[0] != "" {
return ""
}
if !strings.HasPrefix(parts[1], "home") {
return ""
}
for _, ch := range parts[1][len("home"):] {
if ch < '0' || ch > '9' {
return ""
}
}
return parts[2]
}
package incident
import (
"strings"
"github.com/pidginhost/csm/internal/alert"
)
// ClassifyKind returns the incident Kind for a Finding using simple rule
// precedence: mailbox/SMTP signals first (mailbox_takeover), then
// host-integrity checks, then ephemeral-path process exec, then the
// account-scoped web-compromise default.
func ClassifyKind(f alert.Finding) Kind {
check := strings.ToLower(f.Check)
// Mailbox takeover -- any check with a Mailbox attribute or a
// mailbox-auth check name. Mailbox attribution is the
// strongest single signal so it wins over generic check-name
// classification below.
if f.Mailbox != "" {
return KindMailboxTakeover
}
// Some authenticated-mail findings route bare cPanel-local accounts to
// TenantID instead of Mailbox. Keep those in mailbox_takeover without
// sweeping domain/config mail checks or PHP relay findings into it.
if isMailboxTakeoverCheck(check) {
return KindMailboxTakeover
}
// Host integrity -- daemon/kernel-level signals that indicate the
// host itself is compromised, not a single tenant. Listed
// explicitly so account-attributed checks that share a substring
// (e.g. suspicious_crontab on a per-user spool) stay in the
// tenant bucket.
if isHostIntegrityCheck(check) {
return KindHostIntegrityRisk
}
// Post-exploit process -- exe under ephemeral paths is a strong
// indicator of staged-then-executed payloads (cryptominers, reverse
// shells) regardless of which tenant owns the parent.
if f.Process != nil && (strings.HasPrefix(f.Process.Exe, "/tmp/") ||
strings.HasPrefix(f.Process.Exe, "/var/tmp/") ||
strings.HasPrefix(f.Process.Exe, "/dev/shm/")) {
return KindPostExploitProcess
}
// Default -- account-scoped web compromise. Most CSM findings are
// tenant-attributed web/PHP issues, so this fallback matches the
// modal incident shape operators see.
return KindWebAccountCompromise
}
// hostIntegrityChecks lists check names whose scope is the host itself
// (kernel modules, system daemon configs, root-owned credential stores)
// rather than a single tenant. Findings matching one of these jump
// straight to KindHostIntegrityRisk so incident severity reflects the
// blast radius.
var hostIntegrityChecks = map[string]bool{
"bulk_password_change": true,
"sensitive_file_write": true,
"sensitive_file_modified": true,
"fake_kernel_thread": true,
"auditd_disabled": true,
"modsec_disabled": true,
"shadow_change": true,
"sshd_config_change": true,
"root_password_change": true,
"uid0_account": true,
"suid_binary": true,
"bad_asn_outbound": true,
"kernel_module": true,
"crontab_change": true,
"crond_change": true,
}
func isHostIntegrityCheck(check string) bool {
return hostIntegrityChecks[strings.ToLower(strings.TrimSpace(check))]
}
func isMailboxTakeoverCheck(check string) bool {
if strings.HasPrefix(check, "smtp_") || strings.HasPrefix(check, "sasl_") ||
strings.HasPrefix(check, "mail_") {
return true
}
switch check {
case "email_auth_failure_realtime",
"email_cloud_relay_abuse",
"email_compromised_account",
"email_credential_leak",
"email_rate_critical",
"email_rate_warning",
"email_spam_outbreak",
"email_suspicious_geo":
return true
default:
return false
}
}
package incident
import "github.com/pidginhost/csm/internal/metrics"
// RegisterMetrics binds the correlator's counters to reg. Production
// callers should pass metrics.Default(); tests pass metrics.NewRegistry()
// to keep registration isolated.
func RegisterMetrics(reg *metrics.Registry, c *Correlator) {
reg.RegisterGaugeFunc(
"csm_incidents_open",
"Open and Contained incidents currently in correlator state.",
func() float64 { return float64(c.OpenCount()) },
)
reg.RegisterCounterFunc(
"csm_incidents_created_total",
"Total incidents created by the correlator.",
func() float64 { return float64(c.counters.createdTotal.Load()) },
)
reg.RegisterCounterFunc(
"csm_incidents_severity_changed_total",
"Incident severity escalations (severity does not downgrade, so this is monotonic).",
func() float64 { return float64(c.counters.severityChangedTotal.Load()) },
)
reg.RegisterCounterFunc(
"csm_incidents_status_changed_total",
"Incident status transitions (open/contained/resolved/dismissed).",
func() float64 { return float64(c.counters.statusChangedTotal.Load()) },
)
reg.RegisterCounterFunc(
"csm_incidents_findings_merged_total",
"Findings merged into an existing incident (not counted on incident create).",
func() float64 { return float64(c.counters.findingsMergedTotal.Load()) },
)
reg.RegisterCounterFunc(
"csm_incidents_compacted_total",
"Incidents pruned by retention compaction (resolved/dismissed beyond TTL).",
func() float64 { return float64(c.counters.compactedTotal.Load()) },
)
reg.RegisterGaugeFunc(
"csm_incidents_pending",
"Findings held in the threshold gate, awaiting a second correlated finding before opening an incident.",
func() float64 { return float64(c.PendingCount()) },
)
reg.RegisterCounterFunc(
"csm_incidents_auto_closed_total",
"Open or contained incidents auto-resolved after exceeding their per-kind idle threshold.",
func() float64 { return float64(c.counters.autoClosedTotal.Load()) },
)
reg.RegisterCounterFunc(
"csm_incidents_auto_close_dry_run_total",
"Auto-close decisions counted while dry_run was active (state unchanged).",
func() float64 { return float64(c.counters.autoCloseDryRunTotal.Load()) },
)
reg.RegisterCounterFunc(
"csm_credential_spray_opened_total",
"Credential-spray super-incidents opened (one source IP brute-forcing many mailboxes).",
func() float64 { return float64(c.counters.sprayOpenedTotal.Load()) },
)
reg.RegisterCounterFunc(
"csm_credential_spray_suppressed_mailbox_takeover_total",
"Per-mailbox incidents suppressed because a credential_spray incident already covers the source IP.",
func() float64 { return float64(c.counters.spraySuppressedTotal.Load()) },
)
reg.RegisterCounterFunc(
"csm_credential_spray_dry_run_total",
"Spray decisions counted while dry_run was active (routing unchanged).",
func() float64 { return float64(c.counters.sprayDryRunTotal.Load()) },
)
reg.RegisterGaugeFunc(
"csm_credential_spray_tracked_ips",
"Source IPs currently held in the spray detector's per-IP map.",
func() float64 { return float64(c.SprayTrackedIPs()) },
)
}
package incident
import (
"strings"
"github.com/pidginhost/csm/internal/alert"
)
// kindRank orders Kind values from weakest to strongest so the merge
// path can upgrade an incident's Kind when a stronger pattern appears
// later, but never downgrades. Higher number = stronger / more
// operator-attention.
var kindRank = map[Kind]int{
KindWebAccountCompromise: 1,
KindMailboxTakeover: 2,
KindPostExploitProcess: 3,
KindHostIntegrityRisk: 4,
KindCredentialSpray: 3,
KindHostTakeover: 5,
}
var compoundPostExploitWebChecks = map[string]struct{}{
"webshell": {},
"webshell_realtime": {},
"webshell_content_realtime": {},
"new_webshell_file": {},
"obfuscated_php": {},
"obfuscated_php_realtime": {},
"php_shield_webshell": {},
}
var compoundPostExploitNetworkChecks = map[string]struct{}{
"c2_connection": {},
"backdoor_port": {},
"backdoor_port_outbound": {},
}
// compoundHostPrivescUID0Checks, compoundHostPrivescSUIDChecks, and
// compoundHostPrivescBadASNChecks are the three host-takeover legs the
// takeover rule correlates: a new uid-0 account, a planted suid binary, and
// an outbound connection to a bad/unexpected ASN.
var compoundHostPrivescUID0Checks = map[string]struct{}{
"uid0_account": {},
}
var compoundHostPrivescSUIDChecks = map[string]struct{}{
"suid_binary": {},
}
var compoundHostPrivescBadASNChecks = map[string]struct{}{
"bad_asn_outbound": {},
}
// allCompoundFlagsSet reports whether every compound signal is already
// recorded, so the timeline hydrate loop can stop early.
func allCompoundFlagsSet(f CompoundFlags) bool {
return f.Webshell && f.C2 && f.UID0 && f.SUID && f.BadASNOutbound
}
// hostTakeoverLegs counts how many of the three distinct host-takeover legs
// an incident has observed. Two or more legs escalate to KindHostTakeover.
func hostTakeoverLegs(f CompoundFlags) int {
n := 0
if f.UID0 {
n++
}
if f.SUID {
n++
}
if f.BadASNOutbound {
n++
}
return n
}
// maybeReclassifyKind upgrades inc.Kind in place when the new finding
// classifies as a stronger Kind, or when the incident's sticky
// CompoundFlags plus the new finding cover a compound pattern that the
// per-finding classifier cannot see. Compound rules at this time:
// webshell + outbound C2 connection -> PostExploitProcess;
// uid0_account + suid_binary -> HostTakeover. Idempotent: calling with
// weaker findings is a no-op.
//
// CompoundFlags are mutated here so callers do not need a separate
// pass; they survive timeline trimming so an early webshell still
// arms the rule when a much later C2 finding arrives.
func maybeReclassifyKind(inc *Incident, f alert.Finding) {
if inc == nil {
return
}
// Hydrate sticky flags from the current timeline so incidents
// restored from bbolt (predating sticky flags) or built directly
// in tests still arm the compound rule. Timeline scan is bounded
// by maxIncidentTimeline so the cost is constant.
hydrateCompoundFlagsFromTimeline(&inc.CompoundFlags, inc.Timeline)
updateCompoundFlags(&inc.CompoundFlags, f.Check)
if newKind := ClassifyKind(f); kindRank[newKind] > kindRank[inc.Kind] {
inc.Kind = newKind
}
if kindRank[KindPostExploitProcess] > kindRank[inc.Kind] && inc.CompoundFlags.Webshell && inc.CompoundFlags.C2 {
inc.Kind = KindPostExploitProcess
}
// Host takeover: any two of the three distinct legs (new uid-0 account,
// planted suid binary, outbound connection to a bad ASN) on the same
// host inside the window.
if kindRank[KindHostTakeover] > kindRank[inc.Kind] && hostTakeoverLegs(inc.CompoundFlags) >= 2 {
inc.Kind = KindHostTakeover
}
}
// hydrateCompoundFlagsFromTimeline OR-merges timeline-derived signals
// into flags. Used as a one-shot migration for legacy/persisted
// incidents that have webshell or C2 events in their timeline but no
// CompoundFlags yet. Trimmed timelines may miss events, but anything
// still present remains a valid signal.
func hydrateCompoundFlagsFromTimeline(flags *CompoundFlags, events []IncidentEvent) {
if flags == nil || allCompoundFlagsSet(*flags) {
return
}
for _, ev := range events {
updateCompoundFlags(flags, ev.Check)
if allCompoundFlagsSet(*flags) {
return
}
}
}
// updateCompoundFlags sets sticky compound flags based on a Finding's
// check name. Once true a flag stays true so reclassify decisions are not
// silently disarmed by later trimming.
func updateCompoundFlags(flags *CompoundFlags, check string) {
if flags == nil {
return
}
check = strings.ToLower(strings.TrimSpace(check))
if _, ok := compoundPostExploitWebChecks[check]; ok {
flags.Webshell = true
}
if _, ok := compoundPostExploitNetworkChecks[check]; ok {
flags.C2 = true
}
if _, ok := compoundHostPrivescUID0Checks[check]; ok {
flags.UID0 = true
}
if _, ok := compoundHostPrivescSUIDChecks[check]; ok {
flags.SUID = true
}
if _, ok := compoundHostPrivescBadASNChecks[check]; ok {
flags.BadASNOutbound = true
}
}
package incident
import (
"time"
"github.com/pidginhost/csm/internal/alert"
)
// SpraySuppressionConfig is the operator-tunable knob set for the
// credential-spray detector. Defaults are conservative: enabled=false and
// dry_run=true so the detector ships dark, increments counters, and only
// changes incident routing once an operator opts in.
type SpraySuppressionConfig struct {
Enabled bool
DryRun bool
DistinctMailboxes int
SeverityEscalateAt int
PerCheck map[string]bool
MaxTrackedIPs int
// BlockAtSeverity gates the firewall hand-off. Values:
// "" / unset - detection-only (legacy behavior).
// "high" - block on incident open (DistinctMailboxes trip).
// "critical" - block on severity escalation (SeverityEscalateAt).
// Comparison is case-insensitive. Any other value is ignored so an
// operator typo cannot accidentally engage blocking.
BlockAtSeverity string
}
// IsZero reports whether the config is unset; the correlator treats a
// zero value as "spray detection disabled" without touching defaults.
func (c SpraySuppressionConfig) IsZero() bool {
return !c.Enabled && !c.DryRun && c.DistinctMailboxes == 0 && c.SeverityEscalateAt == 0 && len(c.PerCheck) == 0 && c.MaxTrackedIPs == 0 && c.BlockAtSeverity == ""
}
// sprayDecision is the verdict the detector returns to OnFinding for a
// candidate spray-class finding.
type sprayDecision int
const (
// sprayDecisionNone means the finding is not spray traffic; the
// caller should run the normal per-mailbox correlation path.
sprayDecisionNone sprayDecision = iota
// sprayDecisionOpen means this finding tipped the IP over the
// distinct-mailbox threshold; the caller must open a new
// credential_spray incident keyed on RemoteIP and report the id back
// to the detector via BindIncident.
sprayDecisionOpen
// sprayDecisionSuppress means an active spray incident already
// exists for the IP and the finding should be attached to it
// instead of opening a per-mailbox incident.
sprayDecisionSuppress
)
// ipSprayState tracks the distinct-mailbox set hit by a single source IP
// inside the merge window, plus the bound spray incident id once the
// threshold trips.
type ipSprayState struct {
mailboxes map[string]struct{}
firstSeen time.Time
lastSeen time.Time
// incident is the bound credential_spray incident id once the
// threshold has tripped. Empty until then.
incident string
}
// sprayDetector keeps a per-IP sliding window of distinct mailboxes hit
// by spray-class checks and decides whether a finding should open a new
// credential_spray super-incident, attach to an existing one, or fall
// through to the normal per-mailbox correlator path.
//
// Concurrency: the correlator's mutex serializes all calls into the
// detector. The detector itself does not take a separate lock so the
// "single mutex protects all correlator state" invariant holds.
type sprayDetector struct {
cfg SpraySuppressionConfig
window time.Duration
now func() time.Time
isWhitelisted func(string) bool
// perIP is the live state map. Bounded by cfg.MaxTrackedIPs; entries
// that fall outside the window during Decide are pruned in place,
// and once the map exceeds the cap the oldest-by-lastSeen entry is
// evicted before insert.
perIP map[string]*ipSprayState
}
// newSprayDetector returns a detector wired to the supplied config.
// Returns nil when cfg is zero-valued so the correlator can no-op the
// fast path on hosts that have not opted in.
func newSprayDetector(cfg SpraySuppressionConfig, window time.Duration, now func() time.Time, isWhitelisted func(string) bool) *sprayDetector {
if cfg.IsZero() {
return nil
}
if cfg.MaxTrackedIPs <= 0 {
cfg.MaxTrackedIPs = 10000
}
if cfg.DistinctMailboxes <= 0 {
cfg.DistinctMailboxes = 10
}
if cfg.SeverityEscalateAt <= cfg.DistinctMailboxes {
// Default to 5x threshold so a CRITICAL bump only fires on
// genuinely sustained sprays, not on the immediate trip.
cfg.SeverityEscalateAt = cfg.DistinctMailboxes * 5
}
if isWhitelisted == nil {
isWhitelisted = func(string) bool { return false }
}
if now == nil {
now = time.Now
}
return &sprayDetector{
cfg: cfg,
window: window,
now: now,
isWhitelisted: isWhitelisted,
perIP: make(map[string]*ipSprayState),
}
}
// Decide consumes a spray-candidate finding and returns the decision
// the correlator should apply. The detector mutates internal state
// (records the mailbox, tracks the lastSeen timestamp) regardless of
// the dry_run flag so counters and audit logs reflect what the live
// path would have done; only the returned decision is gated by
// dry_run. Caller must hold the correlator mutex.
//
// hitCount is the number of distinct mailboxes the IP has hit inside
// the window after this finding is recorded. The caller uses it to
// decide severity for sprayDecisionOpen and to escalate severity on
// merge-into-existing-spray.
func (d *sprayDetector) Decide(f alert.Finding) (decision sprayDecision, hitCount int) {
if d == nil || !d.cfg.Enabled && !d.cfg.DryRun {
return sprayDecisionNone, 0
}
if f.SourceIP == "" {
return sprayDecisionNone, 0
}
if !d.cfg.PerCheck[f.Check] {
return sprayDecisionNone, 0
}
if d.isWhitelisted(f.SourceIP) {
return sprayDecisionNone, 0
}
// Identity dimension for the distinct-set: prefer mailbox; fall back
// to tenant id (per-account brute force); fall back to cPanel user
// (php-relay attribution); fall back to message text so two findings
// without any structured identity still count as distinct attempts.
target, _ := canonicalizeMailboxDomain(f.Mailbox, f.Domain)
if target == "" {
target = f.TenantID
}
if target == "" {
target = f.CPUser
}
if target == "" {
target = f.Message
}
if target == "" {
return sprayDecisionNone, 0
}
now := d.now()
state, ok := d.perIP[f.SourceIP]
if ok {
// Window expiration: a state whose lastSeen fell outside the
// window is stale; reset before recording the new hit so a
// fresh attack does not inherit cold mailbox counts. Bound
// entries stay until the correlator closes or unbinds their
// incident; otherwise a quiet but still-open super-incident can
// lose suppression and duplicate on the next finding.
if state.incident == "" && now.Sub(state.lastSeen) > d.window {
state = nil
delete(d.perIP, f.SourceIP)
}
}
if state == nil {
// Eviction: keep the live set bounded. Drop the oldest-by-lastSeen
// entry before inserting if we are at the cap. O(N) scan; cheap
// at the configured cap (10k) and only runs at insert time.
if len(d.perIP) >= d.cfg.MaxTrackedIPs {
d.evictOldestLocked()
}
state = &ipSprayState{
mailboxes: make(map[string]struct{}),
firstSeen: now,
}
d.perIP[f.SourceIP] = state
}
state.mailboxes[target] = struct{}{}
state.lastSeen = now
hitCount = len(state.mailboxes)
// Bound to existing spray incident? Continue suppressing.
if state.incident != "" {
if d.cfg.DryRun {
return sprayDecisionNone, hitCount
}
return sprayDecisionSuppress, hitCount
}
if hitCount < d.cfg.DistinctMailboxes {
return sprayDecisionNone, hitCount
}
// Threshold tripped. Live mode opens a new spray incident; dry_run
// only counts the decision so the operator can observe the workload
// without changing routing.
if d.cfg.DryRun {
return sprayDecisionNone, hitCount
}
return sprayDecisionOpen, hitCount
}
// BindIncident records the spray incident id the correlator just
// created in response to sprayDecisionOpen. Subsequent findings from
// the same IP return sprayDecisionSuppress while the incident stays
// bound.
// Caller holds the correlator mutex.
func (d *sprayDetector) BindIncident(ip, id string) {
if d == nil {
return
}
if state, ok := d.perIP[ip]; ok {
state.incident = id
}
}
// Rehydrate seeds the perIP map at daemon startup so an open
// credential_spray incident restored from bbolt continues to suppress
// new per-mailbox fan-out instead of allowing a duplicate super-incident
// to open. The seeded state carries no per-mailbox set: the operator
// already saw the trip on the open incident, and the suppress path only
// reads state.incident. lastSeen is set so metrics and future unbound
// expiry have a stable anchor if the incident later closes.
// Caller holds the correlator mutex.
func (d *sprayDetector) Rehydrate(ip, id string, lastSeen time.Time) {
if d == nil || ip == "" || id == "" {
return
}
d.perIP[ip] = &ipSprayState{
mailboxes: make(map[string]struct{}),
firstSeen: lastSeen,
lastSeen: lastSeen,
incident: id,
}
}
// IncidentForIP returns the bound spray incident id for ip, or "" if no
// spray is currently bound. Used by Decide's suppress path to tell the
// caller which incident to merge the finding into.
func (d *sprayDetector) IncidentForIP(ip string) string {
if d == nil {
return ""
}
if state, ok := d.perIP[ip]; ok {
return state.incident
}
return ""
}
// UnbindIncident drops detector state for every IP currently bound to
// id, so a finding from one of those IPs cannot reach the closed or
// missing incident through the suppress merge path. Dropping the state
// also resets the distinct-mailbox threshold for future activity after
// an operator closes or dismisses the incident. Caller holds the
// correlator mutex. Linear scan over perIP, same shape as the byKey
// unbind in correlator.unbindLocked: the active set is bounded by
// MaxTrackedIPs.
func (d *sprayDetector) UnbindIncident(id string) {
if d == nil || id == "" {
return
}
for ip, state := range d.perIP {
if state.incident == id {
delete(d.perIP, ip)
}
}
}
// UnbindIncidents drops detector state for every IP bound to one of the
// supplied incident ids. Caller holds the correlator mutex.
func (d *sprayDetector) UnbindIncidents(ids []string) {
if d == nil || len(ids) == 0 {
return
}
if len(ids) == 1 {
d.UnbindIncident(ids[0])
return
}
idSet := make(map[string]struct{}, len(ids))
for _, id := range ids {
if id != "" {
idSet[id] = struct{}{}
}
}
if len(idSet) == 0 {
return
}
for ip, state := range d.perIP {
if _, ok := idSet[state.incident]; ok {
delete(d.perIP, ip)
}
}
}
// PruneStale clears entries whose lastSeen is older than the window.
// Called by the daemon retention loop alongside PruneStalePending so
// the detector does not grow without bound on hosts with churning
// attacker IPs. Entries still bound to an open spray incident are
// kept regardless of age -- evicting them lets a new spray finding
// open a duplicate per-mailbox incident while the super-incident is
// still active in the correlator.
func (d *sprayDetector) PruneStale(now time.Time) int {
if d == nil {
return 0
}
pruned := 0
for ip, state := range d.perIP {
if state.incident != "" {
continue
}
if now.Sub(state.lastSeen) > d.window {
delete(d.perIP, ip)
pruned++
}
}
return pruned
}
// TrackedIPs returns the count of source IPs currently held in the
// detector. Surfaced via the csm_credential_spray_tracked_ips gauge.
func (d *sprayDetector) TrackedIPs() int {
if d == nil {
return 0
}
return len(d.perIP)
}
// evictOldestLocked drops the perIP entry with the smallest lastSeen.
// Caller holds the correlator mutex. Linear scan over MaxTrackedIPs.
func (d *sprayDetector) evictOldestLocked() {
var oldestIP string
var oldestAt time.Time
first := true
for ip, state := range d.perIP {
if first || state.lastSeen.Before(oldestAt) {
oldestIP = ip
oldestAt = state.lastSeen
first = false
}
}
if oldestIP != "" {
delete(d.perIP, oldestIP)
}
}
// Counters for the credential-spray decisions live on the Correlator's
// existing `counters` struct (see correlator.go). Keeping them there
// means there is exactly one place that owns counter mutations, which
// keeps the metrics story coherent and avoids surprising operators
// who already grep for `csm_incidents_*`.
// Package incident groups related security findings into a single
// "story" with a timeline. Original findings are not mutated or
// suppressed; the Incident is layered on top so operators read one
// escalating object instead of stitching findings together by hand.
package incident
import (
"encoding/json"
"fmt"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// Status is the lifecycle position of an incident.
type Status string
const (
StatusOpen Status = "open"
StatusContained Status = "contained"
StatusResolved Status = "resolved"
StatusDismissed Status = "dismissed"
)
// Kind is the high-level taxonomy a correlator assigns at create time.
// Stable strings; downstream tooling pins on these.
type Kind string
const (
KindWebAccountCompromise Kind = "web_account_compromise"
KindMailboxTakeover Kind = "mailbox_takeover"
KindPostExploitProcess Kind = "post_exploit_process"
KindHostIntegrityRisk Kind = "host_integrity_risk"
// KindCredentialSpray collapses a single source IP that is brute-forcing
// many distinct mailboxes/accounts inside the merge window into one
// super-incident keyed on the source IP. Prevents the per-mailbox fan-out
// that turns one attacker into thousands of mailbox_takeover incidents.
KindCredentialSpray Kind = "credential_spray" // #nosec G101 -- taxonomy label, not a secret
// KindHostTakeover is the compound escalation when more than one
// host-privilege-escalation leg (a new uid-0 account, a planted suid
// binary, or bad-ASN outbound connection) is seen for the same host
// inside the merge window. It ranks above KindHostIntegrityRisk so a
// confirmed multi-leg takeover stands out from a single host-integrity
// finding.
KindHostTakeover Kind = "host_takeover"
)
// Incident is the wire shape every consumer (API, control socket,
// audit propagation) sees. omitempty fields are absent from JSON when
// zero so consumers ignore optional context cleanly.
type Incident struct {
ID string `json:"id"`
Kind Kind `json:"kind"`
Status Status `json:"status"`
Severity alert.Severity `json:"severity"`
Account string `json:"account,omitempty"`
Domain string `json:"domain,omitempty"`
Mailbox string `json:"mailbox,omitempty"`
CorrelationKey *Key `json:"correlation_key,omitempty"`
Summary string `json:"summary,omitempty"`
Confidence int `json:"confidence,omitempty"`
Findings []string `json:"findings,omitempty"`
Timeline []IncidentEvent `json:"timeline,omitempty"`
Actions []IncidentAction `json:"actions,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// ClosedAt is set when an incident transitions out of Open/Contained.
// Populated by SetStatus and CloseStale; zero for active incidents so
// existing webhook/SIEM consumers see no diff (omitempty).
ClosedAt time.Time `json:"closed_at,omitempty"`
// ClosedBy attributes the close. "operator" for SetStatus calls,
// "auto:stale" for CloseStale. Empty for active incidents.
ClosedBy string `json:"closed_by,omitempty"`
// CompoundFlags carries sticky bits used by the timeline-aware
// reclassifier. Once set, they survive timeline trimming so an
// early webshell or C2 signal still drives the compound rule when
// the matching counterpart arrives much later.
CompoundFlags CompoundFlags `json:"compound_flags,omitzero"`
}
// CompoundFlags records the union of compound-pattern signals an
// Incident has ever observed. Fields are sticky once true; they are
// not derived from the (possibly trimmed) timeline so reclassify is
// not silently disarmed by head+tail eviction.
type CompoundFlags struct {
Webshell bool `json:"webshell,omitempty"`
C2 bool `json:"c2,omitempty"`
// UID0, SUID, and BadASNOutbound record the three host-takeover legs:
// a new uid-0 account, a planted suid binary, and an outbound connection
// to a bad/unexpected ASN. When any two are set on one incident the
// reclassifier escalates to KindHostTakeover.
UID0 bool `json:"uid0,omitempty"`
SUID bool `json:"suid,omitempty"`
BadASNOutbound bool `json:"bad_asn_outbound,omitempty"`
}
// MarshalJSON renders Severity as its uppercase string form
// ("HIGH", "CRITICAL", "WARNING") instead of the underlying int.
// alert.Severity is an int enum, so default marshaling would emit
// numbers; consumers (web UI, control socket, audit propagation)
// expect the same human-readable token already produced by
// audit_sink and webhook dispatch.
func (i Incident) MarshalJSON() ([]byte, error) {
type wireIncident Incident
return json.Marshal(struct {
wireIncident
Severity string `json:"severity"`
}{
wireIncident: wireIncident(i),
Severity: i.Severity.String(),
})
}
// UnmarshalJSON decodes the wire shape produced by MarshalJSON. Severity
// is read from its string form ("WARNING"/"HIGH"/"CRITICAL") and converted
// back to alert.Severity. Unknown strings return an error so SIEM-side
// schema drift is loud, not silent.
func (i *Incident) UnmarshalJSON(data []byte) error {
type wireIncident Incident
aux := struct {
*wireIncident
Severity string `json:"severity"`
}{wireIncident: (*wireIncident)(i)}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
switch aux.Severity {
case "":
// allow the zero-severity case for partial decodes (tests, partial
// JSON snippets in the API). Severity stays at zero value (Warning).
case "WARNING":
i.Severity = alert.Warning
case "HIGH":
i.Severity = alert.High
case "CRITICAL":
i.Severity = alert.Critical
default:
return fmt.Errorf("incident: unknown severity %q", aux.Severity)
}
return nil
}
// IncidentEvent is one entry in an incident's timeline. Built from a
// Finding when it joins the incident; carries enough context to
// render the timeline without re-reading the original record.
type IncidentEvent struct {
Time time.Time `json:"time"`
Kind string `json:"kind"`
Check string `json:"check,omitempty"`
Message string `json:"message"`
FindingID string `json:"finding_id,omitempty"`
PID int `json:"pid,omitempty"`
UID int `json:"uid,omitempty"`
Process string `json:"process,omitempty"`
Path string `json:"path,omitempty"`
RemoteIP string `json:"remote_ip,omitempty"`
}
// IncidentAction is an automated or operator action that touched the
// incident. Appended to the timeline; surfaced separately so dashboards
// can filter by what the system did vs what it observed.
type IncidentAction struct {
Time time.Time `json:"time"`
Action string `json:"action"`
Result string `json:"result"`
Details string `json:"details,omitempty"`
}
package webserver
import (
"context"
"fmt"
"os/exec"
"time"
"github.com/pidginhost/csm/internal/platform"
)
// apacheHandler covers cPanel + plain Apache. cPanel ships `apachectl`
// pointing at the EasyApache binary; plain Apache on Debian/Ubuntu has
// `apache2ctl`. The selector picks the right one at construction.
type apacheHandler struct {
snippetPath string
ctlBinary string // "apachectl" or "apache2ctl"
reloadAction []string
cmdRunner cmdRunner
}
func newApacheHandler(info platform.Info, r cmdRunner) *apacheHandler {
h := &apacheHandler{cmdRunner: r}
switch {
case info.IsCPanel():
// cPanel always runs EasyApache, conf.d is the canonical drop-in.
h.snippetPath = "/etc/apache2/conf.d/csm-challenge.conf"
h.ctlBinary = "apachectl"
h.reloadAction = []string{"apachectl", "graceful"}
case info.IsDebianFamily():
// Debian/Ubuntu use apache2 + conf-enabled.
h.snippetPath = "/etc/apache2/conf-enabled/csm-challenge.conf"
h.ctlBinary = "apache2ctl"
h.reloadAction = []string{"systemctl", "reload", "apache2"}
default:
// RHEL family without cPanel: httpd + /etc/httpd/conf.d.
h.snippetPath = "/etc/httpd/conf.d/csm-challenge.conf"
h.ctlBinary = "apachectl"
h.reloadAction = []string{"systemctl", "reload", "httpd"}
}
return h
}
func (h *apacheHandler) Kind() string { return "apache" }
func (h *apacheHandler) SnippetPath() string { return h.snippetPath }
func (h *apacheHandler) Template() string { return apacheTemplate }
func (h *apacheHandler) PostInstallInstructions() string { return "" }
func (h *apacheHandler) Validate() error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
out, err := h.cmdRunner.Run(ctx, h.ctlBinary, "configtest")
if err != nil {
return fmt.Errorf("apache configtest failed: %v\n%s", err, out)
}
return nil
}
func (h *apacheHandler) Reload() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
out, err := h.cmdRunner.Run(ctx, h.reloadAction[0], h.reloadAction[1:]...)
if err != nil {
return fmt.Errorf("apache reload failed: %v\n%s", err, out)
}
return nil
}
// cmdRunner is the injection seam tests use to mock exec. The real
// implementation in realCmdRunner just shells out via os/exec.
type cmdRunner interface {
Run(ctx context.Context, name string, args ...string) ([]byte, error)
}
type realCmdRunner struct{}
func (realCmdRunner) Run(ctx context.Context, name string, args ...string) ([]byte, error) {
// #nosec G204 -- name + args come from the installer's static handler
// definitions (apachectl/apache2ctl/nginx/lswsctrl + verbs). No
// user-controlled strings reach this path.
return exec.CommandContext(ctx, name, args...).CombinedOutput()
}
// Package webserver auto-installs the CSM challenge webserver glue
// (Apache / LSWS / Nginx) with a write-validate-reload-or-revert flow.
// The operator runs `csm webserver-integration {install|upgrade|...}`;
// the package picks the right handler for the host and never reloads
// the webserver with a snippet that does not pass configtest.
package webserver
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"text/template"
"github.com/pidginhost/csm/internal/challenge"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
)
// templateHeader is prepended to every rendered snippet so the
// installer can read the version back without parsing the body. The
// header is a single comment line whose value uniquely identifies
// "CSM owns this file". Manual edits that wipe or change the header
// trip ErrManualEdits at upgrade time and the file is left untouched.
const templateHeaderPrefix = "# csm-managed-version: "
// RenderConfig contains the daemon settings that have to be baked
// into webserver snippets. Values come from csm.yaml at install time
// where they are operator-configurable.
type RenderConfig struct {
ChallengeMapPath string
ChallengeNginxMapPath string
ChallengeListenAddr string
// ChallengePublicURL is the fully-qualified redirect target the
// webserver snippet emits. Required on install; empty value blocks
// the installer with a clear error. Operator typically points it at
// the host's TLS-valid control-panel domain on a CSM-owned port,
// e.g. https://server.example.com:8439/challenge.
ChallengePublicURL string
}
type templateData struct {
ChallengeMapPath string
ChallengeNginxMapPath string
ChallengePublicURL string
}
// ErrMissingPublicURL is returned by Install when challenge.public_url
// is empty. The webserver snippet has no fallback redirect target so
// the integration cannot function without it.
var ErrMissingPublicURL = errors.New("webserver integration: challenge.public_url is not set")
// ErrInvalidPublicURL is returned when challenge.public_url cannot be
// used as the direct browser redirect target.
var ErrInvalidPublicURL = errors.New("webserver integration: challenge.public_url must be an absolute http(s) URL ending in /challenge")
// ErrLoopbackPublicURL is returned when the installed snippets would
// redirect browsers to a listener that only accepts loopback traffic.
var ErrLoopbackPublicURL = errors.New("webserver integration: challenge.listen_addr must be non-loopback when challenge.public_url redirects are enabled")
// Result is the structured outcome of an installer run. JSON-friendly
// shape so the CLI can render either human text or `--json` output.
type Result struct {
Action string `json:"action"` // "install" | "upgrade" | "remove" | "status" | "validate"
Status string `json:"status"` // "ok" | "no-op" | "skipped" | "fail"
Webserver string `json:"webserver"` // detected handler kind, "" if none
SnippetPath string `json:"snippet_path"` // "" if no handler
OnDiskVer int `json:"on_disk_version,omitempty"`
ShippedVer int `json:"shipped_version,omitempty"`
Message string `json:"message,omitempty"`
// FollowUp is the operator-facing setup the integration could not
// automate on this stack. Empty when install is fully automatic.
// The CLI prints it after the result block on successful install.
FollowUp string `json:"follow_up,omitempty"`
}
// Installer drives the install / upgrade / status / remove flow. All
// I/O goes through injected hooks so unit tests can run on darwin or
// against a temp tree without touching real webserver paths.
type Installer struct {
Handler Handler
Config RenderConfig
MkdirAll func(path string, mode os.FileMode) error
WriteAt func(path string, data []byte, mode os.FileMode) error
ReadAt func(path string) ([]byte, error)
StatAt func(path string) (os.FileInfo, error)
RemoveAt func(path string) error
Stderr io.Writer
}
// New returns an Installer wired for live operation: real filesystem
// reads, atomic writes, real exec runner. The handler is auto-selected
// from platform.Detect(); pass info to override for tests.
func New(info platform.Info, cfg *config.Config) (*Installer, error) {
h, err := pickHandler(info, realCmdRunner{})
if err != nil {
return nil, err
}
return &Installer{
Handler: h,
Config: renderConfigFrom(cfg),
MkdirAll: os.MkdirAll,
WriteAt: atomicWrite,
ReadAt: os.ReadFile,
StatAt: os.Stat,
RemoveAt: os.Remove,
Stderr: os.Stderr,
}, nil
}
// Install writes the snippet for the first time (or overwrites a
// stale one) with the safe rollback flow:
//
// 1. Stash existing bytes (or note absence).
// 2. Write new snippet atomically.
// 3. Validate via the webserver's own configtest.
// 4. On pass: reload + done.
// 5. On fail: restore previous bytes + return error.
//
// Reload failure after a passing configtest is restored the same way,
// then a recovery reload is attempted so the host returns to the
// last-known-good state.
func (i *Installer) Install() (Result, error) {
res := Result{
Action: "install",
Webserver: i.Handler.Kind(),
SnippetPath: i.Handler.SnippetPath(),
ShippedVer: TemplateVersion,
}
if err := validateChallengePublicURL(i.Config); err != nil {
res.Status = "fail"
res.Message = err.Error()
return res, err
}
prevBytes, prevExists, prevVer, err := i.readSnippet()
if err != nil && !errors.Is(err, os.ErrNotExist) {
res.Status = "fail"
res.Message = err.Error()
return res, err
}
res.OnDiskVer = prevVer
if prevExists && prevVer == 0 {
res.Status = "fail"
res.Message = ErrManualEdits.Error() + ": " + i.Handler.SnippetPath()
return res, ErrManualEdits
}
if mapErr := i.ensureChallengeMapFiles(); mapErr != nil {
res.Status = "fail"
res.Message = "runtime files: " + mapErr.Error()
return res, mapErr
}
rendered, rerr := i.renderTemplate()
if rerr != nil {
res.Status = "fail"
res.Message = "render: " + rerr.Error()
return res, rerr
}
if prevExists && bytes.Equal(prevBytes, rendered) {
res.Status = "no-op"
res.Message = "snippet already current"
return res, nil
}
if err := i.WriteAt(i.Handler.SnippetPath(), rendered, 0o644); err != nil {
res.Status = "fail"
res.Message = "write: " + err.Error()
return res, err
}
if verr := i.Handler.Validate(); verr != nil {
i.restore(prevBytes, prevExists)
res.Status = "fail"
res.Message = "configtest: " + verr.Error()
return res, verr
}
if rerr := i.Handler.Reload(); rerr != nil {
i.restore(prevBytes, prevExists)
// Best-effort recovery reload. Even if it fails, the file is
// already back to the last-good content.
_ = i.Handler.Reload()
res.Status = "fail"
res.Message = "reload: " + rerr.Error() + " (rolled back)"
return res, rerr
}
res.Status = "ok"
if prevExists {
res.Message = fmt.Sprintf("snippet upgraded v%d -> v%d", prevVer, TemplateVersion)
} else {
res.Message = fmt.Sprintf("snippet installed (v%d)", TemplateVersion)
}
res.FollowUp = i.Handler.PostInstallInstructions()
return res, nil
}
// Upgrade is an alias for Install with a more honest CLI verb. The
// underlying flow is the same: idempotent install + version compare.
func (i *Installer) Upgrade() (Result, error) {
res, err := i.Install()
res.Action = "upgrade"
return res, err
}
// Status returns the current integration state without writing
// anything. Used by post-upgrade hooks and operator-facing diagnostic
// commands to detect drift.
func (i *Installer) Status() (Result, error) {
res := Result{
Action: "status",
Webserver: i.Handler.Kind(),
SnippetPath: i.Handler.SnippetPath(),
ShippedVer: TemplateVersion,
}
_, exists, ver, err := i.readSnippet()
if err != nil && !errors.Is(err, os.ErrNotExist) {
res.Status = "fail"
res.Message = err.Error()
return res, err
}
res.OnDiskVer = ver
res.Status, res.Message = classifyStatus(exists, ver, TemplateVersion)
return res, nil
}
// classifyStatus is the pure-logic version compare extracted so the
// stale / modified / ok branches can be unit-tested without depending
// on the current TemplateVersion constant.
func classifyStatus(exists bool, onDisk, shipped int) (status, message string) {
switch {
case !exists:
return "missing", "no snippet installed; run `csm webserver-integration install`"
case onDisk == 0:
return "modified", ErrManualEdits.Error()
case onDisk < shipped:
return "stale", fmt.Sprintf("on-disk v%d < shipped v%d; run `csm webserver-integration upgrade`", onDisk, shipped)
default:
return "ok", fmt.Sprintf("snippet at v%d", onDisk)
}
}
// Remove deletes the snippet, runs configtest to confirm the webserver
// is happy without it, and reloads. Mirrors the install rollback
// discipline: if removing the file makes configtest fail, restore the
// original and exit non-zero.
func (i *Installer) Remove() (Result, error) {
res := Result{
Action: "remove",
Webserver: i.Handler.Kind(),
SnippetPath: i.Handler.SnippetPath(),
ShippedVer: TemplateVersion,
}
prevBytes, prevExists, prevVer, err := i.readSnippet()
if err != nil && !errors.Is(err, os.ErrNotExist) {
res.Status = "fail"
res.Message = err.Error()
return res, err
}
res.OnDiskVer = prevVer
if !prevExists {
res.Status = "no-op"
res.Message = "snippet not present"
return res, nil
}
if prevVer == 0 {
res.Status = "fail"
res.Message = ErrManualEdits.Error() + ": refusing to delete an operator-edited file"
return res, ErrManualEdits
}
if err := i.RemoveAt(i.Handler.SnippetPath()); err != nil {
res.Status = "fail"
res.Message = "delete: " + err.Error()
return res, err
}
if verr := i.Handler.Validate(); verr != nil {
i.restore(prevBytes, prevExists)
res.Status = "fail"
res.Message = "configtest after remove: " + verr.Error()
return res, verr
}
if rerr := i.Handler.Reload(); rerr != nil {
i.restore(prevBytes, prevExists)
_ = i.Handler.Reload()
res.Status = "fail"
res.Message = "reload: " + rerr.Error() + " (rolled back)"
return res, rerr
}
res.Status = "ok"
res.Message = "snippet removed"
return res, nil
}
// Validate is a dry-run that exercises the webserver's own configtest
// against the current on-disk state. No writes, no reload.
func (i *Installer) Validate() (Result, error) {
res := Result{
Action: "validate",
Webserver: i.Handler.Kind(),
SnippetPath: i.Handler.SnippetPath(),
ShippedVer: TemplateVersion,
}
if err := i.Handler.Validate(); err != nil {
res.Status = "fail"
res.Message = err.Error()
return res, err
}
res.Status = "ok"
res.Message = "configtest passed"
return res, nil
}
// renderTemplate prefixes the handler's body with the version marker
// the installer reads back at status/upgrade time.
func (i *Installer) renderTemplate() ([]byte, error) {
tpl, err := template.New(i.Handler.Kind()).Parse(i.Handler.Template())
if err != nil {
return nil, err
}
var b strings.Builder
b.WriteString(templateHeaderPrefix)
b.WriteString(strconv.Itoa(TemplateVersion))
b.WriteByte('\n')
if err := tpl.Execute(&b, i.templateData()); err != nil {
return nil, err
}
return []byte(b.String()), nil
}
func renderConfigFrom(cfg *config.Config) RenderConfig {
rc := RenderConfig{
ChallengeMapPath: challenge.DefaultMapPath,
ChallengeNginxMapPath: challenge.DefaultNginxMapPath,
ChallengeListenAddr: "127.0.0.1",
}
if cfg == nil {
return rc
}
rc.ChallengeListenAddr = strings.TrimSpace(cfg.Challenge.ListenAddr)
if rc.ChallengeListenAddr == "" {
rc.ChallengeListenAddr = "127.0.0.1"
}
rc.ChallengePublicURL = strings.TrimSpace(cfg.Challenge.PublicURL)
return rc
}
// RenderConfigFromConfig exposes the daemon-to-template projection for
// diagnostics that need to validate the same inputs before install.
func RenderConfigFromConfig(cfg *config.Config) RenderConfig {
return renderConfigFrom(cfg)
}
func (i *Installer) templateData() templateData {
mapPath := strings.TrimSpace(i.Config.ChallengeMapPath)
if mapPath == "" {
mapPath = challenge.DefaultMapPath
}
nginxMapPath := strings.TrimSpace(i.Config.ChallengeNginxMapPath)
if nginxMapPath == "" {
nginxMapPath = challenge.DefaultNginxMapPath
}
return templateData{
ChallengeMapPath: mapPath,
ChallengeNginxMapPath: nginxMapPath,
ChallengePublicURL: strings.TrimSpace(i.Config.ChallengePublicURL),
}
}
func (i *Installer) ensureChallengeMapFiles() error {
data := i.templateData()
mkdirAll := i.MkdirAll
if mkdirAll == nil {
mkdirAll = os.MkdirAll
}
for _, f := range []struct {
path string
body []byte
}{
{path: data.ChallengeMapPath, body: []byte("# Generated by CSM.\n")},
{path: data.ChallengeNginxMapPath, body: []byte("# Generated by CSM.\n")},
} {
if err := mkdirAll(filepath.Dir(f.path), 0o755); err != nil {
return err
}
statAt := i.StatAt
if statAt == nil {
statAt = os.Stat
}
if _, err := statAt(f.path); err == nil {
continue
} else if !errors.Is(err, os.ErrNotExist) {
return err
}
if err := i.WriteAt(f.path, f.body, 0o644); err != nil {
return err
}
}
return nil
}
func validateChallengePublicURL(rc RenderConfig) error {
raw := strings.TrimSpace(rc.ChallengePublicURL)
if raw == "" {
return ErrMissingPublicURL
}
u, err := url.Parse(raw)
if err != nil || !u.IsAbs() || u.Host == "" {
return fmt.Errorf("%w: %q", ErrInvalidPublicURL, raw)
}
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("%w: unsupported scheme %q", ErrInvalidPublicURL, u.Scheme)
}
if u.User != nil || u.RawQuery != "" || u.Fragment != "" || u.Path != "/challenge" {
return fmt.Errorf("%w: %q", ErrInvalidPublicURL, raw)
}
if isLoopbackHost(u.Hostname()) {
return fmt.Errorf("%w: host %q is loopback", ErrInvalidPublicURL, u.Hostname())
}
if isLoopbackListenAddr(rc.ChallengeListenAddr) {
return ErrLoopbackPublicURL
}
return nil
}
// ValidateChallengePublicURL applies the same direct-redirect validation used
// by Install so diagnostics and tests do not drift from installer behavior.
func ValidateChallengePublicURL(rc RenderConfig) error {
return validateChallengePublicURL(rc)
}
func isLoopbackListenAddr(addr string) bool {
addr = strings.TrimSpace(addr)
if addr == "" {
return true
}
host := addr
if h, _, err := net.SplitHostPort(addr); err == nil {
host = h
}
host = strings.Trim(host, "[]")
if host == "" {
return false
}
return isLoopbackHost(host)
}
func isLoopbackHost(host string) bool {
host = strings.TrimSpace(strings.Trim(host, "[]"))
if strings.EqualFold(host, "localhost") {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
// readSnippet parses the on-disk snippet header to recover the
// embedded version. Returns the raw bytes, a presence flag, and the
// parsed version (zero when the file exists but lacks the marker).
func (i *Installer) readSnippet() ([]byte, bool, int, error) {
data, err := i.ReadAt(i.Handler.SnippetPath())
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, false, 0, err
}
return nil, false, 0, err
}
ver := parseHeaderVersion(data)
return data, true, ver, nil
}
// restore writes the previous bytes back (or removes the new file if
// there were none) after a failed validate/reload step. Best-effort:
// I/O errors here are reported via stderr but do not change the
// installer's return code, because the caller already knows the
// original error.
func (i *Installer) restore(prevBytes []byte, prevExists bool) {
if !prevExists {
if err := i.RemoveAt(i.Handler.SnippetPath()); err != nil && !errors.Is(err, os.ErrNotExist) {
fmt.Fprintf(i.Stderr, "webserver integration: rollback delete failed: %v\n", err)
}
return
}
if err := i.WriteAt(i.Handler.SnippetPath(), prevBytes, 0o644); err != nil {
fmt.Fprintf(i.Stderr, "webserver integration: rollback write failed: %v\n", err)
}
}
func parseHeaderVersion(data []byte) int {
scanner := bytes.SplitN(data, []byte("\n"), 2)
if len(scanner) == 0 {
return 0
}
line := strings.TrimSpace(string(scanner[0]))
if !strings.HasPrefix(line, templateHeaderPrefix) {
return 0
}
rest := strings.TrimSpace(strings.TrimPrefix(line, templateHeaderPrefix))
v, err := strconv.Atoi(rest)
if err != nil {
return 0
}
return v
}
func pickHandler(info platform.Info, r cmdRunner) (Handler, error) {
switch info.WebServer {
case platform.WSApache:
return newApacheHandler(info, r), nil
case platform.WSLiteSpeed:
return newLSWSHandler(info, r), nil
case platform.WSNginx:
return newNginxHandler(r), nil
default:
return nil, ErrUnknownWebserver
}
}
// atomicWrite writes data to a sibling temp file then renames it into
// place so the webserver never sees a half-written snippet. fsync on
// the directory is best-effort; rename + fsync on the file before
// rename gives crash safety on every common Linux filesystem.
func atomicWrite(path string, data []byte, mode os.FileMode) error {
dir := filepath.Dir(path)
tmp, err := os.CreateTemp(dir, ".csm-ws-install-*")
if err != nil {
return err
}
tmpName := tmp.Name()
cleanup := func() { _ = os.Remove(tmpName) }
if _, werr := tmp.Write(data); werr != nil {
_ = tmp.Close()
cleanup()
return werr
}
if serr := tmp.Sync(); serr != nil {
_ = tmp.Close()
cleanup()
return serr
}
if cerr := tmp.Close(); cerr != nil {
cleanup()
return cerr
}
if merr := os.Chmod(tmpName, mode); merr != nil {
cleanup()
return merr
}
return os.Rename(tmpName, path)
}
package webserver
import (
"context"
"fmt"
"time"
"github.com/pidginhost/csm/internal/platform"
)
// lswsHandler manages the LiteSpeed integration. The right snippet path
// depends on how LSWS is wired:
//
// - cPanel + LSWS: LSWS runs with <loadApacheConf>1</loadApacheConf>
// and reads cPanel's Apache config tree. The snippet drops at
// /etc/apache2/conf.d/csm-challenge.conf, same as plain Apache,
// and LSWS picks it up automatically.
//
// - Plain LSWS (no cPanel): the operator runs LSWS in native mode
// with /usr/local/lsws/conf/httpd_config.xml. There is no auto-
// include dir for text-style rewrite rules; the snippet goes in
// /usr/local/lsws/conf/templates/ and the operator must include
// it manually via the LSWS WebAdmin Console -> Server -> General
// -> Rewrite -> External Rewrite Rules. The installer writes the
// file but emits a stderr note pointing at the manual step.
type lswsHandler struct {
cmdRunner cmdRunner
cpanel bool
}
func newLSWSHandler(info platform.Info, r cmdRunner) *lswsHandler {
return &lswsHandler{cmdRunner: r, cpanel: info.IsCPanel()}
}
func (h *lswsHandler) Kind() string { return "lsws" }
func (h *lswsHandler) SnippetPath() string {
if h.cpanel {
return "/etc/apache2/conf.d/csm-challenge.conf"
}
return "/usr/local/lsws/conf/templates/csm-challenge.conf"
}
func (h *lswsHandler) Template() string { return lswsTemplate }
func (h *lswsHandler) Validate() error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
// LSWS has no `lswsctrl conftest` verb (the documented surface is
// start|stop|restart|reload|condrestart|try-restart|status). The
// equivalent configtest is `lshttpd -t`, which parses the active
// config and exits non-zero on syntax errors without touching the
// running listener.
out, err := h.cmdRunner.Run(ctx, "/usr/local/lsws/bin/lshttpd", "-t")
if err != nil {
return fmt.Errorf("lsws lshttpd -t failed: %v\n%s", err, out)
}
return nil
}
// PostInstallInstructions returns operator-facing follow-up after a
// successful install. v4 of the snippet redirects challenged IPs to
// challenge.public_url directly; no LSWS External App is required, so
// the previous WebAdmin Console steps are obsolete.
func (h *lswsHandler) PostInstallInstructions() string { return "" }
func (h *lswsHandler) Reload() error {
// LSWS does not have a graceful reload equivalent; `restart` is the
// supported way to pick up new config without dropping established
// listener sockets (LSWS's internal supervisor handles the hand-
// off). The full-restart path is what the operator's own toolchain
// invokes on config change too.
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
out, err := h.cmdRunner.Run(ctx, "/usr/local/lsws/bin/lswsctrl", "restart")
if err != nil {
return fmt.Errorf("lsws restart failed: %v\n%s", err, out)
}
return nil
}
package webserver
import (
"context"
"fmt"
"time"
)
// nginxHandler manages the nginx integration. The snippet lands in
// /etc/nginx/conf.d/ where stock nginx auto-includes everything via
// the default http{} include glob. Validation uses `nginx -t`; reload
// uses `systemctl reload nginx` so existing connections drain
// gracefully.
type nginxHandler struct {
cmdRunner cmdRunner
}
func newNginxHandler(r cmdRunner) *nginxHandler {
return &nginxHandler{cmdRunner: r}
}
func (h *nginxHandler) Kind() string { return "nginx" }
func (h *nginxHandler) SnippetPath() string { return "/etc/nginx/conf.d/csm-challenge.conf" }
func (h *nginxHandler) Template() string { return nginxTemplate }
// PostInstallInstructions reminds the operator that the http{}
// snippet only ships the shared map; each guarded server{} block has
// to opt in with a one-line if-redirect. Nginx cannot apply this
// safely from an http{} drop-in.
func (h *nginxHandler) PostInstallInstructions() string {
return `Nginx detected. Per-server opt-in required so http{} drop-ins do
not blind-redirect already-protected hosts. For each server{} that
should respect the challenge:
server {
...
if ($csm_challenged) {
return 302 <challenge.public_url>?dest=$scheme://$host$request_uri;
}
}
Then run: nginx -t && systemctl reload nginx`
}
func (h *nginxHandler) Validate() error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
out, err := h.cmdRunner.Run(ctx, "nginx", "-t")
if err != nil {
return fmt.Errorf("nginx configtest failed: %v\n%s", err, out)
}
return nil
}
func (h *nginxHandler) Reload() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
out, err := h.cmdRunner.Run(ctx, "systemctl", "reload", "nginx")
if err != nil {
return fmt.Errorf("nginx reload failed: %v\n%s", err, out)
}
return nil
}
package integrity
import (
"fmt"
"os"
"reflect"
"strings"
"github.com/pidginhost/csm/internal/config"
)
// WriteConfigBytesAtomic writes data to path with the same atomic-rename
// semantics SignAndSaveAtomic uses. Intended for paths that ship pre-signed
// bytes (e.g. restoring a snapshot whose hash already matches its content)
// where re-signing would mutate the integrity block we want to preserve.
func WriteConfigBytesAtomic(path string, data []byte) error {
return atomicWriteFile(path, data, 0o600)
}
// SignAndSavePreserving writes editedBytes to path after patching
// integrity.binary_hash and integrity.config_hash inside the byte
// stream itself, not by re-marshaling the cfg. Operator comments and
// untouched formatting outside the integrity block survive
// byte-for-byte.
//
// Verifies the final bytes decode via config.LoadBytes and match
// intendedClone under reflect.DeepEqual (with integrity fields
// normalised). Mismatch aborts the write.
//
// Atomic write semantics match SignAndSaveAtomic: same-directory
// tempfile, fsync, rename. On success, intendedClone.Integrity.BinaryHash
// and .ConfigHash are updated in place to reflect the hashes written to
// disk. intendedClone.ConfigFile must equal path and intendedClone.ConfigDir
// must equal confDir.
func SignAndSavePreserving(path, confDir string, editedBytes []byte, intendedClone *config.Config, binaryHash string) error {
if intendedClone == nil {
return fmt.Errorf("intendedClone is nil")
}
if intendedClone.ConfigFile != path {
return fmt.Errorf("intendedClone.ConfigFile=%q does not match path=%q", intendedClone.ConfigFile, path)
}
if intendedClone.ConfigDir != confDir {
return fmt.Errorf("intendedClone.ConfigDir=%q does not match confDir=%q", intendedClone.ConfigDir, confDir)
}
// Hash the operator-edited bytes before the integrity scalars are
// rewritten. HashConfigStableBytes ignores the integrity block, so
// the stored hash still matches the final file after the integrity
// patch.
newConfigHash := HashConfigStableBytes(editedBytes)
// Cover the conf.d fragments merged on top of this main config. Empty
// when there are none, leaving conf.d-free configs byte-identical to
// their prior baseline.
newConfdHash, err := HashConfDir(confDir)
if err != nil {
return fmt.Errorf("hashing conf.d: %w", err)
}
patched, err := config.YAMLEdit(editedBytes, []config.YAMLChange{
{Path: []string{"integrity", "binary_hash"}, Value: binaryHash},
{Path: []string{"integrity", "config_hash"}, Value: newConfigHash},
{Path: []string{"integrity", "confd_hash"}, Value: newConfdHash},
})
if err != nil {
return fmt.Errorf("patch integrity scalars: %w", err)
}
if stripIntegrityBlock(string(patched)) != stripIntegrityBlock(string(editedBytes)) {
return fmt.Errorf("integrity patch drift: bytes outside integrity block changed")
}
decoded, err := config.LoadBytes(patched)
if err != nil {
return fmt.Errorf("verify decode: %w", err)
}
decoded.ConfigFile = path
decoded.ConfigDir = confDir
expected := *intendedClone
expected.Integrity.BinaryHash = binaryHash
expected.Integrity.ConfigHash = newConfigHash
expected.Integrity.ConfdHash = newConfdHash
if !reflect.DeepEqual(decoded, &expected) {
return fmt.Errorf("yaml rewrite drift: decoded config does not match intended clone")
}
intendedClone.Integrity.BinaryHash = binaryHash
intendedClone.Integrity.ConfigHash = newConfigHash
intendedClone.Integrity.ConfdHash = newConfdHash
return atomicWriteFile(path, patched, 0o600)
}
// SignConfigFilePreserving signs path in place without re-marshaling the
// config. It updates only the operator-owned main config file, but folds the
// conf.d fragments under confDir into integrity.confd_hash so a later edit to
// any fragment is detected by Verify.
func SignConfigFilePreserving(path, confDir, binaryHash string) (configHash, confdHash string, err error) {
// #nosec G304 -- operator-configured config path.
data, err := os.ReadFile(path)
if err != nil {
return "", "", fmt.Errorf("read config: %w", err)
}
cfg, err := config.LoadBytes(data)
if err != nil {
return "", "", err
}
cfg.ConfigFile = path
cfg.ConfigDir = confDir
if err := SignAndSavePreserving(path, confDir, data, cfg, binaryHash); err != nil {
return "", "", err
}
return cfg.Integrity.ConfigHash, cfg.Integrity.ConfdHash, nil
}
// stripIntegrityBlock removes the top-level `integrity:` mapping and
// its indented children from s, then returns the remaining text.
// Used both by SignAndSavePreserving's drift guard and by tests, so
// both compare using the same definition of "outside the integrity
// block".
func stripIntegrityBlock(s string) string {
lines := strings.Split(s, "\n")
var out []string
inIntegrity := false
for _, line := range lines {
if strings.HasPrefix(line, "integrity:") {
inIntegrity = true
continue
}
if inIntegrity {
if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") || line == "" {
continue
}
inIntegrity = false
}
out = append(out, line)
}
return strings.Join(out, "\n")
}
package integrity
import (
"bufio"
"bytes"
"crypto/sha256"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"gopkg.in/yaml.v3"
"github.com/pidginhost/csm/internal/config"
)
// HashFile returns the SHA256 hash of a file.
func HashFile(path string) (string, error) {
// #nosec G304 -- integrity hashing of operator-configured binary/config paths.
f, err := os.Open(path)
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), nil
}
// HashConfigStable hashes the config file excluding the integrity section,
// so that writing hashes back to the config doesn't change the hash.
//
// Line length is capped at 1 MiB; anything longer is treated as a
// corrupted config and the digest reflects whatever was scanned up to
// the truncation point. A corruption-driven hash will trip
// integrity.Verify on the next start, which is the intended response.
func HashConfigStable(path string) (string, error) {
// #nosec G304 -- operator-supplied config file path.
data, err := os.ReadFile(path)
if err != nil {
return "", fmt.Errorf("reading config: %w", err)
}
return HashConfigStableBytes(data), nil
}
// HashConfigStableBytes is the in-memory counterpart to
// HashConfigStable: compute the stable hash from an already-serialised
// config body without touching disk. Used by the SIGHUP reload path
// and `csm rehash` so both can hash a prospective file content before
// committing it, avoiding the two-pass-write dance that could leave
// integrity.config_hash blank on disk if the second save failed.
//
// Returns the same "sha256:..." string shape as HashConfigStable.
// The scanner is sized to cover config lines well beyond any
// realistic CSM yaml; a line longer than 1 MiB is treated as
// corruption and the final digest reflects whatever was scanned
// before the truncation (the resulting mismatch against any stored
// hash will trip integrity.Verify, which is the correct response).
func HashConfigStableBytes(data []byte) string {
h := sha256.New()
scanner := bufio.NewScanner(bytes.NewReader(data))
scanner.Buffer(make([]byte, 0, 64*1024), 1<<20)
inIntegrity := false
for scanner.Scan() {
line := scanner.Text()
// Skip the integrity section
if strings.HasPrefix(line, "integrity:") {
inIntegrity = true
continue
}
if inIntegrity {
// Still inside integrity block (indented lines)
if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") || line == "" {
continue
}
inIntegrity = false
}
_, _ = h.Write([]byte(line + "\n"))
}
return fmt.Sprintf("sha256:%x", h.Sum(nil))
}
// SignAndSaveAtomic re-computes integrity.config_hash for cfg and
// writes the result to cfg.ConfigFile atomically. Atomicity means:
// the on-disk file either carries the prior content (operation
// failed) or the fully-signed new content (operation succeeded).
// There is no window where the file exists on disk with an empty or
// stale config_hash, so a crash between the two passes of the
// previous two-save dance can no longer put the daemon into a
// crash-loop on next startup.
//
// The integrity.binary_hash is set to the supplied binaryHash. The
// CALLER is responsible for picking the right value: `csm rehash`
// hashes /opt/csm/csm afresh; SIGHUP reload preserves the prior
// daemon's binary hash because a reload cannot upgrade the binary.
//
// Implementation: marshal the config with a blank ConfigHash, hash
// the stable form of those bytes, store the hash, marshal again,
// write to a sibling temp file, rename into place. The YAML hashing
// strips the integrity block so both marshals round-trip to the
// same stable hash.
func SignAndSaveAtomic(cfg *config.Config, binaryHash string) error {
cfg.Integrity.BinaryHash = binaryHash
cfg.Integrity.ConfigHash = ""
confdHash, err := HashConfDir(cfg.ConfigDir)
if err != nil {
return fmt.Errorf("hashing conf.d: %w", err)
}
cfg.Integrity.ConfdHash = confdHash
preHash, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("marshal (pre-hash): %w", err)
}
cfg.Integrity.ConfigHash = HashConfigStableBytes(preHash)
final, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("marshal (post-hash): %w", err)
}
return atomicWriteFile(cfg.ConfigFile, final, 0o600)
}
// atomicWriteFile writes data to a temp file in the same directory
// as path, fsyncs and closes the temp, then renames it onto path.
// Rename is atomic on POSIX when source and destination are on the
// same filesystem (which is why the temp is created in the target's
// dir, not /tmp). Permission is applied before the rename.
func atomicWriteFile(path string, data []byte, perm os.FileMode) error {
targetPath, err := atomicWriteTarget(path)
if err != nil {
return err
}
dir := filepath.Dir(targetPath)
tmp, err := os.CreateTemp(dir, ".csm-cfg-*.tmp")
if err != nil {
return fmt.Errorf("create temp: %w", err)
}
tmpName := tmp.Name()
// Best-effort cleanup: if any error below leaves the temp
// behind, unlink it so we do not fill the dir with orphans.
cleanup := func() { _ = os.Remove(tmpName) }
if _, err := tmp.Write(data); err != nil {
_ = tmp.Close()
cleanup()
return fmt.Errorf("write temp: %w", err)
}
if err := tmp.Chmod(perm); err != nil {
_ = tmp.Close()
cleanup()
return fmt.Errorf("chmod temp: %w", err)
}
if err := tmp.Sync(); err != nil {
_ = tmp.Close()
cleanup()
return fmt.Errorf("fsync temp: %w", err)
}
if err := tmp.Close(); err != nil {
cleanup()
return fmt.Errorf("close temp: %w", err)
}
if err := os.Rename(tmpName, targetPath); err != nil {
cleanup()
return fmt.Errorf("rename: %w", err)
}
if err := syncDirectory(dir); err != nil {
return err
}
return nil
}
func atomicWriteTarget(path string) (string, error) {
info, err := os.Lstat(path)
if err != nil {
if os.IsNotExist(err) {
return path, nil
}
return "", fmt.Errorf("stat config path: %w", err)
}
if info.Mode()&os.ModeSymlink == 0 {
return path, nil
}
resolved, err := filepath.EvalSymlinks(path)
if err != nil {
return "", fmt.Errorf("resolve config symlink: %w", err)
}
return resolved, nil
}
func syncDirectory(dir string) error {
// #nosec G304 -- dir is derived from the caller-owned config path.
d, err := os.Open(dir)
if err != nil {
return fmt.Errorf("open dir: %w", err)
}
if err := d.Sync(); err != nil {
_ = d.Close()
return fmt.Errorf("fsync dir: %w", err)
}
if err := d.Close(); err != nil {
return fmt.Errorf("close dir: %w", err)
}
return nil
}
// HashConfDir returns a stable digest over every conf.d drop-in fragment that
// would be merged on top of the main config, in merge order. It returns the
// empty string when there are no fragments, so a config without conf.d keeps
// the empty digest and its pre-existing baseline still verifies after upgrade.
//
// Each fragment is domain-separated by name and length so two fragments cannot
// collide by shuffling bytes across the filename boundary.
func HashConfDir(confDir string) (string, error) {
frags, err := config.ConfDirFragmentDigestInput(confDir)
if err != nil {
return "", err
}
if len(frags) == 0 {
return "", nil
}
h := sha256.New()
for _, f := range frags {
fmt.Fprintf(h, "confd-fragment:%s:%d\n", f.Name, len(f.Data))
_, _ = h.Write(f.Data)
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), nil
}
// Verify checks the binary and config file integrity.
func Verify(binaryPath string, cfg *config.Config) error {
if cfg.Integrity.BinaryHash == "" {
return nil // Not yet baselined
}
currentHash, err := HashFile(binaryPath)
if err != nil {
return fmt.Errorf("hashing binary: %w", err)
}
if currentHash != cfg.Integrity.BinaryHash {
return fmt.Errorf("binary hash mismatch: expected %s, got %s", cfg.Integrity.BinaryHash, currentHash)
}
if cfg.Integrity.ConfigHash != "" {
configHash, err := HashConfigStable(cfg.ConfigFile)
if err != nil {
return fmt.Errorf("hashing config: %w", err)
}
if configHash != cfg.Integrity.ConfigHash {
return fmt.Errorf("config hash mismatch: expected %s, got %s", cfg.Integrity.ConfigHash, configHash)
}
// conf.d fragments are merged on top of the main config at load time,
// so they must be covered too. A symmetric comparison closes the gap
// both ways: a tampered or added fragment makes the computed digest
// diverge from the stored one, and a baseline taken without conf.d
// stays empty == empty. Operators who already use conf.d must re-run
// `csm rehash` once after upgrade to populate confd_hash.
confdHash, err := HashConfDir(cfg.ConfigDir)
if err != nil {
return fmt.Errorf("hashing conf.d: %w", err)
}
if confdHash != cfg.Integrity.ConfdHash {
return fmt.Errorf("conf.d hash mismatch: expected %s, got %s", cfg.Integrity.ConfdHash, confdHash)
}
}
return nil
}
// Package log provides a structured-logging wrapper around log/slog.
//
// Rationale: CSM's daemon currently emits timestamped log lines via direct
// fmt.Fprintf(os.Stderr, ...) calls. That works for journalctl but loses
// structure when operators ship logs to Loki, ELK, or Datadog. This
// package provides a drop-in replacement that:
//
// - Emits the legacy "[YYYY-MM-DD HH:MM:SS] msg" format in text mode so
// mixing csmlog calls with legacy fmt.Fprintf calls produces a
// uniform log stream (important during incremental migration)
// - Switches to JSON on CSM_LOG_FORMAT=json for log-shipping pipelines,
// emitting the slog-native level/msg/time/fields structure
// - Honors CSM_LOG_LEVEL={debug|info|warn|error} (default: info)
//
// Usage (preferred for new code):
//
// log.Info("daemon starting", "version", v, "pid", os.Getpid())
// log.Warn("log not found, will retry", "path", path)
// log.Error("alert dispatch failed", "err", err)
//
// Legacy call sites that still use fmt.Fprintf will keep working —
// migration is incremental. See docs/src/development.md for guidance.
package log
import (
"context"
"io"
"log/slog"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
// global is the package-level logger, loaded lazily on first use.
// atomic.Pointer so Init can swap it without data races.
var global atomic.Pointer[slog.Logger]
// Init configures the global logger from environment variables:
//
// CSM_LOG_FORMAT = "text" (default) | "json"
// CSM_LOG_LEVEL = "debug" | "info" (default) | "warn" | "error"
//
// Safe to call multiple times. Returns the installed logger so callers can
// also pass it into subsystems that take a *slog.Logger.
func Init() *slog.Logger {
level := parseLevel(os.Getenv("CSM_LOG_LEVEL"))
handler := buildHandler(os.Getenv("CSM_LOG_FORMAT"), level)
logger := slog.New(handler)
global.Store(logger)
slog.SetDefault(logger)
return logger
}
// L returns the current global logger, initializing it on first call.
// Cheap on the hot path (single atomic load).
func L() *slog.Logger {
if l := global.Load(); l != nil {
return l
}
return Init()
}
// Helpers that mirror slog's method set on the global logger. Provided so
// call sites don't need to write log.L().Info(...) on every line.
func Debug(msg string, args ...any) { L().Debug(msg, args...) }
func Info(msg string, args ...any) { L().Info(msg, args...) }
func Warn(msg string, args ...any) { L().Warn(msg, args...) }
func Error(msg string, args ...any) { L().Error(msg, args...) }
func parseLevel(s string) slog.Level {
switch strings.ToLower(strings.TrimSpace(s)) {
case "debug":
return slog.LevelDebug
case "warn", "warning":
return slog.LevelWarn
case "error", "err":
return slog.LevelError
default:
return slog.LevelInfo
}
}
func buildHandler(format string, level slog.Level) slog.Handler {
opts := &slog.HandlerOptions{Level: level}
switch strings.ToLower(strings.TrimSpace(format)) {
case "json":
return slog.NewJSONHandler(os.Stderr, opts)
default:
return newLegacyTextHandler(os.Stderr, level)
}
}
// legacyTextHandler emits log records in CSM's historical "[timestamp] msg"
// format so callers migrating from fmt.Fprintf produce the same output. Key
// differences from slog.NewTextHandler:
//
// - No "time=... level=... msg=..." prefix; just "[YYYY-MM-DD HH:MM:SS] msg"
// - Structured fields are appended as " key=value" when present
// - Level is prepended as "WARN:" / "ERROR:" only for non-info records
//
// This lets operators mix csmlog calls with the ~180 remaining fmt.Fprintf
// call sites in the daemon without introducing a mixed-format log stream.
type legacyTextHandler struct {
w io.Writer
mu *sync.Mutex
level slog.Level
attrs []slog.Attr
group string
}
func newLegacyTextHandler(w io.Writer, level slog.Level) *legacyTextHandler {
return &legacyTextHandler{
w: w,
mu: &sync.Mutex{},
level: level,
}
}
func (h *legacyTextHandler) Enabled(_ context.Context, level slog.Level) bool {
return level >= h.level
}
func (h *legacyTextHandler) Handle(_ context.Context, r slog.Record) error {
var sb strings.Builder
sb.Grow(128)
ts := r.Time
if ts.IsZero() {
ts = time.Now()
}
sb.WriteByte('[')
sb.WriteString(ts.Format("2006-01-02 15:04:05"))
sb.WriteString("] ")
// Prepend a level marker only for non-info records so the info path
// exactly matches the legacy "[ts] msg" format.
switch r.Level {
case slog.LevelWarn:
sb.WriteString("WARN: ")
case slog.LevelError:
sb.WriteString("ERROR: ")
case slog.LevelDebug:
sb.WriteString("DEBUG: ")
}
sb.WriteString(r.Message)
// Append pre-bound attrs then record attrs as " key=value" pairs.
for _, a := range h.attrs {
writeAttr(&sb, a)
}
r.Attrs(func(a slog.Attr) bool {
writeAttr(&sb, a)
return true
})
sb.WriteByte('\n')
h.mu.Lock()
defer h.mu.Unlock()
_, err := io.WriteString(h.w, sb.String())
return err
}
func writeAttr(sb *strings.Builder, a slog.Attr) {
if a.Key == "" {
return
}
sb.WriteString(" ")
sb.WriteString(a.Key)
sb.WriteByte('=')
v := a.Value.Resolve()
s := v.String()
// Quote values that contain whitespace so the key=value pairs stay
// parseable when operators grep the log.
if strings.ContainsAny(s, " \t") {
sb.WriteByte('"')
sb.WriteString(s)
sb.WriteByte('"')
} else {
sb.WriteString(s)
}
}
func (h *legacyTextHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
merged := make([]slog.Attr, 0, len(h.attrs)+len(attrs))
merged = append(merged, h.attrs...)
merged = append(merged, attrs...)
return &legacyTextHandler{
w: h.w,
mu: h.mu,
level: h.level,
attrs: merged,
group: h.group,
}
}
func (h *legacyTextHandler) WithGroup(name string) slog.Handler {
// Groups flatten into the attr key via a prefix; simple implementation
// that's sufficient for CSM's usage (we don't use groups today).
return &legacyTextHandler{
w: h.w,
mu: h.mu,
level: h.level,
attrs: h.attrs,
group: name,
}
}
// Ensure legacyTextHandler satisfies the slog.Handler contract at compile time.
var _ slog.Handler = (*legacyTextHandler)(nil)
// Package adapter renders and applies the MTA-native forward-guard rule. On
// cPanel/exim it writes a router + transport into the cPanel-preserved
// /etc/exim.conf.local include sections and regenerates exim.conf via
// buildeximconf, so the rule survives cPanel exim rebuilds. CSM is never in the
// live mail path: exim evaluates the rule and writes held copies to the
// CSM-owned quarantine Maildir itself.
//
// The exact router/transport here was validated on a real cPanel exim 4.99
// host (null-sender forward held while the local copy delivers; normal mail
// forwarded unchanged; Remove restores normal forwarding).
package adapter
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/pidginhost/csm/internal/mailfwd/policy"
)
// Status reports whether the guard rule is currently installed.
type Status struct {
Installed bool `json:"installed"`
}
// ForwardGuard renders and (un)installs the MTA forward-guard rule.
type ForwardGuard interface {
Apply(cfg policy.Config, badIPs []string) error
Remove() error
Status() (Status, error)
RefreshBadIPs(badIPs []string) error
}
// Default on-disk locations (overridable in tests).
const (
defaultLocalConf = "/etc/exim.conf.local"
defaultBadIPsPath = "/var/lib/csm/forward_guard/bad_ips"
defaultQuarantineDir = "/var/lib/csm/forward_quarantine/held"
// transportUser delivers held copies. It must NOT be root: cPanel lists
// root on exim's never_users, so an appendfile as root fails. mailnull is
// exim's own non-root identity and exists on every cPanel host.
transportUser = "mailnull"
)
// Managed-block sentinels. Apply replaces whatever is between them, so a
// re-apply is idempotent and Remove can strip the block cleanly.
const (
routerBegin = "# CSM-FORWARD-GUARD ROUTER BEGIN (managed by csm; do not edit)"
routerEnd = "# CSM-FORWARD-GUARD ROUTER END"
transportBegin = "# CSM-FORWARD-GUARD TRANSPORT BEGIN (managed by csm; do not edit)"
transportEnd = "# CSM-FORWARD-GUARD TRANSPORT END"
)
// eximLocalSkeleton is the full set of cPanel exim.conf.local section markers,
// created when the host has no exim.conf.local yet. cPanel's buildeximconf
// injects whatever follows each marker into the generated exim.conf.
const eximLocalSkeleton = `@AUTH@
@BEGINACL@
@CONFIG@
@DIRECTOREND@
@DIRECTORMIDDLE@
@DIRECTORSTART@
@ENDACL@
@RETRYEND@
@RETRYSTART@
@REWRITE@
@ROUTEREND@
@ROUTERSTART@
@TRANSPORTEND@
@TRANSPORTMIDDLE@
@TRANSPORTSTART@
`
// EximAdapter is the cPanel/exim ForwardGuard.
type EximAdapter struct {
localConf string
badIPsPath string
quarantineDir string
// Injected side effects (real implementations on a live host; fakes in tests).
rebuild func() error // runs buildeximconf
chown func(path, user string) error // chowns the quarantine dir
mkdirAll func(path string, perm os.FileMode) error
}
// NewEximAdapter returns an adapter targeting the standard cPanel locations.
func NewEximAdapter() *EximAdapter {
return &EximAdapter{
localConf: defaultLocalConf,
badIPsPath: defaultBadIPsPath,
quarantineDir: defaultQuarantineDir,
rebuild: runBuildEximConf,
chown: chownToUser,
mkdirAll: os.MkdirAll,
}
}
// Apply installs (or refreshes) the forward-guard rule for an enabled,
// non-dry-run policy. It is transactional: on any failure the previous
// exim.conf.local is restored and exim is rebuilt back to its prior state, so a
// failed apply never leaves a half-installed rule.
func (a *EximAdapter) Apply(cfg policy.Config, badIPs []string) error {
if !cfg.Enabled {
return fmt.Errorf("forward-guard adapter: cannot apply disabled policy")
}
if cfg.DryRun {
return fmt.Errorf("forward-guard adapter: cannot apply dry-run policy")
}
router, err := a.renderRouter(cfg.HoldSignals)
if err != nil {
return err
}
prev, hadPrev, err := a.readLocalConf()
if err != nil {
return err
}
base := prev
if !hadPrev || strings.TrimSpace(base) == "" {
base = eximLocalSkeleton
}
next, err := injectBlock(base, "@ROUTERSTART@", router)
if err != nil {
return err
}
next, err = injectBlock(next, "@TRANSPORTSTART@", a.renderTransport())
if err != nil {
return err
}
// Quarantine dir must exist and be writable by the transport user before
// exim can deliver into it.
if err := a.mkdirAll(a.quarantineDir, 0700); err != nil {
return fmt.Errorf("creating quarantine dir: %w", err)
}
if err := a.chown(a.quarantineDir, transportUser); err != nil {
return fmt.Errorf("chowning quarantine dir to %s: %w", transportUser, err)
}
if err := a.writeBadIPs(badIPs); err != nil {
return err
}
if err := writeFileAtomic(a.localConf, []byte(next)); err != nil {
return err
}
if err := a.rebuild(); err != nil {
// Roll back to the prior config so mail keeps flowing as before.
if restoreErr := a.restore(prev, hadPrev); restoreErr != nil {
return fmt.Errorf("buildeximconf failed: %w; rollback failed: %v", err, restoreErr)
}
return fmt.Errorf("buildeximconf failed, rolled back: %w", err)
}
return nil
}
// Remove strips the managed blocks and rebuilds, restoring normal forwarding.
func (a *EximAdapter) Remove() error {
cur, had, err := a.readLocalConf()
if err != nil {
return err
}
if !had {
return nil // nothing installed
}
stripped := stripBlock(cur, routerBegin, routerEnd)
stripped = stripBlock(stripped, transportBegin, transportEnd)
if stripped == cur {
return nil // not installed; no rebuild needed
}
if err := writeFileAtomic(a.localConf, []byte(stripped)); err != nil {
return err
}
if err := a.rebuild(); err != nil {
if restoreErr := a.restore(cur, true); restoreErr != nil {
return fmt.Errorf("buildeximconf failed during remove: %w; rollback failed: %v", err, restoreErr)
}
return fmt.Errorf("buildeximconf failed during remove, rolled back: %w", err)
}
return nil
}
// Status reports whether both managed blocks are present.
func (a *EximAdapter) Status() (Status, error) {
cur, had, err := a.readLocalConf()
if err != nil {
return Status{}, err
}
if !had {
return Status{}, nil
}
routerInstalled := strings.Contains(cur, routerBegin)
transportInstalled := strings.Contains(cur, transportBegin)
if routerInstalled != transportInstalled {
return Status{}, fmt.Errorf("forward-guard adapter: partial install in exim.conf.local (router=%t transport=%t)", routerInstalled, transportInstalled)
}
return Status{Installed: routerInstalled}, nil
}
func (a *EximAdapter) renderRouter(sig policy.HoldSignals) (string, error) {
var clauses []string
if sig.BounceBackscatter {
clauses = append(clauses, "{eq{$sender_address}{}}")
}
if sig.BadSenderIP {
clauses = append(clauses, fmt.Sprintf("{eq{${lookup{$sender_host_address}lsearch{%s}{1}{0}}}{1}}", a.badIPsPath))
}
if len(clauses) == 0 {
// Config validation forbids enforce mode with neither signal; guard here
// so the adapter never installs a router that holds everything or nothing.
return "", fmt.Errorf("forward-guard adapter: no routing-time-enforceable signal enabled (need bounce_backscatter or bad_sender_ip)")
}
cond := fmt.Sprintf("${if and{ {def:parent_local_part} {or{ %s } } }{yes}{no}}", strings.Join(clauses, " "))
return strings.Join([]string{
routerBegin,
"csm_forward_guard:",
" driver = accept",
" domains = ! +local_domains",
" condition = " + cond,
" transport = csm_forward_hold",
routerEnd,
}, "\n"), nil
}
func (a *EximAdapter) renderTransport() string {
headers := strings.Join([]string{
"X-CSM-Forwarder: $parent_local_part@$parent_domain",
"X-CSM-Recipient: $local_part@$domain",
"X-CSM-Sender: $sender_address",
"X-CSM-Reasons: ${if eq{$sender_address}{}{bounce_backscatter}{bad_sender_ip}}",
}, "\\n")
return strings.Join([]string{
transportBegin,
"csm_forward_hold:",
" driver = appendfile",
" directory = " + a.quarantineDir,
" maildir_format",
" create_directory",
" directory_mode = 0700",
" mode = 0600",
" user = " + transportUser,
` headers_add = "` + headers + `"`,
transportEnd,
}, "\n")
}
// RefreshBadIPs rewrites only the bad-IP lookup file. exim reads the lsearch
// file at lookup time, so the change takes effect immediately with no rebuild
// or reload -- cheap to call on a schedule as the attack DB changes.
func (a *EximAdapter) RefreshBadIPs(ips []string) error {
return a.writeBadIPs(ips)
}
func (a *EximAdapter) writeBadIPs(ips []string) error {
if err := a.mkdirAll(filepath.Dir(a.badIPsPath), 0755); err != nil {
return fmt.Errorf("creating bad IP lookup dir: %w", err)
}
var buf bytes.Buffer
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if ip == "" || strings.ContainsAny(ip, " \t\r\n:") {
continue // lsearch keys are one bare token per line
}
fmt.Fprintf(&buf, "%s: 1\n", ip)
}
return writeFileAtomic(a.badIPsPath, buf.Bytes())
}
func (a *EximAdapter) readLocalConf() (string, bool, error) {
data, err := os.ReadFile(a.localConf) // #nosec G304 -- operator-fixed exim.conf.local path
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return "", false, nil
}
return "", false, fmt.Errorf("reading %s: %w", a.localConf, err)
}
return string(data), true, nil
}
func (a *EximAdapter) restore(prev string, had bool) error {
if had {
if err := writeFileAtomic(a.localConf, []byte(prev)); err != nil {
return fmt.Errorf("restoring exim.conf.local: %w", err)
}
} else {
if err := os.Remove(a.localConf); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("removing new exim.conf.local: %w", err)
}
}
if err := a.rebuild(); err != nil {
return fmt.Errorf("rebuilding restored exim config: %w", err)
}
return nil
}
// injectBlock removes any existing managed block of the same kind, then inserts
// block immediately after the marker line. Idempotent: re-injecting yields the
// same file.
func injectBlock(conf, marker, block string) (string, error) {
// Strip a prior copy of this block so re-apply doesn't duplicate it.
begin, end := blockSentinels(block)
conf = stripBlock(conf, begin, end)
markerLine := marker + "\n"
idx := strings.Index(conf, markerLine)
if idx < 0 {
return "", fmt.Errorf("exim.conf.local missing %s marker", marker)
}
at := idx + len(markerLine)
return conf[:at] + block + "\n" + conf[at:], nil
}
var blockSentinelRe = regexp.MustCompile(`^(# CSM-FORWARD-GUARD \w+ BEGIN)`)
func blockSentinels(block string) (begin, end string) {
lines := strings.SplitN(block, "\n", 2)
begin = lines[0]
// Derive the END sentinel from the BEGIN kind.
if m := blockSentinelRe.FindStringSubmatch(begin); m != nil {
kind := strings.Fields(m[1])[2] // ROUTER or TRANSPORT
return begin, "# CSM-FORWARD-GUARD " + kind + " END"
}
return begin, ""
}
// stripBlock removes the inclusive begin..end region (and a trailing newline).
func stripBlock(conf, begin, end string) string {
for {
bi := strings.Index(conf, begin)
if bi < 0 {
return conf
}
ei := strings.Index(conf[bi:], end)
if ei < 0 {
return conf
}
stop := bi + ei + len(end)
if stop < len(conf) && conf[stop] == '\n' {
stop++
}
conf = conf[:bi] + conf[stop:]
}
}
func writeFileAtomic(path string, data []byte) error {
dir := filepath.Dir(path)
f, err := os.CreateTemp(dir, "."+filepath.Base(path)+".*.csmtmp") // #nosec G304 -- caller owns the destination path
if err != nil {
return fmt.Errorf("opening temp file for %s: %w", path, err)
}
tmp := f.Name()
if err := f.Chmod(0644); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("chmod temp file for %s: %w", path, err)
}
if n, err := f.Write(data); err != nil {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("writing %s: %w", path, err)
} else if n != len(data) {
_ = f.Close()
_ = os.Remove(tmp)
return fmt.Errorf("writing %s: %w", path, io.ErrShortWrite)
}
if err := f.Close(); err != nil {
_ = os.Remove(tmp)
return fmt.Errorf("closing temp file for %s: %w", path, err)
}
if err := os.Rename(tmp, path); err != nil {
_ = os.Remove(tmp)
return fmt.Errorf("committing %s: %w", path, err)
}
return nil
}
func runBuildEximConf() error {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
return exec.CommandContext(ctx, "/scripts/buildeximconf").Run()
}
func chownToUser(path, user string) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// #nosec G204 -- user is the constant transportUser ("mailnull") and path is
// the operator-fixed quarantine dir; neither is attacker-controlled.
return exec.CommandContext(ctx, "chown", "-R", user+":"+user, path).Run()
}
// Package guard glues the operator config to the pure forward-guard policy.
// It exists so internal/mailfwd/policy stays dependency-free of internal/config
// (policy is a low-level leaf; only this glue knows about both).
package guard
import (
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/mailfwd/policy"
)
// PolicyFromConfig projects the operator's forward-guard config onto the policy
// input. Only the fields the verdict needs are mapped; skip-list and retention
// are consumed by the adapter and quarantine, not by Verdict.
func PolicyFromConfig(fg config.ForwardGuardConfig) policy.Config {
return policy.Config{
Enabled: fg.Enabled,
DryRun: fg.DryRun,
HoldSignals: policy.HoldSignals{
BounceBackscatter: fg.HoldSignals.BounceBackscatter,
SpamFlagged: fg.HoldSignals.SpamFlagged,
Malware: fg.HoldSignals.Malware,
BadSenderIP: fg.HoldSignals.BadSenderIP,
AuthFail: fg.HoldSignals.AuthFail,
},
}
}
package guard
import (
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/mailfwd/adapter"
)
// Reconciler drives the exim forward-guard from operator config. The daemon
// calls Reconcile on startup and on every config reload, and RefreshBadIPs on a
// schedule. It is the only thing that decides apply-vs-remove, so the live mail
// path and the dry-run path can never both be active.
type Reconciler struct {
// Guard is the MTA adapter (nil on platforms without one).
Guard adapter.ForwardGuard
// Active gates the whole reconciler; the daemon sets it true only on a
// cPanel/exim host. When false, Reconcile/RefreshBadIPs are no-ops.
Active bool
// BadIPs supplies the current bad-sender-IP set (from the reputation DB).
BadIPs func() []string
}
// Reconcile installs the guard when it is enabled and enforcing, and removes it
// otherwise (disabled or dry-run). Dry-run never installs an MTA rule -- its
// accounting is CSM-side only.
func (r Reconciler) Reconcile(fg config.ForwardGuardConfig) error {
if !r.Active || r.Guard == nil {
return nil
}
if fg.Enabled && !fg.DryRun {
return r.Guard.Apply(PolicyFromConfig(fg), r.badIPs())
}
return r.Guard.Remove()
}
// RefreshBadIPs rewrites the bad-IP lookup file while the guard is enforcing.
// It is a no-op when the guard is not installed, so it is safe to call on a
// timer regardless of config.
func (r Reconciler) RefreshBadIPs(fg config.ForwardGuardConfig) error {
if !r.Active || r.Guard == nil || !fg.Enabled || fg.DryRun {
return nil
}
return r.Guard.RefreshBadIPs(r.badIPs())
}
func (r Reconciler) badIPs() []string {
if r.BadIPs == nil {
return nil
}
return r.BadIPs()
}
package intel
import (
"context"
"os/exec"
"strings"
"time"
)
// FlushResult reports how many messages a flush removed.
type FlushResult struct {
Removed int `json:"removed"`
}
// QueueFlusher removes safe-to-delete backscatter from the mail queue.
type QueueFlusher interface {
FlushBackscatter() (FlushResult, error)
}
// FrozenBackscatterIDs returns the message IDs of messages that are BOTH frozen
// AND null-sender (<>) in `exim -bp` output. This is the only set the flush
// touches: a frozen null-sender message is undeliverable bounce backscatter,
// so removing it cannot lose a real sender's mail or interrupt a live retry.
func FrozenBackscatterIDs(out string) []string {
var ids []string
for _, line := range strings.Split(out, "\n") {
if id, _, _, bounce, frozen, ok := parseQueueHeader(line); ok && bounce && frozen {
ids = append(ids, id)
}
}
return ids
}
// eximRemoveBatch bounds how many message IDs are passed to one `exim -Mrm`
// invocation so a huge queue cannot overflow the command line.
const eximRemoveBatch = 100
// EximQueueFlusher lists the queue, selects frozen null-sender messages, and
// removes them with `exim -Mrm`.
type EximQueueFlusher struct {
list func() ([]byte, error)
remove func(ids []string) error
}
// NewEximQueueFlusher returns a flusher backed by the live exim binary.
func NewEximQueueFlusher() *EximQueueFlusher {
return &EximQueueFlusher{list: runEximBp, remove: runEximRemove}
}
// FlushBackscatter removes every frozen null-sender message currently queued.
// A failure to list or remove is returned: the operator triggered this action
// explicitly and must know if it did not fully apply.
func (f *EximQueueFlusher) FlushBackscatter() (FlushResult, error) {
out, err := f.list()
if err != nil {
return FlushResult{}, err
}
ids := FrozenBackscatterIDs(string(out))
if len(ids) == 0 {
return FlushResult{}, nil
}
if err := f.remove(ids); err != nil {
return FlushResult{}, err
}
return FlushResult{Removed: len(ids)}, nil
}
func runEximRemove(ids []string) error {
for start := 0; start < len(ids); start += eximRemoveBatch {
end := start + eximRemoveBatch
if end > len(ids) {
end = len(ids)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
args := append([]string{"-Mrm"}, ids[start:end]...)
// #nosec G204 -- ids are exim message IDs parsed by parseQueueHeader's
// fixed legacy/new-format regex; no attacker-controlled text reaches argv.
err := exec.CommandContext(ctx, "exim", args...).Run()
cancel()
if err != nil {
return err
}
}
return nil
}
// Package intel turns exim_mainlog deferral lines into operator-facing
// reputation signals: which outbound IPs are being throttled, by which mail
// providers, and for what stated reason. It answers "why is the queue backing
// up" without CSM sitting in the mail path -- it only reads the log the MTA
// already writes.
//
// The parser is deliberate about attacker-controlled content: a deferral line
// echoes a remote server's free-text error, so every parsed field is bounded
// and the line is rejected unless it has the exact exim deferral shape.
package intel
import (
"net"
"regexp"
"sort"
"strings"
"time"
"unicode/utf8"
"github.com/pidginhost/csm/internal/mailfwd/inventory"
)
// eximTimeLayout is exim_mainlog's default timestamp ("2026-06-07 10:15:23").
const eximTimeLayout = "2006-01-02 15:04:05"
// maxTextLen bounds the stored remote error text so a hostile MTA cannot bloat
// the report with a multi-kilobyte error string.
const (
maxAddressLen = 254
maxDomainLen = 253
maxHostLen = 253
maxIPLen = 45
maxTextLen = 240
)
// Deferral is one parsed exim "==" deferral event.
type Deferral struct {
Time time.Time
Recipient string
Domain string
Provider inventory.ProviderClass
RemoteHost string // the deferring MX (from H=)
RemoteIP string // the deferring MX address
OutboundIP string // this server's sending IP, as echoed in the error, "" if absent
SMTPCode string // "421", "" if none
ReasonCode string // bracketed code (TSS04) or a keyword label (spamhaus, rate_limit)
Text string // bounded remote error text
}
var (
hostBoundaryRe = regexp.MustCompile(`\bH=(\S{1,253})\s+\[([0-9a-fA-F:.]{1,45})\]:\s*`)
smtpCodeRe = regexp.MustCompile(`\b([45]\d{2})\b`)
reasonRe = regexp.MustCompile(`\[([A-Za-z][A-Za-z0-9]{1,7})\]`)
ipv4Re = regexp.MustCompile(`\b\d{1,3}(?:\.\d{1,3}){3}\b`)
whitespaceR = regexp.MustCompile(`\s+`)
)
// parseDeferralLine parses one exim_mainlog line. ok is false for anything that
// is not a deferral (delivery "=>", arrival "<=", failure "**", blanks, junk).
func parseDeferralLine(line string) (Deferral, bool) {
fields := strings.Fields(line)
// <date> <time> <msgid> == <recipient> ...
if len(fields) < 5 || fields[3] != "==" {
return Deferral{}, false
}
recipient := strings.Trim(fields[4], "<>")
if recipient == "" || len(recipient) > maxAddressLen || !strings.Contains(recipient, "@") {
return Deferral{}, false
}
d := Deferral{
Recipient: recipient,
Provider: inventory.ClassifyAddress(recipient),
}
if at := strings.LastIndexByte(recipient, '@'); at >= 0 && at < len(recipient)-1 {
d.Domain = strings.ToLower(recipient[at+1:])
}
if d.Domain == "" || len(d.Domain) > maxDomainLen {
return Deferral{}, false
}
if t, err := time.ParseInLocation(eximTimeLayout, fields[0]+" "+fields[1], time.Local); err == nil {
d.Time = t
}
errText := fallbackDeferralText(line)
canExtractOutboundIP := false
if m := hostBoundaryRe.FindStringSubmatchIndex(line); m != nil {
d.RemoteHost = line[m[2]:m[3]]
d.RemoteIP = parseIPLiteral(line[m[4]:m[5]])
errText = line[m[1]:]
canExtractOutboundIP = true
}
d.SMTPCode = firstSMTPCode(errText)
if canExtractOutboundIP {
d.OutboundIP = firstIPv4(errText)
}
d.ReasonCode = classifyReason(errText)
d.Text = boundText(errText)
return d, true
}
func fallbackDeferralText(line string) string {
if i := strings.Index(line, " defer ("); i >= 0 {
if j := strings.Index(line[i:], "):"); j >= 0 {
return strings.TrimSpace(line[i+j+2:])
}
}
return line
}
func parseIPLiteral(s string) string {
if len(s) > maxIPLen {
return ""
}
ip := net.ParseIP(s)
if ip == nil {
return ""
}
return s
}
func firstSMTPCode(s string) string {
for _, m := range smtpCodeRe.FindAllStringSubmatchIndex(s, -1) {
start, end := m[2], m[3]
if smtpCodeIsAddressFragment(s, start, end) {
continue
}
return s[start:end]
}
return ""
}
func smtpCodeIsAddressFragment(s string, start, end int) bool {
if start > 0 {
switch s[start-1] {
case '.', ':':
return true
}
}
if end < len(s) {
switch s[end] {
case '.', ':':
return true
}
}
return false
}
// firstIPv4 returns the first syntactically valid IPv4 address in s, or "".
func firstIPv4(s string) string {
for _, cand := range ipv4Re.FindAllString(s, -1) {
if ip := net.ParseIP(cand); ip != nil && ip.To4() != nil {
return cand
}
}
return ""
}
// classifyReason resolves a stable reason token: a bracketed provider code
// (e.g. TSS04) when present, otherwise a keyword label derived from the error
// text. Returns "" when nothing recognizable is found.
func classifyReason(errText string) string {
for _, m := range reasonRe.FindAllStringSubmatch(errText, -1) {
if validReasonCode(m[1]) {
return m[1]
}
}
low := strings.ToLower(errText)
switch {
case strings.Contains(low, "spamhaus"):
return "spamhaus"
case strings.Contains(low, "unusual rate"), strings.Contains(low, "rate limit"),
strings.Contains(low, "too many"), strings.Contains(low, "unexpected volume"):
return "rate_limit"
case strings.Contains(low, "complaint"):
return "complaint"
case strings.Contains(low, "greylist"), strings.Contains(low, "grey-list"),
strings.Contains(low, "try again later"):
return "greylist"
case strings.Contains(low, "blocked"), strings.Contains(low, "blacklist"),
strings.Contains(low, "listed"):
return "blocked"
}
return ""
}
func validReasonCode(code string) bool {
upper := strings.ToUpper(code)
if strings.HasPrefix(upper, "TLS") {
return false
}
trailingDigits := 0
for i := len(code) - 1; i >= 0; i-- {
if code[i] < '0' || code[i] > '9' {
break
}
trailingDigits++
}
return trailingDigits >= 2
}
func boundText(s string) string {
s = strings.ToValidUTF8(s, "?")
s = strings.TrimSpace(whitespaceR.ReplaceAllString(s, " "))
if len(s) <= maxTextLen {
return s
}
// Truncate on a rune boundary: the deferral text is attacker-influenced,
// so a byte-position cut could split a multi-byte rune and store invalid
// UTF-8 that JSON would then mangle.
cut := maxTextLen
for cut > 0 && !utf8.RuneStart(s[cut]) {
cut--
}
return s[:cut]
}
// ReasonCount is a reason token and how often it occurred.
type ReasonCount struct {
Code string `json:"code"`
Count int `json:"count"`
}
// ProviderCount is a provider class and how often it appeared.
type ProviderCount struct {
Provider string `json:"provider"`
Count int `json:"count"`
}
// ProviderRollup aggregates deferrals to one provider class.
type ProviderRollup struct {
Provider string `json:"provider"`
Deferrals int `json:"deferrals"`
Reasons []ReasonCount `json:"reasons"`
LastSeen time.Time `json:"last_seen"`
Sample string `json:"sample"`
}
// OutboundIPRollup aggregates deferrals affecting one of this server's sending
// IPs -- the reputation picture for that address.
type OutboundIPRollup struct {
IP string `json:"ip"`
Deferrals int `json:"deferrals"`
Providers []ProviderCount `json:"providers"`
Reasons []ReasonCount `json:"reasons"`
LastSeen time.Time `json:"last_seen"`
}
// Report is the aggregated deferral picture over a window of log lines.
type Report struct {
Deferrals int `json:"deferrals"`
Providers []ProviderRollup `json:"providers"`
OutboundIPs []OutboundIPRollup `json:"outbound_ips"`
}
// emptyReport returns a zero report with non-nil slices so it serializes as []
// rather than null.
func emptyReport() Report {
return Report{Providers: []ProviderRollup{}, OutboundIPs: []OutboundIPRollup{}}
}
// BuildReport parses every line and aggregates deferrals by provider and by
// outbound IP. Non-deferral lines are ignored.
func BuildReport(lines []string) Report {
provAgg := map[string]*provAccum{}
ipAgg := map[string]*ipAccum{}
deferrals := 0
for _, line := range lines {
d, ok := parseDeferralLine(line)
if !ok {
continue
}
deferrals++
prov := string(d.Provider)
pa := provAgg[prov]
if pa == nil {
pa = &provAccum{reasons: map[string]int{}}
provAgg[prov] = pa
}
pa.add(d)
if d.OutboundIP != "" {
ia := ipAgg[d.OutboundIP]
if ia == nil {
ia = &ipAccum{providers: map[string]int{}, reasons: map[string]int{}}
ipAgg[d.OutboundIP] = ia
}
ia.add(d)
}
}
rep := emptyReport()
rep.Deferrals = deferrals
for prov, pa := range provAgg {
rep.Providers = append(rep.Providers, ProviderRollup{
Provider: prov,
Deferrals: pa.count,
Reasons: sortedReasons(pa.reasons),
LastSeen: pa.lastSeen,
Sample: pa.sample,
})
}
for ip, ia := range ipAgg {
rep.OutboundIPs = append(rep.OutboundIPs, OutboundIPRollup{
IP: ip,
Deferrals: ia.count,
Providers: sortedProviders(ia.providers),
Reasons: sortedReasons(ia.reasons),
LastSeen: ia.lastSeen,
})
}
sort.Slice(rep.Providers, func(i, j int) bool {
if rep.Providers[i].Deferrals != rep.Providers[j].Deferrals {
return rep.Providers[i].Deferrals > rep.Providers[j].Deferrals
}
return rep.Providers[i].Provider < rep.Providers[j].Provider
})
sort.Slice(rep.OutboundIPs, func(i, j int) bool {
if rep.OutboundIPs[i].Deferrals != rep.OutboundIPs[j].Deferrals {
return rep.OutboundIPs[i].Deferrals > rep.OutboundIPs[j].Deferrals
}
return rep.OutboundIPs[i].IP < rep.OutboundIPs[j].IP
})
return rep
}
type provAccum struct {
count int
reasons map[string]int
lastSeen time.Time
sample string
}
func (p *provAccum) add(d Deferral) {
p.count++
if d.ReasonCode != "" {
p.reasons[d.ReasonCode]++
}
if d.Time.After(p.lastSeen) {
p.lastSeen = d.Time
}
if p.sample == "" {
p.sample = d.Text
}
}
type ipAccum struct {
count int
providers map[string]int
reasons map[string]int
lastSeen time.Time
}
func (a *ipAccum) add(d Deferral) {
a.count++
a.providers[string(d.Provider)]++
if d.ReasonCode != "" {
a.reasons[d.ReasonCode]++
}
if d.Time.After(a.lastSeen) {
a.lastSeen = d.Time
}
}
func sortedReasons(m map[string]int) []ReasonCount {
out := make([]ReasonCount, 0, len(m))
for code, n := range m {
out = append(out, ReasonCount{Code: code, Count: n})
}
sort.Slice(out, func(i, j int) bool {
if out[i].Count != out[j].Count {
return out[i].Count > out[j].Count
}
return out[i].Code < out[j].Code
})
return out
}
func sortedProviders(m map[string]int) []ProviderCount {
out := make([]ProviderCount, 0, len(m))
for prov, n := range m {
out = append(out, ProviderCount{Provider: prov, Count: n})
}
sort.Slice(out, func(i, j int) bool {
if out[i].Count != out[j].Count {
return out[i].Count > out[j].Count
}
return out[i].Provider < out[j].Provider
})
return out
}
package intel
import (
"context"
"os/exec"
"regexp"
"sort"
"strconv"
"strings"
"time"
)
// QueueComposition is the makeup of the exim queue: how much is real mail still
// trying to deliver versus null-sender bounce backscatter, how much is frozen,
// and which recipients are stuck the most.
type QueueComposition struct {
Total int `json:"total"`
Bounce int `json:"bounce"` // null-sender <> messages (backscatter)
Real int `json:"real"`
Frozen int `json:"frozen"`
FlushableBackscatter int `json:"flushable_backscatter"` // frozen AND null-sender: safe to flush
OldestAge string `json:"oldest_age"`
TopRecipients []RecipientCount `json:"top_recipients"`
}
// RecipientCount is a recipient address and how many queued messages target it.
type RecipientCount struct {
Address string `json:"address"`
Count int `json:"count"`
}
const (
topRecipientLimit = 10
// Queue headers are padded for age alignment. Recipient and continuation
// lines are indented much deeper and must not be parsed as new messages.
maxQueueHeaderIndent = 4
)
var (
// A queue header line: "<age> <size> <msgid> [(user)] <sender> [*** frozen ***]".
// Accept both the legacy 6-6-2 message id and the longer base62 form exim
// 4.97+ emits (6-11-4, e.g. "1wVR8E-0000000C9po-1DDg").
queueMsgIDRe = regexp.MustCompile(`^[0-9A-Za-z]{6}-(?:[0-9A-Za-z]{6}-[0-9A-Za-z]{2}|[0-9A-Za-z]{11}-[0-9A-Za-z]{4})$`)
queueAgeRe = regexp.MustCompile(`^\d+[smhdw]$`)
queueSizeRe = regexp.MustCompile(`(?i)^\d+(?:\.\d+)?[kmgt]?$`)
)
// ParseQueue parses `exim -bp` output into a composition summary.
func ParseQueue(out string) QueueComposition {
comp := QueueComposition{TopRecipients: []RecipientCount{}}
recipients := map[string]int{}
oldestSeconds := -1
inMessage := false
for _, line := range strings.Split(out, "\n") {
if _, age, ageSec, bounce, frozen, ok := parseQueueHeader(line); ok {
comp.Total++
if bounce {
comp.Bounce++
} else {
comp.Real++
}
if frozen {
comp.Frozen++
}
if bounce && frozen {
comp.FlushableBackscatter++
}
if ageSec > oldestSeconds {
oldestSeconds = ageSec
comp.OldestAge = age
}
inMessage = true
continue
}
if !inMessage {
continue
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
if queueHeaderCandidate(line) {
inMessage = false
continue
}
if queueHeaderIndent(line) <= maxQueueHeaderIndent {
inMessage = false
continue
}
if !queueRecipientLine(line) {
inMessage = false
continue
}
if addr, ok := queueRecipientAddress(trimmed); ok {
recipients[addr]++
}
}
comp.TopRecipients = topRecipients(recipients)
return comp
}
func parseQueueHeader(line string) (msgID, age string, ageSeconds int, bounce, frozen, ok bool) {
if queueHeaderIndent(line) > maxQueueHeaderIndent {
return "", "", 0, false, false, false
}
fields := strings.Fields(line)
if len(fields) != 4 && len(fields) != 5 && len(fields) != 7 && len(fields) != 8 {
return "", "", 0, false, false, false
}
if !queueAgeRe.MatchString(fields[0]) || !queueSizeRe.MatchString(fields[1]) || !queueMsgIDRe.MatchString(fields[2]) {
return "", "", 0, false, false, false
}
senderIndex := 3
frozenIndex := 4
if len(fields) == 5 || len(fields) == 8 {
if !queueLocalUserField(fields[3]) {
return "", "", 0, false, false, false
}
senderIndex = 4
frozenIndex = 5
}
if len(fields) == 7 || len(fields) == 8 {
if fields[frozenIndex] != "***" || fields[frozenIndex+1] != "frozen" || fields[frozenIndex+2] != "***" {
return "", "", 0, false, false, false
}
frozen = true
}
bounce = fields[senderIndex] == "<>"
return fields[2], fields[0], ageToSeconds(fields[0]), bounce, frozen, true
}
func queueHeaderCandidate(line string) bool {
if queueHeaderIndent(line) > maxQueueHeaderIndent {
return false
}
fields := strings.Fields(line)
return (len(fields) == 4 || len(fields) == 7) && queueMsgIDRe.MatchString(fields[2])
}
func queueRecipientLine(line string) bool {
return len(line) > 0 && (line[0] == ' ' || line[0] == '\t')
}
func queueRecipientAddress(trimmed string) (string, bool) {
fields := strings.Fields(trimmed)
// Exim prefixes an already-delivered recipient with "D"; it is not stuck.
if len(fields) == 2 && fields[0] == "D" {
return "", false
}
if len(fields) != 1 {
return "", false
}
addr := strings.Trim(fields[0], "<>")
if len(addr) > maxAddressLen || !strings.Contains(addr, "@") {
return "", false
}
return addr, true
}
func queueLocalUserField(field string) bool {
if len(field) < 3 || field[0] != '(' || field[len(field)-1] != ')' {
return false
}
return !strings.ContainsAny(field[1:len(field)-1], "() \t\r\n")
}
func queueHeaderIndent(line string) int {
for i := 0; i < len(line); i++ {
switch line[i] {
case ' ':
continue
case '\t':
return maxQueueHeaderIndent + 1
default:
return i
}
}
return len(line)
}
// ageToSeconds converts an exim age token (e.g. "25m", "4d") to seconds. An
// unrecognized token returns 0.
func ageToSeconds(age string) int {
if len(age) < 2 {
return 0
}
n, err := strconv.Atoi(age[:len(age)-1])
if err != nil {
return 0
}
switch age[len(age)-1] {
case 's':
return n
case 'm':
return n * 60
case 'h':
return n * 3600
case 'd':
return n * 86400
case 'w':
return n * 604800
}
return 0
}
func topRecipients(m map[string]int) []RecipientCount {
out := make([]RecipientCount, 0, len(m))
for addr, n := range m {
out = append(out, RecipientCount{Address: addr, Count: n})
}
sort.Slice(out, func(i, j int) bool {
if out[i].Count != out[j].Count {
return out[i].Count > out[j].Count
}
return out[i].Address < out[j].Address
})
if len(out) > topRecipientLimit {
out = out[:topRecipientLimit]
}
return out
}
// QueueReporter produces the queue composition for the host.
type QueueReporter interface {
Composition() (QueueComposition, error)
}
// EmptyQueueReporter yields an empty composition. It stands in on platforms
// with no exim queue (non-cPanel).
type EmptyQueueReporter struct{}
func (EmptyQueueReporter) Composition() (QueueComposition, error) {
return QueueComposition{TopRecipients: []RecipientCount{}}, nil
}
// EximQueueSource runs `exim -bp` and parses the result.
type EximQueueSource struct {
run func() ([]byte, error)
}
// NewEximQueueSource returns a source that reads the live exim queue.
func NewEximQueueSource() *EximQueueSource {
return &EximQueueSource{run: runEximBp}
}
// Composition lists the queue and summarizes it. An exim error yields an empty
// composition, not a hard error: this is a read-only visibility surface.
func (s *EximQueueSource) Composition() (QueueComposition, error) {
out, err := s.run()
if err != nil {
// exim absent or failing means no observable queue, not an API error;
// surface an empty composition rather than a 500.
return QueueComposition{TopRecipients: []RecipientCount{}}, nil //nolint:nilerr
}
return ParseQueue(string(out)), nil
}
func runEximBp() ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return exec.CommandContext(ctx, "exim", "-bp").Output()
}
package intel
import (
"io"
"os"
"strings"
)
// eximMainLog is the cPanel/exim delivery log. Deferrals to remote providers
// are recorded here; CSM only reads it.
const eximMainLog = "/var/log/exim_mainlog"
// Default read bounds: enough recent history to see a throttle pattern without
// reading an unbounded multi-gigabyte log.
const (
defaultTailBytes = 8 << 20 // 8 MiB
defaultTailLines = 20000
)
// Reporter produces a deferral Report for the host.
type Reporter interface {
Report() (Report, error)
}
// EmptyReporter yields an empty report. It stands in on platforms with no exim
// log (non-cPanel) until their adapters land.
type EmptyReporter struct{}
func (EmptyReporter) Report() (Report, error) { return emptyReport(), nil }
// EximSource reads the tail of exim_mainlog and builds a deferral report.
type EximSource struct {
path string
tailBytes int64
tailLines int
readTail func(path string, maxBytes int64, maxLines int) []string
}
// NewEximSource returns a source reading the standard exim_mainlog location.
func NewEximSource() *EximSource {
return &EximSource{
path: eximMainLog,
tailBytes: defaultTailBytes,
tailLines: defaultTailLines,
readTail: tailLines,
}
}
// Report reads the recent tail of the log and aggregates deferrals. A missing
// or unreadable log yields an empty report, not an error: no log just means no
// observed deferrals, and this is a read-only visibility surface.
func (s *EximSource) Report() (Report, error) {
return BuildReport(s.readTail(s.path, s.tailBytes, s.tailLines)), nil
}
// tailLines returns up to maxLines trailing lines of the file, reading at most
// the last maxBytes so a huge log never loads whole into memory. A partial
// first line (from the byte-window cut) is dropped.
func tailLines(path string, maxBytes int64, maxLines int) []string {
if maxBytes <= 0 || maxLines <= 0 {
return nil
}
f, err := os.Open(path) // #nosec G304 -- fixed exim_mainlog path, operator-scoped.
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
info, err := f.Stat()
if err != nil {
return nil
}
var data []byte
if info.Size() > maxBytes {
if _, err := f.Seek(-maxBytes, io.SeekEnd); err != nil {
return nil
}
data, _ = io.ReadAll(f)
// Drop the first (likely partial) line after a mid-file seek.
if nl := strings.IndexByte(string(data), '\n'); nl >= 0 {
data = data[nl+1:]
}
} else {
data, _ = io.ReadAll(f)
}
text := strings.TrimSuffix(string(data), "\n")
if text == "" {
return nil
}
lines := strings.Split(text, "\n")
if len(lines) > maxLines {
lines = lines[len(lines)-maxLines:]
}
return lines
}
package inventory
import (
"path/filepath"
"sort"
"strings"
)
// FS is the minimal filesystem surface the enumerator needs. Injected so the
// cPanel source can be tested against fixture directories without root.
type FS interface {
Glob(pattern string) ([]string, error)
ReadFile(name string) ([]byte, error)
}
// Source enumerates the forwarders configured on a host.
type Source interface {
Forwarders() ([]Forwarder, error)
}
// EmptySource reports no forwarders. It stands in on platforms whose
// enumeration is not wired yet (non-cPanel), so callers always hold a usable
// Source instead of a nil.
type EmptySource struct{}
func (EmptySource) Forwarders() ([]Forwarder, error) { return []Forwarder{}, nil }
// CPanelSource reads forwarders from cPanel's /etc/valiases directory, with
// local domains from /etc/localdomains and /etc/virtualdomains and owners from
// /etc/userdomains.
type CPanelSource struct {
fs FS
valiasGlob string
localDomainsPath string
virtualDomainsPath string
userDomainsPath string
}
// NewCPanelSource returns a source reading the standard cPanel locations.
func NewCPanelSource() *CPanelSource {
return &CPanelSource{
fs: osFS{},
valiasGlob: "/etc/valiases/*",
localDomainsPath: "/etc/localdomains",
virtualDomainsPath: "/etc/virtualdomains",
userDomainsPath: "/etc/userdomains",
}
}
// Forwarders enumerates every forwarder across all hosted domains. A missing
// or unreadable valias file is skipped, not fatal: partial inventory beats no
// inventory on a server with thousands of domains.
func (s *CPanelSource) Forwarders() ([]Forwarder, error) {
localDomains := s.loadLocalDomains()
owners := s.loadOwners()
files, err := s.fs.Glob(s.valiasGlob)
if err != nil {
return nil, err
}
var out []Forwarder
for _, path := range files {
domain := normalizeDomain(filepath.Base(path))
content, err := s.fs.ReadFile(path)
if err != nil {
continue
}
owner := owners[domain]
for _, line := range strings.Split(string(content), "\n") {
fwd, ok := parseForwarderLine(domain, line, localDomains)
if !ok {
continue
}
fwd.Owner = owner
out = append(out, fwd)
}
}
sort.Slice(out, func(i, j int) bool { return out[i].Source < out[j].Source })
return out, nil
}
// loadLocalDomains reads cPanel's local-domain files into a normalized set.
// Returns an empty set when the files are unavailable, which makes every
// destination classify as external -- the safe direction for a reputation tool
// (over-report external, never hide it).
func (s *CPanelSource) loadLocalDomains() map[string]bool {
domains := make(map[string]bool)
for _, path := range []string{s.localDomainsPath, s.virtualDomainsPath} {
if path == "" {
continue
}
content, err := s.fs.ReadFile(path)
if err != nil {
continue
}
for _, line := range strings.Split(string(content), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
domain := configDomain(line)
if domain != "" {
domains[domain] = true
}
}
}
return domains
}
// loadOwners reads /etc/userdomains ("domain: user") into a domain->owner map.
func (s *CPanelSource) loadOwners() map[string]string {
owners := make(map[string]string)
content, err := s.fs.ReadFile(s.userDomainsPath)
if err != nil {
return owners
}
for _, line := range strings.Split(string(content), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
idx := strings.IndexByte(line, ':')
if idx <= 0 {
continue
}
domain := configDomain(line[:idx])
owner := strings.TrimSpace(line[idx+1:])
if domain != "" && owner != "" {
owners[domain] = owner
}
}
return owners
}
func configDomain(line string) string {
line = strings.TrimSpace(line)
if idx := strings.IndexByte(line, ':'); idx >= 0 {
line = strings.TrimSpace(line[:idx])
}
domain := normalizeDomain(line)
if domain == "" || strings.ContainsAny(domain, " \t\r\n/:\\") {
return ""
}
return domain
}
package inventory
import (
"os"
"path/filepath"
)
// osFS is the production FS backed by the real filesystem.
type osFS struct{}
func (osFS) Glob(pattern string) ([]string, error) { return filepath.Glob(pattern) }
// ReadFile reads a file. Paths come from a fixed glob of operator-owned mail
// config directories, not from untrusted input.
func (osFS) ReadFile(name string) ([]byte, error) {
return os.ReadFile(name) // #nosec G304 -- name is a valias path from a fixed glob, operator-scoped.
}
// Package inventory enumerates mail forwarders on a host and classifies their
// destinations, so operators can see which accounts relay mail off-server and
// to which providers. It is the canonical home for forwarder-domain logic in
// CSM. Enumeration is platform-abstracted: cPanel reads /etc/valiases, other
// hosts read /etc/aliases, ~/.forward, or postfix virtual maps (wired
// incrementally). Reading the filesystem is injected so the parsing logic is
// testable without root or a live mail server.
package inventory
import (
"strings"
"golang.org/x/net/publicsuffix"
)
// ProviderClass labels a forwarder destination by where it delivers. Free
// providers (Yahoo/Gmail/Outlook) are split out because forwarding spam to
// them is what degrades a server's outbound reputation; local stays on-server.
type ProviderClass string
const (
ProviderLocal ProviderClass = "local"
ProviderYahoo ProviderClass = "yahoo"
ProviderGmail ProviderClass = "gmail"
ProviderOutlook ProviderClass = "outlook"
ProviderExternal ProviderClass = "external"
)
// freeProviderExact maps known free-provider mail domains to their class.
// Lowercase keys; lookups lowercase the input.
var freeProviderExact = map[string]ProviderClass{
"gmail.com": ProviderGmail,
"googlemail.com": ProviderGmail,
"rocketmail.com": ProviderYahoo,
"yahoo.ca": ProviderYahoo,
"yahoo.co.in": ProviderYahoo,
"yahoo.co.jp": ProviderYahoo,
"yahoo.co.uk": ProviderYahoo,
"yahoo.com": ProviderYahoo,
"yahoo.com.au": ProviderYahoo,
"yahoo.com.br": ProviderYahoo,
"yahoo.com.mx": ProviderYahoo,
"yahoo.de": ProviderYahoo,
"yahoo.es": ProviderYahoo,
"yahoo.fr": ProviderYahoo,
"yahoo.it": ProviderYahoo,
"yahoo.ro": ProviderYahoo,
"ymail.com": ProviderYahoo,
"hotmail.co.uk": ProviderOutlook,
"hotmail.com": ProviderOutlook,
"hotmail.de": ProviderOutlook,
"hotmail.fr": ProviderOutlook,
"live.co.uk": ProviderOutlook,
"live.com": ProviderOutlook,
"live.com.au": ProviderOutlook,
"live.de": ProviderOutlook,
"live.fr": ProviderOutlook,
"live.it": ProviderOutlook,
"live.ro": ProviderOutlook,
"msn.com": ProviderOutlook,
"outlook.com": ProviderOutlook,
"outlook.de": ProviderOutlook,
}
// Destination is one resolved target of a forwarder.
type Destination struct {
Address string `json:"address"`
Domain string `json:"domain"`
Provider ProviderClass `json:"provider"`
}
// Forwarder is a single source address and everything it relays to.
type Forwarder struct {
Source string `json:"source"` // local_part@domain
Domain string `json:"domain"` // hosting domain
Owner string `json:"owner"` // panel account, "" if unknown
Destinations []Destination `json:"destinations"` // address targets only
KeepLocal bool `json:"keep_local"` // also delivers to a local mailbox
ForwardOnly bool `json:"forward_only"` // only remote targets, no local copy
}
// HasExternal reports whether any destination leaves the server.
func (f Forwarder) HasExternal() bool {
for _, d := range f.Destinations {
if d.Provider != ProviderLocal {
return true
}
}
return false
}
// HasFreeProvider reports whether any destination is a free-provider mailbox
// (the reputation-risk case).
func (f Forwarder) HasFreeProvider() bool {
for _, d := range f.Destinations {
switch d.Provider {
case ProviderYahoo, ProviderGmail, ProviderOutlook:
return true
}
}
return false
}
// ClassifyAddress returns the provider class of a mail address with no
// local-domain context. Use it where the address is known to be a remote
// recipient (e.g. a deferral target parsed from exim_mainlog), so a bare
// free-provider domain classifies as that provider rather than local.
func ClassifyAddress(addr string) ProviderClass {
return classifyProvider(addr, nil)
}
// classifyProvider returns the provider class of a destination address.
// localDomains are the domains hosted on this server (lowercased keys).
func classifyProvider(addr string, localDomains map[string]bool) ProviderClass {
addr = strings.TrimSpace(addr)
at := strings.LastIndexByte(addr, '@')
if at < 0 || at >= len(addr)-1 {
// No domain: a bare local part is delivered to the local mailbox.
return ProviderLocal
}
domain := normalizeDomain(addr[at+1:])
if localDomains[domain] {
return ProviderLocal
}
if c, ok := freeProviderExact[domain]; ok {
return c
}
if registered, ok := registeredDomain(domain); ok {
if c, ok := freeProviderExact[registered]; ok {
return c
}
}
return ProviderExternal
}
func registeredDomain(domain string) (string, bool) {
registered, err := publicsuffix.EffectiveTLDPlusOne(domain)
if err != nil {
return "", false
}
return normalizeDomain(registered), true
}
func normalizeDomain(domain string) string {
return strings.TrimSuffix(strings.ToLower(strings.TrimSpace(domain)), ".")
}
// parseForwarderLine parses one alias/valias line ("local_part: dest[, dest]")
// for the given hosting domain. Returns ok=false for blanks, comments,
// malformed lines, and non-address forwarders (pipes, :fail:, :blackhole:,
// /dev/null) -- those are not mail relayed to a mailbox and carry no
// reputation risk. A line is still returned for purely local aliases so the
// inventory is complete; callers filter on HasExternal as needed.
func parseForwarderLine(domain, line string, localDomains map[string]bool) (Forwarder, bool) {
domain = normalizeDomain(domain)
if domain == "" {
return Forwarder{}, false
}
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
return Forwarder{}, false
}
colon := strings.IndexByte(line, ':')
if colon < 0 {
return Forwarder{}, false
}
localPart := strings.TrimSpace(line[:colon])
rest := strings.TrimSpace(line[colon+1:])
if localPart == "" || rest == "" {
return Forwarder{}, false
}
fwd := Forwarder{
Source: localPart + "@" + domain,
Domain: domain,
}
for _, raw := range strings.Split(rest, ",") {
dest := strings.TrimSpace(raw)
if dest == "" || !isAddressDestination(dest) {
// Pipe / :fail: / :blackhole: / /dev/null / autoresponder:
// not an address relay. Skip the target but keep the record.
continue
}
fwd.Destinations = append(fwd.Destinations, newDestination(dest, localDomains))
}
if len(fwd.Destinations) == 0 {
return Forwarder{}, false
}
localCount := 0
for _, d := range fwd.Destinations {
if d.Provider == ProviderLocal {
localCount++
}
}
fwd.KeepLocal = localCount > 0
fwd.ForwardOnly = localCount == 0
return fwd, true
}
// isAddressDestination reports whether a valias destination is a mailbox
// address (as opposed to a pipe, discard, fail, or file delivery).
func isAddressDestination(dest string) bool {
dest = strings.TrimSpace(dest)
if quotedLocalPartAddress(dest) {
return true
}
if len(dest) >= 2 && dest[0] == '"' && dest[len(dest)-1] == '"' {
dest = strings.TrimSpace(dest[1 : len(dest)-1])
}
switch {
case dest == "":
return false
case strings.HasPrefix(dest, "|"): // pipe to a program
return false
case strings.HasPrefix(dest, ":"): // :fail:, :blackhole:, :defer:
return false
case strings.HasPrefix(dest, "/"): // file / /dev/null
return false
case strings.HasPrefix(dest, "\""): // malformed quote or quoted directive
return false
}
return true
}
func quotedLocalPartAddress(dest string) bool {
if !strings.HasPrefix(dest, "\"") {
return false
}
closing := strings.IndexByte(dest[1:], '"')
if closing < 0 {
return false
}
closing++ // convert offset in dest[1:] to index in dest
return closing+1 < len(dest) && dest[closing+1] == '@'
}
func newDestination(addr string, localDomains map[string]bool) Destination {
addr = normalizeAddressDestination(addr)
domain := ""
if at := strings.LastIndexByte(addr, '@'); at >= 0 && at < len(addr)-1 {
domain = normalizeDomain(addr[at+1:])
}
return Destination{
Address: addr,
Domain: domain,
Provider: classifyProvider(addr, localDomains),
}
}
func normalizeAddressDestination(addr string) string {
addr = strings.TrimSpace(addr)
if len(addr) >= 2 && addr[0] == '"' && addr[len(addr)-1] == '"' {
inner := strings.TrimSpace(addr[1 : len(addr)-1])
if strings.Contains(inner, "@") {
return inner
}
}
return addr
}
// Package policy is the single source of truth for the forward-guard hold
// decision: given the signals observed for a message, should the external
// forward copy be held? The same Verdict function feeds both the dry-run
// "would-hold" accounting and (in Phase 2) the generated MTA rule, so the two
// can never drift apart.
//
// The verdict is layered: any enabled signal that matches holds the message.
// Holding is conservative-by-omission -- a signal whose toggle is off, or a
// message with no matching signal, is never held.
package policy
// MessageMeta is the set of signals known about a single forwarded message.
// All fields default false (unknown == not flagged), so a partially-populated
// meta can only ever reduce the chance of a hold, never invent one.
type MessageMeta struct {
NullSender bool // envelope sender is <> (bounce/backscatter)
SpamFlagged bool // SpamAssassin marked it spam
MalwareHit bool // ClamAV / YARA-X matched
SenderIPBad bool // sender IP is in the CSM attack DB / reputation
SPFFail bool
DKIMFail bool
DMARCFail bool
}
// HoldSignals toggles which layered signals are allowed to hold a message.
// Each is individually switchable so operators can roll out one signal at a time.
type HoldSignals struct {
BounceBackscatter bool
SpamFlagged bool
Malware bool
BadSenderIP bool
AuthFail bool
}
// Config is the forward-guard policy input. Enabled is the master switch; when
// off, Verdict never holds regardless of signals. DryRun does not affect the
// verdict itself -- it tells callers whether to enforce the hold or only
// account for it -- so it lives here for callers but is not read by Verdict.
type Config struct {
Enabled bool
DryRun bool
HoldSignals HoldSignals
}
// Verdict reports whether a message's external-forward copy should be held and
// the matching reason codes (in a fixed, deterministic order). Reasons are
// stable identifiers safe to surface in the UI, log, and generated MTA rule.
func Verdict(meta MessageMeta, cfg Config) (hold bool, reasons []string) {
if !cfg.Enabled {
return false, nil
}
sig := cfg.HoldSignals
// Fixed evaluation order -> deterministic reason slice (no map iteration).
if sig.AuthFail && meta.SPFFail && meta.DKIMFail && meta.DMARCFail {
reasons = append(reasons, "auth_fail")
}
if sig.BadSenderIP && meta.SenderIPBad {
reasons = append(reasons, "bad_sender_ip")
}
if sig.BounceBackscatter && meta.NullSender {
reasons = append(reasons, "bounce_backscatter")
}
if sig.Malware && meta.MalwareHit {
reasons = append(reasons, "malware")
}
if sig.SpamFlagged && meta.SpamFlagged {
reasons = append(reasons, "spam_flagged")
}
return len(reasons) > 0, reasons
}
// Package quarantine is the CSM-owned Maildir that holds external forward
// copies the forward-guard decided to withhold. The exim transport (Phase 2
// Slice C) appends held copies here with X-CSM-* control headers; CSM lists
// them, releases (re-injects to the original external recipient) or deletes,
// and prunes by age. CSM is never in the live delivery path -- exim writes the
// file, CSM only acts on it afterwards.
package quarantine
import (
"bytes"
"context"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"sort"
"strings"
"sync/atomic"
"syscall"
"time"
)
// Control headers the exim transport adds and CSM parses. They are stripped
// before a message is re-injected so they never leak to the recipient.
const (
hdrForwarder = "X-CSM-Forwarder"
hdrRecipient = "X-CSM-Recipient"
hdrSender = "X-CSM-Sender"
hdrReasons = "X-CSM-Reasons"
hdrPrefix = "X-CSM-"
)
// HeldMessage is the operator-facing view of one held forward copy.
type HeldMessage struct {
ID string `json:"id"`
Forwarder string `json:"forwarder"` // local source address that forwarded
Recipient string `json:"recipient"` // external destination that was held
Sender string `json:"sender"` // envelope sender ("" = null-sender bounce)
Reasons []string `json:"reasons"`
HeldAt time.Time `json:"held_at"`
Size int64 `json:"size"`
}
// Quarantine manages the held-forward Maildir.
type Quarantine struct {
base string
counter atomic.Uint64
sendmail func(sender, recipient string, body []byte) error
}
// New returns a quarantine rooted at dir (a Maildir; new/cur/tmp are created on
// demand). The default re-injector shells out to the platform sendmail.
func New(dir string) *Quarantine {
return &Quarantine{base: dir, sendmail: runSendmail}
}
func (q *Quarantine) sub(name string) string { return filepath.Join(q.base, name) }
// Hold writes a held copy into the Maildir with the X-CSM-* control headers and
// returns its id (the Maildir filename). In production exim's appendfile writes
// these files; Hold produces the identical format for CSM-side paths and tests.
func (q *Quarantine) Hold(m HeldMessage, body []byte) (string, error) {
for _, d := range []string{"tmp", "new", "cur"} {
if err := os.MkdirAll(q.sub(d), 0700); err != nil {
return "", fmt.Errorf("creating maildir %s: %w", d, err)
}
}
var buf bytes.Buffer
fmt.Fprintf(&buf, "%s: %s\r\n", hdrForwarder, headerValue(m.Forwarder))
fmt.Fprintf(&buf, "%s: %s\r\n", hdrRecipient, headerValue(m.Recipient))
fmt.Fprintf(&buf, "%s: %s\r\n", hdrSender, headerValue(m.Sender))
fmt.Fprintf(&buf, "%s: %s\r\n", hdrReasons, headerValue(strings.Join(m.Reasons, ",")))
buf.Write(body)
id := fmt.Sprintf("%d.%d.csm", time.Now().UnixNano(), q.counter.Add(1))
tmp := filepath.Join(q.sub("tmp"), id)
if err := os.WriteFile(tmp, buf.Bytes(), 0600); err != nil { // #nosec G306 -- 0600 is intended
return "", fmt.Errorf("writing held message: %w", err)
}
dst := filepath.Join(q.sub("new"), id)
if err := os.Rename(tmp, dst); err != nil {
_ = os.Remove(tmp)
return "", fmt.Errorf("committing held message: %w", err)
}
return id, nil
}
// List returns every held message with parsed metadata. A missing Maildir is
// not an error: it just means nothing has been held.
func (q *Quarantine) List() ([]HeldMessage, error) {
out := []HeldMessage{}
for _, dir := range []string{"new", "cur"} {
entries, err := os.ReadDir(q.sub(dir))
if err != nil {
if os.IsNotExist(err) {
continue
}
return nil, err
}
for _, e := range entries {
if e.IsDir() {
continue
}
m, err := q.read(dir, e.Name())
if err != nil {
continue // skip unreadable/partial entries rather than fail the whole list
}
out = append(out, m)
}
}
sort.Slice(out, func(i, j int) bool { return out[i].HeldAt.Before(out[j].HeldAt) })
return out, nil
}
func (q *Quarantine) read(dir, id string) (HeldMessage, error) {
path := filepath.Join(q.sub(dir), id)
data, info, err := readMessageFile(path)
if err != nil {
return HeldMessage{}, err
}
hdr := parseControlHeaders(data)
return HeldMessage{
ID: id,
Forwarder: hdr[hdrForwarder],
Recipient: hdr[hdrRecipient],
Sender: hdr[hdrSender],
Reasons: splitReasons(hdr[hdrReasons]),
HeldAt: info.ModTime(),
Size: info.Size(),
}, nil
}
// Release re-injects the held copy to its original external recipient (operator
// decided it was a false positive), then removes it. The message is removed
// only after a successful re-injection, so a sendmail failure never loses mail.
func (q *Quarantine) Release(id string) error {
_, path := q.locate(id)
if path == "" {
return fmt.Errorf("held message %q not found", id)
}
data, _, err := readMessageFile(path)
if err != nil {
return err
}
hdr := parseControlHeaders(data)
clean := stripControlHeaders(data)
if err := q.sendmail(hdr[hdrSender], hdr[hdrRecipient], clean); err != nil {
return fmt.Errorf("re-injecting held message: %w", err)
}
return os.Remove(path)
}
// Delete discards a held copy without delivering it.
func (q *Quarantine) Delete(id string) error {
_, path := q.locate(id)
if path == "" {
return fmt.Errorf("held message %q not found", id)
}
return os.Remove(path)
}
// PruneOlderThan removes held copies older than maxAge and returns how many were
// removed.
func (q *Quarantine) PruneOlderThan(maxAge time.Duration) (int, error) {
cutoff := time.Now().Add(-maxAge)
removed := 0
for _, dir := range []string{"new", "cur"} {
entries, err := os.ReadDir(q.sub(dir))
if err != nil {
if os.IsNotExist(err) {
continue
}
return removed, err
}
for _, e := range entries {
if e.IsDir() {
continue
}
info, err := e.Info()
if err != nil {
continue
}
if info.Mode().IsRegular() && info.ModTime().Before(cutoff) {
if err := os.Remove(filepath.Join(q.sub(dir), e.Name())); err == nil {
removed++
}
}
}
}
return removed, nil
}
// CountsByForwarder returns how many held copies each forwarder produced.
func (q *Quarantine) CountsByForwarder() (map[string]int, error) {
msgs, err := q.List()
if err != nil {
return nil, err
}
counts := make(map[string]int)
for _, m := range msgs {
counts[m.Forwarder]++
}
return counts, nil
}
// pathOf returns the on-disk path of a held message id, or "" if absent.
func (q *Quarantine) pathOf(id string) string {
_, path := q.locate(id)
return path
}
func (q *Quarantine) locate(id string) (dir, path string) {
id = filepath.Base(id) // defend against traversal in a caller-supplied id
if id == "" || id == "." || id == ".." || id == string(filepath.Separator) {
return "", ""
}
for _, d := range []string{"new", "cur"} {
p := filepath.Join(q.sub(d), id)
info, err := os.Lstat(p)
if err != nil {
continue
}
if info.Mode().IsRegular() {
return d, p
}
}
return "", ""
}
func readMessageFile(path string) ([]byte, os.FileInfo, error) {
// #nosec G304 -- path is a located Maildir entry under the CSM-owned base;
// O_NOFOLLOW rejects symlink entries and symlink swaps before reading.
f, err := os.OpenFile(path, os.O_RDONLY|syscall.O_NOFOLLOW, 0)
if err != nil {
return nil, nil, err
}
defer f.Close()
info, err := f.Stat()
if err != nil {
return nil, nil, err
}
if !info.Mode().IsRegular() {
return nil, nil, fmt.Errorf("held message %q is not a regular file", filepath.Base(path))
}
data, err := io.ReadAll(f)
if err != nil {
return nil, nil, err
}
return data, info, nil
}
// headerValue collapses a control-header value to a single safe line (no CR/LF
// so a hostile address cannot inject extra headers).
func headerValue(v string) string {
v = strings.ReplaceAll(v, "\r", "")
v = strings.ReplaceAll(v, "\n", "")
return strings.TrimSpace(v)
}
func splitReasons(v string) []string {
if strings.TrimSpace(v) == "" {
return nil
}
parts := strings.Split(v, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
if p = strings.TrimSpace(p); p != "" {
out = append(out, p)
}
}
return out
}
// headerBlockEnd returns the index just past the first blank line (the
// header/body separator), or len(data) if there is no blank line. A leading
// blank line means the header block is empty -- the rest is body. Defining the
// boundary by the first empty line (rather than searching for "\n\n") keeps
// parsing and stripping consistent on degenerate inputs.
func headerBlockEnd(data []byte) int {
off := 0
for off < len(data) {
nl := bytes.IndexByte(data[off:], '\n')
if nl < 0 {
return len(data) // no terminating newline: all headers, no body
}
line := data[off : off+nl+1]
if strings.TrimRight(string(line), "\r\n") == "" {
return off + nl + 1
}
off += nl + 1
}
return len(data)
}
func parseControlHeaders(data []byte) map[string]string {
out := map[string]string{}
header := data[:headerBlockEnd(data)]
for _, line := range strings.Split(string(header), "\n") {
line = strings.TrimRight(line, "\r")
colon := strings.IndexByte(line, ':')
if colon < 0 {
continue
}
key, ok := canonicalControlHeader(line[:colon])
if !ok {
continue
}
if _, exists := out[key]; exists {
continue
}
out[key] = headerValue(line[colon+1:])
}
return out
}
func canonicalControlHeader(name string) (string, bool) {
name = strings.TrimSpace(name)
switch {
case strings.EqualFold(name, hdrForwarder):
return hdrForwarder, true
case strings.EqualFold(name, hdrRecipient):
return hdrRecipient, true
case strings.EqualFold(name, hdrSender):
return hdrSender, true
case strings.EqualFold(name, hdrReasons):
return hdrReasons, true
case hasControlHeaderPrefix(name):
return name, true
default:
return "", false
}
}
func hasControlHeaderPrefix(line string) bool {
return len(line) >= len(hdrPrefix) && strings.EqualFold(line[:len(hdrPrefix)], hdrPrefix)
}
// stripControlHeaders removes every X-CSM-* header line, leaving the original
// message intact for re-injection. Only the header lines before the
// header/body separator are filtered; the separator and body are preserved
// byte-for-byte so the recipient sees exactly the original message.
func stripControlHeaders(data []byte) []byte {
var kept bytes.Buffer
rest := data
droppingControl := false
for len(rest) > 0 {
nl := bytes.IndexByte(rest, '\n')
var line []byte
if nl < 0 {
line, rest = rest, nil
} else {
line, rest = rest[:nl+1], rest[nl+1:]
}
trimmed := strings.TrimRight(string(line), "\r\n")
if trimmed == "" {
// Blank line ends the header block; keep it and the body verbatim.
kept.Write(line)
kept.Write(rest)
return kept.Bytes()
}
if line[0] == ' ' || line[0] == '\t' {
if droppingControl {
continue
}
kept.Write(line)
continue
}
droppingControl = hasControlHeaderPrefix(trimmed)
if droppingControl {
continue
}
kept.Write(line)
}
return kept.Bytes()
}
func runSendmail(sender, recipient string, body []byte) error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
sender = headerValue(sender)
recipient = headerValue(recipient)
// -i: do not treat lone "." as end; -f: envelope sender ("" yields null
// sender); -- terminates options so a hostile recipient cannot be a flag.
cmd := exec.CommandContext(ctx, sendmailPath, "-i", "-f", sender, "--", recipient) // #nosec G204 -- recipient guarded by --, args are envelope addresses
cmd.Stdin = bytes.NewReader(body)
return cmd.Run()
}
const sendmailPath = "/usr/sbin/sendmail"
package maillog
import (
"errors"
"fmt"
"os"
"github.com/pidginhost/csm/internal/config"
)
// New returns the mail-log Reader appropriate for cfg.Source. A "platform"
// default file path is supplied by the caller (computed from
// internal/platform.Detect()). Pass an empty string to skip the platform
// default - useful in tests.
//
// auto - try file (must exist); fall back to journal if file missing
// but units are configured.
// file - error if the file doesn't exist.
// journal - error if the journal reader is unavailable (default builds).
func New(cfg config.MailLogsConfig, platformDefaultFile string) (Reader, error) {
path := cfg.File
if path == "" {
path = platformDefaultFile
}
switch cfg.Source {
case "file":
if _, err := os.Stat(path); err != nil {
return nil, fmt.Errorf("mail_logs.source=file but %s: %w", path, err)
}
return NewFileReader(path), nil
case "journal":
if len(cfg.Units) == 0 {
return nil, fmt.Errorf("mail_logs.source=journal requires units")
}
return NewJournalReader(cfg.Units), nil
case "auto":
if path != "" {
if _, err := os.Stat(path); err == nil {
return NewFileReader(path), nil
}
}
if len(cfg.Units) == 0 {
return nil, errors.New("mail_logs.source=auto: log file not found and no units configured for journal fallback")
}
return NewJournalReader(cfg.Units), nil
default:
return nil, fmt.Errorf("mail_logs.source=%q: unknown", cfg.Source)
}
}
package maillog
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"os"
"strings"
"time"
)
// maxLogLineBytes caps a single mail-log line. Real syslog lines top
// out around 8 KB; 64 KB is generous yet bounded. Without this cap a
// malformed source could ship a multi-gigabyte "line" and turn the
// reader into an OOM vector.
const maxLogLineBytes = 64 * 1024
// defaultGoneGrace is how long the source path must stay missing before
// the reader declares the source gone. Long enough to ride out a
// logrotate create-delay (rename old -> create new), short enough that an
// operator notices a real syslog->journald migration quickly.
const defaultGoneGrace = 90 * time.Second
// FileReader tails a single log file. It uses a 2-second polling loop
// because rsyslog/syslog-ng don't reliably trigger inotify events on
// every line written, and periodic path re-stat checks for log rotation.
//
// On context cancel the reader closes the output channel and returns.
type FileReader struct {
path string
// onGone, when set, fires once when the source path has been missing
// continuously for goneGrace. A FileReader whose path vanishes mid-run
// (e.g. a syslog->journald migration) otherwise tails a dead fd
// silently; the callback lets the daemon surface a finding and mark the
// watcher unhealthy. onRestored fires after the path returns and the
// reader can use it again.
onGone func(error)
onRestored func()
goneGrace time.Duration
nowFn func() time.Time
// gone-tracking state, touched only by the single loop goroutine.
firstMissing time.Time
goneFired bool
restoreReady bool
}
// NewFileReader constructs a FileReader for the given path.
func NewFileReader(path string) *FileReader {
return &FileReader{path: path, goneGrace: defaultGoneGrace, nowFn: time.Now}
}
// SetOnGone installs a callback invoked once when the source path has been
// missing for longer than the grace period. Must be called before Run.
func (r *FileReader) SetOnGone(fn func(error)) { r.onGone = fn }
// SetOnRestored installs a callback invoked once after a previously-gone
// source path returns and the reader is using it again. Must be called
// before Run.
func (r *FileReader) SetOnRestored(fn func()) { r.onRestored = fn }
// recordStat advances the missing-source state machine from one stat result.
// It fires onGone once after the path is missing past the grace period and
// arms the restore callback when the path returns.
func (r *FileReader) recordStat(missing bool, missErr error) {
if !missing {
if !r.firstMissing.IsZero() && r.goneFired {
r.restoreReady = true
}
r.firstMissing = time.Time{}
return
}
now := r.nowFn()
if r.firstMissing.IsZero() {
r.firstMissing = now
}
if !r.goneFired && !r.restoreReady && now.Sub(r.firstMissing) >= r.goneGrace {
r.goneFired = true
if r.onGone != nil {
r.onGone(missErr)
}
}
}
func (r *FileReader) recordRestored() {
if !r.restoreReady {
return
}
r.restoreReady = false
r.goneFired = false
if r.onRestored != nil {
r.onRestored()
}
}
// Run starts the polling loop and returns the line channel. Returns an
// error only when the path can't be opened at all; runtime errors during
// polling are best-effort logged via stderr but do not stop the reader.
func (r *FileReader) Run(ctx context.Context) (<-chan Line, error) {
f, reader, ino, err := r.open()
if err != nil {
return nil, fmt.Errorf("open %s: %w", r.path, err)
}
out := make(chan Line, 64)
go r.loop(ctx, out, f, reader, ino)
return out, nil
}
// readBoundedLine reads up to and including the next '\n'. If the line
// exceeds maxBytes, the returned data is capped at maxBytes, the reader
// is advanced past the line's terminating newline so framing stays
// intact for the next call, and truncated=true is returned. Callers must
// treat truncated records as untrusted and skip them.
func readBoundedLine(r *bufio.Reader, maxBytes int) (string, bool, error) {
var b strings.Builder
truncated := false
for {
chunk, err := r.ReadSlice('\n')
if len(chunk) > 0 {
switch {
case truncated:
// drain remainder to align on next newline
case b.Len()+len(chunk) <= maxBytes:
b.Write(chunk)
default:
if room := maxBytes - b.Len(); room > 0 {
b.Write(chunk[:room])
}
truncated = true
}
}
if errors.Is(err, bufio.ErrBufferFull) {
continue
}
return b.String(), truncated, err
}
}
func (r *FileReader) open() (*os.File, *bufio.Reader, uint64, error) {
return r.openAt(0, io.SeekEnd)
}
func (r *FileReader) openRotated() (*os.File, *bufio.Reader, uint64, error) {
return r.openAt(0, io.SeekStart)
}
func (r *FileReader) openAt(offset int64, whence int) (*os.File, *bufio.Reader, uint64, error) {
f, err := os.Open(r.path) // #nosec G304 -- operator-supplied log path
if err != nil {
return nil, nil, 0, err
}
if _, seekErr := f.Seek(offset, whence); seekErr != nil {
_ = f.Close()
return nil, nil, 0, seekErr
}
st, err := f.Stat()
if err != nil {
_ = f.Close()
return nil, nil, 0, err
}
return f, bufio.NewReader(f), inode(st), nil
}
func (r *FileReader) loop(ctx context.Context, out chan<- Line, f *os.File, reader *bufio.Reader, lastIno uint64) {
defer close(out)
defer func() {
if f != nil {
_ = f.Close()
}
}()
poll := time.NewTicker(2 * time.Second)
defer poll.Stop()
// Rotation safety-net: even if every poll tick finds zero EOFs (a
// continuously-active log), still re-stat once per minute so a
// rotation that happens during a sustained write burst is caught
// without waiting for the next idle period.
rotate := time.NewTicker(time.Minute)
defer rotate.Stop()
reopenOnRotate := func() {
st, err := os.Stat(r.path)
if err != nil {
// Track persistent disappearance so a source that vanishes
// mid-run (syslog->journald migration) surfaces instead of
// tailing a dead fd in silence.
r.recordStat(os.IsNotExist(err), err)
return
}
r.recordStat(false, nil)
if inode(st) == lastIno {
r.recordRestored()
return
}
nf, nr, ino, err := r.openRotated()
if err != nil {
fmt.Fprintf(os.Stderr, "maillog file_reader %s reopen: %v\n", r.path, err)
return
}
_ = f.Close()
f = nf
reader = nr
lastIno = ino
r.recordRestored()
}
for {
select {
case <-ctx.Done():
return
case <-poll.C:
for {
line, truncated, err := readBoundedLine(reader, maxLogLineBytes)
if err != nil {
if truncated {
fmt.Fprintf(os.Stderr, "maillog file_reader %s: oversized line skipped at %d bytes\n", r.path, maxLogLineBytes)
}
// Tight rotation detection: every time the reader
// hits EOF or any I/O error we re-stat the path so a
// post-rotate log is picked up by the next poll tick
// rather than waiting for the safety-net ticker.
reopenOnRotate()
break
}
if truncated {
fmt.Fprintf(os.Stderr, "maillog file_reader %s: oversized line skipped at %d bytes\n", r.path, maxLogLineBytes)
continue
}
select {
case out <- Line{Source: "file", Message: line}:
case <-ctx.Done():
return
}
}
case <-rotate.C:
reopenOnRotate()
}
}
}
//go:build unix
package maillog
import (
"os"
"syscall"
)
func inode(fi os.FileInfo) uint64 {
if st, ok := fi.Sys().(*syscall.Stat_t); ok {
return st.Ino
}
return 0
}
//go:build !linux || !journal
package maillog
import (
"context"
"errors"
)
// JournalReader is a no-op stub on builds without the `journal` tag.
// The factory (T4) uses this to produce a clear error rather than
// silently downgrading to file mode when the operator explicitly asked
// for journald.
type JournalReader struct{}
// NewJournalReader satisfies the same constructor signature as the
// linux+journal build, so the factory and tests compile identically
// on default builds.
func NewJournalReader(_ []string) *JournalReader { return &JournalReader{} }
func JournalSupported() bool { return false }
// ErrJournalUnsupported is returned when the build was produced without
// the `journal` tag (default builds).
var ErrJournalUnsupported = errors.New("journal reader not compiled in (build with JOURNAL=1)")
// Run returns ErrJournalUnsupported immediately on stub builds.
func (*JournalReader) Run(_ context.Context) (<-chan Line, error) {
return nil, ErrJournalUnsupported
}
// Package metrics is CSM's local OpenMetrics implementation. It exists
// so the daemon can expose a `/metrics` endpoint (ROADMAP item 4)
// without pulling in `github.com/prometheus/client_golang`, which
// would add ~20 transitive dependencies for the handful of counters,
// gauges, and histograms this project actually needs.
//
// The surface is intentionally narrow: Counter, Gauge, Histogram, and
// their labelled vector siblings. No summaries, no collectors, no
// custom exposition formats. Metric objects are safe for concurrent
// use; registration is idempotent.
package metrics
import (
"errors"
"fmt"
"io"
"math"
"sort"
"strings"
"sync"
"sync/atomic"
)
// metricType discriminates the OpenMetrics TYPE line.
type metricType string
const (
typeCounter metricType = "counter"
typeGauge metricType = "gauge"
typeHistogram metricType = "histogram"
)
// collectable is any metric that can write its exposition form to w.
// Kept unexported; registry drives it.
type collectable interface {
writeTo(w *bufferedWriter)
}
// Registry holds a set of metrics and a lazily-refreshed snapshot of
// callback-driven gauges. Scraping takes a read lock; registration
// takes a write lock. Both are short-lived.
type Registry struct {
mu sync.RWMutex
entries []registered
names map[string]struct{}
gaugeHooks []gaugeHook
counterHooks []counterHook
}
type registered struct {
name string
c collectable
}
type gaugeHook struct {
name string
help string
fn func() float64
}
type counterHook struct {
name string
help string
fn func() float64
}
// NewRegistry returns an empty registry.
func NewRegistry() *Registry {
return &Registry{names: map[string]struct{}{}}
}
// MustRegister panics if a metric of the same name is already
// registered. Daemons call this once at startup; a duplicate is a
// programming error.
func (r *Registry) MustRegister(name string, c Collector) {
r.mu.Lock()
defer r.mu.Unlock()
if _, dup := r.names[name]; dup {
panic(fmt.Sprintf("metrics: duplicate registration %q", name))
}
r.names[name] = struct{}{}
r.entries = append(r.entries, registered{name: name, c: c})
}
// RegisterGaugeFunc exposes a value produced by calling fn at scrape
// time. Useful for "ask the OS for the bbolt file size" metrics where
// caching the value would be wrong.
func (r *Registry) RegisterGaugeFunc(name, help string, fn func() float64) {
r.mu.Lock()
defer r.mu.Unlock()
if _, dup := r.names[name]; dup {
panic(fmt.Sprintf("metrics: duplicate registration %q", name))
}
r.names[name] = struct{}{}
r.gaugeHooks = append(r.gaugeHooks, gaugeHook{name: name, help: help, fn: fn})
}
// RegisterCounterFunc is the counter equivalent of RegisterGaugeFunc.
// Exposition must be monotonically non-decreasing across calls;
// callers are on the hook for that invariant.
func (r *Registry) RegisterCounterFunc(name, help string, fn func() float64) {
r.mu.Lock()
defer r.mu.Unlock()
if _, dup := r.names[name]; dup {
panic(fmt.Sprintf("metrics: duplicate registration %q", name))
}
r.names[name] = struct{}{}
r.counterHooks = append(r.counterHooks, counterHook{name: name, help: help, fn: fn})
}
// WriteOpenMetrics renders a scrape in the OpenMetrics text format.
// The output ends with the `# EOF` marker that Prometheus requires
// when served as `Content-Type: application/openmetrics-text`.
func (r *Registry) WriteOpenMetrics(w io.Writer) error {
r.mu.RLock()
defer r.mu.RUnlock()
bw := newBufferedWriter(w)
// Stable ordering matters for diffs and for human reading. Sort
// by name at scrape time; registration order is not stable across
// restarts because goroutines may register concurrently.
entries := make([]registered, len(r.entries))
copy(entries, r.entries)
sort.Slice(entries, func(i, j int) bool { return entries[i].name < entries[j].name })
for _, e := range entries {
e.c.writeTo(bw)
}
gauges := append([]gaugeHook(nil), r.gaugeHooks...)
sort.Slice(gauges, func(i, j int) bool { return gauges[i].name < gauges[j].name })
for _, h := range gauges {
bw.writeMeta(h.name, h.help, typeGauge)
bw.writeSample(h.name, nil, h.fn())
}
counters := append([]counterHook(nil), r.counterHooks...)
sort.Slice(counters, func(i, j int) bool { return counters[i].name < counters[j].name })
for _, h := range counters {
bw.writeMeta(h.name, h.help, typeCounter)
bw.writeSample(h.name, nil, h.fn())
}
bw.writeEOF()
return bw.err
}
// -----------------------------------------------------------------------
// Counter
// -----------------------------------------------------------------------
// Counter is a monotonically non-decreasing float value.
type Counter struct {
name string
help string
// Stored as bits of a float64 so Add can safely work on values
// that never need to fit into int64 (e.g., byte counts).
bits uint64
}
// NewCounter constructs an unregistered Counter. Register with
// Registry.MustRegister(name, c).
func NewCounter(name, help string) *Counter {
return &Counter{name: name, help: help}
}
// Add increments the counter by v. Panics on negative v; counters are
// monotonic by contract.
func (c *Counter) Add(v float64) {
if v < 0 {
panic(fmt.Sprintf("metrics: counter %q Add(%g): negative delta", c.name, v))
}
for {
old := atomic.LoadUint64(&c.bits)
newVal := math.Float64frombits(old) + v
if atomic.CompareAndSwapUint64(&c.bits, old, math.Float64bits(newVal)) {
return
}
}
}
// Inc adds 1.
func (c *Counter) Inc() { c.Add(1) }
// Value returns the current counter value. Useful in tests.
func (c *Counter) Value() float64 {
return math.Float64frombits(atomic.LoadUint64(&c.bits))
}
func (c *Counter) writeTo(w *bufferedWriter) {
w.writeMeta(c.name, c.help, typeCounter)
w.writeSample(c.name, nil, c.Value())
}
// -----------------------------------------------------------------------
// Gauge
// -----------------------------------------------------------------------
// Gauge is a point-in-time numeric value that can go up or down.
type Gauge struct {
name string
help string
bits uint64
}
// NewGauge constructs an unregistered Gauge.
func NewGauge(name, help string) *Gauge {
return &Gauge{name: name, help: help}
}
// Set replaces the gauge value.
func (g *Gauge) Set(v float64) {
atomic.StoreUint64(&g.bits, math.Float64bits(v))
}
// Add updates the gauge by v (may be negative).
func (g *Gauge) Add(v float64) {
for {
old := atomic.LoadUint64(&g.bits)
newVal := math.Float64frombits(old) + v
if atomic.CompareAndSwapUint64(&g.bits, old, math.Float64bits(newVal)) {
return
}
}
}
// Inc adds 1. Dec subtracts 1.
func (g *Gauge) Inc() { g.Add(1) }
func (g *Gauge) Dec() { g.Add(-1) }
// Value returns the current gauge value.
func (g *Gauge) Value() float64 {
return math.Float64frombits(atomic.LoadUint64(&g.bits))
}
func (g *Gauge) writeTo(w *bufferedWriter) {
w.writeMeta(g.name, g.help, typeGauge)
w.writeSample(g.name, nil, g.Value())
}
// -----------------------------------------------------------------------
// Histogram
// -----------------------------------------------------------------------
// Histogram is a cumulative histogram with fixed upper-bound buckets.
// Buckets must be strictly increasing; the implicit +Inf bucket is
// appended automatically.
type Histogram struct {
name string
help string
upper []float64
bucketCnt []uint64 // atomic counters per bucket (last entry is +Inf)
sum uint64 // atomic float64 bits
count uint64 // atomic total count
}
// NewHistogram constructs an unregistered Histogram. upperBounds must
// be strictly increasing; the +Inf bucket is implicit.
func NewHistogram(name, help string, upperBounds []float64) *Histogram {
for i := 1; i < len(upperBounds); i++ {
if upperBounds[i] <= upperBounds[i-1] {
panic(fmt.Sprintf("metrics: histogram %q bounds must be strictly increasing", name))
}
}
return &Histogram{
name: name,
help: help,
upper: append([]float64{}, upperBounds...),
bucketCnt: make([]uint64, len(upperBounds)+1), // one extra for +Inf
}
}
// Observe records a single sample.
func (h *Histogram) Observe(v float64) {
for i, up := range h.upper {
if v <= up {
atomic.AddUint64(&h.bucketCnt[i], 1)
}
}
// Always increment the +Inf bucket (cumulative semantics).
atomic.AddUint64(&h.bucketCnt[len(h.upper)], 1)
atomic.AddUint64(&h.count, 1)
for {
old := atomic.LoadUint64(&h.sum)
newSum := math.Float64frombits(old) + v
if atomic.CompareAndSwapUint64(&h.sum, old, math.Float64bits(newSum)) {
return
}
}
}
func (h *Histogram) writeTo(w *bufferedWriter) {
w.writeMeta(h.name, h.help, typeHistogram)
for i, up := range h.upper {
labels := []labelPair{{"le", formatFloat(up)}}
w.writeSample(h.name+"_bucket", labels, float64(atomic.LoadUint64(&h.bucketCnt[i])))
}
w.writeSample(h.name+"_bucket", []labelPair{{"le", "+Inf"}}, float64(atomic.LoadUint64(&h.bucketCnt[len(h.upper)])))
w.writeSample(h.name+"_sum", nil, math.Float64frombits(atomic.LoadUint64(&h.sum)))
w.writeSample(h.name+"_count", nil, float64(atomic.LoadUint64(&h.count)))
}
// -----------------------------------------------------------------------
// Labelled variants (vectors)
// -----------------------------------------------------------------------
// CounterVec is a family of counters indexed by a fixed set of label
// keys. Label values are provided per sample.
type CounterVec struct {
name string
help string
labelKeys []string
mu sync.Mutex
children map[string]*Counter
keys []string // insertion-ordered; stable for scrape ordering within a vec
maxChildren int // 0 = unlimited
}
// defaultVecCardinalityCap bounds the per-vec child count. Operators
// can raise the limit per metric via SetMaxChildren. The cap exists
// because user-controlled labels (per-IP, per-domain) can otherwise
// grow the children map without bound and exhaust memory.
const defaultVecCardinalityCap = 1000
// overflowLabelValue collapses all label values past the cap into a
// single sentinel bucket so cardinality stays bounded.
const overflowLabelValue = "_overflow_"
// NewCounterVec constructs a vector counter. labelKeys must be non-
// empty; use NewCounter for an unlabelled counter.
func NewCounterVec(name, help string, labelKeys []string) *CounterVec {
if len(labelKeys) == 0 {
panic(fmt.Sprintf("metrics: counter vec %q needs at least one label key", name))
}
return &CounterVec{
name: name,
help: help,
labelKeys: append([]string{}, labelKeys...),
children: map[string]*Counter{},
maxChildren: defaultVecCardinalityCap,
}
}
// SetMaxChildren overrides the per-vec cardinality cap. Pass 0 to
// disable the cap entirely (only safe when label values come from a
// fixed enum the operator controls).
func (cv *CounterVec) SetMaxChildren(n int) {
cv.mu.Lock()
defer cv.mu.Unlock()
cv.maxChildren = n
}
// ChildCount returns the current number of distinct label-value
// children, including the overflow sentinel if used. Exposed for
// tests and operator health checks.
func (cv *CounterVec) ChildCount() int {
cv.mu.Lock()
defer cv.mu.Unlock()
return len(cv.children)
}
// With returns the child counter for the given label values. Values
// are identified by the concatenation of label values; caller supplies
// them in the same order as labelKeys from NewCounterVec. Once the cap
// is reached, additional label-value combinations collapse to a
// single "_overflow_" bucket so the map stays bounded.
func (cv *CounterVec) With(values ...string) *Counter {
if len(values) != len(cv.labelKeys) {
panic(fmt.Sprintf("metrics: counter vec %q: got %d label values, want %d", cv.name, len(values), len(cv.labelKeys)))
}
key := joinLabelValues(values)
cv.mu.Lock()
defer cv.mu.Unlock()
if c, ok := cv.children[key]; ok {
return c
}
if cv.maxChildren > 0 && len(cv.children) >= explicitChildLimit(cv.maxChildren) {
key = joinLabelValues(overflowLabelValuesForArity(len(cv.labelKeys)))
if c, ok := cv.children[key]; ok {
return c
}
}
c := &Counter{name: cv.name, help: cv.help}
cv.children[key] = c
cv.keys = append(cv.keys, key)
return c
}
// overflowLabelValuesForArity returns a values slice of length n with
// every entry set to the overflow sentinel. Used so the cap path
// produces a single canonical key regardless of arity.
func overflowLabelValuesForArity(n int) []string {
out := make([]string, n)
for i := range out {
out[i] = overflowLabelValue
}
return out
}
func explicitChildLimit(maxChildren int) int {
if maxChildren <= 1 {
return 0
}
return maxChildren - 1
}
func (cv *CounterVec) writeTo(w *bufferedWriter) {
w.writeMeta(cv.name, cv.help, typeCounter)
cv.mu.Lock()
keys := append([]string(nil), cv.keys...)
childMap := make(map[string]*Counter, len(cv.children))
for k, v := range cv.children {
childMap[k] = v
}
cv.mu.Unlock()
sort.Strings(keys)
for _, k := range keys {
values := splitLabelValues(k)
pairs := make([]labelPair, len(cv.labelKeys))
for i, lk := range cv.labelKeys {
pairs[i] = labelPair{key: lk, value: values[i]}
}
w.writeSample(cv.name, pairs, childMap[k].Value())
}
}
// HistogramVec is the labelled variant of Histogram. All children
// share the same upper-bound set.
type HistogramVec struct {
name string
help string
labelKeys []string
upper []float64
mu sync.Mutex
children map[string]*Histogram
keys []string
maxChildren int
}
// NewHistogramVec constructs a vector histogram.
func NewHistogramVec(name, help string, labelKeys []string, upperBounds []float64) *HistogramVec {
if len(labelKeys) == 0 {
panic(fmt.Sprintf("metrics: histogram vec %q needs at least one label key", name))
}
for i := 1; i < len(upperBounds); i++ {
if upperBounds[i] <= upperBounds[i-1] {
panic(fmt.Sprintf("metrics: histogram vec %q bounds must be strictly increasing", name))
}
}
return &HistogramVec{
name: name,
help: help,
labelKeys: append([]string{}, labelKeys...),
upper: append([]float64{}, upperBounds...),
children: map[string]*Histogram{},
maxChildren: defaultVecCardinalityCap,
}
}
// SetMaxChildren overrides the per-vec cardinality cap.
func (hv *HistogramVec) SetMaxChildren(n int) {
hv.mu.Lock()
defer hv.mu.Unlock()
hv.maxChildren = n
}
// ChildCount returns the current number of distinct label-value children.
func (hv *HistogramVec) ChildCount() int {
hv.mu.Lock()
defer hv.mu.Unlock()
return len(hv.children)
}
// With returns the child histogram for the given label values. See
// CounterVec.With for cap semantics.
func (hv *HistogramVec) With(values ...string) *Histogram {
if len(values) != len(hv.labelKeys) {
panic(fmt.Sprintf("metrics: histogram vec %q: got %d label values, want %d", hv.name, len(values), len(hv.labelKeys)))
}
key := joinLabelValues(values)
hv.mu.Lock()
defer hv.mu.Unlock()
if h, ok := hv.children[key]; ok {
return h
}
if hv.maxChildren > 0 && len(hv.children) >= explicitChildLimit(hv.maxChildren) {
key = joinLabelValues(overflowLabelValuesForArity(len(hv.labelKeys)))
if h, ok := hv.children[key]; ok {
return h
}
}
h := &Histogram{
name: hv.name,
help: hv.help,
upper: hv.upper,
bucketCnt: make([]uint64, len(hv.upper)+1),
}
hv.children[key] = h
hv.keys = append(hv.keys, key)
return h
}
func (hv *HistogramVec) writeTo(w *bufferedWriter) {
w.writeMeta(hv.name, hv.help, typeHistogram)
hv.mu.Lock()
keys := append([]string(nil), hv.keys...)
childMap := make(map[string]*Histogram, len(hv.children))
for k, v := range hv.children {
childMap[k] = v
}
hv.mu.Unlock()
sort.Strings(keys)
for _, k := range keys {
values := splitLabelValues(k)
labelPairs := make([]labelPair, len(hv.labelKeys))
for i, lk := range hv.labelKeys {
labelPairs[i] = labelPair{key: lk, value: values[i]}
}
h := childMap[k]
for i, up := range h.upper {
pairs := append([]labelPair(nil), labelPairs...)
pairs = append(pairs, labelPair{"le", formatFloat(up)})
w.writeSample(hv.name+"_bucket", pairs, float64(atomicLoad(&h.bucketCnt[i])))
}
pairsInf := append([]labelPair(nil), labelPairs...)
pairsInf = append(pairsInf, labelPair{"le", "+Inf"})
w.writeSample(hv.name+"_bucket", pairsInf, float64(atomicLoad(&h.bucketCnt[len(h.upper)])))
w.writeSample(hv.name+"_sum", labelPairs, math.Float64frombits(atomicLoad(&h.sum)))
w.writeSample(hv.name+"_count", labelPairs, float64(atomicLoad(&h.count)))
}
}
// atomicLoad is a small helper so the HistogramVec writeTo reads match
// the base Histogram's atomic semantics without pulling sync/atomic
// across every line.
func atomicLoad(p *uint64) uint64 { return atomic.LoadUint64(p) }
// GaugeVec is the labelled variant of Gauge.
type GaugeVec struct {
name string
help string
labelKeys []string
mu sync.Mutex
children map[string]*Gauge
keys []string
maxChildren int
}
// NewGaugeVec constructs a vector gauge.
func NewGaugeVec(name, help string, labelKeys []string) *GaugeVec {
if len(labelKeys) == 0 {
panic(fmt.Sprintf("metrics: gauge vec %q needs at least one label key", name))
}
return &GaugeVec{
name: name,
help: help,
labelKeys: append([]string{}, labelKeys...),
children: map[string]*Gauge{},
maxChildren: defaultVecCardinalityCap,
}
}
// SetMaxChildren overrides the per-vec cardinality cap.
func (gv *GaugeVec) SetMaxChildren(n int) {
gv.mu.Lock()
defer gv.mu.Unlock()
gv.maxChildren = n
}
// ChildCount returns the current number of distinct label-value children.
func (gv *GaugeVec) ChildCount() int {
gv.mu.Lock()
defer gv.mu.Unlock()
return len(gv.children)
}
// With returns the child gauge for the given label values. See
// CounterVec.With for cap semantics.
func (gv *GaugeVec) With(values ...string) *Gauge {
if len(values) != len(gv.labelKeys) {
panic(fmt.Sprintf("metrics: gauge vec %q: got %d label values, want %d", gv.name, len(values), len(gv.labelKeys)))
}
key := joinLabelValues(values)
gv.mu.Lock()
defer gv.mu.Unlock()
if g, ok := gv.children[key]; ok {
return g
}
if gv.maxChildren > 0 && len(gv.children) >= explicitChildLimit(gv.maxChildren) {
key = joinLabelValues(overflowLabelValuesForArity(len(gv.labelKeys)))
if g, ok := gv.children[key]; ok {
return g
}
}
g := &Gauge{name: gv.name, help: gv.help}
gv.children[key] = g
gv.keys = append(gv.keys, key)
return g
}
func (gv *GaugeVec) writeTo(w *bufferedWriter) {
w.writeMeta(gv.name, gv.help, typeGauge)
gv.mu.Lock()
keys := append([]string(nil), gv.keys...)
childMap := make(map[string]*Gauge, len(gv.children))
for k, v := range gv.children {
childMap[k] = v
}
gv.mu.Unlock()
sort.Strings(keys)
for _, k := range keys {
values := splitLabelValues(k)
pairs := make([]labelPair, len(gv.labelKeys))
for i, lk := range gv.labelKeys {
pairs[i] = labelPair{key: lk, value: values[i]}
}
w.writeSample(gv.name, pairs, childMap[k].Value())
}
}
// -----------------------------------------------------------------------
// Internal exposition
// -----------------------------------------------------------------------
type labelPair struct {
key, value string
}
type bufferedWriter struct {
w io.Writer
err error
}
func newBufferedWriter(w io.Writer) *bufferedWriter {
return &bufferedWriter{w: w}
}
func (bw *bufferedWriter) writef(format string, args ...any) {
if bw.err != nil {
return
}
if _, err := fmt.Fprintf(bw.w, format, args...); err != nil {
bw.err = err
}
}
func (bw *bufferedWriter) writeMeta(name, help string, typ metricType) {
bw.writef("# HELP %s %s\n", name, escapeHelp(help))
bw.writef("# TYPE %s %s\n", name, typ)
}
func (bw *bufferedWriter) writeSample(name string, labels []labelPair, value float64) {
var sb strings.Builder
sb.WriteString(name)
if len(labels) > 0 {
sb.WriteByte('{')
for i, p := range labels {
if i > 0 {
sb.WriteByte(',')
}
sb.WriteString(p.key)
sb.WriteString(`="`)
sb.WriteString(escapeLabel(p.value))
sb.WriteByte('"')
}
sb.WriteByte('}')
}
bw.writef("%s %s\n", sb.String(), formatFloat(value))
}
func (bw *bufferedWriter) writeEOF() {
bw.writef("# EOF\n")
}
// Label values are joined with an ASCII unit-separator so `a|b` and
// `ab|` do not collide. Label values themselves cannot contain 0x1F
// (we would reject it in validation); joinLabelValues panics if
// someone smuggles one in.
const labelSep = "\x1f"
func joinLabelValues(vs []string) string {
for _, v := range vs {
if strings.Contains(v, labelSep) {
panic("metrics: label value contains unit-separator")
}
}
return strings.Join(vs, labelSep)
}
func splitLabelValues(k string) []string {
return strings.Split(k, labelSep)
}
// formatFloat renders a float64 in OpenMetrics-friendly form: integer
// samples look integer, floats keep precision, NaN and infinities use
// the OpenMetrics tokens.
func formatFloat(v float64) string {
switch {
case math.IsNaN(v):
return "NaN"
case math.IsInf(v, 1):
return "+Inf"
case math.IsInf(v, -1):
return "-Inf"
case v == math.Trunc(v) && math.Abs(v) < 1e15:
return fmt.Sprintf("%d", int64(v))
default:
return fmt.Sprintf("%g", v)
}
}
// escapeHelp replaces characters that break the HELP line format
// (newline, backslash). Strings the caller supplies are not attacker-
// controlled (they are developer literals), but defensive escaping
// keeps the scrape well-formed even if a future contributor gets
// creative.
func escapeHelp(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, "\n", `\n`)
return s
}
// escapeLabel is stricter: OpenMetrics requires \\, \n, and \" inside
// double-quoted label values.
func escapeLabel(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `"`, `\"`)
s = strings.ReplaceAll(s, "\n", `\n`)
return s
}
// ErrNotRegistered can be returned by callers that look up a metric by
// name without finding it. Currently unused internally; kept for the
// external helper surface so test doubles can standardise on it.
var ErrNotRegistered = errors.New("metrics: not registered")
// -----------------------------------------------------------------------
// Process-wide default registry
// -----------------------------------------------------------------------
// defaultRegistry is the Registry the daemon shares across packages.
// Tests that need isolation should construct their own NewRegistry().
var defaultRegistry = NewRegistry()
// Default returns the process-wide Registry.
func Default() *Registry { return defaultRegistry }
// MustRegister is shorthand for Default().MustRegister.
func MustRegister(name string, c Collector) { defaultRegistry.MustRegister(name, c) }
// RegisterGaugeFunc is shorthand for Default().RegisterGaugeFunc.
func RegisterGaugeFunc(name, help string, fn func() float64) {
defaultRegistry.RegisterGaugeFunc(name, help, fn)
}
// RegisterCounterFunc is shorthand for Default().RegisterCounterFunc.
func RegisterCounterFunc(name, help string, fn func() float64) {
defaultRegistry.RegisterCounterFunc(name, help, fn)
}
// WriteOpenMetrics is shorthand for Default().WriteOpenMetrics.
func WriteOpenMetrics(w io.Writer) error {
return defaultRegistry.WriteOpenMetrics(w)
}
// Collector is the type accepted by Registry.MustRegister. Exported
// so external packages have a name for the interface, even though the
// useful method on it is unexported (only metrics-package types can
// implement it, which is the intent).
type Collector = collectable
package mime
import (
"archive/tar"
"archive/zip"
"bufio"
"bytes"
"compress/gzip"
"encoding/base64"
"fmt"
"io"
"mime"
"mime/multipart"
"mime/quotedprintable"
"net/mail"
"net/textproto"
"os"
"path/filepath"
"strings"
"unicode"
)
// ExtractedPart represents a single extracted attachment.
type ExtractedPart struct {
Filename string
ContentType string
Size int64
TempPath string
Nested bool
ArchiveName string
}
// ExtractionResult holds all extracted parts and envelope metadata.
type ExtractionResult struct {
Parts []ExtractedPart
Partial bool
PartialReason string
Direction string
From string
To []string
Subject string
}
// Limits controls resource bounds during extraction.
type Limits struct {
MaxAttachmentSize int64
MaxArchiveDepth int
MaxArchiveFiles int
MaxExtractionSize int64
// TempDir is the directory CreateTemp uses for extracted parts.
// Empty falls back to os.TempDir() (/tmp on Linux). Operators
// should set this to a daemon-owned 0700 path so extracted email
// attachments are not staged in a world-writable directory where
// another local uid can race the scanner via symlink swaps.
TempDir string
}
// DefaultLimits returns the default extraction limits.
func DefaultLimits() Limits {
return Limits{
MaxAttachmentSize: 25 * 1024 * 1024,
MaxArchiveDepth: 1,
MaxArchiveFiles: 50,
MaxExtractionSize: 100 * 1024 * 1024,
}
}
// ParseSpoolMessage parses an Exim spool message (-H and -D files) and
// extracts attachments to a temp directory. Caller must remove the temp
// files in result.Parts[*].TempPath when done.
func ParseSpoolMessage(headerPath, bodyPath string, limits Limits) (*ExtractionResult, error) {
envelope, hdrs, err := parseEximHeader(headerPath)
if err != nil {
return nil, fmt.Errorf("parsing header file: %w", err)
}
result := &ExtractionResult{
From: envelope.from,
To: envelope.to,
Subject: envelope.subject,
}
// Determine direction from Received headers
result.Direction = detectDirection(hdrs)
maxBodyBytes := bodyReadLimit(limits)
bodyData, partial, err := readBodyFileLimited(bodyPath, maxBodyBytes)
if err != nil {
return nil, fmt.Errorf("reading body file: %w", err)
}
if partial {
result.Partial = true
result.PartialReason = "message body exceeds parser memory budget"
return result, nil
}
ct := hdrs.Get("Content-Type")
if ct == "" {
ct = "text/plain"
}
mediaType, params, parseErr := mime.ParseMediaType(ct)
if parseErr != nil {
// Unparseable content type - treat as plain text, no attachments
return result, nil //nolint:nilerr // fail-open by design
}
if strings.HasPrefix(mediaType, "multipart/") {
boundary := params["boundary"]
if boundary == "" {
return result, nil
}
var totalSize int64
extractErr := extractMultipart(bytes.NewReader(bodyData), boundary, limits, result, &totalSize, 0)
if extractErr != nil {
return result, nil //nolint:nilerr // fail-open by design: return what we extracted so far
}
} else if !strings.HasPrefix(mediaType, "text/") {
// Single-part non-text message (e.g. application/octet-stream,
// application/pdf, image/*). These are attachment-like payloads
// that must be scanned even without a multipart wrapper.
cte := strings.ToLower(hdrs.Get("Content-Transfer-Encoding"))
decoded, truncated := decodeSinglePart(bodyData, cte, limits.MaxAttachmentSize+1)
if !truncated && int64(len(decoded)) <= limits.MaxAttachmentSize {
tmpFile, tmpErr := os.CreateTemp(limits.TempDir, "csm-emailav-single-*")
if tmpErr == nil {
n, writeErr := tmpFile.Write(decoded)
closeErr := tmpFile.Close()
if writeErr != nil || closeErr != nil || n != len(decoded) {
os.Remove(tmpFile.Name())
markPartial(result, "could not stage single-part attachment for scanning")
} else {
filename := params["name"]
if filename == "" {
filename = "attachment"
}
filename = sanitizeAttachmentName(filename)
result.Parts = append(result.Parts, ExtractedPart{
Filename: filename,
ContentType: mediaType,
Size: int64(len(decoded)),
TempPath: tmpFile.Name(),
})
}
} else {
markPartial(result, "could not stage single-part attachment for scanning")
}
} else {
result.Partial = true
result.PartialReason = "single-part attachment exceeds max size"
}
}
// text/* bodies are not attachments - skip
return result, nil
}
func bodyReadLimit(limits Limits) int64 {
limit := limits.MaxExtractionSize
if limit < limits.MaxAttachmentSize*2 {
limit = limits.MaxAttachmentSize * 2
}
if limit <= 0 {
limit = DefaultLimits().MaxExtractionSize
}
return limit
}
func readFileLimited(path string, limit int64) ([]byte, error) {
// #nosec G304 -- path is mail queue file path from scanner walk.
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
data, err := io.ReadAll(io.LimitReader(f, limit+1))
if err != nil {
return nil, err
}
if int64(len(data)) > limit {
return nil, fmt.Errorf("message body exceeds parser memory budget")
}
return data, nil
}
// readBodyFileLimited opens the spool body file once and reads up to
// limit+1 bytes. Returns (data, partial=true, nil) when the file
// exceeds the limit so the caller can mark the result as partial.
// Folding the size check into the same file descriptor closes the
// TOCTOU window: previously the Stat-then-Open sequence let an
// attacker swap the file for a larger one between the two syscalls
// and bypass the limit.
func readBodyFileLimited(path string, limit int64) ([]byte, bool, error) {
// #nosec G304 -- path is mail queue file path from scanner walk.
f, err := os.Open(path)
if err != nil {
return nil, false, err
}
defer f.Close()
data, err := io.ReadAll(io.LimitReader(f, limit+1))
if err != nil {
return nil, false, err
}
if int64(len(data)) > limit {
return nil, true, nil
}
return data, false, nil
}
func markPartial(result *ExtractionResult, reason string) {
result.Partial = true
if result.PartialReason == "" {
result.PartialReason = reason
}
}
func decodeSinglePart(bodyData []byte, cte string, limit int64) ([]byte, bool) {
var reader io.Reader
switch cte {
case "base64":
reader = base64.NewDecoder(base64.StdEncoding, bytes.NewReader(bodyData))
case "quoted-printable":
reader = quotedprintable.NewReader(bytes.NewReader(bodyData))
default:
if int64(len(bodyData)) > limit {
return bodyData[:limit], true
}
return bodyData, false
}
decoded, err := io.ReadAll(io.LimitReader(reader, limit+1))
if err != nil {
return nil, true
}
if int64(len(decoded)) > limit {
return decoded[:limit], true
}
return decoded, false
}
type envelope struct {
from string
to []string
subject string
}
// parseEximHeader reads an Exim -H file and extracts envelope info and headers.
// Exim -H files have a specific format: the first lines are Exim internal metadata,
// followed by RFC 2822 headers.
func parseEximHeader(path string) (*envelope, textproto.MIMEHeader, error) {
// #nosec G304 -- path is Exim -H file path from mail queue scanner walk.
data, err := os.ReadFile(path)
if err != nil {
return nil, nil, err
}
env := &envelope{}
// Parse as mail message to extract headers - find the header block
// Exim -H files contain envelope info then headers separated by a blank line pattern.
// We look for standard RFC headers.
reader := bufio.NewReader(bytes.NewReader(data))
// Collect the raw header text by finding lines that look like RFC 822 headers
var headerBuf bytes.Buffer
inHeaders := false
for {
line, readErr := reader.ReadString('\n')
if readErr != nil && line == "" {
break
}
trimmed := strings.TrimRight(line, "\r\n")
// Detect start of RFC headers (lines like "From:", "To:", "Subject:", etc.)
if !inHeaders {
if len(trimmed) > 0 && strings.Contains(trimmed, ":") {
lower := strings.ToLower(trimmed)
if strings.HasPrefix(lower, "from:") ||
strings.HasPrefix(lower, "to:") ||
strings.HasPrefix(lower, "subject:") ||
strings.HasPrefix(lower, "date:") ||
strings.HasPrefix(lower, "mime-version:") ||
strings.HasPrefix(lower, "content-type:") ||
strings.HasPrefix(lower, "received:") ||
strings.HasPrefix(lower, "message-id:") {
inHeaders = true
}
}
}
if inHeaders {
headerBuf.WriteString(line)
if trimmed == "" {
break // end of headers
}
}
}
// Parse the collected headers
msg, msgErr := mail.ReadMessage(&headerBuf)
if msgErr != nil {
// Try to extract what we can from the raw data
return env, make(textproto.MIMEHeader), nil //nolint:nilerr // fail-open: return empty headers
}
env.from = msg.Header.Get("From")
env.subject = msg.Header.Get("Subject")
if to := msg.Header.Get("To"); to != "" {
for _, addr := range strings.Split(to, ",") {
env.to = append(env.to, strings.TrimSpace(addr))
}
}
// Convert mail.Header to textproto.MIMEHeader
hdrs := make(textproto.MIMEHeader)
for k, v := range msg.Header {
hdrs[k] = v
}
return env, hdrs, nil
}
// detectDirection guesses inbound vs outbound from Received headers.
func detectDirection(hdrs textproto.MIMEHeader) string {
received := hdrs.Values("Received")
if len(received) == 0 {
return "outbound" // locally generated, no Received headers
}
// If the first (topmost) Received header contains "authenticated" it's outbound
first := strings.ToLower(received[0])
if strings.Contains(first, "(authenticated") || strings.Contains(first, "auth=") {
return "outbound"
}
return "inbound"
}
// maxMIMENestingDepth caps multipart-in-multipart recursion. Legitimate
// mail rarely nests beyond mixed > alternative > related; a crafted message
// can nest arbitrarily and would otherwise consume one stack frame per
// wrapper while hiding attachments below any scanner's patience.
const maxMIMENestingDepth = 16
// extractMultipart recursively walks MIME parts, extracting attachments.
// depth counts archive nesting (zip-in-zip), not MIME nesting: an archive
// attached five multipart levels down is still archive depth 0.
func extractMultipart(r io.Reader, boundary string, limits Limits, result *ExtractionResult, totalSize *int64, depth int) error {
return extractMultipartNested(r, boundary, limits, result, totalSize, depth, 0)
}
func extractMultipartNested(r io.Reader, boundary string, limits Limits, result *ExtractionResult, totalSize *int64, depth, mimeDepth int) error {
if mimeDepth >= maxMIMENestingDepth {
result.Partial = true
if result.PartialReason == "" {
result.PartialReason = "MIME nesting exceeds depth limit"
}
return nil
}
mr := multipart.NewReader(r, boundary)
for {
part, err := mr.NextPart()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
ct := part.Header.Get("Content-Type")
if ct == "" {
ct = "text/plain"
}
mediaType, params, _ := mime.ParseMediaType(ct)
// Recurse into nested multipart
if strings.HasPrefix(mediaType, "multipart/") {
if b := params["boundary"]; b != "" {
if nestedErr := extractMultipartNested(part, b, limits, result, totalSize, depth, mimeDepth+1); nestedErr != nil {
return nestedErr
}
}
continue
}
// Skip inline text bodies - only extract attachments
disp := part.Header.Get("Content-Disposition")
filename := part.FileName()
if filename == "" {
// Try Content-Disposition filename param
if disp != "" {
_, dparams, _ := mime.ParseMediaType(disp)
filename = dparams["filename"]
}
}
if filename == "" {
// No filename and text/* content - this is a body part, skip
if strings.HasPrefix(mediaType, "text/") {
continue
}
// Non-text without filename - use generic name
filename = "unnamed_attachment"
}
rawFilename := filename
filename = sanitizeAttachmentName(filename)
// Decode the part body based on Content-Transfer-Encoding
cte := strings.ToLower(part.Header.Get("Content-Transfer-Encoding"))
var bodyReader io.Reader = part
switch cte {
case "base64":
bodyReader = base64.NewDecoder(base64.StdEncoding, part)
case "quoted-printable":
bodyReader = quotedprintable.NewReader(part)
}
// Write to temp file with size limit
tmpFile, err := os.CreateTemp(limits.TempDir, "csm-emailav-*")
if err != nil {
markPartial(result, "could not stage attachment for scanning")
return fmt.Errorf("creating temp file: %w", err)
}
limited := io.LimitReader(bodyReader, limits.MaxAttachmentSize+1)
n, err := io.Copy(tmpFile, limited)
closeErr := tmpFile.Close()
if err != nil || closeErr != nil {
os.Remove(tmpFile.Name())
markPartial(result, "could not stage attachment for scanning")
continue // fail-open: skip this part
}
if n > limits.MaxAttachmentSize {
os.Remove(tmpFile.Name())
result.Partial = true
result.PartialReason = fmt.Sprintf("attachment %q exceeds max size %d", filename, limits.MaxAttachmentSize)
continue
}
*totalSize += n
if *totalSize > limits.MaxExtractionSize {
os.Remove(tmpFile.Name())
result.Partial = true
result.PartialReason = fmt.Sprintf("total extraction size exceeds %d bytes", limits.MaxExtractionSize)
return nil // stop extracting
}
result.Parts = append(result.Parts, ExtractedPart{
Filename: filename,
ContentType: mediaType,
Size: n,
TempPath: tmpFile.Name(),
})
// Attempt archive extraction
if depth < limits.MaxArchiveDepth {
switch archiveKindForAttachmentName(rawFilename, filename) {
case "zip":
extractZIP(tmpFile.Name(), filename, limits, result, totalSize, depth+1)
case "tar.gz":
extractTarGz(tmpFile.Name(), filename, limits, result, totalSize, depth+1)
}
}
}
}
func extractZIP(zipPath, archiveName string, limits Limits, result *ExtractionResult, totalSize *int64, depth int) {
// #nosec G304 -- zipPath is CreateTemp-produced path from the caller.
f, err := os.Open(zipPath)
if err != nil {
return // fail-open: skip corrupt archives
}
defer f.Close()
info, err := f.Stat()
if err != nil {
return
}
zr, err := zip.NewReader(f, info.Size())
if err != nil {
return
}
extracted := 0
for _, zf := range zr.File {
if extracted >= limits.MaxArchiveFiles {
result.Partial = true
result.PartialReason = fmt.Sprintf("archive %q exceeds max files %d", archiveName, limits.MaxArchiveFiles)
return
}
if zf.FileInfo().IsDir() {
continue
}
safeName := sanitizeAttachmentName(zf.Name)
rc, err := zf.Open()
if err != nil {
continue
}
tmpFile, err := os.CreateTemp(limits.TempDir, "csm-emailav-zip-*")
if err != nil {
rc.Close()
markPartial(result, fmt.Sprintf("could not stage file %q in archive for scanning", safeName))
continue
}
limited := io.LimitReader(rc, limits.MaxAttachmentSize+1)
n, err := io.Copy(tmpFile, limited)
closeErr := tmpFile.Close()
rc.Close()
if err != nil || closeErr != nil || n > limits.MaxAttachmentSize {
os.Remove(tmpFile.Name())
if n > limits.MaxAttachmentSize {
result.Partial = true
result.PartialReason = fmt.Sprintf("file %q in archive exceeds max size", safeName)
} else {
markPartial(result, fmt.Sprintf("could not stage file %q in archive for scanning", safeName))
}
continue
}
*totalSize += n
if *totalSize > limits.MaxExtractionSize {
os.Remove(tmpFile.Name())
result.Partial = true
result.PartialReason = fmt.Sprintf("total extraction size exceeds %d bytes", limits.MaxExtractionSize)
return
}
result.Parts = append(result.Parts, ExtractedPart{
Filename: safeName,
ContentType: "application/octet-stream",
Size: n,
TempPath: tmpFile.Name(),
Nested: true,
ArchiveName: archiveName,
})
extracted++
}
}
func extractTarGz(tgzPath, archiveName string, limits Limits, result *ExtractionResult, totalSize *int64, depth int) {
// #nosec G304 -- tgzPath is CreateTemp-produced path from the caller.
f, err := os.Open(tgzPath)
if err != nil {
return
}
defer f.Close()
gr, err := gzip.NewReader(f)
if err != nil {
return
}
defer func() { _ = gr.Close() }()
tr := tar.NewReader(gr)
extracted := 0
for {
hdr, err := tr.Next()
if err != nil {
return // EOF or error - done
}
if hdr.Typeflag != tar.TypeReg {
continue
}
if extracted >= limits.MaxArchiveFiles {
result.Partial = true
result.PartialReason = fmt.Sprintf("archive %q exceeds max files %d", archiveName, limits.MaxArchiveFiles)
return
}
safeName := sanitizeAttachmentName(hdr.Name)
tmpFile, err := os.CreateTemp(limits.TempDir, "csm-emailav-tgz-*")
if err != nil {
markPartial(result, fmt.Sprintf("could not stage file %q in archive for scanning", safeName))
continue
}
limited := io.LimitReader(tr, limits.MaxAttachmentSize+1)
n, err := io.Copy(tmpFile, limited)
closeErr := tmpFile.Close()
if err != nil || closeErr != nil || n > limits.MaxAttachmentSize {
os.Remove(tmpFile.Name())
if n > limits.MaxAttachmentSize {
result.Partial = true
result.PartialReason = fmt.Sprintf("file %q in archive exceeds max size", safeName)
} else {
markPartial(result, fmt.Sprintf("could not stage file %q in archive for scanning", safeName))
}
continue
}
*totalSize += n
if *totalSize > limits.MaxExtractionSize {
os.Remove(tmpFile.Name())
result.Partial = true
result.PartialReason = fmt.Sprintf("total extraction size exceeds %d bytes", limits.MaxExtractionSize)
return
}
result.Parts = append(result.Parts, ExtractedPart{
Filename: safeName,
ContentType: "application/octet-stream",
Size: n,
TempPath: tmpFile.Name(),
Nested: true,
ArchiveName: archiveName,
})
extracted++
}
}
// sanitizeAttachmentName trims an attachment or archive-entry name to
// its base name and truncates at control characters before the name
// reaches logs, alerts, or JSON responses.
func sanitizeAttachmentName(name string) string {
name = strings.ReplaceAll(name, "\\", "/")
name = filepath.Base(name)
// Truncate at the first control character so a crafted entry
// like "good.txt\nFAKE-LOG-LINE" cannot smuggle a forged log
// record past the visible filename.
if i := strings.IndexFunc(name, unicode.IsControl); i >= 0 {
name = name[:i]
}
name = strings.TrimSpace(name)
switch name {
case "", ".", "..", "/":
return "_unnamed_"
}
return name
}
func archiveKindForAttachmentName(rawName, safeName string) string {
for _, name := range []string{safeName, archiveDetectionName(rawName)} {
lower := strings.ToLower(name)
switch {
case strings.HasSuffix(lower, ".zip"):
return "zip"
case strings.HasSuffix(lower, ".tar.gz") || strings.HasSuffix(lower, ".tgz"):
return "tar.gz"
}
}
return ""
}
// archiveDetectionName removes controls instead of truncating so a
// filename like "payload\u0085.zip" still gets unpacked while the
// public filename remains log-safe.
func archiveDetectionName(name string) string {
name = strings.ReplaceAll(name, "\\", "/")
name = filepath.Base(name)
name = strings.Map(func(r rune) rune {
if unicode.IsControl(r) {
return -1
}
return r
}, name)
return strings.TrimSpace(name)
}
package modsec
import "github.com/pidginhost/csm/internal/platform"
// RuleDirs returns the candidate directories where vendor ModSecurity rules
// live for the detected web server / panel combination. The list is ordered
// from most-specific to least-specific so callers walking the dirs encounter
// the operator's installed pack before any system fallback.
func RuleDirs(info platform.Info) []string {
var dirs []string
switch info.WebServer {
case platform.WSApache:
if info.IsDebianFamily() {
dirs = append(dirs,
"/etc/apache2/conf.d/modsec_vendor_configs/",
"/etc/modsecurity/",
"/usr/share/modsecurity-crs/rules/",
)
}
if info.IsRHELFamily() {
dirs = append(dirs,
"/etc/httpd/modsecurity.d/",
"/etc/httpd/modsecurity.d/activated_rules/",
"/usr/share/modsecurity-crs/rules/",
)
}
dirs = append(dirs, "/usr/local/apache/conf/modsec_vendor_configs/")
case platform.WSNginx:
dirs = append(dirs,
"/etc/nginx/modsec/",
"/etc/modsecurity/",
"/usr/share/modsecurity-crs/rules/",
)
}
// cPanel + LiteSpeed: cPanel's modsec_assemble job writes vendor rules
// into the apache2 tree even when the front-end is LiteSpeed. Without
// this branch the rule probe has no filesystem evidence during the
// window in which modsec_assemble itself is rewriting the tree.
if info.IsCPanel() && info.WebServer == platform.WSLiteSpeed {
dirs = append(dirs,
"/etc/apache2/conf.d/modsec_vendor_configs/",
"/usr/local/apache/conf/modsec_vendor_configs/",
)
}
return dirs
}
package modsec
import (
"bufio"
"fmt"
"io"
"log"
"os"
"sort"
"strconv"
"strings"
)
const overridesHeader = "# CSM ModSecurity Rule Overrides\n# Managed by CSM - do not edit manually.\n"
// WriteOverrides writes the overrides file with SecRuleRemoveById directives.
// Atomic write (tmp + rename). Sorts IDs for deterministic output.
func WriteOverrides(path string, disabledIDs []int) error {
sort.Ints(disabledIDs)
var sb strings.Builder
sb.WriteString(overridesHeader)
for _, id := range disabledIDs {
fmt.Fprintf(&sb, "SecRuleRemoveById %d\n", id)
}
tmpPath := path + ".tmp"
// #nosec G306 -- ModSec config files are read by the webserver process
// (Apache/nginx) which runs as a different user. 0640 lets root write
// and the webserver group read; world-read stays off.
if err := os.WriteFile(tmpPath, []byte(sb.String()), 0640); err != nil {
return fmt.Errorf("writing overrides tmp: %w", err)
}
if err := os.Rename(tmpPath, path); err != nil {
os.Remove(tmpPath)
return fmt.Errorf("renaming overrides: %w", err)
}
return nil
}
// ReadOverrides reads the overrides file and returns disabled rule IDs.
// Returns empty list (not error) if the file does not exist.
func ReadOverrides(path string) ([]int, error) {
// #nosec G304 -- path is operator-configured ModSec overrides file.
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
defer f.Close()
var ids []int
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), maxModsecLineBytes)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "SecRuleRemoveById ") {
idStr := strings.TrimPrefix(line, "SecRuleRemoveById ")
if id, err := strconv.Atoi(strings.TrimSpace(idStr)); err == nil && id >= 900000 && id <= 900999 {
ids = append(ids, id)
}
}
}
return ids, scanner.Err()
}
// ReadOverridesRaw reads the overrides file content for rollback purposes.
// Returns nil (not error) if the file does not exist.
func ReadOverridesRaw(path string) []byte {
// #nosec G304 -- path is operator-configured ModSec overrides file.
data, err := os.ReadFile(path)
if err != nil {
return nil
}
return data
}
// RestoreOverrides writes raw content back to the overrides file (for rollback).
// Uses atomic tmp+rename to prevent partial writes on crash.
func RestoreOverrides(path string, content []byte) error {
if content == nil {
// File didn't exist before - remove it
os.Remove(path)
return nil
}
tmpPath := path + ".tmp"
// #nosec G306 -- see note in WriteOverrides: webserver-readable ModSec config.
if err := os.WriteFile(tmpPath, content, 0640); err != nil {
return fmt.Errorf("writing rollback tmp: %w", err)
}
if err := os.Rename(tmpPath, path); err != nil {
os.Remove(tmpPath)
return fmt.Errorf("renaming rollback: %w", err)
}
return nil
}
// EnsureOverridesInclude appends an Include directive for the overrides file
// to a ModSecurity config file if not already present, and creates an empty
// overrides file if it doesn't exist. Idempotent: re-reads content under the
// write-open to avoid appending duplicate Include directives.
func EnsureOverridesInclude(rulesFile, overridesFile string) {
// Open for read+write to check-then-append atomically (same fd).
// #nosec G302 G304 -- webserver-readable ModSec rules file; 0640 means root
// can write and the webserver group can read. No world read.
f, err := os.OpenFile(rulesFile, os.O_RDWR|os.O_APPEND, 0640)
if err != nil {
return
}
data, err := io.ReadAll(f)
if err != nil {
_ = f.Close()
return
}
if !strings.Contains(string(data), overridesFile) {
if _, writeErr := fmt.Fprintf(f, "\n# CSM overrides - managed by CSM rule management\nInclude %s\n", overridesFile); writeErr != nil {
_ = f.Close()
return
}
}
// Close error here drops the appended Include directive, leaving the
// override file unreferenced -- worth surfacing instead of swallowing.
if closeErr := f.Close(); closeErr != nil {
log.Printf("modsec: overrides include close failed for %s: %v", rulesFile, closeErr)
}
// Create empty overrides file if it doesn't exist
if _, err := os.Stat(overridesFile); os.IsNotExist(err) {
// #nosec G306 -- see note in WriteOverrides: webserver-readable.
_ = os.WriteFile(overridesFile, []byte(overridesHeader), 0640)
}
}
package modsec
import (
"bufio"
"fmt"
"os"
"regexp"
"strconv"
"strings"
)
// maxModsecLineBytes bounds a single logical line for the rule scanners. Far
// above any legitimate directive, but finite so a pathological file cannot
// drive an unbounded allocation.
const maxModsecLineBytes = 8 << 20 // 8 MiB
// Rule represents a parsed ModSecurity rule from the CSM custom config.
type Rule struct {
ID int // e.g. 900112
Description string // from msg:'...' field
Action string // disposition keyword: deny|drop|block|redirect|proxy|pause|allow|pass; "" if rule has only metadata (log, msg, ...) and inherits SecDefaultAction
StatusCode int // 403, 429, 0 (for pass)
Phase int // 1 or 2
Raw string // full rule text including chains
IsCounter bool // true if pass,nolog (bookkeeping rule, hidden in UI)
}
var (
reID = regexp.MustCompile(`[,"]id:(\d+)`)
reMsg = regexp.MustCompile(`msg:'([^']*)'`)
rePhase = regexp.MustCompile(`phase:(\d)`)
reStatus = regexp.MustCompile(`status:(\d+)`)
)
// dispositionPriority lists the ModSecurity action keywords that decide
// what happens to the request, ordered most-disruptive first. The first
// keyword present as a standalone token in the action string wins; this
// matches how ModSecurity itself resolves multiple disruptive directives
// in a single rule. Metadata keywords (log, msg, severity, tag, ...) are
// intentionally excluded - a rule that carries only metadata inherits
// SecDefaultAction, which CSM does not parse, so the registry leaves
// Action empty and the LiteSpeed classifier defaults that rule to block.
var dispositionPriority = []string{
"deny", "drop", "block", "redirect", "proxy", "pause", "allow", "pass",
}
// dispositionSet is dispositionPriority as a lookup table. Pre-built for
// O(1) membership checks during action-string tokenisation.
var dispositionSet = func() map[string]struct{} {
m := make(map[string]struct{}, len(dispositionPriority))
for _, k := range dispositionPriority {
m[k] = struct{}{}
}
return m
}()
// ParseRulesFile reads a ModSecurity config file and extracts CSM-owned rules
// (IDs in 900000-900999). Use ParseRulesFileAll for the rule-action registry,
// which needs every rule including vendor packs.
func ParseRulesFile(path string) ([]Rule, error) {
all, err := ParseRulesFileAll(path)
if err != nil {
return nil, err
}
var csm []Rule
for _, r := range all {
if r.ID >= 900000 && r.ID <= 900999 {
csm = append(csm, r)
}
}
return csm, nil
}
// ParseRulesFileAll reads a ModSecurity config file and extracts every rule,
// regardless of ID range. Handles line continuations (\) and chained rules
// (chain action keyword). Vendor packs (Comodo, OWASP CRS, Imunify360) all
// use this entrypoint via the rule-action registry so the daemon can tell
// pass-action rules apart from deny rules.
func ParseRulesFileAll(path string) ([]Rule, error) {
// #nosec G304 -- path is operator-configured ModSec rules file.
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("opening rules file: %w", err)
}
defer f.Close()
// Phase 1: Read lines, joining backslash continuations into logical lines.
var logicalLines []string
var current strings.Builder
appendCurrent := func(part string) error {
if current.Len()+len(part) > maxModsecLineBytes {
return fmt.Errorf("logical modsec line exceeds %d bytes", maxModsecLineBytes)
}
current.WriteString(part)
return nil
}
scanner := bufio.NewScanner(f)
// Vendor packs (OWASP CRS, Comodo, Imunify360, cPanel modsec_assemble)
// ship assembled/minified directives that can exceed the default 64 KB
// token. Without a larger buffer Scan stops at ErrTooLong and the file's
// rules drop out of the action registry, where unknown IDs then default
// to "deny" and skew the modsec signal. Raise the ceiling so real files
// parse in full.
scanner.Buffer(make([]byte, 0, 64*1024), maxModsecLineBytes)
for scanner.Scan() {
raw := scanner.Text()
trimmed := strings.TrimSpace(raw)
if strings.HasSuffix(trimmed, "\\") {
// Continuation: strip trailing \ and keep accumulating.
if err := appendCurrent(strings.TrimSuffix(trimmed, "\\")); err != nil {
return nil, err
}
if err := appendCurrent(" "); err != nil {
return nil, err
}
continue
}
if err := appendCurrent(trimmed); err != nil {
return nil, err
}
logicalLines = append(logicalLines, current.String())
current.Reset()
}
if current.Len() > 0 {
logicalLines = append(logicalLines, current.String())
}
if err := scanner.Err(); err != nil {
return nil, err
}
// Phase 2: Group logical lines into blocks.
// Each block starts with a SecRule and may include chained SecRules.
// When a directive's action string contains "chain", the next SecRule
// is part of the same block.
var blocks []string
var block strings.Builder
chainPending := false
flushBlock := func() {
if block.Len() > 0 {
blocks = append(blocks, block.String())
block.Reset()
}
chainPending = false
}
for _, line := range logicalLines {
if strings.HasPrefix(line, "SecRule ") {
if block.Len() > 0 && !chainPending {
flushBlock()
}
} else if block.Len() == 0 {
continue // skip comments and blank lines outside blocks
}
if block.Len() > 0 || strings.HasPrefix(line, "SecRule ") {
block.WriteString(line)
block.WriteString("\n")
// Only update chainPending for SecRule lines - non-SecRule
// directives between chained rules must not reset the flag.
if strings.HasPrefix(line, "SecRule ") {
chainPending = hasChainAction(line)
}
}
}
flushBlock()
// Phase 3: Parse each block into a Rule.
var rules []Rule
for _, b := range blocks {
if r, ok := parseBlock(b); ok {
rules = append(rules, r)
}
}
return rules, nil
}
// hasChainAction checks whether a logical line (continuations already joined)
// contains "chain" as a ModSecurity action keyword. Strips whitespace to handle
// action strings split across continuation lines like: "id:900004,..., chain"
func hasChainAction(line string) bool {
// Remove all whitespace so ", chain\"" becomes ",chain\""
stripped := strings.Map(func(r rune) rune {
if r == ' ' || r == '\t' {
return -1
}
return r
}, line)
return strings.Contains(stripped, ",chain\"") ||
strings.Contains(stripped, ",chain'") ||
strings.Contains(stripped, ",chain,") ||
strings.Contains(stripped, "\"chain\"") ||
strings.Contains(stripped, "\"chain,")
}
func parseBlock(block string) (Rule, bool) {
// Extract ID
m := reID.FindStringSubmatch(block)
if m == nil {
return Rule{}, false
}
id, _ := strconv.Atoi(m[1])
r := Rule{
ID: id,
Raw: strings.TrimSpace(block),
}
// Extract description from msg
if mm := reMsg.FindStringSubmatch(block); mm != nil {
r.Description = mm[1]
}
// Extract phase
if pm := rePhase.FindStringSubmatch(block); pm != nil {
r.Phase, _ = strconv.Atoi(pm[1])
}
// Extract the action string (the quoted segment that carries id:N) and
// pull the disposition keyword from its tokens. Substring matching
// against the whole block would false-match on action-like text inside
// regex operators, msg:'...' literals, and the like (e.g. "passive"
// looks like "pass"). Token parsing inside the bounded action string
// avoids those collisions.
if actionStr := extractActionString(block, id); actionStr != "" {
r.Action = pickDisposition(actionStr)
// nolog/log are honoured only as flags, never as the action.
actionLower := strings.ToLower(actionStr)
if r.Action == "pass" && strings.Contains(actionLower, "nolog") {
r.IsCounter = true
}
}
// Extract status code
if sm := reStatus.FindStringSubmatch(block); sm != nil {
r.StatusCode, _ = strconv.Atoi(sm[1])
}
return r, true
}
// extractActionString returns the body of the quoted segment that contains
// "id:<ruleID>". ModSecurity rule blocks have one such segment per rule
// (chained sub-rules carry "chain,capture"-style action lists with no id),
// so finding the id-bearing quotes uniquely identifies the primary action
// list. Returns "" if the segment cannot be located, in which case the
// caller leaves Action empty and the registry treats the rule as unknown.
func extractActionString(block string, ruleID int) string {
needle := "id:" + strconv.Itoa(ruleID)
for i := 0; i < len(block); i++ {
if block[i] != '"' {
continue
}
start := i + 1
escaped := false
for j := start; j < len(block); j++ {
switch {
case escaped:
escaped = false
case block[j] == '\\':
escaped = true
case block[j] == '"':
segment := block[start:j]
if actionStringHasRuleID(segment, needle) {
return segment
}
i = j
j = len(block)
}
}
}
return ""
}
func actionStringHasRuleID(actionStr, needle string) bool {
for _, tok := range tokenizeActionString(actionStr) {
name, value, ok := strings.Cut(tok, ":")
if !ok || strings.ToLower(strings.TrimSpace(name)) != "id" {
continue
}
if strings.Trim(strings.TrimSpace(value), `'"`) == strings.TrimPrefix(needle, "id:") {
return true
}
}
return false
}
// pickDisposition returns the ModSecurity disposition keyword present in
// the action string, preferring more-disruptive keywords when several are
// present (defensive: a rule labelled "deny" wins over a stray "allow").
// Returns "" if the action string carries only metadata and the rule
// therefore inherits SecDefaultAction.
func pickDisposition(actionStr string) string {
tokens := tokenizeActionString(actionStr)
seen := make(map[string]struct{}, len(tokens))
for _, t := range tokens {
name := t
if i := strings.IndexByte(name, ':'); i >= 0 {
name = name[:i]
}
name = strings.ToLower(strings.TrimSpace(name))
if _, ok := dispositionSet[name]; ok {
seen[name] = struct{}{}
}
}
for _, kw := range dispositionPriority {
if _, ok := seen[kw]; ok {
return kw
}
}
return ""
}
// tokenizeActionString splits a ModSecurity action list on top-level commas
// while respecting single-quoted string values (msg:'foo, bar', logdata:'...'),
// where commas are part of the literal and must not be treated as token
// separators. Backslash-escaped quotes inside the literals are preserved.
func tokenizeActionString(s string) []string {
var out []string
var cur strings.Builder
inSingle := false
escape := false
for _, r := range s {
switch {
case escape:
cur.WriteRune(r)
escape = false
case r == '\\':
cur.WriteRune(r)
escape = true
case r == '\'':
inSingle = !inSingle
cur.WriteRune(r)
case r == ',' && !inSingle:
tok := strings.TrimSpace(cur.String())
if tok != "" {
out = append(out, tok)
}
cur.Reset()
default:
cur.WriteRune(r)
}
}
if tok := strings.TrimSpace(cur.String()); tok != "" {
out = append(out, tok)
}
return out
}
// IsBlockingAction reports whether an action causes the request to be denied
// or otherwise diverted away from normal processing. Used by the LiteSpeed
// log-line classifier - error_log records every match as "triggered!"
// regardless of action, so the action lookup is the only way to tell a real
// deny apart from a pass-action informational rule. redirect, proxy and
// pause are disruptive: the original request never reaches the upstream
// application as intended, so they are classified the same as deny.
func IsBlockingAction(action string) bool {
switch action {
case "deny", "drop", "block", "redirect", "proxy", "pause":
return true
}
return false
}
package modsec
import (
"errors"
"io/fs"
"path/filepath"
"strings"
"sync/atomic"
)
// Registry maps parsed ModSecurity rule IDs with a decisive disposition to
// that action (deny, drop, block, redirect, proxy, pause, pass, allow).
// Rules with only metadata actions are intentionally absent so callers use
// the unknown-rule default. It is consulted by the LiteSpeed
// log-line classifier - error_log records every match as "triggered!"
// regardless of whether the rule's action denied the request, so the action
// lookup is the only signal that distinguishes a real deny from a noisy
// pass-action informational rule.
type Registry struct {
actions map[int]string
}
// Action returns the declared action for ruleID and whether it is known.
// An unknown ID is the safe default: callers should treat it as a potential
// block to preserve coverage when the rule files have not been parsed yet.
func (r *Registry) Action(ruleID int) (action string, known bool) {
if r == nil {
return "", false
}
a, ok := r.actions[ruleID]
return a, ok
}
// Len returns the number of rules in the registry. Useful for startup
// telemetry: zero typically means rule directories are missing.
func (r *Registry) Len() int {
if r == nil {
return 0
}
return len(r.actions)
}
// BuildRegistry walks every directory in dirs (recursively), parses each
// .conf file via ParseRulesFileAll, and returns a Registry mapping rule IDs
// to actions. Per-file parse errors are swallowed; a vendor pack with one
// malformed file should not blank the whole registry.
//
// Precedence: dirs is treated as most-specific-first. Within a single
// directory, files are walked in lexical order and a duplicate rule ID
// uses last-write-wins, mirroring how ModSecurity itself resolves two
// SecRule directives that share an ID. Across directories, the first
// directory to define a rule keeps it - that way an operator override in
// /etc/apache2/conf.d/modsec_vendor_configs/ is not silently replaced by
// a stale system fallback in /usr/share/modsecurity-crs/rules/.
func BuildRegistry(dirs []string) (*Registry, error) {
actions := make(map[int]string)
claimed := make(map[int]struct{})
for _, dir := range dirs {
if dir == "" {
continue
}
perDirActions := make(map[int]string)
perDirClaimed := make(map[int]struct{})
walkErr := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return filepath.SkipDir
}
return nil
}
if d.IsDir() {
return nil
}
if !strings.HasSuffix(strings.ToLower(d.Name()), ".conf") {
return nil
}
// Per-file parse errors are intentionally discarded so one
// malformed vendor file (truncated mid-rule, encoding glitch,
// in-flight modsec_assemble overwrite) does not blank the whole
// registry. The cost is silent on the leaf package; daemon-side
// telemetry surfaces the eventual rule count via the startup log.
rules, _ := ParseRulesFileAll(path)
for _, r := range rules {
perDirClaimed[r.ID] = struct{}{}
if r.Action != "" {
perDirActions[r.ID] = r.Action
} else {
delete(perDirActions, r.ID)
}
}
return nil
})
// Promote the per-directory map into the global map only for IDs
// that no earlier (more-specific) directory has already claimed.
for id := range perDirClaimed {
if _, exists := claimed[id]; !exists {
claimed[id] = struct{}{}
action, hasAction := perDirActions[id]
if !hasAction {
continue
}
actions[id] = action
}
}
if walkErr != nil && !errors.Is(walkErr, fs.ErrNotExist) {
return &Registry{actions: actions}, walkErr
}
}
return &Registry{actions: actions}, nil
}
var globalRegistry atomic.Pointer[Registry]
// SetGlobal installs r as the daemon-wide registry. Callers are expected to
// rebuild and re-set on a refresh interval. Safe for concurrent use.
func SetGlobal(r *Registry) {
globalRegistry.Store(r)
}
// Global returns the currently installed registry, or nil if none has been
// set yet (e.g. during very early daemon startup, or in unit tests that
// did not seed one). Callers must nil-check.
func Global() *Registry {
return globalRegistry.Load()
}
// ResetGlobalForTest clears the global registry. Test-only helper.
func ResetGlobalForTest() {
globalRegistry.Store(nil)
}
package modsec
import (
"context"
"errors"
"fmt"
"os/exec"
"time"
)
const reloadTimeout = 30 * time.Second
// Reload executes the configured web server reload command.
// Returns combined stdout+stderr output and any error.
func Reload(command string) (string, error) {
if command == "" {
return "", errors.New("reload command is empty")
}
ctx, cancel := context.WithTimeout(context.Background(), reloadTimeout)
defer cancel()
// Run through shell to support compound commands, quoted paths, etc.
// #nosec G204 -- `command` is the operator-configured reload command
// from csm.yaml (e.g. "apachectl graceful"), loaded at daemon startup
// from a root-owned config. Not webui-settable.
cmd := exec.CommandContext(ctx, "sh", "-c", command)
out, err := cmd.CombinedOutput()
output := string(out)
if ctx.Err() == context.DeadlineExceeded {
return output, fmt.Errorf("reload timed out after %v", reloadTimeout)
}
if err != nil {
return output, fmt.Errorf("reload failed: %w (output: %s)", err, output)
}
return output, nil
}
// Package mysqlclient wraps the database/sql + go-sql-driver/mysql
// pair for the read-only queries CSM issues against host-local
// MySQL/MariaDB. It replaces the per-call `mysql -e <query>`
// shell-outs across CMS DB scans, performance metrics, and forensic
// dumps so the daemon no longer forks and tears down a child process
// (plus its libc/libmariadbclient relocations) every query.
//
// Two cred modes are supported:
//
// - Root: implicit, mirrors the historical `mysql` CLI behaviour
// that reads /root/.my.cnf [client] section when present. RootDB
// parses that file and returns a *sql.DB pointed at unix-socket
// auth, falling back to the mysql CLI's root socket-auth default
// when absent.
//
// - Per-account: the caller passes explicit user / password / host
// / dbname / port / socket via DSN. PerAccountQuery opens,
// runs, closes.
//
// Output shape mirrors the previous `mysql -N -B -e` shell-out: rows
// are returned as []string where each entry is the tab-joined column
// values of one result row, with MySQL batch-mode escaping applied.
// Existing scanner code paths consume this format unchanged.
package mysqlclient
import (
"bufio"
"context"
"database/sql"
"fmt"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/go-sql-driver/mysql"
)
// queryTimeout caps a single SELECT to a defensive 30 s -- well under
// the legacy 2 min CLI cmdTimeout but enough for any realistic CMS
// scan / SHOW STATUS / SHOW PROCESSLIST result.
const queryTimeout = 30 * time.Second
// Creds carries per-account database credentials. Mirrors wpDBCreds in
// internal/checks so callers can pass the same shape without an extra
// adapter type.
type Creds struct {
User string
Password string
Host string
Port int
Socket string
DBName string
}
// dsn returns a go-sql-driver/mysql connection string. Empty Host
// falls back to the Unix socket path the mysql CLI uses on cPanel
// hosts; this matches the historical default when -h was omitted.
func (c Creds) dsn() string {
cfg := mysql.NewConfig()
cfg.User = c.User
cfg.Passwd = c.Password
cfg.DBName = c.DBName
// Read timeouts mirror queryTimeout's budget.
cfg.Timeout = 5 * time.Second
cfg.ReadTimeout = queryTimeout
cfg.WriteTimeout = 5 * time.Second
cfg.Loc = time.Local
cfg.Net, cfg.Addr = c.networkAddr()
return cfg.FormatDSN()
}
func (c Creds) networkAddr() (string, string) {
host := strings.TrimSpace(c.Host)
socket := strings.TrimSpace(c.Socket)
if socket != "" && (host == "" || host == "localhost") {
return "unix", socket
}
if strings.HasPrefix(host, "/") {
return "unix", host
}
if _, socket, ok := splitHostSocket(host); ok {
return "unix", socket
}
if host == "" || host == "localhost" {
return "unix", defaultUnixSocket()
}
host, port := splitHostPort(host, c.Port)
return "tcp", net.JoinHostPort(host, strconv.Itoa(port))
}
func splitHostSocket(host string) (string, string, bool) {
idx := strings.LastIndex(host, ":/")
if idx < 0 || idx == len(host)-1 {
return "", "", false
}
return host[:idx], host[idx+1:], true
}
func splitHostPort(host string, fallbackPort int) (string, int) {
if h, p, err := net.SplitHostPort(host); err == nil {
if port, perr := strconv.Atoi(p); perr == nil {
return trimIPv6Brackets(h), port
}
}
if strings.Count(host, ":") == 1 {
h, p, _ := strings.Cut(host, ":")
if port, err := strconv.Atoi(p); err == nil {
return h, port
}
}
if fallbackPort == 0 {
fallbackPort = 3306
}
return trimIPv6Brackets(host), fallbackPort
}
func trimIPv6Brackets(host string) string {
return strings.TrimSuffix(strings.TrimPrefix(host, "["), "]")
}
// defaultUnixSocket returns the first known mysql socket path that
// exists on disk, falling back to the cPanel canonical location.
func defaultUnixSocket() string {
candidates := []string{
"/var/lib/mysql/mysql.sock",
"/tmp/mysql.sock",
"/var/run/mysqld/mysqld.sock",
}
for _, p := range candidates {
if _, err := os.Stat(p); err == nil {
return p
}
}
return "/var/lib/mysql/mysql.sock"
}
// PerAccountQuery opens a short-lived connection with the supplied
// credentials, runs the query, and returns each row as a tab-joined
// string. Empty result set returns (nil, nil). Errors include open,
// query, and scan failures.
//
// Tests can intercept via SetPerAccountQueryForTest.
func PerAccountQuery(ctx context.Context, creds Creds, query string, args ...any) ([]string, error) {
if fn := getPerAccountQueryMock(); fn != nil {
return fn(ctx, creds, query, args...)
}
db, err := sql.Open("mysql", creds.dsn())
if err != nil {
return nil, fmt.Errorf("mysqlclient: open: %w", err)
}
defer func() { _ = db.Close() }()
// Single short-lived call: cap idle pool to avoid leaking
// connections on per-account scans across many tenants.
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(0)
db.SetConnMaxLifetime(queryTimeout)
return runQuery(ctx, db, query, args...)
}
// PerAccountQueryFunc is the signature SetPerAccountQueryForTest accepts.
type PerAccountQueryFunc func(ctx context.Context, creds Creds, query string, args ...any) ([]string, error)
var (
perAccountMockMu sync.RWMutex
perAccountMock PerAccountQueryFunc
)
// SetPerAccountQueryForTest installs an interceptor for
// PerAccountQuery. Pass nil to clear and restore the real database/sql
// path. Production code paths must NOT call this.
func SetPerAccountQueryForTest(fn PerAccountQueryFunc) {
perAccountMockMu.Lock()
defer perAccountMockMu.Unlock()
perAccountMock = fn
}
func getPerAccountQueryMock() PerAccountQueryFunc {
perAccountMockMu.RLock()
defer perAccountMockMu.RUnlock()
return perAccountMock
}
// runQuery is the shared execution path for any *sql.DB. Internal so
// the package can grow a RootDB-backed singleton later without
// duplicating the row-iteration code.
func runQuery(ctx context.Context, db *sql.DB, query string, args ...any) ([]string, error) {
cctx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
rows, err := db.QueryContext(cctx, query, args...)
if err != nil {
return nil, fmt.Errorf("mysqlclient: query: %w", err)
}
defer func() { _ = rows.Close() }()
cols, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("mysqlclient: columns: %w", err)
}
out := make([]string, 0)
raw := make([]sql.NullString, len(cols))
scanArgs := make([]any, len(cols))
for i := range raw {
scanArgs[i] = &raw[i]
}
for rows.Next() {
if err := rows.Scan(scanArgs...); err != nil {
return nil, fmt.Errorf("mysqlclient: scan: %w", err)
}
parts := make([]string, len(cols))
for i, v := range raw {
if v.Valid {
parts[i] = mysqlBatchEscape(v.String)
} else {
// `mysql -N -B` prints NULL for SQL NULL; preserve
// that so legacy parsers that key off the literal
// see the same bytes.
parts[i] = "NULL"
}
}
out = append(out, strings.Join(parts, "\t"))
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("mysqlclient: iterate: %w", err)
}
return out, nil
}
func mysqlBatchEscape(s string) string {
var b strings.Builder
for i := 0; i < len(s); i++ {
switch s[i] {
case 0:
b.WriteString(`\0`)
case '\n':
b.WriteString(`\n`)
case '\r':
b.WriteString(`\r`)
case '\t':
b.WriteString(`\t`)
case '\\':
b.WriteString(`\\`)
default:
b.WriteByte(s[i])
}
}
return b.String()
}
// --- Root creds via /root/.my.cnf ---------------------------------------
var (
rootMu sync.Mutex
rootDB *sql.DB
rootPath = "/root/.my.cnf"
)
// SetRootCnfPath overrides the [client] config path before the first
// RootDB() call. Tests set this to a tempdir-rooted .my.cnf.
func SetRootCnfPath(p string) {
rootMu.Lock()
defer rootMu.Unlock()
rootPath = p
// Drop any cached DB so the next RootDB() rebuilds from the new path.
if rootDB != nil {
_ = rootDB.Close()
rootDB = nil
}
}
// RootQuery runs a query as root using the credentials in
// /root/.my.cnf (or the path set via SetRootCnfPath). It mirrors the
// historical `mysql -N -B -e <query>` invocation that picked up
// /root/.my.cnf implicitly. The connection is pooled across calls.
//
// Tests can intercept via SetRootQueryForTest.
func RootQuery(ctx context.Context, query string, args ...any) ([]string, error) {
if fn := getRootQueryMock(); fn != nil {
return fn(ctx, "", query, args...)
}
db, err := RootDB()
if err != nil {
return nil, err
}
return runQuery(ctx, db, query, args...)
}
// RootExec runs a non-SELECT statement (DDL/DML) as root and returns
// the affected-row count plus any error. Mirrors `mysql -e <stmt>` for
// callers that previously relied on exit code only.
//
// Tests can intercept via SetRootExecForTest.
func RootExec(ctx context.Context, stmt string, args ...any) (int64, error) {
if fn := getRootExecMock(); fn != nil {
return fn(ctx, "", stmt, args...)
}
db, err := RootDB()
if err != nil {
return 0, err
}
cctx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
res, err := db.ExecContext(cctx, stmt, args...)
if err != nil {
return 0, fmt.Errorf("mysqlclient: exec: %w", err)
}
rows, _ := res.RowsAffected()
return rows, nil
}
// RootExecSchema runs a non-SELECT statement against an explicit schema
// using root creds. Mirrors `mysql <schema> -e <stmt>`.
func RootExecSchema(ctx context.Context, schema, stmt string, args ...any) (int64, error) {
if fn := getRootExecMock(); fn != nil {
return fn(ctx, schema, stmt, args...)
}
creds, err := loadRootCreds()
if err != nil {
return 0, err
}
creds.DBName = schema
db, err := sql.Open("mysql", creds.dsn())
if err != nil {
return 0, fmt.Errorf("mysqlclient: open %s: %w", schema, err)
}
defer func() { _ = db.Close() }()
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(0)
db.SetConnMaxLifetime(queryTimeout)
cctx, cancel := context.WithTimeout(ctx, queryTimeout)
defer cancel()
res, err := db.ExecContext(cctx, stmt, args...)
if err != nil {
return 0, fmt.Errorf("mysqlclient: exec %s: %w", schema, err)
}
rows, _ := res.RowsAffected()
return rows, nil
}
// RootQueryFunc is the signature SetRootQueryForTest accepts. schema is
// empty for RootQuery, non-empty for RootQuerySchema.
type RootQueryFunc func(ctx context.Context, schema, query string, args ...any) ([]string, error)
// RootExecFunc is the signature SetRootExecForTest accepts. schema is
// empty for RootExec, non-empty for RootExecSchema.
type RootExecFunc func(ctx context.Context, schema, stmt string, args ...any) (int64, error)
var (
rootQueryMockMu sync.RWMutex
rootQueryMock RootQueryFunc
rootExecMockMu sync.RWMutex
rootExecMock RootExecFunc
)
// SetRootQueryForTest installs an interceptor for RootQuery /
// RootQuerySchema. Pass nil to clear and restore the real database/sql
// path. Production code paths must NOT call this.
func SetRootQueryForTest(fn RootQueryFunc) {
rootQueryMockMu.Lock()
defer rootQueryMockMu.Unlock()
rootQueryMock = fn
}
// SetRootExecForTest installs an interceptor for RootExec /
// RootExecSchema. Pass nil to clear and restore the real database/sql
// path. Production code paths must NOT call this.
func SetRootExecForTest(fn RootExecFunc) {
rootExecMockMu.Lock()
defer rootExecMockMu.Unlock()
rootExecMock = fn
}
func getRootQueryMock() RootQueryFunc {
rootQueryMockMu.RLock()
defer rootQueryMockMu.RUnlock()
return rootQueryMock
}
func getRootExecMock() RootExecFunc {
rootExecMockMu.RLock()
defer rootExecMockMu.RUnlock()
return rootExecMock
}
// RootDB returns the pooled root MySQL handle backed by /root/.my.cnf
// when present (or the path set via SetRootCnfPath), falling back to
// the mysql CLI's root socket-auth default. sql.Open does not contact
// the server until the first query.
func RootDB() (*sql.DB, error) {
return rootSingleton()
}
// RootQuerySchema runs a query against an explicit schema using root
// creds. Mirrors `mysql <schema> -e <query>`.
func RootQuerySchema(ctx context.Context, schema, query string, args ...any) ([]string, error) {
if fn := getRootQueryMock(); fn != nil {
return fn(ctx, schema, query, args...)
}
creds, err := loadRootCreds()
if err != nil {
return nil, err
}
creds.DBName = schema
return PerAccountQuery(ctx, creds, query, args...)
}
func rootSingleton() (*sql.DB, error) {
rootMu.Lock()
if rootDB != nil {
defer rootMu.Unlock()
return rootDB, nil
}
rootMu.Unlock()
creds, err := loadRootCreds()
if err != nil {
return nil, err
}
db, err := sql.Open("mysql", creds.dsn())
if err != nil {
return nil, fmt.Errorf("mysqlclient: root open: %w", err)
}
// Conservative pool: root scans run at most a few queries per
// minute, sharing one idle connection is plenty.
db.SetMaxOpenConns(2)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(5 * time.Minute)
rootMu.Lock()
defer rootMu.Unlock()
if rootDB != nil {
_ = db.Close()
return rootDB, nil
}
rootDB = db
return rootDB, nil
}
// loadRootCreds parses /root/.my.cnf's [client] section (or the
// path set via SetRootCnfPath) when the file exists. The format is the
// same INI subset mysql_config_editor / the official client honour:
// section header in `[client]`, key=value pairs (values may be
// unquoted, single quoted, or double quoted; matching quote is
// stripped).
//
// Missing files and omitted user keys fall back to user=root with an
// empty password so Unix socket-auth setups keep matching `mysql -e`.
func loadRootCreds() (Creds, error) {
rootMu.Lock()
path := rootPath
rootMu.Unlock()
creds := Creds{User: "root"}
// #nosec G304 -- path is operator-configured at init time, not
// attacker-controlled.
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
return creds, nil
}
return Creds{}, fmt.Errorf("mysqlclient: open %s: %w", path, err)
}
defer f.Close()
inClient := false
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 4096), 1<<20)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") {
continue
}
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
inClient = strings.EqualFold(line, "[client]") || strings.EqualFold(line, "[mysql]")
continue
}
if !inClient {
continue
}
key, val, ok := strings.Cut(line, "=")
if !ok {
continue
}
key = strings.TrimSpace(key)
val = unquote(strings.TrimSpace(val))
switch strings.ToLower(key) {
case "user":
if val != "" {
creds.User = val
}
case "password", "pass":
creds.Password = val
case "host":
creds.Host = val
case "port":
if p, perr := strconv.Atoi(val); perr == nil {
creds.Port = p
}
case "socket":
creds.Socket = val
}
}
if err := scanner.Err(); err != nil {
return Creds{}, fmt.Errorf("mysqlclient: read %s: %w", path, err)
}
return creds, nil
}
func unquote(s string) string {
if len(s) >= 2 {
first, last := s[0], s[len(s)-1]
if (first == '"' && last == '"') || (first == '\'' && last == '\'') {
return s[1 : len(s)-1]
}
}
return s
}
// Package obs centralises crash reporting and selective error capture via
// Sentry. Init is a one-shot called from the daemon entry point; after
// that, callers use Go/SafeGo to launch goroutines with panic recovery
// and Capture/CaptureMsg to forward selected errors.
//
// When Sentry is disabled or the DSN is empty, every function becomes a
// no-op wrapper with the same semantics as a plain `go func()` call,
// so guarded call sites work unchanged in tests and in operator
// builds that opt out of telemetry.
package obs
import (
"fmt"
"os"
"sync/atomic"
"time"
"github.com/getsentry/sentry-go"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
)
var enabled atomic.Bool
// flushTimeout bounds how long shutdown flushes block before giving up.
// Sentry enforces its own deadline on Flush; we pass this value so
// stuck HTTP calls don't hang systemd past TimeoutStopSec.
const flushTimeout = 2 * time.Second
// Init configures the Sentry SDK once at daemon startup. Returns nil if
// Sentry is disabled or the DSN is empty so callers can treat init as
// optional. Safe to call when cfg is nil.
func Init(cfg *config.Config, version, buildHash string) error {
if cfg == nil || !cfg.Sentry.Enabled || cfg.Sentry.DSN == "" {
return nil
}
env := cfg.Sentry.Environment
if env == "" {
env = "production"
}
rate := cfg.Sentry.SampleRate
if rate <= 0 {
rate = 1.0
}
hostname, _ := os.Hostname()
release := "csm@" + version
if buildHash != "" && buildHash != "unknown" {
release = release + "+" + buildHash
}
err := sentry.Init(sentry.ClientOptions{
Dsn: cfg.Sentry.DSN,
Environment: env,
Release: release,
ServerName: hostname,
SampleRate: rate,
TracesSampleRate: 0.0,
Debug: cfg.Sentry.Debug,
AttachStacktrace: true,
})
if err != nil {
return fmt.Errorf("sentry init: %w", err)
}
info := platform.Detect()
sentry.ConfigureScope(func(scope *sentry.Scope) {
scope.SetTag("os", string(info.OS))
if info.OSVersion != "" {
scope.SetTag("os_version", info.OSVersion)
}
if info.Panel != "" {
scope.SetTag("panel", string(info.Panel))
}
if info.WebServer != "" {
scope.SetTag("webserver", string(info.WebServer))
}
})
enabled.Store(true)
return nil
}
// Enabled reports whether Init succeeded and Sentry is live.
func Enabled() bool { return enabled.Load() }
// Flush waits for queued events to be sent before returning. Call from
// shutdown paths before exit. Safe to call when Sentry is disabled.
func Flush() {
if !enabled.Load() {
return
}
sentry.Flush(flushTimeout)
}
// Go launches fn in a new goroutine. A panic in fn is captured with the
// given component tag, the event is flushed, and the panic is
// re-raised so the existing crash-and-systemd-restart behavior is
// preserved. Use this for long-lived supervisor goroutines where a
// silent death would leave the daemon in a degraded state.
func Go(component string, fn func()) {
go func() {
defer func() {
if r := recover(); r != nil {
report(component, r)
panic(r)
}
}()
fn()
}()
}
// SafeGo is Go but swallows the panic after capture. Use this for
// per-request handlers (socket accept, HTTP request) where one bad
// input should not crash the whole daemon.
func SafeGo(component string, fn func()) {
go func() {
defer func() {
if r := recover(); r != nil {
report(component, r)
}
}()
fn()
}()
}
// Capture sends an error to Sentry with a component tag. No-op when
// disabled or err is nil. Reserve for unexpected states and invariant
// violations; expected-failure errors (permission denied, transient
// network) should stay out of Sentry to avoid noise.
func Capture(component string, err error) {
if !enabled.Load() || err == nil {
return
}
sentry.WithScope(func(scope *sentry.Scope) {
scope.SetTag("component", component)
sentry.CaptureException(err)
})
}
// CaptureMsg is Capture for string-only events (e.g. invariant
// violations without a wrapped error).
func CaptureMsg(component, msg string) {
if !enabled.Load() {
return
}
sentry.WithScope(func(scope *sentry.Scope) {
scope.SetTag("component", component)
sentry.CaptureMessage(msg)
})
}
func report(component string, r any) {
if !enabled.Load() {
return
}
sentry.WithScope(func(scope *sentry.Scope) {
scope.SetTag("component", component)
sentry.CurrentHub().Recover(r)
})
sentry.Flush(flushTimeout)
}
package platform
// MTAIdents lists local users and process basenames belonging to the
// host's Mail Transfer Agent stack. Direct SMTP egress detection uses
// this allowlist to skip legitimate local MTA traffic instead of
// path-allowlisting a directory.
type MTAIdents struct {
Users []string
Processes []string
}
// IsMTAUser reports whether name is one of the known MTA usernames.
// Match is exact and case-sensitive (Linux usernames are).
func (m MTAIdents) IsMTAUser(name string) bool {
for _, u := range m.Users {
if u == name {
return true
}
}
return false
}
// IsMTAProcess reports whether basename is one of the known MTA process
// basenames. Exact match; the caller passes comm or basename(exe), not
// a full path.
func (m MTAIdents) IsMTAProcess(basename string) bool {
for _, p := range m.Processes {
if p == basename {
return true
}
}
return false
}
// LocalMTAIdentities returns the MTA users and process basenames that
// should be considered legitimate on the detected platform. cPanel
// hosts get exim variants; non-cPanel hosts get the postfix/dovecot
// baseline.
func LocalMTAIdentities(info Info) MTAIdents {
users := []string{
"mail",
"mailnull",
"postfix",
"dovecot",
"dovenull",
"mailman",
}
processes := []string{
"postfix",
"smtpd",
"smtp",
"qmgr",
"pickup",
"cleanup",
"local",
"dovecot",
"imap-login",
"pop3-login",
"lmtp",
}
if info.IsCPanel() {
users = append(users, "exim")
processes = append(processes, "exim", "exim4")
}
return MTAIdents{Users: users, Processes: processes}
}
// Package platform detects the host OS, control panel, and web server so
// CSM checks can pick the right config/log paths instead of hardcoding
// cPanel+Apache layouts.
package platform
import (
"bufio"
"os"
"os/exec"
"strings"
"sync"
"sync/atomic"
)
type OSFamily string
const (
OSUnknown OSFamily = ""
OSUbuntu OSFamily = "ubuntu"
OSDebian OSFamily = "debian"
OSAlma OSFamily = "almalinux"
OSRocky OSFamily = "rocky"
OSCentOS OSFamily = "centos"
OSRHEL OSFamily = "rhel"
OSCloudLinux OSFamily = "cloudlinux"
)
type Panel string
const (
PanelNone Panel = ""
PanelCPanel Panel = "cpanel"
PanelPlesk Panel = "plesk"
PanelDA Panel = "directadmin"
)
type WebServer string
const (
WSNone WebServer = ""
WSApache WebServer = "apache"
WSNginx WebServer = "nginx"
WSLiteSpeed WebServer = "litespeed"
)
// Info holds everything a check needs to locate web server resources.
type Info struct {
OS OSFamily
OSVersion string
Panel Panel
WebServer WebServer
// Config locations for the detected web server.
ApacheConfigDir string // e.g. /etc/apache2 or /etc/httpd
NginxConfigDir string // e.g. /etc/nginx
// Candidate log files. Populated based on detected web server + OS.
AccessLogPaths []string
ErrorLogPaths []string
ModSecAuditLogPaths []string
// DomlogGlobs is the list of glob patterns used to enumerate
// per-vhost access logs. Populated per panel + OS in populatePaths.
DomlogGlobs []string
// Binary paths useful for reload/control.
ApacheBinary string
NginxBinary string
}
// IsCPanel is a convenience for checks that still need to gate cPanel-only
// behavior (WHM API calls, /home/*/public_html enumeration, exim log
// tailing, etc.) without re-detecting each time.
func (i Info) IsCPanel() bool { return i.Panel == PanelCPanel }
// IsRHELFamily reports whether the OS uses rpm/dnf and /etc/httpd style paths.
func (i Info) IsRHELFamily() bool {
switch i.OS {
case OSAlma, OSRocky, OSCentOS, OSRHEL, OSCloudLinux:
return true
}
return false
}
// IsDebianFamily reports whether the OS uses dpkg/apt and /etc/apache2 style paths.
func (i Info) IsDebianFamily() bool {
return i.OS == OSUbuntu || i.OS == OSDebian
}
// MailLogPath returns the platform-default mail log file. Empty string
// means "no file source available" (operator must use journal).
func (i Info) MailLogPath() string {
if i.IsDebianFamily() {
return "/var/log/mail.log"
}
return "/var/log/maillog"
}
// AuthLogPath returns the platform-default SSH/PAM auth log file.
// Debian-family uses /var/log/auth.log; RHEL-family uses /var/log/secure.
func (i Info) AuthLogPath() string {
if i.IsDebianFamily() {
return "/var/log/auth.log"
}
return "/var/log/secure"
}
// Overrides lets the operator override auto-detected values from csm.yaml.
// Any field left blank or nil falls back to the auto-detected value.
//
// Panel and WebServer use pointer types so callers can distinguish "leave
// auto-detected" (nil) from "explicitly override to none" (pointer to
// PanelNone / WSNone). The non-pointer string/slice fields use the
// zero-value-means-unset convention since they have no legitimate "none"
// value to override to.
type Overrides struct {
Panel *Panel
WebServer *WebServer
AccessLogPaths []string
ErrorLogPaths []string
ModSecAuditLogPaths []string
DomlogGlobs []string
ApacheConfigDir string
NginxConfigDir string
}
var (
detected Info
detectedOnce sync.Once
overrideMu sync.Mutex
pendingOverride *Overrides
)
// SetOverrides installs config-supplied overrides to be merged into the next
// (and all subsequent) Detect() result. Call this once from daemon startup,
// BEFORE the first Detect() call, so the merged info is what every check
// sees. Subsequent SetOverrides calls before Detect() replace the previous
// override; calls after Detect() are no-ops and log a warning via the
// returned bool.
//
// Returns true if the override was installed, false if Detect() had already
// cached an un-overridden result.
func SetOverrides(o Overrides) bool {
overrideMu.Lock()
defer overrideMu.Unlock()
// Use a local mutex + a check on detectedOnce state. detectedOnce has
// no public "was it called" query, so we track it via the pending var.
if pendingOverride != nil && isDetected() {
return false
}
pendingOverride = &o
return !isDetected()
}
// isDetected returns true if Detect() has already cached a result.
// Internal helper — uses a separate flag because sync.Once has no query API.
var detectedFlag atomic.Bool
func isDetected() bool { return detectedFlag.Load() }
// Detect inspects the host and returns platform info. The result is cached
// for the process lifetime — callers that need a fresh probe should use
// DetectFresh instead.
func Detect() Info {
detectedOnce.Do(func() {
detected = DetectFresh()
overrideMu.Lock()
if pendingOverride != nil {
detected = applyOverrides(detected, *pendingOverride)
}
overrideMu.Unlock()
detectedFlag.Store(true)
})
return detected
}
// DetectFresh always re-runs detection, ignoring any cached result.
// Intended for tests and for operator-triggered rescan. Does not apply
// config overrides — use Detect() for the operator-visible view.
func DetectFresh() Info {
i := Info{}
detectOS(&i)
detectPanel(&i)
detectWebServer(&i)
populatePaths(&i)
return i
}
// applyOverrides merges non-empty override fields into info. Always returns
// a new Info — never mutates the input. Paths are replaced, not appended:
// if the operator configured an explicit access-log list, the auto-detected
// list is discarded so operators have full control.
func applyOverrides(info Info, o Overrides) Info {
// Panel override must happen before path rebuild so populatePaths
// picks up the cPanel overlay (or drops it) correctly. Nil means
// "leave auto-detected"; a non-nil pointer always wins, even when it
// points at PanelNone, so operators can explicitly force a host to
// look panel-less.
if o.Panel != nil {
info.Panel = *o.Panel
}
if o.WebServer != nil {
// Web server type changed -- rebuild paths from scratch unless the
// operator also supplied path overrides below. Same nil-vs-pointer
// semantics as Panel: a pointer at WSNone forces "no web server"
// instead of being silently ignored.
info.WebServer = *o.WebServer
info.AccessLogPaths = nil
info.ErrorLogPaths = nil
info.ModSecAuditLogPaths = nil
populatePaths(&info)
}
// When Panel or WebServer changed, recompute DomlogGlobs so the globs
// stay consistent with the new panel/web-server combination. Explicit
// DomlogGlobs override below wins over this recompute.
if o.Panel != nil || o.WebServer != nil {
info.DomlogGlobs = nil
populateDomlogGlobs(&info)
}
if len(o.AccessLogPaths) > 0 {
info.AccessLogPaths = append([]string(nil), o.AccessLogPaths...)
}
if len(o.ErrorLogPaths) > 0 {
info.ErrorLogPaths = append([]string(nil), o.ErrorLogPaths...)
}
if len(o.ModSecAuditLogPaths) > 0 {
info.ModSecAuditLogPaths = append([]string(nil), o.ModSecAuditLogPaths...)
}
if len(o.DomlogGlobs) > 0 {
info.DomlogGlobs = append([]string(nil), o.DomlogGlobs...)
}
if o.ApacheConfigDir != "" {
info.ApacheConfigDir = o.ApacheConfigDir
}
if o.NginxConfigDir != "" {
info.NginxConfigDir = o.NginxConfigDir
}
return info
}
// ResetForTest clears the cached Detect() result so tests can re-run with
// different fixtures. Never call from production code.
func ResetForTest() {
overrideMu.Lock()
defer overrideMu.Unlock()
detected = Info{}
detectedOnce = sync.Once{}
pendingOverride = nil
detectedFlag.Store(false)
}
func detectOS(i *Info) {
f, err := os.Open("/etc/os-release")
if err != nil {
return
}
defer func() { _ = f.Close() }()
scanner := bufio.NewScanner(f)
var id, versionID string
for scanner.Scan() {
line := scanner.Text()
key, val, ok := strings.Cut(line, "=")
if !ok {
continue
}
val = strings.Trim(val, `"'`)
switch key {
case "ID":
id = strings.ToLower(val)
case "VERSION_ID":
versionID = val
}
}
i.OSVersion = versionID
switch id {
case "ubuntu":
i.OS = OSUbuntu
case "debian":
i.OS = OSDebian
case "almalinux":
i.OS = OSAlma
case "rocky":
i.OS = OSRocky
case "centos":
i.OS = OSCentOS
case "rhel":
i.OS = OSRHEL
case "cloudlinux":
i.OS = OSCloudLinux
}
}
func detectPanel(i *Info) {
if _, err := os.Stat("/usr/local/cpanel/version"); err == nil {
i.Panel = PanelCPanel
return
}
if _, err := os.Stat("/usr/local/psa/version"); err == nil {
i.Panel = PanelPlesk
return
}
if _, err := os.Stat("/usr/local/directadmin/directadmin"); err == nil {
i.Panel = PanelDA
return
}
}
func detectWebServer(i *Info) {
// Prefer the process that's actually running. Fall back to installed
// binaries if nothing is running yet (first boot, non-systemd env).
running := runningServices()
// Always record binary paths for reload/control, even if not primary.
if bin, err := exec.LookPath("nginx"); err == nil {
i.NginxBinary = bin
}
if bin, err := exec.LookPath("apache2"); err == nil {
i.ApacheBinary = bin
} else if bin, err := exec.LookPath("httpd"); err == nil {
i.ApacheBinary = bin
}
// cPanel compiles its own httpd under /usr/local/apache/bin/httpd,
// which isn't always in PATH for root under the CSM service unit.
if i.ApacheBinary == "" {
const cpHttpd = "/usr/local/apache/bin/httpd"
if _, err := os.Stat(cpHttpd); err == nil {
i.ApacheBinary = cpHttpd
}
}
i.WebServer = selectWebServer(i.Panel, running, i.ApacheBinary != "", i.NginxBinary != "")
}
// runningServices returns which web server process units are currently
// active. Uses systemctl when available; falls back to checking /proc.
func runningServices() map[string]bool {
active := map[string]bool{}
for _, unit := range []string{"nginx", "apache2", "httpd", "litespeed", "lshttpd", "lsws"} {
// #nosec G204 -- systemctl hardcoded; unit iterates a literal slice.
cmd := exec.Command("systemctl", "is-active", "--quiet", unit)
if err := cmd.Run(); err == nil {
active[unit] = true
}
}
return active
}
func selectWebServer(panel Panel, running map[string]bool, hasApacheBinary, hasNginxBinary bool) WebServer {
apacheRunning := running["apache2"] || running["httpd"]
litespeedRunning := running["litespeed"] || running["lshttpd"] || running["lsws"]
nginxRunning := running["nginx"]
// cPanel commonly runs Nginx as a reverse proxy in front of Apache.
// Prefer the origin server logs when Apache is active so real-time
// access and ModSecurity watchers tail the paths cPanel actually writes.
if panel == PanelCPanel {
switch {
case litespeedRunning:
return WSLiteSpeed
case apacheRunning:
return WSApache
case nginxRunning:
return WSNginx
case hasApacheBinary:
return WSApache
case hasNginxBinary:
return WSNginx
default:
return WSNone
}
}
switch {
case nginxRunning:
return WSNginx
case apacheRunning:
return WSApache
case litespeedRunning:
return WSLiteSpeed
case hasApacheBinary:
return WSApache
case hasNginxBinary:
return WSNginx
default:
return WSNone
}
}
func populatePaths(i *Info) {
// Apache config dir. cPanel compiles Apache from source and installs
// under /usr/local/apache, separate from the OS package tree; that
// override wins over the distro default when cPanel is present.
switch {
case i.Panel == PanelCPanel && dirExists("/usr/local/apache/conf"):
i.ApacheConfigDir = "/usr/local/apache/conf"
case i.IsDebianFamily():
if dirExists("/etc/apache2") {
i.ApacheConfigDir = "/etc/apache2"
}
case i.IsRHELFamily():
if dirExists("/etc/httpd") {
i.ApacheConfigDir = "/etc/httpd"
}
}
if dirExists("/etc/nginx") {
i.NginxConfigDir = "/etc/nginx"
}
// Log paths: pick candidates based on detected web server and OS layout.
// We include ALL plausible locations so log watchers can try each;
// missing paths are handled upstream by the retry logic.
switch i.WebServer {
case WSApache:
if i.IsDebianFamily() {
i.AccessLogPaths = []string{"/var/log/apache2/access.log", "/var/log/apache2/other_vhosts_access.log"}
i.ErrorLogPaths = []string{"/var/log/apache2/error.log"}
i.ModSecAuditLogPaths = []string{"/var/log/apache2/modsec_audit.log"}
} else {
i.AccessLogPaths = []string{"/var/log/httpd/access_log"}
i.ErrorLogPaths = []string{"/var/log/httpd/error_log"}
i.ModSecAuditLogPaths = []string{"/var/log/httpd/modsec_audit.log"}
}
case WSNginx:
i.AccessLogPaths = []string{"/var/log/nginx/access.log"}
i.ErrorLogPaths = []string{"/var/log/nginx/error.log"}
i.ModSecAuditLogPaths = []string{"/var/log/nginx/modsec_audit.log"}
case WSLiteSpeed:
i.AccessLogPaths = []string{"/usr/local/lsws/logs/access.log"}
i.ErrorLogPaths = []string{"/usr/local/lsws/logs/error.log"}
i.ModSecAuditLogPaths = []string{"/usr/local/lsws/logs/auditmodsec.log"}
}
// cPanel overlays its own access/error logs on top of the OS defaults.
if i.Panel == PanelCPanel {
i.AccessLogPaths = append([]string{
"/usr/local/apache/logs/access_log",
"/usr/local/cpanel/logs/access_log",
}, i.AccessLogPaths...)
i.ErrorLogPaths = append([]string{
"/usr/local/apache/logs/error_log",
}, i.ErrorLogPaths...)
i.ModSecAuditLogPaths = append([]string{
"/usr/local/apache/logs/modsec_audit.log",
"/var/log/modsec_audit.log",
}, i.ModSecAuditLogPaths...)
}
populateDomlogGlobs(i)
}
// populateDomlogGlobs sets DomlogGlobs based on panel type and, for
// bare-metal installs, the web server + OS family. Panel takes
// precedence over web server because panel-specific layouts write
// per-vhost logs to panel-owned directories regardless of what web
// server is running underneath.
func populateDomlogGlobs(i *Info) {
switch i.Panel {
case PanelCPanel:
i.DomlogGlobs = []string{
"/home/*/access-logs/*-ssl_log",
"/home/*/access-logs/*_log",
}
case PanelPlesk:
i.DomlogGlobs = []string{
"/var/www/vhosts/*/logs/access_ssl_log",
"/var/www/vhosts/*/logs/access_log",
"/var/www/vhosts/*/logs/proxy_access_ssl_log",
}
case PanelDA:
i.DomlogGlobs = []string{"/var/log/httpd/domains/*.log"}
default:
switch i.WebServer {
case WSApache:
if i.IsDebianFamily() {
i.DomlogGlobs = []string{
"/var/log/apache2/*-access.log",
"/var/log/apache2/*_access.log",
}
} else if i.IsRHELFamily() {
i.DomlogGlobs = []string{
"/var/log/httpd/*-access_log",
"/var/log/httpd/*_access_log",
}
}
case WSNginx:
i.DomlogGlobs = []string{
"/var/log/nginx/*.access.log",
"/var/log/nginx/*-access.log",
}
}
}
}
func dirExists(p string) bool {
fi, err := os.Stat(p)
return err == nil && fi.IsDir()
}
package processctx
import (
"container/list"
"sync"
"sync/atomic"
"time"
)
// Cache is a bounded LRU cache of processEntry keyed by PID with a per-entry
// TTL. Reads and writes are safe for concurrent use. now is a seam for
// deterministic testing.
type Cache struct {
mu sync.Mutex
cap int
ttl time.Duration
ll *list.List // front = most-recently-used
index map[int]*list.Element // pid -> element holding *processEntry
now func() time.Time
evictions atomic.Uint64
ttlPurges atomic.Uint64
misses atomic.Uint64
}
// Stats is a snapshot of cache counters. Safe to call concurrently.
type Stats struct {
Entries int
Evictions uint64 // LRU evictions (cap exceeded)
TTLPurges uint64 // entries dropped because ttl expired on Get
Misses uint64 // Get returned no entry (includes ttl purges)
}
// NewCache returns a cache with the given hard cap and TTL.
func NewCache(cap int, ttl time.Duration) *Cache {
if cap <= 0 {
cap = 1
}
return &Cache{
cap: cap,
ttl: ttl,
ll: list.New(),
index: make(map[int]*list.Element, cap),
now: time.Now,
}
}
// Put inserts or updates an entry. lastTouch is set to now.
func (c *Cache) Put(e processEntry) {
c.mu.Lock()
defer c.mu.Unlock()
e.lastTouch = c.now()
if el, ok := c.index[e.PID]; ok {
el.Value = &e
c.ll.MoveToFront(el)
return
}
el := c.ll.PushFront(&e)
c.index[e.PID] = el
for c.ll.Len() > c.cap {
c.evictOldestLocked()
}
}
// Get returns the entry for pid if present and not TTL-expired. Touching
// an entry promotes it in LRU order.
func (c *Cache) Get(pid int) (processEntry, bool) {
c.mu.Lock()
defer c.mu.Unlock()
el, ok := c.index[pid]
if !ok {
c.misses.Add(1)
return processEntry{}, false
}
entry := el.Value.(*processEntry)
if c.ttl > 0 && c.now().Sub(entry.lastTouch) > c.ttl {
c.removeLocked(el)
c.ttlPurges.Add(1)
c.misses.Add(1)
return processEntry{}, false
}
entry.lastTouch = c.now()
c.ll.MoveToFront(el)
return *entry, true
}
// Len returns the number of live entries (without forcing TTL purge).
func (c *Cache) Len() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.ll.Len()
}
// Stats returns a counter snapshot.
func (c *Cache) Stats() Stats {
c.mu.Lock()
n := c.ll.Len()
c.mu.Unlock()
return Stats{
Entries: n,
Evictions: c.evictions.Load(),
TTLPurges: c.ttlPurges.Load(),
Misses: c.misses.Load(),
}
}
func (c *Cache) evictOldestLocked() {
el := c.ll.Back()
if el == nil {
return
}
c.removeLocked(el)
c.evictions.Add(1)
}
func (c *Cache) removeLocked(el *list.Element) {
entry := el.Value.(*processEntry)
delete(c.index, entry.PID)
c.ll.Remove(el)
}
// PutFromExec is a minimal constructor for callers that have only
// PID/UID/comm/exe from an exec event. UIDKnown is true even for UID 0.
func (c *Cache) PutFromExec(pid, ppid, uid int, comm, exe string) {
c.PutFromExecStartedAt(pid, ppid, uid, comm, exe, time.Time{})
}
// PutFromExecStartedAt is PutFromExec with a detector-supplied process start
// time for PID-reuse validation.
func (c *Cache) PutFromExecStartedAt(pid, ppid, uid int, comm, exe string, startedAt time.Time) {
c.Put(processEntry{PID: pid, PPID: ppid, UID: uid, UIDKnown: true, Comm: comm, Exe: exe, StartedAt: startedAt})
}
// PutFromProc inserts a fully populated entry from /proc-style data. The
// current enricher validates and writes inside processctx; this helper keeps a
// public constructor for tests and future non-daemon callers without exposing
// processEntry.
func (c *Cache) PutFromProc(pid, ppid, uid int, user, account, comm, exe string, cmdline []string) {
c.PutFromProcStartedAt(pid, ppid, uid, user, account, comm, exe, cmdline, time.Time{})
}
// PutFromProcStartedAt is PutFromProc with a known process start time.
func (c *Cache) PutFromProcStartedAt(pid, ppid, uid int, user, account, comm, exe string, cmdline []string, startedAt time.Time) {
c.Put(processEntry{
PID: pid, PPID: ppid, UID: uid, UIDKnown: true,
User: user, Account: account,
Comm: comm, Exe: exe, Cmdline: cmdline, StartedAt: startedAt, ProcRead: true,
})
}
package processctx
import (
"errors"
"sync"
"sync/atomic"
"time"
)
// procReader is the slice of ProcReader the Enricher needs. Allows fakes.
type procReader interface {
Read(pid int) (processEntry, error)
}
// EnrichRequest is the immutable event snapshot queued off the ring-buffer
// path. UID/Comm/StartedAt are used to reject stale PID reuse before caching
// /proc data.
type EnrichRequest struct {
PID int
UID int
UIDKnown bool
Comm string
StartedAt time.Time
}
// IdentityResolver maps a process UID to username/account metadata. It must be
// cache-only in the common path. The daemon implementation uses
// checks.LookupUser's cached /etc/passwd reader and simple local account
// inference; it must not call NSS, LDAP, whmapi1, network services, or any
// blocking account enumerator from the enricher worker.
//
// Implementations SHOULD return within ~1ms in the common case. If a future
// implementation needs a backing data source that can block, refresh it in a
// separate cache outside the worker and have Resolve return ("", "") on cache
// miss rather than stalling the enrichment queue.
type IdentityResolver interface {
Resolve(uid int) (user, account string)
}
type noopResolver struct{}
func (noopResolver) Resolve(int) (string, string) { return "", "" }
// EnricherConfig sizes the worker pool and queue.
type EnricherConfig struct {
Workers int
QueueCap int
Resolver IdentityResolver
}
// EnricherStats is a snapshot of enricher counters.
type EnricherStats struct {
Enqueued uint64
Drops uint64
Reads uint64
Errors uint64
Stale uint64
}
// Enricher consumes PIDs and populates Cache from ProcReader.Read off the
// hot path. Enqueue is nonblocking. On overflow it drops the oldest queued
// request and records that drop, so the producer keeps moving and the queue
// favors fresher process snapshots.
type Enricher struct {
cache *Cache
reader procReader
resolver IdentityResolver
cfg EnricherConfig
queue chan EnrichRequest
wg sync.WaitGroup
stopCh chan struct{}
started atomic.Bool
stopped atomic.Bool
stopOnce sync.Once
enqueued atomic.Uint64
drops atomic.Uint64
reads atomic.Uint64
errors atomic.Uint64
stale atomic.Uint64
latencyMu sync.RWMutex
observeLatency func(float64)
}
// NewEnricher returns a stopped Enricher. Call Start to launch workers.
func NewEnricher(cache *Cache, reader procReader, cfg EnricherConfig) *Enricher {
if cfg.Workers <= 0 {
cfg.Workers = 2
}
if cfg.QueueCap <= 0 {
cfg.QueueCap = 1024
}
resolver := cfg.Resolver
if resolver == nil {
resolver = noopResolver{}
}
return &Enricher{
cache: cache,
reader: reader,
resolver: resolver,
cfg: cfg,
queue: make(chan EnrichRequest, cfg.QueueCap),
stopCh: make(chan struct{}),
}
}
// Start launches the worker goroutines. Idempotent.
func (e *Enricher) Start() {
if e.started.Swap(true) {
return
}
for i := 0; i < e.cfg.Workers; i++ {
e.wg.Add(1)
go e.worker()
}
}
// Stop signals workers and waits for them to exit. Safe to call multiple times.
//
// Queued requests are not drained: any EnrichRequest still in the channel
// when stopCh closes is dropped on the floor. Stats.Enqueued therefore stays
// ahead of Stats.Reads + Stats.Errors after shutdown by the queue depth.
// Acceptable on daemon shutdown because no caller is waiting for completion;
// operators reading the metrics post-restart should expect this delta.
func (e *Enricher) Stop() {
e.stopOnce.Do(func() {
e.stopped.Store(true)
close(e.stopCh)
e.wg.Wait()
})
}
// Enqueue adds a request to the work queue. Returns false only when the
// enricher is stopped. If the queue is full, the oldest pending request is
// dropped and the new one is queued.
func (e *Enricher) Enqueue(req EnrichRequest) bool {
if req.PID <= 0 || e.stopped.Load() {
e.drops.Add(1)
return false
}
select {
case e.queue <- req:
e.enqueued.Add(1)
return true
case <-e.stopCh:
e.drops.Add(1)
return false
default:
select {
case <-e.queue:
e.drops.Add(1)
default:
}
select {
case e.queue <- req:
e.enqueued.Add(1)
return true
case <-e.stopCh:
e.drops.Add(1)
return false
default:
e.drops.Add(1)
return false
}
}
}
// SetLatencyObserver installs an optional callback used by metrics.
func (e *Enricher) SetLatencyObserver(fn func(float64)) {
e.latencyMu.Lock()
defer e.latencyMu.Unlock()
e.observeLatency = fn
}
func (e *Enricher) observe(seconds float64) {
e.latencyMu.RLock()
fn := e.observeLatency
e.latencyMu.RUnlock()
if fn != nil {
fn(seconds)
}
}
func (e *Enricher) shouldCache(req EnrichRequest, entry processEntry) bool {
reqUIDKnown := req.UIDKnown || req.UID != 0
if reqUIDKnown {
if !entry.UIDKnown || req.UID != entry.UID {
return false
}
} else if !entry.UIDKnown {
return false
}
if req.Comm != "" && entry.Comm != req.Comm {
return false
}
// PID-reuse guard: when the detector supplied a process-start
// snapshot, the /proc-derived start time must be present and match
// within a small tolerance. A missing or mismatched /proc value means
// the enricher cannot prove it is looking at the same process.
if !processStartMatches(req.StartedAt, entry.StartedAt) {
return false
}
return true
}
// processStartTimeTolerance bounds the allowed clock skew between a
// detector's process-start snapshot and /proc/<pid>/stat's starttime. Five
// seconds covers clock granularity and slow pickup without letting a PID-reuse
// race slip past.
const processStartTimeTolerance = 5 * time.Second
func processStartMatches(want, got time.Time) bool {
if want.IsZero() {
return true
}
if got.IsZero() {
return false
}
diff := want.Sub(got)
if diff < 0 {
diff = -diff
}
return diff <= processStartTimeTolerance
}
func (e *Enricher) enrichIdentity(entry *processEntry) {
if !entry.UIDKnown {
return
}
user, account := e.resolver.Resolve(entry.UID)
entry.User = user
entry.Account = account
}
// Stats returns a counter snapshot.
func (e *Enricher) Stats() EnricherStats {
return EnricherStats{
Enqueued: e.enqueued.Load(),
Drops: e.drops.Load(),
Reads: e.reads.Load(),
Errors: e.errors.Load(),
Stale: e.stale.Load(),
}
}
func (e *Enricher) worker() {
defer e.wg.Done()
for {
select {
case <-e.stopCh:
return
case req := <-e.queue:
start := time.Now()
e.reads.Add(1)
entry, err := e.reader.Read(req.PID)
e.observe(time.Since(start).Seconds())
if err != nil {
if errors.Is(err, ErrProcessGone) {
continue
}
e.errors.Add(1)
continue
}
if !e.shouldCache(req, entry) {
e.stale.Add(1)
continue
}
e.enrichIdentity(&entry)
e.cache.Put(entry)
}
}
}
package processctx
import "time"
// Materialize walks the PPID chain starting at pid up to MaxParentDepth and
// returns a ProcessContext tree. Returns nil if pid is not in the cache.
// Cycle-safe: tracks visited PIDs.
func (c *Cache) Materialize(pid int) *ProcessContext {
root, ok := c.Get(pid)
if !ok {
return nil
}
return c.materializeFromRoot(root)
}
// MaterializeVerified returns a materialized context only when the cached root
// entry still matches the event snapshot. The bool return reports whether the
// root entry still needs an off-path /proc read.
func (c *Cache) MaterializeVerified(pid, uid int, uidKnown bool, comm string) (*ProcessContext, bool) {
root, ok := c.Get(pid)
if !ok {
return nil, false
}
if !matchesSnapshot(root, uid, uidKnown, comm) {
return nil, false
}
return c.materializeFromRoot(root), !root.ProcRead
}
// MaterializeVerifiedSnapshot returns a materialized context only when the
// cached root entry still matches the full detector snapshot, including the
// process start time when one was captured.
func (c *Cache) MaterializeVerifiedSnapshot(req EnrichRequest) (*ProcessContext, bool) {
root, ok := c.Get(req.PID)
if !ok {
return nil, false
}
if !matchesSnapshot(root, req.UID, req.UIDKnown, req.Comm) {
return nil, false
}
if !processStartMatches(req.StartedAt, root.StartedAt) {
return nil, false
}
return c.materializeFromRoot(root), !root.ProcRead
}
func (c *Cache) materializeFromRoot(root processEntry) *ProcessContext {
visited := map[int]bool{root.PID: true}
head := toContext(root)
cur := head
parentPID := root.PPID
for depth := 1; depth < MaxParentDepth; depth++ {
if parentPID <= 0 || visited[parentPID] {
break
}
entry, ok := c.Get(parentPID)
if !ok {
break
}
visited[entry.PID] = true
cur.Parent = toContext(entry)
cur = cur.Parent
parentPID = entry.PPID
}
return head
}
func matchesSnapshot(e processEntry, uid int, uidKnown bool, comm string) bool {
if uidKnown {
if !e.UIDKnown || e.UID != uid {
return false
}
}
if comm != "" && e.Comm != comm {
return false
}
return true
}
func toContext(e processEntry) *ProcessContext {
var startedAt *time.Time
if !e.StartedAt.IsZero() {
t := e.StartedAt
startedAt = &t
}
return &ProcessContext{
PID: e.PID,
PPID: e.PPID,
UID: e.UID,
User: e.User,
Account: e.Account,
Comm: e.Comm,
Exe: e.Exe,
Cmdline: append([]string(nil), e.Cmdline...),
StartedAt: startedAt,
}
}
package processctx
import "github.com/pidginhost/csm/internal/metrics"
// RegisterMetrics binds cache and enricher counters/gauges to reg. The
// registry argument allows tests to use a private registry. Production
// callers should pass metrics.Default().
func RegisterMetrics(reg *metrics.Registry, cache *Cache, enr *Enricher) {
reg.RegisterGaugeFunc(
"csm_process_context_cache_entries",
"Live process-context cache entries.",
func() float64 { return float64(cache.Stats().Entries) },
)
reg.RegisterCounterFunc(
"csm_process_context_cache_evictions_total",
"Process-context cache LRU evictions (cap exceeded).",
func() float64 { return float64(cache.Stats().Evictions) },
)
reg.RegisterCounterFunc(
"csm_process_context_cache_ttl_purges_total",
"Process-context cache entries dropped because TTL expired on lookup.",
func() float64 { return float64(cache.Stats().TTLPurges) },
)
reg.RegisterCounterFunc(
"csm_process_context_cache_misses_total",
"Process-context cache lookup misses (includes TTL purges).",
func() float64 { return float64(cache.Stats().Misses) },
)
reg.RegisterCounterFunc(
"csm_process_context_enrich_queue_drops_total",
"Process-context enrichment requests dropped because queue was full.",
func() float64 { return float64(enr.Stats().Drops) },
)
reg.RegisterCounterFunc(
"csm_process_context_enrich_reads_total",
"Process-context /proc reads attempted by enricher workers.",
func() float64 { return float64(enr.Stats().Reads) },
)
reg.RegisterCounterFunc(
"csm_process_context_enrich_errors_total",
"Process-context /proc read errors (excluding ProcessGone).",
func() float64 { return float64(enr.Stats().Errors) },
)
reg.RegisterCounterFunc(
"csm_process_context_enrich_stale_total",
"Process-context enrichment results rejected as stale PID reuse.",
func() float64 { return float64(enr.Stats().Stale) },
)
latency := metrics.NewHistogram(
"csm_process_context_enrich_latency_seconds",
"Process-context enrichment worker latency in seconds.",
[]float64{0.001, 0.005, 0.01, 0.05, 0.1, 1},
)
reg.MustRegister("csm_process_context_enrich_latency_seconds", latency)
enr.SetLatencyObserver(latency.Observe)
}
package processctx
import (
"bytes"
"errors"
"io/fs"
"os"
"path/filepath"
"strconv"
"strings"
"time"
)
// ErrProcessGone is returned by ProcReader.Read when the /proc/<pid> tree
// no longer exists. Callers must treat this as a soft miss, not an error
// finding - short-lived processes are expected.
var ErrProcessGone = errors.New("process gone")
// ProcReader reads /proc/<pid>/{status,cmdline,exe} with a per-file deadline.
type ProcReader struct {
root string
perFileDeadline time.Duration
}
// NewProcReader constructs a reader rooted at procRoot ("/proc" in production,
// a temp dir in tests). perFileDeadline bounds each individual file read.
func NewProcReader(procRoot string, perFileDeadline time.Duration) *ProcReader {
return &ProcReader{root: procRoot, perFileDeadline: perFileDeadline}
}
// Read returns a processEntry populated from /proc/<pid>. Fields that cannot
// be read within the deadline are left at their zero value; the caller still
// gets an entry with whatever was retrievable. Returns ErrProcessGone when
// the /proc/<pid> directory does not exist.
func (r *ProcReader) Read(pid int) (processEntry, error) {
dir := filepath.Join(r.root, strconv.Itoa(pid))
if _, err := os.Stat(dir); err != nil {
if errors.Is(err, fs.ErrNotExist) {
return processEntry{}, ErrProcessGone
}
return processEntry{}, err
}
e := processEntry{PID: pid, ProcRead: true}
if data, ok := readFileWithDeadline(filepath.Join(dir, "status"), r.perFileDeadline); ok {
e.PPID = parseStatusPPID(string(data))
e.UID, e.UIDKnown = parseStatusUIDKnown(string(data))
e.Comm = parseStatusName(string(data))
}
if data, ok := readFileWithDeadline(filepath.Join(dir, "cmdline"), r.perFileDeadline); ok {
e.Cmdline = parseCmdline(data)
}
if target, ok := readlinkWithDeadline(filepath.Join(dir, "exe"), r.perFileDeadline); ok {
e.Exe = target
}
if data, ok := readFileWithDeadline(filepath.Join(dir, "stat"), r.perFileDeadline); ok {
if t, ok := r.parseStartedAt(data); ok {
e.StartedAt = t
}
}
return e, nil
}
// ReadStartedAt returns only /proc/<pid>/stat's process start time. It lets
// detector hot paths capture a lightweight PID-reuse token without reading
// status, cmdline, exe, or identity data.
func (r *ProcReader) ReadStartedAt(pid int) (time.Time, bool) {
if pid <= 0 {
return time.Time{}, false
}
path := filepath.Join(r.root, strconv.Itoa(pid), "stat")
data, ok := readFileWithDeadline(path, r.perFileDeadline)
if !ok {
return time.Time{}, false
}
return r.parseStartedAt(data)
}
// procStatStartTime extracts field 22 of /proc/<pid>/stat (starttime in
// clock ticks since boot). Field positions are deterministic except
// that the second field (comm) is parenthesized and may contain
// arbitrary bytes including spaces -- so we anchor on the final ")"
// before splitting the rest.
func procStatStartTime(data []byte) (int64, bool) {
end := bytes.LastIndexByte(data, ')')
if end < 0 || end+1 >= len(data) {
return 0, false
}
rest := strings.TrimSpace(string(data[end+1:]))
fields := strings.Fields(rest)
// rest starts at field 3 (state); starttime is field 22, i.e. index 19 in rest.
const starttimeIdx = 19
if len(fields) <= starttimeIdx {
return 0, false
}
v, err := strconv.ParseInt(fields[starttimeIdx], 10, 64)
if err != nil || v < 0 {
return 0, false
}
return v, true
}
// parseStartedAt converts /proc/<pid>/stat's starttime field into an
// absolute time using the host's boot time. Returns (zero, false) on
// any parse or btime resolution failure so callers leave the field
// unset rather than emitting bogus timestamps.
func (r *ProcReader) parseStartedAt(stat []byte) (time.Time, bool) {
ticks, ok := procStatStartTime(stat)
if !ok {
return time.Time{}, false
}
boot, ok := r.bootTime()
if !ok {
return time.Time{}, false
}
hz := clockTicksPerSecond()
if hz <= 0 {
return time.Time{}, false
}
sec := ticks / hz
rem := ticks % hz
ns := rem * int64(time.Second) / hz
return boot.Add(time.Duration(sec)*time.Second + time.Duration(ns)), true
}
// readFileWithDeadline reads up to 4 KiB from path; returns (data, true) on
// success or (nil, false) on any error or deadline expiry. /proc files are
// small; using ReadFile keeps the normal path simple. The generic deadline
// helper is tested with an injected slow function instead of a FIFO because
// general filesystem opens cannot be cancelled safely on every platform.
func readFileWithDeadline(path string, d time.Duration) ([]byte, bool) {
return runBytesWithDeadline(d, func() ([]byte, error) {
// #nosec G304 -- path is constructed from ProcReader.root + numeric PID;
// callers only pass procfs entries under r.root.
data, err := os.ReadFile(path)
if len(data) > 4096 {
data = data[:4096]
}
return data, err
})
}
// procReadConcurrency bounds how many deadline-bound /proc reads run at once.
// A blocking syscall goroutine cannot be cancelled in Go, so a wedged /proc
// entry (NFS-backed, D-state) leaks its goroutine until the kernel returns --
// which may be never. The cap turns what was an unbounded leak under PID-reuse
// churn into a fixed ceiling: once it is reached, further reads fail fast
// instead of spawning more abandonable goroutines. A goroutine releases its
// slot only when its syscall finally returns, so genuinely-stuck reads keep
// their slot (correctly counting against the ceiling).
const procReadConcurrency = 64
var procReadSem = make(chan struct{}, procReadConcurrency)
func acquireProcReadSlot() bool {
select {
case procReadSem <- struct{}{}:
return true
default:
return false
}
}
func releaseProcReadSlot() { <-procReadSem }
func runBytesWithDeadline(d time.Duration, fn func() ([]byte, error)) ([]byte, bool) {
if d <= 0 {
data, err := fn()
return data, err == nil
}
if !acquireProcReadSlot() {
return nil, false
}
type result struct {
data []byte
err error
}
ch := make(chan result, 1)
go func() {
defer releaseProcReadSlot()
data, err := fn()
ch <- result{data: data, err: err}
}()
timer := time.NewTimer(d)
defer timer.Stop()
select {
case res := <-ch:
if res.err != nil {
return nil, false
}
return res.data, true
case <-timer.C:
return nil, false
}
}
// readlinkWithDeadline runs Readlink in a goroutine and gives up after d.
func readlinkWithDeadline(path string, d time.Duration) (string, bool) {
if d <= 0 {
target, err := os.Readlink(path)
return target, err == nil
}
if !acquireProcReadSlot() {
return "", false
}
type result struct {
target string
err error
}
ch := make(chan result, 1)
go func() {
defer releaseProcReadSlot()
t, err := os.Readlink(path)
ch <- result{t, err}
}()
timer := time.NewTimer(d)
defer timer.Stop()
select {
case res := <-ch:
if res.err != nil {
return "", false
}
return res.target, true
case <-timer.C:
return "", false
}
}
func parseStatusName(s string) string {
for _, line := range strings.Split(s, "\n") {
if rest, ok := strings.CutPrefix(line, "Name:\t"); ok {
return strings.TrimSpace(rest)
}
}
return ""
}
func parseStatusPPID(s string) int {
for _, line := range strings.Split(s, "\n") {
if rest, ok := strings.CutPrefix(line, "PPid:\t"); ok {
v, _ := strconv.Atoi(strings.TrimSpace(rest))
return v
}
}
return 0
}
func parseStatusUID(s string) int {
uid, _ := parseStatusUIDKnown(s)
return uid
}
func parseStatusUIDKnown(s string) (int, bool) {
for _, line := range strings.Split(s, "\n") {
if rest, ok := strings.CutPrefix(line, "Uid:\t"); ok {
fields := strings.Fields(rest)
if len(fields) == 0 {
return 0, false
}
v, err := strconv.Atoi(fields[0])
return v, err == nil
}
}
return 0, false
}
func parseCmdline(b []byte) []string {
if len(b) == 0 {
return nil
}
parts := bytes.Split(b, []byte{0})
out := make([]string, 0, len(parts))
for _, p := range parts {
if len(p) == 0 {
continue
}
out = append(out, string(p))
}
if len(out) == 0 {
return nil
}
return sanitizeCmdline(out)
}
const maxCmdlineArgLen = 256
var sensitiveCmdlineKeys = []string{"password", "passwd", "secret", "token", "api_key", "apikey"}
func sanitizeCmdline(args []string) []string {
out := make([]string, 0, len(args))
redactNext := false
for _, arg := range args {
if redactNext {
out = append(out, "<redacted>")
redactNext = false
continue
}
lower := strings.ToLower(arg)
redacted := false
for _, key := range sensitiveCmdlineKeys {
if strings.Contains(lower, key+"=") {
prefix, _, _ := strings.Cut(arg, "=")
out = append(out, truncateCmdlineArg(prefix+"=<redacted>"))
redacted = true
break
}
if lower == "--"+key || lower == "-"+key {
out = append(out, truncateCmdlineArg(arg))
redactNext = true
redacted = true
break
}
}
if redacted {
continue
}
out = append(out, truncateCmdlineArg(arg))
}
return out
}
func truncateCmdlineArg(arg string) string {
if len(arg) <= maxCmdlineArgLen {
return arg
}
return arg[:maxCmdlineArgLen]
}
package processctx
import (
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
// bootTimeCache caches the host boot time per proc root.
// Boot time is invariant for the life of a kernel; rereading /proc/stat
// on every Read() would only spend syscalls for an unchanging value.
type bootTimeCache struct {
once sync.Once
t time.Time
ok bool
}
var procReaderBoot sync.Map // root -> *bootTimeCache
func (r *ProcReader) bootTime() (time.Time, bool) {
v, _ := procReaderBoot.LoadOrStore(r.root, &bootTimeCache{})
c := v.(*bootTimeCache)
c.once.Do(func() {
c.t, c.ok = readBootTime(filepath.Join(r.root, "stat"))
})
return c.t, c.ok
}
func readBootTime(statPath string) (time.Time, bool) {
// #nosec G304 -- statPath is procReader.root + "stat"; root is operator-pinned.
data, err := os.ReadFile(statPath)
if err != nil {
return time.Time{}, false
}
for _, line := range strings.Split(string(data), "\n") {
if rest, ok := strings.CutPrefix(line, "btime "); ok {
sec, err := strconv.ParseInt(strings.TrimSpace(rest), 10, 64)
if err != nil {
return time.Time{}, false
}
return time.Unix(sec, 0), true
}
}
return time.Time{}, false
}
// clockTicksPerSecondOverride lets tests inject a known _SC_CLK_TCK
// without calling sysconf on the host.
var clockTicksPerSecondOverride int64
func clockTicksPerSecond() int64 {
if clockTicksPerSecondOverride > 0 {
return clockTicksPerSecondOverride
}
return defaultClockTicksPerSecond
}
// defaultClockTicksPerSecond is 100 on every mainstream Linux kernel
// CSM runs on (CONFIG_HZ_100=y). Reading sysconf would require cgo or
// a build-tagged platform shim; the constant is correct for every
// distribution kernel we target. Tests that want a different value
// set clockTicksPerSecondOverride.
const defaultClockTicksPerSecond = 100
// Package redisinfo wraps the go-redis client for the few read-only
// INFO calls CSM needs (memory metrics, keyspace counts). It replaces
// the `redis-cli info <section>` shell-outs the performance UI used
// to issue per-poll, eliminating libc/libpthread fork churn on hosts
// with a busy metrics dashboard.
//
// The client is a lazy package-level singleton: first call opens a
// connection-pooled client against the local redis socket / TCP, all
// subsequent calls reuse it. No Close path because the daemon runs
// for the host's lifetime; go-redis cleans up on process exit.
//
// Connection target matches redis-cli's default behaviour:
//
// 127.0.0.1:6379, no password, db 0
//
// When REDISCLI_AUTH is set, the client uses it as the password so
// daemon environments that previously made redis-cli work keep working.
// Hosts running redis on a non-default socket can override by calling
// SetAddr before the first MemoryUsage / Keyspace call. Absolute
// paths are treated as Unix sockets.
package redisinfo
import (
"context"
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
const defaultAddr = "127.0.0.1:6379"
var (
mu sync.Mutex
defaultDB *redis.Client
clientBuilt bool
addr = defaultAddr
password = ""
)
// SetAddr overrides the connection target before the singleton opens.
// Calls after the first MemoryUsage / Keyspace are ignored: the
// singleton has already been built. Tests can also use SetClientForTest
// to swap the singleton wholesale.
func SetAddr(a, pwd string) {
mu.Lock()
defer mu.Unlock()
if clientBuilt {
return
}
addr = a
password = pwd
}
// SetClientForTest replaces the singleton with a caller-supplied
// client (typically pointed at miniredis or a real test instance).
// Pass nil to clear the override and let the next call rebuild the
// real singleton.
func SetClientForTest(c *redis.Client) {
mu.Lock()
defer mu.Unlock()
defaultDB = c
clientBuilt = c != nil
}
func client() *redis.Client {
mu.Lock()
defer mu.Unlock()
if defaultDB == nil {
defaultDB = redis.NewClient(redisOptions(addr, password))
clientBuilt = true
}
return defaultDB
}
func redisOptions(target, pwd string) *redis.Options {
network := "tcp"
if strings.HasPrefix(target, "/") {
network = "unix"
}
if pwd == "" {
pwd = os.Getenv("REDISCLI_AUTH")
}
return &redis.Options{
Network: network,
Addr: target,
Password: pwd,
DB: 0,
// Fail fast when redis is absent: this client only serves the
// metrics dashboard, never a hot path. Default MaxRetries=3 +
// ReadTimeout=3s would block the perfMetrics sampler for >9s on
// a host without redis (a normal config -- CSM does not require
// redis to run).
DialTimeout: 500 * time.Millisecond,
ReadTimeout: 500 * time.Millisecond,
WriteTimeout: 500 * time.Millisecond,
MaxRetries: -1, // disable retries entirely
PoolTimeout: 500 * time.Millisecond,
PoolSize: 2,
}
}
// MemoryUsage returns the redis used_memory and maxmemory values
// from `INFO memory`, in bytes. Either may be zero if the server
// omits the field. err non-nil only on connection / protocol error.
//
// Tests can intercept via SetMemoryUsageForTest.
func MemoryUsage(ctx context.Context) (used, max uint64, err error) {
if fn := getMemoryUsageMock(); fn != nil {
return fn(ctx)
}
c := client()
if c == nil {
return 0, 0, fmt.Errorf("redisinfo: client not initialised")
}
raw, err := c.Info(ctx, "memory").Result()
if err != nil {
return 0, 0, err
}
used, max = parseMemoryInfo(raw)
return used, max, nil
}
// MemoryUsageFunc is the signature SetMemoryUsageForTest accepts.
type MemoryUsageFunc func(ctx context.Context) (used, max uint64, err error)
// KeyspaceStatsFunc is the signature SetKeyspaceStatsForTest accepts.
type KeyspaceStatsFunc func(ctx context.Context) (KeyspaceStat, error)
// ConfigGetFunc is the signature SetConfigGetForTest accepts.
type ConfigGetFunc func(ctx context.Context, name string) (string, error)
var (
mockMu sync.RWMutex
memMock MemoryUsageFunc
keyMock KeyspaceStatsFunc
configMock ConfigGetFunc
)
// SetMemoryUsageForTest installs an interceptor for MemoryUsage. Pass
// nil to clear. Production code paths must NOT call this.
func SetMemoryUsageForTest(fn MemoryUsageFunc) {
mockMu.Lock()
defer mockMu.Unlock()
memMock = fn
}
// SetKeyspaceStatsForTest installs an interceptor for KeyspaceStats
// (and Keyspace, which proxies to it). Pass nil to clear.
func SetKeyspaceStatsForTest(fn KeyspaceStatsFunc) {
mockMu.Lock()
defer mockMu.Unlock()
keyMock = fn
}
// SetConfigGetForTest installs an interceptor for ConfigGet. Pass nil
// to clear.
func SetConfigGetForTest(fn ConfigGetFunc) {
mockMu.Lock()
defer mockMu.Unlock()
configMock = fn
}
func getMemoryUsageMock() MemoryUsageFunc {
mockMu.RLock()
defer mockMu.RUnlock()
return memMock
}
func getKeyspaceStatsMock() KeyspaceStatsFunc {
mockMu.RLock()
defer mockMu.RUnlock()
return keyMock
}
func getConfigGetMock() ConfigGetFunc {
mockMu.RLock()
defer mockMu.RUnlock()
return configMock
}
func parseMemoryInfo(raw string) (used, max uint64) {
for _, line := range strings.Split(raw, "\n") {
line = strings.TrimSpace(line)
switch {
case strings.HasPrefix(line, "used_memory:"):
used, _ = strconv.ParseUint(strings.TrimSpace(strings.TrimPrefix(line, "used_memory:")), 10, 64)
case strings.HasPrefix(line, "maxmemory:"):
max, _ = strconv.ParseUint(strings.TrimSpace(strings.TrimPrefix(line, "maxmemory:")), 10, 64)
}
}
return used, max
}
// Keyspace returns the sum of `keys=N` across every db<n> line in
// `INFO keyspace`. err non-nil only on connection / protocol error.
func Keyspace(ctx context.Context) (int64, error) {
stats, err := KeyspaceStats(ctx)
if err != nil {
return 0, err
}
return stats.TotalKeys, nil
}
// KeyspaceStat is the aggregated breakdown of `INFO keyspace`. Keys
// counts all keys across all dbs; Expires counts the subset with a
// TTL applied.
type KeyspaceStat struct {
TotalKeys int64
TotalExpires int64
}
// KeyspaceStats returns the per-db sums from `INFO keyspace`.
//
// Tests can intercept via SetKeyspaceStatsForTest.
func KeyspaceStats(ctx context.Context) (KeyspaceStat, error) {
if fn := getKeyspaceStatsMock(); fn != nil {
return fn(ctx)
}
c := client()
if c == nil {
return KeyspaceStat{}, fmt.Errorf("redisinfo: client not initialised")
}
raw, err := c.Info(ctx, "keyspace").Result()
if err != nil {
return KeyspaceStat{}, err
}
return parseKeyspaceStats(raw), nil
}
// ConfigGet returns the value of a single CONFIG GET parameter (e.g.
// "maxmemory", "save", "maxmemory-policy"). Empty result returns
// ("", nil) so callers can distinguish "unset" from "connection error".
//
// Tests can intercept via SetConfigGetForTest.
func ConfigGet(ctx context.Context, name string) (string, error) {
if fn := getConfigGetMock(); fn != nil {
return fn(ctx, name)
}
c := client()
if c == nil {
return "", fmt.Errorf("redisinfo: client not initialised")
}
m, err := c.ConfigGet(ctx, name).Result()
if err != nil {
return "", err
}
if v, ok := m[name]; ok {
return v, nil
}
return "", nil
}
func parseKeyspaceInfo(raw string) int64 {
return parseKeyspaceStats(raw).TotalKeys
}
func parseKeyspaceStats(raw string) KeyspaceStat {
var stat KeyspaceStat
for _, line := range strings.Split(raw, "\n") {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "db") {
continue
}
parts := strings.SplitN(line, ":", 2)
if len(parts) < 2 {
continue
}
for _, kv := range strings.Split(parts[1], ",") {
kv = strings.TrimSpace(kv)
eq := strings.IndexByte(kv, '=')
if eq < 0 {
continue
}
val, perr := strconv.ParseInt(kv[eq+1:], 10, 64)
if perr != nil {
continue
}
switch kv[:eq] {
case "keys":
stat.TotalKeys += val
case "expires":
stat.TotalExpires += val
}
}
}
return stat
}
package reporting
import (
"context"
"sync/atomic"
)
// Action is the node's policy for acting on central scored-set data. It is
// deliberately conservative: central data never hard-blocks on its own.
type Action string
const (
// ActionOff consumes the set for visibility only; never acts.
ActionOff Action = "off"
// ActionChallenge elevates suspicion / serves a challenge for listed IPs.
ActionChallenge Action = "challenge"
// ActionBlockIfLocalCorroborated hard-blocks only when this node also saw
// abuse from the IP and the distributed score meets the threshold.
ActionBlockIfLocalCorroborated Action = "block_if_local_corroborated"
)
// ParseAction maps a config string to an Action, defaulting to challenge for
// any unknown value (the safe default; never silently block).
func ParseAction(s string) Action {
switch Action(s) {
case ActionOff:
return ActionOff
case ActionBlockIfLocalCorroborated:
return ActionBlockIfLocalCorroborated
default:
return ActionChallenge
}
}
// Decision is what the node should do about an IP given central data.
type Decision int
const (
// DecisionIgnore takes no central-driven action.
DecisionIgnore Decision = iota
// DecisionChallenge elevates suspicion / serves a challenge.
DecisionChallenge
// DecisionBlock hard-blocks (only reachable with local corroboration).
DecisionBlock
)
// DecisionInput is the per-IP context for a central-data decision.
type DecisionInput struct {
Found bool // IP present in the central scored-set
Score int // distributed score 0-100
Protected bool // firebreak: infra/CF/crawler/RFC5737/allowlist
LocallyCorroborated bool // this node independently observed abuse from the IP
}
// Decide returns the node action for an IP. Firebreaks always win: a protected
// IP is never acted on from central data. A central-only signal can at most
// challenge; a hard block requires local corroboration, the action policy, and
// the score meeting blockThreshold.
func Decide(in DecisionInput, action Action, blockThreshold int) Decision {
if in.Protected || !in.Found || action == ActionOff {
return DecisionIgnore
}
if action == ActionBlockIfLocalCorroborated && in.LocallyCorroborated && in.Score >= blockThreshold {
return DecisionBlock
}
return DecisionChallenge
}
// centralState bundles the current snapshot and its derived lookup set so both
// are swapped in a single atomic store; two separate pointers could be read in
// a torn intermediate state.
type centralState struct {
snapshot ScoredSnapshot
set *Set
}
// CentralStore holds the current verified scored-set for concurrent lookups and
// is refreshed by a Puller. The state is swapped atomically so readers on the
// block/challenge path never block on a refresh and never see a torn snapshot.
type CentralStore struct {
puller *Puller
state atomic.Pointer[centralState]
}
// NewCentralStore builds an empty store backed by puller.
func NewCentralStore(puller *Puller) *CentralStore {
cs := &CentralStore{puller: puller}
empty := ScoredSnapshot{}
cs.state.Store(¢ralState{snapshot: empty, set: NewSet(empty)})
return cs
}
// Lookup returns the scored entry for ip from the current set.
func (cs *CentralStore) Lookup(ip string) (ScoredEntry, bool) {
return cs.state.Load().set.Lookup(ip)
}
// Version returns the current set version.
func (cs *CentralStore) Version() uint64 { return cs.state.Load().set.Version() }
// Refresh pulls an update and swaps in the new state on change. On a version
// gap it retries once from a cold pull (since=0) so a node that fell behind the
// diff window recovers with a full snapshot. A lower version is rejected so a
// rolled-back or hostile endpoint cannot regress the cache.
func (cs *CentralStore) Refresh(ctx context.Context) error {
cur := cs.state.Load().snapshot
next, changed, err := cs.puller.Refresh(ctx, cur)
if err == ErrSetVersionGap {
next, changed, err = cs.puller.Refresh(ctx, ScoredSnapshot{})
}
if err != nil {
return err
}
if changed {
nextState := ¢ralState{snapshot: next, set: NewSet(next)}
for {
latest := cs.state.Load()
if next.Version < latest.snapshot.Version {
return ErrSetVersionGap
}
if next.Version == latest.snapshot.Version {
return nil
}
if cs.state.CompareAndSwap(latest, nextState) {
return nil
}
}
}
return nil
}
package reporting
import (
"context"
"encoding/hex"
"errors"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
// maxScoredSetBytes caps a pulled scored-set payload so a hostile or broken
// endpoint cannot exhaust memory.
const maxScoredSetBytes = 64 << 20 // 64 MiB
var (
// ErrPullStatus means the endpoint returned an unexpected HTTP status.
ErrPullStatus = errors.New("reporting: scored-set pull bad status")
// ErrPullBodyTooLarge means the endpoint returned more bytes than a node
// will verify and cache.
ErrPullBodyTooLarge = errors.New("reporting: scored-set pull body too large")
)
// Puller fetches and verifies the signed scored-set from the central service.
// It pulls a full snapshot on a cold cache and a one-step diff thereafter,
// verifying the Ed25519 signature before applying anything.
type Puller struct {
client *http.Client
url string
pubHex string
}
// NewPuller builds a Puller for setURL, verifying against the central public
// key (hex). A nil client uses a default with a 30s timeout.
func NewPuller(client *http.Client, setURL, pubHex string) *Puller {
if client == nil {
client = &http.Client{Timeout: 30 * time.Second}
}
return &Puller{client: client, url: setURL, pubHex: pubHex}
}
// Refresh fetches an update relative to current and returns the new snapshot.
// When the endpoint reports no change (304), it returns current with
// changed=false. A diff that does not apply onto current (version gap) falls
// back by returning an error so the caller retries with a full pull (since=0).
func (p *Puller) Refresh(ctx context.Context, current ScoredSnapshot) (ScoredSnapshot, bool, error) {
reqURL := p.url
var err error
if current.Version > 0 {
reqURL, err = withSince(p.url, current.Version)
if err != nil {
return current, false, err
}
}
if validateErr := ValidateTargetURL(reqURL); validateErr != nil {
return current, false, validateErr
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil)
if err != nil {
return current, false, err
}
// The scored-set URL is operator-configured (reputation.central.set_url) and
// the response is Ed25519-verified before use; not attacker-controlled.
// #nosec G704 -- central set URL is operator config, signature-verified.
resp, err := p.client.Do(req)
if err != nil {
return current, false, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode == http.StatusNotModified {
return current, false, nil
}
if resp.StatusCode != http.StatusOK {
return current, false, ErrPullStatus
}
body, err := readScoredSetBody(resp.Body, maxScoredSetBytes)
if err != nil {
return current, false, err
}
sig, err := parseSetSignature(resp.Header.Get("X-CSM-Signature"))
if err != nil {
return current, false, err
}
switch resp.Header.Get("X-CSM-Kind") {
case "diff":
vd, err := OpenDiff(body, sig, p.pubHex)
if err != nil {
return current, false, err
}
next, err := ApplyDiff(current, vd)
if err != nil {
return current, false, err // version gap: caller retries full
}
return next, true, nil
default: // "snapshot" or unset
snap, err := OpenSnapshot(body, sig, p.pubHex)
if err != nil {
return current, false, err
}
return snap, true, nil
}
}
func readScoredSetBody(r io.Reader, limit int64) ([]byte, error) {
body, err := io.ReadAll(io.LimitReader(r, limit+1))
if err != nil {
return nil, err
}
if int64(len(body)) > limit {
return nil, ErrPullBodyTooLarge
}
return body, nil
}
func withSince(base string, version uint64) (string, error) {
u, err := url.Parse(base)
if err != nil {
return "", err
}
q, err := url.ParseQuery(u.RawQuery)
if err != nil {
return "", err
}
q.Set("since", strconv.FormatUint(version, 10))
u.RawQuery = q.Encode()
return u.String(), nil
}
func parseSetSignature(h string) ([]byte, error) {
scheme, hexSig, ok := strings.Cut(h, "=")
if !ok || scheme != "ed25519" {
return nil, ErrSetSignature
}
sig, err := hex.DecodeString(hexSig)
if err != nil {
return nil, ErrSetSignature
}
return sig, nil
}
package reporting
import (
"context"
"encoding/json"
"log"
"time"
)
// Spooler is the production Reporter: it enqueues minimized reports to a durable
// spool and drains them to all configured targets on an interval, retrying from
// the spool when a collector is down. It never blocks the alert path beyond a
// single bbolt write.
type Spooler struct {
spool *Spool
sender *Sender
targets map[string]Target
order []string
interval time.Duration
logf func(string, ...any)
}
// NewSpooler builds a Spooler over a spool, a sender, and the configured
// targets. A zero interval defaults to one minute.
func NewSpooler(spool *Spool, sender *Sender, targets []Target, interval time.Duration) *Spooler {
if interval <= 0 {
interval = time.Minute
}
m := make(map[string]Target, len(targets))
order := make([]string, 0, len(targets))
seen := make(map[string]bool, len(targets))
for _, t := range targets {
m[t.Name] = t
if !seen[t.Name] {
order = append(order, t.Name)
seen[t.Name] = true
}
}
return &Spooler{spool: spool, sender: sender, targets: m, order: order, interval: interval, logf: log.Printf}
}
// Enqueue persists r for delivery to every configured target. Dropped-count
// from spool overflow is logged so a sustained outage is visible.
func (s *Spooler) Enqueue(r Report) {
body, err := json.Marshal(r)
if err != nil {
return
}
for _, name := range s.order {
dropped, err := s.spool.Enqueue(name, body)
if err != nil {
s.logf("reporting: spool enqueue for %s failed: %v", name, err)
continue
}
if dropped > 0 {
s.logf("reporting: spool over capacity, dropped %d oldest reports for %s", dropped, name)
}
}
}
// DrainOnce attempts one delivery pass over the spool.
func (s *Spooler) DrainOnce(ctx context.Context) {
_, err := s.spool.Drain(func(target string, body []byte) error {
t, ok := s.targets[target]
if !ok {
// Target removed from config: drop the item by reporting success.
return nil
}
return s.sender.Send(ctx, t, body)
})
if err != nil {
s.logf("reporting: drain paused (will retry): %v", err)
}
}
// Run drains the spool every interval until ctx is cancelled.
func (s *Spooler) Run(ctx context.Context) {
t := time.NewTicker(s.interval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
s.DrainOnce(ctx)
}
}
}
package reporting
import (
"net"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// Class is the public abuse classification sent on the wire. It is a closed set
// matching the central database's accepted classes.
type Class string
const (
ClassBruteforce Class = "bruteforce"
ClassPHPRelay Class = "php_relay"
// #nosec G101 -- abuse-class label, not a credential.
ClassCredentialStuffing Class = "credential_stuffing"
ClassBadASNEgress Class = "bad_asn_egress"
)
// checkClass maps a CSM finding check name to its public abuse class. Only
// confirmed-abuse checks that carry a source IP appear here; anything absent is
// never reported. This is the v1 reportable set (host_takeover is an incident
// Kind, not a finding check, and is added when the gate also taps incidents).
var checkClass = map[string]Class{
"pam_bruteforce": ClassBruteforce,
"wp_login_bruteforce": ClassBruteforce,
"xmlrpc_abuse": ClassBruteforce,
"ftp_bruteforce": ClassBruteforce,
"smtp_bruteforce": ClassBruteforce,
"mail_bruteforce": ClassBruteforce,
"admin_panel_bruteforce": ClassBruteforce,
"credential_stuffing": ClassCredentialStuffing,
"email_php_relay_abuse": ClassPHPRelay,
"bad_asn_outbound": ClassBadASNEgress,
}
// Classify returns the abuse class for a check name, if it is reportable.
func Classify(check string) (Class, bool) {
c, ok := checkClass[check]
return c, ok
}
// Report is the minimized payload sent for a confirmed-abuse IP. It carries no
// hostnames, accounts, mailboxes, or paths. The JSON shape matches the central
// ingest contract exactly.
type Report struct {
IP string `json:"ip"`
Class Class `json:"class"`
Count int `json:"count"`
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
}
// Gate decides whether a finding is reportable and, if so, the minimized report
// to send. Only Critical findings whose check is an enabled abuse class and
// that carry a usable source IP are reported.
type Gate struct {
// Enabled is the set of classes the operator has turned on. Empty means
// none are reported.
Enabled map[Class]bool
}
// Consider returns the minimized report for f, or ok=false when f must not be
// reported. The minimizer is deny-by-default: it copies only the IP, class,
// count, and timestamps, never tenant/domain/mailbox/path/process fields.
func (g Gate) Consider(f alert.Finding) (Report, bool) {
if f.Severity != alert.Critical {
return Report{}, false
}
class, ok := Classify(f.Check)
if !ok || !g.Enabled[class] {
return Report{}, false
}
ip := net.ParseIP(f.SourceIP)
if ip == nil {
return Report{}, false
}
ts := f.Timestamp
if ts.IsZero() {
return Report{}, false
}
return Report{
IP: ip.String(),
Class: class,
Count: 1,
FirstSeen: ts.UTC(),
LastSeen: ts.UTC(),
}, true
}
// Reporter accepts minimized reports for asynchronous delivery. Implementations
// must not block the caller (the scan/alert path).
type Reporter interface {
Enqueue(Report)
}
// Noop is the default Reporter; it discards reports. Used when reporting is
// disabled so call sites stay unconditional.
type Noop struct{}
// Enqueue discards r.
func (Noop) Enqueue(Report) {}
package reporting
import (
"bytes"
"crypto/ed25519"
"encoding/hex"
"encoding/json"
"errors"
"io"
"net"
"sort"
"time"
)
// This is the node consume side of the signed scored-set. The encoding here
// MUST stay byte-identical to the central publisher (csm-abuse-db
// internal/publish) so a node re-marshals a decoded payload to exactly the
// signed bytes. TestScoredSetGoldenCanonical pins the canonical form.
var (
// ErrSetSignature means the scored-set signature did not verify.
ErrSetSignature = errors.New("reporting: scored-set signature invalid")
// ErrSetInvalid means a decoded scored-set was malformed or noncanonical.
ErrSetInvalid = errors.New("reporting: scored-set invalid")
// ErrSetVersionGap means a diff does not apply onto the cached version.
ErrSetVersionGap = errors.New("reporting: scored-set version gap")
)
// ScoredEntry is one scored IP in the distributed set.
type ScoredEntry struct {
IP string `json:"ip"`
Score int `json:"score"`
Classes []Class `json:"classes"`
LastSeen time.Time `json:"last_seen"`
}
// ScoredSnapshot is the full scored-set at a version.
type ScoredSnapshot struct {
Version uint64 `json:"version"`
Entries []ScoredEntry `json:"entries"`
}
// ScoredDiff is an incremental update from FromVersion to ToVersion.
type ScoredDiff struct {
FromVersion uint64 `json:"from_version"`
ToVersion uint64 `json:"to_version"`
Added []ScoredEntry `json:"added"`
Removed []string `json:"removed"`
Changed []ScoredEntry `json:"changed"`
}
// VerifiedScoredDiff has passed Ed25519 verification and canonical decode.
// Its internals stay opaque so raw decoded bytes cannot be fed to ApplyDiff.
type VerifiedScoredDiff struct {
diff ScoredDiff
verified bool
}
func knownClass(c Class) bool {
switch c {
case ClassBruteforce, ClassPHPRelay, ClassCredentialStuffing, ClassBadASNEgress:
return true
}
return false
}
func canonicalScoredClasses(in []Class) ([]Class, bool) {
if len(in) == 0 {
return nil, false
}
out := append([]Class(nil), in...)
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
for i, c := range out {
if !knownClass(c) {
return nil, false
}
if i > 0 && c == out[i-1] {
return nil, false
}
}
return out, true
}
func canonicalScoredEntry(e ScoredEntry) (ScoredEntry, bool) {
ip := net.ParseIP(e.IP)
if ip == nil || e.Score < 0 || e.Score > 100 || e.LastSeen.IsZero() {
return ScoredEntry{}, false
}
classes, ok := canonicalScoredClasses(e.Classes)
if !ok {
return ScoredEntry{}, false
}
return ScoredEntry{IP: ip.String(), Score: e.Score, Classes: classes, LastSeen: e.LastSeen.UTC()}, true
}
func canonicalScoredEntries(in []ScoredEntry) ([]ScoredEntry, bool) {
out := make([]ScoredEntry, len(in))
seen := make(map[string]struct{}, len(in))
for i, e := range in {
ce, ok := canonicalScoredEntry(e)
if !ok {
return nil, false
}
if _, dup := seen[ce.IP]; dup {
return nil, false
}
seen[ce.IP] = struct{}{}
out[i] = ce
}
sort.Slice(out, func(i, j int) bool { return out[i].IP < out[j].IP })
return out, true
}
func canonicalRemovedIPs(in []string) ([]string, bool) {
out := make([]string, len(in))
seen := make(map[string]struct{}, len(in))
for i, ip := range in {
parsed := net.ParseIP(ip)
if parsed == nil {
return nil, false
}
canonical := parsed.String()
if _, dup := seen[canonical]; dup {
return nil, false
}
seen[canonical] = struct{}{}
out[i] = canonical
}
sort.Strings(out)
return out, true
}
func canonicalScoredDiff(d ScoredDiff) (ScoredDiff, bool) {
if d.ToVersion <= d.FromVersion {
return ScoredDiff{}, false
}
added, ok := canonicalScoredEntries(d.Added)
if !ok {
return ScoredDiff{}, false
}
changed, ok := canonicalScoredEntries(d.Changed)
if !ok {
return ScoredDiff{}, false
}
removed, ok := canonicalRemovedIPs(d.Removed)
if !ok {
return ScoredDiff{}, false
}
if !scoredDiffOperationsDistinct(added, changed, removed) {
return ScoredDiff{}, false
}
return ScoredDiff{
FromVersion: d.FromVersion,
ToVersion: d.ToVersion,
Added: added,
Removed: removed,
Changed: changed,
}, true
}
func scoredDiffOperationsDistinct(added, changed []ScoredEntry, removed []string) bool {
seen := make(map[string]struct{}, len(added)+len(changed)+len(removed))
for _, e := range added {
seen[e.IP] = struct{}{}
}
for _, e := range changed {
if _, dup := seen[e.IP]; dup {
return false
}
seen[e.IP] = struct{}{}
}
for _, ip := range removed {
if _, dup := seen[ip]; dup {
return false
}
seen[ip] = struct{}{}
}
return true
}
func validateDiffAgainstBase(base ScoredSnapshot, d ScoredDiff) bool {
m := indexScoredEntries(base.Entries)
for _, ip := range d.Removed {
if _, ok := m[ip]; !ok {
return false
}
}
for _, e := range d.Added {
if _, ok := m[e.IP]; ok {
return false
}
}
for _, e := range d.Changed {
if _, ok := m[e.IP]; !ok {
return false
}
}
return true
}
func indexScoredEntries(entries []ScoredEntry) map[string]ScoredEntry {
m := make(map[string]ScoredEntry, len(entries))
for _, e := range entries {
m[e.IP] = e
}
return m
}
// MarshalScoredSnapshot deterministically encodes s (for signature/re-marshal).
func MarshalScoredSnapshot(s ScoredSnapshot) ([]byte, bool) {
entries, ok := canonicalScoredEntries(s.Entries)
if !ok {
return nil, false
}
b, err := json.Marshal(ScoredSnapshot{Version: s.Version, Entries: entries})
if err != nil {
return nil, false
}
return b, true
}
// MarshalScoredDiff deterministically encodes d (for signature/re-marshal).
func MarshalScoredDiff(d ScoredDiff) ([]byte, bool) {
d, ok := canonicalScoredDiff(d)
if !ok {
return nil, false
}
b, err := json.Marshal(d)
if err != nil {
return nil, false
}
return b, true
}
// VerifyScoredSet checks sig over payload under pubHex (the central public key).
func VerifyScoredSet(payload, sig []byte, pubHex string) error {
pub, err := hex.DecodeString(pubHex)
if err != nil || len(pub) != ed25519.PublicKeySize {
return ErrSetSignature
}
if !ed25519.Verify(ed25519.PublicKey(pub), payload, sig) {
return ErrSetSignature
}
return nil
}
func decodeStrict(b []byte, v any) error {
dec := json.NewDecoder(bytes.NewReader(b))
dec.DisallowUnknownFields()
if err := dec.Decode(v); err != nil {
return err
}
var extra struct{}
if err := dec.Decode(&extra); err != io.EOF {
if err == nil {
return ErrSetInvalid
}
return err
}
return nil
}
// OpenSnapshot verifies sig over payload, then decodes and canonical-checks the
// snapshot. Verification happens before any structural trust is placed in the
// bytes.
func OpenSnapshot(payload, sig []byte, pubHex string) (ScoredSnapshot, error) {
if err := VerifyScoredSet(payload, sig, pubHex); err != nil {
return ScoredSnapshot{}, err
}
var s ScoredSnapshot
if err := decodeStrict(payload, &s); err != nil {
return ScoredSnapshot{}, ErrSetInvalid
}
// A published snapshot is always version >= 1; version 0 is the node's
// empty-cache sentinel and must never be accepted from the wire (it would
// let a hostile endpoint pin the node to perpetual cold pulls).
if s.Version == 0 {
return ScoredSnapshot{}, ErrSetInvalid
}
entries, ok := canonicalScoredEntries(s.Entries)
if !ok {
return ScoredSnapshot{}, ErrSetInvalid
}
s.Entries = entries
canon, ok := MarshalScoredSnapshot(s)
if !ok || !bytes.Equal(canon, payload) {
return ScoredSnapshot{}, ErrSetInvalid
}
return s, nil
}
// OpenDiff verifies sig over payload then decodes the diff.
func OpenDiff(payload, sig []byte, pubHex string) (VerifiedScoredDiff, error) {
if err := VerifyScoredSet(payload, sig, pubHex); err != nil {
return VerifiedScoredDiff{}, err
}
var d ScoredDiff
if err := decodeStrict(payload, &d); err != nil {
return VerifiedScoredDiff{}, ErrSetInvalid
}
d, ok := canonicalScoredDiff(d)
if !ok {
return VerifiedScoredDiff{}, ErrSetInvalid
}
canon, ok := MarshalScoredDiff(d)
if !ok || !bytes.Equal(canon, payload) {
return VerifiedScoredDiff{}, ErrSetInvalid
}
return VerifiedScoredDiff{diff: d, verified: true}, nil
}
// ApplyDiff applies a verified diff onto base, returning the resulting snapshot.
func ApplyDiff(base ScoredSnapshot, d VerifiedScoredDiff) (ScoredSnapshot, error) {
if !d.verified {
return ScoredSnapshot{}, ErrSetSignature
}
entries, ok := canonicalScoredEntries(base.Entries)
if !ok {
return ScoredSnapshot{}, ErrSetInvalid
}
base.Entries = entries
diff := d.diff
if diff.FromVersion != base.Version {
return ScoredSnapshot{}, ErrSetVersionGap
}
if !validateDiffAgainstBase(base, diff) {
return ScoredSnapshot{}, ErrSetInvalid
}
m := indexScoredEntries(base.Entries)
for _, ip := range diff.Removed {
delete(m, ip)
}
for _, e := range diff.Added {
m[e.IP] = e
}
for _, e := range diff.Changed {
m[e.IP] = e
}
out := make([]ScoredEntry, 0, len(m))
for _, e := range m {
out = append(out, e)
}
entries, ok = canonicalScoredEntries(out)
if !ok {
return ScoredSnapshot{}, ErrSetInvalid
}
return ScoredSnapshot{Version: diff.ToVersion, Entries: entries}, nil
}
// Set is an in-memory lookup over the current scored-set.
type Set struct {
version uint64
byIP map[string]ScoredEntry
}
// NewSet builds a lookup set from a snapshot.
func NewSet(s ScoredSnapshot) *Set {
m := make(map[string]ScoredEntry, len(s.Entries))
for _, e := range s.Entries {
if ce, ok := canonicalScoredEntry(e); ok {
e = ce
}
e.Classes = append([]Class(nil), e.Classes...)
m[e.IP] = e
}
return &Set{version: s.Version, byIP: m}
}
// Version returns the set's version.
func (s *Set) Version() uint64 { return s.version }
// Lookup returns the scored entry for ip, normalizing the textual form.
func (s *Set) Lookup(ip string) (ScoredEntry, bool) {
p := net.ParseIP(ip)
if p == nil {
return ScoredEntry{}, false
}
e, ok := s.byIP[p.String()]
if ok {
e.Classes = append([]Class(nil), e.Classes...)
}
return e, ok
}
// Len returns the number of scored IPs.
func (s *Set) Len() int { return len(s.byIP) }
package reporting
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
)
// Transport selects how a report is signed for a target.
type Transport string
const (
// TransportEd25519 signs with an Ed25519 node key (federation / central DB).
TransportEd25519 Transport = "ed25519"
// TransportHMAC signs with a shared HMAC secret (private collector).
TransportHMAC Transport = "hmac"
)
// Target is one reporting destination.
type Target struct {
Name string
URL string
Transport Transport
NodeID string
KeyID string
// Ed25519Key is the node's private key for TransportEd25519.
Ed25519Key ed25519.PrivateKey
// HMACSecret is the shared secret for TransportHMAC.
HMACSecret []byte
// BearerToken is an optional Authorization bearer for HMAC collectors.
BearerToken string
}
var (
// ErrInsecureURL means a non-HTTPS target was configured for a non-loopback
// host. Reports and their auth context must not cross the network in clear.
ErrInsecureURL = errors.New("reporting: target URL must be https")
// ErrRejected means the collector rejected the report (non-2xx, non-conflict).
ErrRejected = errors.New("reporting: report rejected")
)
// Sender delivers a signed report body to a target over HTTP.
type Sender struct {
client *http.Client
now func() time.Time
}
// NewSender builds a Sender. A nil client uses a default with a 15s timeout.
func NewSender(client *http.Client, now func() time.Time) *Sender {
if client == nil {
client = &http.Client{Timeout: 15 * time.Second}
}
if now == nil {
now = time.Now
}
return &Sender{client: client, now: now}
}
// ValidateTargetURL reports whether raw is allowed for report delivery without
// logging or returning the raw URL.
func ValidateTargetURL(raw string) error {
u, err := url.Parse(raw)
if err != nil || !secureURL(u) {
return ErrInsecureURL
}
return nil
}
// Send signs body for t and POSTs it. A 2xx is success; 409 Conflict (the
// collector already has this report) is treated as success. Other statuses and
// transport errors are failures the caller should retry from the spool.
func (s *Sender) Send(ctx context.Context, t Target, body []byte) error {
u, err := url.Parse(t.URL)
if err != nil {
return fmt.Errorf("reporting: bad target url: %w", err)
}
if !secureURL(u) {
return ErrInsecureURL
}
nonce, err := newNonce()
if err != nil {
return err
}
env := Envelope{
NodeID: t.NodeID,
KeyID: t.KeyID,
Method: http.MethodPost,
Path: requestPath(u),
BodyHash: HashBody(body),
Timestamp: s.now().UTC().Unix(),
Nonce: nonce,
}
sig, scheme, err := s.sign(t, env)
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.URL, bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-CSM-Node", env.NodeID)
req.Header.Set("X-CSM-Key", env.KeyID)
req.Header.Set("X-CSM-Timestamp", fmt.Sprintf("%d", env.Timestamp))
req.Header.Set("X-CSM-Nonce", env.Nonce)
req.Header.Set("X-CSM-Signature", scheme+"="+hex.EncodeToString(sig))
if t.BearerToken != "" {
req.Header.Set("Authorization", "Bearer "+t.BearerToken)
}
resp, err := s.do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
switch {
case resp.StatusCode >= 200 && resp.StatusCode < 300:
return nil
case resp.StatusCode == http.StatusConflict:
return nil // collector already recorded this report (replay/dup)
default:
return fmt.Errorf("%w: status %d", ErrRejected, resp.StatusCode)
}
}
func (s *Sender) do(req *http.Request) (*http.Response, error) {
client := *s.client
// The signature is bound to the configured URL path, and auth headers must
// not be replayed to a redirected endpoint.
client.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse
}
// The destination is an operator-configured report target, validated to
// HTTPS (or loopback HTTP) by secureURL; it is not attacker-controlled.
// #nosec G704 -- report target URL is operator config, scheme-validated.
return client.Do(req)
}
func (s *Sender) sign(t Target, env Envelope) (sig []byte, scheme string, err error) {
switch t.Transport {
case TransportEd25519:
sig, err = SignEd25519(env, t.Ed25519Key)
return sig, "ed25519", err
case TransportHMAC:
sig, err = SignHMAC(env, t.HMACSecret)
return sig, "sha256", err
default:
return nil, "", fmt.Errorf("reporting: unknown transport %q", t.Transport)
}
}
// secureURL requires https, except for loopback hosts (local collectors / tests).
func secureURL(u *url.URL) bool {
host := u.Hostname()
if host == "" {
return false
}
if u.Scheme == "https" {
return true
}
if u.Scheme == "http" {
if ip := net.ParseIP(host); ip != nil {
return ip.IsLoopback()
}
return strings.EqualFold(host, "localhost")
}
return false
}
func requestPath(u *url.URL) string {
if u.Path == "" {
return "/"
}
return u.Path
}
func newNonce() (string, error) {
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
return "", err
}
return hex.EncodeToString(b[:]), nil
}
package reporting
import (
"encoding/binary"
"encoding/json"
"sync"
"time"
bolt "go.etcd.io/bbolt"
)
// spoolItem is one queued report: the destination target name and the exact
// minimized body bytes to sign and send.
type spoolItem struct {
Target string `json:"t"`
Body []byte `json:"b"`
}
// Spool is a durable, bounded outbound queue for reports, backed by bbolt so a
// down collector or a daemon restart does not drop confirmed-abuse reports.
type Spool struct {
db *bolt.DB
bucket []byte
max int
drain sync.Mutex
}
// NewSpool opens (or creates) a spool at path with a per-node entry cap.
func NewSpool(path, bucket string, max int) (*Spool, error) {
db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: 2 * time.Second})
if err != nil {
return nil, err
}
s := &Spool{db: db, bucket: []byte(bucket), max: max}
if err := db.Update(func(tx *bolt.Tx) error {
_, e := tx.CreateBucketIfNotExists(s.bucket)
return e
}); err != nil {
_ = db.Close()
return nil, err
}
return s, nil
}
// Close releases the underlying database.
func (s *Spool) Close() error { return s.db.Close() }
// Enqueue appends a report body destined for target. When the queue exceeds its
// cap, the oldest entries are dropped (FIFO) and the dropped count is returned
// so the caller can surface it; reports are best-effort under sustained outage.
func (s *Spool) Enqueue(target string, body []byte) (dropped int, err error) {
item := spoolItem{Target: target, Body: body}
enc, err := json.Marshal(item)
if err != nil {
return 0, err
}
err = s.db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(s.bucket)
seq, seqErr := b.NextSequence()
if seqErr != nil {
return seqErr
}
var key [8]byte
binary.BigEndian.PutUint64(key[:], seq)
if e := b.Put(key[:], enc); e != nil {
return e
}
// Count current keys via the cursor; Bucket.Stats is not reliable for
// pending changes inside the same write transaction.
count := 0
c := b.Cursor()
for k, _ := c.First(); k != nil; k, _ = c.Next() {
count++
}
// Trim from the front (oldest keys sort first) until within cap.
for count > s.max {
tc := b.Cursor()
k, _ := tc.First()
if k == nil {
break
}
if e := b.Delete(k); e != nil {
return e
}
count--
dropped++
}
return nil
})
return dropped, err
}
// Len returns the number of queued items.
func (s *Spool) Len() int {
n := 0
_ = s.db.View(func(tx *bolt.Tx) error {
n = tx.Bucket(s.bucket).Stats().KeyN
return nil
})
return n
}
// Drain processes queued items in FIFO order, calling send for each. An item is
// removed only when send returns nil; on the first send error Drain stops and
// leaves that item (and the rest) for a later retry, preserving order. It
// returns how many were delivered.
func (s *Spool) Drain(send func(target string, body []byte) error) (delivered int, err error) {
s.drain.Lock()
defer s.drain.Unlock()
for {
var (
key []byte
item spoolItem
has bool
)
if e := s.db.View(func(tx *bolt.Tx) error {
c := tx.Bucket(s.bucket).Cursor()
k, v := c.First()
if k == nil {
return nil
}
key = append([]byte(nil), k...)
has = true
return json.Unmarshal(v, &item)
}); e != nil {
return delivered, e
}
if !has {
return delivered, nil
}
if e := send(item.Target, item.Body); e != nil {
return delivered, e // keep this item; retry later
}
if e := s.db.Update(func(tx *bolt.Tx) error {
return tx.Bucket(s.bucket).Delete(key)
}); e != nil {
return delivered, e
}
delivered++
}
}
// Package reporting is the node side of CSM abuse reporting (Layer A). It turns
// confirmed-abuse findings into minimized, signed reports for a central abuse
// database or a private collector.
//
// The signed-envelope wire format here MUST stay byte-identical to the central
// service's verifier (csm-abuse-db internal/envelope). It is duplicated rather
// than imported to keep this repo's build self-contained; the wire-format test
// pins the canonical bytes so any divergence fails the build.
package reporting
import (
"crypto/ed25519"
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"errors"
)
const (
bodyHashLen = sha256.Size
// maxFieldLen bounds each canonical field; matches the central verifier.
maxFieldLen = 1 << 16
)
// ErrInvalidEnvelope means the envelope is structurally unusable.
var ErrInvalidEnvelope = errors.New("reporting: invalid envelope")
// Envelope is the set of fields signed alongside a report body. Field order and
// encoding match the central verifier exactly.
type Envelope struct {
NodeID string
KeyID string
Method string
Path string
BodyHash []byte // SHA-256 of the report body, 32 bytes
Timestamp int64 // unix seconds, UTC
Nonce string
}
// HashBody returns the SHA-256 of body.
func HashBody(body []byte) []byte {
sum := sha256.Sum256(body)
return sum[:]
}
// canonical returns the deterministic, injective encoding signed over. Every
// variable-length field is length-prefixed with a 4-byte big-endian length, and
// the timestamp is appended as 8 big-endian bytes.
func (e Envelope) canonical() ([]byte, error) {
if len(e.BodyHash) != bodyHashLen {
return nil, ErrInvalidEnvelope
}
fields := [][]byte{
[]byte(e.NodeID),
[]byte(e.KeyID),
[]byte(e.Method),
[]byte(e.Path),
e.BodyHash,
[]byte(e.Nonce),
}
size := 8
for _, f := range fields {
if len(f) > maxFieldLen {
return nil, ErrInvalidEnvelope
}
size += 4 + len(f)
}
buf := make([]byte, 0, size)
var lenbuf [4]byte
for _, f := range fields {
// len(f) is bounded by maxFieldLen above, well under math.MaxUint32.
// #nosec G115 -- length validated <= maxFieldLen (64 KiB) above.
binary.BigEndian.PutUint32(lenbuf[:], uint32(len(f)))
buf = append(buf, lenbuf[:]...)
buf = append(buf, f...)
}
var ts [8]byte
// Lossless int64 -> uint64 bit reinterpretation for fixed-width encoding.
// #nosec G115 -- intentional bit-pattern reinterpret, not a narrowing cast.
binary.BigEndian.PutUint64(ts[:], uint64(e.Timestamp))
buf = append(buf, ts[:]...)
return buf, nil
}
// SignEd25519 signs the canonical envelope with an Ed25519 private key.
func SignEd25519(e Envelope, priv ed25519.PrivateKey) ([]byte, error) {
msg, err := e.canonical()
if err != nil {
return nil, err
}
if len(priv) != ed25519.PrivateKeySize {
return nil, ErrInvalidEnvelope
}
return ed25519.Sign(priv, msg), nil
}
// SignHMAC signs the canonical envelope with HMAC-SHA256 (private collector).
func SignHMAC(e Envelope, secret []byte) ([]byte, error) {
msg, err := e.canonical()
if err != nil {
return nil, err
}
if len(secret) == 0 {
return nil, ErrInvalidEnvelope
}
mac := hmac.New(sha256.New, secret)
_, _ = mac.Write(msg)
return mac.Sum(nil), nil
}
// Package sdnotify is a thin wrapper around go-systemd's daemon notification
// helpers. The daemon calls Ready when watchers are attached, Status to
// publish a one-line state visible in `systemctl status`, and Watchdog on a
// recurring ticker so systemd's WatchdogSec= keep-alive doesn't expire.
//
// Every function is a no-op when NOTIFY_SOCKET is unset (the daemon is not
// running under systemd; e.g. dev mode). That contract makes it safe to call
// these helpers unconditionally without runtime gates in the daemon code.
package sdnotify
import (
systemd "github.com/coreos/go-systemd/v22/daemon"
)
// Ready signals systemd that the daemon has finished startup. Returns
// (true, nil) when the notification was delivered, (false, nil) when
// NOTIFY_SOCKET is unset, or (false, err) on a real I/O error.
func Ready() (bool, error) {
return systemd.SdNotify(false, systemd.SdNotifyReady)
}
// Reloading signals systemd that the daemon is reloading its config.
func Reloading() (bool, error) {
return systemd.SdNotify(false, systemd.SdNotifyReloading)
}
// Status sets a single-line status string visible in `systemctl status csm`.
func Status(msg string) (bool, error) {
return systemd.SdNotify(false, "STATUS="+msg)
}
// Watchdog pings the systemd watchdog. Required when the unit declares
// WatchdogSec=; without periodic pings systemd will restart the daemon
// after WatchdogSec elapses.
func Watchdog() (bool, error) {
return systemd.SdNotify(false, systemd.SdNotifyWatchdog)
}
package signatures
import (
"archive/zip"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/atomicio"
"github.com/pidginhost/csm/internal/yara"
)
const (
forgeReleasesURL = "https://api.github.com/repos/YARAHQ/yara-forge/releases/latest"
forgeHTTPTimeout = 30 * time.Second
forgeMaxZIPSize = 20 * 1024 * 1024
// forgeMaxYarSize caps the decompressed size of a single .yar entry. The
// compressed ZIP is bounded by forgeMaxZIPSize, but a zip bomb (or a
// compromised CDN / signing key) can encode a far larger decompressed
// payload; the full ruleset tier is a few MiB, so 64 MiB is generous.
forgeMaxYarSize = 64 * 1024 * 1024
)
var forgeTierAsset = map[string]string{
"core": "packages/core/yara-rules-core.yar",
"extended": "packages/extended/yara-rules-extended.yar",
"full": "packages/full/yara-rules-full.yar",
}
var forgeAtomicWrite = atomicio.AtomicWrite
// ForgeUpdate checks for a new YARA Forge release and downloads it if newer.
// A detached signature is fetched from the ZIP URL + ".sig" and verified
// against the raw ZIP content before extraction.
func ForgeUpdate(rulesDir, tier, currentVersion, signingKey string, disabledRules []string) (newVersion string, ruleCount int, err error) {
return ForgeUpdateFromURL(rulesDir, tier, currentVersion, signingKey, "", disabledRules)
}
// ForgeUpdateFromURL is ForgeUpdate with an explicit signed ZIP source. The
// downloadURL may contain {tier} and {version}; the signature is fetched from
// the resolved ZIP URL plus ".sig".
func ForgeUpdateFromURL(rulesDir, tier, currentVersion, signingKey, downloadURL string, disabledRules []string) (newVersion string, ruleCount int, err error) {
if _, ok := forgeTierAsset[tier]; !ok {
return "", 0, fmt.Errorf("unknown YARA Forge tier: %q (valid: core, extended, full)", tier)
}
if e := requireSigningKey(signingKey); e != nil {
return "", 0, e
}
if strings.TrimSpace(downloadURL) == "" {
return "", 0, fmt.Errorf("signatures.yara_forge.download_url is required: upstream YARA Forge does not publish CSM detached signatures")
}
latestTag, err := forgeLatestTag()
if err != nil {
return "", 0, fmt.Errorf("checking YARA Forge release: %w", err)
}
if latestTag == currentVersion {
return currentVersion, 0, nil
}
zipURL := forgeDownloadURL(downloadURL, tier, latestTag)
zipData, err := forgeDownload(zipURL)
if err != nil {
return "", 0, fmt.Errorf("downloading YARA Forge %s: %w", tier, err)
}
sig, err := fetchSignature(zipURL + ".sig")
if err != nil {
return "", 0, fmt.Errorf("YARA Forge signature verification required but failed: %w", err)
}
if e := VerifySignature(signingKey, zipData, sig); e != nil {
return "", 0, fmt.Errorf("YARA Forge signature invalid: %w", e)
}
assetPath := forgeTierAsset[tier]
yarContent, err := forgeExtractYar(zipData, assetPath)
if err != nil {
return "", 0, fmt.Errorf("extracting YARA Forge rules: %w", err)
}
if len(disabledRules) > 0 {
yarContent = filterDisabledRules(yarContent, disabledRules)
}
ruleCount = countRules(yarContent)
outFile := filepath.Join(rulesDir, fmt.Sprintf("yara-forge-%s.yar", tier))
outFileExisted := false
if _, err := os.Stat(outFile); err == nil {
outFileExisted = true
} else if !os.IsNotExist(err) {
return "", 0, fmt.Errorf("checking existing Forge tier: %w", err)
}
if err := yara.TestCompile(string(yarContent)); err != nil {
return "", 0, fmt.Errorf("YARA compilation test failed (keeping existing rules): %w", err)
}
if err := os.MkdirAll(rulesDir, 0700); err != nil {
return "", 0, fmt.Errorf("creating rules dir: %w", err)
}
if outFileExisted {
if err := removeInactiveForgeTiers(rulesDir, tier); err != nil {
return "", 0, err
}
if err := forgeAtomicWrite(outFile, 0600, yarContent); err != nil {
return "", 0, fmt.Errorf("installing rules: %w", err)
}
return latestTag, ruleCount, nil
}
if err := forgeAtomicWrite(outFile, 0600, yarContent); err != nil {
return "", 0, fmt.Errorf("installing rules: %w", err)
}
if err := removeInactiveForgeTiers(rulesDir, tier); err != nil {
return "", 0, err
}
return latestTag, ruleCount, nil
}
func removeInactiveForgeTiers(rulesDir, activeTier string) error {
for t := range forgeTierAsset {
if t == activeTier {
continue
}
path := filepath.Join(rulesDir, fmt.Sprintf("yara-forge-%s.yar", t))
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("removing inactive Forge tier %s: %w", t, err)
}
}
return nil
}
func forgeDownloadURL(tmpl, tier, version string) string {
url := strings.TrimSpace(tmpl)
url = strings.ReplaceAll(url, "{tier}", tier)
url = strings.ReplaceAll(url, "{version}", version)
return url
}
func forgeLatestTag() (string, error) {
client := &http.Client{Timeout: forgeHTTPTimeout}
resp, err := client.Get(forgeReleasesURL)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release struct {
TagName string `json:"tag_name"`
}
if err := json.NewDecoder(io.LimitReader(resp.Body, 1024*1024)).Decode(&release); err != nil {
return "", fmt.Errorf("parsing release JSON: %w", err)
}
if release.TagName == "" {
return "", fmt.Errorf("empty tag_name in release")
}
return release.TagName, nil
}
func forgeDownload(url string) ([]byte, error) {
client := &http.Client{Timeout: forgeHTTPTimeout}
resp, err := client.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("download returned %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, forgeMaxZIPSize))
if err != nil {
return nil, fmt.Errorf("reading response: %w", err)
}
return data, nil
}
func forgeExtractYar(zipData []byte, assetPath string) ([]byte, error) {
reader, err := zip.NewReader(bytes.NewReader(zipData), int64(len(zipData)))
if err != nil {
return nil, fmt.Errorf("opening ZIP: %w", err)
}
for _, f := range reader.File {
if f.Name == assetPath {
rc, err := f.Open()
if err != nil {
return nil, fmt.Errorf("opening %s in ZIP: %w", assetPath, err)
}
defer rc.Close()
data, err := io.ReadAll(io.LimitReader(rc, forgeMaxYarSize+1))
if err != nil {
return nil, fmt.Errorf("reading %s: %w", assetPath, err)
}
if len(data) > forgeMaxYarSize {
return nil, fmt.Errorf("%s exceeds decompressed size limit of %d bytes", assetPath, forgeMaxYarSize)
}
return data, nil
}
}
return nil, fmt.Errorf("asset %s not found in ZIP", assetPath)
}
func filterDisabledRules(content []byte, disabled []string) []byte {
if len(disabled) == 0 {
return content
}
disabledSet := make(map[string]bool, len(disabled))
for _, name := range disabled {
disabledSet[name] = true
}
lines := strings.Split(string(content), "\n")
var result []string
skipping := false
braceDepth := 0
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if skipping {
for _, ch := range trimmed {
switch ch {
case '{':
braceDepth++
case '}':
braceDepth--
}
}
if braceDepth <= 0 {
skipping = false
braceDepth = 0
}
continue
}
ruleName := extractRuleName(trimmed)
if ruleName != "" && disabledSet[ruleName] {
skipping = true
braceDepth = 0
for _, ch := range trimmed {
switch ch {
case '{':
braceDepth++
case '}':
braceDepth--
}
}
if braceDepth <= 0 {
skipping = false
braceDepth = 0
}
continue
}
result = append(result, line)
}
return []byte(strings.Join(result, "\n"))
}
func extractRuleName(line string) string {
s := strings.TrimPrefix(line, "private ")
if !strings.HasPrefix(s, "rule ") {
return ""
}
s = s[5:]
for i, ch := range s {
if ch == ' ' || ch == '\t' || ch == ':' || ch == '{' {
return s[:i]
}
}
return s
}
func countRules(content []byte) int {
count := 0
for _, line := range strings.Split(string(content), "\n") {
trimmed := strings.TrimSpace(line)
if extractRuleName(trimmed) != "" {
count++
}
}
return count
}
package signatures
import "sync"
var (
globalScanner *Scanner
globalOnce sync.Once
)
// Init initializes the global scanner with rules from the given directory.
// Safe to call multiple times - only the first call takes effect.
// Call Reload() on the returned scanner to reload rules (e.g., on SIGHUP).
func Init(rulesDir string) *Scanner {
globalOnce.Do(func() {
globalScanner = NewScanner(rulesDir)
})
return globalScanner
}
// Global returns the global scanner, or nil if Init() hasn't been called.
func Global() *Scanner {
return globalScanner
}
package signatures
import (
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"gopkg.in/yaml.v3"
)
// Rule represents a single malware detection rule loaded from an external file.
type Rule struct {
Name string `yaml:"name"`
Description string `yaml:"description"`
Severity string `yaml:"severity"` // "critical", "high", "warning"
Category string `yaml:"category"` // "webshell", "backdoor", "phishing", "dropper", "exploit"
FileTypes []string `yaml:"file_types"` // [".php", ".html", "*"] - which extensions to scan
Patterns []string `yaml:"patterns"` // literal string patterns (case-insensitive match)
Regexes []string `yaml:"regexes"` // regex patterns (for complex matching)
ExcludePatterns []string `yaml:"exclude_patterns"` // if any match, rule is skipped (false positive reduction)
ExcludeRegexes []string `yaml:"exclude_regexes"` // regex exclusions
MinMatch int `yaml:"min_match"` // minimum patterns that must match (default: 1)
RequireRegex bool `yaml:"require_regex"` // if true, at least one regex must match in addition to min_match
// Compiled regexes (populated by Compile())
compiledRegexes []*regexp.Regexp
compiledExcludeRegexes []*regexp.Regexp
}
// RuleFile is the top-level structure of a rules YAML file.
type RuleFile struct {
Version int `yaml:"version"`
Updated string `yaml:"updated"`
Rules []Rule `yaml:"rules"`
}
// Scanner holds compiled rules and provides file scanning.
type Scanner struct {
mu sync.RWMutex
rules []Rule
version int
rulesDir string
}
// NewScanner creates a scanner that loads rules from the given directory.
// Returns a scanner with no rules if the directory doesn't exist (not an error).
func NewScanner(rulesDir string) *Scanner {
s := &Scanner{rulesDir: rulesDir}
_ = s.Reload() // best-effort load on init
return s
}
// Reload loads/reloads all .yml and .yaml rule files from the rules directory.
func (s *Scanner) Reload() error {
if s.rulesDir == "" {
return nil
}
entries, err := os.ReadDir(s.rulesDir)
if err != nil {
if os.IsNotExist(err) {
return nil // no rules dir = no rules, not an error
}
return fmt.Errorf("reading rules dir %s: %w", s.rulesDir, err)
}
var allRules []Rule
maxVersion := 0
fileCount := 0
for _, entry := range entries {
name := entry.Name()
if entry.IsDir() {
continue
}
ext := strings.ToLower(filepath.Ext(name))
if ext != ".yml" && ext != ".yaml" {
continue
}
fileCount++
path := filepath.Join(s.rulesDir, name)
// #nosec G304 -- filepath.Join under operator-configured rulesDir.
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("reading %s: %w", path, err)
}
var rf RuleFile
if err := yaml.Unmarshal(data, &rf); err != nil {
return fmt.Errorf("parsing %s: %w", path, err)
}
if rf.Version > maxVersion {
maxVersion = rf.Version
}
// Compile rules
for i := range rf.Rules {
rule := &rf.Rules[i]
if err := rule.compile(); err != nil {
return fmt.Errorf("compiling rule %q in %s: %w", rule.Name, path, err)
}
if rule.MinMatch == 0 {
rule.MinMatch = 1
}
allRules = append(allRules, *rule)
}
}
if fileCount == 0 {
s.mu.RLock()
hadRules := len(s.rules) > 0
s.mu.RUnlock()
if hadRules {
return fmt.Errorf("no signature rule files found in %s", s.rulesDir)
}
return nil
}
if len(allRules) == 0 {
return fmt.Errorf("no signature rules loaded from %s", s.rulesDir)
}
s.mu.Lock()
s.rules = allRules
s.version = maxVersion
s.mu.Unlock()
if len(allRules) > 0 {
fmt.Fprintf(os.Stderr, "signatures: loaded %d rules (version %d) from %s\n", len(allRules), maxVersion, s.rulesDir)
}
return nil
}
// compile pre-compiles regex patterns for a rule.
func (r *Rule) compile() error {
r.compiledRegexes = nil
for _, pattern := range r.Regexes {
re, err := regexp.Compile("(?i)" + pattern) // case-insensitive
if err != nil {
return fmt.Errorf("invalid regex '%s': %w", pattern, err)
}
r.compiledRegexes = append(r.compiledRegexes, re)
}
r.compiledExcludeRegexes = nil
for _, pattern := range r.ExcludeRegexes {
re, err := regexp.Compile("(?i)" + pattern)
if err != nil {
return fmt.Errorf("invalid exclude regex '%s': %w", pattern, err)
}
r.compiledExcludeRegexes = append(r.compiledExcludeRegexes, re)
}
return nil
}
// Match represents a rule that matched a file.
type Match struct {
RuleName string
Description string
Severity string
Category string
Matched []string // which patterns matched
}
// ScanContent scans file content against loaded rules.
// fileExt should include the dot (e.g., ".php").
func (s *Scanner) ScanContent(content []byte, fileExt string) []Match {
s.mu.RLock()
defer s.mu.RUnlock()
if len(s.rules) == 0 {
return nil
}
contentLower := strings.ToLower(string(content))
extLower := strings.ToLower(fileExt)
var matches []Match
for _, rule := range s.rules {
// Check if this rule applies to this file type
if !ruleMatchesExt(rule, extLower) {
continue
}
// Check exclusions first - if any exclude pattern matches, skip this rule
excluded := false
for _, pattern := range rule.ExcludePatterns {
if strings.Contains(contentLower, strings.ToLower(pattern)) {
excluded = true
break
}
}
if !excluded {
for _, re := range rule.compiledExcludeRegexes {
if re.Match(content) {
excluded = true
break
}
}
}
if excluded {
continue
}
// Count pattern matches
var matched []string
regexMatched := false
for _, pattern := range rule.Patterns {
if strings.Contains(contentLower, strings.ToLower(pattern)) {
matched = append(matched, pattern)
}
}
for _, re := range rule.compiledRegexes {
if re.Match(content) {
matched = append(matched, re.String())
regexMatched = true
}
}
if len(matched) >= rule.MinMatch && (!rule.RequireRegex || regexMatched) {
matches = append(matches, Match{
RuleName: rule.Name,
Description: rule.Description,
Severity: rule.Severity,
Category: rule.Category,
Matched: matched,
})
}
}
return matches
}
// ScanFile reads a file and scans it against loaded rules.
func (s *Scanner) ScanFile(path string, maxBytes int) []Match {
s.mu.RLock()
ruleCount := len(s.rules)
s.mu.RUnlock()
if ruleCount == 0 {
return nil
}
if maxBytes <= 0 {
return nil
}
// #nosec G304 -- ScanFile's whole purpose is to scan a file on disk;
// `path` comes from the daemon's file index walker or a fanotify event.
f, err := os.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
// ReadAll over a LimitReader, not a single Read into a pre-sized buffer:
// a bare f.Read can return a short count on the first call, which would
// hand only a prefix to the scanner and silently miss malware further
// into the file. LimitReader also makes a negative/huge maxBytes safe
// (no make([]byte, maxBytes) panic / over-allocation).
buf, err := io.ReadAll(io.LimitReader(f, int64(maxBytes)))
if err != nil || len(buf) == 0 {
return nil
}
ext := filepath.Ext(path)
return s.ScanContent(buf, ext)
}
// RuleCount returns the number of loaded rules.
func (s *Scanner) RuleCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.rules)
}
// Version returns the highest version number across loaded rule files.
func (s *Scanner) Version() int {
s.mu.RLock()
defer s.mu.RUnlock()
return s.version
}
func ruleMatchesExt(rule Rule, ext string) bool {
if len(rule.FileTypes) == 0 {
return true // no filter = match all
}
for _, ft := range rule.FileTypes {
if ft == "*" || strings.ToLower(ft) == ext {
return true
}
}
return false
}
package signatures
import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
"gopkg.in/yaml.v3"
"github.com/pidginhost/csm/internal/atomicio"
)
// Update downloads the latest rules from the configured URL.
// Validates the downloaded rules before installing.
// A detached ed25519 signature is fetched from url+".sig" and verified
// before the rules are installed.
// Returns the number of rules loaded, or error.
func Update(rulesDir, url, signingKey string) (int, error) {
if url == "" {
return 0, fmt.Errorf("no update URL configured (set signatures.update_url in csm.yaml)")
}
if err := requireSigningKey(signingKey); err != nil {
return 0, err
}
// Download
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Get(url)
if err != nil {
return 0, fmt.Errorf("downloading rules from %s: %w", url, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
return 0, fmt.Errorf("download failed: HTTP %d from %s", resp.StatusCode, url)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) // max 10MB
if err != nil {
return 0, fmt.Errorf("reading response: %w", err)
}
sig, err := fetchSignature(url + ".sig")
if err != nil {
return 0, fmt.Errorf("signature verification required but failed: %w", err)
}
if err := VerifySignature(signingKey, data, sig); err != nil {
return 0, fmt.Errorf("rules signature invalid: %w", err)
}
// Validate: must parse as valid YAML rules
var rf RuleFile
if err := yaml.Unmarshal(data, &rf); err != nil {
return 0, fmt.Errorf("invalid rules file: %w", err)
}
if len(rf.Rules) == 0 {
return 0, fmt.Errorf("rules file contains no rules")
}
// Validate each rule compiles
for _, rule := range rf.Rules {
if err := rule.compile(); err != nil {
return 0, fmt.Errorf("rule '%s' failed validation: %w", rule.Name, err)
}
}
// Ensure rules directory exists
if err := os.MkdirAll(rulesDir, 0700); err != nil {
return 0, fmt.Errorf("creating rules dir: %w", err)
}
// Atomic write: write-temp, fsync, rename, dir-fsync. The daemon reloads
// these rules on the next tick, so a torn write must never be observable.
destPath := filepath.Join(rulesDir, "malware.yml")
if err := atomicio.AtomicWrite(destPath, 0600, data); err != nil {
return 0, fmt.Errorf("installing rules: %w", err)
}
return len(rf.Rules), nil
}
package signatures
import (
"crypto/ed25519"
"encoding/hex"
"fmt"
"io"
"net/http"
"time"
)
func requireSigningKey(signingKey string) error {
if signingKey == "" {
return fmt.Errorf("signatures.signing_key is required for remote rule updates")
}
return nil
}
// VerifySignature checks an ed25519 signature over data using a hex-encoded public key.
// Returns nil if the signature is valid.
func VerifySignature(pubKeyHex string, data, signature []byte) error {
pubKeyBytes, err := hex.DecodeString(pubKeyHex)
if err != nil {
return fmt.Errorf("invalid signing key (bad hex): %w", err)
}
if len(pubKeyBytes) != ed25519.PublicKeySize {
return fmt.Errorf("invalid signing key length: got %d bytes, want %d", len(pubKeyBytes), ed25519.PublicKeySize)
}
if len(signature) != ed25519.SignatureSize {
return fmt.Errorf("invalid signature length: got %d bytes, want %d", len(signature), ed25519.SignatureSize)
}
pubKey := ed25519.PublicKey(pubKeyBytes)
if !ed25519.Verify(pubKey, data, signature) {
return fmt.Errorf("signature verification failed")
}
return nil
}
// fetchSignature downloads a detached signature from url + ".sig".
func fetchSignature(sigURL string) ([]byte, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(sigURL)
if err != nil {
return nil, fmt.Errorf("downloading signature from %s: %w", sigURL, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("signature download returned HTTP %d from %s", resp.StatusCode, sigURL)
}
sig, err := io.ReadAll(io.LimitReader(resp.Body, 1024)) // ed25519 sig is 64 bytes
if err != nil {
return nil, fmt.Errorf("reading signature: %w", err)
}
return sig, nil
}
package state
import (
"fmt"
"os"
"path/filepath"
"syscall"
)
// LockFile provides file-based locking to prevent concurrent CSM runs.
type LockFile struct {
path string
file *os.File
}
// AcquireLock creates an exclusive lock. Returns error if already locked.
func AcquireLock(stateDir string) (*LockFile, error) {
lockPath := filepath.Join(stateDir, "csm.lock")
// #nosec G304 -- filepath.Join under operator-configured stateDir.
f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
return nil, fmt.Errorf("opening lock file: %w", err)
}
// Try non-blocking exclusive lock
// #nosec G115 -- os.File.Fd returns uintptr but POSIX file descriptors
// are small non-negative ints (rlimit ~1024); int conversion is lossless.
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil {
_ = f.Close()
return nil, fmt.Errorf("another CSM instance is already running")
}
// Write PID for debugging
_ = f.Truncate(0)
fmt.Fprintf(f, "%d\n", os.Getpid())
return &LockFile{path: lockPath, file: f}, nil
}
// Release releases the lock.
func (l *LockFile) Release() {
if l.file != nil {
// #nosec G115 -- see AcquireLock: POSIX fd fits in int.
_ = syscall.Flock(int(l.file.Fd()), syscall.LOCK_UN)
_ = l.file.Close()
os.Remove(l.path)
}
}
package state
import (
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
)
// MigrateStateDir copies the contents of oldPath into newPath if and only if:
// (1) oldPath exists and is non-empty,
// (2) newPath does not exist or is empty,
// (3) the two paths are not the same.
// Returns (true, nil) when a copy occurred, (false, nil) when no migration
// was needed. Any I/O error is fatal — the caller (daemon startup) should
// abort rather than silently start with mixed state.
func MigrateStateDir(oldPath, newPath string) (bool, error) {
if oldPath == "" || newPath == "" || oldPath == newPath {
return false, nil
}
oldInfo, err := os.Stat(oldPath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("stat %s: %w", oldPath, err)
}
if !oldInfo.IsDir() {
return false, nil
}
oldEntries, err := os.ReadDir(oldPath)
if err != nil {
return false, fmt.Errorf("reading %s: %w", oldPath, err)
}
if len(oldEntries) == 0 {
return false, nil
}
if newEntries, err := os.ReadDir(newPath); err == nil && len(newEntries) > 0 {
return false, nil // new dir non-empty: no migration
} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
return false, fmt.Errorf("reading %s: %w", newPath, err)
}
if err := os.MkdirAll(newPath, 0o700); err != nil {
return false, fmt.Errorf("mkdir %s: %w", newPath, err)
}
for _, e := range oldEntries {
src := filepath.Join(oldPath, e.Name())
dst := filepath.Join(newPath, e.Name())
if err := copyEntry(src, dst); err != nil {
return false, fmt.Errorf("copying %s: %w", src, err)
}
}
return true, nil
}
func copyEntry(src, dst string) error {
info, err := os.Stat(src)
if err != nil {
return err
}
if info.IsDir() {
if mkdirErr := os.MkdirAll(dst, info.Mode().Perm()); mkdirErr != nil {
return mkdirErr
}
entries, readDirErr := os.ReadDir(src)
if readDirErr != nil {
return readDirErr
}
for _, e := range entries {
if copyErr := copyEntry(filepath.Join(src, e.Name()), filepath.Join(dst, e.Name())); copyErr != nil {
return copyErr
}
}
return nil
}
// #nosec G304 -- src derived from os.ReadDir of a trusted state directory.
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
// #nosec G304 -- dst is new state directory plus an entry name from os.ReadDir.
out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode().Perm())
if err != nil {
return err
}
if _, err := io.Copy(out, in); err != nil {
_ = out.Close()
return err
}
if err := out.Sync(); err != nil {
_ = out.Close()
return err
}
return out.Close()
}
package state
import (
"bytes"
"crypto/sha256"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/atomicio"
"github.com/pidginhost/csm/internal/metrics"
"github.com/pidginhost/csm/internal/store"
)
// findingsTotal counts every finding CSM records, partitioned by
// severity. Registered lazily (via sync.Once) the first time a finding
// lands so tests that open Stores without running a full daemon do not
// panic on duplicate registration.
var (
findingsTotal *metrics.CounterVec
findingsTotalOnce sync.Once
)
func ensureFindingsMetric() {
findingsTotalOnce.Do(func() {
findingsTotal = metrics.NewCounterVec(
"csm_findings_total",
"Findings recorded by CSM, partitioned by severity.",
[]string{"severity"},
)
metrics.MustRegister("csm_findings_total", findingsTotal)
})
}
func recordFindings(findings []alert.Finding) {
if len(findings) == 0 {
return
}
ensureFindingsMetric()
for _, f := range findings {
findingsTotal.With(f.Severity.String()).Inc()
}
}
type Store struct {
mu sync.RWMutex
path string
entries map[string]*Entry
dirty bool // true if state changed since last save
savedHash string // hash of last saved state
// LatestFindings holds the full output of the most recent scan cycle.
// This is what the Findings page shows - "what's wrong right now" -
// separate from the alert dedup state above which controls "what to email."
latestMu sync.RWMutex
latestFindings []alert.Finding
latestScanTime time.Time
}
type Entry struct {
Hash string `json:"hash"`
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
AlertSent time.Time `json:"alert_sent"`
IsBaseline bool `json:"is_baseline"`
}
func Open(path string) (*Store, error) {
if err := os.MkdirAll(path, 0700); err != nil {
return nil, fmt.Errorf("creating state dir: %w", err)
}
s := &Store{
path: path,
entries: make(map[string]*Entry),
}
stateFile := filepath.Join(path, "state.json")
// #nosec G304 -- operator-configured statePath + fixed filename.
data, err := os.ReadFile(stateFile)
if err == nil {
// Backup state file before loading in case of corruption
// #nosec G703 -- stateFile is filepath.Join(path, "state.json") where
// path is the operator-configured statePath from csm.yaml.
_ = os.WriteFile(stateFile+".bak", data, 0600)
if unmarshalErr := json.Unmarshal(data, &s.entries); unmarshalErr != nil {
fmt.Fprintf(os.Stderr, "warning: failed to parse %s: %v (backup saved to %s.bak)\n", stateFile, unmarshalErr, stateFile)
}
}
// Load latest findings from disk (survives restart)
latestFile := filepath.Join(path, "latest_findings.json")
// #nosec G304 -- operator-configured statePath + fixed filename.
if latestData, err := os.ReadFile(latestFile); err == nil {
// Backup latest findings before loading
// #nosec G703 -- latestFile derived the same way as stateFile above.
_ = os.WriteFile(latestFile+".bak", latestData, 0600)
var findings []alert.Finding
if unmarshalErr := json.Unmarshal(latestData, &findings); unmarshalErr != nil {
fmt.Fprintf(os.Stderr, "warning: failed to parse %s: %v (backup saved to %s.bak)\n", latestFile, unmarshalErr, latestFile)
} else {
s.latestFindings = findings
}
}
return s, nil
}
func (s *Store) Close() error {
if !s.dirty {
return nil
}
return s.save()
}
func (s *Store) save() error {
data, err := json.MarshalIndent(s.entries, "", " ")
if err != nil {
return err
}
// Skip write if content hasn't changed
newHash := fmt.Sprintf("%x", sha256.Sum256(data))
if newHash == s.savedHash {
s.dirty = false
return nil
}
// Atomic write with fsync. Power loss between rename and a dir
// fsync can otherwise leave the file truncated; the savedHash gate
// above then "skips" the next attempt because the in-memory hash is
// unchanged, so corruption persists across restarts.
stateFile := filepath.Join(s.path, "state.json")
if err := atomicio.AtomicWriteJSON(stateFile, 0o600, s.entries); err != nil {
return err
}
s.savedHash = newHash
s.dirty = false
return nil
}
func findingKey(f alert.Finding) string {
return f.Key()
}
func findingHash(f alert.Finding) string {
return f.Fingerprint()
}
func (s *Store) FilterNew(findings []alert.Finding) []alert.Finding {
s.mu.RLock()
defer s.mu.RUnlock()
var newFindings []alert.Finding
for _, f := range findings {
key := findingKey(f)
hash := findingHash(f)
entry, exists := s.entries[key]
if !exists {
newFindings = append(newFindings, f)
continue
}
if entry.IsBaseline && entry.Hash == hash {
continue
}
if entry.Hash != hash {
newFindings = append(newFindings, f)
continue
}
// Same finding, check if we should re-alert (state expiry)
if !entry.AlertSent.IsZero() && time.Since(entry.AlertSent) > 24*time.Hour {
newFindings = append(newFindings, f)
}
}
return newFindings
}
func (s *Store) Update(findings []alert.Finding) {
s.mu.Lock()
defer s.mu.Unlock()
s.dirty = true
now := time.Now()
seen := make(map[string]bool)
for _, f := range findings {
key := findingKey(f)
hash := findingHash(f)
seen[key] = true
entry, exists := s.entries[key]
if !exists {
s.entries[key] = &Entry{
Hash: hash,
FirstSeen: now,
LastSeen: now,
AlertSent: now,
}
} else {
entry.Hash = hash
entry.LastSeen = now
if entry.AlertSent.IsZero() {
entry.AlertSent = now
}
}
}
// Clean up entries that are no longer found. Keys with a leading
// underscore are internal housekeeping written via SetRaw (throttles,
// per-file content-hash baselines, sentinel flags) and never appear in
// the findings stream, so the !seen branch would always evict them and
// silently re-arm one-shot detectors that gate on their presence.
for key, entry := range s.entries {
if strings.HasPrefix(key, "_") {
continue
}
if !seen[key] && !entry.IsBaseline {
if time.Since(entry.LastSeen) > 24*time.Hour {
delete(s.entries, key)
}
}
}
if err := s.save(); err != nil {
fmt.Fprintf(os.Stderr, "state: error saving after update: %v\n", err)
}
}
// MarkAlerted refreshes AlertSent on each finding's entry so the 24-hour
// dedup window restarts. Call after dispatch with the slice that came back
// from FilterNew. Without this, any finding that survives past the 24-hour
// expiry branch in FilterNew re-emits on every subsequent tick because
// Update only sets AlertSent when an entry is first created.
func (s *Store) MarkAlerted(findings []alert.Finding) {
if len(findings) == 0 {
return
}
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
changed := false
for _, f := range findings {
key := findingKey(f)
if entry, ok := s.entries[key]; ok {
entry.AlertSent = now
changed = true
}
}
if !changed {
return
}
s.dirty = true
if err := s.save(); err != nil {
fmt.Fprintf(os.Stderr, "state: error saving after MarkAlerted: %v\n", err)
}
}
func (s *Store) SetBaseline(findings []alert.Finding) {
s.mu.Lock()
defer s.mu.Unlock()
s.dirty = true
var baselineAt *Entry
if entry, ok := s.entries[baselineAtMetaKey]; ok {
entryCopy := *entry
baselineAt = &entryCopy
}
s.entries = make(map[string]*Entry)
now := time.Now()
for _, f := range findings {
key := findingKey(f)
hash := findingHash(f)
s.entries[key] = &Entry{
Hash: hash,
FirstSeen: now,
LastSeen: now,
IsBaseline: true,
}
}
if baselineAt != nil {
s.entries[baselineAtMetaKey] = baselineAt
}
if err := s.save(); err != nil {
fmt.Fprintf(os.Stderr, "state: error saving baseline: %v\n", err)
}
}
func (s *Store) ShouldRunThrottled(checkName string, intervalMin int) bool {
s.mu.Lock()
defer s.mu.Unlock()
key := fmt.Sprintf("_throttle:%s", checkName)
entry, exists := s.entries[key]
if !exists {
s.entries[key] = &Entry{LastSeen: time.Now()}
return true
}
if time.Since(entry.LastSeen) >= time.Duration(intervalMin)*time.Minute {
entry.LastSeen = time.Now()
return true
}
return false
}
func (s *Store) GetRaw(key string) (string, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
entry, ok := s.entries[key]
if !ok {
return "", false
}
return entry.Hash, true
}
func (s *Store) SetRaw(key, value string) {
s.mu.Lock()
defer s.mu.Unlock()
entry, exists := s.entries[key]
if !exists {
s.entries[key] = &Entry{
Hash: value,
FirstSeen: time.Now(),
LastSeen: time.Now(),
}
s.dirty = true
} else if entry.Hash != value {
entry.Hash = value
entry.LastSeen = time.Now()
s.dirty = true
}
}
// AppendHistory writes findings to the bbolt store (if available) or
// falls back to the append-only JSONL history file.
// The JSONL fallback is deprecated and will be removed in a future release.
func (s *Store) AppendHistory(findings []alert.Finding) {
if len(findings) == 0 {
return
}
recordFindings(findings)
// Use bbolt store when available; skip JSONL writes entirely.
if db := store.Global(); db != nil {
if err := db.AppendHistory(findings); err != nil {
fmt.Fprintf(os.Stderr, "store: append history: %v\n", err)
}
return
}
// Deprecated: flat-file JSONL fallback has a truncation race condition.
// This path is kept only for installations that have not yet migrated to bbolt.
fmt.Fprintf(os.Stderr, "DEPRECATION: using JSONL history fallback; migrate to bbolt store\n")
s.appendHistoryFile(findings)
}
// appendHistoryFile writes findings to the append-only JSONL history file.
// Caps file at 10MB by truncating the oldest half.
func (s *Store) appendHistoryFile(findings []alert.Finding) {
histPath := filepath.Join(s.path, "history.jsonl")
// Check size, truncate if over 10MB
if info, err := os.Stat(histPath); err == nil && info.Size() > 10*1024*1024 {
// #nosec G304 -- histPath is {s.path}/history.jsonl; s.path is the
// operator-configured statePath set at Store creation.
data, err := os.ReadFile(histPath)
if err == nil {
// Keep the second half
half := len(data) / 2
for half < len(data) && data[half] != '\n' {
half++
}
if half < len(data) {
// #nosec G703 -- histPath comes from state.historyPath, a
// filepath.Join under the operator-configured statePath.
_ = os.WriteFile(histPath, data[half+1:], 0600)
}
}
}
// #nosec G304 -- see histPath derivation above.
f, err := os.OpenFile(histPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return
}
defer func() { _ = f.Close() }()
for _, finding := range findings {
line, err := json.Marshal(finding)
if err != nil {
continue
}
_, _ = f.Write(line)
_, _ = f.Write([]byte("\n"))
}
}
func (s *Store) PrintStatus() {
if len(s.entries) == 0 {
fmt.Println("No state entries. Run 'csm baseline' first.")
return
}
baselineCount := 0
activeCount := 0
for key, entry := range s.entries {
if entry.IsBaseline {
baselineCount++
continue
}
if key[0] == '_' {
continue
}
activeCount++
fmt.Printf(" [ACTIVE] %s (first: %s, last: %s)\n",
key,
entry.FirstSeen.Format("2006-01-02 15:04"),
entry.LastSeen.Format("2006-01-02 15:04"),
)
}
fmt.Printf("\nBaseline entries: %d, Active findings: %d\n", baselineCount, activeCount)
}
// Entries returns a snapshot copy of current state entries (thread-safe).
func (s *Store) Entries() map[string]*Entry {
s.mu.RLock()
defer s.mu.RUnlock()
copy := make(map[string]*Entry, len(s.entries))
for k, v := range s.entries {
if k[0] == '_' {
continue // skip internal keys
}
entryCopy := *v
copy[k] = &entryCopy
}
return copy
}
// EntryForKey returns the dedup entry for a finding key (check:message).
func (s *Store) EntryForKey(key string) (Entry, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
e, ok := s.entries[key]
if !ok {
return Entry{}, false
}
return *e, true
}
// ReadHistory reads the last limit entries, starting at offset.
// Returns the findings (newest first) and total count.
// Uses bbolt store when available, falls back to flat-file JSONL.
func (s *Store) ReadHistory(limit, offset int) ([]alert.Finding, int) {
// Use bbolt store when available.
if db := store.Global(); db != nil {
return db.ReadHistory(limit, offset)
}
// Fallback: flat-file JSONL.
historyPath := filepath.Join(s.path, "history.jsonl")
// #nosec G304 -- {s.path}/history.jsonl; s.path from operator config.
data, err := os.ReadFile(historyPath)
if err != nil {
return nil, 0
}
var all []alert.Finding
for _, line := range splitLines(data) {
if len(line) == 0 {
continue
}
var f alert.Finding
if err := json.Unmarshal(line, &f); err != nil {
continue
}
all = append(all, f)
}
// Reverse (newest first)
for i, j := 0, len(all)-1; i < j; i, j = i+1, j-1 {
all[i], all[j] = all[j], all[i]
}
total := len(all)
if offset >= total {
return nil, total
}
end := offset + limit
if end > total {
end = total
}
return all[offset:end], total
}
// ReadHistoryFiltered reads matching history entries newest-first.
// Date strings that are not YYYY-MM-DD are ignored before they reach the
// lexicographic bbolt key filter.
func (s *Store) ReadHistoryFiltered(limit, offset int, from, to string, severity int, search string) ([]alert.Finding, int) {
return s.ReadHistoryFilteredWithChecks(limit, offset, from, to, severity, search, nil)
}
// ReadHistoryFilteredWithChecks reads matching history entries newest-first,
// optionally constrained to an exact check-name set.
func (s *Store) ReadHistoryFilteredWithChecks(
limit, offset int,
from, to string,
severity int,
search string,
checks map[string]bool,
) ([]alert.Finding, int) {
var fromDate, toDate time.Time
fromFilter := ""
toFilter := ""
if from != "" {
if t, err := time.ParseInLocation("2006-01-02", from, time.Local); err == nil {
fromDate = t
fromFilter = from
}
}
if to != "" {
if t, err := time.ParseInLocation("2006-01-02", to, time.Local); err == nil {
toDate = t.Add(24*time.Hour - time.Nanosecond)
toFilter = to
}
}
if db := store.Global(); db != nil {
return db.ReadHistoryFilteredWithChecks(limit, offset, fromFilter, toFilter, severity, search, checks)
}
all, _ := s.ReadHistory(1<<30, 0)
searchLower := strings.ToLower(search)
var results []alert.Finding
matched := 0
for _, f := range all {
if !fromDate.IsZero() && f.Timestamp.Before(fromDate) {
continue
}
if !toDate.IsZero() && f.Timestamp.After(toDate) {
continue
}
if severity >= 0 && int(f.Severity) != severity {
continue
}
if checks != nil && !checks[f.Check] {
continue
}
if search != "" {
if !strings.Contains(strings.ToLower(f.Check), searchLower) &&
!strings.Contains(strings.ToLower(f.Message), searchLower) &&
!strings.Contains(strings.ToLower(f.Details), searchLower) {
continue
}
}
matched++
if matched > offset && len(results) < limit {
results = append(results, f)
}
}
return results, matched
}
// ReadHistorySince returns all findings since the given time.
// Uses bbolt cursor seeking for efficiency. Results are newest-first.
// Falls back to the JSONL store with a linear cutoff filter when bbolt
// is unavailable so test wiring and migration-pending hosts still get
// time-bounded search results.
func (s *Store) ReadHistorySince(since time.Time) []alert.Finding {
if db := store.Global(); db != nil {
return db.ReadHistorySince(since)
}
all, _ := s.ReadHistory(1<<30, 0)
out := all[:0]
for _, f := range all {
if !f.Timestamp.Before(since) {
out = append(out, f)
}
}
return out
}
// SearchHistorySince returns up to limit matching findings since the given
// time, newest-first.
func (s *Store) SearchHistorySince(since time.Time, limit int, match func(alert.Finding) bool) []alert.Finding {
if limit <= 0 {
return nil
}
if db := store.Global(); db != nil {
return db.SearchHistorySince(since, limit, match)
}
return s.searchHistoryFileSince(since, limit, match)
}
func (s *Store) searchHistoryFileSince(since time.Time, limit int, match func(alert.Finding) bool) []alert.Finding {
historyPath := filepath.Join(s.path, "history.jsonl")
// #nosec G304 -- {s.path}/history.jsonl; s.path from operator config.
data, err := os.ReadFile(historyPath)
if err != nil {
return nil
}
var results []alert.Finding
end := len(data)
for end > 0 && len(results) < limit {
for end > 0 && (data[end-1] == '\n' || data[end-1] == '\r') {
end--
}
if end == 0 {
break
}
start := bytes.LastIndexByte(data[:end], '\n') + 1
line := data[start:end]
if start == 0 {
end = 0
} else {
end = start - 1
}
var f alert.Finding
if err := json.Unmarshal(line, &f); err != nil {
continue
}
// The JSONL fallback appends findings in chronological order, so
// once a reverse scan reaches an old row the remaining rows are older.
if f.Timestamp.Before(since) {
break
}
if match != nil && !match(f) {
continue
}
results = append(results, f)
}
return results
}
// AggregateByHour returns 24 hourly severity buckets for the last 24 hours.
func (s *Store) AggregateByHour() []store.HourBucket {
if db := store.Global(); db != nil {
return db.AggregateByHour()
}
return nil
}
// AggregateByDay returns 30 daily severity buckets for the last 30 days.
func (s *Store) AggregateByDay() []store.DayBucket {
if db := store.Global(); db != nil {
return db.AggregateByDay()
}
return nil
}
// AggregateByDayN returns `days` daily severity buckets (oldest first),
// clamped by the underlying store's retention window.
func (s *Store) AggregateByDayN(days int) []store.DayBucket {
if db := store.Global(); db != nil {
return db.AggregateByDayN(days)
}
return nil
}
func splitLines(data []byte) [][]byte {
var lines [][]byte
start := 0
for i, b := range data {
if b == '\n' {
if i > start {
lines = append(lines, data[start:i])
}
start = i + 1
}
}
if start < len(data) {
lines = append(lines, data[start:])
}
return lines
}
// SetLatestFindings merges scan results into the current findings set.
// Called by the daemon after each periodic scan completes. Merges rather
// than replaces - critical scan results coexist with deep scan results.
// Use ClearLatestFindings() + SetLatestFindings() for a full replace.
func (s *Store) SetLatestFindings(findings []alert.Finding) {
s.latestMu.Lock()
defer s.latestMu.Unlock()
// Build map of existing findings by key
existing := make(map[string]alert.Finding)
for _, f := range s.latestFindings {
existing[f.Key()] = f
}
// Merge new findings (update existing, add new)
for _, f := range findings {
existing[f.Key()] = f // newer overwrites older
}
// Flatten back to slice
var merged []alert.Finding
for _, f := range existing {
merged = append(merged, f)
}
// Cap at 15,000 findings to prevent unbounded memory growth
if len(merged) > 15000 {
merged = merged[:15000]
}
s.latestFindings = merged
s.latestScanTime = time.Now()
// Persist to disk
_ = atomicio.AtomicWriteJSON(filepath.Join(s.path, "latest_findings.json"), 0o600, merged)
}
// PurgeFindingsByChecks removes all findings whose Check field matches
// any of the given check names and persists the result to disk.
// Used to clear stale performance findings before merging fresh results
// from a scan tier.
func (s *Store) PurgeFindingsByChecks(checks []string) {
if len(checks) == 0 {
return
}
s.latestMu.Lock()
defer s.latestMu.Unlock()
remove := make(map[string]bool, len(checks))
for _, c := range checks {
remove[c] = true
}
n := 0
for _, f := range s.latestFindings {
if !remove[f.Check] {
s.latestFindings[n] = f
n++
}
}
s.latestFindings = s.latestFindings[:n]
// Persist to disk so purged findings don't reappear after restart.
// Mirrors the persistence logic at the end of SetLatestFindings().
_ = atomicio.AtomicWriteJSON(filepath.Join(s.path, "latest_findings.json"), 0o600, s.latestFindings)
}
// PurgeAndMergeFindings atomically removes findings matching the given check
// names and then merges the new findings. This prevents a race window where
// concurrent readers could see findings with perf checks missing.
func (s *Store) PurgeAndMergeFindings(purgeChecks []string, findings []alert.Finding) {
s.latestMu.Lock()
defer s.latestMu.Unlock()
// Build set of checks to purge
remove := make(map[string]bool, len(purgeChecks))
for _, c := range purgeChecks {
remove[c] = true
}
// Build map: keep existing non-purged findings
existing := make(map[string]alert.Finding)
for _, f := range s.latestFindings {
if !shouldPurgeLatestFinding(f, remove) {
existing[f.Key()] = f
}
}
// Merge new findings
for _, f := range findings {
existing[f.Key()] = f
}
// Flatten
var merged []alert.Finding
for _, f := range existing {
merged = append(merged, f)
}
if len(merged) > 15000 {
merged = merged[:15000]
}
s.latestFindings = merged
s.latestScanTime = time.Now()
// Persist
_ = atomicio.AtomicWriteJSON(filepath.Join(s.path, "latest_findings.json"), 0o600, merged)
}
func shouldPurgeLatestFinding(f alert.Finding, remove map[string]bool) bool {
if remove[f.Check] {
return true
}
if f.Check != "check_timeout" {
return false
}
runner, ok := timeoutFindingRunner(f.Message)
return ok && remove[runner]
}
func timeoutFindingRunner(msg string) (string, bool) {
rest, ok := strings.CutPrefix(msg, "Check '")
if !ok {
return "", false
}
runner, _, ok := strings.Cut(rest, "'")
return runner, ok && runner != ""
}
// ClearLatestFindings removes all findings from the latest set.
// Use before SetLatestFindings for a full replace (e.g. initial scan).
func (s *Store) ClearLatestFindings() {
s.latestMu.Lock()
defer s.latestMu.Unlock()
s.latestFindings = nil
}
// LatestFindings returns the full results of the most recent scan.
// This is what the Findings page shows - "what's wrong right now."
func (s *Store) LatestFindings() []alert.Finding {
s.latestMu.RLock()
defer s.latestMu.RUnlock()
result := make([]alert.Finding, len(s.latestFindings))
copy(result, s.latestFindings)
return result
}
// LatestScanTime returns when the last scan completed.
func (s *Store) LatestScanTime() time.Time {
s.latestMu.RLock()
defer s.latestMu.RUnlock()
return s.latestScanTime
}
const baselineAtMetaKey = "__baseline_at"
// EnsureBaseline records the first-start timestamp the first time it is
// called against a fresh state directory. Subsequent calls preserve the
// original value so reinstalls / upgrades do not reset the baseline. Safe
// to call from the daemon boot path on every start.
func (s *Store) EnsureBaseline(now time.Time) {
if _, ok := s.GetRaw(baselineAtMetaKey); ok {
return
}
s.SetRaw(baselineAtMetaKey, now.UTC().Format(time.RFC3339Nano))
}
// BaselineAt returns the persisted baseline timestamp, or the zero time
// when EnsureBaseline has not been called yet.
func (s *Store) BaselineAt() time.Time {
raw, ok := s.GetRaw(baselineAtMetaKey)
if !ok {
return time.Time{}
}
t, err := time.Parse(time.RFC3339Nano, raw)
if err != nil {
return time.Time{}
}
return t
}
// DismissLatestFinding removes a finding from the latest scan results.
func (s *Store) DismissLatestFinding(key string) {
s.latestMu.Lock()
defer s.latestMu.Unlock()
var filtered []alert.Finding
for _, f := range s.latestFindings {
if f.Key() != key {
filtered = append(filtered, f)
}
}
s.latestFindings = filtered
_ = atomicio.AtomicWriteJSON(filepath.Join(s.path, "latest_findings.json"), 0o600, s.latestFindings)
}
// DismissFinding marks a finding as baseline (acknowledged/dismissed).
// It will no longer appear in active findings or trigger new alerts.
func (s *Store) DismissFinding(key string) {
s.mu.Lock()
defer s.mu.Unlock()
if entry, exists := s.entries[key]; exists {
entry.IsBaseline = true
s.dirty = true
}
}
// ParseKey splits a state key "check:message" into its components.
func ParseKey(key string) (check, message string) {
for i := 0; i < len(key); i++ {
if key[i] == ':' {
return key[:i], key[i+1:]
}
}
return key, ""
}
// --- Suppression rules ---
// SuppressionRule defines a rule for suppressing specific findings.
type SuppressionRule struct {
ID string `json:"id"`
Check string `json:"check"`
PathPattern string `json:"path_pattern,omitempty"`
Reason string `json:"reason"`
CreatedAt time.Time `json:"created_at"`
}
// LoadSuppressions reads suppression rules from disk.
func (s *Store) LoadSuppressions() []SuppressionRule {
data, err := os.ReadFile(filepath.Join(s.path, "suppressions.json"))
if err != nil {
return nil
}
var rules []SuppressionRule
if err := json.Unmarshal(data, &rules); err != nil {
return nil
}
return rules
}
// SaveSuppressions writes suppression rules to disk atomically with fsync.
func (s *Store) SaveSuppressions(rules []SuppressionRule) error {
return atomicio.AtomicWriteJSON(filepath.Join(s.path, "suppressions.json"), 0o600, rules)
}
// IsSuppressed checks if a finding matches any loaded suppression rule.
// Load rules once with LoadSuppressions() and pass them in to avoid
// re-reading the file for every finding.
func (s *Store) IsSuppressed(f alert.Finding, rules []SuppressionRule) bool {
for _, rule := range rules {
if f.Check != rule.Check {
continue
}
// If no path pattern, suppress all findings for this check type
if rule.PathPattern == "" {
return true
}
// Match against the finding's FilePath
if f.FilePath != "" {
if matched, _ := filepath.Match(rule.PathPattern, f.FilePath); matched {
return true
}
}
for _, candidate := range suppressionPathCandidates(f) {
if matched, _ := filepath.Match(rule.PathPattern, candidate); matched {
return true
}
}
}
return false
}
func suppressionPathCandidates(f alert.Finding) []string {
if f.FilePath != "" {
return []string{f.FilePath}
}
fields := strings.Fields(f.Message + " " + f.Details)
seen := make(map[string]bool)
var paths []string
for _, field := range fields {
field = strings.Trim(field, `"'():,;[]{}<>`)
if !strings.HasPrefix(field, "/") {
continue
}
candidate := filepath.Clean(field)
if candidate == "." || candidate == "/" || seen[candidate] {
continue
}
seen[candidate] = true
paths = append(paths, candidate)
}
return paths
}
package store
import (
"encoding/json"
"strings"
"time"
bolt "go.etcd.io/bbolt"
)
// AdminEmailEntry records that a given email was observed as a WordPress
// administrator on (Account, Schema). One email may carry multiple
// entries when the same person administers several customer sites --
// surfacing that overlap is the whole point of the bucket.
type AdminEmailEntry struct {
Account string `json:"account"`
Schema string `json:"schema"`
LastSeen time.Time `json:"last_seen"`
}
const adminEmailsBucket = "admin:emails"
// RecordAdminEmail upserts an observation that `email` is administrator
// on (account, schema) at `now`. Re-recording the same triple updates
// LastSeen without creating a duplicate row; recording the same email
// across a different (account, schema) appends to the owner list. The
// email is lowercased before storage so case-mismatched recordings
// collapse to a single key.
func (db *DB) RecordAdminEmail(email, account, schema string, now time.Time) error {
email = strings.ToLower(strings.TrimSpace(email))
if email == "" || account == "" {
return nil
}
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(adminEmailsBucket))
var entries []AdminEmailEntry
if v := b.Get([]byte(email)); v != nil {
if err := json.Unmarshal(v, &entries); err != nil {
// Corrupt entry: restart the list rather than fail the
// whole write -- we'd rather lose one stale observation
// than block detection on a malformed payload.
entries = nil
}
}
updated := false
for i := range entries {
if entries[i].Account == account && entries[i].Schema == schema {
entries[i].LastSeen = now
updated = true
break
}
}
if !updated {
entries = append(entries, AdminEmailEntry{
Account: account,
Schema: schema,
LastSeen: now,
})
}
payload, err := json.Marshal(entries)
if err != nil {
return err
}
return b.Put([]byte(email), payload)
})
}
// AdminEmailOwners returns the full list of (account, schema, last_seen)
// triples recorded for `email`. Returns an empty slice when the email
// is unknown.
func (db *DB) AdminEmailOwners(email string) ([]AdminEmailEntry, error) {
email = strings.ToLower(strings.TrimSpace(email))
var out []AdminEmailEntry
err := db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(adminEmailsBucket))
v := b.Get([]byte(email))
if v == nil {
return nil
}
return json.Unmarshal(v, &out)
})
return out, err
}
// OverlappingAdminEmails returns every email whose owner list has at
// least `minAccounts` distinct accounts after stale entries (older than
// `retention`) are filtered out. Stale entries are silently dropped from
// the returned slice but stay in the bucket -- a separate compaction
// pass could prune them, but for the cross-account correlator the
// in-memory filter is sufficient.
//
// `minAccounts` smaller than 2 is clamped to 2 because a single-account
// observation is never an overlap by definition.
func (db *DB) OverlappingAdminEmails(minAccounts int, retention time.Duration) (map[string][]AdminEmailEntry, error) {
if minAccounts < 2 {
minAccounts = 2
}
cutoff := time.Now().Add(-retention)
out := map[string][]AdminEmailEntry{}
err := db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(adminEmailsBucket))
return b.ForEach(func(k, v []byte) error {
var entries []AdminEmailEntry
if err := json.Unmarshal(v, &entries); err != nil {
return nil //nolint:nilerr // skip corrupt entry
}
fresh := entries[:0]
seen := map[string]struct{}{}
for _, e := range entries {
if e.LastSeen.Before(cutoff) {
continue
}
key := e.Account + "|" + e.Schema
if _, dup := seen[key]; dup {
continue
}
seen[key] = struct{}{}
fresh = append(fresh, e)
}
distinct := map[string]struct{}{}
for _, e := range fresh {
distinct[e.Account] = struct{}{}
}
if len(distinct) >= minAccounts {
out[string(k)] = append([]AdminEmailEntry(nil), fresh...)
}
return nil
})
})
return out, err
}
package store
import (
"encoding/json"
"fmt"
"time"
"github.com/pidginhost/csm/internal/alert"
bolt "go.etcd.io/bbolt"
)
// SeverityBucket holds aggregated counts by severity.
type SeverityBucket struct {
Critical int `json:"critical"`
High int `json:"high"`
Warning int `json:"warning"`
Total int `json:"total"`
}
// HourBucket is a SeverityBucket keyed by hour label.
type HourBucket struct {
Hour string `json:"hour"`
SeverityBucket
}
// DayBucket is a SeverityBucket keyed by date.
type DayBucket struct {
Date string `json:"date"`
SeverityBucket
}
// AggregateByHour returns 24 hourly buckets (oldest first) for the last 24 hours.
// It seeks directly to the start key in bbolt, scanning only the relevant range.
func (db *DB) AggregateByHour() []HourBucket {
now := time.Now()
currentHour := now.Truncate(time.Hour)
cutoff := currentHour.Add(-23 * time.Hour)
// Map: hours-ago (0=current, 23=oldest) → counts
counts := make(map[int]*SeverityBucket, 24)
for i := 0; i < 24; i++ {
counts[i] = &SeverityBucket{}
}
seekPrefix := fmt.Sprintf("%04d%02d%02d%02d",
cutoff.Year(), cutoff.Month(), cutoff.Day(), cutoff.Hour())
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
if b == nil {
return nil
}
c := b.Cursor()
// Seek to the earliest key that could be in our 24h window.
for k, v := c.Seek([]byte(seekPrefix)); k != nil; k, v = c.Next() {
var f alert.Finding
if err := json.Unmarshal(v, &f); err != nil {
continue
}
if f.Timestamp.Before(cutoff) {
continue
}
if f.Timestamp.After(now) {
continue
}
fHour := f.Timestamp.Truncate(time.Hour)
hoursAgo := int(currentHour.Sub(fHour).Hours())
if hoursAgo < 0 || hoursAgo >= 24 {
continue
}
bucket := counts[hoursAgo]
bucket.Total++
switch f.Severity {
case alert.Critical:
bucket.Critical++
case alert.High:
bucket.High++
case alert.Warning:
bucket.Warning++
}
}
return nil
})
// Build result oldest→newest (23h ago → 0h ago)
result := make([]HourBucket, 24)
for i := 0; i < 24; i++ {
hoursAgo := 23 - i
t := currentHour.Add(-time.Duration(hoursAgo) * time.Hour)
result[i] = HourBucket{
Hour: fmt.Sprintf("%02d:00", t.Hour()),
SeverityBucket: *counts[hoursAgo],
}
}
return result
}
// ReadHistorySince returns all findings since the given time, using bbolt cursor
// seeking for efficiency. Results are newest-first.
func (db *DB) ReadHistorySince(since time.Time) []alert.Finding {
seekPrefix := fmt.Sprintf("%04d%02d%02d%02d%02d%02d",
since.Year(), since.Month(), since.Day(),
since.Hour(), since.Minute(), since.Second())
var results []alert.Finding
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
if b == nil {
return nil
}
c := b.Cursor()
for k, v := c.Seek([]byte(seekPrefix)); k != nil; k, v = c.Next() {
var f alert.Finding
if err := json.Unmarshal(v, &f); err != nil {
continue
}
results = append(results, f)
}
return nil
})
// Reverse to newest-first
for i, j := 0, len(results)-1; i < j; i, j = i+1, j-1 {
results[i], results[j] = results[j], results[i]
}
return results
}
// SearchHistorySince returns up to limit findings since the given time,
// newest-first. The matcher runs while the bbolt cursor walks backward so
// callers that only need a bounded result set do not materialize the whole
// time window first.
func (db *DB) SearchHistorySince(since time.Time, limit int, match func(alert.Finding) bool) []alert.Finding {
if limit <= 0 {
return nil
}
cutoffKey := timeKeyLowerBound(since)
var results []alert.Finding
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
if b == nil {
return nil
}
c := b.Cursor()
for k, v := c.Last(); k != nil && len(results) < limit; k, v = c.Prev() {
if string(k) < cutoffKey {
break
}
var f alert.Finding
if err := json.Unmarshal(v, &f); err != nil {
continue
}
if match != nil && !match(f) {
continue
}
results = append(results, f)
}
return nil
})
return results
}
// AggregateByDay returns 30 daily buckets (oldest first) for the last 30 days.
// Reads from the pre-aggregated stats:daily bucket so the trend chart is
// not affected by history pruning.
func (db *DB) AggregateByDay() []DayBucket {
return db.AggregateByDayN(30)
}
// AggregateByDayN returns `days` daily buckets (oldest first) ending today.
// Days outside [1, dailyRetentionDays] are clamped to that range. Days with
// no recorded findings are returned as zero-value buckets.
func (db *DB) AggregateByDayN(days int) []DayBucket {
if days < 1 {
days = 1
}
if days > dailyRetentionDays {
days = dailyRetentionDays
}
// Hard cap so a future bump to dailyRetentionDays cannot accidentally
// drive a huge slice allocation here. Daily aggregation over more than
// a decade is not a real workload for this surface.
const aggregateBucketHardCap = 4096
if days > aggregateBucketHardCap {
days = aggregateBucketHardCap
}
now := time.Now()
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local)
cutoff := today.AddDate(0, 0, -(days - 1))
buckets := make([]DayBucket, days)
for i := 0; i < days; i++ {
d := cutoff.AddDate(0, 0, i)
buckets[i] = DayBucket{Date: d.Format("2006-01-02")}
}
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(bucketStatsDaily))
if b == nil {
return nil
}
for i := range buckets {
v := b.Get([]byte(buckets[i].Date))
if v == nil {
continue
}
var sb SeverityBucket
if err := json.Unmarshal(v, &sb); err != nil {
continue
}
buckets[i].SeverityBucket = sb
}
return nil
})
return buckets
}
package store
import (
"archive/tar"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strings"
"time"
"github.com/klauspost/compress/zstd"
bolt "go.etcd.io/bbolt"
)
// ArchiveSchemaVersion is the on-wire schema for backup archives. Bump
// when the manifest layout or contents shape changes incompatibly. Old
// CSM binaries refuse archives newer than the version they understand.
const ArchiveSchemaVersion = 1
// Standard entry names inside the tar.
const (
manifestEntry = "manifest.json"
bboltSnapshotEntry = "bbolt.snapshot"
stateEntryPrefix = "state/"
rulesEntryPrefix = "rules/"
)
// Sentinel errors so callers can branch on the failure mode instead of
// matching strings.
var (
ErrSchemaVersionTooNew = errors.New("archive schema version is newer than this binary supports")
ErrPlatformMismatch = errors.New("archive source platform does not match current host")
ErrManifestMissing = errors.New("archive does not contain manifest.json")
ErrCorruptArchive = errors.New("archive is corrupt or not a CSM backup")
)
// Manifest is the JSON header at the top of every archive.
type Manifest struct {
SchemaVersion int `json:"schema_version"`
CSMVersion string `json:"csm_version"`
SourceHostname string `json:"source_hostname"`
SourcePlatform map[string]string `json:"source_platform"`
ExportTS time.Time `json:"export_ts"`
Contents []string `json:"contents"`
BboltBuckets []string `json:"bbolt_buckets,omitempty"`
BboltSHA256 string `json:"bbolt_sha256,omitempty"`
}
// ExportOptions configures Export.
type ExportOptions struct {
StatePath string // /var/lib/csm/state, source for state JSON files
RulesPath string // /opt/csm/rules, source for signature cache (empty -> skip)
DstPath string // .csmbak file to create
Manifest Manifest // caller fills CSMVersion/SourceHostname/SourcePlatform; rest filled here
}
// ExportResult summarises a successful export.
type ExportResult struct {
Path string
Bytes int64
ArchiveSHA256 string
BboltSHA256 string
}
// ImportOptions configures Import.
type ImportOptions struct {
SrcPath string
StatePath string
RulesPath string
Only string // "all" | "baseline" | "firewall"
ForcePlatformMismatch bool
CurrentPlatform map[string]string // for the mismatch check
}
// ImportResult summarises a successful import.
type ImportResult struct {
Manifest Manifest
BucketsRestored []string
StateFiles int
RulesFiles int
}
// Export writes a tar+zstd archive containing a bbolt snapshot, the
// state directory, and (optionally) the signature-rules directory. The
// daemon is the single source of truth for paths; the caller fills the
// manifest with hostname/version/platform.
func (db *DB) Export(opts ExportOptions) (*ExportResult, error) {
if opts.DstPath == "" {
return nil, errors.New("DstPath is empty")
}
if opts.StatePath == "" {
return nil, errors.New("StatePath is empty")
}
man := opts.Manifest
if man.SchemaVersion == 0 {
man.SchemaVersion = ArchiveSchemaVersion
}
if man.ExportTS.IsZero() {
man.ExportTS = time.Now().UTC()
}
man.Contents = []string{"bbolt", "state"}
if opts.RulesPath != "" {
man.Contents = append(man.Contents, "rules")
}
man.BboltBuckets = listBuckets(db)
// Snapshot bbolt to a temp file in the same directory so we can hash it
// and stream it into the tar without holding a long bolt transaction.
snapDir := filepath.Dir(opts.DstPath)
snap, err := os.CreateTemp(snapDir, "csm-export-bbolt-*.snap")
if err != nil {
return nil, fmt.Errorf("creating bbolt snapshot temp: %w", err)
}
snapPath := snap.Name()
defer os.Remove(snapPath)
bboltHash := sha256.New()
err = db.bolt.View(func(tx *bolt.Tx) error {
_, werr := tx.WriteTo(io.MultiWriter(snap, bboltHash))
return werr
})
if err != nil {
_ = snap.Close()
return nil, fmt.Errorf("bbolt snapshot: %w", err)
}
if err = snap.Close(); err != nil {
return nil, fmt.Errorf("closing bbolt snapshot: %w", err)
}
man.BboltSHA256 = hex.EncodeToString(bboltHash.Sum(nil))
// Build the archive on disk; hash it as we write. Close is called
// explicitly below so any close error after fsync is surfaced --
// silently dropping it would mean the operator gets "export
// succeeded" for a file that may not be fully persisted.
out, err := os.OpenFile(opts.DstPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return nil, fmt.Errorf("creating archive: %w", err)
}
// closed gates the deferred Close so the explicit Close below
// can return its error without a double-close. success gates the
// cleanup of the partial archive + companion file: any
// error-return path between OpenFile and the final return removes
// the half-written files so the operator does not mistake them
// for a usable backup.
closed := false
success := false
defer func() {
if !closed {
_ = out.Close()
}
if !success {
_ = os.Remove(opts.DstPath)
_ = os.Remove(opts.DstPath + ".sha256")
}
}()
archHash := sha256.New()
mw := io.MultiWriter(out, archHash)
zw, err := zstd.NewWriter(mw, zstd.WithEncoderLevel(zstd.SpeedDefault))
if err != nil {
return nil, fmt.Errorf("zstd writer: %w", err)
}
tw := tar.NewWriter(zw)
// 1. manifest first so a streaming reader sees schema info before payload.
manBytes, err := json.MarshalIndent(man, "", " ")
if err != nil {
return nil, fmt.Errorf("marshal manifest: %w", err)
}
if err = writeTarFile(tw, manifestEntry, manBytes, man.ExportTS); err != nil {
return nil, err
}
// 2. bbolt snapshot (already on disk).
if err = streamFileToTar(tw, bboltSnapshotEntry, snapPath, man.ExportTS); err != nil {
return nil, err
}
// 3. state files (skip the bbolt file itself, which is captured in step 2).
if _, err = walkDirIntoTar(tw, opts.StatePath, stateEntryPrefix, []string{"csm.db"}, man.ExportTS); err != nil {
return nil, err
}
// 4. rules files (optional).
if opts.RulesPath != "" {
if _, err = walkDirIntoTar(tw, opts.RulesPath, rulesEntryPrefix, nil, man.ExportTS); err != nil {
return nil, err
}
}
if err = tw.Close(); err != nil {
return nil, fmt.Errorf("closing tar: %w", err)
}
if err = zw.Close(); err != nil {
return nil, fmt.Errorf("closing zstd: %w", err)
}
if err = out.Sync(); err != nil {
return nil, fmt.Errorf("fsync archive: %w", err)
}
if err = out.Close(); err != nil {
return nil, fmt.Errorf("closing archive: %w", err)
}
closed = true
info, err := os.Stat(opts.DstPath)
if err != nil {
return nil, fmt.Errorf("stat archive: %w", err)
}
archiveSHA := hex.EncodeToString(archHash.Sum(nil))
// Write companion .sha256 file alongside for operator verification.
companion := opts.DstPath + ".sha256"
companionLine := fmt.Sprintf("%s %s\n", archiveSHA, filepath.Base(opts.DstPath))
if err = os.WriteFile(companion, []byte(companionLine), 0600); err != nil {
return nil, fmt.Errorf("writing companion sha256: %w", err)
}
success = true
return &ExportResult{
Path: opts.DstPath,
Bytes: info.Size(),
ArchiveSHA256: archiveSHA,
BboltSHA256: man.BboltSHA256,
}, nil
}
// Import unpacks an archive into the target state and rules paths. Live
// daemons must be stopped first; callers enforce that before invoking.
func Import(opts ImportOptions) (*ImportResult, error) {
if opts.SrcPath == "" {
return nil, errors.New("SrcPath is empty")
}
if opts.StatePath == "" {
return nil, errors.New("StatePath is empty")
}
only := opts.Only
if only == "" {
only = "all"
}
switch only {
case "all", "baseline", "firewall":
default:
return nil, fmt.Errorf("invalid Only value %q (want all|baseline|firewall)", only)
}
in, err := os.Open(opts.SrcPath)
if err != nil {
return nil, fmt.Errorf("opening archive: %w", err)
}
defer in.Close()
zr, err := zstd.NewReader(in)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrCorruptArchive, err)
}
defer zr.Close()
tr := tar.NewReader(zr)
// Manifest must come first.
hdr, err := tr.Next()
if err != nil {
return nil, fmt.Errorf("%w: reading first entry: %v", ErrCorruptArchive, err)
}
if hdr.Name != manifestEntry {
return nil, fmt.Errorf("%w: first entry is %q, want %q", ErrManifestMissing, hdr.Name, manifestEntry)
}
manBytes, err := io.ReadAll(tr)
if err != nil {
return nil, fmt.Errorf("%w: reading manifest: %v", ErrCorruptArchive, err)
}
var man Manifest
if err = json.Unmarshal(manBytes, &man); err != nil {
return nil, fmt.Errorf("%w: parsing manifest: %v", ErrCorruptArchive, err)
}
if man.SchemaVersion > ArchiveSchemaVersion {
return nil, fmt.Errorf("%w: archive=%d binary=%d", ErrSchemaVersionTooNew, man.SchemaVersion, ArchiveSchemaVersion)
}
if !opts.ForcePlatformMismatch && !platformMatches(man.SourcePlatform, opts.CurrentPlatform) {
return nil, fmt.Errorf("%w: archive=%v current=%v (use --force-platform-mismatch to override)", ErrPlatformMismatch, man.SourcePlatform, opts.CurrentPlatform)
}
// Stage every payload into a temp dir; commit only after a complete
// successful read so a half-imported state is impossible.
stage, err := os.MkdirTemp(filepath.Dir(opts.StatePath), "csm-import-stage-*")
if err != nil {
return nil, fmt.Errorf("creating staging dir: %w", err)
}
defer os.RemoveAll(stage)
stagedBbolt := ""
stagedState := []string{}
stagedRules := []string{}
for {
nextHdr, nextErr := tr.Next()
if nextErr == io.EOF {
break
}
if nextErr != nil {
return nil, fmt.Errorf("%w: reading entry: %v", ErrCorruptArchive, nextErr)
}
if nextHdr.Typeflag != tar.TypeReg {
continue
}
clean := filepath.Clean(nextHdr.Name)
if strings.HasPrefix(clean, "..") || filepath.IsAbs(clean) {
return nil, fmt.Errorf("%w: unsafe entry name %q", ErrCorruptArchive, nextHdr.Name)
}
dst := filepath.Join(stage, clean)
// Defence in depth against zip-slip: confirm the joined path still
// resolves inside the staging dir even after symlinks / unusual
// path elements in the archive header.
rel, relErr := filepath.Rel(stage, dst)
if relErr != nil || rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) || filepath.IsAbs(rel) {
return nil, fmt.Errorf("%w: unsafe entry name %q", ErrCorruptArchive, nextHdr.Name)
}
if mkErr := os.MkdirAll(filepath.Dir(dst), 0700); mkErr != nil {
return nil, fmt.Errorf("staging dir: %w", mkErr)
}
// #nosec G304 -- dst is filepath.Join(stage, clean) where stage is a freshly created MkdirTemp under StatePath's parent and clean has been validated three lines above (no ".." prefix, not absolute). Cannot escape the staging dir.
f, openErr := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if openErr != nil {
return nil, fmt.Errorf("staging file: %w", openErr)
}
// nextHdr.Size bound caps bytes copied so a hostile archive can't
// drive the stage dir to fill the filesystem.
if _, copyErr := io.CopyN(f, tr, nextHdr.Size); copyErr != nil {
_ = f.Close()
return nil, fmt.Errorf("staging copy %s: %w", clean, copyErr)
}
if closeErr := f.Close(); closeErr != nil {
return nil, fmt.Errorf("closing staged file: %w", closeErr)
}
switch {
case clean == bboltSnapshotEntry:
stagedBbolt = dst
case strings.HasPrefix(clean, stateEntryPrefix):
stagedState = append(stagedState, clean)
case strings.HasPrefix(clean, rulesEntryPrefix):
stagedRules = append(stagedRules, clean)
}
}
if (only == "all" || only == "firewall") && stagedBbolt == "" {
return nil, fmt.Errorf("%w: %s import needs bbolt snapshot in archive", ErrCorruptArchive, only)
}
// Verify the staged bbolt snapshot against the hash the manifest recorded
// at export time before any bbolt-consuming import applies payloads.
// zstd's frame CRC catches transport corruption, but a snapshot that was
// altered before compression would otherwise be promoted over csm.db
// unchecked. Empty hash means a pre-hash archive; skip rather than reject.
if (only == "all" || only == "firewall") && stagedBbolt != "" && man.BboltSHA256 != "" {
gotHash, hashErr := sha256File(stagedBbolt)
if hashErr != nil {
return nil, fmt.Errorf("hashing staged bbolt snapshot: %w", hashErr)
}
if gotHash != man.BboltSHA256 {
return nil, fmt.Errorf("%w: bbolt snapshot hash mismatch (archive manifest %s, staged %s)", ErrCorruptArchive, man.BboltSHA256, gotHash)
}
}
res := &ImportResult{Manifest: man}
// Apply state files (always, unless caller filtered everything out).
if only == "all" || only == "baseline" {
if err := os.MkdirAll(opts.StatePath, 0700); err != nil {
return nil, fmt.Errorf("creating state path: %w", err)
}
for _, rel := range stagedState {
src := filepath.Join(stage, rel)
dst := filepath.Join(opts.StatePath, strings.TrimPrefix(rel, stateEntryPrefix))
if err := atomicReplace(src, dst); err != nil {
return nil, fmt.Errorf("restoring %s: %w", rel, err)
}
res.StateFiles++
}
}
// Apply rules files (only=all only; baseline and firewall skip rules).
if only == "all" && opts.RulesPath != "" {
if err := os.MkdirAll(opts.RulesPath, 0700); err != nil {
return nil, fmt.Errorf("creating rules path: %w", err)
}
for _, rel := range stagedRules {
src := filepath.Join(stage, rel)
dst := filepath.Join(opts.RulesPath, strings.TrimPrefix(rel, rulesEntryPrefix))
if err := atomicReplace(src, dst); err != nil {
return nil, fmt.Errorf("restoring %s: %w", rel, err)
}
res.RulesFiles++
}
}
// Apply bbolt:
// only=all -> wholesale rename the snapshot over csm.db
// only=firewall -> open snapshot read-only and copy fw:* buckets
// into the target bbolt
// only=baseline -> skip bbolt entirely
switch only {
case "all":
target := filepath.Join(opts.StatePath, "csm.db")
if err := atomicReplace(stagedBbolt, target); err != nil {
return nil, fmt.Errorf("restoring csm.db: %w", err)
}
// Report every bucket that came from the snapshot.
res.BucketsRestored = append([]string(nil), man.BboltBuckets...)
case "firewall":
restored, err := mergeBucketsFromSnapshot(stagedBbolt, opts.StatePath, isFirewallBucket)
if err != nil {
return nil, fmt.Errorf("merging firewall buckets: %w", err)
}
res.BucketsRestored = restored
case "baseline":
// no bbolt work
}
return res, nil
}
// platformMatches compares the archive's stored platform map against the
// current host. Empty maps match anything (used when a caller hasn't
// supplied detection results, e.g., in some test paths).
func platformMatches(a, b map[string]string) bool {
if len(a) == 0 || len(b) == 0 {
return true
}
keys := []string{"os", "panel", "webserver"}
for _, k := range keys {
if a[k] != b[k] {
return false
}
}
return true
}
// listBuckets returns the bucket names actually present in the running
// DB (not the static bucketNames slice; callers may have been told via
// migration that some are gone).
func listBuckets(db *DB) []string {
out := []string{}
_ = db.bolt.View(func(tx *bolt.Tx) error {
return tx.ForEach(func(name []byte, _ *bolt.Bucket) error {
out = append(out, string(name))
return nil
})
})
sort.Strings(out)
return out
}
func writeTarFile(tw *tar.Writer, name string, data []byte, modTime time.Time) error {
hdr := &tar.Header{
Name: name,
Mode: 0600,
Size: int64(len(data)),
ModTime: modTime,
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("tar header %s: %w", name, err)
}
if _, err := tw.Write(data); err != nil {
return fmt.Errorf("tar write %s: %w", name, err)
}
return nil
}
// streamFileToTar copies srcPath into the tar under name. There is a
// theoretical Stat-then-Copy race when the source file is replaced
// between f.Stat() and io.Copy() -- the tar writer would then see a
// mismatched byte count. In practice the only files this function
// reads on a live daemon are state JSON written via atomic-replace
// (so a mid-write read sees either the old or the new full content)
// and the bbolt snapshot (already a frozen copy on disk). The risk
// is bounded enough to live with for v1.
// sha256File returns the lowercase hex SHA-256 of a file, matching the
// encoding Export records in Manifest.BboltSHA256.
func sha256File(path string) (string, error) {
// #nosec G304 -- path is the staged bbolt snapshot inside the import
// staging dir (MkdirTemp under StatePath), not attacker-controlled.
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
func streamFileToTar(tw *tar.Writer, name, srcPath string, modTime time.Time) error {
// #nosec G304 -- srcPath is supplied by Export's caller (the daemon's control handler, sourced from cfg.StatePath / cfg.Signatures.RulesDir) or is the bbolt snapshot path created via os.CreateTemp earlier in Export. Both are root-controlled; csm runs as root via systemd. Not attacker-controlled.
f, err := os.Open(srcPath)
if err != nil {
return fmt.Errorf("open %s: %w", srcPath, err)
}
defer f.Close()
info, err := f.Stat()
if err != nil {
return fmt.Errorf("stat %s: %w", srcPath, err)
}
hdr := &tar.Header{
Name: name,
Mode: 0600,
Size: info.Size(),
ModTime: modTime,
}
if err := tw.WriteHeader(hdr); err != nil {
return fmt.Errorf("tar header %s: %w", name, err)
}
if _, err := io.Copy(tw, f); err != nil {
return fmt.Errorf("tar copy %s: %w", name, err)
}
return nil
}
// walkDirIntoTar streams every regular file under srcDir into the tar
// under entryPrefix, skipping any base names in the skip list. Returns
// how many files were written.
func walkDirIntoTar(tw *tar.Writer, srcDir, entryPrefix string, skip []string, modTime time.Time) (int, error) {
skipSet := map[string]bool{}
for _, s := range skip {
skipSet[s] = true
}
count := 0
err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
if skipSet[info.Name()] {
return nil
}
rel, err := filepath.Rel(srcDir, path)
if err != nil {
return err
}
entry := entryPrefix + filepath.ToSlash(rel)
if err := streamFileToTar(tw, entry, path, modTime); err != nil {
return err
}
count++
return nil
})
if err != nil && !os.IsNotExist(err) {
return count, err
}
return count, nil
}
// atomicReplace renames src over dst, ensuring the parent directory is
// fsync'd so the rename survives a crash.
func atomicReplace(src, dst string) error {
if err := os.MkdirAll(filepath.Dir(dst), 0700); err != nil {
return err
}
if err := os.Rename(src, dst); err != nil {
return err
}
parent, err := os.Open(filepath.Dir(dst))
if err != nil {
return err
}
defer parent.Close()
return parent.Sync()
}
// mergeBucketsFromSnapshot opens the snapshot bbolt read-only, iterates
// matching buckets, and copies their key/value pairs into the target
// bbolt at statePath/csm.db. The target may or may not exist; Open
// creates it. Returns the bucket names actually merged.
func mergeBucketsFromSnapshot(snapshotPath, statePath string, match func(string) bool) ([]string, error) {
src, err := bolt.Open(snapshotPath, 0600, &bolt.Options{Timeout: 5 * time.Second, ReadOnly: true})
if err != nil {
return nil, fmt.Errorf("opening snapshot: %w", err)
}
defer func() { _ = src.Close() }()
dst, err := Open(statePath)
if err != nil {
return nil, fmt.Errorf("opening target: %w", err)
}
defer func() { _ = dst.Close() }()
merged := []string{}
err = src.View(func(stx *bolt.Tx) error {
return stx.ForEach(func(name []byte, sb *bolt.Bucket) error {
if !match(string(name)) {
return nil
}
if upErr := dst.bolt.Update(func(dtx *bolt.Tx) error {
db, berr := dtx.CreateBucketIfNotExists(name)
if berr != nil {
return berr
}
return sb.ForEach(func(k, v []byte) error {
return db.Put(append([]byte(nil), k...), append([]byte(nil), v...))
})
}); upErr != nil {
return upErr
}
merged = append(merged, string(name))
return nil
})
})
if err != nil {
return merged, err
}
sort.Strings(merged)
return merged, nil
}
func isFirewallBucket(name string) bool {
return strings.HasPrefix(name, "fw:")
}
package store
import (
"bytes"
"encoding/json"
"fmt"
"time"
bolt "go.etcd.io/bbolt"
)
// maxAttackEvents is the maximum number of attack events to retain.
// It is a var (not const) so tests can override it.
var maxAttackEvents = 100_000
// AttackEvent is the store-layer representation of an attack event.
type AttackEvent struct {
Timestamp time.Time `json:"timestamp"`
IP string `json:"ip"`
AttackType string `json:"attack_type"`
CheckName string `json:"check_name"`
Severity int `json:"severity"`
Account string `json:"account,omitempty"`
Message string `json:"message,omitempty"`
}
// IPRecord is the store-layer representation of an IP attack record.
type IPRecord struct {
IP string `json:"ip"`
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
EventCount int `json:"event_count"`
AttackCounts map[string]int `json:"attack_counts,omitempty"`
Accounts map[string]int `json:"accounts,omitempty"`
ThreatScore int `json:"threat_score"`
AutoBlocked bool `json:"auto_blocked,omitempty"`
BruteForceWindowStart time.Time `json:"brute_force_window_start,omitempty"`
BruteForceWindowCount int `json:"brute_force_window_count,omitempty"`
BruteForceSustainedAt time.Time `json:"brute_force_sustained_at,omitempty"`
}
// RecordAttackEvent inserts an attack event into both the primary bucket
// (attacks:events, keyed by TimeKey) and the secondary index bucket
// (attacks:events:ip, keyed by IP/TimeKey). It increments the event counter
// and prunes oldest entries if the count exceeds maxAttackEvents.
func (db *DB) RecordAttackEvent(event AttackEvent, counter int) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
primary := tx.Bucket([]byte("attacks:events"))
secondary := tx.Bucket([]byte("attacks:events:ip"))
key := TimeKey(event.Timestamp, counter)
val, err := json.Marshal(event)
if err != nil {
return err
}
if err := primary.Put([]byte(key), val); err != nil {
return err
}
secondaryKey := event.IP + "/" + key
if err := secondary.Put([]byte(secondaryKey), val); err != nil {
return err
}
if err := incrCounter(tx, "attacks:events:count", 1); err != nil {
return err
}
// Prune oldest entries if count exceeds maxAttackEvents.
meta := tx.Bucket([]byte("meta"))
var count int
if v := meta.Get([]byte("attacks:events:count")); v != nil {
_, _ = fmt.Sscanf(string(v), "%d", &count)
}
if count > maxAttackEvents {
excess := count - maxAttackEvents
c := primary.Cursor()
k, v := c.First()
for ; k != nil && excess > 0; excess-- {
// Unmarshal to get the IP for secondary index cleanup.
var ev AttackEvent
if err := json.Unmarshal(v, &ev); err != nil {
return err
}
secKey := ev.IP + "/" + string(k)
if err := secondary.Delete([]byte(secKey)); err != nil {
return err
}
if err := c.Delete(); err != nil {
return err
}
// Re-seek after delete (bbolt cursor behavior).
k, v = c.First()
}
if err := setCounter(tx, "attacks:events:count", maxAttackEvents); err != nil {
return err
}
}
return nil
})
}
// QueryAttackEvents returns up to limit attack events for the given IP,
// newest-first. It uses the secondary index bucket for efficient prefix-based
// iteration.
func (db *DB) QueryAttackEvents(ip string, limit int) []AttackEvent {
var results []AttackEvent
prefix := []byte(ip + "/")
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:events:ip"))
c := b.Cursor()
// Collect all events matching the prefix.
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
var ev AttackEvent
if err := json.Unmarshal(v, &ev); err == nil {
results = append(results, ev)
}
}
return nil
})
// Reverse for newest-first order.
for i, j := 0, len(results)-1; i < j; i, j = i+1, j-1 {
results[i], results[j] = results[j], results[i]
}
// Take up to limit.
if len(results) > limit {
results = results[:limit]
}
return results
}
// SaveIPRecord stores an IP record in the attacks:records bucket, keyed by IP.
func (db *DB) SaveIPRecord(record IPRecord) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:records"))
val, err := json.Marshal(record)
if err != nil {
return err
}
return b.Put([]byte(record.IP), val)
})
}
// LoadIPRecord retrieves an IP record from the attacks:records bucket.
// Returns the record and true if found, or a zero value and false if not.
func (db *DB) LoadIPRecord(ip string) (IPRecord, bool) {
var record IPRecord
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:records"))
v := b.Get([]byte(ip))
if v == nil {
return nil
}
if json.Unmarshal(v, &record) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return record, found
}
// LoadAllIPRecords returns all IP records from the attacks:records bucket.
func (db *DB) LoadAllIPRecords() map[string]*IPRecord {
records := make(map[string]*IPRecord)
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:records"))
return b.ForEach(func(k, v []byte) error {
var record IPRecord
if json.Unmarshal(v, &record) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
records[string(k)] = &record
return nil
})
})
return records
}
// DeleteIPRecord removes an IP record from the attacks:records bucket.
func (db *DB) DeleteIPRecord(ip string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:records"))
return b.Delete([]byte(ip))
})
}
// ReadAllAttackEvents returns all attack events from the primary bucket.
// Used for stats computation (hourly/daily bucketing).
func (db *DB) ReadAllAttackEvents() []AttackEvent {
var events []AttackEvent
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:events"))
return b.ForEach(func(k, v []byte) error {
var ev AttackEvent
if json.Unmarshal(v, &ev) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
events = append(events, ev)
return nil
})
})
return events
}
package store
import (
"encoding/binary"
"net"
"time"
bolt "go.etcd.io/bbolt"
)
// botVerifyEntry layout:
//
// byte 0 verified flag (1 byte)
// bytes 1..8 expiry as unix nanos (8 bytes big-endian)
//
// Key format:
//
// <bot-bytes> 0x00 <ip-bytes>
//
// IP is stored in its 16-byte form (IPv4 promoted via To16). The
// 0x00 separator allows the bot name to be any non-null string while
// keeping keys byte-sortable within a single bucket.
func botVerifyKey(ip net.IP, bot string) []byte {
ipBytes := ip.To16()
if ipBytes == nil {
return nil
}
out := make([]byte, 0, len(bot)+1+16)
out = append(out, []byte(bot)...)
out = append(out, 0x00)
out = append(out, ipBytes...)
return out
}
// PutBotVerify stores a PTR+forward-A verification result with an
// explicit expiry. A verified=false entry means the IP failed rDNS
// and will emit http_ua_spoof on the next scan that sees it with the
// same bot UA.
func (db *DB) PutBotVerify(ip net.IP, bot string, verified bool, expiresAt time.Time) error {
key := botVerifyKey(ip, bot)
if key == nil {
return nil
}
var val [9]byte
if verified {
val[0] = 1
}
binary.BigEndian.PutUint64(val[1:], uint64(expiresAt.UnixNano())) // #nosec G115 -- unix nano stored as bit pattern; sign is irrelevant for expiry comparison
return db.bolt.Update(func(tx *bolt.Tx) error {
b, err := tx.CreateBucketIfNotExists([]byte("botverify"))
if err != nil {
return err
}
return b.Put(key, val[:])
})
}
// EnsureBotVerifyLogicVersion compares the stored cache logic version
// with version and, on mismatch (or when no marker exists yet), drops
// the entire botverify bucket and records the new version. The marker
// lives in the "meta" bucket under botverify:logic_version. Returns
// true when the bucket was dropped.
//
// Use this from daemon startup so that any change to the verifier
// logic (BotDomains suffix list, ClaimedBotFromUA mapping, etc.)
// automatically invalidates entries written under the old rules.
// Operators do not need to know about the cache.
func (db *DB) EnsureBotVerifyLogicVersion(version int) (bool, error) {
var dropped bool
err := db.bolt.Update(func(tx *bolt.Tx) error {
meta, mErr := tx.CreateBucketIfNotExists([]byte("meta"))
if mErr != nil {
return mErr
}
current := uint64(version) // #nosec G115 -- logic version is a small positive internal constant
stored := ^uint64(0)
if raw := meta.Get([]byte("botverify:logic_version")); len(raw) == 8 {
stored = binary.BigEndian.Uint64(raw)
}
if stored == current {
return nil
}
if b := tx.Bucket([]byte("botverify")); b != nil {
if dErr := tx.DeleteBucket([]byte("botverify")); dErr != nil {
return dErr
}
}
if _, cErr := tx.CreateBucket([]byte("botverify")); cErr != nil {
return cErr
}
var buf [8]byte
binary.BigEndian.PutUint64(buf[:], current)
if pErr := meta.Put([]byte("botverify:logic_version"), buf[:]); pErr != nil {
return pErr
}
dropped = true
return nil
})
if err != nil {
return false, err
}
return dropped, nil
}
// ResetBotVerify drops every cached PTR+forward-A result. Returns the
// number of entries cleared. Use after a verifier-logic upgrade that
// would invalidate prior negative cache entries (e.g., a domain suffix
// fix that turns prior false-spoof entries into positives). Safe to
// call when the bucket is missing or empty.
func (db *DB) ResetBotVerify() (int, error) {
var cleared int
err := db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("botverify"))
if b == nil {
return nil
}
cleared = b.Stats().KeyN
if err := tx.DeleteBucket([]byte("botverify")); err != nil {
return err
}
_, err := tx.CreateBucket([]byte("botverify"))
return err
})
if err != nil {
return 0, err
}
return cleared, nil
}
// GetBotVerify returns (verified, valid). valid=false means no
// non-expired entry exists; the caller should treat the IP as
// unverified and (optionally) enqueue an async verify job.
func (db *DB) GetBotVerify(ip net.IP, bot string) (verified, valid bool) {
key := botVerifyKey(ip, bot)
if key == nil {
return false, false
}
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("botverify"))
if b == nil {
return nil
}
val := b.Get(key)
if len(val) != 9 {
return nil
}
exp := time.Unix(0, int64(binary.BigEndian.Uint64(val[1:]))) // #nosec G115 -- reinterpret stored bit pattern as signed nanos
if time.Now().After(exp) {
return nil
}
verified = val[0] == 1
valid = true
return nil
})
return verified, valid
}
package store
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
bolt "go.etcd.io/bbolt"
)
// Bucket names - all buckets are created on Open().
var bucketNames = []string{
"history",
"attacks:records",
"attacks:events",
"attacks:events:ip",
"threats",
"threats:whitelist",
"fw:blocked",
"fw:allowed",
"fw:subnets",
"fw:port_allowed",
"reputation",
"plugins",
"plugins:sites",
"meta",
"email:geo",
"email:fwd",
"db_object_backups",
"sig_watch",
bucketStatsDaily,
"phprelay:meta",
"phprelay:msgindex",
"phprelay:ignore",
"phprelay:settings",
"incidents",
"fw:rollback",
adminEmailsBucket,
"botverify",
prefsBucket,
}
// DB wraps a bbolt database.
type DB struct {
bolt *bolt.DB
path string
}
var (
globalDB *DB
globalMu sync.Mutex
ensureOnce sync.Once
ensureErr error
)
// Global returns the singleton DB instance.
func Global() *DB {
globalMu.Lock()
defer globalMu.Unlock()
return globalDB
}
// SetGlobal sets the singleton DB instance.
func SetGlobal(db *DB) {
globalMu.Lock()
globalDB = db
globalMu.Unlock()
}
// EnsureOpen opens the store if not already open. Safe to call from any CLI path.
// First call opens the DB; subsequent calls return immediately.
func EnsureOpen(statePath string) error {
ensureOnce.Do(func() {
db, err := Open(statePath)
if err != nil {
ensureErr = err
return
}
SetGlobal(db)
})
return ensureErr
}
// Open opens or creates the bbolt database at {statePath}/csm.db.
// Creates all buckets if they don't exist. Runs migration if needed.
func Open(statePath string) (*DB, error) {
dbPath := filepath.Join(statePath, "csm.db")
if err := os.MkdirAll(statePath, 0700); err != nil {
return nil, fmt.Errorf("creating state dir: %w", err)
}
bdb, err := bolt.Open(dbPath, 0600, &bolt.Options{Timeout: 5 * time.Second})
if err != nil {
return nil, fmt.Errorf("opening bbolt: %w", err)
}
// Create all buckets
err = bdb.Update(func(tx *bolt.Tx) error {
for _, name := range bucketNames {
if _, berr := tx.CreateBucketIfNotExists([]byte(name)); berr != nil {
return fmt.Errorf("creating bucket %s: %w", name, berr)
}
}
// Initialise phprelay schema_version on first open. Stored as
// 8-byte big-endian uint64 so future migrations can read/compare
// it consistently.
meta := tx.Bucket([]byte("phprelay:meta"))
if meta.Get([]byte("schema_version")) == nil {
if perr := meta.Put([]byte("schema_version"), []byte{0, 0, 0, 0, 0, 0, 0, 1}); perr != nil {
return fmt.Errorf("init phprelay schema_version: %w", perr)
}
}
return nil
})
if err != nil {
_ = bdb.Close()
return nil, err
}
db := &DB{bolt: bdb, path: dbPath}
// Run migration if needed. A failed migration does not set the "migrated"
// sentinel, so proceeding would boot the daemon on partial security state
// and retry the same broken migration every restart. Fail loud instead.
if err := db.migrateIfNeeded(statePath); err != nil {
_ = bdb.Close()
return nil, fmt.Errorf("store migration: %w", err)
}
// One-time backfill of stats:daily from existing history. Runs on
// hosts upgrading from a build that pre-dates the stats:daily bucket;
// no-op afterwards thanks to a meta sentinel.
if err := db.BackfillStatsDaily(); err != nil {
fmt.Fprintf(os.Stderr, "store: stats:daily backfill warning: %v\n", err)
}
if err := db.seedDefaultModSecNoEscalateRules(); err != nil {
fmt.Fprintf(os.Stderr, "store: ModSecurity no-escalate seed warning: %v\n", err)
}
return db, nil
}
const (
modsecNoEscalateSeededKey = "modsec:no_escalate_seeded"
defaultModSecNoEscalateWPEnumerationID = 900112
)
func (db *DB) seedDefaultModSecNoEscalateRules() error {
return db.bolt.Update(func(tx *bolt.Tx) error {
meta := tx.Bucket([]byte("meta"))
if meta.Get([]byte(modsecNoEscalateSeededKey)) != nil {
return nil
}
if meta.Get([]byte(modsecNoEscalateKey)) != nil {
return meta.Put([]byte(modsecNoEscalateSeededKey), []byte("1"))
}
// WordPress user enumeration is blocked at the HTTP layer only.
val, err := json.Marshal([]int{defaultModSecNoEscalateWPEnumerationID})
if err != nil {
return err
}
if err := meta.Put([]byte(modsecNoEscalateKey), val); err != nil {
return err
}
return meta.Put([]byte(modsecNoEscalateSeededKey), []byte("1"))
})
}
// Close closes the bbolt database.
func (db *DB) Close() error {
if db.bolt == nil {
return nil
}
return db.bolt.Close()
}
// Path returns the on-disk path of the bbolt database file.
func (db *DB) Path() string {
return db.path
}
// HasBucket reports whether a top-level bucket named name exists in db.
func (db *DB) HasBucket(name string) bool {
found := false
_ = db.bolt.View(func(tx *bolt.Tx) error {
if tx.Bucket([]byte(name)) != nil {
found = true
}
return nil
})
return found
}
// TimeKey produces a fixed-width 28-byte key for chronological ordering.
// Format: YYYYMMDDHHmmssNNNNNNNNN-CCCC
// Lexicographic order equals chronological order.
func TimeKey(t time.Time, counter int) string {
return fmt.Sprintf("%04d%02d%02d%02d%02d%02d%09d-%04d",
t.Year(), t.Month(), t.Day(),
t.Hour(), t.Minute(), t.Second(),
t.Nanosecond(), counter)
}
// ParseTimeKeyPrefix converts a date string "YYYY-MM-DD" to a seek prefix "YYYYMMDD".
func ParseTimeKeyPrefix(date string) string {
if len(date) == 10 && date[4] == '-' && date[7] == '-' {
return date[:4] + date[5:7] + date[8:10]
}
return date
}
// getCounter reads a counter from the meta bucket. Returns 0 if not found.
// HistoryCount returns the number of findings in the history bucket.
func (db *DB) HistoryCount() int {
return db.getCounter("history:count")
}
func (db *DB) getCounter(key string) int {
var count int
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
if v := b.Get([]byte(key)); v != nil {
_, _ = fmt.Sscanf(string(v), "%d", &count)
}
return nil
})
return count
}
// setCounter writes a counter to the meta bucket within an existing transaction.
func setCounter(tx *bolt.Tx, key string, count int) error {
b := tx.Bucket([]byte("meta"))
return b.Put([]byte(key), []byte(fmt.Sprintf("%d", count)))
}
// IsHealthy returns true if the bbolt file is open and all required buckets exist.
func (db *DB) IsHealthy() bool {
if db == nil || db.bolt == nil {
return false
}
err := db.bolt.View(func(tx *bolt.Tx) error {
for _, name := range []string{"history", "fw:blocked", "meta"} {
if tx.Bucket([]byte(name)) == nil {
return fmt.Errorf("bucket missing: %s", name)
}
}
return nil
})
return err == nil
}
// SizeBytes returns the on-disk size of the bbolt database file. Returns 0 if unavailable.
func (db *DB) SizeBytes() int64 {
if db == nil || db.path == "" {
return 0
}
info, err := os.Stat(db.path)
if err != nil {
return 0
}
return info.Size()
}
// incrCounter increments a counter within an existing transaction.
func incrCounter(tx *bolt.Tx, key string, delta int) error {
b := tx.Bucket([]byte("meta"))
var current int
if v := b.Get([]byte(key)); v != nil {
fmt.Sscanf(string(v), "%d", ¤t)
}
return b.Put([]byte(key), []byte(fmt.Sprintf("%d", current+delta)))
}
type dryRunBlockRecord struct {
IP string `json:"ip"`
Reason string `json:"reason"`
TimeoutSec int `json:"timeout_sec"`
}
// RecordDryRunBlock appends a dry-run-block record to the "dry_run_blocks"
// bucket. Called by the firewall engine when auto_response.dry_run is active
// so operators can review "what would have been blocked" before going live.
func (db *DB) RecordDryRunBlock(ip, reason string, timeout time.Duration) {
if db == nil || db.bolt == nil {
return
}
// Log-derived reasons can carry raw control bytes; the JSON encoder
// emits escape forms that dry-run readers can decode.
payload := dryRunBlockRecord{
IP: ip,
Reason: reason,
TimeoutSec: int(timeout.Seconds()),
}
val, err := json.Marshal(payload)
if err != nil {
return
}
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b, err := tx.CreateBucketIfNotExists([]byte("dry_run_blocks"))
if err != nil {
return err
}
key := []byte(time.Now().UTC().Format(time.RFC3339Nano) + ":" + ip)
return b.Put(key, val)
})
}
// PurgeAllDryRunBlocks deletes every record from the dry_run_blocks
// bucket and returns the number removed. Called when the operator
// flips auto_response.dry_run from true to false so /api/v1/status no
// longer reports a stale count from the previous dry-run window. A
// later periodic prune handles the slow accumulation case via
// PurgeDryRunBlocksOlderThan.
func (db *DB) PurgeAllDryRunBlocks() int {
if db == nil || db.bolt == nil {
return 0
}
removed := 0
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("dry_run_blocks"))
if b == nil {
return nil
}
removed = b.Stats().KeyN
return tx.DeleteBucket([]byte("dry_run_blocks"))
})
return removed
}
// PurgeDryRunBlocksOlderThan removes every dry_run_blocks record
// whose timestamp prefix is strictly older than cutoff. Returns the
// number removed. Key format is "<RFC3339Nano>:<ip>"; entries with a
// key that does not parse as a timestamp are left in place so a
// future key-format change does not silently drop records.
func (db *DB) PurgeDryRunBlocksOlderThan(cutoff time.Time) int {
if db == nil || db.bolt == nil {
return 0
}
removed := 0
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("dry_run_blocks"))
if b == nil {
return nil
}
var stale [][]byte
_ = b.ForEach(func(k, _ []byte) error {
s := string(k)
// RecordDryRunBlock writes UTC timestamps which always
// end in `Z`, so the first `Z:` reliably separates the
// timestamp from the IP (the IP itself can contain
// colons in v6 form, ruling out a naive first-colon
// split).
idx := strings.Index(s, "Z:")
if idx < 0 {
return nil
}
ts, err := time.Parse(time.RFC3339Nano, s[:idx+1])
if err != nil {
// Forward-compat: unrecognised key format is left
// in place rather than treated as stale, so a
// future schema change does not silently drop
// records during the rolling upgrade window.
return nil //nolint:nilerr // intentional skip
}
if ts.Before(cutoff) {
keyCopy := append([]byte(nil), k...)
stale = append(stale, keyCopy)
}
return nil
})
for _, k := range stale {
if err := b.Delete(k); err == nil {
removed++
}
}
return nil
})
return removed
}
// DryRunBlocksCount returns the number of recorded dry-run block entries.
func (db *DB) DryRunBlocksCount() int {
if db == nil || db.bolt == nil {
return 0
}
count := 0
_ = db.bolt.View(func(tx *bolt.Tx) error {
if b := tx.Bucket([]byte("dry_run_blocks")); b != nil {
count = b.Stats().KeyN
}
return nil
})
return count
}
// migrateIfNeeded checks for the meta:migrated key and runs migration if absent.
func (db *DB) migrateIfNeeded(statePath string) error {
var migrated bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
if b.Get([]byte("migrated")) != nil {
migrated = true
}
return nil
})
if migrated {
return nil
}
return db.runMigration(statePath)
}
package store
import (
"encoding/json"
"fmt"
"time"
bolt "go.etcd.io/bbolt"
)
// DBObjectBackup is the persisted record of a SHOW CREATE captured
// before a manual `csm db-clean drop-object`. The CREATE SQL is the
// backup -- replaying it restores the object verbatim. Fields are
// public so the cleanup-history UI can render them without a
// separate API.
type DBObjectBackup struct {
Account string `json:"account"`
Schema string `json:"schema"`
Kind string `json:"kind"` // trigger | event | procedure | function
Name string `json:"name"`
CreateSQL string `json:"create_sql"`
DroppedAt time.Time `json:"dropped_at"`
DroppedBy string `json:"dropped_by"` // operator login or "csm" for daemon-driven
FindingID string `json:"finding_id,omitempty"`
RestoredAt time.Time `json:"restored_at,omitempty"`
}
// PutDBObjectBackup writes one backup record. Key shape:
// `<account>:<schema>:<kind>:<name>:<unix_nanos>` so multiple drops of
// the same object name (e.g., re-creates by an attacker) each get
// their own record.
func (db *DB) PutDBObjectBackup(b DBObjectBackup) error {
if b.Account == "" || b.Schema == "" || b.Kind == "" || b.Name == "" {
return fmt.Errorf("PutDBObjectBackup: account/schema/kind/name all required")
}
if b.DroppedAt.IsZero() {
b.DroppedAt = time.Now().UTC()
}
key := fmt.Sprintf("%s:%s:%s:%s:%d",
b.Account, b.Schema, b.Kind, b.Name, b.DroppedAt.UnixNano())
payload, err := json.Marshal(b)
if err != nil {
return fmt.Errorf("marshal backup: %w", err)
}
return db.bolt.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte("db_object_backups"))
if bucket == nil {
return fmt.Errorf("db_object_backups bucket missing (store not migrated)")
}
return bucket.Put([]byte(key), payload)
})
}
// ListDBObjectBackups returns every record for the given account, in
// insertion order. Used by the CLI's listing path and cleanup-history UI.
func (db *DB) ListDBObjectBackups(account string) ([]DBObjectBackup, error) {
var out []DBObjectBackup
prefix := []byte(account + ":")
err := db.bolt.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte("db_object_backups"))
if bucket == nil {
return nil
}
c := bucket.Cursor()
for k, v := c.Seek(prefix); k != nil && hasPrefix(k, prefix); k, v = c.Next() {
var b DBObjectBackup
if err := json.Unmarshal(v, &b); err != nil {
continue
}
out = append(out, b)
}
return nil
})
return out, err
}
// GetDBObjectBackupByKey fetches a single record by its exact bbolt
// key. Returns ok=false (not an error) when the key is missing,
// matching the lookup-then-act flow callers use.
func (db *DB) GetDBObjectBackupByKey(key string) (DBObjectBackup, bool, error) {
var rec DBObjectBackup
var found bool
err := db.bolt.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte("db_object_backups"))
if bucket == nil {
return nil
}
raw := bucket.Get([]byte(key))
if raw == nil {
return nil
}
if err := json.Unmarshal(raw, &rec); err != nil {
return err
}
found = true
return nil
})
return rec, found, err
}
// MarkDBObjectBackupRestored records that a backup has been replayed. The
// backup row stays in place for audit and future manual inspection, but the
// WebUI can stop offering repeat restore actions for that exact archive.
func (db *DB) MarkDBObjectBackupRestored(key string, restoredAt time.Time) error {
if restoredAt.IsZero() {
restoredAt = time.Now().UTC()
}
return db.bolt.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte("db_object_backups"))
if bucket == nil {
return fmt.Errorf("db_object_backups bucket missing (store not migrated)")
}
raw := bucket.Get([]byte(key))
if raw == nil {
return nil
}
var rec DBObjectBackup
if err := json.Unmarshal(raw, &rec); err != nil {
return err
}
rec.RestoredAt = restoredAt.UTC()
payload, err := json.Marshal(rec)
if err != nil {
return fmt.Errorf("marshal backup: %w", err)
}
return bucket.Put([]byte(key), payload)
})
}
// ListDBObjectBackupsAll returns every record in the bucket,
// regardless of account, in insertion order. Used by the webui
// cleanup-history listing where the operator browses across all
// accounts at once.
func (db *DB) ListDBObjectBackupsAll() ([]DBObjectBackup, []string, error) {
var records []DBObjectBackup
var keys []string
err := db.bolt.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte("db_object_backups"))
if bucket == nil {
return nil
}
return bucket.ForEach(func(k, v []byte) error {
var b DBObjectBackup
// Skip malformed rows silently; returning the unmarshal
// error from ForEach would abort the entire iteration,
// which is the wrong choice when one bad row shouldn't
// hide every other operator's history.
if json.Unmarshal(v, &b) == nil {
records = append(records, b)
keys = append(keys, string(k))
}
return nil
})
})
return records, keys, err
}
func hasPrefix(b, prefix []byte) bool {
if len(b) < len(prefix) {
return false
}
for i, c := range prefix {
if b[i] != c {
return false
}
}
return true
}
package store
import (
"encoding/json"
"time"
bolt "go.etcd.io/bbolt"
)
// GeoHistory tracks the countries from which a mailbox has logged in.
type GeoHistory struct {
Countries map[string]int64 `json:"countries"`
LoginCount int `json:"login_count"`
}
// SetGeoHistory stores geo login history for a mailbox in the email:geo bucket.
func (db *DB) SetGeoHistory(mailbox string, h GeoHistory) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("email:geo"))
val, err := json.Marshal(h)
if err != nil {
return err
}
return b.Put([]byte(mailbox), val)
})
}
// GetGeoHistory retrieves geo login history for a mailbox.
// Returns the entry and true if found, or a zero value and false if not.
func (db *DB) GetGeoHistory(mailbox string) (GeoHistory, bool) {
var h GeoHistory
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("email:geo"))
v := b.Get([]byte(mailbox))
if v == nil {
return nil
}
if json.Unmarshal(v, &h) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return h, found
}
// SetForwarderHash stores a forwarder config hash in the email:fwd bucket.
func (db *DB) SetForwarderHash(key, hash string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("email:fwd"))
return b.Put([]byte(key), []byte(hash))
})
}
// GetForwarderHash retrieves a forwarder config hash.
// Returns the hash and true if found, or an empty string and false if not.
func (db *DB) GetForwarderHash(key string) (string, bool) {
var hash string
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("email:fwd"))
v := b.Get([]byte(key))
if v == nil {
return nil
}
hash = string(v)
found = true
return nil
})
return hash, found
}
// GetEmailPWLastRefresh reads the last email password-check refresh timestamp
// from the meta bucket. Returns the zero time if not set.
func (db *DB) GetEmailPWLastRefresh() time.Time {
var t time.Time
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
v := b.Get([]byte("email:pw_last_refresh"))
if v == nil {
return nil
}
parsed, err := time.Parse(time.RFC3339, string(v))
if err != nil {
return nil //nolint:nilerr // skip corrupt entry
}
t = parsed
return nil
})
return t
}
// SetEmailPWLastRefresh writes the email password-check refresh timestamp
// to the meta bucket.
func (db *DB) SetEmailPWLastRefresh(t time.Time) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
return b.Put([]byte("email:pw_last_refresh"), []byte(t.Format(time.RFC3339)))
})
}
// GetMetaString reads a string value from the meta bucket.
// Returns an empty string if the key is not found.
func (db *DB) GetMetaString(key string) string {
var val string
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
v := b.Get([]byte(key))
if v == nil {
return nil
}
val = string(v)
return nil
})
return val
}
// SetMetaString writes a string value to the meta bucket.
func (db *DB) SetMetaString(key, val string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
return b.Put([]byte(key), []byte(val))
})
}
package store
import (
"encoding/json"
"fmt"
"time"
"github.com/pidginhost/csm/internal/firewall"
bolt "go.etcd.io/bbolt"
)
// FWBlockedEntry represents an IP blocked by the firewall.
type FWBlockedEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"` // zero = permanent
}
// FWAllowedEntry represents an IP explicitly allowed through the firewall.
type FWAllowedEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
Port int `json:"port"` // 0 = all ports
ExpiresAt time.Time `json:"expires_at"` // zero = permanent
}
// FWSubnetEntry represents a subnet added to the firewall.
type FWSubnetEntry struct {
CIDR string `json:"cidr"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
AddedAt time.Time `json:"added_at"`
}
// FWPortAllowEntry represents a per-IP port allow rule.
type FWPortAllowEntry struct {
Key string `json:"key"` // IP:port/proto
IP string `json:"ip"`
Port int `json:"port"`
Proto string `json:"proto"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
}
// FirewallState holds the full state across all 4 firewall buckets.
type FirewallState struct {
Blocked []FWBlockedEntry
Allowed []FWAllowedEntry
Subnets []FWSubnetEntry
PortAllowed []FWPortAllowEntry
}
// portAllowKey returns the composite key "IP:port/proto".
func portAllowKey(ip string, port int, proto string) string {
return fmt.Sprintf("%s:%d/%s", ip, port, proto)
}
// BlockIP adds an IP to the fw:blocked bucket.
func (db *DB) BlockIP(ip, reason string, expiresAt time.Time) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:blocked"))
entry := FWBlockedEntry{
IP: ip,
Reason: reason,
Source: firewall.InferProvenance("block", reason),
BlockedAt: time.Now(),
ExpiresAt: expiresAt,
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(ip), val)
})
}
// UnblockIP removes an IP from the fw:blocked bucket.
func (db *DB) UnblockIP(ip string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:blocked"))
return b.Delete([]byte(ip))
})
}
// GetBlockedIP looks up a blocked IP. Returns false if not found or expired.
func (db *DB) GetBlockedIP(ip string) (FWBlockedEntry, bool) {
var entry FWBlockedEntry
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:blocked"))
v := b.Get([]byte(ip))
if v == nil {
return nil
}
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
// Filter expired entries (zero ExpiresAt = permanent).
if !entry.ExpiresAt.IsZero() && !entry.ExpiresAt.After(time.Now()) {
return nil
}
found = true
return nil
})
return entry, found
}
// AllowIP adds an IP to the fw:allowed bucket.
func (db *DB) AllowIP(ip, reason string, expiresAt time.Time) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:allowed"))
entry := FWAllowedEntry{
IP: ip,
Reason: reason,
Source: firewall.InferProvenance("allow", reason),
ExpiresAt: expiresAt,
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(ip), val)
})
}
// RemoveAllow removes an IP from the fw:allowed bucket.
func (db *DB) RemoveAllow(ip string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:allowed"))
return b.Delete([]byte(ip))
})
}
// AddSubnet adds a CIDR to the fw:subnets bucket.
func (db *DB) AddSubnet(cidr, reason string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:subnets"))
entry := FWSubnetEntry{
CIDR: cidr,
Reason: reason,
Source: firewall.InferProvenance("block_subnet", reason),
AddedAt: time.Now(),
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(cidr), val)
})
}
// RemoveSubnet removes a CIDR from the fw:subnets bucket.
func (db *DB) RemoveSubnet(cidr string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:subnets"))
return b.Delete([]byte(cidr))
})
}
// AddPortAllow adds a per-IP port allow rule to the fw:port_allowed bucket.
func (db *DB) AddPortAllow(ip string, port int, proto, reason string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:port_allowed"))
key := portAllowKey(ip, port, proto)
entry := FWPortAllowEntry{
Key: key,
IP: ip,
Port: port,
Proto: proto,
Reason: reason,
Source: firewall.InferProvenance("allow_port", reason),
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(key), val)
})
}
// RemovePortAllow removes a per-IP port allow rule from the fw:port_allowed bucket.
func (db *DB) RemovePortAllow(ip string, port int, proto string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:port_allowed"))
key := portAllowKey(ip, port, proto)
return b.Delete([]byte(key))
})
}
// ListPortAllows returns all entries in the fw:port_allowed bucket.
func (db *DB) ListPortAllows() []FWPortAllowEntry {
var entries []FWPortAllowEntry
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:port_allowed"))
return b.ForEach(func(k, v []byte) error {
var entry FWPortAllowEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
entries = append(entries, entry)
return nil
})
})
return entries
}
// LoadFirewallState reads all 4 firewall buckets and assembles a FirewallState.
// Expired blocked and allowed entries are filtered out.
func (db *DB) LoadFirewallState() FirewallState {
var state FirewallState
now := time.Now()
_ = db.bolt.View(func(tx *bolt.Tx) error {
// fw:blocked - filter expired
blocked := tx.Bucket([]byte("fw:blocked"))
_ = blocked.ForEach(func(k, v []byte) error {
var entry FWBlockedEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
if !entry.ExpiresAt.IsZero() && !entry.ExpiresAt.After(now) {
return nil // expired
}
state.Blocked = append(state.Blocked, entry)
return nil
})
// fw:allowed - filter expired
allowed := tx.Bucket([]byte("fw:allowed"))
_ = allowed.ForEach(func(k, v []byte) error {
var entry FWAllowedEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
if !entry.ExpiresAt.IsZero() && !entry.ExpiresAt.After(now) {
return nil // expired
}
state.Allowed = append(state.Allowed, entry)
return nil
})
// fw:subnets
subnets := tx.Bucket([]byte("fw:subnets"))
_ = subnets.ForEach(func(k, v []byte) error {
var entry FWSubnetEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
state.Subnets = append(state.Subnets, entry)
return nil
})
// fw:port_allowed
portAllowed := tx.Bucket([]byte("fw:port_allowed"))
_ = portAllowed.ForEach(func(k, v []byte) error {
var entry FWPortAllowEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
state.PortAllowed = append(state.PortAllowed, entry)
return nil
})
return nil
})
return state
}
package store
import (
"encoding/json"
"time"
bolt "go.etcd.io/bbolt"
)
// FirewallRollback is a pending tentative-apply record. The previous
// csm.yaml bytes are stashed verbatim so a recovery path can restore the
// file byte-for-byte without re-rendering through any encoder. Hashes are
// recorded so the daemon can sanity-check the on-disk file matches what
// was applied before deciding to revert.
type FirewallRollback struct {
PrevYAML []byte `json:"prev_yaml"`
PrevHash string `json:"prev_hash"`
NewHash string `json:"new_hash"`
AppliedAt time.Time `json:"applied_at"`
ExpiresAt time.Time `json:"expires_at"`
AppliedBy string `json:"applied_by"`
}
const fwRollbackBucket = "fw:rollback"
const fwRollbackKey = "pending"
// SaveFirewallRollback writes a pending rollback record. Overwrites any
// existing pending entry; callers must clear or revert the previous one
// first if that matters for their flow.
func (db *DB) SaveFirewallRollback(rb FirewallRollback) error {
val, err := json.Marshal(rb)
if err != nil {
return err
}
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(fwRollbackBucket))
return b.Put([]byte(fwRollbackKey), val)
})
}
// GetFirewallRollback returns the pending rollback or (zero, false) if
// none. The bool distinguishes "no record" from a zero-valued record.
// A bbolt unmarshal failure is treated as "no usable record" so the
// daemon can skip a corrupt entry instead of refusing to start.
func (db *DB) GetFirewallRollback() (FirewallRollback, bool) {
var rb FirewallRollback
found := false
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(fwRollbackBucket))
val := b.Get([]byte(fwRollbackKey))
if val == nil {
return nil
}
if uerr := json.Unmarshal(val, &rb); uerr != nil {
// Corrupt record: leave found=false so the caller treats
// it as "no pending rollback" rather than panicking.
return nil //nolint:nilerr // swallowing is the intent; see comment.
}
found = true
return nil
})
return rb, found
}
// ClearFirewallRollback drops the pending rollback. Idempotent: deleting
// a non-existent key is not an error in bbolt.
func (db *DB) ClearFirewallRollback() error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(fwRollbackBucket))
return b.Delete([]byte(fwRollbackKey))
})
}
package store
import (
"encoding/json"
"time"
bolt "go.etcd.io/bbolt"
)
// AuditResult represents a single hardening check result.
type AuditResult struct {
Category string `json:"category"`
Name string `json:"name"`
Title string `json:"title"`
Status string `json:"status"`
Message string `json:"message"`
Fix string `json:"fix,omitempty"`
}
// AuditReport is the full result of a hardening audit run.
type AuditReport struct {
Timestamp time.Time `json:"timestamp"`
ServerType string `json:"server_type"`
Results []AuditResult `json:"results"`
Score int `json:"score"`
Total int `json:"total"`
}
const hardeningReportKey = "hardening:report"
// SaveHardeningReport persists the latest audit report in the meta bucket.
func (db *DB) SaveHardeningReport(report *AuditReport) error {
data, err := json.Marshal(report)
if err != nil {
return err
}
return db.bolt.Update(func(tx *bolt.Tx) error {
return tx.Bucket([]byte("meta")).Put([]byte(hardeningReportKey), data)
})
}
// LoadHardeningReport retrieves the latest audit report from the meta bucket.
// Returns a zero-value report (nil Results) if no report has been saved yet.
func (db *DB) LoadHardeningReport() (*AuditReport, error) {
var report AuditReport
err := db.bolt.View(func(tx *bolt.Tx) error {
v := tx.Bucket([]byte("meta")).Get([]byte(hardeningReportKey))
if v == nil {
return nil
}
return json.Unmarshal(v, &report)
})
return &report, err
}
package store
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
bolt "go.etcd.io/bbolt"
)
// maxHistoryEntries is the maximum number of history entries to retain.
// It is a var (not const) so tests can override it.
var maxHistoryEntries = 100_000
// AppendHistory inserts findings into the history bucket with TimeKey keys.
// It increments the history:count counter and prunes oldest entries if the
// count exceeds maxHistoryEntries.
func (db *DB) AppendHistory(findings []alert.Finding) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
for i, f := range findings {
val, err := json.Marshal(f)
if err != nil {
return err
}
key := nextHistoryKey(b, f.Timestamp, i)
if err := b.Put([]byte(key), val); err != nil {
return err
}
// Same transaction as the history insert: either both land or
// neither does, so the daily aggregate can never drift.
if err := incrStatsDaily(tx, f.Timestamp, f.Severity); err != nil {
return err
}
}
if err := incrCounter(tx, "history:count", len(findings)); err != nil {
return err
}
// Cheap sweep against the bounded stats:daily bucket. Done here
// (rather than on a timer) so the daily-aggregate path has a
// single owner.
if len(findings) > 0 {
if err := pruneStatsDaily(tx, time.Now()); err != nil {
return err
}
}
// Prune oldest entries if count exceeds maxHistoryEntries.
meta := tx.Bucket([]byte("meta"))
var count int
if v := meta.Get([]byte("history:count")); v != nil {
_, _ = fmt.Sscanf(string(v), "%d", &count)
}
if count > maxHistoryEntries {
excess := count - maxHistoryEntries
c := b.Cursor()
k, _ := c.First()
for ; k != nil && excess > 0; excess-- {
// Delete() moves the cursor to the next item, so we
// must NOT call c.Next() after it.
if err := c.Delete(); err != nil {
return err
}
k, _ = c.First()
}
if err := setCounter(tx, "history:count", maxHistoryEntries); err != nil {
return err
}
}
return nil
})
}
func nextHistoryKey(b *bolt.Bucket, timestamp time.Time, start int) string {
for counter := start; ; counter++ {
key := TimeKey(timestamp, counter)
// Shutdown drains can persist separate batches whose findings carry
// the same detector timestamp. Probe instead of overwriting history.
if b.Get([]byte(key)) == nil {
return key
}
}
}
// ReadHistory reads findings from the history bucket, newest-first.
// It returns up to limit findings starting at offset, plus the total count.
func (db *DB) ReadHistory(limit, offset int) ([]alert.Finding, int) {
total := db.getCounter("history:count")
var results []alert.Finding
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
c := b.Cursor()
// Skip offset entries from the end (newest first).
skipped := 0
k, v := c.Last()
for ; k != nil && skipped < offset; k, v = c.Prev() {
skipped++
}
// Collect up to limit entries.
for ; k != nil && len(results) < limit; k, v = c.Prev() {
var f alert.Finding
if err := json.Unmarshal(v, &f); err == nil {
results = append(results, f)
}
}
return nil
})
return results, total
}
// ReadHistoryFiltered reads findings with optional filtering.
// Parameters:
// - from, to: date strings "YYYY-MM-DD" for time-range filtering (empty to skip)
// - severity: filter by severity level (-1 for no filter)
// - search: case-insensitive substring match on check/message/details (empty to skip)
func (db *DB) ReadHistoryFiltered(limit, offset int, from, to string, severity int, search string) ([]alert.Finding, int) {
return db.ReadHistoryFilteredWithChecks(limit, offset, from, to, severity, search, nil)
}
// ReadHistoryFilteredWithChecks reads findings with optional filters, including
// an exact check-name set when checks is non-nil.
func (db *DB) ReadHistoryFilteredWithChecks(
limit, offset int,
from, to string,
severity int,
search string,
checks map[string]bool,
) ([]alert.Finding, int) {
var results []alert.Finding
matched := 0
searchLower := strings.ToLower(search)
var fromPrefix, toPrefix string
if from != "" {
fromPrefix = ParseTimeKeyPrefix(from)
}
if to != "" {
// toPrefix needs to match the entire day, so we use the next day's prefix
// by appending a high character to ensure all entries on that day are included.
toPrefix = ParseTimeKeyPrefix(to) + "99" // "YYYYMMDD99" is > any time on that day
}
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
c := b.Cursor()
// Seek to the first key past the upper bound, then step back so the
// descending walk starts at the newest in-range entry. Without this the
// loop walked (and skipped) every entry newer than `to`, which is O(N)
// of the whole bucket when querying an old range on a large history.
k, v := c.Last()
if toPrefix != "" {
if sk, _ := c.Seek([]byte(toPrefix)); sk != nil {
// Seek lands on the first key >= toPrefix (out of range);
// the previous key is the newest in-range entry.
k, v = c.Prev()
} else {
// All keys are below toPrefix; Last() is already in range.
k, v = c.Last()
}
}
for ; k != nil; k, v = c.Prev() {
key := string(k)
// Defensive: anything still above the upper bound is out of range.
if toPrefix != "" && key > toPrefix {
continue
}
// Time-range: if key is below fromPrefix, all remaining are older - stop.
if fromPrefix != "" && key < fromPrefix {
break
}
var f alert.Finding
if err := json.Unmarshal(v, &f); err != nil {
continue
}
// Severity filter.
if severity >= 0 && int(f.Severity) != severity {
continue
}
// Exact check-name filter.
if checks != nil && !checks[f.Check] {
continue
}
// Search filter.
if search != "" && !containsLower(f.Check, searchLower) &&
!containsLower(f.Message, searchLower) &&
!containsLower(f.Details, searchLower) {
continue
}
matched++
if matched > offset && len(results) < limit {
results = append(results, f)
}
}
return nil
})
return results, matched
}
// containsLower checks if s contains substr using case-insensitive matching.
// substr must already be lowercase.
func containsLower(s, substr string) bool {
return strings.Contains(strings.ToLower(s), substr)
}
package store
import (
"encoding/json"
"errors"
"fmt"
"sort"
"time"
bolt "go.etcd.io/bbolt"
"github.com/pidginhost/csm/internal/incident"
csmlog "github.com/pidginhost/csm/internal/log"
)
const incidentsBucket = "incidents"
// SaveIncident persists an incident, overwriting any prior record with
// the same ID. Caller is responsible for setting UpdatedAt before
// invoking; this method just writes.
func (db *DB) SaveIncident(inc incident.Incident) error {
data, err := json.Marshal(inc)
if err != nil {
return err
}
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(incidentsBucket))
return b.Put([]byte(inc.ID), data)
})
}
// GetIncident returns (incident, true, nil) if found, (zero, false, nil)
// if not, (zero, false, err) on store error.
func (db *DB) GetIncident(id string) (incident.Incident, bool, error) {
var (
inc incident.Incident
found bool
)
err := db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(incidentsBucket))
v := b.Get([]byte(id))
if v == nil {
return nil
}
if err := json.Unmarshal(v, &inc); err != nil {
return err
}
if err := validateIncidentRow(id, inc); err != nil {
return err
}
found = true
return nil
})
return inc, found, err
}
// ListIncidents returns every stored incident, newest UpdatedAt first.
// Rows that fail JSON decode or storage invariants are skipped with a
// warn log so the rest of the bucket is still restorable. Aborting on
// the first bad row would leave the daemon with no restored incidents
// at startup.
func (db *DB) ListIncidents() ([]incident.Incident, error) {
var out []incident.Incident
err := db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(incidentsBucket))
return b.ForEach(func(k, v []byte) error {
inc, ok := decodeIncidentRow(k, v)
if !ok {
return nil
}
out = append(out, inc)
return nil
})
})
if err != nil {
return nil, err
}
sort.Slice(out, func(i, j int) bool {
return out[i].UpdatedAt.After(out[j].UpdatedAt)
})
return out, nil
}
// ListIncidentsByStatus returns incidents matching the requested status,
// newest UpdatedAt first.
func (db *DB) ListIncidentsByStatus(status incident.Status) ([]incident.Incident, error) {
all, err := db.ListIncidents()
if err != nil {
return nil, err
}
out := all[:0]
for _, inc := range all {
if inc.Status == status {
out = append(out, inc)
}
}
return out, nil
}
// CompactIncidents removes resolved/dismissed incidents whose UpdatedAt
// is older than now-retention. Open and Contained incidents are never
// pruned regardless of age. Returns the number of records removed.
func (db *DB) CompactIncidents(now time.Time, retention time.Duration) (int, error) {
cutoff := now.Add(-retention)
pruned := 0
err := db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(incidentsBucket))
var toDelete [][]byte
err := b.ForEach(func(k, v []byte) error {
inc, ok := decodeIncidentRow(k, v)
if !ok {
return nil
}
if inc.Status != incident.StatusResolved && inc.Status != incident.StatusDismissed {
return nil
}
if inc.UpdatedAt.Before(cutoff) {
toDelete = append(toDelete, append([]byte(nil), k...))
}
return nil
})
if err != nil {
return err
}
for _, k := range toDelete {
if err := b.Delete(k); err != nil {
return err
}
pruned++
}
return nil
})
return pruned, err
}
func decodeIncidentRow(k, v []byte) (incident.Incident, bool) {
rowID := string(k)
var inc incident.Incident
if err := json.Unmarshal(v, &inc); err != nil {
warnSkippedIncidentRow(rowID, err)
return incident.Incident{}, false
}
if err := validateIncidentRow(rowID, inc); err != nil {
warnSkippedIncidentRow(rowID, err)
return incident.Incident{}, false
}
return inc, true
}
func validateIncidentRow(rowID string, inc incident.Incident) error {
if rowID == "" {
return errors.New("empty row key")
}
if inc.ID == "" {
return errors.New("empty incident id")
}
if inc.ID != rowID {
return fmt.Errorf("incident id %q does not match row key %q", inc.ID, rowID)
}
if !incidentStatusValid(inc.Status) {
return fmt.Errorf("invalid status %q", inc.Status)
}
return nil
}
func incidentStatusValid(status incident.Status) bool {
switch status {
case incident.StatusOpen, incident.StatusContained, incident.StatusResolved, incident.StatusDismissed:
return true
default:
return false
}
}
func warnSkippedIncidentRow(rowID string, err error) {
csmlog.Warn("store: skipping corrupt incident row",
"bucket", incidentsBucket, "id", rowID, "err", err)
}
package store
import (
"bufio"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
bolt "go.etcd.io/bbolt"
)
// All os.Open / os.ReadFile calls below take paths derived from the
// operator-configured statePath (root-owned /opt/csm or /var/lib/csm
// by default) joined with fixed filenames. gosec G304 suppressions on
// each line refer back to this package-level trust model.
func (db *DB) runMigration(statePath string) error {
fmt.Fprintf(os.Stderr, "store: migrating flat files to bbolt...\n")
var errs []string
if err := db.migrateHistory(statePath); err != nil {
errs = append(errs, fmt.Sprintf("history: %v", err))
}
if err := db.migrateAttackDB(statePath); err != nil {
errs = append(errs, fmt.Sprintf("attackdb: %v", err))
}
if err := db.migrateThreatDB(statePath); err != nil {
errs = append(errs, fmt.Sprintf("threatdb: %v", err))
}
if err := db.migrateFirewall(statePath); err != nil {
errs = append(errs, fmt.Sprintf("firewall: %v", err))
}
if err := db.migrateReputation(statePath); err != nil {
errs = append(errs, fmt.Sprintf("reputation: %v", err))
}
if len(errs) > 0 {
return fmt.Errorf("partial migration: %s", strings.Join(errs, "; "))
}
_ = db.bolt.Update(func(tx *bolt.Tx) error {
return tx.Bucket([]byte("meta")).Put([]byte("migrated"), []byte(time.Now().Format(time.RFC3339)))
})
fmt.Fprintf(os.Stderr, "store: migration complete\n")
return nil
}
func (db *DB) migrateHistory(statePath string) error {
path := filepath.Join(statePath, "history.jsonl")
if _, err := os.Stat(path); err != nil {
return nil
}
// #nosec G304 -- see runMigration trust note.
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
var findings []alert.Finding
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
var finding alert.Finding
if err := json.Unmarshal(scanner.Bytes(), &finding); err != nil {
continue
}
findings = append(findings, finding)
}
if len(findings) > 0 {
if err := db.AppendHistory(findings); err != nil {
return err
}
}
renameToBackup(path)
fmt.Fprintf(os.Stderr, "store: migrated %d history entries\n", len(findings))
return nil
}
func (db *DB) migrateAttackDB(statePath string) error {
dbDir := filepath.Join(statePath, "attack_db")
// records.json
recordsPath := filepath.Join(dbDir, "records.json")
if _, err := os.Stat(recordsPath); err == nil {
// #nosec G304 -- see runMigration trust note.
data, err := os.ReadFile(recordsPath)
if err != nil {
return err
}
var records map[string]*IPRecord
if err := json.Unmarshal(data, &records); err != nil {
return err
}
for _, r := range records {
if err := db.SaveIPRecord(*r); err != nil {
return err
}
}
renameToBackup(recordsPath)
fmt.Fprintf(os.Stderr, "store: migrated %d attack records\n", len(records))
}
// events.jsonl
eventsPath := filepath.Join(dbDir, "events.jsonl")
if _, err := os.Stat(eventsPath); err == nil {
// #nosec G304 -- see runMigration trust note.
f, err := os.Open(eventsPath)
if err != nil {
return err
}
defer f.Close()
counter := 0
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
var event AttackEvent
if err := json.Unmarshal(scanner.Bytes(), &event); err != nil {
continue
}
if err := db.RecordAttackEvent(event, counter); err != nil {
return err
}
counter++
}
renameToBackup(eventsPath)
fmt.Fprintf(os.Stderr, "store: migrated %d attack events\n", counter)
}
return nil
}
func (db *DB) migrateThreatDB(statePath string) error {
dbDir := filepath.Join(statePath, "threat_db")
// permanent.txt
permPath := filepath.Join(dbDir, "permanent.txt")
if _, err := os.Stat(permPath); err == nil {
// #nosec G304 -- see runMigration trust note.
data, err := os.ReadFile(permPath)
if err != nil {
return err
}
count := 0
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, " # ", 2)
ip := strings.TrimSpace(parts[0])
reason := ""
if len(parts) > 1 {
reason = strings.TrimSpace(parts[1])
}
if ip != "" {
if err := db.AddPermanentBlock(ip, reason); err != nil {
return err
}
count++
}
}
renameToBackup(permPath)
fmt.Fprintf(os.Stderr, "store: migrated %d permanent blocks\n", count)
}
// whitelist.txt
wlPath := filepath.Join(dbDir, "whitelist.txt")
if _, err := os.Stat(wlPath); err == nil {
// #nosec G304 -- see runMigration trust note.
data, err := os.ReadFile(wlPath)
if err != nil {
return err
}
count := 0
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.Fields(line)
ip := parts[0]
permanent := true
var expiresAt time.Time
for _, p := range parts[1:] {
if strings.HasPrefix(p, "expires=") {
t, err := time.Parse(time.RFC3339, strings.TrimPrefix(p, "expires="))
if err == nil {
expiresAt = t
permanent = false
}
}
if p == "permanent" {
permanent = true
}
}
if err := db.AddWhitelistEntry(ip, expiresAt, permanent); err != nil {
return err
}
count++
}
renameToBackup(wlPath)
fmt.Fprintf(os.Stderr, "store: migrated %d whitelist entries\n", count)
}
return nil
}
func (db *DB) migrateFirewall(statePath string) error {
path := filepath.Join(statePath, "firewall", "state.json")
if _, serr := os.Stat(path); serr != nil {
return nil //nolint:nilerr // file does not exist, skip migration
}
// #nosec G304 -- see runMigration trust note.
data, err := os.ReadFile(path)
if err != nil {
return err
}
type rawBlocked struct {
IP string `json:"ip"`
Reason string `json:"reason"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type rawAllowed struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Port int `json:"port,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
}
type rawSubnet struct {
CIDR string `json:"cidr"`
Reason string `json:"reason"`
BlockedAt time.Time `json:"blocked_at"`
}
type rawPortAllow struct {
IP string `json:"ip"`
Port int `json:"port"`
Proto string `json:"proto"`
Reason string `json:"reason"`
}
type rawState struct {
Blocked []rawBlocked `json:"blocked"`
BlockedNet []rawSubnet `json:"blocked_nets"`
Allowed []rawAllowed `json:"allowed"`
PortAllowed []rawPortAllow `json:"port_allowed"`
}
var state rawState
if err := json.Unmarshal(data, &state); err != nil {
return err
}
for _, b := range state.Blocked {
if err := db.BlockIP(b.IP, b.Reason, b.ExpiresAt); err != nil {
return err
}
}
for _, a := range state.Allowed {
if err := db.AllowIP(a.IP, a.Reason, a.ExpiresAt); err != nil {
return err
}
}
for _, s := range state.BlockedNet {
if err := db.AddSubnet(s.CIDR, s.Reason); err != nil {
return err
}
}
for _, p := range state.PortAllowed {
if err := db.AddPortAllow(p.IP, p.Port, p.Proto, p.Reason); err != nil {
return err
}
}
renameToBackup(path)
fmt.Fprintf(os.Stderr, "store: migrated firewall state (%d blocked, %d allowed, %d subnets, %d port allows)\n",
len(state.Blocked), len(state.Allowed), len(state.BlockedNet), len(state.PortAllowed))
return nil
}
func (db *DB) migrateReputation(statePath string) error {
path := filepath.Join(statePath, "reputation_cache.json")
if _, serr := os.Stat(path); serr != nil {
return nil //nolint:nilerr // file does not exist, skip migration
}
// #nosec G304 -- see runMigration trust note.
data, err := os.ReadFile(path)
if err != nil {
return err
}
type rawCache struct {
Entries map[string]*ReputationEntry `json:"entries"`
}
var cache rawCache
if err := json.Unmarshal(data, &cache); err != nil {
return err
}
count := 0
for ip, entry := range cache.Entries {
if err := db.SetReputation(ip, *entry); err != nil {
return err
}
count++
}
renameToBackup(path)
fmt.Fprintf(os.Stderr, "store: migrated %d reputation entries\n", count)
return nil
}
func renameToBackup(path string) {
_ = os.Rename(path, path+".bak")
}
package store
import (
"encoding/json"
"fmt"
"strconv"
"time"
bolt "go.etcd.io/bbolt"
)
const modsecNoEscalateKey = "modsec:no_escalate_rules"
// GetModSecNoEscalateRules returns the set of ModSecurity rule IDs that should
// NOT escalate to nftables firewall blocks. Stored in the meta bucket.
func (db *DB) GetModSecNoEscalateRules() map[int]bool {
rules := make(map[int]bool)
_ = db.bolt.View(func(tx *bolt.Tx) error {
v := tx.Bucket([]byte("meta")).Get([]byte(modsecNoEscalateKey))
if v == nil {
return nil
}
var ids []int
if json.Unmarshal(v, &ids) != nil {
return nil //nolint:nilerr // skip corrupt data
}
for _, id := range ids {
rules[id] = true
}
return nil
})
return rules
}
// SetModSecNoEscalateRules stores the set of rule IDs that should not escalate.
func (db *DB) SetModSecNoEscalateRules(rules map[int]bool) error {
var ids []int
for id := range rules {
ids = append(ids, id)
}
return db.bolt.Update(func(tx *bolt.Tx) error {
val, err := json.Marshal(ids)
if err != nil {
return err
}
return tx.Bucket([]byte("meta")).Put([]byte(modsecNoEscalateKey), val)
})
}
// AddModSecNoEscalateRule atomically adds a single rule ID to the no-escalate set.
// Read-modify-write happens in a single bbolt Update transaction to prevent races.
func (db *DB) AddModSecNoEscalateRule(ruleID int) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
var ids []int
if v := b.Get([]byte(modsecNoEscalateKey)); v != nil {
_ = json.Unmarshal(v, &ids)
}
for _, id := range ids {
if id == ruleID {
return nil // already present
}
}
ids = append(ids, ruleID)
val, err := json.Marshal(ids)
if err != nil {
return err
}
return b.Put([]byte(modsecNoEscalateKey), val)
})
}
// RemoveModSecNoEscalateRule atomically removes a single rule ID from the no-escalate set.
func (db *DB) RemoveModSecNoEscalateRule(ruleID int) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
var ids []int
if v := b.Get([]byte(modsecNoEscalateKey)); v != nil {
_ = json.Unmarshal(v, &ids)
}
var filtered []int
for _, id := range ids {
if id != ruleID {
filtered = append(filtered, id)
}
}
val, err := json.Marshal(filtered)
if err != nil {
return err
}
return b.Put([]byte(modsecNoEscalateKey), val)
})
}
// RuleHitStats holds hit count and last-hit time for a ModSecurity rule.
type RuleHitStats struct {
Hits int `json:"hits"`
LastHit time.Time `json:"last_hit"`
}
type ruleHitData struct {
Buckets map[string]int `json:"buckets"` // key: "YYYYMMDDHH" → count
LastHit time.Time `json:"last_hit"`
}
func modsecHitKey(ruleID int) string {
return fmt.Sprintf("modsec:hits:%d", ruleID)
}
func hourBucket(t time.Time) string {
return fmt.Sprintf("%04d%02d%02d%02d", t.Year(), t.Month(), t.Day(), t.Hour())
}
// IncrModSecRuleHit increments the hit counter for a rule ID in the current hour bucket.
func (db *DB) IncrModSecRuleHit(ruleID int, timestamp time.Time) {
key := modsecHitKey(ruleID)
bucket := hourBucket(timestamp)
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
var data ruleHitData
if v := b.Get([]byte(key)); v != nil {
if json.Unmarshal(v, &data) != nil {
data = ruleHitData{Buckets: make(map[string]int)}
}
} else {
data = ruleHitData{Buckets: make(map[string]int)}
}
data.Buckets[bucket]++
data.LastHit = timestamp
val, err := json.Marshal(data)
if err != nil {
return err
}
return b.Put([]byte(key), val)
})
}
// GetModSecRuleHits returns hit counts and last-hit timestamps for all rules
// within the last 24 hours. Prunes buckets older than 24h.
// Note: hourly bucket granularity means the window is 24h +/- 1h at boundaries.
func (db *DB) GetModSecRuleHits() map[int]RuleHitStats {
result := make(map[int]RuleHitStats)
cutoff := time.Now().Add(-24 * time.Hour)
cutoffBucket := hourBucket(cutoff)
prefix := []byte("modsec:hits:")
// Track which keys need pruning - read first with View (no write lock),
// then prune in a separate Update only if needed.
type pruneItem struct {
key []byte
data ruleHitData
}
var toPrune []pruneItem
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && len(k) >= len(prefix) && string(k[:len(prefix)]) == string(prefix); k, v = c.Next() {
idStr := string(k[len(prefix):])
ruleID, err := strconv.Atoi(idStr)
if err != nil {
continue
}
var data ruleHitData
if json.Unmarshal(v, &data) != nil {
continue
}
total := 0
needsPrune := false
for bk, count := range data.Buckets {
if bk >= cutoffBucket {
total += count
} else {
needsPrune = true
}
}
if needsPrune {
// Deep copy key (bbolt keys are only valid inside tx)
keyCopy := make([]byte, len(k))
copy(keyCopy, k)
toPrune = append(toPrune, pruneItem{key: keyCopy, data: data})
}
if total > 0 || !data.LastHit.IsZero() {
result[ruleID] = RuleHitStats{
Hits: total,
LastHit: data.LastHit,
}
}
}
return nil
})
// Prune old buckets in a separate write transaction (only if needed)
if len(toPrune) > 0 {
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
for _, item := range toPrune {
for bk := range item.data.Buckets {
if bk < cutoffBucket {
delete(item.data.Buckets, bk)
}
}
val, _ := json.Marshal(item.data)
_ = b.Put(item.key, val)
}
return nil
})
}
return result
}
package store
import (
"errors"
"fmt"
bolt "go.etcd.io/bbolt"
)
// PHPRelayKV is a single key/value pair for batched writes.
type PHPRelayKV struct {
Key []byte
Value []byte
}
// allowed php_relay buckets. Validated on every helper to prevent the
// daemon from accidentally writing into other buckets through these
// generic wrappers.
var phpRelayBucketAllowlist = map[string]struct{}{
"phprelay:meta": {},
"phprelay:msgindex": {},
"phprelay:ignore": {},
"phprelay:settings": {},
"phprelay:baseline": {}, // Stage 3
}
func phpRelayCheckBucket(name string) error {
if _, ok := phpRelayBucketAllowlist[name]; !ok {
return fmt.Errorf("php_relay: bucket %q is not in the allowlist", name)
}
return nil
}
// PHPRelayPut writes a single key/value into the named php_relay bucket.
func (db *DB) PHPRelayPut(bucket, key string, value []byte) error {
if err := phpRelayCheckBucket(bucket); err != nil {
return err
}
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(bucket))
if b == nil {
return fmt.Errorf("bucket %q missing", bucket)
}
return b.Put([]byte(key), append([]byte(nil), value...))
})
}
// PHPRelayGet reads a single value. ok=false when the key is absent.
func (db *DB) PHPRelayGet(bucket, key string) ([]byte, bool, error) {
if err := phpRelayCheckBucket(bucket); err != nil {
return nil, false, err
}
var out []byte
var found bool
err := db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(bucket))
if b == nil {
return fmt.Errorf("bucket %q missing", bucket)
}
v := b.Get([]byte(key))
if v == nil {
return nil
}
out = append([]byte(nil), v...)
found = true
return nil
})
return out, found, err
}
// PHPRelayDelete removes a key. Missing keys are not an error.
func (db *DB) PHPRelayDelete(bucket, key string) error {
if err := phpRelayCheckBucket(bucket); err != nil {
return err
}
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(bucket))
if b == nil {
return fmt.Errorf("bucket %q missing", bucket)
}
return b.Delete([]byte(key))
})
}
// PHPRelayPutBatch writes many key/value pairs in a single bbolt
// transaction. Used by the msgIndexPersister to keep IOPS bounded.
// Returns on the first encode/put error; partial commits are visible
// only at transaction boundary.
func (db *DB) PHPRelayPutBatch(bucket string, ops []PHPRelayKV) error {
if err := phpRelayCheckBucket(bucket); err != nil {
return err
}
if len(ops) == 0 {
return nil
}
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(bucket))
if b == nil {
return fmt.Errorf("bucket %q missing", bucket)
}
for _, kv := range ops {
if len(kv.Key) == 0 {
return errors.New("php_relay: empty key")
}
if err := b.Put(append([]byte(nil), kv.Key...), append([]byte(nil), kv.Value...)); err != nil {
return err
}
}
return nil
})
}
// PHPRelaySweep iterates the bucket and deletes every key for which
// shouldDelete returns true. Decoding is the caller's responsibility.
// Returns the number of deletions.
func (db *DB) PHPRelaySweep(bucket string, shouldDelete func(key, value []byte) bool) (int, error) {
if err := phpRelayCheckBucket(bucket); err != nil {
return 0, err
}
n := 0
err := db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(bucket))
if b == nil {
return fmt.Errorf("bucket %q missing", bucket)
}
var toDelete [][]byte
c := b.Cursor()
for k, v := c.First(); k != nil; k, v = c.Next() {
if shouldDelete(k, v) {
toDelete = append(toDelete, append([]byte(nil), k...))
}
}
for _, k := range toDelete {
if err := b.Delete(k); err != nil {
return err
}
}
n = len(toDelete)
return nil
})
return n, err
}
// PHPRelayList returns a copy of every key/value in the bucket. Used at
// daemon start to restore the in-memory ignoreList.
func (db *DB) PHPRelayList(bucket string) (map[string][]byte, error) {
if err := phpRelayCheckBucket(bucket); err != nil {
return nil, err
}
out := make(map[string][]byte)
err := db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(bucket))
if b == nil {
return fmt.Errorf("bucket %q missing", bucket)
}
return b.ForEach(func(k, v []byte) error {
out[string(k)] = append([]byte(nil), v...)
return nil
})
})
return out, err
}
package store
import (
"encoding/json"
"time"
bolt "go.etcd.io/bbolt"
)
// PluginInfo holds cached metadata for a WordPress plugin from the API.
type PluginInfo struct {
LatestVersion string `json:"latest_version"`
TestedUpTo string `json:"tested_up_to"`
LastChecked int64 `json:"last_checked_unix"`
}
// SitePluginEntry describes a single plugin installed on a WordPress site.
type SitePluginEntry struct {
Slug string `json:"slug"`
Name string `json:"name"`
Status string `json:"status"`
InstalledVersion string `json:"installed_version"`
UpdateVersion string `json:"update_version"`
}
// SitePlugins holds the full plugin inventory for a WordPress installation.
type SitePlugins struct {
Account string `json:"account"`
Domain string `json:"domain"`
Plugins []SitePluginEntry `json:"plugins"`
}
// SetPluginInfo stores plugin metadata keyed by slug in the plugins bucket.
func (db *DB) SetPluginInfo(slug string, info PluginInfo) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("plugins"))
val, err := json.Marshal(info)
if err != nil {
return err
}
return b.Put([]byte(slug), val)
})
}
// GetPluginInfo retrieves plugin metadata for the given slug.
// Returns the entry and true if found, or a zero value and false if not.
func (db *DB) GetPluginInfo(slug string) (PluginInfo, bool) {
var info PluginInfo
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("plugins"))
v := b.Get([]byte(slug))
if v == nil {
return nil
}
if json.Unmarshal(v, &info) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return info, found
}
// SetSitePlugins stores the plugin inventory for a WordPress installation
// keyed by its filesystem path in the plugins:sites bucket.
func (db *DB) SetSitePlugins(wpPath string, site SitePlugins) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("plugins:sites"))
val, err := json.Marshal(site)
if err != nil {
return err
}
return b.Put([]byte(wpPath), val)
})
}
// GetSitePlugins retrieves the plugin inventory for a WordPress installation.
// Returns the entry and true if found, or a zero value and false if not.
func (db *DB) GetSitePlugins(wpPath string) (SitePlugins, bool) {
var site SitePlugins
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("plugins:sites"))
v := b.Get([]byte(wpPath))
if v == nil {
return nil
}
if json.Unmarshal(v, &site) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return site, found
}
// DeleteSitePlugins removes the plugin inventory for a WordPress installation.
func (db *DB) DeleteSitePlugins(wpPath string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("plugins:sites"))
return b.Delete([]byte(wpPath))
})
}
// AllSitePlugins returns all site plugin inventories keyed by WordPress path.
func (db *DB) AllSitePlugins() map[string]SitePlugins {
entries := make(map[string]SitePlugins)
_ = db.bolt.View(func(tx *bolt.Tx) error {
return tx.Bucket([]byte("plugins:sites")).ForEach(func(k, v []byte) error {
var s SitePlugins
if json.Unmarshal(v, &s) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
entries[string(k)] = s
return nil
})
})
return entries
}
// GetPluginRefreshTime reads the last plugin refresh timestamp from the meta bucket.
// Returns the zero time if not set.
func (db *DB) GetPluginRefreshTime() time.Time {
var t time.Time
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
v := b.Get([]byte("plugins:last_refresh"))
if v == nil {
return nil
}
parsed, err := time.Parse(time.RFC3339, string(v))
if err != nil {
return nil //nolint:nilerr // skip corrupt entry
}
t = parsed
return nil
})
return t
}
// SetPluginRefreshTime writes the plugin refresh timestamp to the meta bucket.
func (db *DB) SetPluginRefreshTime(t time.Time) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
return b.Put([]byte("plugins:last_refresh"), []byte(t.Format(time.RFC3339)))
})
}
package store
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"time"
bolt "go.etcd.io/bbolt"
)
// prefsBucket holds per-operator preference blobs:
// - "<opkey>:user" -> user settings JSON (density, timezone, etc.)
// - "<opkey>:views:<page>" -> saved filter views JSON
// - "<opkey>:table:<tableID>" -> per-table column visibility JSON
// - "<opkey>:undo:<seq>" -> bulk-action undo entry (sequence is
// big-endian uint64 of unix nano, so prefix iteration is chronological)
//
// opkey is a SHA-256 hex of the operator's auth token, computed at request
// time by the webui layer. The store never sees the token itself.
const prefsBucket = "prefs:operator"
// MaxPrefBlobSize caps the size of any single preference blob. Large enough
// for saved views with dozens of params, small enough to keep abuse bounded.
const MaxPrefBlobSize = 64 * 1024
// MaxUndoEntries caps how many undo entries one operator may have queued at
// once. Older entries fall out as new ones are recorded.
const MaxUndoEntries = 32
// UndoTTL is how long an undo entry remains valid. Matches the banner timeout
// the UI advertises so an operator can never undo an action whose banner has
// already disappeared.
const UndoTTL = 30 * time.Second
// ErrPrefBlobTooLarge is returned when a preference blob exceeds MaxPrefBlobSize.
var ErrPrefBlobTooLarge = errors.New("preference blob too large")
func prefsKey(opkey, ns string) []byte {
return []byte(opkey + ":" + ns)
}
func undoKey(opkey string, seq uint64) []byte {
out := make([]byte, 0, len(opkey)+6+8)
out = append(out, opkey...)
out = append(out, ':', 'u', 'n', 'd', 'o', ':')
var seqBE [8]byte
binary.BigEndian.PutUint64(seqBE[:], seq)
return append(out, seqBE[:]...)
}
func undoKeyPrefix(opkey string) []byte {
return []byte(opkey + ":undo:")
}
// GetOperatorPref returns the raw JSON blob stored at (opkey, ns). Returns
// nil with nil error when no entry exists.
func (db *DB) GetOperatorPref(opkey, ns string) ([]byte, error) {
if db == nil || db.bolt == nil {
return nil, errors.New("store unavailable")
}
if opkey == "" || ns == "" {
return nil, errors.New("opkey and namespace required")
}
var out []byte
err := db.bolt.View(func(tx *bolt.Tx) error {
b, err := bucketOrCreate(tx, prefsBucket)
if err != nil {
return err
}
v := b.Get(prefsKey(opkey, ns))
if v == nil {
return nil
}
// Bolt invalidates the slice after the tx ends; copy out.
out = append([]byte(nil), v...)
return nil
})
return out, err
}
// PutOperatorPref writes the JSON blob to (opkey, ns). Rejects payloads above
// MaxPrefBlobSize.
func (db *DB) PutOperatorPref(opkey, ns string, data []byte) error {
if db == nil || db.bolt == nil {
return errors.New("store unavailable")
}
if opkey == "" || ns == "" {
return errors.New("opkey and namespace required")
}
if len(data) > MaxPrefBlobSize {
return ErrPrefBlobTooLarge
}
return db.bolt.Update(func(tx *bolt.Tx) error {
b, err := bucketOrCreate(tx, prefsBucket)
if err != nil {
return err
}
return b.Put(prefsKey(opkey, ns), data)
})
}
// DeleteOperatorPref removes the blob at (opkey, ns). No error if absent.
func (db *DB) DeleteOperatorPref(opkey, ns string) error {
if db == nil || db.bolt == nil {
return errors.New("store unavailable")
}
if opkey == "" || ns == "" {
return errors.New("opkey and namespace required")
}
return db.bolt.Update(func(tx *bolt.Tx) error {
b, err := bucketOrCreate(tx, prefsBucket)
if err != nil {
return err
}
return b.Delete(prefsKey(opkey, ns))
})
}
// UndoEntry is one record in the bulk-action undo queue.
type UndoEntry struct {
ID string `json:"id"` // matches the seq encoded in the bbolt key
RecordedAt time.Time `json:"recorded_at"` // wall time the entry was written
Action string `json:"action"` // e.g. "threat_bulk_block"
Inverse string `json:"inverse"` // inverse action key the runner will dispatch
Payload []byte `json:"payload"` // opaque JSON the runner understands
Summary string `json:"summary"` // human-readable label for the banner
}
// AppendUndoEntry queues an undo entry for opkey. The entry's ID and
// RecordedAt are filled in. Prunes expired entries and trims the queue to
// MaxUndoEntries before writing. Returns the saved entry.
func (db *DB) AppendUndoEntry(opkey string, e UndoEntry) (UndoEntry, error) {
if db == nil || db.bolt == nil {
return UndoEntry{}, errors.New("store unavailable")
}
if opkey == "" {
return UndoEntry{}, errors.New("opkey required")
}
if e.Inverse == "" {
return UndoEntry{}, errors.New("inverse action required")
}
now := time.Now().UTC()
seq := uint64(now.UnixNano())
e.ID = fmt.Sprintf("%016x", seq)
e.RecordedAt = now
raw, err := encodeUndoEntry(e)
if err != nil {
return UndoEntry{}, err
}
if len(raw) > MaxPrefBlobSize {
return UndoEntry{}, ErrPrefBlobTooLarge
}
err = db.bolt.Update(func(tx *bolt.Tx) error {
b, berr := bucketOrCreate(tx, prefsBucket)
if berr != nil {
return berr
}
if perr := pruneOperatorUndo(b, opkey, now); perr != nil {
return perr
}
return b.Put(undoKey(opkey, seq), raw)
})
if err != nil {
return UndoEntry{}, err
}
return e, nil
}
// LatestUndoEntry returns the most recent non-expired undo entry for opkey,
// or (zero, false, nil) if none exists.
func (db *DB) LatestUndoEntry(opkey string) (UndoEntry, bool, error) {
if db == nil || db.bolt == nil {
return UndoEntry{}, false, errors.New("store unavailable")
}
if opkey == "" {
return UndoEntry{}, false, errors.New("opkey required")
}
now := time.Now().UTC()
var entry UndoEntry
var found bool
err := db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(prefsBucket))
if b == nil {
return nil
}
c := b.Cursor()
prefix := undoKeyPrefix(opkey)
// Seek to first key strictly greater than the prefix, then walk
// backwards. Newest entries sort last because the suffix is the
// big-endian nano timestamp.
for k, v := c.Last(); k != nil; k, v = c.Prev() {
if !bytes.HasPrefix(k, prefix) {
if bytes.Compare(k, prefix) < 0 {
return nil
}
continue
}
e, err := decodeUndoEntry(v)
if err != nil {
continue
}
if now.Sub(e.RecordedAt) > UndoTTL {
return nil
}
entry = e
found = true
return nil
}
return nil
})
return entry, found, err
}
// ConsumeUndoEntry removes the entry identified by id from opkey's queue and
// returns the decoded value. Returns (zero, false, nil) if the entry has
// already expired or never existed.
func (db *DB) ConsumeUndoEntry(opkey, id string) (UndoEntry, bool, error) {
if db == nil || db.bolt == nil {
return UndoEntry{}, false, errors.New("store unavailable")
}
if opkey == "" || id == "" {
return UndoEntry{}, false, errors.New("opkey and id required")
}
now := time.Now().UTC()
var entry UndoEntry
var found bool
err := db.bolt.Update(func(tx *bolt.Tx) error {
b, err := bucketOrCreate(tx, prefsBucket)
if err != nil {
return err
}
if perr := pruneOperatorUndo(b, opkey, now); perr != nil {
return perr
}
c := b.Cursor()
prefix := undoKeyPrefix(opkey)
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
e, derr := decodeUndoEntry(v)
if derr != nil {
continue
}
if e.ID != id {
continue
}
if now.Sub(e.RecordedAt) > UndoTTL {
return b.Delete(k)
}
entry = e
found = true
return b.Delete(k)
}
return nil
})
return entry, found, err
}
// PurgeOperatorUndo drops every undo entry for opkey (used when an operator
// logs out or when tests need to reset state).
func (db *DB) PurgeOperatorUndo(opkey string) error {
if db == nil || db.bolt == nil {
return errors.New("store unavailable")
}
if opkey == "" {
return errors.New("opkey required")
}
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(prefsBucket))
if b == nil {
return nil
}
var keys [][]byte
prefix := undoKeyPrefix(opkey)
c := b.Cursor()
for k, _ := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, _ = c.Next() {
keys = append(keys, append([]byte(nil), k...))
}
for _, k := range keys {
if err := b.Delete(k); err != nil {
return err
}
}
return nil
})
}
// pruneOperatorUndo drops expired entries and trims the queue to
// MaxUndoEntries. Must run inside a writable transaction.
func pruneOperatorUndo(b *bolt.Bucket, opkey string, now time.Time) error {
prefix := undoKeyPrefix(opkey)
type kv struct {
key []byte
entry UndoEntry
}
var kept []kv
var expired [][]byte
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
e, err := decodeUndoEntry(v)
if err != nil {
expired = append(expired, append([]byte(nil), k...))
continue
}
if now.Sub(e.RecordedAt) > UndoTTL {
expired = append(expired, append([]byte(nil), k...))
continue
}
kept = append(kept, kv{key: append([]byte(nil), k...), entry: e})
}
for _, k := range expired {
if err := b.Delete(k); err != nil {
return err
}
}
// kept is already in chronological order because keys sort by big-endian
// nano suffix. Trim oldest first.
if len(kept) >= MaxUndoEntries {
drop := len(kept) - MaxUndoEntries + 1
for i := 0; i < drop; i++ {
if err := b.Delete(kept[i].key); err != nil {
return err
}
}
}
return nil
}
func bucketOrCreate(tx *bolt.Tx, name string) (*bolt.Bucket, error) {
if tx.Writable() {
return tx.CreateBucketIfNotExists([]byte(name))
}
if b := tx.Bucket([]byte(name)); b != nil {
return b, nil
}
return nil, fmt.Errorf("bucket %s not initialised", name)
}
package store
import "encoding/json"
func encodeUndoEntry(e UndoEntry) ([]byte, error) {
return json.Marshal(e)
}
func decodeUndoEntry(raw []byte) (UndoEntry, error) {
var e UndoEntry
if err := json.Unmarshal(raw, &e); err != nil {
return UndoEntry{}, err
}
return e, nil
}
package store
import (
"bytes"
"fmt"
"time"
bolt "go.etcd.io/bbolt"
)
// RewriteUndoEntryRecordedAt rewrites the RecordedAt field of an undo
// entry identified by id. Used by tests that need to age an entry past the
// TTL window without sleeping for real.
func RewriteUndoEntryRecordedAt(db *DB, id string, at time.Time) error {
if db == nil || db.bolt == nil {
return fmt.Errorf("store unavailable")
}
if id == "" {
return fmt.Errorf("id required")
}
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte(prefsBucket))
if b == nil {
return fmt.Errorf("prefs bucket missing")
}
c := b.Cursor()
needle := []byte(":undo:")
for k, v := c.First(); k != nil; k, v = c.Next() {
if !bytes.Contains(k, needle) {
continue
}
e, err := decodeUndoEntry(v)
if err != nil {
continue
}
if e.ID != id {
continue
}
e.RecordedAt = at
raw, err := encodeUndoEntry(e)
if err != nil {
return err
}
return b.Put(k, raw)
}
return fmt.Errorf("entry %s not found", id)
})
}
package store
import (
"encoding/json"
"fmt"
"sort"
"time"
bolt "go.etcd.io/bbolt"
)
// Meta-bucket keys for AbuseIPDB quota accounting. Persisted so enforcement
// survives daemon restarts and spans across 10-minute scan cycles.
const (
abuseQuotaExhaustedKey = "abuse:quota_exhausted_until"
abuseDailyCountPrefix = "abuse:daily_count:" // + YYYY-MM-DD in UTC
)
// ReputationEntry holds the cached reputation data for an IP address.
type ReputationEntry struct {
Score int `json:"score"`
Category string `json:"category"`
CheckedAt time.Time `json:"checked_at"`
}
// SetReputation stores a reputation entry for the given IP.
func (db *DB) SetReputation(ip string, entry ReputationEntry) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("reputation"))
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(ip), val)
})
}
// GetReputation retrieves a reputation entry for the given IP.
// Returns the entry and true if found, or a zero value and false if not.
func (db *DB) GetReputation(ip string) (ReputationEntry, bool) {
var entry ReputationEntry
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("reputation"))
v := b.Get([]byte(ip))
if v == nil {
return nil
}
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return entry, found
}
// CleanExpiredReputation deletes entries older than maxAge.
// Uses a collect-then-delete pattern because bbolt does not allow mutation
// during ForEach iteration. Returns the count of entries removed.
func (db *DB) CleanExpiredReputation(maxAge time.Duration) int {
var removed int
cutoff := time.Now().Add(-maxAge)
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("reputation"))
// Collect keys to delete.
var toDelete [][]byte
_ = b.ForEach(func(k, v []byte) error {
var entry ReputationEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
if entry.CheckedAt.Before(cutoff) {
keyCopy := make([]byte, len(k))
copy(keyCopy, k)
toDelete = append(toDelete, keyCopy)
}
return nil
})
// Delete collected keys.
for _, k := range toDelete {
if err := b.Delete(k); err != nil {
return err
}
removed++
}
return nil
})
return removed
}
// AllReputation returns all reputation entries keyed by IP.
func (db *DB) AllReputation() map[string]ReputationEntry {
entries := make(map[string]ReputationEntry)
_ = db.bolt.View(func(tx *bolt.Tx) error {
return tx.Bucket([]byte("reputation")).ForEach(func(k, v []byte) error {
var e ReputationEntry
if json.Unmarshal(v, &e) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
entries[string(k)] = e
return nil
})
})
return entries
}
// SetAbuseQuotaExhaustedUntil records the time at which the AbuseIPDB
// quota is expected to reset. While now < t, callers should skip API
// queries. The daemon re-reads this on every cycle so the flag survives
// restarts and multi-hour backoffs.
func (db *DB) SetAbuseQuotaExhaustedUntil(t time.Time) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
return tx.Bucket([]byte("meta")).Put(
[]byte(abuseQuotaExhaustedKey),
[]byte(t.UTC().Format(time.RFC3339)),
)
})
}
// AbuseQuotaExhaustedUntil returns the persisted quota-reset timestamp,
// or zero time if none is recorded (or the stored value is unparseable).
func (db *DB) AbuseQuotaExhaustedUntil() time.Time {
var ts time.Time
_ = db.bolt.View(func(tx *bolt.Tx) error {
v := tx.Bucket([]byte("meta")).Get([]byte(abuseQuotaExhaustedKey))
if v == nil {
return nil
}
parsed, err := time.Parse(time.RFC3339, string(v))
if err != nil {
return nil //nolint:nilerr // skip corrupt entry
}
ts = parsed
return nil
})
return ts
}
// IncrementAbuseQueryCount bumps and returns the AbuseIPDB query counter
// for the given UTC date (YYYY-MM-DD). Used as a daily circuit breaker.
func (db *DB) IncrementAbuseQueryCount(utcDate string) int {
var count int
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
key := []byte(abuseDailyCountPrefix + utcDate)
if v := b.Get(key); v != nil {
_, _ = fmt.Sscanf(string(v), "%d", &count)
}
count++
return b.Put(key, []byte(fmt.Sprintf("%d", count)))
})
return count
}
// ReserveAbuseQuerySlots atomically reserves up to requested AbuseIPDB
// query slots for utcDate without increasing the daily counter beyond max.
func (db *DB) ReserveAbuseQuerySlots(utcDate string, requested, max int) int {
if requested <= 0 || max <= 0 {
return 0
}
var reserved int
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
key := []byte(abuseDailyCountPrefix + utcDate)
count := 0
if v := b.Get(key); v != nil {
_, _ = fmt.Sscanf(string(v), "%d", &count)
}
if count >= max {
return nil
}
remaining := max - count
reserved = requested
if reserved > remaining {
reserved = remaining
}
count += reserved
return b.Put(key, []byte(fmt.Sprintf("%d", count)))
})
return reserved
}
// AbuseQueryCount returns the AbuseIPDB query count for the given UTC date.
func (db *DB) AbuseQueryCount(utcDate string) int {
var count int
_ = db.bolt.View(func(tx *bolt.Tx) error {
v := tx.Bucket([]byte("meta")).Get([]byte(abuseDailyCountPrefix + utcDate))
if v != nil {
_, _ = fmt.Sscanf(string(v), "%d", &count)
}
return nil
})
return count
}
// EnforceReputationCap ensures the reputation bucket has at most max entries.
// If the count exceeds max, the oldest entries (by CheckedAt) are deleted.
// Returns the count of entries removed.
func (db *DB) EnforceReputationCap(max int) int {
var removed int
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("reputation"))
// Collect all entries with their keys.
type keyed struct {
key []byte
checkedAt time.Time
}
var all []keyed
_ = b.ForEach(func(k, v []byte) error {
var entry ReputationEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
keyCopy := make([]byte, len(k))
copy(keyCopy, k)
all = append(all, keyed{key: keyCopy, checkedAt: entry.CheckedAt})
return nil
})
if len(all) <= max {
return nil
}
// Sort by CheckedAt ascending (oldest first).
sort.Slice(all, func(i, j int) bool {
return all[i].checkedAt.Before(all[j].checkedAt)
})
// Delete the oldest entries beyond the cap.
excess := len(all) - max
for i := 0; i < excess; i++ {
if err := b.Delete(all[i].key); err != nil {
return err
}
removed++
}
return nil
})
return removed
}
package store
import (
"encoding/json"
"errors"
"fmt"
"os"
"time"
bolt "go.etcd.io/bbolt"
)
// timeKeyLowerBound computes the lexicographic lower bound of any TimeKey
// produced for timestamp t. Any TimeKey whose stored time is strictly
// earlier than t sorts before this string; any TimeKey at or after t sorts
// at or above it. Matches the format in TimeKey().
func timeKeyLowerBound(t time.Time) string {
return fmt.Sprintf("%04d%02d%02d%02d%02d%02d%09d-0000",
t.Year(), t.Month(), t.Day(),
t.Hour(), t.Minute(), t.Second(),
t.Nanosecond())
}
// SweepHistoryOlderThan deletes history entries whose TimeKey is strictly
// older than cutoff. Returns the number of entries deleted. All work runs
// in a single bbolt transaction so the UI never sees a half-swept state;
// callers pick cutoffs that keep the batch bounded.
func (db *DB) SweepHistoryOlderThan(cutoff time.Time) (int, error) {
cutoffKey := timeKeyLowerBound(cutoff)
var deleted int
err := db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
if b == nil {
return nil
}
c := b.Cursor()
for k, _ := c.First(); k != nil && string(k) < cutoffKey; k, _ = c.First() {
if err := c.Delete(); err != nil {
return err
}
deleted++
}
if deleted == 0 {
return nil
}
// Decrement the history:count counter without letting it underflow.
current := 0
if v := tx.Bucket([]byte("meta")).Get([]byte("history:count")); v != nil {
_, _ = fmt.Sscanf(string(v), "%d", ¤t)
}
newCount := current - deleted
if newCount < 0 {
newCount = 0
}
return setCounter(tx, "history:count", newCount)
})
return deleted, err
}
// SweepAttackEventsOlderThan deletes attacks:events entries older than
// cutoff and the matching entries from the attacks:events:ip secondary
// index. Returns the number of primary-bucket entries deleted.
func (db *DB) SweepAttackEventsOlderThan(cutoff time.Time) (int, error) {
cutoffKey := timeKeyLowerBound(cutoff)
var deleted int
err := db.bolt.Update(func(tx *bolt.Tx) error {
primary := tx.Bucket([]byte("attacks:events"))
secondary := tx.Bucket([]byte("attacks:events:ip"))
if primary == nil {
return nil
}
c := primary.Cursor()
for k, v := c.First(); k != nil && string(k) < cutoffKey; k, v = c.First() {
// The secondary index is keyed "<ip>/<TimeKey>", so we need
// the event's IP to prune it.
var ev AttackEvent
if err := json.Unmarshal(v, &ev); err == nil && secondary != nil {
secKey := []byte(ev.IP + "/" + string(k))
if err := secondary.Delete(secKey); err != nil {
return err
}
}
if err := c.Delete(); err != nil {
return err
}
deleted++
}
if deleted == 0 {
return nil
}
current := 0
if v := tx.Bucket([]byte("meta")).Get([]byte("attacks:events:count")); v != nil {
_, _ = fmt.Sscanf(string(v), "%d", ¤t)
}
newCount := current - deleted
if newCount < 0 {
newCount = 0
}
return setCounter(tx, "attacks:events:count", newCount)
})
return deleted, err
}
// Size returns the on-disk size of the bbolt file in bytes. bbolt does
// not shrink the file on delete; compare Size() before and after a
// CompactInto call to see how much space would be reclaimed.
func (db *DB) Size() (int64, error) {
info, err := os.Stat(db.path)
if err != nil {
return 0, err
}
return info.Size(), nil
}
// CompactInto snapshots the live DB into a fresh bbolt file at dstPath
// using bolt.Compact. Returns the source size and the compacted size
// (both in bytes).
//
// Correctness: bolt.Compact runs a View transaction on src for the
// duration of the walk, so concurrent Update calls on src will either
// land before the walk begins (captured in the snapshot) or after it
// completes (not in the snapshot). It is the caller's job to quiesce
// writers between the CompactInto call and the file rename+reopen that
// promotes the new file; otherwise post-snapshot writes are silently
// dropped during the swap.
//
// txMaxSize caps per-transaction bytes written to the destination (see
// bolt.Compact docs). Zero means "one transaction for the whole copy",
// which is the fastest path for DBs that comfortably fit in memory.
func (db *DB) CompactInto(dstPath string, txMaxSize int64) (srcSize, dstSize int64, err error) {
if dstPath == "" {
return 0, 0, errors.New("dst path is empty")
}
// Snapshot the src size up front; if bolt.Compact mutates src in ways
// we didn't anticipate, a concurrent reader still sees consistent
// numbers.
srcInfo, statErr := os.Stat(db.path)
if statErr != nil {
return 0, 0, fmt.Errorf("stat src: %w", statErr)
}
srcSize = srcInfo.Size()
dst, err := bolt.Open(dstPath, 0600, &bolt.Options{Timeout: 5 * time.Second})
if err != nil {
// bolt.Open may have created a zero-byte file before failing; clean it up.
_ = os.Remove(dstPath)
return srcSize, 0, fmt.Errorf("opening dst: %w", err)
}
compactErr := bolt.Compact(dst, db.bolt, txMaxSize)
if closeErr := dst.Close(); closeErr != nil && compactErr == nil {
compactErr = fmt.Errorf("closing dst: %w", closeErr)
}
if compactErr != nil {
_ = os.Remove(dstPath)
return srcSize, 0, compactErr
}
dstInfo, err := os.Stat(dstPath)
if err != nil {
return srcSize, 0, fmt.Errorf("stat dst: %w", err)
}
return srcSize, dstInfo.Size(), nil
}
// SweepReputationOlderThan deletes reputation entries whose CheckedAt is
// strictly older than cutoff. The bucket is keyed by IP and not by time,
// so the sweep inspects each value; malformed rows are skipped rather than
// aborting the sweep.
func (db *DB) SweepReputationOlderThan(cutoff time.Time) (int, error) {
var deleted int
err := db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("reputation"))
if b == nil {
return nil
}
var stale [][]byte
if err := b.ForEach(func(k, v []byte) error {
var e ReputationEntry
// Malformed rows are skipped: a bad Unmarshal here must not
// abort the sweep for the rest of the bucket. Expressed as
// "proceed only when the row parses and is stale" so the
// happy path stays on the left of the guard.
if err := json.Unmarshal(v, &e); err == nil && e.CheckedAt.Before(cutoff) {
// Copy k because the slice is only valid for the
// duration of the callback.
stale = append(stale, append([]byte(nil), k...))
}
return nil
}); err != nil {
return err
}
for _, k := range stale {
if err := b.Delete(k); err != nil {
return err
}
deleted++
}
return nil
})
return deleted, err
}
package store
import (
"encoding/json"
"errors"
"time"
bolt "go.etcd.io/bbolt"
)
// Persistence helpers for the signature-update watcher's mtime map.
// The daemon stores the last-seen mtime per signature file in bbolt
// so a restart does not trigger a phantom rescan -- without this the
// in-memory map starts empty after every restart and every file
// looks new.
const sigWatchKey = "last_mtimes"
// GetSignatureMtimes returns the persisted mtime map. Empty (not
// nil) when the bucket has no value yet -- callers can range over
// the result without a nil check.
func (db *DB) GetSignatureMtimes() (map[string]time.Time, error) {
out := map[string]time.Time{}
err := db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("sig_watch"))
if b == nil {
return nil
}
raw := b.Get([]byte(sigWatchKey))
if len(raw) == 0 {
return nil
}
return json.Unmarshal(raw, &out)
})
return out, err
}
// PutSignatureMtimes overwrites the persisted mtime map. Called
// from the watcher's tick after every walk, regardless of whether
// any file changed -- removed files need to disappear from the
// store, not stick around forever.
func (db *DB) PutSignatureMtimes(m map[string]time.Time) error {
payload, err := json.Marshal(m)
if err != nil {
return err
}
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("sig_watch"))
if b == nil {
return errors.New("sig_watch bucket missing (store not migrated)")
}
return b.Put([]byte(sigWatchKey), payload)
})
}
package store
import (
"encoding/json"
"fmt"
"time"
"github.com/pidginhost/csm/internal/alert"
bolt "go.etcd.io/bbolt"
)
// stats:daily holds pre-aggregated SeverityBucket counters keyed by
// "YYYY-MM-DD" date. It is updated atomically with every history insert
// so the 30-day trend chart is decoupled from history pruning. The
// bucket has at most dailyRetentionDays rows and grows by ~50 bytes/day.
const (
bucketStatsDaily = "stats:daily"
metaStatsDailyBackfilled = "stats:daily:backfilled"
)
// dailyRetentionDays caps how far back stats:daily keeps per-day rows.
// Var (not const) so tests can override.
var dailyRetentionDays = 365
// incrStatsDaily increments the per-severity counters for a single
// finding's date inside an existing bbolt write transaction. The caller
// owns the transaction; this helper does not commit.
func incrStatsDaily(tx *bolt.Tx, t time.Time, sev alert.Severity) error {
b := tx.Bucket([]byte(bucketStatsDaily))
if b == nil {
return fmt.Errorf("bucket %s missing", bucketStatsDaily)
}
key := []byte(t.Format("2006-01-02"))
var sb SeverityBucket
if v := b.Get(key); v != nil {
if err := json.Unmarshal(v, &sb); err != nil {
// Corrupted entry - reset rather than refusing to record.
sb = SeverityBucket{}
}
}
sb.Total++
switch sev {
case alert.Critical:
sb.Critical++
case alert.High:
sb.High++
case alert.Warning:
sb.Warning++
}
val, err := json.Marshal(sb)
if err != nil {
return err
}
return b.Put(key, val)
}
// pruneStatsDaily deletes stats:daily rows older than dailyRetentionDays.
// Cheap because the bucket is bounded to ~dailyRetentionDays entries and
// keys sort lexicographically as YYYY-MM-DD.
func pruneStatsDaily(tx *bolt.Tx, now time.Time) error {
b := tx.Bucket([]byte(bucketStatsDaily))
if b == nil {
return nil
}
cutoff := now.AddDate(0, 0, -(dailyRetentionDays - 1)).Format("2006-01-02")
c := b.Cursor()
for k, _ := c.First(); k != nil; k, _ = c.Next() {
if string(k) >= cutoff {
break
}
if err := c.Delete(); err != nil {
return err
}
}
return nil
}
// BackfillStatsDaily seeds stats:daily from the history bucket on first
// run after upgrade. Idempotent: a meta sentinel ensures it only runs
// once. Safe on hosts where the meta:migrated sentinel was set before
// stats:daily existed.
func (db *DB) BackfillStatsDaily() error {
var alreadyDone bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
if v := tx.Bucket([]byte("meta")).Get([]byte(metaStatsDailyBackfilled)); v != nil {
alreadyDone = true
}
return nil
})
if alreadyDone {
return nil
}
// Read history in a View transaction and aggregate in memory so we
// don't hold the write lock while scanning potentially large history.
type counts struct {
c, h, w, total int
}
perDay := make(map[string]*counts)
err := db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
if b == nil {
return nil
}
c := b.Cursor()
for k, v := c.First(); k != nil; k, v = c.Next() {
var f alert.Finding
if err := json.Unmarshal(v, &f); err != nil {
continue
}
key := f.Timestamp.Format("2006-01-02")
cnt, ok := perDay[key]
if !ok {
cnt = &counts{}
perDay[key] = cnt
}
cnt.total++
switch f.Severity {
case alert.Critical:
cnt.c++
case alert.High:
cnt.h++
case alert.Warning:
cnt.w++
}
}
return nil
})
if err != nil {
return err
}
// Apply the aggregated counts and set the sentinel atomically. We
// merge into existing rows (additive) so the operation stays
// idempotent if the sentinel write somehow gets lost mid-flight.
return db.bolt.Update(func(tx *bolt.Tx) error {
// Re-check sentinel inside the write transaction in case another
// process raced ahead of us.
if v := tx.Bucket([]byte("meta")).Get([]byte(metaStatsDailyBackfilled)); v != nil {
return nil
}
b := tx.Bucket([]byte(bucketStatsDaily))
if b == nil {
return fmt.Errorf("bucket %s missing", bucketStatsDaily)
}
for key, cnt := range perDay {
var sb SeverityBucket
if v := b.Get([]byte(key)); v != nil {
if err := json.Unmarshal(v, &sb); err != nil {
sb = SeverityBucket{}
}
}
sb.Critical += cnt.c
sb.High += cnt.h
sb.Warning += cnt.w
sb.Total += cnt.total
val, mErr := json.Marshal(sb)
if mErr != nil {
return mErr
}
if pErr := b.Put([]byte(key), val); pErr != nil {
return pErr
}
}
return tx.Bucket([]byte("meta")).Put(
[]byte(metaStatsDailyBackfilled),
[]byte(time.Now().Format(time.RFC3339)),
)
})
}
package store
import (
"encoding/json"
"time"
bolt "go.etcd.io/bbolt"
)
// PermanentBlockEntry represents an IP permanently blocked by the threat system.
type PermanentBlockEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
BlockedAt time.Time `json:"blocked_at"`
}
// WhitelistEntry represents an IP that should bypass threat checks.
type WhitelistEntry struct {
IP string `json:"ip"`
ExpiresAt time.Time `json:"expires_at"`
Permanent bool `json:"permanent"`
}
// AddPermanentBlock adds an IP to the permanent block list.
// Only increments threats:count if the key is new.
func (db *DB) AddPermanentBlock(ip, reason string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats"))
isNew := b.Get([]byte(ip)) == nil
entry := PermanentBlockEntry{
IP: ip,
Reason: reason,
BlockedAt: time.Now(),
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
if err := b.Put([]byte(ip), val); err != nil {
return err
}
if isNew {
return incrCounter(tx, "threats:count", 1)
}
return nil
})
}
// RemovePermanentBlock removes an IP from the permanent block list and decrements the count.
func (db *DB) RemovePermanentBlock(ip string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats"))
if b.Get([]byte(ip)) == nil {
return nil
}
if err := b.Delete([]byte(ip)); err != nil {
return err
}
return incrCounter(tx, "threats:count", -1)
})
}
// GetPermanentBlock looks up a permanent block entry by IP.
// Returns the entry and true if found, or a zero value and false if not.
func (db *DB) GetPermanentBlock(ip string) (PermanentBlockEntry, bool) {
var entry PermanentBlockEntry
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats"))
v := b.Get([]byte(ip))
if v == nil {
return nil
}
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return entry, found
}
// AllPermanentBlocks returns all entries in the permanent block list.
func (db *DB) AllPermanentBlocks() []PermanentBlockEntry {
var entries []PermanentBlockEntry
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats"))
return b.ForEach(func(k, v []byte) error {
var entry PermanentBlockEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
entries = append(entries, entry)
return nil
})
})
return entries
}
// AddWhitelistEntry adds an IP to the whitelist.
func (db *DB) AddWhitelistEntry(ip string, expiresAt time.Time, permanent bool) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats:whitelist"))
entry := WhitelistEntry{
IP: ip,
ExpiresAt: expiresAt,
Permanent: permanent,
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(ip), val)
})
}
// RemoveWhitelistEntry removes an IP from the whitelist.
func (db *DB) RemoveWhitelistEntry(ip string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats:whitelist"))
return b.Delete([]byte(ip))
})
}
// IsWhitelisted checks if an IP is whitelisted and not expired.
func (db *DB) IsWhitelisted(ip string) bool {
var whitelisted bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats:whitelist"))
v := b.Get([]byte(ip))
if v == nil {
return nil
}
var entry WhitelistEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
if entry.Permanent || entry.ExpiresAt.After(time.Now()) {
whitelisted = true
}
return nil
})
return whitelisted
}
// ListWhitelist returns all whitelist entries (including expired - caller filters).
func (db *DB) ListWhitelist() []WhitelistEntry {
var entries []WhitelistEntry
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats:whitelist"))
return b.ForEach(func(k, v []byte) error {
var entry WhitelistEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
entries = append(entries, entry)
return nil
})
})
return entries
}
// PruneExpiredWhitelist deletes expired non-permanent whitelist entries.
// Returns the count of entries removed. Uses a collect-then-delete pattern
// because bbolt does not allow mutation during ForEach iteration.
func (db *DB) PruneExpiredWhitelist() int {
var removed int
now := time.Now()
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats:whitelist"))
// Collect keys to delete.
var toDelete [][]byte
_ = b.ForEach(func(k, v []byte) error {
var entry WhitelistEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
if !entry.Permanent && !entry.ExpiresAt.After(now) {
keyCopy := make([]byte, len(k))
copy(keyCopy, k)
toDelete = append(toDelete, keyCopy)
}
return nil
})
// Delete collected keys.
for _, k := range toDelete {
if err := b.Delete(k); err != nil {
return err
}
removed++
}
return nil
})
return removed
}
package threat
import (
"encoding/json"
"os"
"path/filepath"
"time"
"github.com/pidginhost/csm/internal/attackdb"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/firewall"
"github.com/pidginhost/csm/internal/store"
)
// IPIntelligence is the complete picture of an IP from all sources.
type IPIntelligence struct {
IP string `json:"ip"`
// Internal attack history
AttackRecord *attackdb.IPRecord `json:"attack_record,omitempty"`
LocalScore int `json:"local_score"`
// ThreatDB (feeds + permanent blocklist)
InThreatDB bool `json:"in_threat_db"`
ThreatDBSource string `json:"threat_db_source,omitempty"`
// AbuseIPDB cache
AbuseScore int `json:"abuse_score"`
AbuseCategory string `json:"abuse_category,omitempty"`
// Firewall state
CurrentlyBlocked bool `json:"currently_blocked"`
BlockReason string `json:"block_reason,omitempty"`
BlockedAt *time.Time `json:"blocked_at,omitempty"`
BlockExpiresAt *time.Time `json:"block_expires_at,omitempty"`
BlockPermanent bool `json:"block_permanent,omitempty"`
// GeoIP (populated by API layer, not by Lookup)
Country string `json:"country,omitempty"`
CountryName string `json:"country_name,omitempty"`
City string `json:"city,omitempty"`
ASN uint `json:"asn,omitempty"`
ASOrg string `json:"as_org,omitempty"`
Network string `json:"network,omitempty"`
// Composite
UnifiedScore int `json:"unified_score"`
Verdict string `json:"verdict"` // "clean", "suspicious", "malicious", "blocked"
}
// Lookup returns the full intelligence picture for an IP.
// All reads are local (no network calls). Pre-loads shared state files
// once (same as LookupBatch) to avoid re-reading per field.
func Lookup(ip, statePath string) *IPIntelligence {
abuseCache := loadFullAbuseCache(statePath)
blockMap := loadFullBlockState(statePath)
intel := &IPIntelligence{
IP: ip,
AbuseScore: -1, // not cached
}
// 1. Attack DB
if adb := attackdb.Global(); adb != nil {
if rec := adb.LookupIP(ip); rec != nil {
intel.AttackRecord = rec
intel.LocalScore = rec.ThreatScore
}
}
// 2. ThreatDB (feeds + permanent)
if tdb := checks.GetThreatDB(); tdb != nil {
if source, found := tdb.Lookup(ip); found {
intel.InThreatDB = true
intel.ThreatDBSource = source
}
}
// 3. AbuseIPDB from pre-loaded cache
if entry, ok := abuseCache[ip]; ok {
intel.AbuseScore = entry.Score
intel.AbuseCategory = entry.Category
}
// 4. Block state from pre-loaded map
applyBlockState(intel, blockMap)
computeVerdict(intel)
return intel
}
// LookupBatch returns intelligence for multiple IPs efficiently.
// Pre-loads shared state files once instead of per-IP.
func LookupBatch(ips []string, statePath string) []*IPIntelligence {
results := make([]*IPIntelligence, len(ips))
abuseCache := loadFullAbuseCache(statePath)
blockMap := loadFullBlockState(statePath)
for i, ip := range ips {
intel := &IPIntelligence{
IP: ip,
AbuseScore: -1,
}
// Attack DB
if adb := attackdb.Global(); adb != nil {
if rec := adb.LookupIP(ip); rec != nil {
intel.AttackRecord = rec
intel.LocalScore = rec.ThreatScore
}
}
// ThreatDB
if tdb := checks.GetThreatDB(); tdb != nil {
if source, found := tdb.Lookup(ip); found {
intel.InThreatDB = true
intel.ThreatDBSource = source
}
}
// AbuseIPDB from pre-loaded cache
if entry, ok := abuseCache[ip]; ok {
intel.AbuseScore = entry.Score
intel.AbuseCategory = entry.Category
}
// Block state from pre-loaded map
applyBlockState(intel, blockMap)
computeVerdict(intel)
results[i] = intel
}
return results
}
func computeVerdict(intel *IPIntelligence) {
intel.UnifiedScore = intel.LocalScore
if intel.AbuseScore > intel.UnifiedScore {
intel.UnifiedScore = intel.AbuseScore
}
if intel.InThreatDB && intel.UnifiedScore < 100 {
intel.UnifiedScore = 100
}
switch {
case intel.CurrentlyBlocked:
intel.Verdict = "blocked"
case intel.UnifiedScore >= 80:
intel.Verdict = "malicious"
case intel.UnifiedScore >= 40:
intel.Verdict = "suspicious"
default:
intel.Verdict = "clean"
}
}
func applyBlockState(intel *IPIntelligence, blockMap map[string]*blockEntry) {
bs, ok := blockMap[intel.IP]
if !ok {
return
}
intel.CurrentlyBlocked = true
intel.BlockReason = bs.reason
intel.BlockPermanent = bs.permanent
if !bs.blockedAt.IsZero() {
t := bs.blockedAt
intel.BlockedAt = &t
}
if !bs.expiresAt.IsZero() && bs.expiresAt.Year() > 1 {
t := bs.expiresAt
intel.BlockExpiresAt = &t
}
}
// --- AbuseIPDB cache reader ---
type abuseEntry struct {
Score int `json:"score"`
Category string `json:"category"`
}
func loadFullAbuseCache(statePath string) map[string]*abuseEntry {
result := make(map[string]*abuseEntry)
sixHoursAgo := time.Now().Add(-6 * time.Hour)
// Try bbolt store first - after migration the flat file is renamed to .bak.
if sdb := store.Global(); sdb != nil {
for ip, entry := range sdb.AllReputation() {
if entry.CheckedAt.Before(sixHoursAgo) || entry.Score < 0 {
continue // expired or error sentinel
}
result[ip] = &abuseEntry{Score: entry.Score, Category: entry.Category}
}
return result
}
// Fallback: flat-file JSON (pre-migration).
type cacheEntry struct {
Score int `json:"score"`
Category string `json:"category"`
CheckedAt time.Time `json:"checked_at"`
}
type cacheFile struct {
Entries map[string]*cacheEntry `json:"entries"`
}
// #nosec G304 -- filepath.Join under operator-configured statePath.
data, err := os.ReadFile(filepath.Join(statePath, "reputation_cache.json"))
if err != nil {
return result
}
var cf cacheFile
if json.Unmarshal(data, &cf) != nil || cf.Entries == nil {
return result
}
for ip, entry := range cf.Entries {
if entry.CheckedAt.Before(sixHoursAgo) || entry.Score < 0 {
continue // expired or error sentinel
}
result[ip] = &abuseEntry{Score: entry.Score, Category: entry.Category}
}
return result
}
// --- Firewall block state reader ---
type blockEntry struct {
reason string
blockedAt time.Time
expiresAt time.Time
permanent bool
}
func loadFullBlockState(statePath string) map[string]*blockEntry {
result := make(map[string]*blockEntry)
now := time.Now()
// Read the authoritative firewall engine state (flat-file state.json).
// The bbolt fw:blocked bucket is written only at migration, so it would
// return a stale snapshot rather than the live block set.
if fwState, err := firewall.LoadState(statePath); err == nil && fwState != nil {
for _, entry := range fwState.Blocked {
perm := entry.ExpiresAt.IsZero() || entry.ExpiresAt.Year() <= 1
result[entry.IP] = &blockEntry{
reason: entry.Reason,
blockedAt: entry.BlockedAt,
expiresAt: entry.ExpiresAt,
permanent: perm,
}
}
}
// CSM blocked_ips.json (legacy)
type csmEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type csmFile struct {
IPs []csmEntry `json:"ips"`
}
// #nosec G304 -- filepath.Join under operator-configured statePath.
if data, err := os.ReadFile(filepath.Join(statePath, "blocked_ips.json")); err == nil {
var cf csmFile
if json.Unmarshal(data, &cf) == nil {
for _, entry := range cf.IPs {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
if _, exists := result[entry.IP]; !exists {
perm := entry.ExpiresAt.IsZero() || entry.ExpiresAt.Year() <= 1
result[entry.IP] = &blockEntry{
reason: entry.Reason,
blockedAt: entry.BlockedAt,
expiresAt: entry.ExpiresAt,
permanent: perm,
}
}
}
}
}
}
return result
}
package threatintel
import "context"
// AbuseIPDBSource adapts the existing reputation lookup function to the
// Source interface. The underlying function lives in internal/checks/
// and is injected at construction (avoids an import cycle).
type AbuseIPDBSource struct {
lookup func(ctx context.Context, ip string) (int, error)
}
// NewAbuseIPDBSource wraps the provided lookup function. The caller
// (typically internal/checks/reputation.go) supplies a closure over its
// existing AbuseIPDB query path.
func NewAbuseIPDBSource(lookup func(context.Context, string) (int, error)) *AbuseIPDBSource {
return &AbuseIPDBSource{lookup: lookup}
}
func (a *AbuseIPDBSource) Name() string { return "abuseipdb" }
func (a *AbuseIPDBSource) Score(ctx context.Context, ip string) (int, error) {
return a.lookup(ctx, ip)
}
// Package threatintel -- bot allowlist + verification.
//
// botallowlist.go owns the embedded static IP CIDR ranges and the UA
// substring -> claimed-bot mapping. The static snapshots are a fast positive
// allow path: if the source IP falls inside a published bot range, skip the
// request without rDNS. Refreshed at runtime by the existing
// `csm update-rules` plumbing (separate task -- not in this commit).
package threatintel
import (
_ "embed"
"encoding/json"
"net"
"strings"
"sync"
)
//go:embed embed/googlebot.json
var googlebotJSON []byte
//go:embed embed/bingbot.json
var bingbotJSON []byte
//go:embed embed/applebot.json
var applebotJSON []byte
// BotRanges holds the parsed allowlist data, indexed by claimed-bot
// identity ("googlebot", "bingbot", "applebot").
type BotRanges struct {
byBot map[string][]*net.IPNet
}
type embedFile struct {
Prefixes []struct {
IPv4 string `json:"ipv4Prefix"`
IPv6 string `json:"ipv6Prefix"`
} `json:"prefixes"`
}
var (
defaultRanges *BotRanges
rangesOnce sync.Once
)
// DefaultRanges parses the embedded snapshots once and returns the
// global BotRanges. Safe to call concurrently.
func DefaultRanges() *BotRanges {
rangesOnce.Do(func() {
defaultRanges = &BotRanges{byBot: map[string][]*net.IPNet{}}
for bot, raw := range map[string][]byte{
"googlebot": googlebotJSON,
"bingbot": bingbotJSON,
"applebot": applebotJSON,
} {
var f embedFile
if err := json.Unmarshal(raw, &f); err != nil {
continue
}
for _, p := range f.Prefixes {
if cidr := p.IPv4; cidr != "" {
if _, n, err := net.ParseCIDR(cidr); err == nil {
defaultRanges.byBot[bot] = append(defaultRanges.byBot[bot], n)
}
}
if cidr := p.IPv6; cidr != "" {
if _, n, err := net.ParseCIDR(cidr); err == nil {
defaultRanges.byBot[bot] = append(defaultRanges.byBot[bot], n)
}
}
}
}
})
return defaultRanges
}
// IPInBot reports whether the given IP falls inside the static range
// of the given bot identity.
func (r *BotRanges) IPInBot(ip net.IP, bot string) bool {
if ip == nil {
return false
}
for _, n := range r.byBot[bot] {
if n.Contains(ip) {
return true
}
}
return false
}
// IPInAnyBot reports whether the IP falls inside any published crawler
// range (Googlebot/Bingbot/Applebot snapshots). Unlike IPInBot it needs
// no claimed-UA, so callers that only have an IP (e.g. the incident
// correlator's whitelist backstop) can recognise a verified-crawler
// address. Deliberately covers only crawlers that publish authoritative
// IP ranges -- CDN edge ranges are NOT included, because legitimate and
// malicious traffic share a CDN's egress IPs and whitelisting them would
// hide attacks proxied through the CDN.
func (r *BotRanges) IPInAnyBot(ip net.IP) bool {
if ip == nil {
return false
}
for _, nets := range r.byBot {
for _, n := range nets {
if n.Contains(ip) {
return true
}
}
}
return false
}
// ClaimedBotFromUA returns the lower-case bot identity if the UA looks
// like a known bot. Empty string otherwise. Identities match BotDomains
// keys in botverify.go so the async verifier can look up the right
// DNS suffix list.
func ClaimedBotFromUA(ua string) string {
low := strings.ToLower(ua)
switch {
case strings.Contains(low, "googlebot"):
return "googlebot"
case strings.Contains(low, "bingbot"):
return "bingbot"
case strings.Contains(low, "applebot"):
return "applebot"
// Appendix A bots: no published static IP range.
case strings.Contains(low, "duckduckbot"):
return "duckduckbot"
case strings.Contains(low, "amazonbot"):
return "amazonbot"
case strings.Contains(low, "gptbot"), strings.Contains(low, "chatgpt-user"):
return "gptbot"
case strings.Contains(low, "claudebot"), strings.Contains(low, "claude-searchbot"):
return "claudebot"
case strings.Contains(low, "perplexitybot"):
return "perplexitybot"
case strings.Contains(low, "meta-externalagent"),
strings.Contains(low, "meta-webindexer"),
strings.Contains(low, "facebookexternalhit"):
return "facebookbot"
case strings.Contains(low, "bravebot"):
return "bravebot"
default:
return ""
}
}
package threatintel
import (
"context"
"errors"
"net"
"strings"
"sync"
"time"
)
type resolver interface {
LookupAddr(ctx context.Context, ip string) ([]string, error)
LookupIP(ctx context.Context, network, host string) ([]net.IP, error)
}
// verifier owns one resolver + a domain suffix list per bot identity.
// One verifier per bot identity in practice; tests construct directly.
type verifier struct {
res resolver
domains []string // lower-case suffix list, e.g. "googlebot.com"
}
func newVerifier(r resolver, domains []string) *verifier {
low := make([]string, len(domains))
for i, d := range domains {
low[i] = strings.ToLower(d)
}
return &verifier{res: r, domains: low}
}
// LogicVersion identifies the current shape of the bot-verifier logic
// (BotDomains suffix list, ClaimedBotFromUA mapping, no-PTR semantics).
// Bump this whenever a change here would invalidate cache entries
// written by an older build -- for example, adding a new domain suffix
// that turns prior negatives into positives, or adding a new UA -> bot
// identity mapping. The daemon calls store.DB.EnsureBotVerifyLogicVersion
// at startup with this value; a mismatch wipes the botverify bucket so
// the next scan re-verifies every IP under the new rules.
const LogicVersion = 2
// ErrUnverifiable signals that the resolver returned no usable PTR for
// the source IP, so the verifier cannot prove or disprove the claimed
// bot identity. Callers treat this as fail-open: do not cache, do not
// flag as spoof. Genuine spoof signals -- PTR present but outside the
// bot's domain suffix list, or forward-confirm mismatch -- still return
// (false, nil).
var ErrUnverifiable = errors.New("bot verify: no PTR record for source IP")
// verify performs Google's official PTR + forward-A method. Returns
// (true, nil) on success, (false, nil) on a definitive negative
// (PTR resolves but does not belong to the claimed bot's domain, or
// forward-A fails to round-trip the IP), (false, ErrUnverifiable) when
// the IP has no PTR at all, and (false, err) on context cancellation
// or transient resolver failure. Both error paths cause the async
// worker to skip the cache write so unverifiable IPs do not get pinned
// as spoof for the TTL window.
func (v *verifier) verify(ctx context.Context, ip net.IP, bot string) (bool, error) {
names, err := v.res.LookupAddr(ctx, ip.String())
if err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
return false, ctxErr
}
if isDNSNotFound(err) {
return false, ErrUnverifiable
}
return false, err
}
if len(names) == 0 {
return false, ErrUnverifiable
}
matched := ""
for _, n := range names {
ln := strings.ToLower(strings.TrimSuffix(n, "."))
for _, suf := range v.domains {
if strings.HasSuffix(ln, "."+suf) || ln == suf {
matched = ln
break
}
}
if matched != "" {
break
}
}
if matched == "" {
return false, nil
}
addrs, err := v.res.LookupIP(ctx, "ip", matched)
if err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
return false, ctxErr
}
if isDNSNotFound(err) {
return false, nil
}
return false, err
}
for _, a := range addrs {
if a.Equal(ip) {
return true, nil
}
}
return false, nil
}
func isDNSNotFound(err error) bool {
var dnsErr *net.DNSError
return errors.As(err, &dnsErr) && dnsErr.IsNotFound && !dnsErr.IsTemporary && !dnsErr.IsTimeout
}
// AsyncBotVerifier runs PTR+forward-A verify in a single background
// goroutine, deduplicating in-flight jobs. Result writes through the
// put callback (store.DB.PutBotVerify); reads happen from the scan
// hot path via store.DB.GetBotVerify with no goroutine.
type AsyncBotVerifier struct {
mu sync.Mutex
inflight map[string]struct{}
ch chan verifyJob
v map[string]*verifier // bot identity -> verifier
put func(net.IP, string, bool, time.Time) error
}
type verifyJob struct {
IP net.IP
Bot string
}
// BotDomains maps each claimed-bot identity to the DNS suffix list
// used for PTR + forward-A verification. Covers all bots that appear
// frequently in production traffic and have no published static IP
// range (Task 4 handles static-range bots via embedded JSON).
var BotDomains = map[string][]string{
"googlebot": {"googlebot.com", "google.com"},
"bingbot": {"search.msn.com"},
"applebot": {"applebot.apple.com", "apple.com"},
"duckduckbot": {"duckduckgo.com"},
"amazonbot": {"amazonbot.amazon", "amazon.com", "developer.amazon.com"},
"gptbot": {"openai.com"},
"claudebot": {"anthropic.com"},
"perplexitybot": {"perplexity.ai"},
"facebookbot": {"fbsv.net", "tfbnw.net", "facebook.com"},
"bravebot": {"brave.com"},
}
// NewAsyncBotVerifier constructs an async verifier backed by the
// system resolver. put is store.DB.PutBotVerify or a test seam.
func NewAsyncBotVerifier(put func(net.IP, string, bool, time.Time) error) *AsyncBotVerifier {
res := net.DefaultResolver
a := &AsyncBotVerifier{
inflight: make(map[string]struct{}),
ch: make(chan verifyJob, 256),
v: make(map[string]*verifier),
put: put,
}
for bot, domains := range BotDomains {
a.v[bot] = newVerifier(res, domains)
}
return a
}
// Enqueue queues a verification job. Drops the request on a full queue
// (the scan path must never block on bot verification).
func (a *AsyncBotVerifier) Enqueue(ip net.IP, bot string) {
key := bot + "|" + ip.String()
a.mu.Lock()
if _, ok := a.inflight[key]; ok {
a.mu.Unlock()
return
}
a.inflight[key] = struct{}{}
a.mu.Unlock()
select {
case a.ch <- verifyJob{IP: ip, Bot: bot}:
default:
a.mu.Lock()
delete(a.inflight, key)
a.mu.Unlock()
}
}
// Run processes the queue until stopCh closes. Runs as a single
// goroutine so DNS calls are serialised; volume is bounded by the
// inflight dedup map so bursts do not launch unbounded goroutines.
//
// Closing stopCh cancels the parent context, so any in-flight verify
// returns from its DNS lookup immediately rather than holding the Run
// goroutine for the per-job 5s timeout.
func (a *AsyncBotVerifier) Run(stopCh <-chan struct{}) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
bridge := make(chan struct{})
go func() {
defer close(bridge)
select {
case <-stopCh:
cancel()
case <-ctx.Done():
}
}()
defer func() {
cancel()
<-bridge
}()
for {
select {
case <-ctx.Done():
return
case job := <-a.ch:
a.processWithContext(ctx, job)
}
}
}
func (a *AsyncBotVerifier) process(job verifyJob) {
a.processWithContext(context.Background(), job)
}
func (a *AsyncBotVerifier) processWithContext(parent context.Context, job verifyJob) {
defer a.finish(job)
v, ok := a.v[job.Bot]
if !ok {
return
}
ctx, cancel := context.WithTimeout(parent, 5*time.Second)
result, err := v.verify(ctx, job.IP, job.Bot)
cancel()
if err != nil || a.put == nil {
return
}
_ = a.put(job.IP, job.Bot, result, time.Now().Add(24*time.Hour))
}
func (a *AsyncBotVerifier) finish(job verifyJob) {
a.mu.Lock()
delete(a.inflight, job.Bot+"|"+job.IP.String())
a.mu.Unlock()
}
package threatintel
import (
"sync"
"sync/atomic"
"github.com/pidginhost/csm/internal/metrics"
)
var upstreamMetrics = struct {
mu sync.Mutex
registered map[*metrics.Registry]struct{}
active atomic.Pointer[UpstreamSource]
cacheHitsTotal atomic.Int64
cacheMissesTotal atomic.Int64
backendFailuresTotal atomic.Int64
}{
registered: make(map[*metrics.Registry]struct{}),
}
// RegisterUpstreamMetrics binds upstream counters to reg so operators
// can observe cache effectiveness and upstream health.
// Production callers pass metrics.Default(); tests pass an isolated
// registry. Idempotent per registry because reputation checks rebuild
// the upstream source every cycle.
func RegisterUpstreamMetrics(reg *metrics.Registry, src *UpstreamSource) {
if reg == nil || src == nil {
return
}
upstreamMetrics.active.Store(src)
upstreamMetrics.mu.Lock()
if _, ok := upstreamMetrics.registered[reg]; ok {
upstreamMetrics.mu.Unlock()
return
}
upstreamMetrics.registered[reg] = struct{}{}
upstreamMetrics.mu.Unlock()
registerUpstreamMetricsLocked(reg)
}
// ClearUpstreamMetricsSource clears the source used by the breaker
// gauge when upstream reputation is disabled by a hot-reloaded config.
func ClearUpstreamMetricsSource() {
upstreamMetrics.active.Store(nil)
}
func registerUpstreamMetricsLocked(reg *metrics.Registry) {
reg.RegisterCounterFunc(
"csm_threatintel_cache_hits_total",
"Upstream threat-intel cache hits.",
func() float64 {
return float64(upstreamMetrics.cacheHitsTotal.Load())
},
)
reg.RegisterCounterFunc(
"csm_threatintel_cache_misses_total",
"Upstream threat-intel lookups not served from the local cache.",
func() float64 {
return float64(upstreamMetrics.cacheMissesTotal.Load())
},
)
reg.RegisterCounterFunc(
"csm_threatintel_backend_failures_total",
"Upstream threat-intel backend failures (network, 4xx, 5xx, malformed body).",
func() float64 {
return float64(upstreamMetrics.backendFailuresTotal.Load())
},
)
reg.RegisterGaugeFunc(
"csm_threatintel_breaker_open",
"Circuit breaker for the upstream source; 1 when open (calls refused), 0 when closed or half-open.",
func() float64 {
src := activeUpstreamMetricsSource()
if src != nil && src.BreakerOpen() {
return 1
}
return 0
},
)
}
func activeUpstreamMetricsSource() *UpstreamSource {
return upstreamMetrics.active.Load()
}
func resetUpstreamMetricsForTest() {
upstreamMetrics.active.Store(nil)
upstreamMetrics.cacheHitsTotal.Store(0)
upstreamMetrics.cacheMissesTotal.Store(0)
upstreamMetrics.backendFailuresTotal.Store(0)
}
package threatintel
import (
"context"
"encoding/json"
"fmt"
"io"
"math"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
)
const rspamdMaxHistoryBytes = 2 << 20
// RspamdSource queries rspamd's rolling history and returns a score
// 0..100 derived only from rows matching the requested IP.
//
// Token resolution happens at Score time (not at construction time) so
// operators can rotate the rspamd controller password via env var
// without restarting the daemon.
type RspamdSource struct {
url string
token string // static token from config (may be empty)
tokenEnv string // env var name to consult at query time
client *http.Client
}
func NewRspamdSource(url, token, tokenEnv string) *RspamdSource {
return &RspamdSource{
url: url,
token: token,
tokenEnv: tokenEnv,
client: &http.Client{Timeout: 5 * time.Second},
}
}
func (s *RspamdSource) Name() string { return "rspamd" }
// resolveToken reads the env var (if set) at query time, falling back to
// the static token. Lets operators rotate via env without daemon restart.
func (s *RspamdSource) resolveToken() string {
if s.tokenEnv != "" {
if v := os.Getenv(s.tokenEnv); v != "" {
return v
}
}
return s.token
}
type rspamdHistoryResp struct {
Rows []rspamdHistoryRow `json:"rows"`
History []rspamdHistoryRow `json:"history"`
Data []rspamdHistoryRow `json:"data"`
}
func (r rspamdHistoryResp) entries() []rspamdHistoryRow {
out := make([]rspamdHistoryRow, 0, len(r.Rows)+len(r.History)+len(r.Data))
out = append(out, r.Rows...)
out = append(out, r.History...)
out = append(out, r.Data...)
return out
}
type rspamdHistoryRow struct {
IP string
Action string
Score float64
}
func (r *rspamdHistoryRow) UnmarshalJSON(data []byte) error {
var raw map[string]json.RawMessage
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
r.IP = firstJSONString(raw, "ip", "sender_ip", "client_ip")
r.Action = firstJSONString(raw, "action", "metric_action")
r.Score = firstJSONFloat(raw, "score")
return nil
}
func firstJSONString(raw map[string]json.RawMessage, names ...string) string {
for _, name := range names {
v, ok := raw[name]
if !ok {
continue
}
var s string
if err := json.Unmarshal(v, &s); err == nil {
return s
}
}
return ""
}
func firstJSONFloat(raw map[string]json.RawMessage, names ...string) float64 {
for _, name := range names {
v, ok := raw[name]
if !ok {
continue
}
var f float64
if err := json.Unmarshal(v, &f); err == nil {
return f
}
}
return 0
}
// Score sends a GET to <url>/history and scores only history rows for ip.
func (s *RspamdSource) Score(ctx context.Context, ip string) (int, error) {
endpoint, err := s.historyEndpoint()
if err != nil {
return 0, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return 0, err
}
if tok := s.resolveToken(); tok != "" {
req.Header.Set("Password", tok)
}
resp, err := s.client.Do(req)
if err != nil {
return 0, fmt.Errorf("rspamd: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("rspamd HTTP %d", resp.StatusCode)
}
rows, err := decodeRspamdHistory(resp.Body)
if err != nil {
return 0, fmt.Errorf("rspamd decode: %w", err)
}
return scoreRspamdHistory(rows, ip), nil
}
func (s *RspamdSource) historyEndpoint() (string, error) {
u, err := url.Parse(s.url)
if err != nil {
return "", err
}
if u.Scheme == "" || u.Host == "" {
return "", fmt.Errorf("rspamd URL must include scheme and host")
}
u.Path = strings.TrimRight(u.Path, "/") + "/history"
u.RawQuery = ""
return u.String(), nil
}
func decodeRspamdHistory(r io.Reader) ([]rspamdHistoryRow, error) {
var raw json.RawMessage
if err := json.NewDecoder(io.LimitReader(r, rspamdMaxHistoryBytes)).Decode(&raw); err != nil {
return nil, err
}
trimmed := strings.TrimSpace(string(raw))
if strings.HasPrefix(trimmed, "[") {
var rows []rspamdHistoryRow
if err := json.Unmarshal(raw, &rows); err != nil {
return nil, err
}
return rows, nil
}
var history rspamdHistoryResp
if err := json.Unmarshal(raw, &history); err != nil {
return nil, err
}
return history.entries(), nil
}
func scoreRspamdHistory(rows []rspamdHistoryRow, ip string) int {
want := normalizeIP(ip)
score := 0
for _, row := range rows {
if normalizeIP(row.IP) != want {
continue
}
score += rspamdActionScore(row.Action)
if row.Score > 0 {
score += int(math.Ceil(row.Score * 4))
}
if score >= 100 {
return 100
}
}
return score
}
func rspamdActionScore(action string) int {
switch strings.ToLower(strings.TrimSpace(action)) {
case "reject":
return 50
case "soft reject":
return 35
case "greylist", "add header", "rewrite subject":
return 20
default:
return 0
}
}
func normalizeIP(ip string) string {
ip = strings.TrimSpace(ip)
parsed := net.ParseIP(ip)
if parsed == nil {
return ip
}
return parsed.String()
}
// Package threatintel defines a pluggable interface for IP reputation
// providers and an Aggregator that combines their scores. CSM uses this
// to consult AbuseIPDB plus optional rspamd / upstream sources without
// hardcoding multiple lookup paths in the reputation check.
package threatintel
import "context"
// Source is a single scoring provider. Score returns 0..100 (higher is
// worse). A source that has no opinion on the IP must return 0, nil -
// that score is excluded from the aggregator's average. Errors are
// per-source and non-fatal at the aggregator level (other sources still
// run).
type Source interface {
Name() string
Score(ctx context.Context, ip string) (int, error)
}
// Aggregator runs every registered source and averages their non-zero scores.
type Aggregator struct {
sources []Source
}
// NewAggregator constructs an empty Aggregator.
func NewAggregator() *Aggregator { return &Aggregator{} }
// Register adds a Source to the aggregator. Order matters only for the
// `Sources` field of Result (which preserves registration order); the
// aggregated score is order-independent.
func (a *Aggregator) Register(s Source) { a.sources = append(a.sources, s) }
// Result holds the aggregated value and per-source breakdown.
type Result struct {
AggregatedScore int `json:"aggregated_score"`
Sources map[string]int `json:"sources"`
}
// Score queries every registered source. Per-source errors are swallowed
// (the source contributes "no signal"). The aggregated score is the mean
// of non-zero scores; if every source returned 0 (or errored), the
// aggregated score is 0.
func (a *Aggregator) Score(ctx context.Context, ip string) (Result, error) {
out := Result{Sources: map[string]int{}}
sum, n := 0, 0
for _, s := range a.sources {
score, err := s.Score(ctx, ip)
if err != nil {
continue
}
out.Sources[s.Name()] = score
if score > 0 {
sum += score
n++
}
}
if n > 0 {
out.AggregatedScore = sum / n
}
return out, nil
}
package threatintel
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
const (
upstreamMaxResponseBytes = 1 << 20
// defaultMaxCacheEntries caps the per-IP score cache so a sustained
// flood of unique attacker IPs cannot grow the map without bound.
// Once the cap is reached, expired entries are pruned first; if the
// cap is still exceeded the oldest entry (by expires) is evicted.
defaultMaxCacheEntries = 10000
// defaultBreakerTrip is the number of consecutive upstream failures
// after which the source short-circuits subsequent Score calls.
defaultBreakerTrip = 5
// defaultBreakerCooldown is how long the breaker stays open before
// allowing a probe call through.
defaultBreakerCooldown = 60 * time.Second
)
// UpstreamConfig configures the HTTP threat-intel client. TokenEnv (if
// set) is consulted at every Score call so operators can rotate via env
// without restarting the daemon.
type UpstreamConfig struct {
URL string
Token string
TokenEnv string
CacheTTL time.Duration
Timeout time.Duration
}
// UpstreamSource queries a panel-side TI cache. The wire contract is
// documented in docs/upstream-threat-intel-contract.md.
//
// GET <URL>/lookup?ip=<ip>
// Authorization: Bearer <token> (omitted if no token resolved)
//
// 200 OK
// {"ip":"1.2.3.4","score":75,"source":"upstream","ttl_sec":900}
//
// Errors of any flavour (network, 4xx, 5xx, malformed JSON) propagate
// up - the aggregator treats them as "no signal" rather than fatal.
type UpstreamSource struct {
cfg UpstreamConfig
client *http.Client
mu sync.RWMutex
cache map[string]upstreamEntry
// maxCacheEntries caps the in-memory score cache. Exposed for tests
// that want to validate eviction without staging 10k entries.
maxCacheEntries int
// breakerMu guards the circuit-breaker state below.
breakerMu sync.Mutex
consecutiveErrors int
breakerOpenedAt time.Time
breakerProbe bool
breakerTrip int
breakerCooldown time.Duration
// Per-source counters back MetricsSnapshot. Package-level metrics use
// process-wide counters because reputation checks rebuild this source.
cacheHitsTotal atomic.Int64
cacheMissesTotal atomic.Int64
backendFailuresTotal atomic.Int64
}
type upstreamEntry struct {
score int
expires time.Time
}
// upstreamResponse mirrors the documented panel response shape.
type upstreamResponse struct {
IP string `json:"ip"`
Score int `json:"score"`
Source string `json:"source,omitempty"`
TTLSec int `json:"ttl_sec,omitempty"`
}
func NewUpstreamSource(cfg UpstreamConfig) *UpstreamSource {
if cfg.Timeout == 0 {
cfg.Timeout = 5 * time.Second
}
if cfg.CacheTTL == 0 {
cfg.CacheTTL = 15 * time.Minute
}
return &UpstreamSource{
cfg: cfg,
client: &http.Client{Timeout: cfg.Timeout},
cache: make(map[string]upstreamEntry),
maxCacheEntries: defaultMaxCacheEntries,
breakerTrip: defaultBreakerTrip,
breakerCooldown: defaultBreakerCooldown,
}
}
func (u *UpstreamSource) Name() string { return "upstream" }
// resolveToken reads TokenEnv (if set) at query time, falling back to the
// static token. Lets operators rotate via env without daemon restart.
func (u *UpstreamSource) resolveToken() string {
if u.cfg.TokenEnv != "" {
if v := os.Getenv(u.cfg.TokenEnv); v != "" {
return v
}
}
return u.cfg.Token
}
func (u *UpstreamSource) Score(ctx context.Context, ip string) (int, error) {
if v, ok := u.cacheGet(ip); ok {
u.cacheHitsTotal.Add(1)
upstreamMetrics.cacheHitsTotal.Add(1)
return v, nil
}
u.cacheMissesTotal.Add(1)
upstreamMetrics.cacheMissesTotal.Add(1)
if open, until := u.breakerOpen(); open {
if until.IsZero() {
return 0, fmt.Errorf("upstream breaker probe already running")
}
return 0, fmt.Errorf("upstream breaker open for %s", time.Until(until).Round(time.Second))
}
endpoint, err := u.lookupEndpoint(ip)
if err != nil {
return 0, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return 0, err
}
if tok := u.resolveToken(); tok != "" {
req.Header.Set("Authorization", "Bearer "+tok)
}
req.Header.Set("Accept", "application/json")
resp, err := u.client.Do(req)
if err != nil {
u.breakerObserve(false)
return 0, fmt.Errorf("upstream request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
u.breakerObserve(false)
fmt.Fprintf(os.Stderr, "upstream threat-intel: HTTP %d for %s\n", resp.StatusCode, ip)
return 0, fmt.Errorf("upstream HTTP %d", resp.StatusCode)
}
var body upstreamResponse
if err := json.NewDecoder(io.LimitReader(resp.Body, upstreamMaxResponseBytes)).Decode(&body); err != nil {
u.breakerObserve(false)
return 0, fmt.Errorf("upstream decode: %w", err)
}
if normalizeIP(body.IP) != normalizeIP(ip) {
u.breakerObserve(false)
return 0, fmt.Errorf("upstream response ip mismatch: got %q want %q", body.IP, ip)
}
if body.Score < 0 || body.Score > 100 {
u.breakerObserve(false)
return 0, fmt.Errorf("upstream score out of range: %d", body.Score)
}
u.breakerObserve(true)
ttl := u.cfg.CacheTTL
if body.TTLSec > 0 {
ttl = time.Duration(body.TTLSec) * time.Second
}
u.cachePut(ip, body.Score, ttl)
return body.Score, nil
}
// breakerOpen reports whether the circuit breaker is currently open.
// When open and the cooldown has elapsed, the breaker transitions to a
// half-open state by clearing the timestamp so one probe call may pass.
func (u *UpstreamSource) breakerOpen() (bool, time.Time) {
u.breakerMu.Lock()
defer u.breakerMu.Unlock()
if u.breakerOpenedAt.IsZero() {
return false, time.Time{}
}
until := u.breakerOpenedAt.Add(u.breakerCooldown)
if time.Now().Before(until) {
return true, until
}
if u.breakerProbe {
return true, time.Time{}
}
u.breakerProbe = true
return false, time.Time{}
}
func (u *UpstreamSource) resetBreakerLocked() {
u.consecutiveErrors = 0
u.breakerProbe = false
u.breakerOpenedAt = time.Time{}
}
func (u *UpstreamSource) recordBreakerFailureLocked(now time.Time) {
u.breakerProbe = false
u.consecutiveErrors++
if u.breakerTrip > 0 && u.consecutiveErrors >= u.breakerTrip {
u.breakerOpenedAt = now
}
}
func (u *UpstreamSource) closeExpiredBreakerLocked(now time.Time) {
if !u.breakerOpenedAt.IsZero() && now.Sub(u.breakerOpenedAt) >= u.breakerCooldown && !u.breakerProbe {
u.breakerOpenedAt = time.Time{}
}
}
// breakerObserve records the outcome of an upstream call so the
// breaker can trip after enough consecutive failures or reset on a
// successful response.
func (u *UpstreamSource) breakerObserve(success bool) {
u.breakerMu.Lock()
defer u.breakerMu.Unlock()
if success {
u.resetBreakerLocked()
return
}
u.backendFailuresTotal.Add(1)
upstreamMetrics.backendFailuresTotal.Add(1)
now := time.Now()
u.closeExpiredBreakerLocked(now)
u.recordBreakerFailureLocked(now)
}
// MetricsSnapshot returns this source's cache-hit, cache-miss, and
// backend-failure counters. Safe to call from any goroutine.
func (u *UpstreamSource) MetricsSnapshot() (cacheHits, cacheMisses, backendFailures int64) {
return u.cacheHitsTotal.Load(), u.cacheMissesTotal.Load(), u.backendFailuresTotal.Load()
}
// BreakerOpen reports whether the circuit breaker is currently in the
// open state (refusing calls). Read-only.
func (u *UpstreamSource) BreakerOpen() bool {
u.breakerMu.Lock()
defer u.breakerMu.Unlock()
if u.breakerOpenedAt.IsZero() {
return false
}
return time.Now().Before(u.breakerOpenedAt.Add(u.breakerCooldown))
}
func (u *UpstreamSource) lookupEndpoint(ip string) (string, error) {
endpoint, err := url.Parse(u.cfg.URL)
if err != nil {
return "", fmt.Errorf("parsing upstream URL: %w", err)
}
if endpoint.Scheme != "http" && endpoint.Scheme != "https" {
return "", fmt.Errorf("upstream URL must use http or https")
}
if endpoint.Host == "" {
return "", fmt.Errorf("upstream URL must include host")
}
endpoint.Path = strings.TrimRight(endpoint.Path, "/") + "/lookup"
endpoint.Fragment = ""
q := endpoint.Query()
q.Set("ip", normalizeIP(ip))
endpoint.RawQuery = q.Encode()
return endpoint.String(), nil
}
func (u *UpstreamSource) cacheGet(ip string) (int, bool) {
u.mu.RLock()
defer u.mu.RUnlock()
e, ok := u.cache[ip]
if !ok || time.Now().After(e.expires) {
return 0, false
}
return e.score, true
}
func (u *UpstreamSource) cachePut(ip string, score int, ttl time.Duration) {
u.mu.Lock()
defer u.mu.Unlock()
u.cache[ip] = upstreamEntry{score: score, expires: time.Now().Add(ttl)}
u.evictLocked()
}
// evictLocked drops expired entries first, then evicts the oldest
// (smallest expires) until size is within maxCacheEntries. Caller
// must hold u.mu.
func (u *UpstreamSource) evictLocked() {
if u.maxCacheEntries <= 0 || len(u.cache) <= u.maxCacheEntries {
return
}
now := time.Now()
for k, e := range u.cache {
if now.After(e.expires) {
delete(u.cache, k)
}
}
for len(u.cache) > u.maxCacheEntries {
var oldestKey string
var oldestAt time.Time
first := true
for k, e := range u.cache {
if first || e.expires.Before(oldestAt) {
oldestKey = k
oldestAt = e.expires
first = false
}
}
delete(u.cache, oldestKey)
}
}
// cacheLen is exposed for tests to inspect cache size without exposing
// the underlying map.
func (u *UpstreamSource) cacheLen() int {
u.mu.RLock()
defer u.mu.RUnlock()
return len(u.cache)
}
package updatecheck
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
const defaultGitHubReleasesURL = "https://api.github.com/repos/pidginhost/csm/releases/latest"
// fetchGitHubLatest returns the newest tagged release version, with
// any leading "v" stripped. Pre-releases are skipped via the standard
// /releases/latest endpoint, which already excludes them.
func fetchGitHubLatest(ctx context.Context, hc *http.Client, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", err
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("User-Agent", "csm-update-check")
resp, err := hc.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
return "", fmt.Errorf("github releases: status %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", err
}
var payload struct {
TagName string `json:"tag_name"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return "", err
}
tag := strings.TrimSpace(payload.TagName)
if tag == "" {
return "", fmt.Errorf("github releases: empty tag_name")
}
return strings.TrimPrefix(tag, "v"), nil
}
package updatecheck
import (
"context"
"errors"
"fmt"
"os/exec"
"reflect"
"runtime"
"sort"
"strings"
"time"
)
// AptProbe queries `apt-cache policy <pkg>` and returns the candidate
// version. Returns an error when apt-cache is missing, the package is
// unknown, or the candidate is "(none)".
func AptProbe(packageName string) PackageProbe {
return func(ctx context.Context) (string, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "apt-cache", "policy", packageName) // #nosec G204 -- packageName is operator-controlled config, not attacker input
out, err := cmd.Output()
if err != nil {
return "", fmt.Errorf("apt-cache policy: %w", err)
}
return parseAptPolicy(string(out))
}
}
// DnfProbe queries `dnf --quiet repoquery --queryformat=%{version}
// <pkg>` and returns the highest version line. Returns an error when
// dnf is missing or returns no rows.
func DnfProbe(packageName string) PackageProbe {
return func(ctx context.Context) (string, error) {
ctx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "dnf", "--quiet", "repoquery", "--queryformat=%{version}\n", packageName) // #nosec G204 -- packageName is operator-controlled config, not attacker input
out, err := cmd.Output()
if err != nil {
return "", fmt.Errorf("dnf repoquery: %w", err)
}
return parseDnfRepoquery(string(out))
}
}
func parseAptPolicy(out string) (string, error) {
for _, line := range strings.Split(out, "\n") {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "Candidate:") {
continue
}
v := strings.TrimSpace(strings.TrimPrefix(line, "Candidate:"))
if v == "" || v == "(none)" {
return "", errors.New("apt-cache policy: candidate (none)")
}
return aptStripEpochRevision(v), nil
}
return "", errors.New("apt-cache policy: no candidate line")
}
func parseDnfRepoquery(out string) (string, error) {
versions := []string{}
for _, line := range strings.Split(out, "\n") {
v := strings.TrimSpace(line)
if v == "" {
continue
}
versions = append(versions, v)
}
if len(versions) == 0 {
return "", errors.New("dnf repoquery: no versions")
}
sort.Slice(versions, func(i, j int) bool { return isNewer(versions[i], versions[j]) })
return versions[0], nil
}
// aptStripEpochRevision drops the optional "EPOCH:" prefix and
// "-DEBIAN_REVISION" suffix added by the apt versioning scheme so the
// returned string can be compared to a plain semver tag.
func aptStripEpochRevision(v string) string {
if i := strings.Index(v, ":"); i >= 0 {
v = v[i+1:]
}
if i := strings.Index(v, "-"); i >= 0 {
v = v[:i]
}
return v
}
// pkgSourceLabel best-effort labels a probe as "apt" or "dnf" by name.
// Unknown probes get "package".
func pkgSourceLabel(p PackageProbe) string {
if p == nil {
return "package"
}
name := runtime.FuncForPC(reflect.ValueOf(p).Pointer()).Name()
switch {
case strings.Contains(name, "AptProbe"):
return "apt"
case strings.Contains(name, "DnfProbe"):
return "dnf"
default:
return "package"
}
}
package updatecheck
import (
"strconv"
"strings"
)
// isNewer returns true when a is strictly greater than b under
// dot-separated numeric ordering. "dev" or empty current always
// loses (a real release is always newer than "dev"). Current-version
// strings produced by git describe, such as 3.0.0-12-gabcdef0, compare
// as newer than their base tag but older than the next tagged release.
func isNewer(a, b string) bool {
a = strings.TrimPrefix(strings.TrimSpace(a), "v")
b = strings.TrimPrefix(strings.TrimSpace(b), "v")
if a == "" {
return false
}
if b == "" || b == "dev" {
return true
}
if base, ok := gitDescribeBase(b); ok {
return isNewer(a, base)
}
ap := strings.Split(a, ".")
bp := strings.Split(b, ".")
for i := 0; i < len(ap) || i < len(bp); i++ {
var av, bv string
if i < len(ap) {
av = ap[i]
}
if i < len(bp) {
bv = bp[i]
}
ai, aErr := strconv.Atoi(av)
bi, bErr := strconv.Atoi(bv)
switch {
case aErr == nil && bErr == nil:
if ai != bi {
return ai > bi
}
case aErr == nil && bErr != nil:
return true
case aErr != nil && bErr == nil:
return false
default:
if av != bv {
return av > bv
}
}
}
return false
}
func gitDescribeBase(v string) (string, bool) {
parts := strings.Split(v, "-")
if len(parts) < 3 {
return "", false
}
if _, err := strconv.Atoi(parts[1]); err != nil {
return "", false
}
if !strings.HasPrefix(parts[2], "g") || !isNumericDotted(parts[0]) {
return "", false
}
return parts[0], true
}
func isNumericDotted(v string) bool {
parts := strings.Split(v, ".")
if len(parts) < 2 {
return false
}
for _, p := range parts {
if p == "" {
return false
}
if _, err := strconv.Atoi(p); err != nil {
return false
}
}
return true
}
// Package updatecheck polls upstream release channels and tells the
// daemon whether a newer CSM version is available so the Web UI can
// surface a banner. It never fetches binaries or modifies the running
// install. Operators upgrade through their normal channel
// (apt, dnf, install.sh, deploy pipeline).
//
// Two sources, tried in order:
//
// 1. GitHub Releases API ("https://api.github.com/repos/pidginhost/csm/releases/latest").
// 2. apt-cache policy or dnf repoquery against the OS package, used
// when the GitHub call fails (network blocked, rate-limited, etc.).
//
// The package never panics on a transient network or exec failure --
// it records the error in Info.Err and keeps the previous successful
// result in place so the banner does not flicker on a single bad poll.
package updatecheck
import (
"context"
"net/http"
"sync/atomic"
"time"
)
// Info is the cached result surfaced to /api/v1/status. Zero value
// means "no check has completed yet"; CheckedAt.IsZero() is the
// canonical signal.
type Info struct {
LatestVersion string `json:"latest_version,omitempty"`
Available bool `json:"available"`
Source string `json:"source,omitempty"` // "github" | "apt" | "dnf"
CheckedAt time.Time `json:"checked_at,omitempty"`
Err string `json:"err,omitempty"`
}
// PackageProbe queries the OS package manager for the highest
// available version of the configured package. Implementations
// must respect ctx for cancellation and timeout.
type PackageProbe func(ctx context.Context) (string, error)
// Options configures a Checker. Zero-valued fields take safe defaults.
type Options struct {
// CurrentVersion is the running daemon's version string ("dev" is
// treated as "always older than any tagged release").
CurrentVersion string
// Interval is how often the checker polls upstream. Clamped to a
// minimum of 1h to avoid hammering the GitHub API.
Interval time.Duration
// GitHubAPIURL overrides the default GitHub releases URL. Used by
// tests; production should leave this empty.
GitHubAPIURL string
// HTTPClient is the HTTP client used for the GitHub request. nil
// gets a sane default with a 15s timeout.
HTTPClient *http.Client
// PackageProbe is the apt/dnf fallback. nil disables fallback.
PackageProbe PackageProbe
// Now lets tests inject a clock. Defaults to time.Now.
Now func() time.Time
// LogErr receives non-fatal probe errors. Optional; nil silences.
LogErr func(source string, err error)
}
// Checker holds polling state. Safe for concurrent reads of Latest()
// while the goroutine started by Run is updating the cache.
type Checker struct {
opts Options
cache atomic.Pointer[Info]
}
// New builds a Checker. Validate the options here so Run can rely on
// invariants without re-checking each tick.
func New(opts Options) *Checker {
if opts.Interval <= 0 {
opts.Interval = 24 * time.Hour
}
if opts.Interval < time.Hour {
opts.Interval = time.Hour
}
if opts.GitHubAPIURL == "" {
opts.GitHubAPIURL = defaultGitHubReleasesURL
}
if opts.HTTPClient == nil {
opts.HTTPClient = &http.Client{Timeout: 15 * time.Second}
}
if opts.Now == nil {
opts.Now = time.Now
}
c := &Checker{opts: opts}
c.cache.Store(&Info{})
return c
}
// Latest returns the most recent successful poll plus any error from
// the most recent attempt. The returned value is safe to mutate.
func (c *Checker) Latest() Info {
if v := c.cache.Load(); v != nil {
return *v
}
return Info{}
}
// CheckOnce runs a single poll synchronously and returns the result.
// Run uses this internally; tests call it directly to avoid the ticker.
func (c *Checker) CheckOnce(ctx context.Context) Info {
now := c.opts.Now()
latest, err := fetchGitHubLatest(ctx, c.opts.HTTPClient, c.opts.GitHubAPIURL)
source := "github"
if err != nil {
if c.opts.LogErr != nil {
c.opts.LogErr("github", err)
}
if c.opts.PackageProbe != nil {
pkgVer, pkgErr := c.opts.PackageProbe(ctx)
if pkgErr == nil {
latest = pkgVer
source = pkgSourceLabel(c.opts.PackageProbe)
err = nil
} else {
if c.opts.LogErr != nil {
c.opts.LogErr("package", pkgErr)
}
err = pkgErr
}
}
}
info := Info{CheckedAt: now}
if err != nil {
// Preserve the last good LatestVersion so the banner does
// not flicker on a single bad poll.
prev := c.Latest()
info.LatestVersion = prev.LatestVersion
info.Available = prev.Available
info.Source = prev.Source
info.Err = err.Error()
} else {
info.LatestVersion = latest
info.Source = source
info.Available = isNewer(latest, c.opts.CurrentVersion)
}
c.cache.Store(&info)
return info
}
// Run polls on the configured interval until ctx is cancelled. It
// performs an initial check after a 5-minute warm-up so daemon
// startup is not blocked on outbound HTTP.
func (c *Checker) Run(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-time.After(5 * time.Minute):
}
c.CheckOnce(ctx)
t := time.NewTicker(c.opts.Interval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
c.CheckOnce(ctx)
}
}
}
// Package verdict implements an HMAC-signed HTTP client for the
// auto_response.verdict_callback hook. CSM POSTs a Request to the panel
// URL before each automatic block; the panel's Response is advisory
// (block / allow / "" -> block default; tenant_id is logged). Errors are
// fail-open: the caller (firewall.Engine.BlockIP) proceeds with the
// default block on any callback failure.
package verdict
import (
"bytes"
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
)
const (
verdictMaxResponseBytes = 64 << 10 // 64 KB - verdict response is small JSON
// verdictMaxResponseSkew bounds how far the panel's reply timestamp
// may drift from CSM's clock. Anything older or further into the
// future is treated as a replayed or forged response.
verdictMaxResponseSkew = 5 * time.Minute
)
// Config configures the verdict callback client. HMACSecretEnv (if set)
// is consulted at every Ask call so operators can rotate via env without
// restarting the daemon.
//
// RequireResponseSignature controls whether the panel must sign its
// response body with the same HMAC scheme used on the request
// (X-CSM-Signature header) and echo the request nonce + timestamp.
// Default is true: when a secret is configured, CSM rejects unsigned
// or forged responses to prevent an on-path attacker from silently
// downgrading a block to "allow". Set the pointer to a false value
// only during phpanel-side rollouts that have not yet implemented
// response signing. Even on the opt-out path, replay protection is
// best-effort enforced: if the panel does echo nonce or timestamp,
// they must match; a panel that echoes neither still works (legacy
// shape). When no HMAC secret is configured at all, signature and
// replay checks are skipped because there is no key to verify against.
//
// AllowUnsigned is the runtime form of
// auto_response.verdict_callback.allow_unsigned. By default, an unsigned
// "allow" response is rejected because it would disable the block. Set true
// only for configs that explicitly opted in to an unsigned rollout.
type Config struct {
URL string
HMACSecret string
HMACSecretEnv string
RequireResponseSignature *bool
AllowUnsigned bool
Timeout time.Duration
}
// requireResponseSig returns the effective response-signature requirement,
// defaulting to true (secure by default) when the operator did not set
// an explicit value.
func (c Config) requireResponseSig() bool {
if c.RequireResponseSignature == nil {
return true
}
return *c.RequireResponseSignature
}
// Request is what CSM asks the panel about. Ask sets Nonce and Timestamp
// for every exchange. They bind each request to its own reply so an
// attacker cannot replay an old "allow" verdict.
type Request struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Severity string `json:"severity,omitempty"`
Source string `json:"source,omitempty"` // "auto_response" | "manual"
Nonce string `json:"nonce,omitempty"`
Timestamp int64 `json:"timestamp,omitempty"`
}
// Response is what the panel may answer.
//
// verdict = "block" - CSM proceeds with its default action.
// verdict = "allow" - CSM logs the verdict but does NOT block.
// verdict = "" or missing - equivalent to "block" (default).
// tenant_id = optional attribution string CSM logs alongside the decision.
// nonce = MUST equal the Request.Nonce when present. Required when
// response signing is in effect; defeats replay of captured replies.
// timestamp = unix seconds the panel produced the reply. MUST be within
// verdictMaxResponseSkew of CSM's clock when present.
type Response struct {
Verdict string `json:"verdict,omitempty"`
TenantID string `json:"tenant_id,omitempty"`
Note string `json:"note,omitempty"`
Nonce string `json:"nonce,omitempty"`
Timestamp int64 `json:"timestamp,omitempty"`
}
// Client posts each block decision to the configured URL and reads the
// (advisory) response. Timeouts and 5xx are returned as errors - the
// caller decides whether to fail open (allow CSM to proceed) or closed.
type Client struct {
cfg Config
client *http.Client
}
func New(cfg Config) *Client {
if cfg.Timeout == 0 {
cfg.Timeout = 2 * time.Second
}
return &Client{cfg: cfg, client: &http.Client{Timeout: cfg.Timeout}}
}
// resolveSecret reads HMACSecretEnv at call time, falling back to the
// static secret. Lets operators rotate via env without daemon restart.
func (c *Client) resolveSecret() string {
if c.cfg.HMACSecretEnv != "" {
if v := os.Getenv(c.cfg.HMACSecretEnv); v != "" {
return v
}
}
return c.cfg.HMACSecret
}
// verifyResponseSignature checks that header carries a well-formed
// X-CSM-Signature header (sha256=<hex>) over body, computed with secret.
// Returns a descriptive error on missing, malformed, or mismatched values
// using a constant-time comparison.
func verifyResponseSignature(secret string, body []byte, header string) error {
if header == "" {
return fmt.Errorf("verdict callback response missing signature header (X-CSM-Signature)")
}
const prefix = "sha256="
if !strings.HasPrefix(header, prefix) {
return fmt.Errorf("verdict callback response signature has unsupported algorithm")
}
gotHex := strings.TrimPrefix(header, prefix)
got, err := hex.DecodeString(gotHex)
if err != nil {
return fmt.Errorf("verdict callback response signature is not valid hex")
}
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(body)
want := mac.Sum(nil)
if !hmac.Equal(got, want) {
return fmt.Errorf("verdict callback response signature mismatch")
}
return nil
}
// Ask POSTs the request, returns the panel's response or an error.
func (c *Client) Ask(ctx context.Context, req Request) (Response, error) {
// Defense-in-depth URL check (config validation already ran at load,
// this re-check defends against misconfiguration via cfg corruption).
rawURL := strings.TrimSpace(c.cfg.URL)
if rawURL == "" {
return Response{}, fmt.Errorf("verdict callback URL not configured")
}
parsed, err := url.Parse(rawURL)
if err != nil {
return Response{}, fmt.Errorf("verdict callback URL parse: %w", err)
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return Response{}, fmt.Errorf("verdict callback URL must use http or https")
}
if parsed.Host == "" {
return Response{}, fmt.Errorf("verdict callback URL must include host")
}
nonce, err := newNonce()
if err != nil {
return Response{}, fmt.Errorf("verdict callback nonce generation: %w", err)
}
req.Nonce = nonce
req.Timestamp = time.Now().Unix()
body, err := json.Marshal(req)
if err != nil {
return Response{}, err
}
secret := c.resolveSecret()
r, err := http.NewRequestWithContext(ctx, http.MethodPost, rawURL, bytes.NewReader(body))
if err != nil {
return Response{}, err
}
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Accept", "application/json")
if secret != "" {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write(body)
r.Header.Set("X-CSM-Signature", "sha256="+hex.EncodeToString(mac.Sum(nil)))
}
resp, err := c.client.Do(r)
if err != nil {
return Response{}, fmt.Errorf("verdict callback: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
// Return the IP in the error so the caller logs it through the
// structured audit path instead of leaking it to raw stderr.
return Response{}, fmt.Errorf("verdict callback HTTP %d for %s", resp.StatusCode, req.IP)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, verdictMaxResponseBytes+1))
if err != nil {
return Response{}, fmt.Errorf("verdict callback read: %w", err)
}
if int64(len(data)) > verdictMaxResponseBytes {
return Response{}, fmt.Errorf("verdict callback response exceeds %d bytes", verdictMaxResponseBytes)
}
if secret != "" && c.cfg.requireResponseSig() {
// Verify the panel signed its response with the same secret used
// on the request. Without this check a network attacker could
// downgrade block to allow on every call. Rejecting the reply keeps
// the engine on its default block path.
if err := verifyResponseSignature(secret, data, resp.Header.Get("X-CSM-Signature")); err != nil {
return Response{}, err
}
}
strictReplay := secret != "" && c.cfg.requireResponseSig()
if strings.TrimSpace(string(data)) == "" {
return Response{}, nil
}
var out Response
dec := json.NewDecoder(bytes.NewReader(data))
if err := dec.Decode(&out); err != nil {
return Response{}, fmt.Errorf("verdict callback decode: %w", err)
}
var trailing struct{}
if err := dec.Decode(&trailing); err != io.EOF {
return Response{}, fmt.Errorf("verdict callback decode: trailing JSON")
}
// Validate response shape. Unknown verdict strings are rejected
// defensively rather than silently treated as "block".
if out.Verdict != "" && out.Verdict != "block" && out.Verdict != "allow" {
return Response{}, fmt.Errorf("verdict callback returned unknown verdict %q", out.Verdict)
}
// Fail closed on an unsigned allow. Without a secret there is no replay or
// signature protection (the checks below are gated on secret != ""), so an
// on-path attacker could return "allow" on every call and silently disable
// auto-blocking. Returning an error keeps the engine on its default block
// path unless the config explicitly opted in to unsigned callback verdicts.
if secret == "" && out.Verdict == "allow" && !c.cfg.AllowUnsigned {
return Response{}, fmt.Errorf("verdict callback: refusing unsigned allow (no HMAC secret configured)")
}
// Replay protection runs whenever a secret is configured. Strict
// mode (response signing required) demands the panel echo nonce and
// timestamp. Best-effort mode (signing opt-out) still enforces what
// the panel did echo, so a captured stale reply with a wrong nonce
// or a long-expired timestamp is rejected even when the operator
// has not yet enabled response signing on the panel side.
if secret != "" {
if out.Nonce != "" {
if subtle.ConstantTimeCompare([]byte(out.Nonce), []byte(req.Nonce)) != 1 {
return Response{}, fmt.Errorf("verdict callback response nonce mismatch")
}
} else if strictReplay {
return Response{}, fmt.Errorf("verdict callback response missing nonce")
}
if out.Timestamp != 0 {
drift := time.Since(time.Unix(out.Timestamp, 0))
if drift < 0 {
drift = -drift
}
if drift > verdictMaxResponseSkew {
return Response{}, fmt.Errorf("verdict callback response timestamp drift %s exceeds %s", drift, verdictMaxResponseSkew)
}
} else if strictReplay {
return Response{}, fmt.Errorf("verdict callback response missing timestamp")
}
// An "allow" is the only verdict an on-path attacker gains from
// forging. In best-effort mode (signing not required) the replay
// checks above only fire when the panel echoed a nonce or timestamp,
// so an attacker can strip both to slip an unbound allow through.
// Require an allow to carry at least one replay binding; otherwise
// treat it like the no-secret case and refuse, keeping the engine on
// its default block path.
if out.Verdict == "allow" && out.Nonce == "" && out.Timestamp == 0 {
return Response{}, fmt.Errorf("verdict callback: refusing allow with no replay binding (nonce/timestamp)")
}
}
return out, nil
}
// newNonce returns a fresh 128-bit hex nonce. crypto/rand is the only
// acceptable source: math/rand would let an attacker who observes one
// nonce predict the next and craft a replay reply in advance.
func newNonce() (string, error) {
var buf [16]byte
if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil {
return "", err
}
return hex.EncodeToString(buf[:]), nil
}
package webui
import (
"net/http"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/checks"
)
func (s *Server) handleAccount(w http.ResponseWriter, r *http.Request) {
name := r.URL.Query().Get("name")
if err := validateAccountName(name); err != nil {
http.Redirect(w, r, "/findings", http.StatusFound)
return
}
s.renderTemplate(w, "account.html", map[string]string{
"Hostname": s.cfg.Hostname,
"AccountName": name,
"HomeDir": filepath.Join("/home", name),
})
}
func (s *Server) apiAccountDetail(w http.ResponseWriter, r *http.Request) {
name := r.URL.Query().Get("name")
if err := validateAccountName(name); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
homePrefix := "/home/" + name + "/"
// Current findings for this account
type findingView struct {
Severity int `json:"severity"`
Check string `json:"check"`
Message string `json:"message"`
HasFix bool `json:"has_fix"`
}
var accountFindings []findingView
latest := s.store.LatestFindings()
for _, f := range latest {
if f.Check == "auto_response" || f.Check == "auto_block" || f.Check == "check_timeout" || f.Check == "health" {
continue
}
if strings.Contains(f.Message, homePrefix) || strings.Contains(f.Details, homePrefix) || strings.Contains(f.FilePath, homePrefix) {
accountFindings = append(accountFindings, findingView{
Severity: int(f.Severity),
Check: f.Check,
Message: f.Message,
HasFix: checks.HasFix(f.Check),
})
}
}
// Quarantined files for this account
type qEntry struct {
ID string `json:"id"`
OriginalPath string `json:"original_path"`
Size int64 `json:"size"`
Reason string `json:"reason"`
}
var quarantined []qEntry
rootMetas := listMetaFiles("/opt/csm/quarantine")
preCleanMetas := listMetaFiles(filepath.Join("/opt/csm/quarantine", "pre_clean"))
metas := rootMetas
metas = append(metas, preCleanMetas...)
for _, metaPath := range metas {
meta, err := readQuarantineMeta(metaPath)
if err != nil {
continue
}
if strings.HasPrefix(meta.OriginalPath, homePrefix) {
id := strings.TrimSuffix(filepath.Base(metaPath), ".meta")
quarantined = append(quarantined, qEntry{
ID: id, OriginalPath: meta.OriginalPath, Size: meta.Size, Reason: meta.Reason,
})
}
}
// Recent history for this account (last 100 matching entries)
allHistory, _ := s.store.ReadHistory(2000, 0)
type histEntry struct {
Severity int `json:"severity"`
Check string `json:"check"`
Message string `json:"message"`
Timestamp string `json:"timestamp"`
}
var history []histEntry
for _, f := range allHistory {
if len(history) >= 100 {
break
}
if strings.Contains(f.Message, homePrefix) || strings.Contains(f.Details, homePrefix) {
history = append(history, histEntry{
Severity: int(f.Severity), Check: f.Check, Message: f.Message,
Timestamp: f.Timestamp.Format(time.RFC3339),
})
}
}
writeJSON(w, map[string]interface{}{
"account": name,
"findings": accountFindings,
"quarantined": quarantined,
"history": history,
"whm_url": "https://" + s.cfg.Hostname + ":2087/scripts/domainsdata?user=" + name,
})
}
package webui
import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"syscall"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/firewall"
"github.com/pidginhost/csm/internal/health"
"github.com/pidginhost/csm/internal/state"
)
var reIPReputation = regexp.MustCompile(`Known malicious IP accessing server: (\S+) \((.+)\)`)
// apiStatus returns the daemon's full health snapshot as JSON. Backward
// compatible with prior callers: every field they consumed (hostname,
// uptime, started_at, rules_loaded, scan_running, last_scan_time) is
// still present, with new fields added alongside.
func (s *Server) apiStatus(w http.ResponseWriter, _ *http.Request) {
provider := s.provider
s.scanMu.Lock()
scanning := s.scanRunning
s.scanMu.Unlock()
if provider == nil {
// No daemon-side provider installed (test harness). Fall back to
// the legacy minimal payload so existing UI code keeps working.
lastScan := ""
if s.store != nil {
lastScan = s.store.LatestScanTime().Format(time.RFC3339)
}
writeJSON(w, map[string]interface{}{
"hostname": s.cfg.Hostname,
"uptime": time.Since(s.startTime).String(),
"started_at": s.startTime.Format(time.RFC3339),
"rules_loaded": s.sigCount,
"scan_running": scanning,
"last_scan_time": lastScan,
"status": "down",
})
return
}
snap := health.Build(provider, s.version, health.Capabilities())
resp := map[string]interface{}{
"hostname": snap.Hostname,
"version": snap.Version,
"uptime": time.Duration(snap.UptimeSec * int64(time.Second)).String(),
"uptime_sec": snap.UptimeSec,
"started_at": snap.StartedAt.Format(time.RFC3339),
"rules_loaded": s.sigCount,
"scan_running": scanning,
// last_scan_time is the legacy key kept for older clients
// (cphulk dashboard, status_check.go). latest_scan mirrors the
// health.Snapshot JSON tag and is the canonical name for new
// clients. Drop last_scan_time once the legacy consumers move.
"last_scan_time": snap.LatestScan.Format(time.RFC3339),
"latest_scan": formatRFC3339OrEmpty(snap.LatestScan),
"baseline_at": formatRFC3339OrEmpty(snap.BaselineAt),
"blocklist_size": snap.BlocklistSize,
"incidents_open": snap.IncidentsOpen,
"bpf_enforcement_active": snap.BPFEnforcementActive,
"history_count": snap.HistoryCount,
"severities": snap.Severities,
"watchers": snap.Watchers,
"store_healthy": snap.StoreHealthy,
"store_size_mb": snap.StoreSizeMB,
"config_hash": snap.ConfigHash,
"binary_hash": snap.BinaryHash,
"capabilities": snap.Capabilities,
"dry_run_blocks": snap.DryRunBlocks,
"automation": snap.Automation,
"status": snap.OverallStatus(),
}
if !snap.Update.CheckedAt.IsZero() {
resp["update"] = snap.Update
}
writeJSON(w, resp)
}
func formatRFC3339OrEmpty(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format(time.RFC3339)
}
// apiCapabilities returns the static feature-flag list for this build.
func (s *Server) apiCapabilities(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, map[string]interface{}{
"capabilities": health.Capabilities(),
"version": s.version,
})
}
// apiFindings returns current scan results - "what's wrong right now."
func (s *Server) apiFindings(w http.ResponseWriter, _ *http.Request) {
latest := s.store.LatestFindings()
type entryView struct {
Severity int `json:"severity"`
Check string `json:"check"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
Time string `json:"time"`
FirstSeen string `json:"first_seen"`
LastSeen string `json:"last_seen"`
HasFix bool `json:"has_fix"`
}
suppressions := s.store.LoadSuppressions()
var result []entryView
for _, f := range latest {
if f.Check == "auto_response" || f.Check == "auto_block" || f.Check == "check_timeout" || f.Check == "health" {
continue
}
// Skip suppressed findings
if s.store.IsSuppressed(f, suppressions) {
continue
}
firstSeen := f.Timestamp
lastSeen := f.Timestamp
if entry, ok := s.store.EntryForKey(f.Key()); ok {
firstSeen = entry.FirstSeen
lastSeen = entry.LastSeen
}
result = append(result, entryView{
Severity: int(f.Severity),
Check: f.Check,
Message: f.Message,
Details: f.Details,
Time: f.Timestamp.Format(time.RFC3339),
FirstSeen: firstSeen.Format(time.RFC3339),
LastSeen: lastSeen.Format(time.RFC3339),
HasFix: checks.HasFix(f.Check),
})
}
writeJSON(w, result)
}
// enrichedFinding is the JSON response type for the enriched findings endpoint.
type enrichedFinding struct {
Key string `json:"key"`
Severity string `json:"severity"`
SevClass string `json:"sev_class"`
Check string `json:"check"`
Message string `json:"message"`
FilePath string `json:"file_path,omitempty"`
Account string `json:"account,omitempty"`
FirstSeen string `json:"first_seen"`
LastSeen string `json:"last_seen"`
HasFix bool `json:"has_fix"`
FixDesc string `json:"fix_desc,omitempty"`
}
// dedupIPReputation groups ip_reputation findings by IP, merging sources and
// promoting to the highest severity. Non-ip_reputation findings pass through unchanged.
func dedupIPReputation(items []enrichedFinding) []enrichedFinding {
type ipGroup struct {
entry enrichedFinding
sources []string
}
ipGroups := make(map[string]*ipGroup)
var ipOrder []string
var result []enrichedFinding
for _, item := range items {
if item.Check != "ip_reputation" {
result = append(result, item)
continue
}
m := reIPReputation.FindStringSubmatch(item.Message)
if m == nil {
result = append(result, item)
continue
}
ip, source := m[1], m[2]
if g, ok := ipGroups[ip]; ok {
g.sources = append(g.sources, source)
if item.FirstSeen < g.entry.FirstSeen {
g.entry.FirstSeen = item.FirstSeen
}
if item.LastSeen > g.entry.LastSeen {
g.entry.LastSeen = item.LastSeen
}
if severityRank(item.Severity) > severityRank(g.entry.Severity) {
g.entry.Severity = item.Severity
g.entry.SevClass = item.SevClass
}
} else {
ipGroups[ip] = &ipGroup{
entry: item,
sources: []string{source},
}
ipOrder = append(ipOrder, ip)
}
}
for _, ip := range ipOrder {
g := ipGroups[ip]
g.entry.Message = fmt.Sprintf("Known malicious IP accessing server: %s (%s)", ip, strings.Join(g.sources, ", "))
result = append(result, g.entry)
}
return result
}
// apiFindingsEnriched returns findings with IP dedup, account extraction, and severity counts.
func (s *Server) apiFindingsEnriched(w http.ResponseWriter, _ *http.Request) {
latest := s.store.LatestFindings()
suppressions := s.store.LoadSuppressions()
items := make([]enrichedFinding, 0)
for _, f := range latest {
if f.Check == "auto_response" || f.Check == "auto_block" || f.Check == "check_timeout" || f.Check == "health" {
continue
}
if s.store.IsSuppressed(f, suppressions) {
continue
}
firstSeen := f.Timestamp
lastSeen := f.Timestamp
if entry, ok := s.store.EntryForKey(f.Key()); ok {
firstSeen = entry.FirstSeen
lastSeen = entry.LastSeen
}
items = append(items, enrichedFinding{
Key: f.Key(),
Severity: severityLabel(f.Severity),
SevClass: severityClass(f.Severity),
Check: f.Check,
Message: f.Message,
FilePath: f.FilePath,
Account: extractAccountFromFinding(f),
FirstSeen: firstSeen.Format(time.RFC3339),
LastSeen: lastSeen.Format(time.RFC3339),
HasFix: checks.HasFix(f.Check),
FixDesc: checks.FixDescription(f.Check, f.Message, f.FilePath),
})
}
items = dedupIPReputation(items)
var critCount, highCount, warnCount int
for _, item := range items {
switch item.Severity {
case "CRITICAL":
critCount++
case "HIGH":
highCount++
default:
warnCount++
}
}
checkTypeSet := make(map[string]bool)
accountSet := make(map[string]bool)
for _, item := range items {
checkTypeSet[item.Check] = true
if item.Account != "" {
accountSet[item.Account] = true
}
}
checkTypes := make([]string, 0, len(checkTypeSet))
for ct := range checkTypeSet {
checkTypes = append(checkTypes, ct)
}
sort.Strings(checkTypes)
accounts := make([]string, 0, len(accountSet))
for a := range accountSet {
accounts = append(accounts, a)
}
sort.Strings(accounts)
writeJSON(w, map[string]interface{}{
"findings": items,
"check_types": checkTypes,
"accounts": accounts,
"critical_count": critCount,
"high_count": highCount,
"warning_count": warnCount,
"total": len(items),
})
}
// apiHistory returns paginated finding history.
// Supports optional filtering via "from", "to" (YYYY-MM-DD), and "severity" (0/1/2) query params.
func (s *Server) apiHistory(w http.ResponseWriter, r *http.Request) {
limit := queryInt(r, "limit", 50)
if limit > 5000 {
limit = 5000
}
offset := queryInt(r, "offset", 0)
fromStr := r.URL.Query().Get("from")
toStr := r.URL.Query().Get("to")
sevStr := r.URL.Query().Get("severity")
searchStr := r.URL.Query().Get("search")
checksStr := r.URL.Query().Get("checks")
var checksFilter map[string]bool
if checksStr != "" {
checksFilter = make(map[string]bool)
for _, c := range strings.Split(checksStr, ",") {
c = strings.TrimSpace(c)
if c != "" {
checksFilter[c] = true
}
}
}
// If no filters, use simple paginated read
if fromStr == "" && toStr == "" && sevStr == "" && searchStr == "" && checksStr == "" {
findings, total := s.store.ReadHistory(limit, offset)
writeJSON(w, map[string]interface{}{
"findings": findings,
"total": total,
"limit": limit,
"offset": offset,
})
return
}
sevFilter := -1
if sevStr != "" {
sevFilter = queryInt(r, "severity", -1)
}
findings, total := s.store.ReadHistoryFilteredWithChecks(
limit,
offset,
fromStr,
toStr,
sevFilter,
searchStr,
checksFilter,
)
writeJSON(w, map[string]interface{}{
"findings": findings,
"total": total,
"limit": limit,
"offset": offset,
"truncated": false,
})
}
// var (not const) so tests can redirect to t.TempDir(). Production
// callers must not mutate at runtime.
var quarantineDir = "/opt/csm/quarantine"
// apiQuarantine lists quarantined files with metadata.
func (s *Server) apiQuarantine(w http.ResponseWriter, _ *http.Request) {
type quarantineEntry struct {
ID string `json:"id"`
Kind string `json:"kind"`
OriginalPath string `json:"original_path"`
Size int64 `json:"size"`
QuarantineAt string `json:"quarantined_at"`
Reason string `json:"reason"`
LiveState string `json:"live_state"`
}
var entries []quarantineEntry
// Scan both root quarantine dir and pre_clean subdirectory
rootMetas := listMetaFiles(quarantineDir)
preCleanMetas := listMetaFiles(filepath.Join(quarantineDir, "pre_clean"))
metaFiles := rootMetas
metaFiles = append(metaFiles, preCleanMetas...)
for _, metaFile := range metaFiles {
meta, err := readQuarantineMeta(metaFile)
if err != nil {
continue
}
// Hide entries whose original has been restored byte-identical to
// the archive: the archive is redundant and the UI should reflect
// the live filesystem, not the quarantine history. Divergence
// (missing, different size, different content) keeps the entry
// visible -- the operator still has to reconcile it.
archivePath := strings.TrimSuffix(metaFile, ".meta")
liveState := quarantineLiveState(archivePath, meta.OriginalPath)
if liveState == "restored_identical" {
continue
}
kind := "quarantine"
if strings.HasPrefix(quarantineEntryID(metaFile), preCleanQuarantineIDPrefix) {
kind = "pre_clean"
}
entries = append(entries, quarantineEntry{
ID: quarantineEntryID(metaFile),
Kind: kind,
OriginalPath: meta.OriginalPath,
Size: meta.Size,
QuarantineAt: meta.QuarantineAt.Format(time.RFC3339),
Reason: meta.Reason,
LiveState: liveState,
})
}
// Sort newest first
sort.Slice(entries, func(i, j int) bool {
return entries[i].QuarantineAt > entries[j].QuarantineAt
})
writeJSON(w, entries)
}
// apiStats returns severity counts and per-check breakdown.
func (s *Server) apiStats(w http.ResponseWriter, _ *http.Request) {
last24h := time.Now().Add(-24 * time.Hour)
findings := s.store.ReadHistorySince(last24h)
critical, high, warning := 0, 0, 0
byCheck := make(map[string]int)
for _, f := range findings {
switch f.Severity {
case alert.Critical:
critical++
case alert.High:
high++
case alert.Warning:
warning++
}
byCheck[f.Check]++
}
// Find most recent critical finding for "time since last critical"
// (findings are newest-first from ReadHistorySince)
lastCriticalAgo := "None"
lastCriticalISO := ""
for _, f := range findings {
if f.Severity == alert.Critical {
lastCriticalAgo = timeAgo(f.Timestamp)
lastCriticalISO = f.Timestamp.Format(time.RFC3339)
break
}
}
// Compute accounts at risk: accounts with critical/high findings in 24h
accountRisk := make(map[string]int) // account -> highest severity
// Auto-response summary: count actions by type in 24h
autoBlocked, autoQuarantined, autoKilled := 0, 0, 0
// Top targeted accounts
accountHits := make(map[string]int)
// Brute force summary
bruteForceIPs := make(map[string]int) // IP -> total attempts
bruteForceTypes := make(map[string]int) // "wp-login" / "xmlrpc" -> count
for _, f := range findings {
// Extract account from finding path/message
acct := extractAccountFromFinding(f)
if acct != "" {
accountHits[acct]++
sev := int(f.Severity)
if prev, ok := accountRisk[acct]; !ok || sev > prev {
accountRisk[acct] = sev
}
}
// Count auto-response actions
switch f.Check {
case "auto_block":
autoBlocked++
case "auto_response":
if strings.Contains(f.Message, "quarantin") {
autoQuarantined++
} else if strings.Contains(f.Message, "kill") || strings.Contains(f.Message, "Kill") {
autoKilled++
}
case "wp_login_bruteforce":
bruteForceTypes["wp-login"]++
if ip := checks.ExtractIPFromFinding(f); ip != "" {
bruteForceIPs[ip]++
}
case "xmlrpc_abuse":
bruteForceTypes["xmlrpc"]++
if ip := checks.ExtractIPFromFinding(f); ip != "" {
bruteForceIPs[ip]++
}
case "modsec_block_escalation", "modsec_csm_block_escalation":
if strings.Contains(f.Message, "xmlrpc") || strings.Contains(f.Message, "900006") || strings.Contains(f.Message, "900007") {
bruteForceTypes["xmlrpc-modsec"]++
}
}
}
// Accounts at risk: those with critical or high severity
var atRisk []map[string]interface{}
for acct, sev := range accountRisk {
if sev >= int(alert.High) {
atRisk = append(atRisk, map[string]interface{}{
"account": acct,
"severity": sev,
"findings": accountHits[acct],
})
}
}
// Sort by severity desc, then findings desc
sort.Slice(atRisk, func(i, j int) bool {
if atRisk[i]["severity"].(int) != atRisk[j]["severity"].(int) {
return atRisk[i]["severity"].(int) > atRisk[j]["severity"].(int)
}
return atRisk[i]["findings"].(int) > atRisk[j]["findings"].(int)
})
if len(atRisk) > 50 {
atRisk = atRisk[:50]
}
// Top targeted accounts (by finding count)
type acctCount struct {
Account string `json:"account"`
Count int `json:"count"`
}
var topAccounts []acctCount
for acct, count := range accountHits {
topAccounts = append(topAccounts, acctCount{acct, count})
}
sort.Slice(topAccounts, func(i, j int) bool {
return topAccounts[i].Count > topAccounts[j].Count
})
if len(topAccounts) > 5 {
topAccounts = topAccounts[:5]
}
result := map[string]interface{}{
"last_24h": map[string]interface{}{
"critical": critical,
"high": high,
"warning": warning,
"total": critical + high + warning,
},
"by_check": byCheck,
"last_critical_ago": lastCriticalAgo,
"last_critical_iso": lastCriticalISO,
"accounts_at_risk": atRisk,
"auto_response": map[string]int{
"blocked": autoBlocked,
"quarantined": autoQuarantined,
"killed": autoKilled,
},
"top_accounts": topAccounts,
"brute_force": buildBruteForceSummary(bruteForceIPs, bruteForceTypes),
}
writeJSON(w, result)
}
func buildBruteForceSummary(ips map[string]int, types map[string]int) map[string]interface{} {
// Top attacker IPs
type ipCount struct {
IP string `json:"ip"`
Count int `json:"count"`
}
var topIPs []ipCount
for ip, count := range ips {
topIPs = append(topIPs, ipCount{ip, count})
}
sort.Slice(topIPs, func(i, j int) bool {
return topIPs[i].Count > topIPs[j].Count
})
if len(topIPs) > 10 {
topIPs = topIPs[:10]
}
total := 0
for _, v := range types {
total += v
}
return map[string]interface{}{
"total_attacks": total,
"unique_ips": len(ips),
"wp_login_count": types["wp-login"],
"xmlrpc_count": types["xmlrpc"] + types["xmlrpc-modsec"],
"top_ips": topIPs,
}
}
// apiStatsTrend returns daily finding counts by severity for the trend
// chart. Accepts optional ?days=N (default 30, clamped to the store's
// retention window).
func (s *Server) apiStatsTrend(w http.ResponseWriter, r *http.Request) {
days := 30
if v := r.URL.Query().Get("days"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
days = n
}
}
writeJSON(w, s.store.AggregateByDayN(days))
}
// apiStatsTimeline returns 24 hourly buckets for the findings timeline chart.
// Uses efficient bbolt cursor seeking instead of loading all findings into memory.
func (s *Server) apiStatsTimeline(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, s.store.AggregateByHour())
}
// apiHealth returns daemon health status.
func (s *Server) apiHealth(w http.ResponseWriter, _ *http.Request) {
health := map[string]interface{}{
"daemon_mode": true,
"uptime": time.Since(s.startTime).String(),
"uptime_seconds": int(time.Since(s.startTime).Seconds()),
"rules_loaded": s.sigCount,
"fanotify": s.fanotifyActive,
"log_watchers": s.logWatcherCount,
}
writeJSON(w, health)
}
// apiHistoryCSV exports history as CSV download.
func (s *Server) apiHistoryCSV(w http.ResponseWriter, _ *http.Request) {
findings, _ := s.store.ReadHistory(5000, 0)
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Content-Disposition", "attachment; filename=csm-history.csv")
// CSV header
fmt.Fprintf(w, "Timestamp,Severity,Check,Message,Details\n")
for _, f := range findings {
sev := "WARNING"
switch f.Severity {
case alert.Critical:
sev = "CRITICAL"
case alert.High:
sev = "HIGH"
}
// Escape CSV fields
msg := csvEscape(f.Message)
details := csvEscape(f.Details)
fmt.Fprintf(w, "%s,%s,%s,%s,%s\n",
f.Timestamp.Format(time.RFC3339), sev, f.Check, msg, details)
}
}
func csvEscape(s string) string {
if strings.ContainsAny(s, ",\"\n\r") {
return "\"" + strings.ReplaceAll(s, "\"", "\"\"") + "\""
}
return s
}
// safeLogString renders a caller-controlled string into a log entry without
// allowing embedded CR/LF/control bytes to forge a separate log line.
func safeLogString(s string) string { return strconv.Quote(s) }
// apiFix applies a known remediation action for a finding.
// POST /api/v1/fix body: {"check": "check_type", "message": "...", "details": "..."}
func (s *Server) apiFix(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Check string `json:"check"`
Message string `json:"message"`
Details string `json:"details"`
FilePath string `json:"file_path"`
Key string `json:"key"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.Check == "" || req.Message == "" {
writeJSONError(w, "check and message are required", http.StatusBadRequest)
return
}
if !checks.HasFix(req.Check) {
writeJSONError(w, "no automated fix available for this check type", http.StatusBadRequest)
return
}
result := checks.ApplyFix(req.Check, req.Message, req.Details, req.FilePath)
// If fix succeeded, dismiss from both alert state and latest findings.
// Prefer the canonical key sent by the client (matches Finding.Key(), which
// includes a hash of Details when present); fall back to check:message for
// older clients and findings with empty Details.
if result.Success {
key := req.Key
if key == "" {
key = req.Check + ":" + req.Message
}
s.store.DismissFinding(key)
s.store.DismissLatestFinding(key)
s.auditLog(r, "fix", req.Check, result.Action)
}
writeJSON(w, result)
}
// apiBulkFix applies fixes to multiple findings at once.
// POST /api/v1/fix-bulk body: [{"check":"...", "message":"...", "details":"..."}, ...]
func (s *Server) apiBulkFix(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var reqs []struct {
Check string `json:"check"`
Message string `json:"message"`
Details string `json:"details"`
FilePath string `json:"file_path"`
Key string `json:"key"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &reqs); err != nil {
writeJSONError(w, "invalid request body", http.StatusBadRequest)
return
}
var results []checks.RemediationResult
for _, req := range reqs {
if !checks.HasFix(req.Check) {
results = append(results, checks.RemediationResult{
Error: fmt.Sprintf("no fix for %s", req.Check),
})
continue
}
result := checks.ApplyFix(req.Check, req.Message, req.Details, req.FilePath)
if result.Success {
key := req.Key
if key == "" {
key = req.Check + ":" + req.Message
}
s.store.DismissFinding(key)
s.store.DismissLatestFinding(key)
}
results = append(results, result)
}
succeeded := 0
for _, r := range results {
if r.Success {
succeeded++
}
}
writeJSON(w, map[string]interface{}{
"results": results,
"total": len(results),
"succeeded": succeeded,
"failed": len(results) - succeeded,
})
}
// apiFixPreview returns what a fix would do without applying it.
// GET /api/v1/fix-preview?check=...&message=...
// apiAccounts returns a list of cPanel account usernames for the scan dropdown.
//
//nolint:unused // registered via mux.Handle in server.go
func (s *Server) apiAccounts(w http.ResponseWriter, _ *http.Request) {
entries, err := os.ReadDir("/home")
if err != nil {
writeJSON(w, []string{})
return
}
var accounts []string
for _, entry := range entries {
if !entry.IsDir() {
continue
}
name := entry.Name()
// Skip system/hidden directories
if strings.HasPrefix(name, ".") || name == "virtfs" || name == "cPanelInstall" ||
name == "cpanelsolr" || name == "lost+found" {
continue
}
// Must have public_html to be a real cPanel account
if _, err := os.Stat(filepath.Join("/home", name, "public_html")); err == nil {
accounts = append(accounts, name)
}
}
writeJSON(w, accounts)
}
// --- Action endpoints ---
// apiBlockIP blocks an IP via the firewall engine.
// POST /api/v1/block-ip body: {"ip": "1.2.3.4", "reason": "..."}
func (s *Server) apiBlockIP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Duration string `json:"duration"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if req.Reason == "" {
req.Reason = "Blocked via CSM Web UI"
}
dur := parseDuration(req.Duration)
if s.blocker == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
// Operator-initiated: bypass auto_response.dry_run gate.
if err := blockIPForOperator(s.blocker, req.IP, req.Reason, dur); err != nil {
writeJSONError(w, fmt.Sprintf("Block failed: %v", err), http.StatusInternalServerError)
return
}
s.auditLog(r, "block_ip", req.IP, req.Reason)
writeJSON(w, map[string]string{"status": "blocked", "ip": req.IP})
}
// apiUnblockIP removes an IP from the firewall + cphulk.
// POST /api/v1/unblock-ip body: {"ip": "1.2.3.4"}
func (s *Server) apiUnblockIP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if s.blocker == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
if err := s.blocker.UnblockIP(req.IP); err != nil {
writeJSONError(w, fmt.Sprintf("Unblock failed: %v", err), http.StatusInternalServerError)
return
}
// Also flush from cphulk (cPanel brute force detector)
flushCphulk(req.IP)
s.auditLog(r, "unblock_ip", req.IP, "manual unblock via UI")
writeJSON(w, map[string]string{"status": "unblocked", "ip": req.IP})
}
// apiUnblockBulk unblocks multiple IPs at once.
func (s *Server) apiUnblockBulk(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IPs []string `json:"ips"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || len(req.IPs) == 0 {
writeJSONError(w, "IPs array is required", http.StatusBadRequest)
return
}
if len(req.IPs) > 100 {
writeJSONError(w, "IPs must be 1-100 items", http.StatusBadRequest)
return
}
if s.blocker == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
succeeded := 0
unblocked := make([]string, 0, len(req.IPs))
for _, ip := range req.IPs {
if _, err := parseAndValidateIP(ip); err != nil {
continue
}
if err := s.blocker.UnblockIP(ip); err != nil {
continue
}
flushCphulk(ip)
s.auditLog(r, "unblock_ip", ip, "bulk unblock via UI")
unblocked = append(unblocked, ip)
succeeded++
}
var undoToken string
if succeeded > 0 {
undoToken = s.recordUndoEntry(r, "firewall_bulk_unblock", undoInverseFirewallUnblock,
fmt.Sprintf("Unblocked %d IPs", succeeded),
undoPayloadIPs{IPs: unblocked, Reason: "Undo: re-block via CSM Web UI", Timeout: "24h"})
}
writeJSON(w, map[string]interface{}{
"status": "completed",
"total": len(req.IPs),
"succeeded": succeeded,
"undo_token": undoToken,
})
}
// blockedEntry is a raw blocked IP record from firewall state.
type blockedEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type blockedView struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source"`
BlockedAt string `json:"blocked_at"`
ExpiresAt string `json:"expires_at"`
ExpiresIn string `json:"expires_in"`
}
func formatBlockedView(b blockedEntry) (blockedView, bool) {
if !b.ExpiresAt.IsZero() && time.Now().After(b.ExpiresAt) {
return blockedView{}, false // expired
}
view := blockedView{
IP: b.IP,
Reason: b.Reason,
Source: b.Source,
BlockedAt: b.BlockedAt.Format(time.RFC3339),
}
if view.Source == "" {
view.Source = firewall.InferProvenance("block", b.Reason)
}
if !b.ExpiresAt.IsZero() {
remaining := time.Until(b.ExpiresAt)
view.ExpiresAt = b.ExpiresAt.Format(time.RFC3339)
view.ExpiresIn = fmt.Sprintf("%dh%dm", int(remaining.Hours()), int(remaining.Minutes())%60)
} else {
view.ExpiresIn = "permanent"
}
return view, true
}
// apiBlockedIPs returns the list of currently blocked IPs.
func (s *Server) apiBlockedIPs(w http.ResponseWriter, _ *http.Request) {
result := []blockedView{}
fwFile := filepath.Join(s.cfg.StatePath, "firewall", "state.json")
_, fwStatErr := os.Stat(fwFile) // #nosec G304 -- filepath.Join under operator-configured StatePath.
if fwState, err := firewall.LoadState(s.cfg.StatePath); err == nil && fwState != nil {
for _, entry := range fwState.Blocked {
b := blockedEntry{
IP: entry.IP,
Reason: entry.Reason,
Source: entry.Source,
BlockedAt: entry.BlockedAt,
ExpiresAt: entry.ExpiresAt,
}
if view, ok := formatBlockedView(b); ok {
result = append(result, view)
}
}
// A present engine state file wins even when empty. blocked_ips.json
// is only a legacy fallback when the engine file does not exist.
if fwStatErr == nil || len(fwState.Blocked) > 0 {
writeJSON(w, result)
return
}
}
// Fall back to blocked_ips.json (legacy)
stateFile := filepath.Join(s.cfg.StatePath, "blocked_ips.json")
// #nosec G304 -- filepath.Join under operator-configured StatePath.
data, err := os.ReadFile(stateFile)
if err != nil {
writeJSON(w, []interface{}{})
return
}
var blockState struct {
IPs []blockedEntry `json:"ips"`
}
if err := json.Unmarshal(data, &blockState); err != nil {
writeJSON(w, []interface{}{})
return
}
for _, b := range blockState.IPs {
if view, ok := formatBlockedView(b); ok {
result = append(result, view)
}
}
writeJSON(w, result)
}
// apiDismissFinding marks a finding as baseline (acknowledged/dismissed).
// POST /api/v1/dismiss body: {"key": "check:message"}
func (s *Server) apiDismissFinding(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Key string `json:"key"`
}
if err := decodeJSONBodyLimited(w, r, 16*1024, &req); err != nil || req.Key == "" {
writeJSONError(w, "Key is required", http.StatusBadRequest)
return
}
s.store.DismissFinding(req.Key)
s.store.DismissLatestFinding(req.Key)
s.auditLog(r, "dismiss", req.Key, "")
writeJSON(w, map[string]string{"status": "dismissed", "key": req.Key})
}
// apiQuarantineRestore restores a quarantined file to its original location.
// POST /api/v1/quarantine-restore body: {"id": "filename"}
func (s *Server) apiQuarantineRestore(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
ID string `json:"id"`
}
if err := decodeJSONBodyLimited(w, r, 16*1024, &req); err != nil || req.ID == "" {
writeJSONError(w, "ID is required", http.StatusBadRequest)
return
}
entry, err := resolveQuarantineEntry(req.ID)
if err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
metaData, err := os.ReadFile(entry.MetaPath)
if err != nil {
writeJSONError(w, "Quarantine entry not found", http.StatusNotFound)
return
}
var meta struct {
OriginalPath string `json:"original_path"`
Owner int `json:"owner_uid"`
Group int `json:"group_gid"`
Mode string `json:"mode"`
}
if unmarshalErr := json.Unmarshal(metaData, &meta); unmarshalErr != nil {
writeJSONError(w, "Invalid metadata", http.StatusInternalServerError)
return
}
restorePath, err := validateQuarantineRestorePath(meta.OriginalPath)
if err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
// Ensure parent directory exists
parentDir := filepath.Dir(restorePath)
// #nosec G301 -- Restoring a quarantined file into a user's public_html
// (validated by validateQuarantineRestorePath). The webserver must be
// able to traverse intermediate directories to serve the file.
if mkdirErr := os.MkdirAll(parentDir, 0755); mkdirErr != nil {
writeJSONError(w, fmt.Sprintf("Cannot create parent directory: %v", mkdirErr), http.StatusInternalServerError)
return
}
// Check if quarantined item is a directory or file
quarInfo, err := os.Lstat(entry.ItemPath)
if err != nil {
writeJSONError(w, fmt.Sprintf("Cannot stat quarantined file: %v", err), http.StatusInternalServerError)
return
}
if quarInfo.Mode()&os.ModeSymlink != 0 {
writeJSONError(w, "Cannot restore symlink quarantine entry", http.StatusInternalServerError)
return
}
// Parse original mode from metadata (format: "-rw-r--r--" or "drwxr-xr-x")
restoredMode := os.FileMode(0644)
if meta.Mode != "" && len(meta.Mode) >= 10 {
restoredMode = parseModeString(meta.Mode)
}
if quarInfo.IsDir() {
if _, statErr := os.Lstat(restorePath); statErr == nil {
writeJSONError(w, "Cannot restore - destination already exists", http.StatusConflict)
return
} else if !os.IsNotExist(statErr) {
writeJSONError(w, fmt.Sprintf("Cannot inspect restore destination: %v", statErr), http.StatusInternalServerError)
return
}
if err := os.Chmod(entry.ItemPath, restoredMode); err != nil {
writeJSONError(w, fmt.Sprintf("Cannot restore directory mode: %v", err), http.StatusInternalServerError)
return
}
if err := syscall.Chown(entry.ItemPath, meta.Owner, meta.Group); err != nil {
log.Printf("webui: chown %s before restore failed: %v", safeLogString(entry.ItemPath), err)
}
// Directory restore: use os.Rename (same device)
if err := os.Rename(entry.ItemPath, restorePath); err != nil {
writeJSONError(w, fmt.Sprintf("Cannot restore directory: %v", err), http.StatusInternalServerError)
return
}
} else {
// File restore: use O_EXCL to prevent overwriting an existing file
src, readErr := os.Open(entry.ItemPath)
if readErr != nil {
writeJSONError(w, fmt.Sprintf("Cannot read quarantined file: %v", readErr), http.StatusInternalServerError)
return
}
// #nosec G304 -- restorePath was validated with
// validateQuarantineRestorePath above (see handler); the O_EXCL|
// O_NOFOLLOW flags additionally block TOCTOU replacement.
dst, createErr := os.OpenFile(restorePath, os.O_WRONLY|os.O_CREATE|os.O_EXCL|syscall.O_NOFOLLOW, 0600)
if createErr != nil {
_ = src.Close()
writeJSONError(w, fmt.Sprintf("Cannot restore - file already exists at original path: %v", createErr), http.StatusConflict)
return
}
if quarantineRestoreAfterCreateForTest != nil {
quarantineRestoreAfterCreateForTest(restorePath)
}
if _, err := ensureOpenFileStillAtPath(dst, restorePath); err != nil {
_ = src.Close()
_ = dst.Close()
writeJSONError(w, "Cannot restore - destination changed during restore", http.StatusConflict)
return
}
_, copyErr := io.Copy(dst, src)
if closeErr := src.Close(); copyErr == nil && closeErr != nil {
copyErr = closeErr
}
if copyErr != nil {
removeRestorePathIfSameOpenFile(dst, restorePath)
_ = dst.Close()
writeJSONError(w, fmt.Sprintf("Cannot write restored file: %v", copyErr), http.StatusInternalServerError)
return
}
if err := dst.Chmod(restoredMode); err != nil {
removeRestorePathIfSameOpenFile(dst, restorePath)
_ = dst.Close()
writeJSONError(w, fmt.Sprintf("Cannot restore file mode: %v", err), http.StatusInternalServerError)
return
}
if err := dst.Chown(meta.Owner, meta.Group); err != nil {
log.Printf("webui: chown %s after restore failed: %v", safeLogString(restorePath), err)
}
restoredInfo, err := ensureOpenFileStillAtPath(dst, restorePath)
if err != nil {
_ = dst.Close()
writeJSONError(w, "Cannot restore - destination changed during restore", http.StatusConflict)
return
}
if err := dst.Close(); err != nil {
removeRestorePathIfSameInfo(restorePath, restoredInfo)
writeJSONError(w, fmt.Sprintf("Cannot write restored file: %v", err), http.StatusInternalServerError)
return
}
if err := ensurePathStillNamesInfo(restorePath, restoredInfo); err != nil {
writeJSONError(w, "Cannot restore - destination changed during restore", http.StatusConflict)
return
}
if err := os.Remove(entry.ItemPath); err != nil && !os.IsNotExist(err) {
log.Printf("webui: failed to remove %s: %v", safeLogString(entry.ItemPath), err)
}
}
// Remove metadata sidecar
if err := os.Remove(entry.MetaPath); err != nil && !os.IsNotExist(err) {
log.Printf("webui: failed to remove %s: %v", safeLogString(entry.MetaPath), err)
}
s.auditLog(r, "restore", restorePath, "quarantine restore")
writeJSON(w, map[string]string{
"status": "restored",
"path": restorePath,
"warning": "File restored to original location. Re-scan recommended.",
})
}
// quarantineRestoreAfterCreateForTest lets race tests replace the path
// after O_EXCL creation; nil in production.
var quarantineRestoreAfterCreateForTest func(string)
func ensureOpenFileStillAtPath(f *os.File, path string) (os.FileInfo, error) {
fileInfo, err := f.Stat()
if err != nil {
return nil, fmt.Errorf("cannot stat restored file handle: %w", err)
}
if err := ensurePathStillNamesInfo(path, fileInfo); err != nil {
return nil, err
}
return fileInfo, nil
}
func ensurePathStillNamesInfo(path string, fileInfo os.FileInfo) error {
pathInfo, err := os.Lstat(path)
if err != nil {
return fmt.Errorf("cannot stat restored file path: %w", err)
}
if !os.SameFile(fileInfo, pathInfo) {
return fmt.Errorf("restore destination changed during restore")
}
return nil
}
func removeRestorePathIfSameOpenFile(f *os.File, path string) {
fileInfo, err := ensureOpenFileStillAtPath(f, path)
if err != nil {
log.Printf("webui: not removing changed restore path %s: %v", safeLogString(path), err)
return
}
removeRestorePathIfSameInfo(path, fileInfo)
}
func removeRestorePathIfSameInfo(path string, fileInfo os.FileInfo) {
if err := ensurePathStillNamesInfo(path, fileInfo); err != nil {
log.Printf("webui: not removing changed restore path %s: %v", safeLogString(path), err)
return
}
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
log.Printf("webui: failed to remove %s: %v", safeLogString(path), err)
}
}
// apiQuarantinePreview returns the first 8KB of a quarantined file for inspection.
func (s *Server) apiQuarantinePreview(w http.ResponseWriter, r *http.Request) {
entry, err := resolveQuarantineEntry(r.URL.Query().Get("id"))
if err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
info, err := os.Stat(entry.ItemPath)
if err != nil {
writeJSONError(w, "not found", http.StatusNotFound)
return
}
if info.IsDir() {
writeJSON(w, map[string]interface{}{
"id": entry.ID, "is_dir": true,
"preview": "[directory - content preview not available]",
})
return
}
f, err := os.Open(entry.ItemPath)
if err != nil {
writeJSONError(w, "cannot read file", http.StatusInternalServerError)
return
}
defer f.Close()
buf := make([]byte, 8192)
n, _ := f.Read(buf)
writeJSON(w, map[string]interface{}{
"id": entry.ID,
"preview": string(buf[:n]),
"truncated": info.Size() > 8192,
"total_size": info.Size(),
})
}
// apiQuarantineBulkDelete permanently removes quarantined files and their metadata.
func (s *Server) apiQuarantineBulkDelete(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IDs []string `json:"ids"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
if len(req.IDs) == 0 || len(req.IDs) > 100 {
writeJSONError(w, "IDs must be 1-100 items", http.StatusBadRequest)
return
}
count := 0
for _, id := range req.IDs {
entry, err := resolveQuarantineEntry(id)
if err != nil {
continue
}
if _, statErr := os.Lstat(entry.ItemPath); statErr == nil {
if err := os.RemoveAll(entry.ItemPath); err == nil {
count++
}
} else if !os.IsNotExist(statErr) {
continue
}
if err := os.Remove(entry.MetaPath); err != nil && !os.IsNotExist(err) {
log.Printf("webui: failed to remove quarantine meta %s: %v", safeLogString(entry.MetaPath), err)
}
}
s.auditLog(r, "quarantine_bulk_delete", fmt.Sprintf("%d files", count), "")
writeJSON(w, map[string]interface{}{"ok": true, "count": count})
}
// apiTestAlert sends a test finding through all configured alert channels.
func (s *Server) apiTestAlert(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
testFinding := []alert.Finding{{
Severity: alert.Warning,
Check: "test_alert",
Message: "Test alert from CSM Web UI",
Details: fmt.Sprintf("Sent by admin at %s", time.Now().Format("2006-01-02 15:04:05")),
Timestamp: time.Now(),
}}
err := alert.Dispatch(s.cfg, testFinding)
if err != nil {
writeJSON(w, map[string]interface{}{"status": "error", "error": err.Error()})
return
}
s.auditLog(r, "test_alert", "notification", "sent test alert")
writeJSON(w, map[string]interface{}{"status": "sent"})
}
// apiScanAccount runs an on-demand scan for a single cPanel account.
// POST /api/v1/scan-account body: {"account": "username"}
func (s *Server) apiScanAccount(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Account string `json:"account"`
}
if err := decodeJSONBodyLimited(w, r, 32*1024, &req); err != nil || req.Account == "" {
writeJSONError(w, "Account name is required", http.StatusBadRequest)
return
}
if err := validateAccountName(req.Account); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
// Rate limit: only one scan at a time
if !s.acquireScan() {
writeJSONError(w, "A scan is already in progress. Please wait.", http.StatusTooManyRequests)
return
}
defer s.releaseScan()
// Extend the write deadline for this long-running request.
// Account scans can take several minutes; the default WriteTimeout (300s)
// causes ERR_HTTP2_PROTOCOL_ERROR in browsers when it fires mid-stream.
rc := http.NewResponseController(w)
_ = rc.SetWriteDeadline(time.Now().Add(10 * time.Minute))
start := time.Now()
findings := checks.RunAccountScan(s.cfg, s.store, req.Account)
elapsed := time.Since(start).Round(time.Millisecond)
result := map[string]interface{}{
"account": req.Account,
"count": len(findings),
"elapsed": elapsed.String(),
}
writeJSON(w, result)
}
// parseModeString converts a permission string like "-rw-r--r--" to os.FileMode.
func parseModeString(s string) os.FileMode {
if len(s) < 10 {
return 0644
}
var mode os.FileMode
perms := s[len(s)-9:] // last 9 chars: "rwxr-xr-x"
bits := []os.FileMode{
0400, 0200, 0100, // owner r/w/x
0040, 0020, 0010, // group r/w/x
0004, 0002, 0001, // other r/w/x
}
for i, b := range bits {
if i < len(perms) && perms[i] != '-' {
mode |= b
}
}
if mode == 0 {
mode = 0644 // fallback
}
return mode
}
// flushCphulk removes brute-force login history for an IP from cPanel's cphulk.
// Callers should pre-validate `ip` with parseAndValidateIP. This function
// re-validates as defense-in-depth so a future caller that forgets cannot
// expose a shell-execution surface even if exec.Command itself does not
// invoke a shell.
func flushCphulk(ip string) {
if _, err := parseAndValidateIP(ip); err != nil {
return
}
// #nosec G204 -- whmapi1 hardcoded; `ip` is validated above with
// parseAndValidateIP. exec.Command passes args directly to execve
// without shell interpolation, so no shell-meta risk either way.
_, _ = exec.Command("whmapi1", "flush_cphulk_login_history_for_ips", "ip="+ip).Output()
}
// apiExport returns a JSON bundle of exportable state.
func (s *Server) apiExport(w http.ResponseWriter, _ *http.Request) {
// Collect suppressions
suppressions := s.store.LoadSuppressions()
if suppressions == nil {
suppressions = []state.SuppressionRule{}
}
// Collect whitelist
var whitelist []checks.WhitelistIP
if tdb := checks.GetThreatDB(); tdb != nil {
whitelist = tdb.WhitelistedIPs()
}
if whitelist == nil {
whitelist = []checks.WhitelistIP{}
}
bundle := map[string]interface{}{
"exported_at": time.Now().Format(time.RFC3339),
"hostname": s.cfg.Hostname,
"suppressions": suppressions,
"whitelist": whitelist,
}
w.Header().Set("Content-Disposition", "attachment; filename=csm-state-export.json")
writeJSON(w, bundle)
}
// apiImport merges an exported state bundle into the current state.
func (s *Server) apiImport(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var bundle struct {
Suppressions []state.SuppressionRule `json:"suppressions"`
Whitelist []struct {
IP string `json:"ip"`
} `json:"whitelist"`
}
if err := decodeJSONBodyLimited(w, r, 512*1024, &bundle); err != nil {
writeJSONError(w, "invalid JSON body", http.StatusBadRequest)
return
}
imported := 0
// Merge suppressions (dedup by ID)
if len(bundle.Suppressions) > 0 {
existing := s.store.LoadSuppressions()
existingIDs := make(map[string]bool)
for _, rule := range existing {
existingIDs[rule.ID] = true
}
for _, rule := range bundle.Suppressions {
if !existingIDs[rule.ID] {
existing = append(existing, rule)
imported++
}
}
if err := s.store.SaveSuppressions(existing); err != nil {
writeJSONError(w, fmt.Sprintf("failed to save suppressions: %v", err), http.StatusInternalServerError)
return
}
}
// Merge whitelist IPs
if len(bundle.Whitelist) > 0 {
if tdb := checks.GetThreatDB(); tdb != nil {
existingWL := tdb.WhitelistedIPs()
existingSet := make(map[string]bool)
for _, w := range existingWL {
existingSet[w.IP] = true
}
for _, entry := range bundle.Whitelist {
// Validate imported IPs like every interactive route does: an
// unvalidated bundle could otherwise poison the threat DB /
// firewall allow-list with malformed or attacker-chosen entries
// (whitelisting bypasses blocking). Use the canonical form.
ip, err := parseAndValidateIP(entry.IP)
if err != nil {
continue
}
canonical := ip.String()
if !existingSet[canonical] {
tdb.AddWhitelist(canonical)
existingSet[canonical] = true
imported++
}
}
}
}
s.auditLog(r, "import", "state", fmt.Sprintf("imported %d items", imported))
writeJSON(w, map[string]interface{}{
"status": "imported",
"imported": imported,
"summary": fmt.Sprintf("%d items imported", imported),
})
}
// apiFindingDetail returns detail about a specific finding including related actions.
func (s *Server) apiFindingDetail(w http.ResponseWriter, r *http.Request) {
check := r.URL.Query().Get("check")
message := r.URL.Query().Get("message")
if check == "" {
writeJSONError(w, "check is required", http.StatusBadRequest)
return
}
key := check + ":" + message
// Get state entry for this finding (first/last seen)
var firstSeen, lastSeen string
if entry, ok := s.store.EntryForKey(key); ok {
firstSeen = entry.FirstSeen.Format(time.RFC3339)
lastSeen = entry.LastSeen.Format(time.RFC3339)
}
// Search audit log for related actions
actions := s.searchAuditEntries(check, 20)
// Search history for related findings (same check type, last 50)
allHistory, _ := s.store.ReadHistory(2000, 0)
type histEntry struct {
Severity int `json:"severity"`
Check string `json:"check"`
Message string `json:"message"`
Timestamp string `json:"timestamp"`
}
var related []histEntry
for _, f := range allHistory {
if len(related) >= 50 {
break
}
if f.Check == check {
related = append(related, histEntry{
Severity: int(f.Severity),
Check: f.Check,
Message: f.Message,
Timestamp: f.Timestamp.Format(time.RFC3339),
})
}
}
writeJSON(w, map[string]interface{}{
"check": check,
"message": message,
"first_seen": firstSeen,
"last_seen": lastSeen,
"actions": actions,
"related": related,
})
}
// extractAccountFromFinding extracts a cPanel account name from a finding
// by checking the message, details, and file path for /home/{user}/ patterns
// or "Account: " / "user: " in the details field (used by login checks).
func extractAccountFromFinding(f alert.Finding) string {
for _, s := range []string{f.Message, f.Details, f.FilePath} {
if idx := strings.Index(s, "/home/"); idx >= 0 {
rest := s[idx+6:]
if end := strings.IndexByte(rest, '/'); end > 0 {
return rest[:end]
}
}
}
for _, prefix := range []string{"Account: ", "user: "} {
if idx := strings.Index(f.Details, prefix); idx >= 0 {
rest := f.Details[idx+len(prefix):]
end := strings.IndexAny(rest, " \n\t,")
if end > 0 {
return rest[:end]
}
if len(rest) > 0 {
return rest
}
}
}
return ""
}
func writeJSONError(w http.ResponseWriter, message string, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(map[string]string{"error": message})
}
func writeJSON(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
_ = enc.Encode(data)
}
func queryInt(r *http.Request, key string, defaultVal int) int {
val := r.URL.Query().Get(key)
if val == "" {
return defaultVal
}
n, err := strconv.Atoi(val)
if err != nil || n < 0 {
return defaultVal
}
return n
}
package webui
import (
"encoding/json"
"fmt"
"net/http"
"time"
)
// sseWriteTimeout caps how long each SSE write is allowed to block on a
// slow or stuck client. It must stay below the daemon's WebUI shutdown
// budget so an in-flight flush cannot outlive graceful shutdown.
const sseWriteTimeout = 3 * time.Second
// apiEvents streams findings to the client over Server-Sent Events. A
// subscriber connects once and receives a `data: {...}\n\n` block per
// finding plus a periodic `: keepalive\n\n` comment line every 25s so
// intermediate proxies don't time the connection out. Auth is checked
// by the upstream requireRead middleware.
func (s *Server) apiEvents(w http.ResponseWriter, r *http.Request) {
s.mu.RLock()
bus := s.findingBus
s.mu.RUnlock()
if bus == nil {
http.Error(w, "event bus not available", http.StatusServiceUnavailable)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming unsupported", http.StatusInternalServerError)
return
}
// Reserve a subscriber slot before sending any stream headers so a flood of
// connections cannot exhaust daemon memory. Done up front so the cap can be
// reported as a clean 503 rather than mid-stream.
sub, ok := bus.TrySubscribe()
if !ok {
http.Error(w, "too many event stream subscribers", http.StatusServiceUnavailable)
return
}
defer bus.Unsubscribe(sub)
rc := http.NewResponseController(w)
setWriteDeadline := func() error {
return rc.SetWriteDeadline(time.Now().Add(sseWriteTimeout))
}
if err := setWriteDeadline(); err != nil {
http.Error(w, "streaming write deadlines unsupported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no") // nginx won't buffer
writeFrame := func(format string, args ...any) error {
if err := setWriteDeadline(); err != nil {
return err
}
if _, err := fmt.Fprintf(w, format, args...); err != nil {
return err
}
flusher.Flush()
return nil
}
// Initial flush establishes the connection and proxies see the headers
// before the first event. Bound it by the same write deadline.
if err := writeFrame(""); err != nil {
return
}
keepalive := time.NewTicker(25 * time.Second)
defer keepalive.Stop()
shutdownDone := s.pruneDone
for {
select {
case <-r.Context().Done():
return
case <-shutdownDone:
return
case <-keepalive.C:
if err := writeFrame(": keepalive\n\n"); err != nil {
return
}
case f, ok := <-sub:
if !ok {
return
}
body, err := json.Marshal(f)
if err != nil {
continue
}
if err := writeFrame("data: %s\n\n", body); err != nil {
return
}
}
}
}
package webui
import (
"bufio"
"encoding/json"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const (
uiAuditFile = "ui_audit.jsonl"
maxUIAuditSize = 10 * 1024 * 1024 // 10 MB
)
// UIAuditEntry records a UI action for compliance and accountability.
type UIAuditEntry struct {
Timestamp time.Time `json:"timestamp"`
Action string `json:"action"` // block, unblock, dismiss, fix, whitelist, etc.
Target string `json:"target"` // IP, finding key, file path
Details string `json:"details,omitempty"` // extra context
SourceIP string `json:"source_ip,omitempty"` // admin's IP
}
// auditLog records a UI action to the audit log.
func (s *Server) auditLog(r *http.Request, action, target, details string) {
entry := UIAuditEntry{
Timestamp: time.Now(),
Action: action,
Target: target,
Details: details,
SourceIP: extractClientIP(r),
}
path := filepath.Join(s.cfg.StatePath, uiAuditFile)
data, err := json.Marshal(entry)
if err != nil {
return
}
data = append(data, '\n')
// Rotate if too large
if info, statErr := os.Stat(path); statErr == nil && info.Size() > maxUIAuditSize {
_ = os.Rename(path, path+".1")
}
// #nosec G304 -- filepath.Join under operator-configured StatePath.
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return
}
defer func() { _ = f.Close() }()
if _, err := f.Write(data); err != nil {
log.Printf("webui: failed to write audit log: %v", err)
}
}
func extractClientIP(r *http.Request) string {
// Use RemoteAddr directly - XFF is trivially spoofable and this is
// a security audit log, so we only trust the TCP connection source.
return clientIPKey(r.RemoteAddr)
}
// readUIAuditLog returns the last N audit entries.
func readUIAuditLog(statePath string, limit int) []UIAuditEntry {
path := filepath.Join(statePath, uiAuditFile)
// #nosec G304 -- filepath.Join under operator-configured statePath.
f, err := os.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var all []UIAuditEntry
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 256*1024), 256*1024)
for scanner.Scan() {
var entry UIAuditEntry
if json.Unmarshal(scanner.Bytes(), &entry) == nil {
all = append(all, entry)
}
}
// Return newest first
for i, j := 0, len(all)-1; i < j; i, j = i+1, j-1 {
all[i], all[j] = all[j], all[i]
}
if limit > 0 && len(all) > limit {
all = all[:limit]
}
return all
}
// searchAuditEntries returns audit entries whose target or details contain the search string.
func (s *Server) searchAuditEntries(search string, limit int) []UIAuditEntry {
if search == "" || limit <= 0 {
return nil
}
readLimit := limit * 10
if readLimit > 5000 {
readLimit = 5000
}
all := readUIAuditLog(s.cfg.StatePath, readLimit)
searchLower := strings.ToLower(search)
var matched []UIAuditEntry
for _, e := range all {
if strings.Contains(strings.ToLower(e.Target), searchLower) ||
strings.Contains(strings.ToLower(e.Details), searchLower) ||
strings.Contains(strings.ToLower(e.Action), searchLower) {
matched = append(matched, e)
if len(matched) >= limit {
break
}
}
}
return matched
}
func (s *Server) handleAudit(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "audit.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
// GET /api/v1/audit - return UI audit log
func (s *Server) apiUIAudit(w http.ResponseWriter, r *http.Request) {
entries := readUIAuditLog(s.cfg.StatePath, 200)
if entries == nil {
entries = []UIAuditEntry{}
}
writeJSON(w, entries)
}
package webui
import (
"net/http"
"sort"
"time"
"github.com/pidginhost/csm/internal/health"
)
// componentsProvider is the optional capability surface the daemon
// exposes for /api/v1/components. Tests and the API-only fallback path
// can omit it; the handler degrades to attached/unknown.
type componentsProvider interface {
WatcherStatuses() map[string]bool
WatcherChangedAt() map[string]time.Time
}
// componentsUpstreamProvider is the optional capability surface the
// daemon adds when it has per-watcher upstream probes wired. Returning
// nil / absent for a watcher means "no probe, do not flag deaf".
type componentsUpstreamProvider interface {
WatcherUpstream() map[string]health.UpstreamResult
}
// componentRow is the JSON shape returned per watcher.
type componentRow struct {
Name string `json:"name"`
Label string `json:"label"`
Status string `json:"status"` // "ok" | "degraded" | "deaf" | "idle" | "unknown"
Attached bool `json:"attached"`
ChangedAtISO string `json:"changed_at_iso,omitempty"`
ChangedAgo string `json:"changed_ago,omitempty"`
LastEventISO string `json:"last_event_iso,omitempty"`
LastEventAgo string `json:"last_event_ago,omitempty"`
LastEventCheck string `json:"last_event_check,omitempty"`
UpstreamFresh *bool `json:"upstream_fresh,omitempty"`
UpstreamReason string `json:"upstream_reason,omitempty"`
UpstreamSeenISO string `json:"upstream_seen_iso,omitempty"`
}
// componentLabels maps the short watcher name to the operator-facing label.
// Watchers not in the map render with their raw key.
var componentLabels = map[string]string{
"fanotify": "Fanotify (filesystem)",
"audit": "Auditd",
"modsec": "ModSecurity audit",
"afalg": "AF_ALG kernel monitor",
"phprelay": "PHP relay watcher",
"maillog": "Mail log",
"email_av_spool": "Email AV spool",
"forwarder": "Forwarder watcher",
"pamlistener": "PAM listener",
"connection": "Connection tracker",
"exec": "Exec monitor",
"sensitive": "Sensitive file monitor",
"accesslog": "Access log",
"dovecot_log": "Dovecot log",
"exim_mainlog": "Exim mainlog",
"cpanel_access_log": "cPanel access log",
}
// componentCheckOrigin maps a finding Check name back to the watcher that
// emits it. Only the entries with a clear single-source origin are listed;
// finding names reused by periodic or retroactive scans intentionally have
// no entry so they do not advance a watcher's "last event" clock.
var componentCheckOrigin = map[string]string{
"cgi_backdoor_realtime": "fanotify",
"cgi_suspicious_location_realtime": "fanotify",
"credential_log_realtime": "fanotify",
"email_auth_failure_realtime": "maillog",
"email_av_degraded": "email_av_spool",
"email_av_parse_error": "email_av_spool",
"email_av_quarantine_error": "email_av_spool",
"email_av_timeout": "email_av_spool",
"email_compromised_account": "maillog",
"email_credential_leak": "maillog",
"email_dkim_failure": "maillog",
"email_malware": "email_av_spool",
"email_php_relay_action_dry_run": "phprelay",
"email_php_relay_action_failed": "phprelay",
"email_php_relay_action_skipped": "phprelay",
"email_php_relay_abuse": "phprelay",
"email_php_relay_account_volume_capped": "phprelay",
"email_php_relay_cpanel_limit_unreadable": "phprelay",
"email_php_relay_disabled": "phprelay",
"email_php_relay_inotify_overflow": "phprelay",
"email_php_relay_inotify_overflow_recovered": "phprelay",
"email_php_relay_msgindex_persist_failed": "phprelay",
"email_php_relay_no_exim": "phprelay",
"email_php_relay_overflow_scan_truncated": "phprelay",
"email_php_relay_path2b_disabled": "phprelay",
"email_php_relay_policies_reload": "phprelay",
"email_php_relay_rate_limit_hit": "phprelay",
"email_php_relay_sweep_failed": "phprelay",
"email_php_relay_watcher_failed": "phprelay",
"email_defer_fail_governor": "maillog",
"email_rate_critical": "maillog",
"email_rate_warning": "maillog",
"email_spam_outbreak": "maillog",
"email_spf_rejection": "maillog",
"executable_in_config_realtime": "fanotify",
"executable_in_tmp_realtime": "fanotify",
"exim_frozen_realtime": "maillog",
"fanotify_overflow": "fanotify",
"htaccess_injection_realtime": "fanotify",
"mail_account_compromised": "maillog",
"mail_account_spray": "maillog",
"mail_bruteforce": "maillog",
"mail_log_source_unavailable": "maillog",
"mail_subnet_spray": "maillog",
"modsec_block_escalation": "modsec",
"modsec_block_realtime": "modsec",
"modsec_csm_block_escalation": "modsec",
"modsec_warning_realtime": "modsec",
"obfuscated_php_realtime": "fanotify",
"credential_stuffing": "pamlistener",
"pam_bruteforce": "pamlistener",
"pam_login": "pamlistener",
"phishing_kit_realtime": "fanotify",
"phishing_realtime": "fanotify",
"php_config_realtime": "fanotify",
"php_dropper_realtime": "fanotify",
"php_in_sensitive_dir_realtime": "fanotify",
"php_in_uploads_realtime": "fanotify",
"signature_match_realtime": "fanotify",
"smtp_account_spray": "maillog",
"smtp_bruteforce": "maillog",
"smtp_probe_abuse": "maillog",
"smtp_subnet_spray": "maillog",
"webshell_content_realtime": "fanotify",
"webshell_realtime": "fanotify",
"yara_match_realtime": "fanotify",
}
// apiComponents returns one row per registered watcher with its live
// state, time since last state change, and the most recent finding it
// emitted within a 7-day lookback. Drives the dashboard component
// matrix.
func (s *Server) apiComponents(w http.ResponseWriter, _ *http.Request) {
cp, _ := s.provider.(componentsProvider)
if cp == nil {
writeJSON(w, []componentRow{})
return
}
statuses := cp.WatcherStatuses()
changed := cp.WatcherChangedAt()
lastEvents := s.lastEventByWatcher(7 * 24 * time.Hour)
var upstream map[string]health.UpstreamResult
if up, ok := s.provider.(componentsUpstreamProvider); ok {
upstream = up.WatcherUpstream()
}
rows := make([]componentRow, 0, len(statuses))
for name, attached := range statuses {
row := componentRow{
Name: name,
Label: componentLabel(name),
Attached: attached,
}
if t, ok := changed[name]; ok && !t.IsZero() {
row.ChangedAtISO = t.Format(time.RFC3339)
row.ChangedAgo = timeAgo(t)
}
if ev, ok := lastEvents[name]; ok && !ev.at.IsZero() {
row.LastEventISO = ev.at.Format(time.RFC3339)
row.LastEventAgo = timeAgo(ev.at)
row.LastEventCheck = ev.check
}
var upstreamFresh *bool
if up, ok := upstream[name]; ok {
fresh := up.Fresh
upstreamFresh = &fresh
row.UpstreamFresh = upstreamFresh
row.UpstreamReason = up.Reason
if !up.LastActivity.IsZero() {
row.UpstreamSeenISO = up.LastActivity.Format(time.RFC3339)
}
}
row.Status = componentStatus(attached, lastEvents[name].at, upstreamFresh)
rows = append(rows, row)
}
sort.Slice(rows, func(i, j int) bool {
if rows[i].Status != rows[j].Status {
return componentStatusRank(rows[i].Status) < componentStatusRank(rows[j].Status)
}
return rows[i].Label < rows[j].Label
})
writeJSON(w, rows)
}
type watcherEvent struct {
at time.Time
check string
}
// lastEventByWatcher walks history within the lookback window and returns
// the most recent finding per known watcher key. Findings whose Check is
// not in componentCheckOrigin are skipped so periodic-scan output does
// not get attributed to a real-time watcher.
func (s *Server) lastEventByWatcher(window time.Duration) map[string]watcherEvent {
out := map[string]watcherEvent{}
if s.store == nil {
return out
}
since := time.Now().Add(-window)
for _, f := range s.store.ReadHistorySince(since) {
watcher, ok := componentCheckOrigin[f.Check]
if !ok {
continue
}
if cur, exists := out[watcher]; exists && !cur.at.Before(f.Timestamp) {
continue
}
out[watcher] = watcherEvent{at: f.Timestamp, check: f.Check}
}
// Also fold in the latest scan set so freshly-emitted findings appear
// before they have rolled into history.
for _, f := range s.store.LatestFindings() {
if f.Timestamp.Before(since) {
continue
}
watcher, ok := componentCheckOrigin[f.Check]
if !ok {
continue
}
if cur, exists := out[watcher]; exists && !cur.at.Before(f.Timestamp) {
continue
}
out[watcher] = watcherEvent{at: f.Timestamp, check: f.Check}
}
return out
}
func componentLabel(name string) string {
if l, ok := componentLabels[name]; ok {
return l
}
return name
}
// componentStatus collapses the per-row state into a UI bucket.
// - degraded: watcher detached (attempted setup, failed or fell off)
// - deaf: attached but the upstream feeding it has gone silent
// (probe registered, returned Fresh=false). Operator action needed
// before this watcher will ever produce events again.
// - ok: attached AND has produced at least one event recently
// - idle: attached, no events in window, and either no probe is
// wired or the probe still confirms the upstream is alive
// - unknown: not attached and no record either way (reserved)
func componentStatus(attached bool, lastEvent time.Time, upstreamFresh *bool) string {
if !attached {
return "degraded"
}
if upstreamFresh != nil && !*upstreamFresh {
return "deaf"
}
if lastEvent.IsZero() {
return "idle"
}
return "ok"
}
func componentStatusRank(status string) int {
switch status {
case "degraded":
return 0
case "deaf":
return 1
case "idle":
return 2
case "ok":
return 3
default:
return 4
}
}
package webui
import (
"net/http"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/store"
)
// Cleanup-history handlers for the db_object_backups bbolt bucket.
// htaccess pre_clean backups already surface through the existing
// /api/v1/quarantine listing (their .meta sidecars now match the
// JSON QuarantineMeta shape). The bbolt-backed db_object_backups
// bucket needs its own list + restore endpoints because the data
// lives outside the filesystem-quarantine flow.
//
// All handlers are registered behind requireAuth in server.go;
// the restore handler additionally requires CSRF.
// dbObjectBackupEntry is the JSON shape returned to the cleanup-
// history UI. The Key field is opaque to the UI -- it round-trips
// to the restore endpoint as-is so the lookup is a single bbolt
// Get, not a multi-field reconstruction.
type dbObjectBackupEntry struct {
Key string `json:"key"`
Account string `json:"account"`
Schema string `json:"schema"`
Kind string `json:"kind"`
Name string `json:"name"`
DroppedAt string `json:"dropped_at"` // RFC 3339
DroppedBy string `json:"dropped_by"`
FindingID string `json:"finding_id,omitempty"`
BodyBytes int `json:"body_bytes"` // length of CreateSQL; surfaced for size hint
RestoredAt string `json:"restored_at,omitempty"`
Restored bool `json:"restored"`
}
const dbObjectBackupPreviewBytes = 8 * 1024
// apiDBObjectBackups returns every record in the bucket, newest
// first by DroppedAt. The full CreateSQL is intentionally NOT
// returned in the listing -- those payloads can be large and the
// listing is meant for browse-and-pick. The preview endpoint returns
// one bounded CREATE SQL payload on demand.
func (s *Server) apiDBObjectBackups(w http.ResponseWriter, _ *http.Request) {
sdb := store.Global()
if sdb == nil {
writeJSON(w, []dbObjectBackupEntry{})
return
}
records, keys, err := sdb.ListDBObjectBackupsAll()
if err != nil {
writeJSONError(w, "failed to list backups: "+err.Error(), http.StatusInternalServerError)
return
}
out := make([]dbObjectBackupEntry, 0, len(records))
for i, r := range records {
entry := dbObjectBackupEntry{
Key: keys[i],
Account: r.Account,
Schema: r.Schema,
Kind: r.Kind,
Name: r.Name,
DroppedAt: r.DroppedAt.UTC().Format("2006-01-02T15:04:05Z"),
DroppedBy: r.DroppedBy,
FindingID: r.FindingID,
BodyBytes: len(r.CreateSQL),
}
if !r.RestoredAt.IsZero() {
entry.Restored = true
entry.RestoredAt = r.RestoredAt.UTC().Format("2006-01-02T15:04:05Z")
}
out = append(out, entry)
}
// Newest first -- the bbolt key embeds unix-nanos so a string
// sort already produces chronological-by-drop-time when
// reversed. Doing it in Go keeps the contract explicit.
sortDBObjectBackupsNewestFirst(out)
writeJSON(w, out)
}
// apiDBObjectBackupPreview returns a bounded CREATE SQL preview for one
// backup. Full restore still round-trips only the opaque key.
func (s *Server) apiDBObjectBackupPreview(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
key := r.URL.Query().Get("key")
if key == "" {
writeJSONError(w, "key is required", http.StatusBadRequest)
return
}
sdb := store.Global()
if sdb == nil {
writeJSONError(w, "bbolt store not available", http.StatusServiceUnavailable)
return
}
rec, ok, err := sdb.GetDBObjectBackupByKey(key)
if err != nil {
writeJSONError(w, "failed to read backup: "+err.Error(), http.StatusInternalServerError)
return
}
if !ok {
writeJSONError(w, "backup not found", http.StatusNotFound)
return
}
preview := rec.CreateSQL
truncated := false
if len(preview) > dbObjectBackupPreviewBytes {
preview = preview[:dbObjectBackupPreviewBytes]
truncated = true
}
writeJSON(w, map[string]any{
"key": key,
"account": rec.Account,
"schema": rec.Schema,
"kind": rec.Kind,
"name": rec.Name,
"preview": preview,
"truncated": truncated,
"total_size": len(rec.CreateSQL),
})
}
// apiDBObjectBackupRestore re-executes the captured CREATE SQL.
// POST body: {"key": "<bbolt key>"}. The handler delegates to
// checks.RestoreDBObjectBackup; CSRF is enforced upstream in
// server.go's requireCSRF wrapper.
func (s *Server) apiDBObjectBackupRestore(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Key string `json:"key"`
}
if err := decodeJSONBodyLimited(w, r, 16*1024, &req); err != nil || req.Key == "" {
writeJSONError(w, "key is required", http.StatusBadRequest)
return
}
result := checks.RestoreDBObjectBackup(req.Key)
if !result.Success {
writeJSONError(w, result.Message, http.StatusBadRequest)
return
}
writeJSON(w, map[string]any{
"success": true,
"message": result.Message,
"details": result.Details,
})
}
// sortDBObjectBackupsNewestFirst sorts in place by DroppedAt
// descending. Local helper rather than relying on sort.Slice so
// the comparator is unambiguous in code review.
func sortDBObjectBackupsNewestFirst(entries []dbObjectBackupEntry) {
for i := 1; i < len(entries); i++ {
for j := i; j > 0 && entries[j].DroppedAt > entries[j-1].DroppedAt; j-- {
entries[j], entries[j-1] = entries[j-1], entries[j]
}
}
}
package webui
import (
"net/http"
"github.com/pidginhost/csm/internal/mailfwd/intel"
"github.com/pidginhost/csm/internal/platform"
)
// selectDeferralReporter picks the deferral-intel source for the host. Only
// cPanel/exim is wired; other platforms get the empty reporter until their
// adapters land (Phase 3).
func selectDeferralReporter() intel.Reporter {
if platform.Detect().IsCPanel() {
return intel.NewEximSource()
}
return intel.EmptyReporter{}
}
// apiEmailDeferrals handles GET /api/v1/email/deferrals and returns the
// outbound-deferral picture parsed from exim_mainlog: per-provider deferral
// rollup and per-outbound-IP reputation with stated reason codes.
func (s *Server) apiEmailDeferrals(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if s.deferralReporter == nil {
writeJSON(w, intel.Report{
Providers: []intel.ProviderRollup{},
OutboundIPs: []intel.OutboundIPRollup{},
})
return
}
rep, err := s.deferralReporter.Report()
if err != nil {
writeJSONError(w, "Failed to read deferral log", http.StatusInternalServerError)
return
}
writeJSON(w, rep)
}
package webui
import (
"context"
"io"
"net/http"
"os"
"os/exec"
"sort"
"strconv"
"strings"
"time"
"github.com/pidginhost/csm/internal/emailav"
"github.com/pidginhost/csm/internal/yara"
)
type emailStatsResponse struct {
QueueSize int `json:"queue_size"`
QueueWarn int `json:"queue_warn"`
QueueCrit int `json:"queue_crit"`
FrozenCount int `json:"frozen_count"`
OldestAge string `json:"oldest_age"`
SMTPBlock bool `json:"smtp_block"`
SMTPAllowUsers []string `json:"smtp_allow_users"`
SMTPPorts []int `json:"smtp_ports"`
PortFlood []portFloodEntry `json:"port_flood"`
TopSenders []senderEntry `json:"top_senders"`
}
type portFloodEntry struct {
Port int `json:"port"`
Proto string `json:"proto"`
Hits int `json:"hits"`
Seconds int `json:"seconds"`
}
type senderEntry struct {
Domain string `json:"domain"`
Count int `json:"count"`
}
func (s *Server) apiEmailStats(w http.ResponseWriter, _ *http.Request) {
resp := emailStatsResponse{
QueueWarn: s.cfg.Thresholds.MailQueueWarn,
QueueCrit: s.cfg.Thresholds.MailQueueCrit,
}
// Live queue size and frozen/oldest via exim
resp.QueueSize = eximQueueSize()
resp.FrozenCount, resp.OldestAge = eximQueueDetails()
// Firewall config
fw := s.cfg.Firewall
resp.SMTPBlock = fw.SMTPBlock
resp.SMTPAllowUsers = fw.SMTPAllowUsers
if resp.SMTPAllowUsers == nil {
resp.SMTPAllowUsers = []string{}
}
resp.SMTPPorts = fw.SMTPPorts
if resp.SMTPPorts == nil {
resp.SMTPPorts = []int{}
}
// Port flood rules for SMTP ports only
smtpPorts := map[int]bool{25: true, 465: true, 587: true}
for _, pf := range fw.PortFlood {
if smtpPorts[pf.Port] {
resp.PortFlood = append(resp.PortFlood, portFloodEntry{
Port: pf.Port,
Proto: pf.Proto,
Hits: pf.Hits,
Seconds: pf.Seconds,
})
}
}
if resp.PortFlood == nil {
resp.PortFlood = []portFloodEntry{}
}
// Top senders from exim_mainlog
resp.TopSenders = topMailSenders(500, 10)
if resp.TopSenders == nil {
resp.TopSenders = []senderEntry{}
}
writeJSON(w, resp)
}
// apiEmailQuarantineList handles GET /api/v1/email/quarantine and returns all
// quarantined email messages, or an empty array if the quarantine is not configured.
func (s *Server) apiEmailQuarantineList(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if s.emailQuarantine == nil {
writeJSON(w, []emailav.QuarantineMetadata{})
return
}
msgs, err := s.emailQuarantine.ListMessages()
if err != nil {
writeJSONError(w, "Failed to list quarantine", http.StatusInternalServerError)
return
}
if msgs == nil {
msgs = []emailav.QuarantineMetadata{}
}
writeJSON(w, msgs)
}
// apiEmailQuarantineAction handles GET, POST (release), and DELETE operations on
// individual quarantined messages at /api/v1/email/quarantine/{msgID}.
func (s *Server) apiEmailQuarantineAction(w http.ResponseWriter, r *http.Request) {
// Extract everything after the prefix, e.g. "abc123" or "abc123/release".
tail := strings.TrimPrefix(r.URL.Path, "/api/v1/email/quarantine/")
if tail == "" {
writeJSONError(w, "Missing message ID", http.StatusBadRequest)
return
}
parts := strings.SplitN(tail, "/", 2)
msgID := parts[0]
action := ""
if len(parts) == 2 {
action = parts[1]
}
if err := validateEximMessageID(msgID); err != nil {
writeJSONError(w, "Invalid message ID: "+err.Error(), http.StatusBadRequest)
return
}
if s.emailQuarantine == nil {
writeJSONError(w, "Email quarantine not configured", http.StatusServiceUnavailable)
return
}
switch r.Method {
case http.MethodGet:
if action != "" {
writeJSONError(w, "Unknown action", http.StatusBadRequest)
return
}
meta, err := s.emailQuarantine.GetMessage(msgID)
if err != nil {
writeJSONError(w, "Message not found", http.StatusNotFound)
return
}
writeJSON(w, meta)
case http.MethodPost:
if action != "release" {
writeJSONError(w, "Unknown action; use /release", http.StatusBadRequest)
return
}
if err := s.emailQuarantine.ReleaseMessage(msgID); err != nil {
writeJSONError(w, "Failed to release message: "+err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "released", "message_id": msgID})
case http.MethodDelete:
if action != "" {
writeJSONError(w, "Unknown action", http.StatusBadRequest)
return
}
if err := s.emailQuarantine.DeleteMessage(msgID); err != nil {
writeJSONError(w, "Failed to delete message: "+err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "deleted", "message_id": msgID})
default:
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
type emailAVStatusResponse struct {
Enabled bool `json:"enabled"`
ClamdAvailable bool `json:"clamd_available"`
ClamdSocket string `json:"clamd_socket"`
YaraXAvailable bool `json:"yarax_available"`
YaraXRuleCount int `json:"yarax_rule_count"`
WatcherMode string `json:"watcher_mode"`
Quarantined int `json:"quarantined"`
}
// apiEmailAVStatus handles GET /api/v1/email/av/status and returns the current
// state of the email AV subsystem (ClamAV, YARA-X, quarantine count, watcher mode).
func (s *Server) apiEmailAVStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
resp := emailAVStatusResponse{
Enabled: s.cfg.EmailAV.Enabled,
}
// ClamAV availability - probe the configured socket.
clamdSocket := s.cfg.EmailAV.ClamdSocket
if clamdSocket == "" {
clamdSocket = "/var/run/clamd.scan/clamd.sock"
}
resp.ClamdSocket = clamdSocket
resp.ClamdAvailable = emailav.NewClamdScanner(clamdSocket).Available()
// YARA-X availability and rule count. Active() covers both the
// in-process scanner and the out-of-process worker so this card
// reports the real rule count under either backend.
resp.YaraXAvailable = yara.Available()
if b := yara.Active(); b != nil {
resp.YaraXRuleCount = b.RuleCount()
}
// Watcher mode (set by daemon on startup).
resp.WatcherMode = s.emailAVWatcherMode
if resp.WatcherMode == "" {
resp.WatcherMode = "disabled"
}
// Count of currently quarantined messages.
if s.emailQuarantine != nil {
msgs, err := s.emailQuarantine.ListMessages()
if err == nil {
resp.Quarantined = len(msgs)
}
}
writeJSON(w, resp)
}
// eximQueueSize returns the current Exim mail queue count, or 0 on error.
func eximQueueSize() int {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
out, err := exec.CommandContext(ctx, "exim", "-bpc").Output()
if err != nil {
return 0
}
n, _ := strconv.Atoi(strings.TrimSpace(string(out)))
return n
}
// eximQueueDetails returns the frozen message count and the age of the oldest
// message in the queue. Uses `exim -bp` which lists all queued messages.
func eximQueueDetails() (frozen int, oldestAge string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
out, err := exec.CommandContext(ctx, "exim", "-bp").Output()
if err != nil {
return 0, ""
}
lines := strings.Split(string(out), "\n")
for _, line := range lines {
if strings.Contains(line, "*** frozen ***") {
frozen++
}
// First field of queue listing lines is the age (e.g., "4d", "15h", "30m")
fields := strings.Fields(line)
if len(fields) >= 3 {
age := fields[0]
// Only consider lines where first field looks like an age
if len(age) >= 2 && (age[len(age)-1] == 'd' || age[len(age)-1] == 'h' || age[len(age)-1] == 'm' || age[len(age)-1] == 's') {
if oldestAge == "" {
oldestAge = age // first entry is the oldest (queue sorted oldest first)
}
}
}
}
return frozen, oldestAge
}
// topMailSenders parses the last N lines of exim_mainlog and returns the
// top K sender domains by outbound message count.
func topMailSenders(tailLines, topK int) []senderEntry {
f, err := os.Open("/var/log/exim_mainlog")
if err != nil {
return nil
}
defer f.Close()
// Read tail of file
info, _ := f.Stat()
var data []byte
if info != nil && info.Size() > 256*1024 {
if _, err := f.Seek(-256*1024, 2); err != nil {
return nil
}
data, _ = io.ReadAll(f)
} else {
data, _ = io.ReadAll(f)
}
lines := strings.Split(string(data), "\n")
// Take last N lines
if len(lines) > tailLines {
lines = lines[len(lines)-tailLines:]
}
counts := make(map[string]int)
for _, line := range lines {
idx := strings.Index(line, " <= ")
if idx < 0 {
continue
}
rest := line[idx+4:]
fields := strings.Fields(rest)
if len(fields) < 1 {
continue
}
sender := fields[0]
atIdx := strings.LastIndex(sender, "@")
if atIdx < 0 {
continue
}
domain := sender[atIdx+1:]
if domain == "" || sender == "<>" || strings.HasPrefix(sender, "cPanel") {
continue
}
counts[domain]++
}
var entries []senderEntry
for domain, count := range counts {
entries = append(entries, senderEntry{Domain: domain, Count: count})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if len(entries) > topK {
entries = entries[:topK]
}
return entries
}
package webui
import (
"net/http"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// emailGroupsScanCap is the hard upper bound on findings inspected per
// /api/v1/email/groups call. Bounded reads keep the workbench cheap on
// hosts that store thousands of mail-related findings per day.
const emailGroupsScanCap = 5000
// emailGroupsDefaultLimit / Max bound the number of grouped rows returned
// to the operator UI. The plan caps the email first viewport at ~250 nodes
// so 200 is the highest useful ceiling.
const (
emailGroupsDefaultLimit = 50
emailGroupsMaxLimit = 200
)
type emailGroup struct {
Kind string `json:"kind"`
Severity int `json:"severity"`
Title string `json:"title"`
Subject string `json:"subject"`
Count int `json:"count"`
FirstSeen string `json:"first_seen"`
LastSeen string `json:"last_seen"`
Summary string `json:"summary"`
SampleFindings []alert.Finding `json:"sample_findings"`
IPs []string `json:"ips,omitempty"`
TopIPs []string `json:"top_ips,omitempty"`
Domains []string `json:"domains,omitempty"`
MessageIDs []string `json:"message_ids,omitempty"`
}
type emailGroupsResponse struct {
Groups []emailGroup `json:"groups"`
From string `json:"from"`
To string `json:"to"`
Scanned int `json:"scanned"`
Truncated bool `json:"truncated"`
}
// emailKindForCheck maps an alert check name to its email-workbench group
// kind. Returns "" when the check is not part of the email surface and
// the finding should be skipped by /api/v1/email/groups.
func emailKindForCheck(check string) string {
switch check {
case "email_compromised_account",
"email_credential_leak",
"email_weak_password",
"mail_account_compromised",
"email_pipe_forwarder",
"email_suspicious_forwarder":
return "compromised_account"
case "email_spam_outbreak",
"email_rate_critical",
"email_rate_warning",
"email_php_relay_abuse",
"email_php_relay_action_failed",
"email_php_relay_rate_limit_hit",
"email_cloud_relay_abuse":
return "spam_outbreak"
case "email_auth_failure_realtime",
"email_suspicious_geo",
"mail_bruteforce",
"mail_subnet_spray",
"mail_account_spray",
"smtp_bruteforce",
"smtp_subnet_spray",
"smtp_account_spray",
"smtp_probe_abuse":
return "auth_failure"
case "email_malware",
"email_phishing_content",
"email_av_degraded",
"email_av_timeout",
"email_av_parse_error",
"email_av_quarantine_error":
return "malware"
case "mail_per_account",
"mail_queue",
"email_defer_fail_governor",
"exim_frozen_realtime":
return "queue_alert"
}
return ""
}
// emailGroupKey is the dedup key used to merge findings into a single
// grouped action row. Different kinds prefer different identity fields:
// auth failures cluster by mailbox/IP, spam/malware/compromised by mailbox
// or domain, and queue alerts by check name.
func emailGroupKey(kind string, f alert.Finding) string {
mailbox := strings.ToLower(strings.TrimSpace(f.Mailbox))
domain := strings.ToLower(strings.TrimSpace(f.Domain))
switch kind {
case "auth_failure":
if mailbox != "" {
return "mailbox:" + mailbox
}
if f.SourceIP != "" {
return "ip:" + f.SourceIP
}
if domain != "" {
return "domain:" + domain
}
return "auth:unknown"
case "queue_alert":
return "queue:" + f.Check
default:
if mailbox != "" {
return kind + ":mailbox:" + mailbox
}
if domain != "" {
return kind + ":domain:" + domain
}
if f.SourceIP != "" {
return kind + ":ip:" + f.SourceIP
}
// Fall back to message text so two distinct payloads with no
// identity fields still produce two groups instead of collapsing.
return kind + ":msg:" + strings.TrimSpace(f.Message)
}
}
// emailGroupTitle renders the human-readable identifier for a grouped row.
// Prefers mailbox > domain > source IP > message text. Queue alerts have
// hard-coded labels because their finding text varies by host.
func emailGroupTitle(kind string, f alert.Finding) string {
if kind == "queue_alert" {
switch f.Check {
case "mail_queue":
return "Mail queue threshold"
case "mail_per_account":
return "Per-account mail volume"
case "exim_frozen_realtime":
return "Frozen mail queue"
}
}
if f.Mailbox != "" {
return f.Mailbox
}
if f.Domain != "" {
return f.Domain
}
if f.SourceIP != "" {
return f.SourceIP
}
return strings.TrimSpace(f.Message)
}
// emailGroupSubject describes the identity dimension behind the group --
// "mailbox", "domain", "ip", or "queue" -- so the UI can pick the right
// detail-panel tabs without re-reading the raw findings.
func emailGroupSubject(kind string, f alert.Finding) string {
if kind == "queue_alert" {
return "queue"
}
if f.Mailbox != "" {
return "mailbox"
}
if f.Domain != "" {
return "domain"
}
if f.SourceIP != "" {
return "ip"
}
return "unknown"
}
// buildEmailGroups walks the supplied findings (already bounded), merges
// matching findings into grouped rows, and returns the result sorted by
// severity (desc) then last-seen (desc). Pure function -- the HTTP
// handler is a thin wrapper so tests can drive grouping directly.
func buildEmailGroups(findings []alert.Finding, from, to time.Time, kindFilter string) []emailGroup {
type aggregator struct {
group *emailGroup
ipCounts map[string]int
domainSet map[string]struct{}
msgIDSet map[string]struct{}
samples []alert.Finding // newest-first
}
groups := make(map[string]*aggregator)
order := make([]string, 0)
for _, f := range findings {
ts := f.Timestamp
if !from.IsZero() && ts.Before(from) {
continue
}
if !to.IsZero() && ts.After(to) {
continue
}
kind := emailKindForCheck(f.Check)
if kind == "" {
continue
}
if kindFilter != "" && kindFilter != kind {
continue
}
key := emailGroupKey(kind, f)
agg, ok := groups[key]
if !ok {
agg = &aggregator{
group: &emailGroup{
Kind: kind,
Severity: int(f.Severity),
Title: emailGroupTitle(kind, f),
Subject: emailGroupSubject(kind, f),
FirstSeen: ts.UTC().Format(time.RFC3339),
LastSeen: ts.UTC().Format(time.RFC3339),
},
ipCounts: make(map[string]int),
domainSet: make(map[string]struct{}),
msgIDSet: make(map[string]struct{}),
}
groups[key] = agg
order = append(order, key)
}
agg.group.Count++
if int(f.Severity) > agg.group.Severity {
agg.group.Severity = int(f.Severity)
}
ftsStr := ts.UTC().Format(time.RFC3339)
if ftsStr < agg.group.FirstSeen {
agg.group.FirstSeen = ftsStr
}
if ftsStr > agg.group.LastSeen {
agg.group.LastSeen = ftsStr
}
if f.SourceIP != "" {
agg.ipCounts[f.SourceIP]++
}
if f.Domain != "" {
agg.domainSet[strings.ToLower(f.Domain)] = struct{}{}
}
for _, id := range f.MsgIDs {
if id != "" {
agg.msgIDSet[id] = struct{}{}
}
}
// Keep up to 3 most recent samples (assumes input is newest-first).
if len(agg.samples) < 3 {
agg.samples = append(agg.samples, f)
}
}
out := make([]emailGroup, 0, len(order))
for _, key := range order {
agg := groups[key]
g := agg.group
// Compose summary text: count + identity + IP/domain hint.
hint := ""
if g.Kind == "auth_failure" && len(agg.ipCounts) > 0 {
hint = " from " + plural(len(agg.ipCounts), "IP")
} else if len(agg.domainSet) > 1 {
hint = " across " + plural(len(agg.domainSet), "domain")
}
g.Summary = plural(g.Count, "event") + hint
g.SampleFindings = agg.samples
if len(agg.ipCounts) > 0 {
g.IPs = sortedKeys(agg.ipCounts)
g.TopIPs = topKeysByCount(agg.ipCounts, 5)
}
if len(agg.domainSet) > 0 {
g.Domains = sortedSetKeys(agg.domainSet)
}
if len(agg.msgIDSet) > 0 {
g.MessageIDs = sortedSetKeys(agg.msgIDSet)
if len(g.MessageIDs) > 10 {
g.MessageIDs = g.MessageIDs[:10]
}
}
out = append(out, *g)
}
sort.SliceStable(out, func(i, j int) bool {
if out[i].Severity != out[j].Severity {
return out[i].Severity > out[j].Severity
}
if out[i].Count != out[j].Count {
return out[i].Count > out[j].Count
}
return out[i].LastSeen > out[j].LastSeen
})
return out
}
func plural(n int, label string) string {
if n == 1 {
return "1 " + label
}
return itoa(n) + " " + label + "s"
}
func itoa(n int) string {
// Avoid pulling strconv just for this hot path; keeps the helper inline.
if n == 0 {
return "0"
}
neg := n < 0
if neg {
n = -n
}
var buf [20]byte
i := len(buf)
for n > 0 {
i--
buf[i] = byte('0' + n%10)
n /= 10
}
if neg {
i--
buf[i] = '-'
}
return string(buf[i:])
}
func sortedKeys(m map[string]int) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
func sortedSetKeys(m map[string]struct{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
// topKeysByCount returns up to k entries from m sorted by descending count
// (ties broken alphabetically) so the UI shows the dominant attackers
// first.
func topKeysByCount(m map[string]int, k int) []string {
type entry struct {
key string
count int
}
entries := make([]entry, 0, len(m))
for key, c := range m {
entries = append(entries, entry{key, c})
}
sort.SliceStable(entries, func(i, j int) bool {
if entries[i].count != entries[j].count {
return entries[i].count > entries[j].count
}
return entries[i].key < entries[j].key
})
if k < len(entries) {
entries = entries[:k]
}
out := make([]string, len(entries))
for i, e := range entries {
out[i] = e.key
}
return out
}
// parseEmailGroupDate accepts RFC3339 or YYYY-MM-DD; returns the default
// when the input is empty or unparseable. Date-only upper bounds include
// the whole local day, matching /api/v1/history.
func parseEmailGroupDate(s string, def time.Time, endOfDay bool) time.Time {
s = strings.TrimSpace(s)
if s == "" {
return def
}
if t, err := time.Parse(time.RFC3339, s); err == nil {
return t
}
if t, err := time.ParseInLocation("2006-01-02", s, time.Local); err == nil {
if endOfDay {
return t.Add(24*time.Hour - time.Nanosecond)
}
return t
}
return def
}
// apiEmailGroups handles GET /api/v1/email/groups. Returns server-side
// grouped action rows for the email workbench. Read-scope tokens may
// call this endpoint -- it does not mutate state.
func (s *Server) apiEmailGroups(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
q := r.URL.Query()
limit := queryInt(r, "limit", emailGroupsDefaultLimit)
if limit <= 0 || limit > emailGroupsMaxLimit {
limit = emailGroupsDefaultLimit
}
now := time.Now()
from := parseEmailGroupDate(q.Get("from"), now.Add(-24*time.Hour), false)
to := parseEmailGroupDate(q.Get("to"), now, true)
if to.Before(from) {
from, to = to, from
}
var findings []alert.Finding
if s.store != nil {
findings = s.store.ReadHistorySince(from)
}
scanned := len(findings)
truncated := false
if scanned > emailGroupsScanCap {
findings = findings[:emailGroupsScanCap]
scanned = emailGroupsScanCap
truncated = true
}
groups := buildEmailGroups(findings, from, to, q.Get("kind"))
if len(groups) > limit {
groups = groups[:limit]
}
writeJSON(w, emailGroupsResponse{
Groups: groups,
From: from.UTC().Format(time.RFC3339),
To: to.UTC().Format(time.RFC3339),
Scanned: scanned,
Truncated: truncated,
})
}
package webui
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"os/exec"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/firewall"
)
type firewallAllowView struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source"`
ExpiresAt string `json:"expires_at,omitempty"`
ExpiresIn string `json:"expires_in"`
}
type firewallPortAllowView struct {
IP string `json:"ip"`
Port int `json:"port"`
Proto string `json:"proto"`
Reason string `json:"reason"`
Source string `json:"source"`
}
func formatRemaining(expiresAt time.Time) string {
if expiresAt.IsZero() {
return "permanent"
}
remaining := time.Until(expiresAt)
if remaining < 0 {
remaining = 0
}
return fmt.Sprintf("%dh%dm", int(remaining.Hours()), int(remaining.Minutes())%60)
}
// apiFirewallStatus returns the firewall engine configuration and state summary.
func (s *Server) apiFirewallStatus(w http.ResponseWriter, _ *http.Request) {
cfg := s.cfg.Firewall
state, _ := firewall.LoadState(s.cfg.StatePath)
now := time.Now()
// Use top-level infra_ips if firewall.infra_ips is empty (daemon syncs at runtime)
infraIPs := cfg.InfraIPs
if len(infraIPs) == 0 {
infraIPs = s.cfg.InfraIPs
}
blockedPermanent := 0
blockedTemporary := 0
for _, entry := range state.Blocked {
if entry.ExpiresAt.IsZero() {
blockedPermanent++
continue
}
if now.Before(entry.ExpiresAt) {
blockedTemporary++
}
}
allowPermanent := 0
allowTemporary := 0
for _, entry := range state.Allowed {
if entry.ExpiresAt.IsZero() {
allowPermanent++
continue
}
if now.Before(entry.ExpiresAt) {
allowTemporary++
}
}
result := map[string]interface{}{
"enabled": cfg.Enabled,
"ipv6": cfg.IPv6,
"tcp_in": cfg.TCPIn,
"tcp_out": cfg.TCPOut,
"udp_in": cfg.UDPIn,
"udp_out": cfg.UDPOut,
"restricted_tcp": cfg.RestrictedTCP,
"passive_ftp": [2]int{cfg.PassiveFTPStart, cfg.PassiveFTPEnd},
"conn_rate_limit": cfg.ConnRateLimit,
"conn_limit": cfg.ConnLimit,
"syn_flood_protection": cfg.SYNFloodProtection,
"udp_flood": cfg.UDPFlood,
"smtp_block": cfg.SMTPBlock,
"log_dropped": cfg.LogDropped,
"deny_ip_limit": cfg.DenyIPLimit,
"blocked_count": blockedPermanent + blockedTemporary,
"blocked_net_count": len(state.BlockedNet),
"blocked_permanent": blockedPermanent,
"blocked_temporary": blockedTemporary,
"allowed_count": allowPermanent + allowTemporary,
"allow_permanent": allowPermanent,
"allow_temporary": allowTemporary,
"port_allow_count": len(state.PortAllowed),
"infra_ips": infraIPs,
"infra_count": len(infraIPs),
"port_flood_rules": len(cfg.PortFlood),
"country_block": cfg.CountryBlock,
"dyndns_hosts": cfg.DynDNSHosts,
}
writeJSON(w, result)
}
// apiFirewallAllowed returns active firewall allow rules and port exceptions.
func (s *Server) apiFirewallAllowed(w http.ResponseWriter, _ *http.Request) {
state, _ := firewall.LoadState(s.cfg.StatePath)
now := time.Now()
var allowed []firewallAllowView
for _, entry := range state.Allowed {
if !entry.ExpiresAt.IsZero() && !now.Before(entry.ExpiresAt) {
continue
}
view := firewallAllowView{
IP: entry.IP,
Reason: entry.Reason,
Source: entry.Source,
ExpiresIn: formatRemaining(entry.ExpiresAt),
}
if view.Source == "" {
view.Source = firewall.InferProvenance("allow", entry.Reason)
}
if !entry.ExpiresAt.IsZero() {
view.ExpiresAt = entry.ExpiresAt.Format(time.RFC3339)
}
allowed = append(allowed, view)
}
sort.Slice(allowed, func(i, j int) bool {
return allowed[i].IP < allowed[j].IP
})
portAllowed := make([]firewallPortAllowView, 0, len(state.PortAllowed))
for _, entry := range state.PortAllowed {
portAllowed = append(portAllowed, firewallPortAllowView{
IP: entry.IP,
Port: entry.Port,
Proto: entry.Proto,
Reason: entry.Reason,
Source: entry.Source,
})
if portAllowed[len(portAllowed)-1].Source == "" {
portAllowed[len(portAllowed)-1].Source = firewall.InferProvenance("allow_port", entry.Reason)
}
}
sort.Slice(portAllowed, func(i, j int) bool {
if portAllowed[i].IP != portAllowed[j].IP {
return portAllowed[i].IP < portAllowed[j].IP
}
if portAllowed[i].Port != portAllowed[j].Port {
return portAllowed[i].Port < portAllowed[j].Port
}
return portAllowed[i].Proto < portAllowed[j].Proto
})
writeJSON(w, map[string]interface{}{
"allowed": allowed,
"port_allowed": portAllowed,
})
}
// apiFirewallAllowIP adds a firewall allow rule, temporary when duration > 0.
func (s *Server) apiFirewallAllowIP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Duration string `json:"duration"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if req.Reason == "" {
req.Reason = "Allowed via CSM Web UI"
}
dur := parseDuration(req.Duration)
if dur > 0 {
allower, ok := s.blocker.(interface {
TempAllowIP(string, string, time.Duration) error
})
if !ok || allower == nil {
writeJSONError(w, "Firewall allow rules are not available", http.StatusServiceUnavailable)
return
}
if err := allower.TempAllowIP(req.IP, req.Reason, dur); err != nil {
writeJSONError(w, fmt.Sprintf("Allow failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "temp_allowed", "ip": req.IP})
return
}
allower, ok := s.blocker.(interface {
AllowIP(string, string) error
})
if !ok || allower == nil {
writeJSONError(w, "Firewall allow rules are not available", http.StatusServiceUnavailable)
return
}
if err := allower.AllowIP(req.IP, req.Reason); err != nil {
writeJSONError(w, fmt.Sprintf("Allow failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "allowed", "ip": req.IP})
}
// apiFirewallRemoveAllow removes a firewall allow rule.
func (s *Server) apiFirewallRemoveAllow(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, fmt.Sprintf("invalid IP address: %s", req.IP), http.StatusBadRequest)
return
}
allower, ok := s.blocker.(interface {
RemoveAllowIP(string) error
})
if !ok || allower == nil {
writeJSONError(w, "Firewall allow rules are not available", http.StatusServiceUnavailable)
return
}
if err := allower.RemoveAllowIP(req.IP); err != nil {
writeJSONError(w, fmt.Sprintf("Remove failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "removed", "ip": req.IP})
}
// apiFirewallAudit returns recent firewall audit log entries.
func (s *Server) apiFirewallAudit(w http.ResponseWriter, r *http.Request) {
limit := queryInt(r, "limit", 100)
entries := firewall.ReadAuditLog(s.cfg.StatePath, limit)
if entries == nil {
writeJSON(w, []interface{}{})
return
}
type auditView struct {
Timestamp string `json:"timestamp"`
Action string `json:"action"`
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source"`
Duration string `json:"duration"`
TimeAgo string `json:"time_ago"`
}
search := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("search")))
actionFilter := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("action")))
sourceFilter := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("source")))
var result []auditView
for _, e := range entries {
source := e.Source
if source == "" {
source = firewall.InferProvenance(e.Action, e.Reason)
}
if actionFilter != "" && strings.ToLower(e.Action) != actionFilter {
continue
}
if sourceFilter != "" && strings.ToLower(source) != sourceFilter {
continue
}
if search != "" {
haystack := strings.ToLower(strings.Join([]string{e.Action, e.IP, e.Reason, source}, " "))
if !strings.Contains(haystack, search) {
continue
}
}
result = append(result, auditView{
Timestamp: e.Timestamp.Format("2006-01-02 15:04:05"),
Action: e.Action,
IP: e.IP,
Reason: e.Reason,
Source: source,
Duration: e.Duration,
TimeAgo: timeAgo(e.Timestamp),
})
}
writeJSON(w, result)
}
// apiFirewallSubnets returns currently blocked subnets.
func (s *Server) apiFirewallSubnets(w http.ResponseWriter, _ *http.Request) {
state, _ := firewall.LoadState(s.cfg.StatePath)
type subnetView struct {
CIDR string `json:"cidr"`
Reason string `json:"reason"`
Source string `json:"source"`
BlockedAt string `json:"blocked_at"`
TimeAgo string `json:"time_ago"`
ExpiresIn string `json:"expires_in"`
}
var result []subnetView
for _, sn := range state.BlockedNet {
v := subnetView{
CIDR: sn.CIDR,
Reason: sn.Reason,
Source: sn.Source,
BlockedAt: sn.BlockedAt.Format(time.RFC3339),
TimeAgo: timeAgo(sn.BlockedAt),
}
if v.Source == "" {
v.Source = firewall.InferProvenance("block_subnet", sn.Reason)
}
v.ExpiresIn = formatRemaining(sn.ExpiresAt)
result = append(result, v)
}
if result == nil {
writeJSON(w, []interface{}{})
return
}
writeJSON(w, result)
}
// apiFirewallDenySubnet blocks a subnet via the firewall engine.
func (s *Server) apiFirewallDenySubnet(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
CIDR string `json:"cidr"`
Reason string `json:"reason"`
Duration string `json:"duration"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.CIDR == "" {
writeJSONError(w, "CIDR is required", http.StatusBadRequest)
return
}
if _, err := validateCIDR(req.CIDR); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if req.Reason == "" {
req.Reason = "Blocked via CSM Web UI"
}
dur := parseDuration(req.Duration)
sb, ok := s.blocker.(interface {
BlockSubnet(string, string, time.Duration) error
})
if !ok || sb == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
if err := sb.BlockSubnet(req.CIDR, req.Reason, dur); err != nil {
writeJSONError(w, fmt.Sprintf("Block failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "blocked", "cidr": req.CIDR})
}
// apiFirewallRemoveSubnet removes a subnet block.
func (s *Server) apiFirewallRemoveSubnet(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
CIDR string `json:"cidr"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.CIDR == "" {
writeJSONError(w, "CIDR is required", http.StatusBadRequest)
return
}
if _, _, err := net.ParseCIDR(req.CIDR); err != nil {
writeJSONError(w, "Invalid CIDR notation", http.StatusBadRequest)
return
}
sb, ok := s.blocker.(interface {
UnblockSubnet(string) error
})
if !ok || sb == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
if err := sb.UnblockSubnet(req.CIDR); err != nil {
writeJSONError(w, fmt.Sprintf("Remove failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "removed", "cidr": req.CIDR})
}
// apiFirewallFlush clears all blocked IPs.
func (s *Server) apiFirewallFlush(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
fb, ok := s.blocker.(interface{ FlushBlocked() error })
if !ok || fb == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
if err := fb.FlushBlocked(); err != nil {
writeJSONError(w, fmt.Sprintf("Flush failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "flushed"})
}
// apiFirewallFlushCphulk clears cPHulk login history for one IP without touching firewall state.
func (s *Server) apiFirewallFlushCphulk(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
flushCphulk(req.IP)
writeJSON(w, map[string]string{"status": "flushed", "ip": req.IP})
}
// apiFirewallCheck checks if an IP is blocked in CSM or cphulk.
// GET /api/v1/firewall/check?ip=1.2.3.4
// Response matches cpanel-service format for phclient compatibility:
//
// {"success": true, "ip": "1.2.3.4", "permanent": "reason or null", "temporary": "reason or null", "cphulk": true/false}
func (s *Server) apiFirewallCheck(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip == "" || net.ParseIP(ip) == nil {
writeJSON(w, map[string]interface{}{"success": false, "error_msg": "The ip is not valid or it was not set."})
return
}
result := map[string]interface{}{
"success": true,
"ip": ip,
"permanent": nil,
"temporary": nil,
"cphulk": false,
}
// Check CSM firewall state
state, _ := firewall.LoadState(s.cfg.StatePath)
now := time.Now()
for _, b := range state.Blocked {
if b.IP == ip {
if b.ExpiresAt.IsZero() {
result["permanent"] = b.Reason
} else if now.Before(b.ExpiresAt) {
result["temporary"] = fmt.Sprintf("%s (expires in %s)", b.Reason,
time.Until(b.ExpiresAt).Truncate(time.Minute))
}
}
}
// Check blocked subnets
parsedIP := net.ParseIP(ip)
for _, sn := range state.BlockedNet {
_, network, err := net.ParseCIDR(sn.CIDR)
if err == nil && network.Contains(parsedIP) {
result["permanent"] = fmt.Sprintf("Subnet block: %s - %s", sn.CIDR, sn.Reason)
}
}
// Check cphulk (cPanel brute force detector) - read-only check.
// CommandContext bounds a hung whmapi1 (cPanel socket busy/down) so a
// stuck call cannot pin this HTTP handler goroutine indefinitely.
cphulkCtx, cancel := context.WithTimeout(r.Context(), 8*time.Second)
defer cancel()
cphulkOut, cphulkErr := exec.CommandContext(cphulkCtx, "whmapi1", "read_cphulk_records",
"list_name=black", "--output=json").Output()
if cphulkErr == nil && cphulkBlocksIP(cphulkOut, ip) {
result["cphulk"] = true
}
writeJSON(w, result)
}
// cphulkBlocksIP scopes the match to cPHulk record IP fields. A raw token
// search would still match unrelated strings such as operator notes.
func cphulkBlocksIP(jsonOut []byte, ip string) bool {
if ip == "" {
return false
}
var payload struct {
Data struct {
Records []map[string]json.RawMessage `json:"records"`
} `json:"data"`
}
if err := json.Unmarshal(jsonOut, &payload); err != nil {
return false
}
for _, record := range payload.Data.Records {
for key, raw := range record {
if !isCphulkRecordIPField(key) {
continue
}
var value string
if err := json.Unmarshal(raw, &value); err != nil {
continue
}
if value == ip {
return true
}
}
}
return false
}
func isCphulkRecordIPField(key string) bool {
switch strings.ToLower(key) {
case "ip", "ip_address":
return true
default:
return false
}
}
// apiFirewallUnban unblocks an IP from CSM + cphulk in one call.
// POST /api/v1/firewall/unban body: {"ip": "1.2.3.4"}
func (s *Server) apiFirewallUnban(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSON(w, map[string]interface{}{"success": false, "error_msg": "The ip is not valid or it was not set."})
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSON(w, map[string]interface{}{"success": false, "error_msg": err.Error()})
return
}
// 1. Unblock from CSM firewall (individual IP)
if s.blocker != nil {
_ = s.blocker.UnblockIP(req.IP)
}
// 2. Also remove from any covering subnet block
state, _ := firewall.LoadState(s.cfg.StatePath)
parsedIP := net.ParseIP(req.IP)
subnetRemoved := ""
if sb, ok := s.blocker.(interface{ UnblockSubnet(string) error }); ok {
for _, sn := range state.BlockedNet {
_, network, err := net.ParseCIDR(sn.CIDR)
if err == nil && network.Contains(parsedIP) {
_ = sb.UnblockSubnet(sn.CIDR)
subnetRemoved = sn.CIDR
break
}
}
}
// 3. Flush from cphulk
flushCphulk(req.IP)
result := map[string]interface{}{"success": true, "ip": req.IP}
if subnetRemoved != "" {
result["subnet_removed"] = subnetRemoved
}
writeJSON(w, result)
}
package webui
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"time"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/firewall/rollback"
"github.com/pidginhost/csm/internal/integrity"
"github.com/pidginhost/csm/internal/obs"
)
// apiFirewallTentativeApply handles POST /api/v1/settings/firewall/tentative-apply.
// Body shape mirrors the regular settings POST plus an optional
// timeout_min field (1..30, default 5). The handler runs the same
// change validation as the normal save path, snapshots the previous
// csm.yaml bytes into bbolt, writes the new file, and triggers a
// daemon restart. The rollback manager arms an in-process timer; if
// the operator does not POST /confirm before the deadline the daemon
// restores the snapshot and restarts itself.
func (s *Server) apiFirewallTentativeApply(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
mgr := rollback.Global()
if mgr == nil {
writeJSONError(w, "rollback manager not available", http.StatusServiceUnavailable)
return
}
if mgr.Status().Pending {
writeJSONError(w, "a firewall rollback is already pending; confirm or revert first", http.StatusConflict)
return
}
section, ok := LookupSettingsSection("firewall")
if !ok {
writeJSONError(w, "firewall section not registered", http.StatusInternalServerError)
return
}
ifMatch := r.Header.Get("If-Match")
if ifMatch == "" {
writeJSONError(w, "If-Match header required", http.StatusBadRequest)
return
}
var body struct {
Changes map[string]json.RawMessage `json:"changes"`
TimeoutMin int `json:"timeout_min"`
}
if err := decodeJSONBodyLimited(w, r, 256*1024, &body); err != nil {
writeJSONError(w, "invalid body: "+err.Error(), http.StatusBadRequest)
return
}
diskBytes, err := os.ReadFile(s.cfg.ConfigFile) // #nosec G304 -- operator-supplied config path
if err != nil {
writeJSONError(w, "read config: "+err.Error(), http.StatusInternalServerError)
return
}
disk, err := config.LoadBytes(diskBytes)
if err != nil {
writeJSONError(w, "parse config: "+err.Error(), http.StatusInternalServerError)
return
}
disk.ConfigFile = s.cfg.ConfigFile
disk.ConfigDir = s.cfg.ConfigDir
if disk.Integrity.ConfigHash != ifMatch {
writeJSONError(w, "config changed on disk, reload", http.StatusPreconditionFailed)
return
}
if rejectIfConfDirChanged(w, s.cfg.ConfigDir, disk.Integrity.ConfdHash) {
return
}
clone := *disk
if disk.Firewall != nil {
fw := *disk.Firewall
clone.Firewall = &fw
}
yamlChanges, errs := buildChangeSet(section, &clone, body.Changes)
if len(errs) > 0 {
writeValidationErrors(w, errs)
return
}
validationResults := append(config.Validate(&clone), config.ValidateDeepSection(&clone, section.ID)...)
fieldErrors, warnings := splitValidationResults(validationResults)
if len(fieldErrors) > 0 {
writeValidationErrors(w, fieldErrors)
return
}
warnings = append(warnings, firewallLockoutWarnings(&clone)...)
edited, err := config.YAMLEdit(diskBytes, yamlChanges)
if err != nil {
writeJSONError(w, "yaml edit: "+err.Error(), http.StatusInternalServerError)
return
}
// Stage rollback BEFORE the on-disk write so a crash between the two
// leaves the snapshot recoverable: if there is no new file on disk
// yet, the snapshot is a no-op revert. The reverse order would let
// the daemon come back to the new config with no rollback record and
// no way to undo without operator intervention.
timeout := time.Duration(body.TimeoutMin) * time.Minute
st, err := mgr.Apply(diskBytes, edited, timeout, extractClientIP(r))
if err != nil {
writeJSONError(w, "stage rollback: "+err.Error(), http.StatusInternalServerError)
return
}
if err := integrity.SignAndSavePreserving(s.cfg.ConfigFile, s.cfg.ConfigDir, edited, &clone, disk.Integrity.BinaryHash); err != nil {
// Best-effort cleanup: the snapshot is now misleading because
// the on-disk file never changed. Drop it so the operator does
// not see a phantom pending rollback in the UI.
_ = mgr.Confirm()
writeJSONError(w, "save: "+err.Error(), http.StatusInternalServerError)
return
}
s.auditLog(r, "settings-tentative-apply", "firewall", auditDetailsFor(section, body.Changes))
// Defer the restart so the response flushes first; otherwise the
// client sees a connection reset and cannot read the rollback ETA
// it needs to drive the countdown banner.
s.scheduleDaemonRestart(250 * time.Millisecond)
writeJSON(w, map[string]interface{}{
"status": "tentative-apply issued",
"warnings": warnings,
"rollback": st,
"new_etag": clone.Integrity.ConfigHash,
"applied": changedFieldList(body.Changes, section),
"requires_restart": true,
})
}
// changedFieldList returns the dotted YAML paths of the keys in changes
// scoped to the section. Used in the response so the UI knows which
// fields to highlight as pending.
func changedFieldList(changes map[string]json.RawMessage, section SettingsSection) []string {
out := make([]string, 0, len(changes))
for k := range changes {
if k == "" {
out = append(out, section.YAMLPath)
continue
}
out = append(out, section.YAMLPath+"."+k)
}
return out
}
// apiFirewallRollbackStatus returns the pending rollback record, if any.
// Read endpoint, no CSRF; safe to poll for the countdown banner.
func (s *Server) apiFirewallRollbackStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
mgr := rollback.Global()
if mgr == nil {
writeJSON(w, rollback.Status{})
return
}
writeJSON(w, mgr.Status())
}
// apiFirewallRollbackConfirm handles POST .../confirm. Drops the snapshot;
// the new config stays.
func (s *Server) apiFirewallRollbackConfirm(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
mgr := rollback.Global()
if mgr == nil {
writeJSONError(w, "rollback manager not available", http.StatusServiceUnavailable)
return
}
if !mgr.Status().Pending {
writeJSONError(w, "no pending rollback", http.StatusConflict)
return
}
if err := mgr.Confirm(); err != nil {
writeJSONError(w, "confirm: "+err.Error(), http.StatusInternalServerError)
return
}
s.auditLog(r, "settings-rollback-confirm", "firewall", "")
writeJSON(w, map[string]string{"status": "confirmed"})
}
// apiFirewallRollbackRevert handles POST .../revert. Restores the
// snapshot to disk and triggers a daemon restart. Returns 200 with the
// pre-revert status; the actual restart happens on a goroutine so the
// response can flush.
func (s *Server) apiFirewallRollbackRevert(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
mgr := rollback.Global()
if mgr == nil {
writeJSONError(w, "rollback manager not available", http.StatusServiceUnavailable)
return
}
if !mgr.Status().Pending {
writeJSONError(w, "no pending rollback", http.StatusConflict)
return
}
s.scheduleRollbackRevert(mgr, 30*time.Second)
s.auditLog(r, "settings-rollback-revert", "firewall", "")
writeJSON(w, map[string]string{"status": "revert issued"})
}
// scheduleDaemonRestart fires restartDaemon after delay in a supervised
// goroutine. The select on pruneDone lets the goroutine exit cleanly if
// the server begins shutdown during the pre-restart delay, so an
// operator-initiated stop is not chased by a phantom restart.
func (s *Server) scheduleDaemonRestart(delay time.Duration) {
obs.SafeGo("webui-tentative-apply-restart", func() {
select {
case <-s.pruneDone:
return
case <-time.After(delay):
}
if _, err := s.restartDaemon(); err != nil {
fmt.Fprintf(os.Stderr, "webui: tentative-apply restart failed: %v\n", err)
}
})
}
// scheduleRollbackRevert runs the revert in a supervised goroutine with
// a hard timeout. If shutdown already started before the worker runs, it
// does not begin a new revert; once started, the revert owns its restart
// context so the restart it triggers cannot cancel itself via Shutdown.
func (s *Server) scheduleRollbackRevert(mgr *rollback.Manager, timeout time.Duration) {
obs.SafeGo("webui-rollback-revert", func() {
select {
case <-s.pruneDone:
return
default:
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := mgr.Revert(ctx); err != nil {
fmt.Fprintf(os.Stderr, "webui: rollback revert failed: %v\n", err)
}
})
}
package webui
import (
"net/http"
"sort"
"github.com/pidginhost/csm/internal/mailfwd/inventory"
"github.com/pidginhost/csm/internal/platform"
)
// forwarderDestination is one resolved target of a forwarder, as served to the
// UI. Provider is the inventory class string (local/yahoo/gmail/outlook/external)
// the table renders as a badge.
type forwarderDestination struct {
Address string `json:"address"`
Domain string `json:"domain"`
Provider string `json:"provider"`
}
// forwarderEntry is a single source address and everything it relays to.
type forwarderEntry struct {
Source string `json:"source"`
Domain string `json:"domain"`
Owner string `json:"owner"`
Destinations []forwarderDestination `json:"destinations"`
Providers []string `json:"providers"` // distinct destination classes, sorted
KeepLocal bool `json:"keep_local"`
ForwardOnly bool `json:"forward_only"`
HasExternal bool `json:"has_external"`
HasFreeProvider bool `json:"has_free_provider"`
}
// forwardersSummary is the page-header rollup: how many forwarders exist and
// how many carry reputation risk (leave the server / target a free provider).
type forwardersSummary struct {
Total int `json:"total"`
External int `json:"external"`
FreeProvider int `json:"free_provider"`
}
type forwardersResponse struct {
Forwarders []forwarderEntry `json:"forwarders"`
Summary forwardersSummary `json:"summary"`
}
// selectForwarderSource picks the inventory source for the host. Only cPanel
// enumeration is wired; other platforms get the empty source until their
// adapters land (Phase 3).
func selectForwarderSource() inventory.Source {
if platform.Detect().IsCPanel() {
return inventory.NewCPanelSource()
}
return inventory.EmptySource{}
}
// apiEmailForwarders handles GET /api/v1/email/forwarders and returns the host's
// forwarder inventory: each source, its destinations with provider class, owner,
// and whether it keeps a local copy or forwards only.
func (s *Server) apiEmailForwarders(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
resp := forwardersResponse{Forwarders: []forwarderEntry{}}
if s.forwarderSource == nil {
writeJSON(w, resp)
return
}
fwds, err := s.forwarderSource.Forwarders()
if err != nil {
writeJSONError(w, "Failed to enumerate forwarders", http.StatusInternalServerError)
return
}
for _, f := range fwds {
resp.Forwarders = append(resp.Forwarders, toForwarderEntry(f))
resp.Summary.Total++
if f.HasExternal() {
resp.Summary.External++
}
if f.HasFreeProvider() {
resp.Summary.FreeProvider++
}
}
writeJSON(w, resp)
}
func toForwarderEntry(f inventory.Forwarder) forwarderEntry {
dests := make([]forwarderDestination, 0, len(f.Destinations))
seen := make(map[string]bool, len(f.Destinations))
providers := make([]string, 0, len(f.Destinations))
for _, d := range f.Destinations {
dests = append(dests, forwarderDestination{
Address: d.Address,
Domain: d.Domain,
Provider: string(d.Provider),
})
if p := string(d.Provider); !seen[p] {
seen[p] = true
providers = append(providers, p)
}
}
sort.Strings(providers)
return forwarderEntry{
Source: f.Source,
Domain: f.Domain,
Owner: f.Owner,
Destinations: dests,
Providers: providers,
KeepLocal: f.KeepLocal,
ForwardOnly: f.ForwardOnly,
HasExternal: f.HasExternal(),
HasFreeProvider: f.HasFreeProvider(),
}
}
package webui
import (
"net"
"net/http"
"github.com/pidginhost/csm/internal/geoip"
)
// SetGeoIPDB sets the GeoIP database for IP lookups.
func (s *Server) SetGeoIPDB(db *geoip.DB) {
s.geoIPDB.Store(db)
}
// apiGeoIPLookup returns geolocation info for an IP.
// GET /api/v1/geoip?ip=1.2.3.4 - fast local lookup (country + ASN)
// GET /api/v1/geoip?ip=1.2.3.4&detail=1 - includes RDAP org/ISP (may take 1-3s)
func (s *Server) apiGeoIPLookup(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip == "" {
writeJSONError(w, "ip parameter required", http.StatusBadRequest)
return
}
if net.ParseIP(ip) == nil {
writeJSONError(w, "invalid IP address", http.StatusBadRequest)
return
}
db := s.geoIPDB.Load()
if db == nil {
writeJSONError(w, "GeoIP databases not loaded", http.StatusServiceUnavailable)
return
}
var info geoip.Info
if r.URL.Query().Get("detail") == "1" {
info = db.LookupWithRDAP(ip)
} else {
info = db.Lookup(ip)
}
writeJSON(w, info)
}
// apiGeoIPBatch returns geolocation info for multiple IPs.
// POST /api/v1/geoip/batch body: {"ips": ["1.2.3.4", "5.6.7.8"]}
func (s *Server) apiGeoIPBatch(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IPs []string `json:"ips"`
}
if err := decodeJSONBodyLimited(w, r, 32*1024, &req); err != nil {
writeJSONError(w, "invalid request body", http.StatusBadRequest)
return
}
if len(req.IPs) > 500 {
writeJSONError(w, "maximum 500 IPs per request", http.StatusBadRequest)
return
}
type geoResult struct {
Country string `json:"country"`
CountryName string `json:"country_name"`
City string `json:"city"`
ASOrg string `json:"as_org"`
Error string `json:"error,omitempty"`
}
db := s.geoIPDB.Load()
results := make(map[string]geoResult, len(req.IPs))
for _, ip := range req.IPs {
if net.ParseIP(ip) == nil {
results[ip] = geoResult{Error: "invalid IP format"}
continue
}
if db == nil {
results[ip] = geoResult{Error: "GeoIP database not loaded"}
continue
}
info := db.Lookup(ip)
results[ip] = geoResult{
Country: info.Country,
CountryName: info.CountryName,
City: info.City,
ASOrg: info.ASOrg,
}
}
writeJSON(w, map[string]interface{}{"results": results})
}
package webui
import (
"bytes"
"fmt"
"net/http"
"os"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
)
func (s *Server) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
tmpl := s.templates[name]
if tmpl == nil {
fmt.Fprintf(os.Stderr, "[webui] template %s missing\n", name)
http.Error(w, "template not found", http.StatusInternalServerError)
return
}
// Render into a buffer first so an execution error can still surface as a
// 500 — html/template streams directly to its writer, and once any byte
// has been flushed the status header is locked in.
var buf bytes.Buffer
if err := tmpl.ExecuteTemplate(&buf, name, data); err != nil {
fmt.Fprintf(os.Stderr, "[webui] template %s error: %v\n", name, err)
http.Error(w, "template render error", http.StatusInternalServerError)
return
}
if _, err := w.Write(buf.Bytes()); err != nil {
fmt.Fprintf(os.Stderr, "[webui] template %s write error: %v\n", name, err)
}
}
type dashboardData struct {
Hostname string
Uptime string
Critical int
High int
Warning int
Total int
SigCount int
FanotifyActive bool
LogWatchers int
LastCriticalAgo string
LastCriticalISO string // RFC3339 of most recent critical, "" if none (so relative time can tick client-side)
RecentFindings []historyEntry
}
type historyEntry struct {
Severity string
SevClass string
Check string
Message string
Details string
Timestamp string
TimestampISO string // RFC3339 for JS comparison
TimeAgo string
HasFix bool
FixDesc string
Key string // canonical dedup key (matches alert.Finding.Key())
}
type quarantineData struct {
Hostname string
Files []quarantineEntry
}
type quarantineEntry struct {
ID string
OriginalPath string
Size int64
QuarantineAt string
Reason string
}
func (s *Server) handleDashboard(w http.ResponseWriter, _ *http.Request) {
last24h := time.Now().Add(-24 * time.Hour)
findings := s.store.ReadHistorySince(last24h)
var recent []historyEntry
critical, high, warning := 0, 0, 0
for _, f := range findings {
// Count all findings by severity
switch f.Severity {
case alert.Critical:
critical++
case alert.High:
high++
case alert.Warning:
warning++
}
// Skip internal checks from the live feed
if f.Check == "auto_response" || f.Check == "auto_block" || f.Check == "check_timeout" || f.Check == "health" {
continue
}
if len(recent) < 10 {
recent = append(recent, historyEntry{
Severity: severityLabel(f.Severity),
SevClass: severityClass(f.Severity),
Check: f.Check,
Message: f.Message,
Details: f.Details,
Timestamp: f.Timestamp.Format("15:04:05"),
TimestampISO: f.Timestamp.Format(time.RFC3339),
TimeAgo: timeAgo(f.Timestamp),
HasFix: checks.HasFix(f.Check),
FixDesc: checks.FixDescription(f.Check, f.Message, f.FilePath),
Key: f.Key(),
})
}
}
// Find most recent critical finding (findings are newest-first)
lastCriticalAgo := "None"
lastCriticalISO := ""
for _, f := range findings {
if f.Severity == alert.Critical {
lastCriticalAgo = timeAgo(f.Timestamp)
lastCriticalISO = f.Timestamp.Format(time.RFC3339)
break
}
}
data := dashboardData{
Hostname: s.cfg.Hostname,
Uptime: time.Since(s.startTime).Round(time.Second).String(),
Critical: critical,
High: high,
Warning: warning,
Total: critical + high + warning,
SigCount: s.sigCount,
FanotifyActive: s.fanotifyActive,
LogWatchers: s.logWatcherCount,
LastCriticalAgo: lastCriticalAgo,
LastCriticalISO: lastCriticalISO,
RecentFindings: recent,
}
s.renderTemplate(w, "dashboard.html", data)
}
func (s *Server) handleFindings(w http.ResponseWriter, _ *http.Request) {
// Findings page is now JS-driven - enriched API provides data
s.renderTemplate(w, "findings.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
func (s *Server) handleHistoryRedirect(w http.ResponseWriter, r *http.Request) {
// History is now a tab on the findings page - redirect for backward compat
target := "/findings?tab=history"
if qs := r.URL.RawQuery; qs != "" {
target = "/findings?tab=history&" + qs
}
// #nosec G710 -- target always starts with the fixed same-origin
// /findings path; the incoming query can only add parameters.
http.Redirect(w, r, target, http.StatusFound)
}
func (s *Server) handleQuarantine(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "quarantine.html", quarantineData{
Hostname: s.cfg.Hostname,
})
}
func (s *Server) handleCleanupHistory(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "cleanup-history.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
func (s *Server) handleFirewall(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "firewall.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
func (s *Server) handleEmail(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "email.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
func (s *Server) handleSettings(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "settings.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
package webui
import (
"net/http"
"time"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/store"
)
// apiHardening returns the last stored audit report (GET).
func (s *Server) apiHardening(w http.ResponseWriter, _ *http.Request) {
db := store.Global()
if db == nil {
writeJSON(w, &store.AuditReport{})
return
}
report, err := db.LoadHardeningReport()
if err != nil {
writeJSONError(w, "failed to load report: "+err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, report)
}
// apiHardeningRun runs the audit, stores the result, and returns it (POST only).
func (s *Server) apiHardeningRun(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if !s.acquireScan() {
writeJSONError(w, "A scan is already in progress. Please wait.", http.StatusConflict)
return
}
defer s.releaseScan()
// Extend write deadline for this long-running request
rc := http.NewResponseController(w)
_ = rc.SetWriteDeadline(time.Now().Add(3 * time.Minute))
report := checks.RunHardeningAudit(s.cfg)
if db := store.Global(); db != nil {
if err := db.SaveHardeningReport(report); err != nil {
writeJSONError(w, "audit completed but failed to save report: "+err.Error(), http.StatusInternalServerError)
return
}
}
writeJSON(w, report)
}
// handleHardening renders the hardening audit page.
func (s *Server) handleHardening(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "hardening.html", nil)
}
package webui
import (
"net/http"
"regexp"
"strings"
"github.com/pidginhost/csm/internal/mailfwd/quarantine"
"github.com/pidginhost/csm/internal/platform"
)
// heldIDRe bounds a held-message id to a Maildir filename shape. The store also
// defends (filepath.Base + regular-file check), but validating at the handler
// boundary rejects traversal/control input before it reaches the filesystem.
var heldIDRe = regexp.MustCompile(`^[A-Za-z0-9._-]{1,128}$`)
// heldForwardStore is the held-forward quarantine surface the webui needs.
// *quarantine.Quarantine satisfies it; tests use a fake.
type heldForwardStore interface {
List() ([]quarantine.HeldMessage, error)
Release(id string) error
Delete(id string) error
}
const forwardQuarantineDir = "/var/lib/csm/forward_quarantine/held"
// selectForwardHeld returns the held-forward store for the host. Only
// cPanel/exim writes held copies; other platforms have none.
func selectForwardHeld() heldForwardStore {
if platform.Detect().IsCPanel() {
return quarantine.New(forwardQuarantineDir)
}
return nil
}
// apiEmailHeldList handles GET /api/v1/email/held and returns the forward
// copies the guard has held.
func (s *Server) apiEmailHeldList(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if s.forwardHeld == nil {
writeJSON(w, []quarantine.HeldMessage{})
return
}
msgs, err := s.forwardHeld.List()
if err != nil {
writeJSONError(w, "Failed to list held forwards", http.StatusInternalServerError)
return
}
if msgs == nil {
msgs = []quarantine.HeldMessage{}
}
writeJSON(w, msgs)
}
// apiEmailHeldAction handles POST /api/v1/email/held/{id}/release (re-inject the
// held copy to its external recipient) and DELETE /api/v1/email/held/{id}
// (discard). Both mutate, so they run under auth + CSRF and are audit-logged.
func (s *Server) apiEmailHeldAction(w http.ResponseWriter, r *http.Request) {
tail := strings.TrimPrefix(r.URL.Path, "/api/v1/email/held/")
if tail == "" {
writeJSONError(w, "Missing held message ID", http.StatusBadRequest)
return
}
parts := strings.SplitN(tail, "/", 2)
id := parts[0]
action := ""
if len(parts) == 2 {
action = parts[1]
}
if !heldIDRe.MatchString(id) || strings.Contains(id, "..") {
writeJSONError(w, "Invalid held message ID", http.StatusBadRequest)
return
}
if s.forwardHeld == nil {
writeJSONError(w, "Forward guard not available on this host", http.StatusServiceUnavailable)
return
}
switch r.Method {
case http.MethodPost:
if action != "release" {
writeJSONError(w, "Unknown action; use /release", http.StatusBadRequest)
return
}
if err := s.forwardHeld.Release(id); err != nil {
writeJSONError(w, "Failed to release held forward: "+err.Error(), http.StatusInternalServerError)
return
}
s.auditLog(r, "email_held_release", id, "re-injected held forward copy to its external recipient")
writeJSON(w, map[string]string{"status": "released", "id": id})
case http.MethodDelete:
if action != "" {
writeJSONError(w, "Unknown action", http.StatusBadRequest)
return
}
if err := s.forwardHeld.Delete(id); err != nil {
writeJSONError(w, "Failed to delete held forward: "+err.Error(), http.StatusInternalServerError)
return
}
s.auditLog(r, "email_held_delete", id, "deleted held forward copy")
writeJSON(w, map[string]string{"status": "deleted", "id": id})
default:
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
package webui
import (
"bytes"
"encoding/json"
"fmt"
"html/template"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/integrity"
)
// jsonForScript marshals v to JSON and returns it as template.JS suitable
// for direct substitution into a <script> block. json.Marshal already
// escapes < > & U+2028 U+2029 to \uXXXX form, so the output cannot break
// out of the surrounding <script> tag or trigger JS line-terminator
// parsing quirks. On marshal failure the fallback is the JS literal
// "null" so the enclosing template still parses.
func jsonForScript(v interface{}) template.JS {
b, err := json.Marshal(v)
if err != nil {
return template.JS("null")
}
// Defense-in-depth: Go's json.Marshal has historically escaped these
// by default, but an explicit pass guarantees the contract even if
// that default ever changes or the input arrived pre-encoded.
b = bytes.ReplaceAll(b, []byte("<"), []byte(`\u003c`))
b = bytes.ReplaceAll(b, []byte(">"), []byte(`\u003e`))
b = bytes.ReplaceAll(b, []byte("&"), []byte(`\u0026`))
b = bytes.ReplaceAll(b, []byte("\u2028"), []byte(`\u2028`))
b = bytes.ReplaceAll(b, []byte("\u2029"), []byte(`\u2029`))
// #nosec G203 -- Output is JSON bytes with HTML/JS-dangerous codepoints
// escaped above; safe to hand to html/template as JS.
return template.JS(b)
}
func decodeJSONBodyLimited(w http.ResponseWriter, r *http.Request, limit int64, dst interface{}) error {
if limit <= 0 {
limit = 64 * 1024
}
r.Body = http.MaxBytesReader(w, r.Body, limit)
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(dst); err != nil {
return err
}
if dec.More() {
return fmt.Errorf("request body must contain a single JSON value")
}
return nil
}
// validateEximMessageID rejects message ids that would be unsafe to
// interpolate into filesystem paths. Real Exim ids are of the form
// 6-6-2 on Exim 4.96 and older, and 6-11-4 on Exim 4.97 and newer. The
// validator accepts anything with the same character class so callers can use
// shorter fixtures in tests while still blocking `..`, `/`, `\`, dots, NUL
// bytes, and other shell / path-traversal metacharacters even if the inner
// Quarantine layer's filepath.Base() guard regresses.
func validateEximMessageID(id string) error {
if id == "" {
return fmt.Errorf("message id is required")
}
if len(id) > 32 {
return fmt.Errorf("message id too long (%d chars, max 32)", len(id))
}
for _, c := range id {
if (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') && (c < '0' || c > '9') && c != '-' {
return fmt.Errorf("message id contains invalid character: %q", c)
}
}
return nil
}
// mustBeWithin resolves candidate against root and reports whether the
// result still lives inside root after symlink and `..` resolution.
// Returns the cleaned absolute path on success. Defense-in-depth for any
// handler that takes a user-supplied path fragment, joins it under a
// trusted root, and hands the result to os.Remove / os.RemoveAll / etc.
func mustBeWithin(root, candidate string) (string, error) {
if root == "" {
return "", fmt.Errorf("root is required")
}
absRoot, err := filepath.Abs(filepath.Clean(root))
if err != nil {
return "", fmt.Errorf("resolve root: %w", err)
}
absRoot, err = filepath.EvalSymlinks(absRoot)
if err != nil {
return "", fmt.Errorf("resolve root symlinks: %w", err)
}
abs, err := resolvePathUnderRoot(absRoot, rootRelativeCandidate(candidate))
if err != nil {
return "", err
}
if !isPathWithin(abs, absRoot) {
return "", fmt.Errorf("path %q escapes root %q", candidate, absRoot)
}
return abs, nil
}
func rootRelativeCandidate(candidate string) string {
if volume := filepath.VolumeName(candidate); volume != "" {
candidate = strings.TrimPrefix(candidate, volume)
}
candidate = strings.TrimLeft(candidate, string(filepath.Separator))
if candidate == "" {
return "."
}
return candidate
}
func resolvePathUnderRoot(root, rel string) (string, error) {
parts := pathParts(rel)
current := root
symlinkCount := 0
for i := 0; i < len(parts); i++ {
part := parts[i]
switch part {
case ".", "":
continue
case "..":
current = filepath.Dir(current)
if !isPathWithin(current, root) {
return "", fmt.Errorf("path %q escapes root %q", rel, root)
}
continue
}
next := filepath.Join(current, part)
info, err := os.Lstat(next)
if err != nil {
if os.IsNotExist(err) {
current = next
continue
}
return "", fmt.Errorf("stat candidate: %w", err)
}
if info.Mode()&os.ModeSymlink == 0 {
current = next
continue
}
symlinkCount++
if symlinkCount > 255 {
return "", fmt.Errorf("too many symlinks resolving %q", rel)
}
target, err := os.Readlink(next)
if err != nil {
return "", fmt.Errorf("read symlink: %w", err)
}
if !filepath.IsAbs(target) {
target = filepath.Join(filepath.Dir(next), target)
}
target = filepath.Clean(target)
targetRel, err := filepath.Rel(root, target)
if err != nil {
return "", fmt.Errorf("resolve symlink target: %w", err)
}
nextParts := append([]string{}, pathParts(targetRel)...)
nextParts = append(nextParts, parts[i+1:]...)
parts = nextParts
current = root
i = -1
}
return filepath.Clean(current), nil
}
func pathParts(path string) []string {
if path == "" || path == "." {
return nil
}
raw := strings.Split(path, string(filepath.Separator))
parts := make([]string, 0, len(raw))
for _, part := range raw {
if part != "" && part != "." {
parts = append(parts, part)
}
}
return parts
}
// validateAccountName checks that name is a valid cPanel account name:
// 1-64 characters, alphanumeric and underscore only.
func validateAccountName(name string) error {
if name == "" {
return fmt.Errorf("account name is required")
}
if len(name) > 64 {
return fmt.Errorf("account name too long (%d chars, max 64)", len(name))
}
if (name[0] < 'a' || name[0] > 'z') && (name[0] < 'A' || name[0] > 'Z') {
return fmt.Errorf("account name must start with a letter")
}
for _, c := range name {
if (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') && (c < '0' || c > '9') && c != '_' {
return fmt.Errorf("account name contains invalid character: %c", c)
}
}
return nil
}
// parseAndValidateIP parses an IP string and rejects non-routable addresses
// (loopback, private RFC 1918, link-local, multicast, unspecified, broadcast).
// RFC 5737 documentation ranges (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24)
// are intentionally allowed.
func parseAndValidateIP(s string) (net.IP, error) {
s = strings.TrimSpace(s)
if s == "" {
return nil, fmt.Errorf("IP address is required")
}
ip := net.ParseIP(s)
if ip == nil {
return nil, fmt.Errorf("invalid IP address: %s", s)
}
if ip.IsLoopback() {
return nil, fmt.Errorf("loopback address not allowed: %s", s)
}
if ip.IsPrivate() {
return nil, fmt.Errorf("private address not allowed: %s", s)
}
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return nil, fmt.Errorf("link-local address not allowed: %s", s)
}
if ip.IsMulticast() {
return nil, fmt.Errorf("multicast address not allowed: %s", s)
}
if ip.IsUnspecified() {
return nil, fmt.Errorf("unspecified address not allowed: %s", s)
}
// Broadcast: 255.255.255.255
if ip.Equal(net.IPv4bcast) {
return nil, fmt.Errorf("broadcast address not allowed: %s", s)
}
return ip, nil
}
// validateCIDR parses a CIDR string and rejects overly broad prefixes
// (/0 through /7; minimum allowed is /8).
func validateCIDR(s string) (*net.IPNet, error) {
if s == "" {
return nil, fmt.Errorf("CIDR is required")
}
_, ipNet, err := net.ParseCIDR(s)
if err != nil {
return nil, fmt.Errorf("invalid CIDR: %w", err)
}
ones, bits := ipNet.Mask.Size()
minPrefix := 8
if bits == 128 { // IPv6
minPrefix = 32
}
if ones < minPrefix {
return nil, fmt.Errorf("CIDR prefix /%d is too broad (minimum /%d)", ones, minPrefix)
}
return ipNet, nil
}
// parseDuration parses a human-friendly duration string from the web UI.
// Supported formats: "24h", "7d", "30d", "0" (permanent), "" (permanent).
func parseDuration(s string) time.Duration {
s = strings.TrimSpace(s)
if s == "" || s == "0" {
return 0
}
if strings.HasSuffix(s, "d") {
s = strings.TrimSuffix(s, "d")
days := 0
for _, c := range s {
if c < '0' || c > '9' {
return 0
}
days = days*10 + int(c-'0')
}
return time.Duration(days) * 24 * time.Hour
}
if d, err := time.ParseDuration(s); err == nil {
return d
}
return 0
}
// isPathUnder returns true if the cleaned path is strictly under the base
// directory. It prevents path traversal via ".." and prefix tricks
// (e.g., /home/username is not under /home/user).
func isPathUnder(path, base string) bool {
cleanPath := filepath.Clean(path)
cleanBase := filepath.Clean(base)
// Ensure base ends with separator so "/home/user" doesn't match "/home/username"
prefix := cleanBase + string(filepath.Separator)
return strings.HasPrefix(cleanPath, prefix)
}
func isPathWithin(path, base string) bool {
cleanPath := filepath.Clean(path)
cleanBase := filepath.Clean(base)
return cleanPath == cleanBase || strings.HasPrefix(cleanPath, cleanBase+string(filepath.Separator))
}
func quarantineLiveState(archivePath, originalPath string) string {
origInfo, err := os.Stat(originalPath)
if os.IsNotExist(err) {
return "original_missing"
}
if err != nil {
return "unknown"
}
if !origInfo.Mode().IsRegular() {
return "original_not_file"
}
archInfo, err := os.Stat(archivePath)
if os.IsNotExist(err) {
return "archive_missing"
}
if err != nil {
return "unknown"
}
if !archInfo.Mode().IsRegular() {
return "archive_not_file"
}
if origInfo.Size() != archInfo.Size() {
return "live_differs"
}
origHash, err := integrity.HashFile(originalPath)
if err != nil {
return "unknown"
}
archHash, err := integrity.HashFile(archivePath)
if err != nil {
return "unknown"
}
if origHash == archHash {
return "restored_identical"
}
return "live_differs"
}
const preCleanQuarantineIDPrefix = "pre_clean:"
type quarantineEntryRef struct {
ID string
ItemPath string
MetaPath string
}
func quarantineEntryID(metaPath string) string {
id := strings.TrimSuffix(filepath.Base(metaPath), ".meta")
if filepath.Clean(filepath.Dir(metaPath)) == filepath.Join(quarantineDir, "pre_clean") {
return preCleanQuarantineIDPrefix + id
}
return id
}
func resolveQuarantineEntry(id string) (quarantineEntryRef, error) {
rawID := strings.TrimSpace(id)
if rawID == "" {
return quarantineEntryRef{}, fmt.Errorf("quarantine ID is required")
}
baseDir := quarantineDir
name := rawID
if strings.HasPrefix(rawID, preCleanQuarantineIDPrefix) {
baseDir = filepath.Join(quarantineDir, "pre_clean")
name = strings.TrimPrefix(rawID, preCleanQuarantineIDPrefix)
}
name = filepath.Base(name)
if name == "" || name == "." || name == ".." {
return quarantineEntryRef{}, fmt.Errorf("invalid quarantine ID")
}
itemPath := filepath.Join(baseDir, name)
if !isPathWithin(itemPath, baseDir) {
return quarantineEntryRef{}, fmt.Errorf("invalid quarantine ID")
}
return quarantineEntryRef{
ID: rawID,
ItemPath: itemPath,
MetaPath: itemPath + ".meta",
}, nil
}
// quarantineMeta represents the JSON sidecar metadata for a quarantined file.
// Must match checks.QuarantineMeta on-disk format.
type quarantineMeta struct {
OriginalPath string `json:"original_path"`
Owner int `json:"owner_uid"`
Group int `json:"group_gid"`
Mode string `json:"mode"`
Size int64 `json:"size"`
QuarantineAt time.Time `json:"quarantined_at"`
Reason string `json:"reason"`
}
// readQuarantineMeta reads and parses a quarantine .meta JSON file.
func readQuarantineMeta(metaPath string) (*quarantineMeta, error) {
// #nosec G304 -- metaPath is constructed by resolveQuarantineEntry under
// the quarantine base dir with filepath.Base applied to the ID.
data, err := os.ReadFile(metaPath)
if err != nil {
return nil, fmt.Errorf("read quarantine meta: %w", err)
}
var meta quarantineMeta
if err := json.Unmarshal(data, &meta); err != nil {
return nil, fmt.Errorf("parse quarantine meta: %w", err)
}
return &meta, nil
}
// listMetaFiles returns the full paths of all .meta files in dir (non-recursive).
// Returns nil on any error (e.g., directory does not exist).
func listMetaFiles(dir string) []string {
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
var metas []string
for _, e := range entries {
if e.IsDir() {
continue
}
if strings.HasSuffix(e.Name(), ".meta") {
metas = append(metas, filepath.Join(dir, e.Name()))
}
}
return metas
}
var quarantineRestoreRoots = []string{"/home", "/tmp", "/dev/shm", "/var/tmp"}
func validateQuarantineRestorePath(path string) (string, error) {
cleanPath := filepath.Clean(strings.TrimSpace(path))
if cleanPath == "" {
return "", fmt.Errorf("restore path is required")
}
if !filepath.IsAbs(cleanPath) {
return "", fmt.Errorf("restore path must be absolute")
}
if !pathWithinAny(cleanPath, quarantineRestoreRoots) {
return "", fmt.Errorf("restore path is outside the allowed restore roots: %s", cleanPath)
}
ancestor, err := nearestExistingAncestor(cleanPath)
if err != nil {
return "", err
}
resolvedAncestor, err := filepath.EvalSymlinks(ancestor)
if err != nil {
return "", fmt.Errorf("cannot validate restore path: %w", err)
}
if !pathWithinAny(resolvedAncestor, quarantineRestoreRoots) {
return "", fmt.Errorf("restore path escapes the allowed restore roots: %s", cleanPath)
}
if accountRoot := homeAccountRoot(cleanPath); accountRoot != "" && !isPathWithin(resolvedAncestor, accountRoot) {
return "", fmt.Errorf("restore path escapes the account boundary: %s", cleanPath)
}
return cleanPath, nil
}
func pathWithinAny(path string, bases []string) bool {
for _, base := range bases {
if isPathWithin(path, base) {
return true
}
if resolvedBase, err := filepath.EvalSymlinks(base); err == nil && isPathWithin(path, resolvedBase) {
return true
}
}
return false
}
func nearestExistingAncestor(path string) (string, error) {
current := filepath.Clean(path)
for {
if _, err := os.Lstat(current); err == nil {
return current, nil
} else if !os.IsNotExist(err) {
return "", fmt.Errorf("cannot stat restore path: %w", err)
}
parent := filepath.Dir(current)
if parent == current {
return "", fmt.Errorf("restore path has no existing ancestor: %s", path)
}
current = parent
}
}
func homeAccountRoot(path string) string {
cleanPath := filepath.Clean(path)
if !strings.HasPrefix(cleanPath, "/home/") {
return ""
}
parts := strings.Split(cleanPath, string(filepath.Separator))
if len(parts) < 4 {
return ""
}
return filepath.Join("/home", parts[2])
}
package webui
import (
"fmt"
"net/http"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/incident"
)
type timelineEvent struct {
Timestamp string `json:"timestamp"`
Type string `json:"type"` // "finding", "action", "block"
Severity int `json:"severity"` // 0=info, 1=high, 2=critical
Summary string `json:"summary"`
Details string `json:"details,omitempty"`
Source string `json:"source"` // "history", "audit", "firewall"
}
const incidentTimelineEventLimit = 200
// incidentSnapshotScanCap bounds how many incidents the timeline walk
// will inspect from the correlator snapshot per request. Correlators on
// hot hosts can hold thousands of open + recently-closed incidents and
// each one carries a full event timeline; without this cap a single
// /api/v1/incident call can walk hundreds of thousands of timeline
// events before paginating down to incidentTimelineEventLimit. The cap
// is generous (most timelines hit the 200-event ceiling well before the
// 1000-incident ceiling) but bounds worst-case wall time and memory.
const incidentSnapshotScanCap = 1000
func (s *Server) handleIncident(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "incident.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
func (s *Server) apiIncident(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
account := r.URL.Query().Get("account")
if ip == "" && account == "" {
writeJSONError(w, "ip or account parameter is required", http.StatusBadRequest)
return
}
hours := queryInt(r, "hours", 72)
if hours > 720 {
hours = 720
} // max 30 days
cutoff := time.Now().Add(-time.Duration(hours) * time.Hour)
var events []timelineEvent
// Build search terms
var searchTerms []string
if ip != "" {
searchTerms = append(searchTerms, ip)
}
if account != "" {
searchTerms = append(searchTerms, "/home/"+account+"/", account)
}
dedup := make(map[string]struct{})
dedupKey := func(t time.Time, summary string) string {
return t.UTC().Format(time.RFC3339Nano) + "|" + summary
}
matchesHistoryQuery := func(f alert.Finding) bool {
matched := false
for _, term := range searchTerms {
if strings.Contains(f.Message, term) || strings.Contains(f.Details, term) {
matched = true
break
}
}
if !matched {
return false
}
summary := f.Check + ": " + f.Message
key := dedupKey(f.Timestamp, summary)
if _, seen := dedup[key]; seen {
return false
}
dedup[key] = struct{}{}
return true
}
// Search newest-first and stop once the timeline has enough matching
// history rows. Busy hosts can retain large 30-day windows, so response
// size alone is not a safe bound for the read path.
allHistory := s.store.SearchHistorySince(cutoff, incidentTimelineEventLimit, matchesHistoryQuery)
for _, f := range allHistory {
summary := f.Check + ": " + f.Message
events = append(events, timelineEvent{
Timestamp: f.Timestamp.Format(time.RFC3339),
Type: "finding",
Severity: int(f.Severity),
Summary: summary,
Details: f.Details,
Source: "history",
})
}
// Fold in events from the incident correlator. The finding history
// bucket rotates aggressively on busy hosts so a Critical incident
// from two days ago may have no surviving history row, but the
// incident object still carries the full timeline. Walk every
// incident, match by RemoteIP for IP queries or by Account / Mailbox /
// Domain for account queries, and emit each matching timeline event.
truncated := false
if s.incidentCorrelator != nil {
snap, totalIncidents := s.incidentCorrelator.SnapshotPageStatuses(nil, 0, incidentSnapshotScanCap)
truncated = totalIncidents > len(snap)
for _, inc := range snap {
incMatches := incidentMatchesAccount(inc, account)
for _, ev := range inc.Timeline {
if ev.Time.Before(cutoff) {
continue
}
match := false
if ip != "" && ev.RemoteIP == ip {
match = true
}
if !match && incMatches {
match = true
}
if !match {
continue
}
summary := ev.Check
if ev.Message != "" {
if summary != "" {
summary += ": "
}
summary += ev.Message
}
key := dedupKey(ev.Time, summary)
if _, seen := dedup[key]; seen {
continue
}
dedup[key] = struct{}{}
events = append(events, timelineEvent{
Timestamp: ev.Time.Format(time.RFC3339),
Type: "finding",
Severity: int(inc.Severity),
Summary: summary,
Details: "From incident " + inc.ID + " (" + string(inc.Kind) + ", " + string(inc.Status) + ")",
Source: "incident:" + inc.ID,
})
}
}
}
// Search UI audit log
auditEntries := readUIAuditLog(s.cfg.StatePath, 500)
for _, a := range auditEntries {
if a.Timestamp.Before(cutoff) {
continue
}
matched := false
for _, term := range searchTerms {
if strings.Contains(a.Target, term) || strings.Contains(a.Details, term) {
matched = true
break
}
}
if !matched {
continue
}
events = append(events, timelineEvent{
Timestamp: a.Timestamp.Format(time.RFC3339),
Type: "action",
Severity: 0,
Summary: a.Action + ": " + a.Target,
Details: a.Details,
Source: "audit",
})
}
// Sort by timestamp descending (newest first)
sort.Slice(events, func(i, j int) bool {
return events[i].Timestamp > events[j].Timestamp
})
if len(events) > incidentTimelineEventLimit {
events = events[:incidentTimelineEventLimit]
truncated = true
}
if truncated {
w.Header().Set("X-CSM-Truncated", "1")
}
writeJSON(w, map[string]interface{}{
"events": events,
"total": len(events),
"query_ip": ip,
"query_account": account,
"hours": hours,
"truncated": truncated,
})
}
// incidentMatchesAccount reports whether an incident's identity fields
// match the account search term. Empty account never matches so an
// IP-only query does not pull in unrelated incidents.
func incidentMatchesAccount(inc incident.Incident, account string) bool {
if account == "" {
return false
}
return inc.Account == account || inc.Mailbox == account || inc.Domain == account
}
// maxIncidentPageSize caps the page size a client may request so a
// misbehaving consumer cannot OOM the daemon by asking for the whole
// world in one round-trip. The web UI's default page size is well
// below this; the ceiling exists for defense in depth.
const maxIncidentPageSize = 500
// defaultIncidentPageSize is applied when the client requests a paged
// shape (any of limit/offset/status set) but does not pass an explicit
// limit. Tuned to fit comfortably on one screen.
const defaultIncidentPageSize = 50
// apiIncidentList serves GET /api/v1/incidents.
//
// Default (no query parameters): returns the full Snapshot as a bare
// JSON array, preserving the wire shape the existing API consumers
// (phpanel, SIEM tooling) decode against.
//
// When the client passes any of ?limit=, ?offset=, ?status=, the
// response switches to an envelope: {"items":[...], "total":N,
// "offset":N, "limit":N, "status":"..."}. Servers that pass the
// envelope must always include all five fields so the client can
// render an accurate page header without a second probe.
//
// status accepts the four spec values (open/contained/resolved/dismissed)
// plus the UI-only convenience "active" that means
// open+contained. An empty string means all statuses. Anything else is
// rejected with 400 Bad Request rather than silently widening to all,
// which would hide a typo like ?status=opn.
func (s *Server) apiIncidentList(w http.ResponseWriter, r *http.Request) {
if s.incidentCorrelator == nil {
writeJSON(w, []incident.Incident{})
return
}
q := r.URL.Query()
hasPagingParams := q.Has("limit") || q.Has("offset") || q.Has("status")
if !hasPagingParams {
writeJSON(w, s.incidentCorrelator.Snapshot())
return
}
statusParam := q.Get("status")
statuses, err := parseIncidentStatusFilter(statusParam)
if err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
limit := queryInt(r, "limit", defaultIncidentPageSize)
if limit <= 0 {
limit = defaultIncidentPageSize
}
if limit > maxIncidentPageSize {
limit = maxIncidentPageSize
}
offset := queryInt(r, "offset", 0)
if offset < 0 {
offset = 0
}
items, total := s.incidentPage(statuses, offset, limit)
writeJSON(w, map[string]any{
"items": items,
"total": total,
"offset": offset,
"limit": limit,
"status": statusParam,
})
}
// parseIncidentStatusFilter validates the status query parameter and
// returns the set of statuses it expands to. Empty input means "all";
// "active" is the UI-only convenience for open+contained.
func parseIncidentStatusFilter(s string) ([]incident.Status, error) {
switch s {
case "":
return nil, nil
case "active":
return []incident.Status{incident.StatusOpen, incident.StatusContained}, nil
case string(incident.StatusOpen),
string(incident.StatusContained),
string(incident.StatusResolved),
string(incident.StatusDismissed):
return []incident.Status{incident.Status(s)}, nil
}
return nil, fmt.Errorf("invalid status %q", s)
}
// incidentPage returns a status-filtered page. The "active" filter
// expands to open+contained and is handled by the correlator in one
// sorted pass so pagination stays stable across statuses.
func (s *Server) incidentPage(statuses []incident.Status, offset, limit int) ([]incident.Incident, int) {
return s.incidentCorrelator.SnapshotPageStatuses(statuses, offset, limit)
}
// apiIncidentShow serves GET /api/v1/incidents/<id>. 404 if not found.
func (s *Server) apiIncidentShow(w http.ResponseWriter, r *http.Request) {
id := strings.TrimPrefix(r.URL.Path, "/api/v1/incidents/")
id = strings.TrimSuffix(id, "/")
if id == "" || s.incidentCorrelator == nil {
http.NotFound(w, r)
return
}
inc, ok := s.incidentCorrelator.Get(id)
if !ok {
http.NotFound(w, r)
return
}
writeJSON(w, inc)
}
// apiIncidentStatus serves POST /api/v1/incidents/<id>/status. Body
// {"status": "resolved", "details": "..."}.
func (s *Server) apiIncidentStatus(w http.ResponseWriter, r *http.Request) {
id := strings.TrimPrefix(r.URL.Path, "/api/v1/incidents/")
id = strings.TrimSuffix(id, "/status")
id = strings.Trim(id, "/")
var body struct {
Status string `json:"status"`
Details string `json:"details"`
}
// Cap the request body like every other mutating handler; a bare
// json.NewDecoder(r.Body) would buffer an unbounded body into memory.
if err := decodeJSONBodyLimited(w, r, 16*1024, &body); err != nil {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
if s.incidentCorrelator == nil {
http.Error(w, "incidents not enabled", http.StatusServiceUnavailable)
return
}
if err := s.incidentCorrelator.SetStatus(id, incident.Status(body.Status), body.Details); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
}
// apiIncidentRouter dispatches /api/v1/incidents/<id>[...] sub-paths.
// POST .../status -> apiIncidentStatus; GET .../<id> -> apiIncidentShow.
func (s *Server) apiIncidentRouter(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/status") {
s.apiIncidentStatus(w, r)
return
}
s.apiIncidentShow(w, r)
}
package webui
import (
"net/http"
"strconv"
"strings"
"github.com/pidginhost/csm/internal/incident"
)
// incidentGroupsDefaultLimit / Max bound how many group rows the UI
// receives. A typical busy production host emits ~12 attacker-IP rows;
// the cap of 200 leaves headroom for the long tail without unbounded
// payload size.
const (
incidentGroupsDefaultLimit = 50
incidentGroupsMaxLimit = 200
)
// apiIncidentGroups handles GET /api/v1/incidents/groups. Buckets the
// in-memory incident snapshot by (kind, source) and returns rolled-up
// group rows. Read-scope eligible; never mutates state.
func (s *Server) apiIncidentGroups(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
q := r.URL.Query()
limit := queryInt(r, "limit", incidentGroupsDefaultLimit)
if limit <= 0 || limit > incidentGroupsMaxLimit {
limit = incidentGroupsDefaultLimit
}
offset := queryInt(r, "offset", 0)
if offset < 0 {
offset = 0
}
filter := incident.GroupFilter{
Kind: incident.Kind(strings.TrimSpace(q.Get("kind"))),
Offset: offset,
MaxGroups: limit,
}
switch strings.ToLower(strings.TrimSpace(q.Get("status"))) {
case "", "active":
// Default surface: open + contained, the UI's primary tab.
filter.StatusSet = []incident.Status{incident.StatusOpen, incident.StatusContained}
case "all":
// No status filter; the operator wants the full picture.
case string(incident.StatusOpen),
string(incident.StatusContained),
string(incident.StatusResolved),
string(incident.StatusDismissed):
filter.StatusSet = []incident.Status{incident.Status(q.Get("status"))}
default:
writeJSONError(w, "unknown status: "+strconv.Quote(q.Get("status")), http.StatusBadRequest)
return
}
if s.incidentCorrelator == nil {
writeJSON(w, incident.GroupsResponse{Groups: []incident.Group{}})
return
}
resp := incident.BuildGroups(s.incidentCorrelator.Snapshot(), filter)
writeJSON(w, resp)
}
package webui
import (
"crypto/subtle"
"fmt"
"net/http"
"os"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/metrics"
)
// handleMetrics serves Prometheus text exposition from the process
// default metrics registry. ROADMAP item 4.
//
// Auth policy (checked in isMetricsAuthenticated):
//
// - If `webui.metrics_token` is set in the live config, a matching
// `Authorization: Bearer` header unlocks the endpoint. The token
// is read from config.Active() per request so a SIGHUP rotation
// (the field is tagged `hotreload:"safe"`) takes effect without
// a restart.
// - As a fallback, a valid UI session cookie or the UI AuthToken
// Bearer is accepted so the dashboard can self-scrape without a
// second credential.
//
// No CSRF required: metrics is read-only and idempotent.
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet && r.Method != http.MethodHead {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if !s.isMetricsAuthenticated(r) {
w.Header().Set("WWW-Authenticate", `Bearer realm="csm-metrics"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
w.Header().Set("Cache-Control", "no-store")
if r.Method == http.MethodHead {
return
}
if err := metrics.WriteOpenMetrics(w); err != nil {
// The response body is already partially written; there is no
// meaningful HTTP status to flip. Drop a server log line.
fmt.Fprintf(os.Stderr, "webui: metrics WriteOpenMetrics: %v\n", err)
}
}
func (s *Server) isMetricsAuthenticated(r *http.Request) bool {
// Read the metrics token from config.Active() when available so a
// SIGHUP-driven rotation of webui.metrics_token takes effect on
// the next request. MetricsToken is tagged `hotreload:"safe"` for
// this reason, even though its WebUI parent is restart-required
// for the listener/TLS/auth-token fields. Fall back to s.cfg on
// the cold-start window before SetActive is called.
tok := s.cfg.WebUI.MetricsToken
if live := config.Active(); live != nil {
tok = live.WebUI.MetricsToken
}
if tok != "" {
if auth := r.Header.Get("Authorization"); len(auth) > 7 && auth[:7] == "Bearer " {
if subtle.ConstantTimeCompare([]byte(auth[7:]), []byte(tok)) == 1 {
return true
}
}
}
// Fall back to the UI session / AuthToken path so the dashboard
// can scrape itself without a second credential.
return s.isAuthenticated(r)
}
package webui
import (
"net/http"
"sort"
"strconv"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/store"
)
func (s *Server) handleModSec(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "modsec.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
// modsecBlockView is an aggregated view of blocks per IP+rule. Phase 8.4
// extends the response with first_seen / last_seen_iso (RFC3339), top_uris,
// domain_count, and sample_events so the workbench can drive the detail
// panel without a second round trip. The extra fields are additive; legacy
// field names keep their JSON keys.
type modsecBlockView struct {
IP string `json:"ip"`
RuleID string `json:"rule_id"`
Description string `json:"description"`
Domains string `json:"domains"`
DomainList []string `json:"domain_list,omitempty"`
DomainCount int `json:"domain_count"`
Hits int `json:"hits"`
LastSeen string `json:"last_seen"`
FirstSeen string `json:"first_seen"`
LastSeenISO string `json:"last_seen_iso"`
TopURIs []string `json:"top_uris"`
SampleEvents []modsecSampleEvent `json:"sample_events"`
Escalated bool `json:"escalated"`
}
// modsecSampleEvent is a compact per-IP event included in the grouped
// blocks response so the UI can show recent activity without a second
// call to /api/v1/modsec/events.
type modsecSampleEvent struct {
Time string `json:"time"`
RuleID string `json:"rule_id"`
Hostname string `json:"hostname"`
URI string `json:"uri"`
Severity string `json:"severity"`
}
// modsecEventView is a single ModSecurity event.
type modsecEventView struct {
Time string `json:"time"`
IP string `json:"ip"`
RuleID string `json:"rule_id"`
Hostname string `json:"hostname"`
URI string `json:"uri"`
Severity string `json:"severity"`
}
// apiModSecStats returns 24h summary stats for ModSecurity blocks.
func (s *Server) apiModSecStats(w http.ResponseWriter, _ *http.Request) {
findings := deduplicateModSecFindings(s.modsecFindings24h())
uniqueIPs := make(map[string]bool)
ruleCounts := make(map[string]int)
escalated := 0
for _, f := range findings {
ip := extractModSecIP(f)
if ip != "" {
uniqueIPs[ip] = true
}
rule := extractModSecRule(f)
if rule != "" {
ruleCounts[rule]++
}
if isModSecEscalation(f.Check) {
escalated++
}
}
topRule := "--"
topCount := 0
for rule, count := range ruleCounts {
if count > topCount {
topCount = count
topRule = rule
}
}
writeJSON(w, map[string]interface{}{
"total": len(findings),
"unique_ips": len(uniqueIPs),
"escalated": escalated,
"top_rule": topRule,
})
}
const (
// modsecFindingsScanCap bounds the 24h history read before per-IP
// aggregation starts. The aggregate map has its own cap below, but
// the handler still needs a findings cap for hosts seeing one or two
// hot IP+rule pairs millions of times.
modsecFindingsScanCap = 10000
// modsecBlocksMaxAggregates caps the IP+rule aggregation map so a host
// with millions of unique ModSec rule hits cannot OOM the daemon by
// asking for /api/v1/modsec/blocks. Existing aggregates keep updating
// after the cap is reached; new IP+rule keys past the cap are dropped
// silently with the X-CSM-Truncated response header set so monitoring
// can flag the condition. Default sized for ~50 MB peak: 50000 entries
// times a few hundred bytes per aggregate.
modsecBlocksMaxAggregates = 50000
)
// apiModSecBlocks returns aggregated blocks per IP+rule for the last 24h.
func (s *Server) apiModSecBlocks(w http.ResponseWriter, _ *http.Request) {
findings, truncated := s.modsecFindings24hWithTruncation()
findings = deduplicateModSecFindings(findings)
type blockAgg struct {
ip string
ruleID string
description string
domains map[string]bool
uriCounts map[string]int
hits int
firstSeen time.Time
lastSeen time.Time
escalated bool
samples []modsecSampleEvent // newest-first, capped at 3
}
byBlock := make(map[string]*blockAgg)
escalatedIPs := make(map[string]bool)
blockKey := func(ip, rule string) string {
return ip + "\x00" + rule
}
for _, f := range findings {
if isModSecEscalation(f.Check) {
ip := extractModSecIP(f)
if ip != "" {
escalatedIPs[ip] = true
}
continue
}
ip := extractModSecIP(f)
if ip == "" {
continue
}
rule := extractModSecRule(f)
desc := extractModSecDescription(f)
domain := extractModSecHostname(f)
uri := extractModSecURI(f)
key := blockKey(ip, rule)
agg, ok := byBlock[key]
if !ok {
if len(byBlock) >= modsecBlocksMaxAggregates {
truncated = true
continue
}
agg = &blockAgg{
ip: ip,
ruleID: rule,
description: desc,
domains: make(map[string]bool),
uriCounts: make(map[string]int),
firstSeen: f.Timestamp,
}
byBlock[key] = agg
}
agg.hits++
if agg.firstSeen.IsZero() || f.Timestamp.Before(agg.firstSeen) {
agg.firstSeen = f.Timestamp
}
if f.Timestamp.After(agg.lastSeen) {
agg.lastSeen = f.Timestamp
if rule != "" {
agg.ruleID = rule
}
if desc != "" {
agg.description = desc
}
}
if agg.description == "" && desc != "" {
agg.description = desc
}
// Skip server IPs and empty hostnames - only show actual domain names.
// ModSecurity logs the server IP as hostname when the request doesn't
// match a specific vhost (e.g. direct IP access, SNI mismatch).
if domain != "" && !looksLikeIP(domain) {
agg.domains[domain] = true
}
if uri != "" {
agg.uriCounts[uri]++
}
if len(agg.samples) < 3 {
agg.samples = append(agg.samples, modsecSampleEvent{
Time: f.Timestamp.UTC().Format(time.RFC3339),
RuleID: rule,
Hostname: domain,
URI: uri,
Severity: f.Severity.String(),
})
}
}
for ip := range escalatedIPs {
hasBlock := false
for _, agg := range byBlock {
if agg.ip == ip {
agg.escalated = true
hasBlock = true
}
}
if !hasBlock {
if len(byBlock) >= modsecBlocksMaxAggregates {
truncated = true
continue
}
byBlock[blockKey(ip, "")] = &blockAgg{
ip: ip,
escalated: true,
domains: make(map[string]bool),
uriCounts: make(map[string]int),
}
}
}
var result []modsecBlockView
for _, agg := range byBlock {
if agg.hits == 0 && !agg.escalated {
continue
}
var domainList []string
for d := range agg.domains {
domainList = append(domainList, d)
}
sort.Strings(domainList)
domains := strings.Join(domainList, ", ")
if len(domains) > 80 {
domains = domains[:77] + "..."
}
lastSeen := ""
lastSeenISO := ""
if !agg.lastSeen.IsZero() {
lastSeen = agg.lastSeen.Format("15:04:05")
lastSeenISO = agg.lastSeen.UTC().Format(time.RFC3339)
}
firstSeenISO := ""
if !agg.firstSeen.IsZero() {
firstSeenISO = agg.firstSeen.UTC().Format(time.RFC3339)
}
topURIs := topKeysByCount(agg.uriCounts, 5)
result = append(result, modsecBlockView{
IP: agg.ip,
RuleID: agg.ruleID,
Description: agg.description,
Domains: domains,
DomainList: domainList,
DomainCount: len(agg.domains),
Hits: agg.hits,
LastSeen: lastSeen,
FirstSeen: firstSeenISO,
LastSeenISO: lastSeenISO,
TopURIs: topURIs,
SampleEvents: agg.samples,
Escalated: agg.escalated,
})
}
sort.Slice(result, func(i, j int) bool {
if result[i].Hits != result[j].Hits {
return result[i].Hits > result[j].Hits
}
if result[i].Escalated != result[j].Escalated {
return result[i].Escalated
}
if result[i].LastSeenISO != result[j].LastSeenISO {
return result[i].LastSeenISO > result[j].LastSeenISO
}
if result[i].IP != result[j].IP {
return result[i].IP < result[j].IP
}
return result[i].RuleID < result[j].RuleID
})
if truncated {
w.Header().Set("X-CSM-Truncated", "1")
}
writeJSON(w, result)
}
// apiModSecEvents returns the most recent individual ModSecurity events.
func (s *Server) apiModSecEvents(w http.ResponseWriter, r *http.Request) {
limit := 100
if l := r.URL.Query().Get("limit"); l != "" {
if n, err := strconv.Atoi(l); err == nil && n > 0 && n <= 500 {
limit = n
}
}
findings := deduplicateModSecFindings(s.modsecFindings24h())
result := make([]modsecEventView, 0, limit)
for _, f := range findings {
if isModSecEscalation(f.Check) {
continue
}
if len(result) >= limit {
break
}
result = append(result, modsecEventView{
Time: f.Timestamp.Format("15:04:05"),
IP: extractModSecIP(f),
RuleID: extractModSecRule(f),
Hostname: extractModSecHostname(f),
URI: extractModSecURI(f),
Severity: f.Severity.String(),
})
}
writeJSON(w, result)
}
// deduplicateModSecFindings merges Apache + LiteSpeed duplicate events.
// Both log the same block within the same second - keep one with merged fields.
func deduplicateModSecFindings(findings []alert.Finding) []alert.Finding {
type dedupKey struct {
second time.Time
ip string
rule string
}
seen := make(map[dedupKey]int) // key → index in result
var result []alert.Finding
for _, f := range findings {
ip := extractModSecIP(f)
rule := extractModSecRule(f)
ts := f.Timestamp.UTC().Truncate(time.Second)
key := dedupKey{second: ts, ip: ip, rule: rule}
if idx, ok := seen[key]; ok {
// Merge richer details into existing entry
existing := &result[idx]
if extractModSecHostname(f) != "" && extractModSecHostname(*existing) == "" {
existing.Details = f.Details
}
} else {
seen[key] = len(result)
result = append(result, f)
}
}
return result
}
func isModSecEscalation(check string) bool {
return check == "modsec_block_escalation" || check == "modsec_csm_block_escalation"
}
// modsecFindings24h returns all modsec findings from the last 24 hours.
func (s *Server) modsecFindings24h() []alert.Finding {
findings, _ := s.modsecFindings24hWithTruncation()
return findings
}
func (s *Server) modsecFindings24hWithTruncation() ([]alert.Finding, bool) {
db := store.Global()
if db == nil {
return nil, false
}
cutoff := time.Now().Add(-24 * time.Hour)
findings := db.SearchHistorySince(cutoff, modsecFindingsScanCap+1, func(f alert.Finding) bool {
return strings.HasPrefix(f.Check, "modsec_")
})
if len(findings) > modsecFindingsScanCap {
return findings[:modsecFindingsScanCap], true
}
return findings, false
}
// --- Field extraction from finding Details ---
// Details format: "Rule: NNNN\nMessage: ...\nHostname: ...\nURI: ..."
func extractModSecIP(f alert.Finding) string {
// Try from message: "... from IP on ..." or "... from IP ..."
msg := f.Message
if idx := strings.Index(msg, " from "); idx >= 0 {
rest := msg[idx+6:]
if sp := strings.IndexAny(rest, " \n"); sp >= 0 {
rest = rest[:sp]
}
if len(rest) >= 7 && rest[0] >= '0' && rest[0] <= '9' && strings.Count(rest, ".") == 3 {
return rest
}
}
// Fallback: parse [client IP] from raw log line in Details
if ip := extractBetween(f.Details, "[client ", "]"); ip != "" {
// Strip port if present (Apache 2.4: "IP:port")
if strings.Count(ip, ":") == 1 {
if idx := strings.LastIndex(ip, ":"); idx > 0 {
ip = ip[:idx]
}
}
return ip
}
// Fallback: LiteSpeed format - IP in [IP:PORT-CONN#VHOST]
for _, field := range strings.Fields(f.Details) {
if strings.HasPrefix(field, "[") && strings.Contains(field, "#") {
inner := strings.TrimPrefix(field, "[")
if colonIdx := strings.Index(inner, ":"); colonIdx > 0 {
ip := inner[:colonIdx]
if len(ip) >= 7 && ip[0] >= '0' && ip[0] <= '9' {
return ip
}
}
}
}
return ""
}
func extractModSecRule(f alert.Finding) string {
// Try structured format first
if v := extractDetailField(f.Details, "Rule: "); v != "" {
return v
}
// Fallback: parse [id "NNNN"] from raw log line in Details
return extractBetween(f.Details, `[id "`, `"]`)
}
// csmRuleDescriptions provides fallback descriptions for CSM custom rules.
// LiteSpeed error logs omit the [msg "..."] field, so the log-extracted
// description is often empty. This map ensures the UI always shows a
// meaningful description for rules we define ourselves.
var csmRuleDescriptions = map[string]string{
"900001": "Blocked LEVIATHAN CGI extension access",
"900002": "Blocked LEVIATHAN directory access",
"900003": "Blocked PHP execution in uploads directory",
"900004": "Blocked PHP execution in languages directory",
"900005": "Blocked direct wp-config.php access",
"900007": "XML-RPC rate limit exceeded",
"900008": "Blocked known webshell filename access",
"900009": "Blocked GSocket User-Agent",
"900100": "WP-Automatic SQLi (CVE-2024-27956)",
"900101": "LayerSlider SQLi (CVE-2024-2879)",
"900102": "Really Simple Security auth bypass (CVE-2024-10924)",
"900103": "LiteSpeed Cache directory traversal (CVE-2024-4345)",
"900104": "Ultimate Member SQLi (CVE-2024-1071)",
"900105": "Backup Migration RCE (CVE-2023-6553)",
"900106": "GiveWP object injection (CVE-2024-5932)",
"900107": "WP File Manager arbitrary upload (CVE-2024-3400)",
"900110": "PHP object injection attempt",
"900111": "Blocked PHP in wp-content/upgrade",
"900112": "WordPress user enumeration blocked",
"900113": "wp-login brute force rate limit",
"900114": "wp-login brute force rate limit",
"900115": "Blocked .env file access",
"900116": "Blocked scanner probe",
"900120": "Blocked wp-coder preview endpoint",
"900121": "Blocked wp-coder attributes endpoint",
// Comodo WAF (CWAF) common rules. Rule IDs in the 21xxxx range are
// from the Comodo vendor ruleset (e.g. /etc/apache2/conf.d/
// modsec_vendor_configs/comodo_litespeed/), NOT from OWASP CRS.
// They were previously mislabeled as "OWASP:" here.
"210710": "Comodo WAF: HTTP Request Smuggling",
"210381": "Comodo WAF: HTTP Request Smuggling",
"214930": "Comodo WAF: Inbound anomaly score threshold exceeded",
"214940": "Comodo WAF: Outbound anomaly score threshold exceeded",
"218420": "Comodo WAF: Request content type restriction",
// OWASP CRS common rules. IDs in the 9xxxxx range are the standard
// OWASP CRS 3.x schema (920xxx protocol, 930xxx LFI, 941xxx XSS,
// 942xxx SQLi).
"920170": "OWASP: Validate GET/HEAD request",
"920420": "OWASP: Request content type is not allowed by policy",
"920600": "OWASP: Illegal Accept header",
"930100": "OWASP: Path traversal attack",
"930110": "OWASP: Path traversal attack",
"930120": "OWASP: OS file access attempt",
"941100": "OWASP: XSS attack detected via libinjection",
"941160": "OWASP: XSS Filter - Category 1",
"942100": "OWASP: SQL injection attack detected via libinjection",
}
func extractModSecDescription(f alert.Finding) string {
if v := extractDetailField(f.Details, "Message: "); v != "" {
return v
}
if v := extractBetween(f.Details, `[msg "`, `"]`); v != "" {
return v
}
// Fallback: use static description for CSM custom rules when the log
// format (e.g. LiteSpeed) doesn't include the [msg "..."] field.
rule := extractModSecRule(f)
if desc, ok := csmRuleDescriptions[rule]; ok {
return desc
}
return ""
}
func extractModSecHostname(f alert.Finding) string {
if v := extractDetailField(f.Details, "Hostname: "); v != "" {
return v
}
return extractBetween(f.Details, `[hostname "`, `"]`)
}
func extractModSecURI(f alert.Finding) string {
if v := extractDetailField(f.Details, "URI: "); v != "" {
return v
}
return extractBetween(f.Details, `[uri "`, `"]`)
}
func extractDetailField(details, prefix string) string {
for _, line := range strings.Split(details, "\n") {
if strings.HasPrefix(line, prefix) {
return strings.TrimPrefix(line, prefix)
}
}
return ""
}
// looksLikeIP returns true if the string looks like an IP address (not a domain).
func looksLikeIP(s string) bool {
if len(s) < 7 {
return false
}
for _, c := range s {
if c != '.' && (c < '0' || c > '9') {
return false
}
}
return strings.Count(s, ".") == 3
}
// extractBetween extracts the value between start and end delimiters.
// Used as fallback for old findings where Details is the raw log line.
func extractBetween(s, start, end string) string {
idx := strings.Index(s, start)
if idx < 0 {
return ""
}
rest := s[idx+len(start):]
endIdx := strings.Index(rest, end)
if endIdx < 0 {
return ""
}
return rest[:endIdx]
}
package webui
import (
"fmt"
"net/http"
"os"
"github.com/pidginhost/csm/internal/modsec"
"github.com/pidginhost/csm/internal/store"
)
func validateModSecDisabledRules(allRules []modsec.Rule, disabled []int) error {
knownIDs := make(map[int]bool)
counterIDs := make(map[int]bool)
for _, rule := range allRules {
knownIDs[rule.ID] = true
if rule.IsCounter {
counterIDs[rule.ID] = true
}
}
for _, id := range disabled {
if !knownIDs[id] {
return fmt.Errorf("rule ID %d is not a known CSM rule", id)
}
if counterIDs[id] {
return fmt.Errorf("rule ID %d is a protected bookkeeping rule", id)
}
}
return nil
}
func (s *Server) handleModSecRules(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "modsec-rules.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
// GET /api/v1/modsec/rules - list all CSM rules with status
func (s *Server) apiModSecRules(w http.ResponseWriter, _ *http.Request) {
cfg := s.cfg.ModSec
// Check all three config fields
var missing []string
if cfg.RulesFile == "" {
missing = append(missing, "rules_file")
}
if cfg.OverridesFile == "" {
missing = append(missing, "overrides_file")
}
if cfg.ReloadCommand == "" {
missing = append(missing, "reload_command")
}
if len(missing) > 0 {
writeJSON(w, map[string]interface{}{
"configured": false,
"missing": missing,
})
return
}
// Parse rules from config file
allRules, err := modsec.ParseRulesFile(cfg.RulesFile)
if err != nil {
fmt.Fprintf(os.Stderr, "modsec: parse rules failed: %v\n", err)
writeJSONError(w, "Failed to parse rules file", http.StatusInternalServerError)
return
}
// Read disabled IDs from overrides file
disabledIDs, _ := modsec.ReadOverrides(cfg.OverridesFile)
disabledSet := make(map[int]bool)
for _, id := range disabledIDs {
disabledSet[id] = true
}
// Read escalation exclusions and hit counts from store
var noEscalate map[int]bool
var hits map[int]store.RuleHitStats
if db := store.Global(); db != nil {
noEscalate = db.GetModSecNoEscalateRules()
hits = db.GetModSecRuleHits()
}
// Build response - filter out counter rules
type ruleView struct {
ID int `json:"id"`
Description string `json:"description"`
Action string `json:"action"`
StatusCode int `json:"status_code"`
Phase int `json:"phase"`
Enabled bool `json:"enabled"`
Escalate bool `json:"escalate"`
Hits24h int `json:"hits_24h"`
LastHit string `json:"last_hit,omitempty"`
}
var rules []ruleView
for _, r := range allRules {
if r.IsCounter {
continue // hide bookkeeping rules
}
rv := ruleView{
ID: r.ID,
Description: r.Description,
Action: r.Action,
StatusCode: r.StatusCode,
Phase: r.Phase,
Enabled: !disabledSet[r.ID],
Escalate: !noEscalate[r.ID],
}
if h, ok := hits[r.ID]; ok {
rv.Hits24h = h.Hits
if !h.LastHit.IsZero() {
rv.LastHit = h.LastHit.Format("2006-01-02T15:04:05Z07:00")
}
}
rules = append(rules, rv)
}
active := 0
for _, r := range rules {
if r.Enabled {
active++
}
}
writeJSON(w, map[string]interface{}{
"rules": rules,
"total": len(rules),
"active": active,
"disabled": disabledIDs,
"configured": true,
})
}
// POST /api/v1/modsec/rules/apply - write overrides and reload
func (s *Server) apiModSecRulesApply(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Serialize apply operations - the write+reload+rollback sequence
// must not interleave with concurrent applies.
s.modSecApplyMu.Lock()
defer s.modSecApplyMu.Unlock()
cfg := s.cfg.ModSec
if cfg.RulesFile == "" || cfg.OverridesFile == "" || cfg.ReloadCommand == "" {
writeJSONError(w, "ModSecurity not configured", http.StatusBadRequest)
return
}
var req struct {
Disabled []int `json:"disabled"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate: only allow disabling known CSM rule IDs from the parsed rules file
allRules, err := modsec.ParseRulesFile(cfg.RulesFile)
if err != nil {
fmt.Fprintf(os.Stderr, "modsec: parse rules failed: %v\n", err)
writeJSONError(w, "Failed to parse rules file", http.StatusInternalServerError)
return
}
if err := validateModSecDisabledRules(allRules, req.Disabled); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
// Save previous state for rollback
previousContent := modsec.ReadOverridesRaw(cfg.OverridesFile)
// Write new overrides
if writeErr := modsec.WriteOverrides(cfg.OverridesFile, req.Disabled); writeErr != nil {
fmt.Fprintf(os.Stderr, "modsec: write overrides failed: %v\n", writeErr)
writeJSONError(w, "Failed to write overrides", http.StatusInternalServerError)
return
}
// Reload web server
output, reloadErr := modsec.Reload(cfg.ReloadCommand)
if reloadErr != nil {
// Rollback on failure
_ = modsec.RestoreOverrides(cfg.OverridesFile, previousContent)
fmt.Fprintf(os.Stderr, "modsec: reload failed (rolled back): %v\noutput: %s\n", reloadErr, output)
// Truncate output for client - may contain sensitive system paths
clientOutput := output
if len(clientOutput) > 500 {
clientOutput = clientOutput[:500] + "... (truncated)"
}
writeJSON(w, map[string]interface{}{
"ok": false,
"error": "Web server reload failed, changes rolled back",
"reload_output": clientOutput,
"rolled_back": true,
})
return
}
writeJSON(w, map[string]interface{}{
"ok": true,
"disabled_count": len(req.Disabled),
})
}
// POST /api/v1/modsec/rules/escalation - toggle escalation for a single rule
func (s *Server) apiModSecRulesEscalation(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
db := store.Global()
if db == nil {
writeJSONError(w, "Store not available", http.StatusInternalServerError)
return
}
var req struct {
RuleID int `json:"rule_id"`
Escalate bool `json:"escalate"`
}
if err := decodeJSONBodyLimited(w, r, 4*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate: only accept known CSM rule IDs (900000-900999)
if req.RuleID < 900000 || req.RuleID > 900999 {
writeJSONError(w, "Rule ID must be a CSM custom rule (900000-900999)", http.StatusBadRequest)
return
}
var err error
if req.Escalate {
err = db.RemoveModSecNoEscalateRule(req.RuleID)
} else {
err = db.AddModSecNoEscalateRule(req.RuleID)
}
if err != nil {
fmt.Fprintf(os.Stderr, "modsec: escalation update failed: %v\n", err)
writeJSONError(w, "Failed to update escalation setting", http.StatusInternalServerError)
return
}
writeJSON(w, map[string]interface{}{"ok": true})
}
package webui
import (
"bufio"
"context"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/mysqlclient"
"github.com/pidginhost/csm/internal/redisinfo"
)
// --- Local response types ---
type perfResponse struct {
Metrics *perfMetrics `json:"metrics"`
Findings []perfFindingView `json:"findings"`
}
type perfMetrics struct {
LoadAvg [3]float64 `json:"load_avg"`
CPUCores int `json:"cpu_cores"`
MemTotalMB uint64 `json:"mem_total_mb"`
MemUsedMB uint64 `json:"mem_used_mb"`
MemAvailMB uint64 `json:"mem_avail_mb"`
SwapTotalMB uint64 `json:"swap_total_mb"`
SwapUsedMB uint64 `json:"swap_used_mb"`
PHPProcs int `json:"php_procs_total"`
TopPHPUsers []userProcs `json:"top_php_users"`
// MySQL telemetry is best-effort. Both fields are nil when csm could
// not read mysqld's pidfile or the mysql client failed (no /root/.my.cnf,
// no socket auth, mysqld absent). The webui renders "n/a" in that case
// so operators can tell "MySQL is idle" from "we couldn't ask".
MySQLMemMB *uint64 `json:"mysql_mem_mb"`
MySQLConns *int `json:"mysql_conns"`
RedisMemMB uint64 `json:"redis_mem_mb"`
RedisMaxMB uint64 `json:"redis_maxmem_mb"`
RedisKeys int64 `json:"redis_keys"`
Uptime string `json:"uptime"`
}
type userProcs struct {
User string `json:"user"`
Count int `json:"count"`
}
type perfFindingView struct {
Severity int `json:"severity"`
SevClass string `json:"sev_class"`
Check string `json:"check"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
Key string `json:"key"`
FirstSeen string `json:"first_seen"`
LastSeen string `json:"last_seen"`
}
// --- Cached values ---
var (
perfCoresOnce sync.Once
perfCoresCache int
perfUIDMapOnce sync.Once
perfUIDMapCache map[string]string
)
// cachedCores reads /proc/cpuinfo once and counts "processor\t" lines.
func cachedCores() int {
perfCoresOnce.Do(func() {
f, err := os.Open("/proc/cpuinfo")
if err != nil {
perfCoresCache = 1
return
}
defer func() { _ = f.Close() }()
count := 0
scanner := bufio.NewScanner(f)
for scanner.Scan() {
if strings.HasPrefix(scanner.Text(), "processor\t") {
count++
}
}
if count == 0 {
count = 1
}
perfCoresCache = count
})
return perfCoresCache
}
// cachedUID resolves a UID string to a username via /etc/passwd, cached.
func cachedUID(uid string) string {
perfUIDMapOnce.Do(func() {
perfUIDMapCache = make(map[string]string)
data, err := os.ReadFile("/etc/passwd")
if err != nil {
return
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Split(line, ":")
if len(fields) >= 3 {
perfUIDMapCache[fields[2]] = fields[0]
}
}
})
if name, ok := perfUIDMapCache[uid]; ok {
return name
}
return uid
}
// --- Metrics sampler ---
// runCmdQuick runs a command with a 5-second timeout. All call sites pass
// constant binary names (mysql, redis-cli, etc.) and literal argument
// lists — no HTTP-controlled input reaches this function.
func runCmdQuick(name string, args ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// #nosec G204 -- see function-level comment: constant names/args only.
out, err := exec.CommandContext(ctx, name, args...).Output()
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("command timed out: %s", name)
}
return out, err
}
// sampleMetrics gathers live system metrics and returns a populated perfMetrics.
func sampleMetrics() *perfMetrics {
m := &perfMetrics{}
// Load averages
if data, err := os.ReadFile("/proc/loadavg"); err == nil {
fields := strings.Fields(string(data))
if len(fields) >= 3 {
for i := 0; i < 3; i++ {
v, _ := strconv.ParseFloat(fields[i], 64)
m.LoadAvg[i] = v
}
}
}
// CPU cores
m.CPUCores = cachedCores()
// Memory from /proc/meminfo
{
f, err := os.Open("/proc/meminfo")
if err == nil {
var memTotal, memAvail, memFree, memBuffers, memCached, swapTotal, swapFree uint64
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
val, _ := strconv.ParseUint(fields[1], 10, 64)
switch fields[0] {
case "MemTotal:":
memTotal = val
case "MemAvailable:":
memAvail = val
case "MemFree:":
memFree = val
case "Buffers:":
memBuffers = val
case "Cached:":
memCached = val
case "SwapTotal:":
swapTotal = val
case "SwapFree:":
swapFree = val
}
}
_ = f.Close()
m.MemTotalMB = memTotal / 1024
m.MemAvailMB = memAvail / 1024
// Used = Total - Free - Buffers - Cached
used := memTotal
if memFree+memBuffers+memCached <= memTotal {
used = memTotal - memFree - memBuffers - memCached
}
m.MemUsedMB = used / 1024
m.SwapTotalMB = swapTotal / 1024
if swapFree <= swapTotal {
m.SwapUsedMB = (swapTotal - swapFree) / 1024
}
}
}
// PHP processes: scan /proc/*/cmdline for lsphp
{
cmdlinePaths, _ := filepath.Glob("/proc/[0-9]*/cmdline")
userCounts := make(map[string]int)
total := 0
for _, cmdPath := range cmdlinePaths {
// #nosec G304 -- cmdPath from /proc/*/cmdline glob; kernel pseudo-FS.
data, err := os.ReadFile(cmdPath)
if err != nil {
continue
}
cmdStr := strings.ReplaceAll(string(data), "\x00", " ")
if !strings.Contains(cmdStr, "lsphp") {
continue
}
pid := filepath.Base(filepath.Dir(cmdPath))
// #nosec G304 -- /proc/<pid>/status; kernel pseudo-FS, pid from /proc glob.
statusData, _ := os.ReadFile(filepath.Join("/proc", pid, "status"))
uid := ""
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
f := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(f) > 0 {
uid = f[0]
}
break
}
}
if uid == "" {
uid = "unknown"
}
username := cachedUID(uid)
userCounts[username]++
total++
}
m.PHPProcs = total
// Build sorted top-10 list
type up struct {
user string
count int
}
var ups []up
for u, c := range userCounts {
ups = append(ups, up{u, c})
}
sort.Slice(ups, func(i, j int) bool {
return ups[i].count > ups[j].count
})
if len(ups) > 10 {
ups = ups[:10]
}
m.TopPHPUsers = make([]userProcs, len(ups))
for i, u := range ups {
m.TopPHPUsers[i] = userProcs{User: u.user, Count: u.count}
}
}
// MySQL: PID -> VmRSS, plus Threads_connected. Both fields stay nil
// when the lookup fails so the webui can show "n/a" instead of a
// misleading 0.
{
pidData, err := os.ReadFile("/var/run/mysqld/mysqld.pid")
if err == nil {
mysqlPID := strings.TrimSpace(string(pidData))
if mysqlPID != "" {
// #nosec G304 G703 -- mysqlPID is read from mysqld's own
// /var/run/mysqld/mysqld.pid and we're reading the kernel
// /proc pseudo-filesystem.
statusData, _ := os.ReadFile(filepath.Join("/proc", mysqlPID, "status"))
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "VmRSS:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
if kb, perr := strconv.ParseUint(fields[1], 10, 64); perr == nil {
mb := kb / 1024
m.MySQLMemMB = &mb
}
}
break
}
}
}
}
// Connection count. mysqlclient open returns nil on auth failure,
// missing socket, or absent server -- in every such case we
// leave MySQLConns nil rather than reporting a fake 0.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
rows, err := mysqlclient.RootQuery(ctx, "SHOW STATUS LIKE 'Threads_connected'")
cancel()
if err == nil && len(rows) > 0 {
fields := strings.Fields(rows[0])
if len(fields) >= 2 {
if n, perr := strconv.Atoi(fields[1]); perr == nil {
m.MySQLConns = &n
}
}
}
}
// Redis: memory + keyspace via in-process client (no redis-cli fork).
{
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if used, max, err := redisinfo.MemoryUsage(ctx); err == nil {
m.RedisMemMB = used / (1024 * 1024)
m.RedisMaxMB = max / (1024 * 1024)
}
if total, err := redisinfo.Keyspace(ctx); err == nil {
m.RedisKeys = total
}
}
// Uptime from /proc/uptime
{
data, err := os.ReadFile("/proc/uptime")
if err == nil {
fields := strings.Fields(string(data))
if len(fields) >= 1 {
secs, _ := strconv.ParseFloat(fields[0], 64)
d := time.Duration(secs) * time.Second
days := int(d.Hours()) / 24
hours := int(d.Hours()) % 24
m.Uptime = fmt.Sprintf("%dd %dh", days, hours)
}
}
}
return m
}
// sampleMetricsLoop samples metrics immediately and then every 10 seconds.
func (s *Server) sampleMetricsLoop(ctx context.Context) {
result := sampleMetrics()
s.perfSnapshot.Store(result)
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
result := sampleMetrics()
s.perfSnapshot.Store(result)
}
}
}
// apiPerformance returns the latest performance snapshot plus perf_ findings.
func (s *Server) apiPerformance(w http.ResponseWriter, r *http.Request) {
limit := queryInt(r, "limit", 100)
if limit > 500 {
limit = 500
}
metrics := s.perfSnapshot.Load()
latest := s.store.LatestFindings()
suppressions := s.store.LoadSuppressions()
var views []perfFindingView
for _, f := range latest {
if !strings.HasPrefix(f.Check, "perf_") {
continue
}
if s.store.IsSuppressed(f, suppressions) {
continue
}
firstSeen := f.Timestamp
lastSeen := f.Timestamp
if entry, ok := s.store.EntryForKey(f.Key()); ok {
firstSeen = entry.FirstSeen
lastSeen = entry.LastSeen
}
key := f.Key()
views = append(views, perfFindingView{
Severity: int(f.Severity),
SevClass: severityClass(f.Severity),
Check: f.Check,
Message: f.Message,
Details: f.Details,
Key: key,
FirstSeen: firstSeen.Format(time.RFC3339),
LastSeen: lastSeen.Format(time.RFC3339),
})
}
// Sort by severity descending
sort.Slice(views, func(i, j int) bool {
return views[i].Severity > views[j].Severity
})
if len(views) > limit {
views = views[:limit]
}
writeJSON(w, perfResponse{
Metrics: metrics,
Findings: views,
})
}
// apiPerfFixErrorLog truncates an account-owned error_log identified by
// the perf_error_logs finding. Admin scope; CSRF enforced at the route.
func (s *Server) apiPerfFixErrorLog(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Path string `json:"path"`
Key string `json:"key"`
}
if err := decodeJSONBodyLimited(w, r, 1<<14, &req); err != nil {
writeJSONError(w, "invalid request body", http.StatusBadRequest)
return
}
if req.Path == "" {
writeJSONError(w, "path is required", http.StatusBadRequest)
return
}
res := checks.FixErrorLogBloatInRoots(req.Path, s.perfFixAllowedRoots())
if !res.Success {
writeJSON(w, res)
return
}
s.dismissPerfFinding(req.Key)
s.auditLog(r, "perf_fix_error_log", req.Path, res.Description)
writeJSON(w, res)
}
// apiPerfFixDisplayErrors disables display_errors in an account-owned
// .user.ini / php.ini / .htaccess identified by the perf_wp_config
// finding's Details field. Admin scope; CSRF enforced at the route.
//
//nolint:dupl // mirrors apiPerfFixErrorLog; separate handlers keep audit actions explicit.
func (s *Server) apiPerfFixDisplayErrors(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Path string `json:"path"`
Key string `json:"key"`
}
if err := decodeJSONBodyLimited(w, r, 1<<14, &req); err != nil {
writeJSONError(w, "invalid request body", http.StatusBadRequest)
return
}
if req.Path == "" {
writeJSONError(w, "path is required", http.StatusBadRequest)
return
}
res := checks.FixDisplayErrorsOnInRoots(req.Path, s.perfFixAllowedRoots())
if !res.Success {
writeJSON(w, res)
return
}
s.dismissPerfFinding(req.Key)
s.auditLog(r, "perf_fix_display_errors", req.Path, res.Description)
writeJSON(w, res)
}
// apiPerfFixWPCron disables WP-Cron in an account-owned wp-config.php
// identified by a perf_wp_cron finding and installs a per-user system cron
// that runs wp-cron.php. Admin scope; CSRF enforced at the route.
func (s *Server) apiPerfFixWPCron(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Path string `json:"path"`
Key string `json:"key"`
}
if err := decodeJSONBodyLimited(w, r, 1<<14, &req); err != nil {
writeJSONError(w, "invalid request body", http.StatusBadRequest)
return
}
if req.Path == "" {
writeJSONError(w, "path is required", http.StatusBadRequest)
return
}
res := checks.FixDisableWPCronInRoots(req.Path, s.perfFixAllowedRoots(), s.wpCronFixOptions())
if !res.Success {
writeJSON(w, res)
return
}
s.dismissPerfFinding(req.Key)
s.auditLog(r, "perf_fix_wp_cron", req.Path, res.Description)
writeJSON(w, res)
}
func (s *Server) wpCronFixOptions() checks.WPCronFixOptions {
if s.cfg == nil {
return checks.WPCronFixOptions{}
}
return checks.WPCronFixOptions{
IntervalMinutes: s.cfg.Performance.WPCronFix.IntervalMinutes,
PHPBin: s.cfg.Performance.WPCronFix.PHPBin,
}
}
func (s *Server) perfFixAllowedRoots() []string {
if s.cfg == nil {
return []string{"/home"}
}
return checks.ResolveWebRoots(s.cfg)
}
func (s *Server) dismissPerfFinding(key string) {
key = strings.TrimSpace(key)
if key == "" {
return
}
s.store.DismissFinding(key)
s.store.DismissLatestFinding(key)
}
// handlePerformance renders the performance dashboard page.
func (s *Server) handlePerformance(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "performance.html", nil)
}
package webui
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"net/http"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/store"
)
var errNoStore = errors.New("store unavailable")
func nowUnix() int64 { return time.Now().UTC().Unix() }
// operatorKey returns a SHA-256 hex of the auth credential carried by r. It
// is used as the per-operator partition key for the preferences store. The
// store never sees the raw token; only its hash.
//
// Returns "" when no admin credential is present (handlers below run after
// requireAuth, so this normally only happens in tests that bypass middleware).
func (s *Server) operatorKey(r *http.Request) string {
if bearer, ok := s.bearerTokenWithScope(r, "admin"); ok {
return hashOperatorToken(bearer)
}
if cookie, ok := s.cookieTokenWithScope(r, "admin"); ok {
return hashOperatorToken(cookie)
}
return ""
}
func hashOperatorToken(raw string) string {
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
const (
prefsNamespaceUser = "user"
prefsNamespaceViews = "views"
)
// userPrefsBlob is the JSON document the client posts to /api/v1/prefs/user.
// Fields are validated and clipped to a small enum where appropriate so an
// attacker cannot smuggle arbitrary template data into the layout.
type userPrefsBlob struct {
Density string `json:"density"`
Timezone string `json:"timezone"`
AutoRefresh string `json:"auto_refresh"`
TableColumns map[string][]string `json:"table_columns,omitempty"`
}
func sanitizeUserPrefs(in userPrefsBlob) userPrefsBlob {
out := userPrefsBlob{}
switch in.Density {
case "compact", "comfortable":
out.Density = in.Density
}
switch in.Timezone {
case "server", "local":
out.Timezone = in.Timezone
default:
// Allow IANA-shaped strings (e.g. "Europe/Bucharest"). Reject anything
// containing whitespace or control characters; the value gets reflected
// in JS where Intl.DateTimeFormat will reject malformed zones anyway.
if isIANAish(in.Timezone) {
out.Timezone = in.Timezone
}
}
switch in.AutoRefresh {
case "on", "off":
out.AutoRefresh = in.AutoRefresh
}
if len(in.TableColumns) > 0 {
cleaned := make(map[string][]string, len(in.TableColumns))
for k, vals := range in.TableColumns {
if !isSimpleIdent(k) || len(vals) > 64 {
continue
}
var v []string
for _, name := range vals {
if isSimpleIdent(name) {
v = append(v, name)
}
}
cleaned[k] = v
}
if len(cleaned) > 0 {
out.TableColumns = cleaned
}
}
return out
}
func isIANAish(s string) bool {
if s == "" || len(s) > 64 {
return false
}
for _, r := range s {
switch {
case r >= 'A' && r <= 'Z':
case r >= 'a' && r <= 'z':
case r >= '0' && r <= '9':
case r == '_' || r == '+' || r == '-' || r == '/':
default:
return false
}
}
return true
}
func isSimpleIdent(s string) bool {
if s == "" || len(s) > 64 {
return false
}
for _, r := range s {
switch {
case r >= 'A' && r <= 'Z':
case r >= 'a' && r <= 'z':
case r >= '0' && r <= '9':
case r == '_' || r == '-' || r == '.':
default:
return false
}
}
return true
}
// apiPrefsUser handles GET and PUT for the operator's user-pref blob.
func (s *Server) apiPrefsUser(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
s.handleGetUserPrefs(w, r)
case http.MethodPut:
s.handlePutUserPrefs(w, r)
default:
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
func (s *Server) handleGetUserPrefs(w http.ResponseWriter, r *http.Request) {
opkey := s.operatorKey(r)
if opkey == "" {
writeJSONError(w, "Unauthenticated", http.StatusUnauthorized)
return
}
sdb := store.Global()
if sdb == nil {
writeJSON(w, userPrefsBlob{})
return
}
raw, err := sdb.GetOperatorPref(opkey, prefsNamespaceUser)
if err != nil {
writeJSONError(w, "Store error", http.StatusInternalServerError)
return
}
if raw == nil {
writeJSON(w, userPrefsBlob{})
return
}
var blob userPrefsBlob
if err := json.Unmarshal(raw, &blob); err != nil {
writeJSON(w, userPrefsBlob{})
return
}
writeJSON(w, sanitizeUserPrefs(blob))
}
func (s *Server) handlePutUserPrefs(w http.ResponseWriter, r *http.Request) {
opkey := s.operatorKey(r)
if opkey == "" {
writeJSONError(w, "Unauthenticated", http.StatusUnauthorized)
return
}
var blob userPrefsBlob
if err := decodeJSONBodyLimited(w, r, store.MaxPrefBlobSize, &blob); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
clean := sanitizeUserPrefs(blob)
raw, err := json.Marshal(clean)
if err != nil {
writeJSONError(w, "Encoding failed", http.StatusInternalServerError)
return
}
sdb := store.Global()
if sdb == nil {
writeJSONError(w, "Store unavailable", http.StatusServiceUnavailable)
return
}
if err := sdb.PutOperatorPref(opkey, prefsNamespaceUser, raw); err != nil {
writeJSONError(w, "Store error", http.StatusInternalServerError)
return
}
writeJSON(w, clean)
}
// savedView represents one user-named filter combination for a page.
type savedView struct {
Name string `json:"name"`
Page string `json:"page"`
Params map[string]string `json:"params"`
Updated int64 `json:"updated"`
}
const maxSavedViewsPerOperator = 200
// apiPrefsViews handles list (GET), upsert (PUT), and delete (DELETE) of
// saved filter views.
func (s *Server) apiPrefsViews(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
s.handleListSavedViews(w, r)
case http.MethodPut:
s.handlePutSavedView(w, r)
case http.MethodDelete:
s.handleDeleteSavedView(w, r)
default:
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
func (s *Server) loadSavedViews(opkey string) []savedView {
sdb := store.Global()
if sdb == nil {
return nil
}
raw, err := sdb.GetOperatorPref(opkey, prefsNamespaceViews)
if err != nil || raw == nil {
return nil
}
var views []savedView
if err := json.Unmarshal(raw, &views); err != nil {
return nil
}
return views
}
func (s *Server) saveSavedViews(opkey string, views []savedView) error {
sort.SliceStable(views, func(i, j int) bool {
if views[i].Page != views[j].Page {
return views[i].Page < views[j].Page
}
return views[i].Name < views[j].Name
})
raw, err := json.Marshal(views)
if err != nil {
return err
}
sdb := store.Global()
if sdb == nil {
return errNoStore
}
return sdb.PutOperatorPref(opkey, prefsNamespaceViews, raw)
}
func (s *Server) handleListSavedViews(w http.ResponseWriter, r *http.Request) {
opkey := s.operatorKey(r)
if opkey == "" {
writeJSONError(w, "Unauthenticated", http.StatusUnauthorized)
return
}
page := strings.TrimSpace(r.URL.Query().Get("page"))
views := s.loadSavedViews(opkey)
out := make([]savedView, 0, len(views))
for _, v := range views {
if page != "" && v.Page != page {
continue
}
out = append(out, v)
}
writeJSON(w, out)
}
func (s *Server) handlePutSavedView(w http.ResponseWriter, r *http.Request) {
opkey := s.operatorKey(r)
if opkey == "" {
writeJSONError(w, "Unauthenticated", http.StatusUnauthorized)
return
}
var body struct {
Name string `json:"name"`
Page string `json:"page"`
Params map[string]string `json:"params"`
}
if err := decodeJSONBodyLimited(w, r, store.MaxPrefBlobSize, &body); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
body.Name = strings.TrimSpace(body.Name)
body.Page = strings.TrimSpace(body.Page)
if !isSimpleIdent(body.Page) {
writeJSONError(w, "Invalid page", http.StatusBadRequest)
return
}
if body.Name == "" || len(body.Name) > 80 {
writeJSONError(w, "Name must be 1-80 characters", http.StatusBadRequest)
return
}
if !isPrintableLabel(body.Name) {
writeJSONError(w, "Name contains invalid characters", http.StatusBadRequest)
return
}
if len(body.Params) > 32 {
writeJSONError(w, "Too many params", http.StatusBadRequest)
return
}
cleanParams := make(map[string]string, len(body.Params))
for k, v := range body.Params {
if !isSimpleIdent(k) || len(v) > 256 {
writeJSONError(w, "Invalid param", http.StatusBadRequest)
return
}
cleanParams[k] = v
}
views := s.loadSavedViews(opkey)
now := nowUnix()
updated := false
for i := range views {
if views[i].Page == body.Page && views[i].Name == body.Name {
views[i].Params = cleanParams
views[i].Updated = now
updated = true
break
}
}
if !updated {
if len(views) >= maxSavedViewsPerOperator {
writeJSONError(w, "Saved view limit reached", http.StatusBadRequest)
return
}
views = append(views, savedView{
Name: body.Name,
Page: body.Page,
Params: cleanParams,
Updated: now,
})
}
if err := s.saveSavedViews(opkey, views); err != nil {
writeJSONError(w, "Store error", http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "ok"})
}
func (s *Server) handleDeleteSavedView(w http.ResponseWriter, r *http.Request) {
opkey := s.operatorKey(r)
if opkey == "" {
writeJSONError(w, "Unauthenticated", http.StatusUnauthorized)
return
}
var body struct {
Name string `json:"name"`
Page string `json:"page"`
}
if err := decodeJSONBodyLimited(w, r, 8*1024, &body); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
views := s.loadSavedViews(opkey)
out := views[:0]
removed := false
for _, v := range views {
if v.Page == body.Page && v.Name == body.Name {
removed = true
continue
}
out = append(out, v)
}
if !removed {
writeJSONError(w, "View not found", http.StatusNotFound)
return
}
if err := s.saveSavedViews(opkey, out); err != nil {
writeJSONError(w, "Store error", http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "ok"})
}
func isPrintableLabel(s string) bool {
for _, r := range s {
if r < 0x20 || r == 0x7f {
return false
}
}
return true
}
package webui
import (
"fmt"
"net/http"
"github.com/pidginhost/csm/internal/mailfwd/intel"
"github.com/pidginhost/csm/internal/platform"
)
// selectQueueReporter picks the queue-composition source for the host. Only
// cPanel/exim is wired; other platforms get the empty reporter.
func selectQueueReporter() intel.QueueReporter {
if platform.Detect().IsCPanel() {
return intel.NewEximQueueSource()
}
return intel.EmptyQueueReporter{}
}
// selectQueueFlusher picks the backscatter-flush executor for the host. Only
// cPanel/exim is wired; other platforms expose the route as unavailable.
func selectQueueFlusher() intel.QueueFlusher {
if platform.Detect().IsCPanel() {
return intel.NewEximQueueFlusher()
}
return nil
}
// apiEmailFlushBackscatter handles POST /api/v1/email/queue/flush-backscatter.
// It removes only frozen null-sender messages -- undeliverable bounce
// backscatter -- from the exim queue. Mutating, so it runs under auth + CSRF
// and is audit-logged.
func (s *Server) apiEmailFlushBackscatter(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if s.queueFlusher == nil {
writeJSONError(w, "Mail queue flush not available on this host", http.StatusServiceUnavailable)
return
}
res, err := s.queueFlusher.FlushBackscatter()
if err != nil {
writeJSONError(w, "Failed to flush backscatter", http.StatusInternalServerError)
return
}
s.auditLog(r, "email_flush_backscatter", "mail-queue",
fmt.Sprintf("removed %d frozen null-sender message(s)", res.Removed))
writeJSON(w, res)
}
// apiEmailQueueComposition handles GET /api/v1/email/queue-composition and
// returns the makeup of the exim queue: real mail vs null-sender bounce
// backscatter, frozen count, oldest age, and the most-stuck recipients.
func (s *Server) apiEmailQueueComposition(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if s.queueReporter == nil {
writeJSON(w, intel.QueueComposition{TopRecipients: []intel.RecipientCount{}})
return
}
comp, err := s.queueReporter.Composition()
if err != nil {
writeJSONError(w, "Failed to read mail queue", http.StatusInternalServerError)
return
}
writeJSON(w, comp)
}
package webui
import (
"net/http"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
)
type relayAbuseResponse struct {
Entries []relayAbuseEntry `json:"entries"`
From string `json:"from"`
To string `json:"to"`
Matched int `json:"matched"`
Truncated bool `json:"truncated"`
}
type relayAbuseEntry struct {
Path string `json:"path"`
PathLabel string `json:"path_label"`
Severity int `json:"severity"`
SourceIP string `json:"source_ip,omitempty"`
CPUser string `json:"cp_user,omitempty"`
TriggerCount int `json:"trigger_count"`
DetectedAt time.Time `json:"detected_at"`
Sites []relaySiteEntry `json:"sites"`
MsgSample []string `json:"msg_sample,omitempty"`
}
type relaySiteEntry struct {
Site string `json:"site"`
Script string `json:"script"`
Hits int `json:"hits"`
LastSeen time.Time `json:"last_seen"`
SampleSubject string `json:"sample_subject,omitempty"`
}
const relayAbuseDefaultLimit = 20
const relayAbuseMaxLimit = 100
func relayPathLabel(path string) string {
switch path {
case "fanout":
return "Spam outbreak (IP fanout)"
case "volume":
return "High volume script"
case "header":
return "Suspicious headers"
case "volume_account":
return "High volume account"
case "":
return "Unknown path"
default:
return path
}
}
// apiEmailRelayAbuse handles GET /api/v1/email/relay-abuse. Read-only.
// Reads email_php_relay_abuse findings from persisted history (the realtime
// dispatch path does not populate LatestFindings).
func (s *Server) apiEmailRelayAbuse(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
q := r.URL.Query()
limit := queryInt(r, "limit", relayAbuseDefaultLimit)
if limit <= 0 || limit > relayAbuseMaxLimit {
limit = relayAbuseDefaultLimit
}
now := time.Now()
from := parseEmailGroupDate(q.Get("from"), now.Add(-24*time.Hour), false)
to := parseEmailGroupDate(q.Get("to"), now, true)
if to.Before(from) {
from, to = to, from
}
resp := relayAbuseResponse{
Entries: []relayAbuseEntry{},
From: from.UTC().Format(time.RFC3339),
To: to.UTC().Format(time.RFC3339),
}
if s.store == nil {
writeJSON(w, resp)
return
}
// Bound the scan, not the match count: read newest-first history since
// from, cap the rows inspected (same scan budget as /email/groups), then
// filter. truncated means the inspected cap was hit, so older matches may
// exist beyond the window we looked at.
history := s.store.ReadHistorySince(from)
if len(history) > emailGroupsScanCap {
history = history[:emailGroupsScanCap]
resp.Truncated = true
}
var rows []alert.Finding
for _, f := range history {
if f.Check == "email_php_relay_abuse" && !f.Timestamp.Before(from) && !f.Timestamp.After(to) {
rows = append(rows, f)
}
}
resp.Matched = len(rows)
sort.SliceStable(rows, func(i, j int) bool {
if !rows[i].Timestamp.Equal(rows[j].Timestamp) {
return rows[i].Timestamp.After(rows[j].Timestamp)
}
if rows[i].Path != rows[j].Path {
return rows[i].Path < rows[j].Path
}
if rows[i].SourceIP != rows[j].SourceIP {
return rows[i].SourceIP < rows[j].SourceIP
}
return rows[i].CPUser < rows[j].CPUser
})
if len(rows) > limit {
rows = rows[:limit]
}
for _, f := range rows {
resp.Entries = append(resp.Entries, toRelayAbuseEntry(f))
}
writeJSON(w, resp)
}
func toRelayAbuseEntry(f alert.Finding) relayAbuseEntry {
e := relayAbuseEntry{
Path: f.Path,
PathLabel: relayPathLabel(f.Path),
Severity: int(f.Severity),
SourceIP: f.SourceIP,
CPUser: f.CPUser,
TriggerCount: relayTriggerCount(f),
DetectedAt: f.Timestamp,
Sites: []relaySiteEntry{},
MsgSample: f.MsgIDs,
}
for _, h := range f.RelayBreakdown {
site, script := splitScriptKey(h.ScriptKey)
e.Sites = append(e.Sites, relaySiteEntry{
Site: site,
Script: script,
Hits: h.Hits,
LastSeen: h.LastSeen,
SampleSubject: h.SampleSubject,
})
}
return e
}
func relayTriggerCount(f alert.Finding) int {
if f.RelayTotal > 0 {
return f.RelayTotal
}
sum := 0
for _, h := range f.RelayBreakdown {
if h.Hits > 0 {
sum += h.Hits
}
}
if sum > 0 {
return sum
}
return len(f.MsgIDs)
}
// splitScriptKey splits a "host:/path" script key into host and path. The key
// is built as host + ":" + path with path starting at "/", so the delimiter is
// the ":/" boundary. Splitting there (not the first colon) keeps host:port and
// IPv6-literal hosts intact. A key without ":/" yields ("", key) so the row
// still renders.
func splitScriptKey(k string) (site, script string) {
if i := strings.Index(k, ":/"); i >= 0 {
return k[:i], k[i+1:]
}
return "", k
}
package webui
import (
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/signatures"
"github.com/pidginhost/csm/internal/store"
"github.com/pidginhost/csm/internal/yara"
)
func (s *Server) handleRules(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "rules.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
// GET /api/v1/rules/status
func (s *Server) apiRulesStatus(w http.ResponseWriter, _ *http.Request) {
yamlCount := 0
yamlVersion := 0
if scanner := signatures.Global(); scanner != nil {
yamlCount = scanner.RuleCount()
yamlVersion = scanner.Version()
}
yaraCount := 0
if b := yara.Active(); b != nil {
yaraCount = b.RuleCount()
}
result := map[string]interface{}{
"yaml_rules": yamlCount,
"yara_rules": yaraCount,
"yara_available": yara.Available(),
"yaml_version": yamlVersion,
"rules_dir": s.cfg.Signatures.RulesDir,
"auto_update": s.cfg.Signatures.UpdateURL != "",
"update_url": s.cfg.Signatures.UpdateURL,
"update_interval": s.cfg.Signatures.UpdateInterval,
}
writeJSON(w, result)
}
// GET /api/v1/rules/list
func (s *Server) apiRulesList(w http.ResponseWriter, _ *http.Request) {
rulesDir := s.cfg.Signatures.RulesDir
type ruleFileInfo struct {
Name string `json:"name"`
Type string `json:"type"` // "yaml" or "yara"
Size int64 `json:"size"`
}
var files []ruleFileInfo
entries, err := os.ReadDir(rulesDir)
if err != nil {
if os.IsNotExist(err) {
writeJSON(w, files)
return
}
writeJSONError(w, fmt.Sprintf("reading rules directory: %v", err), http.StatusInternalServerError)
return
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
ext := strings.ToLower(filepath.Ext(name))
var fileType string
switch ext {
case ".yml", ".yaml":
fileType = "yaml"
case ".yar", ".yara":
fileType = "yara"
default:
continue // skip non-rule files
}
info, err := entry.Info()
if err != nil {
continue
}
files = append(files, ruleFileInfo{
Name: name,
Type: fileType,
Size: info.Size(),
})
}
writeJSON(w, files)
}
// POST /api/v1/rules/reload
func (s *Server) apiRulesReload(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var yamlErr, yaraErr error
yamlCount := 0
yaraCount := 0
if scanner := signatures.Global(); scanner != nil {
yamlErr = scanner.Reload()
yamlCount = scanner.RuleCount()
// Update the cached sig count shown in dashboard/status
s.SetSigCount(yamlCount)
}
if b := yara.Active(); b != nil {
yaraErr = b.Reload()
yaraCount = b.RuleCount()
}
var errors []string
if yamlErr != nil {
errors = append(errors, fmt.Sprintf("YAML reload: %v", yamlErr))
}
if yaraErr != nil {
errors = append(errors, fmt.Sprintf("YARA reload: %v", yaraErr))
}
result := map[string]interface{}{
"ok": len(errors) == 0,
"yaml_rules": yamlCount,
"yara_rules": yaraCount,
}
if len(errors) > 0 {
result["errors"] = errors
}
writeJSON(w, result)
}
// GET/POST /api/v1/rules/modsec-escalation - manage rules excluded from auto-block
func (s *Server) apiModSecEscalation(w http.ResponseWriter, r *http.Request) {
db := store.Global()
if r.Method == http.MethodPost {
if db == nil {
writeJSONError(w, "Store not available", http.StatusInternalServerError)
return
}
var req struct {
Rules []int `json:"rules"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
rules := make(map[int]bool)
for _, id := range req.Rules {
rules[id] = true
}
if err := db.SetModSecNoEscalateRules(rules); err != nil {
writeJSONError(w, fmt.Sprintf("Save failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]interface{}{"ok": true, "count": len(rules)})
return
}
// GET
var ids []int
if db != nil {
for id := range db.GetModSecNoEscalateRules() {
ids = append(ids, id)
}
}
if ids == nil {
ids = []int{}
}
writeJSON(w, map[string]interface{}{"rules": ids})
}
package webui
import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"crypto/tls"
"encoding/hex"
"encoding/json"
"fmt"
"html/template"
"io/fs"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/broadcast"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/emailav"
"github.com/pidginhost/csm/internal/geoip"
"github.com/pidginhost/csm/internal/health"
"github.com/pidginhost/csm/internal/incident"
"github.com/pidginhost/csm/internal/mailfwd/intel"
"github.com/pidginhost/csm/internal/mailfwd/inventory"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/state"
)
// IPBlocker abstracts the firewall engine for block/unblock operations.
type IPBlocker interface {
BlockIP(ip string, reason string, timeout time.Duration) error
UnblockIP(ip string) error
}
// forceBlocker is an optional extension of IPBlocker for operator-initiated
// blocks that must bypass the auto_response.dry_run gate. The firewall engine
// implements this; test stubs need not.
type forceBlocker interface {
BlockIPForce(ip string, reason string, timeout time.Duration) error
}
// blockIPForOperator calls BlockIPForce when the blocker supports it (engine
// on live systems), otherwise falls back to BlockIP (test stubs). This ensures
// operator-initiated blocks from the Web UI are never silenced by dry_run.
func blockIPForOperator(b IPBlocker, ip, reason string, timeout time.Duration) error {
if fb, ok := b.(forceBlocker); ok {
return fb.BlockIPForce(ip, reason, timeout)
}
return b.BlockIP(ip, reason, timeout)
}
// noListDir wraps an http.FileSystem so http.FileServer cannot serve a
// directory index. Opening a directory returns fs.ErrNotExist, which the
// FileServer turns into a 404, while individual files are served normally.
// This keeps the unauthenticated /static/ assets reachable for the login
// page without letting anyone enumerate the shipped file set.
type noListDir struct{ fs http.FileSystem }
func (d noListDir) Open(name string) (http.File, error) {
f, err := d.fs.Open(name)
if err != nil {
return nil, err
}
info, err := f.Stat()
if err != nil {
_ = f.Close()
return nil, err
}
if info.IsDir() {
_ = f.Close()
return nil, fs.ErrNotExist
}
return f, nil
}
// Server is the web UI HTTP server. Serves API always; serves HTML pages
// and static files only if the UI directory exists on disk.
type Server struct {
cfg *config.Config
store *state.Store
httpSrv *http.Server
templates map[string]*template.Template
hasUI bool // true if UI directory with templates exists
uiDir string // path to UI directory on disk
startTime time.Time
sigCount int // loaded signature rule count
fanotifyActive bool
logWatcherCount int
blocker IPBlocker
geoIPDB atomic.Pointer[geoip.DB]
emailQuarantine *emailav.Quarantine
emailAVWatcherMode string
forwarderSource inventory.Source
deferralReporter intel.Reporter
queueReporter intel.QueueReporter
queueFlusher intel.QueueFlusher
forwardHeld heldForwardStore
version string
perfSnapshot atomic.Pointer[perfMetrics]
perfCancel context.CancelFunc
incidentCorrelator *incident.Correlator
// Rate limiting
loginMu sync.Mutex
loginAttempts map[string][]time.Time
apiMu sync.Mutex
apiRequests map[string][]time.Time // per-IP API rate limiting
scanMu sync.Mutex
scanRunning bool // only one scan at a time
modSecApplyMu sync.Mutex // serializes modsec rules apply (write+reload+rollback)
provider health.Provider // set by Daemon when it starts the WebUI
mu sync.RWMutex
findingBus *broadcast.Bus // set by Daemon via SetFindingBus
// Graceful shutdown signal for background goroutines and streaming handlers.
// shutdownOnce makes Shutdown idempotent; closing pruneDone twice panics.
pruneDone chan struct{}
shutdownOnce sync.Once
// restartDaemon is called by apiSettingsRestart. Tests override this.
restartDaemon func() (output []byte, err error)
}
// New creates a new web UI server.
func New(cfg *config.Config, store *state.Store) (*Server, error) {
s := &Server{
cfg: cfg,
store: store,
startTime: time.Now(),
loginAttempts: make(map[string][]time.Time),
apiRequests: make(map[string][]time.Time),
pruneDone: make(chan struct{}),
forwarderSource: selectForwarderSource(),
deferralReporter: selectDeferralReporter(),
queueReporter: selectQueueReporter(),
queueFlusher: selectQueueFlusher(),
forwardHeld: selectForwardHeld(),
}
// Check if UI directory exists on disk
s.uiDir = cfg.WebUI.UIDir
if s.uiDir == "" {
s.uiDir = "/opt/csm/ui"
}
funcMap := template.FuncMap{
"severityClass": severityClass,
"severityLabel": severityLabel,
"timeAgo": timeAgo,
"formatTime": formatTime,
"csrfToken": s.csrfToken,
"csmConfig": func() template.JS { return jsonForScript(s.csmConfig()) },
"json": jsonForScript,
"multiply": func(a, b int) int { return a * b },
"add": func(a, b int) int { return a + b },
"subtract": func(a, b int) int { return a - b },
"divisibleBy": func(a, b int) bool { return b != 0 && a%b == 0 },
}
// Try to load templates from disk
templateDir := filepath.Join(s.uiDir, "templates")
staticDir := filepath.Join(s.uiDir, "static")
if _, err := os.Stat(templateDir); err == nil {
s.templates = make(map[string]*template.Template)
layoutPath := filepath.Join(templateDir, "layout.html")
for _, page := range []string{"dashboard", "findings", "quarantine", "cleanup-history", "firewall", "modsec", "modsec-rules", "threat", "rules", "audit", "account", "incident", "email", "performance", "hardening", "settings"} {
pagePath := filepath.Join(templateDir, page+".html")
t, err := template.New(page+".html").Funcs(funcMap).ParseFiles(layoutPath, pagePath)
if err != nil {
return nil, fmt.Errorf("parsing template %s from %s: %w", page, templateDir, err)
}
s.templates[page+".html"] = t
}
loginPath := filepath.Join(templateDir, "login.html")
loginTmpl, err := template.New("login.html").Funcs(funcMap).ParseFiles(loginPath)
if err != nil {
return nil, fmt.Errorf("parsing login template: %w", err)
}
s.templates["login.html"] = loginTmpl
s.hasUI = true
fmt.Fprintf(os.Stderr, "WebUI: loaded templates from %s\n", templateDir)
} else {
fmt.Fprintf(os.Stderr, "WebUI: UI directory not found at %s - running in API-only mode\n", s.uiDir)
}
// Set up routes
mux := http.NewServeMux()
// Static files and HTML pages - only if UI directory exists
if s.hasUI {
// Static assets must stay reachable pre-auth (the login page loads its
// own CSS/JS), so they are not behind requireAuth. They must not be
// enumerable, though: noListDir makes directory requests 404 instead of
// returning an index listing of every shipped file.
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(noListDir{http.Dir(staticDir)})))
mux.HandleFunc("/login", s.handleLogin)
mux.Handle("/", s.requireAuth(http.HandlerFunc(s.handleDashboard)))
mux.Handle("/dashboard", s.requireAuth(http.HandlerFunc(s.handleDashboard)))
mux.Handle("/findings", s.requireAuth(http.HandlerFunc(s.handleFindings)))
mux.Handle("/history", s.requireAuth(http.HandlerFunc(s.handleHistoryRedirect)))
mux.Handle("/quarantine", s.requireAuth(http.HandlerFunc(s.handleQuarantine)))
mux.Handle("/cleanup-history", s.requireAuth(http.HandlerFunc(s.handleCleanupHistory)))
mux.Handle("/blocked", s.requireAuth(http.HandlerFunc(s.handleFirewall))) // redirect old URL
mux.Handle("/firewall", s.requireAuth(http.HandlerFunc(s.handleFirewall)))
mux.Handle("/threat", s.requireAuth(http.HandlerFunc(s.handleThreat)))
mux.Handle("/rules", s.requireAuth(http.HandlerFunc(s.handleRules)))
mux.Handle("/audit", s.requireAuth(http.HandlerFunc(s.handleAudit)))
mux.Handle("/account", s.requireAuth(http.HandlerFunc(s.handleAccount)))
mux.Handle("/incident", s.requireAuth(http.HandlerFunc(s.handleIncident)))
mux.Handle("/email", s.requireAuth(http.HandlerFunc(s.handleEmail)))
mux.Handle("/performance", s.requireAuth(http.HandlerFunc(s.handlePerformance)))
mux.Handle("/hardening", s.requireAuth(http.HandlerFunc(s.handleHardening)))
mux.Handle("/settings", s.requireAuth(http.HandlerFunc(s.handleSettings)))
mux.Handle("/modsec", s.requireAuth(http.HandlerFunc(s.handleModSec)))
mux.Handle("/modsec/rules", s.requireAuth(http.HandlerFunc(s.handleModSecRules)))
}
// Auth-protected API - read (read-scope tokens accepted)
mux.Handle("/api/v1/events", s.requireRead(http.HandlerFunc(s.apiEvents)))
mux.Handle("/api/v1/status", s.requireRead(http.HandlerFunc(s.apiStatus)))
mux.Handle("/api/v1/findings", s.requireRead(http.HandlerFunc(s.apiFindings)))
mux.Handle("/api/v1/findings/enriched", s.requireRead(http.HandlerFunc(s.apiFindingsEnriched)))
mux.Handle("/api/v1/history", s.requireRead(http.HandlerFunc(s.apiHistory)))
mux.Handle("/api/v1/stats", s.requireRead(http.HandlerFunc(s.apiStats)))
mux.Handle("/api/v1/stats/trend", s.requireRead(http.HandlerFunc(s.apiStatsTrend)))
mux.Handle("/api/v1/stats/timeline", s.requireRead(http.HandlerFunc(s.apiStatsTimeline)))
mux.Handle("/api/v1/blocked-ips", s.requireRead(http.HandlerFunc(s.apiBlockedIPs)))
mux.Handle("/api/v1/capabilities", s.requireRead(http.HandlerFunc(s.apiCapabilities)))
mux.Handle("/api/v1/health", s.requireRead(http.HandlerFunc(s.apiHealth)))
mux.Handle("/api/v1/components", s.requireRead(http.HandlerFunc(s.apiComponents)))
// Auth-protected API - admin-only reads (data with write-adjacent sensitivity)
mux.Handle("/api/v1/quarantine", s.requireAuth(http.HandlerFunc(s.apiQuarantine)))
mux.Handle("/api/v1/modsec/stats", s.requireRead(http.HandlerFunc(s.apiModSecStats)))
mux.Handle("/api/v1/modsec/blocks", s.requireRead(http.HandlerFunc(s.apiModSecBlocks)))
mux.Handle("/api/v1/modsec/events", s.requireRead(http.HandlerFunc(s.apiModSecEvents)))
mux.Handle("/api/v1/modsec/rules", s.requireAuth(http.HandlerFunc(s.apiModSecRules)))
mux.Handle("/api/v1/modsec/rules/apply", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiModSecRulesApply))))
mux.Handle("/api/v1/modsec/rules/escalation", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiModSecRulesEscalation))))
mux.Handle("/api/v1/accounts", s.requireAuth(http.HandlerFunc(s.apiAccounts)))
mux.Handle("/api/v1/account", s.requireAuth(http.HandlerFunc(s.apiAccountDetail)))
mux.Handle("/api/v1/history/csv", s.requireAuth(http.HandlerFunc(s.apiHistoryCSV)))
mux.Handle("/api/v1/export", s.requireAuth(http.HandlerFunc(s.apiExport)))
mux.Handle("/api/v1/import", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiImport))))
mux.Handle("/api/v1/incident", s.requireAuth(http.HandlerFunc(s.apiIncident)))
// Admin-scope on both routes: ServeMux cannot disambiguate by HTTP method,
// so the POST .../status mutator forces admin; reads under the same prefix
// inherit it (admin is a superset of read). The sub-path also runs CSRF
// because the router can dispatch POST .../status; requireCSRF only acts
// on unsafe methods so GET .../<id> still passes through.
mux.Handle("/api/v1/incidents", s.requireAuth(http.HandlerFunc(s.apiIncidentList)))
mux.Handle("/api/v1/incidents/groups", s.requireRead(http.HandlerFunc(s.apiIncidentGroups)))
mux.Handle("/api/v1/incidents/", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiIncidentRouter))))
mux.Handle("/api/v1/email/stats", s.requireAuth(http.HandlerFunc(s.apiEmailStats)))
mux.Handle("/api/v1/email/quarantine", s.requireAuth(http.HandlerFunc(s.apiEmailQuarantineList)))
mux.Handle("/api/v1/email/quarantine/", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiEmailQuarantineAction))))
mux.Handle("/api/v1/email/av/status", s.requireAuth(http.HandlerFunc(s.apiEmailAVStatus)))
mux.Handle("/api/v1/email/groups", s.requireRead(http.HandlerFunc(s.apiEmailGroups)))
mux.Handle("/api/v1/email/relay-abuse", s.requireRead(http.HandlerFunc(s.apiEmailRelayAbuse)))
mux.Handle("/api/v1/email/forwarders", s.requireRead(http.HandlerFunc(s.apiEmailForwarders)))
mux.Handle("/api/v1/email/deferrals", s.requireRead(http.HandlerFunc(s.apiEmailDeferrals)))
mux.Handle("/api/v1/email/queue-composition", s.requireRead(http.HandlerFunc(s.apiEmailQueueComposition)))
mux.Handle("/api/v1/email/queue/flush-backscatter", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiEmailFlushBackscatter))))
mux.Handle("/api/v1/email/held", s.requireAuth(http.HandlerFunc(s.apiEmailHeldList)))
mux.Handle("/api/v1/email/held/", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiEmailHeldAction))))
mux.Handle("/api/v1/performance", s.requireAuth(http.HandlerFunc(s.apiPerformance)))
mux.Handle("/api/v1/perf/fix-error-log", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiPerfFixErrorLog))))
mux.Handle("/api/v1/perf/fix-display-errors", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiPerfFixDisplayErrors))))
mux.Handle("/api/v1/perf/fix-wp-cron", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiPerfFixWPCron))))
mux.Handle("/api/v1/hardening", s.requireAuth(http.HandlerFunc(s.apiHardening)))
mux.Handle("/api/v1/hardening/run", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiHardeningRun))))
// Threat Intelligence API
mux.Handle("/api/v1/threat/stats", s.requireAuth(http.HandlerFunc(s.apiThreatStats)))
mux.Handle("/api/v1/threat/top-attackers", s.requireAuth(http.HandlerFunc(s.apiThreatTopAttackers)))
mux.Handle("/api/v1/threat/ip", s.requireAuth(http.HandlerFunc(s.apiThreatIP)))
mux.Handle("/api/v1/threat/events", s.requireAuth(http.HandlerFunc(s.apiThreatEvents)))
mux.Handle("/api/v1/threat/db-stats", s.requireAuth(http.HandlerFunc(s.apiThreatDBStats)))
mux.Handle("/api/v1/audit", s.requireAuth(http.HandlerFunc(s.apiUIAudit)))
mux.Handle("/api/v1/finding-detail", s.requireAuth(http.HandlerFunc(s.apiFindingDetail)))
mux.Handle("/api/v1/threat/whitelist-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatWhitelistIP))))
mux.Handle("/api/v1/threat/whitelist", s.requireAuth(http.HandlerFunc(s.apiThreatWhitelist)))
mux.Handle("/api/v1/threat/unwhitelist-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatUnwhitelistIP))))
mux.Handle("/api/v1/threat/block-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatBlockIP))))
mux.Handle("/api/v1/threat/clear-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatClearIP))))
mux.Handle("/api/v1/threat/temp-whitelist-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatTempWhitelistIP))))
mux.Handle("/api/v1/threat/bulk-action", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatBulkAction))))
// Rules API
mux.Handle("/api/v1/rules/status", s.requireAuth(http.HandlerFunc(s.apiRulesStatus)))
mux.Handle("/api/v1/rules/list", s.requireAuth(http.HandlerFunc(s.apiRulesList)))
mux.Handle("/api/v1/rules/reload", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiRulesReload))))
mux.Handle("/api/v1/rules/modsec-escalation", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiModSecEscalation))))
// Suppressions API
mux.Handle("/api/v1/suppressions", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiSuppressions))))
// Firewall API
mux.Handle("/api/v1/firewall/status", s.requireAuth(http.HandlerFunc(s.apiFirewallStatus)))
mux.Handle("/api/v1/firewall/allowed", s.requireAuth(http.HandlerFunc(s.apiFirewallAllowed)))
mux.Handle("/api/v1/firewall/audit", s.requireAuth(http.HandlerFunc(s.apiFirewallAudit)))
mux.Handle("/api/v1/firewall/subnets", s.requireAuth(http.HandlerFunc(s.apiFirewallSubnets)))
mux.Handle("/api/v1/firewall/check", s.requireAuth(http.HandlerFunc(s.apiFirewallCheck)))
// Settings API
mux.Handle("/api/v1/settings/restart", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiSettingsRestart))))
mux.Handle("/api/v1/settings/firewall/tentative-apply", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallTentativeApply))))
mux.Handle("/api/v1/settings/firewall/confirm", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallRollbackConfirm))))
mux.Handle("/api/v1/settings/firewall/revert", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallRollbackRevert))))
mux.Handle("/api/v1/settings/firewall/rollback", s.requireAuth(http.HandlerFunc(s.apiFirewallRollbackStatus)))
mux.Handle("/api/v1/settings", s.requireAuth(http.HandlerFunc(s.apiSettingsSections)))
mux.Handle("/api/v1/settings/", s.requireAuth(http.HandlerFunc(s.apiSettings)))
// GeoIP API
mux.Handle("/api/v1/geoip", s.requireAuth(http.HandlerFunc(s.apiGeoIPLookup)))
mux.Handle("/api/v1/geoip/batch", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiGeoIPBatch))))
// Auth-protected API - actions (with CSRF validation)
mux.Handle("/api/v1/fix", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFix))))
mux.Handle("/api/v1/fix-bulk", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiBulkFix))))
mux.Handle("/api/v1/scan-account", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiScanAccount))))
mux.Handle("/api/v1/test-alert", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiTestAlert))))
mux.Handle("/api/v1/block-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiBlockIP))))
mux.Handle("/api/v1/unblock-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiUnblockIP))))
mux.Handle("/api/v1/unblock-bulk", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiUnblockBulk))))
mux.Handle("/api/v1/dismiss", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiDismissFinding))))
mux.Handle("/api/v1/quarantine-preview", s.requireAuth(http.HandlerFunc(s.apiQuarantinePreview)))
mux.Handle("/api/v1/quarantine-restore", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiQuarantineRestore))))
mux.Handle("/api/v1/quarantine/bulk-delete", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiQuarantineBulkDelete))))
mux.Handle("/api/v1/db-object-backups", s.requireAuth(http.HandlerFunc(s.apiDBObjectBackups)))
mux.Handle("/api/v1/db-object-backup-preview", s.requireAuth(http.HandlerFunc(s.apiDBObjectBackupPreview)))
mux.Handle("/api/v1/db-object-backup-restore", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiDBObjectBackupRestore))))
mux.Handle("/api/v1/firewall/deny-subnet", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallDenySubnet))))
mux.Handle("/api/v1/firewall/allow-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallAllowIP))))
mux.Handle("/api/v1/firewall/remove-allow", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallRemoveAllow))))
mux.Handle("/api/v1/firewall/remove-subnet", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallRemoveSubnet))))
mux.Handle("/api/v1/firewall/cphulk-clear", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallFlushCphulk))))
mux.Handle("/api/v1/firewall/flush", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallFlush))))
mux.Handle("/api/v1/firewall/unban", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallUnban))))
// Operator preferences (P5.2 saved views, P5.4 user prefs).
mux.Handle("/api/v1/prefs/user", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiPrefsUser))))
mux.Handle("/api/v1/prefs/views", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiPrefsViews))))
// Bulk-action undo (P5.3).
mux.Handle("/api/v1/undo/pending", s.requireAuth(http.HandlerFunc(s.apiUndoPending)))
mux.Handle("/api/v1/undo/run", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiUndoRun))))
// Logout (clears cookie, requires auth to prevent logout CSRF)
mux.Handle("/logout", s.requireAuth(http.HandlerFunc(s.handleLogout)))
// /metrics (ROADMAP item 4) has its own auth: the handler accepts
// cfg.WebUI.MetricsToken as a dedicated Bearer token so Prometheus
// scrapers get a credential that does not also unlock the UI, and
// falls back to the existing AuthToken/session path so the UI can
// self-scrape. No CSRF: read-only endpoint.
mux.HandleFunc("/metrics", s.handleMetrics)
s.httpSrv = &http.Server{
Addr: cfg.WebUI.Listen,
Handler: s.securityHeaders(mux),
ReadHeaderTimeout: 10 * time.Second, // slowloris protection
ReadTimeout: 30 * time.Second, // max time to read full request
WriteTimeout: 300 * time.Second, // account scans can take several minutes
IdleTimeout: 120 * time.Second,
MaxHeaderBytes: 1 << 20, // 1MB
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
// Disable HTTP/2: Go's HTTP/2 implementation applies WriteTimeout
// to the entire connection, not per-stream. Long-running handlers
// (account scans ~5min) cause ERR_HTTP2_PROTOCOL_ERROR in browsers
// when the timeout fires. HTTP/1.1 handles per-request deadlines
// correctly via ResponseController.SetWriteDeadline.
NextProtos: []string{"http/1.1"},
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
},
}
s.restartDaemon = defaultRestartDaemon
return s, nil
}
// pruneLoginAttempts periodically cleans up stale rate-limit entries.
// It returns when s.pruneDone is closed.
func (s *Server) pruneLoginAttempts() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.pruneDone:
return
case <-ticker.C:
s.loginMu.Lock()
cutoff := time.Now().Add(-time.Minute)
for ip, attempts := range s.loginAttempts {
var recent []time.Time
for _, t := range attempts {
if t.After(cutoff) {
recent = append(recent, t)
}
}
if len(recent) == 0 {
delete(s.loginAttempts, ip)
} else {
s.loginAttempts[ip] = recent
}
}
s.loginMu.Unlock()
// Also prune API rate-limit entries
s.apiMu.Lock()
for ip, reqs := range s.apiRequests {
var recent []time.Time
for _, t := range reqs {
if t.After(cutoff) {
recent = append(recent, t)
}
}
if len(recent) == 0 {
delete(s.apiRequests, ip)
} else {
s.apiRequests[ip] = recent
}
}
s.apiMu.Unlock()
}
}
}
// Start starts the HTTPS server. Blocks until shutdown.
func (s *Server) Start() error {
certPath := s.cfg.WebUI.TLSCert
keyPath := s.cfg.WebUI.TLSKey
if certPath == "" {
certPath = filepath.Join(s.cfg.StatePath, "webui.crt")
keyPath = filepath.Join(s.cfg.StatePath, "webui.key")
}
if err := EnsureTLSCert(certPath, keyPath, s.cfg.Hostname); err != nil {
return fmt.Errorf("TLS cert setup: %w", err)
}
obs.Go("webui-prune-logins", s.pruneLoginAttempts)
perfCtx, perfCancel := context.WithCancel(context.Background())
s.perfCancel = perfCancel
obs.Go("webui-metrics-sample", func() { s.sampleMetricsLoop(perfCtx) })
fmt.Fprintf(os.Stderr, "WebUI listening on https://%s\n", s.cfg.WebUI.Listen)
return s.httpSrv.ListenAndServeTLS(certPath, keyPath)
}
// Shutdown gracefully stops the server. Safe to call more than once;
// the underlying pruneDone close is guarded so duplicate shutdown does
// not panic.
func (s *Server) Shutdown(ctx context.Context) error {
if s.perfCancel != nil {
s.perfCancel()
}
s.shutdownOnce.Do(func() { close(s.pruneDone) })
if s.httpSrv == nil {
return nil
}
return s.httpSrv.Shutdown(ctx)
}
// canonicalAllowedOrigin returns the single CORS origin the web UI
// will accept on /api/ requests. Built from cfg.Hostname plus the
// listen port so a forged HTTP Host header cannot redirect the check.
func (s *Server) canonicalAllowedOrigin() string {
host := canonicalOriginHost(s.cfg.Hostname)
port := webUIListenPort(s.cfg.WebUI.Listen)
if port != "" && port != "443" {
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = host[1 : len(host)-1]
}
host = net.JoinHostPort(host, port)
}
return "https://" + host
}
func canonicalOriginHost(host string) string {
host = strings.TrimSpace(host)
if strings.HasPrefix(host, "[") {
if end := strings.Index(host, "]"); end > 0 {
if ip := net.ParseIP(host[1:end]); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
return ip4.String()
}
return "[" + ip.String() + "]"
}
}
}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
return ip4.String()
}
return "[" + ip.String() + "]"
}
return strings.ToLower(host)
}
func webUIListenPort(listen string) string {
if _, port, err := net.SplitHostPort(listen); err == nil {
return port
}
if idx := strings.LastIndex(listen, ":"); idx >= 0 {
return listen[idx+1:]
}
return ""
}
func sameOrigin(got, want string) bool {
gotURL, err := url.Parse(got)
if err != nil || !originHeaderURL(gotURL) {
return false
}
wantURL, err := url.Parse(want)
if err != nil || !originHeaderURL(wantURL) {
return false
}
return strings.EqualFold(gotURL.Scheme, wantURL.Scheme) &&
strings.EqualFold(gotURL.Hostname(), wantURL.Hostname()) &&
originPort(gotURL) == originPort(wantURL)
}
func originHeaderURL(u *url.URL) bool {
return u.Scheme != "" && u.Host != "" && u.User == nil &&
u.Path == "" && u.RawQuery == "" && u.Fragment == ""
}
func originPort(u *url.URL) string {
if port := u.Port(); port != "" {
return port
}
switch strings.ToLower(u.Scheme) {
case "https":
return "443"
case "http":
return "80"
default:
return ""
}
}
// Broadcast is a no-op kept for daemon compatibility; dashboard uses polling.
func (s *Server) Broadcast(_ []alert.Finding) {}
// SetSigCount sets the loaded signature count for the status API.
func (s *Server) SetSigCount(count int) {
s.sigCount = count
}
// HasUI returns true if UI templates were loaded from disk.
func (s *Server) HasUI() bool {
return s.hasUI
}
// SetIPBlocker sets the firewall engine for block/unblock operations.
func (s *Server) SetIPBlocker(b IPBlocker) {
s.blocker = b
}
// SetHealthInfo sets daemon health info for the health API.
func (s *Server) SetHealthInfo(fanotifyActive bool, logWatchers int) {
s.fanotifyActive = fanotifyActive
s.logWatcherCount = logWatchers
}
// SetEmailQuarantine sets the email quarantine for the email AV API endpoints.
func (s *Server) SetEmailQuarantine(q *emailav.Quarantine) {
s.emailQuarantine = q
}
// SetEmailAVWatcherMode sets the watcher mode string for the email AV status API.
func (s *Server) SetEmailAVWatcherMode(mode string) {
s.emailAVWatcherMode = mode
}
// SetVersion sets the application version for display in the UI.
func (s *Server) SetVersion(v string) {
s.version = v
}
// SetHealthProvider installs the daemon's health provider. The webui
// constructs without one so unit tests can run without a daemon; the
// daemon must call this before any request hits /api/v1/status.
func (s *Server) SetHealthProvider(p health.Provider) {
s.provider = p
}
// SetFindingBus installs the broadcaster the SSE event stream subscribes
// to. The webui constructs without one so unit tests work without a
// daemon; the daemon must call this before any request hits /api/v1/events.
func (s *Server) SetFindingBus(bus *broadcast.Bus) {
s.mu.Lock()
defer s.mu.Unlock()
s.findingBus = bus
}
// SetIncidentCorrelator wires the incident correlator. Called once at
// startup; treated as immutable after first set.
func (s *Server) SetIncidentCorrelator(c *incident.Correlator) {
s.incidentCorrelator = c
}
// csmConfigJSON returns a JSON string of feature flags for the frontend.
func (s *Server) csmConfigJSON() string {
b, _ := json.Marshal(s.csmConfig())
return string(b)
}
// csmConfig returns the feature-flag map used by the frontend. Template
// rendering goes through jsonForScript; the csmConfigJSON wrapper exists
// for code paths that need a pre-marshaled JSON string.
func (s *Server) csmConfig() map[string]interface{} {
return map[string]interface{}{
"version": s.version,
"emailAV": s.cfg.EmailAV.Enabled,
"firewall": s.cfg.Firewall != nil && s.cfg.Firewall.Enabled,
"autoResponse": s.cfg.AutoResponse.Enabled,
"threatIntel": s.cfg.Reputation.AbuseIPDBKey != "",
"signatures": s.cfg.Signatures.RulesDir != "",
"challenge": s.cfg.Challenge.Difficulty > 0,
"fanotify": s.fanotifyActive,
"hostname": s.cfg.Hostname,
"authScope": "admin",
// #nosec G101 -- Not credentials. This is a lookup from
// finding-type ID (waf_block, credential_leak, etc.) to the
// human-readable label rendered in the UI.
"checkNames": map[string]string{
"waf_block": "WAF Block",
"brute_force": "Brute Force",
"webshell": "Web Shell",
"phishing": "Phishing",
"spam": "Spam",
"cpanel_login": "cPanel Login",
"file_upload": "File Upload",
"recon": "Reconnaissance",
"c2": "C2 Communication",
"other": "Other",
"perf_load": "Load",
"perf_php_processes": "PHP Processes",
"perf_memory": "Memory",
"perf_php_handler": "PHP Handler",
"perf_mysql_config": "MySQL Config",
"perf_redis_config": "Redis Config",
"perf_error_logs": "Error Logs",
"perf_wp_config": "WP Config",
"perf_wp_transients": "WP Transients",
"perf_wp_cron": "WP Cron",
"integrity": "Integrity",
"db_siteurl_hijack": "DB URL Hijack",
"db_options_injection": "DB Options Injection",
"db_post_injection": "DB Post Injection",
"db_spam_injection": "DB Spam Injection",
"db_rogue_admin": "DB Rogue Admin",
"db_suspicious_admin_email": "DB Suspicious Admin",
"mail_queue": "Mail Queue",
"mail_per_account": "Mail Volume",
"email_phishing_content": "Email Phishing",
"email_malware": "Email Malware",
"email_compromised_account": "Compromised Account",
"email_cloud_relay_abuse": "Cloud-Relay Credential Abuse",
"email_spam_outbreak": "Spam Outbreak",
"email_defer_fail_governor": "Defer/Fail Governor",
"email_credential_leak": "Credential Leak",
"email_auth_failure_realtime": "Auth Failure",
"smtp_bruteforce": "SMTP Brute Force",
"smtp_subnet_spray": "SMTP Subnet Spray",
"smtp_account_spray": "SMTP Account Spray",
"mail_bruteforce": "Mail Brute Force",
"mail_subnet_spray": "Mail Subnet Spray",
"admin_panel_bruteforce": "Admin Panel Brute Force",
"mail_account_spray": "Mail Account Spray",
"mail_account_compromised": "Mail Account Compromised",
"http_request_flood": "HTTP Request Flood",
"http_ua_spoof": "HTTP UA Spoof",
"http_distributed_flood": "Distributed HTTP Flood",
"exim_frozen_realtime": "Frozen Message",
"email_suspicious_geo": "Suspicious Geo Login",
"email_rate_critical": "Email Rate Critical",
"email_rate_warning": "Email Rate Warning",
"email_dkim_failure": "DKIM Failure",
"email_spf_rejection": "SPF Rejection",
"email_pipe_forwarder": "Pipe Forwarder",
"email_suspicious_forwarder": "Suspicious Forwarder",
"cpanel_login_realtime": "cPanel Login",
"cpanel_password_purge_realtime": "Password Purge",
"ssh_login_realtime": "SSH Login",
"pam_login": "PAM Login",
"pam_bruteforce": "PAM Brute Force",
"modsec_block_escalation": "ModSec Escalation",
"modsec_csm_block_escalation": "ModSec Escalation",
"whm_password_change_noninfra": "WHM Password Change",
"password_hijack_confirmed": "Password Hijack",
},
}
}
// --- Authentication ---
// tokenHasScope reports whether the credentials in r grant at least the
// requested scope. "read" is granted by any token; "admin" is granted only
// by admin-scope tokens. Constant-time compare against every configured
// token. Cookie credentials get treated as their token's scope (browser
// session uses the admin login form, which only matches admin tokens).
func (s *Server) tokenHasScope(r *http.Request, want string) bool {
// Browser cookie session
if _, ok := s.cookieTokenWithScope(r, want); ok {
return true
}
// Bearer token
_, ok := s.bearerTokenWithScope(r, want)
return ok
}
func (s *Server) cookieTokenWithScope(r *http.Request, want string) (string, bool) {
c, err := r.Cookie("csm_auth")
if err != nil || c.Value == "" {
return "", false
}
for _, tok := range s.cfg.WebUI.Tokens {
if webUITokenMatches(c.Value, tok) && webUITokenAllows(tok, want) {
return c.Value, true
}
}
return "", false
}
func (s *Server) bearerTokenWithScope(r *http.Request, want string) (string, bool) {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
return "", false
}
supplied := strings.TrimPrefix(auth, "Bearer ")
if supplied == "" {
return "", false
}
for _, tok := range s.cfg.WebUI.Tokens {
if webUITokenMatches(supplied, tok) && webUITokenAllows(tok, want) {
return supplied, true
}
}
return "", false
}
func webUITokenMatches(supplied string, tok config.WebUIToken) bool {
return supplied != "" &&
tok.Token != "" &&
subtle.ConstantTimeCompare([]byte(supplied), []byte(tok.Token)) == 1
}
func webUITokenAllows(tok config.WebUIToken, want string) bool {
switch want {
case "read":
return tok.Scope == "read" || tok.Scope == "admin"
case "admin":
return tok.Scope == "admin"
default:
return false
}
}
func (s *Server) requireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if s.tokenHasScope(r, "admin") {
next.ServeHTTP(w, r)
return
}
// API calls get 401 JSON; browser requests get redirect to login
if strings.HasPrefix(r.URL.Path, "/api/") {
writeJSONError(w, "Unauthorized", http.StatusUnauthorized)
return
}
http.Redirect(w, r, "/login", http.StatusFound)
})
}
func (s *Server) requireRead(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if s.tokenHasScope(r, "read") {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
next.ServeHTTP(w, r)
return
}
// API calls get 401 JSON; browser requests get redirect to login
if strings.HasPrefix(r.URL.Path, "/api/") {
writeJSONError(w, "Unauthorized", http.StatusUnauthorized)
return
}
http.Redirect(w, r, "/login", http.StatusFound)
})
}
// isAuthenticated is a thin shim used by handleLogin and metrics_api.
// New callers should prefer tokenHasScope directly.
func (s *Server) isAuthenticated(r *http.Request) bool {
return s.tokenHasScope(r, "admin")
}
// clientIPKey strips the port from a net/http RemoteAddr for use as a
// per-client rate-limit key, handling bracketed IPv6 ([::1]:443 -> ::1).
// Falls back to the raw value when there is no host:port to split, so a
// missing port never collapses distinct clients onto one key.
func clientIPKey(remoteAddr string) string {
if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
return host
}
return remoteAddr
}
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
// Redirect already-authenticated users to dashboard
if s.isAuthenticated(r) {
http.Redirect(w, r, "/dashboard", http.StatusFound)
return
}
if r.Method == http.MethodGet {
s.renderTemplate(w, "login.html", nil)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Rate limit: 5 attempts per minute per IP (strip port from RemoteAddr)
ip := clientIPKey(r.RemoteAddr)
s.loginMu.Lock()
now := time.Now()
attempts := s.loginAttempts[ip]
var recent []time.Time
for _, t := range attempts {
if now.Sub(t) < time.Minute {
recent = append(recent, t)
}
}
if len(recent) >= 5 {
s.loginMu.Unlock()
http.Error(w, "Too many login attempts", http.StatusTooManyRequests)
return
}
s.loginAttempts[ip] = append(recent, now)
s.loginMu.Unlock()
token := r.FormValue("token")
// Only admin-scope tokens may log in via the browser form.
validLogin := false
if token != "" {
for _, tok := range s.cfg.WebUI.Tokens {
if tok.Scope == "admin" && webUITokenMatches(token, tok) {
validLogin = true
break
}
}
}
if !validLogin {
s.renderTemplate(w, "login.html", map[string]string{"Error": "Invalid token"})
return
}
http.SetCookie(w, &http.Cookie{
Name: "csm_auth",
Value: token,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
MaxAge: 86400, // 24 hours
})
http.Redirect(w, r, "/dashboard", http.StatusFound)
}
// --- Template helpers ---
func severityClass(sev alert.Severity) string {
switch sev {
case alert.Critical:
return "critical"
case alert.High:
return "high"
case alert.Warning:
return "warning"
default:
return "info"
}
}
func severityLabel(sev alert.Severity) string {
switch sev {
case alert.Critical:
return "CRITICAL"
case alert.High:
return "HIGH"
case alert.Warning:
return "WARNING"
default:
return "INFO"
}
}
// severityRank returns a numeric rank for severity labels (higher = more severe).
func severityRank(label string) int {
switch label {
case "CRITICAL":
return 3
case "HIGH":
return 2
case "WARNING":
return 1
default:
return 0
}
}
func timeAgo(t time.Time) string {
d := time.Since(t)
switch {
case d < time.Minute:
return "just now"
case d < time.Hour:
return fmt.Sprintf("%dm ago", int(d.Minutes()))
case d < 24*time.Hour:
return fmt.Sprintf("%dh ago", int(d.Hours()))
default:
return fmt.Sprintf("%dd ago", int(d.Hours()/24))
}
}
func formatTime(t time.Time) string {
return t.Format("2006-01-02 15:04:05")
}
// --- Security headers middleware ---
func (s *Server) securityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self'; style-src 'self'; connect-src 'self'; img-src 'self' data:; font-src 'self'")
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), payment=()")
w.Header().Set("Cache-Control", "no-store")
// CORS/origin validation: reject cross-origin API requests.
// The allowed origin is derived from configuration, not from
// the request's Host header. Reading r.Host would let a proxy
// attacker forge a Host that matches their forged Origin and
// trivially pass the equality check.
if strings.HasPrefix(r.URL.Path, "/api/") {
origin := r.Header.Get("Origin")
if origin != "" {
allowed := s.canonicalAllowedOrigin()
if !sameOrigin(origin, allowed) {
http.Error(w, "Cross-origin request blocked", http.StatusForbidden)
return
}
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
// Deny CORS preflight from unknown origins
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
}
// API rate limiting: 600 requests per minute per IP
if strings.HasPrefix(r.URL.Path, "/api/") {
ip := clientIPKey(r.RemoteAddr)
s.apiMu.Lock()
now := time.Now()
cutoff := now.Add(-time.Minute)
var recent []time.Time
for _, t := range s.apiRequests[ip] {
if t.After(cutoff) {
recent = append(recent, t)
}
}
if len(recent) >= 600 {
s.apiMu.Unlock()
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
s.apiRequests[ip] = append(recent, now)
s.apiMu.Unlock()
}
next.ServeHTTP(w, r)
})
}
// --- CSRF protection ---
// csrfToken generates a deterministic CSRF token from an active admin secret.
// This is safe because the credential is secret and the CSRF token is derived
// via HMAC - knowing the CSRF token doesn't reveal the credential.
func (s *Server) csrfToken() string {
secret := s.csrfSecret()
if secret == "" {
return ""
}
mac := hmac.New(sha256.New, []byte(secret))
// Include start time so token rotates on each daemon restart
fmt.Fprintf(mac, "csm-csrf-v1:%d", s.startTime.Unix())
return hex.EncodeToString(mac.Sum(nil))[:32]
}
func (s *Server) csrfSecret() string {
for _, tok := range s.cfg.WebUI.Tokens {
if tok.Scope == "admin" && tok.Token != "" {
return tok.Token
}
}
if len(s.cfg.WebUI.Tokens) == 0 {
return s.cfg.WebUI.AuthToken
}
return ""
}
// validateCSRF enforces the browser-session CSRF boundary on state-changing
// routes. Bearer-authenticated requests skip the check because cross-origin
// browser requests cannot attach the Authorization header without script access
// to the bearer token.
func (s *Server) validateCSRF(r *http.Request) bool {
if !isUnsafeCSRFMethod(r.Method) {
return true // only validate state-changing methods
}
// Skip CSRF only when the bearer token itself grants admin writes. A
// read-scope bearer presented alongside an admin cookie must not turn
// the cookie-authenticated request into a CSRF-exempt API call.
// CSRF protection is only needed for cookie-based browser sessions.
if s.isAdminBearerAuth(r) {
return true
}
expected := s.csrfToken()
// A request without an active admin secret cannot prove it came from a
// browser session, so the mutating path stays closed.
if expected == "" {
return false
}
// Check header (API calls from JS use this)
if token := r.Header.Get("X-CSRF-Token"); token != "" {
return subtle.ConstantTimeCompare([]byte(token), []byte(expected)) == 1
}
// Check form field (traditional form posts)
if token := r.FormValue("csrf_token"); token != "" {
return subtle.ConstantTimeCompare([]byte(token), []byte(expected)) == 1
}
return false
}
// requireCSRF wraps a handler to validate CSRF on POST, PUT, PATCH, and DELETE
// requests. PUT joined the unsafe set when /api/v1/prefs/user landed; the
// existing list pre-dates that endpoint.
func (s *Server) requireCSRF(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip CSRF for admin Bearer token auth. API-to-API callers do not
// need CSRF protection, but read-scope bearer tokens never authorize
// mutating handlers on their own.
if isUnsafeCSRFMethod(r.Method) && !s.isAdminBearerAuth(r) && !s.validateCSRF(r) {
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
func isUnsafeCSRFMethod(method string) bool {
switch method {
case http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch:
return true
default:
return false
}
}
func (s *Server) isBearerAuth(r *http.Request) bool {
_, ok := s.bearerTokenWithScope(r, "read")
return ok
}
func (s *Server) isAdminBearerAuth(r *http.Request) bool {
_, ok := s.bearerTokenWithScope(r, "admin")
return ok
}
// --- Logout ---
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: "csm_auth",
Value: "",
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
MaxAge: -1, // delete cookie
})
http.Redirect(w, r, "/login", http.StatusFound)
}
// --- Scan rate limiting ---
// acquireScan tries to start a scan. Returns false if a scan is already running.
func (s *Server) acquireScan() bool {
s.scanMu.Lock()
defer s.scanMu.Unlock()
if s.scanRunning {
return false
}
s.scanRunning = true
return true
}
func (s *Server) releaseScan() {
s.scanMu.Lock()
s.scanRunning = false
s.scanMu.Unlock()
}
package webui
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"os/exec"
"reflect"
"sort"
"strconv"
"strings"
"time"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/integrity"
"gopkg.in/yaml.v3"
)
const settingsURLPrefix = "/api/v1/settings/"
type pendingSettingsSection struct {
ID string `json:"id"`
Title string `json:"title"`
}
func pendingRestartSections(live, disk *config.Config) []pendingSettingsSection {
if live == nil || disk == nil {
return nil
}
seen := map[string]struct{}{}
for _, c := range config.Diff(live, disk) {
if c.Tag == config.TagSafe {
continue
}
for _, section := range settingsSections {
if c.Field == section.YAMLPath || strings.HasPrefix(c.Field, section.YAMLPath+".") {
seen[section.ID] = struct{}{}
break
}
}
}
if len(seen) == 0 {
return nil
}
out := make([]pendingSettingsSection, 0, len(seen))
for _, section := range settingsSections {
if _, ok := seen[section.ID]; ok {
out = append(out, pendingSettingsSection{ID: section.ID, Title: section.Title})
}
}
return out
}
func cloneConfigForSettingsApply(src *config.Config) config.Config {
clone := *src
if src.Firewall != nil {
fw := *src.Firewall
clone.Firewall = &fw
}
return clone
}
func (s *Server) apiSettingsSections(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
writeJSON(w, map[string]interface{}{
"groups": SectionGroupOrder,
"sections": AllSettingsSections(),
})
}
func (s *Server) apiSettings(w http.ResponseWriter, r *http.Request) {
// This prefix serves both GET (read settings) and POST (update); CSRF is
// enforced only on the mutating POST path.
switch r.Method {
case http.MethodGet:
s.apiSettingsGet(w, r)
case http.MethodPost:
s.requireCSRF(http.HandlerFunc(s.apiSettingsPost)).ServeHTTP(w, r)
default:
writeJSONError(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
func (s *Server) apiSettingsGet(w http.ResponseWriter, r *http.Request) {
sectionID := strings.TrimPrefix(r.URL.Path, settingsURLPrefix)
if sectionID == "" || strings.Contains(sectionID, "/") {
writeJSONError(w, "section required", http.StatusBadRequest)
return
}
if sectionID == "restart" {
writeJSONError(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
section, ok := LookupSettingsSection(sectionID)
if !ok {
writeJSONError(w, "unknown section", http.StatusNotFound)
return
}
resolveFieldOptions(§ion)
diskBytes, err := os.ReadFile(s.cfg.ConfigFile) // #nosec G304 -- operator-configured config path
if err != nil {
writeJSONError(w, "read config: "+err.Error(), http.StatusInternalServerError)
return
}
disk, err := config.LoadBytes(diskBytes)
if err != nil {
writeJSONError(w, "load config: "+err.Error(), http.StatusInternalServerError)
return
}
disk.ConfigFile = s.cfg.ConfigFile
redacted := config.Redact(disk)
values, err := extractSectionValues(diskBytes, redacted, section)
if err != nil {
writeJSONError(w, "extract: "+err.Error(), http.StatusInternalServerError)
return
}
var pendingFields []string
var pendingSections []pendingSettingsSection
if live := config.Active(); live != nil {
diff := config.Diff(live, disk)
for _, c := range diff {
if c.Tag != config.TagSafe && (c.Field == section.YAMLPath || strings.HasPrefix(c.Field, section.YAMLPath+".")) {
pendingFields = append(pendingFields, c.Field)
}
}
pendingSections = pendingRestartSections(live, disk)
}
w.Header().Set("ETag", disk.Integrity.ConfigHash)
writeJSON(w, map[string]interface{}{
"section": section,
"values": values,
"etag": disk.Integrity.ConfigHash,
"pending_restart": len(pendingFields) > 0,
"pending_fields": pendingFields,
"pending_sections": pendingSections,
})
}
func extractSectionValues(rawBytes []byte, effectiveCfg *config.Config, section SettingsSection) (map[string]interface{}, error) {
effective, err := extractSectionEffectiveValues(effectiveCfg, section)
if err != nil {
return nil, err
}
raw, err := extractSectionRawValues(rawBytes, section)
if err != nil {
return nil, err
}
values := make(map[string]interface{}, len(effective))
for k, v := range effective {
values[k] = v
}
overlayNullableState(section, values, raw)
return values, nil
}
func extractSectionEffectiveValues(cfg *config.Config, section SettingsSection) (map[string]interface{}, error) {
var wrapper map[string]interface{}
data, err := yaml.Marshal(cfg)
if err != nil {
return nil, err
}
if err := yaml.Unmarshal(data, &wrapper); err != nil {
return nil, err
}
raw, ok := wrapper[section.YAMLPath]
if !ok {
return map[string]interface{}{}, nil
}
if _, isMap := raw.(map[string]interface{}); !isMap {
return map[string]interface{}{section.YAMLPath: raw}, nil
}
return raw.(map[string]interface{}), nil
}
func extractSectionRawValues(rawBytes []byte, section SettingsSection) (map[string]interface{}, error) {
var wrapper map[string]interface{}
if err := yaml.Unmarshal(rawBytes, &wrapper); err != nil {
return nil, err
}
raw, ok := wrapper[section.YAMLPath]
if !ok {
return map[string]interface{}{}, nil
}
if _, isMap := raw.(map[string]interface{}); !isMap {
return map[string]interface{}{section.YAMLPath: raw}, nil
}
return raw.(map[string]interface{}), nil
}
func overlayNullableState(section SettingsSection, values, raw map[string]interface{}) {
for _, field := range section.Fields {
if !field.Nullable {
continue
}
// All v1 nullable fields are direct children of the section.
if strings.Contains(field.YAMLPath, ".") {
continue
}
if v, ok := raw[field.YAMLPath]; ok {
values[field.YAMLPath] = v
continue
}
values[field.YAMLPath] = nil
}
}
func (s *Server) apiSettingsPost(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
sectionID := strings.TrimPrefix(r.URL.Path, settingsURLPrefix)
if sectionID == "" || strings.Contains(sectionID, "/") {
writeJSONError(w, "section required", http.StatusBadRequest)
return
}
section, ok := LookupSettingsSection(sectionID)
if !ok {
writeJSONError(w, "unknown section", http.StatusNotFound)
return
}
ifMatch := r.Header.Get("If-Match")
if ifMatch == "" {
writeJSONError(w, "If-Match header required", http.StatusBadRequest)
return
}
var body struct {
Changes map[string]json.RawMessage `json:"changes"`
}
if err := decodeJSONBodyLimited(w, r, 256*1024, &body); err != nil {
writeJSONError(w, "invalid body: "+err.Error(), http.StatusBadRequest)
return
}
diskBytes, err := os.ReadFile(s.cfg.ConfigFile) // #nosec G304 -- operator-supplied config path
if err != nil {
writeJSONError(w, "read config: "+err.Error(), http.StatusInternalServerError)
return
}
disk, err := config.LoadBytes(diskBytes)
if err != nil {
writeJSONError(w, "parse config: "+err.Error(), http.StatusInternalServerError)
return
}
disk.ConfigFile = s.cfg.ConfigFile
disk.ConfigDir = s.cfg.ConfigDir
if disk.Integrity.ConfigHash != ifMatch {
writeJSONError(w, "config changed on disk, reload", http.StatusPreconditionFailed)
return
}
if rejectIfConfDirChanged(w, s.cfg.ConfigDir, disk.Integrity.ConfdHash) {
return
}
clone := cloneConfigForSettingsApply(disk)
yamlChanges, errs := buildChangeSet(section, &clone, body.Changes)
if len(errs) > 0 {
writeValidationErrors(w, errs)
return
}
validationResults := append(config.Validate(&clone), config.ValidateDeepSection(&clone, section.ID)...)
fieldErrors, warnings := splitValidationResults(validationResults)
if len(fieldErrors) > 0 {
writeValidationErrors(w, fieldErrors)
return
}
if section.ID == "firewall" {
warnings = append(warnings, firewallLockoutWarnings(&clone)...)
}
diff := config.Diff(disk, &clone)
var restartFields []string
for _, c := range diff {
if c.Tag != config.TagSafe {
restartFields = append(restartFields, c.Field)
}
}
edited, err := config.YAMLEdit(diskBytes, yamlChanges)
if err != nil {
writeJSONError(w, "yaml edit: "+err.Error(), http.StatusInternalServerError)
return
}
if err := integrity.SignAndSavePreserving(s.cfg.ConfigFile, s.cfg.ConfigDir, edited, &clone, disk.Integrity.BinaryHash); err != nil {
writeJSONError(w, "save: "+err.Error(), http.StatusInternalServerError)
return
}
newETag := clone.Integrity.ConfigHash
newIntegrity := clone.Integrity
if len(restartFields) == 0 {
if live := config.Active(); live != nil {
liveClone := cloneConfigForSettingsApply(live)
if _, liveErrs := buildChangeSet(section, &liveClone, body.Changes); len(liveErrs) == 0 {
liveClone.Integrity = newIntegrity
config.SetActive(&liveClone)
} else {
livePatched := *live
livePatched.Integrity = newIntegrity
config.SetActive(&livePatched)
}
} else {
config.SetActive(&clone)
}
} else if live := config.Active(); live != nil {
livePatched := *live
livePatched.Integrity = newIntegrity
config.SetActive(&livePatched)
}
var applied []string
for _, c := range diff {
applied = append(applied, c.Field)
}
s.auditLog(r, "settings-save", sectionID, auditDetailsFor(section, body.Changes))
writeJSON(w, map[string]interface{}{
"applied": applied,
"requires_restart": restartFields,
"pending_restart": len(restartFields) > 0,
"pending_sections": pendingRestartSections(config.Active(), &clone),
"warnings": warnings,
"new_etag": newETag,
})
}
func rejectIfConfDirChanged(w http.ResponseWriter, confDir, storedHash string) bool {
currentHash, err := integrity.HashConfDir(confDir)
if err != nil {
writeJSONError(w, "hash conf.d: "+err.Error(), http.StatusInternalServerError)
return true
}
if currentHash != storedHash {
writeJSONError(w, "conf.d changed on disk, reload", http.StatusPreconditionFailed)
return true
}
return false
}
type fieldError struct {
Field string `json:"field"`
Message string `json:"message"`
}
func splitValidationResults(results []config.ValidationResult) (errs []fieldError, warnings []fieldError) {
for _, v := range results {
if v.Level == "error" {
errs = append(errs, fieldError{Field: v.Field, Message: v.Message})
continue
}
if v.Level == "warn" {
warnings = append(warnings, fieldError{Field: v.Field, Message: v.Message})
}
}
return errs, warnings
}
func writeValidationErrors(w http.ResponseWriter, errs []fieldError) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnprocessableEntity)
_ = json.NewEncoder(w).Encode(map[string]interface{}{"errors": errs})
}
func buildChangeSet(section SettingsSection, clone *config.Config, changes map[string]json.RawMessage) ([]config.YAMLChange, []fieldError) {
var out []config.YAMLChange
var errs []fieldError
for key, raw := range changes {
field := lookupSchemaField(section, key)
if field == nil {
errs = append(errs, fieldError{Field: key, Message: "unknown field"})
continue
}
if field.Secret {
var sv string
if err := json.Unmarshal(raw, &sv); err == nil && sv == "***REDACTED***" {
continue
}
}
if field.Type == "[]enum" {
if badValues, ok := validateEnumArray(field, raw); !ok {
for _, bv := range badValues {
errs = append(errs, fieldError{Field: key, Message: "unknown value: " + bv})
}
continue
}
}
if field.Type == "enum" {
if badValue, ok := validateEnumScalar(field, raw); !ok {
errs = append(errs, fieldError{Field: key, Message: "unknown value: " + badValue})
continue
}
}
if field.Type == "[]int" {
normalised, badValues, perr := normaliseIntArray(field, raw)
if perr != nil {
errs = append(errs, fieldError{Field: key, Message: perr.Error()})
continue
}
if len(badValues) > 0 {
for _, bv := range badValues {
errs = append(errs, fieldError{Field: key, Message: "invalid value: " + bv})
}
continue
}
raw = normalised
}
// For float fields, coerce JSON string -> JSON number so the downstream
// json.Unmarshal into *float64 (in applyToClone) succeeds.
if field.Type == "float" {
if normalised, ok := coerceFloatRaw(raw); ok {
raw = normalised
} else {
errs = append(errs, fieldError{Field: key, Message: "decode: expected float"})
continue
}
}
fullPath := section.YAMLPath
if key != "" {
fullPath = section.YAMLPath + "." + key
}
decoded, err := decodeJSONForYAML(raw, field)
if err != nil {
errs = append(errs, fieldError{Field: key, Message: "decode: " + err.Error()})
continue
}
out = append(out, config.YAMLChange{Path: strings.Split(fullPath, "."), Value: decoded})
if err := applyToClone(clone, strings.Split(fullPath, "."), raw); err != nil {
errs = append(errs, fieldError{Field: key, Message: err.Error()})
}
}
return out, errs
}
func validateEnumScalar(field *SettingsField, raw json.RawMessage) (bad string, ok bool) {
var value string
if err := json.Unmarshal(raw, &value); err != nil {
return "(not a string)", false
}
resolved := resolvedOptionsForField(field)
if len(resolved) == 0 {
return "", true
}
for _, opt := range resolved {
if value == opt {
return "", true
}
}
return value, false
}
// validateEnumArray checks that raw is a JSON array of strings, each of
// which appears in the field's resolved Options. Returns the slice of
// unknown values and ok=false when any value is out-of-set. An empty
// array is allowed (clears the list). Calls resolveFieldOptions-style
// logic by inspecting OptionsSource + Options directly so validation
// doesn't depend on GET ordering.
func validateEnumArray(field *SettingsField, raw json.RawMessage) (bad []string, ok bool) {
var values []string
if err := json.Unmarshal(raw, &values); err != nil {
return []string{"(not a string array)"}, false
}
resolved := resolvedOptionsForField(field)
if len(resolved) == 0 {
return nil, true
}
allowed := make(map[string]struct{}, len(resolved))
for _, v := range resolved {
allowed[v] = struct{}{}
}
seen := map[string]struct{}{}
for _, v := range values {
if _, dup := seen[v]; dup {
continue
}
seen[v] = struct{}{}
if _, okk := allowed[v]; !okk {
bad = append(bad, v)
}
}
return bad, len(bad) == 0
}
// resolvedOptionsForField returns the flat list of allowed values for a
// []enum field, whether it uses static Options or an OptionsSource. Keeps
// POST-side validation independent of the GET-time mutation.
func resolvedOptionsForField(field *SettingsField) []string {
if len(field.Options) > 0 {
return field.Options
}
tmp := &SettingsField{Type: field.Type, OptionsSource: field.OptionsSource}
switch field.OptionsSource {
case "check_names":
applyCheckNameOptions(tmp)
case "geoip_editions":
applyGeoIPEditionOptions(tmp)
}
return tmp.Options
}
// firewallLockoutWarnings flags the most common ways a firewall save can lock
// the operator out: WebUI port missing from tcp_in (or tcp6_in when IPv6 dual
// stack is enabled), and firewall.enabled flipped on while the WebUI port is
// not allowed inbound. Returns warnings only -- never errors -- so the UI can
// surface a confirm modal without blocking deliberate changes.
func firewallLockoutWarnings(cfg *config.Config) []fieldError {
if cfg == nil || cfg.Firewall == nil {
return nil
}
listen := cfg.WebUI.Listen
if listen == "" {
listen = "0.0.0.0:9443"
}
_, portStr, err := net.SplitHostPort(listen)
if err != nil {
return nil
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil
}
contains := func(ports []int, p int) bool {
for _, q := range ports {
if q == p {
return true
}
}
return false
}
var out []fieldError
if !contains(cfg.Firewall.TCPIn, port) {
out = append(out, fieldError{
Field: "tcp_in",
Message: fmt.Sprintf("WebUI listens on %d but the port is not in tcp_in. Restart will lock you out of the WebUI.", port),
})
}
if cfg.Firewall.IPv6 && len(cfg.Firewall.TCP6In) > 0 && !contains(cfg.Firewall.TCP6In, port) {
out = append(out, fieldError{
Field: "tcp6_in",
Message: fmt.Sprintf("IPv6 dual-stack is enabled and tcp6_in does not include WebUI port %d.", port),
})
}
return out
}
// normaliseIntArray parses raw as a JSON array of integers (or strings that
// parse as integers), enforces field.Min/Max as the per-element bound (default
// 1..65535 for port-list semantics when both are nil), deduplicates, and
// returns a JSON array of distinct ascending integers. Returns the offending
// raw values as badValues when any element is outside the allowed range; the
// parse error path is reserved for malformed JSON.
func normaliseIntArray(field *SettingsField, raw json.RawMessage) (json.RawMessage, []string, error) {
var items []json.RawMessage
if err := json.Unmarshal(raw, &items); err != nil {
return nil, nil, fmt.Errorf("expected array of ints: %s", err.Error())
}
minV := int64(1)
maxV := int64(65535)
if field.Min != nil {
minV = *field.Min
}
if field.Max != nil {
maxV = *field.Max
}
seen := make(map[int64]struct{}, len(items))
out := make([]int64, 0, len(items))
var bad []string
for _, item := range items {
var n int64
if err := json.Unmarshal(item, &n); err != nil {
var s string
if serr := json.Unmarshal(item, &s); serr != nil {
bad = append(bad, string(item))
continue
}
s = strings.TrimSpace(s)
if s == "" {
continue
}
parsed, perr := strconv.ParseInt(s, 10, 64)
if perr != nil {
bad = append(bad, s)
continue
}
n = parsed
}
if n < minV || n > maxV {
bad = append(bad, strconv.FormatInt(n, 10))
continue
}
if _, dup := seen[n]; dup {
continue
}
seen[n] = struct{}{}
out = append(out, n)
}
if len(bad) > 0 {
return nil, bad, nil
}
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
encoded, err := json.Marshal(out)
if err != nil {
return nil, nil, err
}
return encoded, nil, nil
}
// coerceFloatRaw returns a JSON-number representation of raw if raw is either
// a JSON number already or a JSON string that parses to float64. The second
// return is false if neither form is valid.
func coerceFloatRaw(raw json.RawMessage) (json.RawMessage, bool) {
var asNum float64
if err := json.Unmarshal(raw, &asNum); err == nil {
b, _ := json.Marshal(asNum)
return b, true
}
var asStr string
if err := json.Unmarshal(raw, &asStr); err == nil {
f, err := strconv.ParseFloat(asStr, 64)
if err != nil {
return nil, false
}
b, _ := json.Marshal(f)
return b, true
}
return nil, false
}
func lookupSchemaField(section SettingsSection, key string) *SettingsField {
for i := range section.Fields {
if section.Fields[i].YAMLPath == key {
return §ion.Fields[i]
}
}
return nil
}
func decodeJSONForYAML(raw json.RawMessage, field *SettingsField) (interface{}, error) {
if string(raw) == "null" {
if !field.Nullable {
return nil, fmt.Errorf("null is only allowed for nullable fields")
}
return nil, nil
}
if field.Type == "float" {
// Accept either a JSON number or a JSON string containing a number.
var asNum float64
if err := json.Unmarshal(raw, &asNum); err == nil {
return asNum, nil
}
var asStr string
if err := json.Unmarshal(raw, &asStr); err == nil {
f, perr := strconv.ParseFloat(asStr, 64)
if perr != nil {
return nil, fmt.Errorf("not a float: %q", asStr)
}
return f, nil
}
return nil, fmt.Errorf("expected float, got %s", string(raw))
}
var v interface{}
if err := json.Unmarshal(raw, &v); err != nil {
return nil, err
}
return v, nil
}
func applyToClone(cfg *config.Config, path []string, raw json.RawMessage) error {
v := reflect.ValueOf(cfg).Elem()
for i, key := range path {
if v.Kind() == reflect.Pointer {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return fmt.Errorf("path %v: element %d is not a struct", path, i)
}
field, ok := fieldByYAMLTag(v.Type(), key)
if !ok {
return fmt.Errorf("no yaml field %q under %s", key, strings.Join(path[:i], "."))
}
v = v.FieldByIndex(field.Index)
}
ptr := reflect.New(v.Type())
if err := json.Unmarshal(raw, ptr.Interface()); err != nil {
return fmt.Errorf("unmarshal into %s: %w", v.Type(), err)
}
v.Set(ptr.Elem())
return nil
}
func fieldByYAMLTag(t reflect.Type, yamlName string) (reflect.StructField, bool) {
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
tag := f.Tag.Get("yaml")
if tag == "" {
continue
}
name := tag
if idx := strings.IndexByte(tag, ','); idx >= 0 {
name = tag[:idx]
}
if name == yamlName {
return f, true
}
}
return reflect.StructField{}, false
}
func auditDetailsFor(section SettingsSection, changes map[string]json.RawMessage) string {
redacted := make(map[string]interface{}, len(changes))
for k, raw := range changes {
if field := lookupSchemaField(section, k); field != nil && field.Secret {
redacted[k] = "***"
continue
}
var v interface{}
_ = json.Unmarshal(raw, &v)
redacted[k] = v
}
b, _ := json.Marshal(redacted)
return string(b)
}
// defaultRestartDaemon is the production implementation. Tests override
// s.restartDaemon with a fake.
func defaultRestartDaemon() ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// #nosec G204 -- fixed argv, no operator input interpolated.
cmd := exec.CommandContext(ctx, "systemctl", "restart", "csm")
return cmd.CombinedOutput()
}
// apiSettingsRestart handles POST /api/v1/settings/restart. Returns 202
// on successful systemctl invocation (the server process may die
// mid-response, so the frontend treats a connection reset as expected).
// Returns 500 + stderr truncated to 4 KiB on failure before teardown.
func (s *Server) apiSettingsRestart(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
s.auditLog(r, "settings-restart", "daemon", "")
output, err := s.restartDaemon()
if err != nil {
truncated := output
if len(truncated) > 4096 {
truncated = truncated[:4096]
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"error": err.Error(),
"stderr": string(truncated),
})
return
}
w.WriteHeader(http.StatusAccepted)
_, _ = w.Write([]byte(`{"status":"restart issued"}`))
}
package webui
import (
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/geoip"
)
// resolveFieldOptions populates Options / OptionGroups for any []enum field
// that declares an OptionsSource. Called once per GET /api/v1/settings/:id
// so the UI always sees a fresh list (for check_names this matters if the
// registry grows over time).
//
// The section's Fields slice header points into the package-level
// settingsSections backing array. Writing into it directly would race
// across concurrent requests. Copy the slice first so mutations are local
// to this request.
func resolveFieldOptions(section *SettingsSection) {
fields := make([]SettingsField, len(section.Fields))
copy(fields, section.Fields)
section.Fields = fields
for i := range section.Fields {
f := §ion.Fields[i]
if f.OptionsSource == "" {
continue
}
switch f.OptionsSource {
case "check_names":
applyCheckNameOptions(f)
case "disabled_check_names":
applyDisabledCheckNameOptions(f)
case "geoip_editions":
applyGeoIPEditionOptions(f)
}
}
}
func applyCheckNameOptions(f *SettingsField) {
infos := checks.PublicCheckInfos()
byCategory := make(map[string][]string)
var order []string
for _, info := range infos {
if _, ok := byCategory[info.Category]; !ok {
order = append(order, info.Category)
}
byCategory[info.Category] = append(byCategory[info.Category], info.Name)
}
groups := make([]OptionGroup, 0, len(order))
flat := make([]string, 0, len(infos))
for _, cat := range order {
groups = append(groups, OptionGroup{Label: cat, Values: byCategory[cat]})
flat = append(flat, byCategory[cat]...)
}
f.Options = flat
f.OptionGroups = groups
}
func applyDisabledCheckNameOptions(f *SettingsField) {
allowed := make(map[string]struct{})
for _, name := range checks.DisabledCheckNames() {
allowed[name] = struct{}{}
}
infos := checks.PublicCheckInfos()
byCategory := make(map[string][]string)
var order []string
for _, info := range infos {
if _, ok := allowed[info.Name]; !ok {
continue
}
if _, ok := byCategory[info.Category]; !ok {
order = append(order, info.Category)
}
byCategory[info.Category] = append(byCategory[info.Category], info.Name)
}
groups := make([]OptionGroup, 0, len(order))
flat := make([]string, 0, len(allowed))
for _, cat := range order {
groups = append(groups, OptionGroup{Label: cat, Values: byCategory[cat]})
flat = append(flat, byCategory[cat]...)
}
f.Options = flat
f.OptionGroups = groups
}
func applyGeoIPEditionOptions(f *SettingsField) {
free, commercial := geoip.KnownEditions()
flat := append([]string{}, free...)
flat = append(flat, commercial...)
f.Options = flat
f.OptionGroups = []OptionGroup{
{Label: "GeoLite2 (free)", Values: free},
{Label: "GeoIP2 (paid)", Values: commercial},
}
}
package webui
import "github.com/pidginhost/csm/internal/config"
// OptionGroup is an ordered label + values pair used to render grouped
// multi-select options (e.g. "Authentication & Login" → [cpanel_login, ...]).
type OptionGroup struct {
Label string `json:"label"`
Values []string `json:"values"`
}
// SettingsField describes a single editable leaf within a settings
// section. YAMLPath is the dotted key path relative to the section's
// YAMLPath. For example inside the Alerts section, the field with
// YAMLPath "email.enabled" has full path "alerts.email.enabled".
//
// For Type "[]enum" fields, Options and/or OptionGroups are resolved at
// request time. A field may either declare a static Options list or set
// OptionsSource to have the handler populate Options/OptionGroups from a
// registry ("check_names", "geoip_editions").
type SettingsField struct {
YAMLPath string `json:"yaml_path"`
Type string `json:"type"`
Label string `json:"label"`
Help string `json:"help,omitempty"`
Secret bool `json:"secret,omitempty"`
Nullable bool `json:"nullable,omitempty"`
Min *int64 `json:"min,omitempty"`
Max *int64 `json:"max,omitempty"`
Options []string `json:"options,omitempty"`
OptionGroups []OptionGroup `json:"option_groups,omitempty"`
OptionsSource string `json:"options_source,omitempty"`
Placeholder string `json:"placeholder,omitempty"`
// FieldGroup is the inner subdivider label rendered as a fieldset
// inside a single section (e.g. firewall fields split into Access
// ports / Rate limits / Logging). Empty string means the field
// renders ungrouped under the section's flat grid.
FieldGroup string `json:"field_group,omitempty"`
}
// SettingsSection groups the fields of one top-level Config sub-tree.
// YAMLPath is the root key in csm.yaml (e.g. "auto_response"). ID is
// the URL-path identifier used by the API. Restart is a UI hint based
// on the current hotreload struct tag; final safe-vs-restart authority
// comes from config.Diff at runtime. Icon is a Tabler icon suffix (e.g.
// "bell" for "ti ti-bell"); Group is the nav category the section lives
// in ("Alerting", "Detection", "Integrations", "Ops").
type SettingsSection struct {
ID string `json:"id"`
Title string `json:"title"`
YAMLPath string `json:"yaml_path"`
Restart bool `json:"restart_hint"`
ReloadTag string `json:"reload_tag,omitempty"`
Icon string `json:"icon,omitempty"`
Group string `json:"group,omitempty"`
Fields []SettingsField `json:"fields"`
}
// Section groups for the sidebar. Order here defines order in the UI.
const (
SectionGroupAlerting = "Alerting"
SectionGroupDetection = "Detection"
SectionGroupFirewall = "Firewall"
SectionGroupIntegrations = "Integrations"
SectionGroupOps = "Operations"
)
// SectionGroupOrder is the display order of sidebar group headers.
var SectionGroupOrder = []string{
SectionGroupAlerting,
SectionGroupDetection,
SectionGroupFirewall,
SectionGroupIntegrations,
SectionGroupOps,
}
// Field-group labels used to subdivide large sections. Reuse via
// constants so Phase 6 does not scatter free strings through the
// schema and the static test can lock them in.
const (
FieldGroupAccessPorts = "Access ports"
FieldGroupIPv6 = "IPv6"
FieldGroupRateLimits = "Rate limits"
FieldGroupFloodProtection = "Flood protection"
FieldGroupGeoDynDNS = "Geo and DynDNS"
FieldGroupSMTPControls = "SMTP controls"
FieldGroupLogging = "Logging"
FieldGroupLimits = "Limits"
FieldGroupScanIntervals = "Scan intervals"
FieldGroupWebBruteForce = "Web brute force" // #nosec G101 -- UI label, not a credential.
FieldGroupMailBruteForce = "Mail brute force"
FieldGroupSMTPBruteForce = "SMTP brute force"
FieldGroupAccountSpray = "Account spray"
FieldGroupStateRetention = "State retention"
FieldGroupAbuseReporting = "Abuse reporting"
FieldGroupCentralDB = "Central database"
)
func int64p(v int64) *int64 { return &v }
var settingsSections = []SettingsSection{
{
ID: "alerts",
Title: "Alerts",
YAMLPath: "alerts",
Icon: "bell",
Group: SectionGroupAlerting,
Restart: false,
Fields: []SettingsField{
{YAMLPath: "email.enabled", Type: "bool", Label: "Email alerts enabled"},
{YAMLPath: "email.to", Type: "[]string", Label: "Recipients", Help: "One email address per line"},
{YAMLPath: "email.from", Type: "string", Label: "From address"},
{YAMLPath: "email.smtp", Type: "string", Label: "SMTP server", Placeholder: "smtp.example.com:587"},
{YAMLPath: "email.disabled_checks", Type: "[]enum", Label: "Disabled check names", OptionsSource: "check_names", Help: "Findings with these check names never trigger email alerts."},
{YAMLPath: "webhook.enabled", Type: "bool", Label: "Webhook alerts enabled"},
{YAMLPath: "webhook.url", Type: "string", Label: "Webhook URL"},
{YAMLPath: "webhook.type", Type: "enum", Label: "Webhook type", Options: []string{"slack", "discord", "generic", "phpanel"}},
{YAMLPath: "webhook.hmac_secret", Type: "string", Label: "Webhook HMAC secret", Secret: true},
{YAMLPath: "webhook.hmac_secret_env", Type: "string", Label: "Webhook HMAC secret env"},
{YAMLPath: "webhook.per_finding", Type: "bool", Label: "Per-finding webhook delivery"},
{YAMLPath: "heartbeat.enabled", Type: "bool", Label: "Heartbeat enabled"},
{YAMLPath: "heartbeat.url", Type: "string", Label: "Heartbeat URL"},
{YAMLPath: "max_per_hour", Type: "int", Label: "Max alerts per hour", Min: int64p(0), Max: int64p(10000)},
},
},
{
ID: "thresholds",
Title: "Thresholds",
YAMLPath: "thresholds",
Icon: "adjustments",
Group: SectionGroupAlerting,
Restart: false,
Fields: []SettingsField{
{YAMLPath: "mail_queue_warn", Type: "int", Label: "Mail queue warn", Min: int64p(0), FieldGroup: FieldGroupScanIntervals},
{YAMLPath: "mail_queue_crit", Type: "int", Label: "Mail queue critical", Min: int64p(0), FieldGroup: FieldGroupScanIntervals},
{YAMLPath: "state_expiry_hours", Type: "int", Label: "State expiry (hours)", Min: int64p(1), FieldGroup: FieldGroupStateRetention},
{YAMLPath: "deep_scan_interval_min", Type: "int", Label: "Deep scan interval (min)", Min: int64p(1), FieldGroup: FieldGroupScanIntervals},
{YAMLPath: "wp_core_check_interval_min", Type: "int", Label: "WP core check interval (min)", Min: int64p(1), FieldGroup: FieldGroupScanIntervals},
{YAMLPath: "webshell_scan_interval_min", Type: "int", Label: "Webshell scan interval (min)", Min: int64p(1), FieldGroup: FieldGroupScanIntervals},
{YAMLPath: "filesystem_scan_interval_min", Type: "int", Label: "Filesystem scan interval (min)", Min: int64p(1), FieldGroup: FieldGroupScanIntervals},
{YAMLPath: "multi_ip_login_threshold", Type: "int", Label: "Multi-IP login threshold", Min: int64p(1), FieldGroup: FieldGroupAccountSpray},
{YAMLPath: "multi_ip_login_window_min", Type: "int", Label: "Multi-IP login window (min)", Min: int64p(1), FieldGroup: FieldGroupAccountSpray},
{YAMLPath: "cred_stuffing_distinct_accounts", Type: "int", Label: "Credential stuffing distinct accounts", Min: int64p(2), Max: int64p(200), FieldGroup: FieldGroupAccountSpray, Help: "Distinct failed accounts from one source IP inside the auth window before credential_stuffing fires. Default 5."},
{YAMLPath: "plugin_check_interval_min", Type: "int", Label: "Plugin check interval (min)", Min: int64p(1), FieldGroup: FieldGroupScanIntervals},
{YAMLPath: "brute_force_window", Type: "int", Label: "Brute force window", Min: int64p(1), FieldGroup: FieldGroupWebBruteForce},
{YAMLPath: "domlog_max_files", Type: "int", Label: "Domlog max files", Min: int64p(1), Max: int64p(100000), FieldGroup: FieldGroupWebBruteForce},
{YAMLPath: "domlog_tail_lines", Type: "int", Label: "Domlog tail lines", Min: int64p(10), Max: int64p(100000), FieldGroup: FieldGroupWebBruteForce, Help: "Trailing lines tailed from each per-domain access log per WP brute-force cycle. Default 500 covers ~10 minutes of traffic on a busy site."},
{YAMLPath: "domlog_max_age_min", Type: "int", Label: "Domlog max age (min)", Min: int64p(1), Max: int64p(1440), FieldGroup: FieldGroupWebBruteForce, Help: "Skip per-domain access logs untouched for this many minutes. Default 30. Raise on low-traffic hosts where a slow-burn dictionary attack against a quiet domain still needs to fall inside the freshness window."},
{YAMLPath: "mail_log_tail_lines", Type: "int", Label: "Mail log tail lines", Min: int64p(10), Max: int64p(100000), FieldGroup: FieldGroupMailBruteForce, Help: "Trailing lines of /var/log/exim_mainlog read by the per-account mail rate scanner. Default 500. Raise on busy mail hosts where a single account's spam burst spreads across more than 500 lines per cycle."},
{YAMLPath: "syslog_messages_tail_lines", Type: "int", Label: "Syslog messages tail lines", Min: int64p(10), Max: int64p(100000), FieldGroup: FieldGroupScanIntervals, Help: "Trailing lines of /var/log/messages read by the FTP brute-force scanner. Default 200. Raise on hosts that share /var/log/messages with noisy services so pure-ftpd failure lines do not fall outside the window."},
{YAMLPath: "account_scan_max_files", Type: "int", Label: "Account scan max files", Min: int64p(1), Max: int64p(100000), FieldGroup: FieldGroupLimits, Help: "Per-cycle cap for account and mail-domain scanner paths. Default 10000. Raise on very large multi-tenant hosts."},
{YAMLPath: "crontab_base64_blob_max_bytes", Type: "int", Label: "Crontab base64 blob max bytes", Min: int64p(1024), Max: int64p(1048576), FieldGroup: FieldGroupLimits, Help: "Encoded-byte cap for one crontab base64 candidate before decoded-content matching. Default 16384. Must be a multiple of 4."},
{YAMLPath: "http_flood_threshold", Type: "int", Label: "HTTP flood threshold", Min: int64p(0), FieldGroup: FieldGroupWebBruteForce, Help: "Per-IP requests per window that emits http_request_flood. 0 disables. Sample baseline first."},
{YAMLPath: "http_flood_window_min", Type: "int", Label: "HTTP flood window (min)", Min: int64p(1), FieldGroup: FieldGroupWebBruteForce},
{YAMLPath: "http_ua_spoof_threshold", Type: "int", Label: "UA spoof threshold", Min: int64p(1), FieldGroup: FieldGroupWebBruteForce},
{YAMLPath: "http_distributed_min_ips", Type: "int", Label: "Distributed HTTP min IPs", Min: int64p(0), FieldGroup: FieldGroupWebBruteForce, Help: "Distinct already-abusive source IPs per vhost before the distributed HTTP flood rollup fires. 0 disables."},
{YAMLPath: "smtp_bruteforce_threshold", Type: "int", Label: "SMTP bruteforce threshold", Min: int64p(1), FieldGroup: FieldGroupSMTPBruteForce},
{YAMLPath: "smtp_bruteforce_window_min", Type: "int", Label: "SMTP bruteforce window (min)", Min: int64p(1), FieldGroup: FieldGroupSMTPBruteForce},
{YAMLPath: "smtp_bruteforce_suppress_min", Type: "int", Label: "SMTP bruteforce suppress (min)", Min: int64p(1), FieldGroup: FieldGroupSMTPBruteForce},
{YAMLPath: "smtp_bruteforce_subnet_threshold", Type: "int", Label: "SMTP bruteforce /24 threshold", Min: int64p(1), FieldGroup: FieldGroupSMTPBruteForce},
{YAMLPath: "smtp_account_spray_threshold", Type: "int", Label: "SMTP account spray threshold", Min: int64p(1), FieldGroup: FieldGroupAccountSpray},
{YAMLPath: "smtp_bruteforce_max_tracked", Type: "int", Label: "SMTP bruteforce max tracked", Min: int64p(100), FieldGroup: FieldGroupStateRetention},
{YAMLPath: "mail_bruteforce_threshold", Type: "int", Label: "Mail bruteforce threshold", Min: int64p(1), FieldGroup: FieldGroupMailBruteForce},
{YAMLPath: "mail_bruteforce_window_min", Type: "int", Label: "Mail bruteforce window (min)", Min: int64p(1), FieldGroup: FieldGroupMailBruteForce},
{YAMLPath: "mail_bruteforce_suppress_min", Type: "int", Label: "Mail bruteforce suppress (min)", Min: int64p(1), FieldGroup: FieldGroupMailBruteForce},
{YAMLPath: "mail_bruteforce_subnet_threshold", Type: "int", Label: "Mail bruteforce /24 threshold", Min: int64p(1), FieldGroup: FieldGroupMailBruteForce},
{YAMLPath: "mail_account_spray_threshold", Type: "int", Label: "Mail account spray threshold", Min: int64p(1), FieldGroup: FieldGroupAccountSpray},
{YAMLPath: "mail_bruteforce_max_tracked", Type: "int", Label: "Mail bruteforce max tracked", Min: int64p(100), FieldGroup: FieldGroupStateRetention},
{YAMLPath: "mail_brute_account_key", Type: "string", Label: "Mail brute account-key extractor", FieldGroup: FieldGroupMailBruteForce, Placeholder: "builtin:dovecot-user", Help: "How to derive the account from a dovecot/postfix log line: builtin:dovecot-user (default), builtin:postfix-sasl, or regex:<pattern> where group 1 is the account."},
},
},
{
ID: "mail_logs",
Title: "Mail logs",
YAMLPath: "mail_logs",
Icon: "mail",
Group: SectionGroupDetection,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "source", Type: "enum", Label: "Log source", Options: []string{"auto", "file", "journal"}, Help: "auto: try platform default file then fall back to journal. file: require log file. journal: read systemd-journald (needs journal build tag)."},
{YAMLPath: "file", Type: "string", Label: "Log file override", Placeholder: "/var/log/maillog", Help: "Override the platform-default file path. Leave blank to keep the default."},
{YAMLPath: "units", Type: "[]string", Label: "Journal units", Help: "Systemd units to match when source=journal. One per line (e.g. postfix, dovecot)."},
},
},
{
ID: "suppressions",
Title: "Suppressions",
YAMLPath: "suppressions",
Icon: "volume-off",
Group: SectionGroupAlerting,
Restart: false,
Fields: []SettingsField{
{YAMLPath: "upcp_window_start", Type: "string", Label: "UPCP window start (HH:MM)"},
{YAMLPath: "upcp_window_end", Type: "string", Label: "UPCP window end (HH:MM)"},
{YAMLPath: "known_api_tokens", Type: "[]string", Label: "Known API tokens (hashed)"},
{YAMLPath: "ignore_paths", Type: "[]string", Label: "Ignore paths"},
{YAMLPath: "suppress_webmail_alerts", Type: "bool", Label: "Suppress webmail login alerts"},
{YAMLPath: "suppress_cpanel_login_alerts", Type: "bool", Label: "Suppress cPanel login alerts"},
{YAMLPath: "suppress_blocked_alerts", Type: "bool", Label: "Suppress alerts on auto-blocked IPs"},
{YAMLPath: "trusted_countries", Type: "[]string", Label: "Trusted countries (ISO 3166-1 alpha-2)"},
},
},
{
ID: "auto_response",
Title: "Auto-Response",
YAMLPath: "auto_response",
Icon: "bolt",
Group: SectionGroupDetection,
Restart: false,
Fields: []SettingsField{
{YAMLPath: "enabled", Type: "bool", Label: "Auto-response enabled"},
{YAMLPath: "kill_processes", Type: "bool", Label: "Kill malicious processes"},
{YAMLPath: "quarantine_files", Type: "bool", Label: "Quarantine malicious files"},
{YAMLPath: "block_ips", Type: "bool", Label: "Block attacker IPs"},
{YAMLPath: "block_expiry", Type: "string", Label: "Block expiry", Placeholder: "24h"},
{YAMLPath: "max_blocks_per_hour", Type: "int", Label: "Max IP blocks per hour", Min: int64p(0), Help: "0 uses the default 50/hour cap."},
{YAMLPath: "enforce_permissions", Type: "bool", Label: "Auto-chmod 644 world/group-writable PHP"},
{YAMLPath: "fix_wp_cron", Type: "bool", Label: "Auto-disable WP-Cron + install system cron", Help: "On perf_wp_cron findings, edit wp-config.php and add a per-user cron. Tune interval/php under Performance."},
{YAMLPath: "block_cpanel_logins", Type: "bool", Label: "Block on cPanel/webmail login alerts"},
{YAMLPath: "netblock", Type: "bool", Label: "Auto-block /24 on threshold"},
{YAMLPath: "netblock_threshold", Type: "int", Label: "Netblock threshold", Min: int64p(1)},
{YAMLPath: "permblock", Type: "bool", Label: "Auto-promote to permanent"},
{YAMLPath: "permblock_count", Type: "int", Label: "Temp blocks before permanent", Min: int64p(1)},
{YAMLPath: "permblock_interval", Type: "string", Label: "Permblock window", Placeholder: "24h"},
{YAMLPath: "clean_database", Type: "bool", Label: "Auto-clean DB injections"},
{YAMLPath: "dry_run", Type: "bool", Label: "Auto-block dry run (logs only, no nftables)"},
{YAMLPath: "verdict_callback.enabled", Type: "bool", Label: "Verdict callback hook"},
{YAMLPath: "verdict_callback.url", Type: "string", Label: "Verdict callback URL"},
{YAMLPath: "verdict_callback.hmac_secret", Type: "string", Label: "Verdict callback HMAC secret", Secret: true},
{YAMLPath: "verdict_callback.hmac_secret_env", Type: "string", Label: "Verdict callback HMAC secret env"},
{YAMLPath: "verdict_callback.allow_unsigned", Type: "bool", Label: "Allow unsigned verdict callback"},
{YAMLPath: "verdict_callback.require_response_signature", Type: "bool", Label: "Require signed verdict response"},
{YAMLPath: "verdict_callback.timeout_sec", Type: "int", Label: "Verdict callback timeout (sec)", Min: int64p(1), Max: int64p(30)},
},
},
{
ID: "reputation",
Title: "Reputation",
YAMLPath: "reputation",
Icon: "shield-check",
Group: SectionGroupIntegrations,
Restart: false,
Fields: []SettingsField{
{YAMLPath: "abuseipdb_key", Type: "string", Label: "AbuseIPDB API key", Secret: true},
{YAMLPath: "whitelist", Type: "[]string", Label: "Whitelisted IPs", Help: "Never flagged as malicious"},
{YAMLPath: "bot_verify_enabled", Type: "bool", Label: "Verify search-engine bots via rDNS", Nullable: true},
{YAMLPath: "rspamd.enabled", Type: "bool", Label: "Rspamd threat-intel"},
{YAMLPath: "rspamd.url", Type: "string", Label: "Rspamd controller URL"},
{YAMLPath: "rspamd.token", Type: "string", Label: "Rspamd controller password", Secret: true},
{YAMLPath: "rspamd.token_env", Type: "string", Label: "Rspamd password env var"},
{YAMLPath: "upstream.enabled", Type: "bool", Label: "Upstream threat-intel cache"},
{YAMLPath: "upstream.url", Type: "string", Label: "Upstream URL"},
{YAMLPath: "upstream.token", Type: "string", Label: "Upstream bearer token", Secret: true},
{YAMLPath: "upstream.token_env", Type: "string", Label: "Upstream token env var"},
{YAMLPath: "upstream.cache_ttl_min", Type: "int", Label: "Upstream cache TTL (min)", Min: int64p(1), Max: int64p(1440)},
{YAMLPath: "upstream.timeout_sec", Type: "int", Label: "Upstream request timeout (sec)", Min: int64p(1), Max: int64p(60)},
{YAMLPath: "report.enabled", Type: "bool", Label: "Report confirmed abuse", Help: "Send signed, minimized reports for confirmed-abuse IPs. Targets are configured in YAML.", FieldGroup: FieldGroupAbuseReporting},
{YAMLPath: "report.classes", Type: "[]string", Label: "Reported classes", Help: "bruteforce, php_relay, credential_stuffing, bad_asn_egress", FieldGroup: FieldGroupAbuseReporting},
{YAMLPath: "report.spool_max", Type: "int", Label: "Report spool size", Min: int64p(0), FieldGroup: FieldGroupAbuseReporting},
{YAMLPath: "central.enabled", Type: "bool", Label: "Consume central abuse database", FieldGroup: FieldGroupCentralDB},
{YAMLPath: "central.set_url", Type: "string", Label: "Scored-set URL", FieldGroup: FieldGroupCentralDB},
{YAMLPath: "central.pubkey_env", Type: "string", Label: "Central public-key env var", FieldGroup: FieldGroupCentralDB},
{YAMLPath: "central.action", Type: "string", Label: "Action on listed IPs", Help: "off | challenge | block_if_local_corroborated", FieldGroup: FieldGroupCentralDB},
{YAMLPath: "central.block_threshold", Type: "int", Label: "Block score threshold", Min: int64p(0), Max: int64p(100), FieldGroup: FieldGroupCentralDB},
{YAMLPath: "central.refresh_interval", Type: "string", Label: "Refresh interval", Help: "e.g. 6h", FieldGroup: FieldGroupCentralDB},
},
},
{
ID: "email_protection",
Title: "Email Protection",
YAMLPath: "email_protection",
Icon: "mail-shield",
Group: SectionGroupDetection,
Restart: false,
Fields: []SettingsField{
{YAMLPath: "password_check_interval_min", Type: "int", Label: "Password check interval (min)", Min: int64p(1)},
{YAMLPath: "high_volume_senders", Type: "[]string", Label: "High-volume senders"},
{YAMLPath: "rate_warn_threshold", Type: "int", Label: "Rate warn threshold", Min: int64p(1)},
{YAMLPath: "rate_crit_threshold", Type: "int", Label: "Rate critical threshold", Min: int64p(1)},
{YAMLPath: "rate_window_min", Type: "int", Label: "Rate window (min)", Min: int64p(1)},
{YAMLPath: "known_forwarders", Type: "[]string", Label: "Known forwarders"},
},
},
{
ID: "challenge",
Title: "Challenge",
YAMLPath: "challenge",
Icon: "user-question",
Group: SectionGroupDetection,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "enabled", Type: "bool", Label: "Challenge pages enabled"},
{YAMLPath: "listen_port", Type: "int", Label: "Listen port", Min: int64p(1), Max: int64p(65535)},
{YAMLPath: "difficulty", Type: "int", Label: "PoW difficulty (0-5)", Min: int64p(0), Max: int64p(5)},
{YAMLPath: "trusted_proxies", Type: "[]string", Label: "Trusted proxy IPs"},
// challenge.secret is auto-generated at daemon startup; intentionally
// omitted so the UI cannot overwrite or leak the HMAC key.
},
},
{
ID: "php_shield",
Title: "PHP Shield",
YAMLPath: "php_shield",
Icon: "brand-php",
Group: SectionGroupDetection,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "enabled", Type: "bool", Label: "PHP Shield enabled"},
},
},
{
ID: "signatures",
Title: "Signatures",
YAMLPath: "signatures",
Icon: "scan",
Group: SectionGroupDetection,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "auto_update", Type: "bool", Label: "Auto-update rules"},
{YAMLPath: "update_interval", Type: "string", Label: "Update interval", Placeholder: "24h"},
{YAMLPath: "yara_forge.enabled", Type: "bool", Label: "YARA-Forge enabled"},
{YAMLPath: "yara_forge.tier", Type: "enum", Label: "YARA-Forge tier", Options: []string{"core", "extended", "full"}},
{YAMLPath: "yara_forge.update_interval", Type: "string", Label: "YARA-Forge interval", Placeholder: "168h"},
{YAMLPath: "yara_forge.download_url", Type: "string", Label: "YARA-Forge signed ZIP URL"},
{YAMLPath: "disabled_rules", Type: "[]string", Label: "Disabled rule names"},
{YAMLPath: "yara_worker_enabled", Type: "bool", Label: "Run YARA-X in supervised worker"},
},
},
{
ID: "email_av",
Title: "Email AV",
YAMLPath: "email_av",
Icon: "virus",
Group: SectionGroupDetection,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "enabled", Type: "bool", Label: "Email AV enabled"},
{YAMLPath: "clamd_socket", Type: "string", Label: "clamd socket"},
{YAMLPath: "scan_timeout", Type: "string", Label: "Scan timeout", Placeholder: "30s"},
{YAMLPath: "max_attachment_size", Type: "int", Label: "Max attachment bytes", Min: int64p(1024)},
{YAMLPath: "max_archive_depth", Type: "int", Label: "Max archive depth", Min: int64p(0)},
{YAMLPath: "max_archive_files", Type: "int", Label: "Max archive files", Min: int64p(1)},
{YAMLPath: "max_extraction_size", Type: "int", Label: "Max extraction bytes", Min: int64p(1024)},
{YAMLPath: "quarantine_infected", Type: "bool", Label: "Quarantine infected"},
{YAMLPath: "scan_concurrency", Type: "int", Label: "Scan concurrency", Min: int64p(1), Max: int64p(64)},
},
},
{
ID: "modsec",
Title: "ModSecurity",
YAMLPath: "modsec",
Icon: "shield-lock",
Group: SectionGroupDetection,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "rules_file", Type: "string", Label: "Rules file path"},
{YAMLPath: "overrides_file", Type: "string", Label: "Overrides file path"},
{YAMLPath: "reload_command", Type: "string", Label: "Reload command"},
},
},
{
ID: "performance",
Title: "Performance",
YAMLPath: "performance",
Icon: "activity",
Group: SectionGroupOps,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "enabled", Type: "bool", Label: "Performance checks", Nullable: true, Help: "Leave unset to inherit default (on)"},
{YAMLPath: "load_high_multiplier", Type: "float", Label: "Load high multiplier"},
{YAMLPath: "load_critical_multiplier", Type: "float", Label: "Load critical multiplier"},
{YAMLPath: "php_process_warn_per_user", Type: "int", Label: "PHP process warn per user", Min: int64p(1)},
{YAMLPath: "php_process_critical_total_multiplier", Type: "int", Label: "PHP process crit multiplier", Min: int64p(1)},
{YAMLPath: "error_log_warn_size_mb", Type: "int", Label: "Error log warn size (MB)", Min: int64p(1)},
{YAMLPath: "mysql_join_buffer_max_mb", Type: "int", Label: "MySQL join buffer max (MB)", Min: int64p(1)},
{YAMLPath: "mysql_wait_timeout_max", Type: "int", Label: "MySQL wait timeout max (s)", Min: int64p(1)},
{YAMLPath: "mysql_max_connections_per_user", Type: "int", Label: "MySQL max connections per user", Min: int64p(1)},
{YAMLPath: "redis_bgsave_min_interval", Type: "int", Label: "Redis bgsave min interval (s)", Min: int64p(1)},
{YAMLPath: "redis_large_dataset_gb", Type: "int", Label: "Redis large dataset (GB)", Min: int64p(1)},
{YAMLPath: "wp_memory_limit_max_mb", Type: "int", Label: "WP memory limit max (MB)", Min: int64p(32)},
{YAMLPath: "wp_transient_warn_mb", Type: "int", Label: "WP transient warn (MB)", Min: int64p(1)},
{YAMLPath: "wp_transient_critical_mb", Type: "int", Label: "WP transient critical (MB)", Min: int64p(1)},
{YAMLPath: "wp_cron_fix.interval_minutes", Type: "int", Label: "WP-Cron fix: system cron interval (min)", Min: int64p(1), Max: int64p(60), Help: "How often the installed system cron runs wp-cron.php. Default 5."},
{YAMLPath: "wp_cron_fix.php_bin", Type: "string", Label: "WP-Cron fix: PHP binary", Placeholder: "/usr/local/bin/php", Help: "Interpreter for the cron line. Leave empty to auto-detect."},
},
},
{
ID: "cloudflare",
Title: "Cloudflare",
YAMLPath: "cloudflare",
Icon: "cloud",
Group: SectionGroupIntegrations,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "enabled", Type: "bool", Label: "Cloudflare integration"},
{YAMLPath: "refresh_hours", Type: "int", Label: "Refresh interval (hours)", Min: int64p(1), Max: int64p(168)},
},
},
{
ID: "geoip",
Title: "GeoIP",
YAMLPath: "geoip",
Icon: "world",
Group: SectionGroupIntegrations,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "account_id", Type: "string", Label: "MaxMind account ID"},
{YAMLPath: "license_key", Type: "string", Label: "MaxMind license key", Secret: true},
{YAMLPath: "editions", Type: "[]enum", Label: "Database editions", OptionsSource: "geoip_editions", Help: "Which MaxMind databases to download. GeoLite2-* are free; GeoIP2-* require a paid subscription."},
{YAMLPath: "auto_update", Type: "bool", Label: "Auto-update databases", Nullable: true},
{YAMLPath: "update_interval", Type: "string", Label: "Update interval", Placeholder: "24h"},
},
},
{
ID: "infra_ips",
Title: "Infra IPs",
YAMLPath: "infra_ips",
Icon: "server",
Group: SectionGroupOps,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "", Type: "[]string", Label: "Trusted infra IPs and CIDRs"},
},
},
{
ID: "disabled_checks",
Title: "Disabled checks",
YAMLPath: "disabled_checks",
Icon: "x-circle",
Group: SectionGroupOps,
Restart: false,
Fields: []SettingsField{
{YAMLPath: "", Type: "[]enum", Label: "Skip scheduled check runners", OptionsSource: "disabled_check_names", Help: "Selecting a finding name skips the scheduled check runner or runners that emit it, including sibling findings from the same runner. Realtime findings are not affected. For email-only suppression, use Alerts > Disabled check names instead."},
},
},
{
ID: "sentry",
Title: "Sentry",
YAMLPath: "sentry",
Icon: "bug",
Group: SectionGroupIntegrations,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "enabled", Type: "bool", Label: "Sentry enabled"},
{YAMLPath: "dsn", Type: "string", Label: "Sentry DSN", Secret: true},
{YAMLPath: "environment", Type: "string", Label: "Environment", Placeholder: "production"},
{YAMLPath: "sample_rate", Type: "float", Label: "Sample rate (0 to 1.0)"},
{YAMLPath: "debug", Type: "bool", Label: "Debug logging"},
},
},
{
ID: "firewall",
Title: "Firewall",
YAMLPath: "firewall",
Icon: "shield-lock",
Group: SectionGroupFirewall,
Restart: true,
Fields: []SettingsField{
{YAMLPath: "enabled", Type: "bool", Label: "Firewall enabled", Help: "Activates the nftables-based firewall on next daemon restart. Verify port lists below before enabling to avoid lockout.", FieldGroup: FieldGroupAccessPorts},
{YAMLPath: "ipv6", Type: "bool", Label: "IPv6 dual-stack", FieldGroup: FieldGroupIPv6},
{YAMLPath: "tcp_in", Type: "[]int", Label: "Inbound TCP ports", Help: "SSH (22) is intentionally not in the default; add it explicitly if sshd listens on 22. WebUI port must be present or you will lose remote access on restart.", FieldGroup: FieldGroupAccessPorts},
{YAMLPath: "tcp_out", Type: "[]int", Label: "Outbound TCP ports", FieldGroup: FieldGroupAccessPorts},
{YAMLPath: "udp_in", Type: "[]int", Label: "Inbound UDP ports", FieldGroup: FieldGroupAccessPorts},
{YAMLPath: "udp_out", Type: "[]int", Label: "Outbound UDP ports", Help: "Includes 6277/24441 by default for SpamAssassin DCC/Pyzor; do not remove unless rspamd-only.", FieldGroup: FieldGroupAccessPorts},
{YAMLPath: "tcp6_in", Type: "[]int", Label: "Inbound TCP6 ports", Help: "Empty inherits tcp_in. Only set when IPv6 should differ from IPv4.", FieldGroup: FieldGroupIPv6},
{YAMLPath: "tcp6_out", Type: "[]int", Label: "Outbound TCP6 ports", Help: "Empty inherits tcp_out.", FieldGroup: FieldGroupIPv6},
{YAMLPath: "udp6_in", Type: "[]int", Label: "Inbound UDP6 ports", Help: "Empty inherits udp_in.", FieldGroup: FieldGroupIPv6},
{YAMLPath: "udp6_out", Type: "[]int", Label: "Outbound UDP6 ports", Help: "Empty inherits udp_out.", FieldGroup: FieldGroupIPv6},
{YAMLPath: "restricted_tcp", Type: "[]int", Label: "Restricted TCP (infra-only)", Help: "Reachable only from infra_ips. Manage infra_ips in its own section.", FieldGroup: FieldGroupAccessPorts},
{YAMLPath: "passive_ftp_start", Type: "int", Label: "Passive FTP range start", Min: int64p(1024), Max: int64p(65535), FieldGroup: FieldGroupAccessPorts},
{YAMLPath: "passive_ftp_end", Type: "int", Label: "Passive FTP range end", Min: int64p(1024), Max: int64p(65535), FieldGroup: FieldGroupAccessPorts},
{YAMLPath: "drop_nolog", Type: "[]int", Label: "Silent-drop ports", Help: "Dropped without logging to keep scanner noise out of the log.", FieldGroup: FieldGroupLogging},
{YAMLPath: "conn_rate_limit", Type: "int", Label: "Conn rate limit (per IP/min)", Min: int64p(0), Max: int64p(100000), Help: "0 disables. 200 tolerates shared CGNAT egress.", FieldGroup: FieldGroupRateLimits},
{YAMLPath: "conn_limit", Type: "int", Label: "Concurrent connections per IP", Min: int64p(0), Max: int64p(100000), Help: "0 disables.", FieldGroup: FieldGroupRateLimits},
{YAMLPath: "syn_flood_protection", Type: "bool", Label: "SYN flood protection", FieldGroup: FieldGroupFloodProtection},
{YAMLPath: "udp_flood", Type: "bool", Label: "UDP flood protection", FieldGroup: FieldGroupFloodProtection},
{YAMLPath: "udp_flood_rate", Type: "int", Label: "UDP packets/sec", Min: int64p(1), Max: int64p(100000), FieldGroup: FieldGroupFloodProtection},
{YAMLPath: "udp_flood_burst", Type: "int", Label: "UDP burst allowance", Min: int64p(1), Max: int64p(1000000), FieldGroup: FieldGroupFloodProtection},
{YAMLPath: "deny_ip_limit", Type: "int", Label: "Permanent block cap", Min: int64p(0), Max: int64p(1000000), Help: "0 = unlimited.", FieldGroup: FieldGroupLimits},
{YAMLPath: "deny_temp_ip_limit", Type: "int", Label: "Temporary block cap", Min: int64p(0), Max: int64p(1000000), FieldGroup: FieldGroupLimits},
{YAMLPath: "country_block", Type: "[]string", Label: "Country block (ISO-3166)", Help: "Two-letter codes, one per line.", FieldGroup: FieldGroupGeoDynDNS},
{YAMLPath: "country_db_path", Type: "string", Label: "Country DB path override", Placeholder: "(uses geoip section if empty)", FieldGroup: FieldGroupGeoDynDNS},
{YAMLPath: "dyndns_hosts", Type: "[]string", Label: "DynDNS hosts", Help: "Resolved every 5 minutes and merged into the trusted set.", FieldGroup: FieldGroupGeoDynDNS},
{YAMLPath: "smtp_block", Type: "bool", Label: "Block outbound SMTP", Help: "When enabled, only smtp_allow_users may originate outbound mail. Verify allow list first.", FieldGroup: FieldGroupSMTPControls},
{YAMLPath: "smtp_allow_users", Type: "[]string", Label: "SMTP allow users", Help: "root is always allowed.", FieldGroup: FieldGroupSMTPControls},
{YAMLPath: "smtp_ports", Type: "[]int", Label: "SMTP ports", FieldGroup: FieldGroupSMTPControls},
{YAMLPath: "log_dropped", Type: "bool", Label: "Log dropped packets", FieldGroup: FieldGroupLogging},
{YAMLPath: "log_rate", Type: "int", Label: "Log entries per minute", Min: int64p(0), Max: int64p(10000), FieldGroup: FieldGroupLogging},
},
},
}
// SettingsSectionIDs returns the ordered list of section IDs.
func SettingsSectionIDs() []string {
out := make([]string, 0, len(settingsSections))
for _, s := range settingsSections {
out = append(out, s.ID)
}
return out
}
// LookupSettingsSection returns the section with the given ID.
func LookupSettingsSection(id string) (SettingsSection, bool) {
for _, s := range settingsSections {
if s.ID == id {
return withReloadPolicy(s), true
}
}
return SettingsSection{}, false
}
// AllSettingsSections returns the list of sections. Intended for
// read-only consumers such as the dashboard navigation.
func AllSettingsSections() []SettingsSection {
out := make([]SettingsSection, len(settingsSections))
for i, s := range settingsSections {
out[i] = withReloadPolicy(s)
}
return out
}
func withReloadPolicy(section SettingsSection) SettingsSection {
for _, policy := range config.HotReloadManifest() {
if policy.Field == section.YAMLPath {
section.Restart = policy.RestartRequired
section.ReloadTag = policy.Tag
break
}
}
return section
}
package webui
import (
"crypto/rand"
"encoding/hex"
"net/http"
"time"
"github.com/pidginhost/csm/internal/state"
)
func (s *Server) apiSuppressions(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
rules := s.store.LoadSuppressions()
if rules == nil {
rules = []state.SuppressionRule{}
}
writeJSON(w, rules)
case http.MethodPost:
var req struct {
Check string `json:"check"`
PathPattern string `json:"path_pattern"`
Reason string `json:"reason"`
}
if err := decodeJSONBodyLimited(w, r, 32*1024, &req); err != nil || req.Check == "" {
writeJSONError(w, "check field is required", http.StatusBadRequest)
return
}
// Generate unique ID
b := make([]byte, 8)
_, _ = rand.Read(b)
id := hex.EncodeToString(b)
rules := s.store.LoadSuppressions()
rules = append(rules, state.SuppressionRule{
ID: id,
Check: req.Check,
PathPattern: req.PathPattern,
Reason: req.Reason,
CreatedAt: time.Now(),
})
if err := s.store.SaveSuppressions(rules); err != nil {
writeJSONError(w, "failed to save suppression: "+err.Error(), http.StatusInternalServerError)
return
}
s.auditLog(r, "suppress", req.Check, "pattern: "+req.PathPattern)
writeJSON(w, map[string]string{"status": "created", "id": id})
case http.MethodDelete:
var req struct {
ID string `json:"id"`
}
if err := decodeJSONBodyLimited(w, r, 16*1024, &req); err != nil || req.ID == "" {
writeJSONError(w, "id is required", http.StatusBadRequest)
return
}
rules := s.store.LoadSuppressions()
var filtered []state.SuppressionRule
for _, rule := range rules {
if rule.ID != req.ID {
filtered = append(filtered, rule)
}
}
if err := s.store.SaveSuppressions(filtered); err != nil {
writeJSONError(w, "failed to save suppressions: "+err.Error(), http.StatusInternalServerError)
return
}
s.auditLog(r, "unsuppress", req.ID, "removed suppression rule")
writeJSON(w, map[string]string{"status": "deleted"})
default:
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
package webui
import (
"fmt"
"net"
"net/http"
"time"
"github.com/pidginhost/csm/internal/attackdb"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/threat"
)
func (s *Server) handleThreat(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "threat.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
// GET /api/v1/threat/stats
func (s *Server) apiThreatStats(w http.ResponseWriter, r *http.Request) {
adb := attackdb.Global()
if adb == nil {
writeJSON(w, map[string]string{"error": "attack database not initialized"})
return
}
writeJSON(w, adb.Stats())
}
// GET /api/v1/threat/top-attackers?limit=25
func (s *Server) apiThreatTopAttackers(w http.ResponseWriter, r *http.Request) {
adb := attackdb.Global()
if adb == nil {
writeJSON(w, []struct{}{})
return
}
limit := queryInt(r, "limit", 25)
if limit > 200 {
limit = 200
}
recs := adb.TopAttackers(limit)
// Enrich with unified intelligence
ips := make([]string, len(recs))
for i, rec := range recs {
ips[i] = rec.IP
}
intels := threat.LookupBatch(ips, s.cfg.StatePath)
type enriched struct {
*attackdb.IPRecord
UnifiedScore int `json:"unified_score"`
Verdict string `json:"verdict"`
AbuseScore int `json:"abuse_score"`
InThreatDB bool `json:"in_threat_db"`
Blocked bool `json:"currently_blocked"`
Country string `json:"country,omitempty"`
ASOrg string `json:"as_org,omitempty"`
}
results := make([]enriched, len(recs))
for i, rec := range recs {
results[i] = enriched{
IPRecord: rec,
UnifiedScore: intels[i].UnifiedScore,
Verdict: intels[i].Verdict,
AbuseScore: intels[i].AbuseScore,
InThreatDB: intels[i].InThreatDB,
Blocked: intels[i].CurrentlyBlocked,
}
if gdb := s.geoIPDB.Load(); gdb != nil {
geo := gdb.Lookup(rec.IP)
results[i].Country = geo.Country
results[i].ASOrg = geo.ASOrg
}
}
writeJSON(w, results)
}
// GET /api/v1/threat/ip?ip=1.2.3.4
func (s *Server) apiThreatIP(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip == "" || net.ParseIP(ip) == nil {
writeJSONError(w, "invalid or missing ip parameter", http.StatusBadRequest)
return
}
intel := threat.Lookup(ip, s.cfg.StatePath)
// Enrich with GeoIP data if available
if gdb := s.geoIPDB.Load(); gdb != nil {
geo := gdb.Lookup(ip)
intel.Country = geo.Country
intel.CountryName = geo.CountryName
intel.City = geo.City
intel.ASN = geo.ASN
intel.ASOrg = geo.ASOrg
intel.Network = geo.Network
}
writeJSON(w, intel)
}
// GET /api/v1/threat/events?ip=1.2.3.4&limit=50
func (s *Server) apiThreatEvents(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip == "" || net.ParseIP(ip) == nil {
writeJSONError(w, "invalid or missing ip parameter", http.StatusBadRequest)
return
}
limit := queryInt(r, "limit", 50)
if limit > 500 {
limit = 500
}
adb := attackdb.Global()
if adb == nil {
writeJSON(w, []struct{}{})
return
}
events := adb.QueryEvents(ip, limit)
if events == nil {
events = []attackdb.Event{}
}
writeJSON(w, events)
}
// GET /api/v1/threat/db-stats
func (s *Server) apiThreatDBStats(w http.ResponseWriter, r *http.Request) {
result := make(map[string]interface{})
if tdb := checks.GetThreatDB(); tdb != nil {
result["threat_db"] = tdb.Stats()
}
if adb := attackdb.Global(); adb != nil {
result["attack_db"] = map[string]interface{}{
"total_ips": adb.TotalIPs(),
"top_line": adb.FormatTopLine(),
}
}
writeJSON(w, result)
}
// POST /api/v1/threat/whitelist-ip - mark an IP as a known customer
// Unblocks, removes from threat DB + attack DB, adds to whitelist.
func (s *Server) apiThreatWhitelistIP(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
var actions []string
// 1. Unblock from firewall
if s.blocker != nil {
if err := s.blocker.UnblockIP(req.IP); err == nil {
actions = append(actions, "unblocked from firewall")
}
// Also add to firewall allow list so it doesn't get re-blocked
if allower, ok := s.blocker.(interface {
AllowIP(string, string) error
}); ok {
if err := allower.AllowIP(req.IP, "CSM whitelist: customer IP"); err == nil {
actions = append(actions, "added to firewall allow list")
}
}
}
// 2. Remove from threat DB permanent blocklist + add to whitelist
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemovePermanent(req.IP)
tdb.AddWhitelist(req.IP)
actions = append(actions, "removed from threat DB, added to whitelist")
}
// 3. Remove from attack DB
if adb := attackdb.Global(); adb != nil {
adb.RemoveIP(req.IP)
actions = append(actions, "removed from attack DB")
}
// 4. Flush cphulk
flushCphulk(req.IP)
s.auditLog(r, "whitelist_ip", req.IP, "permanent whitelist")
writeJSON(w, map[string]interface{}{
"status": "whitelisted",
"ip": req.IP,
"actions": actions,
})
}
// GET /api/v1/threat/whitelist - list all whitelisted IPs
func (s *Server) apiThreatWhitelist(w http.ResponseWriter, r *http.Request) {
tdb := checks.GetThreatDB()
if tdb == nil {
writeJSON(w, []string{})
return
}
writeJSON(w, tdb.WhitelistedIPs())
}
// POST /api/v1/threat/unwhitelist-ip - remove an IP from the whitelist
func (s *Server) apiThreatUnwhitelistIP(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemoveWhitelist(req.IP)
}
// Also remove from firewall allow list
if s.blocker != nil {
if remover, ok := s.blocker.(interface {
RemoveAllowIP(string) error
}); ok {
_ = remover.RemoveAllowIP(req.IP)
}
}
writeJSON(w, map[string]string{"status": "removed", "ip": req.IP})
}
// POST /api/v1/threat/block-ip - manually block an IP for 24 hours.
func (s *Server) apiThreatBlockIP(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
var actions []string
// 1. Block in firewall with 24h expiry
if s.blocker != nil {
// Operator-initiated: bypass auto_response.dry_run gate.
if err := blockIPForOperator(s.blocker, req.IP, "Manually blocked via CSM Web UI", 24*time.Hour); err != nil {
writeJSONError(w, fmt.Sprintf("block failed: %v", err), http.StatusInternalServerError)
return
}
actions = append(actions, "blocked in firewall for 24h")
} else {
writeJSONError(w, "firewall engine not available", http.StatusServiceUnavailable)
return
}
// 2. Add to threat DB permanent blocklist
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.AddPermanent(req.IP, "Manually blocked via CSM Web UI")
actions = append(actions, "added to threat DB")
}
// 3. Record in attack DB
if adb := attackdb.Global(); adb != nil {
adb.MarkBlocked(req.IP)
actions = append(actions, "recorded in attack DB")
}
s.auditLog(r, "block_ip", req.IP, "manual block 24h")
writeJSON(w, map[string]interface{}{
"status": "blocked",
"ip": req.IP,
"actions": actions,
})
}
// POST /api/v1/threat/clear-ip - unblock + clear from all DBs without whitelisting.
// For dynamic IP customers: one-time cleanup, IP can be re-blocked later.
func (s *Server) apiThreatClearIP(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
var actions []string
// 1. Unblock from firewall (but don't add to allow list)
if s.blocker != nil {
if err := s.blocker.UnblockIP(req.IP); err == nil {
actions = append(actions, "unblocked from firewall")
}
}
// 2. Remove from threat DB permanent blocklist (but don't whitelist)
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemovePermanent(req.IP)
actions = append(actions, "removed from threat DB")
}
// 3. Remove from attack DB
if adb := attackdb.Global(); adb != nil {
adb.RemoveIP(req.IP)
actions = append(actions, "removed from attack DB")
}
// 4. Flush cphulk
flushCphulk(req.IP)
actions = append(actions, "flushed cPanel login history")
s.auditLog(r, "clear_ip", req.IP, "unblock & clear")
writeJSON(w, map[string]interface{}{
"status": "cleared",
"ip": req.IP,
"actions": actions,
})
}
// POST /api/v1/threat/temp-whitelist-ip - whitelist for a specified duration.
func (s *Server) apiThreatTempWhitelistIP(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
Hours int `json:"hours"` // default 24
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if req.Hours <= 0 {
req.Hours = 24
}
if req.Hours > 168 { // max 7 days
req.Hours = 168
}
ttl := time.Duration(req.Hours) * time.Hour
var actions []string
// 1. Unblock from firewall
if s.blocker != nil {
if err := s.blocker.UnblockIP(req.IP); err == nil {
actions = append(actions, "unblocked from firewall")
}
// Temp allow in firewall too
if allower, ok := s.blocker.(interface {
TempAllowIP(string, string, time.Duration) error
}); ok {
if err := allower.TempAllowIP(req.IP, "CSM temp whitelist", ttl); err == nil {
actions = append(actions, fmt.Sprintf("temp allowed in firewall for %dh", req.Hours))
}
}
}
// 2. Remove from threat DB + temp whitelist
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemovePermanent(req.IP)
tdb.TempWhitelist(req.IP, ttl)
actions = append(actions, fmt.Sprintf("temp whitelisted for %dh", req.Hours))
}
// 3. Remove from attack DB
if adb := attackdb.Global(); adb != nil {
adb.RemoveIP(req.IP)
actions = append(actions, "removed from attack DB")
}
// 4. Flush cphulk
flushCphulk(req.IP)
s.auditLog(r, "temp_whitelist_ip", req.IP, fmt.Sprintf("%dh temp whitelist", req.Hours))
writeJSON(w, map[string]interface{}{
"status": "temp_whitelisted",
"ip": req.IP,
"hours": req.Hours,
"actions": actions,
})
}
// POST /api/v1/threat/bulk-action - block or whitelist multiple IPs at once.
func (s *Server) apiThreatBulkAction(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IPs []string `json:"ips"`
Action string `json:"action"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
if len(req.IPs) == 0 || len(req.IPs) > 100 {
writeJSONError(w, "IPs must be 1-100 items", http.StatusBadRequest)
return
}
if req.Action != "block" && req.Action != "whitelist" {
writeJSONError(w, "Action must be 'block' or 'whitelist'", http.StatusBadRequest)
return
}
if req.Action == "block" && s.blocker == nil {
writeJSONError(w, "firewall engine not available", http.StatusServiceUnavailable)
return
}
count := 0
succeeded := make([]string, 0, len(req.IPs))
for _, ipStr := range req.IPs {
if _, err := parseAndValidateIP(ipStr); err != nil {
continue
}
switch req.Action {
case "block":
// Mirror apiThreatBlockIP flow
if s.blocker != nil {
// Operator-initiated bulk block: bypass auto_response.dry_run gate.
if err := blockIPForOperator(s.blocker, ipStr, "Bulk blocked via CSM Web UI", 24*time.Hour); err != nil {
continue
}
}
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.AddPermanent(ipStr, "Bulk blocked via CSM Web UI")
}
if adb := attackdb.Global(); adb != nil {
adb.MarkBlocked(ipStr)
}
succeeded = append(succeeded, ipStr)
count++
case "whitelist":
// Mirror apiThreatWhitelistIP flow
if s.blocker != nil {
_ = s.blocker.UnblockIP(ipStr)
if allower, ok := s.blocker.(interface {
AllowIP(string, string) error
}); ok {
_ = allower.AllowIP(ipStr, "CSM bulk whitelist")
}
}
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemovePermanent(ipStr)
tdb.AddWhitelist(ipStr)
}
if adb := attackdb.Global(); adb != nil {
adb.RemoveIP(ipStr)
}
flushCphulk(ipStr)
succeeded = append(succeeded, ipStr)
count++
}
}
s.auditLog(r, "threat_bulk_"+req.Action, fmt.Sprintf("%d IPs", count), "")
var undoToken string
if count > 0 {
inverse := undoInverseThreatBlock
summary := fmt.Sprintf("Blocked %d IPs", count)
action := "threat_bulk_block"
if req.Action == "whitelist" {
inverse = undoInverseThreatWhitelist
summary = fmt.Sprintf("Whitelisted %d IPs", count)
action = "threat_bulk_whitelist"
}
undoToken = s.recordUndoEntry(r, action, inverse, summary, undoPayloadIPs{IPs: succeeded})
}
writeJSON(w, map[string]interface{}{
"ok": true,
"count": count,
"undo_token": undoToken,
})
}
// writeJSON is defined in api.go
package webui
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"time"
)
// EnsureTLSCert generates a self-signed ECDSA P-256 certificate if the
// cert and key files don't exist. Includes localhost and the server hostname
// in the certificate SANs.
func EnsureTLSCert(certPath, keyPath string, extraNames ...string) error {
// If both exist, nothing to do
if fileExists(certPath) && fileExists(keyPath) {
return nil
}
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return fmt.Errorf("generating key: %w", err)
}
serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
// Use first extra name (hostname) as CN, fall back to localhost
cn := "localhost"
if len(extraNames) > 0 && extraNames[0] != "" {
cn = extraNames[0]
}
template := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
Organization: []string{"CSM Security Monitor"},
CommonName: cn,
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: buildDNSNames(extraNames),
IPAddresses: buildIPList(extraNames),
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
return fmt.Errorf("creating certificate: %w", err)
}
// Write cert
// #nosec G304 -- certPath is derived from config-owned statePath at the
// callsite; this function is the cert *generator*.
certFile, err := os.Create(certPath)
if err != nil {
return fmt.Errorf("creating cert file: %w", err)
}
defer func() { _ = certFile.Close() }()
if encErr := pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); encErr != nil {
return fmt.Errorf("encoding cert: %w", encErr)
}
// Write key
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return fmt.Errorf("marshaling key: %w", err)
}
// #nosec G304 -- same as certPath: generator for a config-owned path.
keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("creating key file: %w", err)
}
defer func() { _ = keyFile.Close() }()
if err := pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil {
return fmt.Errorf("encoding key: %w", err)
}
return nil
}
func buildDNSNames(extra []string) []string {
names := []string{"localhost"}
for _, n := range extra {
if net.ParseIP(n) == nil { // not an IP - it's a hostname
names = append(names, n)
}
}
return names
}
func buildIPList(extra []string) []net.IP {
ips := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}
for _, n := range extra {
if ip := net.ParseIP(n); ip != nil {
ips = append(ips, ip)
}
}
return ips
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
package webui
import (
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/store"
)
// Recognised inverse-action keys. Each handler that records an undo entry
// sets one of these on the entry; apiUndoRun dispatches based on the value.
const (
undoInverseThreatBlock = "threat_bulk_unblock"
undoInverseThreatUnblock = "threat_bulk_block"
undoInverseThreatWhitelist = "threat_bulk_unwhitelist"
undoInverseThreatUnwhitelist = "threat_bulk_whitelist"
undoInverseFirewallUnblock = "firewall_bulk_reblock"
)
// undoPayloadIPs is the payload schema for every undo entry we currently
// generate: a list of IPs plus an optional reason and timeout. Future undo
// kinds can add their own payload structs alongside this one.
type undoPayloadIPs struct {
IPs []string `json:"ips"`
Reason string `json:"reason,omitempty"`
Timeout string `json:"timeout,omitempty"` // ParseDuration-compatible
}
// recordUndoEntry persists an undo entry for the operator who issued r and
// returns the new entry's ID. The ID lets the calling handler surface the
// undo token to the client in the same response. Any store error is logged
// and swallowed so a bulk action never fails just because the undo queue
// could not be written.
func (s *Server) recordUndoEntry(r *http.Request, action, inverse, summary string, payload undoPayloadIPs) string {
if r == nil {
return ""
}
opkey := s.operatorKey(r)
if opkey == "" {
return ""
}
sdb := store.Global()
if sdb == nil {
return ""
}
raw, err := json.Marshal(payload)
if err != nil {
return ""
}
entry, err := sdb.AppendUndoEntry(opkey, store.UndoEntry{
Action: action,
Inverse: inverse,
Payload: raw,
Summary: summary,
})
if err != nil {
log.Printf("webui: record undo entry: %v", err)
return ""
}
return entry.ID
}
// undoPendingView is the JSON shape returned to the client. The payload is
// stripped because the client only needs identity + summary to render the
// banner; the server keeps the payload for the actual undo run.
type undoPendingView struct {
ID string `json:"id"`
Action string `json:"action"`
Inverse string `json:"inverse"`
Summary string `json:"summary"`
RecordedAt time.Time `json:"recorded_at"`
ExpiresAt time.Time `json:"expires_at"`
}
// apiUndoPending returns the latest non-expired undo entry for the operator,
// or an empty object when no entry is queued.
func (s *Server) apiUndoPending(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
opkey := s.operatorKey(r)
if opkey == "" {
writeJSONError(w, "Unauthenticated", http.StatusUnauthorized)
return
}
sdb := store.Global()
if sdb == nil {
writeJSON(w, map[string]interface{}{})
return
}
entry, ok, err := sdb.LatestUndoEntry(opkey)
if err != nil {
writeJSONError(w, "Store error", http.StatusInternalServerError)
return
}
if !ok {
writeJSON(w, map[string]interface{}{})
return
}
writeJSON(w, undoPendingView{
ID: entry.ID,
Action: entry.Action,
Inverse: entry.Inverse,
Summary: entry.Summary,
RecordedAt: entry.RecordedAt,
ExpiresAt: entry.RecordedAt.Add(store.UndoTTL),
})
}
type undoRunRequest struct {
ID string `json:"id"`
}
type undoRunResponse struct {
Status string `json:"status"`
Action string `json:"action"`
Inverse string `json:"inverse"`
Count int `json:"count"`
}
// apiUndoRun consumes the named undo entry (or the most recent one when id
// is empty) and dispatches its inverse. Each successful undo also writes a
// "undo_<original>" audit entry so the trail records the reversal.
func (s *Server) apiUndoRun(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
opkey := s.operatorKey(r)
if opkey == "" {
writeJSONError(w, "Unauthenticated", http.StatusUnauthorized)
return
}
var req undoRunRequest
if err := decodeJSONBodyLimited(w, r, 4*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
sdb := store.Global()
if sdb == nil {
writeJSONError(w, "Store unavailable", http.StatusServiceUnavailable)
return
}
var (
entry store.UndoEntry
ok bool
err error
)
if req.ID == "" {
entry, ok, err = sdb.LatestUndoEntry(opkey)
if err == nil && ok {
_, _, err = sdb.ConsumeUndoEntry(opkey, entry.ID)
}
} else {
entry, ok, err = sdb.ConsumeUndoEntry(opkey, req.ID)
}
if err != nil {
writeJSONError(w, "Store error", http.StatusInternalServerError)
return
}
if !ok {
writeJSONError(w, "Undo window expired", http.StatusGone)
return
}
resp, runErr := s.runUndoEntry(r, entry)
if runErr != nil {
writeJSONError(w, runErr.Error(), http.StatusInternalServerError)
return
}
s.auditLog(r, "undo_"+entry.Action, fmt.Sprintf("%d items", resp.Count), entry.Summary)
writeJSON(w, resp)
}
func (s *Server) runUndoEntry(r *http.Request, entry store.UndoEntry) (undoRunResponse, error) {
var payload undoPayloadIPs
if len(entry.Payload) > 0 {
if err := json.Unmarshal(entry.Payload, &payload); err != nil {
return undoRunResponse{}, fmt.Errorf("decode payload: %w", err)
}
}
resp := undoRunResponse{
Status: "ok",
Action: entry.Action,
Inverse: entry.Inverse,
}
switch entry.Inverse {
case undoInverseThreatBlock:
// Original action blocked IPs; inverse unblocks them.
resp.Count = s.undoBulkBlock(payload.IPs)
case undoInverseThreatUnblock:
// Original unblocked IPs; inverse re-blocks them with the saved reason.
timeout := parseDuration(payload.Timeout)
if timeout == 0 {
timeout = 24 * time.Hour
}
reason := payload.Reason
if reason == "" {
reason = "Undo: re-block via CSM Web UI"
}
count, err := s.undoBulkReblock(payload.IPs, reason, timeout)
if err != nil {
return undoRunResponse{}, err
}
resp.Count = count
case undoInverseThreatWhitelist:
resp.Count = s.undoBulkWhitelist(payload.IPs)
case undoInverseThreatUnwhitelist:
resp.Count = s.undoBulkUnwhitelist(payload.IPs)
case undoInverseFirewallUnblock:
reason := payload.Reason
if reason == "" {
reason = "Undo: re-block via CSM Web UI"
}
timeout := parseDuration(payload.Timeout)
if timeout == 0 {
timeout = 24 * time.Hour
}
count, err := s.undoBulkReblock(payload.IPs, reason, timeout)
if err != nil {
return undoRunResponse{}, err
}
resp.Count = count
default:
return undoRunResponse{}, fmt.Errorf("unknown inverse action %q", entry.Inverse)
}
return resp, nil
}
func (s *Server) undoBulkBlock(ips []string) int {
count := 0
for _, ip := range ips {
if _, err := parseAndValidateIP(ip); err != nil {
continue
}
if s.blocker != nil {
_ = s.blocker.UnblockIP(ip)
}
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemovePermanent(ip)
}
flushCphulk(ip)
count++
}
return count
}
func (s *Server) undoBulkReblock(ips []string, reason string, timeout time.Duration) (int, error) {
if s.blocker == nil {
return 0, fmt.Errorf("firewall engine not available")
}
count := 0
for _, ip := range ips {
if _, err := parseAndValidateIP(ip); err != nil {
continue
}
if err := blockIPForOperator(s.blocker, ip, reason, timeout); err != nil {
continue
}
count++
}
return count, nil
}
func (s *Server) undoBulkWhitelist(ips []string) int {
count := 0
for _, ip := range ips {
if _, err := parseAndValidateIP(ip); err != nil {
continue
}
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemoveWhitelist(ip)
}
count++
}
return count
}
func (s *Server) undoBulkUnwhitelist(ips []string) int {
count := 0
for _, ip := range ips {
if _, err := parseAndValidateIP(ip); err != nil {
continue
}
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.AddWhitelist(ip)
}
count++
}
return count
}
package wpcheck
import (
"crypto/md5" // #nosec G501 -- MD5 is the hash wordpress.org publishes for core file checksums; this is integrity verification against a published reference, not a security primitive.
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"golang.org/x/sys/unix"
)
type Cache struct {
mu sync.RWMutex
statePath string
checksums map[string]map[string]string // core: "<version>:<locale>" -> relPath -> MD5
pluginChecksums map[string]map[string]string // plugins: "<slug>:<version>" -> relPath -> SHA256
roots map[string]rootEntry
fetching map[string]bool
// pluginNotFoundUntil records slug+version pairs that wordpress.org
// returned 404 for, paired with the absolute time at which the
// suppression expires. Plugins hosted outside wp.org (paid forks,
// custom internal plugins) would otherwise re-arm the 4-attempt
// retry cycle on every cache miss. The TTL ensures wp.org adding
// a plugin later still gets picked up.
pluginNotFoundUntil map[string]time.Time
// stopCh, when non-nil and closed, signals pending retry timers to
// drop their scheduled fetch instead of firing. Wired by the daemon
// to the FileMonitor stopCh so checksum-retry chains do not survive
// daemon shutdown.
stopMu sync.RWMutex
stopCh <-chan struct{}
}
type rootEntry struct {
version string
locale string
}
func NewCache(statePath string) *Cache {
c := &Cache{
statePath: statePath,
checksums: make(map[string]map[string]string),
pluginChecksums: make(map[string]map[string]string),
roots: make(map[string]rootEntry),
fetching: make(map[string]bool),
pluginNotFoundUntil: make(map[string]time.Time),
}
c.loadFromDisk()
return c
}
func cacheKey(version, locale string) string {
return version + ":" + locale
}
func diskFilename(version, locale string) string {
return version + "_" + locale + ".json"
}
func (c *Cache) loadFromDisk() {
dir := filepath.Join(c.statePath, "wp-checksums")
entries, err := os.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
name := entry.Name()
if entry.IsDir() || !strings.HasSuffix(name, ".json") {
continue
}
// #nosec G304 -- dir is {statePath}/wp-checksums; name comes from
// our own os.ReadDir of that same dir.
data, err := os.ReadFile(filepath.Join(dir, name))
if err != nil {
continue
}
checksums, err := ParseChecksumResponse(data)
if err != nil {
continue
}
base := strings.TrimSuffix(name, ".json")
parts := strings.SplitN(base, "_", 2)
if len(parts) != 2 {
continue
}
c.checksums[cacheKey(parts[0], parts[1])] = checksums
}
}
// PersistChecksums writes checksum data to disk atomically (tmpfile + rename)
// and populates the in-memory cache. The file is written to {statePath}/wp-checksums/.
func (c *Cache) PersistChecksums(version, locale string, rawJSON []byte, checksums map[string]string) error {
dir := filepath.Join(c.statePath, "wp-checksums")
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("creating wp-checksums dir: %w", err)
}
filename := diskFilename(version, locale)
tmpPath := filepath.Join(dir, filename+".tmp")
finalPath := filepath.Join(dir, filename)
if err := os.WriteFile(tmpPath, rawJSON, 0600); err != nil {
return fmt.Errorf("writing temp file: %w", err)
}
if err := os.Rename(tmpPath, finalPath); err != nil {
os.Remove(tmpPath)
return fmt.Errorf("renaming to final: %w", err)
}
c.mu.Lock()
c.checksums[cacheKey(version, locale)] = checksums
c.mu.Unlock()
return nil
}
func (c *Cache) lookupChecksum(version, locale, relativePath string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
versionMap, ok := c.checksums[cacheKey(version, locale)]
if !ok {
return "", false
}
md5hex, ok := versionMap[relativePath]
return md5hex, ok
}
func (c *Cache) hasChecksums(version, locale string) bool {
c.mu.RLock()
ok := c.checksums[cacheKey(version, locale)] != nil
c.mu.RUnlock()
return ok
}
func (c *Cache) getRoot(root string) (version, locale string, ok bool) {
c.mu.RLock()
entry, ok := c.roots[root]
c.mu.RUnlock()
if !ok {
return "", "", false
}
return entry.version, entry.locale, true
}
func (c *Cache) setRoot(root, version, locale string) {
c.mu.Lock()
c.roots[root] = rootEntry{version: version, locale: locale}
c.mu.Unlock()
}
func (c *Cache) invalidateRoot(root string) {
c.mu.Lock()
delete(c.roots, root)
c.mu.Unlock()
}
func (c *Cache) startBackgroundFetch(version, locale string) {
if c.isStopped() {
return
}
key := cacheKey(version, locale)
c.mu.Lock()
if c.fetching[key] {
c.mu.Unlock()
return
}
c.fetching[key] = true
c.mu.Unlock()
go c.fetchWithRetry(version, locale, 0)
}
func (c *Cache) fetchWithRetry(version, locale string, attempt int) {
backoffs := []time.Duration{1 * time.Minute, 5 * time.Minute, 15 * time.Minute, 1 * time.Hour}
key := cacheKey(version, locale)
if c.isStopped() {
c.clearFetching(key)
return
}
rawJSON, checksums, err := FetchChecksums(version, locale)
if err != nil {
delay := backoffs[len(backoffs)-1]
if attempt < len(backoffs) {
delay = backoffs[attempt]
}
fmt.Fprintf(os.Stderr, "wpcheck: fetch failed for WP %s (%s), retry in %v: %v\n",
version, locale, delay, err)
c.scheduleRetry(delay, func() {
c.fetchWithRetry(version, locale, attempt+1)
}, func() {
c.clearFetching(key)
})
return
}
if err := c.PersistChecksums(version, locale, rawJSON, checksums); err != nil {
fmt.Fprintf(os.Stderr, "wpcheck: persist failed for WP %s (%s): %v\n", version, locale, err)
}
c.clearFetching(key)
fmt.Fprintf(os.Stderr, "wpcheck: cached %d checksums for WP %s (%s)\n", len(checksums), version, locale)
}
const maxFileSize = 2 << 20
func (c *Cache) IsVerifiedCoreFile(fd int, path string) bool {
root := DetectWPRoot(path)
if root == "" {
return false
}
relPath := RelativePath(root, path)
if relPath == "" {
return false
}
if relPath == filepath.Join("wp-includes", "version.php") {
c.invalidateRoot(root)
}
version, locale, ok := c.getRoot(root)
if !ok {
var err error
version, locale, err = ReadVersionFile(root)
if err != nil {
return false
}
c.setRoot(root, version, locale)
}
if !c.hasChecksums(version, locale) {
c.startBackgroundFetch(version, locale)
return false
}
expectedMD5, ok := c.lookupChecksum(version, locale, relPath)
if !ok {
return false
}
data := make([]byte, maxFileSize)
n, err := unix.Pread(fd, data, 0)
if n <= 0 || (err != nil && n == 0) {
return false
}
data = data[:n]
// #nosec G401 -- MD5 is required here: wordpress.org ships MD5 digests
// as the canonical integrity reference for core files. We compare
// against their published values, not derive authority from the hash.
hash := md5.Sum(data)
if constantTimeHexDigestEqual(hash[:], expectedMD5) {
return true
}
c.invalidateRoot(root)
newVersion, newLocale, err := ReadVersionFile(root)
if err != nil || (newVersion == version && newLocale == locale) {
return false
}
c.setRoot(root, newVersion, newLocale)
if !c.hasChecksums(newVersion, newLocale) {
c.startBackgroundFetch(newVersion, newLocale)
return false
}
newExpectedMD5, ok := c.lookupChecksum(newVersion, newLocale, relPath)
if !ok {
return false
}
return constantTimeHexDigestEqual(hash[:], newExpectedMD5)
}
package wpcheck
import (
"errors"
"os"
"path/filepath"
"regexp"
"strings"
)
// wpRootLevelFiles are filenames that only exist at the WP installation root.
// index.php is excluded - it requires a secondary check (version.php must exist).
var wpRootLevelFiles = map[string]bool{
"wp-activate.php": true,
"wp-blog-header.php": true,
"wp-comments-post.php": true,
"wp-config-sample.php": true,
"wp-cron.php": true,
"wp-links-opml.php": true,
"wp-load.php": true,
"wp-login.php": true,
"wp-mail.php": true,
"wp-settings.php": true,
"wp-signup.php": true,
"wp-trackback.php": true,
"xmlrpc.php": true,
}
// DetectWPRoot returns the WordPress installation root directory for a file path,
// or empty string if the path is not inside a WP core location.
//
// Detection methods:
// - Path contains /wp-includes/ or /wp-admin/ → root is everything before that segment
// - Filename is a known root-level WP file → root is the parent directory
// - Filename is index.php and version.php exists in wp-includes/ → root is the parent directory
func DetectWPRoot(path string) string {
// Check for /wp-includes/ or /wp-admin/ in path
for _, marker := range []string{"/wp-includes/", "/wp-admin/"} {
if idx := strings.Index(path, marker); idx >= 0 {
return path[:idx]
}
}
// Check for direct wp-includes or wp-admin (file directly inside)
dir := filepath.Dir(path)
base := filepath.Base(dir)
if base == "wp-includes" || base == "wp-admin" {
return filepath.Dir(dir)
}
// Check for known root-level WP files
name := filepath.Base(path)
if wpRootLevelFiles[name] {
return dir
}
// Special case: index.php requires version.php to confirm WP root
if name == "index.php" {
versionPath := filepath.Join(dir, "wp-includes", "version.php")
if _, err := os.Stat(versionPath); err == nil {
return dir
}
}
return ""
}
// RelativePath computes the path of a file relative to the WP root.
// Returns empty string if the file is not under root.
func RelativePath(root, path string) string {
rel, err := filepath.Rel(root, path)
if err != nil || strings.HasPrefix(rel, "..") {
return ""
}
return rel
}
var (
reVersion = regexp.MustCompile(`\$wp_version\s*=\s*'([^']+)'`)
reLocale = regexp.MustCompile(`\$wp_local_package\s*=\s*'([^']+)'`)
)
// ParseVersionContent extracts the WP version and locale from version.php content.
// Locale defaults to "en_US" if $wp_local_package is not present.
func ParseVersionContent(data []byte) (version, locale string, err error) {
m := reVersion.FindSubmatch(data)
if m == nil {
return "", "", errors.New("wp_version not found in version.php")
}
version = string(m[1])
locale = "en_US"
if lm := reLocale.FindSubmatch(data); lm != nil {
locale = string(lm[1])
}
return version, locale, nil
}
// ReadVersionFile reads and parses {root}/wp-includes/version.php.
func ReadVersionFile(root string) (version, locale string, err error) {
// #nosec G304 -- path derived from a WordPress install root discovered
// by the scanner under configured /home/*/public_html paths.
data, err := os.ReadFile(filepath.Join(root, "wp-includes", "version.php"))
if err != nil {
return "", "", err
}
return ParseVersionContent(data)
}
package wpcheck
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
)
var httpClient = &http.Client{Timeout: 10 * time.Second}
// checksumAPIURL returns the WordPress.org checksum API URL for a version and locale.
func checksumAPIURL(version, locale string) string {
return fmt.Sprintf("https://api.wordpress.org/core/checksums/1.0/?version=%s&locale=%s", version, locale)
}
// checksumResponse is the JSON structure returned by the WP checksum API.
type checksumResponse struct {
Checksums map[string]string `json:"checksums"`
}
// ParseChecksumResponse parses the JSON response from the WP checksum API.
// Returns a map of relative_path -> md5_hex.
func ParseChecksumResponse(data []byte) (map[string]string, error) {
var resp checksumResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, fmt.Errorf("invalid JSON: %w", err)
}
if len(resp.Checksums) == 0 {
return nil, errors.New("empty or missing checksums in response")
}
return resp.Checksums, nil
}
// FetchChecksums fetches official checksums from api.wordpress.org for a given
// version and locale. Returns the raw response body and parsed checksums.
func FetchChecksums(version, locale string) (rawJSON []byte, checksums map[string]string, err error) {
url := checksumAPIURL(version, locale)
resp, err := httpClient.Get(url)
if err != nil {
return nil, nil, fmt.Errorf("HTTP request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, url)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) // 2MB limit
if err != nil {
return nil, nil, fmt.Errorf("reading response: %w", err)
}
checksums, err = ParseChecksumResponse(body)
if err != nil {
return nil, nil, err
}
return body, checksums, nil
}
package wpcheck
import (
"crypto/subtle"
"encoding/hex"
)
const maxHexDigestLength = 64
func constantTimeHexDigestEqual(actualDigest []byte, expectedHex string) bool {
actualHexLen := len(actualDigest) * 2
if actualHexLen == 0 || actualHexLen > maxHexDigestLength {
return false
}
var actual [maxHexDigestLength]byte
var expected [maxHexDigestLength]byte
hex.Encode(actual[:actualHexLen], actualDigest)
if len(expectedHex) >= actualHexLen {
copy(expected[:actualHexLen], expectedHex[:actualHexLen])
} else {
copy(expected[:actualHexLen], expectedHex)
}
lengthEqual := 0
if len(expectedHex) == actualHexLen {
lengthEqual = 1
}
digestEqual := subtle.ConstantTimeCompare(actual[:actualHexLen], expected[:actualHexLen])
return lengthEqual&digestEqual == 1
}
package wpcheck
import (
"archive/zip"
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"golang.org/x/sys/unix"
)
// Plugin verification mirrors the core-file verification path: when a file
// under /wp-content/plugins/<slug>/ matches the hash we computed from the
// plugin's official wordpress.org ZIP, signature/YARA rule matches on it are
// false positives and should not fire.
const pluginsSegment = "/wp-content/plugins/"
// DetectPluginRoot returns the plugin root directory and slug for a path that
// sits under /wp-content/plugins/<slug>/. Returns empty strings if the path
// is not inside a plugin.
func DetectPluginRoot(path string) (root, slug string) {
idx := strings.Index(path, pluginsSegment)
if idx < 0 {
return "", ""
}
tail := path[idx+len(pluginsSegment):]
slashIdx := strings.Index(tail, "/")
if slashIdx <= 0 {
return "", ""
}
slug = tail[:slashIdx]
root = path[:idx+len(pluginsSegment)] + slug
return root, slug
}
var rePluginVersionHeader = regexp.MustCompile(`(?im)^\s*\*?\s*Version:\s*([^\s]+)`)
// ReadPluginVersion extracts the Version: header from the plugin's main
// file (<pluginRoot>/<slug>.php). WordPress requires this header to exist
// on every published plugin.
func ReadPluginVersion(pluginRoot, slug string) (string, error) {
mainPath := filepath.Join(pluginRoot, slug+".php")
// #nosec G304 -- pluginRoot is derived from a path the scanner received
// from fanotify under /wp-content/plugins/; slug is the immediate child
// segment. The read is bounded by the header-detection limit below.
f, err := os.Open(mainPath)
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
buf := make([]byte, 8192)
n, _ := f.Read(buf)
if n <= 0 {
return "", errors.New("empty plugin main file")
}
m := rePluginVersionHeader.FindSubmatch(buf[:n])
if m == nil {
return "", errors.New("version header not found in plugin main file")
}
return string(m[1]), nil
}
// pluginZipURL returns the canonical wordpress.org download URL for a given
// plugin slug and version.
func pluginZipURL(slug, version string) string {
return fmt.Sprintf("https://downloads.wordpress.org/plugin/%s.%s.zip", slug, version)
}
// FetchPluginChecksums downloads the plugin ZIP from wordpress.org,
// extracts each file, and returns a map of relative path -> SHA256 hex.
// The returned paths are relative to the plugin root (the leading
// "<slug>/" prefix from the ZIP entries is stripped).
func FetchPluginChecksums(slug, version string) (map[string]string, error) {
return fetchPluginChecksumsFromURL(pluginZipURL(slug, version), slug)
}
const maxPluginZipBytes = 100 << 20 // 100 MB ceiling
// ErrPluginNotInWPOrg is returned when wordpress.org responds with HTTP 404
// for a plugin slug+version. Plugins that are not in the wp.org repository
// (paid forks, custom internal plugins, slugs that simply do not exist) need
// to be distinguished from transient errors so the cache can suppress
// further fetch attempts for a TTL.
//
// 5xx responses, network errors, and malformed responses are NOT this
// error - those keep their normal retry behaviour because the plugin may
// still exist in the catalogue and wp.org may simply be having an outage.
var ErrPluginNotInWPOrg = errors.New("plugin not in wordpress.org repository")
func fetchPluginChecksumsFromURL(url, slug string) (map[string]string, error) {
resp, err := httpClient.Get(url) //nolint:gosec,bodyclose // httpClient has a timeout; body is closed below.
if err != nil {
return nil, fmt.Errorf("plugin zip GET failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("plugin zip HTTP 404 from %s: %w", url, ErrPluginNotInWPOrg)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("plugin zip HTTP %d from %s", resp.StatusCode, url)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, maxPluginZipBytes))
if err != nil {
return nil, fmt.Errorf("reading plugin zip: %w", err)
}
zr, err := zip.NewReader(bytes.NewReader(body), int64(len(body)))
if err != nil {
return nil, fmt.Errorf("opening plugin zip: %w", err)
}
out := make(map[string]string, len(zr.File))
prefix := slug + "/"
for _, zf := range zr.File {
if zf.FileInfo().IsDir() {
continue
}
name := zf.Name
if !strings.HasPrefix(name, prefix) {
// Malformed ZIP (e.g. nested into a differently-named folder).
// Skip; callers detect partial results by checking cache emptiness.
continue
}
rel := filepath.Clean(strings.TrimPrefix(name, prefix))
// Reject path-traversal and absolute paths: a crafted ZIP entry
// named "<slug>/../../etc/passwd" would otherwise land in the
// checksum map. Defense-in-depth against a compromised CDN.
if rel == "." || strings.HasPrefix(rel, "..") || strings.HasPrefix(rel, "/") {
continue
}
rc, err := zf.Open()
if err != nil {
return nil, fmt.Errorf("opening zip entry %s: %w", name, err)
}
// Cap decompressed size per entry. Without this, a zip-bomb whose
// compressed body fits under maxPluginZipBytes can still exhaust
// memory during io.Copy. +1 lets us detect overflow.
limited := io.LimitReader(rc, maxPluginZipBytes+1)
h := sha256.New()
nCopied, err := io.Copy(h, limited)
_ = rc.Close()
if err != nil {
return nil, fmt.Errorf("hashing zip entry %s: %w", name, err)
}
if nCopied > maxPluginZipBytes {
return nil, fmt.Errorf("zip entry %s exceeds per-entry size cap", name)
}
out[rel] = hex.EncodeToString(h.Sum(nil))
}
if len(out) == 0 {
return nil, errors.New("plugin zip yielded no checksums")
}
return out, nil
}
// --- Cache plugin support ------------------------------------------------
func pluginKey(slug, version string) string {
return slug + ":" + version
}
func (c *Cache) setPluginChecksums(slug, version string, checksums map[string]string) {
c.mu.Lock()
defer c.mu.Unlock()
if c.pluginChecksums == nil {
c.pluginChecksums = make(map[string]map[string]string)
}
c.pluginChecksums[pluginKey(slug, version)] = checksums
}
func (c *Cache) lookupPluginChecksum(slug, version, relPath string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
m, ok := c.pluginChecksums[pluginKey(slug, version)]
if !ok {
return "", false
}
h, ok := m[relPath]
return h, ok
}
func (c *Cache) hasPluginChecksums(slug, version string) bool {
c.mu.RLock()
_, ok := c.pluginChecksums[pluginKey(slug, version)]
c.mu.RUnlock()
return ok
}
// pluginNotFoundTTL bounds how long a wp.org 404 outcome suppresses
// re-fetches for the same slug+version. After expiry the next cache miss
// retries normally, so a plugin that wp.org publishes later will be
// picked up. 72 hours strikes a balance between not flooding wp.org with
// requests for non-existent plugins and propagating corrections in
// reasonable time.
const pluginNotFoundTTL = 72 * time.Hour
// markPluginNotFound records a wp.org 404 outcome so subsequent fetches
// short-circuit until ttl elapses. Caller passes ttl explicitly so tests
// can shorten or invert it; production code should use pluginNotFoundTTL.
func (c *Cache) markPluginNotFound(slug, version string, ttl time.Duration) {
key := pluginKey(slug, version)
c.mu.Lock()
if c.pluginNotFoundUntil == nil {
c.pluginNotFoundUntil = make(map[string]time.Time)
}
c.pluginNotFoundUntil[key] = time.Now().Add(ttl)
c.mu.Unlock()
}
// isPluginNotFound reports whether an unexpired wp.org 404 marker exists
// for slug+version. Markers are scoped to slug+version so a fork of a
// plugin under a new version number that DOES exist on wp.org is still
// fetched.
func (c *Cache) isPluginNotFound(slug, version string) bool {
key := pluginKey(slug, version)
c.mu.RLock()
until, ok := c.pluginNotFoundUntil[key]
c.mu.RUnlock()
if !ok {
return false
}
return time.Now().Before(until)
}
func (c *Cache) startBackgroundPluginFetch(slug, version string) {
if c.isStopped() {
return
}
// wp.org has already told us this slug+version does not exist;
// suppress the fetch entirely until the marker expires. Without this
// gate every cache miss for a non-wp.org plugin would re-arm the
// 4-attempt retry cycle.
if c.isPluginNotFound(slug, version) {
return
}
key := pluginKey(slug, version)
c.mu.Lock()
if c.fetching == nil {
c.fetching = make(map[string]bool)
}
if c.fetching[key] {
c.mu.Unlock()
return
}
c.fetching[key] = true
c.mu.Unlock()
go c.fetchPluginWithRetry(slug, version, 0)
}
// fetchPluginWithRetry mirrors the core-checksum fetchWithRetry: the
// fetching flag stays set across retries so cache-miss events for the
// same slug/version do not spawn new goroutines. On exhaustion the flag
// is cleared so a future event can retry fresh.
//
// Special case: an HTTP 404 from wordpress.org is treated as a definitive
// "this plugin is not in the wp.org repository" signal. We mark the
// slug+version not-found for pluginNotFoundTTL and skip the retry cycle
// entirely. Network errors and 5xx responses keep their normal retry
// behaviour - those are transient.
func (c *Cache) fetchPluginWithRetry(slug, version string, attempt int) {
backoffs := []time.Duration{1 * time.Minute, 5 * time.Minute, 15 * time.Minute, 1 * time.Hour}
key := pluginKey(slug, version)
if c.isStopped() {
c.clearFetching(key)
return
}
checksums, err := FetchPluginChecksums(slug, version)
if err == nil {
c.setPluginChecksums(slug, version, checksums)
c.clearFetching(key)
fmt.Fprintf(os.Stderr, "wpcheck: cached %d checksums for plugin %s %s\n", len(checksums), slug, version)
return
}
if errors.Is(err, ErrPluginNotInWPOrg) {
c.markPluginNotFound(slug, version, pluginNotFoundTTL)
c.clearFetching(key)
fmt.Fprintf(os.Stderr, "wpcheck: plugin %s %s not in wp.org repository, suppressing retries for %s\n",
slug, version, pluginNotFoundTTL)
return
}
if attempt >= len(backoffs) {
c.clearFetching(key)
fmt.Fprintf(os.Stderr, "wpcheck: plugin fetch abandoned for %s %s after %d attempts: %v\n",
slug, version, attempt+1, err)
return
}
delay := backoffs[attempt]
fmt.Fprintf(os.Stderr, "wpcheck: plugin fetch failed for %s %s, retry in %v: %v\n",
slug, version, delay, err)
c.scheduleRetry(delay, func() {
c.fetchPluginWithRetry(slug, version, attempt+1)
}, func() {
c.clearFetching(key)
})
}
// IsVerifiedPluginFile compares a file against the cached wordpress.org
// checksum for its plugin/version. Returns true only when the on-disk
// content hash matches. Triggers a background fetch on cache miss.
func (c *Cache) IsVerifiedPluginFile(fd int, path string) bool {
root, slug := DetectPluginRoot(path)
if root == "" {
return false
}
rel, err := filepath.Rel(root, path)
if err != nil || strings.HasPrefix(rel, "..") {
return false
}
version, err := ReadPluginVersion(root, slug)
if err != nil || version == "" {
return false
}
expected, ok := c.lookupPluginChecksum(slug, version, rel)
if !ok {
if !c.hasPluginChecksums(slug, version) {
c.startBackgroundPluginFetch(slug, version)
}
return false
}
data := make([]byte, maxFileSize)
n, err := unix.Pread(fd, data, 0)
if n <= 0 || (err != nil && n == 0) {
return false
}
h := sha256.Sum256(data[:n])
return constantTimeHexDigestEqual(h[:], expected)
}
package wpcheck
import (
"time"
)
// SetStopCh wires a daemon-level cancellation channel into the cache.
// Closing the channel signals every pending checksum-retry timer to drop
// the scheduled fetch instead of firing. Safe to call once after NewCache,
// before any fetches start.
func (c *Cache) SetStopCh(stop <-chan struct{}) {
c.stopMu.Lock()
c.stopCh = stop
c.stopMu.Unlock()
}
func (c *Cache) currentStopCh() <-chan struct{} {
c.stopMu.RLock()
defer c.stopMu.RUnlock()
return c.stopCh
}
func (c *Cache) isStopped() bool {
return stopClosed(c.currentStopCh())
}
func stopClosed(stop <-chan struct{}) bool {
if stop == nil {
return false
}
select {
case <-stop:
return true
default:
return false
}
}
func (c *Cache) clearFetching(key string) {
c.mu.Lock()
delete(c.fetching, key)
c.mu.Unlock()
}
// scheduleRetry runs fn after delay unless the cache's stop channel
// closes first. Always returns immediately. Cancellation is necessary
// because the longest checksum-retry backoff (1 hour) survives daemon
// shutdown otherwise, causing wp.org fetches against torn-down state.
func (c *Cache) scheduleRetry(delay time.Duration, fn func(), onCancel func()) {
stop := c.currentStopCh()
go func() {
if stopClosed(stop) {
runCancel(onCancel)
return
}
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-timer.C:
if stopClosed(stop) {
runCancel(onCancel)
return
}
fn()
case <-stop:
runCancel(onCancel)
}
}()
}
func runCancel(onCancel func()) {
if onCancel != nil {
onCancel()
}
}
package yara
import "sync/atomic"
// Backend is the consumable scanning surface shared by the in-process
// *Scanner and out-of-process process supervisor. Callers should depend
// on this interface (via Active()) so they keep working when the daemon
// switches backends at startup. String-valued rule metadata travels on
// Match.Meta, so adapters that historically reached for the compiled
// *yara_x.Rules object (e.g. emailav) now work uniformly under both
// backends -- see internal/emailav/yarax.go.
type Backend interface {
ScanFile(path string, maxBytes int) []Match
ScanBytes(data []byte) []Match
RuleCount() int
Reload() error
}
var activeBackend atomic.Pointer[backendHolder]
type backendHolder struct{ b Backend }
// Active returns the configured scanning backend. When SetActive has
// not been called, it falls back to the in-process singleton Global().
// Returns a nil interface if neither is available (e.g. a !yara build
// with no supervisor wired up); callers must nil-check.
func Active() Backend {
if h := activeBackend.Load(); h != nil && h.b != nil {
return h.b
}
if g := Global(); g != nil {
return g
}
return nil
}
// SetActive installs a scanning backend. Calling with nil clears the
// override and restores the Global() fallback. Safe to call at any
// time; reads in-flight see the prior backend finish and the next read
// sees the new one.
func SetActive(b Backend) {
if b == nil {
activeBackend.Store(nil)
return
}
activeBackend.Store(&backendHolder{b: b})
}
package yara
import (
"fmt"
"os"
"sync"
)
var (
globalScanner *Scanner
globalOnce sync.Once
)
// Init initializes the global YARA-X scanner.
// Returns nil scanner if YARA-X is not compiled in or no rules found.
func Init(rulesDir string) *Scanner {
if !Available() {
return nil
}
globalOnce.Do(func() {
s, err := NewScanner(rulesDir)
if err != nil {
fmt.Fprintf(os.Stderr, "yara: init error: %v\n", err)
return
}
globalScanner = s
})
return globalScanner
}
// Global returns the global YARA-X scanner, or nil if not initialized.
func Global() *Scanner {
return globalScanner
}
package yara
import (
"errors"
"fmt"
"os"
"path/filepath"
"syscall"
)
// validateRulesDir refuses to compile YARA rules from a directory or
// file whose ownership or permissions would let a non-root non-self
// account drop a rule that disables detection. Mirrors the same trust
// rules CSM applies to /etc/csm/conf.d: only root or the running
// process may own the dir or any rule file, and group/world write bits
// are refused. A missing directory is a no-op so an operator who has
// not installed YARA rules yet does not see startup failures.
func validateRulesDir(dir string) error {
if dir == "" {
return nil
}
info, err := os.Stat(dir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return fmt.Errorf("rules dir stat: %w", err)
}
if !info.IsDir() {
return fmt.Errorf("rules dir is not a directory: %s", dir)
}
if trustErr := checkYaraEntryTrust(dir, info); trustErr != nil {
return trustErr
}
entries, err := os.ReadDir(dir)
if err != nil {
return fmt.Errorf("rules dir read: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
ext := filepath.Ext(entry.Name())
if ext != ".yar" && ext != ".yara" {
continue
}
path := filepath.Join(dir, entry.Name())
fileInfo, err := os.Lstat(path)
if err != nil {
return fmt.Errorf("rule file stat %s: %w", path, err)
}
if fileInfo.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("rule file is a symlink: %s", path)
}
if !fileInfo.Mode().IsRegular() {
return fmt.Errorf("rule file is not a regular file: %s", path)
}
if trustErr := checkYaraEntryTrust(path, fileInfo); trustErr != nil {
return trustErr
}
}
return nil
}
// checkYaraEntryTrust enforces the perm + ownership trust rules on a
// single path. Used for both the rules dir itself and each rule file.
func checkYaraEntryTrust(path string, info os.FileInfo) error {
if mode := info.Mode().Perm(); mode&0022 != 0 {
return fmt.Errorf("rules path %s has unsafe mode %04o (group or world writable); set 0750 or stricter", path, mode)
}
sys, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return nil
}
selfUID := uint32(os.Geteuid()) // #nosec G115 -- Linux uid_t is uint32; os.Geteuid returns a non-negative process uid.
if sys.Uid != 0 && sys.Uid != selfUID {
return fmt.Errorf("rules path %s owner uid=%d is neither root nor process uid=%d; refusing to load untrusted rules", path, sys.Uid, selfUID)
}
return nil
}
//go:build !yara
package yara
// Scanner is a no-op stub when YARA-X is not compiled in.
type Scanner struct{}
// Match represents a YARA rule that matched. Meta carries string-valued
// rule metadata pulled from `rule.Metadata()` (see yarax build); under
// this stub no rule ever matches, so Meta is never populated.
type Match struct {
RuleName string
Meta map[string]string
}
// NewScanner returns nil when YARA-X is not available.
func NewScanner(_ string) (*Scanner, error) {
return nil, nil
}
// Reload is a no-op without YARA-X.
func (s *Scanner) Reload() error { return nil }
// ScanBytes returns nil without YARA-X.
func (s *Scanner) ScanBytes(_ []byte) []Match { return nil }
// ScanFile returns nil without YARA-X.
func (s *Scanner) ScanFile(_ string, _ int) []Match { return nil }
// RuleCount returns 0 without YARA-X.
func (s *Scanner) RuleCount() int { return 0 }
// GlobalRules returns nil without YARA-X (no compiled rules available).
func (s *Scanner) GlobalRules() interface{} { return nil }
// Available returns false (YARA-X is not compiled in).
func Available() bool {
return false
}
// TestCompile is a no-op when YARA-X is not compiled in.
func TestCompile(source string) error {
return nil
}
package yaraipc
import (
"errors"
"fmt"
"io"
"net"
"sync"
"syscall"
"time"
)
// ErrWorkerClosed means the worker hung up mid-request.
var ErrWorkerClosed = errors.New("yaraipc: worker connection closed")
// Dialer returns a fresh net.Conn to the worker. Decoupled from
// net.Dial so tests can substitute net.Pipe.
type Dialer func() (net.Conn, error)
// Client is a persistent-connection client. One in-flight request at a
// time: scanner callers do not need concurrency on a single socket and
// serialising simplifies failure semantics.
type Client struct {
mu sync.Mutex
conn net.Conn
dialer Dialer
timeout time.Duration
}
// NewClient constructs a Client that dials socketPath on demand.
func NewClient(socketPath string, timeout time.Duration) *Client {
return NewClientWithDialer(func() (net.Conn, error) {
return net.DialTimeout("unix", socketPath, timeout)
}, timeout)
}
// NewClientWithDialer is the test-friendly constructor.
func NewClientWithDialer(d Dialer, timeout time.Duration) *Client {
return &Client{dialer: d, timeout: timeout}
}
// Close drops the underlying connection if any. The Client stays usable;
// the next call dials again.
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.dropLocked()
}
func (c *Client) ensureConnLocked() (net.Conn, error) {
if c.conn != nil {
return c.conn, nil
}
conn, err := c.dialer()
if err != nil {
return nil, fmt.Errorf("yaraipc: dial: %w", err)
}
c.conn = conn
return conn, nil
}
func (c *Client) dropLocked() error {
if c.conn == nil {
return nil
}
err := c.conn.Close()
c.conn = nil
return err
}
// roundTrip sends req and returns the response frame. A closed socket is
// retried once because the server may intentionally close idle
// connections after their message budget is spent.
func (c *Client) roundTrip(req Frame) (Frame, error) {
c.mu.Lock()
defer c.mu.Unlock()
resp, err := c.roundTripLocked(req)
if err == nil || !isRetryableClosedConn(err) {
return resp, err
}
return c.roundTripLocked(req)
}
func (c *Client) roundTripLocked(req Frame) (Frame, error) {
conn, err := c.ensureConnLocked()
if err != nil {
return Frame{}, err
}
if c.timeout > 0 {
_ = conn.SetDeadline(time.Now().Add(c.timeout))
}
if werr := WriteFrame(conn, req); werr != nil {
_ = c.dropLocked()
return Frame{}, fmt.Errorf("yaraipc: write: %w", werr)
}
resp, rerr := ReadFrame(conn)
if rerr != nil {
_ = c.dropLocked()
if errors.Is(rerr, io.EOF) {
return Frame{}, ErrWorkerClosed
}
return Frame{}, fmt.Errorf("yaraipc: read: %w", rerr)
}
if resp.Error != "" {
return Frame{}, fmt.Errorf("yaraipc: worker: %s", resp.Error)
}
return resp, nil
}
func isRetryableClosedConn(err error) bool {
return errors.Is(err, ErrWorkerClosed) ||
errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrUnexpectedEOF) ||
errors.Is(err, net.ErrClosed) ||
errors.Is(err, syscall.EPIPE) ||
errors.Is(err, syscall.ECONNRESET) ||
errors.Is(err, syscall.ECONNABORTED)
}
// ScanFile is the daemon-side shim for OpScanFile.
func (c *Client) ScanFile(args ScanFileArgs) (ScanResult, error) {
req, err := EncodePayload(OpScanFile, args)
if err != nil {
return ScanResult{}, err
}
resp, err := c.roundTrip(req)
if err != nil {
return ScanResult{}, err
}
var res ScanResult
if len(resp.Payload) > 0 {
if err := DecodePayload(resp, &res); err != nil {
return ScanResult{}, err
}
}
return res, nil
}
// ScanBytes is the daemon-side shim for OpScanBytes.
func (c *Client) ScanBytes(args ScanBytesArgs) (ScanResult, error) {
req, err := EncodePayload(OpScanBytes, args)
if err != nil {
return ScanResult{}, err
}
resp, err := c.roundTrip(req)
if err != nil {
return ScanResult{}, err
}
var res ScanResult
if len(resp.Payload) > 0 {
if err := DecodePayload(resp, &res); err != nil {
return ScanResult{}, err
}
}
return res, nil
}
// Reload is the daemon-side shim for OpReload.
func (c *Client) Reload(args ReloadArgs) (ReloadResult, error) {
req, err := EncodePayload(OpReload, args)
if err != nil {
return ReloadResult{}, err
}
resp, err := c.roundTrip(req)
if err != nil {
return ReloadResult{}, err
}
var res ReloadResult
if len(resp.Payload) > 0 {
if err := DecodePayload(resp, &res); err != nil {
return ReloadResult{}, err
}
}
return res, nil
}
// Ping is the daemon-side shim for OpPing.
func (c *Client) Ping() (PingResult, error) {
req, err := EncodePayload(OpPing, nil)
if err != nil {
return PingResult{}, err
}
resp, err := c.roundTrip(req)
if err != nil {
return PingResult{}, err
}
var res PingResult
if len(resp.Payload) > 0 {
if err := DecodePayload(resp, &res); err != nil {
return PingResult{}, err
}
}
return res, nil
}
// Package yaraipc defines the wire protocol spoken between the CSM daemon
// and the supervised `csm yara-worker` child process. The worker exists to
// isolate the YARA-X cgo surface; a crash in the worker must not take the
// daemon down. See ROADMAP.md item 2 for the decision record.
//
// The protocol is length-prefixed JSON frames on a Unix-domain socket.
// Connections are persistent: the daemon opens one, streams scan and
// reload requests, and reconnects if the worker dies.
package yaraipc
import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
)
// MaxFrameBytes caps a single request or response payload. Sized to cover
// a file sent inline via OpScanBytes up to the scanner's usual 8 MiB read
// ceiling, with headroom for JSON base64 expansion.
const MaxFrameBytes = 16 << 20
// Op selects the handler on the worker side. Strings (not iota ints) so
// adding a new op is additive and mismatched client/worker versions fail
// with a recognisable "unknown op" error instead of silently dispatching
// the wrong handler.
const (
OpScanFile = "scan_file"
OpScanBytes = "scan_bytes"
OpReload = "reload"
OpPing = "ping"
)
// Frame is the envelope. Request frames carry an Op and typed args in
// Payload; response frames carry the typed result (or an Error) in
// Payload and leave Op empty.
type Frame struct {
Op string `json:"op,omitempty"`
Payload json.RawMessage `json:"payload,omitempty"`
Error string `json:"error,omitempty"`
}
// ScanFileArgs asks the worker to read and scan a file by path. MaxBytes
// bounds the read so the daemon cannot make the worker allocate more than
// it agreed to.
type ScanFileArgs struct {
Path string `json:"path"`
MaxBytes int `json:"max_bytes"`
}
// ScanBytesArgs carries file content inline. Used when the caller already
// has the bytes (fanotify buffered reads) and avoids a second file open in
// the worker.
type ScanBytesArgs struct {
Data []byte `json:"data"`
}
// ReloadArgs triggers a rule recompile. RulesDir is optional; if empty
// the worker reuses the directory it was started with.
type ReloadArgs struct {
RulesDir string `json:"rules_dir,omitempty"`
}
// Match mirrors yara.Match but is part of this package's public wire
// contract so the daemon does not need to import the yara package just to
// speak to its worker.
//
// Meta carries string-valued rule metadata (identifier -> value) pulled
// from yara_x `rule.Metadata()` inside the worker, where the compiled
// rules live. Non-string metadata (int / float / bool / bytes) is
// dropped: wiring only string values is a deliberate policy, not a
// fidelity claim. Consumers that need a specific key document their own
// default; e.g. emailav maps a missing "severity" entry to "high".
// Omitted from the wire when empty so the per-scan payload cost is zero
// for the common clean-file case.
type Match struct {
RuleName string `json:"rule"`
Meta map[string]string `json:"meta,omitempty"`
}
// ScanResult is returned for OpScanFile and OpScanBytes.
type ScanResult struct {
Matches []Match `json:"matches,omitempty"`
}
// ReloadResult is returned for OpReload.
type ReloadResult struct {
RuleCount int `json:"rule_count"`
}
// PingResult is returned for OpPing. Used by the supervisor's liveness
// check and as the first frame after a reconnect to confirm the worker is
// past its rule-compile step before real scan traffic begins.
type PingResult struct {
Alive bool `json:"alive"`
RuleCount int `json:"rule_count"`
}
// WriteFrame writes a 4-byte big-endian length prefix followed by the
// JSON-encoded frame. The caller owns any deadline on the underlying
// writer.
func WriteFrame(w io.Writer, f Frame) error {
body, err := json.Marshal(f)
if err != nil {
return fmt.Errorf("marshal frame: %w", err)
}
if len(body) > MaxFrameBytes {
return fmt.Errorf("frame body %d bytes exceeds cap %d", len(body), MaxFrameBytes)
}
var hdr [4]byte
// #nosec G115 -- len(body) is bounded above by MaxFrameBytes (16 MiB), which fits in uint32.
binary.BigEndian.PutUint32(hdr[:], uint32(len(body)))
if _, err := w.Write(hdr[:]); err != nil {
return err
}
if _, err := w.Write(body); err != nil {
return err
}
return nil
}
// ReadFrame reads one length-prefixed JSON frame from r. Frames larger
// than MaxFrameBytes are rejected before the body is read so a hostile or
// corrupt peer cannot make us allocate an unbounded buffer.
func ReadFrame(r io.Reader) (Frame, error) {
var hdr [4]byte
if _, err := io.ReadFull(r, hdr[:]); err != nil {
return Frame{}, err
}
n := binary.BigEndian.Uint32(hdr[:])
if n == 0 {
return Frame{}, errors.New("yaraipc: zero-length frame")
}
if n > MaxFrameBytes {
return Frame{}, fmt.Errorf("yaraipc: frame length %d exceeds cap %d", n, MaxFrameBytes)
}
body := make([]byte, n)
if _, err := io.ReadFull(r, body); err != nil {
return Frame{}, err
}
var f Frame
if err := json.Unmarshal(body, &f); err != nil {
return Frame{}, fmt.Errorf("yaraipc: unmarshal frame: %w", err)
}
return f, nil
}
// DecodePayload unmarshals f.Payload into out. Kept as a helper because
// every handler does it and a typo in the json tag on one side is the
// kind of bug that hides until production.
func DecodePayload(f Frame, out any) error {
if len(f.Payload) == 0 {
return errors.New("yaraipc: empty payload")
}
return json.Unmarshal(f.Payload, out)
}
// EncodePayload marshals v and returns a Frame with Op set. Convenience
// for the client side.
func EncodePayload(op string, v any) (Frame, error) {
if v == nil {
return Frame{Op: op}, nil
}
raw, err := json.Marshal(v)
if err != nil {
return Frame{}, fmt.Errorf("yaraipc: marshal payload: %w", err)
}
return Frame{Op: op, Payload: raw}, nil
}
package yaraipc
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"github.com/pidginhost/csm/internal/obs"
)
// Handler is the worker-side interface. The production implementation
// wraps internal/yara; tests supply fakes to drive the Serve loop.
type Handler interface {
ScanFile(ScanFileArgs) (ScanResult, error)
ScanBytes(ScanBytesArgs) (ScanResult, error)
Reload(ReloadArgs) (ReloadResult, error)
Ping() (PingResult, error)
}
// ServeOptions tunes Serve's behaviour. ErrorLog is called for per-frame
// decode or transport errors. Nil is fine; these errors are not fatal to
// the worker process.
type ServeOptions struct {
ErrorLog func(error)
// MaxMessagesPerConn caps how many frames a single connection may
// serve before the server force-closes it. Zero applies the
// default; a negative value disables the cap. Prevents a malicious
// or runaway peer from holding one socket and looping millions of
// frames through a single goroutine.
MaxMessagesPerConn int
}
// defaultMaxMessagesPerConn is the per-connection frame budget when
// the operator does not configure one. Production daemon-to-worker
// traffic burns a handful of frames per scan; 1e6 is several orders of
// magnitude above the legitimate ceiling and still catches a wedged
// peer in finite time.
const defaultMaxMessagesPerConn = 1_000_000
// Serve accepts connections on ln and dispatches frames to h until ctx
// is cancelled or ln returns an error. Per-connection goroutines are
// spawned; each handles its connection serially (single in-flight
// request), which matches the daemon-side client and keeps failure
// semantics simple.
//
// On ctx cancellation Serve closes the listener and any active
// connections. Without this, clients holding a cached conn would keep
// talking to a zombie handler instead of seeing EOF and reconnecting
// to whatever replaces the worker. In production that "zombie" is a
// crashed process (the kernel closes its sockets), but in-process
// tests and a graceful SIGTERM shutdown both need the explicit close.
func Serve(ctx context.Context, ln net.Listener, h Handler, opts ServeOptions) error {
var (
mu sync.Mutex
active = map[net.Conn]struct{}{}
closed bool
servers sync.WaitGroup
)
closeAll := func() {
mu.Lock()
if closed {
mu.Unlock()
return
}
closed = true
conns := make([]net.Conn, 0, len(active))
for c := range active {
conns = append(conns, c)
}
mu.Unlock()
for _, c := range conns {
_ = c.Close()
}
}
obs.SafeGo("yaraipc-ctx", func() {
<-ctx.Done()
_ = ln.Close()
closeAll()
})
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil {
servers.Wait()
// Accept error after ctx cancellation is the expected
// shutdown path (ln.Close from the watcher goroutine);
// swallowing it here keeps Serve callers from needing to
// distinguish "clean stop" from "real failure".
return nil //nolint:nilerr
}
closeAll()
servers.Wait()
return fmt.Errorf("yaraipc: accept: %w", err)
}
mu.Lock()
if closed {
mu.Unlock()
_ = conn.Close()
continue
}
active[conn] = struct{}{}
mu.Unlock()
servers.Add(1)
c := conn
obs.SafeGo("yaraipc-conn", func() {
defer servers.Done()
serveConn(c, h, opts)
mu.Lock()
delete(active, c)
mu.Unlock()
})
}
}
func serveConn(conn net.Conn, h Handler, opts ServeOptions) {
defer func() { _ = conn.Close() }()
budget := opts.MaxMessagesPerConn
if budget == 0 {
budget = defaultMaxMessagesPerConn
}
served := 0
for {
if budget > 0 && served >= budget {
if opts.ErrorLog != nil {
opts.ErrorLog(fmt.Errorf("max messages per connection (%d) reached; closing", budget))
}
return
}
req, err := ReadFrame(conn)
if err != nil {
if errors.Is(err, io.EOF) {
return
}
if opts.ErrorLog != nil {
opts.ErrorLog(fmt.Errorf("read: %w", err))
}
return
}
resp := safeDispatch(req, h, opts)
if err := WriteFrame(conn, resp); err != nil {
if opts.ErrorLog != nil {
opts.ErrorLog(fmt.Errorf("write: %w", err))
}
return
}
served++
}
}
// safeDispatch runs dispatch and converts a handler panic into an error
// response frame. Without it a panic (malformed rules, a pathological scan
// input, a bad arg) would unwind serveConn with no frame written: the client
// blocks to its deadline, retries, and panics again, while the worker process
// stays up so the supervisor never restarts it -- that scan never completes.
func safeDispatch(req Frame, h Handler, opts ServeOptions) (resp Frame) {
defer func() {
if r := recover(); r != nil {
if opts.ErrorLog != nil {
opts.ErrorLog(fmt.Errorf("handler panic on op %q: %v", req.Op, r))
}
resp = Frame{Error: fmt.Sprintf("handler panic: %v", r)}
}
}()
return dispatch(req, h)
}
func dispatch(req Frame, h Handler) Frame {
switch req.Op {
case OpScanFile:
var args ScanFileArgs
if err := DecodePayload(req, &args); err != nil {
return Frame{Error: fmt.Sprintf("decode scan_file: %v", err)}
}
res, err := h.ScanFile(args)
return responseFrame(res, err)
case OpScanBytes:
var args ScanBytesArgs
if err := DecodePayload(req, &args); err != nil {
return Frame{Error: fmt.Sprintf("decode scan_bytes: %v", err)}
}
res, err := h.ScanBytes(args)
return responseFrame(res, err)
case OpReload:
// Reload payload is optional; an empty frame means reuse the
// worker's startup RulesDir.
var args ReloadArgs
if len(req.Payload) > 0 {
if err := DecodePayload(req, &args); err != nil {
return Frame{Error: fmt.Sprintf("decode reload: %v", err)}
}
}
res, err := h.Reload(args)
return responseFrame(res, err)
case OpPing:
res, err := h.Ping()
return responseFrame(res, err)
default:
return Frame{Error: fmt.Sprintf("yaraipc: unknown op %q", req.Op)}
}
}
func responseFrame(result any, err error) Frame {
if err != nil {
return Frame{Error: err.Error()}
}
f, encErr := EncodePayload("", result)
if encErr != nil {
return Frame{Error: encErr.Error()}
}
return f
}
// Package yaraworker implements the `csm yara-worker` subcommand: a
// child process that exists only to host the YARA-X cgo surface and
// reply to scan requests over a Unix socket. See ROADMAP.md item 2.
//
// The handler here adapts a yara.Scanner (real in `-tags yara` builds,
// no-op in plain builds) to the yaraipc.Handler wire contract. The
// package is deliberately thin: IPC lives in internal/yaraipc, rule
// compilation lives in internal/yara, and supervision lives in
// internal/daemon.
package yaraworker
import (
"github.com/pidginhost/csm/internal/yara"
"github.com/pidginhost/csm/internal/yaraipc"
)
// Scanner is the subset of *yara.Scanner that the handler uses. An
// interface (rather than the concrete type) so tests can inject a fake
// without pulling in the cgo build tag.
type Scanner interface {
ScanFile(path string, maxBytes int) []yara.Match
ScanBytes(data []byte) []yara.Match
Reload() error
RuleCount() int
}
// NewHandler returns a yaraipc.Handler backed by s. A nil scanner is
// permitted: the handler reports Alive=true with zero matches, which is
// the expected behaviour on builds compiled without the yara tag and on
// hosts where no rules directory has been provisioned yet.
func NewHandler(s Scanner) yaraipc.Handler {
return &handler{scanner: s}
}
type handler struct {
scanner Scanner
}
func (h *handler) ScanFile(a yaraipc.ScanFileArgs) (yaraipc.ScanResult, error) {
if h.scanner == nil {
return yaraipc.ScanResult{}, nil
}
return yaraipc.ScanResult{Matches: convertMatches(h.scanner.ScanFile(a.Path, a.MaxBytes))}, nil
}
func (h *handler) ScanBytes(a yaraipc.ScanBytesArgs) (yaraipc.ScanResult, error) {
if h.scanner == nil {
return yaraipc.ScanResult{}, nil
}
return yaraipc.ScanResult{Matches: convertMatches(h.scanner.ScanBytes(a.Data))}, nil
}
func (h *handler) Reload(_ yaraipc.ReloadArgs) (yaraipc.ReloadResult, error) {
if h.scanner == nil {
return yaraipc.ReloadResult{}, nil
}
if err := h.scanner.Reload(); err != nil {
return yaraipc.ReloadResult{}, err
}
return yaraipc.ReloadResult{RuleCount: h.scanner.RuleCount()}, nil
}
func (h *handler) Ping() (yaraipc.PingResult, error) {
if h.scanner == nil {
return yaraipc.PingResult{Alive: true}, nil
}
return yaraipc.PingResult{Alive: true, RuleCount: h.scanner.RuleCount()}, nil
}
func convertMatches(in []yara.Match) []yaraipc.Match {
if len(in) == 0 {
return nil
}
out := make([]yaraipc.Match, len(in))
for i := range in {
out[i] = yaraipc.Match{
RuleName: in[i].RuleName,
Meta: in[i].Meta,
}
}
return out
}
package yaraworker
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
"github.com/pidginhost/csm/internal/yara"
"github.com/pidginhost/csm/internal/yaraipc"
)
// Config is what the `csm yara-worker` subcommand receives from its
// parent process. SocketPath and RulesDir are mandatory; ErrorLog is
// optional.
type Config struct {
SocketPath string
RulesDir string
ErrorLog func(error)
}
// Run is the entrypoint the `csm yara-worker` subcommand calls. It
// binds the Unix socket, compiles rules, and serves until ctx is
// cancelled or an unrecoverable accept error occurs.
//
// Rule-compile failures at startup are not fatal: the worker still
// serves and returns zero matches, so the supervisor can observe the
// condition via Ping and the next OpReload can recover. A fatal failure
// here (bad socket path, permission denied on bind, stale socket that
// cannot be removed) is returned so systemd sees a non-zero exit and
// the supervisor escalates through its backoff.
func Run(ctx context.Context, cfg Config) error {
if cfg.SocketPath == "" {
return fmt.Errorf("yaraworker: socket path is empty")
}
if err := os.MkdirAll(filepath.Dir(cfg.SocketPath), 0o700); err != nil {
return fmt.Errorf("yaraworker: mkdir socket dir: %w", err)
}
// A stale socket file from a previous worker crash blocks bind.
// The supervisor only starts one worker at a time, so there is no
// concurrent binder to race with.
if err := os.Remove(cfg.SocketPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("yaraworker: removing stale socket: %w", err)
}
ln, err := net.Listen("unix", cfg.SocketPath)
if err != nil {
return fmt.Errorf("yaraworker: listen: %w", err)
}
if err := os.Chmod(cfg.SocketPath, 0o600); err != nil {
_ = ln.Close()
return fmt.Errorf("yaraworker: chmod socket: %w", err)
}
scanner, compileErr := yara.NewScanner(cfg.RulesDir)
if compileErr != nil && cfg.ErrorLog != nil {
cfg.ErrorLog(fmt.Errorf("yaraworker: scanner init: %w", compileErr))
}
h := newHandlerNormalised(scanner)
return yaraipc.Serve(ctx, ln, h, yaraipc.ServeOptions{ErrorLog: cfg.ErrorLog})
}
// newHandlerNormalised exists because `*yara.Scanner` may be nil in two
// distinct cases: (1) !yara builds always return (nil, nil), and (2)
// yara-build NewScanner returned (nil, err). In both cases we want the
// handler's untyped-nil fast path, not a typed-nil-inside-interface
// that would blow up when the handler calls a method.
func newHandlerNormalised(s *yara.Scanner) yaraipc.Handler {
if s == nil {
return NewHandler(nil)
}
return NewHandler(s)
}
package yaraworker
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/pidginhost/csm/internal/obs"
"github.com/pidginhost/csm/internal/yara"
"github.com/pidginhost/csm/internal/yaraipc"
)
// SupervisorConfig parameterises the daemon-side lifecycle manager for
// the `csm yara-worker` child process.
//
// Restart backoff policy:
//
// - The first crash triggers a restart after MinRestartInterval.
// - Each consecutive crash doubles the delay up to MaxRestartInterval.
// - A worker that stays up for StableDuration resets the backoff to
// MinRestartInterval so a single bad rule deploy is not punished
// forever.
type SupervisorConfig struct {
BinaryPath string
SocketPath string
RulesDir string
StartTimeout time.Duration
MinRestartInterval time.Duration
MaxRestartInterval time.Duration
StableDuration time.Duration
// ExtraArgs appended to `csm yara-worker`. Tests use this to flag
// the helper process into mock-worker mode.
ExtraArgs []string
// Env override; nil means inherit os.Environ.
Env []string
// OnRestart is called after each unplanned worker exit. Exit code
// is the process exit status (or -1 if the process was killed by a
// signal whose number is signal). Daemons wire this to a finding
// emitter.
OnRestart func(exitCode int, signal syscall.Signal, runDuration time.Duration)
// Logf is an optional structured-log hook. Supervisor internals log
// restarts + transient errors here. Nil is fine.
Logf func(format string, args ...any)
// ClientTimeout is the per-call read/write deadline the supervisor
// imposes on the worker. Scan calls inherit this.
ClientTimeout time.Duration
}
// Supervisor manages the `csm yara-worker` child process and exposes a
// Scanner-shaped surface to the rest of the daemon. One supervisor per
// daemon.
type Supervisor struct {
cfg SupervisorConfig
mu sync.Mutex
cmd *exec.Cmd
client *yaraipc.Client
started time.Time
stopped bool
running atomic.Bool
ctx context.Context
cancel context.CancelFunc
done chan struct{}
// Restart counters for observability / tests. Under mu.
restartCount int
lastExitCode int
lastExitSignal syscall.Signal
}
// NewSupervisor validates cfg and returns an unstarted supervisor.
// Defaults: StartTimeout 10s, MinRestartInterval 1s, MaxRestartInterval
// 60s, StableDuration 30s, ClientTimeout 30s.
func NewSupervisor(cfg SupervisorConfig) (*Supervisor, error) {
if cfg.BinaryPath == "" {
return nil, errors.New("yaraworker: BinaryPath is required")
}
if cfg.SocketPath == "" {
return nil, errors.New("yaraworker: SocketPath is required")
}
if cfg.StartTimeout == 0 {
cfg.StartTimeout = 10 * time.Second
}
if cfg.MinRestartInterval == 0 {
cfg.MinRestartInterval = time.Second
}
if cfg.MaxRestartInterval == 0 {
cfg.MaxRestartInterval = 60 * time.Second
}
if cfg.StableDuration == 0 {
cfg.StableDuration = 30 * time.Second
}
if cfg.ClientTimeout == 0 {
cfg.ClientTimeout = 30 * time.Second
}
return &Supervisor{cfg: cfg}, nil
}
// Start launches the worker and blocks until the first Ping succeeds or
// StartTimeout elapses. Subsequent calls return an error.
func (s *Supervisor) Start(ctx context.Context) error {
s.mu.Lock()
if s.running.Load() {
s.mu.Unlock()
return errors.New("yaraworker: supervisor already started")
}
if s.stopped {
s.mu.Unlock()
return errors.New("yaraworker: supervisor already stopped")
}
s.ctx, s.cancel = context.WithCancel(ctx)
s.done = make(chan struct{})
s.mu.Unlock()
if err := s.spawnAndWaitReady(); err != nil {
s.cancel()
close(s.done)
return err
}
s.mu.Lock()
stopped := s.stopped
if !stopped {
s.running.Store(true)
}
s.mu.Unlock()
obs.Go("yara-supervisor", s.supervise)
if stopped {
return errors.New("yaraworker: supervisor already stopped")
}
return nil
}
// Stop signals the worker to exit, waits for it, and prevents further
// restarts. Safe to call multiple times; subsequent calls are no-ops.
func (s *Supervisor) Stop() error {
s.mu.Lock()
if s.stopped {
s.mu.Unlock()
return nil
}
s.stopped = true
// Clear running so post-Stop ScanFile/ScanBytes/Reload short-circuit to
// the degraded path instead of redialing the now-closed worker socket on
// every call and logging dial errors.
s.running.Store(false)
cancel := s.cancel
done := s.done
s.mu.Unlock()
if cancel != nil {
cancel()
}
s.mu.Lock()
if s.cmd != nil && s.cmd.Process != nil {
_ = s.cmd.Process.Signal(syscall.SIGTERM)
}
if s.client != nil {
_ = s.client.Close()
}
s.mu.Unlock()
if done != nil {
<-done
}
return nil
}
// ScanFile is the daemon-facing entrypoint. Errors from the worker
// surface as zero matches and nil; callers see the same "no match"
// outcome as they would for a clean scan. Distinguish real "worker
// degraded" from "nothing matched" via metrics exported from the
// supervisor, not from a returned error.
func (s *Supervisor) ScanFile(path string, maxBytes int) []yara.Match {
if !s.running.Load() {
return nil
}
s.mu.Lock()
client := s.client
s.mu.Unlock()
if client == nil {
return nil
}
res, err := client.ScanFile(yaraipc.ScanFileArgs{Path: path, MaxBytes: maxBytes})
if err != nil {
s.logf("scan_file: %v", err)
return nil
}
return toYaraMatches(res.Matches)
}
// ScanBytes is the daemon-facing entrypoint for already-in-memory data.
func (s *Supervisor) ScanBytes(data []byte) []yara.Match {
if !s.running.Load() {
return nil
}
s.mu.Lock()
client := s.client
s.mu.Unlock()
if client == nil {
return nil
}
res, err := client.ScanBytes(yaraipc.ScanBytesArgs{Data: data})
if err != nil {
s.logf("scan_bytes: %v", err)
return nil
}
return toYaraMatches(res.Matches)
}
// Reload asks the worker to recompile its rules directory.
func (s *Supervisor) Reload() error {
if !s.running.Load() {
return errors.New("yaraworker: supervisor not running")
}
s.mu.Lock()
client := s.client
s.mu.Unlock()
if client == nil {
return errors.New("yaraworker: no client")
}
_, err := client.Reload(yaraipc.ReloadArgs{})
return err
}
// RuleCount queries the worker. Zero on any error, matching the scanner
// semantics the daemon already expects.
func (s *Supervisor) RuleCount() int {
if !s.running.Load() {
return 0
}
s.mu.Lock()
client := s.client
s.mu.Unlock()
if client == nil {
return 0
}
res, err := client.Ping()
if err != nil {
return 0
}
return res.RuleCount
}
// RestartCount is exposed for metrics + tests.
func (s *Supervisor) RestartCount() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.restartCount
}
// ChildPID returns the current worker's pid, or 0 when no worker is
// running. For operator-facing log lines.
func (s *Supervisor) ChildPID() int {
s.mu.Lock()
defer s.mu.Unlock()
if s.cmd == nil || s.cmd.Process == nil {
return 0
}
return s.cmd.Process.Pid
}
// RestartWorker signals the current worker to exit so the supervise
// loop respawns it against whatever state is now on disk (new rules
// directory, updated binary, etc.). Callers should prefer Reload for
// normal rule updates; RestartWorker is the escalation path for the
// rare case where an in-process recompile cannot or must not be
// trusted. The call returns immediately; the restart is asynchronous
// and observable via the OnRestart callback.
//
// No-op when the supervisor is stopped or has no running child.
func (s *Supervisor) RestartWorker() error {
s.mu.Lock()
if s.stopped {
s.mu.Unlock()
return errors.New("yaraworker: supervisor is stopped")
}
cmd := s.cmd
s.mu.Unlock()
if cmd == nil || cmd.Process == nil {
return errors.New("yaraworker: no running worker")
}
return cmd.Process.Signal(syscall.SIGTERM)
}
// supervise watches the current child and restarts it on exit until
// ctx is cancelled.
func (s *Supervisor) supervise() {
defer close(s.done)
backoff := s.cfg.MinRestartInterval
for {
exitCode, sig := s.waitForChild()
if s.ctx.Err() != nil {
return
}
runDuration := time.Since(s.started)
s.mu.Lock()
s.restartCount++
s.lastExitCode = exitCode
s.lastExitSignal = sig
s.mu.Unlock()
if s.cfg.OnRestart != nil {
s.cfg.OnRestart(exitCode, sig, runDuration)
}
// A stable exit already reset the delay. Short-lived workers and
// failed spawn retries still advance the backoff.
exitBackoffRecorded := runDuration >= s.cfg.StableDuration
if exitBackoffRecorded {
backoff = s.cfg.MinRestartInterval
}
for {
s.logf("worker exited code=%d signal=%v ran=%s, restarting in %s",
exitCode, sig, runDuration.Round(time.Millisecond), backoff)
select {
case <-time.After(backoff):
case <-s.ctx.Done():
return
}
err := s.spawnAndWaitReady()
backoff, exitBackoffRecorded = restartBackoffAfterAttempt(
backoff,
exitBackoffRecorded,
err != nil,
s.cfg.MaxRestartInterval,
)
if err == nil {
break
}
if s.ctx.Err() != nil {
return
}
s.logf("restart failed: %v", err)
}
}
}
func restartBackoffAfterAttempt(
current time.Duration,
exitBackoffRecorded bool,
spawnFailed bool,
max time.Duration,
) (time.Duration, bool) {
if exitBackoffRecorded && !spawnFailed {
return current, exitBackoffRecorded
}
next := current * 2
if next > max {
next = max
}
return next, true
}
// waitForChild blocks until the current worker exits, then returns its
// exit code and signal. -1/0 for either field means "unknown" or "not
// applicable".
func (s *Supervisor) waitForChild() (int, syscall.Signal) {
s.mu.Lock()
cmd := s.cmd
s.mu.Unlock()
if cmd == nil {
return -1, 0
}
err := cmd.Wait()
s.mu.Lock()
if s.cmd == cmd {
s.cmd = nil
}
if s.client != nil {
_ = s.client.Close()
s.client = nil
}
s.mu.Unlock()
if err == nil {
return 0, 0
}
var ee *exec.ExitError
if errors.As(err, &ee) {
if status, ok := ee.Sys().(syscall.WaitStatus); ok {
if status.Signaled() {
return -1, status.Signal()
}
return status.ExitStatus(), 0
}
return ee.ExitCode(), 0
}
return -1, 0
}
func (s *Supervisor) spawnAndWaitReady() error {
s.mu.Lock()
if s.client != nil {
_ = s.client.Close()
s.client = nil
}
s.mu.Unlock()
if err := s.ctx.Err(); err != nil {
return err
}
// Unlink stale socket here too, even though the worker also does
// it: the worker may fail before reaching its own unlink, leaving
// a stale file that blocks dial attempts during a failed start.
if err := os.Remove(s.cfg.SocketPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("yaraworker: removing stale socket: %w", err)
}
args := []string{"yara-worker",
"--socket", s.cfg.SocketPath,
"--rules-dir", s.cfg.RulesDir,
}
args = append(args, s.cfg.ExtraArgs...)
// #nosec G204 -- BinaryPath is supervisor-operator-configured (see
// cmd/csm/main.go binaryPath), not attacker-controlled.
cmd := exec.Command(s.cfg.BinaryPath, args...)
if s.cfg.Env != nil {
cmd.Env = s.cfg.Env
}
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return fmt.Errorf("yaraworker: start worker: %w", err)
}
s.mu.Lock()
s.cmd = cmd
s.started = time.Now()
client := yaraipc.NewClient(s.cfg.SocketPath, s.cfg.ClientTimeout)
s.client = client
s.mu.Unlock()
if err := s.waitForReady(client); err != nil {
_ = cmd.Process.Kill()
_ = cmd.Wait()
s.mu.Lock()
s.cmd = nil
_ = s.client.Close()
s.client = nil
s.mu.Unlock()
return err
}
return nil
}
func (s *Supervisor) waitForReady(client *yaraipc.Client) error {
deadline := time.Now().Add(s.cfg.StartTimeout)
for time.Now().Before(deadline) {
if err := s.ctx.Err(); err != nil {
return err
}
if info, err := os.Stat(s.cfg.SocketPath); err == nil && info.Mode()&os.ModeSocket != 0 {
if _, err := client.Ping(); err == nil {
return nil
}
}
select {
case <-time.After(25 * time.Millisecond):
case <-s.ctx.Done():
return s.ctx.Err()
}
}
return fmt.Errorf("yaraworker: worker did not become ready within %s", s.cfg.StartTimeout)
}
func (s *Supervisor) logf(format string, args ...any) {
if s.cfg.Logf != nil {
s.cfg.Logf(format, args...)
}
}
func toYaraMatches(in []yaraipc.Match) []yara.Match {
if len(in) == 0 {
return nil
}
out := make([]yara.Match, len(in))
for i := range in {
out[i] = yara.Match{
RuleName: in[i].RuleName,
Meta: in[i].Meta,
}
}
return out
}
// defaultSocketPath mirrors the roadmap-agreed location. Exposed for
// the daemon to reach when wiring up config defaults.
func DefaultSocketPath() string {
return filepath.Join("/var", "run", "csm", "yara-worker.sock")
}