Exemplo n.º 1
0
    def training_loop(self,
                      data_loader: DataLoader,
                      start_iter: int,
                      validation_data_loader: Optional[DataLoader] = None):
        self.logger.info(f'Local rank: {communication.get_local_rank()}.')
        self.model.train()
        loss_fns = self.build_loss()
        metric_fns = self.build_metrics()
        storage = get_event_storage()

        total_iter = self.cfg.training.num_iterations  # noqa
        for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
            data = AddNames()(data)
            if iter_idx == 0:
                self.logger.info(
                    f"First case: slice_no: {data['slice_no'][0]}, filename: {data['filename'][0]}."
                )
                storage.add_image('train/mask', data['sampling_mask'][0, ...,
                                                                      0])
                storage.add_image(
                    'train/target',
                    normalize_image(
                        data['target'][0].rename(None).unsqueeze(0)))
                storage.add_image(
                    'train/masked_image',
                    normalize_image(
                        transforms.modulus_if_complex(
                            data['masked_image'][0]).rename(None).unsqueeze(
                                0)))
                self.write_to_logs()

            try:
                output, loss_dict = self._do_iteration(data, loss_fns)
            except ProcessKilledException as e:
                # If the process is killed, the output if saved at state iter_idx, which is the current state,
                # so the computation can restart from the last iteration.
                self.logger.exception(f'{e}.')
                self.checkpointer.save(
                    iter_idx)  # Save checkpoint at kill. # noqa
                self.write_to_logs(
                )  # TODO: This causes the issue that current metrics are not written,
                # and you end up with an empty line.
                sys.exit(f'Exited with exception: {e}')

            # Gradient accumulation
            if (iter_idx +
                    1) % self.cfg.training.gradient_steps == 0:  # type: ignore
                # TODO: Is this slow? This is a generator, so should be cheap.
                pass
                # parameter_list = self.model.parameters() if not self.mixed_precision else \
                #     amp.master_params(self.__optimizer)
                # if self.cfg.training.gradient_steps > 1:  # type: ignore
                #     for parameter in parameter_list:
                #         if parameter.grad is not None:
                #             # In-place division
                #             parameter.grad.div_(self.cfg.training.gradient_steps)  # type: ignore
                # if self.cfg.training.gradient_clipping > 0.0:  # type: ignore
                #     torch.nn.utils.clip_grad_norm_(parameter_list, self.cfg.training.gradient_clipping)  # type: ignore
                #
                # # Gradient norm
                # if self.cfg.training.gradient_debug:  # type: ignore
                #     parameters = list(filter(lambda p: p.grad is not None, parameter_list))
                #     gradient_norm = sum([parameter.grad.data ** 2 for parameter in parameters]).sqrt()  # typing: ignore
                #     storage.add_scalar('gradient_norm', gradient_norm)

            self.__optimizer.step()  # type: ignore
            # Incorrect inference by mypy and pyflake
            self.__lr_scheduler.step()  # type: ignore # noqa
            storage.add_scalar('lr',
                               self.__optimizer.param_groups[0]['lr'],
                               smoothing_hint=False)

            self.__optimizer.zero_grad()  # type: ignore

            # Reduce the loss over all devices
            loss_dict_reduced = communication.reduce_tensor_dict(loss_dict)
            loss_reduced = sum(loss_dict_reduced.values())

            metrics_dict = evaluate_dict(
                metric_fns,
                transforms.modulus_if_complex(output.detach()).rename(None),
                data['target'].rename(None).detach().to(self.device),
                reduction='mean')
            metrics_dict_reduced = communication.reduce_tensor_dict(
                metrics_dict) if metrics_dict else {}
            storage.add_scalars(loss=loss_reduced,
                                **loss_dict_reduced,
                                **metrics_dict_reduced)

            if validation_data_loader is not None \
                    and iter_idx > 5\
                    and (iter_idx % self.cfg.training.validation_steps == 0 or (iter_idx + 1) == total_iter):
                val_loss_dict = self.evaluate(
                    validation_data_loader,
                    loss_fns,
                    evaluation_round=iter_idx //
                    self.cfg.training.validation_steps - 1)
                self.logger.info(f'Done evaluation at iteration {iter_idx}.')
                storage.add_scalars(**prefix_dict_keys(val_loss_dict, 'val_'),
                                    smoothing_hint=False)
                self.model.train()

            if iter_idx > 5 and\
                    (iter_idx % self.cfg.training.checkpointer.checkpoint_steps == 0 or (iter_idx + 1) == total_iter):
                self.logger.info(f'Checkpointing at iteration: {iter_idx}.')
                self.checkpointer.save(iter_idx)

            # Log every 20 iterations, or at a validation step or at the end of training.
            if iter_idx > 5 and (iter_idx % 20 == 0 or iter_idx %
                                 self.cfg.training.validation_steps == 0 or
                                 (iter_idx + 1) == total_iter):
                self.write_to_logs()

            storage.step()
