def setUp(self):
        # Fixture Setup
        clock = TestClock()
        logger = TestLogger(clock)

        self.trader_id = TraderId("TESTER", "000")

        self.strategy = TradingStrategy(order_id_tag="001")
        self.strategy.register_trader(
            TraderId("TESTER", "000"),
            clock,
            logger,
        )

        config = {
            'host': 'localhost',
            'port': 6379,
        }

        self.database = RedisExecutionDatabase(
            trader_id=self.trader_id,
            logger=logger,
            command_serializer=MsgPackCommandSerializer(),
            event_serializer=MsgPackEventSerializer(),
            config=config,
        )

        self.test_redis = redis.Redis(host="localhost", port=6379, db=0)
예제 #2
0
    def __init__(
        self,
        strategies: List[TradingStrategy],
        config: Dict[str, object],
    ):
        """
        Initialize a new instance of the TradingNode class.

        Parameters
        ----------
        strategies : list[TradingStrategy]
            The list of strategies to run on the trading node.
        config : dict[str, object]
            The configuration for the trading node.

        Raises
        ------
        ValueError
            If strategies is None or empty.
        ValueError
            If config is None or empty.

        """
        PyCondition.not_none(strategies, "strategies")
        PyCondition.not_none(config, "config")
        PyCondition.not_empty(strategies, "strategies")
        PyCondition.not_empty(config, "config")

        # Extract configs
        config_trader = config.get("trader", {})
        config_log = config.get("logging", {})
        config_exec_db = config.get("exec_database", {})
        config_strategy = config.get("strategy", {})
        config_adapters = config.get("adapters", {})

        self._uuid_factory = UUIDFactory()
        self._loop = asyncio.get_event_loop()
        self._executor = concurrent.futures.ThreadPoolExecutor()
        self._loop.set_default_executor(self._executor)
        self._clock = LiveClock(loop=self._loop)

        self.created_time = self._clock.utc_now()
        self._is_running = False

        # Uncomment for debugging
        # self._loop.set_debug(True)

        # Setup identifiers
        self.trader_id = TraderId(
            name=config_trader["name"],
            tag=config_trader["id_tag"],
        )

        # Setup logging
        self._logger = LiveLogger(
            clock=self._clock,
            name=self.trader_id.value,
            level_console=LogLevelParser.from_str_py(config_log.get("log_level_console")),
            level_file=LogLevelParser.from_str_py(config_log.get("log_level_file")),
            level_store=LogLevelParser.from_str_py(config_log.get("log_level_store")),
            run_in_process=config_log.get("run_in_process", True),  # Run logger in a separate process
            log_thread=config_log.get("log_thread_id", False),
            log_to_file=config_log.get("log_to_file", False),
            log_file_path=config_log.get("log_file_path", ""),
        )

        self._log = LoggerAdapter(component_name=self.__class__.__name__, logger=self._logger)
        self._log_header()
        self._log.info("Building...")

        self._setup_loop()  # Requires the logger to be initialized

        self.portfolio = Portfolio(
            clock=self._clock,
            logger=self._logger,
        )

        self._data_engine = LiveDataEngine(
            loop=self._loop,
            portfolio=self.portfolio,
            clock=self._clock,
            logger=self._logger,
            config={"qsize": 10000},
        )

        self.portfolio.register_cache(self._data_engine.cache)
        self.analyzer = PerformanceAnalyzer()

        if config_exec_db["type"] == "redis":
            exec_db = RedisExecutionDatabase(
                trader_id=self.trader_id,
                logger=self._logger,
                command_serializer=MsgPackCommandSerializer(),
                event_serializer=MsgPackEventSerializer(),
                config={
                    "host": config_exec_db["host"],
                    "port": config_exec_db["port"],
                }
            )
        else:
            exec_db = BypassExecutionDatabase(
                trader_id=self.trader_id,
                logger=self._logger,
            )

        self._exec_engine = LiveExecutionEngine(
            loop=self._loop,
            database=exec_db,
            portfolio=self.portfolio,
            clock=self._clock,
            logger=self._logger,
            config={"qsize": 10000},
        )

        self._exec_engine.load_cache()
        self._setup_adapters(config_adapters, self._logger)

        self.trader = Trader(
            trader_id=self.trader_id,
            strategies=strategies,
            portfolio=self.portfolio,
            data_engine=self._data_engine,
            exec_engine=self._exec_engine,
            clock=self._clock,
            logger=self._logger,
        )

        self._check_residuals_delay = config_trader.get("check_residuals_delay", 5.0)
        self._load_strategy_state = config_strategy.get("load_state", True)
        self._save_strategy_state = config_strategy.get("save_state", True)

        if self._load_strategy_state:
            self.trader.load()

        self._log.info("state=INITIALIZED.")
        self.time_to_initialize = self._clock.delta(self.created_time)
        self._log.info(f"Initialized in {self.time_to_initialize.total_seconds():.3f}s.")
