Beispiel #1
0
def run_mst_generate(args):
    """
    generate skeleton in batch
    :param args: input folder path and data folder path
    """
    test_list = np.loadtxt(os.path.join(args.dataset_folder, 'test_final.txt'),
                           dtype=np.int)
    root_select_model = ROOTNET()
    root_select_model.to(device)
    root_select_model.eval()
    root_checkpoint = torch.load(args.rootnet)
    root_select_model.load_state_dict(root_checkpoint['state_dict'])
    connectivity_model = PairCls()
    connectivity_model.to(device)
    connectivity_model.eval()
    conn_checkpoint = torch.load(args.bonenet)
    connectivity_model.load_state_dict(conn_checkpoint['state_dict'])

    for model_id in test_list:
        print(model_id)
        pred_joints, vox = predict_joints(model_id, args)
        mesh_filename = os.path.join(args.dataset_folder,
                                     'obj_remesh/{:d}.obj'.format(model_id))
        mesh = o3d.io.read_triangle_mesh(mesh_filename)
        surface_geodesic = calc_surface_geodesic(mesh)
        data = create_single_data(mesh, vox, surface_geodesic, pred_joints)
        root_id = getInitId(data, root_select_model)
        with torch.no_grad():
            cost_matrix, _ = connectivity_model.forward(data)
            connect_prob = torch.sigmoid(cost_matrix)
        pair_idx = data.pairs.long().data.cpu().numpy()
        cost_matrix = np.zeros((data.num_joint[0], data.num_joint[0]))
        cost_matrix[pair_idx[:, 0],
                    pair_idx[:,
                             1]] = connect_prob.data.cpu().numpy().squeeze()
        cost_matrix = cost_matrix + cost_matrix.transpose()
        cost_matrix = -np.log(cost_matrix + 1e-10)
        #cost_matrix = flip_cost_matrix(pred_joints, cost_matrix)
        cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints,
                                                     vox)

        skel = Skel()
        parent, key, root_id = primMST_symmetry(cost_matrix, root_id,
                                                pred_joints)
        for i in range(len(parent)):
            if parent[i] == -1:
                skel.root = TreeNode('root', tuple(pred_joints[i]))
                break
        loadSkel_recur(skel.root, i, None, pred_joints, parent)
        img = show_obj_skel(mesh_filename, skel.root)
        cv2.imwrite(
            os.path.join(args.res_folder, '{:d}_skel.jpg'.format(model_id)),
            img[:, :, ::-1])
        skel.save(
            os.path.join(args.res_folder, '{:d}_skel.txt'.format(model_id)))
Beispiel #2
0
    downsample_skinning = True

    # load all weights
    print("loading all networks...")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    jointNet = JOINTNET()
    jointNet.to(device)
    jointNet.eval()
    jointNet_checkpoint = torch.load(
        'checkpoints/gcn_meanshift/model_best.pth.tar')
    jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
    print("     joint prediction network loaded.")

    rootNet = ROOTNET()
    rootNet.to(device)
    rootNet.eval()
    rootNet_checkpoint = torch.load('checkpoints/rootnet/model_best.pth.tar')
    rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
    print("     root prediction network loaded.")

    boneNet = BONENET()
    boneNet.to(device)
    boneNet.eval()
    boneNet_checkpoint = torch.load('checkpoints/bonenet/model_best.pth.tar')
    boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
    print("     connection prediction network loaded.")

    skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
    skinNet_checkpoint = torch.load('checkpoints/skinnet/model_best.pth.tar')
    skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
