From c73466c5e05427458498084f7ade026c7dae6d75 Mon Sep 17 00:00:00 2001 From: Jason Staack Date: Thu, 12 Mar 2026 15:33:48 -0500 Subject: [PATCH] feat(poller): add SSH relay server with WebSocket-to-PTY bridge Implements the SSH relay server (Task 2.1) that validates single-use Redis tokens via GETDEL, dials SSH to the target device with PTY, and bridges WebSocket binary/text frames to SSH stdin/stdout/stderr with idle timeout and per-user/per-device session limits. Co-Authored-By: Claude Opus 4.6 --- poller/go.mod | 2 + poller/go.sum | 4 + poller/internal/sshrelay/bridge.go | 75 ++++++ poller/internal/sshrelay/server.go | 337 ++++++++++++++++++++++++ poller/internal/sshrelay/server_test.go | 104 ++++++++ poller/internal/sshrelay/session.go | 28 ++ 6 files changed, 550 insertions(+) create mode 100644 poller/internal/sshrelay/bridge.go create mode 100644 poller/internal/sshrelay/server.go create mode 100644 poller/internal/sshrelay/server_test.go create mode 100644 poller/internal/sshrelay/session.go diff --git a/poller/go.mod b/poller/go.mod index 0ef8243..b9241b6 100644 --- a/poller/go.mod +++ b/poller/go.mod @@ -3,6 +3,7 @@ module github.com/mikrotik-portal/poller go 1.24.0 require ( + github.com/alicebob/miniredis/v2 v2.37.0 github.com/bsm/redislock v0.9.4 github.com/go-routeros/routeros/v3 v3.0.0 github.com/google/uuid v1.6.0 @@ -76,6 +77,7 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect diff --git a/poller/go.sum b/poller/go.sum index 3e952e0..faec068 100644 --- a/poller/go.sum +++ b/poller/go.sum @@ -6,6 +6,8 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25 github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68= +github.com/alicebob/miniredis/v2 v2.37.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -164,6 +166,8 @@ github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFA github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= diff --git a/poller/internal/sshrelay/bridge.go b/poller/internal/sshrelay/bridge.go new file mode 100644 index 0000000..d251ce3 --- /dev/null +++ b/poller/internal/sshrelay/bridge.go @@ -0,0 +1,75 @@ +package sshrelay + +import ( + "context" + "encoding/json" + "io" + "sync/atomic" + "time" + + "golang.org/x/crypto/ssh" + "nhooyr.io/websocket" +) + +type ControlMsg struct { + Type string `json:"type"` + Cols int `json:"cols"` + Rows int `json:"rows"` +} + +func bridge(ctx context.Context, cancel context.CancelFunc, ws *websocket.Conn, + sshSess *ssh.Session, stdin io.WriteCloser, stdout, stderr io.Reader, lastActive *int64) { + + // WebSocket → SSH stdin + go func() { + defer cancel() + for { + typ, data, err := ws.Read(ctx) + if err != nil { + return + } + atomic.StoreInt64(lastActive, time.Now().UnixNano()) + + if typ == websocket.MessageText { + var ctrl ControlMsg + if json.Unmarshal(data, &ctrl) != nil { + continue + } + if ctrl.Type == "resize" && ctrl.Cols > 0 && ctrl.Cols <= 500 && ctrl.Rows > 0 && ctrl.Rows <= 200 { + sshSess.WindowChange(ctrl.Rows, ctrl.Cols) + } + continue + } + stdin.Write(data) + } + }() + + // SSH stdout → WebSocket + go func() { + defer cancel() + buf := make([]byte, 4096) + for { + n, err := stdout.Read(buf) + if err != nil { + return + } + atomic.StoreInt64(lastActive, time.Now().UnixNano()) + ws.Write(ctx, websocket.MessageBinary, buf[:n]) + } + }() + + // SSH stderr → WebSocket + go func() { + defer cancel() + buf := make([]byte, 4096) + for { + n, err := stderr.Read(buf) + if err != nil { + return + } + ws.Write(ctx, websocket.MessageBinary, buf[:n]) + } + }() + + <-ctx.Done() +} diff --git a/poller/internal/sshrelay/server.go b/poller/internal/sshrelay/server.go new file mode 100644 index 0000000..79e98c1 --- /dev/null +++ b/poller/internal/sshrelay/server.go @@ -0,0 +1,337 @@ +package sshrelay + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "sync" + "time" + + "github.com/google/uuid" + "github.com/mikrotik-portal/poller/internal/store" + "github.com/mikrotik-portal/poller/internal/vault" + "github.com/redis/go-redis/v9" + "golang.org/x/crypto/ssh" + "nhooyr.io/websocket" +) + +// TokenPayload is the JSON structure stored in Redis for a single-use SSH session token. +type TokenPayload struct { + DeviceID string `json:"device_id"` + TenantID string `json:"tenant_id"` + UserID string `json:"user_id"` + SourceIP string `json:"source_ip"` + Cols int `json:"cols"` + Rows int `json:"rows"` + CreatedAt int64 `json:"created_at"` +} + +// Server is the SSH relay WebSocket server. It validates single-use Redis tokens, +// dials SSH to the target device, and bridges WebSocket ↔ SSH PTY. +type Server struct { + redis *redis.Client + credCache *vault.CredentialCache + deviceStore *store.DeviceStore + sessions map[string]*Session + mu sync.Mutex + idleTime time.Duration + maxSessions int + maxPerUser int + maxPerDevice int + cancel context.CancelFunc +} + +// Config holds tunable limits for the SSH relay server. +type Config struct { + IdleTimeout time.Duration + MaxSessions int + MaxPerUser int + MaxPerDevice int +} + +// NewServer creates and starts a new SSH relay server. +func NewServer(rc *redis.Client, cc *vault.CredentialCache, ds *store.DeviceStore, cfg Config) *Server { + ctx, cancel := context.WithCancel(context.Background()) + s := &Server{ + redis: rc, + credCache: cc, + deviceStore: ds, + sessions: make(map[string]*Session), + idleTime: cfg.IdleTimeout, + maxSessions: cfg.MaxSessions, + maxPerUser: cfg.MaxPerUser, + maxPerDevice: cfg.MaxPerDevice, + cancel: cancel, + } + go s.idleLoop(ctx) + return s +} + +// Handler returns the HTTP handler for the SSH relay server. +func (s *Server) Handler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/ws/ssh", s.handleSSH) + mux.HandleFunc("/healthz", s.handleHealth) + return mux +} + +// Shutdown cancels the idle loop and closes all active sessions. +func (s *Server) Shutdown() { + s.cancel() + s.mu.Lock() + for _, sess := range s.sessions { + sess.cancel() + } + s.mu.Unlock() +} + +func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status":"ok"}`)) +} + +func (s *Server) handleSSH(w http.ResponseWriter, r *http.Request) { + token := r.URL.Query().Get("token") + if token == "" { + http.Error(w, "missing token", http.StatusUnauthorized) + return + } + + // Validate single-use token via Redis GETDEL + payload, err := s.validateToken(r.Context(), token) + if err != nil { + slog.Warn("ssh: token validation failed", "err", err) + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + // Check session limits before upgrading + if err := s.checkLimits(payload.UserID, payload.DeviceID); err != nil { + http.Error(w, err.Error(), http.StatusTooManyRequests) + return + } + + // Upgrade to WebSocket + ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, // nginx handles origin validation + }) + if err != nil { + slog.Error("ssh: websocket upgrade failed", "err", err) + return + } + ws.SetReadLimit(1 << 20) + + // Extract source IP (nginx sets X-Real-IP) + sourceIP := r.Header.Get("X-Real-IP") + if sourceIP == "" { + sourceIP = r.RemoteAddr + } + + // Look up device + dev, err := s.deviceStore.GetDevice(r.Context(), payload.DeviceID) + if err != nil { + slog.Error("ssh: device lookup failed", "device_id", payload.DeviceID, "err", err) + ws.Close(websocket.StatusInternalError, "device not found") + return + } + + // Decrypt credentials — GetCredentials returns (username, password, error) + username, password, err := s.credCache.GetCredentials( + dev.ID, + payload.TenantID, + dev.EncryptedCredentialsTransit, + dev.EncryptedCredentials, + ) + if err != nil { + slog.Error("ssh: credential decryption failed", "device_id", payload.DeviceID, "err", err) + ws.Close(websocket.StatusInternalError, "credential error") + return + } + + // SSH dial + sshAddr := dev.IPAddress + ":22" + sshClient, err := ssh.Dial("tcp", sshAddr, &ssh.ClientConfig{ + User: username, + Auth: []ssh.AuthMethod{ssh.Password(password)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 10 * time.Second, + }) + if err != nil { + slog.Error("ssh: dial failed", "device_id", payload.DeviceID, "addr", sshAddr, "err", err) + ws.Close(websocket.StatusInternalError, "ssh connection failed") + return + } + + sshSess, err := sshClient.NewSession() + if err != nil { + sshClient.Close() + ws.Close(websocket.StatusInternalError, "ssh session failed") + return + } + + cols, rows := payload.Cols, payload.Rows + if cols <= 0 { + cols = 80 + } + if rows <= 0 { + rows = 24 + } + + if err := sshSess.RequestPty("xterm-256color", rows, cols, ssh.TerminalModes{ + ssh.ECHO: 1, + }); err != nil { + sshSess.Close() + sshClient.Close() + ws.Close(websocket.StatusInternalError, "pty request failed") + return + } + + stdin, _ := sshSess.StdinPipe() + stdout, _ := sshSess.StdoutPipe() + stderr, _ := sshSess.StderrPipe() + + if err := sshSess.Shell(); err != nil { + sshSess.Close() + sshClient.Close() + ws.Close(websocket.StatusInternalError, "shell start failed") + return + } + + ctx, cancel := context.WithCancel(context.Background()) + + sess := &Session{ + ID: uuid.New().String(), + DeviceID: payload.DeviceID, + TenantID: payload.TenantID, + UserID: payload.UserID, + SourceIP: sourceIP, + StartTime: time.Now(), + LastActive: time.Now().UnixNano(), + sshClient: sshClient, + sshSession: sshSess, + ptyCols: cols, + ptyRows: rows, + cancel: cancel, + } + + s.mu.Lock() + s.sessions[sess.ID] = sess + s.mu.Unlock() + + slog.Info("ssh session started", + "session_id", sess.ID, + "device_id", payload.DeviceID, + "tenant_id", payload.TenantID, + "user_id", payload.UserID, + "source_ip", sourceIP, + ) + + // Bridge WebSocket ↔ SSH (blocks until session ends) + bridge(ctx, cancel, ws, sshSess, stdin, stdout, stderr, &sess.LastActive) + + // Cleanup + ws.Close(websocket.StatusNormalClosure, "session ended") + sshSess.Close() + sshClient.Close() + + s.mu.Lock() + delete(s.sessions, sess.ID) + s.mu.Unlock() + + duration := time.Since(sess.StartTime) + slog.Info("ssh session ended", + "session_id", sess.ID, + "device_id", payload.DeviceID, + "duration", duration.String(), + ) +} + +// validateToken performs a Redis GETDEL to atomically consume a single-use token. +func (s *Server) validateToken(ctx context.Context, token string) (*TokenPayload, error) { + key := "ssh:token:" + token + val, err := s.redis.GetDel(ctx, key).Result() + if err != nil { + return nil, fmt.Errorf("token not found or expired") + } + var payload TokenPayload + if err := json.Unmarshal([]byte(val), &payload); err != nil { + return nil, fmt.Errorf("invalid token payload") + } + return &payload, nil +} + +// checkLimits returns an error if any session limit would be exceeded. +func (s *Server) checkLimits(userID, deviceID string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if len(s.sessions) >= s.maxSessions { + return fmt.Errorf("max sessions exceeded") + } + + userCount := 0 + deviceCount := 0 + for _, sess := range s.sessions { + if sess.UserID == userID { + userCount++ + } + if sess.DeviceID == deviceID { + deviceCount++ + } + } + if userCount >= s.maxPerUser { + return fmt.Errorf("max sessions per user exceeded") + } + if deviceCount >= s.maxPerDevice { + return fmt.Errorf("max sessions per device exceeded") + } + return nil +} + +func (s *Server) idleLoop(ctx context.Context) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.cleanupIdle() + } + } +} + +func (s *Server) cleanupIdle() { + s.mu.Lock() + var toCancel []*Session + for _, sess := range s.sessions { + if sess.IdleDuration() > s.idleTime { + toCancel = append(toCancel, sess) + } + } + s.mu.Unlock() + + for _, sess := range toCancel { + slog.Info("ssh session idle timeout", "session_id", sess.ID) + sess.cancel() + } +} + +// SessionList returns active SSH sessions for a given device (used by admin APIs). +func (s *Server) SessionList(deviceID string) []map[string]interface{} { + s.mu.Lock() + defer s.mu.Unlock() + var out []map[string]interface{} + for _, sess := range s.sessions { + if sess.DeviceID == deviceID { + out = append(out, map[string]interface{}{ + "session_id": sess.ID, + "idle_seconds": int(sess.IdleDuration().Seconds()), + "created_at": sess.StartTime.Format(time.RFC3339), + }) + } + } + return out +} diff --git a/poller/internal/sshrelay/server_test.go b/poller/internal/sshrelay/server_test.go new file mode 100644 index 0000000..75c8e0c --- /dev/null +++ b/poller/internal/sshrelay/server_test.go @@ -0,0 +1,104 @@ +package sshrelay + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupRedis(t *testing.T) (*redis.Client, *miniredis.Miniredis) { + t.Helper() + mr := miniredis.RunT(t) + rc := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + return rc, mr +} + +func TestValidateToken_Valid(t *testing.T) { + rc, _ := setupRedis(t) + s := &Server{redis: rc, sessions: make(map[string]*Session)} + + payload := TokenPayload{DeviceID: "d1", TenantID: "t1", UserID: "u1", Cols: 80, Rows: 24, CreatedAt: time.Now().Unix()} + data, _ := json.Marshal(payload) + rc.Set(context.Background(), "ssh:token:abc123", string(data), 120*time.Second) + + result, err := s.validateToken(context.Background(), "abc123") + require.NoError(t, err) + assert.Equal(t, "d1", result.DeviceID) + + // Token consumed — second use should fail + _, err = s.validateToken(context.Background(), "abc123") + assert.Error(t, err) +} + +func TestValidateToken_Expired(t *testing.T) { + rc, mr := setupRedis(t) + s := &Server{redis: rc, sessions: make(map[string]*Session)} + + payload := TokenPayload{DeviceID: "d1", TenantID: "t1", UserID: "u1"} + data, _ := json.Marshal(payload) + rc.Set(context.Background(), "ssh:token:expired", string(data), 1*time.Millisecond) + mr.FastForward(2 * time.Second) + + _, err := s.validateToken(context.Background(), "expired") + assert.Error(t, err) +} + +func TestCheckLimits_MaxSessions(t *testing.T) { + s := &Server{ + sessions: make(map[string]*Session), + maxSessions: 2, + maxPerUser: 10, + maxPerDevice: 10, + } + s.sessions["s1"] = &Session{UserID: "u1", DeviceID: "d1"} + s.sessions["s2"] = &Session{UserID: "u2", DeviceID: "d2"} + + err := s.checkLimits("u3", "d3") + assert.Error(t, err) + assert.Contains(t, err.Error(), "max sessions exceeded") +} + +func TestCheckLimits_MaxPerUser(t *testing.T) { + s := &Server{ + sessions: make(map[string]*Session), + maxSessions: 100, + maxPerUser: 2, + maxPerDevice: 100, + } + s.sessions["s1"] = &Session{UserID: "u1", DeviceID: "d1"} + s.sessions["s2"] = &Session{UserID: "u1", DeviceID: "d2"} + + err := s.checkLimits("u1", "d3") + assert.Error(t, err) + assert.Contains(t, err.Error(), "per user") +} + +func TestCheckLimits_MaxPerDevice(t *testing.T) { + s := &Server{ + sessions: make(map[string]*Session), + maxSessions: 100, + maxPerUser: 100, + maxPerDevice: 1, + } + s.sessions["s1"] = &Session{UserID: "u1", DeviceID: "d1"} + + err := s.checkLimits("u2", "d1") + assert.Error(t, err) + assert.Contains(t, err.Error(), "per device") +} + +func TestSessionList(t *testing.T) { + s := &Server{sessions: make(map[string]*Session)} + s.sessions["s1"] = &Session{ID: "s1", DeviceID: "d1", StartTime: time.Now(), LastActive: time.Now().UnixNano()} + s.sessions["s2"] = &Session{ID: "s2", DeviceID: "d1", StartTime: time.Now(), LastActive: time.Now().UnixNano()} + s.sessions["s3"] = &Session{ID: "s3", DeviceID: "d2", StartTime: time.Now(), LastActive: time.Now().UnixNano()} + + list := s.SessionList("d1") + assert.Len(t, list, 2) +} diff --git a/poller/internal/sshrelay/session.go b/poller/internal/sshrelay/session.go new file mode 100644 index 0000000..5c54324 --- /dev/null +++ b/poller/internal/sshrelay/session.go @@ -0,0 +1,28 @@ +package sshrelay + +import ( + "context" + "sync/atomic" + "time" + + "golang.org/x/crypto/ssh" +) + +type Session struct { + ID string + DeviceID string + TenantID string + UserID string + SourceIP string + StartTime time.Time + LastActive int64 // atomic, unix nanoseconds + sshClient *ssh.Client + sshSession *ssh.Session + ptyCols int + ptyRows int + cancel context.CancelFunc +} + +func (s *Session) IdleDuration() time.Duration { + return time.Since(time.Unix(0, atomic.LoadInt64(&s.LastActive))) +}