def _do_iteration(self, data: Dict[str, torch.Tensor], loss_fns: Dict[str, Callable]) -> Tuple[torch.Tensor, Dict]: # Target is not needed in the model input target = data['target'].align_to('batch', 'complex', 'height', 'width').to(self.device) # type: ignore # The first input_image in the iteration is the input_image with the mask applied and no first hidden state. input_image = data.pop('masked_image').to(self.device) # type: ignore hidden_state = None output_image = None loss_dicts = [] for rim_step in range(self.cfg.model.steps): reconstruction_iter, hidden_state = self.model( **dict_to_device(data, self.device), input_image=input_image, hidden_state=hidden_state, ) # TODO: Unclear why this refining is needed. output_image = reconstruction_iter[-1].refine_names('batch', 'complex', 'height', 'width') loss_dict = {k: torch.tensor([0.], dtype=target.dtype).to(self.device) for k in loss_fns.keys()} loss = torch.tensor([0.], device=output_image.device) for output_image_iter in reconstruction_iter: for k, v in loss_dict.items(): loss_dict[k] = v + loss_fns[k]( output_image_iter.rename(None), target.rename(None), reduction='mean' ) # for output_image_iter in reconstruction_iter: # loss_dict = { # k: v + loss_fns[k](output_image_iter.rename(None), target.rename(None), reduction='mean') # for k, v in loss_dict.items()} loss_dict = {k: v / len(reconstruction_iter) for k, v in loss_dict.items()} loss = sum(loss_dict.values()) if self.model.training: if self.mixed_precision: with amp.scale_loss(loss, self.__optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # type: ignore # Detach hidden state from computation graph, to ensure loss is only computed per RIM block. hidden_state = hidden_state.detach() input_image = output_image.detach() loss_dicts.append(detach_dict(loss_dict)) # Need to detach dict as this is only used for logging. # Add the loss dicts together over RIM steps, divide by the number of steps. loss_dict = reduce_list_of_dicts(loss_dicts, mode='sum', divisor=self.cfg.model.steps) return output_image, loss_dict
def training_loop( self, data_loader: DataLoader, start_iter: int, validation_data_loaders: Optional[List[DataLoader]] = None, experiment_directory: Optional[pathlib.Path] = None, ): self.logger.info(f"Local rank: {communication.get_local_rank()}.") self.models_training_mode() loss_fns = self.build_loss() metric_fns = self.build_metrics(self.cfg.training.metrics) storage = get_event_storage() total_iter = self.cfg.training.num_iterations # noqa for data, iter_idx in zip(data_loader, range(start_iter, total_iter)): data = AddNames()(data) if iter_idx == start_iter: self.ndim = self.compute_dimensionality_from_sample(data) self.logger.info(f"Data dimensionality: {self.ndim}.") if iter_idx == 0: self.log_first_training_example(data) try: output, loss_dict = self._do_iteration(data, loss_fns) except ProcessKilledException as e: # If the process is killed, the output if saved at state iter_idx, which is the current state, # so the computation can restart from the last iteration. self.logger.exception(f"Exiting with exception: {e}.") self.checkpointer.save( iter_idx) # Save checkpoint at kill. # noqa self.write_to_logs( ) # TODO: This causes the issue that current metrics are not written, # and you end up with an empty line. sys.exit(-1) # Gradient accumulation if (iter_idx + 1) % self.cfg.training.gradient_steps == 0: # type: ignore if self.cfg.training.gradient_steps > 1: # type: ignore for parameter in self.model.parameters(): if parameter.grad is not None: # In-place division parameter.grad.div_(self.cfg.training. gradient_steps) # type: ignore if self.cfg.training.gradient_clipping > 0.0: # type: ignore self._scaler.unscale_(self.__optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.training.gradient_clipping) # Gradient norm if self.cfg.training.gradient_debug: # type: ignore warnings.warn( f"Gradient debug set. This will affect training performance. Only use for debugging." f"This message will only be displayed once.") parameters = list( filter(lambda p: p.grad is not None, self.model.parameters())) gradient_norm = sum([ parameter.grad.data**2 for parameter in parameters ]).sqrt() # typing: ignore storage.add_scalar("train/gradient_norm", gradient_norm) # Same as self.__optimizer.step() for mixed precision. self._scaler.step(self.__optimizer) # Updates the scale for next iteration. self._scaler.update() # Incorrect inference by mypy and pyflake self.__lr_scheduler.step() # type: ignore # noqa storage.add_scalar("lr", self.__optimizer.param_groups[0]["lr"], smoothing_hint=False) self.__optimizer.zero_grad() # type: ignore # Reduce the loss over all devices loss_dict_reduced = communication.reduce_tensor_dict(loss_dict) loss_reduced = sum(loss_dict_reduced.values()) metrics_dict = evaluate_dict( metric_fns, transforms.modulus_if_complex(output.detach()).rename(None), data["target"].rename(None).detach().to(self.device), reduction="mean", ) metrics_dict_reduced = ( communication.reduce_tensor_dict(metrics_dict) if metrics_dict else {}) storage.add_scalars(loss=loss_reduced, **loss_dict_reduced, **metrics_dict_reduced) if iter_idx > 5 and ( iter_idx % self.cfg.training.checkpointer.checkpoint_steps == 0 or (iter_idx + 1) == total_iter): self.logger.info(f"Checkpointing at iteration {iter_idx}.") self.checkpointer.save(iter_idx) if (validation_data_loaders is not None and iter_idx > 5 and (iter_idx % self.cfg.training.validation_steps == 0 or (iter_idx + 1) == total_iter)): for ( curr_dataset_name, curr_validation_data_loader, ) in validation_data_loaders: self.logger.info( f"Evaluating: {curr_dataset_name}..." ) # TODO(jt): Fix with better names and stuff. ( curr_val_loss_dict, curr_val_metric_dict_per_case, visualize_slices, visualize_target, ) = self.evaluate( curr_validation_data_loader, loss_fns, is_validation_process=True, ) if experiment_directory: # Make dictionary serializable for logging serializable_val_metric_dict = { k0: {k1: float(v1) for k1, v1 in v0.items()} for k0, v0 in curr_val_metric_dict_per_case.items() } write_json( experiment_directory / f"metrics_val_{curr_dataset_name}_{iter_idx}.json", serializable_val_metric_dict, ) # Metric dict still needs to be reduced as it gives values *per* data curr_val_metric_dict = reduce_list_of_dicts(list( curr_val_metric_dict_per_case.values()), mode="average") key_prefix = ("val/" if not curr_dataset_name else f"val/{curr_dataset_name}/") val_loss_reduced = sum(curr_val_loss_dict.values()) storage.add_scalars( **{key_prefix + "loss": val_loss_reduced}, **{ **prefix_dict_keys(curr_val_metric_dict, key_prefix), **prefix_dict_keys(curr_val_loss_dict, key_prefix), }, smoothing_hint=False, ) visualize_slices = self.process_slices_for_visualization( visualize_slices, visualize_target) storage.add_image(f"{key_prefix}prediction", visualize_slices) if iter_idx // self.cfg.training.validation_steps - 1 == 0: visualize_target = make_grid( crop_to_largest(visualize_target, pad_value=0), nrow=self.cfg.tensorboard.num_images, scale_each=True, ) storage.add_image(f"{key_prefix}target", visualize_target) self.logger.info( f"Done evaluation of {curr_dataset_name} at iteration {iter_idx}." ) self.model.train() # Log every 20 iterations, or at a validation step or at the end of training. if iter_idx > 5 and (iter_idx % 20 == 0 or iter_idx % self.cfg.training.validation_steps == 0 or (iter_idx + 1) == total_iter): self.write_to_logs() storage.step()
def _do_iteration( self, data: Dict[str, torch.Tensor], loss_fns: Optional[Dict[str, Callable]] = None, regularizer_fns: Optional[Dict[str, Callable]] = None, ) -> namedtuple: # loss_fns can be done, e.g. during validation if loss_fns is None: loss_fns = {} if regularizer_fns is None: regularizer_fns = {} # The first input_image in the iteration is the input_image with the mask applied and no first hidden state. input_image = None hidden_state = None output_image = None loss_dicts = [] regularizer_dicts = [] data = dict_to_device(data, self.device) # TODO(jt): keys=['sampling_mask', 'sensitivity_map', 'target', 'masked_kspace', 'scaling_factor'] sensitivity_map = data["sensitivity_map"] if "noise_model" in self.models: raise NotImplementedError() # Some things can be done with the sensitivity map here, e.g. apply a u-net if "sensitivity_model" in self.models: # Move channels to first axis sensitivity_map = sensitivity_map.align_to(*self.complex_names( add_coil=True)) sensitivity_map = (self.compute_model_per_coil( "sensitivity_model", sensitivity_map).refine_names(*sensitivity_map.names).align_to( *self.complex_names_complex_last(add_coil=True))) # Output has channel first, it is ("batch, "coil", "complex", ...) # The sensitivity map needs to be normalized such that # So \sum_{i \in \text{coils}} S_i S_i^* = 1 sensitivity_map_norm = torch.sqrt( ((sensitivity_map**2).sum("complex")).sum("coil")) data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm) if self.cfg.model.scale_loglikelihood: scaling_factor = (1.0 * self.cfg.model.scale_loglikelihood / (data["scaling_factor"]**2)) scaling_factor = scaling_factor.reshape(-1, 1).refine_names( "batch", "complex") self.logger.debug(f"Scaling factor is: {scaling_factor}") else: # Needs fixing. scaling_factor = (torch.tensor([1.0]).to( sensitivity_map.device).refine_names("complex")) for _ in range(self.cfg.model.steps): with autocast(enabled=self.mixed_precision): reconstruction_iter, hidden_state = self.model( **data, input_image=input_image, hidden_state=hidden_state, loglikelihood_scaling=scaling_factor, ) # TODO: Unclear why this refining is needed. output_image = reconstruction_iter[-1].refine_names( *self.complex_names()) loss_dict = { k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys() } regularizer_dict = { k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() } # TODO: This seems too similar not to be able to do this, perhaps a partial can help here for output_image_iter in reconstruction_iter: for k, v in loss_dict.items(): loss_dict[k] = v + loss_fns[k]( output_image_iter, **data, reduction="mean", ) for k, v in regularizer_dict.items(): regularizer_dict[k] = (v + regularizer_fns[k]( output_image_iter, **data, ).rename(None)) loss_dict = { k: v / len(reconstruction_iter) for k, v in loss_dict.items() } regularizer_dict = { k: v / len(reconstruction_iter) for k, v in regularizer_dict.items() } loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) if self.model.training: self._scaler.scale(loss).backward() # Detach hidden state from computation graph, to ensure loss is only computed per RIM block. hidden_state = hidden_state.detach() input_image = output_image.detach() loss_dicts.append(detach_dict(loss_dict)) regularizer_dicts.append( detach_dict(regularizer_dict) ) # Need to detach dict as this is only used for logging. # Add the loss dicts together over RIM steps, divide by the number of steps. loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum", divisor=self.cfg.model.steps) regularizer_dict = reduce_list_of_dicts(regularizer_dicts, mode="sum", divisor=self.cfg.model.steps) output = namedtuple( "do_iteration", ["output_image", "sensitivity_map", "data_dict"], ) return output( output_image=output_image, sensitivity_map=data["sensitivity_map"], data_dict={ **loss_dict, **regularizer_dict }, )
def evaluate( self, data_loader: DataLoader, loss_fns: Optional[Dict[str, Callable]], regularizer_fns: Optional[Dict[str, Callable]] = None, crop: Optional[str] = None, is_validation_process=True, ): self.models_to_device() self.models_validation_mode() torch.cuda.empty_cache() # Variables required for evaluation. # TODO(jt): Consider if this needs to be in the main engine.py or here. Might be possible we have different # types needed, perhaps even a FastMRI engine or something similar depending on the metrics. volume_metrics = self.build_metrics(self.cfg.validation.metrics) # filenames can be in the volume_indices attribute of the dataset if hasattr(data_loader.dataset, "volume_indices"): all_filenames = list(data_loader.dataset.volume_indices.keys()) num_for_this_process = len( list(data_loader.batch_sampler.sampler.volume_indices.keys())) self.logger.info( f"Reconstructing a total of {len(all_filenames)} volumes. " f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})." ) else: num_for_this_process = None filenames_seen = 0 reconstruction_output = defaultdict(list) targets_output = defaultdict(list) val_losses = [] val_volume_metrics = defaultdict(dict) last_filename = None # Container to for the slices which can be visualized in TensorBoard. visualize_slices = [] visualize_target = [] visualizations = {} extra_visualization_keys = (self.cfg.logging.log_as_image if self.cfg.logging.log_as_image else []) # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is # that the slices are outputted from the Dataset *sequentially* for each volume one by one. time_start = time.time() for iter_idx, data in enumerate(data_loader): data = AddNames()(data) filenames = data.pop("filename") if len(set(filenames)) != 1: raise ValueError( f"Expected a batch during validation to only contain filenames of one case. " f"Got {set(filenames)}.") slice_nos = data.pop("slice_no") scaling_factors = data["scaling_factor"] resolution = self.compute_resolution( key=self.cfg.validation.crop, reconstruction_size=data.get("reconstruction_size", None), ) # Compute output and loss. iteration_output = self._do_iteration( data, loss_fns, regularizer_fns=regularizer_fns) output = iteration_output.output_image loss_dict = iteration_output.data_dict # sensitivity_map = iteration_output.sensitivity_map loss_dict = detach_dict(loss_dict) output = output.detach() val_losses.append(loss_dict) # Output is complex-valued, and has to be cropped. This holds for both output and target. output_abs = self.process_output( output.refine_names(*self.complex_names()), scaling_factors, resolution=resolution, ) if is_validation_process: target_abs = self.process_output( data["target"].detach().refine_names(*self.real_names()), scaling_factors, resolution=resolution, ) for key in extra_visualization_keys: curr_data = data[key].detach() # Here we need to discover which keys are actually normalized or not # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23 del output # Explicitly call delete to clear memory. # TODO: Is a hack. # Aggregate volumes to be able to compute the metrics on complete volumes. for idx, filename in enumerate(filenames): if last_filename is None: last_filename = ( filename # First iteration last_filename is not set. ) # If the new filename is not the previous one, then we can reconstruct the volume as the sampling # is linear. # For the last case we need to check if we are at the last batch *and* at the last element in the batch. is_last_element_of_last_batch = iter_idx + 1 == len( data_loader) and idx + 1 == len(data["target"]) if filename != last_filename or is_last_element_of_last_batch: filenames_seen += 1 # Now we can ditch the reconstruction dict by reconstructing the volume, # will take too much memory otherwise. # TODO: Stack does not support named tensors. volume = torch.stack([ _[1].rename(None) for _ in reconstruction_output[last_filename] ]) if is_validation_process: target = torch.stack([ _[1].rename(None) for _ in targets_output[last_filename] ]) curr_metrics = { metric_name: metric_fn(target, volume) for metric_name, metric_fn in volume_metrics.items() } val_volume_metrics[last_filename] = curr_metrics # Log the center slice of the volume if (len(visualize_slices) < self.cfg.logging.tensorboard.num_images): visualize_slices.append(volume[volume.shape[0] // 2]) visualize_target.append(target[target.shape[0] // 2]) # Delete outputs from memory, and recreate dictionary. This is not needed when not in validation # as we are actually interested in the output del targets_output targets_output = defaultdict(list) del reconstruction_output reconstruction_output = defaultdict(list) if all_filenames: log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:" else: log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:" self.logger.info( f"{log_prefix} {last_filename}" f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s." ) # restart timer time_start = time.time() last_filename = filename curr_slice = output_abs[idx].detach() slice_no = int(slice_nos[idx].numpy()) # TODO: CPU? reconstruction_output[filename].append( (slice_no, curr_slice.cpu())) if is_validation_process: targets_output[filename].append( (slice_no, target_abs[idx].cpu())) # Average loss dict loss_dict = reduce_list_of_dicts(val_losses) reduce_tensor_dict(loss_dict) communication.synchronize() torch.cuda.empty_cache() # TODO: Does not work yet with normal gather. all_gathered_metrics = merge_list_of_dicts( communication.all_gather(val_volume_metrics)) if not is_validation_process: return loss_dict, reconstruction_output # TODO: Apply named tuples where applicable # TODO: Several functions have multiple output values, in many cases # TODO: it would be more convenient to convert this to namedtuples. return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
def _do_iteration( self, data: Dict[str, torch.Tensor], loss_fns: Optional[Dict[str, Callable]]) -> Tuple[torch.Tensor, Dict]: # loss_fns can be done, e.g. during validation if loss_fns is None: loss_fns = {} # TODO(jt): Target is not needed in the model input, but in the loss computation. Keep it here for now. target = data["target"].align_to(*self.complex_names).to( self.device) # type: ignore # The first input_image in the iteration is the input_image with the mask applied and no first hidden state. input_image = data.pop("masked_image").to(self.device) # type: ignore hidden_state = None output_image = None loss_dicts = [] # TODO: Target might not need to be copied. data = dict_to_device(data, self.device) # TODO(jt): keys=['sampling_mask', 'sensitivity_map', 'target', 'masked_kspace', 'scaling_factor'] sensitivity_map = data["sensitivity_map"] # Some things can be done with the sensitivity map here, e.g. apply a u-net if "sensitivity_model" in self.models: sensitivity_map = self.compute_model_per_coil( self.models["sensitivity_model"], sensitivity_map) # The sensitivity map needs to be normalized such that # So \sum_{i \in \text{coils}} S_i S_i^* = 1 sensitivity_map_norm = modulus(sensitivity_map).sum("coil") data["sensitivity_map"] = safe_divide(sensitivity_map, sensitivity_map_norm) for rim_step in range(self.cfg.model.steps): with autocast(enabled=self.mixed_precision): reconstruction_iter, hidden_state = self.model( **data, input_image=input_image, hidden_state=hidden_state, ) # TODO: Unclear why this refining is needed. output_image = reconstruction_iter[-1].refine_names( *self.complex_names) loss_dict = { k: torch.tensor([0.0], dtype=target.dtype).to(self.device) for k in loss_fns.keys() } for output_image_iter in reconstruction_iter: for k, v in loss_dict.items(): loss_dict[k] = v + loss_fns[k]( output_image_iter, target, reduction="mean", ) loss_dict = { k: v / len(reconstruction_iter) for k, v in loss_dict.items() } loss = sum(loss_dict.values()) if self.model.training: self._scaler.scale(loss).backward() # Detach hidden state from computation graph, to ensure loss is only computed per RIM block. hidden_state = hidden_state.detach() input_image = output_image.detach() loss_dicts.append( detach_dict(loss_dict) ) # Need to detach dict as this is only used for logging. # Add the loss dicts together over RIM steps, divide by the number of steps. loss_dict = reduce_list_of_dicts(loss_dicts, mode="sum", divisor=self.cfg.model.steps) return output_image, loss_dict
def evaluate( self, data_loader: DataLoader, loss_fns: Optional[Dict[str, Callable]], crop: Optional[str] = None, is_validation_process=True, ): # TODO(jt): Also log other models output (e.g. sensitivity map). # TODO(jt): This can be simplified as the sampler now only outputs batches belonging to the same volume. self.models_to_device() self.models_validation_mode() torch.cuda.empty_cache() # Variables required for evaluation. # TODO(jt): Consider if this needs to be in the main engine.py or here. Might be possible we have different # types needed, perhaps even a FastMRI engine or something similar depending on the metrics. volume_metrics = self.build_metrics(self.cfg.validation.metrics) reconstruction_output = defaultdict(list) targets_output = defaultdict(list) val_losses = [] val_volume_metrics = defaultdict(dict) last_filename = None # Container to for the slices which can be visualized in TensorBoard. visualize_slices = [] visualize_target = [] # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is # that the slices are outputted from the Dataset *sequentially* for each volume one by one. for iter_idx, data in enumerate(data_loader): self.log_process(iter_idx, len(data_loader)) data = AddNames()(data) filenames = data.pop("filename") if len(set(filenames)) != 1: raise ValueError( f"Expected a batch during validation to only contain filenames of one case. " f"Got {set(filenames)}.") slice_nos = data.pop("slice_no") scaling_factors = data.pop("scaling_factor") # Check if reconstruction size is the data if self.cfg.validation.crop == "header": # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over # batches. resolution = [ _.cpu().numpy().tolist() for _ in data["reconstruction_size"] ] # The volume sampler should give validation indices belonging to the *same* volume, so it should be # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y). resolution = [_[0] for _ in resolution][:-1] elif self.cfg.validation.crop == "training": resolution = self.cfg.training.loss.crop elif not self.cfg.validation.loss.crop: resolution = None else: raise ValueError( f"Cropping should be either set to `header` to get the values from the header or " f"`training` to take the same value as training.") # Compute output and loss. output, loss_dict = self._do_iteration(data, loss_fns) val_losses.append(loss_dict) # Output is complex-valued, and has to be cropped. This holds for both output and target. output_abs = self.process_output( output.refine_names(*self.complex_names).detach(), scaling_factors, resolution=resolution, ) if is_validation_process: target_abs = self.process_output( data["target"].refine_names(*self.real_names).detach(), scaling_factors, resolution=resolution, ) del output # Explicitly call delete to clear memory. # TODO: Is a hack. # Aggregate volumes to be able to compute the metrics on complete volumes. for idx, filename in enumerate(filenames): if last_filename is None: last_filename = ( filename # First iteration last_filename is not set. ) # If the new filename is not the previous one, then we can reconstruct the volume as the sampling # is linear. # For the last case we need to check if we are at the last batch *and* at the last element in the batch. if filename != last_filename or ( iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"])): # Now we can ditch the reconstruction dict by reconstructing the volume, # will take too much memory otherwise. # TODO: Stack does not support named tensors. volume = torch.stack([ _[1].rename(None) for _ in reconstruction_output[last_filename] ]) self.logger.info( f"Reconstructed {last_filename} (shape = {list(volume.shape)})." ) if is_validation_process: target = torch.stack([ _[1].rename(None) for _ in targets_output[last_filename] ]) curr_metrics = { metric_name: metric_fn(volume, target) for metric_name, metric_fn in volume_metrics.items() } val_volume_metrics[last_filename] = curr_metrics # Log the center slice of the volume if len(visualize_slices ) < self.cfg.tensorboard.num_images: visualize_slices.append( normalize_image(volume[volume.shape[0] // 2])) visualize_target.append( normalize_image(target[target.shape[0] // 2])) # Delete outputs from memory, and recreate dictionary. This is not needed when not in validation # as we are actually interested in the output del targets_output targets_output = defaultdict(list) del reconstruction_output reconstruction_output = defaultdict(list) last_filename = filename curr_slice = output_abs[idx] slice_no = int(slice_nos[idx].numpy()) # TODO: CPU? reconstruction_output[filename].append( (slice_no, curr_slice.cpu())) if is_validation_process: targets_output[filename].append( (slice_no, target_abs[idx].cpu())) # Average loss dict loss_dict = reduce_list_of_dicts(val_losses) reduce_tensor_dict(loss_dict) communication.synchronize() torch.cuda.empty_cache() # TODO(jt): Does not work yet with normal gather. all_gathered_metrics = merge_list_of_dicts( communication.all_gather(val_volume_metrics)) if not is_validation_process: return loss_dict, reconstruction_output # TODO(jt): Make named tuple return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
def validation_loop( self, validation_datasets, loss_fns, experiment_directory, iter_idx, num_workers: int = 6, ): if not validation_datasets: return storage = get_event_storage() data_loaders = self.build_validation_loaders( validation_data=validation_datasets, num_workers=num_workers, ) for curr_dataset_name, curr_data_loader in data_loaders: self.logger.info(f"Evaluating: {curr_dataset_name}...") ( curr_loss_dict, curr_metrics_per_case, visualize_slices, visualize_target, ) = self.evaluate( curr_data_loader, loss_fns, is_validation_process=True, ) if experiment_directory: json_output_fn = ( experiment_directory / f"metrics_val_{curr_dataset_name}_{iter_idx}.json" ) json_output_fn.parent.mkdir( exist_ok=True, parents=True ) # A / in the filename can create a folder if communication.is_main_process(): write_json( json_output_fn, curr_metrics_per_case, ) self.logger.info(f"Wrote per image logs to: {json_output_fn}.") # Metric dict still needs to be reduced as it gives values *per* data curr_metric_dict = reduce_list_of_dicts( list(curr_metrics_per_case.values()), mode="average" ) key_prefix = ( "val/" if not curr_dataset_name else f"val/{curr_dataset_name}/" ) loss_reduced = sum(curr_loss_dict.values()) storage.add_scalars( **{key_prefix + "loss": loss_reduced}, **{ **prefix_dict_keys(curr_metric_dict, key_prefix), **prefix_dict_keys(curr_loss_dict, key_prefix), }, smoothing_hint=False, ) visualize_slices = self.process_slices_for_visualization( visualize_slices, visualize_target ) storage.add_image(f"{key_prefix}prediction", visualize_slices) if iter_idx // self.cfg.training.validation_steps - 1 == 0: visualize_target = make_grid( crop_to_largest(visualize_target, pad_value=0), nrow=self.cfg.logging.tensorboard.num_images, scale_each=True, ) storage.add_image(f"{key_prefix}target", visualize_target) self.logger.info( f"Done evaluation of {curr_dataset_name} at iteration {iter_idx}." ) self.model.train()
def evaluate(self, data_loader: DataLoader, loss_fns: Dict[str, Callable], volume_metrics: Optional[Dict[str, Callable]] = None, evaluation_round=0): self.logger.info(f'Evaluating...') self.model.eval() torch.cuda.empty_cache() # Variables required for evaluation. volume_metrics = volume_metrics if volume_metrics is not None else self.build_metrics() storage = get_event_storage() reconstruction_output = defaultdict(list) targets_output = defaultdict(list) val_losses = [] val_volume_metrics = defaultdict(dict) last_filename = None # Container to for the slices which can be visualized in TensorBoard. visualize_slices = [] visualize_target = [] # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is # that the slices are outputted from the Dataset *sequentially* for each volume one by one. for iter_idx, data in enumerate(data_loader): self.log_process(iter_idx, len(data_loader)) data = AddNames()(data) filenames = data.pop('filename') slice_nos = data.pop('slice_no') scaling_factors = data.pop('scaling_factor') # Compute output and loss. output, loss_dict = self._do_iteration(data, loss_fns) val_losses.append(loss_dict) # Output is complex-valued, and has to be cropped. This holds for both output and target. output_abs = self.process_output( output.refine_names('batch', 'complex', 'height', 'width').detach(), scaling_factors, 320) target_abs = self.process_output( data['target'].refine_names('batch', 'height', 'width').detach(), scaling_factors, 320) del output # Explicitly call delete to clear memory. # TODO: Is a hack. # Aggregate volumes to be able to compute the metrics on complete volumes. batch_counter = 0 for idx, filename in enumerate(filenames): if last_filename is None: last_filename = filename # First iteration last_filename is not set. # If the new filename is not the previous one, then we can reconstruct the volume as the sampling # is linear. # For the last case we need to check if we are at the last batch *and* at the last element in the batch. if filename != last_filename or (iter_idx + 1 == len(data_loader) and idx + 1 == len(data['target'])): # Now we can ditch the reconstruction dict by reconstructing the volume, # will take too mucih memory otherwise. # TODO: Stack does not support named tensors. volume = torch.stack([_[1].rename(None) for _ in reconstruction_output[last_filename]]) target = torch.stack([_[1].rename(None) for _ in targets_output[last_filename]]) self.logger.info(f'Reconstructed {last_filename} (shape = {list(volume.shape)}).') curr_metrics = { metric_name: metric_fn(volume, target) for metric_name, metric_fn in volume_metrics.items()} val_volume_metrics[last_filename] = curr_metrics # Log the center slice of the volume if len(visualize_slices) < self.cfg.tensorboard.num_images: visualize_slices.append(normalize_image(volume[volume.shape[0] // 2])) # Target only needs to be logged once. if evaluation_round == 0: visualize_target.append(normalize_image(target[target.shape[0] // 2])) last_filename = filename # Delete outputs from memory, and recreate dictionary. del reconstruction_output del targets_output reconstruction_output = defaultdict(list) targets_output = defaultdict(list) curr_slice = output_abs[idx] slice_no = int(slice_nos[idx].numpy()) # TODO: CPU? reconstruction_output[filename].append((slice_no, curr_slice.cpu())) targets_output[filename].append((slice_no, target_abs[idx].cpu())) # Average loss dict loss_dict = reduce_list_of_dicts(val_losses) reduce_tensor_dict(loss_dict) # Log slices. visualize_slices = make_grid(visualize_slices, nrow=4, scale_each=True) storage.add_image('validation/prediction', visualize_slices) if evaluation_round == 0: visualize_target = make_grid(visualize_target, nrow=4, scale_each=True) storage.add_image('validation/target', visualize_target) communication.synchronize() torch.cuda.empty_cache() return loss_dict