diff --git a/asyncssh/__init__.py b/asyncssh/__init__.py index fb316e0..d5c0e15 100644 --- a/asyncssh/__init__.py +++ b/asyncssh/__init__.py @@ -40,7 +40,7 @@ from .config import ConfigParseError -from .forward import SSHForwarder +from .forward import ForwardTracker, SSHForwarder from .connection import SSHAcceptor, SSHClientConnection, SSHServerConnection from .connection import SSHClientConnectionOptions, SSHServerConnectionOptions @@ -147,6 +147,7 @@ 'SSHAgentKeyPair', 'SSHAuthorizedKeys', 'SSHCertificate', 'SSHClient', 'SSHClientChannel', 'SSHClientConnection', 'SSHClientConnectionOptions', 'SSHClientProcess', 'SSHClientSession', 'SSHCompletedProcess', + 'ForwardTracker', 'SSHForwarder', 'SSHKey', 'SSHKeyPair', 'SSHKnownHosts', 'SSHLineEditorChannel', 'SSHListener', 'SSHReader', 'SSHServer', 'SSHServerChannel', 'SSHServerConnection', diff --git a/asyncssh/connection.py b/asyncssh/connection.py index 89fdb16..c9393a1 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -85,7 +85,7 @@ from .encryption import encryption_needs_mac from .encryption import get_encryption_params, get_encryption -from .forward import SSHForwarder +from .forward import ForwardTracker, SSHForwarder from .gss import GSSBase, GSSClient, GSSServer, GSSError @@ -3210,7 +3210,8 @@ async def forward_unix_connection(self, dest_path: str) -> SSHForwarder: async def forward_local_port( self, listen_host: str, listen_port: int, dest_host: str, dest_port: int, - accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener: + accept_handler: Optional[SSHAcceptHandler] = None, + tracker: Optional['ForwardTracker'] = None) -> SSHListener: """Set up local port forwarding This method is a coroutine which attempts to set up port @@ -3233,11 +3234,17 @@ async def forward_local_port( or not to allow connection forwarding, returning `True` to accept the connection and begin forwarding or `False` to reject and close it. + :param tracker: + Optional hooks for observing per-connection lifecycle + events on the local listener (open/close). See + :class:`ForwardTracker`. ``None`` (default) preserves + existing behavior with no overhead. :type listen_host: `str` :type listen_port: `int` :type dest_host: `str` :type dest_port: `int` :type accept_handler: `callable` or coroutine + :type tracker: :class:`ForwardTracker` or `None` :returns: :class:`SSHListener` @@ -3281,7 +3288,8 @@ async def tunnel_connection( listener = await create_tcp_forward_listener(self, self._loop, tunnel_connection, listen_host, - listen_port) + listen_port, + tracker=tracker) except OSError as exc: self.logger.debug1('Failed to create local TCP listener: %s', exc) raise diff --git a/asyncssh/forward.py b/asyncssh/forward.py index 8470c00..cc15d1f 100644 --- a/asyncssh/forward.py +++ b/asyncssh/forward.py @@ -24,7 +24,7 @@ import socket from types import TracebackType from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional -from typing import Type, cast +from typing import Protocol, Type, cast, runtime_checkable from typing_extensions import Self from .misc import ChannelOpenError, SockAddr @@ -38,6 +38,27 @@ SSHForwarderCoro = Callable[..., Awaitable] +@runtime_checkable +class ForwardTracker(Protocol): + """Optional hooks for observing local-forward connection lifecycle. + + Each method is called from within the asyncio loop. Implementations + MUST NOT block (no I/O, no sleep). All hooks are best-effort - + asyncssh swallows any exception they raise (logged at debug level). + + Pass an instance to :meth:`SSHClientConnection.forward_local_port` + to observe per-connection events on the local listener. + """ + + def connection_made(self, orig_host: str, orig_port: int) -> None: + """A new client TCP connection was accepted on the local listener.""" + + def connection_lost(self, orig_host: str, orig_port: int, + exc: Optional[Exception]) -> None: + """A previously-accepted connection has closed (clean exc=None + or with an error).""" + + class SSHForwarder(asyncio.BaseProtocol): """SSH port forwarding connection handler""" @@ -229,6 +250,13 @@ def forward(self, *args: object) -> None: class SSHLocalPortForwarder(SSHLocalForwarder): """Local TCP port forwarding connection handler""" + def __init__(self, conn: 'SSHConnection', coro: SSHForwarderCoro, + tracker: Optional[ForwardTracker] = None): + super().__init__(conn, coro) + self._tracker = tracker + self._orig_host: str = '' + self._orig_port: int = 0 + def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a newly opened connection""" @@ -237,9 +265,31 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: peername = cast(SockAddr, transport.get_extra_info('peername')) if peername: # pragma: no branch - orig_host, orig_port = peername[:2] + self._orig_host, self._orig_port = peername[:2] + + if self._tracker is not None: + try: + self._tracker.connection_made(self._orig_host, + self._orig_port) + except Exception: # pylint: disable=broad-exception-caught + # Tracker is an observer; a buggy one must not break + # the forwarder. Swallow silently (asyncssh has no + # logger on the protocol class itself). + pass + + self.forward(self._orig_host, self._orig_port) + + def connection_lost(self, exc: Optional[Exception]) -> None: + """Handle a closed connection""" + + if self._tracker is not None: + try: + self._tracker.connection_lost(self._orig_host, + self._orig_port, exc) + except Exception: # pylint: disable=broad-exception-caught + pass - self.forward(orig_host, orig_port) + super().connection_lost(exc) class SSHLocalPathForwarder(SSHLocalForwarder): diff --git a/asyncssh/listener.py b/asyncssh/listener.py index e9cc475..317d882 100644 --- a/asyncssh/listener.py +++ b/asyncssh/listener.py @@ -28,7 +28,7 @@ from typing import Sequence, Set, Tuple, Type, Union from typing_extensions import Self -from .forward import SSHForwarderCoro +from .forward import ForwardTracker, SSHForwarderCoro from .forward import SSHLocalPortForwarder, SSHLocalPathForwarder from .misc import HostPort, MaybeAwait from .session import SSHTCPSession, SSHUNIXSession @@ -345,14 +345,15 @@ async def create_tcp_local_listener( async def create_tcp_forward_listener(conn: 'SSHConnection', loop: asyncio.AbstractEventLoop, coro: SSHForwarderCoro, listen_host: str, - listen_port: int) -> \ - 'SSHForwardListener': + listen_port: int, + tracker: Optional[ForwardTracker] = None + ) -> 'SSHForwardListener': """Create a listener to forward traffic from a local TCP port over SSH""" def protocol_factory() -> asyncio.BaseProtocol: """Start a port forwarder for each new local connection""" - return SSHLocalPortForwarder(conn, coro) + return SSHLocalPortForwarder(conn, coro, tracker) return await create_tcp_local_listener(conn, loop, protocol_factory, listen_host, listen_port) diff --git a/docs/api.rst b/docs/api.rst index 046245b..4383a85 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1008,6 +1008,13 @@ Forwarder Classes .. automethod:: close ============================== = +.. autoclass:: ForwardTracker + + ============================== = + .. automethod:: connection_made + .. automethod:: connection_lost + ============================== = + Listener Classes ================ diff --git a/docs/changes.rst b/docs/changes.rst index 5c8973d..0ab75f5 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -3,6 +3,23 @@ Change Log ========== +Unreleased (fork: AlexMKX/asyncssh feat/forward-tracker) +-------------------------------------------------------- + +* Added optional ``tracker`` keyword argument to + :meth:`SSHClientConnection.forward_local_port`, accepting an object + implementing the new :class:`ForwardTracker` :class:`typing.Protocol`. + The tracker's ``connection_made(orig_host, orig_port)`` hook fires on + every client connect to the local listener, and + ``connection_lost(orig_host, orig_port, exc)`` on every disconnect. + Hook exceptions are caught and swallowed so a buggy tracker cannot + break the forwarder. ``tracker=None`` (default) preserves existing + behavior with no overhead. + + Use case: observe per-connection lifecycle on local forwards for + idle-based daemon auto-shutdown, byte counters, or other passive + metrics. + Release 2.23.0 (8 Feb 2026) --------------------------- diff --git a/tests/test_forward.py b/tests/test_forward.py index dbfd792..31e8cdd 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -697,6 +697,60 @@ async def accept_handler(_orig_host: str, _orig_port: int) -> bool: writer.close() await maybe_wait_closed(writer) + @asynctest + async def test_forward_local_port_tracker_fires_on_open_and_close(self): + """ForwardTracker.connection_made and connection_lost both fire""" + + events: list[tuple[str, str, int, object]] = [] + + class _RecordingTracker: + def connection_made(self, orig_host: str, + orig_port: int) -> None: + events.append(('made', orig_host, orig_port, None)) + + def connection_lost(self, orig_host: str, orig_port: int, + exc: object) -> None: + events.append(('lost', orig_host, orig_port, exc)) + + async with self.connect() as conn: + async with conn.forward_local_port( + '', 0, '', 7, + tracker=_RecordingTracker()) as listener: + listen_port = listener.get_port() + reader, writer = await asyncio.open_connection( + '127.0.0.1', listen_port) + writer.close() + await maybe_wait_closed(writer) + # Give asyncssh a moment to fire connection_lost. + await asyncio.sleep(0.1) + + kinds = [event[0] for event in events] + self.assertIn('made', kinds) + self.assertIn('lost', kinds) + + @asynctest + async def test_forward_local_port_tracker_exception_is_swallowed(self): + """A buggy tracker does not break forwarding""" + + class _BuggyTracker: + def connection_made(self, orig_host: str, + orig_port: int) -> None: + raise RuntimeError('made boom') + + def connection_lost(self, orig_host: str, orig_port: int, + exc: object) -> None: + raise RuntimeError('lost boom') + + async with self.connect() as conn: + async with conn.forward_local_port( + '', 0, '', 7, + tracker=_BuggyTracker()) as listener: + # If asyncssh did not swallow tracker exceptions, the + # forward would die here and _check_local_connection + # would raise. + await self._check_local_connection(listener.get_port(), + delay=0.1) + @unittest.skipIf(sys.platform == 'win32', 'skip UNIX domain socket tests on Windows') @asynctest