def test_raw_input(self): inputs = {"key1": "value1"} outputs = utils.flatten_to_dict(inputs) self.assertIn("key1", outputs) self.assertEqual(outputs["key1"], "value1") inputs = "value1" outputs = utils.flatten_to_dict(inputs) self.assertIn("value1", outputs) self.assertEqual(outputs["value1"], "value1")
def test_list_input(self): inputs = [{"key1": "value1"}, {"key2": "value2"}] outputs = utils.flatten_to_dict(inputs) self.assertIn("key1", outputs) self.assertEqual(outputs["key1"], "value1") self.assertIn("key2", outputs) self.assertEqual(outputs["key2"], "value2") inputs = [{"key1": "value1"}, {"key2": "value2"}, {"key1": "repeated"}] outputs = utils.flatten_to_dict(inputs) self.assertIn("key1", outputs) self.assertEqual(outputs["key1"], "repeated") self.assertIn("key2", outputs) self.assertEqual(outputs["key2"], "value2") inputs = [] outputs = utils.flatten_to_dict(inputs) self.assertFalse(outputs) inputs = [{"key1": "value1"}, "value2"] outputs = utils.flatten_to_dict(inputs) self.assertIn("key1", outputs) self.assertEqual(outputs["key1"], "value1") self.assertIn("value2", outputs) self.assertEqual(outputs["value2"], "value2") inputs = ["value1", "value2"] outputs = utils.flatten_to_dict(inputs) self.assertIn("value1", outputs) self.assertEqual(outputs["value1"], "value1") self.assertIn("value2", outputs) self.assertEqual(outputs["value2"], "value2")
def _download_dataset(self): """ Download the task outputs from the gbasf2 project dataset. For each task output defined via ``self.add_to_output(<name>.root)`` a directory will be created, into which all files named ``name_*.root`` on the grid dataset corresponding to the project name will be downloaded. The download is ensured to be automatic by first downloading into temporary directories. """ if not check_dataset_exists_on_grid(self.gbasf2_project_name, dirac_user=self.dirac_user): raise RuntimeError( f"Not dataset to download under project name {self.gbasf2_project_name}" ) task_output_dict = flatten_to_dict(self.task.output()) for output_file_name, output_target in task_output_dict.items(): output_dir_path = output_target.path assert output_file_name == os.path.basename( output_file_name) # not sure I need this output_file_stem, output_file_ext = os.path.splitext( output_file_name) assert output_file_ext == ".root", "gbasf2 batch only supports root outputs" # Get list of files that we want to download from the grid via ``gb2_ds_list`` so that we can # then compare this list with the results of the download to see if it was successful dataset_query_string = f"--user {self.dirac_user} {self.gbasf2_project_name}/{output_file_stem}_*{output_file_ext}" ds_list_command = shlex.split( f"gb2_ds_list {dataset_query_string}") output_dataset_grid_filepaths = run_with_gbasf2( ds_list_command, capture_output=True).stdout.splitlines() output_dataset_basenames = { os.path.basename(grid_path) for grid_path in output_dataset_grid_filepaths } # check if dataset had been already downloaded and if so, skip downloading if os.path.isdir(output_dir_path) and os.listdir( output_dir_path) == output_dataset_basenames: print( f"Dataset already exists in {output_dir_path}, skipping download." ) return # To prevent from task being accidentally marked as complete when the gbasf2 dataset download failed, # we create a temporary directory in the parent of ``output_dir_path`` and first download the dataset there. # The download command will download it into a subdirectory with the same name as the project. # If the download had been successful and the local files are identical to the list of files on the grid, # we move the downloaded dataset to the location specified by ``output_dir_path`` output_dir_parent = os.path.dirname(output_dir_path) os.makedirs(output_dir_parent, exist_ok=True) with tempfile.TemporaryDirectory( dir=output_dir_parent) as tmpdir_path: ds_get_command = shlex.split( f"gb2_ds_get --force {dataset_query_string}") print("Downloading dataset with command ", " ".join(ds_get_command)) stdout = run_with_gbasf2(ds_get_command, cwd=tmpdir_path, capture_output=True).stdout print(stdout) if "No file found" in stdout: raise RuntimeError( f"No output data for gbasf2 project {self.gbasf2_project_name} found." ) tmp_output_dir = os.path.join(tmpdir_path, self.gbasf2_project_name) downloaded_dataset_basenames = set(os.listdir(tmp_output_dir)) if output_dataset_basenames == downloaded_dataset_basenames: print( f"Download of {self.gbasf2_project_name} files successful.\n" f"Moving output files to directory: {output_dir_path}") if os.path.exists(output_dir_path): shutil.rmtree(output_dir_path) shutil.move(src=tmp_output_dir, dst=output_dir_path) else: raise RuntimeError( f"The downloaded of files in {tmp_output_dir} is not equal to the " f"dataset files for the grid project {self.gbasf2_project_name}" )
def _get_output_target(self, key): """Shortcut to get the output target for a given key. Will return a luigi target.""" output_dict = utils.flatten_to_dict(self.output()) return output_dict[key]