def _load_voc(self): trans = get_transform(513, 0, 0, 'voc') trainset = VOCSegmentation(root='.voc', image_set='train', download=True, transforms=trans['train']) testset = VOCSegmentation(root='.voc', image_set='val', download=True, transforms=trans['test']) return {'train': trainset, 'test': testset}
def load_data(datadir): # Data loading code print("Loading data") normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) base_size = 320 crop_size = 256 min_size = int(0.5 * base_size) max_size = int(2.0 * base_size) print("Loading training data") st = time.time() dataset = VOCSegmentation(datadir, image_set='train', download=True, transforms=Compose([ RandomResize(min_size, max_size), RandomCrop(crop_size), RandomHorizontalFlip(0.5), SampleTransform( transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.1, hue=0.02)), ToTensor(), SampleTransform(normalize) ])) print("Took", time.time() - st) print("Loading validation data") st = time.time() dataset_test = VOCSegmentation(datadir, image_set='val', download=True, transforms=Compose([ RandomResize(base_size, base_size), ToTensor(), SampleTransform(normalize) ])) print("Took", time.time() - st) print("Creating data loaders") train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset, dataset_test, train_sampler, test_sampler
def __init__( self, train: bool = True, preproc: bool = False, augmentation: bool = False, cache: bool = False, size: Tuple[int, int] = (192, 256), mode: str = "seg20", onehot: bool = False, ): assert mode in self.MODES, f"Mode {mode} not one of {','.join(self.MODES)}" self.cache = cache self.train = train self.preproc = preproc self.augmentation = augmentation self.size = size self.mode = mode self.onehot = onehot self.n_classes = len(self.classes) image_set = "train" if self.train else "val" self.D = VOCSegmentation(root=self.path, year="2012", image_set=image_set) if self.preproc: self._load_transforms() if self.cache: if self.augmentation: self.D.__getitem__ = lru_cache(maxsize=None)(self.D.__getitem__) else: self.__getitem__ = lru_cache(maxsize=None)(self.__getitem__)
def main(): parser = argparse.ArgumentParser(description='Dump voc c') parser.add_argument('--cn', type=int, default=4, metavar='N', help='Corruption Number') parser.add_argument('--sv', type=int, default=1, metavar='N', help='Severity') args=parser.parse_args() sv = args.sv corruption_name = corruption_tuple[args.cn].__name__ if args.cn==-1 or sv==-1: if not os.path.isdir('VOC-C/lbl'): os.mkdir('VOC-C/lbl') if not os.path.isdir('VOC-C/{}'.format(corruption_name)): os.mkdir('VOC-C/{}'.format(corruption_name)) if not os.path.isdir('VOC-C/{}/{}'.format(corruption_name,sv)): os.mkdir('VOC-C/{}/{}'.format(corruption_name,sv)) corr_val = VOCSegmentation(root='/data/datasets/', transforms=ImLblCorruptTransform(sv,args.cn), image_set='val') iterator = enumerate(tqdm(corr_val)) for n, (im,lbl) in iterator: if args.cn==-1 or sv==-1: lbl.save('VOC-C/lbl/{:04d}.png'.format(n)) else: save_image(im, 'VOC-C/{}/{}/{:04d}.png'.format(corruption_name,sv,n))
def download_segmentation_masks(): """Obtains random masks from the PASCAL VOC 2012 (Segmentation) Dataset.""" segmentation_masks = [] # Initialize the PASCAL VOC 2012 dataset for segmentation input_transform = transforms.Compose( [transforms.Resize(1), transforms.ToTensor()]) target_transform = transforms.Compose([transforms.ToTensor()]) dataset = VOCSegmentation('.data/', image_set='trainval', download=True, transform=input_transform, target_transform=target_transform) loader = DataLoader(dataset, batch_size=1) for batch in tqdm(loader, desc='Loading Segmentation Masks'): _, mask = batch # Get the first (and only) example from the batch mask = mask[0] # NOTE: The masks have boundaries of 1. and inner regions of 0.5, let's all make it 1. mask[mask > 0.] = 1. # Only get masks that span up to 1/4 of the image if torch.mean(mask) <= 0.25: segmentation_masks.append(mask) return segmentation_masks
def __init__(self): self.id2name = {} self.name2id = {} for idx, name in enumerate(self.class_names): self.id2name[idx] = name self.name2id[name] = idx self.train_dataset = \ VOCSegmentation('./', image_set='train', ) self.image_trans = transforms.Compose([ ToTensor(), ]) self.mask_trans = transforms.Compose([ ToLabel(), Relabel(255, 21) # change 255 to 21 ]) self.label_trans = transforms.Compose([ToTensor()]) self.y_trans = transforms.Compose([ToTensor()])
def __init__(self, dataset_root, split, download=True, integrity_check=True): assert split in (SPLIT_TRAIN, SPLIT_VALID), f'Invalid split {split}' self.integrity_check = integrity_check root_voc = os.path.join(dataset_root, 'VOC') root_sbd = os.path.join(dataset_root, 'SBD') self.ds_voc_valid = VOCSegmentation(root_voc, image_set=SPLIT_VALID, download=download) if split == SPLIT_TRAIN: self.ds_voc_train = VOCSegmentation(root_voc, image_set=SPLIT_TRAIN, download=False) self.ds_sbd_train = SBDataset( root_sbd, image_set=SPLIT_TRAIN, download=download and not os.path.isdir(os.path.join(root_sbd, 'img'))) self.ds_sbd_valid = SBDataset(root_sbd, image_set=SPLIT_VALID, download=False) self.name_to_ds_id = { self._sample_name(path): (self.ds_sbd_train, i) for i, path in enumerate(self.ds_sbd_train.images) } self.name_to_ds_id.update({ self._sample_name(path): (self.ds_sbd_valid, i) for i, path in enumerate(self.ds_sbd_valid.images) }) self.name_to_ds_id.update({ self._sample_name(path): (self.ds_voc_train, i) for i, path in enumerate(self.ds_voc_train.images) }) for path in self.ds_voc_valid.images: name = self._sample_name(path) self.name_to_ds_id.pop(name, None) else: self.name_to_ds_id = { self._sample_name(path): (self.ds_voc_valid, i) for i, path in enumerate(self.ds_voc_valid.images) } self.sample_names = list(sorted(self.name_to_ds_id.keys())) self.transforms = None dir = os.path.dirname(__file__) path_points_fg = os.path.join(dir, 'voc_whats_the_point.json') path_points_bg = os.path.join( dir, 'voc_whats_the_point_bg_from_scribbles.json') with open(path_points_fg, 'r') as f: self.ds_clicks_fg = json.load(f) with open(path_points_bg, 'r') as f: self.ds_clicks_bg = json.load(f) self.ds_scribbles_path = os.path.join(dir, 'voc_scribbles.zip') assert os.path.isfile( self.ds_scribbles_path ), f'Scribbles not found at {self.ds_scribbles_path}' self.cls_name_to_id = { name: i for i, name in enumerate(self.semseg_class_names) } self._semseg_class_histogram = self._compute_histogram() if integrity_check: results = [] for i in tqdm(range(len(self)), desc=f'Checking "{split}" split'): results.append(self.get(i)) for d in results: if d['num_clicks_bg'] == 0: print(d['name'], 'has no background clicks') if d['num_clicks_fg'] == 0: print(d['name'], 'has no foreground clicks') self.integrity_check = False
def main(args): print(args) torch.backends.cudnn.benchmark = True # Data loading normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) base_size = 320 crop_size = 256 min_size, max_size = int(0.5 * base_size), int(2.0 * base_size) interpolation_mode = InterpolationMode.BILINEAR train_loader, val_loader = None, None if not args.test_only: st = time.time() train_set = VOCSegmentation(args.data_path, image_set='train', download=True, transforms=Compose([ RandomResize(min_size, max_size, interpolation_mode), RandomCrop(crop_size), RandomHorizontalFlip(0.5), ImageTransform( T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.1, hue=0.02)), ToTensor(), ImageTransform(normalize) ])) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, drop_last=True, sampler=RandomSampler(train_set), num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn) print(f"Training set loaded in {time.time() - st:.2f}s " f"({len(train_set)} samples in {len(train_loader)} batches)") if args.show_samples: x, target = next(iter(train_loader)) plot_samples(x, target, ignore_index=255) return if not (args.lr_finder or args.check_setup): st = time.time() val_set = VOCSegmentation(args.data_path, image_set='val', download=True, transforms=Compose([ Resize((crop_size, crop_size), interpolation_mode), ToTensor(), ImageTransform(normalize) ])) val_loader = torch.utils.data.DataLoader( val_set, batch_size=args.batch_size, drop_last=False, sampler=SequentialSampler(val_set), num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn) print( f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)" ) if args.source.lower() == 'holocron': model = segmentation.__dict__[args.arch](args.pretrained, num_classes=len(VOC_CLASSES)) elif args.source.lower() == 'torchvision': model = tv_segmentation.__dict__[args.arch]( args.pretrained, num_classes=len(VOC_CLASSES)) # Loss setup loss_weight = None if isinstance(args.bg_factor, float) and args.bg_factor != 1: loss_weight = torch.ones(len(VOC_CLASSES)) loss_weight[0] = args.bg_factor if args.loss == 'crossentropy': criterion = nn.CrossEntropyLoss(weight=loss_weight, ignore_index=255, label_smoothing=args.label_smoothing) elif args.loss == 'focal': criterion = holocron.nn.FocalLoss(weight=loss_weight, ignore_index=255) elif args.loss == 'mc': criterion = holocron.nn.MutualChannelLoss(weight=loss_weight, ignore_index=255, xi=3) # Optimizer setup model_params = [p for p in model.parameters() if p.requires_grad] if args.opt == 'sgd': optimizer = torch.optim.SGD(model_params, args.lr, momentum=0.9, weight_decay=args.weight_decay) elif args.opt == 'radam': optimizer = holocron.optim.RAdam(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) elif args.opt == 'adamp': optimizer = holocron.optim.AdamP(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) elif args.opt == 'adabelief': optimizer = holocron.optim.AdaBelief(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) log_wb = lambda metrics: wandb.log(metrics) if args.wb else None trainer = SegmentationTrainer(model, train_loader, val_loader, criterion, optimizer, args.device, args.output_file, num_classes=len(VOC_CLASSES), amp=args.amp, on_epoch_end=log_wb) if args.resume: print(f"Resuming {args.resume}") checkpoint = torch.load(args.resume, map_location='cpu') trainer.load(checkpoint) if args.show_preds: x, target = next(iter(train_loader)) with torch.no_grad(): if isinstance(args.device, int): x = x.cuda() trainer.model.eval() preds = trainer.model(x) plot_predictions(x.cpu(), preds.cpu(), target, ignore_index=255) return if args.test_only: print("Running evaluation") eval_metrics = trainer.evaluate() print( f"Validation loss: {eval_metrics['val_loss']:.4} (Mean IoU: {eval_metrics['mean_iou']:.2%})" ) return if args.lr_finder: print("Looking for optimal LR") trainer.lr_find(args.freeze_until, norm_weight_decay=args.norm_weight_decay, num_it=min(len(train_loader), 100)) trainer.plot_recorder() return if args.check_setup: print("Checking batch overfitting") is_ok = trainer.check_setup(args.freeze_until, args.lr, norm_weight_decay=args.norm_weight_decay, num_it=min(len(train_loader), 100)) print(is_ok) return # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") exp_name = f"{args.arch}-{current_time}" if args.name is None else args.name # W&B if args.wb: run = wandb.init(name=exp_name, project="holocron-semantic-segmentation", config={ "learning_rate": args.lr, "scheduler": args.sched, "weight_decay": args.weight_decay, "epochs": args.epochs, "batch_size": args.batch_size, "architecture": args.arch, "source": args.source, "input_size": 256, "optimizer": args.opt, "dataset": "Pascal VOC2012 Segmentation", "loss": args.loss, }) print("Start training") start_time = time.time() trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched, norm_weight_decay=args.norm_weight_decay) total_time_str = str( datetime.timedelta(seconds=int(time.time() - start_time))) print(f"Training time {total_time_str}") if args.wb: run.finish()
def __init__(self, args, val=False, query=False): super(VOC2012Segmentation, self).__init__() self.dir_checkpoints = f"{args.dir_root}/checkpoints/{args.experim_name}" self.ignore_index = args.ignore_index self.size_base = args.size_base self.size_crop = (args.size_crop, args.size_crop) self.stride_total = args.stride_total if args.use_augmented_dataset and not val: self.voc = AugmentedVOC(args.dir_augmented_dataset) else: self.voc = VOCSegmentation(f"{args.dir_dataset}", image_set='val' if val else 'train', download=False) print("# images:", len(self.voc)) self.geometric_augmentations = args.augmentations["geometric"] self.photometric_augmentations = args.augmentations["photometric"] self.normalize = Normalize(mean=args.mean, std=args.std) if query: self.geometric_augmentations["random_scale"] = False self.geometric_augmentations["crop"] = False self.geometric_augmentations["random_hflip"] = False if self.geometric_augmentations["crop"]: self.mean = tuple((np.array(args.mean) * 255.0).astype(np.uint8).tolist()) # generate initial queries n_pixels_per_img = args.n_pixels_by_us init_n_pixels = args.n_init_pixels if args.n_init_pixels > 0 else n_pixels_per_img self.queries, self.n_pixels_total = None, -1 path_queries = f"{args.dir_dataset}/init_labelled_pixels_{args.seed}.pkl" if n_pixels_per_img != 0 and not val: os.makedirs(f"{self.dir_checkpoints}/0_query", exist_ok=True) n_pixels_total = 0 list_queries = list() for i in tqdm(range(len(self.voc))): label = self.voc[i][1] w, h = label.size if n_pixels_per_img == 0: n_pixels_per_img = h * w elif n_pixels_per_img != 0 and init_n_pixels > 0: n_pixels_per_img = init_n_pixels else: raise NotImplementedError # generate queries whose size is set to base_size (longer side), i.e. 400 as default h, w = self._compute_base_size(h, w) queries_flat = np.zeros((h * w), dtype=np.bool) # filter void pixels - boundary pixels that the original labels have (fyi, 5 pixels thickness) label = label.resize((w, h), Image.NEAREST) # note that downsampling method should be Image.NEAREST label = np.asarray(label, dtype=np.int32) label_flatten = label.flatten() ind_void_pixels = np.where(label_flatten == 255)[0] ind_non_void_pixels = np.setdiff1d(range(len(queries_flat)), ind_void_pixels) # remove void pixels assert len(ind_non_void_pixels) <= len(queries_flat) # for a very rare case where the number of non_void_pixels is not large enough to sample from if len(ind_non_void_pixels) < n_pixels_per_img: n_pixels_per_img = len(ind_non_void_pixels) ind_chosen_pixels = np.random.choice(ind_non_void_pixels, n_pixels_per_img, replace=False) queries_flat[ind_chosen_pixels] += True queries = queries_flat.reshape((h, w)) list_queries.append(queries) n_pixels_total += queries.sum() pkl.dump(list_queries, open(f"{path_queries}", 'wb')) # Note that images of voc dataset vary from image to image thus can't use np.stack(). self.queries = list_queries pkl.dump(self.queries, open(f"{self.dir_checkpoints}/0_query/label.pkl", 'wb')) self.n_pixels_total = n_pixels_total print("# labelled pixels used for training:", n_pixels_total) self.val, self.query = val, query
def main(train_args, model): print(train_args) dset_path = os.path.join(os.path.abspath(os.environ["HOME"]), 'datasets') net = model.to(device) train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) input_transform = transforms.Compose([ transforms.Pad(200), transforms.CenterCrop(320), transforms.ToTensor(), transforms.Normalize(*mean_std) ]) train_transform = transforms.Compose([ transforms.Pad(200), transforms.CenterCrop(320), MaskToTensor()]) restore_transform = transforms.Compose([ DeNormalize(*mean_std), transforms.ToPILImage(), ]) visualize = transforms.Compose([ transforms.Resize(400), transforms.CenterCrop(400), transforms.ToTensor() ]) train_set = VOCSegmentation(root=dset_path, image_set='train', transform=input_transform, target_transform=train_transform) train_loader = DataLoader(train_set, batch_size=1, num_workers=4, shuffle=True) val_set = VOCSegmentation(root=dset_path, image_set='val', transform=input_transform, target_transform=train_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=4, shuffle=False) criterion = CrossEntropyLoss(ignore_index=255, reduction='mean').to(device) optimizer = optim.SGD([ {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 'lr': 2 * train_args['lr']}, {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']} ], momentum=train_args['momentum']) for epoch in range(1, train_args['epoch_num'] + 1): train(train_loader, net, criterion, optimizer, epoch, train_args) val_loss, imges = validate(val_loader, net, criterion, optimizer, epoch, train_args, restore_transform, visualize) return imges
def main(train_args, model): print(train_args) net = model.cuda() if len(train_args['snapshot']) == 0: curr_epoch = 1 train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} else: print('training resumes from ' + train_args['snapshot']) net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot']))) split_snapshot = train_args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) input_transform = standard_transforms.Compose([ standard_transforms.Pad(200), standard_transforms.CenterCrop(320), standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) train_transform = standard_transforms.Compose([ standard_transforms.Pad(200), standard_transforms.CenterCrop(320), extended_transforms.MaskToTensor()]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage(), ]) visualize = standard_transforms.Compose([ standard_transforms.Resize(400), standard_transforms.CenterCrop(400), standard_transforms.ToTensor() ]) train_set = VOCSegmentation(root='./', image_set='train', transform=input_transform, target_transform=train_transform) train_loader = DataLoader(train_set, batch_size=1, num_workers=4, shuffle=True) val_set = VOCSegmentation(root='./', image_set='val', transform=input_transform, target_transform=train_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=4, shuffle=False) #criterion = CrossEntropyLoss().cuda()#2d(size_average=False, ignore_index=voc.ignore_label).cuda() criterion = CrossEntropyLoss(size_average=False, ignore_index=255).cuda() optimizer = optim.SGD([ {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 'lr': 2 * train_args['lr']}, {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']} ], momentum=train_args['momentum']) """optimizer = optim.Adam([ {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 'lr': 2 * train_args['lr']}, {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']} ], betas=(train_args['momentum'], 0.999))""" if len(train_args['snapshot']) > 0: optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] optimizer.param_groups[1]['lr'] = train_args['lr'] """check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n')""" scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True) for epoch in range(curr_epoch, train_args['epoch_num'] + 1): train(train_loader, net, criterion, optimizer, epoch, train_args) val_loss, imges = validate(val_loader, net, criterion, optimizer, epoch, train_args, restore_transform, visualize) #imges.show() scheduler.step(val_loss) return imges
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # for output bounding box post-processing """Let's put everything together in a `detect` function:""" """## Loading Pascal VOC 2012 dataset Before we start let's download the Pascal VOC validation set from the [here](https://oc.embl.de/index.php/s/bkBUhSajTPP0lUP) and save it in your Google Drive. The archive is 2GB in size so it will take a while. After the ZIP file has been successfully uploaded to your Google Drive, mount your Drive following [the instructions](https://colab.research.google.com/github/constantinpape/training-deep-learning-models-for-vison/blob/master/exercises/mount-gdrive-in-colab.ipynb) and unzip the archive. """ """Let's create the Pascal VOC loader from `torchvision` package and show some images with the ground truth segmentation masks.""" root_dir = "./PascalVOC2012" voc_dataset = VOCSegmentation(root_dir, year='2012', image_set='trainval', download=False) """Before we move on let's define the 20 classes of objects avialable in the Pascal VOC dataset""" # Pascal VOC classes, modifed to match the COCO classes, i.e. the following 4 class names were mapped: # aeroplane -> airplane # diningtable -> dining table # motorbike -> motorcycle # sofa -> couch # tvmonitor -> tv """For the exercises we will need a helper function which extracts the bounding boxes around the individual instances given the ground truth semantic mask.""" """Visualize the bounding boxes on a given image from the Pascal VOC dataset""" indexes = torch.randint(0, len(voc_dataset), (20, )) for index, i in enumerate(indexes): fig = plt.figure(figsize=(8, 8))
def main(args): print(args) torch.backends.cudnn.benchmark = True # Data loading normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) base_size = 320 crop_size = 256 min_size, max_size = int(0.5 * base_size), int(2.0 * base_size) train_loader, val_loader = None, None if not args.test_only: st = time.time() train_set = VOCSegmentation(args.data_path, image_set='train', download=True, transforms=Compose([ RandomResize(min_size, max_size), RandomCrop(crop_size), RandomHorizontalFlip(0.5), ImageTransform( T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.1, hue=0.02)), ToTensor(), ImageTransform(normalize) ])) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, drop_last=True, sampler=RandomSampler(train_set), num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn) print(f"Training set loaded in {time.time() - st:.2f}s " f"({len(train_set)} samples in {len(train_loader)} batches)") if args.show_samples: x, target = next(iter(train_loader)) plot_samples(x, target, ignore_index=255) return if not (args.lr_finder or args.check_setup): st = time.time() val_set = VOCSegmentation(args.data_path, image_set='val', download=True, transforms=Compose([ Resize((crop_size, crop_size)), ToTensor(), ImageTransform(normalize) ])) val_loader = torch.utils.data.DataLoader( val_set, batch_size=args.batch_size, drop_last=False, sampler=SequentialSampler(val_set), num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn) print( f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)" ) model = segmentation.__dict__[args.model]( args.pretrained, not (args.pretrained), num_classes=len(VOC_CLASSES), ) # Loss setup loss_weight = None if isinstance(args.bg_factor, float): loss_weight = torch.ones(len(VOC_CLASSES)) loss_weight[0] = args.bg_factor if args.loss == 'crossentropy': criterion = nn.CrossEntropyLoss(weight=loss_weight, ignore_index=255) elif args.loss == 'label_smoothing': criterion = holocron.nn.LabelSmoothingCrossEntropy(weight=loss_weight, ignore_index=255) elif args.loss == 'focal': criterion = holocron.nn.FocalLoss(weight=loss_weight, ignore_index=255) elif args.loss == 'mc': criterion = holocron.nn.MutualChannelLoss(weight=loss_weight, ignore_index=255) # Optimizer setup model_params = [p for p in model.parameters() if p.requires_grad] if args.opt == 'sgd': optimizer = torch.optim.SGD(model_params, args.lr, momentum=0.9, weight_decay=args.weight_decay) elif args.opt == 'adam': optimizer = torch.optim.Adam(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) elif args.opt == 'radam': optimizer = holocron.optim.RAdam(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) elif args.opt == 'adamp': optimizer = holocron.optim.AdamP(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) elif args.opt == 'adabelief': optimizer = holocron.optim.AdaBelief(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) trainer = SegmentationTrainer(model, train_loader, val_loader, criterion, optimizer, args.device, args.output_file, num_classes=len(VOC_CLASSES)) if args.resume: print(f"Resuming {args.resume}") checkpoint = torch.load(args.resume, map_location='cpu') trainer.load(checkpoint) if args.show_preds: x, target = next(iter(train_loader)) with torch.no_grad(): if isinstance(args.device, int): x = x.cuda() trainer.model.eval() preds = trainer.model(x) plot_predictions(x.cpu(), preds.cpu(), target, ignore_index=255) return if args.test_only: print("Running evaluation") eval_metrics = trainer.evaluate() print( f"Validation loss: {eval_metrics['val_loss']:.4} (Mean IoU: {eval_metrics['mean_iou']:.2%})" ) return if args.lr_finder: print("Looking for optimal LR") trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100)) trainer.plot_recorder() return if args.check_setup: print("Checking batch overfitting") is_ok = trainer.check_setup(args.freeze_until, args.lr, num_it=min(len(train_loader), 100)) print(is_ok) return print("Start training") start_time = time.time() trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched) total_time_str = str( datetime.timedelta(seconds=int(time.time() - start_time))) print(f"Training time {total_time_str}")
def __init__(self, args): # TODO: augmentation. t_val = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.args = args self.get_palette = get_palette(256) # Savar self.saver = Saver(args) self.saver.save_experiment_config() # Tensorboard self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() # Dataloader kwargs = {'num_workers:': args.num_workers, 'pin_memory': True} # TODO: dataset download # self.train_loader, self.val_loader, self.test_loader, self.nclass = get_pascalvoc(args, base_dir=args.pascal_dataset_path ,transforms_train=t) t = trainsforms_default() self.train_loader = VOCSegmentation(root='./dataset/', year='2012', image_set='train', download=False, transform=t, target_transform=t_val) self.val_loader = VOCSegmentation(root='./dataset/', year='2012', image_set='val', download=False, transform=t, target_transform=t_val) # Dataset self.train_loader = DataLoader(self.train_loader, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) self.val_loader = DataLoader(self.val_loader, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=True, pin_memory=True) # Netwok self.model = deeplabV3plus(backbone=args.backbone, output_stride=args.out_stride, # num_classes=self.nclass, num_classes=21, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).to(self.device) train_params = [{'params': self.model.get_1x_lr_params(), 'lr': args.lr}, {'params': self.model.get_10x_lr_params(), 'lr': args.lr * 10}] # Optimizer self.optimizer = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) # Criterion # Wether to use class balanced weights. if args.use_balanced_weights: pass # TODO: else: weight = None self.criterion = SegmentationLosses(weight=None).build_loss(mode=args.loss_type) # Cuda if args.data_parallel: self.model = torch.nn.DataParallel(self.model) patch_replication_callback(self.model) # Evaluator self.evaluator = Evaluator(21) # Lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError('no checkpoint found at: {}'.format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] self.model.module.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print('Loaded checkpoint: {} (epoch: {})'.format(args.resume, checkpoint['epoch'])) if args.ft: args.start_epoch = 0
def load_data(dataset, path, batch_size=64, normalize=False): if normalize: # Wasserstein BiGAN is trained on normalized data. transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) else: # BiGAN is trained on unnormalized data (see Dumoulin et al. ICLR 16). transform = transforms.ToTensor() if dataset == 'svhn': train_set = SVHN(path, split='extra', transform=transform, download=True) val_set = SVHN(path, split='test', transform=transform, download=True) if dataset == 'stl10': train_set = STL10(path, split='train', transform=transform, download=True) val_set = STL10(path, split='test', transform=transform, download=True) elif dataset == 'cifar10': train_set = CIFAR10(path, train=True, transform=transform, download=True) val_set = CIFAR10(path, train=False, transform=transform, download=True) elif dataset == 'stl10': train_set = STL10(path, split='train', transform=transform, download=True) val_set = STL10(path, split='test', transform=transform, download=True) elif dataset == 'cifar100': train_set = CIFAR100(path, train=True, transform=transform, download=True) val_set = CIFAR100(path, train=False, transform=transform, download=True) elif dataset == 'VOC07': train_set = VOCSegmentation(path, image_set='train', year='2007', transform=transform, download=True) val_set = VOCSegmentation(path, image_set='val', year='2007', transform=transform, download=True) elif dataset == 'VOC10': train_set = VOCSegmentation(path, image_set='train', year='2010', transform=transform, download=True) val_set = VOCSegmentation(path, image_set='val', year='2010', transform=transform, download=True) train_loader = data.DataLoader(train_set, batch_size, shuffle=True, num_workers=12) val_loader = data.DataLoader(val_set, 1, shuffle=False, num_workers=1, pin_memory=True) return train_loader, val_loader
]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_seg = transforms.Compose([ # transforms.RandomCrop(128, 64), # This can cause very low accuracy, pay attention!!! transforms.ToTensor(), ]) # load PASCAL VOC 2012 Segmentation dataset seg_dataset = VOCSegmentation('~/DeLightCMU/CVPR-Prep/Non-local_pytorch/data', year = "2012", image_set='train', download=False, transform=transform_seg, target_transform=transform_seg) print('VOCSeg ends.') seg_loader = DataLoader(seg_dataset, batch_size=1) print('seg_loader ends.') # input_num = 0 # for input, target in seg_loader: # # print('for loop.') # print(input.size(), target.size()) # input_num = input_num + 1 # print('input_num: ', input_num) # exit(-1) # for i, data in enumerate(seg_loader):