Example #1
0
    def process_output(self, data, scaling_factors=None, resolution=None):
        if scaling_factors is not None:
            data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device)
        data = modulus_if_complex(data).rename(None)
        if len(data.shape) == 3:  # (batch, height, width)
            data = data.unsqueeze(1)  # Added channel dimension.

        if resolution is not None:
            data = center_crop(data, (resolution, resolution)).contiguous()

        return data
Example #2
0
    def log_first_training_example(self, data):
        storage = get_event_storage()
        self.logger.info(
            f"First case: slice_no: {data['slice_no'][0]}, filename: {data['filename'][0]}."
        )

        # TODO(jt): Cleaner, loop over types of images
        first_sampling_mask = data["sampling_mask"][0][0]
        first_acs_mask = data["acs_mask"][0][0] if "acs_mask" in data else None
        first_target = data["target"][0]
        first_masked_image = data["masked_image"][0]

        if self.ndim == 3:
            first_sampling_mask = first_sampling_mask[0]
            if first_acs_mask:
                first_acs_mask = first_acs_mask[0]

            num_slices = first_target.shape[first_target.names.index("slice")]
            first_target = first_target[num_slices // 2]
            first_masked_image = first_masked_image[num_slices // 2]
        elif self.ndim > 3:
            raise NotImplementedError

        storage.add_image(
            "train/mask", first_sampling_mask[...,
                                              0].rename(None).unsqueeze(0))
        if first_acs_mask:
            storage.add_image(
                "train/acs_mask",
                first_acs_mask[..., 0].rename(None).unsqueeze(0),
            )
        storage.add_image(
            "train/target",
            normalize_image(first_target.rename(None).unsqueeze(0)),
        )
        storage.add_image(
            "train/masked_image",
            normalize_image(
                transforms.modulus_if_complex(first_masked_image).rename(
                    None).unsqueeze(0)),
        )
        self.write_to_logs()
Example #3
0
    def training_loop(self,
                      data_loader: DataLoader,
                      start_iter: int,
                      validation_data_loader: Optional[DataLoader] = None):
        self.logger.info(f'Local rank: {communication.get_local_rank()}.')
        self.model.train()
        loss_fns = self.build_loss()
        metric_fns = self.build_metrics()
        storage = get_event_storage()

        total_iter = self.cfg.training.num_iterations  # noqa
        for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
            data = AddNames()(data)
            if iter_idx == 0:
                self.logger.info(
                    f"First case: slice_no: {data['slice_no'][0]}, filename: {data['filename'][0]}."
                )
                storage.add_image('train/mask', data['sampling_mask'][0, ...,
                                                                      0])
                storage.add_image(
                    'train/target',
                    normalize_image(
                        data['target'][0].rename(None).unsqueeze(0)))
                storage.add_image(
                    'train/masked_image',
                    normalize_image(
                        transforms.modulus_if_complex(
                            data['masked_image'][0]).rename(None).unsqueeze(
                                0)))
                self.write_to_logs()

            try:
                output, loss_dict = self._do_iteration(data, loss_fns)
            except ProcessKilledException as e:
                # If the process is killed, the output if saved at state iter_idx, which is the current state,
                # so the computation can restart from the last iteration.
                self.logger.exception(f'{e}.')
                self.checkpointer.save(
                    iter_idx)  # Save checkpoint at kill. # noqa
                self.write_to_logs(
                )  # TODO: This causes the issue that current metrics are not written,
                # and you end up with an empty line.
                sys.exit(f'Exited with exception: {e}')

            # Gradient accumulation
            if (iter_idx +
                    1) % self.cfg.training.gradient_steps == 0:  # type: ignore
                # TODO: Is this slow? This is a generator, so should be cheap.
                pass
                # parameter_list = self.model.parameters() if not self.mixed_precision else \
                #     amp.master_params(self.__optimizer)
                # if self.cfg.training.gradient_steps > 1:  # type: ignore
                #     for parameter in parameter_list:
                #         if parameter.grad is not None:
                #             # In-place division
                #             parameter.grad.div_(self.cfg.training.gradient_steps)  # type: ignore
                # if self.cfg.training.gradient_clipping > 0.0:  # type: ignore
                #     torch.nn.utils.clip_grad_norm_(parameter_list, self.cfg.training.gradient_clipping)  # type: ignore
                #
                # # Gradient norm
                # if self.cfg.training.gradient_debug:  # type: ignore
                #     parameters = list(filter(lambda p: p.grad is not None, parameter_list))
                #     gradient_norm = sum([parameter.grad.data ** 2 for parameter in parameters]).sqrt()  # typing: ignore
                #     storage.add_scalar('gradient_norm', gradient_norm)

            self.__optimizer.step()  # type: ignore
            # Incorrect inference by mypy and pyflake
            self.__lr_scheduler.step()  # type: ignore # noqa
            storage.add_scalar('lr',
                               self.__optimizer.param_groups[0]['lr'],
                               smoothing_hint=False)

            self.__optimizer.zero_grad()  # type: ignore

            # Reduce the loss over all devices
            loss_dict_reduced = communication.reduce_tensor_dict(loss_dict)
            loss_reduced = sum(loss_dict_reduced.values())

            metrics_dict = evaluate_dict(
                metric_fns,
                transforms.modulus_if_complex(output.detach()).rename(None),
                data['target'].rename(None).detach().to(self.device),
                reduction='mean')
            metrics_dict_reduced = communication.reduce_tensor_dict(
                metrics_dict) if metrics_dict else {}
            storage.add_scalars(loss=loss_reduced,
                                **loss_dict_reduced,
                                **metrics_dict_reduced)

            if validation_data_loader is not None \
                    and iter_idx > 5\
                    and (iter_idx % self.cfg.training.validation_steps == 0 or (iter_idx + 1) == total_iter):
                val_loss_dict = self.evaluate(
                    validation_data_loader,
                    loss_fns,
                    evaluation_round=iter_idx //
                    self.cfg.training.validation_steps - 1)
                self.logger.info(f'Done evaluation at iteration {iter_idx}.')
                storage.add_scalars(**prefix_dict_keys(val_loss_dict, 'val_'),
                                    smoothing_hint=False)
                self.model.train()

            if iter_idx > 5 and\
                    (iter_idx % self.cfg.training.checkpointer.checkpoint_steps == 0 or (iter_idx + 1) == total_iter):
                self.logger.info(f'Checkpointing at iteration: {iter_idx}.')
                self.checkpointer.save(iter_idx)

            # Log every 20 iterations, or at a validation step or at the end of training.
            if iter_idx > 5 and (iter_idx % 20 == 0 or iter_idx %
                                 self.cfg.training.validation_steps == 0 or
                                 (iter_idx + 1) == total_iter):
                self.write_to_logs()

            storage.step()
Example #4
0
    def training_loop(
        self,
        data_loader: DataLoader,
        start_iter: int,
        validation_data_loaders: Optional[List[DataLoader]] = None,
        experiment_directory: Optional[pathlib.Path] = None,
    ):
        self.logger.info(f"Local rank: {communication.get_local_rank()}.")
        self.models_training_mode()

        loss_fns = self.build_loss()
        metric_fns = self.build_metrics(self.cfg.training.metrics)
        storage = get_event_storage()

        total_iter = self.cfg.training.num_iterations  # noqa
        for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
            data = AddNames()(data)
            if iter_idx == start_iter:
                self.ndim = self.compute_dimensionality_from_sample(data)
                self.logger.info(f"Data dimensionality: {self.ndim}.")

            if iter_idx == 0:
                self.log_first_training_example(data)

            try:
                output, loss_dict = self._do_iteration(data, loss_fns)
            except ProcessKilledException as e:
                # If the process is killed, the output if saved at state iter_idx, which is the current state,
                # so the computation can restart from the last iteration.
                self.logger.exception(f"Exiting with exception: {e}.")
                self.checkpointer.save(
                    iter_idx)  # Save checkpoint at kill. # noqa
                self.write_to_logs(
                )  # TODO: This causes the issue that current metrics are not written,
                # and you end up with an empty line.
                sys.exit(-1)

            # Gradient accumulation
            if (iter_idx +
                    1) % self.cfg.training.gradient_steps == 0:  # type: ignore
                if self.cfg.training.gradient_steps > 1:  # type: ignore
                    for parameter in self.model.parameters():
                        if parameter.grad is not None:
                            # In-place division
                            parameter.grad.div_(self.cfg.training.
                                                gradient_steps)  # type: ignore
                if self.cfg.training.gradient_clipping > 0.0:  # type: ignore
                    self._scaler.unscale_(self.__optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.cfg.training.gradient_clipping)

                # Gradient norm
                if self.cfg.training.gradient_debug:  # type: ignore
                    warnings.warn(
                        f"Gradient debug set. This will affect training performance. Only use for debugging."
                        f"This message will only be displayed once.")
                    parameters = list(
                        filter(lambda p: p.grad is not None,
                               self.model.parameters()))
                    gradient_norm = sum([
                        parameter.grad.data**2 for parameter in parameters
                    ]).sqrt()  # typing: ignore
                    storage.add_scalar("train/gradient_norm", gradient_norm)

                # Same as self.__optimizer.step() for mixed precision.
                self._scaler.step(self.__optimizer)
                # Updates the scale for next iteration.
                self._scaler.update()

            # Incorrect inference by mypy and pyflake
            self.__lr_scheduler.step()  # type: ignore # noqa
            storage.add_scalar("lr",
                               self.__optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)

            self.__optimizer.zero_grad()  # type: ignore

            # Reduce the loss over all devices
            loss_dict_reduced = communication.reduce_tensor_dict(loss_dict)
            loss_reduced = sum(loss_dict_reduced.values())

            metrics_dict = evaluate_dict(
                metric_fns,
                transforms.modulus_if_complex(output.detach()).rename(None),
                data["target"].rename(None).detach().to(self.device),
                reduction="mean",
            )
            metrics_dict_reduced = (
                communication.reduce_tensor_dict(metrics_dict)
                if metrics_dict else {})
            storage.add_scalars(loss=loss_reduced,
                                **loss_dict_reduced,
                                **metrics_dict_reduced)

            if iter_idx > 5 and (
                    iter_idx % self.cfg.training.checkpointer.checkpoint_steps
                    == 0 or (iter_idx + 1) == total_iter):
                self.logger.info(f"Checkpointing at iteration {iter_idx}.")
                self.checkpointer.save(iter_idx)

            if (validation_data_loaders is not None and iter_idx > 5
                    and (iter_idx % self.cfg.training.validation_steps == 0 or
                         (iter_idx + 1) == total_iter)):
                for (
                        curr_dataset_name,
                        curr_validation_data_loader,
                ) in validation_data_loaders:
                    self.logger.info(
                        f"Evaluating: {curr_dataset_name}..."
                    )  # TODO(jt): Fix with better names and stuff.
                    (
                        curr_val_loss_dict,
                        curr_val_metric_dict_per_case,
                        visualize_slices,
                        visualize_target,
                    ) = self.evaluate(
                        curr_validation_data_loader,
                        loss_fns,
                        is_validation_process=True,
                    )

                    if experiment_directory:
                        # Make dictionary serializable for logging
                        serializable_val_metric_dict = {
                            k0: {k1: float(v1)
                                 for k1, v1 in v0.items()}
                            for k0, v0 in
                            curr_val_metric_dict_per_case.items()
                        }
                        write_json(
                            experiment_directory /
                            f"metrics_val_{curr_dataset_name}_{iter_idx}.json",
                            serializable_val_metric_dict,
                        )

                    # Metric dict still needs to be reduced as it gives values *per* data
                    curr_val_metric_dict = reduce_list_of_dicts(list(
                        curr_val_metric_dict_per_case.values()),
                                                                mode="average")

                    key_prefix = ("val/" if not curr_dataset_name else
                                  f"val/{curr_dataset_name}/")
                    val_loss_reduced = sum(curr_val_loss_dict.values())
                    storage.add_scalars(
                        **{key_prefix + "loss": val_loss_reduced},
                        **{
                            **prefix_dict_keys(curr_val_metric_dict, key_prefix),
                            **prefix_dict_keys(curr_val_loss_dict, key_prefix),
                        },
                        smoothing_hint=False,
                    )
                    visualize_slices = self.process_slices_for_visualization(
                        visualize_slices, visualize_target)
                    storage.add_image(f"{key_prefix}prediction",
                                      visualize_slices)

                    if iter_idx // self.cfg.training.validation_steps - 1 == 0:
                        visualize_target = make_grid(
                            crop_to_largest(visualize_target, pad_value=0),
                            nrow=self.cfg.tensorboard.num_images,
                            scale_each=True,
                        )
                        storage.add_image(f"{key_prefix}target",
                                          visualize_target)

                self.logger.info(
                    f"Done evaluation of {curr_dataset_name} at iteration {iter_idx}."
                )
                self.model.train()

            # Log every 20 iterations, or at a validation step or at the end of training.
            if iter_idx > 5 and (iter_idx % 20 == 0 or iter_idx %
                                 self.cfg.training.validation_steps == 0 or
                                 (iter_idx + 1) == total_iter):
                self.write_to_logs()

            storage.step()
Example #5
0
    def training_loop(
        self,
        training_datasets: List,  # TODO(jt): Improve typing
        start_iter: int,
        validation_datasets: Optional[List] = None,
        experiment_directory: Optional[pathlib.Path] = None,
        num_workers: int = 6,
        start_with_validation: bool = False,
    ):
        self.logger.info(f"Local rank: {communication.get_local_rank()}.")
        self.models_training_mode()

        loss_fns = self.build_loss()
        metric_fns = self.build_metrics(self.cfg.training.metrics)
        regularizer_fns = self.build_regularizers(self.cfg.training.regularizers)
        storage = get_event_storage()

        self.ndim = training_datasets[0].ndim
        self.logger.info(f"Data dimensionality: {self.ndim}.")

        training_data = ConcatDataset(training_datasets)

        self.logger.info(f"Concatenated dataset length: {len(training_data)}.")
        self.logger.info(
            f"Building batch sampler for training set with batch size {self.cfg.training.batch_size}."
        )

        training_sampler = self.build_batch_sampler(
            training_datasets, self.cfg.training.batch_size, "random"
        )
        data_loader = self.build_loader(
            training_data,
            batch_sampler=training_sampler,
            num_workers=num_workers,
        )

        # Convenient shorthand
        validation_func = functools.partial(
            self.validation_loop,
            validation_datasets,
            loss_fns,
            experiment_directory,
            num_workers=num_workers,
        )

        total_iter = self.cfg.training.num_iterations  # noqa
        fail_counter = 0
        for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
            data = AddNames()(data)
            if iter_idx == 0:
                self.log_first_training_example_and_model(data)

            if start_with_validation and iter_idx == start_iter:
                self.logger.info(f"Starting with validation at iteration: {iter_idx}.")
                validation_func(iter_idx)
            try:
                iteration_output = self._do_iteration(
                    data, loss_fns, regularizer_fns=regularizer_fns
                )
                output = iteration_output.output_image
                loss_dict = iteration_output.data_dict
            except (ProcessKilledException, TrainingException) as e:
                # If the process is killed, the output if saved at state iter_idx, which is the current state,
                # so the computation can restart from the last iteration.
                self.logger.exception(f"Exiting with exception: {e}.")
                self.checkpoint_and_write_to_logs(iter_idx)
                sys.exit(-1)
            except RuntimeError as e:
                # Maybe string can change
                if "out of memory" in str(e):
                    if fail_counter == 3:
                        self.checkpoint_and_write_to_logs(iter_idx)
                        raise TrainingException(
                            f"OOM, could not recover after 3 tries: {e}."
                        )
                    fail_counter += 1
                    self.logger.info(
                        f"OOM Error: {e}. Skipping batch. Retry {fail_counter}/3."
                    )
                    self.__optimizer.zero_grad()
                    gc.collect()
                    torch.cuda.empty_cache()
                    continue
                self.checkpoint_and_write_to_logs(iter_idx)
                self.logger.info(f"Cannot recover from exception {e}. Exiting.")
                raise RuntimeError(e)

            if fail_counter > 0:
                self.logger.info(f"Recovered from OOM, skipped batch.")
            fail_counter = 0
            # Gradient accumulation
            if (iter_idx + 1) % self.cfg.training.gradient_steps == 0:  # type: ignore
                if self.cfg.training.gradient_steps > 1:  # type: ignore
                    for parameter in self.model.parameters():
                        if parameter.grad is not None:
                            # In-place division
                            parameter.grad.div_(self.cfg.training.gradient_steps)  # type: ignore
                if self.cfg.training.gradient_clipping > 0.0:  # type: ignore
                    self._scaler.unscale_(self.__optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.cfg.training.gradient_clipping
                    )

                # Gradient norm
                if self.cfg.training.gradient_debug:  # type: ignore
                    warnings.warn(
                        f"Gradient debug set. This will affect training performance. Only use for debugging."
                        f"This message will only be displayed once."
                    )
                    parameters = list(
                        filter(lambda p: p.grad is not None, self.model.parameters())
                    )
                    gradient_norm = sum(
                        [parameter.grad.data ** 2 for parameter in parameters]
                    ).sqrt()  # typing: ignore
                    storage.add_scalar("train/gradient_norm", gradient_norm)

                # Same as self.__optimizer.step() for mixed precision.
                self._scaler.step(self.__optimizer)
                # Updates the scale for next iteration.
                self._scaler.update()

            # Incorrect inference by mypy and pyflake
            self.__lr_scheduler.step()  # type: ignore # noqa
            storage.add_scalar(
                "lr", self.__optimizer.param_groups[0]["lr"], smoothing_hint=False
            )

            self.__optimizer.zero_grad()  # type: ignore

            # Reduce the loss over all devices
            loss_dict_reduced = communication.reduce_tensor_dict(loss_dict)
            loss_reduced = sum(loss_dict_reduced.values())

            metrics_dict = evaluate_dict(
                metric_fns,
                T.modulus_if_complex(output.detach()).rename(None),
                data["target"].rename(None).detach().to(self.device),
                reduction="mean",
            )
            metrics_dict_reduced = (
                communication.reduce_tensor_dict(metrics_dict) if metrics_dict else {}
            )
            storage.add_scalars(
                loss=loss_reduced, **loss_dict_reduced, **metrics_dict_reduced
            )

            self.checkpoint_model_at_interval(iter_idx, total_iter)
            self.write_to_logs_at_interval(iter_idx, total_iter)
            self.validate_model_at_interval(validation_func, iter_idx, total_iter)

            storage.step()