예제 #3
0
    def __init__(
        self,
        strategies: List[TradingStrategy],
        config: Dict[str, object],
    ):
        """
        Initialize a new instance of the TradingNode class.

        Parameters
        ----------
        strategies : list[TradingStrategy]
            The list of strategies to run on the trading node.
        config : dict[str, object]
            The configuration for the trading node.

        """
        if strategies is None:
            strategies = []

        config_trader = config.get("trader", {})
        config_log = config.get("logging", {})
        config_exec_db = config.get("exec_database", {})
        config_strategy = config.get("strategy", {})
        config_data_clients = config.get("data_clients", {})
        config_exec_clients = config.get("exec_clients", {})

        self._clock = LiveClock()
        self._uuid_factory = UUIDFactory()
        self._loop = asyncio.get_event_loop()
        self._executor = concurrent.futures.ThreadPoolExecutor()
        self._loop.set_default_executor(self._executor)
        self._loop.set_debug(False)  # TODO: Development
        self._is_running = False

        # Setup identifiers
        self.trader_id = TraderId(
            name=config_trader["name"],
            tag=config_trader["id_tag"],
        )

        # Setup logging
        logger = LiveLogger(
            clock=self._clock,
            name=self.trader_id.value,
            level_console=LogLevelParser.from_str_py(
                config_log.get("log_level_console")),
            level_file=LogLevelParser.from_str_py(
                config_log.get("log_level_file")),
            level_store=LogLevelParser.from_str_py(
                config_log.get("log_level_store")),
            log_thread=config_log.get("log_thread_id", True),
            log_to_file=config_log.get("log_to_file", False),
            log_file_path=config_log.get("log_file_path", ""),
        )

        self._log = LoggerAdapter(component_name=self.__class__.__name__,
                                  logger=logger)
        self._log_header()
        self._log.info("Building...")

        self.portfolio = Portfolio(
            clock=self._clock,
            logger=logger,
        )

        self._data_engine = LiveDataEngine(
            loop=self._loop,
            portfolio=self.portfolio,
            clock=self._clock,
            logger=logger,
        )

        self.portfolio.register_cache(self._data_engine.cache)
        self.analyzer = PerformanceAnalyzer()

        if config_exec_db["type"] == "redis":
            exec_db = RedisExecutionDatabase(
                trader_id=self.trader_id,
                logger=logger,
                command_serializer=MsgPackCommandSerializer(),
                event_serializer=MsgPackEventSerializer(),
                config={
                    "host": config_exec_db["host"],
                    "port": config_exec_db["port"],
                })
        else:
            exec_db = BypassExecutionDatabase(
                trader_id=self.trader_id,
                logger=logger,
            )

        self._exec_engine = LiveExecutionEngine(
            loop=self._loop,
            database=exec_db,
            portfolio=self.portfolio,
            clock=self._clock,
            logger=logger,
        )

        self._exec_engine.load_cache()
        self._setup_data_clients(config_data_clients, logger)
        self._setup_exec_clients(config_exec_clients, logger)

        self.trader = Trader(
            trader_id=self.trader_id,
            strategies=strategies,
            data_engine=self._data_engine,
            exec_engine=self._exec_engine,
            clock=self._clock,
            logger=logger,
        )

        self._check_residuals_delay = 2.0  # Hard coded delay (refactor)
        self._load_strategy_state = config_strategy.get("load_state", True)
        self._save_strategy_state = config_strategy.get("save_state", True)

        if self._load_strategy_state:
            self.trader.load()

        self._setup_loop()
        self._log.info("state=INITIALIZED.")
