示例#1
0
def run_mst_generate(args):
    """
    generate skeleton in batch
    :param args: input folder path and data folder path
    """
    test_list = np.loadtxt(os.path.join(args.dataset_folder, 'test_final.txt'),
                           dtype=np.int)
    root_select_model = ROOTNET()
    root_select_model.to(device)
    root_select_model.eval()
    root_checkpoint = torch.load(args.rootnet)
    root_select_model.load_state_dict(root_checkpoint['state_dict'])
    connectivity_model = PairCls()
    connectivity_model.to(device)
    connectivity_model.eval()
    conn_checkpoint = torch.load(args.bonenet)
    connectivity_model.load_state_dict(conn_checkpoint['state_dict'])

    for model_id in test_list:
        print(model_id)
        pred_joints, vox = predict_joints(model_id, args)
        mesh_filename = os.path.join(args.dataset_folder,
                                     'obj_remesh/{:d}.obj'.format(model_id))
        mesh = o3d.io.read_triangle_mesh(mesh_filename)
        surface_geodesic = calc_surface_geodesic(mesh)
        data = create_single_data(mesh, vox, surface_geodesic, pred_joints)
        root_id = getInitId(data, root_select_model)
        with torch.no_grad():
            cost_matrix, _ = connectivity_model.forward(data)
            connect_prob = torch.sigmoid(cost_matrix)
        pair_idx = data.pairs.long().data.cpu().numpy()
        cost_matrix = np.zeros((data.num_joint[0], data.num_joint[0]))
        cost_matrix[pair_idx[:, 0],
                    pair_idx[:,
                             1]] = connect_prob.data.cpu().numpy().squeeze()
        cost_matrix = cost_matrix + cost_matrix.transpose()
        cost_matrix = -np.log(cost_matrix + 1e-10)
        #cost_matrix = flip_cost_matrix(pred_joints, cost_matrix)
        cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints,
                                                     vox)

        skel = Skel()
        parent, key, root_id = primMST_symmetry(cost_matrix, root_id,
                                                pred_joints)
        for i in range(len(parent)):
            if parent[i] == -1:
                skel.root = TreeNode('root', tuple(pred_joints[i]))
                break
        loadSkel_recur(skel.root, i, None, pred_joints, parent)
        img = show_obj_skel(mesh_filename, skel.root)
        cv2.imwrite(
            os.path.join(args.res_folder, '{:d}_skel.jpg'.format(model_id)),
            img[:, :, ::-1])
        skel.save(
            os.path.join(args.res_folder, '{:d}_skel.txt'.format(model_id)))
示例#2
0
    jointNet.eval()
    jointNet_checkpoint = torch.load(
        'checkpoints/gcn_meanshift/model_best.pth.tar')
    jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
    print("     joint prediction network loaded.")

    rootNet = ROOTNET()
    rootNet.to(device)
    rootNet.eval()
    rootNet_checkpoint = torch.load('checkpoints/rootnet/model_best.pth.tar')
    rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
    print("     root prediction network loaded.")

    boneNet = BONENET()
    boneNet.to(device)
    boneNet.eval()
    boneNet_checkpoint = torch.load('checkpoints/bonenet/model_best.pth.tar')
    boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
    print("     connection prediction network loaded.")

    skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
    skinNet_checkpoint = torch.load('checkpoints/skinnet/model_best.pth.tar')
    skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
    skinNet.to(device)
    skinNet.eval()
    print("     skinning prediction network loaded.")

    # Here we provide 16~17 examples. For best results, we will need to override the learned bandwidth and its associated threshold
    # To process other input characters, please first try the learned bandwidth (0.0429 in the provided model), and the default threshold 1e-5.
    # We also use these two default parameters for processing all test models in batch.
