예제 #1
0
    def infer_batch(self, image, label):
        image = image.cuda()
        label = label.cuda()
        logit = self.net(image)

        pred = logit.max(dim=1)[1]

        return fast_hist(label, pred)
예제 #2
0
    def inf_batch(self, image, label):
        image = image  #.cuda()
        label = label  #.cuda()
        with torch.no_grad():
            logit = self.net(image)

        pred = logit.max(dim=1)[1]
        self.hist += fast_hist(label, pred)
예제 #3
0
파일: eval.py 프로젝트: Cynicsss/Zy3s_Seg
def eval(args):
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend='nccl',
                            init_method='tcp://127.0.0.1:{}'.format(
                                config_CS.port),
                            world_size=torch.cuda.device_count(),
                            rank=0)

    dataset = CityScapes(mode='val')
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            sampler=sampler,
                            num_workers=4,
                            pin_memory=True,
                            drop_last=False)

    # net = Origin_Res()
    # net = Deeplab_v3plus()
    net = HighOrder(19)
    net.cuda()
    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    net = nn.parallel.DistributedDataParallel(net,
                                              device_ids=[args.local_rank],
                                              output_device=args.local_rank)
    net.load_state_dict(torch.load('./Res60000.pth'))
    net.eval()

    data = iter(dataloader)
    num = 0
    hist = 0
    with torch.no_grad():
        while 1:
            try:
                image, label, name = next(data)
            except:
                break
            image = image.cuda()
            label = label.cuda()
            label = torch.squeeze(label, 1)

            output = net(image)
            pred = output.max(dim=1)[1]
            hist_once = fast_hist(label, pred)
            hist = torch.tensor(hist).cuda()
            hist = hist + hist_once
            dist.all_reduce(hist, dist.ReduceOp.SUM)
            num += 1
            if num % 50 == 0:
                print('iter :{}'.format(num))
        hist = hist.cpu().numpy().astype(np.float32)
        miou = cal_scores(hist)

    print('miou = {}'.format(miou))
예제 #4
0
def val_segmentation(epoch, net_segmentation):
    global best_iou
    global val_seg_iou
    progbar = tqdm(total=len(val_seg_loader), desc='Val')
    net_segmentation.eval()

    val_seg_loss.append(0)
    hist = np.zeros((nClasses, nClasses))
    for batch_idx, (inputs_, targets) in enumerate(val_seg_loader):
        inputs_, targets = Variable(inputs_.to(device)), Variable(
            targets.to(device))

        outputs = net_segmentation(inputs_)

        total_loss = 1 - soft_iou(outputs, targets, ignore=ignore_class)

        val_seg_loss[-1] += total_loss.data

        _, predicted = torch.max(outputs.data, 1)
        correctLabel = targets.view(-1, targets.size()[1], targets.size()[2])
        hist += fast_hist(
            correctLabel.view(correctLabel.size(0), -1).cpu().numpy(),
            predicted.view(predicted.size(0), -1).cpu().numpy(), nClasses)

        miou, p_acc, fwacc = performMetrics(hist)

        progbar.set_description('Val (loss=%.4f, mIoU=%.4f)' %
                                (val_seg_loss[-1] / (batch_idx + 1), miou))
        progbar.update(1)
    val_seg_loss[-1] = val_seg_loss[-1] / len(val_seg_loader)
    val_miou, _, _ = performMetrics(hist)
    val_seg_iou += [val_miou]

    if best_iou < val_miou:
        best_iou = val_miou
        print('Saving..')
        state = {'net_segmentation': net_segmentation}

        torch.save(state,
                   model_root + experiment + 'segmentation' + '.ckpt.t7')
예제 #5
0
def train_segmentation(epoch, net_segmentation, seg_optimizer):
    global train_seg_iou
    progbar = tqdm(total=len(train_seg_loader), desc='Train')
    net_segmentation.train()

    train_seg_loss.append(0)
    seg_optimizer.zero_grad()
    hist = np.zeros((nClasses, nClasses))
    for batch_idx, (inputs_, targets) in enumerate(train_seg_loader):
        inputs_, targets = Variable(inputs_.to(device)), Variable(
            targets.to(device))

        outputs = net_segmentation(inputs_)

        total_loss = (
            1 - soft_iou(outputs, targets, ignore=ignore_class)) / ITER_SIZE
        total_loss.backward()

        if (batch_idx % ITER_SIZE == 0
                and batch_idx != 0) or batch_idx == len(train_loader) - 1:
            seg_optimizer.step()
            seg_optimizer.zero_grad()

        train_seg_loss[-1] += total_loss.data

        _, predicted = torch.max(outputs.data, 1)
        correctLabel = targets.view(-1, targets.size()[1], targets.size()[2])
        hist += fast_hist(
            correctLabel.view(correctLabel.size(0), -1).cpu().numpy(),
            predicted.view(predicted.size(0), -1).cpu().numpy(), nClasses)

        miou, p_acc, fwacc = performMetrics(hist)

        progbar.set_description('Train (loss=%.4f, mIoU=%.4f)' %
                                (train_seg_loss[-1] / (batch_idx + 1), miou))
        progbar.update(1)
    train_seg_loss[-1] = train_seg_loss[-1] / len(train_seg_loader)
    miou, p_acc, fwacc = performMetrics(hist)
    train_seg_iou += [miou]
