def run_mst_generate(args): """ generate skeleton in batch :param args: input folder path and data folder path """ test_list = np.loadtxt(os.path.join(args.dataset_folder, 'test_final.txt'), dtype=np.int) root_select_model = ROOTNET() root_select_model.to(device) root_select_model.eval() root_checkpoint = torch.load(args.rootnet) root_select_model.load_state_dict(root_checkpoint['state_dict']) connectivity_model = PairCls() connectivity_model.to(device) connectivity_model.eval() conn_checkpoint = torch.load(args.bonenet) connectivity_model.load_state_dict(conn_checkpoint['state_dict']) for model_id in test_list: print(model_id) pred_joints, vox = predict_joints(model_id, args) mesh_filename = os.path.join(args.dataset_folder, 'obj_remesh/{:d}.obj'.format(model_id)) mesh = o3d.io.read_triangle_mesh(mesh_filename) surface_geodesic = calc_surface_geodesic(mesh) data = create_single_data(mesh, vox, surface_geodesic, pred_joints) root_id = getInitId(data, root_select_model) with torch.no_grad(): cost_matrix, _ = connectivity_model.forward(data) connect_prob = torch.sigmoid(cost_matrix) pair_idx = data.pairs.long().data.cpu().numpy() cost_matrix = np.zeros((data.num_joint[0], data.num_joint[0])) cost_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze() cost_matrix = cost_matrix + cost_matrix.transpose() cost_matrix = -np.log(cost_matrix + 1e-10) #cost_matrix = flip_cost_matrix(pred_joints, cost_matrix) cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox) skel = Skel() parent, key, root_id = primMST_symmetry(cost_matrix, root_id, pred_joints) for i in range(len(parent)): if parent[i] == -1: skel.root = TreeNode('root', tuple(pred_joints[i])) break loadSkel_recur(skel.root, i, None, pred_joints, parent) img = show_obj_skel(mesh_filename, skel.root) cv2.imwrite( os.path.join(args.res_folder, '{:d}_skel.jpg'.format(model_id)), img[:, :, ::-1]) skel.save( os.path.join(args.res_folder, '{:d}_skel.txt'.format(model_id)))
print("loading all networks...") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") jointNet = JOINTNET() jointNet.to(device) jointNet.eval() jointNet_checkpoint = torch.load( 'checkpoints/gcn_meanshift/model_best.pth.tar') jointNet.load_state_dict(jointNet_checkpoint['state_dict']) print(" joint prediction network loaded.") rootNet = ROOTNET() rootNet.to(device) rootNet.eval() rootNet_checkpoint = torch.load('checkpoints/rootnet/model_best.pth.tar') rootNet.load_state_dict(rootNet_checkpoint['state_dict']) print(" root prediction network loaded.") boneNet = BONENET() boneNet.to(device) boneNet.eval() boneNet_checkpoint = torch.load('checkpoints/bonenet/model_best.pth.tar') boneNet.load_state_dict(boneNet_checkpoint['state_dict']) print(" connection prediction network loaded.") skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True) skinNet_checkpoint = torch.load('checkpoints/skinnet/model_best.pth.tar') skinNet.load_state_dict(skinNet_checkpoint['state_dict']) skinNet.to(device) skinNet.eval() print(" skinning prediction network loaded.")
def main(args): global device best_acc = 0.0 # create checkpoint dir and log dir if not isdir(args.checkpoint): print("Create new checkpoint folder " + args.checkpoint) mkdir_p(args.checkpoint) if not args.resume: if isdir(args.logdir): shutil.rmtree(args.logdir) mkdir_p(args.logdir) # create model model = ROOTNET() model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc = checkpoint['best_acc'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) lr = optimizer.param_groups[0]['lr'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) train_loader = DataLoader(GraphDataset(root=args.train_folder), batch_size=args.train_batch, shuffle=True, follow_batch=['joints']) val_loader = DataLoader(GraphDataset(root=args.val_folder), batch_size=args.test_batch, shuffle=False, follow_batch=['joints']) test_loader = DataLoader(GraphDataset(root=args.test_folder), batch_size=args.test_batch, shuffle=False, follow_batch=['joints']) if args.evaluate: print('\nEvaluation only') test_loss, test_acc = test(test_loader, model) print('test_loss {:.8f}. test_acc: {:.6f}'.format(test_loss, test_acc)) return scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.schedule, gamma=args.gamma) logger = SummaryWriter(log_dir=args.logdir) for epoch in range(args.start_epoch, args.epochs): lr = scheduler.get_last_lr() print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr[0])) train_loss = train(train_loader, model, optimizer, args) val_loss, val_acc = test(val_loader, model) test_loss, test_acc = test(test_loader, model) scheduler.step() print('Epoch{:d}. train_loss: {:.6f}.'.format(epoch + 1, train_loss)) print('Epoch{:d}. val_loss: {:.6f}. val_acc: {:.6f}'.format( epoch + 1, val_loss, val_acc)) print('Epoch{:d}. test_loss: {:.6f}. test_acc: {:.6f}'.format( epoch + 1, test_loss, test_acc)) # remember best acc and save checkpoint is_best = val_acc > best_acc best_acc = max(val_acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc': best_acc, 'optimizer': optimizer.state_dict() }, is_best, checkpoint=args.checkpoint) info = { 'train_loss': train_loss, 'val_loss': val_loss, 'val_accuracy': val_acc, 'test_loss': test_loss, 'test_accuracy': test_acc } for tag, value in info.items(): logger.add_scalar(tag, value, epoch + 1) print("=> loading checkpoint '{}'".format( os.path.join(args.checkpoint, 'model_best.pth.tar'))) checkpoint = torch.load(os.path.join(args.checkpoint, 'model_best.pth.tar')) best_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( os.path.join(args.checkpoint, 'model_best.pth.tar'), best_epoch)) test_loss, test_acc = test(test_loader, model) print('Best epoch:\n test_loss {:8f} test_acc {:8f}'.format( test_loss, test_acc))
def runApp(model_id, bandwidth, threshold): global device, data input_folder = "quick_start/" if platform == "linux" or platform == "linux2": print("I am linux") os.system("echo Hello from the other side!") os.system("Xvfb :99 -screen 0 640x480x24 &") os.system("export DISPLAY=:99") # downsample_skinning is used to speed up the calculation of volumetric geodesic distance # and to save cpu memory in skinning calculation. # Change to False to be more accurate but less efficient. downsample_skinning = True # load all weights print("loading all networks...") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") jointNet = JOINTNET() jointNet.to(device) jointNet.eval() jointNet_checkpoint = torch.load('checkpoints/gcn_meanshift/model_best.pth.tar', map_location='cpu') jointNet.load_state_dict(jointNet_checkpoint['state_dict']) print(" joint prediction network loaded.") rootNet = ROOTNET() rootNet.to(device) rootNet.eval() rootNet_checkpoint = torch.load('checkpoints/rootnet/model_best.pth.tar', map_location='cpu') rootNet.load_state_dict(rootNet_checkpoint['state_dict']) print(" root prediction network loaded.") boneNet = BONENET() boneNet.to(device) boneNet.eval() boneNet_checkpoint = torch.load('checkpoints/bonenet/model_best.pth.tar', map_location='cpu') boneNet.load_state_dict(boneNet_checkpoint['state_dict']) print(" connection prediction network loaded.") skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True) skinNet_checkpoint = torch.load('checkpoints/skinnet/model_best.pth.tar', map_location='cpu') skinNet.load_state_dict(skinNet_checkpoint['state_dict']) skinNet.to(device) skinNet.eval() print(" skinning prediction network loaded.") # Here we provide 16~17 examples. For best results, we will need to override the learned bandwidth and its associated threshold # To process other input characters, please first try the learned bandwidth (0.0429 in the provided model), and the default threshold 1e-5. # We also use these two default parameters for processing all test models in batch. # model_id, bandwidth, threshold = "1", 0.045, 0.75e-5 # create data used for inferece print("creating data for model ID {:s}".format(model_id)) mesh_filename = os.path.join(input_folder, '{:s}_remesh.obj'.format(model_id)) data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data(mesh_filename) data.to(device) print("predicting joints") data = predict_joints(data, vox, jointNet, threshold, bandwidth=bandwidth, mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj")) data.to(device) print("predicting connectivity") pred_skeleton = predict_skeleton(data, vox, rootNet, boneNet, mesh_filename=mesh_filename.replace("_remesh.obj", "_normalized.obj")) print("predicting skinning") pred_rig = predict_skinning(data, pred_skeleton, skinNet, surface_geodesic, mesh_filename.replace("_remesh.obj", "_normalized.obj"), subsampling=downsample_skinning) # here we reverse the normalization to the original scale and position pred_rig.normalize(scale_normalize, -translation_normalize) print("Saving result") if True: # here we use original mesh tesselation (without remeshing) mesh_filename_ori = os.path.join(input_folder, '{:s}_ori.obj'.format(model_id)) pred_rig = tranfer_to_ori_mesh(mesh_filename_ori, mesh_filename, pred_rig) pred_rig.save(mesh_filename_ori.replace('.obj', '_rig.txt')) else: # here we use remeshed mesh pred_rig.save(mesh_filename.replace('.obj', '_rig.txt')) print("Done!")
def predict_rig(mesh_obj, bandwidth, threshold, downsample_skinning=True, decimation=3000, sampling=1500): print("predicting rig") # downsample_skinning is used to speed up the calculation of volumetric geodesic distance # and to save cpu memory in skinning calculation. # Change to False to be more accurate but less efficient. # load all weights print("loading all networks...") model_dir = bpy.context.preferences.addons[ __package__].preferences.model_path jointNet = JOINTNET() jointNet.to(device) jointNet.eval() jointNet_checkpoint = torch.load( os.path.join(model_dir, 'gcn_meanshift/model_best.pth.tar')) jointNet.load_state_dict(jointNet_checkpoint['state_dict']) print(" joint prediction network loaded.") rootNet = ROOTNET() rootNet.to(device) rootNet.eval() rootNet_checkpoint = torch.load( os.path.join(model_dir, 'rootnet/model_best.pth.tar')) rootNet.load_state_dict(rootNet_checkpoint['state_dict']) print(" root prediction network loaded.") boneNet = BONENET() boneNet.to(device) boneNet.eval() boneNet_checkpoint = torch.load( os.path.join(model_dir, 'bonenet/model_best.pth.tar')) boneNet.load_state_dict(boneNet_checkpoint['state_dict']) print(" connection prediction network loaded.") skinNet = SKINNET(nearest_bone=5, use_Dg=True, use_Lf=True) skinNet_checkpoint = torch.load( os.path.join(model_dir, 'skinnet/model_best.pth.tar')) skinNet.load_state_dict(skinNet_checkpoint['state_dict']) skinNet.to(device) skinNet.eval() print(" skinning prediction network loaded.") data, vox, surface_geodesic, translation_normalize, scale_normalize = create_single_data( mesh_obj) data.to(device) print("predicting joints") data = predict_joints(data, vox, jointNet, threshold, bandwidth=bandwidth) data.to(device) print("predicting connectivity") pred_skeleton = predict_skeleton(data, vox, rootNet, boneNet) # pred_skeleton.normalize(scale_normalize, -translation_normalize) print("predicting skinning") pred_rig = predict_skinning(data, pred_skeleton, skinNet, surface_geodesic, subsampling=downsample_skinning, decimation=decimation, sampling=sampling) # here we reverse the normalization to the original scale and position pred_rig.normalize(scale_normalize, -translation_normalize) mesh_obj.vertex_groups.clear() for obj in bpy.data.objects: obj.select_set(False) ArmatureGenerator(pred_rig, mesh_obj).generate() torch.cuda.empty_cache()