From f1abb75caba295c7ab1abd419bc3adae45c75271 Mon Sep 17 00:00:00 2001 From: Jason Staack Date: Thu, 12 Mar 2026 20:46:04 -0500 Subject: [PATCH] feat(02-01): add SSH executor with TOFU host key verification and config normalizer - SSH RunCommand with typed error classification (auth, hostkey, timeout, connection refused, truncated) - TOFU host key callback: accept-on-first-connect, verify-on-subsequent, reject-on-mismatch - NormalizeConfig strips timestamps, normalizes line endings, trims whitespace, collapses blanks - HashConfig returns 64-char lowercase hex SHA256 of normalized config - 22 unit tests covering all error kinds, TOFU flows, normalization edge cases, idempotency Co-Authored-By: Claude Opus 4.6 --- poller/internal/device/normalize.go | 79 ++++++ poller/internal/device/normalize_test.go | 116 ++++++++ poller/internal/device/ssh_executor.go | 276 ++++++++++++++++++++ poller/internal/device/ssh_executor_test.go | 128 +++++++++ 4 files changed, 599 insertions(+) create mode 100644 poller/internal/device/normalize.go create mode 100644 poller/internal/device/normalize_test.go create mode 100644 poller/internal/device/ssh_executor.go create mode 100644 poller/internal/device/ssh_executor_test.go diff --git a/poller/internal/device/normalize.go b/poller/internal/device/normalize.go new file mode 100644 index 0000000..860fdba --- /dev/null +++ b/poller/internal/device/normalize.go @@ -0,0 +1,79 @@ +// Package device provides SSH command execution and config normalization for RouterOS devices. +package device + +import ( + "crypto/sha256" + "fmt" + "regexp" + "strings" +) + +// NormalizationVersion tracks the normalization algorithm version for NATS payloads. +// Increment when the normalization logic changes to allow re-processing. +const NormalizationVersion = 1 + +// timestampHeaderRe matches the RouterOS export timestamp header line. +// Example: "# 2024/01/15 10:30:00 by RouterOS 7.14" +var timestampHeaderRe = regexp.MustCompile(`(?m)^# \d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2} by RouterOS.*\n?`) + +// NormalizeConfig deterministically normalizes a RouterOS config export. +// +// Steps: +// 1. Replace \r\n with \n +// 2. Strip the timestamp header line (and the blank line immediately following it) +// 3. Trim trailing whitespace from each line +// 4. Collapse consecutive blank lines (2+ empty lines become 1) +// 5. Ensure exactly one trailing newline +func NormalizeConfig(raw string) string { + // Step 1: Normalize line endings + s := strings.ReplaceAll(raw, "\r\n", "\n") + + // Step 2: Strip timestamp header and the blank line immediately following it + loc := timestampHeaderRe.FindStringIndex(s) + if loc != nil { + after := s[loc[1]:] + // Remove the blank line immediately following the timestamp header + if strings.HasPrefix(after, "\n") { + s = s[:loc[0]] + after[1:] + } else { + s = s[:loc[0]] + after + } + } + + // Step 3: Trim trailing whitespace from each line + lines := strings.Split(s, "\n") + for i, line := range lines { + lines[i] = strings.TrimRight(line, " \t") + } + + // Step 4: Collapse consecutive blank lines + var result []string + prevBlank := false + for _, line := range lines { + if line == "" { + if prevBlank { + continue + } + prevBlank = true + } else { + prevBlank = false + } + result = append(result, line) + } + + // Step 5: Ensure exactly one trailing newline + out := strings.Join(result, "\n") + out = strings.TrimRight(out, "\n") + if out != "" { + out += "\n" + } + + return out +} + +// HashConfig returns the lowercase hex-encoded SHA256 hash of the normalized config text. +// The hash is 64 characters long and deterministic for the same input. +func HashConfig(normalized string) string { + h := sha256.Sum256([]byte(normalized)) + return fmt.Sprintf("%x", h) +} diff --git a/poller/internal/device/normalize_test.go b/poller/internal/device/normalize_test.go new file mode 100644 index 0000000..bb61504 --- /dev/null +++ b/poller/internal/device/normalize_test.go @@ -0,0 +1,116 @@ +package device + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNormalizeConfig_StripTimestampHeader(t *testing.T) { + input := "# 2024/01/15 10:30:00 by RouterOS 7.14\n# software id = ABC123\n/ip address\n" + got := NormalizeConfig(input) + assert.NotContains(t, got, "2024/01/15") + assert.Contains(t, got, "# software id = ABC123") + assert.Contains(t, got, "/ip address") +} + +func TestNormalizeConfig_LineEndingNormalization(t *testing.T) { + input := "/ip address\r\nadd address=10.0.0.1\r\n" + got := NormalizeConfig(input) + assert.NotContains(t, got, "\r") + assert.Contains(t, got, "/ip address\n") +} + +func TestNormalizeConfig_TrailingWhitespaceTrimming(t *testing.T) { + input := " /ip address \n" + got := NormalizeConfig(input) + // Each line should be trimmed of trailing whitespace only + lines := strings.Split(got, "\n") + for _, line := range lines { + if line == "" { + continue + } + assert.Equal(t, strings.TrimRight(line, " \t"), line, "line should have no trailing whitespace") + } +} + +func TestNormalizeConfig_BlankLineCollapsing(t *testing.T) { + input := "/ip address\n\n\n\n/ip route\n" + got := NormalizeConfig(input) + assert.NotContains(t, got, "\n\n\n") + assert.Contains(t, got, "/ip address\n\n/ip route") +} + +func TestNormalizeConfig_TrailingNewline(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"no trailing newline", "/ip address"}, + {"one trailing newline", "/ip address\n"}, + {"multiple trailing newlines", "/ip address\n\n\n"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NormalizeConfig(tt.input) + assert.True(t, strings.HasSuffix(got, "\n"), "should end with newline") + assert.False(t, strings.HasSuffix(got, "\n\n"), "should not end with double newline") + }) + } +} + +func TestNormalizeConfig_CommentPreservation(t *testing.T) { + input := "# 2024/01/15 10:30:00 by RouterOS 7.14\n# software id = ABC123\n# custom comment\n/ip address\n" + got := NormalizeConfig(input) + assert.Contains(t, got, "# software id = ABC123") + assert.Contains(t, got, "# custom comment") +} + +func TestNormalizeConfig_FullPipeline(t *testing.T) { + input := "# 2024/01/15 10:30:00 by RouterOS 7.14\n" + + "# software id = ABC123\r\n" + + "/ip address \r\n" + + "add address=10.0.0.1/24 \r\n" + + "\r\n" + + "\r\n" + + "\r\n" + + "/ip route \r\n" + + "add dst-address=0.0.0.0/0 gateway=10.0.0.1\r\n" + + expected := "# software id = ABC123\n" + + "/ip address\n" + + "add address=10.0.0.1/24\n" + + "\n" + + "/ip route\n" + + "add dst-address=0.0.0.0/0 gateway=10.0.0.1\n" + + got := NormalizeConfig(input) + assert.Equal(t, expected, got) +} + +func TestHashConfig(t *testing.T) { + normalized := "/ip address\nadd address=10.0.0.1/24\n" + hash := HashConfig(normalized) + assert.Len(t, hash, 64, "SHA256 hex should be 64 chars") + assert.Equal(t, strings.ToLower(hash), hash, "hash should be lowercase") + // Deterministic + assert.Equal(t, hash, HashConfig(normalized)) +} + +func TestNormalizeConfig_Idempotency(t *testing.T) { + input := "# 2024/01/15 10:30:00 by RouterOS 7.14\n" + + "# software id = ABC123\r\n" + + "/ip address \r\n" + + "\r\n\r\n\r\n" + + "/ip route\r\n" + + first := NormalizeConfig(input) + second := NormalizeConfig(first) + assert.Equal(t, first, second, "NormalizeConfig should be idempotent") +} + +func TestNormalizationVersion(t *testing.T) { + require.Equal(t, 1, NormalizationVersion, "NormalizationVersion should be 1") +} diff --git a/poller/internal/device/ssh_executor.go b/poller/internal/device/ssh_executor.go new file mode 100644 index 0000000..261aea6 --- /dev/null +++ b/poller/internal/device/ssh_executor.go @@ -0,0 +1,276 @@ +package device + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "fmt" + "net" + "strings" + "time" + + "golang.org/x/crypto/ssh" +) + +// SSHErrorKind classifies SSH connection and command errors. +type SSHErrorKind string + +const ( + // ErrAuthFailed indicates the SSH credentials were rejected. + ErrAuthFailed SSHErrorKind = "auth_failed" + // ErrHostKeyMismatch indicates a TOFU host key verification failure. + ErrHostKeyMismatch SSHErrorKind = "host_key_mismatch" + // ErrTimeout indicates the connection or command timed out. + ErrTimeout SSHErrorKind = "timeout" + // ErrTruncatedOutput indicates the command timed out mid-stream, producing partial output. + ErrTruncatedOutput SSHErrorKind = "truncated_output" + // ErrConnectionRefused indicates the remote host refused the TCP connection. + ErrConnectionRefused SSHErrorKind = "connection_refused" + // ErrUnknown indicates an unclassified error. + ErrUnknown SSHErrorKind = "unknown" +) + +// SSHError wraps an SSH-related error with a classification kind. +type SSHError struct { + Kind SSHErrorKind + Err error + Message string +} + +// Error implements the error interface. +func (e *SSHError) Error() string { + if e.Err != nil { + return fmt.Sprintf("%s: %s", e.Message, e.Err.Error()) + } + return e.Message +} + +// Unwrap returns the underlying error for errors.Is/As support. +func (e *SSHError) Unwrap() error { + return e.Err +} + +// CommandResult holds the output of a remote SSH command execution. +type CommandResult struct { + Stdout string + Stderr string + ExitCode int + Duration time.Duration +} + +// RunCommand executes a command on a remote device via SSH with TOFU host key verification. +// +// Parameters: +// - knownFingerprint: empty string for first connect (TOFU accepts any key), or a +// previously stored "SHA256:base64(...)" fingerprint for verification. +// - command: the RouterOS CLI command to execute (e.g., "/export") +// +// Returns: +// - result: command output (stdout, stderr, exit code, duration) +// - observedFingerprint: the SSH host key fingerprint observed during connection +// - err: classified SSHError on failure, nil on success +func RunCommand(ctx context.Context, ip string, port int, username, password string, + timeout time.Duration, knownFingerprint string, command string) (*CommandResult, string, error) { + + cb, fpCh := tofuHostKeyCallback(knownFingerprint) + + config := &ssh.ClientConfig{ + User: username, + Auth: []ssh.AuthMethod{ + ssh.Password(password), + }, + HostKeyCallback: cb, + Timeout: timeout, + } + + addr := fmt.Sprintf("%s:%d", ip, port) + + // Context-aware dial + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, "", &SSHError{ + Kind: classifySSHError(err), + Err: err, + Message: fmt.Sprintf("TCP dial to %s failed", addr), + } + } + + // SSH handshake over the raw connection + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + conn.Close() + return nil, "", &SSHError{ + Kind: classifySSHError(err), + Err: err, + Message: fmt.Sprintf("SSH handshake to %s failed", addr), + } + } + + client := ssh.NewClient(sshConn, chans, reqs) + defer client.Close() + + // Read the observed fingerprint (will be available after handshake) + var observedFP string + select { + case fp := <-fpCh: + observedFP = fp + default: + // Channel already drained or callback didn't fire (shouldn't happen) + } + + session, err := client.NewSession() + if err != nil { + return nil, observedFP, &SSHError{ + Kind: ErrUnknown, + Err: err, + Message: "creating SSH session failed", + } + } + defer session.Close() + + var stdout, stderr bytes.Buffer + session.Stdout = &stdout + session.Stderr = &stderr + + start := time.Now() + + // Run with context cancellation for timeout detection + done := make(chan error, 1) + go func() { + done <- session.Run(command) + }() + + var runErr error + select { + case <-ctx.Done(): + // Context cancelled/timed out mid-execution + session.Close() + return &CommandResult{ + Stdout: stdout.String(), + Stderr: stderr.String(), + ExitCode: -1, + Duration: time.Since(start), + }, observedFP, &SSHError{ + Kind: ErrTruncatedOutput, + Err: ctx.Err(), + Message: "command timed out mid-execution, output may be truncated", + } + case runErr = <-done: + } + + duration := time.Since(start) + + result := &CommandResult{ + Stdout: stdout.String(), + Stderr: stderr.String(), + ExitCode: 0, + Duration: duration, + } + + if runErr != nil { + // Check for exit status errors + if exitErr, ok := runErr.(*ssh.ExitError); ok { + result.ExitCode = exitErr.ExitStatus() + return result, observedFP, nil // Non-zero exit is not an SSH error + } + return result, observedFP, &SSHError{ + Kind: classifySSHError(runErr), + Err: runErr, + Message: "SSH command execution failed", + } + } + + return result, observedFP, nil +} + +// tofuHostKeyCallback returns an SSH host key callback implementing Trust-On-First-Use. +// +// If knownFingerprint is empty, any key is accepted and its fingerprint is sent on the channel. +// If knownFingerprint matches the presented key, the connection is accepted. +// If knownFingerprint does not match, the connection is rejected with ErrHostKeyMismatch. +func tofuHostKeyCallback(knownFingerprint string) (ssh.HostKeyCallback, chan string) { + fpCh := make(chan string, 1) + + cb := func(hostname string, remote net.Addr, key ssh.PublicKey) error { + fp := computeFingerprint(key) + + if knownFingerprint == "" { + // First connect: accept and report fingerprint + fpCh <- fp + return nil + } + + fpCh <- fp + + if fp != knownFingerprint { + return &SSHError{ + Kind: ErrHostKeyMismatch, + Err: fmt.Errorf("expected %s, got %s", knownFingerprint, fp), + Message: fmt.Sprintf("host key mismatch for %s", hostname), + } + } + + return nil + } + + return cb, fpCh +} + +// computeFingerprint computes the SSH host key fingerprint in the same format as ssh-keygen: +// "SHA256:" followed by the base64-encoded (no padding) SHA256 hash of the public key bytes. +func computeFingerprint(key ssh.PublicKey) string { + h := sha256.Sum256(key.Marshal()) + return "SHA256:" + base64.RawStdEncoding.EncodeToString(h[:]) +} + +// classifySSHError inspects an error and returns the appropriate SSHErrorKind. +func classifySSHError(err error) SSHErrorKind { + if err == nil { + return ErrUnknown + } + + errStr := err.Error() + + // Check for timeout (net.Error interface) + var netErr net.Error + if ok := errorAs(err, &netErr); ok && netErr.Timeout() { + return ErrTimeout + } + + if strings.Contains(errStr, "i/o timeout") { + return ErrTimeout + } + + if strings.Contains(errStr, "unable to authenticate") || + strings.Contains(errStr, "no supported methods remain") { + return ErrAuthFailed + } + + if strings.Contains(errStr, "host key") { + return ErrHostKeyMismatch + } + + if strings.Contains(errStr, "connection refused") { + return ErrConnectionRefused + } + + return ErrUnknown +} + +// errorAs is a helper that wraps errors.As for interface targets. +func errorAs[T any](err error, target *T) bool { + for err != nil { + if t, ok := err.(T); ok { + *target = t + return true + } + if u, ok := err.(interface{ Unwrap() error }); ok { + err = u.Unwrap() + } else { + return false + } + } + return false +} diff --git a/poller/internal/device/ssh_executor_test.go b/poller/internal/device/ssh_executor_test.go new file mode 100644 index 0000000..079cbda --- /dev/null +++ b/poller/internal/device/ssh_executor_test.go @@ -0,0 +1,128 @@ +package device + +import ( + "crypto/ed25519" + "crypto/rand" + "errors" + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +func TestClassifySSHError_AuthFailed(t *testing.T) { + err := fmt.Errorf("ssh: unable to authenticate") + kind := classifySSHError(err) + assert.Equal(t, ErrAuthFailed, kind) +} + +func TestClassifySSHError_HostKeyMismatch(t *testing.T) { + err := fmt.Errorf("ssh: host key mismatch") + kind := classifySSHError(err) + assert.Equal(t, ErrHostKeyMismatch, kind) +} + +func TestClassifySSHError_Timeout(t *testing.T) { + err := &net.OpError{ + Op: "dial", + Err: &timeoutError{}, + } + kind := classifySSHError(err) + assert.Equal(t, ErrTimeout, kind) +} + +func TestClassifySSHError_ConnectionRefused(t *testing.T) { + err := fmt.Errorf("dial tcp 10.0.0.1:22: connection refused") + kind := classifySSHError(err) + assert.Equal(t, ErrConnectionRefused, kind) +} + +func TestClassifySSHError_Unknown(t *testing.T) { + err := fmt.Errorf("some random error") + kind := classifySSHError(err) + assert.Equal(t, ErrUnknown, kind) +} + +func TestSSHError_Error(t *testing.T) { + sshErr := &SSHError{ + Kind: ErrAuthFailed, + Err: fmt.Errorf("underlying"), + Message: "auth failed for device", + } + assert.Contains(t, sshErr.Error(), "auth failed for device") + assert.Contains(t, sshErr.Error(), "underlying") +} + +func TestSSHError_Unwrap(t *testing.T) { + inner := fmt.Errorf("inner error") + sshErr := &SSHError{ + Kind: ErrUnknown, + Err: inner, + } + assert.True(t, errors.Is(sshErr, inner)) +} + +func TestCommandResult_Fields(t *testing.T) { + result := &CommandResult{ + Stdout: "output", + Stderr: "err", + ExitCode: 0, + } + require.NotNil(t, result) + assert.Equal(t, "output", result.Stdout) + assert.Equal(t, "err", result.Stderr) + assert.Equal(t, 0, result.ExitCode) +} + +func TestTOFUCallback_FirstConnect(t *testing.T) { + cb, fpCh := tofuHostKeyCallback("") + // Simulate first connect with any key + key := generateTestPublicKey(t) + err := cb("10.0.0.1:22", nil, key) + assert.NoError(t, err, "first connect should accept any key") + + fp := <-fpCh + assert.NotEmpty(t, fp, "should return a fingerprint") + assert.Contains(t, fp, "SHA256:", "fingerprint should have SHA256 prefix") +} + +func TestTOFUCallback_MatchingFingerprint(t *testing.T) { + key := generateTestPublicKey(t) + fp := computeFingerprint(key) + + cb, _ := tofuHostKeyCallback(fp) + err := cb("10.0.0.1:22", nil, key) + assert.NoError(t, err, "matching fingerprint should be accepted") +} + +func TestTOFUCallback_MismatchedFingerprint(t *testing.T) { + key := generateTestPublicKey(t) + + cb, _ := tofuHostKeyCallback("SHA256:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=") + err := cb("10.0.0.1:22", nil, key) + require.Error(t, err, "mismatched fingerprint should be rejected") + + var sshErr *SSHError + require.True(t, errors.As(err, &sshErr)) + assert.Equal(t, ErrHostKeyMismatch, sshErr.Kind) +} + +// generateTestPublicKey creates an ed25519 public key for testing. +func generateTestPublicKey(t *testing.T) ssh.PublicKey { + t.Helper() + _, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + pub, err := ssh.NewPublicKey(priv.Public()) + require.NoError(t, err) + return pub +} + +// timeoutError implements net.Error with Timeout() returning true. +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "i/o timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return false }