Beispiel #3
0
def main(args):
    global device
    best_acc = 0.0

    # create checkpoint dir and log dir
    if not isdir(args.checkpoint):
        print("Create new checkpoint folder " + args.checkpoint)
    mkdir_p(args.checkpoint)
    if not args.resume:
        if isdir(args.logdir):
            shutil.rmtree(args.logdir)
        mkdir_p(args.logdir)

    # create model
    model = ROOTNET()

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr = optimizer.param_groups[0]['lr']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))
    train_loader = DataLoader(GraphDataset(root=args.train_folder),
                              batch_size=args.train_batch,
                              shuffle=True,
                              follow_batch=['joints'])
    val_loader = DataLoader(GraphDataset(root=args.val_folder),
                            batch_size=args.test_batch,
                            shuffle=False,
                            follow_batch=['joints'])
    test_loader = DataLoader(GraphDataset(root=args.test_folder),
                             batch_size=args.test_batch,
                             shuffle=False,
                             follow_batch=['joints'])
    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(test_loader, model)
        print('test_loss {:.8f}. test_acc: {:.6f}'.format(test_loss, test_acc))
        return
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     args.schedule,
                                                     gamma=args.gamma)
    logger = SummaryWriter(log_dir=args.logdir)
    for epoch in range(args.start_epoch, args.epochs):
        lr = scheduler.get_last_lr()
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr[0]))
        train_loss = train(train_loader, model, optimizer, args)
        val_loss, val_acc = test(val_loader, model)
        test_loss, test_acc = test(test_loader, model)
        scheduler.step()
        print('Epoch{:d}. train_loss: {:.6f}.'.format(epoch + 1, train_loss))
        print('Epoch{:d}. val_loss: {:.6f}. val_acc: {:.6f}'.format(
            epoch + 1, val_loss, val_acc))
        print('Epoch{:d}. test_loss: {:.6f}. test_acc: {:.6f}'.format(
            epoch + 1, test_loss, test_acc))

        # remember best acc and save checkpoint
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict()
            },
            is_best,
            checkpoint=args.checkpoint)

        info = {
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_accuracy': val_acc,
            'test_loss': test_loss,
            'test_accuracy': test_acc
        }
        for tag, value in info.items():
            logger.add_scalar(tag, value, epoch + 1)
    print("=> loading checkpoint '{}'".format(
        os.path.join(args.checkpoint, 'model_best.pth.tar')))
    checkpoint = torch.load(os.path.join(args.checkpoint,
                                         'model_best.pth.tar'))
    best_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})".format(
        os.path.join(args.checkpoint, 'model_best.pth.tar'), best_epoch))
    test_loss, test_acc = test(test_loader, model)
    print('Best epoch:\n test_loss {:8f} test_acc {:8f}'.format(
        test_loss, test_acc))
