def test_transforms( augmentations_type, crop_size, dataset_type, num_stages, shorter_side, low_scale, high_scale, img_mean=(0.5, 0.5, 0.5), img_std=(0.5, 0.5, 0.5), img_scale=1.0 / 255, ignore_label=255, ): train_transforms, val_transforms = get_transforms( crop_size=broadcast(crop_size, num_stages), shorter_side=broadcast(shorter_side, num_stages), low_scale=broadcast(low_scale, num_stages), high_scale=broadcast(high_scale, num_stages), img_mean=(0.5, 0.5, 0.5), img_std=(0.5, 0.5, 0.5), img_scale=1.0 / 255, ignore_label=255, num_stages=num_stages, augmentations_type=augmentations_type, dataset_type=dataset_type, ) assert len(train_transforms) == num_stages for is_val, transform in zip([False] * num_stages + [True], train_transforms + [val_transforms]): image, mask = get_dummy_image_and_mask() sample = pack_sample(image=image, mask=mask, dataset_type=dataset_type) output = transform(*sample) image_output, mask_output = unpack_sample(sample=output, dataset_type=dataset_type) # Test shape if not is_val: assert (image_output.shape[-2:] == mask_output.shape[-2:] == (crop_size, crop_size)) # Test that the outputs are torch tensors assert isinstance(image_output, torch.Tensor) assert isinstance(mask_output, torch.Tensor) # Test that there are no new segmentation classes, except for probably ignore_label uq_classes_before = np.unique(mask) uq_classes_after = np.unique(mask_output.numpy()) assert (len( np.setdiff1d(uq_classes_after, uq_classes_before.tolist() + [ignore_label])) == 0) if is_val: # Test that for validation transformation the output shape has not changed assert (image_output.shape[-2:] == image.shape[:2] == mask_output.shape[-2:] == mask.shape[:2]) # Test that there were no changes to the classes at all assert all(uq_classes_before == uq_classes_after)
def get_transforms( crop_size, shorter_side, low_scale, high_scale, img_mean, img_std, img_scale, ignore_label, num_stages, augmentations_type, dataset_type, ): """ Args: crop_size (int) : square crop to apply during the training. shorter_side (int) : parameter of the shorter_side resize transformation. low_scale (float) : lowest scale ratio for augmentations. high_scale (float) : highest scale ratio for augmentations. img_mean (list of float) : image mean. img_std (list of float) : image standard deviation img_scale (list of float) : image scale. ignore_label (int) : label to pad segmentation masks with. num_stages (int): broadcast training parameters to have this length. augmentations_type (str): whether to use densetorch augmentations or albumentations. dataset_type (str): whether to use densetorch or torchvision dataset, needed to correctly wrap transformations. Returns: train_transforms, val_transforms """ crop_size, shorter_side, low_scale, high_scale = [ broadcast(param, num_stages) for param in (crop_size, shorter_side, low_scale, high_scale) ] if augmentations_type == "densetorch": func = densetorch_transforms elif augmentations_type == "albumentations": func = albumentations_transforms else: raise ValueError(f"Unknown augmentations type {augmentations_type}") return func( crop_size=crop_size, shorter_side=shorter_side, low_scale=low_scale, high_scale=high_scale, img_mean=img_mean, img_std=img_std, img_scale=img_scale, ignore_label=ignore_label, num_stages=num_stages, dataset_type=dataset_type, )
def get_datasets( train_dir, val_dir, train_list_path, val_list_path, train_transforms, val_transforms, masks_names, dataset_type, stage_names, train_download, val_download, ): # Broadcast train dir to have the same length as train_transforms train_dir = broadcast(train_dir, len(train_transforms)) train_list_path = broadcast(train_list_path, len(train_transforms)) train_download = broadcast(train_download, len(train_transforms)) stage_names = broadcast(stage_names, len(train_transforms)) if dataset_type == "densetorch": func = densetorch_dataset elif dataset_type == "torchvision": func = torchvision_dataset else: raise ValueError(f"Unknown dataset type {dataset_type}") return func( train_dir, val_dir, train_list_path, val_list_path, train_transforms, val_transforms, masks_names, stage_names, train_download, val_download, )
def get_arguments(): """Parse all the arguments provided from the CLI.""" parser = argparse.ArgumentParser( description="Arguments for Light-Weight-RefineNet Training Pipeline") # Common transformations parser.add_argument("--img-scale", type=float, default=1.0 / 255) parser.add_argument("--img-mean", type=float, nargs=3, default=(0.485, 0.456, 0.406)) parser.add_argument("--img-std", type=float, nargs=3, default=(0.229, 0.224, 0.225)) # Training augmentations parser.add_argument( "--augmentations-type", type=str, choices=["densetorch", "albumentations"], default="densetorch", ) # Dataset parser.add_argument( "--val-list-path", type=str, default="./data/val.nyu", ) parser.add_argument( "--val-dir", type=str, default="./datasets/nyud/", ) parser.add_argument("--val-batch-size", type=int, default=1) # Optimisation parser.add_argument( "--enc-optim-type", type=str, default="sgd", ) parser.add_argument( "--dec-optim-type", type=str, default="sgd", ) parser.add_argument( "--enc-lr", type=float, default=5e-4, ) parser.add_argument( "--dec-lr", type=float, default=5e-3, ) parser.add_argument( "--enc-weight-decay", type=float, default=1e-5, ) parser.add_argument( "--dec-weight-decay", type=float, default=1e-5, ) parser.add_argument( "--enc-momentum", type=float, default=0.9, ) parser.add_argument( "--dec-momentum", type=float, default=0.9, ) parser.add_argument( "--enc-lr-gamma", type=float, default=0.5, help="Multilpy lr_enc by this value after each stage.", ) parser.add_argument( "--dec-lr-gamma", type=float, default=0.5, help="Multilpy lr_dec by this value after each stage.", ) parser.add_argument( "--enc-scheduler-type", type=str, choices=["poly", "multistep"], default="multistep", ) parser.add_argument( "--dec-scheduler-type", type=str, choices=["poly", "multistep"], default="multistep", ) parser.add_argument( "--ignore-label", type=int, default=255, help="Ignore this label in the training loss.", ) parser.add_argument("--random-seed", type=int, default=42) # Training / validation setup parser.add_argument("--enc-backbone", type=str, choices=["50", "101", "152", "mbv2"], default="50") parser.add_argument("--enc-pretrained", type=int, choices=[0, 1], default=1) parser.add_argument( "--num-stages", type=int, default=3, help= "Number of training stages. All other arguments with nargs='+' must " "have the number of arguments equal to this value. Otherwise, the given " "arguments will be broadcasted to have the required length.", ) parser.add_argument("--num-classes", type=int, default=40) parser.add_argument( "--dataset-type", type=str, default="densetorch", choices=["densetorch", "torchvision"], ) parser.add_argument( "--val-download", type=int, choices=[0, 1], default=0, help="Only used if dataset_type == torchvision.", ) # Checkpointing configuration parser.add_argument("--ckpt-dir", type=str, default="./checkpoints/") parser.add_argument( "--ckpt-path", type=str, default="./checkpoints/checkpoint.pth.tar", help="Path to the checkpoint file.", ) # Arguments broadcastable across training stages stage_parser = parser.add_argument_group("stage-parser") stage_parser.add_argument("--crop-size", type=int, nargs="+", default=( 500, 500, 500, )) stage_parser.add_argument("--shorter-side", type=int, nargs="+", default=( 350, 350, 350, )) stage_parser.add_argument("--low-scale", type=float, nargs="+", default=( 0.5, 0.5, 0.5, )) stage_parser.add_argument("--high-scale", type=float, nargs="+", default=( 2.0, 2.0, 2.0, )) stage_parser.add_argument("--train-list-path", type=str, nargs="+", default=("./data/train.nyu", )) stage_parser.add_argument("--train-dir", type=str, nargs="+", default=("./datasets/nyud/", )) stage_parser.add_argument("--train-batch-size", type=int, nargs="+", default=( 6, 6, 6, )) stage_parser.add_argument("--freeze-bn", type=int, choices=[0, 1], nargs="+", default=( 1, 1, 1, )) stage_parser.add_argument( "--epochs-per-stage", type=int, nargs="+", default=(100, 100, 100), ) stage_parser.add_argument("--val-every", type=int, nargs="+", default=( 5, 5, 5, )) stage_parser.add_argument( "--stage-names", type=str, nargs="+", choices=["SBD", "VOC"], default=( "SBD", "VOC", ), help="Only used if dataset_type == torchvision.", ) stage_parser.add_argument( "--train-download", type=int, nargs="+", choices=[0, 1], default=( 0, 0, ), help="Only used if dataset_type == torchvision.", ) stage_parser.add_argument( "--grad-norm", type=float, nargs="+", default=(0.0, ), help="If > 0.0, clip gradients' norm to this value.", ) args = parser.parse_args() # Broadcast all arguments in stage-parser for group_action in stage_parser._group_actions: argument_name = group_action.dest setattr( args, argument_name, broadcast(getattr(args, argument_name), args.num_stages), ) return args