Exemplo n.º 2
0
    def training_loop(
        self,
        data_loader: DataLoader,
        start_iter: int,
        validation_data_loaders: Optional[List[DataLoader]] = None,
        experiment_directory: Optional[pathlib.Path] = None,
    ):
        self.logger.info(f"Local rank: {communication.get_local_rank()}.")
        self.models_training_mode()

        loss_fns = self.build_loss()
        metric_fns = self.build_metrics(self.cfg.training.metrics)
        storage = get_event_storage()

        total_iter = self.cfg.training.num_iterations  # noqa
        for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
            data = AddNames()(data)
            if iter_idx == start_iter:
                self.ndim = self.compute_dimensionality_from_sample(data)
                self.logger.info(f"Data dimensionality: {self.ndim}.")

            if iter_idx == 0:
                self.log_first_training_example(data)

            try:
                output, loss_dict = self._do_iteration(data, loss_fns)
            except ProcessKilledException as e:
                # If the process is killed, the output if saved at state iter_idx, which is the current state,
                # so the computation can restart from the last iteration.
                self.logger.exception(f"Exiting with exception: {e}.")
                self.checkpointer.save(
                    iter_idx)  # Save checkpoint at kill. # noqa
                self.write_to_logs(
                )  # TODO: This causes the issue that current metrics are not written,
                # and you end up with an empty line.
                sys.exit(-1)

            # Gradient accumulation
            if (iter_idx +
                    1) % self.cfg.training.gradient_steps == 0:  # type: ignore
                if self.cfg.training.gradient_steps > 1:  # type: ignore
                    for parameter in self.model.parameters():
                        if parameter.grad is not None:
                            # In-place division
                            parameter.grad.div_(self.cfg.training.
                                                gradient_steps)  # type: ignore
                if self.cfg.training.gradient_clipping > 0.0:  # type: ignore
                    self._scaler.unscale_(self.__optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.cfg.training.gradient_clipping)

                # Gradient norm
                if self.cfg.training.gradient_debug:  # type: ignore
                    warnings.warn(
                        f"Gradient debug set. This will affect training performance. Only use for debugging."
                        f"This message will only be displayed once.")
                    parameters = list(
                        filter(lambda p: p.grad is not None,
                               self.model.parameters()))
                    gradient_norm = sum([
                        parameter.grad.data**2 for parameter in parameters
                    ]).sqrt()  # typing: ignore
                    storage.add_scalar("train/gradient_norm", gradient_norm)

                # Same as self.__optimizer.step() for mixed precision.
                self._scaler.step(self.__optimizer)
                # Updates the scale for next iteration.
                self._scaler.update()

            # Incorrect inference by mypy and pyflake
            self.__lr_scheduler.step()  # type: ignore # noqa
            storage.add_scalar("lr",
                               self.__optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)

            self.__optimizer.zero_grad()  # type: ignore

            # Reduce the loss over all devices
            loss_dict_reduced = communication.reduce_tensor_dict(loss_dict)
            loss_reduced = sum(loss_dict_reduced.values())

            metrics_dict = evaluate_dict(
                metric_fns,
                transforms.modulus_if_complex(output.detach()).rename(None),
                data["target"].rename(None).detach().to(self.device),
                reduction="mean",
            )
            metrics_dict_reduced = (
                communication.reduce_tensor_dict(metrics_dict)
                if metrics_dict else {})
            storage.add_scalars(loss=loss_reduced,
                                **loss_dict_reduced,
                                **metrics_dict_reduced)

            if iter_idx > 5 and (
                    iter_idx % self.cfg.training.checkpointer.checkpoint_steps
                    == 0 or (iter_idx + 1) == total_iter):
                self.logger.info(f"Checkpointing at iteration {iter_idx}.")
                self.checkpointer.save(iter_idx)

            if (validation_data_loaders is not None and iter_idx > 5
                    and (iter_idx % self.cfg.training.validation_steps == 0 or
                         (iter_idx + 1) == total_iter)):
                for (
                        curr_dataset_name,
                        curr_validation_data_loader,
                ) in validation_data_loaders:
                    self.logger.info(
                        f"Evaluating: {curr_dataset_name}..."
                    )  # TODO(jt): Fix with better names and stuff.
                    (
                        curr_val_loss_dict,
                        curr_val_metric_dict_per_case,
                        visualize_slices,
                        visualize_target,
                    ) = self.evaluate(
                        curr_validation_data_loader,
                        loss_fns,
                        is_validation_process=True,
                    )

                    if experiment_directory:
                        # Make dictionary serializable for logging
                        serializable_val_metric_dict = {
                            k0: {k1: float(v1)
                                 for k1, v1 in v0.items()}
                            for k0, v0 in
                            curr_val_metric_dict_per_case.items()
                        }
                        write_json(
                            experiment_directory /
                            f"metrics_val_{curr_dataset_name}_{iter_idx}.json",
                            serializable_val_metric_dict,
                        )

                    # Metric dict still needs to be reduced as it gives values *per* data
                    curr_val_metric_dict = reduce_list_of_dicts(list(
                        curr_val_metric_dict_per_case.values()),
                                                                mode="average")

                    key_prefix = ("val/" if not curr_dataset_name else
                                  f"val/{curr_dataset_name}/")
                    val_loss_reduced = sum(curr_val_loss_dict.values())
                    storage.add_scalars(
                        **{key_prefix + "loss": val_loss_reduced},
                        **{
                            **prefix_dict_keys(curr_val_metric_dict, key_prefix),
                            **prefix_dict_keys(curr_val_loss_dict, key_prefix),
                        },
                        smoothing_hint=False,
                    )
                    visualize_slices = self.process_slices_for_visualization(
                        visualize_slices, visualize_target)
                    storage.add_image(f"{key_prefix}prediction",
                                      visualize_slices)

                    if iter_idx // self.cfg.training.validation_steps - 1 == 0:
                        visualize_target = make_grid(
                            crop_to_largest(visualize_target, pad_value=0),
                            nrow=self.cfg.tensorboard.num_images,
                            scale_each=True,
                        )
                        storage.add_image(f"{key_prefix}target",
                                          visualize_target)

                self.logger.info(
                    f"Done evaluation of {curr_dataset_name} at iteration {iter_idx}."
                )
                self.model.train()

            # Log every 20 iterations, or at a validation step or at the end of training.
            if iter_idx > 5 and (iter_idx % 20 == 0 or iter_idx %
                                 self.cfg.training.validation_steps == 0 or
                                 (iter_idx + 1) == total_iter):
                self.write_to_logs()

            storage.step()
