def test_not_usable_agent(self): """Checks if non-usable agents raise error with backtest. """ with self.assertRaises(ValueError): agent = api.get_agent_object("sarsa_learning_agent_1", model_name="linear") backtest_agent(agent)
def test_multiple_backtests(self): """Checks if can run the backtest multiple times on an agent. """ agent = api.get_agent_object() output = backtest_agent(agent) self.is_output_valid(output) output = backtest_agent(agent) self.is_output_valid(output)
def test_backtest_with_kwargs(self): """Checks if backtesting works for agents that produce kwargs. """ agent = api.get_agent_object("sarsa_learning_agent_1", model_name="linear") agent.trained = True output = backtest_agent(agent) self.is_output_valid(output) self.assertIn("sa_value", output)
def test_backtest(self, agent_name, data_collection_name): """Checks if backtesting works. Args: agent_name: the agent to test backtesting with. data_collection_name: the data collection to test backtesting with. """ agent = api.get_agent_object(agent_name, data_collection_name) output = backtest_agent(agent, from_date=datetime(2015, 2, 1), to_date=datetime(2015, 3, 1)) self.is_output_valid(output)
def test_training(self, agent_name, data_collection_name, reward_name, model_name): """Checks if training works. Args: agent_name: the agent to test training with. data_collection_name: the data collection to test training with. reward_name: the reward to test training with. model_name: the model to test training with. """ agent = api.get_agent_object(agent_name, data_collection_name, reward_name, model_name) _, loss_history = train_agent(agent, episode_batch_size=2, num_episodes=4, min_duration=10, max_duration=20) self.assertTrue(loss_history[0] > loss_history[-1]) self.assertTrue(agent.usable)
def test_get_agent_object(self): """Checks if get_agent_object works properly """ agent = api.get_agent_object("following_feature_agent_1", "default") self.assertIsInstance(agent, FollowingFeatureAgent)
"""Demonstration of how to use the API. """ import matplotlib import matplotlib.pyplot as plt import numpy as np from stock_trading_backend import api, train, backtest agent = api.get_agent_object("q_learning_agent", "generated_1", "net_worth_ratio", "neural_network") reward_history, loss_history = train.train_agent(agent, episode_batch_size=5, num_episodes=50, min_duration=100, max_duration=150, commission=0.01) N = 15 reward_history = np.array(reward_history) reward_history_avg = np.convolve(reward_history, np.ones(N)/N, mode="same") fig, axs = plt.subplots(2, figsize=(10, 10)) axs[0].plot(reward_history) axs[0].plot(reward_history_avg) axs[0].plot([-1, len(reward_history)], [0, 0], 'r--') axs[0].set_title("Reward history vs batch number") axs[1].plot(loss_history) axs[1].set_title("Loss history vs batch number") axs[1].set_yscale("log") plt.savefig("demo_training.png") reward_history, loss_history = train.train_agent(agent, episode_batch_size=5, num_episodes=50, min_duration=100, max_duration=150, commission=0.01, training=False)
def test_non_trainable_agent(self): """Test if providing non-trainable agent raises error. """ agent = api.get_agent_object("following_feature_agent_1") with self.assertRaises(ValueError): _, _ = train_agent(agent)