def test_make_decision(self): """A test to see if make decision works properly. """ from_date = datetime(2016, 1, 1) to_date = datetime(2016, 2, 1) data_collection_config = read_config_file("test/simulation.yaml") features = [read_config_file("test/running_average_analysis.yaml")] agent = FollowingFeatureAgent( data_collection_config=data_collection_config, reward_config=None, features=features) simulation = StockMarketSimulation(data_collection_config, from_date, to_date, min_start_balance=100, max_start_balance=100, max_stock_owned=2) observation = simulation.reset() def check(val_1, val_2, exp_1, exp_2): observation["ra_5_stock_data_GOOG"] = val_1 observation["ra_5_stock_data_AMZN"] = val_2 expected = [exp_1, exp_2] action, _ = agent.make_decision(observation, simulation) self.assertTrue((expected == action)) for i, j in itertools.combinations([0, 1], 2): check(i, j, i + 1, j + 1) observation, _, _ = simulation.step([2, 2]) for i, j in itertools.combinations([0, 1], 2): check(i, j, i, j)
def test_save_and_load(self): """Checks if saving and loading functin works properly. """ data_collection_config = read_config_file("test/simulation.yaml") model_config = read_config_file("model/linear.yaml") agent = SARSALearningAgent( data_collection_config=data_collection_config, model_config=model_config) simulation = StockMarketSimulation(data_collection_config) observation = simulation.reset() _, _ = agent.make_decision(observation, simulation) agent.id_str_with_hash = "test" agent.trained = True agent.save() self.assertTrue(agent.can_be_loaded()) agent.trained = False agent.load() self.assertTrue(agent.trained) with self.assertRaises(ValueError): agent.trained = False agent.save()
def test_initializes(self): """A test to see if agent is initialized properly. """ data_collection_config = read_config_file("test/simulation.yaml") model_config = read_config_file("model/linear.yaml") agent = QLearningAgent(data_collection_config, model_config=model_config) self.assertEqual(data_collection_config, agent.data_collection_config) self.assertFalse(agent.usable)
def test_initializes(self): """A test to see if agent is initialized properly. """ data_collection_config = read_config_file("test/simulation.yaml") features = [read_config_file("test/running_average_analysis.yaml")] agent = FollowingFeatureAgent( data_collection_config=data_collection_config, reward_config=None, features=features) self.assertEqual(data_collection_config, agent.data_collection_config) self.assertEqual("ra_5_stock_data_{}", agent.feature_template)
def test_creates_agent(self, config_filename, expected_class): """Checks if created agent class is of the right class. Args: config_filename: the filename for the config file. expected_class: the expected class created from config file. """ data_collection_config = read_config_file("data/default.yaml") agent = create_agent(read_config_file(config_filename), data_collection_config, None) self.assertIsInstance(agent, expected_class)
def test_make_decision(self): """A test to see if agent can make decisions. """ data_collection_config = read_config_file("test/simulation.yaml") model_config = read_config_file("model/linear.yaml") agent = SARSALearningAgent( data_collection_config=data_collection_config, model_config=model_config) simulation = StockMarketSimulation(data_collection_config) observation = simulation.reset() # Testing whether sarsa learning agent makes a valid decision. action, _ = agent.make_decision(observation, simulation) self.assertEqual(2, len(action))
def test_too_expensive(self): """Test for ignoring expensive purchases. """ from_date = datetime(2016, 1, 1) to_date = datetime(2016, 1, 5) data_collection_config = read_config_file("test/simulation.yaml") simulation = StockMarketSimulation(data_collection_config, from_date, to_date, min_start_balance=10, max_start_balance=10, max_stock_owned=1) _ = simulation.reset() observation, _, done = simulation.step([2, 2]) self.assertEqual(0, observation["balance"]) self.assertEqual(10, observation["net_worth"]) self.assertEqual(0, observation["owned_GOOG"]) self.assertEqual(1, observation["owned_AMZN"]) self.assertFalse(done) observation, _, done = simulation.step([2, 1]) self.assertEqual(0, observation["balance"]) self.assertEqual(10, observation["net_worth"]) self.assertEqual(0, observation["owned_GOOG"]) self.assertEqual(1, observation["owned_AMZN"]) self.assertFalse(done)
def test_without_model_errer(self): """Checks if the agent raises error when initialized without model config. """ with self.assertRaises(ValueError): data_collection_config = read_config_file("test/simulation.yaml") _ = SARSALearningAgent( data_collection_config=data_collection_config)
def test_not_implemented(self): """Test not-implemented features. """ with self.assertRaises(NotImplementedError): _ = read_config_file("test/test_config.yaml", FileSourceType.aws) with self.assertRaises(NotImplementedError): _ = read_manifest_file("data/test/test_manifest.json", FileSourceType.aws) with self.assertRaises(NotImplementedError): write_manifest_file({}, "data/test/test_manifest.json", FileSourceType.aws) with self.assertRaises(NotImplementedError): _ = read_csv_file("data/test/test.csv", FileSourceType.aws) with self.assertRaises(NotImplementedError): write_csv_file(None, "data/test/test.csv", FileSourceType.aws) with self.assertRaises(NotImplementedError): _ = load_torch_model("data/test/test.pkl", FileSourceType.aws) with self.assertRaises(NotImplementedError): save_torch_model(None, "data/test/test.pkl", FileSourceType.aws)
def test_config_reader(self): """Unit test for config reader. """ config = read_config_file("test/test_config.yaml") self.assertEqual(type({}), type(config)) self.assertIn("test", config) self.assertEqual("test", config["test"])
def test_checks_date_range_set(self): """Checks if data collection catches when date range is not set. """ config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) with self.assertRaises(ValueError): data_collection.prepare_data()
def test_catches_circular_dependency(self): """Checks if data collection catches circular dependency. """ config = read_config_file( "test/circular_dependency_data_collection.yaml") data_collection = create_data_collection(config) with self.assertRaises(ValueError): data_collection.reset()
def test_use_relative(self): """Checks if data collection uses relative stock data properly. """ config = read_config_file("test/data_collection.yaml") config["use_relative_stock_data"] = True data_collection = create_data_collection(config) self.assertEqual(4, len(data_collection.data_objects)) self.assertNotEqual(data_collection.stock_data_id, data_collection.absolute_stock_data_id)
def test_ignores_duplicates(self): """Checks if data collection ignores duplicates. """ config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) stock_data = RealStockData() returned_stock_data = data_collection.append(stock_data) self.assertIsInstance(returned_stock_data, RealStockData) self.assertNotEqual(stock_data, returned_stock_data)
def test_catches_config_error(self, config_filename, excpecetd_exception): """Checks if data collection catches configuration errors. Args: config_filename: the filename for the config file. excpecetd_exception: the expected exception. """ with self.assertRaises(excpecetd_exception): create_data_collection(read_config_file(config_filename))
def test_creates_data(self, config_filename, expected_class): """Checks if created data class is of the right class. Args: config_filename: the filename for the config file. expected_class: the expected class created from config file. """ data = create_data(read_config_file(config_filename)) self.assertIsInstance(data, expected_class)
def test_initializes(self): """Test for simulation initialization. """ from_date = datetime(2016, 1, 1) to_date = datetime(2016, 2, 1) data_collection_config = read_config_file("test/data_collection.yaml") simulation = StockMarketSimulation(data_collection_config, from_date, to_date) self.assertEqual(1000, simulation.max_start_balance)
def test_creates_agent(self, config_filename, expected_class): """Checks if created model class is of the right class. Args: config_filename: the filename for the config file. expected_class: the expected class created from config file. """ model = create_model(read_config_file(config_filename)) self.assertIsInstance(model, expected_class)
def get_reward_config(reward_name): """Get reward config. Args: reward_name: string, name of the reward (returned from get_available_rewards). Returns: A config to use for reward factory. """ return read_config_file(os.path.join("reward", reward_name + ".yaml"))
def get_agent_config(agent_name): """Get agent config. Args: agent_name: string, name of the agent (returned from get_available_agents). Returns: A config to use for agent factory. """ return read_config_file(os.path.join("agent", agent_name + ".yaml"))
def test_get_availbale_dates(self): """Checks if get available dates works properly. """ from_date = datetime(2016, 1, 1) to_date = datetime(2016, 2, 1) config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) data_collection.set_date_range(from_date, to_date) data_collection.prepare_data() available_dates = data_collection.get_available_dates() self.assertEqual(10, len(available_dates))
def test_resets(self): """Test for simulation reset. """ from_date = datetime(2016, 1, 1) to_date = datetime(2016, 2, 1) data_collection_config = read_config_file("test/data_collection.yaml") simulation = StockMarketSimulation(data_collection_config, from_date, to_date) simulation.reset() self.assertEqual(0, simulation.from_date_index) self.assertEqual(0, simulation.curr_date_index)
def setUp(self): """Set up for the unit tests. """ from_date = datetime(2016, 1, 1) to_date = datetime(2016, 2, 1) self.data_collection_config = read_config_file("test/simulation.yaml") self.simulation = StockMarketSimulation(self.data_collection_config, from_date, to_date, min_start_balance=100, max_start_balance=100, max_stock_owned=2)
def get_model_config(model_name): """Get model config. Args: model_name: string, name of the model to use (returned from get_available_models). Returns: A config to use for model factory. """ if model_name is None: return None return read_config_file(os.path.join("model", model_name + ".yaml"))
def test_adds_randomization(self): """Checks if data collection adds randomization properly. """ config = read_config_file("test/data_collection.yaml") config["stock_data_randomization"] = True data_collection = create_data_collection(config) self.assertEqual(3, len(data_collection.data_objects)) config = read_config_file("test/simulation.yaml") config["stock_data_randomization"] = False data_collection = create_data_collection(config) self.assertEqual(1, len(data_collection.data_objects)) config = read_config_file("test/simulation.yaml") config["stock_data_randomization"] = True data_collection = create_data_collection(config) self.assertEqual(2, len(data_collection.data_objects)) randomization_layer = data_collection.data_objects[0] stock_data = data_collection.data_objects[1] self.assertEqual(randomization_layer.dependencies[0], stock_data.id_str)
def test_prepares_data(self): """Checks if data collection prepares the data properly. """ from_date = datetime(2016, 1, 1) to_date = datetime(2016, 2, 1) config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) self.assertEqual(data_collection.stock_data_id, data_collection.absolute_stock_data_id) data_collection.set_date_range(from_date, to_date) data_collection.prepare_data() for data in data_collection.data_objects: self.assertTrue(data.ready)
def test_apply_learning(self): """A test to see if sarsa agent can apply learning. """ data_collection_config = read_config_file("test/simulation.yaml") model_config = read_config_file("model/linear.yaml") agent = SARSALearningAgent( data_collection_config=data_collection_config, model_config=model_config, discount_factor=0.5) observations = pd.DataFrame([[0]] * 6, columns=["balance"]) actions = [[0]] * 3 + [[1]] * 3 rewards = [0] * 3 + [1] * 3 sa_values = [0] * 3 + [1] * 3 # Testing whther sarsa learning agent changes trained variable. agent.apply_learning([observations], [actions], [rewards], [sa_values]) self.assertTrue(agent.usable) # pylint: disable=unused-variable for i in range(5): agent.apply_learning([observations], [actions], [rewards], [sa_values])
def get_data_collection_config(data_collection_name): """Get data collection config. Args: agent_name: string, name of the agent (returned from get_available_agents). data_collection_name: string, name of the data collection set up. reward_name: string, name of the reward (returned from get_available_rewards). model_name: string, name of the model to use (returned from get_available_models). Returns: A config to use for data collection factory. """ return read_config_file( os.path.join("data", data_collection_name + ".yaml"))
def test_apply_learning(self): """A test to see if q learning agent can apply learning. """ data_collection_config = read_config_file("test/simulation.yaml") model_config = read_config_file("model/linear.yaml") agent = QLearningAgent(data_collection_config, model_config=model_config) observations = pd.DataFrame([[1]] * 6, columns=["shmalance"]) actions = [[0]] * 3 + [[1]] * 3 rewards = [-1] * 3 + [1] * 3 sa_values = [0] * 3 + [1] * 3 # Testing whther q learning agent changes trained variable. agent.apply_learning([observations], [actions], [rewards], [sa_values], [sa_values]) self.assertTrue(agent.usable) # Test applying learning multiple times. for _ in range(5): agent.apply_learning([observations], [actions], [rewards], [sa_values], [sa_values]) self.assertTrue(agent.usable)
def test_getitem(self): """Checks if __getitem__ works as expected. """ expected_index = [ "GOOG", "AMZN", "ra_10_stock_data_GOOG", "ra_10_stock_data_AMZN" ] from_date = datetime(2016, 1, 1) to_date = datetime(2016, 2, 1) config = read_config_file("test/data_collection.yaml") data_collection = create_data_collection(config) data_collection.set_date_range(from_date, to_date) data_collection.prepare_data() available_dates = data_collection.get_available_dates() self.assertTrue((expected_index == data_collection[ available_dates[0]].index.tolist()))