- Added DataConnector base class with OHLCV, InstrumentInfo, Interval - Implemented MoomooClient with rate limiting, circuit breaker, caching - Mock mode generates realistic data for development - Real-time WebSocket subscription support (mock) - Added examples/demo_moomoo.py showcasing functionality - Updated requirements.txt with requests, websocket-client, redis, python-dotenv - Updated README.md with Data Layer documentation - Added .env.example for configuration
561 lines
20 KiB
Python
561 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Moomoo API connector for OHLCV data.
|
|
|
|
Implementation of a robust Python client for moomoo API with comprehensive
|
|
error handling, rate limiting, and data normalization.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import random
|
|
from datetime import datetime, timedelta
|
|
from typing import List, Optional, Dict, Any, Callable
|
|
from threading import Lock, Timer
|
|
from queue import Queue
|
|
import json
|
|
|
|
from .base import DataConnector, OHLCV, InstrumentInfo, Interval
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimiter:
|
|
"""Token bucket rate limiter."""
|
|
|
|
def __init__(self, tokens_per_minute: int = 60, bucket_size: int = 60):
|
|
"""
|
|
Args:
|
|
tokens_per_minute: Maximum tokens per minute
|
|
bucket_size: Maximum bucket capacity (burst)
|
|
"""
|
|
self.tokens_per_second = tokens_per_minute / 60.0
|
|
self.bucket_size = bucket_size
|
|
self.tokens = bucket_size
|
|
self.last_refill = time.time()
|
|
self.lock = Lock()
|
|
|
|
def acquire(self, tokens: int = 1) -> float:
|
|
"""
|
|
Acquire tokens, blocking if necessary.
|
|
|
|
Returns:
|
|
Wait time in seconds
|
|
"""
|
|
with self.lock:
|
|
now = time.time()
|
|
elapsed = now - self.last_refill
|
|
# Refill tokens based on elapsed time
|
|
self.tokens = min(
|
|
self.bucket_size,
|
|
self.tokens + elapsed * self.tokens_per_second
|
|
)
|
|
self.last_refill = now
|
|
|
|
if self.tokens >= tokens:
|
|
self.tokens -= tokens
|
|
return 0.0
|
|
else:
|
|
# Need to wait for enough tokens
|
|
deficit = tokens - self.tokens
|
|
wait_time = deficit / self.tokens_per_second
|
|
self.tokens = 0
|
|
return wait_time
|
|
|
|
def wait_if_needed(self, tokens: int = 1):
|
|
"""Block until tokens are available."""
|
|
wait_time = self.acquire(tokens)
|
|
if wait_time > 0:
|
|
time.sleep(wait_time)
|
|
|
|
|
|
class CircuitBreaker:
|
|
"""Circuit breaker pattern for API failures."""
|
|
|
|
def __init__(
|
|
self,
|
|
failure_threshold: int = 5,
|
|
recovery_timeout: int = 60,
|
|
half_open_max_attempts: int = 3
|
|
):
|
|
self.failure_threshold = failure_threshold
|
|
self.recovery_timeout = recovery_timeout
|
|
self.half_open_max_attempts = half_open_max_attempts
|
|
|
|
self.failure_count = 0
|
|
self.last_failure_time = 0
|
|
self.state = "CLOSED" # CLOSED, OPEN, HALF_OPEN
|
|
self.half_open_attempts = 0
|
|
self.lock = Lock()
|
|
|
|
def record_success(self):
|
|
with self.lock:
|
|
if self.state == "HALF_OPEN":
|
|
self.half_open_attempts += 1
|
|
if self.half_open_attempts >= self.half_open_max_attempts:
|
|
self.state = "CLOSED"
|
|
self.failure_count = 0
|
|
self.half_open_attempts = 0
|
|
elif self.state == "CLOSED":
|
|
self.failure_count = max(0, self.failure_count - 1)
|
|
|
|
def record_failure(self):
|
|
with self.lock:
|
|
self.failure_count += 1
|
|
self.last_failure_time = time.time()
|
|
|
|
if self.state == "CLOSED" and self.failure_count >= self.failure_threshold:
|
|
self.state = "OPEN"
|
|
logger.warning(f"Circuit breaker OPEN after {self.failure_count} failures")
|
|
elif self.state == "HALF_OPEN":
|
|
self.state = "OPEN"
|
|
self.half_open_attempts = 0
|
|
|
|
def allow_request(self) -> bool:
|
|
with self.lock:
|
|
if self.state == "CLOSED":
|
|
return True
|
|
elif self.state == "OPEN":
|
|
# Check if recovery timeout has passed
|
|
if time.time() - self.last_failure_time > self.recovery_timeout:
|
|
self.state = "HALF_OPEN"
|
|
self.half_open_attempts = 0
|
|
logger.info("Circuit breaker HALF_OPEN")
|
|
return True
|
|
return False
|
|
elif self.state == "HALF_OPEN":
|
|
return True
|
|
return False
|
|
|
|
def get_state(self) -> str:
|
|
with self.lock:
|
|
return self.state
|
|
|
|
|
|
class MoomooClient(DataConnector):
|
|
"""
|
|
Moomoo API client with rate limiting, error handling, and data normalization.
|
|
|
|
Supports both REST API for historical data and WebSocket for real-time updates.
|
|
"""
|
|
|
|
# Mock configuration - replace with real API endpoints
|
|
REST_BASE_URL = "https://api.moomoo.com/v1"
|
|
WS_BASE_URL = "wss://api.moomoo.com/ws"
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: Optional[str] = None,
|
|
session_token: Optional[str] = None,
|
|
rate_limit_per_minute: int = 60,
|
|
mock_mode: bool = True,
|
|
cache_ttl: int = 300 # 5 minutes
|
|
):
|
|
"""
|
|
Args:
|
|
api_key: Moomoo API key (optional in mock mode)
|
|
session_token: Pre-authenticated session token
|
|
rate_limit_per_minute: API rate limit
|
|
mock_mode: If True, generate mock data instead of calling real API
|
|
cache_ttl: Cache TTL in seconds (0 to disable)
|
|
"""
|
|
self.api_key = api_key
|
|
self.session_token = session_token
|
|
self.mock_mode = mock_mode
|
|
self.cache_ttl = cache_ttl
|
|
|
|
# Rate limiting
|
|
self.rate_limiter = RateLimiter(tokens_per_minute=rate_limit_per_minute)
|
|
self.circuit_breaker = CircuitBreaker()
|
|
|
|
# Connection state
|
|
self._connected = False
|
|
self._request_count = 0
|
|
self._last_request_time = 0
|
|
|
|
# Cache (simple in-memory dict, could be Redis)
|
|
self._cache: Dict[str, tuple[float, Any]] = {}
|
|
|
|
# WebSocket subscription state
|
|
self._ws_connected = False
|
|
self._subscriptions: Dict[str, List[Callable]] = {}
|
|
self._ws_thread = None
|
|
|
|
# Thread safety
|
|
self._lock = Lock()
|
|
|
|
logger.info(f"MoomooClient initialized (mock_mode={mock_mode})")
|
|
|
|
def connect(self) -> bool:
|
|
"""Establish connection to moomoo API."""
|
|
if self._connected:
|
|
return True
|
|
|
|
if self.mock_mode:
|
|
logger.info("Mock mode: simulating connection")
|
|
self._connected = True
|
|
return True
|
|
|
|
# Real API connection logic would go here
|
|
try:
|
|
# TODO: Implement actual connection
|
|
# 1. Validate API key
|
|
# 2. Obtain session token if not provided
|
|
# 3. Test connectivity
|
|
self._connected = True
|
|
logger.info("Connected to moomoo API")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to moomoo API: {e}")
|
|
self._connected = False
|
|
return False
|
|
|
|
def disconnect(self):
|
|
"""Close connection."""
|
|
self._unsubscribe_all()
|
|
self._connected = False
|
|
logger.info("Disconnected from moomoo API")
|
|
|
|
def _make_request(self, endpoint: str, params: Optional[Dict] = None) -> Dict:
|
|
"""
|
|
Make HTTP request with rate limiting and error handling.
|
|
|
|
Args:
|
|
endpoint: API endpoint (e.g., "/history/kline")
|
|
params: Query parameters
|
|
|
|
Returns:
|
|
JSON response as dict
|
|
"""
|
|
if not self.circuit_breaker.allow_request():
|
|
raise Exception("Circuit breaker is OPEN - API unavailable")
|
|
|
|
# Apply rate limiting
|
|
self.rate_limiter.wait_if_needed()
|
|
|
|
if self.mock_mode:
|
|
# Simulate network latency
|
|
time.sleep(random.uniform(0.05, 0.2))
|
|
|
|
# Simulate occasional failures (5% chance)
|
|
if random.random() < 0.05:
|
|
self.circuit_breaker.record_failure()
|
|
raise Exception("Mock API failure")
|
|
|
|
self.circuit_breaker.record_success()
|
|
self._request_count += 1
|
|
|
|
# Return mock response based on endpoint
|
|
return self._generate_mock_response(endpoint, params)
|
|
|
|
# TODO: Implement actual HTTP request
|
|
# headers = {"Authorization": f"Bearer {self.session_token}"}
|
|
# response = requests.get(f"{self.REST_BASE_URL}{endpoint}", params=params, headers=headers)
|
|
# response.raise_for_status()
|
|
# return response.json()
|
|
|
|
# For now, return mock
|
|
return self._generate_mock_response(endpoint, params)
|
|
|
|
def _generate_mock_response(self, endpoint: str, params: Optional[Dict]) -> Dict:
|
|
"""Generate mock API response."""
|
|
if endpoint == "/history/kline":
|
|
symbol = params.get("symbol", "AAPL") if params else "AAPL"
|
|
interval = params.get("interval", "1d") if params else "1d"
|
|
limit = params.get("limit", 100) if params else 100
|
|
|
|
# Generate mock OHLCV data
|
|
data = []
|
|
base_price = random.uniform(100, 200)
|
|
for i in range(limit):
|
|
timestamp = datetime.now() - timedelta(days=limit - i)
|
|
open_price = base_price + random.uniform(-5, 5)
|
|
high = open_price + random.uniform(0, 3)
|
|
low = open_price - random.uniform(0, 3)
|
|
close = open_price + random.uniform(-2, 2)
|
|
volume = random.uniform(1000000, 5000000)
|
|
|
|
data.append({
|
|
"time": timestamp.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"open": round(open_price, 2),
|
|
"high": round(high, 2),
|
|
"low": round(low, 2),
|
|
"close": round(close, 2),
|
|
"volume": int(volume),
|
|
"amount": round(volume * close, 2)
|
|
})
|
|
base_price = close
|
|
|
|
return {
|
|
"code": 0,
|
|
"msg": "success",
|
|
"data": {
|
|
"symbol": symbol,
|
|
"interval": interval,
|
|
"list": data
|
|
}
|
|
}
|
|
|
|
elif endpoint == "/stock/basicinfo":
|
|
symbol = params.get("symbol", "AAPL") if params else "AAPL"
|
|
return {
|
|
"code": 0,
|
|
"msg": "success",
|
|
"data": {
|
|
"symbol": symbol,
|
|
"name": f"Mock {symbol} Inc.",
|
|
"exchange": "NASDAQ",
|
|
"currency": "USD",
|
|
"lotSize": 100,
|
|
"minPriceIncrement": 0.01,
|
|
"tradingHours": "09:30-16:00",
|
|
"isTradable": True
|
|
}
|
|
}
|
|
|
|
return {"code": 0, "msg": "success", "data": {}}
|
|
|
|
def _normalize_ohlcv(self, raw_data: Dict, symbol: str, interval: Interval) -> List[OHLCV]:
|
|
"""Convert moomoo-specific OHLCV format to standardized format."""
|
|
normalized = []
|
|
|
|
for item in raw_data.get("list", []):
|
|
# Parse timestamp (adjust based on actual moomoo format)
|
|
timestamp_str = item.get("time")
|
|
if timestamp_str:
|
|
# Try various formats
|
|
for fmt in ["%Y-%m-%d %H:%M:%S", "%Y-%m-%d", "%Y%m%d%H%M%S"]:
|
|
try:
|
|
timestamp = datetime.strptime(timestamp_str, fmt)
|
|
break
|
|
except ValueError:
|
|
continue
|
|
else:
|
|
logger.warning(f"Could not parse timestamp: {timestamp_str}")
|
|
continue
|
|
else:
|
|
continue
|
|
|
|
normalized.append(OHLCV(
|
|
timestamp=timestamp,
|
|
open=float(item["open"]),
|
|
high=float(item["high"]),
|
|
low=float(item["low"]),
|
|
close=float(item["close"]),
|
|
volume=float(item["volume"]),
|
|
symbol=symbol,
|
|
interval=interval
|
|
))
|
|
|
|
return normalized
|
|
|
|
def get_ohlcv(
|
|
self,
|
|
symbol: str,
|
|
interval: Interval,
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
limit: Optional[int] = None
|
|
) -> List[OHLCV]:
|
|
# Generate cache key
|
|
cache_key = f"ohlcv:{symbol}:{interval.value}:{start_date.date()}:{end_date.date()}"
|
|
|
|
# Check cache
|
|
if self.cache_ttl > 0:
|
|
cached = self._cache.get(cache_key)
|
|
if cached and time.time() - cached[0] < self.cache_ttl:
|
|
logger.debug(f"Cache hit for {cache_key}")
|
|
return cached[1]
|
|
|
|
# Build request parameters
|
|
params = {
|
|
"symbol": symbol,
|
|
"interval": interval.value,
|
|
"start_date": start_date.strftime("%Y-%m-%d"),
|
|
"end_date": end_date.strftime("%Y-%m-%d")
|
|
}
|
|
if limit:
|
|
params["limit"] = limit
|
|
|
|
try:
|
|
response = self._make_request("/history/kline", params)
|
|
if response.get("code") != 0:
|
|
raise Exception(f"API error: {response.get('msg')}")
|
|
|
|
raw_data = response.get("data", {})
|
|
ohlcv_data = self._normalize_ohlcv(raw_data, symbol, interval)
|
|
|
|
# Cache result
|
|
if self.cache_ttl > 0:
|
|
self._cache[cache_key] = (time.time(), ohlcv_data)
|
|
|
|
return ohlcv_data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch OHLCV data: {e}")
|
|
raise
|
|
|
|
def get_instrument_info(self, symbol: str) -> InstrumentInfo:
|
|
cache_key = f"instrument:{symbol}"
|
|
|
|
if self.cache_ttl > 0:
|
|
cached = self._cache.get(cache_key)
|
|
if cached and time.time() - cached[0] < self.cache_ttl:
|
|
logger.debug(f"Cache hit for {cache_key}")
|
|
return cached[1]
|
|
|
|
params = {"symbol": symbol}
|
|
|
|
try:
|
|
response = self._make_request("/stock/basicinfo", params)
|
|
if response.get("code") != 0:
|
|
raise Exception(f"API error: {response.get('msg')}")
|
|
|
|
raw_data = response.get("data", {})
|
|
|
|
info = InstrumentInfo(
|
|
symbol=raw_data.get("symbol", symbol),
|
|
name=raw_data.get("name", symbol),
|
|
exchange=raw_data.get("exchange", "UNKNOWN"),
|
|
currency=raw_data.get("currency", "USD"),
|
|
lot_size=int(raw_data.get("lotSize", 100)),
|
|
min_price_increment=float(raw_data.get("minPriceIncrement", 0.01)),
|
|
trading_hours=raw_data.get("tradingHours"),
|
|
is_tradable=bool(raw_data.get("isTradable", True))
|
|
)
|
|
|
|
if self.cache_ttl > 0:
|
|
self._cache[cache_key] = (time.time(), info)
|
|
|
|
return info
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch instrument info: {e}")
|
|
raise
|
|
|
|
def subscribe_ohlcv(self, symbols: List[str], interval: Interval, callback):
|
|
"""Subscribe to real-time OHLCV updates."""
|
|
if not self._ws_connected and not self.mock_mode:
|
|
self._connect_websocket()
|
|
|
|
with self._lock:
|
|
for symbol in symbols:
|
|
key = f"{symbol}:{interval.value}"
|
|
if key not in self._subscriptions:
|
|
self._subscriptions[key] = []
|
|
self._subscriptions[key].append(callback)
|
|
|
|
logger.info(f"Subscribed to {len(symbols)} symbols for {interval.value} interval")
|
|
|
|
def unsubscribe(self, symbols: List[str]):
|
|
"""Cancel subscriptions."""
|
|
with self._lock:
|
|
for symbol in symbols:
|
|
# Remove all subscriptions for this symbol
|
|
keys_to_remove = [k for k in self._subscriptions.keys() if k.startswith(symbol + ":")]
|
|
for key in keys_to_remove:
|
|
del self._subscriptions[key]
|
|
|
|
logger.info(f"Unsubscribed from {len(symbols)} symbols")
|
|
|
|
def _connect_websocket(self):
|
|
"""Establish WebSocket connection."""
|
|
if self._ws_connected:
|
|
return
|
|
|
|
if self.mock_mode:
|
|
# Start mock data generator thread
|
|
self._ws_connected = True
|
|
self._start_mock_websocket()
|
|
return
|
|
|
|
# TODO: Implement actual WebSocket connection
|
|
# self._ws = websocket.WebSocketApp(...)
|
|
# self._ws_thread = threading.Thread(target=self._ws.run_forever)
|
|
# self._ws_thread.start()
|
|
pass
|
|
|
|
def _start_mock_websocket(self):
|
|
"""Start thread that generates mock real-time data."""
|
|
from threading import Thread
|
|
|
|
def mock_data_generator():
|
|
while self._ws_connected:
|
|
time.sleep(1) # Emit data every second
|
|
with self._lock:
|
|
for key, callbacks in list(self._subscriptions.items()):
|
|
symbol, interval_str = key.split(":")
|
|
interval = Interval(interval_str)
|
|
|
|
# Generate mock OHLCV update
|
|
timestamp = datetime.now()
|
|
ohlcv = OHLCV(
|
|
timestamp=timestamp,
|
|
open=random.uniform(100, 200),
|
|
high=random.uniform(100, 200),
|
|
low=random.uniform(100, 200),
|
|
close=random.uniform(100, 200),
|
|
volume=random.uniform(1000, 5000),
|
|
symbol=symbol,
|
|
interval=interval
|
|
)
|
|
|
|
for callback in callbacks:
|
|
try:
|
|
callback(ohlcv)
|
|
except Exception as e:
|
|
logger.error(f"Callback error: {e}")
|
|
|
|
self._ws_thread = Thread(target=mock_data_generator, daemon=True)
|
|
self._ws_thread.start()
|
|
|
|
def _unsubscribe_all(self):
|
|
"""Cancel all subscriptions."""
|
|
with self._lock:
|
|
self._subscriptions.clear()
|
|
|
|
self._ws_connected = False
|
|
if self._ws_thread:
|
|
self._ws_thread.join(timeout=1)
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
return self._connected
|
|
|
|
@property
|
|
def request_count(self) -> int:
|
|
return self._request_count
|
|
|
|
def get_health_metrics(self) -> Dict[str, Any]:
|
|
"""Get client health metrics."""
|
|
return {
|
|
"connected": self._connected,
|
|
"request_count": self._request_count,
|
|
"circuit_breaker_state": self.circuit_breaker.get_state(),
|
|
"subscription_count": sum(len(callbacks) for callbacks in self._subscriptions.values()),
|
|
"cache_size": len(self._cache),
|
|
"mock_mode": self.mock_mode
|
|
}
|
|
|
|
|
|
# Convenience function for quick usage
|
|
def create_moomoo_client(
|
|
api_key: Optional[str] = None,
|
|
mock_mode: bool = True,
|
|
**kwargs
|
|
) -> MoomooClient:
|
|
"""
|
|
Create and connect a MoomooClient instance.
|
|
|
|
Args:
|
|
api_key: Moomoo API key (optional in mock mode)
|
|
mock_mode: If True, generate mock data
|
|
**kwargs: Additional arguments for MoomooClient
|
|
|
|
Returns:
|
|
Connected MoomooClient instance
|
|
"""
|
|
client = MoomooClient(api_key=api_key, mock_mode=mock_mode, **kwargs)
|
|
client.connect()
|
|
return client |