コード例 #1
0
def build_operators(cfg):
    # Get the operators
    forward_operator = str_to_class(f"direct.data.transforms",
                                    cfg.forward_operator)
    backward_operator = str_to_class(f"direct.data.transforms",
                                     cfg.backward_operator)
    return forward_operator, backward_operator
コード例 #2
0
def build_operators(cfg) -> Tuple[Callable, Callable]:
    # Get the operators
    forward_operator = str_to_class("direct.data.transforms",
                                    cfg.forward_operator)
    backward_operator = str_to_class("direct.data.transforms",
                                     cfg.backward_operator)
    return forward_operator, backward_operator
コード例 #3
0
def setup_engine(cfg,
                 device,
                 model,
                 additional_models: dict,
                 mixed_precision: bool = False):
    # Setup engine.
    # There is a bit of repetition here, but the warning provided is more descriptive
    # TODO(jt): Try to find a way to combine this with the setup above.
    model_name_short = cfg.model.model_name.split(".")[0]
    engine_name = cfg.model.model_name.split(".")[-1] + "Engine"

    try:
        engine_class = str_to_class(
            f"direct.nn.{model_name_short.lower()}.{model_name_short.lower()}_engine",
            engine_name,
        )
    except (AttributeError, ModuleNotFoundError) as e:
        logger.error(
            f"Engine does not exist for {cfg.model.model_name} (err = {e}).")
        sys.exit(-1)

    engine = engine_class(
        cfg,
        model,
        device=device,
        mixed_precision=mixed_precision,
        **additional_models,
    )
    return engine
コード例 #4
0
def build_datasets(dataset_name,
                   training_root: pathlib.Path,
                   train_sensitivity_maps=None,
                   train_transforms=None,
                   validation_root=None,
                   val_sensitivity_maps=None,
                   val_transforms=None):
    logger.info(f'Building dataset for {dataset_name}.')
    dataset_class: Callable = str_to_class('direct.data.datasets',
                                           dataset_name + 'Dataset')

    train_data = dataset_class(root=training_root,
                               dataset_description=None,
                               transform=train_transforms,
                               sensitivity_maps=train_sensitivity_maps,
                               pass_mask=False)
    logger.info(f'Train data size: {len(train_data)}.')

    if validation_root:
        val_data = dataset_class(root=validation_root,
                                 dataset_description=None,
                                 transform=val_transforms,
                                 sensitivity_maps=val_sensitivity_maps,
                                 pass_mask=False)

        logger.info(f'Validation data size: {len(val_data)}.')

        return train_data, val_data

    return train_data
コード例 #5
0
ファイル: datasets.py プロジェクト: jonasteuwen/direct
def build_dataset(
    dataset_name,
    root: pathlib.Path,
    filenames_filter: Optional[List[PathOrString]] = None,
    sensitivity_maps: Optional[pathlib.Path] = None,
    transforms: Optional[Any] = None,
    text_description: Optional[str] = None,
    kspace_context: Optional[int] = 0,
    **kwargs,
) -> Dataset:
    """

    Parameters
    ----------
    dataset_name : str
        Name of dataset class (without `Dataset`) in direct.data.datasets.
    root : pathlib.Path
        Root path to the data for the dataset class.
    filenames_filter : List
        List of filenames to include in the dataset, should be the same as the ones that can be derived from a glob
        on the root. If set, will skip searching for files in the root.
    sensitivity_maps : pathlib.Path
        Path to sensitivity maps.
    transforms : object
        Transformation object
    text_description : str
        Description of dataset, can be used for logging.
    kspace_context : int
        If set, output will be of shape -kspace_context:kspace_context.

    Returns
    -------
    Dataset
    """

    logger.info(f"Building dataset for: {dataset_name}.")
    dataset_class: Callable = str_to_class("direct.data.datasets",
                                           dataset_name + "Dataset")
    logger.debug(f"Dataset class: {dataset_class}.")

    dataset = dataset_class(
        root=root,
        filenames_filter=filenames_filter,
        dataset_description=None,
        transform=transforms,
        sensitivity_maps=sensitivity_maps,
        pass_mask=False,
        text_description=text_description,
        kspace_context=kspace_context,
        **kwargs,
    )

    logger.debug(f"Dataset:\n{dataset}")

    return dataset
