package alert
import (
"crypto/sha256"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/config"
)
// Severity levels for findings.
type Severity int
const (
Warning Severity = iota
High Severity = iota
Critical Severity = iota
)
func (s Severity) String() string {
switch s {
case Warning:
return "WARNING"
case High:
return "HIGH"
case Critical:
return "CRITICAL"
}
return "UNKNOWN"
}
// Finding represents a single security check result.
type Finding struct {
Severity Severity `json:"severity"`
Check string `json:"check"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
FilePath string `json:"file_path,omitempty"`
ProcessInfo string `json:"process_info,omitempty"` // "pid=N cmd=name uid=N" from fanotify
PID int `json:"pid,omitempty"` // structured PID for auto-response
Timestamp time.Time `json:"timestamp"`
}
func (f Finding) String() string {
ts := f.Timestamp.Format("2006-01-02 15:04:05")
s := fmt.Sprintf("[%s] %s - %s", f.Severity, f.Check, f.Message)
if f.Details != "" {
s += "\n " + strings.ReplaceAll(f.Details, "\n", "\n ")
}
if f.ProcessInfo != "" {
s += fmt.Sprintf("\n Process: %s", f.ProcessInfo)
}
s += fmt.Sprintf("\n Time: %s", ts)
return s
}
// Key returns a unique key for deduplication.
func (f Finding) Key() string {
if f.Details == "" {
return fmt.Sprintf("%s:%s", f.Check, f.Message)
}
h := sha256.Sum256([]byte(f.Details))
return fmt.Sprintf("%s:%s:%x", f.Check, f.Message, h[:4])
}
// Deduplicate removes findings with the same Check+Message, keeping the first.
func Deduplicate(findings []Finding) []Finding {
seen := make(map[string]bool)
var result []Finding
for _, f := range findings {
key := f.Key()
if !seen[key] {
seen[key] = true
result = append(result, f)
}
}
return result
}
// FormatAlert formats a list of findings into a human-readable alert body.
// Sensitive data (passwords, tokens) is redacted before sending.
func FormatAlert(hostname string, findings []Finding) string {
var b strings.Builder
critCount := 0
highCount := 0
warnCount := 0
for _, f := range findings {
switch f.Severity {
case Critical:
critCount++
case High:
highCount++
case Warning:
warnCount++
}
}
fmt.Fprintf(&b, "SECURITY ALERT - %s\n", hostname)
fmt.Fprintf(&b, "Timestamp: %s\n", time.Now().Format("2006-01-02 15:04:05 MST"))
fmt.Fprintf(&b, "Findings: %d critical, %d high, %d warning\n", critCount, highCount, warnCount)
b.WriteString(strings.Repeat("─", 60) + "\n\n")
for _, sev := range []Severity{Critical, High, Warning} {
for _, f := range findings {
if f.Severity == sev {
b.WriteString(sanitizeFinding(f).String())
b.WriteString("\n\n")
}
}
}
b.WriteString(strings.Repeat("─", 60) + "\n")
b.WriteString("CSM - Continuous Security Monitor\n")
return b.String()
}
// sanitizeFinding redacts sensitive data (passwords, tokens, secrets)
// from finding messages and details before including them in alerts.
func sanitizeFinding(f Finding) Finding {
f.Message = redactSensitive(f.Message)
f.Details = redactSensitive(f.Details)
return f
}
// redactSensitive replaces password values and tokens in text with [REDACTED].
func redactSensitive(s string) string {
if s == "" {
return s
}
// Redact password= values in URLs and POST data.
// Matches: password=X, pass=X, passwd=X (up to next & or space or quote).
//
// The search base advances past each replacement (or past an
// empty-value occurrence) so we never re-match the same prefix
// position on the next iteration. An earlier version of this code
// restarted the search at position 0 after every replacement, which
// re-found the same prefix and re-wrote `[REDACTED]` -> `[REDACTED]`
// forever whenever the replacement was non-empty. That infinite
// loop would hang the daemon's alert dispatch on any log line that
// contained a populated password field.
for _, prefix := range []string{
"password=", "pass=", "passwd=", "new_password=",
"old_password=", "confirmpassword=",
} {
searchFrom := 0
for searchFrom < len(s) {
lower := strings.ToLower(s[searchFrom:])
rel := strings.Index(lower, prefix)
if rel < 0 {
break
}
idx := searchFrom + rel
valStart := idx + len(prefix)
valEnd := valStart
for valEnd < len(s) {
c := s[valEnd]
if c == '&' || c == ' ' || c == '\n' || c == '"' || c == '\'' || c == ',' {
break
}
valEnd++
}
if valEnd > valStart {
s = s[:valStart] + "[REDACTED]" + s[valEnd:]
searchFrom = valStart + len("[REDACTED]")
} else {
// Empty value (e.g. `password=&`): advance past this
// occurrence so a later populated field is still redacted.
searchFrom = valStart
}
}
}
// Redact API token values (long alphanumeric strings after token-like keys)
for _, prefix := range []string{"token_value=", "api_token="} {
lower := strings.ToLower(s)
if idx := strings.Index(lower, prefix); idx >= 0 {
valStart := idx + len(prefix)
valEnd := valStart
for valEnd < len(s) && s[valEnd] != ' ' && s[valEnd] != '\n' && s[valEnd] != '&' {
valEnd++
}
if valEnd > valStart {
s = s[:valStart] + "[REDACTED]" + s[valEnd:]
}
}
}
return s
}
func filterChecks(findings []Finding, disabledChecks []string) []Finding {
if len(findings) == 0 || len(disabledChecks) == 0 {
return findings
}
disabled := make(map[string]bool, len(disabledChecks))
for _, check := range disabledChecks {
check = strings.TrimSpace(check)
if check != "" {
disabled[check] = true
}
}
if len(disabled) == 0 {
return findings
}
filtered := make([]Finding, 0, len(findings))
for _, f := range findings {
if !disabled[f.Check] {
filtered = append(filtered, f)
}
}
return filtered
}
func buildSubject(hostname string, findings []Finding) string {
subject := fmt.Sprintf("[CSM] %s - %d security finding(s)", hostname, len(findings))
for _, f := range findings {
if f.Severity == Critical {
return fmt.Sprintf("[CSM] CRITICAL - %s - %d finding(s)", hostname, len(findings))
}
}
return subject
}
// rateLimitState tracks alerts sent per hour.
type rateLimitState struct {
Hour string `json:"hour"`
Count int `json:"count"`
}
var rateLimitMu sync.Mutex
// checkRateLimit returns true if we can send more alerts this hour.
func checkRateLimit(statePath string, maxPerHour int) bool {
rateLimitMu.Lock()
defer rateLimitMu.Unlock()
rlPath := filepath.Join(statePath, "ratelimit.json")
currentHour := time.Now().Format("2006-01-02T15")
var rl rateLimitState
// #nosec G304 -- filepath.Join(statePath, "ratelimit.json"); statePath from operator config.
data, err := os.ReadFile(rlPath)
if err == nil {
_ = json.Unmarshal(data, &rl)
}
// Reset if new hour
if rl.Hour != currentHour {
rl = rateLimitState{Hour: currentHour, Count: 0}
}
if rl.Count >= maxPerHour {
return false
}
rl.Count++
newData, _ := json.Marshal(rl)
_ = os.WriteFile(rlPath, newData, 0600)
return true
}
// Dispatch sends alerts via all configured channels.
func Dispatch(cfg *config.Config, findings []Finding) error {
// Deduplicate
findings = Deduplicate(findings)
// Filter out blocked IP alerts if configured
findings = FilterBlockedAlerts(cfg, findings)
if len(findings) == 0 {
return nil
}
emailFindings := []Finding(nil)
if cfg.Alerts.Email.Enabled {
emailFindings = filterChecks(findings, cfg.Alerts.Email.DisabledChecks)
}
webhookFindings := []Finding(nil)
if cfg.Alerts.Webhook.Enabled {
webhookFindings = findings
}
if len(emailFindings) == 0 && len(webhookFindings) == 0 {
return nil
}
// Rate limit check — CRITICAL realtime findings (malware, webshells,
// backdoors) always get through. Only non-critical alerts are rate-limited.
hasCritical := false
for _, f := range findings {
if f.Severity == Critical {
hasCritical = true
break
}
}
if !hasCritical && !checkRateLimit(cfg.StatePath, cfg.Alerts.MaxPerHour) {
fmt.Fprintf(os.Stderr, "Alert rate limit reached (%d/hour), skipping non-critical alert dispatch\n", cfg.Alerts.MaxPerHour)
return nil
}
// Still count critical dispatches toward the rate limit budget
if hasCritical {
checkRateLimit(cfg.StatePath, cfg.Alerts.MaxPerHour)
}
var errs []error
if len(emailFindings) > 0 {
subject := buildSubject(cfg.Hostname, emailFindings)
body := FormatAlert(cfg.Hostname, emailFindings)
if err := SendEmail(cfg, subject, body); err != nil {
errs = append(errs, fmt.Errorf("email: %w", err))
}
}
if len(webhookFindings) > 0 {
subject := buildSubject(cfg.Hostname, webhookFindings)
body := FormatAlert(cfg.Hostname, webhookFindings)
if err := SendWebhook(cfg, subject, body); err != nil {
errs = append(errs, fmt.Errorf("webhook: %w", err))
}
}
if len(errs) > 0 {
msgs := make([]string, len(errs))
for i, e := range errs {
msgs[i] = e.Error()
}
return fmt.Errorf("alert dispatch errors: %s", strings.Join(msgs, "; "))
}
return nil
}
// SendHeartbeat pings a dead man's switch URL.
func SendHeartbeat(cfg *config.Config) {
if !cfg.Alerts.Heartbeat.Enabled || cfg.Alerts.Heartbeat.URL == "" {
return
}
client := httpClient(10 * time.Second)
resp, err := client.Get(cfg.Alerts.Heartbeat.URL)
if err != nil {
fmt.Fprintf(os.Stderr, "Heartbeat failed: %v\n", err)
return
}
defer func() { _ = resp.Body.Close() }()
}
package alert
import (
"fmt"
"net/smtp"
"strings"
"github.com/pidginhost/csm/internal/config"
)
func SendEmail(cfg *config.Config, subject, body string) error {
to := cfg.Alerts.Email.To
from := cfg.Alerts.Email.From
smtpAddr := cfg.Alerts.Email.SMTP
if len(to) == 0 {
return fmt.Errorf("no email recipients configured")
}
msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nContent-Type: text/plain; charset=UTF-8\r\n\r\n%s",
from,
strings.Join(to, ", "),
subject,
body,
)
host := strings.Split(smtpAddr, ":")[0]
err := smtp.SendMail(smtpAddr, nil, from, to, []byte(msg))
if err != nil {
// Retry without auth for local servers
c, dialErr := smtp.Dial(smtpAddr)
if dialErr != nil {
return fmt.Errorf("smtp dial %s: %w (original: %v)", smtpAddr, dialErr, err)
}
defer func() { _ = c.Close() }()
if err := c.Hello(host); err != nil {
return fmt.Errorf("smtp hello: %w", err)
}
if err := c.Mail(from); err != nil {
return fmt.Errorf("smtp mail from: %w", err)
}
for _, addr := range to {
if err := c.Rcpt(addr); err != nil {
return fmt.Errorf("smtp rcpt %s: %w", addr, err)
}
}
w, err := c.Data()
if err != nil {
return fmt.Errorf("smtp data: %w", err)
}
if _, err := w.Write([]byte(msg)); err != nil {
return fmt.Errorf("smtp write: %w", err)
}
if err := w.Close(); err != nil {
return fmt.Errorf("smtp close: %w", err)
}
return c.Quit()
}
return nil
}
package alert
import (
"encoding/json"
"net"
"os"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/config"
)
// FilterBlockedAlerts removes reputation and auto-block alerts for IPs
// that are currently blocked in CSM firewall (either just blocked or previously blocked).
func FilterBlockedAlerts(cfg *config.Config, findings []Finding) []Finding {
if !cfg.Suppressions.SuppressBlockedAlerts {
return findings
}
// Load all currently blocked IPs from state
blockedIPs := loadBlockedIPs(cfg.StatePath)
// Also collect IPs and subnets blocked in this batch.
var blockedSubnets []*net.IPNet
for _, f := range findings {
if f.Check != "auto_block" {
continue
}
parts := strings.Fields(f.Message)
for i, p := range parts {
if p == "AUTO-BLOCK:" && i+1 < len(parts) {
blockedIPs[parts[i+1]] = true
break
}
if p == "AUTO-BLOCK-SUBNET:" && i+1 < len(parts) {
if _, ipnet, err := net.ParseCIDR(parts[i+1]); err == nil {
blockedSubnets = append(blockedSubnets, ipnet)
}
break
}
}
}
// Also suppress alerts for IPs queued for blocking (rate-limited).
// These will be blocked once the rate limit resets - no need to alert.
for ip := range loadPendingIPs(cfg.StatePath) {
blockedIPs[ip] = true
}
if len(blockedIPs) == 0 && len(blockedSubnets) == 0 {
return findings
}
// Filter out alerts for IPs that are handled automatically.
// When suppress_blocked_alerts is on, the operator doesn't want to be
// notified about IPs that are already dealt with - they only want alerts
// that require human action.
var filtered []Finding
for _, f := range findings {
if f.Check == "ip_reputation" || f.Check == "local_threat_score" {
// If auto-blocking is enabled, these IPs are handled automatically.
// Skip the alert - there's nothing for the operator to do.
if cfg.AutoResponse.Enabled && cfg.AutoResponse.BlockIPs {
continue
}
// Check if the IP is already blocked by exact IP match.
isBlocked := false
for ip := range blockedIPs {
if strings.Contains(f.Message, ip) {
isBlocked = true
break
}
}
// Also suppress if the IP falls within a freshly-blocked subnet.
// AUTO-BLOCK-SUBNET: findings from the same batch must silence
// per-IP reputation alerts for addresses inside that /24.
if !isBlocked && len(blockedSubnets) > 0 {
if parsed := extractIPFromFindingMessage(f.Message); parsed != nil {
for _, subnet := range blockedSubnets {
if subnet.Contains(parsed) {
isBlocked = true
break
}
}
}
}
if isBlocked {
continue
}
}
if f.Check == "auto_block" {
continue
}
filtered = append(filtered, f)
}
return filtered
}
// extractIPFromFindingMessage scans a finding message for the first token
// that parses as a valid IP address and returns it. Returns nil if no IP
// is found. Used to match reputation findings against blocked-subnet CIDRs.
func extractIPFromFindingMessage(msg string) net.IP {
for _, field := range strings.Fields(msg) {
// Strip common punctuation that may trail an IP in a message.
candidate := strings.TrimRight(field, ",:;)([]")
if ip := net.ParseIP(candidate); ip != nil {
return ip
}
}
return nil
}
// loadPendingIPs reads IPs queued for blocking from blocked_ips.json.
func loadPendingIPs(statePath string) map[string]bool {
ips := make(map[string]bool)
type pending struct {
IP string `json:"ip"`
}
type blockFile struct {
Pending []pending `json:"pending"`
}
// #nosec G304 -- filepath.Join under operator-configured statePath.
data, err := os.ReadFile(filepath.Join(statePath, "blocked_ips.json"))
if err == nil {
var bf blockFile
if json.Unmarshal(data, &bf) == nil {
for _, p := range bf.Pending {
ips[p.IP] = true
}
}
}
return ips
}
// BlockedIPsFunc is an optional callback that returns currently blocked IPs.
// Set by the daemon (or tests) to provide blocked IPs from bbolt store,
// avoiding a circular import between alert and store packages.
// When nil, loadBlockedIPs falls back to reading flat files.
var BlockedIPsFunc func() map[string]bool
// loadBlockedIPs reads blocked IPs from both the firewall engine state
// and the legacy blocked_ips.json file.
func loadBlockedIPs(statePath string) map[string]bool {
ips := make(map[string]bool)
now := time.Now()
// Use injected loader (bbolt-backed) when available.
if BlockedIPsFunc != nil {
for ip, v := range BlockedIPsFunc() {
ips[ip] = v
}
} else {
// Fallback: read from firewall engine state.json (nftables) flat file.
// #nosec G304 -- filepath.Join under operator-configured statePath.
if fwData, err := os.ReadFile(filepath.Join(statePath, "firewall", "state.json")); err == nil {
var fwState struct {
Blocked []struct {
IP string `json:"ip"`
ExpiresAt time.Time `json:"expires_at"`
} `json:"blocked"`
}
if json.Unmarshal(fwData, &fwState) == nil {
for _, entry := range fwState.Blocked {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
ips[entry.IP] = true
}
}
}
}
}
// Also read from blocked_ips.json (CSM auto-block)
type blockFile struct {
IPs []struct {
IP string `json:"ip"`
ExpiresAt time.Time `json:"expires_at"`
} `json:"ips"`
}
// #nosec G304 -- filepath.Join under operator-configured statePath.
if data, err := os.ReadFile(filepath.Join(statePath, "blocked_ips.json")); err == nil {
var bf blockFile
if json.Unmarshal(data, &bf) == nil {
for _, entry := range bf.IPs {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
ips[entry.IP] = true
}
}
}
}
return ips
}
package alert
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/pidginhost/csm/internal/config"
)
func httpClient(timeout time.Duration) *http.Client {
return &http.Client{Timeout: timeout}
}
func SendWebhook(cfg *config.Config, subject, body string) error {
url := cfg.Alerts.Webhook.URL
if url == "" {
return fmt.Errorf("no webhook URL configured")
}
var payload []byte
var err error
switch cfg.Alerts.Webhook.Type {
case "slack":
payload, err = json.Marshal(map[string]string{
"text": fmt.Sprintf("*%s*\n```\n%s\n```", subject, body),
})
case "discord":
payload, err = json.Marshal(map[string]string{
"content": fmt.Sprintf("**%s**\n```\n%s\n```", subject, body),
})
default:
payload, err = json.Marshal(map[string]string{
"subject": subject,
"body": body,
})
}
if err != nil {
return fmt.Errorf("marshaling webhook payload: %w", err)
}
client := httpClient(10 * time.Second)
resp, err := client.Post(url, "application/json", bytes.NewReader(payload))
if err != nil {
return fmt.Errorf("webhook POST: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
return fmt.Errorf("webhook returned %d", resp.StatusCode)
}
return nil
}
package attackdb
import (
"bufio"
"fmt"
"net"
"os"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// AttackType categorises observed attacks for grouping and scoring.
type AttackType string
const (
AttackBruteForce AttackType = "brute_force"
AttackWAFBlock AttackType = "waf_block"
AttackWebshell AttackType = "webshell"
AttackPhishing AttackType = "phishing"
AttackC2 AttackType = "c2"
AttackRecon AttackType = "recon"
AttackSPAM AttackType = "spam"
AttackCPanelLogin AttackType = "cpanel_login"
AttackFileUpload AttackType = "file_upload"
AttackReputation AttackType = "reputation"
AttackOther AttackType = "other"
)
// checkToAttack maps alert.Finding.Check values to attack types.
var checkToAttack = map[string]AttackType{
// Brute force
"wp_login_bruteforce": AttackBruteForce,
"xmlrpc_abuse": AttackBruteForce,
"ftp_bruteforce": AttackBruteForce,
"ssh_login_unknown_ip": AttackBruteForce,
"ssh_login_realtime": AttackBruteForce,
"webmail_bruteforce": AttackBruteForce,
"api_auth_failure": AttackBruteForce,
"api_auth_failure_realtime": AttackBruteForce,
"ftp_auth_failure_realtime": AttackBruteForce,
"pam_bruteforce": AttackBruteForce,
"smtp_bruteforce": AttackBruteForce,
"smtp_subnet_spray": AttackBruteForce,
"mail_bruteforce": AttackBruteForce,
"mail_subnet_spray": AttackBruteForce,
"mail_account_compromised": AttackBruteForce,
"admin_panel_bruteforce": AttackBruteForce,
// Webshells and malware
"webshell": AttackWebshell,
"new_webshell_file": AttackWebshell,
"obfuscated_php": AttackWebshell,
"php_dropper": AttackWebshell,
"suspicious_php_content": AttackWebshell,
"new_php_in_languages": AttackWebshell,
"new_php_in_upgrade": AttackWebshell,
"backdoor_binary": AttackWebshell,
"new_executable_in_config": AttackWebshell,
// Phishing
"phishing_page": AttackPhishing,
"phishing_php": AttackPhishing,
"phishing_iframe": AttackPhishing,
"phishing_redirector": AttackPhishing,
"phishing_credential_log": AttackPhishing,
"phishing_kit_archive": AttackPhishing,
"phishing_directory": AttackPhishing,
// C2 and suspicious processes
"fake_kernel_thread": AttackC2,
"suspicious_process": AttackC2,
"php_suspicious_execution": AttackC2,
"user_outbound_connection": AttackC2,
"exfiltration_paste_site": AttackC2,
// Recon
"wp_user_enumeration": AttackRecon,
// SPAM
"mail_per_account": AttackSPAM,
"exim_frozen_realtime": AttackSPAM,
// WAF
"modsec_block": AttackWAFBlock,
"waf_block": AttackWAFBlock,
// cPanel/webmail login
"cpanel_login": AttackCPanelLogin,
"cpanel_login_realtime": AttackCPanelLogin,
"cpanel_multi_ip_login": AttackCPanelLogin,
"webmail_login_realtime": AttackCPanelLogin,
"ftp_login": AttackCPanelLogin,
"ftp_login_realtime": AttackCPanelLogin,
"pam_login": AttackCPanelLogin,
// File upload
"cpanel_file_upload_realtime": AttackFileUpload,
// Reputation - known malicious IPs from threat database
"ip_reputation": AttackReputation,
// NOTE: "local_threat_score" is intentionally excluded - it is a derived
// finding, not a raw attack. Recording it would create a feedback loop
// that inflates EventCount by +1 every 10-minute cycle.
}
// Event is a single observed attack incident.
type Event struct {
Timestamp time.Time `json:"ts"`
IP string `json:"ip"`
AttackType AttackType `json:"type"`
CheckName string `json:"check"`
Severity int `json:"sev"`
Account string `json:"account,omitempty"`
Message string `json:"msg,omitempty"`
}
// IPRecord is the per-IP aggregated intelligence record.
type IPRecord struct {
IP string `json:"ip"`
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
EventCount int `json:"event_count"`
AttackCounts map[AttackType]int `json:"attack_counts"`
Accounts map[string]int `json:"accounts"`
ThreatScore int `json:"threat_score"`
AutoBlocked bool `json:"auto_blocked"`
}
// DB is the in-memory attack database backed by JSON files.
type DB struct {
mu sync.RWMutex
records map[string]*IPRecord
deletedIPs map[string]struct{}
pendingEvents []Event
dbPath string
dirty bool
stopCh chan struct{}
wg sync.WaitGroup
}
var (
globalDB *DB
dbInitOnce sync.Once
)
// Init initializes the global attack database.
func Init(statePath string) *DB {
dbInitOnce.Do(func() {
dbPath := statePath + "/attack_db"
_ = os.MkdirAll(dbPath, 0700)
db := &DB{
records: make(map[string]*IPRecord),
deletedIPs: make(map[string]struct{}),
dbPath: dbPath,
stopCh: make(chan struct{}),
}
db.load()
db.pruneExpired()
// Background saver - flush dirty records every 30 seconds
db.wg.Add(1)
go db.backgroundSaver()
globalDB = db
})
return globalDB
}
// SeedFromPermanentBlocklist imports IPs from the threat DB permanent blocklist
// into the attack database. These are IPs that already attacked and were auto-blocked.
// Only imports IPs not already in the attack DB.
func (db *DB) SeedFromPermanentBlocklist(statePath string) int {
path := statePath + "/threat_db/permanent.txt"
// #nosec G304 -- fixed filename under operator-configured statePath.
f, err := os.Open(path)
if err != nil {
return 0
}
defer func() { _ = f.Close() }()
imported := 0
now := time.Now()
scanner := bufio.NewScanner(f)
db.mu.Lock()
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
ip := fields[0]
if net.ParseIP(ip) == nil {
continue
}
if _, exists := db.records[ip]; exists {
continue // already tracked
}
// Extract reason from comment: "1.2.3.4 # reason [date]"
reason := "auto-blocked (historical)"
if idx := strings.Index(line, "# "); idx > 0 {
reason = strings.TrimSpace(line[idx+2:])
}
db.records[ip] = &IPRecord{
IP: ip,
FirstSeen: now,
LastSeen: now,
EventCount: 1,
AttackCounts: map[AttackType]int{AttackOther: 1},
Accounts: make(map[string]int),
AutoBlocked: true,
}
db.records[ip].ThreatScore = ComputeScore(db.records[ip])
db.pendingEvents = append(db.pendingEvents, Event{
Timestamp: now,
IP: ip,
AttackType: AttackOther,
CheckName: "permanent_blocklist_import",
Severity: 2,
Message: truncate(reason, 200),
})
imported++
}
if imported > 0 {
db.dirty = true
}
db.mu.Unlock()
return imported
}
// Global returns the global attack database instance.
func Global() *DB {
return globalDB
}
// RecordFinding records an attack event from a finding.
// Fire-and-forget: never blocks, never panics.
func (db *DB) RecordFinding(f alert.Finding) {
attackType, ok := checkToAttack[f.Check]
if !ok {
return // not an attack-related check
}
ip := extractIP(f.Message)
if ip == "" {
return
}
account := extractAccount(f.Message, f.Details)
event := Event{
Timestamp: f.Timestamp,
IP: ip,
AttackType: attackType,
CheckName: f.Check,
Severity: int(f.Severity),
Account: account,
Message: truncate(f.Message, 200),
}
now := f.Timestamp
if now.IsZero() {
now = time.Now()
}
db.mu.Lock()
rec, exists := db.records[ip]
if !exists {
rec = &IPRecord{
IP: ip,
FirstSeen: now,
AttackCounts: make(map[AttackType]int),
Accounts: make(map[string]int),
}
db.records[ip] = rec
}
rec.LastSeen = now
rec.EventCount++
rec.AttackCounts[attackType]++
if account != "" {
rec.Accounts[account]++
}
rec.ThreatScore = ComputeScore(rec)
db.pendingEvents = append(db.pendingEvents, event)
db.dirty = true
db.mu.Unlock()
}
// MarkBlocked sets the auto-blocked flag on an IP record.
func (db *DB) MarkBlocked(ip string) {
db.mu.Lock()
if rec, ok := db.records[ip]; ok {
rec.AutoBlocked = true
rec.ThreatScore = ComputeScore(rec)
db.dirty = true
}
db.mu.Unlock()
}
// LookupIP returns the record for an IP, or nil if not tracked.
func (db *DB) LookupIP(ip string) *IPRecord {
db.mu.RLock()
defer db.mu.RUnlock()
rec, ok := db.records[ip]
if !ok {
return nil
}
// Return a copy to avoid races
cp := *rec
cp.AttackCounts = make(map[AttackType]int, len(rec.AttackCounts))
for k, v := range rec.AttackCounts {
cp.AttackCounts[k] = v
}
cp.Accounts = make(map[string]int, len(rec.Accounts))
for k, v := range rec.Accounts {
cp.Accounts[k] = v
}
return &cp
}
// TopAttackers returns the top N IPs by threat score.
func (db *DB) TopAttackers(n int) []*IPRecord {
db.mu.RLock()
defer db.mu.RUnlock()
all := make([]*IPRecord, 0, len(db.records))
for _, rec := range db.records {
cp := *rec
cp.AttackCounts = make(map[AttackType]int, len(rec.AttackCounts))
for k, v := range rec.AttackCounts {
cp.AttackCounts[k] = v
}
cp.Accounts = make(map[string]int, len(rec.Accounts))
for k, v := range rec.Accounts {
cp.Accounts[k] = v
}
all = append(all, &cp)
}
// Sort by threat score descending, then event count
sortRecords(all)
if n > len(all) {
n = len(all)
}
return all[:n]
}
// Flush saves all pending data to disk. Called on daemon shutdown.
func (db *DB) Flush() error {
db.mu.Lock()
events := db.pendingEvents
db.pendingEvents = nil
dirty := db.dirty
db.dirty = false
db.mu.Unlock()
if len(events) > 0 {
db.appendEvents(events)
}
if dirty {
db.saveRecords()
}
return nil
}
// Stop stops the background saver and flushes.
func (db *DB) Stop() {
close(db.stopCh)
db.wg.Wait()
_ = db.Flush()
}
func (db *DB) backgroundSaver() {
defer db.wg.Done()
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-db.stopCh:
return
case <-ticker.C:
_ = db.Flush()
}
}
}
// extractIP pulls an IP address from a finding message.
func extractIP(message string) string {
for _, sep := range []string{" from ", ": ", "accessing server: "} {
if idx := strings.Index(message, sep); idx >= 0 {
rest := message[idx+len(sep):]
fields := strings.Fields(rest)
if len(fields) > 0 {
ip := strings.TrimRight(fields[0], ",:;)([]")
// Strip AbuseIPDB score suffix like "(AbuseIPDB"
if paren := strings.Index(ip, "("); paren > 0 {
ip = ip[:paren]
}
if net.ParseIP(ip) != nil {
return ip
}
}
}
}
return ""
}
// extractAccount tries to pull a cPanel account name from message or details.
func extractAccount(message, details string) string {
// Check details first: "Account: username"
for _, text := range []string{details, message} {
if idx := strings.Index(text, "Account: "); idx >= 0 {
rest := text[idx+9:]
fields := strings.Fields(rest)
if len(fields) > 0 {
return fields[0]
}
}
}
// Try /home/username/ pattern
if idx := strings.Index(message, "/home/"); idx >= 0 {
rest := message[idx+6:]
if slash := strings.Index(rest, "/"); slash > 0 {
return rest[:slash]
}
}
return ""
}
func truncate(s string, n int) string {
r := []rune(s)
if len(r) <= n {
return s
}
return string(r[:n])
}
// RemoveIP removes an IP from the attack database entirely.
func (db *DB) RemoveIP(ip string) {
db.mu.Lock()
delete(db.records, ip)
if db.deletedIPs == nil {
db.deletedIPs = make(map[string]struct{})
}
db.deletedIPs[ip] = struct{}{}
db.dirty = true
db.mu.Unlock()
}
// PruneExpired removes records older than 90 days.
func (db *DB) PruneExpired() {
db.pruneExpired()
}
func (db *DB) pruneExpired() {
cutoff := time.Now().Add(-90 * 24 * time.Hour)
db.mu.Lock()
for ip, rec := range db.records {
if rec.LastSeen.Before(cutoff) {
delete(db.records, ip)
if db.deletedIPs == nil {
db.deletedIPs = make(map[string]struct{})
}
db.deletedIPs[ip] = struct{}{}
db.dirty = true
}
}
db.mu.Unlock()
}
// TotalIPs returns the number of tracked IPs.
func (db *DB) TotalIPs() int {
db.mu.RLock()
defer db.mu.RUnlock()
return len(db.records)
}
// AllRecords returns a deep-copy snapshot of all records.
func (db *DB) AllRecords() []*IPRecord {
db.mu.RLock()
defer db.mu.RUnlock()
result := make([]*IPRecord, 0, len(db.records))
for _, rec := range db.records {
cp := *rec
cp.AttackCounts = make(map[AttackType]int, len(rec.AttackCounts))
for k, v := range rec.AttackCounts {
cp.AttackCounts[k] = v
}
cp.Accounts = make(map[string]int, len(rec.Accounts))
for k, v := range rec.Accounts {
cp.Accounts[k] = v
}
result = append(result, &cp)
}
return result
}
// FormatTopLine returns a summary string for stderr logging.
func (db *DB) FormatTopLine() string {
db.mu.RLock()
defer db.mu.RUnlock()
total := len(db.records)
blocked := 0
for _, r := range db.records {
if r.AutoBlocked {
blocked++
}
}
return fmt.Sprintf("%d IPs tracked, %d auto-blocked", total, blocked)
}
package attackdb
import (
"bufio"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
"github.com/pidginhost/csm/internal/store"
)
const (
recordsFile = "records.json"
eventsFile = "events.jsonl"
maxEventsBytes = 10 * 1024 * 1024 // 10 MB
)
// load reads IP records from the bbolt store (if available) or from
// the flat-file records.json.
func (db *DB) load() {
if sdb := store.Global(); sdb != nil {
storeRecords := sdb.LoadAllIPRecords()
for ip, sr := range storeRecords {
rec := &IPRecord{
IP: sr.IP,
FirstSeen: sr.FirstSeen,
LastSeen: sr.LastSeen,
EventCount: sr.EventCount,
ThreatScore: sr.ThreatScore,
AutoBlocked: sr.AutoBlocked,
AttackCounts: make(map[AttackType]int),
Accounts: make(map[string]int),
}
for k, v := range sr.AttackCounts {
rec.AttackCounts[AttackType(k)] = v
}
for k, v := range sr.Accounts {
rec.Accounts[k] = v
}
db.records[ip] = rec
}
return
}
// Fallback: flat-file records.json.
path := filepath.Join(db.dbPath, recordsFile)
// #nosec G304 -- filepath.Join under operator-configured db.dbPath.
data, err := os.ReadFile(path)
if err != nil {
return
}
var records map[string]*IPRecord
if err := json.Unmarshal(data, &records); err != nil {
fmt.Fprintf(os.Stderr, "attackdb: error loading %s: %v\n", path, err)
return
}
// Ensure maps are initialized
for _, rec := range records {
if rec.AttackCounts == nil {
rec.AttackCounts = make(map[AttackType]int)
}
if rec.Accounts == nil {
rec.Accounts = make(map[string]int)
}
}
db.records = records
}
// saveRecords writes records to the bbolt store (if available) or to
// the flat-file records.json.
func (db *DB) saveRecords() {
if sdb := store.Global(); sdb != nil {
db.mu.RLock()
for _, rec := range db.records {
sr := store.IPRecord{
IP: rec.IP,
FirstSeen: rec.FirstSeen,
LastSeen: rec.LastSeen,
EventCount: rec.EventCount,
ThreatScore: rec.ThreatScore,
AutoBlocked: rec.AutoBlocked,
AttackCounts: make(map[string]int),
Accounts: make(map[string]int),
}
for k, v := range rec.AttackCounts {
sr.AttackCounts[string(k)] = v
}
for k, v := range rec.Accounts {
sr.Accounts[k] = v
}
if err := sdb.SaveIPRecord(sr); err != nil {
fmt.Fprintf(os.Stderr, "attackdb: store save %s: %v\n", rec.IP, err)
}
}
var deleted []string
for ip := range db.deletedIPs {
deleted = append(deleted, ip)
}
db.mu.RUnlock()
if len(deleted) > 0 {
var removed []string
for _, ip := range deleted {
if err := sdb.DeleteIPRecord(ip); err != nil {
fmt.Fprintf(os.Stderr, "attackdb: store delete %s: %v\n", ip, err)
continue
}
removed = append(removed, ip)
}
if len(removed) > 0 {
db.mu.Lock()
for _, ip := range removed {
delete(db.deletedIPs, ip)
}
db.mu.Unlock()
}
}
return
}
// Fallback: flat-file records.json.
db.mu.RLock()
data, err := json.Marshal(db.records)
db.mu.RUnlock()
if err != nil {
fmt.Fprintf(os.Stderr, "attackdb: error marshaling records: %v\n", err)
return
}
path := filepath.Join(db.dbPath, recordsFile)
tmpPath := path + ".tmp"
if err := os.WriteFile(tmpPath, data, 0600); err != nil {
fmt.Fprintf(os.Stderr, "attackdb: error writing %s: %v\n", tmpPath, err)
return
}
_ = os.Rename(tmpPath, path)
}
// appendEvents writes events to the bbolt store (if available) or appends
// to the JSONL file, rotating if needed.
func (db *DB) appendEvents(events []Event) {
if sdb := store.Global(); sdb != nil {
for i, ev := range events {
ts := ev.Timestamp
if ts.IsZero() {
ts = time.Now()
}
se := store.AttackEvent{
Timestamp: ts,
IP: ev.IP,
AttackType: string(ev.AttackType),
CheckName: ev.CheckName,
Severity: ev.Severity,
Account: ev.Account,
Message: ev.Message,
}
if err := sdb.RecordAttackEvent(se, i); err != nil {
fmt.Fprintf(os.Stderr, "attackdb: store event: %v\n", err)
}
}
return
}
// Fallback: flat-file JSONL.
path := filepath.Join(db.dbPath, eventsFile)
// Check file size and rotate if needed
if info, err := os.Stat(path); err == nil && info.Size() > maxEventsBytes {
rotateEventsFile(path)
}
// #nosec G304 -- filepath.Join under operator-configured db.dbPath.
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
fmt.Fprintf(os.Stderr, "attackdb: error opening %s: %v\n", path, err)
return
}
defer func() { _ = f.Close() }()
w := bufio.NewWriter(f)
enc := json.NewEncoder(w)
for _, ev := range events {
_ = enc.Encode(ev)
}
_ = w.Flush()
}
// rotateEventsFile keeps the newest half of the file.
func rotateEventsFile(path string) {
// #nosec G304 -- path is filepath.Join under operator-configured db.dbPath.
data, err := os.ReadFile(path)
if err != nil {
return
}
// Find the midpoint newline
mid := len(data) / 2
for mid < len(data) {
if data[mid] == '\n' {
mid++
break
}
mid++
}
if mid >= len(data) {
return
}
tmpPath := path + ".tmp"
// #nosec G703 -- path is db.path, derived from the operator-configured
// statePath at DB open time (see Open / NewDB).
if err := os.WriteFile(tmpPath, data[mid:], 0600); err != nil {
return
}
_ = os.Rename(tmpPath, path)
}
// QueryEvents reads events for a specific IP from the bbolt store (if
// available) or from the JSONL file. Returns the most recent `limit`
// events in reverse chronological order.
func (db *DB) QueryEvents(ip string, limit int) []Event {
if sdb := store.Global(); sdb != nil {
storeEvents := sdb.QueryAttackEvents(ip, limit)
result := make([]Event, len(storeEvents))
for i, se := range storeEvents {
result[i] = Event{
Timestamp: se.Timestamp,
IP: se.IP,
AttackType: AttackType(se.AttackType),
CheckName: se.CheckName,
Severity: se.Severity,
Account: se.Account,
Message: se.Message,
}
}
return result
}
// Fallback: flat-file JSONL.
path := filepath.Join(db.dbPath, eventsFile)
// #nosec G304 -- filepath.Join under operator-configured db.dbPath.
f, err := os.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var all []Event
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
var ev Event
if err := json.Unmarshal(scanner.Bytes(), &ev); err != nil {
continue
}
if ev.IP == ip {
all = append(all, ev)
}
}
// Return most recent first
if len(all) > limit && limit > 0 {
all = all[len(all)-limit:]
}
// Reverse
for i, j := 0, len(all)-1; i < j; i, j = i+1, j-1 {
all[i], all[j] = all[j], all[i]
}
return all
}
package attackdb
import "sort"
// ComputeScore returns a 0-100 local threat score from an IPRecord.
//
// Scoring logic:
// - Volume: min(event_count * 2, 30)
// - Attack type bonuses (non-cumulative per type)
// - Multi-account targeting: +10
// - Auto-blocked floor: 50
// - Hard cap: 100
func ComputeScore(r *IPRecord) int {
score := 0
// Volume component - caps at 30
vol := r.EventCount * 2
if vol > 30 {
vol = 30
}
score += vol
// Attack type bonuses
if r.AttackCounts[AttackC2] > 0 {
score += 35
}
if r.AttackCounts[AttackWebshell] > 0 {
score += 30
}
if r.AttackCounts[AttackPhishing] > 0 {
score += 25
}
if r.AttackCounts[AttackBruteForce] > 0 {
score += 15
}
if r.AttackCounts[AttackWAFBlock] > 5 {
score += 10
}
if r.AttackCounts[AttackFileUpload] > 0 {
score += 20
}
// Multi-account targeting
if len(r.Accounts) > 1 {
score += 10
}
// Auto-blocked floor
if r.AutoBlocked && score < 50 {
score = 50
}
// Cap at 100
if score > 100 {
score = 100
}
return score
}
// sortRecords sorts by threat score descending, then event count descending.
func sortRecords(recs []*IPRecord) {
sort.Slice(recs, func(i, j int) bool {
if recs[i].ThreatScore != recs[j].ThreatScore {
return recs[i].ThreatScore > recs[j].ThreatScore
}
return recs[i].EventCount > recs[j].EventCount
})
}
package attackdb
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"sync"
"time"
"github.com/pidginhost/csm/internal/store"
)
const statsCacheTTL = 30 * time.Second
// AttackStats contains aggregate statistics for the API and dashboard.
type AttackStats struct {
TotalIPs int `json:"total_ips"`
TotalEvents int `json:"total_events"`
Last24hEvents int `json:"last_24h_events"`
Last7dEvents int `json:"last_7d_events"`
BlockedIPs int `json:"blocked_ips"`
ByType map[AttackType]int `json:"by_type"` // lifetime, aggregated from IPRecord.AttackCounts
ByType24h map[AttackType]int `json:"by_type_24h"` // last 24h, aggregated from the events log
TopAttackers []*IPRecord `json:"top_attackers"`
HourlyBuckets [24]int `json:"hourly_buckets"` // last 24h, index 0 = oldest hour
DailyBuckets [7]int `json:"daily_buckets"` // last 7 days, index 0 = oldest day
}
var (
cachedStats AttackStats
cachedStatsTime time.Time
cachedStatsMu sync.Mutex
)
// Stats returns aggregate statistics, cached for 30 seconds to avoid
// re-scanning the full events.jsonl on every API call.
func (db *DB) Stats() AttackStats {
cachedStatsMu.Lock()
if time.Since(cachedStatsTime) < statsCacheTTL {
s := cachedStats
cachedStatsMu.Unlock()
return s
}
cachedStatsMu.Unlock()
stats := db.computeStats()
cachedStatsMu.Lock()
cachedStats = stats
cachedStatsTime = time.Now()
cachedStatsMu.Unlock()
return stats
}
func (db *DB) computeStats() AttackStats {
now := time.Now()
cutoff24h := now.Add(-24 * time.Hour)
cutoff7d := now.Add(-7 * 24 * time.Hour)
db.mu.RLock()
stats := AttackStats{
TotalIPs: len(db.records),
ByType: make(map[AttackType]int),
ByType24h: make(map[AttackType]int),
}
for _, rec := range db.records {
stats.TotalEvents += rec.EventCount
if rec.AutoBlocked {
stats.BlockedIPs++
}
for atype, count := range rec.AttackCounts {
stats.ByType[atype] += count
}
}
db.mu.RUnlock()
// Compute time-based stats from events log
events := db.readAllEvents()
for _, ev := range events {
if ev.Timestamp.After(cutoff24h) {
stats.Last24hEvents++
// Skip events with no attack type — a malformed or legacy JSONL
// line would otherwise surface as a "" key in the JSON response.
if ev.AttackType != "" {
stats.ByType24h[ev.AttackType]++
}
hoursAgo := int(now.Sub(ev.Timestamp).Hours())
if hoursAgo >= 0 && hoursAgo < 24 {
stats.HourlyBuckets[23-hoursAgo]++
}
}
if ev.Timestamp.After(cutoff7d) {
stats.Last7dEvents++
daysAgo := int(now.Sub(ev.Timestamp).Hours() / 24)
if daysAgo >= 0 && daysAgo < 7 {
stats.DailyBuckets[6-daysAgo]++
}
}
}
stats.TopAttackers = db.TopAttackers(10)
return stats
}
// readAllEvents reads all events for stats computation.
// Uses bbolt store when available, falls back to JSONL file.
func (db *DB) readAllEvents() []Event {
if sdb := store.Global(); sdb != nil {
storeEvents := sdb.ReadAllAttackEvents()
events := make([]Event, 0, len(storeEvents))
for _, se := range storeEvents {
events = append(events, Event{
Timestamp: se.Timestamp,
IP: se.IP,
AttackType: AttackType(se.AttackType),
CheckName: se.CheckName,
Severity: se.Severity,
Account: se.Account,
Message: se.Message,
})
}
return events
}
// Fallback: flat-file events.jsonl
path := filepath.Join(db.dbPath, eventsFile)
// #nosec G304 -- filepath.Join under operator-configured db.dbPath.
f, err := os.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var events []Event
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
var ev Event
if err := json.Unmarshal(scanner.Bytes(), &ev); err != nil {
continue
}
events = append(events, ev)
}
return events
}
package auditd
import (
"os"
"os/exec"
)
const rulesPath = "/etc/audit/rules.d/csm.rules"
const rules = `## Continuous Security Monitor - auditd rules
# Password/auth file changes
-w /etc/shadow -p wa -k csm_shadow_change
-w /etc/passwd -p wa -k csm_passwd_change
-w /etc/group -p wa -k csm_group_change
# SSH config and keys
-w /etc/ssh/sshd_config -p wa -k csm_sshd_change
-w /root/.ssh/authorized_keys -p wa -k csm_root_ssh_keys
# WHM API tokens
-w /var/cpanel/authn/api_tokens_v2/ -p wa -k csm_whm_api_tokens
# Crontab modifications
-w /var/spool/cron/ -p wa -k csm_crontab_change
-w /etc/cron.d/ -p wa -k csm_crond_change
# Password change commands
-w /usr/bin/passwd -p x -k csm_passwd_exec
-w /usr/sbin/chpasswd -p x -k csm_chpasswd_exec
# CSM binary self-protection
-w /opt/csm/csm -p wa -k csm_binary_tamper
-w /opt/csm/csm.yaml -p wa -k csm_config_tamper
# Execution from suspicious locations
-a always,exit -F arch=b64 -S execve -F dir=/tmp -k csm_exec_tmp
-a always,exit -F arch=b64 -S execve -F dir=/dev/shm -k csm_exec_shm
# User account modifications
-w /usr/sbin/useradd -p x -k csm_useradd
-w /usr/sbin/usermod -p x -k csm_usermod
-w /usr/sbin/userdel -p x -k csm_userdel
`
func Deploy() error {
// #nosec G306 -- /etc/audit/rules.d/csm.rules is read by the auditd
// tooling (augenrules) on reload; 0640 keeps world-read off.
if err := os.WriteFile(rulesPath, []byte(rules), 0640); err != nil {
return err
}
// Reload auditd rules
return exec.Command("augenrules", "--load").Run()
}
func Remove() {
_ = os.Remove(rulesPath)
_ = exec.Command("augenrules", "--load").Run()
}
package challenge
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
// challengeEntry stores the challenge metadata for a single IP.
type challengeEntry struct {
ExpiresAt time.Time
Reason string
}
// ExpiredEntry is returned by ExpiredEntries for escalation.
type ExpiredEntry struct {
IP string
Reason string
}
// IPList manages the set of IPs that should see challenge pages.
// Apache RewriteMap reads the file to redirect IPs to the challenge server.
type IPList struct {
path string
ips map[string]challengeEntry
mu sync.Mutex
}
// NewIPList creates an IP list writer.
func NewIPList(statePath string) *IPList {
return &IPList{
path: filepath.Join(statePath, "challenge_ips.txt"),
ips: make(map[string]challengeEntry),
}
}
// Add marks an IP for challenge with the given reason.
func (l *IPList) Add(ip string, reason string, duration time.Duration) {
l.mu.Lock()
defer l.mu.Unlock()
l.ips[ip] = challengeEntry{
ExpiresAt: time.Now().Add(duration),
Reason: reason,
}
l.flush()
}
// Remove stops challenging an IP (passed or manually removed).
func (l *IPList) Remove(ip string) {
l.mu.Lock()
defer l.mu.Unlock()
delete(l.ips, ip)
l.flush()
}
// Contains returns true if the IP is currently on the challenge list.
func (l *IPList) Contains(ip string) bool {
l.mu.Lock()
defer l.mu.Unlock()
_, ok := l.ips[ip]
return ok
}
// ExpiredEntries removes and returns all expired entries for escalation.
// The caller is expected to hard-block these IPs.
func (l *IPList) ExpiredEntries() []ExpiredEntry {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
var expired []ExpiredEntry
for ip, entry := range l.ips {
if now.After(entry.ExpiresAt) {
expired = append(expired, ExpiredEntry{IP: ip, Reason: entry.Reason})
delete(l.ips, ip)
}
}
if len(expired) > 0 {
l.flush()
}
return expired
}
// CleanExpired removes expired entries without returning them.
// Use ExpiredEntries() instead when escalation is needed.
func (l *IPList) CleanExpired() {
_ = l.ExpiredEntries()
}
// flush writes the IP list to disk in Apache RewriteMap txt format.
// Format: "IP challenge" per line.
func (l *IPList) flush() {
var sb strings.Builder
sb.WriteString("# CSM Challenge IP list - auto-generated, do not edit\n")
sb.WriteString("# Format: IP challenge (for Apache RewriteMap)\n")
for ip := range l.ips {
fmt.Fprintf(&sb, "%s challenge\n", ip)
}
tmpPath := l.path + ".tmp"
// #nosec G306 -- Apache uses this file as a RewriteMap source; it has
// to be readable by the webserver user. No sensitive data — only a
// list of IPs that must re-solve the PoW challenge.
if err := os.WriteFile(tmpPath, []byte(sb.String()), 0644); err != nil {
return
}
_ = os.Rename(tmpPath, l.path)
}
package challenge
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"html"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/config"
)
// IPUnblocker is the interface for temporarily allowing an IP.
type IPUnblocker interface {
TempAllowIP(ip string, reason string, timeout time.Duration) error
}
// Server serves challenge pages to gray-listed IPs.
// When an IP passes the challenge, it gets a temporary allow.
type Server struct {
cfg *config.Config
secret []byte
unblocker IPUnblocker
ipList *IPList
srv *http.Server
trustedProxies map[string]bool
// Track recently verified IPs to prevent replay
verified map[string]time.Time
verifiedMu sync.Mutex
}
// New creates a challenge server.
func New(cfg *config.Config, unblocker IPUnblocker, ipList *IPList) *Server {
secret := []byte(cfg.Challenge.Secret)
if len(secret) == 0 {
secret = make([]byte, 32)
_, _ = rand.Read(secret)
}
trusted := make(map[string]bool)
for _, p := range cfg.Challenge.TrustedProxies {
trusted[strings.TrimSpace(p)] = true
}
s := &Server{
cfg: cfg,
secret: secret,
unblocker: unblocker,
ipList: ipList,
trustedProxies: trusted,
verified: make(map[string]time.Time),
}
mux := http.NewServeMux()
mux.HandleFunc("/challenge", s.handleChallenge)
mux.HandleFunc("/challenge/verify", s.handleVerify)
s.srv = &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Challenge.ListenPort),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
ReadHeaderTimeout: 5 * time.Second,
}
return s
}
// Start begins serving challenge pages.
func (s *Server) Start() error {
return s.srv.ListenAndServe()
}
// Shutdown gracefully stops the server.
func (s *Server) Shutdown() {
_ = s.srv.Close()
}
func (s *Server) handleChallenge(w http.ResponseWriter, r *http.Request) {
ip := s.extractIP(r)
nonce := generateNonce()
difficulty := s.cfg.Challenge.Difficulty
// Generate HMAC token binding nonce to IP
token := s.makeToken(ip, nonce)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set("Cache-Control", "no-store")
// #nosec G705 -- All interpolated values are constrained to non-HTML
// character sets: ip is validated via net.ParseIP / net.SplitHostPort
// in extractIP (never attacker-controlled string); nonce and token
// are hex.EncodeToString output (0-9, a-f only); difficulty is an int.
fmt.Fprintf(w, challengePageHTML, ip, nonce, token, difficulty, difficulty)
}
func (s *Server) handleVerify(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
ip := s.extractIP(r)
nonce := r.FormValue("nonce")
token := r.FormValue("token")
solution := r.FormValue("solution")
// Verify HMAC token
expected := s.makeToken(ip, nonce)
if token != expected {
http.Error(w, "Invalid token", http.StatusForbidden)
return
}
// Verify proof-of-work solution
if !verifyPoW(nonce, solution, s.cfg.Challenge.Difficulty) {
http.Error(w, "Invalid solution", http.StatusForbidden)
return
}
// Prevent replay
s.verifiedMu.Lock()
if _, seen := s.verified[nonce]; seen {
s.verifiedMu.Unlock()
http.Error(w, "Token already used", http.StatusForbidden)
return
}
s.verified[nonce] = time.Now()
s.verifiedMu.Unlock()
// Allow the IP temporarily (4 hours)
allowDuration := 4 * time.Hour
if s.unblocker != nil {
if err := s.unblocker.TempAllowIP(ip, "passed challenge", allowDuration); err != nil {
fmt.Fprintf(os.Stderr, "[challenge] failed to allow %s: %v\n", ip, err)
}
}
// Remove from challenge list so Apache stops redirecting
if s.ipList != nil {
s.ipList.Remove(ip)
}
// Set verification cookie. Secure is always on: CSM is designed to run
// behind HTTPS and the cookie grants a multi-hour bypass of the PoW
// gate, so leaking it over plaintext is never acceptable.
cookie := &http.Cookie{
Name: "csm_verified",
Value: s.makeVerifyCookie(ip),
Path: "/",
MaxAge: int(allowDuration.Seconds()),
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(w, cookie)
// Redirect to original destination (sanitized to prevent open redirect / XSS)
dest := sanitizeRedirectDest(r.FormValue("dest"), r.Host)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprintf(w, `<!DOCTYPE html><html><head><meta http-equiv="refresh" content="2;url=%s">
<style>body{font-family:system-ui;display:flex;justify-content:center;align-items:center;min-height:100vh;margin:0;background:#1a2234;color:#c8d3e0}
.ok{text-align:center}.ok h1{color:#2fb344;font-size:3em}p{font-size:1.2em}</style>
</head><body><div class="ok"><h1>✓</h1><p>Verified - redirecting...</p></div></body></html>`, html.EscapeString(dest))
}
func (s *Server) makeToken(ip, nonce string) string {
mac := hmac.New(sha256.New, s.secret)
mac.Write([]byte(ip + ":" + nonce))
return hex.EncodeToString(mac.Sum(nil))
}
func (s *Server) makeVerifyCookie(ip string) string {
mac := hmac.New(sha256.New, s.secret)
mac.Write([]byte("cookie:" + ip + ":" + time.Now().Truncate(time.Hour).Format(time.RFC3339)))
return hex.EncodeToString(mac.Sum(nil))[:32]
}
// CleanExpired removes old verification records.
func (s *Server) CleanExpired() {
s.verifiedMu.Lock()
defer s.verifiedMu.Unlock()
cutoff := time.Now().Add(-4 * time.Hour)
for k, t := range s.verified {
if t.Before(cutoff) {
delete(s.verified, k)
}
}
}
// extractIP returns the client IP from the request. X-Forwarded-For is only
// trusted when the direct peer is in the configured trusted_proxies list.
// Uses the rightmost XFF entry (the one appended by the trusted proxy),
// not the leftmost (which the client controls). Without trusted proxies,
// RemoteAddr is always used — this prevents attackers from spoofing their IP
// to mint firewall allow rules for arbitrary addresses.
func (s *Server) extractIP(r *http.Request) string {
peerIP, _, _ := net.SplitHostPort(r.RemoteAddr)
if len(s.trustedProxies) > 0 && s.trustedProxies[peerIP] {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
// Use rightmost entry — the one the trusted proxy appended
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if net.ParseIP(ip) != nil {
return ip
}
}
}
}
return peerIP
}
func generateNonce() string {
b := make([]byte, 16)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
// sanitizeRedirectDest validates that a redirect destination is a safe same-origin
// relative path or absolute URL matching the request host. Rejects cross-origin
// redirects, javascript: URIs, backslash-based bypasses, and other open-redirect payloads.
// Returns a reconstructed URL from parsed components to prevent any raw-string injection.
func sanitizeRedirectDest(dest, requestHost string) string {
if dest == "" {
return "/"
}
// Parse through url.Parse to normalize and detect scheme/host
parsed, err := url.Parse(dest)
if err != nil {
return "/"
}
// Scheme whitelist — applies even for opaque URLs with empty Host.
// Without this, `javascript:alert(1)` produces an opaque URL with
// Host="" and Scheme="javascript", which would slip past the
// host-equality check below and end up reconstructed as
// `"javascript:"`. The only acceptable schemes are the empty
// string (pure-path relatives) and the two HTTP variants.
scheme := strings.ToLower(parsed.Scheme)
if scheme != "" && scheme != "http" && scheme != "https" {
return "/"
}
// Reject anything with a host component that doesn't match the request host.
// This catches protocol-relative (//evil.com), backslash tricks (/\evil.com
// which some browsers normalize to //evil.com), and explicit cross-origin URLs.
if parsed.Host != "" {
destHost := parsed.Hostname()
reqHost := requestHost
if h, _, err := net.SplitHostPort(requestHost); err == nil {
reqHost = h
}
if destHost != reqHost {
return "/"
}
}
// For relative paths: reject anything that doesn't start with a clean /
if parsed.Host == "" && parsed.Scheme == "" {
if !strings.HasPrefix(parsed.Path, "/") {
return "/"
}
// Reject backslash in path (browser normalization attack)
if strings.ContainsRune(parsed.Path, '\\') {
return "/"
}
}
// Reconstruct from parsed components to prevent raw-string injection
safe := &url.URL{
Scheme: parsed.Scheme,
Host: parsed.Host,
Path: parsed.Path,
RawQuery: parsed.RawQuery,
Fragment: parsed.Fragment,
}
return safe.String()
}
// verifyPoW checks that SHA256(nonce + solution) starts with `difficulty` zero nibbles.
func verifyPoW(nonce, solution string, difficulty int) bool {
h := sha256.Sum256([]byte(nonce + solution))
hexHash := hex.EncodeToString(h[:])
for i := 0; i < difficulty; i++ {
if i >= len(hexHash) || hexHash[i] != '0' {
return false
}
}
return true
}
package checks
import (
"context"
"fmt"
"os"
"os/user"
"path/filepath"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/state"
)
// ScanAccount restricts filesystem-based checks to a single account.
// Protected by scanMu - concurrent scans are serialized to prevent scope bleed.
var (
ScanAccount string
scanMu sync.Mutex
)
// RunAccountScan runs all applicable checks scoped to a single cPanel account.
// Returns findings for that account only. Does NOT trigger auto-response actions.
func RunAccountScan(cfg *config.Config, store *state.Store, account string) []alert.Finding {
// Verify account exists
homeDir := filepath.Join("/home", account)
if _, err := osFS.Stat(homeDir); os.IsNotExist(err) {
return []alert.Finding{{
Severity: alert.Warning,
Check: "account_scan",
Message: fmt.Sprintf("Account '%s' not found (no /home/%s directory)", account, account),
Timestamp: time.Now(),
}}
}
// Acquire scan lock - only one account scan at a time to prevent scope bleed
scanMu.Lock()
ScanAccount = account
defer func() {
ScanAccount = ""
scanMu.Unlock()
}()
// Account-scoped checks (filesystem + account-specific)
accountChecks := []namedCheck{
{"webshells", CheckWebshells},
{"htaccess", CheckHtaccess},
{"wp_core", CheckWPCore},
{"php_content", CheckPHPContent},
{"phishing", CheckPhishing},
{"filesystem", CheckFilesystem},
{"group_writable_php", CheckGroupWritablePHP},
{"nulled_plugins", CheckNulledPlugins},
{"open_basedir", CheckOpenBasedir},
{"symlink_attacks", CheckSymlinkAttacks},
{"db_content", CheckDatabaseContent},
{"php_config_changes", CheckPHPConfigChanges},
}
// Account-specific checks that need the account name
accountChecks = append(accountChecks,
namedCheck{"ssh_keys_account", makeAccountSSHKeyCheck(account)},
namedCheck{"crontab_account", makeAccountCrontabCheck(account)},
namedCheck{"backdoor_binaries", makeAccountBackdoorCheck(account)},
)
// Run with bounded parallelism - filesystem checks all walk the same
// directory tree, so too many concurrent checks starve each other on
// loaded servers with slow I/O.
var mu sync.Mutex
var findings []alert.Finding
var wg sync.WaitGroup
sem := make(chan struct{}, 4) // max 4 concurrent checks
for _, nc := range accountChecks {
wg.Add(1)
go func(c namedCheck) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
ctx, cancel := context.WithTimeout(context.Background(), checkTimeout)
done := make(chan []alert.Finding, 1)
go func() {
done <- c.fn(ctx, cfg, store)
}()
select {
case results := <-done:
cancel()
if len(results) > 0 {
mu.Lock()
findings = append(findings, results...)
mu.Unlock()
}
case <-ctx.Done():
cancel()
mu.Lock()
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "check_timeout",
Message: fmt.Sprintf("Account scan check '%s' timed out", c.name),
Timestamp: time.Now(),
})
mu.Unlock()
}
}(nc)
}
wg.Wait()
now := time.Now()
for i := range findings {
if findings[i].Timestamp.IsZero() {
findings[i].Timestamp = now
}
}
// Filter findings to only include this account's paths
var filtered []alert.Finding
accountPrefix := "/home/" + account + "/"
for _, f := range findings {
// Include if the finding mentions this account's path, or has no path at all
if strings.Contains(f.Message, accountPrefix) ||
strings.Contains(f.Details, accountPrefix) ||
(!strings.Contains(f.Message, "/home/") && !strings.Contains(f.Details, "/home/")) {
filtered = append(filtered, f)
}
}
return filtered
}
// GetScanHomeDirs returns the list of home directories to scan.
// If ScanAccount is set, returns only that account. Otherwise returns all.
func GetScanHomeDirs() ([]os.DirEntry, error) {
scanMu.Lock()
account := ScanAccount
scanMu.Unlock()
if account != "" {
// Return a single synthetic DirEntry for the target account
info, err := osFS.Stat(filepath.Join("/home", account))
if err != nil {
return nil, err
}
return []os.DirEntry{fakeDirEntry{info}}, nil
}
return osFS.ReadDir("/home")
}
// ResolveWebRoots returns the list of directory paths CSM should scan for
// web-facing content (wp-config.php, .htaccess, public_html trees, etc.).
//
// Resolution order:
// 1. If cfg.AccountRoots is set, expand each glob and return the result.
// Explicit config always wins.
// 2. On cPanel hosts (detected via platform.Detect), fall back to
// /home/*/public_html for backward compatibility.
// 3. On non-cPanel hosts with no config, return an empty list. Callers
// should treat this as "no scanning" and skip cleanly.
//
// Each returned path is an absolute directory that exists on disk.
func ResolveWebRoots(cfg *config.Config) []string {
var patterns []string
switch {
case len(cfg.AccountRoots) > 0:
patterns = cfg.AccountRoots
case platform.Detect().IsCPanel():
patterns = []string{"/home/*/public_html"}
default:
return nil
}
var roots []string
seen := make(map[string]struct{})
for _, pattern := range patterns {
matches, err := osFS.Glob(pattern)
if err != nil || len(matches) == 0 {
continue
}
for _, m := range matches {
info, err := osFS.Stat(m)
if err != nil || !info.IsDir() {
continue
}
if _, ok := seen[m]; ok {
continue
}
seen[m] = struct{}{}
roots = append(roots, m)
}
}
return roots
}
// fakeDirEntry wraps os.FileInfo to implement os.DirEntry.
type fakeDirEntry struct {
fi os.FileInfo
}
func (f fakeDirEntry) Name() string { return f.fi.Name() }
func (f fakeDirEntry) IsDir() bool { return f.fi.IsDir() }
func (f fakeDirEntry) Type() os.FileMode { return f.fi.Mode().Type() }
func (f fakeDirEntry) Info() (os.FileInfo, error) { return f.fi, nil }
// makeAccountSSHKeyCheck creates a check for SSH keys of a specific account.
func makeAccountSSHKeyCheck(account string) CheckFunc {
return func(_ context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
keyFile := filepath.Join("/home", account, ".ssh", "authorized_keys")
hash, err := hashFileContent(keyFile)
if err != nil {
return nil
}
key := fmt.Sprintf("_ssh_user_keys:%s", keyFile)
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ssh_keys",
Message: fmt.Sprintf("User authorized_keys modified: %s", keyFile),
})
}
return findings
}
}
// makeAccountCrontabCheck creates a check for a specific account's crontab.
func makeAccountCrontabCheck(account string) CheckFunc {
return func(_ context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
crontabFile := filepath.Join("/var/spool/cron", account)
data, err := osFS.ReadFile(crontabFile)
if err != nil {
return nil
}
content := string(data)
for _, pattern := range MatchCrontabPatternsDeep(content) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "suspicious_crontab",
Message: fmt.Sprintf("Suspicious pattern in crontab for %s: %s", account, pattern),
Details: fmt.Sprintf("File: /var/spool/cron/%s\nContent:\n%s", account, content),
})
}
return findings
}
}
// makeAccountBackdoorCheck creates a check for backdoor binaries in account's .config.
func makeAccountBackdoorCheck(account string) CheckFunc {
return func(_ context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
backdoorNames := map[string]bool{
"defunct": true, "defunct.dat": true, "gs-netcat": true,
"gs-sftp": true, "gs-mount": true, "gsocket": true,
}
patterns := []string{
filepath.Join("/home", account, ".config", "htop", "*"),
filepath.Join("/home", account, ".config", "*", "*"),
}
for _, pattern := range patterns {
matches, _ := osFS.Glob(pattern)
for _, path := range matches {
if backdoorNames[filepath.Base(path)] {
info, _ := osFS.Stat(path)
var details string
if info != nil {
details = fmt.Sprintf("Size: %d bytes, Mtime: %s", info.Size(), info.ModTime().Format("2006-01-02 15:04:05"))
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "backdoor_binary",
Message: fmt.Sprintf("Backdoor binary found: %s", path),
Details: details,
})
}
}
}
return findings
}
}
// LookupUID returns the UID for a system account name, or -1 if not found.
func LookupUID(account string) int {
u, err := user.Lookup(account)
if err != nil {
return -1
}
uid := 0
fmt.Sscanf(u.Uid, "%d", &uid)
return uid
}
package checks
import (
"context"
"encoding/json"
"fmt"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
func CheckShadowChanges(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
info, err := osFS.Stat("/etc/shadow")
if err != nil {
return nil
}
mtime := info.ModTime()
mtimeKey := "_shadow_mtime"
hashKey := "_shadow_hash"
// Load current shadow entries (user:hash pairs, no sensitive data stored)
currentEntries := parseShadowUsers()
currentHash := hashBytes([]byte(fmt.Sprintf("%v", currentEntries)))
prevMtimeRaw, mtimeExists := store.GetRaw(mtimeKey)
prevHash, hashExists := store.GetRaw(hashKey)
if mtimeExists {
var lastMtime time.Time
if err := json.Unmarshal([]byte(prevMtimeRaw), &lastMtime); err == nil {
if mtime.After(lastMtime) {
// Shadow file was modified - find what changed
var details string
if hashExists && prevHash != currentHash {
changed := diffShadowChanges(store, currentEntries)
if len(changed) > 0 {
details = fmt.Sprintf("Previous: %s\nCurrent: %s\nAccounts changed: %s",
lastMtime.Format("2006-01-02 15:04:05"),
mtime.Format("2006-01-02 15:04:05"),
strings.Join(changed, ", "))
}
}
if details == "" {
details = fmt.Sprintf("Previous: %s\nCurrent: %s",
lastMtime.Format("2006-01-02 15:04:05"),
mtime.Format("2006-01-02 15:04:05"))
}
// Check if within upcp window
sev := alert.Critical
if cfg.Suppressions.UPCPWindowStart != "" {
now := time.Now()
h, m := now.Hour(), now.Minute()
nowMin := h*60 + m
start := parseTimeMin(cfg.Suppressions.UPCPWindowStart)
end := parseTimeMin(cfg.Suppressions.UPCPWindowEnd)
if nowMin >= start && nowMin <= end {
sev = alert.Warning
}
}
// Check auditd for who made the change
auditInfo := getAuditShadowInfo()
if auditInfo != "" {
details += "\n" + auditInfo
}
// Suppress alerts for password changes made by infra IPs
// (admin-initiated password resets via WHM/xml-api)
if isInfraShadowChange(cfg) {
// Still update state, but don't alert
goto storeState
}
// Separate root password change (higher severity)
changed := diffShadowChanges(store, currentEntries)
rootChanged := false
userCount := 0
for _, c := range changed {
if c == "root" {
rootChanged = true
} else {
userCount++
}
}
if rootChanged {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "root_password_change",
Message: "Root password changed",
Details: details,
})
}
// Bulk password changes (5+ accounts at once)
if userCount >= 5 {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "bulk_password_change",
Message: fmt.Sprintf("Bulk password change: %d accounts modified", userCount),
Details: details,
})
} else {
findings = append(findings, alert.Finding{
Severity: sev,
Check: "shadow_change",
Message: "/etc/shadow modified",
Details: details,
})
}
}
}
}
storeState:
// Store current state
mtimeData, _ := json.Marshal(mtime)
store.SetRaw(mtimeKey, string(mtimeData))
store.SetRaw(hashKey, currentHash)
// Store per-user hashes for diff next time
for user, hash := range currentEntries {
store.SetRaw("_shadow_user:"+user, hash)
}
return findings
}
// parseShadowUsers reads /etc/shadow and returns a map of user -> password hash.
// Only stores a hash of the hash, not the actual password hash.
func parseShadowUsers() map[string]string {
data, err := osFS.ReadFile("/etc/shadow")
if err != nil {
return nil
}
entries := make(map[string]string)
for _, line := range strings.Split(string(data), "\n") {
parts := strings.SplitN(line, ":", 3)
if len(parts) < 2 || parts[0] == "" {
continue
}
// Store a hash of the password field, not the field itself
entries[parts[0]] = hashBytes([]byte(parts[1]))
}
return entries
}
// diffShadowChanges compares current entries against stored per-user hashes.
func diffShadowChanges(store *state.Store, current map[string]string) []string {
var changed []string
for user, hash := range current {
prev, exists := store.GetRaw("_shadow_user:" + user)
if exists && prev != hash {
changed = append(changed, user)
} else if !exists {
changed = append(changed, user+" (new)")
}
}
return changed
}
// getAuditShadowInfo checks auditd for recent shadow change events.
func getAuditShadowInfo() string {
out, err := runCmd("grep", "csm_shadow_change", "/var/log/audit/audit.log")
if err != nil || len(out) == 0 {
return ""
}
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
if len(lines) == 0 {
return ""
}
// Get the last event
last := lines[len(lines)-1]
// Extract exe= field
exe := ""
for _, part := range strings.Fields(last) {
if strings.HasPrefix(part, "exe=") {
exe = strings.Trim(strings.TrimPrefix(part, "exe="), "\"")
break
}
}
// Decode hex comm if present
comm := ""
for _, part := range strings.Fields(last) {
if strings.HasPrefix(part, "comm=") {
raw := strings.Trim(strings.TrimPrefix(part, "comm="), "\"")
decoded := decodeHexString(raw)
if decoded != "" {
comm = decoded
} else {
comm = raw
}
break
}
}
if exe != "" || comm != "" {
return fmt.Sprintf("Changed by: %s (command: %s)", exe, comm)
}
return ""
}
// decodeHexString tries to decode a hex-encoded string (auditd encodes some comm fields).
func decodeHexString(s string) string {
if len(s)%2 != 0 || len(s) < 4 {
return ""
}
// Check if it looks like hex (all hex chars)
for _, c := range s {
isHex := (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')
if !isHex {
return ""
}
}
var result []byte
for i := 0; i < len(s); i += 2 {
// #nosec G115 -- hexVal returns 0..15; (h<<4)|h fits in a byte.
b := byte(hexVal(s[i])<<4 | hexVal(s[i+1]))
if b == 0 {
break
}
result = append(result, b)
}
if len(result) == 0 {
return ""
}
return string(result)
}
func CheckUID0Accounts(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
data, err := osFS.ReadFile("/etc/passwd")
if err != nil {
return nil
}
allowedUID0 := map[string]bool{
"root": true, "sync": true, "shutdown": true,
"halt": true, "operator": true,
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Split(line, ":")
if len(fields) < 4 {
continue
}
user := fields[0]
uid := fields[2]
if uid == "0" && !allowedUID0[user] {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "uid0_account",
Message: fmt.Sprintf("Unauthorized UID 0 account: %s", user),
Details: line,
})
}
}
return findings
}
func CheckSSHKeys(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
// Check root authorized_keys
rootKeys := "/root/.ssh/authorized_keys"
if hash, err := hashFileContent(rootKeys); err == nil {
key := "_ssh_root_keys_hash"
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "ssh_keys",
Message: "Root authorized_keys modified",
Details: fmt.Sprintf("File: %s", rootKeys),
})
}
store.SetRaw(key, hash)
}
// Check for new authorized_keys in /home
homes, _ := osFS.Glob("/home/*/.ssh/authorized_keys")
for _, keyFile := range homes {
hash, err := hashFileContent(keyFile)
if err != nil {
continue
}
key := fmt.Sprintf("_ssh_user_keys:%s", keyFile)
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ssh_keys",
Message: fmt.Sprintf("User authorized_keys modified: %s", keyFile),
})
}
store.SetRaw(key, hash)
}
return findings
}
func CheckAPITokens(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
// WHM root API tokens
out, err := runCmd("whmapi1", "api_token_list")
if err == nil {
hash := hashBytes(out)
key := "_whm_api_tokens_hash"
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "api_tokens",
Message: "WHM root API tokens changed",
Details: "Run 'whmapi1 api_token_list' to review",
})
}
store.SetRaw(key, hash)
}
// User API tokens - read directly from disk instead of spawning uapi per user
// Token files are JSON at /home/<user>/.cpanel/api_tokens/<token_name>
tokenDirs, _ := osFS.Glob("/home/*/.cpanel/api_tokens")
for _, tokenDir := range tokenDirs {
user := filepath.Base(filepath.Dir(filepath.Dir(tokenDir)))
tokenFiles, _ := osFS.Glob(filepath.Join(tokenDir, "*"))
for _, tokenFile := range tokenFiles {
tokenName := filepath.Base(tokenFile)
data, err := osFS.ReadFile(tokenFile)
if err != nil {
continue
}
content := string(data)
// Check for full access with no IP whitelist
hasFullAccess := strings.Contains(content, `"has_full_access":1`) ||
strings.Contains(content, `"has_full_access": 1`)
noWhitelist := strings.Contains(content, `"whitelist_ips":null`) ||
strings.Contains(content, `"whitelist_ips": null`) ||
strings.Contains(content, `"whitelist_ips":[]`) ||
strings.Contains(content, `"whitelist_ips": []`) ||
!strings.Contains(content, "whitelist_ips")
if hasFullAccess && noWhitelist {
known := false
for _, t := range cfg.Suppressions.KnownAPITokens {
if tokenName == t {
known = true
break
}
}
if !known {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "api_tokens",
Message: fmt.Sprintf("User %s has full-access API token '%s' with no IP whitelist", user, tokenName),
Details: fmt.Sprintf("File: %s", tokenFile),
})
}
}
}
}
return findings
}
// isInfraShadowChange checks the cPanel session log for recent password
// change PURGE events from infra IPs. This is more reliable than the
// access log because it captures the actual IP that triggered the change.
//
// Returns true only if ALL recent password purges came from infra IPs.
// If any came from non-infra → returns false (possible attack).
func isInfraShadowChange(cfg *config.Config) bool {
lines := tailFile("/usr/local/cpanel/logs/session_log", 100)
foundAny := false
allInfra := true
// Check recent PURGE password_change events (last ~5 minutes of log)
for i := len(lines) - 1; i >= 0; i-- {
line := lines[i]
if !strings.Contains(line, "PURGE") || !strings.Contains(line, "password_change") {
continue
}
foundAny = true
// Extract IP from session log
// Format: [timestamp] info [xml-api] IP PURGE account:token password_change
// Format: [timestamp] info [whostmgr] IP PURGE account:token password_change
var ip string
for _, tag := range []string{"[xml-api]", "[whostmgr]", "[security]"} {
if idx := strings.Index(line, tag); idx >= 0 {
rest := strings.TrimSpace(line[idx+len(tag):])
fields := strings.Fields(rest)
if len(fields) > 0 {
ip = fields[0]
}
break
}
}
if ip == "" || ip == "internal" {
continue // internal events are safe
}
if !isInfraIP(ip, cfg.InfraIPs) && ip != "127.0.0.1" {
allInfra = false
break // found a non-infra password change - don't suppress
}
}
return foundAny && allInfra
}
func parseTimeMin(s string) int {
parts := strings.Split(s, ":")
if len(parts) != 2 {
return 0
}
h := 0
m := 0
fmt.Sscanf(parts[0], "%d", &h)
fmt.Sscanf(parts[1], "%d", &m)
return h*60 + m
}
package checks
import (
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
const (
maxBlocksPerHour = 50
defaultBlockExpiry = "24h"
blockStateFile = "blocked_ips.json"
)
// IPBlocker abstracts the firewall engine for auto-blocking.
// When set, blocks go through nftables firewall engine.
type IPBlocker interface {
BlockIP(ip string, reason string, timeout time.Duration) error
UnblockIP(ip string) error
IsBlocked(ip string) bool
}
var fwBlocker IPBlocker
var blockStateMu sync.Mutex
// SetIPBlocker sets the firewall engine for auto-blocking.
func SetIPBlocker(b IPBlocker) {
fwBlocker = b
}
type blockedIP struct {
IP string `json:"ip"`
Reason string `json:"reason"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type pendingIP struct {
IP string `json:"ip"`
Reason string `json:"reason"`
}
type blockState struct {
IPs []blockedIP `json:"ips"`
Pending []pendingIP `json:"pending,omitempty"` // IPs waiting for rate-limit reset
BlocksThisHour int `json:"blocks_this_hour"`
HourKey string `json:"hour_key"`
}
// AutoBlockIPs processes findings and blocks attacker IPs via the firewall engine.
// Note: this should be called with ALL findings (not just new ones)
// for reputation-based blocking to work on repeat offenders.
func AutoBlockIPs(cfg *config.Config, findings []alert.Finding) []alert.Finding {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.BlockIPs {
return nil
}
blockStateMu.Lock()
defer blockStateMu.Unlock()
var actions []alert.Finding
// Load block state
state := loadBlockState(cfg.StatePath)
// Prune IPs that the firewall engine no longer has blocked.
// The engine handles expiry natively via nftables timeouts -
// we just sync our state to match.
var stillBlocked []blockedIP
for _, b := range state.IPs {
if fwBlocker != nil {
if !fwBlocker.IsBlocked(b.IP) {
// Engine expired this block - clean up our state
fmt.Fprintf(os.Stderr, "[%s] AUTO-UNBLOCK: %s removed (engine expired)\n", time.Now().Format("2006-01-02 15:04:05"), b.IP)
continue
}
}
stillBlocked = append(stillBlocked, b)
}
state.IPs = stillBlocked
// Check rate limit
currentHour := time.Now().Format("2006-01-02T15")
if state.HourKey != currentHour {
state.HourKey = currentHour
state.BlocksThisHour = 0
}
// Collect IPs to block from findings
ipsToBlock := make(map[string]string) // ip -> reason
// Always blockable (brute force, C2, known malicious)
alwaysBlock := map[string]bool{
"wp_login_bruteforce": true,
"xmlrpc_abuse": true,
"ftp_bruteforce": true,
"smtp_bruteforce": true,
"mail_bruteforce": true,
"mail_account_compromised": true,
"admin_panel_bruteforce": true,
"ssh_login_unknown_ip": true,
"ssh_login_realtime": true,
"c2_connection": true,
"ip_reputation": true,
"local_threat_score": true,
"modsec_csm_block_escalation": true,
"email_compromised_account": true,
}
// Only blockable when block_cpanel_logins is enabled (disabled by default)
cpanelWebmailChecks := map[string]bool{
"cpanel_login": true,
"cpanel_login_realtime": true,
"cpanel_multi_ip_login": true,
"cpanel_file_upload_realtime": true,
"api_auth_failure": true,
"api_auth_failure_realtime": true,
"webmail_bruteforce": true,
"webmail_login_realtime": true,
"ftp_login_realtime": true,
"ftp_auth_failure_realtime": true,
}
// Drain pending queue first (IPs from prior rate-limited cycles)
for _, p := range state.Pending {
if !isAlreadyBlocked(state, p.IP) {
ipsToBlock[p.IP] = p.Reason
}
}
state.Pending = nil
// Subnet fast-path: checks that represent a subnet directly.
// Independent of the per-IP rate limit, because a single subnet block
// replaces what would otherwise be hundreds of per-IP blocks.
for _, f := range findings {
if f.Check != "smtp_subnet_spray" && f.Check != "mail_subnet_spray" {
continue
}
cidr := extractCIDRFromFinding(f)
if cidr == "" {
continue
}
if fwBlocker == nil {
fmt.Fprintf(os.Stderr, "auto-block: firewall engine not available, skipping subnet %s\n", cidr)
continue
}
sb, ok := fwBlocker.(interface {
BlockSubnet(string, string, time.Duration) error
})
if !ok {
fmt.Fprintf(os.Stderr, "auto-block: firewall engine does not support subnet blocking, skipping %s\n", cidr)
continue
}
reason := fmt.Sprintf("CSM auto-block (subnet): %s", truncate(f.Message, 100))
if err := sb.BlockSubnet(cidr, reason, parseExpiry(cfg.AutoResponse.BlockExpiry)); err != nil {
fmt.Fprintf(os.Stderr, "auto-block: error blocking subnet %s: %v\n", cidr, err)
continue
}
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-BLOCK-SUBNET: %s blocked", cidr),
Details: fmt.Sprintf("Reason: %s", f.Message),
Timestamp: time.Now(),
})
}
for _, f := range findings {
isBlockable := alwaysBlock[f.Check]
if !isBlockable && cfg.AutoResponse.BlockCpanelLogins && cpanelWebmailChecks[f.Check] {
isBlockable = true
}
if !isBlockable {
continue
}
ip := extractIPFromFinding(f)
if ip == "" {
continue
}
// Never block infra IPs
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
// Don't re-block already blocked IPs
if isAlreadyBlocked(state, ip) {
continue
}
// Skip IPs on the challenge list (they'll be challenged, not blocked)
if cl := GetChallengeIPList(); cl != nil && cl.Contains(ip) {
continue
}
ipsToBlock[ip] = f.Message
}
// Block IPs - queue any that can't be blocked due to rate limit
expiry := parseExpiry(cfg.AutoResponse.BlockExpiry)
rateLimited := false
for ip, reason := range ipsToBlock {
if state.BlocksThisHour >= maxBlocksPerHour {
// Queue for next cycle instead of dropping
state.Pending = append(state.Pending, pendingIP{IP: ip, Reason: reason})
rateLimited = true
continue
}
// Block via firewall engine (nftables)
blockReason := fmt.Sprintf("CSM auto-block: %s", truncate(reason, 100))
if fwBlocker == nil {
fmt.Fprintf(os.Stderr, "auto-block: firewall engine not available, skipping %s\n", ip)
continue
}
if err := fwBlocker.BlockIP(ip, blockReason, expiry); err != nil {
fmt.Fprintf(os.Stderr, "auto-block: error blocking %s: %v\n", ip, err)
continue
}
state.BlocksThisHour++
// Add to permanent local threat database
if db := GetThreatDB(); db != nil {
db.AddPermanent(ip, reason)
}
state.IPs = append(state.IPs, blockedIP{
IP: ip,
Reason: reason,
BlockedAt: time.Now(),
ExpiresAt: time.Now().Add(expiry),
})
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-BLOCK: %s blocked (expires in %s)", ip, expiry),
Details: fmt.Sprintf("Reason: %s", reason),
Timestamp: time.Now(),
})
// Permanent block escalation: promote to permanent after N temp blocks
if cfg.AutoResponse.PermBlock && fwBlocker != nil {
count := cfg.AutoResponse.PermBlockCount
if count < 2 {
count = 4
}
interval := parseExpiry(cfg.AutoResponse.PermBlockInterval)
if interval == 0 {
interval = 24 * time.Hour
}
if checkPermBlockEscalation(cfg.StatePath, ip, count, interval) {
permReason := fmt.Sprintf("PERMBLOCK: %d temp blocks within %s", count, interval)
if err := fwBlocker.BlockIP(ip, permReason, 0); err == nil {
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-PERMBLOCK: %s promoted to permanent block (%d temp blocks)", ip, count),
Timestamp: time.Now(),
})
}
}
}
}
if rateLimited {
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_block",
Message: fmt.Sprintf("Auto-block rate limit reached (%d/hour), %d IPs queued for next cycle", maxBlocksPerHour, len(state.Pending)),
Timestamp: time.Now(),
})
}
// Subnet auto-blocking: detect /24 patterns
if cfg.AutoResponse.NetBlock && fwBlocker != nil {
threshold := cfg.AutoResponse.NetBlockThreshold
if threshold < 2 {
threshold = 3
}
// Count blocked IPs per /24
subnetCounts := make(map[string]int)
subnetBlocked := make(map[string]bool)
for _, b := range state.IPs {
prefix := extractPrefix24(b.IP)
if prefix != "" {
subnetCounts[prefix]++
}
}
for prefix, count := range subnetCounts {
if count >= threshold && !subnetBlocked[prefix] {
cidr := prefix + ".0/24"
if sb, ok := fwBlocker.(interface {
BlockSubnet(string, string, time.Duration) error
}); ok {
reason := fmt.Sprintf("Auto-netblock: %d IPs from %s", count, cidr)
if err := sb.BlockSubnet(cidr, reason, 0); err == nil {
subnetBlocked[prefix] = true
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-NETBLOCK: %s blocked (%d IPs from same /24)", cidr, count),
Timestamp: time.Now(),
})
}
}
}
}
}
// Save state (expired IPs were already pruned at the top of this function)
saveBlockState(cfg.StatePath, state)
return actions
}
// ExtractIPFromFinding extracts an IP address from a finding message.
func ExtractIPFromFinding(f alert.Finding) string {
return extractIPFromFinding(f)
}
func extractIPFromFinding(f alert.Finding) string {
msg := f.Message
// Use LastIndex to find the rightmost separator - log-injected content
// tends to appear earlier in the message, while the structurally-parsed
// IP from the log parser appears at the end.
for _, sep := range []string{" from ", ": "} {
if idx := strings.LastIndex(msg, sep); idx >= 0 {
rest := msg[idx+len(sep):]
fields := strings.Fields(rest)
if len(fields) > 0 {
candidate := strings.TrimRight(fields[0], ",:;)([]")
ip := net.ParseIP(candidate)
if ip == nil {
continue
}
// Reject loopback and unspecified - never block these
if ip.IsLoopback() || ip.IsUnspecified() {
continue
}
return ip.String()
}
}
}
return ""
}
func isAlreadyBlocked(state *blockState, ip string) bool {
for _, b := range state.IPs {
if b.IP == ip {
return true
}
}
return false
}
func parseExpiry(s string) time.Duration {
if s == "" {
s = defaultBlockExpiry
}
d, err := time.ParseDuration(s)
if err != nil {
return 24 * time.Hour
}
return d
}
func loadBlockState(statePath string) *blockState {
state := &blockState{}
data, err := osFS.ReadFile(filepath.Join(statePath, blockStateFile))
if err == nil {
_ = json.Unmarshal(data, state)
}
return state
}
func saveBlockState(statePath string, state *blockState) {
data, _ := json.MarshalIndent(state, "", " ")
tmpPath := filepath.Join(statePath, blockStateFile+".tmp")
_ = os.WriteFile(tmpPath, data, 0600)
_ = os.Rename(tmpPath, filepath.Join(statePath, blockStateFile))
}
// PendingBlockIPs returns IPs queued for blocking (rate-limited).
// Used by alert.FilterBlockedAlerts to suppress reputation alerts for these IPs.
func PendingBlockIPs(statePath string) map[string]bool {
state := loadBlockState(statePath)
ips := make(map[string]bool, len(state.Pending))
for _, p := range state.Pending {
ips[p.IP] = true
}
return ips
}
// extractPrefix24 returns the first 3 octets of an IPv4 address (e.g. "1.2.3").
func extractPrefix24(ip string) string {
parts := strings.Split(ip, ".")
if len(parts) != 4 {
return ""
}
return parts[0] + "." + parts[1] + "." + parts[2]
}
// --- Permanent block escalation (LF_PERMBLOCK) ---
type permBlockTracker struct {
IPs map[string][]time.Time `json:"ips"` // IP -> list of block timestamps
}
// checkPermBlockEscalation records a new block and returns true if the IP
// has been temp-blocked count times within interval.
func checkPermBlockEscalation(statePath, ip string, count int, interval time.Duration) bool {
tracker := loadPermBlockTracker(statePath)
now := time.Now()
cutoff := now.Add(-interval)
// Add current block timestamp
tracker.IPs[ip] = append(tracker.IPs[ip], now)
// Clean old entries for this IP
var recent []time.Time
for _, t := range tracker.IPs[ip] {
if t.After(cutoff) {
recent = append(recent, t)
}
}
tracker.IPs[ip] = recent
// Clean old IPs entirely (haven't been seen in 2x the interval)
for k, times := range tracker.IPs {
if len(times) == 0 {
delete(tracker.IPs, k)
continue
}
latest := times[len(times)-1]
if now.Sub(latest) > interval*2 {
delete(tracker.IPs, k)
}
}
savePermBlockTracker(statePath, tracker)
return len(recent) >= count
}
func loadPermBlockTracker(statePath string) *permBlockTracker {
tracker := &permBlockTracker{IPs: make(map[string][]time.Time)}
data, err := osFS.ReadFile(filepath.Join(statePath, "permblock_tracker.json"))
if err == nil {
_ = json.Unmarshal(data, tracker)
if tracker.IPs == nil {
tracker.IPs = make(map[string][]time.Time)
}
}
return tracker
}
func savePermBlockTracker(statePath string, tracker *permBlockTracker) {
data, _ := json.MarshalIndent(tracker, "", " ")
tmpPath := filepath.Join(statePath, "permblock_tracker.json.tmp")
_ = os.WriteFile(tmpPath, data, 0600)
_ = os.Rename(tmpPath, filepath.Join(statePath, "permblock_tracker.json"))
}
// extractCIDRFromFinding returns the CIDR appearing in the message after
// the canonical " from " separator. Returns "" if the value does not parse
// as a CIDR.
func extractCIDRFromFinding(f alert.Finding) string {
msg := f.Message
idx := strings.LastIndex(msg, " from ")
if idx < 0 {
return ""
}
rest := msg[idx+len(" from "):]
fields := strings.Fields(rest)
if len(fields) == 0 {
return ""
}
candidate := strings.TrimRight(fields[0], ",:;)([]")
_, ipnet, err := net.ParseCIDR(candidate)
if err != nil {
return ""
}
return ipnet.String()
}
package checks
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"syscall"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
// var (not const) so tests can redirect to t.TempDir().
var quarantineDir = "/opt/csm/quarantine"
// QuarantineMeta stores original file metadata alongside quarantined files.
type QuarantineMeta struct {
OriginalPath string `json:"original_path"`
Owner int `json:"owner_uid"`
Group int `json:"group_gid"`
Mode string `json:"mode"`
Size int64 `json:"size"`
QuarantineAt time.Time `json:"quarantined_at"`
Reason string `json:"reason"`
}
// AutoKillProcesses kills processes that match critical findings.
// Only targets: fake kernel threads, reverse shells, GSocket processes.
// Never kills root system services or cPanel processes.
func AutoKillProcesses(cfg *config.Config, findings []alert.Finding) []alert.Finding {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.KillProcesses {
return nil
}
var actions []alert.Finding
for _, f := range findings {
// Only act on specific high-confidence critical checks
switch f.Check {
case "fake_kernel_thread", "suspicious_process", "php_suspicious_execution":
default:
continue
}
if f.Severity != alert.Critical {
continue
}
// Use structured PID field when available, fall back to text extraction
pid := fmt.Sprintf("%d", f.PID)
if f.PID == 0 {
pid = extractPID(f.Details)
if pid == "" {
continue
}
}
// Safety: verify the process is not root/system
uid := getProcessUID(pid)
if uid == "0" || uid == "" {
continue // never kill root processes automatically
}
// Safety: verify it's not a cPanel/system process
exe := getProcessExe(pid)
if isSafeProcess(exe) {
continue
}
// Kill it
pidInt := f.PID
if pidInt == 0 {
fmt.Sscanf(pid, "%d", &pidInt)
}
if pidInt <= 1 {
continue
}
err := syscall.Kill(pidInt, syscall.SIGKILL)
if err != nil {
continue
}
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-KILL: Process %s killed (was: %s)", pid, f.Check),
Details: fmt.Sprintf("Original finding: %s\nProcess: %s (UID: %s)", f.Message, exe, uid),
})
}
return actions
}
// AutoQuarantineFiles moves malicious files to quarantine directory.
// Preserves original path and metadata in a sidecar .meta file.
func AutoQuarantineFiles(cfg *config.Config, findings []alert.Finding) []alert.Finding {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.QuarantineFiles {
return nil
}
var actions []alert.Finding
for _, f := range findings {
// Only quarantine specific file-based findings
isRealtimeMatch := false
switch f.Check {
case "webshell", "backdoor_binary", "new_webshell_file", "new_executable_in_config",
"obfuscated_php", "php_dropper", "suspicious_php_content",
"new_php_in_languages", "new_php_in_upgrade",
"phishing_page", "phishing_directory",
"htaccess_handler_abuse":
case "signature_match_realtime":
isRealtimeMatch = true
default:
continue
}
if f.Severity != alert.Critical {
continue
}
// Extract file path - prefer structured field, fallback to message parsing
path := f.FilePath
if path == "" {
path = extractFilePath(f.Message) // fallback for legacy findings
}
if path == "" {
continue
}
// Realtime signature matches require additional validation to avoid
// quarantining false positives (e.g. legitimate PHPMailer matching
// "webshell_marijuana", or zip libraries matching hex patterns).
// Only quarantine when the file is genuinely obfuscated malware.
if isRealtimeMatch && !isHighConfidenceRealtimeMatch(f, path, nil) {
continue
}
// Verify file or directory exists and reject symlinks
info, err := osFS.Lstat(path)
if err != nil {
continue
}
if info.Mode()&os.ModeSymlink != 0 {
continue
}
// Realtime high-confidence matches are fully obfuscated malware -
// there is no legitimate code to preserve, skip cleaning and go
// straight to quarantine.
if isRealtimeMatch {
goto quarantine
}
// For WP core/plugin/theme files: clean surgically instead of quarantining.
// This preserves site functionality while removing the injected code.
if !info.IsDir() && ShouldCleanInsteadOfQuarantine(path) {
result := CleanInfectedFile(path)
switch {
case result.Cleaned:
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-CLEAN: %s surgically cleaned", path),
Details: fmt.Sprintf("Backup: %s\n%s", result.BackupPath, strings.Join(result.Removals, "\n")),
Timestamp: time.Now(),
})
continue // successfully cleaned, skip quarantine
case result.Error != "":
// Cleaning failed - fall through to quarantine
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-CLEAN failed for %s, quarantining instead", path),
Details: result.Error,
Timestamp: time.Now(),
})
// Don't continue - fall through to quarantine below
default:
continue // no changes needed
}
}
quarantine:
// Create quarantine directory
_ = os.MkdirAll(quarantineDir, 0700)
// Build quarantine destination preserving directory structure
safeName := strings.ReplaceAll(path, "/", "_")
ts := time.Now().Format("20060102-150405")
qPath := filepath.Join(quarantineDir, fmt.Sprintf("%s_%s", ts, safeName))
// Get file ownership
var uid, gid int
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
uid = int(stat.Uid)
gid = int(stat.Gid)
}
// Handle directory quarantine (e.g., LEVIATHAN/ webshell directories)
if info.IsDir() {
if err := os.Rename(path, qPath); err != nil {
// Cross-device: skip directory move (too complex for auto-response)
continue
}
} else {
// Move file to quarantine
if err := os.Rename(path, qPath); err != nil {
// If rename fails (cross-device), copy and delete
data, readErr := osFS.ReadFile(path)
if readErr != nil {
continue
}
if writeErr := os.WriteFile(qPath, data, 0600); writeErr != nil {
continue
}
if rmErr := os.Remove(path); rmErr != nil {
// Remove failed - delete the copy to avoid duplication
os.Remove(qPath)
fmt.Fprintf(os.Stderr, "autoresponse: cross-device quarantine failed, cannot remove original %s: %v\n", path, rmErr)
continue
}
}
}
// Write metadata sidecar
meta := QuarantineMeta{
OriginalPath: path,
Owner: uid,
Group: gid,
Mode: info.Mode().String(),
Size: info.Size(),
QuarantineAt: time.Now(),
Reason: f.Message,
}
metaData, _ := json.MarshalIndent(meta, "", " ")
_ = os.WriteFile(qPath+".meta", metaData, 0600)
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-QUARANTINE: %s moved to quarantine", path),
Details: fmt.Sprintf("Quarantined to: %s\nOriginal finding: %s", qPath, f.Message),
})
}
return actions
}
// AutoFixPermissions sets world/group-writable PHP files to 0644.
// Returns the auto-response action findings and the keys of original findings
// that were successfully fixed (so the caller can dismiss them from the UI).
func AutoFixPermissions(cfg *config.Config, findings []alert.Finding) (actions []alert.Finding, fixedKeys []string) {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.EnforcePermissions {
return nil, nil
}
for _, f := range findings {
switch f.Check {
case "world_writable_php", "group_writable_php":
default:
continue
}
path := extractFilePath(f.Message)
if path == "" {
continue
}
path, info, err := resolveExistingFixPath(path, fixPermissionsAllowedRoots)
if err != nil || info.IsDir() {
continue
}
oldMode := info.Mode().Perm()
// #nosec G302 -- same as fixPermissions: restoring canonical web-content
// mode on a user file flagged for dangerous (e.g. world-writable) perms.
if err := os.Chmod(path, 0644); err != nil {
continue
}
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-FIX: %s permissions set to 644 (was %o)", path, oldMode),
Timestamp: time.Now(),
})
fixedKeys = append(fixedKeys, f.Check+":"+f.Message)
}
return actions, fixedKeys
}
func extractPID(details string) string {
// Look for "PID: 12345" pattern. Stop at the first whitespace, comma,
// or newline so a trailing word ("PID: 42 exe=/bin/ls") doesn't get
// returned as part of the PID string.
idx := strings.Index(details, "PID: ")
if idx < 0 {
return ""
}
rest := details[idx+5:]
for i, c := range rest {
if c == ' ' || c == ',' || c == '\n' || c == '\t' {
return strings.TrimSpace(rest[:i])
}
}
return strings.TrimSpace(rest)
}
func extractFilePath(message string) string {
// Look for /home/... or /tmp/... paths in the message. Order matters:
// longer/more-specific prefixes (/var/tmp/, /dev/shm/) must come BEFORE
// shorter ones (/tmp/) — otherwise "/tmp/" would match inside "/var/tmp/"
// and we'd silently misclassify the path.
for _, prefix := range []string{"/var/tmp/", "/dev/shm/", "/home/", "/tmp/"} {
if idx := strings.Index(message, prefix); idx >= 0 {
rest := message[idx:]
// Path ends at space, comma, or end of string
endIdx := len(rest)
for i, c := range rest {
if c == ' ' || c == ',' || c == '\n' {
endIdx = i
break
}
}
return rest[:endIdx]
}
}
return ""
}
func getProcessUID(pid string) string {
data, err := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
if err != nil {
return ""
}
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
return fields[0]
}
}
}
return ""
}
func getProcessExe(pid string) string {
exe, err := osFS.Readlink(filepath.Join("/proc", pid, "exe"))
if err != nil {
return ""
}
return exe
}
func isSafeProcess(exe string) bool {
safePrefixes := []string{
"/usr/local/cpanel/",
"/usr/sbin/",
"/usr/bin/",
"/usr/libexec/",
"/opt/cpanel/",
"/opt/cloudlinux/",
"/opt/imunify360/",
}
for _, prefix := range safePrefixes {
if strings.HasPrefix(exe, prefix) {
return true
}
}
return false
}
// isHighConfidenceRealtimeMatch validates whether a realtime signature match
// is truly malicious and safe to auto-quarantine. Prevents false positives
// on legitimate libraries (PHPMailer, zip) and theme code.
//
// The data parameter should be the file content already read by the caller
// (fanotify fd or scanner) to avoid TOCTOU re-reads. Pass nil to read from path.
//
// Criteria:
// 1. Category must be "dropper" or "webshell"
// 2. File must not be in a known library path
// 3. File must be >= 512 bytes (entropy unreliable below this)
// 4. Content must show obfuscation indicators:
// Shannon entropy >= 4.8 OR hex density > 20%.
// This applies to BOTH dropper and webshell categories to avoid
// false-positive quarantine of legitimate plugins that happen to
// match a dropper rule (e.g. curl_exec + eval on distant lines).
func isHighConfidenceRealtimeMatch(f alert.Finding, path string, data []byte) bool {
cat := extractCategory(f.Details)
switch cat {
case "dropper", "webshell":
default:
return false
}
pathLower := strings.ToLower(path)
for _, lib := range knownLibraryPaths {
if strings.Contains(pathLower, lib) {
return false
}
}
if data == nil {
var err error
data, err = osFS.ReadFile(path)
if err != nil {
return false
}
}
if len(data) < 512 {
return false
}
// Both dropper and webshell categories go through the same entropy/encoding
// checks to avoid false positives. Legitimate plugins (low entropy ~4.2)
// pass through, while real obfuscated malware (high entropy ~5.5+) or
// hex-heavy payloads still get quarantined.
content := string(data)
return shannonEntropy(content) >= 4.8 || hexEncodingDensity(content) > 0.20
}
// hexEncodingDensity returns the fraction of a string's bytes that are part
// of PHP hex escape sequences (\xNN). LEVIATHAN AES-encrypted webshells
// encode their payload as long hex strings - the \x prefix repeats so
// frequently that Shannon entropy drops to ~3.5 (below normal PHP), but
// the hex density reaches 40-60%.
func hexEncodingDensity(s string) float64 {
if len(s) == 0 {
return 0
}
hexBytes := 0
for i := 0; i < len(s)-3; i++ {
if s[i] == '\\' && s[i+1] == 'x' &&
isHexDigit(s[i+2]) && isHexDigit(s[i+3]) {
hexBytes += 4
i += 3 // skip past this sequence
}
}
return float64(hexBytes) / float64(len(s))
}
func isHexDigit(b byte) bool {
return (b >= '0' && b <= '9') || (b >= 'a' && b <= 'f') || (b >= 'A' && b <= 'F')
}
// knownLibraryPaths are directory fragments that indicate a file belongs to
// a well-known third-party library and should never be auto-quarantined.
var knownLibraryPaths = []string{
"/phpmailer/",
"/vendor/",
"/node_modules/",
"/pear/",
"/tcpdf/",
"/dompdf/",
"/guzzlehttp/",
"/symfony/",
"/monolog/",
}
// InlineQuarantine moves a file to quarantine immediately if it passes the
// high-confidence validation gates. Called from fanotify's analyzeFile to
// quarantine malware without waiting for the 5-second batch dispatcher.
// The data parameter is the file content already read by the caller (avoids
// TOCTOU re-read). Pass nil to read from path.
// Returns the quarantine path and true if the file was quarantined.
func InlineQuarantine(f alert.Finding, path string, data []byte) (string, bool) {
if !isHighConfidenceRealtimeMatch(f, path, data) {
return "", false
}
info, err := osFS.Stat(path)
if err != nil {
return "", false
}
_ = os.MkdirAll(quarantineDir, 0700)
safeName := strings.ReplaceAll(path, "/", "_")
ts := time.Now().Format("20060102-150405")
qPath := filepath.Join(quarantineDir, fmt.Sprintf("%s_%s", ts, safeName))
if err := os.Rename(path, qPath); err != nil {
if info.IsDir() {
return "", false
}
data, readErr := osFS.ReadFile(path)
if readErr != nil {
return "", false
}
if writeErr := os.WriteFile(qPath, data, 0600); writeErr != nil {
return "", false
}
os.Remove(path)
}
// Write metadata sidecar
var uid, gid int
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
uid = int(stat.Uid)
gid = int(stat.Gid)
}
meta := QuarantineMeta{
OriginalPath: path,
Owner: uid,
Group: gid,
Mode: info.Mode().String(),
Size: info.Size(),
QuarantineAt: time.Now(),
Reason: "Inline quarantine: high-confidence realtime signature match",
}
metaData, _ := json.MarshalIndent(meta, "", " ")
_ = os.WriteFile(qPath+".meta", metaData, 0600)
return qPath, true
}
// extractCategory parses "Category: <value>" from a finding's Details field.
func extractCategory(details string) string {
for _, line := range strings.Split(details, "\n") {
if strings.HasPrefix(line, "Category: ") {
return strings.TrimPrefix(line, "Category: ")
}
}
return ""
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const (
wpLoginThreshold = 20 // attempts per IP across all logs
xmlrpcThreshold = 30
ftpFailThreshold = 10
webmailThreshold = 10
apiFailThreshold = 10
// domlogTailLines is how many lines to read from each domlog.
// 500 lines covers ~10 minutes of traffic on a busy site.
domlogTailLines = 500
// domlogMaxAge skips domlogs not modified recently (inactive sites).
domlogMaxAge = 30 * time.Minute
// domlogMaxFiles caps the number of domlogs scanned per cycle
// to prevent unbounded I/O on servers with thousands of domains.
domlogMaxFiles = 500
)
// CheckWPBruteForce detects brute force attacks against wp-login.php and
// xmlrpc.php by scanning access logs. Always scans per-domain domlogs
// because on LiteSpeed+cPanel, virtual host traffic only appears there.
// The central access log is scanned as a supplement.
//
// Aggregates per-IP counts across ALL domains — catches attackers who
// distribute requests across many sites to stay under per-site thresholds.
func CheckWPBruteForce(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
window := cfg.Thresholds.BruteForceWindow
if window <= 0 {
window = 5000
}
wpLogin := make(map[string]int)
xmlrpc := make(map[string]int)
userEnum := make(map[string]int)
// 1. Per-domain domlogs — primary source on LiteSpeed.
// Glob both SSL and non-SSL logs: attackers may use HTTP.
scanned := scanDomlogs(cfg.InfraIPs, wpLogin, xmlrpc, userEnum)
// 2. Central access log — supplement for non-vhost traffic.
// On LiteSpeed this mostly has WHM/server-level requests.
// On Apache it duplicates domlog data — minor double-counting is
// acceptable since thresholds are high enough.
for _, p := range []string{
"/usr/local/apache/logs/access_log",
"/var/log/apache2/access_log",
"/etc/apache2/logs/access_log",
} {
lines := tailFile(p, window)
if len(lines) > 0 {
countBruteForce(lines, cfg.InfraIPs, wpLogin, xmlrpc, userEnum)
break
}
}
// 3. Build findings from aggregated counters.
var findings []alert.Finding
for ip, count := range wpLogin {
if count >= wpLoginThreshold {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "wp_login_bruteforce",
Message: fmt.Sprintf("WordPress login brute force from %s: %d attempts", ip, count),
Details: fmt.Sprintf("Aggregated across %d domlog files", scanned),
})
}
}
for ip, count := range xmlrpc {
if count >= xmlrpcThreshold {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "xmlrpc_abuse",
Message: fmt.Sprintf("XML-RPC abuse from %s: %d requests", ip, count),
Details: fmt.Sprintf("Aggregated across %d domlog files", scanned),
})
}
}
for ip, count := range userEnum {
if count >= 5 {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "wp_user_enumeration",
Message: fmt.Sprintf("WordPress user enumeration from %s: %d requests", ip, count),
Details: "Requests to /wp-json/wp/v2/users or ?author=",
})
}
}
return findings
}
// scanDomlogs globs per-domain access logs, deduplicates symlinks,
// skips stale files, and aggregates brute force counts.
// Returns the number of files actually scanned.
func scanDomlogs(infraIPs []string, wpLogin, xmlrpc, userEnum map[string]int) int {
var domlogs []string
for _, pattern := range []string{
"/home/*/access-logs/*-ssl_log",
"/home/*/access-logs/*_log",
} {
matches, _ := osFS.Glob(pattern)
domlogs = append(domlogs, matches...)
}
// Deduplicate via resolved symlinks and skip stale logs.
seen := make(map[string]bool)
cutoff := time.Now().Add(-domlogMaxAge)
scanned := 0
for _, dl := range domlogs {
if scanned >= domlogMaxFiles {
break
}
// Resolve symlinks — cPanel often symlinks SSL and non-SSL logs.
real, err := filepath.EvalSymlinks(dl)
if err != nil {
continue
}
if seen[real] {
continue
}
seen[real] = true
// Skip logs not modified recently — inactive sites add no value.
info, err := osFS.Stat(real)
if err != nil || info.ModTime().Before(cutoff) {
continue
}
lines := tailFile(real, domlogTailLines)
countBruteForce(lines, infraIPs, wpLogin, xmlrpc, userEnum)
scanned++
}
return scanned
}
// countBruteForce parses Combined Log Format lines and increments per-IP
// counters for wp-login.php, xmlrpc.php, and user enumeration attacks.
func countBruteForce(lines []string, infraIPs []string, wpLogin, xmlrpc, userEnum map[string]int) {
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) < 7 {
continue
}
ip := fields[0]
// Skip localhost (wp-cron self-requests), placeholder, and infra IPs.
if ip == "127.0.0.1" || ip == "::1" || ip == "-" {
continue
}
if isInfraIP(ip, infraIPs) {
continue
}
method := strings.Trim(fields[5], "\"")
uri := fields[6]
if method == "POST" {
if strings.Contains(uri, "wp-login.php") {
wpLogin[ip]++
}
if strings.Contains(uri, "xmlrpc.php") {
xmlrpc[ip]++
}
}
// User enumeration — only exclude /users/me (authenticated self-check).
if strings.Contains(uri, "?author=") {
userEnum[ip]++
} else if strings.Contains(uri, "/wp-json/wp/v2/users") &&
!strings.Contains(uri, "/users/me") {
userEnum[ip]++
}
}
}
// CheckFTPLogins parses /var/log/messages or pure-ftpd log for FTP brute force.
func CheckFTPLogins(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
lines := tailFile("/var/log/messages", 200)
if len(lines) == 0 {
return nil
}
failedFTP := make(map[string]int)
for _, line := range lines {
// pure-ftpd logs: "pure-ftpd: ... [WARNING] Authentication failed for user"
if !strings.Contains(line, "pure-ftpd") {
continue
}
if strings.Contains(line, "Authentication failed") || strings.Contains(line, "auth failed") {
// Extract IP
ip := extractIPFromLog(line)
if ip != "" && !isInfraIP(ip, cfg.InfraIPs) {
failedFTP[ip]++
}
}
// Successful FTP login from non-infra
if strings.Contains(line, "is now logged in") {
ip := extractIPFromLog(line)
if ip != "" && !isInfraIP(ip, cfg.InfraIPs) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ftp_login",
Message: fmt.Sprintf("FTP login from non-infra IP: %s", ip),
Details: truncate(line, 200),
})
}
}
}
for ip, count := range failedFTP {
if count >= ftpFailThreshold {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ftp_bruteforce",
Message: fmt.Sprintf("FTP brute force from %s: %d failed attempts", ip, count),
})
}
}
return findings
}
// CheckWebmailLogins parses cPanel access log for webmail logins from non-infra IPs.
func CheckWebmailLogins(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if cfg.Suppressions.SuppressWebmail {
return nil
}
var findings []alert.Finding
lines := tailFile("/usr/local/cpanel/logs/access_log", 300)
loginAttempts := make(map[string]int)
for _, line := range lines {
// Webmail ports: 2095 (HTTP), 2096 (HTTPS)
if !strings.Contains(line, "2095") && !strings.Contains(line, "2096") {
continue
}
fields := strings.Fields(line)
if len(fields) < 1 {
continue
}
ip := fields[0]
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
// Count login attempts per IP
if strings.Contains(line, "POST") && (strings.Contains(line, "login") || strings.Contains(line, "auth")) {
loginAttempts[ip]++
}
}
for ip, count := range loginAttempts {
if count >= webmailThreshold {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "webmail_bruteforce",
Message: fmt.Sprintf("Webmail brute force from %s: %d attempts", ip, count),
})
}
}
return findings
}
// CheckAPIAuthFailures parses cPanel access log for failed API authentication.
func CheckAPIAuthFailures(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
lines := tailFile("/usr/local/cpanel/logs/access_log", 300)
failedAPI := make(map[string]int)
for _, line := range lines {
// Look for 401/403 responses on API endpoints
if !strings.Contains(line, "\" 401 ") && !strings.Contains(line, "\" 403 ") {
continue
}
// Only API endpoints
if !strings.Contains(line, "json-api") && !strings.Contains(line, "/execute/") &&
!strings.Contains(line, "cpsess") {
continue
}
fields := strings.Fields(line)
if len(fields) < 1 {
continue
}
ip := fields[0]
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
failedAPI[ip]++
}
for ip, count := range failedAPI {
if count >= apiFailThreshold {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "api_auth_failure",
Message: fmt.Sprintf("cPanel API auth failures from %s: %d attempts", ip, count),
Details: "Possible API token brute force or unauthorized API access",
})
}
}
return findings
}
// extractIPFromLog tries to extract an IP address from a log line.
func extractIPFromLog(line string) string {
// Look for IP pattern in common positions
fields := strings.Fields(line)
for _, f := range fields {
// Simple IP detection: starts with digit, contains dots
if len(f) >= 7 && f[0] >= '0' && f[0] <= '9' && strings.Count(f, ".") == 3 {
// Strip trailing punctuation
f = strings.TrimRight(f, ",:;)([]")
return f
}
}
return ""
}
package checks
import (
"fmt"
"os"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
// ChallengeIPList abstracts the challenge IP list for routing.
type ChallengeIPList interface {
Add(ip string, reason string, duration time.Duration)
Contains(ip string) bool
}
var challengeIPList ChallengeIPList
// SetChallengeIPList sets the challenge IP list for routing.
func SetChallengeIPList(list ChallengeIPList) {
challengeIPList = list
}
// GetChallengeIPList returns the current challenge IP list (for AutoBlockIPs skip check).
func GetChallengeIPList() ChallengeIPList {
return challengeIPList
}
// hardBlockChecks are exact check names that must NEVER be routed to challenge.
var hardBlockChecks = map[string]bool{
"signature_match_realtime": true,
"yara_match_realtime": true,
"webshell": true,
"backdoor_binary": true,
"cross_account_malware": true,
"c2_connection": true,
"backdoor_port": true,
"backdoor_port_outbound": true,
"exfiltration_paste_site": true,
"htaccess_injection": true,
"htaccess_handler_abuse": true,
"db_siteurl_hijack": true,
"db_options_injection": true,
"db_post_injection": true,
"db_rogue_admin": true,
"db_spam_injection": true,
"phishing_page": true,
"phishing_iframe": true,
"phishing_php": true,
"phishing_redirector": true,
"phishing_credential_log": true,
"phishing_kit_archive": true,
"phishing_directory": true,
"php_shield_webshell": true,
"php_shield_block": true,
"php_shield_eval": true,
"suspicious_crontab": true,
"suspicious_process": true,
"fake_kernel_thread": true,
"php_suspicious_execution": true,
"suspicious_file": true,
"password_hijack_confirmed": true,
"symlink_attack": true,
"shadow_change": true,
"root_password_change": true,
"coordinated_attack": true,
"database_dump": true,
"kernel_module": true,
"uid0_account": true,
"suid_binary": true,
"rpm_integrity": true,
"email_spam_outbreak": true,
"modsec_csm_block_escalation": true,
"api_auth_failure_realtime": true, // cPanel API brute force — challenge is useless, hard-block
"ftp_auth_failure_realtime": true, // FTP brute force — can't challenge non-HTTP
"pam_bruteforce": true, // PAM brute force — can't challenge non-HTTP
"smtp_bruteforce": true, // SMTP brute force — can't challenge non-HTTP
"smtp_subnet_spray": true, // SMTP subnet spray — can't challenge non-HTTP
"mail_bruteforce": true, // Mail brute force — can't challenge non-HTTP
"mail_subnet_spray": true, // Mail subnet spray — can't challenge non-HTTP
"mail_account_compromised": true, // Mail account compromise — instant block, zero-FP signal
"admin_panel_bruteforce": true, // Admin panel brute force — tight path set makes FP near-impossible
}
// hardBlockPrefixes match any check name starting with these strings.
var hardBlockPrefixes = []string{
"outgoing_mail_",
"spam_",
"modsec_",
"email_auth_failure", // email brute force - SMTP/IMAP, can't challenge via HTTP
"email_compromised", // confirmed compromised email account
"email_credential", // credential leak
}
// challengeableChecks lists checks whose findings contain attacker IPs and
// are appropriate for challenge routing. This is a closed allowlist — any
// new check that produces IP-bearing findings must be added here. This is
// safer than a denylist because forgetting to add an informational check
// (like outdated_plugins, whose version strings parse as IPs) defaults to
// "skip" rather than "block an innocent IP".
var challengeableChecks = map[string]bool{
// Brute force checks — all contain attacker IPs
"brute_force": true,
"wp_login_bruteforce": true,
"xmlrpc_abuse": true,
"wp_user_enumeration": true,
"ftp_bruteforce": true,
"webmail_bruteforce": true,
"api_auth_failure": true,
// Login monitoring — contain source IPs
"cpanel_login": true,
"cpanel_login_realtime": true,
"cpanel_multi_ip_login": true,
"cpanel_file_upload": true,
"cpanel_file_upload_realtime": true,
"ftp_login": true,
"ftp_login_realtime": true,
"ssh_login_realtime": true,
"ssh_login_unknown_ip": true,
"webmail_login_realtime": true,
"whm_password_change": true,
// Reputation and threat scoring — contain IPs
"ip_reputation": true,
"local_threat_score": true,
// Network — contain IPs
"dns_connection": true,
"user_outbound_connection": true,
// WAF — contain attacker IPs
"waf_attack_blocked": true,
}
func isChallengeableCheck(check string) bool {
return challengeableChecks[check]
}
// isHardBlockCheck returns true if the check should be hard-blocked (never challenged).
func isHardBlockCheck(check string) bool {
if hardBlockChecks[check] {
return true
}
for _, prefix := range hardBlockPrefixes {
if strings.HasPrefix(check, prefix) {
return true
}
}
return false
}
const challengeDuration = 30 * time.Minute
// ChallengeRouteIPs processes findings and routes eligible IPs to the challenge
// list instead of hard-blocking them. Must be called BEFORE AutoBlockIPs so
// that challenged IPs are on the list when AutoBlockIPs checks Contains().
func ChallengeRouteIPs(cfg *config.Config, findings []alert.Finding) []alert.Finding {
if !cfg.Challenge.Enabled || challengeIPList == nil {
return nil
}
var actions []alert.Finding
routed := make(map[string]bool)
for _, f := range findings {
if isHardBlockCheck(f.Check) {
continue
}
// Only route checks that are known to contain attacker IPs.
// This is an allowlist — new IP-bearing checks must be added to
// challengeableChecks. Defaulting to skip prevents version numbers,
// sizes, and other numeric finding fields from being blocked as IPs.
if !isChallengeableCheck(f.Check) {
continue
}
ip := extractIPFromFinding(f)
if ip == "" || routed[ip] {
continue
}
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
if challengeIPList.Contains(ip) {
continue
}
challengeIPList.Add(ip, f.Message, challengeDuration)
routed[ip] = true
fmt.Fprintf(os.Stderr, "[%s] CHALLENGE: %s routed to challenge (check: %s)\n",
time.Now().Format("2006-01-02 15:04:05"), ip, f.Check)
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "challenge_route",
Message: fmt.Sprintf("CHALLENGE: %s sent to PoW challenge (expires in %s)", ip, challengeDuration),
Details: fmt.Sprintf("Reason: %s", f.Message),
Timestamp: time.Now(),
})
}
return actions
}
package checks
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"regexp"
"strings"
"syscall"
"time"
)
// CleanResult describes the outcome of a cleaning attempt.
type CleanResult struct {
Path string
Cleaned bool
BackupPath string
Removals []string // descriptions of what was removed
Error string
}
// CleanInfectedFile attempts to surgically remove malicious code from a PHP file
// while preserving the legitimate content. Always creates a backup first.
//
// Cleaning strategies (tried in order):
// 1. @include injection - remove @include lines pointing to /tmp, eval, base64, or via variables
// 2. Prepend injection - remove malicious code blocks at start of file (entropy-validated)
// 3. Append injection - remove malicious code after closing ?> or end of PSR-12 file
// 4. Inline eval injection - remove eval(base64_decode(...)) single-line injections
func CleanInfectedFile(path string) CleanResult {
result := CleanResult{Path: path}
// Read original file
data, err := osFS.ReadFile(path)
if err != nil {
result.Error = fmt.Sprintf("cannot read file: %v", err)
return result
}
// Create backup before any modification
backupDir := filepath.Join(quarantineDir, "pre_clean")
_ = os.MkdirAll(backupDir, 0700)
ts := time.Now().Format("20060102-150405")
safeName := strings.ReplaceAll(path, "/", "_")
backupPath := filepath.Join(backupDir, fmt.Sprintf("%s_%s", ts, safeName))
if err := os.WriteFile(backupPath, data, 0600); err != nil {
result.Error = fmt.Sprintf("cannot create backup: %v", err)
return result
}
result.BackupPath = backupPath
// Write metadata sidecar so the WebUI quarantine page can list pre-clean backups
info, _ := osFS.Stat(path)
var fileSize int64
var fileMode string
var uid, gid int
if info != nil {
fileSize = info.Size()
fileMode = info.Mode().String()
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
uid = int(stat.Uid)
gid = int(stat.Gid)
}
}
meta := map[string]interface{}{
"original_path": path,
"owner_uid": uid,
"group_gid": gid,
"mode": fileMode,
"size": fileSize,
"quarantined_at": time.Now(),
"reason": "Pre-clean backup (surgical cleaning)",
}
metaData, _ := json.MarshalIndent(meta, "", " ")
_ = os.WriteFile(backupPath+".meta", metaData, 0600)
content := string(data)
originalLen := len(content)
var removals []string
// Strategy 1: Remove @include injections (including variable-based)
content, removed := removeIncludeInjections(content)
removals = append(removals, removed...)
// Strategy 2: Remove prepend injections (with entropy validation)
content, removed = removePrependInjection(content)
removals = append(removals, removed...)
// Strategy 3: Remove append injections (handles files with and without closing ?>)
content, removed = removeAppendInjection(content)
removals = append(removals, removed...)
// Strategy 4: Remove inline eval(base64_decode(...)) injections
content, removed = removeInlineEvalInjections(content)
removals = append(removals, removed...)
// Strategy 5: Remove multi-layer base64 decode chains
content, removed = removeMultiLayerBase64(content)
removals = append(removals, removed...)
// Strategy 6: Remove chr()/pack() constructed code
content, removed = removeChrPackInjections(content)
removals = append(removals, removed...)
// Strategy 7: Remove hex-encoded variable injections
content, removed = removeHexVarInjections(content)
removals = append(removals, removed...)
// If nothing was removed, file couldn't be cleaned
if len(removals) == 0 || len(content) == originalLen {
result.Error = "no known injection patterns found - file may need manual review"
return result
}
// Write cleaned file
info, _ = osFS.Stat(path)
mode := os.FileMode(0644)
if info != nil {
mode = info.Mode()
}
if err := os.WriteFile(path, []byte(content), mode); err != nil {
result.Error = fmt.Sprintf("cannot write cleaned file: %v", err)
return result
}
result.Cleaned = true
result.Removals = removals
return result
}
// ShouldCleanInsteadOfQuarantine returns true if the file should be cleaned
// (surgical removal) instead of quarantined (full removal).
// WP core files and plugin files are better cleaned - removing them breaks the site.
// Unknown standalone files (droppers, webshells) should be quarantined.
func ShouldCleanInsteadOfQuarantine(path string) bool {
// WP core files - always clean, never quarantine
if strings.Contains(path, "/wp-includes/") || strings.Contains(path, "/wp-admin/") {
return true
}
// Plugin/theme main files - clean to preserve functionality
if strings.Contains(path, "/wp-content/plugins/") || strings.Contains(path, "/wp-content/themes/") {
// But not if the file itself is the malware (h4x0r.php inside a theme)
name := strings.ToLower(filepath.Base(path))
if isWebshellName(name) {
return false // quarantine this - it's a standalone webshell
}
return true
}
// Everything else - quarantine
return false
}
// removeIncludeInjections removes @include lines that load malicious files.
// Catches:
// - @include("/tmp/...") - literal paths to temp dirs
// - @include(base64_decode("...")) - encoded includes
// - @include($var) where $var is built from obfuscated strings nearby
func removeIncludeInjections(content string) (string, []string) {
var removals []string
lines := strings.Split(content, "\n")
var clean []string
// Pattern 1: @include with literal malicious paths or encoding functions
maliciousInclude := regexp.MustCompile(
`(?i)^\s*@include\s*\(\s*(?:` +
`['"](?:/tmp/|/dev/shm/|/var/tmp/)` + // include from temp dirs
`|base64_decode\s*\(` + // include with base64
`|str_rot13\s*\(` + // include with rot13
`|gzinflate\s*\(` + // include with gzip
`)`)
// Pattern 2: @include($variable) - suspicious variable-based include
// Only flag if the variable is defined nearby with concatenation/obfuscation
varInclude := regexp.MustCompile(`(?i)^\s*@include\s*\(\s*\$[a-zA-Z_]+\s*\)`)
for i, line := range lines {
if maliciousInclude.MatchString(line) {
removals = append(removals, fmt.Sprintf("removed @include injection: %s", strings.TrimSpace(line)))
continue
}
// Variable-based @include - check surrounding context for obfuscation
if varInclude.MatchString(line) {
context := getLineContext(lines, i, 3)
contextLower := strings.ToLower(context)
isObfuscated := strings.Contains(contextLower, "base64_decode") ||
strings.Contains(contextLower, "str_rot13") ||
strings.Contains(contextLower, "chr(") ||
strings.Contains(contextLower, `"\x`) ||
strings.Count(contextLower, ". ") > 5 // heavy string concatenation
if isObfuscated {
removals = append(removals, fmt.Sprintf("removed obfuscated @include: %s", strings.TrimSpace(line)))
continue
}
}
clean = append(clean, line)
}
return strings.Join(clean, "\n"), removals
}
// removePrependInjection removes malicious PHP code injected before the
// legitimate file content. Uses entropy analysis to verify the prefix
// is actually obfuscated (not legitimate minified code).
func removePrependInjection(content string) (string, []string) {
var removals []string
trimmed := strings.TrimSpace(content)
if !strings.HasPrefix(trimmed, "<?php") {
return content, nil
}
// Find if there's a malicious block at the start followed by ?><?php
closeOpen := regexp.MustCompile(`\?>\s*<\?php`)
loc := closeOpen.FindStringIndex(content)
if loc == nil {
return content, nil
}
prefix := content[:loc[0]]
prefixLower := strings.ToLower(prefix)
// Check if the prefix contains malicious patterns
hasMaliciousPatterns := strings.Contains(prefixLower, "eval(") ||
strings.Contains(prefixLower, "base64_decode") ||
strings.Contains(prefixLower, "gzinflate") ||
strings.Contains(prefixLower, "str_rot13") ||
strings.Contains(prefixLower, "@include")
if !hasMaliciousPatterns {
return content, nil
}
// Additional safety: verify the prefix has high entropy (obfuscated code)
// or contains long encoded strings. This prevents false positives on
// legitimate minified PHP that happens to use ?><?php patterns.
entropy := shannonEntropy(prefix)
hasLongStrings := containsLongEncodedString(prefix, 100)
if entropy < 4.5 && !hasLongStrings {
// Low entropy and no long encoded strings - likely legitimate code
return content, nil
}
// Remove everything before the second <?php
cleaned := "<?php" + content[loc[1]:]
removals = append(removals, fmt.Sprintf("removed %d-byte prepend injection (entropy: %.2f)", loc[1], entropy))
return cleaned, removals
}
// removeAppendInjection removes malicious code appended after the end of
// legitimate PHP content. Handles both files with closing ?> and files
// without (PSR-12 style).
func removeAppendInjection(content string) (string, []string) {
var removals []string
// Case 1: File has closing ?> with malicious code after it
lastClose := strings.LastIndex(content, "?>")
if lastClose >= 0 {
after := content[lastClose+2:]
afterTrimmed := strings.TrimSpace(after)
if afterTrimmed != "" {
afterLower := strings.ToLower(afterTrimmed)
isMalicious := strings.Contains(afterLower, "eval(") ||
strings.Contains(afterLower, "base64_decode") ||
strings.Contains(afterLower, "gzinflate") ||
strings.Contains(afterLower, "system(") ||
strings.Contains(afterLower, "exec(") ||
strings.Contains(afterLower, "@include") ||
strings.Contains(afterLower, "<?php") // second PHP block appended
if isMalicious {
cleaned := content[:lastClose+2] + "\n"
removals = append(removals, fmt.Sprintf("removed %d-byte append injection (after ?>)", len(after)))
return cleaned, removals
}
}
}
// Case 2: PSR-12 style file (no closing ?>) - check if there's a malicious
// block appended at the very end, separated by multiple newlines
lines := strings.Split(content, "\n")
if len(lines) < 5 {
return content, nil
}
// Check last 10 lines for injected code block
startCheck := len(lines) - 10
if startCheck < 0 {
startCheck = 0
}
blankLineIdx := -1
for i := startCheck; i < len(lines); i++ {
if strings.TrimSpace(lines[i]) == "" && blankLineIdx < 0 {
blankLineIdx = i
}
}
if blankLineIdx >= 0 {
tailBlock := strings.Join(lines[blankLineIdx:], "\n")
tailLower := strings.ToLower(tailBlock)
if strings.Contains(tailLower, "eval(") && strings.Contains(tailLower, "base64_decode") {
cleaned := strings.Join(lines[:blankLineIdx], "\n") + "\n"
removals = append(removals, fmt.Sprintf("removed %d-byte PSR-12 append injection", len(tailBlock)))
return cleaned, removals
}
}
return content, nil
}
// removeInlineEvalInjections removes single-line eval(base64_decode("..."));
// injections that are inserted as standalone lines in PHP files.
func removeInlineEvalInjections(content string) (string, []string) {
var removals []string
lines := strings.Split(content, "\n")
var clean []string
// Matches standalone eval(base64_decode("...")) or eval(gzinflate(base64_decode("...")))
evalInject := regexp.MustCompile(
`(?i)^\s*(?:@?)eval\s*\(\s*(?:base64_decode|gzinflate|gzuncompress|str_rot13)\s*\(`)
for _, line := range lines {
trimmedLine := strings.TrimSpace(line)
if evalInject.MatchString(trimmedLine) {
// Verify it's a standalone injection (not part of legitimate code)
// Standalone injections are long (encoded payload) - short eval() is likely legitimate
if len(trimmedLine) > 50 {
removals = append(removals, fmt.Sprintf("removed inline eval injection (%d chars)", len(trimmedLine)))
continue
}
}
clean = append(clean, line)
}
return strings.Join(clean, "\n"), removals
}
// --- Helper functions ---
// shannonEntropy calculates the Shannon entropy of a string.
// Obfuscated/encoded code typically has entropy > 5.0.
// Normal PHP code typically has entropy 4.0-4.5.
func shannonEntropy(s string) float64 {
if len(s) == 0 {
return 0
}
freq := make(map[byte]float64)
for i := 0; i < len(s); i++ {
freq[s[i]]++
}
length := float64(len(s))
entropy := 0.0
for _, count := range freq {
p := count / length
if p > 0 {
entropy -= p * math.Log2(p)
}
}
return entropy
}
// containsLongEncodedString checks if the text contains a long base64-like
// string (alphanumeric + /+ without spaces).
func containsLongEncodedString(s string, minLength int) bool {
encoded := regexp.MustCompile(`[A-Za-z0-9+/=]{` + fmt.Sprintf("%d", minLength) + `,}`)
return encoded.MatchString(s)
}
// getLineContext returns N lines before and after the given line index.
func getLineContext(lines []string, idx, window int) string {
start := idx - window
if start < 0 {
start = 0
}
end := idx + window + 1
if end > len(lines) {
end = len(lines)
}
return strings.Join(lines[start:end], "\n")
}
// FormatCleanResult returns a human-readable summary of a clean operation.
func FormatCleanResult(r CleanResult) string {
if r.Error != "" {
return fmt.Sprintf("FAILED to clean %s: %s", r.Path, r.Error)
}
if !r.Cleaned {
return fmt.Sprintf("No changes made to %s", r.Path)
}
var b strings.Builder
fmt.Fprintf(&b, "CLEANED %s\n", r.Path)
fmt.Fprintf(&b, " Backup: %s\n", r.BackupPath)
for _, removal := range r.Removals {
fmt.Fprintf(&b, " - %s\n", removal)
}
return b.String()
}
// --- Strategy 5: Multi-layer base64 decode chains ---
// Catches: eval(base64_decode(base64_decode("...")))
// Catches: $x=base64_decode("...");$y=base64_decode($x);eval($y);
func removeMultiLayerBase64(content string) (string, []string) {
var removals []string
lines := strings.Split(content, "\n")
var clean []string
// Multi-layer base64: 2+ nested base64_decode calls on one line
multiB64 := regexp.MustCompile(`(?i)(?:base64_decode\s*\(\s*){2,}`)
// Chained base64 across variables: $x = base64_decode(...); eval($x);
chainedB64 := regexp.MustCompile(`(?i)\$\w+\s*=\s*base64_decode\s*\(\s*base64_decode`)
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if len(trimmed) < 80 {
clean = append(clean, line)
continue
}
if multiB64.MatchString(trimmed) || chainedB64.MatchString(trimmed) {
removals = append(removals, fmt.Sprintf("removed multi-layer base64 chain (%d chars)", len(trimmed)))
continue
}
clean = append(clean, line)
}
return strings.Join(clean, "\n"), removals
}
// --- Strategy 6: chr()/pack() constructed code ---
// Catches: eval(chr(115).chr(121).chr(115)...);
// Catches: $f = pack("H*", "73797374656d"); $f($_POST['cmd']);
func removeChrPackInjections(content string) (string, []string) {
var removals []string
lines := strings.Split(content, "\n")
var clean []string
// 5+ chr() calls concatenated - building function names from char codes
chrChain := regexp.MustCompile(`(?i)(?:chr\s*\(\s*\d+\s*\)\s*\.?\s*){5,}`)
// pack("H*", ...) - hex string to function name construction
packHex := regexp.MustCompile(`(?i)pack\s*\(\s*["']H\*["']\s*,`)
for i, line := range lines {
trimmed := strings.TrimSpace(line)
if chrChain.MatchString(trimmed) {
removals = append(removals, fmt.Sprintf("removed chr() chain injection (line %d, %d chars)", i+1, len(trimmed)))
continue
}
if packHex.MatchString(trimmed) {
lower := strings.ToLower(trimmed)
// Only remove if combined with execution
if strings.Contains(lower, "eval") || strings.Contains(lower, "$_") ||
strings.Contains(lower, "system") || strings.Contains(lower, "exec") {
removals = append(removals, fmt.Sprintf("removed pack() code construction (line %d)", i+1))
continue
}
}
clean = append(clean, line)
}
return strings.Join(clean, "\n"), removals
}
// --- Strategy 7: Hex-encoded variable injections ---
// Catches: $GLOBALS["\x61\x64\x6d\x69\x6e"] = eval(...)
// Catches: ${"\x47\x4c\x4f\x42\x41\x4c\x53"}[...] = ...
func removeHexVarInjections(content string) (string, []string) {
var removals []string
lines := strings.Split(content, "\n")
var clean []string
// Variable names built from hex: $GLOBALS["\x41\x42\x43"]
hexVar := regexp.MustCompile(`(?:"\x5c\x78[0-9a-fA-F]{2}){3,}|(?:\\x[0-9a-fA-F]{2}){3,}`)
for i, line := range lines {
trimmed := strings.TrimSpace(line)
if len(trimmed) < 30 {
clean = append(clean, line)
continue
}
if hexVar.MatchString(trimmed) {
lower := strings.ToLower(trimmed)
// Only remove if combined with dangerous operations
if strings.Contains(lower, "eval") || strings.Contains(lower, "system(") ||
strings.Contains(lower, "exec(") || strings.Contains(lower, "base64_decode") ||
strings.Contains(lower, "assert(") || strings.Contains(lower, "$_post") ||
strings.Contains(lower, "$_request") || strings.Contains(lower, "$_get") {
removals = append(removals, fmt.Sprintf("removed hex-encoded variable injection (line %d, %d chars)", i+1, len(trimmed)))
continue
}
}
clean = append(clean, line)
}
return strings.Join(clean, "\n"), removals
}
package checks
import (
"crypto/sha256"
"encoding/hex"
"io"
"sync"
)
// CMSHashCache stores SHA256 hashes of verified-clean CMS core files.
// After wp core verify-checksums confirms an installation is clean,
// all its core files are hashed and cached. The real-time scanner
// checks this cache before reporting signature matches - if a file's
// hash is in the cache, it's a known-clean CMS file and signature
// matches on it are false positives.
type CMSHashCache struct {
mu sync.RWMutex
hashes map[string]bool // SHA256 hex → true
}
var (
globalCache *CMSHashCache
globalCacheOnce sync.Once
)
// GlobalCMSCache returns the singleton cache, creating it on first call.
func GlobalCMSCache() *CMSHashCache {
globalCacheOnce.Do(func() {
globalCache = &CMSHashCache{
hashes: make(map[string]bool),
}
})
return globalCache
}
// Add inserts a file hash into the cache.
func (c *CMSHashCache) Add(hash string) {
c.mu.Lock()
c.hashes[hash] = true
c.mu.Unlock()
}
// Contains checks if a file hash is in the cache.
func (c *CMSHashCache) Contains(hash string) bool {
c.mu.RLock()
ok := c.hashes[hash]
c.mu.RUnlock()
return ok
}
// Size returns the number of cached hashes.
func (c *CMSHashCache) Size() int {
c.mu.RLock()
n := len(c.hashes)
c.mu.RUnlock()
return n
}
// Clear removes all cached hashes (used before rebuilding).
func (c *CMSHashCache) Clear() {
c.mu.Lock()
c.hashes = make(map[string]bool)
c.mu.Unlock()
}
// HashFile computes the SHA256 hash of a file. Returns empty string on error.
func HashFile(path string) string {
f, err := osFS.Open(path)
if err != nil {
return ""
}
defer func() { _ = f.Close() }()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return ""
}
return hex.EncodeToString(h.Sum(nil))
}
// IsVerifiedCMSFile checks if a file at the given path matches a
// known-clean CMS core file by comparing its SHA256 hash against the cache.
//
// The cache is keyed by SHA256 hash alone (not path+hash) - this is correct:
// - SHA256 preimage resistance makes it computationally infeasible for an
// attacker to craft a file that produces the same hash as a legitimate
// WP core file. Birthday attacks do not apply here because the attacker
// must hit a specific pre-existing hash, not merely find any collision.
// - If file content matches a known WP core file byte-for-byte, it IS that
// file regardless of where it is located on disk. The path is irrelevant
// to whether the content is clean.
func IsVerifiedCMSFile(path string) bool {
cache := GlobalCMSCache()
if cache.Size() == 0 {
return false
}
hash := HashFile(path)
if hash == "" {
return false
}
return cache.Contains(hash)
}
package checks
import (
"context"
"encoding/hex"
"fmt"
"net"
"path/filepath"
"strconv"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckOutboundUserConnections looks for non-root user processes making
// outbound connections to IPs that aren't infra or well-known services.
// Catches compromised accounts phoning home.
func CheckOutboundUserConnections(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Known service ports that are always OK for outbound
safeRemotePorts := map[int]bool{
53: true, 80: true, 443: true, 25: true, 587: true, 465: true,
993: true, 995: true, 110: true, 143: true,
}
// Known safe service users - system daemons that make outbound connections
safeUsers := map[string]bool{
"imunify360-webshield": true,
"named": true,
"mysql": true,
"memcached": true,
"icinga": true,
"dovecot": true,
"mailman": true,
}
data, err := osFS.ReadFile("/proc/net/tcp")
if err != nil {
return nil
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) < 8 || fields[0] == "sl" {
continue
}
// State 01 = ESTABLISHED
if fields[3] != "01" {
continue
}
// Get UID (field 7)
uid := fields[7]
if uid == "0" {
continue // skip root
}
// Parse local and remote addresses
_, localPort := parseHexAddr(fields[1])
remoteIP, remotePort := parseHexAddr(fields[2])
if remoteIP == "127.0.0.1" || remoteIP == "0.0.0.0" {
continue
}
// Skip if local port is a known service (we're the server)
knownLocalPorts := map[int]bool{
21: true, 25: true, 26: true, 53: true, 80: true, 110: true,
143: true, 443: true, 465: true, 587: true, 993: true, 995: true,
2082: true, 2083: true, 2086: true, 2087: true, 2095: true, 2096: true,
3306: true, 4190: true,
// Imunify360 webshield ports
52223: true, 52224: true, 52227: true, 52228: true,
52229: true, 52230: true, 52231: true, 52232: true,
}
if knownLocalPorts[localPort] {
continue
}
// Skip safe remote ports
if safeRemotePorts[remotePort] {
continue
}
// Skip infra IPs
if isInfraIP(remoteIP, cfg.InfraIPs) {
continue
}
// Check if this is a known safe service user
user := uidToUser(uid)
if safeUsers[user] {
continue
}
// This is a non-root user process connecting to a non-standard
// port on a non-infra IP - suspicious
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "user_outbound_connection",
Message: fmt.Sprintf("Non-root user connecting to unusual destination: %s:%d", remoteIP, remotePort),
Details: fmt.Sprintf("UID: %s (%s), Local port: %d", uid, user, localPort),
})
}
// Parse /proc/net/tcp6 for IPv6 connections
tcp6Data, err := osFS.ReadFile("/proc/net/tcp6")
if err == nil {
for _, line := range strings.Split(string(tcp6Data), "\n") {
fields := strings.Fields(line)
if len(fields) < 8 || fields[0] == "sl" {
continue
}
// State 01 = ESTABLISHED
if fields[3] != "01" {
continue
}
uid := fields[7]
if uid == "0" {
continue
}
_, localPort6 := parseHex6Addr(fields[1])
remoteIP6, remotePort6 := parseHex6Addr(fields[2])
if remoteIP6 == nil || remoteIP6.IsLoopback() || remoteIP6.IsUnspecified() {
continue
}
// Skip if local port is a known service (we're the server)
knownLocalPorts6 := map[int]bool{
21: true, 25: true, 26: true, 53: true, 80: true, 110: true,
143: true, 443: true, 465: true, 587: true, 993: true, 995: true,
2082: true, 2083: true, 2086: true, 2087: true, 2095: true, 2096: true,
3306: true, 4190: true,
52223: true, 52224: true, 52227: true, 52228: true,
52229: true, 52230: true, 52231: true, 52232: true,
}
if knownLocalPorts6[localPort6] {
continue
}
if safeRemotePorts[remotePort6] {
continue
}
remoteIPStr := remoteIP6.String()
if isInfraIP(remoteIPStr, cfg.InfraIPs) {
continue
}
user := uidToUser(uid)
if safeUsers[user] {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "user_outbound_connection",
Message: fmt.Sprintf("Non-root user connecting to unusual destination: [%s]:%d", remoteIPStr, remotePort6),
Details: fmt.Sprintf("UID: %s (%s), Local port: %d, Proto: tcp6", uid, user, localPort6),
})
}
}
return findings
}
// uidToUser tries to resolve a UID to username from /etc/passwd.
func uidToUser(uid string) string {
data, err := osFS.ReadFile("/etc/passwd")
if err != nil {
return uid
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Split(line, ":")
if len(fields) >= 3 && fields[2] == uid {
return fields[0]
}
}
return uid
}
// parseHex6Addr parses an IPv6 address:port from /proc/net/tcp6 format.
// IPv6 addresses are 32 hex chars (128 bits) in little-endian 4-byte groups.
func parseHex6Addr(s string) (net.IP, int) {
parts := strings.SplitN(s, ":", 2)
if len(parts) != 2 {
return nil, 0
}
hexIP := parts[0]
hexPort := parts[1]
if len(hexIP) != 32 {
return nil, 0
}
port, _ := strconv.ParseInt(hexPort, 16, 32)
// Parse as 4 little-endian 32-bit words
ip := make(net.IP, 16)
for i := 0; i < 4; i++ {
word := hexIP[i*8 : (i+1)*8]
b, _ := hex.DecodeString(word)
if len(b) != 4 {
return nil, 0
}
// Reverse bytes within each 32-bit word (little-endian to big-endian)
ip[i*4+0] = b[3]
ip[i*4+1] = b[2]
ip[i*4+2] = b[1]
ip[i*4+3] = b[0]
}
return ip, int(port)
}
// CheckSSHDConfig monitors sshd_config for dangerous changes.
func CheckSSHDConfig(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
hash, err := hashFileContent(sshdConfigPath)
if err != nil {
return nil
}
current := currentSSHDSettings()
hashKey := "_sshd_config_hash"
passKey := "_sshd_passwordauthentication"
rootKey := "_sshd_permitrootlogin"
prevHash, exists := store.GetRaw(hashKey)
prevPass, _ := store.GetRaw(passKey)
prevRoot, _ := store.GetRaw(rootKey)
if exists && prevHash != hash {
// Only alert when the effective setting changed into a dangerous value.
// This avoids false positives from commented defaults or Match blocks.
if current.PasswordAuthentication == "yes" && prevPass != "yes" {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "sshd_config_change",
Message: "PasswordAuthentication changed to 'yes' in sshd_config",
Details: "This allows password-based SSH login - high risk if passwords are compromised",
})
}
if current.PermitRootLogin == "yes" && prevRoot != "yes" {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "sshd_config_change",
Message: "PermitRootLogin changed to 'yes' in sshd_config",
})
}
// Generic change alert if no specific dangerous setting found
if len(findings) == 0 {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "sshd_config_change",
Message: "sshd_config modified",
})
}
}
store.SetRaw(hashKey, hash)
store.SetRaw(passKey, current.PasswordAuthentication)
store.SetRaw(rootKey, current.PermitRootLogin)
return findings
}
// CheckNulledPlugins scans WordPress plugin directories for signs of
// nulled/pirated plugins: missing licenses, known crack patterns, GPL
// bypass code, and plugins not found on wordpress.org.
func CheckNulledPlugins(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Known crack/null signatures in PHP files
crackSignatures := []string{
"nulled by", "cracked by", "gpl-club", "gpldl.com",
"developer license", "remove license check",
"license_key_bypass", "activation_bypass",
"@remove_license", "null_license",
}
homeDirs, _ := osFS.ReadDir("/home")
for _, homeEntry := range homeDirs {
if !homeEntry.IsDir() {
continue
}
pluginsDir := filepath.Join("/home", homeEntry.Name(), "public_html", "wp-content", "plugins")
plugins, err := osFS.ReadDir(pluginsDir)
if err != nil {
continue
}
for _, plugin := range plugins {
if !plugin.IsDir() {
continue
}
pluginDir := filepath.Join(pluginsDir, plugin.Name())
// Check main plugin PHP file for crack signatures
mainFiles, _ := osFS.Glob(filepath.Join(pluginDir, "*.php"))
for _, mainFile := range mainFiles {
// Only read the first 10KB of each file
data := readFileHead(mainFile, 10*1024)
if data == nil {
continue
}
contentLower := strings.ToLower(string(data))
for _, sig := range crackSignatures {
if strings.Contains(contentLower, sig) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "nulled_plugin",
Message: fmt.Sprintf("Possible nulled plugin: %s/%s", homeEntry.Name(), plugin.Name()),
Details: fmt.Sprintf("File: %s\nSignature: %s", mainFile, sig),
})
break
}
}
}
}
}
return findings
}
// readFileHead reads the first N bytes of a file.
func readFileHead(path string, maxBytes int) []byte {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
buf := make([]byte, maxBytes)
n, _ := f.Read(buf)
if n == 0 {
return nil
}
return buf[:n]
}
package checks
import (
"fmt"
"strings"
"github.com/pidginhost/csm/internal/alert"
)
// Checks that indicate active security events (not static config issues).
// Only these are used for cross-account correlation.
var securityEventChecks = map[string]bool{
"fake_kernel_thread": true,
"suspicious_process": true,
"php_suspicious_execution": true,
"backdoor_binary": true,
"webshell": true,
"new_webshell_file": true,
"new_executable_in_config": true,
"new_php_in_uploads": true,
"new_php_in_languages": true,
"new_php_in_upgrade": true,
"obfuscated_php": true,
"php_dropper": true,
"webshell_realtime": true,
"php_in_uploads_realtime": true,
"php_in_sensitive_dir_realtime": true,
"executable_in_config_realtime": true,
"obfuscated_php_realtime": true,
"webshell_content_realtime": true,
"c2_connection": true,
"cpanel_file_upload_realtime": true,
"shadow_change": true,
"root_password_change": true,
}
// CorrelateFindings analyzes findings for cross-account attack patterns.
// Only considers active security events, not static config issues
// (like WAF status, open_basedir, world-writable files).
func CorrelateFindings(findings []alert.Finding) []alert.Finding {
var extra []alert.Finding
// Count security event findings per account
accountCriticals := make(map[string]int)
for _, f := range findings {
if f.Severity != alert.Critical {
continue
}
if !securityEventChecks[f.Check] {
continue
}
account := extractAccountFromFinding(f)
if account != "" {
accountCriticals[account]++
}
}
// If 3+ accounts have critical security events, it's a coordinated attack
affectedAccounts := 0
var accountNames []string
for account, count := range accountCriticals {
if count > 0 {
affectedAccounts++
accountNames = append(accountNames, account)
}
}
if affectedAccounts >= 3 {
extra = append(extra, alert.Finding{
Severity: alert.Critical,
Check: "coordinated_attack",
Message: fmt.Sprintf("Possible coordinated attack: %d accounts have critical security events", affectedAccounts),
Details: fmt.Sprintf("Affected accounts: %s", strings.Join(accountNames, ", ")),
})
}
// Check for same malware type across accounts
malwareByCheck := make(map[string][]string)
for _, f := range findings {
if f.Check == "new_executable_in_config" || f.Check == "backdoor_binary" ||
f.Check == "webshell" || f.Check == "new_webshell_file" {
account := extractAccountFromFinding(f)
if account != "" {
malwareByCheck[f.Check] = append(malwareByCheck[f.Check], account)
}
}
}
for check, accounts := range malwareByCheck {
unique := uniqueStrings(accounts)
if len(unique) >= 2 {
extra = append(extra, alert.Finding{
Severity: alert.Critical,
Check: "cross_account_malware",
Message: fmt.Sprintf("Same malware type (%s) found in %d accounts", check, len(unique)),
Details: fmt.Sprintf("Accounts: %s", strings.Join(unique, ", ")),
})
}
}
return extra
}
func extractAccountFromFinding(f alert.Finding) string {
for _, s := range []string{f.Message, f.Details} {
if idx := strings.Index(s, "/home/"); idx >= 0 {
rest := s[idx+6:]
if slashIdx := strings.Index(rest, "/"); slashIdx > 0 {
return rest[:slashIdx]
}
}
}
return ""
}
func uniqueStrings(input []string) []string {
seen := make(map[string]bool)
var result []string
for _, s := range input {
if !seen[s] {
seen[s] = true
result = append(result, s)
}
}
return result
}
package checks
import (
"context"
"fmt"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const (
sessionLogPath = "/usr/local/cpanel/logs/session_log"
sessionLogTailLines = 1000
defaultMultiIPThreshold = 3
defaultMultiIPWindowMin = 60
)
// CheckCpanelLogins parses the cPanel session log for suspicious login activity:
// - cPanel (cpaneld) logins from non-infra IPs
// - Same account logged in from multiple distinct IPs (credential compromise indicator)
// - Password change purge events (attacker or auto-response password resets)
func CheckCpanelLogins(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
lines := tailFile(sessionLogPath, sessionLogTailLines)
if len(lines) == 0 {
return nil
}
// Track logins per account for multi-IP correlation
accountIPs := make(map[string]map[string]bool)
var passwordChanges []string
// Determine cutoff - only alert on events within the scan window
cutoff := time.Now().Add(-time.Duration(multiIPWindowMin(cfg)) * time.Minute)
for _, line := range lines {
// Session log format:
// [2026-03-25 07:19:47 +0200] info [cpaneld] 203.0.113.133 NEW user:token address=IP,...
// [2026-03-25 07:38:18 +0200] info [security] internal PURGE user:token password_change
// Parse timestamp
ts := parseSessionTimestamp(line)
if ts.IsZero() || ts.Before(cutoff) {
continue
}
// Detect cPanel logins from non-infra IPs
// Skip API/portal sessions (create_user_session) - only alert on direct form login
if strings.Contains(line, "[cpaneld]") && strings.Contains(line, " NEW ") {
if cfg.Suppressions.SuppressCpanelLogin {
// Still track IPs for multi-IP correlation even when suppressed
} else if strings.Contains(line, "method=create_user_session") ||
strings.Contains(line, "method=create_session") ||
strings.Contains(line, "create_user_session") {
continue
}
ip, account := parseCpanelLogin(line)
if ip == "" || account == "" {
continue
}
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" || ip == "internal" {
continue
}
// Always track for multi-IP correlation (even when login alerts suppressed)
if accountIPs[account] == nil {
accountIPs[account] = make(map[string]bool)
}
accountIPs[account][ip] = true
// WARNING severity - logins are audit trail, not paging-level.
// Multi-IP correlation stays CRITICAL via its own check below.
if !cfg.Suppressions.SuppressCpanelLogin {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "cpanel_login",
Message: fmt.Sprintf("cPanel direct login from non-infra IP: %s (account: %s)", ip, account),
Details: truncateString(line, 300),
})
}
}
// Detect password change purge events
if strings.Contains(line, "PURGE") && strings.Contains(line, "password_change") {
account := parsePurgeAccount(line)
if account != "" {
passwordChanges = append(passwordChanges, account)
}
}
}
// Multi-IP correlation: same account from 3+ distinct non-infra IPs
threshold := multiIPThreshold(cfg)
for account, ips := range accountIPs {
if len(ips) >= threshold {
ipList := make([]string, 0, len(ips))
for ip := range ips {
ipList = append(ipList, ip)
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "cpanel_multi_ip_login",
Message: fmt.Sprintf("Account '%s' logged in from %d distinct IPs (credential compromise likely)", account, len(ips)),
Details: fmt.Sprintf("IPs: %s\nThreshold: %d IPs within %d minutes", strings.Join(ipList, ", "), threshold, multiIPWindowMin(cfg)),
})
}
}
// Password change events - deduplicate by account
seen := make(map[string]bool)
for _, account := range passwordChanges {
if seen[account] {
continue
}
seen[account] = true
// Check if triggered by security module (Imunify auto-response) vs user action
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "cpanel_password_purge",
Message: fmt.Sprintf("cPanel sessions purged via password change for account: %s", account),
Details: "This may indicate an automated security response or attacker-initiated password change",
})
}
return findings
}
// CheckCpanelFileManager parses the cPanel access log for file management
// operations from non-infra IPs (file uploads, edits via cPanel File Manager).
func CheckCpanelFileManager(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
lines := tailFile("/usr/local/cpanel/logs/access_log", 300)
// Only match actual write actions - not read-only calls like get_homedir.
// Skip 401/403 responses - the server rejected the request, no write occurred.
// Match against request URI only, not the full line (referer can contain "upload").
filemanWriteActions := []string{
"fileman/save_file_content",
"fileman/upload_files",
"fileman/save_file",
"fileman/paste",
"fileman/rename",
"fileman/delete",
}
for _, line := range lines {
// Only check cPanel (port 2083) entries
if !strings.Contains(line, "2083") {
continue
}
// Skip rejected requests - no write occurred
if strings.Contains(line, "\" 401 ") || strings.Contains(line, "\" 403 ") {
continue
}
fields := strings.Fields(line)
if len(fields) < 1 {
continue
}
ip := fields[0]
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
// Extract request URI (between first pair of quotes) to avoid
// matching "upload" in referer URLs like upload-ajax.html
requestURI := extractRequestURIChecks(line)
for _, action := range filemanWriteActions {
if strings.Contains(strings.ToLower(requestURI), strings.ToLower(action)) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "cpanel_file_upload",
Message: fmt.Sprintf("cPanel File Manager write operation from non-infra IP: %s", ip),
Details: truncateString(line, 300),
})
break
}
}
}
return findings
}
// parseSessionTimestamp extracts the timestamp from a session log line.
// Format: [2026-03-25 07:19:47 +0200]
func parseSessionTimestamp(line string) time.Time {
start := strings.Index(line, "[")
end := strings.Index(line, "]")
if start < 0 || end < 0 || end <= start+1 {
return time.Time{}
}
tsStr := line[start+1 : end]
// Try common cPanel session log formats
for _, layout := range []string{
"2006-01-02 15:04:05 -0700",
"2006-01-02 15:04:05 +0000",
} {
if t, err := time.Parse(layout, tsStr); err == nil {
return t
}
}
return time.Time{}
}
// parseCpanelLogin extracts IP and account from a session NEW line.
// Format: [timestamp] info [cpaneld] 203.0.113.133 NEW user:token address=IP,...
func parseCpanelLogin(line string) (ip, account string) {
// Find IP after [cpaneld]
idx := strings.Index(line, "[cpaneld]")
if idx < 0 {
return "", ""
}
rest := strings.TrimSpace(line[idx+len("[cpaneld]"):])
fields := strings.Fields(rest)
if len(fields) < 3 {
return "", ""
}
ip = fields[0]
// Find account from "NEW user:token" or "NEW user:token address=..."
for i, f := range fields {
if f == "NEW" && i+1 < len(fields) {
userToken := fields[i+1]
parts := strings.SplitN(userToken, ":", 2)
if len(parts) >= 1 {
account = parts[0]
}
break
}
}
return ip, account
}
// parsePurgeAccount extracts the account name from a PURGE password_change line.
// Format: [timestamp] info [security] internal PURGE user:token password_change
func parsePurgeAccount(line string) string {
idx := strings.Index(line, "PURGE")
if idx < 0 {
return ""
}
rest := strings.TrimSpace(line[idx+len("PURGE"):])
fields := strings.Fields(rest)
if len(fields) < 1 {
return ""
}
parts := strings.SplitN(fields[0], ":", 2)
if len(parts) >= 1 {
return parts[0]
}
return ""
}
func multiIPThreshold(cfg *config.Config) int {
if cfg.Thresholds.MultiIPLoginThreshold > 0 {
return cfg.Thresholds.MultiIPLoginThreshold
}
return defaultMultiIPThreshold
}
func multiIPWindowMin(cfg *config.Config) int {
if cfg.Thresholds.MultiIPLoginWindowMin > 0 {
return cfg.Thresholds.MultiIPLoginWindowMin
}
return defaultMultiIPWindowMin
}
// extractRequestURIChecks extracts the request line from an access log entry.
// Format: ... "METHOD /path HTTP/1.1" ... → returns "METHOD /path HTTP/1.1"
func extractRequestURIChecks(line string) string {
start := strings.Index(line, "\"")
if start < 0 {
return ""
}
end := strings.Index(line[start+1:], "\"")
if end < 0 {
return ""
}
return line[start+1 : start+1+end]
}
package checks
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"syscall"
"time"
)
// fixCrontabAllowedRoots limits suspicious_crontab remediation to the cron
// spool. Declared as a var so tests can redirect it under t.TempDir()
// without touching the real /var/spool/cron.
var fixCrontabAllowedRoots = []string{"/var/spool/cron"}
// fixSuspiciousCrontab copies a user crontab matching known-bad persistence
// markers into quarantine, writes a restore-ready metadata sidecar, and then
// truncates the live file to empty. Truncation (not deletion) keeps cron(8)
// from re-reading stale content and preserves the caller's ability to
// inspect file perms while the malware is gone.
func fixSuspiciousCrontab(path string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path from finding"}
}
path, info, err := resolveExistingFixPath(path, fixCrontabAllowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
data, err := osFS.ReadFile(path)
if err != nil {
return RemediationResult{Error: fmt.Sprintf("cannot read: %v", err)}
}
_ = os.MkdirAll(quarantineDir, 0700)
ts := time.Now().Format("20060102-150405")
user := filepath.Base(path)
qPath := filepath.Join(quarantineDir, fmt.Sprintf("%s_crontab_%s", ts, user))
if err := os.WriteFile(qPath, data, 0600); err != nil {
return RemediationResult{Error: fmt.Sprintf("cannot write quarantine: %v", err)}
}
var uid, gid int
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
uid = int(stat.Uid)
gid = int(stat.Gid)
}
meta := map[string]interface{}{
"original_path": path,
"owner_uid": uid,
"group_gid": gid,
"mode": info.Mode().String(),
"size": info.Size(),
"quarantine_at": time.Now(),
"reason": "suspicious_crontab remediation",
}
metaData, _ := json.MarshalIndent(meta, "", " ")
if err := os.WriteFile(qPath+".meta", metaData, 0600); err != nil {
fmt.Fprintf(os.Stderr, "remediate: error writing crontab quarantine metadata %s: %v\n", qPath+".meta", err)
}
// Truncate live crontab. 0600 is the mode cron(8) enforces for user
// spool files; any other mode makes cron skip the file with a warning.
// #nosec G306 -- cron(8) rejects world-readable user crontabs, so 0600
// is the only safe mode for /var/spool/cron/<user>.
if err := os.WriteFile(path, []byte{}, 0600); err != nil {
return RemediationResult{Error: fmt.Sprintf("cannot truncate crontab: %v", err)}
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("quarantined crontab %s -> %s and truncated", path, qPath),
Description: fmt.Sprintf("Truncated %d-byte crontab; copy saved to quarantine", len(data)),
}
}
package checks
import (
"context"
"encoding/base64"
"fmt"
"path/filepath"
"regexp"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
func CheckCrontabs(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
crontabs, _ := osFS.Glob("/var/spool/cron/*")
for _, path := range crontabs {
user := filepath.Base(path)
if user == "root" {
// Track root crontab changes via hash
hash, _ := hashFileContent(path)
key := "_crontab_root_hash"
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "crontab_change",
Message: "Root crontab modified",
Details: "Review with: crontab -l",
})
}
store.SetRaw(key, hash)
continue
}
data, err := osFS.ReadFile(path)
if err != nil {
continue
}
content := string(data)
for _, pattern := range MatchCrontabPatternsDeep(content) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "suspicious_crontab",
Message: fmt.Sprintf("Suspicious pattern in crontab for user %s: %s", user, pattern),
Details: fmt.Sprintf("File: %s\nContent:\n%s", path, truncate(content, 500)),
})
}
}
// Check /etc/cron.d for new files
cronDFiles, _ := osFS.Glob("/etc/cron.d/*")
for _, path := range cronDFiles {
hash, _ := hashFileContent(path)
key := fmt.Sprintf("_crond:%s", filepath.Base(path))
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "crond_change",
Message: fmt.Sprintf("Cron.d file modified: %s", path),
})
}
store.SetRaw(key, hash)
}
return findings
}
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// crontabSuspiciousPatterns is the shared allowlist of case-insensitive
// substrings that mark a crontab line as likely malicious. Single source of
// truth for CheckCrontabs (system scan) and makeAccountCrontabCheck
// (per-account scan) so the two cannot drift apart.
var crontabSuspiciousPatterns = []string{
"defunct-kernel",
"SEED PRNG",
"base64_decode",
"base64 -d|bash",
"base64 -d | bash",
"base64 --decode|bash",
"eval(",
"/dev/tcp/",
"gsocket",
"gs-netcat",
"reverse",
"bash -i",
"/bin/sh -i",
"nc -e",
"ncat -e",
"python -c",
"perl -e",
}
// matchCrontabPatterns returns patterns from crontabSuspiciousPatterns that
// appear as case-insensitive substrings of content, preserving list order.
func matchCrontabPatterns(content string) []string {
lower := strings.ToLower(content)
var matched []string
for _, pattern := range crontabSuspiciousPatterns {
if strings.Contains(lower, strings.ToLower(pattern)) {
matched = append(matched, pattern)
}
}
return matched
}
// crontabBase64BlobMaxBytes caps a single base64 candidate before decoding.
// 8192 encoded bytes -> ~6KB decoded; comfortably fits any realistic cron
// payload while bounding work on adversarial input.
const crontabBase64BlobMaxBytes = 8192
// crontabBase64BlobMaxCount caps the number of base64 candidates examined
// per crontab. A realistic gsocket cron entry has one outer blob; we
// allow enough headroom for a handful without doing unbounded work.
const crontabBase64BlobMaxCount = 16
// crontabBase64BlobRE matches contiguous standard-alphabet base64 of
// length >= 40 (with optional padding). The 40-char floor avoids matching
// short config IDs and noise like Wordfence cookie names.
var crontabBase64BlobRE = regexp.MustCompile(`[A-Za-z0-9+/]{40,}={0,2}`)
// MatchCrontabPatternsDeep is matchCrontabPatterns plus a single base64
// decode pass: it pulls out base64 candidates from content and re-runs
// pattern matching on the decoded bytes. Catches attackers who wrap the
// `base64 -d|bash` pipe chain in an outer base64 layer so the literal
// markers never appear in the cron file as written. Single decode depth;
// no recursion.
func MatchCrontabPatternsDeep(content string) []string {
matched := matchCrontabPatterns(content)
seen := make(map[string]bool, len(matched))
for _, m := range matched {
seen[m] = true
}
candidates := crontabBase64BlobRE.FindAllString(content, crontabBase64BlobMaxCount)
for _, blob := range candidates {
if len(blob) > crontabBase64BlobMaxBytes {
blob = blob[:crontabBase64BlobMaxBytes]
}
decoded, err := base64.StdEncoding.DecodeString(blob)
if err != nil {
continue
}
for _, m := range matchCrontabPatterns(string(decoded)) {
if !seen[m] {
matched = append(matched, m)
seen[m] = true
}
}
}
return matched
}
package checks
import (
"fmt"
"net"
"regexp"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
// AutoRespondDBMalware processes database injection findings and takes
// automated action: blocks attacker IPs extracted from WordPress session
// tokens, revokes compromised user sessions, and cleans confirmed
// malicious content from wp_options.
//
// Only acts on high-confidence findings:
// - db_options_injection with confirmed malicious external script URLs
// - db_siteurl_hijack (siteurl/home pointing to malicious content)
//
// Does NOT act on:
// - db_spam_injection (spam posts — needs manual review)
// - db_post_injection (script in posts — too many FPs from page builders)
// - db_options_injection without confirmed malicious URLs
func AutoRespondDBMalware(cfg *config.Config, findings []alert.Finding) []alert.Finding {
if !cfg.AutoResponse.Enabled || !cfg.AutoResponse.CleanDatabase {
return nil
}
var actions []alert.Finding
for _, f := range findings {
switch f.Check {
case "db_options_injection":
acts := handleMaliciousOption(cfg, f)
actions = append(actions, acts...)
case "db_siteurl_hijack":
acts := handleSiteurlHijack(cfg, f)
actions = append(actions, acts...)
}
}
return actions
}
// handleMaliciousOption checks if a db_options_injection finding contains
// a confirmed malicious external script URL, and if so:
// 1. Extracts attacker IPs from WP sessions and emits block findings
// 2. Revokes sessions for users with non-infra, non-private IPs only
// 3. Backs up and cleans the malicious content from the option
func handleMaliciousOption(cfg *config.Config, f alert.Finding) []alert.Finding {
var actions []alert.Finding
dbName, optionName := parseDBFindingDetails(f.Details)
if dbName == "" || optionName == "" {
return nil
}
// Validate option name — must be a plausible WP option name.
if !isValidOptionName(optionName) {
return nil
}
// Never act on CSM backup options — they preserve original malicious
// content for recovery. Acting on them causes cascading backup loops.
if strings.HasPrefix(optionName, "csm_backup_") {
return nil
}
creds := findCredsForDB(dbName)
if creds.dbName == "" {
return nil
}
prefix := creds.tablePrefix
if prefix == "" {
prefix = "wp_"
}
// Re-read the FULL option value from the database — the finding's
// Details field only has a truncated 200-char preview.
fullValue := readOptionValue(creds, prefix, optionName)
if fullValue == "" {
return nil
}
// Only act on options with confirmed malicious external script URLs.
maliciousURL := extractMaliciousScriptURL(fullValue)
if maliciousURL == "" {
return nil
}
// 1. Extract and block attacker IPs from active WP sessions.
suspiciousIPs := extractSuspiciousSessionIPs(creds, prefix, cfg.InfraIPs)
for _, ip := range suspiciousIPs {
// Emit as auto_block check so AutoBlockIPs processes it.
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-BLOCK: %s (active WP session on compromised site, DB: %s)", ip, dbName),
Timestamp: time.Now(),
})
}
// 2. Revoke sessions only for users with suspicious IPs.
// This preserves the site admin's session if they're on an infra IP.
revoked := revokeCompromisedSessions(creds, prefix, cfg.InfraIPs)
if revoked > 0 {
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-DB-CLEAN: Revoked %d compromised WordPress sessions (DB: %s)", revoked, dbName),
Timestamp: time.Now(),
})
}
// 3. Back up the original value, then clean the malicious content.
cleaned := backupAndCleanOption(creds, prefix, optionName, fullValue, maliciousURL)
if cleaned {
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-DB-CLEAN: Removed malicious script from wp_options '%s' (DB: %s, URL: %s)", optionName, dbName, maliciousURL),
Timestamp: time.Now(),
})
}
return actions
}
// handleSiteurlHijack handles siteurl/home hijacking by revoking sessions
// and blocking attacker IPs. Does NOT modify siteurl/home values.
func handleSiteurlHijack(cfg *config.Config, f alert.Finding) []alert.Finding {
var actions []alert.Finding
dbName, _ := parseDBFindingDetails(f.Details)
if dbName == "" {
return nil
}
creds := findCredsForDB(dbName)
if creds.dbName == "" {
return nil
}
prefix := creds.tablePrefix
if prefix == "" {
prefix = "wp_"
}
suspiciousIPs := extractSuspiciousSessionIPs(creds, prefix, cfg.InfraIPs)
for _, ip := range suspiciousIPs {
actions = append(actions, alert.Finding{
Severity: alert.Critical,
Check: "auto_block",
Message: fmt.Sprintf("AUTO-BLOCK: %s (active session on hijacked site, DB: %s)", ip, dbName),
Timestamp: time.Now(),
})
}
revoked := revokeCompromisedSessions(creds, prefix, cfg.InfraIPs)
if revoked > 0 {
actions = append(actions, alert.Finding{
Severity: alert.Warning,
Check: "auto_response",
Message: fmt.Sprintf("AUTO-DB-CLEAN: Revoked %d sessions on hijacked site (DB: %s)", revoked, dbName),
Timestamp: time.Now(),
})
}
return actions
}
// --- URL analysis ---
// scriptSrcRe matches <script src="..."> or <script src=...> patterns.
// Accepts https://, http://, and protocol-relative // URLs, since real
// attackers use all three forms to load external payloads.
var scriptSrcRe = regexp.MustCompile(`(?i)<script[^>]+src\s*=\s*["']?((?:https?:)?//[^"'\s>]+)`)
// knownSafeDomains are legitimate services that embed scripts in wp_options.
var knownSafeDomains = []string{
"googletagmanager.com",
"google-analytics.com",
"googleapis.com",
"gstatic.com",
"google.com",
"facebook.net",
"facebook.com",
"fbcdn.net",
"connect.facebook.net",
"chimpstatic.com",
"mailchimp.com",
"hotjar.com",
"clarity.ms",
"cloudflare.com",
"cdnjs.cloudflare.com",
"jquery.com",
"jsdelivr.net",
"unpkg.com",
"wp.com",
"wordpress.com",
"gravatar.com",
"tawk.to",
"crisp.chat",
"tidio.co",
"intercom.io",
"zendesk.com",
"hubspot.com",
"hubspot.net",
"hs-scripts.com",
"hs-analytics.net",
"hsforms.com",
"mautic.net",
"pinterest.com",
"twitter.com",
"linkedin.com",
"addthis.com",
"sharethis.com",
"recaptcha.net",
"stripe.com",
"paypal.com",
"brevo-mail.com",
}
// extractMaliciousScriptURL finds a <script src="..."> URL in the content
// that is classified as an attacker script by isAttackerScriptURL.
//
// The classification uses an attack-indicator model (see url_reputation.go):
// a URL flags only when it shows attacker-characteristic markers (raw IP
// host, abused TLD, plaintext HTTP, known-bad exfil host, or no valid
// TLD). The previous allowlist-only model produced HIGH-severity findings
// for legitimate third-party widgets (OneTrust, Issuu, regional video
// embeds, regional tax-form widgets) whose domains were not on the
// allowlist; the attack-indicator model eliminates those false positives
// while still catching the injection patterns attackers actually use.
//
// knownSafeDomains is retained as a fast-path optimisation and operator-
// pre-approved list — see isAttackerScriptURL for the composition order.
func extractMaliciousScriptURL(content string) string {
matches := scriptSrcRe.FindAllStringSubmatch(content, -1)
for _, match := range matches {
if len(match) < 2 {
continue
}
url := match[1]
if isAttackerScriptURL(url) {
return url
}
}
return ""
}
// isSafeScriptDomain checks if a script URL is from a known safe domain.
// Handles https://host, http://host, //host (protocol-relative), and
// host-with-port forms.
func isSafeScriptDomain(url string) bool {
urlLower := strings.ToLower(url)
urlLower = strings.TrimPrefix(urlLower, "https://")
urlLower = strings.TrimPrefix(urlLower, "http://")
urlLower = strings.TrimPrefix(urlLower, "//")
host := urlLower
if idx := strings.IndexByte(host, '/'); idx >= 0 {
host = host[:idx]
}
if idx := strings.IndexByte(host, ':'); idx >= 0 {
host = host[:idx]
}
for _, safe := range knownSafeDomains {
if host == safe || strings.HasSuffix(host, "."+safe) {
return true
}
}
return false
}
// --- Validation ---
// validOptionNameRe allows alphanumeric, underscores, hyphens, colons, and dots.
// Rejects anything that could be SQL injection.
var validOptionNameRe = regexp.MustCompile(`^[a-zA-Z0-9_\-:.]+$`)
// isValidOptionName validates that an option name is safe for SQL interpolation.
func isValidOptionName(name string) bool {
return len(name) > 0 && len(name) <= 191 && validOptionNameRe.MatchString(name)
}
// --- DB helpers ---
// parseDBFindingDetails extracts the database name and option name from
// a finding's Details field.
func parseDBFindingDetails(details string) (dbName, optionName string) {
for _, line := range strings.Split(details, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "Database: ") {
dbName = strings.TrimPrefix(line, "Database: ")
}
if strings.HasPrefix(line, "Option: ") {
optionName = strings.TrimPrefix(line, "Option: ")
}
}
return
}
// findCredsForDB finds wp-config.php credentials that match a database name.
func findCredsForDB(dbName string) wpDBCreds {
wpConfigs, _ := osFS.Glob("/home/*/public_html/wp-config.php")
addonConfigs, _ := osFS.Glob("/home/*/*/wp-config.php")
wpConfigs = append(wpConfigs, addonConfigs...)
for _, path := range wpConfigs {
creds := parseWPConfig(path)
if creds.dbName == dbName {
return creds
}
}
return wpDBCreds{}
}
// readOptionValue reads the full value of a wp_option from the database.
func readOptionValue(creds wpDBCreds, prefix, optionName string) string {
if !isValidOptionName(optionName) {
return ""
}
query := fmt.Sprintf(
"SELECT option_value FROM %soptions WHERE option_name='%s' LIMIT 1",
prefix, escapeSQLString(optionName))
lines := runMySQLQuery(creds, query)
if len(lines) == 0 {
return ""
}
return lines[0]
}
// extractSuspiciousSessionIPs reads WP session tokens and returns IPs that
// are NOT infra IPs, not private, and not loopback.
func extractSuspiciousSessionIPs(creds wpDBCreds, prefix string, infraIPs []string) []string {
query := fmt.Sprintf(
"SELECT meta_value FROM %susermeta WHERE meta_key='session_tokens' AND meta_value != ''",
prefix)
lines := runMySQLQuery(creds, query)
seen := make(map[string]bool)
var ips []string
ipRe := regexp.MustCompile(`"ip";s:\d+:"([^"]+)"`)
for _, line := range lines {
matches := ipRe.FindAllStringSubmatch(line, -1)
for _, m := range matches {
if len(m) < 2 {
continue
}
ip := m[1]
parsed := net.ParseIP(ip)
if parsed == nil || parsed.IsLoopback() || parsed.IsPrivate() {
continue
}
if isInfraIP(ip, infraIPs) {
continue
}
if !seen[ip] {
seen[ip] = true
ips = append(ips, ip)
}
}
}
return ips
}
// revokeCompromisedSessions clears session_tokens only for WP users whose
// sessions contain non-infra, non-private IPs. Returns count of users revoked.
func revokeCompromisedSessions(creds wpDBCreds, prefix string, infraIPs []string) int {
// Get user IDs with active sessions.
query := fmt.Sprintf(
"SELECT user_id, meta_value FROM %susermeta WHERE meta_key='session_tokens' AND meta_value != ''",
prefix)
lines := runMySQLQuery(creds, query)
ipRe := regexp.MustCompile(`"ip";s:\d+:"([^"]+)"`)
revoked := 0
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) != 2 {
continue
}
userID := strings.TrimSpace(parts[0])
sessionData := parts[1]
// Check if this user has any suspicious (non-infra, non-private) IPs.
hasSuspicious := false
matches := ipRe.FindAllStringSubmatch(sessionData, -1)
for _, m := range matches {
if len(m) < 2 {
continue
}
ip := m[1]
parsed := net.ParseIP(ip)
if parsed == nil || parsed.IsLoopback() || parsed.IsPrivate() {
continue
}
if isInfraIP(ip, infraIPs) {
continue
}
hasSuspicious = true
break
}
if hasSuspicious {
revokeQuery := fmt.Sprintf(
"UPDATE %susermeta SET meta_value='' WHERE user_id=%s AND meta_key='session_tokens'",
prefix, escapeSQLString(userID))
runMySQLQuery(creds, revokeQuery)
revoked++
}
}
return revoked
}
// backupAndCleanOption saves the original value to a backup option, then
// removes malicious script injections from the option value.
func backupAndCleanOption(creds wpDBCreds, prefix, optionName, originalValue, maliciousURL string) bool {
cleaned := removeMaliciousScripts(originalValue)
if cleaned == originalValue {
return false
}
// Save original value as a backup option (csm_backup_<name>_<timestamp>).
backupName := fmt.Sprintf("csm_backup_%s_%d", optionName, time.Now().Unix())
if len(backupName) > 191 {
backupName = backupName[:191]
}
backupQuery := fmt.Sprintf(
"INSERT INTO %soptions (option_name, option_value, autoload) VALUES ('%s', '%s', 'no')",
prefix, escapeSQLString(backupName), escapeSQLString(originalValue))
runMySQLQuery(creds, backupQuery)
// Write the cleaned value.
updateQuery := fmt.Sprintf(
"UPDATE %soptions SET option_value='%s' WHERE option_name='%s'",
prefix, escapeSQLString(cleaned), escapeSQLString(optionName))
runMySQLQuery(creds, updateQuery)
return true
}
// --- Script removal ---
// maliciousScriptRe matches the style-break injection pattern:
// </style><script src=...></script><style>
var maliciousScriptRe = regexp.MustCompile(
`(?i)</style>\s*<script[^>]*src\s*=\s*[^>]+>\s*</script>\s*<style>`)
// simpleScriptRe matches standalone <script src="..."></script> tags.
var simpleScriptRe = regexp.MustCompile(
`(?i)<script[^>]*src\s*=\s*["']?https?://[^"'\s>]+["']?[^>]*>\s*</script>`)
// removeMaliciousScripts strips malicious <script> injections from content,
// preserving scripts that are not classified as attacker scripts.
//
// Uses the same isAttackerScriptURL predicate as extractMaliciousScriptURL
// so detection and removal stay semantically paired. If the detector
// would not flag a given URL as malicious, the remover must not strip
// it — otherwise an operator running DBCleanOption on an option that
// contains a real injection alongside a legitimate third-party embed
// (OneTrust, Issuu, regional widget) would silently lose the legitimate
// embed along with the attacker's script.
func removeMaliciousScripts(content string) string {
// First pass: remove style-break pattern (always malicious — the
// </style><script ...></script><style> sandwich is not a form any
// legitimate CMS or plugin emits).
content = maliciousScriptRe.ReplaceAllString(content, "")
// Second pass: remove standalone script tags only when the URL
// shows attacker indicators.
content = simpleScriptRe.ReplaceAllStringFunc(content, func(match string) string {
urls := scriptSrcRe.FindStringSubmatch(match)
if len(urls) >= 2 && isAttackerScriptURL(urls[1]) {
return ""
}
return match
})
return strings.TrimSpace(content)
}
// escapeSQLString escapes special characters for MySQL string interpolation.
func escapeSQLString(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `'`, `\'`)
s = strings.ReplaceAll(s, "\x00", `\0`)
s = strings.ReplaceAll(s, "\n", `\n`)
s = strings.ReplaceAll(s, "\r", `\r`)
s = strings.ReplaceAll(s, "\x1a", `\Z`)
return s
}
package checks
import (
"fmt"
"regexp"
"strings"
"time"
)
// DBCleanResult describes the outcome of a database cleanup operation.
type DBCleanResult struct {
Account string
Database string
Action string // "clean-option", "revoke-user", "delete-spam"
Success bool
Message string
Details []string // individual actions taken
BackupNames []string // names of backup options created
}
// DBCleanOption removes malicious script injections from a wp_option value.
// Creates a backup option before modifying. Returns a result describing
// what was done. If preview is true, reports what would be done without
// modifying the database.
func DBCleanOption(account, optionName string, preview bool) DBCleanResult {
result := DBCleanResult{
Account: account,
Action: "clean-option",
}
if !isValidOptionName(optionName) {
result.Message = fmt.Sprintf("Invalid option name: %q", optionName)
return result
}
creds, prefix := findCredsForAccount(account)
if creds.dbName == "" {
result.Message = fmt.Sprintf("No WordPress database found for account %q", account)
return result
}
result.Database = creds.dbName
// Read current value.
value := readOptionValue(creds, prefix, optionName)
if value == "" {
result.Message = fmt.Sprintf("Option %q not found or empty in %s", optionName, creds.dbName)
return result
}
// Check for malicious content.
maliciousURL := extractMaliciousScriptURL(value)
if maliciousURL == "" {
result.Message = fmt.Sprintf("No malicious external script found in %q", optionName)
result.Details = append(result.Details, "Option exists but contains no confirmed malicious URLs")
return result
}
cleaned := removeMaliciousScripts(value)
if cleaned == value {
result.Message = "Content unchanged after cleaning"
return result
}
result.Details = append(result.Details, fmt.Sprintf("Malicious URL: %s", maliciousURL))
result.Details = append(result.Details, fmt.Sprintf("Original length: %d, Cleaned length: %d", len(value), len(cleaned)))
if preview {
result.Message = fmt.Sprintf("PREVIEW: Would clean malicious script from %q", optionName)
result.Success = true
return result
}
// Backup and clean.
if backupAndCleanOption(creds, prefix, optionName, value, maliciousURL) {
backupName := fmt.Sprintf("csm_backup_%s_%d", optionName, time.Now().Unix())
if len(backupName) > 191 {
backupName = backupName[:191]
}
result.BackupNames = append(result.BackupNames, backupName)
result.Details = append(result.Details, fmt.Sprintf("Backup saved as: %s", backupName))
result.Message = fmt.Sprintf("Cleaned malicious script from %q", optionName)
result.Success = true
} else {
result.Message = "Failed to clean option"
}
return result
}
// DBRevokeUser revokes WordPress sessions for a specific user and optionally
// demotes them to subscriber role. If preview is true, reports what would be
// done without modifying the database.
func DBRevokeUser(account string, userID int, demote, preview bool) DBCleanResult {
result := DBCleanResult{
Account: account,
Action: "revoke-user",
}
creds, prefix := findCredsForAccount(account)
if creds.dbName == "" {
result.Message = fmt.Sprintf("No WordPress database found for account %q", account)
return result
}
result.Database = creds.dbName
// Verify user exists.
query := fmt.Sprintf(
"SELECT user_login, user_email FROM %susers WHERE ID=%d LIMIT 1",
prefix, userID)
lines := runMySQLQueryRoot(creds.dbName, query)
if len(lines) == 0 {
result.Message = fmt.Sprintf("User ID %d not found in %s", userID, creds.dbName)
return result
}
parts := strings.SplitN(lines[0], "\t", 2)
login := parts[0]
email := ""
if len(parts) > 1 {
email = parts[1]
}
result.Details = append(result.Details, fmt.Sprintf("User: %s (email: %s)", login, email))
// Check current sessions.
sessQuery := fmt.Sprintf(
"SELECT LEFT(meta_value, 200) FROM %susermeta WHERE user_id=%d AND meta_key='session_tokens'",
prefix, userID)
sessLines := runMySQLQueryRoot(creds.dbName, sessQuery)
sessionCount := 0
if len(sessLines) > 0 && sessLines[0] != "" {
sessionCount = strings.Count(sessLines[0], `"expiration"`)
}
result.Details = append(result.Details, fmt.Sprintf("Active sessions: %d", sessionCount))
if preview {
msg := fmt.Sprintf("PREVIEW: Would revoke %d sessions for user %s (ID %d)", sessionCount, login, userID)
if demote {
msg += " and demote to subscriber"
}
result.Message = msg
result.Success = true
return result
}
// Revoke sessions.
revokeQuery := fmt.Sprintf(
"UPDATE %susermeta SET meta_value='' WHERE user_id=%d AND meta_key='session_tokens'",
prefix, userID)
runMySQLQueryRoot(creds.dbName, revokeQuery)
result.Details = append(result.Details, "Sessions revoked")
// Demote to subscriber.
if demote {
// Read current capabilities to find the meta_key (varies by prefix).
capQuery := fmt.Sprintf(
"SELECT meta_key FROM %susermeta WHERE user_id=%d AND meta_key LIKE '%%capabilities'",
prefix, userID)
capLines := runMySQLQueryRoot(creds.dbName, capQuery)
if len(capLines) > 0 {
capKey := capLines[0]
demoteQuery := fmt.Sprintf(
"UPDATE %susermeta SET meta_value='a:1:{s:10:\"subscriber\";b:1;}' WHERE user_id=%d AND meta_key='%s'",
prefix, userID, escapeSQLString(capKey))
runMySQLQueryRoot(creds.dbName, demoteQuery)
result.Details = append(result.Details, "Demoted to subscriber role")
}
}
result.Message = fmt.Sprintf("Revoked sessions for user %s (ID %d)", login, userID)
result.Success = true
return result
}
// DBDeleteSpam deletes published posts matching spam patterns from a WordPress
// database. Only deletes posts of type 'post' with status 'publish' to avoid
// touching pages, attachments, or plugin data. If preview is true, reports
// counts without deleting.
func DBDeleteSpam(account string, preview bool) DBCleanResult {
result := DBCleanResult{
Account: account,
Action: "delete-spam",
}
creds, prefix := findCredsForAccount(account)
if creds.dbName == "" {
result.Message = fmt.Sprintf("No WordPress database found for account %q", account)
return result
}
result.Database = creds.dbName
// Count spam posts by pattern.
patterns := []struct {
keyword string
sqlLike string
}{
{"casino", "%casino-%"},
{"betting", "%betting%"},
{"cialis", "%cialis%"},
{"viagra", "%viagra%"},
{"pharma", "%pharma%"},
{"buy-cheap", "%buy-cheap-%"},
{"crack-serial", "%crack-serial%"},
{"free-download", "%free-download%"},
}
totalCount := 0
for _, p := range patterns {
countQuery := "SELECT COUNT(*) FROM " + prefix + "posts WHERE post_type='post' AND post_status='publish' AND (post_content LIKE '" + p.sqlLike + "' OR post_title LIKE '" + p.sqlLike + "')"
lines := runMySQLQueryRoot(creds.dbName, countQuery)
if len(lines) > 0 {
var count int
if _, err := fmt.Sscanf(lines[0], "%d", &count); err == nil && count > 0 {
result.Details = append(result.Details, fmt.Sprintf("%s: %d posts", p.keyword, count))
totalCount += count
}
}
}
if totalCount == 0 {
result.Message = "No spam posts found"
result.Success = true
return result
}
if preview {
result.Message = fmt.Sprintf("PREVIEW: Would delete up to %d spam posts", totalCount)
result.Success = true
return result
}
// Delete spam posts (and their revisions/meta).
deleted := 0
for _, p := range patterns {
// Get IDs of matching posts.
idQuery := "SELECT ID FROM " + prefix + "posts WHERE post_type='post' AND post_status='publish' AND (post_content LIKE '" + p.sqlLike + "' OR post_title LIKE '" + p.sqlLike + "')"
idLines := runMySQLQueryRoot(creds.dbName, idQuery)
if len(idLines) == 0 {
continue
}
// Delete in batches of 100.
for i := 0; i < len(idLines); i += 100 {
end := i + 100
if end > len(idLines) {
end = len(idLines)
}
batch := idLines[i:end]
// Validate IDs are numeric.
var validIDs []string
idRe := regexp.MustCompile(`^\d+$`)
for _, id := range batch {
id = strings.TrimSpace(id)
if idRe.MatchString(id) {
validIDs = append(validIDs, id)
}
}
if len(validIDs) == 0 {
continue
}
idList := strings.Join(validIDs, ",")
// Delete postmeta for these posts.
runMySQLQueryRoot(creds.dbName, fmt.Sprintf(
"DELETE FROM %spostmeta WHERE post_id IN (%s)", prefix, idList))
// Delete revisions.
runMySQLQueryRoot(creds.dbName, fmt.Sprintf(
"DELETE FROM %sposts WHERE post_parent IN (%s) AND post_type='revision'",
prefix, idList))
// Delete the posts themselves.
runMySQLQueryRoot(creds.dbName, fmt.Sprintf(
"DELETE FROM %sposts WHERE ID IN (%s) AND post_type='post' AND post_status='publish'",
prefix, idList))
deleted += len(validIDs)
}
}
result.Message = fmt.Sprintf("Deleted %d spam posts and their metadata", deleted)
result.Success = true
return result
}
// FormatDBCleanResult formats a DBCleanResult for terminal output.
func FormatDBCleanResult(r DBCleanResult) string {
var sb strings.Builder
status := "FAILED"
if r.Success {
status = "OK"
}
fmt.Fprintf(&sb, "[%s] %s — %s\n", status, r.Action, r.Message)
if r.Database != "" {
fmt.Fprintf(&sb, " Database: %s\n", r.Database)
}
for _, d := range r.Details {
fmt.Fprintf(&sb, " %s\n", d)
}
return sb.String()
}
// --- helpers ---
// findCredsForAccount finds WP database credentials for a cPanel account.
// Returns root-authenticated credentials that use /root/.my.cnf instead of
// wp-config.php passwords (which are often stale on cPanel servers).
func findCredsForAccount(account string) (wpDBCreds, string) {
patterns := []string{
fmt.Sprintf("/home/%s/public_html/wp-config.php", account),
}
addonConfigs, _ := osFS.Glob(fmt.Sprintf("/home/%s/*/wp-config.php", account))
patterns = append(patterns, addonConfigs...)
for _, path := range patterns {
creds := parseWPConfig(path)
if creds.dbName != "" {
prefix := creds.tablePrefix
if prefix == "" {
prefix = "wp_"
}
// Use root auth — CSM runs as root with /root/.my.cnf.
// wp-config.php passwords are unreliable (cPanel password
// rotations don't always update the file).
creds.dbUser = ""
creds.dbPass = ""
creds.dbHost = "localhost"
return creds, prefix
}
}
return wpDBCreds{}, ""
}
// runMySQLQueryRoot runs a MySQL query using root credentials from
// /root/.my.cnf (no explicit user/password args).
func runMySQLQueryRoot(dbName, query string) []string {
args := []string{
"-N", "-B",
dbName,
"-e", query,
}
out, err := runCmd("mysql", args...)
if err != nil || out == nil {
return nil
}
var lines []string
for _, line := range strings.Split(string(out), "\n") {
line = strings.TrimSpace(line)
if line != "" {
lines = append(lines, line)
}
}
return lines
}
package checks
import (
"bufio"
"context"
"fmt"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// Malicious patterns in WordPress database content.
//
// requiresExternalScript: when true, a matching row is only reported if
// its content also contains a <script src=...> pointing at a domain NOT
// on the known-safe list. This filters out the legitimate analytics and
// widget embeds that site owners place in page content (Google Tag
// Manager, Google merchant badge, HubSpot, Mailchimp, etc.) without
// weakening detection of attacker-injected external loaders.
var dbMalwarePatterns = []struct {
pattern string
severity alert.Severity
desc string
requiresExternalScript bool
}{
// The script-tag entry catches BOTH inline <script> blocks and
// <script src=...> loaders as a fast LIKE pre-filter; the Go post-
// filter (hasMaliciousExternalScript) verifies the presence of a
// non-safe-domain external src before raising a finding. Inline
// obfuscation without an external src is caught by the subsequent
// code-pattern entries below.
{"<script", alert.High, "injected <script> tag with non-safe external src", true},
{"eval(", alert.High, "eval() in database content", false},
{"base64_decode", alert.High, "base64_decode in database content", false},
{"document.write(", alert.High, "document.write injection", false},
{"String.fromCharCode", alert.High, "JavaScript obfuscation (fromCharCode)", false},
{".workers.dev", alert.Critical, "Cloudflare Workers exfiltration URL", false},
{"gist.githubusercontent.com", alert.Critical, "GitHub Gist payload URL", false},
{"pastebin.com/raw", alert.Critical, "Pastebin payload URL", false},
}
// CheckDatabaseContent scans WordPress databases for injected malware,
// spam content, siteurl hijacking, and rogue admin accounts.
func CheckDatabaseContent(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
wpConfigs, _ := osFS.Glob("/home/*/public_html/wp-config.php")
if len(wpConfigs) == 0 {
return nil
}
for _, wpConfig := range wpConfigs {
user := extractUser(filepath.Dir(wpConfig))
creds := parseWPConfig(wpConfig)
if creds.dbName == "" || creds.dbUser == "" {
continue
}
prefix := creds.tablePrefix
if prefix == "" {
prefix = "wp_"
}
// 1. Check wp_options for siteurl/home hijacking
findings = append(findings, checkWPOptions(user, creds, prefix)...)
// 2. Check wp_posts for injected scripts/malware
findings = append(findings, checkWPPosts(user, creds, prefix)...)
// 3. Check wp_users for rogue admin accounts
findings = append(findings, checkWPUsers(user, creds, prefix)...)
}
return findings
}
type wpDBCreds struct {
dbName string
dbUser string
dbPass string
dbHost string
tablePrefix string
}
// parseWPConfig extracts database credentials from wp-config.php.
func parseWPConfig(path string) wpDBCreds {
f, err := osFS.Open(path)
if err != nil {
return wpDBCreds{}
}
defer func() { _ = f.Close() }()
var creds wpDBCreds
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
// Match: define( 'DB_NAME', 'value' );
if val := extractDefine(line, "DB_NAME"); val != "" {
creds.dbName = val
}
if val := extractDefine(line, "DB_USER"); val != "" {
creds.dbUser = val
}
if val := extractDefine(line, "DB_PASSWORD"); val != "" {
creds.dbPass = val
}
if val := extractDefine(line, "DB_HOST"); val != "" {
creds.dbHost = val
}
// Match: $table_prefix = 'wp_';
if strings.Contains(line, "$table_prefix") {
if val := extractPHPString(line); val != "" {
creds.tablePrefix = val
}
}
}
if creds.dbHost == "" {
creds.dbHost = "localhost"
}
return creds
}
// extractDefine extracts the value from: define( 'KEY', 'value' );
func extractDefine(line, key string) string {
if !strings.Contains(line, key) {
return ""
}
// Skip comments
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "//") || strings.HasPrefix(trimmed, "#") || strings.HasPrefix(trimmed, "/*") {
return ""
}
// After the literal key, step past the first comma so
// extractPHPString picks up the VALUE's opening quote rather than
// the KEY's trailing closing quote. Without this, on input
// define( 'DB_NAME', 'wordpress_db' );
// extractPHPString would see `', 'wordpress_db' );` and return
// `, ` — the substring between the closing quote of 'DB_NAME' and
// the opening quote of 'wordpress_db'. Every real WordPress
// install's wp-config.php triggered this, which silently broke
// the entire WP database scan check.
rest := line[strings.Index(line, key)+len(key):]
if commaIdx := strings.Index(rest, ","); commaIdx >= 0 {
rest = rest[commaIdx+1:]
}
return extractPHPString(rest)
}
// extractPHPString extracts the first quoted string value from a line.
func extractPHPString(s string) string {
// Find opening quote
for _, quote := range []byte{'\'', '"'} {
start := strings.IndexByte(s, quote)
if start < 0 {
continue
}
rest := s[start+1:]
end := strings.IndexByte(rest, quote)
if end < 0 {
continue
}
return rest[:end]
}
return ""
}
// runMySQLQuery executes a MySQL query and returns the output lines.
func runMySQLQuery(creds wpDBCreds, query string) []string {
args := []string{
"-N", "-B", // no headers, tab-separated
"-u", creds.dbUser,
"-h", creds.dbHost,
creds.dbName,
"-e", query,
}
// Set password via environment to avoid command-line exposure
out, err := runCmdWithEnv("mysql", args, "MYSQL_PWD="+creds.dbPass)
if err != nil || out == nil {
return nil
}
var lines []string
for _, line := range strings.Split(string(out), "\n") {
line = strings.TrimSpace(line)
if line != "" {
lines = append(lines, line)
}
}
return lines
}
// checkWPOptions checks for siteurl/home hijacking and injected JavaScript.
func checkWPOptions(user string, creds wpDBCreds, prefix string) []alert.Finding {
var findings []alert.Finding
// Check siteurl and home for hijacking
query := fmt.Sprintf(
"SELECT option_name, option_value FROM %soptions WHERE option_name IN ('siteurl', 'home', 'admin_email') LIMIT 10",
prefix)
lines := runMySQLQuery(creds, query)
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) != 2 {
continue
}
optName := parts[0]
optValue := strings.ToLower(parts[1])
// Check if siteurl/home points to a different domain
if (optName == "siteurl" || optName == "home") &&
(strings.Contains(optValue, "eval(") || strings.Contains(optValue, "<script")) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "db_siteurl_hijack",
Message: fmt.Sprintf("WordPress %s contains malicious code (account: %s)", optName, user),
Details: fmt.Sprintf("Database: %s\n%s = %s", creds.dbName, optName, truncateDB(parts[1], 200)),
})
}
}
// Path 1: External script URLs in any option — only flag non-safe domains.
query = fmt.Sprintf(
"SELECT option_name, option_value FROM %soptions WHERE option_value LIKE '%%<script%%src=%%' LIMIT 20",
prefix)
lines = runMySQLQuery(creds, query)
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) != 2 {
continue
}
optName := parts[0]
optValue := parts[1]
// Skip CSM backup options — they preserve the original malicious
// content for recovery and should not be re-detected/re-cleaned.
if strings.HasPrefix(optName, "csm_backup_") {
continue
}
maliciousURL := extractMaliciousScriptURL(optValue)
if maliciousURL == "" {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "db_options_injection",
Message: fmt.Sprintf("Malicious script injection in wp_options '%s' (account: %s)", optName, user),
Details: fmt.Sprintf("Database: %s\nOption: %s\nMalicious URL: %s\nContent preview: %s", creds.dbName, optName, maliciousURL, truncateDB(optValue, 200)),
})
}
// Path 2: Inline script/code injection in core WP options that should
// NEVER contain JavaScript (siteurl, home, blogname, blogdescription).
coreOpts := "siteurl', 'home', 'blogname', 'blogdescription', 'admin_email"
codePatterns := "<script"
query = fmt.Sprintf(
"SELECT option_name, LEFT(option_value, 500) FROM %soptions WHERE option_name IN ('%s') AND option_value LIKE '%%%s%%'",
prefix, coreOpts, codePatterns)
lines = runMySQLQuery(creds, query)
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) != 2 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "db_options_injection",
Message: fmt.Sprintf("Malicious content in core wp_option '%s' (account: %s)", parts[0], user),
Details: fmt.Sprintf("Database: %s\nOption: %s\nContent preview: %s", creds.dbName, parts[0], truncateDB(parts[1], 200)),
})
}
return findings
}
// checkWPPosts checks post content for injected scripts and malware.
//
// Two classes of false positive are suppressed compared to a naive LIKE-
// based scan:
//
// - post_types used for plugin-managed storage (form submissions,
// revisions, templates, minified bundles) are excluded via the
// shared nonScannablePostTypes denylist. See dbscan_filters.go for
// the rationale and the full list.
//
// - Patterns that match too broadly at the SQL layer (the bare
// <script substring, and bare-word spam keywords like "cialis")
// are post-filtered in Go against word-boundary regexes and the
// known-safe-domain list. Legitimate analytics embeds and
// substring coincidences ("specialist" containing "cialis") no
// longer produce findings.
//
// The denylist is defense-in-depth: custom post_types created by a
// theme or plugin remain in scope, so attackers cannot evade by
// inventing a new post_type value.
func checkWPPosts(user string, creds wpDBCreds, prefix string) []alert.Finding {
var findings []alert.Finding
postTypeExcl := nonScannablePostTypesSQLList()
for _, mp := range dbMalwarePatterns {
// Select ID and content so we can post-filter in Go for the
// patterns that require it. ID comes first so that if the
// MySQL client wraps long content across lines we can still
// join reliably on the first tab.
query := fmt.Sprintf(
"SELECT ID, post_content FROM %sposts WHERE post_status='publish' AND post_type NOT IN (%s) AND (post_content LIKE '%%%s%%' OR post_content_filtered LIKE '%%%s%%') LIMIT 20",
prefix, postTypeExcl, mp.pattern, mp.pattern)
lines := runMySQLQuery(creds, query)
if len(lines) == 0 {
continue
}
var confirmedIDs []string
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) < 2 {
continue
}
postID := parts[0]
content := parts[1]
if mp.requiresExternalScript && !hasMaliciousExternalScript(content) {
// All scripts in this post are inline, or point at a
// known-safe widget host. Skip.
continue
}
confirmedIDs = append(confirmedIDs, postID)
if len(confirmedIDs) >= 5 {
break
}
}
if len(confirmedIDs) == 0 {
continue
}
findings = append(findings, alert.Finding{
Severity: mp.severity,
Check: "db_post_injection",
Message: fmt.Sprintf("WordPress posts contain %s (account: %s, %d posts)", mp.desc, user, len(confirmedIDs)),
Details: fmt.Sprintf("Database: %s\nAffected post IDs: %s\nPattern: %s",
creds.dbName, strings.Join(confirmedIDs, ", "), mp.pattern),
})
}
// Spam keyword scan. Three-layer filter:
//
// 1. SQL LIKE as a fast server-side pre-filter (reduces rows).
// 2. Word-boundary regex in countCloakedSpamMatches (rejects
// substring false positives like "specialist" / "cialis").
// 3. SEO-context requirement in contentHasSpamContext: a keyword
// hit only counts when accompanied by CSS cloaking, an
// injection fingerprint, or an external anchor whose URL
// path contains the keyword. Bare prose mentions (industry
// verticals, advisor bios, product catalogs listing a
// pharmaceutical supply chain) do not fire.
//
// The context requirement catches the real attack pattern — hidden
// off-screen div with external commercial link — while leaving
// legitimate content silent. See spam_context.go for the full
// signal catalog.
for _, sp := range dbSpamPatterns {
query := fmt.Sprintf(
"SELECT ID, post_content FROM %sposts WHERE post_status='publish' AND post_type NOT IN (%s) AND post_content LIKE '%s' LIMIT 200",
prefix, postTypeExcl, sp.likeFragment)
lines := runMySQLQuery(creds, query)
if len(lines) == 0 {
continue
}
contents := make([]string, 0, len(lines))
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) < 2 {
continue
}
contents = append(contents, parts[1])
}
n := countCloakedSpamMatches(sp, contents)
if n == 0 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "db_spam_injection",
Message: fmt.Sprintf("WordPress posts contain cloaked spam keyword '%s' (%d posts, account: %s)", sp.keyword, n, user),
Details: fmt.Sprintf("Database: %s", creds.dbName),
})
}
return findings
}
// checkWPUsers checks for rogue admin accounts created recently.
func checkWPUsers(user string, creds wpDBCreds, prefix string) []alert.Finding {
var findings []alert.Finding
// Find admin users created in the last 7 days
query := fmt.Sprintf(
"SELECT u.ID, u.user_login, u.user_email, u.user_registered FROM %susers u "+
"INNER JOIN %susermeta m ON u.ID = m.user_id "+
"WHERE m.meta_key = '%scapabilities' AND m.meta_value LIKE '%%administrator%%' "+
"AND u.user_registered >= DATE_SUB(NOW(), INTERVAL 7 DAY) "+
"LIMIT 10",
prefix, prefix, prefix)
lines := runMySQLQuery(creds, query)
for _, line := range lines {
parts := strings.SplitN(line, "\t", 4)
if len(parts) < 3 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "db_rogue_admin",
Message: fmt.Sprintf("New WordPress admin account created in last 7 days: %s (account: %s)", parts[1], user),
Details: fmt.Sprintf("Database: %s\nUser ID: %s\nLogin: %s\nEmail: %s\nRegistered: %s",
creds.dbName, parts[0], parts[1], parts[2], safeGet(parts, 3)),
})
}
// Check for admin users with suspicious email patterns
query = fmt.Sprintf(
"SELECT u.user_login, u.user_email FROM %susers u "+
"INNER JOIN %susermeta m ON u.ID = m.user_id "+
"WHERE m.meta_key = '%scapabilities' AND m.meta_value LIKE '%%administrator%%' "+
"LIMIT 50",
prefix, prefix, prefix)
lines = runMySQLQuery(creds, query)
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) != 2 {
continue
}
email := strings.ToLower(parts[1])
// Flag suspicious admin emails (disposable/temporary email domains)
suspiciousDomains := []string{
"tempmail", "guerrillamail", "mailinator", "throwaway",
"yopmail", "sharklasers", "trashmail", "maildrop",
}
for _, sd := range suspiciousDomains {
if strings.Contains(email, sd) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "db_suspicious_admin_email",
Message: fmt.Sprintf("WordPress admin '%s' has disposable email (account: %s)", parts[0], user),
Details: fmt.Sprintf("Database: %s\nEmail: %s", creds.dbName, email),
})
break
}
}
}
return findings
}
func safeGet(parts []string, idx int) string {
if idx < len(parts) {
return parts[idx]
}
return ""
}
func truncateDB(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// CleanDatabaseSpam removes known spam/malware patterns from WordPress database content.
// Targets wp_posts and wp_options tables. Returns findings for each cleaned row.
func CleanDatabaseSpam(account string) []alert.Finding {
var findings []alert.Finding
wpConfigs, _ := osFS.Glob(filepath.Join("/home", account, "*/wp-config.php"))
wpConfigs2, _ := osFS.Glob(filepath.Join("/home", account, "public_html/wp-config.php"))
wpConfigs = append(wpConfigs, wpConfigs2...)
for _, wpConfig := range wpConfigs {
creds := parseWPConfig(wpConfig)
if creds.dbName == "" {
continue
}
prefix := creds.tablePrefix
if prefix == "" {
prefix = "wp_"
}
// Clean spam from wp_posts
spamPatterns := []struct {
pattern string
desc string
}{
{"<script>", "injected script tag"},
{"eval(", "eval() in post content"},
{"base64_decode(", "base64_decode in post content"},
{"document.write(", "document.write injection"},
}
for _, sp := range spamPatterns {
// Count affected rows first
countQuery := fmt.Sprintf(
"SELECT COUNT(*) FROM %sposts WHERE post_content LIKE '%%%s%%'",
prefix, sp.pattern)
countLines := runMySQLQuery(creds, countQuery)
if len(countLines) == 0 || countLines[0] == "0" {
continue
}
// Clean: remove the malicious pattern from post_content
cleanQuery := fmt.Sprintf(
"UPDATE %sposts SET post_content = REPLACE(post_content, '%s', '') WHERE post_content LIKE '%%%s%%'",
prefix, sp.pattern, sp.pattern)
runMySQLQuery(creds, cleanQuery)
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "db_spam_cleaned",
Message: fmt.Sprintf("Cleaned %s from %s posts in %s (account: %s)", sp.desc, countLines[0], creds.dbName, account),
Timestamp: time.Now(),
})
}
// Scan for spam keywords in wp_posts. Uses the same word-boundary
// regex + post_type denylist + SEO-context requirement as
// checkWPPosts so an operator-initiated cleanup surfaces the
// same set of findings the periodic scan does.
postTypeExcl := nonScannablePostTypesSQLList()
for _, sp := range dbSpamPatterns {
query := fmt.Sprintf(
"SELECT ID, post_content FROM %sposts WHERE post_status='publish' AND post_type NOT IN (%s) AND post_content LIKE '%s' LIMIT 200",
prefix, postTypeExcl, sp.likeFragment)
lines := runMySQLQuery(creds, query)
if len(lines) == 0 {
continue
}
contents := make([]string, 0, len(lines))
for _, line := range lines {
parts := strings.SplitN(line, "\t", 2)
if len(parts) < 2 {
continue
}
contents = append(contents, parts[1])
}
n := countCloakedSpamMatches(sp, contents)
if n == 0 {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "db_spam_found",
Message: fmt.Sprintf("Found spam keyword '%s' in %d published posts in %s (account: %s) - manual review recommended", sp.keyword, n, creds.dbName, account),
})
}
}
return findings
}
package checks
import (
"regexp"
"strings"
)
// This file contains pure-function helpers used by the database content
// scanner (checkWPPosts in dbscan.go). Keeping them pure (no MySQL, no
// filesystem) makes them deterministically testable and independently
// reusable.
//
// Two classes of false positive were historically observed on real
// production traffic:
//
// 1. db_post_injection fired on every post containing a script tag,
// including site-owner-added analytics and widget embeds (Google Tag
// Manager, Google merchant rating badge, etc.).
//
// 2. db_spam_injection used substring LIKE matching, so "specialist"
// triggered on "cialis", "pharmaceutical" triggered on "pharma",
// "casino resort" triggered on "casino", etc. It also scanned all
// post_types including Contact Form 7 / WPForms / Jetpack stored
// submissions, which routinely contain spambot form fills the site
// owner never displays.
//
// The helpers below encode the decisions needed to eliminate those FPs
// without opening detection holes: word-boundary keyword matching,
// post_type filtering against a denylist (not an allowlist, so attackers
// cannot hide a post by renaming post_type to one we didn't anticipate),
// and safe-domain filtering for external script-tag sources.
// nonScannablePostTypes are post_type values that legitimately store
// non-site-content data (form submissions, revisions, templates, feeds).
// These are excluded from malware and spam scans because their content
// is operator-invisible storage, not material rendered to site visitors.
//
// This is a DENYLIST, not an allowlist. A custom post_type created by a
// theme or plugin (for example WooCommerce `product`, events, portfolios)
// is still scanned. Adding a new value here is safe; an attacker cannot
// hide a post by choosing a new post_type, because we default to
// scanning anything not on this list.
var nonScannablePostTypes = []string{
// WordPress internals / templates / navigation
"revision",
"customize_changeset",
"oembed_cache",
"nav_menu_item",
"wp_template",
"wp_template_part",
"wp_global_styles",
"wp_navigation",
// Minification plugins (store compiled bundles that legitimately
// contain JavaScript and obfuscated character sequences).
"wphb_minify_group",
// Form builders (store plugin configuration and visitor submissions.
// Contact-form spam landing here is noise, not site compromise.)
"wpforms",
"wpforms_entries",
"wpforms-log",
"wpcf7_contact_form",
"flamingo_inbound",
"flamingo_outbound",
"cf7_message",
"feedback",
"jetpack_feedback",
}
// isScannablePostType returns true if the given post_type should be
// included in malware and spam scans. The decision mirrors the SQL
// `post_type NOT IN (...)` clause used in checkWPPosts, so Go-side
// callers and test assertions stay consistent with the live SQL.
func isScannablePostType(postType string) bool {
for _, t := range nonScannablePostTypes {
if t == postType {
return false
}
}
return true
}
// nonScannablePostTypesSQLList returns the denylist as a comma-separated
// SQL literal list (for example "'revision','wp_template',...") suitable
// for use inside a `post_type NOT IN (...)` clause. The values are
// hardcoded and contain only [a-z_-] characters, so SQL injection is not
// a risk here; nonetheless the function escapes defensively.
func nonScannablePostTypesSQLList() string {
parts := make([]string, 0, len(nonScannablePostTypes))
for _, t := range nonScannablePostTypes {
// Hardcoded list contains only [a-z_-], but defensive escape in
// case future maintainers add a value with a quote or backslash.
escaped := strings.ReplaceAll(t, `\`, `\\`)
escaped = strings.ReplaceAll(escaped, `'`, `\'`)
parts = append(parts, "'"+escaped+"'")
}
return strings.Join(parts, ",")
}
// dbSpamPattern is a single spam-keyword detector. The LIKE fragment
// is a MySQL server-side pre-filter that quickly narrows the set of
// candidate posts; the Go regex then applies a strict word-boundary
// test so that "specialist" does not match "cialis", "pharmacy" does
// not match "pharma", and so on.
//
// Patterns that end with a non-word character (dash) already have an
// implicit right boundary from that character and only need a left
// word boundary. Pure-word patterns need boundaries on both sides.
type dbSpamPattern struct {
keyword string // human-readable keyword used in finding messages
regex *regexp.Regexp // applied Go-side to candidate rows
likeFragment string // SQL LIKE fragment, always bracketed with '%'
}
// dbSpamPatterns enumerates the keywords we flag as SEO/pharma/gambling
// spam in WordPress post content. Each entry pairs a fast SQL LIKE with
// a strict Go-side word-boundary regex.
//
// The regexes are case-insensitive to catch CIALIS / Cialis / cialis.
// The LIKE fragments are lowercase because MySQL LIKE is case-insensitive
// under the default _ci collation used by cPanel MariaDB.
var dbSpamPatterns = []dbSpamPattern{
{"viagra", regexp.MustCompile(`(?i)\bviagra\b`), "%viagra%"},
{"cialis", regexp.MustCompile(`(?i)\bcialis\b`), "%cialis%"},
{"pharma", regexp.MustCompile(`(?i)\bpharma\b`), "%pharma%"},
{"betting", regexp.MustCompile(`(?i)\bbetting\b`), "%betting%"},
// Dashed variants: the trailing dash is itself a non-word char and
// serves as the right boundary. Only a left word-boundary is needed.
{"casino-", regexp.MustCompile(`(?i)\bcasino-`), "%casino-%"},
{"buy-cheap-", regexp.MustCompile(`(?i)\bbuy-cheap-`), "%buy-cheap-%"},
{"free-download", regexp.MustCompile(`(?i)\bfree-download`), "%free-download%"},
{"crack-serial", regexp.MustCompile(`(?i)\bcrack-serial`), "%crack-serial%"},
}
// countSpamMatches returns the number of candidate rows whose content
// matches pattern.regex. The caller is responsible for passing only
// rows that were already narrowed by the pattern.likeFragment SQL
// pre-filter; this function applies the strict word-boundary test.
func countSpamMatches(pattern dbSpamPattern, contents []string) int {
n := 0
for _, c := range contents {
if pattern.regex.MatchString(c) {
n++
}
}
return n
}
// hasMaliciousExternalScript reports whether the content contains a
// script-tag with a src attribute pointing at a domain NOT on the known-
// safe list (see knownSafeDomains in db_autoresponse.go).
//
// Inline script blocks without a src attribute are not classified by
// this function; those are covered by the separate code-pattern entries
// in dbMalwarePatterns which catch common inline obfuscation techniques.
//
// Rationale: a bare script-tag match was the primary source of false
// positives on real traffic. Legitimate analytics embeds (Google Tag
// Manager, Google Analytics, Google merchant rating badge, Mailchimp,
// HubSpot, etc.) install both an external loader tag AND an inline
// initialization block. Flagging the inline block alone produced many
// HIGH severity noise findings on customer sites. Requiring a
// non-safe-domain external src reduces this to zero FPs in practice
// while still catching attackers who inject a tag pointing at an
// untrusted domain.
func hasMaliciousExternalScript(content string) bool {
return extractMaliciousScriptURL(content) != ""
}
package checks
import (
"bufio"
"context"
"fmt"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// DNS server processes that legitimately connect to many resolvers
// (e.g. BIND doing recursive resolution on a cPanel server).
var dnsServerUsers = map[string]bool{
"named": true, // BIND
"unbound": true, // Unbound
"pdns": true, // PowerDNS
}
// CheckDNSConnections looks for established connections to port 53 on
// DNS servers that are NOT in /etc/resolv.conf. This catches DNS
// tunneling, GSocket relay discovery, and malware using hardcoded resolvers.
// Connections owned by known DNS server processes (e.g. named) are skipped.
func CheckDNSConnections(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Parse configured resolvers
resolvers := parseResolvers()
if len(resolvers) == 0 {
return nil
}
// Also allow infra IPs and localhost
allowed := make(map[string]bool)
allowed["127.0.0.1"] = true
allowed["0.0.0.0"] = true
for _, r := range resolvers {
allowed[r] = true
}
// Build a set of UIDs belonging to DNS server processes (named, unbound,
// etc.) so we can skip their connections without reading /etc/passwd
// on every loop iteration.
dnsServerUIDs := resolveDNSServerUIDs()
// Parse /proc/net/tcp for connections to port 53
data, err := osFS.ReadFile("/proc/net/tcp")
if err != nil {
return nil
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) < 8 || fields[0] == "sl" {
continue
}
// State 01 = ESTABLISHED
if fields[3] != "01" {
continue
}
remoteIP, remotePort := parseHexAddr(fields[2])
if remotePort != 53 {
continue
}
if allowed[remoteIP] {
continue
}
// Check if it's an infra IP
if isInfraIP(remoteIP, cfg.InfraIPs) {
continue
}
// Skip connections owned by DNS server processes (e.g. named
// doing recursive resolution talks to many different servers)
if dnsServerUIDs[fields[7]] {
continue
}
_, localPort := parseHexAddr(fields[1])
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "dns_connection",
Message: fmt.Sprintf("DNS connection to non-configured resolver: %s", remoteIP),
Details: fmt.Sprintf("Local port: %d, Remote: %s:53\nConfigured resolvers: %s", localPort, remoteIP, strings.Join(resolvers, ", ")),
})
}
return findings
}
// resolveDNSServerUIDs returns a set of UIDs that belong to known DNS
// server users (named, unbound, pdns) by reading /etc/passwd once.
func resolveDNSServerUIDs() map[string]bool {
uids := make(map[string]bool)
data, err := osFS.ReadFile("/etc/passwd")
if err != nil {
return uids
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Split(line, ":")
if len(fields) >= 3 && dnsServerUsers[fields[0]] {
uids[fields[2]] = true
}
}
return uids
}
func parseResolvers() []string {
f, err := osFS.Open("/etc/resolv.conf")
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var resolvers []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "nameserver ") {
ip := strings.TrimSpace(strings.TrimPrefix(line, "nameserver"))
if ip != "" {
resolvers = append(resolvers, ip)
}
}
}
return resolvers
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// maxBulkDNSChanges is the threshold above which zone changes are considered
// a cPanel bulk operation (AutoSSL, serial bump, DNSSEC rotation) and suppressed.
// Only 1-5 zone changes at once are reported - likely targeted modifications.
const maxBulkDNSChanges = 5
// CheckDNSZoneChanges monitors named zone files for modifications.
// Suppresses bulk changes (>5 zones at once = cPanel maintenance).
// Only alerts on targeted changes (1-5 zones) which may indicate tampering.
func CheckDNSZoneChanges(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
// cPanel stores zone files in /var/named/
zoneDir := "/var/named"
zones, err := osFS.ReadDir(zoneDir)
if err != nil {
return nil
}
// First pass: count how many zones changed
var changedZones []string
for _, zone := range zones {
if zone.IsDir() {
continue
}
name := zone.Name()
if !strings.HasSuffix(name, ".db") {
continue
}
fullPath := filepath.Join(zoneDir, name)
hash, err := hashFileContent(fullPath)
if err != nil {
continue
}
key := "_dns_zone:" + name
prev, exists := store.GetRaw(key)
store.SetRaw(key, hash)
if exists && prev != hash {
changedZones = append(changedZones, name)
}
}
// If many zones changed at once, it's cPanel maintenance - suppress
if len(changedZones) > maxBulkDNSChanges {
return nil
}
// Only alert on targeted changes (1-5 zones)
var findings []alert.Finding
for _, name := range changedZones {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "dns_zone_change",
Message: fmt.Sprintf("DNS zone file modified: %s", name),
Details: fmt.Sprintf("File: %s\nThis could indicate DNS hijacking or unauthorized domain changes", filepath.Join(zoneDir, name)),
})
}
return findings
}
// CheckSSLCertIssuance monitors AutoSSL logs for new certificate issuance.
// Attackers may issue certificates for phishing domains using compromised accounts.
func CheckSSLCertIssuance(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
// Check AutoSSL log
logPath := "/var/cpanel/logs/autossl"
entries, err := osFS.ReadDir(logPath)
if err != nil {
return nil
}
// Track the count of log files as a simple change indicator
currentCount := len(entries)
key := "_ssl_autossl_count"
prev, exists := store.GetRaw(key)
store.SetRaw(key, fmt.Sprintf("%d", currentCount))
if !exists {
return nil
}
prevCount := 0
fmt.Sscanf(prev, "%d", &prevCount)
if currentCount > prevCount {
// New AutoSSL activity - check the latest log
var latestLog string
var latestTime int64
for _, entry := range entries {
if entry.IsDir() {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
if info.ModTime().Unix() > latestTime {
latestTime = info.ModTime().Unix()
latestLog = filepath.Join(logPath, entry.Name())
}
}
if latestLog != "" {
// Read the tail of the latest log for certificate issuance
lines := tailFile(latestLog, 50)
for _, line := range lines {
lineLower := strings.ToLower(line)
if strings.Contains(lineLower, "installed") || strings.Contains(lineLower, "issued") {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "ssl_cert_issued",
Message: "New SSL certificate issued via AutoSSL",
Details: truncate(line, 300),
})
}
}
}
}
return findings
}
package checks
import (
"bufio"
"context"
"crypto/sha1" // #nosec G505 -- SHA1 is the digest format required by the Have I Been Pwned range API (https://haveibeenpwned.com/API/v3#PwnedPasswords). We send the first 5 chars of the digest and compare remaining chars against the returned list — HIBP does not offer a stronger-hash endpoint.
"crypto/sha256"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"unicode"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
)
// doveadmSemaphore limits concurrent doveadm processes.
var doveadmSemaphore = make(chan struct{}, 3)
// hibpClient is used for HIBP API requests.
var hibpClient = &http.Client{Timeout: 10 * time.Second}
// hibpEndpoint is the base URL queried for HIBP password-range lookups.
// Declared as a var (not const) so tests can swap in an httptest server.
// Production callers must not modify this.
var hibpEndpoint = "https://api.pwnedpasswords.com/range/"
// currentYear returns the current year. Called each audit cycle so
// long-running daemons don't use a stale year after Jan 1.
func currentYear() int { return time.Now().Year() }
// weakPasswordCache caches the bundled wordlist (loaded once).
var (
weakPasswordOnce sync.Once
weakPasswords []string
)
// parseShadowLine parses a Dovecot shadow line "mailbox:{scheme}hash".
// Returns empty strings if the line is malformed.
func parseShadowLine(line string) (mailbox, hash string) {
idx := strings.IndexByte(line, ':')
if idx <= 0 || idx >= len(line)-1 {
return "", ""
}
return line[:idx], line[idx+1:]
}
// isLockedHash returns true if the hash indicates a locked/disabled account.
func isLockedHash(hash string) bool {
if hash == "" {
return true
}
return hash[0] == '!' || hash[0] == '*'
}
// generateCandidates creates password candidates from username and domain.
// All candidates are >= 6 characters. No duplicates.
func generateCandidates(username, domain string) []string {
seen := make(map[string]bool)
var candidates []string
add := func(s string) {
if len(s) >= 6 && !seen[s] {
seen[s] = true
candidates = append(candidates, s)
}
}
domainLabel := domain
if idx := strings.IndexByte(domain, '.'); idx > 0 {
domainLabel = domain[:idx]
}
bases := []string{username, domainLabel}
for _, base := range bases {
add(base)
// Capitalize first letter variant
upper := capitalizeFirst(base)
add(upper)
// Year variants: current year +/- 2
year := currentYear()
for y := year - 2; y <= year+2; y++ {
ys := strconv.Itoa(y)
add(base + ys)
add(upper + ys)
}
// Two-digit suffix variants: 00-99
for n := 0; n <= 99; n++ {
suffix := fmt.Sprintf("%02d", n)
add(base + suffix)
add(upper + suffix)
}
}
return candidates
}
// capitalizeFirst returns the string with its first rune upper-cased.
func capitalizeFirst(s string) string {
if len(s) == 0 {
return s
}
runes := []rune(s)
runes[0] = unicode.ToUpper(runes[0])
return string(runes)
}
// verifyDoveadm checks a candidate password against a stored hash.
// Returns true if the password matches.
func verifyDoveadm(hash, candidate string) bool {
doveadmSemaphore <- struct{}{}
defer func() { <-doveadmSemaphore }()
// Routed through cmdExec so tests can mock doveadm without a real install.
_, err := cmdExec.Run("doveadm", "pw", "-t", hash, "-p", candidate)
return err == nil
}
// hashFingerprint returns a SHA256 hex fingerprint of a password hash
// (used for change detection -- re-audit only when hash changes).
func hashFingerprint(hash string) string {
h := sha256.Sum256([]byte(hash))
return fmt.Sprintf("%x", h[:])
}
// parseHIBPCount searches a HIBP range response body for a hash suffix
// and returns the breach count. Returns 0 if not found.
func parseHIBPCount(body, suffix string) int {
upperSuffix := strings.ToUpper(suffix)
for _, line := range strings.Split(body, "\n") {
line = strings.TrimSpace(line)
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
if strings.ToUpper(strings.TrimSpace(parts[0])) == upperSuffix {
count, err := strconv.Atoi(strings.TrimSpace(parts[1]))
if err != nil {
return 0
}
return count
}
}
return 0
}
// checkHIBP queries the HIBP Pwned Passwords API for a plaintext password.
// Returns the breach count (0 if not found or on error).
func checkHIBP(plaintext string) int {
// #nosec G401 -- SHA1 is mandated by the HIBP Pwned Passwords range API; see import comment.
h := sha1.Sum([]byte(plaintext))
hex := fmt.Sprintf("%X", h[:])
prefix := hex[:5]
suffix := hex[5:]
resp, err := hibpClient.Get(hibpEndpoint + prefix) //nolint:noctx
if err != nil {
return 0
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0
}
return parseHIBPCount(string(body), suffix)
}
type shadowFile struct {
path string
account string
domain string
}
type mailboxEntry struct {
account string
domain string
mailbox string
hash string
}
// discoverShadowFiles finds all Dovecot shadow files under /home/*/etc/*/shadow.
func discoverShadowFiles() []shadowFile {
matches, _ := osFS.Glob("/home/*/etc/*/shadow")
var results []shadowFile
for _, m := range matches {
parts := strings.Split(m, "/")
// /home/{account}/etc/{domain}/shadow
if len(parts) >= 5 {
results = append(results, shadowFile{
path: m,
account: parts[2],
domain: parts[4],
})
}
}
return results
}
// readShadowFile reads all mailbox entries from a Dovecot shadow file.
func readShadowFile(sf shadowFile) []mailboxEntry {
f, err := osFS.Open(sf.path)
if err != nil {
return nil
}
defer f.Close()
var entries []mailboxEntry
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
mailbox, hash := parseShadowLine(line)
if mailbox == "" || hash == "" {
continue
}
if isLockedHash(hash) {
continue
}
entries = append(entries, mailboxEntry{
account: sf.account,
domain: sf.domain,
mailbox: mailbox,
hash: hash,
})
}
return entries
}
// loadWeakPasswords reads the bundled wordlist once and caches it.
func loadWeakPasswords() []string {
weakPasswordOnce.Do(func() {
f, err := osFS.Open("/opt/csm/configs/weak_passwords.txt")
if err != nil {
return
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
word := strings.TrimSpace(scanner.Text())
if len(word) < 6 || strings.HasPrefix(word, "#") {
continue
}
weakPasswords = append(weakPasswords, word)
}
})
return weakPasswords
}
// checkWordlist tests a hash against the bundled weak passwords list.
// Returns the matched password or empty string.
func checkWordlist(hash string) string {
for _, word := range loadWeakPasswords() {
if verifyDoveadm(hash, word) {
return word
}
}
return ""
}
// CheckEmailPasswords audits Dovecot email account passwords for weak/predictable
// patterns. Uses internal throttle: skips if last refresh was less than
// password_check_interval_min ago.
func CheckEmailPasswords(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
db := store.Global()
if db == nil {
return nil
}
// Internal throttle -- same pattern as CheckOutdatedPlugins
if !ForceAll {
lastRefresh := db.GetEmailPWLastRefresh()
interval := time.Duration(cfg.EmailProtection.PasswordCheckIntervalMin) * time.Minute
if time.Since(lastRefresh) < interval {
return nil
}
}
shadowFiles := discoverShadowFiles()
if len(shadowFiles) == 0 {
return nil
}
// Collect all mailbox entries
var allEntries []mailboxEntry
for _, sf := range shadowFiles {
allEntries = append(allEntries, readShadowFile(sf)...)
}
if len(allEntries) == 0 {
_ = db.SetEmailPWLastRefresh(time.Now())
return nil
}
var mu sync.Mutex
var findings []alert.Finding
var wg sync.WaitGroup
// Process each mailbox concurrently (bounded by semaphore)
sem := make(chan struct{}, 5)
for _, entry := range allEntries {
wg.Add(1)
go func(e mailboxEntry) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
fullMailbox := e.mailbox + "@" + e.domain
storeKey := fmt.Sprintf("email:pwaudit:%s:%s", e.account, fullMailbox)
// Skip if hash hasn't changed since last audit
fp := hashFingerprint(e.hash)
if stored := db.GetMetaString(storeKey); stored == fp {
return
}
// Layer 1: Heuristic candidates
candidates := generateCandidates(e.mailbox, e.domain)
var matched string
var matchType string
for _, c := range candidates {
if verifyDoveadm(e.hash, c) {
matched = c
matchType = "heuristic"
break
}
}
// Layer 2: Common wordlist (skip if layer 1 matched)
if matched == "" {
if w := checkWordlist(e.hash); w != "" {
matched = w
matchType = "wordlist"
}
}
if matched != "" {
details := fmt.Sprintf("Account: %s\nMailbox: %s\nMatch type: %s\nMatched password pattern: %q",
e.account, fullMailbox, matchType, matched)
// Layer 3: HIBP enrichment (only for confirmed matches)
breachCount := checkHIBP(matched)
if breachCount > 0 {
details += fmt.Sprintf("\nHIBP: password found in %d data breaches", breachCount)
}
mu.Lock()
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_weak_password",
Message: fmt.Sprintf("Weak email password for %s (account: %s)", fullMailbox, e.account),
Details: details,
})
mu.Unlock()
}
// Record fingerprint so we don't re-audit until hash changes
_ = db.SetMetaString(storeKey, fp)
}(entry)
}
wg.Wait()
_ = db.SetEmailPWLastRefresh(time.Now())
return findings
}
package checks
import (
"context"
"fmt"
"path/filepath"
"regexp"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const emailBodySampleSize = 4096 // Read first 4KB of email body
// Suspicious mailer headers that indicate mass mailing scripts.
var suspiciousMailers = []string{
"phpmailer", "swiftmailer", "mass mailer", "bulk mailer",
"leaf phpmailer", "phpmail", "mail.php",
}
// Known safe mailers that should not be flagged.
var safeMailers = []string{
"wordpress", "woocommerce", "roundcube", "squirrelmail",
"thunderbird", "outlook", "apple mail", "cpanel",
"postfix", "exim", "dovecot",
}
// Phishing URL patterns in email body.
var emailPhishPatterns = []string{
".workers.dev",
"//bit.ly/", "//tinyurl.com/", "//is.gd/", "//rb.gy/",
"//t.co/",
"/redir?url=", "/redirect?url=", "link?url=",
"effi.redir",
}
// Phishing language in email body.
var emailPhishLanguage = []string{
"verify your account",
"confirm your identity",
"unusual activity",
"your account will be",
"suspended unless",
"click here to verify",
"update your payment",
"confirm your email address",
"security alert",
"unauthorized access",
}
// Brand impersonation in email body (when sender doesn't match).
var emailBrandNames = []string{
"paypal", "microsoft", "apple", "google", "amazon",
"netflix", "facebook", "instagram", "bank of",
"wells fargo", "chase", "citibank",
}
// CheckOutboundEmailContent samples outbound email content from Exim spool
// for phishing URLs, credential harvesting language, suspicious mailers,
// and Reply-To mismatches.
func CheckOutboundEmailContent(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Read recent outbound messages from exim_mainlog
lines := tailFile("/var/log/exim_mainlog", 200)
if len(lines) == 0 {
return nil
}
// Extract message IDs for outbound emails
msgIDRegex := regexp.MustCompile(`^\S+\s+(\S+)\s+<=\s+(\S+)`)
scanned := make(map[string]bool) // avoid scanning same message twice
for _, line := range lines {
matches := msgIDRegex.FindStringSubmatch(line)
if len(matches) < 3 {
continue
}
msgID := matches[1]
sender := matches[2]
if sender == "<>" || sender == "" {
continue // bounce message
}
if scanned[msgID] {
continue
}
scanned[msgID] = true
// Scan the message
result := scanEximMessage(msgID, sender, cfg)
if result != nil {
findings = append(findings, *result)
}
}
return findings
}
// scanEximMessage reads an Exim spool message and checks for suspicious content.
func scanEximMessage(msgID, sender string, cfg *config.Config) *alert.Finding {
// Exim spool paths
// Headers: /var/spool/exim/input/{msgID}-H
// Body: /var/spool/exim/input/{msgID}-D
spoolDirs := []string{
"/var/spool/exim/input",
"/var/spool/exim4/input",
}
var headerPath, bodyPath string
for _, dir := range spoolDirs {
h := filepath.Join(dir, msgID+"-H")
b := filepath.Join(dir, msgID+"-D")
if _, err := osFS.Stat(h); err == nil {
headerPath = h
bodyPath = b
break
}
}
if headerPath == "" {
return nil // message already delivered/removed from spool
}
var indicators []string
// Read and analyze headers
headerData, err := osFS.ReadFile(headerPath)
if err != nil {
return nil
}
headers := string(headerData)
headersLower := strings.ToLower(headers)
// Check 1: Reply-To mismatch
from := extractEmailHeader(headers, "From:")
replyTo := extractEmailHeader(headers, "Reply-To:")
if from != "" && replyTo != "" {
fromDomain := extractDomain(from)
replyDomain := extractDomain(replyTo)
if fromDomain != "" && replyDomain != "" && fromDomain != replyDomain {
indicators = append(indicators, fmt.Sprintf("Reply-To mismatch: From=%s, Reply-To=%s", fromDomain, replyDomain))
}
}
// Check 2: Suspicious X-Mailer
mailer := extractEmailHeader(headers, "X-Mailer:")
if mailer == "" {
mailer = extractEmailHeader(headers, "User-Agent:")
}
if mailer != "" {
mailerLower := strings.ToLower(mailer)
isSafe := false
for _, safe := range safeMailers {
if strings.Contains(mailerLower, safe) {
isSafe = true
break
}
}
if !isSafe {
for _, suspicious := range suspiciousMailers {
if strings.Contains(mailerLower, suspicious) {
indicators = append(indicators, fmt.Sprintf("suspicious mailer: %s", strings.TrimSpace(mailer)))
break
}
}
}
}
// Check 3: Spoofed display name (brand name in From: but sender is not that brand)
fromHeader := extractEmailHeader(headers, "From:")
if fromHeader != "" {
fromLower := strings.ToLower(fromHeader)
senderDomain := extractDomain(sender)
for _, brand := range emailBrandNames {
if strings.Contains(fromLower, brand) && !strings.Contains(strings.ToLower(senderDomain), brand) {
indicators = append(indicators, fmt.Sprintf("spoofed brand in From: '%s' (actual sender: %s)", strings.TrimSpace(fromHeader), sender))
break
}
}
}
// Read and analyze body (sample first 4KB)
bodyData, _ := osFS.ReadFile(bodyPath)
if len(bodyData) > emailBodySampleSize {
bodyData = bodyData[:emailBodySampleSize]
}
if len(bodyData) > 0 {
bodyLower := strings.ToLower(string(bodyData))
// Check 4: Phishing URLs in body
for _, pattern := range emailPhishPatterns {
if strings.Contains(bodyLower, pattern) {
indicators = append(indicators, fmt.Sprintf("phishing URL pattern: %s", pattern))
break
}
}
// Check 5: Credential harvesting language
harvestCount := 0
for _, phrase := range emailPhishLanguage {
if strings.Contains(bodyLower, phrase) {
harvestCount++
}
}
if harvestCount >= 2 {
indicators = append(indicators, fmt.Sprintf("credential harvesting language (%d phrases)", harvestCount))
}
// Check 6: Base64-encoded HTML body (used to bypass filters)
if strings.Contains(headersLower, "content-transfer-encoding: base64") &&
strings.Contains(headersLower, "text/html") {
// Check if decoded content has phishing patterns
indicators = append(indicators, "base64-encoded HTML body (potential filter bypass)")
}
}
if len(indicators) == 0 {
return nil
}
severity := alert.High
if len(indicators) >= 3 {
severity = alert.Critical
}
return &alert.Finding{
Severity: severity,
Check: "email_phishing_content",
Message: fmt.Sprintf("Suspicious outbound email from %s (message: %s)", sender, msgID),
Details: fmt.Sprintf("Indicators:\n- %s", strings.Join(indicators, "\n- ")),
}
}
// extractEmailHeader extracts a header value from raw email headers.
func extractEmailHeader(headers, name string) string {
nameLower := strings.ToLower(name)
for _, line := range strings.Split(headers, "\n") {
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(line)), nameLower) {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
return strings.TrimSpace(parts[1])
}
}
}
return ""
}
// extractDomain extracts the domain part from an email address or header value.
func extractDomain(s string) string {
// Handle "Display Name <email@domain>" format
if idx := strings.Index(s, "<"); idx >= 0 {
end := strings.Index(s[idx:], ">")
if end > 0 {
s = s[idx+1 : idx+end]
}
}
// Extract domain from email
if idx := strings.LastIndex(s, "@"); idx >= 0 {
return strings.TrimSpace(s[idx+1:])
}
return ""
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckDatabaseDumps detects mysqldump/pg_dump processes running under
// non-root users - potential data exfiltration.
func CheckDatabaseDumps(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
dumpTools := []string{"mysqldump", "pg_dump", "mongodump"}
procs, _ := osFS.Glob("/proc/[0-9]*/cmdline")
for _, cmdPath := range procs {
pid := filepath.Base(filepath.Dir(cmdPath))
// Read UID
statusData, _ := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
var uid string
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
uid = fields[0]
}
}
}
// Skip root - root may run legitimate backups
if uid == "0" || uid == "" {
continue
}
cmdline, err := osFS.ReadFile(cmdPath)
if err != nil {
continue
}
cmdStr := strings.ReplaceAll(string(cmdline), "\x00", " ")
for _, tool := range dumpTools {
if strings.Contains(cmdStr, tool) {
user := uidToUser(uid)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "database_dump",
Message: fmt.Sprintf("Database dump by non-root user: %s (%s)", user, tool),
Details: fmt.Sprintf("PID: %s, UID: %s, cmdline: %s", pid, uid, strings.TrimSpace(cmdStr)),
})
break
}
}
}
return findings
}
// CheckOutboundPasteSites detects connections to known paste/exfiltration sites.
func CheckOutboundPasteSites(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Check running processes for connections to paste sites
pasteSites := []string{
"pastebin.com", "hastebin.com", "ghostbin.co",
"paste.ee", "dpaste.org", "gist.githubusercontent.com",
"raw.githubusercontent.com", "transfer.sh",
"file.io", "0x0.st", "ix.io",
}
procs, _ := osFS.Glob("/proc/[0-9]*/cmdline")
for _, cmdPath := range procs {
pid := filepath.Base(filepath.Dir(cmdPath))
statusData, _ := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
var uid string
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
uid = fields[0]
}
}
}
if uid == "0" {
continue
}
cmdline, err := osFS.ReadFile(cmdPath)
if err != nil {
continue
}
cmdStr := strings.ToLower(strings.ReplaceAll(string(cmdline), "\x00", " "))
// Check if process is connecting to paste sites
for _, site := range pasteSites {
if strings.Contains(cmdStr, site) {
user := uidToUser(uid)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "exfiltration_paste_site",
Message: fmt.Sprintf("Process connecting to paste/exfiltration site: %s (user: %s)", site, user),
Details: fmt.Sprintf("PID: %s, cmdline: %s", pid, strings.TrimSpace(cmdStr)),
})
break
}
}
}
return findings
}
package checks
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"sync/atomic"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// fileIndexScanCount tracks the number of CheckFileIndex invocations.
// Every 6th scan forces a full directory rescan, bypassing the mtime cache
// to catch writes that don't update parent directory mtime (e.g. hard links).
var fileIndexScanCount int32
// suspiciousExtensions are file extensions that should never appear in web roots.
var suspiciousExtensions = map[string]bool{
".phtml": true, ".pht": true, ".php5": true,
".haxor": true, ".cgix": true,
}
// dirMtimeCache maps directory paths to their last-known mtime (unix seconds).
// Directories with unchanged mtime are skipped during scanning.
type dirMtimeCache map[string]int64
func loadDirCache(stateDir string) dirMtimeCache {
cache := make(dirMtimeCache)
data, err := osFS.ReadFile(filepath.Join(stateDir, "dircache.json"))
if err == nil {
_ = json.Unmarshal(data, &cache)
}
return cache
}
func saveDirCache(stateDir string, cache dirMtimeCache) {
data, _ := json.Marshal(cache)
tmpPath := filepath.Join(stateDir, "dircache.json.tmp")
_ = os.WriteFile(tmpPath, data, 0600)
_ = os.Rename(tmpPath, filepath.Join(stateDir, "dircache.json"))
}
// dirChanged returns true if the directory mtime has changed since last scan.
// Updates the cache with the new mtime. If forceFullScan is true, always
// returns true to force a ReadDir regardless of mtime (catches writes that
// bypass parent mtime updates, e.g. hard links or mount tricks).
func dirChanged(dir string, cache dirMtimeCache, forceFullScan bool) bool {
info, err := osFS.Stat(dir)
if err != nil {
return true // can't stat, scan it to be safe
}
mtime := info.ModTime().Unix()
prev, exists := cache[dir]
cache[dir] = mtime
if forceFullScan {
return true
}
if !exists {
return true // first time seeing this dir
}
return mtime != prev
}
// CheckFileIndex builds an index of suspicious files using pure Go directory
// reads, diffs against the previous index, and alerts on new files.
// Uses directory mtime caching: unchanged dirs carry forward previous entries
// without calling ReadDir, while changed dirs are re-scanned.
func CheckFileIndex(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
scanNum := atomic.AddInt32(&fileIndexScanCount, 1)
forceFullScan := scanNum%6 == 0 // full rescan every 6th cycle
indexDir := cfg.StatePath
currentPath := filepath.Join(indexDir, "fileindex.current")
previousPath := filepath.Join(indexDir, "fileindex.previous")
// Load caches
dirCache := loadDirCache(indexDir)
// Build a set of previous entries grouped by their top-level scan dir,
// so unchanged dirs can carry forward their entries without ReadDir.
previousEntries := loadIndex(previousPath)
prevByDir := groupEntriesByUploadDir(previousEntries)
// Build current index
currentEntries := buildFileIndex(dirCache, prevByDir, forceFullScan)
// Save updated dir cache
saveDirCache(indexDir, dirCache)
// Write current index (atomic)
writeIndex(currentPath, currentEntries)
// First run - save baseline
if _, err := osFS.Stat(previousPath); os.IsNotExist(err) {
copyFile(currentPath, previousPath)
return nil
}
// Validate index
if len(previousEntries) > 10 && len(currentEntries) == 0 {
fmt.Fprintf(os.Stderr, "file_index: current index empty but previous had %d entries, skipping diff\n", len(previousEntries))
return nil
}
if len(previousEntries) > 0 && len(currentEntries) < len(previousEntries)/2 {
fmt.Fprintf(os.Stderr, "file_index: current index (%d) < half of previous (%d), skipping diff\n", len(currentEntries), len(previousEntries))
return nil
}
// Diff
prevSet := make(map[string]bool, len(previousEntries))
for _, e := range previousEntries {
prevSet[e] = true
}
var newFiles []string
for _, e := range currentEntries {
if !prevSet[e] {
newFiles = append(newFiles, e)
}
}
// Analyze new files - lazy stat only for alerting
for _, path := range newFiles {
name := filepath.Base(path)
nameLower := strings.ToLower(name)
suppressed := false
for _, ignore := range cfg.Suppressions.IgnorePaths {
if matchGlob(path, ignore) {
suppressed = true
break
}
}
if suppressed {
continue
}
severity := alert.Severity(-1)
check := ""
message := ""
if strings.Contains(path, "/wp-content/uploads/") && strings.HasSuffix(nameLower, ".php") {
if isKnownSafeUpload(path, name) {
continue
}
severity = alert.High
check = "new_php_in_uploads"
message = fmt.Sprintf("New PHP file in uploads: %s", path)
}
// PHP files in wp-content/languages - should only contain .mo/.po/.l10n.php
if strings.Contains(path, "/wp-content/languages/") && strings.HasSuffix(nameLower, ".php") {
if !strings.HasSuffix(nameLower, ".l10n.php") && nameLower != "index.php" {
severity = alert.Critical
check = "new_php_in_languages"
message = fmt.Sprintf("New PHP file in wp-content/languages (should only contain translations): %s", path)
}
}
// PHP files in wp-content/upgrade - should be empty except index.php
if strings.Contains(path, "/wp-content/upgrade/") && strings.HasSuffix(nameLower, ".php") {
if nameLower != "index.php" {
severity = alert.Critical
check = "new_php_in_upgrade"
message = fmt.Sprintf("New PHP file in wp-content/upgrade (should be empty): %s", path)
}
}
if strings.Contains(path, "/.config/") {
severity = alert.Critical
check = "new_executable_in_config"
message = fmt.Sprintf("New executable in .config: %s", path)
}
if isWebshellName(nameLower) {
severity = alert.Critical
check = "new_webshell_file"
message = fmt.Sprintf("New file with webshell name: %s", path)
}
if strings.HasSuffix(nameLower, ".php") && isSuspiciousPHPName(nameLower) {
if severity < 0 {
severity = alert.High
check = "new_suspicious_php"
message = fmt.Sprintf("New suspicious PHP file: %s", path)
}
}
if severity >= 0 {
details := ""
if info, err := osFS.Stat(path); err == nil {
details = fmt.Sprintf("Size: %d, Mtime: %s", info.Size(), info.ModTime().Format("2006-01-02 15:04:05"))
}
findings = append(findings, alert.Finding{
Severity: severity,
Check: check,
Message: message,
Details: details,
FilePath: path,
})
}
}
copyFile(currentPath, previousPath)
return findings
}
// groupEntriesByUploadDir groups index entries by their containing scan root.
// Used to carry forward entries from unchanged directories.
func groupEntriesByUploadDir(entries []string) map[string][]string {
grouped := make(map[string][]string)
for _, path := range entries {
// Find the scan root: uploads dir, .config dir, or tmp dir
dir := filepath.Dir(path)
grouped[dir] = append(grouped[dir], path)
}
return grouped
}
// buildFileIndex scans targeted directory subtrees using ReadDir.
// Skips directories whose mtime hasn't changed - carries forward
// their entries from the previous index instead.
// If forceFullScan is true, all directories are re-scanned regardless of mtime.
func buildFileIndex(dirCache dirMtimeCache, prevByDir map[string][]string, forceFullScan bool) []string {
var entries []string
homeDirs, err := GetScanHomeDirs()
if err != nil {
return nil
}
for _, homeEntry := range homeDirs {
if !homeEntry.IsDir() {
continue
}
user := homeEntry.Name()
homeDir := filepath.Join("/home", user)
// Scan wp-content/uploads for PHP files
uploadDirs := []string{
filepath.Join(homeDir, "public_html", "wp-content", "uploads"),
}
// Scan directories that shouldn't normally contain user PHP:
// languages (translation files only), upgrade (temp dir), mu-plugins
sensitiveWPDirs := []string{
filepath.Join(homeDir, "public_html", "wp-content", "languages"),
filepath.Join(homeDir, "public_html", "wp-content", "upgrade"),
filepath.Join(homeDir, "public_html", "wp-content", "mu-plugins"),
}
subDirs, _ := osFS.ReadDir(homeDir)
for _, sd := range subDirs {
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
uploadsPath := filepath.Join(homeDir, sd.Name(), "wp-content", "uploads")
if info, err := osFS.Stat(uploadsPath); err == nil && info.IsDir() {
uploadDirs = append(uploadDirs, uploadsPath)
}
// Also track sensitive dirs for addon domains
for _, subDir := range []string{"languages", "upgrade", "mu-plugins"} {
sensitiveDir := filepath.Join(homeDir, sd.Name(), "wp-content", subDir)
if info, err := osFS.Stat(sensitiveDir); err == nil && info.IsDir() {
sensitiveWPDirs = append(sensitiveWPDirs, sensitiveDir)
}
}
}
}
for _, uploadsDir := range uploadDirs {
scanDirForPHP(uploadsDir, 6, dirCache, prevByDir, forceFullScan, &entries)
}
// Scan sensitive WP directories for any PHP files
for _, sensitiveDir := range sensitiveWPDirs {
scanDirForPHP(sensitiveDir, 4, dirCache, prevByDir, forceFullScan, &entries)
}
// Scan .config for executables
configDir := filepath.Join(homeDir, ".config")
scanDirForExecutables(configDir, 3, dirCache, prevByDir, forceFullScan, &entries)
}
// Scan tmp dirs
for _, tmpDir := range []string{"/tmp", "/dev/shm", "/var/tmp"} {
scanDirForSuspiciousExt(tmpDir, 2, dirCache, prevByDir, forceFullScan, &entries)
}
sort.Strings(entries)
return entries
}
// scanDirForPHP recursively reads directories for .php files.
// If directory mtime is unchanged, carries forward previous entries.
func scanDirForPHP(dir string, maxDepth int, cache dirMtimeCache, prev map[string][]string, forceFullScan bool, entries *[]string) {
if maxDepth <= 0 {
return
}
if !dirChanged(dir, cache, forceFullScan) {
// Carry forward previous entries for this directory
*entries = append(*entries, prev[dir]...)
return
}
dirEntries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range dirEntries {
name := entry.Name()
fullPath := filepath.Join(dir, name)
if entry.IsDir() {
scanDirForPHP(fullPath, maxDepth-1, cache, prev, forceFullScan, entries)
continue
}
nameLower := strings.ToLower(name)
if strings.HasSuffix(nameLower, ".php") && nameLower != "index.php" {
*entries = append(*entries, fullPath)
}
ext := filepath.Ext(nameLower)
if suspiciousExtensions[ext] {
*entries = append(*entries, fullPath)
}
}
}
// scanDirForExecutables reads .config dirs for executable files.
func scanDirForExecutables(dir string, maxDepth int, cache dirMtimeCache, prev map[string][]string, forceFullScan bool, entries *[]string) {
if maxDepth <= 0 {
return
}
if !dirChanged(dir, cache, forceFullScan) {
*entries = append(*entries, prev[dir]...)
return
}
dirEntries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range dirEntries {
fullPath := filepath.Join(dir, entry.Name())
if entry.IsDir() {
scanDirForExecutables(fullPath, maxDepth-1, cache, prev, forceFullScan, entries)
continue
}
info, err := entry.Info()
if err != nil {
continue
}
if info.Mode()&0111 != 0 {
*entries = append(*entries, fullPath)
}
}
}
// scanDirForSuspiciousExt reads tmp dirs for files with suspicious extensions.
func scanDirForSuspiciousExt(dir string, maxDepth int, cache dirMtimeCache, prev map[string][]string, forceFullScan bool, entries *[]string) {
if maxDepth <= 0 {
return
}
if !dirChanged(dir, cache, forceFullScan) {
*entries = append(*entries, prev[dir]...)
return
}
dirEntries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range dirEntries {
name := entry.Name()
fullPath := filepath.Join(dir, name)
if entry.IsDir() {
scanDirForSuspiciousExt(fullPath, maxDepth-1, cache, prev, forceFullScan, entries)
continue
}
ext := filepath.Ext(strings.ToLower(name))
if suspiciousExtensions[ext] {
*entries = append(*entries, fullPath)
}
}
}
func isKnownSafeUpload(path, name string) bool {
safePaths := []string{
"/cache/", "/imunify", "/redux/", "/mailchimp-for-wp/",
"/sucuri/", "/smush/", "/goldish/", "/wpallexport/",
"/wpallimport/", "/wph/", "/stm_fonts/", "/smile_fonts/",
"/bws-custom-code/", "/wp-import-export-lite/",
"/mc4wp-debug-log.php", "/zn_fonts/", "/companies_documents/",
}
if name == "index.php" {
return true
}
for _, sp := range safePaths {
if strings.Contains(path, sp) {
return true
}
}
return false
}
func isWebshellName(name string) bool {
webshells := map[string]bool{
"h4x0r.php": true, "c99.php": true, "r57.php": true,
"wso.php": true, "alfa.php": true, "b374k.php": true,
"shell.php": true, "cmd.php": true, "backdoor.php": true,
"webshell.php": true, "hack.php": true, "0x.php": true,
"up.php": true, "uploader.php": true, "filemanager.php": true,
}
return webshells[name]
}
func isSuspiciousPHPName(name string) bool {
suspicious := []string{
"shell", "cmd", "exec", "hack", "backdoor", "upload",
"exploit", "reverse", "connect", "proxy", "tunnel",
"0x", "x0", "eval", "assert", "passthru",
}
for _, s := range suspicious {
if strings.Contains(name, s) {
return true
}
}
nameNoExt := strings.TrimSuffix(name, ".php")
if len(nameNoExt) <= 5 && strings.ContainsAny(nameNoExt, "0123456789") {
return true
}
return false
}
func writeIndex(path string, entries []string) {
tmpPath := path + ".tmp"
// #nosec G304 -- path is filepath.Join under operator-configured StatePath.
f, err := os.Create(tmpPath)
if err != nil {
return
}
w := bufio.NewWriter(f)
for _, e := range entries {
_, _ = w.WriteString(e + "\n")
}
_ = w.Flush()
_ = f.Close()
_ = os.Rename(tmpPath, path)
}
func loadIndex(path string) []string {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var entries []string
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if line != "" {
entries = append(entries, line)
}
}
return entries
}
func copyFile(src, dst string) {
data, err := osFS.ReadFile(src)
if err != nil {
return
}
_ = os.WriteFile(dst, data, 0600)
}
package checks
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckFilesystem uses globs and targeted ReadDir to check for backdoors,
// hidden files, and SUID binaries. No `find` command needed.
func CheckFilesystem(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// GSocket / backdoor binaries in .config dirs - glob (instant)
backdoorNames := map[string]bool{
"defunct": true, "defunct.dat": true, "gs-netcat": true,
"gs-sftp": true, "gs-mount": true, "gsocket": true,
}
configGlobs := []string{
"/home/*/.config/htop/*",
"/home/*/.config/*/*",
}
for _, pattern := range configGlobs {
if ctx.Err() != nil {
return findings
}
matches, _ := osFS.Glob(pattern)
for _, path := range matches {
if backdoorNames[filepath.Base(path)] {
info, _ := osFS.Stat(path)
var details string
if info != nil {
details = fmt.Sprintf("Size: %d bytes, Mtime: %s", info.Size(), info.ModTime().Format("2006-01-02 15:04:05"))
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "backdoor_binary",
Message: fmt.Sprintf("Backdoor binary found: %s", path),
Details: details,
FilePath: path,
})
}
}
}
// Hidden files in /tmp, /dev/shm, /var/tmp - glob (instant)
safeHiddenPrefixes := []string{
".s.PGSQL", ".font-unix", ".ICE-unix", ".X11-unix",
".XIM-unix", ".crontab.", ".Test-unix",
}
for _, pattern := range []string{"/tmp/.*", "/dev/shm/.*", "/var/tmp/.*"} {
if ctx.Err() != nil {
return findings
}
matches, _ := osFS.Glob(pattern)
for _, match := range matches {
info, err := osFS.Stat(match)
if err != nil || info.IsDir() {
continue
}
base := filepath.Base(match)
safe := false
for _, prefix := range safeHiddenPrefixes {
if strings.HasPrefix(base, prefix) {
safe = true
break
}
}
if safe {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "suspicious_file",
Message: fmt.Sprintf("Suspicious hidden file: %s", match),
Details: fmt.Sprintf("Size: %d, Mtime: %s", info.Size(), info.ModTime()),
FilePath: match,
})
}
}
// SUID binaries in tmp dirs - ReadDir + stat (small dirs, fast)
for _, dir := range []string{"/tmp", "/var/tmp", "/dev/shm"} {
scanForSUID(ctx, dir, 3, &findings)
}
// SUID in /home - shallow scan only
homeDirs, _ := GetScanHomeDirs()
for _, entry := range homeDirs {
if ctx.Err() != nil {
return findings
}
if !entry.IsDir() {
continue
}
scanForSUID(ctx, filepath.Join("/home", entry.Name()), 3, &findings)
}
return findings
}
// scanForSUID checks for SUID binaries using ReadDir.
func scanForSUID(ctx context.Context, dir string, maxDepth int, findings *[]alert.Finding) {
if ctx.Err() != nil {
return
}
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
if ctx.Err() != nil {
return
}
fullPath := filepath.Join(dir, entry.Name())
if entry.IsDir() {
// Skip virtfs and known large dirs
if entry.Name() == "virtfs" || entry.Name() == "mail" || entry.Name() == "public_html" {
continue
}
scanForSUID(ctx, fullPath, maxDepth-1, findings)
continue
}
info, err := entry.Info()
if err != nil {
continue
}
if info.Mode()&os.ModeSetuid != 0 {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "suid_binary",
Message: fmt.Sprintf("SUID binary in unusual location: %s", fullPath),
Details: fmt.Sprintf("Mode: %s, Size: %d", info.Mode(), info.Size()),
FilePath: fullPath,
})
}
}
}
// CheckWebshells uses pure Go ReadDir to scan for known webshell files
// and directories. No `find` command needed.
func CheckWebshells(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
webshellNames := map[string]bool{
"h4x0r.php": true, "c99.php": true, "r57.php": true,
"wso.php": true, "alfa.php": true, "b374k.php": true,
"mini.php": true, "adminer.php": true,
}
webshellDirs := map[string]bool{
"LEVIATHAN": true, "haxorcgiapi": true,
}
// Scan each user's public_html and addon domains
homeDirs, _ := GetScanHomeDirs()
for _, homeEntry := range homeDirs {
if ctx.Err() != nil {
return findings
}
if !homeEntry.IsDir() {
continue
}
homeDir := filepath.Join("/home", homeEntry.Name())
// Get all potential document roots
docRoots := []string{filepath.Join(homeDir, "public_html")}
subDirs, _ := osFS.ReadDir(homeDir)
for _, sd := range subDirs {
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
docRoots = append(docRoots, filepath.Join(homeDir, sd.Name()))
}
}
for _, docRoot := range docRoots {
scanForWebshells(ctx, docRoot, 8, webshellNames, webshellDirs, cfg, &findings)
if ctx.Err() != nil {
return findings
}
}
}
return findings
}
// scanForWebshells recursively reads directories looking for known webshell
// files and directories. Uses ReadDir (getdents) - no stat unless matched.
func scanForWebshells(ctx context.Context, dir string, maxDepth int, names map[string]bool, dirs map[string]bool, cfg *config.Config, findings *[]alert.Finding) {
if ctx.Err() != nil {
return
}
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
if ctx.Err() != nil {
return
}
name := entry.Name()
fullPath := filepath.Join(dir, name)
// Check suppressed paths
suppressed := false
for _, ignore := range cfg.Suppressions.IgnorePaths {
if matchGlob(fullPath, ignore) {
suppressed = true
break
}
}
if suppressed {
continue
}
if entry.IsDir() {
if dirs[name] {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "webshell",
Message: fmt.Sprintf("Webshell directory found: %s", fullPath),
FilePath: fullPath,
})
}
scanForWebshells(ctx, fullPath, maxDepth-1, names, dirs, cfg, findings)
continue
}
nameLower := strings.ToLower(name)
if names[nameLower] {
info, _ := osFS.Stat(fullPath)
var details string
if info != nil {
details = fmt.Sprintf("Size: %d, Mtime: %s", info.Size(), info.ModTime())
}
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "webshell",
Message: fmt.Sprintf("Known webshell found: %s", fullPath),
Details: details,
FilePath: fullPath,
})
}
// .haxor extension
if strings.HasSuffix(nameLower, ".haxor") || strings.HasSuffix(nameLower, ".cgix") {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "webshell",
Message: fmt.Sprintf("Suspicious CGI file: %s", fullPath),
FilePath: fullPath,
})
}
// File permission anomalies - only check PHP files to keep it fast
if strings.HasSuffix(nameLower, ".php") {
info, err := entry.Info()
if err == nil {
mode := info.Mode()
// World-writable PHP
if mode&0002 != 0 {
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "world_writable_php",
Message: fmt.Sprintf("World-writable PHP file: %s", fullPath),
Details: fmt.Sprintf("Mode: %s", mode),
FilePath: fullPath,
})
}
// Note: executable PHP check removed - most PHP files on cPanel
// have +x due to suPHP/lsapi, making this too noisy.
}
}
}
}
func matchGlob(path, pattern string) bool {
matched, _ := filepath.Match(pattern, filepath.Base(path))
if matched {
return true
}
pattern = strings.ReplaceAll(pattern, "*", "")
return strings.Contains(path, pattern)
}
package checks
import (
"context"
"fmt"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
func CheckFirewall(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
if !cfg.Firewall.Enabled {
// Firewall not managed by CSM - skip nftables checks
return findings
}
// Verify the CSM nftables table exists and has expected components.
// Routed through cmdExec so tests can mock the nft response without
// requiring a real nftables stack.
out, err := cmdExec.RunAllowNonZero("nft", "list", "table", "inet", "csm")
if err != nil {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "firewall",
Message: "CSM firewall table not found in nftables - rules may not be active",
Timestamp: time.Now(),
})
return findings
}
output := string(out)
required := []string{"chain input", "chain output", "set blocked_ips", "set allowed_ips", "set infra_ips"}
for _, component := range required {
if !strings.Contains(output, component) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "firewall",
Message: fmt.Sprintf("Firewall missing expected component: %s", component),
Timestamp: time.Now(),
})
}
}
// Hash the rule structure (excluding dynamic set elements which change on every block/unblock).
// Only detects if someone manually added/removed chains or rules.
var stableLines []byte
for _, line := range strings.Split(string(out), "\n") {
trimmed := strings.TrimSpace(line)
// Skip set element lines (start with IPs or "elements = {" or "expires")
if strings.HasPrefix(trimmed, "elements") || strings.Contains(trimmed, "expires") {
continue
}
// Skip empty element continuation lines (just IPs)
if len(trimmed) > 0 && (trimmed[0] >= '0' && trimmed[0] <= '9') {
continue
}
stableLines = append(stableLines, line...)
stableLines = append(stableLines, '\n')
}
hash := hashBytes(stableLines)
prev, exists := store.GetRaw("_nftables_rules_hash")
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "firewall",
Message: "nftables ruleset modified outside of CSM",
Timestamp: time.Now(),
})
}
store.SetRaw("_nftables_rules_hash", hash)
// Check for dangerous ports in config
findings = append(findings, checkDangerousPorts(cfg)...)
return findings
}
func checkDangerousPorts(cfg *config.Config) []alert.Finding {
var findings []alert.Finding
dangerousPorts := make(map[int]bool)
for _, p := range cfg.BackdoorPorts {
dangerousPorts[p] = true
}
restricted := make(map[int]bool)
for _, p := range cfg.Firewall.RestrictedTCP {
restricted[p] = true
}
for _, port := range cfg.Firewall.TCPIn {
if dangerousPorts[port] {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "firewall_ports",
Message: fmt.Sprintf("Known backdoor port %d is open in firewall TCP_IN", port),
Timestamp: time.Now(),
})
}
if restricted[port] {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "firewall_ports",
Message: fmt.Sprintf("Restricted port %d found in public TCP_IN - should be infra-only", port),
Timestamp: time.Now(),
})
}
}
return findings
}
package checks
import (
"bufio"
"context"
"crypto/sha256"
"fmt"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
)
// parseValiasLine parses a valiases line "local_part: destination".
// Returns empty strings for comments, blank lines, or malformed lines.
func parseValiasLine(line string) (localPart, dest string) {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
return "", ""
}
idx := strings.IndexByte(line, ':')
if idx < 0 {
return "", ""
}
localPart = strings.TrimSpace(line[:idx])
dest = strings.TrimSpace(line[idx+1:])
return localPart, dest
}
// isPipeForwarder returns true if the destination is a pipe forwarder,
// excluding known-safe cPanel built-in pipes (autoresponder, BoxTrapper).
func isPipeForwarder(dest string) bool {
if !strings.HasPrefix(dest, "|") {
return false
}
safe := []string{
"/usr/local/cpanel/bin/autorespond",
"/usr/local/cpanel/bin/boxtrapper",
"/usr/local/cpanel/bin/mailman",
}
for _, s := range safe {
if strings.Contains(dest, s) {
return false
}
}
return true
}
// isDevNullForwarder returns true if the destination is /dev/null.
func isDevNullForwarder(dest string) bool {
return dest == "/dev/null"
}
// isExternalDest returns true if the destination is an email address
// with a domain not in the local domains set.
func isExternalDest(dest string, localDomains map[string]bool) bool {
atIdx := strings.LastIndexByte(dest, '@')
if atIdx < 0 || atIdx >= len(dest)-1 {
return false
}
domain := strings.ToLower(dest[atIdx+1:])
return !localDomains[domain]
}
// parseVfilterExternalDests extracts external email destinations from vfilter content.
// Looks for `to "dest@domain"` directives.
func parseVfilterExternalDests(content string, localDomains map[string]bool) []string {
var external []string
lines := strings.Split(content, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
// Look for: to "email@domain"
if !strings.HasPrefix(line, "to ") && !strings.HasPrefix(line, "to\t") {
continue
}
// Extract the quoted destination
quoteStart := strings.IndexByte(line, '"')
if quoteStart < 0 {
continue
}
rest := line[quoteStart+1:]
quoteEnd := strings.IndexByte(rest, '"')
if quoteEnd < 0 {
continue
}
dest := rest[:quoteEnd]
if isExternalDest(dest, localDomains) {
external = append(external, dest)
}
}
return external
}
// parseLocalDomainsContent parses the content of /etc/localdomains or /etc/virtualdomains.
func parseLocalDomainsContent(content string) map[string]bool {
domains := make(map[string]bool)
lines := strings.Split(content, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// virtualdomains format: "domain: user" - take domain part
if idx := strings.IndexByte(line, ':'); idx > 0 {
line = strings.TrimSpace(line[:idx])
}
domains[strings.ToLower(line)] = true
}
return domains
}
// loadLocalDomains reads /etc/localdomains and /etc/virtualdomains.
func loadLocalDomains() map[string]bool {
domains := make(map[string]bool)
for _, path := range []string{"/etc/localdomains", "/etc/virtualdomains"} {
data, err := osFS.ReadFile(path)
if err != nil {
continue
}
for k, v := range parseLocalDomainsContent(string(data)) {
domains[k] = v
}
}
return domains
}
// isKnownForwarder checks if a forwarder rule matches the known forwarders suppression list.
func isKnownForwarder(localPart, domain, dest string, knownForwarders []string) bool {
entry := fmt.Sprintf("%s@%s: %s", localPart, domain, dest)
for _, known := range knownForwarders {
if strings.EqualFold(strings.TrimSpace(known), entry) {
return true
}
}
return false
}
// fileContentHash returns the SHA256 hex hash of a file's content.
func fileContentHash(path string) (string, error) {
data, err := osFS.ReadFile(path)
if err != nil {
return "", err
}
h := sha256.Sum256(data)
return fmt.Sprintf("%x", h[:]), nil
}
// CheckForwarders audits all valiases and vfilters files for dangerous forwarder
// patterns. Uses internal throttle: skips if last refresh was less than
// password_check_interval_min ago (reuses the same interval).
func CheckForwarders(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
db := store.Global()
if db == nil {
return nil
}
// Internal throttle (24h) - reuse PasswordCheckIntervalMin
if !ForceAll {
lastRefreshStr := db.GetMetaString("email:fwd_last_refresh")
if lastRefreshStr != "" {
if lastRefresh, err := time.Parse(time.RFC3339, lastRefreshStr); err == nil {
interval := time.Duration(cfg.EmailProtection.PasswordCheckIntervalMin) * time.Minute
if time.Since(lastRefresh) < interval {
return nil
}
}
}
}
localDomains := loadLocalDomains()
var findings []alert.Finding
// Audit valiases
valiasFiles, _ := osFS.Glob("/etc/valiases/*")
for _, path := range valiasFiles {
domain := filepath.Base(path)
entries := auditValiasFile(path, domain, localDomains, cfg)
findings = append(findings, entries...)
// Store hash for change detection (enrichment, not filtering)
hash, err := fileContentHash(path)
if err == nil {
_ = db.SetForwarderHash("valiases:"+domain, hash)
}
}
// Audit vfilters
vfilterFiles, _ := osFS.Glob("/etc/vfilters/*")
for _, path := range vfilterFiles {
domain := filepath.Base(path)
entries := auditVfilterFile(path, domain, localDomains, cfg)
findings = append(findings, entries...)
hash, err := fileContentHash(path)
if err == nil {
_ = db.SetForwarderHash("vfilters:"+domain, hash)
}
}
_ = db.SetMetaString("email:fwd_last_refresh", time.Now().Format(time.RFC3339))
return findings
}
// auditValiasFile parses a valiases file and returns findings for dangerous entries.
func auditValiasFile(path, domain string, localDomains map[string]bool, cfg *config.Config) []alert.Finding {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer f.Close()
db := store.Global()
var findings []alert.Finding
// Check if file hash changed (for "newly added" context).
// On first scan (no prior hash), store baseline but don't mark as new —
// only flag as "newly added" when the hash actually changed from a
// previously known value. This prevents flooding on fresh installs.
isNew := false
if db != nil {
currentHash, hashErr := fileContentHash(path)
if hashErr == nil {
oldHash, found := db.GetForwarderHash("valiases:" + domain)
if found && oldHash != currentHash {
isNew = true // hash changed — genuinely new forwarder
}
// Always store current hash (establishes baseline on first run)
_ = db.SetForwarderHash("valiases:"+domain, currentHash)
}
}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
localPart, dest := parseValiasLine(scanner.Text())
if localPart == "" || dest == "" {
continue
}
// Check each destination (may be comma-separated)
dests := strings.Split(dest, ",")
for _, d := range dests {
d = strings.TrimSpace(d)
if d == "" {
continue
}
// Suppression check
if isKnownForwarder(localPart, domain, d, cfg.EmailProtection.KnownForwarders) {
continue
}
newContext := ""
if isNew {
newContext = " (newly added)"
}
if isPipeForwarder(d) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_pipe_forwarder",
Message: fmt.Sprintf("Pipe forwarder detected: %s@%s -> %s%s", localPart, domain, d, newContext),
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: %s\nFile: %s\nPipe forwarders execute arbitrary commands on incoming mail.", domain, localPart, d, path),
})
continue
}
if isDevNullForwarder(d) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_suspicious_forwarder",
Message: fmt.Sprintf("Mail blackhole: %s@%s -> /dev/null%s", localPart, domain, newContext),
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: /dev/null\nFile: %s\nAll mail to this address is silently discarded.", domain, localPart, path),
})
continue
}
if isExternalDest(d, localDomains) && isNew {
severity := alert.High
msg := fmt.Sprintf("External forwarder: %s@%s -> %s%s", localPart, domain, d, newContext)
if localPart == "*" {
msg = fmt.Sprintf("Wildcard catch-all to external: *@%s -> %s%s", domain, d, newContext)
}
findings = append(findings, alert.Finding{
Severity: severity,
Check: "email_suspicious_forwarder",
Message: msg,
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: %s\nFile: %s", domain, localPart, d, path),
})
}
}
}
return findings
}
// auditVfilterFile parses a vfilters file and returns findings for external destinations.
func auditVfilterFile(path, domain string, localDomains map[string]bool, cfg *config.Config) []alert.Finding {
data, err := osFS.ReadFile(path)
if err != nil {
return nil
}
db := store.Global()
content := string(data)
// Check if file hash changed — same first-run logic as valiases above.
isNew := false
if db != nil {
currentHash := fmt.Sprintf("%x", sha256.Sum256(data))
oldHash, found := db.GetForwarderHash("vfilters:" + domain)
if found && oldHash != currentHash {
isNew = true
}
_ = db.SetForwarderHash("vfilters:"+domain, currentHash)
}
externalDests := parseVfilterExternalDests(content, localDomains)
var findings []alert.Finding
for _, dest := range externalDests {
// Suppression check - use "*" as localPart for vfilter entries
if isKnownForwarder("*", domain, dest, cfg.EmailProtection.KnownForwarders) {
continue
}
// Only alert when the vfilter file actually changed — existing forwarders
// are normal customer configuration, not an attack indicator.
if !isNew {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_suspicious_forwarder",
Message: fmt.Sprintf("External destination in vfilter: %s -> %s (newly added)", domain, dest),
Details: fmt.Sprintf("Domain: %s\nDestination: %s\nFile: %s\nA mail filter rule forwards messages to an external address.", domain, dest, path),
})
}
return findings
}
package checks
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckOpenBasedir verifies that each cPanel account has proper PHP
// isolation via CageFS and/or open_basedir.
// Flags accounts where CageFS is disabled AND open_basedir is not set.
func CheckOpenBasedir(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Check global CageFS mode
cageFSMode := getCageFSMode()
// Get list of disabled CageFS users (if mode is "Enable All", check exceptions)
disabledUsers := getCageFSDisabledUsers(cageFSMode)
// For users without CageFS, check if open_basedir is set
userDirs, _ := osFS.ReadDir("/var/cpanel/users")
for _, userEntry := range userDirs {
user := userEntry.Name()
// Skip if CageFS is active for this user
if !disabledUsers[user] {
continue
}
// CageFS is disabled for this user - check open_basedir
if !hasOpenBasedir(user) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "open_basedir",
Message: fmt.Sprintf("Account %s has no PHP isolation: CageFS disabled and no open_basedir", user),
Details: "This account's PHP scripts can read any file on the server including other accounts' wp-config.php and /etc/shadow",
})
}
}
return findings
}
func getCageFSMode() string {
// CageFS mode file
data, err := osFS.ReadFile("/etc/cagefs/cagefs.mp")
if err != nil {
return "unknown"
}
content := strings.TrimSpace(string(data))
if content != "" {
return "enabled"
}
return "unknown"
}
func getCageFSDisabledUsers(mode string) map[string]bool {
disabled := make(map[string]bool)
// Check cagefsctl --list-disabled
out, _ := runCmd("cagefsctl", "--list-disabled")
if out != nil {
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
user := strings.TrimSpace(line)
if user != "" {
disabled[user] = true
}
}
}
// If mode is unknown (no CageFS), all users are "disabled"
if mode == "unknown" && len(disabled) == 0 {
userDirs, _ := osFS.ReadDir("/var/cpanel/users")
for _, u := range userDirs {
disabled[u.Name()] = true
}
}
return disabled
}
func hasOpenBasedir(user string) bool {
// Check .user.ini in public_html
userIni := filepath.Join("/home", user, "public_html", ".user.ini")
if data, err := osFS.ReadFile(userIni); err == nil {
if strings.Contains(strings.ToLower(string(data)), "open_basedir") {
return true
}
}
// Check .htaccess for php_value open_basedir
htaccess := filepath.Join("/home", user, "public_html", ".htaccess")
if data, err := osFS.ReadFile(htaccess); err == nil {
if strings.Contains(strings.ToLower(string(data)), "open_basedir") {
return true
}
}
// Check per-user PHP config set via cPanel MultiPHP
phpConfDirs, _ := osFS.Glob("/opt/cpanel/ea-php*/root/etc/php.d/")
for _, confDir := range phpConfDirs {
userConf := filepath.Join(confDir, "local.ini")
if data, err := osFS.ReadFile(userConf); err == nil {
if strings.Contains(string(data), "open_basedir") {
return true
}
}
}
return false
}
// CheckSymlinkAttacks detects symbolic links inside user public_html
// directories that point outside the account's own directory.
// This is a classic shared hosting attack to read other users' files.
func CheckSymlinkAttacks(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
homeDirs, _ := GetScanHomeDirs()
for _, homeEntry := range homeDirs {
if !homeEntry.IsDir() {
continue
}
user := homeEntry.Name()
homeDir := filepath.Join("/home", user)
docRoot := filepath.Join(homeDir, "public_html")
scanForMaliciousSymlinks(docRoot, user, homeDir, 4, &findings)
}
return findings
}
func scanForMaliciousSymlinks(dir, user, homeDir string, maxDepth int, findings *[]alert.Finding) {
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
fullPath := filepath.Join(dir, entry.Name())
// Check if it's a symlink
if entry.Type()&os.ModeSymlink == 0 {
if entry.IsDir() {
scanForMaliciousSymlinks(fullPath, user, homeDir, maxDepth-1, findings)
}
continue
}
// It's a symlink - read the target
target, err := osFS.Readlink(fullPath)
if err != nil {
continue
}
// Resolve relative targets
if !filepath.IsAbs(target) {
target = filepath.Join(filepath.Dir(fullPath), target)
}
target = filepath.Clean(target)
// Check if target is outside the user's home
if isSymlinkSafe(target, user, homeDir) {
continue
}
// Check if target points to another user's home
if strings.HasPrefix(target, "/home/") {
parts := strings.SplitN(target[6:], "/", 2)
if len(parts) > 0 && parts[0] != user {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "symlink_attack",
Message: fmt.Sprintf("Symlink to another user's directory: %s -> %s", fullPath, target),
Details: fmt.Sprintf("User: %s, target user: %s\nThis could be used to read other accounts' files", user, parts[0]),
})
continue
}
}
// Check if target points to sensitive system files
sensitiveTargets := []string{"/etc/shadow", "/etc/passwd", "/root/"}
for _, sens := range sensitiveTargets {
if strings.HasPrefix(target, sens) {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "symlink_attack",
Message: fmt.Sprintf("Symlink to sensitive system file: %s -> %s", fullPath, target),
Details: fmt.Sprintf("User: %s", user),
})
break
}
}
}
}
func isSymlinkSafe(target, user, homeDir string) bool {
// Inside own home directory
if strings.HasPrefix(target, homeDir+"/") || target == homeDir {
return true
}
// Standard cPanel-created symlinks
safeTargets := []string{
"/etc/apache2/logs/",
"/usr/local/apache/logs/",
"/var/cpanel/",
"/opt/cpanel/",
"/usr/",
"/var/lib/mysql/",
"/var/run/",
}
for _, safe := range safeTargets {
if strings.HasPrefix(target, safe) {
return true
}
}
return false
}
package checks
import (
"bufio"
"context"
"encoding/hex"
"fmt"
"net"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/store"
)
// auditCmdTimeout is the per-subprocess timeout for audit checks.
// Audit checks are fast config reads, not heavy scans.
const auditCmdTimeout = 10 * time.Second
// RunHardeningAudit runs all hardening checks and returns a report.
// Pure function — reads system state only, no store access.
func RunHardeningAudit(cfg *config.Config) *store.AuditReport {
serverType := detectServerType()
var results []store.AuditResult
results = append(results, auditSSH()...)
results = append(results, auditPHP(serverType)...)
results = append(results, auditWebServer(serverType)...)
results = append(results, auditMail()...)
if serverType != "bare" {
results = append(results, auditCPanel(serverType)...)
}
results = append(results, auditOS()...)
results = append(results, auditFirewall()...)
score := 0
for _, r := range results {
if r.Status == "pass" {
score++
}
}
return &store.AuditReport{
Timestamp: time.Now(),
ServerType: serverType,
Results: results,
Score: score,
Total: len(results),
}
}
func detectServerType() string {
info := platform.Detect()
if info.IsCPanel() {
if info.OS == platform.OSCloudLinux {
return "cloudlinux"
}
return "cpanel"
}
return "bare"
}
// auditRunCmd executes a command with the audit-specific timeout via the
// cmdExec injector so tests can mock systemctl/cagefsctl/etc. without
// invoking real binaries on the host.
func auditRunCmd(name string, args ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), auditCmdTimeout)
defer cancel()
out, err := cmdExec.RunContext(ctx, name, args...)
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("command timed out: %s", name)
}
return out, err
}
// --- SSH checks ---
// sshdDefaults are the OpenSSH compiled defaults for settings we audit.
var sshdDefaults = map[string]string{
"port": "22",
"protocol": "2",
"passwordauthentication": "yes",
"permitrootlogin": "prohibit-password",
"maxauthtries": "6",
"x11forwarding": "no",
"usedns": "no",
}
var sshdConfigPath = "/etc/ssh/sshd_config"
type sshdSettings struct {
PasswordAuthentication string
PermitRootLogin string
X11Forwarding string
}
// parseSSHDConfig reads sshd_config + Include drop-ins with first-match-wins.
// Match blocks are skipped entirely (audit evaluates global config only).
func parseSSHDConfig() map[string]string {
effective := make(map[string]string)
parseSSHDFile(sshdConfigPath, effective)
return effective
}
func parseSSHDFile(path string, effective map[string]string) {
f, err := osFS.Open(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
inMatch := false
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
// Detect Match blocks — a Match block continues until the next
// Match keyword or EOF, regardless of indentation (per sshd_config(5)).
lower := strings.ToLower(line)
if strings.HasPrefix(lower, "match ") {
inMatch = true
continue
}
if inMatch {
// Only another Match line (handled above) or EOF ends the block.
// Everything else inside is a Match-scoped directive — skip it.
continue
}
// Handle Include directives
if strings.HasPrefix(lower, "include ") {
pattern := strings.TrimSpace(line[8:])
if !filepath.IsAbs(pattern) {
pattern = filepath.Join(filepath.Dir(path), pattern)
}
matches, _ := osFS.Glob(pattern)
for _, m := range matches {
parseSSHDFile(m, effective)
}
continue
}
// Parse keyword=value or keyword value
parts := strings.SplitN(line, " ", 2)
if len(parts) < 2 {
parts = strings.SplitN(line, "=", 2)
}
if len(parts) < 2 {
continue
}
keyword := strings.ToLower(strings.TrimSpace(parts[0]))
value := strings.TrimSpace(parts[1])
// First-match-wins: only record the first occurrence
if _, exists := effective[keyword]; !exists {
effective[keyword] = value
}
}
}
func sshdEffective(parsed map[string]string, key string) string {
if v, ok := parsed[key]; ok {
return strings.ToLower(v)
}
return sshdDefaults[key]
}
func currentSSHDSettings() sshdSettings {
parsed := parseSSHDConfig()
return sshdSettings{
PasswordAuthentication: sshdEffective(parsed, "passwordauthentication"),
PermitRootLogin: sshdEffective(parsed, "permitrootlogin"),
X11Forwarding: sshdEffective(parsed, "x11forwarding"),
}
}
func auditSSH() []store.AuditResult {
parsed := parseSSHDConfig()
var results []store.AuditResult
// ssh_port
port := sshdEffective(parsed, "port")
if port == "22" {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_port", Title: "SSH Port",
Status: "warn", Message: "SSH is running on default port 22",
Fix: "Change to a non-standard port in /etc/ssh/sshd_config to reduce automated scan noise. Update firewall rules before changing.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_port", Title: "SSH Port",
Status: "pass", Message: fmt.Sprintf("SSH on non-standard port %s", port),
})
}
// ssh_protocol
proto := sshdEffective(parsed, "protocol")
if strings.Contains(proto, "1") {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_protocol", Title: "SSH Protocol",
Status: "fail", Message: "SSHv1 protocol is enabled",
Fix: "Set 'Protocol 2' in /etc/ssh/sshd_config. SSHv1 has known cryptographic weaknesses.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_protocol", Title: "SSH Protocol",
Status: "pass", Message: "SSHv1 disabled",
})
}
// ssh_password_auth
passAuth := sshdEffective(parsed, "passwordauthentication")
if passAuth != "no" {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_password_auth", Title: "SSH PasswordAuthentication",
Status: "fail", Message: "Password authentication is enabled",
Fix: "Set 'PasswordAuthentication no' in /etc/ssh/sshd_config and use SSH key authentication only.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_password_auth", Title: "SSH PasswordAuthentication",
Status: "pass", Message: "Password authentication disabled",
})
}
// ssh_root_login
rootLogin := sshdEffective(parsed, "permitrootlogin")
if rootLogin == "yes" {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_root_login", Title: "SSH PermitRootLogin",
Status: "fail", Message: "Direct root login is permitted",
Fix: "Set 'PermitRootLogin no' or 'PermitRootLogin prohibit-password' in /etc/ssh/sshd_config.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_root_login", Title: "SSH PermitRootLogin",
Status: "pass", Message: fmt.Sprintf("PermitRootLogin set to %s", rootLogin),
})
}
// ssh_max_auth_tries
maxTries := sshdEffective(parsed, "maxauthtries")
n, _ := strconv.Atoi(maxTries)
if n == 0 {
n = 6 // default
}
if n > 4 {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_max_auth_tries", Title: "SSH MaxAuthTries",
Status: "warn", Message: fmt.Sprintf("MaxAuthTries is %d (recommended: 4 or less)", n),
Fix: "Set 'MaxAuthTries 4' in /etc/ssh/sshd_config to limit brute-force attempts per connection.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_max_auth_tries", Title: "SSH MaxAuthTries",
Status: "pass", Message: fmt.Sprintf("MaxAuthTries set to %d", n),
})
}
// ssh_x11_forwarding
x11 := sshdEffective(parsed, "x11forwarding")
if x11 != "no" {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_x11_forwarding", Title: "SSH X11Forwarding",
Status: "warn", Message: "X11 forwarding is enabled",
Fix: "Set 'X11Forwarding no' in /etc/ssh/sshd_config unless X11 forwarding is actively needed.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_x11_forwarding", Title: "SSH X11Forwarding",
Status: "pass", Message: "X11 forwarding disabled",
})
}
// ssh_use_dns
useDNS := sshdEffective(parsed, "usedns")
if useDNS != "no" {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_use_dns", Title: "SSH UseDNS",
Status: "warn", Message: "UseDNS is enabled",
Fix: "Set 'UseDNS no' in /etc/ssh/sshd_config. Otherwise lfd may not track SSH login failures by IP.",
})
} else {
results = append(results, store.AuditResult{
Category: "ssh", Name: "ssh_use_dns", Title: "SSH UseDNS",
Status: "pass", Message: "UseDNS disabled",
})
}
return results
}
// --- OS hardening checks ---
func auditOS() []store.AuditResult {
var results []store.AuditResult
// /tmp and /var/tmp permissions
for _, dir := range []struct {
path, id, title string
}{
{"/tmp", "os_tmp_permissions", "/tmp Permissions"},
{"/var/tmp", "os_var_tmp_permissions", "/var/tmp Permissions"},
} {
info, err := osFS.Stat(dir.path)
if err != nil {
results = append(results, store.AuditResult{
Category: "os", Name: dir.id, Title: dir.title,
Status: "warn", Message: fmt.Sprintf("Cannot stat %s: %v", dir.path, err),
})
continue
}
// Use the raw Unix mode bits from syscall to get the traditional
// permission representation (sticky=01000, setuid=04000, etc.).
// Go's os.ModeSticky uses high bits that don't map to Unix octal,
// so os.FileMode math produces wrong values for comparison.
// Only check the lower 12 bits (sticky + rwx) — ignore setuid/setgid
// which CloudLinux/CageFS may set on virtmp mounts.
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
results = append(results, store.AuditResult{
Category: "os", Name: dir.id, Title: dir.title,
Status: "warn", Message: fmt.Sprintf("Cannot read ownership of %s", dir.path),
})
continue
}
mode := stat.Mode & 01777 // sticky + rwxrwxrwx, ignore setuid/setgid
if mode != 01777 || stat.Uid != 0 || stat.Gid != 0 {
results = append(results, store.AuditResult{
Category: "os", Name: dir.id, Title: dir.title,
Status: "fail",
Message: fmt.Sprintf("%s has mode %04o uid=%d gid=%d (expected 1777 root:root)", dir.path, mode, stat.Uid, stat.Gid),
Fix: fmt.Sprintf("chmod 1777 %s && chown root:root %s", dir.path, dir.path),
})
} else {
results = append(results, store.AuditResult{
Category: "os", Name: dir.id, Title: dir.title,
Status: "pass", Message: fmt.Sprintf("%s is 1777 root:root", dir.path),
})
}
}
// /etc/shadow permissions
// Accept 0000, 0600 (RHEL/CentOS default), and 0640 (Debian default).
// All three restrict access to root only. 0600 is the standard on
// CentOS/CloudLinux — changing it can break passwd/chage.
if info, err := osFS.Stat("/etc/shadow"); err == nil {
perm := info.Mode().Perm()
if perm == 0 || perm == 0o600 || perm == 0o640 {
results = append(results, store.AuditResult{
Category: "os", Name: "os_shadow_permissions", Title: "/etc/shadow Permissions",
Status: "pass", Message: fmt.Sprintf("/etc/shadow has mode %04o", perm),
})
} else {
results = append(results, store.AuditResult{
Category: "os", Name: "os_shadow_permissions", Title: "/etc/shadow Permissions",
Status: "fail", Message: fmt.Sprintf("/etc/shadow has mode %04o (expected 0000, 0600, or 0640)", perm),
Fix: "chmod 0600 /etc/shadow",
})
}
} else {
results = append(results, store.AuditResult{
Category: "os", Name: "os_shadow_permissions", Title: "/etc/shadow Permissions",
Status: "warn", Message: "Cannot stat /etc/shadow",
})
}
// Swap
if data, err := osFS.ReadFile("/proc/swaps"); err == nil {
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) > 1 {
results = append(results, store.AuditResult{
Category: "os", Name: "os_swap", Title: "Swap Configured",
Status: "pass", Message: fmt.Sprintf("%d swap device(s) active", len(lines)-1),
})
} else {
results = append(results, store.AuditResult{
Category: "os", Name: "os_swap", Title: "Swap Configured",
Status: "warn", Message: "No swap configured",
Fix: "Configure swap space to prevent OOM kills: fallocate -l 2G /swapfile && chmod 600 /swapfile && mkswap /swapfile && swapon /swapfile",
})
}
}
// Distro EOL
results = append(results, checkDistroEOL()...)
// nobody crontab
if info, err := osFS.Stat("/var/spool/cron/nobody"); err != nil {
// absent is fine
results = append(results, store.AuditResult{
Category: "os", Name: "os_nobody_cron", Title: "Nobody Crontab",
Status: "pass", Message: "No crontab for nobody user",
})
} else if info.Size() == 0 {
results = append(results, store.AuditResult{
Category: "os", Name: "os_nobody_cron", Title: "Nobody Crontab",
Status: "pass", Message: "Nobody crontab is empty",
})
} else {
results = append(results, store.AuditResult{
Category: "os", Name: "os_nobody_cron", Title: "Nobody Crontab",
Status: "fail", Message: "nobody user has a crontab with content",
Fix: "Review and remove: crontab -u nobody -r",
})
}
// Unnecessary services
results = append(results, checkUnnecessaryServices()...)
// Sysctl checks (table-driven)
sysctlChecks := []struct {
id, title, path, expected string
}{
{"os_sysctl_syncookies", "TCP SYN Cookies", "/proc/sys/net/ipv4/tcp_syncookies", "1"},
{"os_sysctl_aslr", "Address Space Layout Randomization", "/proc/sys/kernel/randomize_va_space", "2"},
{"os_sysctl_rp_filter", "Reverse Path Filtering", "/proc/sys/net/ipv4/conf/all/rp_filter", "1"},
{"os_sysctl_icmp_broadcast", "ICMP Broadcast Ignore", "/proc/sys/net/ipv4/icmp_echo_ignore_broadcasts", "1"},
{"os_sysctl_symlinks", "Protected Symlinks", "/proc/sys/fs/protected_symlinks", "1"},
{"os_sysctl_hardlinks", "Protected Hardlinks", "/proc/sys/fs/protected_hardlinks", "1"},
}
for _, sc := range sysctlChecks {
data, err := osFS.ReadFile(sc.path)
if err != nil {
results = append(results, store.AuditResult{
Category: "os", Name: sc.id, Title: sc.title,
Status: "warn", Message: fmt.Sprintf("Cannot read %s", sc.path),
})
continue
}
val := strings.TrimSpace(string(data))
// Convert /proc/sys path to sysctl dotted notation for fix command
sysctlKey := strings.TrimPrefix(sc.path, "/proc/sys/")
sysctlKey = strings.ReplaceAll(sysctlKey, "/", ".")
if val == sc.expected {
results = append(results, store.AuditResult{
Category: "os", Name: sc.id, Title: sc.title,
Status: "pass", Message: fmt.Sprintf("%s = %s", sysctlKey, val),
})
} else {
results = append(results, store.AuditResult{
Category: "os", Name: sc.id, Title: sc.title,
Status: "fail", Message: fmt.Sprintf("%s = %s (expected %s)", sysctlKey, val, sc.expected),
Fix: fmt.Sprintf("sysctl -w %s=%s && echo '%s = %s' >> /etc/sysctl.d/99-csm-hardening.conf", sysctlKey, sc.expected, sysctlKey, sc.expected),
})
}
}
return results
}
// distroEOLPolicy encodes the oldest supported major version per known OS.
// Anything below the minimum is considered EOL by this check.
var distroEOLPolicy = map[platform.OSFamily]int{
platform.OSAlma: 8,
platform.OSRocky: 8,
platform.OSRHEL: 8,
platform.OSCloudLinux: 7,
platform.OSUbuntu: 20, // 20.04 is the oldest non-EOL LTS
platform.OSDebian: 11, // Debian 11 "bullseye"
}
func checkDistroEOL() []store.AuditResult {
return evaluateDistroEOL(platform.Detect(), readOSReleasePretty())
}
// evaluateDistroEOL is the pure, testable core of checkDistroEOL. It returns
// an AuditResult based purely on the supplied platform info and
// PRETTY_NAME string (either may be empty).
func evaluateDistroEOL(info platform.Info, prettyName string) []store.AuditResult {
if prettyName == "" && info.OSVersion != "" {
prettyName = fmt.Sprintf("%s %s", info.OS, info.OSVersion)
}
if info.OS == platform.OSUnknown || info.OSVersion == "" {
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "warn", Message: "Unable to determine distribution version",
}}
}
if info.OS == platform.OSCentOS {
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "fail",
Message: fmt.Sprintf("%s — CentOS is end-of-life", prettyName),
Fix: "Migrate to a supported replacement such as AlmaLinux, Rocky Linux, or RHEL. CentOS no longer receives security patches.",
}}
}
// Extract the major version. Ubuntu/Debian use "24.04" / "12", RHEL
// family uses "8.6" / "10", etc. Taking the integer prefix handles both.
majorStr, _, _ := strings.Cut(info.OSVersion, ".")
major, err := strconv.Atoi(majorStr)
if err != nil {
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "warn", Message: fmt.Sprintf("%s — unable to parse version %q", prettyName, info.OSVersion),
}}
}
minVersion, known := distroEOLPolicy[info.OS]
if !known {
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "warn", Message: fmt.Sprintf("%s — no EOL policy configured for this distro", prettyName),
}}
}
if major < minVersion {
fix := "Upgrade to a supported release. EOL distributions receive no security patches."
if info.IsRHELFamily() {
fix = fmt.Sprintf("Upgrade to %s %d+ or newer. EOL distributions receive no security patches.", info.OS, minVersion)
}
if info.IsDebianFamily() {
fix = fmt.Sprintf("Upgrade to %s %d+ or newer LTS. EOL distributions receive no security patches.", info.OS, minVersion)
}
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "fail",
Message: fmt.Sprintf("%s — major version %d is EOL", prettyName, major),
Fix: fix,
}}
}
return []store.AuditResult{{
Category: "os", Name: "os_distro_eol", Title: "Distribution End of Life",
Status: "pass", Message: prettyName,
}}
}
// readOSReleasePretty returns the PRETTY_NAME from /etc/os-release or "".
func readOSReleasePretty() string {
data, err := osFS.ReadFile("/etc/os-release")
if err != nil {
return ""
}
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "PRETTY_NAME=") {
return strings.Trim(strings.TrimPrefix(line, "PRETTY_NAME="), `"'`)
}
}
return ""
}
func checkUnnecessaryServices() []store.AuditResult {
badServices := []string{
"avahi-daemon", "bluetooth", "cups", "cupsd", "gdm",
"ModemManager", "packagekit", "rpcbind", "wpa_supplicant", "firewalld",
}
out, err := auditRunCmd("systemctl", "list-unit-files", "--state=enabled", "--no-pager", "--no-legend")
if err != nil {
return []store.AuditResult{{
Category: "os", Name: "os_services", Title: "Unnecessary Services",
Status: "warn", Message: "Cannot query systemd unit files",
}}
}
var found []string
lines := strings.Split(string(out), "\n")
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) == 0 {
continue
}
unit := strings.TrimSuffix(fields[0], ".service")
for _, bad := range badServices {
if unit == bad {
found = append(found, bad)
}
}
}
if len(found) == 0 {
return []store.AuditResult{{
Category: "os", Name: "os_services", Title: "Unnecessary Services",
Status: "pass", Message: "No unnecessary services enabled",
}}
}
return []store.AuditResult{{
Category: "os", Name: "os_services", Title: "Unnecessary Services",
Status: "warn",
Message: fmt.Sprintf("Unnecessary services enabled: %s", strings.Join(found, ", ")),
Fix: fmt.Sprintf("systemctl disable --now %s", strings.Join(found, " ")),
}}
}
// --- Firewall checks ---
func auditFirewall() []store.AuditResult {
var results []store.AuditResult
// Gather nft and iptables state
nftOut, nftErr := auditRunCmd("nft", "list", "ruleset")
nftRules := string(nftOut)
hasNft := nftErr == nil && strings.TrimSpace(nftRules) != ""
iptOut, iptErr := auditRunCmd("iptables", "-L", "INPUT", "-n")
iptRules := string(iptOut)
hasIpt := iptErr == nil && strings.TrimSpace(iptRules) != ""
// fw_active
if hasNft || hasIpt {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_active", Title: "Firewall Active",
Status: "pass", Message: "Firewall has active rules",
})
} else {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_active", Title: "Firewall Active",
Status: "fail", Message: "No active firewall rules detected",
Fix: "Install and configure nftables or iptables with a default-deny policy.",
})
}
// fw_default_policy
defaultDeny := false
if hasNft {
lower := strings.ToLower(nftRules)
if strings.Contains(lower, "policy drop") || strings.Contains(lower, "policy reject") {
defaultDeny = true
}
}
if !defaultDeny && hasIpt {
for _, line := range strings.Split(iptRules, "\n") {
if strings.HasPrefix(line, "Chain INPUT") {
upper := strings.ToUpper(line)
if strings.Contains(upper, "POLICY DROP") || strings.Contains(upper, "POLICY REJECT") {
defaultDeny = true
}
break
}
}
}
if defaultDeny {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_default_policy", Title: "Default INPUT Policy",
Status: "pass", Message: "INPUT chain has default-deny policy",
})
} else {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_default_policy", Title: "Default INPUT Policy",
Status: "fail", Message: "INPUT chain does not have a DROP/REJECT policy",
Fix: "Set the default INPUT policy to DROP: iptables -P INPUT DROP (or nft equivalent).",
})
}
// fw_mysql_exposed
results = append(results, checkMySQLExposed(hasNft, nftRules, hasIpt, iptRules)...)
// fw_telnet
if isPortListening(23) {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_telnet", Title: "Telnet Service",
Status: "fail", Message: "Something is listening on port 23 (telnet)",
Fix: "Disable telnet: systemctl disable --now telnet.socket xinetd; use SSH instead.",
})
} else {
results = append(results, store.AuditResult{
Category: "firewall", Name: "fw_telnet", Title: "Telnet Service",
Status: "pass", Message: "Nothing listening on port 23",
})
}
// fw_ipv6
results = append(results, checkIPv6Firewall()...)
return results
}
// getListeningAddr reads /proc/net/tcp for a port in LISTEN state (0A)
// and returns the hex-encoded local IP, or "" if not found.
func getListeningAddr(port int) string {
hexPort := fmt.Sprintf("%04X", port)
for _, path := range []string{"/proc/net/tcp", "/proc/net/tcp6"} {
data, err := osFS.ReadFile(path)
if err != nil {
continue
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) < 4 {
continue
}
// fields[1] = local_address (hex_ip:hex_port), fields[3] = state
if fields[3] != "0A" { // 0A = LISTEN
continue
}
parts := strings.SplitN(fields[1], ":", 2)
if len(parts) != 2 {
continue
}
if parts[1] == hexPort {
return parts[0]
}
}
}
return ""
}
// hexToIPv4 converts a /proc/net/tcp hex IP (little-endian 32-bit) to dotted notation.
func hexToIPv4(h string) string {
if len(h) != 8 {
return h
}
b, err := hex.DecodeString(h)
if err != nil || len(b) != 4 {
return h
}
// /proc/net/tcp stores IPs in little-endian byte order
return fmt.Sprintf("%d.%d.%d.%d", b[3], b[2], b[1], b[0])
}
// isPrivateOrLoopback returns true for loopback, RFC1918, and RFC4193 addresses.
func isPrivateOrLoopback(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
if ip.IsLoopback() {
return true
}
// Check private ranges
privateRanges := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"fc00::/7",
}
for _, cidr := range privateRanges {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if network.Contains(ip) {
return true
}
}
return false
}
// isPortListening checks /proc/net/tcp and /proc/net/tcp6 for a port in LISTEN state.
func isPortListening(port int) bool {
hexPort := fmt.Sprintf("%04X", port)
for _, path := range []string{"/proc/net/tcp", "/proc/net/tcp6"} {
data, err := osFS.ReadFile(path)
if err != nil {
continue
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) < 4 {
continue
}
if fields[3] != "0A" {
continue
}
parts := strings.SplitN(fields[1], ":", 2)
if len(parts) == 2 && parts[1] == hexPort {
return true
}
}
}
return false
}
func checkMySQLExposed(hasNft bool, nftRules string, hasIpt bool, iptRules string) []store.AuditResult {
hexAddr := getListeningAddr(3306)
if hexAddr == "" {
return []store.AuditResult{{
Category: "firewall", Name: "fw_mysql_exposed", Title: "MySQL Exposure",
Status: "pass", Message: "MySQL is not listening on any port",
}}
}
// Convert hex addr to IP and check if private/loopback
var ip string
if len(hexAddr) == 8 {
ip = hexToIPv4(hexAddr)
} else {
// IPv6: 32 hex chars, little-endian 4-byte groups
if b, err := hex.DecodeString(hexAddr); err == nil {
ipBytes := make(net.IP, len(b))
// Reverse each 4-byte group for /proc/net/tcp6 little-endian encoding
for i := 0; i+4 <= len(b); i += 4 {
ipBytes[i] = b[i+3]
ipBytes[i+1] = b[i+2]
ipBytes[i+2] = b[i+1]
ipBytes[i+3] = b[i]
}
ip = ipBytes.String()
}
}
// All zeros = wildcard bind
allZero := true
for _, c := range hexAddr {
if c != '0' {
allZero = false
break
}
}
if !allZero && ip != "" && isPrivateOrLoopback(ip) {
return []store.AuditResult{{
Category: "firewall", Name: "fw_mysql_exposed", Title: "MySQL Exposure",
Status: "pass", Message: fmt.Sprintf("MySQL bound to private/loopback address %s", ip),
}}
}
// Wildcard or public bind — check if firewall blocks 3306
fwBlocks3306 := false
if hasNft && !strings.Contains(nftRules, "3306") {
// If nft has rules but doesn't mention 3306 and has default deny, it's blocked
lower := strings.ToLower(nftRules)
if strings.Contains(lower, "policy drop") || strings.Contains(lower, "policy reject") {
fwBlocks3306 = true
}
}
if !fwBlocks3306 && hasIpt && !strings.Contains(iptRules, "3306") {
for _, line := range strings.Split(iptRules, "\n") {
if strings.HasPrefix(line, "Chain INPUT") {
upper := strings.ToUpper(line)
if strings.Contains(upper, "POLICY DROP") || strings.Contains(upper, "POLICY REJECT") {
fwBlocks3306 = true
}
break
}
}
}
bindDesc := "wildcard (0.0.0.0)"
if !allZero && ip != "" {
bindDesc = ip
}
if fwBlocks3306 {
return []store.AuditResult{{
Category: "firewall", Name: "fw_mysql_exposed", Title: "MySQL Exposure",
Status: "warn",
Message: fmt.Sprintf("MySQL bound to %s but firewall blocks port 3306", bindDesc),
Fix: "Bind MySQL to 127.0.0.1 in /etc/my.cnf: bind-address = 127.0.0.1",
}}
}
return []store.AuditResult{{
Category: "firewall", Name: "fw_mysql_exposed", Title: "MySQL Exposure",
Status: "fail",
Message: fmt.Sprintf("MySQL bound to %s and port 3306 appears accessible", bindDesc),
Fix: "Bind MySQL to 127.0.0.1 in /etc/my.cnf and/or block port 3306 in firewall.",
}}
}
func checkIPv6Firewall() []store.AuditResult {
// Check if any non-loopback, non-link-local IPv6 addresses exist
data, err := osFS.ReadFile("/proc/net/if_inet6")
if err != nil {
return []store.AuditResult{{
Category: "firewall", Name: "fw_ipv6", Title: "IPv6 Firewall",
Status: "pass", Message: "IPv6 not active (cannot read /proc/net/if_inet6)",
}}
}
hasIPv6 := false
for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") {
fields := strings.Fields(line)
if len(fields) < 6 {
continue
}
addr := fields[0]
iface := fields[5]
// Skip loopback
if iface == "lo" {
continue
}
// Skip link-local (fe80::/10)
if strings.HasPrefix(strings.ToLower(addr), "fe80") {
continue
}
hasIPv6 = true
break
}
if !hasIPv6 {
return []store.AuditResult{{
Category: "firewall", Name: "fw_ipv6", Title: "IPv6 Firewall",
Status: "pass", Message: "No non-link-local IPv6 addresses found",
}}
}
// Check nftables for inet/ip6 family chain with input hook and default deny
nftChains, err := auditRunCmd("nft", "list", "chains")
if err == nil {
chainsStr := strings.ToLower(string(nftChains))
// Look for chains in inet or ip6 family with filter hook input
// nft list chains output looks like:
// table inet filter {
// chain input {
// type filter hook input priority filter; policy drop;
// }
// }
var currentFamily string
for _, line := range strings.Split(chainsStr, "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "table ") {
parts := strings.Fields(trimmed)
if len(parts) >= 3 {
currentFamily = parts[1]
}
}
if (currentFamily == "inet" || currentFamily == "ip6") &&
strings.Contains(trimmed, "hook input") {
if strings.Contains(trimmed, "policy drop") || strings.Contains(trimmed, "policy reject") {
return []store.AuditResult{{
Category: "firewall", Name: "fw_ipv6", Title: "IPv6 Firewall",
Status: "pass", Message: fmt.Sprintf("IPv6 active; nftables %s family has default-deny input chain", currentFamily),
}}
}
}
}
}
// Fallback: check ip6tables
ip6Out, err := auditRunCmd("ip6tables", "-L", "INPUT", "-n")
if err == nil {
for _, line := range strings.Split(string(ip6Out), "\n") {
if strings.HasPrefix(line, "Chain INPUT") {
upper := strings.ToUpper(line)
if strings.Contains(upper, "POLICY DROP") || strings.Contains(upper, "POLICY REJECT") {
return []store.AuditResult{{
Category: "firewall", Name: "fw_ipv6", Title: "IPv6 Firewall",
Status: "pass", Message: "IPv6 active; ip6tables INPUT chain has default-deny policy",
}}
}
break
}
}
}
return []store.AuditResult{{
Category: "firewall", Name: "fw_ipv6", Title: "IPv6 Firewall",
Status: "fail", Message: "IPv6 is active but no default-deny input policy found",
Fix: "Configure ip6tables or nftables inet family with a default DROP policy for INPUT.",
}}
}
// --- cPanel/WHM and CloudLinux checks ---
func auditCPanel(serverType string) []store.AuditResult {
var results []store.AuditResult
cpConf := parseCpanelConfig("/var/cpanel/cpanel.config")
// Table-driven boolean checks on cpanel.config.
// fix is the human-readable remediation shown in the UI.
type cpCheck struct {
id, title, key, wantVal string
invert bool // true = fail when value matches wantVal
fix string
}
checks := []cpCheck{
{"cp_ssl_only", "Always Redirect to SSL", "alwaysredirecttossl", "1", false,
"In WHM > Tweak Settings > Redirection, set 'Always redirect to SSL/TLS' to On."},
{"cp_boxtrapper", "BoxTrapper Disabled", "skipboxtrapper", "1", false,
"In WHM > Tweak Settings > Mail, set 'Enable BoxTrapper spam trap' to Off."},
{"cp_password_reset", "Password Reset Disabled", "resetpass", "1", true,
"In WHM > Tweak Settings > System, set 'Reset Password for cPanel accounts' to Off."},
{"cp_password_reset_sub", "Subaccount Password Reset Disabled", "resetpass_sub", "1", true,
"In WHM > Tweak Settings > System, set 'Reset Password for Subaccounts' to Off."},
{"cp_email_passwords", "Email Passwords Disabled", "emailpasswords", "1", true,
"In WHM > Tweak Settings > Security, set 'Send passwords when creating a new account' to Off."},
{"cp_cookie_validation", "Cookie IP Validation", "cookieipvalidation", "strict", false,
"In WHM > Tweak Settings > Security, set 'Cookie IP validation' to strict."},
{"cp_remote_domains", "Remote Domains Disabled", "allowremotedomains", "1", true,
"In WHM > Tweak Settings > Domains, set 'Allow Remote Domains' to Off."},
{"cp_core_dumps", "Core Dumps Disabled", "coredump", "1", true,
"In WHM > Tweak Settings > Security, set 'Generate core dumps' to Off."},
{"cp_nobodyspam", "Nobody Spam Prevention", "nobodyspam", "1", false,
"In WHM > Tweak Settings > Mail, set 'Prevent nobody from sending mail' to On."},
}
for _, c := range checks {
val := cpConf[c.key]
var pass bool
if c.invert {
pass = val != c.wantVal
} else {
pass = val == c.wantVal
}
if pass {
results = append(results, store.AuditResult{
Category: "cpanel", Name: c.id, Title: c.title,
Status: "pass", Message: fmt.Sprintf("%s = %s", c.key, val),
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: c.id, Title: c.title,
Status: "fail", Message: fmt.Sprintf("%s = %s", c.key, val),
Fix: c.fix,
})
}
}
// cp_max_emails_hour
maxEmail := cpConf["maxemailsperhour"]
if maxEmail != "" && maxEmail != "0" {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_max_emails_hour", Title: "Max Emails Per Hour",
Status: "pass", Message: fmt.Sprintf("maxemailsperhour = %s", maxEmail),
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_max_emails_hour", Title: "Max Emails Per Hour",
Status: "fail", Message: "maxemailsperhour is not set or is 0",
Fix: "In WHM > Tweak Settings, set 'Max emails per hour per domain' to a reasonable limit (e.g., 200).",
})
}
// cp_compilers: check /usr/bin/cc permissions
if info, err := osFS.Stat("/usr/bin/cc"); err == nil {
perm := info.Mode().Perm()
if perm <= 0o750 {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_compilers", Title: "Compiler Access Restricted",
Status: "pass", Message: fmt.Sprintf("/usr/bin/cc has mode %04o", perm),
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_compilers", Title: "Compiler Access Restricted",
Status: "fail", Message: fmt.Sprintf("/usr/bin/cc has mode %04o (should be <= 0750)", perm),
Fix: "WHM > Security Center > Compiler Access, or: chmod 750 /usr/bin/cc",
})
}
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_compilers", Title: "Compiler Access Restricted",
Status: "pass", Message: "No compiler found at /usr/bin/cc",
})
}
// cp_ftp_anonymous: parse /etc/pure-ftpd.conf
if data, err := osFS.ReadFile("/etc/pure-ftpd.conf"); err == nil {
noAnon := false
for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "#") {
continue
}
if strings.HasPrefix(trimmed, "NoAnonymous") {
parts := strings.Fields(trimmed)
if len(parts) >= 2 && strings.EqualFold(parts[1], "yes") {
noAnon = true
}
}
}
if noAnon {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_ftp_anonymous", Title: "Anonymous FTP Disabled",
Status: "pass", Message: "NoAnonymous is enabled in pure-ftpd",
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_ftp_anonymous", Title: "Anonymous FTP Disabled",
Status: "fail", Message: "Anonymous FTP may be enabled",
Fix: "Set 'NoAnonymous yes' in /etc/pure-ftpd.conf and restart pure-ftpd.",
})
}
}
// cp_updates: parse /etc/cpupdate.conf
if data, err := osFS.ReadFile("/etc/cpupdate.conf"); err == nil {
updatesDaily := false
for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "#") {
continue
}
if strings.HasPrefix(strings.ToUpper(trimmed), "UPDATES=") {
val := strings.TrimPrefix(trimmed, trimmed[:strings.Index(trimmed, "=")+1])
if strings.EqualFold(strings.TrimSpace(val), "daily") {
updatesDaily = true
}
}
}
if updatesDaily {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_updates", Title: "cPanel Auto-Updates",
Status: "pass", Message: "UPDATES=daily in cpupdate.conf",
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cp_updates", Title: "cPanel Auto-Updates",
Status: "warn", Message: "cPanel auto-updates not set to daily",
Fix: "Set UPDATES=daily in /etc/cpupdate.conf or WHM > Update Preferences.",
})
}
}
// CloudLinux-specific checks
if serverType == "cloudlinux" {
results = append(results, auditCloudLinux()...)
}
return results
}
func parseCpanelConfig(path string) map[string]string {
conf := make(map[string]string)
data, err := osFS.ReadFile(path)
if err != nil {
return conf
}
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if idx := strings.Index(line, "="); idx > 0 {
key := strings.TrimSpace(line[:idx])
val := strings.TrimSpace(line[idx+1:])
conf[key] = val
}
}
return conf
}
func auditCloudLinux() []store.AuditResult {
var results []store.AuditResult
// cl_cagefs
out, err := auditRunCmd("cagefsctl", "--cagefs-status")
switch {
case err != nil:
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_cagefs", Title: "CageFS Enabled",
Status: "warn", Message: "Cannot check CageFS status",
})
case strings.Contains(string(out), "Enabled"):
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_cagefs", Title: "CageFS Enabled",
Status: "pass", Message: "CageFS is enabled",
})
default:
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_cagefs", Title: "CageFS Enabled",
Status: "fail", Message: "CageFS is not enabled",
Fix: "Enable CageFS: cagefsctl --enable-all",
})
}
// cl_symlink_protection
if data, err := osFS.ReadFile("/proc/sys/fs/enforce_symlinksifowner"); err == nil {
val := strings.TrimSpace(string(data))
n, _ := strconv.Atoi(val)
if n >= 1 {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_symlink_protection", Title: "CloudLinux Symlink Protection",
Status: "pass", Message: fmt.Sprintf("enforce_symlinksifowner = %s", val),
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_symlink_protection", Title: "CloudLinux Symlink Protection",
Status: "fail", Message: fmt.Sprintf("enforce_symlinksifowner = %s (expected >= 1)", val),
Fix: "sysctl -w fs.enforce_symlinksifowner=1",
})
}
}
// cl_proc_virtualization
if data, err := osFS.ReadFile("/proc/sys/fs/proc_can_see_other_uid"); err == nil {
val := strings.TrimSpace(string(data))
if val == "0" {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_proc_virtualization", Title: "CloudLinux /proc Virtualization",
Status: "pass", Message: "proc_can_see_other_uid = 0",
})
} else {
results = append(results, store.AuditResult{
Category: "cpanel", Name: "cl_proc_virtualization", Title: "CloudLinux /proc Virtualization",
Status: "fail", Message: fmt.Sprintf("proc_can_see_other_uid = %s (expected 0)", val),
Fix: "sysctl -w fs.proc_can_see_other_uid=0",
})
}
}
return results
}
// --- PHP checks ---
func auditPHP(serverType string) []store.AuditResult {
var results []store.AuditResult
type phpInstall struct {
version string // e.g. "8.1"
shortID string // e.g. "81"
iniPath string
fpmDir string // for FPM pool override merging
}
var installs []phpInstall
// cPanel EA4 PHP installs
eaInis, _ := osFS.Glob("/opt/cpanel/ea-php*/root/etc/php.ini")
for _, ini := range eaInis {
// Extract version from path: /opt/cpanel/ea-php81/root/etc/php.ini -> "81"
dir := filepath.Dir(filepath.Dir(filepath.Dir(ini))) // /opt/cpanel/ea-php81/root -> /opt/cpanel/ea-php81
base := filepath.Base(dir) // ea-php81
shortID := strings.TrimPrefix(base, "ea-php")
if len(shortID) >= 2 {
ver := shortID[:len(shortID)-1] + "." + shortID[len(shortID)-1:]
fpmDir := filepath.Join(dir, "root", "etc", "php-fpm.d")
installs = append(installs, phpInstall{
version: ver,
shortID: shortID,
iniPath: ini,
fpmDir: fpmDir,
})
}
}
// CloudLinux alt-php installs (skip Imunify360's internal PHP builds)
if serverType == "cloudlinux" {
altInis, _ := osFS.Glob("/opt/alt/php*/etc/php.ini")
for _, ini := range altInis {
dir := filepath.Dir(filepath.Dir(ini)) // /opt/alt/php81
base := filepath.Base(dir) // php81
if strings.Contains(base, "-") {
continue // skip php74-imunify, php81-hardened, etc.
}
shortID := strings.TrimPrefix(base, "php")
if len(shortID) >= 2 {
ver := shortID[:len(shortID)-1] + "." + shortID[len(shortID)-1:]
installs = append(installs, phpInstall{
version: ver,
shortID: shortID,
iniPath: ini,
})
}
}
}
// Bare server fallback
if len(installs) == 0 {
out, err := auditRunCmd("php", "-i")
if err == nil {
for _, line := range strings.Split(string(out), "\n") {
if strings.HasPrefix(line, "Loaded Configuration File") {
parts := strings.SplitN(line, "=>", 2)
if len(parts) == 2 {
iniPath := strings.TrimSpace(parts[1])
if iniPath != "(none)" && iniPath != "" {
installs = append(installs, phpInstall{
version: "unknown",
shortID: "system",
iniPath: iniPath,
})
}
}
}
}
}
// Try to get version for bare
if len(installs) > 0 && installs[0].version == "unknown" {
vout, verr := auditRunCmd("php", "-v")
if verr == nil {
first := strings.SplitN(string(vout), "\n", 2)[0]
// "PHP 8.2.15 (cli) ..."
fields := strings.Fields(first)
if len(fields) >= 2 {
verParts := strings.SplitN(fields[1], ".", 3)
if len(verParts) >= 2 {
installs[0].version = verParts[0] + "." + verParts[1]
installs[0].shortID = verParts[0] + verParts[1]
}
}
}
}
}
for _, inst := range installs {
data, err := osFS.ReadFile(inst.iniPath)
if err != nil {
continue
}
ini := parsePHPIni(string(data))
// Merge FPM pool overrides if available
if inst.fpmDir != "" {
poolConfs, _ := osFS.Glob(filepath.Join(inst.fpmDir, "*.conf"))
for _, pc := range poolConfs {
pdata, perr := osFS.ReadFile(pc)
if perr != nil {
continue
}
for _, line := range strings.Split(string(pdata), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, ";") {
continue
}
// php_admin_value[key] = val or php_value[key] = val
for _, prefix := range []string{"php_admin_value[", "php_value["} {
if strings.HasPrefix(line, prefix) {
rest := strings.TrimPrefix(line, prefix)
if idx := strings.Index(rest, "]"); idx > 0 {
key := rest[:idx]
valPart := rest[idx+1:]
if eqIdx := strings.Index(valPart, "="); eqIdx >= 0 {
val := strings.TrimSpace(valPart[eqIdx+1:])
ini[key] = val
}
}
}
}
}
}
}
suffix := inst.shortID
// php_version check
major, minor := parsePHPVersion(inst.version)
if major > 0 {
if major < 8 || (major == 8 && minor < 1) {
results = append(results, store.AuditResult{
Category: "php", Name: "php_version_" + suffix, Title: fmt.Sprintf("PHP %s Version", inst.version),
Status: "fail", Message: fmt.Sprintf("PHP %s is end-of-life", inst.version),
Fix: fmt.Sprintf("Upgrade PHP %s to 8.1 or later. EOL versions receive no security patches.", inst.version),
})
} else {
results = append(results, store.AuditResult{
Category: "php", Name: "php_version_" + suffix, Title: fmt.Sprintf("PHP %s Version", inst.version),
Status: "pass", Message: fmt.Sprintf("PHP %s is supported", inst.version),
})
}
}
// php_disable_functions
df := strings.TrimSpace(ini["disable_functions"])
if df == "" || strings.EqualFold(df, "none") {
results = append(results, store.AuditResult{
Category: "php", Name: "php_disable_functions_" + suffix, Title: fmt.Sprintf("PHP %s disable_functions", inst.version),
Status: "fail", Message: "disable_functions is empty",
Fix: fmt.Sprintf("Set disable_functions in %s to include dangerous functions like exec, system, passthru, shell_exec, popen, proc_open.", inst.iniPath),
})
} else {
results = append(results, store.AuditResult{
Category: "php", Name: "php_disable_functions_" + suffix, Title: fmt.Sprintf("PHP %s disable_functions", inst.version),
Status: "pass", Message: "disable_functions is configured",
})
}
// php_expose
expose := strings.TrimSpace(strings.ToLower(ini["expose_php"]))
if expose == "off" || expose == "0" {
results = append(results, store.AuditResult{
Category: "php", Name: "php_expose_" + suffix, Title: fmt.Sprintf("PHP %s expose_php", inst.version),
Status: "pass", Message: "expose_php is off",
})
} else {
results = append(results, store.AuditResult{
Category: "php", Name: "php_expose_" + suffix, Title: fmt.Sprintf("PHP %s expose_php", inst.version),
Status: "warn", Message: "expose_php is on — PHP version disclosed in headers",
Fix: fmt.Sprintf("Set expose_php = Off in %s", inst.iniPath),
})
}
// php_allow_url_fopen
auf := strings.TrimSpace(strings.ToLower(ini["allow_url_fopen"]))
if auf == "off" || auf == "0" {
results = append(results, store.AuditResult{
Category: "php", Name: "php_allow_url_fopen_" + suffix, Title: fmt.Sprintf("PHP %s allow_url_fopen", inst.version),
Status: "pass", Message: "allow_url_fopen is off",
})
} else {
results = append(results, store.AuditResult{
Category: "php", Name: "php_allow_url_fopen_" + suffix, Title: fmt.Sprintf("PHP %s allow_url_fopen", inst.version),
Status: "warn", Message: "allow_url_fopen is on — remote file inclusion risk",
Fix: fmt.Sprintf("Set allow_url_fopen = Off in %s", inst.iniPath),
})
}
// php_enable_dl
edl := strings.TrimSpace(strings.ToLower(ini["enable_dl"]))
if edl == "on" || edl == "1" {
results = append(results, store.AuditResult{
Category: "php", Name: "php_enable_dl_" + suffix, Title: fmt.Sprintf("PHP %s enable_dl", inst.version),
Status: "fail", Message: "enable_dl is on — allows loading arbitrary shared objects",
Fix: fmt.Sprintf("Set enable_dl = Off in %s", inst.iniPath),
})
} else {
results = append(results, store.AuditResult{
Category: "php", Name: "php_enable_dl_" + suffix, Title: fmt.Sprintf("PHP %s enable_dl", inst.version),
Status: "pass", Message: "enable_dl is off",
})
}
}
return results
}
func parsePHPIni(content string) map[string]string {
ini := make(map[string]string)
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, ";") || strings.HasPrefix(line, "[") {
continue
}
if idx := strings.Index(line, "="); idx > 0 {
key := strings.TrimSpace(line[:idx])
val := strings.TrimSpace(line[idx+1:])
ini[key] = val
}
}
return ini
}
func parsePHPVersion(ver string) (int, int) {
parts := strings.SplitN(ver, ".", 3)
if len(parts) < 2 {
return 0, 0
}
major, _ := strconv.Atoi(parts[0])
minor, _ := strconv.Atoi(parts[1])
return major, minor
}
// --- Web server checks ---
func auditWebServer(serverType string) []store.AuditResult {
var results []store.AuditResult
// Find main config file
var configPath string
configPaths := []string{
"/etc/apache2/conf/httpd.conf", // cPanel EA4
"/usr/local/lsws/conf/httpd_config.xml", // LiteSpeed
"/etc/httpd/conf/httpd.conf", // RHEL/CentOS bare
"/etc/apache2/apache2.conf", // Debian/Ubuntu bare
}
for _, p := range configPaths {
if _, err := osFS.Stat(p); err == nil {
configPath = p
break
}
}
if configPath != "" {
configData, err := osFS.ReadFile(configPath)
if err == nil {
configContent := string(configData)
// Table-driven directive checks
type directiveCheck struct {
id, title, directive string
goodValues []string
}
dirChecks := []directiveCheck{
{"web_server_tokens", "ServerTokens", "ServerTokens", []string{"prod", "productonly"}},
{"web_server_signature", "ServerSignature", "ServerSignature", []string{"off"}},
{"web_trace_enable", "TraceEnable", "TraceEnable", []string{"off"}},
{"web_file_etag", "FileETag", "FileETag", []string{"none"}},
}
for _, dc := range dirChecks {
found := false
var foundVal string
for _, line := range strings.Split(configContent, "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "#") {
continue
}
if strings.HasPrefix(strings.ToLower(trimmed), strings.ToLower(dc.directive)+" ") {
parts := strings.Fields(trimmed)
if len(parts) >= 2 {
foundVal = parts[1]
found = true
}
}
}
if !found {
results = append(results, store.AuditResult{
Category: "webserver", Name: dc.id, Title: dc.title,
Status: "warn", Message: fmt.Sprintf("%s not set in %s", dc.directive, configPath),
Fix: fmt.Sprintf("Add '%s %s' to %s", dc.directive, dc.goodValues[0], configPath),
})
continue
}
isGood := false
for _, gv := range dc.goodValues {
if strings.EqualFold(foundVal, gv) {
isGood = true
break
}
}
if isGood {
results = append(results, store.AuditResult{
Category: "webserver", Name: dc.id, Title: dc.title,
Status: "pass", Message: fmt.Sprintf("%s = %s", dc.directive, foundVal),
})
} else {
results = append(results, store.AuditResult{
Category: "webserver", Name: dc.id, Title: dc.title,
Status: "fail", Message: fmt.Sprintf("%s = %s", dc.directive, foundVal),
Fix: fmt.Sprintf("Set '%s %s' in %s", dc.directive, dc.goodValues[0], configPath),
})
}
}
// Directory listing check
hasIndexes := false
for _, line := range strings.Split(configContent, "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "#") {
continue
}
lower := strings.ToLower(trimmed)
if strings.Contains(lower, "options") && strings.Contains(lower, "indexes") && !strings.Contains(lower, "-indexes") {
hasIndexes = true
break
}
}
if hasIndexes {
results = append(results, store.AuditResult{
Category: "webserver", Name: "web_directory_listing", Title: "Directory Listing",
Status: "warn", Message: "Global config enables directory listing (Options Indexes)",
Fix: "Replace 'Indexes' with '-Indexes' in Options directives.",
})
} else {
results = append(results, store.AuditResult{
Category: "webserver", Name: "web_directory_listing", Title: "Directory Listing",
Status: "pass", Message: "No global directory listing enabled",
})
}
}
}
// TLS version checks: probe with openssl
for _, tc := range []struct {
id, title, flag, version string
}{
{"web_tls_version", "Legacy TLS Disabled", "-tls1", "TLSv1.0"},
{"web_tls11_version", "TLS 1.1 Disabled", "-tls1_1", "TLSv1.1"},
} {
out, err := auditRunCmd("openssl", "s_client", "-connect", "localhost:443", tc.flag)
output := string(out)
// If the handshake succeeds, output contains "SSL-Session:" without ":error:" on the same handshake
succeeded := err == nil && strings.Contains(output, "SSL-Session:") && !strings.Contains(output, ":error:")
if succeeded {
results = append(results, store.AuditResult{
Category: "webserver", Name: tc.id, Title: tc.title,
Status: "fail", Message: fmt.Sprintf("Server accepts %s connections", tc.version),
Fix: fmt.Sprintf("Disable %s in your web server's SSL configuration. Minimum should be TLSv1.2.", tc.version),
})
} else {
results = append(results, store.AuditResult{
Category: "webserver", Name: tc.id, Title: tc.title,
Status: "pass", Message: fmt.Sprintf("%s is rejected", tc.version),
})
}
}
return results
}
// --- Mail checks ---
func auditMail() []store.AuditResult {
var results []store.AuditResult
// mail_root_forwarder
if info, err := osFS.Stat("/root/.forward"); err != nil {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_root_forwarder", Title: "Root Mail Forwarder",
Status: "warn", Message: "/root/.forward does not exist — root mail may go unread",
Fix: "Create /root/.forward with an email address to receive root's mail.",
})
} else if info.Size() == 0 {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_root_forwarder", Title: "Root Mail Forwarder",
Status: "warn", Message: "/root/.forward is empty — root mail may go unread",
Fix: "Add an email address to /root/.forward to receive root's mail.",
})
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_root_forwarder", Title: "Root Mail Forwarder",
Status: "pass", Message: "Root mail forwarding is configured",
})
}
// Get exim config for multiple checks
eximOut, eximErr := auditRunCmd("exim", "-bP")
eximConfig := ""
if eximErr == nil {
eximConfig = string(eximOut)
}
// mail_exim_logging
if eximConfig != "" {
lower := strings.ToLower(eximConfig)
if strings.Contains(lower, "+arguments") || strings.Contains(lower, "+all") {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_exim_logging", Title: "Exim Argument Logging",
Status: "pass", Message: "Exim logs include +arguments",
})
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_exim_logging", Title: "Exim Argument Logging",
Status: "warn", Message: "Exim log_selector does not include +arguments",
Fix: "Add '+arguments' to log_selector in exim configuration for better forensics.",
})
}
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_exim_logging", Title: "Exim Argument Logging",
Status: "warn", Message: "Cannot query exim configuration",
})
}
// mail_exim_tls: check for SSLv2 in tls_require_ciphers
// +no_sslv2 in openssl_options means SSLv2 is DISABLED (good).
// Only flag if SSLv2 is referenced WITHOUT +no_ prefix.
if eximConfig != "" {
lower := strings.ToLower(eximConfig)
hasSslv2 := strings.Contains(lower, "sslv2")
isDisabled := strings.Contains(lower, "+no_sslv2") || strings.Contains(lower, "no_sslv2")
if hasSslv2 && !isDisabled {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_exim_tls", Title: "Exim TLS Ciphers",
Status: "fail", Message: "Exim allows SSLv2 connections",
Fix: "Add '+no_sslv2' to openssl_options in exim configuration.",
})
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_exim_tls", Title: "Exim TLS Ciphers",
Status: "pass", Message: "SSLv2 is disabled in exim TLS configuration",
})
}
}
// mail_secure_auth: check /etc/exim.conf.localopts
if data, err := osFS.ReadFile("/etc/exim.conf.localopts"); err == nil {
content := string(data)
if strings.Contains(content, "require_secure_auth=0") {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_secure_auth", Title: "Exim Secure Authentication",
Status: "fail", Message: "require_secure_auth is disabled in /etc/exim.conf.localopts",
Fix: "Remove or set require_secure_auth=1 in /etc/exim.conf.localopts.",
})
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_secure_auth", Title: "Exim Secure Authentication",
Status: "pass", Message: "Secure authentication is not disabled",
})
}
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_secure_auth", Title: "Exim Secure Authentication",
Status: "pass", Message: "No local exim overrides file found (default is secure)",
})
}
// mail_dovecot_tls: check ssl_min_protocol
// Use 'doveconf -a' for the effective config — cPanel manages Dovecot
// settings outside the standard config files, so file parsing misses
// them. Routed through cmdExec so tests can mock the doveconf output.
dovecotTLS := false
if out, err := cmdExec.Run("doveconf", "-a"); err == nil {
for _, line := range strings.Split(string(out), "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "ssl_min_protocol") {
val := strings.TrimSpace(strings.TrimPrefix(trimmed, "ssl_min_protocol"))
val = strings.TrimLeft(val, "= ")
if strings.Contains(val, "TLSv1.2") || strings.Contains(val, "TLSv1.3") {
dovecotTLS = true
}
}
}
} else {
// Fallback: try config files
for _, path := range []string{"/etc/dovecot/conf.d/10-ssl.conf", "/etc/dovecot/dovecot.conf"} {
data, readErr := osFS.ReadFile(path)
if readErr != nil {
continue
}
for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "#") {
continue
}
if strings.HasPrefix(trimmed, "ssl_min_protocol") {
val := strings.TrimSpace(strings.TrimPrefix(trimmed, "ssl_min_protocol"))
val = strings.TrimLeft(val, "= ")
if strings.Contains(val, "TLSv1.2") || strings.Contains(val, "TLSv1.3") {
dovecotTLS = true
}
}
}
}
}
if dovecotTLS {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_dovecot_tls", Title: "Dovecot TLS Minimum",
Status: "pass", Message: "Dovecot ssl_min_protocol is TLSv1.2 or higher",
})
} else {
if _, err := osFS.Stat("/etc/dovecot/dovecot.conf"); err != nil {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_dovecot_tls", Title: "Dovecot TLS Minimum",
Status: "warn", Message: "Dovecot configuration not found",
})
} else {
results = append(results, store.AuditResult{
Category: "mail", Name: "mail_dovecot_tls", Title: "Dovecot TLS Minimum",
Status: "fail", Message: "Dovecot ssl_min_protocol not set to TLSv1.2 or higher",
Fix: "Set 'ssl_min_protocol = TLSv1.2' in /etc/dovecot/conf.d/10-ssl.conf.",
})
}
}
return results
}
package checks
import (
"context"
"fmt"
"os"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/state"
)
// CheckHealth verifies that CSM's dependencies are working.
// Reports on missing external commands, broken auditd, etc.
func CheckHealth(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
info := platform.Detect()
// Required commands depend on the platform. On plain Linux hosts we
// don't need Exim/cPanel-specific tooling.
requiredCmds := platformRequiredCommands(info)
for _, cmd := range requiredCmds {
if _, err := cmdExec.LookPath(cmd); err != nil {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "csm_health",
Message: fmt.Sprintf("Required command not found: %s", cmd),
Details: "Some checks will be skipped",
})
}
}
// Optional commands. Only complain about cPanel tools on cPanel hosts.
optionalCmds := map[string]string{
"wp": "WordPress core integrity check will be skipped",
}
if info.IsCPanel() {
optionalCmds["whmapi1"] = "WHM API token check will be skipped"
}
for cmd, impact := range optionalCmds {
if _, err := cmdExec.LookPath(cmd); err != nil {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "csm_health",
Message: fmt.Sprintf("Optional command not found: %s", cmd),
Details: impact,
})
}
}
// Check auditd is running and has CSM rules
out, _ := runCmd("auditctl", "-l")
if out != nil {
rules := string(out)
if !strings.Contains(rules, "csm_shadow_change") {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "csm_health",
Message: "auditd CSM rules not loaded",
Details: "Run 'csm install' to deploy auditd rules, then 'service auditd restart'",
})
}
}
// Check state directory is writable
stateDir := "/opt/csm/state"
testFile := stateDir + "/.health_check"
if err := os.WriteFile(testFile, []byte("ok"), 0600); err != nil {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "csm_health",
Message: fmt.Sprintf("State directory not writable: %s", stateDir),
Details: err.Error(),
})
} else {
os.Remove(testFile)
}
return findings
}
// platformRequiredCommands returns the external commands CSM needs on the
// detected platform. On plain Linux hosts Exim is not required.
func platformRequiredCommands(info platform.Info) []string {
cmds := []string{"find", "auditctl"}
if info.IsCPanel() {
cmds = append(cmds, "exim")
}
return cmds
}
package checks
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"os"
"os/exec"
"time"
)
const cmdTimeout = 2 * time.Minute
func hashFileContent(path string) (string, error) {
data, err := osFS.ReadFile(path)
if err != nil {
return "", err
}
h := sha256.Sum256(data)
return fmt.Sprintf("%x", h[:]), nil
}
func hashBytes(data []byte) string {
h := sha256.Sum256(data)
return fmt.Sprintf("%x", h[:])
}
// runCmd delegates to the package-level cmdExec provider.
// Check functions call runCmd; tests swap cmdExec via SetCmdRunner.
func runCmd(name string, args ...string) ([]byte, error) {
return cmdExec.Run(name, args...)
}
func runCmdAllowNonZero(name string, args ...string) ([]byte, error) {
return cmdExec.RunAllowNonZero(name, args...)
}
func runCmdCombinedContext(parent context.Context, name string, args ...string) ([]byte, error) {
return cmdExec.RunContext(parent, name, args...)
}
func runCmdWithEnv(name string, args []string, extraEnv ...string) ([]byte, error) {
return cmdExec.RunWithEnv(name, args, extraEnv...)
}
// ---------------------------------------------------------------------------
// Real implementations — used by realCmd in provider.go
//
// Every caller of these helpers passes a constant system command name
// (nft, rpm, wp-cli, doveadm, systemctl, etc.) with arguments built from
// either config, filesystem state, or CSM-generated data. Nothing here
// reaches out to HTTP request bodies or webui form inputs. The gosec
// G204 suppressions below are in that trust model.
// ---------------------------------------------------------------------------
func runCmdReal(name string, args ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), cmdTimeout)
defer cancel()
// #nosec G204 -- see package-level trust note above.
out, err := exec.CommandContext(ctx, name, args...).Output()
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "Command timed out: %s %v\n", name, args)
return nil, nil
}
return out, err
}
func runCmdAllowNonZeroReal(name string, args ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), cmdTimeout)
defer cancel()
// #nosec G204 -- see package-level trust note above.
out, err := exec.CommandContext(ctx, name, args...).Output()
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "Command timed out: %s %v\n", name, args)
return nil, nil
}
var exitErr *exec.ExitError
if err != nil && errors.As(err, &exitErr) {
return out, nil
}
return out, err
}
func runCmdCombinedContextReal(parent context.Context, name string, args ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(parent, cmdTimeout)
defer cancel()
// #nosec G204 -- see package-level trust note above.
out, err := exec.CommandContext(ctx, name, args...).CombinedOutput()
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "Command timed out: %s %v\n", name, args)
return nil, nil
}
if parent.Err() != nil {
return nil, parent.Err()
}
return out, err
}
func runCmdWithEnvReal(name string, args []string, extraEnv ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), cmdTimeout)
defer cancel()
// #nosec G204 -- see package-level trust note above.
cmd := exec.CommandContext(ctx, name, args...)
cmd.Env = append(os.Environ(), extraEnv...)
out, err := cmd.Output()
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "Command timed out: %s\n", name)
return nil, nil
}
return out, err
}
package checks
import (
"context"
"fmt"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/attackdb"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckLocalThreatScore generates findings for IPs that have accumulated
// a high local threat score but have not yet been blocked.
// Runs every 10 minutes as part of TierCritical.
func CheckLocalThreatScore(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
adb := attackdb.Global()
if adb == nil {
return nil
}
alreadyBlocked := loadAllBlockedIPs(cfg.StatePath)
var findings []alert.Finding
for _, rec := range adb.TopAttackers(50) {
if alreadyBlocked[rec.IP] {
continue
}
if rec.ThreatScore >= 70 {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "local_threat_score",
Message: fmt.Sprintf("High local threat score: %s (score %d/100, %d attacks)", rec.IP, rec.ThreatScore, rec.EventCount),
Details: fmt.Sprintf("Attack types: %v\nAccounts targeted: %d\nFirst seen: %s\nLast seen: %s", rec.AttackCounts, len(rec.Accounts), rec.FirstSeen.Format("2006-01-02 15:04"), rec.LastSeen.Format("2006-01-02 15:04")),
Timestamp: time.Now(),
})
}
}
return findings
}
package checks
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
func CheckMailQueue(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
out, err := runCmd("exim", "-bpc")
if err != nil || out == nil {
return nil
}
count, err := strconv.Atoi(strings.TrimSpace(string(out)))
if err != nil {
return nil
}
if count >= cfg.Thresholds.MailQueueCrit {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "mail_queue",
Message: fmt.Sprintf("Exim mail queue critical: %d messages", count),
Details: "Possible spam outbreak from compromised account",
})
} else if count >= cfg.Thresholds.MailQueueWarn {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "mail_queue",
Message: fmt.Sprintf("Exim mail queue elevated: %d messages", count),
})
}
return findings
}
package checks
import (
"context"
"fmt"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const perAccountMailThreshold = 100 // emails per recent log window
// CheckMailPerAccount reads the tail of exim_mainlog and counts outbound
// emails per cPanel account. Alerts if a single account exceeds the threshold.
func CheckMailPerAccount(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
lines := tailFile("/var/log/exim_mainlog", 500)
// Count emails per sender domain/user
// Exim log format: "... <= user@domain.com ..." for outgoing
counts := make(map[string]int)
for _, line := range lines {
// Look for outgoing messages (<=)
idx := strings.Index(line, " <= ")
if idx < 0 {
continue
}
// Extract sender address
rest := line[idx+4:]
fields := strings.Fields(rest)
if len(fields) < 1 {
continue
}
sender := fields[0]
// Extract the domain part
atIdx := strings.LastIndex(sender, "@")
if atIdx < 0 {
continue
}
domain := sender[atIdx+1:]
// Skip system/bounce messages
if domain == "" || sender == "<>" || strings.HasPrefix(sender, "cPanel") {
continue
}
counts[domain]++
}
// Alert on accounts exceeding threshold
for domain, count := range counts {
if count >= perAccountMailThreshold {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "mail_per_account",
Message: fmt.Sprintf("High email volume from %s: %d messages in recent log", domain, count),
Details: "Possible spam outbreak or compromised email account",
})
}
}
return findings
}
package checks
import (
"context"
"fmt"
"net"
"strings"
"sync"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
func CheckOutboundConnections(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Parse /proc/net/tcp for established connections
// Format: sl local_address rem_address st ...
// local_address = IP:port we are listening/connecting from
// rem_address = IP:port of the remote end
data, err := osFS.ReadFile("/proc/net/tcp")
if err != nil {
return nil
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Fields(line)
if len(fields) < 4 {
continue
}
if fields[0] == "sl" {
continue
}
// State 01 = ESTABLISHED
if fields[3] != "01" {
continue
}
localAddr := fields[1]
remoteAddr := fields[2]
_, localPort := parseHexAddr(localAddr)
remoteIP, remotePort := parseHexAddr(remoteAddr)
if remoteIP == "" || remoteIP == "127.0.0.1" || remoteIP == "0.0.0.0" {
continue
}
// Check remote IP against C2 blocklist
for _, blocked := range cfg.C2Blocklist {
if remoteIP == blocked {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "c2_connection",
Message: fmt.Sprintf("Connection to known C2 IP: %s:%d", remoteIP, remotePort),
Details: fmt.Sprintf("Local port: %d", localPort),
})
}
}
// Check if OUR LOCAL port is a backdoor port (we're listening on it)
// This catches backdoor listeners, not clients connecting from high ports
for _, bp := range cfg.BackdoorPorts {
if localPort == bp {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "backdoor_port",
Message: fmt.Sprintf("Listening on known backdoor port %d, connected from %s:%d", localPort, remoteIP, remotePort),
})
}
}
// Also check if we're connecting OUT to a backdoor port on a remote host
// (e.g. reverse shell calling back to attacker's listener)
// Skip if our local port is a known service (the remote port is just
// the client's ephemeral port, not a backdoor listener)
knownServicePorts := map[int]bool{
21: true, 25: true, 26: true, 53: true, 80: true, 110: true,
143: true, 443: true, 465: true, 587: true, 993: true, 995: true,
2082: true, 2083: true, 2086: true, 2087: true, 2095: true, 2096: true,
3306: true, 4190: true,
}
if knownServicePorts[localPort] {
continue
}
for _, bp := range cfg.BackdoorPorts {
if remotePort == bp {
if isInfraIP(remoteIP, cfg.InfraIPs) {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "backdoor_port_outbound",
Message: fmt.Sprintf("Outbound connection to backdoor port: %s:%d", remoteIP, remotePort),
Details: fmt.Sprintf("Local port: %d", localPort),
})
}
}
}
return findings
}
func isInfraIP(ip string, infraNets []string) bool {
parsed := net.ParseIP(ip)
if parsed == nil {
return false
}
for _, entry := range infraNets {
// Try CIDR first (e.g. "10.0.0.0/8")
_, network, err := net.ParseCIDR(entry)
if err == nil {
if network.Contains(parsed) {
return true
}
continue
}
// Fall back to plain IP match (e.g. "1.2.3.4")
if net.ParseIP(entry) != nil && entry == ip {
return true
}
}
// Also check Cloudflare IPs - these must never be blocked/challenged
// because blocking a CF edge IP blocks thousands of legitimate users.
// The detection/alert still fires; only the nftables action is skipped.
if isCloudflareIP(parsed) {
return true
}
return false
}
var (
cfNets []*net.IPNet
cfNetsMu sync.RWMutex
)
// SetCloudflareNets updates the cached Cloudflare IP ranges.
// Called by the daemon after fetching CF IPs.
func SetCloudflareNets(cidrs []string) {
var nets []*net.IPNet
for _, cidr := range cidrs {
_, network, err := net.ParseCIDR(cidr)
if err == nil {
nets = append(nets, network)
}
}
cfNetsMu.Lock()
cfNets = nets
cfNetsMu.Unlock()
}
func isCloudflareIP(ip net.IP) bool {
cfNetsMu.RLock()
defer cfNetsMu.RUnlock()
for _, network := range cfNets {
if network.Contains(ip) {
return true
}
}
return false
}
func parseHexAddr(hexAddr string) (string, int) {
parts := strings.Split(hexAddr, ":")
if len(parts) != 2 {
return "", 0
}
hexIP := parts[0]
hexPort := parts[1]
if len(hexIP) != 8 {
return "", 0
}
// Parse little-endian hex IP
var octets [4]byte
for i := 0; i < 4; i++ {
val := hexToByte(hexIP[6-2*i : 8-2*i])
octets[i] = val
}
ip := net.IPv4(octets[0], octets[1], octets[2], octets[3]).String()
port := 0
for _, c := range hexPort {
// #nosec G115 -- hexPort is ASCII hex from /proc/net/*; rune→byte is lossless.
port = port*16 + hexVal(byte(c))
}
return ip, port
}
func hexToByte(s string) byte {
if len(s) != 2 {
return 0
}
// #nosec G115 -- hexVal returns 0..15; (h<<4)|h fits in a byte (0..255).
return byte(hexVal(s[0])<<4 | hexVal(s[1]))
}
func hexVal(c byte) int {
switch {
case c >= '0' && c <= '9':
return int(c - '0')
case c >= 'a' && c <= 'f':
return int(c-'a') + 10
case c >= 'A' && c <= 'F':
return int(c-'A') + 10
}
return 0
}
package checks
import (
"bufio"
"context"
"encoding/json"
"fmt"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// perfEnabled returns false only if Performance.Enabled is explicitly set to false.
// nil (unset) is treated as enabled.
func perfEnabled(cfg *config.Config) bool {
if cfg.Performance.Enabled == nil {
return true
}
return *cfg.Performance.Enabled
}
// cpuCoresOnce guards the cached CPU core count.
var (
cpuCoresOnce sync.Once
cpuCoresCache int
)
// getCPUCores reads /proc/cpuinfo and counts "processor\t" lines.
// The result is cached after the first call. Returns 1 on error.
func getCPUCores() int {
cpuCoresOnce.Do(func() {
f, err := osFS.Open("/proc/cpuinfo")
if err != nil {
cpuCoresCache = 1
return
}
defer func() { _ = f.Close() }()
count := 0
scanner := bufio.NewScanner(f)
for scanner.Scan() {
if strings.HasPrefix(scanner.Text(), "processor\t") {
count++
}
}
if count == 0 {
count = 1
}
cpuCoresCache = count
})
return cpuCoresCache
}
// parseLoadAvg reads /proc/loadavg and returns the first three load average
// values (1m, 5m, 15m).
func parseLoadAvg() ([3]float64, error) {
var result [3]float64
data, err := osFS.ReadFile("/proc/loadavg")
if err != nil {
return result, fmt.Errorf("reading /proc/loadavg: %w", err)
}
fields := strings.Fields(string(data))
if len(fields) < 3 {
return result, fmt.Errorf("unexpected /proc/loadavg format: %q", string(data))
}
for i := 0; i < 3; i++ {
v, err := strconv.ParseFloat(fields[i], 64)
if err != nil {
return result, fmt.Errorf("parsing load avg field %d: %w", i, err)
}
result[i] = v
}
return result, nil
}
// parseMemInfo reads /proc/meminfo and returns total memory, available memory,
// swap total, and swap free - all in kilobytes.
func parseMemInfo() (total, available, swapTotal, swapFree uint64) {
f, err := osFS.Open("/proc/meminfo")
if err != nil {
return
}
defer func() { _ = f.Close() }()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
key := fields[0]
val, err := strconv.ParseUint(fields[1], 10, 64)
if err != nil {
continue
}
switch key {
case "MemTotal:":
total = val
case "MemAvailable:":
available = val
case "SwapTotal:":
swapTotal = val
case "SwapFree:":
swapFree = val
}
}
return
}
// humanBytes formats a byte count as a human-readable string.
// Thresholds: >=1G → "1.0G", >=1M → "1M", >=1K → "1K", else "0B".
func humanBytes(b int64) string {
const (
KB = 1024
MB = 1024 * KB
GB = 1024 * MB
)
switch {
case b >= GB:
return fmt.Sprintf("%.1fG", float64(b)/float64(GB))
case b >= MB:
return fmt.Sprintf("%dM", b/MB)
case b >= KB:
return fmt.Sprintf("%dK", b/KB)
default:
return "0B"
}
}
// CheckLoadAverage compares the 1-minute load average against per-core
// thresholds from config. Reports Critical or High findings.
func CheckLoadAverage(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
loads, err := parseLoadAvg()
if err != nil {
return nil
}
cores := getCPUCores()
load1 := loads[0]
critThreshold := float64(cores) * cfg.Performance.LoadCriticalMultiplier
highThreshold := float64(cores) * cfg.Performance.LoadHighMultiplier
var sev alert.Severity
var msg string
switch {
case load1 > critThreshold:
sev = alert.Critical
msg = "High load average exceeds critical threshold"
case load1 > highThreshold:
sev = alert.High
msg = "High load average exceeds high threshold"
default:
return nil
}
details := fmt.Sprintf("Load: %.1f/%.1f/%.1f, Cores: %d, Threshold: %.1f",
loads[0], loads[1], loads[2], cores,
map[bool]float64{true: critThreshold, false: highThreshold}[sev == alert.Critical])
return []alert.Finding{{
Severity: sev,
Check: "perf_load",
Message: msg,
Details: details,
Timestamp: time.Now(),
}}
}
// CheckPHPProcessLoad scans /proc for lsphp processes, groups them by user,
// and fires Critical if total exceeds cores*multiplier, High per user if
// individual count exceeds threshold.
func CheckPHPProcessLoad(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
cores := getCPUCores()
cmdlinePaths, _ := osFS.Glob("/proc/[0-9]*/cmdline")
// user → list of cmdline samples
userProcs := make(map[string][]string)
total := 0
for _, cmdPath := range cmdlinePaths {
pid := filepath.Base(filepath.Dir(cmdPath))
data, err := osFS.ReadFile(cmdPath)
if err != nil {
continue
}
cmdStr := strings.ReplaceAll(string(data), "\x00", " ")
cmdStr = strings.TrimSpace(cmdStr)
if !strings.Contains(cmdStr, "lsphp") {
continue
}
// Read UID from status
statusData, _ := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
var uid string
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
uid = fields[0]
}
break
}
}
if uid == "" {
uid = "unknown"
}
username := uidToUser(uid)
userProcs[username] = append(userProcs[username], cmdStr)
total++
}
if total == 0 {
return nil
}
var findings []alert.Finding
// Critical: total lsphp count exceeds cores * multiplier
critTotalThreshold := cores * cfg.Performance.PHPProcessCriticalTotalMult
if total > critTotalThreshold {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "perf_php_processes",
Message: "Total lsphp process count exceeds critical threshold",
Details: fmt.Sprintf("Count: %d, Threshold: %d (cores: %d × %d)", total, critTotalThreshold, cores, cfg.Performance.PHPProcessCriticalTotalMult),
Timestamp: time.Now(),
})
}
// High: per-user count exceeds threshold
for username, procs := range userProcs {
if len(procs) > cfg.Performance.PHPProcessWarnPerUser {
// Collect up to 3 sample cmdlines
samples := procs
if len(samples) > 3 {
samples = samples[:3]
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_php_processes",
Message: fmt.Sprintf("Excessive lsphp processes for user %s", username),
Details: fmt.Sprintf("Count: %d, Threshold: %d, Sample cmdlines: %s", len(procs), cfg.Performance.PHPProcessWarnPerUser, strings.Join(samples, " | ")),
Timestamp: time.Now(),
})
}
}
return findings
}
// CheckSwapAndOOM checks for OOM killer events in dmesg and elevated swap
// usage from /proc/meminfo. Reports Critical for OOM, High for swap > 50%.
func CheckSwapAndOOM(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
var findings []alert.Finding
// Check dmesg for OOM events
// Prefer ISO timestamps so we can filter to the last hour.
// Fall back to -T (human-readable) on older kernels that don't support --time-format.
dmesgOut, isoErr := runCmd("dmesg", "--time-format", "iso", "--level=err")
useISO := isoErr == nil && dmesgOut != nil
if !useISO {
dmesgOut, _ = runCmd("dmesg", "--level=err", "-T")
}
if dmesgOut != nil {
cutoff := time.Now().Add(-1 * time.Hour)
for _, line := range strings.Split(string(dmesgOut), "\n") {
lower := strings.ToLower(line)
if !strings.Contains(lower, "out of memory") && !strings.Contains(lower, "oom_reaper") {
continue
}
if useISO {
// ISO format: 2006-01-02T15:04:05,000000+0300
// The timestamp is the first field before the first space.
ts := strings.SplitN(line, " ", 2)[0]
// Normalise: replace comma-decimal with period so time.Parse handles it.
ts = strings.Replace(ts, ",", ".", 1)
// Try parsing with timezone offset (+hhmm or +hh:mm).
var parsed time.Time
var parseErr error
for _, layout := range []string{"2006-01-02T15:04:05.000000-0700", "2006-01-02T15:04:05.000000-07:00"} {
parsed, parseErr = time.Parse(layout, ts)
if parseErr == nil {
break
}
}
if parseErr != nil || parsed.Before(cutoff) {
continue
}
}
message := "OOM killer detected in dmesg"
if useISO {
message = "OOM killer invoked in the last hour"
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "perf_memory",
Message: message,
Details: strings.TrimSpace(line),
Timestamp: time.Now(),
})
break // one finding is enough
}
}
// Check swap usage
_, _, swapTotal, swapFree := parseMemInfo()
if swapTotal > 0 {
swapUsed := swapTotal - swapFree
usagePct := float64(swapUsed) / float64(swapTotal) * 100
if usagePct > 50 {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_memory",
Message: "High swap usage",
// #nosec G115 -- swap sizes from /proc/meminfo are kernel-bounded
// to physical memory, multiple orders below int64 max even after *1024.
Details: fmt.Sprintf("Swap used: %s / %s (%.0f%%)", humanBytes(int64(swapUsed)*1024), humanBytes(int64(swapTotal)*1024), usagePct),
Timestamp: time.Now(),
})
}
}
return findings
}
// CheckPHPHandler detects PHP CGI handler usage on LiteSpeed servers.
// On LiteSpeed, CGI is significantly slower than LSAPI; this check fires
// a Critical finding for each PHP version using the CGI handler.
func CheckPHPHandler(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
if !store.ShouldRunThrottled("perf_php_handler", 60) {
return nil
}
// Only relevant on LiteSpeed
if _, err := osFS.Stat("/usr/local/lsws/bin/litespeed"); err != nil {
return nil
}
var cgiVersions []string
// Try whmapi1 first
out, err := runCmd("whmapi1", "php_get_handlers", "--output=json")
if err == nil && len(out) > 0 {
// Parse JSON: look for handler entries with type "cgi"
var result struct {
Data struct {
Handlers []struct {
Version string `json:"version"`
Handler string `json:"handler"`
Type string `json:"type"`
} `json:"handlers"`
} `json:"data"`
}
if jsonErr := json.Unmarshal(out, &result); jsonErr == nil {
for _, h := range result.Data.Handlers {
t := strings.ToLower(h.Handler + " " + h.Type)
if strings.Contains(t, "cgi") && !strings.Contains(t, "lsapi") && !strings.Contains(t, "fpm") {
cgiVersions = append(cgiVersions, h.Version)
}
}
}
} else {
// Fallback: read /etc/cpanel/ea4/ea4.conf
data, readErr := osFS.ReadFile("/etc/cpanel/ea4/ea4.conf")
if readErr == nil {
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
// Lines like: ea-php74.handler = cgi
if !strings.Contains(line, ".handler") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
val := strings.TrimSpace(parts[1])
if val == "cgi" {
versionPart := strings.TrimSpace(parts[0])
cgiVersions = append(cgiVersions, versionPart)
}
}
}
}
if len(cgiVersions) == 0 {
return nil
}
return []alert.Finding{{
Severity: alert.Critical,
Check: "perf_php_handler",
Message: "PHP handler set to CGI instead of LSAPI on LiteSpeed",
Details: fmt.Sprintf("Affected PHP versions: %s", strings.Join(cgiVersions, ", ")),
Timestamp: time.Now(),
}}
}
// CheckMySQLConfig inspects MySQL global variables and runtime status for
// performance-impacting misconfigurations. Each issue emits its own finding
// with a stable message so deduplication works correctly.
func CheckMySQLConfig(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
if !store.ShouldRunThrottled("perf_mysql_config", 60) {
return nil
}
var findings []alert.Finding
// --- Global variables ---
varOut, err := runCmd("mysql", "-N", "-B", "-e",
"SHOW GLOBAL VARIABLES WHERE Variable_name IN ('join_buffer_size','wait_timeout','interactive_timeout','max_user_connections','slow_query_log')")
if err == nil && len(varOut) > 0 {
joinBufThresholdBytes := int64(cfg.Performance.MySQLJoinBufferMaxMB) * 1024 * 1024
waitTimeoutMax := cfg.Performance.MySQLWaitTimeoutMax
for _, line := range strings.Split(string(varOut), "\n") {
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
name := fields[0]
val := fields[1]
switch name {
case "join_buffer_size":
v, convErr := strconv.ParseInt(val, 10, 64)
if convErr == nil && v > joinBufThresholdBytes {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "perf_mysql_config",
Message: "MySQL join_buffer_size exceeds safe maximum",
Details: fmt.Sprintf("Current: %s, Max: %s", humanBytes(v), humanBytes(joinBufThresholdBytes)),
Timestamp: time.Now(),
})
}
case "wait_timeout":
v, convErr := strconv.Atoi(val)
if convErr == nil && v > waitTimeoutMax {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_mysql_config",
Message: "MySQL wait_timeout is too high",
Details: fmt.Sprintf("Current: %ds, Max: %ds", v, waitTimeoutMax),
Timestamp: time.Now(),
})
}
case "interactive_timeout":
v, convErr := strconv.Atoi(val)
if convErr == nil && v > waitTimeoutMax {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_mysql_config",
Message: "MySQL interactive_timeout is too high",
Details: fmt.Sprintf("Current: %ds, Max: %ds", v, waitTimeoutMax),
Timestamp: time.Now(),
})
}
case "max_user_connections":
if val == "0" {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_mysql_config",
Message: "MySQL max_user_connections is unlimited",
Details: fmt.Sprintf("Current: 0 (unlimited), Recommended: %d", cfg.Performance.MySQLMaxConnectionsPerUser),
Timestamp: time.Now(),
})
}
case "slow_query_log":
if strings.ToUpper(val) == "OFF" {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_mysql_config",
Message: "MySQL slow query log is disabled",
Details: "Set slow_query_log=ON to help diagnose performance issues",
Timestamp: time.Now(),
})
}
}
}
}
// --- InnoDB buffer pool hit ratio + temporary disk tables ---
statusOut, err := runCmd("mysql", "-N", "-B", "-e",
"SHOW GLOBAL STATUS WHERE Variable_name IN ('Innodb_buffer_pool_read_requests','Innodb_buffer_pool_reads','Created_tmp_disk_tables','Created_tmp_tables')")
if err == nil && len(statusOut) > 0 {
var readRequests, reads, tmpDiskTables, tmpTables int64
for _, line := range strings.Split(string(statusOut), "\n") {
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
v, convErr := strconv.ParseInt(fields[1], 10, 64)
if convErr != nil {
continue
}
switch fields[0] {
case "Innodb_buffer_pool_read_requests":
readRequests = v
case "Innodb_buffer_pool_reads":
reads = v
case "Created_tmp_disk_tables":
tmpDiskTables = v
case "Created_tmp_tables":
tmpTables = v
}
}
if tmpTables > 0 && tmpDiskTables > 0 {
diskRatio := float64(tmpDiskTables) / float64(tmpTables) * 100
if diskRatio > 25.0 {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_mysql_config",
Message: "MySQL creating excessive temporary tables on disk",
Details: fmt.Sprintf("Disk ratio: %.1f%% (%d disk tables / %d total tables)", diskRatio, tmpDiskTables, tmpTables),
Timestamp: time.Now(),
})
}
}
if readRequests > 0 {
hitRatio := float64(readRequests-reads) / float64(readRequests) * 100
if hitRatio < 95.0 {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_mysql_config",
Message: "InnoDB buffer pool hit ratio is low",
Details: fmt.Sprintf("Hit ratio: %.1f%% (threshold: 95%%), disk reads: %d", hitRatio, reads),
Timestamp: time.Now(),
})
}
}
}
// --- Per-user connection counts ---
plOut, err := runCmd("mysql", "-N", "-B", "-e", "SHOW PROCESSLIST")
if err == nil && len(plOut) > 0 {
userCounts := make(map[string]int)
for _, line := range strings.Split(string(plOut), "\n") {
fields := strings.Fields(line)
// SHOW PROCESSLIST columns: Id, User, Host, db, Command, Time, State, Info
if len(fields) < 2 {
continue
}
user := fields[1]
if user == "" || user == "User" {
continue
}
userCounts[user]++
}
maxConn := cfg.Performance.MySQLMaxConnectionsPerUser
for dbUser, count := range userCounts {
if count > maxConn {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_mysql_config",
Message: fmt.Sprintf("MySQL user %s holding excessive connections", dbUser),
Details: fmt.Sprintf("Connections: %d, Threshold: %d", count, maxConn),
Timestamp: time.Now(),
})
}
}
}
return findings
}
// CheckRedisConfig inspects a local Redis instance for performance-impacting
// misconfigurations: unset maxmemory, noeviction policy, non-expiring keys,
// and an overly aggressive bgsave schedule for the dataset size.
func CheckRedisConfig(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
if !store.ShouldRunThrottled("perf_redis_config", 60) {
return nil
}
// Locate redis-cli
redisCLI := ""
for _, candidate := range []string{"/usr/bin/redis-cli", "/usr/local/bin/redis-cli"} {
if _, err := osFS.Stat(candidate); err == nil {
redisCLI = candidate
break
}
}
if redisCLI == "" {
return nil
}
var findings []alert.Finding
// --- maxmemory ---
maxMemOut, err := runCmd(redisCLI, "config", "get", "maxmemory")
if err == nil && len(maxMemOut) > 0 {
lines := strings.Fields(string(maxMemOut))
// redis config get returns two tokens: key value
if len(lines) >= 2 && lines[1] == "0" {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "perf_redis_config",
Message: "Redis maxmemory is not set",
Details: "maxmemory=0 means Redis will use all available system memory without bound",
Timestamp: time.Now(),
})
}
}
// --- maxmemory-policy ---
policyOut, err := runCmd(redisCLI, "config", "get", "maxmemory-policy")
if err == nil && len(policyOut) > 0 {
lines := strings.Fields(string(policyOut))
if len(lines) >= 2 && strings.ToLower(lines[1]) == "noeviction" {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_redis_config",
Message: "Redis maxmemory-policy is noeviction",
Details: "noeviction causes Redis to return errors when memory is full instead of evicting keys",
Timestamp: time.Now(),
})
}
}
// --- Non-expiring keys ratio via keyspace ---
keyspaceOut, err := runCmd(redisCLI, "info", "keyspace")
if err == nil && len(keyspaceOut) > 0 {
var totalKeys, totalExpires int64
for _, line := range strings.Split(string(keyspaceOut), "\n") {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "db") {
continue
}
// format: db0:keys=123,expires=45,avg_ttl=...
parts := strings.SplitN(line, ":", 2)
if len(parts) < 2 {
continue
}
for _, kv := range strings.Split(parts[1], ",") {
kv = strings.TrimSpace(kv)
kvParts := strings.SplitN(kv, "=", 2)
if len(kvParts) != 2 {
continue
}
v, convErr := strconv.ParseInt(kvParts[1], 10, 64)
if convErr != nil {
continue
}
switch kvParts[0] {
case "keys":
totalKeys += v
case "expires":
totalExpires += v
}
}
}
if totalKeys > 0 {
nonExpiring := totalKeys - totalExpires
ratio := float64(nonExpiring) / float64(totalKeys) * 100
if ratio > 95.0 {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_redis_config",
Message: "Redis has excessive non-expiring keys",
Details: fmt.Sprintf("Non-expiring: %d / %d total keys (%.1f%%)", nonExpiring, totalKeys, ratio),
Timestamp: time.Now(),
})
}
}
}
// --- bgsave interval vs dataset size ---
saveOut, _ := runCmd(redisCLI, "config", "get", "save")
infoOut, _ := runCmd(redisCLI, "info", "memory")
var usedMemoryBytes int64
if len(infoOut) > 0 {
for _, line := range strings.Split(string(infoOut), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "used_memory:") {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
v, convErr := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64)
if convErr == nil {
usedMemoryBytes = v
}
}
break
}
}
}
const gbBytes = 1024 * 1024 * 1024
largeDatasetBytes := int64(cfg.Performance.RedisLargeDatasetGB) * gbBytes
bgsaveMinInterval := cfg.Performance.RedisBgsaveMinInterval
if usedMemoryBytes > largeDatasetBytes && len(saveOut) > 0 {
// save config output: "save\n<seconds> <changes>\n<seconds> <changes>\n..."
lines := strings.Split(string(saveOut), "\n")
aggressiveSave := false
for _, line := range lines {
line = strings.TrimSpace(line)
fields := strings.Fields(line)
if len(fields) < 1 {
continue
}
// Skip the "save" key line itself
if fields[0] == "save" {
continue
}
// Each remaining line is "<seconds> <changes>" or a combined token
seconds, convErr := strconv.Atoi(fields[0])
if convErr == nil && seconds < bgsaveMinInterval {
aggressiveSave = true
break
}
}
if aggressiveSave {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "perf_redis_config",
Message: "Redis bgsave interval too aggressive for dataset size",
Details: fmt.Sprintf(
"Used memory: %s, Threshold: %s, Minimum safe bgsave interval: %ds",
humanBytes(usedMemoryBytes),
humanBytes(largeDatasetBytes),
bgsaveMinInterval,
),
Timestamp: time.Now(),
})
}
}
return findings
}
// ---------------------------------------------------------------------------
// Performance check helpers (WP-specific)
// ---------------------------------------------------------------------------
// safeIdentifier returns true if s matches ^[a-zA-Z0-9_]+$ (non-empty).
// Used to reject values with shell metacharacters before use in commands/SQL.
var safeIdentRe = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
func safeIdentifier(s string) bool {
return s != "" && safeIdentRe.MatchString(s)
}
// extractPHPDefine extracts the value argument from a PHP define() line:
//
// define('KEY', 'value'); or define("KEY", "value");
//
// It is distinct from extractDefine (dbscan.go) which requires a key parameter.
// Returns the empty string if no value can be extracted.
func extractPHPDefine(line string) string {
// Trim whitespace and trailing semicolons/comments.
line = strings.TrimSpace(line)
// Find the opening parenthesis.
parenIdx := strings.Index(line, "(")
if parenIdx < 0 {
return ""
}
inner := line[parenIdx+1:]
// Strip closing paren and anything after.
if closeIdx := strings.LastIndex(inner, ")"); closeIdx >= 0 {
inner = inner[:closeIdx]
}
// inner is now like: 'KEY', 'value' or "KEY", "value"
// Split on the first comma, ignoring the key part.
commaIdx := strings.Index(inner, ",")
if commaIdx < 0 {
return ""
}
valuePart := strings.TrimSpace(inner[commaIdx+1:])
if valuePart == "" {
return ""
}
// Strip surrounding quotes (single or double) when present.
q := valuePart[0]
if q == '\'' || q == '"' {
if len(valuePart) < 2 {
return ""
}
end := strings.LastIndexByte(valuePart, q)
if end <= 0 {
return ""
}
return valuePart[1:end]
}
// Unquoted literal (boolean/number constant). Strip a trailing ); or
// whitespace and return the bare token. Examples wp-config.php uses:
// define('DISABLE_WP_CRON', true);
// define('WP_DEBUG', false);
// define('WP_MEMORY_LIMIT', 256);
for i, c := range valuePart {
if c == ' ' || c == '\t' || c == ';' || c == ')' || c == ',' {
return strings.TrimSpace(valuePart[:i])
}
}
return strings.TrimSpace(valuePart)
}
// ---------------------------------------------------------------------------
// Subdirs to skip in recursive helpers.
// ---------------------------------------------------------------------------
var skipDirs = map[string]bool{
"wp-admin": true,
"wp-content": true,
"wp-includes": true,
"cache": true,
"node_modules": true,
"vendor": true,
}
// ---------------------------------------------------------------------------
// CheckErrorLogBloat
// ---------------------------------------------------------------------------
// scanErrorLogs recursively walks dir up to maxDepth looking for error_log
// files larger than threshold bytes. Results are appended to *findings (capped
// at 20).
func scanErrorLogs(dir string, thresholdBytes int64, depth int, findings *[]alert.Finding) {
if depth < 0 || len(*findings) >= 20 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, e := range entries {
if len(*findings) >= 20 {
return
}
name := e.Name()
fullPath := filepath.Join(dir, name)
if e.IsDir() {
if skipDirs[name] {
continue
}
scanErrorLogs(fullPath, thresholdBytes, depth-1, findings)
continue
}
if name != "error_log" {
continue
}
info, statErr := e.Info()
if statErr != nil {
continue
}
if info.Size() > thresholdBytes {
*findings = append(*findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_error_logs",
Message: fmt.Sprintf("Bloated error_log: %s", fullPath),
Details: fmt.Sprintf("Size: %s", humanBytes(info.Size())),
Timestamp: time.Now(),
})
}
}
}
// CheckErrorLogBloat walks configured web roots (default /home/*/public_html
// on cPanel) looking for error_log files that exceed the configured size
// threshold. Throttled to once every 60 minutes.
func CheckErrorLogBloat(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
if !store.ShouldRunThrottled("perf_error_logs", 60) {
return nil
}
thresholdBytes := int64(cfg.Performance.ErrorLogWarnSizeMB) * 1024 * 1024
homeDirs := ResolveWebRoots(cfg)
var findings []alert.Finding
for _, dir := range homeDirs {
scanErrorLogs(dir, thresholdBytes, 3, &findings)
if len(findings) >= 20 {
break
}
}
return findings
}
// ---------------------------------------------------------------------------
// CheckWPConfig
// ---------------------------------------------------------------------------
// parseMemoryLimit converts a PHP memory_limit string (e.g. "256M", "1G")
// to megabytes. Returns 0 if the value cannot be parsed.
func parseMemoryLimit(s string) int {
s = strings.TrimSpace(strings.ToUpper(s))
if s == "" || s == "-1" {
return 0
}
suffix := s[len(s)-1]
numStr := s
mult := 1
switch suffix {
case 'K':
numStr = s[:len(s)-1]
v, err := strconv.Atoi(numStr)
if err != nil {
return 0
}
return v / 1024
case 'M':
numStr = s[:len(s)-1]
mult = 1
case 'G':
numStr = s[:len(s)-1]
mult = 1024
}
v, err := strconv.Atoi(numStr)
if err != nil {
return 0
}
return v * mult
}
// scanWPConfigs recursively searches dir (max depth) for wp-config.php files
// and checks WP_MEMORY_LIMIT and co-located config files for issues.
func scanWPConfigs(dir, account string, cfg *config.Config, depth int, findings *[]alert.Finding) {
if depth < 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, e := range entries {
name := e.Name()
fullPath := filepath.Join(dir, name)
if e.IsDir() {
if skipDirs[name] {
continue
}
scanWPConfigs(fullPath, account, cfg, depth-1, findings)
continue
}
if name != "wp-config.php" {
continue
}
// --- WP_MEMORY_LIMIT ---
wpData, readErr := osFS.ReadFile(fullPath)
if readErr == nil {
for _, line := range strings.Split(string(wpData), "\n") {
if strings.Contains(line, "WP_MEMORY_LIMIT") {
val := extractPHPDefine(strings.TrimSpace(line))
if mb := parseMemoryLimit(val); mb > cfg.Performance.WPMemoryLimitMaxMB {
*findings = append(*findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_wp_config",
Message: fmt.Sprintf("Excessive WP_MEMORY_LIMIT for %s", account),
Details: fmt.Sprintf("File: %s, Value: %s", fullPath, val),
Timestamp: time.Now(),
})
}
break
}
}
}
// --- Co-located PHP config files ---
wpDir := filepath.Dir(fullPath)
for _, cfgFile := range []string{".htaccess", "php.ini", ".user.ini"} {
cfgPath := filepath.Join(wpDir, cfgFile)
data, readErr2 := osFS.ReadFile(cfgPath)
if readErr2 != nil {
continue
}
// cPanel MultiPHP INI Editor writes .user.ini with a fixed
// header and owns the file's content. Values inside a
// cPanel-managed .user.ini (max_execution_time=0 for a
// backup importer, display_errors=On for a staging account)
// reflect operator choices made through the cPanel UI and
// are not attacker actions. Suppress findings for this
// file in that case — operators do not need alerts for
// their own configuration. The suppression is scoped
// strictly to .user.ini: the same signature in php.ini or
// .htaccess is not authoritative (cPanel does not write
// those files) and the scanner treats it normally.
if cfgFile == ".user.ini" && isCpanelManagedUserIni(data) {
continue
}
for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
// Skip comment lines
if strings.HasPrefix(trimmed, "#") || strings.HasPrefix(trimmed, ";") {
continue
}
lc := strings.ToLower(trimmed)
switch {
case strings.Contains(lc, "max_execution_time"):
// max_execution_time = 0 (or php_value max_execution_time 0)
parts := strings.FieldsFunc(trimmed, func(r rune) bool { return r == '=' || r == ' ' || r == '\t' })
if len(parts) >= 2 && parts[len(parts)-1] == "0" {
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "perf_wp_config",
Message: fmt.Sprintf("Unlimited max_execution_time for %s", account),
Details: fmt.Sprintf("File: %s, Value: 0", cfgPath),
Timestamp: time.Now(),
})
}
case strings.Contains(lc, "display_errors"):
parts := strings.FieldsFunc(trimmed, func(r rune) bool { return r == '=' || r == ' ' || r == '\t' })
if len(parts) >= 2 && strings.ToLower(parts[len(parts)-1]) == "on" {
*findings = append(*findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_wp_config",
Message: fmt.Sprintf("display_errors enabled in production for %s", account),
Details: fmt.Sprintf("File: %s, Value: On", cfgPath),
Timestamp: time.Now(),
})
}
}
}
}
}
}
// CheckWPConfig scans /home/*/public_html (max depth 2) for wp-config.php
// files and reports excessive WP_MEMORY_LIMIT values, unlimited
// max_execution_time, and display_errors enabled in production.
// Throttled to once every 60 minutes.
func CheckWPConfig(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
if !store.ShouldRunThrottled("perf_wp_config", 60) {
return nil
}
homeDirs := ResolveWebRoots(cfg)
var findings []alert.Finding
for _, dir := range homeDirs {
scanWPConfigs(dir, accountFromPath(dir), cfg, 2, &findings)
}
return findings
}
// accountFromPath extracts a best-effort account name from a web root path.
// On cPanel (/home/USER/public_html) it returns USER. On other layouts it
// returns the parent directory name, or the final path component if there
// is no parent. Used for reporting only — never for authorization.
func accountFromPath(dir string) string {
parts := strings.Split(dir, string(filepath.Separator))
// cPanel shape: /home/<account>/public_html
for i, p := range parts {
if p == "home" && i+1 < len(parts) {
return parts[i+1]
}
}
// Generic shape: /var/www/<site>, /srv/http/<site>, etc.
if len(parts) >= 2 && parts[len(parts)-1] != "" {
return parts[len(parts)-2]
}
return filepath.Base(dir)
}
// ---------------------------------------------------------------------------
// CheckWPTransientBloat
// ---------------------------------------------------------------------------
// findWPTransients recursively searches dir for wp-config.php files and
// queries the WordPress database for bloated transients.
func findWPTransients(dir string, cfg *config.Config, warnBytes, critBytes int64, depth int, findings *[]alert.Finding) {
if depth < 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, e := range entries {
name := e.Name()
fullPath := filepath.Join(dir, name)
if e.IsDir() {
if skipDirs[name] {
continue
}
findWPTransients(fullPath, cfg, warnBytes, critBytes, depth-1, findings)
continue
}
if name != "wp-config.php" {
continue
}
info := parseWPConfig(fullPath)
if info.dbName == "" || info.dbUser == "" {
continue
}
// Apply default table prefix when not set.
if info.tablePrefix == "" {
info.tablePrefix = "wp_"
}
// Security: validate identifiers before use in SQL.
if !safeIdentifier(info.dbName) || !safeIdentifier(info.dbUser) || !safeIdentifier(info.tablePrefix) {
continue
}
query := fmt.Sprintf(
"SELECT option_name, LENGTH(option_value) as size FROM %soptions WHERE option_name LIKE '_transient_%%' AND LENGTH(option_value) > %d ORDER BY size DESC LIMIT 5",
info.tablePrefix,
warnBytes,
)
args := []string{
"-N", "-B",
"-h", info.dbHost,
"-u", info.dbUser,
info.dbName,
"-e", query,
}
out, runErr := runCmdWithEnv("mysql", args, "MYSQL_PWD="+info.dbPass)
if runErr != nil || len(out) == 0 {
continue
}
for _, line := range strings.Split(string(out), "\n") {
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
optionName := fields[0]
sizeBytes, convErr := strconv.ParseInt(fields[1], 10, 64)
if convErr != nil {
continue
}
var sev alert.Severity
switch {
case sizeBytes > critBytes:
sev = alert.High
case sizeBytes > warnBytes:
sev = alert.Warning
default:
continue
}
*findings = append(*findings, alert.Finding{
Severity: sev,
Check: "perf_wp_transients",
Message: fmt.Sprintf("Bloated transient %s in %s", optionName, info.dbName),
Details: fmt.Sprintf("Size: %s", humanBytes(sizeBytes)),
Timestamp: time.Now(),
})
}
}
}
// CheckWPTransientBloat scans configured web roots (default /home/*/public_html
// on cPanel) for WordPress installs and queries each database for oversized
// transients. DB credentials are read from wp-config.php; the password is
// passed via MYSQL_PWD environment variable (never on the command line).
// Throttled to once every 60 minutes.
func CheckWPTransientBloat(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
if !store.ShouldRunThrottled("perf_wp_transients", 60) {
return nil
}
warnBytes := int64(cfg.Performance.WPTransientWarnMB) * 1024 * 1024
critBytes := int64(cfg.Performance.WPTransientCriticalMB) * 1024 * 1024
homeDirs := ResolveWebRoots(cfg)
var findings []alert.Finding
for _, dir := range homeDirs {
findWPTransients(dir, cfg, warnBytes, critBytes, 2, &findings)
}
return findings
}
// ---------------------------------------------------------------------------
// CheckWPCron
// ---------------------------------------------------------------------------
// scanWPCron recursively searches dir for wp-config.php files and checks
// whether DISABLE_WP_CRON is defined and set to true.
func scanWPCron(dir, account string, depth int, findings *[]alert.Finding) {
if depth < 0 || len(*findings) >= 30 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, e := range entries {
if len(*findings) >= 30 {
return
}
name := e.Name()
fullPath := filepath.Join(dir, name)
if e.IsDir() {
if skipDirs[name] {
continue
}
scanWPCron(fullPath, account, depth-1, findings)
continue
}
if name != "wp-config.php" {
continue
}
data, readErr := osFS.ReadFile(fullPath)
if readErr != nil {
continue
}
defined := false
enabled := false // true when defined as true
for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
if !strings.Contains(trimmed, "DISABLE_WP_CRON") {
continue
}
val := strings.ToLower(extractPHPDefine(trimmed))
defined = true
if val == "true" || val == "1" {
enabled = true
}
break
}
if !defined || !enabled {
*findings = append(*findings, alert.Finding{
Severity: alert.Warning,
Check: "perf_wp_cron",
Message: fmt.Sprintf("WP-Cron not disabled for %s", account),
Details: fmt.Sprintf(
"File: %s - add define('DISABLE_WP_CRON', true); and use a real cron job instead",
fullPath,
),
Timestamp: time.Now(),
})
}
}
}
// CheckWPCron scans configured web roots (default /home/*/public_html on
// cPanel) for WordPress installs that have not disabled the built-in
// WP-Cron mechanism. Running WP-Cron via HTTP is a common cause of high
// load on busy sites.
// Throttled to once every 60 minutes.
func CheckWPCron(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
if !perfEnabled(cfg) {
return nil
}
if !store.ShouldRunThrottled("perf_wp_cron", 60) {
return nil
}
homeDirs := ResolveWebRoots(cfg)
var findings []alert.Finding
for _, dir := range homeDirs {
scanWPCron(dir, accountFromPath(dir), 2, &findings)
if len(findings) >= 30 {
break
}
}
return findings
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strings"
"unicode"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const phishingReadSize = 16384 // Read first 16KB - phishing pages are self-contained
// ---------------------------------------------------------------------------
// Brand impersonation patterns
// ---------------------------------------------------------------------------
var phishingBrands = []struct {
name string
titlePatterns []string
bodyPatterns []string
}{
{
name: "Microsoft/SharePoint",
titlePatterns: []string{"sharepoint", "onedrive", "microsoft 365", "outlook web", "office 365", "ms online"},
bodyPatterns: []string{"sharepoint", "onedrive", "secured by microsoft", "microsoft corporation"},
},
{
name: "Google",
titlePatterns: []string{"google drive", "google docs", "google sign", "gmail", "google workspace"},
bodyPatterns: []string{"google drive", "google docs", "accounts.google", "secured by google"},
},
{
name: "Dropbox",
titlePatterns: []string{"dropbox", "shared file", "shared folder"},
bodyPatterns: []string{"dropbox", "dropbox.com", "secured by dropbox"},
},
{
name: "DocuSign",
titlePatterns: []string{"docusign", "document signing", "e-signature"},
bodyPatterns: []string{"docusign", "please review and sign", "e-signature"},
},
{
name: "Adobe",
titlePatterns: []string{"adobe sign", "adobe document", "adobe acrobat"},
bodyPatterns: []string{"adobe sign", "adobe document cloud", "secured by adobe"},
},
{
name: "WeTransfer",
titlePatterns: []string{"wetransfer", "file transfer"},
bodyPatterns: []string{"wetransfer", "download your files"},
},
{
name: "Apple/iCloud",
titlePatterns: []string{"icloud", "apple id", "find my"},
bodyPatterns: []string{"icloud.com", "apple id", "secured by apple"},
},
{
name: "PayPal",
titlePatterns: []string{"paypal", "pay pal"},
bodyPatterns: []string{"paypal.com", "secured by paypal"},
},
{
name: "Webmail/Roundcube",
titlePatterns: []string{"roundcube", "horde", "webmail login", "webmail ::", "squirrelmail", "zimbra"},
bodyPatterns: []string{"roundcube webmail", "horde login", "zimbra web client", "webmail login"},
// Note: bare "squirrelmail"/"roundcube" removed from body - sites legitimately
// link to their server's webmail (e.g. href="squirrelmail/index.php").
},
{
name: "cPanel/WHM",
titlePatterns: []string{"cpanel", "whm login", "webhost manager"},
bodyPatterns: []string{"cpanel login", "whm login", "webhost manager"},
},
{
name: "Banking/Financial",
titlePatterns: []string{"online banking", "bank login", "secure banking", "account login"},
bodyPatterns: []string{"online banking", "bank account", "transaction verification"},
},
{
name: "Generic Login",
titlePatterns: []string{"secure access", "verify your", "confirm your identity", "account verification", "email verification", "sign in"},
bodyPatterns: []string{"verify your identity", "confirm your account", "unusual activity"},
},
}
// ---------------------------------------------------------------------------
// Content-based indicators
// ---------------------------------------------------------------------------
// Credential harvesting patterns in page body.
var harvestIndicators = []string{
"window.location.href",
"window.location.replace",
"window.location =",
"document.location.href",
"form.submit()",
".workers.dev",
"confirm access",
"verify your email",
"confirm your email",
"verify identity",
"continue to document",
"access confirmed, redirecting",
"secured by microsoft",
"secured by google",
"secured by apple",
"256-bit encrypted",
"256‑bit encrypted",
// fetch/XHR exfiltration - silent credential POST without redirect
"fetch(",
"xmlhttprequest",
"$.ajax(",
"$.post(",
"navigator.sendbeacon(",
}
// Redirect/exfiltration URL patterns.
var exfilPatterns = []string{
".workers.dev",
"//t.co/",
"/redir?",
"/redirect?",
"effi.redir",
"link?url=",
"goto_url=",
"//bit.ly/",
"//tinyurl.com/",
"//rb.gy/",
"//is.gd/",
"/servlet/effi.redir",
}
// Fake trust badge patterns - security claims in pages not on the brand's domain.
var trustBadgePatterns = []string{
"secured by microsoft",
"secured by google",
"secured by apple",
"secured by dropbox",
"secured by adobe",
"verified by microsoft",
"protected by microsoft",
"256-bit encrypted",
"256‑bit encrypted",
"ssl secured",
"bank-level encryption",
"enterprise security",
}
// Urgency language used to pressure victims.
var urgencyPatterns = []string{
"expires in",
"temporary hold",
"limited time",
"unusual activity detected",
"suspicious activity",
"your account will be",
"verify within",
"action required",
"immediate action",
"account suspended",
"access will be revoked",
}
// Embedded asset indicators - phishing kits embed logos to avoid external loading.
var embeddedAssetPatterns = []string{
"data:image/png;base64,",
"data:image/svg+xml;base64,",
"data:image/jpeg;base64,",
}
// ---------------------------------------------------------------------------
// Main check entry point
// ---------------------------------------------------------------------------
// CheckPhishing scans HTML files in user document roots for phishing pages.
// Uses three detection layers:
// 1. Content analysis - brand impersonation + credential harvesting patterns
// 2. Structural analysis - self-contained HTML with embedded assets
// 3. Directory anomaly - lone HTML files in otherwise empty directories
func CheckPhishing(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
homeDirs, err := GetScanHomeDirs()
if err != nil {
return nil
}
for _, homeEntry := range homeDirs {
if ctx.Err() != nil {
return findings
}
if !homeEntry.IsDir() {
continue
}
user := homeEntry.Name()
if strings.HasPrefix(user, ".") || user == "virtfs" {
continue
}
homeDir := filepath.Join("/home", user)
docRoots := []string{filepath.Join(homeDir, "public_html")}
subDirs, _ := osFS.ReadDir(homeDir)
for _, sd := range subDirs {
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
docRoots = append(docRoots, filepath.Join(homeDir, sd.Name()))
}
}
for _, docRoot := range docRoots {
scanForPhishing(ctx, docRoot, 3, user, cfg, &findings)
if ctx.Err() != nil {
return findings
}
}
}
return findings
}
// ---------------------------------------------------------------------------
// Directory scanner
// ---------------------------------------------------------------------------
func scanForPhishing(ctx context.Context, dir string, maxDepth int, user string, cfg *config.Config, findings *[]alert.Finding) {
if ctx.Err() != nil {
return
}
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
if ctx.Err() != nil {
return
}
name := entry.Name()
fullPath := filepath.Join(dir, name)
suppressed := false
for _, ignore := range cfg.Suppressions.IgnorePaths {
if matchGlob(fullPath, ignore) {
suppressed = true
break
}
}
if suppressed {
continue
}
if entry.IsDir() {
if isKnownSafeDir(name) {
continue
}
// --- Directory anomaly detection ---
dirResult := analyzeDirectoryStructure(fullPath, user)
if dirResult != nil {
*findings = append(*findings, *dirResult)
}
scanForPhishing(ctx, fullPath, maxDepth-1, user, cfg, findings)
continue
}
nameLower := strings.ToLower(name)
info, err := entry.Info()
if err != nil {
continue
}
size := info.Size()
// --- HTML/HTM phishing pages ---
if strings.HasSuffix(nameLower, ".html") || strings.HasSuffix(nameLower, ".htm") {
// Standard phishing page check (3KB-100KB)
if size >= 3000 && size <= 100000 {
result := analyzeHTMLForPhishing(fullPath)
if result != nil {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "phishing_page",
Message: fmt.Sprintf("Phishing page detected (%s impersonation): %s", result.brand, fullPath),
Details: fmt.Sprintf("Account: %s\nBrand: %s\nScore: %d/10\nIndicators:\n- %s\nSize: %d bytes",
user, result.brand, result.score, strings.Join(result.indicators, "\n- "), size),
FilePath: fullPath,
})
}
}
// --- iframe phishing (tiny HTML files that embed external phishing) ---
if size > 0 && size < 3000 {
if result := checkIframePhishing(fullPath); result != "" {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "phishing_iframe",
Message: fmt.Sprintf("Iframe phishing page detected: %s", fullPath),
Details: fmt.Sprintf("Account: %s\n%s", user, result),
FilePath: fullPath,
})
}
}
continue
}
// --- PHP phishing pages and open redirectors ---
if strings.HasSuffix(nameLower, ".php") {
// Skip known CMS files
if isKnownCMSFile(nameLower) {
continue
}
// PHP phishing (3KB-100KB) - same brand/content analysis as HTML
if size >= 3000 && size <= 100000 {
result := analyzePHPForPhishing(fullPath)
if result != nil {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "phishing_php",
Message: fmt.Sprintf("PHP phishing page detected (%s): %s", result.brand, fullPath),
Details: fmt.Sprintf("Account: %s\nBrand: %s\nScore: %d/10\nIndicators:\n- %s\nSize: %d bytes",
user, result.brand, result.score, strings.Join(result.indicators, "\n- "), size),
FilePath: fullPath,
})
}
}
// PHP open redirector (tiny PHP files under 1KB)
if size > 0 && size < 1024 {
if result := checkPHPRedirector(fullPath); result != "" {
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "phishing_redirector",
Message: fmt.Sprintf("PHP open redirector detected: %s", fullPath),
Details: fmt.Sprintf("Account: %s\n%s", user, result),
FilePath: fullPath,
})
}
}
continue
}
// --- Credential log files ---
if isCredentialLogName(nameLower) && size > 0 && size < 10*1024*1024 {
if result := checkCredentialLog(fullPath); result != "" {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "phishing_credential_log",
Message: fmt.Sprintf("Harvested credential log file detected: %s", fullPath),
Details: fmt.Sprintf("Account: %s\n%s", user, result),
FilePath: fullPath,
})
}
continue
}
// --- Phishing kit ZIP archives ---
if strings.HasSuffix(nameLower, ".zip") && size > 1000 && size < 50*1024*1024 {
if isPhishingKitZip(nameLower) {
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "phishing_kit_archive",
Message: fmt.Sprintf("Suspected phishing kit archive: %s", fullPath),
Details: fmt.Sprintf("Account: %s\nFilename: %s\nSize: %d bytes",
user, name, size),
FilePath: fullPath,
})
}
}
}
}
// ---------------------------------------------------------------------------
// Layer 1: Content analysis
// ---------------------------------------------------------------------------
type phishingResult struct {
brand string
score int
indicators []string
}
func analyzeHTMLForPhishing(path string) *phishingResult {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
buf := make([]byte, phishingReadSize)
n, _ := f.Read(buf)
if n == 0 {
return nil
}
content := string(buf[:n])
contentLower := strings.ToLower(content)
// Must contain a form or input
hasForm := strings.Contains(contentLower, "<form") ||
strings.Contains(contentLower, "<input")
if !hasForm {
return nil
}
// Must contain email/password input
hasCredentialInput := strings.Contains(contentLower, "type=\"email\"") ||
strings.Contains(contentLower, "type=\"password\"") ||
strings.Contains(contentLower, "type='email'") ||
strings.Contains(contentLower, "type='password'") ||
strings.Contains(contentLower, "name=\"email\"") ||
strings.Contains(contentLower, "name=\"pass\"") ||
strings.Contains(contentLower, "name=\"password\"") ||
strings.Contains(contentLower, "name=\"login\"") ||
strings.Contains(contentLower, "placeholder=\"email") ||
strings.Contains(contentLower, "placeholder=\"you@") ||
strings.Contains(contentLower, "placeholder=\"your email") ||
strings.Contains(contentLower, "work or school email") ||
strings.Contains(contentLower, "corporate email")
if !hasCredentialInput {
return nil
}
var indicators []string
score := 0
// --- Brand impersonation ---
brandMatch := ""
titleContent := extractTitle(contentLower)
for _, brand := range phishingBrands {
titleHit := false
bodyHit := false
for _, tp := range brand.titlePatterns {
if strings.Contains(titleContent, tp) {
titleHit = true
indicators = append(indicators, fmt.Sprintf("title impersonates '%s'", tp))
score += 3
break
}
}
for _, bp := range brand.bodyPatterns {
if strings.Contains(contentLower, bp) {
bodyHit = true
if !titleHit {
indicators = append(indicators, fmt.Sprintf("body impersonates '%s'", bp))
score += 2
}
break
}
}
if titleHit || bodyHit {
brandMatch = brand.name
break
}
}
// --- Credential harvesting / redirect indicators ---
for _, pattern := range harvestIndicators {
if strings.Contains(contentLower, pattern) {
score++
indicators = append(indicators, fmt.Sprintf("harvest: '%s'", pattern))
}
}
// --- Exfiltration URL patterns ---
for _, pattern := range exfilPatterns {
if strings.Contains(contentLower, pattern) {
score += 2
indicators = append(indicators, fmt.Sprintf("exfiltration: '%s'", pattern))
}
}
// --- Fake trust badges ---
for _, pattern := range trustBadgePatterns {
if strings.Contains(contentLower, pattern) {
score++
indicators = append(indicators, fmt.Sprintf("fake trust badge: '%s'", pattern))
}
}
// --- Urgency language ---
urgencyCount := 0
for _, pattern := range urgencyPatterns {
if strings.Contains(contentLower, pattern) {
urgencyCount++
}
}
if urgencyCount > 0 {
score++
indicators = append(indicators, fmt.Sprintf("urgency language (%d patterns)", urgencyCount))
}
// --- Embedded Base64 assets (logos embedded to avoid external loading) ---
embeddedCount := 0
for _, pattern := range embeddedAssetPatterns {
embeddedCount += countOccurrences(contentLower, pattern)
}
if embeddedCount > 0 {
score++
indicators = append(indicators, fmt.Sprintf("embedded base64 assets (%d)", embeddedCount))
}
// --- Form action pointing to external domain ---
if hasExternalFormAction(content) {
score += 2
indicators = append(indicators, "form action points to external domain")
}
// --- Self-contained page (all CSS inline, no external stylesheets except CDN) ---
if isSelfContainedHTML(contentLower) {
score++
indicators = append(indicators, "self-contained HTML (all styles inline)")
}
// --- Person name as filename ---
baseName := strings.TrimSuffix(strings.TrimSuffix(filepath.Base(path), ".html"), ".htm")
if looksLikePersonName(baseName) {
score += 2
indicators = append(indicators, fmt.Sprintf("filename looks like person name: '%s'", baseName))
}
// --- Decision ---
// With brand match: need score >= 4 (brand gives 2-3 + at least 1 other signal)
// Without brand match: need score >= 6 (multiple strong signals)
if brandMatch != "" && score >= 4 {
return &phishingResult{brand: brandMatch, score: score, indicators: indicators}
}
if brandMatch == "" && score >= 6 {
return &phishingResult{brand: "Unknown", score: score, indicators: indicators}
}
return nil
}
// ---------------------------------------------------------------------------
// Layer 2: Structural analysis helpers
// ---------------------------------------------------------------------------
// extractTitle pulls the <title> content from HTML.
func extractTitle(contentLower string) string {
start := strings.Index(contentLower, "<title>")
end := strings.Index(contentLower, "</title>")
if start >= 0 && end > start+7 {
return contentLower[start+7 : end]
}
return ""
}
// hasExternalFormAction checks if a <form> action points to a different domain.
func hasExternalFormAction(content string) bool {
lower := strings.ToLower(content)
// Find form action="..." or action='...'
idx := strings.Index(lower, "action=")
if idx < 0 {
return false
}
rest := content[idx+7:]
if len(rest) == 0 {
return false
}
// Extract the URL value
quote := rest[0]
if quote != '"' && quote != '\'' {
return false
}
endIdx := strings.IndexByte(rest[1:], quote)
if endIdx < 0 {
return false
}
url := strings.ToLower(rest[1 : endIdx+1])
// External if it starts with http:// or https:// (not relative)
return strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://")
}
// isSelfContainedHTML checks if a page has all its CSS inline (embedded <style> tags)
// with no or minimal external stylesheet references - typical of phishing kits.
func isSelfContainedHTML(contentLower string) bool {
hasInlineStyle := strings.Contains(contentLower, "<style")
externalCSS := countOccurrences(contentLower, "rel=\"stylesheet\"") +
countOccurrences(contentLower, "rel='stylesheet'")
// Allow 1 external CSS (e.g., Font Awesome CDN) - phishing kits often use one
return hasInlineStyle && externalCSS <= 1
}
// looksLikePersonName checks if a filename looks like a person name (CamelCase
// with 2+ capitalized words, e.g., "PalmerHamilton", "MarilynEsguerra").
func looksLikePersonName(name string) bool {
if len(name) < 6 {
return false
}
// Count uppercase transitions (start of words in CamelCase)
upperCount := 0
for i, c := range name {
if unicode.IsUpper(c) {
if i == 0 || unicode.IsLower(rune(name[i-1])) {
upperCount++
}
}
}
// 2+ capitalized words, all letters, no common web words
if upperCount < 2 {
return false
}
allLetters := true
for _, c := range name {
if !unicode.IsLetter(c) {
allLetters = false
break
}
}
if !allLetters {
return false
}
// Exclude common web filenames
nameLower := strings.ToLower(name)
webNames := []string{"index", "default", "portal", "login", "home", "main",
"readme", "changelog", "license", "manifest", "service"}
for _, w := range webNames {
if nameLower == w {
return false
}
}
return true
}
// ---------------------------------------------------------------------------
// Layer 3: Directory anomaly detection
// ---------------------------------------------------------------------------
// analyzeDirectoryStructure checks if a directory looks like a phishing drop:
// - Contains only 1-3 HTML files and nothing else significant
// - Directory name looks like a business/organization name
// - No CMS markers (wp-config, index.php, etc.)
func analyzeDirectoryStructure(dir string, user string) *alert.Finding {
entries, err := osFS.ReadDir(dir)
if err != nil {
return nil
}
var htmlFiles []string
otherFiles := 0
totalFiles := 0
for _, entry := range entries {
if entry.IsDir() {
return nil // Has subdirectories - likely not a simple phishing drop
}
name := entry.Name()
if strings.HasPrefix(name, ".") {
continue // Skip dotfiles (.htaccess etc.)
}
totalFiles++
nameLower := strings.ToLower(name)
if strings.HasSuffix(nameLower, ".html") || strings.HasSuffix(nameLower, ".htm") {
htmlFiles = append(htmlFiles, name)
} else {
otherFiles++
}
}
// Must have exactly 1-3 HTML files and at most 1 other file
if len(htmlFiles) == 0 || len(htmlFiles) > 3 || otherFiles > 1 {
return nil
}
// Directory name should look like a business/organization (CamelCase or multi-word)
dirName := filepath.Base(dir)
if !looksLikeBusinessName(dirName) {
return nil
}
// Verify at least one HTML file has credential inputs (quick check)
hasPhishingContent := false
for _, htmlFile := range htmlFiles {
fullPath := filepath.Join(dir, htmlFile)
if quickPhishingCheck(fullPath) {
hasPhishingContent = true
break
}
}
if !hasPhishingContent {
return nil
}
// Build indicators
indicators := []string{
fmt.Sprintf("directory '%s' contains only %d HTML file(s)", dirName, len(htmlFiles)),
fmt.Sprintf("directory name resembles business/organization: '%s'", dirName),
}
for _, h := range htmlFiles {
baseName := strings.TrimSuffix(strings.TrimSuffix(h, ".html"), ".htm")
if looksLikePersonName(baseName) {
indicators = append(indicators, fmt.Sprintf("HTML filename looks like person name: '%s'", baseName))
}
}
return &alert.Finding{
Severity: alert.High,
Check: "phishing_directory",
Message: fmt.Sprintf("Suspected phishing directory (lone HTML in business-named folder): %s", dir),
Details: fmt.Sprintf("Account: %s\nDirectory: %s\nHTML files: %s\nIndicators:\n- %s",
user, dirName, strings.Join(htmlFiles, ", "), strings.Join(indicators, "\n- ")),
FilePath: dir,
}
}
// looksLikeBusinessName checks if a directory name looks like a business or
// organization name rather than a standard web directory.
func looksLikeBusinessName(name string) bool {
if len(name) < 5 {
return false
}
nameLower := strings.ToLower(name)
// Skip names that start with tech/dev terms - these are tutorial
// or test directories, not business names (e.g. "php-email-form",
// "PHP-Login", "JavaScript Login")
techPrefixes := []string{
"php", "javascript", "js-", "css", "html", "python",
"java", "node", "react", "vue", "angular", "jquery",
"bootstrap", "wordpress", "wp-", "laravel",
}
for _, prefix := range techPrefixes {
if strings.HasPrefix(nameLower, prefix) {
return false
}
}
// Skip standard web directories
standardDirs := []string{
"images", "img", "css", "js", "fonts", "assets", "static",
"media", "uploads", "downloads", "files", "docs", "data",
"api", "admin", "config", "templates", "scripts", "lib",
"src", "dist", "build", "public", "private", "backup",
"old", "new", "test", "dev", "staging", "demo",
}
for _, sd := range standardDirs {
if nameLower == sd {
return false
}
}
// CamelCase detection (e.g., WashingtonGolf, XRFScientificAmericasInc)
upperTransitions := 0
for i, c := range name {
if unicode.IsUpper(c) && i > 0 && unicode.IsLower(rune(name[i-1])) {
upperTransitions++
}
}
if upperTransitions >= 1 && unicode.IsUpper(rune(name[0])) {
return true
}
// Multi-word with separators (e.g., federated-lighting, northwest_crawlspace)
if strings.ContainsAny(name, "-_") {
parts := strings.FieldsFunc(name, func(r rune) bool { return r == '-' || r == '_' })
if len(parts) >= 2 {
return true
}
}
// Long lowercase name that doesn't match standard dirs (e.g., "healthcornerpediattrics")
allLower := true
for _, c := range name {
if !unicode.IsLetter(c) {
allLower = false
break
}
}
if allLower && len(name) >= 12 {
return true
}
return false
}
// quickPhishingCheck does a fast read of an HTML file to check for credential
// input fields without full analysis - used for directory structure checks.
func quickPhishingCheck(path string) bool {
f, err := osFS.Open(path)
if err != nil {
return false
}
defer func() { _ = f.Close() }()
buf := make([]byte, 4096) // Only need first 4KB for quick check
n, _ := f.Read(buf)
if n == 0 {
return false
}
content := strings.ToLower(string(buf[:n]))
return (strings.Contains(content, "<form") || strings.Contains(content, "<input")) &&
(strings.Contains(content, "email") || strings.Contains(content, "password"))
}
// ---------------------------------------------------------------------------
// Layer 4: PHP phishing pages
// ---------------------------------------------------------------------------
// analyzePHPForPhishing reads a PHP file and checks for embedded HTML with
// brand impersonation. PHP phishing kits often have PHP code at the top
// (credential handling, emailing) and HTML output below.
func analyzePHPForPhishing(path string) *phishingResult {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
buf := make([]byte, phishingReadSize)
n, _ := f.Read(buf)
if n == 0 {
return nil
}
content := string(buf[:n])
contentLower := strings.ToLower(content)
// PHP phishing indicators: credential handling code.
// Only truly specific patterns belong here - generic functions like
// mail() and fwrite() are handled separately with context checks below.
phpPhishingPatterns := []string{
"$_post['email']",
"$_post['password']",
"$_post['pass']",
"$_post[\"email\"]",
"$_post[\"password\"]",
"$_post[\"pass\"]",
"$_request['email']",
"$_request['password']",
}
var indicators []string
score := 0
// Check for PHP credential handling
phpCredHandling := false
for _, pattern := range phpPhishingPatterns {
if strings.Contains(contentLower, pattern) {
phpCredHandling = true
indicators = append(indicators, fmt.Sprintf("PHP credential handling: '%s'", pattern))
score += 2
break
}
}
// Must have either PHP credential handling OR HTML form output
hasForm := strings.Contains(contentLower, "<form") || strings.Contains(contentLower, "<input")
hasCredentialInput := strings.Contains(contentLower, "type=\"email\"") ||
strings.Contains(contentLower, "type=\"password\"") ||
strings.Contains(contentLower, "type='email'") ||
strings.Contains(contentLower, "type='password'") ||
strings.Contains(contentLower, "name=\"email\"") ||
strings.Contains(contentLower, "name=\"password\"") ||
strings.Contains(contentLower, "name=\"pass\"")
if !phpCredHandling && (!hasForm || !hasCredentialInput) {
return nil
}
// Check brand impersonation (same as HTML check)
brandMatch := ""
titleContent := extractTitle(contentLower)
for _, brand := range phishingBrands {
for _, tp := range brand.titlePatterns {
if strings.Contains(titleContent, tp) {
brandMatch = brand.name
indicators = append(indicators, fmt.Sprintf("title impersonates '%s'", tp))
score += 3
break
}
}
if brandMatch != "" {
break
}
for _, bp := range brand.bodyPatterns {
if strings.Contains(contentLower, bp) {
brandMatch = brand.name
indicators = append(indicators, fmt.Sprintf("body impersonates '%s'", bp))
score += 2
break
}
}
if brandMatch != "" {
break
}
}
// Check harvest/exfil patterns
for _, pattern := range harvestIndicators {
if strings.Contains(contentLower, pattern) {
score++
indicators = append(indicators, fmt.Sprintf("harvest: '%s'", pattern))
}
}
for _, pattern := range exfilPatterns {
if strings.Contains(contentLower, pattern) {
score += 2
indicators = append(indicators, fmt.Sprintf("exfiltration: '%s'", pattern))
}
}
// PHP-specific exfil: emailing or writing harvested credentials.
// These only fire when the file also reads from $_POST/$_REQUEST,
// because a phishing kit must capture form data before exfiltrating it.
// Without this gate, any PHP file using fwrite()+config keywords triggers.
hasPostData := strings.Contains(contentLower, "$_post") || strings.Contains(contentLower, "$_request")
if hasPostData && strings.Contains(contentLower, "mail(") &&
(strings.Contains(contentLower, "password") || strings.Contains(contentLower, "email")) {
score += 2
indicators = append(indicators, "PHP mail() with credential data")
}
if hasPostData &&
(strings.Contains(contentLower, "fwrite(") || strings.Contains(contentLower, "file_put_contents(")) {
if strings.Contains(contentLower, "password") || strings.Contains(contentLower, "email") ||
strings.Contains(contentLower, "result") || strings.Contains(contentLower, "log") {
score += 2
indicators = append(indicators, "PHP writes credential data to file")
}
}
// Require brand impersonation to flag as phishing.
// PHP files with $_POST['email'] + mail() are normal (contact forms, CMS user
// admin, gallery software). Without brand impersonation, these are almost
// always legitimate applications.
if brandMatch == "" {
return nil
}
if score >= 4 {
return &phishingResult{brand: brandMatch, score: score, indicators: indicators}
}
return nil
}
// isKnownCMSFile returns true for PHP files that are standard CMS files
// and should not be scanned for phishing (too many false positives).
func isKnownCMSFile(nameLower string) bool {
cmsFiles := map[string]bool{
"index.php": true, "wp-config.php": true, "wp-login.php": true,
"wp-cron.php": true, "wp-settings.php": true, "wp-load.php": true,
"wp-blog-header.php": true, "wp-links-opml.php": true,
"xmlrpc.php": true, "wp-signup.php": true, "wp-activate.php": true,
"wp-trackback.php": true, "wp-comments-post.php": true,
"wp-mail.php": true, "configuration.php": true,
"config.php": true, "settings.php": true,
}
return cmsFiles[nameLower]
}
// ---------------------------------------------------------------------------
// Layer 5: PHP open redirectors
// ---------------------------------------------------------------------------
// checkPHPRedirector reads a small PHP file and checks if it's an open
// redirector - a file that redirects the visitor to a URL from a parameter.
func checkPHPRedirector(path string) string {
f, err := osFS.Open(path)
if err != nil {
return ""
}
defer func() { _ = f.Close() }()
buf := make([]byte, 1024)
n, _ := f.Read(buf)
if n == 0 {
return ""
}
content := strings.ToLower(string(buf[:n]))
hasHeader := strings.Contains(content, "header(") &&
(strings.Contains(content, "location:") || strings.Contains(content, "location :"))
if !hasHeader {
return ""
}
// Pattern 1: user-controlled redirect target - the URL in header() must
// come from user input. Just having $_GET anywhere + header() is too broad;
// normal form handlers use $_POST for data then header() for redirect.
// Only flag when the redirect URL itself is parameterized.
userControlledRedirect := false
redirectPatterns := []string{
"$_get['url']", "$_get[\"url\"]",
"$_get['redirect']", "$_get[\"redirect\"]",
"$_get['r']", "$_get[\"r\"]",
"$_get['return']", "$_get[\"return\"]",
"$_get['next']", "$_get[\"next\"]",
"$_get['goto']", "$_get[\"goto\"]",
"$_get['link']", "$_get[\"link\"]",
"$_request['url']", "$_request[\"url\"]",
"$_request['redirect']", "$_request[\"redirect\"]",
"header(\"location: \".$_get", "header(\"location: \".$_request",
"header('location: '.$_get", "header('location: '.$_request",
"header(\"location:\".$_get", "header('location:'.$_get",
}
for _, p := range redirectPatterns {
if strings.Contains(content, p) {
userControlledRedirect = true
break
}
}
if userControlledRedirect {
return "PHP open redirector: header(Location) with user-supplied URL"
}
// Pattern 2: Hardcoded redirect to suspicious domain
for _, pattern := range exfilPatterns {
if strings.Contains(content, pattern) {
return fmt.Sprintf("PHP redirect to suspicious destination matching '%s'", pattern)
}
}
return ""
}
// ---------------------------------------------------------------------------
// Layer 6: Credential log files
// ---------------------------------------------------------------------------
// isCredentialLogName checks if a filename matches patterns used by phishing
// kits to store harvested credentials.
func isCredentialLogName(nameLower string) bool {
// Exact names commonly used by phishing kits
exactNames := map[string]bool{
"results.txt": true, "result.txt": true, "log.txt": true,
"logs.txt": true, "emails.txt": true, "data.txt": true,
"passwords.txt": true, "creds.txt": true, "credentials.txt": true,
"victims.txt": true, "output.txt": true, "harvested.txt": true,
"results.log": true, "emails.log": true, "data.log": true,
"results.csv": true, "emails.csv": true, "data.csv": true,
"results.html": true,
}
if exactNames[nameLower] {
return true
}
// Pattern: contains "result", "victim", "harvested", "credential" in name
suspiciousWords := []string{"result", "victim", "harvest", "credential", "creds", "stolen"}
for _, word := range suspiciousWords {
if strings.Contains(nameLower, word) {
return true
}
}
return false
}
// checkCredentialLog reads a text file and checks if it contains harvested
// credentials (email:password pairs, one per line).
func checkCredentialLog(path string) string {
f, err := osFS.Open(path)
if err != nil {
return ""
}
defer func() { _ = f.Close() }()
// Read first 4KB
buf := make([]byte, 4096)
n, _ := f.Read(buf)
if n == 0 {
return ""
}
content := string(buf[:n])
lines := strings.Split(content, "\n")
credentialLines := 0
emailCount := 0
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Pattern: email:password or email|password or email,password
if strings.Contains(line, "@") {
emailCount++
// Check for delimiter after email
for _, delim := range []string{":", "|", "\t", ","} {
parts := strings.SplitN(line, delim, 3)
if len(parts) >= 2 {
part0 := strings.TrimSpace(parts[0])
part1 := strings.TrimSpace(parts[1])
if strings.Contains(part0, "@") && len(part1) > 0 && !strings.Contains(part1, " ") {
credentialLines++
break
}
}
}
}
}
// 3+ lines that look like email:password pairs = credential log
if credentialLines >= 3 {
return fmt.Sprintf("File contains %d credential-like lines (email:password format) out of %d email lines",
credentialLines, emailCount)
}
// High density of emails alone (10+) in a non-.csv file is suspicious
if emailCount >= 10 && !strings.HasSuffix(strings.ToLower(path), ".csv") {
return fmt.Sprintf("File contains %d email addresses - possible harvested email list", emailCount)
}
return ""
}
// ---------------------------------------------------------------------------
// Layer 7: Iframe phishing
// ---------------------------------------------------------------------------
// checkIframePhishing checks small HTML files for iframe-based phishing -
// a minimal HTML page that just loads an external phishing page in a full-screen iframe.
func checkIframePhishing(path string) string {
f, err := osFS.Open(path)
if err != nil {
return ""
}
defer func() { _ = f.Close() }()
buf := make([]byte, 3000)
n, _ := f.Read(buf)
if n == 0 {
return ""
}
contentLower := strings.ToLower(string(buf[:n]))
// Must contain an iframe
if !strings.Contains(contentLower, "<iframe") {
return ""
}
// Check if iframe src points to an external URL
idx := strings.Index(contentLower, "<iframe")
if idx < 0 {
return ""
}
rest := contentLower[idx:]
endTag := strings.Index(rest, ">")
if endTag < 0 {
return ""
}
iframeTag := rest[:endTag+1]
// Extract src
srcIdx := strings.Index(iframeTag, "src=")
if srcIdx < 0 {
return ""
}
srcRest := iframeTag[srcIdx+4:]
if len(srcRest) == 0 {
return ""
}
quote := srcRest[0]
if quote != '"' && quote != '\'' {
return ""
}
srcEnd := strings.IndexByte(srcRest[1:], quote)
if srcEnd < 0 {
return ""
}
src := srcRest[1 : srcEnd+1]
// Must be external (http:// or https://)
if !strings.HasPrefix(src, "http://") && !strings.HasPrefix(src, "https://") {
return ""
}
// Check if iframe is fullscreen (width/height 100% or style covers viewport)
isFullscreen := strings.Contains(iframeTag, "100%") ||
strings.Contains(contentLower, "width:100%") ||
strings.Contains(contentLower, "width: 100%") ||
strings.Contains(contentLower, "position:fixed") ||
strings.Contains(contentLower, "position: fixed")
if isFullscreen {
return fmt.Sprintf("Full-screen iframe loading external URL: %s", src)
}
// Even non-fullscreen, check if URL matches known phishing/exfil patterns
for _, pattern := range exfilPatterns {
if strings.Contains(src, pattern) {
return fmt.Sprintf("Iframe loading suspicious external URL matching '%s': %s", pattern, src)
}
}
return ""
}
// ---------------------------------------------------------------------------
// Layer 8: Phishing kit ZIP archives
// ---------------------------------------------------------------------------
// isPhishingKitZip checks if a ZIP filename matches common phishing kit names.
// Requires 2+ keyword matches to reduce false positives (e.g. "CssCheckboxKit"
// matched "kit" alone, but legitimate UI kits, CSS kits, etc. are common).
func isPhishingKitZip(nameLower string) bool {
// High-confidence single-match keywords (brand impersonation in filename)
singleMatch := []string{
"office365", "office 365", "sharepoint", "onedrive",
"microsoft", "outlook", "gmail",
"dropbox", "docusign", "wetransfer",
"paypal", "icloud", "netflix",
"facebook", "instagram", "linkedin",
"roundcube", "cpanel",
"phish", "scam",
}
for _, kw := range singleMatch {
if strings.Contains(nameLower, kw) {
return true
}
}
// Lower-confidence keywords - require 2+ matches to flag.
// Words like "login", "verify", "secure", "google", "apple", "bank"
// appear in legitimate archives too.
multiMatch := []string{
"login", "verify", "secure", "bank",
"google", "apple", "adobe", "webmail",
}
matches := 0
for _, kw := range multiMatch {
if strings.Contains(nameLower, kw) {
matches++
}
}
return matches >= 2
}
// ---------------------------------------------------------------------------
// Safe directory list
// ---------------------------------------------------------------------------
func isKnownSafeDir(name string) bool {
safeDirs := map[string]bool{
"wp-admin": true, "wp-includes": true, "wp-content": true,
"node_modules": true, "vendor": true, ".git": true,
"cgi-bin": true, ".well-known": true, "mail": true,
"cache": true, "tmp": true, "logs": true,
}
return safeDirs[name]
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const phpContentReadSize = 32768 // Read first 32KB for analysis
// CheckPHPContent scans new/suspicious PHP files for obfuscation patterns,
// remote payload fetching, and eval chains. This is designed to catch droppers
// like the LEVIATHAN attack's file.php and files.php that use goto spaghetti,
// hex-encoded strings, and call_user_func with string-built function names.
//
// This check scans PHP files in directories that shouldn't normally contain
// user-authored PHP: wp-content/languages, wp-content/upgrade, wp-content/mu-plugins,
// and also checks any PHP files flagged by the file index as new.
func CheckPHPContent(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
homeDirs, err := GetScanHomeDirs()
if err != nil {
return nil
}
for _, homeEntry := range homeDirs {
if ctx.Err() != nil {
return findings
}
if !homeEntry.IsDir() {
continue
}
homeDir := filepath.Join("/home", homeEntry.Name())
// Get all potential document roots
docRoots := []string{filepath.Join(homeDir, "public_html")}
subDirs, _ := osFS.ReadDir(homeDir)
for _, sd := range subDirs {
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
docRoots = append(docRoots, filepath.Join(homeDir, sd.Name()))
}
}
for _, docRoot := range docRoots {
// Scan directories that shouldn't contain user PHP
suspiciousDirs := []string{
filepath.Join(docRoot, "wp-content", "languages"),
filepath.Join(docRoot, "wp-content", "upgrade"),
filepath.Join(docRoot, "wp-content", "mu-plugins"),
filepath.Join(docRoot, "wp-content", "plugins"),
filepath.Join(docRoot, "wp-content", "themes"),
}
for _, dir := range suspiciousDirs {
scanDirForObfuscatedPHP(ctx, dir, 4, cfg, &findings)
if ctx.Err() != nil {
return findings
}
}
}
}
return findings
}
// scanDirForObfuscatedPHP recursively scans directories for PHP files with
// malicious content patterns.
func scanDirForObfuscatedPHP(ctx context.Context, dir string, maxDepth int, cfg *config.Config, findings *[]alert.Finding) {
if ctx.Err() != nil {
return
}
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
if ctx.Err() != nil {
return
}
name := entry.Name()
fullPath := filepath.Join(dir, name)
// Check suppressed paths
suppressed := false
for _, ignore := range cfg.Suppressions.IgnorePaths {
if matchGlob(fullPath, ignore) {
suppressed = true
break
}
}
if suppressed {
continue
}
if entry.IsDir() {
scanDirForObfuscatedPHP(ctx, fullPath, maxDepth-1, cfg, findings)
continue
}
nameLower := strings.ToLower(name)
if !strings.HasSuffix(nameLower, ".php") {
continue
}
// Skip known safe files
if isSafePHPInWPDir(fullPath, name) {
continue
}
// Read and analyze content
result := analyzePHPContent(fullPath)
if result.severity >= 0 {
info, _ := osFS.Stat(fullPath)
details := result.details
if info != nil {
details += fmt.Sprintf("\nSize: %d, Mtime: %s", info.Size(), info.ModTime().Format("2006-01-02 15:04:05"))
}
*findings = append(*findings, alert.Finding{
Severity: result.severity,
Check: result.check,
Message: fmt.Sprintf("%s: %s", result.message, fullPath),
Details: details,
FilePath: fullPath,
})
}
}
}
type phpAnalysisResult struct {
severity alert.Severity
check string
message string
details string
}
// analyzePHPContent reads the first 8KB of a PHP file and checks for
// obfuscation and malicious patterns.
func analyzePHPContent(path string) phpAnalysisResult {
f, err := osFS.Open(path)
if err != nil {
return phpAnalysisResult{severity: -1}
}
defer func() { _ = f.Close() }()
buf := make([]byte, phpContentReadSize)
n, _ := f.Read(buf)
if n == 0 {
return phpAnalysisResult{severity: -1}
}
content := string(buf[:n])
contentLower := strings.ToLower(content)
var indicators []string
// --- Critical: Remote payload fetching ---
// Paste sites are always suspicious in PHP files.
// GitHub raw URLs are common in legitimate plugin update checkers,
// so only count them as indicators when they appear on the same line
// as a dangerous PHP function call.
pasteHosts := []string{
"pastebin.com/raw",
"paste.ee/r/",
"ghostbin.co/paste/",
"hastebin.com/raw/",
}
for _, host := range pasteHosts {
if strings.Contains(contentLower, host) {
indicators = append(indicators, fmt.Sprintf("remote payload URL: %s", host))
}
}
githubHosts := []string{"gist.githubusercontent.com", "raw.githubusercontent.com"}
dangerousCalls := []string{"file_put_contents(", "fwrite(", "shell_", "passthru(", "popen("}
for _, host := range githubHosts {
if !strings.Contains(contentLower, host) {
continue
}
// Same-line = strong signal (critical)
sameLine := false
for _, line := range strings.Split(contentLower, "\n") {
if !strings.Contains(line, host) {
continue
}
for _, fn := range dangerousCalls {
if strings.Contains(line, fn) {
indicators = append(indicators, fmt.Sprintf("remote payload URL with dangerous call: %s", host))
sameLine = true
break
}
}
if sameLine {
break
}
}
// Co-presence = weaker signal (contributes to multi-indicator scoring)
if !sameLine {
for _, fn := range dangerousCalls {
if strings.Contains(contentLower, fn) {
indicators = append(indicators, fmt.Sprintf("remote URL co-present with %s: %s", fn, host))
break
}
}
}
}
// --- Critical: eval() chains with decoding ---
// Only flag when eval directly wraps a decoder (structural nesting),
// not when they merely co-exist in the same file (which causes false
// positives on legitimate plugins that use eval for templates and
// base64_decode for unrelated data processing).
decoders := []string{
"base64_decode", "gzinflate", "gzuncompress", "str_rot13",
"rawurldecode", "gzdecode", "bzdecompress",
}
hasDecoder := false
hasNestedEvalDecode := false
for _, d := range decoders {
if strings.Contains(contentLower, d) {
hasDecoder = true
}
}
// Check for structural nesting: eval(base64_decode(...)), eval(gzinflate(...)), etc.
// Scan individual lines for the nesting pattern to avoid flagging unrelated
// occurrences on distant lines.
for _, line := range strings.Split(contentLower, "\n") {
for _, d := range decoders {
if strings.Contains(line, "eval(") && strings.Contains(line, d+"(") {
hasNestedEvalDecode = true
break
}
if strings.Contains(line, "assert(") && strings.Contains(line, d+"(") {
hasNestedEvalDecode = true
break
}
}
if hasNestedEvalDecode {
break
}
}
if hasNestedEvalDecode {
indicators = append(indicators, "eval() directly wrapping encoding/compression function")
}
// --- Critical: call_user_func with string-built function names ---
// This is the exact technique used in the LEVIATHAN droppers
if strings.Contains(contentLower, "call_user_func") {
// Check if function names are built from string concatenation
if countOccurrences(content, `"\x`) > 5 || countOccurrences(content, "\" . \"") > 10 {
indicators = append(indicators, "call_user_func with obfuscated function names")
}
}
// --- High: Goto obfuscation (LEVIATHAN signature) ---
gotoCount := countOccurrences(contentLower, "goto ")
if gotoCount > 10 {
indicators = append(indicators, fmt.Sprintf("excessive goto statements (%d found - obfuscation pattern)", gotoCount))
}
// --- High: Hex-encoded string construction ---
// Only flag hex strings when accompanied by concatenation - real obfuscation
// builds function names like "\x63" . "\x75" . "\x72" . "\x6c" (= "curl").
// Standalone hex arrays (Wordfence IPv6 subnet masks, binary data) are benign.
hexStringCount := countOccurrences(content, `"\x`)
dotConcatCount := countOccurrences(content, `" . "`)
if hexStringCount > 20 && dotConcatCount > 10 {
indicators = append(indicators, fmt.Sprintf("heavy hex-encoded strings with concatenation (%d hex, %d concat - obfuscation pattern)", hexStringCount, dotConcatCount))
} else if dotConcatCount > 30 {
indicators = append(indicators, fmt.Sprintf("excessive string concatenation (%d - function name obfuscation)", dotConcatCount))
}
// --- High: Variable function calls with obfuscated names ---
// call_user_func + decoder alone is too broad - Elementor, WooCommerce, and
// dozens of plugins use call_user_func_array with base64_decode legitimately.
// Only flag when combined with actual obfuscation (hex strings, heavy concat).
if strings.Contains(contentLower, "call_user_func") && hasDecoder {
if hexStringCount > 5 || dotConcatCount > 5 {
indicators = append(indicators, "variable function call with decoder and obfuscation")
}
}
// --- High: Shell execution functions combined with request input ---
// Uses containsStandaloneFunc to avoid substring false positives
// (e.g. "WP_Filesystem(" matching "exec(", "preg_match(" matching "exec(")
shellFuncs := []string{"system(", "passthru(", "exec(", "shell_exec(", "popen(", "proc_open(", "pcntl_exec("}
requestVars := []string{"$_request", "$_post", "$_get", "$_cookie", "$_server"}
// Two-tier detection:
// Same line = CRITICAL signal (auto-quarantine eligible)
// Co-presence = HIGH signal (alert only, not quarantined alone)
// This prevents bypass by splitting across lines while avoiding
// false-positive quarantine of legitimate plugins.
hasShellFunc := false
hasRequestVar := false
sameLineShellRequest := false
for _, sf := range shellFuncs {
if containsStandaloneFunc(contentLower, sf) {
hasShellFunc = true
break
}
}
for _, rv := range requestVars {
if strings.Contains(contentLower, rv) {
hasRequestVar = true
break
}
}
if hasShellFunc && hasRequestVar {
// Check for same-line (strong signal)
for _, line := range strings.Split(contentLower, "\n") {
lineHasShell := false
for _, sf := range shellFuncs {
if containsStandaloneFunc(line, sf) {
lineHasShell = true
break
}
}
if !lineHasShell {
continue
}
for _, rv := range requestVars {
if strings.Contains(line, rv) {
sameLineShellRequest = true
break
}
}
if sameLineShellRequest {
break
}
}
if sameLineShellRequest {
indicators = append(indicators, "shell function with request input on same line")
} else if !IsVerifiedCMSFile(path) {
// Co-presence in non-verified file = weaker signal
indicators = append(indicators, "shell function co-present with request input")
}
}
// --- High: base64 encoding/decoding with execution on same line ---
if strings.Contains(contentLower, "base64_decode") && strings.Contains(contentLower, "base64_encode") {
for _, line := range strings.Split(contentLower, "\n") {
hasBoth := strings.Contains(line, "base64_decode") && strings.Contains(line, "base64_encode")
hasExec := false
for _, sf := range shellFuncs {
if containsStandaloneFunc(line, sf) {
hasExec = true
break
}
}
if hasBoth && hasExec {
indicators = append(indicators, "base64 encode+decode with execution on same line (command relay)")
break
}
}
}
// --- Determine severity based on indicators ---
if len(indicators) == 0 {
return phpAnalysisResult{severity: -1}
}
// Multiple indicators or remote payloads = critical
if len(indicators) >= 2 || containsAny(indicators, "remote payload", "call_user_func with obfuscated") {
return phpAnalysisResult{
severity: alert.Critical,
check: "obfuscated_php",
message: "Obfuscated/malicious PHP detected",
details: fmt.Sprintf("Indicators found:\n- %s", strings.Join(indicators, "\n- ")),
}
}
return phpAnalysisResult{
severity: alert.High,
check: "suspicious_php_content",
message: "Suspicious PHP content detected",
details: fmt.Sprintf("Indicators found:\n- %s", strings.Join(indicators, "\n- ")),
}
}
// isSafePHPInWPDir returns true for known legitimate PHP files in WP directories
// like languages (translation files) and upgrade (empty index.php).
func isSafePHPInWPDir(path, name string) bool {
nameLower := strings.ToLower(name)
// WordPress translation files: *.l10n.php, *.mo, admin-*.php patterns
if strings.HasSuffix(nameLower, ".l10n.php") {
return true
}
if nameLower == "index.php" {
return true
}
// Known safe patterns in wp-content/languages/
if strings.Contains(path, "/wp-content/languages/") {
// Legitimate translation files have standard naming
if strings.HasPrefix(nameLower, "admin-") ||
strings.HasPrefix(nameLower, "continents-") ||
strings.Contains(path, "/languages/plugins/") ||
strings.Contains(path, "/languages/themes/") {
return true
}
// WP 6.5+ PHP translation files: xx_XX.php format (2-5 letter locale codes)
noExt := strings.TrimSuffix(nameLower, ".php")
if strings.Contains(noExt, "_") && len(noExt) <= 10 && !strings.ContainsAny(noExt, " /.\\") {
return true
}
}
// Known safe in mu-plugins - common hosting provider mu-plugins
if strings.Contains(path, "/mu-plugins/") {
safeMuPlugins := []string{
"endurance", "starter", "imunify", "wp-toolkit",
"starter-plugin", "starter_plugin",
"jetpack", "object-cache", "redis-cache",
"cloudlinux", "alt-php",
}
for _, safe := range safeMuPlugins {
if strings.Contains(nameLower, safe) {
return true
}
}
}
// Files in vendor/ or node_modules/ subdirectories within plugins/themes
// are third-party dependencies and should not be flagged.
if strings.Contains(path, "/wp-content/plugins/") || strings.Contains(path, "/wp-content/themes/") {
if strings.Contains(path, "/vendor/") || strings.Contains(path, "/node_modules/") {
return true
}
}
return false
}
func countOccurrences(s, substr string) int {
count := 0
offset := 0
for {
idx := strings.Index(s[offset:], substr)
if idx < 0 {
break
}
count++
offset += idx + len(substr)
}
return count
}
// containsStandaloneFunc checks if content contains a function call like "eval("
// without it being part of a longer function name (e.g. "doubleval(").
// Requires the character before the match to be non-alphanumeric or start-of-string.
func containsStandaloneFunc(content, funcCall string) bool {
idx := 0
for {
pos := strings.Index(content[idx:], funcCall)
if pos < 0 {
return false
}
absPos := idx + pos
if absPos == 0 {
return true // at start of content
}
prev := content[absPos-1]
// Must not be preceded by a letter, digit, or underscore
isAlnum := (prev >= 'a' && prev <= 'z') || (prev >= 'A' && prev <= 'Z') ||
(prev >= '0' && prev <= '9') || prev == '_'
if !isAlnum {
return true
}
idx = absPos + len(funcCall)
if idx >= len(content) {
return false
}
}
}
func containsAny(strs []string, substrs ...string) bool {
for _, s := range strs {
for _, sub := range substrs {
if strings.Contains(s, sub) {
return true
}
}
}
return false
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckPHPConfigChanges monitors .user.ini and .htaccess for PHP configuration
// changes that weaken security (disabling disable_functions, enabling dangerous functions).
// This runs as a deep check. The fanotify watcher also catches .user.ini writes in real-time.
func CheckPHPConfigChanges(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
homeDirs, _ := GetScanHomeDirs()
for _, homeEntry := range homeDirs {
if !homeEntry.IsDir() {
continue
}
user := homeEntry.Name()
// Check .user.ini in public_html and addon domains
iniPaths := []string{
filepath.Join("/home", user, "public_html", ".user.ini"),
}
subDirs, _ := osFS.ReadDir(filepath.Join("/home", user))
for _, sd := range subDirs {
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
iniPath := filepath.Join("/home", user, sd.Name(), ".user.ini")
if _, err := osFS.Stat(iniPath); err == nil {
iniPaths = append(iniPaths, iniPath)
}
}
}
for _, iniPath := range iniPaths {
// Hash-based change detection
hash, err := hashFileContent(iniPath)
if err != nil {
continue
}
key := "_phpini:" + iniPath
prev, exists := store.GetRaw(key)
store.SetRaw(key, hash)
if !exists || prev == hash {
continue
}
// File changed - analyze content for dangerous settings
data, err := osFS.ReadFile(iniPath)
if err != nil {
continue
}
content := strings.ToLower(string(data))
// Check for dangerous PHP settings
dangerous := analyzePHPINI(content)
if len(dangerous) > 0 {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "php_config_change",
Message: fmt.Sprintf("Dangerous PHP config change: %s (user: %s)", iniPath, user),
Details: fmt.Sprintf("Dangerous settings:\n- %s", strings.Join(dangerous, "\n- ")),
})
}
}
}
return findings
}
func analyzePHPINI(content string) []string {
var dangerous []string
// disable_functions being cleared or reduced
if strings.Contains(content, "disable_functions") {
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, ";") || strings.HasPrefix(line, "#") {
continue
}
if strings.HasPrefix(line, "disable_functions") {
// Check if it's being set to empty or very short
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
val := strings.TrimSpace(parts[1])
if val == "" || val == "\"\"" || val == "''" || val == "none" {
dangerous = append(dangerous, "disable_functions cleared (all PHP functions enabled)")
}
}
}
}
}
// Dangerous functions being enabled
dangerousFuncs := []string{
"exec", "system", "passthru", "shell_exec",
"popen", "proc_open", "pcntl_exec",
}
if strings.Contains(content, "disable_functions") {
// Check if dangerous functions are NOT in the disable list
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "disable_functions") {
for _, fn := range dangerousFuncs {
if !strings.Contains(line, fn) {
// This dangerous function is not disabled
dangerous = append(dangerous, fmt.Sprintf("%s not in disable_functions", fn))
}
}
break
}
}
}
// allow_url_fopen / allow_url_include being enabled
if strings.Contains(content, "allow_url_include") {
for _, line := range strings.Split(content, "\n") {
if strings.Contains(line, "allow_url_include") && (strings.Contains(line, "on") || strings.Contains(line, "1")) {
dangerous = append(dangerous, "allow_url_include enabled (remote code inclusion)")
}
}
}
// open_basedir being removed
if strings.Contains(content, "open_basedir") {
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "open_basedir") {
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
val := strings.TrimSpace(parts[1])
if val == "" || val == "/" || val == "\"\"" {
dangerous = append(dangerous, "open_basedir cleared or set to / (no restriction)")
}
}
}
}
}
return dangerous
}
package checks
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
neturl "net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
)
var wpOrgHTTPClient = &http.Client{Timeout: 10 * time.Second}
type wpOrgResponse struct {
Slug string `json:"slug"`
Version string `json:"version"`
Tested string `json:"tested"`
Error string `json:"error"`
}
// parseWPOrgPluginResponse parses a JSON body from the WordPress.org plugin
// information API into a store.PluginInfo. It returns an error if the response
// contains an error field (e.g. "Plugin not found.") or if the JSON is invalid.
func parseWPOrgPluginResponse(body []byte) (store.PluginInfo, error) {
var resp wpOrgResponse
if err := json.Unmarshal(body, &resp); err != nil {
return store.PluginInfo{}, fmt.Errorf("wporg: invalid JSON: %w", err)
}
if resp.Error != "" {
return store.PluginInfo{}, fmt.Errorf("wporg: %s", resp.Error)
}
return store.PluginInfo{
LatestVersion: resp.Version,
TestedUpTo: resp.Tested,
LastChecked: time.Now().Unix(),
}, nil
}
// fetchWPOrgPluginInfo queries the WordPress.org plugin information API for the
// given slug and returns the parsed PluginInfo.
func fetchWPOrgPluginInfo(ctx context.Context, slug string) (store.PluginInfo, error) {
url := "https://api.wordpress.org/plugins/info/1.2/?action=plugin_information" +
"&request[slug]=" + neturl.QueryEscape(slug) +
"&request[fields][version]=1&request[fields][tested]=1"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return store.PluginInfo{}, fmt.Errorf("wporg: building request: %w", err)
}
resp, err := wpOrgHTTPClient.Do(req)
if err != nil {
return store.PluginInfo{}, fmt.Errorf("wporg: HTTP request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return store.PluginInfo{}, fmt.Errorf("wporg: reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return store.PluginInfo{}, fmt.Errorf("wporg: unexpected status %d", resp.StatusCode)
}
return parseWPOrgPluginResponse(body)
}
// parseVersion splits a dotted version string like "6.4.2" into []int{6, 4, 2}.
// Non-numeric segments are treated as 0.
func parseVersion(v string) []int {
if v == "" {
return nil
}
parts := strings.Split(v, ".")
out := make([]int, len(parts))
for i, p := range parts {
n, err := strconv.Atoi(p)
if err != nil {
n = 0
}
out[i] = n
}
return out
}
// compareVersions returns whether there is a major version gap and
// how many minor versions behind the installed version is.
func compareVersions(installed, available string) (majorGap bool, minorBehind int) {
iv := parseVersion(installed)
av := parseVersion(available)
if len(iv) < 2 || len(av) < 2 {
return false, 0
}
if av[0] > iv[0] {
return true, 0
}
if av[0] == iv[0] && av[1] > iv[1] {
return false, av[1] - iv[1]
}
return false, 0
}
// pluginAlertSeverity returns "critical", "high", "warning", or "" for a version gap.
func pluginAlertSeverity(installed, available string) string {
majorGap, minorBehind := compareVersions(installed, available)
if majorGap {
return "critical"
}
if minorBehind >= 3 {
return "high"
}
// Check if there is any difference at all.
iv := parseVersion(installed)
av := parseVersion(available)
if len(iv) < 2 || len(av) < 2 {
return ""
}
// Compare all parsed components to detect if available is actually newer.
// If installed >= available at every component, the site is up to date
// (or ahead, e.g. custom/premium builds). Only warn when behind.
maxLen := len(iv)
if len(av) > maxLen {
maxLen = len(av)
}
for i := 0; i < maxLen; i++ {
var a, b int
if i < len(iv) {
a = iv[i]
}
if i < len(av) {
b = av[i]
}
if b > a {
return "warning" // available is newer at this component
}
if a > b {
return "" // installed is ahead - not outdated
}
}
return "" // identical
}
const pluginCheckWorkers = 5
// CheckOutdatedPlugins scans all WordPress installations for plugins with
// available updates and emits findings based on severity of the version gap.
// Results are cached in bbolt with a configurable refresh interval (default 24h).
func CheckOutdatedPlugins(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
db := store.Global()
if db == nil {
return nil
}
// Refresh cache if stale.
lastRefresh := db.GetPluginRefreshTime()
interval := time.Duration(cfg.Thresholds.PluginCheckIntervalMin) * time.Minute
if time.Since(lastRefresh) > interval {
refreshPluginCache(ctx, db)
}
return evaluatePluginCache(db)
}
// findAllWPInstalls discovers all wp-config.php files under /home, deduplicating
// and skipping cache/backup/staging/trash paths.
func findAllWPInstalls() []string {
patterns := []string{
"/home/*/public_html/wp-config.php",
"/home/*/public_html/*/wp-config.php",
"/home/*/*/wp-config.php",
}
seen := make(map[string]bool)
var results []string
skipSubstrings := []string{"/cache/", "/backup", "/staging", "/.trash/"}
for _, pattern := range patterns {
matches, _ := osFS.Glob(pattern)
for _, m := range matches {
skip := false
lower := strings.ToLower(m)
for _, sub := range skipSubstrings {
if strings.Contains(lower, sub) {
skip = true
break
}
}
if skip {
continue
}
if !seen[m] {
seen[m] = true
results = append(results, m)
}
}
}
return results
}
// wpCLIPluginEntry mirrors the JSON output of `wp plugin list --format=json`.
type wpCLIPluginEntry struct {
Name string `json:"name"`
Status string `json:"status"`
Version string `json:"version"`
UpdateVersion string `json:"update_version"`
}
// refreshPluginCache discovers all WP installs, runs wp-cli to inventory
// plugins for each site, enriches free plugins via the WordPress.org API,
// and stores everything in bbolt.
func refreshPluginCache(ctx context.Context, db *store.DB) {
wpConfigs := findAllWPInstalls()
if len(wpConfigs) == 0 {
return
}
var mu sync.Mutex
var wg sync.WaitGroup
successCount := 0
slugsSeen := make(map[string]bool)
discoveredPaths := make(map[string]bool)
jobs := make(chan string, len(wpConfigs))
for i := 0; i < pluginCheckWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for wpConfig := range jobs {
if ctx.Err() != nil {
return
}
wpPath := filepath.Dir(wpConfig)
user := extractUser(wpPath)
domain := extractWPDomain(ctx, wpPath, user)
if ctx.Err() != nil {
return
}
mu.Lock()
discoveredPaths[wpPath] = true
mu.Unlock()
// Run wp plugin list as the site owner.
// Use --path flag instead of shell cd to avoid shell injection
// via crafted directory names on shared hosting.
// Routed through cmdExec so tests can mock the wp-cli output.
out, err := cmdExec.RunContext(ctx, "su", "-", user, "-s", "/bin/bash", "-c",
"wp plugin list --fields=name,status,version,update_version --format=json --path="+shellQuote(wpPath),
)
if err != nil {
if ctx.Err() != nil {
return
}
fmt.Fprintf(os.Stderr, "plugincheck: wp-cli failed for %s: %v\n", wpPath, err)
continue
}
var entries []wpCLIPluginEntry
if err := json.Unmarshal(out, &entries); err != nil {
fmt.Fprintf(os.Stderr, "plugincheck: JSON parse failed for %s: %v\n", wpPath, err)
continue
}
sitePlugins := store.SitePlugins{
Account: user,
Domain: domain,
}
for _, e := range entries {
sitePlugins.Plugins = append(sitePlugins.Plugins, store.SitePluginEntry{
Slug: e.Name,
Name: e.Name,
Status: e.Status,
InstalledVersion: e.Version,
UpdateVersion: e.UpdateVersion,
})
mu.Lock()
slugsSeen[e.Name] = true
mu.Unlock()
}
if err := db.SetSitePlugins(wpPath, sitePlugins); err != nil {
fmt.Fprintf(os.Stderr, "plugincheck: store failed for %s: %v\n", wpPath, err)
continue
}
mu.Lock()
successCount++
mu.Unlock()
}
}()
}
for _, wpConfig := range wpConfigs {
if ctx.Err() != nil {
break
}
jobs <- wpConfig
}
close(jobs)
wg.Wait()
if ctx.Err() != nil {
return
}
// Enrich free plugins via WordPress.org API (one lookup per unique slug).
mu.Lock()
slugList := make([]string, 0, len(slugsSeen))
for slug := range slugsSeen {
slugList = append(slugList, slug)
}
mu.Unlock()
for _, slug := range slugList {
if ctx.Err() != nil {
return
}
// Skip if we have a recent cached entry (< 24h).
if cached, ok := db.GetPluginInfo(slug); ok {
if time.Since(time.Unix(cached.LastChecked, 0)) < 24*time.Hour {
continue
}
}
info, err := fetchWPOrgPluginInfo(ctx, slug)
if err != nil {
// Not found on .org = premium/custom plugin, skip silently.
continue
}
_ = db.SetPluginInfo(slug, info)
}
// Prune cache entries for WP installs no longer on disk.
allCached := db.AllSitePlugins()
for path := range allCached {
if !discoveredPaths[path] {
_ = db.DeleteSitePlugins(path)
}
}
// Only mark refresh as complete if the majority of sites refreshed
// successfully. A partial failure (e.g. one wp-cli timeout on a 100-site
// server) should not freeze ALL stale data for 24 hours. But if most
// sites failed (e.g. transient PHP issue), don't mark as fresh - allow
// retry next cycle.
mu.Lock()
sc := successCount
mu.Unlock()
failCount := len(wpConfigs) - sc
if sc == 0 {
fmt.Fprintf(os.Stderr, "[%s] Plugin cache refresh FAILED: 0/%d sites succeeded, not updating timestamp\n",
time.Now().Format("2006-01-02 15:04:05"), len(wpConfigs))
return
}
if failCount > sc {
fmt.Fprintf(os.Stderr, "[%s] Plugin cache refresh PARTIAL: %d/%d sites failed (majority), not updating timestamp\n",
time.Now().Format("2006-01-02 15:04:05"), failCount, len(wpConfigs))
return
}
_ = db.SetPluginRefreshTime(time.Now())
}
// evaluatePluginCache reads the cached plugin inventory and emits findings
// for outdated or untracked active plugins.
func evaluatePluginCache(db *store.DB) []alert.Finding {
var findings []alert.Finding
allSites := db.AllSitePlugins()
for wpPath, site := range allSites {
for _, p := range site.Plugins {
if p.Status != "active" && p.Status != "active-network" {
continue
}
// Determine the best available version: prefer wp-cli's update_version,
// fall back to WordPress.org API cache.
available := p.UpdateVersion
if available == "" {
if info, ok := db.GetPluginInfo(p.Slug); ok {
available = info.LatestVersion
}
}
if available == "" {
// No update source - skip silently (custom/private plugins).
continue
}
sev := pluginAlertSeverity(p.InstalledVersion, available)
if sev == "" {
continue
}
var severity alert.Severity
switch sev {
case "critical":
severity = alert.Critical
case "high":
severity = alert.High
default:
severity = alert.Warning
}
findings = append(findings, alert.Finding{
Severity: severity,
Check: "outdated_plugins",
Message: fmt.Sprintf("Outdated plugin %q on %s (%s): %s -> %s", p.Name, site.Domain, site.Account, p.InstalledVersion, available),
Details: fmt.Sprintf("Path: %s\nInstalled: %s\nAvailable: %s\nSeverity: %s", wpPath, p.InstalledVersion, available, sev),
})
}
}
return findings
}
// extractWPDomain runs `wp option get siteurl` to discover the site's domain.
// Falls back to directory name heuristics if wp-cli fails.
func extractWPDomain(ctx context.Context, wpPath, user string) string {
out, err := cmdExec.RunContext(ctx, "su", "-", user, "-s", "/bin/bash", "-c",
"wp option get siteurl --path="+shellQuote(wpPath),
)
if err == nil {
url := strings.TrimSpace(string(out))
if url != "" {
// Strip protocol prefix for display.
url = strings.TrimPrefix(url, "https://")
url = strings.TrimPrefix(url, "http://")
return url
}
}
// Fallback: use directory name after public_html (addon domain)
// or account name (main domain).
parts := strings.Split(wpPath, "/")
for i, p := range parts {
if p == "public_html" && i+1 < len(parts) {
return parts[i+1]
}
}
return user
}
// shellQuote wraps a string in single quotes for safe shell argument passing.
// Any embedded single quotes are escaped as '\” (end quote, literal quote, start quote).
func shellQuote(s string) string {
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
}
package checks
import (
"context"
"fmt"
"path/filepath"
"strconv"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
func CheckFakeKernelThreads(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
procs, _ := osFS.Glob("/proc/[0-9]*/status")
for _, statusPath := range procs {
pid := filepath.Base(filepath.Dir(statusPath))
data, err := osFS.ReadFile(statusPath)
if err != nil {
continue
}
var name, uid string
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "Name:\t") {
name = strings.TrimPrefix(line, "Name:\t")
}
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
uid = fields[0]
}
}
}
// Kernel threads run as root (uid 0). Non-root process with
// a name that looks like a kernel thread is suspicious.
if uid == "0" || uid == "" {
continue
}
// Read cmdline - real kernel threads have empty cmdline
cmdline, _ := osFS.ReadFile(filepath.Join("/proc", pid, "cmdline"))
cmdStr := strings.TrimRight(strings.ReplaceAll(string(cmdline), "\x00", " "), " ")
// Check if the process name contains brackets (faking kernel thread)
// or if cmdline starts with [
if strings.HasPrefix(cmdStr, "[") || strings.HasPrefix(name, "[") {
// This is a non-root process masquerading as a kernel thread
exe, _ := osFS.Readlink(filepath.Join("/proc", pid, "exe"))
uidInt, _ := strconv.Atoi(uid)
pidInt, _ := strconv.Atoi(pid)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "fake_kernel_thread",
Message: fmt.Sprintf("Non-root process masquerading as kernel thread: [%s]", name),
Details: fmt.Sprintf("PID: %s, UID: %d, exe: %s, cmdline: %s", pid, uidInt, exe, cmdStr),
PID: pidInt,
})
}
}
return findings
}
func CheckSuspiciousProcesses(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
suspiciousNames := []string{"defunct", "gsocket", "gs-netcat", "gs-sftp"}
suspiciousCmdline := []string{
"/bin/sh -i", "/bin/bash -i", "bash -i",
"/dev/tcp/", "semutmerah", "gsocket",
"reverse", "nc -e", "ncat -e",
}
suspiciousPaths := []string{"/tmp/", "/dev/shm/", "/.config/"}
procs, _ := osFS.Glob("/proc/[0-9]*/exe")
for _, exePath := range procs {
pid := filepath.Base(filepath.Dir(exePath))
pidInt, _ := strconv.Atoi(pid)
statusData, _ := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
var uid string
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
uid = fields[0]
}
}
}
if uid == "0" {
continue // Skip root processes for this check
}
exe, _ := osFS.Readlink(exePath)
cmdline, _ := osFS.ReadFile(filepath.Join("/proc", pid, "cmdline"))
cmdStr := strings.TrimRight(strings.ReplaceAll(string(cmdline), "\x00", " "), " ")
// Check executable name
exeName := filepath.Base(exe)
for _, s := range suspiciousNames {
if strings.Contains(strings.ToLower(exeName), s) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "suspicious_process",
Message: fmt.Sprintf("Suspicious process name: %s", exeName),
Details: fmt.Sprintf("PID: %s, UID: %s, exe: %s, cmdline: %s", pid, uid, exe, cmdStr),
PID: pidInt,
})
}
}
// Check cmdline for suspicious patterns
cmdLower := strings.ToLower(cmdStr)
for _, s := range suspiciousCmdline {
if strings.Contains(cmdLower, strings.ToLower(s)) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "suspicious_process",
Message: fmt.Sprintf("Suspicious cmdline pattern: %s", s),
Details: fmt.Sprintf("PID: %s, UID: %s, exe: %s, cmdline: %s", pid, uid, exe, cmdStr),
PID: pidInt,
})
break
}
}
// Check executable path
for _, s := range suspiciousPaths {
if strings.Contains(exe, s) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "suspicious_process",
Message: fmt.Sprintf("Process running from suspicious path: %s", exe),
Details: fmt.Sprintf("PID: %s, UID: %s, cmdline: %s", pid, uid, cmdStr),
PID: pidInt,
})
break
}
}
}
return findings
}
// CheckPHPProcesses inspects running lsphp processes to detect active
// webshell execution. Only reads /proc cmdline - zero disk I/O.
func CheckPHPProcesses(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
suspiciousPHPPaths := []string{
"/tmp/",
"/dev/shm/",
"/wp-content/uploads/",
"/.config/",
}
procs, _ := osFS.Glob("/proc/[0-9]*/cmdline")
for _, cmdPath := range procs {
pid := filepath.Base(filepath.Dir(cmdPath))
pidInt, _ := strconv.Atoi(pid)
cmdline, err := osFS.ReadFile(cmdPath)
if err != nil {
continue
}
cmdStr := strings.ReplaceAll(string(cmdline), "\x00", " ")
// Only check lsphp processes
if !strings.Contains(cmdStr, "lsphp") {
continue
}
for _, sus := range suspiciousPHPPaths {
if strings.Contains(cmdStr, sus) {
statusData, _ := osFS.ReadFile(filepath.Join("/proc", pid, "status"))
var uid string
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
fields := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(fields) > 0 {
uid = fields[0]
}
}
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "php_suspicious_execution",
Message: fmt.Sprintf("PHP executing from suspicious path: %s", sus),
Details: fmt.Sprintf("PID: %s, UID: %s, cmdline: %s", pid, uid, strings.TrimSpace(cmdStr)),
PID: pidInt,
})
break
}
}
}
return findings
}
package checks
import (
"context"
"os"
"os/exec"
"path/filepath"
)
// ---------------------------------------------------------------------------
// OS abstraction — filesystem read operations
// ---------------------------------------------------------------------------
// OS abstracts filesystem read operations used by check functions.
// Production code uses realOS{}; tests swap in a mockOS via SetOS().
type OS interface {
ReadFile(name string) ([]byte, error)
ReadDir(name string) ([]os.DirEntry, error)
Stat(name string) (os.FileInfo, error)
Lstat(name string) (os.FileInfo, error)
Readlink(name string) (string, error)
Open(name string) (*os.File, error)
Glob(pattern string) ([]string, error)
}
type realOS struct{}
// #nosec G304 -- filesystem abstraction; check functions pass trusted paths.
func (realOS) ReadFile(name string) ([]byte, error) { return os.ReadFile(name) }
func (realOS) ReadDir(name string) ([]os.DirEntry, error) { return os.ReadDir(name) }
func (realOS) Stat(name string) (os.FileInfo, error) { return os.Stat(name) }
func (realOS) Lstat(name string) (os.FileInfo, error) { return os.Lstat(name) }
func (realOS) Readlink(name string) (string, error) { return os.Readlink(name) }
// #nosec G304 -- filesystem abstraction; check functions pass trusted paths.
func (realOS) Open(name string) (*os.File, error) { return os.Open(name) }
func (realOS) Glob(pattern string) ([]string, error) { return filepath.Glob(pattern) }
// osFS is the package-level filesystem provider. All check functions use
// this instead of calling os.ReadFile / os.ReadDir / etc. directly.
var osFS OS = realOS{}
// SetOS replaces the filesystem provider. Used by tests to inject mocks.
func SetOS(o OS) { osFS = o }
// ---------------------------------------------------------------------------
// CmdRunner abstraction — external command execution
// ---------------------------------------------------------------------------
// CmdRunner abstracts external command execution used by check functions.
// Production code uses realCmd{}; tests swap in a mockCmdRunner via SetCmdRunner().
type CmdRunner interface {
Run(name string, args ...string) ([]byte, error)
RunAllowNonZero(name string, args ...string) ([]byte, error)
RunContext(parent context.Context, name string, args ...string) ([]byte, error)
RunWithEnv(name string, args []string, extraEnv ...string) ([]byte, error)
LookPath(file string) (string, error)
}
type realCmd struct{}
func (realCmd) Run(name string, args ...string) ([]byte, error) {
return runCmdReal(name, args...)
}
func (realCmd) RunAllowNonZero(name string, args ...string) ([]byte, error) {
return runCmdAllowNonZeroReal(name, args...)
}
func (realCmd) RunContext(parent context.Context, name string, args ...string) ([]byte, error) {
return runCmdCombinedContextReal(parent, name, args...)
}
func (realCmd) RunWithEnv(name string, args []string, extraEnv ...string) ([]byte, error) {
return runCmdWithEnvReal(name, args, extraEnv...)
}
func (realCmd) LookPath(file string) (string, error) {
return exec.LookPath(file)
}
// cmdExec is the package-level command runner. All check functions use
// this instead of calling runCmd / exec.Command directly.
var cmdExec CmdRunner = realCmd{}
// SetCmdRunner replaces the command runner. Used by tests to inject mocks.
func SetCmdRunner(r CmdRunner) { cmdExec = r }
package checks
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"syscall"
"time"
)
// eximMsgIDRegex validates Exim message ID format (e.g., 2jKPFm-000abc-1X).
var eximMsgIDRegex = regexp.MustCompile(`^[0-9A-Za-z]{6}-[0-9A-Za-z]{6}-[0-9A-Za-z]{2}$`)
// Allowed roots for each fix action. Declared as vars (not consts) so tests
// can redirect remediation under t.TempDir() without writing to real /home,
// /tmp, or /var/spool. Production must not mutate these at runtime.
var (
fixPermissionsAllowedRoots = []string{"/home"}
fixQuarantineAllowedRoots = []string{"/home", "/tmp", "/dev/shm", "/var/tmp"}
fixHtaccessAllowedRoots = []string{"/home"}
eximSpoolDirs = []string{"/var/spool/exim/input", "/var/spool/exim4/input"}
)
// RemediationResult describes the outcome of a fix action.
type RemediationResult struct {
Success bool `json:"success"`
Action string `json:"action"` // human-readable description of what was done
Description string `json:"description"` // what fix was applied
Error string `json:"error,omitempty"`
}
// FixDescription returns a human-readable description of what the fix will do
// for a given check type and file path. Returns empty string if no fix is available.
func FixDescription(checkType, message string, filePath ...string) string {
path := selectFindingPath(message, filePath...)
switch checkType {
case "world_writable_php", "group_writable_php":
if path != "" {
return fmt.Sprintf("Set permissions to 644 on %s", path)
}
case "webshell", "new_webshell_file", "obfuscated_php", "php_dropper",
"suspicious_php_content", "new_php_in_languages", "new_php_in_upgrade",
"phishing_page", "phishing_directory":
if path != "" {
return fmt.Sprintf("Quarantine %s to /opt/csm/quarantine/", path)
}
case "backdoor_binary", "new_executable_in_config":
if path != "" {
return fmt.Sprintf("Kill process and quarantine %s", path)
}
case "suspicious_crontab":
if path != "" {
return fmt.Sprintf("Quarantine and truncate crontab %s", path)
}
return "Quarantine and truncate crontab"
case "htaccess_injection", "htaccess_handler_abuse":
if path != "" {
return fmt.Sprintf("Remove malicious directives from %s", path)
}
case "email_phishing_content":
msgID := extractEximMsgID(message)
if msgID != "" {
return fmt.Sprintf("Quarantine Exim spool message %s", msgID)
}
}
return ""
}
// HasFix returns true if the check type has a known automated fix.
func HasFix(checkType string) bool {
fixableChecks := map[string]bool{
"world_writable_php": true,
"group_writable_php": true,
"webshell": true,
"new_webshell_file": true,
"obfuscated_php": true,
"php_dropper": true,
"suspicious_php_content": true,
"new_php_in_languages": true,
"new_php_in_upgrade": true,
"phishing_page": true,
"phishing_directory": true,
"backdoor_binary": true,
"new_executable_in_config": true,
"htaccess_injection": true,
"htaccess_handler_abuse": true,
"email_phishing_content": true,
"suspicious_crontab": true,
}
return fixableChecks[checkType]
}
// ApplyFix executes the remediation action for a finding.
func ApplyFix(checkType, message, details string, filePath ...string) RemediationResult {
path := selectFindingPath(message, filePath...)
switch checkType {
case "world_writable_php", "group_writable_php":
return fixPermissions(path)
case "webshell", "new_webshell_file", "obfuscated_php", "php_dropper",
"suspicious_php_content", "new_php_in_languages", "new_php_in_upgrade",
"phishing_page", "phishing_directory":
return fixQuarantine(path)
case "backdoor_binary", "new_executable_in_config":
return fixKillAndQuarantine(path, details)
case "htaccess_injection", "htaccess_handler_abuse":
return fixHtaccess(path, message)
case "email_phishing_content":
return fixQuarantineSpoolMessage(message)
case "suspicious_crontab":
return fixSuspiciousCrontab(path)
default:
return RemediationResult{Error: fmt.Sprintf("no automated fix available for check type '%s'", checkType)}
}
}
// fixPermissions sets file permissions to 0644.
func fixPermissions(path string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path from finding"}
}
path, info, err := resolveExistingFixPath(path, fixPermissionsAllowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
oldMode := info.Mode().Perm()
// #nosec G302 -- Intentional: this is the remediation that sets the
// canonical "safe web content" mode on a user file after we flagged
// the file as having dangerous perms (e.g. 0777). 0644 is what the
// webserver needs to serve static content as the file owner.
if err := os.Chmod(path, 0644); err != nil {
return RemediationResult{Error: fmt.Sprintf("chmod failed: %v", err)}
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("chmod 644 %s", path),
Description: fmt.Sprintf("Changed permissions from %o to 644", oldMode),
}
}
// fixQuarantine moves a file or directory to quarantine.
func fixQuarantine(path string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path from finding"}
}
path, info, err := resolveExistingFixPath(path, fixQuarantineAllowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
_ = os.MkdirAll(quarantineDir, 0700)
safeName := strings.ReplaceAll(path, "/", "_")
ts := time.Now().Format("20060102-150405")
qPath := filepath.Join(quarantineDir, fmt.Sprintf("%s_%s", ts, safeName))
if err := os.Rename(path, qPath); err != nil {
// Cross-device fallback for files
if !info.IsDir() {
data, readErr := osFS.ReadFile(path)
if readErr != nil {
return RemediationResult{Error: fmt.Sprintf("cannot read file: %v", readErr)}
}
if writeErr := os.WriteFile(qPath, data, 0600); writeErr != nil {
return RemediationResult{Error: fmt.Sprintf("cannot write quarantine: %v", writeErr)}
}
os.Remove(path)
} else {
return RemediationResult{Error: fmt.Sprintf("cannot quarantine directory: %v", err)}
}
}
// Write metadata sidecar for restore
var uid, gid int
if stat, ok := info.Sys().(*syscall.Stat_t); ok {
uid = int(stat.Uid)
gid = int(stat.Gid)
}
meta := map[string]interface{}{
"original_path": path,
"owner_uid": uid,
"group_gid": gid,
"mode": info.Mode().String(),
"size": info.Size(),
"quarantine_at": time.Now(),
"reason": "Fixed via CSM Web UI",
}
metaData, _ := json.MarshalIndent(meta, "", " ")
if err := os.WriteFile(qPath+".meta", metaData, 0600); err != nil {
fmt.Fprintf(os.Stderr, "remediate: error writing quarantine metadata %s: %v\n", qPath+".meta", err)
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("quarantined %s → %s", path, qPath),
Description: fmt.Sprintf("Moved to quarantine: %s", qPath),
}
}
// fixKillAndQuarantine kills any process using the file, then quarantines it.
func fixKillAndQuarantine(path, details string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path from finding"}
}
// Try to extract and kill PID from details
pid := extractPID(details)
if pid != "" {
pidInt := 0
fmt.Sscanf(pid, "%d", &pidInt)
if pidInt > 1 {
uid := getProcessUID(pid)
if uid != "0" && uid != "" { // never kill root
_ = syscall.Kill(pidInt, syscall.SIGKILL)
}
}
}
// Then quarantine
result := fixQuarantine(path)
if result.Success && pid != "" {
result.Action = fmt.Sprintf("killed PID %s and %s", pid, result.Action)
result.Description = "Process killed and file quarantined"
}
return result
}
// fixHtaccess removes malicious directives from an .htaccess file while
// preserving comments and known-safe directives (e.g., Wordfence, LiteSpeed).
func fixHtaccess(path, message string) RemediationResult {
if path == "" {
return RemediationResult{Error: "could not extract file path"}
}
if filepath.Base(path) != ".htaccess" {
return RemediationResult{Error: "automated .htaccess remediation only applies to .htaccess files"}
}
path, _, err := resolveExistingFixPath(path, fixHtaccessAllowedRoots)
if err != nil {
return RemediationResult{Error: err.Error()}
}
data, err := osFS.ReadFile(path)
if err != nil {
return RemediationResult{Error: fmt.Sprintf("cannot read: %v", err)}
}
dangerous := []string{"auto_prepend_file", "auto_append_file", "eval(", "base64_decode",
"gzinflate", "str_rot13", "addhandler", "sethandler"}
safe := []string{
"wordfence-waf.php", "litespeed", "advanced-headers.php", "rsssl",
"application/x-httpd-php", "application/x-httpd-ea-php", "application/x-httpd-alt-php",
"-execcgi", "sethandler none", "sethandler default-handler",
"text/html", "text/css", "text/javascript", "application/javascript",
"image/", "font/", ".woff", ".woff2", ".ttf", ".eot", ".svg",
"wordfence",
}
var cleaned []string
removed := 0
for _, line := range strings.Split(string(data), "\n") {
lineLower := strings.ToLower(strings.TrimSpace(line))
if strings.HasPrefix(lineLower, "#") {
cleaned = append(cleaned, line)
continue
}
isDangerous := false
for _, d := range dangerous {
if strings.Contains(lineLower, d) {
isSafe := false
for _, s := range safe {
if strings.Contains(lineLower, s) {
isSafe = true
break
}
}
if !isSafe {
isDangerous = true
break
}
}
}
if isDangerous {
removed++
} else {
cleaned = append(cleaned, line)
}
}
if removed == 0 {
return RemediationResult{Error: "no malicious directives found to remove"}
}
// #nosec G306 -- .htaccess rewritten for a user's public_html; 0644 is
// the mode the webserver expects for static content.
if err := os.WriteFile(path, []byte(strings.Join(cleaned, "\n")), 0644); err != nil {
return RemediationResult{Error: fmt.Sprintf("write failed: %v", err)}
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("removed %d malicious directive(s) from %s", removed, path),
Description: fmt.Sprintf("Cleaned .htaccess: removed %d line(s)", removed),
}
}
// extractFilePathFromMessage extracts a file path from a finding message.
// Handles patterns like "World-writable PHP file: /path/to/file"
// and "Webshell found: /path/to/file"
func extractFilePathFromMessage(message string) string {
// Look for /home/ or /tmp/ paths
for _, prefix := range []string{"/home/", "/tmp/", "/dev/shm/", "/var/tmp/"} {
idx := strings.Index(message, prefix)
if idx < 0 {
continue
}
rest := message[idx:]
// Path ends at space, comma, newline, or end
endIdx := len(rest)
for i, c := range rest {
if c == ' ' || c == ',' || c == '\n' || c == ')' {
endIdx = i
break
}
}
return rest[:endIdx]
}
return ""
}
func selectFindingPath(message string, filePath ...string) string {
if len(filePath) > 0 {
if path := strings.TrimSpace(filePath[0]); path != "" {
return path
}
}
return extractFilePathFromMessage(message)
}
func resolveExistingFixPath(path string, allowedRoots []string) (string, os.FileInfo, error) {
cleanPath, err := sanitizeFixPath(path, allowedRoots)
if err != nil {
return "", nil, err
}
info, err := osFS.Lstat(cleanPath)
if err != nil {
return "", nil, fmt.Errorf("file not found: %v", err)
}
if info.Mode()&os.ModeSymlink != 0 {
return "", nil, fmt.Errorf("symlinked paths are not eligible for automated remediation: %s", cleanPath)
}
resolved, err := filepath.EvalSymlinks(cleanPath)
if err != nil {
return "", nil, fmt.Errorf("cannot resolve path: %v", err)
}
resolved, err = sanitizeFixPath(resolved, allowedRoots)
if err != nil {
return "", nil, err
}
if accountRoot := homeAccountRoot(cleanPath); accountRoot != "" && !isPathWithinOrEqual(resolved, accountRoot) {
return "", nil, fmt.Errorf("resolved path escapes account boundary: %s", resolved)
}
resolvedInfo, err := osFS.Lstat(resolved)
if err != nil {
return "", nil, fmt.Errorf("file not found: %v", err)
}
if resolvedInfo.Mode()&os.ModeSymlink != 0 {
return "", nil, fmt.Errorf("symlinked paths are not eligible for automated remediation: %s", resolved)
}
return resolved, resolvedInfo, nil
}
func sanitizeFixPath(path string, allowedRoots []string) (string, error) {
path = filepath.Clean(strings.TrimSpace(path))
if path == "" {
return "", fmt.Errorf("file path is required")
}
if !filepath.IsAbs(path) {
return "", fmt.Errorf("file path must be absolute")
}
for _, root := range allowedRoots {
if isPathWithinOrEqual(path, root) {
return path, nil
}
}
return "", fmt.Errorf("file path is outside the allowed remediation roots: %s", path)
}
func isPathWithinOrEqual(path, base string) bool {
cleanPath := filepath.Clean(path)
cleanBase := filepath.Clean(base)
return cleanPath == cleanBase || strings.HasPrefix(cleanPath, cleanBase+string(filepath.Separator))
}
func homeAccountRoot(path string) string {
clean := filepath.Clean(path)
if !strings.HasPrefix(clean, "/home/") {
return ""
}
parts := strings.Split(clean, string(filepath.Separator))
if len(parts) < 4 {
return ""
}
return filepath.Join("/home", parts[2])
}
// extractEximMsgID extracts an Exim message ID from a finding message.
// Matches the pattern "(message: XXXXXX-XXXXXX-XX)" used by emailscan.go.
func extractEximMsgID(message string) string {
prefix := "(message: "
idx := strings.Index(message, prefix)
if idx < 0 {
return ""
}
rest := message[idx+len(prefix):]
end := strings.Index(rest, ")")
if end < 0 {
return ""
}
return strings.TrimSpace(rest[:end])
}
// fixQuarantineSpoolMessage moves Exim spool files (-H header and -D body)
// for a message ID into quarantine.
func fixQuarantineSpoolMessage(message string) RemediationResult {
msgID := extractEximMsgID(message)
if msgID == "" {
return RemediationResult{Error: "could not extract Exim message ID from finding"}
}
// Validate Exim message ID format to prevent path traversal
if !eximMsgIDRegex.MatchString(msgID) {
return RemediationResult{Error: fmt.Sprintf("invalid Exim message ID format: %s", msgID)}
}
var spoolDir string
for _, dir := range eximSpoolDirs {
if _, err := osFS.Stat(filepath.Join(dir, msgID+"-H")); err == nil {
spoolDir = dir
break
}
}
if spoolDir == "" {
return RemediationResult{Error: fmt.Sprintf("spool message %s not found (already delivered or removed)", msgID)}
}
_ = os.MkdirAll(quarantineDir, 0700)
ts := time.Now().Format("20060102-150405")
moved := 0
for _, suffix := range []string{"-H", "-D"} {
src := filepath.Join(spoolDir, msgID+suffix)
if _, err := osFS.Stat(src); err != nil {
continue
}
dst := filepath.Join(quarantineDir, fmt.Sprintf("%s_exim_%s%s", ts, msgID, suffix))
if err := os.Rename(src, dst); err != nil {
// Cross-device fallback
data, readErr := osFS.ReadFile(src)
if readErr != nil {
return RemediationResult{Error: fmt.Sprintf("cannot read %s: %v", src, readErr)}
}
if writeErr := os.WriteFile(dst, data, 0600); writeErr != nil {
return RemediationResult{Error: fmt.Sprintf("cannot write quarantine: %v", writeErr)}
}
os.Remove(src)
}
moved++
}
if moved == 0 {
return RemediationResult{Error: fmt.Sprintf("no spool files found for message %s", msgID)}
}
// Write metadata sidecar
meta := map[string]interface{}{
"message_id": msgID,
"spool_dir": spoolDir,
"quarantine_at": time.Now(),
"reason": "Phishing email quarantined via CSM Web UI",
}
metaData, _ := json.MarshalIndent(meta, "", " ")
metaPath := filepath.Join(quarantineDir, fmt.Sprintf("%s_exim_%s.meta", ts, msgID))
if err := os.WriteFile(metaPath, metaData, 0600); err != nil {
fmt.Fprintf(os.Stderr, "remediate: error writing spool quarantine metadata %s: %v\n", metaPath, err)
}
return RemediationResult{
Success: true,
Action: fmt.Sprintf("quarantined spool message %s (%d files)", msgID, moved),
Description: fmt.Sprintf("Exim spool files moved to quarantine for message %s", msgID),
}
}
package checks
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
)
const (
reputationCacheFile = "reputation_cache.json"
cacheExpiry = 6 * time.Hour
errorCacheExpiry = 1 * time.Hour // cache transient API errors to avoid retrying same IP
abuseConfidenceThreshold = 50
maxQueriesPerCycle = 5 // max AbuseIPDB API calls per 10-min cycle (~720/day, fits free tier)
maxCacheEntries = 5000 // cap cache size
)
// abuseIPDBEndpoint is the URL queried for IP reputation. Declared as a
// var (not const) so tests can point it at an httptest server. Production
// callers must not modify this.
var abuseIPDBEndpoint = "https://api.abuseipdb.com/api/v2/check"
// abuseIPDBClient is the HTTP client used for AbuseIPDB queries. Declared
// at package scope so tests can swap in a mock client (e.g., one whose
// transport routes all traffic to an httptest server).
var abuseIPDBClient = &http.Client{Timeout: 10 * time.Second}
type reputationCache struct {
Entries map[string]*reputationEntry `json:"entries"`
}
type reputationEntry struct {
Score int `json:"score"`
Category string `json:"category"`
CheckedAt time.Time `json:"checked_at"`
}
// CheckIPReputation looks up non-infra IPs against threat intelligence.
// Four-tier approach:
// 1. Skip if already blocked
// 2. Check local threat DB (permanent blocklist + free feeds)
// 3. Check AbuseIPDB cache
// 4. Query AbuseIPDB for truly unknown IPs (max 5/cycle, ~720/day)
func CheckIPReputation(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
ips := collectRecentIPs(cfg)
if len(ips) == 0 {
return nil
}
alreadyBlocked := loadAllBlockedIPs(cfg.StatePath)
threatDB := GetThreatDB()
cache := loadReputationCache(cfg.StatePath)
client := abuseIPDBClient
quotaExhausted := false
checked := 0
for ip, source := range ips {
// Tier 1: Skip if already blocked
if alreadyBlocked[ip] {
continue
}
// Tier 2: Check local threat DB
if threatDB != nil {
if dbSource, found := threatDB.Lookup(ip); found {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "ip_reputation",
Message: fmt.Sprintf("Known malicious IP accessing server: %s (source: %s)", ip, dbSource),
Details: fmt.Sprintf("Detected via: %s\nMatched in local threat intelligence database", source),
Timestamp: time.Now(),
})
continue
}
}
// Tier 3: Check AbuseIPDB cache
if entry, ok := cache.Entries[ip]; ok {
if time.Since(entry.CheckedAt) < cacheExpiry {
if entry.Score >= abuseConfidenceThreshold {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "ip_reputation",
Message: fmt.Sprintf("Known malicious IP accessing server: %s (AbuseIPDB score: %d/100)", ip, entry.Score),
Details: fmt.Sprintf("Detected via: %s\nCategory: %s\nThis IP is reported in threat intelligence databases", source, entry.Category),
Timestamp: time.Now(),
})
}
continue
}
}
// Tier 4: Query AbuseIPDB - skip if no key, quota exhausted, or limit reached
if cfg.Reputation.AbuseIPDBKey == "" || quotaExhausted || checked >= maxQueriesPerCycle {
continue
}
score, category, err := queryAbuseIPDB(client, ip, cfg.Reputation.AbuseIPDBKey)
if err != nil {
if strings.Contains(err.Error(), "429") || strings.Contains(err.Error(), "402") {
fmt.Fprintf(os.Stderr, "abuseipdb: quota exhausted (%v), stopping lookups for this cycle\n", err)
quotaExhausted = true
continue
}
cache.Entries[ip] = &reputationEntry{
Score: -1,
Category: fmt.Sprintf("error: %v", err),
CheckedAt: time.Now().Add(cacheExpiry - errorCacheExpiry),
}
checked++
continue
}
checked++
cache.Entries[ip] = &reputationEntry{
Score: score,
Category: category,
CheckedAt: time.Now(),
}
if score >= abuseConfidenceThreshold {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "ip_reputation",
Message: fmt.Sprintf("Known malicious IP accessing server: %s (AbuseIPDB score: %d/100)", ip, score),
Details: fmt.Sprintf("Detected via: %s\nCategory: %s\nThis IP is reported in threat intelligence databases", source, category),
Timestamp: time.Now(),
})
}
}
// Clean and cap cache
cleanCache(cache)
saveReputationCache(cfg.StatePath, cache)
return findings
}
// collectRecentIPs gathers non-infra IPs from multiple log sources.
// collectRecentIPs gathers non-infra IPs from multiple log sources.
// Returns map of IP → source description (e.g. "SSH login", "Dovecot IMAP auth failure").
func collectRecentIPs(cfg *config.Config) map[string]string {
ips := make(map[string]string)
// SSH logins
for _, line := range tailFile("/var/log/secure", 50) {
if !strings.Contains(line, "Accepted") {
continue
}
if ip := extractIPAfterKeyword(line, "from"); ip != "" {
addIfNotInfra(ips, ip, "SSH login", cfg)
}
}
// cPanel access log
for _, line := range tailFile("/usr/local/cpanel/logs/access_log", 100) {
if ip := firstField(line); ip != "" {
addIfNotInfra(ips, ip, "cPanel/WHM access", cfg)
}
}
// Web server access log (LiteSpeed/Apache)
webLogPaths := []string{
"/usr/local/apache/logs/access_log",
"/var/log/apache2/access_log",
"/etc/apache2/logs/access_log",
}
for _, path := range webLogPaths {
lines := tailFile(path, 100)
if len(lines) == 0 {
continue
}
for _, line := range lines {
if ip := firstField(line); ip != "" {
addIfNotInfra(ips, ip, "HTTP request", cfg)
}
}
break
}
// Exim log - SMTP auth failures
for _, line := range tailFile("/var/log/exim_mainlog", 50) {
if strings.Contains(line, "authenticator failed") || strings.Contains(line, "rejected RCPT") {
if ip := extractBracketedIP(line); ip != "" {
addIfNotInfra(ips, ip, "SMTP auth failure", cfg)
}
}
}
// Dovecot - IMAP/POP3 auth failures
for _, line := range tailFile("/var/log/maillog", 50) {
if strings.Contains(line, "auth failed") || strings.Contains(line, "Aborted login") {
if ip := extractIPAfterKeyword(line, "rip="); ip != "" {
addIfNotInfra(ips, ip, "Dovecot IMAP/POP3 auth failure", cfg)
}
}
}
return ips
}
func addIfNotInfra(ips map[string]string, ip, source string, cfg *config.Config) {
if ip == "127.0.0.1" || ip == "::1" || ip == "" {
return
}
if isInfraIP(ip, cfg.InfraIPs) {
return
}
if _, exists := ips[ip]; !exists {
ips[ip] = source
}
}
func firstField(line string) string {
fields := strings.Fields(line)
if len(fields) == 0 {
return ""
}
ip := fields[0]
// Validate it looks like an IP (v4 or v6)
if strings.Count(ip, ".") == 3 || strings.Contains(ip, ":") {
return ip
}
return ""
}
func extractIPAfterKeyword(line, keyword string) string {
idx := strings.Index(line, keyword)
if idx < 0 {
return ""
}
rest := line[idx+len(keyword):]
rest = strings.TrimLeft(rest, " =")
fields := strings.Fields(rest)
if len(fields) == 0 {
return ""
}
ip := strings.TrimRight(fields[0], ",:;)([]")
if strings.Count(ip, ".") == 3 || strings.Contains(ip, ":") {
return ip
}
return ""
}
func extractBracketedIP(line string) string {
// Extract IP from [1.2.3.4] format common in exim logs
start := strings.Index(line, "[")
if start < 0 {
return ""
}
end := strings.Index(line[start:], "]")
if end < 0 {
return ""
}
ip := line[start+1 : start+end]
if strings.Count(ip, ".") == 3 || strings.Contains(ip, ":") {
return ip
}
return ""
}
// loadAllBlockedIPs returns all IPs currently blocked in CSM.
// Uses bbolt store when available, falls back to flat files.
func loadAllBlockedIPs(statePath string) map[string]bool {
blocked := make(map[string]bool)
// Try bbolt store first.
if sdb := store.Global(); sdb != nil {
ss := sdb.LoadFirewallState()
for _, entry := range ss.Blocked {
blocked[entry.IP] = true
}
} else {
// Fallback: read from firewall engine state (nftables) flat file.
if fwData, err := osFS.ReadFile(filepath.Join(statePath, "firewall", "state.json")); err == nil {
var fwState struct {
Blocked []struct {
IP string `json:"ip"`
ExpiresAt time.Time `json:"expires_at"`
} `json:"blocked"`
}
if json.Unmarshal(fwData, &fwState) == nil {
now := time.Now()
for _, entry := range fwState.Blocked {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
blocked[entry.IP] = true
}
}
}
}
}
// Also read from blocked_ips.json (legacy CSM auto-block)
type blockedEntry struct {
IP string `json:"ip"`
ExpiresAt time.Time `json:"expires_at"`
}
type blockFile struct {
IPs []blockedEntry `json:"ips"`
}
data, err := osFS.ReadFile(filepath.Join(statePath, "blocked_ips.json"))
if err == nil {
var bf blockFile
if err := json.Unmarshal(data, &bf); err == nil {
now := time.Now()
for _, entry := range bf.IPs {
if now.Before(entry.ExpiresAt) {
blocked[entry.IP] = true
}
}
}
}
return blocked
}
type abuseIPDBResponse struct {
Data struct {
AbuseConfidenceScore int `json:"abuseConfidenceScore"`
UsageType string `json:"usageType"`
ISP string `json:"isp"`
TotalReports int `json:"totalReports"`
} `json:"data"`
Errors []struct {
Detail string `json:"detail"`
Status int `json:"status"`
} `json:"errors"`
}
// queryAbuseIPDB returns (score, category, error).
// Returns specific errors for rate limiting (429) and quota exhaustion (402).
func queryAbuseIPDB(client *http.Client, ip, apiKey string) (int, string, error) {
req, err := http.NewRequest("GET", abuseIPDBEndpoint+"?ipAddress="+ip+"&maxAgeInDays=90", nil)
if err != nil {
return 0, "", err
}
req.Header.Set("Key", apiKey)
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return 0, "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode == 429 {
return 0, "", fmt.Errorf("429 rate limited")
}
if resp.StatusCode == 402 {
return 0, "", fmt.Errorf("402 quota exceeded")
}
if resp.StatusCode != 200 {
return 0, "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
var result abuseIPDBResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return 0, "", err
}
if len(result.Errors) > 0 {
return 0, "", fmt.Errorf("API error: %s", result.Errors[0].Detail)
}
category := result.Data.UsageType
if result.Data.ISP != "" {
category += " (" + result.Data.ISP + ")"
}
if result.Data.TotalReports > 0 {
category += fmt.Sprintf(", %d reports", result.Data.TotalReports)
}
return result.Data.AbuseConfidenceScore, category, nil
}
// cleanCache removes expired entries and caps at maxCacheEntries.
func cleanCache(cache *reputationCache) {
// When using bbolt store, delegate cleanup to store methods.
if sdb := store.Global(); sdb != nil {
sdb.CleanExpiredReputation(cacheExpiry)
sdb.EnforceReputationCap(maxCacheEntries)
return
}
// Fallback: in-memory cache cleanup.
now := time.Now()
// Remove expired entries - use same expiry as cache freshness check
for ip, entry := range cache.Entries {
if now.Sub(entry.CheckedAt) > cacheExpiry {
delete(cache.Entries, ip)
}
}
// Cap at max entries - remove oldest if over limit
if len(cache.Entries) > maxCacheEntries {
type aged struct {
ip string
age time.Duration
}
entries := make([]aged, 0, len(cache.Entries))
for ip, entry := range cache.Entries {
entries = append(entries, aged{ip, now.Sub(entry.CheckedAt)})
}
// Sort by age descending (oldest first)
sort.Slice(entries, func(i, j int) bool {
return entries[i].age > entries[j].age
})
// Remove oldest until under limit
for i := 0; i < len(entries)-maxCacheEntries; i++ {
delete(cache.Entries, entries[i].ip)
}
}
}
func loadReputationCache(statePath string) *reputationCache {
cache := &reputationCache{Entries: make(map[string]*reputationEntry)}
// Try bbolt store first - after migration the flat file is renamed to .bak.
if sdb := store.Global(); sdb != nil {
for ip, entry := range sdb.AllReputation() {
cache.Entries[ip] = &reputationEntry{
Score: entry.Score,
Category: entry.Category,
CheckedAt: entry.CheckedAt,
}
}
return cache
}
// Fallback: flat-file JSON (pre-migration).
data, err := osFS.ReadFile(filepath.Join(statePath, reputationCacheFile))
if err == nil {
_ = json.Unmarshal(data, cache)
if cache.Entries == nil {
cache.Entries = make(map[string]*reputationEntry)
}
}
return cache
}
func saveReputationCache(statePath string, cache *reputationCache) {
if sdb := store.Global(); sdb != nil {
for ip, entry := range cache.Entries {
_ = sdb.SetReputation(ip, store.ReputationEntry{
Score: entry.Score,
Category: entry.Category,
CheckedAt: entry.CheckedAt,
})
}
return
}
// Fallback: flat-file JSON.
data, _ := json.MarshalIndent(cache, "", " ")
tmpPath := filepath.Join(statePath, reputationCacheFile+".tmp")
_ = os.WriteFile(tmpPath, data, 0600)
_ = os.Rename(tmpPath, filepath.Join(statePath, reputationCacheFile))
}
package checks
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckFunc is the signature for all check functions.
// The context is cancelled when the check times out so goroutines can exit.
type CheckFunc func(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding
// namedCheck pairs a check function with its name for timeout reporting.
type namedCheck struct {
name string
fn CheckFunc
}
// ForceAll forces all checks to run regardless of throttle (used by baseline).
var ForceAll bool
// DryRun disables auto-response actions (kill, quarantine, block).
// Used by `check` (read-only) and `baseline` commands.
var DryRun bool
// Tier identifies which set of checks to run.
type Tier string
const (
TierCritical Tier = "critical" // Fast checks - processes, auth, network (~5 seconds)
TierDeep Tier = "deep" // Filesystem scans - webshells, htaccess, WP core (~90 seconds)
TierAll Tier = "all" // Both tiers
)
const checkTimeout = 5 * time.Minute
func criticalChecks() []namedCheck {
return []namedCheck{
{"fake_kernel_threads", CheckFakeKernelThreads},
{"suspicious_processes", CheckSuspiciousProcesses},
{"php_processes", CheckPHPProcesses},
{"shadow_changes", CheckShadowChanges},
{"uid0_accounts", CheckUID0Accounts},
{"ssh_keys", CheckSSHKeys},
{"sshd_config", CheckSSHDConfig},
{"ssh_logins", CheckSSHLogins},
{"api_tokens", CheckAPITokens},
{"crontabs", CheckCrontabs},
{"outbound_connections", CheckOutboundConnections},
{"user_outbound", CheckOutboundUserConnections},
{"dns_connections", CheckDNSConnections},
{"whm_access", CheckWHMAccess},
{"cpanel_logins", CheckCpanelLogins},
{"cpanel_filemanager", CheckCpanelFileManager},
{"firewall", CheckFirewall},
{"mail_queue", CheckMailQueue},
{"mail_per_account", CheckMailPerAccount},
{"kernel_modules", CheckKernelModules},
{"mysql_users", CheckMySQLUsers},
{"database_dumps", CheckDatabaseDumps},
{"exfiltration_paste", CheckOutboundPasteSites},
{"wp_bruteforce", CheckWPBruteForce},
{"ftp_logins", CheckFTPLogins},
{"webmail_logins", CheckWebmailLogins},
{"api_auth_failures", CheckAPIAuthFailures},
{"ip_reputation", CheckIPReputation},
{"local_threat_score", CheckLocalThreatScore},
{"modsec_audit", CheckModSecAuditLog},
{"health", CheckHealth},
{"perf_load", CheckLoadAverage},
{"perf_php_processes", CheckPHPProcessLoad},
{"perf_memory", CheckSwapAndOOM},
}
}
func deepChecks() []namedCheck {
return []namedCheck{
{"filesystem", CheckFilesystem},
{"webshells", CheckWebshells},
{"htaccess", CheckHtaccess},
{"wp_core", CheckWPCore},
{"file_index", CheckFileIndex},
{"php_content", CheckPHPContent},
{"phishing", CheckPhishing},
{"nulled_plugins", CheckNulledPlugins},
{"rpm_integrity", CheckRPMIntegrity},
{"group_writable_php", CheckGroupWritablePHP},
{"open_basedir", CheckOpenBasedir},
{"symlink_attacks", CheckSymlinkAttacks},
{"php_config_changes", CheckPHPConfigChanges},
{"dns_zones", CheckDNSZoneChanges},
{"ssl_certs", CheckSSLCertIssuance},
{"waf_status", CheckWAFStatus},
{"db_content", CheckDatabaseContent},
{"email_content", CheckOutboundEmailContent},
{"outdated_plugins", CheckOutdatedPlugins},
{"email_weak_password", CheckEmailPasswords},
{"email_forwarder_audit", CheckForwarders},
{"perf_php_handler", CheckPHPHandler},
{"perf_mysql_config", CheckMySQLConfig},
{"perf_redis_config", CheckRedisConfig},
{"perf_error_logs", CheckErrorLogBloat},
{"perf_wp_config", CheckWPConfig},
{"perf_wp_transients", CheckWPTransientBloat},
{"perf_wp_cron", CheckWPCron},
}
}
// PerfCheckNamesForTier returns the perf_* check names registered in the given tier.
// Used by the daemon to perform an atomic purge-and-merge when storing findings.
func PerfCheckNamesForTier(tier Tier) []string {
var toScan []namedCheck
switch tier {
case TierCritical:
toScan = criticalChecks()
case TierDeep:
toScan = deepChecks()
case TierAll:
toScan = append(criticalChecks(), deepChecks()...)
}
var names []string
for _, nc := range toScan {
if strings.HasPrefix(nc.name, "perf_") {
names = append(names, nc.name)
}
}
return names
}
// RunTier runs only the specified tier of checks.
func RunTier(cfg *config.Config, store *state.Store, tier Tier) []alert.Finding {
var toRun []namedCheck
switch tier {
case TierCritical:
toRun = criticalChecks()
case TierDeep:
toRun = deepChecks()
case TierAll:
toRun = append(criticalChecks(), deepChecks()...)
}
return runParallel(cfg, store, toRun)
}
// RunReducedDeep runs only the deep checks that fanotify can't replace.
// Used by the daemon when fanotify is active.
//
// Skipped (fanotify handles these in real-time):
//
// filesystem, webshells, htaccess, file_index, php_content,
// phishing, php_config_changes
func RunReducedDeep(cfg *config.Config, store *state.Store) []alert.Finding {
reduced := []namedCheck{
{"wp_core", CheckWPCore},
{"nulled_plugins", CheckNulledPlugins},
{"rpm_integrity", CheckRPMIntegrity},
{"group_writable_php", CheckGroupWritablePHP},
{"open_basedir", CheckOpenBasedir},
{"symlink_attacks", CheckSymlinkAttacks},
{"dns_zones", CheckDNSZoneChanges},
{"ssl_certs", CheckSSLCertIssuance},
{"waf_status", CheckWAFStatus},
{"db_content", CheckDatabaseContent},
{"email_content", CheckOutboundEmailContent},
{"outdated_plugins", CheckOutdatedPlugins},
{"email_weak_password", CheckEmailPasswords},
{"email_forwarder_audit", CheckForwarders},
{"perf_php_handler", CheckPHPHandler},
{"perf_mysql_config", CheckMySQLConfig},
{"perf_redis_config", CheckRedisConfig},
{"perf_error_logs", CheckErrorLogBloat},
{"perf_wp_config", CheckWPConfig},
{"perf_wp_transients", CheckWPTransientBloat},
{"perf_wp_cron", CheckWPCron},
}
return runParallel(cfg, store, reduced)
}
// RunAll runs critical checks always. Deep checks run if throttle allows or ForceAll is set.
func RunAll(cfg *config.Config, store *state.Store) []alert.Finding {
toRun := criticalChecks()
if ForceAll || store.ShouldRunThrottled("deep_scan", cfg.Thresholds.DeepScanIntervalMin) {
toRun = append(toRun, deepChecks()...)
}
return runParallel(cfg, store, toRun)
}
func runParallel(cfg *config.Config, store *state.Store, checks []namedCheck) []alert.Finding {
var mu sync.Mutex
var findings []alert.Finding
var wg sync.WaitGroup
// Limit concurrent checks to avoid saturating CPU (keeps WebUI responsive)
sem := make(chan struct{}, 5)
for _, nc := range checks {
wg.Add(1)
go func(c namedCheck) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
// Run with cancellable context so timed-out checks stop
ctx, cancel := context.WithTimeout(context.Background(), checkTimeout)
done := make(chan []alert.Finding, 1)
go func() {
done <- c.fn(ctx, cfg, store)
}()
select {
case results := <-done:
cancel()
if len(results) > 0 {
mu.Lock()
findings = append(findings, results...)
mu.Unlock()
}
case <-ctx.Done():
cancel()
mu.Lock()
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "check_timeout",
Message: fmt.Sprintf("Check '%s' timed out after %s", c.name, checkTimeout),
Timestamp: time.Now(),
})
mu.Unlock()
}
}(nc)
}
wg.Wait()
now := time.Now()
for i := range findings {
if findings[i].Timestamp.IsZero() {
findings[i].Timestamp = now
}
}
// Cross-account correlation
extra := CorrelateFindings(findings)
for i := range extra {
if extra[i].Timestamp.IsZero() {
extra[i].Timestamp = now
}
}
findings = append(findings, extra...)
// Auto-response: skip when DryRun is set (check/baseline commands)
if !DryRun {
killActions := AutoKillProcesses(cfg, findings)
for i := range killActions {
if killActions[i].Timestamp.IsZero() {
killActions[i].Timestamp = now
}
}
findings = append(findings, killActions...)
quarantineActions := AutoQuarantineFiles(cfg, findings)
for i := range quarantineActions {
if quarantineActions[i].Timestamp.IsZero() {
quarantineActions[i].Timestamp = now
}
}
findings = append(findings, quarantineActions...)
blockActions := AutoBlockIPs(cfg, findings)
for i := range blockActions {
if blockActions[i].Timestamp.IsZero() {
blockActions[i].Timestamp = now
}
}
findings = append(findings, blockActions...)
}
return findings
}
package checks
import (
"regexp"
"strings"
)
// SEO-spam context analysis for WordPress post content.
//
// Word-boundary keyword matching (see countSpamMatches) eliminated the
// substring false positives where "specialist" triggered "cialis" and
// "pharmaceutical" triggered "pharma". It did not eliminate a second
// class of false positive: legitimate prose mentions of an industry or
// product category ("our advisor covers consumer goods, energy,
// pharma" or "Industria alimentara si Pharma"). Word-boundary matching
// cannot distinguish the prose mention from the cloaked black-hat SEO
// link the lalimanro attack injected on the same site.
//
// This file classifies a keyword HIT as SPAM only when the surrounding
// HTML shows an attacker signal: CSS cloaking (off-screen absolute
// positioning, display:none, visibility:hidden, text-indent, micro
// height, font-size:0), an injection fingerprint (short hex HTML
// comment bracketing content), or an external anchor whose URL path
// itself contains the keyword ("/buy/pharma/" style commercial paths).
//
// Bare keyword mentions with none of those signals do not fire, so
// legitimate industry-vertical prose is silent.
//
// The proximity window is bounded to ±spamContextWindow bytes around
// the keyword match. Cloaking at the top of a long post unrelated to
// a keyword mention at the bottom does not spuriously associate.
// spamContextWindow bounds how far (in bytes) around a keyword match
// the analyzer looks for cloaking signals. 400 covers a typical
// cloaked-div attack (small <div> with style + <a> + keyword) without
// reaching into unrelated content.
const spamContextWindow = 400
// cssCloakPatterns are regexes that each, when matched, identify a
// CSS property value indicative of content cloaking. The list is
// conservative: each entry corresponds to a technique widely used in
// real SEO spam campaigns and rarely to nothing else at production
// scale. `(?i)` makes matching case-insensitive; `\s*` around colons
// and values tolerates the whitespace variants attackers use to evade
// naive string scanners ("display : none" vs "display:none").
var cssCloakPatterns = []*regexp.Regexp{
// display:none and visibility:hidden — classic hide.
regexp.MustCompile(`(?i)\bdisplay\s*:\s*none\b`),
regexp.MustCompile(`(?i)\bvisibility\s*:\s*hidden\b`),
// text-indent with a negative value — pushes text off-screen.
regexp.MustCompile(`(?i)\btext-indent\s*:\s*-\s*\d+`),
// Micro height with a positive integer of 0 or 1, paired with
// overflow:hidden to suppress contents. We require the pair
// because height:1 alone is occasionally legitimate.
regexp.MustCompile(`(?i)\bheight\s*:\s*[01](\s*px)?\b[^"'}]*\boverflow\s*:\s*hidden\b`),
regexp.MustCompile(`(?i)\boverflow\s*:\s*hidden\b[^"'}]*\bheight\s*:\s*[01](\s*px)?\b`),
// font-size:0 — classic invisible-text technique.
regexp.MustCompile(`(?i)\bfont-size\s*:\s*0\b`),
// position:absolute paired with a negative coordinate in the same
// style attribute (non-greedy [^"'}]* stays inside one attribute).
// Legitimate uses of position:absolute (menus, tooltips) have
// non-negative coordinates; the pairing is the signal.
regexp.MustCompile(`(?i)\bposition\s*:\s*absolute\b[^"'}]*\b(left|top|right|bottom)\s*:\s*-\s*\d+`),
regexp.MustCompile(`(?i)\b(left|top|right|bottom)\s*:\s*-\s*\d+[^"'}]*\bposition\s*:\s*absolute\b`),
// Off-screen via large negative margin on a block element.
regexp.MustCompile(`(?i)\bmargin(-left|-top)?\s*:\s*-\s*\d{4,}`),
}
// injectionFingerprintRe matches the short hex HTML comment (4-8 hex
// chars) attackers use to tag their own injections across a campaign.
// The length bracket is deliberate: shorter matches are too ambiguous,
// longer matches overlap with legitimate WordPress UUIDs.
var injectionFingerprintRe = regexp.MustCompile(`<!--\s*[a-f0-9]{4,8}\s*-->`)
// anchorHrefRe extracts the href attribute value from each anchor tag
// in a fragment. Double- and single-quoted values are both accepted.
var anchorHrefRe = regexp.MustCompile(`(?i)<a\b[^>]+href\s*=\s*["']([^"']+)["']`)
// externalSchemeRe recognises absolute or protocol-relative URLs. A
// relative URL ("/services/pharma/") is same-origin navigation and not
// an external spam link.
var externalSchemeRe = regexp.MustCompile(`^(?i)(https?:)?//`)
// contentHasSpamContext reports whether any occurrence of the keyword
// in content is accompanied by an SEO-spam signal within
// spamContextWindow bytes. Returning false is a direct statement that
// every keyword hit is a bare mention without cloaking context — the
// caller should suppress the finding in that case.
func contentHasSpamContext(content string, pattern dbSpamPattern) bool {
matches := pattern.regex.FindAllStringIndex(content, -1)
for _, m := range matches {
if hitHasSpamContext(content, m[0], m[1], pattern.keyword) {
return true
}
}
return false
}
// countCloakedSpamMatches returns the number of rows in contents whose
// text contains the spam keyword AND shows an accompanying SEO-spam
// context signal. It is the aggregator used by checkWPPosts to decide
// whether to emit a db_spam_injection finding: bare prose mentions of
// a keyword do not count; only cloaked/SEO-style injections do.
//
// Each qualifying row is counted exactly once regardless of how many
// keyword hits it contains — the finding is per-post, not per-hit.
func countCloakedSpamMatches(pattern dbSpamPattern, contents []string) int {
n := 0
for _, c := range contents {
if contentHasSpamContext(c, pattern) {
n++
}
}
return n
}
// hitHasSpamContext examines the ±spamContextWindow byte region around
// a single keyword hit for cloaking, injection-fingerprint, or
// external-spam-link signals. Extracted as a helper so tests can target
// one hit at a time.
func hitHasSpamContext(content string, start, end int, keyword string) bool {
ws := start - spamContextWindow
if ws < 0 {
ws = 0
}
we := end + spamContextWindow
if we > len(content) {
we = len(content)
}
window := content[ws:we]
if windowHasCSSCloaking(window) {
return true
}
if injectionFingerprintRe.MatchString(window) {
return true
}
if windowHasExternalSpamAnchor(window, keyword) {
return true
}
return false
}
// positionAbsoluteRe and negativeCoordRe are evaluated together in
// windowHasCSSCloaking for the "off-screen absolute positioning"
// signal. Keeping them as two independent regexes (rather than one
// paired regex with [^"'}]* between them) closes an evasion where the
// attacker splits the cloak across two CSS rules — e.g.
// `<style>.a{position:absolute}.b{left:-9999px}</style>` — which the
// paired form stops matching at the first `}`. Both signals must
// appear somewhere in the proximity window; the window itself is
// bounded (see spamContextWindow) so the association is not unbounded.
var (
positionAbsoluteRe = regexp.MustCompile(`(?i)\bposition\s*:\s*absolute\b`)
negativeCoordRe = regexp.MustCompile(`(?i)\b(left|top|right|bottom|margin-left|margin-top)\s*:\s*-\s*\d{2,}`)
)
// windowHasCSSCloaking returns true if the window contains any CSS
// declaration from cssCloakPatterns, OR if it contains both
// position:absolute and a negative coordinate somewhere in the window
// (independent-signal form, for rule-split evasions).
func windowHasCSSCloaking(window string) bool {
for _, re := range cssCloakPatterns {
if re.MatchString(window) {
return true
}
}
if positionAbsoluteRe.MatchString(window) && negativeCoordRe.MatchString(window) {
return true
}
return false
}
// windowHasExternalSpamAnchor returns true if the window contains an
// <a href> whose destination is an external URL (absolute or
// protocol-relative) AND whose URL path contains the spam keyword as
// a bounded segment. Same-origin relative links ("/services/pharma/")
// are skipped — they are internal navigation, not spam.
func windowHasExternalSpamAnchor(window, keyword string) bool {
anchors := anchorHrefRe.FindAllStringSubmatch(window, -1)
if len(anchors) == 0 {
return false
}
keywordLower := strings.ToLower(keyword)
for _, a := range anchors {
href := a[1]
if !externalSchemeRe.MatchString(href) {
continue
}
if !hrefPathContainsKeyword(href, keywordLower) {
continue
}
return true
}
return false
}
// hrefPathContainsKeyword checks whether the URL path portion of href
// contains the keyword bounded by non-alphanumeric characters. The
// boundary guarantees "/buy/pharma/cheap" matches "pharma" but
// "/pharmaceutical/" does not.
func hrefPathContainsKeyword(href, keywordLower string) bool {
// Strip scheme://host prefix. externalSchemeRe already confirmed
// the URL starts with //host or scheme://host.
rest := href
if i := strings.Index(rest, "://"); i >= 0 {
rest = rest[i+3:]
} else {
rest = strings.TrimPrefix(rest, "//")
}
// Everything after the first '/' is the path+query.
var path string
if i := strings.IndexByte(rest, '/'); i >= 0 {
path = rest[i:]
} else {
path = ""
}
lower := strings.ToLower(path)
idx := strings.Index(lower, keywordLower)
if idx < 0 {
return false
}
// Check bounded: preceding char (if any) and following char (if
// any) must not be alphanumeric. This disambiguates
// "/buy/pharma/" (bounded by '/') from "/pharmacist/" (bounded by
// 'c' after "pharma").
if idx > 0 {
c := lower[idx-1]
if isURLWordChar(c) {
return false
}
}
after := idx + len(keywordLower)
if after < len(lower) {
c := lower[after]
if isURLWordChar(c) {
return false
}
}
return true
}
// isURLWordChar reports whether a byte is an ASCII alphanumeric used
// for URL-path word-boundary analysis. Underscore is treated as a word
// character by convention; hyphen is NOT (so "buy-cheap-pharma" still
// matches "pharma" bounded by the trailing hyphen/slash).
func isURLWordChar(b byte) bool {
switch {
case b >= 'a' && b <= 'z':
return true
case b >= 'A' && b <= 'Z':
return true
case b >= '0' && b <= '9':
return true
case b == '_':
return true
}
return false
}
package checks
import (
"bufio"
"context"
"fmt"
"strings"
"syscall"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/state"
)
// CheckKernelModules compares loaded kernel modules against baseline.
// All modules present at baseline time are considered known.
// Only modules loaded AFTER baseline trigger alerts.
func CheckKernelModules(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
modules := loadModuleList()
if len(modules) == 0 {
return nil
}
// Check if baseline exists for kernel modules
_, baselineExists := store.GetRaw("_kmod_baseline_set")
if !baselineExists {
// First run - store all current modules as baseline
for _, mod := range modules {
store.SetRaw("_kmod:"+mod, "baseline")
}
store.SetRaw("_kmod_baseline_set", "true")
return nil
}
// Check for modules not seen at baseline
for _, mod := range modules {
key := "_kmod:" + mod
_, known := store.GetRaw(key)
if !known {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "kernel_module",
Message: fmt.Sprintf("New kernel module loaded after baseline: %s", mod),
Details: "This module was not present when CSM baseline was set. Verify it is legitimate.",
})
// Store it so we don't re-alert
store.SetRaw(key, "new")
}
}
return findings
}
func loadModuleList() []string {
f, err := osFS.Open("/proc/modules")
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var modules []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if len(fields) >= 1 {
modules = append(modules, fields[0])
}
}
return modules
}
// CheckRPMIntegrity verifies critical system binaries haven't been modified.
// Only checks a small set of security-critical packages. Dispatches to
// rpm -V on RHEL-family systems and debsums/dpkg --verify on Debian family.
func CheckRPMIntegrity(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
info := platform.Detect()
switch {
case info.IsRHELFamily():
return checkRPMPackageIntegrity(rpmCriticalPackages)
case info.IsDebianFamily():
return checkDebianPackageIntegrity(debianCriticalPackages)
}
return nil
}
var rpmCriticalPackages = []string{
"openssh-server",
"shadow-utils",
"sudo",
"coreutils",
"util-linux",
"passwd",
}
var debianCriticalPackages = []string{
"openssh-server",
"passwd",
"sudo",
"coreutils",
"util-linux",
"login",
}
func checkRPMPackageIntegrity(packages []string) []alert.Finding {
var findings []alert.Finding
for _, pkg := range packages {
// rpm -V exits non-zero when it finds problems; treat that as
// "findings present" rather than command failure.
out, err := runCmdAllowNonZero("rpm", "-V", pkg)
if err != nil || out == nil {
continue
}
output := strings.TrimSpace(string(out))
if output == "" {
continue
}
// Parse rpm -V output: each line starts with flags
// S=size, 5=md5, T=mtime, etc. We care about S, 5, and M (mode)
for _, line := range strings.Split(output, "\n") {
if len(line) < 9 {
continue
}
flags := line[:9]
file := strings.TrimSpace(line[9:])
// Skip config files (c) and documentation (d)
if strings.Contains(line, " c ") || strings.Contains(line, " d ") {
continue
}
// Check for size (S) or checksum (5) changes on binaries
if strings.Contains(flags, "S") || strings.Contains(flags, "5") {
// Only flag binary files, not configs
if strings.HasPrefix(file, "/usr/bin/") || strings.HasPrefix(file, "/usr/sbin/") ||
strings.HasPrefix(file, "/bin/") || strings.HasPrefix(file, "/sbin/") {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "rpm_integrity",
Message: fmt.Sprintf("Modified system binary: %s (package: %s)", file, pkg),
Details: fmt.Sprintf("RPM verification flags: %s", flags),
})
}
}
}
}
return findings
}
// checkDebianPackageIntegrity verifies Debian/Ubuntu packages using debsums
// (from the debsums package) when available, falling back to dpkg --verify
// (built into dpkg and always present).
func checkDebianPackageIntegrity(packages []string) []alert.Finding {
// Prefer debsums: it's the Debian equivalent of `rpm -V` and reports
// changed files vs. the md5sum shipped by the package.
if _, err := cmdExec.LookPath("debsums"); err == nil {
return checkDebsums(packages)
}
return checkDpkgVerify(packages)
}
func checkDebsums(packages []string) []alert.Finding {
var findings []alert.Finding
for _, pkg := range packages {
// debsums -c exits 2 when it finds mismatches; treat as findings.
out, err := runCmdAllowNonZero("debsums", "-c", pkg)
if err != nil || out == nil {
continue
}
for _, file := range strings.Split(strings.TrimSpace(string(out)), "\n") {
file = strings.TrimSpace(file)
if file == "" {
continue
}
if !isCriticalSystemPath(file) {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "dpkg_integrity",
Message: fmt.Sprintf("Modified system binary: %s (package: %s)", file, pkg),
Details: "debsums reported md5 mismatch against the package manifest.",
})
}
}
return findings
}
func checkDpkgVerify(packages []string) []alert.Finding {
var findings []alert.Finding
for _, pkg := range packages {
// dpkg --verify exits 1 when it finds mismatches; treat as findings.
out, err := runCmdAllowNonZero("dpkg", "--verify", pkg)
if err != nil || out == nil {
continue
}
// dpkg --verify prints lines like:
// ??5?????? /usr/bin/passwd
// where position 2 == '5' means md5 mismatch.
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
if len(line) < 10 {
continue
}
flags := line[:9]
file := strings.TrimSpace(line[9:])
// Skip config files (marked with 'c' after flags).
if strings.Contains(line, " c ") {
continue
}
if !strings.Contains(flags, "5") && !strings.Contains(flags, "S") {
continue
}
if !isCriticalSystemPath(file) {
continue
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "dpkg_integrity",
Message: fmt.Sprintf("Modified system binary: %s (package: %s)", file, pkg),
Details: fmt.Sprintf("dpkg --verify flags: %s", flags),
})
}
}
return findings
}
// isCriticalSystemPath returns true for binaries in standard executable
// directories. Matches the same filter used by the rpm path so the two
// backends report the same scope of findings.
func isCriticalSystemPath(path string) bool {
return strings.HasPrefix(path, "/usr/bin/") ||
strings.HasPrefix(path, "/usr/sbin/") ||
strings.HasPrefix(path, "/bin/") ||
strings.HasPrefix(path, "/sbin/")
}
// CheckMySQLUsers queries for MySQL users with elevated privileges
// that aren't standard cPanel-managed users.
func CheckMySQLUsers(ctx context.Context, _ *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
out, err := runCmd("mysql", "-N", "-B", "-e",
"SELECT user, host FROM mysql.user WHERE Super_priv='Y' AND user NOT IN ('root','mysql.session','mysql.sys','mysql.infoschema','debian-sys-maint')")
if err != nil || out == nil {
return nil
}
output := strings.TrimSpace(string(out))
if output == "" {
return nil
}
// Track known MySQL superusers
hash := hashBytes(out)
key := "_mysql_super_users"
prev, exists := store.GetRaw(key)
if exists && prev != hash {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "mysql_superuser",
Message: "MySQL superuser accounts changed",
Details: fmt.Sprintf("Current superusers:\n%s", output),
})
}
store.SetRaw(key, hash)
return findings
}
// CheckGroupWritablePHP scans for PHP files that are group-writable
// where the group is the web server (nobody/www-data). This allows
// webshells to persist by the web server modifying PHP files.
func CheckGroupWritablePHP(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
// Get web server group GIDs
webGroupGIDs := getWebServerGIDs()
if len(webGroupGIDs) == 0 {
return nil
}
homeDirs, _ := osFS.ReadDir("/home")
for _, homeEntry := range homeDirs {
if !homeEntry.IsDir() {
continue
}
docRoot := fmt.Sprintf("/home/%s/public_html", homeEntry.Name())
scanGroupWritablePHP(docRoot, 4, webGroupGIDs, &findings)
}
return findings
}
func scanGroupWritablePHP(dir string, maxDepth int, webGIDs map[uint32]bool, findings *[]alert.Finding) {
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
name := entry.Name()
fullPath := dir + "/" + name
if entry.IsDir() {
// Skip known large/safe dirs
if name == "cache" || name == "node_modules" || name == "vendor" {
continue
}
scanGroupWritablePHP(fullPath, maxDepth-1, webGIDs, findings)
continue
}
if !strings.HasSuffix(strings.ToLower(name), ".php") {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
// Check group-write bit
if info.Mode()&0020 == 0 {
continue
}
// Check if group is web server
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
continue
}
if webGIDs[stat.Gid] {
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "group_writable_php",
Message: fmt.Sprintf("Web-server group-writable PHP: %s", fullPath),
Details: fmt.Sprintf("Mode: %s, GID: %d", info.Mode(), stat.Gid),
})
}
}
}
func getWebServerGIDs() map[uint32]bool {
gids := make(map[uint32]bool)
data, err := osFS.ReadFile("/etc/group")
if err != nil {
return gids
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Split(line, ":")
if len(fields) < 3 {
continue
}
name := fields[0]
if name == "nobody" || name == "www-data" || name == "apache" || name == "www" {
gid := uint32(0)
fmt.Sscanf(fields[2], "%d", &gid)
gids[gid] = true
}
}
return gids
}
package checks
import (
"bufio"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/store"
)
// ThreatDB is a local IP reputation database built from:
// 1. CSM's own block history (permanent)
// 2. Public threat intelligence feeds (updated daily)
// 3. AbuseIPDB as fallback for unknown IPs
type ThreatDB struct {
mu sync.RWMutex
badIPs map[string]string // ip -> source/reason
badNets []*net.IPNet // CIDR ranges from feeds
whitelist map[string]bool // IPs to never flag
whitelistMeta map[string]*whitelistEntry // expiry metadata
lastUpdate time.Time
dbPath string
// Stats for WebUI
PermanentCount int
FeedIPCount int
FeedNetCount int
LastFeedUpdate time.Time
LastUpdated time.Time // tracks when feeds were last successfully loaded
}
var (
globalThreatDB *ThreatDB
threatDBOnce sync.Once
)
// Minimum expected entries per feed - alerts if feed returns less (corrupted/down)
var feedMinEntries = map[string]int{
"spamhaus-drop": 50,
"spamhaus-edrop": 10,
"blocklist-de": 1000,
"cins-army": 5000,
}
// Free public threat intelligence feeds
var threatFeeds = []struct {
name string
url string
}{
{"spamhaus-drop", "https://www.spamhaus.org/drop/drop.txt"},
{"spamhaus-edrop", "https://www.spamhaus.org/drop/edrop.txt"},
{"blocklist-de", "https://lists.blocklist.de/lists/all.txt"},
{"cins-army", "https://cinsscore.com/list/ci-badguys.txt"},
}
// InitThreatDB initializes the global threat database.
func InitThreatDB(statePath string, whitelistIPs []string) *ThreatDB {
threatDBOnce.Do(func() {
wl := make(map[string]bool)
for _, ip := range whitelistIPs {
wl[ip] = true
}
db := &ThreatDB{
badIPs: make(map[string]string),
whitelist: wl,
dbPath: filepath.Join(statePath, "threat_db"),
}
_ = os.MkdirAll(db.dbPath, 0700)
db.loadPermanentBlocklist()
db.loadPersistedWhitelist()
db.loadFeedCache()
globalThreatDB = db
})
return globalThreatDB
}
// GetThreatDB returns the global threat database.
func GetThreatDB() *ThreatDB {
return globalThreatDB
}
// Lookup checks if an IP is in the local threat database.
// Returns (source, true) if found, ("", false) if unknown.
// Whitelisted IPs always return false.
func (db *ThreatDB) Lookup(ip string) (string, bool) {
db.mu.RLock()
defer db.mu.RUnlock()
// Never flag whitelisted IPs
if db.whitelist[ip] {
return "", false
}
// Check exact IP match
if source, ok := db.badIPs[ip]; ok {
return source, true
}
// Check CIDR ranges (supports both IPv4 and IPv6)
parsed := net.ParseIP(ip)
if parsed != nil {
for _, cidr := range db.badNets {
if cidr.Contains(parsed) {
return "threat-feed-cidr", true
}
}
}
return "", false
}
// AddPermanent adds an IP to the permanent local blocklist.
// Called when CSM auto-blocks an IP - persists across restarts.
func (db *ThreatDB) AddPermanent(ip, reason string) {
db.mu.Lock()
_, exists := db.badIPs[ip]
db.badIPs[ip] = reason
db.mu.Unlock()
// Only persist if this is a new IP (dedup)
if exists {
return
}
if sdb := store.Global(); sdb != nil {
_ = sdb.AddPermanentBlock(ip, reason)
return
}
// Fallback: flat-file permanent.txt.
f, err := os.OpenFile(filepath.Join(db.dbPath, "permanent.txt"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return
}
defer func() { _ = f.Close() }()
fmt.Fprintf(f, "%s # %s [%s]\n", ip, reason, time.Now().Format("2006-01-02"))
}
// RemovePermanent removes an IP from the permanent blocklist and in-memory DB.
func (db *ThreatDB) RemovePermanent(ip string) {
db.mu.Lock()
delete(db.badIPs, ip)
db.mu.Unlock()
if sdb := store.Global(); sdb != nil {
_ = sdb.RemovePermanentBlock(ip)
return
}
// Fallback: rewrite permanent.txt without this IP.
path := filepath.Join(db.dbPath, "permanent.txt")
data, err := osFS.ReadFile(path)
if err != nil {
return
}
var kept []string
for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
kept = append(kept, line)
continue
}
fields := strings.Fields(trimmed)
if len(fields) > 0 && fields[0] == ip {
continue // skip this IP
}
kept = append(kept, line)
}
tmpPath := path + ".tmp"
_ = os.WriteFile(tmpPath, []byte(strings.Join(kept, "\n")+"\n"), 0600)
_ = os.Rename(tmpPath, path)
}
// whitelistEntry tracks an IP with optional expiry.
type whitelistEntry struct {
ExpiresAt time.Time // zero = permanent
}
// AddWhitelist adds an IP to the permanent whitelist.
func (db *ThreatDB) AddWhitelist(ip string) {
db.addWhitelistEntry(ip, time.Time{})
}
// TempWhitelist adds an IP to the whitelist with a TTL.
func (db *ThreatDB) TempWhitelist(ip string, ttl time.Duration) {
db.addWhitelistEntry(ip, time.Now().Add(ttl))
}
func (db *ThreatDB) addWhitelistEntry(ip string, expiresAt time.Time) {
db.mu.Lock()
db.whitelist[ip] = true
if db.whitelistMeta == nil {
db.whitelistMeta = make(map[string]*whitelistEntry)
}
db.whitelistMeta[ip] = &whitelistEntry{ExpiresAt: expiresAt}
delete(db.badIPs, ip)
db.mu.Unlock()
if sdb := store.Global(); sdb != nil {
permanent := expiresAt.IsZero()
_ = sdb.AddWhitelistEntry(ip, expiresAt, permanent)
return
}
db.saveWhitelistFile()
}
// RemoveWhitelist removes an IP from the whitelist.
func (db *ThreatDB) RemoveWhitelist(ip string) {
db.mu.Lock()
delete(db.whitelist, ip)
delete(db.whitelistMeta, ip)
db.mu.Unlock()
if sdb := store.Global(); sdb != nil {
_ = sdb.RemoveWhitelistEntry(ip)
return
}
db.saveWhitelistFile()
}
// PruneExpiredWhitelist removes expired temporary whitelist entries.
// Called periodically from the daemon heartbeat.
func (db *ThreatDB) PruneExpiredWhitelist() int {
now := time.Now()
pruned := 0
db.mu.Lock()
for ip, entry := range db.whitelistMeta {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
delete(db.whitelist, ip)
delete(db.whitelistMeta, ip)
pruned++
}
}
db.mu.Unlock()
if pruned > 0 {
if sdb := store.Global(); sdb != nil {
sdb.PruneExpiredWhitelist()
} else {
db.saveWhitelistFile()
}
fmt.Fprintf(os.Stderr, "[%s] Pruned %d expired whitelist entries\n",
time.Now().Format("2006-01-02 15:04:05"), pruned)
}
return pruned
}
// WhitelistInfo returns all whitelisted IPs with their expiry info.
type WhitelistIP struct {
IP string `json:"ip"`
ExpiresAt *time.Time `json:"expires_at,omitempty"` // nil = permanent
Permanent bool `json:"permanent"`
}
func (db *ThreatDB) WhitelistedIPs() []WhitelistIP {
db.mu.RLock()
defer db.mu.RUnlock()
var ips []string
for ip := range db.whitelist {
ips = append(ips, ip)
}
sort.Strings(ips)
result := make([]WhitelistIP, len(ips))
for i, ip := range ips {
entry := db.whitelistMeta[ip]
w := WhitelistIP{IP: ip, Permanent: true}
if entry != nil && !entry.ExpiresAt.IsZero() {
t := entry.ExpiresAt
w.ExpiresAt = &t
w.Permanent = false
}
result[i] = w
}
return result
}
func (db *ThreatDB) saveWhitelistFile() {
path := filepath.Join(db.dbPath, "whitelist.txt")
db.mu.RLock()
var lines []string
for ip := range db.whitelist {
entry := db.whitelistMeta[ip]
if entry != nil && !entry.ExpiresAt.IsZero() {
lines = append(lines, fmt.Sprintf("%s expires=%s", ip, entry.ExpiresAt.Format(time.RFC3339)))
} else {
lines = append(lines, fmt.Sprintf("%s permanent", ip))
}
}
db.mu.RUnlock()
sort.Strings(lines)
tmpPath := path + ".tmp"
_ = os.WriteFile(tmpPath, []byte(strings.Join(lines, "\n")+"\n"), 0600)
_ = os.Rename(tmpPath, path)
}
// loadPersistedWhitelist loads IPs from the bbolt store (if available)
// or from the flat-file whitelist.txt.
func (db *ThreatDB) loadPersistedWhitelist() {
if db.whitelistMeta == nil {
db.whitelistMeta = make(map[string]*whitelistEntry)
}
if sdb := store.Global(); sdb != nil {
entries := sdb.ListWhitelist()
now := time.Now()
for _, e := range entries {
// Skip expired entries
if !e.Permanent && !e.ExpiresAt.IsZero() && now.After(e.ExpiresAt) {
continue
}
db.whitelist[e.IP] = true
db.whitelistMeta[e.IP] = &whitelistEntry{ExpiresAt: e.ExpiresAt}
delete(db.badIPs, e.IP)
}
return
}
// Fallback: flat-file whitelist.txt.
path := filepath.Join(db.dbPath, "whitelist.txt")
f, err := osFS.Open(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
now := time.Now()
needsRewrite := false
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
fields := strings.Fields(line)
if len(fields) == 0 {
continue
}
ip := fields[0]
if net.ParseIP(ip) == nil {
continue
}
entry := &whitelistEntry{}
// Parse "expires=2026-03-28T19:00:00Z" if present
for _, f := range fields[1:] {
if strings.HasPrefix(f, "expires=") {
if t, err := time.Parse(time.RFC3339, f[8:]); err == nil {
entry.ExpiresAt = t
}
}
}
// Skip expired entries
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
needsRewrite = true
continue
}
db.whitelist[ip] = true
db.whitelistMeta[ip] = entry
delete(db.badIPs, ip)
}
if needsRewrite {
// Synchronous: the load path runs once at startup, so the cost
// is negligible, and a fire-and-forget goroutine would race the
// daemon's shutdown (potentially leaving a `.tmp` file behind or
// writing a half-serialized whitelist.txt if the process is
// killed before the rewrite lands).
db.saveWhitelistFile()
}
}
// Count returns the total number of entries in the database.
func (db *ThreatDB) Count() int {
db.mu.RLock()
defer db.mu.RUnlock()
return len(db.badIPs) + len(db.badNets)
}
// Stats returns statistics for the WebUI dashboard.
func (db *ThreatDB) Stats() map[string]interface{} {
db.mu.RLock()
defer db.mu.RUnlock()
return map[string]interface{}{
"permanent_ips": db.PermanentCount,
"feed_ips": db.FeedIPCount,
"feed_cidrs": db.FeedNetCount,
"total": len(db.badIPs) + len(db.badNets),
"whitelist": len(db.whitelist),
"last_update": db.LastFeedUpdate.Format(time.RFC3339),
}
}
// FeedsStale returns true if threat feeds have not been updated in over 7 days.
func (db *ThreatDB) FeedsStale() bool {
db.mu.RLock()
defer db.mu.RUnlock()
if db.LastUpdated.IsZero() {
return db.lastUpdate.IsZero() || time.Since(db.lastUpdate) > 7*24*time.Hour
}
return time.Since(db.LastUpdated) > 7*24*time.Hour
}
// UpdateFeeds downloads fresh threat intelligence feeds.
// Downloads outside the lock, then swaps data under lock to avoid blocking lookups.
func (db *ThreatDB) UpdateFeeds() error {
db.mu.RLock()
lastUpdate := db.lastUpdate
db.mu.RUnlock()
// Only update once per day
if time.Since(lastUpdate) < 20*time.Hour {
return nil
}
client := &http.Client{Timeout: 30 * time.Second}
// Download all feeds OUTSIDE the lock
newIPs := make(map[string]string)
var newNets []*net.IPNet
totalIPs := 0
totalNets := 0
for _, feed := range threatFeeds {
ips, nets, err := downloadFeed(client, feed.url, feed.name)
if err != nil {
fmt.Fprintf(os.Stderr, "threatdb: error downloading %s: %v\n", feed.name, err)
continue
}
// Validate feed - reject partial downloads to avoid losing good data
minExpected := feedMinEntries[feed.name]
if minExpected > 0 && len(ips)+len(nets) < minExpected {
fmt.Fprintf(os.Stderr, "threatdb: WARNING %s returned only %d entries (expected >%d), keeping cached version\n",
feed.name, len(ips)+len(nets), minExpected)
continue // keep previous cached data for this feed
}
for _, ip := range ips {
newIPs[ip] = feed.name
}
newNets = append(newNets, nets...)
totalIPs += len(ips)
totalNets += len(nets)
// Cache to disk
cachePath := filepath.Join(db.dbPath, feed.name+".txt")
saveLines(cachePath, ips)
}
// Swap data UNDER the lock - fast operation
db.mu.Lock()
// Clear old feed data but keep permanent entries
for ip, source := range db.badIPs {
if source == "permanent-blocklist" {
continue
}
// Check if it's a feed entry (not permanent)
isPermanent := source == "permanent-blocklist"
if !isPermanent {
delete(db.badIPs, ip)
}
}
// Add new feed data
for ip, source := range newIPs {
if _, isPermanent := db.badIPs[ip]; !isPermanent {
db.badIPs[ip] = source
}
}
// Replace CIDR ranges entirely (fixes accumulation bug)
db.badNets = newNets
now := time.Now()
db.lastUpdate = now
db.FeedIPCount = totalIPs
db.FeedNetCount = totalNets
db.LastFeedUpdate = now
db.LastUpdated = now
db.mu.Unlock()
// Save timestamp
_ = os.WriteFile(filepath.Join(db.dbPath, "last_update"),
[]byte(db.lastUpdate.Format(time.RFC3339)), 0600)
fmt.Fprintf(os.Stderr, "threatdb: updated %d IPs + %d CIDR ranges from %d feeds\n",
totalIPs, totalNets, len(threatFeeds))
return nil
}
// loadPermanentBlocklist loads the permanent blocklist from bbolt store
// (if available) or from the flat-file permanent.txt.
func (db *ThreatDB) loadPermanentBlocklist() {
if sdb := store.Global(); sdb != nil {
blocks := sdb.AllPermanentBlocks()
for _, b := range blocks {
db.badIPs[b.IP] = "permanent-blocklist"
}
db.PermanentCount = len(blocks)
return
}
// Fallback: flat-file permanent.txt.
path := filepath.Join(db.dbPath, "permanent.txt")
f, err := osFS.Open(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
seen := make(map[string]bool)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
ip := strings.Fields(line)[0]
// Support both IPv4 and IPv6
if net.ParseIP(ip) != nil && !seen[ip] {
db.badIPs[ip] = "permanent-blocklist"
seen[ip] = true
}
}
db.PermanentCount = len(seen)
// Compact the file if it has duplicates (rewrite with unique entries)
if db.PermanentCount > 0 {
compactPermanentFile(path, seen)
}
}
// compactPermanentFile rewrites the permanent blocklist with unique entries only.
func compactPermanentFile(path string, uniqueIPs map[string]bool) {
// Read all lines to preserve comments/reasons
data, err := osFS.ReadFile(path)
if err != nil {
return
}
lines := strings.Split(string(data), "\n")
if len(lines) <= len(uniqueIPs)+5 {
return // not worth compacting - minimal duplicates
}
// Rewrite with deduplication
seen := make(map[string]bool)
var unique []string
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
unique = append(unique, line)
continue
}
ip := strings.Fields(trimmed)[0]
if !seen[ip] {
seen[ip] = true
unique = append(unique, line)
}
}
tmpPath := path + ".tmp"
_ = os.WriteFile(tmpPath, []byte(strings.Join(unique, "\n")+"\n"), 0600)
_ = os.Rename(tmpPath, path)
}
// loadFeedCache loads cached feed data from disk.
func (db *ThreatDB) loadFeedCache() {
data, err := osFS.ReadFile(filepath.Join(db.dbPath, "last_update"))
if err == nil {
if t, err := time.Parse(time.RFC3339, strings.TrimSpace(string(data))); err == nil {
db.lastUpdate = t
db.LastFeedUpdate = t
db.LastUpdated = t
}
}
for _, feed := range threatFeeds {
cachePath := filepath.Join(db.dbPath, feed.name+".txt")
lines := loadLines(cachePath)
for _, ip := range lines {
db.badIPs[ip] = feed.name
}
db.FeedIPCount += len(lines)
}
// Warn on startup if feeds are stale
if db.LastUpdated.IsZero() && db.FeedIPCount == 0 {
fmt.Fprintf(os.Stderr, "threatdb: WARNING no threat feed data loaded, feeds have never been fetched\n")
} else if !db.LastUpdated.IsZero() && time.Since(db.LastUpdated) > 7*24*time.Hour {
fmt.Fprintf(os.Stderr, "threatdb: WARNING threat feeds are stale (last updated %s, %d days ago)\n",
db.LastUpdated.Format("2006-01-02"), int(time.Since(db.LastUpdated).Hours()/24))
}
}
func downloadFeed(client *http.Client, url, name string) ([]string, []*net.IPNet, error) {
resp, err := client.Get(url)
if err != nil {
return nil, nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
return nil, nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
limited := io.LimitReader(resp.Body, 10*1024*1024)
scanner := bufio.NewScanner(limited)
var ips []string
var nets []*net.IPNet
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") {
continue
}
if idx := strings.IndexAny(line, ";#"); idx > 0 {
line = strings.TrimSpace(line[:idx])
}
if strings.Contains(line, "/") {
_, cidr, err := net.ParseCIDR(line)
if err == nil {
nets = append(nets, cidr)
}
continue
}
ip := strings.Fields(line)[0]
if net.ParseIP(ip) != nil {
ips = append(ips, ip)
}
}
return ips, nets, nil
}
func saveLines(path string, lines []string) {
sort.Strings(lines) // sorted for diffing
// #nosec G304 -- path is filepath.Join under operator-configured statePath.
f, err := os.Create(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
w := bufio.NewWriter(f)
for _, line := range lines {
fmt.Fprintln(w, line)
}
_ = w.Flush()
}
func loadLines(path string) []string {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var lines []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line != "" && !strings.HasPrefix(line, "#") {
lines = append(lines, line)
}
}
return lines
}
package checks
import (
"net"
"net/url"
"strings"
)
// URL reputation — attack-indicator based classifier for external
// <script src="..."> URLs embedded in WordPress content.
//
// PHILOSOPHY
//
// Earlier versions of this package classified script sources against a
// hardcoded allowlist of "known safe" domains (Google Tag Manager,
// Cloudflare CDN, HubSpot, Stripe, etc.). In practice the allowlist is
// unmaintainable: every new widget service (OneTrust, Issuu, regional
// video embeds, tax-form widgets, etc.) adds an entry and operators
// still see HIGH-severity findings for legitimate third-party embeds.
//
// This file takes the inverse approach: rather than asking "is this
// domain on my list of safe services?", it asks "does this URL show
// attacker-characteristic markers?". A <script src> fires a finding
// only when at least one attack indicator is present:
//
// - the host is a raw IP address (attackers dodge domain reputation);
// - the host TLD is on a well-known abused-TLD list (.tk, .ml, .ga,
// .cf, .gq free-abuse; .top, .icu, .click, .pw and similar cheap
// gTLDs per Spamhaus recurring bad-TLD reports);
// - the host is on the existing short known-bad-exfil list
// (Cloudflare Workers free tier, Pastebin raw, GitHub Gist raw —
// legitimate content rarely loads from these hosts);
// - the scheme is plaintext HTTP (an external JS loader without TLS
// is a plaintext MITM target regardless of the destination);
// - the host is empty, contains no dot, or otherwise fails basic
// FQDN validation.
//
// The knownSafeDomains list is retained as a FAST-PATH tie-breaker: if
// the host matches a well-known service we return "not malicious"
// immediately, skipping further analysis. It is an optimization, not
// the primary filter. Unknown hosts on unremarkable TLDs (e.g.
// onetrust.com, issuu.com, trilulilu.ro, formular230.ro) pass because
// they have zero attack indicators — no allowlist growth needed.
//
// TRADE-OFF
//
// An attacker who hosts payload on a compromised mainstream domain
// (.com/.org/.net HTTPS, normal-looking path) is not caught. This gap
// existed under the prior allowlist too — it is a fundamental limit of
// URL-only classification. Closing it requires threat-intelligence
// correlation or content-based JS analysis, both of which are out of
// scope here. The defensive value is in raising the bar for casual
// injection, not in defeating sophisticated adversaries.
// abusedTLDs are top-level domains whose registrations are cheap,
// unverified, and overwhelmingly abused for phishing, malware, and SEO
// spam per recurring Spamhaus and KnowBe4 reports.
//
// The entry bar is high: a TLD is included only if (a) it has shown up
// in top-10 abuse rankings across multiple reporting years, AND
// (b) legitimate business usage is rare. Mixed-use TLDs with
// significant legitimate traffic (.xyz, .online, .site, .live, .space)
// are intentionally excluded to keep false-positive rates low.
//
// Entries are stored without the leading dot; comparison strips the
// leading dot from the observed TLD before lookup.
var abusedTLDs = map[string]bool{
// Former Freenom TLDs — free registration, no verification.
// Near-100% abuse rate; Freenom itself was shut down in 2023 but
// the TLDs remain in DNS.
"tk": true,
"ml": true,
"ga": true,
"cf": true,
"gq": true,
// Cheap new gTLDs consistently in the Spamhaus top-abused list.
"top": true,
"icu": true,
"click": true,
"pw": true,
"loan": true,
"work": true,
"download": true,
// Spamhaus badlist recurring entries — legitimate usage is
// essentially nonexistent at meaningful volume.
"kim": true,
"gdn": true,
"stream": true,
"bid": true,
"racing": true,
"win": true,
"party": true,
"science": true,
"trade": true,
}
// knownBadExfilHosts are hosts where legitimate WordPress content
// essentially never loads JavaScript from, but attackers routinely do.
// The list is intentionally short — see knownSafeDomains for the
// inverse fast-path.
var knownBadExfilHosts = []string{
// Cloudflare Workers free-tier subdomain (e.g. x.workers.dev).
// Legitimate apps host on custom domains; free .workers.dev is a
// common payload drop.
".workers.dev",
// Pastebin and GitHub Gist raw endpoints — legitimate sites do not
// load JS from these paths.
"pastebin.com",
"gist.githubusercontent.com",
// bit.ly and similar URL shorteners in a <script src> are a very
// strong attack signal — no legitimate embed uses a shortener for
// a JS asset.
"bit.ly",
"tinyurl.com",
"is.gd",
"cutt.ly",
}
// scriptSrcMaliciousReason classifies a <script src> URL by attack
// indicators. The first return value is true iff at least one indicator
// fires; the second return value is a short tag suitable for inclusion
// in a finding detail (e.g. "raw IP address", "abused TLD: .top",
// "plaintext HTTP").
//
// The function does not consult knownSafeDomains — callers should do
// that first as a fast path. Separating the two concerns keeps this
// function single-purpose and trivially testable.
func scriptSrcMaliciousReason(rawURL string) (bool, string) {
// Protocol-relative URLs (//host/path) have to be normalised before
// url.Parse can extract the host.
normalised := rawURL
if strings.HasPrefix(normalised, "//") {
normalised = "https:" + normalised
}
u, err := url.Parse(normalised)
if err != nil || u == nil {
return true, "unparseable URL"
}
host := strings.ToLower(u.Hostname())
if host == "" {
return true, "empty host"
}
// Plaintext HTTP for an external script is a strong indicator.
// We do not flag https:// nor protocol-relative (which inherits the
// page's scheme — an https page yields https, an http page already
// has other problems).
if strings.EqualFold(u.Scheme, "http") && !strings.HasPrefix(rawURL, "//") {
return true, "plaintext HTTP external script"
}
// Raw IP host (v4 or v6). net.ParseIP accepts both.
if ip := net.ParseIP(host); ip != nil {
return true, "raw IP address host"
}
// Known bad exfil hosts. Match exact host or any subdomain.
for _, bad := range knownBadExfilHosts {
if host == strings.TrimPrefix(bad, ".") {
return true, "known-bad exfil host: " + bad
}
if strings.HasSuffix(host, bad) {
return true, "known-bad exfil host: " + bad
}
}
// TLD analysis — require at least one dot and a recognisable TLD.
// Hosts without a dot (e.g. bare "localhost" or "internalhost") do
// not belong in external <script src>, so flag them.
lastDot := strings.LastIndexByte(host, '.')
if lastDot < 0 || lastDot == len(host)-1 {
return true, "host without valid TLD"
}
tld := host[lastDot+1:]
if abusedTLDs[tld] {
return true, "abused TLD: ." + tld
}
return false, ""
}
// isAttackerScriptURL is the caller-facing predicate used by
// hasMaliciousExternalScript (in dbscan_filters.go). It combines the
// known-safe fast path with the attack-indicator classifier.
//
// The order matters: the fast path is checked first because it lets us
// short-circuit common legitimate widgets (Google Tag Manager,
// Cloudflare CDN, Stripe) without parsing the URL. Only unknown hosts
// are subjected to the attack-indicator analysis.
func isAttackerScriptURL(rawURL string) bool {
if isSafeScriptDomain(rawURL) {
return false
}
bad, _ := scriptSrcMaliciousReason(rawURL)
return bad
}
package checks
import "bytes"
// .user.ini cPanel-managed signature detection.
//
// cPanel's MultiPHP INI Editor writes .user.ini files with a fixed
// four-line header that begins:
//
// ; cPanel-generated php ini directives, do not edit
// ; Manual editing of this file may result in unexpected behavior.
// ; To make changes to this file, use the cPanel MultiPHP INI Editor ...
// ; For more information, read our documentation ...
//
// When this header is present the values in the file reflect operator
// choices made through cPanel's UI (max_execution_time=0 for a backup
// importer, display_errors=On for a staging account, etc.). These are
// not attacker signals. The severity of findings on values in a
// cPanel-managed file should be reduced to informational.
//
// When the header is absent we cannot tell whether the site owner
// hand-edited the file or an attacker planted it. In that case we
// preserve the original severity.
//
// Detection rule (minimal and precise):
//
// The signature must appear on the FIRST non-blank line of the file.
//
// Reasoning: cPanel always writes the header at the top of the file
// and rewrites the whole file on every edit through the UI. An
// attacker appending content below keeps the header position intact.
// An attacker prepending content above pushes the header down; the
// file is then no longer a regular cPanel-managed file and the
// attacker's values take precedence anyway — we must NOT accept this
// as "managed" because doing so would let attackers suppress findings
// by inserting one of their own lines and then the real cPanel header.
// cpanelUserIniSignature is the exact string cPanel writes on the first
// comment line of a managed .user.ini. Case-sensitive: alternate
// capitalizations are rejected to avoid accepting forgeries.
const cpanelUserIniSignature = "cPanel-generated php ini directives"
// cpanelUserIniMaxLeadingBlanks caps how many blank lines may precede
// the signature. cPanel itself writes the signature as the very first
// line, so any tolerance here is purely for admins who may have
// round-tripped the file through a text editor that adds a trailing
// newline or through an FTP client that accumulates CRLFs. A small cap
// also closes a forgery route: without a bound, an attacker who wants
// to suppress severity on their injected values could prepend a large
// run of blank lines followed by the genuine cPanel header string, and
// a first-non-blank-line-only check would classify the file as
// managed despite the attacker owning all the content above the
// header.
const cpanelUserIniMaxLeadingBlanks = 5
// isCpanelManagedUserIni reports whether data begins with the cPanel
// MultiPHP-managed .user.ini header. Empty/whitespace-only input
// returns false. A small number (cpanelUserIniMaxLeadingBlanks) of
// leading blank lines is tolerated; beyond that, the file is not
// considered cPanel-managed even if the signature appears later.
func isCpanelManagedUserIni(data []byte) bool {
if len(data) == 0 {
return false
}
blanks := 0
start := 0
for i := 0; i <= len(data); i++ {
if i < len(data) && data[i] != '\n' {
continue
}
line := data[start:i]
// Strip CR from CRLF, then whitespace on both sides.
line = bytes.TrimRight(line, "\r")
line = bytes.TrimSpace(line)
if len(line) == 0 {
blanks++
if blanks > cpanelUserIniMaxLeadingBlanks {
return false
}
start = i + 1
continue
}
return bytes.Contains(line, []byte(cpanelUserIniSignature))
}
return false
}
package checks
import (
"bufio"
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/state"
)
// CheckWAFStatus verifies that ModSecurity is loaded, the engine is in
// enforcement mode (not DetectionOnly), OWASP/Comodo rules are active,
// and rules are up to date.
func CheckWAFStatus(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
info := platform.Detect()
// If there is no web server at all, WAF concerns don't apply to this host.
if info.WebServer == platform.WSNone {
return findings
}
modsecActive := modsecDetected(info)
if !modsecActive {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "waf_status",
Message: "ModSecurity WAF is not active",
Details: wafInstallHint(info),
})
return findings // no point checking further
}
// --- Engine mode check ---
engineMode := checkEngineMode(info)
if engineMode == "detectiononly" {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "waf_detection_only",
Message: "ModSecurity is in DetectionOnly mode - attacks are logged but NOT blocked",
Details: "SecRuleEngine is set to DetectionOnly. Change to 'On' for enforcement:\nWHM > Security Center > ModSecurity > Edit Global Directive",
})
}
// --- Rule vendor check ---
// cPanel-only: whmapi1 vendors. Plain hosts rely on file-system probing.
hasRules := false
if info.IsCPanel() {
out, _ := runCmd("whmapi1", "modsec_get_vendors")
if out != nil {
outStr := string(out)
if strings.Contains(outStr, "comodo") || strings.Contains(outStr, "owasp") ||
strings.Contains(outStr, "OWASP") || strings.Contains(outStr, "Comodo") {
hasRules = true
}
}
}
ruleDirs := modsecRuleDirs(info)
if hasRuleArtifacts(ruleDirs) {
hasRules = true
}
if !hasRules {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "waf_rules",
Message: "ModSecurity has no WAF rules loaded",
Details: wafRulesHint(info),
})
}
// --- Rule age check + auto-update ---
if hasRules {
staleAge := checkRuleAge(ruleDirs)
if staleAge > 0 {
// Attempt auto-update before alerting
updated := false
if info.IsCPanel() {
updated = autoUpdateWAFRules()
}
if updated {
// Re-check age after update
staleAge = checkRuleAge(ruleDirs)
}
if staleAge > 0 {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "waf_rules_stale",
Message: fmt.Sprintf("ModSecurity rules are %d days old - update recommended", staleAge),
Details: wafRulesStaleHint(info),
})
}
}
}
// --- Virtual patch deployment ---
// Only cPanel has the modsec user config dirs we write into.
if info.IsCPanel() {
deployVirtualPatches()
}
// --- Per-account WAF bypass check ---
// whmapi1-only, skip on non-cPanel hosts.
if info.IsCPanel() {
bypassed := checkPerAccountBypass()
for _, domain := range bypassed {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "waf_bypass",
Message: fmt.Sprintf("ModSecurity disabled for domain: %s", domain),
Details: "This domain has ModSecurity bypassed. All web attacks pass through unfiltered.\nCheck: WHM > Security Center > ModSecurity > Domains",
})
}
}
return findings
}
// modsecDetected returns true if a ModSecurity module is loaded for the
// detected web server. It first consults the platform layer, then falls
// back to scanning config files.
func modsecDetected(info platform.Info) bool {
// cPanel fast path
if info.IsCPanel() {
if out, _ := runCmd("whmapi1", "modsec_is_installed"); out != nil &&
strings.Contains(string(out), "installed: 1") {
return true
}
}
// Generic file-based probes per web server
for _, conf := range expandPathGlobs(modsecActivationCandidates(info)) {
data, err := osFS.ReadFile(conf)
if err != nil {
continue
}
if modsecEnabledInConfig(info, string(data)) {
return true
}
}
return false
}
// modsecActivationCandidates returns the config files that can enable the
// ModSecurity module for the detected web server.
func modsecActivationCandidates(info platform.Info) []string {
var paths []string
switch info.WebServer {
case platform.WSApache:
if info.ApacheConfigDir != "" {
paths = append(paths,
filepath.Join(info.ApacheConfigDir, "httpd.conf"),
filepath.Join(info.ApacheConfigDir, "apache2.conf"),
filepath.Join(info.ApacheConfigDir, "modsec2.conf"),
filepath.Join(info.ApacheConfigDir, "conf.d", "modsec2.conf"),
filepath.Join(info.ApacheConfigDir, "mods-enabled", "security2.conf"),
filepath.Join(info.ApacheConfigDir, "conf-enabled", "security2.conf"),
filepath.Join(info.ApacheConfigDir, "conf.d", "mod_security.conf"),
filepath.Join(info.ApacheConfigDir, "conf.modules.d", "10-mod_security.conf"),
filepath.Join(info.ApacheConfigDir, "conf.d", "*.conf"),
filepath.Join(info.ApacheConfigDir, "mods-enabled", "*.conf"),
filepath.Join(info.ApacheConfigDir, "conf-enabled", "*.conf"),
)
}
case platform.WSNginx:
if info.NginxConfigDir != "" {
paths = append(paths,
filepath.Join(info.NginxConfigDir, "nginx.conf"),
filepath.Join(info.NginxConfigDir, "conf.d", "*.conf"),
filepath.Join(info.NginxConfigDir, "sites-enabled", "*"),
)
}
case platform.WSLiteSpeed:
paths = append(paths,
"/usr/local/lsws/conf/httpd_config.xml",
)
}
return paths
}
// modsecConfigCandidates returns the set of config files worth scanning
// for ModSecurity directives on the detected web server.
func modsecConfigCandidates(info platform.Info) []string {
paths := append([]string(nil), modsecActivationCandidates(info)...)
switch info.WebServer {
case platform.WSNginx:
if info.NginxConfigDir != "" {
paths = append(paths,
filepath.Join(info.NginxConfigDir, "modules-enabled", "*.conf"),
filepath.Join(info.NginxConfigDir, "modsec", "main.conf"),
)
}
case platform.WSLiteSpeed:
paths = append(paths, "/usr/local/lsws/conf/modsec2.conf")
}
return paths
}
func modsecEnabledInConfig(info platform.Info, contents string) bool {
scanner := bufio.NewScanner(strings.NewReader(contents))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
lineLower := strings.ToLower(line)
switch info.WebServer {
case platform.WSNginx:
if strings.HasPrefix(lineLower, "#") {
continue
}
if strings.HasPrefix(lineLower, "modsecurity on") ||
strings.HasPrefix(lineLower, "modsecurity_rules ") ||
strings.HasPrefix(lineLower, "modsecurity_rules_file ") {
return true
}
case platform.WSLiteSpeed:
if strings.Contains(lineLower, "mod_security") || strings.Contains(lineLower, "modsecurity") {
return true
}
default:
if strings.HasPrefix(lineLower, "#") {
continue
}
if strings.Contains(lineLower, "security2_module") ||
strings.HasPrefix(lineLower, "secruleengine ") ||
strings.Contains(lineLower, "mod_security2") {
return true
}
}
}
return false
}
func expandPathGlobs(paths []string) []string {
var expanded []string
seen := make(map[string]struct{})
for _, candidate := range paths {
matches := []string{candidate}
if strings.ContainsAny(candidate, "*?[") {
if globbed, err := osFS.Glob(candidate); err == nil && len(globbed) > 0 {
matches = globbed
}
}
for _, match := range matches {
if _, ok := seen[match]; ok {
continue
}
seen[match] = struct{}{}
expanded = append(expanded, match)
}
}
return expanded
}
// modsecRuleDirs returns the candidate directories where vendor rules live
// for the detected web server/panel combination.
func modsecRuleDirs(info platform.Info) []string {
var dirs []string
switch info.WebServer {
case platform.WSApache:
if info.IsDebianFamily() {
dirs = append(dirs,
"/etc/apache2/conf.d/modsec_vendor_configs/",
"/etc/modsecurity/",
"/usr/share/modsecurity-crs/rules/",
)
}
if info.IsRHELFamily() {
dirs = append(dirs,
"/etc/httpd/modsecurity.d/",
"/etc/httpd/modsecurity.d/activated_rules/",
"/usr/share/modsecurity-crs/rules/",
)
}
dirs = append(dirs, "/usr/local/apache/conf/modsec_vendor_configs/")
case platform.WSNginx:
dirs = append(dirs,
"/etc/nginx/modsec/",
"/etc/modsecurity/",
"/usr/share/modsecurity-crs/rules/",
)
}
return dirs
}
// wafInstallHint returns platform-specific install instructions.
func wafInstallHint(info platform.Info) string {
switch {
case info.IsCPanel():
return "No ModSecurity module detected. Install: WHM > Security Center > ModSecurity"
case info.WebServer == platform.WSNginx && info.IsDebianFamily():
return "No ModSecurity module detected for Nginx.\nInstall: apt install libnginx-mod-http-modsecurity modsecurity-crs"
case info.WebServer == platform.WSApache && info.IsDebianFamily():
return "No ModSecurity module detected for Apache.\nInstall: apt install libapache2-mod-security2 modsecurity-crs && a2enmod security2"
case info.WebServer == platform.WSApache && info.IsRHELFamily():
return "No ModSecurity module detected for Apache.\nInstall (requires EPEL): dnf install -y epel-release && dnf install -y mod_security mod_security_crs && systemctl restart httpd"
case info.WebServer == platform.WSNginx && info.IsRHELFamily():
return "No ModSecurity module detected for Nginx.\nInstall (requires EPEL): dnf install -y epel-release && dnf install -y nginx-mod-http-modsecurity && systemctl restart nginx"
}
return "No ModSecurity module detected. The server has no web application firewall protecting against SQL injection, XSS, and other web attacks."
}
// wafRulesHint returns platform-specific rules-install instructions.
func wafRulesHint(info platform.Info) string {
if info.IsCPanel() {
return "ModSecurity is installed but has no OWASP or Comodo rules. Add rules: WHM > Security Center > ModSecurity Vendors"
}
if info.IsDebianFamily() {
return "ModSecurity is installed but has no rules loaded. Install OWASP CRS: apt install modsecurity-crs"
}
if info.IsRHELFamily() {
return "ModSecurity is installed but has no rules loaded. Install OWASP CRS: dnf install --enablerepo=epel modsecurity-crs"
}
return "ModSecurity is installed but has no rules loaded."
}
// wafRulesStaleHint returns platform-specific advice for updating stale
// ModSecurity vendor rules.
func wafRulesStaleHint(info platform.Info) string {
if info.IsCPanel() {
return "Vendor rules should be updated at least monthly. Check: WHM > Security Center > ModSecurity Vendors > Update"
}
if info.IsDebianFamily() {
return "Vendor rules should be updated at least monthly. Update with: apt update && apt upgrade modsecurity-crs"
}
if info.IsRHELFamily() {
return "Vendor rules should be updated at least monthly. Update with: dnf upgrade modsecurity-crs"
}
return "Vendor rules should be updated at least monthly."
}
// checkEngineMode reads ModSecurity config files to determine the SecRuleEngine setting.
// Returns "on", "detectiononly", "off", or "" if unknown.
func checkEngineMode(info platform.Info) string {
configPaths := modsecConfigCandidates(info)
// Also include the top-level modsecurity.conf installed by distro packages.
configPaths = append(configPaths,
"/etc/modsecurity/modsecurity.conf",
"/etc/nginx/modsec/modsecurity.conf",
)
for _, path := range configPaths {
f, err := osFS.Open(path)
if err != nil {
continue
}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "#") {
continue
}
lineLower := strings.ToLower(line)
if strings.HasPrefix(lineLower, "secruleengine") {
parts := strings.Fields(lineLower)
if len(parts) >= 2 {
_ = f.Close()
return parts[1]
}
}
}
_ = f.Close()
}
return ""
}
// checkRuleAge returns the age in days of the oldest rule file, or 0 if rules are fresh.
// Only alerts if rules are >30 days old.
func checkRuleAge(ruleDirs []string) int {
oldestMtime, found := oldestRuleArtifact(ruleDirs)
if !found {
return 0
}
age := int(time.Since(oldestMtime).Hours() / 24)
if age > 30 {
return age
}
return 0
}
func hasRuleArtifacts(ruleDirs []string) bool {
_, found := oldestRuleArtifact(ruleDirs)
return found
}
func oldestRuleArtifact(ruleDirs []string) (time.Time, bool) {
var oldestMtime time.Time
found := false
for _, dir := range ruleDirs {
entries, err := osFS.ReadDir(dir)
if err != nil {
continue
}
for _, entry := range entries {
if !entry.IsDir() {
// Rule file directly in the rule dir (distro CRS layout,
// e.g. /usr/share/modsecurity-crs/rules/REQUEST-*.conf).
info, err := entry.Info()
if err != nil {
continue
}
if !isRuleArtifact(entry.Name()) {
continue
}
if !found || info.ModTime().Before(oldestMtime) {
oldestMtime = info.ModTime()
found = true
}
continue
}
// Subdirectory: scan one level deeper for vendor-packed rules
// (cPanel layout, e.g. /usr/local/apache/conf/modsec_vendor_configs/OWASP/*.conf).
subDir := dir + "/" + entry.Name()
subEntries, err := osFS.ReadDir(subDir)
if err != nil {
continue
}
for _, subEntry := range subEntries {
if subEntry.IsDir() {
continue
}
if !isRuleArtifact(subEntry.Name()) {
continue
}
info, err := subEntry.Info()
if err != nil {
continue
}
if !found || info.ModTime().Before(oldestMtime) {
oldestMtime = info.ModTime()
found = true
}
}
}
}
return oldestMtime, found
}
// isRuleArtifact reports whether a filename looks like a ModSecurity rule
// or data artifact (.conf, .data, .rules) so unrelated files like README
// or LICENSE don't dominate the oldest-mtime calculation.
func isRuleArtifact(name string) bool {
name = strings.ToLower(name)
return strings.HasSuffix(name, ".conf") ||
strings.HasSuffix(name, ".data") ||
strings.HasSuffix(name, ".rules")
}
// checkPerAccountBypass checks for domains with ModSecurity disabled.
func checkPerAccountBypass() []string {
out, err := runCmd("whmapi1", "modsec_get_rules")
if err != nil || out == nil {
return nil
}
var bypassed []string
outStr := string(out)
// Parse YAML-like output for disabled domains
// The output format varies, but disabled rules/domains show "disabled: 1" or "active: 0"
lines := strings.Split(outStr, "\n")
for i, line := range lines {
lineLower := strings.ToLower(strings.TrimSpace(line))
if strings.Contains(lineLower, "disabled: 1") || strings.Contains(lineLower, "active: 0") {
// Look backward for the domain/config name
for j := i - 1; j >= 0 && j >= i-5; j-- {
prev := strings.TrimSpace(lines[j])
if strings.HasSuffix(prev, ":") && !strings.HasPrefix(prev, "-") {
domain := strings.TrimSuffix(prev, ":")
if strings.Contains(domain, ".") { // looks like a domain
bypassed = append(bypassed, domain)
}
break
}
}
}
}
return bypassed
}
// deployVirtualPatches ensures CSM's custom ModSec rules are installed.
// These provide virtual patches for known WordPress CVEs.
func deployVirtualPatches() {
// Possible modsec user config paths
destPaths := []string{
"/etc/apache2/conf.d/modsec/modsec2.user.conf",
"/usr/local/apache/conf/modsec2.user.conf",
}
srcPath := "/opt/csm/configs/csm_modsec_custom.conf"
srcData, err := osFS.ReadFile(srcPath)
if err != nil {
return // no custom rules to deploy
}
for _, dest := range destPaths {
dir := filepath.Dir(dest)
if _, err := osFS.Stat(dir); os.IsNotExist(err) {
continue
}
// Check if CSM rules are already included
existing, err := osFS.ReadFile(dest)
if err == nil && strings.Contains(string(existing), "CSM Custom ModSecurity Rules") {
// Already deployed - check if rules need updating
if string(existing) == string(srcData) {
return // up to date
}
}
// Deploy: if file exists and has non-CSM content, append. Otherwise write.
if err == nil && len(existing) > 0 && !strings.Contains(string(existing), "CSM Custom ModSecurity Rules") {
// Append to existing user config
// #nosec G302 G304 -- WAF rule file read by Apache/nginx as a different user; dest is fixed list above.
f, err := os.OpenFile(dest, os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
continue
}
_, _ = f.Write([]byte("\n\n"))
_, _ = f.Write(srcData)
_ = f.Close()
} else {
// Write or overwrite
// #nosec G306 -- same reason: webserver-readable WAF rule file.
_ = os.WriteFile(dest, srcData, 0644)
}
fmt.Fprintf(os.Stderr, "[%s] Virtual patches deployed to %s\n",
time.Now().Format("2006-01-02 15:04:05"), dest)
return
}
}
// autoUpdateWAFRules triggers ModSecurity vendor rule updates via whmapi1.
// Returns true if an update was successfully triggered.
func autoUpdateWAFRules() bool {
// Get installed vendors
out, err := runCmd("whmapi1", "modsec_get_vendors")
if err != nil || out == nil {
return false
}
// Parse vendor IDs from output (look for "vendor_id:" lines)
var vendors []string
for _, line := range strings.Split(string(out), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "vendor_id:") || strings.HasPrefix(line, "id:") {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
vid := strings.TrimSpace(parts[1])
if vid != "" {
vendors = append(vendors, vid)
}
}
}
}
if len(vendors) == 0 {
return false
}
// Update each vendor
updated := false
for _, vid := range vendors {
out, err := runCmd("whmapi1", "modsec_update_vendor", fmt.Sprintf("vendor_id=%s", vid))
if err == nil && out != nil && strings.Contains(string(out), "result: 1") {
fmt.Fprintf(os.Stderr, "[%s] WAF auto-update: vendor %s updated successfully\n",
time.Now().Format("2006-01-02 15:04:05"), vid)
updated = true
}
}
return updated
}
// CheckModSecAuditLog parses the ModSecurity audit log for blocked attacks.
// High-volume attackers are reported for potential auto-blocking.
func CheckModSecAuditLog(ctx context.Context, cfg *config.Config, store *state.Store) []alert.Finding {
var findings []alert.Finding
logPaths := platform.Detect().ModSecAuditLogPaths
if len(logPaths) == 0 {
return nil
}
var lines []string
for _, path := range logPaths {
lines = tailFile(path, 200)
if len(lines) > 0 {
break
}
}
if len(lines) == 0 {
return nil
}
// Count blocked attacks per IP
blocked := make(map[string]int)
for _, line := range lines {
if strings.Contains(line, "403") || strings.Contains(line, "Access denied") ||
strings.Contains(line, "MODSEC") || strings.Contains(line, "mod_security") {
ip := extractIPFromLog(line)
if ip != "" && !isInfraIP(ip, cfg.InfraIPs) {
blocked[ip]++
}
}
}
// Alert on high-volume attackers (auto-block integration via check name)
for ip, count := range blocked {
if count >= 20 {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "waf_attack_blocked",
Message: fmt.Sprintf("WAF blocking high-volume attacker: %s (%d blocked requests)", ip, count),
Details: fmt.Sprintf("IP %s has been blocked %d times by ModSecurity. Consider permanent block via CSM.", ip, count),
})
}
}
return findings
}
package checks
import (
"bufio"
"context"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
const wpChecksumWorkers = 5 // concurrent wp core verify-checksums
// CheckHtaccess scans for malicious .htaccess directives using pure Go ReadDir.
func CheckHtaccess(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
suspiciousPatterns := []string{
"auto_prepend_file",
"auto_append_file",
"eval(",
"base64_decode",
"gzinflate",
"str_rot13",
"php_value disable_functions",
"addhandler",
"addtype",
"sethandler",
}
safePatterns := []string{
"wordfence-waf.php",
"litespeed",
"advanced-headers.php",
"rsssl",
// Standard handler directives for PHP/static files are safe
"application/x-httpd-php",
"application/x-httpd-php5",
"application/x-httpd-ea-php",
"application/x-httpd-alt-php",
"text/html",
"text/css",
"text/javascript",
"application/javascript",
"image/",
"font/",
"proxy:unix",
// Security plugins that use handler directives to BLOCK execution
"-execcgi", // Options -ExecCGI disables CGI (Wordfence pattern)
"sethandler none", // Disables all handlers (security measure)
"sethandler default-handler", // Resets to default (security measure)
// Legitimate MIME type additions
"application/font",
"application/vnd",
".woff",
".woff2",
".ttf",
".eot",
".svg",
// Wordfence code execution protection
"wordfence",
}
// Scan each user's document roots
homeDirs, _ := GetScanHomeDirs()
for _, homeEntry := range homeDirs {
if ctx.Err() != nil {
return findings
}
if !homeEntry.IsDir() {
continue
}
homeDir := filepath.Join("/home", homeEntry.Name())
docRoot := filepath.Join(homeDir, "public_html")
scanHtaccess(ctx, docRoot, 5, suspiciousPatterns, safePatterns, cfg, &findings)
// Also check addon domains
subDirs, _ := osFS.ReadDir(homeDir)
for _, sd := range subDirs {
if sd.IsDir() && sd.Name() != "public_html" && sd.Name() != "mail" &&
!strings.HasPrefix(sd.Name(), ".") && sd.Name() != "etc" &&
sd.Name() != "logs" && sd.Name() != "ssl" && sd.Name() != "tmp" {
scanHtaccess(ctx, filepath.Join(homeDir, sd.Name()), 5, suspiciousPatterns, safePatterns, cfg, &findings)
}
}
}
return findings
}
func scanHtaccess(ctx context.Context, dir string, maxDepth int, suspicious, safe []string, cfg *config.Config, findings *[]alert.Finding) {
if ctx.Err() != nil {
return
}
if maxDepth <= 0 {
return
}
entries, err := osFS.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
if ctx.Err() != nil {
return
}
name := entry.Name()
fullPath := filepath.Join(dir, name)
if entry.IsDir() {
scanHtaccess(ctx, fullPath, maxDepth-1, suspicious, safe, cfg, findings)
continue
}
if name != ".htaccess" {
continue
}
// Skip suppressed paths
suppressed := false
for _, ignore := range cfg.Suppressions.IgnorePaths {
if matchGlob(fullPath, ignore) {
suppressed = true
break
}
}
if suppressed {
continue
}
checkHtaccessFile(fullPath, suspicious, safe, findings)
}
}
func checkHtaccessFile(path string, suspicious, safe []string, findings *[]alert.Finding) {
f, err := osFS.Open(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
// Read entire file to check cross-line context (e.g., AddHandler + Options -ExecCGI)
var lines []string
scanner := bufio.NewScanner(f)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
// Build full file content for context checks
fullContentLower := strings.ToLower(strings.Join(lines, "\n"))
// If file contains handler directives paired with -ExecCGI, the whole
// block is a security measure (e.g., Wordfence execution protection)
hasExecCGIBlock := strings.Contains(fullContentLower, "-execcgi")
for lineNum, line := range lines {
trimmed := strings.TrimSpace(line)
lineLower := strings.ToLower(trimmed)
// Skip comments entirely - commented-out directives are not active
if strings.HasPrefix(trimmed, "#") {
continue
}
for _, pattern := range suspicious {
if !strings.Contains(lineLower, strings.ToLower(pattern)) {
continue
}
// Check per-line safe patterns
isSafe := false
for _, sp := range safe {
if strings.Contains(lineLower, strings.ToLower(sp)) {
isSafe = true
break
}
}
if isSafe {
continue
}
patternLower := strings.ToLower(pattern)
// For handler directives, apply context-aware checks
if patternLower == "addhandler" || patternLower == "sethandler" {
// Skip if paired with -ExecCGI (Wordfence protection)
if hasExecCGIBlock {
continue
}
// Skip Drupal security handlers
if strings.Contains(lineLower, "drupal_security") {
continue
}
// Skip SetHandler none/default (disabling handlers = security measure)
if strings.Contains(lineLower, "sethandler none") ||
strings.Contains(lineLower, "sethandler default") {
continue
}
// Skip AddHandler for standard CGI extensions only (.cgi, .pl)
if strings.Contains(lineLower, "addhandler") {
// Only flag if mapping non-standard extensions
standardCGI := true
hasNonStandard := false
// Check each extension on the line
for _, ext := range []string{".haxor", ".cgix", ".phtml", ".php3",
".php5", ".suspected", ".bak.php", ".shtml", ".sh"} {
if strings.Contains(lineLower, ext) {
hasNonStandard = true
break
}
}
// If line only has .cgi and/or .pl, it's standard
if !hasNonStandard && standardCGI {
onlyStandard := true
parts := strings.Fields(lineLower)
for _, p := range parts {
if strings.HasPrefix(p, ".") && p != ".cgi" && p != ".pl" && p != ".py" &&
p != ".php" && p != ".jsp" && p != ".asp" {
// Has non-standard extension
onlyStandard = false
break
}
}
if onlyStandard {
continue
}
}
}
}
// Skip AddType for any MIME type (application/*, text/*, x-mapp-*, etc.)
if patternLower == "addtype" {
// AddType is only dangerous if it maps to a PHP/CGI handler
// Standard MIME type declarations are safe
if strings.Contains(lineLower, "application/") ||
strings.Contains(lineLower, "text/") ||
strings.Contains(lineLower, "image/") ||
strings.Contains(lineLower, "font/") ||
strings.Contains(lineLower, "x-mapp-") ||
strings.Contains(lineLower, "audio/") ||
strings.Contains(lineLower, "video/") {
continue
}
}
*findings = append(*findings, alert.Finding{
Severity: alert.High,
Check: "htaccess_injection",
Message: fmt.Sprintf("Suspicious .htaccess directive: %s", pattern),
Details: fmt.Sprintf("File: %s (line %d)\nContent: %s", path, lineNum+1, trimmed),
FilePath: path,
})
}
}
// Special check: AddHandler mapping non-standard extensions WITHOUT -ExecCGI
// (actual attack pattern - e.g., AddHandler cgi-script .haxor)
if !hasExecCGIBlock && strings.Contains(fullContentLower, "addhandler") {
for lineNum, line := range lines {
lineLower := strings.ToLower(line)
if !strings.Contains(lineLower, "addhandler") {
continue
}
// Flag if it maps unusual extensions like .haxor, .cgix, etc.
dangerousExts := []string{".haxor", ".cgix", ".suspected", ".bak.php"}
for _, ext := range dangerousExts {
if strings.Contains(lineLower, ext) {
*findings = append(*findings, alert.Finding{
Severity: alert.Critical,
Check: "htaccess_handler_abuse",
Message: fmt.Sprintf("Malicious handler mapping for %s extension", ext),
Details: fmt.Sprintf("File: %s (line %d)\nContent: %s", path, lineNum+1, strings.TrimSpace(line)),
FilePath: path,
})
}
}
}
}
}
// CheckWPCore runs wp core verify-checksums for each WordPress installation
// using a bounded worker pool for concurrency.
// Installations that pass verification have their core files cached in
// GlobalCMSCache so the real-time scanner can skip signature matches
// on known-clean CMS files.
func CheckWPCore(ctx context.Context, _ *config.Config, _ *state.Store) []alert.Finding {
wpConfigs, _ := osFS.Glob("/home/*/public_html/wp-config.php")
if len(wpConfigs) == 0 {
return nil
}
cache := GlobalCMSCache()
var mu sync.Mutex
var findings []alert.Finding
var wg sync.WaitGroup
// Bounded worker pool
jobs := make(chan string, len(wpConfigs))
for i := 0; i < wpChecksumWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for wpConfig := range jobs {
if ctx.Err() != nil {
return
}
wpPath := filepath.Dir(wpConfig)
user := extractUser(wpPath)
out, err := runCmdCombinedContext(ctx, "wp", "core", "verify-checksums",
"--path="+wpPath, "--allow-root")
if ctx.Err() != nil {
return
}
if err == nil {
// Verification passed - cache all core files
cacheWPCoreFiles(cache, wpPath)
continue
}
if out == nil {
continue
}
outStr := string(out)
for _, line := range strings.Split(outStr, "\n") {
if strings.Contains(line, "should not exist") && !strings.Contains(line, "error_log") {
mu.Lock()
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "wp_core_integrity",
Message: fmt.Sprintf("WordPress core integrity failure for %s", user),
Details: fmt.Sprintf("Path: %s\n%s", wpPath, line),
})
mu.Unlock()
}
}
}
}()
}
for _, wpConfig := range wpConfigs {
if ctx.Err() != nil {
break
}
jobs <- wpConfig
}
close(jobs)
wg.Wait()
fmt.Fprintf(os.Stderr, "CMS hash cache: %d verified core files cached\n", cache.Size())
return findings
}
// cacheWPCoreFiles hashes all PHP files in wp-includes/ and wp-admin/
// for a verified-clean WordPress installation and adds them to the cache.
func cacheWPCoreFiles(cache *CMSHashCache, wpPath string) {
coreDirs := []string{
filepath.Join(wpPath, "wp-includes"),
filepath.Join(wpPath, "wp-admin"),
}
// Also cache root-level WP core files
rootFiles := []string{
"wp-config.php", "wp-cron.php", "wp-login.php", "wp-settings.php",
"wp-load.php", "wp-blog-header.php", "wp-links-opml.php",
"wp-mail.php", "wp-signup.php", "wp-activate.php",
"wp-comments-post.php", "wp-trackback.php", "xmlrpc.php",
"index.php",
}
for _, name := range rootFiles {
path := filepath.Join(wpPath, name)
if hash := HashFile(path); hash != "" {
cache.Add(hash)
}
}
for _, dir := range coreDirs {
_ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
name := strings.ToLower(info.Name())
if strings.HasSuffix(name, ".php") || strings.HasSuffix(name, ".js") {
if hash := HashFile(path); hash != "" {
cache.Add(hash)
}
}
return nil
})
}
}
func extractUser(path string) string {
parts := strings.Split(path, "/")
for i, p := range parts {
if p == "home" && i+1 < len(parts) {
return parts[i+1]
}
}
return "unknown"
}
package checks
import (
"bufio"
"context"
"fmt"
"os"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/state"
)
// CheckWHMAccess parses the cPanel access log for WHM (port 2087) logins
// and password change API calls from non-infra IPs.
// Only reads the tail of the log - lightweight.
func CheckWHMAccess(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
lines := tailFile("/usr/local/cpanel/logs/access_log", 200)
for _, line := range lines {
// Only check WHM (port 2087) entries
if !strings.Contains(line, "2087") {
continue
}
// Extract IP (first field)
fields := strings.Fields(line)
if len(fields) < 1 {
continue
}
ip := fields[0]
// Skip infra IPs
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
// Check for password change actions
passwordActions := []string{
"passwd", "change_root_password", "chpasswd",
"force_password_change", "resetpass",
}
lineLower := strings.ToLower(line)
for _, action := range passwordActions {
if strings.Contains(lineLower, action) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "whm_password_change",
Message: fmt.Sprintf("WHM password change from non-infra IP: %s", ip),
Details: truncateString(line, 200),
})
break
}
}
// Check for account management from unknown IPs
accountActions := []string{
"createacct", "killacct", "suspendacct", "unsuspendacct",
}
for _, action := range accountActions {
if strings.Contains(lineLower, action) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "whm_account_action",
Message: fmt.Sprintf("WHM account action from non-infra IP: %s", ip),
Details: truncateString(line, 200),
})
break
}
}
}
return findings
}
// CheckSSHLogins parses /var/log/secure for SSH logins from non-infra IPs.
func CheckSSHLogins(ctx context.Context, cfg *config.Config, _ *state.Store) []alert.Finding {
var findings []alert.Finding
lines := tailFile("/var/log/secure", 100)
for _, line := range lines {
if !strings.Contains(line, "Accepted") {
continue
}
// Extract IP - format: "Accepted publickey for root from 1.2.3.4 port 12345"
parts := strings.Fields(line)
ipIdx := -1
for i, p := range parts {
if p == "from" && i+1 < len(parts) {
ipIdx = i + 1
break
}
}
if ipIdx < 0 || ipIdx >= len(parts) {
continue
}
ip := parts[ipIdx]
if isInfraIP(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
continue
}
// Extract user
user := "unknown"
for i, p := range parts {
if p == "for" && i+1 < len(parts) {
user = parts[i+1]
break
}
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "ssh_login_unknown_ip",
Message: fmt.Sprintf("SSH login from non-infra IP: %s (user: %s)", ip, user),
Details: truncateString(line, 200),
})
}
return findings
}
// tailFile reads the last N lines of a file efficiently.
func tailFile(path string, maxLines int) []string {
f, err := osFS.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
// Seek to end and read backwards to find last N lines
info, err := f.Stat()
if err != nil {
return nil
}
// For small files, just read all
if info.Size() < 1024*1024 {
return readAllLines(f, maxLines)
}
// For large files, read last 256KB (enough for ~2000 lines)
readSize := int64(256 * 1024)
if readSize > info.Size() {
readSize = info.Size()
}
_, err = f.Seek(-readSize, 2) // seek from end
if err != nil {
return readAllLines(f, maxLines)
}
return readAllLines(f, maxLines)
}
func readAllLines(f *os.File, maxLines int) []string {
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), 256*1024)
var lines []string
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
// Return last N lines
if len(lines) > maxLines {
return lines[len(lines)-maxLines:]
}
return lines
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
package config
import (
"bytes"
"fmt"
"os"
"gopkg.in/yaml.v3"
"github.com/pidginhost/csm/internal/firewall"
)
type Config struct {
ConfigFile string `yaml:"-"`
Hostname string `yaml:"hostname"`
Alerts struct {
Email struct {
Enabled bool `yaml:"enabled"`
To []string `yaml:"to"`
From string `yaml:"from"`
SMTP string `yaml:"smtp"`
DisabledChecks []string `yaml:"disabled_checks"`
} `yaml:"email"`
Webhook struct {
Enabled bool `yaml:"enabled"`
URL string `yaml:"url"`
Type string `yaml:"type"` // slack, discord, generic
} `yaml:"webhook"`
Heartbeat struct {
Enabled bool `yaml:"enabled"`
URL string `yaml:"url"`
} `yaml:"heartbeat"`
MaxPerHour int `yaml:"max_per_hour"`
} `yaml:"alerts"`
Integrity struct {
BinaryHash string `yaml:"binary_hash"`
ConfigHash string `yaml:"config_hash"`
Immutable bool `yaml:"immutable"`
} `yaml:"integrity"`
Thresholds struct {
MailQueueWarn int `yaml:"mail_queue_warn"`
MailQueueCrit int `yaml:"mail_queue_crit"`
StateExpiryHours int `yaml:"state_expiry_hours"`
DeepScanIntervalMin int `yaml:"deep_scan_interval_min"`
WPCoreCheckIntervalMin int `yaml:"wp_core_check_interval_min"`
WebshellScanIntervalMin int `yaml:"webshell_scan_interval_min"`
FilesystemScanIntervalMin int `yaml:"filesystem_scan_interval_min"`
MultiIPLoginThreshold int `yaml:"multi_ip_login_threshold"`
MultiIPLoginWindowMin int `yaml:"multi_ip_login_window_min"`
PluginCheckIntervalMin int `yaml:"plugin_check_interval_min"`
BruteForceWindow int `yaml:"brute_force_window"`
SMTPBruteForceThreshold int `yaml:"smtp_bruteforce_threshold"`
SMTPBruteForceWindowMin int `yaml:"smtp_bruteforce_window_min"`
SMTPBruteForceSuppressMin int `yaml:"smtp_bruteforce_suppress_min"`
SMTPBruteForceSubnetThresh int `yaml:"smtp_bruteforce_subnet_threshold"`
SMTPAccountSprayThreshold int `yaml:"smtp_account_spray_threshold"`
SMTPBruteForceMaxTracked int `yaml:"smtp_bruteforce_max_tracked"`
MailBruteForceThreshold int `yaml:"mail_bruteforce_threshold"`
MailBruteForceWindowMin int `yaml:"mail_bruteforce_window_min"`
MailBruteForceSuppressMin int `yaml:"mail_bruteforce_suppress_min"`
MailBruteForceSubnetThresh int `yaml:"mail_bruteforce_subnet_threshold"`
MailAccountSprayThreshold int `yaml:"mail_account_spray_threshold"`
MailBruteForceMaxTracked int `yaml:"mail_bruteforce_max_tracked"`
} `yaml:"thresholds"`
InfraIPs []string `yaml:"infra_ips"`
StatePath string `yaml:"state_path"`
Suppressions struct {
UPCPWindowStart string `yaml:"upcp_window_start"`
UPCPWindowEnd string `yaml:"upcp_window_end"`
KnownAPITokens []string `yaml:"known_api_tokens"`
IgnorePaths []string `yaml:"ignore_paths"`
SuppressWebmail bool `yaml:"suppress_webmail_alerts"` // don't alert on webmail logins
SuppressCpanelLogin bool `yaml:"suppress_cpanel_login_alerts"` // don't alert on cPanel direct logins
SuppressBlockedAlerts bool `yaml:"suppress_blocked_alerts"` // don't alert on IPs that were auto-blocked
TrustedCountries []string `yaml:"trusted_countries"` // ISO 3166-1 alpha-2 codes - suppress cPanel login alerts from these countries
} `yaml:"suppressions"`
AutoResponse struct {
Enabled bool `yaml:"enabled"`
KillProcesses bool `yaml:"kill_processes"`
QuarantineFiles bool `yaml:"quarantine_files"`
BlockIPs bool `yaml:"block_ips"`
BlockExpiry string `yaml:"block_expiry"` // e.g. "24h", "12h"
EnforcePermissions bool `yaml:"enforce_permissions"` // auto-chmod 644 world/group-writable PHP files (default false)
BlockCpanelLogins bool `yaml:"block_cpanel_logins"` // block IPs on cPanel/webmail login alerts (default false)
NetBlock bool `yaml:"netblock"` // auto-block /24 when threshold IPs from same subnet
NetBlockThreshold int `yaml:"netblock_threshold"` // IPs from same /24 before subnet block (default 3)
PermBlock bool `yaml:"permblock"` // auto-promote to permanent after N temp blocks
PermBlockCount int `yaml:"permblock_count"` // temp blocks before permanent (default 4)
PermBlockInterval string `yaml:"permblock_interval"` // window for counting temp blocks (default "24h")
CleanDatabase bool `yaml:"clean_database"` // auto-clean malicious DB injections, revoke sessions, block attacker IPs (default false)
} `yaml:"auto_response"`
Challenge struct {
Enabled bool `yaml:"enabled"` // enable challenge pages instead of hard block for some IPs
ListenPort int `yaml:"listen_port"` // port for challenge server (default: 8439)
Secret string `yaml:"secret"` // HMAC secret for challenge tokens (auto-generated if empty)
Difficulty int `yaml:"difficulty"` // proof-of-work difficulty 0-5 (default: 2)
TrustedProxies []string `yaml:"trusted_proxies"` // IPs allowed to set X-Forwarded-For (empty = trust RemoteAddr only)
} `yaml:"challenge"`
PHPShield struct {
Enabled bool `yaml:"enabled"` // watch php_events.log for PHP Shield alerts (default: false)
} `yaml:"php_shield"`
Reputation struct {
AbuseIPDBKey string `yaml:"abuseipdb_key"`
Whitelist []string `yaml:"whitelist"` // IPs to never flag as malicious
} `yaml:"reputation"`
Signatures struct {
RulesDir string `yaml:"rules_dir"`
UpdateURL string `yaml:"update_url"`
AutoUpdate bool `yaml:"auto_update"` // auto-download rules daily (default: true if update_url set)
UpdateInterval string `yaml:"update_interval"` // how often to check (default: "24h")
SigningKey string `yaml:"signing_key"` // hex-encoded ed25519 public key for verifying rule updates
YaraForge struct {
Enabled bool `yaml:"enabled"`
Tier string `yaml:"tier"` // "core", "extended", "full" (default: "core")
UpdateInterval string `yaml:"update_interval"` // default: "168h" (weekly)
} `yaml:"yara_forge"`
DisabledRules []string `yaml:"disabled_rules"` // YARA rule names to exclude from Forge downloads
} `yaml:"signatures"`
WebUI struct {
Enabled bool `yaml:"enabled"`
Listen string `yaml:"listen"`
AuthToken string `yaml:"auth_token"`
TLSCert string `yaml:"tls_cert"`
TLSKey string `yaml:"tls_key"`
UIDir string `yaml:"ui_dir"` // path to UI files on disk (default: /opt/csm/ui)
} `yaml:"webui"`
EmailAV EmailAVConfig `yaml:"email_av"`
EmailProtection struct {
PasswordCheckIntervalMin int `yaml:"password_check_interval_min"`
HighVolumeSenders []string `yaml:"high_volume_senders"`
RateWarnThreshold int `yaml:"rate_warn_threshold"`
RateCritThreshold int `yaml:"rate_crit_threshold"`
RateWindowMin int `yaml:"rate_window_min"`
KnownForwarders []string `yaml:"known_forwarders"`
} `yaml:"email_protection"`
Firewall *firewall.FirewallConfig `yaml:"firewall"`
GeoIP struct {
AccountID string `yaml:"account_id"`
LicenseKey string `yaml:"license_key"`
Editions []string `yaml:"editions"`
AutoUpdate *bool `yaml:"auto_update"` // nil = true when credentials set
UpdateInterval string `yaml:"update_interval"` // default "24h"
} `yaml:"geoip"`
ModSecErrorLog string `yaml:"modsec_error_log"`
ModSec struct {
RulesFile string `yaml:"rules_file"` // path to modsec2.user.conf
OverridesFile string `yaml:"overrides_file"` // path to csm-overrides.conf
ReloadCommand string `yaml:"reload_command"` // e.g. "systemctl restart lsws"
} `yaml:"modsec"`
// WebServer overrides the auto-detected web server paths. Every field is
// optional: anything left blank or empty falls back to what
// platform.Detect() returned at startup. Intended for hosts with a
// custom layout (reverse proxy in front of a second daemon, non-standard
// package locations, chroot, etc.).
WebServer struct {
Type string `yaml:"type"` // "apache", "nginx", "litespeed" — overrides auto-detect
ConfigDir string `yaml:"config_dir"` // e.g. /etc/apache2 or /etc/nginx
AccessLogs []string `yaml:"access_logs"` // candidate access-log paths, tried in order
ErrorLogs []string `yaml:"error_logs"` // candidate error-log paths (used for modsec denies)
ModSecAudits []string `yaml:"modsec_audit_logs"` // candidate ModSecurity audit-log paths
} `yaml:"web_server"`
// AccountRoots lets operators point the account-scan based checks at
// non-cPanel web root layouts. Each entry is a glob pattern expanded
// at check time. Examples:
//
// account_roots:
// - /var/www/*/public
// - /srv/http/*
// - /home/*/public_html # cPanel default (implicit when unset on cPanel)
//
// When unset, CSM uses the cPanel default of /home/*/public_html on
// cPanel hosts and an empty list on non-cPanel hosts (account-scan
// checks skip entirely). See docs/src/configuration.md for the full
// list of checks that consume this.
AccountRoots []string `yaml:"account_roots"`
Performance struct {
Enabled *bool `yaml:"enabled"`
LoadHighMultiplier float64 `yaml:"load_high_multiplier"`
LoadCriticalMultiplier float64 `yaml:"load_critical_multiplier"`
PHPProcessWarnPerUser int `yaml:"php_process_warn_per_user"`
PHPProcessCriticalTotalMult int `yaml:"php_process_critical_total_multiplier"`
ErrorLogWarnSizeMB int `yaml:"error_log_warn_size_mb"`
MySQLJoinBufferMaxMB int `yaml:"mysql_join_buffer_max_mb"`
MySQLWaitTimeoutMax int `yaml:"mysql_wait_timeout_max"`
MySQLMaxConnectionsPerUser int `yaml:"mysql_max_connections_per_user"`
RedisBgsaveMinInterval int `yaml:"redis_bgsave_min_interval"`
RedisLargeDatasetGB int `yaml:"redis_large_dataset_gb"`
WPMemoryLimitMaxMB int `yaml:"wp_memory_limit_max_mb"`
WPTransientWarnMB int `yaml:"wp_transient_warn_mb"`
WPTransientCriticalMB int `yaml:"wp_transient_critical_mb"`
} `yaml:"performance"`
Cloudflare struct {
Enabled bool `yaml:"enabled"`
RefreshHours int `yaml:"refresh_hours"`
} `yaml:"cloudflare"`
C2Blocklist []string `yaml:"c2_blocklist"`
BackdoorPorts []int `yaml:"backdoor_ports"`
}
func Load(path string) (*Config, error) {
// #nosec G304 -- path is operator-supplied config file (CLI flag / env).
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading config %s: %w", path, err)
}
cfg := &Config{}
dec := yaml.NewDecoder(bytes.NewReader(data))
dec.KnownFields(true)
if err := dec.Decode(cfg); err != nil {
return nil, fmt.Errorf("parsing config: %w", err)
}
cfg.ConfigFile = path
// Defaults
if cfg.StatePath == "" {
cfg.StatePath = "/opt/csm/state"
}
if cfg.Signatures.RulesDir == "" {
cfg.Signatures.RulesDir = "/opt/csm/rules"
}
if cfg.Signatures.YaraForge.Tier == "" {
cfg.Signatures.YaraForge.Tier = "core"
}
if cfg.Signatures.YaraForge.UpdateInterval == "" {
cfg.Signatures.YaraForge.UpdateInterval = "168h"
}
if cfg.WebUI.Listen == "" {
cfg.WebUI.Listen = "0.0.0.0:9443"
}
if cfg.Thresholds.MailQueueWarn == 0 {
cfg.Thresholds.MailQueueWarn = 500
}
if cfg.Thresholds.MailQueueCrit == 0 {
cfg.Thresholds.MailQueueCrit = 2000
}
if cfg.Thresholds.StateExpiryHours == 0 {
cfg.Thresholds.StateExpiryHours = 24
}
if cfg.Thresholds.DeepScanIntervalMin == 0 {
cfg.Thresholds.DeepScanIntervalMin = 60
}
if cfg.Thresholds.WPCoreCheckIntervalMin == 0 {
cfg.Thresholds.WPCoreCheckIntervalMin = 60
}
if cfg.Thresholds.WebshellScanIntervalMin == 0 {
cfg.Thresholds.WebshellScanIntervalMin = 30
}
if cfg.Thresholds.FilesystemScanIntervalMin == 0 {
cfg.Thresholds.FilesystemScanIntervalMin = 30
}
if cfg.Thresholds.PluginCheckIntervalMin == 0 {
cfg.Thresholds.PluginCheckIntervalMin = 1440
}
if cfg.Thresholds.BruteForceWindow == 0 {
cfg.Thresholds.BruteForceWindow = 5000
}
if cfg.Thresholds.SMTPBruteForceThreshold == 0 {
cfg.Thresholds.SMTPBruteForceThreshold = 5
}
if cfg.Thresholds.SMTPBruteForceWindowMin == 0 {
cfg.Thresholds.SMTPBruteForceWindowMin = 10
}
if cfg.Thresholds.SMTPBruteForceSuppressMin == 0 {
cfg.Thresholds.SMTPBruteForceSuppressMin = 60
}
if cfg.Thresholds.SMTPBruteForceSubnetThresh == 0 {
cfg.Thresholds.SMTPBruteForceSubnetThresh = 8
}
if cfg.Thresholds.SMTPAccountSprayThreshold == 0 {
cfg.Thresholds.SMTPAccountSprayThreshold = 12
}
if cfg.Thresholds.SMTPBruteForceMaxTracked == 0 {
cfg.Thresholds.SMTPBruteForceMaxTracked = 20000
}
if cfg.Thresholds.MailBruteForceThreshold == 0 {
cfg.Thresholds.MailBruteForceThreshold = 5
}
if cfg.Thresholds.MailBruteForceWindowMin == 0 {
cfg.Thresholds.MailBruteForceWindowMin = 10
}
if cfg.Thresholds.MailBruteForceSuppressMin == 0 {
cfg.Thresholds.MailBruteForceSuppressMin = 60
}
if cfg.Thresholds.MailBruteForceSubnetThresh == 0 {
cfg.Thresholds.MailBruteForceSubnetThresh = 8
}
if cfg.Thresholds.MailAccountSprayThreshold == 0 {
cfg.Thresholds.MailAccountSprayThreshold = 12
}
if cfg.Thresholds.MailBruteForceMaxTracked == 0 {
cfg.Thresholds.MailBruteForceMaxTracked = 20000
}
if cfg.Alerts.MaxPerHour == 0 {
cfg.Alerts.MaxPerHour = 30
}
if cfg.Challenge.ListenPort == 0 {
cfg.Challenge.ListenPort = 8439
}
if cfg.Challenge.Difficulty == 0 {
cfg.Challenge.Difficulty = 2
}
if cfg.Firewall == nil {
cfg.Firewall = firewall.DefaultConfig()
}
if len(cfg.GeoIP.Editions) == 0 {
cfg.GeoIP.Editions = []string{"GeoLite2-City", "GeoLite2-ASN"}
}
if cfg.GeoIP.UpdateInterval == "" {
cfg.GeoIP.UpdateInterval = "24h"
}
EmailAVDefaults(&cfg.EmailAV)
if cfg.EmailProtection.PasswordCheckIntervalMin == 0 {
cfg.EmailProtection.PasswordCheckIntervalMin = 1440
}
if cfg.EmailProtection.RateWarnThreshold == 0 {
cfg.EmailProtection.RateWarnThreshold = 50
}
if cfg.EmailProtection.RateCritThreshold == 0 {
cfg.EmailProtection.RateCritThreshold = 100
}
if cfg.EmailProtection.RateWindowMin == 0 {
cfg.EmailProtection.RateWindowMin = 10
}
// Performance defaults
if cfg.Performance.Enabled == nil {
t := true
cfg.Performance.Enabled = &t
}
if cfg.Performance.LoadHighMultiplier == 0 {
cfg.Performance.LoadHighMultiplier = 1.0
}
if cfg.Performance.LoadCriticalMultiplier == 0 {
cfg.Performance.LoadCriticalMultiplier = 2.0
}
if cfg.Performance.PHPProcessWarnPerUser == 0 {
cfg.Performance.PHPProcessWarnPerUser = 20
}
if cfg.Performance.PHPProcessCriticalTotalMult == 0 {
cfg.Performance.PHPProcessCriticalTotalMult = 5
}
if cfg.Performance.ErrorLogWarnSizeMB == 0 {
cfg.Performance.ErrorLogWarnSizeMB = 50
}
if cfg.Performance.MySQLJoinBufferMaxMB == 0 {
cfg.Performance.MySQLJoinBufferMaxMB = 64
}
if cfg.Performance.MySQLWaitTimeoutMax == 0 {
cfg.Performance.MySQLWaitTimeoutMax = 3600
}
if cfg.Performance.MySQLMaxConnectionsPerUser == 0 {
cfg.Performance.MySQLMaxConnectionsPerUser = 10
}
if cfg.Performance.RedisBgsaveMinInterval == 0 {
cfg.Performance.RedisBgsaveMinInterval = 900
}
if cfg.Performance.RedisLargeDatasetGB == 0 {
cfg.Performance.RedisLargeDatasetGB = 4
}
if cfg.Performance.WPMemoryLimitMaxMB == 0 {
cfg.Performance.WPMemoryLimitMaxMB = 512
}
if cfg.Performance.WPTransientWarnMB == 0 {
cfg.Performance.WPTransientWarnMB = 1
}
if cfg.Performance.WPTransientCriticalMB == 0 {
cfg.Performance.WPTransientCriticalMB = 10
}
if cfg.Cloudflare.RefreshHours == 0 {
cfg.Cloudflare.RefreshHours = 6
}
return cfg, nil
}
func Save(cfg *Config) error {
data, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("marshaling config: %w", err)
}
return os.WriteFile(cfg.ConfigFile, data, 0600)
}
package config
import "time"
// EmailAVConfig holds email antivirus scanning settings.
type EmailAVConfig struct {
Enabled bool `yaml:"enabled"`
ClamdSocket string `yaml:"clamd_socket"`
ScanTimeout string `yaml:"scan_timeout"`
MaxAttachmentSize int64 `yaml:"max_attachment_size"`
MaxArchiveDepth int `yaml:"max_archive_depth"`
MaxArchiveFiles int `yaml:"max_archive_files"`
MaxExtractionSize int64 `yaml:"max_extraction_size"`
QuarantineInfected bool `yaml:"quarantine_infected"`
ScanConcurrency int `yaml:"scan_concurrency"`
// FailMode controls behavior when scanning cannot complete.
// "open" (default): deliver mail when engines are down, scans time out, or quarantine fails.
// "tempfail": defer delivery (Exim retries later) when scanning cannot complete or infected mail cannot be quarantined.
FailMode string `yaml:"fail_mode"`
}
// ScanTimeoutDuration parses the ScanTimeout string as a time.Duration.
func (c *EmailAVConfig) ScanTimeoutDuration() time.Duration {
if c.ScanTimeout == "" {
return 30 * time.Second
}
d, err := time.ParseDuration(c.ScanTimeout)
if err != nil {
return 30 * time.Second
}
return d
}
// EmailAVDefaults applies default values to an EmailAVConfig.
func EmailAVDefaults(c *EmailAVConfig) {
if c.ClamdSocket == "" {
c.ClamdSocket = "/var/run/clamd.scan/clamd.sock"
}
if c.ScanTimeout == "" {
c.ScanTimeout = "30s"
}
if c.MaxAttachmentSize == 0 {
c.MaxAttachmentSize = 25 * 1024 * 1024 // 25 MB
}
if c.MaxArchiveDepth == 0 {
c.MaxArchiveDepth = 1
}
if c.MaxArchiveFiles == 0 {
c.MaxArchiveFiles = 50
}
if c.MaxExtractionSize == 0 {
c.MaxExtractionSize = 100 * 1024 * 1024 // 100 MB
}
if c.ScanConcurrency == 0 {
c.ScanConcurrency = 4
}
}
package config
const redactedValue = "***REDACTED***"
// Redact returns a copy of the config with sensitive fields replaced.
// Empty fields are left empty (not replaced with the redaction marker).
// The original config is not modified.
func Redact(cfg *Config) *Config {
// Shallow copy the struct
c := *cfg
// Redact secrets (only if non-empty)
if c.WebUI.AuthToken != "" {
c.WebUI.AuthToken = redactedValue
}
if c.GeoIP.LicenseKey != "" {
c.GeoIP.LicenseKey = redactedValue
}
if c.Reputation.AbuseIPDBKey != "" {
c.Reputation.AbuseIPDBKey = redactedValue
}
if c.Challenge.Secret != "" {
c.Challenge.Secret = redactedValue
}
if c.Integrity.BinaryHash != "" {
c.Integrity.BinaryHash = redactedValue
}
if c.Integrity.ConfigHash != "" {
c.Integrity.ConfigHash = redactedValue
}
// Deep-copy Firewall pointer so we don't share it with the original
if cfg.Firewall != nil {
fw := *cfg.Firewall
c.Firewall = &fw
}
return &c
}
package config
import (
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
// ValidationResult represents a single validation finding.
type ValidationResult struct {
Level string // "error", "warn", "ok"
Field string // dotted path matching YAML keys
Message string
}
// String implements the Stringer interface for nice printing.
func (v ValidationResult) String() string {
return fmt.Sprintf("[%s] %s: %s", strings.ToUpper(v.Level), v.Field, v.Message)
}
// Validate checks the config for errors, warnings, and emits OK for valid sections.
func Validate(cfg *Config) []ValidationResult {
var results []ValidationResult
// --- Hostname ---
if cfg.Hostname == "" || cfg.Hostname == "SET_HOSTNAME_HERE" {
results = append(results, ValidationResult{"error", "hostname", "hostname is not set"})
} else {
results = append(results, ValidationResult{"ok", "hostname", cfg.Hostname})
}
// --- Alerts ---
if !cfg.Alerts.Email.Enabled && !cfg.Alerts.Webhook.Enabled {
results = append(results, ValidationResult{"error", "alerts", "no alert method enabled (enable email or webhook)"})
}
// --- Email alerts ---
if cfg.Alerts.Email.Enabled {
if len(cfg.Alerts.Email.To) == 0 {
results = append(results, ValidationResult{"error", "alerts.email.to", "email alerts enabled but no recipients configured"})
} else {
valid := true
for _, to := range cfg.Alerts.Email.To {
if to == "SET_EMAIL_HERE" || !strings.Contains(to, "@") {
results = append(results, ValidationResult{"error", "alerts.email.to", fmt.Sprintf("invalid email recipient: %s", to)})
valid = false
}
}
if valid {
results = append(results, ValidationResult{"ok", "alerts.email.to", strings.Join(cfg.Alerts.Email.To, ", ")})
}
}
if cfg.Alerts.Email.From == "" {
results = append(results, ValidationResult{"error", "alerts.email.from", "email alerts enabled but no from address configured"})
}
if cfg.Alerts.Email.SMTP == "" {
results = append(results, ValidationResult{"error", "alerts.email.smtp", "email alerts enabled but no SMTP server configured"})
} else {
results = append(results, ValidationResult{"ok", "alerts.email.smtp", cfg.Alerts.Email.SMTP})
}
}
// --- Webhook ---
if cfg.Alerts.Webhook.Enabled {
if cfg.Alerts.Webhook.URL == "" {
results = append(results, ValidationResult{"error", "alerts.webhook.url", "webhook alerts enabled but no URL configured"})
} else {
results = append(results, ValidationResult{"ok", "alerts.webhook.url", cfg.Alerts.Webhook.URL})
}
}
// --- Heartbeat ---
if cfg.Alerts.Heartbeat.Enabled {
if cfg.Alerts.Heartbeat.URL == "" {
results = append(results, ValidationResult{"error", "alerts.heartbeat.url", "heartbeat enabled but no URL configured"})
} else {
results = append(results, ValidationResult{"ok", "alerts.heartbeat.url", cfg.Alerts.Heartbeat.URL})
}
}
// --- MaxPerHour ---
if cfg.Alerts.MaxPerHour <= 0 {
results = append(results, ValidationResult{"error", "alerts.max_per_hour", "max_per_hour must be > 0"})
}
// --- WebUI ---
if cfg.WebUI.Enabled {
if cfg.WebUI.AuthToken == "" {
results = append(results, ValidationResult{"error", "webui.auth_token", "webui enabled but no auth_token configured"})
}
}
if cfg.WebUI.Enabled && cfg.WebUI.AuthToken != "" {
results = append(results, ValidationResult{"ok", "webui", fmt.Sprintf("listening on %s", cfg.WebUI.Listen)})
}
// --- Trusted countries ---
for _, cc := range cfg.Suppressions.TrustedCountries {
if len(cc) != 2 {
results = append(results, ValidationResult{"error", "suppressions.trusted_countries", fmt.Sprintf("invalid country code: %q (expected 2-letter ISO code)", cc)})
}
}
// --- Duration fields (only check if non-empty) ---
if cfg.AutoResponse.BlockExpiry != "" {
if _, err := time.ParseDuration(cfg.AutoResponse.BlockExpiry); err != nil {
results = append(results, ValidationResult{"error", "auto_response.block_expiry", fmt.Sprintf("unparseable duration: %s", cfg.AutoResponse.BlockExpiry)})
}
}
if cfg.Signatures.UpdateInterval != "" {
if _, err := time.ParseDuration(cfg.Signatures.UpdateInterval); err != nil {
results = append(results, ValidationResult{"error", "signatures.update_interval", fmt.Sprintf("unparseable duration: %s", cfg.Signatures.UpdateInterval)})
}
}
if cfg.EmailAV.ScanTimeout != "" {
if _, err := time.ParseDuration(cfg.EmailAV.ScanTimeout); err != nil {
results = append(results, ValidationResult{"error", "email_av.scan_timeout", fmt.Sprintf("unparseable duration: %s", cfg.EmailAV.ScanTimeout)})
}
}
if cfg.GeoIP.UpdateInterval != "" {
if _, err := time.ParseDuration(cfg.GeoIP.UpdateInterval); err != nil {
results = append(results, ValidationResult{"error", "geoip.update_interval", fmt.Sprintf("unparseable duration: %s", cfg.GeoIP.UpdateInterval)})
}
}
if cfg.AutoResponse.PermBlockInterval != "" {
if _, err := time.ParseDuration(cfg.AutoResponse.PermBlockInterval); err != nil {
results = append(results, ValidationResult{"error", "auto_response.permblock_interval", fmt.Sprintf("unparseable duration: %s", cfg.AutoResponse.PermBlockInterval)})
}
}
// --- Firewall ---
if cfg.Firewall != nil && cfg.Firewall.Enabled {
if cfg.Firewall.ConnRateLimit <= 0 {
results = append(results, ValidationResult{"error", "firewall.conn_rate_limit", "conn_rate_limit must be > 0 when firewall enabled"})
}
if cfg.Firewall.ConnLimit < 0 {
results = append(results, ValidationResult{"error", "firewall.conn_limit", "conn_limit must be >= 0 when firewall enabled (0 = disabled)"})
}
if cfg.Firewall.ConnRateLimit > 0 && cfg.Firewall.ConnLimit >= 0 {
results = append(results, ValidationResult{"ok", "firewall", fmt.Sprintf("enabled, conn_rate_limit=%d, conn_limit=%d", cfg.Firewall.ConnRateLimit, cfg.Firewall.ConnLimit)})
}
}
// --- Challenge ---
if cfg.Challenge.Difficulty < 0 || cfg.Challenge.Difficulty > 5 {
results = append(results, ValidationResult{"error", "challenge.difficulty", fmt.Sprintf("difficulty must be 0-5, got %d", cfg.Challenge.Difficulty)})
}
// --- EmailAV ---
if cfg.EmailAV.Enabled && cfg.EmailAV.MaxAttachmentSize <= 0 {
results = append(results, ValidationResult{"error", "email_av.max_attachment_size", "max_attachment_size must be > 0 when email_av enabled"})
}
if cfg.EmailAV.FailMode != "" && cfg.EmailAV.FailMode != "open" && cfg.EmailAV.FailMode != "tempfail" {
results = append(results, ValidationResult{"error", "email_av.fail_mode",
fmt.Sprintf("invalid fail_mode %q: must be \"open\" or \"tempfail\"", cfg.EmailAV.FailMode)})
}
if cfg.Signatures.UpdateURL != "" && cfg.Signatures.SigningKey == "" {
results = append(results, ValidationResult{"error", "signatures.signing_key",
"signing_key is required when signatures.update_url is configured"})
}
if cfg.Signatures.YaraForge.Enabled && cfg.Signatures.SigningKey == "" {
results = append(results, ValidationResult{"error", "signatures.signing_key",
"signing_key is required when signatures.yara_forge.enabled is true"})
}
// --- EmailProtection ---
if cfg.EmailProtection.RateWarnThreshold > 0 && cfg.EmailProtection.RateWarnThreshold < 10 {
results = append(results, ValidationResult{"warn", "email_protection.rate_warn_threshold", "rate_warn_threshold < 10 may cause excessive alerts"})
}
if cfg.EmailProtection.RateCritThreshold > 0 && cfg.EmailProtection.RateCritThreshold <= cfg.EmailProtection.RateWarnThreshold {
results = append(results, ValidationResult{"error", "email_protection.rate_crit_threshold", "rate_crit_threshold must be > rate_warn_threshold"})
}
if cfg.EmailProtection.RateWindowMin > 0 && (cfg.EmailProtection.RateWindowMin < 5 || cfg.EmailProtection.RateWindowMin > 60) {
results = append(results, ValidationResult{"error", "email_protection.rate_window_min", "rate_window_min must be between 5 and 60"})
}
if cfg.EmailProtection.PasswordCheckIntervalMin > 0 && cfg.EmailProtection.PasswordCheckIntervalMin < 60 {
results = append(results, ValidationResult{"warn", "email_protection.password_check_interval_min", "password_check_interval_min < 60 may cause high CPU from doveadm"})
}
// --- SMTP brute-force thresholds ---
t := cfg.Thresholds
if t.SMTPBruteForceThreshold != 0 && (t.SMTPBruteForceThreshold < 2 || t.SMTPBruteForceThreshold > 50) {
results = append(results, ValidationResult{"error", "thresholds.smtp_bruteforce_threshold", "smtp_bruteforce_threshold must be between 2 and 50"})
}
if t.SMTPBruteForceWindowMin != 0 && (t.SMTPBruteForceWindowMin < 1 || t.SMTPBruteForceWindowMin > 60) {
results = append(results, ValidationResult{"error", "thresholds.smtp_bruteforce_window_min", "smtp_bruteforce_window_min must be between 1 and 60"})
}
if t.SMTPBruteForceSuppressMin != 0 && (t.SMTPBruteForceSuppressMin < 1 || t.SMTPBruteForceSuppressMin > 1440) {
results = append(results, ValidationResult{"error", "thresholds.smtp_bruteforce_suppress_min", "smtp_bruteforce_suppress_min must be between 1 and 1440"})
}
if t.SMTPBruteForceSubnetThresh != 0 && (t.SMTPBruteForceSubnetThresh < 2 || t.SMTPBruteForceSubnetThresh > 64) {
results = append(results, ValidationResult{"error", "thresholds.smtp_bruteforce_subnet_threshold", "smtp_bruteforce_subnet_threshold must be between 2 and 64"})
}
if t.SMTPAccountSprayThreshold != 0 && (t.SMTPAccountSprayThreshold < 2 || t.SMTPAccountSprayThreshold > 200) {
results = append(results, ValidationResult{"error", "thresholds.smtp_account_spray_threshold", "smtp_account_spray_threshold must be between 2 and 200"})
}
if t.SMTPBruteForceMaxTracked != 0 && (t.SMTPBruteForceMaxTracked < 1000 || t.SMTPBruteForceMaxTracked > 200000) {
results = append(results, ValidationResult{"error", "thresholds.smtp_bruteforce_max_tracked", "smtp_bruteforce_max_tracked must be between 1000 and 200000"})
}
if t.MailBruteForceThreshold != 0 && (t.MailBruteForceThreshold < 2 || t.MailBruteForceThreshold > 50) {
results = append(results, ValidationResult{"error", "thresholds.mail_bruteforce_threshold", "mail_bruteforce_threshold must be between 2 and 50"})
}
if t.MailBruteForceWindowMin != 0 && (t.MailBruteForceWindowMin < 1 || t.MailBruteForceWindowMin > 60) {
results = append(results, ValidationResult{"error", "thresholds.mail_bruteforce_window_min", "mail_bruteforce_window_min must be between 1 and 60"})
}
if t.MailBruteForceSuppressMin != 0 && (t.MailBruteForceSuppressMin < 1 || t.MailBruteForceSuppressMin > 1440) {
results = append(results, ValidationResult{"error", "thresholds.mail_bruteforce_suppress_min", "mail_bruteforce_suppress_min must be between 1 and 1440"})
}
if t.MailBruteForceSubnetThresh != 0 && (t.MailBruteForceSubnetThresh < 2 || t.MailBruteForceSubnetThresh > 64) {
results = append(results, ValidationResult{"error", "thresholds.mail_bruteforce_subnet_threshold", "mail_bruteforce_subnet_threshold must be between 2 and 64"})
}
if t.MailAccountSprayThreshold != 0 && (t.MailAccountSprayThreshold < 2 || t.MailAccountSprayThreshold > 200) {
results = append(results, ValidationResult{"error", "thresholds.mail_account_spray_threshold", "mail_account_spray_threshold must be between 2 and 200"})
}
if t.MailBruteForceMaxTracked != 0 && (t.MailBruteForceMaxTracked < 1000 || t.MailBruteForceMaxTracked > 200000) {
results = append(results, ValidationResult{"error", "thresholds.mail_bruteforce_max_tracked", "mail_bruteforce_max_tracked must be between 1000 and 200000"})
}
// --- Warnings ---
results = append(results, validateWarnings(cfg)...)
return results
}
// validateWarnings checks for non-fatal configuration issues.
func validateWarnings(cfg *Config) []ValidationResult {
var results []ValidationResult
// GeoIP credentials set but auto_update explicitly false
if cfg.GeoIP.AccountID != "" && cfg.GeoIP.LicenseKey != "" {
if cfg.GeoIP.AutoUpdate != nil && !*cfg.GeoIP.AutoUpdate {
results = append(results, ValidationResult{"warn", "geoip", "GeoIP credentials configured but auto_update is disabled"})
}
}
// Auto-response enabled but no actions
if cfg.AutoResponse.Enabled {
if !cfg.AutoResponse.KillProcesses && !cfg.AutoResponse.QuarantineFiles && !cfg.AutoResponse.BlockIPs {
results = append(results, ValidationResult{"warn", "auto_response", "auto_response enabled but no actions configured (kill/quarantine/block all false)"})
}
}
// Infra IPs both empty
fwInfra := cfg.Firewall != nil && len(cfg.Firewall.InfraIPs) > 0
topInfra := len(cfg.InfraIPs) > 0
if !topInfra && !fwInfra {
results = append(results, ValidationResult{"warn", "infra_ips", "no infra_ips configured in either top-level or firewall section"})
}
// Firewall enabled but no infra IPs (lockout risk)
if cfg.Firewall != nil && cfg.Firewall.Enabled && !topInfra && !fwInfra {
results = append(results, ValidationResult{"warn", "firewall", "firewall enabled but no infra_ips configured - risk of lockout"})
}
// Netblock threshold too low
if cfg.AutoResponse.NetBlock && cfg.AutoResponse.NetBlockThreshold < 2 {
results = append(results, ValidationResult{"warn", "auto_response.netblock_threshold", fmt.Sprintf("netblock_threshold=%d is very low (< 2), may cause excessive blocking", cfg.AutoResponse.NetBlockThreshold)})
}
// Permblock count too low
if cfg.AutoResponse.PermBlock && cfg.AutoResponse.PermBlockCount < 2 {
results = append(results, ValidationResult{"warn", "auto_response.permblock_count", fmt.Sprintf("permblock_count=%d is very low (< 2), may permanently block too quickly", cfg.AutoResponse.PermBlockCount)})
}
return results
}
// ValidateDeep performs connectivity probes against configured services.
// It does NOT call Validate(); the caller should invoke both separately.
func ValidateDeep(cfg *Config) []ValidationResult {
var results []ValidationResult
// State directory
results = append(results, probeStatePath(cfg.StatePath)...)
// Signature rules directory
if cfg.Signatures.RulesDir != "" {
results = append(results, probeRulesDir(cfg.Signatures.RulesDir)...)
}
// SMTP
if cfg.Alerts.Email.Enabled && cfg.Alerts.Email.SMTP != "" {
results = append(results, probeSMTP(cfg.Alerts.Email.SMTP)...)
}
// ClamAV socket
if cfg.EmailAV.Enabled && cfg.EmailAV.ClamdSocket != "" {
results = append(results, probeClamd(cfg.EmailAV.ClamdSocket)...)
}
// TLS cert/key (only when custom paths set)
if cfg.WebUI.TLSCert != "" {
if _, err := os.Stat(cfg.WebUI.TLSCert); err != nil {
results = append(results, ValidationResult{"error", "webui.tls_cert", fmt.Sprintf("file not found: %s", cfg.WebUI.TLSCert)})
} else {
results = append(results, ValidationResult{"ok", "webui.tls_cert", cfg.WebUI.TLSCert})
}
}
if cfg.WebUI.TLSKey != "" {
if _, err := os.Stat(cfg.WebUI.TLSKey); err != nil {
results = append(results, ValidationResult{"error", "webui.tls_key", fmt.Sprintf("file not found: %s", cfg.WebUI.TLSKey)})
} else {
results = append(results, ValidationResult{"ok", "webui.tls_key", cfg.WebUI.TLSKey})
}
}
// Webhook
if cfg.Alerts.Webhook.Enabled && cfg.Alerts.Webhook.URL != "" {
results = append(results, probeWebhook(cfg.Alerts.Webhook.URL)...)
}
// GeoIP database files
if cfg.GeoIP.AccountID != "" && cfg.GeoIP.LicenseKey != "" && len(cfg.GeoIP.Editions) > 0 {
results = append(results, probeGeoIPDBs(cfg.StatePath, cfg.GeoIP.Editions)...)
}
return results
}
// probeStatePath checks that the state directory exists and is writable.
func probeStatePath(path string) []ValidationResult {
info, err := os.Stat(path)
if err != nil {
return []ValidationResult{{"error", "state_path", fmt.Sprintf("directory not found: %s", path)}}
}
if !info.IsDir() {
return []ValidationResult{{"error", "state_path", fmt.Sprintf("not a directory: %s", path)}}
}
probe := filepath.Join(path, ".csm-validate-probe")
// #nosec G304 -- filepath.Join under operator-configured statePath.
f, err := os.Create(probe)
if err != nil {
return []ValidationResult{{"error", "state_path", fmt.Sprintf("directory not writable: %s", path)}}
}
f.Close()
os.Remove(probe)
return []ValidationResult{{"ok", "state_path", path}}
}
// probeRulesDir checks that the rules directory exists and contains rule files.
func probeRulesDir(path string) []ValidationResult {
info, err := os.Stat(path)
if err != nil {
return []ValidationResult{{"error", "signatures.rules_dir", fmt.Sprintf("directory not found: %s", path)}}
}
if !info.IsDir() {
return []ValidationResult{{"error", "signatures.rules_dir", fmt.Sprintf("not a directory: %s", path)}}
}
// Check for rule files
for _, pattern := range []string{"*.yaml", "*.yml", "*.yar", "*.yara"} {
matches, _ := filepath.Glob(filepath.Join(path, pattern))
if len(matches) > 0 {
return []ValidationResult{{"ok", "signatures.rules_dir", fmt.Sprintf("%s (%d rule files)", path, len(matches))}}
}
}
return []ValidationResult{{"error", "signatures.rules_dir", fmt.Sprintf("no rule files (.yaml/.yml/.yar/.yara) found in %s", path)}}
}
// probeSMTP attempts a TCP dial to the SMTP server.
func probeSMTP(addr string) []ValidationResult {
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
return []ValidationResult{{"error", "alerts.email.smtp", fmt.Sprintf("cannot connect to %s: %v", addr, err)}}
}
_ = conn.Close()
return []ValidationResult{{"ok", "alerts.email.smtp", fmt.Sprintf("connected to %s", addr)}}
}
// probeClamd attempts to connect to the ClamAV unix socket.
func probeClamd(socket string) []ValidationResult {
conn, err := net.DialTimeout("unix", socket, 3*time.Second)
if err != nil {
return []ValidationResult{{"error", "email_av.clamd_socket", fmt.Sprintf("cannot connect to %s: %v", socket, err)}}
}
_ = conn.Close()
return []ValidationResult{{"ok", "email_av.clamd_socket", fmt.Sprintf("connected to %s", socket)}}
}
// probeWebhook performs an HTTP HEAD request to verify the webhook endpoint is reachable.
// DNS/TCP/TLS failures are errors; HTTP status codes (even 401/403/404/405) mean reachable.
func probeWebhook(url string) []ValidationResult {
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Head(url)
if err != nil {
return []ValidationResult{{"error", "alerts.webhook.url", fmt.Sprintf("cannot reach %s: %v", url, err)}}
}
resp.Body.Close()
return []ValidationResult{{"ok", "alerts.webhook.url", fmt.Sprintf("reachable (HTTP %d)", resp.StatusCode)}}
}
// probeGeoIPDBs checks that expected GeoIP database files exist on disk.
func probeGeoIPDBs(statePath string, editions []string) []ValidationResult {
var results []ValidationResult
allOK := true
for _, edition := range editions {
dbPath := filepath.Join(statePath, "geoip", edition+".mmdb")
if _, err := os.Stat(dbPath); err != nil {
results = append(results, ValidationResult{"error", "geoip", fmt.Sprintf("database not found: %s", dbPath)})
allOK = false
}
}
if allOK {
results = append(results, ValidationResult{"ok", "geoip", fmt.Sprintf("all %d edition databases present", len(editions))})
}
return results
}
package daemon
import (
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/control"
"github.com/pidginhost/csm/internal/integrity"
"github.com/pidginhost/csm/internal/store"
)
// dispatch parses a raw request line, routes to the right handler, and
// wraps the handler's (result, error) pair in the response envelope.
// Unknown commands fail cleanly rather than crash the listener.
func (c *ControlListener) dispatch(line []byte) control.Response {
var req control.Request
if err := json.Unmarshal(line, &req); err != nil {
return control.Response{OK: false, Error: fmt.Sprintf("bad request: %v", err)}
}
var (
result any
err error
)
switch req.Cmd {
case control.CmdTierRun:
result, err = c.handleTierRun(req.Args)
case control.CmdStatus:
result, err = c.handleStatus(req.Args)
case control.CmdHistoryRead:
result, err = c.handleHistoryRead(req.Args)
case control.CmdRulesReload:
result, err = c.handleRulesReload(req.Args)
case control.CmdGeoIPReload:
result, err = c.handleGeoIPReload(req.Args)
default:
return control.Response{OK: false, Error: fmt.Sprintf("unknown command: %q", req.Cmd)}
}
if err != nil {
return control.Response{OK: false, Error: err.Error()}
}
payload, mErr := json.Marshal(result)
if mErr != nil {
return control.Response{OK: false, Error: "result marshal: " + mErr.Error()}
}
return control.Response{OK: true, Result: payload}
}
// parseTier maps the wire string onto the checks.Tier constants. Kept
// local to the listener so the protocol package does not depend on the
// checks package.
func parseTier(s string) (checks.Tier, error) {
switch s {
case "critical":
return checks.TierCritical, nil
case "deep":
return checks.TierDeep, nil
case "all", "":
return checks.TierAll, nil
}
return "", fmt.Errorf("unknown tier: %q", s)
}
// handleTierRun runs a tier synchronously and reports the result. The
// flow mirrors Daemon.runPeriodicChecks: integrity verify, RunTier,
// purge-and-merge, then hand findings to the alert pipeline. The only
// deviation is that we return counts to the caller instead of nothing.
func (c *ControlListener) handleTierRun(argsRaw json.RawMessage) (any, error) {
var args control.TierRunArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
tier, err := parseTier(args.Tier)
if err != nil {
return nil, err
}
if vErr := integrity.Verify(c.d.binaryPath, c.d.cfg); vErr != nil {
// Integrity failures are escalated through the normal alert
// pipeline so the on-call path sees them regardless of who
// kicked the tier run. The client also gets an error so the
// systemd timer unit fails loudly.
if args.Alerts {
select {
case c.d.alertCh <- alert.Finding{
Severity: alert.Critical,
Check: "integrity",
Message: fmt.Sprintf("BINARY/CONFIG TAMPER DETECTED: %v", vErr),
Timestamp: time.Now(),
}:
default:
atomic.AddInt64(&c.d.droppedAlerts, 1)
}
}
return nil, fmt.Errorf("integrity verify failed: %w", vErr)
}
start := time.Now()
findings := checks.RunTier(c.d.cfg, c.d.store, tier)
if len(findings) > 0 {
c.d.store.PurgeAndMergeFindings(checks.PerfCheckNamesForTier(tier), findings)
if args.Alerts {
for _, f := range findings {
if strings.HasPrefix(f.Check, "perf_") && f.Severity == alert.Warning {
continue
}
select {
case c.d.alertCh <- f:
default:
atomic.AddInt64(&c.d.droppedAlerts, 1)
}
}
}
}
// FilterNew takes a snapshot under the state mutex; racing with the
// dispatcher that will later process the same findings is fine —
// both observe consistent state, we just report our view.
newCount := len(c.d.store.FilterNew(findings))
return control.TierRunResult{
Findings: len(findings),
NewFindings: newCount,
ElapsedMs: time.Since(start).Milliseconds(),
}, nil
}
// handleStatus reports what `csm status` historically printed from
// disk, sourced from the live daemon instead of re-opening the store.
func (c *ControlListener) handleStatus(_ json.RawMessage) (any, error) {
latest := c.d.store.LatestFindings()
latestTime := c.d.store.LatestScanTime()
var latestStr string
if !latestTime.IsZero() {
latestStr = latestTime.UTC().Format(time.RFC3339)
}
var historyCount int
if sdb := store.Global(); sdb != nil {
historyCount = sdb.HistoryCount()
}
uptime := int64(0)
if !c.d.startTime.IsZero() {
uptime = int64(time.Since(c.d.startTime).Seconds())
}
return control.StatusResult{
Version: c.d.version,
UptimeSec: uptime,
LatestScanTime: latestStr,
LatestFindings: len(latest),
HistoryCount: historyCount,
DroppedAlerts: c.d.DroppedAlerts(),
}, nil
}
// handleHistoryRead paginates bbolt history. Clamps Limit so a buggy
// client cannot ask for everything at once; 1000 is well above the
// dashboard page size.
func (c *ControlListener) handleHistoryRead(argsRaw json.RawMessage) (any, error) {
var args control.HistoryReadArgs
if len(argsRaw) > 0 {
if err := json.Unmarshal(argsRaw, &args); err != nil {
return nil, fmt.Errorf("parsing args: %w", err)
}
}
if args.Limit <= 0 || args.Limit > 1000 {
args.Limit = 100
}
if args.Offset < 0 {
args.Offset = 0
}
findings, total := c.d.store.ReadHistory(args.Limit, args.Offset)
return control.HistoryReadResult{Findings: findings, Total: total}, nil
}
// handleRulesReload replaces `kill -HUP $(pidof csm)`. Returns after
// the reload completes so the client can confirm it happened.
func (c *ControlListener) handleRulesReload(_ json.RawMessage) (any, error) {
c.d.reloadSignatures()
return map[string]string{"status": "reloaded"}, nil
}
// handleGeoIPReload is the GeoIP equivalent of rules.reload.
func (c *ControlListener) handleGeoIPReload(_ json.RawMessage) (any, error) {
c.d.publishGeoIP()
return map[string]string{"status": "reloaded"}, nil
}
package daemon
import (
"bufio"
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"time"
csmlog "github.com/pidginhost/csm/internal/log"
)
// controlSocketPath is the Unix socket the daemon binds for
// CLI-to-daemon IPC. Kept in sync with internal/control.DefaultSocketPath
// via a compile-time reference in NewControlListener.
const controlSocketPath = "/var/run/csm/control.sock"
// controlRequestTimeout caps how long a single client request can block
// the listener. Handlers that legitimately take longer (tier.run on a
// large server) run on the accepting goroutine, so the timeout applies
// to reading the request line and writing the response, not to the
// handler body itself.
const controlRequestTimeout = 2 * time.Second
// ControlListener serves the local command-line client over a Unix
// socket. One request and one response per connection, line-framed JSON.
// The daemon keeps exclusive ownership of the bbolt store; this listener
// is the only reason any CLI command needs to reach into daemon state.
type ControlListener struct {
d *Daemon
listener net.Listener
}
// NewControlListener creates the socket, enforces 0600 perms, and
// returns a listener that the daemon wires into its goroutine pool.
func NewControlListener(d *Daemon) (*ControlListener, error) {
socketDir := filepath.Dir(controlSocketPath)
if err := os.MkdirAll(socketDir, 0750); err != nil {
return nil, fmt.Errorf("creating socket dir: %w", err)
}
// Stale socket from a previous crash would make Listen fail with
// EADDRINUSE. Remove before binding; the file is process-owned.
_ = os.Remove(controlSocketPath)
ln, err := net.Listen("unix", controlSocketPath)
if err != nil {
return nil, fmt.Errorf("listening on %s: %w", controlSocketPath, err)
}
// 0600 root-only: the CLI client also runs as root (fanotify,
// nftables, cpanel APIs all require it), so no group is needed.
if err := os.Chmod(controlSocketPath, 0600); err != nil {
_ = ln.Close()
return nil, fmt.Errorf("chmod socket: %w", err)
}
return &ControlListener{d: d, listener: ln}, nil
}
// Run accepts connections until stopCh closes. Each connection is
// handled on its own goroutine so a slow request never stalls the
// accept loop.
func (c *ControlListener) Run(stopCh <-chan struct{}) {
for {
conn, err := c.listener.Accept()
if err != nil {
select {
case <-stopCh:
return
default:
csmlog.Warn("control listener accept error", "err", err)
time.Sleep(100 * time.Millisecond)
continue
}
}
go c.handleConnection(conn)
}
}
// Stop closes the listener and removes the socket file. Safe to call
// after Run has already returned.
func (c *ControlListener) Stop() {
_ = c.listener.Close()
_ = os.Remove(controlSocketPath)
}
// handleConnection reads one request, dispatches it, writes one
// response, and closes. The short timeout applies to I/O only; the
// handler itself can take as long as the underlying work requires.
func (c *ControlListener) handleConnection(conn net.Conn) {
defer func() { _ = conn.Close() }()
// Read deadline. Writes use a separate deadline set after the
// handler returns so slow scans don't count against the reader.
_ = conn.SetReadDeadline(time.Now().Add(controlRequestTimeout))
scanner := bufio.NewScanner(conn)
// Requests are single-line JSON but can carry larger Args payloads
// once baseline/history paging enters the command set. 1 MiB is
// well above any request we expect and well below any DoS concern
// on a root-only socket.
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
if !scanner.Scan() {
return
}
line := scanner.Bytes()
resp := c.dispatch(line)
payload, err := json.Marshal(resp)
if err != nil {
// Marshalling a Response should not fail; if it does, fall back
// to a minimal error response the client can still parse.
payload = []byte(`{"ok":false,"error":"internal: response marshal failed"}`)
}
payload = append(payload, '\n')
_ = conn.SetWriteDeadline(time.Now().Add(controlRequestTimeout))
_, _ = conn.Write(payload)
}
package daemon
import (
"context"
"fmt"
"net"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/attackdb"
"github.com/pidginhost/csm/internal/challenge"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/emailav"
"github.com/pidginhost/csm/internal/firewall"
"github.com/pidginhost/csm/internal/geoip"
"github.com/pidginhost/csm/internal/integrity"
csmlog "github.com/pidginhost/csm/internal/log"
"github.com/pidginhost/csm/internal/modsec"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/signatures"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
"github.com/pidginhost/csm/internal/webui"
"github.com/pidginhost/csm/internal/yara"
)
// Daemon is the main persistent monitoring process.
type Daemon struct {
cfg *config.Config
store *state.Store
lock *state.LockFile
binaryPath string
logWatchers []*LogWatcher
logWatchersMu sync.Mutex
fileMonitor *FileMonitor
hijackDetector *PasswordHijackDetector
pamListener *PAMListener
controlListener *ControlListener
spoolWatcher *SpoolWatcher
spoolWatcherMu sync.Mutex
forwarderWatcher *ForwarderWatcher
emailQuarantine *emailav.Quarantine
webServer *webui.Server
challengeServer *challenge.Server
ipList *challenge.IPList
fwEngine *firewall.Engine
geoipDB *geoip.DB
geoipMu sync.Mutex // protects geoipDB for publishGeoIP
version string
alertCh chan alert.Finding
droppedAlerts int64 // atomic counter for alert channel backpressure drops
stopCh chan struct{}
wg sync.WaitGroup
smtpAuthTracker *smtpAuthTracker
mailAuthTracker *mailAuthTracker
startTime time.Time
}
// New creates a new daemon instance.
func New(cfg *config.Config, store *state.Store, lock *state.LockFile, binaryPath string) *Daemon {
d := &Daemon{
cfg: cfg,
store: store,
lock: lock,
binaryPath: binaryPath,
alertCh: make(chan alert.Finding, 500),
stopCh: make(chan struct{}),
}
d.smtpAuthTracker = newSMTPAuthTracker(
cfg.Thresholds.SMTPBruteForceThreshold,
cfg.Thresholds.SMTPBruteForceSubnetThresh,
cfg.Thresholds.SMTPAccountSprayThreshold,
time.Duration(cfg.Thresholds.SMTPBruteForceWindowMin)*time.Minute,
time.Duration(cfg.Thresholds.SMTPBruteForceSuppressMin)*time.Minute,
cfg.Thresholds.SMTPBruteForceMaxTracked,
time.Now,
)
d.mailAuthTracker = newMailAuthTracker(
cfg.Thresholds.MailBruteForceThreshold,
cfg.Thresholds.MailBruteForceSubnetThresh,
cfg.Thresholds.MailAccountSprayThreshold,
time.Duration(cfg.Thresholds.MailBruteForceWindowMin)*time.Minute,
time.Duration(cfg.Thresholds.MailBruteForceSuppressMin)*time.Minute,
cfg.Thresholds.MailBruteForceMaxTracked,
time.Now,
)
return d
}
// SetVersion sets the application version for display in the web UI.
func (d *Daemon) SetVersion(v string) {
d.version = v
}
// Run starts the daemon and blocks until stopped.
func (d *Daemon) Run() error {
d.startTime = time.Now()
// Initialize structured logging from environment (CSM_LOG_FORMAT,
// CSM_LOG_LEVEL). The default text handler preserves the legacy
// "[YYYY-MM-DD HH:MM:SS] msg" format so operators mixing csmlog
// with legacy fmt.Fprintf call sites see a uniform log stream.
// CSM_LOG_FORMAT=json switches to structured JSON for log shipping.
csmlog.Init()
csmlog.Info("CSM daemon starting")
// Install config-supplied platform overrides BEFORE the first Detect()
// call so every check sees the merged view. Must happen before any
// other code calls platform.Detect() in this daemon run.
//
// WebServer is a *platform.WebServer pointer — take address only when
// the operator actually supplied a non-empty type, otherwise leave
// the auto-detected value alone.
var wsOverride *platform.WebServer
if t := d.cfg.WebServer.Type; t != "" {
ws := platform.WebServer(t)
wsOverride = &ws
}
platform.SetOverrides(platform.Overrides{
WebServer: wsOverride,
ApacheConfigDir: d.cfg.WebServer.ConfigDir,
AccessLogPaths: d.cfg.WebServer.AccessLogs,
ErrorLogPaths: d.cfg.WebServer.ErrorLogs,
ModSecAuditLogPaths: d.cfg.WebServer.ModSecAudits,
})
// Log detected platform as a structured record. In text mode this
// comes out as "[ts] platform detected os=X panel=Y ..."; in
// JSON mode as {"msg":"platform detected","os":"X","panel":"Y",...}.
pi := platform.Detect()
csmlog.Info("platform detected",
"os", orUnknown(string(pi.OS)),
"os_version", orUnknown(pi.OSVersion),
"panel", orNone(string(pi.Panel)),
"webserver", orNone(string(pi.WebServer)),
)
// Verify integrity on startup
if err := integrity.Verify(d.binaryPath, d.cfg); err != nil {
tamper := alert.Finding{
Severity: alert.Critical,
Check: "integrity",
Message: fmt.Sprintf("BINARY/CONFIG TAMPER DETECTED: %v", err),
Timestamp: time.Now(),
}
_ = alert.Dispatch(d.cfg, []alert.Finding{tamper})
return fmt.Errorf("integrity check failed: %w", err)
}
// Deploy WHM plugin and configs if cPanel is present
deployConfigs()
// Initialize signature scanners and threat DB (fast, no I/O scan)
if yaraScanner := yara.Init(d.cfg.Signatures.RulesDir); yaraScanner != nil {
fmt.Fprintf(os.Stderr, "[%s] YARA-X scanner active: %d rule file(s)\n", ts(), yaraScanner.RuleCount())
}
checks.InitThreatDB(d.cfg.StatePath, d.cfg.Reputation.Whitelist)
if db := checks.GetThreatDB(); db != nil {
fmt.Fprintf(os.Stderr, "[%s] Threat DB initialized (%d entries)\n", ts(), db.Count())
}
if adb := attackdb.Init(d.cfg.StatePath); adb != nil {
// Seed from permanent blocklist on first run (when attack DB is empty)
if adb.TotalIPs() == 0 {
if n := adb.SeedFromPermanentBlocklist(d.cfg.StatePath); n > 0 {
fmt.Fprintf(os.Stderr, "[%s] Attack DB seeded %d IPs from permanent blocklist\n", ts(), n)
}
}
fmt.Fprintf(os.Stderr, "[%s] Attack DB initialized (%s)\n", ts(), adb.FormatTopLine())
}
// Start firewall engine if enabled
d.startFirewall()
// Start challenge server if enabled (gray listing)
d.startChallengeServer()
// Start challenge escalation ticker
if d.ipList != nil {
d.wg.Add(1)
go d.challengeEscalator()
}
// Create password hijack detector
d.hijackDetector = NewPasswordHijackDetector(d.cfg, d.alertCh)
// Start inotify log watchers
d.startLogWatchers()
// Start PAM listener for real-time brute-force detection
d.startPAMListener()
// Start control socket listener for the thin-client CLI. The
// daemon is the sole bbolt owner; CLI commands that previously
// raced for the lock now route through this socket.
d.startControlListener()
// Start fanotify file monitor (real-time detection starts immediately)
d.startFileMonitor()
// Start email AV spool watcher (separate fanotify for Exim spool).
// Spool and forwarder watchers are cPanel-only; they watch paths
// (/var/spool/exim, /etc/valiases) that only exist on cPanel hosts.
if platform.Detect().IsCPanel() {
d.startSpoolWatcher()
d.startForwarderWatcher()
}
// Start Web UI server - available immediately, before initial scan
d.startWebUI()
// Wire email quarantine to web server (after both start)
d.syncEmailAVWebState()
// Initialize GeoIP databases (after webServer so SetGeoIPDB can attach)
d.initGeoIP()
// Run initial scan synchronously (before dispatcher starts)
fmt.Fprintf(os.Stderr, "[%s] Running initial baseline scan...\n", ts())
initialFindings := checks.RunTier(d.cfg, d.store, checks.TierCritical)
// Seed the attack database with initial scan findings
if adb := attackdb.Global(); adb != nil {
for _, f := range initialFindings {
adb.RecordFinding(f)
}
}
d.store.AppendHistory(initialFindings)
newFindings := d.store.FilterNew(initialFindings)
suppressions := d.store.LoadSuppressions()
initialAutoResponseFindings := initialFindings
if len(suppressions) > 0 {
initialAutoResponseFindings = filterUnsuppressedFindings(d.store, initialFindings, suppressions)
newFindings = filterUnsuppressedFindings(d.store, newFindings, suppressions)
}
// Permission auto-fix runs on ALL findings (not just new) because
// it's safe/idempotent and should fix baseline findings too.
permActions, permFixedKeys := checks.AutoFixPermissions(d.cfg, initialAutoResponseFindings)
// Challenge routing runs on ALL findings unconditionally when enabled.
challengeActions := checks.ChallengeRouteIPs(d.cfg, initialAutoResponseFindings)
// Other auto-response only on new findings
if len(newFindings) > 0 {
killActions := checks.AutoKillProcesses(d.cfg, newFindings)
quarantineActions := checks.AutoQuarantineFiles(d.cfg, newFindings)
blockActions := checks.AutoBlockIPs(d.cfg, initialAutoResponseFindings)
newFindings = append(newFindings, killActions...)
newFindings = append(newFindings, quarantineActions...)
newFindings = append(newFindings, permActions...)
newFindings = append(newFindings, challengeActions...)
newFindings = append(newFindings, blockActions...)
_ = alert.Dispatch(d.cfg, newFindings)
}
// Remove auto-fixed findings before storing to UI
if len(permFixedKeys) > 0 {
fixedSet := make(map[string]bool, len(permFixedKeys))
for _, k := range permFixedKeys {
fixedSet[k] = true
}
var filtered []alert.Finding
for _, f := range initialFindings {
key := f.Check + ":" + f.Message
if !fixedSet[key] {
filtered = append(filtered, f)
}
}
initialFindings = filtered
}
d.store.Update(initialFindings)
// Merge initial scan findings into the existing set. Previous deep scan
// results (outdated_plugins, wp_core, etc.) persist across restarts until
// the next deep scan replaces them. ClearLatestFindings is NOT called
// here - it would wipe deep scan findings that haven't re-run yet.
d.store.SetLatestFindings(initialFindings)
csmlog.Info("initial scan complete", "findings", len(initialFindings), "new", len(newFindings))
// NOW start the alert dispatcher - no more race with initial scan
d.wg.Add(1)
go d.alertDispatcher()
// Start periodic scanners
d.wg.Add(1)
go d.criticalScanner()
d.wg.Add(1)
go d.deepScanner()
// Start automatic signature updates
d.wg.Add(1)
go d.signatureUpdater()
d.wg.Add(1)
go d.geoipUpdater()
// Start heartbeat
d.wg.Add(1)
go d.heartbeat()
// Start systemd watchdog notifier — independent goroutine with its own
// ticker so long-running scans don't block the heartbeat.
d.wg.Add(1)
go d.watchdogNotifier()
csmlog.Info("CSM daemon running")
// Wait for signals
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT, syscall.SIGHUP)
for sig := range sigCh {
if sig == syscall.SIGHUP {
fmt.Fprintf(os.Stderr, "[%s] SIGHUP received - reloading rules\n", ts())
d.reloadSignatures()
d.publishGeoIP()
if d.fwEngine != nil {
if err := d.fwEngine.Apply(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Firewall reload error: %v\n", ts(), err)
} else {
fmt.Fprintf(os.Stderr, "[%s] Firewall rules reloaded\n", ts())
}
}
continue
}
break // SIGTERM or SIGINT
}
csmlog.Info("shutting down")
close(d.stopCh)
// Stop all watchers
d.logWatchersMu.Lock()
watchers := d.logWatchers
d.logWatchersMu.Unlock()
for _, w := range watchers {
w.Stop()
}
if d.webServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = d.webServer.Shutdown(ctx)
cancel()
}
if d.challengeServer != nil {
d.challengeServer.Shutdown()
}
if d.fileMonitor != nil {
d.fileMonitor.Stop()
}
if sw := d.getSpoolWatcher(); sw != nil {
sw.Stop()
}
if d.pamListener != nil {
d.pamListener.Stop()
}
if d.controlListener != nil {
d.controlListener.Stop()
}
d.wg.Wait()
if adb := attackdb.Global(); adb != nil {
adb.Stop()
}
if err := d.store.Close(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] error closing state store: %v\n", ts(), err)
}
d.lock.Release()
fmt.Fprintf(os.Stderr, "[%s] CSM daemon stopped\n", ts())
return nil
}
// DroppedAlerts returns the total number of alerts dropped due to
// channel backpressure since the daemon started.
func (d *Daemon) DroppedAlerts() int64 {
return atomic.LoadInt64(&d.droppedAlerts)
}
// alertDispatcher batches and dispatches alerts.
func (d *Daemon) alertDispatcher() {
defer d.wg.Done()
// Batch alerts: collect for 5 seconds, then dispatch
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
var batch []alert.Finding
for {
select {
case <-d.stopCh:
// Flush remaining with a timeout to avoid blocking shutdown
if len(batch) > 0 {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
done := make(chan struct{})
go func() {
d.dispatchBatch(batch)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
fmt.Fprintf(os.Stderr, "[%s] shutdown flush timed out after 30s, %d findings not dispatched\n", ts(), len(batch))
}
cancel()
}
return
case f := <-d.alertCh:
batch = append(batch, f)
case <-ticker.C:
if len(batch) > 0 {
d.dispatchBatch(batch)
batch = nil
}
}
}
}
func (d *Daemon) dispatchBatch(findings []alert.Finding) {
findings = alert.Deduplicate(findings)
suppressions := d.store.LoadSuppressions()
autoResponseFindings := findings
if len(suppressions) > 0 {
autoResponseFindings = filterUnsuppressedFindings(d.store, findings, suppressions)
}
// Record ALL findings in attack database (before filtering -
// repeated attacks from the same IP must still be counted even if
// the alert is suppressed by FilterNew).
if adb := attackdb.Global(); adb != nil {
for _, f := range autoResponseFindings {
adb.RecordFinding(f)
}
}
// Auto-block and permission fix run on ALL findings (not just new ones).
// These must execute BEFORE FilterNew because repeat offender IPs and
// recurring permission issues need to be fixed even if the alert was
// already sent in a previous cycle.
// Challenge routing runs FIRST - claims eligible IPs before hard-blocking.
challengeActions := checks.ChallengeRouteIPs(d.cfg, autoResponseFindings)
blockActions := checks.AutoBlockIPs(d.cfg, autoResponseFindings)
permActions, permFixedKeys := checks.AutoFixPermissions(d.cfg, autoResponseFindings)
// Mark auto-blocked IPs in attack database
if adb := attackdb.Global(); adb != nil {
for _, f := range blockActions {
if ip := checks.ExtractIPFromFinding(f); ip != "" {
adb.MarkBlocked(ip)
}
}
}
// Dismiss auto-fixed findings from the Findings page
for _, key := range permFixedKeys {
d.store.DismissLatestFinding(key)
}
// Filter through state - only new findings get alerted and logged
newFindings := d.store.FilterNew(findings)
// Filter out suppressed findings - prevents email/webhook alerts for
// paths the admin has explicitly suppressed (e.g. false positives).
// Suppressions are stored in state/suppressions.json, not in rule files.
if len(suppressions) > 0 {
newFindings = filterUnsuppressedFindings(d.store, newFindings, suppressions)
}
// Append auto-response actions to new findings for alerting
newFindings = append(newFindings, blockActions...)
newFindings = append(newFindings, challengeActions...)
newFindings = append(newFindings, permActions...)
if len(newFindings) == 0 {
d.store.Update(findings)
return
}
// Log to history
d.store.AppendHistory(newFindings)
// Kill, quarantine, and DB cleanup only run on NEW findings
killActions := checks.AutoKillProcesses(d.cfg, newFindings)
quarantineActions := checks.AutoQuarantineFiles(d.cfg, newFindings)
dbCleanActions := checks.AutoRespondDBMalware(d.cfg, newFindings)
newFindings = append(newFindings, killActions...)
newFindings = append(newFindings, quarantineActions...)
newFindings = append(newFindings, dbCleanActions...)
// Correlation
extra := checks.CorrelateFindings(newFindings)
now := time.Now()
for i := range extra {
if extra[i].Timestamp.IsZero() {
extra[i].Timestamp = now
}
}
newFindings = append(newFindings, extra...)
// Broadcast findings (no-op; dashboard uses polling)
if d.webServer != nil {
d.webServer.Broadcast(newFindings)
}
// Dispatch via email/webhook - filter out findings that are
// informational or fully automated (no human action needed).
// These are all visible in the web UI for forensics.
var alertable []alert.Finding
for _, f := range newFindings {
switch f.Check {
case "modsec_block_realtime", "modsec_warning_realtime", "modsec_csm_block_escalation":
continue // ModSecurity: fully automated, visible on /modsec
case "outdated_plugins":
continue // informational, visible on findings page
case "email_dkim_failure", "email_spf_rejection":
continue // operational email auth issues - visible on findings page
case "email_auth_failure_realtime", "pam_bruteforce", "exim_frozen_realtime":
continue // failed logins and frozen bounces - informational, no action needed
}
alertable = append(alertable, f)
}
if err := alert.Dispatch(d.cfg, alertable); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Alert dispatch error: %v\n", ts(), err)
}
d.store.Update(findings)
}
func filterUnsuppressedFindings(store *state.Store, findings []alert.Finding, suppressions []state.SuppressionRule) []alert.Finding {
if len(suppressions) == 0 {
return findings
}
var filtered []alert.Finding
for _, f := range findings {
if !store.IsSuppressed(f, suppressions) {
filtered = append(filtered, f)
}
}
return filtered
}
// criticalScanner runs critical checks every 10 minutes.
func (d *Daemon) criticalScanner() {
defer d.wg.Done()
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
d.runPeriodicChecks(checks.TierCritical)
}
}
}
// deepScanner runs deep checks at the configured interval (default 60 min).
// If fanotify is active, runs only the checks it can't replace (reduced set).
// If fanotify is NOT active (fallback mode), runs the full deep tier for timer-mode parity.
func (d *Daemon) deepScanner() {
defer d.wg.Done()
interval := time.Duration(d.cfg.Thresholds.DeepScanIntervalMin) * time.Minute
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
// Update threat intelligence feeds (once per day)
if db := checks.GetThreatDB(); db != nil {
_ = db.UpdateFeeds()
}
// Prune expired attack DB records (90-day retention)
if adb := attackdb.Global(); adb != nil {
adb.PruneExpired()
}
// If fanotify is active, only run checks it can't replace.
// If fanotify is NOT active, run the full deep tier.
var findings []alert.Finding
var deepTier checks.Tier
if d.fileMonitor != nil {
findings = checks.RunReducedDeep(d.cfg, d.store)
deepTier = checks.TierDeep
} else {
deepTier = checks.TierDeep
findings = checks.RunTier(d.cfg, d.store, deepTier)
}
if len(findings) > 0 {
// Atomically purge stale perf findings and merge new ones.
d.store.PurgeAndMergeFindings(checks.PerfCheckNamesForTier(deepTier), findings)
for _, f := range findings {
if strings.HasPrefix(f.Check, "perf_") && f.Severity == alert.Warning {
continue
}
select {
case d.alertCh <- f:
default:
atomic.AddInt64(&d.droppedAlerts, 1)
fmt.Fprintf(os.Stderr, "[%s] alert channel full, dropping deep finding: %s\n", ts(), f.Check)
}
}
}
}
}
}
func (d *Daemon) runPeriodicChecks(tier checks.Tier) {
// Verify integrity
if err := integrity.Verify(d.binaryPath, d.cfg); err != nil {
select {
case d.alertCh <- alert.Finding{
Severity: alert.Critical,
Check: "integrity",
Message: fmt.Sprintf("BINARY/CONFIG TAMPER DETECTED: %v", err),
Timestamp: time.Now(),
}:
default:
atomic.AddInt64(&d.droppedAlerts, 1)
fmt.Fprintf(os.Stderr, "[%s] alert channel full, dropping integrity finding\n", ts())
}
return
}
findings := checks.RunTier(d.cfg, d.store, tier)
if len(findings) > 0 {
// Atomically purge stale perf findings and merge new ones.
d.store.PurgeAndMergeFindings(checks.PerfCheckNamesForTier(tier), findings)
for _, f := range findings {
if strings.HasPrefix(f.Check, "perf_") && f.Severity == alert.Warning {
continue
}
select {
case d.alertCh <- f:
default:
atomic.AddInt64(&d.droppedAlerts, 1)
fmt.Fprintf(os.Stderr, "[%s] alert channel full, dropping periodic finding: %s\n", ts(), f.Check)
}
}
}
}
// heartbeat sends periodic pings to dead man's switch.
func (d *Daemon) heartbeat() {
defer d.wg.Done()
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
alert.SendHeartbeat(d.cfg)
d.hijackDetector.Cleanup()
// Clean expired temporary allows
if d.fwEngine != nil {
d.fwEngine.CleanExpiredAllows()
d.fwEngine.CleanExpiredSubnets()
}
// Clean expired temporary whitelist entries
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.PruneExpiredWhitelist()
}
}
}
}
func (d *Daemon) startLogWatchers() {
// Session log handler wrapper - feeds events to both the alert handler and hijack detector
sessionHandler := func(line string, cfg *config.Config) []alert.Finding {
// Feed to hijack detector (tracks password changes + correlates with logins)
ParseSessionLineForHijack(line, d.hijackDetector)
// Regular session log handling
return parseSessionLogLine(line, cfg)
}
hostInfo := platform.Detect()
type logFile struct {
path string
handler func(string, *config.Config) []alert.Finding
}
var logFiles []logFile
// Generic Linux auth log. RHEL-family uses /var/log/secure, Debian
// family uses /var/log/auth.log. Only register the log appropriate
// for the detected OS so we don't spam "not found, retrying" forever.
if hostInfo.IsDebianFamily() {
logFiles = append(logFiles, logFile{"/var/log/auth.log", parseSecureLogLine})
} else {
logFiles = append(logFiles, logFile{"/var/log/secure", parseSecureLogLine})
}
// eximHandler wraps parseEximLogLine (unchanged) and augments the result
// with smtpAuthTracker findings for dovecot authenticator failures.
eximHandler := func(line string, cfg *config.Config) []alert.Finding {
findings := parseEximLogLine(line, cfg)
if strings.Contains(line, "authenticator failed") && strings.Contains(line, "dovecot") {
ip := extractBracketedIP(line)
account := extractSetID(line)
// Canonicalize IPv4-mapped IPv6 (::ffff:a.b.c.d) to plain IPv4 so the
// tracker doesn't double-count the same attacker as two IPs.
if ip != "" {
if parsed := net.ParseIP(ip); parsed != nil {
if v4 := parsed.To4(); v4 != nil {
ip = v4.String()
}
}
}
if ip != "" && !isInfraIPDaemon(ip, cfg.InfraIPs) && !isPrivateOrLoopback(ip) {
if d.smtpAuthTracker != nil {
findings = append(findings, d.smtpAuthTracker.Record(ip, account)...)
}
}
}
return findings
}
// mailHandler composes parseDovecotLogLine (preserving email_suspicious_geo)
// with mailAuthTracker augmentation for IMAP/POP3/ManageSieve brute-force,
// subnet spray, account spray, and compromise detection.
mailHandler := func(line string, cfg *config.Config) []alert.Finding {
findings := parseDovecotLogLine(line, cfg)
if !isMailAuthLine(line) {
return findings
}
ip, account, success := extractMailLoginEvent(line)
if ip == "" {
return findings
}
if parsed := net.ParseIP(ip); parsed != nil {
if v4 := parsed.To4(); v4 != nil {
ip = v4.String()
}
}
if isInfraIPDaemon(ip, cfg.InfraIPs) || isPrivateOrLoopback(ip) {
return findings
}
if d.mailAuthTracker == nil {
return findings
}
if success {
findings = append(findings, d.mailAuthTracker.RecordSuccess(ip, account)...)
} else {
findings = append(findings, d.mailAuthTracker.Record(ip, account)...)
}
return findings
}
// cPanel-specific logs — only watch these on cPanel hosts. On plain
// Ubuntu/AlmaLinux they do not exist and the old code spammed
// "not found, will retry every 60s" forever.
if hostInfo.IsCPanel() {
logFiles = append(logFiles,
logFile{"/usr/local/cpanel/logs/session_log", sessionHandler},
logFile{"/usr/local/cpanel/logs/access_log", parseAccessLogLineEnhanced},
logFile{"/var/log/exim_mainlog", eximHandler},
logFile{"/var/log/messages", parseFTPLogLine},
logFile{"/var/log/maillog", mailHandler},
)
}
// Only watch PHP Shield events if enabled in config
if d.cfg.PHPShield.Enabled {
logFiles = append(logFiles, logFile{phpEventsLogPath, parsePHPShieldLogLine})
}
// ModSecurity error log - auto-discover path based on detected web server.
if modsecPath := discoverModSecLogPath(d.cfg); modsecPath != "" {
logFiles = append(logFiles, logFile{modsecPath, parseModSecLogLineDeduped})
} else if hostInfo.WebServer != platform.WSNone {
// Only bother with the retry loop if a web server is actually
// present. Headless hosts don't need this.
fmt.Fprintf(os.Stderr, "[%s] ModSecurity error log not found (checked %v), will retry every 60s\n", ts(), hostInfo.ErrorLogPaths)
d.wg.Add(1)
go func() {
defer d.wg.Done()
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
path := discoverModSecLogPath(d.cfg)
if path == "" {
continue
}
w, err := NewLogWatcher(path, d.cfg, parseModSecLogLineDeduped, d.alertCh)
if err != nil {
continue
}
d.logWatchersMu.Lock()
d.logWatchers = append(d.logWatchers, w)
d.logWatchersMu.Unlock()
d.wg.Add(1)
go func(w *LogWatcher) {
defer d.wg.Done()
w.Run(d.stopCh)
}(w)
csmlog.Info("watching log (appeared after retry)", "path", path)
return
}
}
}()
}
// Real-time access log watcher for wp-login/xmlrpc brute force detection.
// Auto-discover path from platform info (Apache/Nginx/cPanel aware).
if accessLogPath := discoverAccessLogPath(); accessLogPath != "" {
logFiles = append(logFiles, logFile{accessLogPath, parseAccessLogBruteForce})
} else if hostInfo.WebServer != platform.WSNone && len(hostInfo.AccessLogPaths) > 0 {
csmlog.Warn("access log not found, will retry every 60s", "candidates", fmt.Sprintf("%v", hostInfo.AccessLogPaths))
d.wg.Add(1)
go d.retryLogWatcher(hostInfo.AccessLogPaths[0], parseAccessLogBruteForce)
}
// Start background eviction for modsec dedup/escalation state
StartModSecEviction(d.stopCh)
// Start background eviction for access log brute force state
StartAccessLogEviction(d.stopCh)
// Start background eviction for email rate limiting state
StartEmailRateEviction(d.stopCh)
// Start background purge for SMTP brute-force tracker
d.wg.Add(1)
go func() {
defer d.wg.Done()
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
if d.smtpAuthTracker != nil {
d.smtpAuthTracker.Purge()
}
}
}
}()
// Start background purge for mail (IMAP/POP3) brute-force tracker
d.wg.Add(1)
go func() {
defer d.wg.Done()
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
if d.mailAuthTracker != nil {
d.mailAuthTracker.Purge()
}
}
}
}()
for _, lf := range logFiles {
w, err := NewLogWatcher(lf.path, d.cfg, lf.handler, d.alertCh)
if err != nil {
if os.IsNotExist(err) {
// File doesn't exist yet - retry periodically until it appears
d.wg.Add(1)
go d.retryLogWatcher(lf.path, lf.handler)
} else {
fmt.Fprintf(os.Stderr, "[%s] Warning: could not watch %s: %v\n", ts(), lf.path, err)
}
continue
}
d.logWatchers = append(d.logWatchers, w)
d.wg.Add(1)
go func(w *LogWatcher) {
defer d.wg.Done()
w.Run(d.stopCh)
}(w)
csmlog.Info("watching log", "path", lf.path)
}
}
// retryLogWatcher polls for a missing log file every 60 seconds.
// When the file appears, it starts a watcher and returns.
func (d *Daemon) retryLogWatcher(path string, handler LogLineHandler) {
defer d.wg.Done()
csmlog.Warn("log not found, will retry every 60s", "path", path)
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
w, err := NewLogWatcher(path, d.cfg, handler, d.alertCh)
if err != nil {
continue // still missing, keep retrying
}
d.logWatchersMu.Lock()
d.logWatchers = append(d.logWatchers, w)
d.logWatchersMu.Unlock()
d.wg.Add(1)
go func(w *LogWatcher) {
defer d.wg.Done()
w.Run(d.stopCh)
}(w)
csmlog.Info("watching log (appeared after retry)", "path", path)
return
}
}
}
func (d *Daemon) startWebUI() {
if !d.cfg.WebUI.Enabled {
return
}
srv, err := webui.New(d.cfg, d.store)
if err != nil {
csmlog.Error("webui init error", "err", err)
return
}
// Set version and signature count for status API
srv.SetVersion(d.version)
if scanner := signatures.Global(); scanner != nil {
srv.SetSigCount(scanner.RuleCount())
}
d.webServer = srv
d.logWatchersMu.Lock()
numWatchers := len(d.logWatchers)
d.logWatchersMu.Unlock()
srv.SetHealthInfo(d.fileMonitor != nil, numWatchers)
if d.fwEngine != nil {
srv.SetIPBlocker(d.fwEngine)
}
d.wg.Add(1)
go func() {
defer d.wg.Done()
if err := srv.Start(); err != nil {
csmlog.Error("webui server error", "err", err)
}
}()
}
func (d *Daemon) startPAMListener() {
pl, err := NewPAMListener(d.cfg, d.alertCh)
if err != nil {
csmlog.Warn("PAM listener not available", "err", err)
return
}
d.pamListener = pl
d.wg.Add(1)
go func() {
defer d.wg.Done()
pl.Run(d.stopCh)
}()
csmlog.Info("PAM listener active", "socket", pamSocketPath)
}
func (d *Daemon) startControlListener() {
cl, err := NewControlListener(d)
if err != nil {
// The daemon can still function without the socket — periodic
// scans and webui keep running — but the CLI will hard-error
// because the socket is the expected path. Log loudly.
csmlog.Error("control listener not available", "err", err)
return
}
d.controlListener = cl
d.wg.Add(1)
go func() {
defer d.wg.Done()
cl.Run(d.stopCh)
}()
csmlog.Info("control listener active", "socket", controlSocketPath)
}
func (d *Daemon) startFileMonitor() {
fm, err := NewFileMonitor(d.cfg, d.alertCh)
if err != nil {
csmlog.Warn("fanotify not available, falling back to periodic deep scan", "err", err)
return
}
d.fileMonitor = fm
d.wg.Add(1)
go func() {
defer d.wg.Done()
fm.Run(d.stopCh)
}()
csmlog.Info("fanotify file monitor active", "paths", "/home, /tmp, /dev/shm")
}
func (d *Daemon) startSpoolWatcher() {
if !d.cfg.EmailAV.Enabled {
return
}
// Create ClamAV scanner
clamScanner := emailav.NewClamdScanner(d.cfg.EmailAV.ClamdSocket)
// Create YARA-X scanner - share compiled rules from the global YARA scanner
var yaraScanner *emailav.YaraXScanner
if gs := yara.Global(); gs != nil {
yaraScanner = emailav.NewYaraXScanner(gs)
} else {
yaraScanner = emailav.NewYaraXScanner(nil)
}
// Create orchestrator with both engines
scanners := []emailav.Scanner{clamScanner, yaraScanner}
orch := emailav.NewOrchestrator(scanners, d.cfg.EmailAV.ScanTimeoutDuration())
// Create quarantine
quar := emailav.NewQuarantine("/opt/csm/quarantine/email")
d.emailQuarantine = quar
// Create and start spool watcher
sw, err := NewSpoolWatcher(d.cfg, d.alertCh, orch, quar)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Email AV spool watcher not available: %v\n", ts(), err)
return
}
d.setSpoolWatcher(sw)
d.wg.Add(1)
go func() {
defer d.wg.Done()
d.runSpoolWatcherLoop(sw, orch, quar)
}()
// Start quarantine cleanup goroutine
d.wg.Add(1)
go d.emailQuarantineCleanup()
fmt.Fprintf(os.Stderr, "[%s] Email AV spool watcher active\n", ts())
}
func (d *Daemon) runSpoolWatcherLoop(sw *SpoolWatcher, orch *emailav.Orchestrator, quar *emailav.Quarantine) {
current := sw
for {
current.Run()
select {
case <-d.stopCh:
return
default:
}
fmt.Fprintf(os.Stderr, "[%s] Email AV spool watcher stopped unexpectedly; restarting in 2s\n", ts())
select {
case <-d.stopCh:
return
case <-time.After(2 * time.Second):
}
next, err := NewSpoolWatcher(d.cfg, d.alertCh, orch, quar)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Email AV spool watcher restart failed: %v\n", ts(), err)
continue
}
current = next
d.setSpoolWatcher(next)
}
}
func (d *Daemon) setSpoolWatcher(sw *SpoolWatcher) {
d.spoolWatcherMu.Lock()
d.spoolWatcher = sw
d.spoolWatcherMu.Unlock()
d.syncEmailAVWebState()
}
func (d *Daemon) getSpoolWatcher() *SpoolWatcher {
d.spoolWatcherMu.Lock()
defer d.spoolWatcherMu.Unlock()
return d.spoolWatcher
}
// startForwarderWatcher starts the inotify watcher for /etc/valiases/.
func (d *Daemon) startForwarderWatcher() {
fw, err := NewForwarderWatcher(d.alertCh, d.cfg.EmailProtection.KnownForwarders)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Warning: forwarder watcher not started: %v\n", ts(), err)
return
}
d.forwarderWatcher = fw
d.wg.Add(1)
go func() {
defer d.wg.Done()
fw.Run(d.stopCh)
}()
csmlog.Info("watching log (inotify forwarder watcher)", "path", "/etc/valiases/")
}
func (d *Daemon) syncEmailAVWebState() {
if d.webServer == nil || d.emailQuarantine == nil {
return
}
d.webServer.SetEmailQuarantine(d.emailQuarantine)
if sw := d.getSpoolWatcher(); sw != nil {
if sw.PermissionMode() {
d.webServer.SetEmailAVWatcherMode("permission")
} else {
d.webServer.SetEmailAVWatcherMode("notification")
}
}
}
// emailQuarantineCleanup periodically removes expired quarantined email messages.
func (d *Daemon) emailQuarantineCleanup() {
defer d.wg.Done()
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
if d.emailQuarantine != nil {
cleaned, err := d.emailQuarantine.CleanExpired(30 * 24 * time.Hour)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Email quarantine cleanup error: %v\n", ts(), err)
} else if cleaned > 0 {
fmt.Fprintf(os.Stderr, "[%s] Email quarantine cleanup: removed %d expired messages\n", ts(), cleaned)
}
}
}
}
}
func (d *Daemon) startChallengeServer() {
if !d.cfg.Challenge.Enabled {
return
}
if d.fwEngine == nil {
fmt.Fprintf(os.Stderr, "[%s] Challenge server requires firewall to be enabled (for escalation and allow). Skipping.\n", ts())
return
}
unblocker := challenge.IPUnblocker(d.fwEngine)
d.ipList = challenge.NewIPList(d.cfg.StatePath)
checks.SetChallengeIPList(d.ipList)
srv := challenge.New(d.cfg, unblocker, d.ipList)
d.challengeServer = srv
d.wg.Add(1)
go func() {
defer d.wg.Done()
csmlog.Info("challenge server active", "port", d.cfg.Challenge.ListenPort)
if err := srv.Start(); err != nil && err.Error() != "http: Server closed" {
csmlog.Error("challenge server error", "err", err)
}
}()
}
func (d *Daemon) challengeEscalator() {
defer d.wg.Done()
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
expiry := parseBlockExpiry(d.cfg.AutoResponse.BlockExpiry)
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
if d.challengeServer != nil {
d.challengeServer.CleanExpired()
}
expired := d.ipList.ExpiredEntries()
for _, e := range expired {
if d.fwEngine == nil {
continue
}
reason := fmt.Sprintf("CSM challenge-timeout: %s", truncateStr(e.Reason, 100))
if err := d.fwEngine.BlockIP(e.IP, reason, expiry); err != nil {
fmt.Fprintf(os.Stderr, "[%s] challenge-escalate: error blocking %s: %v\n", ts(), e.IP, err)
continue
}
fmt.Fprintf(os.Stderr, "[%s] CHALLENGE-ESCALATE: %s timed out, hard-blocked\n", ts(), e.IP)
}
}
}
}
func parseBlockExpiry(s string) time.Duration {
if s == "" {
return 24 * time.Hour
}
d, err := time.ParseDuration(s)
if err != nil {
return 24 * time.Hour
}
return d
}
func truncateStr(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max]
}
func (d *Daemon) initGeoIP() {
dbDir := filepath.Join(d.cfg.StatePath, "geoip")
db := geoip.Open(dbDir)
if db != nil {
d.geoipDB = db
setGeoIPDB(db) // make available to log watcher handlers for country filtering
if d.webServer != nil {
d.webServer.SetGeoIPDB(db)
}
}
}
// publishGeoIP reloads existing GeoIP databases or creates a new DB
// if databases were downloaded for the first time.
// Mutex-protected: safe to call from geoipUpdater goroutine and SIGHUP handler concurrently.
func (d *Daemon) publishGeoIP() {
d.geoipMu.Lock()
defer d.geoipMu.Unlock()
if d.geoipDB != nil {
if err := d.geoipDB.Reload(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] GeoIP reload error: %v\n", ts(), err)
} else {
fmt.Fprintf(os.Stderr, "[%s] GeoIP databases reloaded\n", ts())
}
return
}
// First-time: no DB existed at startup, try to open freshly downloaded files
dbDir := filepath.Join(d.cfg.StatePath, "geoip")
db := geoip.OpenFresh(dbDir)
if db != nil {
d.geoipDB = db
setGeoIPDB(db)
if d.webServer != nil {
d.webServer.SetGeoIPDB(db)
}
fmt.Fprintf(os.Stderr, "[%s] GeoIP databases loaded for the first time\n", ts())
}
}
// geoipUpdater periodically downloads updated GeoLite2 databases.
func (d *Daemon) geoipUpdater() {
defer d.wg.Done()
// Skip if no credentials configured
if d.cfg.GeoIP.AccountID == "" || d.cfg.GeoIP.LicenseKey == "" {
return
}
// Skip if auto_update is explicitly false
if d.cfg.GeoIP.AutoUpdate != nil && !*d.cfg.GeoIP.AutoUpdate {
return
}
interval := 24 * time.Hour
if d.cfg.GeoIP.UpdateInterval != "" {
if parsed, err := time.ParseDuration(d.cfg.GeoIP.UpdateInterval); err == nil && parsed >= time.Hour {
interval = parsed
}
}
// Wait 5 minutes before first update attempt (let the daemon stabilize)
select {
case <-d.stopCh:
return
case <-time.After(5 * time.Minute):
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
d.doGeoIPUpdate()
select {
case <-d.stopCh:
return
case <-ticker.C:
}
}
}
func (d *Daemon) doGeoIPUpdate() {
results := geoip.Update(
filepath.Join(d.cfg.StatePath, "geoip"),
d.cfg.GeoIP.AccountID,
d.cfg.GeoIP.LicenseKey,
d.cfg.GeoIP.Editions,
)
if results == nil {
return
}
anyUpdated := false
for _, r := range results {
switch r.Status {
case "updated":
fmt.Fprintf(os.Stderr, "[%s] GeoIP auto-update: %s updated\n", ts(), r.Edition)
anyUpdated = true
case "up_to_date":
// silent
case "error":
fmt.Fprintf(os.Stderr, "[%s] GeoIP auto-update: %s error: %v\n", ts(), r.Edition, r.Err)
}
}
if anyUpdated {
d.publishGeoIP()
}
}
func (d *Daemon) startFirewall() {
if !d.cfg.Firewall.Enabled {
return
}
// Merge top-level infra IPs into firewall's list. Top-level controls
// alert suppression (tight: only admin IPs), firewall may include
// additional CIDRs (e.g. server's own range) that need port access
// but should still be tracked for security alerts.
d.cfg.Firewall.InfraIPs = mergeInfraIPs(d.cfg.InfraIPs, d.cfg.Firewall.InfraIPs)
engine, err := firewall.NewEngine(d.cfg.Firewall, d.cfg.StatePath)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Firewall engine init error: %v\n", ts(), err)
return
}
if err := engine.Apply(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Firewall apply error: %v\n", ts(), err)
return
}
d.fwEngine = engine
// Set firewall engine for auto-blocking
checks.SetIPBlocker(engine)
fwState, _ := firewall.LoadState(d.cfg.StatePath)
csmlog.Info("firewall active",
"blocked_ips", len(fwState.Blocked),
"allowed_ips", len(fwState.Allowed),
)
// Start Dynamic DNS resolver if configured
if len(d.cfg.Firewall.DynDNSHosts) > 0 {
resolver := firewall.NewDynDNSResolver(d.cfg.Firewall.DynDNSHosts, engine)
d.wg.Add(1)
go func() {
defer d.wg.Done()
resolver.Run(d.stopCh)
}()
csmlog.Info("DynDNS resolver active", "hosts", len(d.cfg.Firewall.DynDNSHosts))
}
// Start Cloudflare IP whitelist refresh if configured
if d.cfg.Cloudflare.Enabled {
d.wg.Add(1)
go d.cloudflareRefreshLoop()
csmlog.Info("cloudflare IP whitelist enabled", "refresh_hours", d.cfg.Cloudflare.RefreshHours)
}
}
// cloudflareRefreshLoop fetches Cloudflare IPs and updates the firewall sets periodically.
func (d *Daemon) cloudflareRefreshLoop() {
defer d.wg.Done()
interval := time.Duration(d.cfg.Cloudflare.RefreshHours) * time.Hour
// Fetch immediately on startup
d.refreshCloudflareIPs()
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
d.refreshCloudflareIPs()
}
}
}
func (d *Daemon) refreshCloudflareIPs() {
ipv4, ipv6, err := firewall.FetchCloudflareIPs()
if err != nil {
csmlog.Error("cloudflare IP fetch error", "err", err)
return
}
if d.fwEngine != nil {
if err := d.fwEngine.UpdateCloudflareSet(ipv4, ipv6); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Cloudflare set update error: %v\n", ts(), err)
return
}
}
firewall.SaveCFState(d.cfg.StatePath, ipv4, ipv6, time.Now())
// Update the checks package so AutoBlockIPs/ChallengeRouteIPs skip CF IPs.
// Blocking a CF edge IP would block thousands of legitimate users.
allCF := make([]string, 0, len(ipv4)+len(ipv6))
allCF = append(allCF, ipv4...)
allCF = append(allCF, ipv6...)
checks.SetCloudflareNets(allCF)
}
// signatureUpdater periodically downloads new rules and reloads scanners.
func (d *Daemon) signatureUpdater() {
defer d.wg.Done()
yamlEnabled := d.cfg.Signatures.UpdateURL != ""
forgeEnabled := d.cfg.Signatures.YaraForge.Enabled && yara.Available()
if !yamlEnabled && !forgeEnabled {
return
}
select {
case <-d.stopCh:
return
case <-time.After(5 * time.Minute):
}
yamlInterval := 24 * time.Hour
if d.cfg.Signatures.UpdateInterval != "" {
if parsed, err := time.ParseDuration(d.cfg.Signatures.UpdateInterval); err == nil && parsed >= time.Hour {
yamlInterval = parsed
}
}
forgeInterval := 168 * time.Hour
if d.cfg.Signatures.YaraForge.UpdateInterval != "" {
if parsed, err := time.ParseDuration(d.cfg.Signatures.YaraForge.UpdateInterval); err == nil && parsed >= time.Hour {
forgeInterval = parsed
}
}
tickInterval := yamlInterval
if forgeEnabled && forgeInterval < tickInterval {
tickInterval = forgeInterval
}
if !yamlEnabled {
tickInterval = forgeInterval
}
var lastYAML, lastForge time.Time
ticker := time.NewTicker(tickInterval)
defer ticker.Stop()
for {
now := time.Now()
if yamlEnabled && now.Sub(lastYAML) >= yamlInterval {
d.doSignatureUpdate()
lastYAML = now
}
if forgeEnabled && now.Sub(lastForge) >= forgeInterval {
d.doForgeUpdate()
lastForge = now
}
select {
case <-d.stopCh:
return
case <-ticker.C:
}
}
}
func (d *Daemon) doSignatureUpdate() {
count, err := signatures.Update(d.cfg.Signatures.RulesDir, d.cfg.Signatures.UpdateURL, d.cfg.Signatures.SigningKey)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Signature auto-update failed: %v\n", ts(), err)
return
}
fmt.Fprintf(os.Stderr, "[%s] Signature auto-update: %d rules downloaded\n", ts(), count)
d.reloadSignatures()
}
func (d *Daemon) doForgeUpdate() {
yaraScanner := yara.Global()
if yaraScanner == nil {
// No YARA scanner active (build without yara tag or no rules dir).
// Skip Forge update - rules can't be loaded anyway.
return
}
db := store.Global()
currentVersion := ""
if db != nil {
currentVersion = db.GetMetaString("forge_version_" + d.cfg.Signatures.YaraForge.Tier)
}
newVersion, count, err := signatures.ForgeUpdate(
d.cfg.Signatures.RulesDir,
d.cfg.Signatures.YaraForge.Tier,
currentVersion,
d.cfg.Signatures.SigningKey,
d.cfg.Signatures.DisabledRules,
)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] YARA Forge update failed: %v\n", ts(), err)
return
}
if count == 0 {
return
}
fmt.Fprintf(os.Stderr, "[%s] YARA Forge update: %d rules (version %s)\n", ts(), count, newVersion)
// Record rule count before reload to detect conflicts with existing .yar files.
prevCount := yaraScanner.RuleCount()
if err := yaraScanner.Reload(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] YARA rule reload after Forge update error: %v\n", ts(), err)
return // don't store version - retry next cycle
}
newCount := yaraScanner.RuleCount()
// If rule count dropped, the Forge file likely conflicts with existing rules.
// Roll back: remove the Forge file, reload again, don't store version.
if newCount < prevCount {
forgeFile := filepath.Join(d.cfg.Signatures.RulesDir, fmt.Sprintf("yara-forge-%s.yar", d.cfg.Signatures.YaraForge.Tier))
fmt.Fprintf(os.Stderr, "[%s] YARA Forge rollback: rule count dropped %d -> %d (conflict with existing rules), removing %s\n",
ts(), prevCount, newCount, forgeFile)
_ = os.Remove(forgeFile)
_ = yaraScanner.Reload()
return // don't store version
}
fmt.Fprintf(os.Stderr, "[%s] Reloaded %d YARA rules after Forge update\n", ts(), newCount)
if db != nil {
_ = db.SetMetaString("forge_version_"+d.cfg.Signatures.YaraForge.Tier, newVersion)
}
}
func (d *Daemon) reloadSignatures() {
if scanner := signatures.Global(); scanner != nil {
if err := scanner.Reload(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] YAML rule reload error: %v\n", ts(), err)
} else {
fmt.Fprintf(os.Stderr, "[%s] Reloaded %d YAML rules (version %d)\n", ts(), scanner.RuleCount(), scanner.Version())
if d.webServer != nil {
d.webServer.SetSigCount(scanner.RuleCount())
}
}
}
if yaraScanner := yara.Global(); yaraScanner != nil {
if err := yaraScanner.Reload(); err != nil {
fmt.Fprintf(os.Stderr, "[%s] YARA rule reload error: %v\n", ts(), err)
} else {
fmt.Fprintf(os.Stderr, "[%s] Reloaded %d YARA rule file(s)\n", ts(), yaraScanner.RuleCount())
}
}
}
// deployConfigs writes embedded config files to their system locations on startup.
// Ensures WHM plugin CGI and ModSec rules stay current after binary upgrades.
//
// Every file written here is a system integration point consumed by a
// different process (WHM, Apache, nginx); the permissions intentionally
// allow the right external reader. Gosec G301/G306 warnings on this
// function are suppressed inline with the specific integration target.
func deployConfigs() {
// WHM plugin CGI - embedded in binary
if _, err := os.Stat("/usr/local/cpanel"); err == nil {
dst := "/usr/local/cpanel/whostmgr/docroot/cgi/addon_csm.cgi"
// #nosec G306 -- WHM CGI endpoint; 0755 is required so cPanel's
// webserver can execute it.
if err := os.WriteFile(dst, embeddedWHMCGI, 0755); err == nil {
csmlog.Info("WHM plugin CGI deployed", "path", dst)
}
// Write the AppConfig file, then register it with WHM.
// Writing the file alone does NOT make the plugin appear in the
// sidebar — WHM's AppConfig system maintains a registration
// database that is updated via `register_appconfig`. Skipping
// that step was a long-standing bug; the plugin file existed on
// disk but never showed up in the menu.
// #nosec G301 -- cPanel standard /var/cpanel/apps directory.
_ = os.MkdirAll("/var/cpanel/apps", 0755)
confPath := "/var/cpanel/apps/csm.conf"
// #nosec G306 -- WHM AppConfig; read by cPanel tooling, 0644 is convention.
if err := os.WriteFile(confPath, embeddedWHMConf, 0644); err != nil {
csmlog.Error("WHM AppConfig write failed", "path", confPath, "err", err)
} else if err := registerWHMPlugin(confPath); err != nil {
// Non-fatal: the conf is on disk, register_appconfig failure is
// logged so operators can fix it manually. Most common failure
// is register_appconfig not being in PATH (old cPanel versions).
csmlog.Warn("WHM plugin registration failed", "err", err)
} else {
csmlog.Info("WHM plugin registered with AppConfig")
}
}
// Deploy script (self-updating)
// #nosec G306 -- Shell script executed by operators and by the CSM
// upgrade path; needs to be executable, not private.
_ = os.WriteFile("/opt/csm/deploy.sh", embeddedDeployScript, 0755)
// ModSecurity virtual patches
for _, dst := range []string{
"/etc/apache2/conf.d/modsec/modsec2.user.conf",
"/usr/local/apache/conf/modsec2.user.conf",
} {
if _, err := os.Stat(filepath.Dir(dst)); err == nil {
// #nosec G306 -- Apache reads this ModSecurity config; webserver
// runs as a different user.
_ = os.WriteFile(dst, embeddedModSec, 0644)
overridesFile := filepath.Join(filepath.Dir(dst), "modsec2.csm-overrides.conf")
modsec.EnsureOverridesInclude(dst, overridesFile)
break
}
}
}
// registerWHMPlugin runs cPanel's register_appconfig helper to add the CSM
// plugin to the WHM sidebar. WHM maintains a cached registration database
// separate from the /var/cpanel/apps/ conf files; without running this
// helper, the plugin file exists on disk but the menu never shows it.
//
// Idempotent: re-running against an already-registered plugin just updates
// the entry. Non-fatal on failure — deployment continues and the operator
// can rerun manually.
func registerWHMPlugin(confPath string) error {
bin := "/usr/local/cpanel/bin/register_appconfig"
if _, err := os.Stat(bin); err != nil {
return fmt.Errorf("register_appconfig not found at %s: %w", bin, err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// #nosec G204 -- bin is the fixed cPanel path validated by os.Stat above;
// confPath was just written by deployConfigs from an embedded constant.
cmd := exec.CommandContext(ctx, bin, confPath)
out, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("%s %s: %w: %s", bin, confPath, err, strings.TrimSpace(string(out)))
}
return nil
}
// watchdogNotifier sends systemd watchdog keepalives on its own ticker.
// Runs at half the WatchdogSec interval so there's always margin.
// Completely independent of scan goroutines — never blocks.
func (d *Daemon) watchdogNotifier() {
defer d.wg.Done()
usecStr := os.Getenv("WATCHDOG_USEC")
if usecStr == "" {
return // watchdog not configured
}
addr := os.Getenv("NOTIFY_SOCKET")
if addr == "" {
return
}
usec, err := strconv.ParseInt(usecStr, 10, 64)
if err != nil || usec <= 0 {
return
}
// Notify at half the watchdog interval for safety margin
interval := time.Duration(usec) * time.Microsecond / 2
if interval < 10*time.Second {
interval = 10 * time.Second
}
csmlog.Info("systemd watchdog active", "interval", interval.String())
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
sdNotify(addr, "WATCHDOG=1")
}
}
}
func sdNotify(addr, msg string) {
conn, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
if err != nil {
return
}
defer func() { _ = syscall.Close(conn) }()
sa := &syscall.SockaddrUnix{Name: addr}
_ = syscall.Sendmsg(conn, []byte(msg), nil, sa, 0)
}
func ts() string {
return time.Now().Format("2006-01-02 15:04:05")
}
func orUnknown(v string) string {
if v == "" {
return "unknown"
}
return v
}
func orNone(v string) string {
if v == "" {
return "none"
}
return v
}
//go:build linux
package daemon
import (
"fmt"
"os"
"path/filepath"
"regexp"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
"unsafe"
"golang.org/x/sys/unix"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/signatures"
"github.com/pidginhost/csm/internal/wpcheck"
"github.com/pidginhost/csm/internal/yara"
)
// fanotify constants (not all in Go stdlib)
const (
FAN_MARK_ADD = 0x00000001
FAN_MARK_MOUNT = 0x00000010
FAN_CLOSE_WRITE = 0x00000008
FAN_CREATE = 0x00000100
FAN_CLASS_NOTIF = 0x00000000
FAN_CLOEXEC = 0x00000001
FAN_NONBLOCK = 0x00000002
)
// fanotifyEventMetadata is the header for each fanotify event.
type fanotifyEventMetadata struct {
EventLen uint32
Vers uint8
Reserved uint8
MetadataLen uint16
Mask uint64
Fd int32
Pid int32
}
const metadataSize = int(unsafe.Sizeof(fanotifyEventMetadata{}))
// M1 - webshells map at package level (avoid per-call allocation)
var knownWebshells = map[string]bool{
"h4x0r.php": true, "c99.php": true, "r57.php": true,
"wso.php": true, "alfa.php": true, "b374k.php": true,
"shell.php": true, "cmd.php": true, "backdoor.php": true,
"webshell.php": true,
}
// M3 - plugin stat cache with TTL
type pluginCacheEntry struct {
exists bool
ts time.Time
}
var pluginStatCache sync.Map // key: pluginDir string → value: pluginCacheEntry
const pluginCacheTTL = 5 * time.Minute
// cronSpoolWatchDir is the directory the realtime crontab watcher marks.
// Declared as a var (not const) so tests can redirect it under t.TempDir()
// without touching the real /var/spool/cron.
var cronSpoolWatchDir = "/var/spool/cron"
// alertDedupTTL is the cooldown period for duplicate alerts on the same
// check+filepath combination. Prevents alert storms from rapid writes.
const alertDedupTTL = 30 * time.Second
// FileMonitor watches mount points for file creation/modification using fanotify.
type FileMonitor struct {
fd int
cfg *config.Config
alertCh chan<- alert.Finding
analyzerCh chan fileEvent
// M7 - separate counters for dropped events and alerts
droppedEvents int64
droppedAlerts int64
// C4 - pipe for epoll stop signaling
pipeFds [2]int // [0]=read, [1]=write
pipeClosed int32 // atomic flag: 1 = pipe fds closed by drainAndClose
// C2 - sync.Once for safe Stop
stopOnce sync.Once
drainOnce sync.Once
stopCh chan struct{} // internal stop channel
wg sync.WaitGroup
// Per-path alert deduplication: "check:filepath" → last alert time
alertDedup sync.Map
// WordPress core checksum verifier - skips detection on unmodified WP core files
wpCache *wpcheck.Cache
// Drop-recovery reconcile: directories that had fanotify events dropped
// because the analyzer queue was full. The overflow reporter walks this
// set once a minute and scans any interesting file modified within
// reconcileWindow so bulk filesystem operations (unzip, backup restore)
// do not blind detection to actual threats landing in the storm.
reconcileMu sync.Mutex
reconcileDirs map[string]time.Time
}
const (
reconcileDirCap = 64
reconcileWindow = 70 * time.Second
)
type fileEvent struct {
path string
fd int
pid int32
}
// NewFileMonitor creates a fanotify-based file monitor.
// Returns error if the kernel doesn't support the required features.
func NewFileMonitor(cfg *config.Config, alertCh chan<- alert.Finding) (*FileMonitor, error) {
// H1 - use golang.org/x/sys/unix for fanotify_init
fd, err := unix.FanotifyInit(FAN_CLASS_NOTIF|FAN_CLOEXEC|FAN_NONBLOCK, unix.O_RDONLY)
if err != nil {
return nil, fmt.Errorf("fanotify_init: %w (kernel may not support fanotify)", err)
}
// Mark mount points; M2 - track successful mounts
mountPaths := []string{"/home", "/tmp", "/dev/shm", "/var/tmp"}
mountOK := 0
for _, path := range mountPaths {
// H1 - use golang.org/x/sys/unix for fanotify_mark
err = unix.FanotifyMark(fd, FAN_MARK_ADD|FAN_MARK_MOUNT, FAN_CLOSE_WRITE|FAN_CREATE, -1, path)
if err != nil {
// Try without FAN_CREATE (older kernels)
err = unix.FanotifyMark(fd, FAN_MARK_ADD|FAN_MARK_MOUNT, FAN_CLOSE_WRITE, -1, path)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] Warning: cannot watch %s: %v\n", ts(), path, err)
continue
}
}
mountOK++
}
// M2 - error on zero successful mounts
if mountOK == 0 {
_ = unix.Close(fd)
return nil, fmt.Errorf("no mount points could be watched (tried %v)", mountPaths)
}
// Directory-scoped watch on /var/spool/cron so any user crontab write
// reaches analyzeFile in real time. Best-effort: cron may live under a
// different path on non-cPanel hosts (the platform layer normalises),
// and we'd rather lose the realtime crontab signal than fail daemon
// startup. The polled CheckCrontabs run still covers this case via
// the next scheduled scan. Mask matches spoolwatch.go (the proven
// production pattern for directory-scoped marks): FAN_CLOSE_WRITE
// alone catches both new and modified crontabs, since the close
// after O_CREAT|O_WRONLY|... fires the close-write event. FAN_CREATE
// is omitted because it has stricter kernel requirements with
// directory-scoped (non-MOUNT) marks and adds no coverage here.
if _, statErr := os.Stat(cronSpoolWatchDir); statErr == nil {
if err := unix.FanotifyMark(fd, FAN_MARK_ADD,
FAN_CLOSE_WRITE|FAN_EVENT_ON_CHILD, -1, cronSpoolWatchDir); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Warning: cannot watch %s: %v\n", ts(), cronSpoolWatchDir, err)
}
}
// C4 - create pipe for epoll stop signaling
var pipeFds [2]int
if err := unix.Pipe2(pipeFds[:], unix.O_NONBLOCK|unix.O_CLOEXEC); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("pipe2: %w", err)
}
fm := &FileMonitor{
fd: fd,
cfg: cfg,
alertCh: alertCh,
analyzerCh: make(chan fileEvent, 4000),
pipeFds: pipeFds,
stopCh: make(chan struct{}),
reconcileDirs: make(map[string]time.Time),
}
fm.wpCache = wpcheck.NewCache(cfg.StatePath)
return fm, nil
}
// Run starts the file monitor event loop and analyzer workers.
func (fm *FileMonitor) Run(stopCh <-chan struct{}) {
// H7 - configurable workers: min 4, max 16, based on NumCPU
numWorkers := runtime.NumCPU()
if numWorkers < 4 {
numWorkers = 4
}
if numWorkers > 16 {
numWorkers = 16
}
for i := 0; i < numWorkers; i++ {
fm.wg.Add(1)
go fm.analyzerWorker()
}
// Start overflow reporter
fm.wg.Add(1)
go fm.overflowReporter()
// C4 - create epoll instance, watch fanotify fd + pipe read end
epfd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] epoll_create1 failed: %v, falling back to poll loop\n", ts(), err)
fm.runPollFallback(stopCh)
return
}
defer func() { _ = unix.Close(epfd) }()
// Add fanotify fd to epoll
if err := unix.EpollCtl(epfd, unix.EPOLL_CTL_ADD, fm.fd, &unix.EpollEvent{
Events: unix.EPOLLIN,
// #nosec G115 -- POSIX fd fits in int32 (rlimit caps fds at ~1024).
Fd: int32(fm.fd),
}); err != nil {
fmt.Fprintf(os.Stderr, "[%s] epoll_ctl(fanotify): %v\n", ts(), err)
fm.runPollFallback(stopCh)
return
}
// Add pipe read end to epoll (for stop signaling)
if err := unix.EpollCtl(epfd, unix.EPOLL_CTL_ADD, fm.pipeFds[0], &unix.EpollEvent{
Events: unix.EPOLLIN,
// #nosec G115 -- POSIX fd fits in int32.
Fd: int32(fm.pipeFds[0]),
}); err != nil {
fmt.Fprintf(os.Stderr, "[%s] epoll_ctl(pipe): %v\n", ts(), err)
fm.runPollFallback(stopCh)
return
}
// Forward external stopCh to our internal mechanism
go func() {
select {
case <-stopCh:
fm.Stop()
case <-fm.stopCh:
}
}()
buf := make([]byte, 4096*24) // Large buffer for event batches
events := make([]unix.EpollEvent, 4)
for {
n, err := unix.EpollWait(epfd, events, 500) // 500ms timeout
if err != nil {
if err == unix.EINTR {
continue
}
// Check if we've been stopped
select {
case <-fm.stopCh:
fm.drainAndClose()
return
default:
}
fmt.Fprintf(os.Stderr, "[%s] epoll_wait error: %v\n", ts(), err)
time.Sleep(1 * time.Second)
continue
}
// Check for stop first
select {
case <-fm.stopCh:
fm.drainAndClose()
return
default:
}
for i := 0; i < n; i++ {
// #nosec G115 -- POSIX fd fits in int32; comparing against epoll event fd.
if events[i].Fd == int32(fm.pipeFds[0]) {
// Stop signal received via pipe
fm.drainAndClose()
return
}
// #nosec G115 -- POSIX fd fits in int32.
if events[i].Fd == int32(fm.fd) {
// fanotify events ready — single read per epoll wake
nr, readErr := unix.Read(fm.fd, buf)
if readErr != nil {
if readErr != unix.EAGAIN && readErr != unix.EINTR {
fmt.Fprintf(os.Stderr, "[%s] fanotify read error: %v\n", ts(), readErr)
}
} else if nr >= metadataSize {
fm.processEvents(buf[:nr])
}
}
}
}
}
// runPollFallback is used when epoll setup fails; falls back to sleep-based polling.
func (fm *FileMonitor) runPollFallback(stopCh <-chan struct{}) {
// Forward external stopCh to our internal mechanism
go func() {
select {
case <-stopCh:
fm.Stop()
case <-fm.stopCh:
}
}()
buf := make([]byte, 4096*24)
for {
select {
case <-fm.stopCh:
fm.drainAndClose()
return
default:
}
n, err := unix.Read(fm.fd, buf)
if err != nil {
if err == unix.EAGAIN || err == unix.EINTR {
time.Sleep(100 * time.Millisecond)
continue
}
fmt.Fprintf(os.Stderr, "[%s] fanotify read error: %v\n", ts(), err)
time.Sleep(1 * time.Second)
continue
}
if n < metadataSize {
continue
}
fm.processEvents(buf[:n])
}
}
// processEvents parses a buffer of fanotify event metadata and dispatches each event.
func (fm *FileMonitor) processEvents(buf []byte) {
offset := 0
for offset+metadataSize <= len(buf) {
// #nosec G103 -- fanotify delivers a packed binary stream on the
// fd; we must reinterpret the byte buffer as the kernel struct.
// The metadataSize bounds check above guarantees we have enough
// bytes for the struct.
event := (*fanotifyEventMetadata)(unsafe.Pointer(&buf[offset]))
if event.EventLen < uint32(metadataSize) {
break
}
if event.Fd >= 0 {
fm.handleEvent(int(event.Fd), event.Pid)
}
offset += int(event.EventLen)
}
}
// drainAndClose drains the analyzerCh and waits for workers to finish.
// C1 - ensures no fd leak on shutdown.
func (fm *FileMonitor) drainAndClose() {
fm.drainOnce.Do(func() {
close(fm.analyzerCh)
fm.wg.Wait()
// Mark pipe as closed before actually closing, so Stop() won't
// write to an already-closed fd (H2 fix).
atomic.StoreInt32(&fm.pipeClosed, 1)
_ = unix.Close(fm.pipeFds[0])
_ = unix.Close(fm.pipeFds[1])
})
}
// Stop signals the monitor to shut down.
// C2 - sync.Once ensures safe concurrent calls; does not close analyzerCh directly.
func (fm *FileMonitor) Stop() {
fm.stopOnce.Do(func() {
close(fm.stopCh)
// Wake epoll so Run() exits and calls drainAndClose.
// Only write if pipe hasn't been closed by drainAndClose yet.
if atomic.LoadInt32(&fm.pipeClosed) == 0 {
_, _ = unix.Write(fm.pipeFds[1], []byte{0})
}
// Close fanotify fd - causes any pending Read/EpollWait to return
_ = unix.Close(fm.fd)
})
}
func (fm *FileMonitor) handleEvent(fd int, pid int32) {
// Get the file path from the fd via /proc/self/fd/N
path, err := os.Readlink(fmt.Sprintf("/proc/self/fd/%d", fd))
if err != nil {
_ = unix.Close(fd)
return
}
// M5 - skip directory events
if strings.HasSuffix(path, "/") {
_ = unix.Close(fd)
return
}
// Fast filter - decide if this file is interesting based on path only
if !fm.isInteresting(path) {
_ = unix.Close(fd)
return
}
// Send to analyzer pool (with backpressure)
select {
case fm.analyzerCh <- fileEvent{path: path, fd: fd, pid: pid}:
default:
// Queue full - drop event, count, and record the parent dir so the
// reconcile pass in overflowReporter can rescan it. Without this
// every file in a bulk burst past buffer capacity is invisible to
// detection forever.
n := atomic.AddInt64(&fm.droppedEvents, 1)
if n%100 == 0 {
fmt.Fprintf(os.Stderr, "[%s] fanotify: %d events dropped (analyzer queue full)\n", ts(), n)
}
fm.recordDroppedDir(path)
_ = unix.Close(fd)
}
}
// recordDroppedDir registers a directory whose file had its fanotify event
// dropped, capped at reconcileDirCap entries (oldest evicted).
func (fm *FileMonitor) recordDroppedDir(path string) {
dir := filepath.Dir(path)
fm.reconcileMu.Lock()
defer fm.reconcileMu.Unlock()
if fm.reconcileDirs == nil {
fm.reconcileDirs = make(map[string]time.Time)
}
fm.reconcileDirs[dir] = time.Now()
if len(fm.reconcileDirs) <= reconcileDirCap {
return
}
var oldestKey string
var oldestTime time.Time
first := true
for k, t := range fm.reconcileDirs {
if first || t.Before(oldestTime) {
oldestKey, oldestTime, first = k, t, false
}
}
delete(fm.reconcileDirs, oldestKey)
}
// reconcileDrops walks every directory with a recent dropped event and
// analyses any interesting file modified within reconcileWindow. Converts
// lost events into delayed events rather than invisible ones. Called from
// overflowReporter after the minute-granularity overflow alert.
func (fm *FileMonitor) reconcileDrops() {
fm.reconcileMu.Lock()
dirs := fm.reconcileDirs
fm.reconcileDirs = make(map[string]time.Time)
fm.reconcileMu.Unlock()
if len(dirs) == 0 {
return
}
cutoff := time.Now().Add(-reconcileWindow)
for dir := range dirs {
entries, err := os.ReadDir(dir)
if err != nil {
continue
}
for _, e := range entries {
if e.IsDir() {
continue
}
info, err := e.Info()
if err != nil {
continue
}
if info.ModTime().Before(cutoff) {
continue
}
fullPath := filepath.Join(dir, e.Name())
if !fm.isInteresting(fullPath) {
continue
}
// Open+analyse+close wrapped so a panic in analyzeFile does not
// leak the fd; otherwise `defer f.Close()` in a loop body would
// defer until reconcileDrops returns, accumulating fds across
// every entry in every tracked dir.
func() {
// #nosec G304 -- fullPath is a directory entry under a dir
// the kernel already notified us about; reconcile owns
// reopening because the original fanotify fd is gone.
f, err := os.Open(fullPath)
if err != nil {
return
}
defer func() { _ = f.Close() }()
// #nosec G115 -- POSIX fd fits in int32 (rlimit caps fds at ~1024).
fm.analyzeFile(fileEvent{path: fullPath, fd: int(f.Fd())})
}()
}
}
}
// isInteresting is the fast filter - zero I/O, pure string matching.
func (fm *FileMonitor) isInteresting(path string) bool {
lower := strings.ToLower(path)
// PHP files
if strings.HasSuffix(lower, ".php") || strings.HasSuffix(lower, ".phtml") ||
strings.HasSuffix(lower, ".pht") || strings.HasSuffix(lower, ".php5") {
return true
}
// Webshell extensions
if strings.HasSuffix(lower, ".haxor") || strings.HasSuffix(lower, ".cgix") {
return true
}
// CGI scripts in web-accessible directories — detect Perl/Python/Bash backdoors
if strings.HasPrefix(path, "/home/") {
if strings.HasSuffix(lower, ".pl") || strings.HasSuffix(lower, ".cgi") ||
strings.HasSuffix(lower, ".py") || strings.HasSuffix(lower, ".sh") ||
strings.HasSuffix(lower, ".rb") {
return true
}
}
// .htaccess and .user.ini files
if strings.HasSuffix(lower, ".htaccess") || strings.HasSuffix(lower, ".user.ini") {
return true
}
// HTML files in /home (phishing pages)
if strings.HasPrefix(path, "/home/") &&
(strings.HasSuffix(lower, ".html") || strings.HasSuffix(lower, ".htm")) {
return true
}
// Credential log files - known phishing harvest filenames
base := filepath.Base(lower)
if credentialLogNames[base] {
return true
}
// ZIP archives in /home (phishing kit uploads)
if strings.HasPrefix(path, "/home/") && strings.HasSuffix(lower, ".zip") {
return true
}
// Anything in .config directories
if strings.Contains(path, "/.config/") {
return true
}
// User crontabs surfaced via the directory-scoped fanotify mark in
// NewFileMonitor. Each write to /var/spool/cron/<user> dispatches to
// checkCrontab in real time.
if strings.HasPrefix(path, cronSpoolWatchDir+"/") {
return true
}
// Executables in /tmp or /dev/shm
if strings.HasPrefix(path, "/tmp/") || strings.HasPrefix(path, "/dev/shm/") || strings.HasPrefix(path, "/var/tmp/") {
return true
}
// PHP in sensitive directories that should never contain PHP
if (strings.Contains(path, "/.ssh/") || strings.Contains(path, "/.cpanel/") ||
strings.Contains(path, "/mail/") || strings.Contains(path, "/.gnupg/") ||
strings.Contains(path, "/.cagefs/")) && isPHPExtension(strings.ToLower(filepath.Base(path))) {
return true
}
return false
}
// credentialLogNames are filenames commonly used by phishing kits to store
// harvested credentials. Checked in isInteresting() for real-time detection.
var credentialLogNames = map[string]bool{
"results.txt": true, "result.txt": true, "log.txt": true,
"logs.txt": true, "emails.txt": true, "data.txt": true,
"passwords.txt": true, "creds.txt": true, "credentials.txt": true,
"victims.txt": true, "output.txt": true, "harvested.txt": true,
"results.log": true, "emails.log": true, "data.log": true,
"results.csv": true, "emails.csv": true, "data.csv": true,
}
// analyzerWorker processes file events from the bounded channel.
// C1 - on channel close, drains remaining events and closes their fds.
func (fm *FileMonitor) analyzerWorker() {
defer fm.wg.Done()
for event := range fm.analyzerCh {
fm.analyzeFile(event)
_ = unix.Close(event.fd)
}
}
// readFromFd reads up to maxBytes from a file descriptor at position 0.
// C3 - avoids TOCTOU by reading from the original fanotify event fd.
// readFromFd reads up to maxBytes from the fanotify event fd using pread
// at offset 0. Uses unix.Pread directly to avoid os.NewFile's GC finalizer
// which would close the fd out-of-band, racing with the worker's explicit close.
func readFromFd(fd int, maxBytes int) []byte {
buf := make([]byte, maxBytes)
n, err := unix.Pread(fd, buf, 0)
if n <= 0 || (err != nil && n == 0) {
return nil
}
return buf[:n]
}
// readTailFromFd reads the last maxBytes of a file via its fd using pread.
// Returns nil if the file is smaller than maxBytes (head scan already covers it).
func readTailFromFd(fd int, maxBytes int) []byte {
var stat unix.Stat_t
if err := unix.Fstat(fd, &stat); err != nil {
return nil
}
size := stat.Size
if size <= int64(maxBytes) {
return nil // head read already covers the entire file
}
offset := size - int64(maxBytes)
buf := make([]byte, maxBytes)
n, err := unix.Pread(fd, buf, offset)
if n <= 0 || (err != nil && n == 0) {
return nil
}
return buf[:n]
}
// resolveProcessInfo reads /proc/<pid>/comm and /proc/<pid>/status
// to build a "pid=N cmd=name uid=N" string for alert enrichment.
// Returns empty string on any error (process may have exited).
func resolveProcessInfo(pid int32) string {
if pid <= 0 {
return ""
}
procDir := fmt.Sprintf("/proc/%d", pid)
// Read process name
// #nosec G304 -- /proc/<pid>/comm; kernel pseudo-FS, pid is int32 from fanotify event.
comm, err := os.ReadFile(procDir + "/comm")
if err != nil {
return ""
}
name := strings.TrimSpace(string(comm))
info := fmt.Sprintf("pid=%d cmd=%s", pid, name)
// Read UID from status to map to cPanel username
// #nosec G304 -- /proc/<pid>/status; kernel pseudo-FS.
statusData, err := os.ReadFile(procDir + "/status")
if err != nil {
return info
}
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
info += fmt.Sprintf(" uid=%s", fields[1])
}
break
}
}
return info
}
func (fm *FileMonitor) analyzeFile(event fileEvent) {
path := event.path
name := filepath.Base(path)
nameLower := strings.ToLower(name)
// Resolve process info from PID (best-effort - process may have exited)
procInfo := resolveProcessInfo(event.pid)
// H2 - suppression path matching using filepath.Match
for _, ignore := range fm.cfg.Suppressions.IgnorePaths {
if matchSuppression(ignore, path) {
return
}
}
// Skip verified WordPress core files - checksum matches official WP.org checksums.
// Content is read from the event fd (not path) to preserve TOCTOU safety.
if fm.wpCache != nil && fm.wpCache.IsVerifiedCoreFile(event.fd, path) {
return
}
// Verified WordPress plugin files: hash matches the plugin's official
// wordpress.org ZIP for its declared version. Stops signature/YARA FPs
// on stock plugin code (Wordfence, Contact Form 7, etc.). Cache miss
// triggers a background fetch; misses fall through to rule evaluation.
if fm.wpCache != nil && fm.wpCache.IsVerifiedPluginFile(event.fd, path) {
return
}
// User crontab written under /var/spool/cron/<user>. Scan content
// from the event fd via the shared deep matcher and emit Critical on
// any hit. The polled CheckCrontabs run still tracks root crontab
// hash drift, so we skip root here to avoid duplicate signal.
if strings.HasPrefix(path, cronSpoolWatchDir+"/") {
fm.checkCrontab(event.fd, path, procInfo)
return
}
// Location-based severity escalation: PHP in dirs that should NEVER have PHP
if isPHPExtension(nameLower) {
for _, sensitive := range []string{"/.ssh/", "/.cpanel/", "/mail/", "/.gnupg/", "/.cagefs/"} {
if strings.Contains(path, sensitive) {
fm.sendAlertWithPath(alert.Critical, "php_in_sensitive_dir_realtime",
fmt.Sprintf("PHP file in critical directory: %s", path),
fmt.Sprintf("PHP should never exist in %s - likely webshell or backdoor", sensitive), path, procInfo)
return
}
}
}
// Immediate CRITICAL: known webshell filenames (M1 - package-level var)
if knownWebshells[nameLower] {
fm.sendAlertWithPath(alert.Critical, "webshell_realtime",
fmt.Sprintf("Webshell file created: %s", path), "", path, procInfo)
return
}
// Webshell extensions
if strings.HasSuffix(nameLower, ".haxor") || strings.HasSuffix(nameLower, ".cgix") {
fm.sendAlertWithPath(alert.Critical, "webshell_realtime",
fmt.Sprintf("Suspicious CGI file created: %s", path), "", path, procInfo)
return
}
// .htaccess modification - check for injection (C3 - read from fd).
// Checked before the /tmp early-return so a malicious .htaccess anywhere
// (including /tmp) is still analyzed for dangerous directives.
if nameLower == ".htaccess" {
fm.checkHtaccess(event.fd, path, procInfo)
return
}
// .user.ini modification - check for dangerous PHP settings (C3 - read from fd).
// Also checked before /tmp so malicious .user.ini is detected anywhere.
if nameLower == ".user.ini" {
fm.checkUserINI(event.fd, path, procInfo)
return
}
// Executables in .config - checked before the /tmp block so a miner
// dropped at /tmp/.config/* is flagged as executable_in_config_realtime
// (more specific) rather than executable_in_tmp_realtime.
if strings.Contains(path, "/.config/") {
finfo, err := os.Stat(path)
if err == nil && finfo.Mode()&0111 != 0 {
fm.sendAlertWithPath(alert.Critical, "executable_in_config_realtime",
fmt.Sprintf("Executable created in .config: %s", path),
fmt.Sprintf("Size: %d", finfo.Size()), path, procInfo)
}
return
}
// Executables in /tmp or /dev/shm - detect dropped malware/miners
// Uses unix.Fstat on event fd for TOCTOU safety (attacker can't chmod -x after event)
if strings.HasPrefix(path, "/tmp/") || strings.HasPrefix(path, "/dev/shm/") || strings.HasPrefix(path, "/var/tmp/") {
var tmpStat unix.Stat_t
if err := unix.Fstat(event.fd, &tmpStat); err == nil {
isDir := tmpStat.Mode&unix.S_IFMT == unix.S_IFDIR
isExec := tmpStat.Mode&0111 != 0
if !isDir && isExec {
// Skip known root-owned work directories:
// - cPanel: SpamAssassin compiles .so regex modules, UPCP stages scripts
// - dracut: rebuilds initramfs after kernel updates, copies system binaries
// Non-root files in these paths are still suspicious.
isCpanelWork := strings.Contains(path, "/cpanel.TMP.work.") || strings.Contains(path, "/cPanel-")
isDracutWork := strings.Contains(path, "/dracut.")
if (isCpanelWork || isDracutWork) && tmpStat.Uid == 0 {
// Root-owned executable in system work dir - legitimate, skip
} else {
fm.sendAlertWithPath(alert.Critical, "executable_in_tmp_realtime",
fmt.Sprintf("Executable created in %s: %s", filepath.Dir(path), path),
fmt.Sprintf("Size: %d, Mode: %04o, UID: %d", tmpStat.Size, tmpStat.Mode&0777, tmpStat.Uid), path, procInfo)
}
}
}
// Fall through to PHP checks below for .php files in /tmp
if !isPHPExtension(nameLower) {
return
}
}
// PHP in uploads directories.
// Any PHP file here is anomalous: /wp-content/uploads/ is meant for
// media, not code. Plugin-update temp dirs are recognised structurally
// via looksLikePluginUpdate; everything else is Critical. Operators
// suppress legitimate caching daemons through the path-scoped
// suppressions_api, not an implicit substring allowlist in the daemon.
if strings.Contains(path, "/wp-content/uploads/") && isPHPExtension(nameLower) {
if nameLower != "index.php" {
if looksLikePluginUpdate(path) {
// Verified plugin update - emit one low-severity alert per temp directory.
// Cannot suppress entirely - attacker could create a decoy plugin dir.
uploadsIdx := strings.Index(path, "/wp-content/uploads/")
afterUploads := path[uploadsIdx+len("/wp-content/uploads/"):]
tempDir := afterUploads
if slashIdx := strings.Index(afterUploads, "/"); slashIdx > 0 {
tempDir = afterUploads[:slashIdx]
}
updateDir := path[:uploadsIdx] + "/wp-content/uploads/" + tempDir
fm.sendAlertWithPath(alert.Warning, "php_in_uploads_realtime",
fmt.Sprintf("Plugin update in uploads: %s", updateDir),
"Verified: matching plugin exists in plugins/", updateDir, procInfo)
} else {
fm.sendAlertWithPath(alert.Critical, "php_in_uploads_realtime",
fmt.Sprintf("PHP file created in uploads: %s", path), "", path, procInfo)
}
}
return
}
// PHP in languages/upgrade directories.
// Path-only Critical buried real alerts under location noise (WPML
// translation queues, WP auto-update staging). Run content analysis
// first: if a real rule fires the Critical lives there; clean files
// still get a Warning so unexpected PHP in these dirs is visible.
if (strings.Contains(path, "/wp-content/languages/") || strings.Contains(path, "/wp-content/upgrade/")) &&
isPHPExtension(nameLower) {
if nameLower != "index.php" && !strings.HasSuffix(nameLower, ".l10n.php") {
if !fm.checkPHPContent(event.fd, path, procInfo) {
fm.sendAlertWithPath(alert.Warning, "php_in_sensitive_dir_realtime",
fmt.Sprintf("PHP file created in sensitive WP directory (content clean): %s", path), "", path, procInfo)
}
}
return
}
// PHP content analysis (C3 - read from fd; M4 - 32KB scan size).
// .htaccess, .user.ini, and .config executable checks are handled
// earlier in this function (before the /tmp early-return) so specific
// file types take precedence over the /tmp generic block.
if isPHPExtension(nameLower) {
fm.checkPHPContent(event.fd, path, procInfo)
return
}
// HTML phishing page detection (uses event fd for content, unix.Fstat for size)
if strings.HasSuffix(nameLower, ".html") || strings.HasSuffix(nameLower, ".htm") {
fm.checkHTMLPhishing(event.fd, path, procInfo)
return
}
// Credential log files (path-based)
if credentialLogNames[nameLower] {
fm.checkCredentialLog(path, procInfo)
return
}
// Phishing kit ZIP archives (path-based)
if strings.HasSuffix(nameLower, ".zip") {
fm.checkPhishingZip(path, nameLower, procInfo)
return
}
// CGI scripts in web-accessible directories (Perl, Python, Bash, Ruby)
// Detect backdoor toolkits like LEVIATHAN that use non-PHP scripts.
if strings.HasPrefix(path, "/home/") && isCGIExtension(nameLower) {
fm.checkCGIBackdoor(event.fd, path, procInfo)
return
}
}
// Structural exclusions for checkHtaccess. Both anchor to the actual
// directive or regex context, not to loose substrings that an attacker
// can paste anywhere on the line.
var (
// Legit auto_(prepend|append)_file directive targets: known product
// files shipped by security plugins. Match is anchored to the
// directive argument, so a trailing "# litespeed" comment cannot
// forge safety.
htaccessAutoPrependSafeTarget = regexp.MustCompile(
`(?i)auto_(?:prepend|append)_file\s*=?\s*['"]?(?:[^\s'"]*/)?` +
`(?:wordfence-waf|sucuri|advanced-headers)\.php(?:['"]|\s|$)`,
)
// Apache mod_rewrite directives. base64_decode / eval( appearing
// inside a RewriteCond or RewriteRule is a pattern in an attack-query
// blocklist (e.g. Really Simple SSL hardening), not PHP code.
htaccessRewriteDirective = regexp.MustCompile(
`(?i)^\s*Rewrite(?:Cond|Rule)\s`,
)
)
// checkCrontab scans a freshly-written /var/spool/cron/<user> file for the
// known persistence-marker patterns (literal + base64-decoded). Reads from
// the event fd, not the path, so an attacker swapping the file post-event
// cannot redirect us. Root crontab drift is tracked separately via
// hash-baseline by the polled CheckCrontabs.
func (fm *FileMonitor) checkCrontab(fd int, path, procInfo string) {
user := filepath.Base(path)
if user == "" || user == "root" || user == filepath.Base(cronSpoolWatchDir) {
return
}
data := readFromFd(fd, 65536)
if data == nil {
return
}
matched := checks.MatchCrontabPatternsDeep(string(data))
if len(matched) == 0 {
return
}
fm.sendAlertWithPath(alert.Critical, "suspicious_crontab",
fmt.Sprintf("Suspicious crontab written for user %s: %v", user, matched),
fmt.Sprintf("File: %s\nPatterns matched: %v", path, matched),
path, procInfo)
}
// checkHtaccess reads .htaccess content from the event fd and checks for injection.
// C3 - reads from fd, not path.
func (fm *FileMonitor) checkHtaccess(fd int, path, procInfo string) {
data := readFromFd(fd, 16384)
if data == nil {
return
}
for _, rawLine := range strings.Split(string(data), "\n") {
line := strings.TrimSpace(rawLine)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
lower := strings.ToLower(line)
// auto_prepend_file / auto_append_file: suspicious unless the
// directive target matches a known-legit security plugin file.
if strings.Contains(lower, "auto_prepend_file") || strings.Contains(lower, "auto_append_file") {
if htaccessAutoPrependSafeTarget.MatchString(line) {
continue
}
fm.sendAlertWithPath(alert.High, "htaccess_injection_realtime",
fmt.Sprintf("Suspicious .htaccess modification: %s", path),
"auto_prepend_file/auto_append_file target not recognised", path, procInfo)
return
}
// eval( / base64_decode outside a RewriteCond/RewriteRule is a
// tamper signal: .htaccess is not a PHP execution context, so
// the only legit appearance of these tokens is as regex patterns
// inside mod_rewrite attack-blocklists.
if strings.Contains(lower, "eval(") || strings.Contains(lower, "base64_decode") {
if htaccessRewriteDirective.MatchString(line) {
continue
}
fm.sendAlertWithPath(alert.High, "htaccess_injection_realtime",
fmt.Sprintf("Suspicious .htaccess modification: %s", path),
"PHP function reference outside RewriteCond/RewriteRule", path, procInfo)
return
}
}
// Run signature/YARA scanning on .htaccess content
fm.runSignatureScan(data, path, ".htaccess", procInfo)
}
// checkUserINI reads .user.ini content from the event fd and checks for dangerous PHP settings.
// C3 - reads from fd, not path. H6 - proper allow_url_include parsing.
func (fm *FileMonitor) checkUserINI(fd int, path, procInfo string) {
data := readFromFd(fd, 4096)
if data == nil {
return
}
content := strings.ToLower(string(data))
dangerous := []struct {
pattern string
desc string
}{
{"allow_url_include", "allow_url_include (remote code inclusion)"},
{"disable_functions", "disable_functions modified"},
}
for _, d := range dangerous {
if !strings.Contains(content, d.pattern) {
continue
}
if d.pattern == "disable_functions" {
for _, line := range strings.Split(content, "\n") {
if strings.HasPrefix(strings.TrimSpace(line), "disable_functions") {
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
val := strings.TrimSpace(parts[1])
if val == "" || val == "\"\"" || val == "none" {
fm.sendAlertWithPath(alert.Critical, "php_config_realtime",
fmt.Sprintf("PHP disable_functions cleared: %s", path),
"All dangerous PHP functions enabled - shell execution possible", path, procInfo)
return
}
}
}
}
}
// H6 - parse the specific line value instead of checking for "on" anywhere
if d.pattern == "allow_url_include" {
for _, line := range strings.Split(content, "\n") {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "allow_url_include") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 {
val := strings.TrimSpace(strings.ToLower(parts[1]))
if val == "on" || val == "1" || val == "\"on\"" || val == "'on'" {
fm.sendAlertWithPath(alert.Critical, "php_config_realtime",
fmt.Sprintf("PHP allow_url_include enabled: %s", path),
"Remote PHP file inclusion is now possible", path, procInfo)
return
}
}
}
}
}
// Run signature/YARA scanning on .user.ini content
fm.runSignatureScan(data, path, ".ini", procInfo)
}
// checkPHPContent reads PHP content from the event fd and checks for malicious patterns.
// C3 - reads from fd, not path. M4 - 32KB scan size.
func (fm *FileMonitor) checkPHPContent(fd int, path, procInfo string) bool {
data := readFromFd(fd, 32768)
if data == nil {
return false
}
content := strings.ToLower(string(data))
// Remote payload fetching — paste sites are always suspicious.
// GitHub raw URLs only flag when combined with a dangerous call on
// the same line (legitimate plugins use GitHub for update checks).
pasteURLs := []string{"pastebin.com/raw", "paste.ee/r/", "ghostbin.co/paste/", "hastebin.com/raw/"}
for _, p := range pasteURLs {
if strings.Contains(content, p) {
fm.sendAlertWithPath(alert.Critical, "php_dropper_realtime",
fmt.Sprintf("PHP dropper with paste site URL: %s", path),
fmt.Sprintf("Fetches from: %s", p), path, procInfo)
return true
}
}
githubURLs := []string{"gist.githubusercontent.com", "raw.githubusercontent.com"}
dangerousFns := []string{"file_put_contents(", "fwrite(", "shell_", "passthru(", "popen("}
for _, gh := range githubURLs {
if !strings.Contains(content, gh) {
continue
}
for _, line := range strings.Split(content, "\n") {
if !strings.Contains(line, gh) {
continue
}
for _, fn := range dangerousFns {
if strings.Contains(line, fn) {
fm.sendAlertWithPath(alert.Critical, "php_dropper_realtime",
fmt.Sprintf("PHP dropper fetching from GitHub with dangerous call: %s", path),
fmt.Sprintf("URL: %s, Function: %s", gh, fn), path, procInfo)
return true
}
}
}
}
// eval + decoder combo — require same-line nesting to avoid FPs on
// legitimate plugins that use these functions in unrelated contexts.
evalStr := "eval(" // search target for PHP eval function calls
assertStr := "assert(" // search target for PHP assert function calls
decoders := []string{"base64_decode", "gzinflate", "gzuncompress", "str_rot13", "gzdecode"}
for _, line := range strings.Split(content, "\n") {
lineHasEval := strings.Contains(line, evalStr) || strings.Contains(line, assertStr)
if !lineHasEval {
continue
}
for _, dec := range decoders {
if strings.Contains(line, dec) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Obfuscated PHP detected: %s", path),
fmt.Sprintf("PHP code execution with %s on same line", dec), path, procInfo)
return true
}
}
}
// Fragmented base64 evasion: $a="base"; $b="64_decode"; $c=$a.$b;
if strings.Contains(content, "\"base\"") || strings.Contains(content, "'base'") {
if strings.Contains(content, "64_dec") && strings.Contains(content, evalStr) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Fragmented base64_decode evasion detected: %s", path),
"base64_decode function name split across string variables", path, procInfo)
return true
}
}
// Massive variable concatenation payload ($z .= "xxxx"; repeated thousands of times)
concatCount := strings.Count(content, ".= \"")
if concatCount > 50 && strings.Contains(content, evalStr) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Concatenation payload detected: %s (%d concat ops)", path, concatCount),
"Variable built from hundreds of string concatenations then executed", path, procInfo)
return true
}
// Shell execution with request input
// Uses containsFunc to avoid substring false positives
// (e.g. "WP_Filesystem(" matching "exec(", "preg_match(" matching "exec(")
shellFuncs := []string{"system(", "passthru(", "exec(", "shell_exec(", "popen("}
requestVars := []string{"$_request", "$_post", "$_get", "$_cookie", "$_server"}
hasShell := false
hasInput := false
for _, sf := range shellFuncs {
if containsFunc(content, sf) {
hasShell = true
}
}
for _, rv := range requestVars {
if strings.Contains(content, rv) {
hasInput = true
}
}
if hasShell && hasInput {
// Require shell function + request variable on the SAME line.
// Same-line narrowing is the actual detection: admin panels with
// both tokens in unrelated contexts stay quiet because they never
// co-occur on one line. A file-wide allowlist (e.g. "skip when
// 'wp_filesystem' appears anywhere") would be forgeable by any
// webshell that pastes the token into a comment.
for _, line := range strings.Split(content, "\n") {
lineHasShell := false
lineHasInput := false
for _, sf := range shellFuncs {
if containsFunc(line, sf) {
lineHasShell = true
break
}
}
for _, rv := range requestVars {
if strings.Contains(line, rv) {
lineHasInput = true
break
}
}
if lineHasShell && lineHasInput {
fm.sendAlertWithPath(alert.Critical, "webshell_content_realtime",
fmt.Sprintf("Webshell pattern detected: %s", path),
fmt.Sprintf("Shell execution with request input on same line: %s", strings.TrimSpace(line)), path, procInfo)
return true
}
}
}
// Tail scan: for large files, also check the last 32KB.
// Attackers append payloads (eval+base64) at the end of legitimate PHP files,
// beyond the head scan window. Only do the cheap heuristic checks, not full
// signature scanning (which would be too slow on every large PHP file).
if tailData := readTailFromFd(fd, 32768); tailData != nil {
tail := strings.ToLower(string(tailData))
// Check for eval+decoder on same line in tail
for _, line := range strings.Split(tail, "\n") {
lineHasEval := strings.Contains(line, evalStr) || strings.Contains(line, assertStr)
if !lineHasEval {
continue
}
for _, dec := range decoders {
if strings.Contains(line, dec) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Obfuscated PHP appended to file tail: %s", path),
fmt.Sprintf("PHP code execution with %s found at end of file", dec), path, procInfo)
return true
}
}
}
// Fragmented base64 in tail
if strings.Contains(tail, "\"base\"") || strings.Contains(tail, "'base'") {
if strings.Contains(tail, "64_dec") && strings.Contains(tail, evalStr) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Fragmented base64_decode evasion in file tail: %s", path),
"Payload appended at end of legitimate PHP file", path, procInfo)
return true
}
}
// Concat payload with eval in tail
tailConcatCount := strings.Count(tail, ".= \"")
if tailConcatCount > 50 && strings.Contains(tail, evalStr) {
fm.sendAlertWithPath(alert.Critical, "obfuscated_php_realtime",
fmt.Sprintf("Concatenation payload in file tail: %s (%d concat ops)", path, tailConcatCount),
"Payload appended at end of legitimate PHP file", path, procInfo)
return true
}
}
// Skip signature/YARA scanning for verified CMS core files.
// The wp_core periodic check validates files against official checksums;
// if a file's hash matches a known-clean core file, signature matches
// on it are false positives (e.g. $_POST in wp-includes, mail() in
// PHPMailer, fsockopen() in POP3.php).
if checks.IsVerifiedCMSFile(path) {
return false
}
// External signature + YARA scanning
return fm.runSignatureScan(data, path, filepath.Ext(path), procInfo)
}
// checkHTMLPhishing reads an HTML file and checks for phishing indicators:
// brand impersonation + credential input + redirect/exfiltration.
// Uses event fd for content read and unix.Fstat for size (TOCTOU-safe).
func (fm *FileMonitor) checkHTMLPhishing(fd int, path, procInfo string) {
// Only check files in web-accessible directories.
//
// No path-allowlist below this point: the content gates (credential
// inputs + brand impersonation + exfil/trust-badge) reject legitimate
// framework HTML on their own. A previous allowlist for /wp-admin/,
// /wp-content/themes/, /wp-content/plugins/, /node_modules/, /vendor/,
// /.well-known/ let an attacker who compromised any of those dirs drop
// a credential-harvesting page with full suppression.
if !strings.Contains(path, "/public_html/") {
return
}
var stat unix.Stat_t
if err := unix.Fstat(fd, &stat); err != nil {
return
}
size := stat.Size
if size < 500 || size > 500000 {
return
}
data := readFromFd(fd, 16384)
if data == nil {
return
}
content := strings.ToLower(string(data))
// Must have a form with credential inputs
if !strings.Contains(content, "<form") && !strings.Contains(content, "<input") {
return
}
hasCredInput := strings.Contains(content, "type=\"email\"") ||
strings.Contains(content, "type=\"password\"") ||
strings.Contains(content, "type='email'") ||
strings.Contains(content, "type='password'") ||
strings.Contains(content, "name=\"email\"") ||
strings.Contains(content, "name=\"password\"") ||
strings.Contains(content, "placeholder=\"you@")
if !hasCredInput {
return
}
// Check for brand impersonation
brands := []struct {
name string
patterns []string
}{
{"Microsoft/SharePoint", []string{"sharepoint", "onedrive", "microsoft 365", "outlook web", "office 365"}},
{"Google", []string{"google drive", "google docs", "accounts.google", "gmail"}},
{"Dropbox", []string{"dropbox"}},
{"DocuSign", []string{"docusign"}},
{"Adobe", []string{"adobe sign", "adobe document"}},
{"Apple/iCloud", []string{"icloud", "apple id"}},
{"Webmail", []string{"roundcube", "horde", "webmail login", "zimbra"}},
{"Generic", []string{"secure access", "verify your", "confirm your identity", "account verification"}},
}
brandMatch := ""
for _, b := range brands {
for _, p := range b.patterns {
if strings.Contains(content, p) {
brandMatch = b.name
break
}
}
if brandMatch != "" {
break
}
}
if brandMatch == "" {
return
}
// Check for redirect/exfiltration patterns
exfilPatterns := []string{
"window.location.href", "window.location.replace", "window.location =",
".workers.dev", "fetch(", "xmlhttprequest",
}
hasExfil := false
for _, p := range exfilPatterns {
if strings.Contains(content, p) {
hasExfil = true
break
}
}
// Also check for trust badges (strong phishing signal)
hasTrustBadge := strings.Contains(content, "secured by microsoft") ||
strings.Contains(content, "secured by google") ||
strings.Contains(content, "256-bit encrypted") ||
strings.Contains(content, "256‑bit encrypted")
if hasExfil || hasTrustBadge {
fm.sendAlertWithPath(alert.Critical, "phishing_realtime",
fmt.Sprintf("Phishing page created (%s impersonation): %s", brandMatch, path),
fmt.Sprintf("Size: %d bytes", size), path, procInfo)
return
}
// Run signature/YARA scanning on HTML content not caught by phishing heuristics
fm.runSignatureScan(data, path, ".html", procInfo)
}
// checkCredentialLog reads a text file and checks if it contains harvested
// email:password pairs - output from an active phishing kit.
// Uses path-based reads because it needs path context.
func (fm *FileMonitor) checkCredentialLog(path, procInfo string) {
if !strings.Contains(path, "/public_html/") {
return
}
// Exclude known config file paths - these legitimately contain email-like patterns.
if strings.HasPrefix(path, "/etc/") {
return
}
for _, suffix := range []string{".conf", ".cfg", ".ini", ".yaml", ".yml"} {
if strings.HasSuffix(path, suffix) {
return
}
}
data := readHead(path, 4096)
if data == nil {
return
}
content := string(data)
lines := strings.Split(content, "\n")
credLines := 0
emailCount := 0
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
if strings.Contains(line, "@") {
emailCount++
for _, delim := range []string{":", "|", "\t", ","} {
parts := strings.SplitN(line, delim, 3)
if len(parts) >= 2 {
p0 := strings.TrimSpace(parts[0])
p1 := strings.TrimSpace(parts[1])
if strings.Contains(p0, "@") && len(p1) > 0 && !strings.Contains(p1, " ") {
credLines++
break
}
}
}
}
}
if credLines >= 5 {
fm.sendAlertWithPath(alert.Critical, "credential_log_realtime",
fmt.Sprintf("Harvested credential log detected: %s", path),
fmt.Sprintf("%d credential lines (email:password format) found", credLines), path, procInfo)
} else if emailCount >= 10 {
fm.sendAlertWithPath(alert.High, "credential_log_realtime",
fmt.Sprintf("Possible harvested email list: %s", path),
fmt.Sprintf("%d email addresses found in %s", emailCount, filepath.Base(path)), path, procInfo)
}
}
// checkPhishingZip checks if a newly created ZIP file matches known phishing
// kit archive names.
// Uses path-based approach since it only checks the filename.
func (fm *FileMonitor) checkPhishingZip(path, nameLower, procInfo string) {
if !strings.Contains(path, "/public_html/") {
return
}
kitNames := []string{
"office365", "office 365", "sharepoint", "onedrive",
"microsoft", "outlook", "google", "gmail",
"dropbox", "docusign", "adobe", "wetransfer",
"paypal", "apple", "icloud", "netflix",
"facebook", "instagram", "linkedin",
"login", "phish", "scam", "kit",
"webmail", "roundcube", "cpanel",
"bank", "verify", "secure",
}
for _, kit := range kitNames {
if strings.Contains(nameLower, kit) {
fm.sendAlertWithPath(alert.High, "phishing_kit_realtime",
fmt.Sprintf("Suspected phishing kit archive uploaded: %s", path),
fmt.Sprintf("Filename matches phishing kit pattern: '%s'", kit), path, procInfo)
return
}
}
}
// runSignatureScan runs YAML and YARA signature scanning on file content.
// Returns true if a match was found and an alert was sent.
// Non-critical YAML matches use directory-level dedup to avoid alert floods
// when a plugin directory has many files matching the same rule.
// Critical matches (backdoors, webshells) always alert per-file.
func (fm *FileMonitor) runSignatureScan(data []byte, path, ext, procInfo string) bool {
if scanner := signatures.Global(); scanner != nil {
matches := scanner.ScanContent(data, ext)
if len(matches) > 0 {
m := matches[0]
sev := alert.High
if m.Severity == "critical" {
sev = alert.Critical
}
// Non-critical: dedup by rule+directory so 30 files in the same
// plugin matching the same rule produce one alert, not 30.
// Critical matches always alert per-file (real path for quarantine).
if sev != alert.Critical {
dirKey := m.RuleName + ":" + filepath.Dir(path)
if !fm.shouldAlert("signature_match_realtime", dirKey) {
return true // suppressed by dedup, but still counts as "matched"
}
}
details := fmt.Sprintf("Category: %s\nDescription: %s\nMatched: %s",
m.Category, m.Description, strings.Join(m.Matched, ", "))
fm.sendAlertWithPath(sev, "signature_match_realtime",
fmt.Sprintf("Signature match [%s]: %s", m.RuleName, path),
details, path, procInfo)
// Inline quarantine: move high-confidence malware to quarantine
// immediately instead of waiting for the 5-second batch dispatcher.
// Uses the same 3-gate validation as AutoQuarantineFiles (category +
// library exclusion + entropy >= 4.8) to prevent false positives.
if sev == alert.Critical {
finding := alert.Finding{
Severity: sev,
Check: "signature_match_realtime",
Details: details,
FilePath: path,
}
if qPath, ok := checks.InlineQuarantine(finding, path, data); ok {
fm.sendAlert(alert.Critical, "auto_response",
fmt.Sprintf("AUTO-QUARANTINE (inline): %s moved to quarantine", path),
fmt.Sprintf("Quarantined to: %s\nRule: %s", qPath, m.RuleName))
}
}
return true
}
}
if yaraScanner := yara.Global(); yaraScanner != nil {
matches := yaraScanner.ScanBytes(data)
if len(matches) > 0 {
fm.sendAlertWithPath(alert.Critical, "yara_match_realtime",
fmt.Sprintf("YARA rule match [%s]: %s", matches[0].RuleName, path),
fmt.Sprintf("Matched %d YARA rule(s)", len(matches)), path, procInfo)
return true
}
}
return false
}
// M7 - sendAlert uses droppedAlerts counter, separate from droppedEvents.
// No dedup - only used for system-level alerts (overflow reporting) that are
// already ticker-gated. File-related alerts should use sendAlertWithPath.
func (fm *FileMonitor) sendAlert(severity alert.Severity, check, message, details string) {
finding := alert.Finding{
Severity: severity,
Check: check,
Message: message,
Details: details,
Timestamp: time.Now(),
}
select {
case fm.alertCh <- finding:
default:
atomic.AddInt64(&fm.droppedAlerts, 1)
}
}
// sendAlertWithPath is like sendAlert but also sets the FilePath and
// ProcessInfo fields for structured propagation to auto-response.
// Applies per-path deduplication to prevent alert storms from rapid writes.
func (fm *FileMonitor) sendAlertWithPath(severity alert.Severity, check, message, details, filePath, processInfo string) {
if !fm.shouldAlert(check, filePath) {
return
}
finding := alert.Finding{
Severity: severity,
Check: check,
Message: message,
Details: details,
FilePath: filePath,
ProcessInfo: processInfo,
Timestamp: time.Now(),
}
select {
case fm.alertCh <- finding:
default:
atomic.AddInt64(&fm.droppedAlerts, 1)
}
}
// shouldAlert returns true if this check+path combination hasn't been alerted
// recently. Prevents duplicate alerts from rapid writes to the same file.
// Uses LoadOrStore for atomic initial insertion to avoid TOCTOU races
// between concurrent analyzer workers.
func (fm *FileMonitor) shouldAlert(check, filePath string) bool {
if filePath == "" {
return true // no path = no dedup possible
}
key := check + ":" + filePath
now := time.Now()
if v, loaded := fm.alertDedup.LoadOrStore(key, now); loaded {
if now.Sub(v.(time.Time)) < alertDedupTTL {
return false
}
fm.alertDedup.Store(key, now) // refresh TTL on expiry
}
return true
}
// M7 - overflowReporter reports dropped events and alerts separately.
func (fm *FileMonitor) overflowReporter() {
defer fm.wg.Done()
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-fm.stopCh:
return
case <-ticker.C:
droppedEv := atomic.SwapInt64(&fm.droppedEvents, 0)
droppedAl := atomic.SwapInt64(&fm.droppedAlerts, 0)
if droppedEv > 0 {
fm.sendAlert(alert.Warning, "fanotify_overflow",
fmt.Sprintf("fanotify event queue overflowed: %d events dropped in last minute", droppedEv),
"Possible event storm (backup, bulk update) or high-volume attack")
// Recover coverage: scan files in directories that saw drops
// so a threat landing during the storm is still detected.
fm.reconcileDrops()
}
if droppedAl > 0 {
fmt.Fprintf(os.Stderr, "[%s] alert channel full: %d alerts dropped in last minute\n", ts(), droppedAl)
}
// Evict stale dedup entries every minute
now := time.Now()
fm.alertDedup.Range(func(key, value any) bool {
if now.Sub(value.(time.Time)) > alertDedupTTL {
fm.alertDedup.Delete(key)
}
return true
})
}
}
}
// isPHPExtension returns true for all PHP file extensions that can execute code.
// containsFunc checks if content contains a function call that isn't part of
// a longer identifier. Prevents "WP_Filesystem(" matching "exec(" or
// "preg_match(" matching "exec(". Checks the character before the match
// is not a letter, digit, or underscore.
func containsFunc(content, funcCall string) bool {
idx := 0
for {
pos := strings.Index(content[idx:], funcCall)
if pos < 0 {
return false
}
absPos := idx + pos
if absPos == 0 {
return true
}
prev := content[absPos-1]
if (prev < 'a' || prev > 'z') && (prev < 'A' || prev > 'Z') &&
(prev < '0' || prev > '9') && prev != '_' {
return true
}
idx = absPos + len(funcCall)
if idx >= len(content) {
return false
}
}
}
func isPHPExtension(nameLower string) bool {
return strings.HasSuffix(nameLower, ".php") ||
strings.HasSuffix(nameLower, ".phtml") ||
strings.HasSuffix(nameLower, ".pht") ||
strings.HasSuffix(nameLower, ".php5")
}
func isCGIExtension(nameLower string) bool {
return strings.HasSuffix(nameLower, ".pl") ||
strings.HasSuffix(nameLower, ".cgi") ||
strings.HasSuffix(nameLower, ".py") ||
strings.HasSuffix(nameLower, ".sh") ||
strings.HasSuffix(nameLower, ".rb")
}
// checkCGIBackdoor reads a CGI script and checks for backdoor patterns.
// Detects Perl/Python/Bash backdoors like the LEVIATHAN toolkit.
func (fm *FileMonitor) checkCGIBackdoor(fd int, path, procInfo string) {
data := readFromFd(fd, 32768)
if data == nil {
return
}
content := strings.ToLower(string(data))
// Backdoor indicators in CGI scripts
indicators := 0
var matched []string
// Indicators weighted by suspicion level. Generic patterns like
// "request_method" and "cmd" removed — they match every CGI script.
shellPatterns := []struct {
pattern string
desc string
}{
{"system(", "system() call"},
{"os.popen", "os.popen() call"},
{"`$", "backtick execution with variable"},
{"content_length", "reads POST body length"},
{"base64_decode", "base64 decoding"},
{"$_post", "PHP POST input"},
{"$_get", "PHP GET input"},
{"param(", "CGI parameter read"},
{"qs.parse", "query string parsing"},
}
for _, sp := range shellPatterns {
if strings.Contains(content, sp.pattern) {
indicators++
matched = append(matched, sp.desc)
}
}
// 4+ indicators = likely backdoor
if indicators >= 4 {
fm.sendAlertWithPath(alert.Critical, "cgi_backdoor_realtime",
fmt.Sprintf("CGI backdoor detected: %s", path),
fmt.Sprintf("Indicators (%d): %s", indicators, strings.Join(matched, ", ")), path, procInfo)
return
}
// CGI scripts in unusual locations (images, css, js directories)
if strings.Contains(path, "/img/") || strings.Contains(path, "/images/") ||
strings.Contains(path, "/css/") || strings.Contains(path, "/js/") ||
strings.Contains(path, "/fonts/") || strings.Contains(path, "/icons/") {
fm.sendAlertWithPath(alert.High, "cgi_suspicious_location_realtime",
fmt.Sprintf("CGI script in non-CGI directory: %s", path),
"Scripts should not exist in image/css/js directories", path, procInfo)
return
}
// Run signature scan on the content
fm.runSignatureScan(data, path, filepath.Ext(path), procInfo)
}
// matchSuppression checks if a file path matches a suppression glob pattern.
// Supports patterns like "*/cache/*", "*/vendor/*", "*.log".
// Uses filepath.Match per path segment for wildcard patterns.
func matchSuppression(pattern, path string) bool {
// Direct match against full path
if m, _ := filepath.Match(pattern, path); m {
return true
}
// Match against basename (e.g. "*.log")
if m, _ := filepath.Match(pattern, filepath.Base(path)); m {
return true
}
// For patterns like "*/cache/*": check if any directory segment matches
// the non-wildcard core of the pattern. We split the pattern on "/" and
// match each pattern segment against the corresponding path segments.
patParts := strings.Split(pattern, "/")
pathParts := strings.Split(path, "/")
if len(patParts) < 2 {
return false
}
// Sliding window: try to align pattern segments with path segments.
// Empty pattern parts (from leading/trailing/double slashes) match any segment.
for i := 0; i <= len(pathParts)-len(patParts); i++ {
allMatch := true
for j, pp := range patParts {
if pp == "" {
continue // empty segment matches anything (acts as wildcard)
}
if i+j >= len(pathParts) {
allMatch = false
break
}
m, _ := filepath.Match(pp, pathParts[i+j])
if !m {
allMatch = false
break
}
}
// Only count as match if we consumed at least one non-empty pattern part
if allMatch {
hasNonEmpty := false
for _, pp := range patParts {
if pp != "" {
hasNonEmpty = true
break
}
}
if hasNonEmpty {
return true
}
}
}
return false
}
// readHead opens a file by path and reads the first maxBytes.
// Kept for path-based checks (HTML phishing, credential logs, ZIP checks)
// that need os.Stat for file size anyway.
func readHead(path string, maxBytes int) []byte {
// #nosec G304 -- readHead scans files surfaced by fanotify/scanner;
// reading user files for signature analysis is the daemon's purpose.
f, err := os.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
buf := make([]byte, maxBytes)
n, _ := f.Read(buf)
if n == 0 {
return nil
}
return buf[:n]
}
// looksLikePluginUpdate checks if a PHP file in uploads looks like a plugin
// update temp directory (e.g., elementor_t0q9y). Returns true if it matches
// the pattern of a known plugin extracting an update.
// M3 - uses sync.Map cache with 5-minute TTL for plugin directory stat results.
func looksLikePluginUpdate(path string) bool {
// WordPress plugin updates extract to /uploads/{pluginname}_{random}/
// Detect by extracting the directory name under uploads/ and checking
// if a matching plugin exists in wp-content/plugins/.
// No hardcoded whitelist - works for all 60,000+ WP plugins.
uploadsIdx := strings.Index(path, "/wp-content/uploads/")
if uploadsIdx < 0 {
return false
}
wpRoot := path[:uploadsIdx]
afterUploads := path[uploadsIdx+len("/wp-content/uploads/"):]
// Extract the first directory component: "header-footer_7ocsd"
slashIdx := strings.Index(afterUploads, "/")
if slashIdx < 0 {
return false
}
dirName := afterUploads[:slashIdx]
// Strip the random suffix (e.g. "_7ocsd") - WordPress appends _XXXXX
// The plugin name is everything before the last underscore-followed-by-random
pluginName := dirName
if lastUnderscore := strings.LastIndex(dirName, "_"); lastUnderscore > 0 {
suffix := dirName[lastUnderscore+1:]
// Random suffixes are short alphanumeric strings (5-8 chars)
if len(suffix) >= 4 && len(suffix) <= 10 {
pluginName = dirName[:lastUnderscore]
}
}
// Check if a matching plugin directory exists in plugins/
pluginDir := wpRoot + "/wp-content/plugins/" + pluginName
// M3 - check cache first
if cached, ok := pluginStatCache.Load(pluginDir); ok {
entry := cached.(pluginCacheEntry)
if time.Since(entry.ts) < pluginCacheTTL {
return entry.exists
}
}
_, err := os.Stat(pluginDir)
exists := err == nil
pluginStatCache.Store(pluginDir, pluginCacheEntry{
exists: exists,
ts: time.Now(),
})
return exists
}
package daemon
import (
"bufio"
"fmt"
"os"
"strings"
"github.com/pidginhost/csm/internal/alert"
)
// parseValiasFileForFindings parses a valiases file and returns findings.
// Used by both the realtime watcher and tests.
func parseValiasFileForFindings(path, domain string, localDomains map[string]bool, knownForwarders []string) []alert.Finding {
// #nosec G304 -- path from cPanel valiases directory walk; operator-scoped.
f, err := os.Open(path)
if err != nil {
return nil
}
defer f.Close()
var findings []alert.Finding
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
idx := strings.IndexByte(line, ':')
if idx < 0 {
continue
}
localPart := strings.TrimSpace(line[:idx])
dest := strings.TrimSpace(line[idx+1:])
if localPart == "" || dest == "" {
continue
}
dests := strings.Split(dest, ",")
for _, d := range dests {
d = strings.TrimSpace(d)
if d == "" {
continue
}
// Suppression check
if isKnownForwarderWatcher(localPart, domain, d, knownForwarders) {
continue
}
if strings.HasPrefix(d, "|") {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_pipe_forwarder",
Message: fmt.Sprintf("Pipe forwarder detected: %s@%s -> %s", localPart, domain, d),
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: %s\nFile: %s", domain, localPart, d, path),
})
continue
}
if d == "/dev/null" {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_suspicious_forwarder",
Message: fmt.Sprintf("Mail blackhole: %s@%s -> /dev/null", localPart, domain),
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: /dev/null\nFile: %s", domain, localPart, path),
})
continue
}
// Check if external
atIdx := strings.LastIndexByte(d, '@')
if atIdx >= 0 && atIdx < len(d)-1 {
destDomain := strings.ToLower(d[atIdx+1:])
if !localDomains[destDomain] {
msg := fmt.Sprintf("External forwarder: %s@%s -> %s", localPart, domain, d)
if localPart == "*" {
msg = fmt.Sprintf("Wildcard catch-all to external: *@%s -> %s", domain, d)
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_suspicious_forwarder",
Message: msg,
Details: fmt.Sprintf("Domain: %s\nLocal part: %s\nDestination: %s\nFile: %s", domain, localPart, d, path),
})
}
}
}
}
return findings
}
// isKnownForwarderWatcher checks if a forwarder matches the suppression list.
// Separate from checks package to avoid import cycles.
func isKnownForwarderWatcher(localPart, domain, dest string, knownForwarders []string) bool {
entry := fmt.Sprintf("%s@%s: %s", localPart, domain, dest)
for _, known := range knownForwarders {
if strings.EqualFold(strings.TrimSpace(known), entry) {
return true
}
}
return false
}
//go:build linux
package daemon
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"unsafe"
"golang.org/x/sys/unix"
"github.com/pidginhost/csm/internal/alert"
)
const valiasesDir = "/etc/valiases"
// ForwarderWatcher watches /etc/valiases/ for changes using inotify.
type ForwarderWatcher struct {
alertCh chan<- alert.Finding
knownForwarders []string
inotifyFd int
}
// NewForwarderWatcher creates a watcher for the valiases directory.
func NewForwarderWatcher(alertCh chan<- alert.Finding, knownForwarders []string) (*ForwarderWatcher, error) {
fd, err := unix.InotifyInit1(unix.IN_CLOEXEC | unix.IN_NONBLOCK)
if err != nil {
return nil, fmt.Errorf("inotify_init1: %w", err)
}
_, err = unix.InotifyAddWatch(fd, valiasesDir, unix.IN_CLOSE_WRITE)
if err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("inotify_add_watch(%s): %w", valiasesDir, err)
}
return &ForwarderWatcher{
alertCh: alertCh,
knownForwarders: knownForwarders,
inotifyFd: fd,
}, nil
}
// Run starts the watch loop. Blocks until stopCh is closed.
func (fw *ForwarderWatcher) Run(stopCh <-chan struct{}) {
buf := make([]byte, 4096)
// Use a polling approach since inotify fd + stopCh coordination
// requires either epoll or periodic polling. Keep it simple.
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-stopCh:
_ = unix.Close(fw.inotifyFd)
return
case <-ticker.C:
fw.readEvents(buf)
}
}
}
func (fw *ForwarderWatcher) readEvents(buf []byte) {
for {
n, err := unix.Read(fw.inotifyFd, buf)
if err != nil || n <= 0 {
return // EAGAIN or error - no more events
}
offset := 0
for offset < n {
if offset+unix.SizeofInotifyEvent > n {
break
}
// #nosec G103 -- inotify returns a packed binary stream;
// reinterpretation is required and bounded by the SizeofInotifyEvent check above.
event := (*unix.InotifyEvent)(unsafe.Pointer(&buf[offset]))
nameLen := int(event.Len)
if nameLen > 0 && offset+unix.SizeofInotifyEvent+nameLen <= n {
nameBytes := buf[offset+unix.SizeofInotifyEvent : offset+unix.SizeofInotifyEvent+nameLen]
// Trim null bytes
name := strings.TrimRight(string(nameBytes), "\x00")
if name != "" && !strings.HasPrefix(name, ".") {
fw.handleFileChange(name)
}
}
offset += unix.SizeofInotifyEvent + nameLen
}
}
}
func (fw *ForwarderWatcher) handleFileChange(domain string) {
path := filepath.Join(valiasesDir, domain)
// Load local domains for external detection
localDomains := loadLocalDomainsForWatcher()
findings := parseValiasFileForFindings(path, domain, localDomains, fw.knownForwarders)
for _, f := range findings {
f.Timestamp = time.Now()
f.Details += "\n(detected in realtime via inotify)"
select {
case fw.alertCh <- f:
default:
fmt.Fprintf(os.Stderr, "[%s] Warning: alert channel full, dropping forwarder finding for %s\n",
time.Now().Format("2006-01-02 15:04:05"), domain)
}
}
}
// loadLocalDomainsForWatcher reads local domain files. Separate from the checks
// package version to avoid import cycles.
func loadLocalDomainsForWatcher() map[string]bool {
domains := make(map[string]bool)
for _, path := range []string{"/etc/localdomains", "/etc/virtualdomains"} {
// #nosec G304 -- path iterates a literal slice of cPanel system files.
data, err := os.ReadFile(path)
if err != nil {
continue
}
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if idx := strings.IndexByte(line, ':'); idx > 0 {
line = strings.TrimSpace(line[:idx])
}
domains[strings.ToLower(line)] = true
}
}
return domains
}
package daemon
import (
"strings"
"sync/atomic"
"github.com/pidginhost/csm/internal/geoip"
)
// daemonGeoIPDB is a package-level pointer to the GeoIP database, set once
// during daemon init and read by log watcher handlers for country filtering.
var daemonGeoIPDB atomic.Pointer[geoip.DB]
// setGeoIPDB stores the GeoIP database for daemon-wide use.
func setGeoIPDB(db *geoip.DB) {
daemonGeoIPDB.Store(db)
}
// getGeoIPDB returns the daemon's GeoIP database, or nil.
func getGeoIPDB() *geoip.DB {
return daemonGeoIPDB.Load()
}
// isTrustedCountry checks if an IP's country is in the trusted list.
// Returns false if GeoIP is unavailable or country can't be resolved.
func isTrustedCountry(ip string, trustedCountries []string) bool {
if len(trustedCountries) == 0 {
return false
}
db := getGeoIPDB()
if db == nil {
return false
}
info := db.Lookup(ip)
if info.Country == "" {
return false
}
for _, tc := range trustedCountries {
if strings.EqualFold(info.Country, tc) {
return true
}
}
return false
}
package daemon
import (
"fmt"
"net"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
// Enhanced access_log handler - catches File Manager, API failures,
// webmail logins, wp-login brute force, and xmlrpc abuse.
func parseAccessLogLineEnhanced(line string, cfg *config.Config) []alert.Finding {
var findings []alert.Finding
fields := strings.Fields(line)
if len(fields) < 7 {
return nil
}
ip := fields[0]
if isInfraIPDaemon(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
return nil
}
lineLower := strings.ToLower(line)
// File Manager write operations (port 2083)
// Only match actual write actions - not read-only calls like get_homedir.
// Skip 401/403 responses - the server rejected the request, no write occurred.
// Match against the request URI only (between first pair of quotes), not the
// full line which includes the referer URL that can contain "upload" in paths.
if strings.Contains(line, "2083") && !strings.Contains(line, "\" 401 ") && !strings.Contains(line, "\" 403 ") {
requestURI := extractRequestURI(lineLower)
filemanWriteActions := []string{
"fileman/save_file", "fileman/upload_files",
"fileman/paste", "fileman/rename", "fileman/delete",
}
for _, action := range filemanWriteActions {
if strings.Contains(requestURI, action) {
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "cpanel_file_upload_realtime",
Message: fmt.Sprintf("cPanel File Manager write from non-infra IP: %s", ip),
Details: truncateDaemon(line, 300),
})
break
}
}
}
// API authentication failures (401/403)
// Suppress 401s that are stale-session artifacts from a recent password change.
// When a user changes their password, in-flight browser AJAX requests (notification
// polls, etc.) will 401 against the now-invalidated session - that's expected, not
// an attack. Real API abuse won't correlate with a recent purge for the same account.
if strings.Contains(line, "\" 401 ") || strings.Contains(line, "\" 403 ") {
if strings.Contains(lineLower, "json-api") || strings.Contains(lineLower, "/execute/") {
if !purgeTracker.isPostPurge401(ip) {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "api_auth_failure_realtime",
Message: fmt.Sprintf("cPanel API auth failure from %s", ip),
Details: truncateDaemon(line, 300),
})
}
}
}
// Webmail login attempts (port 2095/2096)
if !cfg.Suppressions.SuppressWebmail {
if strings.Contains(line, "2095") || strings.Contains(line, "2096") {
if strings.Contains(lineLower, "post") {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "webmail_login_realtime",
Message: fmt.Sprintf("Webmail login attempt from non-infra IP: %s", ip),
Details: truncateDaemon(line, 200),
})
}
}
}
return findings
}
// parseFTPLogLine handles FTP log entries from /var/log/messages.
func parseFTPLogLine(line string, cfg *config.Config) []alert.Finding {
var findings []alert.Finding
if !strings.Contains(line, "pure-ftpd") {
return nil
}
// Extract the client address. Pure-ftpd's standard syslog format
// prefixes the client as (user@addr), where addr is either an IP
// (DontResolve=yes) or a reverse-resolved hostname (cPanel's
// default with DontResolve=no). We try the pure-ftpd prefix first
// and fall back to the generic "whitespace field starting with a
// digit" scanner. If the pure-ftpd prefix contains a hostname
// rather than an IP, no finding is emitted — we can't hold an
// attacker accountable by hostname, and reverse-DNS lookups in the
// log hot path are not acceptable.
ip := extractPureFTPDClientIP(line)
if ip == "" {
ip = extractIPFromLogDaemon(line)
}
if ip == "" || isInfraIPDaemon(ip, cfg.InfraIPs) {
return nil
}
// Failed authentication
if strings.Contains(line, "Authentication failed") || strings.Contains(line, "auth failed") {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ftp_auth_failure_realtime",
Message: fmt.Sprintf("FTP authentication failed from %s", ip),
Details: truncateDaemon(line, 200),
})
}
// Successful login from non-infra
if strings.Contains(line, "is now logged in") {
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "ftp_login_realtime",
Message: fmt.Sprintf("FTP login from non-infra IP: %s", ip),
Details: truncateDaemon(line, 200),
})
}
return findings
}
// extractPureFTPDClientIP parses the "(user@addr)" prefix that pure-ftpd
// prepends to every log message and returns `addr` only if it parses as
// an IP. Returns empty if the log line contains no prefix at all, the
// prefix is malformed, or addr is a reverse-resolved hostname (in which
// case the caller should not emit a finding since we can't block a
// hostname at the firewall).
func extractPureFTPDClientIP(line string) string {
open := strings.Index(line, "(")
if open < 0 {
return ""
}
rest := line[open+1:]
close := strings.Index(rest, ")")
if close < 0 {
return ""
}
inner := rest[:close]
at := strings.IndexByte(inner, '@')
if at < 0 {
return ""
}
addr := inner[at+1:]
if net.ParseIP(addr) == nil {
return "" // hostname, not an IP — nothing we can block
}
return addr
}
// extractRequestURI extracts the request URI from an access log line.
// Format: ... "METHOD /path HTTP/1.1" ... → returns "/path"
// Returns the content between the first pair of quotes (the request line).
func extractRequestURI(line string) string {
start := strings.Index(line, "\"")
if start < 0 {
return ""
}
end := strings.Index(line[start+1:], "\"")
if end < 0 {
return ""
}
return line[start+1 : start+1+end]
}
func extractIPFromLogDaemon(line string) string {
fields := strings.Fields(line)
for _, f := range fields {
if len(f) >= 7 && f[0] >= '0' && f[0] <= '9' && strings.Count(f, ".") == 3 {
return strings.TrimRight(f, ",:;)([]")
}
}
return ""
}
package daemon
import (
"fmt"
"os"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
)
// Real-time access log handler for detecting wp-login.php brute force and
// xmlrpc.php abuse. Watches the LiteSpeed/Apache Combined Log Format access
// log and tracks per-IP POST counts using a sliding time window.
//
// Emits the same check names as the periodic CheckWPBruteForce
// (wp_login_bruteforce, xmlrpc_abuse) so the existing auto-block pipeline
// handles them automatically.
const (
// Sliding window for counting requests per IP.
accessLogWindow = 5 * time.Minute
// Thresholds within the window. Lower than periodic checks because
// we're watching in real-time and want fast response.
accessLogWPLoginThreshold = 10
accessLogXMLRPCThreshold = 15
// Eviction: how often to prune expired trackers.
accessLogEvictInterval = 5 * time.Minute
// Cooldown after an IP is flagged: don't re-alert for this long.
// Prevents alert spam while auto-block processes the finding.
accessLogBlockCooldown = 30 * time.Minute
)
// accessLogTracker tracks POST timestamps per endpoint for a single IP.
type accessLogTracker struct {
mu sync.Mutex
wpLoginTimes []time.Time
xmlrpcTimes []time.Time
adminPanelTimes []time.Time
wpLoginAlerted bool
xmlrpcAlerted bool
adminPanelAlerted bool
lastSeen time.Time
}
// accessLogTrackers holds per-IP state. sync.Map for concurrent handler access.
var accessLogTrackers sync.Map // key: IP string → value: *accessLogTracker
// discoverAccessLogPath returns the first access log path that exists,
// consulting the platform detector for OS/web-server specific candidates.
func discoverAccessLogPath() string {
info := platform.Detect()
for _, p := range info.AccessLogPaths {
if _, err := os.Stat(p); err == nil {
return p
}
}
return ""
}
// parseAccessLogBruteForce is the LogLineHandler for the Apache/LiteSpeed
// Combined Log Format access log. It parses each line, tracks per-IP POST
// counts to wp-login.php and xmlrpc.php, and emits findings when thresholds
// are crossed.
func parseAccessLogBruteForce(line string, cfg *config.Config) []alert.Finding {
// Fast reject: only care about POST requests to known attack targets.
if !strings.Contains(line, "POST") {
return nil
}
fields := strings.Fields(line)
if len(fields) < 7 {
return nil
}
ip := fields[0]
// Skip infra IPs and loopback.
if ip == "127.0.0.1" || ip == "::1" || isInfraIPDaemon(ip, cfg.InfraIPs) {
return nil
}
method := strings.Trim(fields[5], "\"")
if method != "POST" {
return nil
}
path := fields[6]
isWPLogin := strings.Contains(path, "wp-login.php")
isXMLRPC := strings.Contains(path, "xmlrpc.php")
isAdminPanel := isAdminPanelPath(path)
if !isWPLogin && !isXMLRPC && !isAdminPanel {
return nil
}
now := time.Now()
val, _ := accessLogTrackers.LoadOrStore(ip, &accessLogTracker{})
tracker := val.(*accessLogTracker)
tracker.mu.Lock()
defer tracker.mu.Unlock()
tracker.lastSeen = now
cutoff := now.Add(-accessLogWindow)
var results []alert.Finding
// Once a per-tier alert has fired, skip pruneAndAppend until cooldown
// clears the `alerted` flag. The slice would otherwise grow on every
// event during a sustained burst (potentially tens of thousands of
// entries for a 5-min window at 100 rps), wasting CPU on prune passes
// whose result is never consumed: the `alerted` flag already prevents
// re-alerts, and the eviction loop trims the slice on its own schedule.
//
// Safety: evictAccessLogState resets `alerted` once `lastSeen` is older
// than `cooldownCutoff` (30 min of silence by default). By that point
// the same eviction call has also pruned the slice to empty (window is
// 5 min, so any remaining timestamp is far past cutoff), so the next
// matching event correctly starts a fresh count from 1.
if isWPLogin && !tracker.wpLoginAlerted {
tracker.wpLoginTimes = pruneAndAppend(tracker.wpLoginTimes, cutoff, now)
if len(tracker.wpLoginTimes) >= accessLogWPLoginThreshold {
tracker.wpLoginAlerted = true
results = append(results, alert.Finding{
Severity: alert.Critical,
Check: "wp_login_bruteforce",
Message: fmt.Sprintf("WordPress login brute force from %s: %d POSTs in %v (real-time)", ip, len(tracker.wpLoginTimes), accessLogWindow),
Details: "Real-time detection: high rate of POST requests to wp-login.php",
Timestamp: now,
})
}
}
if isXMLRPC && !tracker.xmlrpcAlerted {
tracker.xmlrpcTimes = pruneAndAppend(tracker.xmlrpcTimes, cutoff, now)
if len(tracker.xmlrpcTimes) >= accessLogXMLRPCThreshold {
tracker.xmlrpcAlerted = true
results = append(results, alert.Finding{
Severity: alert.Critical,
Check: "xmlrpc_abuse",
Message: fmt.Sprintf("XML-RPC abuse from %s: %d POSTs in %v (real-time)", ip, len(tracker.xmlrpcTimes), accessLogWindow),
Details: "Real-time detection: high rate of POST requests to xmlrpc.php (brute force or amplification)",
Timestamp: now,
})
}
}
if isAdminPanel && !tracker.adminPanelAlerted {
tracker.adminPanelTimes = pruneAndAppend(tracker.adminPanelTimes, cutoff, now)
if len(tracker.adminPanelTimes) >= accessLogWPLoginThreshold {
tracker.adminPanelAlerted = true
results = append(results, alert.Finding{
Severity: alert.Critical,
Check: "admin_panel_bruteforce",
Message: fmt.Sprintf("Admin panel brute force from %s: %d POSTs in %v (real-time)", ip, len(tracker.adminPanelTimes), accessLogWindow),
Details: "Real-time detection: high rate of POST requests to common admin panel login paths (phpMyAdmin / Joomla)",
Timestamp: now,
})
}
}
return results
}
// pruneAndAppend removes entries older than cutoff and appends now.
func pruneAndAppend(times []time.Time, cutoff, now time.Time) []time.Time {
recent := times[:0]
for _, t := range times {
if !t.Before(cutoff) {
recent = append(recent, t)
}
}
return append(recent, now)
}
// StartAccessLogEviction starts a background goroutine that prunes expired
// tracker entries to prevent unbounded memory growth. Same pattern as
// StartModSecEviction.
func StartAccessLogEviction(stopCh <-chan struct{}) {
go func() {
ticker := time.NewTicker(accessLogEvictInterval)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case now := <-ticker.C:
evictAccessLogState(now)
}
}
}()
}
func evictAccessLogState(now time.Time) {
cutoff := now.Add(-accessLogWindow)
cooldownCutoff := now.Add(-accessLogBlockCooldown)
accessLogTrackers.Range(func(key, value any) bool {
tracker := value.(*accessLogTracker)
tracker.mu.Lock()
// Prune old timestamps.
tracker.wpLoginTimes = pruneSlice(tracker.wpLoginTimes, cutoff)
tracker.xmlrpcTimes = pruneSlice(tracker.xmlrpcTimes, cutoff)
tracker.adminPanelTimes = pruneSlice(tracker.adminPanelTimes, cutoff)
// Reset alerted flags after cooldown so the IP can be re-detected
// if it comes back after the block expires.
if tracker.wpLoginAlerted && tracker.lastSeen.Before(cooldownCutoff) {
tracker.wpLoginAlerted = false
}
if tracker.xmlrpcAlerted && tracker.lastSeen.Before(cooldownCutoff) {
tracker.xmlrpcAlerted = false
}
if tracker.adminPanelAlerted && tracker.lastSeen.Before(cooldownCutoff) {
tracker.adminPanelAlerted = false
}
empty := len(tracker.wpLoginTimes) == 0 && len(tracker.xmlrpcTimes) == 0 &&
len(tracker.adminPanelTimes) == 0 && tracker.lastSeen.Before(cooldownCutoff)
tracker.mu.Unlock()
if empty {
accessLogTrackers.Delete(key)
}
return true
})
}
// isAdminPanelPath returns true for high-confidence non-WP admin panel login
// paths suitable for hard-block auto-response. Drupal /user/login, Tomcat
// /manager/html, generic /admin/login.php, /mysql/ are intentionally EXCLUDED
// because they're either too generic (FP risk on shared hosting) or use a
// different attack shape (Basic auth vs. POST forms). See spec Component 5
// for the full rationale.
func isAdminPanelPath(path string) bool {
return strings.Contains(path, "/phpmyadmin/index.php") ||
strings.Contains(path, "/pma/index.php") ||
strings.Contains(path, "/phpMyAdmin/index.php") ||
strings.Contains(path, "/administrator/index.php")
}
func pruneSlice(times []time.Time, cutoff time.Time) []time.Time {
recent := times[:0]
for _, t := range times {
if !t.Before(cutoff) {
recent = append(recent, t)
}
}
return recent
}
package daemon
import (
"fmt"
"net"
"os"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/store"
)
const (
geoHistoryMaxAge = 30 * 24 * 60 * 60 // 30 days in seconds
geoMinLoginCount = 5 // minimum logins before alerting on new country
geoAlertCooldownH = 24 // hours between alerts per account
)
// parseDovecotLogLine handles Dovecot login lines from /var/log/maillog.
// It tracks per-mailbox login countries and alerts on new-country logins.
func parseDovecotLogLine(line string, cfg *config.Config) []alert.Finding {
// Only process successful login lines
if !strings.Contains(line, "Login: user=<") {
return nil
}
if !strings.Contains(line, "dovecot:") {
return nil
}
user, ip := parseDovecotLoginFields(line)
if user == "" || ip == "" {
return nil
}
// Skip private/loopback IPs
if isPrivateOrLoopback(ip) {
return nil
}
// Skip infra IPs
if isInfraIPDaemon(ip, cfg.InfraIPs) {
return nil
}
// GeoIP lookup
db := getGeoIPDB()
if db == nil {
return nil
}
info := db.Lookup(ip)
country := info.Country
if country == "" {
return nil
}
// Skip trusted countries
for _, tc := range cfg.Suppressions.TrustedCountries {
if strings.EqualFold(country, tc) {
return nil
}
}
// Load history from bbolt
boltDB := store.Global()
if boltDB == nil {
return nil
}
now := time.Now().Unix()
history, _ := boltDB.GetGeoHistory(user)
if history.Countries == nil {
history.Countries = make(map[string]int64)
}
// Prune old country entries
history.Countries = pruneOldCountries(history.Countries, now, geoHistoryMaxAge)
// Increment login count
history.LoginCount++
// Check if this is a new country
_, countryKnown := history.Countries[country]
isNewCountry := !countryKnown && history.LoginCount >= geoMinLoginCount
// Update country timestamp
history.Countries[country] = now
// Persist updated history
if err := boltDB.SetGeoHistory(user, history); err != nil {
fmt.Fprintf(os.Stderr, "[%s] Warning: failed to save geo history for %s: %v\n",
time.Now().Format("2006-01-02 15:04:05"), user, err)
}
if !isNewCountry {
return nil
}
// Rate limit: max 1 alert per account per 24h
alertKey := "email:geo_alert:" + user
lastAlertStr := boltDB.GetMetaString(alertKey)
if lastAlertStr != "" {
if lastAlert, err := time.Parse(time.RFC3339, lastAlertStr); err == nil {
if time.Since(lastAlert) < time.Duration(geoAlertCooldownH)*time.Hour {
return nil
}
}
}
// Record alert time
_ = boltDB.SetMetaString(alertKey, time.Now().Format(time.RFC3339))
// Build "previously seen" country list
var knownCountries []string
for c := range history.Countries {
if c != country {
knownCountries = append(knownCountries, c)
}
}
previousList := "none"
if len(knownCountries) > 0 {
previousList = strings.Join(knownCountries, ", ")
}
countryName := info.CountryName
if countryName == "" {
countryName = country
}
return []alert.Finding{{
Severity: alert.High,
Check: "email_suspicious_geo",
Message: fmt.Sprintf("Suspicious email login for %s from %s (%s) - previously seen: %s",
user, countryName, ip, previousList),
Details: fmt.Sprintf("Country: %s (%s)\nIP: %s\nLogin count: %d\nPreviously seen countries: %s",
country, countryName, ip, history.LoginCount, previousList),
}}
}
// parseDovecotLoginFields extracts user and remote IP from a Dovecot login line.
// Expected format: "... Login: user=<user@domain>, ... rip=1.2.3.4, ..."
func parseDovecotLoginFields(line string) (user, ip string) {
// Extract user from user=<...>
userIdx := strings.Index(line, "user=<")
if userIdx < 0 {
return "", ""
}
rest := line[userIdx+6:]
endIdx := strings.IndexByte(rest, '>')
if endIdx < 0 {
return "", ""
}
user = rest[:endIdx]
// Extract remote IP from rip=...
ripIdx := strings.Index(line, "rip=")
if ripIdx < 0 {
return "", ""
}
rest = line[ripIdx+4:]
endIdx = strings.IndexAny(rest, ", \t\n")
if endIdx < 0 {
ip = rest
} else {
ip = rest[:endIdx]
}
if user == "" || ip == "" {
return "", ""
}
return user, ip
}
// isPrivateOrLoopback returns true if the IP is loopback or RFC1918 private.
func isPrivateOrLoopback(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return true // invalid = skip
}
if ip.IsLoopback() {
return true
}
// RFC1918 checks
private10 := net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)}
private172 := net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)}
private192 := net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}
return private10.Contains(ip) || private172.Contains(ip) || private192.Contains(ip)
}
// pruneOldCountries removes country entries older than maxAge seconds.
func pruneOldCountries(countries map[string]int64, now, maxAge int64) map[string]int64 {
pruned := make(map[string]int64, len(countries))
cutoff := now - maxAge
for c, ts := range countries {
if ts >= cutoff {
pruned[c] = ts
}
}
return pruned
}
package daemon
import (
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/platform"
"github.com/pidginhost/csm/internal/store"
)
// modsecIPCounter tracks deny timestamps for a single IP.
type modsecIPCounter struct {
mu sync.Mutex
times []time.Time
escalated bool // set once escalation fires - prevents repeated findings per window
}
var (
modsecDedup sync.Map // key: "IP:ruleID" → value: time.Time
modsecCSMCounter sync.Map // key: IP → value: *modsecIPCounter
)
const (
modsecDedupTTL = 60 * time.Second
modsecEscalationWin = 10 * time.Minute
modsecEscalationHits = 3
modsecEvictInterval = 10 * time.Minute
)
// parseModSecLogLine parses a ModSecurity log line from Apache or LiteSpeed
// error logs and returns findings for blocked requests or warnings.
func parseModSecLogLine(line string, cfg *config.Config) []alert.Finding {
// Fast reject: not a ModSecurity line.
if !strings.Contains(line, "ModSecurity:") && !strings.Contains(line, "[MODSEC]") {
return nil
}
isLiteSpeed := strings.Contains(line, "[MODSEC]")
var ip, ruleID, msg, hostname, uri string
if isLiteSpeed {
ip = extractLiteSpeedIP(line)
ruleID = extractModSecField(line, `[id "`, `"]`)
msg = extractModSecField(line, `[msg "`, `"]`)
hostname = extractModSecField(line, `[hostname "`, `"]`)
uri = extractModSecField(line, `[uri "`, `"]`)
} else {
// Apache format: [client IP] or [client IP:port]
raw := extractModSecField(line, "[client ", "]")
// Strip port if present (Apache 2.4 uses "IP:port").
if idx := strings.LastIndex(raw, ":"); idx > 0 {
// Make sure it's not an IPv6 address (contains multiple colons).
if strings.Count(raw, ":") == 1 {
raw = raw[:idx]
}
}
ip = raw
ruleID = extractModSecField(line, `[id "`, `"]`)
msg = extractModSecField(line, `[msg "`, `"]`)
hostname = extractModSecField(line, `[hostname "`, `"]`)
uri = extractModSecField(line, `[uri "`, `"]`)
}
// Skip infra and loopback IPs - consistent with other realtime handlers
// (handlers.go:22, autoblock.go:149). Prevents noisy findings and
// false escalation from proxied or locally forwarded Apache traffic.
if ip != "" && (isInfraIPDaemon(ip, cfg.InfraIPs) || ip == "127.0.0.1" || ip == "::1") {
return nil
}
// Determine check name.
check := "modsec_warning_realtime"
if strings.Contains(line, "Access denied") {
check = "modsec_block_realtime"
} else if isLiteSpeed && strings.Contains(line, "triggered!") {
check = "modsec_block_realtime"
}
// Determine severity from rule ID.
// Individual blocks are informational - ModSecurity already denied the
// request. Only the escalation finding (3+ from same IP) is CRITICAL
// because it triggers auto-block at the firewall level.
severity := alert.Warning
if ruleNum, err := strconv.Atoi(ruleID); err == nil {
switch {
case ruleNum >= 900000 && ruleNum <= 900999:
severity = alert.High // CSM custom rules - attack blocked, informational
case ruleNum >= 910000:
severity = alert.High // OWASP CRS
}
}
// Build message.
message := fmt.Sprintf("ModSecurity rule %s", ruleID)
if check == "modsec_block_realtime" {
message = fmt.Sprintf("ModSecurity blocked request: rule %s", ruleID)
}
if ip != "" {
message += fmt.Sprintf(" from %s", ip)
}
if hostname != "" {
message += fmt.Sprintf(" on %s", hostname)
}
if uri != "" {
message += fmt.Sprintf(" uri=%s", uri)
}
if msg != "" {
message += fmt.Sprintf(" - %s", msg)
}
// Store structured details so the web UI can extract fields consistently
// regardless of whether the source was Apache or LiteSpeed format.
details := fmt.Sprintf("Rule: %s\nMessage: %s\nHostname: %s\nURI: %s\nRaw: %s",
ruleID, msg, hostname, uri, truncateDaemon(line, 300))
return []alert.Finding{{
Severity: severity,
Check: check,
Message: message,
Details: details,
}}
}
// extractModSecField extracts the value between start and end delimiters.
// Returns empty string if delimiters are not found.
func extractModSecField(line, start, end string) string {
idx := strings.Index(line, start)
if idx < 0 {
return ""
}
rest := line[idx+len(start):]
endIdx := strings.Index(rest, end)
if endIdx < 0 {
return ""
}
return rest[:endIdx]
}
// extractLiteSpeedIP extracts the client IP from a LiteSpeed log line.
// Format: [IP:PORT-CONN#VHOST] e.g. [122.9.114.57:41920-13#APVH_*_server.example.com]
func extractLiteSpeedIP(line string) string {
// Find the field that looks like [IP:PORT-CONN#VHOST]
// It appears as a bracketed field containing # and a port separator.
start := 0
for {
openBracket := strings.Index(line[start:], "[")
if openBracket < 0 {
return ""
}
openBracket += start
closeBracket := strings.Index(line[openBracket:], "]")
if closeBracket < 0 {
return ""
}
closeBracket += openBracket
field := line[openBracket+1 : closeBracket]
// LiteSpeed connection field has # (for VHOST) and contains IP:PORT-CONN#
if strings.Contains(field, "#") && strings.Contains(field, ":") && strings.Contains(field, "-") {
// Extract IP part: everything before the first ':'
colonIdx := strings.Index(field, ":")
if colonIdx > 0 {
ip := field[:colonIdx]
// Validate it looks like an IP (has dots).
if strings.Count(ip, ".") == 3 {
return ip
}
}
}
start = closeBracket + 1
}
}
// parseModSecLogLineDeduped wraps parseModSecLogLine with dedup and CSM-rule
// threshold escalation. It is the handler registered with the log watcher.
//
// Order of operations (critical for correctness):
// 1. Parse the raw line.
// 2. ALWAYS increment the CSM escalation counter (even if dedup will suppress).
// 3. Then check dedup - suppress the base finding if a duplicate, but still
// return any escalation finding from step 2.
func parseModSecLogLineDeduped(line string, cfg *config.Config) []alert.Finding {
raw := parseModSecLogLine(line, cfg)
if len(raw) == 0 {
return nil
}
f := raw[0]
now := time.Now()
var results []alert.Finding
// --- Step 1: CSM escalation (before dedup) ---
// Extract IP and rule ID directly from the raw log line - NOT from the
// finding message, which could be manipulated via log injection.
ip := extractModSecField(line, "[client ", "]")
if ip == "" {
ip = extractLiteSpeedIP(line)
}
// Strip port from Apache 2.4 format (IP:port)
if strings.Count(ip, ":") == 1 {
if idx := strings.LastIndex(ip, ":"); idx > 0 {
ip = ip[:idx]
}
}
ruleID := extractModSecField(line, `[id "`, `"]`)
ruleNum, _ := strconv.Atoi(ruleID)
isCSM := f.Check == "modsec_block_realtime" && ruleNum >= 900000 && ruleNum <= 900999
// Record hit for per-rule stats (24h hourly buckets)
if ruleNum >= 900000 && ruleNum <= 900999 {
if sdb := store.Global(); sdb != nil {
sdb.IncrModSecRuleHit(ruleNum, now)
}
}
// Check if this rule is excluded from auto-block escalation.
// Configurable via the Rules page in the web UI.
noEscalate := false
if db := store.Global(); db != nil {
noEscalate = db.GetModSecNoEscalateRules()[ruleNum]
}
if isCSM && ip != "" && !noEscalate {
if recordCSMDeny(ip, now) {
results = append(results, alert.Finding{
Severity: alert.Critical,
Check: "modsec_csm_block_escalation",
Message: fmt.Sprintf("CSM rule escalation: %d+ denies from %s within %v", modsecEscalationHits, ip, modsecEscalationWin),
Details: truncateDaemon(line, 400),
})
}
}
// --- Step 2: Dedup ---
dedupKey := ip + ":" + ruleID
if prev, loaded := modsecDedup.Load(dedupKey); loaded {
if now.Sub(prev.(time.Time)) < modsecDedupTTL {
// Suppress the base finding but still return any escalation.
if len(results) > 0 {
return results
}
return nil
}
}
modsecDedup.Store(dedupKey, now)
results = append(results, f)
return results
}
// recordCSMDeny records a deny event for the given IP and returns true if the
// escalation threshold has been reached (>= modsecEscalationHits within the
// escalation window).
func recordCSMDeny(ip string, now time.Time) bool {
val, _ := modsecCSMCounter.LoadOrStore(ip, &modsecIPCounter{})
ctr := val.(*modsecIPCounter)
ctr.mu.Lock()
defer ctr.mu.Unlock()
// Prune entries older than the escalation window.
cutoff := now.Add(-modsecEscalationWin)
recent := ctr.times[:0]
for _, t := range ctr.times {
if !t.Before(cutoff) {
recent = append(recent, t)
}
}
recent = append(recent, now)
ctr.times = recent
if len(recent) >= modsecEscalationHits && !ctr.escalated {
ctr.escalated = true
return true
}
return false
}
// StartModSecEviction starts a background goroutine that prunes expired dedup
// and counter entries every modsecEvictInterval to prevent unbounded memory
// growth. It returns when stopCh is closed.
func StartModSecEviction(stopCh <-chan struct{}) {
go func() {
ticker := time.NewTicker(modsecEvictInterval)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case now := <-ticker.C:
evictModSecState(now)
}
}
}()
}
// discoverModSecLogPath returns the path to the web server error log that
// CSM should tail for ModSecurity denies. Config override wins, then the
// first candidate from platform detection that actually exists.
func discoverModSecLogPath(cfg *config.Config) string {
if cfg.ModSecErrorLog != "" {
return cfg.ModSecErrorLog
}
return firstExistingPath(platform.Detect().ErrorLogPaths)
}
// firstExistingPath returns the first path in the list that exists on disk,
// or "" if none do. Pure function so tests can exercise it directly.
func firstExistingPath(candidates []string) string {
for _, p := range candidates {
if _, err := os.Stat(p); err == nil {
return p
}
}
return ""
}
// evictModSecState prunes expired entries from modsecDedup and modsecCSMCounter.
func evictModSecState(now time.Time) {
// Prune dedup entries older than modsecDedupTTL.
modsecDedup.Range(func(key, value any) bool {
if now.Sub(value.(time.Time)) >= modsecDedupTTL {
modsecDedup.Delete(key)
}
return true
})
// Prune counter entries.
cutoff := now.Add(-modsecEscalationWin)
modsecCSMCounter.Range(func(key, value any) bool {
ctr := value.(*modsecIPCounter)
ctr.mu.Lock()
recent := ctr.times[:0]
for _, t := range ctr.times {
if !t.Before(cutoff) {
recent = append(recent, t)
}
}
ctr.times = recent
empty := len(recent) == 0
if len(recent) < modsecEscalationHits {
ctr.escalated = false // reset cooldown when counter drops below threshold
}
ctr.mu.Unlock()
if empty {
modsecCSMCounter.Delete(key)
}
return true
})
}
package daemon
import (
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// mailIPEntry tracks failed-auth timestamps and suppression state for one IP.
type mailIPEntry struct {
times []time.Time
suppressed time.Time
lastSeen time.Time
}
// mailSubnetEntry tracks unique attacker IPs within a /24.
type mailSubnetEntry struct {
ips map[string]time.Time
suppressed time.Time
lastSeen time.Time
}
// mailAccountEntry tracks unique attacker IPs per mailbox, plus a separate
// suppression clock for compromise findings emitted by RecordSuccess.
type mailAccountEntry struct {
ips map[string]time.Time
suppressed time.Time
compromiseSuppressed time.Time
lastSeen time.Time
}
// mailAuthTracker aggregates dovecot IMAP/POP3/ManageSieve auth events into
// four detection signals: per-IP brute force, per-/24 password spray,
// per-mailbox account spray, and per-account compromise (success after
// recent failures).
//
// Thread-safe; Record/RecordSuccess may be called concurrently from multiple
// log readers.
type mailAuthTracker struct {
mu sync.Mutex
perIPThreshold int
subnetThreshold int
accountSprayThreshold int
window time.Duration
suppression time.Duration
maxTracked int
now func() time.Time
ips map[string]*mailIPEntry
subnets map[string]*mailSubnetEntry
accounts map[string]*mailAccountEntry
}
// newMailAuthTracker constructs a tracker. `now` is injected so tests can
// use deterministic clocks; pass `time.Now` in production.
func newMailAuthTracker(
perIPThreshold int,
subnetThreshold int,
accountSprayThreshold int,
window time.Duration,
suppression time.Duration,
maxTracked int,
now func() time.Time,
) *mailAuthTracker {
if now == nil {
now = time.Now
}
return &mailAuthTracker{
perIPThreshold: perIPThreshold,
subnetThreshold: subnetThreshold,
accountSprayThreshold: accountSprayThreshold,
window: window,
suppression: suppression,
maxTracked: maxTracked,
now: now,
ips: make(map[string]*mailIPEntry),
subnets: make(map[string]*mailSubnetEntry),
accounts: make(map[string]*mailAccountEntry),
}
}
// Size returns the total number of tracked entities (IPs + subnets + accounts).
func (t *mailAuthTracker) Size() int {
t.mu.Lock()
defer t.mu.Unlock()
return len(t.ips) + len(t.subnets) + len(t.accounts)
}
// Record processes one dovecot IMAP/POP3/ManageSieve auth-failure observation.
// Returns zero or more findings that callers should append.
//
// ip MUST be non-private, non-loopback, and non-infra — callers enforce this
// before invoking Record.
func (t *mailAuthTracker) Record(ip, account string) []alert.Finding {
if ip == "" {
return nil
}
t.mu.Lock()
defer t.mu.Unlock()
now := t.now()
cutoff := now.Add(-t.window)
var findings []alert.Finding
// --- Per-IP tracker ---
e, ok := t.ips[ip]
if !ok {
e = &mailIPEntry{}
t.ips[ip] = e
}
e.times = pruneTimes(e.times, cutoff)
e.times = append(e.times, now)
e.lastSeen = now
if len(e.times) >= t.perIPThreshold && !now.Before(e.suppressed) {
e.suppressed = now.Add(t.suppression)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "mail_bruteforce",
Message: fmt.Sprintf("Mail auth brute force from %s: %d failed auths in %v",
ip, len(e.times), t.window),
Details: "Real-time detection of dovecot imap/pop3/managesieve auth failures",
Timestamp: now,
})
}
// --- Per-/24 subnet tracker (IPv4 only) ---
if prefix := extractPrefix24Daemon(ip); prefix != "" {
s, ok := t.subnets[prefix]
if !ok {
s = &mailSubnetEntry{ips: make(map[string]time.Time)}
t.subnets[prefix] = s
}
for ipKey, ts := range s.ips {
if ts.Before(cutoff) {
delete(s.ips, ipKey)
}
}
s.ips[ip] = now
s.lastSeen = now
if len(s.ips) >= t.subnetThreshold && !now.Before(s.suppressed) {
s.suppressed = now.Add(t.suppression)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "mail_subnet_spray",
Message: fmt.Sprintf("Mail password spray from %s.0/24: %d unique IPs in %v",
prefix, len(s.ips), t.window),
Details: "Real-time detection of mail auth failures from many IPs in one /24",
Timestamp: now,
})
}
}
// --- Per-account spray tracker ---
if account != "" {
a, ok := t.accounts[account]
if !ok {
a = &mailAccountEntry{ips: make(map[string]time.Time)}
t.accounts[account] = a
}
for ipKey, ts := range a.ips {
if ts.Before(cutoff) {
delete(a.ips, ipKey)
}
}
a.ips[ip] = now
a.lastSeen = now
if len(a.ips) >= t.accountSprayThreshold && !now.Before(a.suppressed) {
a.suppressed = now.Add(t.suppression)
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "mail_account_spray",
Message: fmt.Sprintf("Mail password spray targeting %s: %d unique IPs in %v",
account, len(a.ips), t.window),
Details: "Distributed login attempts across many IPs against one mailbox (visibility only — no auto-block).",
Timestamp: now,
})
}
}
t.enforceMaxTracked()
return findings
}
// RecordSuccess processes a successful mail login. Emits mail_account_compromised
// when the successful IP has recent failed auths for the same account — a
// zero-FP compromise signal: the attacker literally failed N times from that IP
// for that mailbox, then guessed the password.
//
// ip and account MUST both be non-empty. Caller filters infra/private/loopback
// IPs before invoking.
func (t *mailAuthTracker) RecordSuccess(ip, account string) []alert.Finding {
if ip == "" || account == "" {
return nil
}
t.mu.Lock()
defer t.mu.Unlock()
a, ok := t.accounts[account]
if !ok {
return nil
}
if _, failedRecently := a.ips[ip]; !failedRecently {
return nil
}
now := t.now()
if now.Before(a.compromiseSuppressed) {
return nil
}
a.compromiseSuppressed = now.Add(t.suppression)
return []alert.Finding{{
Severity: alert.Critical,
Check: "mail_account_compromised",
Message: fmt.Sprintf("Mail account compromise: successful login for %s from %s after recent auth failures",
account, ip),
Details: "Attacker succeeded after one or more failed attempts from the same IP for this mailbox. Rotate password and revoke sessions.",
Timestamp: now,
}}
}
// Purge removes stale entries older than (window + suppression).
// Called from a background goroutine every minute.
func (t *mailAuthTracker) Purge() {
t.mu.Lock()
defer t.mu.Unlock()
now := t.now()
activityCutoff := now.Add(-(t.window + t.suppression))
windowCutoff := now.Add(-t.window)
for k, e := range t.ips {
e.times = pruneTimes(e.times, windowCutoff)
if len(e.times) == 0 && !e.lastSeen.After(activityCutoff) {
delete(t.ips, k)
}
}
for k, s := range t.subnets {
for ip, ts := range s.ips {
if ts.Before(windowCutoff) {
delete(s.ips, ip)
}
}
if len(s.ips) == 0 && !s.lastSeen.After(activityCutoff) {
delete(t.subnets, k)
}
}
for k, a := range t.accounts {
for ip, ts := range a.ips {
if ts.Before(windowCutoff) {
delete(a.ips, ip)
}
}
if len(a.ips) == 0 && !a.lastSeen.After(activityCutoff) {
delete(t.accounts, k)
}
}
}
// enforceMaxTracked evicts the least-recently-seen entries until the IP count
// is <= 95% of maxTracked. Batch target avoids re-sorting on every subsequent
// insert. Caller must hold t.mu.
func (t *mailAuthTracker) enforceMaxTracked() {
total := len(t.ips) + len(t.subnets) + len(t.accounts)
if total <= t.maxTracked {
return
}
// Evict to 95% of cap so subsequent inserts don't re-trigger the sort.
target := t.maxTracked * 95 / 100
type victim struct {
kind string // "ip" | "subnet" | "account"
key string
seen time.Time
}
victims := make([]victim, 0, total)
for k, v := range t.ips {
victims = append(victims, victim{"ip", k, v.lastSeen})
}
for k, v := range t.subnets {
victims = append(victims, victim{"subnet", k, v.lastSeen})
}
for k, v := range t.accounts {
victims = append(victims, victim{"account", k, v.lastSeen})
}
sort.Slice(victims, func(i, j int) bool { return victims[i].seen.Before(victims[j].seen) })
for i := 0; i < len(victims); i++ {
if len(t.ips)+len(t.subnets)+len(t.accounts) <= target {
break
}
v := victims[i]
switch v.kind {
case "ip":
delete(t.ips, v.key)
case "subnet":
delete(t.subnets, v.key)
case "account":
delete(t.accounts, v.key)
}
}
}
// isMailAuthLine returns true for dovecot imap/pop3/managesieve login events.
func isMailAuthLine(line string) bool {
if !strings.Contains(line, "dovecot:") {
return false
}
return strings.Contains(line, "imap-login:") ||
strings.Contains(line, "pop3-login:") ||
strings.Contains(line, "managesieve-login:")
}
// extractMailLoginEvent parses a dovecot login line and returns
// (ip, account, success). Returns empty strings and false on parse failure.
//
// Real dovecot wire format (validated against production logs):
//
// Success: "imap-login: Logged in: user=<alice@x.ro>, method=PLAIN, rip=..."
// Failure: "imap-login: Login aborted: ... (auth failed, N attempts ...): user=<...>, method=..., rip=..."
//
// The success marker is "Logged in" (NOT "Login:" — an earlier version of
// this parser used the wrong marker and silently skipped every successful
// login, which in turn broke RecordSuccess-based compromise detection).
// The failure marker is "(auth failed" with the opening paren, which
// distinguishes real auth failures from Login-aborted reasons like
// "no auth attempts" or TLS handshake errors.
func extractMailLoginEvent(line string) (ip, account string, success bool) {
switch {
case strings.Contains(line, "-login: Logged in"):
success = true
case strings.Contains(line, "(auth failed"):
success = false
default:
return "", "", false
}
// Extract user=<...> via balanced angle brackets.
if i := strings.Index(line, "user=<"); i >= 0 {
rest := line[i+len("user=<"):]
if end := strings.Index(rest, ">"); end >= 0 {
account = rest[:end]
}
}
// Extract rip=... field. Delimited by comma or whitespace.
if i := strings.Index(line, "rip="); i >= 0 {
rest := line[i+len("rip="):]
end := strings.IndexAny(rest, ", \n")
if end < 0 {
end = len(rest)
}
ip = rest[:end]
}
return ip, account, success
}
package daemon
import (
"bufio"
"fmt"
"net"
"os"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
const pamSocketPath = "/var/run/csm/pam.sock"
// PAMListener listens on a Unix socket for authentication events from the
// pam_csm.so PAM module. Tracks failures per IP and triggers CSM auto-blocking.
type PAMListener struct {
cfg *config.Config
alertCh chan<- alert.Finding
listener net.Listener
mu sync.Mutex
failures map[string]*pamFailureTracker
}
type pamFailureTracker struct {
count int
firstSeen time.Time
lastSeen time.Time
users map[string]bool
services map[string]bool
blocked bool
}
// NewPAMListener creates a Unix socket listener for PAM events.
func NewPAMListener(cfg *config.Config, alertCh chan<- alert.Finding) (*PAMListener, error) {
// Ensure socket directory exists
if err := os.MkdirAll("/var/run/csm", 0750); err != nil {
return nil, fmt.Errorf("creating socket dir: %w", err)
}
// Remove stale socket
os.Remove(pamSocketPath)
listener, err := net.Listen("unix", pamSocketPath)
if err != nil {
return nil, fmt.Errorf("listening on %s: %w", pamSocketPath, err)
}
// The PAM module runs inside privileged auth stacks, so the socket can stay
// root-only instead of accepting arbitrary local writers.
_ = os.Chmod(pamSocketPath, 0600)
return &PAMListener{
cfg: cfg,
alertCh: alertCh,
listener: listener,
failures: make(map[string]*pamFailureTracker),
}, nil
}
// Run accepts connections and processes PAM events.
func (p *PAMListener) Run(stopCh <-chan struct{}) {
// Start cleanup goroutine to expire old failure records
go p.cleanupLoop(stopCh)
// Accept connections
go func() {
for {
conn, err := p.listener.Accept()
if err != nil {
select {
case <-stopCh:
return
default:
fmt.Fprintf(os.Stderr, "[%s] PAM listener accept error: %v\n", ts(), err)
time.Sleep(100 * time.Millisecond)
continue
}
}
go p.handleConnection(conn)
}
}()
<-stopCh
}
// Stop closes the listener and removes the socket file.
func (p *PAMListener) Stop() {
_ = p.listener.Close()
os.Remove(pamSocketPath)
}
func (p *PAMListener) handleConnection(conn net.Conn) {
defer func() { _ = conn.Close() }()
if !isTrustedPAMPeer(conn) {
return
}
_ = conn.SetDeadline(time.Now().Add(1 * time.Second))
scanner := bufio.NewScanner(conn)
for scanner.Scan() {
line := scanner.Text()
p.processEvent(line)
}
}
// processEvent handles a single PAM event line.
// Format: FAIL ip=1.2.3.4 user=root service=sshd
//
// OK ip=1.2.3.4 user=root service=sshd
func (p *PAMListener) processEvent(line string) {
parts := strings.SplitN(strings.TrimSpace(line), " ", 2)
if len(parts) < 2 {
return
}
eventType := parts[0]
kvPart := parts[1]
var ip, user, service string
for _, kv := range strings.Fields(kvPart) {
switch {
case strings.HasPrefix(kv, "ip="):
ip = kv[3:]
case strings.HasPrefix(kv, "user="):
user = kv[5:]
case strings.HasPrefix(kv, "service="):
service = kv[8:]
}
}
if ip == "" || ip == "-" || ip == "127.0.0.1" {
return
}
// Skip infra IPs
if isInfraIP(ip, p.cfg.InfraIPs) {
return
}
switch eventType {
case "FAIL":
p.recordFailure(ip, user, service)
case "OK":
p.clearFailures(ip)
// Successful login from non-infra IP - informational alert
p.alertCh <- alert.Finding{
Severity: alert.High,
Check: "pam_login",
Message: fmt.Sprintf("Login success from non-infra IP: %s (user: %s, service: %s)", ip, user, service),
Timestamp: time.Now(),
}
}
}
func (p *PAMListener) recordFailure(ip, user, service string) {
p.mu.Lock()
defer p.mu.Unlock()
tracker, exists := p.failures[ip]
if !exists {
tracker = &pamFailureTracker{
firstSeen: time.Now(),
users: make(map[string]bool),
services: make(map[string]bool),
}
p.failures[ip] = tracker
}
tracker.count++
tracker.lastSeen = time.Now()
tracker.users[user] = true
tracker.services[service] = true
// Check threshold
threshold := 5
windowMin := 10
if p.cfg.Thresholds.MultiIPLoginThreshold > 0 {
threshold = p.cfg.Thresholds.MultiIPLoginThreshold
}
if p.cfg.Thresholds.MultiIPLoginWindowMin > 0 {
windowMin = p.cfg.Thresholds.MultiIPLoginWindowMin
}
// Only block if within the time window
if time.Since(tracker.firstSeen) > time.Duration(windowMin)*time.Minute {
// Window expired - reset tracker
tracker.count = 1
tracker.firstSeen = time.Now()
tracker.users = map[string]bool{user: true}
tracker.services = map[string]bool{service: true}
tracker.blocked = false
return
}
if tracker.count >= threshold && !tracker.blocked {
tracker.blocked = true
// Build user/service lists for details
var users, services []string
for u := range tracker.users {
users = append(users, u)
}
for s := range tracker.services {
services = append(services, s)
}
p.alertCh <- alert.Finding{
Severity: alert.Critical,
Check: "pam_bruteforce",
Message: fmt.Sprintf("PAM brute-force detected: %s (%d failures in %ds)", ip, tracker.count, int(time.Since(tracker.firstSeen).Seconds())),
Details: fmt.Sprintf("Users targeted: %s\nServices: %s",
strings.Join(users, ", "), strings.Join(services, ", ")),
Timestamp: time.Now(),
}
}
}
func (p *PAMListener) clearFailures(ip string) {
p.mu.Lock()
defer p.mu.Unlock()
delete(p.failures, ip)
}
// cleanupLoop removes expired failure trackers every minute.
func (p *PAMListener) cleanupLoop(stopCh <-chan struct{}) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case <-ticker.C:
p.mu.Lock()
cutoff := time.Now().Add(-30 * time.Minute)
for ip, tracker := range p.failures {
if tracker.lastSeen.Before(cutoff) {
delete(p.failures, ip)
}
}
p.mu.Unlock()
}
}
}
// isInfraIP checks if an IP is in the configured infra IP ranges.
// Duplicated here to avoid import cycle with checks package.
func isInfraIP(ip string, infraNets []string) bool {
parsed := net.ParseIP(ip)
if parsed == nil {
return false
}
for _, cidr := range infraNets {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if network.Contains(parsed) {
return true
}
}
return false
}
//go:build linux
package daemon
import (
"net"
"golang.org/x/sys/unix"
)
func isTrustedPAMPeer(conn net.Conn) bool {
unixConn, ok := conn.(*net.UnixConn)
if !ok {
return false
}
rawConn, err := unixConn.SyscallConn()
if err != nil {
return false
}
trusted := false
controlErr := rawConn.Control(func(fd uintptr) {
// #nosec G115 -- socket fd from net.Conn.SyscallConn; POSIX fd fits in int.
cred, err := unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
if err == nil && cred != nil && cred.Uid == 0 {
trusted = true
}
})
return controlErr == nil && trusted
}
package daemon
import (
"fmt"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
// PasswordHijackDetector tracks WHM password changes from non-infra IPs
// and correlates them with subsequent cPanel logins to detect the attack
// pattern: attacker changes password via WHM → immediately logs in.
//
// Legitimate flow (excluded):
//
// Portal (infra IP) changes password via xml-api → user logs in
//
// Attack flow (detected):
//
// Attacker (non-infra IP) changes password via whostmgr → logs in within 60s
type PasswordHijackDetector struct {
mu sync.Mutex
recentChanges map[string]*passwordChange // account -> change info
cfg *config.Config
alertCh chan<- alert.Finding
}
type passwordChange struct {
account string
ip string
timestamp time.Time
}
const hijackWindow = 120 * time.Second // time window to correlate password change + login
// NewPasswordHijackDetector creates a new detector.
func NewPasswordHijackDetector(cfg *config.Config, alertCh chan<- alert.Finding) *PasswordHijackDetector {
return &PasswordHijackDetector{
recentChanges: make(map[string]*passwordChange),
cfg: cfg,
alertCh: alertCh,
}
}
// HandlePasswordChange records a WHM password change from a non-infra IP.
func (d *PasswordHijackDetector) HandlePasswordChange(account, ip string) {
if isInfraIPDaemon(ip, d.cfg.InfraIPs) || ip == "127.0.0.1" || ip == "internal" {
return // legitimate - portal or admin action
}
d.mu.Lock()
defer d.mu.Unlock()
d.recentChanges[account] = &passwordChange{
account: account,
ip: ip,
timestamp: time.Now(),
}
// Alert on the password change itself - non-infra WHM password change is always suspicious
d.alertCh <- alert.Finding{
Severity: alert.Critical,
Check: "whm_password_change_noninfra",
Message: fmt.Sprintf("WHM password change from non-infra IP: %s (account: %s)", ip, account),
Details: "Password was changed via WHM from an IP outside your infrastructure. This is a strong indicator of account takeover.",
Timestamp: time.Now(),
}
}
// HandleLogin checks if a cPanel login matches a recent non-infra password change.
func (d *PasswordHijackDetector) HandleLogin(account, loginIP string) {
if isInfraIPDaemon(loginIP, d.cfg.InfraIPs) {
return
}
d.mu.Lock()
change, exists := d.recentChanges[account]
if exists {
delete(d.recentChanges, account)
}
d.mu.Unlock()
if !exists {
return
}
// Check if within the hijack window
if time.Since(change.timestamp) > hijackWindow {
return
}
// CONFIRMED ATTACK: password changed from non-infra IP, login within 120s
d.alertCh <- alert.Finding{
Severity: alert.Critical,
Check: "password_hijack_confirmed",
Message: fmt.Sprintf("CONFIRMED ACCOUNT HIJACK: %s - password changed from %s, login from %s within %ds", account, change.ip, loginIP, int(time.Since(change.timestamp).Seconds())),
Details: fmt.Sprintf("Attack pattern: WHM password change from non-infra IP followed by immediate cPanel login.\nPassword change IP: %s\nLogin IP: %s\nTime between: %ds\n\nBoth IPs should be permanently blocked.", change.ip, loginIP, int(time.Since(change.timestamp).Seconds())),
Timestamp: time.Now(),
}
}
// Cleanup removes expired entries.
func (d *PasswordHijackDetector) Cleanup() {
d.mu.Lock()
defer d.mu.Unlock()
for account, change := range d.recentChanges {
if time.Since(change.timestamp) > hijackWindow*2 {
delete(d.recentChanges, account)
}
}
}
// ParseSessionLineForHijack extracts password change and login events
// from session log lines and feeds them to the detector.
func ParseSessionLineForHijack(line string, detector *PasswordHijackDetector) {
// WHM password change: [timestamp] info [whostmgr] IP PURGE account:token password_change
if strings.Contains(line, "[whostmgr]") && strings.Contains(line, "PURGE") && strings.Contains(line, "password_change") {
ip, account := parseWHMPurge(line)
if ip != "" && account != "" {
detector.HandlePasswordChange(account, ip)
}
}
// cPanel login: [timestamp] info [cpaneld] IP NEW account:token ...
if strings.Contains(line, "[cpaneld]") && strings.Contains(line, " NEW ") {
// Skip API sessions
if strings.Contains(line, "method=create_user_session") {
return
}
ip, account := parseCpanelSessionLogin(line)
if ip != "" && account != "" {
detector.HandleLogin(account, ip)
}
}
}
func parseWHMPurge(line string) (ip, account string) {
// Format: [timestamp] info [whostmgr] 198.51.100.50 PURGE account:token password_change
idx := strings.Index(line, "[whostmgr]")
if idx < 0 {
return "", ""
}
rest := strings.TrimSpace(line[idx+len("[whostmgr]"):])
fields := strings.Fields(rest)
if len(fields) < 3 {
return "", ""
}
ip = fields[0]
// Find account from PURGE account:token
for i, f := range fields {
if f == "PURGE" && i+1 < len(fields) {
parts := strings.SplitN(fields[i+1], ":", 2)
if len(parts) >= 1 {
account = parts[0]
}
break
}
}
return ip, account
}
package daemon
import (
"fmt"
"strings"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
)
const phpEventsLogPath = "/var/run/csm/php_events.log"
// parsePHPShieldLogLine wraps parsePHPShieldLine for the log watcher handler signature.
func parsePHPShieldLogLine(line string, _ *config.Config) []alert.Finding { //nolint:unparam
f := parsePHPShieldLine(line)
if f == nil {
return nil
}
return []alert.Finding{*f}
}
// parsePHPShieldLine parses a line from the PHP shield event log and returns
// a finding if it represents a security event.
//
// Format: [2026-03-25 10:00:00] EVENT_TYPE ip=X script=Y uri=Z ua=A details=B
func parsePHPShieldLine(line string) *alert.Finding {
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "[") {
return nil
}
// Extract event type (first word after the timestamp bracket)
closeBracket := strings.Index(line, "]")
if closeBracket < 0 || closeBracket+2 >= len(line) {
return nil
}
rest := strings.TrimSpace(line[closeBracket+1:])
fields := strings.SplitN(rest, " ", 2)
if len(fields) < 1 {
return nil
}
eventType := fields[0]
// Extract key=value pairs
var ip, script, details string
if len(fields) > 1 {
kvPart := fields[1]
for _, kv := range splitKV(kvPart) {
switch kv[0] {
case "ip":
ip = kv[1]
case "script":
script = kv[1]
case "details":
details = kv[1]
}
}
}
switch eventType {
case "BLOCK_PATH":
return &alert.Finding{
Severity: alert.Critical,
Check: "php_shield_block",
Message: fmt.Sprintf("PHP Shield blocked execution from dangerous path: %s", script),
Details: fmt.Sprintf("IP: %s\n%s", ip, details),
}
case "WEBSHELL_PARAM":
return &alert.Finding{
Severity: alert.Critical,
Check: "php_shield_webshell",
Message: fmt.Sprintf("PHP Shield detected webshell command parameter: %s", script),
Details: fmt.Sprintf("IP: %s\n%s", ip, details),
}
case "EVAL_FATAL":
return &alert.Finding{
Severity: alert.High,
Check: "php_shield_eval",
Message: fmt.Sprintf("PHP Shield detected eval() chain failure: %s", script),
Details: details,
}
}
return nil
}
// splitKV splits "key1=val1 key2=val2" respecting values with spaces.
func splitKV(s string) [][2]string {
var result [][2]string
keys := []string{"ip=", "script=", "uri=", "ua=", "details="}
for i, key := range keys {
idx := strings.Index(s, key)
if idx < 0 {
continue
}
valStart := idx + len(key)
// Value ends at the next key or end of string
valEnd := len(s)
for _, nextKey := range keys[i+1:] {
nextIdx := strings.Index(s[valStart:], " "+nextKey)
if nextIdx >= 0 {
valEnd = valStart + nextIdx
break
}
}
val := strings.TrimSpace(s[valStart:valEnd])
keyName := strings.TrimSuffix(key, "=")
result = append(result, [2]string{keyName, val})
}
return result
}
package daemon
import (
"sync"
"time"
)
const purgeSuppressionWindow = 60 * time.Second
// purgeTracker correlates password purge events with subsequent
// stale-session 401 errors to suppress false-positive alerts.
//
// When a cPanel user changes their password, all existing sessions are
// invalidated. Any in-flight browser requests (AJAX polls, notifications,
// etc.) will return 401 - these are expected side effects, not attacks.
//
// Flow: login (NEW) records IP→account, PURGE records account→time,
// 401 handler checks IP→account→purgeTime to decide suppression.
var purgeTracker = &purgeState{
purges: make(map[string]time.Time),
sessions: make(map[string]string),
}
type purgeState struct {
mu sync.Mutex
purges map[string]time.Time // account → last purge time
sessions map[string]string // IP → last known account
}
// recordLogin tracks which account an IP most recently logged into.
func (ps *purgeState) recordLogin(ip, account string) {
ps.mu.Lock()
defer ps.mu.Unlock()
ps.sessions[ip] = account
ps.cleanupLocked()
}
// recordPurge records a password purge event for an account.
func (ps *purgeState) recordPurge(account string) {
ps.mu.Lock()
defer ps.mu.Unlock()
ps.purges[account] = time.Now()
}
// isPostPurge401 returns true if the IP's 401 is likely a stale session
// artifact from a recent password change (within the suppression window).
func (ps *purgeState) isPostPurge401(ip string) bool {
ps.mu.Lock()
defer ps.mu.Unlock()
account, ok := ps.sessions[ip]
if !ok {
return false
}
purgeTime, ok := ps.purges[account]
if !ok {
return false
}
return time.Since(purgeTime) < purgeSuppressionWindow
}
// cleanupLocked removes stale entries. Caller must hold ps.mu.
func (ps *purgeState) cleanupLocked() {
cutoff := time.Now().Add(-2 * purgeSuppressionWindow)
for k, t := range ps.purges {
if t.Before(cutoff) {
delete(ps.purges, k)
}
}
// Cap sessions map to prevent unbounded growth (keep most recent 500)
if len(ps.sessions) > 500 {
ps.sessions = make(map[string]string)
}
}
package daemon
import (
"fmt"
"sort"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
)
// smtpIPEntry tracks failed-auth timestamps and suppression state for one IP.
type smtpIPEntry struct {
times []time.Time
suppressed time.Time
lastSeen time.Time
}
// smtpSubnetEntry tracks unique attacker IPs within a /24.
type smtpSubnetEntry struct {
ips map[string]time.Time // ip -> firstSeen in window
suppressed time.Time
lastSeen time.Time
}
// smtpAccountEntry tracks unique attacker IPs per mailbox.
type smtpAccountEntry struct {
ips map[string]time.Time
suppressed time.Time
lastSeen time.Time
}
// smtpAuthTracker aggregates dovecot auth-failure events into three
// detection signals: per-IP brute force, per-/24 password spray, and
// per-mailbox account spray.
//
// Thread-safe; Record may be called concurrently from multiple log readers.
type smtpAuthTracker struct {
mu sync.Mutex
perIPThreshold int
subnetThreshold int
accountSprayThreshold int
window time.Duration
suppression time.Duration
maxTracked int
now func() time.Time
ips map[string]*smtpIPEntry
subnets map[string]*smtpSubnetEntry
accounts map[string]*smtpAccountEntry
}
// newSMTPAuthTracker constructs a tracker. `now` is injected so tests can
// use deterministic clocks; pass `time.Now` in production.
func newSMTPAuthTracker(
perIPThreshold int,
subnetThreshold int,
accountSprayThreshold int,
window time.Duration,
suppression time.Duration,
maxTracked int,
now func() time.Time,
) *smtpAuthTracker {
if now == nil {
now = time.Now
}
return &smtpAuthTracker{
perIPThreshold: perIPThreshold,
subnetThreshold: subnetThreshold,
accountSprayThreshold: accountSprayThreshold,
window: window,
suppression: suppression,
maxTracked: maxTracked,
now: now,
ips: make(map[string]*smtpIPEntry),
subnets: make(map[string]*smtpSubnetEntry),
accounts: make(map[string]*smtpAccountEntry),
}
}
// Size returns the total number of tracked entities (IPs + subnets + accounts).
func (t *smtpAuthTracker) Size() int {
t.mu.Lock()
defer t.mu.Unlock()
return len(t.ips) + len(t.subnets) + len(t.accounts)
}
// Record processes one dovecot auth-failure observation. Returns zero or more
// findings that callers should append to their finding slice.
//
// ip MUST be non-private, non-loopback, and non-infra — callers enforce this
// before invoking Record.
func (t *smtpAuthTracker) Record(ip, account string) []alert.Finding {
if ip == "" {
return nil
}
t.mu.Lock()
defer t.mu.Unlock()
now := t.now()
cutoff := now.Add(-t.window)
var findings []alert.Finding
// --- Per-IP tracker ---
e, ok := t.ips[ip]
if !ok {
e = &smtpIPEntry{}
t.ips[ip] = e
}
e.times = pruneTimes(e.times, cutoff)
e.times = append(e.times, now)
e.lastSeen = now
if len(e.times) >= t.perIPThreshold && !now.Before(e.suppressed) {
e.suppressed = now.Add(t.suppression)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "smtp_bruteforce",
Message: fmt.Sprintf("SMTP brute force from %s: %d failed auths in %v",
ip, len(e.times), t.window),
Details: "Real-time detection of dovecot_login auth failures",
Timestamp: now,
})
}
// --- Per-/24 subnet tracker (IPv4 only) ---
if prefix := extractPrefix24Daemon(ip); prefix != "" {
s, ok := t.subnets[prefix]
if !ok {
s = &smtpSubnetEntry{ips: make(map[string]time.Time)}
t.subnets[prefix] = s
}
pruneSubnetIPs(s, cutoff)
s.ips[ip] = now
s.lastSeen = now
if len(s.ips) >= t.subnetThreshold && !now.Before(s.suppressed) {
s.suppressed = now.Add(t.suppression)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "smtp_subnet_spray",
Message: fmt.Sprintf("SMTP password spray from %s.0/24: %d unique IPs in %v",
prefix, len(s.ips), t.window),
Details: "Real-time detection of dovecot_login auth failures from many IPs in one /24",
Timestamp: now,
})
}
}
// --- Per-account spray tracker ---
if account != "" {
a, ok := t.accounts[account]
if !ok {
a = &smtpAccountEntry{ips: make(map[string]time.Time)}
t.accounts[account] = a
}
pruneAccountIPs(a, cutoff)
a.ips[ip] = now
a.lastSeen = now
if len(a.ips) >= t.accountSprayThreshold && !now.Before(a.suppressed) {
a.suppressed = now.Add(t.suppression)
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "smtp_account_spray",
Message: fmt.Sprintf("SMTP password spray targeting %s: %d unique IPs in %v",
account, len(a.ips), t.window),
Details: "Distributed login attempts across many IPs against one mailbox (visibility only — no auto-block).",
Timestamp: now,
})
}
}
t.enforceMaxTracked()
return findings
}
// pruneTimes drops timestamps older than cutoff. Reuses the backing array.
func pruneTimes(times []time.Time, cutoff time.Time) []time.Time {
recent := times[:0]
for _, ts := range times {
if !ts.Before(cutoff) {
recent = append(recent, ts)
}
}
return recent
}
// extractPrefix24Daemon returns the first three octets of an IPv4 address as
// "a.b.c", or "" if the input isn't an IPv4 address in dotted-quad form.
func extractPrefix24Daemon(ip string) string {
parts := 0
end := 0
for i := 0; i < len(ip); i++ {
if ip[i] == '.' {
parts++
if parts == 3 {
end = i
break
}
}
}
if parts != 3 {
return ""
}
// Reject IPv6 mapped or containing colons.
for i := 0; i < end; i++ {
if ip[i] == ':' {
return ""
}
}
return ip[:end]
}
// pruneSubnetIPs drops per-/24 IP entries whose last-seen is older than cutoff.
func pruneSubnetIPs(s *smtpSubnetEntry, cutoff time.Time) {
for ip, ts := range s.ips {
if ts.Before(cutoff) {
delete(s.ips, ip)
}
}
}
// pruneAccountIPs drops per-account IP entries whose last-seen is older than cutoff.
func pruneAccountIPs(a *smtpAccountEntry, cutoff time.Time) {
for ip, ts := range a.ips {
if ts.Before(cutoff) {
delete(a.ips, ip)
}
}
}
// Purge removes entries with no recent activity (older than window + suppression).
// Called from a background goroutine every minute.
func (t *smtpAuthTracker) Purge() {
t.mu.Lock()
defer t.mu.Unlock()
now := t.now()
activityCutoff := now.Add(-(t.window + t.suppression))
for k, e := range t.ips {
e.times = pruneTimes(e.times, now.Add(-t.window))
if len(e.times) == 0 && !e.lastSeen.After(activityCutoff) {
delete(t.ips, k)
}
}
for k, s := range t.subnets {
pruneSubnetIPs(s, now.Add(-t.window))
if len(s.ips) == 0 && !s.lastSeen.After(activityCutoff) {
delete(t.subnets, k)
}
}
for k, a := range t.accounts {
pruneAccountIPs(a, now.Add(-t.window))
if len(a.ips) == 0 && !a.lastSeen.After(activityCutoff) {
delete(t.accounts, k)
}
}
}
// enforceMaxTracked evicts the least-recently-seen entries until the total
// number of tracked entities (IPs + subnets + accounts) is <= maxTracked.
// Caller must hold t.mu.
func (t *smtpAuthTracker) enforceMaxTracked() {
total := len(t.ips) + len(t.subnets) + len(t.accounts)
if total <= t.maxTracked {
return
}
type victim struct {
kind string // "ip" | "subnet" | "account"
key string
seen time.Time
}
victims := make([]victim, 0, total)
for k, v := range t.ips {
victims = append(victims, victim{"ip", k, v.lastSeen})
}
for k, v := range t.subnets {
victims = append(victims, victim{"subnet", k, v.lastSeen})
}
for k, v := range t.accounts {
victims = append(victims, victim{"account", k, v.lastSeen})
}
sort.Slice(victims, func(i, j int) bool { return victims[i].seen.Before(victims[j].seen) })
// Evict to 95% of cap so subsequent inserts don't re-trigger the sort.
target := t.maxTracked * 95 / 100
for i := 0; i < len(victims); i++ {
if len(t.ips)+len(t.subnets)+len(t.accounts) <= target {
break
}
v := victims[i]
switch v.kind {
case "ip":
delete(t.ips, v.key)
case "subnet":
delete(t.subnets, v.key)
case "account":
delete(t.accounts, v.key)
}
}
}
//go:build linux
package daemon
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"unsafe"
"golang.org/x/sys/unix"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/emailav"
emime "github.com/pidginhost/csm/internal/mime"
)
// fanotify constants for permission events (not in Go stdlib).
const (
FAN_CLASS_CONTENT = 0x00000004
FAN_OPEN_PERM = 0x00010000
FAN_ALLOW = 0x01
FAN_DENY = 0x02
FAN_EVENT_ON_CHILD = 0x08000000
)
// fanotifyResponse is the struct written back to the fanotify fd
// to allow or deny a permission event.
type fanotifyResponse struct {
Fd int32
Response uint32
}
const responseSize = int(unsafe.Sizeof(fanotifyResponse{}))
// SpoolWatcher monitors Exim spool directories for new messages using a
// dedicated fanotify instance with permission events (FAN_OPEN_PERM).
// It is completely separate from the FileMonitor.
type SpoolWatcher struct {
fd int
cfg *config.Config
alertCh chan<- alert.Finding
orchestrator *emailav.Orchestrator
quarantine *emailav.Quarantine
permissionMode bool // true if using FAN_OPEN_PERM, false if fallback to FAN_CLOSE_WRITE
scanCh chan spoolEvent
pipeFds [2]int
stopOnce sync.Once
drainOnce sync.Once
stopCh chan struct{}
wg sync.WaitGroup
pipeClosed int32 // atomic
fdClosed int32 // atomic - guards sw.fd against double-close
degradedMu sync.Mutex
lastDegradedAt time.Time
}
type spoolEvent struct {
path string
fd int // fanotify event fd (for permission response)
pid int32
needResp bool // true if permission event requiring response
}
// NewSpoolWatcher creates a dedicated fanotify instance for Exim spool scanning.
// Attempts FAN_CLASS_CONTENT with FAN_OPEN_PERM first; falls back to
// FAN_CLASS_NOTIF with FAN_CLOSE_WRITE if permission events are unavailable.
func NewSpoolWatcher(cfg *config.Config, alertCh chan<- alert.Finding, orch *emailav.Orchestrator, quar *emailav.Quarantine) (*SpoolWatcher, error) {
sw := &SpoolWatcher{
cfg: cfg,
alertCh: alertCh,
orchestrator: orch,
quarantine: quar,
scanCh: make(chan spoolEvent, 256),
stopCh: make(chan struct{}),
}
// Try permission-capable class first
fd, err := unix.FanotifyInit(FAN_CLASS_CONTENT|FAN_CLOEXEC|FAN_NONBLOCK, unix.O_RDONLY)
if err == nil {
sw.fd = fd
sw.permissionMode = true
fmt.Fprintf(os.Stderr, "[%s] spool watcher: permission events enabled (FAN_OPEN_PERM)\n", ts())
} else {
// Fallback to notification-only
fd, err = unix.FanotifyInit(FAN_CLASS_NOTIF|FAN_CLOEXEC|FAN_NONBLOCK, unix.O_RDONLY)
if err != nil {
return nil, fmt.Errorf("fanotify_init: %w (neither permission nor notification mode available)", err)
}
sw.fd = fd
sw.permissionMode = false
fmt.Fprintf(os.Stderr, "[%s] spool watcher: WARNING - permission events unavailable, using notification mode (small delivery race window possible)\n", ts())
if sw.cfg.EmailAV.FailMode == "tempfail" {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: WARNING - fail_mode=tempfail requested but cannot be honoured without FAN_OPEN_PERM; operating fail-open\n", ts())
}
}
// Mark spool directories
spoolDirs := []string{"/var/spool/exim/input", "/var/spool/exim4/input"}
var eventMask uint64
if sw.permissionMode {
eventMask = FAN_OPEN_PERM | FAN_EVENT_ON_CHILD
} else {
eventMask = FAN_CLOSE_WRITE | FAN_EVENT_ON_CHILD
}
marked := 0
for _, dir := range spoolDirs {
if _, err := os.Stat(dir); err != nil {
continue
}
// Use FAN_MARK_ADD (not FAN_MARK_MOUNT) to scope to the directory
err := unix.FanotifyMark(sw.fd, FAN_MARK_ADD, eventMask, -1, dir)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: cannot watch %s: %v\n", ts(), dir, err)
continue
}
marked++
fmt.Fprintf(os.Stderr, "[%s] spool watcher: watching %s\n", ts(), dir)
}
if marked == 0 {
_ = unix.Close(sw.fd)
return nil, fmt.Errorf("no Exim spool directories found to watch")
}
// Create pipe for stop signaling
if err := unix.Pipe2(sw.pipeFds[:], unix.O_NONBLOCK|unix.O_CLOEXEC); err != nil {
_ = unix.Close(sw.fd)
return nil, fmt.Errorf("creating pipe: %w", err)
}
return sw, nil
}
// Run starts the event loop and scanner workers. Blocks until Stop() is called.
func (sw *SpoolWatcher) Run() {
// Start scanner workers
concurrency := sw.cfg.EmailAV.ScanConcurrency
if concurrency < 1 {
concurrency = 4
}
for i := 0; i < concurrency; i++ {
sw.wg.Add(1)
go sw.scanWorker()
}
// Event loop
epfd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: epoll_create: %v\n", ts(), err)
return
}
defer func() { _ = unix.Close(epfd) }()
// #nosec G115 -- POSIX fd fits in int32 (rlimit ~1024). Same for all fd→int32 in this file.
if err := unix.EpollCtl(epfd, unix.EPOLL_CTL_ADD, sw.fd, &unix.EpollEvent{Events: unix.EPOLLIN, Fd: int32(sw.fd)}); err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: epoll_ctl(fanotify fd): %v\n", ts(), err)
return
}
// #nosec G115 -- POSIX fd fits in int32.
if err := unix.EpollCtl(epfd, unix.EPOLL_CTL_ADD, sw.pipeFds[0], &unix.EpollEvent{Events: unix.EPOLLIN, Fd: int32(sw.pipeFds[0])}); err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: epoll_ctl(pipe fd): %v\n", ts(), err)
return
}
events := make([]unix.EpollEvent, 16)
buf := make([]byte, 4096)
for {
select {
case <-sw.stopCh:
sw.drainAndClose()
return
default:
}
n, err := unix.EpollWait(epfd, events, 500)
if err != nil {
if err == unix.EINTR {
continue
}
select {
case <-sw.stopCh:
sw.drainAndClose()
return
default:
continue
}
}
for i := 0; i < n; i++ {
// #nosec G115 -- POSIX fd fits in int32.
if events[i].Fd == int32(sw.pipeFds[0]) {
sw.drainAndClose()
return
}
// #nosec G115 -- POSIX fd fits in int32.
if events[i].Fd == int32(sw.fd) {
sw.readEvents(buf)
}
}
}
}
func (sw *SpoolWatcher) readEvents(buf []byte) {
for {
n, err := unix.Read(sw.fd, buf)
if err != nil || n < metadataSize {
return
}
offset := 0
for offset+metadataSize <= n {
// #nosec G103 -- fanotify delivers a packed binary stream;
// reinterpretation is required and bounded by metadataSize above.
meta := (*fanotifyEventMetadata)(unsafe.Pointer(&buf[offset]))
if meta.Fd < 0 {
offset += int(meta.EventLen)
continue
}
path, err := os.Readlink(fmt.Sprintf("/proc/self/fd/%d", meta.Fd))
if err != nil {
// Must respond even on error paths (permission mode)
if sw.permissionMode {
sw.writeResponse(meta.Fd, FAN_ALLOW)
}
_ = unix.Close(int(meta.Fd))
offset += int(meta.EventLen)
continue
}
// Only process *-D files (message body)
if !strings.HasSuffix(path, "-D") {
if sw.permissionMode {
sw.writeResponse(meta.Fd, FAN_ALLOW)
}
_ = unix.Close(int(meta.Fd))
offset += int(meta.EventLen)
continue
}
// Send to scan workers - blocks if pool is full.
// This is intentional: backpressure on Exim's delivery runner
// is the correct behavior per the spec. Exim is designed to
// handle delivery delays; unscanned delivery is not acceptable.
evt := spoolEvent{
path: path,
fd: int(meta.Fd),
pid: meta.Pid,
needResp: sw.permissionMode,
}
select {
case sw.scanCh <- evt:
// Worker will handle response and fd close
case <-sw.stopCh:
// Shutting down - allow and close
if sw.permissionMode {
sw.writeResponse(meta.Fd, FAN_ALLOW)
}
_ = unix.Close(int(meta.Fd))
}
offset += int(meta.EventLen)
}
}
}
// scanWorker processes spool events: MIME parse, scan, quarantine/allow.
func (sw *SpoolWatcher) scanWorker() {
defer sw.wg.Done()
for {
select {
case <-sw.stopCh:
return
case evt, ok := <-sw.scanCh:
if !ok {
return
}
sw.handleSpoolEvent(evt)
}
}
}
func (sw *SpoolWatcher) handleSpoolEvent(evt spoolEvent) {
// CRITICAL: deferred FAN_ALLOW - every code path must allow by default.
// Only overridden to FAN_DENY on confirmed infection + successful quarantine.
response := uint32(FAN_ALLOW)
defer func() {
if evt.needResp {
// #nosec G115 -- evt.fd is a POSIX fd; fits in int32.
sw.writeResponse(int32(evt.fd), response)
}
_ = unix.Close(evt.fd)
}()
// Derive message ID: strip -D suffix and directory
base := filepath.Base(evt.path)
msgID := strings.TrimSuffix(base, "-D")
spoolDir := filepath.Dir(evt.path)
headerPath := filepath.Join(spoolDir, msgID+"-H")
bodyPath := evt.path
// MIME parse - fail-open on error
limits := emime.Limits{
MaxAttachmentSize: sw.cfg.EmailAV.MaxAttachmentSize,
MaxArchiveDepth: sw.cfg.EmailAV.MaxArchiveDepth,
MaxArchiveFiles: sw.cfg.EmailAV.MaxArchiveFiles,
MaxExtractionSize: sw.cfg.EmailAV.MaxExtractionSize,
}
tempfail := sw.cfg.EmailAV.FailMode == "tempfail" && evt.needResp
extraction, err := emime.ParseSpoolMessage(headerPath, bodyPath, limits)
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: MIME parse error for %s: %v\n", ts(), msgID, err)
sw.emitFinding("email_av_parse_error", alert.Warning, fmt.Sprintf("MIME parse failed for message %s: %v", msgID, err))
if tempfail {
response = FAN_DENY // tempfail: Exim retries later
}
return
}
// Clean up temp files when done
defer func() {
for _, p := range extraction.Parts {
os.Remove(p.TempPath)
}
}()
if len(extraction.Parts) == 0 {
return // No attachments to scan - allow
}
// Scan
result := sw.orchestrator.ScanParts(msgID, extraction.Parts, extraction.Partial)
// Emit degraded/timeout findings for operator visibility
if result.AllEnginesDown {
if shouldTempfailEmailDelivery(tempfail, result, nil) {
response = FAN_DENY // tempfail: defer delivery until engines recover
sw.emitDegradedWarning(fmt.Sprintf("All AV engines unavailable - message %s deferred (tempfail mode)", msgID))
return
}
sw.emitDegradedWarning(fmt.Sprintf("All AV engines unavailable - message %s delivered unscanned", msgID))
}
if len(result.TimedOutEngines) > 0 {
sw.emitFinding("email_av_timeout", alert.Warning,
fmt.Sprintf("Scan timeout for message %s on engine(s): %s", msgID, strings.Join(result.TimedOutEngines, ", ")))
if shouldTempfailEmailDelivery(tempfail, result, nil) {
sw.emitDegradedWarning(fmt.Sprintf("Incomplete AV scan - message %s deferred after engine timeout", msgID))
response = FAN_DENY
return
}
}
if !result.Infected {
return // Clean - allow
}
// Infected - attempt quarantine
env := emailav.QuarantineEnvelope{
From: extraction.From,
To: extraction.To,
Subject: extraction.Subject,
Direction: extraction.Direction,
}
if sw.cfg.EmailAV.QuarantineInfected {
if err := sw.quarantine.QuarantineMessage(msgID, spoolDir, result, env); err != nil {
fmt.Fprintf(os.Stderr, "[%s] spool watcher: quarantine failed for %s: %v\n", ts(), msgID, err)
sw.emitFinding("email_av_quarantine_error", alert.Warning,
fmt.Sprintf("Quarantine failed for infected message %s: %v", msgID, err))
if shouldTempfailEmailDelivery(tempfail, result, err) {
sw.emitDegradedWarning(fmt.Sprintf("Quarantine failed for infected message %s - delivery deferred (tempfail mode)", msgID))
response = FAN_DENY
}
} else {
// Quarantine succeeded - deny the open so Exim can't deliver
response = FAN_DENY
}
}
// Emit alert finding
sigNames := make([]string, len(result.Findings))
for i, f := range result.Findings {
sigNames[i] = fmt.Sprintf("%s(%s)", f.Signature, f.Engine)
}
msg := fmt.Sprintf("Malware detected in %s email from %s to %s: %s [subject: %s]",
extraction.Direction, extraction.From, strings.Join(extraction.To, ","),
strings.Join(sigNames, ", "), extraction.Subject)
sw.emitFinding("email_malware", alert.Critical, msg)
}
func (sw *SpoolWatcher) writeResponse(fd int32, response uint32) {
resp := fanotifyResponse{Fd: fd, Response: response}
// #nosec G103 -- serializing the fanotify response struct for the
// kernel write; unsafe cast to a byte slice of the exact struct size.
respBytes := (*[responseSize]byte)(unsafe.Pointer(&resp))[:]
_, err := unix.Write(sw.fd, respBytes)
if err != nil {
// The kernel holds blocked processes until a response is written or
// the fanotify fd is closed. A failed write means the fd is broken -
// close it to release ALL pending permission events (fail-open),
// then signal the event loop to exit so the daemon can restart us.
fmt.Fprintf(os.Stderr, "[%s] spool watcher: FATAL - fanotify response write failed: %v - closing fd to release pending events\n", ts(), err)
sw.closeFd()
sw.Stop()
}
}
// closeFd closes the fanotify fd exactly once, even if called from multiple paths.
func (sw *SpoolWatcher) closeFd() {
if atomic.CompareAndSwapInt32(&sw.fdClosed, 0, 1) {
_ = unix.Close(sw.fd)
}
}
func (sw *SpoolWatcher) emitFinding(check string, severity alert.Severity, message string) {
select {
case sw.alertCh <- alert.Finding{
Severity: severity,
Check: check,
Message: message,
}:
default:
// Alert channel full - drop
}
}
// emitDegradedWarning emits an email_av_degraded finding, rate-limited to
// once per minute to avoid flooding the alert channel when clamd is down.
func (sw *SpoolWatcher) emitDegradedWarning(message string) {
sw.degradedMu.Lock()
if time.Since(sw.lastDegradedAt) < time.Minute {
sw.degradedMu.Unlock()
return
}
sw.lastDegradedAt = time.Now()
sw.degradedMu.Unlock()
sw.emitFinding("email_av_degraded", alert.Warning, message)
}
func (sw *SpoolWatcher) drainAndClose() {
sw.drainOnce.Do(func() {
close(sw.scanCh)
sw.wg.Wait()
sw.closeFd()
if atomic.CompareAndSwapInt32(&sw.pipeClosed, 0, 1) {
_ = unix.Close(sw.pipeFds[0])
_ = unix.Close(sw.pipeFds[1])
}
})
}
// PermissionMode returns true if using FAN_OPEN_PERM, false if FAN_CLOSE_WRITE fallback.
func (sw *SpoolWatcher) PermissionMode() bool {
return sw.permissionMode
}
// Stop signals the event loop to exit.
func (sw *SpoolWatcher) Stop() {
sw.stopOnce.Do(func() {
close(sw.stopCh)
if atomic.LoadInt32(&sw.pipeClosed) == 0 {
_, _ = unix.Write(sw.pipeFds[1], []byte{0})
}
sw.closeFd()
})
}
package daemon
import "github.com/pidginhost/csm/internal/emailav"
func shouldTempfailEmailDelivery(tempfail bool, result *emailav.ScanResult, quarantineErr error) bool {
if !tempfail {
return false
}
if quarantineErr != nil {
return true
}
if result == nil {
return false
}
return result.AllEnginesDown || len(result.TimedOutEngines) > 0
}
package daemon
import (
"bufio"
"context"
"fmt"
"io"
"net"
"os"
"os/exec"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/store"
)
// LogLineHandler parses a log line and returns findings (if any).
type LogLineHandler func(line string, cfg *config.Config) []alert.Finding
// LogWatcher tails a log file using inotify and processes new lines.
type LogWatcher struct {
path string
cfg *config.Config
handler LogLineHandler
alertCh chan<- alert.Finding
file *os.File
offset int64
}
// NewLogWatcher creates a watcher for a log file.
func NewLogWatcher(path string, cfg *config.Config, handler LogLineHandler, alertCh chan<- alert.Finding) (*LogWatcher, error) {
// #nosec G304 -- path is operator-configured log path from csm.yaml.
f, err := os.Open(path)
if err != nil {
return nil, err
}
// Seek to end - only process new lines
offset, err := f.Seek(0, io.SeekEnd)
if err != nil {
_ = f.Close()
return nil, err
}
return &LogWatcher{
path: path,
cfg: cfg,
handler: handler,
alertCh: alertCh,
file: f,
offset: offset,
}, nil
}
// Run starts watching the log file. Uses polling (every 2 seconds) instead of
// inotify to avoid complexity with log rotation. Simple, reliable, low overhead.
func (w *LogWatcher) Run(stopCh <-chan struct{}) {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
// Also reopen the file every 5 minutes to handle log rotation
reopenTicker := time.NewTicker(5 * time.Minute)
defer reopenTicker.Stop()
for {
select {
case <-stopCh:
return
case <-reopenTicker.C:
w.reopen()
case <-ticker.C:
w.readNewLines()
}
}
}
// Stop closes the watcher.
func (w *LogWatcher) Stop() {
if w.file != nil {
_ = w.file.Close()
}
}
func (w *LogWatcher) readNewLines() {
info, err := w.file.Stat()
if err != nil {
w.reopen()
return
}
// File was truncated or rotated (smaller than our offset)
if info.Size() < w.offset {
w.reopen()
return
}
// No new data
if info.Size() == w.offset {
return
}
// Seek to where we left off
_, err = w.file.Seek(w.offset, io.SeekStart)
if err != nil {
return
}
scanner := bufio.NewScanner(w.file)
scanner.Buffer(make([]byte, 0, 64*1024), 256*1024)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
findings := w.handler(line, w.cfg)
for _, f := range findings {
if f.Timestamp.IsZero() {
f.Timestamp = time.Now()
}
select {
case w.alertCh <- f:
default:
// Channel full - drop (backpressure)
fmt.Fprintf(os.Stderr, "[%s] Warning: alert channel full, dropping finding from %s\n", ts(), w.path)
}
}
}
// Update offset
newOffset, err := w.file.Seek(0, io.SeekCurrent)
if err == nil {
w.offset = newOffset
}
}
func (w *LogWatcher) reopen() {
if w.file != nil {
_ = w.file.Close()
}
f, err := os.Open(w.path)
if err != nil {
return
}
// If the file is new (after rotation), start from beginning
info, err := f.Stat()
if err != nil {
_ = f.Close()
return
}
w.file = f
if info.Size() < w.offset {
// File was rotated - read from start
w.offset = 0
} else {
// Same file or larger - seek to where we were
w.offset, _ = f.Seek(0, io.SeekEnd)
}
}
// --- Log line handlers ---
func parseSessionLogLine(line string, cfg *config.Config) []alert.Finding {
var findings []alert.Finding
// cPanel login from non-infra IP - only alert on direct form login,
// not API-created sessions (from portal create_user_session)
if strings.Contains(line, "[cpaneld]") && strings.Contains(line, " NEW ") {
// Track IP→account for purge correlation (before any filtering)
if loginIP, loginAccount := parseCpanelSessionLogin(line); loginIP != "" && loginAccount != "" {
purgeTracker.recordLogin(loginIP, loginAccount)
}
switch {
case cfg.Suppressions.SuppressCpanelLogin:
// Skip all cPanel login alerts
case strings.Contains(line, "method=create_user_session") ||
strings.Contains(line, "method=create_session") ||
strings.Contains(line, "create_user_session"):
// Portal-created session - no alert
default:
ip, account := parseCpanelSessionLogin(line)
if ip != "" && account != "" && !isInfraIPDaemon(ip, cfg.InfraIPs) &&
!isTrustedCountry(ip, cfg.Suppressions.TrustedCountries) {
// WARNING severity - logins are useful for audit trail but
// not paging-level. Multi-IP correlation and brute-force
// stay at CRITICAL/HIGH via their own checks.
method := "unknown"
if strings.Contains(line, "method=handle_form_login") {
method = "direct form login"
} else if idx := strings.Index(line, "method="); idx >= 0 {
rest := line[idx+7:]
if comma := strings.IndexAny(rest, ",\n "); comma > 0 {
method = rest[:comma]
}
}
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "cpanel_login_realtime",
Message: fmt.Sprintf("cPanel direct login from non-infra IP: %s (account: %s, method: %s)", ip, account, method),
Details: truncateDaemon(line, 300),
})
}
}
}
// Password purge
if strings.Contains(line, "PURGE") && strings.Contains(line, "password_change") {
account := parsePurgeDaemon(line)
if account != "" {
purgeTracker.recordPurge(account)
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "cpanel_password_purge_realtime",
Message: fmt.Sprintf("cPanel password purge for: %s", account),
})
}
}
return findings
}
func parseSecureLogLine(line string, cfg *config.Config) []alert.Finding {
var findings []alert.Finding
if !strings.Contains(line, "Accepted") {
return nil
}
// Extract IP
parts := strings.Fields(line)
for i, p := range parts {
if p == "from" && i+1 < len(parts) {
ip := parts[i+1]
if isInfraIPDaemon(ip, cfg.InfraIPs) || ip == "127.0.0.1" {
return nil
}
user := "unknown"
for j, q := range parts {
if q == "for" && j+1 < len(parts) {
user = parts[j+1]
break
}
}
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "ssh_login_realtime",
Message: fmt.Sprintf("SSH login from non-infra IP: %s (user: %s)", ip, user),
Details: truncateDaemon(line, 200),
})
break
}
}
return findings
}
func parseEximLogLine(line string, cfg *config.Config) []alert.Finding {
var findings []alert.Finding
// 1. Frozen bounces - spam indicator
if strings.Contains(line, "frozen") {
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "exim_frozen_realtime",
Message: "Exim frozen message detected",
Details: truncateDaemon(line, 200),
})
}
// 2. Outgoing mail hold - account suspended for spam
// Format: "Sender office@example.com has an outgoing mail hold"
// or: "Domain example.org has an outgoing mail hold"
// Dedup: only alert once per domain per hour (exim retries held messages
// every few minutes, generating the same log line each time)
if strings.Contains(line, "outgoing mail hold") {
sender := extractMailHoldSender(line)
if sender == "" {
sender = extractEximSender(line)
}
domain := extractDomainFromEmail(sender)
if domain == "" {
domain = sender // may already be a bare domain
}
// Auto-suspend regardless of dedup - idempotent, ensures hold stays on
if sender != "" {
autoSuspendOutgoingMail(sender)
}
if domain != "" {
RecordCompromisedDomain(domain)
}
// Alert only once per domain per hour
dedupKey := "email_hold:" + domain
if db := store.Global(); db != nil {
lastAlert := db.GetMetaString(dedupKey)
if lastAlert != "" && !isDedupExpired(lastAlert, 1*time.Hour) {
// Already alerted for this domain recently - skip finding
} else {
_ = db.SetMetaString(dedupKey, time.Now().Format(time.RFC3339))
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_compromised_account",
Message: fmt.Sprintf("Account %s has outgoing mail hold - outgoing mail auto-suspended", sender),
Details: truncateDaemon(line, 300),
})
}
}
}
// 3. Max defers/failures exceeded - active spam outbreak
if strings.Contains(line, "max defers and failures per hour") {
domain := extractEximDomain(line)
// Auto-suspend: confirmed spam outbreak
autoSuspendOutgoingMail(domain)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_spam_outbreak",
Message: fmt.Sprintf("Spam outbreak: %s exceeded max defers/failures - outgoing mail auto-suspended", domain),
Details: truncateDaemon(line, 300),
})
if domain != "" {
RecordCompromisedDomain(domain)
}
}
// 4. SMTP credentials leaked in subject - compromised account
// Pattern: T="...host:port,user@domain,PASSWORD..." in the subject field
if strings.Contains(line, " <= ") && strings.Contains(line, "T=\"") {
subject := extractEximSubject(line)
subjectLower := strings.ToLower(subject)
// Detect credential patterns: host:port,user,password or
// SMTP credentials in subject (common in credential stuffing attacks)
if (strings.Contains(subject, ":587,") || strings.Contains(subject, ":465,") ||
strings.Contains(subject, ":25,")) &&
strings.Contains(subject, "@") {
sender := extractEximSender(line)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_credential_leak",
Message: fmt.Sprintf("SMTP credentials leaked in email subject from %s", sender),
Details: fmt.Sprintf("The email subject contains what appears to be SMTP credentials (host:port,user,password). This account is likely compromised by a bulk mail service.\nSubject: %s", truncateDaemon(subject, 100)),
})
}
// Also detect common spam subject patterns
if strings.Contains(subjectLower, "password") && strings.Contains(subjectLower, "smtp") {
sender := extractEximSender(line)
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_credential_leak",
Message: fmt.Sprintf("Suspicious email subject with SMTP/password keywords from %s", sender),
Details: truncateDaemon(line, 300),
})
}
}
// 5. Authentication from known bulk mail services
if strings.Contains(line, " <= ") && strings.Contains(line, "A=dovecot_") {
knownSpamServices := []string{
"truelist.io", "sendinblue.com", "mailspree.co",
"bulkmailer.", "massmailsoftware.", "sendblaster.",
}
lineLower := strings.ToLower(line)
for _, service := range knownSpamServices {
if strings.Contains(lineLower, service) {
sender := extractEximSender(line)
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_compromised_account",
Message: fmt.Sprintf("Compromised email account %s authenticated from bulk mail service %s", sender, service),
Details: truncateDaemon(line, 300),
})
break
}
}
}
// 6. Dovecot auth failure - brute force indicator
// Format: "dovecot_login authenticator failed for H=(hostname) [IP]:port: 535 ... (set_id=user@domain)"
if strings.Contains(line, "authenticator failed") && strings.Contains(line, "dovecot") {
ip := extractBracketedIP(line)
account := extractSetID(line)
msg := "Email authentication failure"
if account != "" {
msg += " for " + account
}
if ip != "" {
msg += " from " + ip
}
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_auth_failure_realtime",
Message: msg,
Details: truncateDaemon(line, 300),
})
}
// 7. DKIM signing failures
if dkimDomain := parseDKIMFailureDomain(line); dkimDomain != "" {
dedupKey := "dkim_fail:" + dkimDomain
if db := store.Global(); db != nil {
lastAlert := db.GetMetaString(dedupKey)
if lastAlert == "" || isDedupExpired(lastAlert, 24*time.Hour) {
_ = db.SetMetaString(dedupKey, time.Now().Format(time.RFC3339))
findings = append(findings, alert.Finding{
Severity: alert.Warning,
Check: "email_dkim_failure",
Message: fmt.Sprintf("DKIM signing failed for %s - check key file and DNS TXT record", dkimDomain),
Details: truncateDaemon(line, 300),
Timestamp: time.Now(),
})
}
}
}
// 8. SPF/DMARC outbound rejections
if spfDomain, spfReason := parseSPFDMARCRejection(line); spfDomain != "" {
dedupKey := "spf_reject:" + spfDomain
if db := store.Global(); db != nil {
lastAlert := db.GetMetaString(dedupKey)
if lastAlert == "" || isDedupExpired(lastAlert, 24*time.Hour) {
_ = db.SetMetaString(dedupKey, time.Now().Format(time.RFC3339))
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_spf_rejection",
Message: fmt.Sprintf("Outbound mail from %s rejected due to SPF/DMARC failure", spfDomain),
Details: fmt.Sprintf("Reason: %s\n%s", spfReason, truncateDaemon(line, 200)),
Timestamp: time.Now(),
})
}
}
}
// 9. Outbound rate limiting for authenticated users
if strings.Contains(line, " <= ") && strings.Contains(line, "A=dovecot_") {
authUser := extractAuthUser(line)
if authUser != "" {
rateFindings := checkEmailRate(authUser, cfg)
findings = append(findings, rateFindings...)
}
}
return findings
}
// extractEximSender extracts the sender address from an exim log line.
// Format: "... <= sender@domain.com H=..."
func extractEximSender(line string) string {
idx := strings.Index(line, " <= ")
if idx < 0 {
return ""
}
rest := line[idx+4:]
fields := strings.Fields(rest)
if len(fields) > 0 {
return fields[0]
}
return ""
}
// extractEximDomain extracts a domain from an exim log line mentioning
// "Domain X has exceeded".
func extractEximDomain(line string) string {
idx := strings.Index(line, "Domain ")
if idx < 0 {
return ""
}
rest := line[idx+7:]
if sp := strings.IndexByte(rest, ' '); sp > 0 {
return rest[:sp]
}
return rest
}
// extractEximSubject extracts the subject from T="..." in an exim log line.
func extractEximSubject(line string) string {
idx := strings.Index(line, "T=\"")
if idx < 0 {
return ""
}
rest := line[idx+3:]
end := strings.Index(rest, "\"")
if end < 0 {
return rest
}
return rest[:end]
}
// --- Helpers (avoid import cycle with checks package) ---
func parseCpanelSessionLogin(line string) (ip, account string) {
idx := strings.Index(line, "[cpaneld]")
if idx < 0 {
return "", ""
}
rest := strings.TrimSpace(line[idx+len("[cpaneld]"):])
fields := strings.Fields(rest)
if len(fields) < 3 {
return "", ""
}
ip = fields[0]
for i, f := range fields {
if f == "NEW" && i+1 < len(fields) {
parts := strings.SplitN(fields[i+1], ":", 2)
if len(parts) >= 1 {
account = parts[0]
}
break
}
}
return ip, account
}
func parsePurgeDaemon(line string) string {
idx := strings.Index(line, "PURGE")
if idx < 0 {
return ""
}
rest := strings.TrimSpace(line[idx+len("PURGE"):])
fields := strings.Fields(rest)
if len(fields) < 1 {
return ""
}
parts := strings.SplitN(fields[0], ":", 2)
if len(parts) >= 1 {
return parts[0]
}
return ""
}
func isInfraIPDaemon(ip string, infraNets []string) bool {
parsed := net.ParseIP(ip)
if parsed == nil {
return false
}
for _, cidr := range infraNets {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
// Try as plain IP
if ip == cidr {
return true
}
continue
}
if network.Contains(parsed) {
return true
}
}
return false
}
// mergeInfraIPs combines top-level infra IPs with firewall-specific ones,
// deduplicating entries. This allows the firewall to include additional CIDRs
// (e.g. server's own range) that need port access but shouldn't suppress alerts.
func mergeInfraIPs(topLevel, fwSpecific []string) []string {
seen := make(map[string]bool, len(topLevel)+len(fwSpecific))
var merged []string
for _, ip := range topLevel {
if !seen[ip] {
seen[ip] = true
merged = append(merged, ip)
}
}
for _, ip := range fwSpecific {
if !seen[ip] {
seen[ip] = true
merged = append(merged, ip)
}
}
return merged
}
// autoSuspendOutgoingMail calls whmapi1 to hold outgoing mail for the cPanel
// account that owns the given domain or email address. This is safe to call
// on confirmed spam (cPanel already flagged it via mail hold or max defers).
func autoSuspendOutgoingMail(domainOrEmail string) {
if domainOrEmail == "" {
return
}
// Extract domain from email if needed
domain := domainOrEmail
if atIdx := strings.LastIndexByte(domain, '@'); atIdx >= 0 {
domain = domain[atIdx+1:]
}
// Look up cPanel username for this domain
user := lookupCPanelUser(domain)
if user == "" {
fmt.Fprintf(os.Stderr, "[%s] auto-suspend: could not find cPanel user for domain %s\n",
time.Now().Format("2006-01-02 15:04:05"), domain)
return
}
// Hold outgoing mail
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// #nosec G204 -- whmapi1 is the cPanel API binary; user is a cPanel
// account name validated upstream (cpuser regex in the caller).
out, err := exec.CommandContext(ctx, "whmapi1", "hold_outgoing_email", "user="+user).CombinedOutput()
if err != nil {
fmt.Fprintf(os.Stderr, "[%s] auto-suspend: whmapi1 hold_outgoing_email failed for %s: %v\n%s\n",
time.Now().Format("2006-01-02 15:04:05"), user, err, string(out))
return
}
fmt.Fprintf(os.Stderr, "[%s] AUTO-SUSPEND: outgoing mail held for cPanel user %s (domain: %s)\n",
time.Now().Format("2006-01-02 15:04:05"), user, domain)
}
// userdomainsPath is the cPanel domain→user map file. var (not const)
// so tests can point it at a fixture under t.TempDir(). Production must
// not mutate at runtime.
var userdomainsPath = "/etc/userdomains"
// lookupCPanelUser finds the cPanel username that owns a domain.
// Reads userdomainsPath which maps "domain: user" per line.
func lookupCPanelUser(domain string) string {
f, err := os.Open(userdomainsPath)
if err != nil {
return ""
}
defer func() { _ = f.Close() }()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
parts := strings.SplitN(line, ":", 2)
if len(parts) != 2 {
continue
}
d := strings.TrimSpace(parts[0])
u := strings.TrimSpace(parts[1])
if strings.EqualFold(d, domain) {
return u
}
}
return ""
}
// extractMailHoldSender extracts the account/domain from outgoing mail hold messages.
//
// Two formats:
//
// "Sender user@domain has an outgoing mail hold" -> "user@domain"
// "Domain example.com has an outgoing mail hold" -> "example.com"
func extractMailHoldSender(line string) string {
// Try "Sender user@domain" first
if idx := strings.Index(line, "Sender "); idx >= 0 {
rest := line[idx+7:]
if sp := strings.IndexByte(rest, ' '); sp > 0 {
return rest[:sp]
}
return rest
}
// Try "Domain example.com" format
if idx := strings.Index(line, "Domain "); idx >= 0 {
rest := line[idx+7:]
if sp := strings.IndexByte(rest, ' '); sp > 0 {
return rest[:sp]
}
return rest
}
return ""
}
// extractBracketedIP extracts an IP from [IP]:port or [IP] format in exim logs.
func extractBracketedIP(line string) string {
// Find the LAST [IP] in the line (the client IP, not the hostname)
lastBracket := strings.LastIndex(line, "[")
if lastBracket < 0 {
return ""
}
rest := line[lastBracket+1:]
end := strings.IndexByte(rest, ']')
if end < 0 {
return ""
}
ip := rest[:end]
if len(ip) >= 7 && (ip[0] >= '0' && ip[0] <= '9' || ip[0] == ':') {
return ip
}
return ""
}
// extractSetID extracts the account from "(set_id=user@domain)" or "(set_id=user)" in exim logs.
func extractSetID(line string) string {
const prefix = "set_id="
idx := strings.Index(line, prefix)
if idx < 0 {
return ""
}
rest := line[idx+len(prefix):]
end := strings.IndexAny(rest, ")\n ")
if end < 0 {
return rest
}
return rest[:end]
}
func truncateDaemon(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// parseDKIMFailureDomain extracts domain from "DKIM: signing failed for {domain}"
func parseDKIMFailureDomain(line string) string {
const prefix = "DKIM: signing failed for "
idx := strings.Index(line, prefix)
if idx < 0 {
return ""
}
rest := line[idx+len(prefix):]
end := strings.IndexAny(rest, ": \t\n")
if end < 0 {
return rest
}
return rest[:end]
}
// parseSPFDMARCRejection extracts SENDER domain and rejection reason from
// exim ** permanent failure lines. Sender comes from <envelope_sender>.
func parseSPFDMARCRejection(line string) (senderDomain, reason string) {
starIdx := strings.Index(line, " ** ")
if starIdx < 0 {
return "", ""
}
// Extract envelope sender from <sender@domain> - search AFTER the **
// marker to avoid matching earlier <> fields (e.g. H=<hostname>).
rest := line[starIdx:]
ltIdx := strings.Index(rest, "<")
gtIdx := strings.Index(rest, ">")
if ltIdx < 0 || gtIdx < 0 || gtIdx <= ltIdx+1 {
return "", ""
}
sender := rest[ltIdx+1 : gtIdx]
atIdx := strings.LastIndexByte(sender, '@')
if atIdx < 0 || atIdx >= len(sender)-1 {
return "", ""
}
domain := sender[atIdx+1:]
// Extract rejection reason after last " : "
colonIdx := strings.LastIndex(line, " : ")
if colonIdx < 0 {
return "", ""
}
reason = strings.TrimSpace(line[colonIdx+3:])
if !isSPFDMARCRelated(reason) {
return "", ""
}
if len(reason) > 200 {
reason = reason[:200]
}
return domain, reason
}
// isSPFDMARCRelated checks if a rejection reason is SPF/DMARC related.
// Generic 5.7.1 alone is NOT sufficient - requires explicit auth keywords.
func isSPFDMARCRelated(reason string) bool {
if reason == "" {
return false
}
lower := strings.ToLower(reason)
for _, kw := range []string{"spf", "dmarc", "dkim"} {
if strings.Contains(lower, kw) {
return true
}
}
for _, code := range []string{"5.7.23", "5.7.25", "5.7.26"} {
if strings.Contains(lower, code) {
return true
}
}
if strings.Contains(lower, "5.7.1") {
if strings.Contains(lower, "authentication") || strings.Contains(lower, "ptr record") ||
strings.Contains(lower, "sender policy") || strings.Contains(lower, "alignment") {
return true
}
}
return false
}
// isDedupExpired checks if a stored RFC3339 timestamp is older than the given duration.
func isDedupExpired(stored string, window time.Duration) bool {
t, err := time.Parse(time.RFC3339, stored)
if err != nil {
return true
}
return time.Since(t) > window
}
// --- Outbound email rate limiting ---
// rateWindow tracks send timestamps for a single authenticated user.
type rateWindow struct {
mu sync.Mutex
times []time.Time
alerted string // last threshold level alerted ("warn" or "crit") - prevents repeated alerts per window
}
// add appends a timestamp to the window.
func (rw *rateWindow) add(t time.Time) {
rw.times = append(rw.times, t)
}
// countInWindow returns the number of timestamps within the window duration.
// Caller must hold rw.mu.
func (rw *rateWindow) countInWindow(now time.Time, window time.Duration) int {
cutoff := now.Add(-window)
count := 0
for _, t := range rw.times {
if t.After(cutoff) {
count++
}
}
return count
}
// prune removes timestamps older than the window duration and resets the
// alerted flag when the count drops below thresholds. Caller must hold rw.mu.
func (rw *rateWindow) prune(now time.Time, window time.Duration) {
cutoff := now.Add(-window)
kept := rw.times[:0]
for _, t := range rw.times {
if t.After(cutoff) {
kept = append(kept, t)
}
}
rw.times = kept
}
// emailRateWindows tracks per-user send rate windows.
var emailRateWindows sync.Map // map[string]*rateWindow
// extractAuthUser extracts the authenticated user from an exim <= line.
// Looks for A=dovecot_login:{user} or A=dovecot_plain:{user}.
// Returns empty string if not found or line is not an acceptance line.
func extractAuthUser(line string) string {
if !strings.Contains(line, " <= ") {
return ""
}
// Look for A=dovecot_login: or A=dovecot_plain:
for _, prefix := range []string{"A=dovecot_login:", "A=dovecot_plain:"} {
idx := strings.Index(line, prefix)
if idx < 0 {
continue
}
rest := line[idx+len(prefix):]
end := strings.IndexAny(rest, " \t\n")
if end < 0 {
return rest
}
return rest[:end]
}
return ""
}
// isHighVolumeSender checks if a user is in the high-volume senders allowlist.
func isHighVolumeSender(user string, allowlist []string) bool {
for _, allowed := range allowlist {
if strings.EqualFold(user, allowed) {
return true
}
}
return false
}
// extractDomainFromEmail returns the domain part of an email address.
func extractDomainFromEmail(email string) string {
idx := strings.LastIndexByte(email, '@')
if idx < 0 || idx >= len(email)-1 {
return ""
}
return email[idx+1:]
}
// hasRecentCompromisedFinding checks if there's a recent email_compromised_account
// or email_spam_outbreak finding for the given domain (suppresses rate alerts).
func hasRecentCompromisedFinding(domain string) bool {
emailRateSuppressed.mu.Lock()
defer emailRateSuppressed.mu.Unlock()
if ts, ok := emailRateSuppressed.domains[domain]; ok {
if time.Since(ts) < time.Hour {
return true
}
delete(emailRateSuppressed.domains, domain)
}
return false
}
// emailRateSuppressed tracks domains with recent compromised/spam findings.
var emailRateSuppressed = struct {
mu sync.Mutex
domains map[string]time.Time
}{domains: make(map[string]time.Time)}
// RecordCompromisedDomain marks a domain as having a recent compromised finding.
// Called from parseEximLogLine when email_compromised_account or email_spam_outbreak fires.
func RecordCompromisedDomain(domain string) {
emailRateSuppressed.mu.Lock()
defer emailRateSuppressed.mu.Unlock()
emailRateSuppressed.domains[domain] = time.Now()
}
// checkEmailRate processes an outbound email for rate limiting.
// Returns findings if thresholds are exceeded.
func checkEmailRate(user string, cfg *config.Config) []alert.Finding {
// Guard: skip if thresholds are zero (misconfigured or disabled)
if cfg.EmailProtection.RateWarnThreshold <= 0 || cfg.EmailProtection.RateCritThreshold <= 0 {
return nil
}
if isHighVolumeSender(user, cfg.EmailProtection.HighVolumeSenders) {
return nil
}
// Load or create rate window for this user
val, _ := emailRateWindows.LoadOrStore(user, &rateWindow{})
rw := val.(*rateWindow)
now := time.Now()
windowDur := time.Duration(cfg.EmailProtection.RateWindowMin) * time.Minute
rw.mu.Lock()
defer rw.mu.Unlock()
// Check domain suppression BEFORE adding to window - prevents
// phantom rate inflation for suppressed domains.
domain := extractDomainFromEmail(user)
if domain != "" && hasRecentCompromisedFinding(domain) {
return nil
}
rw.add(now)
count := rw.countInWindow(now, windowDur)
// Reset alerted state when count drops below warn threshold -
// allows re-alerting on the next burst after the window slides.
if count < cfg.EmailProtection.RateWarnThreshold {
rw.alerted = ""
}
var findings []alert.Finding
if count >= cfg.EmailProtection.RateCritThreshold {
if rw.alerted != "crit" {
rw.alerted = "crit"
findings = append(findings, alert.Finding{
Severity: alert.Critical,
Check: "email_rate_critical",
Message: fmt.Sprintf("Email rate CRITICAL: %s sent %d messages in %d minutes (threshold: %d)", user, count, cfg.EmailProtection.RateWindowMin, cfg.EmailProtection.RateCritThreshold),
Details: fmt.Sprintf("User: %s\nMessages in window: %d\nWindow: %d minutes\nThreshold: %d", user, count, cfg.EmailProtection.RateWindowMin, cfg.EmailProtection.RateCritThreshold),
})
}
} else if count >= cfg.EmailProtection.RateWarnThreshold {
if rw.alerted != "warn" && rw.alerted != "crit" {
rw.alerted = "warn"
findings = append(findings, alert.Finding{
Severity: alert.High,
Check: "email_rate_warning",
Message: fmt.Sprintf("Email rate WARNING: %s sent %d messages in %d minutes (threshold: %d)", user, count, cfg.EmailProtection.RateWindowMin, cfg.EmailProtection.RateWarnThreshold),
Details: fmt.Sprintf("User: %s\nMessages in window: %d\nWindow: %d minutes\nThreshold: %d", user, count, cfg.EmailProtection.RateWindowMin, cfg.EmailProtection.RateWarnThreshold),
})
}
}
return findings
}
// StartEmailRateEviction starts a background goroutine that prunes expired
// rate windows every 10 minutes. Same pattern as StartModSecEviction.
func StartEmailRateEviction(stopCh <-chan struct{}) {
go func() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case now := <-ticker.C:
evictEmailRateWindows(now)
}
}
}()
}
// evictEmailRateWindows prunes all per-user rate windows and deletes empty entries.
func evictEmailRateWindows(now time.Time) {
// Use a generous 60-minute eviction window to avoid premature deletion.
// The actual rate window is checked during rate evaluation.
evictWindow := 60 * time.Minute
emailRateWindows.Range(func(key, val any) bool {
rw := val.(*rateWindow)
rw.mu.Lock()
rw.prune(now, evictWindow)
empty := len(rw.times) == 0
if empty {
rw.alerted = ""
}
rw.mu.Unlock()
if empty {
emailRateWindows.Delete(key)
}
return true
})
// Also prune the suppressed domains map
emailRateSuppressed.mu.Lock()
for domain, ts := range emailRateSuppressed.domains {
if time.Since(ts) > time.Hour {
delete(emailRateSuppressed.domains, domain)
}
}
emailRateSuppressed.mu.Unlock()
}
package emailav
import (
"encoding/binary"
"fmt"
"io"
"net"
"os"
"strings"
"time"
)
// ClamdScanner scans files via the clamd Unix socket using INSTREAM.
type ClamdScanner struct {
socketPath string
}
// NewClamdScanner creates a ClamdScanner that connects to clamd at the given socket path.
func NewClamdScanner(socketPath string) *ClamdScanner {
return &ClamdScanner{socketPath: socketPath}
}
func (s *ClamdScanner) Name() string { return "clamav" }
// Available checks if clamd is reachable by attempting a connection.
func (s *ClamdScanner) Available() bool {
conn, err := net.DialTimeout("unix", s.socketPath, 2*time.Second)
if err != nil {
return false
}
_ = conn.Close()
return true
}
// Scan sends the file at path to clamd via INSTREAM and returns the verdict.
func (s *ClamdScanner) Scan(path string) (Verdict, error) {
// #nosec G304 -- path is mail queue file path from mail scanner walk.
f, err := os.Open(path)
if err != nil {
return Verdict{}, fmt.Errorf("opening file: %w", err)
}
defer f.Close()
conn, err := net.DialTimeout("unix", s.socketPath, 5*time.Second)
if err != nil {
return Verdict{}, fmt.Errorf("connecting to clamd: %w", err)
}
defer func() { _ = conn.Close() }()
err = conn.SetDeadline(time.Now().Add(30 * time.Second))
if err != nil {
return Verdict{}, fmt.Errorf("setting deadline: %w", err)
}
// Send INSTREAM command
_, err = conn.Write([]byte("nINSTREAM\n"))
if err != nil {
return Verdict{}, fmt.Errorf("sending INSTREAM: %w", err)
}
// Stream file content in chunks
buf := make([]byte, 8192)
lenBuf := make([]byte, 4)
for {
n, readErr := f.Read(buf)
if n > 0 {
// #nosec G115 -- n is bounded by len(buf)=8192; fits in uint32.
binary.BigEndian.PutUint32(lenBuf, uint32(n))
_, err = conn.Write(lenBuf)
if err != nil {
return Verdict{}, fmt.Errorf("sending chunk length: %w", err)
}
_, err = conn.Write(buf[:n])
if err != nil {
return Verdict{}, fmt.Errorf("sending chunk data: %w", err)
}
}
if readErr == io.EOF {
break
}
if readErr != nil {
return Verdict{}, fmt.Errorf("reading file: %w", readErr)
}
}
// Send terminator (4 zero bytes)
binary.BigEndian.PutUint32(lenBuf, 0)
_, err = conn.Write(lenBuf)
if err != nil {
return Verdict{}, fmt.Errorf("sending terminator: %w", err)
}
// Read response
resp := make([]byte, 4096)
n, err := conn.Read(resp)
if err != nil && err != io.EOF {
return Verdict{}, fmt.Errorf("reading response: %w", err)
}
return parseClamdResponse(string(resp[:n]))
}
// parseClamdResponse parses a clamd INSTREAM response line.
// "stream: OK\n" → clean
// "stream: Win.Trojan.Agent-123 FOUND\n" → infected
func parseClamdResponse(resp string) (Verdict, error) {
resp = strings.TrimSpace(resp)
if strings.HasSuffix(resp, "OK") {
return Verdict{Infected: false}, nil
}
if strings.HasSuffix(resp, "FOUND") {
// Extract signature: "stream: <sig> FOUND"
resp = strings.TrimPrefix(resp, "stream: ")
sig := strings.TrimSuffix(resp, " FOUND")
return Verdict{
Infected: true,
Signature: sig,
Severity: "critical",
}, nil
}
return Verdict{}, fmt.Errorf("unexpected clamd response: %q", resp)
}
package emailav
import (
"context"
"fmt"
"os"
"sync"
"time"
emime "github.com/pidginhost/csm/internal/mime"
)
// Orchestrator runs multiple scanners in parallel against extracted email parts.
type Orchestrator struct {
scanners []Scanner
scanTimeout time.Duration
}
// NewOrchestrator creates an orchestrator with the given scanners and per-scan timeout.
func NewOrchestrator(scanners []Scanner, scanTimeout time.Duration) *Orchestrator {
return &Orchestrator{
scanners: scanners,
scanTimeout: scanTimeout,
}
}
// ScanParts scans all extracted parts with all available engines.
// Fail-open: unavailable engines, timeouts, and errors are recorded but do not
// mark the message as infected.
func (o *Orchestrator) ScanParts(messageID string, parts []emime.ExtractedPart, partial bool) *ScanResult {
result := &ScanResult{
MessageID: messageID,
ScannedAt: time.Now(),
PartialExtraction: partial,
}
// Determine which engines are available
var available []Scanner
for _, s := range o.scanners {
if s.Available() {
available = append(available, s)
result.EnginesUsed = append(result.EnginesUsed, s.Name())
} else {
result.FailedEngines = append(result.FailedEngines, s.Name())
fmt.Fprintf(os.Stderr, "[emailav] engine %s unavailable\n", s.Name())
}
}
if len(available) == 0 {
// fail-open: no engines available - rate-limit the warning
result.AllEnginesDown = true
return result
}
// Scan each part with all available engines
for _, part := range parts {
findings, timedOut := o.scanPart(part, available)
result.Findings = append(result.Findings, findings...)
result.TimedOutEngines = append(result.TimedOutEngines, timedOut...)
}
result.Infected = len(result.Findings) > 0
return result
}
// scanPart scans a single part with all available engines concurrently.
// Returns findings and a list of engine names that timed out.
func (o *Orchestrator) scanPart(part emime.ExtractedPart, scanners []Scanner) ([]Finding, []string) {
type scanResult struct {
engine string
verdict Verdict
err error
timedOut bool
}
ctx, cancel := context.WithTimeout(context.Background(), o.scanTimeout)
defer cancel()
results := make(chan scanResult, len(scanners))
var wg sync.WaitGroup
for _, s := range scanners {
wg.Add(1)
go func(scanner Scanner) {
defer wg.Done()
done := make(chan scanResult, 1)
go func() {
v, err := scanner.Scan(part.TempPath)
done <- scanResult{engine: scanner.Name(), verdict: v, err: err}
}()
select {
case r := <-done:
results <- r
case <-ctx.Done():
results <- scanResult{engine: scanner.Name(), err: fmt.Errorf("scan timeout"), timedOut: true}
}
}(s)
}
// Close results channel when all scans complete
go func() {
wg.Wait()
close(results)
}()
var findings []Finding
var timedOut []string
for r := range results {
if r.err != nil {
fmt.Fprintf(os.Stderr, "[emailav] %s scan error on %s: %v\n", r.engine, part.Filename, r.err)
if r.timedOut {
timedOut = append(timedOut, r.engine)
}
continue // fail-open
}
if r.verdict.Infected {
f := Finding{
Filename: part.Filename,
Engine: r.engine,
Signature: r.verdict.Signature,
Severity: r.verdict.Severity,
}
if part.Nested {
f.Filename = part.ArchiveName + "/" + part.Filename
}
findings = append(findings, f)
}
}
return findings, timedOut
}
package emailav
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
type movedFile struct {
src string
dst string
}
// QuarantineEnvelope holds the email envelope info for quarantine metadata.
type QuarantineEnvelope struct {
From string
To []string
Subject string
Direction string
}
// QuarantineMetadata is the JSON sidecar written with quarantined messages.
type QuarantineMetadata struct {
MessageID string `json:"message_id"`
Direction string `json:"direction"`
From string `json:"from"`
To []string `json:"to"`
Subject string `json:"subject"`
QuarantinedAt time.Time `json:"quarantined_at"`
OriginalSpoolDir string `json:"original_spool_dir"`
Findings []Finding `json:"findings"`
PartialScan bool `json:"partial_scan"`
EnginesUsed []string `json:"engines_used"`
}
// Quarantine manages the per-message email quarantine directory.
type Quarantine struct {
baseDir string // e.g. /opt/csm/quarantine/email
allowedSpoolDirs []string
}
// NewQuarantine creates a quarantine manager for the given base directory.
func NewQuarantine(baseDir string) *Quarantine {
return &Quarantine{
baseDir: baseDir,
allowedSpoolDirs: []string{"/var/spool/exim/input", "/var/spool/exim4/input"},
}
}
// QuarantineMessage moves spool files into a per-message quarantine directory
// and writes metadata.json.
func (q *Quarantine) QuarantineMessage(msgID, spoolDir string, result *ScanResult, env QuarantineEnvelope) error {
msgID = filepath.Base(msgID) // sanitize against path traversal
msgDir := filepath.Join(q.baseDir, msgID)
meta := QuarantineMetadata{
MessageID: msgID,
Direction: env.Direction,
From: env.From,
To: env.To,
Subject: env.Subject,
QuarantinedAt: time.Now(),
OriginalSpoolDir: spoolDir,
Findings: result.Findings,
PartialScan: result.PartialExtraction,
EnginesUsed: result.EnginesUsed,
}
metaData, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return fmt.Errorf("marshaling metadata: %w", err)
}
if err := os.MkdirAll(msgDir, 0700); err != nil {
return fmt.Errorf("creating quarantine dir: %w", err)
}
var moved []movedFile
for _, suffix := range []string{"-H", "-D"} {
src := filepath.Join(spoolDir, msgID+suffix)
dst := filepath.Join(msgDir, msgID+suffix)
if err := moveFile(src, dst); err != nil {
continue
}
moved = append(moved, movedFile{src: src, dst: dst})
}
if len(moved) == 0 {
os.Remove(msgDir)
return fmt.Errorf("no spool files found for %s", msgID)
}
metaPath := filepath.Join(msgDir, "metadata.json")
if err := os.WriteFile(metaPath, metaData, 0600); err != nil {
rollbackErr := rollbackMovedFiles(moved)
_ = os.RemoveAll(msgDir)
if rollbackErr != nil {
return fmt.Errorf("writing metadata: %w (rollback failed: %v)", err, rollbackErr)
}
return fmt.Errorf("writing metadata: %w", err)
}
return nil
}
// ListMessages returns all quarantined email messages.
func (q *Quarantine) ListMessages() ([]QuarantineMetadata, error) {
entries, err := os.ReadDir(q.baseDir)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, fmt.Errorf("reading quarantine dir: %w", err)
}
var msgs []QuarantineMetadata
for _, entry := range entries {
if !entry.IsDir() {
continue
}
meta, err := q.readMetadata(entry.Name())
if err != nil {
continue
}
msgs = append(msgs, *meta)
}
return msgs, nil
}
// GetMessage returns the metadata for a single quarantined message.
func (q *Quarantine) GetMessage(msgID string) (*QuarantineMetadata, error) {
return q.readMetadata(filepath.Base(msgID))
}
// ReleaseMessage moves spool files back to the original spool directory
// and removes the quarantine directory.
func (q *Quarantine) ReleaseMessage(msgID string) error {
msgID = filepath.Base(msgID) // sanitize against path traversal
meta, err := q.readMetadata(msgID)
if err != nil {
return fmt.Errorf("reading metadata: %w", err)
}
spoolDir, err := q.validateReleaseSpoolDir(meta.OriginalSpoolDir)
if err != nil {
return err
}
msgDir := filepath.Join(q.baseDir, msgID)
for _, suffix := range []string{"-H", "-D"} {
src := filepath.Join(msgDir, msgID+suffix)
dst := filepath.Join(spoolDir, msgID+suffix)
if err := moveFile(src, dst); err != nil {
// If source doesn't exist, skip (partial quarantine)
if os.IsNotExist(err) {
continue
}
return fmt.Errorf("moving %s back to spool: %w", suffix, err)
}
}
return os.RemoveAll(msgDir)
}
// DeleteMessage permanently removes a quarantined message.
func (q *Quarantine) DeleteMessage(msgID string) error {
msgDir := filepath.Join(q.baseDir, filepath.Base(msgID)) // sanitize
return os.RemoveAll(msgDir)
}
// CleanExpired removes quarantine directories older than maxAge.
// Returns the number of directories cleaned.
func (q *Quarantine) CleanExpired(maxAge time.Duration) (int, error) {
entries, err := os.ReadDir(q.baseDir)
if err != nil {
if os.IsNotExist(err) {
return 0, nil
}
return 0, err
}
cleaned := 0
cutoff := time.Now().Add(-maxAge)
for _, entry := range entries {
if !entry.IsDir() {
continue
}
meta, err := q.readMetadata(entry.Name())
if err != nil {
continue
}
if meta.QuarantinedAt.Before(cutoff) {
os.RemoveAll(filepath.Join(q.baseDir, entry.Name()))
cleaned++
}
}
return cleaned, nil
}
func (q *Quarantine) readMetadata(msgID string) (*QuarantineMetadata, error) {
msgID = filepath.Base(msgID) // defense in depth
metaPath := filepath.Join(q.baseDir, msgID, "metadata.json")
// #nosec G304 -- msgID sanitized with filepath.Base; filepath.Join under baseDir.
data, err := os.ReadFile(metaPath)
if err != nil {
return nil, err
}
var meta QuarantineMetadata
if err := json.Unmarshal(data, &meta); err != nil {
return nil, err
}
return &meta, nil
}
// moveFile renames src to dst, falling back to copy+delete for cross-device
// moves. Callers construct dst by filepath.Join under the quarantine base dir
// (config-owned) plus a filepath.Base-sanitized identifier.
func moveFile(src, dst string) error {
if err := os.Rename(src, dst); err != nil {
// Cross-device fallback
// #nosec G304 -- src is mail queue path from scanner walk.
data, readErr := os.ReadFile(src)
if readErr != nil {
return readErr
}
// #nosec G703 -- dst is constructed by the caller under the
// quarantine baseDir with filepath.Base applied to any user-
// supplied component (see readMetadata above for the pattern).
if writeErr := os.WriteFile(dst, data, 0600); writeErr != nil {
return writeErr
}
os.Remove(src)
}
return nil
}
func rollbackMovedFiles(moved []movedFile) error {
for i := len(moved) - 1; i >= 0; i-- {
if err := moveFile(moved[i].dst, moved[i].src); err != nil {
return err
}
}
return nil
}
func (q *Quarantine) validateReleaseSpoolDir(spoolDir string) (string, error) {
cleanDir := filepath.Clean(spoolDir)
if cleanDir == "" || !filepath.IsAbs(cleanDir) {
return "", fmt.Errorf("invalid original spool directory")
}
resolvedDir := cleanDir
if dir, err := filepath.EvalSymlinks(cleanDir); err == nil {
resolvedDir = dir
}
for _, allowed := range q.allowedSpoolDirs {
cleanAllowed := filepath.Clean(allowed)
resolvedAllowed := cleanAllowed
if dir, err := filepath.EvalSymlinks(cleanAllowed); err == nil {
resolvedAllowed = dir
}
if resolvedDir == resolvedAllowed {
return resolvedDir, nil
}
}
return "", fmt.Errorf("original spool directory is not trusted: %s", cleanDir)
}
//go:build !yara
package emailav
// YaraXScanner is a no-op stub when YARA-X is not compiled in.
type YaraXScanner struct{}
// NewYaraXScanner returns a scanner that is never available.
func NewYaraXScanner(_ interface{}) *YaraXScanner {
return &YaraXScanner{}
}
func (s *YaraXScanner) Name() string { return "yara-x" }
func (s *YaraXScanner) Available() bool { return false }
func (s *YaraXScanner) Scan(_ string) (Verdict, error) {
return Verdict{}, nil
}
package firewall
import (
"bufio"
"encoding/json"
"os"
"path/filepath"
"time"
)
const maxAuditFileSize = 10 * 1024 * 1024 // 10 MB
// AuditEntry records a firewall modification for compliance and forensics.
type AuditEntry struct {
Timestamp time.Time `json:"timestamp"`
Action string `json:"action"` // block, unblock, allow, remove_allow, flush, apply
IP string `json:"ip,omitempty"`
Reason string `json:"reason,omitempty"`
Source string `json:"source,omitempty"`
Duration string `json:"duration,omitempty"`
}
// AppendAudit writes an audit entry to the JSONL audit log.
// Rotates the log when it exceeds 10 MB.
func AppendAudit(statePath, action, ip, reason, source string, duration time.Duration) {
if source == "" {
source = InferProvenance(action, reason)
}
entry := AuditEntry{
Timestamp: time.Now(),
Action: action,
IP: ip,
Reason: reason,
Source: source,
}
if duration > 0 {
entry.Duration = duration.String()
}
path := filepath.Join(statePath, "audit.jsonl")
data, err := json.Marshal(entry)
if err != nil {
return
}
data = append(data, '\n')
// Rotate if file exceeds max size
if info, statErr := os.Stat(path); statErr == nil && info.Size() > maxAuditFileSize {
_ = os.Rename(path, path+".1")
}
// #nosec G304 -- path is filepath.Join under operator-configured statePath.
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return
}
defer f.Close()
_, _ = f.Write(data)
}
// ReadAuditLog returns the last N audit entries from the log.
func ReadAuditLog(statePath string, limit int) []AuditEntry {
path := filepath.Join(statePath, "firewall", "audit.jsonl")
// #nosec G304 -- filepath.Join under operator-configured statePath.
f, err := os.Open(path)
if err != nil {
return nil
}
defer f.Close()
var all []AuditEntry
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var entry AuditEntry
if json.Unmarshal(scanner.Bytes(), &entry) == nil {
all = append(all, entry)
}
}
if limit > 0 && len(all) > limit {
all = all[len(all)-limit:]
}
return all
}
//go:build linux
package firewall
import (
"fmt"
"net"
"os"
"github.com/google/nftables"
)
// UpdateCloudflareSet flushes and repopulates the Cloudflare nftables sets.
func (e *Engine) UpdateCloudflareSet(ipv4, ipv6 []string) error {
e.mu.Lock()
defer e.mu.Unlock()
if e.setCFWhitelist == nil {
return fmt.Errorf("cf_whitelist set not initialized")
}
// Flush existing entries
e.conn.FlushSet(e.setCFWhitelist)
e.conn.FlushSet(e.setCFWhitelist6)
// Populate IPv4 CIDRs
var elems4 []nftables.SetElement
for _, cidr := range ipv4 {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
start := network.IP.To4()
end := lastIPInRange(network)
if start == nil || end == nil {
continue
}
elems4 = append(elems4,
nftables.SetElement{Key: start},
nftables.SetElement{Key: nextIP(end), IntervalEnd: true},
)
}
if len(elems4) > 0 {
if err := e.conn.SetAddElements(e.setCFWhitelist, elems4); err != nil {
return fmt.Errorf("adding CF IPv4 elements: %w", err)
}
}
// Populate IPv6 CIDRs
var elems6 []nftables.SetElement
for _, cidr := range ipv6 {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
start := network.IP.To16()
end := lastIPInRange(network)
if start == nil || end == nil {
continue
}
elems6 = append(elems6,
nftables.SetElement{Key: start},
nftables.SetElement{Key: nextIP(end), IntervalEnd: true},
)
}
if len(elems6) > 0 {
if err := e.conn.SetAddElements(e.setCFWhitelist6, elems6); err != nil {
return fmt.Errorf("adding CF IPv6 elements: %w", err)
}
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing CF whitelist: %w", err)
}
fmt.Fprintf(os.Stderr, "firewall: cloudflare whitelist updated: %d IPv4, %d IPv6 CIDRs\n",
len(ipv4), len(ipv6))
return nil
}
// CloudflareIPs returns the currently configured Cloudflare CIDRs from the cached state.
func (e *Engine) CloudflareIPs() (ipv4, ipv6 []string) {
return LoadCFState(e.statePath)
}
package firewall
import (
"bufio"
"fmt"
"net"
"net/http"
"os"
"strings"
"time"
)
const (
cfIPv4URL = "https://www.cloudflare.com/ips-v4"
cfIPv6URL = "https://www.cloudflare.com/ips-v6"
)
// FetchCloudflareIPs downloads the current Cloudflare IP ranges.
func FetchCloudflareIPs() (ipv4, ipv6 []string, err error) {
client := &http.Client{Timeout: 30 * time.Second}
ipv4, err = fetchCIDRList(client, cfIPv4URL)
if err != nil {
return nil, nil, fmt.Errorf("fetching CF IPv4: %w", err)
}
ipv6, err = fetchCIDRList(client, cfIPv6URL)
if err != nil {
return nil, nil, fmt.Errorf("fetching CF IPv6: %w", err)
}
return ipv4, ipv6, nil
}
// fetchCIDRList fetches a URL and parses one CIDR per line.
func fetchCIDRList(client *http.Client, url string) ([]string, error) {
resp, err := client.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, url)
}
return parseCloudflareResponse(bufio.NewScanner(resp.Body)), nil
}
// parseCloudflareResponse parses lines from a scanner, returning valid CIDRs.
// Skips blank lines, comments, and invalid entries.
func parseCloudflareResponse(scanner *bufio.Scanner) []string {
var cidrs []string
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
_, _, err := net.ParseCIDR(line)
if err != nil {
continue
}
cidrs = append(cidrs, line)
}
return cidrs
}
// SaveCFState persists the Cloudflare CIDRs for status display.
func SaveCFState(statePath string, ipv4, ipv6 []string, refreshed time.Time) {
path := statePath
if !strings.HasSuffix(path, "/firewall") {
path += "/firewall"
}
_ = os.MkdirAll(path, 0700)
var sb strings.Builder
fmt.Fprintf(&sb, "# refreshed: %s\n", refreshed.Format(time.RFC3339))
sb.WriteString("# ipv4\n")
for _, cidr := range ipv4 {
sb.WriteString(cidr)
sb.WriteByte('\n')
}
sb.WriteString("# ipv6\n")
for _, cidr := range ipv6 {
sb.WriteString(cidr)
sb.WriteByte('\n')
}
file := path + "/cf_whitelist.txt"
_ = os.WriteFile(file, []byte(sb.String()), 0600)
}
// LoadCFState reads the cached Cloudflare CIDRs.
func LoadCFState(statePath string) (ipv4, ipv6 []string) {
path := statePath
if !strings.HasSuffix(path, "/firewall") {
path += "/firewall"
}
file := path + "/cf_whitelist.txt"
// #nosec G304 -- fixed filename under operator-configured statePath.
f, err := os.Open(file)
if err != nil {
return nil, nil
}
defer f.Close()
scanner := bufio.NewScanner(f)
section := "ipv4"
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "# ipv4" {
section = "ipv4"
continue
}
if line == "# ipv6" {
section = "ipv6"
continue
}
if line == "" || strings.HasPrefix(line, "#") {
continue
}
switch section {
case "ipv4":
ipv4 = append(ipv4, line)
case "ipv6":
ipv6 = append(ipv6, line)
}
}
return ipv4, ipv6
}
// LoadCFRefreshTime reads the last CF refresh time from state.
func LoadCFRefreshTime(statePath string) time.Time {
path := statePath
if !strings.HasSuffix(path, "/firewall") {
path += "/firewall"
}
file := path + "/cf_whitelist.txt"
// #nosec G304 -- fixed filename under operator-configured statePath.
f, err := os.Open(file)
if err != nil {
return time.Time{}
}
defer f.Close()
scanner := bufio.NewScanner(f)
if scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "# refreshed: ") {
ts := strings.TrimPrefix(line, "# refreshed: ")
if t, err := time.Parse(time.RFC3339, ts); err == nil {
return t
}
}
}
return time.Time{}
}
package firewall
// FirewallConfig defines the nftables firewall configuration.
type FirewallConfig struct {
Enabled bool `yaml:"enabled"`
// Open ports (IPv4)
TCPIn []int `yaml:"tcp_in"`
TCPOut []int `yaml:"tcp_out"`
UDPIn []int `yaml:"udp_in"`
UDPOut []int `yaml:"udp_out"`
// IPv6 - enable dual-stack filtering
IPv6 bool `yaml:"ipv6"`
TCP6In []int `yaml:"tcp6_in"` // if empty, uses tcp_in
TCP6Out []int `yaml:"tcp6_out"` // if empty, uses tcp_out
UDP6In []int `yaml:"udp6_in"` // if empty, uses udp_in
UDP6Out []int `yaml:"udp6_out"` // if empty, uses udp_out
// Ports restricted to infra IPs only
RestrictedTCP []int `yaml:"restricted_tcp"`
// Passive FTP range
PassiveFTPStart int `yaml:"passive_ftp_start"`
PassiveFTPEnd int `yaml:"passive_ftp_end"`
// Infra IPs (CIDR notation)
InfraIPs []string `yaml:"infra_ips"`
// Rate limiting (per-source-IP via nftables meters)
ConnRateLimit int `yaml:"conn_rate_limit"` // new connections per minute per IP
SYNFloodProtection bool `yaml:"syn_flood_protection"`
ConnLimit int `yaml:"conn_limit"` // max concurrent connections per IP (0 = disabled)
// Per-port flood protection - per-source rate limit per port
PortFlood []PortFloodRule `yaml:"port_flood"`
// UDP flood protection - per-source rate limit on UDP packets
UDPFlood bool `yaml:"udp_flood"`
UDPFloodRate int `yaml:"udp_flood_rate"` // packets per second
UDPFloodBurst int `yaml:"udp_flood_burst"` // burst allowance
// Country blocking
CountryBlock []string `yaml:"country_block"` // ISO country codes
CountryDBPath string `yaml:"country_db_path"`
// Ports to drop silently without logging (reduces log noise from scanners)
DropNoLog []int `yaml:"drop_nolog"`
// Max blocked IPs (prevents memory exhaustion, 0 = unlimited)
DenyIPLimit int `yaml:"deny_ip_limit"`
DenyTempIPLimit int `yaml:"deny_temp_ip_limit"`
// Outbound SMTP restriction - block outgoing mail except from allowed users
SMTPBlock bool `yaml:"smtp_block"`
SMTPAllowUsers []string `yaml:"smtp_allow_users"` // usernames allowed to send
SMTPPorts []int `yaml:"smtp_ports"`
// Dynamic DNS - resolve hostnames to IPs, update allowed set periodically
DynDNSHosts []string `yaml:"dyndns_hosts"`
// Logging
LogDropped bool `yaml:"log_dropped"`
LogRate int `yaml:"log_rate"` // log entries per minute
}
// PortFloodRule defines per-port connection rate limiting.
type PortFloodRule struct {
Port int `yaml:"port"`
Proto string `yaml:"proto"` // "tcp" or "udp"
Hits int `yaml:"hits"` // max new connections
Seconds int `yaml:"seconds"` // time window in seconds
}
// DefaultConfig returns a sensible default firewall configuration
// matching a typical cPanel server.
func DefaultConfig() *FirewallConfig {
return &FirewallConfig{
Enabled: false,
TCPIn: []int{
20, 21, 25, 26, 53, 80, 110, 143, 443, 465, 587,
993, 995, 2077, 2078, 2079, 2080, 2082, 2083,
2091, 2095, 2096,
},
TCPOut: []int{
20, 21, 25, 26, 37, 43, 53, 80, 110, 113, 443,
465, 587, 873, 993, 995, 2082, 2083, 2086, 2087,
2089, 2195, 2325, 2703,
},
UDPIn: []int{53, 443},
UDPOut: []int{53, 113, 123, 443, 873},
RestrictedTCP: []int{2086, 2087, 2325, 9443},
PassiveFTPStart: 49152,
PassiveFTPEnd: 65534,
ConnRateLimit: 30,
SYNFloodProtection: true,
PortFlood: []PortFloodRule{
{Port: 25, Proto: "tcp", Hits: 40, Seconds: 300},
{Port: 465, Proto: "tcp", Hits: 40, Seconds: 300},
{Port: 587, Proto: "tcp", Hits: 40, Seconds: 300},
},
UDPFlood: true,
UDPFloodRate: 100,
UDPFloodBurst: 500,
DropNoLog: []int{23, 67, 68, 111, 113, 135, 136, 137, 138, 139, 445, 500, 513, 520},
DenyIPLimit: 3000,
DenyTempIPLimit: 500,
SMTPBlock: false,
SMTPPorts: []int{25, 465, 587},
LogDropped: true,
LogRate: 5,
}
}
package firewall
import (
"fmt"
"net"
"os"
"sort"
"sync"
"time"
)
// DynDNSResolver periodically resolves hostnames and updates the firewall allowed set.
type DynDNSResolver struct {
mu sync.Mutex
hosts []string
resolved map[string][]string // hostname -> all resolved IPs
engine interface {
AllowIP(ip string, reason string) error
RemoveAllowIPBySource(ip string, source string) error
}
}
// NewDynDNSResolver creates a resolver for the given hostnames.
func NewDynDNSResolver(hosts []string, engine interface {
AllowIP(ip string, reason string) error
RemoveAllowIPBySource(ip string, source string) error
}) *DynDNSResolver {
return &DynDNSResolver{
hosts: hosts,
resolved: make(map[string][]string),
engine: engine,
}
}
// Run starts the periodic resolver. Blocks until stopCh is closed.
func (d *DynDNSResolver) Run(stopCh <-chan struct{}) {
d.resolveAll()
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case <-ticker.C:
d.resolveAll()
}
}
}
func (d *DynDNSResolver) resolveAll() {
for _, host := range d.hosts {
d.resolveHost(host)
}
}
func (d *DynDNSResolver) resolveHost(host string) {
newIPs, err := net.LookupHost(host)
if err != nil || len(newIPs) == 0 {
fmt.Fprintf(os.Stderr, "dyndns: failed to resolve %s: %v\n", host, err)
return
}
sort.Strings(newIPs)
d.mu.Lock()
oldIPs := d.resolved[host]
d.mu.Unlock()
oldSet := make(map[string]bool)
for _, ip := range oldIPs {
oldSet[ip] = true
}
newSet := make(map[string]bool)
for _, ip := range newIPs {
newSet[ip] = true
}
// Remove IPs no longer in DNS (only remove the dyndns source entry)
for _, ip := range oldIPs {
if !newSet[ip] {
_ = d.engine.RemoveAllowIPBySource(ip, SourceDynDNS)
fmt.Fprintf(os.Stderr, "dyndns: %s removed %s (no longer resolves)\n", host, ip)
}
}
// Add new IPs
reason := fmt.Sprintf("dyndns: %s", host)
var successIPs []string
for _, ip := range newIPs {
if oldSet[ip] {
successIPs = append(successIPs, ip) // already allowed
continue
}
if err := d.engine.AllowIP(ip, reason); err != nil {
fmt.Fprintf(os.Stderr, "dyndns: error allowing %s (%s): %v\n", ip, host, err)
continue
}
successIPs = append(successIPs, ip)
fmt.Fprintf(os.Stderr, "dyndns: %s resolved to %s (added)\n", host, ip)
}
// Only update resolved map with successfully allowed IPs
d.mu.Lock()
d.resolved[host] = successIPs
d.mu.Unlock()
}
//go:build linux
package firewall
import (
"encoding/json"
"fmt"
"net"
"os"
"os/user"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
)
// Engine manages the nftables firewall ruleset.
// Manages the nftables ruleset via netlink.
type Engine struct {
mu sync.Mutex
conn *nftables.Conn
cfg *FirewallConfig
table *nftables.Table
chainIn *nftables.Chain
chainOut *nftables.Chain
setBlocked *nftables.Set
setBlockedNet *nftables.Set
setAllowed *nftables.Set
setInfra *nftables.Set
setCountry *nftables.Set
// Cloudflare IP whitelist sets (interval for CIDR matching)
setCFWhitelist *nftables.Set // IPv4
setCFWhitelist6 *nftables.Set // IPv6
// IPv6 sets (nil if IPv6 disabled)
setBlocked6 *nftables.Set
setBlockedNet6 *nftables.Set
setAllowed6 *nftables.Set
setInfra6 *nftables.Set
// Meters for per-IP rate limiting
meterSYN *nftables.Set
meterConn *nftables.Set
meterUDP *nftables.Set
meterConnlim *nftables.Set
statePath string
}
// BlockedEntry represents a blocked IP with metadata.
type BlockedEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"` // zero = permanent
}
// AllowedEntry represents an allowed IP with metadata.
type AllowedEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
Port int `json:"port,omitempty"` // 0 = all ports
ExpiresAt time.Time `json:"expires_at,omitempty"` // zero = permanent
}
// SubnetEntry represents a blocked CIDR range.
type SubnetEntry struct {
CIDR string `json:"cidr"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
}
// PortAllowEntry represents a port-specific IP allow (e.g. tcp|in|d=PORT|s=IP).
type PortAllowEntry struct {
IP string `json:"ip"`
Port int `json:"port"`
Proto string `json:"proto"` // "tcp" or "udp"
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
}
// FirewallState is persisted to disk for restore on restart.
type FirewallState struct {
Blocked []BlockedEntry `json:"blocked"`
BlockedNet []SubnetEntry `json:"blocked_nets"`
Allowed []AllowedEntry `json:"allowed"`
PortAllowed []PortAllowEntry `json:"port_allowed"`
}
// portU16 converts an operator-configured int port to the uint16 nftables
// expects. Returns 0 for out-of-range values; 0 is an unroutable TCP/UDP port
// so a misconfigured rule fails closed (no traffic matches) rather than
// wrapping silently to a valid-but-wrong port.
func portU16(p int) uint16 {
if p < 0 || p > 65535 {
return 0
}
// #nosec G115 -- bounded above.
return uint16(p)
}
// NewEngine creates a new nftables firewall engine.
func NewEngine(cfg *FirewallConfig, statePath string) (*Engine, error) {
conn, err := nftables.New()
if err != nil {
return nil, fmt.Errorf("nftables connection: %w", err)
}
e := &Engine{
conn: conn,
cfg: cfg,
statePath: filepath.Join(statePath, "firewall"),
}
_ = os.MkdirAll(e.statePath, 0700)
return e, nil
}
// ConnectExisting connects to an already-running CSM firewall.
// Used by CLI commands to modify the live ruleset without reapplying all rules.
func ConnectExisting(cfg *FirewallConfig, statePath string) (*Engine, error) {
conn, err := nftables.New()
if err != nil {
return nil, fmt.Errorf("nftables connection: %w", err)
}
// Find existing CSM table
tables, err := conn.ListTables()
if err != nil {
return nil, fmt.Errorf("listing tables: %w", err)
}
var table *nftables.Table
for _, t := range tables {
if t.Name == "csm" && t.Family == nftables.TableFamilyINet {
table = t
break
}
}
if table == nil {
return nil, fmt.Errorf("CSM firewall not running (table 'csm' not found) - run 'csm firewall restart' first")
}
setBlocked, err := conn.GetSetByName(table, "blocked_ips")
if err != nil {
return nil, fmt.Errorf("blocked_ips set not found: %w", err)
}
setBlockedNet, err := conn.GetSetByName(table, "blocked_nets")
if err != nil {
return nil, fmt.Errorf("blocked_nets set not found: %w", err)
}
setAllowed, err := conn.GetSetByName(table, "allowed_ips")
if err != nil {
return nil, fmt.Errorf("allowed_ips set not found: %w", err)
}
setInfra, err := conn.GetSetByName(table, "infra_ips")
if err != nil {
return nil, fmt.Errorf("infra_ips set not found: %w", err)
}
e := &Engine{
conn: conn,
cfg: cfg,
table: table,
setBlocked: setBlocked,
setBlockedNet: setBlockedNet,
setAllowed: setAllowed,
setInfra: setInfra,
statePath: filepath.Join(statePath, "firewall"),
}
// Try to find Cloudflare whitelist sets (optional)
if s, err := conn.GetSetByName(table, "cf_whitelist"); err == nil {
e.setCFWhitelist = s
}
if s, err := conn.GetSetByName(table, "cf_whitelist6"); err == nil {
e.setCFWhitelist6 = s
}
// Try to find IPv6 sets (optional - may not exist if IPv6 disabled)
if s, err := conn.GetSetByName(table, "blocked_ips6"); err == nil {
e.setBlocked6 = s
}
if s, err := conn.GetSetByName(table, "blocked_nets6"); err == nil {
e.setBlockedNet6 = s
}
if s, err := conn.GetSetByName(table, "allowed_ips6"); err == nil {
e.setAllowed6 = s
}
if s, err := conn.GetSetByName(table, "infra_ips6"); err == nil {
e.setInfra6 = s
}
return e, nil
}
// Apply builds and atomically applies the complete nftables ruleset.
// All operations (delete old table + create new table/rules) are batched
// into a single netlink transaction. If the flush fails, the kernel keeps
// whatever ruleset was running before - the server is never left without a firewall.
func (e *Engine) Apply() error {
e.mu.Lock()
defer e.mu.Unlock()
// Check if existing CSM table needs replacing.
// If so, include the delete in the same atomic batch as the new table.
tables, _ := e.conn.ListTables()
for _, t := range tables {
if t.Name == "csm" && t.Family == nftables.TableFamilyINet {
e.conn.DelTable(t)
break
}
}
// Create table - all operations below are batched, nothing is sent until Flush()
e.table = e.conn.AddTable(&nftables.Table{
Name: "csm",
Family: nftables.TableFamilyINet,
})
// Create IP sets
if err := e.createSets(); err != nil {
return fmt.Errorf("creating sets: %w", err)
}
// Create chains and rules
if err := e.createInputChain(); err != nil {
return fmt.Errorf("creating input chain: %w", err)
}
if err := e.createOutputChain(); err != nil {
return fmt.Errorf("creating output chain: %w", err)
}
// Apply atomically - if this fails, nftables keeps whatever was running before
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("applying ruleset: %w", err)
}
// Populate sets from persisted state
if err := e.loadState(); err != nil {
fmt.Fprintf(os.Stderr, "firewall: warning loading state: %v\n", err)
}
return nil
}
// createSets creates the nftables named sets for IP management.
func (e *Engine) createSets() error {
// Blocked IPs set (with per-element timeout)
e.setBlocked = &nftables.Set{
Table: e.table,
Name: "blocked_ips",
KeyType: nftables.TypeIPAddr,
HasTimeout: true,
Timeout: 24 * time.Hour,
}
if err := e.conn.AddSet(e.setBlocked, nil); err != nil {
return fmt.Errorf("blocked set: %w", err)
}
// Blocked subnets set (interval for CIDR ranges, permanent)
e.setBlockedNet = &nftables.Set{
Table: e.table,
Name: "blocked_nets",
KeyType: nftables.TypeIPAddr,
Interval: true,
}
if err := e.conn.AddSet(e.setBlockedNet, nil); err != nil {
return fmt.Errorf("blocked nets set: %w", err)
}
// Allowed IPs set
e.setAllowed = &nftables.Set{
Table: e.table,
Name: "allowed_ips",
KeyType: nftables.TypeIPAddr,
}
if err := e.conn.AddSet(e.setAllowed, nil); err != nil {
return fmt.Errorf("allowed set: %w", err)
}
// Infra IPs set (interval for CIDR support) - split IPv4 and IPv6
e.setInfra = &nftables.Set{
Table: e.table,
Name: "infra_ips",
KeyType: nftables.TypeIPAddr,
Interval: true,
}
var infraElements []nftables.SetElement
var infra6Elements []nftables.SetElement
for _, cidr := range e.cfg.InfraIPs {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
ip := net.ParseIP(cidr)
if ip == nil {
continue
}
if ip4 := ip.To4(); ip4 != nil {
infraElements = append(infraElements,
nftables.SetElement{Key: ip4},
nftables.SetElement{Key: nextIP(ip4), IntervalEnd: true},
)
} else if e.cfg.IPv6 {
ip16 := ip.To16()
infra6Elements = append(infra6Elements,
nftables.SetElement{Key: ip16},
nftables.SetElement{Key: nextIP(ip16), IntervalEnd: true},
)
}
continue
}
if network.IP.To4() != nil {
start := network.IP.To4()
end := lastIPInRange(network)
if start != nil && end != nil {
infraElements = append(infraElements,
nftables.SetElement{Key: start},
nftables.SetElement{Key: nextIP(end), IntervalEnd: true},
)
}
} else if e.cfg.IPv6 {
start := network.IP.To16()
end := lastIPInRange(network)
if start != nil && end != nil {
infra6Elements = append(infra6Elements,
nftables.SetElement{Key: start},
nftables.SetElement{Key: nextIP(end), IntervalEnd: true},
)
}
}
}
if err := e.conn.AddSet(e.setInfra, infraElements); err != nil {
return fmt.Errorf("infra set: %w", err)
}
// Cloudflare IP whitelist sets (interval for CIDR ranges, accept on 80/443 only)
e.setCFWhitelist = &nftables.Set{
Table: e.table,
Name: "cf_whitelist",
KeyType: nftables.TypeIPAddr,
Interval: true,
}
if err := e.conn.AddSet(e.setCFWhitelist, nil); err != nil {
return fmt.Errorf("cf_whitelist set: %w", err)
}
e.setCFWhitelist6 = &nftables.Set{
Table: e.table,
Name: "cf_whitelist6",
KeyType: nftables.TypeIP6Addr,
Interval: true,
}
if err := e.conn.AddSet(e.setCFWhitelist6, nil); err != nil {
return fmt.Errorf("cf_whitelist6 set: %w", err)
}
// Country-blocked IPs set (interval for CIDR ranges)
if len(e.cfg.CountryBlock) > 0 && e.cfg.CountryDBPath != "" {
e.setCountry = &nftables.Set{
Table: e.table,
Name: "country_blocked",
KeyType: nftables.TypeIPAddr,
Interval: true,
}
var countryElements []nftables.SetElement
for _, code := range e.cfg.CountryBlock {
countryElements = append(countryElements, loadCountryCIDRs(e.cfg.CountryDBPath, code)...)
}
if err := e.conn.AddSet(e.setCountry, countryElements); err != nil {
fmt.Fprintf(os.Stderr, "firewall: warning creating country set: %v\n", err)
e.setCountry = nil
} else if len(countryElements) > 0 {
fmt.Fprintf(os.Stderr, "firewall: loaded %d country block ranges for %v\n",
len(countryElements)/2, e.cfg.CountryBlock)
}
}
// IPv6 sets
if e.cfg.IPv6 {
e.setBlocked6 = &nftables.Set{
Table: e.table, Name: "blocked_ips6",
KeyType: nftables.TypeIP6Addr, HasTimeout: true, Timeout: 24 * time.Hour,
}
if err := e.conn.AddSet(e.setBlocked6, nil); err != nil {
return fmt.Errorf("blocked6 set: %w", err)
}
e.setBlockedNet6 = &nftables.Set{
Table: e.table, Name: "blocked_nets6",
KeyType: nftables.TypeIP6Addr, Interval: true,
}
if err := e.conn.AddSet(e.setBlockedNet6, nil); err != nil {
return fmt.Errorf("blocked_nets6 set: %w", err)
}
e.setAllowed6 = &nftables.Set{
Table: e.table, Name: "allowed_ips6",
KeyType: nftables.TypeIP6Addr,
}
if err := e.conn.AddSet(e.setAllowed6, nil); err != nil {
return fmt.Errorf("allowed6 set: %w", err)
}
e.setInfra6 = &nftables.Set{
Table: e.table, Name: "infra_ips6",
KeyType: nftables.TypeIP6Addr, Interval: true,
}
if err := e.conn.AddSet(e.setInfra6, infra6Elements); err != nil {
return fmt.Errorf("infra6 set: %w", err)
}
}
// Meter sets for per-IP rate limiting (dynamic sets)
if e.cfg.SYNFloodProtection {
e.meterSYN = &nftables.Set{
Table: e.table, Name: "meter_syn", KeyType: nftables.TypeIPAddr,
Dynamic: true, HasTimeout: true, Timeout: time.Minute,
}
_ = e.conn.AddSet(e.meterSYN, nil)
}
if e.cfg.ConnRateLimit > 0 {
e.meterConn = &nftables.Set{
Table: e.table, Name: "meter_conn", KeyType: nftables.TypeIPAddr,
Dynamic: true, HasTimeout: true, Timeout: time.Minute,
}
_ = e.conn.AddSet(e.meterConn, nil)
}
if e.cfg.UDPFlood && e.cfg.UDPFloodRate > 0 {
e.meterUDP = &nftables.Set{
Table: e.table, Name: "meter_udp", KeyType: nftables.TypeIPAddr,
Dynamic: true, HasTimeout: true, Timeout: time.Minute,
}
_ = e.conn.AddSet(e.meterUDP, nil)
}
if e.cfg.ConnLimit > 0 {
e.meterConnlim = &nftables.Set{
Table: e.table, Name: "meter_connlimit", KeyType: nftables.TypeIPAddr,
Dynamic: true,
}
_ = e.conn.AddSet(e.meterConnlim, nil)
}
return nil
}
// createInputChain builds the input filter chain with proper rule ordering.
func (e *Engine) createInputChain() error {
policy := nftables.ChainPolicyDrop
e.chainIn = e.conn.AddChain(&nftables.Chain{
Name: "input",
Table: e.table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
Policy: &policy,
})
// Rule 1: Allow established/related connections
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
// Rule 2: Allow loopback
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte("lo\x00"),
},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
// Rule 3: Drop INVALID conntrack state (malformed packets)
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1, DestRegister: 1, Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitINVALID),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(0)},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
// Rule 4: Allow infra IPs FIRST - infra must NEVER be blocked, even accidentally
e.addSetMatchRule(e.setInfra, expr.VerdictAccept)
e.addSetMatchRuleV6(e.setInfra6, expr.VerdictAccept)
// Rule 4b: Cloudflare IP whitelist - accept on TCP 80/443 only.
// CF IPs can still be blocked on other ports (unlike infra).
e.addCFWhitelistRule(e.setCFWhitelist, false)
e.addCFWhitelistRule(e.setCFWhitelist6, true)
// Rule 5: Drop blocked IPs (O(1) hash set lookup)
e.addSetMatchRule(e.setBlocked, expr.VerdictDrop)
e.addSetMatchRuleV6(e.setBlocked6, expr.VerdictDrop)
// Rule 6: Drop blocked subnets (interval set for CIDR ranges)
e.addSetMatchRule(e.setBlockedNet, expr.VerdictDrop)
e.addSetMatchRuleV6(e.setBlockedNet6, expr.VerdictDrop)
// Rule 7: Allow explicitly allowed IPs
e.addSetMatchRule(e.setAllowed, expr.VerdictAccept)
e.addSetMatchRuleV6(e.setAllowed6, expr.VerdictAccept)
// Rule 8: Port-specific allows (IP+port, e.g. MySQL access for specific IPs)
state := e.loadStateFile()
for _, pa := range state.PortAllowed {
parsed := net.ParseIP(pa.IP)
if parsed == nil {
continue
}
proto := byte(6) // TCP
if pa.Proto == "udp" {
proto = 17
}
if ip4 := parsed.To4(); ip4 != nil {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ip4},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(pa.Port))},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
} else if e.cfg.IPv6 {
ip16 := parsed.To16()
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
&expr.Meta{Key: expr.MetaKeyNFPROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{10}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 8, Len: 16},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ip16},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(pa.Port))},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
}
// ICMPv4 echo-request (type 8)
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{1}}, // ICMPv4
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{8}}, // echo-request
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
// ICMPv6 echo-request (type 128) + neighbor discovery (types 133-137, required for IPv6)
if e.cfg.IPv6 {
for _, icmp6Type := range []byte{128, 133, 134, 135, 136, 137} {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{58}}, // ICMPv6
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{icmp6Type}},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
}
// Per-IP SYN flood protection via meter
if e.cfg.SYNFloodProtection && e.meterSYN != nil {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}}, // TCP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 13, Len: 1},
&expr.Bitwise{SourceRegister: 1, DestRegister: 1, Len: 1, Mask: []byte{0x12}, Xor: []byte{0x00}},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{0x02}}, // SYN only
// Load source IP for per-IP metering
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Dynset{
SrcRegKey: 1,
SetName: e.meterSYN.Name,
SetID: e.meterSYN.ID,
Operation: 1, // NFT_DYNSET_OP_UPDATE
Exprs: []expr.Any{
&expr.Limit{Type: expr.LimitTypePkts, Rate: 25, Unit: expr.LimitTimeSecond, Burst: 100, Over: true},
},
},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
// Per-IP new connection rate limit via meter
if e.cfg.ConnRateLimit > 0 && e.meterConn != nil {
// #nosec G115 -- ConnRateLimit is an operator-configured int (typical 10–1000);
// /2 is non-negative and well below uint32 max.
burst := uint32(e.cfg.ConnRateLimit / 2)
if burst < 5 {
burst = 5
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1, DestRegister: 1, Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(0)},
// Load source IP for per-IP metering
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Dynset{
SrcRegKey: 1,
SetName: e.meterConn.Name,
SetID: e.meterConn.ID,
Operation: 1,
Exprs: []expr.Any{
&expr.Limit{Type: expr.LimitTypePkts, Rate: uint64(e.cfg.ConnRateLimit), Unit: expr.LimitTimeMinute, Burst: burst, Over: true},
},
},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
// Per-IP concurrent connection limit (CONNLIMIT)
if e.cfg.ConnLimit > 0 && e.meterConnlim != nil {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1, DestRegister: 1, Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(0)},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Dynset{
SrcRegKey: 1,
SetName: e.meterConnlim.Name,
SetID: e.meterConnlim.ID,
Operation: 1,
Exprs: []expr.Any{
// #nosec G115 -- ConnLimit is operator-configured non-negative int; fits in uint32.
&expr.Connlimit{Count: uint32(e.cfg.ConnLimit), Flags: 1}, // 1 = over
},
},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
// Rule 9: Country block - drop traffic from blocked countries
if e.setCountry != nil {
e.addSetMatchRule(e.setCountry, expr.VerdictDrop)
}
// Per-port flood protection - rate limit new connections per port
for _, pf := range e.cfg.PortFlood {
if pf.Hits <= 0 || pf.Seconds <= 0 {
continue
}
proto := byte(6) // TCP
if pf.Proto == "udp" {
proto = 17
}
// Convert hits/seconds to per-minute rate (multiply first to reduce truncation)
ratePerMin := uint64(pf.Hits) * 60 / uint64(pf.Seconds)
if ratePerMin < 1 {
ratePerMin = 1
}
burst := uint32(ratePerMin / 4)
if burst < 2 {
burst = 2
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1, DestRegister: 1, Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(0)},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(pf.Port))},
&expr.Limit{Type: expr.LimitTypePkts, Rate: ratePerMin, Unit: expr.LimitTimeMinute, Burst: burst, Over: true},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
// Per-IP UDP flood protection via meter
if e.cfg.UDPFlood && e.cfg.UDPFloodRate > 0 && e.meterUDP != nil {
// #nosec G115 -- UDPFloodBurst is operator-configured non-negative int.
burst := uint32(e.cfg.UDPFloodBurst)
if burst < 10 {
burst = 10
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{17}}, // UDP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Dynset{
SrcRegKey: 1,
SetName: e.meterUDP.Name,
SetID: e.meterUDP.ID,
Operation: 1,
Exprs: []expr.Any{
&expr.Limit{Type: expr.LimitTypePkts, Rate: uint64(e.cfg.UDPFloodRate), Unit: expr.LimitTimeSecond, Burst: burst, Over: true},
},
},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
// Build restricted port set - these are only reachable via infra IPs (rule 4)
restricted := make(map[int]bool)
for _, p := range e.cfg.RestrictedTCP {
restricted[p] = true
}
// Open TCP ports (public) - restricted ports excluded
for _, port := range e.cfg.TCPIn {
if restricted[port] {
continue
}
e.addPortAcceptRule(port, true)
}
// Open UDP ports (public)
for _, port := range e.cfg.UDPIn {
e.addPortAcceptRule(port, false)
}
// Passive FTP range
if e.cfg.PassiveFTPStart > 0 && e.cfg.PassiveFTPEnd > 0 {
e.addPortRangeAcceptRule(e.cfg.PassiveFTPStart, e.cfg.PassiveFTPEnd, true)
}
// Silent drop for commonly-scanned ports (no logging)
for _, port := range e.cfg.DropNoLog {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}}, // TCP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{17}}, // UDP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
// Rate-limited log for remaining dropped packets
if e.cfg.LogDropped {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Limit{
Type: expr.LimitTypePkts,
Rate: uint64(max(e.cfg.LogRate, 1)),
Unit: expr.LimitTimeMinute,
Burst: 5,
},
&expr.Log{Key: 1, Data: []byte("CSM-DROP: ")},
},
})
}
// Default policy is DROP - anything not matched above is dropped
return nil
}
// createOutputChain builds the output filter chain.
// Restricts outbound to configured ports only (prevents C2 on non-standard ports).
func (e *Engine) createOutputChain() error {
if len(e.cfg.TCPOut) == 0 && len(e.cfg.UDPOut) == 0 {
// No outbound restrictions configured - accept all
policy := nftables.ChainPolicyAccept
e.chainOut = e.conn.AddChain(&nftables.Chain{
Name: "output",
Table: e.table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityFilter,
Policy: &policy,
})
return nil
}
// Outbound filtering enabled
policy := nftables.ChainPolicyDrop
e.chainOut = e.conn.AddChain(&nftables.Chain{
Name: "output",
Table: e.table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityFilter,
Policy: &policy,
})
// Allow established/related outbound
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Ct{Register: 1, SourceRegister: false, Key: expr.CtKeySTATE},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(0)},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
// Allow loopback outbound
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte("lo\x00")},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
// SMTP block - restrict outbound mail to allowed users only
smtpBlocked := make(map[int]bool)
if e.cfg.SMTPBlock && len(e.cfg.SMTPPorts) > 0 {
// Resolve usernames to UIDs
var allowedUIDs []uint32
for _, username := range e.cfg.SMTPAllowUsers {
u, err := user.Lookup(username)
if err != nil {
fmt.Fprintf(os.Stderr, "firewall: smtp_allow_users: unknown user %q\n", username)
continue
}
uid, parseErr := strconv.ParseUint(u.Uid, 10, 32)
if parseErr != nil {
fmt.Fprintf(os.Stderr, "firewall: smtp_allow_users: invalid uid for %s: %v\n", username, parseErr)
continue
}
allowedUIDs = append(allowedUIDs, uint32(uid))
}
// Always allow root
allowedUIDs = append(allowedUIDs, 0)
for _, port := range e.cfg.SMTPPorts {
smtpBlocked[port] = true
// Accept from each allowed UID
for _, uid := range allowedUIDs {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Meta{Key: expr.MetaKeySKUID, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(uid)},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
// Drop SMTP from everyone else
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Verdict{Kind: expr.VerdictDrop},
},
})
}
}
// Allow configured outbound TCP ports (skip SMTP-blocked ports - handled above)
for _, port := range e.cfg.TCPOut {
if smtpBlocked[port] {
continue
}
e.addOutboundPortRule(port, true)
}
// Allow configured outbound UDP ports
for _, port := range e.cfg.UDPOut {
e.addOutboundPortRule(port, false)
}
// Allow only safe ICMP outbound (echo-reply + echo-request, block dest-unreachable)
// Blocking ICMP type 3 (dest-unreachable) prevents leaking closed port info to scanners
for _, icmpType := range []byte{0, 8} { // 0=echo-reply, 8=echo-request
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{1}}, // ICMP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{icmpType}},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
// ICMPv6 outbound - allow echo-reply (129) + echo-request (128) + ND (133-137)
if e.cfg.IPv6 {
for _, icmp6Type := range []byte{128, 129, 133, 134, 135, 136, 137} {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{58}}, // ICMPv6
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{icmp6Type}},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
}
// REJECT outbound TCP with RST (faster failure than silent DROP)
// UDP still silently drops via chain policy.
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}},
&expr.Reject{Type: 1}, // NFT_REJECT_TCP_RST
},
})
return nil
}
// --- Helper methods ---
// resolveIPSet returns the appropriate set and key bytes for an IP address.
// Falls back to IPv6 set if the IP is not IPv4.
func (e *Engine) resolveIPSet(ip string, set4, set6 *nftables.Set) (*nftables.Set, []byte, error) {
parsed := net.ParseIP(ip)
if parsed == nil {
return nil, nil, fmt.Errorf("invalid IP: %s", ip)
}
if ip4 := parsed.To4(); ip4 != nil {
return set4, ip4, nil
}
if set6 == nil {
return nil, nil, fmt.Errorf("IPv6 not enabled in firewall config: %s", ip)
}
return set6, parsed.To16(), nil
}
// addSetMatchRule adds an IPv4 source-IP set match rule on the input chain.
func (e *Engine) addSetMatchRule(set *nftables.Set, verdict expr.VerdictKind) {
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
&expr.Lookup{SourceRegister: 1, SetName: set.Name, SetID: set.ID},
&expr.Verdict{Kind: verdict},
},
})
}
// addSetMatchRuleV6 adds an IPv6 source-IP set match rule on the input chain.
func (e *Engine) addSetMatchRuleV6(set *nftables.Set, verdict expr.VerdictKind) {
if set == nil {
return
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyNFPROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{10}}, // NFPROTO_IPV6
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 8, Len: 16},
&expr.Lookup{SourceRegister: 1, SetName: set.Name, SetID: set.ID},
&expr.Verdict{Kind: verdict},
},
})
}
// addCFWhitelistRule adds an accept rule for Cloudflare IPs on TCP ports 80 and 443.
// Equivalent to: ip saddr @cf_whitelist tcp dport {80, 443} accept
func (e *Engine) addCFWhitelistRule(set *nftables.Set, ipv6 bool) {
if set == nil {
return
}
for _, port := range []uint16{80, 443} {
var exprs []expr.Any
if ipv6 {
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyNFPROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{10}}, // NFPROTO_IPV6
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 8, Len: 16},
)
} else {
exprs = append(exprs,
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4},
)
}
exprs = append(exprs,
&expr.Lookup{SourceRegister: 1, SetName: set.Name, SetID: set.ID},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{6}}, // TCP
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(port)},
&expr.Verdict{Kind: expr.VerdictAccept},
)
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: exprs,
})
}
}
func (e *Engine) addPortAcceptRule(port int, tcp bool) {
proto := byte(6) // TCP
if !tcp {
proto = 17 // UDP
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
func (e *Engine) addPortRangeAcceptRule(startPort, endPort int, tcp bool) {
proto := byte(6)
if !tcp {
proto = 17
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainIn,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
// Load dest port once, check range
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpGte, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(startPort))},
&expr.Cmp{Op: expr.CmpOpLte, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(endPort))},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
func (e *Engine) addOutboundPortRule(port int, tcp bool) {
proto := byte(6)
if !tcp {
proto = 17
}
e.conn.AddRule(&nftables.Rule{
Table: e.table,
Chain: e.chainOut,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.BigEndian.PutUint16(portU16(port))},
&expr.Verdict{Kind: expr.VerdictAccept},
},
})
}
// --- Public API ---
// BlockIP adds an IP to the blocked set with optional timeout.
// timeout 0 = permanent block.
func (e *Engine) BlockIP(ip string, reason string, timeout time.Duration) error {
e.mu.Lock()
defer e.mu.Unlock()
targetSet, key, err := e.resolveIPSet(ip, e.setBlocked, e.setBlocked6)
if err != nil {
return err
}
// SAFETY: never block infra IPs - prevents admin lockout
for _, cidr := range e.cfg.InfraIPs {
_, network, cidrErr := net.ParseCIDR(cidr)
if cidrErr != nil {
if cidr == ip {
return fmt.Errorf("refusing to block infra IP: %s", ip)
}
continue
}
if network.Contains(net.ParseIP(ip)) {
return fmt.Errorf("refusing to block infra IP: %s (in %s)", ip, cidr)
}
}
// Enforce deny IP limits
if e.cfg.DenyIPLimit > 0 || e.cfg.DenyTempIPLimit > 0 {
st := e.loadStateFile()
perm, temp := 0, 0
for _, b := range st.Blocked {
if b.ExpiresAt.IsZero() {
perm++
} else {
temp++
}
}
if timeout == 0 && e.cfg.DenyIPLimit > 0 && perm >= e.cfg.DenyIPLimit {
return fmt.Errorf("permanent deny limit reached (%d)", e.cfg.DenyIPLimit)
}
if timeout > 0 && e.cfg.DenyTempIPLimit > 0 && temp >= e.cfg.DenyTempIPLimit {
return fmt.Errorf("temporary deny limit reached (%d)", e.cfg.DenyTempIPLimit)
}
}
elem := []nftables.SetElement{{Key: key, Timeout: timeout}}
if err := e.conn.SetAddElements(targetSet, elem); err != nil {
return fmt.Errorf("adding to blocked set: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
// Persist - zero ExpiresAt means permanent
entry := BlockedEntry{
IP: ip,
Reason: reason,
Source: InferProvenance("block", reason),
BlockedAt: time.Now(),
}
if timeout > 0 {
entry.ExpiresAt = time.Now().Add(timeout)
}
e.saveBlockedEntry(entry)
AppendAudit(e.statePath, "block", ip, reason, entry.Source, timeout)
return nil
}
// UnblockIP removes an IP from the blocked set and state.
func (e *Engine) UnblockIP(ip string) error {
e.mu.Lock()
defer e.mu.Unlock()
targetSet, key, err := e.resolveIPSet(ip, e.setBlocked, e.setBlocked6)
if err != nil {
return err
}
if err := e.conn.SetDeleteElements(targetSet, []nftables.SetElement{{Key: key}}); err != nil {
return fmt.Errorf("removing from blocked set: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
e.removeBlockedState(ip)
AppendAudit(e.statePath, "unblock", ip, "", "", 0)
return nil
}
// IsBlocked returns true if the IP is currently in the engine's blocked state.
// Uses the persisted state file (which is cleaned of expired entries on load).
func (e *Engine) IsBlocked(ip string) bool {
e.mu.Lock()
defer e.mu.Unlock()
st := e.loadStateFile()
for _, entry := range st.Blocked {
if entry.IP == ip {
return true
}
}
return false
}
// AllowIP adds an IP to the allowed set and persists it.
// If the IP is currently blocked, the block is removed first.
func (e *Engine) AllowIP(ip string, reason string) error {
e.mu.Lock()
defer e.mu.Unlock()
blockedSet, blockedKey, _ := e.resolveIPSet(ip, e.setBlocked, e.setBlocked6)
allowedSet, allowedKey, err := e.resolveIPSet(ip, e.setAllowed, e.setAllowed6)
if err != nil {
return err
}
// Remove from blocked set + add to allowed set in same batch
if blockedSet != nil {
_ = e.conn.SetDeleteElements(blockedSet, []nftables.SetElement{{Key: blockedKey}})
}
if err := e.conn.SetAddElements(allowedSet, []nftables.SetElement{{Key: allowedKey}}); err != nil {
return fmt.Errorf("adding to allowed set: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
// State mutations only after successful flush
e.removeBlockedState(ip)
entry := AllowedEntry{IP: ip, Reason: reason, Source: InferProvenance("allow", reason)}
e.saveAllowedEntry(entry)
AppendAudit(e.statePath, "allow", ip, reason, entry.Source, 0)
return nil
}
// TempAllowIP adds a temporary allow with expiry. Uses the same allowed set
// but tracks expiry in state - CleanExpiredAllows removes them periodically.
func (e *Engine) TempAllowIP(ip string, reason string, timeout time.Duration) error {
e.mu.Lock()
defer e.mu.Unlock()
blockedSet, blockedKey, _ := e.resolveIPSet(ip, e.setBlocked, e.setBlocked6)
allowedSet, allowedKey, err := e.resolveIPSet(ip, e.setAllowed, e.setAllowed6)
if err != nil {
return err
}
if blockedSet != nil {
_ = e.conn.SetDeleteElements(blockedSet, []nftables.SetElement{{Key: blockedKey}})
}
if err := e.conn.SetAddElements(allowedSet, []nftables.SetElement{{Key: allowedKey}}); err != nil {
return fmt.Errorf("adding to allowed set: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
e.removeBlockedState(ip)
entry := AllowedEntry{IP: ip, Reason: reason, Source: InferProvenance("temp_allow", reason)}
if timeout > 0 {
entry.ExpiresAt = time.Now().Add(timeout)
}
e.saveAllowedEntry(entry)
AppendAudit(e.statePath, "temp_allow", ip, reason, entry.Source, timeout)
return nil
}
// CleanExpiredAllows removes expired temporary allows from the set and state.
// An IP is only removed from nftables if no non-expired entries remain for it.
// Called periodically by the daemon.
func (e *Engine) CleanExpiredAllows() int {
e.mu.Lock()
defer e.mu.Unlock()
state := e.loadStateFile()
now := time.Now()
var active []AllowedEntry
expiredIPs := make(map[string]bool)
removed := 0
for _, entry := range state.Allowed {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
expiredIPs[entry.IP] = true
removed++
AppendAudit(e.statePath, "temp_allow_expired", entry.IP, "", SourceSystem, 0)
} else {
active = append(active, entry)
}
}
if removed > 0 {
// Only remove from nftables if no active entries remain for the IP
activeIPs := make(map[string]bool)
for _, entry := range active {
activeIPs[entry.IP] = true
}
for ip := range expiredIPs {
if !activeIPs[ip] {
if set, key, err := e.resolveIPSet(ip, e.setAllowed, e.setAllowed6); err == nil {
_ = e.conn.SetDeleteElements(set, []nftables.SetElement{{Key: key}})
}
}
}
if err := e.conn.Flush(); err != nil {
fmt.Fprintf(os.Stderr, "firewall: error flushing expired allows: %v\n", err)
return 0 // don't update state; will retry on next tick
}
state.Allowed = active
e.saveState(&state)
}
return removed
}
// CleanExpiredSubnets removes expired temporary subnet blocks from nftables and state.
func (e *Engine) CleanExpiredSubnets() int {
e.mu.Lock()
defer e.mu.Unlock()
state := e.loadStateFile()
now := time.Now()
var active []SubnetEntry
removed := 0
for _, entry := range state.BlockedNet {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
if _, network, err := net.ParseCIDR(entry.CIDR); err == nil {
if set, start, end := e.resolveSubnetSet(network); set != nil {
_ = e.conn.SetDeleteElements(set, []nftables.SetElement{
{Key: start},
{Key: nextIP(end), IntervalEnd: true},
})
}
}
removed++
AppendAudit(e.statePath, "temp_subnet_expired", entry.CIDR, "", SourceSystem, 0)
continue
}
active = append(active, entry)
}
if removed > 0 {
_ = e.conn.Flush()
state.BlockedNet = active
e.saveState(&state)
}
return removed
}
// RemoveAllowIP removes an IP from the allowed set and state.
func (e *Engine) RemoveAllowIP(ip string) error {
e.mu.Lock()
defer e.mu.Unlock()
targetSet, key, err := e.resolveIPSet(ip, e.setAllowed, e.setAllowed6)
if err != nil {
return err
}
if err := e.conn.SetDeleteElements(targetSet, []nftables.SetElement{{Key: key}}); err != nil {
return fmt.Errorf("removing from allowed set: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
e.removeAllowedState(ip)
AppendAudit(e.statePath, "remove_allow", ip, "", "", 0)
return nil
}
// RemoveAllowIPBySource removes only allow entries from a specific source.
// The IP is only removed from the nftables set if no other sources remain.
func (e *Engine) RemoveAllowIPBySource(ip, source string) error {
e.mu.Lock()
defer e.mu.Unlock()
ipGone := e.removeAllowedStateBySource(ip, source)
if ipGone {
targetSet, key, err := e.resolveIPSet(ip, e.setAllowed, e.setAllowed6)
if err != nil {
return err
}
if err := e.conn.SetDeleteElements(targetSet, []nftables.SetElement{{Key: key}}); err != nil {
return fmt.Errorf("removing from allowed set: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
}
AppendAudit(e.statePath, "remove_allow", ip, "source: "+source, source, 0)
return nil
}
// AllowIPPort adds a port-specific IP allow. The rule is persisted to state
// and applied on the next Apply(). For immediate effect, call Apply() after.
func (e *Engine) AllowIPPort(ip string, port int, proto string, reason string) error {
if net.ParseIP(ip) == nil {
return fmt.Errorf("invalid IP: %s", ip)
}
if port < 1 || port > 65535 {
return fmt.Errorf("invalid port: %d", port)
}
if proto != "tcp" && proto != "udp" {
proto = "tcp"
}
e.mu.Lock()
defer e.mu.Unlock()
st := e.loadStateFile()
// Deduplicate
for _, existing := range st.PortAllowed {
if existing.IP == ip && existing.Port == port && existing.Proto == proto {
return nil // already exists
}
}
st.PortAllowed = append(st.PortAllowed, PortAllowEntry{
IP: ip, Port: port, Proto: proto, Reason: reason, Source: InferProvenance("allow_port", reason),
})
e.saveState(&st)
AppendAudit(e.statePath, "allow_port", fmt.Sprintf("%s:%d/%s", ip, port, proto), reason, InferProvenance("allow_port", reason), 0)
return nil
}
// RemoveAllowIPPort removes a port-specific IP allow from state.
func (e *Engine) RemoveAllowIPPort(ip string, port int, proto string) error {
e.mu.Lock()
defer e.mu.Unlock()
st := e.loadStateFile()
var remaining []PortAllowEntry
found := false
for _, entry := range st.PortAllowed {
if entry.IP == ip && entry.Port == port && entry.Proto == proto {
found = true
continue
}
remaining = append(remaining, entry)
}
if !found {
return fmt.Errorf("port allow not found: %s:%d/%s", ip, port, proto)
}
st.PortAllowed = remaining
e.saveState(&st)
AppendAudit(e.statePath, "remove_port_allow", fmt.Sprintf("%s:%d/%s", ip, port, proto), "", "", 0)
return nil
}
// FlushBlocked removes all IPs from the blocked set and clears persisted state.
func (e *Engine) FlushBlocked() error {
e.mu.Lock()
defer e.mu.Unlock()
e.conn.FlushSet(e.setBlocked)
if e.setBlocked6 != nil {
e.conn.FlushSet(e.setBlocked6)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing blocked set: %w", err)
}
state := e.loadStateFile()
count := len(state.Blocked)
state.Blocked = nil
e.saveState(&state)
AppendAudit(e.statePath, "flush", "", fmt.Sprintf("cleared %d entries", count), SourceSystem, 0)
return nil
}
// BlockSubnet adds a CIDR range to the blocked subnets set (IPv4 or IPv6).
// timeout 0 = permanent block.
func (e *Engine) BlockSubnet(cidr string, reason string, timeout time.Duration) error {
e.mu.Lock()
defer e.mu.Unlock()
_, network, err := net.ParseCIDR(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR: %s", cidr)
}
targetSet, start, end := e.resolveSubnetSet(network)
if targetSet == nil {
return fmt.Errorf("no matching set for %s (IPv6 disabled?)", cidr)
}
elements := []nftables.SetElement{
{Key: start},
{Key: nextIP(end), IntervalEnd: true},
}
if err := e.conn.SetAddElements(targetSet, elements); err != nil {
return fmt.Errorf("adding to blocked_nets: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
entry := SubnetEntry{
CIDR: network.String(),
Reason: reason,
Source: InferProvenance("block_subnet", reason),
BlockedAt: time.Now(),
}
if timeout > 0 {
entry.ExpiresAt = time.Now().Add(timeout)
}
e.saveSubnetEntry(entry)
AppendAudit(e.statePath, "block_subnet", network.String(), reason, entry.Source, timeout)
return nil
}
// UnblockSubnet removes a CIDR range from the blocked subnets set (IPv4 or IPv6).
func (e *Engine) UnblockSubnet(cidr string) error {
e.mu.Lock()
defer e.mu.Unlock()
_, network, err := net.ParseCIDR(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR: %s", cidr)
}
targetSet, start, end := e.resolveSubnetSet(network)
if targetSet == nil {
return fmt.Errorf("no matching set for %s (IPv6 disabled?)", cidr)
}
elements := []nftables.SetElement{
{Key: start},
{Key: nextIP(end), IntervalEnd: true},
}
if err := e.conn.SetDeleteElements(targetSet, elements); err != nil {
return fmt.Errorf("removing from blocked_nets: %w", err)
}
if err := e.conn.Flush(); err != nil {
return fmt.Errorf("flushing: %w", err)
}
e.removeSubnetState(network.String())
AppendAudit(e.statePath, "unblock_subnet", network.String(), "", "", 0)
return nil
}
// resolveSubnetSet returns the correct blocked_nets set and start/end keys for a CIDR.
func (e *Engine) resolveSubnetSet(network *net.IPNet) (*nftables.Set, net.IP, net.IP) {
end := lastIPInRange(network)
if start := network.IP.To4(); start != nil {
return e.setBlockedNet, start, end
}
if e.setBlockedNet6 != nil {
return e.setBlockedNet6, network.IP.To16(), end
}
return nil, nil, nil
}
// Status returns current firewall statistics.
func (e *Engine) Status() map[string]interface{} {
state := e.loadStateFile()
return map[string]interface{}{
"enabled": e.cfg.Enabled,
"tcp_in": e.cfg.TCPIn,
"tcp_out": e.cfg.TCPOut,
"udp_in": e.cfg.UDPIn,
"udp_out": e.cfg.UDPOut,
"infra_ips": e.cfg.InfraIPs,
"blocked": len(state.Blocked),
"allowed": len(state.Allowed),
"log_dropped": e.cfg.LogDropped,
}
}
// --- State persistence ---
func (e *Engine) loadState() error {
state := e.loadStateFile()
now := time.Now()
// Restore blocked IPs (skip expired, route to IPv4 or IPv6 set)
for _, entry := range state.Blocked {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
continue
}
parsed := net.ParseIP(entry.IP)
if parsed == nil {
continue
}
timeout := time.Duration(0)
if !entry.ExpiresAt.IsZero() {
timeout = time.Until(entry.ExpiresAt)
}
if ip4 := parsed.To4(); ip4 != nil {
_ = e.conn.SetAddElements(e.setBlocked, []nftables.SetElement{{Key: ip4, Timeout: timeout}})
} else if e.setBlocked6 != nil {
_ = e.conn.SetAddElements(e.setBlocked6, []nftables.SetElement{{Key: parsed.To16(), Timeout: timeout}})
}
}
// Restore allowed IPs (skip expired, deduplicate, route to IPv4 or IPv6 set)
restoredAllowed := make(map[string]bool)
for _, entry := range state.Allowed {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
continue
}
if restoredAllowed[entry.IP] {
continue // already added to nftables from another source entry
}
parsed := net.ParseIP(entry.IP)
if parsed == nil {
continue
}
if ip4 := parsed.To4(); ip4 != nil {
_ = e.conn.SetAddElements(e.setAllowed, []nftables.SetElement{{Key: ip4}})
} else if e.setAllowed6 != nil {
_ = e.conn.SetAddElements(e.setAllowed6, []nftables.SetElement{{Key: parsed.To16()}})
}
restoredAllowed[entry.IP] = true
}
// Restore blocked subnets (route to IPv4 or IPv6 set)
for _, entry := range state.BlockedNet {
if !entry.ExpiresAt.IsZero() && now.After(entry.ExpiresAt) {
continue
}
_, network, err := net.ParseCIDR(entry.CIDR)
if err != nil {
continue
}
start := network.IP.To4()
end := lastIPInRange(network)
if start != nil && end != nil {
_ = e.conn.SetAddElements(e.setBlockedNet, []nftables.SetElement{
{Key: start},
{Key: nextIP(end), IntervalEnd: true},
})
} else if e.setBlockedNet6 != nil {
start6 := network.IP.To16()
end6 := lastIPInRange(network)
_ = e.conn.SetAddElements(e.setBlockedNet6, []nftables.SetElement{
{Key: start6},
{Key: nextIP(end6), IntervalEnd: true},
})
}
}
return e.conn.Flush()
}
func (e *Engine) loadStateFile() FirewallState {
var state FirewallState
stateFile := filepath.Join(e.statePath, "state.json")
if !fileExistsFirewall(stateFile) {
return state
}
// #nosec G304 -- filepath.Join under operator-configured statePath.
data, _ := os.ReadFile(stateFile)
_ = json.Unmarshal(data, &state)
// Clean expired entries
now := time.Now()
var active []BlockedEntry
for _, entry := range state.Blocked {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
active = append(active, entry)
}
}
state.Blocked = active
var activeNets []SubnetEntry
for _, entry := range state.BlockedNet {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
activeNets = append(activeNets, entry)
}
}
state.BlockedNet = activeNets
var activeAllowed []AllowedEntry
for _, entry := range state.Allowed {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
activeAllowed = append(activeAllowed, entry)
}
}
state.Allowed = activeAllowed
return state
}
func (e *Engine) saveState(state *FirewallState) {
path := filepath.Join(e.statePath, "state.json")
data, _ := json.MarshalIndent(state, "", " ")
tmpPath := path + ".tmp"
_ = os.WriteFile(tmpPath, data, 0600)
_ = os.Rename(tmpPath, path)
}
func (e *Engine) saveBlockedEntry(entry BlockedEntry) {
if entry.Source == "" {
entry.Source = InferProvenance("block", entry.Reason)
}
state := e.loadStateFile()
// Deduplicate
for i, existing := range state.Blocked {
if existing.IP == entry.IP {
state.Blocked[i] = entry
e.saveState(&state)
return
}
}
state.Blocked = append(state.Blocked, entry)
e.saveState(&state)
}
func (e *Engine) removeBlockedState(ip string) {
state := e.loadStateFile()
var remaining []BlockedEntry
for _, entry := range state.Blocked {
if entry.IP != ip {
remaining = append(remaining, entry)
}
}
state.Blocked = remaining
e.saveState(&state)
}
func (e *Engine) saveAllowedEntry(entry AllowedEntry) {
if entry.Source == "" {
entry.Source = InferProvenance("allow", entry.Reason)
}
state := e.loadStateFile()
for i, existing := range state.Allowed {
if existing.IP == entry.IP && existing.Source == entry.Source {
state.Allowed[i] = entry // update reason/expiry for same source
e.saveState(&state)
return
}
}
state.Allowed = append(state.Allowed, entry)
e.saveState(&state)
}
func (e *Engine) removeAllowedState(ip string) {
state := e.loadStateFile()
var remaining []AllowedEntry
for _, entry := range state.Allowed {
if entry.IP != ip {
remaining = append(remaining, entry)
}
}
state.Allowed = remaining
e.saveState(&state)
}
// removeAllowedStateBySource removes only entries matching ip+source.
// Returns true if no entries remain for that IP (caller should remove from nftables).
// Returns false if the IP was not in state at all (no action needed).
func (e *Engine) removeAllowedStateBySource(ip, source string) bool {
state := e.loadStateFile()
var remaining []AllowedEntry
found := false
ipStillPresent := false
for _, entry := range state.Allowed {
if entry.IP == ip && entry.Source == source {
found = true
continue
}
remaining = append(remaining, entry)
if entry.IP == ip {
ipStillPresent = true
}
}
if !found {
return false // IP+source not in state, nothing to do
}
state.Allowed = remaining
e.saveState(&state)
return !ipStillPresent
}
func (e *Engine) saveSubnetEntry(entry SubnetEntry) {
if entry.Source == "" {
entry.Source = InferProvenance("block_subnet", entry.Reason)
}
state := e.loadStateFile()
for _, existing := range state.BlockedNet {
if existing.CIDR == entry.CIDR {
return
}
}
state.BlockedNet = append(state.BlockedNet, entry)
e.saveState(&state)
}
func (e *Engine) removeSubnetState(cidr string) {
state := e.loadStateFile()
var remaining []SubnetEntry
for _, entry := range state.BlockedNet {
if entry.CIDR != cidr {
remaining = append(remaining, entry)
}
}
state.BlockedNet = remaining
e.saveState(&state)
}
// IP helpers (nextIP, lastIPInRange, fileExistsFirewall) moved to ip_helpers.go (no build tag).
// loadCountryCIDRs reads CIDR ranges from a country file.
// Expected format: one CIDR per line in {dbPath}/{CODE}.cidr
func loadCountryCIDRs(dbPath, countryCode string) []nftables.SetElement {
file := filepath.Join(dbPath, strings.ToUpper(countryCode)+".cidr")
// #nosec G304 -- filepath.Join under operator-configured GeoIP dbPath.
data, err := os.ReadFile(file)
if err != nil {
return nil
}
var elements []nftables.SetElement
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
_, network, err := net.ParseCIDR(line)
if err != nil {
continue
}
start := network.IP.To4()
end := lastIPInRange(network)
if start != nil && end != nil {
elements = append(elements,
nftables.SetElement{Key: start},
nftables.SetElement{Key: nextIP(end), IntervalEnd: true},
)
}
}
return elements
}
package firewall
import (
"bufio"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const geoIPBaseURL = "https://raw.githubusercontent.com/herrbischoff/country-ip-blocks/master/ipv4/"
// UpdateGeoIPDB downloads country CIDR lists from a public source.
// Creates one file per country code: {dbPath}/{CC}.cidr
func UpdateGeoIPDB(dbPath string, countryCodes []string) (int, error) {
if err := os.MkdirAll(dbPath, 0700); err != nil {
return 0, fmt.Errorf("creating geoip directory: %w", err)
}
client := &http.Client{Timeout: 30 * time.Second}
updated := 0
for _, code := range countryCodes {
code = strings.ToLower(strings.TrimSpace(code))
if len(code) != 2 {
continue
}
url := geoIPBaseURL + code + ".cidr"
resp, err := client.Get(url)
if err != nil {
fmt.Fprintf(os.Stderr, "geoip: error downloading %s: %v\n", code, err)
continue
}
if resp.StatusCode != 200 {
resp.Body.Close()
fmt.Fprintf(os.Stderr, "geoip: %s returned HTTP %d\n", code, resp.StatusCode)
continue
}
outPath := filepath.Join(dbPath, strings.ToUpper(code)+".cidr")
tmpPath := outPath + ".tmp"
// #nosec G304 -- filepath.Join under operator-configured dbPath; code from fixed list.
f, err := os.Create(tmpPath)
if err != nil {
resp.Body.Close()
continue
}
n, _ := io.Copy(f, resp.Body)
f.Close()
resp.Body.Close()
if n < 10 {
os.Remove(tmpPath)
fmt.Fprintf(os.Stderr, "geoip: %s too small (%d bytes), skipping\n", code, n)
continue
}
_ = os.Rename(tmpPath, outPath)
updated++
fmt.Fprintf(os.Stderr, "geoip: updated %s (%d bytes)\n", strings.ToUpper(code), n)
}
return updated, nil
}
// LookupIP finds which country CIDR files contain the given IP.
// Returns matching country codes.
func LookupIP(dbPath string, ip string) []string {
parsed := net.ParseIP(ip)
if parsed == nil {
return nil
}
ip4 := parsed.To4()
if ip4 == nil {
return nil
}
entries, err := os.ReadDir(dbPath)
if err != nil {
return nil
}
var matches []string
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".cidr") {
continue
}
code := strings.TrimSuffix(entry.Name(), ".cidr")
if containsIP(filepath.Join(dbPath, entry.Name()), ip4) {
matches = append(matches, code)
}
}
return matches
}
func containsIP(cidrFile string, ip net.IP) bool {
// #nosec G304 -- cidrFile is filepath.Join under operator-configured dbPath.
f, err := os.Open(cidrFile)
if err != nil {
return false
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
_, network, err := net.ParseCIDR(line)
if err != nil {
continue
}
if network.Contains(ip) {
return true
}
}
return false
}
package firewall
import (
"net"
"os"
)
// nextIP returns the IP address immediately following the given IP.
// Used for nftables interval set end markers.
func nextIP(ip net.IP) net.IP {
next := make(net.IP, len(ip))
copy(next, ip)
for i := len(next) - 1; i >= 0; i-- {
next[i]++
if next[i] != 0 {
break
}
}
return next
}
// lastIPInRange returns the last IP address in a CIDR range.
func lastIPInRange(network *net.IPNet) net.IP {
ip := network.IP.To4()
if ip == nil {
ip = network.IP.To16()
}
if ip == nil {
return nil
}
mask := network.Mask
last := make(net.IP, len(ip))
for i := range ip {
if i < len(mask) {
last[i] = ip[i] | ^mask[i]
} else {
last[i] = ip[i]
}
}
return last
}
// fileExistsFirewall checks if a file exists.
func fileExistsFirewall(path string) bool {
_, err := os.Stat(path)
return err == nil
}
package firewall
import "strings"
const (
SourceUnknown = "unknown"
SourceWebUI = "web_ui"
SourceCLI = "cli"
SourceAutoResponse = "auto_response"
SourceChallenge = "challenge"
SourceWhitelist = "whitelist"
SourceDynDNS = "dyndns"
SourceSystem = "system"
)
// InferProvenance classifies a firewall entry source from structured action/reason text.
// This keeps provenance logic centralized instead of spreading fragile string checks
// throughout the web UI and firewall call sites.
func InferProvenance(action, reason string) string {
action = strings.ToLower(strings.TrimSpace(action))
reason = strings.ToLower(strings.TrimSpace(reason))
switch {
case strings.Contains(reason, "dyndns:"):
return SourceDynDNS
case strings.Contains(reason, "passed challenge"),
strings.Contains(reason, "challenge-timeout"),
strings.Contains(reason, "challenge timeout"):
return SourceChallenge
case strings.Contains(reason, "temp whitelist"),
strings.Contains(reason, "whitelist"),
strings.Contains(reason, "bulk whitelist"),
strings.Contains(reason, "customer ip"):
return SourceWhitelist
case strings.Contains(reason, "auto-block"),
strings.Contains(reason, "permbblock"),
strings.Contains(reason, "permblock"),
strings.Contains(reason, "auto-netblock"):
return SourceAutoResponse
case strings.Contains(reason, "via cli"):
return SourceCLI
case strings.Contains(reason, "via csm web ui"),
strings.Contains(reason, "via ui"),
strings.Contains(reason, "allowed from firewall lookup"),
strings.Contains(reason, "manual block"):
return SourceWebUI
case action == "temp_allow_expired":
return SourceSystem
case action == "flush":
return SourceSystem
default:
return SourceUnknown
}
}
package firewall
import (
"encoding/json"
"os"
"path/filepath"
"time"
)
// LoadState reads the firewall state file directly without requiring a running engine.
// Used by CLI commands that only need to display state.
// Note: callers that have access to the store package should check store.Global()
// first for bbolt-backed state. This function reads flat-file state.json only.
func LoadState(statePath string) (*FirewallState, error) {
stateFile := filepath.Join(statePath, "firewall", "state.json")
// #nosec G304 -- filepath.Join under operator-configured statePath.
data, err := os.ReadFile(stateFile)
if err != nil {
if os.IsNotExist(err) {
return &FirewallState{}, nil
}
return nil, err
}
var state FirewallState
if err := json.Unmarshal(data, &state); err != nil {
return nil, err
}
// Clean expired entries
now := time.Now()
var active []BlockedEntry
for _, entry := range state.Blocked {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
active = append(active, entry)
}
}
state.Blocked = active
var activeNets []SubnetEntry
for _, entry := range state.BlockedNet {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
activeNets = append(activeNets, entry)
}
}
state.BlockedNet = activeNets
var activeAllowed []AllowedEntry
for _, entry := range state.Allowed {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
activeAllowed = append(activeAllowed, entry)
}
}
state.Allowed = activeAllowed
return &state, nil
}
// Package geoip provides IP geolocation via MaxMind GeoLite2 databases
// and on-demand RDAP lookups for detailed ISP/org information.
package geoip
import (
"encoding/json"
"fmt"
"net/http"
"net/netip"
"os"
"path/filepath"
"sync"
"time"
"github.com/oschwald/maxminddb-golang/v2"
)
// Info contains geolocation and network information for an IP.
type Info struct {
IP string `json:"ip"`
Country string `json:"country"` // ISO 3166-1 alpha-2 (e.g. "US")
CountryName string `json:"country_name"` // Full name (e.g. "United States")
City string `json:"city,omitempty"`
ASN uint `json:"asn,omitempty"` // Autonomous System Number
ASOrg string `json:"as_org,omitempty"` // AS Organization (ISP)
Network string `json:"network,omitempty"` // CIDR range
RDAPOrg string `json:"rdap_org,omitempty"` // Detailed org from RDAP (on-demand)
RDAPName string `json:"rdap_name,omitempty"` // Network name from RDAP
RDAPCountry string `json:"rdap_country,omitempty"` // Country from RDAP
}
// DB holds the MaxMind database readers.
type DB struct {
mu sync.RWMutex
cityDB *maxminddb.Reader
asnDB *maxminddb.Reader
dbDir string
rdapMu sync.Mutex
rdapTTL map[string]rdapCacheEntry
}
type rdapCacheEntry struct {
info Info
fetched time.Time
}
// MaxMind GeoLite2 record structures
type cityRecord struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
Names map[string]string `maxminddb:"names"`
} `maxminddb:"country"`
City struct {
Names map[string]string `maxminddb:"names"`
} `maxminddb:"city"`
}
type asnRecord struct {
ASN uint `maxminddb:"autonomous_system_number"`
Org string `maxminddb:"autonomous_system_organization"`
}
// Open loads MaxMind databases from the given directory.
// Expects GeoLite2-City.mmdb and/or GeoLite2-ASN.mmdb.
// Returns nil if no databases found (graceful degradation).
func Open(dbDir string) *DB {
if dbDir == "" {
return nil
}
db := &DB{
dbDir: dbDir,
rdapTTL: make(map[string]rdapCacheEntry),
}
cityPath := filepath.Join(dbDir, "GeoLite2-City.mmdb")
if r, err := maxminddb.Open(cityPath); err == nil {
db.cityDB = r
fmt.Fprintf(os.Stderr, "geoip: loaded %s\n", cityPath)
}
asnPath := filepath.Join(dbDir, "GeoLite2-ASN.mmdb")
if r, err := maxminddb.Open(asnPath); err == nil {
db.asnDB = r
fmt.Fprintf(os.Stderr, "geoip: loaded %s\n", asnPath)
}
if db.cityDB == nil && db.asnDB == nil {
fmt.Fprintf(os.Stderr, "geoip: no databases found in %s (download GeoLite2-City.mmdb and GeoLite2-ASN.mmdb)\n", dbDir)
return nil
}
return db
}
// Close releases database resources.
func (db *DB) Close() {
if db == nil {
return
}
db.mu.Lock()
defer db.mu.Unlock()
if db.cityDB != nil {
_ = db.cityDB.Close()
}
if db.asnDB != nil {
_ = db.asnDB.Close()
}
}
// Reload opens new database readers and swaps them in atomically.
// Opens replacement readers first - if both fail, the old readers stay in place.
// If one succeeds and the other fails, only the successful one is swapped.
func (db *DB) Reload() error {
if db == nil {
return fmt.Errorf("geoip: cannot reload nil DB")
}
cityPath := filepath.Join(db.dbDir, "GeoLite2-City.mmdb")
asnPath := filepath.Join(db.dbDir, "GeoLite2-ASN.mmdb")
// Open replacements before taking lock
newCity, cityErr := maxminddb.Open(cityPath)
newASN, asnErr := maxminddb.Open(asnPath)
if cityErr != nil && asnErr != nil {
return fmt.Errorf("geoip: reload failed - city: %v, asn: %v", cityErr, asnErr)
}
db.mu.Lock()
if newCity != nil {
if db.cityDB != nil {
_ = db.cityDB.Close()
}
db.cityDB = newCity
}
if newASN != nil {
if db.asnDB != nil {
_ = db.asnDB.Close()
}
db.asnDB = newASN
}
db.mu.Unlock()
if cityErr != nil {
fmt.Fprintf(os.Stderr, "geoip: reload warning - city DB failed: %v\n", cityErr)
}
if asnErr != nil {
fmt.Fprintf(os.Stderr, "geoip: reload warning - ASN DB failed: %v\n", asnErr)
}
return nil
}
// OpenFresh creates a new DB from databases on disk.
// Use when no DB existed at startup and databases have since been downloaded.
// Returns nil if no databases found (same behavior as Open).
func OpenFresh(dbDir string) *DB {
return Open(dbDir)
}
// Lookup returns geolocation info for an IP from local MaxMind databases.
// Fast (microseconds), no network calls.
func (db *DB) Lookup(ip string) Info {
info := Info{IP: ip}
if db == nil {
return info
}
addr, err := netip.ParseAddr(ip)
if err != nil {
return info
}
db.mu.RLock()
defer db.mu.RUnlock()
if db.cityDB != nil {
var record cityRecord
result := db.cityDB.Lookup(addr)
if err := result.Decode(&record); err == nil {
info.Country = record.Country.ISOCode
info.CountryName = record.Country.Names["en"]
info.City = record.City.Names["en"]
if prefix := result.Prefix(); prefix.IsValid() {
info.Network = prefix.String()
}
}
}
if db.asnDB != nil {
var record asnRecord
result := db.asnDB.Lookup(addr)
if err := result.Decode(&record); err == nil {
info.ASN = record.ASN
info.ASOrg = record.Org
}
}
return info
}
// LookupWithRDAP returns geolocation info plus on-demand RDAP details.
// The RDAP lookup is cached for 24 hours.
func (db *DB) LookupWithRDAP(ip string) Info {
info := db.Lookup(ip)
// Check RDAP cache
db.rdapMu.Lock()
if cached, ok := db.rdapTTL[ip]; ok && time.Since(cached.fetched) < 24*time.Hour {
db.rdapMu.Unlock()
info.RDAPOrg = cached.info.RDAPOrg
info.RDAPName = cached.info.RDAPName
info.RDAPCountry = cached.info.RDAPCountry
return info
}
db.rdapMu.Unlock()
// Fetch from RDAP
rdapInfo := fetchRDAP(ip)
info.RDAPOrg = rdapInfo.RDAPOrg
info.RDAPName = rdapInfo.RDAPName
info.RDAPCountry = rdapInfo.RDAPCountry
// Cache
db.rdapMu.Lock()
db.rdapTTL[ip] = rdapCacheEntry{info: rdapInfo, fetched: time.Now()}
// Evict old entries
if len(db.rdapTTL) > 10000 {
for k, v := range db.rdapTTL {
if time.Since(v.fetched) > 24*time.Hour {
delete(db.rdapTTL, k)
}
}
}
db.rdapMu.Unlock()
return info
}
// RDAP lookup - fetches from the appropriate RIR
func fetchRDAP(ip string) Info {
var info Info
url := fmt.Sprintf("https://rdap.org/ip/%s", ip)
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Get(url)
if err != nil {
return info
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return info
}
var rdap struct {
Name string `json:"name"`
Country string `json:"country"`
Handle string `json:"handle"`
Entities []struct {
VCardArray []interface{} `json:"vcardArray"`
Roles []string `json:"roles"`
} `json:"entities"`
}
if err := json.NewDecoder(resp.Body).Decode(&rdap); err != nil {
return info
}
info.RDAPName = rdap.Name
info.RDAPCountry = rdap.Country
// Extract org name from entities
for _, entity := range rdap.Entities {
for _, role := range entity.Roles {
if role == "registrant" || role == "abuse" {
if len(entity.VCardArray) >= 2 {
if props, ok := entity.VCardArray[1].([]interface{}); ok {
for _, prop := range props {
if arr, ok := prop.([]interface{}); ok && len(arr) >= 4 {
if name, ok := arr[0].(string); ok && name == "fn" {
if val, ok := arr[3].(string); ok {
info.RDAPOrg = val
}
}
}
}
}
}
}
}
}
return info
}
package geoip
import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/oschwald/maxminddb-golang/v2"
)
const (
maxMindBaseURL = "https://download.maxmind.com/geoip/databases"
maxDownloadSize = 150 * 1024 * 1024 // 150MB
// maxExtractedSize bounds the size of the .mmdb entry we extract from
// the tar.gz. Real GeoLite2 files are ~70MB; 500MiB is generous. The
// check sits on the tar header and prevents io.Copy from writing a
// bomb-compressed entry to disk.
maxExtractedSize = 500 * 1024 * 1024
downloadTimeout = 120 * time.Second
)
// EditionResult reports the outcome of updating a single GeoLite2 edition.
type EditionResult struct {
Edition string // e.g. "GeoLite2-City"
Status string // "updated", "up_to_date", "error"
Err error // nil unless Status == "error"
}
// Update downloads GeoLite2 databases from MaxMind's direct download API.
// Returns one EditionResult per edition. Returns nil if credentials are empty.
func Update(dbDir, accountID, licenseKey string, editions []string) []EditionResult {
if accountID == "" || licenseKey == "" {
return nil
}
if err := os.MkdirAll(dbDir, 0700); err != nil {
result := make([]EditionResult, len(editions))
for i, ed := range editions {
result[i] = EditionResult{Edition: ed, Status: "error", Err: fmt.Errorf("creating directory: %w", err)}
}
return result
}
client := &http.Client{Timeout: downloadTimeout}
results := make([]EditionResult, len(editions))
for i, edition := range editions {
results[i] = updateEdition(client, dbDir, accountID, licenseKey, edition)
}
return results
}
func updateEdition(client *http.Client, dbDir, accountID, licenseKey, edition string) EditionResult {
return updateEditionWithURL(client, dbDir, accountID, licenseKey, edition, maxMindBaseURL)
}
func updateEditionWithURL(client *http.Client, dbDir, accountID, licenseKey, edition, baseURL string) EditionResult {
url := fmt.Sprintf("%s/%s/download?suffix=tar.gz", baseURL, edition)
markerPath := filepath.Join(dbDir, ".last-modified-"+edition)
// Read stored Last-Modified
storedLM := ""
// #nosec G304 -- filepath.Join under operator-configured dbDir.
if data, err := os.ReadFile(markerPath); err == nil {
storedLM = strings.TrimSpace(string(data))
}
// HEAD request to check Last-Modified
headReq, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return EditionResult{Edition: edition, Status: "error", Err: err}
}
headReq.SetBasicAuth(accountID, licenseKey)
headResp, err := client.Do(headReq)
if err != nil {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("HEAD request: %w", err)}
}
headResp.Body.Close()
if headResp.StatusCode == 401 {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("invalid MaxMind credentials")}
}
if headResp.StatusCode == 429 {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("rate limited by MaxMind")}
}
if headResp.StatusCode != 200 {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("HEAD returned HTTP %d", headResp.StatusCode)}
}
remoteLM := headResp.Header.Get("Last-Modified")
if storedLM != "" && remoteLM != "" && storedLM == remoteLM {
return EditionResult{Edition: edition, Status: "up_to_date"}
}
// GET request to download
getReq, err := http.NewRequest("GET", url, nil)
if err != nil {
return EditionResult{Edition: edition, Status: "error", Err: err}
}
getReq.SetBasicAuth(accountID, licenseKey)
getResp, err := client.Do(getReq)
if err != nil {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("download: %w", err)}
}
defer getResp.Body.Close()
if getResp.StatusCode != 200 {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("download returned HTTP %d", getResp.StatusCode)}
}
// Reject oversized responses upfront if Content-Length is known
if getResp.ContentLength > maxDownloadSize {
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("download too large: %d bytes (max %d)", getResp.ContentLength, maxDownloadSize)}
}
// LimitReader as safety net for responses without Content-Length
mmdbTmpPath := filepath.Join(dbDir, edition+".mmdb.tmp")
if err := extractMMDB(io.LimitReader(getResp.Body, maxDownloadSize), mmdbTmpPath, edition); err != nil {
os.Remove(mmdbTmpPath)
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("extract: %w", err)}
}
if err := validateMMDB(mmdbTmpPath); err != nil {
os.Remove(mmdbTmpPath)
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("validate: %w", err)}
}
// Atomic install
destPath := filepath.Join(dbDir, edition+".mmdb")
if err := os.Rename(mmdbTmpPath, destPath); err != nil {
os.Remove(mmdbTmpPath)
return EditionResult{Edition: edition, Status: "error", Err: fmt.Errorf("install: %w", err)}
}
// Save Last-Modified marker
if remoteLM != "" {
_ = os.WriteFile(markerPath, []byte(remoteLM), 0600)
}
return EditionResult{Edition: edition, Status: "updated"}
}
func validateMMDB(path string) error {
db, err := maxminddb.Open(path)
if err != nil {
return err
}
return db.Close()
}
// extractMMDB reads a tar.gz stream and extracts the .mmdb file to destPath.
// MaxMind tar.gz archives contain a single directory with the .mmdb inside,
// e.g. GeoLite2-City_20260328/GeoLite2-City.mmdb
func extractMMDB(r io.Reader, destPath, edition string) error {
gz, err := gzip.NewReader(r)
if err != nil {
return fmt.Errorf("gzip: %w", err)
}
defer func() { _ = gz.Close() }()
tr := tar.NewReader(gz)
suffix := edition + ".mmdb"
for {
header, err := tr.Next()
if err == io.EOF {
return fmt.Errorf("no %s found in archive", suffix)
}
if err != nil {
return fmt.Errorf("reading tar: %w", err)
}
if header.Typeflag != tar.TypeReg {
continue
}
if !strings.HasSuffix(header.Name, suffix) {
continue
}
if header.Size > maxExtractedSize {
return fmt.Errorf("archive entry too large: %d bytes", header.Size)
}
// #nosec G304 -- destPath is filepath.Join under operator-configured dbDir.
f, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("creating %s: %w", destPath, err)
}
_, copyErr := io.Copy(f, io.LimitReader(tr, maxExtractedSize))
closeErr := f.Close()
if copyErr != nil {
return fmt.Errorf("writing mmdb: %w", copyErr)
}
if closeErr != nil {
return fmt.Errorf("closing mmdb: %w", closeErr)
}
return nil
}
}
package integrity
import (
"bufio"
"crypto/sha256"
"fmt"
"io"
"os"
"strings"
"github.com/pidginhost/csm/internal/config"
)
// HashFile returns the SHA256 hash of a file.
func HashFile(path string) (string, error) {
// #nosec G304 -- integrity hashing of operator-configured binary/config paths.
f, err := os.Open(path)
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), nil
}
// HashConfigStable hashes the config file excluding the integrity section,
// so that writing hashes back to the config doesn't change the hash.
func HashConfigStable(path string) (string, error) {
// #nosec G304 -- operator-supplied config file path.
f, err := os.Open(path)
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
h := sha256.New()
scanner := bufio.NewScanner(f)
inIntegrity := false
for scanner.Scan() {
line := scanner.Text()
// Skip the integrity section
if strings.HasPrefix(line, "integrity:") {
inIntegrity = true
continue
}
if inIntegrity {
// Still inside integrity block (indented lines)
if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") || line == "" {
continue
}
inIntegrity = false
}
_, _ = h.Write([]byte(line + "\n"))
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("scanning config: %w", err)
}
return fmt.Sprintf("sha256:%x", h.Sum(nil)), nil
}
// Verify checks the binary and config file integrity.
func Verify(binaryPath string, cfg *config.Config) error {
if cfg.Integrity.BinaryHash == "" {
return nil // Not yet baselined
}
currentHash, err := HashFile(binaryPath)
if err != nil {
return fmt.Errorf("hashing binary: %w", err)
}
if currentHash != cfg.Integrity.BinaryHash {
return fmt.Errorf("binary hash mismatch: expected %s, got %s", cfg.Integrity.BinaryHash, currentHash)
}
if cfg.Integrity.ConfigHash != "" {
configHash, err := HashConfigStable(cfg.ConfigFile)
if err != nil {
return fmt.Errorf("hashing config: %w", err)
}
if configHash != cfg.Integrity.ConfigHash {
return fmt.Errorf("config hash mismatch: expected %s, got %s", cfg.Integrity.ConfigHash, configHash)
}
}
return nil
}
// Package log provides a structured-logging wrapper around log/slog.
//
// Rationale: CSM's daemon currently emits timestamped log lines via direct
// fmt.Fprintf(os.Stderr, ...) calls. That works for journalctl but loses
// structure when operators ship logs to Loki, ELK, or Datadog. This
// package provides a drop-in replacement that:
//
// - Emits the legacy "[YYYY-MM-DD HH:MM:SS] msg" format in text mode so
// mixing csmlog calls with legacy fmt.Fprintf calls produces a
// uniform log stream (important during incremental migration)
// - Switches to JSON on CSM_LOG_FORMAT=json for log-shipping pipelines,
// emitting the slog-native level/msg/time/fields structure
// - Honors CSM_LOG_LEVEL={debug|info|warn|error} (default: info)
//
// Usage (preferred for new code):
//
// log.Info("daemon starting", "version", v, "pid", os.Getpid())
// log.Warn("log not found, will retry", "path", path)
// log.Error("alert dispatch failed", "err", err)
//
// Legacy call sites that still use fmt.Fprintf will keep working —
// migration is incremental. See docs/src/development.md for guidance.
package log
import (
"context"
"io"
"log/slog"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)
// global is the package-level logger, loaded lazily on first use.
// atomic.Pointer so Init can swap it without data races.
var global atomic.Pointer[slog.Logger]
// Init configures the global logger from environment variables:
//
// CSM_LOG_FORMAT = "text" (default) | "json"
// CSM_LOG_LEVEL = "debug" | "info" (default) | "warn" | "error"
//
// Safe to call multiple times. Returns the installed logger so callers can
// also pass it into subsystems that take a *slog.Logger.
func Init() *slog.Logger {
level := parseLevel(os.Getenv("CSM_LOG_LEVEL"))
handler := buildHandler(os.Getenv("CSM_LOG_FORMAT"), level)
logger := slog.New(handler)
global.Store(logger)
slog.SetDefault(logger)
return logger
}
// L returns the current global logger, initializing it on first call.
// Cheap on the hot path (single atomic load).
func L() *slog.Logger {
if l := global.Load(); l != nil {
return l
}
return Init()
}
// Helpers that mirror slog's method set on the global logger. Provided so
// call sites don't need to write log.L().Info(...) on every line.
func Debug(msg string, args ...any) { L().Debug(msg, args...) }
func Info(msg string, args ...any) { L().Info(msg, args...) }
func Warn(msg string, args ...any) { L().Warn(msg, args...) }
func Error(msg string, args ...any) { L().Error(msg, args...) }
func parseLevel(s string) slog.Level {
switch strings.ToLower(strings.TrimSpace(s)) {
case "debug":
return slog.LevelDebug
case "warn", "warning":
return slog.LevelWarn
case "error", "err":
return slog.LevelError
default:
return slog.LevelInfo
}
}
func buildHandler(format string, level slog.Level) slog.Handler {
opts := &slog.HandlerOptions{Level: level}
switch strings.ToLower(strings.TrimSpace(format)) {
case "json":
return slog.NewJSONHandler(os.Stderr, opts)
default:
return newLegacyTextHandler(os.Stderr, level)
}
}
// legacyTextHandler emits log records in CSM's historical "[timestamp] msg"
// format so callers migrating from fmt.Fprintf produce the same output. Key
// differences from slog.NewTextHandler:
//
// - No "time=... level=... msg=..." prefix; just "[YYYY-MM-DD HH:MM:SS] msg"
// - Structured fields are appended as " key=value" when present
// - Level is prepended as "WARN:" / "ERROR:" only for non-info records
//
// This lets operators mix csmlog calls with the ~180 remaining fmt.Fprintf
// call sites in the daemon without introducing a mixed-format log stream.
type legacyTextHandler struct {
w io.Writer
mu *sync.Mutex
level slog.Level
attrs []slog.Attr
group string
}
func newLegacyTextHandler(w io.Writer, level slog.Level) *legacyTextHandler {
return &legacyTextHandler{
w: w,
mu: &sync.Mutex{},
level: level,
}
}
func (h *legacyTextHandler) Enabled(_ context.Context, level slog.Level) bool {
return level >= h.level
}
func (h *legacyTextHandler) Handle(_ context.Context, r slog.Record) error {
var sb strings.Builder
sb.Grow(128)
ts := r.Time
if ts.IsZero() {
ts = time.Now()
}
sb.WriteByte('[')
sb.WriteString(ts.Format("2006-01-02 15:04:05"))
sb.WriteString("] ")
// Prepend a level marker only for non-info records so the info path
// exactly matches the legacy "[ts] msg" format.
switch r.Level {
case slog.LevelWarn:
sb.WriteString("WARN: ")
case slog.LevelError:
sb.WriteString("ERROR: ")
case slog.LevelDebug:
sb.WriteString("DEBUG: ")
}
sb.WriteString(r.Message)
// Append pre-bound attrs then record attrs as " key=value" pairs.
for _, a := range h.attrs {
writeAttr(&sb, a)
}
r.Attrs(func(a slog.Attr) bool {
writeAttr(&sb, a)
return true
})
sb.WriteByte('\n')
h.mu.Lock()
defer h.mu.Unlock()
_, err := io.WriteString(h.w, sb.String())
return err
}
func writeAttr(sb *strings.Builder, a slog.Attr) {
if a.Key == "" {
return
}
sb.WriteString(" ")
sb.WriteString(a.Key)
sb.WriteByte('=')
v := a.Value.Resolve()
s := v.String()
// Quote values that contain whitespace so the key=value pairs stay
// parseable when operators grep the log.
if strings.ContainsAny(s, " \t") {
sb.WriteByte('"')
sb.WriteString(s)
sb.WriteByte('"')
} else {
sb.WriteString(s)
}
}
func (h *legacyTextHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
merged := make([]slog.Attr, 0, len(h.attrs)+len(attrs))
merged = append(merged, h.attrs...)
merged = append(merged, attrs...)
return &legacyTextHandler{
w: h.w,
mu: h.mu,
level: h.level,
attrs: merged,
group: h.group,
}
}
func (h *legacyTextHandler) WithGroup(name string) slog.Handler {
// Groups flatten into the attr key via a prefix; simple implementation
// that's sufficient for CSM's usage (we don't use groups today).
return &legacyTextHandler{
w: h.w,
mu: h.mu,
level: h.level,
attrs: h.attrs,
group: name,
}
}
// Ensure legacyTextHandler satisfies the slog.Handler contract at compile time.
var _ slog.Handler = (*legacyTextHandler)(nil)
package mime
import (
"archive/tar"
"archive/zip"
"bufio"
"bytes"
"compress/gzip"
"encoding/base64"
"fmt"
"io"
"mime"
"mime/multipart"
"mime/quotedprintable"
"net/mail"
"net/textproto"
"os"
"path/filepath"
"strings"
)
// ExtractedPart represents a single extracted attachment.
type ExtractedPart struct {
Filename string
ContentType string
Size int64
TempPath string
Nested bool
ArchiveName string
}
// ExtractionResult holds all extracted parts and envelope metadata.
type ExtractionResult struct {
Parts []ExtractedPart
Partial bool
PartialReason string
Direction string
From string
To []string
Subject string
}
// Limits controls resource bounds during extraction.
type Limits struct {
MaxAttachmentSize int64
MaxArchiveDepth int
MaxArchiveFiles int
MaxExtractionSize int64
}
// DefaultLimits returns the default extraction limits.
func DefaultLimits() Limits {
return Limits{
MaxAttachmentSize: 25 * 1024 * 1024,
MaxArchiveDepth: 1,
MaxArchiveFiles: 50,
MaxExtractionSize: 100 * 1024 * 1024,
}
}
// ParseSpoolMessage parses an Exim spool message (-H and -D files) and
// extracts attachments to a temp directory. Caller must remove the temp
// files in result.Parts[*].TempPath when done.
func ParseSpoolMessage(headerPath, bodyPath string, limits Limits) (*ExtractionResult, error) {
envelope, hdrs, err := parseEximHeader(headerPath)
if err != nil {
return nil, fmt.Errorf("parsing header file: %w", err)
}
result := &ExtractionResult{
From: envelope.from,
To: envelope.to,
Subject: envelope.subject,
}
// Determine direction from Received headers
result.Direction = detectDirection(hdrs)
maxBodyBytes := bodyReadLimit(limits)
if info, statErr := os.Stat(bodyPath); statErr == nil && info.Size() > maxBodyBytes {
result.Partial = true
result.PartialReason = "message body exceeds parser memory budget"
return result, nil
}
bodyData, err := readFileLimited(bodyPath, maxBodyBytes)
if err != nil {
return nil, fmt.Errorf("reading body file: %w", err)
}
ct := hdrs.Get("Content-Type")
if ct == "" {
ct = "text/plain"
}
mediaType, params, parseErr := mime.ParseMediaType(ct)
if parseErr != nil {
// Unparseable content type - treat as plain text, no attachments
return result, nil //nolint:nilerr // fail-open by design
}
if strings.HasPrefix(mediaType, "multipart/") {
boundary := params["boundary"]
if boundary == "" {
return result, nil
}
var totalSize int64
extractErr := extractMultipart(bytes.NewReader(bodyData), boundary, limits, result, &totalSize, 0)
if extractErr != nil {
return result, nil //nolint:nilerr // fail-open by design: return what we extracted so far
}
} else if !strings.HasPrefix(mediaType, "text/") {
// Single-part non-text message (e.g. application/octet-stream,
// application/pdf, image/*). These are attachment-like payloads
// that must be scanned even without a multipart wrapper.
cte := strings.ToLower(hdrs.Get("Content-Transfer-Encoding"))
decoded, truncated := decodeSinglePart(bodyData, cte, limits.MaxAttachmentSize+1)
if !truncated && int64(len(decoded)) <= limits.MaxAttachmentSize {
tmpFile, tmpErr := os.CreateTemp("", "csm-emailav-single-*")
if tmpErr == nil {
_, _ = tmpFile.Write(decoded)
tmpFile.Close()
filename := params["name"]
if filename == "" {
filename = "attachment"
}
result.Parts = append(result.Parts, ExtractedPart{
Filename: filename,
ContentType: mediaType,
Size: int64(len(decoded)),
TempPath: tmpFile.Name(),
})
}
} else {
result.Partial = true
result.PartialReason = "single-part attachment exceeds max size"
}
}
// text/* bodies are not attachments - skip
return result, nil
}
func bodyReadLimit(limits Limits) int64 {
limit := limits.MaxExtractionSize
if limit < limits.MaxAttachmentSize*2 {
limit = limits.MaxAttachmentSize * 2
}
if limit <= 0 {
limit = DefaultLimits().MaxExtractionSize
}
return limit
}
func readFileLimited(path string, limit int64) ([]byte, error) {
// #nosec G304 -- path is mail queue file path from scanner walk.
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
data, err := io.ReadAll(io.LimitReader(f, limit+1))
if err != nil {
return nil, err
}
if int64(len(data)) > limit {
return nil, fmt.Errorf("message body exceeds parser memory budget")
}
return data, nil
}
func decodeSinglePart(bodyData []byte, cte string, limit int64) ([]byte, bool) {
var reader io.Reader
switch cte {
case "base64":
reader = base64.NewDecoder(base64.StdEncoding, bytes.NewReader(bodyData))
case "quoted-printable":
reader = quotedprintable.NewReader(bytes.NewReader(bodyData))
default:
if int64(len(bodyData)) > limit {
return bodyData[:limit], true
}
return bodyData, false
}
decoded, err := io.ReadAll(io.LimitReader(reader, limit+1))
if err != nil {
return nil, true
}
if int64(len(decoded)) > limit {
return decoded[:limit], true
}
return decoded, false
}
type envelope struct {
from string
to []string
subject string
}
// parseEximHeader reads an Exim -H file and extracts envelope info and headers.
// Exim -H files have a specific format: the first lines are Exim internal metadata,
// followed by RFC 2822 headers.
func parseEximHeader(path string) (*envelope, textproto.MIMEHeader, error) {
// #nosec G304 -- path is Exim -H file path from mail queue scanner walk.
data, err := os.ReadFile(path)
if err != nil {
return nil, nil, err
}
env := &envelope{}
// Parse as mail message to extract headers - find the header block
// Exim -H files contain envelope info then headers separated by a blank line pattern.
// We look for standard RFC headers.
reader := bufio.NewReader(bytes.NewReader(data))
// Collect the raw header text by finding lines that look like RFC 822 headers
var headerBuf bytes.Buffer
inHeaders := false
for {
line, readErr := reader.ReadString('\n')
if readErr != nil && line == "" {
break
}
trimmed := strings.TrimRight(line, "\r\n")
// Detect start of RFC headers (lines like "From:", "To:", "Subject:", etc.)
if !inHeaders {
if len(trimmed) > 0 && strings.Contains(trimmed, ":") {
lower := strings.ToLower(trimmed)
if strings.HasPrefix(lower, "from:") ||
strings.HasPrefix(lower, "to:") ||
strings.HasPrefix(lower, "subject:") ||
strings.HasPrefix(lower, "date:") ||
strings.HasPrefix(lower, "mime-version:") ||
strings.HasPrefix(lower, "content-type:") ||
strings.HasPrefix(lower, "received:") ||
strings.HasPrefix(lower, "message-id:") {
inHeaders = true
}
}
}
if inHeaders {
headerBuf.WriteString(line)
if trimmed == "" {
break // end of headers
}
}
}
// Parse the collected headers
msg, msgErr := mail.ReadMessage(&headerBuf)
if msgErr != nil {
// Try to extract what we can from the raw data
return env, make(textproto.MIMEHeader), nil //nolint:nilerr // fail-open: return empty headers
}
env.from = msg.Header.Get("From")
env.subject = msg.Header.Get("Subject")
if to := msg.Header.Get("To"); to != "" {
for _, addr := range strings.Split(to, ",") {
env.to = append(env.to, strings.TrimSpace(addr))
}
}
// Convert mail.Header to textproto.MIMEHeader
hdrs := make(textproto.MIMEHeader)
for k, v := range msg.Header {
hdrs[k] = v
}
return env, hdrs, nil
}
// detectDirection guesses inbound vs outbound from Received headers.
func detectDirection(hdrs textproto.MIMEHeader) string {
received := hdrs.Values("Received")
if len(received) == 0 {
return "outbound" // locally generated, no Received headers
}
// If the first (topmost) Received header contains "authenticated" it's outbound
first := strings.ToLower(received[0])
if strings.Contains(first, "(authenticated") || strings.Contains(first, "auth=") {
return "outbound"
}
return "inbound"
}
// extractMultipart recursively walks MIME parts, extracting attachments.
func extractMultipart(r io.Reader, boundary string, limits Limits, result *ExtractionResult, totalSize *int64, depth int) error {
mr := multipart.NewReader(r, boundary)
for {
part, err := mr.NextPart()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
ct := part.Header.Get("Content-Type")
if ct == "" {
ct = "text/plain"
}
mediaType, params, _ := mime.ParseMediaType(ct)
// Recurse into nested multipart
if strings.HasPrefix(mediaType, "multipart/") {
if b := params["boundary"]; b != "" {
if nestedErr := extractMultipart(part, b, limits, result, totalSize, depth); nestedErr != nil {
return nestedErr
}
}
continue
}
// Skip inline text bodies - only extract attachments
disp := part.Header.Get("Content-Disposition")
filename := part.FileName()
if filename == "" {
// Try Content-Disposition filename param
if disp != "" {
_, dparams, _ := mime.ParseMediaType(disp)
filename = dparams["filename"]
}
}
if filename == "" {
// No filename and text/* content - this is a body part, skip
if strings.HasPrefix(mediaType, "text/") {
continue
}
// Non-text without filename - use generic name
filename = "unnamed_attachment"
}
// Decode the part body based on Content-Transfer-Encoding
cte := strings.ToLower(part.Header.Get("Content-Transfer-Encoding"))
var bodyReader io.Reader = part
switch cte {
case "base64":
bodyReader = base64.NewDecoder(base64.StdEncoding, part)
case "quoted-printable":
bodyReader = quotedprintable.NewReader(part)
}
// Write to temp file with size limit
tmpFile, err := os.CreateTemp("", "csm-emailav-*")
if err != nil {
return fmt.Errorf("creating temp file: %w", err)
}
limited := io.LimitReader(bodyReader, limits.MaxAttachmentSize+1)
n, err := io.Copy(tmpFile, limited)
tmpFile.Close()
if err != nil {
os.Remove(tmpFile.Name())
continue // fail-open: skip this part
}
if n > limits.MaxAttachmentSize {
os.Remove(tmpFile.Name())
result.Partial = true
result.PartialReason = fmt.Sprintf("attachment %q exceeds max size %d", filename, limits.MaxAttachmentSize)
continue
}
*totalSize += n
if *totalSize > limits.MaxExtractionSize {
os.Remove(tmpFile.Name())
result.Partial = true
result.PartialReason = fmt.Sprintf("total extraction size exceeds %d bytes", limits.MaxExtractionSize)
return nil // stop extracting
}
result.Parts = append(result.Parts, ExtractedPart{
Filename: filename,
ContentType: mediaType,
Size: n,
TempPath: tmpFile.Name(),
})
// Attempt archive extraction
lower := strings.ToLower(filename)
if depth < limits.MaxArchiveDepth {
if strings.HasSuffix(lower, ".zip") {
extractZIP(tmpFile.Name(), filename, limits, result, totalSize, depth+1)
} else if strings.HasSuffix(lower, ".tar.gz") || strings.HasSuffix(lower, ".tgz") {
extractTarGz(tmpFile.Name(), filename, limits, result, totalSize, depth+1)
}
}
}
}
func extractZIP(zipPath, archiveName string, limits Limits, result *ExtractionResult, totalSize *int64, depth int) {
// #nosec G304 -- zipPath is CreateTemp-produced path from the caller.
f, err := os.Open(zipPath)
if err != nil {
return // fail-open: skip corrupt archives
}
defer f.Close()
info, err := f.Stat()
if err != nil {
return
}
zr, err := zip.NewReader(f, info.Size())
if err != nil {
return
}
extracted := 0
for _, zf := range zr.File {
if extracted >= limits.MaxArchiveFiles {
result.Partial = true
result.PartialReason = fmt.Sprintf("archive %q exceeds max files %d", archiveName, limits.MaxArchiveFiles)
return
}
if zf.FileInfo().IsDir() {
continue
}
rc, err := zf.Open()
if err != nil {
continue
}
tmpFile, err := os.CreateTemp("", "csm-emailav-zip-*")
if err != nil {
rc.Close()
continue
}
limited := io.LimitReader(rc, limits.MaxAttachmentSize+1)
n, err := io.Copy(tmpFile, limited)
tmpFile.Close()
rc.Close()
if err != nil || n > limits.MaxAttachmentSize {
os.Remove(tmpFile.Name())
if n > limits.MaxAttachmentSize {
result.Partial = true
result.PartialReason = fmt.Sprintf("file %q in archive exceeds max size", zf.Name)
}
continue
}
*totalSize += n
if *totalSize > limits.MaxExtractionSize {
os.Remove(tmpFile.Name())
result.Partial = true
result.PartialReason = fmt.Sprintf("total extraction size exceeds %d bytes", limits.MaxExtractionSize)
return
}
result.Parts = append(result.Parts, ExtractedPart{
Filename: zf.Name,
ContentType: "application/octet-stream",
Size: n,
TempPath: tmpFile.Name(),
Nested: true,
ArchiveName: archiveName,
})
extracted++
}
}
func extractTarGz(tgzPath, archiveName string, limits Limits, result *ExtractionResult, totalSize *int64, depth int) {
// #nosec G304 -- tgzPath is CreateTemp-produced path from the caller.
f, err := os.Open(tgzPath)
if err != nil {
return
}
defer f.Close()
gr, err := gzip.NewReader(f)
if err != nil {
return
}
defer func() { _ = gr.Close() }()
tr := tar.NewReader(gr)
extracted := 0
for {
hdr, err := tr.Next()
if err != nil {
return // EOF or error - done
}
if hdr.Typeflag != tar.TypeReg {
continue
}
if extracted >= limits.MaxArchiveFiles {
result.Partial = true
result.PartialReason = fmt.Sprintf("archive %q exceeds max files %d", archiveName, limits.MaxArchiveFiles)
return
}
tmpFile, err := os.CreateTemp("", "csm-emailav-tgz-*")
if err != nil {
continue
}
limited := io.LimitReader(tr, limits.MaxAttachmentSize+1)
n, err := io.Copy(tmpFile, limited)
tmpFile.Close()
if err != nil || n > limits.MaxAttachmentSize {
os.Remove(tmpFile.Name())
continue
}
*totalSize += n
if *totalSize > limits.MaxExtractionSize {
os.Remove(tmpFile.Name())
result.Partial = true
result.PartialReason = fmt.Sprintf("total extraction size exceeds %d bytes", limits.MaxExtractionSize)
return
}
result.Parts = append(result.Parts, ExtractedPart{
Filename: filepath.Base(hdr.Name),
ContentType: "application/octet-stream",
Size: n,
TempPath: tmpFile.Name(),
Nested: true,
ArchiveName: archiveName,
})
extracted++
}
}
package modsec
import (
"bufio"
"fmt"
"io"
"os"
"sort"
"strconv"
"strings"
)
const overridesHeader = "# CSM ModSecurity Rule Overrides\n# Managed by CSM - do not edit manually.\n"
// WriteOverrides writes the overrides file with SecRuleRemoveById directives.
// Atomic write (tmp + rename). Sorts IDs for deterministic output.
func WriteOverrides(path string, disabledIDs []int) error {
sort.Ints(disabledIDs)
var sb strings.Builder
sb.WriteString(overridesHeader)
for _, id := range disabledIDs {
fmt.Fprintf(&sb, "SecRuleRemoveById %d\n", id)
}
tmpPath := path + ".tmp"
// #nosec G306 -- ModSec config files are read by the webserver process
// (Apache/nginx) which runs as a different user. 0640 lets root write
// and the webserver group read; world-read stays off.
if err := os.WriteFile(tmpPath, []byte(sb.String()), 0640); err != nil {
return fmt.Errorf("writing overrides tmp: %w", err)
}
if err := os.Rename(tmpPath, path); err != nil {
os.Remove(tmpPath)
return fmt.Errorf("renaming overrides: %w", err)
}
return nil
}
// ReadOverrides reads the overrides file and returns disabled rule IDs.
// Returns empty list (not error) if the file does not exist.
func ReadOverrides(path string) ([]int, error) {
// #nosec G304 -- path is operator-configured ModSec overrides file.
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
defer f.Close()
var ids []int
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "SecRuleRemoveById ") {
idStr := strings.TrimPrefix(line, "SecRuleRemoveById ")
if id, err := strconv.Atoi(strings.TrimSpace(idStr)); err == nil && id >= 900000 && id <= 900999 {
ids = append(ids, id)
}
}
}
return ids, scanner.Err()
}
// ReadOverridesRaw reads the overrides file content for rollback purposes.
// Returns nil (not error) if the file does not exist.
func ReadOverridesRaw(path string) []byte {
// #nosec G304 -- path is operator-configured ModSec overrides file.
data, err := os.ReadFile(path)
if err != nil {
return nil
}
return data
}
// RestoreOverrides writes raw content back to the overrides file (for rollback).
// Uses atomic tmp+rename to prevent partial writes on crash.
func RestoreOverrides(path string, content []byte) error {
if content == nil {
// File didn't exist before - remove it
os.Remove(path)
return nil
}
tmpPath := path + ".tmp"
// #nosec G306 -- see note in WriteOverrides: webserver-readable ModSec config.
if err := os.WriteFile(tmpPath, content, 0640); err != nil {
return fmt.Errorf("writing rollback tmp: %w", err)
}
if err := os.Rename(tmpPath, path); err != nil {
os.Remove(tmpPath)
return fmt.Errorf("renaming rollback: %w", err)
}
return nil
}
// EnsureOverridesInclude appends an Include directive for the overrides file
// to a ModSecurity config file if not already present, and creates an empty
// overrides file if it doesn't exist. Idempotent: re-reads content under the
// write-open to avoid appending duplicate Include directives.
func EnsureOverridesInclude(rulesFile, overridesFile string) {
// Open for read+write to check-then-append atomically (same fd).
// #nosec G302 G304 -- webserver-readable ModSec rules file; 0640 means root
// can write and the webserver group can read. No world read.
f, err := os.OpenFile(rulesFile, os.O_RDWR|os.O_APPEND, 0640)
if err != nil {
return
}
defer f.Close()
data, err := io.ReadAll(f)
if err != nil {
return
}
if !strings.Contains(string(data), overridesFile) {
fmt.Fprintf(f, "\n# CSM overrides - managed by CSM rule management\nInclude %s\n", overridesFile)
}
// Create empty overrides file if it doesn't exist
if _, err := os.Stat(overridesFile); os.IsNotExist(err) {
// #nosec G306 -- see note in WriteOverrides: webserver-readable.
_ = os.WriteFile(overridesFile, []byte(overridesHeader), 0640)
}
}
package modsec
import (
"bufio"
"fmt"
"os"
"regexp"
"strconv"
"strings"
)
// Rule represents a parsed ModSecurity rule from the CSM custom config.
type Rule struct {
ID int // e.g. 900112
Description string // from msg:'...' field
Action string // "deny", "pass", "log"
StatusCode int // 403, 429, 0 (for pass/log)
Phase int // 1 or 2
Raw string // full rule text including chains
IsCounter bool // true if pass,nolog (bookkeeping rule, hidden in UI)
}
var (
reID = regexp.MustCompile(`[,"]id:(\d+)`)
reMsg = regexp.MustCompile(`msg:'([^']*)'`)
rePhase = regexp.MustCompile(`phase:(\d)`)
reStatus = regexp.MustCompile(`status:(\d+)`)
)
// ParseRulesFile reads a ModSecurity config file and extracts all rules
// with IDs in the 900000-900999 range. Handles line continuations (\) and
// chained rules (chain action keyword).
func ParseRulesFile(path string) ([]Rule, error) {
// #nosec G304 -- path is operator-configured ModSec rules file.
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("opening rules file: %w", err)
}
defer f.Close()
// Phase 1: Read lines, joining backslash continuations into logical lines.
var logicalLines []string
var current strings.Builder
scanner := bufio.NewScanner(f)
for scanner.Scan() {
raw := scanner.Text()
trimmed := strings.TrimSpace(raw)
if strings.HasSuffix(trimmed, "\\") {
// Continuation: strip trailing \ and keep accumulating
current.WriteString(strings.TrimSuffix(trimmed, "\\"))
current.WriteString(" ")
continue
}
current.WriteString(trimmed)
logicalLines = append(logicalLines, current.String())
current.Reset()
}
if current.Len() > 0 {
logicalLines = append(logicalLines, current.String())
}
if err := scanner.Err(); err != nil {
return nil, err
}
// Phase 2: Group logical lines into blocks.
// Each block starts with a SecRule and may include chained SecRules.
// When a directive's action string contains "chain", the next SecRule
// is part of the same block.
var blocks []string
var block strings.Builder
chainPending := false
flushBlock := func() {
if block.Len() > 0 {
blocks = append(blocks, block.String())
block.Reset()
}
chainPending = false
}
for _, line := range logicalLines {
if strings.HasPrefix(line, "SecRule ") {
if block.Len() > 0 && !chainPending {
flushBlock()
}
} else if block.Len() == 0 {
continue // skip comments and blank lines outside blocks
}
if block.Len() > 0 || strings.HasPrefix(line, "SecRule ") {
block.WriteString(line)
block.WriteString("\n")
// Only update chainPending for SecRule lines - non-SecRule
// directives between chained rules must not reset the flag.
if strings.HasPrefix(line, "SecRule ") {
chainPending = hasChainAction(line)
}
}
}
flushBlock()
// Phase 3: Parse each block into a Rule.
var rules []Rule
for _, b := range blocks {
if r, ok := parseBlock(b); ok {
rules = append(rules, r)
}
}
return rules, nil
}
// hasChainAction checks whether a logical line (continuations already joined)
// contains "chain" as a ModSecurity action keyword. Strips whitespace to handle
// action strings split across continuation lines like: "id:900004,..., chain"
func hasChainAction(line string) bool {
// Remove all whitespace so ", chain\"" becomes ",chain\""
stripped := strings.Map(func(r rune) rune {
if r == ' ' || r == '\t' {
return -1
}
return r
}, line)
return strings.Contains(stripped, ",chain\"") ||
strings.Contains(stripped, ",chain'") ||
strings.Contains(stripped, ",chain,") ||
strings.Contains(stripped, "\"chain\"") ||
strings.Contains(stripped, "\"chain,")
}
func parseBlock(block string) (Rule, bool) {
// Extract ID
m := reID.FindStringSubmatch(block)
if m == nil {
return Rule{}, false
}
id, _ := strconv.Atoi(m[1])
// Filter to CSM range
if id < 900000 || id > 900999 {
return Rule{}, false
}
r := Rule{
ID: id,
Raw: strings.TrimSpace(block),
}
// Extract description from msg
if mm := reMsg.FindStringSubmatch(block); mm != nil {
r.Description = mm[1]
}
// Extract phase
if pm := rePhase.FindStringSubmatch(block); pm != nil {
r.Phase, _ = strconv.Atoi(pm[1])
}
// Extract action from the action string (quoted section).
// Use comma/quote-prefixed matching to avoid false matches
// in variable names or patterns (e.g. "nolog" matching "log").
lower := strings.ToLower(block)
switch {
case strings.Contains(lower, ",deny") || strings.Contains(lower, "\"deny"):
r.Action = "deny"
case strings.Contains(lower, ",pass") || strings.Contains(lower, "\"pass"):
r.Action = "pass"
case strings.Contains(lower, ",log,") || strings.Contains(lower, ",log\"") || strings.Contains(lower, "\"log,"):
r.Action = "log"
}
// Extract status code
if sm := reStatus.FindStringSubmatch(block); sm != nil {
r.StatusCode, _ = strconv.Atoi(sm[1])
}
// Mark counter rules
if r.Action == "pass" && strings.Contains(lower, "nolog") {
r.IsCounter = true
}
return r, true
}
package modsec
import (
"context"
"errors"
"fmt"
"os/exec"
"time"
)
const reloadTimeout = 30 * time.Second
// Reload executes the configured web server reload command.
// Returns combined stdout+stderr output and any error.
func Reload(command string) (string, error) {
if command == "" {
return "", errors.New("reload command is empty")
}
ctx, cancel := context.WithTimeout(context.Background(), reloadTimeout)
defer cancel()
// Run through shell to support compound commands, quoted paths, etc.
// #nosec G204 -- `command` is the operator-configured reload command
// from csm.yaml (e.g. "apachectl graceful"), loaded at daemon startup
// from a root-owned config. Not webui-settable.
cmd := exec.CommandContext(ctx, "sh", "-c", command)
out, err := cmd.CombinedOutput()
output := string(out)
if ctx.Err() == context.DeadlineExceeded {
return output, fmt.Errorf("reload timed out after %v", reloadTimeout)
}
if err != nil {
return output, fmt.Errorf("reload failed: %w (output: %s)", err, output)
}
return output, nil
}
// Package platform detects the host OS, control panel, and web server so
// CSM checks can pick the right config/log paths instead of hardcoding
// cPanel+Apache layouts.
package platform
import (
"bufio"
"os"
"os/exec"
"strings"
"sync"
"sync/atomic"
)
type OSFamily string
const (
OSUnknown OSFamily = ""
OSUbuntu OSFamily = "ubuntu"
OSDebian OSFamily = "debian"
OSAlma OSFamily = "almalinux"
OSRocky OSFamily = "rocky"
OSCentOS OSFamily = "centos"
OSRHEL OSFamily = "rhel"
OSCloudLinux OSFamily = "cloudlinux"
)
type Panel string
const (
PanelNone Panel = ""
PanelCPanel Panel = "cpanel"
PanelPlesk Panel = "plesk"
PanelDA Panel = "directadmin"
)
type WebServer string
const (
WSNone WebServer = ""
WSApache WebServer = "apache"
WSNginx WebServer = "nginx"
WSLiteSpeed WebServer = "litespeed"
)
// Info holds everything a check needs to locate web server resources.
type Info struct {
OS OSFamily
OSVersion string
Panel Panel
WebServer WebServer
// Config locations for the detected web server.
ApacheConfigDir string // e.g. /etc/apache2 or /etc/httpd
NginxConfigDir string // e.g. /etc/nginx
// Candidate log files. Populated based on detected web server + OS.
AccessLogPaths []string
ErrorLogPaths []string
ModSecAuditLogPaths []string
// Binary paths useful for reload/control.
ApacheBinary string
NginxBinary string
}
// IsCPanel is a convenience for checks that still need to gate cPanel-only
// behavior (WHM API calls, /home/*/public_html enumeration, exim log
// tailing, etc.) without re-detecting each time.
func (i Info) IsCPanel() bool { return i.Panel == PanelCPanel }
// IsRHELFamily reports whether the OS uses rpm/dnf and /etc/httpd style paths.
func (i Info) IsRHELFamily() bool {
switch i.OS {
case OSAlma, OSRocky, OSCentOS, OSRHEL, OSCloudLinux:
return true
}
return false
}
// IsDebianFamily reports whether the OS uses dpkg/apt and /etc/apache2 style paths.
func (i Info) IsDebianFamily() bool {
return i.OS == OSUbuntu || i.OS == OSDebian
}
// Overrides lets the operator override auto-detected values from csm.yaml.
// Any field left blank or nil falls back to the auto-detected value.
//
// Panel and WebServer use pointer types so callers can distinguish "leave
// auto-detected" (nil) from "explicitly override to none" (pointer to
// PanelNone / WSNone). The non-pointer string/slice fields use the
// zero-value-means-unset convention since they have no legitimate "none"
// value to override to.
type Overrides struct {
Panel *Panel
WebServer *WebServer
AccessLogPaths []string
ErrorLogPaths []string
ModSecAuditLogPaths []string
ApacheConfigDir string
NginxConfigDir string
}
var (
detected Info
detectedOnce sync.Once
overrideMu sync.Mutex
pendingOverride *Overrides
)
// SetOverrides installs config-supplied overrides to be merged into the next
// (and all subsequent) Detect() result. Call this once from daemon startup,
// BEFORE the first Detect() call, so the merged info is what every check
// sees. Subsequent SetOverrides calls before Detect() replace the previous
// override; calls after Detect() are no-ops and log a warning via the
// returned bool.
//
// Returns true if the override was installed, false if Detect() had already
// cached an un-overridden result.
func SetOverrides(o Overrides) bool {
overrideMu.Lock()
defer overrideMu.Unlock()
// Use a local mutex + a check on detectedOnce state. detectedOnce has
// no public "was it called" query, so we track it via the pending var.
if pendingOverride != nil && isDetected() {
return false
}
pendingOverride = &o
return !isDetected()
}
// isDetected returns true if Detect() has already cached a result.
// Internal helper — uses a separate flag because sync.Once has no query API.
var detectedFlag atomic.Bool
func isDetected() bool { return detectedFlag.Load() }
// Detect inspects the host and returns platform info. The result is cached
// for the process lifetime — callers that need a fresh probe should use
// DetectFresh instead.
func Detect() Info {
detectedOnce.Do(func() {
detected = DetectFresh()
overrideMu.Lock()
if pendingOverride != nil {
detected = applyOverrides(detected, *pendingOverride)
}
overrideMu.Unlock()
detectedFlag.Store(true)
})
return detected
}
// DetectFresh always re-runs detection, ignoring any cached result.
// Intended for tests and for operator-triggered rescan. Does not apply
// config overrides — use Detect() for the operator-visible view.
func DetectFresh() Info {
i := Info{}
detectOS(&i)
detectPanel(&i)
detectWebServer(&i)
populatePaths(&i)
return i
}
// applyOverrides merges non-empty override fields into info. Always returns
// a new Info — never mutates the input. Paths are replaced, not appended:
// if the operator configured an explicit access-log list, the auto-detected
// list is discarded so operators have full control.
func applyOverrides(info Info, o Overrides) Info {
// Panel override must happen before path rebuild so populatePaths
// picks up the cPanel overlay (or drops it) correctly. Nil means
// "leave auto-detected"; a non-nil pointer always wins, even when it
// points at PanelNone, so operators can explicitly force a host to
// look panel-less.
if o.Panel != nil {
info.Panel = *o.Panel
}
if o.WebServer != nil {
// Web server type changed → rebuild paths from scratch unless the
// operator also supplied path overrides below. Same nil-vs-pointer
// semantics as Panel: a pointer at WSNone forces "no web server"
// instead of being silently ignored.
info.WebServer = *o.WebServer
info.AccessLogPaths = nil
info.ErrorLogPaths = nil
info.ModSecAuditLogPaths = nil
populatePaths(&info)
}
if len(o.AccessLogPaths) > 0 {
info.AccessLogPaths = append([]string(nil), o.AccessLogPaths...)
}
if len(o.ErrorLogPaths) > 0 {
info.ErrorLogPaths = append([]string(nil), o.ErrorLogPaths...)
}
if len(o.ModSecAuditLogPaths) > 0 {
info.ModSecAuditLogPaths = append([]string(nil), o.ModSecAuditLogPaths...)
}
if o.ApacheConfigDir != "" {
info.ApacheConfigDir = o.ApacheConfigDir
}
if o.NginxConfigDir != "" {
info.NginxConfigDir = o.NginxConfigDir
}
return info
}
// ResetForTest clears the cached Detect() result so tests can re-run with
// different fixtures. Never call from production code.
func ResetForTest() {
overrideMu.Lock()
defer overrideMu.Unlock()
detected = Info{}
detectedOnce = sync.Once{}
pendingOverride = nil
detectedFlag.Store(false)
}
func detectOS(i *Info) {
f, err := os.Open("/etc/os-release")
if err != nil {
return
}
defer func() { _ = f.Close() }()
scanner := bufio.NewScanner(f)
var id, versionID string
for scanner.Scan() {
line := scanner.Text()
key, val, ok := strings.Cut(line, "=")
if !ok {
continue
}
val = strings.Trim(val, `"'`)
switch key {
case "ID":
id = strings.ToLower(val)
case "VERSION_ID":
versionID = val
}
}
i.OSVersion = versionID
switch id {
case "ubuntu":
i.OS = OSUbuntu
case "debian":
i.OS = OSDebian
case "almalinux":
i.OS = OSAlma
case "rocky":
i.OS = OSRocky
case "centos":
i.OS = OSCentOS
case "rhel":
i.OS = OSRHEL
case "cloudlinux":
i.OS = OSCloudLinux
}
}
func detectPanel(i *Info) {
if _, err := os.Stat("/usr/local/cpanel/version"); err == nil {
i.Panel = PanelCPanel
return
}
if _, err := os.Stat("/usr/local/psa/version"); err == nil {
i.Panel = PanelPlesk
return
}
if _, err := os.Stat("/usr/local/directadmin/directadmin"); err == nil {
i.Panel = PanelDA
return
}
}
func detectWebServer(i *Info) {
// Prefer the process that's actually running. Fall back to installed
// binaries if nothing is running yet (first boot, non-systemd env).
running := runningServices()
// Always record binary paths for reload/control, even if not primary.
if bin, err := exec.LookPath("nginx"); err == nil {
i.NginxBinary = bin
}
if bin, err := exec.LookPath("apache2"); err == nil {
i.ApacheBinary = bin
} else if bin, err := exec.LookPath("httpd"); err == nil {
i.ApacheBinary = bin
}
// cPanel compiles its own httpd under /usr/local/apache/bin/httpd,
// which isn't always in PATH for root under the CSM service unit.
if i.ApacheBinary == "" {
const cpHttpd = "/usr/local/apache/bin/httpd"
if _, err := os.Stat(cpHttpd); err == nil {
i.ApacheBinary = cpHttpd
}
}
i.WebServer = selectWebServer(i.Panel, running, i.ApacheBinary != "", i.NginxBinary != "")
}
// runningServices returns which web server process units are currently
// active. Uses systemctl when available; falls back to checking /proc.
func runningServices() map[string]bool {
active := map[string]bool{}
for _, unit := range []string{"nginx", "apache2", "httpd", "litespeed", "lshttpd", "lsws"} {
// #nosec G204 -- systemctl hardcoded; unit iterates a literal slice.
cmd := exec.Command("systemctl", "is-active", "--quiet", unit)
if err := cmd.Run(); err == nil {
active[unit] = true
}
}
return active
}
func selectWebServer(panel Panel, running map[string]bool, hasApacheBinary, hasNginxBinary bool) WebServer {
apacheRunning := running["apache2"] || running["httpd"]
litespeedRunning := running["litespeed"] || running["lshttpd"] || running["lsws"]
nginxRunning := running["nginx"]
// cPanel commonly runs Nginx as a reverse proxy in front of Apache.
// Prefer the origin server logs when Apache is active so real-time
// access and ModSecurity watchers tail the paths cPanel actually writes.
if panel == PanelCPanel {
switch {
case litespeedRunning:
return WSLiteSpeed
case apacheRunning:
return WSApache
case nginxRunning:
return WSNginx
case hasApacheBinary:
return WSApache
case hasNginxBinary:
return WSNginx
default:
return WSNone
}
}
switch {
case nginxRunning:
return WSNginx
case apacheRunning:
return WSApache
case litespeedRunning:
return WSLiteSpeed
case hasApacheBinary:
return WSApache
case hasNginxBinary:
return WSNginx
default:
return WSNone
}
}
func populatePaths(i *Info) {
// Apache config dir. cPanel compiles Apache from source and installs
// under /usr/local/apache, separate from the OS package tree; that
// override wins over the distro default when cPanel is present.
switch {
case i.Panel == PanelCPanel && dirExists("/usr/local/apache/conf"):
i.ApacheConfigDir = "/usr/local/apache/conf"
case i.IsDebianFamily():
if dirExists("/etc/apache2") {
i.ApacheConfigDir = "/etc/apache2"
}
case i.IsRHELFamily():
if dirExists("/etc/httpd") {
i.ApacheConfigDir = "/etc/httpd"
}
}
if dirExists("/etc/nginx") {
i.NginxConfigDir = "/etc/nginx"
}
// Log paths: pick candidates based on detected web server and OS layout.
// We include ALL plausible locations so log watchers can try each;
// missing paths are handled upstream by the retry logic.
switch i.WebServer {
case WSApache:
if i.IsDebianFamily() {
i.AccessLogPaths = []string{"/var/log/apache2/access.log", "/var/log/apache2/other_vhosts_access.log"}
i.ErrorLogPaths = []string{"/var/log/apache2/error.log"}
i.ModSecAuditLogPaths = []string{"/var/log/apache2/modsec_audit.log"}
} else {
i.AccessLogPaths = []string{"/var/log/httpd/access_log"}
i.ErrorLogPaths = []string{"/var/log/httpd/error_log"}
i.ModSecAuditLogPaths = []string{"/var/log/httpd/modsec_audit.log"}
}
case WSNginx:
i.AccessLogPaths = []string{"/var/log/nginx/access.log"}
i.ErrorLogPaths = []string{"/var/log/nginx/error.log"}
i.ModSecAuditLogPaths = []string{"/var/log/nginx/modsec_audit.log"}
case WSLiteSpeed:
i.AccessLogPaths = []string{"/usr/local/lsws/logs/access.log"}
i.ErrorLogPaths = []string{"/usr/local/lsws/logs/error.log"}
i.ModSecAuditLogPaths = []string{"/usr/local/lsws/logs/auditmodsec.log"}
}
// cPanel overlays its own access/error logs on top of the OS defaults.
if i.Panel == PanelCPanel {
i.AccessLogPaths = append([]string{
"/usr/local/apache/logs/access_log",
"/usr/local/cpanel/logs/access_log",
}, i.AccessLogPaths...)
i.ErrorLogPaths = append([]string{
"/usr/local/apache/logs/error_log",
}, i.ErrorLogPaths...)
i.ModSecAuditLogPaths = append([]string{
"/usr/local/apache/logs/modsec_audit.log",
"/var/log/modsec_audit.log",
}, i.ModSecAuditLogPaths...)
}
}
func dirExists(p string) bool {
fi, err := os.Stat(p)
return err == nil && fi.IsDir()
}
package signatures
import (
"archive/zip"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/yara"
)
const (
forgeReleasesURL = "https://api.github.com/repos/YARAHQ/yara-forge/releases/latest"
forgeDownloadFmt = "https://github.com/YARAHQ/yara-forge/releases/latest/download/yara-forge-rules-%s.zip"
forgeHTTPTimeout = 30 * time.Second
forgeMaxZIPSize = 20 * 1024 * 1024
)
var forgeTierAsset = map[string]string{
"core": "packages/core/yara-rules-core.yar",
"extended": "packages/extended/yara-rules-extended.yar",
"full": "packages/full/yara-rules-full.yar",
}
// ForgeUpdate checks for a new YARA Forge release and downloads it if newer.
// A detached signature is fetched from the ZIP URL + ".sig" and verified
// against the raw ZIP content before extraction.
func ForgeUpdate(rulesDir, tier, currentVersion, signingKey string, disabledRules []string) (newVersion string, ruleCount int, err error) {
if _, ok := forgeTierAsset[tier]; !ok {
return "", 0, fmt.Errorf("unknown YARA Forge tier: %q (valid: core, extended, full)", tier)
}
if e := requireSigningKey(signingKey); e != nil {
return "", 0, e
}
latestTag, err := forgeLatestTag()
if err != nil {
return "", 0, fmt.Errorf("checking YARA Forge release: %w", err)
}
if latestTag == currentVersion {
return currentVersion, 0, nil
}
zipURL := fmt.Sprintf(forgeDownloadFmt, tier)
zipData, err := forgeDownload(zipURL)
if err != nil {
return "", 0, fmt.Errorf("downloading YARA Forge %s: %w", tier, err)
}
sig, err := fetchSignature(zipURL + ".sig")
if err != nil {
return "", 0, fmt.Errorf("YARA Forge signature verification required but failed: %w", err)
}
if e := VerifySignature(signingKey, zipData, sig); e != nil {
return "", 0, fmt.Errorf("YARA Forge signature invalid: %w", e)
}
assetPath := forgeTierAsset[tier]
yarContent, err := forgeExtractYar(zipData, assetPath)
if err != nil {
return "", 0, fmt.Errorf("extracting YARA Forge rules: %w", err)
}
if len(disabledRules) > 0 {
yarContent = filterDisabledRules(yarContent, disabledRules)
}
ruleCount = countRules(yarContent)
outFile := filepath.Join(rulesDir, fmt.Sprintf("yara-forge-%s.yar", tier))
tmpFile := outFile + ".tmp"
if err := os.WriteFile(tmpFile, yarContent, 0600); err != nil {
return "", 0, fmt.Errorf("writing temp file: %w", err)
}
if err := yara.TestCompile(string(yarContent)); err != nil {
_ = os.Remove(tmpFile)
return "", 0, fmt.Errorf("YARA compilation test failed (keeping existing rules): %w", err)
}
// Remove other tier files (only one tier active at a time)
for t := range forgeTierAsset {
if t != tier {
_ = os.Remove(filepath.Join(rulesDir, fmt.Sprintf("yara-forge-%s.yar", t)))
}
}
if err := os.Rename(tmpFile, outFile); err != nil {
_ = os.Remove(tmpFile)
return "", 0, fmt.Errorf("installing rules: %w", err)
}
return latestTag, ruleCount, nil
}
func forgeLatestTag() (string, error) {
client := &http.Client{Timeout: forgeHTTPTimeout}
resp, err := client.Get(forgeReleasesURL)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release struct {
TagName string `json:"tag_name"`
}
if err := json.NewDecoder(io.LimitReader(resp.Body, 1024*1024)).Decode(&release); err != nil {
return "", fmt.Errorf("parsing release JSON: %w", err)
}
if release.TagName == "" {
return "", fmt.Errorf("empty tag_name in release")
}
return release.TagName, nil
}
func forgeDownload(url string) ([]byte, error) {
client := &http.Client{Timeout: forgeHTTPTimeout}
resp, err := client.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("download returned %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, forgeMaxZIPSize))
if err != nil {
return nil, fmt.Errorf("reading response: %w", err)
}
return data, nil
}
func forgeExtractYar(zipData []byte, assetPath string) ([]byte, error) {
reader, err := zip.NewReader(bytes.NewReader(zipData), int64(len(zipData)))
if err != nil {
return nil, fmt.Errorf("opening ZIP: %w", err)
}
for _, f := range reader.File {
if f.Name == assetPath {
rc, err := f.Open()
if err != nil {
return nil, fmt.Errorf("opening %s in ZIP: %w", assetPath, err)
}
defer rc.Close()
data, err := io.ReadAll(rc)
if err != nil {
return nil, fmt.Errorf("reading %s: %w", assetPath, err)
}
return data, nil
}
}
return nil, fmt.Errorf("asset %s not found in ZIP", assetPath)
}
func filterDisabledRules(content []byte, disabled []string) []byte {
if len(disabled) == 0 {
return content
}
disabledSet := make(map[string]bool, len(disabled))
for _, name := range disabled {
disabledSet[name] = true
}
lines := strings.Split(string(content), "\n")
var result []string
skipping := false
braceDepth := 0
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if skipping {
for _, ch := range trimmed {
switch ch {
case '{':
braceDepth++
case '}':
braceDepth--
}
}
if braceDepth <= 0 {
skipping = false
braceDepth = 0
}
continue
}
ruleName := extractRuleName(trimmed)
if ruleName != "" && disabledSet[ruleName] {
skipping = true
braceDepth = 0
for _, ch := range trimmed {
switch ch {
case '{':
braceDepth++
case '}':
braceDepth--
}
}
if braceDepth <= 0 {
skipping = false
braceDepth = 0
}
continue
}
result = append(result, line)
}
return []byte(strings.Join(result, "\n"))
}
func extractRuleName(line string) string {
s := strings.TrimPrefix(line, "private ")
if !strings.HasPrefix(s, "rule ") {
return ""
}
s = s[5:]
for i, ch := range s {
if ch == ' ' || ch == '\t' || ch == ':' || ch == '{' {
return s[:i]
}
}
return s
}
func countRules(content []byte) int {
count := 0
for _, line := range strings.Split(string(content), "\n") {
trimmed := strings.TrimSpace(line)
if extractRuleName(trimmed) != "" {
count++
}
}
return count
}
package signatures
import "sync"
var (
globalScanner *Scanner
globalOnce sync.Once
)
// Init initializes the global scanner with rules from the given directory.
// Safe to call multiple times - only the first call takes effect.
// Call Reload() on the returned scanner to reload rules (e.g., on SIGHUP).
func Init(rulesDir string) *Scanner {
globalOnce.Do(func() {
globalScanner = NewScanner(rulesDir)
})
return globalScanner
}
// Global returns the global scanner, or nil if Init() hasn't been called.
func Global() *Scanner {
return globalScanner
}
package signatures
import (
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"gopkg.in/yaml.v3"
)
// Rule represents a single malware detection rule loaded from an external file.
type Rule struct {
Name string `yaml:"name"`
Description string `yaml:"description"`
Severity string `yaml:"severity"` // "critical", "high", "warning"
Category string `yaml:"category"` // "webshell", "backdoor", "phishing", "dropper", "exploit"
FileTypes []string `yaml:"file_types"` // [".php", ".html", "*"] - which extensions to scan
Patterns []string `yaml:"patterns"` // literal string patterns (case-insensitive match)
Regexes []string `yaml:"regexes"` // regex patterns (for complex matching)
ExcludePatterns []string `yaml:"exclude_patterns"` // if any match, rule is skipped (false positive reduction)
ExcludeRegexes []string `yaml:"exclude_regexes"` // regex exclusions
MinMatch int `yaml:"min_match"` // minimum patterns that must match (default: 1)
RequireRegex bool `yaml:"require_regex"` // if true, at least one regex must match in addition to min_match
// Compiled regexes (populated by Compile())
compiledRegexes []*regexp.Regexp
compiledExcludeRegexes []*regexp.Regexp
}
// RuleFile is the top-level structure of a rules YAML file.
type RuleFile struct {
Version int `yaml:"version"`
Updated string `yaml:"updated"`
Rules []Rule `yaml:"rules"`
}
// Scanner holds compiled rules and provides file scanning.
type Scanner struct {
mu sync.RWMutex
rules []Rule
version int
rulesDir string
}
// NewScanner creates a scanner that loads rules from the given directory.
// Returns a scanner with no rules if the directory doesn't exist (not an error).
func NewScanner(rulesDir string) *Scanner {
s := &Scanner{rulesDir: rulesDir}
_ = s.Reload() // best-effort load on init
return s
}
// Reload loads/reloads all .yml and .yaml rule files from the rules directory.
func (s *Scanner) Reload() error {
if s.rulesDir == "" {
return nil
}
entries, err := os.ReadDir(s.rulesDir)
if err != nil {
if os.IsNotExist(err) {
return nil // no rules dir = no rules, not an error
}
return fmt.Errorf("reading rules dir %s: %w", s.rulesDir, err)
}
var allRules []Rule
maxVersion := 0
fileCount := 0
for _, entry := range entries {
name := entry.Name()
if entry.IsDir() {
continue
}
ext := strings.ToLower(filepath.Ext(name))
if ext != ".yml" && ext != ".yaml" {
continue
}
fileCount++
path := filepath.Join(s.rulesDir, name)
// #nosec G304 -- filepath.Join under operator-configured rulesDir.
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("reading %s: %w", path, err)
}
var rf RuleFile
if err := yaml.Unmarshal(data, &rf); err != nil {
return fmt.Errorf("parsing %s: %w", path, err)
}
if rf.Version > maxVersion {
maxVersion = rf.Version
}
// Compile rules
for i := range rf.Rules {
rule := &rf.Rules[i]
if err := rule.compile(); err != nil {
return fmt.Errorf("compiling rule %q in %s: %w", rule.Name, path, err)
}
if rule.MinMatch == 0 {
rule.MinMatch = 1
}
allRules = append(allRules, *rule)
}
}
if fileCount == 0 {
s.mu.RLock()
hadRules := len(s.rules) > 0
s.mu.RUnlock()
if hadRules {
return fmt.Errorf("no signature rule files found in %s", s.rulesDir)
}
return nil
}
if len(allRules) == 0 {
return fmt.Errorf("no signature rules loaded from %s", s.rulesDir)
}
s.mu.Lock()
s.rules = allRules
s.version = maxVersion
s.mu.Unlock()
if len(allRules) > 0 {
fmt.Fprintf(os.Stderr, "signatures: loaded %d rules (version %d) from %s\n", len(allRules), maxVersion, s.rulesDir)
}
return nil
}
// compile pre-compiles regex patterns for a rule.
func (r *Rule) compile() error {
r.compiledRegexes = nil
for _, pattern := range r.Regexes {
re, err := regexp.Compile("(?i)" + pattern) // case-insensitive
if err != nil {
return fmt.Errorf("invalid regex '%s': %w", pattern, err)
}
r.compiledRegexes = append(r.compiledRegexes, re)
}
r.compiledExcludeRegexes = nil
for _, pattern := range r.ExcludeRegexes {
re, err := regexp.Compile("(?i)" + pattern)
if err != nil {
return fmt.Errorf("invalid exclude regex '%s': %w", pattern, err)
}
r.compiledExcludeRegexes = append(r.compiledExcludeRegexes, re)
}
return nil
}
// Match represents a rule that matched a file.
type Match struct {
RuleName string
Description string
Severity string
Category string
Matched []string // which patterns matched
}
// ScanContent scans file content against loaded rules.
// fileExt should include the dot (e.g., ".php").
func (s *Scanner) ScanContent(content []byte, fileExt string) []Match {
s.mu.RLock()
defer s.mu.RUnlock()
if len(s.rules) == 0 {
return nil
}
contentLower := strings.ToLower(string(content))
extLower := strings.ToLower(fileExt)
var matches []Match
for _, rule := range s.rules {
// Check if this rule applies to this file type
if !ruleMatchesExt(rule, extLower) {
continue
}
// Check exclusions first - if any exclude pattern matches, skip this rule
excluded := false
for _, pattern := range rule.ExcludePatterns {
if strings.Contains(contentLower, strings.ToLower(pattern)) {
excluded = true
break
}
}
if !excluded {
for _, re := range rule.compiledExcludeRegexes {
if re.Match(content) {
excluded = true
break
}
}
}
if excluded {
continue
}
// Count pattern matches
var matched []string
regexMatched := false
for _, pattern := range rule.Patterns {
if strings.Contains(contentLower, strings.ToLower(pattern)) {
matched = append(matched, pattern)
}
}
for _, re := range rule.compiledRegexes {
if re.Match(content) {
matched = append(matched, re.String())
regexMatched = true
}
}
if len(matched) >= rule.MinMatch && (!rule.RequireRegex || regexMatched) {
matches = append(matches, Match{
RuleName: rule.Name,
Description: rule.Description,
Severity: rule.Severity,
Category: rule.Category,
Matched: matched,
})
}
}
return matches
}
// ScanFile reads a file and scans it against loaded rules.
func (s *Scanner) ScanFile(path string, maxBytes int) []Match {
s.mu.RLock()
ruleCount := len(s.rules)
s.mu.RUnlock()
if ruleCount == 0 {
return nil
}
// #nosec G304 -- ScanFile's whole purpose is to scan a file on disk;
// `path` comes from the daemon's file index walker or a fanotify event.
f, err := os.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
buf := make([]byte, maxBytes)
n, _ := f.Read(buf)
if n == 0 {
return nil
}
ext := filepath.Ext(path)
return s.ScanContent(buf[:n], ext)
}
// RuleCount returns the number of loaded rules.
func (s *Scanner) RuleCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.rules)
}
// Version returns the highest version number across loaded rule files.
func (s *Scanner) Version() int {
s.mu.RLock()
defer s.mu.RUnlock()
return s.version
}
func ruleMatchesExt(rule Rule, ext string) bool {
if len(rule.FileTypes) == 0 {
return true // no filter = match all
}
for _, ft := range rule.FileTypes {
if ft == "*" || strings.ToLower(ft) == ext {
return true
}
}
return false
}
package signatures
import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
"gopkg.in/yaml.v3"
)
// Update downloads the latest rules from the configured URL.
// Validates the downloaded rules before installing.
// A detached ed25519 signature is fetched from url+".sig" and verified
// before the rules are installed.
// Returns the number of rules loaded, or error.
func Update(rulesDir, url, signingKey string) (int, error) {
if url == "" {
return 0, fmt.Errorf("no update URL configured (set signatures.update_url in csm.yaml)")
}
if err := requireSigningKey(signingKey); err != nil {
return 0, err
}
// Download
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Get(url)
if err != nil {
return 0, fmt.Errorf("downloading rules from %s: %w", url, err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
return 0, fmt.Errorf("download failed: HTTP %d from %s", resp.StatusCode, url)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) // max 10MB
if err != nil {
return 0, fmt.Errorf("reading response: %w", err)
}
sig, err := fetchSignature(url + ".sig")
if err != nil {
return 0, fmt.Errorf("signature verification required but failed: %w", err)
}
if err := VerifySignature(signingKey, data, sig); err != nil {
return 0, fmt.Errorf("rules signature invalid: %w", err)
}
// Validate: must parse as valid YAML rules
var rf RuleFile
if err := yaml.Unmarshal(data, &rf); err != nil {
return 0, fmt.Errorf("invalid rules file: %w", err)
}
if len(rf.Rules) == 0 {
return 0, fmt.Errorf("rules file contains no rules")
}
// Validate each rule compiles
for _, rule := range rf.Rules {
if err := rule.compile(); err != nil {
return 0, fmt.Errorf("rule '%s' failed validation: %w", rule.Name, err)
}
}
// Ensure rules directory exists
if err := os.MkdirAll(rulesDir, 0700); err != nil {
return 0, fmt.Errorf("creating rules dir: %w", err)
}
// Atomic write: temp file + rename
destPath := filepath.Join(rulesDir, "malware.yml")
tmpPath := destPath + ".tmp"
if err := os.WriteFile(tmpPath, data, 0600); err != nil {
return 0, fmt.Errorf("writing rules: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
os.Remove(tmpPath)
return 0, fmt.Errorf("installing rules: %w", err)
}
return len(rf.Rules), nil
}
package signatures
import (
"crypto/ed25519"
"encoding/hex"
"fmt"
"io"
"net/http"
"time"
)
func requireSigningKey(signingKey string) error {
if signingKey == "" {
return fmt.Errorf("signatures.signing_key is required for remote rule updates")
}
return nil
}
// VerifySignature checks an ed25519 signature over data using a hex-encoded public key.
// Returns nil if the signature is valid.
func VerifySignature(pubKeyHex string, data, signature []byte) error {
pubKeyBytes, err := hex.DecodeString(pubKeyHex)
if err != nil {
return fmt.Errorf("invalid signing key (bad hex): %w", err)
}
if len(pubKeyBytes) != ed25519.PublicKeySize {
return fmt.Errorf("invalid signing key length: got %d bytes, want %d", len(pubKeyBytes), ed25519.PublicKeySize)
}
if len(signature) != ed25519.SignatureSize {
return fmt.Errorf("invalid signature length: got %d bytes, want %d", len(signature), ed25519.SignatureSize)
}
pubKey := ed25519.PublicKey(pubKeyBytes)
if !ed25519.Verify(pubKey, data, signature) {
return fmt.Errorf("signature verification failed")
}
return nil
}
// fetchSignature downloads a detached signature from url + ".sig".
func fetchSignature(sigURL string) ([]byte, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(sigURL)
if err != nil {
return nil, fmt.Errorf("downloading signature from %s: %w", sigURL, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("signature download returned HTTP %d from %s", resp.StatusCode, sigURL)
}
sig, err := io.ReadAll(io.LimitReader(resp.Body, 1024)) // ed25519 sig is 64 bytes
if err != nil {
return nil, fmt.Errorf("reading signature: %w", err)
}
return sig, nil
}
package state
import (
"fmt"
"os"
"path/filepath"
"syscall"
)
// LockFile provides file-based locking to prevent concurrent CSM runs.
type LockFile struct {
path string
file *os.File
}
// AcquireLock creates an exclusive lock. Returns error if already locked.
func AcquireLock(stateDir string) (*LockFile, error) {
lockPath := filepath.Join(stateDir, "csm.lock")
// #nosec G304 -- filepath.Join under operator-configured stateDir.
f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
return nil, fmt.Errorf("opening lock file: %w", err)
}
// Try non-blocking exclusive lock
// #nosec G115 -- os.File.Fd returns uintptr but POSIX file descriptors
// are small non-negative ints (rlimit ~1024); int conversion is lossless.
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil {
_ = f.Close()
return nil, fmt.Errorf("another CSM instance is already running")
}
// Write PID for debugging
_ = f.Truncate(0)
fmt.Fprintf(f, "%d\n", os.Getpid())
return &LockFile{path: lockPath, file: f}, nil
}
// Release releases the lock.
func (l *LockFile) Release() {
if l.file != nil {
// #nosec G115 -- see AcquireLock: POSIX fd fits in int.
_ = syscall.Flock(int(l.file.Fd()), syscall.LOCK_UN)
_ = l.file.Close()
os.Remove(l.path)
}
}
package state
import (
"crypto/sha256"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/store"
)
type Store struct {
mu sync.RWMutex
path string
entries map[string]*Entry
dirty bool // true if state changed since last save
savedHash string // hash of last saved state
// LatestFindings holds the full output of the most recent scan cycle.
// This is what the Findings page shows - "what's wrong right now" -
// separate from the alert dedup state above which controls "what to email."
latestMu sync.RWMutex
latestFindings []alert.Finding
latestScanTime time.Time
}
type Entry struct {
Hash string `json:"hash"`
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
AlertSent time.Time `json:"alert_sent"`
IsBaseline bool `json:"is_baseline"`
}
func Open(path string) (*Store, error) {
if err := os.MkdirAll(path, 0700); err != nil {
return nil, fmt.Errorf("creating state dir: %w", err)
}
s := &Store{
path: path,
entries: make(map[string]*Entry),
}
stateFile := filepath.Join(path, "state.json")
// #nosec G304 -- operator-configured statePath + fixed filename.
data, err := os.ReadFile(stateFile)
if err == nil {
// Backup state file before loading in case of corruption
// #nosec G703 -- stateFile is filepath.Join(path, "state.json") where
// path is the operator-configured statePath from csm.yaml.
_ = os.WriteFile(stateFile+".bak", data, 0600)
if unmarshalErr := json.Unmarshal(data, &s.entries); unmarshalErr != nil {
fmt.Fprintf(os.Stderr, "warning: failed to parse %s: %v (backup saved to %s.bak)\n", stateFile, unmarshalErr, stateFile)
}
}
// Load latest findings from disk (survives restart)
latestFile := filepath.Join(path, "latest_findings.json")
// #nosec G304 -- operator-configured statePath + fixed filename.
if latestData, err := os.ReadFile(latestFile); err == nil {
// Backup latest findings before loading
// #nosec G703 -- latestFile derived the same way as stateFile above.
_ = os.WriteFile(latestFile+".bak", latestData, 0600)
var findings []alert.Finding
if unmarshalErr := json.Unmarshal(latestData, &findings); unmarshalErr != nil {
fmt.Fprintf(os.Stderr, "warning: failed to parse %s: %v (backup saved to %s.bak)\n", latestFile, unmarshalErr, latestFile)
} else {
s.latestFindings = findings
}
}
return s, nil
}
func (s *Store) Close() error {
if !s.dirty {
return nil
}
return s.save()
}
func (s *Store) save() error {
data, err := json.MarshalIndent(s.entries, "", " ")
if err != nil {
return err
}
// Skip write if content hasn't changed
newHash := fmt.Sprintf("%x", sha256.Sum256(data))
if newHash == s.savedHash {
s.dirty = false
return nil
}
// Atomic write: write to temp file, then rename
stateFile := filepath.Join(s.path, "state.json")
tmpFile := stateFile + ".tmp"
if err := os.WriteFile(tmpFile, data, 0600); err != nil {
return err
}
if err := os.Rename(tmpFile, stateFile); err != nil {
os.Remove(tmpFile)
return err
}
s.savedHash = newHash
s.dirty = false
return nil
}
func findingKey(f alert.Finding) string {
// Include a truncated hash of Details to prevent collisions when
// the same Check:Message pair has different Details content.
if f.Details == "" {
return fmt.Sprintf("%s:%s", f.Check, f.Message)
}
h := sha256.Sum256([]byte(f.Details))
return fmt.Sprintf("%s:%s:%x", f.Check, f.Message, h[:4])
}
func findingHash(f alert.Finding) string {
h := sha256.Sum256([]byte(fmt.Sprintf("%s:%s:%s", f.Check, f.Message, f.Details)))
return fmt.Sprintf("%x", h[:8])
}
func (s *Store) FilterNew(findings []alert.Finding) []alert.Finding {
s.mu.RLock()
defer s.mu.RUnlock()
var newFindings []alert.Finding
for _, f := range findings {
key := findingKey(f)
hash := findingHash(f)
entry, exists := s.entries[key]
if !exists {
newFindings = append(newFindings, f)
continue
}
if entry.IsBaseline && entry.Hash == hash {
continue
}
if entry.Hash != hash {
newFindings = append(newFindings, f)
continue
}
// Same finding, check if we should re-alert (state expiry)
if !entry.AlertSent.IsZero() && time.Since(entry.AlertSent) > 24*time.Hour {
newFindings = append(newFindings, f)
}
}
return newFindings
}
func (s *Store) Update(findings []alert.Finding) {
s.mu.Lock()
defer s.mu.Unlock()
s.dirty = true
now := time.Now()
seen := make(map[string]bool)
for _, f := range findings {
key := findingKey(f)
hash := findingHash(f)
seen[key] = true
entry, exists := s.entries[key]
if !exists {
s.entries[key] = &Entry{
Hash: hash,
FirstSeen: now,
LastSeen: now,
AlertSent: now,
}
} else {
entry.Hash = hash
entry.LastSeen = now
if entry.AlertSent.IsZero() {
entry.AlertSent = now
}
}
}
// Clean up entries that are no longer found
for key, entry := range s.entries {
if !seen[key] && !entry.IsBaseline {
if time.Since(entry.LastSeen) > 24*time.Hour {
delete(s.entries, key)
}
}
}
if err := s.save(); err != nil {
fmt.Fprintf(os.Stderr, "state: error saving after update: %v\n", err)
}
}
func (s *Store) SetBaseline(findings []alert.Finding) {
s.mu.Lock()
defer s.mu.Unlock()
s.dirty = true
s.entries = make(map[string]*Entry)
now := time.Now()
for _, f := range findings {
key := findingKey(f)
hash := findingHash(f)
s.entries[key] = &Entry{
Hash: hash,
FirstSeen: now,
LastSeen: now,
IsBaseline: true,
}
}
if err := s.save(); err != nil {
fmt.Fprintf(os.Stderr, "state: error saving baseline: %v\n", err)
}
}
func (s *Store) ShouldRunThrottled(checkName string, intervalMin int) bool {
s.mu.Lock()
defer s.mu.Unlock()
key := fmt.Sprintf("_throttle:%s", checkName)
entry, exists := s.entries[key]
if !exists {
s.entries[key] = &Entry{LastSeen: time.Now()}
return true
}
if time.Since(entry.LastSeen) >= time.Duration(intervalMin)*time.Minute {
entry.LastSeen = time.Now()
return true
}
return false
}
func (s *Store) GetRaw(key string) (string, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
entry, ok := s.entries[key]
if !ok {
return "", false
}
return entry.Hash, true
}
func (s *Store) SetRaw(key, value string) {
s.mu.Lock()
defer s.mu.Unlock()
entry, exists := s.entries[key]
if !exists {
s.entries[key] = &Entry{
Hash: value,
FirstSeen: time.Now(),
LastSeen: time.Now(),
}
s.dirty = true
} else if entry.Hash != value {
entry.Hash = value
entry.LastSeen = time.Now()
s.dirty = true
}
}
// AppendHistory writes findings to the bbolt store (if available) or
// falls back to the append-only JSONL history file.
// The JSONL fallback is deprecated and will be removed in a future release.
func (s *Store) AppendHistory(findings []alert.Finding) {
if len(findings) == 0 {
return
}
// Use bbolt store when available; skip JSONL writes entirely.
if db := store.Global(); db != nil {
if err := db.AppendHistory(findings); err != nil {
fmt.Fprintf(os.Stderr, "store: append history: %v\n", err)
}
return
}
// Deprecated: flat-file JSONL fallback has a truncation race condition.
// This path is kept only for installations that have not yet migrated to bbolt.
fmt.Fprintf(os.Stderr, "DEPRECATION: using JSONL history fallback; migrate to bbolt store\n")
s.appendHistoryFile(findings)
}
// appendHistoryFile writes findings to the append-only JSONL history file.
// Caps file at 10MB by truncating the oldest half.
func (s *Store) appendHistoryFile(findings []alert.Finding) {
histPath := filepath.Join(s.path, "history.jsonl")
// Check size, truncate if over 10MB
if info, err := os.Stat(histPath); err == nil && info.Size() > 10*1024*1024 {
// #nosec G304 -- histPath is {s.path}/history.jsonl; s.path is the
// operator-configured statePath set at Store creation.
data, err := os.ReadFile(histPath)
if err == nil {
// Keep the second half
half := len(data) / 2
for half < len(data) && data[half] != '\n' {
half++
}
if half < len(data) {
// #nosec G703 -- histPath comes from state.historyPath, a
// filepath.Join under the operator-configured statePath.
_ = os.WriteFile(histPath, data[half+1:], 0600)
}
}
}
// #nosec G304 -- see histPath derivation above.
f, err := os.OpenFile(histPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return
}
defer func() { _ = f.Close() }()
for _, finding := range findings {
line, err := json.Marshal(finding)
if err != nil {
continue
}
_, _ = f.Write(line)
_, _ = f.Write([]byte("\n"))
}
}
func (s *Store) PrintStatus() {
if len(s.entries) == 0 {
fmt.Println("No state entries. Run 'csm baseline' first.")
return
}
baselineCount := 0
activeCount := 0
for key, entry := range s.entries {
if entry.IsBaseline {
baselineCount++
continue
}
if key[0] == '_' {
continue
}
activeCount++
fmt.Printf(" [ACTIVE] %s (first: %s, last: %s)\n",
key,
entry.FirstSeen.Format("2006-01-02 15:04"),
entry.LastSeen.Format("2006-01-02 15:04"),
)
}
fmt.Printf("\nBaseline entries: %d, Active findings: %d\n", baselineCount, activeCount)
}
// Entries returns a snapshot copy of current state entries (thread-safe).
func (s *Store) Entries() map[string]*Entry {
s.mu.RLock()
defer s.mu.RUnlock()
copy := make(map[string]*Entry, len(s.entries))
for k, v := range s.entries {
if k[0] == '_' {
continue // skip internal keys
}
entryCopy := *v
copy[k] = &entryCopy
}
return copy
}
// EntryForKey returns the dedup entry for a finding key (check:message).
func (s *Store) EntryForKey(key string) (Entry, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
e, ok := s.entries[key]
if !ok {
return Entry{}, false
}
return *e, true
}
// ReadHistory reads the last limit entries, starting at offset.
// Returns the findings (newest first) and total count.
// Uses bbolt store when available, falls back to flat-file JSONL.
func (s *Store) ReadHistory(limit, offset int) ([]alert.Finding, int) {
// Use bbolt store when available.
if db := store.Global(); db != nil {
return db.ReadHistory(limit, offset)
}
// Fallback: flat-file JSONL.
historyPath := filepath.Join(s.path, "history.jsonl")
// #nosec G304 -- {s.path}/history.jsonl; s.path from operator config.
data, err := os.ReadFile(historyPath)
if err != nil {
return nil, 0
}
var all []alert.Finding
for _, line := range splitLines(data) {
if len(line) == 0 {
continue
}
var f alert.Finding
if err := json.Unmarshal(line, &f); err != nil {
continue
}
all = append(all, f)
}
// Reverse (newest first)
for i, j := 0, len(all)-1; i < j; i, j = i+1, j-1 {
all[i], all[j] = all[j], all[i]
}
total := len(all)
if offset >= total {
return nil, total
}
end := offset + limit
if end > total {
end = total
}
return all[offset:end], total
}
// ReadHistorySince returns all findings since the given time.
// Uses bbolt cursor seeking for efficiency. Results are newest-first.
func (s *Store) ReadHistorySince(since time.Time) []alert.Finding {
if db := store.Global(); db != nil {
return db.ReadHistorySince(since)
}
return nil
}
// AggregateByHour returns 24 hourly severity buckets for the last 24 hours.
func (s *Store) AggregateByHour() []store.HourBucket {
if db := store.Global(); db != nil {
return db.AggregateByHour()
}
return nil
}
// AggregateByDay returns 30 daily severity buckets for the last 30 days.
func (s *Store) AggregateByDay() []store.DayBucket {
if db := store.Global(); db != nil {
return db.AggregateByDay()
}
return nil
}
func splitLines(data []byte) [][]byte {
var lines [][]byte
start := 0
for i, b := range data {
if b == '\n' {
if i > start {
lines = append(lines, data[start:i])
}
start = i + 1
}
}
if start < len(data) {
lines = append(lines, data[start:])
}
return lines
}
// SetLatestFindings merges scan results into the current findings set.
// Called by the daemon after each periodic scan completes. Merges rather
// than replaces - critical scan results coexist with deep scan results.
// Use ClearLatestFindings() + SetLatestFindings() for a full replace.
func (s *Store) SetLatestFindings(findings []alert.Finding) {
s.latestMu.Lock()
defer s.latestMu.Unlock()
// Build map of existing findings by key
existing := make(map[string]alert.Finding)
for _, f := range s.latestFindings {
existing[f.Key()] = f
}
// Merge new findings (update existing, add new)
for _, f := range findings {
existing[f.Key()] = f // newer overwrites older
}
// Flatten back to slice
var merged []alert.Finding
for _, f := range existing {
merged = append(merged, f)
}
// Cap at 15,000 findings to prevent unbounded memory growth
if len(merged) > 15000 {
merged = merged[:15000]
}
s.latestFindings = merged
s.latestScanTime = time.Now()
// Persist to disk
data, _ := json.Marshal(merged)
tmpPath := filepath.Join(s.path, "latest_findings.json.tmp")
_ = os.WriteFile(tmpPath, data, 0600)
_ = os.Rename(tmpPath, filepath.Join(s.path, "latest_findings.json"))
}
// PurgeFindingsByChecks removes all findings whose Check field matches
// any of the given check names and persists the result to disk.
// Used to clear stale performance findings before merging fresh results
// from a scan tier.
func (s *Store) PurgeFindingsByChecks(checks []string) {
if len(checks) == 0 {
return
}
s.latestMu.Lock()
defer s.latestMu.Unlock()
remove := make(map[string]bool, len(checks))
for _, c := range checks {
remove[c] = true
}
n := 0
for _, f := range s.latestFindings {
if !remove[f.Check] {
s.latestFindings[n] = f
n++
}
}
s.latestFindings = s.latestFindings[:n]
// Persist to disk so purged findings don't reappear after restart.
// Mirrors the persistence logic at the end of SetLatestFindings().
data, _ := json.Marshal(s.latestFindings)
tmpPath := filepath.Join(s.path, "latest_findings.json.tmp")
_ = os.WriteFile(tmpPath, data, 0600)
_ = os.Rename(tmpPath, filepath.Join(s.path, "latest_findings.json"))
}
// PurgeAndMergeFindings atomically removes findings matching the given check
// names and then merges the new findings. This prevents a race window where
// concurrent readers could see findings with perf checks missing.
func (s *Store) PurgeAndMergeFindings(purgeChecks []string, findings []alert.Finding) {
s.latestMu.Lock()
defer s.latestMu.Unlock()
// Build set of checks to purge
remove := make(map[string]bool, len(purgeChecks))
for _, c := range purgeChecks {
remove[c] = true
}
// Build map: keep existing non-purged findings
existing := make(map[string]alert.Finding)
for _, f := range s.latestFindings {
if !remove[f.Check] {
existing[f.Key()] = f
}
}
// Merge new findings
for _, f := range findings {
existing[f.Key()] = f
}
// Flatten
var merged []alert.Finding
for _, f := range existing {
merged = append(merged, f)
}
if len(merged) > 15000 {
merged = merged[:15000]
}
s.latestFindings = merged
s.latestScanTime = time.Now()
// Persist
data, _ := json.Marshal(merged)
tmpPath := filepath.Join(s.path, "latest_findings.json.tmp")
_ = os.WriteFile(tmpPath, data, 0600)
_ = os.Rename(tmpPath, filepath.Join(s.path, "latest_findings.json"))
}
// ClearLatestFindings removes all findings from the latest set.
// Use before SetLatestFindings for a full replace (e.g. initial scan).
func (s *Store) ClearLatestFindings() {
s.latestMu.Lock()
defer s.latestMu.Unlock()
s.latestFindings = nil
}
// LatestFindings returns the full results of the most recent scan.
// This is what the Findings page shows - "what's wrong right now."
func (s *Store) LatestFindings() []alert.Finding {
s.latestMu.RLock()
defer s.latestMu.RUnlock()
result := make([]alert.Finding, len(s.latestFindings))
copy(result, s.latestFindings)
return result
}
// LatestScanTime returns when the last scan completed.
func (s *Store) LatestScanTime() time.Time {
s.latestMu.RLock()
defer s.latestMu.RUnlock()
return s.latestScanTime
}
// DismissLatestFinding removes a finding from the latest scan results.
func (s *Store) DismissLatestFinding(key string) {
s.latestMu.Lock()
defer s.latestMu.Unlock()
var filtered []alert.Finding
for _, f := range s.latestFindings {
if f.Key() != key {
filtered = append(filtered, f)
}
}
s.latestFindings = filtered
}
// DismissFinding marks a finding as baseline (acknowledged/dismissed).
// It will no longer appear in active findings or trigger new alerts.
func (s *Store) DismissFinding(key string) {
s.mu.Lock()
defer s.mu.Unlock()
if entry, exists := s.entries[key]; exists {
entry.IsBaseline = true
s.dirty = true
}
}
// ParseKey splits a state key "check:message" into its components.
func ParseKey(key string) (check, message string) {
for i := 0; i < len(key); i++ {
if key[i] == ':' {
return key[:i], key[i+1:]
}
}
return key, ""
}
// --- Suppression rules ---
// SuppressionRule defines a rule for suppressing specific findings.
type SuppressionRule struct {
ID string `json:"id"`
Check string `json:"check"`
PathPattern string `json:"path_pattern,omitempty"`
Reason string `json:"reason"`
CreatedAt time.Time `json:"created_at"`
}
// LoadSuppressions reads suppression rules from disk.
func (s *Store) LoadSuppressions() []SuppressionRule {
data, err := os.ReadFile(filepath.Join(s.path, "suppressions.json"))
if err != nil {
return nil
}
var rules []SuppressionRule
if err := json.Unmarshal(data, &rules); err != nil {
return nil
}
return rules
}
// SaveSuppressions writes suppression rules to disk atomically.
func (s *Store) SaveSuppressions(rules []SuppressionRule) error {
data, err := json.MarshalIndent(rules, "", " ")
if err != nil {
return err
}
target := filepath.Join(s.path, "suppressions.json")
tmp := target + ".tmp"
if err := os.WriteFile(tmp, data, 0600); err != nil {
return err
}
if err := os.Rename(tmp, target); err != nil {
os.Remove(tmp)
return err
}
return nil
}
// IsSuppressed checks if a finding matches any loaded suppression rule.
// Load rules once with LoadSuppressions() and pass them in to avoid
// re-reading the file for every finding.
func (s *Store) IsSuppressed(f alert.Finding, rules []SuppressionRule) bool {
for _, rule := range rules {
if f.Check != rule.Check {
continue
}
// If no path pattern, suppress all findings for this check type
if rule.PathPattern == "" {
return true
}
// Match against the finding's FilePath
if f.FilePath != "" {
if matched, _ := filepath.Match(rule.PathPattern, f.FilePath); matched {
return true
}
}
for _, candidate := range suppressionPathCandidates(f) {
if matched, _ := filepath.Match(rule.PathPattern, candidate); matched {
return true
}
}
}
return false
}
func suppressionPathCandidates(f alert.Finding) []string {
if f.FilePath != "" {
return []string{f.FilePath}
}
fields := strings.Fields(f.Message)
seen := make(map[string]bool)
var paths []string
for _, field := range fields {
field = strings.Trim(field, `"'():,;[]{}<>`)
if !strings.HasPrefix(field, "/") {
continue
}
candidate := filepath.Clean(field)
if candidate == "." || candidate == "/" || seen[candidate] {
continue
}
seen[candidate] = true
paths = append(paths, candidate)
}
return paths
}
package store
import (
"encoding/json"
"fmt"
"time"
"github.com/pidginhost/csm/internal/alert"
bolt "go.etcd.io/bbolt"
)
// SeverityBucket holds aggregated counts by severity.
type SeverityBucket struct {
Critical int `json:"critical"`
High int `json:"high"`
Warning int `json:"warning"`
Total int `json:"total"`
}
// HourBucket is a SeverityBucket keyed by hour label.
type HourBucket struct {
Hour string `json:"hour"`
SeverityBucket
}
// DayBucket is a SeverityBucket keyed by date.
type DayBucket struct {
Date string `json:"date"`
SeverityBucket
}
// AggregateByHour returns 24 hourly buckets (oldest first) for the last 24 hours.
// It seeks directly to the start key in bbolt, scanning only the relevant range.
func (db *DB) AggregateByHour() []HourBucket {
now := time.Now()
currentHour := now.Truncate(time.Hour)
cutoff := currentHour.Add(-23 * time.Hour)
// Map: hours-ago (0=current, 23=oldest) → counts
counts := make(map[int]*SeverityBucket, 24)
for i := 0; i < 24; i++ {
counts[i] = &SeverityBucket{}
}
seekPrefix := fmt.Sprintf("%04d%02d%02d%02d",
cutoff.Year(), cutoff.Month(), cutoff.Day(), cutoff.Hour())
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
if b == nil {
return nil
}
c := b.Cursor()
// Seek to the earliest key that could be in our 24h window.
for k, v := c.Seek([]byte(seekPrefix)); k != nil; k, v = c.Next() {
var f alert.Finding
if err := json.Unmarshal(v, &f); err != nil {
continue
}
if f.Timestamp.Before(cutoff) {
continue
}
if f.Timestamp.After(now) {
continue
}
fHour := f.Timestamp.Truncate(time.Hour)
hoursAgo := int(currentHour.Sub(fHour).Hours())
if hoursAgo < 0 || hoursAgo >= 24 {
continue
}
bucket := counts[hoursAgo]
bucket.Total++
switch f.Severity {
case alert.Critical:
bucket.Critical++
case alert.High:
bucket.High++
case alert.Warning:
bucket.Warning++
}
}
return nil
})
// Build result oldest→newest (23h ago → 0h ago)
result := make([]HourBucket, 24)
for i := 0; i < 24; i++ {
hoursAgo := 23 - i
t := currentHour.Add(-time.Duration(hoursAgo) * time.Hour)
result[i] = HourBucket{
Hour: fmt.Sprintf("%02d:00", t.Hour()),
SeverityBucket: *counts[hoursAgo],
}
}
return result
}
// ReadHistorySince returns all findings since the given time, using bbolt cursor
// seeking for efficiency. Results are newest-first.
func (db *DB) ReadHistorySince(since time.Time) []alert.Finding {
seekPrefix := fmt.Sprintf("%04d%02d%02d%02d%02d%02d",
since.Year(), since.Month(), since.Day(),
since.Hour(), since.Minute(), since.Second())
var results []alert.Finding
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
if b == nil {
return nil
}
c := b.Cursor()
for k, v := c.Seek([]byte(seekPrefix)); k != nil; k, v = c.Next() {
var f alert.Finding
if err := json.Unmarshal(v, &f); err != nil {
continue
}
results = append(results, f)
}
return nil
})
// Reverse to newest-first
for i, j := 0, len(results)-1; i < j; i, j = i+1, j-1 {
results[i], results[j] = results[j], results[i]
}
return results
}
// AggregateByDay returns 30 daily buckets (oldest first) for the last 30 days.
// It seeks directly to the start key in bbolt, scanning only the relevant range.
func (db *DB) AggregateByDay() []DayBucket {
now := time.Now()
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local)
cutoff := today.AddDate(0, 0, -29)
// Map: date string → bucket index
dateIndex := make(map[string]int, 30)
buckets := make([]DayBucket, 30)
for i := 0; i < 30; i++ {
d := cutoff.AddDate(0, 0, i)
key := d.Format("2006-01-02")
buckets[i] = DayBucket{Date: key}
dateIndex[key] = i
}
seekPrefix := fmt.Sprintf("%04d%02d%02d",
cutoff.Year(), cutoff.Month(), cutoff.Day())
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
if b == nil {
return nil
}
c := b.Cursor()
for k, v := c.Seek([]byte(seekPrefix)); k != nil; k, v = c.Next() {
var f alert.Finding
if err := json.Unmarshal(v, &f); err != nil {
continue
}
if f.Timestamp.Before(cutoff) {
continue
}
key := f.Timestamp.Format("2006-01-02")
idx, ok := dateIndex[key]
if !ok {
continue
}
buckets[idx].Total++
switch f.Severity {
case alert.Critical:
buckets[idx].Critical++
case alert.High:
buckets[idx].High++
case alert.Warning:
buckets[idx].Warning++
}
}
return nil
})
return buckets
}
package store
import (
"bytes"
"encoding/json"
"fmt"
"time"
bolt "go.etcd.io/bbolt"
)
// maxAttackEvents is the maximum number of attack events to retain.
// It is a var (not const) so tests can override it.
var maxAttackEvents = 100_000
// AttackEvent is the store-layer representation of an attack event.
type AttackEvent struct {
Timestamp time.Time `json:"timestamp"`
IP string `json:"ip"`
AttackType string `json:"attack_type"`
CheckName string `json:"check_name"`
Severity int `json:"severity"`
Account string `json:"account,omitempty"`
Message string `json:"message,omitempty"`
}
// IPRecord is the store-layer representation of an IP attack record.
type IPRecord struct {
IP string `json:"ip"`
FirstSeen time.Time `json:"first_seen"`
LastSeen time.Time `json:"last_seen"`
EventCount int `json:"event_count"`
AttackCounts map[string]int `json:"attack_counts,omitempty"`
Accounts map[string]int `json:"accounts,omitempty"`
ThreatScore int `json:"threat_score"`
AutoBlocked bool `json:"auto_blocked,omitempty"`
}
// RecordAttackEvent inserts an attack event into both the primary bucket
// (attacks:events, keyed by TimeKey) and the secondary index bucket
// (attacks:events:ip, keyed by IP/TimeKey). It increments the event counter
// and prunes oldest entries if the count exceeds maxAttackEvents.
func (db *DB) RecordAttackEvent(event AttackEvent, counter int) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
primary := tx.Bucket([]byte("attacks:events"))
secondary := tx.Bucket([]byte("attacks:events:ip"))
key := TimeKey(event.Timestamp, counter)
val, err := json.Marshal(event)
if err != nil {
return err
}
if err := primary.Put([]byte(key), val); err != nil {
return err
}
secondaryKey := event.IP + "/" + key
if err := secondary.Put([]byte(secondaryKey), val); err != nil {
return err
}
if err := incrCounter(tx, "attacks:events:count", 1); err != nil {
return err
}
// Prune oldest entries if count exceeds maxAttackEvents.
meta := tx.Bucket([]byte("meta"))
var count int
if v := meta.Get([]byte("attacks:events:count")); v != nil {
_, _ = fmt.Sscanf(string(v), "%d", &count)
}
if count > maxAttackEvents {
excess := count - maxAttackEvents
c := primary.Cursor()
k, v := c.First()
for ; k != nil && excess > 0; excess-- {
// Unmarshal to get the IP for secondary index cleanup.
var ev AttackEvent
if err := json.Unmarshal(v, &ev); err != nil {
return err
}
secKey := ev.IP + "/" + string(k)
if err := secondary.Delete([]byte(secKey)); err != nil {
return err
}
if err := c.Delete(); err != nil {
return err
}
// Re-seek after delete (bbolt cursor behavior).
k, v = c.First()
}
if err := setCounter(tx, "attacks:events:count", maxAttackEvents); err != nil {
return err
}
}
return nil
})
}
// QueryAttackEvents returns up to limit attack events for the given IP,
// newest-first. It uses the secondary index bucket for efficient prefix-based
// iteration.
func (db *DB) QueryAttackEvents(ip string, limit int) []AttackEvent {
var results []AttackEvent
prefix := []byte(ip + "/")
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:events:ip"))
c := b.Cursor()
// Collect all events matching the prefix.
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
var ev AttackEvent
if err := json.Unmarshal(v, &ev); err == nil {
results = append(results, ev)
}
}
return nil
})
// Reverse for newest-first order.
for i, j := 0, len(results)-1; i < j; i, j = i+1, j-1 {
results[i], results[j] = results[j], results[i]
}
// Take up to limit.
if len(results) > limit {
results = results[:limit]
}
return results
}
// SaveIPRecord stores an IP record in the attacks:records bucket, keyed by IP.
func (db *DB) SaveIPRecord(record IPRecord) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:records"))
val, err := json.Marshal(record)
if err != nil {
return err
}
return b.Put([]byte(record.IP), val)
})
}
// LoadIPRecord retrieves an IP record from the attacks:records bucket.
// Returns the record and true if found, or a zero value and false if not.
func (db *DB) LoadIPRecord(ip string) (IPRecord, bool) {
var record IPRecord
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:records"))
v := b.Get([]byte(ip))
if v == nil {
return nil
}
if json.Unmarshal(v, &record) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return record, found
}
// LoadAllIPRecords returns all IP records from the attacks:records bucket.
func (db *DB) LoadAllIPRecords() map[string]*IPRecord {
records := make(map[string]*IPRecord)
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:records"))
return b.ForEach(func(k, v []byte) error {
var record IPRecord
if json.Unmarshal(v, &record) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
records[string(k)] = &record
return nil
})
})
return records
}
// DeleteIPRecord removes an IP record from the attacks:records bucket.
func (db *DB) DeleteIPRecord(ip string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:records"))
return b.Delete([]byte(ip))
})
}
// ReadAllAttackEvents returns all attack events from the primary bucket.
// Used for stats computation (hourly/daily bucketing).
func (db *DB) ReadAllAttackEvents() []AttackEvent {
var events []AttackEvent
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("attacks:events"))
return b.ForEach(func(k, v []byte) error {
var ev AttackEvent
if json.Unmarshal(v, &ev) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
events = append(events, ev)
return nil
})
})
return events
}
package store
import (
"fmt"
"os"
"path/filepath"
"sync"
"time"
bolt "go.etcd.io/bbolt"
)
// Bucket names - all buckets are created on Open().
var bucketNames = []string{
"history",
"attacks:records",
"attacks:events",
"attacks:events:ip",
"threats",
"threats:whitelist",
"fw:blocked",
"fw:allowed",
"fw:subnets",
"fw:port_allowed",
"reputation",
"plugins",
"plugins:sites",
"meta",
"email:geo",
"email:fwd",
}
// DB wraps a bbolt database.
type DB struct {
bolt *bolt.DB
path string
}
var (
globalDB *DB
globalMu sync.Mutex
ensureOnce sync.Once
ensureErr error
)
// Global returns the singleton DB instance.
func Global() *DB {
globalMu.Lock()
defer globalMu.Unlock()
return globalDB
}
// SetGlobal sets the singleton DB instance.
func SetGlobal(db *DB) {
globalMu.Lock()
globalDB = db
globalMu.Unlock()
}
// EnsureOpen opens the store if not already open. Safe to call from any CLI path.
// First call opens the DB; subsequent calls return immediately.
func EnsureOpen(statePath string) error {
ensureOnce.Do(func() {
db, err := Open(statePath)
if err != nil {
ensureErr = err
return
}
SetGlobal(db)
})
return ensureErr
}
// Open opens or creates the bbolt database at {statePath}/csm.db.
// Creates all buckets if they don't exist. Runs migration if needed.
func Open(statePath string) (*DB, error) {
dbPath := filepath.Join(statePath, "csm.db")
if err := os.MkdirAll(statePath, 0700); err != nil {
return nil, fmt.Errorf("creating state dir: %w", err)
}
bdb, err := bolt.Open(dbPath, 0600, &bolt.Options{Timeout: 5 * time.Second})
if err != nil {
return nil, fmt.Errorf("opening bbolt: %w", err)
}
// Create all buckets
err = bdb.Update(func(tx *bolt.Tx) error {
for _, name := range bucketNames {
if _, berr := tx.CreateBucketIfNotExists([]byte(name)); berr != nil {
return fmt.Errorf("creating bucket %s: %w", name, berr)
}
}
return nil
})
if err != nil {
_ = bdb.Close()
return nil, err
}
db := &DB{bolt: bdb, path: dbPath}
// Run migration if needed
if err := db.migrateIfNeeded(statePath); err != nil {
fmt.Fprintf(os.Stderr, "store: migration warning: %v\n", err)
}
// Seed default ModSecurity no-escalate rules (one-time only).
// Uses a sentinel key so an admin who deliberately empties the set
// won't have defaults re-added on every restart.
var seeded bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
if v := tx.Bucket([]byte("meta")).Get([]byte("modsec:no_escalate_seeded")); v != nil {
seeded = true
}
return nil
})
if !seeded {
_ = db.SetModSecNoEscalateRules(map[int]bool{
900112: true, // WordPress user enumeration - blocks at HTTP level only
})
_ = db.bolt.Update(func(tx *bolt.Tx) error {
return tx.Bucket([]byte("meta")).Put([]byte("modsec:no_escalate_seeded"), []byte("1"))
})
}
return db, nil
}
// Close closes the bbolt database.
func (db *DB) Close() error {
if db.bolt == nil {
return nil
}
return db.bolt.Close()
}
// TimeKey produces a fixed-width 28-byte key for chronological ordering.
// Format: YYYYMMDDHHmmssNNNNNNNNN-CCCC
// Lexicographic order equals chronological order.
func TimeKey(t time.Time, counter int) string {
return fmt.Sprintf("%04d%02d%02d%02d%02d%02d%09d-%04d",
t.Year(), t.Month(), t.Day(),
t.Hour(), t.Minute(), t.Second(),
t.Nanosecond(), counter)
}
// ParseTimeKeyPrefix converts a date string "YYYY-MM-DD" to a seek prefix "YYYYMMDD".
func ParseTimeKeyPrefix(date string) string {
if len(date) == 10 && date[4] == '-' && date[7] == '-' {
return date[:4] + date[5:7] + date[8:10]
}
return date
}
// getCounter reads a counter from the meta bucket. Returns 0 if not found.
// HistoryCount returns the number of findings in the history bucket.
func (db *DB) HistoryCount() int {
return db.getCounter("history:count")
}
func (db *DB) getCounter(key string) int {
var count int
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
if v := b.Get([]byte(key)); v != nil {
_, _ = fmt.Sscanf(string(v), "%d", &count)
}
return nil
})
return count
}
// setCounter writes a counter to the meta bucket within an existing transaction.
func setCounter(tx *bolt.Tx, key string, count int) error {
b := tx.Bucket([]byte("meta"))
return b.Put([]byte(key), []byte(fmt.Sprintf("%d", count)))
}
// incrCounter increments a counter within an existing transaction.
func incrCounter(tx *bolt.Tx, key string, delta int) error {
b := tx.Bucket([]byte("meta"))
var current int
if v := b.Get([]byte(key)); v != nil {
fmt.Sscanf(string(v), "%d", ¤t)
}
return b.Put([]byte(key), []byte(fmt.Sprintf("%d", current+delta)))
}
// migrateIfNeeded checks for the meta:migrated key and runs migration if absent.
func (db *DB) migrateIfNeeded(statePath string) error {
var migrated bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
if b.Get([]byte("migrated")) != nil {
migrated = true
}
return nil
})
if migrated {
return nil
}
return db.runMigration(statePath)
}
package store
import (
"encoding/json"
"time"
bolt "go.etcd.io/bbolt"
)
// GeoHistory tracks the countries from which a mailbox has logged in.
type GeoHistory struct {
Countries map[string]int64 `json:"countries"`
LoginCount int `json:"login_count"`
}
// SetGeoHistory stores geo login history for a mailbox in the email:geo bucket.
func (db *DB) SetGeoHistory(mailbox string, h GeoHistory) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("email:geo"))
val, err := json.Marshal(h)
if err != nil {
return err
}
return b.Put([]byte(mailbox), val)
})
}
// GetGeoHistory retrieves geo login history for a mailbox.
// Returns the entry and true if found, or a zero value and false if not.
func (db *DB) GetGeoHistory(mailbox string) (GeoHistory, bool) {
var h GeoHistory
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("email:geo"))
v := b.Get([]byte(mailbox))
if v == nil {
return nil
}
if json.Unmarshal(v, &h) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return h, found
}
// SetForwarderHash stores a forwarder config hash in the email:fwd bucket.
func (db *DB) SetForwarderHash(key, hash string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("email:fwd"))
return b.Put([]byte(key), []byte(hash))
})
}
// GetForwarderHash retrieves a forwarder config hash.
// Returns the hash and true if found, or an empty string and false if not.
func (db *DB) GetForwarderHash(key string) (string, bool) {
var hash string
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("email:fwd"))
v := b.Get([]byte(key))
if v == nil {
return nil
}
hash = string(v)
found = true
return nil
})
return hash, found
}
// GetEmailPWLastRefresh reads the last email password-check refresh timestamp
// from the meta bucket. Returns the zero time if not set.
func (db *DB) GetEmailPWLastRefresh() time.Time {
var t time.Time
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
v := b.Get([]byte("email:pw_last_refresh"))
if v == nil {
return nil
}
parsed, err := time.Parse(time.RFC3339, string(v))
if err != nil {
return nil //nolint:nilerr // skip corrupt entry
}
t = parsed
return nil
})
return t
}
// SetEmailPWLastRefresh writes the email password-check refresh timestamp
// to the meta bucket.
func (db *DB) SetEmailPWLastRefresh(t time.Time) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
return b.Put([]byte("email:pw_last_refresh"), []byte(t.Format(time.RFC3339)))
})
}
// GetMetaString reads a string value from the meta bucket.
// Returns an empty string if the key is not found.
func (db *DB) GetMetaString(key string) string {
var val string
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
v := b.Get([]byte(key))
if v == nil {
return nil
}
val = string(v)
return nil
})
return val
}
// SetMetaString writes a string value to the meta bucket.
func (db *DB) SetMetaString(key, val string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
return b.Put([]byte(key), []byte(val))
})
}
package store
import (
"encoding/json"
"fmt"
"time"
"github.com/pidginhost/csm/internal/firewall"
bolt "go.etcd.io/bbolt"
)
// FWBlockedEntry represents an IP blocked by the firewall.
type FWBlockedEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"` // zero = permanent
}
// FWAllowedEntry represents an IP explicitly allowed through the firewall.
type FWAllowedEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
Port int `json:"port"` // 0 = all ports
ExpiresAt time.Time `json:"expires_at"` // zero = permanent
}
// FWSubnetEntry represents a subnet added to the firewall.
type FWSubnetEntry struct {
CIDR string `json:"cidr"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
AddedAt time.Time `json:"added_at"`
}
// FWPortAllowEntry represents a per-IP port allow rule.
type FWPortAllowEntry struct {
Key string `json:"key"` // IP:port/proto
IP string `json:"ip"`
Port int `json:"port"`
Proto string `json:"proto"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
}
// FirewallState holds the full state across all 4 firewall buckets.
type FirewallState struct {
Blocked []FWBlockedEntry
Allowed []FWAllowedEntry
Subnets []FWSubnetEntry
PortAllowed []FWPortAllowEntry
}
// portAllowKey returns the composite key "IP:port/proto".
func portAllowKey(ip string, port int, proto string) string {
return fmt.Sprintf("%s:%d/%s", ip, port, proto)
}
// BlockIP adds an IP to the fw:blocked bucket.
func (db *DB) BlockIP(ip, reason string, expiresAt time.Time) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:blocked"))
entry := FWBlockedEntry{
IP: ip,
Reason: reason,
Source: firewall.InferProvenance("block", reason),
BlockedAt: time.Now(),
ExpiresAt: expiresAt,
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(ip), val)
})
}
// UnblockIP removes an IP from the fw:blocked bucket.
func (db *DB) UnblockIP(ip string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:blocked"))
return b.Delete([]byte(ip))
})
}
// GetBlockedIP looks up a blocked IP. Returns false if not found or expired.
func (db *DB) GetBlockedIP(ip string) (FWBlockedEntry, bool) {
var entry FWBlockedEntry
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:blocked"))
v := b.Get([]byte(ip))
if v == nil {
return nil
}
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
// Filter expired entries (zero ExpiresAt = permanent).
if !entry.ExpiresAt.IsZero() && !entry.ExpiresAt.After(time.Now()) {
return nil
}
found = true
return nil
})
return entry, found
}
// AllowIP adds an IP to the fw:allowed bucket.
func (db *DB) AllowIP(ip, reason string, expiresAt time.Time) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:allowed"))
entry := FWAllowedEntry{
IP: ip,
Reason: reason,
Source: firewall.InferProvenance("allow", reason),
ExpiresAt: expiresAt,
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(ip), val)
})
}
// RemoveAllow removes an IP from the fw:allowed bucket.
func (db *DB) RemoveAllow(ip string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:allowed"))
return b.Delete([]byte(ip))
})
}
// AddSubnet adds a CIDR to the fw:subnets bucket.
func (db *DB) AddSubnet(cidr, reason string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:subnets"))
entry := FWSubnetEntry{
CIDR: cidr,
Reason: reason,
Source: firewall.InferProvenance("block_subnet", reason),
AddedAt: time.Now(),
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(cidr), val)
})
}
// RemoveSubnet removes a CIDR from the fw:subnets bucket.
func (db *DB) RemoveSubnet(cidr string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:subnets"))
return b.Delete([]byte(cidr))
})
}
// AddPortAllow adds a per-IP port allow rule to the fw:port_allowed bucket.
func (db *DB) AddPortAllow(ip string, port int, proto, reason string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:port_allowed"))
key := portAllowKey(ip, port, proto)
entry := FWPortAllowEntry{
Key: key,
IP: ip,
Port: port,
Proto: proto,
Reason: reason,
Source: firewall.InferProvenance("allow_port", reason),
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(key), val)
})
}
// RemovePortAllow removes a per-IP port allow rule from the fw:port_allowed bucket.
func (db *DB) RemovePortAllow(ip string, port int, proto string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:port_allowed"))
key := portAllowKey(ip, port, proto)
return b.Delete([]byte(key))
})
}
// ListPortAllows returns all entries in the fw:port_allowed bucket.
func (db *DB) ListPortAllows() []FWPortAllowEntry {
var entries []FWPortAllowEntry
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("fw:port_allowed"))
return b.ForEach(func(k, v []byte) error {
var entry FWPortAllowEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
entries = append(entries, entry)
return nil
})
})
return entries
}
// LoadFirewallState reads all 4 firewall buckets and assembles a FirewallState.
// Expired blocked entries are filtered out.
func (db *DB) LoadFirewallState() FirewallState {
var state FirewallState
now := time.Now()
_ = db.bolt.View(func(tx *bolt.Tx) error {
// fw:blocked - filter expired
blocked := tx.Bucket([]byte("fw:blocked"))
_ = blocked.ForEach(func(k, v []byte) error {
var entry FWBlockedEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
if !entry.ExpiresAt.IsZero() && !entry.ExpiresAt.After(now) {
return nil // expired
}
state.Blocked = append(state.Blocked, entry)
return nil
})
// fw:allowed
allowed := tx.Bucket([]byte("fw:allowed"))
_ = allowed.ForEach(func(k, v []byte) error {
var entry FWAllowedEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
state.Allowed = append(state.Allowed, entry)
return nil
})
// fw:subnets
subnets := tx.Bucket([]byte("fw:subnets"))
_ = subnets.ForEach(func(k, v []byte) error {
var entry FWSubnetEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
state.Subnets = append(state.Subnets, entry)
return nil
})
// fw:port_allowed
portAllowed := tx.Bucket([]byte("fw:port_allowed"))
_ = portAllowed.ForEach(func(k, v []byte) error {
var entry FWPortAllowEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
state.PortAllowed = append(state.PortAllowed, entry)
return nil
})
return nil
})
return state
}
package store
import (
"encoding/json"
"time"
bolt "go.etcd.io/bbolt"
)
// AuditResult represents a single hardening check result.
type AuditResult struct {
Category string `json:"category"`
Name string `json:"name"`
Title string `json:"title"`
Status string `json:"status"`
Message string `json:"message"`
Fix string `json:"fix,omitempty"`
}
// AuditReport is the full result of a hardening audit run.
type AuditReport struct {
Timestamp time.Time `json:"timestamp"`
ServerType string `json:"server_type"`
Results []AuditResult `json:"results"`
Score int `json:"score"`
Total int `json:"total"`
}
const hardeningReportKey = "hardening:report"
// SaveHardeningReport persists the latest audit report in the meta bucket.
func (db *DB) SaveHardeningReport(report *AuditReport) error {
data, err := json.Marshal(report)
if err != nil {
return err
}
return db.bolt.Update(func(tx *bolt.Tx) error {
return tx.Bucket([]byte("meta")).Put([]byte(hardeningReportKey), data)
})
}
// LoadHardeningReport retrieves the latest audit report from the meta bucket.
// Returns a zero-value report (nil Results) if no report has been saved yet.
func (db *DB) LoadHardeningReport() (*AuditReport, error) {
var report AuditReport
err := db.bolt.View(func(tx *bolt.Tx) error {
v := tx.Bucket([]byte("meta")).Get([]byte(hardeningReportKey))
if v == nil {
return nil
}
return json.Unmarshal(v, &report)
})
return &report, err
}
package store
import (
"encoding/json"
"fmt"
"strings"
"github.com/pidginhost/csm/internal/alert"
bolt "go.etcd.io/bbolt"
)
// maxHistoryEntries is the maximum number of history entries to retain.
// It is a var (not const) so tests can override it.
var maxHistoryEntries = 100_000
// AppendHistory inserts findings into the history bucket with TimeKey keys.
// It increments the history:count counter and prunes oldest entries if the
// count exceeds maxHistoryEntries.
func (db *DB) AppendHistory(findings []alert.Finding) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
for i, f := range findings {
key := TimeKey(f.Timestamp, i)
val, err := json.Marshal(f)
if err != nil {
return err
}
if err := b.Put([]byte(key), val); err != nil {
return err
}
}
if err := incrCounter(tx, "history:count", len(findings)); err != nil {
return err
}
// Prune oldest entries if count exceeds maxHistoryEntries.
meta := tx.Bucket([]byte("meta"))
var count int
if v := meta.Get([]byte("history:count")); v != nil {
_, _ = fmt.Sscanf(string(v), "%d", &count)
}
if count > maxHistoryEntries {
excess := count - maxHistoryEntries
c := b.Cursor()
k, _ := c.First()
for ; k != nil && excess > 0; excess-- {
// Delete() moves the cursor to the next item, so we
// must NOT call c.Next() after it.
if err := c.Delete(); err != nil {
return err
}
k, _ = c.First()
}
if err := setCounter(tx, "history:count", maxHistoryEntries); err != nil {
return err
}
}
return nil
})
}
// ReadHistory reads findings from the history bucket, newest-first.
// It returns up to limit findings starting at offset, plus the total count.
func (db *DB) ReadHistory(limit, offset int) ([]alert.Finding, int) {
total := db.getCounter("history:count")
var results []alert.Finding
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
c := b.Cursor()
// Skip offset entries from the end (newest first).
skipped := 0
k, v := c.Last()
for ; k != nil && skipped < offset; k, v = c.Prev() {
skipped++
}
// Collect up to limit entries.
for ; k != nil && len(results) < limit; k, v = c.Prev() {
var f alert.Finding
if err := json.Unmarshal(v, &f); err == nil {
results = append(results, f)
}
}
return nil
})
return results, total
}
// ReadHistoryFiltered reads findings with optional filtering.
// Parameters:
// - from, to: date strings "YYYY-MM-DD" for time-range filtering (empty to skip)
// - severity: filter by severity level (-1 for no filter)
// - search: case-insensitive substring match on check/message/details (empty to skip)
func (db *DB) ReadHistoryFiltered(limit, offset int, from, to string, severity int, search string) ([]alert.Finding, int) {
var results []alert.Finding
matched := 0
searchLower := strings.ToLower(search)
var fromPrefix, toPrefix string
if from != "" {
fromPrefix = ParseTimeKeyPrefix(from)
}
if to != "" {
// toPrefix needs to match the entire day, so we use the next day's prefix
// by appending a high character to ensure all entries on that day are included.
toPrefix = ParseTimeKeyPrefix(to) + "99" // "YYYYMMDD99" is > any time on that day
}
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("history"))
c := b.Cursor()
// Start from the end (newest) and iterate backward.
for k, v := c.Last(); k != nil; k, v = c.Prev() {
key := string(k)
// Time-range: if key is above toPrefix, skip it.
if toPrefix != "" && key > toPrefix {
continue
}
// Time-range: if key is below fromPrefix, all remaining are older - stop.
if fromPrefix != "" && key < fromPrefix {
break
}
var f alert.Finding
if err := json.Unmarshal(v, &f); err != nil {
continue
}
// Severity filter.
if severity >= 0 && int(f.Severity) != severity {
continue
}
// Search filter.
if search != "" && !containsLower(f.Check, searchLower) &&
!containsLower(f.Message, searchLower) &&
!containsLower(f.Details, searchLower) {
continue
}
matched++
if matched > offset && len(results) < limit {
results = append(results, f)
}
}
return nil
})
return results, matched
}
// containsLower checks if s contains substr using case-insensitive matching.
// substr must already be lowercase.
func containsLower(s, substr string) bool {
return strings.Contains(strings.ToLower(s), substr)
}
package store
import (
"bufio"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
bolt "go.etcd.io/bbolt"
)
// All os.Open / os.ReadFile calls below take paths derived from the
// operator-configured statePath (root-owned /opt/csm or /var/lib/csm
// by default) joined with fixed filenames. gosec G304 suppressions on
// each line refer back to this package-level trust model.
func (db *DB) runMigration(statePath string) error {
fmt.Fprintf(os.Stderr, "store: migrating flat files to bbolt...\n")
var errs []string
if err := db.migrateHistory(statePath); err != nil {
errs = append(errs, fmt.Sprintf("history: %v", err))
}
if err := db.migrateAttackDB(statePath); err != nil {
errs = append(errs, fmt.Sprintf("attackdb: %v", err))
}
if err := db.migrateThreatDB(statePath); err != nil {
errs = append(errs, fmt.Sprintf("threatdb: %v", err))
}
if err := db.migrateFirewall(statePath); err != nil {
errs = append(errs, fmt.Sprintf("firewall: %v", err))
}
if err := db.migrateReputation(statePath); err != nil {
errs = append(errs, fmt.Sprintf("reputation: %v", err))
}
if len(errs) > 0 {
return fmt.Errorf("partial migration: %s", strings.Join(errs, "; "))
}
_ = db.bolt.Update(func(tx *bolt.Tx) error {
return tx.Bucket([]byte("meta")).Put([]byte("migrated"), []byte(time.Now().Format(time.RFC3339)))
})
fmt.Fprintf(os.Stderr, "store: migration complete\n")
return nil
}
func (db *DB) migrateHistory(statePath string) error {
path := filepath.Join(statePath, "history.jsonl")
if _, err := os.Stat(path); err != nil {
return nil
}
// #nosec G304 -- see runMigration trust note.
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
var findings []alert.Finding
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
var finding alert.Finding
if err := json.Unmarshal(scanner.Bytes(), &finding); err != nil {
continue
}
findings = append(findings, finding)
}
if len(findings) > 0 {
if err := db.AppendHistory(findings); err != nil {
return err
}
}
renameToBackup(path)
fmt.Fprintf(os.Stderr, "store: migrated %d history entries\n", len(findings))
return nil
}
func (db *DB) migrateAttackDB(statePath string) error {
dbDir := filepath.Join(statePath, "attack_db")
// records.json
recordsPath := filepath.Join(dbDir, "records.json")
if _, err := os.Stat(recordsPath); err == nil {
// #nosec G304 -- see runMigration trust note.
data, err := os.ReadFile(recordsPath)
if err != nil {
return err
}
var records map[string]*IPRecord
if err := json.Unmarshal(data, &records); err != nil {
return err
}
for _, r := range records {
if err := db.SaveIPRecord(*r); err != nil {
return err
}
}
renameToBackup(recordsPath)
fmt.Fprintf(os.Stderr, "store: migrated %d attack records\n", len(records))
}
// events.jsonl
eventsPath := filepath.Join(dbDir, "events.jsonl")
if _, err := os.Stat(eventsPath); err == nil {
// #nosec G304 -- see runMigration trust note.
f, err := os.Open(eventsPath)
if err != nil {
return err
}
defer f.Close()
counter := 0
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
var event AttackEvent
if err := json.Unmarshal(scanner.Bytes(), &event); err != nil {
continue
}
if err := db.RecordAttackEvent(event, counter); err != nil {
return err
}
counter++
}
renameToBackup(eventsPath)
fmt.Fprintf(os.Stderr, "store: migrated %d attack events\n", counter)
}
return nil
}
func (db *DB) migrateThreatDB(statePath string) error {
dbDir := filepath.Join(statePath, "threat_db")
// permanent.txt
permPath := filepath.Join(dbDir, "permanent.txt")
if _, err := os.Stat(permPath); err == nil {
// #nosec G304 -- see runMigration trust note.
data, err := os.ReadFile(permPath)
if err != nil {
return err
}
count := 0
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, " # ", 2)
ip := strings.TrimSpace(parts[0])
reason := ""
if len(parts) > 1 {
reason = strings.TrimSpace(parts[1])
}
if ip != "" {
if err := db.AddPermanentBlock(ip, reason); err != nil {
return err
}
count++
}
}
renameToBackup(permPath)
fmt.Fprintf(os.Stderr, "store: migrated %d permanent blocks\n", count)
}
// whitelist.txt
wlPath := filepath.Join(dbDir, "whitelist.txt")
if _, err := os.Stat(wlPath); err == nil {
// #nosec G304 -- see runMigration trust note.
data, err := os.ReadFile(wlPath)
if err != nil {
return err
}
count := 0
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.Fields(line)
ip := parts[0]
permanent := true
var expiresAt time.Time
for _, p := range parts[1:] {
if strings.HasPrefix(p, "expires=") {
t, err := time.Parse(time.RFC3339, strings.TrimPrefix(p, "expires="))
if err == nil {
expiresAt = t
permanent = false
}
}
if p == "permanent" {
permanent = true
}
}
if err := db.AddWhitelistEntry(ip, expiresAt, permanent); err != nil {
return err
}
count++
}
renameToBackup(wlPath)
fmt.Fprintf(os.Stderr, "store: migrated %d whitelist entries\n", count)
}
return nil
}
func (db *DB) migrateFirewall(statePath string) error {
path := filepath.Join(statePath, "firewall", "state.json")
if _, serr := os.Stat(path); serr != nil {
return nil //nolint:nilerr // file does not exist, skip migration
}
// #nosec G304 -- see runMigration trust note.
data, err := os.ReadFile(path)
if err != nil {
return err
}
type rawBlocked struct {
IP string `json:"ip"`
Reason string `json:"reason"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type rawAllowed struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Port int `json:"port,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
}
type rawSubnet struct {
CIDR string `json:"cidr"`
Reason string `json:"reason"`
BlockedAt time.Time `json:"blocked_at"`
}
type rawPortAllow struct {
IP string `json:"ip"`
Port int `json:"port"`
Proto string `json:"proto"`
Reason string `json:"reason"`
}
type rawState struct {
Blocked []rawBlocked `json:"blocked"`
BlockedNet []rawSubnet `json:"blocked_nets"`
Allowed []rawAllowed `json:"allowed"`
PortAllowed []rawPortAllow `json:"port_allowed"`
}
var state rawState
if err := json.Unmarshal(data, &state); err != nil {
return err
}
for _, b := range state.Blocked {
if err := db.BlockIP(b.IP, b.Reason, b.ExpiresAt); err != nil {
return err
}
}
for _, a := range state.Allowed {
if err := db.AllowIP(a.IP, a.Reason, a.ExpiresAt); err != nil {
return err
}
}
for _, s := range state.BlockedNet {
if err := db.AddSubnet(s.CIDR, s.Reason); err != nil {
return err
}
}
for _, p := range state.PortAllowed {
if err := db.AddPortAllow(p.IP, p.Port, p.Proto, p.Reason); err != nil {
return err
}
}
renameToBackup(path)
fmt.Fprintf(os.Stderr, "store: migrated firewall state (%d blocked, %d allowed, %d subnets, %d port allows)\n",
len(state.Blocked), len(state.Allowed), len(state.BlockedNet), len(state.PortAllowed))
return nil
}
func (db *DB) migrateReputation(statePath string) error {
path := filepath.Join(statePath, "reputation_cache.json")
if _, serr := os.Stat(path); serr != nil {
return nil //nolint:nilerr // file does not exist, skip migration
}
// #nosec G304 -- see runMigration trust note.
data, err := os.ReadFile(path)
if err != nil {
return err
}
type rawCache struct {
Entries map[string]*ReputationEntry `json:"entries"`
}
var cache rawCache
if err := json.Unmarshal(data, &cache); err != nil {
return err
}
count := 0
for ip, entry := range cache.Entries {
if err := db.SetReputation(ip, *entry); err != nil {
return err
}
count++
}
renameToBackup(path)
fmt.Fprintf(os.Stderr, "store: migrated %d reputation entries\n", count)
return nil
}
func renameToBackup(path string) {
_ = os.Rename(path, path+".bak")
}
package store
import (
"encoding/json"
"fmt"
"strconv"
"time"
bolt "go.etcd.io/bbolt"
)
const modsecNoEscalateKey = "modsec:no_escalate_rules"
// GetModSecNoEscalateRules returns the set of ModSecurity rule IDs that should
// NOT escalate to nftables firewall blocks. Stored in the meta bucket.
func (db *DB) GetModSecNoEscalateRules() map[int]bool {
rules := make(map[int]bool)
_ = db.bolt.View(func(tx *bolt.Tx) error {
v := tx.Bucket([]byte("meta")).Get([]byte(modsecNoEscalateKey))
if v == nil {
return nil
}
var ids []int
if json.Unmarshal(v, &ids) != nil {
return nil //nolint:nilerr // skip corrupt data
}
for _, id := range ids {
rules[id] = true
}
return nil
})
return rules
}
// SetModSecNoEscalateRules stores the set of rule IDs that should not escalate.
func (db *DB) SetModSecNoEscalateRules(rules map[int]bool) error {
var ids []int
for id := range rules {
ids = append(ids, id)
}
return db.bolt.Update(func(tx *bolt.Tx) error {
val, err := json.Marshal(ids)
if err != nil {
return err
}
return tx.Bucket([]byte("meta")).Put([]byte(modsecNoEscalateKey), val)
})
}
// AddModSecNoEscalateRule atomically adds a single rule ID to the no-escalate set.
// Read-modify-write happens in a single bbolt Update transaction to prevent races.
func (db *DB) AddModSecNoEscalateRule(ruleID int) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
var ids []int
if v := b.Get([]byte(modsecNoEscalateKey)); v != nil {
_ = json.Unmarshal(v, &ids)
}
for _, id := range ids {
if id == ruleID {
return nil // already present
}
}
ids = append(ids, ruleID)
val, err := json.Marshal(ids)
if err != nil {
return err
}
return b.Put([]byte(modsecNoEscalateKey), val)
})
}
// RemoveModSecNoEscalateRule atomically removes a single rule ID from the no-escalate set.
func (db *DB) RemoveModSecNoEscalateRule(ruleID int) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
var ids []int
if v := b.Get([]byte(modsecNoEscalateKey)); v != nil {
_ = json.Unmarshal(v, &ids)
}
var filtered []int
for _, id := range ids {
if id != ruleID {
filtered = append(filtered, id)
}
}
val, err := json.Marshal(filtered)
if err != nil {
return err
}
return b.Put([]byte(modsecNoEscalateKey), val)
})
}
// RuleHitStats holds hit count and last-hit time for a ModSecurity rule.
type RuleHitStats struct {
Hits int `json:"hits"`
LastHit time.Time `json:"last_hit"`
}
type ruleHitData struct {
Buckets map[string]int `json:"buckets"` // key: "YYYYMMDDHH" → count
LastHit time.Time `json:"last_hit"`
}
func modsecHitKey(ruleID int) string {
return fmt.Sprintf("modsec:hits:%d", ruleID)
}
func hourBucket(t time.Time) string {
return fmt.Sprintf("%04d%02d%02d%02d", t.Year(), t.Month(), t.Day(), t.Hour())
}
// IncrModSecRuleHit increments the hit counter for a rule ID in the current hour bucket.
func (db *DB) IncrModSecRuleHit(ruleID int, timestamp time.Time) {
key := modsecHitKey(ruleID)
bucket := hourBucket(timestamp)
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
var data ruleHitData
if v := b.Get([]byte(key)); v != nil {
if json.Unmarshal(v, &data) != nil {
data = ruleHitData{Buckets: make(map[string]int)}
}
} else {
data = ruleHitData{Buckets: make(map[string]int)}
}
data.Buckets[bucket]++
data.LastHit = timestamp
val, err := json.Marshal(data)
if err != nil {
return err
}
return b.Put([]byte(key), val)
})
}
// GetModSecRuleHits returns hit counts and last-hit timestamps for all rules
// within the last 24 hours. Prunes buckets older than 24h.
// Note: hourly bucket granularity means the window is 24h +/- 1h at boundaries.
func (db *DB) GetModSecRuleHits() map[int]RuleHitStats {
result := make(map[int]RuleHitStats)
cutoff := time.Now().Add(-24 * time.Hour)
cutoffBucket := hourBucket(cutoff)
prefix := []byte("modsec:hits:")
// Track which keys need pruning - read first with View (no write lock),
// then prune in a separate Update only if needed.
type pruneItem struct {
key []byte
data ruleHitData
}
var toPrune []pruneItem
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && len(k) >= len(prefix) && string(k[:len(prefix)]) == string(prefix); k, v = c.Next() {
idStr := string(k[len(prefix):])
ruleID, err := strconv.Atoi(idStr)
if err != nil {
continue
}
var data ruleHitData
if json.Unmarshal(v, &data) != nil {
continue
}
total := 0
needsPrune := false
for bk, count := range data.Buckets {
if bk >= cutoffBucket {
total += count
} else {
needsPrune = true
}
}
if needsPrune {
// Deep copy key (bbolt keys are only valid inside tx)
keyCopy := make([]byte, len(k))
copy(keyCopy, k)
toPrune = append(toPrune, pruneItem{key: keyCopy, data: data})
}
if total > 0 || !data.LastHit.IsZero() {
result[ruleID] = RuleHitStats{
Hits: total,
LastHit: data.LastHit,
}
}
}
return nil
})
// Prune old buckets in a separate write transaction (only if needed)
if len(toPrune) > 0 {
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
for _, item := range toPrune {
for bk := range item.data.Buckets {
if bk < cutoffBucket {
delete(item.data.Buckets, bk)
}
}
val, _ := json.Marshal(item.data)
_ = b.Put(item.key, val)
}
return nil
})
}
return result
}
package store
import (
"encoding/json"
"time"
bolt "go.etcd.io/bbolt"
)
// PluginInfo holds cached metadata for a WordPress plugin from the API.
type PluginInfo struct {
LatestVersion string `json:"latest_version"`
TestedUpTo string `json:"tested_up_to"`
LastChecked int64 `json:"last_checked_unix"`
}
// SitePluginEntry describes a single plugin installed on a WordPress site.
type SitePluginEntry struct {
Slug string `json:"slug"`
Name string `json:"name"`
Status string `json:"status"`
InstalledVersion string `json:"installed_version"`
UpdateVersion string `json:"update_version"`
}
// SitePlugins holds the full plugin inventory for a WordPress installation.
type SitePlugins struct {
Account string `json:"account"`
Domain string `json:"domain"`
Plugins []SitePluginEntry `json:"plugins"`
}
// SetPluginInfo stores plugin metadata keyed by slug in the plugins bucket.
func (db *DB) SetPluginInfo(slug string, info PluginInfo) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("plugins"))
val, err := json.Marshal(info)
if err != nil {
return err
}
return b.Put([]byte(slug), val)
})
}
// GetPluginInfo retrieves plugin metadata for the given slug.
// Returns the entry and true if found, or a zero value and false if not.
func (db *DB) GetPluginInfo(slug string) (PluginInfo, bool) {
var info PluginInfo
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("plugins"))
v := b.Get([]byte(slug))
if v == nil {
return nil
}
if json.Unmarshal(v, &info) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return info, found
}
// SetSitePlugins stores the plugin inventory for a WordPress installation
// keyed by its filesystem path in the plugins:sites bucket.
func (db *DB) SetSitePlugins(wpPath string, site SitePlugins) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("plugins:sites"))
val, err := json.Marshal(site)
if err != nil {
return err
}
return b.Put([]byte(wpPath), val)
})
}
// GetSitePlugins retrieves the plugin inventory for a WordPress installation.
// Returns the entry and true if found, or a zero value and false if not.
func (db *DB) GetSitePlugins(wpPath string) (SitePlugins, bool) {
var site SitePlugins
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("plugins:sites"))
v := b.Get([]byte(wpPath))
if v == nil {
return nil
}
if json.Unmarshal(v, &site) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return site, found
}
// DeleteSitePlugins removes the plugin inventory for a WordPress installation.
func (db *DB) DeleteSitePlugins(wpPath string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("plugins:sites"))
return b.Delete([]byte(wpPath))
})
}
// AllSitePlugins returns all site plugin inventories keyed by WordPress path.
func (db *DB) AllSitePlugins() map[string]SitePlugins {
entries := make(map[string]SitePlugins)
_ = db.bolt.View(func(tx *bolt.Tx) error {
return tx.Bucket([]byte("plugins:sites")).ForEach(func(k, v []byte) error {
var s SitePlugins
if json.Unmarshal(v, &s) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
entries[string(k)] = s
return nil
})
})
return entries
}
// GetPluginRefreshTime reads the last plugin refresh timestamp from the meta bucket.
// Returns the zero time if not set.
func (db *DB) GetPluginRefreshTime() time.Time {
var t time.Time
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
v := b.Get([]byte("plugins:last_refresh"))
if v == nil {
return nil
}
parsed, err := time.Parse(time.RFC3339, string(v))
if err != nil {
return nil //nolint:nilerr // skip corrupt entry
}
t = parsed
return nil
})
return t
}
// SetPluginRefreshTime writes the plugin refresh timestamp to the meta bucket.
func (db *DB) SetPluginRefreshTime(t time.Time) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("meta"))
return b.Put([]byte("plugins:last_refresh"), []byte(t.Format(time.RFC3339)))
})
}
package store
import (
"encoding/json"
"sort"
"time"
bolt "go.etcd.io/bbolt"
)
// ReputationEntry holds the cached reputation data for an IP address.
type ReputationEntry struct {
Score int `json:"score"`
Category string `json:"category"`
CheckedAt time.Time `json:"checked_at"`
}
// SetReputation stores a reputation entry for the given IP.
func (db *DB) SetReputation(ip string, entry ReputationEntry) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("reputation"))
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(ip), val)
})
}
// GetReputation retrieves a reputation entry for the given IP.
// Returns the entry and true if found, or a zero value and false if not.
func (db *DB) GetReputation(ip string) (ReputationEntry, bool) {
var entry ReputationEntry
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("reputation"))
v := b.Get([]byte(ip))
if v == nil {
return nil
}
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return entry, found
}
// CleanExpiredReputation deletes entries older than maxAge.
// Uses a collect-then-delete pattern because bbolt does not allow mutation
// during ForEach iteration. Returns the count of entries removed.
func (db *DB) CleanExpiredReputation(maxAge time.Duration) int {
var removed int
cutoff := time.Now().Add(-maxAge)
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("reputation"))
// Collect keys to delete.
var toDelete [][]byte
_ = b.ForEach(func(k, v []byte) error {
var entry ReputationEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
if entry.CheckedAt.Before(cutoff) {
keyCopy := make([]byte, len(k))
copy(keyCopy, k)
toDelete = append(toDelete, keyCopy)
}
return nil
})
// Delete collected keys.
for _, k := range toDelete {
if err := b.Delete(k); err != nil {
return err
}
removed++
}
return nil
})
return removed
}
// AllReputation returns all reputation entries keyed by IP.
func (db *DB) AllReputation() map[string]ReputationEntry {
entries := make(map[string]ReputationEntry)
_ = db.bolt.View(func(tx *bolt.Tx) error {
return tx.Bucket([]byte("reputation")).ForEach(func(k, v []byte) error {
var e ReputationEntry
if json.Unmarshal(v, &e) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
entries[string(k)] = e
return nil
})
})
return entries
}
// EnforceReputationCap ensures the reputation bucket has at most max entries.
// If the count exceeds max, the oldest entries (by CheckedAt) are deleted.
// Returns the count of entries removed.
func (db *DB) EnforceReputationCap(max int) int {
var removed int
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("reputation"))
// Collect all entries with their keys.
type keyed struct {
key []byte
checkedAt time.Time
}
var all []keyed
_ = b.ForEach(func(k, v []byte) error {
var entry ReputationEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
keyCopy := make([]byte, len(k))
copy(keyCopy, k)
all = append(all, keyed{key: keyCopy, checkedAt: entry.CheckedAt})
return nil
})
if len(all) <= max {
return nil
}
// Sort by CheckedAt ascending (oldest first).
sort.Slice(all, func(i, j int) bool {
return all[i].checkedAt.Before(all[j].checkedAt)
})
// Delete the oldest entries beyond the cap.
excess := len(all) - max
for i := 0; i < excess; i++ {
if err := b.Delete(all[i].key); err != nil {
return err
}
removed++
}
return nil
})
return removed
}
package store
import (
"encoding/json"
"time"
bolt "go.etcd.io/bbolt"
)
// PermanentBlockEntry represents an IP permanently blocked by the threat system.
type PermanentBlockEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
BlockedAt time.Time `json:"blocked_at"`
}
// WhitelistEntry represents an IP that should bypass threat checks.
type WhitelistEntry struct {
IP string `json:"ip"`
ExpiresAt time.Time `json:"expires_at"`
Permanent bool `json:"permanent"`
}
// AddPermanentBlock adds an IP to the permanent block list.
// Only increments threats:count if the key is new.
func (db *DB) AddPermanentBlock(ip, reason string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats"))
isNew := b.Get([]byte(ip)) == nil
entry := PermanentBlockEntry{
IP: ip,
Reason: reason,
BlockedAt: time.Now(),
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
if err := b.Put([]byte(ip), val); err != nil {
return err
}
if isNew {
return incrCounter(tx, "threats:count", 1)
}
return nil
})
}
// RemovePermanentBlock removes an IP from the permanent block list and decrements the count.
func (db *DB) RemovePermanentBlock(ip string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats"))
if b.Get([]byte(ip)) == nil {
return nil
}
if err := b.Delete([]byte(ip)); err != nil {
return err
}
return incrCounter(tx, "threats:count", -1)
})
}
// GetPermanentBlock looks up a permanent block entry by IP.
// Returns the entry and true if found, or a zero value and false if not.
func (db *DB) GetPermanentBlock(ip string) (PermanentBlockEntry, bool) {
var entry PermanentBlockEntry
var found bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats"))
v := b.Get([]byte(ip))
if v == nil {
return nil
}
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
found = true
return nil
})
return entry, found
}
// AllPermanentBlocks returns all entries in the permanent block list.
func (db *DB) AllPermanentBlocks() []PermanentBlockEntry {
var entries []PermanentBlockEntry
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats"))
return b.ForEach(func(k, v []byte) error {
var entry PermanentBlockEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
entries = append(entries, entry)
return nil
})
})
return entries
}
// AddWhitelistEntry adds an IP to the whitelist.
func (db *DB) AddWhitelistEntry(ip string, expiresAt time.Time, permanent bool) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats:whitelist"))
entry := WhitelistEntry{
IP: ip,
ExpiresAt: expiresAt,
Permanent: permanent,
}
val, err := json.Marshal(entry)
if err != nil {
return err
}
return b.Put([]byte(ip), val)
})
}
// RemoveWhitelistEntry removes an IP from the whitelist.
func (db *DB) RemoveWhitelistEntry(ip string) error {
return db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats:whitelist"))
return b.Delete([]byte(ip))
})
}
// IsWhitelisted checks if an IP is whitelisted and not expired.
func (db *DB) IsWhitelisted(ip string) bool {
var whitelisted bool
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats:whitelist"))
v := b.Get([]byte(ip))
if v == nil {
return nil
}
var entry WhitelistEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
if entry.Permanent || entry.ExpiresAt.After(time.Now()) {
whitelisted = true
}
return nil
})
return whitelisted
}
// ListWhitelist returns all whitelist entries (including expired - caller filters).
func (db *DB) ListWhitelist() []WhitelistEntry {
var entries []WhitelistEntry
_ = db.bolt.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats:whitelist"))
return b.ForEach(func(k, v []byte) error {
var entry WhitelistEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
entries = append(entries, entry)
return nil
})
})
return entries
}
// PruneExpiredWhitelist deletes expired non-permanent whitelist entries.
// Returns the count of entries removed. Uses a collect-then-delete pattern
// because bbolt does not allow mutation during ForEach iteration.
func (db *DB) PruneExpiredWhitelist() int {
var removed int
now := time.Now()
_ = db.bolt.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("threats:whitelist"))
// Collect keys to delete.
var toDelete [][]byte
_ = b.ForEach(func(k, v []byte) error {
var entry WhitelistEntry
if json.Unmarshal(v, &entry) != nil {
return nil //nolint:nilerr // skip corrupt entry
}
if !entry.Permanent && !entry.ExpiresAt.After(now) {
keyCopy := make([]byte, len(k))
copy(keyCopy, k)
toDelete = append(toDelete, keyCopy)
}
return nil
})
// Delete collected keys.
for _, k := range toDelete {
if err := b.Delete(k); err != nil {
return err
}
removed++
}
return nil
})
return removed
}
package threat
import (
"encoding/json"
"os"
"path/filepath"
"time"
"github.com/pidginhost/csm/internal/attackdb"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/firewall"
"github.com/pidginhost/csm/internal/store"
)
// IPIntelligence is the complete picture of an IP from all sources.
type IPIntelligence struct {
IP string `json:"ip"`
// Internal attack history
AttackRecord *attackdb.IPRecord `json:"attack_record,omitempty"`
LocalScore int `json:"local_score"`
// ThreatDB (feeds + permanent blocklist)
InThreatDB bool `json:"in_threat_db"`
ThreatDBSource string `json:"threat_db_source,omitempty"`
// AbuseIPDB cache
AbuseScore int `json:"abuse_score"`
AbuseCategory string `json:"abuse_category,omitempty"`
// Firewall state
CurrentlyBlocked bool `json:"currently_blocked"`
BlockReason string `json:"block_reason,omitempty"`
BlockedAt *time.Time `json:"blocked_at,omitempty"`
BlockExpiresAt *time.Time `json:"block_expires_at,omitempty"`
BlockPermanent bool `json:"block_permanent,omitempty"`
// GeoIP (populated by API layer, not by Lookup)
Country string `json:"country,omitempty"`
CountryName string `json:"country_name,omitempty"`
City string `json:"city,omitempty"`
ASN uint `json:"asn,omitempty"`
ASOrg string `json:"as_org,omitempty"`
Network string `json:"network,omitempty"`
// Composite
UnifiedScore int `json:"unified_score"`
Verdict string `json:"verdict"` // "clean", "suspicious", "malicious", "blocked"
}
// Lookup returns the full intelligence picture for an IP.
// All reads are local (no network calls). Pre-loads shared state files
// once (same as LookupBatch) to avoid re-reading per field.
func Lookup(ip, statePath string) *IPIntelligence {
abuseCache := loadFullAbuseCache(statePath)
blockMap := loadFullBlockState(statePath)
intel := &IPIntelligence{
IP: ip,
AbuseScore: -1, // not cached
}
// 1. Attack DB
if adb := attackdb.Global(); adb != nil {
if rec := adb.LookupIP(ip); rec != nil {
intel.AttackRecord = rec
intel.LocalScore = rec.ThreatScore
}
}
// 2. ThreatDB (feeds + permanent)
if tdb := checks.GetThreatDB(); tdb != nil {
if source, found := tdb.Lookup(ip); found {
intel.InThreatDB = true
intel.ThreatDBSource = source
}
}
// 3. AbuseIPDB from pre-loaded cache
if entry, ok := abuseCache[ip]; ok {
intel.AbuseScore = entry.Score
intel.AbuseCategory = entry.Category
}
// 4. Block state from pre-loaded map
applyBlockState(intel, blockMap)
computeVerdict(intel)
return intel
}
// LookupBatch returns intelligence for multiple IPs efficiently.
// Pre-loads shared state files once instead of per-IP.
func LookupBatch(ips []string, statePath string) []*IPIntelligence {
results := make([]*IPIntelligence, len(ips))
abuseCache := loadFullAbuseCache(statePath)
blockMap := loadFullBlockState(statePath)
for i, ip := range ips {
intel := &IPIntelligence{
IP: ip,
AbuseScore: -1,
}
// Attack DB
if adb := attackdb.Global(); adb != nil {
if rec := adb.LookupIP(ip); rec != nil {
intel.AttackRecord = rec
intel.LocalScore = rec.ThreatScore
}
}
// ThreatDB
if tdb := checks.GetThreatDB(); tdb != nil {
if source, found := tdb.Lookup(ip); found {
intel.InThreatDB = true
intel.ThreatDBSource = source
}
}
// AbuseIPDB from pre-loaded cache
if entry, ok := abuseCache[ip]; ok {
intel.AbuseScore = entry.Score
intel.AbuseCategory = entry.Category
}
// Block state from pre-loaded map
applyBlockState(intel, blockMap)
computeVerdict(intel)
results[i] = intel
}
return results
}
func computeVerdict(intel *IPIntelligence) {
intel.UnifiedScore = intel.LocalScore
if intel.AbuseScore > intel.UnifiedScore {
intel.UnifiedScore = intel.AbuseScore
}
if intel.InThreatDB && intel.UnifiedScore < 100 {
intel.UnifiedScore = 100
}
switch {
case intel.CurrentlyBlocked:
intel.Verdict = "blocked"
case intel.UnifiedScore >= 80:
intel.Verdict = "malicious"
case intel.UnifiedScore >= 40:
intel.Verdict = "suspicious"
default:
intel.Verdict = "clean"
}
}
func applyBlockState(intel *IPIntelligence, blockMap map[string]*blockEntry) {
bs, ok := blockMap[intel.IP]
if !ok {
return
}
intel.CurrentlyBlocked = true
intel.BlockReason = bs.reason
intel.BlockPermanent = bs.permanent
if !bs.blockedAt.IsZero() {
t := bs.blockedAt
intel.BlockedAt = &t
}
if !bs.expiresAt.IsZero() && bs.expiresAt.Year() > 1 {
t := bs.expiresAt
intel.BlockExpiresAt = &t
}
}
// --- AbuseIPDB cache reader ---
type abuseEntry struct {
Score int `json:"score"`
Category string `json:"category"`
}
func loadFullAbuseCache(statePath string) map[string]*abuseEntry {
result := make(map[string]*abuseEntry)
sixHoursAgo := time.Now().Add(-6 * time.Hour)
// Try bbolt store first - after migration the flat file is renamed to .bak.
if sdb := store.Global(); sdb != nil {
for ip, entry := range sdb.AllReputation() {
if entry.CheckedAt.Before(sixHoursAgo) || entry.Score < 0 {
continue // expired or error sentinel
}
result[ip] = &abuseEntry{Score: entry.Score, Category: entry.Category}
}
return result
}
// Fallback: flat-file JSON (pre-migration).
type cacheEntry struct {
Score int `json:"score"`
Category string `json:"category"`
CheckedAt time.Time `json:"checked_at"`
}
type cacheFile struct {
Entries map[string]*cacheEntry `json:"entries"`
}
// #nosec G304 -- filepath.Join under operator-configured statePath.
data, err := os.ReadFile(filepath.Join(statePath, "reputation_cache.json"))
if err != nil {
return result
}
var cf cacheFile
if json.Unmarshal(data, &cf) != nil || cf.Entries == nil {
return result
}
for ip, entry := range cf.Entries {
if entry.CheckedAt.Before(sixHoursAgo) || entry.Score < 0 {
continue // expired or error sentinel
}
result[ip] = &abuseEntry{Score: entry.Score, Category: entry.Category}
}
return result
}
// --- Firewall block state reader ---
type blockEntry struct {
reason string
blockedAt time.Time
expiresAt time.Time
permanent bool
}
func loadFullBlockState(statePath string) map[string]*blockEntry {
result := make(map[string]*blockEntry)
now := time.Now()
// Try bbolt store first.
if sdb := store.Global(); sdb != nil {
ss := sdb.LoadFirewallState()
for _, entry := range ss.Blocked {
perm := entry.ExpiresAt.IsZero() || entry.ExpiresAt.Year() <= 1
result[entry.IP] = &blockEntry{
reason: entry.Reason,
blockedAt: entry.BlockedAt,
expiresAt: entry.ExpiresAt,
permanent: perm,
}
}
} else {
// Fallback: firewall.LoadState() reads flat-file state.json.
fwState, err := firewall.LoadState(statePath)
if err == nil && fwState != nil {
for _, entry := range fwState.Blocked {
perm := entry.ExpiresAt.IsZero() || entry.ExpiresAt.Year() <= 1
result[entry.IP] = &blockEntry{
reason: entry.Reason,
blockedAt: entry.BlockedAt,
expiresAt: entry.ExpiresAt,
permanent: perm,
}
}
}
}
// CSM blocked_ips.json (legacy)
type csmEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type csmFile struct {
IPs []csmEntry `json:"ips"`
}
// #nosec G304 -- filepath.Join under operator-configured statePath.
if data, err := os.ReadFile(filepath.Join(statePath, "blocked_ips.json")); err == nil {
var cf csmFile
if json.Unmarshal(data, &cf) == nil {
for _, entry := range cf.IPs {
if entry.ExpiresAt.IsZero() || now.Before(entry.ExpiresAt) {
if _, exists := result[entry.IP]; !exists {
perm := entry.ExpiresAt.IsZero() || entry.ExpiresAt.Year() <= 1
result[entry.IP] = &blockEntry{
reason: entry.Reason,
blockedAt: entry.BlockedAt,
expiresAt: entry.ExpiresAt,
permanent: perm,
}
}
}
}
}
}
return result
}
package webui
import (
"net/http"
"path/filepath"
"strings"
"time"
"github.com/pidginhost/csm/internal/checks"
)
func (s *Server) handleAccount(w http.ResponseWriter, r *http.Request) {
name := r.URL.Query().Get("name")
if err := validateAccountName(name); err != nil {
http.Redirect(w, r, "/findings", http.StatusFound)
return
}
s.renderTemplate(w, "account.html", map[string]string{
"Hostname": s.cfg.Hostname,
"AccountName": name,
"HomeDir": filepath.Join("/home", name),
})
}
func (s *Server) apiAccountDetail(w http.ResponseWriter, r *http.Request) {
name := r.URL.Query().Get("name")
if err := validateAccountName(name); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
homePrefix := "/home/" + name + "/"
// Current findings for this account
type findingView struct {
Severity int `json:"severity"`
Check string `json:"check"`
Message string `json:"message"`
HasFix bool `json:"has_fix"`
}
var accountFindings []findingView
latest := s.store.LatestFindings()
for _, f := range latest {
if f.Check == "auto_response" || f.Check == "auto_block" || f.Check == "check_timeout" || f.Check == "health" {
continue
}
if strings.Contains(f.Message, homePrefix) || strings.Contains(f.Details, homePrefix) || strings.Contains(f.FilePath, homePrefix) {
accountFindings = append(accountFindings, findingView{
Severity: int(f.Severity),
Check: f.Check,
Message: f.Message,
HasFix: checks.HasFix(f.Check),
})
}
}
// Quarantined files for this account
type qEntry struct {
ID string `json:"id"`
OriginalPath string `json:"original_path"`
Size int64 `json:"size"`
Reason string `json:"reason"`
}
var quarantined []qEntry
rootMetas := listMetaFiles("/opt/csm/quarantine")
preCleanMetas := listMetaFiles(filepath.Join("/opt/csm/quarantine", "pre_clean"))
metas := rootMetas
metas = append(metas, preCleanMetas...)
for _, metaPath := range metas {
meta, err := readQuarantineMeta(metaPath)
if err != nil {
continue
}
if strings.HasPrefix(meta.OriginalPath, homePrefix) {
id := strings.TrimSuffix(filepath.Base(metaPath), ".meta")
quarantined = append(quarantined, qEntry{
ID: id, OriginalPath: meta.OriginalPath, Size: meta.Size, Reason: meta.Reason,
})
}
}
// Recent history for this account (last 100 matching entries)
allHistory, _ := s.store.ReadHistory(2000, 0)
type histEntry struct {
Severity int `json:"severity"`
Check string `json:"check"`
Message string `json:"message"`
Timestamp string `json:"timestamp"`
}
var history []histEntry
for _, f := range allHistory {
if len(history) >= 100 {
break
}
if strings.Contains(f.Message, homePrefix) || strings.Contains(f.Details, homePrefix) {
history = append(history, histEntry{
Severity: int(f.Severity), Check: f.Check, Message: f.Message,
Timestamp: f.Timestamp.Format(time.RFC3339),
})
}
}
writeJSON(w, map[string]interface{}{
"account": name,
"findings": accountFindings,
"quarantined": quarantined,
"history": history,
"whm_url": "https://" + s.cfg.Hostname + ":2087/scripts/domainsdata?user=" + name,
})
}
package webui
import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"syscall"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/firewall"
"github.com/pidginhost/csm/internal/state"
"github.com/pidginhost/csm/internal/store"
)
var reIPReputation = regexp.MustCompile(`Known malicious IP accessing server: (\S+) \((.+)\)`)
// apiStatus returns daemon status and uptime.
func (s *Server) apiStatus(w http.ResponseWriter, _ *http.Request) {
s.scanMu.Lock()
scanning := s.scanRunning
s.scanMu.Unlock()
status := map[string]interface{}{
"hostname": s.cfg.Hostname,
"uptime": time.Since(s.startTime).String(),
"started_at": s.startTime.Format(time.RFC3339),
"rules_loaded": s.sigCount,
"scan_running": scanning,
"last_scan_time": s.store.LatestScanTime().Format(time.RFC3339),
}
writeJSON(w, status)
}
// apiFindings returns current scan results - "what's wrong right now."
func (s *Server) apiFindings(w http.ResponseWriter, _ *http.Request) {
latest := s.store.LatestFindings()
type entryView struct {
Severity int `json:"severity"`
Check string `json:"check"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
Time string `json:"time"`
FirstSeen string `json:"first_seen"`
LastSeen string `json:"last_seen"`
HasFix bool `json:"has_fix"`
}
suppressions := s.store.LoadSuppressions()
var result []entryView
for _, f := range latest {
if f.Check == "auto_response" || f.Check == "auto_block" || f.Check == "check_timeout" || f.Check == "health" {
continue
}
// Skip suppressed findings
if s.store.IsSuppressed(f, suppressions) {
continue
}
firstSeen := f.Timestamp
lastSeen := f.Timestamp
if entry, ok := s.store.EntryForKey(f.Key()); ok {
firstSeen = entry.FirstSeen
lastSeen = entry.LastSeen
}
result = append(result, entryView{
Severity: int(f.Severity),
Check: f.Check,
Message: f.Message,
Details: f.Details,
Time: f.Timestamp.Format(time.RFC3339),
FirstSeen: firstSeen.Format(time.RFC3339),
LastSeen: lastSeen.Format(time.RFC3339),
HasFix: checks.HasFix(f.Check),
})
}
writeJSON(w, result)
}
// enrichedFinding is the JSON response type for the enriched findings endpoint.
type enrichedFinding struct {
Key string `json:"key"`
Severity string `json:"severity"`
SevClass string `json:"sev_class"`
Check string `json:"check"`
Message string `json:"message"`
FilePath string `json:"file_path,omitempty"`
Account string `json:"account,omitempty"`
FirstSeen string `json:"first_seen"`
LastSeen string `json:"last_seen"`
HasFix bool `json:"has_fix"`
FixDesc string `json:"fix_desc,omitempty"`
}
// dedupIPReputation groups ip_reputation findings by IP, merging sources and
// promoting to the highest severity. Non-ip_reputation findings pass through unchanged.
func dedupIPReputation(items []enrichedFinding) []enrichedFinding {
type ipGroup struct {
entry enrichedFinding
sources []string
}
ipGroups := make(map[string]*ipGroup)
var ipOrder []string
var result []enrichedFinding
for _, item := range items {
if item.Check != "ip_reputation" {
result = append(result, item)
continue
}
m := reIPReputation.FindStringSubmatch(item.Message)
if m == nil {
result = append(result, item)
continue
}
ip, source := m[1], m[2]
if g, ok := ipGroups[ip]; ok {
g.sources = append(g.sources, source)
if item.FirstSeen < g.entry.FirstSeen {
g.entry.FirstSeen = item.FirstSeen
}
if item.LastSeen > g.entry.LastSeen {
g.entry.LastSeen = item.LastSeen
}
if severityRank(item.Severity) > severityRank(g.entry.Severity) {
g.entry.Severity = item.Severity
g.entry.SevClass = item.SevClass
}
} else {
ipGroups[ip] = &ipGroup{
entry: item,
sources: []string{source},
}
ipOrder = append(ipOrder, ip)
}
}
for _, ip := range ipOrder {
g := ipGroups[ip]
g.entry.Message = fmt.Sprintf("Known malicious IP accessing server: %s (%s)", ip, strings.Join(g.sources, ", "))
result = append(result, g.entry)
}
return result
}
// apiFindingsEnriched returns findings with IP dedup, account extraction, and severity counts.
func (s *Server) apiFindingsEnriched(w http.ResponseWriter, _ *http.Request) {
latest := s.store.LatestFindings()
suppressions := s.store.LoadSuppressions()
items := make([]enrichedFinding, 0)
for _, f := range latest {
if f.Check == "auto_response" || f.Check == "auto_block" || f.Check == "check_timeout" || f.Check == "health" {
continue
}
if s.store.IsSuppressed(f, suppressions) {
continue
}
firstSeen := f.Timestamp
lastSeen := f.Timestamp
if entry, ok := s.store.EntryForKey(f.Key()); ok {
firstSeen = entry.FirstSeen
lastSeen = entry.LastSeen
}
items = append(items, enrichedFinding{
Key: f.Key(),
Severity: severityLabel(f.Severity),
SevClass: severityClass(f.Severity),
Check: f.Check,
Message: f.Message,
FilePath: f.FilePath,
Account: extractAccountFromFinding(f),
FirstSeen: firstSeen.Format(time.RFC3339),
LastSeen: lastSeen.Format(time.RFC3339),
HasFix: checks.HasFix(f.Check),
FixDesc: checks.FixDescription(f.Check, f.Message, f.FilePath),
})
}
items = dedupIPReputation(items)
var critCount, highCount, warnCount int
for _, item := range items {
switch item.Severity {
case "CRITICAL":
critCount++
case "HIGH":
highCount++
default:
warnCount++
}
}
checkTypeSet := make(map[string]bool)
accountSet := make(map[string]bool)
for _, item := range items {
checkTypeSet[item.Check] = true
if item.Account != "" {
accountSet[item.Account] = true
}
}
checkTypes := make([]string, 0, len(checkTypeSet))
for ct := range checkTypeSet {
checkTypes = append(checkTypes, ct)
}
sort.Strings(checkTypes)
accounts := make([]string, 0, len(accountSet))
for a := range accountSet {
accounts = append(accounts, a)
}
sort.Strings(accounts)
writeJSON(w, map[string]interface{}{
"findings": items,
"check_types": checkTypes,
"accounts": accounts,
"critical_count": critCount,
"high_count": highCount,
"warning_count": warnCount,
"total": len(items),
})
}
// apiHistory returns paginated history from history.jsonl.
// Supports optional filtering via "from", "to" (YYYY-MM-DD), and "severity" (0/1/2) query params.
func (s *Server) apiHistory(w http.ResponseWriter, r *http.Request) {
limit := queryInt(r, "limit", 50)
if limit > 5000 {
limit = 5000
}
offset := queryInt(r, "offset", 0)
fromStr := r.URL.Query().Get("from")
toStr := r.URL.Query().Get("to")
sevStr := r.URL.Query().Get("severity")
searchStr := r.URL.Query().Get("search")
checksStr := r.URL.Query().Get("checks")
var checksFilter map[string]bool
if checksStr != "" {
checksFilter = make(map[string]bool)
for _, c := range strings.Split(checksStr, ",") {
c = strings.TrimSpace(c)
if c != "" {
checksFilter[c] = true
}
}
}
// If no filters, use simple paginated read
if fromStr == "" && toStr == "" && sevStr == "" && searchStr == "" && checksStr == "" {
findings, total := s.store.ReadHistory(limit, offset)
writeJSON(w, map[string]interface{}{
"findings": findings,
"total": total,
"limit": limit,
"offset": offset,
})
return
}
// With filters: read all history, filter, then paginate
var fromDate, toDate time.Time
if fromStr != "" {
fromDate, _ = time.ParseInLocation("2006-01-02", fromStr, time.Local)
}
if toStr != "" {
toDate, _ = time.ParseInLocation("2006-01-02", toStr, time.Local)
toDate = toDate.Add(24*time.Hour - time.Second)
}
sevFilter := -1
if sevStr != "" {
sevFilter = queryInt(r, "severity", -1)
}
searchLower := strings.ToLower(searchStr)
allFindings, _ := s.store.ReadHistory(5000, 0)
var filtered []alert.Finding
for _, f := range allFindings {
if !fromDate.IsZero() && f.Timestamp.Before(fromDate) {
continue
}
if !toDate.IsZero() && f.Timestamp.After(toDate) {
continue
}
if sevFilter >= 0 && int(f.Severity) != sevFilter {
continue
}
if searchStr != "" {
if !strings.Contains(strings.ToLower(f.Check), searchLower) &&
!strings.Contains(strings.ToLower(f.Message), searchLower) &&
!strings.Contains(strings.ToLower(f.Details), searchLower) {
continue
}
}
if checksFilter != nil && !checksFilter[f.Check] {
continue
}
filtered = append(filtered, f)
}
total := len(filtered)
// Apply offset and limit
if offset > len(filtered) {
filtered = nil
} else {
filtered = filtered[offset:]
if len(filtered) > limit {
filtered = filtered[:limit]
}
}
writeJSON(w, map[string]interface{}{
"findings": filtered,
"total": total,
"limit": limit,
"offset": offset,
})
}
// var (not const) so tests can redirect to t.TempDir(). Production
// callers must not mutate at runtime.
var quarantineDir = "/opt/csm/quarantine"
// apiQuarantine lists quarantined files with metadata.
func (s *Server) apiQuarantine(w http.ResponseWriter, _ *http.Request) {
type quarantineEntry struct {
ID string `json:"id"`
OriginalPath string `json:"original_path"`
Size int64 `json:"size"`
QuarantineAt string `json:"quarantined_at"`
Reason string `json:"reason"`
}
var entries []quarantineEntry
// Scan both root quarantine dir and pre_clean subdirectory
rootMetas := listMetaFiles(quarantineDir)
preCleanMetas := listMetaFiles(filepath.Join(quarantineDir, "pre_clean"))
metaFiles := rootMetas
metaFiles = append(metaFiles, preCleanMetas...)
for _, metaFile := range metaFiles {
meta, err := readQuarantineMeta(metaFile)
if err != nil {
continue
}
entries = append(entries, quarantineEntry{
ID: quarantineEntryID(metaFile),
OriginalPath: meta.OriginalPath,
Size: meta.Size,
QuarantineAt: meta.QuarantineAt.Format(time.RFC3339),
Reason: meta.Reason,
})
}
// Sort newest first
sort.Slice(entries, func(i, j int) bool {
return entries[i].QuarantineAt > entries[j].QuarantineAt
})
writeJSON(w, entries)
}
// apiStats returns severity counts and per-check breakdown.
func (s *Server) apiStats(w http.ResponseWriter, _ *http.Request) {
last24h := time.Now().Add(-24 * time.Hour)
findings := s.store.ReadHistorySince(last24h)
critical, high, warning := 0, 0, 0
byCheck := make(map[string]int)
for _, f := range findings {
switch f.Severity {
case alert.Critical:
critical++
case alert.High:
high++
case alert.Warning:
warning++
}
byCheck[f.Check]++
}
// Find most recent critical finding for "time since last critical"
// (findings are newest-first from ReadHistorySince)
lastCriticalAgo := "None"
for _, f := range findings {
if f.Severity == alert.Critical {
lastCriticalAgo = timeAgo(f.Timestamp)
break
}
}
// Compute accounts at risk: accounts with critical/high findings in 24h
accountRisk := make(map[string]int) // account -> highest severity
// Auto-response summary: count actions by type in 24h
autoBlocked, autoQuarantined, autoKilled := 0, 0, 0
// Top targeted accounts
accountHits := make(map[string]int)
// Brute force summary
bruteForceIPs := make(map[string]int) // IP -> total attempts
bruteForceTypes := make(map[string]int) // "wp-login" / "xmlrpc" -> count
for _, f := range findings {
// Extract account from finding path/message
acct := extractAccountFromFinding(f)
if acct != "" {
accountHits[acct]++
sev := int(f.Severity)
if prev, ok := accountRisk[acct]; !ok || sev > prev {
accountRisk[acct] = sev
}
}
// Count auto-response actions
switch f.Check {
case "auto_block":
autoBlocked++
case "auto_response":
if strings.Contains(f.Message, "quarantin") {
autoQuarantined++
} else if strings.Contains(f.Message, "kill") || strings.Contains(f.Message, "Kill") {
autoKilled++
}
case "wp_login_bruteforce":
bruteForceTypes["wp-login"]++
if ip := checks.ExtractIPFromFinding(f); ip != "" {
bruteForceIPs[ip]++
}
case "xmlrpc_abuse":
bruteForceTypes["xmlrpc"]++
if ip := checks.ExtractIPFromFinding(f); ip != "" {
bruteForceIPs[ip]++
}
case "modsec_csm_block_escalation":
if strings.Contains(f.Message, "xmlrpc") || strings.Contains(f.Message, "900006") || strings.Contains(f.Message, "900007") {
bruteForceTypes["xmlrpc-modsec"]++
}
}
}
// Accounts at risk: those with critical or high severity
var atRisk []map[string]interface{}
for acct, sev := range accountRisk {
if sev >= int(alert.High) {
atRisk = append(atRisk, map[string]interface{}{
"account": acct,
"severity": sev,
"findings": accountHits[acct],
})
}
}
// Sort by severity desc, then findings desc
sort.Slice(atRisk, func(i, j int) bool {
if atRisk[i]["severity"].(int) != atRisk[j]["severity"].(int) {
return atRisk[i]["severity"].(int) > atRisk[j]["severity"].(int)
}
return atRisk[i]["findings"].(int) > atRisk[j]["findings"].(int)
})
if len(atRisk) > 50 {
atRisk = atRisk[:50]
}
// Top targeted accounts (by finding count)
type acctCount struct {
Account string `json:"account"`
Count int `json:"count"`
}
var topAccounts []acctCount
for acct, count := range accountHits {
topAccounts = append(topAccounts, acctCount{acct, count})
}
sort.Slice(topAccounts, func(i, j int) bool {
return topAccounts[i].Count > topAccounts[j].Count
})
if len(topAccounts) > 5 {
topAccounts = topAccounts[:5]
}
result := map[string]interface{}{
"last_24h": map[string]interface{}{
"critical": critical,
"high": high,
"warning": warning,
"total": critical + high + warning,
},
"by_check": byCheck,
"last_critical_ago": lastCriticalAgo,
"accounts_at_risk": atRisk,
"auto_response": map[string]int{
"blocked": autoBlocked,
"quarantined": autoQuarantined,
"killed": autoKilled,
},
"top_accounts": topAccounts,
"brute_force": buildBruteForceSummary(bruteForceIPs, bruteForceTypes),
}
writeJSON(w, result)
}
func buildBruteForceSummary(ips map[string]int, types map[string]int) map[string]interface{} {
// Top attacker IPs
type ipCount struct {
IP string `json:"ip"`
Count int `json:"count"`
}
var topIPs []ipCount
for ip, count := range ips {
topIPs = append(topIPs, ipCount{ip, count})
}
sort.Slice(topIPs, func(i, j int) bool {
return topIPs[i].Count > topIPs[j].Count
})
if len(topIPs) > 10 {
topIPs = topIPs[:10]
}
total := 0
for _, v := range types {
total += v
}
return map[string]interface{}{
"total_attacks": total,
"unique_ips": len(ips),
"wp_login_count": types["wp-login"],
"xmlrpc_count": types["xmlrpc"] + types["xmlrpc-modsec"],
"top_ips": topIPs,
}
}
// apiStatsTrend returns 30-day daily finding counts by severity.
// Uses efficient bbolt cursor seeking instead of loading all findings into memory.
func (s *Server) apiStatsTrend(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, s.store.AggregateByDay())
}
// apiStatsTimeline returns 24 hourly buckets for the findings timeline chart.
// Uses efficient bbolt cursor seeking instead of loading all findings into memory.
func (s *Server) apiStatsTimeline(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, s.store.AggregateByHour())
}
// apiHealth returns daemon health status.
func (s *Server) apiHealth(w http.ResponseWriter, _ *http.Request) {
health := map[string]interface{}{
"daemon_mode": true,
"uptime": time.Since(s.startTime).String(),
"uptime_seconds": int(time.Since(s.startTime).Seconds()),
"rules_loaded": s.sigCount,
"fanotify": s.fanotifyActive,
"log_watchers": s.logWatcherCount,
}
writeJSON(w, health)
}
// apiHistoryCSV exports history as CSV download.
func (s *Server) apiHistoryCSV(w http.ResponseWriter, _ *http.Request) {
findings, _ := s.store.ReadHistory(5000, 0)
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Content-Disposition", "attachment; filename=csm-history.csv")
// CSV header
fmt.Fprintf(w, "Timestamp,Severity,Check,Message,Details\n")
for _, f := range findings {
sev := "WARNING"
switch f.Severity {
case alert.Critical:
sev = "CRITICAL"
case alert.High:
sev = "HIGH"
}
// Escape CSV fields
msg := csvEscape(f.Message)
details := csvEscape(f.Details)
fmt.Fprintf(w, "%s,%s,%s,%s,%s\n",
f.Timestamp.Format(time.RFC3339), sev, f.Check, msg, details)
}
}
func csvEscape(s string) string {
if strings.ContainsAny(s, ",\"\n\r") {
return "\"" + strings.ReplaceAll(s, "\"", "\"\"") + "\""
}
return s
}
// apiFix applies a known remediation action for a finding.
// POST /api/v1/fix body: {"check": "check_type", "message": "...", "details": "..."}
func (s *Server) apiFix(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Check string `json:"check"`
Message string `json:"message"`
Details string `json:"details"`
FilePath string `json:"file_path"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.Check == "" || req.Message == "" {
writeJSONError(w, "check and message are required", http.StatusBadRequest)
return
}
if !checks.HasFix(req.Check) {
writeJSONError(w, "no automated fix available for this check type", http.StatusBadRequest)
return
}
result := checks.ApplyFix(req.Check, req.Message, req.Details, req.FilePath)
// If fix succeeded, dismiss from both alert state and latest findings
if result.Success {
key := req.Check + ":" + req.Message
s.store.DismissFinding(key)
s.store.DismissLatestFinding(key)
s.auditLog(r, "fix", req.Check, result.Action)
}
writeJSON(w, result)
}
// apiBulkFix applies fixes to multiple findings at once.
// POST /api/v1/fix-bulk body: [{"check":"...", "message":"...", "details":"..."}, ...]
func (s *Server) apiBulkFix(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var reqs []struct {
Check string `json:"check"`
Message string `json:"message"`
Details string `json:"details"`
FilePath string `json:"file_path"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &reqs); err != nil {
writeJSONError(w, "invalid request body", http.StatusBadRequest)
return
}
var results []checks.RemediationResult
for _, req := range reqs {
if !checks.HasFix(req.Check) {
results = append(results, checks.RemediationResult{
Error: fmt.Sprintf("no fix for %s", req.Check),
})
continue
}
result := checks.ApplyFix(req.Check, req.Message, req.Details, req.FilePath)
if result.Success {
key := req.Check + ":" + req.Message
s.store.DismissFinding(key)
s.store.DismissLatestFinding(key)
}
results = append(results, result)
}
succeeded := 0
for _, r := range results {
if r.Success {
succeeded++
}
}
writeJSON(w, map[string]interface{}{
"results": results,
"total": len(results),
"succeeded": succeeded,
"failed": len(results) - succeeded,
})
}
// apiFixPreview returns what a fix would do without applying it.
// GET /api/v1/fix-preview?check=...&message=...
// apiAccounts returns a list of cPanel account usernames for the scan dropdown.
//
//nolint:unused // registered via mux.Handle in server.go
func (s *Server) apiAccounts(w http.ResponseWriter, _ *http.Request) {
entries, err := os.ReadDir("/home")
if err != nil {
writeJSON(w, []string{})
return
}
var accounts []string
for _, entry := range entries {
if !entry.IsDir() {
continue
}
name := entry.Name()
// Skip system/hidden directories
if strings.HasPrefix(name, ".") || name == "virtfs" || name == "cPanelInstall" ||
name == "cpanelsolr" || name == "lost+found" {
continue
}
// Must have public_html to be a real cPanel account
if _, err := os.Stat(filepath.Join("/home", name, "public_html")); err == nil {
accounts = append(accounts, name)
}
}
writeJSON(w, accounts)
}
// --- Action endpoints ---
// apiBlockIP blocks an IP via the firewall engine.
// POST /api/v1/block-ip body: {"ip": "1.2.3.4", "reason": "..."}
func (s *Server) apiBlockIP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Duration string `json:"duration"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if req.Reason == "" {
req.Reason = "Blocked via CSM Web UI"
}
dur := parseDuration(req.Duration)
if s.blocker == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
if err := s.blocker.BlockIP(req.IP, req.Reason, dur); err != nil {
writeJSONError(w, fmt.Sprintf("Block failed: %v", err), http.StatusInternalServerError)
return
}
s.auditLog(r, "block_ip", req.IP, req.Reason)
writeJSON(w, map[string]string{"status": "blocked", "ip": req.IP})
}
// apiUnblockIP removes an IP from the firewall + cphulk.
// POST /api/v1/unblock-ip body: {"ip": "1.2.3.4"}
func (s *Server) apiUnblockIP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if s.blocker == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
if err := s.blocker.UnblockIP(req.IP); err != nil {
writeJSONError(w, fmt.Sprintf("Unblock failed: %v", err), http.StatusInternalServerError)
return
}
// Also flush from cphulk (cPanel brute force detector)
flushCphulk(req.IP)
s.auditLog(r, "unblock_ip", req.IP, "manual unblock via UI")
writeJSON(w, map[string]string{"status": "unblocked", "ip": req.IP})
}
// apiUnblockBulk unblocks multiple IPs at once.
func (s *Server) apiUnblockBulk(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IPs []string `json:"ips"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || len(req.IPs) == 0 {
writeJSONError(w, "IPs array is required", http.StatusBadRequest)
return
}
if len(req.IPs) > 100 {
writeJSONError(w, "IPs must be 1-100 items", http.StatusBadRequest)
return
}
if s.blocker == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
succeeded := 0
for _, ip := range req.IPs {
if _, err := parseAndValidateIP(ip); err != nil {
continue
}
if err := s.blocker.UnblockIP(ip); err != nil {
continue
}
flushCphulk(ip)
s.auditLog(r, "unblock_ip", ip, "bulk unblock via UI")
succeeded++
}
writeJSON(w, map[string]interface{}{
"status": "completed",
"total": len(req.IPs),
"succeeded": succeeded,
})
}
// blockedEntry is a raw blocked IP record from firewall state.
type blockedEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source,omitempty"`
BlockedAt time.Time `json:"blocked_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type blockedView struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source"`
BlockedAt string `json:"blocked_at"`
ExpiresAt string `json:"expires_at"`
ExpiresIn string `json:"expires_in"`
}
func formatBlockedView(b blockedEntry) (blockedView, bool) {
if !b.ExpiresAt.IsZero() && time.Now().After(b.ExpiresAt) {
return blockedView{}, false // expired
}
view := blockedView{
IP: b.IP,
Reason: b.Reason,
Source: b.Source,
BlockedAt: b.BlockedAt.Format(time.RFC3339),
}
if view.Source == "" {
view.Source = firewall.InferProvenance("block", b.Reason)
}
if !b.ExpiresAt.IsZero() {
remaining := time.Until(b.ExpiresAt)
view.ExpiresAt = b.ExpiresAt.Format(time.RFC3339)
view.ExpiresIn = fmt.Sprintf("%dh%dm", int(remaining.Hours()), int(remaining.Minutes())%60)
} else {
view.ExpiresIn = "permanent"
}
return view, true
}
// apiBlockedIPs returns the list of currently blocked IPs.
// Uses bbolt store when available, falls back to flat files.
func (s *Server) apiBlockedIPs(w http.ResponseWriter, _ *http.Request) {
var result []blockedView
// Try bbolt store first.
if sdb := store.Global(); sdb != nil {
ss := sdb.LoadFirewallState()
for _, entry := range ss.Blocked {
b := blockedEntry{
IP: entry.IP,
Reason: entry.Reason,
BlockedAt: entry.BlockedAt,
ExpiresAt: entry.ExpiresAt,
}
if view, ok := formatBlockedView(b); ok {
result = append(result, view)
}
}
writeJSON(w, result)
return
}
// Fallback: try firewall engine state.json
fwFile := filepath.Join(s.cfg.StatePath, "firewall", "state.json")
// #nosec G304 -- filepath.Join under operator-configured StatePath.
if fwData, err := os.ReadFile(fwFile); err == nil {
var fwState struct {
Blocked []blockedEntry `json:"blocked"`
}
if json.Unmarshal(fwData, &fwState) == nil {
for _, b := range fwState.Blocked {
if view, ok := formatBlockedView(b); ok {
result = append(result, view)
}
}
writeJSON(w, result)
return
}
}
// Fall back to blocked_ips.json (legacy)
stateFile := filepath.Join(s.cfg.StatePath, "blocked_ips.json")
// #nosec G304 -- filepath.Join under operator-configured StatePath.
data, err := os.ReadFile(stateFile)
if err != nil {
writeJSON(w, []interface{}{})
return
}
var blockState struct {
IPs []blockedEntry `json:"ips"`
}
if err := json.Unmarshal(data, &blockState); err != nil {
writeJSON(w, []interface{}{})
return
}
for _, b := range blockState.IPs {
if view, ok := formatBlockedView(b); ok {
result = append(result, view)
}
}
writeJSON(w, result)
}
// apiDismissFinding marks a finding as baseline (acknowledged/dismissed).
// POST /api/v1/dismiss body: {"key": "check:message"}
func (s *Server) apiDismissFinding(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Key string `json:"key"`
}
if err := decodeJSONBodyLimited(w, r, 16*1024, &req); err != nil || req.Key == "" {
writeJSONError(w, "Key is required", http.StatusBadRequest)
return
}
s.store.DismissFinding(req.Key)
s.store.DismissLatestFinding(req.Key)
s.auditLog(r, "dismiss", req.Key, "")
writeJSON(w, map[string]string{"status": "dismissed", "key": req.Key})
}
// apiQuarantineRestore restores a quarantined file to its original location.
// POST /api/v1/quarantine-restore body: {"id": "filename"}
func (s *Server) apiQuarantineRestore(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
ID string `json:"id"`
}
if err := decodeJSONBodyLimited(w, r, 16*1024, &req); err != nil || req.ID == "" {
writeJSONError(w, "ID is required", http.StatusBadRequest)
return
}
entry, err := resolveQuarantineEntry(req.ID)
if err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
metaData, err := os.ReadFile(entry.MetaPath)
if err != nil {
writeJSONError(w, "Quarantine entry not found", http.StatusNotFound)
return
}
var meta struct {
OriginalPath string `json:"original_path"`
Owner int `json:"owner_uid"`
Group int `json:"group_gid"`
Mode string `json:"mode"`
}
if unmarshalErr := json.Unmarshal(metaData, &meta); unmarshalErr != nil {
writeJSONError(w, "Invalid metadata", http.StatusInternalServerError)
return
}
restorePath, err := validateQuarantineRestorePath(meta.OriginalPath)
if err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
// Ensure parent directory exists
parentDir := filepath.Dir(restorePath)
// #nosec G301 -- Restoring a quarantined file into a user's public_html
// (validated by validateQuarantineRestorePath). The webserver must be
// able to traverse intermediate directories to serve the file.
if mkdirErr := os.MkdirAll(parentDir, 0755); mkdirErr != nil {
writeJSONError(w, fmt.Sprintf("Cannot create parent directory: %v", mkdirErr), http.StatusInternalServerError)
return
}
// Check if quarantined item is a directory or file
quarInfo, err := os.Stat(entry.ItemPath)
if err != nil {
writeJSONError(w, fmt.Sprintf("Cannot stat quarantined file: %v", err), http.StatusInternalServerError)
return
}
// Parse original mode from metadata (format: "-rw-r--r--" or "drwxr-xr-x")
restoredMode := os.FileMode(0644)
if meta.Mode != "" && len(meta.Mode) >= 10 {
restoredMode = parseModeString(meta.Mode)
}
if quarInfo.IsDir() {
if _, statErr := os.Lstat(restorePath); statErr == nil {
writeJSONError(w, "Cannot restore - destination already exists", http.StatusConflict)
return
} else if !os.IsNotExist(statErr) {
writeJSONError(w, fmt.Sprintf("Cannot inspect restore destination: %v", statErr), http.StatusInternalServerError)
return
}
// Directory restore: use os.Rename (same device)
if err := os.Rename(entry.ItemPath, restorePath); err != nil {
writeJSONError(w, fmt.Sprintf("Cannot restore directory: %v", err), http.StatusInternalServerError)
return
}
} else {
// File restore: use O_EXCL to prevent overwriting an existing file
src, readErr := os.Open(entry.ItemPath)
if readErr != nil {
writeJSONError(w, fmt.Sprintf("Cannot read quarantined file: %v", readErr), http.StatusInternalServerError)
return
}
// #nosec G304 -- restorePath was validated with
// validateQuarantineRestorePath above (see handler); the O_EXCL|
// O_NOFOLLOW flags additionally block TOCTOU replacement.
dst, createErr := os.OpenFile(restorePath, os.O_WRONLY|os.O_CREATE|os.O_EXCL|syscall.O_NOFOLLOW, restoredMode)
if createErr != nil {
_ = src.Close()
writeJSONError(w, fmt.Sprintf("Cannot restore - file already exists at original path: %v", createErr), http.StatusConflict)
return
}
_, copyErr := io.Copy(dst, src)
_ = src.Close()
_ = dst.Close()
if copyErr != nil {
if err := os.Remove(restorePath); err != nil && !os.IsNotExist(err) {
log.Printf("webui: failed to remove %s: %v", restorePath, err)
}
writeJSONError(w, fmt.Sprintf("Cannot write restored file: %v", copyErr), http.StatusInternalServerError)
return
}
if err := os.Remove(entry.ItemPath); err != nil && !os.IsNotExist(err) {
log.Printf("webui: failed to remove %s: %v", entry.ItemPath, err)
}
}
// Restore ownership
_ = syscall.Chown(restorePath, meta.Owner, meta.Group)
// Restore mode explicitly (WriteFile may be affected by umask)
_ = os.Chmod(restorePath, restoredMode)
// Remove metadata sidecar
if err := os.Remove(entry.MetaPath); err != nil && !os.IsNotExist(err) {
log.Printf("webui: failed to remove %s: %v", entry.MetaPath, err)
}
s.auditLog(r, "restore", restorePath, "quarantine restore")
writeJSON(w, map[string]string{
"status": "restored",
"path": restorePath,
"warning": "File restored to original location. Re-scan recommended.",
})
}
// apiQuarantinePreview returns the first 8KB of a quarantined file for inspection.
func (s *Server) apiQuarantinePreview(w http.ResponseWriter, r *http.Request) {
entry, err := resolveQuarantineEntry(r.URL.Query().Get("id"))
if err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
info, err := os.Stat(entry.ItemPath)
if err != nil {
writeJSONError(w, "not found", http.StatusNotFound)
return
}
if info.IsDir() {
writeJSON(w, map[string]interface{}{
"id": entry.ID, "is_dir": true,
"preview": "[directory - content preview not available]",
})
return
}
f, err := os.Open(entry.ItemPath)
if err != nil {
writeJSONError(w, "cannot read file", http.StatusInternalServerError)
return
}
defer f.Close()
buf := make([]byte, 8192)
n, _ := f.Read(buf)
writeJSON(w, map[string]interface{}{
"id": entry.ID,
"preview": string(buf[:n]),
"truncated": info.Size() > 8192,
"total_size": info.Size(),
})
}
// apiQuarantineBulkDelete permanently removes quarantined files and their metadata.
func (s *Server) apiQuarantineBulkDelete(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IDs []string `json:"ids"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
if len(req.IDs) == 0 || len(req.IDs) > 100 {
writeJSONError(w, "IDs must be 1-100 items", http.StatusBadRequest)
return
}
count := 0
for _, id := range req.IDs {
entry, err := resolveQuarantineEntry(id)
if err != nil {
continue
}
if _, statErr := os.Lstat(entry.ItemPath); statErr == nil {
if err := os.RemoveAll(entry.ItemPath); err == nil {
count++
}
} else if !os.IsNotExist(statErr) {
continue
}
if err := os.Remove(entry.MetaPath); err != nil && !os.IsNotExist(err) {
log.Printf("webui: failed to remove quarantine meta %s: %v", entry.MetaPath, err)
}
}
s.auditLog(r, "quarantine_bulk_delete", fmt.Sprintf("%d files", count), "")
writeJSON(w, map[string]interface{}{"ok": true, "count": count})
}
// apiTestAlert sends a test finding through all configured alert channels.
func (s *Server) apiTestAlert(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
testFinding := []alert.Finding{{
Severity: alert.Warning,
Check: "test_alert",
Message: "Test alert from CSM Web UI",
Details: fmt.Sprintf("Sent by admin at %s", time.Now().Format("2006-01-02 15:04:05")),
Timestamp: time.Now(),
}}
err := alert.Dispatch(s.cfg, testFinding)
if err != nil {
writeJSON(w, map[string]interface{}{"status": "error", "error": err.Error()})
return
}
s.auditLog(r, "test_alert", "notification", "sent test alert")
writeJSON(w, map[string]interface{}{"status": "sent"})
}
// apiScanAccount runs an on-demand scan for a single cPanel account.
// POST /api/v1/scan-account body: {"account": "username"}
func (s *Server) apiScanAccount(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Account string `json:"account"`
}
if err := decodeJSONBodyLimited(w, r, 32*1024, &req); err != nil || req.Account == "" {
writeJSONError(w, "Account name is required", http.StatusBadRequest)
return
}
if err := validateAccountName(req.Account); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
// Rate limit: only one scan at a time
if !s.acquireScan() {
writeJSONError(w, "A scan is already in progress. Please wait.", http.StatusTooManyRequests)
return
}
defer s.releaseScan()
// Extend the write deadline for this long-running request.
// Account scans can take several minutes; the default WriteTimeout (300s)
// causes ERR_HTTP2_PROTOCOL_ERROR in browsers when it fires mid-stream.
rc := http.NewResponseController(w)
_ = rc.SetWriteDeadline(time.Now().Add(10 * time.Minute))
start := time.Now()
findings := checks.RunAccountScan(s.cfg, s.store, req.Account)
elapsed := time.Since(start).Round(time.Millisecond)
result := map[string]interface{}{
"account": req.Account,
"count": len(findings),
"elapsed": elapsed.String(),
}
writeJSON(w, result)
}
// parseModeString converts a permission string like "-rw-r--r--" to os.FileMode.
func parseModeString(s string) os.FileMode {
if len(s) < 10 {
return 0644
}
var mode os.FileMode
perms := s[len(s)-9:] // last 9 chars: "rwxr-xr-x"
bits := []os.FileMode{
0400, 0200, 0100, // owner r/w/x
0040, 0020, 0010, // group r/w/x
0004, 0002, 0001, // other r/w/x
}
for i, b := range bits {
if i < len(perms) && perms[i] != '-' {
mode |= b
}
}
if mode == 0 {
mode = 0644 // fallback
}
return mode
}
// flushCphulk removes brute-force login history for an IP from cPanel's cphulk.
// Callers must pre-validate `ip` with parseAndValidateIP before invoking.
func flushCphulk(ip string) {
// #nosec G204 -- whmapi1 hardcoded; `ip` is pre-validated with
// parseAndValidateIP in every caller (firewall_api.go, threat_api.go).
// exec.Command passes args directly to execve without shell
// interpolation, so no shell-meta risk even if validation slipped.
_, _ = exec.Command("whmapi1", "flush_cphulk_login_history_for_ips", "ip="+ip).Output()
}
// apiExport returns a JSON bundle of exportable state.
func (s *Server) apiExport(w http.ResponseWriter, _ *http.Request) {
// Collect suppressions
suppressions := s.store.LoadSuppressions()
if suppressions == nil {
suppressions = []state.SuppressionRule{}
}
// Collect whitelist
var whitelist []checks.WhitelistIP
if tdb := checks.GetThreatDB(); tdb != nil {
whitelist = tdb.WhitelistedIPs()
}
if whitelist == nil {
whitelist = []checks.WhitelistIP{}
}
bundle := map[string]interface{}{
"exported_at": time.Now().Format(time.RFC3339),
"hostname": s.cfg.Hostname,
"suppressions": suppressions,
"whitelist": whitelist,
}
w.Header().Set("Content-Disposition", "attachment; filename=csm-state-export.json")
writeJSON(w, bundle)
}
// apiImport merges an exported state bundle into the current state.
func (s *Server) apiImport(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var bundle struct {
Suppressions []state.SuppressionRule `json:"suppressions"`
Whitelist []struct {
IP string `json:"ip"`
} `json:"whitelist"`
}
if err := decodeJSONBodyLimited(w, r, 512*1024, &bundle); err != nil {
writeJSONError(w, "invalid JSON body", http.StatusBadRequest)
return
}
imported := 0
// Merge suppressions (dedup by ID)
if len(bundle.Suppressions) > 0 {
existing := s.store.LoadSuppressions()
existingIDs := make(map[string]bool)
for _, rule := range existing {
existingIDs[rule.ID] = true
}
for _, rule := range bundle.Suppressions {
if !existingIDs[rule.ID] {
existing = append(existing, rule)
imported++
}
}
if err := s.store.SaveSuppressions(existing); err != nil {
writeJSONError(w, fmt.Sprintf("failed to save suppressions: %v", err), http.StatusInternalServerError)
return
}
}
// Merge whitelist IPs
if len(bundle.Whitelist) > 0 {
if tdb := checks.GetThreatDB(); tdb != nil {
existingWL := tdb.WhitelistedIPs()
existingSet := make(map[string]bool)
for _, w := range existingWL {
existingSet[w.IP] = true
}
for _, entry := range bundle.Whitelist {
if entry.IP != "" && !existingSet[entry.IP] {
tdb.AddWhitelist(entry.IP)
imported++
}
}
}
}
s.auditLog(r, "import", "state", fmt.Sprintf("imported %d items", imported))
writeJSON(w, map[string]interface{}{
"status": "imported",
"imported": imported,
"summary": fmt.Sprintf("%d items imported", imported),
})
}
// apiFindingDetail returns detail about a specific finding including related actions.
func (s *Server) apiFindingDetail(w http.ResponseWriter, r *http.Request) {
check := r.URL.Query().Get("check")
message := r.URL.Query().Get("message")
if check == "" {
writeJSONError(w, "check is required", http.StatusBadRequest)
return
}
key := check + ":" + message
// Get state entry for this finding (first/last seen)
var firstSeen, lastSeen string
if entry, ok := s.store.EntryForKey(key); ok {
firstSeen = entry.FirstSeen.Format(time.RFC3339)
lastSeen = entry.LastSeen.Format(time.RFC3339)
}
// Search audit log for related actions
actions := s.searchAuditEntries(check, 20)
// Search history for related findings (same check type, last 50)
allHistory, _ := s.store.ReadHistory(2000, 0)
type histEntry struct {
Severity int `json:"severity"`
Check string `json:"check"`
Message string `json:"message"`
Timestamp string `json:"timestamp"`
}
var related []histEntry
for _, f := range allHistory {
if len(related) >= 50 {
break
}
if f.Check == check {
related = append(related, histEntry{
Severity: int(f.Severity),
Check: f.Check,
Message: f.Message,
Timestamp: f.Timestamp.Format(time.RFC3339),
})
}
}
writeJSON(w, map[string]interface{}{
"check": check,
"message": message,
"first_seen": firstSeen,
"last_seen": lastSeen,
"actions": actions,
"related": related,
})
}
// extractAccountFromFinding extracts a cPanel account name from a finding
// by checking the message, details, and file path for /home/{user}/ patterns
// or "Account: " / "user: " in the details field (used by login checks).
func extractAccountFromFinding(f alert.Finding) string {
for _, s := range []string{f.Message, f.Details, f.FilePath} {
if idx := strings.Index(s, "/home/"); idx >= 0 {
rest := s[idx+6:]
if end := strings.IndexByte(rest, '/'); end > 0 {
return rest[:end]
}
}
}
for _, prefix := range []string{"Account: ", "user: "} {
if idx := strings.Index(f.Details, prefix); idx >= 0 {
rest := f.Details[idx+len(prefix):]
end := strings.IndexAny(rest, " \n\t,")
if end > 0 {
return rest[:end]
}
if len(rest) > 0 {
return rest
}
}
}
return ""
}
func writeJSONError(w http.ResponseWriter, message string, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(map[string]string{"error": message})
}
func writeJSON(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
_ = enc.Encode(data)
}
func queryInt(r *http.Request, key string, defaultVal int) int {
val := r.URL.Query().Get(key)
if val == "" {
return defaultVal
}
n, err := strconv.Atoi(val)
if err != nil || n < 0 {
return defaultVal
}
return n
}
package webui
import (
"bufio"
"encoding/json"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const (
uiAuditFile = "ui_audit.jsonl"
maxUIAuditSize = 10 * 1024 * 1024 // 10 MB
)
// UIAuditEntry records a UI action for compliance and accountability.
type UIAuditEntry struct {
Timestamp time.Time `json:"timestamp"`
Action string `json:"action"` // block, unblock, dismiss, fix, whitelist, etc.
Target string `json:"target"` // IP, finding key, file path
Details string `json:"details,omitempty"` // extra context
SourceIP string `json:"source_ip,omitempty"` // admin's IP
}
// auditLog records a UI action to the audit log.
func (s *Server) auditLog(r *http.Request, action, target, details string) {
entry := UIAuditEntry{
Timestamp: time.Now(),
Action: action,
Target: target,
Details: details,
SourceIP: extractClientIP(r),
}
path := filepath.Join(s.cfg.StatePath, uiAuditFile)
data, err := json.Marshal(entry)
if err != nil {
return
}
data = append(data, '\n')
// Rotate if too large
if info, statErr := os.Stat(path); statErr == nil && info.Size() > maxUIAuditSize {
_ = os.Rename(path, path+".1")
}
// #nosec G304 -- filepath.Join under operator-configured StatePath.
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return
}
defer func() { _ = f.Close() }()
if _, err := f.Write(data); err != nil {
log.Printf("webui: failed to write audit log: %v", err)
}
}
func extractClientIP(r *http.Request) string {
// Use RemoteAddr directly - XFF is trivially spoofable and this is
// a security audit log, so we only trust the TCP connection source.
host := r.RemoteAddr
// Strip port from "ip:port" or "[ipv6]:port"
if last := strings.LastIndex(host, ":"); last >= 0 {
if host[0] == '[' {
// IPv6: [::1]:port
if bracket := strings.Index(host, "]"); bracket >= 0 {
return host[1:bracket]
}
}
return host[:last]
}
return host
}
// readUIAuditLog returns the last N audit entries.
func readUIAuditLog(statePath string, limit int) []UIAuditEntry {
path := filepath.Join(statePath, uiAuditFile)
// #nosec G304 -- filepath.Join under operator-configured statePath.
f, err := os.Open(path)
if err != nil {
return nil
}
defer func() { _ = f.Close() }()
var all []UIAuditEntry
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 256*1024), 256*1024)
for scanner.Scan() {
var entry UIAuditEntry
if json.Unmarshal(scanner.Bytes(), &entry) == nil {
all = append(all, entry)
}
}
// Return newest first
for i, j := 0, len(all)-1; i < j; i, j = i+1, j-1 {
all[i], all[j] = all[j], all[i]
}
if limit > 0 && len(all) > limit {
all = all[:limit]
}
return all
}
// searchAuditEntries returns audit entries whose target or details contain the search string.
func (s *Server) searchAuditEntries(search string, limit int) []UIAuditEntry {
if search == "" || limit <= 0 {
return nil
}
readLimit := limit * 10
if readLimit > 5000 {
readLimit = 5000
}
all := readUIAuditLog(s.cfg.StatePath, readLimit)
searchLower := strings.ToLower(search)
var matched []UIAuditEntry
for _, e := range all {
if strings.Contains(strings.ToLower(e.Target), searchLower) ||
strings.Contains(strings.ToLower(e.Details), searchLower) ||
strings.Contains(strings.ToLower(e.Action), searchLower) {
matched = append(matched, e)
if len(matched) >= limit {
break
}
}
}
return matched
}
func (s *Server) handleAudit(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "audit.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
// GET /api/v1/audit - return UI audit log
func (s *Server) apiUIAudit(w http.ResponseWriter, r *http.Request) {
entries := readUIAuditLog(s.cfg.StatePath, 200)
if entries == nil {
entries = []UIAuditEntry{}
}
writeJSON(w, entries)
}
package webui
import (
"context"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
"github.com/pidginhost/csm/internal/emailav"
"github.com/pidginhost/csm/internal/yara"
)
type emailStatsResponse struct {
QueueSize int `json:"queue_size"`
QueueWarn int `json:"queue_warn"`
QueueCrit int `json:"queue_crit"`
FrozenCount int `json:"frozen_count"`
OldestAge string `json:"oldest_age"`
SMTPBlock bool `json:"smtp_block"`
SMTPAllowUsers []string `json:"smtp_allow_users"`
SMTPPorts []int `json:"smtp_ports"`
PortFlood []portFloodEntry `json:"port_flood"`
TopSenders []senderEntry `json:"top_senders"`
}
type portFloodEntry struct {
Port int `json:"port"`
Proto string `json:"proto"`
Hits int `json:"hits"`
Seconds int `json:"seconds"`
}
type senderEntry struct {
Domain string `json:"domain"`
Count int `json:"count"`
}
func (s *Server) apiEmailStats(w http.ResponseWriter, _ *http.Request) {
resp := emailStatsResponse{
QueueWarn: s.cfg.Thresholds.MailQueueWarn,
QueueCrit: s.cfg.Thresholds.MailQueueCrit,
}
// Live queue size and frozen/oldest via exim
resp.QueueSize = eximQueueSize()
resp.FrozenCount, resp.OldestAge = eximQueueDetails()
// Firewall config
fw := s.cfg.Firewall
resp.SMTPBlock = fw.SMTPBlock
resp.SMTPAllowUsers = fw.SMTPAllowUsers
if resp.SMTPAllowUsers == nil {
resp.SMTPAllowUsers = []string{}
}
resp.SMTPPorts = fw.SMTPPorts
if resp.SMTPPorts == nil {
resp.SMTPPorts = []int{}
}
// Port flood rules for SMTP ports only
smtpPorts := map[int]bool{25: true, 465: true, 587: true}
for _, pf := range fw.PortFlood {
if smtpPorts[pf.Port] {
resp.PortFlood = append(resp.PortFlood, portFloodEntry{
Port: pf.Port,
Proto: pf.Proto,
Hits: pf.Hits,
Seconds: pf.Seconds,
})
}
}
if resp.PortFlood == nil {
resp.PortFlood = []portFloodEntry{}
}
// Top senders from exim_mainlog
resp.TopSenders = topMailSenders(500, 10)
if resp.TopSenders == nil {
resp.TopSenders = []senderEntry{}
}
writeJSON(w, resp)
}
// apiEmailQuarantineList handles GET /api/v1/email/quarantine and returns all
// quarantined email messages, or an empty array if the quarantine is not configured.
func (s *Server) apiEmailQuarantineList(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if s.emailQuarantine == nil {
writeJSON(w, []emailav.QuarantineMetadata{})
return
}
msgs, err := s.emailQuarantine.ListMessages()
if err != nil {
writeJSONError(w, "Failed to list quarantine", http.StatusInternalServerError)
return
}
if msgs == nil {
msgs = []emailav.QuarantineMetadata{}
}
writeJSON(w, msgs)
}
// apiEmailQuarantineAction handles GET, POST (release), and DELETE operations on
// individual quarantined messages at /api/v1/email/quarantine/{msgID}.
func (s *Server) apiEmailQuarantineAction(w http.ResponseWriter, r *http.Request) {
// Extract everything after the prefix, e.g. "abc123" or "abc123/release"
tail := strings.TrimPrefix(r.URL.Path, "/api/v1/email/quarantine/")
if tail == "" {
writeJSONError(w, "Missing message ID", http.StatusBadRequest)
return
}
parts := strings.SplitN(tail, "/", 2)
msgID := filepath.Base(parts[0]) // sanitize path traversal
action := ""
if len(parts) == 2 {
action = parts[1]
}
if msgID == "" || msgID == "." {
writeJSONError(w, "Invalid message ID", http.StatusBadRequest)
return
}
if s.emailQuarantine == nil {
writeJSONError(w, "Email quarantine not configured", http.StatusServiceUnavailable)
return
}
switch r.Method {
case http.MethodGet:
meta, err := s.emailQuarantine.GetMessage(msgID)
if err != nil {
writeJSONError(w, "Message not found", http.StatusNotFound)
return
}
writeJSON(w, meta)
case http.MethodPost:
if action != "release" {
writeJSONError(w, "Unknown action; use /release", http.StatusBadRequest)
return
}
if err := s.emailQuarantine.ReleaseMessage(msgID); err != nil {
writeJSONError(w, "Failed to release message: "+err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "released", "message_id": msgID})
case http.MethodDelete:
if err := s.emailQuarantine.DeleteMessage(msgID); err != nil {
writeJSONError(w, "Failed to delete message: "+err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "deleted", "message_id": msgID})
default:
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
type emailAVStatusResponse struct {
Enabled bool `json:"enabled"`
ClamdAvailable bool `json:"clamd_available"`
ClamdSocket string `json:"clamd_socket"`
YaraXAvailable bool `json:"yarax_available"`
YaraXRuleCount int `json:"yarax_rule_count"`
WatcherMode string `json:"watcher_mode"`
Quarantined int `json:"quarantined"`
}
// apiEmailAVStatus handles GET /api/v1/email/av/status and returns the current
// state of the email AV subsystem (ClamAV, YARA-X, quarantine count, watcher mode).
func (s *Server) apiEmailAVStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
resp := emailAVStatusResponse{
Enabled: s.cfg.EmailAV.Enabled,
}
// ClamAV availability - probe the configured socket.
clamdSocket := s.cfg.EmailAV.ClamdSocket
if clamdSocket == "" {
clamdSocket = "/var/run/clamd.scan/clamd.sock"
}
resp.ClamdSocket = clamdSocket
resp.ClamdAvailable = emailav.NewClamdScanner(clamdSocket).Available()
// YARA-X availability and rule count.
resp.YaraXAvailable = yara.Available()
if gs := yara.Global(); gs != nil {
resp.YaraXRuleCount = gs.RuleCount()
}
// Watcher mode (set by daemon on startup).
resp.WatcherMode = s.emailAVWatcherMode
if resp.WatcherMode == "" {
resp.WatcherMode = "disabled"
}
// Count of currently quarantined messages.
if s.emailQuarantine != nil {
msgs, err := s.emailQuarantine.ListMessages()
if err == nil {
resp.Quarantined = len(msgs)
}
}
writeJSON(w, resp)
}
// eximQueueSize returns the current Exim mail queue count, or 0 on error.
func eximQueueSize() int {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
out, err := exec.CommandContext(ctx, "exim", "-bpc").Output()
if err != nil {
return 0
}
n, _ := strconv.Atoi(strings.TrimSpace(string(out)))
return n
}
// eximQueueDetails returns the frozen message count and the age of the oldest
// message in the queue. Uses `exim -bp` which lists all queued messages.
func eximQueueDetails() (frozen int, oldestAge string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
out, err := exec.CommandContext(ctx, "exim", "-bp").Output()
if err != nil {
return 0, ""
}
lines := strings.Split(string(out), "\n")
for _, line := range lines {
if strings.Contains(line, "*** frozen ***") {
frozen++
}
// First field of queue listing lines is the age (e.g., "4d", "15h", "30m")
fields := strings.Fields(line)
if len(fields) >= 3 {
age := fields[0]
// Only consider lines where first field looks like an age
if len(age) >= 2 && (age[len(age)-1] == 'd' || age[len(age)-1] == 'h' || age[len(age)-1] == 'm' || age[len(age)-1] == 's') {
if oldestAge == "" {
oldestAge = age // first entry is the oldest (queue sorted oldest first)
}
}
}
}
return frozen, oldestAge
}
// topMailSenders parses the last N lines of exim_mainlog and returns the
// top K sender domains by outbound message count.
func topMailSenders(tailLines, topK int) []senderEntry {
f, err := os.Open("/var/log/exim_mainlog")
if err != nil {
return nil
}
defer f.Close()
// Read tail of file
info, _ := f.Stat()
var data []byte
if info != nil && info.Size() > 256*1024 {
if _, err := f.Seek(-256*1024, 2); err != nil {
return nil
}
data, _ = io.ReadAll(f)
} else {
data, _ = io.ReadAll(f)
}
lines := strings.Split(string(data), "\n")
// Take last N lines
if len(lines) > tailLines {
lines = lines[len(lines)-tailLines:]
}
counts := make(map[string]int)
for _, line := range lines {
idx := strings.Index(line, " <= ")
if idx < 0 {
continue
}
rest := line[idx+4:]
fields := strings.Fields(rest)
if len(fields) < 1 {
continue
}
sender := fields[0]
atIdx := strings.LastIndex(sender, "@")
if atIdx < 0 {
continue
}
domain := sender[atIdx+1:]
if domain == "" || sender == "<>" || strings.HasPrefix(sender, "cPanel") {
continue
}
counts[domain]++
}
var entries []senderEntry
for domain, count := range counts {
entries = append(entries, senderEntry{Domain: domain, Count: count})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if len(entries) > topK {
entries = entries[:topK]
}
return entries
}
package webui
import (
"bytes"
"fmt"
"net"
"net/http"
"os/exec"
"sort"
"strings"
"time"
"github.com/pidginhost/csm/internal/firewall"
)
type firewallAllowView struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source"`
ExpiresAt string `json:"expires_at,omitempty"`
ExpiresIn string `json:"expires_in"`
}
type firewallPortAllowView struct {
IP string `json:"ip"`
Port int `json:"port"`
Proto string `json:"proto"`
Reason string `json:"reason"`
Source string `json:"source"`
}
func formatRemaining(expiresAt time.Time) string {
if expiresAt.IsZero() {
return "permanent"
}
remaining := time.Until(expiresAt)
if remaining < 0 {
remaining = 0
}
return fmt.Sprintf("%dh%dm", int(remaining.Hours()), int(remaining.Minutes())%60)
}
// apiFirewallStatus returns the firewall engine configuration and state summary.
func (s *Server) apiFirewallStatus(w http.ResponseWriter, _ *http.Request) {
cfg := s.cfg.Firewall
state, _ := firewall.LoadState(s.cfg.StatePath)
now := time.Now()
// Use top-level infra_ips if firewall.infra_ips is empty (daemon syncs at runtime)
infraIPs := cfg.InfraIPs
if len(infraIPs) == 0 {
infraIPs = s.cfg.InfraIPs
}
blockedPermanent := 0
blockedTemporary := 0
for _, entry := range state.Blocked {
if entry.ExpiresAt.IsZero() {
blockedPermanent++
continue
}
if now.Before(entry.ExpiresAt) {
blockedTemporary++
}
}
allowPermanent := 0
allowTemporary := 0
for _, entry := range state.Allowed {
if entry.ExpiresAt.IsZero() {
allowPermanent++
continue
}
if now.Before(entry.ExpiresAt) {
allowTemporary++
}
}
result := map[string]interface{}{
"enabled": cfg.Enabled,
"ipv6": cfg.IPv6,
"tcp_in": cfg.TCPIn,
"tcp_out": cfg.TCPOut,
"udp_in": cfg.UDPIn,
"udp_out": cfg.UDPOut,
"restricted_tcp": cfg.RestrictedTCP,
"passive_ftp": [2]int{cfg.PassiveFTPStart, cfg.PassiveFTPEnd},
"conn_rate_limit": cfg.ConnRateLimit,
"conn_limit": cfg.ConnLimit,
"syn_flood_protection": cfg.SYNFloodProtection,
"udp_flood": cfg.UDPFlood,
"smtp_block": cfg.SMTPBlock,
"log_dropped": cfg.LogDropped,
"deny_ip_limit": cfg.DenyIPLimit,
"blocked_count": blockedPermanent + blockedTemporary,
"blocked_net_count": len(state.BlockedNet),
"blocked_permanent": blockedPermanent,
"blocked_temporary": blockedTemporary,
"allowed_count": allowPermanent + allowTemporary,
"allow_permanent": allowPermanent,
"allow_temporary": allowTemporary,
"port_allow_count": len(state.PortAllowed),
"infra_ips": infraIPs,
"infra_count": len(infraIPs),
"port_flood_rules": len(cfg.PortFlood),
"country_block": cfg.CountryBlock,
"dyndns_hosts": cfg.DynDNSHosts,
}
writeJSON(w, result)
}
// apiFirewallAllowed returns active firewall allow rules and port exceptions.
func (s *Server) apiFirewallAllowed(w http.ResponseWriter, _ *http.Request) {
state, _ := firewall.LoadState(s.cfg.StatePath)
now := time.Now()
var allowed []firewallAllowView
for _, entry := range state.Allowed {
if !entry.ExpiresAt.IsZero() && !now.Before(entry.ExpiresAt) {
continue
}
view := firewallAllowView{
IP: entry.IP,
Reason: entry.Reason,
Source: entry.Source,
ExpiresIn: formatRemaining(entry.ExpiresAt),
}
if view.Source == "" {
view.Source = firewall.InferProvenance("allow", entry.Reason)
}
if !entry.ExpiresAt.IsZero() {
view.ExpiresAt = entry.ExpiresAt.Format(time.RFC3339)
}
allowed = append(allowed, view)
}
sort.Slice(allowed, func(i, j int) bool {
return allowed[i].IP < allowed[j].IP
})
portAllowed := make([]firewallPortAllowView, 0, len(state.PortAllowed))
for _, entry := range state.PortAllowed {
portAllowed = append(portAllowed, firewallPortAllowView{
IP: entry.IP,
Port: entry.Port,
Proto: entry.Proto,
Reason: entry.Reason,
Source: entry.Source,
})
if portAllowed[len(portAllowed)-1].Source == "" {
portAllowed[len(portAllowed)-1].Source = firewall.InferProvenance("allow_port", entry.Reason)
}
}
sort.Slice(portAllowed, func(i, j int) bool {
if portAllowed[i].IP != portAllowed[j].IP {
return portAllowed[i].IP < portAllowed[j].IP
}
if portAllowed[i].Port != portAllowed[j].Port {
return portAllowed[i].Port < portAllowed[j].Port
}
return portAllowed[i].Proto < portAllowed[j].Proto
})
writeJSON(w, map[string]interface{}{
"allowed": allowed,
"port_allowed": portAllowed,
})
}
// apiFirewallAllowIP adds a firewall allow rule, temporary when duration > 0.
func (s *Server) apiFirewallAllowIP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
Reason string `json:"reason"`
Duration string `json:"duration"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if req.Reason == "" {
req.Reason = "Allowed via CSM Web UI"
}
dur := parseDuration(req.Duration)
if dur > 0 {
allower, ok := s.blocker.(interface {
TempAllowIP(string, string, time.Duration) error
})
if !ok || allower == nil {
writeJSONError(w, "Firewall allow rules are not available", http.StatusServiceUnavailable)
return
}
if err := allower.TempAllowIP(req.IP, req.Reason, dur); err != nil {
writeJSONError(w, fmt.Sprintf("Allow failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "temp_allowed", "ip": req.IP})
return
}
allower, ok := s.blocker.(interface {
AllowIP(string, string) error
})
if !ok || allower == nil {
writeJSONError(w, "Firewall allow rules are not available", http.StatusServiceUnavailable)
return
}
if err := allower.AllowIP(req.IP, req.Reason); err != nil {
writeJSONError(w, fmt.Sprintf("Allow failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "allowed", "ip": req.IP})
}
// apiFirewallRemoveAllow removes a firewall allow rule.
func (s *Server) apiFirewallRemoveAllow(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, fmt.Sprintf("invalid IP address: %s", req.IP), http.StatusBadRequest)
return
}
allower, ok := s.blocker.(interface {
RemoveAllowIP(string) error
})
if !ok || allower == nil {
writeJSONError(w, "Firewall allow rules are not available", http.StatusServiceUnavailable)
return
}
if err := allower.RemoveAllowIP(req.IP); err != nil {
writeJSONError(w, fmt.Sprintf("Remove failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "removed", "ip": req.IP})
}
// apiFirewallAudit returns recent firewall audit log entries.
func (s *Server) apiFirewallAudit(w http.ResponseWriter, r *http.Request) {
limit := queryInt(r, "limit", 100)
entries := firewall.ReadAuditLog(s.cfg.StatePath, limit)
if entries == nil {
writeJSON(w, []interface{}{})
return
}
type auditView struct {
Timestamp string `json:"timestamp"`
Action string `json:"action"`
IP string `json:"ip"`
Reason string `json:"reason"`
Source string `json:"source"`
Duration string `json:"duration"`
TimeAgo string `json:"time_ago"`
}
search := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("search")))
actionFilter := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("action")))
sourceFilter := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("source")))
var result []auditView
for _, e := range entries {
source := e.Source
if source == "" {
source = firewall.InferProvenance(e.Action, e.Reason)
}
if actionFilter != "" && strings.ToLower(e.Action) != actionFilter {
continue
}
if sourceFilter != "" && strings.ToLower(source) != sourceFilter {
continue
}
if search != "" {
haystack := strings.ToLower(strings.Join([]string{e.Action, e.IP, e.Reason, source}, " "))
if !strings.Contains(haystack, search) {
continue
}
}
result = append(result, auditView{
Timestamp: e.Timestamp.Format("2006-01-02 15:04:05"),
Action: e.Action,
IP: e.IP,
Reason: e.Reason,
Source: source,
Duration: e.Duration,
TimeAgo: timeAgo(e.Timestamp),
})
}
writeJSON(w, result)
}
// apiFirewallSubnets returns currently blocked subnets.
func (s *Server) apiFirewallSubnets(w http.ResponseWriter, _ *http.Request) {
state, _ := firewall.LoadState(s.cfg.StatePath)
type subnetView struct {
CIDR string `json:"cidr"`
Reason string `json:"reason"`
Source string `json:"source"`
BlockedAt string `json:"blocked_at"`
TimeAgo string `json:"time_ago"`
ExpiresIn string `json:"expires_in"`
}
var result []subnetView
for _, sn := range state.BlockedNet {
v := subnetView{
CIDR: sn.CIDR,
Reason: sn.Reason,
Source: sn.Source,
BlockedAt: sn.BlockedAt.Format(time.RFC3339),
TimeAgo: timeAgo(sn.BlockedAt),
}
if v.Source == "" {
v.Source = firewall.InferProvenance("block_subnet", sn.Reason)
}
v.ExpiresIn = formatRemaining(sn.ExpiresAt)
result = append(result, v)
}
if result == nil {
writeJSON(w, []interface{}{})
return
}
writeJSON(w, result)
}
// apiFirewallDenySubnet blocks a subnet via the firewall engine.
func (s *Server) apiFirewallDenySubnet(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
CIDR string `json:"cidr"`
Reason string `json:"reason"`
Duration string `json:"duration"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.CIDR == "" {
writeJSONError(w, "CIDR is required", http.StatusBadRequest)
return
}
if _, err := validateCIDR(req.CIDR); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if req.Reason == "" {
req.Reason = "Blocked via CSM Web UI"
}
dur := parseDuration(req.Duration)
sb, ok := s.blocker.(interface {
BlockSubnet(string, string, time.Duration) error
})
if !ok || sb == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
if err := sb.BlockSubnet(req.CIDR, req.Reason, dur); err != nil {
writeJSONError(w, fmt.Sprintf("Block failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "blocked", "cidr": req.CIDR})
}
// apiFirewallRemoveSubnet removes a subnet block.
func (s *Server) apiFirewallRemoveSubnet(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
CIDR string `json:"cidr"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.CIDR == "" {
writeJSONError(w, "CIDR is required", http.StatusBadRequest)
return
}
if _, _, err := net.ParseCIDR(req.CIDR); err != nil {
writeJSONError(w, "Invalid CIDR notation", http.StatusBadRequest)
return
}
sb, ok := s.blocker.(interface {
UnblockSubnet(string) error
})
if !ok || sb == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
if err := sb.UnblockSubnet(req.CIDR); err != nil {
writeJSONError(w, fmt.Sprintf("Remove failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "removed", "cidr": req.CIDR})
}
// apiFirewallFlush clears all blocked IPs.
func (s *Server) apiFirewallFlush(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
fb, ok := s.blocker.(interface{ FlushBlocked() error })
if !ok || fb == nil {
writeJSONError(w, "Firewall engine not available", http.StatusServiceUnavailable)
return
}
if err := fb.FlushBlocked(); err != nil {
writeJSONError(w, fmt.Sprintf("Flush failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]string{"status": "flushed"})
}
// apiFirewallFlushCphulk clears cPHulk login history for one IP without touching firewall state.
func (s *Server) apiFirewallFlushCphulk(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
flushCphulk(req.IP)
writeJSON(w, map[string]string{"status": "flushed", "ip": req.IP})
}
// apiFirewallCheck checks if an IP is blocked in CSM or cphulk.
// GET /api/v1/firewall/check?ip=1.2.3.4
// Response matches cpanel-service format for phclient compatibility:
//
// {"success": true, "ip": "1.2.3.4", "permanent": "reason or null", "temporary": "reason or null", "cphulk": true/false}
func (s *Server) apiFirewallCheck(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip == "" || net.ParseIP(ip) == nil {
writeJSON(w, map[string]interface{}{"success": false, "error_msg": "The ip is not valid or it was not set."})
return
}
result := map[string]interface{}{
"success": true,
"ip": ip,
"permanent": nil,
"temporary": nil,
"cphulk": false,
}
// Check CSM firewall state
state, _ := firewall.LoadState(s.cfg.StatePath)
now := time.Now()
for _, b := range state.Blocked {
if b.IP == ip {
if b.ExpiresAt.IsZero() {
result["permanent"] = b.Reason
} else if now.Before(b.ExpiresAt) {
result["temporary"] = fmt.Sprintf("%s (expires in %s)", b.Reason,
time.Until(b.ExpiresAt).Truncate(time.Minute))
}
}
}
// Check blocked subnets
parsedIP := net.ParseIP(ip)
for _, sn := range state.BlockedNet {
_, network, err := net.ParseCIDR(sn.CIDR)
if err == nil && network.Contains(parsedIP) {
result["permanent"] = fmt.Sprintf("Subnet block: %s - %s", sn.CIDR, sn.Reason)
}
}
// Check cphulk (cPanel brute force detector) - read-only check
cphulkOut, cphulkErr := exec.Command("whmapi1", "read_cphulk_records",
"list_name=black", "--output=json").Output()
if cphulkErr == nil {
if bytes.Contains(cphulkOut, []byte(ip)) {
result["cphulk"] = true
}
}
writeJSON(w, result)
}
// apiFirewallUnban unblocks an IP from CSM + cphulk in one call.
// POST /api/v1/firewall/unban body: {"ip": "1.2.3.4"}
func (s *Server) apiFirewallUnban(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSON(w, map[string]interface{}{"success": false, "error_msg": "The ip is not valid or it was not set."})
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSON(w, map[string]interface{}{"success": false, "error_msg": err.Error()})
return
}
// 1. Unblock from CSM firewall (individual IP)
if s.blocker != nil {
_ = s.blocker.UnblockIP(req.IP)
}
// 2. Also remove from any covering subnet block
state, _ := firewall.LoadState(s.cfg.StatePath)
parsedIP := net.ParseIP(req.IP)
subnetRemoved := ""
if sb, ok := s.blocker.(interface{ UnblockSubnet(string) error }); ok {
for _, sn := range state.BlockedNet {
_, network, err := net.ParseCIDR(sn.CIDR)
if err == nil && network.Contains(parsedIP) {
_ = sb.UnblockSubnet(sn.CIDR)
subnetRemoved = sn.CIDR
break
}
}
}
// 3. Flush from cphulk
flushCphulk(req.IP)
result := map[string]interface{}{"success": true, "ip": req.IP}
if subnetRemoved != "" {
result["subnet_removed"] = subnetRemoved
}
writeJSON(w, result)
}
package webui
import (
"net"
"net/http"
"github.com/pidginhost/csm/internal/geoip"
)
// SetGeoIPDB sets the GeoIP database for IP lookups.
func (s *Server) SetGeoIPDB(db *geoip.DB) {
s.geoIPDB.Store(db)
}
// apiGeoIPLookup returns geolocation info for an IP.
// GET /api/v1/geoip?ip=1.2.3.4 - fast local lookup (country + ASN)
// GET /api/v1/geoip?ip=1.2.3.4&detail=1 - includes RDAP org/ISP (may take 1-3s)
func (s *Server) apiGeoIPLookup(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip == "" {
writeJSONError(w, "ip parameter required", http.StatusBadRequest)
return
}
if net.ParseIP(ip) == nil {
writeJSONError(w, "invalid IP address", http.StatusBadRequest)
return
}
db := s.geoIPDB.Load()
if db == nil {
writeJSONError(w, "GeoIP databases not loaded", http.StatusServiceUnavailable)
return
}
var info geoip.Info
if r.URL.Query().Get("detail") == "1" {
info = db.LookupWithRDAP(ip)
} else {
info = db.Lookup(ip)
}
writeJSON(w, info)
}
// apiGeoIPBatch returns geolocation info for multiple IPs.
// POST /api/v1/geoip/batch body: {"ips": ["1.2.3.4", "5.6.7.8"]}
func (s *Server) apiGeoIPBatch(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IPs []string `json:"ips"`
}
if err := decodeJSONBodyLimited(w, r, 32*1024, &req); err != nil {
writeJSONError(w, "invalid request body", http.StatusBadRequest)
return
}
if len(req.IPs) > 500 {
writeJSONError(w, "maximum 500 IPs per request", http.StatusBadRequest)
return
}
type geoResult struct {
Country string `json:"country"`
CountryName string `json:"country_name"`
City string `json:"city"`
ASOrg string `json:"as_org"`
Error string `json:"error,omitempty"`
}
db := s.geoIPDB.Load()
results := make(map[string]geoResult, len(req.IPs))
for _, ip := range req.IPs {
if net.ParseIP(ip) == nil {
results[ip] = geoResult{Error: "invalid IP format"}
continue
}
if db == nil {
results[ip] = geoResult{Error: "GeoIP database not loaded"}
continue
}
info := db.Lookup(ip)
results[ip] = geoResult{
Country: info.Country,
CountryName: info.CountryName,
City: info.City,
ASOrg: info.ASOrg,
}
}
writeJSON(w, map[string]interface{}{"results": results})
}
package webui
import (
"fmt"
"net/http"
"os"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/checks"
)
func (s *Server) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
if err := s.templates[name].ExecuteTemplate(w, name, data); err != nil {
fmt.Fprintf(os.Stderr, "[webui] template %s error: %v\n", name, err)
}
}
type dashboardData struct {
Hostname string
Uptime string
Critical int
High int
Warning int
Total int
SigCount int
FanotifyActive bool
LogWatchers int
LastCriticalAgo string
RecentFindings []historyEntry
}
type historyEntry struct {
Severity string
SevClass string
Check string
Message string
Details string
Timestamp string
TimestampISO string // RFC3339 for JS comparison
TimeAgo string
HasFix bool
FixDesc string
}
type quarantineData struct {
Hostname string
Files []quarantineEntry
}
type quarantineEntry struct {
ID string
OriginalPath string
Size int64
QuarantineAt string
Reason string
}
func (s *Server) handleDashboard(w http.ResponseWriter, _ *http.Request) {
last24h := time.Now().Add(-24 * time.Hour)
findings := s.store.ReadHistorySince(last24h)
var recent []historyEntry
critical, high, warning := 0, 0, 0
for _, f := range findings {
// Count all findings by severity
switch f.Severity {
case alert.Critical:
critical++
case alert.High:
high++
case alert.Warning:
warning++
}
// Skip internal checks from the live feed
if f.Check == "auto_response" || f.Check == "auto_block" || f.Check == "check_timeout" || f.Check == "health" {
continue
}
if len(recent) < 10 {
recent = append(recent, historyEntry{
Severity: severityLabel(f.Severity),
SevClass: severityClass(f.Severity),
Check: f.Check,
Message: f.Message,
Details: f.Details,
Timestamp: f.Timestamp.Format("15:04:05"),
TimestampISO: f.Timestamp.Format(time.RFC3339),
TimeAgo: timeAgo(f.Timestamp),
HasFix: checks.HasFix(f.Check),
FixDesc: checks.FixDescription(f.Check, f.Message, f.FilePath),
})
}
}
// Find most recent critical finding (findings are newest-first)
lastCriticalAgo := "None"
for _, f := range findings {
if f.Severity == alert.Critical {
lastCriticalAgo = timeAgo(f.Timestamp)
break
}
}
data := dashboardData{
Hostname: s.cfg.Hostname,
Uptime: time.Since(s.startTime).Round(time.Second).String(),
Critical: critical,
High: high,
Warning: warning,
Total: critical + high + warning,
SigCount: s.sigCount,
FanotifyActive: s.fanotifyActive,
LogWatchers: s.logWatcherCount,
LastCriticalAgo: lastCriticalAgo,
RecentFindings: recent,
}
s.renderTemplate(w, "dashboard.html", data)
}
func (s *Server) handleFindings(w http.ResponseWriter, _ *http.Request) {
// Findings page is now JS-driven - enriched API provides data
s.renderTemplate(w, "findings.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
func (s *Server) handleHistoryRedirect(w http.ResponseWriter, r *http.Request) {
// History is now a tab on the findings page - redirect for backward compat
target := "/findings?tab=history"
if qs := r.URL.RawQuery; qs != "" {
target = "/findings?tab=history&" + qs
}
http.Redirect(w, r, target, http.StatusFound)
}
func (s *Server) handleQuarantine(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "quarantine.html", quarantineData{
Hostname: s.cfg.Hostname,
})
}
func (s *Server) handleFirewall(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "firewall.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
func (s *Server) handleEmail(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "email.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
package webui
import (
"net/http"
"time"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/store"
)
// apiHardening returns the last stored audit report (GET).
func (s *Server) apiHardening(w http.ResponseWriter, _ *http.Request) {
db := store.Global()
if db == nil {
writeJSON(w, &store.AuditReport{})
return
}
report, err := db.LoadHardeningReport()
if err != nil {
writeJSONError(w, "failed to load report: "+err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, report)
}
// apiHardeningRun runs the audit, stores the result, and returns it (POST only).
func (s *Server) apiHardeningRun(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if !s.acquireScan() {
writeJSONError(w, "A scan is already in progress. Please wait.", http.StatusConflict)
return
}
defer s.releaseScan()
// Extend write deadline for this long-running request
rc := http.NewResponseController(w)
_ = rc.SetWriteDeadline(time.Now().Add(3 * time.Minute))
report := checks.RunHardeningAudit(s.cfg)
if db := store.Global(); db != nil {
if err := db.SaveHardeningReport(report); err != nil {
writeJSONError(w, "audit completed but failed to save report: "+err.Error(), http.StatusInternalServerError)
return
}
}
writeJSON(w, report)
}
// handleHardening renders the hardening audit page.
func (s *Server) handleHardening(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "hardening.html", nil)
}
package webui
import (
"bytes"
"encoding/json"
"fmt"
"html/template"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
// jsonForScript marshals v to JSON and returns it as template.JS suitable
// for direct substitution into a <script> block. json.Marshal already
// escapes < > & U+2028 U+2029 to \uXXXX form, so the output cannot break
// out of the surrounding <script> tag or trigger JS line-terminator
// parsing quirks. On marshal failure the fallback is the JS literal
// "null" so the enclosing template still parses.
func jsonForScript(v interface{}) template.JS {
b, err := json.Marshal(v)
if err != nil {
return template.JS("null")
}
// Defense-in-depth: Go's json.Marshal has historically escaped these
// by default, but an explicit pass guarantees the contract even if
// that default ever changes or the input arrived pre-encoded.
b = bytes.ReplaceAll(b, []byte("<"), []byte(`\u003c`))
b = bytes.ReplaceAll(b, []byte(">"), []byte(`\u003e`))
b = bytes.ReplaceAll(b, []byte("&"), []byte(`\u0026`))
b = bytes.ReplaceAll(b, []byte("\u2028"), []byte(`\u2028`))
b = bytes.ReplaceAll(b, []byte("\u2029"), []byte(`\u2029`))
// #nosec G203 -- Output is JSON bytes with HTML/JS-dangerous codepoints
// escaped above; safe to hand to html/template as JS.
return template.JS(b)
}
func decodeJSONBodyLimited(w http.ResponseWriter, r *http.Request, limit int64, dst interface{}) error {
if limit <= 0 {
limit = 64 * 1024
}
r.Body = http.MaxBytesReader(w, r.Body, limit)
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(dst); err != nil {
return err
}
if dec.More() {
return fmt.Errorf("request body must contain a single JSON value")
}
return nil
}
// validateAccountName checks that name is a valid cPanel account name:
// 1-64 characters, alphanumeric and underscore only.
func validateAccountName(name string) error {
if name == "" {
return fmt.Errorf("account name is required")
}
if len(name) > 64 {
return fmt.Errorf("account name too long (%d chars, max 64)", len(name))
}
if (name[0] < 'a' || name[0] > 'z') && (name[0] < 'A' || name[0] > 'Z') {
return fmt.Errorf("account name must start with a letter")
}
for _, c := range name {
if (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') && (c < '0' || c > '9') && c != '_' {
return fmt.Errorf("account name contains invalid character: %c", c)
}
}
return nil
}
// parseAndValidateIP parses an IP string and rejects non-routable addresses
// (loopback, private RFC 1918, link-local, multicast, unspecified, broadcast).
// RFC 5737 documentation ranges (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24)
// are intentionally allowed.
func parseAndValidateIP(s string) (net.IP, error) {
s = strings.TrimSpace(s)
if s == "" {
return nil, fmt.Errorf("IP address is required")
}
ip := net.ParseIP(s)
if ip == nil {
return nil, fmt.Errorf("invalid IP address: %s", s)
}
if ip.IsLoopback() {
return nil, fmt.Errorf("loopback address not allowed: %s", s)
}
if ip.IsPrivate() {
return nil, fmt.Errorf("private address not allowed: %s", s)
}
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return nil, fmt.Errorf("link-local address not allowed: %s", s)
}
if ip.IsMulticast() {
return nil, fmt.Errorf("multicast address not allowed: %s", s)
}
if ip.IsUnspecified() {
return nil, fmt.Errorf("unspecified address not allowed: %s", s)
}
// Broadcast: 255.255.255.255
if ip.Equal(net.IPv4bcast) {
return nil, fmt.Errorf("broadcast address not allowed: %s", s)
}
return ip, nil
}
// validateCIDR parses a CIDR string and rejects overly broad prefixes
// (/0 through /7; minimum allowed is /8).
func validateCIDR(s string) (*net.IPNet, error) {
if s == "" {
return nil, fmt.Errorf("CIDR is required")
}
_, ipNet, err := net.ParseCIDR(s)
if err != nil {
return nil, fmt.Errorf("invalid CIDR: %w", err)
}
ones, bits := ipNet.Mask.Size()
minPrefix := 8
if bits == 128 { // IPv6
minPrefix = 32
}
if ones < minPrefix {
return nil, fmt.Errorf("CIDR prefix /%d is too broad (minimum /%d)", ones, minPrefix)
}
return ipNet, nil
}
// parseDuration parses a human-friendly duration string from the web UI.
// Supported formats: "24h", "7d", "30d", "0" (permanent), "" (permanent).
func parseDuration(s string) time.Duration {
s = strings.TrimSpace(s)
if s == "" || s == "0" {
return 0
}
if strings.HasSuffix(s, "d") {
s = strings.TrimSuffix(s, "d")
days := 0
for _, c := range s {
if c < '0' || c > '9' {
return 0
}
days = days*10 + int(c-'0')
}
return time.Duration(days) * 24 * time.Hour
}
if d, err := time.ParseDuration(s); err == nil {
return d
}
return 0
}
// isPathUnder returns true if the cleaned path is strictly under the base
// directory. It prevents path traversal via ".." and prefix tricks
// (e.g., /home/username is not under /home/user).
func isPathUnder(path, base string) bool {
cleanPath := filepath.Clean(path)
cleanBase := filepath.Clean(base)
// Ensure base ends with separator so "/home/user" doesn't match "/home/username"
prefix := cleanBase + string(filepath.Separator)
return strings.HasPrefix(cleanPath, prefix)
}
func isPathWithin(path, base string) bool {
cleanPath := filepath.Clean(path)
cleanBase := filepath.Clean(base)
return cleanPath == cleanBase || strings.HasPrefix(cleanPath, cleanBase+string(filepath.Separator))
}
const preCleanQuarantineIDPrefix = "pre_clean:"
type quarantineEntryRef struct {
ID string
ItemPath string
MetaPath string
}
func quarantineEntryID(metaPath string) string {
id := strings.TrimSuffix(filepath.Base(metaPath), ".meta")
if filepath.Clean(filepath.Dir(metaPath)) == filepath.Join(quarantineDir, "pre_clean") {
return preCleanQuarantineIDPrefix + id
}
return id
}
func resolveQuarantineEntry(id string) (quarantineEntryRef, error) {
rawID := strings.TrimSpace(id)
if rawID == "" {
return quarantineEntryRef{}, fmt.Errorf("quarantine ID is required")
}
baseDir := quarantineDir
name := rawID
if strings.HasPrefix(rawID, preCleanQuarantineIDPrefix) {
baseDir = filepath.Join(quarantineDir, "pre_clean")
name = strings.TrimPrefix(rawID, preCleanQuarantineIDPrefix)
}
name = filepath.Base(name)
if name == "" || name == "." || name == ".." {
return quarantineEntryRef{}, fmt.Errorf("invalid quarantine ID")
}
itemPath := filepath.Join(baseDir, name)
if !isPathWithin(itemPath, baseDir) {
return quarantineEntryRef{}, fmt.Errorf("invalid quarantine ID")
}
return quarantineEntryRef{
ID: rawID,
ItemPath: itemPath,
MetaPath: itemPath + ".meta",
}, nil
}
// quarantineMeta represents the JSON sidecar metadata for a quarantined file.
// Must match checks.QuarantineMeta on-disk format.
type quarantineMeta struct {
OriginalPath string `json:"original_path"`
Owner int `json:"owner_uid"`
Group int `json:"group_gid"`
Mode string `json:"mode"`
Size int64 `json:"size"`
QuarantineAt time.Time `json:"quarantined_at"`
Reason string `json:"reason"`
}
// readQuarantineMeta reads and parses a quarantine .meta JSON file.
func readQuarantineMeta(metaPath string) (*quarantineMeta, error) {
// #nosec G304 -- metaPath is constructed by resolveQuarantineEntry under
// the quarantine base dir with filepath.Base applied to the ID.
data, err := os.ReadFile(metaPath)
if err != nil {
return nil, fmt.Errorf("read quarantine meta: %w", err)
}
var meta quarantineMeta
if err := json.Unmarshal(data, &meta); err != nil {
return nil, fmt.Errorf("parse quarantine meta: %w", err)
}
return &meta, nil
}
// listMetaFiles returns the full paths of all .meta files in dir (non-recursive).
// Returns nil on any error (e.g., directory does not exist).
func listMetaFiles(dir string) []string {
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
var metas []string
for _, e := range entries {
if e.IsDir() {
continue
}
if strings.HasSuffix(e.Name(), ".meta") {
metas = append(metas, filepath.Join(dir, e.Name()))
}
}
return metas
}
var quarantineRestoreRoots = []string{"/home", "/tmp", "/dev/shm", "/var/tmp"}
func validateQuarantineRestorePath(path string) (string, error) {
cleanPath := filepath.Clean(strings.TrimSpace(path))
if cleanPath == "" {
return "", fmt.Errorf("restore path is required")
}
if !filepath.IsAbs(cleanPath) {
return "", fmt.Errorf("restore path must be absolute")
}
if !pathWithinAny(cleanPath, quarantineRestoreRoots) {
return "", fmt.Errorf("restore path is outside the allowed restore roots: %s", cleanPath)
}
ancestor, err := nearestExistingAncestor(cleanPath)
if err != nil {
return "", err
}
resolvedAncestor, err := filepath.EvalSymlinks(ancestor)
if err != nil {
return "", fmt.Errorf("cannot validate restore path: %w", err)
}
if !pathWithinAny(resolvedAncestor, quarantineRestoreRoots) {
return "", fmt.Errorf("restore path escapes the allowed restore roots: %s", cleanPath)
}
if accountRoot := homeAccountRoot(cleanPath); accountRoot != "" && !isPathWithin(resolvedAncestor, accountRoot) {
return "", fmt.Errorf("restore path escapes the account boundary: %s", cleanPath)
}
return cleanPath, nil
}
func pathWithinAny(path string, bases []string) bool {
for _, base := range bases {
if isPathWithin(path, base) {
return true
}
if resolvedBase, err := filepath.EvalSymlinks(base); err == nil && isPathWithin(path, resolvedBase) {
return true
}
}
return false
}
func nearestExistingAncestor(path string) (string, error) {
current := filepath.Clean(path)
for {
if _, err := os.Lstat(current); err == nil {
return current, nil
} else if !os.IsNotExist(err) {
return "", fmt.Errorf("cannot stat restore path: %w", err)
}
parent := filepath.Dir(current)
if parent == current {
return "", fmt.Errorf("restore path has no existing ancestor: %s", path)
}
current = parent
}
}
func homeAccountRoot(path string) string {
cleanPath := filepath.Clean(path)
if !strings.HasPrefix(cleanPath, "/home/") {
return ""
}
parts := strings.Split(cleanPath, string(filepath.Separator))
if len(parts) < 4 {
return ""
}
return filepath.Join("/home", parts[2])
}
package webui
import (
"net/http"
"sort"
"strings"
"time"
)
type timelineEvent struct {
Timestamp string `json:"timestamp"`
Type string `json:"type"` // "finding", "action", "block"
Severity int `json:"severity"` // 0=info, 1=high, 2=critical
Summary string `json:"summary"`
Details string `json:"details,omitempty"`
Source string `json:"source"` // "history", "audit", "firewall"
}
func (s *Server) handleIncident(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "incident.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
func (s *Server) apiIncident(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
account := r.URL.Query().Get("account")
if ip == "" && account == "" {
writeJSONError(w, "ip or account parameter is required", http.StatusBadRequest)
return
}
hours := queryInt(r, "hours", 72)
if hours > 720 {
hours = 720
} // max 30 days
cutoff := time.Now().Add(-time.Duration(hours) * time.Hour)
var events []timelineEvent
// Build search terms
var searchTerms []string
if ip != "" {
searchTerms = append(searchTerms, ip)
}
if account != "" {
searchTerms = append(searchTerms, "/home/"+account+"/", account)
}
// Search finding history
allHistory, _ := s.store.ReadHistory(3000, 0)
for _, f := range allHistory {
if f.Timestamp.Before(cutoff) {
continue
}
matched := false
for _, term := range searchTerms {
if strings.Contains(f.Message, term) || strings.Contains(f.Details, term) {
matched = true
break
}
}
if !matched {
continue
}
events = append(events, timelineEvent{
Timestamp: f.Timestamp.Format(time.RFC3339),
Type: "finding",
Severity: int(f.Severity),
Summary: f.Check + ": " + f.Message,
Details: f.Details,
Source: "history",
})
}
// Search UI audit log
auditEntries := readUIAuditLog(s.cfg.StatePath, 500)
for _, a := range auditEntries {
if a.Timestamp.Before(cutoff) {
continue
}
matched := false
for _, term := range searchTerms {
if strings.Contains(a.Target, term) || strings.Contains(a.Details, term) {
matched = true
break
}
}
if !matched {
continue
}
events = append(events, timelineEvent{
Timestamp: a.Timestamp.Format(time.RFC3339),
Type: "action",
Severity: 0,
Summary: a.Action + ": " + a.Target,
Details: a.Details,
Source: "audit",
})
}
// Sort by timestamp descending (newest first)
sort.Slice(events, func(i, j int) bool {
return events[i].Timestamp > events[j].Timestamp
})
// Limit to 200 events
if len(events) > 200 {
events = events[:200]
}
writeJSON(w, map[string]interface{}{
"events": events,
"total": len(events),
"query_ip": ip,
"query_account": account,
"hours": hours,
})
}
package webui
import (
"net/http"
"sort"
"strconv"
"strings"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/store"
)
func (s *Server) handleModSec(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "modsec.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
// modsecBlockView is an aggregated view of blocks per IP+rule.
type modsecBlockView struct {
IP string `json:"ip"`
RuleID string `json:"rule_id"`
Description string `json:"description"`
Domains string `json:"domains"`
Hits int `json:"hits"`
LastSeen string `json:"last_seen"`
Escalated bool `json:"escalated"`
}
// modsecEventView is a single ModSecurity event.
type modsecEventView struct {
Time string `json:"time"`
IP string `json:"ip"`
RuleID string `json:"rule_id"`
Hostname string `json:"hostname"`
URI string `json:"uri"`
Severity string `json:"severity"`
}
// apiModSecStats returns 24h summary stats for ModSecurity blocks.
func (s *Server) apiModSecStats(w http.ResponseWriter, _ *http.Request) {
findings := deduplicateModSecFindings(s.modsecFindings24h())
uniqueIPs := make(map[string]bool)
ruleCounts := make(map[string]int)
escalated := 0
for _, f := range findings {
ip := extractModSecIP(f)
if ip != "" {
uniqueIPs[ip] = true
}
rule := extractModSecRule(f)
if rule != "" {
ruleCounts[rule]++
}
if f.Check == "modsec_csm_block_escalation" {
escalated++
}
}
topRule := "--"
topCount := 0
for rule, count := range ruleCounts {
if count > topCount {
topCount = count
topRule = rule
}
}
writeJSON(w, map[string]interface{}{
"total": len(findings),
"unique_ips": len(uniqueIPs),
"escalated": escalated,
"top_rule": topRule,
})
}
// apiModSecBlocks returns aggregated blocks per IP+rule for the last 24h.
func (s *Server) apiModSecBlocks(w http.ResponseWriter, _ *http.Request) {
findings := deduplicateModSecFindings(s.modsecFindings24h())
// Aggregate by IP
type ipAgg struct {
ruleID string
description string
domains map[string]bool
hits int
lastSeen time.Time
escalated bool
}
byIP := make(map[string]*ipAgg)
for _, f := range findings {
if f.Check == "modsec_csm_block_escalation" {
// Mark IP as escalated
ip := extractModSecIP(f)
if ip != "" {
if agg, ok := byIP[ip]; ok {
agg.escalated = true
} else {
byIP[ip] = &ipAgg{escalated: true, domains: make(map[string]bool)}
}
}
continue
}
ip := extractModSecIP(f)
if ip == "" {
continue
}
rule := extractModSecRule(f)
desc := extractModSecDescription(f)
domain := extractModSecHostname(f)
agg, ok := byIP[ip]
if !ok {
agg = &ipAgg{
ruleID: rule,
description: desc,
domains: make(map[string]bool),
}
byIP[ip] = agg
}
agg.hits++
if f.Timestamp.After(agg.lastSeen) {
agg.lastSeen = f.Timestamp
// Update rule/desc to the most recent
if rule != "" {
agg.ruleID = rule
}
if desc != "" {
agg.description = desc
}
}
// Skip server IPs and empty hostnames - only show actual domain names.
// ModSecurity logs the server IP as hostname when the request doesn't
// match a specific vhost (e.g. direct IP access, SNI mismatch).
if domain != "" && !looksLikeIP(domain) {
agg.domains[domain] = true
}
}
var result []modsecBlockView
for ip, agg := range byIP {
if agg.hits == 0 && !agg.escalated {
continue
}
var domainList []string
for d := range agg.domains {
domainList = append(domainList, d)
}
sort.Strings(domainList)
domains := strings.Join(domainList, ", ")
if len(domains) > 80 {
domains = domains[:77] + "..."
}
lastSeen := ""
if !agg.lastSeen.IsZero() {
lastSeen = agg.lastSeen.Format("15:04:05")
}
result = append(result, modsecBlockView{
IP: ip,
RuleID: agg.ruleID,
Description: agg.description,
Domains: domains,
Hits: agg.hits,
LastSeen: lastSeen,
Escalated: agg.escalated,
})
}
// Sort by hits descending
sort.Slice(result, func(i, j int) bool {
return result[i].Hits > result[j].Hits
})
writeJSON(w, result)
}
// apiModSecEvents returns the most recent individual ModSecurity events.
func (s *Server) apiModSecEvents(w http.ResponseWriter, r *http.Request) {
limit := 100
if l := r.URL.Query().Get("limit"); l != "" {
if n, err := strconv.Atoi(l); err == nil && n > 0 && n <= 500 {
limit = n
}
}
findings := deduplicateModSecFindings(s.modsecFindings24h())
// Collect from the tail (newest entries) to avoid reversing the entire slice
start := len(findings) - limit
if start < 0 {
start = 0
}
result := make([]modsecEventView, 0, limit)
for i := len(findings) - 1; i >= start; i-- {
f := findings[i]
if f.Check == "modsec_csm_block_escalation" {
continue
}
if len(result) >= limit {
break
}
result = append(result, modsecEventView{
Time: f.Timestamp.Format("15:04:05"),
IP: extractModSecIP(f),
RuleID: extractModSecRule(f),
Hostname: extractModSecHostname(f),
URI: extractModSecURI(f),
Severity: f.Severity.String(),
})
}
writeJSON(w, result)
}
// deduplicateModSecFindings merges Apache + LiteSpeed duplicate events.
// Both log the same block within the same second - keep one with merged fields.
func deduplicateModSecFindings(findings []alert.Finding) []alert.Finding {
type dedupKey struct {
second string
ip string
rule string
}
seen := make(map[dedupKey]int) // key → index in result
var result []alert.Finding
for _, f := range findings {
ip := extractModSecIP(f)
rule := extractModSecRule(f)
ts := f.Timestamp.Format("15:04:05")
key := dedupKey{second: ts, ip: ip, rule: rule}
if idx, ok := seen[key]; ok {
// Merge richer details into existing entry
existing := &result[idx]
if extractModSecHostname(f) != "" && extractModSecHostname(*existing) == "" {
existing.Details = f.Details
}
} else {
seen[key] = len(result)
result = append(result, f)
}
}
return result
}
// modsecFindings24h returns all modsec findings from the last 24 hours.
func (s *Server) modsecFindings24h() []alert.Finding {
db := store.Global()
if db == nil {
return nil
}
// ReadHistoryFiltered expects "YYYY-MM-DD" for the from parameter.
// Use yesterday's date to ensure we cover the full 24h window.
cutoff := time.Now().Add(-24 * time.Hour)
all, _ := db.ReadHistoryFiltered(10000, 0, cutoff.Format("2006-01-02"), "", -1, "modsec_")
// Further filter to exact 24h window (from prefix is date-level granularity)
var filtered []alert.Finding
for _, f := range all {
if f.Timestamp.After(cutoff) {
filtered = append(filtered, f)
}
}
return filtered
}
// --- Field extraction from finding Details ---
// Details format: "Rule: NNNN\nMessage: ...\nHostname: ...\nURI: ..."
func extractModSecIP(f alert.Finding) string {
// Try from message: "... from IP on ..." or "... from IP ..."
msg := f.Message
if idx := strings.Index(msg, " from "); idx >= 0 {
rest := msg[idx+6:]
if sp := strings.IndexAny(rest, " \n"); sp >= 0 {
rest = rest[:sp]
}
if len(rest) >= 7 && rest[0] >= '0' && rest[0] <= '9' && strings.Count(rest, ".") == 3 {
return rest
}
}
// Fallback: parse [client IP] from raw log line in Details
if ip := extractBetween(f.Details, "[client ", "]"); ip != "" {
// Strip port if present (Apache 2.4: "IP:port")
if strings.Count(ip, ":") == 1 {
if idx := strings.LastIndex(ip, ":"); idx > 0 {
ip = ip[:idx]
}
}
return ip
}
// Fallback: LiteSpeed format - IP in [IP:PORT-CONN#VHOST]
for _, field := range strings.Fields(f.Details) {
if strings.HasPrefix(field, "[") && strings.Contains(field, "#") {
inner := strings.TrimPrefix(field, "[")
if colonIdx := strings.Index(inner, ":"); colonIdx > 0 {
ip := inner[:colonIdx]
if len(ip) >= 7 && ip[0] >= '0' && ip[0] <= '9' {
return ip
}
}
}
}
return ""
}
func extractModSecRule(f alert.Finding) string {
// Try structured format first
if v := extractDetailField(f.Details, "Rule: "); v != "" {
return v
}
// Fallback: parse [id "NNNN"] from raw log line in Details
return extractBetween(f.Details, `[id "`, `"]`)
}
// csmRuleDescriptions provides fallback descriptions for CSM custom rules.
// LiteSpeed error logs omit the [msg "..."] field, so the log-extracted
// description is often empty. This map ensures the UI always shows a
// meaningful description for rules we define ourselves.
var csmRuleDescriptions = map[string]string{
"900001": "Blocked LEVIATHAN CGI extension access",
"900002": "Blocked LEVIATHAN directory access",
"900003": "Blocked PHP execution in uploads directory",
"900004": "Blocked PHP execution in languages directory",
"900005": "Blocked direct wp-config.php access",
"900007": "XML-RPC rate limit exceeded",
"900008": "Blocked known webshell filename access",
"900009": "Blocked GSocket User-Agent",
"900100": "WP-Automatic SQLi (CVE-2024-27956)",
"900101": "LayerSlider SQLi (CVE-2024-2879)",
"900102": "Really Simple Security auth bypass (CVE-2024-10924)",
"900103": "LiteSpeed Cache directory traversal (CVE-2024-4345)",
"900104": "Ultimate Member SQLi (CVE-2024-1071)",
"900105": "Backup Migration RCE (CVE-2023-6553)",
"900106": "GiveWP object injection (CVE-2024-5932)",
"900107": "WP File Manager arbitrary upload (CVE-2024-3400)",
"900110": "PHP object injection attempt",
"900111": "Blocked PHP in wp-content/upgrade",
"900112": "WordPress user enumeration blocked",
"900113": "wp-login brute force rate limit",
"900114": "wp-login brute force rate limit",
"900115": "Blocked .env file access",
"900116": "Blocked scanner probe",
"900120": "Blocked wp-coder preview endpoint",
"900121": "Blocked wp-coder attributes endpoint",
// Comodo WAF (CWAF) common rules. Rule IDs in the 21xxxx range are
// from the Comodo vendor ruleset (e.g. /etc/apache2/conf.d/
// modsec_vendor_configs/comodo_litespeed/), NOT from OWASP CRS.
// They were previously mislabeled as "OWASP:" here.
"210710": "Comodo WAF: HTTP Request Smuggling",
"210381": "Comodo WAF: HTTP Request Smuggling",
"214930": "Comodo WAF: Inbound anomaly score threshold exceeded",
"214940": "Comodo WAF: Outbound anomaly score threshold exceeded",
"218420": "Comodo WAF: Request content type restriction",
// OWASP CRS common rules. IDs in the 9xxxxx range are the standard
// OWASP CRS 3.x schema (920xxx protocol, 930xxx LFI, 941xxx XSS,
// 942xxx SQLi).
"920170": "OWASP: Validate GET/HEAD request",
"920420": "OWASP: Request content type is not allowed by policy",
"920600": "OWASP: Illegal Accept header",
"930100": "OWASP: Path traversal attack",
"930110": "OWASP: Path traversal attack",
"930120": "OWASP: OS file access attempt",
"941100": "OWASP: XSS attack detected via libinjection",
"941160": "OWASP: XSS Filter - Category 1",
"942100": "OWASP: SQL injection attack detected via libinjection",
}
func extractModSecDescription(f alert.Finding) string {
if v := extractDetailField(f.Details, "Message: "); v != "" {
return v
}
if v := extractBetween(f.Details, `[msg "`, `"]`); v != "" {
return v
}
// Fallback: use static description for CSM custom rules when the log
// format (e.g. LiteSpeed) doesn't include the [msg "..."] field.
rule := extractModSecRule(f)
if desc, ok := csmRuleDescriptions[rule]; ok {
return desc
}
return ""
}
func extractModSecHostname(f alert.Finding) string {
if v := extractDetailField(f.Details, "Hostname: "); v != "" {
return v
}
return extractBetween(f.Details, `[hostname "`, `"]`)
}
func extractModSecURI(f alert.Finding) string {
if v := extractDetailField(f.Details, "URI: "); v != "" {
return v
}
return extractBetween(f.Details, `[uri "`, `"]`)
}
func extractDetailField(details, prefix string) string {
for _, line := range strings.Split(details, "\n") {
if strings.HasPrefix(line, prefix) {
return strings.TrimPrefix(line, prefix)
}
}
return ""
}
// looksLikeIP returns true if the string looks like an IP address (not a domain).
func looksLikeIP(s string) bool {
if len(s) < 7 {
return false
}
for _, c := range s {
if c != '.' && (c < '0' || c > '9') {
return false
}
}
return strings.Count(s, ".") == 3
}
// extractBetween extracts the value between start and end delimiters.
// Used as fallback for old findings where Details is the raw log line.
func extractBetween(s, start, end string) string {
idx := strings.Index(s, start)
if idx < 0 {
return ""
}
rest := s[idx+len(start):]
endIdx := strings.Index(rest, end)
if endIdx < 0 {
return ""
}
return rest[:endIdx]
}
package webui
import (
"fmt"
"net/http"
"os"
"github.com/pidginhost/csm/internal/modsec"
"github.com/pidginhost/csm/internal/store"
)
func validateModSecDisabledRules(allRules []modsec.Rule, disabled []int) error {
knownIDs := make(map[int]bool)
counterIDs := make(map[int]bool)
for _, rule := range allRules {
knownIDs[rule.ID] = true
if rule.IsCounter {
counterIDs[rule.ID] = true
}
}
for _, id := range disabled {
if !knownIDs[id] {
return fmt.Errorf("rule ID %d is not a known CSM rule", id)
}
if counterIDs[id] {
return fmt.Errorf("rule ID %d is a protected bookkeeping rule", id)
}
}
return nil
}
func (s *Server) handleModSecRules(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "modsec-rules.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
// GET /api/v1/modsec/rules - list all CSM rules with status
func (s *Server) apiModSecRules(w http.ResponseWriter, _ *http.Request) {
cfg := s.cfg.ModSec
// Check all three config fields
var missing []string
if cfg.RulesFile == "" {
missing = append(missing, "rules_file")
}
if cfg.OverridesFile == "" {
missing = append(missing, "overrides_file")
}
if cfg.ReloadCommand == "" {
missing = append(missing, "reload_command")
}
if len(missing) > 0 {
writeJSON(w, map[string]interface{}{
"configured": false,
"missing": missing,
})
return
}
// Parse rules from config file
allRules, err := modsec.ParseRulesFile(cfg.RulesFile)
if err != nil {
fmt.Fprintf(os.Stderr, "modsec: parse rules failed: %v\n", err)
writeJSONError(w, "Failed to parse rules file", http.StatusInternalServerError)
return
}
// Read disabled IDs from overrides file
disabledIDs, _ := modsec.ReadOverrides(cfg.OverridesFile)
disabledSet := make(map[int]bool)
for _, id := range disabledIDs {
disabledSet[id] = true
}
// Read escalation exclusions and hit counts from store
var noEscalate map[int]bool
var hits map[int]store.RuleHitStats
if db := store.Global(); db != nil {
noEscalate = db.GetModSecNoEscalateRules()
hits = db.GetModSecRuleHits()
}
// Build response - filter out counter rules
type ruleView struct {
ID int `json:"id"`
Description string `json:"description"`
Action string `json:"action"`
StatusCode int `json:"status_code"`
Phase int `json:"phase"`
Enabled bool `json:"enabled"`
Escalate bool `json:"escalate"`
Hits24h int `json:"hits_24h"`
LastHit string `json:"last_hit,omitempty"`
}
var rules []ruleView
for _, r := range allRules {
if r.IsCounter {
continue // hide bookkeeping rules
}
rv := ruleView{
ID: r.ID,
Description: r.Description,
Action: r.Action,
StatusCode: r.StatusCode,
Phase: r.Phase,
Enabled: !disabledSet[r.ID],
Escalate: !noEscalate[r.ID],
}
if h, ok := hits[r.ID]; ok {
rv.Hits24h = h.Hits
if !h.LastHit.IsZero() {
rv.LastHit = h.LastHit.Format("2006-01-02T15:04:05Z07:00")
}
}
rules = append(rules, rv)
}
active := 0
for _, r := range rules {
if r.Enabled {
active++
}
}
writeJSON(w, map[string]interface{}{
"rules": rules,
"total": len(rules),
"active": active,
"disabled": disabledIDs,
"configured": true,
})
}
// POST /api/v1/modsec/rules/apply - write overrides and reload
func (s *Server) apiModSecRulesApply(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Serialize apply operations - the write+reload+rollback sequence
// must not interleave with concurrent applies.
s.modSecApplyMu.Lock()
defer s.modSecApplyMu.Unlock()
cfg := s.cfg.ModSec
if cfg.RulesFile == "" || cfg.OverridesFile == "" || cfg.ReloadCommand == "" {
writeJSONError(w, "ModSecurity not configured", http.StatusBadRequest)
return
}
var req struct {
Disabled []int `json:"disabled"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate: only allow disabling known CSM rule IDs from the parsed rules file
allRules, err := modsec.ParseRulesFile(cfg.RulesFile)
if err != nil {
fmt.Fprintf(os.Stderr, "modsec: parse rules failed: %v\n", err)
writeJSONError(w, "Failed to parse rules file", http.StatusInternalServerError)
return
}
if err := validateModSecDisabledRules(allRules, req.Disabled); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
// Save previous state for rollback
previousContent := modsec.ReadOverridesRaw(cfg.OverridesFile)
// Write new overrides
if writeErr := modsec.WriteOverrides(cfg.OverridesFile, req.Disabled); writeErr != nil {
fmt.Fprintf(os.Stderr, "modsec: write overrides failed: %v\n", writeErr)
writeJSONError(w, "Failed to write overrides", http.StatusInternalServerError)
return
}
// Reload web server
output, reloadErr := modsec.Reload(cfg.ReloadCommand)
if reloadErr != nil {
// Rollback on failure
_ = modsec.RestoreOverrides(cfg.OverridesFile, previousContent)
fmt.Fprintf(os.Stderr, "modsec: reload failed (rolled back): %v\noutput: %s\n", reloadErr, output)
// Truncate output for client - may contain sensitive system paths
clientOutput := output
if len(clientOutput) > 500 {
clientOutput = clientOutput[:500] + "... (truncated)"
}
writeJSON(w, map[string]interface{}{
"ok": false,
"error": "Web server reload failed, changes rolled back",
"reload_output": clientOutput,
"rolled_back": true,
})
return
}
writeJSON(w, map[string]interface{}{
"ok": true,
"disabled_count": len(req.Disabled),
})
}
// POST /api/v1/modsec/rules/escalation - toggle escalation for a single rule
func (s *Server) apiModSecRulesEscalation(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
db := store.Global()
if db == nil {
writeJSONError(w, "Store not available", http.StatusInternalServerError)
return
}
var req struct {
RuleID int `json:"rule_id"`
Escalate bool `json:"escalate"`
}
if err := decodeJSONBodyLimited(w, r, 4*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate: only accept known CSM rule IDs (900000-900999)
if req.RuleID < 900000 || req.RuleID > 900999 {
writeJSONError(w, "Rule ID must be a CSM custom rule (900000-900999)", http.StatusBadRequest)
return
}
var err error
if req.Escalate {
err = db.RemoveModSecNoEscalateRule(req.RuleID)
} else {
err = db.AddModSecNoEscalateRule(req.RuleID)
}
if err != nil {
fmt.Fprintf(os.Stderr, "modsec: escalation update failed: %v\n", err)
writeJSONError(w, "Failed to update escalation setting", http.StatusInternalServerError)
return
}
writeJSON(w, map[string]interface{}{"ok": true})
}
package webui
import (
"bufio"
"context"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"
)
// --- Local response types ---
type perfResponse struct {
Metrics *perfMetrics `json:"metrics"`
Findings []perfFindingView `json:"findings"`
}
type perfMetrics struct {
LoadAvg [3]float64 `json:"load_avg"`
CPUCores int `json:"cpu_cores"`
MemTotalMB uint64 `json:"mem_total_mb"`
MemUsedMB uint64 `json:"mem_used_mb"`
MemAvailMB uint64 `json:"mem_avail_mb"`
SwapTotalMB uint64 `json:"swap_total_mb"`
SwapUsedMB uint64 `json:"swap_used_mb"`
PHPProcs int `json:"php_procs_total"`
TopPHPUsers []userProcs `json:"top_php_users"`
MySQLMemMB uint64 `json:"mysql_mem_mb"`
MySQLConns int `json:"mysql_conns"`
RedisMemMB uint64 `json:"redis_mem_mb"`
RedisMaxMB uint64 `json:"redis_maxmem_mb"`
RedisKeys int64 `json:"redis_keys"`
Uptime string `json:"uptime"`
}
type userProcs struct {
User string `json:"user"`
Count int `json:"count"`
}
type perfFindingView struct {
Severity int `json:"severity"`
SevClass string `json:"sev_class"`
Check string `json:"check"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
FirstSeen string `json:"first_seen"`
LastSeen string `json:"last_seen"`
}
// --- Cached values ---
var (
perfCoresOnce sync.Once
perfCoresCache int
perfUIDMapOnce sync.Once
perfUIDMapCache map[string]string
)
// cachedCores reads /proc/cpuinfo once and counts "processor\t" lines.
func cachedCores() int {
perfCoresOnce.Do(func() {
f, err := os.Open("/proc/cpuinfo")
if err != nil {
perfCoresCache = 1
return
}
defer func() { _ = f.Close() }()
count := 0
scanner := bufio.NewScanner(f)
for scanner.Scan() {
if strings.HasPrefix(scanner.Text(), "processor\t") {
count++
}
}
if count == 0 {
count = 1
}
perfCoresCache = count
})
return perfCoresCache
}
// cachedUID resolves a UID string to a username via /etc/passwd, cached.
func cachedUID(uid string) string {
perfUIDMapOnce.Do(func() {
perfUIDMapCache = make(map[string]string)
data, err := os.ReadFile("/etc/passwd")
if err != nil {
return
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Split(line, ":")
if len(fields) >= 3 {
perfUIDMapCache[fields[2]] = fields[0]
}
}
})
if name, ok := perfUIDMapCache[uid]; ok {
return name
}
return uid
}
// --- Metrics sampler ---
// runCmdQuick runs a command with a 5-second timeout. All call sites pass
// constant binary names (mysql, redis-cli, etc.) and literal argument
// lists — no HTTP-controlled input reaches this function.
func runCmdQuick(name string, args ...string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// #nosec G204 -- see function-level comment: constant names/args only.
out, err := exec.CommandContext(ctx, name, args...).Output()
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("command timed out: %s", name)
}
return out, err
}
// sampleMetrics gathers live system metrics and returns a populated perfMetrics.
func sampleMetrics() *perfMetrics {
m := &perfMetrics{}
// Load averages
if data, err := os.ReadFile("/proc/loadavg"); err == nil {
fields := strings.Fields(string(data))
if len(fields) >= 3 {
for i := 0; i < 3; i++ {
v, _ := strconv.ParseFloat(fields[i], 64)
m.LoadAvg[i] = v
}
}
}
// CPU cores
m.CPUCores = cachedCores()
// Memory from /proc/meminfo
{
f, err := os.Open("/proc/meminfo")
if err == nil {
var memTotal, memAvail, memFree, memBuffers, memCached, swapTotal, swapFree uint64
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
val, _ := strconv.ParseUint(fields[1], 10, 64)
switch fields[0] {
case "MemTotal:":
memTotal = val
case "MemAvailable:":
memAvail = val
case "MemFree:":
memFree = val
case "Buffers:":
memBuffers = val
case "Cached:":
memCached = val
case "SwapTotal:":
swapTotal = val
case "SwapFree:":
swapFree = val
}
}
_ = f.Close()
m.MemTotalMB = memTotal / 1024
m.MemAvailMB = memAvail / 1024
// Used = Total - Free - Buffers - Cached
used := memTotal
if memFree+memBuffers+memCached <= memTotal {
used = memTotal - memFree - memBuffers - memCached
}
m.MemUsedMB = used / 1024
m.SwapTotalMB = swapTotal / 1024
if swapFree <= swapTotal {
m.SwapUsedMB = (swapTotal - swapFree) / 1024
}
}
}
// PHP processes: scan /proc/*/cmdline for lsphp
{
cmdlinePaths, _ := filepath.Glob("/proc/[0-9]*/cmdline")
userCounts := make(map[string]int)
total := 0
for _, cmdPath := range cmdlinePaths {
// #nosec G304 -- cmdPath from /proc/*/cmdline glob; kernel pseudo-FS.
data, err := os.ReadFile(cmdPath)
if err != nil {
continue
}
cmdStr := strings.ReplaceAll(string(data), "\x00", " ")
if !strings.Contains(cmdStr, "lsphp") {
continue
}
pid := filepath.Base(filepath.Dir(cmdPath))
// #nosec G304 -- /proc/<pid>/status; kernel pseudo-FS, pid from /proc glob.
statusData, _ := os.ReadFile(filepath.Join("/proc", pid, "status"))
uid := ""
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "Uid:\t") {
f := strings.Fields(strings.TrimPrefix(line, "Uid:\t"))
if len(f) > 0 {
uid = f[0]
}
break
}
}
if uid == "" {
uid = "unknown"
}
username := cachedUID(uid)
userCounts[username]++
total++
}
m.PHPProcs = total
// Build sorted top-10 list
type up struct {
user string
count int
}
var ups []up
for u, c := range userCounts {
ups = append(ups, up{u, c})
}
sort.Slice(ups, func(i, j int) bool {
return ups[i].count > ups[j].count
})
if len(ups) > 10 {
ups = ups[:10]
}
m.TopPHPUsers = make([]userProcs, len(ups))
for i, u := range ups {
m.TopPHPUsers[i] = userProcs{User: u.user, Count: u.count}
}
}
// MySQL: PID → VmRSS, plus Threads_connected
{
pidData, err := os.ReadFile("/var/run/mysqld/mysqld.pid")
if err == nil {
mysqlPID := strings.TrimSpace(string(pidData))
if mysqlPID != "" {
// #nosec G304 G703 -- mysqlPID is read from mysqld's own
// /var/run/mysqld/mysqld.pid and we're reading the kernel
// /proc pseudo-filesystem.
statusData, _ := os.ReadFile(filepath.Join("/proc", mysqlPID, "status"))
for _, line := range strings.Split(string(statusData), "\n") {
if strings.HasPrefix(line, "VmRSS:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
kb, _ := strconv.ParseUint(fields[1], 10, 64)
m.MySQLMemMB = kb / 1024
}
break
}
}
}
}
// Connection count
out, err := runCmdQuick("mysql", "-N", "-B", "-e", "SHOW STATUS LIKE 'Threads_connected'")
if err == nil && len(out) > 0 {
fields := strings.Fields(string(out))
if len(fields) >= 2 {
n, _ := strconv.Atoi(fields[1])
m.MySQLConns = n
}
}
}
// Redis: memory + keyspace
{
memOut, err := runCmdQuick("redis-cli", "info", "memory")
if err == nil && len(memOut) > 0 {
for _, line := range strings.Split(string(memOut), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "used_memory:") {
val := strings.TrimPrefix(line, "used_memory:")
bytes, _ := strconv.ParseUint(strings.TrimSpace(val), 10, 64)
m.RedisMemMB = bytes / (1024 * 1024)
} else if strings.HasPrefix(line, "maxmemory:") {
val := strings.TrimPrefix(line, "maxmemory:")
bytes, _ := strconv.ParseUint(strings.TrimSpace(val), 10, 64)
m.RedisMaxMB = bytes / (1024 * 1024)
}
}
}
ksOut, err := runCmdQuick("redis-cli", "info", "keyspace")
if err == nil && len(ksOut) > 0 {
var totalKeys int64
for _, line := range strings.Split(string(ksOut), "\n") {
line = strings.TrimSpace(line)
// Lines like: db0:keys=1234,expires=5,avg_ttl=0
if !strings.HasPrefix(line, "db") {
continue
}
parts := strings.SplitN(line, ":", 2)
if len(parts) < 2 {
continue
}
for _, kv := range strings.Split(parts[1], ",") {
if strings.HasPrefix(kv, "keys=") {
n, _ := strconv.ParseInt(strings.TrimPrefix(kv, "keys="), 10, 64)
totalKeys += n
}
}
}
m.RedisKeys = totalKeys
}
}
// Uptime from /proc/uptime
{
data, err := os.ReadFile("/proc/uptime")
if err == nil {
fields := strings.Fields(string(data))
if len(fields) >= 1 {
secs, _ := strconv.ParseFloat(fields[0], 64)
d := time.Duration(secs) * time.Second
days := int(d.Hours()) / 24
hours := int(d.Hours()) % 24
m.Uptime = fmt.Sprintf("%dd %dh", days, hours)
}
}
}
return m
}
// sampleMetricsLoop samples metrics immediately and then every 10 seconds.
func (s *Server) sampleMetricsLoop(ctx context.Context) {
result := sampleMetrics()
s.perfSnapshot.Store(result)
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
result := sampleMetrics()
s.perfSnapshot.Store(result)
}
}
}
// apiPerformance returns the latest performance snapshot plus perf_ findings.
func (s *Server) apiPerformance(w http.ResponseWriter, r *http.Request) {
limit := queryInt(r, "limit", 100)
if limit > 500 {
limit = 500
}
metrics := s.perfSnapshot.Load()
latest := s.store.LatestFindings()
suppressions := s.store.LoadSuppressions()
var views []perfFindingView
for _, f := range latest {
if !strings.HasPrefix(f.Check, "perf_") {
continue
}
if s.store.IsSuppressed(f, suppressions) {
continue
}
firstSeen := f.Timestamp
lastSeen := f.Timestamp
if entry, ok := s.store.EntryForKey(f.Key()); ok {
firstSeen = entry.FirstSeen
lastSeen = entry.LastSeen
}
views = append(views, perfFindingView{
Severity: int(f.Severity),
SevClass: severityClass(f.Severity),
Check: f.Check,
Message: f.Message,
Details: f.Details,
FirstSeen: firstSeen.Format(time.RFC3339),
LastSeen: lastSeen.Format(time.RFC3339),
})
}
// Sort by severity descending
sort.Slice(views, func(i, j int) bool {
return views[i].Severity > views[j].Severity
})
if len(views) > limit {
views = views[:limit]
}
writeJSON(w, perfResponse{
Metrics: metrics,
Findings: views,
})
}
// handlePerformance renders the performance dashboard page.
func (s *Server) handlePerformance(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "performance.html", nil)
}
package webui
import (
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/pidginhost/csm/internal/signatures"
"github.com/pidginhost/csm/internal/store"
"github.com/pidginhost/csm/internal/yara"
)
func (s *Server) handleRules(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "rules.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
// GET /api/v1/rules/status
func (s *Server) apiRulesStatus(w http.ResponseWriter, _ *http.Request) {
yamlCount := 0
yamlVersion := 0
if scanner := signatures.Global(); scanner != nil {
yamlCount = scanner.RuleCount()
yamlVersion = scanner.Version()
}
yaraCount := 0
if yaraScanner := yara.Global(); yaraScanner != nil {
yaraCount = yaraScanner.RuleCount()
}
result := map[string]interface{}{
"yaml_rules": yamlCount,
"yara_rules": yaraCount,
"yara_available": yara.Available(),
"yaml_version": yamlVersion,
"rules_dir": s.cfg.Signatures.RulesDir,
"auto_update": s.cfg.Signatures.UpdateURL != "",
"update_url": s.cfg.Signatures.UpdateURL,
"update_interval": s.cfg.Signatures.UpdateInterval,
}
writeJSON(w, result)
}
// GET /api/v1/rules/list
func (s *Server) apiRulesList(w http.ResponseWriter, _ *http.Request) {
rulesDir := s.cfg.Signatures.RulesDir
type ruleFileInfo struct {
Name string `json:"name"`
Type string `json:"type"` // "yaml" or "yara"
Size int64 `json:"size"`
}
var files []ruleFileInfo
entries, err := os.ReadDir(rulesDir)
if err != nil {
if os.IsNotExist(err) {
writeJSON(w, files)
return
}
writeJSONError(w, fmt.Sprintf("reading rules directory: %v", err), http.StatusInternalServerError)
return
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
ext := strings.ToLower(filepath.Ext(name))
var fileType string
switch ext {
case ".yml", ".yaml":
fileType = "yaml"
case ".yar", ".yara":
fileType = "yara"
default:
continue // skip non-rule files
}
info, err := entry.Info()
if err != nil {
continue
}
files = append(files, ruleFileInfo{
Name: name,
Type: fileType,
Size: info.Size(),
})
}
writeJSON(w, files)
}
// POST /api/v1/rules/reload
func (s *Server) apiRulesReload(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var yamlErr, yaraErr error
yamlCount := 0
yaraCount := 0
if scanner := signatures.Global(); scanner != nil {
yamlErr = scanner.Reload()
yamlCount = scanner.RuleCount()
// Update the cached sig count shown in dashboard/status
s.SetSigCount(yamlCount)
}
if yaraScanner := yara.Global(); yaraScanner != nil {
yaraErr = yaraScanner.Reload()
yaraCount = yaraScanner.RuleCount()
}
var errors []string
if yamlErr != nil {
errors = append(errors, fmt.Sprintf("YAML reload: %v", yamlErr))
}
if yaraErr != nil {
errors = append(errors, fmt.Sprintf("YARA reload: %v", yaraErr))
}
result := map[string]interface{}{
"ok": len(errors) == 0,
"yaml_rules": yamlCount,
"yara_rules": yaraCount,
}
if len(errors) > 0 {
result["errors"] = errors
}
writeJSON(w, result)
}
// GET/POST /api/v1/rules/modsec-escalation - manage rules excluded from auto-block
func (s *Server) apiModSecEscalation(w http.ResponseWriter, r *http.Request) {
db := store.Global()
if r.Method == http.MethodPost {
if db == nil {
writeJSONError(w, "Store not available", http.StatusInternalServerError)
return
}
var req struct {
Rules []int `json:"rules"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
rules := make(map[int]bool)
for _, id := range req.Rules {
rules[id] = true
}
if err := db.SetModSecNoEscalateRules(rules); err != nil {
writeJSONError(w, fmt.Sprintf("Save failed: %v", err), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]interface{}{"ok": true, "count": len(rules)})
return
}
// GET
var ids []int
if db != nil {
for id := range db.GetModSecNoEscalateRules() {
ids = append(ids, id)
}
}
if ids == nil {
ids = []int{}
}
writeJSON(w, map[string]interface{}{"rules": ids})
}
package webui
import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"crypto/tls"
"encoding/hex"
"encoding/json"
"fmt"
"html/template"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pidginhost/csm/internal/alert"
"github.com/pidginhost/csm/internal/config"
"github.com/pidginhost/csm/internal/emailav"
"github.com/pidginhost/csm/internal/geoip"
"github.com/pidginhost/csm/internal/state"
)
// IPBlocker abstracts the firewall engine for block/unblock operations.
type IPBlocker interface {
BlockIP(ip string, reason string, timeout time.Duration) error
UnblockIP(ip string) error
}
// Server is the web UI HTTP server. Serves API always; serves HTML pages
// and static files only if the UI directory exists on disk.
type Server struct {
cfg *config.Config
store *state.Store
httpSrv *http.Server
templates map[string]*template.Template
hasUI bool // true if UI directory with templates exists
uiDir string // path to UI directory on disk
startTime time.Time
sigCount int // loaded signature rule count
fanotifyActive bool
logWatcherCount int
blocker IPBlocker
geoIPDB atomic.Pointer[geoip.DB]
emailQuarantine *emailav.Quarantine
emailAVWatcherMode string
version string
perfSnapshot atomic.Pointer[perfMetrics]
perfCancel context.CancelFunc
// Rate limiting
loginMu sync.Mutex
loginAttempts map[string][]time.Time
apiMu sync.Mutex
apiRequests map[string][]time.Time // per-IP API rate limiting
scanMu sync.Mutex
scanRunning bool // only one scan at a time
modSecApplyMu sync.Mutex // serializes modsec rules apply (write+reload+rollback)
// Graceful shutdown signal for background goroutines
pruneDone chan struct{}
}
// New creates a new web UI server.
func New(cfg *config.Config, store *state.Store) (*Server, error) {
s := &Server{
cfg: cfg,
store: store,
startTime: time.Now(),
loginAttempts: make(map[string][]time.Time),
apiRequests: make(map[string][]time.Time),
pruneDone: make(chan struct{}),
}
// Check if UI directory exists on disk
s.uiDir = cfg.WebUI.UIDir
if s.uiDir == "" {
s.uiDir = "/opt/csm/ui"
}
funcMap := template.FuncMap{
"severityClass": severityClass,
"severityLabel": severityLabel,
"timeAgo": timeAgo,
"formatTime": formatTime,
"csrfToken": s.csrfToken,
"csmConfig": func() template.JS { return jsonForScript(s.csmConfig()) },
"json": jsonForScript,
"multiply": func(a, b int) int { return a * b },
"add": func(a, b int) int { return a + b },
"subtract": func(a, b int) int { return a - b },
"divisibleBy": func(a, b int) bool { return b != 0 && a%b == 0 },
}
// Try to load templates from disk
templateDir := filepath.Join(s.uiDir, "templates")
staticDir := filepath.Join(s.uiDir, "static")
if _, err := os.Stat(templateDir); err == nil {
s.templates = make(map[string]*template.Template)
layoutPath := filepath.Join(templateDir, "layout.html")
for _, page := range []string{"dashboard", "findings", "quarantine", "firewall", "modsec", "modsec-rules", "threat", "rules", "audit", "account", "incident", "email", "performance", "hardening"} {
pagePath := filepath.Join(templateDir, page+".html")
t, err := template.New(page+".html").Funcs(funcMap).ParseFiles(layoutPath, pagePath)
if err != nil {
return nil, fmt.Errorf("parsing template %s from %s: %w", page, templateDir, err)
}
s.templates[page+".html"] = t
}
loginPath := filepath.Join(templateDir, "login.html")
loginTmpl, err := template.New("login.html").Funcs(funcMap).ParseFiles(loginPath)
if err != nil {
return nil, fmt.Errorf("parsing login template: %w", err)
}
s.templates["login.html"] = loginTmpl
s.hasUI = true
fmt.Fprintf(os.Stderr, "WebUI: loaded templates from %s\n", templateDir)
} else {
fmt.Fprintf(os.Stderr, "WebUI: UI directory not found at %s - running in API-only mode\n", s.uiDir)
}
// Set up routes
mux := http.NewServeMux()
// Static files and HTML pages - only if UI directory exists
if s.hasUI {
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir(staticDir))))
mux.HandleFunc("/login", s.handleLogin)
mux.Handle("/", s.requireAuth(http.HandlerFunc(s.handleDashboard)))
mux.Handle("/dashboard", s.requireAuth(http.HandlerFunc(s.handleDashboard)))
mux.Handle("/findings", s.requireAuth(http.HandlerFunc(s.handleFindings)))
mux.Handle("/history", s.requireAuth(http.HandlerFunc(s.handleHistoryRedirect)))
mux.Handle("/quarantine", s.requireAuth(http.HandlerFunc(s.handleQuarantine)))
mux.Handle("/blocked", s.requireAuth(http.HandlerFunc(s.handleFirewall))) // redirect old URL
mux.Handle("/firewall", s.requireAuth(http.HandlerFunc(s.handleFirewall)))
mux.Handle("/threat", s.requireAuth(http.HandlerFunc(s.handleThreat)))
mux.Handle("/rules", s.requireAuth(http.HandlerFunc(s.handleRules)))
mux.Handle("/audit", s.requireAuth(http.HandlerFunc(s.handleAudit)))
mux.Handle("/account", s.requireAuth(http.HandlerFunc(s.handleAccount)))
mux.Handle("/incident", s.requireAuth(http.HandlerFunc(s.handleIncident)))
mux.Handle("/email", s.requireAuth(http.HandlerFunc(s.handleEmail)))
mux.Handle("/performance", s.requireAuth(http.HandlerFunc(s.handlePerformance)))
mux.Handle("/hardening", s.requireAuth(http.HandlerFunc(s.handleHardening)))
mux.Handle("/modsec", s.requireAuth(http.HandlerFunc(s.handleModSec)))
mux.Handle("/modsec/rules", s.requireAuth(http.HandlerFunc(s.handleModSecRules)))
}
// Auth-protected API - read
mux.Handle("/api/v1/status", s.requireAuth(http.HandlerFunc(s.apiStatus)))
mux.Handle("/api/v1/findings", s.requireAuth(http.HandlerFunc(s.apiFindings)))
mux.Handle("/api/v1/findings/enriched", s.requireAuth(http.HandlerFunc(s.apiFindingsEnriched)))
mux.Handle("/api/v1/history", s.requireAuth(http.HandlerFunc(s.apiHistory)))
mux.Handle("/api/v1/quarantine", s.requireAuth(http.HandlerFunc(s.apiQuarantine)))
mux.Handle("/api/v1/stats", s.requireAuth(http.HandlerFunc(s.apiStats)))
mux.Handle("/api/v1/stats/trend", s.requireAuth(http.HandlerFunc(s.apiStatsTrend)))
mux.Handle("/api/v1/stats/timeline", s.requireAuth(http.HandlerFunc(s.apiStatsTimeline)))
mux.Handle("/api/v1/blocked-ips", s.requireAuth(http.HandlerFunc(s.apiBlockedIPs)))
mux.Handle("/api/v1/modsec/stats", s.requireAuth(http.HandlerFunc(s.apiModSecStats)))
mux.Handle("/api/v1/modsec/blocks", s.requireAuth(http.HandlerFunc(s.apiModSecBlocks)))
mux.Handle("/api/v1/modsec/events", s.requireAuth(http.HandlerFunc(s.apiModSecEvents)))
mux.Handle("/api/v1/modsec/rules", s.requireAuth(http.HandlerFunc(s.apiModSecRules)))
mux.Handle("/api/v1/modsec/rules/apply", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiModSecRulesApply))))
mux.Handle("/api/v1/modsec/rules/escalation", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiModSecRulesEscalation))))
mux.Handle("/api/v1/health", s.requireAuth(http.HandlerFunc(s.apiHealth)))
mux.Handle("/api/v1/accounts", s.requireAuth(http.HandlerFunc(s.apiAccounts)))
mux.Handle("/api/v1/account", s.requireAuth(http.HandlerFunc(s.apiAccountDetail)))
mux.Handle("/api/v1/history/csv", s.requireAuth(http.HandlerFunc(s.apiHistoryCSV)))
mux.Handle("/api/v1/export", s.requireAuth(http.HandlerFunc(s.apiExport)))
mux.Handle("/api/v1/import", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiImport))))
mux.Handle("/api/v1/incident", s.requireAuth(http.HandlerFunc(s.apiIncident)))
mux.Handle("/api/v1/email/stats", s.requireAuth(http.HandlerFunc(s.apiEmailStats)))
mux.Handle("/api/v1/email/quarantine", s.requireAuth(http.HandlerFunc(s.apiEmailQuarantineList)))
mux.Handle("/api/v1/email/quarantine/", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiEmailQuarantineAction))))
mux.Handle("/api/v1/email/av/status", s.requireAuth(http.HandlerFunc(s.apiEmailAVStatus)))
mux.Handle("/api/v1/performance", s.requireAuth(http.HandlerFunc(s.apiPerformance)))
mux.Handle("/api/v1/hardening", s.requireAuth(http.HandlerFunc(s.apiHardening)))
mux.Handle("/api/v1/hardening/run", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiHardeningRun))))
// Threat Intelligence API
mux.Handle("/api/v1/threat/stats", s.requireAuth(http.HandlerFunc(s.apiThreatStats)))
mux.Handle("/api/v1/threat/top-attackers", s.requireAuth(http.HandlerFunc(s.apiThreatTopAttackers)))
mux.Handle("/api/v1/threat/ip", s.requireAuth(http.HandlerFunc(s.apiThreatIP)))
mux.Handle("/api/v1/threat/events", s.requireAuth(http.HandlerFunc(s.apiThreatEvents)))
mux.Handle("/api/v1/threat/db-stats", s.requireAuth(http.HandlerFunc(s.apiThreatDBStats)))
mux.Handle("/api/v1/audit", s.requireAuth(http.HandlerFunc(s.apiUIAudit)))
mux.Handle("/api/v1/finding-detail", s.requireAuth(http.HandlerFunc(s.apiFindingDetail)))
mux.Handle("/api/v1/threat/whitelist-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatWhitelistIP))))
mux.Handle("/api/v1/threat/whitelist", s.requireAuth(http.HandlerFunc(s.apiThreatWhitelist)))
mux.Handle("/api/v1/threat/unwhitelist-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatUnwhitelistIP))))
mux.Handle("/api/v1/threat/block-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatBlockIP))))
mux.Handle("/api/v1/threat/clear-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatClearIP))))
mux.Handle("/api/v1/threat/temp-whitelist-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatTempWhitelistIP))))
mux.Handle("/api/v1/threat/bulk-action", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiThreatBulkAction))))
// Rules API
mux.Handle("/api/v1/rules/status", s.requireAuth(http.HandlerFunc(s.apiRulesStatus)))
mux.Handle("/api/v1/rules/list", s.requireAuth(http.HandlerFunc(s.apiRulesList)))
mux.Handle("/api/v1/rules/reload", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiRulesReload))))
mux.Handle("/api/v1/rules/modsec-escalation", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiModSecEscalation))))
// Suppressions API
mux.Handle("/api/v1/suppressions", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiSuppressions))))
// Firewall API
mux.Handle("/api/v1/firewall/status", s.requireAuth(http.HandlerFunc(s.apiFirewallStatus)))
mux.Handle("/api/v1/firewall/allowed", s.requireAuth(http.HandlerFunc(s.apiFirewallAllowed)))
mux.Handle("/api/v1/firewall/audit", s.requireAuth(http.HandlerFunc(s.apiFirewallAudit)))
mux.Handle("/api/v1/firewall/subnets", s.requireAuth(http.HandlerFunc(s.apiFirewallSubnets)))
mux.Handle("/api/v1/firewall/check", s.requireAuth(http.HandlerFunc(s.apiFirewallCheck)))
// GeoIP API
mux.Handle("/api/v1/geoip", s.requireAuth(http.HandlerFunc(s.apiGeoIPLookup)))
mux.Handle("/api/v1/geoip/batch", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiGeoIPBatch))))
// Auth-protected API - actions (with CSRF validation)
mux.Handle("/api/v1/fix", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFix))))
mux.Handle("/api/v1/fix-bulk", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiBulkFix))))
mux.Handle("/api/v1/scan-account", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiScanAccount))))
mux.Handle("/api/v1/test-alert", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiTestAlert))))
mux.Handle("/api/v1/block-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiBlockIP))))
mux.Handle("/api/v1/unblock-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiUnblockIP))))
mux.Handle("/api/v1/unblock-bulk", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiUnblockBulk))))
mux.Handle("/api/v1/dismiss", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiDismissFinding))))
mux.Handle("/api/v1/quarantine-preview", s.requireAuth(http.HandlerFunc(s.apiQuarantinePreview)))
mux.Handle("/api/v1/quarantine-restore", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiQuarantineRestore))))
mux.Handle("/api/v1/quarantine/bulk-delete", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiQuarantineBulkDelete))))
mux.Handle("/api/v1/firewall/deny-subnet", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallDenySubnet))))
mux.Handle("/api/v1/firewall/allow-ip", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallAllowIP))))
mux.Handle("/api/v1/firewall/remove-allow", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallRemoveAllow))))
mux.Handle("/api/v1/firewall/remove-subnet", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallRemoveSubnet))))
mux.Handle("/api/v1/firewall/cphulk-clear", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallFlushCphulk))))
mux.Handle("/api/v1/firewall/flush", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallFlush))))
mux.Handle("/api/v1/firewall/unban", s.requireAuth(s.requireCSRF(http.HandlerFunc(s.apiFirewallUnban))))
// Logout (clears cookie, requires auth to prevent logout CSRF)
mux.Handle("/logout", s.requireAuth(http.HandlerFunc(s.handleLogout)))
s.httpSrv = &http.Server{
Addr: cfg.WebUI.Listen,
Handler: s.securityHeaders(mux),
ReadHeaderTimeout: 10 * time.Second, // slowloris protection
ReadTimeout: 30 * time.Second, // max time to read full request
WriteTimeout: 300 * time.Second, // account scans can take several minutes
IdleTimeout: 120 * time.Second,
MaxHeaderBytes: 1 << 20, // 1MB
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
// Disable HTTP/2: Go's HTTP/2 implementation applies WriteTimeout
// to the entire connection, not per-stream. Long-running handlers
// (account scans ~5min) cause ERR_HTTP2_PROTOCOL_ERROR in browsers
// when the timeout fires. HTTP/1.1 handles per-request deadlines
// correctly via ResponseController.SetWriteDeadline.
NextProtos: []string{"http/1.1"},
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
},
}
return s, nil
}
// pruneLoginAttempts periodically cleans up stale rate-limit entries.
// It returns when s.pruneDone is closed.
func (s *Server) pruneLoginAttempts() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.pruneDone:
return
case <-ticker.C:
s.loginMu.Lock()
cutoff := time.Now().Add(-time.Minute)
for ip, attempts := range s.loginAttempts {
var recent []time.Time
for _, t := range attempts {
if t.After(cutoff) {
recent = append(recent, t)
}
}
if len(recent) == 0 {
delete(s.loginAttempts, ip)
} else {
s.loginAttempts[ip] = recent
}
}
s.loginMu.Unlock()
// Also prune API rate-limit entries
s.apiMu.Lock()
for ip, reqs := range s.apiRequests {
var recent []time.Time
for _, t := range reqs {
if t.After(cutoff) {
recent = append(recent, t)
}
}
if len(recent) == 0 {
delete(s.apiRequests, ip)
} else {
s.apiRequests[ip] = recent
}
}
s.apiMu.Unlock()
}
}
}
// Start starts the HTTPS server. Blocks until shutdown.
func (s *Server) Start() error {
certPath := s.cfg.WebUI.TLSCert
keyPath := s.cfg.WebUI.TLSKey
if certPath == "" {
certPath = filepath.Join(s.cfg.StatePath, "webui.crt")
keyPath = filepath.Join(s.cfg.StatePath, "webui.key")
}
if err := EnsureTLSCert(certPath, keyPath, s.cfg.Hostname); err != nil {
return fmt.Errorf("TLS cert setup: %w", err)
}
go s.pruneLoginAttempts()
perfCtx, perfCancel := context.WithCancel(context.Background())
s.perfCancel = perfCancel
go s.sampleMetricsLoop(perfCtx)
fmt.Fprintf(os.Stderr, "WebUI listening on https://%s\n", s.cfg.WebUI.Listen)
return s.httpSrv.ListenAndServeTLS(certPath, keyPath)
}
// Shutdown gracefully stops the server.
func (s *Server) Shutdown(ctx context.Context) error {
if s.perfCancel != nil {
s.perfCancel()
}
close(s.pruneDone)
return s.httpSrv.Shutdown(ctx)
}
// Broadcast is a no-op kept for daemon compatibility; dashboard uses polling.
func (s *Server) Broadcast(_ []alert.Finding) {}
// SetSigCount sets the loaded signature count for the status API.
func (s *Server) SetSigCount(count int) {
s.sigCount = count
}
// HasUI returns true if UI templates were loaded from disk.
func (s *Server) HasUI() bool {
return s.hasUI
}
// SetIPBlocker sets the firewall engine for block/unblock operations.
func (s *Server) SetIPBlocker(b IPBlocker) {
s.blocker = b
}
// SetHealthInfo sets daemon health info for the health API.
func (s *Server) SetHealthInfo(fanotifyActive bool, logWatchers int) {
s.fanotifyActive = fanotifyActive
s.logWatcherCount = logWatchers
}
// SetEmailQuarantine sets the email quarantine for the email AV API endpoints.
func (s *Server) SetEmailQuarantine(q *emailav.Quarantine) {
s.emailQuarantine = q
}
// SetEmailAVWatcherMode sets the watcher mode string for the email AV status API.
func (s *Server) SetEmailAVWatcherMode(mode string) {
s.emailAVWatcherMode = mode
}
// SetVersion sets the application version for display in the UI.
func (s *Server) SetVersion(v string) {
s.version = v
}
// csmConfigJSON returns a JSON string of feature flags for the frontend.
func (s *Server) csmConfigJSON() string {
b, _ := json.Marshal(s.csmConfig())
return string(b)
}
// csmConfig returns the feature-flag map used by the frontend. Template
// rendering goes through jsonForScript; the csmConfigJSON wrapper exists
// for code paths that need a pre-marshaled JSON string.
func (s *Server) csmConfig() map[string]interface{} {
return map[string]interface{}{
"version": s.version,
"emailAV": s.cfg.EmailAV.Enabled,
"firewall": s.cfg.Firewall != nil && s.cfg.Firewall.Enabled,
"autoResponse": s.cfg.AutoResponse.Enabled,
"threatIntel": s.cfg.Reputation.AbuseIPDBKey != "",
"signatures": s.cfg.Signatures.RulesDir != "",
"challenge": s.cfg.Challenge.Difficulty > 0,
"hostname": s.cfg.Hostname,
// #nosec G101 -- Not credentials. This is a lookup from
// finding-type ID (waf_block, credential_leak, etc.) to the
// human-readable label rendered in the UI.
"checkNames": map[string]string{
"waf_block": "WAF Block",
"brute_force": "Brute Force",
"webshell": "Web Shell",
"phishing": "Phishing",
"spam": "Spam",
"cpanel_login": "cPanel Login",
"file_upload": "File Upload",
"recon": "Reconnaissance",
"c2": "C2 Communication",
"other": "Other",
"perf_load": "Load",
"perf_php_processes": "PHP Processes",
"perf_memory": "Memory",
"perf_php_handler": "PHP Handler",
"perf_mysql_config": "MySQL Config",
"perf_redis_config": "Redis Config",
"perf_error_logs": "Error Logs",
"perf_wp_config": "WP Config",
"perf_wp_transients": "WP Transients",
"perf_wp_cron": "WP Cron",
"integrity": "Integrity",
"db_siteurl_hijack": "DB URL Hijack",
"db_options_injection": "DB Options Injection",
"db_post_injection": "DB Post Injection",
"db_spam_injection": "DB Spam Injection",
"db_rogue_admin": "DB Rogue Admin",
"db_suspicious_admin_email": "DB Suspicious Admin",
"mail_queue": "Mail Queue",
"mail_per_account": "Mail Volume",
"email_phishing_content": "Email Phishing",
"email_malware": "Email Malware",
"email_compromised_account": "Compromised Account",
"email_spam_outbreak": "Spam Outbreak",
"email_credential_leak": "Credential Leak",
"email_auth_failure_realtime": "Auth Failure",
"smtp_bruteforce": "SMTP Brute Force",
"smtp_subnet_spray": "SMTP Subnet Spray",
"smtp_account_spray": "SMTP Account Spray",
"mail_bruteforce": "Mail Brute Force",
"mail_subnet_spray": "Mail Subnet Spray",
"admin_panel_bruteforce": "Admin Panel Brute Force",
"mail_account_spray": "Mail Account Spray",
"mail_account_compromised": "Mail Account Compromised",
"exim_frozen_realtime": "Frozen Message",
"email_suspicious_geo": "Suspicious Geo Login",
"email_rate_critical": "Email Rate Critical",
"email_rate_warning": "Email Rate Warning",
"email_dkim_failure": "DKIM Failure",
"email_spf_rejection": "SPF Rejection",
"email_pipe_forwarder": "Pipe Forwarder",
"email_suspicious_forwarder": "Suspicious Forwarder",
"cpanel_login_realtime": "cPanel Login",
"cpanel_password_purge_realtime": "Password Purge",
"ssh_login_realtime": "SSH Login",
"pam_login": "PAM Login",
"pam_bruteforce": "PAM Brute Force",
"modsec_csm_block_escalation": "ModSec Escalation",
"whm_password_change_noninfra": "WHM Password Change",
"password_hijack_confirmed": "Password Hijack",
},
}
}
// --- Authentication ---
func (s *Server) requireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if s.isAuthenticated(r) {
next.ServeHTTP(w, r)
return
}
// API calls get 401 JSON; browser requests get redirect to login
if strings.HasPrefix(r.URL.Path, "/api/") {
writeJSONError(w, "Unauthorized", http.StatusUnauthorized)
return
}
http.Redirect(w, r, "/login", http.StatusFound)
})
}
func (s *Server) isAuthenticated(r *http.Request) bool {
token := s.cfg.WebUI.AuthToken
if token == "" {
return false
}
// Check Authorization header
if auth := r.Header.Get("Authorization"); auth != "" {
if len(auth) > 7 && auth[:7] == "Bearer " {
if subtle.ConstantTimeCompare([]byte(auth[7:]), []byte(token)) == 1 {
return true
}
}
}
// Check cookie
if cookie, err := r.Cookie("csm_auth"); err == nil {
if subtle.ConstantTimeCompare([]byte(cookie.Value), []byte(token)) == 1 {
return true
}
}
return false
}
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
// Redirect already-authenticated users to dashboard
if s.isAuthenticated(r) {
http.Redirect(w, r, "/dashboard", http.StatusFound)
return
}
if r.Method == http.MethodGet {
s.renderTemplate(w, "login.html", nil)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Rate limit: 5 attempts per minute per IP (strip port from RemoteAddr)
ip := r.RemoteAddr
if host, _, err := net.SplitHostPort(ip); err == nil {
ip = host
}
s.loginMu.Lock()
now := time.Now()
attempts := s.loginAttempts[ip]
var recent []time.Time
for _, t := range attempts {
if now.Sub(t) < time.Minute {
recent = append(recent, t)
}
}
if len(recent) >= 5 {
s.loginMu.Unlock()
http.Error(w, "Too many login attempts", http.StatusTooManyRequests)
return
}
s.loginAttempts[ip] = append(recent, now)
s.loginMu.Unlock()
token := r.FormValue("token")
if subtle.ConstantTimeCompare([]byte(token), []byte(s.cfg.WebUI.AuthToken)) != 1 {
s.renderTemplate(w, "login.html", map[string]string{"Error": "Invalid token"})
return
}
http.SetCookie(w, &http.Cookie{
Name: "csm_auth",
Value: token,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
MaxAge: 86400, // 24 hours
})
http.Redirect(w, r, "/dashboard", http.StatusFound)
}
// --- Template helpers ---
func severityClass(sev alert.Severity) string {
switch sev {
case alert.Critical:
return "critical"
case alert.High:
return "high"
case alert.Warning:
return "warning"
default:
return "info"
}
}
func severityLabel(sev alert.Severity) string {
switch sev {
case alert.Critical:
return "CRITICAL"
case alert.High:
return "HIGH"
case alert.Warning:
return "WARNING"
default:
return "INFO"
}
}
// severityRank returns a numeric rank for severity labels (higher = more severe).
func severityRank(label string) int {
switch label {
case "CRITICAL":
return 3
case "HIGH":
return 2
case "WARNING":
return 1
default:
return 0
}
}
func timeAgo(t time.Time) string {
d := time.Since(t)
switch {
case d < time.Minute:
return "just now"
case d < time.Hour:
return fmt.Sprintf("%dm ago", int(d.Minutes()))
case d < 24*time.Hour:
return fmt.Sprintf("%dh ago", int(d.Hours()))
default:
return fmt.Sprintf("%dd ago", int(d.Hours()/24))
}
}
func formatTime(t time.Time) string {
return t.Format("2006-01-02 15:04:05")
}
// --- Security headers middleware ---
func (s *Server) securityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; connect-src 'self'; img-src 'self' data:; font-src 'self'")
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), payment=()")
w.Header().Set("Cache-Control", "no-store")
// CORS/origin validation: reject cross-origin API requests
if strings.HasPrefix(r.URL.Path, "/api/") {
origin := r.Header.Get("Origin")
if origin != "" {
// Only allow same-origin - use r.Host which includes port
host := r.Host
if host == "" {
// Fallback: build from config (hostname + listen port)
host = s.cfg.Hostname
if listen := s.cfg.WebUI.Listen; listen != "" {
if idx := strings.LastIndex(listen, ":"); idx >= 0 {
port := listen[idx+1:]
if port != "443" {
host = host + ":" + port
}
}
}
}
allowed := "https://" + host
if origin != allowed {
http.Error(w, "Cross-origin request blocked", http.StatusForbidden)
return
}
w.Header().Set("Access-Control-Allow-Origin", allowed)
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
// Deny CORS preflight from unknown origins
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
}
// API rate limiting: 600 requests per minute per IP
if strings.HasPrefix(r.URL.Path, "/api/") {
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx >= 0 {
ip = ip[:idx]
}
s.apiMu.Lock()
now := time.Now()
cutoff := now.Add(-time.Minute)
var recent []time.Time
for _, t := range s.apiRequests[ip] {
if t.After(cutoff) {
recent = append(recent, t)
}
}
if len(recent) >= 600 {
s.apiMu.Unlock()
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
s.apiRequests[ip] = append(recent, now)
s.apiMu.Unlock()
}
next.ServeHTTP(w, r)
})
}
// --- CSRF protection ---
// csrfToken generates a deterministic CSRF token from the auth token.
// This is safe because the auth token is secret and the CSRF token is
// derived via HMAC - knowing the CSRF token doesn't reveal the auth token.
func (s *Server) csrfToken() string {
mac := hmac.New(sha256.New, []byte(s.cfg.WebUI.AuthToken))
// Include start time so token rotates on each daemon restart
fmt.Fprintf(mac, "csm-csrf-v1:%d", s.startTime.Unix())
return hex.EncodeToString(mac.Sum(nil))[:32]
}
// validateCSRF checks the CSRF token on POST requests.
// Checks X-CSRF-Token header (for API calls) or csrf_token form field (for form posts).
func (s *Server) validateCSRF(r *http.Request) bool {
if r.Method != http.MethodPost && r.Method != http.MethodDelete {
return true // only validate POST and DELETE
}
// Skip CSRF for Bearer token auth - the token itself proves identity.
// CSRF protection is only needed for cookie-based browser sessions.
if auth := r.Header.Get("Authorization"); auth != "" {
if len(auth) > 7 && auth[:7] == "Bearer " {
token := s.cfg.WebUI.AuthToken
if subtle.ConstantTimeCompare([]byte(auth[7:]), []byte(token)) == 1 {
return true
}
}
}
expected := s.csrfToken()
// Check header (API calls from JS use this)
if token := r.Header.Get("X-CSRF-Token"); token != "" {
return subtle.ConstantTimeCompare([]byte(token), []byte(expected)) == 1
}
// Check form field (traditional form posts)
if token := r.FormValue("csrf_token"); token != "" {
return subtle.ConstantTimeCompare([]byte(token), []byte(expected)) == 1
}
return false
}
// requireCSRF wraps a handler to validate CSRF on POST requests.
func (s *Server) requireCSRF(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip CSRF for Bearer token auth (API-to-API calls don't need CSRF protection)
if (r.Method == http.MethodPost || r.Method == http.MethodDelete) && !s.isBearerAuth(r) && !s.validateCSRF(r) {
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
func (s *Server) isBearerAuth(r *http.Request) bool {
auth := r.Header.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " {
return subtle.ConstantTimeCompare([]byte(auth[7:]), []byte(s.cfg.WebUI.AuthToken)) == 1
}
return false
}
// --- Logout ---
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: "csm_auth",
Value: "",
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
MaxAge: -1, // delete cookie
})
http.Redirect(w, r, "/login", http.StatusFound)
}
// --- Scan rate limiting ---
// acquireScan tries to start a scan. Returns false if a scan is already running.
func (s *Server) acquireScan() bool {
s.scanMu.Lock()
defer s.scanMu.Unlock()
if s.scanRunning {
return false
}
s.scanRunning = true
return true
}
func (s *Server) releaseScan() {
s.scanMu.Lock()
s.scanRunning = false
s.scanMu.Unlock()
}
package webui
import (
"crypto/rand"
"encoding/hex"
"net/http"
"time"
"github.com/pidginhost/csm/internal/state"
)
func (s *Server) apiSuppressions(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
rules := s.store.LoadSuppressions()
if rules == nil {
rules = []state.SuppressionRule{}
}
writeJSON(w, rules)
case http.MethodPost:
var req struct {
Check string `json:"check"`
PathPattern string `json:"path_pattern"`
Reason string `json:"reason"`
}
if err := decodeJSONBodyLimited(w, r, 32*1024, &req); err != nil || req.Check == "" {
writeJSONError(w, "check field is required", http.StatusBadRequest)
return
}
// Generate unique ID
b := make([]byte, 8)
_, _ = rand.Read(b)
id := hex.EncodeToString(b)
rules := s.store.LoadSuppressions()
rules = append(rules, state.SuppressionRule{
ID: id,
Check: req.Check,
PathPattern: req.PathPattern,
Reason: req.Reason,
CreatedAt: time.Now(),
})
if err := s.store.SaveSuppressions(rules); err != nil {
writeJSONError(w, "failed to save suppression: "+err.Error(), http.StatusInternalServerError)
return
}
s.auditLog(r, "suppress", req.Check, "pattern: "+req.PathPattern)
writeJSON(w, map[string]string{"status": "created", "id": id})
case http.MethodDelete:
var req struct {
ID string `json:"id"`
}
if err := decodeJSONBodyLimited(w, r, 16*1024, &req); err != nil || req.ID == "" {
writeJSONError(w, "id is required", http.StatusBadRequest)
return
}
rules := s.store.LoadSuppressions()
var filtered []state.SuppressionRule
for _, rule := range rules {
if rule.ID != req.ID {
filtered = append(filtered, rule)
}
}
if err := s.store.SaveSuppressions(filtered); err != nil {
writeJSONError(w, "failed to save suppressions: "+err.Error(), http.StatusInternalServerError)
return
}
s.auditLog(r, "unsuppress", req.ID, "removed suppression rule")
writeJSON(w, map[string]string{"status": "deleted"})
default:
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
package webui
import (
"fmt"
"net"
"net/http"
"time"
"github.com/pidginhost/csm/internal/attackdb"
"github.com/pidginhost/csm/internal/checks"
"github.com/pidginhost/csm/internal/threat"
)
func (s *Server) handleThreat(w http.ResponseWriter, _ *http.Request) {
s.renderTemplate(w, "threat.html", map[string]string{
"Hostname": s.cfg.Hostname,
})
}
// GET /api/v1/threat/stats
func (s *Server) apiThreatStats(w http.ResponseWriter, r *http.Request) {
adb := attackdb.Global()
if adb == nil {
writeJSON(w, map[string]string{"error": "attack database not initialized"})
return
}
writeJSON(w, adb.Stats())
}
// GET /api/v1/threat/top-attackers?limit=25
func (s *Server) apiThreatTopAttackers(w http.ResponseWriter, r *http.Request) {
adb := attackdb.Global()
if adb == nil {
writeJSON(w, []struct{}{})
return
}
limit := queryInt(r, "limit", 25)
if limit > 200 {
limit = 200
}
recs := adb.TopAttackers(limit)
// Enrich with unified intelligence
ips := make([]string, len(recs))
for i, rec := range recs {
ips[i] = rec.IP
}
intels := threat.LookupBatch(ips, s.cfg.StatePath)
type enriched struct {
*attackdb.IPRecord
UnifiedScore int `json:"unified_score"`
Verdict string `json:"verdict"`
AbuseScore int `json:"abuse_score"`
InThreatDB bool `json:"in_threat_db"`
Blocked bool `json:"currently_blocked"`
Country string `json:"country,omitempty"`
ASOrg string `json:"as_org,omitempty"`
}
results := make([]enriched, len(recs))
for i, rec := range recs {
results[i] = enriched{
IPRecord: rec,
UnifiedScore: intels[i].UnifiedScore,
Verdict: intels[i].Verdict,
AbuseScore: intels[i].AbuseScore,
InThreatDB: intels[i].InThreatDB,
Blocked: intels[i].CurrentlyBlocked,
}
if gdb := s.geoIPDB.Load(); gdb != nil {
geo := gdb.Lookup(rec.IP)
results[i].Country = geo.Country
results[i].ASOrg = geo.ASOrg
}
}
writeJSON(w, results)
}
// GET /api/v1/threat/ip?ip=1.2.3.4
func (s *Server) apiThreatIP(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip == "" || net.ParseIP(ip) == nil {
writeJSONError(w, "invalid or missing ip parameter", http.StatusBadRequest)
return
}
intel := threat.Lookup(ip, s.cfg.StatePath)
// Enrich with GeoIP data if available
if gdb := s.geoIPDB.Load(); gdb != nil {
geo := gdb.Lookup(ip)
intel.Country = geo.Country
intel.CountryName = geo.CountryName
intel.City = geo.City
intel.ASN = geo.ASN
intel.ASOrg = geo.ASOrg
intel.Network = geo.Network
}
writeJSON(w, intel)
}
// GET /api/v1/threat/events?ip=1.2.3.4&limit=50
func (s *Server) apiThreatEvents(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip == "" || net.ParseIP(ip) == nil {
writeJSONError(w, "invalid or missing ip parameter", http.StatusBadRequest)
return
}
limit := queryInt(r, "limit", 50)
if limit > 500 {
limit = 500
}
adb := attackdb.Global()
if adb == nil {
writeJSON(w, []struct{}{})
return
}
events := adb.QueryEvents(ip, limit)
if events == nil {
events = []attackdb.Event{}
}
writeJSON(w, events)
}
// GET /api/v1/threat/db-stats
func (s *Server) apiThreatDBStats(w http.ResponseWriter, r *http.Request) {
result := make(map[string]interface{})
if tdb := checks.GetThreatDB(); tdb != nil {
result["threat_db"] = tdb.Stats()
}
if adb := attackdb.Global(); adb != nil {
result["attack_db"] = map[string]interface{}{
"total_ips": adb.TotalIPs(),
"top_line": adb.FormatTopLine(),
}
}
writeJSON(w, result)
}
// POST /api/v1/threat/whitelist-ip - mark an IP as a known customer
// Unblocks, removes from threat DB + attack DB, adds to whitelist.
func (s *Server) apiThreatWhitelistIP(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
var actions []string
// 1. Unblock from firewall
if s.blocker != nil {
if err := s.blocker.UnblockIP(req.IP); err == nil {
actions = append(actions, "unblocked from firewall")
}
// Also add to firewall allow list so it doesn't get re-blocked
if allower, ok := s.blocker.(interface {
AllowIP(string, string) error
}); ok {
if err := allower.AllowIP(req.IP, "CSM whitelist: customer IP"); err == nil {
actions = append(actions, "added to firewall allow list")
}
}
}
// 2. Remove from threat DB permanent blocklist + add to whitelist
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemovePermanent(req.IP)
tdb.AddWhitelist(req.IP)
actions = append(actions, "removed from threat DB, added to whitelist")
}
// 3. Remove from attack DB
if adb := attackdb.Global(); adb != nil {
adb.RemoveIP(req.IP)
actions = append(actions, "removed from attack DB")
}
// 4. Flush cphulk
flushCphulk(req.IP)
s.auditLog(r, "whitelist_ip", req.IP, "permanent whitelist")
writeJSON(w, map[string]interface{}{
"status": "whitelisted",
"ip": req.IP,
"actions": actions,
})
}
// GET /api/v1/threat/whitelist - list all whitelisted IPs
func (s *Server) apiThreatWhitelist(w http.ResponseWriter, r *http.Request) {
tdb := checks.GetThreatDB()
if tdb == nil {
writeJSON(w, []string{})
return
}
writeJSON(w, tdb.WhitelistedIPs())
}
// POST /api/v1/threat/unwhitelist-ip - remove an IP from the whitelist
func (s *Server) apiThreatUnwhitelistIP(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemoveWhitelist(req.IP)
}
// Also remove from firewall allow list
if s.blocker != nil {
if remover, ok := s.blocker.(interface {
RemoveAllowIP(string) error
}); ok {
_ = remover.RemoveAllowIP(req.IP)
}
}
writeJSON(w, map[string]string{"status": "removed", "ip": req.IP})
}
// POST /api/v1/threat/block-ip - manually block an IP for 24 hours.
func (s *Server) apiThreatBlockIP(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
var actions []string
// 1. Block in firewall with 24h expiry
if s.blocker != nil {
if err := s.blocker.BlockIP(req.IP, "Manually blocked via CSM Web UI", 24*time.Hour); err != nil {
writeJSONError(w, fmt.Sprintf("block failed: %v", err), http.StatusInternalServerError)
return
}
actions = append(actions, "blocked in firewall for 24h")
} else {
writeJSONError(w, "firewall engine not available", http.StatusServiceUnavailable)
return
}
// 2. Add to threat DB permanent blocklist
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.AddPermanent(req.IP, "Manually blocked via CSM Web UI")
actions = append(actions, "added to threat DB")
}
// 3. Record in attack DB
if adb := attackdb.Global(); adb != nil {
adb.MarkBlocked(req.IP)
actions = append(actions, "recorded in attack DB")
}
s.auditLog(r, "block_ip", req.IP, "manual block 24h")
writeJSON(w, map[string]interface{}{
"status": "blocked",
"ip": req.IP,
"actions": actions,
})
}
// POST /api/v1/threat/clear-ip - unblock + clear from all DBs without whitelisting.
// For dynamic IP customers: one-time cleanup, IP can be re-blocked later.
func (s *Server) apiThreatClearIP(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
var actions []string
// 1. Unblock from firewall (but don't add to allow list)
if s.blocker != nil {
if err := s.blocker.UnblockIP(req.IP); err == nil {
actions = append(actions, "unblocked from firewall")
}
}
// 2. Remove from threat DB permanent blocklist (but don't whitelist)
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemovePermanent(req.IP)
actions = append(actions, "removed from threat DB")
}
// 3. Remove from attack DB
if adb := attackdb.Global(); adb != nil {
adb.RemoveIP(req.IP)
actions = append(actions, "removed from attack DB")
}
// 4. Flush cphulk
flushCphulk(req.IP)
actions = append(actions, "flushed cPanel login history")
s.auditLog(r, "clear_ip", req.IP, "unblock & clear")
writeJSON(w, map[string]interface{}{
"status": "cleared",
"ip": req.IP,
"actions": actions,
})
}
// POST /api/v1/threat/temp-whitelist-ip - whitelist for a specified duration.
func (s *Server) apiThreatTempWhitelistIP(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
Hours int `json:"hours"` // default 24
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil || req.IP == "" {
writeJSONError(w, "IP is required", http.StatusBadRequest)
return
}
if _, err := parseAndValidateIP(req.IP); err != nil {
writeJSONError(w, err.Error(), http.StatusBadRequest)
return
}
if req.Hours <= 0 {
req.Hours = 24
}
if req.Hours > 168 { // max 7 days
req.Hours = 168
}
ttl := time.Duration(req.Hours) * time.Hour
var actions []string
// 1. Unblock from firewall
if s.blocker != nil {
if err := s.blocker.UnblockIP(req.IP); err == nil {
actions = append(actions, "unblocked from firewall")
}
// Temp allow in firewall too
if allower, ok := s.blocker.(interface {
TempAllowIP(string, string, time.Duration) error
}); ok {
if err := allower.TempAllowIP(req.IP, "CSM temp whitelist", ttl); err == nil {
actions = append(actions, fmt.Sprintf("temp allowed in firewall for %dh", req.Hours))
}
}
}
// 2. Remove from threat DB + temp whitelist
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemovePermanent(req.IP)
tdb.TempWhitelist(req.IP, ttl)
actions = append(actions, fmt.Sprintf("temp whitelisted for %dh", req.Hours))
}
// 3. Remove from attack DB
if adb := attackdb.Global(); adb != nil {
adb.RemoveIP(req.IP)
actions = append(actions, "removed from attack DB")
}
// 4. Flush cphulk
flushCphulk(req.IP)
s.auditLog(r, "temp_whitelist_ip", req.IP, fmt.Sprintf("%dh temp whitelist", req.Hours))
writeJSON(w, map[string]interface{}{
"status": "temp_whitelisted",
"ip": req.IP,
"hours": req.Hours,
"actions": actions,
})
}
// POST /api/v1/threat/bulk-action - block or whitelist multiple IPs at once.
func (s *Server) apiThreatBulkAction(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSONError(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IPs []string `json:"ips"`
Action string `json:"action"`
}
if err := decodeJSONBodyLimited(w, r, 64*1024, &req); err != nil {
writeJSONError(w, "Invalid request body", http.StatusBadRequest)
return
}
if len(req.IPs) == 0 || len(req.IPs) > 100 {
writeJSONError(w, "IPs must be 1-100 items", http.StatusBadRequest)
return
}
if req.Action != "block" && req.Action != "whitelist" {
writeJSONError(w, "Action must be 'block' or 'whitelist'", http.StatusBadRequest)
return
}
count := 0
for _, ipStr := range req.IPs {
if _, err := parseAndValidateIP(ipStr); err != nil {
continue
}
switch req.Action {
case "block":
// Mirror apiThreatBlockIP flow
if s.blocker != nil {
if err := s.blocker.BlockIP(ipStr, "Bulk blocked via CSM Web UI", 24*time.Hour); err != nil {
continue
}
}
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.AddPermanent(ipStr, "Bulk blocked via CSM Web UI")
}
if adb := attackdb.Global(); adb != nil {
adb.MarkBlocked(ipStr)
}
count++
case "whitelist":
// Mirror apiThreatWhitelistIP flow
if s.blocker != nil {
_ = s.blocker.UnblockIP(ipStr)
if allower, ok := s.blocker.(interface {
AllowIP(string, string) error
}); ok {
_ = allower.AllowIP(ipStr, "CSM bulk whitelist")
}
}
if tdb := checks.GetThreatDB(); tdb != nil {
tdb.RemovePermanent(ipStr)
tdb.AddWhitelist(ipStr)
}
if adb := attackdb.Global(); adb != nil {
adb.RemoveIP(ipStr)
}
flushCphulk(ipStr)
count++
}
}
s.auditLog(r, "threat_bulk_"+req.Action, fmt.Sprintf("%d IPs", count), "")
writeJSON(w, map[string]interface{}{"ok": true, "count": count})
}
// writeJSON is defined in api.go
package webui
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"time"
)
// EnsureTLSCert generates a self-signed ECDSA P-256 certificate if the
// cert and key files don't exist. Includes localhost and the server hostname
// in the certificate SANs.
func EnsureTLSCert(certPath, keyPath string, extraNames ...string) error {
// If both exist, nothing to do
if fileExists(certPath) && fileExists(keyPath) {
return nil
}
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return fmt.Errorf("generating key: %w", err)
}
serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
// Use first extra name (hostname) as CN, fall back to localhost
cn := "localhost"
if len(extraNames) > 0 && extraNames[0] != "" {
cn = extraNames[0]
}
template := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
Organization: []string{"CSM Security Monitor"},
CommonName: cn,
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: buildDNSNames(extraNames),
IPAddresses: buildIPList(extraNames),
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
return fmt.Errorf("creating certificate: %w", err)
}
// Write cert
// #nosec G304 -- certPath is derived from config-owned statePath at the
// callsite; this function is the cert *generator*.
certFile, err := os.Create(certPath)
if err != nil {
return fmt.Errorf("creating cert file: %w", err)
}
defer func() { _ = certFile.Close() }()
if encErr := pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); encErr != nil {
return fmt.Errorf("encoding cert: %w", encErr)
}
// Write key
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
return fmt.Errorf("marshaling key: %w", err)
}
// #nosec G304 -- same as certPath: generator for a config-owned path.
keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("creating key file: %w", err)
}
defer func() { _ = keyFile.Close() }()
if err := pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil {
return fmt.Errorf("encoding key: %w", err)
}
return nil
}
func buildDNSNames(extra []string) []string {
names := []string{"localhost"}
for _, n := range extra {
if net.ParseIP(n) == nil { // not an IP - it's a hostname
names = append(names, n)
}
}
return names
}
func buildIPList(extra []string) []net.IP {
ips := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}
for _, n := range extra {
if ip := net.ParseIP(n); ip != nil {
ips = append(ips, ip)
}
}
return ips
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
package wpcheck
import (
"crypto/md5" // #nosec G501 -- MD5 is the hash wordpress.org publishes for core file checksums; this is integrity verification against a published reference, not a security primitive.
"encoding/hex"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"golang.org/x/sys/unix"
)
type Cache struct {
mu sync.RWMutex
statePath string
checksums map[string]map[string]string // core: "<version>:<locale>" -> relPath -> MD5
pluginChecksums map[string]map[string]string // plugins: "<slug>:<version>" -> relPath -> SHA256
roots map[string]rootEntry
fetching map[string]bool
}
type rootEntry struct {
version string
locale string
}
func NewCache(statePath string) *Cache {
c := &Cache{
statePath: statePath,
checksums: make(map[string]map[string]string),
pluginChecksums: make(map[string]map[string]string),
roots: make(map[string]rootEntry),
fetching: make(map[string]bool),
}
c.loadFromDisk()
return c
}
func cacheKey(version, locale string) string {
return version + ":" + locale
}
func diskFilename(version, locale string) string {
return version + "_" + locale + ".json"
}
func (c *Cache) loadFromDisk() {
dir := filepath.Join(c.statePath, "wp-checksums")
entries, err := os.ReadDir(dir)
if err != nil {
return
}
for _, entry := range entries {
name := entry.Name()
if entry.IsDir() || !strings.HasSuffix(name, ".json") {
continue
}
// #nosec G304 -- dir is {statePath}/wp-checksums; name comes from
// our own os.ReadDir of that same dir.
data, err := os.ReadFile(filepath.Join(dir, name))
if err != nil {
continue
}
checksums, err := ParseChecksumResponse(data)
if err != nil {
continue
}
base := strings.TrimSuffix(name, ".json")
parts := strings.SplitN(base, "_", 2)
if len(parts) != 2 {
continue
}
c.checksums[cacheKey(parts[0], parts[1])] = checksums
}
}
// PersistChecksums writes checksum data to disk atomically (tmpfile + rename)
// and populates the in-memory cache. The file is written to {statePath}/wp-checksums/.
func (c *Cache) PersistChecksums(version, locale string, rawJSON []byte, checksums map[string]string) error {
dir := filepath.Join(c.statePath, "wp-checksums")
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("creating wp-checksums dir: %w", err)
}
filename := diskFilename(version, locale)
tmpPath := filepath.Join(dir, filename+".tmp")
finalPath := filepath.Join(dir, filename)
if err := os.WriteFile(tmpPath, rawJSON, 0600); err != nil {
return fmt.Errorf("writing temp file: %w", err)
}
if err := os.Rename(tmpPath, finalPath); err != nil {
os.Remove(tmpPath)
return fmt.Errorf("renaming to final: %w", err)
}
c.mu.Lock()
c.checksums[cacheKey(version, locale)] = checksums
c.mu.Unlock()
return nil
}
func (c *Cache) lookupChecksum(version, locale, relativePath string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
versionMap, ok := c.checksums[cacheKey(version, locale)]
if !ok {
return "", false
}
md5hex, ok := versionMap[relativePath]
return md5hex, ok
}
func (c *Cache) hasChecksums(version, locale string) bool {
c.mu.RLock()
ok := c.checksums[cacheKey(version, locale)] != nil
c.mu.RUnlock()
return ok
}
func (c *Cache) getRoot(root string) (version, locale string, ok bool) {
c.mu.RLock()
entry, ok := c.roots[root]
c.mu.RUnlock()
if !ok {
return "", "", false
}
return entry.version, entry.locale, true
}
func (c *Cache) setRoot(root, version, locale string) {
c.mu.Lock()
c.roots[root] = rootEntry{version: version, locale: locale}
c.mu.Unlock()
}
func (c *Cache) invalidateRoot(root string) {
c.mu.Lock()
delete(c.roots, root)
c.mu.Unlock()
}
func (c *Cache) startBackgroundFetch(version, locale string) {
key := cacheKey(version, locale)
c.mu.Lock()
if c.fetching[key] {
c.mu.Unlock()
return
}
c.fetching[key] = true
c.mu.Unlock()
go c.fetchWithRetry(version, locale, 0)
}
func (c *Cache) fetchWithRetry(version, locale string, attempt int) {
backoffs := []time.Duration{1 * time.Minute, 5 * time.Minute, 15 * time.Minute, 1 * time.Hour}
rawJSON, checksums, err := FetchChecksums(version, locale)
if err != nil {
delay := backoffs[len(backoffs)-1]
if attempt < len(backoffs) {
delay = backoffs[attempt]
}
fmt.Fprintf(os.Stderr, "wpcheck: fetch failed for WP %s (%s), retry in %v: %v\n",
version, locale, delay, err)
time.AfterFunc(delay, func() {
c.fetchWithRetry(version, locale, attempt+1)
})
return
}
if err := c.PersistChecksums(version, locale, rawJSON, checksums); err != nil {
fmt.Fprintf(os.Stderr, "wpcheck: persist failed for WP %s (%s): %v\n", version, locale, err)
}
c.mu.Lock()
delete(c.fetching, cacheKey(version, locale))
c.mu.Unlock()
fmt.Fprintf(os.Stderr, "wpcheck: cached %d checksums for WP %s (%s)\n", len(checksums), version, locale)
}
const maxFileSize = 2 << 20
func (c *Cache) IsVerifiedCoreFile(fd int, path string) bool {
root := DetectWPRoot(path)
if root == "" {
return false
}
relPath := RelativePath(root, path)
if relPath == "" {
return false
}
if relPath == filepath.Join("wp-includes", "version.php") {
c.invalidateRoot(root)
}
version, locale, ok := c.getRoot(root)
if !ok {
var err error
version, locale, err = ReadVersionFile(root)
if err != nil {
return false
}
c.setRoot(root, version, locale)
}
if !c.hasChecksums(version, locale) {
c.startBackgroundFetch(version, locale)
return false
}
expectedMD5, ok := c.lookupChecksum(version, locale, relPath)
if !ok {
return false
}
data := make([]byte, maxFileSize)
n, err := unix.Pread(fd, data, 0)
if n <= 0 || (err != nil && n == 0) {
return false
}
data = data[:n]
// #nosec G401 -- MD5 is required here: wordpress.org ships MD5 digests
// as the canonical integrity reference for core files. We compare
// against their published values, not derive authority from the hash.
hash := md5.Sum(data)
actualMD5 := hex.EncodeToString(hash[:])
if actualMD5 == expectedMD5 {
return true
}
c.invalidateRoot(root)
newVersion, newLocale, err := ReadVersionFile(root)
if err != nil || (newVersion == version && newLocale == locale) {
return false
}
c.setRoot(root, newVersion, newLocale)
if !c.hasChecksums(newVersion, newLocale) {
c.startBackgroundFetch(newVersion, newLocale)
return false
}
newExpectedMD5, ok := c.lookupChecksum(newVersion, newLocale, relPath)
if !ok {
return false
}
return actualMD5 == newExpectedMD5
}
package wpcheck
import (
"errors"
"os"
"path/filepath"
"regexp"
"strings"
)
// wpRootLevelFiles are filenames that only exist at the WP installation root.
// index.php is excluded - it requires a secondary check (version.php must exist).
var wpRootLevelFiles = map[string]bool{
"wp-activate.php": true,
"wp-blog-header.php": true,
"wp-comments-post.php": true,
"wp-config-sample.php": true,
"wp-cron.php": true,
"wp-links-opml.php": true,
"wp-load.php": true,
"wp-login.php": true,
"wp-mail.php": true,
"wp-settings.php": true,
"wp-signup.php": true,
"wp-trackback.php": true,
"xmlrpc.php": true,
}
// DetectWPRoot returns the WordPress installation root directory for a file path,
// or empty string if the path is not inside a WP core location.
//
// Detection methods:
// - Path contains /wp-includes/ or /wp-admin/ → root is everything before that segment
// - Filename is a known root-level WP file → root is the parent directory
// - Filename is index.php and version.php exists in wp-includes/ → root is the parent directory
func DetectWPRoot(path string) string {
// Check for /wp-includes/ or /wp-admin/ in path
for _, marker := range []string{"/wp-includes/", "/wp-admin/"} {
if idx := strings.Index(path, marker); idx >= 0 {
return path[:idx]
}
}
// Check for direct wp-includes or wp-admin (file directly inside)
dir := filepath.Dir(path)
base := filepath.Base(dir)
if base == "wp-includes" || base == "wp-admin" {
return filepath.Dir(dir)
}
// Check for known root-level WP files
name := filepath.Base(path)
if wpRootLevelFiles[name] {
return dir
}
// Special case: index.php requires version.php to confirm WP root
if name == "index.php" {
versionPath := filepath.Join(dir, "wp-includes", "version.php")
if _, err := os.Stat(versionPath); err == nil {
return dir
}
}
return ""
}
// RelativePath computes the path of a file relative to the WP root.
// Returns empty string if the file is not under root.
func RelativePath(root, path string) string {
rel, err := filepath.Rel(root, path)
if err != nil || strings.HasPrefix(rel, "..") {
return ""
}
return rel
}
var (
reVersion = regexp.MustCompile(`\$wp_version\s*=\s*'([^']+)'`)
reLocale = regexp.MustCompile(`\$wp_local_package\s*=\s*'([^']+)'`)
)
// ParseVersionContent extracts the WP version and locale from version.php content.
// Locale defaults to "en_US" if $wp_local_package is not present.
func ParseVersionContent(data []byte) (version, locale string, err error) {
m := reVersion.FindSubmatch(data)
if m == nil {
return "", "", errors.New("wp_version not found in version.php")
}
version = string(m[1])
locale = "en_US"
if lm := reLocale.FindSubmatch(data); lm != nil {
locale = string(lm[1])
}
return version, locale, nil
}
// ReadVersionFile reads and parses {root}/wp-includes/version.php.
func ReadVersionFile(root string) (version, locale string, err error) {
// #nosec G304 -- path derived from a WordPress install root discovered
// by the scanner under configured /home/*/public_html paths.
data, err := os.ReadFile(filepath.Join(root, "wp-includes", "version.php"))
if err != nil {
return "", "", err
}
return ParseVersionContent(data)
}
package wpcheck
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
)
var httpClient = &http.Client{Timeout: 10 * time.Second}
// checksumAPIURL returns the WordPress.org checksum API URL for a version and locale.
func checksumAPIURL(version, locale string) string {
return fmt.Sprintf("https://api.wordpress.org/core/checksums/1.0/?version=%s&locale=%s", version, locale)
}
// checksumResponse is the JSON structure returned by the WP checksum API.
type checksumResponse struct {
Checksums map[string]string `json:"checksums"`
}
// ParseChecksumResponse parses the JSON response from the WP checksum API.
// Returns a map of relative_path -> md5_hex.
func ParseChecksumResponse(data []byte) (map[string]string, error) {
var resp checksumResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, fmt.Errorf("invalid JSON: %w", err)
}
if len(resp.Checksums) == 0 {
return nil, errors.New("empty or missing checksums in response")
}
return resp.Checksums, nil
}
// FetchChecksums fetches official checksums from api.wordpress.org for a given
// version and locale. Returns the raw response body and parsed checksums.
func FetchChecksums(version, locale string) (rawJSON []byte, checksums map[string]string, err error) {
url := checksumAPIURL(version, locale)
resp, err := httpClient.Get(url)
if err != nil {
return nil, nil, fmt.Errorf("HTTP request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, nil, fmt.Errorf("HTTP %d from %s", resp.StatusCode, url)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) // 2MB limit
if err != nil {
return nil, nil, fmt.Errorf("reading response: %w", err)
}
checksums, err = ParseChecksumResponse(body)
if err != nil {
return nil, nil, err
}
return body, checksums, nil
}
package wpcheck
import (
"archive/zip"
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"time"
"golang.org/x/sys/unix"
)
// Plugin verification mirrors the core-file verification path: when a file
// under /wp-content/plugins/<slug>/ matches the hash we computed from the
// plugin's official wordpress.org ZIP, signature/YARA rule matches on it are
// false positives and should not fire.
const pluginsSegment = "/wp-content/plugins/"
// DetectPluginRoot returns the plugin root directory and slug for a path that
// sits under /wp-content/plugins/<slug>/. Returns empty strings if the path
// is not inside a plugin.
func DetectPluginRoot(path string) (root, slug string) {
idx := strings.Index(path, pluginsSegment)
if idx < 0 {
return "", ""
}
tail := path[idx+len(pluginsSegment):]
slashIdx := strings.Index(tail, "/")
if slashIdx <= 0 {
return "", ""
}
slug = tail[:slashIdx]
root = path[:idx+len(pluginsSegment)] + slug
return root, slug
}
var rePluginVersionHeader = regexp.MustCompile(`(?im)^\s*\*?\s*Version:\s*([^\s]+)`)
// ReadPluginVersion extracts the Version: header from the plugin's main
// file (<pluginRoot>/<slug>.php). WordPress requires this header to exist
// on every published plugin.
func ReadPluginVersion(pluginRoot, slug string) (string, error) {
mainPath := filepath.Join(pluginRoot, slug+".php")
// #nosec G304 -- pluginRoot is derived from a path the scanner received
// from fanotify under /wp-content/plugins/; slug is the immediate child
// segment. The read is bounded by the header-detection limit below.
f, err := os.Open(mainPath)
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
buf := make([]byte, 8192)
n, _ := f.Read(buf)
if n <= 0 {
return "", errors.New("empty plugin main file")
}
m := rePluginVersionHeader.FindSubmatch(buf[:n])
if m == nil {
return "", errors.New("version header not found in plugin main file")
}
return string(m[1]), nil
}
// pluginZipURL returns the canonical wordpress.org download URL for a given
// plugin slug and version.
func pluginZipURL(slug, version string) string {
return fmt.Sprintf("https://downloads.wordpress.org/plugin/%s.%s.zip", slug, version)
}
// FetchPluginChecksums downloads the plugin ZIP from wordpress.org,
// extracts each file, and returns a map of relative path -> SHA256 hex.
// The returned paths are relative to the plugin root (the leading
// "<slug>/" prefix from the ZIP entries is stripped).
func FetchPluginChecksums(slug, version string) (map[string]string, error) {
return fetchPluginChecksumsFromURL(pluginZipURL(slug, version), slug)
}
const maxPluginZipBytes = 100 << 20 // 100 MB ceiling
func fetchPluginChecksumsFromURL(url, slug string) (map[string]string, error) {
resp, err := httpClient.Get(url) //nolint:gosec,bodyclose // httpClient has a timeout; body is closed below.
if err != nil {
return nil, fmt.Errorf("plugin zip GET failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("plugin zip HTTP %d from %s", resp.StatusCode, url)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, maxPluginZipBytes))
if err != nil {
return nil, fmt.Errorf("reading plugin zip: %w", err)
}
zr, err := zip.NewReader(bytes.NewReader(body), int64(len(body)))
if err != nil {
return nil, fmt.Errorf("opening plugin zip: %w", err)
}
out := make(map[string]string, len(zr.File))
prefix := slug + "/"
for _, zf := range zr.File {
if zf.FileInfo().IsDir() {
continue
}
name := zf.Name
if !strings.HasPrefix(name, prefix) {
// Malformed ZIP (e.g. nested into a differently-named folder).
// Skip; callers detect partial results by checking cache emptiness.
continue
}
rel := filepath.Clean(strings.TrimPrefix(name, prefix))
// Reject path-traversal and absolute paths: a crafted ZIP entry
// named "<slug>/../../etc/passwd" would otherwise land in the
// checksum map. Defense-in-depth against a compromised CDN.
if rel == "." || strings.HasPrefix(rel, "..") || strings.HasPrefix(rel, "/") {
continue
}
rc, err := zf.Open()
if err != nil {
return nil, fmt.Errorf("opening zip entry %s: %w", name, err)
}
// Cap decompressed size per entry. Without this, a zip-bomb whose
// compressed body fits under maxPluginZipBytes can still exhaust
// memory during io.Copy. +1 lets us detect overflow.
limited := io.LimitReader(rc, maxPluginZipBytes+1)
h := sha256.New()
nCopied, err := io.Copy(h, limited)
_ = rc.Close()
if err != nil {
return nil, fmt.Errorf("hashing zip entry %s: %w", name, err)
}
if nCopied > maxPluginZipBytes {
return nil, fmt.Errorf("zip entry %s exceeds per-entry size cap", name)
}
out[rel] = hex.EncodeToString(h.Sum(nil))
}
if len(out) == 0 {
return nil, errors.New("plugin zip yielded no checksums")
}
return out, nil
}
// --- Cache plugin support ------------------------------------------------
func pluginKey(slug, version string) string {
return slug + ":" + version
}
func (c *Cache) setPluginChecksums(slug, version string, checksums map[string]string) {
c.mu.Lock()
defer c.mu.Unlock()
if c.pluginChecksums == nil {
c.pluginChecksums = make(map[string]map[string]string)
}
c.pluginChecksums[pluginKey(slug, version)] = checksums
}
func (c *Cache) lookupPluginChecksum(slug, version, relPath string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
m, ok := c.pluginChecksums[pluginKey(slug, version)]
if !ok {
return "", false
}
h, ok := m[relPath]
return h, ok
}
func (c *Cache) hasPluginChecksums(slug, version string) bool {
c.mu.RLock()
_, ok := c.pluginChecksums[pluginKey(slug, version)]
c.mu.RUnlock()
return ok
}
func (c *Cache) startBackgroundPluginFetch(slug, version string) {
key := pluginKey(slug, version)
c.mu.Lock()
if c.fetching == nil {
c.fetching = make(map[string]bool)
}
if c.fetching[key] {
c.mu.Unlock()
return
}
c.fetching[key] = true
c.mu.Unlock()
go c.fetchPluginWithRetry(slug, version, 0)
}
// fetchPluginWithRetry mirrors the core-checksum fetchWithRetry: the
// fetching flag stays set across retries so cache-miss events for the
// same slug/version do not spawn new goroutines. On exhaustion the flag
// is cleared so a future event can retry fresh.
func (c *Cache) fetchPluginWithRetry(slug, version string, attempt int) {
backoffs := []time.Duration{1 * time.Minute, 5 * time.Minute, 15 * time.Minute, 1 * time.Hour}
key := pluginKey(slug, version)
checksums, err := FetchPluginChecksums(slug, version)
if err == nil {
c.setPluginChecksums(slug, version, checksums)
c.mu.Lock()
delete(c.fetching, key)
c.mu.Unlock()
fmt.Fprintf(os.Stderr, "wpcheck: cached %d checksums for plugin %s %s\n", len(checksums), slug, version)
return
}
if attempt >= len(backoffs) {
c.mu.Lock()
delete(c.fetching, key)
c.mu.Unlock()
fmt.Fprintf(os.Stderr, "wpcheck: plugin fetch abandoned for %s %s after %d attempts: %v\n",
slug, version, attempt+1, err)
return
}
delay := backoffs[attempt]
fmt.Fprintf(os.Stderr, "wpcheck: plugin fetch failed for %s %s, retry in %v: %v\n",
slug, version, delay, err)
time.AfterFunc(delay, func() {
c.fetchPluginWithRetry(slug, version, attempt+1)
})
}
// IsVerifiedPluginFile compares a file against the cached wordpress.org
// checksum for its plugin/version. Returns true only when the on-disk
// content hash matches. Triggers a background fetch on cache miss.
func (c *Cache) IsVerifiedPluginFile(fd int, path string) bool {
root, slug := DetectPluginRoot(path)
if root == "" {
return false
}
rel, err := filepath.Rel(root, path)
if err != nil || strings.HasPrefix(rel, "..") {
return false
}
version, err := ReadPluginVersion(root, slug)
if err != nil || version == "" {
return false
}
expected, ok := c.lookupPluginChecksum(slug, version, rel)
if !ok {
if !c.hasPluginChecksums(slug, version) {
c.startBackgroundPluginFetch(slug, version)
}
return false
}
data := make([]byte, maxFileSize)
n, err := unix.Pread(fd, data, 0)
if n <= 0 || (err != nil && n == 0) {
return false
}
h := sha256.Sum256(data[:n])
return hex.EncodeToString(h[:]) == expected
}
package yara
import (
"fmt"
"os"
"sync"
)
var (
globalScanner *Scanner
globalOnce sync.Once
)
// Init initializes the global YARA-X scanner.
// Returns nil scanner if YARA-X is not compiled in or no rules found.
func Init(rulesDir string) *Scanner {
if !Available() {
return nil
}
globalOnce.Do(func() {
s, err := NewScanner(rulesDir)
if err != nil {
fmt.Fprintf(os.Stderr, "yara: init error: %v\n", err)
return
}
globalScanner = s
})
return globalScanner
}
// Global returns the global YARA-X scanner, or nil if not initialized.
func Global() *Scanner {
return globalScanner
}
//go:build !yara
package yara
// Scanner is a no-op stub when YARA-X is not compiled in.
type Scanner struct{}
// Match represents a YARA rule that matched.
type Match struct {
RuleName string
}
// NewScanner returns nil when YARA-X is not available.
func NewScanner(_ string) (*Scanner, error) {
return nil, nil
}
// Reload is a no-op without YARA-X.
func (s *Scanner) Reload() error { return nil }
// ScanBytes returns nil without YARA-X.
func (s *Scanner) ScanBytes(_ []byte) []Match { return nil }
// ScanFile returns nil without YARA-X.
func (s *Scanner) ScanFile(_ string, _ int) []Match { return nil }
// RuleCount returns 0 without YARA-X.
func (s *Scanner) RuleCount() int { return 0 }
// GlobalRules returns nil without YARA-X (no compiled rules available).
func (s *Scanner) GlobalRules() interface{} { return nil }
// Available returns false (YARA-X is not compiled in).
func Available() bool {
return false
}
// TestCompile is a no-op when YARA-X is not compiled in.
func TestCompile(source string) error {
return nil
}