Пример #1
0
def trainaccuracy():
    cm = np.zeros((3, 3), dtype=int)
    net.eval()
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(earlystopping):
            inputs, targets = inputs.to(device), targets.to(device)

            targets = dataloader.convertIn3class(targets)

            outputs = net(inputs)
            _, pred = outputs.max(1)

            rahh = torch.transpose(inputs, 1, 2)
            rahh = torch.transpose(rahh, 2, 3)

            for j in range(pred.shape[0]):
                imageraw = PIL.Image.fromarray(np.uint8(rahh[j].cpu().numpy()))
                imageraw.save("build/" + str(i * 16 + j) + "_x.png")
                labelim = PIL.Image.fromarray(np.uint8(targets[j].cpu().numpy()) * 125)
                labelim.save("build/" + str(i * 16 + j) + "_y.png")
                predim = PIL.Image.fromarray(np.uint8(pred[j].cpu().numpy()) * 125)
                predim.save("build/" + str(i * 16 + j) + "_z.png")

                cm += confusion_matrix(
                    pred[j].cpu().numpy().flatten(),
                    targets[j].cpu().numpy().flatten(),
                    labels=[0, 1, 2],
                )
    return cm[0:2, 0:2]
Пример #2
0
def trainaccuracy():
    cm = np.zeros((3, 3), dtype=int)
    net.eval()
    with torch.no_grad():
        for inputs, targets in earlystopping:
            inputs, targets = inputs.to(device), targets.to(device)

            targets = dataloader.convertIn3class(targets)

            outputs = net(inputs)
            _, pred = outputs.max(1)
            for i in range(pred.shape[0]):
                cm += confusion_matrix(
                    pred[i].cpu().numpy().flatten(),
                    targets[i].cpu().numpy().flatten(),
                    labels=[0, 1, 2],
                )
    return cm[0:2, 0:2]
        XYA = []
        for i in range(miniworld.data[town].nbImages):
            imageraw, label = miniworld.data[town].getImageAndLabel(i)
            tmp = select_rootcentredpatch(imageraw, label)
            if tmp is not None:
                XYA += select_rootcentredpatch(imageraw, label)

        if XYA == []:
            continue
        print(town)
        # pytorch
        X = torch.stack(
            [torch.Tensor(np.transpose(x, axes=(2, 0, 1))).cpu() for x, _, _ in XYA]
        )
        Y = torch.stack([torch.from_numpy(y).long().cpu() for _, y, _ in XYA])
        Y = dataloader.convertIn3class(Y)
        A = torch.stack([torch.from_numpy(a).long().cpu() for _, _, a in XYA])
        A = dataloader.convertIn3class(A)  # remove border of the roof
        A = torch.min(A.float(), torch.ones(A.shape).float()).long()
        del XYA

        XYAtensor = torch.utils.data.TensorDataset(X, Y, A)
        pytorchloader = torch.utils.data.DataLoader(
            XYAtensor, batch_size=16, shuffle=False, num_workers=2
        )

        ZXaZa = []
        for inputs, targets, masks in pytorchloader:
            inputs, targets, masks = inputs.cuda(), targets.cuda(), masks.cuda()
            with torch.no_grad():
                preds = net(inputs)
Пример #4
0
nbepoch = 150
batchsize = 32

for epoch in range(nbepoch):
    print("epoch=", epoch, "/", nbepoch)

    XY = miniworld.getrandomtiles(5000, 128, batchsize)
    for x, y in XY:
        x, y = x.to(device), y.to(device)

        preds = net(x)
        tmp = torch.zeros(preds.shape[0], 1, preds.shape[2], preds.shape[3])
        tmp = tmp.to(device)
        preds = torch.cat([preds, tmp], dim=1)

        yy = dataloader.convertIn3class(y)

        ypm = y * 2 - 1
        predspm = preds[:, 1, :, :] - preds[:, 0, :, :]
        one_no_border = (y == yy).long()

        assert ypm.shape == predspm.shape
        assert one_no_border.shape == predspm.shape

        hingeloss = torch.sum(
            torch.nn.functional.relu(-one_no_border * ypm * predspm))
        # can not be 0 due to the other criterion
        # ensure a linear penalty of the error rather than exponential one

        loss = (criterion(preds * 1000, yy) +
                hingeloss / yy.shape[0] / yy.shape[1] / yy.shape[2])