Ejemplo n.º 1
0
    def process_slices_for_visualization(self, visualize_slices, visualize_target):
        # Log slices.
        # Compute the difference as well, and normalize for visualization
        difference_slices = [a - b for a, b in zip(visualize_slices, visualize_target)]
        # Normalize slices
        difference_slices = [(d / np.abs(d)) * 0.5 + 0.5 for d in difference_slices]
        visualize_slices = [normalize_image(image) for image in visualize_slices]

        # Visualize slices, and crop to the largest volume
        visualize_slices = make_grid(
            crop_to_largest(visualize_slices + difference_slices, pad_value=0),
            nrow=self.cfg.logging.tensorboard.num_images,
            scale_each=True,
        )
        return visualize_slices
Ejemplo n.º 2
0
    def training_loop(
        self,
        data_loader: DataLoader,
        start_iter: int,
        validation_data_loaders: Optional[List[DataLoader]] = None,
        experiment_directory: Optional[pathlib.Path] = None,
    ):
        self.logger.info(f"Local rank: {communication.get_local_rank()}.")
        self.models_training_mode()

        loss_fns = self.build_loss()
        metric_fns = self.build_metrics(self.cfg.training.metrics)
        storage = get_event_storage()

        total_iter = self.cfg.training.num_iterations  # noqa
        for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
            data = AddNames()(data)
            if iter_idx == start_iter:
                self.ndim = self.compute_dimensionality_from_sample(data)
                self.logger.info(f"Data dimensionality: {self.ndim}.")

            if iter_idx == 0:
                self.log_first_training_example(data)

            try:
                output, loss_dict = self._do_iteration(data, loss_fns)
            except ProcessKilledException as e:
                # If the process is killed, the output if saved at state iter_idx, which is the current state,
                # so the computation can restart from the last iteration.
                self.logger.exception(f"Exiting with exception: {e}.")
                self.checkpointer.save(
                    iter_idx)  # Save checkpoint at kill. # noqa
                self.write_to_logs(
                )  # TODO: This causes the issue that current metrics are not written,
                # and you end up with an empty line.
                sys.exit(-1)

            # Gradient accumulation
            if (iter_idx +
                    1) % self.cfg.training.gradient_steps == 0:  # type: ignore
                if self.cfg.training.gradient_steps > 1:  # type: ignore
                    for parameter in self.model.parameters():
                        if parameter.grad is not None:
                            # In-place division
                            parameter.grad.div_(self.cfg.training.
                                                gradient_steps)  # type: ignore
                if self.cfg.training.gradient_clipping > 0.0:  # type: ignore
                    self._scaler.unscale_(self.__optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.cfg.training.gradient_clipping)

                # Gradient norm
                if self.cfg.training.gradient_debug:  # type: ignore
                    warnings.warn(
                        f"Gradient debug set. This will affect training performance. Only use for debugging."
                        f"This message will only be displayed once.")
                    parameters = list(
                        filter(lambda p: p.grad is not None,
                               self.model.parameters()))
                    gradient_norm = sum([
                        parameter.grad.data**2 for parameter in parameters
                    ]).sqrt()  # typing: ignore
                    storage.add_scalar("train/gradient_norm", gradient_norm)

                # Same as self.__optimizer.step() for mixed precision.
                self._scaler.step(self.__optimizer)
                # Updates the scale for next iteration.
                self._scaler.update()

            # Incorrect inference by mypy and pyflake
            self.__lr_scheduler.step()  # type: ignore # noqa
            storage.add_scalar("lr",
                               self.__optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)

            self.__optimizer.zero_grad()  # type: ignore

            # Reduce the loss over all devices
            loss_dict_reduced = communication.reduce_tensor_dict(loss_dict)
            loss_reduced = sum(loss_dict_reduced.values())

            metrics_dict = evaluate_dict(
                metric_fns,
                transforms.modulus_if_complex(output.detach()).rename(None),
                data["target"].rename(None).detach().to(self.device),
                reduction="mean",
            )
            metrics_dict_reduced = (
                communication.reduce_tensor_dict(metrics_dict)
                if metrics_dict else {})
            storage.add_scalars(loss=loss_reduced,
                                **loss_dict_reduced,
                                **metrics_dict_reduced)

            if iter_idx > 5 and (
                    iter_idx % self.cfg.training.checkpointer.checkpoint_steps
                    == 0 or (iter_idx + 1) == total_iter):
                self.logger.info(f"Checkpointing at iteration {iter_idx}.")
                self.checkpointer.save(iter_idx)

            if (validation_data_loaders is not None and iter_idx > 5
                    and (iter_idx % self.cfg.training.validation_steps == 0 or
                         (iter_idx + 1) == total_iter)):
                for (
                        curr_dataset_name,
                        curr_validation_data_loader,
                ) in validation_data_loaders:
                    self.logger.info(
                        f"Evaluating: {curr_dataset_name}..."
                    )  # TODO(jt): Fix with better names and stuff.
                    (
                        curr_val_loss_dict,
                        curr_val_metric_dict_per_case,
                        visualize_slices,
                        visualize_target,
                    ) = self.evaluate(
                        curr_validation_data_loader,
                        loss_fns,
                        is_validation_process=True,
                    )

                    if experiment_directory:
                        # Make dictionary serializable for logging
                        serializable_val_metric_dict = {
                            k0: {k1: float(v1)
                                 for k1, v1 in v0.items()}
                            for k0, v0 in
                            curr_val_metric_dict_per_case.items()
                        }
                        write_json(
                            experiment_directory /
                            f"metrics_val_{curr_dataset_name}_{iter_idx}.json",
                            serializable_val_metric_dict,
                        )

                    # Metric dict still needs to be reduced as it gives values *per* data
                    curr_val_metric_dict = reduce_list_of_dicts(list(
                        curr_val_metric_dict_per_case.values()),
                                                                mode="average")

                    key_prefix = ("val/" if not curr_dataset_name else
                                  f"val/{curr_dataset_name}/")
                    val_loss_reduced = sum(curr_val_loss_dict.values())
                    storage.add_scalars(
                        **{key_prefix + "loss": val_loss_reduced},
                        **{
                            **prefix_dict_keys(curr_val_metric_dict, key_prefix),
                            **prefix_dict_keys(curr_val_loss_dict, key_prefix),
                        },
                        smoothing_hint=False,
                    )
                    visualize_slices = self.process_slices_for_visualization(
                        visualize_slices, visualize_target)
                    storage.add_image(f"{key_prefix}prediction",
                                      visualize_slices)

                    if iter_idx // self.cfg.training.validation_steps - 1 == 0:
                        visualize_target = make_grid(
                            crop_to_largest(visualize_target, pad_value=0),
                            nrow=self.cfg.tensorboard.num_images,
                            scale_each=True,
                        )
                        storage.add_image(f"{key_prefix}target",
                                          visualize_target)

                self.logger.info(
                    f"Done evaluation of {curr_dataset_name} at iteration {iter_idx}."
                )
                self.model.train()

            # Log every 20 iterations, or at a validation step or at the end of training.
            if iter_idx > 5 and (iter_idx % 20 == 0 or iter_idx %
                                 self.cfg.training.validation_steps == 0 or
                                 (iter_idx + 1) == total_iter):
                self.write_to_logs()

            storage.step()