コード例 #6
0
ファイル: engine.py プロジェクト: jonasteuwen/direct
    def build_metrics(metrics_list) -> Dict:
        if not metrics_list:
            return {}

        # _metric is added as only keys containining loss or metric are logged.
        metrics_dict = {
            curr_metric + "_metric": str_to_class("direct.functionals",
                                                  curr_metric)
            for curr_metric in metrics_list
        }
        return metrics_dict
コード例 #7
0
    def _build_function_class(functions_list, root_module, postfix) -> Dict:
        if not functions_list:
            return {}

        # _postfix is added as only keys containing loss, metric or reg are logged.
        functions_dict = {
            curr_func.split("(")[0]
            + f"_{postfix}": str_to_class(root_module, curr_func)
            for curr_func in functions_list
        }
        return functions_dict
コード例 #8
0
def load_model_from_name(model_name):
    module_path = f"direct.nn.{'.'.join([_.lower() for _ in model_name.split('.')[:-1]])}"
    module_name = model_name.split(".")[-1]
    try:
        model = str_to_class(module_path, module_name)
    except (AttributeError, ModuleNotFoundError) as e:
        logger.error(
            f"Path {module_path} for model_name {module_name} does not exist (err = {e})."
        )
        sys.exit(-1)

    return model
コード例 #9
0
def build_masking_function(name,
                           accelerations,
                           center_fractions=None,
                           uniform_range=False,
                           **kwargs):
    MaskFunc: BaseMaskFunc = str_to_class("direct.common.subsample",
                                          name + "MaskFunc")  # noqa
    mask_func = MaskFunc(
        accelerations=accelerations,
        center_fractions=center_fractions,
        uniform_range=uniform_range,
    )

    return mask_func
コード例 #10
0
def build_masking_functions(name,
                            center_fractions,
                            accelerations,
                            uniform_range=False,
                            val_center_fractions=None,
                            val_accelerations=None):

    MaskFunc: BaseMaskFunc = str_to_class('direct.common.subsample',
                                          name + 'MaskFunc')  # noqa

    train_mask_func = MaskFunc(accelerations,
                               center_fractions,
                               uniform_range=uniform_range)

    val_center_fractions = center_fractions if not val_center_fractions else val_center_fractions
    val_accelerations = accelerations if not val_accelerations else val_accelerations
    val_mask_func = MaskFunc(val_accelerations,
                             val_center_fractions,
                             uniform_range=False)

    return train_mask_func, val_mask_func
コード例 #11
0
def load_model_config_from_name(model_name):
    """
    Load specific configuration module for

    Parameters
    ----------
    model_name : path to model relative to direct.nn

    Returns
    -------
    model configuration.
    """
    module_path = f"direct.nn.{model_name.split('.')[0].lower()}.config"
    model_name += "Config"
    config_name = model_name.split(".")[-1]
    try:
        model_cfg = str_to_class(module_path, config_name)
    except (AttributeError, ModuleNotFoundError) as e:
        logger.error(
            f"Path {module_path} for config_name {config_name} does not exist (err = {e})."
        )
        sys.exit(-1)
    return model_cfg
