# # a2.scatter(ps2[:,0],ps2[:,1],ps2[:,2])

    # plt.show()

    foldingnet = FoldingNet_graph()

    #foldingnet.load_state_dict(torch.load('cls_fold_512code_2500points/foldingnet_model_170.pth'))
    foldingnet.load_state_dict(
        torch.load(
            'cls_fold_512code_2500points_170_restart/foldingnet_model_150.pth')
    )

    foldingnet.cuda()

    chamferloss = ChamferLoss()
    chamferloss = chamferloss.cuda()
    #print(foldingnet)

    foldingnet.eval()

    i, data = li[1]
    points, target = data

    batch_graph, Cov = build_graph(points, opt)

    Cov = Cov.transpose(2, 1)
    Cov = Cov.cuda()

    points = points.transpose(2, 1)
    points = points.cuda()
    recon_pc, mid_pc, _ = foldingnet(points, Cov, batch_graph)
Ejemplo n.º 2
0
    pass

#classifier = PointNetCls(k = num_classes)
foldingnet = FoldingNet()

if opt.model != '':
    classifier.load_state_dict(torch.load(opt.model))

#optimizer = optim.SGD(foldingnet.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam(foldingnet.parameters(), lr=0.0001, weight_decay=1e-6)
foldingnet.cuda()

num_batch = len(dataset) / opt.batchSize

chamferloss = ChamferLoss()
chamferloss.cuda()

start_time = time.time()
time_p, loss_p, loss_m = [], [], []

for epoch in range(opt.nepoch):
    sum_loss = 0
    sum_step = 0
    sum_mid_loss = 0
    for i, data in enumerate(dataloader, 0):
        points, target = data

        #print(points.size())

        points, target = Variable(points), Variable(target[:, 0])
        points = points.transpose(2, 1)