예제 #4
0
    def __init__(
        self,
        strategies: List[TradingStrategy],
        config: Dict[str, object],
    ):
        """
        Initialize a new instance of the TradingNode class.

        Parameters
        ----------
        strategies : list[TradingStrategy]
            The list of strategies to run on the trading node.
        config : dict[str, object]
            The configuration for the trading node.

        Raises
        ------
        ValueError
            If strategies is None or empty.
        ValueError
            If config is None or empty.

        """
        PyCondition.not_none(strategies, "strategies")
        PyCondition.not_none(config, "config")
        PyCondition.not_empty(strategies, "strategies")
        PyCondition.not_empty(config, "config")

        self._config = config

        # Extract configs
        config_trader = config.get("trader", {})
        config_system = config.get("system", {})
        config_log = config.get("logging", {})
        config_exec_db = config.get("exec_database", {})
        config_risk = config.get("risk", {})
        config_strategy = config.get("strategy", {})

        # System config
        self._connection_timeout = config_system.get("connection_timeout", 5.0)
        self._disconnection_timeout = config_system.get(
            "disconnection_timeout", 5.0)
        self._check_residuals_delay = config_system.get(
            "check_residuals_delay", 5.0)
        self._load_strategy_state = config_strategy.get("load_state", True)
        self._save_strategy_state = config_strategy.get("save_state", True)

        # Setup loop
        self._loop = asyncio.get_event_loop()
        self._executor = concurrent.futures.ThreadPoolExecutor()
        self._loop.set_default_executor(self._executor)
        self._loop.set_debug(config_system.get("loop_debug", False))

        # Components
        self._clock = LiveClock(loop=self._loop)
        self._uuid_factory = UUIDFactory()
        self.system_id = self._uuid_factory.generate()
        self.created_time = self._clock.utc_now()
        self._is_running = False

        # Setup identifiers
        self.trader_id = TraderId(
            name=config_trader["name"],
            tag=config_trader["id_tag"],
        )

        # Setup logging
        level_stdout = LogLevelParser.from_str_py(
            config_log.get("level_stdout"))

        self._logger = LiveLogger(
            loop=self._loop,
            clock=self._clock,
            trader_id=self.trader_id,
            system_id=self.system_id,
            level_stdout=level_stdout,
        )

        self._log = LoggerAdapter(
            component=self.__class__.__name__,
            logger=self._logger,
        )

        self._log_header()
        self._log.info("Building...")

        if platform.system() != "Windows":
            # Requires the logger to be initialized
            # Windows does not support signal handling
            # https://stackoverflow.com/questions/45987985/asyncio-loops-add-signal-handler-in-windows
            self._setup_loop()

        # Build platform
        # ----------------------------------------------------------------------
        self.portfolio = Portfolio(
            clock=self._clock,
            logger=self._logger,
        )

        self._data_engine = LiveDataEngine(
            loop=self._loop,
            portfolio=self.portfolio,
            clock=self._clock,
            logger=self._logger,
            config={"qsize": 10000},
        )

        self.portfolio.register_cache(self._data_engine.cache)
        self.analyzer = PerformanceAnalyzer()

        if config_exec_db["type"] == "redis":
            exec_db = RedisExecutionDatabase(
                trader_id=self.trader_id,
                logger=self._logger,
                command_serializer=MsgPackCommandSerializer(),
                event_serializer=MsgPackEventSerializer(),
                config={
                    "host": config_exec_db["host"],
                    "port": config_exec_db["port"],
                },
            )
        else:
            exec_db = BypassExecutionDatabase(
                trader_id=self.trader_id,
                logger=self._logger,
            )

        self._exec_engine = LiveExecutionEngine(
            loop=self._loop,
            database=exec_db,
            portfolio=self.portfolio,
            clock=self._clock,
            logger=self._logger,
            config={"qsize": 10000},
        )

        self._risk_engine = LiveRiskEngine(
            loop=self._loop,
            exec_engine=self._exec_engine,
            portfolio=self.portfolio,
            clock=self._clock,
            logger=self._logger,
            config=config_risk,
        )

        self._exec_engine.load_cache()
        self._exec_engine.register_risk_engine(self._risk_engine)

        self.trader = Trader(
            trader_id=self.trader_id,
            strategies=strategies,
            portfolio=self.portfolio,
            data_engine=self._data_engine,
            exec_engine=self._exec_engine,
            risk_engine=self._risk_engine,
            clock=self._clock,
            logger=self._logger,
        )

        if self._load_strategy_state:
            self.trader.load()

        self._builder = TradingNodeBuilder(
            data_engine=self._data_engine,
            exec_engine=self._exec_engine,
            risk_engine=self._risk_engine,
            clock=self._clock,
            logger=self._logger,
            log=self._log,
        )

        self._log.info("state=INITIALIZED.")
        self.time_to_initialize = self._clock.delta(self.created_time)
        self._log.info(
            f"Initialized in {self.time_to_initialize.total_seconds():.3f}s.")

        self._is_built = False
