コード例 #1
0
ファイル: launch.py プロジェクト: rmsouza01/direct
def _distributed_worker(local_rank, main_func, world_size,
                        num_gpus_per_machine, machine_rank, dist_url, args):
    global_rank = machine_rank * num_gpus_per_machine + local_rank
    logger = logging.getLogger(__name__)
    try:
        dist.init_process_group(backend='NCCL',
                                init_method=dist_url,
                                world_size=world_size,
                                rank=global_rank)
    except Exception as e:
        logger.error(f'Process group URL: {dist_url}')
        raise e
    # synchronize is needed here to prevent a possible timeout after calling init_process_group
    # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
    communication.synchronize()
    logger.info(f'Global rank {global_rank}.')
    logger.info('Synchronized GPUs.')

    assert num_gpus_per_machine <= torch.cuda.device_count()
    torch.cuda.set_device(local_rank)

    # Setup the local process group (which contains ranks within the same machine)
    assert communication._LOCAL_PROCESS_GROUP is None  # noqa
    num_machines = world_size // num_gpus_per_machine
    for i in range(num_machines):
        ranks_on_i = list(
            range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
        pg = dist.new_group(ranks_on_i)
        if i == machine_rank:
            communication._LOCAL_PROCESS_GROUP = pg

    main_func(*args)
コード例 #2
0
def setup_training_environment(
    run_name,
    base_directory,
    cfg_filename,
    device,
    machine_rank,
    mixed_precision,
    debug=False,
):

    env = setup_common_environment(
        run_name,
        base_directory,
        cfg_filename,
        device,
        machine_rank,
        mixed_precision,
        debug=debug,
    )
    # Write config file to experiment directory.
    config_file_in_project_folder = env.experiment_dir / "config.yaml"
    logger.info(
        f"Writing configuration file to: {config_file_in_project_folder}.")
    if communication.is_main_process():
        with open(config_file_in_project_folder, "w") as f:
            f.write(OmegaConf.to_yaml(env.cfg))
    communication.synchronize()

    return env
コード例 #3
0
def setup_training_environment(
    run_name,
    base_directory,
    cfg_filename,
    device,
    machine_rank,
    mixed_precision,
    debug=False,
):
    experiment_dir = base_directory / run_name
    if communication.get_local_rank() == 0:
        # Want to prevent multiple workers from trying to write a directory
        # This is required in the logging below
        experiment_dir.mkdir(parents=True, exist_ok=True)
    communication.synchronize()  # Ensure folders are in place.

    # Load configs from YAML file to check which model needs to be loaded.
    cfg_from_file = OmegaConf.load(cfg_filename)
    base_cfg, models = load_models_into_environment_config(cfg_from_file)

    # Setup everything for training
    base_cfg.training = TrainingConfig
    # Parse the proper specific config for the datasets:
    base_cfg.training.datasets = [
        load_dataset_config(dataset) for dataset in base_cfg.training.datasets
    ]
    base_cfg.validation.datasets = [
        load_dataset_config(dataset)
        for dataset in base_cfg.validation.datasets
    ]

    # Make configuration read only.
    # TODO(jt): Does not work when indexing config lists.
    # OmegaConf.set_readonly(cfg, True)

    forward_operator, backward_operator, engine, cfg = setup_common_environment(
        base_cfg,
        cfg_from_file,
        models,
        device,
        machine_rank,
        experiment_dir,
        run_name,
        cfg_filename,
        mixed_precision,
        debug,
    )

    # Check if the file exists in the project directory
    config_file_in_project_folder = experiment_dir / "config.yaml"
    if config_file_in_project_folder.exists():
        if dict(OmegaConf.load(config_file_in_project_folder)) != dict(cfg):
            pass
            # raise ValueError(
            #     f"This project folder exists and has a config.yaml, "
            #     f"yet this does not match with the one the model was built with."
            # )
    else:
        if communication.get_local_rank() == 0:
            with open(config_file_in_project_folder, "w") as f:
                f.write(OmegaConf.to_yaml(cfg))
        communication.synchronize()

    environment = namedtuple(
        "environment",
        [
            "cfg", "experiment_dir", "forward_operator", "backward_operator",
            "engine"
        ],
    )
    return environment(cfg, experiment_dir, forward_operator,
                       backward_operator, engine)
コード例 #4
0
    def evaluate(
        self,
        data_loader: DataLoader,
        loss_fns: Optional[Dict[str, Callable]],
        regularizer_fns: Optional[Dict[str, Callable]] = None,
        crop: Optional[str] = None,
        is_validation_process=True,
    ):

        self.models_to_device()
        self.models_validation_mode()
        torch.cuda.empty_cache()

        # Variables required for evaluation.
        # TODO(jt): Consider if this needs to be in the main engine.py or here. Might be possible we have different
        # types needed, perhaps even a FastMRI engine or something similar depending on the metrics.
        volume_metrics = self.build_metrics(self.cfg.validation.metrics)

        # filenames can be in the volume_indices attribute of the dataset
        if hasattr(data_loader.dataset, "volume_indices"):
            all_filenames = list(data_loader.dataset.volume_indices.keys())
            num_for_this_process = len(
                list(data_loader.batch_sampler.sampler.volume_indices.keys()))
            self.logger.info(
                f"Reconstructing a total of {len(all_filenames)} volumes. "
                f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})."
            )
        else:
            num_for_this_process = None
        filenames_seen = 0

        reconstruction_output = defaultdict(list)
        targets_output = defaultdict(list)
        val_losses = []
        val_volume_metrics = defaultdict(dict)
        last_filename = None

        # Container to for the slices which can be visualized in TensorBoard.
        visualize_slices = []
        visualize_target = []
        visualizations = {}

        extra_visualization_keys = (self.cfg.logging.log_as_image
                                    if self.cfg.logging.log_as_image else [])

        # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
        # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
        # that the slices are outputted from the Dataset *sequentially* for each volume one by one.
        time_start = time.time()

        for iter_idx, data in enumerate(data_loader):
            data = AddNames()(data)
            filenames = data.pop("filename")
            if len(set(filenames)) != 1:
                raise ValueError(
                    f"Expected a batch during validation to only contain filenames of one case. "
                    f"Got {set(filenames)}.")

            slice_nos = data.pop("slice_no")
            scaling_factors = data["scaling_factor"]

            resolution = self.compute_resolution(
                key=self.cfg.validation.crop,
                reconstruction_size=data.get("reconstruction_size", None),
            )

            # Compute output and loss.
            iteration_output = self._do_iteration(
                data, loss_fns, regularizer_fns=regularizer_fns)
            output = iteration_output.output_image
            loss_dict = iteration_output.data_dict
            # sensitivity_map = iteration_output.sensitivity_map

            loss_dict = detach_dict(loss_dict)
            output = output.detach()
            val_losses.append(loss_dict)

            # Output is complex-valued, and has to be cropped. This holds for both output and target.
            output_abs = self.process_output(
                output.refine_names(*self.complex_names()),
                scaling_factors,
                resolution=resolution,
            )

            if is_validation_process:
                target_abs = self.process_output(
                    data["target"].detach().refine_names(*self.real_names()),
                    scaling_factors,
                    resolution=resolution,
                )
                for key in extra_visualization_keys:
                    curr_data = data[key].detach()
                    # Here we need to discover which keys are actually normalized or not
                    # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23

            del output  # Explicitly call delete to clear memory.
            # TODO: Is a hack.

            # Aggregate volumes to be able to compute the metrics on complete volumes.
            for idx, filename in enumerate(filenames):
                if last_filename is None:
                    last_filename = (
                        filename  # First iteration last_filename is not set.
                    )

                # If the new filename is not the previous one, then we can reconstruct the volume as the sampling
                # is linear.
                # For the last case we need to check if we are at the last batch *and* at the last element in the batch.
                is_last_element_of_last_batch = iter_idx + 1 == len(
                    data_loader) and idx + 1 == len(data["target"])
                if filename != last_filename or is_last_element_of_last_batch:
                    filenames_seen += 1
                    # Now we can ditch the reconstruction dict by reconstructing the volume,
                    # will take too much memory otherwise.
                    # TODO: Stack does not support named tensors.
                    volume = torch.stack([
                        _[1].rename(None)
                        for _ in reconstruction_output[last_filename]
                    ])
                    if is_validation_process:
                        target = torch.stack([
                            _[1].rename(None)
                            for _ in targets_output[last_filename]
                        ])
                        curr_metrics = {
                            metric_name: metric_fn(target, volume)
                            for metric_name, metric_fn in
                            volume_metrics.items()
                        }
                        val_volume_metrics[last_filename] = curr_metrics
                        # Log the center slice of the volume
                        if (len(visualize_slices) <
                                self.cfg.logging.tensorboard.num_images):
                            visualize_slices.append(volume[volume.shape[0] //
                                                           2])
                            visualize_target.append(target[target.shape[0] //
                                                           2])

                        # Delete outputs from memory, and recreate dictionary. This is not needed when not in validation
                        # as we are actually interested in the output
                        del targets_output
                        targets_output = defaultdict(list)
                        del reconstruction_output
                        reconstruction_output = defaultdict(list)

                    if all_filenames:
                        log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
                    else:
                        log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"

                    self.logger.info(
                        f"{log_prefix} {last_filename}"
                        f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
                    )
                    # restart timer
                    time_start = time.time()
                    last_filename = filename

                curr_slice = output_abs[idx].detach()
                slice_no = int(slice_nos[idx].numpy())

                # TODO: CPU?
                reconstruction_output[filename].append(
                    (slice_no, curr_slice.cpu()))

                if is_validation_process:
                    targets_output[filename].append(
                        (slice_no, target_abs[idx].cpu()))

        # Average loss dict
        loss_dict = reduce_list_of_dicts(val_losses)
        reduce_tensor_dict(loss_dict)

        communication.synchronize()
        torch.cuda.empty_cache()

        # TODO: Does not work yet with normal gather.
        all_gathered_metrics = merge_list_of_dicts(
            communication.all_gather(val_volume_metrics))
        if not is_validation_process:
            return loss_dict, reconstruction_output

        # TODO: Apply named tuples where applicable
        # TODO: Several functions have multiple output values, in many cases
        # TODO: it would be more convenient to convert this to namedtuples.
        return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