コード例 #12
0
ファイル: train_rim.py プロジェクト: mwacaan/direct
def setup_train(
    run_name,
    training_root,
    validation_root,
    base_directory,
    cfg_filename,
    force_validation,
    initialization_checkpoint,
    initial_images,
    initial_kspace,
    noise,
    device,
    num_workers,
    resume,
    machine_rank,
    mixed_precision,
    debug,
):

    env = setup_training_environment(
        run_name,
        base_directory,
        cfg_filename,
        device,
        machine_rank,
        mixed_precision,
        debug=debug,
    )

    if initial_kspace is not None and initial_images is not None:
        raise ValueError(f"Cannot both provide initial kspace or initial images.")

    pass_dictionaries = {}
    if noise is not None:
        if not env.cfg.physics.use_noise_matrix:
            raise ValueError(
                f"cfg.physics.use_noise_matrix is null, yet command line passed noise files."
            )

        noise = [read_json(fn) for fn in noise]
        pass_dictionaries["loglikelihood_scaling"] = [
            parse_noise_dict(
                _, percentile=0.999, multiplier=env.cfg.physics.noise_matrix_scaling
            )
            for _ in noise
        ]

    # Create training and validation data
    # Transforms configuration
    # TODO: More ** passing...

    training_datasets = build_training_datasets_from_environment(
        env=env,
        datasets_config=env.cfg.training.datasets,
        lists_root=cfg_filename.parents[0],
        data_root=training_root,
        initial_images=None if initial_images is None else initial_images[0],
        initial_kspaces=None if initial_kspace is None else initial_kspace[0],
        pass_text_description=False,
        pass_dictionaries=pass_dictionaries,
    )
    training_data_sizes = [len(_) for _ in training_datasets]
    logger.info(
        f"Training data sizes: {training_data_sizes} (sum={sum(training_data_sizes)})."
    )

    if validation_root:
        validation_data = build_training_datasets_from_environment(
            env=env,
            datasets_config=env.cfg.validation.datasets,
            lists_root=cfg_filename.parents[0],
            data_root=validation_root,
            initial_images=None if initial_images is None else initial_images[1],
            initial_kspaces=None if initial_kspace is None else initial_kspace[1],
            pass_text_description=True,
        )
    else:
        logger.info(f"No validation data.")
        validation_data = None

    # Create the optimizers
    logger.info("Building optimizers.")
    optimizer_params = [{"params": env.engine.model.parameters()}]
    for curr_model_name in env.engine.models:
        # TODO(jt): Can get learning rate from the config per additional model too.
        curr_learning_rate = env.cfg.training.lr
        logger.info(
            f"Adding model parameters of {curr_model_name} with learning rate {curr_learning_rate}."
        )
        optimizer_params.append(
            {
                "params": env.engine.models[curr_model_name].parameters(),
                "lr": curr_learning_rate,
            }
        )

    optimizer: torch.optim.Optimizer = str_to_class(
        "torch.optim", env.cfg.training.optimizer
    )(  # noqa
        optimizer_params,
        lr=env.cfg.training.lr,
        weight_decay=env.cfg.training.weight_decay,
    )  # noqa

    # Build the LR scheduler, we use a fixed LR schedule step size, no adaptive training schedule.
    solver_steps = list(
        range(
            env.cfg.training.lr_step_size,
            env.cfg.training.num_iterations,
            env.cfg.training.lr_step_size,
        )
    )
    lr_scheduler = WarmupMultiStepLR(
        optimizer,
        solver_steps,
        env.cfg.training.lr_gamma,
        warmup_factor=1 / 3.0,
        warmup_iterations=env.cfg.training.lr_warmup_iter,
        warmup_method="linear",
    )

    # Just to make sure.
    torch.cuda.empty_cache()

    env.engine.train(
        optimizer,
        lr_scheduler,
        training_datasets,
        env.experiment_dir,
        validation_datasets=validation_data,
        resume=resume,
        initialization=initialization_checkpoint,
        start_with_validation=force_validation,
        num_workers=num_workers,
    )
コード例 #13
0
def load_dataset_config(dataset):
    dataset_name = dataset.name
    dataset_config = str_to_class("direct.data.datasets_config",
                                  dataset_name + "Config")
    return dataset_config