Beispiel #4
0
def runApp(model_id, bandwidth, threshold):
    global device, data

    input_folder = "quick_start/"

    if platform == "linux" or platform == "linux2":
        print("I am linux")
        os.system("echo Hello from the other side!")
        os.system("Xvfb :99 -screen 0 640x480x24 &")
        os.system("export DISPLAY=:99")


    # downsample_skinning is used to speed up the calculation of volumetric geodesic distance
    # and to save cpu memory in skinning calculation.
    # Change to False to be more accurate but less efficient.
    downsample_skinning = True

    # load all weights
    print("loading all networks...")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    jointNet = JOINTNET()
    jointNet.to(device)
    jointNet.eval()
    jointNet_checkpoint = torch.load('checkpoints/gcn_meanshift/model_best.pth.tar', map_location='cpu')
    jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
    print("     joint prediction network loaded.")

    rootNet = ROOTNET()
    rootNet.to(device)
    rootNet.eval()
    rootNet_checkpoint = torch.load('checkpoints/rootnet/model_best.pth.tar', map_location='cpu')
    rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
    print("     root prediction network loaded.")

    boneNet = BONENET()
    boneNet.to(device)
    boneNet.eval()
    boneNet_checkpoint = torch.load('checkpoints/bonenet/model_best.pth.tar', map_location='cpu')
    boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
    print("     connection prediction network loaded.")

    skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
    skinNet_checkpoint = torch.load('checkpoints/skinnet/model_best.pth.tar', map_location='cpu')
    skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
    skinNet.to(device)
    skinNet.eval()
    print("     skinning prediction network loaded.")

    # Here we provide 16~17 examples. For best results, we will need to override the learned bandwidth and its associated threshold
    # To process other input characters, please first try the learned bandwidth (0.0429 in the provided model), and the default threshold 1e-5.
    # We also use these two default parameters for processing all test models in batch.

    # model_id, bandwidth, threshold = "1", 0.045, 0.75e-5

    # create data used for inferece
    print("creating data for model ID {:s}".format(model_id))
    mesh_filename = os.path.join(input_folder, '{:s}_remesh.obj'.format(model_id))

    data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(mesh_filename)
    data.to(device)

    print("predicting joints")
    data = predict_joints(data, vox, jointNet, threshold, bandwidth=bandwidth,
                      mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
    data.to(device)
    print("predicting connectivity")
    pred_skeleton = predict_skeleton(data, vox, rootNet, boneNet,
                                 mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj"))
    print("predicting skinning")
    pred_rig = predict_skinning(data, pred_skeleton, skinNet, surface_geodesic,
                            mesh_filename.replace("_remesh.obj", "_normalized.obj"),
                            subsampling=downsample_skinning)

    # here we reverse the normalization to the original scale and position
    pred_rig.normalize(scale_normalize, -translation_normalize)

    print("Saving result")
    if True:
        # here we use original mesh tesselation (without remeshing)
        mesh_filename_ori = os.path.join(input_folder, '{:s}_ori.obj'.format(model_id))
        pred_rig = tranfer_to_ori_mesh(mesh_filename_ori, mesh_filename, pred_rig)
        pred_rig.save(mesh_filename_ori.replace('.obj', '_rig.txt'))
    else:
        # here we use remeshed mesh
        pred_rig.save(mesh_filename.replace('.obj', '_rig.txt'))
    print("Done!")
Beispiel #5
0
def predict_rig(mesh_obj,
                bandwidth,
                threshold,
                downsample_skinning=True,
                decimation=3000,
                sampling=1500):
    print("predicting rig")
    # downsample_skinning is used to speed up the calculation of volumetric geodesic distance
    # and to save cpu memory in skinning calculation.
    # Change to False to be more accurate but less efficient.

    # load all weights
    print("loading all networks...")

    model_dir = bpy.context.preferences.addons[
        __package__].preferences.model_path

    jointNet = JOINTNET()
    jointNet.to(device)
    jointNet.eval()
    jointNet_checkpoint = torch.load(
        os.path.join(model_dir, 'gcn_meanshift/model_best.pth.tar'))
    jointNet.load_state_dict(jointNet_checkpoint['state_dict'])
    print("     joint prediction network loaded.")

    rootNet = ROOTNET()
    rootNet.to(device)
    rootNet.eval()
    rootNet_checkpoint = torch.load(
        os.path.join(model_dir, 'rootnet/model_best.pth.tar'))
    rootNet.load_state_dict(rootNet_checkpoint['state_dict'])
    print("     root prediction network loaded.")

    boneNet = BONENET()
    boneNet.to(device)
    boneNet.eval()
    boneNet_checkpoint = torch.load(
        os.path.join(model_dir, 'bonenet/model_best.pth.tar'))
    boneNet.load_state_dict(boneNet_checkpoint['state_dict'])
    print("     connection prediction network loaded.")

    skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True)
    skinNet_checkpoint = torch.load(
        os.path.join(model_dir, 'skinnet/model_best.pth.tar'))
    skinNet.load_state_dict(skinNet_checkpoint['state_dict'])
    skinNet.to(device)
    skinNet.eval()
    print("     skinning prediction network loaded.")

    data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(
        mesh_obj)
    data.to(device)

    print("predicting joints")
    data = predict_joints(data, vox, jointNet, threshold, bandwidth=bandwidth)

    data.to(device)
    print("predicting connectivity")
    pred_skeleton = predict_skeleton(data, vox, rootNet, boneNet)
    # pred_skeleton.normalize(scale_normalize, -translation_normalize)

    print("predicting skinning")
    pred_rig = predict_skinning(data,
                                pred_skeleton,
                                skinNet,
                                surface_geodesic,
                                subsampling=downsample_skinning,
                                decimation=decimation,
                                sampling=sampling)

    # here we reverse the normalization to the original scale and position
    pred_rig.normalize(scale_normalize, -translation_normalize)

    mesh_obj.vertex_groups.clear()

    for obj in bpy.data.objects:
        obj.select_set(False)

    ArmatureGenerator(pred_rig, mesh_obj).generate()
    torch.cuda.empty_cache()