diff --git a/poller/internal/tunnel/tunnel.go b/poller/internal/tunnel/tunnel.go new file mode 100644 index 0000000..6454088 --- /dev/null +++ b/poller/internal/tunnel/tunnel.go @@ -0,0 +1,109 @@ +package tunnel + +import ( + "context" + "io" + "log/slog" + "net" + "sync" + "sync/atomic" + "time" +) + +// Tunnel represents an active WinBox TCP tunnel to a single router. +type Tunnel struct { + ID string + DeviceID string + TenantID string + UserID string + LocalPort int + RemoteAddr string // router IP:port + CreatedAt time.Time + LastActive int64 // atomic, unix nanoseconds + + listener net.Listener + ctx context.Context + cancel context.CancelFunc + conns sync.WaitGroup + activeConns int64 // atomic +} + +// Close shuts down the tunnel in the correct order. +func (t *Tunnel) Close() { + t.listener.Close() + t.cancel() + t.conns.Wait() + slog.Info("tunnel closed", "tunnel_id", t.ID, "device_id", t.DeviceID, "port", t.LocalPort) +} + +// IdleDuration returns how long the tunnel has been idle. +func (t *Tunnel) IdleDuration() time.Duration { + return time.Since(time.Unix(0, atomic.LoadInt64(&t.LastActive))) +} + +// ActiveConns returns the number of active TCP connections. +func (t *Tunnel) ActiveConns() int64 { + return atomic.LoadInt64(&t.activeConns) +} + +func (t *Tunnel) accept() { + for { + conn, err := t.listener.Accept() + if err != nil { + return // listener closed + } + t.conns.Add(1) + atomic.AddInt64(&t.activeConns, 1) + go t.handleConn(conn) + } +} + +func (t *Tunnel) handleConn(clientConn net.Conn) { + defer t.conns.Done() + defer atomic.AddInt64(&t.activeConns, -1) + + slog.Info("tunnel client connected", "tunnel_id", t.ID, "device_id", t.DeviceID) + + routerConn, err := net.DialTimeout("tcp", t.RemoteAddr, 10*time.Second) + if err != nil { + slog.Warn("tunnel dial failed", "tunnel_id", t.ID, "remote", t.RemoteAddr, "err", err) + clientConn.Close() + return + } + + ctx, cancel := context.WithCancel(t.ctx) + defer cancel() + + go func() { + io.Copy(routerConn, newActivityReader(clientConn, &t.LastActive)) + cancel() + }() + go func() { + io.Copy(clientConn, newActivityReader(routerConn, &t.LastActive)) + cancel() + }() + + <-ctx.Done() + clientConn.Close() + routerConn.Close() + + slog.Info("tunnel client disconnected", "tunnel_id", t.ID, "device_id", t.DeviceID) +} + +// activityReader wraps an io.Reader and updates a shared timestamp on every Read. +type activityReader struct { + r io.Reader + lastActive *int64 +} + +func newActivityReader(r io.Reader, lastActive *int64) *activityReader { + return &activityReader{r: r, lastActive: lastActive} +} + +func (a *activityReader) Read(p []byte) (int, error) { + n, err := a.r.Read(p) + if n > 0 { + atomic.StoreInt64(a.lastActive, time.Now().UnixNano()) + } + return n, err +} diff --git a/poller/internal/tunnel/tunnel_test.go b/poller/internal/tunnel/tunnel_test.go new file mode 100644 index 0000000..7f892b8 --- /dev/null +++ b/poller/internal/tunnel/tunnel_test.go @@ -0,0 +1,161 @@ +package tunnel + +import ( + "context" + "io" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockRouter simulates a RouterOS device accepting TCP connections +func mockRouter(t *testing.T) (string, func()) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + io.Copy(c, c) // echo server + }(conn) + } + }() + return ln.Addr().String(), func() { ln.Close() } +} + +func TestTunnel_ProxyBidirectional(t *testing.T) { + routerAddr, cleanup := mockRouter(t) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tun := &Tunnel{ + ID: "test-1", + RemoteAddr: routerAddr, + LastActive: time.Now().UnixNano(), + cancel: cancel, + ctx: ctx, + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + tun.listener = ln + + go tun.accept() + + // Connect as a WinBox client + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + defer conn.Close() + + // Write and read back (echo) + msg := []byte("hello winbox") + _, err = conn.Write(msg) + require.NoError(t, err) + + buf := make([]byte, len(msg)) + _, err = io.ReadFull(conn, buf) + require.NoError(t, err) + assert.Equal(t, msg, buf) +} + +func TestTunnel_ActivityTracking(t *testing.T) { + routerAddr, cleanup := mockRouter(t) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + before := time.Now().UnixNano() + tun := &Tunnel{ + ID: "test-2", + RemoteAddr: routerAddr, + LastActive: before, + cancel: cancel, + ctx: ctx, + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + tun.listener = ln + go tun.accept() + + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + conn.Write([]byte("data")) + buf := make([]byte, 4) + io.ReadFull(conn, buf) + conn.Close() + + time.Sleep(50 * time.Millisecond) + after := atomic.LoadInt64(&tun.LastActive) + assert.Greater(t, after, before) +} + +func TestTunnel_Close(t *testing.T) { + routerAddr, cleanup := mockRouter(t) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + + tun := &Tunnel{ + ID: "test-3", + RemoteAddr: routerAddr, + LastActive: time.Now().UnixNano(), + cancel: cancel, + ctx: ctx, + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + tun.listener = ln + go tun.accept() + + // Open a connection + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + + // Close tunnel — should terminate everything + tun.Close() + + // Connection should be dead + conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + _, err = conn.Read(make([]byte, 1)) + assert.Error(t, err) +} + +func TestTunnel_DialFailure(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tun := &Tunnel{ + ID: "test-4", + RemoteAddr: "127.0.0.1:1", // nothing listening + LastActive: time.Now().UnixNano(), + cancel: cancel, + ctx: ctx, + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + tun.listener = ln + go tun.accept() + + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + + // Should be closed quickly since dial to router fails + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.Read(make([]byte, 1)) + assert.Error(t, err) +}