def init_seg(input_sizes,
             std,
             mean,
             dataset,
             test_base=None,
             test_label_id_map=None,
             city_aug=0):

    if dataset == 'voc':
        transform_test = Compose([
            ToTensor(),
            ZeroPad(size=input_sizes),
            Normalize(mean=mean, std=std)
        ])
    elif dataset == 'city' or dataset == 'gtav' or dataset == 'synthia':  # All the same size
        if city_aug == 2:  # ERFNet and ENet
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes, size_label=input_sizes),
                LabelMap(test_label_id_map)
            ])
        elif city_aug == 1:  # City big
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes, size_label=input_sizes),
                Normalize(mean=mean, std=std),
                LabelMap(test_label_id_map)
            ])
    else:
        raise ValueError

    # Not the actual test set (i.e. validation set)
    test_set = StandardSegmentationDataset(
        root=test_base,
        image_set='val',
        transforms=transform_test,
        data_set='city'
        if dataset == 'gtav' or dataset == 'synthia' else dataset)

    val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                             batch_size=1,
                                             num_workers=0,
                                             shuffle=False)

    # Testing
    return val_loader
Exemplo n.º 2
0
def init(batch_size, state, split, input_sizes, sets_id, std, mean, keep_scale, reverse_channels, data_set,
         valtiny, no_aug):
    # Return data_loaders/data_loader
    # depending on whether the split is
    # 1: semi-supervised training
    # 2: fully-supervised training
    # 3: Just testing

    # Transformations (compatible with unlabeled data/pseudo labeled data)
    # ! Can't use torchvision.Transforms.Compose
    if data_set == 'voc':
        base = base_voc
        workers = 4
        transform_train = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
             RandomCrop(size=input_sizes[0]),
             RandomHorizontalFlip(flip_prob=0.5),
             Normalize(mean=mean, std=std)])
        if no_aug:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                 Normalize(mean=mean, std=std)])
        else:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                 RandomCrop(size=input_sizes[0]),
                 RandomHorizontalFlip(flip_prob=0.5),
                 Normalize(mean=mean, std=std)])
        transform_pseudo = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
             Normalize(mean=mean, std=std)])
        transform_test = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             ZeroPad(size=input_sizes[2]),
             Normalize(mean=mean, std=std)])
    elif data_set == 'city':  # All the same size (whole set is down-sampled by 2)
        base = base_city
        workers = 8
        transform_train = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
             RandomCrop(size=input_sizes[0]),
             RandomHorizontalFlip(flip_prob=0.5),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
        if no_aug:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                 Normalize(mean=mean, std=std)])
        else:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                 RandomCrop(size=input_sizes[0]),
                 RandomHorizontalFlip(flip_prob=0.5),
                 Normalize(mean=mean, std=std)])
        transform_pseudo = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
        transform_test = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
    else:
        base = ''

    # Not the actual test set (i.e.validation set)
    test_set = StandardSegmentationDataset(root=base, image_set='valtiny' if valtiny else 'val',
                                           transforms=transform_test, label_state=0, data_set=data_set)
    val_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, num_workers=workers, shuffle=False)

    # Testing
    if state == 3:
        return val_loader
    else:
        # Fully-supervised training
        if state == 2:
            labeled_set = StandardSegmentationDataset(root=base, image_set=(str(split) + '_labeled_' + str(sets_id)),
                                                      transforms=transform_train, label_state=0, data_set=data_set)
            labeled_loader = torch.utils.data.DataLoader(dataset=labeled_set, batch_size=batch_size,
                                                         num_workers=workers, shuffle=True)
            return labeled_loader, val_loader

        # Semi-supervised training
        elif state == 1:
            pseudo_labeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                             image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                             transforms=transform_train_pseudo, label_state=1)
            reference_set = SegmentationLabelsDataset(root=base, image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                      data_set=data_set)
            reference_loader = torch.utils.data.DataLoader(dataset=reference_set, batch_size=batch_size,
                                                           num_workers=workers, shuffle=False)
            unlabeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                        image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                        transforms=transform_pseudo, label_state=2)
            labeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                      image_set=(str(split) + '_labeled_' + str(sets_id)),
                                                      transforms=transform_train, label_state=0)

            unlabeled_loader = torch.utils.data.DataLoader(dataset=unlabeled_set, batch_size=batch_size,
                                                           num_workers=workers, shuffle=False)

            pseudo_labeled_loader = torch.utils.data.DataLoader(dataset=pseudo_labeled_set,
                                                                batch_size=int(batch_size / 2),
                                                                num_workers=workers, shuffle=True)
            labeled_loader = torch.utils.data.DataLoader(dataset=labeled_set,
                                                         batch_size=int(batch_size / 2),
                                                         num_workers=workers, shuffle=True)
            return labeled_loader, pseudo_labeled_loader, unlabeled_loader, val_loader, reference_loader

        else:
            # Support unsupervised learning here if that's what you want
            raise ValueError
