コード例 #1
0
def test(param,test_npy_fn, out_ply_folder, out_img_folder, is_render_mesh=False, skip_frames =0):
    
    
    print ("**********Initiate Netowrk**********")
    model=graphAE.Model(param)
    
    model.cuda()
    
    if(param.read_weight_path!=""):
        print ("load "+param.read_weight_path)
        checkpoint = torch.load(param.read_weight_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        #model.init_test_mode()
    
    
    model.train()
    
    
    print ("**********Get test pcs**********", test_npy_fn)
    ##get ply file lst
    pc_lst= np.load(test_npy_fn)

    print (pc_lst.shape[0], "meshes in total.")
    pc_lst[:,:,0:3] -= pc_lst[:,:,0:3].mean(1).reshape((-1,1,3)).repeat(param.point_num, 1)

    geo_error_avg=evaluate(param, model, pc_lst)

    print ("geo error:", geo_error_avg)#, "laplace error:", laplace_error_avg)
コード例 #2
0
def test(param, test_ply_fn, start_id, end_id, out_ply_folder):

    print("**********Initiate Netowrk**********")
    model = graphAE.Model(param)

    model.cuda()

    if (param.read_weight_path != ""):
        print("load " + param.read_weight_path)
        checkpoint = torch.load(param.read_weight_path)
        model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()

    template_plydata = PlyData.read(param.template_ply_fn)

    MSE_sum = 0
    MSE_num = 0
    for i in range(start_id, end_id):
        ply_fn = test_ply_fn + "%04d" % i + ".ply"
        if (os.path.exists(ply_fn) == False):
            continue
        pc, mean = get_pc_from_ply_fn(ply_fn)

        MSE = evaluate(param, model, "%04d" % i, pc, mean, template_plydata,
                       out_ply_folder)
        MSE_sum = MSE_sum + MSE
        MSE_num = MSE_num + 1
    print("mean MSE:", MSE_sum / MSE_num)
コード例 #3
0
def visualize_weights(param, out_folder):
    print ("**********Initiate Netowrk**********")
    model=graphAE.Model(param, test_mode=False)
    
    model.cuda()
    
    if(param.read_weight_path!=""):
        print ("load "+param.read_weight_path)
        checkpoint = torch.load(param.read_weight_path)
        model.load_state_dict(checkpoint['model_state_dict'])
    
    
    model.eval()
    
    model.quantify_and_draw_w_weight_histogram(out_folder)
コード例 #4
0
def test(param, test_ply_fn, start_id, end_id, out_ply_folder,
         out_render_gt_folder, out_render_out_folder):

    print("**********Initiate Netowrk**********")
    model = graphAE.Model(param)

    model.cuda()

    if (param.read_weight_path != ""):
        print("load " + param.read_weight_path)
        checkpoint = torch.load(param.read_weight_path)
        model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()

    template_plydata = PlyData.read(param.template_ply_fn)
    faces = get_faces_from_ply(template_plydata)

    from renderer.render.gl.glcontext import create_opengl_context
    from renderer.render.gl.glcontext import destroy_opengl_context
    RESOLUTION = 1024

    glutWindow = create_opengl_context(width=RESOLUTION, height=RESOLUTION)

    render = Renderer(width=RESOLUTION, height=RESOLUTION)
    MSE_sum = 0
    MSE_num = 0
    for i in range(start_id, end_id):
        ply_fn = test_ply_fn + "%06d" % i + ".ply"
        if (os.path.exists(ply_fn) == False):
            continue
        pc, mean = get_pc_from_ply_fn(ply_fn)

        MSE = evaluate(param, model, render, "%06d" % i, pc, mean,
                       template_plydata, faces, out_ply_folder,
                       out_render_gt_folder, out_render_out_folder)
        MSE_sum = MSE_sum + MSE
        MSE_num = MSE_num + 1
    print("mean MSE:", MSE_sum / MSE_num)
    destroy_opengl_context(glutWindow)
コード例 #5
0
def train(param):
    torch.manual_seed(0)
    np.random.seed(0)
    
    
    print ("**********Initiate Netowrk**********")
    model=graphAE.Model(param)
    
    model.cuda()
    optimizer = torch.optim.Adam(params = model.parameters(), lr=param.lr, weight_decay=param.weight_decay)
    scheduler=torch.optim.lr_scheduler.StepLR(optimizer, param.lr_decay_epoch_step,gamma=param.lr_decay)

    if(param.read_weight_path!=""):
        print ("load "+param.read_weight_path)
        checkpoint = torch.load(param.read_weight_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    model.compute_param_num()
    
    model.train()
    
    
    
    print ("**********Get training ply fn list from**********", param.pcs_train)
    ##get ply file lst
    pc_lst_train = np.load(param.pcs_train)
    param.iter_per_epoch = int(len(pc_lst_train) / param.batch)
    param.end_iter = param.iter_per_epoch * param.epoch
    print ("**********Get evaluating ply fn list from**********", param.pcs_evaluate)
    pc_lst_evaluate = np.load(param.pcs_evaluate)
    #print ("**********Get test ply fn list from**********", param.pcs_evaluate)
    #pc_lst_test = np.load(param.pcs_test)

    np.random.shuffle(pc_lst_train)
    np.random.shuffle(pc_lst_evaluate)


    pc_lst_train[:,:,0:3] -= pc_lst_train[:,:,0:3].mean(1).reshape((-1,1,3)).repeat(param.point_num, 1)
    pc_lst_evaluate[:,:,0:3] -= pc_lst_evaluate[:,:,0:3].mean(1).reshape((-1,1,3)).repeat(param.point_num, 1)
    #pc_lst_test[:,:,0:3] -= pc_lst_test[:,:,0:3].mean(1).reshape((-1,1,3)).repeat(param.point_num, 1)

    template_plydata = PlyData.read(param.template_ply_fn)
    
    print ("**********Start Training**********")
    
    min_geo_error=123456
    for i in range(param.start_epoch, param.epoch+1):

        if(((i%param.evaluate_epoch==0)and(i!=0)) or(i==param.epoch)):
            print ("###Evaluate", "epoch", i, "##########################")
            with torch.no_grad():
                torch.manual_seed(0)
                np.random.seed(0)
                geo_error = evaluate(param, model, pc_lst_evaluate,i,template_plydata, suffix="eval")    
                if(geo_error<min_geo_error):
                    min_geo_error=geo_error
                    print ("###Save Weight")
                    path = param.write_weight_folder + "model_epoch%04d"%i +".weight"
                    torch.save({'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict()}, path)
                
            
        torch.manual_seed(i)
        np.random.seed(i)

        for j in range(param.iter_per_epoch):
            train_one_iteration(param, model, optimizer,pc_lst_train, i, j)
        
        scheduler.step()
コード例 #6
0
def test(param,test_npy_fn, out_ply_folder, skip_frames =0):
    
    
    print ("**********Initiate Netowrk**********")
    model=graphAE.Model(param, test_mode=True)
    
    model.cuda()
    
    if(param.read_weight_path!=""):
        print ("load "+param.read_weight_path)
        checkpoint = torch.load(param.read_weight_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.init_test_mode()
    
    
    model.eval()
    
    template_plydata = PlyData.read(param.template_ply_fn)
    faces = get_faces_from_ply(template_plydata)
    
    
    pose_sum=0
    laplace_sum=0
    test_num=0
    
    print ("**********Get test pcs**********", test_npy_fn)
    ##get ply file lst
    pc_lst= np.load(test_npy_fn)
    print (pc_lst.shape[0], "meshes in total.")
    

    geo_error_sum = 0
    laplace_error_sum=0
    pc_num = len(pc_lst)
    n = 0
    
    while (n<(pc_num-1)):
        
        batch = min(pc_num-n, param.batch)
        pcs = pc_lst[n:n+batch]
        height = pcs[:,:,1].mean(1)
        pcs[:,:,0:3] -= pcs[:,:,0:3].mean(1).reshape((-1,1,3)).repeat(param.point_num, 1) ##centralize each instance

        pcs_torch = torch.FloatTensor(pcs).cuda()
        if(param.augmented_data==True):
            pcs_torch = Dataloader.get_augmented_pcs(pcs_torch)
        if(batch<param.batch):
            pcs_torch = torch.cat((pcs_torch, torch.zeros(param.batch-batch, param.point_num, 3).cuda()),0)

        out_pcs_torch = model(pcs_torch)
        geo_error = model.compute_geometric_mean_euclidean_dist_error(pcs_torch[0:batch], out_pcs_torch[0:batch])
        geo_error_sum += geo_error*batch
        laplace_error_sum = laplace_error_sum + model.compute_laplace_Mean_Euclidean_Error(pcs_torch[0:batch], out_pcs_torch[0:batch])*batch
        print (n, geo_error.item())
        

        if(n % 128 ==0):
            print (height[0])
            pc_gt = np.array(pcs_torch[0].data.tolist()) 
            pc_gt[:,1] +=height[0]
            pc_out = np.array(out_pcs_torch[0].data.tolist())
            pc_out[:,1] +=height[0]

            diff_pc = np.sqrt(pow(pc_gt-pc_out, 2).sum(1))
            color = get_colors_from_diff_pc(diff_pc, 0, 0.02)*255
            Dataloader.save_pc_with_color_into_ply(template_plydata, pc_out, color, out_ply_folder+"%08d"%(n)+"_out.ply")
            Dataloader.save_pc_into_ply(template_plydata, pc_gt, out_ply_folder+"%08d"%(n)+"_gt.ply")

        n = n+batch


    geo_error_avg=geo_error_sum.item()/pc_num
    laplace_error_avg=  laplace_error_sum.item()/pc_num

    print ("geo error:", geo_error_avg, "laplace error:", laplace_error_avg)