def _process_dataset_param(self) -> None: """Dataset needs to be fully executed before sent over to trainables. A valid dataset configuration in param space looks like: "datasets": { "train_dataset": tune.grid_search([ds1, ds2]), }, """ execute_dataset(self._param_space)
def test_choice(): ds1 = gen_dataset_func().experimental_lazy().map(lambda x: x) ds2 = gen_dataset_func().experimental_lazy().map(lambda x: x) assert not ds1._plan._has_final_stage_snapshot() assert not ds2._plan._has_final_stage_snapshot() param_space = {"train_dataset": tune.choice([ds1, ds2])} execute_dataset(param_space) executed_ds = param_space["train_dataset"].categories assert len(executed_ds) == 2 assert executed_ds[0]._plan._has_final_stage_snapshot() assert executed_ds[1]._plan._has_final_stage_snapshot()
def test_grid_search(): ds1 = gen_dataset_func()._experimental_lazy().map(lambda x: x) ds2 = gen_dataset_func()._experimental_lazy().map(lambda x: x) assert not ds1._plan._has_final_stage_snapshot() assert not ds2._plan._has_final_stage_snapshot() param_space = {"train_dataset": tune.grid_search([ds1, ds2])} execute_dataset(param_space) executed_ds = param_space["train_dataset"]["grid_search"] assert len(executed_ds) == 2 assert executed_ds[0]._plan._has_final_stage_snapshot() assert executed_ds[1]._plan._has_final_stage_snapshot()