示例#1
0
    def test_catches_config_error(self, config_filename, excpecetd_exception):
        """Checks if data collection catches configuration errors.

        Args:
            config_filename: the filename for the config file.
            excpecetd_exception: the expected exception.
        """
        with self.assertRaises(excpecetd_exception):
            create_data_collection(read_config_file(config_filename))
示例#2
0
 def test_checks_date_range_set(self):
     """Checks if data collection catches when date range is not set.
     """
     config = read_config_file("test/data_collection.yaml")
     data_collection = create_data_collection(config)
     with self.assertRaises(ValueError):
         data_collection.prepare_data()
示例#3
0
 def test_catches_circular_dependency(self):
     """Checks if data collection catches circular dependency.
     """
     config = read_config_file(
         "test/circular_dependency_data_collection.yaml")
     data_collection = create_data_collection(config)
     with self.assertRaises(ValueError):
         data_collection.reset()
示例#4
0
 def test_use_relative(self):
     """Checks if data collection uses relative stock data properly.
     """
     config = read_config_file("test/data_collection.yaml")
     config["use_relative_stock_data"] = True
     data_collection = create_data_collection(config)
     self.assertEqual(4, len(data_collection.data_objects))
     self.assertNotEqual(data_collection.stock_data_id,
                         data_collection.absolute_stock_data_id)
示例#5
0
 def test_ignores_duplicates(self):
     """Checks if data collection ignores duplicates.
     """
     config = read_config_file("test/data_collection.yaml")
     data_collection = create_data_collection(config)
     stock_data = RealStockData()
     returned_stock_data = data_collection.append(stock_data)
     self.assertIsInstance(returned_stock_data, RealStockData)
     self.assertNotEqual(stock_data, returned_stock_data)
示例#6
0
 def test_get_availbale_dates(self):
     """Checks if get available dates works properly.
     """
     from_date = datetime(2016, 1, 1)
     to_date = datetime(2016, 2, 1)
     config = read_config_file("test/data_collection.yaml")
     data_collection = create_data_collection(config)
     data_collection.set_date_range(from_date, to_date)
     data_collection.prepare_data()
     available_dates = data_collection.get_available_dates()
     self.assertEqual(10, len(available_dates))
示例#7
0
    def test_adds_randomization(self):
        """Checks if data collection adds randomization properly.
        """
        config = read_config_file("test/data_collection.yaml")
        config["stock_data_randomization"] = True
        data_collection = create_data_collection(config)
        self.assertEqual(3, len(data_collection.data_objects))

        config = read_config_file("test/simulation.yaml")
        config["stock_data_randomization"] = False
        data_collection = create_data_collection(config)
        self.assertEqual(1, len(data_collection.data_objects))

        config = read_config_file("test/simulation.yaml")
        config["stock_data_randomization"] = True
        data_collection = create_data_collection(config)
        self.assertEqual(2, len(data_collection.data_objects))
        randomization_layer = data_collection.data_objects[0]
        stock_data = data_collection.data_objects[1]
        self.assertEqual(randomization_layer.dependencies[0],
                         stock_data.id_str)
示例#8
0
 def _add_id_with_hash_values(self):
     """Adds hash values to the id_str
     """
     data_collection = create_data_collection(self.data_collection_config)
     reward = None
     model = None
     if not self.reward_config is None:
         reward = create_reward(self.reward_config, None, None)
     if not self.model_config is None:
         model = create_model(self.model_config)
     self.id_str_with_hash = "{}_{}_{}_{}".format(self.id_str,
                                                  hash(data_collection),
                                                  hash(reward), hash(model))
示例#9
0
 def test_prepares_data(self):
     """Checks if data collection prepares the data properly.
     """
     from_date = datetime(2016, 1, 1)
     to_date = datetime(2016, 2, 1)
     config = read_config_file("test/data_collection.yaml")
     data_collection = create_data_collection(config)
     self.assertEqual(data_collection.stock_data_id,
                      data_collection.absolute_stock_data_id)
     data_collection.set_date_range(from_date, to_date)
     data_collection.prepare_data()
     for data in data_collection.data_objects:
         self.assertTrue(data.ready)
示例#10
0
 def test_getitem(self):
     """Checks if __getitem__ works as expected.
     """
     expected_index = [
         "GOOG", "AMZN", "ra_10_stock_data_GOOG", "ra_10_stock_data_AMZN"
     ]
     from_date = datetime(2016, 1, 1)
     to_date = datetime(2016, 2, 1)
     config = read_config_file("test/data_collection.yaml")
     data_collection = create_data_collection(config)
     data_collection.set_date_range(from_date, to_date)
     data_collection.prepare_data()
     available_dates = data_collection.get_available_dates()
     self.assertTrue((expected_index == data_collection[
         available_dates[0]].index.tolist()))