コード例 #14
0
ファイル: train_rim.py プロジェクト: jonasteuwen/direct
def setup_train(
    run_name,
    training_root,
    validation_root,
    base_directory,
    cfg_filename,
    checkpoint,
    device,
    num_workers,
    resume,
    machine_rank,
    mixed_precision,
    debug,
):

    env = setup_training_environment(
        run_name,
        base_directory,
        cfg_filename,
        device,
        machine_rank,
        mixed_precision,
        debug=debug,
    )

    # Create training and validation data
    # Transforms configuration
    training_datasets = build_dataset_from_environment(
        env=env,
        datasets_config=env.cfg.training.datasets,
        lists_root=cfg_filename.parents[0],
        data_root=training_root,
        type_data="training",
    )
    training_data_sizes = [len(_) for _ in training_datasets]
    logger.info(
        f"Training data sizes: {training_data_sizes} (sum={sum(training_data_sizes)})."
    )

    if validation_root:
        validation_data = build_dataset_from_environment(
            env=env,
            datasets_config=env.cfg.validation.datasets,
            lists_root=cfg_filename.parents[0],
            data_root=validation_root,
            type_data="validation",
        )
    else:
        logger.info(f"No validation data.")
        validation_data = None

    # Create the optimizers
    logger.info("Building optimizers.")
    optimizer_params = [{"params": env.engine.model.parameters()}]
    for curr_model_name in env.engine.models:
        # TODO(jt): Can get learning rate from the config per additional model too.
        curr_learning_rate = env.cfg.training.lr
        logger.info(
            f"Adding model parameters of {curr_model_name} with learning rate {curr_learning_rate}."
        )
        optimizer_params.append({
            "params":
            env.engine.models[curr_model_name].parameters(),
            "lr":
            curr_learning_rate,
        })

    optimizer: torch.optim.Optimizer = str_to_class(
        "torch.optim", env.cfg.training.optimizer)(  # noqa
            optimizer_params,
            lr=env.cfg.training.lr,
            weight_decay=env.cfg.training.weight_decay,
        )  # noqa

    # Build the LR scheduler, we use a fixed LR schedule step size, no adaptive training schedule.
    solver_steps = list(
        range(
            env.cfg.training.lr_step_size,
            env.cfg.training.num_iterations,
            env.cfg.training.lr_step_size,
        ))
    lr_scheduler = WarmupMultiStepLR(
        optimizer,
        solver_steps,
        env.cfg.training.lr_gamma,
        warmup_factor=1 / 3.0,
        warmup_iterations=env.cfg.training.lr_warmup_iter,
        warmup_method="linear",
    )

    # Just to make sure.
    torch.cuda.empty_cache()

    env.engine.train(
        optimizer,
        lr_scheduler,
        training_datasets,
        env.experiment_dir,
        validation_data=validation_data,
        resume=resume,
        initialization=checkpoint,
        num_workers=num_workers,
    )
