Example #1
0
def setup_train(
    run_name,
    training_root,
    validation_root,
    base_directory,
    cfg_filename,
    force_validation,
    initialization_checkpoint,
    initial_images,
    initial_kspace,
    noise,
    device,
    num_workers,
    resume,
    machine_rank,
    mixed_precision,
    debug,
):

    env = setup_training_environment(
        run_name,
        base_directory,
        cfg_filename,
        device,
        machine_rank,
        mixed_precision,
        debug=debug,
    )

    if initial_kspace is not None and initial_images is not None:
        raise ValueError(f"Cannot both provide initial kspace or initial images.")

    pass_dictionaries = {}
    if noise is not None:
        if not env.cfg.physics.use_noise_matrix:
            raise ValueError(
                f"cfg.physics.use_noise_matrix is null, yet command line passed noise files."
            )

        noise = [read_json(fn) for fn in noise]
        pass_dictionaries["loglikelihood_scaling"] = [
            parse_noise_dict(
                _, percentile=0.999, multiplier=env.cfg.physics.noise_matrix_scaling
            )
            for _ in noise
        ]

    # Create training and validation data
    # Transforms configuration
    # TODO: More ** passing...

    training_datasets = build_training_datasets_from_environment(
        env=env,
        datasets_config=env.cfg.training.datasets,
        lists_root=cfg_filename.parents[0],
        data_root=training_root,
        initial_images=None if initial_images is None else initial_images[0],
        initial_kspaces=None if initial_kspace is None else initial_kspace[0],
        pass_text_description=False,
        pass_dictionaries=pass_dictionaries,
    )
    training_data_sizes = [len(_) for _ in training_datasets]
    logger.info(
        f"Training data sizes: {training_data_sizes} (sum={sum(training_data_sizes)})."
    )

    if validation_root:
        validation_data = build_training_datasets_from_environment(
            env=env,
            datasets_config=env.cfg.validation.datasets,
            lists_root=cfg_filename.parents[0],
            data_root=validation_root,
            initial_images=None if initial_images is None else initial_images[1],
            initial_kspaces=None if initial_kspace is None else initial_kspace[1],
            pass_text_description=True,
        )
    else:
        logger.info(f"No validation data.")
        validation_data = None

    # Create the optimizers
    logger.info("Building optimizers.")
    optimizer_params = [{"params": env.engine.model.parameters()}]
    for curr_model_name in env.engine.models:
        # TODO(jt): Can get learning rate from the config per additional model too.
        curr_learning_rate = env.cfg.training.lr
        logger.info(
            f"Adding model parameters of {curr_model_name} with learning rate {curr_learning_rate}."
        )
        optimizer_params.append(
            {
                "params": env.engine.models[curr_model_name].parameters(),
                "lr": curr_learning_rate,
            }
        )

    optimizer: torch.optim.Optimizer = str_to_class(
        "torch.optim", env.cfg.training.optimizer
    )(  # noqa
        optimizer_params,
        lr=env.cfg.training.lr,
        weight_decay=env.cfg.training.weight_decay,
    )  # noqa

    # Build the LR scheduler, we use a fixed LR schedule step size, no adaptive training schedule.
    solver_steps = list(
        range(
            env.cfg.training.lr_step_size,
            env.cfg.training.num_iterations,
            env.cfg.training.lr_step_size,
        )
    )
    lr_scheduler = WarmupMultiStepLR(
        optimizer,
        solver_steps,
        env.cfg.training.lr_gamma,
        warmup_factor=1 / 3.0,
        warmup_iterations=env.cfg.training.lr_warmup_iter,
        warmup_method="linear",
    )

    # Just to make sure.
    torch.cuda.empty_cache()

    env.engine.train(
        optimizer,
        lr_scheduler,
        training_datasets,
        env.experiment_dir,
        validation_datasets=validation_data,
        resume=resume,
        initialization=initialization_checkpoint,
        start_with_validation=force_validation,
        num_workers=num_workers,
    )
Example #2
0
    def __init__(self,
                 root: pathlib.Path,
                 dataset_description: Optional[Dict[PathOrString, Any]] = None,
                 metadata: Optional[Dict[PathOrString, Dict]] = None,
                 sensitivity_maps: Optional[PathOrString] = None,
                 extra_keys: Optional[Tuple] = None) -> None:
        """
        Initialize the dataset. The dataset can remove spike noise and empty slices.

        Parameters
        ----------
        root : pathlib.Path
            Root directory to data.
        metadata : dict
            If given, this dictionary will be passed to the output transform.
        sensitivity_maps : [pathlib.Path, None]
            Path to sensitivity maps, or None.
        extra_keys : Tuple
            Add extra keys in h5 file to output.
        """
        self.logger = logging.getLogger(type(self).__name__)

        self.root = pathlib.Path(root)

        self.metadata = metadata

        self.dataset_description = dataset_description
        self.data = []

        self.volume_indices = OrderedDict()
        current_slice_number = 0  # This is required to keep track of where a volume is in the dataset
        if isinstance(dataset_description, (pathlib.Path, str)):
            examples = read_json(dataset_description)
            for filename in examples:
                num_slices = examples[filename]['num_slices']
                # ignore_slices = examples[filename].get('ignore_slices', [])
                # TODO: Slices can, and should be able to be ignored (for instance too many empty ones)
                ignore_slices = []
                for idx in range(num_slices):
                    if idx not in ignore_slices:
                        self.data.append((filename, idx))
                self.volume_indices[filename] = range(
                    current_slice_number, current_slice_number + num_slices)
                current_slice_number += num_slices

        elif not dataset_description:
            self.logger.info(
                f'No dataset description given, parsing directory {self.root} for h5 files. '
                f'It is recommended you create such a file, as this will speed up processing.'
            )
            filenames = list(self.root.glob('*.h5'))
            self.logger.info(
                f'Using {len(filenames)} h5 files in {self.root}.')

            for idx, filename in enumerate(filenames):
                if len(filenames) % (idx + 1) == 5 or len(filenames) == (idx +
                                                                         1):
                    self.logger.info(
                        f'Parsing: {(idx + 1) / len(filenames) * 100:.2f}%.')
                kspace = h5py.File(filename, 'r')['kspace']
                num_slices = kspace.shape[0]
                self.data += [(filename, idx) for idx in range(num_slices)]
                self.volume_indices[filename] = range(
                    current_slice_number, current_slice_number + num_slices)
                current_slice_number += num_slices
        else:
            raise ValueError(
                f'Expected `Path` or `str` for `dataset_description`, got {type(dataset_description)}'
            )

        self.sensitivity_maps = cast_as_path(sensitivity_maps)
        self.extra_keys = extra_keys
