예제 #1
0
 def test_not_usable_agent(self):
     """Checks if non-usable agents raise error with backtest.
     """
     with self.assertRaises(ValueError):
         agent = api.get_agent_object("sarsa_learning_agent_1",
                                      model_name="linear")
         backtest_agent(agent)
예제 #2
0
 def test_multiple_backtests(self):
     """Checks if can run the backtest multiple times on an agent.
     """
     agent = api.get_agent_object()
     output = backtest_agent(agent)
     self.is_output_valid(output)
     output = backtest_agent(agent)
     self.is_output_valid(output)
예제 #3
0
 def test_backtest_with_kwargs(self):
     """Checks if backtesting works for agents that produce kwargs.
     """
     agent = api.get_agent_object("sarsa_learning_agent_1",
                                  model_name="linear")
     agent.trained = True
     output = backtest_agent(agent)
     self.is_output_valid(output)
     self.assertIn("sa_value", output)
예제 #4
0
    def test_backtest(self, agent_name, data_collection_name):
        """Checks if backtesting works.

        Args:
            agent_name: the agent to test backtesting with.
            data_collection_name: the data collection to test backtesting with.
        """
        agent = api.get_agent_object(agent_name, data_collection_name)
        output = backtest_agent(agent,
                                from_date=datetime(2015, 2, 1),
                                to_date=datetime(2015, 3, 1))
        self.is_output_valid(output)
예제 #5
0
    def test_training(self, agent_name, data_collection_name, reward_name, model_name):
        """Checks if training works.

        Args:
            agent_name: the agent to test training with.
            data_collection_name: the data collection to test training with.
            reward_name: the reward to test training with.
            model_name: the model to test training with.
        """
        agent = api.get_agent_object(agent_name, data_collection_name, reward_name, model_name)
        _, loss_history = train_agent(agent, episode_batch_size=2, num_episodes=4,
                                      min_duration=10, max_duration=20)
        self.assertTrue(loss_history[0] > loss_history[-1])
        self.assertTrue(agent.usable)
예제 #6
0
 def test_get_agent_object(self):
     """Checks if get_agent_object works properly
     """
     agent = api.get_agent_object("following_feature_agent_1", "default")
     self.assertIsInstance(agent, FollowingFeatureAgent)
예제 #7
0
"""Demonstration of how to use the API.
"""
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from stock_trading_backend import api, train, backtest

agent = api.get_agent_object("q_learning_agent", "generated_1",
                             "net_worth_ratio", "neural_network")
reward_history, loss_history = train.train_agent(agent, episode_batch_size=5, num_episodes=50,
                                                 min_duration=100, max_duration=150,
                                                 commission=0.01)

N = 15
reward_history = np.array(reward_history)
reward_history_avg = np.convolve(reward_history, np.ones(N)/N, mode="same")
fig, axs =  plt.subplots(2, figsize=(10, 10))
axs[0].plot(reward_history)
axs[0].plot(reward_history_avg)
axs[0].plot([-1, len(reward_history)], [0, 0], 'r--')
axs[0].set_title("Reward history vs batch number")

axs[1].plot(loss_history)
axs[1].set_title("Loss history vs batch number")
axs[1].set_yscale("log")
plt.savefig("demo_training.png")

reward_history, loss_history = train.train_agent(agent, episode_batch_size=5, num_episodes=50,
                                                 min_duration=100, max_duration=150,
                                                 commission=0.01, training=False)
예제 #8
0
 def test_non_trainable_agent(self):
     """Test if providing non-trainable agent raises error.
     """
     agent = api.get_agent_object("following_feature_agent_1")
     with self.assertRaises(ValueError):
         _, _ = train_agent(agent)