コード例 #15
0
    def __init__(
        self,
        num_channels=2,
        num_classes=1000,
        width_mult=1.0,
        inverted_residual_setting=None,
        round_nearest=8,
        block=None,
        norm_layer: Callable[..., Any] = None,
    ):
        """
        MobileNet V2 main class

        Parameters
        ----------
        num_channels : int
            Number of channels.
        num_classes : int
            Number of classes.
        width_mult : float
            Width multiplier - adjusts number of channels in each layer by this amount.
        inverted_residual_setting : Network structure
        round_nearest : int
            Round the number of channels in each layer to be a multiple of this number
            Set to 1 to turn off rounding
        block : str
            Module specifying inverted residual building block for mobilenet.
        norm_layer : str
            Module specifying the normalization layer to use.
        """
        super(MobileNetV2, self).__init__()

        if block is None:
            block = InvertedResidual

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        else:
            module_name = ".".join(str(norm_layer).split(".")[:-1])
            norm_layer = str_to_class(f"torch.{module_name}", str(norm_layer).split(".")[-1])

        input_channel = 32
        last_channel = 1280

        if inverted_residual_setting is None:
            inverted_residual_setting = [
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
                [6, 160, 3, 2],
                [6, 320, 1, 1],
            ]

        # only check the first element, assuming user knows t,c,n,s are required
        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
            raise ValueError(
                f"inverted_residual_setting should be non-empty "
                f"or a 4-element list, got {inverted_residual_setting}"
            )

        # building first layer
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
        features = [ConvBNReLU(num_channels, input_channel, stride=2, norm_layer=norm_layer)]
        # building inverted residual blocks
        for t, c, n, s in inverted_residual_setting:
            output_channel = _make_divisible(c * width_mult, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(
                    block(
                        input_channel,
                        output_channel,
                        stride,
                        expand_ratio=t,
                        norm_layer=norm_layer,
                    )
                )
                input_channel = output_channel
        # building last several layers
        features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
        # make it nn.Sequential
        self.features = nn.Sequential(*features)

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, num_classes),
        )

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)
コード例 #16
0
ファイル: train_rim.py プロジェクト: rmsouza01/direct
def setup(run_name, training_root, validation_root, base_directory,
          cfg_filename, device, num_workers, resume, machine_rank):
    experiment_dir = base_directory / run_name

    if communication.get_local_rank() == 0:
        # Want to prevent multiple workers from trying to write a directory
        # This is required in the logging below
        experiment_dir.mkdir(parents=True, exist_ok=True)
    communication.synchronize()  # Ensure folders are in place.

    # Load configs from YAML file to check which model needs to be loaded.
    cfg_from_file = OmegaConf.load(cfg_filename)
    model_name = cfg_from_file.model_name + 'Config'
    try:
        model_cfg = str_to_class(f'direct.nn.{cfg_from_file.model_name.lower()}.config', model_name)
    except (AttributeError, ModuleNotFoundError) as e:
        logger.error(f'Model configuration does not exist for {cfg_from_file.model_name} (err = {e}).')
        sys.exit(-1)

    # Load the default configs to ensure type safety
    base_cfg = OmegaConf.structured(DefaultConfig)
    base_cfg = OmegaConf.merge(base_cfg, {'model': model_cfg, 'training': TrainingConfig()})
    cfg = OmegaConf.merge(base_cfg, cfg_from_file)

    # Setup logging
    log_file = experiment_dir / f'log_{machine_rank}_{communication.get_local_rank()}.txt'
    direct.utils.logging.setup(
        use_stdout=communication.get_local_rank() == 0 or cfg.debug,
        filename=log_file,
        log_level=('INFO' if not cfg.debug else 'DEBUG')
    )
    logger.info(f'Machine rank: {machine_rank}.')
    logger.info(f'Local rank: {communication.get_local_rank()}.')
    logger.info(f'Logging: {log_file}.')
    logger.info(f'Saving to: {experiment_dir}.')
    logger.info(f'Run name: {run_name}.')
    logger.info(f'Config file: {cfg_filename}.')
    logger.info(f'Python version: {sys.version}.')
    logger.info(f'PyTorch version: {torch.__version__}.')  # noqa
    logger.info(f'CUDA {torch.version.cuda} - cuDNN {torch.backends.cudnn.version()}.')
    logger.info(f'Configuration: {pformat(dict(cfg))}.')

    # Create the model
    logger.info('Building model.')
    model = MRIReconstruction(2, **cfg.model).to(device)
    n_params = sum(p.numel() for p in model.parameters())
    logger.info(f'Number of parameters: {n_params} ({n_params / 10.0**3:.2f}k).')
    logger.debug(model)

    # Create training and validation data
    train_mask_func, val_mask_func = build_masking_functions(**cfg.masking)
    train_transforms, val_transforms = build_mri_transforms(
        train_mask_func, val_mask_func=val_mask_func, crop=cfg.dataset.transforms.crop)

    training_data, validation_data = build_datasets(
        cfg.dataset.name, training_root, train_sensitivity_maps=None, train_transforms=train_transforms,
        validation_root=validation_root, val_sensitivity_maps=None, val_transforms=val_transforms)

    # Create the optimizers
    logger.info('Building optimizers.')
    optimizer: torch.optim.Optimizer = str_to_class('torch.optim', cfg.training.optimizer)(  # noqa
        model.parameters(), lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
    )  # noqa

    # Build the LR scheduler, we use a fixed LR schedule step size, no adaptive training schedule.
    solver_steps = list(range(cfg.training.lr_step_size, cfg.training.num_iterations, cfg.training.lr_step_size))
    lr_scheduler = WarmupMultiStepLR(
        optimizer, solver_steps, cfg.training.lr_gamma, warmup_factor=1 / 3.,
        warmup_iters=cfg.training.lr_warmup_iter, warmup_method='linear')

    # Just to make sure.
    torch.cuda.empty_cache()

    # Setup training engine.
    engine = RIMEngine(cfg, model, device=device)

    engine.train(
        optimizer, lr_scheduler, training_data, experiment_dir,
        validation_data=validation_data, resume=resume, num_workers=num_workers)