def main(config): random.seed(config['seed']) torch.manual_seed(config['seed']) torch.cuda.manual_seed_all(config['seed']) results_dir = prepare_results_dir(config) starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger(__name__) device = cuda_setup(config['cuda'], config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.debug("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True) # # Models # arch = import_module(f"models.{config['arch']}") G = arch.Generator(config).to(device) E = arch.Encoder(config).to(device) D = arch.Discriminator(config).to(device) G.apply(weights_init) E.apply(weights_init) D.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) else: raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Float Tensors # distribution = config['distribution'].lower() if distribution == 'bernoulli': p = torch.tensor(config['p']).to(device) sampler = Bernoulli(probs=p) fixed_noise = sampler.sample(torch.Size([config['batch_size'], config['z_size']])) elif distribution == 'beta': fixed_noise_np = np.random.beta(config['z_beta_a'], config['z_beta_b'], size=(config['batch_size'], config['z_size'])) fixed_noise = torch.tensor(fixed_noise_np).float().to(device) else: raise ValueError('Invalid distribution for binaray model.') # # Optimizers # EG_optim = getattr(optim, config['optimizer']['EG']['type']) EG_optim = EG_optim(chain(E.parameters(), G.parameters()), **config['optimizer']['EG']['hyperparams']) D_optim = getattr(optim, config['optimizer']['D']['type']) D_optim = D_optim(D.parameters(), **config['optimizer']['D']['hyperparams']) if starting_epoch > 1: G.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_G.pth'))) E.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_E.pth'))) D.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_D.pth'))) D_optim.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_Do.pth'))) EG_optim.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) loss_d_tot, loss_gp_tot, loss_e_tot, loss_g_tot = [], [], [], [] for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() G.train() E.train() D.train() total_loss_eg = 0.0 total_loss_d = 0.0 for i, point_data in enumerate(points_dataloader, 1): log.debug('-' * 20) X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) codes = E(X) if distribution == 'bernoulli': noise = sampler.sample(fixed_noise.size()) elif distribution == 'beta': noise_np = np.random.beta(config['z_beta_a'], config['z_beta_b'], size=(config['batch_size'], config['z_size'])) noise = torch.tensor(noise_np).float().to(device) synth_logit = D(codes) real_logit = D(noise) loss_d = torch.mean(synth_logit) - torch.mean(real_logit) loss_d_tot.append(loss_d) # Gradient Penalty alpha = torch.rand(config['batch_size'], 1).to(device) differences = codes - noise interpolates = noise + alpha * differences disc_interpolates = D(interpolates) gradients = grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(disc_interpolates).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0] slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1)) gradient_penalty = ((slopes - 1) ** 2).mean() loss_gp = config['gp_lambda'] * gradient_penalty loss_gp_tot.append(loss_gp) ### loss_d += loss_gp D_optim.zero_grad() D.zero_grad() loss_d.backward(retain_graph=True) total_loss_d += loss_d.item() D_optim.step() # EG part of training X_rec = G(codes) loss_e = torch.mean( config['reconstruction_coef'] * reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_e_tot.append(loss_e) synth_logit = D(codes) loss_g = -torch.mean(synth_logit) loss_g_tot.append(loss_g) loss_eg = loss_e + loss_g EG_optim.zero_grad() E.zero_grad() G.zero_grad() loss_eg.backward() total_loss_eg += loss_eg.item() EG_optim.step() log.debug(f'[{epoch}: ({i})] ' f'Loss_D: {loss_d.item():.4f} ' f'(GP: {loss_gp.item(): .4f}) ' f'Loss_EG: {loss_eg.item():.4f} ' f'(REC: {loss_e.item(): .4f}) ' f'Time: {datetime.now() - start_epoch_time}') log.debug( f'[{epoch}/{config["max_epochs"]}] ' f'Loss_D: {total_loss_d / i:.4f} ' f'Loss_EG: {total_loss_eg / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}' ) # # Save intermediate results # G.eval() E.eval() D.eval() with torch.no_grad(): fake = G(fixed_noise).data.cpu().numpy() X_rec = G(E(X)).data.cpu().numpy() X = X.data.cpu().numpy() plt.figure(figsize=(16, 9)) plt.plot(loss_d_tot, 'r-', label="loss_d") plt.plot(loss_gp_tot, 'g-', label="loss_gp") plt.plot(loss_e_tot, 'b-', label="loss_e") plt.plot(loss_g_tot, 'k-', label="loss_g") plt.legend() plt.xlabel("Batch number") plt.xlabel("Loss value") plt.savefig( join(results_dir, 'samples', f'loss_plot.png')) plt.close() for k in range(5): fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_real.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig(join(results_dir, 'samples', f'{epoch:05}_{k}_reconstructed.png')) plt.close(fig) if epoch % config['save_frequency'] == 0: torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(D.state_dict(), join(weights_path, f'{epoch:05}_D.pth')) torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(EG_optim.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth')) torch.save(D_optim.state_dict(), join(weights_path, f'{epoch:05}_Do.pth'))
def main(config): set_seed(config['seed']) results_dir = prepare_results_dir(config, config['arch'], 'experiments', dirs_to_create=[ 'interpolations', 'sphere', 'points_interpolation', 'different_number_points', 'fixed', 'reconstruction', 'sphere_triangles', 'sphere_triangles_interpolation' ]) weights_path = get_weights_dir(config) epoch = find_latest_epoch(weights_path) if not epoch: print("Invalid 'weights_path' in configuration") exit(1) setup_logging(results_dir) global log log = logging.getLogger('aae') if not exists(join(results_dir, 'experiment_config.json')): with open(join(results_dir, 'experiment_config.json'), mode='w') as f: json.dump(config, f) device = cuda_setup(config['cuda'], config['gpu']) log.info(f'Device variable: {device}') if device.type == 'cuda': log.info(f'Current CUDA device: {torch.cuda.current_device()}') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) elif dataset_name == 'custom': dataset = TxtDataset(root_dir=config['data_dir'], classes=config['classes'], config=config) elif dataset_name == 'benchmark': dataset = Benchmark(root_dir=config['data_dir'], classes=config['classes'], config=config) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.info("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8, drop_last=True, pin_memory=True, collate_fn=collate_fn) # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) encoder_visible = aae.VisibleEncoder(config).to(device) encoder_pocket = aae.PocketEncoder(config).to(device) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': # from utils.metrics import earth_mover_distance # reconstruction_loss = earth_mover_distance from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) else: raise ValueError( f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') log.info("Weights for epoch: %s" % epoch) log.info("Loading weights...") hyper_network.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_G.pth'))) encoder_pocket.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_EP.pth'))) encoder_visible.load_state_dict( torch.load(join(weights_path, f'{epoch:05}_EV.pth'))) hyper_network.eval() encoder_visible.eval() encoder_pocket.eval() total_loss_eg = 0.0 total_loss_e = 0.0 total_loss_kld = 0.0 x = [] with torch.no_grad(): for i, point_data in enumerate(points_dataloader, 1): X = point_data['non-visible'] X = X.to(device, dtype=torch.float) # get whole point cloud X_whole = point_data['cloud'] X_whole = X_whole.to(device, dtype=torch.float) # get visible point cloud X_visible = point_data['visible'] X_visible = X_visible.to(device, dtype=torch.float) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) x.append(X) codes, mu, logvar = encoder_pocket(X) mu_visible = encoder_visible(X_visible) target_networks_weights = hyper_network( torch.cat((codes, mu_visible), 1)) X_rec = torch.zeros(X_whole.shape).to(device) for j, target_network_weights in enumerate( target_networks_weights): target_network = aae.TargetNetwork( config, target_network_weights).to(device) target_network_input = generate_points(config=config, epoch=epoch, size=(X_whole.shape[2], X_whole.shape[1])) X_rec[j] = torch.transpose( target_network(target_network_input.to(device)), 0, 1) loss_e = torch.mean(config['reconstruction_coef'] * reconstruction_loss( X_whole.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 - logvar).sum() loss_eg = loss_e + loss_kld total_loss_e += loss_e.item() total_loss_kld += loss_kld.item() total_loss_eg += loss_eg.item() log.info(f'Loss_ALL: {total_loss_eg / i:.4f} ' f'Loss_R: {total_loss_e / i:.4f} ' f'Loss_E: {total_loss_kld / i:.4f} ') # take the lowest possible first dim min_dim = min(x, key=lambda X: X.shape[2]).shape[2] x = [X[:, :, :min_dim] for X in x] x = torch.cat(x) if config['experiments']['interpolation']['execute']: interpolation( x, encoder_pocket, hyper_network, device, results_dir, epoch, config['experiments']['interpolation']['amount'], config['experiments']['interpolation']['transitions']) if config['experiments']['interpolation_between_two_points'][ 'execute']: interpolation_between_two_points( encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['interpolation_between_two_points'] ['amount'], config['experiments'] ['interpolation_between_two_points']['image_points'], config['experiments']['interpolation_between_two_points'] ['transitions']) if config['experiments']['reconstruction']['execute']: reconstruction(encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['reconstruction']['amount']) if config['experiments']['sphere']['execute']: sphere(encoder_pocket, hyper_network, device, x, results_dir, epoch, config['experiments']['sphere']['amount'], config['experiments']['sphere']['image_points'], config['experiments']['sphere']['start'], config['experiments']['sphere']['end'], config['experiments']['sphere']['transitions']) if config['experiments']['sphere_triangles']['execute']: sphere_triangles( encoder_pocket, hyper_network, device, x, results_dir, config['experiments']['sphere_triangles']['amount'], config['experiments']['sphere_triangles']['method'], config['experiments']['sphere_triangles']['depth'], config['experiments']['sphere_triangles']['start'], config['experiments']['sphere_triangles']['end'], config['experiments']['sphere_triangles']['transitions']) if config['experiments']['sphere_triangles_interpolation']['execute']: sphere_triangles_interpolation( encoder_pocket, hyper_network, device, x, results_dir, config['experiments']['sphere_triangles_interpolation'] ['amount'], config['experiments'] ['sphere_triangles_interpolation']['method'], config['experiments']['sphere_triangles_interpolation'] ['depth'], config['experiments'] ['sphere_triangles_interpolation']['coefficient'], config['experiments']['sphere_triangles_interpolation'] ['transitions']) if config['experiments']['different_number_of_points']['execute']: different_number_of_points( encoder_pocket, hyper_network, x, device, results_dir, epoch, config['experiments']['different_number_of_points']['amount'], config['experiments']['different_number_of_points'] ['image_points']) if config['experiments']['fixed']['execute']: # get visible element from loader (probably should be done using given object for example using # parser points_dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=8, drop_last=True, pin_memory=True, collate_fn=collate_fn) X_visible = next(iter(points_dataloader))['visible'].to( device, dtype=torch.float) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) fixed(hyper_network, encoder_visible, X_visible, device, results_dir, epoch, config['experiments']['fixed']['amount'], config['z_size'] // 2, config['experiments']['fixed']['mean'], config['experiments']['fixed']['std'], (3, 2048), config['experiments']['fixed']['triangulation']['execute'], config['experiments']['fixed']['triangulation']['method'], config['experiments']['fixed']['triangulation']['depth'])
def main(config): set_seed(config['seed']) results_dir = prepare_results_dir(config, 'aae', 'training') starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger('aae') device = cuda_setup(config['cuda'], config['gpu']) log.info(f'Device variable: {device}') if device.type == 'cuda': log.info(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') metrics_path = join(results_dir, 'metrics') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.info("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], pin_memory=True) pointnet = config.get('pointnet', False) # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) if pointnet: from models.pointnet import PointNet encoder = PointNet(config).to(device) # PointNet initializes it's own weights during instance creation else: encoder = aae.Encoder(config).to(device) encoder.apply(weights_init) discriminator = aae.Discriminator(config).to(device) hyper_network.apply(weights_init) discriminator.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': if pointnet: from utils.metrics import chamfer_distance reconstruction_loss = chamfer_distance else: from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': from utils.metrics import earth_mover_distance reconstruction_loss = earth_mover_distance else: raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Optimizers # e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type']) e_hn_optimizer = e_hn_optimizer(chain(encoder.parameters(), hyper_network.parameters()), **config['optimizer']['E_HN']['hyperparams']) discriminator_optimizer = getattr(optim, config['optimizer']['D']['type']) discriminator_optimizer = discriminator_optimizer(discriminator.parameters(), **config['optimizer']['D']['hyperparams']) log.info("Starting epoch: %s" % starting_epoch) if starting_epoch > 1: log.info("Loading weights...") hyper_network.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}_G.pth'))) encoder.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}_E.pth'))) discriminator.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_D.pth'))) e_hn_optimizer.load_state_dict(torch.load( join(weights_path, f'{starting_epoch - 1:05}_EGo.pth'))) discriminator_optimizer.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_Do.pth'))) log.info("Loading losses...") losses_e = np.load(join(metrics_path, f'{starting_epoch - 1:05}_E.npy')).tolist() losses_g = np.load(join(metrics_path, f'{starting_epoch - 1:05}_G.npy')).tolist() losses_eg = np.load(join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist() losses_d = np.load(join(metrics_path, f'{starting_epoch - 1:05}_D.npy')).tolist() else: log.info("First epoch") losses_e = [] losses_g = [] losses_eg = [] losses_d = [] normalize_points = config['target_network_input']['normalization']['enable'] if normalize_points: normalization_type = config['target_network_input']['normalization']['type'] assert normalization_type == 'progressive', 'Invalid normalization type' target_network_input = None for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() log.debug("Epoch: %s" % epoch) hyper_network.train() encoder.train() discriminator.train() total_loss_all = 0.0 total_loss_reconstruction = 0.0 total_loss_encoder = 0.0 total_loss_discriminator = 0.0 total_loss_regularization = 0.0 for i, point_data in enumerate(points_dataloader, 1): X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) if pointnet: _, feature_transform, codes = encoder(X) else: codes, _, _ = encoder(X) # discriminator training noise = torch.empty(codes.shape[0], config['z_size']).normal_(mean=config['normal_mu'], std=config['normal_std']).to(device) synth_logit = discriminator(codes) real_logit = discriminator(noise) if config.get('wasserstein', True): loss_discriminator = torch.mean(synth_logit) - torch.mean(real_logit) alpha = torch.rand(codes.shape[0], 1).to(device) differences = codes - noise interpolates = noise + alpha * differences disc_interpolates = discriminator(interpolates) # gradient_penalty_function gradients = grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(disc_interpolates).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0] slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1)) gradient_penalty = ((slopes - 1) ** 2).mean() loss_gp = config['gradient_penalty_coef'] * gradient_penalty loss_discriminator += loss_gp else: # An alternative is a = -1, b = 1 iff c = 0 a = 0.0 b = 1.0 loss_discriminator = 0.5 * ((real_logit - b)**2 + (synth_logit - a)**2) discriminator_optimizer.zero_grad() discriminator.zero_grad() loss_discriminator.backward(retain_graph=True) total_loss_discriminator += loss_discriminator.item() discriminator_optimizer.step() # hyper network training target_networks_weights = hyper_network(codes) X_rec = torch.zeros(X.shape).to(device) for j, target_network_weights in enumerate(target_networks_weights): target_network = aae.TargetNetwork(config, target_network_weights).to(device) if not config['target_network_input']['constant'] or target_network_input is None: target_network_input = generate_points(config=config, epoch=epoch, size=(X.shape[2], X.shape[1])) X_rec[j] = torch.transpose(target_network(target_network_input.to(device)), 0, 1) if pointnet: loss_reconstruction = config['reconstruction_coef'] * \ reconstruction_loss(torch.transpose(X, 1, 2).contiguous(), torch.transpose(X_rec, 1, 2).contiguous(), batch_size=X.shape[0]).mean() else: loss_reconstruction = torch.mean( config['reconstruction_coef'] * reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) # encoder training synth_logit = discriminator(codes) if config.get('wasserstein', True): loss_encoder = -torch.mean(synth_logit) else: # An alternative is c = 0 iff a = -1, b = 1 c = 1.0 loss_encoder = 0.5 * (synth_logit - c)**2 if pointnet: regularization_loss = config['feature_regularization_coef'] * \ feature_transform_regularization(feature_transform).mean() loss_all = loss_reconstruction + loss_encoder + regularization_loss else: loss_all = loss_reconstruction + loss_encoder e_hn_optimizer.zero_grad() encoder.zero_grad() hyper_network.zero_grad() loss_all.backward() e_hn_optimizer.step() total_loss_reconstruction += loss_reconstruction.item() total_loss_encoder += loss_encoder.item() total_loss_all += loss_all.item() if pointnet: total_loss_regularization += regularization_loss.item() log.info( f'[{epoch}/{config["max_epochs"]}] ' f'Total_Loss: {total_loss_all / i:.4f} ' f'Loss_R: {total_loss_reconstruction / i:.4f} ' f'Loss_E: {total_loss_encoder / i:.4f} ' f'Loss_D: {total_loss_discriminator / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}' ) if pointnet: log.info(f'Loss_Regularization: {total_loss_regularization / i:.4f}') losses_e.append(total_loss_reconstruction) losses_g.append(total_loss_encoder) losses_eg.append(total_loss_all) losses_d.append(total_loss_discriminator) # # Save intermediate results # if epoch % config['save_samples_frequency'] == 0: log.debug('Saving samples...') X = X.cpu().numpy() X_rec = X_rec.detach().cpu().numpy() for k in range(min(5, X_rec.shape[0])): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png')) plt.close(fig) fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False) fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_real.png')) plt.close(fig) if config['clean_weights_dir']: log.debug('Cleaning weights path: %s' % weights_path) shutil.rmtree(weights_path, ignore_errors=True) os.makedirs(weights_path, exist_ok=True) if epoch % config['save_weights_frequency'] == 0: log.debug('Saving weights and losses...') torch.save(hyper_network.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(encoder.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(e_hn_optimizer.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth')) torch.save(discriminator.state_dict(), join(weights_path, f'{epoch:05}_D.pth')) torch.save(discriminator_optimizer.state_dict(), join(weights_path, f'{epoch:05}_Do.pth')) np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e)) np.save(join(metrics_path, f'{epoch:05}_G'), np.array(losses_g)) np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg)) np.save(join(metrics_path, f'{epoch:05}_D'), np.array(losses_d))
def main(config): set_seed(config['seed']) results_dir = prepare_results_dir(config, 'vae', 'training') starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger('vae') device = cuda_setup(config['cuda'], config['gpu']) log.info(f'Device variable: {device}') if device.type == 'cuda': log.info(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') metrics_path = join(results_dir, 'metrics') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) elif dataset_name == 'custom': dataset = TxtDataset(root_dir=config['data_dir'], classes=config['classes'], config=config) elif dataset_name == 'benchmark': dataset = Benchmark(root_dir=config['data_dir'], classes=config['classes'], config=config) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.info("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True, collate_fn=collate_fn) # # Models # hyper_network = aae.HyperNetwork(config, device).to(device) encoder_pocket = aae.PocketEncoder(config).to(device) encoder_visible = aae.VisibleEncoder(config).to(device) hyper_network.apply(weights_init) encoder_pocket.apply(weights_init) encoder_visible.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': # from utils.metrics import earth_mover_distance # reconstruction_loss = earth_mover_distance from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) else: raise ValueError( f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Optimizers # e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type']) e_hn_optimizer = e_hn_optimizer( chain(encoder_visible.parameters(), encoder_pocket.parameters(), hyper_network.parameters()), **config['optimizer']['E_HN']['hyperparams']) log.info("Starting epoch: %s" % starting_epoch) if starting_epoch > 1: log.info("Loading weights...") hyper_network.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_G.pth'))) encoder_pocket.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EP.pth'))) encoder_visible.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EV.pth'))) e_hn_optimizer.load_state_dict( torch.load(join(weights_path, f'{starting_epoch - 1:05}_EGo.pth'))) log.info("Loading losses...") losses_e = np.load(join(metrics_path, f'{starting_epoch - 1:05}_E.npy')).tolist() losses_kld = np.load( join(metrics_path, f'{starting_epoch - 1:05}_KLD.npy')).tolist() losses_eg = np.load( join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist() else: log.info("First epoch") losses_e = [] losses_kld = [] losses_eg = [] if config['target_network_input']['normalization']['enable']: normalization_type = config['target_network_input']['normalization'][ 'type'] assert normalization_type == 'progressive', 'Invalid normalization type' target_network_input = None for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() log.debug("Epoch: %s" % epoch) hyper_network.train() encoder_visible.train() encoder_pocket.train() total_loss_all = 0.0 total_loss_r = 0.0 total_loss_kld = 0.0 for i, point_data in enumerate(points_dataloader, 1): # get only visible part of point cloud X = point_data['non-visible'] X = X.to(device, dtype=torch.float) # get not visible part of point cloud X_visible = point_data['visible'] X_visible = X_visible.to(device, dtype=torch.float) # get whole point cloud X_whole = point_data['cloud'] X_whole = X_whole.to(device, dtype=torch.float) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1) X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1) codes, mu, logvar = encoder_pocket(X) mu_visible = encoder_visible(X_visible) target_networks_weights = hyper_network( torch.cat((codes, mu_visible), 1)) X_rec = torch.zeros(X_whole.shape).to(device) for j, target_network_weights in enumerate( target_networks_weights): target_network = aae.TargetNetwork( config, target_network_weights).to(device) if not config['target_network_input'][ 'constant'] or target_network_input is None: target_network_input = generate_points( config=config, epoch=epoch, size=(X_whole.shape[2], X_whole.shape[1])) X_rec[j] = torch.transpose( target_network(target_network_input.to(device)), 0, 1) loss_r = torch.mean(config['reconstruction_coef'] * reconstruction_loss( X_whole.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 - logvar).sum() loss_all = loss_r + loss_kld e_hn_optimizer.zero_grad() encoder_visible.zero_grad() encoder_pocket.zero_grad() hyper_network.zero_grad() loss_all.backward() e_hn_optimizer.step() total_loss_r += loss_r.item() total_loss_kld += loss_kld.item() total_loss_all += loss_all.item() log.info(f'[{epoch}/{config["max_epochs"]}] ' f'Loss_ALL: {total_loss_all / i:.4f} ' f'Loss_R: {total_loss_r / i:.4f} ' f'Loss_E: {total_loss_kld / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}') losses_e.append(total_loss_r) losses_kld.append(total_loss_kld) losses_eg.append(total_loss_all) # # Save intermediate results # X = X.cpu().numpy() X_whole = X_whole.cpu().numpy() X_rec = X_rec.detach().cpu().numpy() if epoch % config['save_frequency'] == 0: for k in range(min(5, X_rec.shape[0])): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png')) plt.close(fig) fig = plot_3d_point_cloud(X_whole[k][0], X_whole[k][1], X_whole[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_real.png')) plt.close(fig) fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_visible.png')) plt.close(fig) if config['clean_weights_dir']: log.debug('Cleaning weights path: %s' % weights_path) shutil.rmtree(weights_path, ignore_errors=True) os.makedirs(weights_path, exist_ok=True) if epoch % config['save_frequency'] == 0: log.debug('Saving data...') torch.save(hyper_network.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(encoder_visible.state_dict(), join(weights_path, f'{epoch:05}_EV.pth')) torch.save(encoder_pocket.state_dict(), join(weights_path, f'{epoch:05}_EP.pth')) torch.save(e_hn_optimizer.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth')) np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e)) np.save(join(metrics_path, f'{epoch:05}_KLD'), np.array(losses_kld)) np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
def dist_chamfer(x, y): from losses.champfer_loss import ChamferLoss from utils.util import cuda_setup chamfer_loss = ChamferLoss().to(cuda_setup()) P = chamfer_loss.batch_pairwise_dist(x, y) return P.min(1)[0], P.min(2)[0]
def main(config): random.seed(config['seed']) torch.manual_seed(config['seed']) torch.cuda.manual_seed_all(config['seed']) results_dir = prepare_results_dir(config) starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger(__name__) device = cuda_setup(config['cuda'], config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.debug("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True) # # Models # arch = import_module(f"model.architectures.{config['arch']}") G = arch.Generator(config).to(device) E = arch.Encoder(config).to(device) G.apply(weights_init) E.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': reconstruction_loss = EMD().to(device) else: raise ValueError( f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Optimizers # EG_optim = getattr(optim, config['optimizer']['EG']['type']) EG_optim = EG_optim(chain(E.parameters(), G.parameters()), **config['optimizer']['EG']['hyperparams']) if starting_epoch > 1: G.load_state_dict( torch.load(join(weights_path, f'{starting_epoch-1:05}_G.pth'))) E.load_state_dict( torch.load(join(weights_path, f'{starting_epoch-1:05}_E.pth'))) EG_optim.load_state_dict( torch.load(join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) for epoch in range(starting_epoch, config['max_epochs'] + 1): start_epoch_time = datetime.now() G.train() E.train() total_loss = 0.0 for i, point_data in enumerate(points_dataloader, 1): log.debug('-' * 20) X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_rec = G(E(X)) loss = torch.mean(config['reconstruction_coef'] * reconstruction_loss( X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) EG_optim.zero_grad() E.zero_grad() G.zero_grad() loss.backward() total_loss += loss.item() EG_optim.step() log.debug(f'[{epoch}: ({i})] ' f'Loss: {loss.item():.4f} ' f'Time: {datetime.now() - start_epoch_time}') log.debug(f'[{epoch}/{config["max_epochs"]}] ' f'Loss: {total_loss / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}') # # Save intermediate results # G.eval() E.eval() with torch.no_grad(): X_rec = G(E(X)).data.cpu().numpy() for k in range(5): fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_real.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', f'{epoch:05}_{k}_reconstructed.png')) plt.close(fig) if epoch % config['save_frequency'] == 0: torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(EG_optim.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth'))
def main(config): global results_dir, max_epochs, batch_size # seeds random.seed(2019) torch.manual_seed(2019) torch.cuda.manual_seed_all(2019) #setting directory to save results and weights results_dir = join(results_dir , experiment) results_dir = prepare_results_dir(results_dir, b_clean=False) weights_path = join(results_dir, 'weights') # find latest epoch # starting_epoch = find_latest_epoch(results_dir) + 1 device = cuda_setup(True, 0) log = logging.getLogger(__name__) dataset = VesselDataset3() points_dataloader = DataLoader(dataset,batch_size= batch_size, shuffle = True, num_workers = 8, drop_last=True, pin_memory=True) noise = tf.placeholder(tf.float32, [None, n_points, 3]) G = Generator().to(device) E = Encoder().to(device) G.apply(weights_init) E.apply(weights_init) reconstruction_loss = ChamferLoss().to(device) EG_optim = torch.optim.Adam(chain(E.parameters(), G.parameters()), lr= 0.0005, weight_decay= 0, betas= [0.9, 0.999],amsgrad= False) # load_latest_epoch(E, G, EG_optim, weights_path ,starting_epoch) for epoch in range(0, max_epochs ): start_epoch_time = datetime.now() G.train() E.train() total_loss = 0.0 for i, point_data in enumerate(points_dataloader, 0): F, X = point_data X = X.to(device) F = F.to(device) # print(X.shape) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) # print("features shape: ", F.shape) # print("points shape: ", X.shape) X_rec = G(E(F)) # print("reconstructed points shape: ", X_rec.shape) loss = torch.mean(0.05 * reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) EG_optim.zero_grad() E.zero_grad() G.zero_grad() loss.backward() total_loss += loss.item() EG_optim.step() print(f'[{epoch}: ({i})] ' f'Loss: {loss.item():.4f} ' f'Time: {datetime.now() - start_epoch_time}') log.debug( f'[{epoch}/{max_epochs}] ' f'Loss: {total_loss / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}' ) G.eval() E.eval() with torch.no_grad(): X_rec = G(E(F)).data.cpu().numpy() X_cpu = X.cpu().numpy() print(X_rec.min(axis=0), X_rec.max(axis=0)) print(X_cpu.min(axis=0), X_cpu.max(axis=0))
def main(config): global results_dir, max_epochs, batch_size # setting seeds random.seed(2019) torch.manual_seed(2019) torch.cuda.manual_seed_all(2019) #setting directory to save results and weights results_dir = join(results_dir, experiment) results_dir = prepare_results_dir(results_dir, b_clean=False) weights_path = join(results_dir, 'weights') #finding last saved epoch if exists in resutls directory starting_epoch = find_latest_epoch(results_dir) + 1 #setting device for pytorch usage device = cuda_setup(True, 0) #use to log useful information log = logging.getLogger(__name__) #load vessels dataset dataset = VesselDataset2( root_dir= "/home/texs/Documents/Repositorios/point_cloud_reconstruction/data/ModelNet10" ) points_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True, pin_memory=True) #loading models and weights G = Generator().to(device) E = Encoder().to(device) G.apply(weights_init) E.apply(weights_init) # setting reconstruction loss reconstruction_loss = ChamferLoss().to(device) # reconstruction_loss = EMD().to(device) # reconstruction_loss = EMD().to(device) #optimization in models parameters EG_optim = torch.optim.Adam(chain(E.parameters(), G.parameters()), lr=0.0005, weight_decay=0, betas=[0.9, 0.999], amsgrad=False) #loading weights if they exists in results directory load_latest_epoch(E, G, EG_optim, weights_path, starting_epoch) # training for epoch in range(starting_epoch, max_epochs): start_epoch_time = datetime.now() G.train() E.train() total_loss = 0.0 for i, point_data in enumerate(points_dataloader, 0): X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) X_rec = G(E(X)) loss = torch.mean(0.05 * reconstruction_loss( X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) EG_optim.zero_grad() E.zero_grad() G.zero_grad() loss.backward() total_loss += loss.item() EG_optim.step() print(f'[{epoch}: ({i})] ' f'Loss: {loss.item():.4f} ' f'Time: {datetime.now() - start_epoch_time}') log.debug(f'[{epoch}/{max_epochs}] ' f'Loss: {total_loss / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}') G.eval() E.eval() with torch.no_grad(): X_rec = G(E(X)).data.cpu().numpy() X_cpu = X.cpu().numpy() save_point_cloud(X_cpu, X_rec, epoch, n_fig=5, results_dir=results_dir) if epoch % save_frequency == 0: save_weights(E, G, EG_optim, weights_path, epoch)
def main(config): random.seed(config['seed']) torch.manual_seed(config['seed']) torch.cuda.manual_seed_all(config['seed']) results_dir = prepare_results_dir(config) starting_epoch = find_latest_epoch(results_dir) + 1 if not exists(join(results_dir, 'config.json')): with open(join(results_dir, 'config.json'), mode='w') as f: json.dump(config, f) setup_logging(results_dir) log = logging.getLogger('vae') device = cuda_setup(config['cuda'], config['gpu']) log.debug(f'Device variable: {device}') if device.type == 'cuda': log.debug(f'Current CUDA device: {torch.cuda.current_device()}') weights_path = join(results_dir, 'weights') # # Dataset # dataset_name = config['dataset'].lower() if dataset_name == 'shapenet': from datasets.shapenet import ShapeNetDataset dataset = ShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) else: raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' f'`faust`. Got: `{dataset_name}`') log.debug("Selected {} classes. Loaded {} samples.".format( 'all' if not config['classes'] else ','.join(config['classes']), len(dataset))) points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=config['shuffle'], num_workers=config['num_workers'], drop_last=True, pin_memory=True) # # Models # arch = import_module(f"models.{config['arch']}") G = arch.Generator(config).to(device) E = arch.Encoder(config).to(device) G.apply(weights_init) E.apply(weights_init) if config['reconstruction_loss'].lower() == 'chamfer': from losses.champfer_loss import ChamferLoss reconstruction_loss = ChamferLoss().to(device) elif config['reconstruction_loss'].lower() == 'earth_mover': from losses.earth_mover_distance import EMD reconstruction_loss = EMD().to(device) elif config['reconstruction_loss'].lower() == 'cramer_wold': from losses.cramer_wold import CWSample reconstruction_loss = CWSample().to(device) else: raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' f'`earth_mover`, got: {config["reconstruction_loss"]}') # # Float Tensors # fixed_noise = torch.FloatTensor(config['batch_size'], config['z_size'], 1) fixed_noise.normal_(mean=0, std=0.2) std_assumed = torch.tensor(0.2) fixed_noise = fixed_noise.to(device) std_assumed = std_assumed.to(device) # # Optimizers # EG_optim = getattr(optim, config['optimizer']['EG']['type']) EG_optim = EG_optim(chain(E.parameters(), G.parameters()), **config['optimizer']['EG']['hyperparams']) if starting_epoch > 1: G.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_G.pth'))) E.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_E.pth'))) EG_optim.load_state_dict(torch.load( join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) losses = [] with trange(starting_epoch, config['max_epochs'] + 1) as t: for epoch in t: start_epoch_time = datetime.now() G.train() E.train() total_loss = 0.0 losses_eg = [] losses_e = [] losses_kld = [] for i, point_data in enumerate(points_dataloader, 1): # log.debug('-' * 20) X, _ = point_data X = X.to(device) # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] if X.size(-1) == 3: X.transpose_(X.dim() - 2, X.dim() - 1) codes, mu, logvar = E(X) X_rec = G(codes) loss_e = torch.mean( # config['reconstruction_coef'] * reconstruction_loss(X.permute(0, 2, 1) + 0.5, X_rec.permute(0, 2, 1) + 0.5)) loss_kld = config['reconstruction_coef'] * cw_distance(mu) # loss_kld = -0.5 * torch.mean( # 1 - 2.0 * torch.log(std_assumed) + logvar - # (mu.pow(2) + logvar.exp()) / torch.pow(std_assumed, 2)) loss_eg = loss_e + loss_kld EG_optim.zero_grad() E.zero_grad() G.zero_grad() loss_eg.backward() total_loss += loss_eg.item() EG_optim.step() losses_e.append(loss_e.item()) losses_kld.append(loss_kld.item()) losses_eg.append(loss_eg.item()) # log.debug t.set_description( f'[{epoch}: ({i})] ' f'Loss_EG: {loss_eg.item():.4f} ' f'(REC: {loss_e.item(): .4f}' f' KLD: {loss_kld.item(): .4f})' f' Time: {datetime.now() - start_epoch_time}' ) t.set_description( f'[{epoch}/{config["max_epochs"]}] ' f'Loss_G: {total_loss / i:.4f} ' f'Time: {datetime.now() - start_epoch_time}' ) losses.append([ np.mean(losses_e), np.mean(losses_kld), np.mean(losses_eg) ]) # # Save intermediate results # G.eval() E.eval() with torch.no_grad(): fake = G(fixed_noise).data.cpu().numpy() codes, _, _ = E(X) X_rec = G(codes).data.cpu().numpy() X_numpy = X.cpu().numpy() for k in range(5): fig = plot_3d_point_cloud(X_numpy[k][0], X_numpy[k][1], X_numpy[k][2], in_u_sphere=True, show=False) fig.savefig( join(results_dir, 'samples', f'{epoch}_{k}_real.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2], in_u_sphere=True, show=False, title=str(epoch)) fig.savefig( join(results_dir, 'samples', 'fixed', f'{epoch:05}_{k}_fixed.png')) plt.close(fig) for k in range(5): fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False) fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png')) plt.close(fig) if epoch % config['save_frequency'] == 0: df = pd.DataFrame(losses, columns=['loss_e', 'loss_kld', 'loss_eg']) df.to_json(join(results_dir, 'losses', 'losses.json')) fig = df.plot.line().get_figure() fig.savefig(join(results_dir, 'losses', f'{epoch:05}_{k}.png')) torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) torch.save(EG_optim.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth'))
def main(config: dict): # region Setup seed_setup(config['setup']['seed']) run_mode: str = config['mode'] result_dir_path: str = get_results_dir_path(config, run_mode) if run_mode == 'training': dirs_to_create = ('weights', 'samples', 'metrics') weights_path = join(result_dir_path, 'weights') metrics_path = join(result_dir_path, 'metrics') elif run_mode == 'experiments': dirs_to_create = tuple(experiment_functions_dict.keys()) weights_path = join(get_results_dir_path(config, 'training'), 'weights') metrics_path = join(get_results_dir_path(config, 'training'), 'metrics') else: raise ValueError("mode should be `training` or `experiments`") results_dir_setup(result_dir_path, dirs_to_create) with open(join(result_dir_path, 'last_config.json'), mode='w') as f: json.dump(config, f) logging_setup(result_dir_path) log = logging.getLogger() log.info(f'Current mode {run_mode}') if config['telegram_logger']['enable']: tg_log = TelegramLogger.get_logger(config['telegram_logger']) device = cuda_setup(config['setup']['gpu_id']) log.info(f'Device variable: {device}') reconstruction_loss = ChamferLoss().to(device) full_model = FullModel(config['full_model']).to(device) full_model.apply(weights_init) optimizer = getattr(optim, config['training']['optimizer']['type']) # class optimizer = optimizer(full_model.parameters(), **config['training']['optimizer']['hyperparams']) scheduler = getattr(optim.lr_scheduler, config['training']['lr_scheduler']['type']) # class scheduler = scheduler(optimizer, **config['training']['lr_scheduler']['hyperparams']) log.info(f'Model {get_model_name(config)} created') latest_epoch = find_latest_epoch(result_dir_path if run_mode == "training" else weights_path) log.info(f'Latest epoch found: {latest_epoch}') if latest_epoch > 0: if run_mode == "training": latest_epoch = restore_model_state(weights_path, metrics_path, config['setup']['gpu_id'], latest_epoch, "latest", full_model, optimizer, scheduler) elif run_mode == "experiments": latest_epoch = restore_model_state(weights_path, metrics_path, config['setup']['gpu_id'], latest_epoch, config['experiments']['epoch'], full_model) log.info(f'Restored epoch : {latest_epoch}') elif run_mode == "experiments": raise FileNotFoundError("no weights found at ", weights_path) # endregion Setup train_dataset, val_dataset_dict, test_dataset_dict = get_datasets( config['dataset']) log.info( f'Dataset loaded for classes: {[cat_name for cat_name in val_dataset_dict.keys()]}' ) if run_mode == 'training': samples_path = join(result_dir_path, 'samples') train_dataloader = DataLoader( train_dataset, pin_memory=True, **config['training']['dataloader']['train']) val_dataloaders_dict = { cat_name: DataLoader(cat_ds, pin_memory=True, **config['training']['dataloader']['val']) for cat_name, cat_ds in val_dataset_dict.items() } if latest_epoch == 0: best_epoch_loss = np.Infinity train_losses = [] val_losses = [] else: train_losses, val_losses, best_epoch_loss = restore_metrics( metrics_path, latest_epoch) for epoch in range(latest_epoch + 1, config['training']['max_epoch'] + 1): start_epoch_time = datetime.now() log.debug("Epoch: %s" % epoch) full_model, optimizer, epoch_loss_all, epoch_loss_kld, epoch_loss_r, latest_existing, latest_gt, latest_rec \ = train_epoch(epoch, full_model, optimizer, train_dataloader, device, reconstruction_loss, config['training']['loss_coef']) scheduler.step() train_losses.append( np.array([epoch_loss_all, epoch_loss_r, epoch_loss_kld])) log_string = f'[{epoch}/{config["training"]["max_epoch"]}] ' \ f'Loss_ALL: {epoch_loss_all:.4f} ' \ f'Loss_R: {epoch_loss_r:.4f} ' \ f'Loss_E: {epoch_loss_kld:.4f} ' \ f'Time: {datetime.now() - start_epoch_time}' log.info(log_string) train_plots = [] for k in range(min(5, latest_rec.shape[0])): train_plots.append( save_plot(latest_existing[k], epoch, k, samples_path, 'existing')) train_plots.append( save_plot(latest_rec[k], epoch, k, samples_path, 'reconstructed')) train_plots.append( save_plot(latest_gt[k].T, epoch, k, samples_path, 'gt')) if config['telegram_logger']['enable']: tg_log.log_images(train_plots[:9], log_string) epoch_val_losses, epoch_val_samples = val_epoch( epoch, full_model, device, val_dataloaders_dict, val_dataset_dict.keys(), reconstruction_loss, config['training']['loss_coef']) is_new_best = epoch_val_losses['total'][0] < best_epoch_loss if is_new_best: best_epoch_loss = epoch_val_losses['total'][0] val_losses.append(epoch_val_losses['total']) log_string = f'val results[{config["training"]["loss_coef"]}*our_cd]:\n' for k, v in epoch_val_losses.items(): log_string += k + ': ' + str(v) + '\n' if is_new_best: log_string += "new best epoch" log.info(log_string) val_plots = [] for cat_name, sample in epoch_val_samples.items(): val_plots.append( save_plot(sample[0], epoch, cat_name, samples_path, 'val_existing')) val_plots.append( save_plot(sample[2], epoch, cat_name, samples_path, 'val_rec')) val_plots.append( save_plot(sample[1].T, epoch, cat_name, samples_path, 'val_gt')) if config['telegram_logger']['enable']: chosen_plot_idx = np.random.choice( np.arange(len(val_plots) / 3, dtype=np.int), int(np.min([3, len(val_plots) / 3])), replace=False) plots_to_log = [] for idx in chosen_plot_idx: plots_to_log.extend(val_plots[3 * idx:3 * idx + 3]) tg_log.log_images(plots_to_log, log_string) if (epoch % config['training']['state_save_frequency'] == 0 or is_new_best) \ and epoch > config['training'].get('min_save_epoch', 0): torch.save(full_model.state_dict(), join(weights_path, f'{epoch:05}_model.pth')) torch.save(optimizer.state_dict(), join(weights_path, f'{epoch:05}_O.pth')) torch.save(scheduler.state_dict(), join(weights_path, f'{epoch:05}_S.pth')) np.save(join(metrics_path, f'{epoch:05}_train'), np.array(train_losses)) np.save(join(metrics_path, f'{epoch:05}_val'), np.array(val_losses)) log_string = "Epoch: %s saved" % epoch log.debug(log_string) if config['telegram_logger']['enable']: tg_log.log(log_string) elif run_mode == 'experiments': # from datasets.real_data import RealDataNPYDataset # test_dataset_dict = RealDataNPYDataset(root_dir="D:\\UJ\\bachelors\\3d-point-clouds-autocomplete\\data\\real_car_data") full_model.eval() with torch.no_grad(): for experiment_name, experiment_dict in config['experiments'][ 'settings'].items(): if experiment_dict.pop('execute', False): log.info(experiment_name) experiment_functions_dict[experiment_name]( full_model, device, test_dataset_dict, result_dir_path, latest_epoch, **experiment_dict) exit(0)
def evaluate_generativity(full_model: FullModel, device, datasets_dict, results_dir, epoch, batch_size, num_workers, mean=0.0, std=0.005): dataloaders_dict = { cat_name: DataLoader(cat_ds, pin_memory=True, batch_size=1, num_workers=num_workers) for cat_name, cat_ds in datasets_dict.items() } chamfer_loss = ChamferLoss().to(device) with torch.no_grad(): results = {} for cat_name, dl in dataloaders_dict.items(): cat_gt = [] for data in dl: _, missing, _, _ = data missing = missing.to(device) cat_gt.append(missing) cat_gt = torch.cat(cat_gt).contiguous() cat_results = {} for data in tqdm(dl, total=len(dl)): existing, _, _, _ = data existing = existing.to(device) obj_recs = [] for j in range(len(cat_gt)): fixed_noise = torch.zeros( 1, full_model.get_noise_size()).normal_( mean=mean, std=std).to(device) reconstruction = full_model(existing, None, [1, 2048, 3], epoch, device, noise=fixed_noise) pc = reconstruction.cpu().detach().numpy()[0] obj_recs.append( torch.from_numpy(pc.T[ pc[1].argsort()[:1024]]).unsqueeze(0).to(device)) obj_recs = torch.cat(obj_recs) for k, v in compute_all_metrics(obj_recs, cat_gt, batch_size, chamfer_loss).items(): cat_results[k] = cat_results.get(k, 0.0) + v.item() cat_results['jsd'] = cat_results.get( 'jsd', 0.0) + jsd_between_point_cloud_sets( obj_recs.cpu().detach().numpy(), cat_gt.cpu().numpy()) results[cat_name] = cat_results print(cat_name, cat_results) with open(join(results_dir, 'evaluate_generativity', str(epoch) + 'eval_gen_by_cat.json'), mode='w') as f: json.dump(results, f)