diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 133a0ae..f96bba1 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -13,6 +13,12 @@ jobs: matrix: python-version: [3.8, 3.9, "3.10", 3.11, 3.12, 3.13] os: [ubuntu-latest] + disable_trio: [""] + include: + - python-version: "3.13" + disable_trio: "-p no:pytest-trio" + - python-version: "3.14-dev" + disable_trio: "-p no:pytest-trio" runs-on: ${{ matrix.os }} steps: - uses: KeisukeYamashita/memcached-actions@v1 @@ -31,4 +37,4 @@ jobs: uv run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=.venv uv run flake8 . --count --exit-zero --max-complexity=10 --statistics --exclude=.venv - name: Test with pytest - run: uv run pytest -v + run: pytest ${{ matrix.disable_trio }} diff --git a/README.md b/README.md index cf20bb9..32ddb02 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,92 @@ Key features: $ pip install memcache ``` +## Usage + +### Basic Usage + +```python +import memcache + +client = memcache.Memcache(("localhost", 11211)) + +client.set("key", "value", expire=60) +value = client.get("key") +client.delete("key") + +# Atomic counters +client.set("counter", 0) +client.incr("counter") # 1 +client.incr("counter", 5) # 6 +client.decr("counter", 2) # 4 + +# Compare-and-swap +value, token = client.gets("key") +client.cas("key", "new_value", token) +``` + +Async usage mirrors the sync API with `AsyncMemcache` and `await`. + +### MetaClient (Advanced) + +> **Experimental.** `MetaClient` lives under `memcache.experiment` and its API +> may change in any minor release. If you depend on it, pin the **minor version** +> in your dependency spec. Patch releases (`x.y.Z`) will not introduce breaking +> changes, but minor releases (`x.Y.0`) might. +> +> **requirements.txt** +> ``` +> memcache~=0.14.0 # allows 0.14.x, blocks 0.15+ +> ``` +> +> **pyproject.toml** +> ```toml +> [project] +> dependencies = [ +> "memcache>=0.14.0,<0.15", +> ] +> ``` + +`MetaClient` exposes the full power of memcached's +[meta protocol](https://github.com/memcached/memcached/blob/master/doc/protocol.txt), +including flags unavailable through the basic API. + +```python +from memcache.experiment import MetaClient + +client = MetaClient(("localhost", 11211)) + +# get returns a GetResult with rich metadata +result = client.get( + "key", + return_cas=True, + return_ttl=True, + return_hit_before=True, +) +if result is not None: + print(result.value) + print(result.cas_token) + print(result.ttl) + print(result.hit_before) + +# Atomic get-and-touch (update TTL in the same round-trip) +value = client.gat("key", expire=120) + +# Store only if key does not exist +client.add("key", "value", expire=60) + +# Store only if key already exists +client.replace("key", "new_value") + +# Increment with auto-create if missing +client.incr("counter", delta=1, initial=0, initial_ttl=3600) + +# Flush with a delay +client.flush_all(delay=30) +``` + +`AsyncMetaClient` is the async counterpart with the same interface. + ## About the Project Memcache is © 2020-2025 by [aisk](https://github.com/aisk). diff --git a/memcache/async_connection.py b/memcache/async_connection.py new file mode 100644 index 0000000..6e942b3 --- /dev/null +++ b/memcache/async_connection.py @@ -0,0 +1,118 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import AsyncIterator, Callable, Optional, Tuple + +import anyio +from anyio.streams.buffered import BufferedByteReceiveStream + +from .errors import MemcacheError +from .meta_command import MetaCommand, MetaResult + + +class AsyncConnection: + def __init__( + self, + addr: Tuple[str, int], + *, + username: Optional[str] = None, + password: Optional[str] = None, + ): + self._addr = addr + self._username = username + self._password = password + self._connected = False + + async def _connect(self) -> None: + self.writer = await anyio.connect_tcp(self._addr[0], self._addr[1]) + self.reader = BufferedByteReceiveStream(self.writer) + await self._auth() + self._connected = True + + async def _auth(self) -> None: + if self._username is None or self._password is None: + return + auth_data = b"%s %s" % ( + self._username.encode("utf-8"), + self._password.encode("utf-8"), + ) + await self.writer.send(b"set auth x 0 %d\r\n" % len(auth_data)) + await self.writer.send(auth_data) + await self.writer.send(b"\r\n") + response = await self.reader.receive_until(b"\r\n", max_bytes=1024) + if response != b"STORED": + raise MemcacheError(response) + + async def flush_all(self, delay: int = 0) -> None: + if not self._connected: + await self._connect() + + if delay > 0: + await self.writer.send(b"flush_all %d\r\n" % delay) + else: + await self.writer.send(b"flush_all\r\n") + response = await self.reader.receive_until(b"\r\n", max_bytes=1024) + if response != b"OK": + raise MemcacheError(response) + + async def execute_meta_command(self, command: MetaCommand) -> MetaResult: + try: + return await self._execute_meta_command(command) + except (IndexError, ConnectionResetError, BrokenPipeError): + self._connected = False + return await self._execute_meta_command(command) + + async def _execute_meta_command(self, command: MetaCommand) -> MetaResult: + if not self._connected: + await self._connect() + + await self.writer.send(command.dump_header()) + if command.value: + await self.writer.send(command.value + b"\r\n") + return await self._receive_meta_result() + + async def _receive_meta_result(self) -> MetaResult: + header_line = await self.reader.receive_until(b"\r\n", max_bytes=1024) + result = MetaResult.load_header(header_line) + + if result.rc == b"VA": + if result.datalen is None: + raise MemcacheError("invalid response: missing datalen") + result.value = await self.reader.receive_exactly(result.datalen) + await self.reader.receive_exactly(2) # read the "\r\n" + + return result + + +class AsyncPool: + def __init__( + self, + create_connection: Callable[..., AsyncConnection], + max_size: Optional[int], + timeout: Optional[int], + ) -> None: + self._create_connection = create_connection + self._max_size = max_size + self._timeout = timeout + self._size = 0 + self._lock = asyncio.Lock() + self._connections: asyncio.Queue[AsyncConnection] = asyncio.Queue() + + @asynccontextmanager + async def get(self) -> AsyncIterator[AsyncConnection]: + try: + connection = self._connections.get_nowait() + yield connection + await self._connections.put(connection) + except asyncio.QueueEmpty: + if self._max_size and self._size >= self._max_size: + connection = await asyncio.wait_for( + self._connections.get(), timeout=self._timeout + ) + yield connection + await self._connections.put(connection) + else: + async with self._lock: + self._size += 1 + connection = self._create_connection() + yield connection + await self._connections.put(connection) diff --git a/memcache/async_memcache.py b/memcache/async_memcache.py index 1553e90..355f386 100644 --- a/memcache/async_memcache.py +++ b/memcache/async_memcache.py @@ -1,241 +1,14 @@ -import asyncio from contextlib import asynccontextmanager -from typing import Any, AsyncIterator, Callable, List, Tuple, Union, Optional - -import anyio -import hashring -from anyio.streams.buffered import BufferedByteReceiveStream +from typing import Any, AsyncIterator, List, Optional, Tuple, Union +from .async_connection import AsyncConnection, AsyncPool # noqa: F401 re-export +from .connection import Addr from .errors import MemcacheError -from .memcache import Addr +from .experiment.async_meta_client import AsyncMetaClient from .meta_command import MetaCommand, MetaResult from .serialize import dump, load, DumpFunc, LoadFunc - -class AsyncConnection: - def __init__( - self, - addr: Tuple[str, int], - *, - load_func: LoadFunc = load, - dump_func: DumpFunc = dump, - username: Optional[str] = None, - password: Optional[str] = None, - ): - self._addr = addr - self._load = load_func - self._dump = dump_func - self._username = username - self._password = password - self._connected = False - - async def _connect(self) -> None: - self.writer = await anyio.connect_tcp(self._addr[0], self._addr[1]) - self.reader = BufferedByteReceiveStream(self.writer) - await self._auth() - self._connected = True - - async def _auth(self) -> None: - if self._username is None or self._password is None: - return - auth_data = b"%s %s" % ( - self._username.encode("utf-8"), - self._password.encode("utf-8"), - ) - await self.writer.send(b"set auth x 0 %d\r\n" % len(auth_data)) - await self.writer.send(auth_data) - await self.writer.send(b"\r\n") - response = await self.reader.receive_until(b"\r\n", max_bytes=1024) - if response != b"STORED": - raise MemcacheError(response) - - async def flush_all(self) -> None: - if not self._connected: - await self._connect() - - await self.writer.send(b"flush_all\r\n") - response = await self.reader.receive_until(b"\r\n", max_bytes=1024) - if response != b"OK": - raise MemcacheError(response) - - async def execute_meta_command(self, command: MetaCommand) -> MetaResult: - try: - return await self._execute_meta_command(command) - except (IndexError, ConnectionResetError, BrokenPipeError): - self._connected = False - return await self._execute_meta_command(command) - - async def _execute_meta_command(self, command: MetaCommand) -> MetaResult: - if not self._connected: - await self._connect() - - await self.writer.send(command.dump_header()) - if command.value: - await self.writer.send(command.value + b"\r\n") - return await self._receive_meta_result() - - async def _receive_meta_result(self) -> MetaResult: - header_line = await self.reader.receive_until(b"\r\n", max_bytes=1024) - result = MetaResult.load_header(header_line) - - if result.rc == b"VA": - if result.datalen is None: - raise MemcacheError("invalid response: missing datalen") - result.value = await self.reader.receive_exactly(result.datalen) - await self.reader.receive_exactly(2) # read the "\r\n" - - return result - - async def set( - self, key: Union[bytes, str], value: Any, expire: Optional[int] = None - ) -> None: - value, client_flags = self._dump(key, value) - - flags = [b"F%d" % client_flags] - if expire: - flags.append(b"T%d" % expire) - - command = MetaCommand( - cm=b"ms", key=key, datalen=len(value), flags=flags, value=value - ) - await self.execute_meta_command(command) - - async def cas( - self, - key: Union[bytes, str], - value: Any, - cas_token: int, - *, - expire: Optional[int] = None, - ) -> None: - """ - Store a value using compare-and-swap operation. - - :param key: The key to store - :param value: The value to store - :param cas_token: The CAS token from a previous gets operation - :param expire: Optional expiration time in seconds - :raises MemcacheError: If the CAS token doesn't match or other error occurs - """ - value, client_flags = self._dump(key, value) - - flags = [b"F%d" % client_flags, b"C%d" % cas_token] - if expire: - flags.append(b"T%d" % expire) - - command = MetaCommand( - cm=b"ms", key=key, datalen=len(value), flags=flags, value=value - ) - result = await self.execute_meta_command(command) - - if result.rc != b"HD": - raise MemcacheError("CAS operation failed: token mismatch or other error") - - async def get(self, key: Union[bytes, str]) -> Optional[Any]: - command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f"]) - result = await self.execute_meta_command(command) - - if result.value is None: - return None - - client_flags = int(result.flags[0][1:]) - - return self._load(key, result.value, client_flags) - - async def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]: - """ - Get a value and its CAS token from memcached. - - :param key: The key to retrieve - :return: A tuple of (value, cas_token) or None if key doesn't exist - """ - command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f", b"c"]) - result = await self.execute_meta_command(command) - - if result.value is None: - return None - - client_flags = int(result.flags[0][1:]) - value = self._load(key, result.value, client_flags) - - # Find CAS token in flags - cas_token = None - for flag in result.flags[1:]: # Skip the first flag (client_flags) - if flag.startswith(b"c"): - cas_token = int(flag[1:]) - break - - if cas_token is None: - raise MemcacheError("CAS token not found in response") - - return value, cas_token - - async def delete(self, key: Union[bytes, str]) -> None: - command = MetaCommand(cm=b"md", key=key, flags=[], value=None) - await self.execute_meta_command(command) - - async def incr(self, key: Union[bytes, str], value: int = 1) -> int: - command = MetaCommand( - cm=b"ma", key=key, flags=[b"D%d" % value, b"v"] - ) - result = await self.execute_meta_command(command) - - if result.rc != b"VA": - raise MemcacheError(f"INCR operation failed: {result.rc.decode()}") - - if result.value is None: - raise MemcacheError("INCR operation failed: no value returned") - - return int(result.value) - - async def decr(self, key: Union[bytes, str], value: int = 1) -> int: - command = MetaCommand( - cm=b"ma", key=key, flags=[b"D%d" % value, b"MD", b"v"] - ) - result = await self.execute_meta_command(command) - - if result.rc != b"VA": - raise MemcacheError(f"DECR operation failed: {result.rc.decode()}") - - if result.value is None: - raise MemcacheError("DECR operation failed: no value returned") - - return int(result.value) - - -class AsyncPool: - def __init__( - self, - create_connection: Callable[..., AsyncConnection], - max_size: Optional[int], - timeout: Optional[int], - ) -> None: - self._create_connection = create_connection - self._max_size = max_size - self._timeout = timeout - self._size = 0 - self._lock = asyncio.Lock() - self._connections: asyncio.Queue[AsyncConnection] = asyncio.Queue() - - @asynccontextmanager - async def get(self) -> AsyncIterator[AsyncConnection]: - try: - connection = self._connections.get_nowait() - yield connection - await self._connections.put(connection) - except asyncio.QueueEmpty: - if self._max_size and self._size >= self._max_size: - connection = await asyncio.wait_for( - self._connections.get(), timeout=self._timeout - ) - yield connection - await self._connections.put(connection) - else: - async with self._lock: - self._size += 1 - connection = self._create_connection() - yield connection - await self._connections.put(connection) +__all__ = ["AsyncConnection", "AsyncPool", "AsyncMemcache"] class AsyncMemcache: @@ -276,67 +49,37 @@ def __init__( username: Optional[str] = None, password: Optional[str] = None, ): - addr = addr or ("localhost", 11211) - if isinstance(addr, list): - addrs: List[Addr] = addr - nodes: List[AsyncPool] = [] - for addr in addrs: - create_connection = lambda: AsyncConnection( - addr, - load_func=load_func, - dump_func=dump_func, - username=username, - password=password, - ) - nodes.append( - AsyncPool( - create_connection, max_size=pool_size, timeout=pool_timeout - ) - ) - self._connections = hashring.HashRing(nodes) - elif isinstance(addr, tuple): - a: Addr = addr - create_connection = lambda: AsyncConnection( - a, - load_func=load_func, - dump_func=dump_func, - username=username, - password=password, - ) - self._connections = hashring.HashRing( - [AsyncPool(create_connection, max_size=pool_size, timeout=pool_timeout)] - ) - else: - raise TypeError("invalid type for addr") + self._meta = AsyncMetaClient( + addr, + pool_size=pool_size, + pool_timeout=pool_timeout, + load_func=load_func, + dump_func=dump_func, + username=username, + password=password, + ) @asynccontextmanager async def _get_connection( self, key: Union[str, bytes] ) -> AsyncIterator[AsyncConnection]: - if isinstance(key, bytes): - key = key.decode("utf-8") - pool = self._connections.get_node(key) - async with pool.get() as connection: - yield connection + async with self._meta._get_connection(key) as conn: + yield conn async def execute_meta_command(self, command: MetaCommand) -> MetaResult: - async with self._get_connection(command.key) as connection: - return await connection.execute_meta_command(command) + return await self._meta.execute_meta_command(command) async def flush_all(self) -> None: - for pool in self._connections.nodes: - async with pool.get() as connection: - await connection.flush_all() + await self._meta.flush_all() async def set( self, key: Union[bytes, str], value: Any, *, expire: Optional[int] = None ) -> None: - async with self._get_connection(key) as connection: - return await connection.set(key, value, expire=expire) + await self._meta.set(key, value, expire=expire) async def get(self, key: Union[bytes, str]) -> Optional[Any]: - async with self._get_connection(key) as connection: - return await connection.get(key) + r = await self._meta.get(key) + return r.value if r is not None else None async def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]: """ @@ -345,8 +88,12 @@ async def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]: :param key: The key to retrieve :return: A tuple of (value, cas_token) or None if key doesn't exist """ - async with self._get_connection(key) as connection: - return await connection.gets(key) + r = await self._meta.get(key, return_cas=True) + if r is None: + return None + if r.cas_token is None: + raise MemcacheError("CAS token not found in response") + return r.value, r.cas_token async def cas( self, @@ -365,17 +112,15 @@ async def cas( :param expire: Optional expiration time in seconds :raises MemcacheError: If the CAS token doesn't match or other error occurs """ - async with self._get_connection(key) as connection: - await connection.cas(key, value, cas_token, expire=expire) + ok = await self._meta.cas(key, value, cas_token, expire=expire) + if not ok: + raise MemcacheError("CAS operation failed: token mismatch or other error") async def delete(self, key: Union[bytes, str]) -> None: - async with self._get_connection(key) as connection: - return await connection.delete(key) + await self._meta.delete(key) async def incr(self, key: Union[bytes, str], value: int = 1) -> int: - async with self._get_connection(key) as connection: - return await connection.incr(key, value) + return await self._meta.incr(key, value) async def decr(self, key: Union[bytes, str], value: int = 1) -> int: - async with self._get_connection(key) as connection: - return await connection.decr(key, value) + return await self._meta.decr(key, value) diff --git a/memcache/connection.py b/memcache/connection.py new file mode 100644 index 0000000..e02edc9 --- /dev/null +++ b/memcache/connection.py @@ -0,0 +1,121 @@ +import queue +import socket +import threading +from contextlib import contextmanager +from typing import Callable, Iterator, Optional, Tuple + +from .errors import MemcacheError +from .meta_command import MetaCommand, MetaResult + + +NEWLINE = b"\r\n" + +Addr = Tuple[str, int] + + +class Connection: + def __init__( + self, + addr: Addr, + *, + username: Optional[str] = None, + password: Optional[str] = None, + ): + self._addr = addr + self._username = username + self._password = password + self._connect() + + def _connect(self) -> None: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.connect(self._addr) + self.stream = self.socket.makefile(mode="rwb") + self._auth() + + def _auth(self) -> None: + if self._username is None or self._password is None: + return + auth_data = b"%s %s" % ( + self._username.encode("utf-8"), + self._password.encode("utf-8"), + ) + self.stream.write(b"set auth x 0 %d\r\n" % len(auth_data)) + self.stream.write(auth_data) + self.stream.write(b"\r\n") + self.stream.flush() + response = self.stream.readline() + if response != b"STORED\r\n": + raise MemcacheError(response.rstrip(NEWLINE)) + + def close(self) -> None: + self.stream.close() + self.socket.close() + + def flush_all(self, delay: int = 0) -> None: + if delay > 0: + self.stream.write(b"flush_all %d\r\n" % delay) + else: + self.stream.write(b"flush_all\r\n") + self.stream.flush() + response = self.stream.readline() + if response != b"OK\r\n": + raise MemcacheError(response.rstrip(NEWLINE)) + + def execute_meta_command(self, command: MetaCommand) -> MetaResult: + try: + return self._execute_meta_command(command) + except (IndexError, ConnectionResetError, BrokenPipeError): + # This happens when connection is closed by memcached. + self._connect() + return self._execute_meta_command(command) + + def _execute_meta_command(self, command: MetaCommand) -> MetaResult: + self.stream.write(command.dump_header()) + if command.value: + self.stream.write(command.value + b"\r\n") + self.stream.flush() + return self._receive_meta_result() + + def _receive_meta_result(self) -> MetaResult: + result = MetaResult.load_header(self.stream.readline()) + + if result.rc == b"VA": + if result.datalen is None: + raise MemcacheError("invalid response: missing datalen") + result.value = self.stream.read(result.datalen) + self.stream.read(2) # read the "\r\n" + + return result + + +class Pool: + def __init__( + self, + create_connection: Callable[..., Connection], + max_size: Optional[int], + timeout: Optional[int], + ) -> None: + self._create_connection = create_connection + self._max_size = max_size + self._timeout = timeout + self._size = 0 + self._lock = threading.Lock() + self._connections: queue.Queue[Connection] = queue.Queue() + + @contextmanager + def get(self) -> Iterator[Connection]: + try: + connection = self._connections.get_nowait() + yield connection + self._connections.put(connection) + except queue.Empty: + if self._max_size and self._size >= self._max_size: + connection = self._connections.get(timeout=self._timeout) + yield connection + self._connections.put(connection) + else: + with self._lock: + self._size += 1 + connection = self._create_connection() + yield connection + self._connections.put(connection) diff --git a/memcache/experiment/__init__.py b/memcache/experiment/__init__.py new file mode 100644 index 0000000..2241871 --- /dev/null +++ b/memcache/experiment/__init__.py @@ -0,0 +1,5 @@ +from .async_meta_client import AsyncMetaClient +from .meta_client import MetaClient +from .result import GetResult + +__all__ = ["MetaClient", "AsyncMetaClient", "GetResult"] diff --git a/memcache/experiment/async_meta_client.py b/memcache/experiment/async_meta_client.py new file mode 100644 index 0000000..0fd0b01 --- /dev/null +++ b/memcache/experiment/async_meta_client.py @@ -0,0 +1,508 @@ +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator, Dict, List, Optional, Union + +import hashring + +from ..async_connection import AsyncConnection, AsyncPool +from ..connection import Addr +from ..errors import MemcacheError +from ..meta_command import MetaCommand, MetaResult +from ..serialize import dump, load, DumpFunc, LoadFunc +from .meta_client import _parse_flags +from .result import GetResult + + +class AsyncMetaClient: + """ + Async memcache client with full meta protocol capability. + + Mirror of MetaClient using anyio-based AsyncConnection/AsyncPool. + """ + + def __init__( + self, + addr: Union[Addr, List[Addr], None] = None, + *, + pool_size: Optional[int] = 23, + pool_timeout: Optional[int] = 1, + load_func: LoadFunc = load, + dump_func: DumpFunc = dump, + username: Optional[str] = None, + password: Optional[str] = None, + ): + self._load = load_func + self._dump = dump_func + + addr = addr or ("localhost", 11211) + if isinstance(addr, list): + addrs: List[Addr] = addr + nodes: List[AsyncPool] = [] + for a in addrs: + def _make(a: Addr = a) -> AsyncConnection: + return AsyncConnection( + a, + username=username, + password=password, + ) + nodes.append( + AsyncPool(_make, max_size=pool_size, timeout=pool_timeout) + ) + self._connections = hashring.HashRing(nodes) + elif isinstance(addr, tuple): + a_single: Addr = addr + + def _make_single() -> AsyncConnection: + return AsyncConnection( + a_single, + username=username, + password=password, + ) + self._connections = hashring.HashRing( + [AsyncPool(_make_single, max_size=pool_size, timeout=pool_timeout)] + ) + else: + raise TypeError("invalid type for addr") + + @asynccontextmanager + async def _get_connection( + self, key: Union[str, bytes] + ) -> AsyncIterator[AsyncConnection]: + if isinstance(key, bytes): + key = key.decode("utf-8") + pool = self._connections.get_node(key) + async with pool.get() as connection: + yield connection + + @staticmethod + def _to_bytes(key: Union[str, bytes]) -> bytes: + if isinstance(key, str): + return key.encode() + return key + + async def execute_meta_command(self, command: MetaCommand) -> MetaResult: + async with self._get_connection(command.key) as connection: + return await connection.execute_meta_command(command) + + # ------------------------------------------------------------------ # + # Meta Get (mg) # + # ------------------------------------------------------------------ # + + async def get( + self, + key: Union[str, bytes], + *, + return_cas: bool = False, + return_ttl: bool = False, + return_last_access: bool = False, + return_size: bool = False, + return_hit_before: bool = False, + update_ttl: Optional[int] = None, + no_lru_bump: bool = False, + vivify_on_miss_ttl: Optional[int] = None, + recache_ttl_threshold: Optional[int] = None, + check_cas: Optional[int] = None, + ) -> Optional[GetResult[Any]]: + key_bytes = self._to_bytes(key) + flags: List[bytes] = [b"v", b"f"] + if return_cas: + flags.append(b"c") + if return_ttl: + flags.append(b"t") + if return_last_access: + flags.append(b"l") + if return_size: + flags.append(b"s") + if return_hit_before: + flags.append(b"h") + if update_ttl is not None: + flags.append(b"T%d" % update_ttl) + if no_lru_bump: + flags.append(b"u") + if vivify_on_miss_ttl is not None: + flags.append(b"N%d" % vivify_on_miss_ttl) + if recache_ttl_threshold is not None: + flags.append(b"R%d" % recache_ttl_threshold) + if check_cas is not None: + flags.append(b"C%d" % check_cas) + + command = MetaCommand(cm=b"mg", key=key_bytes, flags=flags) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.rc == b"EN": + return None + + parsed = _parse_flags(result.flags) + + gr: GetResult[Any] = GetResult() + gr.is_stale = parsed.get("is_stale", False) + gr.won_recache = parsed.get("won_recache", False) + gr.already_won = parsed.get("already_won", False) + if return_cas: + gr.cas_token = parsed.get("cas_token") + if return_ttl: + gr.ttl = parsed.get("ttl") + if return_last_access: + gr.last_access = parsed.get("last_access") + if return_size: + gr.size = parsed.get("size") + if return_hit_before: + gr.hit_before = parsed.get("hit_before") + if "key" in parsed: + gr.key = parsed["key"] + + if result.value is not None and len(result.value) > 0: + client_flags = parsed.get("client_flags", 0) + gr.value = self._load(key_bytes, result.value, client_flags) + + return gr + + async def gat(self, key: Union[str, bytes], expire: int) -> Optional[Any]: + """Atomic get-and-touch: retrieve value and update TTL atomically.""" + key_bytes = self._to_bytes(key) + command = MetaCommand( + cm=b"mg", + key=key_bytes, + flags=[b"v", b"f", b"T%d" % expire], + ) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.value is None: + return None + + parsed = _parse_flags(result.flags) + client_flags = parsed.get("client_flags", 0) + return self._load(key_bytes, result.value, client_flags) + + async def touch(self, key: Union[str, bytes], expire: int) -> bool: + """Update the TTL of a key without returning its value.""" + key_bytes = self._to_bytes(key) + command = MetaCommand( + cm=b"mg", + key=key_bytes, + flags=[b"T%d" % expire], + ) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + return result.rc != b"EN" + + async def get_many( + self, + keys: List[Union[str, bytes]], + ) -> Dict[str, GetResult[Any]]: + """Retrieve multiple keys; missing keys are omitted from the result.""" + result: Dict[str, GetResult[Any]] = {} + for key in keys: + key_str = key if isinstance(key, str) else key.decode("utf-8") + gr = await self.get(key) + if gr is not None: + result[key_str] = gr + return result + + # ------------------------------------------------------------------ # + # Meta Set (ms) # + # ------------------------------------------------------------------ # + + async def set( + self, + key: Union[str, bytes], + value: Any, + *, + expire: Optional[int] = None, + ) -> None: + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"F%d" % client_flags] + if expire is not None: + flags.append(b"T%d" % expire) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.rc != b"HD": + raise MemcacheError(f"set failed: {result.rc.decode()}") + + async def add( + self, + key: Union[str, bytes], + value: Any, + *, + expire: Optional[int] = None, + ) -> bool: + """Store only if key does not already exist. Returns True on success.""" + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"ME", b"F%d" % client_flags] + if expire is not None: + flags.append(b"T%d" % expire) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc == b"NS": + return False + raise MemcacheError(f"add failed: {result.rc.decode()}") + + async def replace( + self, + key: Union[str, bytes], + value: Any, + *, + expire: Optional[int] = None, + ) -> bool: + """Store only if key already exists. Returns True on success.""" + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"MR", b"F%d" % client_flags] + if expire is not None: + flags.append(b"T%d" % expire) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc == b"NS": + return False + raise MemcacheError(f"replace failed: {result.rc.decode()}") + + async def append( + self, + key: Union[str, bytes], + value: Any, + *, + vivify_ttl: Optional[int] = None, + ) -> bool: + """Append value to an existing key. Returns True on success.""" + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"MA", b"F%d" % client_flags] + if vivify_ttl is not None: + flags.append(b"N%d" % vivify_ttl) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc == b"NS": + return False + raise MemcacheError(f"append failed: {result.rc.decode()}") + + async def prepend( + self, + key: Union[str, bytes], + value: Any, + *, + vivify_ttl: Optional[int] = None, + ) -> bool: + """Prepend value to an existing key. Returns True on success.""" + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"MP", b"F%d" % client_flags] + if vivify_ttl is not None: + flags.append(b"N%d" % vivify_ttl) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc == b"NS": + return False + raise MemcacheError(f"prepend failed: {result.rc.decode()}") + + async def cas( + self, + key: Union[str, bytes], + value: Any, + cas_token: int, + *, + expire: Optional[int] = None, + ) -> bool: + """Compare-and-swap. Returns True on success, False on CAS conflict.""" + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"F%d" % client_flags, b"C%d" % cas_token] + if expire is not None: + flags.append(b"T%d" % expire) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc in (b"EX", b"NF"): + return False + raise MemcacheError(f"cas failed: {result.rc.decode()}") + + # ------------------------------------------------------------------ # + # Meta Delete (md) # + # ------------------------------------------------------------------ # + + async def delete( + self, + key: Union[str, bytes], + *, + cas_token: Optional[int] = None, + ) -> bool: + """Delete a key. Returns True on success, False if key not found.""" + key_bytes = self._to_bytes(key) + flags: List[bytes] = [] + if cas_token is not None: + flags.append(b"C%d" % cas_token) + + command = MetaCommand(cm=b"md", key=key_bytes, flags=flags) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc in (b"NF", b"EX"): + return False + raise MemcacheError(f"delete failed: {result.rc.decode()}") + + async def invalidate( + self, + key: Union[str, bytes], + *, + stale_ttl: Optional[int] = None, + cas_token: Optional[int] = None, + ) -> bool: + """Mark a key as stale (stale-while-revalidate pattern).""" + key_bytes = self._to_bytes(key) + flags: List[bytes] = [b"I"] + if stale_ttl is not None: + flags.append(b"T%d" % stale_ttl) + if cas_token is not None: + flags.append(b"C%d" % cas_token) + + command = MetaCommand(cm=b"md", key=key_bytes, flags=flags) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc in (b"NF", b"EX"): + return False + raise MemcacheError(f"invalidate failed: {result.rc.decode()}") + + # ------------------------------------------------------------------ # + # Meta Arithmetic (ma) # + # ------------------------------------------------------------------ # + + async def incr( + self, + key: Union[str, bytes], + delta: int = 1, + *, + initial: Optional[int] = None, + initial_ttl: Optional[int] = None, + update_ttl: Optional[int] = None, + ) -> int: + """Increment counter. Raises MemcacheError if key missing and no initial.""" + key_bytes = self._to_bytes(key) + flags: List[bytes] = [b"D%d" % delta, b"v"] + if initial is not None: + flags.append(b"J%d" % initial) + if initial_ttl is not None: + flags.append(b"N%d" % initial_ttl) + if update_ttl is not None: + flags.append(b"T%d" % update_ttl) + + command = MetaCommand(cm=b"ma", key=key_bytes, flags=flags) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.rc == b"NF": + raise MemcacheError("key not found") + if result.rc != b"VA": + raise MemcacheError(f"incr failed: {result.rc.decode()}") + if result.value is None: + raise MemcacheError("incr: no value returned") + return int(result.value) + + async def decr( + self, + key: Union[str, bytes], + delta: int = 1, + *, + initial: Optional[int] = None, + initial_ttl: Optional[int] = None, + update_ttl: Optional[int] = None, + ) -> int: + """Decrement counter (floor 0). Raises MemcacheError if key missing.""" + key_bytes = self._to_bytes(key) + flags: List[bytes] = [b"D%d" % delta, b"MD", b"v"] + if initial is not None: + flags.append(b"J%d" % initial) + if initial_ttl is not None: + flags.append(b"N%d" % initial_ttl) + if update_ttl is not None: + flags.append(b"T%d" % update_ttl) + + command = MetaCommand(cm=b"ma", key=key_bytes, flags=flags) + async with self._get_connection(key_bytes) as connection: + result = await connection.execute_meta_command(command) + + if result.rc == b"NF": + raise MemcacheError("key not found") + if result.rc != b"VA": + raise MemcacheError(f"decr failed: {result.rc.decode()}") + if result.value is None: + raise MemcacheError("decr: no value returned") + return int(result.value) + + # ------------------------------------------------------------------ # + # Other # + # ------------------------------------------------------------------ # + + async def flush_all(self, delay: int = 0) -> None: + for pool in self._connections.nodes: + async with pool.get() as connection: + await connection.flush_all(delay) diff --git a/memcache/experiment/meta_client.py b/memcache/experiment/meta_client.py new file mode 100644 index 0000000..af79107 --- /dev/null +++ b/memcache/experiment/meta_client.py @@ -0,0 +1,533 @@ +from contextlib import contextmanager +from typing import Any, Dict, Iterator, List, Optional, Union + +import hashring + +from ..connection import Addr, Connection, Pool +from ..errors import MemcacheError +from ..meta_command import MetaCommand, MetaResult +from ..serialize import dump, load, DumpFunc, LoadFunc +from .result import GetResult + + +def _parse_flags(flags: List[bytes]) -> Dict[str, Any]: + result: Dict[str, Any] = {} + for flag in flags: + if not flag: + continue + prefix = chr(flag[0]) + rest = flag[1:] + if prefix == "f": + result["client_flags"] = int(rest) + elif prefix == "c": + result["cas_token"] = int(rest) + elif prefix == "t": + result["ttl"] = int(rest) + elif prefix == "l": + result["last_access"] = int(rest) + elif prefix == "s": + result["size"] = int(rest) + elif prefix == "h": + result["hit_before"] = int(rest) != 0 + elif prefix == "W": + result["won_recache"] = True + elif prefix == "X": + result["is_stale"] = True + elif prefix == "Z": + result["already_won"] = True + elif prefix == "k": + result["key"] = rest.decode("utf-8") + return result + + +class MetaClient: + """ + Memcache client with full meta protocol capability. + + Exposes the complete meta protocol (mg/ms/md/ma commands) with all flags. + Access via ``from memcache.experiment import MetaClient``. + """ + + def __init__( + self, + addr: Union[Addr, List[Addr], None] = None, + *, + pool_size: Optional[int] = 23, + pool_timeout: Optional[int] = 1, + load_func: LoadFunc = load, + dump_func: DumpFunc = dump, + username: Optional[str] = None, + password: Optional[str] = None, + ): + self._load = load_func + self._dump = dump_func + + addr = addr or ("localhost", 11211) + if isinstance(addr, list): + addrs: List[Addr] = addr + nodes: List[Pool] = [] + for a in addrs: + def _make(a: Addr = a) -> Connection: + return Connection( + a, + username=username, + password=password, + ) + nodes.append(Pool(_make, max_size=pool_size, timeout=pool_timeout)) + self._connections = hashring.HashRing(nodes) + elif isinstance(addr, tuple): + a_single: Addr = addr + + def _make_single() -> Connection: + return Connection( + a_single, + username=username, + password=password, + ) + self._connections = hashring.HashRing( + [Pool(_make_single, max_size=pool_size, timeout=pool_timeout)] + ) + else: + raise TypeError("invalid type for addr") + + @contextmanager + def _get_connection(self, key: Union[str, bytes]) -> Iterator[Connection]: + if isinstance(key, bytes): + key = key.decode("utf-8") + pool = self._connections.get_node(key) + with pool.get() as connection: + yield connection + + @staticmethod + def _to_bytes(key: Union[str, bytes]) -> bytes: + if isinstance(key, str): + return key.encode() + return key + + def execute_meta_command(self, command: MetaCommand) -> MetaResult: + with self._get_connection(command.key) as connection: + return connection.execute_meta_command(command) + + # ------------------------------------------------------------------ # + # Meta Get (mg) # + # ------------------------------------------------------------------ # + + def get( + self, + key: Union[str, bytes], + *, + return_cas: bool = False, + return_ttl: bool = False, + return_last_access: bool = False, + return_size: bool = False, + return_hit_before: bool = False, + update_ttl: Optional[int] = None, + no_lru_bump: bool = False, + vivify_on_miss_ttl: Optional[int] = None, + recache_ttl_threshold: Optional[int] = None, + check_cas: Optional[int] = None, + ) -> Optional[GetResult[Any]]: + key_bytes = self._to_bytes(key) + flags: List[bytes] = [b"v", b"f"] + if return_cas: + flags.append(b"c") + if return_ttl: + flags.append(b"t") + if return_last_access: + flags.append(b"l") + if return_size: + flags.append(b"s") + if return_hit_before: + flags.append(b"h") + if update_ttl is not None: + flags.append(b"T%d" % update_ttl) + if no_lru_bump: + flags.append(b"u") + if vivify_on_miss_ttl is not None: + flags.append(b"N%d" % vivify_on_miss_ttl) + if recache_ttl_threshold is not None: + flags.append(b"R%d" % recache_ttl_threshold) + if check_cas is not None: + flags.append(b"C%d" % check_cas) + + command = MetaCommand(cm=b"mg", key=key_bytes, flags=flags) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.rc == b"EN": + return None + + parsed = _parse_flags(result.flags) + + gr: GetResult[Any] = GetResult() + gr.is_stale = parsed.get("is_stale", False) + gr.won_recache = parsed.get("won_recache", False) + gr.already_won = parsed.get("already_won", False) + if return_cas: + gr.cas_token = parsed.get("cas_token") + if return_ttl: + gr.ttl = parsed.get("ttl") + if return_last_access: + gr.last_access = parsed.get("last_access") + if return_size: + gr.size = parsed.get("size") + if return_hit_before: + gr.hit_before = parsed.get("hit_before") + if "key" in parsed: + gr.key = parsed["key"] + + if result.value is not None and len(result.value) > 0: + client_flags = parsed.get("client_flags", 0) + gr.value = self._load(key_bytes, result.value, client_flags) + + return gr + + def gat(self, key: Union[str, bytes], expire: int) -> Optional[Any]: + """Atomic get-and-touch: retrieve value and update TTL atomically.""" + key_bytes = self._to_bytes(key) + command = MetaCommand( + cm=b"mg", + key=key_bytes, + flags=[b"v", b"f", b"T%d" % expire], + ) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.value is None: + return None + + parsed = _parse_flags(result.flags) + client_flags = parsed.get("client_flags", 0) + return self._load(key_bytes, result.value, client_flags) + + def touch(self, key: Union[str, bytes], expire: int) -> bool: + """Update the TTL of a key without returning its value.""" + key_bytes = self._to_bytes(key) + command = MetaCommand( + cm=b"mg", + key=key_bytes, + flags=[b"T%d" % expire], + ) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + return result.rc != b"EN" + + def get_many( + self, + keys: List[Union[str, bytes]], + ) -> Dict[str, GetResult[Any]]: + """Retrieve multiple keys; missing keys are omitted from the result.""" + result: Dict[str, GetResult[Any]] = {} + for key in keys: + key_str = key if isinstance(key, str) else key.decode("utf-8") + gr = self.get(key) + if gr is not None: + result[key_str] = gr + return result + + # ------------------------------------------------------------------ # + # Meta Set (ms) # + # ------------------------------------------------------------------ # + + def set( + self, + key: Union[str, bytes], + value: Any, + *, + expire: Optional[int] = None, + ) -> None: + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"F%d" % client_flags] + if expire is not None: + flags.append(b"T%d" % expire) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.rc != b"HD": + raise MemcacheError(f"set failed: {result.rc.decode()}") + + def add( + self, + key: Union[str, bytes], + value: Any, + *, + expire: Optional[int] = None, + ) -> bool: + """Store only if key does not already exist. Returns True on success.""" + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"ME", b"F%d" % client_flags] + if expire is not None: + flags.append(b"T%d" % expire) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc == b"NS": + return False + raise MemcacheError(f"add failed: {result.rc.decode()}") + + def replace( + self, + key: Union[str, bytes], + value: Any, + *, + expire: Optional[int] = None, + ) -> bool: + """Store only if key already exists. Returns True on success.""" + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"MR", b"F%d" % client_flags] + if expire is not None: + flags.append(b"T%d" % expire) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc == b"NS": + return False + raise MemcacheError(f"replace failed: {result.rc.decode()}") + + def append( + self, + key: Union[str, bytes], + value: Any, + *, + vivify_ttl: Optional[int] = None, + ) -> bool: + """Append value to an existing key. Returns True on success.""" + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"MA", b"F%d" % client_flags] + if vivify_ttl is not None: + flags.append(b"N%d" % vivify_ttl) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc == b"NS": + return False + raise MemcacheError(f"append failed: {result.rc.decode()}") + + def prepend( + self, + key: Union[str, bytes], + value: Any, + *, + vivify_ttl: Optional[int] = None, + ) -> bool: + """Prepend value to an existing key. Returns True on success.""" + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"MP", b"F%d" % client_flags] + if vivify_ttl is not None: + flags.append(b"N%d" % vivify_ttl) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc == b"NS": + return False + raise MemcacheError(f"prepend failed: {result.rc.decode()}") + + def cas( + self, + key: Union[str, bytes], + value: Any, + cas_token: int, + *, + expire: Optional[int] = None, + ) -> bool: + """Compare-and-swap. Returns True on success, False on CAS conflict.""" + key_bytes = self._to_bytes(key) + raw_value, client_flags = self._dump(key_bytes, value) + flags: List[bytes] = [b"F%d" % client_flags, b"C%d" % cas_token] + if expire is not None: + flags.append(b"T%d" % expire) + + command = MetaCommand( + cm=b"ms", + key=key_bytes, + datalen=len(raw_value), + flags=flags, + value=raw_value, + ) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc in (b"EX", b"NF"): + return False + raise MemcacheError(f"cas failed: {result.rc.decode()}") + + # ------------------------------------------------------------------ # + # Meta Delete (md) # + # ------------------------------------------------------------------ # + + def delete( + self, + key: Union[str, bytes], + *, + cas_token: Optional[int] = None, + ) -> bool: + """Delete a key. Returns True on success, False if key not found.""" + key_bytes = self._to_bytes(key) + flags: List[bytes] = [] + if cas_token is not None: + flags.append(b"C%d" % cas_token) + + command = MetaCommand(cm=b"md", key=key_bytes, flags=flags) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc in (b"NF", b"EX"): + return False + raise MemcacheError(f"delete failed: {result.rc.decode()}") + + def invalidate( + self, + key: Union[str, bytes], + *, + stale_ttl: Optional[int] = None, + cas_token: Optional[int] = None, + ) -> bool: + """Mark a key as stale (stale-while-revalidate pattern).""" + key_bytes = self._to_bytes(key) + flags: List[bytes] = [b"I"] + if stale_ttl is not None: + flags.append(b"T%d" % stale_ttl) + if cas_token is not None: + flags.append(b"C%d" % cas_token) + + command = MetaCommand(cm=b"md", key=key_bytes, flags=flags) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.rc == b"HD": + return True + if result.rc in (b"NF", b"EX"): + return False + raise MemcacheError(f"invalidate failed: {result.rc.decode()}") + + # ------------------------------------------------------------------ # + # Meta Arithmetic (ma) # + # ------------------------------------------------------------------ # + + def incr( + self, + key: Union[str, bytes], + delta: int = 1, + *, + initial: Optional[int] = None, + initial_ttl: Optional[int] = None, + update_ttl: Optional[int] = None, + ) -> int: + """Increment counter. Raises MemcacheError if key missing and no initial.""" + key_bytes = self._to_bytes(key) + flags: List[bytes] = [b"D%d" % delta, b"v"] + if initial is not None: + flags.append(b"J%d" % initial) + if initial_ttl is not None: + flags.append(b"N%d" % initial_ttl) + if update_ttl is not None: + flags.append(b"T%d" % update_ttl) + + command = MetaCommand(cm=b"ma", key=key_bytes, flags=flags) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.rc == b"NF": + raise MemcacheError("key not found") + if result.rc != b"VA": + raise MemcacheError(f"incr failed: {result.rc.decode()}") + if result.value is None: + raise MemcacheError("incr: no value returned") + return int(result.value) + + def decr( + self, + key: Union[str, bytes], + delta: int = 1, + *, + initial: Optional[int] = None, + initial_ttl: Optional[int] = None, + update_ttl: Optional[int] = None, + ) -> int: + """Decrement counter (floor 0). Raises MemcacheError if key missing.""" + key_bytes = self._to_bytes(key) + flags: List[bytes] = [b"D%d" % delta, b"MD", b"v"] + if initial is not None: + flags.append(b"J%d" % initial) + if initial_ttl is not None: + flags.append(b"N%d" % initial_ttl) + if update_ttl is not None: + flags.append(b"T%d" % update_ttl) + + command = MetaCommand(cm=b"ma", key=key_bytes, flags=flags) + with self._get_connection(key_bytes) as connection: + result = connection.execute_meta_command(command) + + if result.rc == b"NF": + raise MemcacheError("key not found") + if result.rc != b"VA": + raise MemcacheError(f"decr failed: {result.rc.decode()}") + if result.value is None: + raise MemcacheError("decr: no value returned") + return int(result.value) + + # ------------------------------------------------------------------ # + # Other # + # ------------------------------------------------------------------ # + + def flush_all(self, delay: int = 0) -> None: + for pool in self._connections.nodes: + with pool.get() as connection: + connection.flush_all(delay) diff --git a/memcache/experiment/result.py b/memcache/experiment/result.py new file mode 100644 index 0000000..8856ed0 --- /dev/null +++ b/memcache/experiment/result.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import Generic, Optional, TypeVar + + +T = TypeVar("T") + + +@dataclass +class GetResult(Generic[T]): + value: Optional[T] = None + key: Optional[str] = None + cas_token: Optional[int] = None + ttl: Optional[int] = None + last_access: Optional[int] = None + size: Optional[int] = None + hit_before: Optional[bool] = None + is_stale: bool = False + won_recache: bool = False + already_won: bool = False diff --git a/memcache/memcache.py b/memcache/memcache.py index f5b0ec2..ab18a53 100644 --- a/memcache/memcache.py +++ b/memcache/memcache.py @@ -1,245 +1,13 @@ -import socket -import threading -import queue from contextlib import contextmanager -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union - -import hashring +from typing import Any, Iterator, List, Optional, Tuple, Union +from .connection import Addr, Connection, Pool # re-export for backward compat from .errors import MemcacheError +from .experiment.meta_client import MetaClient from .meta_command import MetaCommand, MetaResult from .serialize import dump, load, DumpFunc, LoadFunc - -NEWLINE = b"\r\n" - - -class Connection: - def __init__( - self, - addr: Tuple[str, int], - *, - load_func: LoadFunc = load, - dump_func: DumpFunc = dump, - username: Optional[str] = None, - password: Optional[str] = None, - ): - self._addr = addr - self._load = load_func - self._dump = dump_func - self._username = username - self._password = password - self._connect() - - def _connect(self) -> None: - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.connect(self._addr) - self.stream = self.socket.makefile(mode="rwb") - self._auth() - - def _auth(self) -> None: - if self._username is None or self._password is None: - return - auth_data = b"%s %s" % ( - self._username.encode("utf-8"), - self._password.encode("utf-8"), - ) - self.stream.write(b"set auth x 0 %d\r\n" % len(auth_data)) - self.stream.write(auth_data) - self.stream.write(b"\r\n") - self.stream.flush() - response = self.stream.readline() - if response != b"STORED\r\n": - raise MemcacheError(response.rstrip(NEWLINE)) - - def close(self) -> None: - self.stream.close() - self.socket.close() - - def flush_all(self) -> None: - self.stream.write(b"flush_all\r\n") - self.stream.flush() - response = self.stream.readline() - if response != b"OK\r\n": - raise MemcacheError(response.rstrip(NEWLINE)) - - def execute_meta_command(self, command: MetaCommand) -> MetaResult: - try: - return self._execute_meta_command(command) - except (IndexError, ConnectionResetError, BrokenPipeError): - # This happens when connection is closed by memcached. - self._connect() - return self._execute_meta_command(command) - - def _execute_meta_command(self, command: MetaCommand) -> MetaResult: - self.stream.write(command.dump_header()) - if command.value: - self.stream.write(command.value + b"\r\n") - self.stream.flush() - return self._receive_meta_result() - - def _receive_meta_result(self) -> MetaResult: - result = MetaResult.load_header(self.stream.readline()) - - if result.rc == b"VA": - if result.datalen is None: - raise MemcacheError("invalid response: missing datalen") - result.value = self.stream.read(result.datalen) - self.stream.read(2) # read the "\r\n" - - return result - - def set( - self, key: Union[bytes, str], value: Any, expire: Optional[int] = None - ) -> None: - value, client_flags = self._dump(key, value) - - flags = [b"F%d" % client_flags] - if expire: - flags.append(b"T%d" % expire) - - command = MetaCommand( - cm=b"ms", key=key, datalen=len(value), flags=flags, value=value - ) - self.execute_meta_command(command) - - def cas( - self, - key: Union[bytes, str], - value: Any, - cas_token: int, - *, - expire: Optional[int] = None, - ) -> None: - """ - Store a value using compare-and-swap operation. - - :param key: The key to store - :param value: The value to store - :param cas_token: The CAS token from a previous gets operation - :param expire: Optional expiration time in seconds - :raises MemcacheError: If the CAS token doesn't match or other error occurs - """ - value, client_flags = self._dump(key, value) - - flags = [b"F%d" % client_flags, b"C%d" % cas_token] - if expire: - flags.append(b"T%d" % expire) - - command = MetaCommand( - cm=b"ms", key=key, datalen=len(value), flags=flags, value=value - ) - result = self.execute_meta_command(command) - - if result.rc != b"HD": - raise MemcacheError("CAS operation failed: token mismatch or other error") - - def get(self, key: Union[bytes, str]) -> Optional[Any]: - command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f"]) - result = self.execute_meta_command(command) - - if result.value is None: - return None - - client_flags = int(result.flags[0][1:]) - - return self._load(key, result.value, client_flags) - - def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]: - """ - Get a value and its CAS token from memcached. - - :param key: The key to retrieve - :return: A tuple of (value, cas_token) or None if key doesn't exist - """ - command = MetaCommand(cm=b"mg", key=key, flags=[b"v", b"f", b"c"]) - result = self.execute_meta_command(command) - - if result.value is None: - return None - - client_flags = int(result.flags[0][1:]) - value = self._load(key, result.value, client_flags) - - # Find CAS token in flags - cas_token = None - for flag in result.flags[1:]: # Skip the first flag (client_flags) - if flag.startswith(b"c"): - cas_token = int(flag[1:]) - break - - if cas_token is None: - raise MemcacheError("CAS token not found in response") - - return value, cas_token - - def delete(self, key: Union[bytes, str]) -> None: - command = MetaCommand(cm=b"md", key=key, flags=[], value=None) - self.execute_meta_command(command) - - def incr(self, key: Union[bytes, str], value: int = 1) -> int: - command = MetaCommand( - cm=b"ma", key=key, flags=[b"D%d" % value, b"v"] - ) - result = self.execute_meta_command(command) - - if result.rc != b"VA": - raise MemcacheError(f"INCR operation failed: {result.rc.decode()}") - - if result.value is None: - raise MemcacheError("INCR operation failed: no value returned") - - return int(result.value) - - def decr(self, key: Union[bytes, str], value: int = 1) -> int: - command = MetaCommand( - cm=b"ma", key=key, flags=[b"D%d" % value, b"MD", b"v"] - ) - result = self.execute_meta_command(command) - - if result.rc != b"VA": - raise MemcacheError(f"DECR operation failed: {result.rc.decode()}") - - if result.value is None: - raise MemcacheError("DECR operation failed: no value returned") - - return int(result.value) - - -Addr = Tuple[str, int] - - -class Pool: - def __init__( - self, - create_connection: Callable[..., Connection], - max_size: Optional[int], - timeout: Optional[int], - ) -> None: - self._create_connection = create_connection - self._max_size = max_size - self._timeout = timeout - self._size = 0 - self._lock = threading.Lock() - self._connections: queue.Queue[Connection] = queue.Queue() - - @contextmanager - def get(self) -> Iterator[Connection]: - try: - connection = self._connections.get_nowait() - yield connection - self._connections.put(connection) - except queue.Empty: - if self._max_size and self._size >= self._max_size: - connection = self._connections.get(timeout=self._timeout) - yield connection - self._connections.put(connection) - else: - with self._lock: - self._size += 1 - connection = self._create_connection() - yield connection - self._connections.put(connection) +__all__ = ["Addr", "Connection", "Pool", "Memcache"] class Memcache: @@ -280,63 +48,35 @@ def __init__( username: Optional[str] = None, password: Optional[str] = None, ): - addr = addr or ("localhost", 11211) - if isinstance(addr, list): - addrs: List[Addr] = addr - nodes: List[Pool] = [] - for addr in addrs: - create_connection = lambda: Connection( - addr, - load_func=load_func, - dump_func=dump_func, - username=username, - password=password, - ) - nodes.append( - Pool(create_connection, max_size=pool_size, timeout=pool_timeout) - ) - self._connections = hashring.HashRing(nodes) - elif isinstance(addr, tuple): - a: Addr = addr - create_connection = lambda: Connection( - a, - load_func=load_func, - dump_func=dump_func, - username=username, - password=password, - ) - self._connections = hashring.HashRing( - [Pool(create_connection, max_size=pool_size, timeout=pool_timeout)] - ) - else: - raise TypeError("invalid type for addr") + self._meta = MetaClient( + addr, + pool_size=pool_size, + pool_timeout=pool_timeout, + load_func=load_func, + dump_func=dump_func, + username=username, + password=password, + ) @contextmanager def _get_connection(self, key: Union[str, bytes]) -> Iterator[Connection]: - if isinstance(key, bytes): - key = key.decode("utf-8") - pool = self._connections.get_node(key) - with pool.get() as connection: - yield connection + with self._meta._get_connection(key) as conn: + yield conn def execute_meta_command(self, command: MetaCommand) -> MetaResult: - with self._get_connection(command.key) as connection: - return connection.execute_meta_command(command) + return self._meta.execute_meta_command(command) def flush_all(self) -> None: - for pool in self._connections.nodes: - with pool.get() as connection: - connection.flush_all() + self._meta.flush_all() def set( self, key: Union[bytes, str], value: Any, *, expire: Optional[int] = None ) -> None: - with self._get_connection(key) as connection: - return connection.set(key, value, expire=expire) + self._meta.set(key, value, expire=expire) def get(self, key: Union[bytes, str]) -> Optional[Any]: - with self._get_connection(key) as connection: - return connection.get(key) + r = self._meta.get(key) + return r.value if r is not None else None def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]: """ @@ -345,8 +85,12 @@ def gets(self, key: Union[bytes, str]) -> Optional[Tuple[Any, int]]: :param key: The key to retrieve :return: A tuple of (value, cas_token) or None if key doesn't exist """ - with self._get_connection(key) as connection: - return connection.gets(key) + r = self._meta.get(key, return_cas=True) + if r is None: + return None + if r.cas_token is None: + raise MemcacheError("CAS token not found in response") + return r.value, r.cas_token def cas( self, @@ -365,17 +109,15 @@ def cas( :param expire: Optional expiration time in seconds :raises MemcacheError: If the CAS token doesn't match or other error occurs """ - with self._get_connection(key) as connection: - connection.cas(key, value, cas_token, expire=expire) + ok = self._meta.cas(key, value, cas_token, expire=expire) + if not ok: + raise MemcacheError("CAS operation failed: token mismatch or other error") def delete(self, key: Union[bytes, str]) -> None: - with self._get_connection(key) as connection: - return connection.delete(key) + self._meta.delete(key) def incr(self, key: Union[bytes, str], value: int = 1) -> int: - with self._get_connection(key) as connection: - return connection.incr(key, value) + return self._meta.incr(key, value) def decr(self, key: Union[bytes, str], value: int = 1) -> int: - with self._get_connection(key) as connection: - return connection.decr(key, value) + return self._meta.decr(key, value) diff --git a/memcache/meta_command.py b/memcache/meta_command.py index b1d2c92..982900a 100644 --- a/memcache/meta_command.py +++ b/memcache/meta_command.py @@ -60,7 +60,7 @@ def load_header(line: bytes) -> "MetaResult": flags = [] datalen = None if len(parts) > 1: - if str(parts[1][0]).isdigit(): + if chr(parts[1][0]).isdigit(): datalen = int(parts[1]) flags = parts[2:] else: diff --git a/tests/test_async_meta_client.py b/tests/test_async_meta_client.py new file mode 100644 index 0000000..0bd912a --- /dev/null +++ b/tests/test_async_meta_client.py @@ -0,0 +1,454 @@ +import asyncio + +import pytest +import pytest_asyncio + +from memcache.experiment import AsyncMetaClient, GetResult +from memcache import MemcacheError +from memcache.meta_command import MetaCommand + + +@pytest.fixture() +def client() -> AsyncMetaClient: + return AsyncMetaClient(("localhost", 11211)) + + +@pytest_asyncio.fixture(autouse=True) +async def flush(client: AsyncMetaClient) -> None: + await client.flush_all() + + +# ------------------------------------------------------------------ # +# get / set # +# ------------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_set_get(client: AsyncMetaClient) -> None: + await client.set("key1", "hello") + r = await client.get("key1") + assert r is not None + assert r.value == "hello" + + +@pytest.mark.asyncio +async def test_get_missing(client: AsyncMetaClient) -> None: + assert await client.get("no_such_key") is None + + +@pytest.mark.asyncio +async def test_set_get_with_expire(client: AsyncMetaClient) -> None: + await client.set("expkey", "val", expire=1) + r = await client.get("expkey") + assert r is not None + assert r.value == "val" + await asyncio.sleep(1.1) + assert await client.get("expkey") is None + + +@pytest.mark.asyncio +async def test_get_returns_getresult_instance(client: AsyncMetaClient) -> None: + await client.set("gr_type", 42) + r = await client.get("gr_type") + assert isinstance(r, GetResult) + assert r.value == 42 + + +@pytest.mark.asyncio +async def test_get_no_lru_bump(client: AsyncMetaClient) -> None: + await client.set("nlb", "v") + r = await client.get("nlb", no_lru_bump=True) + assert r is not None + assert r.value == "v" + + +@pytest.mark.asyncio +async def test_get_update_ttl(client: AsyncMetaClient) -> None: + await client.set("utt", "v", expire=10) + r = await client.get("utt", update_ttl=3600) + assert r is not None + assert r.value == "v" + + +@pytest.mark.asyncio +async def test_get_return_cas(client: AsyncMetaClient) -> None: + await client.set("gr_cas", "v") + r = await client.get("gr_cas", return_cas=True) + assert r is not None + assert isinstance(r.cas_token, int) + assert r.cas_token > 0 + + +@pytest.mark.asyncio +async def test_get_return_ttl(client: AsyncMetaClient) -> None: + await client.set("gr_ttl", "v", expire=3600) + r = await client.get("gr_ttl", return_ttl=True) + assert r is not None + assert r.ttl is not None + assert r.ttl > 0 + + +@pytest.mark.asyncio +async def test_get_no_ttl_requested(client: AsyncMetaClient) -> None: + await client.set("gr_nottl", "v") + r = await client.get("gr_nottl") + assert r is not None + assert r.ttl is None + + +@pytest.mark.asyncio +async def test_get_return_size(client: AsyncMetaClient) -> None: + await client.set("gr_size", b"hello") + r = await client.get("gr_size", return_size=True) + assert r is not None + assert r.size == 5 + + +@pytest.mark.asyncio +async def test_get_return_hit_before(client: AsyncMetaClient) -> None: + await client.set("gr_hit", "v") + r1 = await client.get("gr_hit", return_hit_before=True) + assert r1 is not None + assert r1.hit_before is False + r2 = await client.get("gr_hit", return_hit_before=True) + assert r2 is not None + assert r2.hit_before is True + + +@pytest.mark.asyncio +async def test_get_return_last_access(client: AsyncMetaClient) -> None: + await client.set("gr_la", "v") + r = await client.get("gr_la", return_last_access=True) + assert r is not None + assert r.last_access is not None + assert isinstance(r.last_access, int) + + +# ------------------------------------------------------------------ # +# gat / touch # +# ------------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_gat(client: AsyncMetaClient) -> None: + await client.set("gat_key", "v", expire=1) + result = await client.gat("gat_key", expire=3600) + assert result == "v" + await asyncio.sleep(1.1) + r = await client.get("gat_key") + assert r is not None + assert r.value == "v" + + +@pytest.mark.asyncio +async def test_gat_miss(client: AsyncMetaClient) -> None: + assert await client.gat("no_gat_key", expire=60) is None + + +@pytest.mark.asyncio +async def test_touch_existing(client: AsyncMetaClient) -> None: + await client.set("touch_key", "v", expire=1) + assert await client.touch("touch_key", expire=3600) is True + await asyncio.sleep(1.1) + r = await client.get("touch_key") + assert r is not None + assert r.value == "v" + + +@pytest.mark.asyncio +async def test_touch_missing(client: AsyncMetaClient) -> None: + assert await client.touch("no_touch_key", expire=60) is False + + +# ------------------------------------------------------------------ # +# get_many # +# ------------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_get_many(client: AsyncMetaClient) -> None: + await client.set("m1", "a") + await client.set("m2", "b") + result = await client.get_many(["m1", "m2", "m_miss"]) + assert set(result.keys()) == {"m1", "m2"} + assert result["m1"].value == "a" + assert result["m2"].value == "b" + + +@pytest.mark.asyncio +async def test_get_many_all_miss(client: AsyncMetaClient) -> None: + assert await client.get_many(["x1", "x2"]) == {} + + +# ------------------------------------------------------------------ # +# add / replace # +# ------------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_add_new_key(client: AsyncMetaClient) -> None: + assert await client.add("add_new", "v") is True + r = await client.get("add_new") + assert r is not None + assert r.value == "v" + + +@pytest.mark.asyncio +async def test_add_existing_key(client: AsyncMetaClient) -> None: + await client.set("add_exists", "original") + assert await client.add("add_exists", "new") is False + r = await client.get("add_exists") + assert r is not None + assert r.value == "original" + + +@pytest.mark.asyncio +async def test_replace_existing(client: AsyncMetaClient) -> None: + await client.set("rep_key", "old") + assert await client.replace("rep_key", "new") is True + r = await client.get("rep_key") + assert r is not None + assert r.value == "new" + + +@pytest.mark.asyncio +async def test_replace_missing(client: AsyncMetaClient) -> None: + assert await client.replace("rep_miss", "v") is False + + +# ------------------------------------------------------------------ # +# append / prepend # +# ------------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_append(client: AsyncMetaClient) -> None: + await client.set("app_key", b"hello") + assert await client.append("app_key", b" world") is True + r = await client.get("app_key") + assert r is not None + assert r.value == b"hello world" + + +@pytest.mark.asyncio +async def test_append_missing_key(client: AsyncMetaClient) -> None: + assert await client.append("app_miss", b"v") is False + + +@pytest.mark.asyncio +async def test_append_vivify(client: AsyncMetaClient) -> None: + assert await client.append("app_vivify", b"data", vivify_ttl=60) is True + r = await client.get("app_vivify") + assert r is not None + assert r.value == b"data" + + +@pytest.mark.asyncio +async def test_prepend(client: AsyncMetaClient) -> None: + await client.set("pre_key", b"world") + assert await client.prepend("pre_key", b"hello ") is True + r = await client.get("pre_key") + assert r is not None + assert r.value == b"hello world" + + +@pytest.mark.asyncio +async def test_prepend_missing_key(client: AsyncMetaClient) -> None: + assert await client.prepend("pre_miss", b"v") is False + + +@pytest.mark.asyncio +async def test_prepend_vivify(client: AsyncMetaClient) -> None: + assert await client.prepend("pre_vivify", b"data", vivify_ttl=60) is True + r = await client.get("pre_vivify") + assert r is not None + assert r.value == b"data" + + +# ------------------------------------------------------------------ # +# cas # +# ------------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_cas_success(client: AsyncMetaClient) -> None: + await client.set("cas_key", "initial") + r = await client.get("cas_key", return_cas=True) + assert r is not None + assert r.cas_token is not None + assert await client.cas("cas_key", "updated", r.cas_token) is True + r2 = await client.get("cas_key") + assert r2 is not None + assert r2.value == "updated" + + +@pytest.mark.asyncio +async def test_cas_conflict(client: AsyncMetaClient) -> None: + await client.set("cas_conf", "v") + r = await client.get("cas_conf", return_cas=True) + assert r is not None + assert r.cas_token is not None + await client.set("cas_conf", "modified") + assert await client.cas("cas_conf", "new", r.cas_token) is False + r2 = await client.get("cas_conf") + assert r2 is not None + assert r2.value == "modified" + + +@pytest.mark.asyncio +async def test_cas_missing_key(client: AsyncMetaClient) -> None: + assert await client.cas("cas_no_key", "v", 12345) is False + + +# ------------------------------------------------------------------ # +# delete # +# ------------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_delete_existing(client: AsyncMetaClient) -> None: + await client.set("del_key", "v") + assert await client.delete("del_key") is True + assert await client.get("del_key") is None + + +@pytest.mark.asyncio +async def test_delete_missing(client: AsyncMetaClient) -> None: + assert await client.delete("del_miss") is False + + +@pytest.mark.asyncio +async def test_delete_with_cas(client: AsyncMetaClient) -> None: + await client.set("del_cas", "v") + r = await client.get("del_cas", return_cas=True) + assert r is not None + assert r.cas_token is not None + assert await client.delete("del_cas", cas_token=r.cas_token) is True + + +@pytest.mark.asyncio +async def test_delete_with_wrong_cas(client: AsyncMetaClient) -> None: + await client.set("del_cas_bad", "v") + assert await client.delete("del_cas_bad", cas_token=99999999) is False + + +# ------------------------------------------------------------------ # +# invalidate # +# ------------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_invalidate(client: AsyncMetaClient) -> None: + await client.set("inv_key", "v") + assert await client.invalidate("inv_key") is True + + +@pytest.mark.asyncio +async def test_invalidate_missing(client: AsyncMetaClient) -> None: + assert await client.invalidate("inv_miss") is False + + +@pytest.mark.asyncio +async def test_invalidate_stale_flag(client: AsyncMetaClient) -> None: + await client.set("inv_stale", "v") + await client.invalidate("inv_stale", stale_ttl=10) + r = await client.get("inv_stale") + assert r is not None + assert r.is_stale is True + + +# ------------------------------------------------------------------ # +# incr / decr # +# ------------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_incr(client: AsyncMetaClient) -> None: + await client.set("ctr", 10) + assert await client.incr("ctr") == 11 + + +@pytest.mark.asyncio +async def test_incr_with_delta(client: AsyncMetaClient) -> None: + await client.set("ctr2", 5) + assert await client.incr("ctr2", 3) == 8 + + +@pytest.mark.asyncio +async def test_incr_missing(client: AsyncMetaClient) -> None: + with pytest.raises(MemcacheError): + await client.incr("incr_miss") + + +@pytest.mark.asyncio +async def test_incr_with_initial(client: AsyncMetaClient) -> None: + # On miss with J flag, memcached creates item with initial value (delta not applied) + assert await client.incr("incr_init", initial=100, initial_ttl=60) == 100 + + +@pytest.mark.asyncio +async def test_incr_with_initial_existing(client: AsyncMetaClient) -> None: + await client.set("incr_init_ex", 5) + assert await client.incr("incr_init_ex", initial=100, initial_ttl=60) == 6 + + +@pytest.mark.asyncio +async def test_decr(client: AsyncMetaClient) -> None: + await client.set("dctr", 10) + assert await client.decr("dctr") == 9 + + +@pytest.mark.asyncio +async def test_decr_with_delta(client: AsyncMetaClient) -> None: + await client.set("dctr2", 10) + assert await client.decr("dctr2", 3) == 7 + + +@pytest.mark.asyncio +async def test_decr_floor_zero(client: AsyncMetaClient) -> None: + await client.set("dctr3", 1) + assert await client.decr("dctr3", 5) == 0 + + +@pytest.mark.asyncio +async def test_decr_missing(client: AsyncMetaClient) -> None: + with pytest.raises(MemcacheError): + await client.decr("decr_miss") + + +@pytest.mark.asyncio +async def test_decr_with_initial(client: AsyncMetaClient) -> None: + result = await client.decr("decr_init", initial=50, initial_ttl=60) + # On miss with J flag, memcached creates item with initial value (delta not applied) + assert result == 50 + + +# ------------------------------------------------------------------ # +# flush_all # +# ------------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_flush_all(client: AsyncMetaClient) -> None: + await client.set("flush_k", "v") + await client.flush_all() + assert await client.get("flush_k") is None + + +# ------------------------------------------------------------------ # +# execute_meta_command (low-level pass-through) # +# ------------------------------------------------------------------ # + + +@pytest.mark.asyncio +async def test_execute_meta_command(client: AsyncMetaClient) -> None: + cmd = MetaCommand( + cm=b"ms", key=b"raw_key", datalen=3, flags=[b"T60"], value=b"raw" + ) + result = await client.execute_meta_command(cmd) + assert result.rc == b"HD" + + cmd2 = MetaCommand(cm=b"mg", key=b"raw_key", flags=[b"v"]) + result2 = await client.execute_meta_command(cmd2) + assert result2.rc == b"VA" + assert result2.value == b"raw" diff --git a/tests/test_client.py b/tests/test_client.py index 5ed0280..a32f55d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,4 @@ import queue -import threading import time import pytest diff --git a/tests/test_meta_client.py b/tests/test_meta_client.py new file mode 100644 index 0000000..343e33c --- /dev/null +++ b/tests/test_meta_client.py @@ -0,0 +1,438 @@ +import time + +import pytest + +from memcache.experiment import GetResult, MetaClient +from memcache import MemcacheError +from memcache.meta_command import MetaCommand, MetaResult + + +@pytest.fixture() +def client() -> MetaClient: + c = MetaClient(("localhost", 11211)) + c.flush_all() + return c + + +# ------------------------------------------------------------------ # +# Bug fix: load_header correctly parses flags after datalen # +# ------------------------------------------------------------------ # + + +def test_load_header_with_flags_after_datalen() -> None: + # HD with flag containing non-digit prefix should NOT be parsed as datalen + result = MetaResult.load_header(b"HD t20 c123\r\n") + assert result.rc == b"HD" + assert result.datalen is None + assert b"t20" in result.flags + assert b"c123" in result.flags + + +def test_load_header_va_with_datalen_and_flags() -> None: + result = MetaResult.load_header(b"VA 5 f16 c999\r\n") + assert result.rc == b"VA" + assert result.datalen == 5 + assert b"f16" in result.flags + assert b"c999" in result.flags + + +def test_load_header_en() -> None: + result = MetaResult.load_header(b"EN\r\n") + assert result.rc == b"EN" + assert result.datalen is None + assert result.flags == [] + + +# ------------------------------------------------------------------ # +# get / set # +# ------------------------------------------------------------------ # + + +def test_set_get(client: MetaClient) -> None: + client.set("key1", "hello") + r = client.get("key1") + assert r is not None + assert r.value == "hello" + + +def test_get_missing(client: MetaClient) -> None: + assert client.get("no_such_key") is None + + +def test_set_get_with_expire(client: MetaClient) -> None: + client.set("expkey", "val", expire=1) + r = client.get("expkey") + assert r is not None + assert r.value == "val" + time.sleep(1.1) + assert client.get("expkey") is None + + +def test_get_returns_getresult_instance(client: MetaClient) -> None: + client.set("gr_type", 42) + r = client.get("gr_type") + assert isinstance(r, GetResult) + assert r.value == 42 + + +def test_get_no_lru_bump(client: MetaClient) -> None: + client.set("nlb", "v") + r = client.get("nlb", no_lru_bump=True) + assert r is not None + assert r.value == "v" + + +def test_get_update_ttl(client: MetaClient) -> None: + client.set("utt", "v", expire=10) + r = client.get("utt", update_ttl=3600) + assert r is not None + assert r.value == "v" + + +def test_get_return_cas(client: MetaClient) -> None: + client.set("gr_cas", "v") + r = client.get("gr_cas", return_cas=True) + assert r is not None + assert isinstance(r.cas_token, int) + assert r.cas_token > 0 + + +def test_get_return_ttl(client: MetaClient) -> None: + client.set("gr_ttl", "v", expire=3600) + r = client.get("gr_ttl", return_ttl=True) + assert r is not None + assert r.ttl is not None + assert r.ttl > 0 + + +def test_get_no_ttl_requested(client: MetaClient) -> None: + client.set("gr_nottl", "v") + r = client.get("gr_nottl") + assert r is not None + assert r.ttl is None + + +def test_get_return_size(client: MetaClient) -> None: + client.set("gr_size", b"hello") + r = client.get("gr_size", return_size=True) + assert r is not None + assert r.size == 5 + + +def test_get_return_hit_before(client: MetaClient) -> None: + client.set("gr_hit", "v") + # First get: not hit before + r1 = client.get("gr_hit", return_hit_before=True) + assert r1 is not None + assert r1.hit_before is False + # Second get: hit before + r2 = client.get("gr_hit", return_hit_before=True) + assert r2 is not None + assert r2.hit_before is True + + +def test_get_return_last_access(client: MetaClient) -> None: + client.set("gr_la", "v") + r = client.get("gr_la", return_last_access=True) + assert r is not None + assert r.last_access is not None + assert isinstance(r.last_access, int) + + +# ------------------------------------------------------------------ # +# gat / touch # +# ------------------------------------------------------------------ # + + +def test_gat(client: MetaClient) -> None: + client.set("gat_key", "v", expire=1) + result = client.gat("gat_key", expire=3600) + assert result == "v" + # After gat with long TTL, key should still exist past original TTL + time.sleep(1.1) + r = client.get("gat_key") + assert r is not None + assert r.value == "v" + + +def test_gat_miss(client: MetaClient) -> None: + assert client.gat("no_gat_key", expire=60) is None + + +def test_touch_existing(client: MetaClient) -> None: + client.set("touch_key", "v", expire=1) + assert client.touch("touch_key", expire=3600) is True + time.sleep(1.1) + r = client.get("touch_key") + assert r is not None + assert r.value == "v" + + +def test_touch_missing(client: MetaClient) -> None: + assert client.touch("no_touch_key", expire=60) is False + + +# ------------------------------------------------------------------ # +# get_many # +# ------------------------------------------------------------------ # + + +def test_get_many(client: MetaClient) -> None: + client.set("m1", "a") + client.set("m2", "b") + result = client.get_many(["m1", "m2", "m_miss"]) + assert set(result.keys()) == {"m1", "m2"} + assert result["m1"].value == "a" + assert result["m2"].value == "b" + + +def test_get_many_all_miss(client: MetaClient) -> None: + assert client.get_many(["x1", "x2"]) == {} + + +# ------------------------------------------------------------------ # +# add / replace # +# ------------------------------------------------------------------ # + + +def test_add_new_key(client: MetaClient) -> None: + assert client.add("add_new", "v") is True + r = client.get("add_new") + assert r is not None + assert r.value == "v" + + +def test_add_existing_key(client: MetaClient) -> None: + client.set("add_exists", "original") + assert client.add("add_exists", "new") is False + r = client.get("add_exists") + assert r is not None + assert r.value == "original" + + +def test_replace_existing(client: MetaClient) -> None: + client.set("rep_key", "old") + assert client.replace("rep_key", "new") is True + r = client.get("rep_key") + assert r is not None + assert r.value == "new" + + +def test_replace_missing(client: MetaClient) -> None: + assert client.replace("rep_miss", "v") is False + + +# ------------------------------------------------------------------ # +# append / prepend # +# ------------------------------------------------------------------ # + + +def test_append(client: MetaClient) -> None: + client.set("app_key", b"hello") + assert client.append("app_key", b" world") is True + r = client.get("app_key") + assert r is not None + assert r.value == b"hello world" + + +def test_append_missing_key(client: MetaClient) -> None: + assert client.append("app_miss", b"v") is False + + +def test_append_vivify(client: MetaClient) -> None: + assert client.append("app_vivify", b"data", vivify_ttl=60) is True + r = client.get("app_vivify") + assert r is not None + assert r.value == b"data" + + +def test_prepend(client: MetaClient) -> None: + client.set("pre_key", b"world") + assert client.prepend("pre_key", b"hello ") is True + r = client.get("pre_key") + assert r is not None + assert r.value == b"hello world" + + +def test_prepend_missing_key(client: MetaClient) -> None: + assert client.prepend("pre_miss", b"v") is False + + +def test_prepend_vivify(client: MetaClient) -> None: + assert client.prepend("pre_vivify", b"data", vivify_ttl=60) is True + r = client.get("pre_vivify") + assert r is not None + assert r.value == b"data" + + +# ------------------------------------------------------------------ # +# cas # +# ------------------------------------------------------------------ # + + +def test_cas_success(client: MetaClient) -> None: + client.set("cas_key", "initial") + r = client.get("cas_key", return_cas=True) + assert r is not None + assert r.cas_token is not None + assert client.cas("cas_key", "updated", r.cas_token) is True + r2 = client.get("cas_key") + assert r2 is not None + assert r2.value == "updated" + + +def test_cas_conflict(client: MetaClient) -> None: + client.set("cas_conf", "v") + r = client.get("cas_conf", return_cas=True) + assert r is not None + assert r.cas_token is not None + client.set("cas_conf", "modified") + assert client.cas("cas_conf", "new", r.cas_token) is False + r2 = client.get("cas_conf") + assert r2 is not None + assert r2.value == "modified" + + +def test_cas_missing_key(client: MetaClient) -> None: + assert client.cas("cas_no_key", "v", 12345) is False + + +# ------------------------------------------------------------------ # +# delete # +# ------------------------------------------------------------------ # + + +def test_delete_existing(client: MetaClient) -> None: + client.set("del_key", "v") + assert client.delete("del_key") is True + assert client.get("del_key") is None + + +def test_delete_missing(client: MetaClient) -> None: + assert client.delete("del_miss") is False + + +def test_delete_with_cas(client: MetaClient) -> None: + client.set("del_cas", "v") + r = client.get("del_cas", return_cas=True) + assert r is not None + assert r.cas_token is not None + assert client.delete("del_cas", cas_token=r.cas_token) is True + + +def test_delete_with_wrong_cas(client: MetaClient) -> None: + client.set("del_cas_bad", "v") + assert client.delete("del_cas_bad", cas_token=99999999) is False + + +# ------------------------------------------------------------------ # +# invalidate # +# ------------------------------------------------------------------ # + + +def test_invalidate(client: MetaClient) -> None: + client.set("inv_key", "v") + assert client.invalidate("inv_key") is True + + +def test_invalidate_missing(client: MetaClient) -> None: + assert client.invalidate("inv_miss") is False + + +def test_invalidate_stale_flag(client: MetaClient) -> None: + client.set("inv_stale", "v") + client.invalidate("inv_stale", stale_ttl=10) + # After invalidation, get should see is_stale=True + r = client.get("inv_stale") + assert r is not None + assert r.is_stale is True + + +# ------------------------------------------------------------------ # +# incr / decr # +# ------------------------------------------------------------------ # + + +def test_incr(client: MetaClient) -> None: + client.set("ctr", 10) + assert client.incr("ctr") == 11 + + +def test_incr_with_delta(client: MetaClient) -> None: + client.set("ctr2", 5) + assert client.incr("ctr2", 3) == 8 + + +def test_incr_missing(client: MetaClient) -> None: + with pytest.raises(MemcacheError): + client.incr("incr_miss") + + +def test_incr_with_initial(client: MetaClient) -> None: + # On miss with J flag, memcached creates item with initial value (delta not applied) + assert client.incr("incr_init", initial=100, initial_ttl=60) == 100 + + +def test_incr_with_initial_existing(client: MetaClient) -> None: + client.set("incr_init_ex", 5) + # initial is ignored when key exists + assert client.incr("incr_init_ex", initial=100, initial_ttl=60) == 6 + + +def test_decr(client: MetaClient) -> None: + client.set("dctr", 10) + assert client.decr("dctr") == 9 + + +def test_decr_with_delta(client: MetaClient) -> None: + client.set("dctr2", 10) + assert client.decr("dctr2", 3) == 7 + + +def test_decr_floor_zero(client: MetaClient) -> None: + client.set("dctr3", 1) + assert client.decr("dctr3", 5) == 0 + + +def test_decr_missing(client: MetaClient) -> None: + with pytest.raises(MemcacheError): + client.decr("decr_miss") + + +def test_decr_with_initial(client: MetaClient) -> None: + result = client.decr("decr_init", initial=50, initial_ttl=60) + # On miss with J flag, memcached creates item with initial value (delta not applied) + assert result == 50 + + +# ------------------------------------------------------------------ # +# flush_all # +# ------------------------------------------------------------------ # + + +def test_flush_all(client: MetaClient) -> None: + client.set("flush_k", "v") + client.flush_all() + assert client.get("flush_k") is None + + +def test_flush_all_with_delay(client: MetaClient) -> None: + client.set("flush_d", "v") + client.flush_all(delay=0) + assert client.get("flush_d") is None + + +# ------------------------------------------------------------------ # +# execute_meta_command (low-level pass-through) # +# ------------------------------------------------------------------ # + + +def test_execute_meta_command(client: MetaClient) -> None: + cmd = MetaCommand(cm=b"ms", key=b"raw_key", datalen=3, flags=[b"T60"], value=b"raw") + result = client.execute_meta_command(cmd) + assert result.rc == b"HD" + + cmd2 = MetaCommand(cm=b"mg", key=b"raw_key", flags=[b"v"]) + result2 = client.execute_meta_command(cmd2) + assert result2.rc == b"VA" + assert result2.value == b"raw" diff --git a/tests/test_trio_client.py b/tests/test_trio_client.py index e356615..d179df0 100644 --- a/tests/test_trio_client.py +++ b/tests/test_trio_client.py @@ -1,6 +1,5 @@ import pytest import trio -import time import memcache diff --git a/uv.lock b/uv.lock index 0d8fd13..5374152 100644 --- a/uv.lock +++ b/uv.lock @@ -417,7 +417,7 @@ wheels = [ [[package]] name = "memcache" -version = "0.14.0b1" +version = "0.14.0" source = { editable = "." } dependencies = [ { name = "anyio" },