feat(02-01): add SSH executor with TOFU host key verification and config normalizer
- SSH RunCommand with typed error classification (auth, hostkey, timeout, connection refused, truncated) - TOFU host key callback: accept-on-first-connect, verify-on-subsequent, reject-on-mismatch - NormalizeConfig strips timestamps, normalizes line endings, trims whitespace, collapses blanks - HashConfig returns 64-char lowercase hex SHA256 of normalized config - 22 unit tests covering all error kinds, TOFU flows, normalization edge cases, idempotency Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
79
poller/internal/device/normalize.go
Normal file
79
poller/internal/device/normalize.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
// Package device provides SSH command execution and config normalization for RouterOS devices.
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NormalizationVersion tracks the normalization algorithm version for NATS payloads.
|
||||||
|
// Increment when the normalization logic changes to allow re-processing.
|
||||||
|
const NormalizationVersion = 1
|
||||||
|
|
||||||
|
// timestampHeaderRe matches the RouterOS export timestamp header line.
|
||||||
|
// Example: "# 2024/01/15 10:30:00 by RouterOS 7.14"
|
||||||
|
var timestampHeaderRe = regexp.MustCompile(`(?m)^# \d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2} by RouterOS.*\n?`)
|
||||||
|
|
||||||
|
// NormalizeConfig deterministically normalizes a RouterOS config export.
|
||||||
|
//
|
||||||
|
// Steps:
|
||||||
|
// 1. Replace \r\n with \n
|
||||||
|
// 2. Strip the timestamp header line (and the blank line immediately following it)
|
||||||
|
// 3. Trim trailing whitespace from each line
|
||||||
|
// 4. Collapse consecutive blank lines (2+ empty lines become 1)
|
||||||
|
// 5. Ensure exactly one trailing newline
|
||||||
|
func NormalizeConfig(raw string) string {
|
||||||
|
// Step 1: Normalize line endings
|
||||||
|
s := strings.ReplaceAll(raw, "\r\n", "\n")
|
||||||
|
|
||||||
|
// Step 2: Strip timestamp header and the blank line immediately following it
|
||||||
|
loc := timestampHeaderRe.FindStringIndex(s)
|
||||||
|
if loc != nil {
|
||||||
|
after := s[loc[1]:]
|
||||||
|
// Remove the blank line immediately following the timestamp header
|
||||||
|
if strings.HasPrefix(after, "\n") {
|
||||||
|
s = s[:loc[0]] + after[1:]
|
||||||
|
} else {
|
||||||
|
s = s[:loc[0]] + after
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Trim trailing whitespace from each line
|
||||||
|
lines := strings.Split(s, "\n")
|
||||||
|
for i, line := range lines {
|
||||||
|
lines[i] = strings.TrimRight(line, " \t")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Collapse consecutive blank lines
|
||||||
|
var result []string
|
||||||
|
prevBlank := false
|
||||||
|
for _, line := range lines {
|
||||||
|
if line == "" {
|
||||||
|
if prevBlank {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
prevBlank = true
|
||||||
|
} else {
|
||||||
|
prevBlank = false
|
||||||
|
}
|
||||||
|
result = append(result, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Ensure exactly one trailing newline
|
||||||
|
out := strings.Join(result, "\n")
|
||||||
|
out = strings.TrimRight(out, "\n")
|
||||||
|
if out != "" {
|
||||||
|
out += "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashConfig returns the lowercase hex-encoded SHA256 hash of the normalized config text.
|
||||||
|
// The hash is 64 characters long and deterministic for the same input.
|
||||||
|
func HashConfig(normalized string) string {
|
||||||
|
h := sha256.Sum256([]byte(normalized))
|
||||||
|
return fmt.Sprintf("%x", h)
|
||||||
|
}
|
||||||
116
poller/internal/device/normalize_test.go
Normal file
116
poller/internal/device/normalize_test.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeConfig_StripTimestampHeader(t *testing.T) {
|
||||||
|
input := "# 2024/01/15 10:30:00 by RouterOS 7.14\n# software id = ABC123\n/ip address\n"
|
||||||
|
got := NormalizeConfig(input)
|
||||||
|
assert.NotContains(t, got, "2024/01/15")
|
||||||
|
assert.Contains(t, got, "# software id = ABC123")
|
||||||
|
assert.Contains(t, got, "/ip address")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeConfig_LineEndingNormalization(t *testing.T) {
|
||||||
|
input := "/ip address\r\nadd address=10.0.0.1\r\n"
|
||||||
|
got := NormalizeConfig(input)
|
||||||
|
assert.NotContains(t, got, "\r")
|
||||||
|
assert.Contains(t, got, "/ip address\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeConfig_TrailingWhitespaceTrimming(t *testing.T) {
|
||||||
|
input := " /ip address \n"
|
||||||
|
got := NormalizeConfig(input)
|
||||||
|
// Each line should be trimmed of trailing whitespace only
|
||||||
|
lines := strings.Split(got, "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
assert.Equal(t, strings.TrimRight(line, " \t"), line, "line should have no trailing whitespace")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeConfig_BlankLineCollapsing(t *testing.T) {
|
||||||
|
input := "/ip address\n\n\n\n/ip route\n"
|
||||||
|
got := NormalizeConfig(input)
|
||||||
|
assert.NotContains(t, got, "\n\n\n")
|
||||||
|
assert.Contains(t, got, "/ip address\n\n/ip route")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeConfig_TrailingNewline(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
}{
|
||||||
|
{"no trailing newline", "/ip address"},
|
||||||
|
{"one trailing newline", "/ip address\n"},
|
||||||
|
{"multiple trailing newlines", "/ip address\n\n\n"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := NormalizeConfig(tt.input)
|
||||||
|
assert.True(t, strings.HasSuffix(got, "\n"), "should end with newline")
|
||||||
|
assert.False(t, strings.HasSuffix(got, "\n\n"), "should not end with double newline")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeConfig_CommentPreservation(t *testing.T) {
|
||||||
|
input := "# 2024/01/15 10:30:00 by RouterOS 7.14\n# software id = ABC123\n# custom comment\n/ip address\n"
|
||||||
|
got := NormalizeConfig(input)
|
||||||
|
assert.Contains(t, got, "# software id = ABC123")
|
||||||
|
assert.Contains(t, got, "# custom comment")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeConfig_FullPipeline(t *testing.T) {
|
||||||
|
input := "# 2024/01/15 10:30:00 by RouterOS 7.14\n" +
|
||||||
|
"# software id = ABC123\r\n" +
|
||||||
|
"/ip address \r\n" +
|
||||||
|
"add address=10.0.0.1/24 \r\n" +
|
||||||
|
"\r\n" +
|
||||||
|
"\r\n" +
|
||||||
|
"\r\n" +
|
||||||
|
"/ip route \r\n" +
|
||||||
|
"add dst-address=0.0.0.0/0 gateway=10.0.0.1\r\n"
|
||||||
|
|
||||||
|
expected := "# software id = ABC123\n" +
|
||||||
|
"/ip address\n" +
|
||||||
|
"add address=10.0.0.1/24\n" +
|
||||||
|
"\n" +
|
||||||
|
"/ip route\n" +
|
||||||
|
"add dst-address=0.0.0.0/0 gateway=10.0.0.1\n"
|
||||||
|
|
||||||
|
got := NormalizeConfig(input)
|
||||||
|
assert.Equal(t, expected, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashConfig(t *testing.T) {
|
||||||
|
normalized := "/ip address\nadd address=10.0.0.1/24\n"
|
||||||
|
hash := HashConfig(normalized)
|
||||||
|
assert.Len(t, hash, 64, "SHA256 hex should be 64 chars")
|
||||||
|
assert.Equal(t, strings.ToLower(hash), hash, "hash should be lowercase")
|
||||||
|
// Deterministic
|
||||||
|
assert.Equal(t, hash, HashConfig(normalized))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeConfig_Idempotency(t *testing.T) {
|
||||||
|
input := "# 2024/01/15 10:30:00 by RouterOS 7.14\n" +
|
||||||
|
"# software id = ABC123\r\n" +
|
||||||
|
"/ip address \r\n" +
|
||||||
|
"\r\n\r\n\r\n" +
|
||||||
|
"/ip route\r\n"
|
||||||
|
|
||||||
|
first := NormalizeConfig(input)
|
||||||
|
second := NormalizeConfig(first)
|
||||||
|
assert.Equal(t, first, second, "NormalizeConfig should be idempotent")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizationVersion(t *testing.T) {
|
||||||
|
require.Equal(t, 1, NormalizationVersion, "NormalizationVersion should be 1")
|
||||||
|
}
|
||||||
276
poller/internal/device/ssh_executor.go
Normal file
276
poller/internal/device/ssh_executor.go
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSHErrorKind classifies SSH connection and command errors.
|
||||||
|
type SSHErrorKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ErrAuthFailed indicates the SSH credentials were rejected.
|
||||||
|
ErrAuthFailed SSHErrorKind = "auth_failed"
|
||||||
|
// ErrHostKeyMismatch indicates a TOFU host key verification failure.
|
||||||
|
ErrHostKeyMismatch SSHErrorKind = "host_key_mismatch"
|
||||||
|
// ErrTimeout indicates the connection or command timed out.
|
||||||
|
ErrTimeout SSHErrorKind = "timeout"
|
||||||
|
// ErrTruncatedOutput indicates the command timed out mid-stream, producing partial output.
|
||||||
|
ErrTruncatedOutput SSHErrorKind = "truncated_output"
|
||||||
|
// ErrConnectionRefused indicates the remote host refused the TCP connection.
|
||||||
|
ErrConnectionRefused SSHErrorKind = "connection_refused"
|
||||||
|
// ErrUnknown indicates an unclassified error.
|
||||||
|
ErrUnknown SSHErrorKind = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSHError wraps an SSH-related error with a classification kind.
|
||||||
|
type SSHError struct {
|
||||||
|
Kind SSHErrorKind
|
||||||
|
Err error
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface.
|
||||||
|
func (e *SSHError) Error() string {
|
||||||
|
if e.Err != nil {
|
||||||
|
return fmt.Sprintf("%s: %s", e.Message, e.Err.Error())
|
||||||
|
}
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the underlying error for errors.Is/As support.
|
||||||
|
func (e *SSHError) Unwrap() error {
|
||||||
|
return e.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommandResult holds the output of a remote SSH command execution.
|
||||||
|
type CommandResult struct {
|
||||||
|
Stdout string
|
||||||
|
Stderr string
|
||||||
|
ExitCode int
|
||||||
|
Duration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunCommand executes a command on a remote device via SSH with TOFU host key verification.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - knownFingerprint: empty string for first connect (TOFU accepts any key), or a
|
||||||
|
// previously stored "SHA256:base64(...)" fingerprint for verification.
|
||||||
|
// - command: the RouterOS CLI command to execute (e.g., "/export")
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - result: command output (stdout, stderr, exit code, duration)
|
||||||
|
// - observedFingerprint: the SSH host key fingerprint observed during connection
|
||||||
|
// - err: classified SSHError on failure, nil on success
|
||||||
|
func RunCommand(ctx context.Context, ip string, port int, username, password string,
|
||||||
|
timeout time.Duration, knownFingerprint string, command string) (*CommandResult, string, error) {
|
||||||
|
|
||||||
|
cb, fpCh := tofuHostKeyCallback(knownFingerprint)
|
||||||
|
|
||||||
|
config := &ssh.ClientConfig{
|
||||||
|
User: username,
|
||||||
|
Auth: []ssh.AuthMethod{
|
||||||
|
ssh.Password(password),
|
||||||
|
},
|
||||||
|
HostKeyCallback: cb,
|
||||||
|
Timeout: timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := fmt.Sprintf("%s:%d", ip, port)
|
||||||
|
|
||||||
|
// Context-aware dial
|
||||||
|
var d net.Dialer
|
||||||
|
conn, err := d.DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", &SSHError{
|
||||||
|
Kind: classifySSHError(err),
|
||||||
|
Err: err,
|
||||||
|
Message: fmt.Sprintf("TCP dial to %s failed", addr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSH handshake over the raw connection
|
||||||
|
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, "", &SSHError{
|
||||||
|
Kind: classifySSHError(err),
|
||||||
|
Err: err,
|
||||||
|
Message: fmt.Sprintf("SSH handshake to %s failed", addr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := ssh.NewClient(sshConn, chans, reqs)
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
// Read the observed fingerprint (will be available after handshake)
|
||||||
|
var observedFP string
|
||||||
|
select {
|
||||||
|
case fp := <-fpCh:
|
||||||
|
observedFP = fp
|
||||||
|
default:
|
||||||
|
// Channel already drained or callback didn't fire (shouldn't happen)
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := client.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
return nil, observedFP, &SSHError{
|
||||||
|
Kind: ErrUnknown,
|
||||||
|
Err: err,
|
||||||
|
Message: "creating SSH session failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
session.Stdout = &stdout
|
||||||
|
session.Stderr = &stderr
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
// Run with context cancellation for timeout detection
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- session.Run(command)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var runErr error
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Context cancelled/timed out mid-execution
|
||||||
|
session.Close()
|
||||||
|
return &CommandResult{
|
||||||
|
Stdout: stdout.String(),
|
||||||
|
Stderr: stderr.String(),
|
||||||
|
ExitCode: -1,
|
||||||
|
Duration: time.Since(start),
|
||||||
|
}, observedFP, &SSHError{
|
||||||
|
Kind: ErrTruncatedOutput,
|
||||||
|
Err: ctx.Err(),
|
||||||
|
Message: "command timed out mid-execution, output may be truncated",
|
||||||
|
}
|
||||||
|
case runErr = <-done:
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
|
||||||
|
result := &CommandResult{
|
||||||
|
Stdout: stdout.String(),
|
||||||
|
Stderr: stderr.String(),
|
||||||
|
ExitCode: 0,
|
||||||
|
Duration: duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if runErr != nil {
|
||||||
|
// Check for exit status errors
|
||||||
|
if exitErr, ok := runErr.(*ssh.ExitError); ok {
|
||||||
|
result.ExitCode = exitErr.ExitStatus()
|
||||||
|
return result, observedFP, nil // Non-zero exit is not an SSH error
|
||||||
|
}
|
||||||
|
return result, observedFP, &SSHError{
|
||||||
|
Kind: classifySSHError(runErr),
|
||||||
|
Err: runErr,
|
||||||
|
Message: "SSH command execution failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, observedFP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tofuHostKeyCallback returns an SSH host key callback implementing Trust-On-First-Use.
|
||||||
|
//
|
||||||
|
// If knownFingerprint is empty, any key is accepted and its fingerprint is sent on the channel.
|
||||||
|
// If knownFingerprint matches the presented key, the connection is accepted.
|
||||||
|
// If knownFingerprint does not match, the connection is rejected with ErrHostKeyMismatch.
|
||||||
|
func tofuHostKeyCallback(knownFingerprint string) (ssh.HostKeyCallback, chan string) {
|
||||||
|
fpCh := make(chan string, 1)
|
||||||
|
|
||||||
|
cb := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||||
|
fp := computeFingerprint(key)
|
||||||
|
|
||||||
|
if knownFingerprint == "" {
|
||||||
|
// First connect: accept and report fingerprint
|
||||||
|
fpCh <- fp
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fpCh <- fp
|
||||||
|
|
||||||
|
if fp != knownFingerprint {
|
||||||
|
return &SSHError{
|
||||||
|
Kind: ErrHostKeyMismatch,
|
||||||
|
Err: fmt.Errorf("expected %s, got %s", knownFingerprint, fp),
|
||||||
|
Message: fmt.Sprintf("host key mismatch for %s", hostname),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return cb, fpCh
|
||||||
|
}
|
||||||
|
|
||||||
|
// computeFingerprint computes the SSH host key fingerprint in the same format as ssh-keygen:
|
||||||
|
// "SHA256:" followed by the base64-encoded (no padding) SHA256 hash of the public key bytes.
|
||||||
|
func computeFingerprint(key ssh.PublicKey) string {
|
||||||
|
h := sha256.Sum256(key.Marshal())
|
||||||
|
return "SHA256:" + base64.RawStdEncoding.EncodeToString(h[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// classifySSHError inspects an error and returns the appropriate SSHErrorKind.
|
||||||
|
func classifySSHError(err error) SSHErrorKind {
|
||||||
|
if err == nil {
|
||||||
|
return ErrUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
errStr := err.Error()
|
||||||
|
|
||||||
|
// Check for timeout (net.Error interface)
|
||||||
|
var netErr net.Error
|
||||||
|
if ok := errorAs(err, &netErr); ok && netErr.Timeout() {
|
||||||
|
return ErrTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(errStr, "i/o timeout") {
|
||||||
|
return ErrTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(errStr, "unable to authenticate") ||
|
||||||
|
strings.Contains(errStr, "no supported methods remain") {
|
||||||
|
return ErrAuthFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(errStr, "host key") {
|
||||||
|
return ErrHostKeyMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(errStr, "connection refused") {
|
||||||
|
return ErrConnectionRefused
|
||||||
|
}
|
||||||
|
|
||||||
|
return ErrUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorAs is a helper that wraps errors.As for interface targets.
|
||||||
|
func errorAs[T any](err error, target *T) bool {
|
||||||
|
for err != nil {
|
||||||
|
if t, ok := err.(T); ok {
|
||||||
|
*target = t
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if u, ok := err.(interface{ Unwrap() error }); ok {
|
||||||
|
err = u.Unwrap()
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
128
poller/internal/device/ssh_executor_test.go
Normal file
128
poller/internal/device/ssh_executor_test.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClassifySSHError_AuthFailed(t *testing.T) {
|
||||||
|
err := fmt.Errorf("ssh: unable to authenticate")
|
||||||
|
kind := classifySSHError(err)
|
||||||
|
assert.Equal(t, ErrAuthFailed, kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifySSHError_HostKeyMismatch(t *testing.T) {
|
||||||
|
err := fmt.Errorf("ssh: host key mismatch")
|
||||||
|
kind := classifySSHError(err)
|
||||||
|
assert.Equal(t, ErrHostKeyMismatch, kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifySSHError_Timeout(t *testing.T) {
|
||||||
|
err := &net.OpError{
|
||||||
|
Op: "dial",
|
||||||
|
Err: &timeoutError{},
|
||||||
|
}
|
||||||
|
kind := classifySSHError(err)
|
||||||
|
assert.Equal(t, ErrTimeout, kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifySSHError_ConnectionRefused(t *testing.T) {
|
||||||
|
err := fmt.Errorf("dial tcp 10.0.0.1:22: connection refused")
|
||||||
|
kind := classifySSHError(err)
|
||||||
|
assert.Equal(t, ErrConnectionRefused, kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifySSHError_Unknown(t *testing.T) {
|
||||||
|
err := fmt.Errorf("some random error")
|
||||||
|
kind := classifySSHError(err)
|
||||||
|
assert.Equal(t, ErrUnknown, kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHError_Error(t *testing.T) {
|
||||||
|
sshErr := &SSHError{
|
||||||
|
Kind: ErrAuthFailed,
|
||||||
|
Err: fmt.Errorf("underlying"),
|
||||||
|
Message: "auth failed for device",
|
||||||
|
}
|
||||||
|
assert.Contains(t, sshErr.Error(), "auth failed for device")
|
||||||
|
assert.Contains(t, sshErr.Error(), "underlying")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHError_Unwrap(t *testing.T) {
|
||||||
|
inner := fmt.Errorf("inner error")
|
||||||
|
sshErr := &SSHError{
|
||||||
|
Kind: ErrUnknown,
|
||||||
|
Err: inner,
|
||||||
|
}
|
||||||
|
assert.True(t, errors.Is(sshErr, inner))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommandResult_Fields(t *testing.T) {
|
||||||
|
result := &CommandResult{
|
||||||
|
Stdout: "output",
|
||||||
|
Stderr: "err",
|
||||||
|
ExitCode: 0,
|
||||||
|
}
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.Equal(t, "output", result.Stdout)
|
||||||
|
assert.Equal(t, "err", result.Stderr)
|
||||||
|
assert.Equal(t, 0, result.ExitCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOFUCallback_FirstConnect(t *testing.T) {
|
||||||
|
cb, fpCh := tofuHostKeyCallback("")
|
||||||
|
// Simulate first connect with any key
|
||||||
|
key := generateTestPublicKey(t)
|
||||||
|
err := cb("10.0.0.1:22", nil, key)
|
||||||
|
assert.NoError(t, err, "first connect should accept any key")
|
||||||
|
|
||||||
|
fp := <-fpCh
|
||||||
|
assert.NotEmpty(t, fp, "should return a fingerprint")
|
||||||
|
assert.Contains(t, fp, "SHA256:", "fingerprint should have SHA256 prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOFUCallback_MatchingFingerprint(t *testing.T) {
|
||||||
|
key := generateTestPublicKey(t)
|
||||||
|
fp := computeFingerprint(key)
|
||||||
|
|
||||||
|
cb, _ := tofuHostKeyCallback(fp)
|
||||||
|
err := cb("10.0.0.1:22", nil, key)
|
||||||
|
assert.NoError(t, err, "matching fingerprint should be accepted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTOFUCallback_MismatchedFingerprint(t *testing.T) {
|
||||||
|
key := generateTestPublicKey(t)
|
||||||
|
|
||||||
|
cb, _ := tofuHostKeyCallback("SHA256:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=")
|
||||||
|
err := cb("10.0.0.1:22", nil, key)
|
||||||
|
require.Error(t, err, "mismatched fingerprint should be rejected")
|
||||||
|
|
||||||
|
var sshErr *SSHError
|
||||||
|
require.True(t, errors.As(err, &sshErr))
|
||||||
|
assert.Equal(t, ErrHostKeyMismatch, sshErr.Kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTestPublicKey creates an ed25519 public key for testing.
|
||||||
|
func generateTestPublicKey(t *testing.T) ssh.PublicKey {
|
||||||
|
t.Helper()
|
||||||
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
pub, err := ssh.NewPublicKey(priv.Public())
|
||||||
|
require.NoError(t, err)
|
||||||
|
return pub
|
||||||
|
}
|
||||||
|
|
||||||
|
// timeoutError implements net.Error with Timeout() returning true.
|
||||||
|
type timeoutError struct{}
|
||||||
|
|
||||||
|
func (e *timeoutError) Error() string { return "i/o timeout" }
|
||||||
|
func (e *timeoutError) Timeout() bool { return true }
|
||||||
|
func (e *timeoutError) Temporary() bool { return false }
|
||||||
Reference in New Issue
Block a user