def main(config): set_seed(config['seed']) results_dir = prepare_results_dir(config, config['arch'], 'experiments', dirs_to_create=[ 'interpolations', 'sphere', 'points_interpolation', 'different_number_points', 'fixed', 'reconstruction', 'sphere_triangles', 'sphere_triangles_interpolation' ]) weights_path = get_weights_dir(config) epoch = find_latest_epoch(weights_path) if not epoch: print("Invalid 'weights_path' in configuration") exit(1) setup_logging(results_dir) global log log = logging.getLogger('aae') if not exists(join(results_dir, 'experiment_config.json')): with open(join(results_dir, 'experiment_config.json'), mode='w') as f: json.dump(config, f) device = cuda_setup(config['cuda'], config['gpu']) log.info(f'Device variable: {device}') if device.type == 'cuda': log.info(f'Current CUDA device: {torch.cuda.current_device()}') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) elif dataset_name == 'custom': dataset = TxtDataset(root_dir=config['data_dir'], classes=config['classes'], config=config) elif dataset_name == 'benchmark': dataset = Benchmark(root_dir=config['data_dir'], classes=config['classes'], config=config) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.info("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8, drop_last=True, pin_memory=True, collate_fn=collate_fn) # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) encoder_visible = aae.VisibleEncoder(config).to(device) encoder_pocket = aae.PocketEncoder(config).to(device) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': # from utils.metrics import earth_mover_distance # reconstruction_loss = earth_mover_distance from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) else: raise ValueError( f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') log.info("Weights for epoch: %s" % epoch) log.info("Loading weights...") hyper_network.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_G.pth'))) encoder_pocket.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_EP.pth'))) encoder_visible.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_EV.pth'))) hyper_network.eval() encoder_visible.eval() encoder_pocket.eval() total_loss_eg = 0.0 total_loss_e = 0.0 total_loss_kld = 0.0 x = [] with torch.no_grad(): for i, point_data in enumerate(points_dataloader, 1): X = point_data['non-visible'] X = X.to(device, dtype=torch.float) # get whole point cloud X_whole = point_data['cloud'] X_whole = X_whole.to(device, dtype=torch.float) # get visible point cloud X_visible = point_data['visible'] X_visible = X_visible.to(device, dtype=torch.float) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) x.append(X) codes, mu, logvar = encoder_pocket(X) mu_visible = encoder_visible(X_visible) target_networks_weights = hyper_network( torch.cat((codes, mu_visible), 1)) X_rec = torch.zeros(X_whole.shape).to(device) for j, target_network_weights in enumerate( target_networks_weights): target_network = aae.TargetNetwork( config, target_network_weights).to(device) target_network_input = generate_points(config=config, epoch=epoch, size=(X_whole.shape[2], X_whole.shape[1])) X_rec[j] = torch.transpose( target_network(target_network_input.to(device)), 0, 1) loss_e = torch.mean(config['reconstruction_coef'] * reconstruction_loss( X_whole.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 - logvar).sum() loss_eg = loss_e + loss_kld total_loss_e += loss_e.item() total_loss_kld += loss_kld.item() total_loss_eg += loss_eg.item() log.info(f'Loss_ALL: {total_loss_eg / i:.4f} ' f'Loss_R: {total_loss_e / i:.4f} ' f'Loss_E: {total_loss_kld / i:.4f} ') # take the lowest possible first dim min_dim = min(x, key=lambda X: X.shape[2]).shape[2] x = [X[:, :, :min_dim] for X in x] x = torch.cat(x) if config['experiments']['interpolation']['execute']: interpolation( x, encoder_pocket, hyper_network, device, results_dir, epoch, config['experiments']['interpolation']['amount'], config['experiments']['interpolation']['transitions']) if config['experiments']['interpolation_between_two_points'][ 'execute']: interpolation_between_two_points( encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['interpolation_between_two_points'] ['amount'], config['experiments'] ['interpolation_between_two_points']['image_points'], config['experiments']['interpolation_between_two_points'] ['transitions']) if config['experiments']['reconstruction']['execute']: reconstruction(encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['reconstruction']['amount']) if config['experiments']['sphere']['execute']: sphere(encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['sphere']['amount'], config['experiments']['sphere']['image_points'], config['experiments']['sphere']['start'], config['experiments']['sphere']['end'], config['experiments']['sphere']['transitions']) if config['experiments']['sphere_triangles']['execute']: sphere_triangles( encoder_pocket, hyper_network, device, x, results_dir, config['experiments']['sphere_triangles']['amount'], config['experiments']['sphere_triangles']['method'], config['experiments']['sphere_triangles']['depth'], config['experiments']['sphere_triangles']['start'], config['experiments']['sphere_triangles']['end'], config['experiments']['sphere_triangles']['transitions']) if config['experiments']['sphere_triangles_interpolation']['execute']: sphere_triangles_interpolation( encoder_pocket, hyper_network, device, x, results_dir, config['experiments']['sphere_triangles_interpolation'] ['amount'], config['experiments'] ['sphere_triangles_interpolation']['method'], config['experiments']['sphere_triangles_interpolation'] ['depth'], config['experiments'] ['sphere_triangles_interpolation']['coefficient'], config['experiments']['sphere_triangles_interpolation'] ['transitions']) if config['experiments']['different_number_of_points']['execute']: different_number_of_points( encoder_pocket, hyper_network, x, device, results_dir, epoch, config['experiments']['different_number_of_points']['amount'], config['experiments']['different_number_of_points'] ['image_points']) if config['experiments']['fixed']['execute']: # get visible element from loader (probably should be done using given object for example using # parser points_dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=8, drop_last=True, pin_memory=True, collate_fn=collate_fn) X_visible = next(iter(points_dataloader))['visible'].to( device, dtype=torch.float) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) fixed(hyper_network, encoder_visible, X_visible, device, results_dir, epoch, config['experiments']['fixed']['amount'], config['z_size'] // 2, config['experiments']['fixed']['mean'], config['experiments']['fixed']['std'], (3, 2048), config['experiments']['fixed']['triangulation']['execute'], config['experiments']['fixed']['triangulation']['method'], config['experiments']['fixed']['triangulation']['depth'])
def main(config): set_seed(config['seed']) results_dir = prepare_results_dir(config, 'vae', 'training') starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger('vae') device = cuda_setup(config['cuda'], config['gpu']) log.info(f'Device variable: {device}') if device.type == 'cuda': log.info(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') metrics_path = join(results_dir, 'metrics') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) elif dataset_name == 'custom': dataset = TxtDataset(root_dir=config['data_dir'], classes=config['classes'], config=config) elif dataset_name == 'benchmark': dataset = Benchmark(root_dir=config['data_dir'], classes=config['classes'], config=config) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.info("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True, collate_fn=collate_fn) # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) encoder_pocket = aae.PocketEncoder(config).to(device) encoder_visible = aae.VisibleEncoder(config).to(device) hyper_network.apply(weights_init) encoder_pocket.apply(weights_init) encoder_visible.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': # from utils.metrics import earth_mover_distance # reconstruction_loss = earth_mover_distance from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) else: raise ValueError( f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Optimizers # e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type']) e_hn_optimizer = e_hn_optimizer( chain(encoder_visible.parameters(), encoder_pocket.parameters(), hyper_network.parameters()), **config['optimizer']['E_HN']['hyperparams']) log.info("Starting epoch: %s" % starting_epoch) if starting_epoch > 1: log.info("Loading weights...") hyper_network.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_G.pth'))) encoder_pocket.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EP.pth'))) encoder_visible.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EV.pth'))) e_hn_optimizer.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EGo.pth'))) log.info("Loading losses...") losses_e = np.load(join(metrics_path, f'{starting_epoch - 1:05}_E.npy')).tolist() losses_kld = np.load( join(metrics_path, f'{starting_epoch - 1:05}_KLD.npy')).tolist() losses_eg = np.load( join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist() else: log.info("First epoch") losses_e = [] losses_kld = [] losses_eg = [] if config['target_network_input']['normalization']['enable']: normalization_type = config['target_network_input']['normalization'][ 'type'] assert normalization_type == 'progressive', 'Invalid normalization type' target_network_input = None for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() log.debug("Epoch: %s" % epoch) hyper_network.train() encoder_visible.train() encoder_pocket.train() total_loss_all = 0.0 total_loss_r = 0.0 total_loss_kld = 0.0 for i, point_data in enumerate(points_dataloader, 1): # get only visible part of point cloud X = point_data['non-visible'] X = X.to(device, dtype=torch.float) # get not visible part of point cloud X_visible = point_data['visible'] X_visible = X_visible.to(device, dtype=torch.float) # get whole point cloud X_whole = point_data['cloud'] X_whole = X_whole.to(device, dtype=torch.float) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1) codes, mu, logvar = encoder_pocket(X) mu_visible = encoder_visible(X_visible) target_networks_weights = hyper_network( torch.cat((codes, mu_visible), 1)) X_rec = torch.zeros(X_whole.shape).to(device) for j, target_network_weights in enumerate( target_networks_weights): target_network = aae.TargetNetwork( config, target_network_weights).to(device) if not config['target_network_input'][ 'constant'] or target_network_input is None: target_network_input = generate_points( config=config, epoch=epoch, size=(X_whole.shape[2], X_whole.shape[1])) X_rec[j] = torch.transpose( target_network(target_network_input.to(device)), 0, 1) loss_r = torch.mean(config['reconstruction_coef'] * reconstruction_loss( X_whole.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 - logvar).sum() loss_all = loss_r + loss_kld e_hn_optimizer.zero_grad() encoder_visible.zero_grad() encoder_pocket.zero_grad() hyper_network.zero_grad() loss_all.backward() e_hn_optimizer.step() total_loss_r += loss_r.item() total_loss_kld += loss_kld.item() total_loss_all += loss_all.item() log.info(f'[{epoch}/{config["max_epochs"]}] ' f'Loss_ALL: {total_loss_all / i:.4f} ' f'Loss_R: {total_loss_r / i:.4f} ' f'Loss_E: {total_loss_kld / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}') losses_e.append(total_loss_r) losses_kld.append(total_loss_kld) losses_eg.append(total_loss_all) # # Save intermediate results # X = X.cpu().numpy() X_whole = X_whole.cpu().numpy() X_rec = X_rec.detach().cpu().numpy() if epoch % config['save_frequency'] == 0: for k in range(min(5, X_rec.shape[0])): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png')) plt.close(fig) fig = plot_3d_point_cloud(X_whole[k][0], X_whole[k][1], X_whole[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_real.png')) plt.close(fig) fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_visible.png')) plt.close(fig) if config['clean_weights_dir']: log.debug('Cleaning weights path: %s' % weights_path) shutil.rmtree(weights_path, ignore_errors=True) os.makedirs(weights_path, exist_ok=True) if epoch % config['save_frequency'] == 0: log.debug('Saving data...') torch.save(hyper_network.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(encoder_visible.state_dict(), join(weights_path, f'{epoch:05}_EV.pth')) torch.save(encoder_pocket.state_dict(), join(weights_path, f'{epoch:05}_EP.pth')) torch.save(e_hn_optimizer.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth')) np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e)) np.save(join(metrics_path, f'{epoch:05}_KLD'), np.array(losses_kld)) np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
def all_metrics(config, weights_path, device, epoch, jsd_value): from utils.metrics import compute_all_metrics print('All metrics') if epoch is None: print('Finding latest epoch...') epoch = find_latest_epoch(weights_path) print(f'Epoch: {epoch}') if jsd_value is not None: print(f'Best Epoch selected via mimnimal JSD: {epoch}') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes'], split='test') else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') classes_selected = ('all' if not config['classes'] else ','.join( config['classes'])) print( f'Test dataset. Selected {classes_selected} classes. Loaded {len(dataset)} ' f'samples.') distribution = config['metrics']['distribution'] assert distribution in ['normal', 'beta' ], 'Invalid distribution. Choose normal or beta' # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) hyper_network.eval() data_loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) hyper_network.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_G.pth'))) result = {} size = 0 start_clock = datetime.now() for point_data in data_loader: X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) with torch.no_grad(): noise = torch.zeros(X.shape[0], config['z_size']).to(device) if distribution == 'normal': noise.normal_(config['metrics']['normal_mu'], config['metrics']['normal_std']) elif distribution == 'beta': noise_np = np.random.beta(config['metrics']['beta_a'], config['metrics']['beta_b'], noise.shape) noise = torch.tensor(noise_np).float().round().to(device) target_networks_weights = hyper_network(noise) X_rec = torch.zeros(X.shape).to(device) for j, target_network_weights in enumerate( target_networks_weights): target_network = aae.TargetNetwork( config, target_network_weights).to(device) target_network_input = generate_points(config=config, epoch=epoch, size=(X.shape[2], X.shape[1])) X_rec[j] = torch.transpose( target_network(target_network_input.to(device)), 0, 1) for k, v in compute_all_metrics( torch.transpose(X, 1, 2).contiguous(), torch.transpose(X_rec, 1, 2).contiguous(), X.shape[0]).items(): result[k] = (size * result.get(k, 0.0) + X.shape[0] * v.item()) / (size + X.shape[0]) size += X.shape[0] result['jsd'] = jsd_value print(f'Time: {datetime.now() - start_clock}') print(f'Result:') for k, v in result.items(): print(f'{k}: {v}')
def jsd(config, weights_path, device): print( 'Evaluating Jensen-Shannon divergences on validation set on all saved epochs.' ) # Find all epochs that have saved model weights e_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_E\.pth') g_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_G\.pth') epochs = sorted(e_epochs.intersection(g_epochs)) # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes'], split='valid') else: raise ValueError( f'Invalid dataset name. Expected `shapenet`. Got: `{dataset_name}`' ) classes_selected = ('all' if not config['classes'] else ','.join( config['classes'])) print( f'Valid dataset. Selected {classes_selected} classes. Loaded {len(dataset)} ' f'samples.') distribution = config['metrics']['distribution'] assert distribution in ['normal', 'beta' ], 'Invalid distribution. Choose normal or beta' # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) hyper_network.eval() num_samples = len(dataset.point_clouds_names_valid) data_loader = DataLoader(dataset, batch_size=num_samples, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) X, _ = next(iter(data_loader)) X = X.to(device) # We take 3 times as many samples as there are in test data in order to # perform JSD calculation in the same manner as in the reference publication noise = torch.zeros(3 * X.shape[0], config['z_size']).to(device) results = {} n_last_epochs = config['metrics'].get('jsd_how_many_last_epochs', -1) epochs = epochs[-n_last_epochs:] if n_last_epochs > 0 else epochs print(f'Testing epochs: {epochs}') for epoch in reversed(epochs): try: hyper_network.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_G.pth'))) start_clock = datetime.now() # We average JSD computation from 3 independent trials. js_results = [] for _ in range(3): if distribution == 'normal': noise.normal_(config['metrics']['normal_mu'], config['metrics']['normal_std']) elif distribution == 'beta': noise_np = np.random.beta(config['metrics']['beta_a'], config['metrics']['beta_b'], noise.shape) noise = torch.tensor(noise_np).float().round().to(device) with torch.no_grad(): target_networks_weights = hyper_network(noise) X_rec = torch.zeros(3 * X.shape[0], X.shape[1], X.shape[2]).to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X_rec.size(-1) == 3: X_rec.transpose_(X_rec.dim() - 2, X_rec.dim() - 1) for j, target_network_weights in enumerate( target_networks_weights): target_network = aae.TargetNetwork( config, target_network_weights).to(device) target_network_input = generate_points( config=config, epoch=epoch, size=(X_rec.shape[2], X_rec.shape[1])) X_rec[j] = torch.transpose( target_network(target_network_input.to(device)), 0, 1) jsd = jsd_between_point_cloud_sets( X.cpu().numpy(), torch.transpose(X_rec, 1, 2).cpu().numpy()) js_results.append(jsd) js_result = np.mean(js_results) print(f'Epoch: {epoch} JSD: {js_result: .6f} ' f'Time: {datetime.now() - start_clock}') results[epoch] = js_result except KeyboardInterrupt: print(f'Interrupted during epoch: {epoch}') break results = pd.DataFrame.from_dict(results, orient='index', columns=['jsd']) print(f"Minimum JSD at epoch {results.idxmin()['jsd']}: " f"{results.min()['jsd']: .6f}") return results.idxmin()['jsd'], results.min()['jsd']
def minimum_matching_distance(config, weights_path, device): from utils.metrics import EMD_CD print('Minimum Matching Distance (MMD) Test split') epoch = find_latest_epoch(weights_path) print(f'Last Epoch: {epoch}') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes'], split='test') else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') classes_selected = ('all' if not config['classes'] else ','.join( config['classes'])) print( f'Test dataset. Selected {classes_selected} classes. Loaded {len(dataset)} ' f'samples.') # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) encoder = aae.Encoder(config).to(device) hyper_network.eval() encoder.eval() num_samples = len(dataset.point_clouds_names_test) data_loader = DataLoader(dataset, batch_size=num_samples, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) encoder.load_state_dict(torch.load(join(weights_path, f'{epoch:05}_E.pth'))) hyper_network.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_G.pth'))) result = {} for point_data in data_loader: X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) with torch.no_grad(): z_a, _, _ = encoder(X) target_networks_weights = hyper_network(z_a) X_rec = torch.zeros(X.shape).to(device) for j, target_network_weights in enumerate( target_networks_weights): target_network = aae.TargetNetwork( config, target_network_weights).to(device) target_network_input = generate_points(config=config, epoch=epoch, size=(X.shape[2], X.shape[1])) X_rec[j] = torch.transpose( target_network(target_network_input.to(device)), 0, 1) for k, v in EMD_CD( torch.transpose(X, 1, 2).contiguous(), torch.transpose(X_rec, 1, 2).contiguous(), X.shape[0]).items(): result[k] = result.get(k, 0.0) + v.item() print(result)
def main(config): set_seed(config['seed']) results_dir = prepare_results_dir(config, 'aae', 'training') starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger('aae') device = cuda_setup(config['cuda'], config['gpu']) log.info(f'Device variable: {device}') if device.type == 'cuda': log.info(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') metrics_path = join(results_dir, 'metrics') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.info("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], pin_memory=True) pointnet = config.get('pointnet', False) # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) if pointnet: from models.pointnet import PointNet encoder = PointNet(config).to(device) # PointNet initializes it's own weights during instance creation else: encoder = aae.Encoder(config).to(device) encoder.apply(weights_init) discriminator = aae.Discriminator(config).to(device) hyper_network.apply(weights_init) discriminator.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': if pointnet: from utils.metrics import chamfer_distance reconstruction_loss = chamfer_distance else: from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': from utils.metrics import earth_mover_distance reconstruction_loss = earth_mover_distance else: raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Optimizers # e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type']) e_hn_optimizer = e_hn_optimizer(chain(encoder.parameters(), hyper_network.parameters()), **config['optimizer']['E_HN']['hyperparams']) discriminator_optimizer = getattr(optim, config['optimizer']['D']['type']) discriminator_optimizer = discriminator_optimizer(discriminator.parameters(), **config['optimizer']['D']['hyperparams']) log.info("Starting epoch: %s" % starting_epoch) if starting_epoch > 1: log.info("Loading weights...") hyper_network.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}_G.pth'))) encoder.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}_E.pth'))) discriminator.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_D.pth'))) e_hn_optimizer.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}_EGo.pth'))) discriminator_optimizer.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_Do.pth'))) log.info("Loading losses...") losses_e = np.load(join(metrics_path, f'{starting_epoch - 1:05}_E.npy')).tolist() losses_g = np.load(join(metrics_path, f'{starting_epoch - 1:05}_G.npy')).tolist() losses_eg = np.load(join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist() losses_d = np.load(join(metrics_path, f'{starting_epoch - 1:05}_D.npy')).tolist() else: log.info("First epoch") losses_e = [] losses_g = [] losses_eg = [] losses_d = [] normalize_points = config['target_network_input']['normalization']['enable'] if normalize_points: normalization_type = config['target_network_input']['normalization']['type'] assert normalization_type == 'progressive', 'Invalid normalization type' target_network_input = None for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() log.debug("Epoch: %s" % epoch) hyper_network.train() encoder.train() discriminator.train() total_loss_all = 0.0 total_loss_reconstruction = 0.0 total_loss_encoder = 0.0 total_loss_discriminator = 0.0 total_loss_regularization = 0.0 for i, point_data in enumerate(points_dataloader, 1): X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) if pointnet: _, feature_transform, codes = encoder(X) else: codes, _, _ = encoder(X) # discriminator training noise = torch.empty(codes.shape[0], config['z_size']).normal_(mean=config['normal_mu'], std=config['normal_std']).to(device) synth_logit = discriminator(codes) real_logit = discriminator(noise) if config.get('wasserstein', True): loss_discriminator = torch.mean(synth_logit) - torch.mean(real_logit) alpha = torch.rand(codes.shape[0], 1).to(device) differences = codes - noise interpolates = noise + alpha * differences disc_interpolates = discriminator(interpolates) # gradient_penalty_function gradients = grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(disc_interpolates).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0] slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1)) gradient_penalty = ((slopes - 1) ** 2).mean() loss_gp = config['gradient_penalty_coef'] * gradient_penalty loss_discriminator += loss_gp else: # An alternative is a = -1, b = 1 iff c = 0 a = 0.0 b = 1.0 loss_discriminator = 0.5 * ((real_logit - b)**2 + (synth_logit - a)**2) discriminator_optimizer.zero_grad() discriminator.zero_grad() loss_discriminator.backward(retain_graph=True) total_loss_discriminator += loss_discriminator.item() discriminator_optimizer.step() # hyper network training target_networks_weights = hyper_network(codes) X_rec = torch.zeros(X.shape).to(device) for j, target_network_weights in enumerate(target_networks_weights): target_network = aae.TargetNetwork(config, target_network_weights).to(device) if not config['target_network_input']['constant'] or target_network_input is None: target_network_input = generate_points(config=config, epoch=epoch, size=(X.shape[2], X.shape[1])) X_rec[j] = torch.transpose(target_network(target_network_input.to(device)), 0, 1) if pointnet: loss_reconstruction = config['reconstruction_coef'] * \ reconstruction_loss(torch.transpose(X, 1, 2).contiguous(), torch.transpose(X_rec, 1, 2).contiguous(), batch_size=X.shape[0]).mean() else: loss_reconstruction = torch.mean( config['reconstruction_coef'] * reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) # encoder training synth_logit = discriminator(codes) if config.get('wasserstein', True): loss_encoder = -torch.mean(synth_logit) else: # An alternative is c = 0 iff a = -1, b = 1 c = 1.0 loss_encoder = 0.5 * (synth_logit - c)**2 if pointnet: regularization_loss = config['feature_regularization_coef'] * \ feature_transform_regularization(feature_transform).mean() loss_all = loss_reconstruction + loss_encoder + regularization_loss else: loss_all = loss_reconstruction + loss_encoder e_hn_optimizer.zero_grad() encoder.zero_grad() hyper_network.zero_grad() loss_all.backward() e_hn_optimizer.step() total_loss_reconstruction += loss_reconstruction.item() total_loss_encoder += loss_encoder.item() total_loss_all += loss_all.item() if pointnet: total_loss_regularization += regularization_loss.item() log.info( f'[{epoch}/{config["max_epochs"]}] ' f'Total_Loss: {total_loss_all / i:.4f} ' f'Loss_R: {total_loss_reconstruction / i:.4f} ' f'Loss_E: {total_loss_encoder / i:.4f} ' f'Loss_D: {total_loss_discriminator / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}' ) if pointnet: log.info(f'Loss_Regularization: {total_loss_regularization / i:.4f}') losses_e.append(total_loss_reconstruction) losses_g.append(total_loss_encoder) losses_eg.append(total_loss_all) losses_d.append(total_loss_discriminator) # # Save intermediate results # if epoch % config['save_samples_frequency'] == 0: log.debug('Saving samples...') X = X.cpu().numpy() X_rec = X_rec.detach().cpu().numpy() for k in range(min(5, X_rec.shape[0])): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png')) plt.close(fig) fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False) fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_real.png')) plt.close(fig) if config['clean_weights_dir']: log.debug('Cleaning weights path: %s' % weights_path) shutil.rmtree(weights_path, ignore_errors=True) os.makedirs(weights_path, exist_ok=True) if epoch % config['save_weights_frequency'] == 0: log.debug('Saving weights and losses...') torch.save(hyper_network.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(encoder.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(e_hn_optimizer.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth')) torch.save(discriminator.state_dict(), join(weights_path, f'{epoch:05}_D.pth')) torch.save(discriminator_optimizer.state_dict(), join(weights_path, f'{epoch:05}_Do.pth')) np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e)) np.save(join(metrics_path, f'{epoch:05}_G'), np.array(losses_g)) np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg)) np.save(join(metrics_path, f'{epoch:05}_D'), np.array(losses_d))