Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion asyncssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from .config import ConfigParseError

from .forward import SSHForwarder
from .forward import ForwardTracker, SSHForwarder

from .connection import SSHAcceptor, SSHClientConnection, SSHServerConnection
from .connection import SSHClientConnectionOptions, SSHServerConnectionOptions
Expand Down Expand Up @@ -147,6 +147,7 @@
'SSHAgentKeyPair', 'SSHAuthorizedKeys', 'SSHCertificate', 'SSHClient',
'SSHClientChannel', 'SSHClientConnection', 'SSHClientConnectionOptions',
'SSHClientProcess', 'SSHClientSession', 'SSHCompletedProcess',
'ForwardTracker',
'SSHForwarder', 'SSHKey', 'SSHKeyPair', 'SSHKnownHosts',
'SSHLineEditorChannel', 'SSHListener', 'SSHReader', 'SSHServer',
'SSHServerChannel', 'SSHServerConnection',
Expand Down
14 changes: 11 additions & 3 deletions asyncssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
from .encryption import encryption_needs_mac
from .encryption import get_encryption_params, get_encryption

from .forward import SSHForwarder
from .forward import ForwardTracker, SSHForwarder

from .gss import GSSBase, GSSClient, GSSServer, GSSError

Expand Down Expand Up @@ -3210,7 +3210,8 @@ async def forward_unix_connection(self, dest_path: str) -> SSHForwarder:
async def forward_local_port(
self, listen_host: str, listen_port: int,
dest_host: str, dest_port: int,
accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener:
accept_handler: Optional[SSHAcceptHandler] = None,
tracker: Optional['ForwardTracker'] = None) -> SSHListener:
"""Set up local port forwarding

This method is a coroutine which attempts to set up port
Expand All @@ -3233,11 +3234,17 @@ async def forward_local_port(
or not to allow connection forwarding, returning `True` to
accept the connection and begin forwarding or `False` to
reject and close it.
:param tracker:
Optional hooks for observing per-connection lifecycle
events on the local listener (open/close). See
:class:`ForwardTracker`. ``None`` (default) preserves
existing behavior with no overhead.
:type listen_host: `str`
:type listen_port: `int`
:type dest_host: `str`
:type dest_port: `int`
:type accept_handler: `callable` or coroutine
:type tracker: :class:`ForwardTracker` or `None`

:returns: :class:`SSHListener`

Expand Down Expand Up @@ -3281,7 +3288,8 @@ async def tunnel_connection(
listener = await create_tcp_forward_listener(self, self._loop,
tunnel_connection,
listen_host,
listen_port)
listen_port,
tracker=tracker)
except OSError as exc:
self.logger.debug1('Failed to create local TCP listener: %s', exc)
raise
Expand Down
56 changes: 53 additions & 3 deletions asyncssh/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import socket
from types import TracebackType
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional
from typing import Type, cast
from typing import Protocol, Type, cast, runtime_checkable
from typing_extensions import Self

from .misc import ChannelOpenError, SockAddr
Expand All @@ -38,6 +38,27 @@
SSHForwarderCoro = Callable[..., Awaitable]


@runtime_checkable
class ForwardTracker(Protocol):
"""Optional hooks for observing local-forward connection lifecycle.

Each method is called from within the asyncio loop. Implementations
MUST NOT block (no I/O, no sleep). All hooks are best-effort -
asyncssh swallows any exception they raise (logged at debug level).

Pass an instance to :meth:`SSHClientConnection.forward_local_port`
to observe per-connection events on the local listener.
"""

def connection_made(self, orig_host: str, orig_port: int) -> None:
"""A new client TCP connection was accepted on the local listener."""

def connection_lost(self, orig_host: str, orig_port: int,
exc: Optional[Exception]) -> None:
"""A previously-accepted connection has closed (clean exc=None
or with an error)."""


class SSHForwarder(asyncio.BaseProtocol):
"""SSH port forwarding connection handler"""

Expand Down Expand Up @@ -229,6 +250,13 @@ def forward(self, *args: object) -> None:
class SSHLocalPortForwarder(SSHLocalForwarder):
"""Local TCP port forwarding connection handler"""

def __init__(self, conn: 'SSHConnection', coro: SSHForwarderCoro,
tracker: Optional[ForwardTracker] = None):
super().__init__(conn, coro)
self._tracker = tracker
self._orig_host: str = ''
self._orig_port: int = 0

def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a newly opened connection"""

Expand All @@ -237,9 +265,31 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
peername = cast(SockAddr, transport.get_extra_info('peername'))

if peername: # pragma: no branch
orig_host, orig_port = peername[:2]
self._orig_host, self._orig_port = peername[:2]

if self._tracker is not None:
try:
self._tracker.connection_made(self._orig_host,
self._orig_port)
except Exception: # pylint: disable=broad-exception-caught
# Tracker is an observer; a buggy one must not break
# the forwarder. Swallow silently (asyncssh has no
# logger on the protocol class itself).
pass

