def lr_find(launch: str, model_architecture: str, device: str, dataset_type: str, out_dp: str): """Find optimal LR for training with 1-cycle policy.""" const.set_launch_type_env_var(launch == 'local') data_paths = const.DataPaths() split = utils.load_split_from_yaml(const.TRAIN_VALID_SPLIT_FP) if dataset_type == 'nifti': train_dataset = NiftiDataset(data_paths.scans_dp, data_paths.masks_dp, split['train']) elif dataset_type == 'numpy': ndp = const.NumpyDataPaths(data_paths.default_numpy_dataset_dp) train_dataset = NumpyDataset(ndp.scans_dp, ndp.masks_dp, ndp.shapes_fp, split['train']) else: raise ValueError( f"`dataset` should be in ['nifti', 'numpy']. passed '{dataset_type}'" ) loss_func = METRICS_DICT['NegDiceLoss'] train_loader = DataLoaderNoAugmentations(train_dataset, batch_size=4, to_shuffle=True) device_t = torch.device(device) pipeline = Pipeline(model_architecture=model_architecture, device=device_t) pipeline.lr_find_and_store(loss_func=loss_func, train_loader=train_loader, out_dp=out_dp)
def create_numpy_dataset(launch: str, scans_dp: str, masks_dp: str, zoom_factor: float, output_dp: str): """Create numpy dataset from initial Nifti `.nii.gz` scans to speedup the training.""" const.set_launch_type_env_var(launch == 'local') data_paths = const.DataPaths() scans_dp = scans_dp or data_paths.scans_dp masks_dp = masks_dp or data_paths.masks_dp numpy_data_root_dp = data_paths.get_numpy_data_root_dp( zoom_factor=zoom_factor) output_dp = output_dp or numpy_data_root_dp ds = NiftiDataset(scans_dp, masks_dp) ds.store_as_numpy_dataset(output_dp, zoom_factor)
def segment_scans(launch: str, model_architecture: str, device: str, checkpoint_fp: str, scans_dp: str, subset: str, output_dp: str, postfix: str): """Segment Nifti `.nii.gz` scans with already trained model stored in `.pth` file.""" const.set_launch_type_env_var(launch == 'local') data_paths = const.DataPaths() device_t = torch.device(device) pipeline = Pipeline(model_architecture=model_architecture, device=device_t) scans_dp = scans_dp or data_paths.scans_dp ids_list = None if subset == 'validation': split = utils.load_split_from_yaml(const.TRAIN_VALID_SPLIT_FP) ids_list = split['valid'] pipeline.segment_scans(checkpoint_fp=checkpoint_fp, scans_dp=scans_dp, ids=ids_list, output_dp=output_dp, postfix=postfix)
def train(launch: str, model_architecture: str, device: str, dataset_type: str, apply_heavy_augs: bool, n_epochs: int, out_dp: str, max_batches: int, initial_checkpoint_fp: str): """Build and train the model. Heavy augs and warm start are supported.""" loss_func = METRICS_DICT['NegDiceLoss'] metrics = [ METRICS_DICT['BCELoss'], METRICS_DICT['NegDiceLoss'], METRICS_DICT['FocalLoss'] ] const.set_launch_type_env_var(launch == 'local') data_paths = const.DataPaths() split = utils.load_split_from_yaml(const.TRAIN_VALID_SPLIT_FP) if dataset_type == 'nifti': train_dataset = NiftiDataset(data_paths.scans_dp, data_paths.masks_dp, split['train']) valid_dataset = NiftiDataset(data_paths.scans_dp, data_paths.masks_dp, split['valid']) elif dataset_type == 'numpy': ndp = const.NumpyDataPaths(data_paths.default_numpy_dataset_dp) train_dataset = NumpyDataset(ndp.scans_dp, ndp.masks_dp, ndp.shapes_fp, split['train']) valid_dataset = NumpyDataset(ndp.scans_dp, ndp.masks_dp, ndp.shapes_fp, split['valid']) else: raise ValueError( f"`dataset` should be in ['nifti', 'numpy']. passed '{dataset_type}'" ) # init train data loader if apply_heavy_augs: print('\nwill apply heavy augmentations for train images') # set different augmentations for hard and general cases ids_hard_train = utils.get_image_ids_with_hard_cases_in_train_set( const.HARD_CASES_MAPPING, const.TRAIN_VALID_SPLIT_FP) train_dataset.set_different_aug_cnt_for_two_subsets( 1, ids_hard_train, 3) # init loader train_loader = DataLoaderNoAugmentations(train_dataset, batch_size=4, to_shuffle=True) else: print('\nwill apply the same augmentations for all train images') train_loader = DataLoaderWithAugmentations(train_dataset, orig_img_per_batch=2, aug_cnt=1, to_shuffle=True) valid_loader = DataLoaderNoAugmentations(valid_dataset, batch_size=4, to_shuffle=False) device_t = torch.device(device) pipeline = Pipeline(model_architecture=model_architecture, device=device_t) pipeline.train(train_loader=train_loader, valid_loader=valid_loader, n_epochs=n_epochs, loss_func=loss_func, metrics=metrics, out_dp=out_dp, max_batches=max_batches, initial_checkpoint_fp=initial_checkpoint_fp)
def main(launch): const.set_launch_type_env_var(launch == 'local') data_paths = const.DataPaths() add_raw_masks(data_paths.masks_raw_dp, f'{data_paths.root_dp}/masks_orientation_fixed_binary')