예제 #6
0
def evaluate_segmentation(net_segmentation):
    net_segmentation.eval()
    hist = np.zeros((nClasses, nClasses))
    val_seg_loader = torch.utils.data.DataLoader(segmentation_data_loader(
        img_root=val_img_root,
        gt_root=val_gt_root,
        image_list=val_image_list,
        suffix=dataset,
        out=out,
        crop=False,
        mirror=False),
                                                 batch_size=1,
                                                 num_workers=8,
                                                 shuffle=False)

    progbar = tqdm(total=len(val_seg_loader), desc='Eval')

    hist = np.zeros((nClasses, nClasses))
    for batch_idx, (inputs_, targets) in enumerate(val_seg_loader):
        inputs_, targets = Variable(inputs_.to(device)), Variable(
            targets.to(device))

        outputs = net_segmentation(inputs_)

        _, predicted = torch.max(outputs.data, 1)
        correctLabel = targets.view(-1, targets.size()[1], targets.size()[2])
        hist += fast_hist(
            correctLabel.view(correctLabel.size(0), -1).cpu().numpy(),
            predicted.view(predicted.size(0), -1).cpu().numpy(), nClasses)

        miou, p_acc, fwacc = performMetrics(hist)
        progbar.set_description('Eval (mIoU=%.4f)' % (miou))
        progbar.update(1)

    miou, p_acc, fwacc = performMetrics(hist)
    print('\n mIoU: ', miou)
    print('\n Pixel accuracy: ', p_acc)
    print('\n Frequency Weighted Pixel accuracy: ', fwacc)
예제 #7
0
def eval(args):
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend='nccl',
                            init_method='tcp://127.0.0.1:{}'.format(
                                config.port),
                            world_size=torch.cuda.device_count(),
                            rank=args.local_rank
                            # rank=args.local_rank
                            )

    dataset = ADE20K(mode='val')
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            sampler=sampler,
                            num_workers=4,
                            drop_last=False,
                            pin_memory=True)

    net = PANet(150)
    net.cuda()
    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    net = nn.parallel.DistributedDataParallel(net,
                                              device_ids=[args.local_rank],
                                              output_device=args.local_rank)
    net.load_state_dict(
        torch.load('./GPADE20Kres50150000.pth', map_location='cpu'))
    net.eval()

    data = iter(dataloader)
    palette = get_palette(256)
    num = 0
    hist = 0
    with torch.no_grad():
        while 1:
            try:
                image, label, name = next(data)
            except:
                break

            image = image.cuda()
            label = label.cuda()
            label = torch.squeeze(label, 1)
            N, _, H, W = image.size()
            preds = torch.zeros((N, 150, H, W))
            preds = preds.cuda()
            for scale in config.eval_scales:
                new_hw = [int(H * scale), int(W * scale)]
                image_change = F.interpolate(image,
                                             new_hw,
                                             mode='bilinear',
                                             align_corners=True)
                output, w = net(image_change)
                output = F.interpolate(output, (H, W),
                                       mode='bilinear',
                                       align_corners=True)
                output = F.softmax(output, 1)
                preds += output
                if config.eval_flip:
                    output, w = net(torch.flip(image_change, dims=(3, )))
                    output = torch.flip(output, dims=(3, ))
                    output = F.interpolate(output, (H, W),
                                           mode='bilinear',
                                           align_corners=True)
                    output = F.softmax(output, 1)
                    preds += output
            pred = preds.max(dim=1)[1]
            hist_once = fast_hist(label, pred)
            hist = torch.tensor(hist).cuda()
            hist = hist + hist_once
            dist.all_reduce(hist, dist.ReduceOp.SUM)
            num += 1
            if num % 5 == 0:
                print('iter: {}'.format(num))

            preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
            # for i in range(preds.shape[0]):
            #     pred = convert_label(preds[i], inverse=True)
            #     save_img = Image.fromarray(pred)
            #     save_img.putpalette(palette)
            #     save_img.save(os.path.join('./CS_results/', name[i] + '.png'))

        hist = hist.cpu().numpy().astype(np.float32)
        miou = cal_scores(hist)

    print('miou = {}'.format(miou))