def test_catches_config_error(self, config_filename, excpecetd_exception): """Checks if data collection catches configuration errors. Args: config_filename: the filename for the config file. excpecetd_exception: the expected exception. """ with self.assertRaises(excpecetd_exception): create_data_collection(read_config_file(config_filename))
def test_checks_date_range_set(self): """Checks if data collection catches when date range is not set. """ config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) with self.assertRaises(ValueError): data_collection.prepare_data()
def test_catches_circular_dependency(self): """Checks if data collection catches circular dependency. """ config = read_config_file( "test/circular_dependency_data_collection.yaml") data_collection = create_data_collection(config) with self.assertRaises(ValueError): data_collection.reset()
def test_use_relative(self): """Checks if data collection uses relative stock data properly. """ config = read_config_file("test/data_collection.yaml") config["use_relative_stock_data"] = True data_collection = create_data_collection(config) self.assertEqual(4, len(data_collection.data_objects)) self.assertNotEqual(data_collection.stock_data_id, data_collection.absolute_stock_data_id)
def test_ignores_duplicates(self): """Checks if data collection ignores duplicates. """ config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) stock_data = RealStockData() returned_stock_data = data_collection.append(stock_data) self.assertIsInstance(returned_stock_data, RealStockData) self.assertNotEqual(stock_data, returned_stock_data)
def test_get_availbale_dates(self): """Checks if get available dates works properly. """ from_date = datetime(2016, 1, 1) to_date = datetime(2016, 2, 1) config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) data_collection.set_date_range(from_date, to_date) data_collection.prepare_data() available_dates = data_collection.get_available_dates() self.assertEqual(10, len(available_dates))
def test_adds_randomization(self): """Checks if data collection adds randomization properly. """ config = read_config_file("test/data_collection.yaml") config["stock_data_randomization"] = True data_collection = create_data_collection(config) self.assertEqual(3, len(data_collection.data_objects)) config = read_config_file("test/simulation.yaml") config["stock_data_randomization"] = False data_collection = create_data_collection(config) self.assertEqual(1, len(data_collection.data_objects)) config = read_config_file("test/simulation.yaml") config["stock_data_randomization"] = True data_collection = create_data_collection(config) self.assertEqual(2, len(data_collection.data_objects)) randomization_layer = data_collection.data_objects[0] stock_data = data_collection.data_objects[1] self.assertEqual(randomization_layer.dependencies[0], stock_data.id_str)
def _add_id_with_hash_values(self): """Adds hash values to the id_str """ data_collection = create_data_collection(self.data_collection_config) reward = None model = None if not self.reward_config is None: reward = create_reward(self.reward_config, None, None) if not self.model_config is None: model = create_model(self.model_config) self.id_str_with_hash = "{}_{}_{}_{}".format(self.id_str, hash(data_collection), hash(reward), hash(model))
def test_prepares_data(self): """Checks if data collection prepares the data properly. """ from_date = datetime(2016, 1, 1) to_date = datetime(2016, 2, 1) config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) self.assertEqual(data_collection.stock_data_id, data_collection.absolute_stock_data_id) data_collection.set_date_range(from_date, to_date) data_collection.prepare_data() for data in data_collection.data_objects: self.assertTrue(data.ready)
def test_getitem(self): """Checks if __getitem__ works as expected. """ expected_index = [ "GOOG", "AMZN", "ra_10_stock_data_GOOG", "ra_10_stock_data_AMZN" ] from_date = datetime(2016, 1, 1) to_date = datetime(2016, 2, 1) config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) data_collection.set_date_range(from_date, to_date) data_collection.prepare_data() available_dates = data_collection.get_available_dates() self.assertTrue((expected_index == data_collection[ available_dates[0]].index.tolist()))
def test_resets_data(self): """Checks if data collection prepares the data properly. """ from_date = datetime(2016, 1, 1) to_date = datetime(2016, 2, 1) config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) data_collection.set_date_range(from_date, to_date) data_collection.prepare_data() for data in data_collection.data_objects: self.assertTrue(data.ready) self.assertEqual(3, data_collection.recursive_counter) data_collection.reset() for data in data_collection.data_objects: if data.name != "real_stock_data": self.assertFalse(data.ready) else: self.assertTrue(data.ready) data_collection.prepare_data() self.assertEqual(2, data_collection.recursive_counter)
def test_creates_data_collection(self): """Checks if creates data collection. """ data_collection = create_data_collection( read_config_file("test/data_collection.yaml")) self.assertIsInstance(data_collection, DataCollection)
def __init__(self, data_collection_config=None, from_date=None, to_date=None, min_duration=0, max_duration=0, min_start_balance=1000, max_start_balance=1000, commission=0, max_stock_owned=1, stock_data_randomization=False, reward_config=None): """Initializer for the simulation class. Args: data_collection_config: configuration of the data configuration. from_date: datetime date for the start of the range to_date: datetime date for the end of the range min_duration: minimum length of the episode. max_duration: maximum length of the episode (if 0 will run for all available dates). min_start_balance: minimum starting balance. max_start_balance: maximum starting balance. Balance selected unifromly. commission: relative commission for each transcation. max_stock_owned: a maximum number of different stocks that can be owned. stock_data_randomization: whether to add stock data randomization. reward_config: the configuration for the reward. """ if data_collection_config is None: data_collection_config = read_config_file(DEFAULT_DATA_COLLECTION_CONFIG_FILE) data_collection_config["stock_data_randomization"] = stock_data_randomization self.data_collection_config = data_collection_config if from_date is None and to_date is None: from_date = datetime(2014, 1, 1) to_date = datetime(2016, 1, 1) elif from_date is None or to_date is None: raise ValueError("Either both from and to dates are None or none of them.") self.data_collection = create_data_collection(data_collection_config) # Setup of simulation data. stock_data_id = self.data_collection.absolute_stock_data_id self.stock_data = self.data_collection.id_to_data[stock_data_id] self.stock_names = self.data_collection.stock_names self.balance = 0 self.net_worth = 0 self.owned_stocks = None # Adding buffer days. buffer = self.data_collection.get_buffer() from_date -= timedelta(days=buffer) # Setting date range for data collection. self.data_collection.set_date_range(from_date, to_date) self.data_collection.prepare_data() self.available_dates = self.data_collection.get_available_dates() # Setting duration range. max_duration = min(max_duration, len(self.available_dates)) self.max_duration = max_duration if max_duration > 0 else len(self.available_dates) self.min_duration = min_duration if min_duration > 0 else self.max_duration self.min_duration = min(self.max_duration, self.min_duration) # Setting starting balance range. self.min_start_balance = min_start_balance self.max_start_balance = max_start_balance # Setting date tracking. self.curr_date_index = -1 self.from_date_index = -1 self.to_date_index = -1 # Setting commission and max stock owned. self.commission = commission self.max_stock_owned = max_stock_owned # Setting up action space. self.action_space = ( gym.spaces.MultiDiscrete([3] * len(self.data_collection.stock_names)) ) # Setting up reward function. if reward_config is None: reward_config = {"name": "net_worth_ratio_reward"} self.reward_function = create_reward(reward_config, from_date, to_date) # Setting up observation cache. self.saved_date_index = -1 self.saved_observation = None
def test_hash(self): """Checks if __hash__ works. """ config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) self.assertIsNotNone(hash(data_collection))
def test_get_buffer(self): """Checks if buffer days is calculated correctly. """ config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) self.assertEqual(10, data_collection.get_buffer())