예제 #5
0
class RedisExecutionDatabaseTests(unittest.TestCase):
    def setUp(self):
        # Fixture Setup
        self.clock = TestClock()
        self.logger = Logger(self.clock)
        self.trader_id = TraderId("TESTER", "000")

        self.strategy = TradingStrategy(order_id_tag="001")
        self.strategy.register_trader(self.trader_id, self.clock, self.logger)

        config = {
            "host": "localhost",
            "port": 6379,
        }

        self.database = RedisExecutionDatabase(
            trader_id=self.trader_id,
            logger=self.logger,
            command_serializer=MsgPackCommandSerializer(),
            event_serializer=MsgPackEventSerializer(),
            config=config,
        )

        self.test_redis = redis.Redis(host="localhost", port=6379, db=0)

    def tearDown(self):
        # Tests will start failing if redis is not flushed on tear down
        self.test_redis.flushall()  # Comment this line out to preserve data between tests

    def test_add_account(self):
        # Arrange
        event = TestStubs.event_account_state()
        account = Account(event)

        # Act
        self.database.add_account(account)

        # Assert
        self.assertEqual(account, self.database.load_account(account.id))

    def test_add_order(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
        )

        # Act
        self.database.add_order(order)

        # Assert
        self.assertEqual(order, self.database.load_order(order.client_order_id))

    def test_add_position(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
        )

        self.database.add_order(order)

        position_id = PositionId("P-1")
        fill = TestStubs.event_order_filled(
            order,
            instrument=AUDUSD_SIM,
            position_id=position_id,
            last_px=Price("1.00000"),
        )

        position = Position(fill=fill)

        # Act
        self.database.add_position(position)

        # Assert
        self.assertEqual(position, self.database.load_position(position.id))

    def test_update_account(self):
        # Arrange
        event = TestStubs.event_account_state()
        account = Account(event)
        self.database.add_account(account)

        # Act
        self.database.update_account(account)

        # Assert
        self.assertEqual(account, self.database.load_account(account.id))

    def test_update_order_for_working_order(self):
        # Arrange
        order = self.strategy.order_factory.stop_market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
            Price("1.00000"),
        )

        self.database.add_order(order)

        order.apply(TestStubs.event_order_submitted(order))
        self.database.update_order(order)

        order.apply(TestStubs.event_order_accepted(order))

        # Act
        self.database.update_order(order)

        # Assert
        self.assertEqual(order, self.database.load_order(order.client_order_id))

    def test_update_order_for_completed_order(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
        )

        self.database.add_order(order)

        order.apply(TestStubs.event_order_submitted(order))
        self.database.update_order(order)

        order.apply(TestStubs.event_order_accepted(order))
        self.database.update_order(order)

        fill = TestStubs.event_order_filled(
            order,
            instrument=AUDUSD_SIM,
            last_px=Price("1.00001"),
        )

        order.apply(fill)

        # Act
        self.database.update_order(order)

        # Assert
        self.assertEqual(order, self.database.load_order(order.client_order_id))

    def test_update_position_for_closed_position(self):
        # Arrange
        order1 = self.strategy.order_factory.market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
        )

        position_id = PositionId("P-1")
        self.database.add_order(order1)

        order1.apply(TestStubs.event_order_submitted(order1))
        self.database.update_order(order1)

        order1.apply(TestStubs.event_order_accepted(order1))
        self.database.update_order(order1)

        order1.apply(
            TestStubs.event_order_filled(
                order1,
                instrument=AUDUSD_SIM,
                position_id=position_id,
                last_px=Price("1.00001"),
            )
        )
        self.database.update_order(order1)

        # Act
        position = Position(fill=order1.last_event)
        self.database.add_position(position)

        order2 = self.strategy.order_factory.market(
            AUDUSD_SIM.id,
            OrderSide.SELL,
            Quantity(100000),
        )

        self.database.add_order(order2)

        order2.apply(TestStubs.event_order_submitted(order2))
        self.database.update_order(order2)

        order2.apply(TestStubs.event_order_accepted(order2))
        self.database.update_order(order2)

        filled = TestStubs.event_order_filled(
            order2,
            instrument=AUDUSD_SIM,
            position_id=position_id,
            last_px=Price("1.00001"),
        )

        order2.apply(filled)
        self.database.update_order(order2)

        position.apply(filled)

        # Act
        self.database.update_position(position)

        # Assert
        self.assertEqual(position, self.database.load_position(position.id))

    def test_update_strategy(self):
        # Arrange
        strategy = MockStrategy(TestStubs.bartype_btcusdt_binance_100tick_last())
        strategy.register_trader(self.trader_id, self.clock, self.logger)

        # Act
        self.database.update_strategy(strategy)
        result = self.database.load_strategy(strategy.id)

        # Assert
        self.assertEqual({"UserState": b"1"}, result)

    def test_load_account_when_no_account_in_database_returns_none(self):
        # Arrange
        event = TestStubs.event_account_state()
        account = Account(event)

        # Act
        result = self.database.load_account(account.id)

        # Assert
        self.assertIsNone(result)

    def test_load_account_when_account_in_database_returns_account(self):
        # Arrange
        event = TestStubs.event_account_state()
        account = Account(event)
        self.database.add_account(account)

        # Act
        result = self.database.load_account(account.id)

        # Assert
        self.assertEqual(account, result)

    def test_load_order_when_no_order_in_database_returns_none(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
        )

        # Act
        result = self.database.load_order(order.client_order_id)

        # Assert
        self.assertIsNone(result)

    def test_load_order_when_market_order_in_database_returns_order(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
        )

        self.database.add_order(order)

        # Act
        result = self.database.load_order(order.client_order_id)

        # Assert
        self.assertEqual(order, result)

    def test_load_order_when_limit_order_in_database_returns_order(self):
        # Arrange
        order = self.strategy.order_factory.limit(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
            Price("1.00000"),
        )

        self.database.add_order(order)

        # Act
        result = self.database.load_order(order.client_order_id)

        # Assert
        self.assertEqual(order, result)

    def test_load_order_when_stop_market_order_in_database_returns_order(self):
        # Arrange
        order = self.strategy.order_factory.stop_market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
            Price("1.00000"),
        )

        self.database.add_order(order)

        # Act
        result = self.database.load_order(order.client_order_id)

        # Assert
        self.assertEqual(order, result)

    def test_load_order_when_stop_limit_order_in_database_returns_order(self):
        # Arrange
        order = self.strategy.order_factory.stop_limit(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
            price=Price("1.00000"),
            trigger=Price("1.00010"),
        )

        self.database.add_order(order)

        # Act
        result = self.database.load_order(order.client_order_id)

        # Assert
        self.assertEqual(order, result)
        self.assertEqual(order.price, result.price)
        self.assertEqual(order.trigger, result.trigger)

    def test_load_position_when_no_position_in_database_returns_none(self):
        # Arrange
        position_id = PositionId("P-1")

        # Act
        result = self.database.load_position(position_id)

        # Assert
        self.assertIsNone(result)

    def test_load_order_when_position_in_database_returns_position(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
        )

        self.database.add_order(order)

        position_id = PositionId("P-1")
        fill = TestStubs.event_order_filled(
            order,
            instrument=AUDUSD_SIM,
            position_id=position_id,
            last_px=Price("1.00000"),
        )

        position = Position(fill=fill)

        self.database.add_position(position)

        # Act
        result = self.database.load_position(position_id)
        # Assert
        self.assertEqual(position, result)

    def test_load_accounts_when_no_accounts_returns_empty_dict(self):
        # Arrange
        # Act
        result = self.database.load_accounts()

        # Assert
        self.assertEqual({}, result)

    def test_load_accounts_cache_when_one_account_in_database(self):
        # Arrange
        event = TestStubs.event_account_state()
        account = Account(event)
        self.database.add_account(account)

        # Act
        # Assert
        self.assertEqual({account.id: account}, self.database.load_accounts())

    def test_load_orders_cache_when_no_orders(self):
        # Arrange
        # Act
        self.database.load_orders()

        # Assert
        self.assertEqual({}, self.database.load_orders())

    def test_load_orders_cache_when_one_order_in_database(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
        )

        self.database.add_order(order)

        # Act
        result = self.database.load_orders()

        # Assert
        self.assertEqual({order.client_order_id: order}, result)

    def test_load_positions_cache_when_no_positions(self):
        # Arrange
        # Act
        self.database.load_positions()

        # Assert
        self.assertEqual({}, self.database.load_positions())

    def test_load_positions_cache_when_one_position_in_database(self):
        # Arrange
        order1 = self.strategy.order_factory.stop_market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
            Price("1.00000"),
        )

        self.database.add_order(order1)

        position_id = PositionId("P-1")
        order1.apply(TestStubs.event_order_submitted(order1))
        order1.apply(TestStubs.event_order_accepted(order1))
        order1.apply(
            TestStubs.event_order_filled(
                order1,
                instrument=AUDUSD_SIM,
                position_id=position_id,
                last_px=Price("1.00001"),
            )
        )

        position = Position(fill=order1.last_event)
        self.database.add_position(position)

        # Act
        result = self.database.load_positions()

        # Assert
        self.assertEqual({position.id: position}, result)

    def test_delete_strategy(self):
        # Arrange
        # Act
        self.database.delete_strategy(self.strategy.id)
        result = self.database.load_strategy(self.strategy.id)

        # Assert
        self.assertEqual({}, result)

    def test_flush(self):
        # Arrange
        order1 = self.strategy.order_factory.market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
        )

        self.database.add_order(order1)

        position1_id = PositionId("P-1")
        fill = TestStubs.event_order_filled(
            order1,
            instrument=AUDUSD_SIM,
            position_id=position1_id,
            last_px=Price("1.00000"),
        )

        position1 = Position(fill=fill)
        self.database.update_order(order1)
        self.database.add_position(position1)

        order2 = self.strategy.order_factory.stop_market(
            AUDUSD_SIM.id,
            OrderSide.BUY,
            Quantity(100000),
            Price("1.00000"),
        )

        self.database.add_order(order2)

        order2.apply(TestStubs.event_order_submitted(order2))
        order2.apply(TestStubs.event_order_accepted(order2))

        self.database.update_order(order2)

        # Act
        self.database.flush()

        # Assert
        self.assertIsNone(self.database.load_order(order1.client_order_id))
        self.assertIsNone(self.database.load_order(order2.client_order_id))
        self.assertIsNone(self.database.load_position(position1.id))
