def _distributed_worker(local_rank, main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args): global_rank = machine_rank * num_gpus_per_machine + local_rank logger = logging.getLogger(__name__) try: dist.init_process_group(backend='NCCL', init_method=dist_url, world_size=world_size, rank=global_rank) except Exception as e: logger.error(f'Process group URL: {dist_url}') raise e # synchronize is needed here to prevent a possible timeout after calling init_process_group # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 communication.synchronize() logger.info(f'Global rank {global_rank}.') logger.info('Synchronized GPUs.') assert num_gpus_per_machine <= torch.cuda.device_count() torch.cuda.set_device(local_rank) # Setup the local process group (which contains ranks within the same machine) assert communication._LOCAL_PROCESS_GROUP is None # noqa num_machines = world_size // num_gpus_per_machine for i in range(num_machines): ranks_on_i = list( range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)) pg = dist.new_group(ranks_on_i) if i == machine_rank: communication._LOCAL_PROCESS_GROUP = pg main_func(*args)
def setup_training_environment( run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=False, ): env = setup_common_environment( run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=debug, ) # Write config file to experiment directory. config_file_in_project_folder = env.experiment_dir / "config.yaml" logger.info( f"Writing configuration file to: {config_file_in_project_folder}.") if communication.is_main_process(): with open(config_file_in_project_folder, "w") as f: f.write(OmegaConf.to_yaml(env.cfg)) communication.synchronize() return env
def setup_training_environment( run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=False, ): experiment_dir = base_directory / run_name if communication.get_local_rank() == 0: # Want to prevent multiple workers from trying to write a directory # This is required in the logging below experiment_dir.mkdir(parents=True, exist_ok=True) communication.synchronize() # Ensure folders are in place. # Load configs from YAML file to check which model needs to be loaded. cfg_from_file = OmegaConf.load(cfg_filename) base_cfg, models = load_models_into_environment_config(cfg_from_file) # Setup everything for training base_cfg.training = TrainingConfig # Parse the proper specific config for the datasets: base_cfg.training.datasets = [ load_dataset_config(dataset) for dataset in base_cfg.training.datasets ] base_cfg.validation.datasets = [ load_dataset_config(dataset) for dataset in base_cfg.validation.datasets ] # Make configuration read only. # TODO(jt): Does not work when indexing config lists. # OmegaConf.set_readonly(cfg, True) forward_operator, backward_operator, engine, cfg = setup_common_environment( base_cfg, cfg_from_file, models, device, machine_rank, experiment_dir, run_name, cfg_filename, mixed_precision, debug, ) # Check if the file exists in the project directory config_file_in_project_folder = experiment_dir / "config.yaml" if config_file_in_project_folder.exists(): if dict(OmegaConf.load(config_file_in_project_folder)) != dict(cfg): pass # raise ValueError( # f"This project folder exists and has a config.yaml, " # f"yet this does not match with the one the model was built with." # ) else: if communication.get_local_rank() == 0: with open(config_file_in_project_folder, "w") as f: f.write(OmegaConf.to_yaml(cfg)) communication.synchronize() environment = namedtuple( "environment", [ "cfg", "experiment_dir", "forward_operator", "backward_operator", "engine" ], ) return environment(cfg, experiment_dir, forward_operator, backward_operator, engine)
def evaluate( self, data_loader: DataLoader, loss_fns: Optional[Dict[str, Callable]], regularizer_fns: Optional[Dict[str, Callable]] = None, crop: Optional[str] = None, is_validation_process=True, ): self.models_to_device() self.models_validation_mode() torch.cuda.empty_cache() # Variables required for evaluation. # TODO(jt): Consider if this needs to be in the main engine.py or here. Might be possible we have different # types needed, perhaps even a FastMRI engine or something similar depending on the metrics. volume_metrics = self.build_metrics(self.cfg.validation.metrics) # filenames can be in the volume_indices attribute of the dataset if hasattr(data_loader.dataset, "volume_indices"): all_filenames = list(data_loader.dataset.volume_indices.keys()) num_for_this_process = len( list(data_loader.batch_sampler.sampler.volume_indices.keys())) self.logger.info( f"Reconstructing a total of {len(all_filenames)} volumes. " f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})." ) else: num_for_this_process = None filenames_seen = 0 reconstruction_output = defaultdict(list) targets_output = defaultdict(list) val_losses = [] val_volume_metrics = defaultdict(dict) last_filename = None # Container to for the slices which can be visualized in TensorBoard. visualize_slices = [] visualize_target = [] visualizations = {} extra_visualization_keys = (self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else []) # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is # that the slices are outputted from the Dataset *sequentially* for each volume one by one. time_start = time.time() for iter_idx, data in enumerate(data_loader): data = AddNames()(data) filenames = data.pop("filename") if len(set(filenames)) != 1: raise ValueError( f"Expected a batch during validation to only contain filenames of one case. " f"Got {set(filenames)}.") slice_nos = data.pop("slice_no") scaling_factors = data["scaling_factor"] resolution = self.compute_resolution( key=self.cfg.validation.crop, reconstruction_size=data.get("reconstruction_size", None), ) # Compute output and loss. iteration_output = self._do_iteration( data, loss_fns, regularizer_fns=regularizer_fns) output = iteration_output.output_image loss_dict = iteration_output.data_dict # sensitivity_map = iteration_output.sensitivity_map loss_dict = detach_dict(loss_dict) output = output.detach() val_losses.append(loss_dict) # Output is complex-valued, and has to be cropped. This holds for both output and target. output_abs = self.process_output( output.refine_names(*self.complex_names()), scaling_factors, resolution=resolution, ) if is_validation_process: target_abs = self.process_output( data["target"].detach().refine_names(*self.real_names()), scaling_factors, resolution=resolution, ) for key in extra_visualization_keys: curr_data = data[key].detach() # Here we need to discover which keys are actually normalized or not # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23 del output # Explicitly call delete to clear memory. # TODO: Is a hack. # Aggregate volumes to be able to compute the metrics on complete volumes. for idx, filename in enumerate(filenames): if last_filename is None: last_filename = ( filename # First iteration last_filename is not set. ) # If the new filename is not the previous one, then we can reconstruct the volume as the sampling # is linear. # For the last case we need to check if we are at the last batch *and* at the last element in the batch. is_last_element_of_last_batch = iter_idx + 1 == len( data_loader) and idx + 1 == len(data["target"]) if filename != last_filename or is_last_element_of_last_batch: filenames_seen += 1 # Now we can ditch the reconstruction dict by reconstructing the volume, # will take too much memory otherwise. # TODO: Stack does not support named tensors. volume = torch.stack([ _[1].rename(None) for _ in reconstruction_output[last_filename] ]) if is_validation_process: target = torch.stack([ _[1].rename(None) for _ in targets_output[last_filename] ]) curr_metrics = { metric_name: metric_fn(target, volume) for metric_name, metric_fn in volume_metrics.items() } val_volume_metrics[last_filename] = curr_metrics # Log the center slice of the volume if (len(visualize_slices) < self.cfg.logging.tensorboard.num_images): visualize_slices.append(volume[volume.shape[0] // 2]) visualize_target.append(target[target.shape[0] // 2]) # Delete outputs from memory, and recreate dictionary. This is not needed when not in validation # as we are actually interested in the output del targets_output targets_output = defaultdict(list) del reconstruction_output reconstruction_output = defaultdict(list) if all_filenames: log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:" else: log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:" self.logger.info( f"{log_prefix} {last_filename}" f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s." ) # restart timer time_start = time.time() last_filename = filename curr_slice = output_abs[idx].detach() slice_no = int(slice_nos[idx].numpy()) # TODO: CPU? reconstruction_output[filename].append( (slice_no, curr_slice.cpu())) if is_validation_process: targets_output[filename].append( (slice_no, target_abs[idx].cpu())) # Average loss dict loss_dict = reduce_list_of_dicts(val_losses) reduce_tensor_dict(loss_dict) communication.synchronize() torch.cuda.empty_cache() # TODO: Does not work yet with normal gather. all_gathered_metrics = merge_list_of_dicts( communication.all_gather(val_volume_metrics)) if not is_validation_process: return loss_dict, reconstruction_output # TODO: Apply named tuples where applicable # TODO: Several functions have multiple output values, in many cases # TODO: it would be more convenient to convert this to namedtuples. return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
def evaluate( self, data_loader: DataLoader, loss_fns: Optional[Dict[str, Callable]], crop: Optional[str] = None, is_validation_process=True, ): # TODO(jt): Also log other models output (e.g. sensitivity map). # TODO(jt): This can be simplified as the sampler now only outputs batches belonging to the same volume. self.models_to_device() self.models_validation_mode() torch.cuda.empty_cache() # Variables required for evaluation. # TODO(jt): Consider if this needs to be in the main engine.py or here. Might be possible we have different # types needed, perhaps even a FastMRI engine or something similar depending on the metrics. volume_metrics = self.build_metrics(self.cfg.validation.metrics) reconstruction_output = defaultdict(list) targets_output = defaultdict(list) val_losses = [] val_volume_metrics = defaultdict(dict) last_filename = None # Container to for the slices which can be visualized in TensorBoard. visualize_slices = [] visualize_target = [] # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is # that the slices are outputted from the Dataset *sequentially* for each volume one by one. for iter_idx, data in enumerate(data_loader): self.log_process(iter_idx, len(data_loader)) data = AddNames()(data) filenames = data.pop("filename") if len(set(filenames)) != 1: raise ValueError( f"Expected a batch during validation to only contain filenames of one case. " f"Got {set(filenames)}.") slice_nos = data.pop("slice_no") scaling_factors = data.pop("scaling_factor") # Check if reconstruction size is the data if self.cfg.validation.crop == "header": # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over # batches. resolution = [ _.cpu().numpy().tolist() for _ in data["reconstruction_size"] ] # The volume sampler should give validation indices belonging to the *same* volume, so it should be # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y). resolution = [_[0] for _ in resolution][:-1] elif self.cfg.validation.crop == "training": resolution = self.cfg.training.loss.crop elif not self.cfg.validation.loss.crop: resolution = None else: raise ValueError( f"Cropping should be either set to `header` to get the values from the header or " f"`training` to take the same value as training.") # Compute output and loss. output, loss_dict = self._do_iteration(data, loss_fns) val_losses.append(loss_dict) # Output is complex-valued, and has to be cropped. This holds for both output and target. output_abs = self.process_output( output.refine_names(*self.complex_names).detach(), scaling_factors, resolution=resolution, ) if is_validation_process: target_abs = self.process_output( data["target"].refine_names(*self.real_names).detach(), scaling_factors, resolution=resolution, ) del output # Explicitly call delete to clear memory. # TODO: Is a hack. # Aggregate volumes to be able to compute the metrics on complete volumes. for idx, filename in enumerate(filenames): if last_filename is None: last_filename = ( filename # First iteration last_filename is not set. ) # If the new filename is not the previous one, then we can reconstruct the volume as the sampling # is linear. # For the last case we need to check if we are at the last batch *and* at the last element in the batch. if filename != last_filename or ( iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"])): # Now we can ditch the reconstruction dict by reconstructing the volume, # will take too much memory otherwise. # TODO: Stack does not support named tensors. volume = torch.stack([ _[1].rename(None) for _ in reconstruction_output[last_filename] ]) self.logger.info( f"Reconstructed {last_filename} (shape = {list(volume.shape)})." ) if is_validation_process: target = torch.stack([ _[1].rename(None) for _ in targets_output[last_filename] ]) curr_metrics = { metric_name: metric_fn(volume, target) for metric_name, metric_fn in volume_metrics.items() } val_volume_metrics[last_filename] = curr_metrics # Log the center slice of the volume if len(visualize_slices ) < self.cfg.tensorboard.num_images: visualize_slices.append( normalize_image(volume[volume.shape[0] // 2])) visualize_target.append( normalize_image(target[target.shape[0] // 2])) # Delete outputs from memory, and recreate dictionary. This is not needed when not in validation # as we are actually interested in the output del targets_output targets_output = defaultdict(list) del reconstruction_output reconstruction_output = defaultdict(list) last_filename = filename curr_slice = output_abs[idx] slice_no = int(slice_nos[idx].numpy()) # TODO: CPU? reconstruction_output[filename].append( (slice_no, curr_slice.cpu())) if is_validation_process: targets_output[filename].append( (slice_no, target_abs[idx].cpu())) # Average loss dict loss_dict = reduce_list_of_dicts(val_losses) reduce_tensor_dict(loss_dict) communication.synchronize() torch.cuda.empty_cache() # TODO(jt): Does not work yet with normal gather. all_gathered_metrics = merge_list_of_dicts( communication.all_gather(val_volume_metrics)) if not is_validation_process: return loss_dict, reconstruction_output # TODO(jt): Make named tuple return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
def setup_common_environment( run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=False, ): # Shutup all loggers logger = logging.getLogger() experiment_dir = base_directory / run_name if communication.get_local_rank() == 0: # Want to prevent multiple workers from trying to write a directory # This is required in the logging below experiment_dir.mkdir(parents=True, exist_ok=True) communication.synchronize() # Ensure folders are in place. # Load configs from YAML file to check which model needs to be loaded. cfg_from_file = OmegaConf.load(cfg_filename) # Load the default configs to ensure type safety cfg = OmegaConf.structured(DefaultConfig) models, models_config = load_models_into_environment_config(cfg_from_file) cfg.model = models_config.model del models_config["model"] cfg.additional_models = models_config # Setup everything for training cfg.training = TrainingConfig cfg.validation = ValidationConfig cfg.inference = InferenceConfig cfg_from_file_new = cfg_from_file.copy() for key in cfg_from_file: # TODO: This does not really do a full validation. # BODY: This will be handeled once Hydra is implemented. if key in ["models", "additional_models"]: # Still handled separately continue if key in ["training", "validation", "inference"]: if not cfg_from_file[key]: logger.info(f"key {key} missing in config.") continue if key in ["training", "validation"]: dataset_cfg_from_file = extract_names( cfg_from_file[key].datasets) for idx, (dataset_name, dataset_config) in enumerate(dataset_cfg_from_file): cfg_from_file_new[key].datasets[idx] = dataset_config cfg[key].datasets.append(load_dataset_config(dataset_name)) else: dataset_name, dataset_config = extract_names( cfg_from_file[key].dataset) cfg_from_file_new[key].dataset = dataset_config cfg[key].dataset = load_dataset_config(dataset_name) cfg[key] = OmegaConf.merge(cfg[key], cfg_from_file_new[key]) # sys.exit() # Make configuration read only. # TODO(jt): Does not work when indexing config lists. # OmegaConf.set_readonly(cfg, True) setup_logging(machine_rank, experiment_dir, run_name, cfg_filename, cfg, debug) forward_operator, backward_operator = build_operators(cfg.physics) model, additional_models = initialize_models_from_config( cfg, models, forward_operator, backward_operator, device) engine = setup_engine( cfg, device, model, additional_models, forward_operator=forward_operator, backward_operator=backward_operator, mixed_precision=mixed_precision, ) environment = namedtuple( "environment", ["cfg", "experiment_dir", "engine"], ) return environment(cfg, experiment_dir, engine)
def evaluate(self, data_loader: DataLoader, loss_fns: Dict[str, Callable], volume_metrics: Optional[Dict[str, Callable]] = None, evaluation_round=0): self.logger.info(f'Evaluating...') self.model.eval() torch.cuda.empty_cache() # Variables required for evaluation. volume_metrics = volume_metrics if volume_metrics is not None else self.build_metrics() storage = get_event_storage() reconstruction_output = defaultdict(list) targets_output = defaultdict(list) val_losses = [] val_volume_metrics = defaultdict(dict) last_filename = None # Container to for the slices which can be visualized in TensorBoard. visualize_slices = [] visualize_target = [] # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is # that the slices are outputted from the Dataset *sequentially* for each volume one by one. for iter_idx, data in enumerate(data_loader): self.log_process(iter_idx, len(data_loader)) data = AddNames()(data) filenames = data.pop('filename') slice_nos = data.pop('slice_no') scaling_factors = data.pop('scaling_factor') # Compute output and loss. output, loss_dict = self._do_iteration(data, loss_fns) val_losses.append(loss_dict) # Output is complex-valued, and has to be cropped. This holds for both output and target. output_abs = self.process_output( output.refine_names('batch', 'complex', 'height', 'width').detach(), scaling_factors, 320) target_abs = self.process_output( data['target'].refine_names('batch', 'height', 'width').detach(), scaling_factors, 320) del output # Explicitly call delete to clear memory. # TODO: Is a hack. # Aggregate volumes to be able to compute the metrics on complete volumes. batch_counter = 0 for idx, filename in enumerate(filenames): if last_filename is None: last_filename = filename # First iteration last_filename is not set. # If the new filename is not the previous one, then we can reconstruct the volume as the sampling # is linear. # For the last case we need to check if we are at the last batch *and* at the last element in the batch. if filename != last_filename or (iter_idx + 1 == len(data_loader) and idx + 1 == len(data['target'])): # Now we can ditch the reconstruction dict by reconstructing the volume, # will take too mucih memory otherwise. # TODO: Stack does not support named tensors. volume = torch.stack([_[1].rename(None) for _ in reconstruction_output[last_filename]]) target = torch.stack([_[1].rename(None) for _ in targets_output[last_filename]]) self.logger.info(f'Reconstructed {last_filename} (shape = {list(volume.shape)}).') curr_metrics = { metric_name: metric_fn(volume, target) for metric_name, metric_fn in volume_metrics.items()} val_volume_metrics[last_filename] = curr_metrics # Log the center slice of the volume if len(visualize_slices) < self.cfg.tensorboard.num_images: visualize_slices.append(normalize_image(volume[volume.shape[0] // 2])) # Target only needs to be logged once. if evaluation_round == 0: visualize_target.append(normalize_image(target[target.shape[0] // 2])) last_filename = filename # Delete outputs from memory, and recreate dictionary. del reconstruction_output del targets_output reconstruction_output = defaultdict(list) targets_output = defaultdict(list) curr_slice = output_abs[idx] slice_no = int(slice_nos[idx].numpy()) # TODO: CPU? reconstruction_output[filename].append((slice_no, curr_slice.cpu())) targets_output[filename].append((slice_no, target_abs[idx].cpu())) # Average loss dict loss_dict = reduce_list_of_dicts(val_losses) reduce_tensor_dict(loss_dict) # Log slices. visualize_slices = make_grid(visualize_slices, nrow=4, scale_each=True) storage.add_image('validation/prediction', visualize_slices) if evaluation_round == 0: visualize_target = make_grid(visualize_target, nrow=4, scale_each=True) storage.add_image('validation/target', visualize_target) communication.synchronize() torch.cuda.empty_cache() return loss_dict
def setup(run_name, training_root, validation_root, base_directory, cfg_filename, device, num_workers, resume, machine_rank): experiment_dir = base_directory / run_name if communication.get_local_rank() == 0: # Want to prevent multiple workers from trying to write a directory # This is required in the logging below experiment_dir.mkdir(parents=True, exist_ok=True) communication.synchronize() # Ensure folders are in place. # Load configs from YAML file to check which model needs to be loaded. cfg_from_file = OmegaConf.load(cfg_filename) model_name = cfg_from_file.model_name + 'Config' try: model_cfg = str_to_class(f'direct.nn.{cfg_from_file.model_name.lower()}.config', model_name) except (AttributeError, ModuleNotFoundError) as e: logger.error(f'Model configuration does not exist for {cfg_from_file.model_name} (err = {e}).') sys.exit(-1) # Load the default configs to ensure type safety base_cfg = OmegaConf.structured(DefaultConfig) base_cfg = OmegaConf.merge(base_cfg, {'model': model_cfg, 'training': TrainingConfig()}) cfg = OmegaConf.merge(base_cfg, cfg_from_file) # Setup logging log_file = experiment_dir / f'log_{machine_rank}_{communication.get_local_rank()}.txt' direct.utils.logging.setup( use_stdout=communication.get_local_rank() == 0 or cfg.debug, filename=log_file, log_level=('INFO' if not cfg.debug else 'DEBUG') ) logger.info(f'Machine rank: {machine_rank}.') logger.info(f'Local rank: {communication.get_local_rank()}.') logger.info(f'Logging: {log_file}.') logger.info(f'Saving to: {experiment_dir}.') logger.info(f'Run name: {run_name}.') logger.info(f'Config file: {cfg_filename}.') logger.info(f'Python version: {sys.version}.') logger.info(f'PyTorch version: {torch.__version__}.') # noqa logger.info(f'CUDA {torch.version.cuda} - cuDNN {torch.backends.cudnn.version()}.') logger.info(f'Configuration: {pformat(dict(cfg))}.') # Create the model logger.info('Building model.') model = MRIReconstruction(2, **cfg.model).to(device) n_params = sum(p.numel() for p in model.parameters()) logger.info(f'Number of parameters: {n_params} ({n_params / 10.0**3:.2f}k).') logger.debug(model) # Create training and validation data train_mask_func, val_mask_func = build_masking_functions(**cfg.masking) train_transforms, val_transforms = build_mri_transforms( train_mask_func, val_mask_func=val_mask_func, crop=cfg.dataset.transforms.crop) training_data, validation_data = build_datasets( cfg.dataset.name, training_root, train_sensitivity_maps=None, train_transforms=train_transforms, validation_root=validation_root, val_sensitivity_maps=None, val_transforms=val_transforms) # Create the optimizers logger.info('Building optimizers.') optimizer: torch.optim.Optimizer = str_to_class('torch.optim', cfg.training.optimizer)( # noqa model.parameters(), lr=cfg.training.lr, weight_decay=cfg.training.weight_decay ) # noqa # Build the LR scheduler, we use a fixed LR schedule step size, no adaptive training schedule. solver_steps = list(range(cfg.training.lr_step_size, cfg.training.num_iterations, cfg.training.lr_step_size)) lr_scheduler = WarmupMultiStepLR( optimizer, solver_steps, cfg.training.lr_gamma, warmup_factor=1 / 3., warmup_iters=cfg.training.lr_warmup_iter, warmup_method='linear') # Just to make sure. torch.cuda.empty_cache() # Setup training engine. engine = RIMEngine(cfg, model, device=device) engine.train( optimizer, lr_scheduler, training_data, experiment_dir, validation_data=validation_data, resume=resume, num_workers=num_workers)