Example #3
0
    def __init__(
        self,
        root: pathlib.Path,
        filenames_filter: Optional[List[PathOrString]] = None,
        dataset_description: Optional[Dict[PathOrString, Any]] = None,
        metadata: Optional[Dict[PathOrString, Dict]] = None,
        sensitivity_maps: Optional[PathOrString] = None,
        extra_keys: Optional[Tuple] = None,
        text_description: Optional[str] = None,
        kspace_context: Optional[int] = None,
    ) -> None:
        """
        Initialize the dataset. The dataset can remove spike noise and empty slices.

        Parameters
        ----------
        root : pathlib.Path
            Root directory to data.
        filenames_filter : List
            List of filenames to include in the dataset, should be the same as the ones that can be derived from a glob
            on the root. If set, will skip searching for files in the root.
        metadata : dict
            If given, this dictionary will be passed to the output transform.
        sensitivity_maps : [pathlib.Path, None]
            Path to sensitivity maps, or None.
        extra_keys : Tuple
            Add extra keys in h5 file to output.
        text_description : str
            Description of dataset, can be useful for logging.
        """
        self.logger = logging.getLogger(type(self).__name__)

        self.root = pathlib.Path(root)
        self.filenames_filter = filenames_filter

        self.metadata = metadata

        self.dataset_description = dataset_description
        self.text_description = text_description
        self.data = []

        self.volume_indices = OrderedDict()
        current_slice_number = (
            0  # This is required to keep track of where a volume is in the dataset
        )
        if isinstance(dataset_description, (pathlib.Path, str)):
            warnings.warn(f"Untested functionality.")
            # TODO(jt): This is untested. Maybe this can even be removed, loading from SSD is very fast even for large
            # TODO(jt): datasets.
            examples = read_json(dataset_description)
            filtered_examples = 0
            for filename in examples:
                if self.filenames_filter:
                    if filename in self.filenames_filter:
                        filtered_examples += 1
                        continue

                num_slices = examples[filename]["num_slices"]
                # ignore_slices = examples[filename].get('ignore_slices', [])
                # TODO: Slices can, and should be able to be ignored (for instance too many empty ones)
                ignore_slices = []
                for idx in range(num_slices):
                    if idx not in ignore_slices:
                        self.data.append((filename, idx))
                self.volume_indices[filename] = range(
                    current_slice_number, current_slice_number + num_slices)
                current_slice_number += num_slices
            if filtered_examples > 0:
                self.logger.info(
                    f"Included {len(self.volume_indices)} volumes, skipped {filtered_examples}."
                )

        elif not dataset_description:
            if self.filenames_filter:
                self.logger.info(
                    f"Attempting to load {len(filenames_filter)} filenames from list."
                )
                filenames = filenames_filter
            else:
                self.logger.info(
                    f"No dataset description given, parsing directory {self.root} for h5 files. "
                    f"It is recommended you create such a file, as this will speed up processing."
                )
                filenames = list(self.root.glob("*.h5"))
            self.logger.info(
                f"Using {len(filenames)} h5 files in {self.root}.")

            for idx, filename in enumerate(filenames):
                if (len(filenames) < 5 or idx % (len(filenames) // 5) == 0
                        or len(filenames) == (idx + 1)):
                    self.logger.info(
                        f"Parsing: {(idx + 1) / len(filenames) * 100:.2f}%.")
                try:
                    kspace = h5py.File(filename, "r")["kspace"]
                except OSError as e:
                    self.logger.warning(
                        f"{filename} failed with OSError: {e}. Skipping...")
                    continue

                num_slices = kspace.shape[0]
                self.data += [(filename, idx) for idx in range(num_slices)]
                self.volume_indices[filename] = range(
                    current_slice_number, current_slice_number + num_slices)
                current_slice_number += num_slices
        else:
            raise ValueError(
                f"Expected `Path` or `str` for `dataset_description`, got {type(dataset_description)}"
            )
        self.sensitivity_maps = cast_as_path(sensitivity_maps)
        self.extra_keys = extra_keys

        self.kspace_context = kspace_context if kspace_context else 0

        if self.text_description:
            self.logger.info(f"Dataset description: {self.text_description}.")