def make_data_loader(args, **kwargs): if args.dataset == 'pascal': train_set = pascal.VOCSegmentation(args, split='train') val_set = pascal.VOCSegmentation(args, split='val') if args.use_sbd: sbd_train = sbd.SBDSegmentation(args, split=['train', 'val']) train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set]) num_class = train_set.NUM_CLASSES train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = None return train_loader, val_loader, test_loader, num_class elif args.dataset == 'tt100k': train_set = tt100k.TT100KSegmentation(args, split='train') val_set = tt100k.TT100KSegmentation(args, split='val') num_class = train_set.NUM_CLASSES train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = None return train_loader, val_loader, test_loader, num_class elif args.dataset == 'cityscapes': train_set = cityscapes.CityscapesSegmentation(args, split='train') val_set = cityscapes.CityscapesSegmentation(args, split='val') test_set = cityscapes.CityscapesSegmentation(args, split='test') num_class = train_set.NUM_CLASSES train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) return train_loader, val_loader, test_loader, num_class elif args.dataset == 'coco': train_set = coco.COCOSegmentation(args, split='train') val_set = coco.COCOSegmentation(args, split='val') num_class = train_set.NUM_CLASSES train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = None return train_loader, val_loader, test_loader, num_class else: raise NotImplementedError
def __init__(self, args): self.args = args # Define Dataloader test_set = tt100k.TT100KSegmentation(args, split='val') self.nclass = test_set.NUM_CLASSES self.test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers) # Define network self.model = DeepLab(num_classes=self.nclass, backbone=args.backbone, output_stride=args.out_stride) # Using cuda if args.cuda: torch.cuda.set_device(self.args.gpu_ids) self.model = self.model.cuda() # load weight assert args.weight is not None if not os.path.isfile(args.weight): raise RuntimeError("=> no checkpoint found at '{}'" .format(args.weight)) checkpoint = torch.load(args.weight) self.model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}'".format(args.weight)) self.show = False self.outdir = 'run/mask' if not self.show: if os.path.exists(self.outdir): shutil.rmtree(self.outdir) os.makedirs(self.outdir)