feat(poller): add TCP tunnel with bidirectional proxy and activity tracking
Implements Tunnel type that listens on a local port, accepts WinBox client connections, dials the remote RouterOS device, and proxies traffic bidirectionally. Uses activityReader to atomically update LastActive on each read for idle timeout detection. Per-connection contexts derive from the tunnel context so Close() terminates all connections cleanly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
109
poller/internal/tunnel/tunnel.go
Normal file
109
poller/internal/tunnel/tunnel.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Tunnel represents an active WinBox TCP tunnel to a single router.
|
||||
type Tunnel struct {
|
||||
ID string
|
||||
DeviceID string
|
||||
TenantID string
|
||||
UserID string
|
||||
LocalPort int
|
||||
RemoteAddr string // router IP:port
|
||||
CreatedAt time.Time
|
||||
LastActive int64 // atomic, unix nanoseconds
|
||||
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
conns sync.WaitGroup
|
||||
activeConns int64 // atomic
|
||||
}
|
||||
|
||||
// Close shuts down the tunnel in the correct order.
|
||||
func (t *Tunnel) Close() {
|
||||
t.listener.Close()
|
||||
t.cancel()
|
||||
t.conns.Wait()
|
||||
slog.Info("tunnel closed", "tunnel_id", t.ID, "device_id", t.DeviceID, "port", t.LocalPort)
|
||||
}
|
||||
|
||||
// IdleDuration returns how long the tunnel has been idle.
|
||||
func (t *Tunnel) IdleDuration() time.Duration {
|
||||
return time.Since(time.Unix(0, atomic.LoadInt64(&t.LastActive)))
|
||||
}
|
||||
|
||||
// ActiveConns returns the number of active TCP connections.
|
||||
func (t *Tunnel) ActiveConns() int64 {
|
||||
return atomic.LoadInt64(&t.activeConns)
|
||||
}
|
||||
|
||||
func (t *Tunnel) accept() {
|
||||
for {
|
||||
conn, err := t.listener.Accept()
|
||||
if err != nil {
|
||||
return // listener closed
|
||||
}
|
||||
t.conns.Add(1)
|
||||
atomic.AddInt64(&t.activeConns, 1)
|
||||
go t.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) handleConn(clientConn net.Conn) {
|
||||
defer t.conns.Done()
|
||||
defer atomic.AddInt64(&t.activeConns, -1)
|
||||
|
||||
slog.Info("tunnel client connected", "tunnel_id", t.ID, "device_id", t.DeviceID)
|
||||
|
||||
routerConn, err := net.DialTimeout("tcp", t.RemoteAddr, 10*time.Second)
|
||||
if err != nil {
|
||||
slog.Warn("tunnel dial failed", "tunnel_id", t.ID, "remote", t.RemoteAddr, "err", err)
|
||||
clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(t.ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
io.Copy(routerConn, newActivityReader(clientConn, &t.LastActive))
|
||||
cancel()
|
||||
}()
|
||||
go func() {
|
||||
io.Copy(clientConn, newActivityReader(routerConn, &t.LastActive))
|
||||
cancel()
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
clientConn.Close()
|
||||
routerConn.Close()
|
||||
|
||||
slog.Info("tunnel client disconnected", "tunnel_id", t.ID, "device_id", t.DeviceID)
|
||||
}
|
||||
|
||||
// activityReader wraps an io.Reader and updates a shared timestamp on every Read.
|
||||
type activityReader struct {
|
||||
r io.Reader
|
||||
lastActive *int64
|
||||
}
|
||||
|
||||
func newActivityReader(r io.Reader, lastActive *int64) *activityReader {
|
||||
return &activityReader{r: r, lastActive: lastActive}
|
||||
}
|
||||
|
||||
func (a *activityReader) Read(p []byte) (int, error) {
|
||||
n, err := a.r.Read(p)
|
||||
if n > 0 {
|
||||
atomic.StoreInt64(a.lastActive, time.Now().UnixNano())
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
161
poller/internal/tunnel/tunnel_test.go
Normal file
161
poller/internal/tunnel/tunnel_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockRouter simulates a RouterOS device accepting TCP connections
|
||||
func mockRouter(t *testing.T) (string, func()) {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
io.Copy(c, c) // echo server
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
return ln.Addr().String(), func() { ln.Close() }
|
||||
}
|
||||
|
||||
func TestTunnel_ProxyBidirectional(t *testing.T) {
|
||||
routerAddr, cleanup := mockRouter(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tun := &Tunnel{
|
||||
ID: "test-1",
|
||||
RemoteAddr: routerAddr,
|
||||
LastActive: time.Now().UnixNano(),
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
tun.listener = ln
|
||||
|
||||
go tun.accept()
|
||||
|
||||
// Connect as a WinBox client
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Write and read back (echo)
|
||||
msg := []byte("hello winbox")
|
||||
_, err = conn.Write(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, len(msg))
|
||||
_, err = io.ReadFull(conn, buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, msg, buf)
|
||||
}
|
||||
|
||||
func TestTunnel_ActivityTracking(t *testing.T) {
|
||||
routerAddr, cleanup := mockRouter(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
before := time.Now().UnixNano()
|
||||
tun := &Tunnel{
|
||||
ID: "test-2",
|
||||
RemoteAddr: routerAddr,
|
||||
LastActive: before,
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
tun.listener = ln
|
||||
go tun.accept()
|
||||
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
conn.Write([]byte("data"))
|
||||
buf := make([]byte, 4)
|
||||
io.ReadFull(conn, buf)
|
||||
conn.Close()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
after := atomic.LoadInt64(&tun.LastActive)
|
||||
assert.Greater(t, after, before)
|
||||
}
|
||||
|
||||
func TestTunnel_Close(t *testing.T) {
|
||||
routerAddr, cleanup := mockRouter(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tun := &Tunnel{
|
||||
ID: "test-3",
|
||||
RemoteAddr: routerAddr,
|
||||
LastActive: time.Now().UnixNano(),
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
tun.listener = ln
|
||||
go tun.accept()
|
||||
|
||||
// Open a connection
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close tunnel — should terminate everything
|
||||
tun.Close()
|
||||
|
||||
// Connection should be dead
|
||||
conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTunnel_DialFailure(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tun := &Tunnel{
|
||||
ID: "test-4",
|
||||
RemoteAddr: "127.0.0.1:1", // nothing listening
|
||||
LastActive: time.Now().UnixNano(),
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
tun.listener = ln
|
||||
go tun.accept()
|
||||
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be closed quickly since dial to router fails
|
||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
Reference in New Issue
Block a user