def test_mean_iou():
    pred = torch.tensor([0, 0, 1, 1, 0, 1])
    target = torch.tensor([0, 1, 0, 1, 0, 0])

    out = mean_iou(pred, target, num_classes=2)
    assert out == (0.4 + 0.25) / 2

    batch = torch.tensor([0, 0, 0, 0, 1, 1])
    out = mean_iou(pred, target, num_classes=2, batch=batch)
    assert out[0] == (1 / 3 + 1 / 3) / 2
    assert out[1] == 0.25
Exemple #2
0
def test(loader):
    model.eval()

    correct_nodes = total_nodes = 0
    ious = []
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
        pred = out.max(dim=1)[1]
        correct_nodes += pred.eq(data.y).sum().item()
        ious += [mean_iou(pred, data.y, test_dataset.num_classes, data.batch)]
        total_nodes += data.num_nodes
    return correct_nodes / total_nodes, torch.cat(ious, dim=0).mean().item()
Exemple #3
0
def test(loader):
    model.eval()
    correct_nodes = total_nodes = 0
    ious = []
    loss_list = []
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
            loss = F.nll_loss(out, data.y)
            loss_list.append(loss.item())
        pred = out.max(dim=1)[1]
        correct_nodes += pred.eq(data.y).sum().item()
        ious += [mean_iou(pred, data.y, params.nClassesTotal, data.batch)]
        total_nodes += data.num_nodes
    loss_valid.append(sum(loss_list)/len(loss_list))
    acc = correct_nodes / total_nodes
    save_checkpoint({'state_dict':model.state_dict()}, acc>best, exp_path)
    return acc, torch.cat(ious, dim=0).mean().item()
Exemple #4
0
def test(loader):
    model.eval()
    correct_nodes = total_nodes = 0
    ious = []
    loss_list = []
    for i, data in enumerate(loader):
        path = path_list[i]

        data = data.to(device)
        with torch.no_grad():
            out = model(data)
            loss = F.nll_loss(out, data.y)
            loss_list.append(loss.item())
        pred = out.max(dim=1)[1]
        path = osp.join(args.out_dir,
                        path.split(os.sep)[-1].replace('graph', 'face_gnn'))
        with open(path, 'wb') as file:
            pickle.dump(pred.cpu().numpy(), file)

        correct_nodes += pred.eq(data.y).sum().item()
        ious += [mean_iou(pred, data.y, params.nClassesTotal, data.batch)]
        total_nodes += data.num_nodes
    acc = correct_nodes / total_nodes
    return acc, torch.cat(ious, dim=0).mean().item()