Beispiel #1
0
 def _load_voc(self):
     trans = get_transform(513, 0, 0, 'voc')
     trainset = VOCSegmentation(root='.voc', image_set='train', download=True,
                             transforms=trans['train'])
     testset = VOCSegmentation(root='.voc', image_set='val', download=True,
                             transforms=trans['test'])
     return {'train': trainset, 'test': testset}
Beispiel #2
0
def load_data(datadir):
    # Data loading code
    print("Loading data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    base_size = 320
    crop_size = 256

    min_size = int(0.5 * base_size)
    max_size = int(2.0 * base_size)

    print("Loading training data")
    st = time.time()
    dataset = VOCSegmentation(datadir,
                              image_set='train',
                              download=True,
                              transforms=Compose([
                                  RandomResize(min_size, max_size),
                                  RandomCrop(crop_size),
                                  RandomHorizontalFlip(0.5),
                                  SampleTransform(
                                      transforms.ColorJitter(brightness=0.3,
                                                             contrast=0.3,
                                                             saturation=0.1,
                                                             hue=0.02)),
                                  ToTensor(),
                                  SampleTransform(normalize)
                              ]))

    print("Took", time.time() - st)

    print("Loading validation data")
    st = time.time()
    dataset_test = VOCSegmentation(datadir,
                                   image_set='val',
                                   download=True,
                                   transforms=Compose([
                                       RandomResize(base_size, base_size),
                                       ToTensor(),
                                       SampleTransform(normalize)
                                   ]))

    print("Took", time.time() - st)
    print("Creating data loaders")
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler
Beispiel #3
0
    def __init__(
        self,
        train: bool = True,
        preproc: bool = False,
        augmentation: bool = False,
        cache: bool = False,
        size: Tuple[int, int] = (192, 256),
        mode: str = "seg20",
        onehot: bool = False,
    ):
        assert mode in self.MODES, f"Mode {mode} not one of {','.join(self.MODES)}"

        self.cache = cache
        self.train = train
        self.preproc = preproc
        self.augmentation = augmentation
        self.size = size
        self.mode = mode
        self.onehot = onehot
        self.n_classes = len(self.classes)

        image_set = "train" if self.train else "val"
        self.D = VOCSegmentation(root=self.path, year="2012", image_set=image_set)

        if self.preproc:
            self._load_transforms()

        if self.cache:
            if self.augmentation:
                self.D.__getitem__ = lru_cache(maxsize=None)(self.D.__getitem__)
            else:
                self.__getitem__ = lru_cache(maxsize=None)(self.__getitem__)
def main():
    parser = argparse.ArgumentParser(description='Dump voc c')
    parser.add_argument('--cn', type=int, default=4, metavar='N',
                        help='Corruption Number')
    parser.add_argument('--sv', type=int, default=1, metavar='N',
                        help='Severity')

    args=parser.parse_args()
    sv = args.sv
    corruption_name = corruption_tuple[args.cn].__name__
    if args.cn==-1 or sv==-1:
        if not os.path.isdir('VOC-C/lbl'):
            os.mkdir('VOC-C/lbl')
    if not os.path.isdir('VOC-C/{}'.format(corruption_name)):
        os.mkdir('VOC-C/{}'.format(corruption_name))
    if not os.path.isdir('VOC-C/{}/{}'.format(corruption_name,sv)):
        os.mkdir('VOC-C/{}/{}'.format(corruption_name,sv))
    corr_val = VOCSegmentation(root='/data/datasets/',
                            transforms=ImLblCorruptTransform(sv,args.cn),
                            image_set='val')
    iterator = enumerate(tqdm(corr_val))
    for n, (im,lbl) in iterator:
        if args.cn==-1 or sv==-1:
            lbl.save('VOC-C/lbl/{:04d}.png'.format(n))
        else:
            save_image(im, 'VOC-C/{}/{}/{:04d}.png'.format(corruption_name,sv,n))
    def download_segmentation_masks():
        """Obtains random masks from the PASCAL VOC 2012 (Segmentation) Dataset."""
        segmentation_masks = []

        # Initialize the PASCAL VOC 2012 dataset for segmentation
        input_transform = transforms.Compose(
            [transforms.Resize(1), transforms.ToTensor()])
        target_transform = transforms.Compose([transforms.ToTensor()])
        dataset = VOCSegmentation('.data/',
                                  image_set='trainval',
                                  download=True,
                                  transform=input_transform,
                                  target_transform=target_transform)
        loader = DataLoader(dataset, batch_size=1)

        for batch in tqdm(loader, desc='Loading Segmentation Masks'):
            _, mask = batch

            # Get the first (and only) example from the batch
            mask = mask[0]
            # NOTE: The masks have boundaries of 1. and inner regions of 0.5, let's all make it 1.
            mask[mask > 0.] = 1.

            # Only get masks that span up to 1/4 of the image
            if torch.mean(mask) <= 0.25:
                segmentation_masks.append(mask)

        return segmentation_masks
Beispiel #6
0
    def __init__(self):
        self.id2name = {}
        self.name2id = {}
        for idx, name in enumerate(self.class_names):
            self.id2name[idx] = name
            self.name2id[name] = idx

        self.train_dataset = \
            VOCSegmentation('./', image_set='train', )

        self.image_trans = transforms.Compose([
            ToTensor(),
        ])

        self.mask_trans = transforms.Compose([
            ToLabel(),
            Relabel(255, 21)  # change 255 to 21
        ])

        self.label_trans = transforms.Compose([ToTensor()])

        self.y_trans = transforms.Compose([ToTensor()])
    def __init__(self,
                 dataset_root,
                 split,
                 download=True,
                 integrity_check=True):
        assert split in (SPLIT_TRAIN, SPLIT_VALID), f'Invalid split {split}'
        self.integrity_check = integrity_check

        root_voc = os.path.join(dataset_root, 'VOC')
        root_sbd = os.path.join(dataset_root, 'SBD')

        self.ds_voc_valid = VOCSegmentation(root_voc,
                                            image_set=SPLIT_VALID,
                                            download=download)

        if split == SPLIT_TRAIN:
            self.ds_voc_train = VOCSegmentation(root_voc,
                                                image_set=SPLIT_TRAIN,
                                                download=False)
            self.ds_sbd_train = SBDataset(
                root_sbd,
                image_set=SPLIT_TRAIN,
                download=download
                and not os.path.isdir(os.path.join(root_sbd, 'img')))
            self.ds_sbd_valid = SBDataset(root_sbd,
                                          image_set=SPLIT_VALID,
                                          download=False)

            self.name_to_ds_id = {
                self._sample_name(path): (self.ds_sbd_train, i)
                for i, path in enumerate(self.ds_sbd_train.images)
            }
            self.name_to_ds_id.update({
                self._sample_name(path): (self.ds_sbd_valid, i)
                for i, path in enumerate(self.ds_sbd_valid.images)
            })
            self.name_to_ds_id.update({
                self._sample_name(path): (self.ds_voc_train, i)
                for i, path in enumerate(self.ds_voc_train.images)
            })
            for path in self.ds_voc_valid.images:
                name = self._sample_name(path)
                self.name_to_ds_id.pop(name, None)
        else:
            self.name_to_ds_id = {
                self._sample_name(path): (self.ds_voc_valid, i)
                for i, path in enumerate(self.ds_voc_valid.images)
            }

        self.sample_names = list(sorted(self.name_to_ds_id.keys()))
        self.transforms = None

        dir = os.path.dirname(__file__)
        path_points_fg = os.path.join(dir, 'voc_whats_the_point.json')
        path_points_bg = os.path.join(
            dir, 'voc_whats_the_point_bg_from_scribbles.json')
        with open(path_points_fg, 'r') as f:
            self.ds_clicks_fg = json.load(f)
        with open(path_points_bg, 'r') as f:
            self.ds_clicks_bg = json.load(f)
        self.ds_scribbles_path = os.path.join(dir, 'voc_scribbles.zip')
        assert os.path.isfile(
            self.ds_scribbles_path
        ), f'Scribbles not found at {self.ds_scribbles_path}'
        self.cls_name_to_id = {
            name: i
            for i, name in enumerate(self.semseg_class_names)
        }
        self._semseg_class_histogram = self._compute_histogram()

        if integrity_check:
            results = []
            for i in tqdm(range(len(self)), desc=f'Checking "{split}" split'):
                results.append(self.get(i))
            for d in results:
                if d['num_clicks_bg'] == 0:
                    print(d['name'], 'has no background clicks')
                if d['num_clicks_fg'] == 0:
                    print(d['name'], 'has no foreground clicks')
            self.integrity_check = False
Beispiel #8
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    base_size = 320
    crop_size = 256
    min_size, max_size = int(0.5 * base_size), int(2.0 * base_size)

    interpolation_mode = InterpolationMode.BILINEAR

    train_loader, val_loader = None, None
    if not args.test_only:
        st = time.time()
        train_set = VOCSegmentation(args.data_path,
                                    image_set='train',
                                    download=True,
                                    transforms=Compose([
                                        RandomResize(min_size, max_size,
                                                     interpolation_mode),
                                        RandomCrop(crop_size),
                                        RandomHorizontalFlip(0.5),
                                        ImageTransform(
                                            T.ColorJitter(brightness=0.3,
                                                          contrast=0.3,
                                                          saturation=0.1,
                                                          hue=0.02)),
                                        ToTensor(),
                                        ImageTransform(normalize)
                                    ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target, ignore_index=255)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        val_set = VOCSegmentation(args.data_path,
                                  image_set='val',
                                  download=True,
                                  transforms=Compose([
                                      Resize((crop_size, crop_size),
                                             interpolation_mode),
                                      ToTensor(),
                                      ImageTransform(normalize)
                                  ]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    if args.source.lower() == 'holocron':
        model = segmentation.__dict__[args.arch](args.pretrained,
                                                 num_classes=len(VOC_CLASSES))
    elif args.source.lower() == 'torchvision':
        model = tv_segmentation.__dict__[args.arch](
            args.pretrained, num_classes=len(VOC_CLASSES))

    # Loss setup
    loss_weight = None
    if isinstance(args.bg_factor, float) and args.bg_factor != 1:
        loss_weight = torch.ones(len(VOC_CLASSES))
        loss_weight[0] = args.bg_factor
    if args.loss == 'crossentropy':
        criterion = nn.CrossEntropyLoss(weight=loss_weight,
                                        ignore_index=255,
                                        label_smoothing=args.label_smoothing)
    elif args.loss == 'focal':
        criterion = holocron.nn.FocalLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'mc':
        criterion = holocron.nn.MutualChannelLoss(weight=loss_weight,
                                                  ignore_index=255,
                                                  xi=3)

    # Optimizer setup
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adamp':
        optimizer = holocron.optim.AdamP(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adabelief':
        optimizer = holocron.optim.AdaBelief(model_params,
                                             args.lr,
                                             betas=(0.95, 0.99),
                                             eps=1e-6,
                                             weight_decay=args.weight_decay)

    log_wb = lambda metrics: wandb.log(metrics) if args.wb else None
    trainer = SegmentationTrainer(model,
                                  train_loader,
                                  val_loader,
                                  criterion,
                                  optimizer,
                                  args.device,
                                  args.output_file,
                                  num_classes=len(VOC_CLASSES),
                                  amp=args.amp,
                                  on_epoch_end=log_wb)
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.show_preds:
        x, target = next(iter(train_loader))
        with torch.no_grad():
            if isinstance(args.device, int):
                x = x.cuda()
            trainer.model.eval()
            preds = trainer.model(x)
        plot_predictions(x.cpu(), preds.cpu(), target, ignore_index=255)
        return

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} (Mean IoU: {eval_metrics['mean_iou']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until,
                        norm_weight_decay=args.norm_weight_decay,
                        num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    norm_weight_decay=args.norm_weight_decay,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    # Training monitoring
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}-{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(name=exp_name,
                         project="holocron-semantic-segmentation",
                         config={
                             "learning_rate": args.lr,
                             "scheduler": args.sched,
                             "weight_decay": args.weight_decay,
                             "epochs": args.epochs,
                             "batch_size": args.batch_size,
                             "architecture": args.arch,
                             "source": args.source,
                             "input_size": 256,
                             "optimizer": args.opt,
                             "dataset": "Pascal VOC2012 Segmentation",
                             "loss": args.loss,
                         })

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs,
                         args.lr,
                         args.freeze_until,
                         args.sched,
                         norm_weight_decay=args.norm_weight_decay)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")

    if args.wb:
        run.finish()
Beispiel #9
0
    def __init__(self, args, val=False, query=False):
        super(VOC2012Segmentation, self).__init__()
        self.dir_checkpoints = f"{args.dir_root}/checkpoints/{args.experim_name}"
        self.ignore_index = args.ignore_index
        self.size_base = args.size_base
        self.size_crop = (args.size_crop, args.size_crop)
        self.stride_total = args.stride_total

        if args.use_augmented_dataset and not val:
            self.voc = AugmentedVOC(args.dir_augmented_dataset)
        else:
            self.voc = VOCSegmentation(f"{args.dir_dataset}", image_set='val' if val else 'train', download=False)
        print("# images:", len(self.voc))

        self.geometric_augmentations = args.augmentations["geometric"]
        self.photometric_augmentations = args.augmentations["photometric"]
        self.normalize = Normalize(mean=args.mean, std=args.std)
        if query:
            self.geometric_augmentations["random_scale"] = False
            self.geometric_augmentations["crop"] = False
            self.geometric_augmentations["random_hflip"] = False

        if self.geometric_augmentations["crop"]:
            self.mean = tuple((np.array(args.mean) * 255.0).astype(np.uint8).tolist())

        # generate initial queries
        n_pixels_per_img = args.n_pixels_by_us
        init_n_pixels = args.n_init_pixels if args.n_init_pixels > 0 else n_pixels_per_img

        self.queries, self.n_pixels_total = None, -1
        path_queries = f"{args.dir_dataset}/init_labelled_pixels_{args.seed}.pkl"
        if n_pixels_per_img != 0 and not val:
            os.makedirs(f"{self.dir_checkpoints}/0_query", exist_ok=True)
            n_pixels_total = 0

            list_queries = list()
            for i in tqdm(range(len(self.voc))):
                label = self.voc[i][1]
                w, h = label.size

                if n_pixels_per_img == 0:
                    n_pixels_per_img = h * w
                elif n_pixels_per_img != 0 and init_n_pixels > 0:
                    n_pixels_per_img = init_n_pixels
                else:
                    raise NotImplementedError

                # generate queries whose size is set to base_size (longer side), i.e. 400 as default
                h, w = self._compute_base_size(h, w)

                queries_flat = np.zeros((h * w), dtype=np.bool)

                # filter void pixels - boundary pixels that the original labels have (fyi, 5 pixels thickness)
                label = label.resize((w, h), Image.NEAREST)  # note that downsampling method should be Image.NEAREST
                label = np.asarray(label, dtype=np.int32)

                label_flatten = label.flatten()
                ind_void_pixels = np.where(label_flatten == 255)[0]

                ind_non_void_pixels = np.setdiff1d(range(len(queries_flat)), ind_void_pixels)  # remove void pixels
                assert len(ind_non_void_pixels) <= len(queries_flat)

                # for a very rare case where the number of non_void_pixels is not large enough to sample from
                if len(ind_non_void_pixels) < n_pixels_per_img:
                    n_pixels_per_img = len(ind_non_void_pixels)

                ind_chosen_pixels = np.random.choice(ind_non_void_pixels, n_pixels_per_img, replace=False)

                queries_flat[ind_chosen_pixels] += True
                queries = queries_flat.reshape((h, w))

                list_queries.append(queries)
                n_pixels_total += queries.sum()
            pkl.dump(list_queries, open(f"{path_queries}", 'wb'))

            # Note that images of voc dataset vary from image to image thus can't use np.stack().
            self.queries = list_queries
            pkl.dump(self.queries, open(f"{self.dir_checkpoints}/0_query/label.pkl", 'wb'))

            self.n_pixels_total = n_pixels_total
            print("# labelled pixels used for training:", n_pixels_total)
        self.val, self.query = val, query
Beispiel #10
0
def main(train_args, model):
    print(train_args)
    dset_path = os.path.join(os.path.abspath(os.environ["HOME"]), 'datasets')

    net = model.to(device)

    train_args['best_record'] = {'epoch': 0,
                                 'val_loss': 1e10,
                                 'acc': 0, 'acc_cls': 0,
                                 'mean_iu': 0,
                                 'fwavacc': 0}

    net.train()

    mean_std = ([0.485, 0.456, 0.406],
                [0.229, 0.224, 0.225])

    input_transform = transforms.Compose([
        transforms.Pad(200),
        transforms.CenterCrop(320),
        transforms.ToTensor(),
        transforms.Normalize(*mean_std)
    ])

    train_transform = transforms.Compose([
        transforms.Pad(200),
        transforms.CenterCrop(320),
        MaskToTensor()])

    restore_transform = transforms.Compose([
        DeNormalize(*mean_std),
        transforms.ToPILImage(),
    ])

    visualize = transforms.Compose([
        transforms.Resize(400),
        transforms.CenterCrop(400),
        transforms.ToTensor()
    ])

    train_set = VOCSegmentation(root=dset_path,
                                image_set='train',
                                transform=input_transform,
                                target_transform=train_transform)
    train_loader = DataLoader(train_set,
                              batch_size=1,
                              num_workers=4,
                              shuffle=True)
    val_set = VOCSegmentation(root=dset_path,
                              image_set='val',
                              transform=input_transform,
                              target_transform=train_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=4,
                            shuffle=False)

    criterion = CrossEntropyLoss(ignore_index=255, reduction='mean').to(device)

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], momentum=train_args['momentum'])

    for epoch in range(1, train_args['epoch_num'] + 1):
        train(train_loader,
              net,
              criterion,
              optimizer,
              epoch,
              train_args)
        val_loss, imges = validate(val_loader,
                                   net,
                                   criterion,
                                   optimizer,
                                   epoch,
                                   train_args,
                                   restore_transform,
                                   visualize)
    return imges
Beispiel #11
0
def main(train_args, model):
    print(train_args)

    net = model.cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    input_transform = standard_transforms.Compose([
        standard_transforms.Pad(200),
        standard_transforms.CenterCrop(320),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    train_transform = standard_transforms.Compose([
        standard_transforms.Pad(200),
        standard_transforms.CenterCrop(320),
        extended_transforms.MaskToTensor()])

    target_transform = extended_transforms.MaskToTensor()

    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])

    visualize = standard_transforms.Compose([
        standard_transforms.Resize(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = VOCSegmentation(root='./', image_set='train', transform=input_transform, target_transform=train_transform)
    train_loader = DataLoader(train_set, batch_size=1, num_workers=4, shuffle=True)
    val_set = VOCSegmentation(root='./', image_set='val', transform=input_transform, target_transform=train_transform)
    val_loader = DataLoader(val_set, batch_size=1, num_workers=4, shuffle=False)

    #criterion = CrossEntropyLoss().cuda()#2d(size_average=False, ignore_index=voc.ignore_label).cuda()
    criterion = CrossEntropyLoss(size_average=False, ignore_index=255).cuda()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], momentum=train_args['momentum'])

    """optimizer = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], betas=(train_args['momentum'], 0.999))"""

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    """check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n')"""

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, train_args)
        val_loss, imges = validate(val_loader, net, criterion, optimizer, epoch, train_args, restore_transform, visualize)
        #imges.show()
        scheduler.step(val_loss)
    return imges
Beispiel #12
0
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# for output bounding box post-processing
"""Let's put everything together in a `detect` function:"""
"""## Loading Pascal VOC 2012 dataset
Before we start let's download the Pascal VOC validation set from the [here](https://oc.embl.de/index.php/s/bkBUhSajTPP0lUP) and save it in your Google Drive. The archive is 2GB in size so it will take a while.

After the ZIP file has been successfully uploaded to your Google Drive, mount your Drive following [the instructions](https://colab.research.google.com/github/constantinpape/training-deep-learning-models-for-vison/blob/master/exercises/mount-gdrive-in-colab.ipynb) and unzip the archive.
"""
"""Let's create the Pascal VOC loader from `torchvision` package and show some images with the ground truth segmentation masks."""

root_dir = "./PascalVOC2012"

voc_dataset = VOCSegmentation(root_dir,
                              year='2012',
                              image_set='trainval',
                              download=False)
"""Before we move on let's define the 20 classes of objects avialable in the Pascal VOC dataset"""

# Pascal VOC classes, modifed to match the COCO classes, i.e. the following 4 class names were mapped:
# aeroplane -> airplane
# diningtable -> dining table
# motorbike -> motorcycle
# sofa -> couch
# tvmonitor -> tv
"""For the exercises we will need a helper function which extracts the bounding boxes around the individual instances given the ground truth semantic mask."""
"""Visualize the bounding boxes on a given image from the Pascal VOC dataset"""

indexes = torch.randint(0, len(voc_dataset), (20, ))
for index, i in enumerate(indexes):
    fig = plt.figure(figsize=(8, 8))
Beispiel #13
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    base_size = 320
    crop_size = 256
    min_size, max_size = int(0.5 * base_size), int(2.0 * base_size)

    train_loader, val_loader = None, None
    if not args.test_only:
        st = time.time()
        train_set = VOCSegmentation(args.data_path,
                                    image_set='train',
                                    download=True,
                                    transforms=Compose([
                                        RandomResize(min_size, max_size),
                                        RandomCrop(crop_size),
                                        RandomHorizontalFlip(0.5),
                                        ImageTransform(
                                            T.ColorJitter(brightness=0.3,
                                                          contrast=0.3,
                                                          saturation=0.1,
                                                          hue=0.02)),
                                        ToTensor(),
                                        ImageTransform(normalize)
                                    ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target, ignore_index=255)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        val_set = VOCSegmentation(args.data_path,
                                  image_set='val',
                                  download=True,
                                  transforms=Compose([
                                      Resize((crop_size, crop_size)),
                                      ToTensor(),
                                      ImageTransform(normalize)
                                  ]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    model = segmentation.__dict__[args.model](
        args.pretrained,
        not (args.pretrained),
        num_classes=len(VOC_CLASSES),
    )

    # Loss setup
    loss_weight = None
    if isinstance(args.bg_factor, float):
        loss_weight = torch.ones(len(VOC_CLASSES))
        loss_weight[0] = args.bg_factor
    if args.loss == 'crossentropy':
        criterion = nn.CrossEntropyLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'label_smoothing':
        criterion = holocron.nn.LabelSmoothingCrossEntropy(weight=loss_weight,
                                                           ignore_index=255)
    elif args.loss == 'focal':
        criterion = holocron.nn.FocalLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'mc':
        criterion = holocron.nn.MutualChannelLoss(weight=loss_weight,
                                                  ignore_index=255)

    # Optimizer setup
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model_params,
                                     args.lr,
                                     betas=(0.95, 0.99),
                                     eps=1e-6,
                                     weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adamp':
        optimizer = holocron.optim.AdamP(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adabelief':
        optimizer = holocron.optim.AdaBelief(model_params,
                                             args.lr,
                                             betas=(0.95, 0.99),
                                             eps=1e-6,
                                             weight_decay=args.weight_decay)

    trainer = SegmentationTrainer(model,
                                  train_loader,
                                  val_loader,
                                  criterion,
                                  optimizer,
                                  args.device,
                                  args.output_file,
                                  num_classes=len(VOC_CLASSES))
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.show_preds:
        x, target = next(iter(train_loader))
        with torch.no_grad():
            if isinstance(args.device, int):
                x = x.cuda()
            trainer.model.eval()
            preds = trainer.model(x)
        plot_predictions(x.cpu(), preds.cpu(), target, ignore_index=255)
        return

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} (Mean IoU: {eval_metrics['mean_iou']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")
Beispiel #14
0
    def __init__(self, args):
        # TODO: augmentation.
        t_val = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.args = args
        self.get_palette = get_palette(256)
        # Savar
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Tensorboard
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Dataloader
        kwargs = {'num_workers:': args.num_workers, 'pin_memory': True}
        # TODO: dataset download
        # self.train_loader, self.val_loader, self.test_loader, self.nclass = get_pascalvoc(args, base_dir=args.pascal_dataset_path ,transforms_train=t)
        t = trainsforms_default()
        self.train_loader = VOCSegmentation(root='./dataset/', year='2012',
                            image_set='train', download=False, transform=t, target_transform=t_val)
        
        self.val_loader = VOCSegmentation(root='./dataset/', year='2012',
                            image_set='val', download=False, transform=t, target_transform=t_val)

        # Dataset
        self.train_loader = DataLoader(self.train_loader, batch_size=args.batch_size, shuffle=True,
                                        num_workers=args.num_workers, drop_last=True, pin_memory=True)
        self.val_loader = DataLoader(self.val_loader, batch_size=args.batch_size, shuffle=False,
                                        num_workers=args.num_workers, drop_last=True, pin_memory=True)

        # Netwok
        self.model = deeplabV3plus(backbone=args.backbone,
                                output_stride=args.out_stride,
                                # num_classes=self.nclass,
                                num_classes=21,
                                sync_bn=args.sync_bn,
                                freeze_bn=args.freeze_bn).to(self.device)
        train_params = [{'params': self.model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': self.model.get_10x_lr_params(), 'lr': args.lr * 10}]
        
        # Optimizer
        self.optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)
        
        # Criterion
        # Wether to use class balanced weights.
        if args.use_balanced_weights:
            pass # TODO:
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=None).build_loss(mode=args.loss_type)

        # Cuda
        if args.data_parallel:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)

        # Evaluator
        self.evaluator = Evaluator(21)
        # Lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader))
        
        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError('no checkpoint found at: {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            self.model.module.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print('Loaded checkpoint: {} (epoch: {})'.format(args.resume, checkpoint['epoch']))
        
        if args.ft:
            args.start_epoch = 0
def load_data(dataset, path, batch_size=64, normalize=False):
    if normalize:
        # Wasserstein BiGAN is trained on normalized data.
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    else:
        # BiGAN is trained on unnormalized data (see Dumoulin et al. ICLR 16).
        transform = transforms.ToTensor()

    if dataset == 'svhn':
        train_set = SVHN(path,
                         split='extra',
                         transform=transform,
                         download=True)
        val_set = SVHN(path, split='test', transform=transform, download=True)

    if dataset == 'stl10':
        train_set = STL10(path,
                          split='train',
                          transform=transform,
                          download=True)
        val_set = STL10(path, split='test', transform=transform, download=True)

    elif dataset == 'cifar10':
        train_set = CIFAR10(path,
                            train=True,
                            transform=transform,
                            download=True)
        val_set = CIFAR10(path,
                          train=False,
                          transform=transform,
                          download=True)

    elif dataset == 'stl10':
        train_set = STL10(path,
                          split='train',
                          transform=transform,
                          download=True)
        val_set = STL10(path, split='test', transform=transform, download=True)

    elif dataset == 'cifar100':
        train_set = CIFAR100(path,
                             train=True,
                             transform=transform,
                             download=True)
        val_set = CIFAR100(path,
                           train=False,
                           transform=transform,
                           download=True)

    elif dataset == 'VOC07':
        train_set = VOCSegmentation(path,
                                    image_set='train',
                                    year='2007',
                                    transform=transform,
                                    download=True)
        val_set = VOCSegmentation(path,
                                  image_set='val',
                                  year='2007',
                                  transform=transform,
                                  download=True)

    elif dataset == 'VOC10':
        train_set = VOCSegmentation(path,
                                    image_set='train',
                                    year='2010',
                                    transform=transform,
                                    download=True)
        val_set = VOCSegmentation(path,
                                  image_set='val',
                                  year='2010',
                                  transform=transform,
                                  download=True)

    train_loader = data.DataLoader(train_set,
                                   batch_size,
                                   shuffle=True,
                                   num_workers=12)
    val_loader = data.DataLoader(val_set,
                                 1,
                                 shuffle=False,
                                 num_workers=1,
                                 pin_memory=True)
    return train_loader, val_loader
Beispiel #16
0
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_seg = transforms.Compose([
    # transforms.RandomCrop(128, 64), # This can cause very low accuracy, pay attention!!!
    transforms.ToTensor(),
])

# load PASCAL VOC 2012 Segmentation dataset
seg_dataset = VOCSegmentation('~/DeLightCMU/CVPR-Prep/Non-local_pytorch/data',
                              year = "2012",
                              image_set='train',
                              download=False,
                              transform=transform_seg,
                              target_transform=transform_seg)
print('VOCSeg ends.')
seg_loader  = DataLoader(seg_dataset, batch_size=1)
print('seg_loader ends.')

# input_num = 0
# for input, target in seg_loader:
#     # print('for loop.')
#     print(input.size(), target.size())
#     input_num = input_num + 1
# print('input_num: ', input_num)
# exit(-1)

# for i, data in enumerate(seg_loader):