def upload_as_parquet(df) -> Dataset: """ Generate a random parquet """ suffix = rand_string(length=10) rand_name = f"tmp_parquet_{suffix}" df.write.mode("errorifexists").format("parquet").saveAsTable(rand_name) parquet_url = get_table_url(rand_name) logger.info(f"Saved parquet to {parquet_url}") return Dataset(parquet_url=parquet_url)
def _test_dqn_workflow(self, use_gpu=False, use_all_avail_gpus=False): runner = CliRunner() config = get_test_workflow_config( path_to_config=DQN_WORKFLOW_YAML, use_gpu=use_gpu ) # create new altered config (for faster testing) with runner.isolated_filesystem(): yaml = YAML(typ="safe") with open(NEW_CONFIG_NAME, "w") as f: yaml.dump(config, f) # unzip zipped parquet folder into cwd with zipfile.ZipFile(DQN_WORKFLOW_PARQUET_ZIP, "r") as zip_ref: zip_ref.extractall() # patch the two calls to spark # dataset points to the unzipped parquet folder # normalization points to mocked norm extracted from json mock_dataset = Dataset( parquet_url=f"file://{os.path.abspath(DQN_WORKFLOW_PARQUET_REL_PATH)}" ) mock_normalization = mock_cartpole_normalization() with patch( "reagent.data.oss_data_fetcher.OssDataFetcher.query_data", return_value=mock_dataset, ), patch( f"{DISCRETE_DQN_BASE}.identify_normalization_parameters", return_value=mock_normalization, ): # call the cli test result = runner.invoke( cli.run, [ "reagent.workflow.training.identify_and_train_network", NEW_CONFIG_NAME, ], catch_exceptions=False, ) print(result.output) assert result.exit_code == 0, f"result = {result}"
def _read_data(self, custom_reward_expression=None, gamma=None, multi_steps=None): ts = TableSpec( table_name=self.table_name, output_dataset=Dataset(parquet_url=self.parquet_url), ) query_data( input_table_spec=ts, actions=["L", "R", "U", "D"], custom_reward_expression=custom_reward_expression, multi_steps=multi_steps, gamma=gamma, ) df = self.sqlCtx.read.parquet(self.parquet_url) df = df.orderBy(asc("sequence_number")) logger.info("Read parquet dataframe") df.show() return df
def upload_as_parquet(df) -> Dataset: """ Generate a random parquet. Fails if cannot generate a non-existent name. """ # get a random tmp name and check if it exists sqlCtx = get_spark_session() success = False for _ in range(MAX_UPLOAD_PARQUET_TRIES): suffix = rand_string(length=UPLOAD_PARQUET_TMP_SUFFIX_LEN) rand_name = f"tmp_parquet_{suffix}" if not sqlCtx.catalog._jcatalog.tableExists(rand_name): success = True break if not success: raise Exception(f"Failed to find name after {MAX_UPLOAD_PARQUET_TRIES} tries.") # perform the write df.write.mode("errorifexists").format("parquet").saveAsTable(rand_name) parquet_url = get_table_url(rand_name) logger.info(f"Saved parquet to {parquet_url}") return Dataset(parquet_url=parquet_url)