Example #1
0
        def __init__(self):
            self.lr = 1e-5  #0.007
            self.lr_gamma = 0.1
            self.momentum = 0.9
            self.weight_decay = 0.00004
            self.bn_mom = 0.0003
            self.power = 0.9
            self.gpus = 1
            self.batch_size = 2
            self.epochs = 10
            self.eval_display = 200
            self.display = 2
            self.num_classes = 21
            self.ckpt_step = 5000
            self.workers = 1
            self.distributed = True
            self.crop_height = 512
            self.crop_width = 512
            self.sampler = DistributedSampler

            self.log_dir = './log'
            self.log_name = 'deeplabv3+'
            self.transforms = Compose([
                RandomResizedCrop(size=(self.crop_height, self.crop_width)),
                RandomHorizontalFlip(),
                RandomVerticalFlip(),
                ToTensor(),
                Normalize(mean=(0.485, 0.456, 0.406),
                          std=(0.229, 0.224, 0.225))
            ])
Example #2
0
class Configuration:
    FOCAL_LOSS_INDICES = None
    CE_LOSS_INDICES = None
    BATCH_SIZE = 2
    CHECKPOINT = ""
    SAVE_FREQUENCY = 4
    CLASS_VALUE = -1
    CROP_SIZE = 256
    CUDA = True
    DATASET = {
        NetMode.TRAIN: "SmartRandomDataLoader",
        NetMode.VALIDATE: "DataLoaderCrop2D",
    }
    FOLDER_WITH_IMAGE_DATA = "/home/branislav/datasets/refuge"

    LEARNING_RATE = 1e-4
    LOSS = CrossEntropyLoss

    MODEL = "DeepLabV3p"
    NUM_CLASSES = 2
    NUM_WORKERS = 8
    NUMBER_OF_EPOCHS = 100
    OUTPUT = "ckpt"
    OUTPUT_FOLDER = "polyps"
    STRIDE = 0.5
    STRIDE_VAL = 0.5
    STRIDE_LIMIT = (1000, 0.5)  # THIS PREVENTS DATASET HALTING

    OPTIMALIZER = SGD
    VALIDATION_FREQUENCY = 1  # num epochs

    MOMENTUM = 0.9
    WEIGHT_DECAY = 1e-4
    AUGMENTATION = ComposeTransforms([
        RandomRotate(0.6),
        RandomSquaredCrop(0.85),
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        Transpose(),
        ToTensor()
    ])
    VAL_AUGMENTATION = ComposeTransforms([Transpose(), ToTensor()])
    PATH_TO_SAVED_SUBIMAGE_INFO = None  # FOLDER_WITH_IMAGE_DATA + "train/info.pkl"

    FOLDERS = {NetMode.TRAIN: "train", NetMode.VALIDATE: "train"}
    SUBFOLDERS = {
        ImagesSubfolder.IMAGES: "images/*.tif",
        ImagesSubfolder.MASKS: "mask/*.tif"
    }
    NUM_RANDOM_CROPS_PER_IMAGE = 4
    VISUALIZER = "VisualizationSaveImages"

    def serialize(self):
        output = {}
        for key in list(filter(lambda x: x.isupper(), dir(self))):
            value = getattr(self, key)
            if any(
                    map(lambda type_: isinstance(value, type_),
                        [str, float, int, tuple, list, dict])):
                output[key] = str(value)
        return output

    def __str__(self):
        serialized = self.serialize()
        return "\n".join(
            [f"{key}: {value}" for key, value in serialized.items()])

    def process_mask(self, mask):
        mask[mask > 0] = 1
        return mask
