Esempio n. 1
0
def upload_as_parquet(df) -> Dataset:
    """ Generate a random parquet """
    suffix = rand_string(length=10)
    rand_name = f"tmp_parquet_{suffix}"
    df.write.mode("errorifexists").format("parquet").saveAsTable(rand_name)
    parquet_url = get_table_url(rand_name)
    logger.info(f"Saved parquet to {parquet_url}")
    return Dataset(parquet_url=parquet_url)
Esempio n. 2
0
    def _test_dqn_workflow(self, use_gpu=False, use_all_avail_gpus=False):
        runner = CliRunner()
        config = get_test_workflow_config(
            path_to_config=DQN_WORKFLOW_YAML, use_gpu=use_gpu
        )

        # create new altered config (for faster testing)
        with runner.isolated_filesystem():
            yaml = YAML(typ="safe")
            with open(NEW_CONFIG_NAME, "w") as f:
                yaml.dump(config, f)

            # unzip zipped parquet folder into cwd
            with zipfile.ZipFile(DQN_WORKFLOW_PARQUET_ZIP, "r") as zip_ref:
                zip_ref.extractall()

            # patch the two calls to spark
            # dataset points to the unzipped parquet folder
            # normalization points to mocked norm extracted from json
            mock_dataset = Dataset(
                parquet_url=f"file://{os.path.abspath(DQN_WORKFLOW_PARQUET_REL_PATH)}"
            )
            mock_normalization = mock_cartpole_normalization()
            with patch(
                "reagent.data.oss_data_fetcher.OssDataFetcher.query_data",
                return_value=mock_dataset,
            ), patch(
                f"{DISCRETE_DQN_BASE}.identify_normalization_parameters",
                return_value=mock_normalization,
            ):
                # call the cli test
                result = runner.invoke(
                    cli.run,
                    [
                        "reagent.workflow.training.identify_and_train_network",
                        NEW_CONFIG_NAME,
                    ],
                    catch_exceptions=False,
                )

            print(result.output)
            assert result.exit_code == 0, f"result = {result}"
Esempio n. 3
0
 def _read_data(self,
                custom_reward_expression=None,
                gamma=None,
                multi_steps=None):
     ts = TableSpec(
         table_name=self.table_name,
         output_dataset=Dataset(parquet_url=self.parquet_url),
     )
     query_data(
         input_table_spec=ts,
         actions=["L", "R", "U", "D"],
         custom_reward_expression=custom_reward_expression,
         multi_steps=multi_steps,
         gamma=gamma,
     )
     df = self.sqlCtx.read.parquet(self.parquet_url)
     df = df.orderBy(asc("sequence_number"))
     logger.info("Read parquet dataframe")
     df.show()
     return df
Esempio n. 4
0
def upload_as_parquet(df) -> Dataset:
    """ Generate a random parquet. Fails if cannot generate a non-existent name. """

    # get a random tmp name and check if it exists
    sqlCtx = get_spark_session()
    success = False
    for _ in range(MAX_UPLOAD_PARQUET_TRIES):
        suffix = rand_string(length=UPLOAD_PARQUET_TMP_SUFFIX_LEN)
        rand_name = f"tmp_parquet_{suffix}"
        if not sqlCtx.catalog._jcatalog.tableExists(rand_name):
            success = True
            break
    if not success:
        raise Exception(f"Failed to find name after {MAX_UPLOAD_PARQUET_TRIES} tries.")

    # perform the write
    df.write.mode("errorifexists").format("parquet").saveAsTable(rand_name)
    parquet_url = get_table_url(rand_name)
    logger.info(f"Saved parquet to {parquet_url}")
    return Dataset(parquet_url=parquet_url)