コード例 #5
0
ファイル: rim_engine.py プロジェクト: jonasteuwen/direct
    def evaluate(
        self,
        data_loader: DataLoader,
        loss_fns: Optional[Dict[str, Callable]],
        crop: Optional[str] = None,
        is_validation_process=True,
    ):

        # TODO(jt): Also log other models output (e.g. sensitivity map).
        # TODO(jt): This can be simplified as the sampler now only outputs batches belonging to the same volume.
        self.models_to_device()
        self.models_validation_mode()
        torch.cuda.empty_cache()

        # Variables required for evaluation.
        # TODO(jt): Consider if this needs to be in the main engine.py or here. Might be possible we have different
        # types needed, perhaps even a FastMRI engine or something similar depending on the metrics.
        volume_metrics = self.build_metrics(self.cfg.validation.metrics)

        reconstruction_output = defaultdict(list)
        targets_output = defaultdict(list)
        val_losses = []
        val_volume_metrics = defaultdict(dict)
        last_filename = None

        # Container to for the slices which can be visualized in TensorBoard.
        visualize_slices = []
        visualize_target = []

        # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
        # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
        # that the slices are outputted from the Dataset *sequentially* for each volume one by one.
        for iter_idx, data in enumerate(data_loader):
            self.log_process(iter_idx, len(data_loader))
            data = AddNames()(data)
            filenames = data.pop("filename")
            if len(set(filenames)) != 1:
                raise ValueError(
                    f"Expected a batch during validation to only contain filenames of one case. "
                    f"Got {set(filenames)}.")

            slice_nos = data.pop("slice_no")
            scaling_factors = data.pop("scaling_factor")

            # Check if reconstruction size is the data
            if self.cfg.validation.crop == "header":
                # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over
                # batches.
                resolution = [
                    _.cpu().numpy().tolist()
                    for _ in data["reconstruction_size"]
                ]
                # The volume sampler should give validation indices belonging to the *same* volume, so it should be
                # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y).
                resolution = [_[0] for _ in resolution][:-1]
            elif self.cfg.validation.crop == "training":
                resolution = self.cfg.training.loss.crop
            elif not self.cfg.validation.loss.crop:
                resolution = None
            else:
                raise ValueError(
                    f"Cropping should be either set to `header` to get the values from the header or "
                    f"`training` to take the same value as training.")

            # Compute output and loss.
            output, loss_dict = self._do_iteration(data, loss_fns)
            val_losses.append(loss_dict)

            # Output is complex-valued, and has to be cropped. This holds for both output and target.
            output_abs = self.process_output(
                output.refine_names(*self.complex_names).detach(),
                scaling_factors,
                resolution=resolution,
            )

            if is_validation_process:
                target_abs = self.process_output(
                    data["target"].refine_names(*self.real_names).detach(),
                    scaling_factors,
                    resolution=resolution,
                )
            del output  # Explicitly call delete to clear memory.
            # TODO: Is a hack.

            # Aggregate volumes to be able to compute the metrics on complete volumes.
            for idx, filename in enumerate(filenames):
                if last_filename is None:
                    last_filename = (
                        filename  # First iteration last_filename is not set.
                    )
                # If the new filename is not the previous one, then we can reconstruct the volume as the sampling
                # is linear.
                # For the last case we need to check if we are at the last batch *and* at the last element in the batch.
                if filename != last_filename or (
                        iter_idx + 1 == len(data_loader)
                        and idx + 1 == len(data["target"])):
                    # Now we can ditch the reconstruction dict by reconstructing the volume,
                    # will take too much memory otherwise.
                    # TODO: Stack does not support named tensors.
                    volume = torch.stack([
                        _[1].rename(None)
                        for _ in reconstruction_output[last_filename]
                    ])
                    self.logger.info(
                        f"Reconstructed {last_filename} (shape = {list(volume.shape)})."
                    )
                    if is_validation_process:
                        target = torch.stack([
                            _[1].rename(None)
                            for _ in targets_output[last_filename]
                        ])
                        curr_metrics = {
                            metric_name: metric_fn(volume, target)
                            for metric_name, metric_fn in
                            volume_metrics.items()
                        }
                        val_volume_metrics[last_filename] = curr_metrics
                        # Log the center slice of the volume
                        if len(visualize_slices
                               ) < self.cfg.tensorboard.num_images:
                            visualize_slices.append(
                                normalize_image(volume[volume.shape[0] // 2]))
                            visualize_target.append(
                                normalize_image(target[target.shape[0] // 2]))

                        # Delete outputs from memory, and recreate dictionary. This is not needed when not in validation
                        # as we are actually interested in the output
                        del targets_output
                        targets_output = defaultdict(list)
                        del reconstruction_output
                        reconstruction_output = defaultdict(list)

                    last_filename = filename

                curr_slice = output_abs[idx]
                slice_no = int(slice_nos[idx].numpy())

                # TODO: CPU?
                reconstruction_output[filename].append(
                    (slice_no, curr_slice.cpu()))

                if is_validation_process:
                    targets_output[filename].append(
                        (slice_no, target_abs[idx].cpu()))

        # Average loss dict
        loss_dict = reduce_list_of_dicts(val_losses)
        reduce_tensor_dict(loss_dict)

        communication.synchronize()
        torch.cuda.empty_cache()

        # TODO(jt): Does not work yet with normal gather.
        all_gathered_metrics = merge_list_of_dicts(
            communication.all_gather(val_volume_metrics))

        if not is_validation_process:
            return loss_dict, reconstruction_output

        # TODO(jt): Make named tuple
        return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
コード例 #6
0
def setup_common_environment(
    run_name,
    base_directory,
    cfg_filename,
    device,
    machine_rank,
    mixed_precision,
    debug=False,
):

    # Shutup all loggers
    logger = logging.getLogger()

    experiment_dir = base_directory / run_name
    if communication.get_local_rank() == 0:
        # Want to prevent multiple workers from trying to write a directory
        # This is required in the logging below
        experiment_dir.mkdir(parents=True, exist_ok=True)
    communication.synchronize()  # Ensure folders are in place.

    # Load configs from YAML file to check which model needs to be loaded.
    cfg_from_file = OmegaConf.load(cfg_filename)

    # Load the default configs to ensure type safety
    cfg = OmegaConf.structured(DefaultConfig)

    models, models_config = load_models_into_environment_config(cfg_from_file)
    cfg.model = models_config.model
    del models_config["model"]
    cfg.additional_models = models_config

    # Setup everything for training
    cfg.training = TrainingConfig
    cfg.validation = ValidationConfig
    cfg.inference = InferenceConfig

    cfg_from_file_new = cfg_from_file.copy()
    for key in cfg_from_file:
        # TODO: This does not really do a full validation.
        # BODY: This will be handeled once Hydra is implemented.
        if key in ["models", "additional_models"]:  # Still handled separately
            continue

        if key in ["training", "validation", "inference"]:
            if not cfg_from_file[key]:
                logger.info(f"key {key} missing in config.")
                continue

            if key in ["training", "validation"]:
                dataset_cfg_from_file = extract_names(
                    cfg_from_file[key].datasets)
                for idx, (dataset_name,
                          dataset_config) in enumerate(dataset_cfg_from_file):
                    cfg_from_file_new[key].datasets[idx] = dataset_config
                    cfg[key].datasets.append(load_dataset_config(dataset_name))
            else:
                dataset_name, dataset_config = extract_names(
                    cfg_from_file[key].dataset)
                cfg_from_file_new[key].dataset = dataset_config
                cfg[key].dataset = load_dataset_config(dataset_name)

        cfg[key] = OmegaConf.merge(cfg[key], cfg_from_file_new[key])
    # sys.exit()
    # Make configuration read only.
    # TODO(jt): Does not work when indexing config lists.
    # OmegaConf.set_readonly(cfg, True)
    setup_logging(machine_rank, experiment_dir, run_name, cfg_filename, cfg,
                  debug)
    forward_operator, backward_operator = build_operators(cfg.physics)

    model, additional_models = initialize_models_from_config(
        cfg, models, forward_operator, backward_operator, device)

    engine = setup_engine(
        cfg,
        device,
        model,
        additional_models,
        forward_operator=forward_operator,
        backward_operator=backward_operator,
        mixed_precision=mixed_precision,
    )

    environment = namedtuple(
        "environment",
        ["cfg", "experiment_dir", "engine"],
    )
    return environment(cfg, experiment_dir, engine)
コード例 #7
0
    def evaluate(self,
                 data_loader: DataLoader,
                 loss_fns: Dict[str, Callable],
                 volume_metrics: Optional[Dict[str, Callable]] = None,
                 evaluation_round=0):

        self.logger.info(f'Evaluating...')
        self.model.eval()
        torch.cuda.empty_cache()

        # Variables required for evaluation.
        volume_metrics = volume_metrics if volume_metrics is not None else self.build_metrics()
        storage = get_event_storage()

        reconstruction_output = defaultdict(list)
        targets_output = defaultdict(list)
        val_losses = []
        val_volume_metrics = defaultdict(dict)
        last_filename = None

        # Container to for the slices which can be visualized in TensorBoard.
        visualize_slices = []
        visualize_target = []

        # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
        # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
        # that the slices are outputted from the Dataset *sequentially* for each volume one by one.
        for iter_idx, data in enumerate(data_loader):
            self.log_process(iter_idx, len(data_loader))
            data = AddNames()(data)
            filenames = data.pop('filename')
            slice_nos = data.pop('slice_no')
            scaling_factors = data.pop('scaling_factor')

            # Compute output and loss.
            output, loss_dict = self._do_iteration(data, loss_fns)
            val_losses.append(loss_dict)

            # Output is complex-valued, and has to be cropped. This holds for both output and target.
            output_abs = self.process_output(
                output.refine_names('batch', 'complex', 'height', 'width').detach(), scaling_factors, 320)
            target_abs = self.process_output(
                data['target'].refine_names('batch', 'height', 'width').detach(), scaling_factors, 320)
            del output  # Explicitly call delete to clear memory.
            # TODO: Is a hack.

            # Aggregate volumes to be able to compute the metrics on complete volumes.
            batch_counter = 0
            for idx, filename in enumerate(filenames):
                if last_filename is None:
                    last_filename = filename  # First iteration last_filename is not set.
                # If the new filename is not the previous one, then we can reconstruct the volume as the sampling
                # is linear.
                # For the last case we need to check if we are at the last batch *and* at the last element in the batch.
                if filename != last_filename or (iter_idx + 1 == len(data_loader) and idx + 1 == len(data['target'])):
                    # Now we can ditch the reconstruction dict by reconstructing the volume,
                    # will take too mucih memory otherwise.
                    # TODO: Stack does not support named tensors.
                    volume = torch.stack([_[1].rename(None) for _ in reconstruction_output[last_filename]])
                    target = torch.stack([_[1].rename(None) for _ in targets_output[last_filename]])
                    self.logger.info(f'Reconstructed {last_filename} (shape = {list(volume.shape)}).')
                    curr_metrics = {
                        metric_name: metric_fn(volume, target) for metric_name, metric_fn in volume_metrics.items()}
                    val_volume_metrics[last_filename] = curr_metrics

                    # Log the center slice of the volume
                    if len(visualize_slices) < self.cfg.tensorboard.num_images:
                        visualize_slices.append(normalize_image(volume[volume.shape[0] // 2]))
                        # Target only needs to be logged once.
                        if evaluation_round == 0:
                            visualize_target.append(normalize_image(target[target.shape[0] // 2]))

                    last_filename = filename

                    # Delete outputs from memory, and recreate dictionary.
                    del reconstruction_output
                    del targets_output
                    reconstruction_output = defaultdict(list)
                    targets_output = defaultdict(list)

                curr_slice = output_abs[idx]
                slice_no = int(slice_nos[idx].numpy())

                # TODO: CPU?
                reconstruction_output[filename].append((slice_no, curr_slice.cpu()))
                targets_output[filename].append((slice_no, target_abs[idx].cpu()))

        # Average loss dict
        loss_dict = reduce_list_of_dicts(val_losses)
        reduce_tensor_dict(loss_dict)

        # Log slices.
        visualize_slices = make_grid(visualize_slices, nrow=4, scale_each=True)
        storage.add_image('validation/prediction', visualize_slices)

        if evaluation_round == 0:
            visualize_target = make_grid(visualize_target, nrow=4, scale_each=True)
            storage.add_image('validation/target', visualize_target)

        communication.synchronize()
        torch.cuda.empty_cache()

        return loss_dict
コード例 #8
0
ファイル: train_rim.py プロジェクト: rmsouza01/direct
def setup(run_name, training_root, validation_root, base_directory,
          cfg_filename, device, num_workers, resume, machine_rank):
    experiment_dir = base_directory / run_name

    if communication.get_local_rank() == 0:
        # Want to prevent multiple workers from trying to write a directory
        # This is required in the logging below
        experiment_dir.mkdir(parents=True, exist_ok=True)
    communication.synchronize()  # Ensure folders are in place.

    # Load configs from YAML file to check which model needs to be loaded.
    cfg_from_file = OmegaConf.load(cfg_filename)
    model_name = cfg_from_file.model_name + 'Config'
    try:
        model_cfg = str_to_class(f'direct.nn.{cfg_from_file.model_name.lower()}.config', model_name)
    except (AttributeError, ModuleNotFoundError) as e:
        logger.error(f'Model configuration does not exist for {cfg_from_file.model_name} (err = {e}).')
        sys.exit(-1)

    # Load the default configs to ensure type safety
    base_cfg = OmegaConf.structured(DefaultConfig)
    base_cfg = OmegaConf.merge(base_cfg, {'model': model_cfg, 'training': TrainingConfig()})
    cfg = OmegaConf.merge(base_cfg, cfg_from_file)

    # Setup logging
    log_file = experiment_dir / f'log_{machine_rank}_{communication.get_local_rank()}.txt'
    direct.utils.logging.setup(
        use_stdout=communication.get_local_rank() == 0 or cfg.debug,
        filename=log_file,
        log_level=('INFO' if not cfg.debug else 'DEBUG')
    )
    logger.info(f'Machine rank: {machine_rank}.')
    logger.info(f'Local rank: {communication.get_local_rank()}.')
    logger.info(f'Logging: {log_file}.')
    logger.info(f'Saving to: {experiment_dir}.')
    logger.info(f'Run name: {run_name}.')
    logger.info(f'Config file: {cfg_filename}.')
    logger.info(f'Python version: {sys.version}.')
    logger.info(f'PyTorch version: {torch.__version__}.')  # noqa
    logger.info(f'CUDA {torch.version.cuda} - cuDNN {torch.backends.cudnn.version()}.')
    logger.info(f'Configuration: {pformat(dict(cfg))}.')

    # Create the model
    logger.info('Building model.')
    model = MRIReconstruction(2, **cfg.model).to(device)
    n_params = sum(p.numel() for p in model.parameters())
    logger.info(f'Number of parameters: {n_params} ({n_params / 10.0**3:.2f}k).')
    logger.debug(model)

    # Create training and validation data
    train_mask_func, val_mask_func = build_masking_functions(**cfg.masking)
    train_transforms, val_transforms = build_mri_transforms(
        train_mask_func, val_mask_func=val_mask_func, crop=cfg.dataset.transforms.crop)

    training_data, validation_data = build_datasets(
        cfg.dataset.name, training_root, train_sensitivity_maps=None, train_transforms=train_transforms,
        validation_root=validation_root, val_sensitivity_maps=None, val_transforms=val_transforms)

    # Create the optimizers
    logger.info('Building optimizers.')
    optimizer: torch.optim.Optimizer = str_to_class('torch.optim', cfg.training.optimizer)(  # noqa
        model.parameters(), lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
    )  # noqa

    # Build the LR scheduler, we use a fixed LR schedule step size, no adaptive training schedule.
    solver_steps = list(range(cfg.training.lr_step_size, cfg.training.num_iterations, cfg.training.lr_step_size))
    lr_scheduler = WarmupMultiStepLR(
        optimizer, solver_steps, cfg.training.lr_gamma, warmup_factor=1 / 3.,
        warmup_iters=cfg.training.lr_warmup_iter, warmup_method='linear')

    # Just to make sure.
    torch.cuda.empty_cache()

    # Setup training engine.
    engine = RIMEngine(cfg, model, device=device)

    engine.train(
        optimizer, lr_scheduler, training_data, experiment_dir,
        validation_data=validation_data, resume=resume, num_workers=num_workers)