Exemplo n.º 3
0
    def evaluate(
        self,
        data_loader: DataLoader,
        loss_fns: Optional[Dict[str, Callable]],
        regularizer_fns: Optional[Dict[str, Callable]] = None,
        crop: Optional[str] = None,
        is_validation_process=True,
    ):

        self.models_to_device()
        self.models_validation_mode()
        torch.cuda.empty_cache()

        # Variables required for evaluation.
        # TODO(jt): Consider if this needs to be in the main engine.py or here. Might be possible we have different
        # types needed, perhaps even a FastMRI engine or something similar depending on the metrics.
        volume_metrics = self.build_metrics(self.cfg.validation.metrics)

        # filenames can be in the volume_indices attribute of the dataset
        if hasattr(data_loader.dataset, "volume_indices"):
            all_filenames = list(data_loader.dataset.volume_indices.keys())
            num_for_this_process = len(
                list(data_loader.batch_sampler.sampler.volume_indices.keys()))
            self.logger.info(
                f"Reconstructing a total of {len(all_filenames)} volumes. "
                f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})."
            )
        else:
            num_for_this_process = None
        filenames_seen = 0

        reconstruction_output = defaultdict(list)
        targets_output = defaultdict(list)
        val_losses = []
        val_volume_metrics = defaultdict(dict)
        last_filename = None

        # Container to for the slices which can be visualized in TensorBoard.
        visualize_slices = []
        visualize_target = []
        visualizations = {}

        extra_visualization_keys = (self.cfg.logging.log_as_image
                                    if self.cfg.logging.log_as_image else [])

        # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
        # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
        # that the slices are outputted from the Dataset *sequentially* for each volume one by one.
        time_start = time.time()

        for iter_idx, data in enumerate(data_loader):
            data = AddNames()(data)
            filenames = data.pop("filename")
            if len(set(filenames)) != 1:
                raise ValueError(
                    f"Expected a batch during validation to only contain filenames of one case. "
                    f"Got {set(filenames)}.")

            slice_nos = data.pop("slice_no")
            scaling_factors = data["scaling_factor"]

            resolution = self.compute_resolution(
                key=self.cfg.validation.crop,
                reconstruction_size=data.get("reconstruction_size", None),
            )

            # Compute output and loss.
            iteration_output = self._do_iteration(
                data, loss_fns, regularizer_fns=regularizer_fns)
            output = iteration_output.output_image
            loss_dict = iteration_output.data_dict
            # sensitivity_map = iteration_output.sensitivity_map

            loss_dict = detach_dict(loss_dict)
            output = output.detach()
            val_losses.append(loss_dict)

            # Output is complex-valued, and has to be cropped. This holds for both output and target.
            output_abs = self.process_output(
                output.refine_names(*self.complex_names()),
                scaling_factors,
                resolution=resolution,
            )

            if is_validation_process:
                target_abs = self.process_output(
                    data["target"].detach().refine_names(*self.real_names()),
                    scaling_factors,
                    resolution=resolution,
                )
                for key in extra_visualization_keys:
                    curr_data = data[key].detach()
                    # Here we need to discover which keys are actually normalized or not
                    # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23

            del output  # Explicitly call delete to clear memory.
            # TODO: Is a hack.

            # Aggregate volumes to be able to compute the metrics on complete volumes.
            for idx, filename in enumerate(filenames):
                if last_filename is None:
                    last_filename = (
                        filename  # First iteration last_filename is not set.
                    )

                # If the new filename is not the previous one, then we can reconstruct the volume as the sampling
                # is linear.
                # For the last case we need to check if we are at the last batch *and* at the last element in the batch.
                is_last_element_of_last_batch = iter_idx + 1 == len(
                    data_loader) and idx + 1 == len(data["target"])
                if filename != last_filename or is_last_element_of_last_batch:
                    filenames_seen += 1
                    # Now we can ditch the reconstruction dict by reconstructing the volume,
                    # will take too much memory otherwise.
                    # TODO: Stack does not support named tensors.
                    volume = torch.stack([
                        _[1].rename(None)
                        for _ in reconstruction_output[last_filename]
                    ])
                    if is_validation_process:
                        target = torch.stack([
                            _[1].rename(None)
                            for _ in targets_output[last_filename]
                        ])
                        curr_metrics = {
                            metric_name: metric_fn(target, volume)
                            for metric_name, metric_fn in
                            volume_metrics.items()
                        }
                        val_volume_metrics[last_filename] = curr_metrics
                        # Log the center slice of the volume
                        if (len(visualize_slices) <
                                self.cfg.logging.tensorboard.num_images):
                            visualize_slices.append(volume[volume.shape[0] //
                                                           2])
                            visualize_target.append(target[target.shape[0] //
                                                           2])

                        # Delete outputs from memory, and recreate dictionary. This is not needed when not in validation
                        # as we are actually interested in the output
                        del targets_output
                        targets_output = defaultdict(list)
                        del reconstruction_output
                        reconstruction_output = defaultdict(list)

                    if all_filenames:
                        log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
                    else:
                        log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"

                    self.logger.info(
                        f"{log_prefix} {last_filename}"
                        f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
                    )
                    # restart timer
                    time_start = time.time()
                    last_filename = filename

                curr_slice = output_abs[idx].detach()
                slice_no = int(slice_nos[idx].numpy())

                # TODO: CPU?
                reconstruction_output[filename].append(
                    (slice_no, curr_slice.cpu()))

                if is_validation_process:
                    targets_output[filename].append(
                        (slice_no, target_abs[idx].cpu()))

        # Average loss dict
        loss_dict = reduce_list_of_dicts(val_losses)
        reduce_tensor_dict(loss_dict)

        communication.synchronize()
        torch.cuda.empty_cache()

        # TODO: Does not work yet with normal gather.
        all_gathered_metrics = merge_list_of_dicts(
            communication.all_gather(val_volume_metrics))
        if not is_validation_process:
            return loss_dict, reconstruction_output

        # TODO: Apply named tuples where applicable
        # TODO: Several functions have multiple output values, in many cases
        # TODO: it would be more convenient to convert this to namedtuples.
        return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
