def test_transforms(
        augmentations_type,
        crop_size,
        dataset_type,
        num_stages,
        shorter_side,
        low_scale,
        high_scale,
        img_mean=(0.5, 0.5, 0.5),
        img_std=(0.5, 0.5, 0.5),
        img_scale=1.0 / 255,
        ignore_label=255,
):
    train_transforms, val_transforms = get_transforms(
        crop_size=broadcast(crop_size, num_stages),
        shorter_side=broadcast(shorter_side, num_stages),
        low_scale=broadcast(low_scale, num_stages),
        high_scale=broadcast(high_scale, num_stages),
        img_mean=(0.5, 0.5, 0.5),
        img_std=(0.5, 0.5, 0.5),
        img_scale=1.0 / 255,
        ignore_label=255,
        num_stages=num_stages,
        augmentations_type=augmentations_type,
        dataset_type=dataset_type,
    )
    assert len(train_transforms) == num_stages
    for is_val, transform in zip([False] * num_stages + [True],
                                 train_transforms + [val_transforms]):
        image, mask = get_dummy_image_and_mask()
        sample = pack_sample(image=image, mask=mask, dataset_type=dataset_type)
        output = transform(*sample)
        image_output, mask_output = unpack_sample(sample=output,
                                                  dataset_type=dataset_type)
        # Test shape
        if not is_val:
            assert (image_output.shape[-2:] == mask_output.shape[-2:] ==
                    (crop_size, crop_size))
        # Test that the outputs are torch tensors
        assert isinstance(image_output, torch.Tensor)
        assert isinstance(mask_output, torch.Tensor)
        # Test that there are no new segmentation classes, except for probably ignore_label
        uq_classes_before = np.unique(mask)
        uq_classes_after = np.unique(mask_output.numpy())
        assert (len(
            np.setdiff1d(uq_classes_after,
                         uq_classes_before.tolist() + [ignore_label])) == 0)
        if is_val:
            # Test that for validation transformation the output shape has not changed
            assert (image_output.shape[-2:] == image.shape[:2] ==
                    mask_output.shape[-2:] == mask.shape[:2])
            # Test that there were no changes to the classes at all
            assert all(uq_classes_before == uq_classes_after)
Exemple #2
0
def get_transforms(
    crop_size,
    shorter_side,
    low_scale,
    high_scale,
    img_mean,
    img_std,
    img_scale,
    ignore_label,
    num_stages,
    augmentations_type,
    dataset_type,
):
    """
    Args:

      crop_size (int) : square crop to apply during the training.
      shorter_side (int) : parameter of the shorter_side resize transformation.
      low_scale (float) : lowest scale ratio for augmentations.
      high_scale (float) : highest scale ratio for augmentations.
      img_mean (list of float) : image mean.
      img_std (list of float) : image standard deviation
      img_scale (list of float) : image scale.
      ignore_label (int) : label to pad segmentation masks with.
      num_stages (int): broadcast training parameters to have this length.
      augmentations_type (str): whether to use densetorch augmentations or albumentations.
      dataset_type (str): whether to use densetorch or torchvision dataset, needed to correctly wrap transformations.

    Returns:
      train_transforms, val_transforms

    """
    crop_size, shorter_side, low_scale, high_scale = [
        broadcast(param, num_stages)
        for param in (crop_size, shorter_side, low_scale, high_scale)
    ]

    if augmentations_type == "densetorch":
        func = densetorch_transforms
    elif augmentations_type == "albumentations":
        func = albumentations_transforms
    else:
        raise ValueError(f"Unknown augmentations type {augmentations_type}")
    return func(
        crop_size=crop_size,
        shorter_side=shorter_side,
        low_scale=low_scale,
        high_scale=high_scale,
        img_mean=img_mean,
        img_std=img_std,
        img_scale=img_scale,
        ignore_label=ignore_label,
        num_stages=num_stages,
        dataset_type=dataset_type,
    )
Exemple #3
0
def get_datasets(
    train_dir,
    val_dir,
    train_list_path,
    val_list_path,
    train_transforms,
    val_transforms,
    masks_names,
    dataset_type,
    stage_names,
    train_download,
    val_download,
):
    # Broadcast train dir to have the same length as train_transforms
    train_dir = broadcast(train_dir, len(train_transforms))
    train_list_path = broadcast(train_list_path, len(train_transforms))
    train_download = broadcast(train_download, len(train_transforms))
    stage_names = broadcast(stage_names, len(train_transforms))
    if dataset_type == "densetorch":
        func = densetorch_dataset
    elif dataset_type == "torchvision":
        func = torchvision_dataset
    else:
        raise ValueError(f"Unknown dataset type {dataset_type}")
    return func(
        train_dir,
        val_dir,
        train_list_path,
        val_list_path,
        train_transforms,
        val_transforms,
        masks_names,
        stage_names,
        train_download,
        val_download,
    )