Example #3
0
def init(batch_size_labeled, batch_size_pseudo, state, split, input_sizes,
         sets_id, std, mean, keep_scale, reverse_channels, data_set, valtiny,
         no_aug):
    # Return data_loaders/data_loader
    # depending on whether the state is
    # 0: Pseudo labeling
    # 1: Semi-supervised training
    # 2: Fully-supervised training
    # 3: Just testing

    # For labeled set divisions
    split_u = split.replace('-r', '')
    split_u = split_u.replace('-l', '')

    # Transformations (compatible with unlabeled data/pseudo labeled data)
    # ! Can't use torchvision.Transforms.Compose
    if data_set == 'voc':
        base = base_voc
        workers = 4
        transform_train = Compose([
            ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
            # RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
            RandomScale(min_scale=0.5, max_scale=1.5),
            RandomCrop(size=input_sizes[0]),
            RandomHorizontalFlip(flip_prob=0.5),
            Normalize(mean=mean, std=std)
        ])
        if no_aug:
            transform_train_pseudo = Compose([
                ToTensor(keep_scale=keep_scale,
                         reverse_channels=reverse_channels),
                Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                Normalize(mean=mean, std=std)
            ])
        else:
            transform_train_pseudo = Compose([
                ToTensor(keep_scale=keep_scale,
                         reverse_channels=reverse_channels),
                # RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                RandomScale(min_scale=0.5, max_scale=1.5),
                RandomCrop(size=input_sizes[0]),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std)
            ])
        # transform_pseudo = Compose(
        #     [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
        #      Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
        #      Normalize(mean=mean, std=std)])
        transform_test = Compose([
            ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
            ZeroPad(size=input_sizes[2]),
            Normalize(mean=mean, std=std)
        ])
    elif data_set == 'city':  # All the same size (whole set is down-sampled by 2)
        base = base_city
        workers = 8
        transform_train = Compose([
            ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
            # RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
            Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
            RandomScale(min_scale=0.5, max_scale=1.5),
            RandomCrop(size=input_sizes[0]),
            RandomHorizontalFlip(flip_prob=0.5),
            Normalize(mean=mean, std=std),
            LabelMap(label_id_map_city)
        ])
        if no_aug:
            transform_train_pseudo = Compose([
                ToTensor(keep_scale=keep_scale,
                         reverse_channels=reverse_channels),
                Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                Normalize(mean=mean, std=std)
            ])
        else:
            transform_train_pseudo = Compose([
                ToTensor(keep_scale=keep_scale,
                         reverse_channels=reverse_channels),
                # RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                RandomScale(min_scale=0.5, max_scale=1.5),
                RandomCrop(size=input_sizes[0]),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std)
            ])
        # transform_pseudo = Compose(
        #     [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
        #      Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
        #      Normalize(mean=mean, std=std),
        #      LabelMap(label_id_map_city)])
        transform_test = Compose([
            ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
            Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
            Normalize(mean=mean, std=std),
            LabelMap(label_id_map_city)
        ])
    else:
        base = ''

    # Not the actual test set (i.e.validation set)
    test_set = StandardSegmentationDataset(
        root=base,
        image_set='valtiny' if valtiny else 'val',
        transforms=transform_test,
        label_state=0,
        data_set=data_set)
    val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                             batch_size=batch_size_labeled +
                                             batch_size_pseudo,
                                             num_workers=workers,
                                             shuffle=False)

    # Testing
    if state == 3:
        return val_loader
    else:
        # Fully-supervised training
        if state == 2:
            labeled_set = StandardSegmentationDataset(
                root=base,
                image_set=(str(split) + '_labeled_' + str(sets_id)),
                transforms=transform_train,
                label_state=0,
                data_set=data_set)
            labeled_loader = torch.utils.data.DataLoader(
                dataset=labeled_set,
                batch_size=batch_size_labeled,
                num_workers=workers,
                shuffle=True)
            return labeled_loader, val_loader

        # Semi-supervised training
        elif state == 1:
            pseudo_labeled_set = StandardSegmentationDataset(
                root=base,
                data_set=data_set,
                mask_type='.npy',
                image_set=(str(split_u) + '_unlabeled_' + str(sets_id)),
                transforms=transform_train_pseudo,
                label_state=1)
            labeled_set = StandardSegmentationDataset(
                root=base,
                data_set=data_set,
                image_set=(str(split) + '_labeled_' + str(sets_id)),
                transforms=transform_train,
                label_state=0)
            pseudo_labeled_loader = torch.utils.data.DataLoader(
                dataset=pseudo_labeled_set,
                batch_size=batch_size_pseudo,
                num_workers=workers,
                shuffle=True)
            labeled_loader = torch.utils.data.DataLoader(
                dataset=labeled_set,
                batch_size=batch_size_labeled,
                num_workers=workers,
                shuffle=True)
            return labeled_loader, pseudo_labeled_loader, val_loader

        else:
            # Labeling
            unlabeled_set = StandardSegmentationDataset(
                root=base,
                data_set=data_set,
                mask_type='.npy',
                image_set=(str(split_u) + '_unlabeled_' + str(sets_id)),
                transforms=transform_test,
                label_state=2)
            unlabeled_loader = torch.utils.data.DataLoader(
                dataset=unlabeled_set,
                batch_size=batch_size_labeled,
                num_workers=workers,
                shuffle=False)
            return unlabeled_loader