コード例 #1
0
ファイル: train_rim.py プロジェクト: jonasteuwen/direct
def build_dataset_from_environment(env, datasets_config, lists_root, data_root,
                                   type_data, **kwargs):
    datasets = []
    for idx, dataset_config in enumerate(datasets_config):
        transforms = build_mri_transforms(
            forward_operator=env.forward_operator,
            backward_operator=env.backward_operator,
            mask_func=build_masking_function(
                **dataset_config.transforms.masking),
            crop=dataset_config.transforms.crop,
            crop_type=dataset_config.transforms.crop_type,
            image_center_crop=dataset_config.transforms.image_center_crop,
            estimate_sensitivity_maps=dataset_config.transforms.
            estimate_sensitivity_maps,
            pad_coils=dataset_config.transforms.pad_coils,
        )
        logger.debug(f"Transforms for {type_data}: {idx}:\n{transforms}")

        # Only give fancy names when validating
        # TODO(jt): Perhaps this can be split up to just a description parameters, and parse config in the main func.
        if type_data == "validation":
            if dataset_config.text_description:
                text_description = dataset_config.text_description
            else:
                text_description = f"ds{idx}" if len(
                    datasets_config) > 1 else None
        elif type_data == "training":
            text_description = None
        else:
            raise ValueError(
                f"Type of data needs to be either `validation` or `training`, got {type_data}."
            )

        dataset = build_dataset(
            dataset_config.name,
            data_root,
            filenames_filter=get_filenames_for_datasets(
                dataset_config, lists_root, data_root),
            sensitivity_maps=None,
            transforms=transforms,
            text_description=text_description,
            kspace_context=dataset_config.kspace_context,
            **kwargs,
        )
        datasets.append(dataset)
        logger.info(
            f"Data size for {type_data} dataset"
            f" {dataset_config.name} ({idx + 1}/{len(datasets_config)}): {len(dataset)}."
        )

    return datasets
コード例 #2
0
def setup_inference(run_name, data_root, base_directory, output_directory,
                    cfg_filename, checkpoint, masks, device, num_workers,
                    machine_rank, validation_sm):

    # TODO(jt): This is a duplicate line, check how this can be merged with train_rim.py
    # TODO(jt): Log elsewhere than for training.
    # TODO(jt): Logging is different when having multiple processes.
    # TODO(jt): This can be merged with run_rim.py
    cfg, experiment_directory, forward_operator, backward_operator, engine\
        = setup_environment(run_name, base_directory, cfg_filename, device, machine_rank)

    # Process all masks
    all_maps = masks.glob('*.npy')
    logger.info('Loading masks...')
    masks_dict = {
        filename.name.replace('.npy', '.h5'): np.load(filename)
        for filename in all_maps
    }
    logger.info(f'Loaded {len(masks_dict)} masks.')

    # Don't add the mask func, add it separately
    mri_transforms = build_mri_transforms(
        forward_operator=forward_operator,
        backward_operator=backward_operator,
        mask_func=None,
        crop=(320, 320),  #(kp) Cropping needed for fastmri testing
        image_center_crop=True,
        estimate_sensitivity_maps=cfg.training.dataset.transforms.
        estimate_sensitivity_maps,
    )

    mri_transforms = Compose([CreateSamplingMask(masks_dict), mri_transforms])

    # Trigger cudnn benchmark when the number of different input shapes is small.
    torch.backends.cudnn.benchmark = True

    # TODO(jt): batches should have constant shapes! This works for Calgary Campinas because they are all with 256
    # slices.
    if len(cfg.validation.datasets) > 1:
        logger.warning('Multiple datasets given. Will only predict the first.')

    data = build_dataset(cfg.validation.datasets[0].name,
                         data_root,
                         sensitivity_maps=validation_sm,
                         transforms=mri_transforms)
    logger.info(f'Inference data size: {len(data)}.')

    # Just to make sure.
    torch.cuda.empty_cache()

    # Run prediction
    output = engine.predict(data,
                            experiment_directory,
                            checkpoint_number=checkpoint,
                            num_workers=num_workers)

    # Create output directory
    output_directory.mkdir(exist_ok=True, parents=True)

    # Only relevant for the Calgary Campinas challenge.
    # TODO(jt): This can be inferred from the configuration.
    # crop = (320, 320)

    # TODO(jt): Perhaps aggregation to the main process would be most optimal here before writing.
    for idx, filename in enumerate(output):
        # The output has shape (depth, 1, height, width)
        logger.info(
            f'({idx + 1}/{len(output)}): Writing {output_directory / filename}...'
        )
        reconstruction = torch.stack([
            _[1].rename(None) for _ in output[filename]
        ]).numpy()[:, 0, ...].astype(np.float)
        # if crop:
        #     reconstruction = reconstruction[slice(*crop)]

        # Only needed to fix a bug in Calgary Campinas training
        # reconstruction = reconstruction / np.sqrt(np.prod(reconstruction.shape[1:]))

        with h5py.File(output_directory / filename, 'w') as f:
            f.create_dataset('reconstruction', data=reconstruction)
