def test_read_manifest(self):
     """Test for manifest reader.
     """
     manifest = read_manifest_file("data/test/test_manifest.json")
     self.assertEqual(type({}), type(manifest))
     self.assertIn("stock", manifest)
     self.assertEqual("test", manifest["stock"])
示例#2
0
 def can_be_loaded(self):
     """Checks if the agent can be loaded.
     """
     if not self.requires_learning:
         return True
     agent_manifest = read_manifest_file(join(AGENT_PATH, "manifest.json"))
     return self.id_str_with_hash in agent_manifest
    def test_not_implemented(self):
        """Test not-implemented features.
        """
        with self.assertRaises(NotImplementedError):
            _ = read_config_file("test/test_config.yaml", FileSourceType.aws)

        with self.assertRaises(NotImplementedError):
            _ = read_manifest_file("data/test/test_manifest.json",
                                   FileSourceType.aws)

        with self.assertRaises(NotImplementedError):
            write_manifest_file({}, "data/test/test_manifest.json",
                                FileSourceType.aws)

        with self.assertRaises(NotImplementedError):
            _ = read_csv_file("data/test/test.csv", FileSourceType.aws)

        with self.assertRaises(NotImplementedError):
            write_csv_file(None, "data/test/test.csv", FileSourceType.aws)

        with self.assertRaises(NotImplementedError):
            _ = load_torch_model("data/test/test.pkl", FileSourceType.aws)

        with self.assertRaises(NotImplementedError):
            save_torch_model(None, "data/test/test.pkl", FileSourceType.aws)
 def test_write_new_manifest(self):
     """Test for manifest writer for non-existent manifest.
     """
     file_path = "data/test/write_test.json"
     manifest = {"stock": "write_test"}
     write_manifest_file(manifest, file_path)
     manifest = read_manifest_file(file_path)
     os.remove(file_path)
     self.assertIn("stock", manifest)
     self.assertEqual("write_test", manifest["stock"])
 def test_manifest_works_with_dates(self):
     """Test for manifest reader and writer for dates.
     """
     file_path = "data/test/write_test.json"
     date = datetime(2015, 1, 1)
     manifest = {"stock_date": date}
     write_manifest_file(manifest, file_path)
     manifest = read_manifest_file(file_path)
     os.remove(file_path)
     self.assertIn("stock_date", manifest)
     self.assertEqual(date, manifest["stock_date"])
示例#6
0
    def load(self):
        """Loads the saved version of the agent.
        """
        if not self.requires_learning:
            return

        path = join(AGENT_PATH, "manifest.json")
        agent_manifest = read_manifest_file(path)

        if self.id_str_with_hash not in agent_manifest:
            raise LookupError("Couldn't find the saved version.")

        model_path = join(MODEL_PATH_TEMPLATE.format(self.id_str_with_hash))
        self.model.load(model_path)
        self.trained = True
示例#7
0
    def save(self):
        """Saves the agent for future re-use.
        """
        if not self.requires_learning:
            return

        if not self.trained:
            raise ValueError("Agent must be trained before saving it.")

        manifest_path = join(AGENT_PATH, "manifest.json")
        model_path = join(MODEL_PATH_TEMPLATE.format(self.id_str_with_hash))
        agent_manifest = read_manifest_file(manifest_path)
        agent_manifest[self.id_str_with_hash] = {"trained": True}
        self.model.save(model_path)
        write_manifest_file(agent_manifest, manifest_path)
    def test_stock_data_for_multiple_stock(self):
        """Unit test for getting stock data.
        """
        from_date = datetime(2016, 1, 1)
        to_date = datetime(2016, 2, 1)
        stock_names = ["GOOG", "TSLA"]
        path = "data/test"

        write_manifest_file({}, os.path.join(path, STOCK_MANIFEST_FILE_NAME))
        data = get_stock_data(stock_names, from_date, to_date, path)
        self.clean_csv_files(path, stock_names)
        manifest = read_manifest_file(
            os.path.join(path, STOCK_MANIFEST_FILE_NAME))
        write_manifest_file({}, os.path.join(path, STOCK_MANIFEST_FILE_NAME))

        for stock_name in stock_names:
            self.assertIn(stock_name, manifest)
        self.assertTrue((stock_names == data.columns.tolist()))
 def test_read_new_manifest(self):
     """Test for manifest reader for non-existent manifest.
     """
     manifest = read_manifest_file("data/test/non_existent.json")
     self.assertEqual(type({}), type(manifest))
     self.assertEqual(0, len(manifest))