Exemple #1
0
def lr_find(launch: str, model_architecture: str, device: str,
            dataset_type: str, out_dp: str):
    """Find optimal LR for training with 1-cycle policy."""
    const.set_launch_type_env_var(launch == 'local')
    data_paths = const.DataPaths()

    split = utils.load_split_from_yaml(const.TRAIN_VALID_SPLIT_FP)

    if dataset_type == 'nifti':
        train_dataset = NiftiDataset(data_paths.scans_dp, data_paths.masks_dp,
                                     split['train'])
    elif dataset_type == 'numpy':
        ndp = const.NumpyDataPaths(data_paths.default_numpy_dataset_dp)
        train_dataset = NumpyDataset(ndp.scans_dp, ndp.masks_dp, ndp.shapes_fp,
                                     split['train'])
    else:
        raise ValueError(
            f"`dataset` should be in ['nifti', 'numpy']. passed '{dataset_type}'"
        )

    loss_func = METRICS_DICT['NegDiceLoss']
    train_loader = DataLoaderNoAugmentations(train_dataset,
                                             batch_size=4,
                                             to_shuffle=True)
    device_t = torch.device(device)

    pipeline = Pipeline(model_architecture=model_architecture, device=device_t)
    pipeline.lr_find_and_store(loss_func=loss_func,
                               train_loader=train_loader,
                               out_dp=out_dp)
Exemple #2
0
def create_numpy_dataset(launch: str, scans_dp: str, masks_dp: str,
                         zoom_factor: float, output_dp: str):
    """Create numpy dataset from initial Nifti `.nii.gz` scans to speedup the training."""
    const.set_launch_type_env_var(launch == 'local')
    data_paths = const.DataPaths()

    scans_dp = scans_dp or data_paths.scans_dp
    masks_dp = masks_dp or data_paths.masks_dp

    numpy_data_root_dp = data_paths.get_numpy_data_root_dp(
        zoom_factor=zoom_factor)
    output_dp = output_dp or numpy_data_root_dp

    ds = NiftiDataset(scans_dp, masks_dp)
    ds.store_as_numpy_dataset(output_dp, zoom_factor)
Exemple #3
0
def segment_scans(launch: str, model_architecture: str, device: str,
                  checkpoint_fp: str, scans_dp: str, subset: str,
                  output_dp: str, postfix: str):
    """Segment Nifti `.nii.gz` scans with already trained model stored in `.pth` file."""
    const.set_launch_type_env_var(launch == 'local')
    data_paths = const.DataPaths()

    device_t = torch.device(device)
    pipeline = Pipeline(model_architecture=model_architecture, device=device_t)

    scans_dp = scans_dp or data_paths.scans_dp

    ids_list = None
    if subset == 'validation':
        split = utils.load_split_from_yaml(const.TRAIN_VALID_SPLIT_FP)
        ids_list = split['valid']

    pipeline.segment_scans(checkpoint_fp=checkpoint_fp,
                           scans_dp=scans_dp,
                           ids=ids_list,
                           output_dp=output_dp,
                           postfix=postfix)
Exemple #4
0
def train(launch: str, model_architecture: str, device: str, dataset_type: str,
          apply_heavy_augs: bool, n_epochs: int, out_dp: str, max_batches: int,
          initial_checkpoint_fp: str):
    """Build and train the model. Heavy augs and warm start are supported."""
    loss_func = METRICS_DICT['NegDiceLoss']
    metrics = [
        METRICS_DICT['BCELoss'], METRICS_DICT['NegDiceLoss'],
        METRICS_DICT['FocalLoss']
    ]

    const.set_launch_type_env_var(launch == 'local')
    data_paths = const.DataPaths()

    split = utils.load_split_from_yaml(const.TRAIN_VALID_SPLIT_FP)

    if dataset_type == 'nifti':
        train_dataset = NiftiDataset(data_paths.scans_dp, data_paths.masks_dp,
                                     split['train'])
        valid_dataset = NiftiDataset(data_paths.scans_dp, data_paths.masks_dp,
                                     split['valid'])
    elif dataset_type == 'numpy':
        ndp = const.NumpyDataPaths(data_paths.default_numpy_dataset_dp)
        train_dataset = NumpyDataset(ndp.scans_dp, ndp.masks_dp, ndp.shapes_fp,
                                     split['train'])
        valid_dataset = NumpyDataset(ndp.scans_dp, ndp.masks_dp, ndp.shapes_fp,
                                     split['valid'])
    else:
        raise ValueError(
            f"`dataset` should be in ['nifti', 'numpy']. passed '{dataset_type}'"
        )

    # init train data loader
    if apply_heavy_augs:
        print('\nwill apply heavy augmentations for train images')

        # set different augmentations for hard and general cases
        ids_hard_train = utils.get_image_ids_with_hard_cases_in_train_set(
            const.HARD_CASES_MAPPING, const.TRAIN_VALID_SPLIT_FP)
        train_dataset.set_different_aug_cnt_for_two_subsets(
            1, ids_hard_train, 3)
        # init loader
        train_loader = DataLoaderNoAugmentations(train_dataset,
                                                 batch_size=4,
                                                 to_shuffle=True)
    else:
        print('\nwill apply the same augmentations for all train images')
        train_loader = DataLoaderWithAugmentations(train_dataset,
                                                   orig_img_per_batch=2,
                                                   aug_cnt=1,
                                                   to_shuffle=True)

    valid_loader = DataLoaderNoAugmentations(valid_dataset,
                                             batch_size=4,
                                             to_shuffle=False)

    device_t = torch.device(device)
    pipeline = Pipeline(model_architecture=model_architecture, device=device_t)

    pipeline.train(train_loader=train_loader,
                   valid_loader=valid_loader,
                   n_epochs=n_epochs,
                   loss_func=loss_func,
                   metrics=metrics,
                   out_dp=out_dp,
                   max_batches=max_batches,
                   initial_checkpoint_fp=initial_checkpoint_fp)
def main(launch):
    const.set_launch_type_env_var(launch == 'local')
    data_paths = const.DataPaths()

    add_raw_masks(data_paths.masks_raw_dp,
                  f'{data_paths.root_dp}/masks_orientation_fixed_binary')