Exemplo n.º 4
0
    def evaluate(
        self,
        data_loader: DataLoader,
        loss_fns: Optional[Dict[str, Callable]],
        crop: Optional[str] = None,
        is_validation_process=True,
    ):

        # TODO(jt): Also log other models output (e.g. sensitivity map).
        # TODO(jt): This can be simplified as the sampler now only outputs batches belonging to the same volume.
        self.models_to_device()
        self.models_validation_mode()
        torch.cuda.empty_cache()

        # Variables required for evaluation.
        # TODO(jt): Consider if this needs to be in the main engine.py or here. Might be possible we have different
        # types needed, perhaps even a FastMRI engine or something similar depending on the metrics.
        volume_metrics = self.build_metrics(self.cfg.validation.metrics)

        reconstruction_output = defaultdict(list)
        targets_output = defaultdict(list)
        val_losses = []
        val_volume_metrics = defaultdict(dict)
        last_filename = None

        # Container to for the slices which can be visualized in TensorBoard.
        visualize_slices = []
        visualize_target = []

        # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
        # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
        # that the slices are outputted from the Dataset *sequentially* for each volume one by one.
        for iter_idx, data in enumerate(data_loader):
            self.log_process(iter_idx, len(data_loader))
            data = AddNames()(data)
            filenames = data.pop("filename")
            if len(set(filenames)) != 1:
                raise ValueError(
                    f"Expected a batch during validation to only contain filenames of one case. "
                    f"Got {set(filenames)}.")

            slice_nos = data.pop("slice_no")
            scaling_factors = data.pop("scaling_factor")

            # Check if reconstruction size is the data
            if self.cfg.validation.crop == "header":
                # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over
                # batches.
                resolution = [
                    _.cpu().numpy().tolist()
                    for _ in data["reconstruction_size"]
                ]
                # The volume sampler should give validation indices belonging to the *same* volume, so it should be
                # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y).
                resolution = [_[0] for _ in resolution][:-1]
            elif self.cfg.validation.crop == "training":
                resolution = self.cfg.training.loss.crop
            elif not self.cfg.validation.loss.crop:
                resolution = None
            else:
                raise ValueError(
                    f"Cropping should be either set to `header` to get the values from the header or "
                    f"`training` to take the same value as training.")

            # Compute output and loss.
            output, loss_dict = self._do_iteration(data, loss_fns)
            val_losses.append(loss_dict)

            # Output is complex-valued, and has to be cropped. This holds for both output and target.
            output_abs = self.process_output(
                output.refine_names(*self.complex_names).detach(),
                scaling_factors,
                resolution=resolution,
            )

            if is_validation_process:
                target_abs = self.process_output(
                    data["target"].refine_names(*self.real_names).detach(),
                    scaling_factors,
                    resolution=resolution,
                )
            del output  # Explicitly call delete to clear memory.
            # TODO: Is a hack.

            # Aggregate volumes to be able to compute the metrics on complete volumes.
            for idx, filename in enumerate(filenames):
                if last_filename is None:
                    last_filename = (
                        filename  # First iteration last_filename is not set.
                    )
                # If the new filename is not the previous one, then we can reconstruct the volume as the sampling
                # is linear.
                # For the last case we need to check if we are at the last batch *and* at the last element in the batch.
                if filename != last_filename or (
                        iter_idx + 1 == len(data_loader)
                        and idx + 1 == len(data["target"])):
                    # Now we can ditch the reconstruction dict by reconstructing the volume,
                    # will take too much memory otherwise.
                    # TODO: Stack does not support named tensors.
                    volume = torch.stack([
                        _[1].rename(None)
                        for _ in reconstruction_output[last_filename]
                    ])
                    self.logger.info(
                        f"Reconstructed {last_filename} (shape = {list(volume.shape)})."
                    )
                    if is_validation_process:
                        target = torch.stack([
                            _[1].rename(None)
                            for _ in targets_output[last_filename]
                        ])
                        curr_metrics = {
                            metric_name: metric_fn(volume, target)
                            for metric_name, metric_fn in
                            volume_metrics.items()
                        }
                        val_volume_metrics[last_filename] = curr_metrics
                        # Log the center slice of the volume
                        if len(visualize_slices
                               ) < self.cfg.tensorboard.num_images:
                            visualize_slices.append(
                                normalize_image(volume[volume.shape[0] // 2]))
                            visualize_target.append(
                                normalize_image(target[target.shape[0] // 2]))

                        # Delete outputs from memory, and recreate dictionary. This is not needed when not in validation
                        # as we are actually interested in the output
                        del targets_output
                        targets_output = defaultdict(list)
                        del reconstruction_output
                        reconstruction_output = defaultdict(list)

                    last_filename = filename

                curr_slice = output_abs[idx]
                slice_no = int(slice_nos[idx].numpy())

                # TODO: CPU?
                reconstruction_output[filename].append(
                    (slice_no, curr_slice.cpu()))

                if is_validation_process:
                    targets_output[filename].append(
                        (slice_no, target_abs[idx].cpu()))

        # Average loss dict
        loss_dict = reduce_list_of_dicts(val_losses)
        reduce_tensor_dict(loss_dict)

        communication.synchronize()
        torch.cuda.empty_cache()

        # TODO(jt): Does not work yet with normal gather.
        all_gathered_metrics = merge_list_of_dicts(
            communication.all_gather(val_volume_metrics))

        if not is_validation_process:
            return loss_dict, reconstruction_output

        # TODO(jt): Make named tuple
        return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
Exemplo n.º 5
0
    def training_loop(
        self,
        training_datasets: List,  # TODO(jt): Improve typing
        start_iter: int,
        validation_datasets: Optional[List] = None,
        experiment_directory: Optional[pathlib.Path] = None,
        num_workers: int = 6,
        start_with_validation: bool = False,
    ):
        self.logger.info(f"Local rank: {communication.get_local_rank()}.")
        self.models_training_mode()

        loss_fns = self.build_loss()
        metric_fns = self.build_metrics(self.cfg.training.metrics)
        regularizer_fns = self.build_regularizers(self.cfg.training.regularizers)
        storage = get_event_storage()

        self.ndim = training_datasets[0].ndim
        self.logger.info(f"Data dimensionality: {self.ndim}.")

        training_data = ConcatDataset(training_datasets)

        self.logger.info(f"Concatenated dataset length: {len(training_data)}.")
        self.logger.info(
            f"Building batch sampler for training set with batch size {self.cfg.training.batch_size}."
        )

        training_sampler = self.build_batch_sampler(
            training_datasets, self.cfg.training.batch_size, "random"
        )
        data_loader = self.build_loader(
            training_data,
            batch_sampler=training_sampler,
            num_workers=num_workers,
        )

        # Convenient shorthand
        validation_func = functools.partial(
            self.validation_loop,
            validation_datasets,
            loss_fns,
            experiment_directory,
            num_workers=num_workers,
        )

        total_iter = self.cfg.training.num_iterations  # noqa
        fail_counter = 0
        for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
            data = AddNames()(data)
            if iter_idx == 0:
                self.log_first_training_example_and_model(data)

            if start_with_validation and iter_idx == start_iter:
                self.logger.info(f"Starting with validation at iteration: {iter_idx}.")
                validation_func(iter_idx)
            try:
                iteration_output = self._do_iteration(
                    data, loss_fns, regularizer_fns=regularizer_fns
                )
                output = iteration_output.output_image
                loss_dict = iteration_output.data_dict
            except (ProcessKilledException, TrainingException) as e:
                # If the process is killed, the output if saved at state iter_idx, which is the current state,
                # so the computation can restart from the last iteration.
                self.logger.exception(f"Exiting with exception: {e}.")
                self.checkpoint_and_write_to_logs(iter_idx)
                sys.exit(-1)
            except RuntimeError as e:
                # Maybe string can change
                if "out of memory" in str(e):
                    if fail_counter == 3:
                        self.checkpoint_and_write_to_logs(iter_idx)
                        raise TrainingException(
                            f"OOM, could not recover after 3 tries: {e}."
                        )
                    fail_counter += 1
                    self.logger.info(
                        f"OOM Error: {e}. Skipping batch. Retry {fail_counter}/3."
                    )
                    self.__optimizer.zero_grad()
                    gc.collect()
                    torch.cuda.empty_cache()
                    continue
                self.checkpoint_and_write_to_logs(iter_idx)
                self.logger.info(f"Cannot recover from exception {e}. Exiting.")
                raise RuntimeError(e)

            if fail_counter > 0:
                self.logger.info(f"Recovered from OOM, skipped batch.")
            fail_counter = 0
            # Gradient accumulation
            if (iter_idx + 1) % self.cfg.training.gradient_steps == 0:  # type: ignore
                if self.cfg.training.gradient_steps > 1:  # type: ignore
                    for parameter in self.model.parameters():
                        if parameter.grad is not None:
                            # In-place division
                            parameter.grad.div_(self.cfg.training.gradient_steps)  # type: ignore
                if self.cfg.training.gradient_clipping > 0.0:  # type: ignore
                    self._scaler.unscale_(self.__optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.cfg.training.gradient_clipping
                    )

                # Gradient norm
                if self.cfg.training.gradient_debug:  # type: ignore
                    warnings.warn(
                        f"Gradient debug set. This will affect training performance. Only use for debugging."
                        f"This message will only be displayed once."
                    )
                    parameters = list(
                        filter(lambda p: p.grad is not None, self.model.parameters())
                    )
                    gradient_norm = sum(
                        [parameter.grad.data ** 2 for parameter in parameters]
                    ).sqrt()  # typing: ignore
                    storage.add_scalar("train/gradient_norm", gradient_norm)

                # Same as self.__optimizer.step() for mixed precision.
                self._scaler.step(self.__optimizer)
                # Updates the scale for next iteration.
                self._scaler.update()

            # Incorrect inference by mypy and pyflake
            self.__lr_scheduler.step()  # type: ignore # noqa
            storage.add_scalar(
                "lr", self.__optimizer.param_groups[0]["lr"], smoothing_hint=False
            )

            self.__optimizer.zero_grad()  # type: ignore

            # Reduce the loss over all devices
            loss_dict_reduced = communication.reduce_tensor_dict(loss_dict)
            loss_reduced = sum(loss_dict_reduced.values())

            metrics_dict = evaluate_dict(
                metric_fns,
                T.modulus_if_complex(output.detach()).rename(None),
                data["target"].rename(None).detach().to(self.device),
                reduction="mean",
            )
            metrics_dict_reduced = (
                communication.reduce_tensor_dict(metrics_dict) if metrics_dict else {}
            )
            storage.add_scalars(
                loss=loss_reduced, **loss_dict_reduced, **metrics_dict_reduced
            )

            self.checkpoint_model_at_interval(iter_idx, total_iter)
            self.write_to_logs_at_interval(iter_idx, total_iter)
            self.validate_model_at_interval(validation_func, iter_idx, total_iter)

            storage.step()
Exemplo n.º 6
0
    def evaluate(self,
                 data_loader: DataLoader,
                 loss_fns: Dict[str, Callable],
                 volume_metrics: Optional[Dict[str, Callable]] = None,
                 evaluation_round=0):

        self.logger.info(f'Evaluating...')
        self.model.eval()
        torch.cuda.empty_cache()

        # Variables required for evaluation.
        volume_metrics = volume_metrics if volume_metrics is not None else self.build_metrics()
        storage = get_event_storage()

        reconstruction_output = defaultdict(list)
        targets_output = defaultdict(list)
        val_losses = []
        val_volume_metrics = defaultdict(dict)
        last_filename = None

        # Container to for the slices which can be visualized in TensorBoard.
        visualize_slices = []
        visualize_target = []

        # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
        # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
        # that the slices are outputted from the Dataset *sequentially* for each volume one by one.
        for iter_idx, data in enumerate(data_loader):
            self.log_process(iter_idx, len(data_loader))
            data = AddNames()(data)
            filenames = data.pop('filename')
            slice_nos = data.pop('slice_no')
            scaling_factors = data.pop('scaling_factor')

            # Compute output and loss.
            output, loss_dict = self._do_iteration(data, loss_fns)
            val_losses.append(loss_dict)

            # Output is complex-valued, and has to be cropped. This holds for both output and target.
            output_abs = self.process_output(
                output.refine_names('batch', 'complex', 'height', 'width').detach(), scaling_factors, 320)
            target_abs = self.process_output(
                data['target'].refine_names('batch', 'height', 'width').detach(), scaling_factors, 320)
            del output  # Explicitly call delete to clear memory.
            # TODO: Is a hack.

            # Aggregate volumes to be able to compute the metrics on complete volumes.
            batch_counter = 0
            for idx, filename in enumerate(filenames):
                if last_filename is None:
                    last_filename = filename  # First iteration last_filename is not set.
                # If the new filename is not the previous one, then we can reconstruct the volume as the sampling
                # is linear.
                # For the last case we need to check if we are at the last batch *and* at the last element in the batch.
                if filename != last_filename or (iter_idx + 1 == len(data_loader) and idx + 1 == len(data['target'])):
                    # Now we can ditch the reconstruction dict by reconstructing the volume,
                    # will take too mucih memory otherwise.
                    # TODO: Stack does not support named tensors.
                    volume = torch.stack([_[1].rename(None) for _ in reconstruction_output[last_filename]])
                    target = torch.stack([_[1].rename(None) for _ in targets_output[last_filename]])
                    self.logger.info(f'Reconstructed {last_filename} (shape = {list(volume.shape)}).')
                    curr_metrics = {
                        metric_name: metric_fn(volume, target) for metric_name, metric_fn in volume_metrics.items()}
                    val_volume_metrics[last_filename] = curr_metrics

                    # Log the center slice of the volume
                    if len(visualize_slices) < self.cfg.tensorboard.num_images:
                        visualize_slices.append(normalize_image(volume[volume.shape[0] // 2]))
                        # Target only needs to be logged once.
                        if evaluation_round == 0:
                            visualize_target.append(normalize_image(target[target.shape[0] // 2]))

                    last_filename = filename

                    # Delete outputs from memory, and recreate dictionary.
                    del reconstruction_output
                    del targets_output
                    reconstruction_output = defaultdict(list)
                    targets_output = defaultdict(list)

                curr_slice = output_abs[idx]
                slice_no = int(slice_nos[idx].numpy())

                # TODO: CPU?
                reconstruction_output[filename].append((slice_no, curr_slice.cpu()))
                targets_output[filename].append((slice_no, target_abs[idx].cpu()))

        # Average loss dict
        loss_dict = reduce_list_of_dicts(val_losses)
        reduce_tensor_dict(loss_dict)

        # Log slices.
        visualize_slices = make_grid(visualize_slices, nrow=4, scale_each=True)
        storage.add_image('validation/prediction', visualize_slices)

        if evaluation_round == 0:
            visualize_target = make_grid(visualize_target, nrow=4, scale_each=True)
            storage.add_image('validation/target', visualize_target)

        communication.synchronize()
        torch.cuda.empty_cache()

        return loss_dict