Exemplo n.º 1
0
                cm[a][b] = torch.sum(
                    (z[j] == a).float() * (y[j] == b).float() * D[j])
            stats[batchchoise[j]] += cm

        if i < 10:
            print(i, "/", nbbatchs, printloss)
        if i < 1000 and i % 100 == 99:
            print(i, "/", nbbatchs, printloss / 100)
            printloss = torch.zeros(1).cuda()
        if i >= 1000 and i % 300 == 299:
            print(i, "/", nbbatchs, printloss / 300)
            printloss = torch.zeros(1).cuda()

        if i % 1000 == 999:
            torch.save(net, "build/model.pth")
            perf = miniworld.perf(torch.sum(stats, dim=0))
            print(i, "perf", perf)
            if perf[0] > 92:
                print("training stops after reaching high training accuracy")
                os._exit(0)
            else:
                stats = torch.zeros(
                    (len(miniworlddataset.cities), 2, 2)).cuda()

    if i > nbbatchs * 0.1:
        loss = loss * 0.5
    if i > nbbatchs * 0.2:
        loss = loss * 0.5
    if i > nbbatchs * 0.5:
        loss = loss * 0.5
    if i > nbbatchs * 0.8:
Exemplo n.º 2
0
            z = largeforward(net, x.unsqueeze(0))
            z = globalresize(z)
            z = (z[0, 1, :, :] > z[0, 0, :, :]).float()

            for a, b in [(0, 0), (0, 1), (1, 0), (1, 1)]:
                cm[k][a][b] += torch.sum(
                    (z == a).float() * (y == b).float() * D)

            if False:
                nextI = len(os.listdir("build"))
                debug = miniworld.torchTOpil(globalresize(x))
                debug = PIL.Image.fromarray(numpy.uint8(debug))
                debug.save("build/" + str(nextI) + "_x.png")
                debug = (2.0 * y - 1) * D * 127 + 127
                debug = debug.cpu().numpy()
                debug = PIL.Image.fromarray(numpy.uint8(debug))
                debug.save("build/" + str(nextI) + "_y.png")
                debug = z.cpu().numpy() * 255
                debug = PIL.Image.fromarray(numpy.uint8(debug))
                debug.save("build/" + str(nextI) + "_z.png")

        print("perf=", miniworld.perf(cm[k]))
        numpy.savetxt("build/tmp.txt", miniworld.perf(cm).cpu().numpy())

print("-------- results ----------")
for k, city in enumerate(miniworlddataset.cities):
    print(city, miniworld.perf(cm[k]))

cm = torch.sum(cm, dim=0)
print("miniworld", miniworld.perf(cm))
Exemplo n.º 3
0
            z = globalresize(z)
            z = (z[0, 1, :, :] > z[0, 0, :, :]).float()

            cm[k] += miniworld.confusion(y, z, size=size)

            if False:
                nextI = len(os.listdir("build"))
                debug = miniworld.torchTOpil(globalresize(x))
                debug = PIL.Image.fromarray(numpy.uint8(debug))
                debug.save("build/" + str(nextI) + "_x.png")
                debug = y.float()
                debug = debug * 2 * (1 - miniworld.isborder(y, size=size))
                debug = debug + miniworld.isborder(y, size=size)
                debug *= 127
                debug = debug.cpu().numpy()
                debug = PIL.Image.fromarray(numpy.uint8(debug))
                debug.save("build/" + str(nextI) + "_y.png")
                debug = z.cpu().numpy() * 255
                debug = PIL.Image.fromarray(numpy.uint8(debug))
                debug.save("build/" + str(nextI) + "_z.png")

        print("perf=", miniworld.perf(cm[k]))

perfs = miniworld.perf(cm)
print("miniworld", perfs[-1])
print(perfs)
numpy.savetxt(name,
              numpy.int16(perfs.cpu().numpy() * 10),
              fmt="%i",
              delimiter="\t")
Exemplo n.º 4
0
        z = (z[:, 1, :, :] > z[:, 0, :, :]).clone().detach().float()
        for j in range(batchsize):
            stats[batchchoise[j]] += miniworld.confusion(y[j], z[j])

        if i < 10:
            print(i, "/", nbbatchs, printloss)
        if i < 1000 and i % 100 == 99:
            print(i, "/", nbbatchs, printloss / 100)
            printloss = torch.zeros(1).cuda()
        if i >= 1000 and i % 300 == 299:
            print(i, "/", nbbatchs, printloss / 300)
            printloss = torch.zeros(1).cuda()

        if i % 1000 == 999:
            torch.save(net, "build/model.pth")
            perf = miniworld.perf(stats)
            print(i, "perf", perf)
            if perf[-1][0] > 95:
                print("training stops after reaching high training accuracy")
                os._exit(0)
            else:
                stats = torch.zeros((len(dataset.cities), 2, 2)).cuda()

    if i > nbbatchs * 0.1:
        loss = loss * 0.5
    if i > nbbatchs * 0.2:
        loss = loss * 0.5
    if i > nbbatchs * 0.5:
        loss = loss * 0.5
    if i > nbbatchs * 0.8:
        loss = loss * 0.5