def build_transforms_from_environment(env, dataset_config): mri_transforms_func = functools.partial( build_mri_transforms, forward_operator=env.engine.forward_operator, backward_operator=env.engine.backward_operator, mask_func=build_masking_function(**dataset_config.transforms.masking), ) transforms = mri_transforms_func(**remove_keys(dataset_config.transforms, "masking")) return transforms
def build_dataset_from_environment(env, datasets_config, lists_root, data_root, type_data, **kwargs): datasets = [] for idx, dataset_config in enumerate(datasets_config): transforms = build_mri_transforms( forward_operator=env.forward_operator, backward_operator=env.backward_operator, mask_func=build_masking_function( **dataset_config.transforms.masking), crop=dataset_config.transforms.crop, crop_type=dataset_config.transforms.crop_type, image_center_crop=dataset_config.transforms.image_center_crop, estimate_sensitivity_maps=dataset_config.transforms. estimate_sensitivity_maps, pad_coils=dataset_config.transforms.pad_coils, ) logger.debug(f"Transforms for {type_data}: {idx}:\n{transforms}") # Only give fancy names when validating # TODO(jt): Perhaps this can be split up to just a description parameters, and parse config in the main func. if type_data == "validation": if dataset_config.text_description: text_description = dataset_config.text_description else: text_description = f"ds{idx}" if len( datasets_config) > 1 else None elif type_data == "training": text_description = None else: raise ValueError( f"Type of data needs to be either `validation` or `training`, got {type_data}." ) dataset = build_dataset( dataset_config.name, data_root, filenames_filter=get_filenames_for_datasets( dataset_config, lists_root, data_root), sensitivity_maps=None, transforms=transforms, text_description=text_description, kspace_context=dataset_config.kspace_context, **kwargs, ) datasets.append(dataset) logger.info( f"Data size for {type_data} dataset" f" {dataset_config.name} ({idx + 1}/{len(datasets_config)}): {len(dataset)}." ) return datasets
def _get_transforms(validation_index, env): dataset_cfg = env.cfg.validation.datasets[validation_index] mask_func = build_masking_function(**dataset_cfg.transforms.masking) transforms = build_inference_transforms(env, mask_func, dataset_cfg) return dataset_cfg, transforms
def setup_inference( run_name, data_root, base_directory, output_directory, cfg_filename, checkpoint, validation_set_index, accelerations, center_fractions, device, num_workers, machine_rank, mixed_precision ): # TODO(jt): This is a duplicate line, check how this can be merged with train_rim.py # TODO(jt): Log elsewhere than for training. # TODO(jt): Logging is different when having multiple processes. env = setup_environment( run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision ) # Create training and validation data # Masking configuration if len(env.cfg.validation.datasets) > 1 and not validation_set_index: logger.warning( "Multiple validation datasets given in config, yet no index is given. Will select first." ) validation_set_index = validation_set_index if validation_set_index else 0 if accelerations or center_fractions: sys.exit(f"Overwriting of accelerations or ACS not yet supported.") mask_func = build_masking_function( **env.cfg.validation.datasets[validation_set_index].transforms.masking ) mri_transforms = build_mri_transforms( forward_operator=env.forward_operator, backward_operator=env.backward_operator, mask_func=mask_func, crop=None, # No cropping needed for testing image_center_crop=True, estimate_sensitivity_maps=env.cfg.training.datasets[0].transforms.estimate_sensitivity_maps, ) # Trigger cudnn benchmark when the number of different input shapes is small. torch.backends.cudnn.benchmark = True # TODO(jt): batches should have constant shapes! This works for Calgary Campinas because they are all with 256 # slices. data = build_dataset( env.cfg.validation.datasets[validation_set_index].name, data_root, sensitivity_maps=None, transforms=mri_transforms, ) logger.info(f"Inference data size: {len(data)}.") # Just to make sure. torch.cuda.empty_cache() # Run prediction output = env.engine.predict( data, env.experiment_dir, checkpoint_number=checkpoint, num_workers=num_workers, ) # Create output directory output_directory.mkdir(exist_ok=True, parents=True) # Only relevant for the Calgary Campinas challenge. # TODO(jt): This can be inferred from the configuration. # TODO(jt): Refactor this for v0.2. crop = ( (50, -50) if env.cfg.validation.datasets[validation_set_index].name == "CalgaryCampinas" else None ) # TODO(jt): Perhaps aggregation to the main process would be most optimal here before writing. for idx, filename in enumerate(output): # The output has shape (depth, 1, height, width) logger.info( f"({idx + 1}/{len(output)}): Writing {output_directory / filename}..." ) reconstruction = ( torch.stack([_[1].rename(None) for _ in output[filename]]) .numpy()[:, 0, ...] .astype(np.float) ) if crop: reconstruction = reconstruction[slice(*crop)] # Only needed to fix a bug in Calgary Campinas training if env.cfg.validation.datasets[validation_set_index].name == "CalgaryCampinas": reconstruction = reconstruction / np.sqrt(np.prod(reconstruction.shape[1:])) with h5py.File(output_directory / filename, "w") as f: f.create_dataset("reconstruction", data=reconstruction)