示例#11
0
 def test_resets_data(self):
     """Checks if data collection prepares the data properly.
     """
     from_date = datetime(2016, 1, 1)
     to_date = datetime(2016, 2, 1)
     config = read_config_file("test/data_collection.yaml")
     data_collection = create_data_collection(config)
     data_collection.set_date_range(from_date, to_date)
     data_collection.prepare_data()
     for data in data_collection.data_objects:
         self.assertTrue(data.ready)
     self.assertEqual(3, data_collection.recursive_counter)
     data_collection.reset()
     for data in data_collection.data_objects:
         if data.name != "real_stock_data":
             self.assertFalse(data.ready)
         else:
             self.assertTrue(data.ready)
     data_collection.prepare_data()
     self.assertEqual(2, data_collection.recursive_counter)
 def test_creates_data_collection(self):
     """Checks if creates data collection.
     """
     data_collection = create_data_collection(
         read_config_file("test/data_collection.yaml"))
     self.assertIsInstance(data_collection, DataCollection)
    def __init__(self, data_collection_config=None, from_date=None, to_date=None, min_duration=0,
                 max_duration=0, min_start_balance=1000, max_start_balance=1000, commission=0,
                 max_stock_owned=1, stock_data_randomization=False, reward_config=None):
        """Initializer for the simulation class.

        Args:
            data_collection_config: configuration of the data configuration.
            from_date: datetime date for the start of the range
            to_date: datetime date for the end of the range
            min_duration: minimum length of the episode.
            max_duration: maximum length of the episode (if 0 will run for all available dates).
            min_start_balance: minimum starting balance.
            max_start_balance: maximum starting balance. Balance selected unifromly.
            commission: relative commission for each transcation.
            max_stock_owned: a maximum number of different stocks that can be owned.
            stock_data_randomization: whether to add stock data randomization.
            reward_config: the configuration for the reward.
        """
        if data_collection_config is None:
            data_collection_config = read_config_file(DEFAULT_DATA_COLLECTION_CONFIG_FILE)
        data_collection_config["stock_data_randomization"] = stock_data_randomization
        self.data_collection_config = data_collection_config

        if from_date is None and to_date is None:
            from_date = datetime(2014, 1, 1)
            to_date = datetime(2016, 1, 1)
        elif from_date is None or to_date is None:
            raise ValueError("Either both from and to dates are None or none of them.")

        self.data_collection = create_data_collection(data_collection_config)

        # Setup of simulation data.
        stock_data_id = self.data_collection.absolute_stock_data_id
        self.stock_data = self.data_collection.id_to_data[stock_data_id]
        self.stock_names = self.data_collection.stock_names
        self.balance = 0
        self.net_worth = 0
        self.owned_stocks = None

        # Adding buffer days.
        buffer = self.data_collection.get_buffer()
        from_date -= timedelta(days=buffer)

        # Setting date range for data collection.
        self.data_collection.set_date_range(from_date, to_date)
        self.data_collection.prepare_data()
        self.available_dates = self.data_collection.get_available_dates()

        # Setting duration range.
        max_duration = min(max_duration, len(self.available_dates))
        self.max_duration = max_duration if max_duration > 0 else len(self.available_dates)
        self.min_duration = min_duration if min_duration > 0 else self.max_duration
        self.min_duration = min(self.max_duration, self.min_duration)

        # Setting starting balance range.
        self.min_start_balance = min_start_balance
        self.max_start_balance = max_start_balance

        # Setting date tracking.
        self.curr_date_index = -1
        self.from_date_index = -1
        self.to_date_index = -1

        # Setting commission and max stock owned.
        self.commission = commission
        self.max_stock_owned = max_stock_owned

        # Setting up action space.
        self.action_space = (
            gym.spaces.MultiDiscrete([3] * len(self.data_collection.stock_names))
        )

        # Setting up reward function.
        if reward_config is None:
            reward_config = {"name": "net_worth_ratio_reward"}
        self.reward_function = create_reward(reward_config, from_date, to_date)

        # Setting up observation cache.
        self.saved_date_index = -1
        self.saved_observation = None
示例#14
0
 def test_hash(self):
     """Checks if __hash__ works.
     """
     config = read_config_file("test/data_collection.yaml")
     data_collection = create_data_collection(config)
     self.assertIsNotNone(hash(data_collection))
示例#15
0
 def test_get_buffer(self):
     """Checks if buffer days is calculated correctly.
     """
     config = read_config_file("test/data_collection.yaml")
     data_collection = create_data_collection(config)
     self.assertEqual(10, data_collection.get_buffer())