def main(config): random.seed(config['seed']) torch.manual_seed(config['seed']) torch.cuda.manual_seed_all(config['seed']) results_dir = prepare_results_dir(config) starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger(__name__) device = cuda_setup(config['cuda'], config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.debug("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True) # # Models # arch = import_module(f"models.{config['arch']}") G = arch.Generator(config).to(device) E = arch.Encoder(config).to(device) D = arch.Discriminator(config).to(device) G.apply(weights_init) E.apply(weights_init) D.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) else: raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Float Tensors # distribution = config['distribution'].lower() if distribution == 'bernoulli': p = torch.tensor(config['p']).to(device) sampler = Bernoulli(probs=p) fixed_noise = sampler.sample(torch.Size([config['batch_size'], config['z_size']])) elif distribution == 'beta': fixed_noise_np = np.random.beta(config['z_beta_a'], config['z_beta_b'], size=(config['batch_size'], config['z_size'])) fixed_noise = torch.tensor(fixed_noise_np).float().to(device) else: raise ValueError('Invalid distribution for binaray model.') # # Optimizers # EG_optim = getattr(optim, config['optimizer']['EG']['type']) EG_optim = EG_optim(chain(E.parameters(), G.parameters()), **config['optimizer']['EG']['hyperparams']) D_optim = getattr(optim, config['optimizer']['D']['type']) D_optim = D_optim(D.parameters(), **config['optimizer']['D']['hyperparams']) if starting_epoch > 1: G.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_G.pth'))) E.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_E.pth'))) D.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_D.pth'))) D_optim.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_Do.pth'))) EG_optim.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) loss_d_tot, loss_gp_tot, loss_e_tot, loss_g_tot = [], [], [], [] for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() G.train() E.train() D.train() total_loss_eg = 0.0 total_loss_d = 0.0 for i, point_data in enumerate(points_dataloader, 1): log.debug('-' * 20) X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) codes = E(X) if distribution == 'bernoulli': noise = sampler.sample(fixed_noise.size()) elif distribution == 'beta': noise_np = np.random.beta(config['z_beta_a'], config['z_beta_b'], size=(config['batch_size'], config['z_size'])) noise = torch.tensor(noise_np).float().to(device) synth_logit = D(codes) real_logit = D(noise) loss_d = torch.mean(synth_logit) - torch.mean(real_logit) loss_d_tot.append(loss_d) # Gradient Penalty alpha = torch.rand(config['batch_size'], 1).to(device) differences = codes - noise interpolates = noise + alpha * differences disc_interpolates = D(interpolates) gradients = grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(disc_interpolates).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0] slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1)) gradient_penalty = ((slopes - 1) ** 2).mean() loss_gp = config['gp_lambda'] * gradient_penalty loss_gp_tot.append(loss_gp) ### loss_d += loss_gp D_optim.zero_grad() D.zero_grad() loss_d.backward(retain_graph=True) total_loss_d += loss_d.item() D_optim.step() # EG part of training X_rec = G(codes) loss_e = torch.mean( config['reconstruction_coef'] * reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_e_tot.append(loss_e) synth_logit = D(codes) loss_g = -torch.mean(synth_logit) loss_g_tot.append(loss_g) loss_eg = loss_e + loss_g EG_optim.zero_grad() E.zero_grad() G.zero_grad() loss_eg.backward() total_loss_eg += loss_eg.item() EG_optim.step() log.debug(f'[{epoch}: ({i})] ' f'Loss_D: {loss_d.item():.4f} ' f'(GP: {loss_gp.item(): .4f}) ' f'Loss_EG: {loss_eg.item():.4f} ' f'(REC: {loss_e.item(): .4f}) ' f'Time: {datetime.now() - start_epoch_time}') log.debug( f'[{epoch}/{config["max_epochs"]}] ' f'Loss_D: {total_loss_d / i:.4f} ' f'Loss_EG: {total_loss_eg / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}' ) # # Save intermediate results # G.eval() E.eval() D.eval() with torch.no_grad(): fake = G(fixed_noise).data.cpu().numpy() X_rec = G(E(X)).data.cpu().numpy() X = X.data.cpu().numpy() plt.figure(figsize=(16, 9)) plt.plot(loss_d_tot, 'r-', label="loss_d") plt.plot(loss_gp_tot, 'g-', label="loss_gp") plt.plot(loss_e_tot, 'b-', label="loss_e") plt.plot(loss_g_tot, 'k-', label="loss_g") plt.legend() plt.xlabel("Batch number") plt.xlabel("Loss value") plt.savefig( join(results_dir, 'samples', f'loss_plot.png')) plt.close() for k in range(5): fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_real.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig(join(results_dir, 'samples', f'{epoch:05}_{k}_reconstructed.png')) plt.close(fig) if epoch % config['save_frequency'] == 0: torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(D.state_dict(), join(weights_path, f'{epoch:05}_D.pth')) torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(EG_optim.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth')) torch.save(D_optim.state_dict(), join(weights_path, f'{epoch:05}_Do.pth'))
def main(config): set_seed(config['seed']) results_dir = prepare_results_dir(config, config['arch'], 'experiments', dirs_to_create=[ 'interpolations', 'sphere', 'points_interpolation', 'different_number_points', 'fixed', 'reconstruction', 'sphere_triangles', 'sphere_triangles_interpolation' ]) weights_path = get_weights_dir(config) epoch = find_latest_epoch(weights_path) if not epoch: print("Invalid 'weights_path' in configuration") exit(1) setup_logging(results_dir) global log log = logging.getLogger('aae') if not exists(join(results_dir, 'experiment_config.json')): with open(join(results_dir, 'experiment_config.json'), mode='w') as f: json.dump(config, f) device = cuda_setup(config['cuda'], config['gpu']) log.info(f'Device variable: {device}') if device.type == 'cuda': log.info(f'Current CUDA device: {torch.cuda.current_device()}') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) elif dataset_name == 'custom': dataset = TxtDataset(root_dir=config['data_dir'], classes=config['classes'], config=config) elif dataset_name == 'benchmark': dataset = Benchmark(root_dir=config['data_dir'], classes=config['classes'], config=config) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.info("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8, drop_last=True, pin_memory=True, collate_fn=collate_fn) # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) encoder_visible = aae.VisibleEncoder(config).to(device) encoder_pocket = aae.PocketEncoder(config).to(device) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': # from utils.metrics import earth_mover_distance # reconstruction_loss = earth_mover_distance from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) else: raise ValueError( f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') log.info("Weights for epoch: %s" % epoch) log.info("Loading weights...") hyper_network.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_G.pth'))) encoder_pocket.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_EP.pth'))) encoder_visible.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_EV.pth'))) hyper_network.eval() encoder_visible.eval() encoder_pocket.eval() total_loss_eg = 0.0 total_loss_e = 0.0 total_loss_kld = 0.0 x = [] with torch.no_grad(): for i, point_data in enumerate(points_dataloader, 1): X = point_data['non-visible'] X = X.to(device, dtype=torch.float) # get whole point cloud X_whole = point_data['cloud'] X_whole = X_whole.to(device, dtype=torch.float) # get visible point cloud X_visible = point_data['visible'] X_visible = X_visible.to(device, dtype=torch.float) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) x.append(X) codes, mu, logvar = encoder_pocket(X) mu_visible = encoder_visible(X_visible) target_networks_weights = hyper_network( torch.cat((codes, mu_visible), 1)) X_rec = torch.zeros(X_whole.shape).to(device) for j, target_network_weights in enumerate( target_networks_weights): target_network = aae.TargetNetwork( config, target_network_weights).to(device) target_network_input = generate_points(config=config, epoch=epoch, size=(X_whole.shape[2], X_whole.shape[1])) X_rec[j] = torch.transpose( target_network(target_network_input.to(device)), 0, 1) loss_e = torch.mean(config['reconstruction_coef'] * reconstruction_loss( X_whole.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 - logvar).sum() loss_eg = loss_e + loss_kld total_loss_e += loss_e.item() total_loss_kld += loss_kld.item() total_loss_eg += loss_eg.item() log.info(f'Loss_ALL: {total_loss_eg / i:.4f} ' f'Loss_R: {total_loss_e / i:.4f} ' f'Loss_E: {total_loss_kld / i:.4f} ') # take the lowest possible first dim min_dim = min(x, key=lambda X: X.shape[2]).shape[2] x = [X[:, :, :min_dim] for X in x] x = torch.cat(x) if config['experiments']['interpolation']['execute']: interpolation( x, encoder_pocket, hyper_network, device, results_dir, epoch, config['experiments']['interpolation']['amount'], config['experiments']['interpolation']['transitions']) if config['experiments']['interpolation_between_two_points'][ 'execute']: interpolation_between_two_points( encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['interpolation_between_two_points'] ['amount'], config['experiments'] ['interpolation_between_two_points']['image_points'], config['experiments']['interpolation_between_two_points'] ['transitions']) if config['experiments']['reconstruction']['execute']: reconstruction(encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['reconstruction']['amount']) if config['experiments']['sphere']['execute']: sphere(encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['sphere']['amount'], config['experiments']['sphere']['image_points'], config['experiments']['sphere']['start'], config['experiments']['sphere']['end'], config['experiments']['sphere']['transitions']) if config['experiments']['sphere_triangles']['execute']: sphere_triangles( encoder_pocket, hyper_network, device, x, results_dir, config['experiments']['sphere_triangles']['amount'], config['experiments']['sphere_triangles']['method'], config['experiments']['sphere_triangles']['depth'], config['experiments']['sphere_triangles']['start'], config['experiments']['sphere_triangles']['end'], config['experiments']['sphere_triangles']['transitions']) if config['experiments']['sphere_triangles_interpolation']['execute']: sphere_triangles_interpolation( encoder_pocket, hyper_network, device, x, results_dir, config['experiments']['sphere_triangles_interpolation'] ['amount'], config['experiments'] ['sphere_triangles_interpolation']['method'], config['experiments']['sphere_triangles_interpolation'] ['depth'], config['experiments'] ['sphere_triangles_interpolation']['coefficient'], config['experiments']['sphere_triangles_interpolation'] ['transitions']) if config['experiments']['different_number_of_points']['execute']: different_number_of_points( encoder_pocket, hyper_network, x, device, results_dir, epoch, config['experiments']['different_number_of_points']['amount'], config['experiments']['different_number_of_points'] ['image_points']) if config['experiments']['fixed']['execute']: # get visible element from loader (probably should be done using given object for example using # parser points_dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=8, drop_last=True, pin_memory=True, collate_fn=collate_fn) X_visible = next(iter(points_dataloader))['visible'].to( device, dtype=torch.float) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) fixed(hyper_network, encoder_visible, X_visible, device, results_dir, epoch, config['experiments']['fixed']['amount'], config['z_size'] // 2, config['experiments']['fixed']['mean'], config['experiments']['fixed']['std'], (3, 2048), config['experiments']['fixed']['triangulation']['execute'], config['experiments']['fixed']['triangulation']['method'], config['experiments']['fixed']['triangulation']['depth'])
def main(config): set_seed(config['seed']) results_dir = prepare_results_dir(config, 'aae', 'training') starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger('aae') device = cuda_setup(config['cuda'], config['gpu']) log.info(f'Device variable: {device}') if device.type == 'cuda': log.info(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') metrics_path = join(results_dir, 'metrics') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.info("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], pin_memory=True) pointnet = config.get('pointnet', False) # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) if pointnet: from models.pointnet import PointNet encoder = PointNet(config).to(device) # PointNet initializes it's own weights during instance creation else: encoder = aae.Encoder(config).to(device) encoder.apply(weights_init) discriminator = aae.Discriminator(config).to(device) hyper_network.apply(weights_init) discriminator.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': if pointnet: from utils.metrics import chamfer_distance reconstruction_loss = chamfer_distance else: from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': from utils.metrics import earth_mover_distance reconstruction_loss = earth_mover_distance else: raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Optimizers # e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type']) e_hn_optimizer = e_hn_optimizer(chain(encoder.parameters(), hyper_network.parameters()), **config['optimizer']['E_HN']['hyperparams']) discriminator_optimizer = getattr(optim, config['optimizer']['D']['type']) discriminator_optimizer = discriminator_optimizer(discriminator.parameters(), **config['optimizer']['D']['hyperparams']) log.info("Starting epoch: %s" % starting_epoch) if starting_epoch > 1: log.info("Loading weights...") hyper_network.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}_G.pth'))) encoder.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}_E.pth'))) discriminator.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_D.pth'))) e_hn_optimizer.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}_EGo.pth'))) discriminator_optimizer.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_Do.pth'))) log.info("Loading losses...") losses_e = np.load(join(metrics_path, f'{starting_epoch - 1:05}_E.npy')).tolist() losses_g = np.load(join(metrics_path, f'{starting_epoch - 1:05}_G.npy')).tolist() losses_eg = np.load(join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist() losses_d = np.load(join(metrics_path, f'{starting_epoch - 1:05}_D.npy')).tolist() else: log.info("First epoch") losses_e = [] losses_g = [] losses_eg = [] losses_d = [] normalize_points = config['target_network_input']['normalization']['enable'] if normalize_points: normalization_type = config['target_network_input']['normalization']['type'] assert normalization_type == 'progressive', 'Invalid normalization type' target_network_input = None for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() log.debug("Epoch: %s" % epoch) hyper_network.train() encoder.train() discriminator.train() total_loss_all = 0.0 total_loss_reconstruction = 0.0 total_loss_encoder = 0.0 total_loss_discriminator = 0.0 total_loss_regularization = 0.0 for i, point_data in enumerate(points_dataloader, 1): X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) if pointnet: _, feature_transform, codes = encoder(X) else: codes, _, _ = encoder(X) # discriminator training noise = torch.empty(codes.shape[0], config['z_size']).normal_(mean=config['normal_mu'], std=config['normal_std']).to(device) synth_logit = discriminator(codes) real_logit = discriminator(noise) if config.get('wasserstein', True): loss_discriminator = torch.mean(synth_logit) - torch.mean(real_logit) alpha = torch.rand(codes.shape[0], 1).to(device) differences = codes - noise interpolates = noise + alpha * differences disc_interpolates = discriminator(interpolates) # gradient_penalty_function gradients = grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(disc_interpolates).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0] slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1)) gradient_penalty = ((slopes - 1) ** 2).mean() loss_gp = config['gradient_penalty_coef'] * gradient_penalty loss_discriminator += loss_gp else: # An alternative is a = -1, b = 1 iff c = 0 a = 0.0 b = 1.0 loss_discriminator = 0.5 * ((real_logit - b)**2 + (synth_logit - a)**2) discriminator_optimizer.zero_grad() discriminator.zero_grad() loss_discriminator.backward(retain_graph=True) total_loss_discriminator += loss_discriminator.item() discriminator_optimizer.step() # hyper network training target_networks_weights = hyper_network(codes) X_rec = torch.zeros(X.shape).to(device) for j, target_network_weights in enumerate(target_networks_weights): target_network = aae.TargetNetwork(config, target_network_weights).to(device) if not config['target_network_input']['constant'] or target_network_input is None: target_network_input = generate_points(config=config, epoch=epoch, size=(X.shape[2], X.shape[1])) X_rec[j] = torch.transpose(target_network(target_network_input.to(device)), 0, 1) if pointnet: loss_reconstruction = config['reconstruction_coef'] * \ reconstruction_loss(torch.transpose(X, 1, 2).contiguous(), torch.transpose(X_rec, 1, 2).contiguous(), batch_size=X.shape[0]).mean() else: loss_reconstruction = torch.mean( config['reconstruction_coef'] * reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) # encoder training synth_logit = discriminator(codes) if config.get('wasserstein', True): loss_encoder = -torch.mean(synth_logit) else: # An alternative is c = 0 iff a = -1, b = 1 c = 1.0 loss_encoder = 0.5 * (synth_logit - c)**2 if pointnet: regularization_loss = config['feature_regularization_coef'] * \ feature_transform_regularization(feature_transform).mean() loss_all = loss_reconstruction + loss_encoder + regularization_loss else: loss_all = loss_reconstruction + loss_encoder e_hn_optimizer.zero_grad() encoder.zero_grad() hyper_network.zero_grad() loss_all.backward() e_hn_optimizer.step() total_loss_reconstruction += loss_reconstruction.item() total_loss_encoder += loss_encoder.item() total_loss_all += loss_all.item() if pointnet: total_loss_regularization += regularization_loss.item() log.info( f'[{epoch}/{config["max_epochs"]}] ' f'Total_Loss: {total_loss_all / i:.4f} ' f'Loss_R: {total_loss_reconstruction / i:.4f} ' f'Loss_E: {total_loss_encoder / i:.4f} ' f'Loss_D: {total_loss_discriminator / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}' ) if pointnet: log.info(f'Loss_Regularization: {total_loss_regularization / i:.4f}') losses_e.append(total_loss_reconstruction) losses_g.append(total_loss_encoder) losses_eg.append(total_loss_all) losses_d.append(total_loss_discriminator) # # Save intermediate results # if epoch % config['save_samples_frequency'] == 0: log.debug('Saving samples...') X = X.cpu().numpy() X_rec = X_rec.detach().cpu().numpy() for k in range(min(5, X_rec.shape[0])): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png')) plt.close(fig) fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False) fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_real.png')) plt.close(fig) if config['clean_weights_dir']: log.debug('Cleaning weights path: %s' % weights_path) shutil.rmtree(weights_path, ignore_errors=True) os.makedirs(weights_path, exist_ok=True) if epoch % config['save_weights_frequency'] == 0: log.debug('Saving weights and losses...') torch.save(hyper_network.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(encoder.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(e_hn_optimizer.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth')) torch.save(discriminator.state_dict(), join(weights_path, f'{epoch:05}_D.pth')) torch.save(discriminator_optimizer.state_dict(), join(weights_path, f'{epoch:05}_Do.pth')) np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e)) np.save(join(metrics_path, f'{epoch:05}_G'), np.array(losses_g)) np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg)) np.save(join(metrics_path, f'{epoch:05}_D'), np.array(losses_d))
def main(config): set_seed(config['seed']) results_dir = prepare_results_dir(config, 'vae', 'training') starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger('vae') device = cuda_setup(config['cuda'], config['gpu']) log.info(f'Device variable: {device}') if device.type == 'cuda': log.info(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') metrics_path = join(results_dir, 'metrics') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) elif dataset_name == 'custom': dataset = TxtDataset(root_dir=config['data_dir'], classes=config['classes'], config=config) elif dataset_name == 'benchmark': dataset = Benchmark(root_dir=config['data_dir'], classes=config['classes'], config=config) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.info("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True, collate_fn=collate_fn) # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) encoder_pocket = aae.PocketEncoder(config).to(device) encoder_visible = aae.VisibleEncoder(config).to(device) hyper_network.apply(weights_init) encoder_pocket.apply(weights_init) encoder_visible.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': # from utils.metrics import earth_mover_distance # reconstruction_loss = earth_mover_distance from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) else: raise ValueError( f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Optimizers # e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type']) e_hn_optimizer = e_hn_optimizer( chain(encoder_visible.parameters(), encoder_pocket.parameters(), hyper_network.parameters()), **config['optimizer']['E_HN']['hyperparams']) log.info("Starting epoch: %s" % starting_epoch) if starting_epoch > 1: log.info("Loading weights...") hyper_network.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_G.pth'))) encoder_pocket.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EP.pth'))) encoder_visible.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EV.pth'))) e_hn_optimizer.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EGo.pth'))) log.info("Loading losses...") losses_e = np.load(join(metrics_path, f'{starting_epoch - 1:05}_E.npy')).tolist() losses_kld = np.load( join(metrics_path, f'{starting_epoch - 1:05}_KLD.npy')).tolist() losses_eg = np.load( join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist() else: log.info("First epoch") losses_e = [] losses_kld = [] losses_eg = [] if config['target_network_input']['normalization']['enable']: normalization_type = config['target_network_input']['normalization'][ 'type'] assert normalization_type == 'progressive', 'Invalid normalization type' target_network_input = None for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() log.debug("Epoch: %s" % epoch) hyper_network.train() encoder_visible.train() encoder_pocket.train() total_loss_all = 0.0 total_loss_r = 0.0 total_loss_kld = 0.0 for i, point_data in enumerate(points_dataloader, 1): # get only visible part of point cloud X = point_data['non-visible'] X = X.to(device, dtype=torch.float) # get not visible part of point cloud X_visible = point_data['visible'] X_visible = X_visible.to(device, dtype=torch.float) # get whole point cloud X_whole = point_data['cloud'] X_whole = X_whole.to(device, dtype=torch.float) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1) codes, mu, logvar = encoder_pocket(X) mu_visible = encoder_visible(X_visible) target_networks_weights = hyper_network( torch.cat((codes, mu_visible), 1)) X_rec = torch.zeros(X_whole.shape).to(device) for j, target_network_weights in enumerate( target_networks_weights): target_network = aae.TargetNetwork( config, target_network_weights).to(device) if not config['target_network_input'][ 'constant'] or target_network_input is None: target_network_input = generate_points( config=config, epoch=epoch, size=(X_whole.shape[2], X_whole.shape[1])) X_rec[j] = torch.transpose( target_network(target_network_input.to(device)), 0, 1) loss_r = torch.mean(config['reconstruction_coef'] * reconstruction_loss( X_whole.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 - logvar).sum() loss_all = loss_r + loss_kld e_hn_optimizer.zero_grad() encoder_visible.zero_grad() encoder_pocket.zero_grad() hyper_network.zero_grad() loss_all.backward() e_hn_optimizer.step() total_loss_r += loss_r.item() total_loss_kld += loss_kld.item() total_loss_all += loss_all.item() log.info(f'[{epoch}/{config["max_epochs"]}] ' f'Loss_ALL: {total_loss_all / i:.4f} ' f'Loss_R: {total_loss_r / i:.4f} ' f'Loss_E: {total_loss_kld / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}') losses_e.append(total_loss_r) losses_kld.append(total_loss_kld) losses_eg.append(total_loss_all) # # Save intermediate results # X = X.cpu().numpy() X_whole = X_whole.cpu().numpy() X_rec = X_rec.detach().cpu().numpy() if epoch % config['save_frequency'] == 0: for k in range(min(5, X_rec.shape[0])): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png')) plt.close(fig) fig = plot_3d_point_cloud(X_whole[k][0], X_whole[k][1], X_whole[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_real.png')) plt.close(fig) fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_visible.png')) plt.close(fig) if config['clean_weights_dir']: log.debug('Cleaning weights path: %s' % weights_path) shutil.rmtree(weights_path, ignore_errors=True) os.makedirs(weights_path, exist_ok=True) if epoch % config['save_frequency'] == 0: log.debug('Saving data...') torch.save(hyper_network.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(encoder_visible.state_dict(), join(weights_path, f'{epoch:05}_EV.pth')) torch.save(encoder_pocket.state_dict(), join(weights_path, f'{epoch:05}_EP.pth')) torch.save(e_hn_optimizer.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth')) np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e)) np.save(join(metrics_path, f'{epoch:05}_KLD'), np.array(losses_kld)) np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
def main(config): random.seed(config['seed']) torch.manual_seed(config['seed']) torch.cuda.manual_seed_all(config['seed']) results_dir = prepare_results_dir(config) starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger(__name__) device = cuda_setup(config['cuda'], config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.debug("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True) # # Models # arch = import_module(f"model.architectures.{config['arch']}") G = arch.Generator(config).to(device) E = arch.Encoder(config).to(device) G.apply(weights_init) E.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': reconstruction_loss = EMD().to(device) else: raise ValueError( f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Optimizers # EG_optim = getattr(optim, config['optimizer']['EG']['type']) EG_optim = EG_optim(chain(E.parameters(), G.parameters()), **config['optimizer']['EG']['hyperparams']) if starting_epoch > 1: G.load_state_dict( torch.load(join(weights_path, f'{starting_epoch-1:05}_G.pth'))) E.load_state_dict( torch.load(join(weights_path, f'{starting_epoch-1:05}_E.pth'))) EG_optim.load_state_dict( torch.load(join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() G.train() E.train() total_loss = 0.0 for i, point_data in enumerate(points_dataloader, 1): log.debug('-' * 20) X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_rec = G(E(X)) loss = torch.mean(config['reconstruction_coef'] * reconstruction_loss( X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) EG_optim.zero_grad() E.zero_grad() G.zero_grad() loss.backward() total_loss += loss.item() EG_optim.step() log.debug(f'[{epoch}: ({i})] ' f'Loss: {loss.item():.4f} ' f'Time: {datetime.now() - start_epoch_time}') log.debug(f'[{epoch}/{config["max_epochs"]}] ' f'Loss: {total_loss / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}') # # Save intermediate results # G.eval() E.eval() with torch.no_grad(): X_rec = G(E(X)).data.cpu().numpy() for k in range(5): fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_real.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_reconstructed.png')) plt.close(fig) if epoch % config['save_frequency'] == 0: torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(EG_optim.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth'))
def main(eval_config): # Load hyperparameters as they were during training train_results_path = join(eval_config['results_root'], eval_config['arch'], eval_config['experiment_name']) with open(join(train_results_path, 'config.json')) as f: train_config = json.load(f) random.seed(train_config['seed']) torch.manual_seed(train_config['seed']) torch.cuda.manual_seed_all(train_config['seed']) setup_logging(join(train_results_path, 'results')) log = logging.getLogger(__name__) weights_path = join(train_results_path, 'weights') if eval_config['epoch'] == 0: epoch = find_latest_epoch(weights_path) else: epoch = eval_config['epoch'] log.debug(f'Starting from epoch: {epoch}') device = cuda_setup(eval_config['cuda'], eval_config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') # # Dataset # dataset_name = train_config['dataset'].lower() if dataset_name == 'shapenet': dataset = ShapeNetDataset(root_dir=train_config['data_dir'], classes=train_config['classes'], split='test') elif dataset_name == 'faust': from datasets.dfaust import DFaustDataset dataset = DFaustDataset(root_dir=train_config['data_dir'], classes=train_config['classes'], split='test') elif dataset_name == 'mcgill': from datasets.mcgill import McGillDataset dataset = McGillDataset(root_dir=train_config['data_dir'], classes=train_config['classes'], split='test') elif dataset_name == 'custom': from datasets.customdataset import CustomDataset dataset = CustomDataset(root_dir=train_config['data_dir'], classes=train_config['classes'], split='test') else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') classes_selected = ('all' if not train_config['classes'] else ','.join( train_config['classes'])) log.debug(f'Selected {classes_selected} classes. Loaded {len(dataset)} ' f'samples.') if 'distribution' in train_config: distribution = train_config['distribution'] elif 'distribution' in eval_config: distribution = eval_config['distribution'] else: log.warning( 'No distribution type specified. Assumed normal = N(0, 0.2)') distribution = 'normal' # # Models # arch = import_module(f"models.{eval_config['arch']}") E = arch.Encoder(train_config).to(device) G = arch.Generator(train_config).to(device) # # Load saved state # E.load_state_dict(torch.load(join(weights_path, f'{epoch:05}_E.pth'))) G.load_state_dict(torch.load(join(weights_path, f'{epoch:05}_G.pth'))) E.eval() G.eval() num_samples = len(dataset.point_clouds_names_test) data_loader = DataLoader(dataset, batch_size=num_samples, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) # We take 3 times as many samples as there are in test data in order to # perform JSD calculation in the same manner as in the reference publication noise = torch.FloatTensor(3 * num_samples, train_config['z_size'], 1) noise = noise.to(device) X, _ = next(iter(data_loader)) X = X.to(device) X_ = X.data.cpu().numpy() np.save(join(train_results_path, 'results', f'{epoch:05}_X'), X_) for i in range(3): if distribution == 'normal': noise.normal_(0, 0.2) else: noise_np = np.random.beta(train_config['z_beta_a'], train_config['z_beta_b'], noise.shape) noise = torch.tensor(noise_np).float().round().to(device) with torch.no_grad(): X_g = G(noise) if X_g.shape[-2:] == (3, 2048): X_g.transpose_(1, 2) X_g_ = X_g.data.cpu().numpy() np.save(join(train_results_path, 'results', f'{epoch:05}_Xg_{i}'), X_g_) with torch.no_grad(): z_e = E(X.transpose(1, 2)) if isinstance(z_e, tuple): z_e = z_e[0] X_rec = G(z_e) if X_rec.shape[-2:] == (3, 2048): X_rec.transpose_(1, 2) X_rec_ = X_rec.data.cpu().numpy() np.save(join(train_results_path, 'results', f'{epoch:05}_Xrec'), X_rec_)
def main(eval_config): # Load hyperparameters as they were during training train_results_path = join(eval_config['results_root'], eval_config['arch'], eval_config['experiment_name']) with open(join(train_results_path, 'config.json')) as f: train_config = json.load(f) random.seed(train_config['seed']) torch.manual_seed(train_config['seed']) torch.cuda.manual_seed_all(train_config['seed']) setup_logging(join(train_results_path, 'results')) log = logging.getLogger(__name__) log.debug('Evaluating JensenShannon divergences on validation set on all ' 'saved epochs.') weights_path = join(train_results_path, 'weights') # Find all epochs that have saved model weights e_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_E\.pth') g_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_G\.pth') epochs = sorted(e_epochs.intersection(g_epochs)) log.debug(f'Testing epochs: {epochs}') device = cuda_setup(eval_config['cuda'], eval_config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') # # Dataset # dataset_name = train_config['dataset'].lower() if dataset_name == 'shapenet': dataset = ShapeNetDataset(root_dir=train_config['data_dir'], classes=train_config['classes'], split='valid') elif dataset_name == 'faust': from datasets.dfaust import DFaustDataset dataset = DFaustDataset(root_dir=train_config['data_dir'], classes=train_config['classes'], split='valid') elif dataset_name == 'mcgill': from datasets.mcgill import McGillDataset dataset = McGillDataset(root_dir=train_config['data_dir'], classes=train_config['classes'], split='valid') else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') classes_selected = ('all' if not train_config['classes'] else ','.join( train_config['classes'])) log.debug(f'Selected {classes_selected} classes. Loaded {len(dataset)} ' f'samples.') if 'distribution' in train_config: distribution = train_config['distribution'] elif 'distribution' in eval_config: distribution = eval_config['distribution'] else: log.warning( 'No distribution type specified. Assumed normal = N(0, 0.2)') distribution = 'normal' # # Models # arch = import_module(f"model.architectures.{train_config['arch']}") E = arch.Encoder(train_config).to(device) G = arch.Generator(train_config).to(device) E.eval() G.eval() num_samples = len(dataset.point_clouds_names_valid) data_loader = DataLoader(dataset, batch_size=num_samples, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) # We take 3 times as many samples as there are in test data in order to # perform JSD calculation in the same manner as in the reference publication noise = torch.FloatTensor(3 * num_samples, train_config['z_size'], 1) noise = noise.to(device) X, _ = next(iter(data_loader)) X = X.to(device) results = {} for epoch in reversed(epochs): try: E.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_E.pth'))) G.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_G.pth'))) start_clock = datetime.now() # We average JSD computation from 3 independet trials. js_results = [] for _ in range(3): if distribution == 'normal': noise.normal_(0, 0.2) elif distribution == 'beta': noise_np = np.random.beta(train_config['z_beta_a'], train_config['z_beta_b'], noise.shape) noise = torch.tensor(noise_np).float().round().to(device) with torch.no_grad(): X_g = G(noise) if X_g.shape[-2:] == (3, 2048): X_g.transpose_(1, 2) jsd = jsd_between_point_cloud_sets(X, X_g, voxels=28) js_results.append(jsd) js_result = np.mean(js_results) log.debug(f'Epoch: {epoch} JSD: {js_result: .6f} ' f'Time: {datetime.now() - start_clock}') results[epoch] = js_result except KeyboardInterrupt: log.debug(f'Interrupted during epoch: {epoch}') break results = pd.DataFrame.from_dict(results, orient='index', columns=['jsd']) log.debug(f"Minimum JSD at epoch {results.idxmin()['jsd']}: " f"{results.min()['jsd']: .6f}")
def main(config): if config['seed'] >= 0: random.seed(config['seed']) torch.manual_seed(config['seed']) torch.cuda.manual_seed(config['seed']) np.random.seed(config['seed']) torch.backends.cudnn.deterministic = True print("random seed: ", config['seed']) results_dir = prepare_results_dir(config) starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger(__name__) logging.getLogger('matplotlib.font_manager').disabled = True device = cuda_setup(config['cuda'], config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') # load dataset dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.debug("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True) scale = 1 / (3 * config['n_points']) # hyper-parameters valid_frequency = config["valid_frequency"] num_vae = config["num_vae"] beta_rec = config["beta_rec"] beta_kl = config["beta_kl"] beta_neg = config["beta_neg"] gamma_r = config["gamma_r"] apply_random_rotation = "rotate" in config["transforms"] if apply_random_rotation: print("applying random rotation to input shapes") # model model = SoftIntroVAE(config).to(device) if config['reconstruction_loss'].lower() == 'chamfer': from losses.chamfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) else: raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') prior_std = config["prior_std"] prior_logvar = np.log(prior_std ** 2) print(f'prior: N(0, {prior_std ** 2:.3f})') # optimizers optimizer_e = getattr(optim, config['optimizer']['E']['type']) optimizer_e = optimizer_e(model.encoder.parameters(), **config['optimizer']['E']['hyperparams']) optimizer_d = getattr(optim, config['optimizer']['D']['type']) optimizer_d = optimizer_d(model.decoder.parameters(), **config['optimizer']['D']['hyperparams']) scheduler_e = optim.lr_scheduler.MultiStepLR(optimizer_e, milestones=[350, 450, 550], gamma=0.5) scheduler_d = optim.lr_scheduler.MultiStepLR(optimizer_d, milestones=[350, 450, 550], gamma=0.5) if starting_epoch > 1: model.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}.pth'))) optimizer_e.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}_optim_e.pth'))) optimizer_d.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}_optim_d.pth'))) kls_real = [] kls_fake = [] kls_rec = [] rec_errs = [] exp_elbos_f = [] exp_elbos_r = [] diff_kls = [] best_res = {"epoch": 0, "jsd": None} for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() model.train() if epoch < num_vae: total_loss = 0.0 pbar = tqdm(iterable=points_dataloader) for i, point_data in enumerate(pbar, 1): x, _ = point_data x = x.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if x.size(-1) == 3: x.transpose_(x.dim() - 2, x.dim() - 1) x_rec, mu, logvar = model(x) loss_rec = reconstruction_loss(x.permute(0, 2, 1) + 0.5, x_rec.permute(0, 2, 1) + 0.5) while len(loss_rec.shape) > 1: loss_rec = loss_rec.sum(-1) loss_rec = loss_rec.mean() loss_kl = calc_kl(logvar, mu, logvar_o=prior_logvar, reduce="mean") loss = beta_rec * loss_rec + beta_kl * loss_kl optimizer_e.zero_grad() optimizer_d.zero_grad() loss.backward() total_loss += loss.item() optimizer_e.step() optimizer_d.step() pbar.set_description_str('epoch #{}'.format(epoch)) pbar.set_postfix(r_loss=loss_rec.data.cpu().item(), kl=loss_kl.data.cpu().item()) else: batch_kls_real = [] batch_kls_fake = [] batch_kls_rec = [] batch_rec_errs = [] batch_exp_elbo_f = [] batch_exp_elbo_r = [] batch_diff_kls = [] pbar = tqdm(iterable=points_dataloader) for i, point_data in enumerate(pbar, 1): x, _ = point_data x = x.to(device) # random rotation if apply_random_rotation: angle = torch.rand(size=(x.shape[0],)) * 180 rotate_transform = RotateAxisAngle(angle, axis="Z", device=device) x = rotate_transform.transform_points(x) # change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if x.size(-1) == 3: x.transpose_(x.dim() - 2, x.dim() - 1) noise_batch = prior_std * torch.randn(size=(config['batch_size'], model.zdim)).to(device) # ----- update E ----- # for param in model.encoder.parameters(): param.requires_grad = True for param in model.decoder.parameters(): param.requires_grad = False fake = model.sample(noise_batch) real_mu, real_logvar = model.encode(x) z = reparameterize(real_mu, real_logvar) x_rec = model.decoder(z) loss_rec = reconstruction_loss(x.permute(0, 2, 1) + 0.5, x_rec.permute(0, 2, 1) + 0.5) while len(loss_rec.shape) > 1: loss_rec = loss_rec.sum(-1) loss_rec = loss_rec.mean() loss_real_kl = calc_kl(real_logvar, real_mu, logvar_o=prior_logvar, reduce="mean") rec_rec, rec_mu, rec_logvar = model(x_rec.detach()) rec_fake, fake_mu, fake_logvar = model(fake.detach()) kl_rec = calc_kl(rec_logvar, rec_mu, logvar_o=prior_logvar, reduce="none") kl_fake = calc_kl(fake_logvar, fake_mu, logvar_o=prior_logvar, reduce="none") loss_rec_rec_e = reconstruction_loss(x_rec.detach().permute(0, 2, 1) + 0.5, rec_rec.permute(0, 2, 1) + 0.5) while len(loss_rec_rec_e.shape) > 1: loss_rec_rec_e = loss_rec_rec_e.sum(-1) loss_rec_fake_e = reconstruction_loss(fake.permute(0, 2, 1) + 0.5, rec_fake.permute(0, 2, 1) + 0.5) while len(loss_rec_fake_e.shape) > 1: loss_rec_fake_e = loss_rec_fake_e.sum(-1) expelbo_rec = (-2 * scale * (beta_rec * loss_rec_rec_e + beta_neg * kl_rec)).exp().mean() expelbo_fake = (-2 * scale * (beta_rec * loss_rec_fake_e + beta_neg * kl_fake)).exp().mean() loss_margin = scale * beta_kl * loss_real_kl + 0.25 * (expelbo_rec + expelbo_fake) lossE = scale * beta_rec * loss_rec + loss_margin optimizer_e.zero_grad() lossE.backward() optimizer_e.step() # ----- update D ----- # for param in model.encoder.parameters(): param.requires_grad = False for param in model.decoder.parameters(): param.requires_grad = True fake = model.sample(noise_batch) with torch.no_grad(): z = reparameterize(real_mu.detach(), real_logvar.detach()) rec = model.decoder(z.detach()) loss_rec = reconstruction_loss(x.permute(0, 2, 1) + 0.5, rec.permute(0, 2, 1) + 0.5) while len(loss_rec.shape) > 1: loss_rec = loss_rec.sum(-1) loss_rec = loss_rec.mean() rec_mu, rec_logvar = model.encode(rec) z_rec = reparameterize(rec_mu, rec_logvar) fake_mu, fake_logvar = model.encode(fake) z_fake = reparameterize(fake_mu, fake_logvar) rec_rec = model.decode(z_rec.detach()) rec_fake = model.decode(z_fake.detach()) loss_rec_rec = reconstruction_loss(rec.detach().permute(0, 2, 1) + 0.5, rec_rec.permute(0, 2, 1) + 0.5) while len(loss_rec_rec.shape) > 1: loss_rec_rec = loss_rec.sum(-1) loss_rec_rec = loss_rec_rec.mean() loss_fake_rec = reconstruction_loss(fake.detach().permute(0, 2, 1) + 0.5, rec_fake.permute(0, 2, 1) + 0.5) while len(loss_fake_rec.shape) > 1: loss_fake_rec = loss_rec.sum(-1) loss_fake_rec = loss_fake_rec.mean() lossD_rec_kl = calc_kl(rec_logvar, rec_mu, logvar_o=prior_logvar, reduce="mean") lossD_fake_kl = calc_kl(fake_logvar, fake_mu, logvar_o=prior_logvar, reduce="mean") lossD = scale * (loss_rec * beta_rec + ( lossD_rec_kl + lossD_fake_kl) * 0.5 * beta_kl + gamma_r * 0.5 * beta_rec * ( loss_rec_rec + loss_fake_rec)) optimizer_d.zero_grad() lossD.backward() optimizer_d.step() if torch.isnan(lossD): raise SystemError("loss is Nan") diff_kl = -loss_real_kl.data.cpu() + lossD_fake_kl.data.cpu() batch_diff_kls.append(diff_kl) batch_kls_real.append(loss_real_kl.data.cpu().item()) batch_kls_fake.append(lossD_fake_kl.cpu().item()) batch_kls_rec.append(lossD_rec_kl.data.cpu().item()) batch_rec_errs.append(loss_rec.data.cpu().item()) batch_exp_elbo_f.append(expelbo_fake.data.cpu()) batch_exp_elbo_r.append(expelbo_rec.data.cpu()) pbar.set_description_str('epoch #{}'.format(epoch)) pbar.set_postfix(r_loss=loss_rec.data.cpu().item(), kl=loss_real_kl.data.cpu().item(), diff_kl=diff_kl.item(), expelbo_f=expelbo_fake.cpu().item()) pbar.close() scheduler_e.step() scheduler_d.step() if epoch > num_vae - 1: kls_real.append(np.mean(batch_kls_real)) kls_fake.append(np.mean(batch_kls_fake)) kls_rec.append(np.mean(batch_kls_rec)) rec_errs.append(np.mean(batch_rec_errs)) exp_elbos_f.append(np.mean(batch_exp_elbo_f)) exp_elbos_r.append(np.mean(batch_exp_elbo_r)) diff_kls.append(np.mean(batch_diff_kls)) # epoch summary print('#' * 50) print(f'Epoch {epoch} Summary:') print(f'beta_rec: {beta_rec}, beta_kl: {beta_kl}, beta_neg: {beta_neg}') print( f'rec: {rec_errs[-1]:.3f}, kl: {kls_real[-1]:.3f}, kl_fake: {kls_fake[-1]:.3f}, kl_rec: {kls_rec[-1]:.3f}') print( f'diff_kl: {diff_kls[-1]:.3f}, exp_elbo_f: {exp_elbos_f[-1]:.4e}, exp_elbo_r: {exp_elbos_r[-1]:.4e}') if best_res['jsd'] is not None: print(f'best jsd: {best_res["jsd"]}, epoch: {best_res["epoch"]}') print(f'time: {datetime.now() - start_epoch_time}') print('#' * 50) # save intermediate results model.eval() with torch.no_grad(): noise_batch = prior_std * torch.randn(size=(5, model.zdim)).to(device) fake = model.sample(noise_batch).data.cpu().numpy() x_rec, _, _ = model(x, deterministic=True) x_rec = x_rec.data.cpu().numpy() fig = plt.figure(dpi=350) for k in range(5): ax = fig.add_subplot(3, 5, k + 1, projection='3d') ax = plot_3d_point_cloud(x[k][0].data.cpu().numpy(), x[k][1].data.cpu().numpy(), x[k][2].data.cpu().numpy(), in_u_sphere=True, show=False, axis=ax, show_axis=True, s=4, color='dodgerblue') remove_ticks_from_ax(ax) for k in range(5): ax = fig.add_subplot(3, 5, k + 6, projection='3d') ax = plot_3d_point_cloud(x_rec[k][0], x_rec[k][1], x_rec[k][2], in_u_sphere=True, show=False, axis=ax, show_axis=True, s=4, color='dodgerblue') remove_ticks_from_ax(ax) for k in range(5): ax = fig.add_subplot(3, 5, k + 11, projection='3d') ax = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2], in_u_sphere=True, show=False, axis=ax, show_axis=True, s=4, color='dodgerblue') remove_ticks_from_ax(ax) fig.savefig(join(results_dir, 'samples', f'figure_{epoch}')) plt.close(fig) if epoch % valid_frequency == 0: print("calculating valid jsd...") model.eval() with torch.no_grad(): jsd = calc_jsd_valid(model, config, prior_std=prior_std) print(f'epoch: {epoch}, jsd: {jsd:.4f}') if best_res['jsd'] is None: best_res['jsd'] = jsd best_res['epoch'] = epoch elif best_res['jsd'] > jsd: print(f'epoch: {epoch}: best jsd updated: {best_res["jsd"]} -> {jsd}') best_res['jsd'] = jsd best_res['epoch'] = epoch # save torch.save(model.state_dict(), join(weights_path, f'{epoch:05}_jsd_{jsd:.4f}.pth')) if epoch % config['save_frequency'] == 0: torch.save(model.state_dict(), join(weights_path, f'{epoch:05}.pth')) torch.save(optimizer_e.state_dict(), join(weights_path, f'{epoch:05}_optim_e.pth')) torch.save(optimizer_d.state_dict(), join(weights_path, f'{epoch:05}_optim_d.pth'))
def main(config): random.seed(config['seed']) torch.manual_seed(config['seed']) torch.cuda.manual_seed_all(config['seed']) results_dir = prepare_results_dir(config) starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger('vae') device = cuda_setup(config['cuda'], config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.debug("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True) # # Models # arch = import_module(f"models.{config['arch']}") G = arch.Generator(config).to(device) E = arch.Encoder(config).to(device) G.apply(weights_init) E.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) elif config['reconstruction_loss'].lower() == 'cramer_wold': from losses.cramer_wold import CWSample reconstruction_loss = CWSample().to(device) else: raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Float Tensors # fixed_noise = torch.FloatTensor(config['batch_size'], config['z_size'], 1) fixed_noise.normal_(mean=0, std=0.2) std_assumed = torch.tensor(0.2) fixed_noise = fixed_noise.to(device) std_assumed = std_assumed.to(device) # # Optimizers # EG_optim = getattr(optim, config['optimizer']['EG']['type']) EG_optim = EG_optim(chain(E.parameters(), G.parameters()), **config['optimizer']['EG']['hyperparams']) if starting_epoch > 1: G.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_G.pth'))) E.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_E.pth'))) EG_optim.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) losses = [] with trange(starting_epoch, config['max_epochs'] + 1) as t: for epoch in t: start_epoch_time = datetime.now() G.train() E.train() total_loss = 0.0 losses_eg = [] losses_e = [] losses_kld = [] for i, point_data in enumerate(points_dataloader, 1): # log.debug('-' * 20) X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) codes, mu, logvar = E(X) X_rec = G(codes) loss_e = torch.mean( # config['reconstruction_coef'] * reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_kld = config['reconstruction_coef'] * cw_distance(mu) # loss_kld = -0.5 * torch.mean( # 1 - 2.0 * torch.log(std_assumed) + logvar - # (mu.pow(2) + logvar.exp()) / torch.pow(std_assumed, 2)) loss_eg = loss_e + loss_kld EG_optim.zero_grad() E.zero_grad() G.zero_grad() loss_eg.backward() total_loss += loss_eg.item() EG_optim.step() losses_e.append(loss_e.item()) losses_kld.append(loss_kld.item()) losses_eg.append(loss_eg.item()) # log.debug t.set_description( f'[{epoch}: ({i})] ' f'Loss_EG: {loss_eg.item():.4f} ' f'(REC: {loss_e.item(): .4f}' f' KLD: {loss_kld.item(): .4f})' f' Time: {datetime.now() - start_epoch_time}' ) t.set_description( f'[{epoch}/{config["max_epochs"]}] ' f'Loss_G: {total_loss / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}' ) losses.append([ np.mean(losses_e), np.mean(losses_kld), np.mean(losses_eg) ]) # # Save intermediate results # G.eval() E.eval() with torch.no_grad(): fake = G(fixed_noise).data.cpu().numpy() codes, _, _ = E(X) X_rec = G(codes).data.cpu().numpy() X_numpy = X.cpu().numpy() for k in range(5): fig = plot_3d_point_cloud(X_numpy[k][0], X_numpy[k][1], X_numpy[k][2], in_u_sphere=True, show=False) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_real.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', 'fixed', f'{epoch:05}_{k}_fixed.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False) fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png')) plt.close(fig) if epoch % config['save_frequency'] == 0: df = pd.DataFrame(losses, columns=['loss_e', 'loss_kld', 'loss_eg']) df.to_json(join(results_dir, 'losses', 'losses.json')) fig = df.plot.line().get_figure() fig.savefig(join(results_dir, 'losses', f'{epoch:05}_{k}.png')) torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(EG_optim.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth'))