def export_image(obj_id): tmstmp = datetime.now().strftime("%d %b %Y - %H:%M:%S") print("{} - Exporting figure {}...".format(str(tmstmp), str(obj_id))) obj = PartNetDataset.load_object(os.path.join(root_dir, obj_list[obj_id])) figname = "chairs_with_moi/objnr_{}.png".format(str(obj_id)) draw_partnet_objects( objects=[obj], object_names=[obj_list[obj_id]], figsize=(9, 5), leafs_only=True, visu_edges=False, sem_colors_filename='../stats/semantics_colors/Chair.txt', save_fig=True, save_fig_file=figname)
def generate_data(obj_id): print("Calculating MoI for object {}".format(str(obj_id))) obj = PartNetDataset.load_object(os.path.join(root_dir, obj_list[obj_id])) res = moi_from_graph(obj, options) return (res.moi, res.hover_penalty)
def compute_gen_cd_numbers(in_dir, data_path, object_list, shapediff_topk, shapediff_metric, self_is_neighbor, tot_shape): chamfer_loss = ChamferDistance() data_features = [ 'object', 'name', 'neighbor_diffs', 'neighbor_objs', 'neighbor_names' ] dataset = PartNetShapeDiffDataset(data_path, object_list, data_features, shapediff_topk, shapediff_metric, self_is_neighbor) tot_gen = 100 bar = ProgressBar() quality = 0.0 coverage = 0.0 for i in bar(range(tot_shape)): obj, obj_name, neighbor_diffs, neighbor_objs, neighbor_names = dataset[ i] mat = np.zeros((shapediff_topk, tot_gen), dtype=np.float32) gt_pcs = [] for ni in range(shapediff_topk): obbs_np = torch.cat([ item.view(1, -1) for item in neighbor_objs[ni].boxes(leafs_only=True) ], dim=0).cpu().numpy() mesh_v, mesh_f = utils.gen_obb_mesh(obbs_np) pc_sample = utils.sample_pc(mesh_v, mesh_f) gt_pcs.append(np.expand_dims(pc_sample, axis=0)) gt_pcs = np.concatenate(gt_pcs, axis=0) gt_pcs = torch.from_numpy(gt_pcs).float().cuda() for i in range(tot_gen): obj = PartNetDataset.load_object( os.path.join(in_dir, obj_name, 'obj2-%03d.json' % i)) obbs_np = torch.cat( [item.view(1, -1) for item in obj.boxes(leafs_only=True)], dim=0).cpu().numpy() mesh_v, mesh_f = utils.gen_obb_mesh(obbs_np) gen_pc = utils.sample_pc(mesh_v, mesh_f) gen_pc = np.tile(np.expand_dims(gen_pc, axis=0), [shapediff_topk, 1, 1]) gen_pc = torch.from_numpy(gen_pc).float().cuda() d1, d2 = chamfer_loss(gt_pcs.cuda(), gen_pc) mat[:, i] = (d1.sqrt().mean(dim=1) + d2.sqrt().mean(dim=1)).cpu().numpy() / 2 quality += mat.min(axis=0).mean() coverage += mat.min(axis=1).mean() np.save(os.path.join(in_dir, obj_name, 'cd_stats.npy'), mat) quality /= tot_shape coverage /= tot_shape print('mean cd quality: %.5f' % quality) print('mean cd coverage: %.5f' % coverage) print('q + c: %.5f' % (quality + coverage)) with open( os.path.join(in_dir, 'neighbor_%s_cd_stats.txt' % shapediff_metric), 'w') as fout: fout.write('mean cd quality: %.5f\n' % quality) fout.write('mean cd coverage: %.5f\n' % coverage) fout.write('q + c: %.5f\n' % (quality + coverage))
def compute_gen_sd_numbers(in_dir, data_path, object_list, shapediff_topk, shapediff_metric, self_is_neighbor, tot_shape): chamfer_loss = ChamferDistance() unit_cube = torch.from_numpy(utils.load_pts('cube.pts')) def box_dist(box_feature, gt_box_feature): pred_box_pc = utils.transform_pc_batch(unit_cube, box_feature) pred_reweight = utils.get_surface_reweighting_batch( box_feature[:, 3:6], unit_cube.size(0)) gt_box_pc = utils.transform_pc_batch(unit_cube, gt_box_feature) gt_reweight = utils.get_surface_reweighting_batch( gt_box_feature[:, 3:6], unit_cube.size(0)) dist1, dist2 = chamfer_loss(gt_box_pc, pred_box_pc) loss1 = (dist1 * gt_reweight).sum(dim=1) / (gt_reweight.sum(dim=1) + 1e-12) loss2 = (dist2 * pred_reweight).sum(dim=1) / (pred_reweight.sum(dim=1) + 1e-12) loss = (loss1 + loss2) / 2 return loss def struct_dist(gt_node, pred_node): if gt_node.is_leaf: if pred_node.is_leaf: return 0 else: return len(pred_node.boxes()) - 1 else: if pred_node.is_leaf: return len(gt_node.boxes()) - 1 else: gt_sem = set([node.label for node in gt_node.children]) pred_sem = set([node.label for node in pred_node.children]) intersect_sem = set.intersection(gt_sem, pred_sem) gt_cnodes_per_sem = dict() for node_id, gt_cnode in enumerate(gt_node.children): if gt_cnode.label in intersect_sem: if gt_cnode.label not in gt_cnodes_per_sem: gt_cnodes_per_sem[gt_cnode.label] = [] gt_cnodes_per_sem[gt_cnode.label].append(node_id) pred_cnodes_per_sem = dict() for node_id, pred_cnode in enumerate(pred_node.children): if pred_cnode.label in intersect_sem: if pred_cnode.label not in pred_cnodes_per_sem: pred_cnodes_per_sem[pred_cnode.label] = [] pred_cnodes_per_sem[pred_cnode.label].append(node_id) matched_gt_idx = [] matched_pred_idx = [] matched_gt2pred = np.zeros((100), dtype=np.int32) for sem in intersect_sem: gt_boxes = torch.cat([ gt_node.children[cid].get_box_quat() for cid in gt_cnodes_per_sem[sem] ], dim=0) pred_boxes = torch.cat([ pred_node.children[cid].get_box_quat() for cid in pred_cnodes_per_sem[sem] ], dim=0) num_gt = gt_boxes.size(0) num_pred = pred_boxes.size(0) if num_gt == 1 and num_pred == 1: cur_matched_gt_idx = [0] cur_matched_pred_idx = [0] else: gt_boxes_tiled = gt_boxes.unsqueeze(dim=1).repeat( 1, num_pred, 1) pred_boxes_tiled = pred_boxes.unsqueeze(dim=0).repeat( num_gt, 1, 1) dmat = box_dist(gt_boxes_tiled.view(-1, 10), pred_boxes_tiled.view(-1, 10)).view( -1, num_gt, num_pred) _, cur_matched_gt_idx, cur_matched_pred_idx = utils.linear_assignment( dmat) for i in range(len(cur_matched_gt_idx)): matched_gt_idx.append( gt_cnodes_per_sem[sem][cur_matched_gt_idx[i]]) matched_pred_idx.append( pred_cnodes_per_sem[sem][cur_matched_pred_idx[i]]) matched_gt2pred[gt_cnodes_per_sem[sem][ cur_matched_gt_idx[i]]] = pred_cnodes_per_sem[sem][ cur_matched_pred_idx[i]] struct_diff = 0.0 for i in range(len(gt_node.children)): if i not in matched_gt_idx: struct_diff += len(gt_node.children[i].boxes()) for i in range(len(pred_node.children)): if i not in matched_pred_idx: struct_diff += len(pred_node.children[i].boxes()) for i in range(len(matched_gt_idx)): gt_id = matched_gt_idx[i] pred_id = matched_pred_idx[i] struct_diff += struct_dist(gt_node.children[gt_id], pred_node.children[pred_id]) return struct_diff # create dataset and data loader data_features = [ 'object', 'name', 'neighbor_diffs', 'neighbor_objs', 'neighbor_names' ] dataset = PartNetShapeDiffDataset(data_path, object_list, data_features, shapediff_topk, shapediff_metric, self_is_neighbor) tot_gen = 100 bar = ProgressBar() quality = 0.0 coverage = 0.0 for i in bar(range(tot_shape)): obj, obj_name, neighbor_diffs, neighbor_objs, neighbor_names = dataset[ i] mat1 = np.zeros((shapediff_topk, tot_gen), dtype=np.float32) mat2 = np.zeros((shapediff_topk, tot_gen), dtype=np.float32) for j in range(tot_gen): gen_obj = PartNetDataset.load_object( os.path.join(in_dir, obj_name, 'obj2-%03d.json' % j)) for ni in range(shapediff_topk): sd = struct_dist(neighbor_objs[ni].root, gen_obj.root) mat1[ni, j] = sd / len(neighbor_objs[ni].root.boxes()) mat2[ni, j] = sd / len(gen_obj.root.boxes()) quality += mat2.min(axis=0).mean() coverage += mat1.min(axis=1).mean() np.save(os.path.join(in_dir, obj_name, 'sd_mat1_stats.npy'), mat1) np.save(os.path.join(in_dir, obj_name, 'sd_mat2_stats.npy'), mat2) quality /= tot_shape coverage /= tot_shape print('mean sd quality: ', quality) print('mean sd coverage: ', coverage) print('q + c: %.5f' % (quality + coverage)) with open( os.path.join(in_dir, 'neighbor_%s_sd_stats.txt' % shapediff_metric), 'w') as fout: fout.write('mean sd quality: %f\n' % quality) fout.write('mean sd coverage: %f\n' % coverage) fout.write('q + c: %.5f\n' % (quality + coverage))
# create models encoder = models.RecursiveEncoder(conf, variational=True, probabilistic=False) decoder = models.RecursiveDecoder(conf) models = [encoder, decoder] model_names = ['encoder', 'decoder'] # load pretrained model __ = utils.load_checkpoint( models=models, model_names=model_names, dirname=os.path.join(conf.model_path, conf.exp_name), epoch=conf.model_epoch, strict=True) # create dataset data_features = ['object', 'name'] dataset = PartNetDataset(conf.data_path, conf.test_dataset, data_features, load_geo=True) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=utils.collate_feats) # send to device for m in models: m.to(device) # set models to evaluation mode for m in models: m.eval() # test over all test shapes num_batch = len(dataloader) with torch.no_grad(): for batch_ind, batch in enumerate(dataloader): obj = batch[data_features.index('object')][0]
result_dir = os.path.join(conf.result_path, conf.exp_name + '_editgen_sigma_%f' % conf.sigma) if os.path.exists(result_dir): response = input( 'Eval results for "%s" already exists, overwrite? (y/n) ' % result_dir) if response != 'y': sys.exit() shutil.rmtree(result_dir) # create a new directory to store eval results os.makedirs(result_dir) # dataset data_features = ['object', 'name'] dataset = PartNetDataset(conf.data_path, conf.test_dataset, data_features, load_geo=False) # create models encoder = models.RecursiveEncoder(conf, variational=True, probabilistic=False) decoder = models.RecursiveDecoder(conf) models = [encoder, decoder] model_names = ['encoder', 'decoder'] # load pretrained model __ = utils.load_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.ckpt_path, conf.exp_name), epoch=conf.model_epoch, strict=True)
neighbor_name = neighbor_names[ni] neighbor_fn = 'neighbor_%02d_%s' % (ni, neighbor_name) neighbor_obj = neighbor_objs[ni] neighbor_obj.to(device) neighbor_diff = neighbor_diffs[ni] neighbor_diff.to(device) with open(os.path.join(cur_res_dir, neighbor_fn+'.orig.diff'), 'w') as fout: fout.write(str(neighbor_diff)) root_code = encoder.encode_tree_diff(obj, neighbor_diff) recon_diff = decoder.decode_tree_diff(root_code, obj) with open(os.path.join(cur_res_dir, neighbor_fn+'.recon.diff'), 'w') as fout: fout.write(str(recon_diff)) recon_neighbor = Tree(Tree.apply_shape_diff(obj.root, recon_diff)) PartNetDataset.save_object(recon_neighbor, os.path.join(cur_res_dir, neighbor_fn+'.recon.json')) cd = geometry_dist(neighbor_obj, recon_neighbor) sd = struct_dist(neighbor_obj.root, recon_neighbor.root) sd = sd / len(neighbor_obj.root.boxes()) with open(os.path.join(cur_res_dir, neighbor_fn+'.stats'), 'w') as fout: fout.write('cd: %f\nsd: %f\n' % (cd, sd)) print('computing stats ...') compute_recon_numbers( in_dir=result_dir, baseline_dir=conf.baseline_dir, shapediff_topk=conf.shapediff_topk)
models = [decoder] model_names = ['decoder'] # load pretrained model __ = utils.load_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.model_path, conf.exp_name), epoch=conf.model_epoch, strict=True) # send to device for m in models: m.to(device) # set models to evaluation mode for m in models: m.eval() # generate shapes with torch.no_grad(): gen_pcs = [] gen_objs = [] for i in range(conf.num_gen): print(f'Generating {i}/{conf.num_gen} ...') code = torch.randn(1, conf.feature_size).cuda() obj = decoder.decode_structure(z=code, max_depth=conf.max_tree_depth) output_filename = os.path.join(conf.result_path, conf.exp_name, 'object-%04d.json' % i) PartNetDataset.save_object(obj=obj, fn=output_filename)
encoder = models.RecursiveEncoder(conf, variational=True, probabilistic=False) decoder = models.RecursiveDecoder(conf) models = [encoder, decoder] model_names = ['encoder', 'decoder'] # load pretrained model __ = utils.load_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.ckpt_path, conf.exp_name), epoch=conf.model_epoch, strict=True, device=device) # create dataset and data loader data_features = ['object', 'name'] dataset = PartNetDataset(conf.data_path, conf.test_dataset, data_features) # send to device for m in models: m.to(device) # set models to evaluation mode for m in models: m.eval() # test over all test shapes with torch.no_grad(): pbar = ProgressBar() print('generating edits ...') for i in pbar(range(len(dataset))): obj, obj_name = dataset[i]
def train(conf): # load network model models = utils.get_model_module(conf.model_version) # check if training run already exists. If so, delete it. if os.path.exists(os.path.join(conf.log_path, conf.exp_name)) or \ os.path.exists(os.path.join(conf.model_path, conf.exp_name)): response = input( 'A training run named "%s" already exists, overwrite? (y/n) ' % (conf.exp_name)) if response != 'y': sys.exit() if os.path.exists(os.path.join(conf.log_path, conf.exp_name)): shutil.rmtree(os.path.join(conf.log_path, conf.exp_name)) if os.path.exists(os.path.join(conf.model_path, conf.exp_name)): shutil.rmtree(os.path.join(conf.model_path, conf.exp_name)) # create directories for this run os.makedirs(os.path.join(conf.model_path, conf.exp_name)) os.makedirs(os.path.join(conf.log_path, conf.exp_name)) # file log flog = open(os.path.join(conf.log_path, conf.exp_name, 'train.log'), 'w') # set training device device = torch.device(conf.device) print(f'Using device: {conf.device}') flog.write(f'Using device: {conf.device}\n') # log the object category information print(f'Object Category: {conf.category}') flog.write(f'Object Category: {conf.category}\n') # control randomness if conf.seed < 0: conf.seed = random.randint(1, 10000) print("Random Seed: %d" % (conf.seed)) flog.write(f'Random Seed: {conf.seed}\n') random.seed(conf.seed) np.random.seed(conf.seed) torch.manual_seed(conf.seed) # save config torch.save(conf, os.path.join(conf.model_path, conf.exp_name, 'conf.pth')) # create models encoder = models.RecursiveEncoder(conf, variational=True, probabilistic=not conf.non_variational) decoder = models.RecursiveDecoder(conf) models = [encoder, decoder] model_names = ['encoder', 'decoder'] # load pretrained part AE/VAE pretrain_ckpt_dir = os.path.join(conf.model_path, conf.part_pc_exp_name) pretrain_ckpt_epoch = conf.part_pc_model_epoch print( f'Loading ckpt from {pretrain_ckpt_dir}: epoch {pretrain_ckpt_epoch}') __ = utils.load_checkpoint( models=[ encoder.node_encoder.part_encoder, decoder.node_decoder.part_decoder ], model_names=['part_pc_encoder', 'part_pc_decoder'], dirname=pretrain_ckpt_dir, epoch=pretrain_ckpt_epoch, strict=True) # set part_encoder and part_decoder BatchNorm to eval mode encoder.node_encoder.part_encoder.eval() for param in encoder.node_encoder.part_encoder.parameters(): param.requires_grad = False decoder.node_decoder.part_decoder.eval() for param in decoder.node_decoder.part_decoder.parameters(): param.requires_grad = False # create optimizers encoder_opt = torch.optim.Adam(encoder.parameters(), lr=conf.lr) decoder_opt = torch.optim.Adam(decoder.parameters(), lr=conf.lr) optimizers = [encoder_opt, decoder_opt] optimizer_names = ['encoder', 'decoder'] # learning rate scheduler encoder_scheduler = torch.optim.lr_scheduler.StepLR(encoder_opt, \ step_size=conf.lr_decay_every, gamma=conf.lr_decay_by) decoder_scheduler = torch.optim.lr_scheduler.StepLR(decoder_opt, \ step_size=conf.lr_decay_every, gamma=conf.lr_decay_by) # create training and validation datasets and data loaders data_features = ['object'] train_dataset = PartNetDataset(conf.data_path, conf.train_dataset, data_features, \ load_geo=conf.load_geo) valdt_dataset = PartNetDataset(conf.data_path, conf.val_dataset, data_features, \ load_geo=conf.load_geo) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, \ shuffle=True, collate_fn=utils.collate_feats) valdt_dataloader = torch.utils.data.DataLoader(valdt_dataset, batch_size=conf.batch_size, \ shuffle=True, collate_fn=utils.collate_feats) # create logs if not conf.no_console_log: header = ' Time Epoch Dataset Iteration Progress(%) LR LatentLoss GeoLoss CenterLoss ScaleLoss StructLoss EdgeExists KLDivLoss SymLoss AdjLoss TotalLoss' if not conf.no_tb_log: # https://github.com/lanpa/tensorboard-pytorch from tensorboardX import SummaryWriter train_writer = SummaryWriter( os.path.join(conf.log_path, conf.exp_name, 'train')) valdt_writer = SummaryWriter( os.path.join(conf.log_path, conf.exp_name, 'val')) # send parameters to device for m in models: m.to(device) for o in optimizers: utils.optimizer_to_device(o, device) # start training print("Starting training ...... ") flog.write('Starting training ......\n') start_time = time.time() last_checkpoint_step = None last_train_console_log_step, last_valdt_console_log_step = None, None train_num_batch, valdt_num_batch = len(train_dataloader), len( valdt_dataloader) # train for every epoch for epoch in range(conf.epochs): if not conf.no_console_log: print(f'training run {conf.exp_name}') flog.write(f'training run {conf.exp_name}\n') print(header) flog.write(header + '\n') train_batches = enumerate(train_dataloader, 0) valdt_batches = enumerate(valdt_dataloader, 0) train_fraction_done, valdt_fraction_done = 0.0, 0.0 valdt_batch_ind = -1 # train for every batch for train_batch_ind, batch in train_batches: train_fraction_done = (train_batch_ind + 1) / train_num_batch train_step = epoch * train_num_batch + train_batch_ind log_console = not conf.no_console_log and (last_train_console_log_step is None or \ train_step - last_train_console_log_step >= conf.console_log_interval) if log_console: last_train_console_log_step = train_step # make sure the models are in eval mode to deactivate BatchNorm for PartEncoder and PartDecoder # there are no other BatchNorm / Dropout in the rest of the network for m in models: m.eval() # forward pass (including logging) total_loss = forward(batch=batch, data_features=data_features, encoder=encoder, decoder=decoder, device=device, conf=conf, is_valdt=False, step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time, log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=train_writer, lr=encoder_opt.param_groups[0]['lr'], flog=flog) # optimize one step encoder_scheduler.step() decoder_scheduler.step() encoder_opt.zero_grad() decoder_opt.zero_grad() total_loss.backward() encoder_opt.step() decoder_opt.step() # save checkpoint with torch.no_grad(): if last_checkpoint_step is None or \ train_step - last_checkpoint_step >= conf.checkpoint_interval: print("Saving checkpoint ...... ", end='', flush=True) flog.write("Saving checkpoint ...... ") utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join( conf.model_path, conf.exp_name), epoch=epoch, prepend_epoch=True, optimizers=optimizers, optimizer_names=model_names) print("DONE") flog.write("DONE\n") last_checkpoint_step = train_step # validate one batch while valdt_fraction_done <= train_fraction_done and valdt_batch_ind + 1 < valdt_num_batch: valdt_batch_ind, batch = next(valdt_batches) valdt_fraction_done = (valdt_batch_ind + 1) / valdt_num_batch valdt_step = (epoch + valdt_fraction_done) * train_num_batch - 1 log_console = not conf.no_console_log and (last_valdt_console_log_step is None or \ valdt_step - last_valdt_console_log_step >= conf.console_log_interval) if log_console: last_valdt_console_log_step = valdt_step # set models to evaluation mode for m in models: m.eval() with torch.no_grad(): # forward pass (including logging) __ = forward(batch=batch, data_features=data_features, encoder=encoder, decoder=decoder, device=device, conf=conf, is_valdt=True, step=valdt_step, epoch=epoch, batch_ind=valdt_batch_ind, num_batch=valdt_num_batch, start_time=start_time, log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=valdt_writer, lr=encoder_opt.param_groups[0]['lr'], flog=flog) # save the final models print("Saving final checkpoint ...... ", end='', flush=True) flog.write("Saving final checkpoint ...... ") utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.model_path, conf.exp_name), epoch=epoch, prepend_epoch=False, optimizers=optimizers, optimizer_names=optimizer_names) print("DONE") flog.write("DONE\n") flog.close()
data_path = '../data/partnetdata/chair_hier' dataset_fn = 'train_no_other_less_than_10_parts.txt' out_dir = os.path.join(data_path, 'neighbors_cd', os.path.splitext(dataset_fn)[0]) if os.path.exists(out_dir): response = input('output directory "%s" already exists, overwrite? (y/n) ' % out_dir) if response != 'y': sys.exit() shutil.rmtree(out_dir) # create a new directory to store eval results os.makedirs(out_dir) # create dataset and data loader data_features = ['object', 'name'] dataset = PartNetDataset(data_path, dataset_fn, data_features, load_geo=False) # parameters device = 'cuda:0' # enumerate over all training shapes num_shape = len(dataset) n_points = 2048 print('Creating point clouds ...') pcs = np.zeros((num_shape, n_points, 3), dtype=np.float32) objs = [] names = [] bar = ProgressBar() for i in bar(range(num_shape)): obj, name = dataset[i]
all_tuples = [l.rstrip().split() for l in fin.readlines()] print(len(all_tuples)) # test over all tuples with torch.no_grad(): bar = ProgressBar() tot_cd = 0 tot_sd = 0 tot_cnt = 0 for i in bar(range(conf.num_tuples)): cur_res_dir = os.path.join(result_dir, '%06d' % i) os.mkdir(cur_res_dir) # load tuple (A, B, C, D) name_A, name_B, name_C, name_D = all_tuples[i] obj_A = PartNetDataset.load_object( os.path.join(conf.data_path1, name_A + '.json')).to(device) PartNetDataset.save_object(obj_A, os.path.join(cur_res_dir, 'obj_A.json')) obj_B = PartNetDataset.load_object( os.path.join(conf.data_path1, name_B + '.json')).to(device) PartNetDataset.save_object(obj_B, os.path.join(cur_res_dir, 'obj_B.json')) obj_C = PartNetDataset.load_object( os.path.join(conf.data_path2, name_C + '.json')).to(device) PartNetDataset.save_object(obj_C, os.path.join(cur_res_dir, 'obj_C.json')) obj_D = PartNetDataset.load_object( os.path.join(conf.data_path2, name_D + '.json')).to(device) PartNetDataset.save_object(obj_D, os.path.join(cur_res_dir, 'obj_D.json'))
model_names = ['encoder', 'decoder'] # load pretrained model __ = utils.load_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.ckpt_path, conf.exp_name), epoch=conf.model_epoch, strict=True) # set models to evaluation mode for m in models: m.eval() # test over all test shapes with torch.no_grad(): objA = PartNetDataset.load_object( os.path.join(conf.data_path, '%s.json' % conf.shapeA)) objB = PartNetDataset.load_object( os.path.join(conf.data_path, '%s.json' % conf.shapeB)) objC = PartNetDataset.load_object( os.path.join(conf.data_path, '%s.json' % conf.shapeC)) encoder.encode_tree(objA) diffAB = Tree.compute_shape_diff(objA.root, objB.root) code = encoder.encode_tree_diff(objA, diffAB) encoder.encode_tree(objC) recon_obj_diff = decoder.decode_tree_diff(code, objC) recon_obj = Tree(Tree.apply_shape_diff(objC.root, recon_obj_diff)) PartNetDataset.save_object(objA, os.path.join(result_dir, 'shapeA.json')) PartNetDataset.save_object(objB, os.path.join(result_dir, 'shapeB.json'))