コード例 #1
0
ファイル: train_rim.py プロジェクト: wdika/direct
def build_transforms_from_environment(env, dataset_config):
    mri_transforms_func = functools.partial(
        build_mri_transforms,
        forward_operator=env.engine.forward_operator,
        backward_operator=env.engine.backward_operator,
        mask_func=build_masking_function(**dataset_config.transforms.masking),
    )

    transforms = mri_transforms_func(**remove_keys(dataset_config.transforms, "masking"))
    return transforms
コード例 #2
0
ファイル: train_rim.py プロジェクト: jonasteuwen/direct
def build_dataset_from_environment(env, datasets_config, lists_root, data_root,
                                   type_data, **kwargs):
    datasets = []
    for idx, dataset_config in enumerate(datasets_config):
        transforms = build_mri_transforms(
            forward_operator=env.forward_operator,
            backward_operator=env.backward_operator,
            mask_func=build_masking_function(
                **dataset_config.transforms.masking),
            crop=dataset_config.transforms.crop,
            crop_type=dataset_config.transforms.crop_type,
            image_center_crop=dataset_config.transforms.image_center_crop,
            estimate_sensitivity_maps=dataset_config.transforms.
            estimate_sensitivity_maps,
            pad_coils=dataset_config.transforms.pad_coils,
        )
        logger.debug(f"Transforms for {type_data}: {idx}:\n{transforms}")

        # Only give fancy names when validating
        # TODO(jt): Perhaps this can be split up to just a description parameters, and parse config in the main func.
        if type_data == "validation":
            if dataset_config.text_description:
                text_description = dataset_config.text_description
            else:
                text_description = f"ds{idx}" if len(
                    datasets_config) > 1 else None
        elif type_data == "training":
            text_description = None
        else:
            raise ValueError(
                f"Type of data needs to be either `validation` or `training`, got {type_data}."
            )

        dataset = build_dataset(
            dataset_config.name,
            data_root,
            filenames_filter=get_filenames_for_datasets(
                dataset_config, lists_root, data_root),
            sensitivity_maps=None,
            transforms=transforms,
            text_description=text_description,
            kspace_context=dataset_config.kspace_context,
            **kwargs,
        )
        datasets.append(dataset)
        logger.info(
            f"Data size for {type_data} dataset"
            f" {dataset_config.name} ({idx + 1}/{len(datasets_config)}): {len(dataset)}."
        )

    return datasets
コード例 #3
0
ファイル: predict_val.py プロジェクト: mwacaan/direct
def _get_transforms(validation_index, env):
    dataset_cfg = env.cfg.validation.datasets[validation_index]
    mask_func = build_masking_function(**dataset_cfg.transforms.masking)
    transforms = build_inference_transforms(env, mask_func, dataset_cfg)
    return dataset_cfg, transforms
コード例 #4
0
def setup_inference(
    run_name,
    data_root,
    base_directory,
    output_directory,
    cfg_filename,
    checkpoint,
    validation_set_index,
    accelerations,
    center_fractions,
    device,
    num_workers,
    machine_rank,
    mixed_precision
):

    # TODO(jt): This is a duplicate line, check how this can be merged with train_rim.py
    # TODO(jt): Log elsewhere than for training.
    # TODO(jt): Logging is different when having multiple processes.
    env = setup_environment(
        run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision
    )

    # Create training and validation data
    # Masking configuration
    if len(env.cfg.validation.datasets) > 1 and not validation_set_index:
        logger.warning(
            "Multiple validation datasets given in config, yet no index is given. Will select first."
        )
    validation_set_index = validation_set_index if validation_set_index else 0

    if accelerations or center_fractions:
        sys.exit(f"Overwriting of accelerations or ACS not yet supported.")

    mask_func = build_masking_function(
        **env.cfg.validation.datasets[validation_set_index].transforms.masking
    )

    mri_transforms = build_mri_transforms(
        forward_operator=env.forward_operator,
        backward_operator=env.backward_operator,
        mask_func=mask_func,
        crop=None,  # No cropping needed for testing
        image_center_crop=True,
        estimate_sensitivity_maps=env.cfg.training.datasets[0].transforms.estimate_sensitivity_maps,
    )

    # Trigger cudnn benchmark when the number of different input shapes is small.
    torch.backends.cudnn.benchmark = True

    # TODO(jt): batches should have constant shapes! This works for Calgary Campinas because they are all with 256
    # slices.
    data = build_dataset(
        env.cfg.validation.datasets[validation_set_index].name,
        data_root,
        sensitivity_maps=None,
        transforms=mri_transforms,
    )
    logger.info(f"Inference data size: {len(data)}.")

    # Just to make sure.
    torch.cuda.empty_cache()

    # Run prediction
    output = env.engine.predict(
        data,
        env.experiment_dir,
        checkpoint_number=checkpoint,
        num_workers=num_workers,
    )

    # Create output directory
    output_directory.mkdir(exist_ok=True, parents=True)

    # Only relevant for the Calgary Campinas challenge.
    # TODO(jt): This can be inferred from the configuration.
    # TODO(jt): Refactor this for v0.2.
    crop = (
        (50, -50)
        if env.cfg.validation.datasets[validation_set_index].name == "CalgaryCampinas"
        else None
    )

    # TODO(jt): Perhaps aggregation to the main process would be most optimal here before writing.
    for idx, filename in enumerate(output):
        # The output has shape (depth, 1, height, width)
        logger.info(
            f"({idx + 1}/{len(output)}): Writing {output_directory / filename}..."
        )
        reconstruction = (
            torch.stack([_[1].rename(None) for _ in output[filename]])
            .numpy()[:, 0, ...]
            .astype(np.float)
        )
        if crop:
            reconstruction = reconstruction[slice(*crop)]

        # Only needed to fix a bug in Calgary Campinas training
        if env.cfg.validation.datasets[validation_set_index].name == "CalgaryCampinas":
            reconstruction = reconstruction / np.sqrt(np.prod(reconstruction.shape[1:]))

        with h5py.File(output_directory / filename, "w") as f:
            f.create_dataset("reconstruction", data=reconstruction)