self.forward(self._orig_host, self._orig_port)

def connection_lost(self, exc: Optional[Exception]) -> None:
"""Handle a closed connection"""

if self._tracker is not None:
try:
self._tracker.connection_lost(self._orig_host,
self._orig_port, exc)
except Exception: # pylint: disable=broad-exception-caught
pass

self.forward(orig_host, orig_port)
super().connection_lost(exc)


class SSHLocalPathForwarder(SSHLocalForwarder):
Expand Down
9 changes: 5 additions & 4 deletions asyncssh/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import Sequence, Set, Tuple, Type, Union
from typing_extensions import Self

from .forward import SSHForwarderCoro
from .forward import ForwardTracker, SSHForwarderCoro
from .forward import SSHLocalPortForwarder, SSHLocalPathForwarder
from .misc import HostPort, MaybeAwait
from .session import SSHTCPSession, SSHUNIXSession
Expand Down Expand Up @@ -345,14 +345,15 @@ async def create_tcp_local_listener(
async def create_tcp_forward_listener(conn: 'SSHConnection',
loop: asyncio.AbstractEventLoop,
coro: SSHForwarderCoro, listen_host: str,
listen_port: int) -> \
'SSHForwardListener':
listen_port: int,
tracker: Optional[ForwardTracker] = None
) -> 'SSHForwardListener':
"""Create a listener to forward traffic from a local TCP port over SSH"""

def protocol_factory() -> asyncio.BaseProtocol:
"""Start a port forwarder for each new local connection"""

return SSHLocalPortForwarder(conn, coro)
return SSHLocalPortForwarder(conn, coro, tracker)

return await create_tcp_local_listener(conn, loop, protocol_factory,
listen_host, listen_port)
Expand Down
7 changes: 7 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,13 @@ Forwarder Classes
.. automethod:: close
============================== =

.. autoclass:: ForwardTracker

============================== =
.. automethod:: connection_made
.. automethod:: connection_lost
============================== =


Listener Classes
================
Expand Down
17 changes: 17 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@
Change Log
==========

Unreleased (fork: AlexMKX/asyncssh feat/forward-tracker)
--------------------------------------------------------

* Added optional ``tracker`` keyword argument to
:meth:`SSHClientConnection.forward_local_port`, accepting an object
implementing the new :class:`ForwardTracker` :class:`typing.Protocol`.
The tracker's ``connection_made(orig_host, orig_port)`` hook fires on
every client connect to the local listener, and
``connection_lost(orig_host, orig_port, exc)`` on every disconnect.
Hook exceptions are caught and swallowed so a buggy tracker cannot
break the forwarder. ``tracker=None`` (default) preserves existing
behavior with no overhead.

Use case: observe per-connection lifecycle on local forwards for
idle-based daemon auto-shutdown, byte counters, or other passive
metrics.

Release 2.23.0 (8 Feb 2026)
---------------------------

Expand Down
54 changes: 54 additions & 0 deletions tests/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,60 @@ async def accept_handler(_orig_host: str, _orig_port: int) -> bool:
writer.close()
await maybe_wait_closed(writer)

@asynctest
async def test_forward_local_port_tracker_fires_on_open_and_close(self):
"""ForwardTracker.connection_made and connection_lost both fire"""

events: list[tuple[str, str, int, object]] = []

class _RecordingTracker:
def connection_made(self, orig_host: str,
orig_port: int) -> None:
events.append(('made', orig_host, orig_port, None))

def connection_lost(self, orig_host: str, orig_port: int,
exc: object) -> None:
events.append(('lost', orig_host, orig_port, exc))

async with self.connect() as conn:
async with conn.forward_local_port(
'', 0, '', 7,
tracker=_RecordingTracker()) as listener:
listen_port = listener.get_port()
reader, writer = await asyncio.open_connection(
'127.0.0.1', listen_port)
writer.close()
await maybe_wait_closed(writer)
# Give asyncssh a moment to fire connection_lost.
await asyncio.sleep(0.1)

kinds = [event[0] for event in events]
self.assertIn('made', kinds)
self.assertIn('lost', kinds)

@asynctest
async def test_forward_local_port_tracker_exception_is_swallowed(self):
"""A buggy tracker does not break forwarding"""

class _BuggyTracker:
def connection_made(self, orig_host: str,
orig_port: int) -> None:
raise RuntimeError('made boom')

def connection_lost(self, orig_host: str, orig_port: int,
exc: object) -> None:
raise RuntimeError('lost boom')

async with self.connect() as conn:
async with conn.forward_local_port(
'', 0, '', 7,
tracker=_BuggyTracker()) as listener:
# If asyncssh did not swallow tracker exceptions, the
# forward would die here and _check_local_connection
# would raise.
await self._check_local_connection(listener.get_port(),
delay=0.1)

@unittest.skipIf(sys.platform == 'win32',
'skip UNIX domain socket tests on Windows')
@asynctest
Expand Down