示例#3
0
def runApp(model_id, bandwidth, threshold):
    global device, data

    input_folder = "quick_start/"

    if platform == "linux" or platform == "linux2":
        print("I am linux")
        os.system("echo Hello from the other side!")
        os.system("Xvfb :99 -screen 0 640x480x24 &")
        os.system("export DISPLAY=:99")


    # downsample_skinning is used to speed up the calculation of volumetric geodesic distance
    # and to save cpu memory in skinning calculation.
    # Change to False to be more accurate but less efficient.
    downsample_skinning = True

    # load all weights
    print("loading all networks...")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    jointNet = JOINTNET()
    jointNet.to(device)
    jointNet.eval()
    jointNet_checkpoint = torch.load('checkpoints/gcn_meanshift/model_best.pth.tar', map_location='cpu')
    jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
    print("     joint prediction network loaded.")

    rootNet = ROOTNET()
    rootNet.to(device)
    rootNet.eval()
    rootNet_checkpoint = torch.load('checkpoints/rootnet/model_best.pth.tar', map_location='cpu')
    rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
    print("     root prediction network loaded.")

    boneNet = BONENET()
    boneNet.to(device)
    boneNet.eval()
    boneNet_checkpoint = torch.load('checkpoints/bonenet/model_best.pth.tar', map_location='cpu')
    boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
    print("     connection prediction network loaded.")

    skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
    skinNet_checkpoint = torch.load('checkpoints/skinnet/model_best.pth.tar', map_location='cpu')
    skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
    skinNet.to(device)
    skinNet.eval()
    print("     skinning prediction network loaded.")

    # Here we provide 16~17 examples. For best results, we will need to override the learned bandwidth and its associated threshold
    # To process other input characters, please first try the learned bandwidth (0.0429 in the provided model), and the default threshold 1e-5.
    # We also use these two default parameters for processing all test models in batch.

    # model_id, bandwidth, threshold = "1", 0.045, 0.75e-5

    # create data used for inferece
    print("creating data for model ID {:s}".format(model_id))
    mesh_filename = os.path.join(input_folder, '{:s}_remesh.obj'.format(model_id))

    data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(mesh_filename)
    data.to(device)

    print("predicting joints")
    data = predict_joints(data, vox, jointNet, threshold, bandwidth=bandwidth,
                      mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
    data.to(device)
    print("predicting connectivity")
    pred_skeleton = predict_skeleton(data, vox, rootNet, boneNet,
                                 mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
    print("predicting skinning")
    pred_rig = predict_skinning(data, pred_skeleton, skinNet, surface_geodesic,
                            mesh_filename.replace("_remesh.obj", "_normalized.obj"),
                            subsampling=downsample_skinning)

    # here we reverse the normalization to the original scale and position
    pred_rig.normalize(scale_normalize, -translation_normalize)

    print("Saving result")
    if True:
        # here we use original mesh tesselation (without remeshing)
        mesh_filename_ori = os.path.join(input_folder, '{:s}_ori.obj'.format(model_id))
        pred_rig = tranfer_to_ori_mesh(mesh_filename_ori, mesh_filename, pred_rig)
        pred_rig.save(mesh_filename_ori.replace('.obj', '_rig.txt'))
    else:
        # here we use remeshed mesh
        pred_rig.save(mesh_filename.replace('.obj', '_rig.txt'))
    print("Done!")
示例#4
0
def predict_rig(mesh_obj,
                bandwidth,
                threshold,
                downsample_skinning=True,
                decimation=3000,
                sampling=1500):
    print("predicting rig")
    # downsample_skinning is used to speed up the calculation of volumetric geodesic distance
    # and to save cpu memory in skinning calculation.
    # Change to False to be more accurate but less efficient.

    # load all weights
    print("loading all networks...")

    model_dir = bpy.context.preferences.addons[
        __package__].preferences.model_path

    jointNet = JOINTNET()
    jointNet.to(device)
    jointNet.eval()
    jointNet_checkpoint = torch.load(
        os.path.join(model_dir, 'gcn_meanshift/model_best.pth.tar'))
    jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
    print("     joint prediction network loaded.")

    rootNet = ROOTNET()
    rootNet.to(device)
    rootNet.eval()
    rootNet_checkpoint = torch.load(
        os.path.join(model_dir, 'rootnet/model_best.pth.tar'))
    rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
    print("     root prediction network loaded.")

    boneNet = BONENET()
    boneNet.to(device)
    boneNet.eval()
    boneNet_checkpoint = torch.load(
        os.path.join(model_dir, 'bonenet/model_best.pth.tar'))
    boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
    print("     connection prediction network loaded.")

    skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
    skinNet_checkpoint = torch.load(
        os.path.join(model_dir, 'skinnet/model_best.pth.tar'))
    skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
    skinNet.to(device)
    skinNet.eval()
    print("     skinning prediction network loaded.")

    data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(
        mesh_obj)
    data.to(device)

    print("predicting joints")
    data = predict_joints(data, vox, jointNet, threshold, bandwidth=bandwidth)

    data.to(device)
    print("predicting connectivity")
    pred_skeleton = predict_skeleton(data, vox, rootNet, boneNet)
    # pred_skeleton.normalize(scale_normalize, -translation_normalize)

    print("predicting skinning")
    pred_rig = predict_skinning(data,
                                pred_skeleton,
                                skinNet,
                                surface_geodesic,
                                subsampling=downsample_skinning,
                                decimation=decimation,
                                sampling=sampling)

    # here we reverse the normalization to the original scale and position
    pred_rig.normalize(scale_normalize, -translation_normalize)

    mesh_obj.vertex_groups.clear()

    for obj in bpy.data.objects:
        obj.select_set(False)

    ArmatureGenerator(pred_rig, mesh_obj).generate()
    torch.cuda.empty_cache()