def __init__(self, cfg, optimizer=None): super().__init__() self.cfg = cfg self.pointnet2 = PointNet2(cfg) self.img_enc = ImgEncoder(cfg) self.displacement_net = DisplacementNet(cfg) self.optimizer = None if optimizer is None else optimizer(self.parameters()) # 2D supervision part self.projector = Projector(self.cfg) # proj loss self.proj_loss = ProjectLoss(self.cfg) # emd loss self.emd = EMD() if torch.cuda.is_available(): self.img_enc = torch.nn.DataParallel(self.img_enc, device_ids=cfg.CONST.DEVICE).cuda() self.pointnet2 = torch.nn.DataParallel(self.pointnet2, device_ids=cfg.CONST.DEVICE).cuda() self.projector = torch.nn.DataParallel(self.projector, device_ids=cfg.CONST.DEVICE).cuda() self.proj_loss = torch.nn.DataParallel(self.proj_loss, device_ids=cfg.CONST.DEVICE).cuda() self.emd = torch.nn.DataParallel(self.emd, device_ids=cfg.CONST.DEVICE).cuda() self.cuda()
def __init__(self, cfg, optimizer, scheduler): super().__init__() self.cfg = cfg # Init the generator and refiner self.generator = GRAPHX_Generator(cfg=cfg, in_channels=3, in_instances=cfg.GRAPHX.NUM_INIT_POINTS, use_graphx=cfg.GRAPHX.USE_GRAPHX) self.refiner = EdgeConv_Refiner(cfg=cfg, num_points=cfg.CONST.NUM_POINTS, use_SElayer=True) self.optimizer = None if optimizer is None else optimizer(self.parameters()) self.scheduler = None if scheduler or optimizer is None else scheduler(self.optimizer) self.emd = EMD() if torch.cuda.is_available(): self.generator = torch.nn.DataParallel(self.generator, device_ids=cfg.CONST.DEVICE).cuda() self.refiner = torch.nn.DataParallel(self.refiner, device_ids=cfg.CONST.DEVICE).cuda() self.emd = torch.nn.DataParallel(self.emd, device_ids=cfg.CONST.DEVICE).cuda() self.cuda() # Load pretrained generator if cfg.REFINE.USE_PRETRAIN_GENERATOR: print('[INFO] %s Recovering from %s ...' % (dt.now(), cfg.REFINE.GENERATOR_WEIGHTS)) checkpoint = torch.load(cfg.REFINE.GENERATOR_WEIGHTS) self.generator.load_state_dict(checkpoint['net']) print('Best Epoch: %d' % (checkpoint['epoch_idx']))
def __init__(self, cfg, in_channels, in_instances, activation=nn.ReLU(), optimizer=None, scheduler=None, use_graphx=True, **kwargs): super().__init__() self.cfg = cfg # Graphx self.img_enc = CNN18Encoder(in_channels, activation) out_features = [block[-2].out_channels for block in self.img_enc.children()] self.pc_enc = PointCloudEncoder(3, out_features, cat_pc=True, use_adain=True, use_proj=True, activation=activation) deform_net = PointCloudGraphXDecoder if use_graphx else PointCloudDecoder self.pc = deform_net(2 * sum(out_features) + 3, in_instances=in_instances, activation=activation) self.kwargs = kwargs # emd loss self.emd = EMD() if torch.cuda.is_available(): self.img_enc = torch.nn.DataParallel(self.img_enc, device_ids=cfg.CONST.DEVICE).cuda() self.pc_enc = torch.nn.DataParallel(self.pc_enc, device_ids=cfg.CONST.DEVICE).cuda() self.pc = torch.nn.DataParallel(self.pc, device_ids=cfg.CONST.DEVICE).cuda() self.emd = torch.nn.DataParallel(self.emd, device_ids=cfg.CONST.DEVICE).cuda() self.cuda()
def __init__(self, cfg, in_channels, activation=nn.ReLU(), optimizer=None): super().__init__() self.cfg = cfg self.img_enc = CNN18Encoder(in_channels, activation) self.transform_pc = TransformPC(cfg) self.feature_projection = FeatureProjection(cfg) if cfg.UPDATER.PC_ENCODE_MODULE == 'Pointnet++': self.pc_encode = PointNet2(cfg) elif cfg.UPDATER.PC_ENCODE_MODULE == 'EdgeRes': self.pc_encode = EdgeRes(use_SElayer=True) if cfg.UPDATER.PC_DECODE_MODULE == 'Linear': self.displacement_net = LinearDisplacementNet(cfg) elif cfg.UPDATER.PC_DECODE_MODULE == 'Graphx': self.displacement_net = GraphxDisplacementNet(cfg) self.optimizer = None if optimizer is None else optimizer(self.parameters()) # emd loss self.emd = EMD() if torch.cuda.is_available(): self.img_enc = torch.nn.DataParallel(self.img_enc, device_ids=cfg.CONST.DEVICE).cuda() self.transform_pc = torch.nn.DataParallel(self.transform_pc, device_ids=cfg.CONST.DEVICE).cuda() self.feature_projection = torch.nn.DataParallel(self.feature_projection, device_ids=cfg.CONST.DEVICE).cuda() self.pc_encode = torch.nn.DataParallel(self.pc_encode, device_ids=cfg.CONST.DEVICE).cuda() self.displacement_net = torch.nn.DataParallel(self.displacement_net, device_ids=cfg.CONST.DEVICE).cuda() self.emd = torch.nn.DataParallel(self.emd, device_ids=cfg.CONST.DEVICE).cuda() self.cuda()
def __init__(self, cfg, optimizer, num_points: int = 2048, use_SElayer: bool = True): super().__init__() self.cfg = cfg self.num_points = num_points self.residual = EdgeRes(use_SElayer=use_SElayer) self.optimizer = None if optimizer is None else optimizer( self.parameters()) self.emd = EMD() if torch.cuda.is_available(): self.residual = torch.nn.DataParallel( self.residual, device_ids=cfg.CONST.DEVICE).cuda() self.emd = torch.nn.DataParallel( self.emd, device_ids=cfg.CONST.DEVICE).cuda() self.cuda()
def main(config): set_seed(config['seed']) results_dir = prepare_results_dir(config, config['arch'], 'experiments', dirs_to_create=[ 'interpolations', 'sphere', 'points_interpolation', 'different_number_points', 'fixed', 'reconstruction', 'sphere_triangles', 'sphere_triangles_interpolation' ]) weights_path = get_weights_dir(config) epoch = find_latest_epoch(weights_path) if not epoch: print("Invalid 'weights_path' in configuration") exit(1) setup_logging(results_dir) global log log = logging.getLogger('aae') if not exists(join(results_dir, 'experiment_config.json')): with open(join(results_dir, 'experiment_config.json'), mode='w') as f: json.dump(config, f) device = cuda_setup(config['cuda'], config['gpu']) log.info(f'Device variable: {device}') if device.type == 'cuda': log.info(f'Current CUDA device: {torch.cuda.current_device()}') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) elif dataset_name == 'custom': dataset = TxtDataset(root_dir=config['data_dir'], classes=config['classes'], config=config) elif dataset_name == 'benchmark': dataset = Benchmark(root_dir=config['data_dir'], classes=config['classes'], config=config) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.info("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8, drop_last=True, pin_memory=True, collate_fn=collate_fn) # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) encoder_visible = aae.VisibleEncoder(config).to(device) encoder_pocket = aae.PocketEncoder(config).to(device) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': # from utils.metrics import earth_mover_distance # reconstruction_loss = earth_mover_distance from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) else: raise ValueError( f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') log.info("Weights for epoch: %s" % epoch) log.info("Loading weights...") hyper_network.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_G.pth'))) encoder_pocket.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_EP.pth'))) encoder_visible.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_EV.pth'))) hyper_network.eval() encoder_visible.eval() encoder_pocket.eval() total_loss_eg = 0.0 total_loss_e = 0.0 total_loss_kld = 0.0 x = [] with torch.no_grad(): for i, point_data in enumerate(points_dataloader, 1): X = point_data['non-visible'] X = X.to(device, dtype=torch.float) # get whole point cloud X_whole = point_data['cloud'] X_whole = X_whole.to(device, dtype=torch.float) # get visible point cloud X_visible = point_data['visible'] X_visible = X_visible.to(device, dtype=torch.float) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) x.append(X) codes, mu, logvar = encoder_pocket(X) mu_visible = encoder_visible(X_visible) target_networks_weights = hyper_network( torch.cat((codes, mu_visible), 1)) X_rec = torch.zeros(X_whole.shape).to(device) for j, target_network_weights in enumerate( target_networks_weights): target_network = aae.TargetNetwork( config, target_network_weights).to(device) target_network_input = generate_points(config=config, epoch=epoch, size=(X_whole.shape[2], X_whole.shape[1])) X_rec[j] = torch.transpose( target_network(target_network_input.to(device)), 0, 1) loss_e = torch.mean(config['reconstruction_coef'] * reconstruction_loss( X_whole.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 - logvar).sum() loss_eg = loss_e + loss_kld total_loss_e += loss_e.item() total_loss_kld += loss_kld.item() total_loss_eg += loss_eg.item() log.info(f'Loss_ALL: {total_loss_eg / i:.4f} ' f'Loss_R: {total_loss_e / i:.4f} ' f'Loss_E: {total_loss_kld / i:.4f} ') # take the lowest possible first dim min_dim = min(x, key=lambda X: X.shape[2]).shape[2] x = [X[:, :, :min_dim] for X in x] x = torch.cat(x) if config['experiments']['interpolation']['execute']: interpolation( x, encoder_pocket, hyper_network, device, results_dir, epoch, config['experiments']['interpolation']['amount'], config['experiments']['interpolation']['transitions']) if config['experiments']['interpolation_between_two_points'][ 'execute']: interpolation_between_two_points( encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['interpolation_between_two_points'] ['amount'], config['experiments'] ['interpolation_between_two_points']['image_points'], config['experiments']['interpolation_between_two_points'] ['transitions']) if config['experiments']['reconstruction']['execute']: reconstruction(encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['reconstruction']['amount']) if config['experiments']['sphere']['execute']: sphere(encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['sphere']['amount'], config['experiments']['sphere']['image_points'], config['experiments']['sphere']['start'], config['experiments']['sphere']['end'], config['experiments']['sphere']['transitions']) if config['experiments']['sphere_triangles']['execute']: sphere_triangles( encoder_pocket, hyper_network, device, x, results_dir, config['experiments']['sphere_triangles']['amount'], config['experiments']['sphere_triangles']['method'], config['experiments']['sphere_triangles']['depth'], config['experiments']['sphere_triangles']['start'], config['experiments']['sphere_triangles']['end'], config['experiments']['sphere_triangles']['transitions']) if config['experiments']['sphere_triangles_interpolation']['execute']: sphere_triangles_interpolation( encoder_pocket, hyper_network, device, x, results_dir, config['experiments']['sphere_triangles_interpolation'] ['amount'], config['experiments'] ['sphere_triangles_interpolation']['method'], config['experiments']['sphere_triangles_interpolation'] ['depth'], config['experiments'] ['sphere_triangles_interpolation']['coefficient'], config['experiments']['sphere_triangles_interpolation'] ['transitions']) if config['experiments']['different_number_of_points']['execute']: different_number_of_points( encoder_pocket, hyper_network, x, device, results_dir, epoch, config['experiments']['different_number_of_points']['amount'], config['experiments']['different_number_of_points'] ['image_points']) if config['experiments']['fixed']['execute']: # get visible element from loader (probably should be done using given object for example using # parser points_dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=8, drop_last=True, pin_memory=True, collate_fn=collate_fn) X_visible = next(iter(points_dataloader))['visible'].to( device, dtype=torch.float) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) fixed(hyper_network, encoder_visible, X_visible, device, results_dir, epoch, config['experiments']['fixed']['amount'], config['z_size'] // 2, config['experiments']['fixed']['mean'], config['experiments']['fixed']['std'], (3, 2048), config['experiments']['fixed']['triangulation']['execute'], config['experiments']['fixed']['triangulation']['method'], config['experiments']['fixed']['triangulation']['depth'])
def main(config): random.seed(config['seed']) torch.manual_seed(config['seed']) torch.cuda.manual_seed_all(config['seed']) results_dir = prepare_results_dir(config) starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger(__name__) device = cuda_setup(config['cuda'], config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.debug("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True) # # Models # arch = import_module(f"models.{config['arch']}") G = arch.Generator(config).to(device) E = arch.Encoder(config).to(device) D = arch.Discriminator(config).to(device) G.apply(weights_init) E.apply(weights_init) D.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) else: raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Float Tensors # distribution = config['distribution'].lower() if distribution == 'bernoulli': p = torch.tensor(config['p']).to(device) sampler = Bernoulli(probs=p) fixed_noise = sampler.sample(torch.Size([config['batch_size'], config['z_size']])) elif distribution == 'beta': fixed_noise_np = np.random.beta(config['z_beta_a'], config['z_beta_b'], size=(config['batch_size'], config['z_size'])) fixed_noise = torch.tensor(fixed_noise_np).float().to(device) else: raise ValueError('Invalid distribution for binaray model.') # # Optimizers # EG_optim = getattr(optim, config['optimizer']['EG']['type']) EG_optim = EG_optim(chain(E.parameters(), G.parameters()), **config['optimizer']['EG']['hyperparams']) D_optim = getattr(optim, config['optimizer']['D']['type']) D_optim = D_optim(D.parameters(), **config['optimizer']['D']['hyperparams']) if starting_epoch > 1: G.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_G.pth'))) E.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_E.pth'))) D.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_D.pth'))) D_optim.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_Do.pth'))) EG_optim.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) loss_d_tot, loss_gp_tot, loss_e_tot, loss_g_tot = [], [], [], [] for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() G.train() E.train() D.train() total_loss_eg = 0.0 total_loss_d = 0.0 for i, point_data in enumerate(points_dataloader, 1): log.debug('-' * 20) X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) codes = E(X) if distribution == 'bernoulli': noise = sampler.sample(fixed_noise.size()) elif distribution == 'beta': noise_np = np.random.beta(config['z_beta_a'], config['z_beta_b'], size=(config['batch_size'], config['z_size'])) noise = torch.tensor(noise_np).float().to(device) synth_logit = D(codes) real_logit = D(noise) loss_d = torch.mean(synth_logit) - torch.mean(real_logit) loss_d_tot.append(loss_d) # Gradient Penalty alpha = torch.rand(config['batch_size'], 1).to(device) differences = codes - noise interpolates = noise + alpha * differences disc_interpolates = D(interpolates) gradients = grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(disc_interpolates).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0] slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1)) gradient_penalty = ((slopes - 1) ** 2).mean() loss_gp = config['gp_lambda'] * gradient_penalty loss_gp_tot.append(loss_gp) ### loss_d += loss_gp D_optim.zero_grad() D.zero_grad() loss_d.backward(retain_graph=True) total_loss_d += loss_d.item() D_optim.step() # EG part of training X_rec = G(codes) loss_e = torch.mean( config['reconstruction_coef'] * reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_e_tot.append(loss_e) synth_logit = D(codes) loss_g = -torch.mean(synth_logit) loss_g_tot.append(loss_g) loss_eg = loss_e + loss_g EG_optim.zero_grad() E.zero_grad() G.zero_grad() loss_eg.backward() total_loss_eg += loss_eg.item() EG_optim.step() log.debug(f'[{epoch}: ({i})] ' f'Loss_D: {loss_d.item():.4f} ' f'(GP: {loss_gp.item(): .4f}) ' f'Loss_EG: {loss_eg.item():.4f} ' f'(REC: {loss_e.item(): .4f}) ' f'Time: {datetime.now() - start_epoch_time}') log.debug( f'[{epoch}/{config["max_epochs"]}] ' f'Loss_D: {total_loss_d / i:.4f} ' f'Loss_EG: {total_loss_eg / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}' ) # # Save intermediate results # G.eval() E.eval() D.eval() with torch.no_grad(): fake = G(fixed_noise).data.cpu().numpy() X_rec = G(E(X)).data.cpu().numpy() X = X.data.cpu().numpy() plt.figure(figsize=(16, 9)) plt.plot(loss_d_tot, 'r-', label="loss_d") plt.plot(loss_gp_tot, 'g-', label="loss_gp") plt.plot(loss_e_tot, 'b-', label="loss_e") plt.plot(loss_g_tot, 'k-', label="loss_g") plt.legend() plt.xlabel("Batch number") plt.xlabel("Loss value") plt.savefig( join(results_dir, 'samples', f'loss_plot.png')) plt.close() for k in range(5): fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_real.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig(join(results_dir, 'samples', f'{epoch:05}_{k}_reconstructed.png')) plt.close(fig) if epoch % config['save_frequency'] == 0: torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(D.state_dict(), join(weights_path, f'{epoch:05}_D.pth')) torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(EG_optim.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth')) torch.save(D_optim.state_dict(), join(weights_path, f'{epoch:05}_Do.pth'))
def main(config): set_seed(config['seed']) results_dir = prepare_results_dir(config, 'vae', 'training') starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger('vae') device = cuda_setup(config['cuda'], config['gpu']) log.info(f'Device variable: {device}') if device.type == 'cuda': log.info(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') metrics_path = join(results_dir, 'metrics') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) elif dataset_name == 'custom': dataset = TxtDataset(root_dir=config['data_dir'], classes=config['classes'], config=config) elif dataset_name == 'benchmark': dataset = Benchmark(root_dir=config['data_dir'], classes=config['classes'], config=config) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.info("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True, collate_fn=collate_fn) # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) encoder_pocket = aae.PocketEncoder(config).to(device) encoder_visible = aae.VisibleEncoder(config).to(device) hyper_network.apply(weights_init) encoder_pocket.apply(weights_init) encoder_visible.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': # from utils.metrics import earth_mover_distance # reconstruction_loss = earth_mover_distance from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) else: raise ValueError( f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Optimizers # e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type']) e_hn_optimizer = e_hn_optimizer( chain(encoder_visible.parameters(), encoder_pocket.parameters(), hyper_network.parameters()), **config['optimizer']['E_HN']['hyperparams']) log.info("Starting epoch: %s" % starting_epoch) if starting_epoch > 1: log.info("Loading weights...") hyper_network.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_G.pth'))) encoder_pocket.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EP.pth'))) encoder_visible.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EV.pth'))) e_hn_optimizer.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EGo.pth'))) log.info("Loading losses...") losses_e = np.load(join(metrics_path, f'{starting_epoch - 1:05}_E.npy')).tolist() losses_kld = np.load( join(metrics_path, f'{starting_epoch - 1:05}_KLD.npy')).tolist() losses_eg = np.load( join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist() else: log.info("First epoch") losses_e = [] losses_kld = [] losses_eg = [] if config['target_network_input']['normalization']['enable']: normalization_type = config['target_network_input']['normalization'][ 'type'] assert normalization_type == 'progressive', 'Invalid normalization type' target_network_input = None for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() log.debug("Epoch: %s" % epoch) hyper_network.train() encoder_visible.train() encoder_pocket.train() total_loss_all = 0.0 total_loss_r = 0.0 total_loss_kld = 0.0 for i, point_data in enumerate(points_dataloader, 1): # get only visible part of point cloud X = point_data['non-visible'] X = X.to(device, dtype=torch.float) # get not visible part of point cloud X_visible = point_data['visible'] X_visible = X_visible.to(device, dtype=torch.float) # get whole point cloud X_whole = point_data['cloud'] X_whole = X_whole.to(device, dtype=torch.float) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1) codes, mu, logvar = encoder_pocket(X) mu_visible = encoder_visible(X_visible) target_networks_weights = hyper_network( torch.cat((codes, mu_visible), 1)) X_rec = torch.zeros(X_whole.shape).to(device) for j, target_network_weights in enumerate( target_networks_weights): target_network = aae.TargetNetwork( config, target_network_weights).to(device) if not config['target_network_input'][ 'constant'] or target_network_input is None: target_network_input = generate_points( config=config, epoch=epoch, size=(X_whole.shape[2], X_whole.shape[1])) X_rec[j] = torch.transpose( target_network(target_network_input.to(device)), 0, 1) loss_r = torch.mean(config['reconstruction_coef'] * reconstruction_loss( X_whole.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 - logvar).sum() loss_all = loss_r + loss_kld e_hn_optimizer.zero_grad() encoder_visible.zero_grad() encoder_pocket.zero_grad() hyper_network.zero_grad() loss_all.backward() e_hn_optimizer.step() total_loss_r += loss_r.item() total_loss_kld += loss_kld.item() total_loss_all += loss_all.item() log.info(f'[{epoch}/{config["max_epochs"]}] ' f'Loss_ALL: {total_loss_all / i:.4f} ' f'Loss_R: {total_loss_r / i:.4f} ' f'Loss_E: {total_loss_kld / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}') losses_e.append(total_loss_r) losses_kld.append(total_loss_kld) losses_eg.append(total_loss_all) # # Save intermediate results # X = X.cpu().numpy() X_whole = X_whole.cpu().numpy() X_rec = X_rec.detach().cpu().numpy() if epoch % config['save_frequency'] == 0: for k in range(min(5, X_rec.shape[0])): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png')) plt.close(fig) fig = plot_3d_point_cloud(X_whole[k][0], X_whole[k][1], X_whole[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_real.png')) plt.close(fig) fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_visible.png')) plt.close(fig) if config['clean_weights_dir']: log.debug('Cleaning weights path: %s' % weights_path) shutil.rmtree(weights_path, ignore_errors=True) os.makedirs(weights_path, exist_ok=True) if epoch % config['save_frequency'] == 0: log.debug('Saving data...') torch.save(hyper_network.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(encoder_visible.state_dict(), join(weights_path, f'{epoch:05}_EV.pth')) torch.save(encoder_pocket.state_dict(), join(weights_path, f'{epoch:05}_EP.pth')) torch.save(e_hn_optimizer.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth')) np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e)) np.save(join(metrics_path, f'{epoch:05}_KLD'), np.array(losses_kld)) np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
def main(config): random.seed(config['seed']) torch.manual_seed(config['seed']) torch.cuda.manual_seed_all(config['seed']) results_dir = prepare_results_dir(config) starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger(__name__) device = cuda_setup(config['cuda'], config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.debug("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True) # # Models # arch = import_module(f"model.architectures.{config['arch']}") G = arch.Generator(config).to(device) E = arch.Encoder(config).to(device) G.apply(weights_init) E.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': reconstruction_loss = EMD().to(device) else: raise ValueError( f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Optimizers # EG_optim = getattr(optim, config['optimizer']['EG']['type']) EG_optim = EG_optim(chain(E.parameters(), G.parameters()), **config['optimizer']['EG']['hyperparams']) if starting_epoch > 1: G.load_state_dict( torch.load(join(weights_path, f'{starting_epoch-1:05}_G.pth'))) E.load_state_dict( torch.load(join(weights_path, f'{starting_epoch-1:05}_E.pth'))) EG_optim.load_state_dict( torch.load(join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() G.train() E.train() total_loss = 0.0 for i, point_data in enumerate(points_dataloader, 1): log.debug('-' * 20) X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_rec = G(E(X)) loss = torch.mean(config['reconstruction_coef'] * reconstruction_loss( X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) EG_optim.zero_grad() E.zero_grad() G.zero_grad() loss.backward() total_loss += loss.item() EG_optim.step() log.debug(f'[{epoch}: ({i})] ' f'Loss: {loss.item():.4f} ' f'Time: {datetime.now() - start_epoch_time}') log.debug(f'[{epoch}/{config["max_epochs"]}] ' f'Loss: {total_loss / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}') # # Save intermediate results # G.eval() E.eval() with torch.no_grad(): X_rec = G(E(X)).data.cpu().numpy() for k in range(5): fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_real.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_reconstructed.png')) plt.close(fig) if epoch % config['save_frequency'] == 0: torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(EG_optim.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth'))
def main(config): random.seed(config['seed']) torch.manual_seed(config['seed']) torch.cuda.manual_seed_all(config['seed']) results_dir = prepare_results_dir(config) starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger('vae') device = cuda_setup(config['cuda'], config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.debug("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True) # # Models # arch = import_module(f"models.{config['arch']}") G = arch.Generator(config).to(device) E = arch.Encoder(config).to(device) G.apply(weights_init) E.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) elif config['reconstruction_loss'].lower() == 'cramer_wold': from losses.cramer_wold import CWSample reconstruction_loss = CWSample().to(device) else: raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Float Tensors # fixed_noise = torch.FloatTensor(config['batch_size'], config['z_size'], 1) fixed_noise.normal_(mean=0, std=0.2) std_assumed = torch.tensor(0.2) fixed_noise = fixed_noise.to(device) std_assumed = std_assumed.to(device) # # Optimizers # EG_optim = getattr(optim, config['optimizer']['EG']['type']) EG_optim = EG_optim(chain(E.parameters(), G.parameters()), **config['optimizer']['EG']['hyperparams']) if starting_epoch > 1: G.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_G.pth'))) E.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_E.pth'))) EG_optim.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) losses = [] with trange(starting_epoch, config['max_epochs'] + 1) as t: for epoch in t: start_epoch_time = datetime.now() G.train() E.train() total_loss = 0.0 losses_eg = [] losses_e = [] losses_kld = [] for i, point_data in enumerate(points_dataloader, 1): # log.debug('-' * 20) X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) codes, mu, logvar = E(X) X_rec = G(codes) loss_e = torch.mean( # config['reconstruction_coef'] * reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_kld = config['reconstruction_coef'] * cw_distance(mu) # loss_kld = -0.5 * torch.mean( # 1 - 2.0 * torch.log(std_assumed) + logvar - # (mu.pow(2) + logvar.exp()) / torch.pow(std_assumed, 2)) loss_eg = loss_e + loss_kld EG_optim.zero_grad() E.zero_grad() G.zero_grad() loss_eg.backward() total_loss += loss_eg.item() EG_optim.step() losses_e.append(loss_e.item()) losses_kld.append(loss_kld.item()) losses_eg.append(loss_eg.item()) # log.debug t.set_description( f'[{epoch}: ({i})] ' f'Loss_EG: {loss_eg.item():.4f} ' f'(REC: {loss_e.item(): .4f}' f' KLD: {loss_kld.item(): .4f})' f' Time: {datetime.now() - start_epoch_time}' ) t.set_description( f'[{epoch}/{config["max_epochs"]}] ' f'Loss_G: {total_loss / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}' ) losses.append([ np.mean(losses_e), np.mean(losses_kld), np.mean(losses_eg) ]) # # Save intermediate results # G.eval() E.eval() with torch.no_grad(): fake = G(fixed_noise).data.cpu().numpy() codes, _, _ = E(X) X_rec = G(codes).data.cpu().numpy() X_numpy = X.cpu().numpy() for k in range(5): fig = plot_3d_point_cloud(X_numpy[k][0], X_numpy[k][1], X_numpy[k][2], in_u_sphere=True, show=False) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_real.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', 'fixed', f'{epoch:05}_{k}_fixed.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False) fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png')) plt.close(fig) if epoch % config['save_frequency'] == 0: df = pd.DataFrame(losses, columns=['loss_e', 'loss_kld', 'loss_eg']) df.to_json(join(results_dir, 'losses', 'losses.json')) fig = df.plot.line().get_figure() fig.savefig(join(results_dir, 'losses', f'{epoch:05}_{k}.png')) torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(EG_optim.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth'))
def test_net(cfg): # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use torch.backends.cudnn.benchmark = True # Set up data augmentation IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W test_transforms = utils.data_transforms.Compose([ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), utils.data_transforms.ToTensor(), ]) dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[ cfg.DATASET.TEST_DATASET](cfg) test_data_loader = torch.utils.data.DataLoader( dataset=dataset_loader.get_dataset(utils.data_loaders.DatasetType.TEST, test_transforms), batch_size=cfg.CONST.BATCH_SIZE, num_workers=1, pin_memory=True, shuffle=False) # Set up networks # Set up networks # The parameters here need to be set in cfg net = Pixel2Pointcloud( 3, cfg.GRAPHX.NUM_INIT_POINTS, optimizer=lambda x: torch.optim.Adam( x, lr=cfg.TRAIN.LEARNING_RATE, weight_decay=cfg.TRAIN.WEIGHT_DECAY ), scheduler=lambda x: MultiStepLR( x, milestones=cfg.TRAIN.MILESTONES, gamma=cfg.TRAIN.GAMMA), use_graphx=cfg.GRAPHX.USE_GRAPHX) view_encoder = Encoder(cfg) azi_classes, ele_classes = int(360 / cfg.CONST.BIN_SIZE), int( 180 / cfg.CONST.BIN_SIZE) view_estimater = ViewEstimater(cfg, azi_classes=azi_classes, ele_classes=ele_classes) if torch.cuda.is_available(): net = torch.nn.DataParallel(net).cuda() view_encoder = torch.nn.DataParallel(view_encoder).cuda() view_estimater = torch.nn.DataParallel(view_estimater).cuda() # Load weight # Load weight for encoder, decoder print('[INFO] %s Loading reconstruction weights from %s ...' % (dt.now(), cfg.TEST.RECONSTRUCTION_WEIGHTS)) rec_checkpoint = torch.load(cfg.TEST.RECONSTRUCTION_WEIGHTS) net.load_state_dict(rec_checkpoint['net']) print('[INFO] Best reconstruction result at epoch %d ...' % rec_checkpoint['epoch_idx']) # Load weight for view encoder print('[INFO] %s Loading view estimation weights from %s ...' % (dt.now(), cfg.TEST.VIEW_ENCODER_WEIGHTS)) view_enc_checkpoint = torch.load(cfg.TEST.VIEW_ENCODER_WEIGHTS) view_encoder.load_state_dict(view_enc_checkpoint['encoder_state_dict']) print('[INFO] Best view encode result at epoch %d ...' % view_enc_checkpoint['epoch_idx']) # Load weight for view estimater print('[INFO] %s Loading view estimation weights from %s ...' % (dt.now(), cfg.TEST.VIEW_ESTIMATION_WEIGHTS)) view_est_checkpoint = torch.load(cfg.TEST.VIEW_ESTIMATION_WEIGHTS) view_estimater.load_state_dict( view_est_checkpoint['view_estimator_state_dict']) print('[INFO] Best view estimation result at epoch %d ...' % view_est_checkpoint['epoch_idx']) # Set up loss functions emd = EMD().cuda() cd = ChamferLoss().cuda() # Batch average meterics cd_distances = utils.network_utils.AverageMeter() emd_distances = utils.network_utils.AverageMeter() pointwise_emd_distances = utils.network_utils.AverageMeter() test_preds = torch.zeros([1, 3], dtype=torch.float).cuda() test_ground_truth_views = torch.zeros([1, 3], dtype=torch.long).cuda() # Switch models to evaluation mode net.eval() view_encoder.eval() view_estimater.eval() n_batches = len(test_data_loader) # Testing loop for sample_idx, (taxonomy_names, sample_names, rendering_images, init_point_clouds, ground_truth_point_clouds, ground_truth_views) in enumerate(test_data_loader): with torch.no_grad(): # Only one image per sample rendering_images = torch.squeeze(rendering_images, 1) # Get data from data loader rendering_images = utils.network_utils.var_or_cuda( rendering_images) init_point_clouds = utils.network_utils.var_or_cuda( init_point_clouds) ground_truth_views = utils.network_utils.var_or_cuda( ground_truth_views) ground_truth_point_clouds = utils.network_utils.var_or_cuda( ground_truth_point_clouds) #=================================================# # Test the encoder, decoder # #=================================================# emd_loss, generated_point_clouds = net.module.loss( rendering_images, init_point_clouds, ground_truth_point_clouds, 'mean') # Compute CD, EMD cd_distance = cd(generated_point_clouds, ground_truth_point_clouds ) / cfg.CONST.BATCH_SIZE / cfg.CONST.NUM_POINTS emd_distance = emd_loss pointwise_emd_distance = emd_loss / cfg.CONST.NUM_POINTS #=================================================# # Test the view estimater # #=================================================# # output[0]:prediction of azi class # output[1]:prediction of ele class # output[2]:prediction of azi regression # output[3]:prediction of ele regression vgg_features, _ = view_encoder(rendering_images) output = view_estimater(vgg_features) #=================================================# # Get predict view # #=================================================# preds_cls = utils.view_pred_utils.get_pred_from_cls_output( [output[0], output[1]]) preds = [] for n in range(len(preds_cls)): pred_delta = output[n + 2] delta_value = pred_delta[torch.arange(pred_delta.size(0)), preds_cls[n].long()].tanh() / 2 preds.append((preds_cls[n].float() + delta_value + 0.5) * cfg.CONST.BIN_SIZE) # In this experiment, expect all the object has a fixed rotaion angle 0 and with diferent view point # Add the zero inplane term for the rotation acc calculation zero_inplane = torch.zeros_like(preds[0]) # Add fixed inplane rotation to pred view test_pred = torch.cat( (preds[0].unsqueeze(1), preds[1].unsqueeze(1), zero_inplane.unsqueeze(1)), 1) # Add fixed inplane rotation to ground_truth_views zero_inplane = zero_inplane.unsqueeze(1) ground_truth_views = torch.cat( (ground_truth_views, zero_inplane.long()), 1) # Append loss and accuracy to average metrics cd_distances.update(cd_distance.item()) emd_distances.update(emd_distance.item()) pointwise_emd_distances.update(pointwise_emd_distance.item()) # concatenate results and labels for view estimation test_preds = torch.cat((test_preds, test_pred), 0) test_ground_truth_views = torch.cat( (test_ground_truth_views, ground_truth_views), 0) print( "Test on [%d/%d] data, CD: %.4f Point EMD: %.4f Total EMD %.4f" % (sample_idx + 1, n_batches, cd_distance.item(), pointwise_emd_distance.item(), emd_distance.item())) test_preds = test_preds[1:, :] test_ground_truth_views = test_ground_truth_views[1:, :] # calculate the rotation errors between prediction and ground truth test_errs = utils.rotation_eval.rotation_err( test_preds, test_ground_truth_views.float()).cpu().numpy() Acc = 100. * np.mean(test_errs <= 30) Med = np.median(test_errs) # print result print("Reconstruction result:") print("CD result: ", cd_distances.avg) print("Pointwise EMD result: ", pointwise_emd_distances.avg) print("Total EMD result", emd_distances.avg) print("View estimation result:") print('Med_Err is %.2f, and Acc_pi/6 is %.2f \n \n' % (Med, Acc)) logname = cfg.TEST.RESULT_PATH with open(logname, 'a') as f: f.write('Reconstruction result: \n') f.write("CD result: %.8f \n" % cd_distances.avg) f.write("Pointwise EMD result: %.8f \n" % pointwise_emd_distances.avg) f.write("Total EMD result: %.8f \n" % emd_distances.avg) f.write('View estimation result: \n') f.write('Med_Err is %.2f, and Acc_pi/6 is %.2f \n \n' % (Med, Acc))