Exemple #4
0
def get_arguments():
    """Parse all the arguments provided from the CLI."""
    parser = argparse.ArgumentParser(
        description="Arguments for Light-Weight-RefineNet Training Pipeline")

    # Common transformations
    parser.add_argument("--img-scale", type=float, default=1.0 / 255)
    parser.add_argument("--img-mean",
                        type=float,
                        nargs=3,
                        default=(0.485, 0.456, 0.406))
    parser.add_argument("--img-std",
                        type=float,
                        nargs=3,
                        default=(0.229, 0.224, 0.225))

    # Training augmentations
    parser.add_argument(
        "--augmentations-type",
        type=str,
        choices=["densetorch", "albumentations"],
        default="densetorch",
    )

    # Dataset
    parser.add_argument(
        "--val-list-path",
        type=str,
        default="./data/val.nyu",
    )
    parser.add_argument(
        "--val-dir",
        type=str,
        default="./datasets/nyud/",
    )
    parser.add_argument("--val-batch-size", type=int, default=1)

    # Optimisation
    parser.add_argument(
        "--enc-optim-type",
        type=str,
        default="sgd",
    )
    parser.add_argument(
        "--dec-optim-type",
        type=str,
        default="sgd",
    )
    parser.add_argument(
        "--enc-lr",
        type=float,
        default=5e-4,
    )
    parser.add_argument(
        "--dec-lr",
        type=float,
        default=5e-3,
    )
    parser.add_argument(
        "--enc-weight-decay",
        type=float,
        default=1e-5,
    )
    parser.add_argument(
        "--dec-weight-decay",
        type=float,
        default=1e-5,
    )
    parser.add_argument(
        "--enc-momentum",
        type=float,
        default=0.9,
    )
    parser.add_argument(
        "--dec-momentum",
        type=float,
        default=0.9,
    )
    parser.add_argument(
        "--enc-lr-gamma",
        type=float,
        default=0.5,
        help="Multilpy lr_enc by this value after each stage.",
    )
    parser.add_argument(
        "--dec-lr-gamma",
        type=float,
        default=0.5,
        help="Multilpy lr_dec by this value after each stage.",
    )
    parser.add_argument(
        "--enc-scheduler-type",
        type=str,
        choices=["poly", "multistep"],
        default="multistep",
    )
    parser.add_argument(
        "--dec-scheduler-type",
        type=str,
        choices=["poly", "multistep"],
        default="multistep",
    )
    parser.add_argument(
        "--ignore-label",
        type=int,
        default=255,
        help="Ignore this label in the training loss.",
    )
    parser.add_argument("--random-seed", type=int, default=42)

    # Training / validation setup
    parser.add_argument("--enc-backbone",
                        type=str,
                        choices=["50", "101", "152", "mbv2"],
                        default="50")
    parser.add_argument("--enc-pretrained",
                        type=int,
                        choices=[0, 1],
                        default=1)
    parser.add_argument(
        "--num-stages",
        type=int,
        default=3,
        help=
        "Number of training stages. All other arguments with nargs='+' must "
        "have the number of arguments equal to this value. Otherwise, the given "
        "arguments will be broadcasted to have the required length.",
    )
    parser.add_argument("--num-classes", type=int, default=40)
    parser.add_argument(
        "--dataset-type",
        type=str,
        default="densetorch",
        choices=["densetorch", "torchvision"],
    )
    parser.add_argument(
        "--val-download",
        type=int,
        choices=[0, 1],
        default=0,
        help="Only used if dataset_type == torchvision.",
    )

    # Checkpointing configuration
    parser.add_argument("--ckpt-dir", type=str, default="./checkpoints/")
    parser.add_argument(
        "--ckpt-path",
        type=str,
        default="./checkpoints/checkpoint.pth.tar",
        help="Path to the checkpoint file.",
    )

    # Arguments broadcastable across training stages
    stage_parser = parser.add_argument_group("stage-parser")
    stage_parser.add_argument("--crop-size",
                              type=int,
                              nargs="+",
                              default=(
                                  500,
                                  500,
                                  500,
                              ))
    stage_parser.add_argument("--shorter-side",
                              type=int,
                              nargs="+",
                              default=(
                                  350,
                                  350,
                                  350,
                              ))
    stage_parser.add_argument("--low-scale",
                              type=float,
                              nargs="+",
                              default=(
                                  0.5,
                                  0.5,
                                  0.5,
                              ))
    stage_parser.add_argument("--high-scale",
                              type=float,
                              nargs="+",
                              default=(
                                  2.0,
                                  2.0,
                                  2.0,
                              ))
    stage_parser.add_argument("--train-list-path",
                              type=str,
                              nargs="+",
                              default=("./data/train.nyu", ))
    stage_parser.add_argument("--train-dir",
                              type=str,
                              nargs="+",
                              default=("./datasets/nyud/", ))
    stage_parser.add_argument("--train-batch-size",
                              type=int,
                              nargs="+",
                              default=(
                                  6,
                                  6,
                                  6,
                              ))
    stage_parser.add_argument("--freeze-bn",
                              type=int,
                              choices=[0, 1],
                              nargs="+",
                              default=(
                                  1,
                                  1,
                                  1,
                              ))
    stage_parser.add_argument(
        "--epochs-per-stage",
        type=int,
        nargs="+",
        default=(100, 100, 100),
    )
    stage_parser.add_argument("--val-every",
                              type=int,
                              nargs="+",
                              default=(
                                  5,
                                  5,
                                  5,
                              ))
    stage_parser.add_argument(
        "--stage-names",
        type=str,
        nargs="+",
        choices=["SBD", "VOC"],
        default=(
            "SBD",
            "VOC",
        ),
        help="Only used if dataset_type == torchvision.",
    )
    stage_parser.add_argument(
        "--train-download",
        type=int,
        nargs="+",
        choices=[0, 1],
        default=(
            0,
            0,
        ),
        help="Only used if dataset_type == torchvision.",
    )
    stage_parser.add_argument(
        "--grad-norm",
        type=float,
        nargs="+",
        default=(0.0, ),
        help="If > 0.0, clip gradients' norm to this value.",
    )
    args = parser.parse_args()
    # Broadcast all arguments in stage-parser
    for group_action in stage_parser._group_actions:
        argument_name = group_action.dest
        setattr(
            args,
            argument_name,
            broadcast(getattr(args, argument_name), args.num_stages),
        )
    return args