Ejemplo n.º 3
0
    def validation_loop(
        self,
        validation_datasets,
        loss_fns,
        experiment_directory,
        iter_idx,
        num_workers: int = 6,
    ):
        if not validation_datasets:
            return

        storage = get_event_storage()

        data_loaders = self.build_validation_loaders(
            validation_data=validation_datasets,
            num_workers=num_workers,
        )
        for curr_dataset_name, curr_data_loader in data_loaders:
            self.logger.info(f"Evaluating: {curr_dataset_name}...")
            (
                curr_loss_dict,
                curr_metrics_per_case,
                visualize_slices,
                visualize_target,
            ) = self.evaluate(
                curr_data_loader,
                loss_fns,
                is_validation_process=True,
            )

            if experiment_directory:
                json_output_fn = (
                    experiment_directory
                    / f"metrics_val_{curr_dataset_name}_{iter_idx}.json"
                )
                json_output_fn.parent.mkdir(
                    exist_ok=True, parents=True
                )  # A / in the filename can create a folder
                if communication.is_main_process():
                    write_json(
                        json_output_fn,
                        curr_metrics_per_case,
                    )
                self.logger.info(f"Wrote per image logs to: {json_output_fn}.")

            # Metric dict still needs to be reduced as it gives values *per* data
            curr_metric_dict = reduce_list_of_dicts(
                list(curr_metrics_per_case.values()), mode="average"
            )

            key_prefix = (
                "val/" if not curr_dataset_name else f"val/{curr_dataset_name}/"
            )
            loss_reduced = sum(curr_loss_dict.values())
            storage.add_scalars(
                **{key_prefix + "loss": loss_reduced},
                **{
                    **prefix_dict_keys(curr_metric_dict, key_prefix),
                    **prefix_dict_keys(curr_loss_dict, key_prefix),
                },
                smoothing_hint=False,
            )
            visualize_slices = self.process_slices_for_visualization(
                visualize_slices, visualize_target
            )
            storage.add_image(f"{key_prefix}prediction", visualize_slices)

            if iter_idx // self.cfg.training.validation_steps - 1 == 0:
                visualize_target = make_grid(
                    crop_to_largest(visualize_target, pad_value=0),
                    nrow=self.cfg.logging.tensorboard.num_images,
                    scale_each=True,
                )
                storage.add_image(f"{key_prefix}target", visualize_target)

            self.logger.info(
                f"Done evaluation of {curr_dataset_name} at iteration {iter_idx}."
            )
        self.model.train()