コード例 #3
0
def setup_inference(
    run_name,
    data_root,
    base_directory,
    output_directory,
    cfg_filename,
    checkpoint,
    validation_set_index,
    accelerations,
    center_fractions,
    device,
    num_workers,
    machine_rank,
    mixed_precision
):

    # TODO(jt): This is a duplicate line, check how this can be merged with train_rim.py
    # TODO(jt): Log elsewhere than for training.
    # TODO(jt): Logging is different when having multiple processes.
    env = setup_environment(
        run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision
    )

    # Create training and validation data
    # Masking configuration
    if len(env.cfg.validation.datasets) > 1 and not validation_set_index:
        logger.warning(
            "Multiple validation datasets given in config, yet no index is given. Will select first."
        )
    validation_set_index = validation_set_index if validation_set_index else 0

    if accelerations or center_fractions:
        sys.exit(f"Overwriting of accelerations or ACS not yet supported.")

    mask_func = build_masking_function(
        **env.cfg.validation.datasets[validation_set_index].transforms.masking
    )

    mri_transforms = build_mri_transforms(
        forward_operator=env.forward_operator,
        backward_operator=env.backward_operator,
        mask_func=mask_func,
        crop=None,  # No cropping needed for testing
        image_center_crop=True,
        estimate_sensitivity_maps=env.cfg.training.datasets[0].transforms.estimate_sensitivity_maps,
    )

    # Trigger cudnn benchmark when the number of different input shapes is small.
    torch.backends.cudnn.benchmark = True

    # TODO(jt): batches should have constant shapes! This works for Calgary Campinas because they are all with 256
    # slices.
    data = build_dataset(
        env.cfg.validation.datasets[validation_set_index].name,
        data_root,
        sensitivity_maps=None,
        transforms=mri_transforms,
    )
    logger.info(f"Inference data size: {len(data)}.")

    # Just to make sure.
    torch.cuda.empty_cache()

    # Run prediction
    output = env.engine.predict(
        data,
        env.experiment_dir,
        checkpoint_number=checkpoint,
        num_workers=num_workers,
    )

    # Create output directory
    output_directory.mkdir(exist_ok=True, parents=True)

    # Only relevant for the Calgary Campinas challenge.
    # TODO(jt): This can be inferred from the configuration.
    # TODO(jt): Refactor this for v0.2.
    crop = (
        (50, -50)
        if env.cfg.validation.datasets[validation_set_index].name == "CalgaryCampinas"
        else None
    )

    # TODO(jt): Perhaps aggregation to the main process would be most optimal here before writing.
    for idx, filename in enumerate(output):
        # The output has shape (depth, 1, height, width)
        logger.info(
            f"({idx + 1}/{len(output)}): Writing {output_directory / filename}..."
        )
        reconstruction = (
            torch.stack([_[1].rename(None) for _ in output[filename]])
            .numpy()[:, 0, ...]
            .astype(np.float)
        )
        if crop:
            reconstruction = reconstruction[slice(*crop)]

        # Only needed to fix a bug in Calgary Campinas training
        if env.cfg.validation.datasets[validation_set_index].name == "CalgaryCampinas":
            reconstruction = reconstruction / np.sqrt(np.prod(reconstruction.shape[1:]))

        with h5py.File(output_directory / filename, "w") as f:
            f.create_dataset("reconstruction", data=reconstruction)
コード例 #4
0
def setup_inference(
    run_name,
    data_root,
    base_directory,
    output_directory,
    volume_processing_func,
    cfg_filename,
    checkpoint,
    masks,
    device,
    num_workers,
    machine_rank,
):

    # TODO(jt): This is a duplicate line, check how this can be merged with train_rim.py
    # TODO(jt): Log elsewhere than for training.
    # TODO(jt): Logging is different when having multiple processes.
    # TODO(jt): This can be merged with run_rim.py
    env = setup_environment(run_name, base_directory, cfg_filename, device,
                            machine_rank)

    # Process all masks
    all_maps = masks.glob("*.npy")
    logger.info("Loading masks...")
    masks_dict = {
        filename.name.replace(".npy", ".h5"): np.load(filename)
        for filename in all_maps
    }
    logger.info(f"Loaded {len(masks_dict)} masks.")

    # Don't add the mask func, add it separately
    mri_transforms = build_mri_transforms(
        forward_operator=env.forward_operator,
        backward_operator=env.backward_operator,
        mask_func=None,
        crop=None,  # No cropping needed for testing
        image_center_crop=True,
        estimate_sensitivity_maps=env.cfg.training.dataset.transforms.
        estimate_sensitivity_maps,
    )

    mri_transforms = Compose([CreateSamplingMask(masks_dict), mri_transforms])

    # Trigger cudnn benchmark when the number of different input shapes is small.
    torch.backends.cudnn.benchmark = True

    # TODO(jt): batches should have constant shapes! This works for Calgary Campinas because they are all with 256
    # slices.
    if len(env.cfg.validation.datasets) > 1:
        logger.warning("Multiple datasets given. Will only predict the first.")

    data = build_dataset(
        env.cfg.validation.dataset[0].name,
        data_root,
        sensitivity_maps=None,
        transforms=mri_transforms,
    )
    logger.info(f"Inference data size: {len(data)}.")

    # Just to make sure.
    torch.cuda.empty_cache()

    # Run prediction
    output = env.engine.predict(
        data,
        env.experiment_dir,
        checkpoint_number=checkpoint,
        num_workers=num_workers,
    )

    # Create output directory
    output_directory.mkdir(exist_ok=True, parents=True)

    # TODO(jt): Perhaps aggregation to the main process would be most optimal here before writing.
    for idx, filename in enumerate(output):
        # The output has shape (depth, 1, height, width)
        logger.info(
            f"({idx + 1}/{len(output)}): Writing {output_directory / filename}..."
        )
        reconstruction = (torch.stack([
            _[1].rename(None) for _ in output[filename]
        ]).numpy()[:, 0, ...].astype(np.float))
        reconstruction = volume_processing_func(reconstruction)

        with h5py.File(output_directory / filename, "w") as f:
            f.create_dataset("reconstruction", data=reconstruction)