예제 #6
0
class RedisExecutionDatabaseTests(unittest.TestCase):

    def setUp(self):
        # Fixture Setup
        clock = TestClock()
        uuid_factory = TestUUIDFactory()
        logger = TestLogger(clock)

        self.trader_id = TraderId("TESTER", "000")

        self.strategy = EmptyStrategy(order_id_tag="001")
        self.strategy.register_trader(
            TraderId("TESTER", "000"),
            clock,
            uuid_factory,
            logger,
        )

        config = {
            'host': 'localhost',
            'port': 6379
        }

        self.database = RedisExecutionDatabase(
            trader_id=self.trader_id,
            logger=logger,
            command_serializer=MsgPackCommandSerializer(),
            event_serializer=MsgPackEventSerializer(),
            config=config
        )

        self.test_redis = redis.Redis(host="localhost", port=6379, db=0)

    def tearDown(self):
        # Tests will start failing if redis is not flushed on tear down
        self.test_redis.flushall()  # Comment this line out to preserve data between tests
        pass

    def test_keys(self):
        # Arrange
        # Act
        # Assert
        self.assertEqual("Trader-TESTER-000", self.database.key_trader)
        self.assertEqual("Trader-TESTER-000:Accounts:", self.database.key_accounts)
        self.assertEqual("Trader-TESTER-000:Orders:", self.database.key_orders)
        self.assertEqual("Trader-TESTER-000:Positions:", self.database.key_positions)
        self.assertEqual("Trader-TESTER-000:Strategies:", self.database.key_strategies)

    # TODO: AccountState not finalized (no serialization yet)
    # def test_add_account(self):
    #     # Arrange
    #     event = TestStubs.account_event()
    #     account = Account(event)
    #
    #     # Act
    #     self.database.add_account(account)
    #
    #     # Assert
    #     self.assertEqual(account, self.database.load_account(account.id))

    def test_add_order(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000),
        )

        position_id = PositionId.py_null()

        # Act
        self.database.add_order(order, position_id)

        # Assert
        self.assertEqual(order, self.database.load_order(order.cl_ord_id))

    def test_add_position(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000))
        position_id = PositionId('P-1')
        self.database.add_order(order, position_id)

        order_filled = TestStubs.event_order_filled(order, position_id=position_id, fill_price=Price("1.00000"))
        position = Position(order_filled)

        # Act
        self.database.add_position(position)

        # Assert
        self.assertEqual(position, self.database.load_position(position.id))

    # TODO: Investigate why str is deserializing to a list
    # def test_update_account(self):
    #     # Arrange
    #     event = TestStubs.account_event()
    #     account = Account(event)
    #     self.database.add_account(account)
    #
    #     # Act
    #     self.database.update_account(account)
    #
    #     # Assert
    #     self.assertEqual(account, self.database.load_account(account.id))

    def test_update_order_for_working_order(self):
        # Arrange
        order = self.strategy.order_factory.stop(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000),
            Price("1.00000"),
        )

        position_id = PositionId('P-1')
        self.database.add_order(order, position_id)

        order.apply(TestStubs.event_order_submitted(order))
        self.database.update_order(order)

        order.apply(TestStubs.event_order_accepted(order))
        self.database.update_order(order)

        # Act
        order.apply(TestStubs.event_order_working(order))
        self.database.update_order(order)

        # Assert
        self.assertEqual(order, self.database.load_order(order.cl_ord_id))

    def test_update_order_for_completed_order(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000))
        position_id = PositionId('P-1')
        self.database.add_order(order, position_id)

        order.apply(TestStubs.event_order_submitted(order))
        self.database.update_order(order)

        order.apply(TestStubs.event_order_accepted(order))
        self.database.update_order(order)

        order.apply(TestStubs.event_order_filled(order, fill_price=Price("1.00001")))

        # Act
        self.database.update_order(order)

        # Assert
        self.assertEqual(order, self.database.load_order(order.cl_ord_id))

    def test_update_position_for_closed_position(self):
        # Arrange
        order1 = self.strategy.order_factory.market(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000))
        position_id = PositionId('P-1')
        self.database.add_order(order1, position_id)

        order1.apply(TestStubs.event_order_submitted(order1))
        self.database.update_order(order1)

        order1.apply(TestStubs.event_order_accepted(order1))
        self.database.update_order(order1)

        order1.apply(TestStubs.event_order_filled(order1, position_id=position_id, fill_price=Price("1.00001")))
        self.database.update_order(order1)

        # Act
        position = Position(order1.last_event())
        self.database.add_position(position)

        order2 = self.strategy.order_factory.market(
            AUDUSD_FXCM,
            OrderSide.SELL,
            Quantity(100000))
        self.database.add_order(order2, position_id)

        order2.apply(TestStubs.event_order_submitted(order2))
        self.database.update_order(order2)

        order2.apply(TestStubs.event_order_accepted(order2))
        self.database.update_order(order2)

        filled = TestStubs.event_order_filled(order2, position_id=position_id, fill_price=Price("1.00001"))
        order2.apply(filled)
        self.database.update_order(order2)

        position.apply(filled)

        # Act
        self.database.update_position(position)

        # Assert
        self.assertEqual(position, self.database.load_position(position.id))

    def test_load_account_when_no_account_in_database_returns_none(self):
        # Arrange
        event = TestStubs.event_account_state()
        account = Account(event)

        # Act
        result = self.database.load_account(account.id)

        # Assert
        self.assertIsNone(result)

    # TODO: Investigate why str is deserializing to a list
    # def test_load_account_when_account_in_database_returns_account(self):
    #     # Arrange
    #     event = TestStubs.account_event()
    #     account = Account(event)
    #     self.database.add_account(account)
    #
    #     # Act
    #     result = self.database.load_account(account.id)
    #
    #     # Assert
    #     self.assertEqual(account, result)

    def test_load_order_when_no_order_in_database_returns_none(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000))

        # Act
        result = self.database.load_order(order.cl_ord_id)

        # Assert
        self.assertIsNone(result)

    def test_load_order_when_order_in_database_returns_order(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000))
        position_id = PositionId('P-1')
        self.database.add_order(order, position_id)

        # Act
        result = self.database.load_order(order.cl_ord_id)

        # Assert
        self.assertEqual(order, result)

    def test_load_position_when_no_position_in_database_returns_none(self):
        # Arrange
        position_id = PositionId('P-1')

        # Act
        result = self.database.load_position(position_id)

        # Assert
        self.assertIsNone(result)

    def test_load_order_when_position_in_database_returns_position(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000))
        position_id = PositionId('P-1')
        self.database.add_order(order, position_id)

        order_filled = TestStubs.event_order_filled(order, position_id=position_id, fill_price=Price("1.00000"))
        position = Position(order_filled)

        self.database.add_position(position)

        # Act
        result = self.database.load_position(position_id)
        # Assert
        self.assertEqual(position, result)

    def test_load_accounts_when_no_accounts_returns_empty_dict(self):
        # Arrange
        # Act
        result = self.database.load_accounts()

        # Assert
        self.assertEqual({}, result)

    # TODO: Investigate why str is deserializing to a list
    # def test_load_accounts_cache_when_one_account_in_database(self):
    #     # Arrange
    #     event = TestStubs.account_event()
    #     account = Account(event)
    #     self.database.add_account(account)
    #
    #     # Act
    #
    #     # Assert
    #     self.assertEqual(account, self.database.load_account(account.id))

    def test_load_orders_cache_when_no_orders(self):
        # Arrange
        # Act
        self.database.load_orders()

        # Assert
        self.assertEqual({}, self.database.load_orders())

    def test_load_orders_cache_when_one_order_in_database(self):
        # Arrange
        order = self.strategy.order_factory.market(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000),
        )

        position_id = PositionId('P-1')
        self.database.add_order(order, position_id)

        # Act
        result = self.database.load_orders()

        # Assert
        self.assertEqual({order.cl_ord_id: order}, result)

    def test_load_positions_cache_when_no_positions(self):
        # Arrange
        # Act
        self.database.load_positions()

        # Assert
        self.assertEqual({}, self.database.load_positions())

    def test_load_positions_cache_when_one_position_in_database(self):
        # Arrange
        order1 = self.strategy.order_factory.stop(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000),
            Price("1.00000"),
        )

        position_id = PositionId('P-1')
        self.database.add_order(order1, position_id)

        order1.apply(TestStubs.event_order_submitted(order1))
        order1.apply(TestStubs.event_order_accepted(order1))
        order1.apply(TestStubs.event_order_working(order1))
        order1.apply(TestStubs.event_order_filled(order1, position_id=position_id, fill_price=Price("1.00001")))

        position = Position(order1.last_event())
        self.database.add_position(position)

        # Act
        result = self.database.load_positions()

        # Assert
        self.assertEqual({position.id: position}, result)

    def test_can_delete_strategy(self):
        # Arrange
        # Act
        self.database.delete_strategy(self.strategy.id)
        result = self.database.load_strategy(self.strategy.id)

        # Assert
        self.assertEqual({}, result)

    def test_can_flush(self):
        # Arrange
        order1 = self.strategy.order_factory.market(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000))
        position1_id = PositionId('P-1')
        self.database.add_order(order1, position1_id)

        filled = TestStubs.event_order_filled(order1, position_id=position1_id, fill_price=Price("1.00000"))
        position1 = Position(filled)
        self.database.update_order(order1)
        self.database.add_position(position1)

        order2 = self.strategy.order_factory.stop(
            AUDUSD_FXCM,
            OrderSide.BUY,
            Quantity(100000),
            Price("1.00000"))

        position2_id = PositionId('P-2')
        self.database.add_order(order2, position2_id)

        order2.apply(TestStubs.event_order_submitted(order2))
        order2.apply(TestStubs.event_order_accepted(order2))
        order2.apply(TestStubs.event_order_working(order2))

        self.database.update_order(order2)

        # Act
        self.database.flush()