From 89d904505d95a1a5772370fb7f5fffb78cce9854 Mon Sep 17 00:00:00 2001 From: Jason Staack Date: Sat, 21 Mar 2026 18:28:56 -0500 Subject: [PATCH] feat(16-03): add GetRawCredentials with 4-source fallback, wrap GetCredentials - GetRawCredentials resolves credentials: device transit, device legacy, profile transit, profile legacy - Cache key includes source (device/profile) to prevent cross-source poisoning - GetCredentials is now a backward-compatible wrapper calling GetRawCredentials + ParseRouterOSCredentials - Add DecryptRaw to device package for raw byte decryption without JSON parsing - Invalidate clears both parsed and raw cache entries - All existing callers (PollDevice, CmdResponder, TunnelResponder, BackupResponder, SSHRelay) unchanged Co-Authored-By: Claude Opus 4.6 (1M context) --- poller/internal/device/crypto.go | 32 +++++ poller/internal/vault/cache.go | 211 +++++++++++++++++++------------ 2 files changed, 160 insertions(+), 83 deletions(-) diff --git a/poller/internal/device/crypto.go b/poller/internal/device/crypto.go index d2a76d7..3665760 100644 --- a/poller/internal/device/crypto.go +++ b/poller/internal/device/crypto.go @@ -14,6 +14,38 @@ type credentialsJSON struct { Password string `json:"password"` } +// DecryptRaw decrypts AES-256-GCM encrypted data and returns the raw plaintext bytes. +// Used by GetRawCredentials to obtain credential JSON before type-specific parsing. +// The ciphertext layout is the same as described in DecryptCredentials. +func DecryptRaw(ciphertext []byte, key []byte) ([]byte, error) { + if len(key) != 32 { + return nil, fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) + } + if len(ciphertext) < 12+16 { + return nil, fmt.Errorf("ciphertext too short: need at least 28 bytes (12 nonce + 16 tag), got %d", len(ciphertext)) + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("creating AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("creating GCM cipher: %w", err) + } + + nonce := ciphertext[:12] + encryptedData := ciphertext[12:] + + plaintext, err := gcm.Open(nil, nonce, encryptedData, nil) + if err != nil { + return nil, fmt.Errorf("decrypting credentials (wrong key or tampered data): %w", err) + } + + return plaintext, nil +} + // DecryptCredentials decrypts AES-256-GCM encrypted credentials and returns the // username and password stored within. // diff --git a/poller/internal/vault/cache.go b/poller/internal/vault/cache.go index f88f609..3210374 100644 --- a/poller/internal/vault/cache.go +++ b/poller/internal/vault/cache.go @@ -2,7 +2,6 @@ package vault import ( "context" - "encoding/json" "fmt" "log/slog" "strings" @@ -45,13 +44,18 @@ var ( ) // CredentialCache provides cached credential decryption with dual-read support. -// It uses an LRU cache with TTL to avoid redundant OpenBao calls and falls back +// It uses LRU caches with TTL to avoid redundant OpenBao calls and falls back // to legacy AES-256-GCM decryption for credentials not yet migrated to Transit. +// +// Two caches are maintained: +// - cache: parsed RouterOS credentials (CachedCreds) for backward compatibility +// - rawCache: raw decrypted JSON bytes for type-agnostic credential access type CredentialCache struct { - cache *expirable.LRU[string, *CachedCreds] - transit *TransitClient - legacy []byte // legacy AES-256-GCM key (nil if not available) - db *pgxpool.Pool // for key_access_log inserts (nil if not available) + cache *expirable.LRU[string, *CachedCreds] + rawCache *expirable.LRU[string, []byte] // raw decrypted credential JSON bytes + transit *TransitClient + legacy []byte // legacy AES-256-GCM key (nil if not available) + db *pgxpool.Pool // for key_access_log inserts (nil if not available) } // NewCredentialCache creates a bounded LRU cache with the given size and TTL. @@ -59,101 +63,142 @@ type CredentialCache struct { // db may be nil if key access logging is not needed. func NewCredentialCache(size int, ttl time.Duration, transit *TransitClient, legacyKey []byte, db *pgxpool.Pool) *CredentialCache { cache := expirable.NewLRU[string, *CachedCreds](size, nil, ttl) + rawCache := expirable.NewLRU[string, []byte](size, nil, ttl) return &CredentialCache{ - cache: cache, - transit: transit, - legacy: legacyKey, - db: db, + cache: cache, + rawCache: rawCache, + transit: transit, + legacy: legacyKey, + db: db, } } -// GetCredentials returns decrypted credentials for a device, using the cache. -// transitCiphertext is the Transit-encrypted string (nullable), legacyCiphertext is the legacy BYTEA (nullable). -// Returns (username, password, error). +// GetRawCredentials returns raw decrypted credential JSON bytes for a device. +// It resolves credentials using the fallback chain: +// 1. Per-device transitCiphertext (highest priority) +// 2. Per-device legacyCiphertext +// 3. Profile transitCiphertext (from credential_profiles via FetchDevices JOIN) +// 4. Profile legacyCiphertext +// +// The cache key includes the source to prevent poisoning when a device +// switches from per-device to profile credentials. +func (c *CredentialCache) GetRawCredentials( + deviceID, tenantID string, + transitCiphertext *string, + legacyCiphertext []byte, + profileTransitCiphertext *string, + profileLegacyCiphertext []byte, +) ([]byte, error) { + + // Determine which ciphertext source to use and the source label. + var activeTransit *string + var activeLegacy []byte + var source string + + if transitCiphertext != nil && *transitCiphertext != "" && strings.HasPrefix(*transitCiphertext, "vault:v") { + activeTransit = transitCiphertext + source = "device" + } else if len(legacyCiphertext) > 0 { + activeLegacy = legacyCiphertext + source = "device" + } else if profileTransitCiphertext != nil && *profileTransitCiphertext != "" && strings.HasPrefix(*profileTransitCiphertext, "vault:v") { + activeTransit = profileTransitCiphertext + source = "profile" + } else if len(profileLegacyCiphertext) > 0 { + activeLegacy = profileLegacyCiphertext + source = "profile" + } else { + return nil, fmt.Errorf("no credentials available for device %s", deviceID) + } + + // Cache key includes source to prevent poisoning across device/profile switch. + cacheKey := "raw:" + deviceID + ":" + source + + // Check raw cache first. + if cached, ok := c.rawCache.Get(cacheKey); ok { + CacheHits.Inc() + return cached, nil + } + CacheMisses.Inc() + + var raw []byte + + // Decrypt using the selected ciphertext source. + if activeTransit != nil { + if c.transit == nil { + return nil, fmt.Errorf("transit ciphertext present but OpenBao client not configured") + } + + start := time.Now() + plaintext, err := c.transit.Decrypt(tenantID, *activeTransit) + OpenBaoLatency.Observe(time.Since(start).Seconds()) + + if err != nil { + return nil, fmt.Errorf("transit decrypt for device %s (%s): %w", deviceID, source, err) + } + raw = plaintext + + // Fire-and-forget key access log INSERT for audit trail. + if c.db != nil { + go c.logKeyAccess(deviceID, tenantID, "decrypt_credentials", "poller_poll") + } + + } else if len(activeLegacy) > 0 { + if c.legacy == nil { + return nil, fmt.Errorf("legacy ciphertext present but encryption key not configured") + } + + plaintext, err := device.DecryptRaw(activeLegacy, c.legacy) + if err != nil { + return nil, fmt.Errorf("legacy decrypt for device %s (%s): %w", deviceID, source, err) + } + raw = plaintext + LegacyDecrypts.Inc() + } + + // Cache the raw bytes. + c.rawCache.Add(cacheKey, raw) + + slog.Debug("credential decrypted and cached (raw)", + "device_id", deviceID, + "source", source, + ) + + return raw, nil +} + +// GetCredentials returns decrypted RouterOS credentials for a device, using the cache. +// This is a backward-compatible wrapper around GetRawCredentials that maintains the +// original (username, password, error) return signature. All existing callers +// (PollDevice, CmdResponder, TunnelResponder, BackupResponder, SSHRelay) continue +// to work without changes. +// +// transitCiphertext is the Transit-encrypted string (nullable), +// legacyCiphertext is the legacy BYTEA (nullable). func (c *CredentialCache) GetCredentials( deviceID, tenantID string, transitCiphertext *string, legacyCiphertext []byte, ) (string, string, error) { - // Check cache first - if cached, ok := c.cache.Get(deviceID); ok { - CacheHits.Inc() - return cached.Username, cached.Password, nil + raw, err := c.GetRawCredentials(deviceID, tenantID, transitCiphertext, legacyCiphertext, nil, nil) + if err != nil { + return "", "", err } - CacheMisses.Inc() - - var username, password string - - // Prefer Transit ciphertext if available - if transitCiphertext != nil && *transitCiphertext != "" && strings.HasPrefix(*transitCiphertext, "vault:v") { - if c.transit == nil { - return "", "", fmt.Errorf("transit ciphertext present but OpenBao client not configured") - } - - start := time.Now() - plaintext, err := c.transit.Decrypt(tenantID, *transitCiphertext) - OpenBaoLatency.Observe(time.Since(start).Seconds()) - - if err != nil { - return "", "", fmt.Errorf("transit decrypt for device %s: %w", deviceID, err) - } - - var creds struct { - Username string `json:"username"` - Password string `json:"password"` - } - if err := json.Unmarshal(plaintext, &creds); err != nil { - return "", "", fmt.Errorf("unmarshal transit-decrypted credentials: %w", err) - } - username = creds.Username - password = creds.Password - - // Fire-and-forget key access log INSERT for audit trail - if c.db != nil { - go c.logKeyAccess(deviceID, tenantID, "decrypt_credentials", "poller_poll") - } - - } else if len(legacyCiphertext) > 0 { - // Fall back to legacy AES-256-GCM decryption - if c.legacy == nil { - return "", "", fmt.Errorf("legacy ciphertext present but encryption key not configured") - } - - var err error - username, password, err = device.DecryptCredentials(legacyCiphertext, c.legacy) - if err != nil { - return "", "", fmt.Errorf("legacy decrypt for device %s: %w", deviceID, err) - } - LegacyDecrypts.Inc() - - } else { - return "", "", fmt.Errorf("no credentials available for device %s", deviceID) - } - - // Cache the result - c.cache.Add(deviceID, &CachedCreds{Username: username, Password: password}) - - slog.Debug("credential decrypted and cached", - "device_id", deviceID, - "source", func() string { - if transitCiphertext != nil && *transitCiphertext != "" { - return "transit" - } - return "legacy" - }(), - ) - - return username, password, nil + return ParseRouterOSCredentials(raw) } // Invalidate removes a device's cached credentials (e.g., after credential rotation). +// Clears both the parsed credential cache and the raw credential cache. func (c *CredentialCache) Invalidate(deviceID string) { c.cache.Remove(deviceID) + // Clear all raw cache entries for this device (both device and profile sources). + c.rawCache.Remove("raw:" + deviceID + ":device") + c.rawCache.Remove("raw:" + deviceID + ":profile") } -// Len returns the number of cached entries. +// Len returns the number of cached entries in the raw credential cache. func (c *CredentialCache) Len() int { - return c.cache.Len() + return c.rawCache.Len() } // logKeyAccess inserts an immutable audit record for a credential decryption event.