Пример #1
0
def main():    
    use_cuda = torch.cuda.is_available()
    path = os.path.expanduser('~/codedata/seg/')
    
    dataset = data.VOCClassSeg(root=path,
                            split='val.txt',
                            transform=True)

    model = models.FCN8(path)
    model.load('SBD.pth')
    model.eval()
    
    if use_cuda:
        model.cuda()
    
    criterion = utils.CrossEntropyLoss2d(size_average=False, ignore_index=255)

    for i in range(len(dataset)):
        idx = random.randrange(0, len(dataset))
        img, label = dataset[idx]
        img_name = str(i)

        img_src, _ = dataset.untransform(img, label)
        cv2.imwrite(path + 'image/%s_src.jpg' % img_name, img_src)
        utils.tool.labelTopng(label, path + 'image/%s_label.png' % img_name)
        
        print(img_name)

        if use_cuda:
            img = img.cuda()
            label = label.cuda()
        img = Variable(img.unsqueeze(0), volatile=True)
        label = Variable(label.unsqueeze(0), volatile=True)

        out = model(img)
        loss = criterion(out, label)
        print('loss:', loss.data[0])

        label = out.data.max(1)[1].squeeze_(1).squeeze_(0)
        if use_cuda:
            label = label.cpu()
        
        utils.tool.labelTopng(label, path + 'image/%s_out.png' % img_name)
        
        if i == 10:
            break
Пример #2
0
def evaluate():
    use_cuda = torch.cuda.is_available()
    path = os.path.expanduser('~/codedata/seg/')
    val_data = data.VOCClassSeg(root=path,
                            split='val.txt',
                            transform=True)
    val_loader = torch.utils.data.DataLoader(val_data,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=5)  
    print('load model .....') 
    model = models.FCN8(path)
    model.load('SBD.pth')
    if use_cuda:
        model.cuda()
    model.eval()

    label_trues, label_preds = [], []
    # for idx, (img, label) in enumerate(val_loader):
    for idx in range(len(val_data)):
        img, label = val_data[idx]
        img = img.unsqueeze(0)
        if use_cuda:
            img = img.cuda()
        img = Variable(img, volatile=True)
        
        out = model(img)
        pred = out.data.max(1)[1].squeeze_(1).squeeze_(0)
        
        if use_cuda:
            pred = pred.cpu()
        label_trues.append(label.numpy())
        label_preds.append(pred.numpy())

        if idx % 30 == 0:
            print('evaluate [%d/%d]' % (idx, len(val_loader)))
    
    metrics = utils.tool.accuracy_score(label_trues, label_preds)
    metrics = np.array(metrics)
    metrics *= 100
    print('''\
            Accuracy: {0}
            Accuracy Class: {1}
            Mean IU: {2}
            FWAV Accuracy: {3}'''.format(*metrics))
Пример #3
0
    def __init__(self, args):
        self.args = args
        self.mode = args.mode
        self.epochs = args.epochs
        self.dataset = args.dataset
        self.data_path = args.data_path
        self.train_crop_size = args.train_crop_size
        self.eval_crop_size = args.eval_crop_size
        self.stride = args.stride
        self.batch_size = args.train_batch_size
        self.train_data = AerialDataset(crop_size=self.train_crop_size,
                                        dataset=self.dataset,
                                        data_path=self.data_path,
                                        mode='train')
        self.train_loader = DataLoader(self.train_data,
                                       batch_size=self.batch_size,
                                       shuffle=True,
                                       num_workers=2)
        self.eval_data = AerialDataset(dataset=self.dataset,
                                       data_path=self.data_path,
                                       mode='val')
        self.eval_loader = DataLoader(self.eval_data,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=2)

        if self.dataset == 'Potsdam':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(6000, 6000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD5':
            self.num_of_class = 5
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD6':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        else:
            raise NotImplementedError

        if args.model == 'FCN':
            self.model = models.FCN8(num_classes=self.num_of_class)
        elif args.model == 'DeepLabV3+':
            self.model = models.DeepLab(num_classes=self.num_of_class,
                                        backbone='resnet')
        elif args.model == 'GCN':
            self.model = models.GCN(num_classes=self.num_of_class)
        elif args.model == 'UNet':
            self.model = models.UNet(num_classes=self.num_of_class)
        elif args.model == 'ENet':
            self.model = models.ENet(num_classes=self.num_of_class)
        elif args.model == 'D-LinkNet':
            self.model = models.DinkNet34(num_classes=self.num_of_class)
        else:
            raise NotImplementedError

        if args.loss == 'CE':
            self.criterion = CrossEntropyLoss2d()
        elif args.loss == 'LS':
            self.criterion = LovaszSoftmax()
        elif args.loss == 'F':
            self.criterion = FocalLoss()
        elif args.loss == 'CE+D':
            self.criterion = CE_DiceLoss()
        else:
            raise NotImplementedError

        self.schedule_mode = args.schedule_mode
        self.optimizer = opt.AdamW(self.model.parameters(), lr=args.lr)
        if self.schedule_mode == 'step':
            self.scheduler = opt.lr_scheduler.StepLR(self.optimizer,
                                                     step_size=30,
                                                     gamma=0.1)
        elif self.schedule_mode == 'miou' or self.schedule_mode == 'acc':
            self.scheduler = opt.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                                mode='max',
                                                                patience=10,
                                                                factor=0.1)
        elif self.schedule_mode == 'poly':
            iters_per_epoch = len(self.train_loader)
            self.scheduler = Poly(self.optimizer,
                                  num_epochs=args.epochs,
                                  iters_per_epoch=iters_per_epoch)
        else:
            raise NotImplementedError

        self.evaluator = Evaluator(self.num_of_class)

        self.model = nn.DataParallel(self.model)

        self.cuda = args.cuda
        if self.cuda is True:
            self.model = self.model.cuda()

        self.resume = args.resume
        self.finetune = args.finetune
        assert not (self.resume != None and self.finetune != None)

        if self.resume != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.resume)
            else:
                checkpoint = torch.load(args.resume, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.start_epoch = checkpoint['epoch'] + 1
            #start from next epoch
        elif self.finetune != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.finetune)
            else:
                checkpoint = torch.load(args.finetune, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.start_epoch = checkpoint['epoch'] + 1
        else:
            self.start_epoch = 1
        if self.mode == 'train':
            self.writer = SummaryWriter(comment='-' + self.dataset + '_' +
                                        self.model.__class__.__name__ + '_' +
                                        args.loss)
        self.init_eval = args.init_eval
Пример #4
0
path = os.path.expanduser('~/codedata/seg/')

print('load data....')
train_data = data.SBDClassSeg(root=path, split='train.txt', transform=True)
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=5)
val_data = data.VOCClassSeg(root=path, split='val_val.txt', transform=True)
val_loader = torch.utils.data.DataLoader(val_data,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=5)

print('load model.....')
model = models.FCN8(path)
# model = models.FCN8(path)

if pretrained is 'pretrain':
    VGG16 = torchvision.models.vgg16(pretrained=True)
    model.copy_params_from_vgg16(VGG16)
elif pretrained is 'reload':
    model.load('SBD.pth')
else:
    print("no pretrained model load")

if use_cuda:
    model.cuda()

criterion = utils.loss.CrossEntropyLoss2d(size_average=False, ignore_index=255)
optimizer = torch.optim.SGD([{