feat(poller): add SSH relay server with WebSocket-to-PTY bridge
Implements the SSH relay server (Task 2.1) that validates single-use Redis tokens via GETDEL, dials SSH to the target device with PTY, and bridges WebSocket binary/text frames to SSH stdin/stdout/stderr with idle timeout and per-user/per-device session limits. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,7 @@ module github.com/mikrotik-portal/poller
|
|||||||
go 1.24.0
|
go 1.24.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/alicebob/miniredis/v2 v2.37.0
|
||||||
github.com/bsm/redislock v0.9.4
|
github.com/bsm/redislock v0.9.4
|
||||||
github.com/go-routeros/routeros/v3 v3.0.0
|
github.com/go-routeros/routeros/v3 v3.0.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
@@ -76,6 +77,7 @@ require (
|
|||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||||
|
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25
|
|||||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
|
github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68=
|
||||||
|
github.com/alicebob/miniredis/v2 v2.37.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||||
@@ -164,6 +166,8 @@ github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFA
|
|||||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||||
|
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||||
|
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||||
|
|||||||
75
poller/internal/sshrelay/bridge.go
Normal file
75
poller/internal/sshrelay/bridge.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package sshrelay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
"nhooyr.io/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ControlMsg struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Cols int `json:"cols"`
|
||||||
|
Rows int `json:"rows"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func bridge(ctx context.Context, cancel context.CancelFunc, ws *websocket.Conn,
|
||||||
|
sshSess *ssh.Session, stdin io.WriteCloser, stdout, stderr io.Reader, lastActive *int64) {
|
||||||
|
|
||||||
|
// WebSocket → SSH stdin
|
||||||
|
go func() {
|
||||||
|
defer cancel()
|
||||||
|
for {
|
||||||
|
typ, data, err := ws.Read(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
atomic.StoreInt64(lastActive, time.Now().UnixNano())
|
||||||
|
|
||||||
|
if typ == websocket.MessageText {
|
||||||
|
var ctrl ControlMsg
|
||||||
|
if json.Unmarshal(data, &ctrl) != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ctrl.Type == "resize" && ctrl.Cols > 0 && ctrl.Cols <= 500 && ctrl.Rows > 0 && ctrl.Rows <= 200 {
|
||||||
|
sshSess.WindowChange(ctrl.Rows, ctrl.Cols)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stdin.Write(data)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// SSH stdout → WebSocket
|
||||||
|
go func() {
|
||||||
|
defer cancel()
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
for {
|
||||||
|
n, err := stdout.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
atomic.StoreInt64(lastActive, time.Now().UnixNano())
|
||||||
|
ws.Write(ctx, websocket.MessageBinary, buf[:n])
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// SSH stderr → WebSocket
|
||||||
|
go func() {
|
||||||
|
defer cancel()
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
for {
|
||||||
|
n, err := stderr.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ws.Write(ctx, websocket.MessageBinary, buf[:n])
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-ctx.Done()
|
||||||
|
}
|
||||||
337
poller/internal/sshrelay/server.go
Normal file
337
poller/internal/sshrelay/server.go
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
package sshrelay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/mikrotik-portal/poller/internal/store"
|
||||||
|
"github.com/mikrotik-portal/poller/internal/vault"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
"nhooyr.io/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenPayload is the JSON structure stored in Redis for a single-use SSH session token.
|
||||||
|
type TokenPayload struct {
|
||||||
|
DeviceID string `json:"device_id"`
|
||||||
|
TenantID string `json:"tenant_id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
SourceIP string `json:"source_ip"`
|
||||||
|
Cols int `json:"cols"`
|
||||||
|
Rows int `json:"rows"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server is the SSH relay WebSocket server. It validates single-use Redis tokens,
|
||||||
|
// dials SSH to the target device, and bridges WebSocket ↔ SSH PTY.
|
||||||
|
type Server struct {
|
||||||
|
redis *redis.Client
|
||||||
|
credCache *vault.CredentialCache
|
||||||
|
deviceStore *store.DeviceStore
|
||||||
|
sessions map[string]*Session
|
||||||
|
mu sync.Mutex
|
||||||
|
idleTime time.Duration
|
||||||
|
maxSessions int
|
||||||
|
maxPerUser int
|
||||||
|
maxPerDevice int
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config holds tunable limits for the SSH relay server.
|
||||||
|
type Config struct {
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
MaxSessions int
|
||||||
|
MaxPerUser int
|
||||||
|
MaxPerDevice int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates and starts a new SSH relay server.
|
||||||
|
func NewServer(rc *redis.Client, cc *vault.CredentialCache, ds *store.DeviceStore, cfg Config) *Server {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
s := &Server{
|
||||||
|
redis: rc,
|
||||||
|
credCache: cc,
|
||||||
|
deviceStore: ds,
|
||||||
|
sessions: make(map[string]*Session),
|
||||||
|
idleTime: cfg.IdleTimeout,
|
||||||
|
maxSessions: cfg.MaxSessions,
|
||||||
|
maxPerUser: cfg.MaxPerUser,
|
||||||
|
maxPerDevice: cfg.MaxPerDevice,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
go s.idleLoop(ctx)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler returns the HTTP handler for the SSH relay server.
|
||||||
|
func (s *Server) Handler() http.Handler {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/ws/ssh", s.handleSSH)
|
||||||
|
mux.HandleFunc("/healthz", s.handleHealth)
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown cancels the idle loop and closes all active sessions.
|
||||||
|
func (s *Server) Shutdown() {
|
||||||
|
s.cancel()
|
||||||
|
s.mu.Lock()
|
||||||
|
for _, sess := range s.sessions {
|
||||||
|
sess.cancel()
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write([]byte(`{"status":"ok"}`))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleSSH(w http.ResponseWriter, r *http.Request) {
|
||||||
|
token := r.URL.Query().Get("token")
|
||||||
|
if token == "" {
|
||||||
|
http.Error(w, "missing token", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate single-use token via Redis GETDEL
|
||||||
|
payload, err := s.validateToken(r.Context(), token)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("ssh: token validation failed", "err", err)
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check session limits before upgrading
|
||||||
|
if err := s.checkLimits(payload.UserID, payload.DeviceID); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrade to WebSocket
|
||||||
|
ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||||
|
OriginPatterns: []string{"*"}, // nginx handles origin validation
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("ssh: websocket upgrade failed", "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ws.SetReadLimit(1 << 20)
|
||||||
|
|
||||||
|
// Extract source IP (nginx sets X-Real-IP)
|
||||||
|
sourceIP := r.Header.Get("X-Real-IP")
|
||||||
|
if sourceIP == "" {
|
||||||
|
sourceIP = r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up device
|
||||||
|
dev, err := s.deviceStore.GetDevice(r.Context(), payload.DeviceID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("ssh: device lookup failed", "device_id", payload.DeviceID, "err", err)
|
||||||
|
ws.Close(websocket.StatusInternalError, "device not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt credentials — GetCredentials returns (username, password, error)
|
||||||
|
username, password, err := s.credCache.GetCredentials(
|
||||||
|
dev.ID,
|
||||||
|
payload.TenantID,
|
||||||
|
dev.EncryptedCredentialsTransit,
|
||||||
|
dev.EncryptedCredentials,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("ssh: credential decryption failed", "device_id", payload.DeviceID, "err", err)
|
||||||
|
ws.Close(websocket.StatusInternalError, "credential error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSH dial
|
||||||
|
sshAddr := dev.IPAddress + ":22"
|
||||||
|
sshClient, err := ssh.Dial("tcp", sshAddr, &ssh.ClientConfig{
|
||||||
|
User: username,
|
||||||
|
Auth: []ssh.AuthMethod{ssh.Password(password)},
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("ssh: dial failed", "device_id", payload.DeviceID, "addr", sshAddr, "err", err)
|
||||||
|
ws.Close(websocket.StatusInternalError, "ssh connection failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sshSess, err := sshClient.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
sshClient.Close()
|
||||||
|
ws.Close(websocket.StatusInternalError, "ssh session failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cols, rows := payload.Cols, payload.Rows
|
||||||
|
if cols <= 0 {
|
||||||
|
cols = 80
|
||||||
|
}
|
||||||
|
if rows <= 0 {
|
||||||
|
rows = 24
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sshSess.RequestPty("xterm-256color", rows, cols, ssh.TerminalModes{
|
||||||
|
ssh.ECHO: 1,
|
||||||
|
}); err != nil {
|
||||||
|
sshSess.Close()
|
||||||
|
sshClient.Close()
|
||||||
|
ws.Close(websocket.StatusInternalError, "pty request failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stdin, _ := sshSess.StdinPipe()
|
||||||
|
stdout, _ := sshSess.StdoutPipe()
|
||||||
|
stderr, _ := sshSess.StderrPipe()
|
||||||
|
|
||||||
|
if err := sshSess.Shell(); err != nil {
|
||||||
|
sshSess.Close()
|
||||||
|
sshClient.Close()
|
||||||
|
ws.Close(websocket.StatusInternalError, "shell start failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
sess := &Session{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
DeviceID: payload.DeviceID,
|
||||||
|
TenantID: payload.TenantID,
|
||||||
|
UserID: payload.UserID,
|
||||||
|
SourceIP: sourceIP,
|
||||||
|
StartTime: time.Now(),
|
||||||
|
LastActive: time.Now().UnixNano(),
|
||||||
|
sshClient: sshClient,
|
||||||
|
sshSession: sshSess,
|
||||||
|
ptyCols: cols,
|
||||||
|
ptyRows: rows,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.sessions[sess.ID] = sess
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
slog.Info("ssh session started",
|
||||||
|
"session_id", sess.ID,
|
||||||
|
"device_id", payload.DeviceID,
|
||||||
|
"tenant_id", payload.TenantID,
|
||||||
|
"user_id", payload.UserID,
|
||||||
|
"source_ip", sourceIP,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Bridge WebSocket ↔ SSH (blocks until session ends)
|
||||||
|
bridge(ctx, cancel, ws, sshSess, stdin, stdout, stderr, &sess.LastActive)
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
ws.Close(websocket.StatusNormalClosure, "session ended")
|
||||||
|
sshSess.Close()
|
||||||
|
sshClient.Close()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
delete(s.sessions, sess.ID)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
duration := time.Since(sess.StartTime)
|
||||||
|
slog.Info("ssh session ended",
|
||||||
|
"session_id", sess.ID,
|
||||||
|
"device_id", payload.DeviceID,
|
||||||
|
"duration", duration.String(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateToken performs a Redis GETDEL to atomically consume a single-use token.
|
||||||
|
func (s *Server) validateToken(ctx context.Context, token string) (*TokenPayload, error) {
|
||||||
|
key := "ssh:token:" + token
|
||||||
|
val, err := s.redis.GetDel(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("token not found or expired")
|
||||||
|
}
|
||||||
|
var payload TokenPayload
|
||||||
|
if err := json.Unmarshal([]byte(val), &payload); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid token payload")
|
||||||
|
}
|
||||||
|
return &payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkLimits returns an error if any session limit would be exceeded.
|
||||||
|
func (s *Server) checkLimits(userID, deviceID string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if len(s.sessions) >= s.maxSessions {
|
||||||
|
return fmt.Errorf("max sessions exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
userCount := 0
|
||||||
|
deviceCount := 0
|
||||||
|
for _, sess := range s.sessions {
|
||||||
|
if sess.UserID == userID {
|
||||||
|
userCount++
|
||||||
|
}
|
||||||
|
if sess.DeviceID == deviceID {
|
||||||
|
deviceCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if userCount >= s.maxPerUser {
|
||||||
|
return fmt.Errorf("max sessions per user exceeded")
|
||||||
|
}
|
||||||
|
if deviceCount >= s.maxPerDevice {
|
||||||
|
return fmt.Errorf("max sessions per device exceeded")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) idleLoop(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
s.cleanupIdle()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) cleanupIdle() {
|
||||||
|
s.mu.Lock()
|
||||||
|
var toCancel []*Session
|
||||||
|
for _, sess := range s.sessions {
|
||||||
|
if sess.IdleDuration() > s.idleTime {
|
||||||
|
toCancel = append(toCancel, sess)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
for _, sess := range toCancel {
|
||||||
|
slog.Info("ssh session idle timeout", "session_id", sess.ID)
|
||||||
|
sess.cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionList returns active SSH sessions for a given device (used by admin APIs).
|
||||||
|
func (s *Server) SessionList(deviceID string) []map[string]interface{} {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
var out []map[string]interface{}
|
||||||
|
for _, sess := range s.sessions {
|
||||||
|
if sess.DeviceID == deviceID {
|
||||||
|
out = append(out, map[string]interface{}{
|
||||||
|
"session_id": sess.ID,
|
||||||
|
"idle_seconds": int(sess.IdleDuration().Seconds()),
|
||||||
|
"created_at": sess.StartTime.Format(time.RFC3339),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
104
poller/internal/sshrelay/server_test.go
Normal file
104
poller/internal/sshrelay/server_test.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package sshrelay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupRedis(t *testing.T) (*redis.Client, *miniredis.Miniredis) {
|
||||||
|
t.Helper()
|
||||||
|
mr := miniredis.RunT(t)
|
||||||
|
rc := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||||
|
return rc, mr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateToken_Valid(t *testing.T) {
|
||||||
|
rc, _ := setupRedis(t)
|
||||||
|
s := &Server{redis: rc, sessions: make(map[string]*Session)}
|
||||||
|
|
||||||
|
payload := TokenPayload{DeviceID: "d1", TenantID: "t1", UserID: "u1", Cols: 80, Rows: 24, CreatedAt: time.Now().Unix()}
|
||||||
|
data, _ := json.Marshal(payload)
|
||||||
|
rc.Set(context.Background(), "ssh:token:abc123", string(data), 120*time.Second)
|
||||||
|
|
||||||
|
result, err := s.validateToken(context.Background(), "abc123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "d1", result.DeviceID)
|
||||||
|
|
||||||
|
// Token consumed — second use should fail
|
||||||
|
_, err = s.validateToken(context.Background(), "abc123")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateToken_Expired(t *testing.T) {
|
||||||
|
rc, mr := setupRedis(t)
|
||||||
|
s := &Server{redis: rc, sessions: make(map[string]*Session)}
|
||||||
|
|
||||||
|
payload := TokenPayload{DeviceID: "d1", TenantID: "t1", UserID: "u1"}
|
||||||
|
data, _ := json.Marshal(payload)
|
||||||
|
rc.Set(context.Background(), "ssh:token:expired", string(data), 1*time.Millisecond)
|
||||||
|
mr.FastForward(2 * time.Second)
|
||||||
|
|
||||||
|
_, err := s.validateToken(context.Background(), "expired")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckLimits_MaxSessions(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
sessions: make(map[string]*Session),
|
||||||
|
maxSessions: 2,
|
||||||
|
maxPerUser: 10,
|
||||||
|
maxPerDevice: 10,
|
||||||
|
}
|
||||||
|
s.sessions["s1"] = &Session{UserID: "u1", DeviceID: "d1"}
|
||||||
|
s.sessions["s2"] = &Session{UserID: "u2", DeviceID: "d2"}
|
||||||
|
|
||||||
|
err := s.checkLimits("u3", "d3")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "max sessions exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckLimits_MaxPerUser(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
sessions: make(map[string]*Session),
|
||||||
|
maxSessions: 100,
|
||||||
|
maxPerUser: 2,
|
||||||
|
maxPerDevice: 100,
|
||||||
|
}
|
||||||
|
s.sessions["s1"] = &Session{UserID: "u1", DeviceID: "d1"}
|
||||||
|
s.sessions["s2"] = &Session{UserID: "u1", DeviceID: "d2"}
|
||||||
|
|
||||||
|
err := s.checkLimits("u1", "d3")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "per user")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckLimits_MaxPerDevice(t *testing.T) {
|
||||||
|
s := &Server{
|
||||||
|
sessions: make(map[string]*Session),
|
||||||
|
maxSessions: 100,
|
||||||
|
maxPerUser: 100,
|
||||||
|
maxPerDevice: 1,
|
||||||
|
}
|
||||||
|
s.sessions["s1"] = &Session{UserID: "u1", DeviceID: "d1"}
|
||||||
|
|
||||||
|
err := s.checkLimits("u2", "d1")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "per device")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionList(t *testing.T) {
|
||||||
|
s := &Server{sessions: make(map[string]*Session)}
|
||||||
|
s.sessions["s1"] = &Session{ID: "s1", DeviceID: "d1", StartTime: time.Now(), LastActive: time.Now().UnixNano()}
|
||||||
|
s.sessions["s2"] = &Session{ID: "s2", DeviceID: "d1", StartTime: time.Now(), LastActive: time.Now().UnixNano()}
|
||||||
|
s.sessions["s3"] = &Session{ID: "s3", DeviceID: "d2", StartTime: time.Now(), LastActive: time.Now().UnixNano()}
|
||||||
|
|
||||||
|
list := s.SessionList("d1")
|
||||||
|
assert.Len(t, list, 2)
|
||||||
|
}
|
||||||
28
poller/internal/sshrelay/session.go
Normal file
28
poller/internal/sshrelay/session.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package sshrelay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
ID string
|
||||||
|
DeviceID string
|
||||||
|
TenantID string
|
||||||
|
UserID string
|
||||||
|
SourceIP string
|
||||||
|
StartTime time.Time
|
||||||
|
LastActive int64 // atomic, unix nanoseconds
|
||||||
|
sshClient *ssh.Client
|
||||||
|
sshSession *ssh.Session
|
||||||
|
ptyCols int
|
||||||
|
ptyRows int
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) IdleDuration() time.Duration {
|
||||||
|
return time.Since(time.Unix(0, atomic.LoadInt64(&s.LastActive)))
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user