def main(): use_cuda = torch.cuda.is_available() path = os.path.expanduser('~/codedata/seg/') dataset = data.VOCClassSeg(root=path, split='val.txt', transform=True) model = models.FCN8(path) model.load('SBD.pth') model.eval() if use_cuda: model.cuda() criterion = utils.CrossEntropyLoss2d(size_average=False, ignore_index=255) for i in range(len(dataset)): idx = random.randrange(0, len(dataset)) img, label = dataset[idx] img_name = str(i) img_src, _ = dataset.untransform(img, label) cv2.imwrite(path + 'image/%s_src.jpg' % img_name, img_src) utils.tool.labelTopng(label, path + 'image/%s_label.png' % img_name) print(img_name) if use_cuda: img = img.cuda() label = label.cuda() img = Variable(img.unsqueeze(0), volatile=True) label = Variable(label.unsqueeze(0), volatile=True) out = model(img) loss = criterion(out, label) print('loss:', loss.data[0]) label = out.data.max(1)[1].squeeze_(1).squeeze_(0) if use_cuda: label = label.cpu() utils.tool.labelTopng(label, path + 'image/%s_out.png' % img_name) if i == 10: break
def evaluate(): use_cuda = torch.cuda.is_available() path = os.path.expanduser('~/codedata/seg/') val_data = data.VOCClassSeg(root=path, split='val.txt', transform=True) val_loader = torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=False, num_workers=5) print('load model .....') model = models.FCN8(path) model.load('SBD.pth') if use_cuda: model.cuda() model.eval() label_trues, label_preds = [], [] # for idx, (img, label) in enumerate(val_loader): for idx in range(len(val_data)): img, label = val_data[idx] img = img.unsqueeze(0) if use_cuda: img = img.cuda() img = Variable(img, volatile=True) out = model(img) pred = out.data.max(1)[1].squeeze_(1).squeeze_(0) if use_cuda: pred = pred.cpu() label_trues.append(label.numpy()) label_preds.append(pred.numpy()) if idx % 30 == 0: print('evaluate [%d/%d]' % (idx, len(val_loader))) metrics = utils.tool.accuracy_score(label_trues, label_preds) metrics = np.array(metrics) metrics *= 100 print('''\ Accuracy: {0} Accuracy Class: {1} Mean IU: {2} FWAV Accuracy: {3}'''.format(*metrics))
def __init__(self, args): self.args = args self.mode = args.mode self.epochs = args.epochs self.dataset = args.dataset self.data_path = args.data_path self.train_crop_size = args.train_crop_size self.eval_crop_size = args.eval_crop_size self.stride = args.stride self.batch_size = args.train_batch_size self.train_data = AerialDataset(crop_size=self.train_crop_size, dataset=self.dataset, data_path=self.data_path, mode='train') self.train_loader = DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=2) self.eval_data = AerialDataset(dataset=self.dataset, data_path=self.data_path, mode='val') self.eval_loader = DataLoader(self.eval_data, batch_size=1, shuffle=False, num_workers=2) if self.dataset == 'Potsdam': self.num_of_class = 6 self.epoch_repeat = get_test_times(6000, 6000, self.train_crop_size, self.train_crop_size) elif self.dataset == 'UDD5': self.num_of_class = 5 self.epoch_repeat = get_test_times(4000, 3000, self.train_crop_size, self.train_crop_size) elif self.dataset == 'UDD6': self.num_of_class = 6 self.epoch_repeat = get_test_times(4000, 3000, self.train_crop_size, self.train_crop_size) else: raise NotImplementedError if args.model == 'FCN': self.model = models.FCN8(num_classes=self.num_of_class) elif args.model == 'DeepLabV3+': self.model = models.DeepLab(num_classes=self.num_of_class, backbone='resnet') elif args.model == 'GCN': self.model = models.GCN(num_classes=self.num_of_class) elif args.model == 'UNet': self.model = models.UNet(num_classes=self.num_of_class) elif args.model == 'ENet': self.model = models.ENet(num_classes=self.num_of_class) elif args.model == 'D-LinkNet': self.model = models.DinkNet34(num_classes=self.num_of_class) else: raise NotImplementedError if args.loss == 'CE': self.criterion = CrossEntropyLoss2d() elif args.loss == 'LS': self.criterion = LovaszSoftmax() elif args.loss == 'F': self.criterion = FocalLoss() elif args.loss == 'CE+D': self.criterion = CE_DiceLoss() else: raise NotImplementedError self.schedule_mode = args.schedule_mode self.optimizer = opt.AdamW(self.model.parameters(), lr=args.lr) if self.schedule_mode == 'step': self.scheduler = opt.lr_scheduler.StepLR(self.optimizer, step_size=30, gamma=0.1) elif self.schedule_mode == 'miou' or self.schedule_mode == 'acc': self.scheduler = opt.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', patience=10, factor=0.1) elif self.schedule_mode == 'poly': iters_per_epoch = len(self.train_loader) self.scheduler = Poly(self.optimizer, num_epochs=args.epochs, iters_per_epoch=iters_per_epoch) else: raise NotImplementedError self.evaluator = Evaluator(self.num_of_class) self.model = nn.DataParallel(self.model) self.cuda = args.cuda if self.cuda is True: self.model = self.model.cuda() self.resume = args.resume self.finetune = args.finetune assert not (self.resume != None and self.finetune != None) if self.resume != None: print("Loading existing model...") if self.cuda: checkpoint = torch.load(args.resume) else: checkpoint = torch.load(args.resume, map_location='cpu') self.model.load_state_dict(checkpoint['parameters']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) self.start_epoch = checkpoint['epoch'] + 1 #start from next epoch elif self.finetune != None: print("Loading existing model...") if self.cuda: checkpoint = torch.load(args.finetune) else: checkpoint = torch.load(args.finetune, map_location='cpu') self.model.load_state_dict(checkpoint['parameters']) self.start_epoch = checkpoint['epoch'] + 1 else: self.start_epoch = 1 if self.mode == 'train': self.writer = SummaryWriter(comment='-' + self.dataset + '_' + self.model.__class__.__name__ + '_' + args.loss) self.init_eval = args.init_eval
path = os.path.expanduser('~/codedata/seg/') print('load data....') train_data = data.SBDClassSeg(root=path, split='train.txt', transform=True) train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=5) val_data = data.VOCClassSeg(root=path, split='val_val.txt', transform=True) val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=5) print('load model.....') model = models.FCN8(path) # model = models.FCN8(path) if pretrained is 'pretrain': VGG16 = torchvision.models.vgg16(pretrained=True) model.copy_params_from_vgg16(VGG16) elif pretrained is 'reload': model.load('SBD.pth') else: print("no pretrained model load") if use_cuda: model.cuda() criterion = utils.loss.CrossEntropyLoss2d(size_average=False, ignore_index=255) optimizer = torch.optim.SGD([{