"""
Caching and rate limiting for QCoDeS parameter reads.
"""
import asyncio
import time
import logging
from typing import Dict, Tuple, Any, Optional
logger = logging.getLogger(__name__)
[docs]
class ReadCache:
"""Thread-safe cache for QCoDeS parameter values with timestamps."""
[docs]
def __init__(self):
# (instrument_name, parameter_name) -> (value, timestamp)
self.data: Dict[Tuple[str, str], Tuple[Any, float]] = {}
self.lock = asyncio.Lock()
[docs]
async def get(self, key: Tuple[str, str]) -> Optional[Tuple[Any, float]]:
"""Get cached value and timestamp for a parameter."""
async with self.lock:
return self.data.get(key)
[docs]
async def set(
self, key: Tuple[str, str], value: Any, timestamp: Optional[float] = None
):
"""Set cached value with timestamp for a parameter."""
if timestamp is None:
timestamp = time.time()
async with self.lock:
self.data[key] = (value, timestamp)
[docs]
async def clear(self):
"""Clear all cached values."""
async with self.lock:
self.data.clear()
[docs]
async def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
async with self.lock:
return {
"size": len(self.data),
"keys": list(self.data.keys()),
"oldest_timestamp": min(
(ts for _, ts in self.data.values()), default=0
),
"newest_timestamp": max(
(ts for _, ts in self.data.values()), default=0
),
}
[docs]
class RateLimiter:
"""Rate limiter for QCoDeS instrument access."""
[docs]
def __init__(self, min_interval_s: float = 0.2):
self.min_interval_s = min_interval_s
# instrument_name -> last_access_time
self.last_access: Dict[str, float] = {}
# Per-instrument locks to serialize access
self.locks: Dict[str, asyncio.Lock] = {}
self.lock = asyncio.Lock()
[docs]
def get_instrument_lock(self, instrument_name: str) -> asyncio.Lock:
"""Get or create a lock for an instrument."""
if instrument_name not in self.locks:
self.locks[instrument_name] = asyncio.Lock()
return self.locks[instrument_name]
[docs]
async def can_access(self, instrument_name: str) -> bool:
"""Check if instrument can be accessed (rate limit check)."""
async with self.lock:
last_time = self.last_access.get(instrument_name, 0)
return (time.time() - last_time) >= self.min_interval_s
[docs]
async def record_access(self, instrument_name: str):
"""Record that instrument was accessed."""
async with self.lock:
self.last_access[instrument_name] = time.time()
[docs]
async def wait_if_needed(self, instrument_name: str):
"""Wait if rate limit would be exceeded."""
async with self.lock:
last_time = self.last_access.get(instrument_name, 0)
elapsed = time.time() - last_time
if elapsed < self.min_interval_s:
wait_time = self.min_interval_s - elapsed
logger.debug(
f"Rate limiting {instrument_name}: waiting {wait_time:.3f}s"
)
await asyncio.sleep(wait_time)
[docs]
class ParameterPoller:
"""Background poller for subscribed parameters."""
[docs]
def __init__(self, cache: ReadCache, rate_limiter: RateLimiter):
self.cache = cache
self.rate_limiter = rate_limiter
self.subscriptions: Dict[Tuple[str, str], float] = (
{}
) # (inst, param) -> interval_s
self.tasks: Dict[Tuple[str, str], asyncio.Task] = {}
self.running = False
[docs]
async def subscribe(
self,
instrument_name: str,
parameter_name: str,
interval_s: float,
get_parameter_func,
):
"""Subscribe to periodic parameter updates."""
key = (instrument_name, parameter_name)
# Cancel existing subscription if any
await self.unsubscribe(instrument_name, parameter_name)
self.subscriptions[key] = interval_s
# Start polling task
task = asyncio.create_task(
self._poll_parameter(
instrument_name, parameter_name, interval_s, get_parameter_func
)
)
self.tasks[key] = task
logger.debug(
f"Subscribed to {instrument_name}.{parameter_name} at {interval_s}s interval"
)
[docs]
async def unsubscribe(self, instrument_name: str, parameter_name: str):
"""Unsubscribe from parameter updates."""
key = (instrument_name, parameter_name)
if key in self.tasks:
task = self.tasks.pop(key)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.subscriptions.pop(key, None)
logger.debug(f"Unsubscribed from {instrument_name}.{parameter_name}")
async def _poll_parameter(
self,
instrument_name: str,
parameter_name: str,
interval_s: float,
get_parameter_func,
):
"""Continuously poll a parameter at the specified interval."""
key = (instrument_name, parameter_name)
while key in self.subscriptions:
try:
# Use the rate limiter and instrument lock
async with self.rate_limiter.get_instrument_lock(instrument_name):
await self.rate_limiter.wait_if_needed(instrument_name)
# Get the parameter value
value = await asyncio.to_thread(
get_parameter_func, instrument_name, parameter_name
)
# Cache the value
await self.cache.set(key, value)
await self.rate_limiter.record_access(instrument_name)
logger.debug(f"Polled {instrument_name}.{parameter_name} = {value}")
except Exception as e:
logger.error(f"Error polling {instrument_name}.{parameter_name}: {e}")
# Wait for next poll
await asyncio.sleep(interval_s)
[docs]
async def stop_all(self):
"""Stop all polling tasks."""
for key in list(self.tasks.keys()):
instrument_name, parameter_name = key
await self.unsubscribe(instrument_name, parameter_name)
self.running = False
[docs]
def get_subscriptions(self) -> Dict[str, Any]:
"""Get current subscription status."""
return {
"subscriptions": list(self.subscriptions.keys()),
"active_tasks": len(self.tasks),
"intervals": dict(self.subscriptions),
}