"""WebSocket client for account data streams.
This module provides the HibachiWSAccountClient for streaming real-time
account updates including balance changes and position updates via WebSocket.
"""
import asyncio
import logging
import time
import orjson
from hibachi_xyz.connection import connect_with_retry
from hibachi_xyz.errors import (
DeserializationError,
SerializationError,
ValidationError,
WebSocketConnectionError,
WebSocketMessageError,
)
from hibachi_xyz.executors import DEFAULT_WS_EXECUTOR, WsConnection, WsExecutor
from hibachi_xyz.helpers import DEFAULT_API_URL, create_with, get_hibachi_client
from hibachi_xyz.types import (
AccountSnapshot,
AccountStreamStartResult,
Json,
Position,
WsEventHandler,
)
log = logging.getLogger(__name__)
[docs]
class HibachiWSAccountClient:
"""WebSocket client for streaming Hibachi account data.
Provides real-time updates for account balances and positions via WebSocket connection.
"""
def __init__(
self,
api_key: str,
account_id: str,
api_endpoint: str = DEFAULT_API_URL,
executor: WsExecutor | None = None,
):
"""Initialize the Hibachi WebSocket Account Client.
Args:
api_key: The API key for authentication with the Hibachi API.
account_id: The account ID to stream data for.
api_endpoint: The base API endpoint URL. Defaults to DEFAULT_API_URL.
Will be converted from https:// to wss:// protocol.
executor: Optional WebSocket executor for handling connections. If None,
uses the default executor.
"""
self.api_endpoint = api_endpoint.replace("https://", "wss://") + "/ws/account"
self._websocket: WsConnection | None = None
self.message_id = 0
self.api_key = api_key
try:
self.account_id = int(account_id)
except (ValueError, TypeError) as e:
raise ValidationError(f"Invalid account_id format: {e}") from e
self.listenKey: str | None = None
self._event_handlers: dict[str, list[WsEventHandler]] = {}
self._executor: WsExecutor = (
executor if executor is not None else DEFAULT_WS_EXECUTOR()
)
@property
def websocket(self) -> WsConnection:
"""Get the active WebSocket connection.
Returns:
The active WebSocket connection.
Raises:
ValidationError: If no WebSocket connection exists. Must call connect() first.
"""
if self._websocket is None:
raise ValidationError from ValueError(
"No existing ws connection. Call `connect` first"
)
return self._websocket
[docs]
def on(self, topic: str, handler: WsEventHandler) -> None:
"""Register an event handler for a specific topic.
Registers a callback function that will be invoked when messages
with the specified topic are received from the WebSocket.
Args:
topic: The topic name to listen for (e.g., 'account_update', 'position_change').
handler: An async callback function that accepts a message dictionary.
"""
if topic not in self._event_handlers:
self._event_handlers[topic] = []
self._event_handlers[topic].append(handler)
[docs]
async def connect(self) -> None:
"""Establish a WebSocket connection to the account data stream.
Creates an authenticated WebSocket connection with automatic retry logic
using the provided API key and account ID.
Raises:
WebSocketConnectionError: If connection fails after retry attempts.
"""
self._websocket = await connect_with_retry(
web_url=self.api_endpoint
+ f"?accountId={self.account_id}&hibachiClient={get_hibachi_client()}",
headers=[("Authorization", self.api_key)],
executor=self._executor,
)
def _next_message_id(self) -> int:
"""Generate and return the next message ID.
Increments the internal message counter and returns the new value.
Used for tracking request-response pairs in the WebSocket protocol.
Returns:
The next sequential message ID.
"""
self.message_id += 1
return self.message_id
def _timestamp(self) -> int:
"""Get the current Unix timestamp in seconds.
Returns:
The current time as an integer Unix timestamp (seconds since epoch).
"""
return int(time.time())
[docs]
async def stream_start(self) -> AccountStreamStartResult:
"""Start the account data stream and retrieve the initial snapshot.
Sends a stream.start request to the WebSocket server and waits for
the response containing the initial account snapshot and listen key.
The listen key is stored for subsequent ping operations.
Returns:
AccountStreamStartResult containing the account snapshot (balance
and positions) and the listen key for maintaining the stream.
Raises:
ValidationError: If WebSocket connection is not established.
KeyError: If the response format is unexpected or missing required fields.
"""
message = {
"id": self._next_message_id(),
"method": "stream.start",
"params": {"accountId": self.account_id},
"timestamp": self._timestamp(),
}
try:
payload = orjson.dumps(message).decode()
except (ValueError, TypeError) as e:
raise SerializationError(
f"Failed to serialize stream.start message: {e}"
) from e
try:
await self.websocket.send(payload)
except Exception as e:
raise WebSocketMessageError(
f"Failed to send stream.start message {self.account_id=}"
) from e
response = await self.websocket.recv()
try:
response_data = orjson.loads(response)
except (ValueError, TypeError) as e:
raise DeserializationError(
f"Failed to parse WebSocket response: {e}"
) from e
try:
snapshot_data = response_data["result"]["accountSnapshot"]
snapshot = AccountSnapshot(
account_id=snapshot_data["account_id"],
balance=snapshot_data["balance"],
positions=[
create_with(Position, pos) for pos in snapshot_data["positions"]
],
)
result = AccountStreamStartResult(
accountSnapshot=snapshot,
listenKey=response_data["result"]["listenKey"],
)
except KeyError as e:
raise DeserializationError(
f"Missing required field in response: {e}"
) from e
self.listenKey = result.listenKey
return result
[docs]
async def ping(self) -> None:
"""Send a ping message to keep the account stream alive.
Sends a stream.ping request with the current listen key to prevent
the server from closing the stream due to inactivity.
Raises:
ValueError: If listenKey is not initialized. Must call stream_start() first.
ValidationError: If WebSocket connection is not established.
"""
if not self.listenKey:
raise ValidationError("Cannot send ping: listenKey not initialized.")
message = {
"id": self._next_message_id(),
"method": "stream.ping",
"params": {"accountId": self.account_id, "listenKey": self.listenKey},
"timestamp": self._timestamp(),
}
try:
payload = orjson.dumps(message).decode()
except (ValueError, TypeError) as e:
raise SerializationError(f"Failed to serialize ping message: {e}") from e
try:
await self.websocket.send(payload)
except Exception as e:
raise WebSocketMessageError(
f"Failed to send ping message {self.account_id=}"
) from e
response = await self.websocket.recv()
try:
parsed = orjson.loads(response)
except (ValueError, TypeError) as e:
raise DeserializationError(f"Failed to parse ping response: {e}") from e
if parsed.get("status") == 200:
log.debug("pong!")
[docs]
async def listen(self) -> Json | None:
"""Listen for and process a single message from the account stream.
Waits for a message from the WebSocket with a 15-second timeout.
If a timeout occurs, automatically sends a ping to keep the stream alive.
Dispatches received messages to registered event handlers based on topic.
Returns:
The parsed message as a JSON dictionary, or None if a timeout occurred
or the connection was closed.
Raises:
ValidationError: If WebSocket connection is not established.
WebSocketConnectionError: If the WebSocket connection is closed unexpectedly.
Exception: For any other errors during message processing.
"""
try:
response = await asyncio.wait_for(self.websocket.recv(), timeout=15)
try:
message = orjson.loads(response)
except (ValueError, TypeError) as e:
raise DeserializationError(
f"Failed to parse WebSocket message: {e}"
) from e
topic = message.get("topic")
if topic in self._event_handlers:
for handler in self._event_handlers[topic]:
await handler(message)
return message # type: ignore
except asyncio.TimeoutError:
await self.ping()
return None
except WebSocketConnectionError as e:
log.warning("WebSocket closed: %s", e)
except Exception as e:
log.error("WebSocket closed: %s", e)
raise
return None
[docs]
async def disconnect(self) -> None:
"""Close the WebSocket connection and clean up resources.
Closes the WebSocket connection and resets the internal state.
After calling this method, the client must call connect() and
stream_start() again before listening for messages.
"""
if self._websocket:
await self._websocket.close()
self._websocket = None