def init(batch_size, state, input_sizes, std, mean, dataset, city_aug=0):
    # Return data_loaders
    # depending on whether the state is
    # 1: training
    # 2: just testing

    # Transformations
    # ! Can't use torchvision.Transforms.Compose
    if dataset == 'voc':
        base = base_voc
        workers = 4
        transform_train = Compose([
            ToTensor(),
            RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
            RandomCrop(size=input_sizes[0]),
            RandomHorizontalFlip(flip_prob=0.5),
            Normalize(mean=mean, std=std)
        ])
        transform_test = Compose([
            ToTensor(),
            ZeroPad(size=input_sizes[2]),
            Normalize(mean=mean, std=std)
        ])
    elif dataset == 'city' or dataset == 'gtav' or dataset == 'synthia':  # All the same size
        if dataset == 'city':
            base = base_city
        elif dataset == 'gtav':
            base = base_gtav
        else:
            base = base_synthia
        outlier = False if dataset == 'city' else True  # GTAV has f****d up label ID
        workers = 8

        if city_aug == 3:  # SYNTHIA & GTAV
            if dataset == 'gtav':
                transform_train = Compose([
                    ToTensor(),
                    Resize(size_label=input_sizes[1],
                           size_image=input_sizes[1]),
                    RandomCrop(size=input_sizes[0]),
                    RandomHorizontalFlip(flip_prob=0.5),
                    Normalize(mean=mean, std=std),
                    LabelMap(label_id_map_city, outlier=outlier)
                ])
            else:
                transform_train = Compose([
                    ToTensor(),
                    RandomCrop(size=input_sizes[0]),
                    RandomHorizontalFlip(flip_prob=0.5),
                    Normalize(mean=mean, std=std),
                    LabelMap(label_id_map_synthia, outlier=outlier)
                ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
        elif city_aug == 2:  # ERFNet
            transform_train = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                LabelMap(label_id_map_city, outlier=outlier),
                RandomTranslation(trans_h=2, trans_w=2),
                RandomHorizontalFlip(flip_prob=0.5)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[0], size_label=input_sizes[2]),
                LabelMap(label_id_map_city)
            ])
        elif city_aug == 1:  # City big
            transform_train = Compose([
                ToTensor(),
                RandomCrop(size=input_sizes[0]),
                LabelMap(label_id_map_city, outlier=outlier),
                RandomTranslation(trans_h=2, trans_w=2),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
        else:  # Standard city
            transform_train = Compose([
                ToTensor(),
                RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                RandomCrop(size=input_sizes[0]),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city, outlier=outlier)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
    else:
        raise ValueError

    # Not the actual test set (i.e. validation set)
    test_set = StandardSegmentationDataset(
        root=base_city if dataset == 'gtav' or dataset == 'synthia' else base,
        image_set='val',
        transforms=transform_test,
        data_set='city'
        if dataset == 'gtav' or dataset == 'synthia' else dataset)
    if city_aug == 1:
        val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                 batch_size=1,
                                                 num_workers=workers,
                                                 shuffle=False)
    else:
        val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                 batch_size=batch_size,
                                                 num_workers=workers,
                                                 shuffle=False)

    # Testing
    if state == 1:
        return val_loader
    else:
        # Training
        train_set = StandardSegmentationDataset(
            root=base,
            image_set='trainaug' if dataset == 'voc' else 'train',
            transforms=transform_train,
            data_set=dataset)
        train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                   batch_size=batch_size,
                                                   num_workers=workers,
                                                   shuffle=True)
        return train_loader, val_loader