def init_seg(input_sizes, std, mean, dataset, test_base=None, test_label_id_map=None, city_aug=0): if dataset == 'voc': transform_test = Compose([ ToTensor(), ZeroPad(size=input_sizes), Normalize(mean=mean, std=std) ]) elif dataset == 'city' or dataset == 'gtav' or dataset == 'synthia': # All the same size if city_aug == 2: # ERFNet and ENet transform_test = Compose([ ToTensor(), Resize(size_image=input_sizes, size_label=input_sizes), LabelMap(test_label_id_map) ]) elif city_aug == 1: # City big transform_test = Compose([ ToTensor(), Resize(size_image=input_sizes, size_label=input_sizes), Normalize(mean=mean, std=std), LabelMap(test_label_id_map) ]) else: raise ValueError # Not the actual test set (i.e. validation set) test_set = StandardSegmentationDataset( root=test_base, image_set='val', transforms=transform_test, data_set='city' if dataset == 'gtav' or dataset == 'synthia' else dataset) val_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=1, num_workers=0, shuffle=False) # Testing return val_loader
def init(batch_size, state, split, input_sizes, sets_id, std, mean, keep_scale, reverse_channels, data_set, valtiny, no_aug): # Return data_loaders/data_loader # depending on whether the split is # 1: semi-supervised training # 2: fully-supervised training # 3: Just testing # Transformations (compatible with unlabeled data/pseudo labeled data) # ! Can't use torchvision.Transforms.Compose if data_set == 'voc': base = base_voc workers = 4 transform_train = Compose( [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels), RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]), RandomCrop(size=input_sizes[0]), RandomHorizontalFlip(flip_prob=0.5), Normalize(mean=mean, std=std)]) if no_aug: transform_train_pseudo = Compose( [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels), Resize(size_image=input_sizes[0], size_label=input_sizes[0]), Normalize(mean=mean, std=std)]) else: transform_train_pseudo = Compose( [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels), RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]), RandomCrop(size=input_sizes[0]), RandomHorizontalFlip(flip_prob=0.5), Normalize(mean=mean, std=std)]) transform_pseudo = Compose( [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels), Resize(size_image=input_sizes[0], size_label=input_sizes[0]), Normalize(mean=mean, std=std)]) transform_test = Compose( [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels), ZeroPad(size=input_sizes[2]), Normalize(mean=mean, std=std)]) elif data_set == 'city': # All the same size (whole set is down-sampled by 2) base = base_city workers = 8 transform_train = Compose( [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels), RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]), RandomCrop(size=input_sizes[0]), RandomHorizontalFlip(flip_prob=0.5), Normalize(mean=mean, std=std), LabelMap(label_id_map_city)]) if no_aug: transform_train_pseudo = Compose( [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels), Resize(size_image=input_sizes[0], size_label=input_sizes[0]), Normalize(mean=mean, std=std)]) else: transform_train_pseudo = Compose( [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels), RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]), RandomCrop(size=input_sizes[0]), RandomHorizontalFlip(flip_prob=0.5), Normalize(mean=mean, std=std)]) transform_pseudo = Compose( [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels), Resize(size_image=input_sizes[0], size_label=input_sizes[0]), Normalize(mean=mean, std=std), LabelMap(label_id_map_city)]) transform_test = Compose( [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels), Resize(size_image=input_sizes[2], size_label=input_sizes[2]), Normalize(mean=mean, std=std), LabelMap(label_id_map_city)]) else: base = '' # Not the actual test set (i.e.validation set) test_set = StandardSegmentationDataset(root=base, image_set='valtiny' if valtiny else 'val', transforms=transform_test, label_state=0, data_set=data_set) val_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, num_workers=workers, shuffle=False) # Testing if state == 3: return val_loader else: # Fully-supervised training if state == 2: labeled_set = StandardSegmentationDataset(root=base, image_set=(str(split) + '_labeled_' + str(sets_id)), transforms=transform_train, label_state=0, data_set=data_set) labeled_loader = torch.utils.data.DataLoader(dataset=labeled_set, batch_size=batch_size, num_workers=workers, shuffle=True) return labeled_loader, val_loader # Semi-supervised training elif state == 1: pseudo_labeled_set = StandardSegmentationDataset(root=base, data_set=data_set, image_set=(str(split) + '_unlabeled_' + str(sets_id)), transforms=transform_train_pseudo, label_state=1) reference_set = SegmentationLabelsDataset(root=base, image_set=(str(split) + '_unlabeled_' + str(sets_id)), data_set=data_set) reference_loader = torch.utils.data.DataLoader(dataset=reference_set, batch_size=batch_size, num_workers=workers, shuffle=False) unlabeled_set = StandardSegmentationDataset(root=base, data_set=data_set, image_set=(str(split) + '_unlabeled_' + str(sets_id)), transforms=transform_pseudo, label_state=2) labeled_set = StandardSegmentationDataset(root=base, data_set=data_set, image_set=(str(split) + '_labeled_' + str(sets_id)), transforms=transform_train, label_state=0) unlabeled_loader = torch.utils.data.DataLoader(dataset=unlabeled_set, batch_size=batch_size, num_workers=workers, shuffle=False) pseudo_labeled_loader = torch.utils.data.DataLoader(dataset=pseudo_labeled_set, batch_size=int(batch_size / 2), num_workers=workers, shuffle=True) labeled_loader = torch.utils.data.DataLoader(dataset=labeled_set, batch_size=int(batch_size / 2), num_workers=workers, shuffle=True) return labeled_loader, pseudo_labeled_loader, unlabeled_loader, val_loader, reference_loader else: # Support unsupervised learning here if that's what you want raise ValueError
def init(batch_size, state, input_sizes, std, mean, dataset, city_aug=0): # Return data_loaders # depending on whether the state is # 1: training # 2: just testing # Transformations # ! Can't use torchvision.Transforms.Compose if dataset == 'voc': base = base_voc workers = 4 transform_train = Compose([ ToTensor(), RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]), RandomCrop(size=input_sizes[0]), RandomHorizontalFlip(flip_prob=0.5), Normalize(mean=mean, std=std) ]) transform_test = Compose([ ToTensor(), ZeroPad(size=input_sizes[2]), Normalize(mean=mean, std=std) ]) elif dataset == 'city' or dataset == 'gtav' or dataset == 'synthia': # All the same size if dataset == 'city': base = base_city elif dataset == 'gtav': base = base_gtav else: base = base_synthia outlier = False if dataset == 'city' else True # GTAV has f****d up label ID workers = 8 if city_aug == 3: # SYNTHIA & GTAV if dataset == 'gtav': transform_train = Compose([ ToTensor(), Resize(size_label=input_sizes[1], size_image=input_sizes[1]), RandomCrop(size=input_sizes[0]), RandomHorizontalFlip(flip_prob=0.5), Normalize(mean=mean, std=std), LabelMap(label_id_map_city, outlier=outlier) ]) else: transform_train = Compose([ ToTensor(), RandomCrop(size=input_sizes[0]), RandomHorizontalFlip(flip_prob=0.5), Normalize(mean=mean, std=std), LabelMap(label_id_map_synthia, outlier=outlier) ]) transform_test = Compose([ ToTensor(), Resize(size_image=input_sizes[2], size_label=input_sizes[2]), Normalize(mean=mean, std=std), LabelMap(label_id_map_city) ]) elif city_aug == 2: # ERFNet transform_train = Compose([ ToTensor(), Resize(size_image=input_sizes[0], size_label=input_sizes[0]), LabelMap(label_id_map_city, outlier=outlier), RandomTranslation(trans_h=2, trans_w=2), RandomHorizontalFlip(flip_prob=0.5) ]) transform_test = Compose([ ToTensor(), Resize(size_image=input_sizes[0], size_label=input_sizes[2]), LabelMap(label_id_map_city) ]) elif city_aug == 1: # City big transform_train = Compose([ ToTensor(), RandomCrop(size=input_sizes[0]), LabelMap(label_id_map_city, outlier=outlier), RandomTranslation(trans_h=2, trans_w=2), RandomHorizontalFlip(flip_prob=0.5), Normalize(mean=mean, std=std) ]) transform_test = Compose([ ToTensor(), Resize(size_image=input_sizes[2], size_label=input_sizes[2]), Normalize(mean=mean, std=std), LabelMap(label_id_map_city) ]) else: # Standard city transform_train = Compose([ ToTensor(), RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]), RandomCrop(size=input_sizes[0]), RandomHorizontalFlip(flip_prob=0.5), Normalize(mean=mean, std=std), LabelMap(label_id_map_city, outlier=outlier) ]) transform_test = Compose([ ToTensor(), Resize(size_image=input_sizes[2], size_label=input_sizes[2]), Normalize(mean=mean, std=std), LabelMap(label_id_map_city) ]) else: raise ValueError # Not the actual test set (i.e. validation set) test_set = StandardSegmentationDataset( root=base_city if dataset == 'gtav' or dataset == 'synthia' else base, image_set='val', transforms=transform_test, data_set='city' if dataset == 'gtav' or dataset == 'synthia' else dataset) if city_aug == 1: val_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=1, num_workers=workers, shuffle=False) else: val_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, num_workers=workers, shuffle=False) # Testing if state == 1: return val_loader else: # Training train_set = StandardSegmentationDataset( root=base, image_set='trainaug' if dataset == 'voc' else 'train', transforms=transform_train, data_set=dataset) train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, num_workers=workers, shuffle=True) return train_loader, val_loader