コード例 #5
0
ファイル: train_rim.py プロジェクト: rmsouza01/direct
def setup(run_name, training_root, validation_root, base_directory,
          cfg_filename, device, num_workers, resume, machine_rank):
    experiment_dir = base_directory / run_name

    if communication.get_local_rank() == 0:
        # Want to prevent multiple workers from trying to write a directory
        # This is required in the logging below
        experiment_dir.mkdir(parents=True, exist_ok=True)
    communication.synchronize()  # Ensure folders are in place.

    # Load configs from YAML file to check which model needs to be loaded.
    cfg_from_file = OmegaConf.load(cfg_filename)
    model_name = cfg_from_file.model_name + 'Config'
    try:
        model_cfg = str_to_class(f'direct.nn.{cfg_from_file.model_name.lower()}.config', model_name)
    except (AttributeError, ModuleNotFoundError) as e:
        logger.error(f'Model configuration does not exist for {cfg_from_file.model_name} (err = {e}).')
        sys.exit(-1)

    # Load the default configs to ensure type safety
    base_cfg = OmegaConf.structured(DefaultConfig)
    base_cfg = OmegaConf.merge(base_cfg, {'model': model_cfg, 'training': TrainingConfig()})
    cfg = OmegaConf.merge(base_cfg, cfg_from_file)

    # Setup logging
    log_file = experiment_dir / f'log_{machine_rank}_{communication.get_local_rank()}.txt'
    direct.utils.logging.setup(
        use_stdout=communication.get_local_rank() == 0 or cfg.debug,
        filename=log_file,
        log_level=('INFO' if not cfg.debug else 'DEBUG')
    )
    logger.info(f'Machine rank: {machine_rank}.')
    logger.info(f'Local rank: {communication.get_local_rank()}.')
    logger.info(f'Logging: {log_file}.')
    logger.info(f'Saving to: {experiment_dir}.')
    logger.info(f'Run name: {run_name}.')
    logger.info(f'Config file: {cfg_filename}.')
    logger.info(f'Python version: {sys.version}.')
    logger.info(f'PyTorch version: {torch.__version__}.')  # noqa
    logger.info(f'CUDA {torch.version.cuda} - cuDNN {torch.backends.cudnn.version()}.')
    logger.info(f'Configuration: {pformat(dict(cfg))}.')

    # Create the model
    logger.info('Building model.')
    model = MRIReconstruction(2, **cfg.model).to(device)
    n_params = sum(p.numel() for p in model.parameters())
    logger.info(f'Number of parameters: {n_params} ({n_params / 10.0**3:.2f}k).')
    logger.debug(model)

    # Create training and validation data
    train_mask_func, val_mask_func = build_masking_functions(**cfg.masking)
    train_transforms, val_transforms = build_mri_transforms(
        train_mask_func, val_mask_func=val_mask_func, crop=cfg.dataset.transforms.crop)

    training_data, validation_data = build_datasets(
        cfg.dataset.name, training_root, train_sensitivity_maps=None, train_transforms=train_transforms,
        validation_root=validation_root, val_sensitivity_maps=None, val_transforms=val_transforms)

    # Create the optimizers
    logger.info('Building optimizers.')
    optimizer: torch.optim.Optimizer = str_to_class('torch.optim', cfg.training.optimizer)(  # noqa
        model.parameters(), lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
    )  # noqa

    # Build the LR scheduler, we use a fixed LR schedule step size, no adaptive training schedule.
    solver_steps = list(range(cfg.training.lr_step_size, cfg.training.num_iterations, cfg.training.lr_step_size))
    lr_scheduler = WarmupMultiStepLR(
        optimizer, solver_steps, cfg.training.lr_gamma, warmup_factor=1 / 3.,
        warmup_iters=cfg.training.lr_warmup_iter, warmup_method='linear')

    # Just to make sure.
    torch.cuda.empty_cache()

    # Setup training engine.
    engine = RIMEngine(cfg, model, device=device)

    engine.train(
        optimizer, lr_scheduler, training_data, experiment_dir,
        validation_data=validation_data, resume=resume, num_workers=num_workers)