コード例 #1
0
def different_number_of_points(encoder,
                               hyper_network,
                               x,
                               device,
                               results_dir,
                               epoch,
                               amount=5,
                               number_of_points_list=(10, 100, 1000, 2048,
                                                      10000)):
    log.info("Different number of points")
    x = x[:amount]

    latent, _, _ = encoder(x)
    weights_diff = hyper_network(latent)
    x = x.cpu().numpy()
    for k in range(amount):
        np.save(join(results_dir, 'different_number_points', f'{k}_real'),
                np.array(x[k]))
        fig = plot_3d_point_cloud(x[k][0],
                                  x[k][1],
                                  x[k][2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(
            join(results_dir, 'different_number_points', f'{k}_real.png'))
        plt.close(fig)

        target_network = aae.TargetNetwork(config, weights_diff[k])

        for number_of_points in number_of_points_list:
            target_network_input = generate_points(config=config,
                                                   epoch=epoch,
                                                   size=(number_of_points,
                                                         x.shape[1]))
            x_diff = torch.transpose(
                target_network(target_network_input.to(device)), 0,
                1).cpu().numpy()

            np.save(
                join(results_dir, 'different_number_points',
                     f'{k}_target_network_input'),
                np.array(target_network_input))
            np.save(
                join(results_dir, 'different_number_points',
                     f'{k}_{number_of_points}'), np.array(x_diff))

            fig = plot_3d_point_cloud(x_diff[0],
                                      x_diff[1],
                                      x_diff[2],
                                      in_u_sphere=True,
                                      show=False)
            fig.savefig(
                join(results_dir, 'different_number_points',
                     f'{k}_{number_of_points}.png'))
            plt.close(fig)
コード例 #2
0
    def process_existing(pcd, cat_name, name, i, j):
        np.save(
            join(results_dir, 'same_model_different_slices',
                 f'{cat_name}_{i}_{j}_{name}_pcd'), pcd)
        noise = torch.zeros(1, full_model.get_noise_size()).normal_(mean=mean,
                                                                    std=std)
        np.save(
            join(results_dir, 'same_model_different_slices',
                 f'{cat_name}_{i}_{j}_{name}_noise'), noise.numpy())

        pcd = torch.from_numpy(pcd).unsqueeze(0).to(device)
        noise = noise.to(device)
        rec = full_model(pcd, None, [1, 2048, 3], epoch, device,
                         noise=noise)[0].cpu().numpy()

        np.save(
            join(results_dir, 'same_model_different_slices',
                 f'{cat_name}_{i}_{j}_{name}_rec'), rec)

        fig = plot_3d_point_cloud(rec[0],
                                  rec[1],
                                  rec[2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(
            join(results_dir, 'same_model_different_slices',
                 f'{cat_name}_{i}_{j}_{name}_rec.png'))
        plt.close(fig)
コード例 #3
0
def reconstruction(encoder,
                   hyper_network,
                   device,
                   x,
                   results_dir,
                   epoch,
                   amount=5):
    log.info("Reconstruction")
    x = x[:amount]

    z_a, _, _ = encoder(x)
    weights_rec = hyper_network(z_a)
    x = x.cpu().numpy()

    for k in range(amount):
        target_network = aae.TargetNetwork(config, weights_rec[k])
        target_network_input = generate_points(config=config,
                                               epoch=epoch,
                                               size=(x.shape[2], x.shape[1]))
        x_rec = torch.transpose(
            target_network(target_network_input.to(device)), 0,
            1).cpu().numpy()

        np.save(
            join(results_dir, 'reconstruction', f'{k}_target_network_input'),
            np.array(target_network_input))
        np.save(join(results_dir, 'reconstruction', f'{k}_real'),
                np.array(x[k]))
        np.save(join(results_dir, 'reconstruction', f'{k}_reconstructed'),
                np.array(x_rec))

        fig = plot_3d_point_cloud(x_rec[0],
                                  x_rec[1],
                                  x_rec[2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(
            join(results_dir, 'reconstruction', f'{k}_reconstructed.png'))
        plt.close(fig)

        fig = plot_3d_point_cloud(x[k][0],
                                  x[k][1],
                                  x[k][2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(join(results_dir, 'reconstruction', f'{k}_real.png'))
        plt.close(fig)
コード例 #4
0
def interpolation_between_two_points(encoder,
                                     hyper_network,
                                     device,
                                     x,
                                     results_dir,
                                     epoch,
                                     amount=30,
                                     image_points=1000,
                                     transitions=21):
    log.info("Interpolations between two points")
    x = x[:amount]

    z_a, _, _ = encoder(x)
    weights_int = hyper_network(z_a)
    for k in range(amount):
        target_network = aae.TargetNetwork(config, weights_int[k])
        target_network_input = generate_points(config=config,
                                               epoch=epoch,
                                               size=(image_points, x.shape[1]))
        x_a = target_network_input[torch.argmin(target_network_input,
                                                dim=0)[2]][None, :]
        x_b = target_network_input[torch.argmax(target_network_input,
                                                dim=0)[2]][None, :]

        x_rec = torch.transpose(
            target_network(target_network_input.to(device)), 0,
            1).cpu().numpy()
        x_int = torch.zeros(transitions, x.shape[1])
        for j, alpha in enumerate(np.linspace(0, 1, transitions)):
            z_int = (1 - alpha) * x_a + alpha * x_b  # interpolate point
            x_int[j] = target_network(z_int.to(device))

        x_int = torch.transpose(x_int, 0, 1).cpu().numpy()

        np.save(
            join(results_dir, 'points_interpolation',
                 f'{k}_target_network_input'), np.array(target_network_input))
        np.save(
            join(results_dir, 'points_interpolation', f'{k}_reconstruction'),
            np.array(x_rec))
        np.save(
            join(results_dir, 'points_interpolation',
                 f'{k}_points_interpolation'), np.array(x_int))

        fig = plot_3d_point_cloud(x_rec[0],
                                  x_rec[1],
                                  x_rec[2],
                                  in_u_sphere=True,
                                  show=False,
                                  x1=x_int[0],
                                  y1=x_int[1],
                                  z1=x_int[2])
        fig.savefig(
            join(results_dir, 'points_interpolation',
                 f'{k}_points_interpolation.png'))
        plt.close(fig)
コード例 #5
0
def save_plot(X, epoch, k, results_dir, t):
    fig = plot_3d_point_cloud(X[0],
                              X[1],
                              X[2],
                              in_u_sphere=True,
                              show=False,
                              title=f'{t}_{k} epoch: {epoch}')
    fig_path = join(results_dir, f'{epoch}_{k}_{t}.png')
    fig.savefig(fig_path)
    plt.close(fig)
    return fig_path
コード例 #6
0
def interpolation(x,
                  encoder,
                  hyper_network,
                  device,
                  results_dir,
                  epoch,
                  amount=5,
                  transitions=10):
    log.info(f'Interpolations')

    for k in range(amount):
        x_a = x[None, 2 * k, :, :]
        x_b = x[None, 2 * k + 1, :, :]

        with torch.no_grad():
            z_a, mu_a, var_a = encoder(x_a)
            z_b, mu_b, var_b = encoder(x_b)

        for j, alpha in enumerate(np.linspace(0, 1, transitions)):
            z_int = (1 - alpha
                     ) * z_a + alpha * z_b  # interpolate in the latent space
            weights_int = hyper_network(
                z_int)  # decode the interpolated sample

            target_network = aae.TargetNetwork(config, weights_int[0])
            target_network_input = generate_points(config=config,
                                                   epoch=epoch,
                                                   size=(x.shape[2],
                                                         x.shape[1]))
            x_int = torch.transpose(
                target_network(target_network_input.to(device)), 0,
                1).cpu().numpy()

            np.save(
                join(results_dir, 'interpolations',
                     f'{k}_{j}_target_network_input'),
                np.array(target_network_input))
            np.save(
                join(results_dir, 'interpolations', f'{k}_{j}_interpolation'),
                np.array(x_int))

            fig = plot_3d_point_cloud(x_int[0],
                                      x_int[1],
                                      x_int[2],
                                      in_u_sphere=True,
                                      show=False)

            fig.savefig(
                join(results_dir, 'interpolations',
                     f'{k}_{j}_interpolation.png'))
            plt.close(fig)
コード例 #7
0
def sphere(encoder,
           hyper_network,
           device,
           x,
           results_dir,
           epoch,
           amount=10,
           image_points=10240,
           start=2.0,
           end=4.0,
           transitions=21):
    log.info("Sphere")
    x = x[:amount]

    z_a, _, _ = encoder(x)
    weights_sphere = hyper_network(z_a)
    x = x.cpu().numpy()
    for k in range(amount):
        target_network = aae.TargetNetwork(config, weights_sphere[k])
        target_network_input = generate_points(config=config,
                                               epoch=epoch,
                                               size=(image_points, x.shape[1]),
                                               normalize_points=False)
        x_rec = torch.transpose(
            target_network(target_network_input.to(device)), 0,
            1).cpu().numpy()

        np.save(join(results_dir, 'sphere', f'{k}_real'), np.array(x[k]))
        np.save(
            join(results_dir, 'sphere',
                 f'{k}_point_cloud_before_normalization'),
            np.array(target_network_input))
        np.save(join(results_dir, 'sphere', f'{k}_reconstruction'),
                np.array(x_rec))

        target_network_input = target_network_input / torch.norm(
            target_network_input, dim=1).view(-1, 1)
        np.save(
            join(results_dir, 'sphere',
                 f'{k}_point_cloud_after_normalization'),
            np.array(target_network_input))

        for coeff in np.linspace(start, end, num=transitions):
            coeff = round(coeff, 1)
            x_sphere = torch.transpose(
                target_network(target_network_input.to(device) * coeff), 0,
                1).cpu().numpy()

            np.save(
                join(
                    results_dir, 'sphere',
                    f'{k}_output_from_target_network_for_point_cloud_after_normalization_coefficient_{coeff}'
                ), np.array(x_sphere))

            fig = plot_3d_point_cloud(x_sphere[0],
                                      x_sphere[1],
                                      x_sphere[2],
                                      in_u_sphere=True,
                                      show=False)
            fig.savefig(join(results_dir, 'sphere', f'{k}_{coeff}_sphere.png'))
            plt.close(fig)

        fig = plot_3d_point_cloud(x[k][0],
                                  x[k][1],
                                  x[k][2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(join(results_dir, 'sphere', f'{k}_real.png'))
        plt.close(fig)
コード例 #8
0
def main(config):
    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])

    results_dir = prepare_results_dir(config)
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger(__name__)

    device = cuda_setup(config['cuda'], config['gpu'])
    log.debug(f'Device variable: {device}')
    if device.type == 'cuda':
        log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')
    log.debug("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset, batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True, pin_memory=True)

    #
    # Models
    #
    arch = import_module(f"models.{config['arch']}")
    G = arch.Generator(config).to(device)
    E = arch.Encoder(config).to(device)
    D = arch.Discriminator(config).to(device)

    G.apply(weights_init)
    E.apply(weights_init)
    D.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or '
                         f'`earth_mover`, got: {config["reconstruction_loss"]}')
    #
    # Float Tensors
    #
    distribution = config['distribution'].lower()
    if distribution == 'bernoulli':
        p = torch.tensor(config['p']).to(device)
        sampler = Bernoulli(probs=p)
        fixed_noise = sampler.sample(torch.Size([config['batch_size'],
                                                 config['z_size']]))
    elif distribution == 'beta':
        fixed_noise_np = np.random.beta(config['z_beta_a'],
                                        config['z_beta_b'],
                                        size=(config['batch_size'],
                                              config['z_size']))
        fixed_noise = torch.tensor(fixed_noise_np).float().to(device)
    else:
        raise ValueError('Invalid distribution for binaray model.')

    #
    # Optimizers
    #
    EG_optim = getattr(optim, config['optimizer']['EG']['type'])
    EG_optim = EG_optim(chain(E.parameters(), G.parameters()),
                        **config['optimizer']['EG']['hyperparams'])

    D_optim = getattr(optim, config['optimizer']['D']['type'])
    D_optim = D_optim(D.parameters(),
                      **config['optimizer']['D']['hyperparams'])

    if starting_epoch > 1:
        G.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_G.pth')))
        E.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_E.pth')))
        D.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_D.pth')))

        D_optim.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_Do.pth')))

        EG_optim.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_EGo.pth')))

    loss_d_tot, loss_gp_tot, loss_e_tot, loss_g_tot = [], [], [], []
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()

        G.train()
        E.train()
        D.train()

        total_loss_eg = 0.0
        total_loss_d = 0.0
        for i, point_data in enumerate(points_dataloader, 1):
            log.debug('-' * 20)

            X, _ = point_data
            X = X.to(device)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            codes = E(X)
            if distribution == 'bernoulli':
                noise = sampler.sample(fixed_noise.size())
            elif distribution == 'beta':
                noise_np = np.random.beta(config['z_beta_a'],
                                          config['z_beta_b'],
                                          size=(config['batch_size'],
                                                config['z_size']))
                noise = torch.tensor(noise_np).float().to(device)
            synth_logit = D(codes)
            real_logit = D(noise)
            loss_d = torch.mean(synth_logit) - torch.mean(real_logit)
            loss_d_tot.append(loss_d)

            # Gradient Penalty
            alpha = torch.rand(config['batch_size'], 1).to(device)
            differences = codes - noise
            interpolates = noise + alpha * differences
            disc_interpolates = D(interpolates)

            gradients = grad(
                outputs=disc_interpolates,
                inputs=interpolates,
                grad_outputs=torch.ones_like(disc_interpolates).to(device),
                create_graph=True,
                retain_graph=True,
                only_inputs=True)[0]
            slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1))
            gradient_penalty = ((slopes - 1) ** 2).mean()
            loss_gp = config['gp_lambda'] * gradient_penalty
            loss_gp_tot.append(loss_gp)
            ###

            loss_d += loss_gp

            D_optim.zero_grad()
            D.zero_grad()

            loss_d.backward(retain_graph=True)
            total_loss_d += loss_d.item()
            D_optim.step()

            # EG part of training
            X_rec = G(codes)

            loss_e = torch.mean(
                config['reconstruction_coef'] *
                reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))
            loss_e_tot.append(loss_e)

            synth_logit = D(codes)

            loss_g = -torch.mean(synth_logit)
            loss_g_tot.append(loss_g)

            loss_eg = loss_e + loss_g
            EG_optim.zero_grad()
            E.zero_grad()
            G.zero_grad()

            loss_eg.backward()
            total_loss_eg += loss_eg.item()
            EG_optim.step()

            log.debug(f'[{epoch}: ({i})] '
                      f'Loss_D: {loss_d.item():.4f} '
                      f'(GP: {loss_gp.item(): .4f}) '
                      f'Loss_EG: {loss_eg.item():.4f} '
                      f'(REC: {loss_e.item(): .4f}) '
                      f'Time: {datetime.now() - start_epoch_time}')

        log.debug(
            f'[{epoch}/{config["max_epochs"]}] '
            f'Loss_D: {total_loss_d / i:.4f} '
            f'Loss_EG: {total_loss_eg / i:.4f} '
            f'Time: {datetime.now() - start_epoch_time}'
        )

        #
        # Save intermediate results
        #
        G.eval()
        E.eval()
        D.eval()
        with torch.no_grad():
            fake = G(fixed_noise).data.cpu().numpy()
            X_rec = G(E(X)).data.cpu().numpy()
            X = X.data.cpu().numpy()

        plt.figure(figsize=(16, 9))
        plt.plot(loss_d_tot, 'r-', label="loss_d")
        plt.plot(loss_gp_tot, 'g-', label="loss_gp")
        plt.plot(loss_e_tot, 'b-', label="loss_e")
        plt.plot(loss_g_tot, 'k-', label="loss_g")
        plt.legend()
        plt.xlabel("Batch number")
        plt.xlabel("Loss value")
        plt.savefig(
            join(results_dir, 'samples', f'loss_plot.png'))
        plt.close()

        for k in range(5):
            fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2],
                                      in_u_sphere=True, show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples', f'{epoch:05}_{k}_real.png'))
            plt.close(fig)

        for k in range(5):
            fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2],
                                      in_u_sphere=True, show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png'))
            plt.close(fig)

        for k in range(5):
            fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2],
                                      in_u_sphere=True, show=False,
                                      title=str(epoch))
            fig.savefig(join(results_dir, 'samples',
                             f'{epoch:05}_{k}_reconstructed.png'))
            plt.close(fig)

        if epoch % config['save_frequency'] == 0:
            torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(D.state_dict(), join(weights_path, f'{epoch:05}_D.pth'))
            torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))

            torch.save(EG_optim.state_dict(),
                       join(weights_path, f'{epoch:05}_EGo.pth'))

            torch.save(D_optim.state_dict(),
                       join(weights_path, f'{epoch:05}_Do.pth'))
コード例 #9
0
def same_model_different_slices(full_model,
                                device,
                                datasets_dict,
                                results_dir,
                                epoch,
                                amount=10,
                                slices_number=10,
                                mean=0.0,
                                std=0.015):
    def process_existing(pcd, cat_name, name, i, j):
        np.save(
            join(results_dir, 'same_model_different_slices',
                 f'{cat_name}_{i}_{j}_{name}_pcd'), pcd)
        noise = torch.zeros(1, full_model.get_noise_size()).normal_(mean=mean,
                                                                    std=std)
        np.save(
            join(results_dir, 'same_model_different_slices',
                 f'{cat_name}_{i}_{j}_{name}_noise'), noise.numpy())

        pcd = torch.from_numpy(pcd).unsqueeze(0).to(device)
        noise = noise.to(device)
        rec = full_model(pcd, None, [1, 2048, 3], epoch, device,
                         noise=noise)[0].cpu().numpy()

        np.save(
            join(results_dir, 'same_model_different_slices',
                 f'{cat_name}_{i}_{j}_{name}_rec'), rec)

        fig = plot_3d_point_cloud(rec[0],
                                  rec[1],
                                  rec[2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(
            join(results_dir, 'same_model_different_slices',
                 f'{cat_name}_{i}_{j}_{name}_rec.png'))
        plt.close(fig)

    with torch.no_grad():
        for cat_name, ds in datasets_dict.items():
            ids = np.random.choice(len(ds), amount, replace=False)
            for i, idx in tqdm(enumerate(ids), total=len(ids)):
                _, _, points, _ = ds[idx]
                points = points.T
                fig = plot_3d_point_cloud(points[0],
                                          points[1],
                                          points[2],
                                          in_u_sphere=True,
                                          show=False)
                fig.savefig(
                    join(results_dir, 'same_model_different_slices',
                         f'{cat_name}_{i}_gt.png'))
                plt.close(fig)
                points = points.T
                np.save(
                    join(results_dir, 'same_model_different_slices',
                         f'{cat_name}_{i}_gt'), points)
                for j in range(slices_number):
                    f_pcd, s_pcd = SlicedDatasetGenerator.generate_item(
                        points, 1024)
                    process_existing(f_pcd, cat_name, 'f', i, j)
                    process_existing(s_pcd, cat_name, 's', i, j)
コード例 #10
0
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config, 'aae', 'training')
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('aae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')
    metrics_path = join(results_dir, 'metrics')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset, batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   pin_memory=True)

    pointnet = config.get('pointnet', False)
    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)

    if pointnet:
        from models.pointnet import PointNet
        encoder = PointNet(config).to(device)
        # PointNet initializes it's own weights during instance creation
    else:
        encoder = aae.Encoder(config).to(device)
        encoder.apply(weights_init)

    discriminator = aae.Discriminator(config).to(device)

    hyper_network.apply(weights_init)
    discriminator.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        if pointnet:
            from utils.metrics import chamfer_distance
            reconstruction_loss = chamfer_distance
        else:
            from losses.champfer_loss import ChamferLoss
            reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        from utils.metrics import earth_mover_distance
        reconstruction_loss = earth_mover_distance
    else:
        raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or '
                         f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type'])
    e_hn_optimizer = e_hn_optimizer(chain(encoder.parameters(), hyper_network.parameters()),
                                    **config['optimizer']['E_HN']['hyperparams'])

    discriminator_optimizer = getattr(optim, config['optimizer']['D']['type'])
    discriminator_optimizer = discriminator_optimizer(discriminator.parameters(),
                                                      **config['optimizer']['D']['hyperparams'])

    log.info("Starting epoch: %s" % starting_epoch)
    if starting_epoch > 1:
        log.info("Loading weights...")
        hyper_network.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_G.pth')))
        encoder.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_E.pth')))
        discriminator.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_D.pth')))

        e_hn_optimizer.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_EGo.pth')))

        discriminator_optimizer.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_Do.pth')))

        log.info("Loading losses...")
        losses_e = np.load(join(metrics_path, f'{starting_epoch - 1:05}_E.npy')).tolist()
        losses_g = np.load(join(metrics_path, f'{starting_epoch - 1:05}_G.npy')).tolist()
        losses_eg = np.load(join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist()
        losses_d = np.load(join(metrics_path, f'{starting_epoch - 1:05}_D.npy')).tolist()
    else:
        log.info("First epoch")
        losses_e = []
        losses_g = []
        losses_eg = []
        losses_d = []

    normalize_points = config['target_network_input']['normalization']['enable']
    if normalize_points:
        normalization_type = config['target_network_input']['normalization']['type']
        assert normalization_type == 'progressive', 'Invalid normalization type'

    target_network_input = None
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()
        log.debug("Epoch: %s" % epoch)
        hyper_network.train()
        encoder.train()
        discriminator.train()

        total_loss_all = 0.0
        total_loss_reconstruction = 0.0
        total_loss_encoder = 0.0
        total_loss_discriminator = 0.0
        total_loss_regularization = 0.0
        for i, point_data in enumerate(points_dataloader, 1):

            X, _ = point_data
            X = X.to(device)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            if pointnet:
                _, feature_transform, codes = encoder(X)
            else:
                codes, _, _ = encoder(X)

            # discriminator training
            noise = torch.empty(codes.shape[0], config['z_size']).normal_(mean=config['normal_mu'],
                                                                          std=config['normal_std']).to(device)
            synth_logit = discriminator(codes)
            real_logit = discriminator(noise)
            if config.get('wasserstein', True):
                loss_discriminator = torch.mean(synth_logit) - torch.mean(real_logit)

                alpha = torch.rand(codes.shape[0], 1).to(device)
                differences = codes - noise
                interpolates = noise + alpha * differences
                disc_interpolates = discriminator(interpolates)

                # gradient_penalty_function
                gradients = grad(
                    outputs=disc_interpolates,
                    inputs=interpolates,
                    grad_outputs=torch.ones_like(disc_interpolates).to(device),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)[0]
                slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1))
                gradient_penalty = ((slopes - 1) ** 2).mean()
                loss_gp = config['gradient_penalty_coef'] * gradient_penalty
                loss_discriminator += loss_gp
            else:
                # An alternative is a = -1, b = 1 iff c = 0
                a = 0.0
                b = 1.0
                loss_discriminator = 0.5 * ((real_logit - b)**2 + (synth_logit - a)**2)

            discriminator_optimizer.zero_grad()
            discriminator.zero_grad()

            loss_discriminator.backward(retain_graph=True)
            total_loss_discriminator += loss_discriminator.item()
            discriminator_optimizer.step()

            # hyper network training
            target_networks_weights = hyper_network(codes)

            X_rec = torch.zeros(X.shape).to(device)
            for j, target_network_weights in enumerate(target_networks_weights):
                target_network = aae.TargetNetwork(config, target_network_weights).to(device)

                if not config['target_network_input']['constant'] or target_network_input is None:
                    target_network_input = generate_points(config=config, epoch=epoch, size=(X.shape[2], X.shape[1]))

                X_rec[j] = torch.transpose(target_network(target_network_input.to(device)), 0, 1)

            if pointnet:
                loss_reconstruction = config['reconstruction_coef'] * \
                                      reconstruction_loss(torch.transpose(X, 1, 2).contiguous(),
                                                          torch.transpose(X_rec, 1, 2).contiguous(),
                                                          batch_size=X.shape[0]).mean()
            else:
                loss_reconstruction = torch.mean(
                    config['reconstruction_coef'] *
                    reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                        X_rec.permute(0, 2, 1) + 0.5))

            # encoder training
            synth_logit = discriminator(codes)
            if config.get('wasserstein', True):
                loss_encoder = -torch.mean(synth_logit)
            else:
                # An alternative is c = 0 iff a = -1, b = 1
                c = 1.0
                loss_encoder = 0.5 * (synth_logit - c)**2

            if pointnet:
                regularization_loss = config['feature_regularization_coef'] * \
                                      feature_transform_regularization(feature_transform).mean()
                loss_all = loss_reconstruction + loss_encoder + regularization_loss
            else:
                loss_all = loss_reconstruction + loss_encoder

            e_hn_optimizer.zero_grad()
            encoder.zero_grad()
            hyper_network.zero_grad()

            loss_all.backward()
            e_hn_optimizer.step()

            total_loss_reconstruction += loss_reconstruction.item()
            total_loss_encoder += loss_encoder.item()
            total_loss_all += loss_all.item()

            if pointnet:
                total_loss_regularization += regularization_loss.item()

        log.info(
            f'[{epoch}/{config["max_epochs"]}] '
            f'Total_Loss: {total_loss_all / i:.4f} '
            f'Loss_R: {total_loss_reconstruction / i:.4f} '
            f'Loss_E: {total_loss_encoder / i:.4f} '
            f'Loss_D: {total_loss_discriminator / i:.4f} '
            f'Time: {datetime.now() - start_epoch_time}'
        )

        if pointnet:
            log.info(f'Loss_Regularization: {total_loss_regularization / i:.4f}')

        losses_e.append(total_loss_reconstruction)
        losses_g.append(total_loss_encoder)
        losses_eg.append(total_loss_all)
        losses_d.append(total_loss_discriminator)

        #
        # Save intermediate results
        #
        if epoch % config['save_samples_frequency'] == 0:
            log.debug('Saving samples...')

            X = X.cpu().numpy()
            X_rec = X_rec.detach().cpu().numpy()

            for k in range(min(5, X_rec.shape[0])):
                fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], in_u_sphere=True, show=False,
                                          title=str(epoch))
                fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], in_u_sphere=True, show=False)
                fig.savefig(join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

        if config['clean_weights_dir']:
            log.debug('Cleaning weights path: %s' % weights_path)
            shutil.rmtree(weights_path, ignore_errors=True)
            os.makedirs(weights_path, exist_ok=True)

        if epoch % config['save_weights_frequency'] == 0:
            log.debug('Saving weights and losses...')

            torch.save(hyper_network.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(encoder.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))
            torch.save(e_hn_optimizer.state_dict(), join(weights_path, f'{epoch:05}_EGo.pth'))
            torch.save(discriminator.state_dict(), join(weights_path, f'{epoch:05}_D.pth'))
            torch.save(discriminator_optimizer.state_dict(), join(weights_path, f'{epoch:05}_Do.pth'))

            np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e))
            np.save(join(metrics_path, f'{epoch:05}_G'), np.array(losses_g))
            np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
            np.save(join(metrics_path, f'{epoch:05}_D'), np.array(losses_d))
コード例 #11
0
def main(config):
    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])

    results_dir = prepare_results_dir(config)
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger(__name__)

    device = cuda_setup(config['cuda'], config['gpu'])
    log.debug(f'Device variable: {device}')
    if device.type == 'cuda':
        log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')
    log.debug("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True,
                                   pin_memory=True)

    #
    # Models
    #
    arch = import_module(f"model.architectures.{config['arch']}")
    G = arch.Generator(config).to(device)
    E = arch.Encoder(config).to(device)

    G.apply(weights_init)
    E.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    EG_optim = getattr(optim, config['optimizer']['EG']['type'])
    EG_optim = EG_optim(chain(E.parameters(), G.parameters()),
                        **config['optimizer']['EG']['hyperparams'])

    if starting_epoch > 1:
        G.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch-1:05}_G.pth')))
        E.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch-1:05}_E.pth')))

        EG_optim.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch-1:05}_EGo.pth')))

    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()

        G.train()
        E.train()

        total_loss = 0.0
        for i, point_data in enumerate(points_dataloader, 1):
            log.debug('-' * 20)

            X, _ = point_data
            X = X.to(device)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)

            X_rec = G(E(X))

            loss = torch.mean(config['reconstruction_coef'] *
                              reconstruction_loss(
                                  X.permute(0, 2, 1) + 0.5,
                                  X_rec.permute(0, 2, 1) + 0.5))

            EG_optim.zero_grad()
            E.zero_grad()
            G.zero_grad()

            loss.backward()
            total_loss += loss.item()
            EG_optim.step()

            log.debug(f'[{epoch}: ({i})] '
                      f'Loss: {loss.item():.4f} '
                      f'Time: {datetime.now() - start_epoch_time}')

        log.debug(f'[{epoch}/{config["max_epochs"]}] '
                  f'Loss: {total_loss / i:.4f} '
                  f'Time: {datetime.now() - start_epoch_time}')

        #
        # Save intermediate results
        #
        G.eval()
        E.eval()
        with torch.no_grad():
            X_rec = G(E(X)).data.cpu().numpy()

        for k in range(5):
            fig = plot_3d_point_cloud(X[k][0],
                                      X[k][1],
                                      X[k][2],
                                      in_u_sphere=True,
                                      show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples', f'{epoch:05}_{k}_real.png'))
            plt.close(fig)

        for k in range(5):
            fig = plot_3d_point_cloud(X_rec[k][0],
                                      X_rec[k][1],
                                      X_rec[k][2],
                                      in_u_sphere=True,
                                      show=False,
                                      title=str(epoch))
            fig.savefig(
                join(results_dir, 'samples',
                     f'{epoch:05}_{k}_reconstructed.png'))
            plt.close(fig)

        if epoch % config['save_frequency'] == 0:
            torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))

            torch.save(EG_optim.state_dict(),
                       join(weights_path, f'{epoch:05}_EGo.pth'))
コード例 #12
0
def fixed(hyper_network, device, results_dir, epoch, fixed_number, z_size,
          fixed_mean, fixed_std, x_shape, triangulation, method, depth):
    log.info("Fixed")

    fixed_noise = torch.zeros(fixed_number,
                              z_size).normal_(mean=fixed_mean,
                                              std=fixed_std).to(device)
    weights_fixed = hyper_network(fixed_noise)

    for j, weights in enumerate(weights_fixed):
        target_network = aae.TargetNetwork(config, weights).to(device)

        target_network_input = generate_points(config=config,
                                               epoch=epoch,
                                               size=(x_shape[1], x_shape[0]))
        fixed_rec = torch.transpose(
            target_network(target_network_input.to(device)), 0,
            1).cpu().numpy()
        np.save(join(results_dir, 'fixed', f'{j}_target_network_input'),
                np.array(target_network_input))
        np.save(join(results_dir, 'fixed', f'{j}_fixed_reconstruction'),
                fixed_rec)

        fig = plot_3d_point_cloud(fixed_rec[0],
                                  fixed_rec[1],
                                  fixed_rec[2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(join(results_dir, 'fixed', f'{j}_fixed_reconstructed.png'))
        plt.close(fig)

        if triangulation:
            from utils.sphere_triangles import generate

            target_network_input, triangulation = generate(method, depth)

            with open(join(results_dir, 'fixed', f'{j}_triangulation.pickle'),
                      'wb') as triangulation_file:
                pickle.dump(triangulation, triangulation_file)

            fixed_rec = torch.transpose(
                target_network(target_network_input.to(device)), 0,
                1).cpu().numpy()
            np.save(
                join(results_dir, 'fixed',
                     f'{j}_target_network_input_triangulation'),
                np.array(target_network_input))
            np.save(
                join(results_dir, 'fixed',
                     f'{j}_fixed_reconstruction_triangulation'), fixed_rec)

            fig = plot_3d_point_cloud(fixed_rec[0],
                                      fixed_rec[1],
                                      fixed_rec[2],
                                      in_u_sphere=True,
                                      show=False)
            fig.savefig(
                join(results_dir, 'fixed',
                     f'{j}_fixed_reconstructed_triangulation.png'))
            plt.close(fig)

        np.save(join(results_dir, 'fixed', f'{j}_fixed_noise'),
                np.array(fixed_noise[j].cpu()))
コード例 #13
0
def main(config):
    if config['seed'] >= 0:
        random.seed(config['seed'])
        torch.manual_seed(config['seed'])
        torch.cuda.manual_seed(config['seed'])
        np.random.seed(config['seed'])
        torch.backends.cudnn.deterministic = True
        print("random seed: ", config['seed'])

    results_dir = prepare_results_dir(config)
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger(__name__)
    logging.getLogger('matplotlib.font_manager').disabled = True

    device = cuda_setup(config['cuda'], config['gpu'])
    log.debug(f'Device variable: {device}')
    if device.type == 'cuda':
        log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')

    # load dataset
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.debug("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset, batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True, pin_memory=True)
    scale = 1 / (3 * config['n_points'])
    # hyper-parameters
    valid_frequency = config["valid_frequency"]
    num_vae = config["num_vae"]
    beta_rec = config["beta_rec"]
    beta_kl = config["beta_kl"]
    beta_neg = config["beta_neg"]
    gamma_r = config["gamma_r"]
    apply_random_rotation = "rotate" in config["transforms"]
    if apply_random_rotation:
        print("applying random rotation to input shapes")

    # model
    model = SoftIntroVAE(config).to(device)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.chamfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    else:
        raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or '
                         f'`earth_mover`, got: {config["reconstruction_loss"]}')

    prior_std = config["prior_std"]
    prior_logvar = np.log(prior_std ** 2)
    print(f'prior: N(0, {prior_std ** 2:.3f})')

    # optimizers
    optimizer_e = getattr(optim, config['optimizer']['E']['type'])
    optimizer_e = optimizer_e(model.encoder.parameters(), **config['optimizer']['E']['hyperparams'])
    optimizer_d = getattr(optim, config['optimizer']['D']['type'])
    optimizer_d = optimizer_d(model.decoder.parameters(), **config['optimizer']['D']['hyperparams'])

    scheduler_e = optim.lr_scheduler.MultiStepLR(optimizer_e, milestones=[350, 450, 550], gamma=0.5)
    scheduler_d = optim.lr_scheduler.MultiStepLR(optimizer_d, milestones=[350, 450, 550], gamma=0.5)

    if starting_epoch > 1:
        model.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}.pth')))

        optimizer_e.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_optim_e.pth')))
        optimizer_d.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch - 1:05}_optim_d.pth')))

    kls_real = []
    kls_fake = []
    kls_rec = []
    rec_errs = []
    exp_elbos_f = []
    exp_elbos_r = []
    diff_kls = []
    best_res = {"epoch": 0, "jsd": None}

    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()

        model.train()

        if epoch < num_vae:
            total_loss = 0.0
            pbar = tqdm(iterable=points_dataloader)
            for i, point_data in enumerate(pbar, 1):
                x, _ = point_data
                x = x.to(device)

                # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
                if x.size(-1) == 3:
                    x.transpose_(x.dim() - 2, x.dim() - 1)

                x_rec, mu, logvar = model(x)
                loss_rec = reconstruction_loss(x.permute(0, 2, 1) + 0.5, x_rec.permute(0, 2, 1) + 0.5)
                while len(loss_rec.shape) > 1:
                    loss_rec = loss_rec.sum(-1)
                loss_rec = loss_rec.mean()
                loss_kl = calc_kl(logvar, mu, logvar_o=prior_logvar, reduce="mean")
                loss = beta_rec * loss_rec + beta_kl * loss_kl

                optimizer_e.zero_grad()
                optimizer_d.zero_grad()
                loss.backward()
                total_loss += loss.item()
                optimizer_e.step()
                optimizer_d.step()

                pbar.set_description_str('epoch #{}'.format(epoch))
                pbar.set_postfix(r_loss=loss_rec.data.cpu().item(), kl=loss_kl.data.cpu().item())

        else:
            batch_kls_real = []
            batch_kls_fake = []
            batch_kls_rec = []
            batch_rec_errs = []
            batch_exp_elbo_f = []
            batch_exp_elbo_r = []
            batch_diff_kls = []
            pbar = tqdm(iterable=points_dataloader)
            for i, point_data in enumerate(pbar, 1):
                x, _ = point_data
                x = x.to(device)

                # random rotation
                if apply_random_rotation:
                    angle = torch.rand(size=(x.shape[0],)) * 180
                    rotate_transform = RotateAxisAngle(angle, axis="Z", device=device)
                    x = rotate_transform.transform_points(x)

                # change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
                if x.size(-1) == 3:
                    x.transpose_(x.dim() - 2, x.dim() - 1)

                noise_batch = prior_std * torch.randn(size=(config['batch_size'], model.zdim)).to(device)

                # ----- update E ----- #
                for param in model.encoder.parameters():
                    param.requires_grad = True
                for param in model.decoder.parameters():
                    param.requires_grad = False

                fake = model.sample(noise_batch)

                real_mu, real_logvar = model.encode(x)
                z = reparameterize(real_mu, real_logvar)
                x_rec = model.decoder(z)

                loss_rec = reconstruction_loss(x.permute(0, 2, 1) + 0.5, x_rec.permute(0, 2, 1) + 0.5)
                while len(loss_rec.shape) > 1:
                    loss_rec = loss_rec.sum(-1)
                loss_rec = loss_rec.mean()

                loss_real_kl = calc_kl(real_logvar, real_mu, logvar_o=prior_logvar, reduce="mean")

                rec_rec, rec_mu, rec_logvar = model(x_rec.detach())
                rec_fake, fake_mu, fake_logvar = model(fake.detach())

                kl_rec = calc_kl(rec_logvar, rec_mu, logvar_o=prior_logvar, reduce="none")
                kl_fake = calc_kl(fake_logvar, fake_mu, logvar_o=prior_logvar, reduce="none")

                loss_rec_rec_e = reconstruction_loss(x_rec.detach().permute(0, 2, 1) + 0.5,
                                                     rec_rec.permute(0, 2, 1) + 0.5)
                while len(loss_rec_rec_e.shape) > 1:
                    loss_rec_rec_e = loss_rec_rec_e.sum(-1)
                loss_rec_fake_e = reconstruction_loss(fake.permute(0, 2, 1) + 0.5, rec_fake.permute(0, 2, 1) + 0.5)
                while len(loss_rec_fake_e.shape) > 1:
                    loss_rec_fake_e = loss_rec_fake_e.sum(-1)

                expelbo_rec = (-2 * scale * (beta_rec * loss_rec_rec_e + beta_neg * kl_rec)).exp().mean()
                expelbo_fake = (-2 * scale * (beta_rec * loss_rec_fake_e + beta_neg * kl_fake)).exp().mean()

                loss_margin = scale * beta_kl * loss_real_kl + 0.25 * (expelbo_rec + expelbo_fake)

                lossE = scale * beta_rec * loss_rec + loss_margin
                optimizer_e.zero_grad()
                lossE.backward()
                optimizer_e.step()

                # ----- update D ----- #
                for param in model.encoder.parameters():
                    param.requires_grad = False
                for param in model.decoder.parameters():
                    param.requires_grad = True

                fake = model.sample(noise_batch)
                with torch.no_grad():
                    z = reparameterize(real_mu.detach(), real_logvar.detach())
                rec = model.decoder(z.detach())
                loss_rec = reconstruction_loss(x.permute(0, 2, 1) + 0.5, rec.permute(0, 2, 1) + 0.5)
                while len(loss_rec.shape) > 1:
                    loss_rec = loss_rec.sum(-1)
                loss_rec = loss_rec.mean()

                rec_mu, rec_logvar = model.encode(rec)
                z_rec = reparameterize(rec_mu, rec_logvar)

                fake_mu, fake_logvar = model.encode(fake)
                z_fake = reparameterize(fake_mu, fake_logvar)

                rec_rec = model.decode(z_rec.detach())
                rec_fake = model.decode(z_fake.detach())

                loss_rec_rec = reconstruction_loss(rec.detach().permute(0, 2, 1) + 0.5, rec_rec.permute(0, 2, 1) + 0.5)
                while len(loss_rec_rec.shape) > 1:
                    loss_rec_rec = loss_rec.sum(-1)
                loss_rec_rec = loss_rec_rec.mean()
                loss_fake_rec = reconstruction_loss(fake.detach().permute(0, 2, 1) + 0.5,
                                                    rec_fake.permute(0, 2, 1) + 0.5)
                while len(loss_fake_rec.shape) > 1:
                    loss_fake_rec = loss_rec.sum(-1)
                loss_fake_rec = loss_fake_rec.mean()

                lossD_rec_kl = calc_kl(rec_logvar, rec_mu, logvar_o=prior_logvar, reduce="mean")
                lossD_fake_kl = calc_kl(fake_logvar, fake_mu, logvar_o=prior_logvar, reduce="mean")

                lossD = scale * (loss_rec * beta_rec + (
                        lossD_rec_kl + lossD_fake_kl) * 0.5 * beta_kl + gamma_r * 0.5 * beta_rec * (
                                         loss_rec_rec + loss_fake_rec))

                optimizer_d.zero_grad()
                lossD.backward()
                optimizer_d.step()

                if torch.isnan(lossD):
                    raise SystemError("loss is Nan")

                diff_kl = -loss_real_kl.data.cpu() + lossD_fake_kl.data.cpu()
                batch_diff_kls.append(diff_kl)
                batch_kls_real.append(loss_real_kl.data.cpu().item())
                batch_kls_fake.append(lossD_fake_kl.cpu().item())
                batch_kls_rec.append(lossD_rec_kl.data.cpu().item())
                batch_rec_errs.append(loss_rec.data.cpu().item())
                batch_exp_elbo_f.append(expelbo_fake.data.cpu())
                batch_exp_elbo_r.append(expelbo_rec.data.cpu())

                pbar.set_description_str('epoch #{}'.format(epoch))
                pbar.set_postfix(r_loss=loss_rec.data.cpu().item(), kl=loss_real_kl.data.cpu().item(),
                                 diff_kl=diff_kl.item(), expelbo_f=expelbo_fake.cpu().item())

        pbar.close()
        scheduler_e.step()
        scheduler_d.step()
        if epoch > num_vae - 1:
            kls_real.append(np.mean(batch_kls_real))
            kls_fake.append(np.mean(batch_kls_fake))
            kls_rec.append(np.mean(batch_kls_rec))
            rec_errs.append(np.mean(batch_rec_errs))
            exp_elbos_f.append(np.mean(batch_exp_elbo_f))
            exp_elbos_r.append(np.mean(batch_exp_elbo_r))
            diff_kls.append(np.mean(batch_diff_kls))
            # epoch summary
            print('#' * 50)
            print(f'Epoch {epoch} Summary:')
            print(f'beta_rec: {beta_rec}, beta_kl: {beta_kl}, beta_neg: {beta_neg}')
            print(
                f'rec: {rec_errs[-1]:.3f}, kl: {kls_real[-1]:.3f}, kl_fake: {kls_fake[-1]:.3f}, kl_rec: {kls_rec[-1]:.3f}')
            print(
                f'diff_kl: {diff_kls[-1]:.3f}, exp_elbo_f: {exp_elbos_f[-1]:.4e}, exp_elbo_r: {exp_elbos_r[-1]:.4e}')
            if best_res['jsd'] is not None:
                print(f'best jsd: {best_res["jsd"]}, epoch: {best_res["epoch"]}')
            print(f'time: {datetime.now() - start_epoch_time}')
            print('#' * 50)
        # save intermediate results
        model.eval()
        with torch.no_grad():
            noise_batch = prior_std * torch.randn(size=(5, model.zdim)).to(device)
            fake = model.sample(noise_batch).data.cpu().numpy()
            x_rec, _, _ = model(x, deterministic=True)
            x_rec = x_rec.data.cpu().numpy()

        fig = plt.figure(dpi=350)
        for k in range(5):
            ax = fig.add_subplot(3, 5, k + 1, projection='3d')
            ax = plot_3d_point_cloud(x[k][0].data.cpu().numpy(), x[k][1].data.cpu().numpy(),
                                     x[k][2].data.cpu().numpy(),
                                     in_u_sphere=True, show=False, axis=ax, show_axis=True, s=4, color='dodgerblue')
            remove_ticks_from_ax(ax)

        for k in range(5):
            ax = fig.add_subplot(3, 5, k + 6, projection='3d')
            ax = plot_3d_point_cloud(x_rec[k][0],
                                     x_rec[k][1],
                                     x_rec[k][2],
                                     in_u_sphere=True, show=False, axis=ax, show_axis=True, s=4, color='dodgerblue')
            remove_ticks_from_ax(ax)

        for k in range(5):
            ax = fig.add_subplot(3, 5, k + 11, projection='3d')
            ax = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2],
                                     in_u_sphere=True, show=False, axis=ax, show_axis=True, s=4, color='dodgerblue')
            remove_ticks_from_ax(ax)

        fig.savefig(join(results_dir, 'samples', f'figure_{epoch}'))
        plt.close(fig)

        if epoch % valid_frequency == 0:
            print("calculating valid jsd...")
            model.eval()
            with torch.no_grad():
                jsd = calc_jsd_valid(model, config, prior_std=prior_std)
            print(f'epoch: {epoch}, jsd: {jsd:.4f}')
            if best_res['jsd'] is None:
                best_res['jsd'] = jsd
                best_res['epoch'] = epoch
            elif best_res['jsd'] > jsd:
                print(f'epoch: {epoch}: best jsd updated: {best_res["jsd"]} -> {jsd}')
                best_res['jsd'] = jsd
                best_res['epoch'] = epoch
                # save
                torch.save(model.state_dict(), join(weights_path, f'{epoch:05}_jsd_{jsd:.4f}.pth'))

        if epoch % config['save_frequency'] == 0:
            torch.save(model.state_dict(), join(weights_path, f'{epoch:05}.pth'))
            torch.save(optimizer_e.state_dict(),
                       join(weights_path, f'{epoch:05}_optim_e.pth'))
            torch.save(optimizer_d.state_dict(),
                       join(weights_path, f'{epoch:05}_optim_d.pth'))
コード例 #14
0
ファイル: train_vae.py プロジェクト: luke9642/master-thesis
def main(config):
    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])

    results_dir = prepare_results_dir(config)
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('vae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.debug(f'Device variable: {device}')
    if device.type == 'cuda':
        log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.debug("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset, batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True, pin_memory=True)

    #
    # Models
    #
    arch = import_module(f"models.{config['arch']}")
    G = arch.Generator(config).to(device)
    E = arch.Encoder(config).to(device)

    G.apply(weights_init)
    E.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    elif config['reconstruction_loss'].lower() == 'cramer_wold':
        from losses.cramer_wold import CWSample
        reconstruction_loss = CWSample().to(device)
    else:
        raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or '
                         f'`earth_mover`, got: {config["reconstruction_loss"]}')
    #
    # Float Tensors
    #
    fixed_noise = torch.FloatTensor(config['batch_size'], config['z_size'], 1)
    fixed_noise.normal_(mean=0, std=0.2)
    std_assumed = torch.tensor(0.2)

    fixed_noise = fixed_noise.to(device)
    std_assumed = std_assumed.to(device)

    #
    # Optimizers
    #
    EG_optim = getattr(optim, config['optimizer']['EG']['type'])
    EG_optim = EG_optim(chain(E.parameters(), G.parameters()),
                        **config['optimizer']['EG']['hyperparams'])

    if starting_epoch > 1:
        G.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_G.pth')))
        E.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_E.pth')))

        EG_optim.load_state_dict(torch.load(
            join(weights_path, f'{starting_epoch-1:05}_EGo.pth')))

    losses = []

    with trange(starting_epoch, config['max_epochs'] + 1) as t:
        for epoch in t:
            start_epoch_time = datetime.now()

            G.train()
            E.train()

            total_loss = 0.0
            losses_eg = []
            losses_e = []
            losses_kld = []

            for i, point_data in enumerate(points_dataloader, 1):
                # log.debug('-' * 20)

                X, _ = point_data
                X = X.to(device)

                # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
                if X.size(-1) == 3:
                    X.transpose_(X.dim() - 2, X.dim() - 1)

                codes, mu, logvar = E(X)
                X_rec = G(codes)

                loss_e = torch.mean(
                    # config['reconstruction_coef'] *
                    reconstruction_loss(X.permute(0, 2, 1) + 0.5,
                                        X_rec.permute(0, 2, 1) + 0.5))

                loss_kld = config['reconstruction_coef'] * cw_distance(mu)

                # loss_kld = -0.5 * torch.mean(
                #     1 - 2.0 * torch.log(std_assumed) + logvar -
                #     (mu.pow(2) + logvar.exp()) / torch.pow(std_assumed, 2))

                loss_eg = loss_e + loss_kld
                EG_optim.zero_grad()
                E.zero_grad()
                G.zero_grad()

                loss_eg.backward()
                total_loss += loss_eg.item()
                EG_optim.step()

                losses_e.append(loss_e.item())
                losses_kld.append(loss_kld.item())
                losses_eg.append(loss_eg.item())

                # log.debug

                t.set_description(
                    f'[{epoch}: ({i})] '
                    f'Loss_EG: {loss_eg.item():.4f} '
                    f'(REC: {loss_e.item(): .4f}'
                    f' KLD: {loss_kld.item(): .4f})'
                    f' Time: {datetime.now() - start_epoch_time}'
                )

            t.set_description(
                f'[{epoch}/{config["max_epochs"]}] '
                f'Loss_G: {total_loss / i:.4f} '
                f'Time: {datetime.now() - start_epoch_time}'
            )

            losses.append([
                np.mean(losses_e),
                np.mean(losses_kld),
                np.mean(losses_eg)
            ])

            #
            # Save intermediate results
            #
            G.eval()
            E.eval()
            with torch.no_grad():
                fake = G(fixed_noise).data.cpu().numpy()
                codes, _, _ = E(X)
                X_rec = G(codes).data.cpu().numpy()

            X_numpy = X.cpu().numpy()
            for k in range(5):
                fig = plot_3d_point_cloud(X_numpy[k][0], X_numpy[k][1], X_numpy[k][2],
                                          in_u_sphere=True, show=False)
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

            for k in range(5):
                fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2],
                                          in_u_sphere=True, show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples', 'fixed', f'{epoch:05}_{k}_fixed.png'))
                plt.close(fig)

            for k in range(5):
                fig = plot_3d_point_cloud(X_rec[k][0],
                                          X_rec[k][1],
                                          X_rec[k][2],
                                          in_u_sphere=True, show=False)
                fig.savefig(join(results_dir, 'samples',
                                 f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

            if epoch % config['save_frequency'] == 0:
                df = pd.DataFrame(losses, columns=['loss_e', 'loss_kld', 'loss_eg'])
                df.to_json(join(results_dir, 'losses', 'losses.json'))
                fig = df.plot.line().get_figure()
                fig.savefig(join(results_dir, 'losses', f'{epoch:05}_{k}.png'))

                torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth'))
                torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth'))

                torch.save(EG_optim.state_dict(),
                           join(weights_path, f'{epoch:05}_EGo.pth'))
コード例 #15
0
def sphere_triangles(encoder, hyper_network, device, x, results_dir, amount,
                     method, depth, start, end, transitions):
    from utils.sphere_triangles import generate
    log.info("Sphere triangles")
    x = x[:amount]

    z_a, _, _ = encoder(x)
    weights_sphere = hyper_network(z_a)
    x = x.cpu().numpy()
    for k in range(amount):
        target_network = aae.TargetNetwork(config, weights_sphere[k])
        target_network_input, triangulation = generate(method, depth)
        x_rec = torch.transpose(
            target_network(target_network_input.to(device)), 0,
            1).cpu().numpy()

        np.save(join(results_dir, 'sphere_triangles', f'{k}_real'),
                np.array(x[k]))
        np.save(join(results_dir, 'sphere_triangles', f'{k}_point_cloud'),
                np.array(target_network_input))
        np.save(join(results_dir, 'sphere_triangles', f'{k}_reconstruction'),
                np.array(x_rec))

        with open(
                join(results_dir, 'sphere_triangles',
                     f'{k}_triangulation.pickle'), 'wb') as triangulation_file:
            pickle.dump(triangulation, triangulation_file)

        fig = plot_3d_point_cloud(x_rec[0],
                                  x_rec[1],
                                  x_rec[2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(
            join(results_dir, 'sphere_triangles', f'{k}_reconstructed.png'))
        plt.close(fig)

        for coefficient in np.linspace(start, end, num=transitions):
            coefficient = round(coefficient, 3)
            target_network_input_coefficient = target_network_input * coefficient
            x_sphere = torch.transpose(
                target_network(target_network_input_coefficient.to(device)), 0,
                1).cpu().numpy()

            np.save(
                join(results_dir, 'sphere_triangles',
                     f'{k}_point_cloud_coefficient_{coefficient}'),
                np.array(target_network_input_coefficient))
            np.save(
                join(results_dir, 'sphere_triangles',
                     f'{k}_reconstruction_coefficient_{coefficient}'),
                x_sphere)

            fig = plot_3d_point_cloud(x_sphere[0],
                                      x_sphere[1],
                                      x_sphere[2],
                                      in_u_sphere=True,
                                      show=False)
            fig.savefig(
                join(results_dir, 'sphere_triangles',
                     f'{k}_{coefficient}_reconstructed.png'))
            plt.close(fig)

        fig = plot_3d_point_cloud(x[k][0],
                                  x[k][1],
                                  x[k][2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(join(results_dir, 'sphere_triangles', f'{k}_real.png'))
        plt.close(fig)
コード例 #16
0
def sphere_triangles_interpolation(encoder, hyper_network, device, x,
                                   results_dir, amount, method, depth,
                                   coefficient, transitions):
    from utils.sphere_triangles import generate
    log.info("Sphere triangles interpolation")

    for k in range(amount):
        x_a = x[None, 2 * k, :, :]
        x_b = x[None, 2 * k + 1, :, :]

        with torch.no_grad():
            z_a, mu_a, var_a = encoder(x_a)
            z_b, mu_b, var_b = encoder(x_b)

        for j, alpha in enumerate(np.linspace(0, 1, transitions)):
            z_int = (1 - alpha
                     ) * z_a + alpha * z_b  # interpolate in the latent space
            weights_int = hyper_network(
                z_int)  # decode the interpolated sample

            target_network = aae.TargetNetwork(config, weights_int[0])
            target_network_input, triangulation = generate(method, depth)
            x_int = torch.transpose(
                target_network(target_network_input.to(device)), 0,
                1).cpu().numpy()

            np.save(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_point_cloud'), np.array(target_network_input))
            np.save(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_interpolation'), x_int)

            with open(
                    join(results_dir, 'sphere_triangles_interpolation',
                         f'{k}_{j}_triangulation.pickle'),
                    'wb') as triangulation_file:
                pickle.dump(triangulation, triangulation_file)

            fig = plot_3d_point_cloud(x_int[0],
                                      x_int[1],
                                      x_int[2],
                                      in_u_sphere=True,
                                      show=False)
            fig.savefig(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_interpolation.png'))
            plt.close(fig)

            target_network_input_coefficient = target_network_input * coefficient
            x_int_coeff = torch.transpose(
                target_network(target_network_input_coefficient.to(device)), 0,
                1).cpu().numpy()

            np.save(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_point_cloud_coefficient_{coefficient}'),
                np.array(target_network_input_coefficient))
            np.save(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_interpolation_coefficient_{coefficient}'),
                x_int_coeff)

            fig = plot_3d_point_cloud(x_int_coeff[0],
                                      x_int_coeff[1],
                                      x_int_coeff[2],
                                      in_u_sphere=True,
                                      show=False)
            fig.savefig(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_{coefficient}_interpolation.png'))
            plt.close(fig)
コード例 #17
0
def main(config):
    set_seed(config['seed'])

    results_dir = prepare_results_dir(config, 'vae', 'training')
    starting_epoch = find_latest_epoch(results_dir) + 1

    if not exists(join(results_dir, 'config.json')):
        with open(join(results_dir, 'config.json'), mode='w') as f:
            json.dump(config, f)

    setup_logging(results_dir)
    log = logging.getLogger('vae')

    device = cuda_setup(config['cuda'], config['gpu'])
    log.info(f'Device variable: {device}')
    if device.type == 'cuda':
        log.info(f'Current CUDA device: {torch.cuda.current_device()}')

    weights_path = join(results_dir, 'weights')
    metrics_path = join(results_dir, 'metrics')

    #
    # Dataset
    #
    dataset_name = config['dataset'].lower()
    if dataset_name == 'shapenet':
        from datasets.shapenet import ShapeNetDataset
        dataset = ShapeNetDataset(root_dir=config['data_dir'],
                                  classes=config['classes'])
    elif dataset_name == 'custom':
        dataset = TxtDataset(root_dir=config['data_dir'],
                             classes=config['classes'],
                             config=config)
    elif dataset_name == 'benchmark':
        dataset = Benchmark(root_dir=config['data_dir'],
                            classes=config['classes'],
                            config=config)
    else:
        raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
                         f'`faust`. Got: `{dataset_name}`')

    log.info("Selected {} classes. Loaded {} samples.".format(
        'all' if not config['classes'] else ','.join(config['classes']),
        len(dataset)))

    points_dataloader = DataLoader(dataset,
                                   batch_size=config['batch_size'],
                                   shuffle=config['shuffle'],
                                   num_workers=config['num_workers'],
                                   drop_last=True,
                                   pin_memory=True,
                                   collate_fn=collate_fn)

    #
    # Models
    #
    hyper_network = aae.HyperNetwork(config, device).to(device)
    encoder_pocket = aae.PocketEncoder(config).to(device)
    encoder_visible = aae.VisibleEncoder(config).to(device)

    hyper_network.apply(weights_init)
    encoder_pocket.apply(weights_init)
    encoder_visible.apply(weights_init)

    if config['reconstruction_loss'].lower() == 'chamfer':
        from losses.champfer_loss import ChamferLoss
        reconstruction_loss = ChamferLoss().to(device)
    elif config['reconstruction_loss'].lower() == 'earth_mover':
        # from utils.metrics import earth_mover_distance
        # reconstruction_loss = earth_mover_distance
        from losses.earth_mover_distance import EMD
        reconstruction_loss = EMD().to(device)
    else:
        raise ValueError(
            f'Invalid reconstruction loss. Accepted `chamfer` or '
            f'`earth_mover`, got: {config["reconstruction_loss"]}')

    #
    # Optimizers
    #
    e_hn_optimizer = getattr(optim, config['optimizer']['E_HN']['type'])
    e_hn_optimizer = e_hn_optimizer(
        chain(encoder_visible.parameters(), encoder_pocket.parameters(),
              hyper_network.parameters()),
        **config['optimizer']['E_HN']['hyperparams'])

    log.info("Starting epoch: %s" % starting_epoch)
    if starting_epoch > 1:
        log.info("Loading weights...")
        hyper_network.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_G.pth')))
        encoder_pocket.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EP.pth')))
        encoder_visible.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EV.pth')))

        e_hn_optimizer.load_state_dict(
            torch.load(join(weights_path, f'{starting_epoch - 1:05}_EGo.pth')))

        log.info("Loading losses...")
        losses_e = np.load(join(metrics_path,
                                f'{starting_epoch - 1:05}_E.npy')).tolist()
        losses_kld = np.load(
            join(metrics_path, f'{starting_epoch - 1:05}_KLD.npy')).tolist()
        losses_eg = np.load(
            join(metrics_path, f'{starting_epoch - 1:05}_EG.npy')).tolist()
    else:
        log.info("First epoch")
        losses_e = []
        losses_kld = []
        losses_eg = []

    if config['target_network_input']['normalization']['enable']:
        normalization_type = config['target_network_input']['normalization'][
            'type']
        assert normalization_type == 'progressive', 'Invalid normalization type'

    target_network_input = None
    for epoch in range(starting_epoch, config['max_epochs'] + 1):
        start_epoch_time = datetime.now()
        log.debug("Epoch: %s" % epoch)
        hyper_network.train()
        encoder_visible.train()
        encoder_pocket.train()

        total_loss_all = 0.0
        total_loss_r = 0.0
        total_loss_kld = 0.0
        for i, point_data in enumerate(points_dataloader, 1):
            # get only visible part of point cloud
            X = point_data['non-visible']
            X = X.to(device, dtype=torch.float)

            # get not visible part of point cloud
            X_visible = point_data['visible']
            X_visible = X_visible.to(device, dtype=torch.float)

            # get whole point cloud
            X_whole = point_data['cloud']
            X_whole = X_whole.to(device, dtype=torch.float)

            # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS]
            if X.size(-1) == 3:
                X.transpose_(X.dim() - 2, X.dim() - 1)
                X_visible.transpose_(X_visible.dim() - 2, X_visible.dim() - 1)
                X_whole.transpose_(X_whole.dim() - 2, X_whole.dim() - 1)

            codes, mu, logvar = encoder_pocket(X)
            mu_visible = encoder_visible(X_visible)

            target_networks_weights = hyper_network(
                torch.cat((codes, mu_visible), 1))

            X_rec = torch.zeros(X_whole.shape).to(device)
            for j, target_network_weights in enumerate(
                    target_networks_weights):
                target_network = aae.TargetNetwork(
                    config, target_network_weights).to(device)

                if not config['target_network_input'][
                        'constant'] or target_network_input is None:
                    target_network_input = generate_points(
                        config=config,
                        epoch=epoch,
                        size=(X_whole.shape[2], X_whole.shape[1]))

                X_rec[j] = torch.transpose(
                    target_network(target_network_input.to(device)), 0, 1)

            loss_r = torch.mean(config['reconstruction_coef'] *
                                reconstruction_loss(
                                    X_whole.permute(0, 2, 1) + 0.5,
                                    X_rec.permute(0, 2, 1) + 0.5))

            loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 -
                              logvar).sum()

            loss_all = loss_r + loss_kld
            e_hn_optimizer.zero_grad()
            encoder_visible.zero_grad()
            encoder_pocket.zero_grad()
            hyper_network.zero_grad()

            loss_all.backward()
            e_hn_optimizer.step()

            total_loss_r += loss_r.item()
            total_loss_kld += loss_kld.item()
            total_loss_all += loss_all.item()

        log.info(f'[{epoch}/{config["max_epochs"]}] '
                 f'Loss_ALL: {total_loss_all / i:.4f} '
                 f'Loss_R: {total_loss_r / i:.4f} '
                 f'Loss_E: {total_loss_kld / i:.4f} '
                 f'Time: {datetime.now() - start_epoch_time}')

        losses_e.append(total_loss_r)
        losses_kld.append(total_loss_kld)
        losses_eg.append(total_loss_all)

        #
        # Save intermediate results
        #
        X = X.cpu().numpy()
        X_whole = X_whole.cpu().numpy()
        X_rec = X_rec.detach().cpu().numpy()

        if epoch % config['save_frequency'] == 0:
            for k in range(min(5, X_rec.shape[0])):
                fig = plot_3d_point_cloud(X_rec[k][0],
                                          X_rec[k][1],
                                          X_rec[k][2],
                                          in_u_sphere=True,
                                          show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples',
                         f'{epoch}_{k}_reconstructed.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X_whole[k][0],
                                          X_whole[k][1],
                                          X_whole[k][2],
                                          in_u_sphere=True,
                                          show=False,
                                          title=str(epoch))
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_real.png'))
                plt.close(fig)

                fig = plot_3d_point_cloud(X[k][0],
                                          X[k][1],
                                          X[k][2],
                                          in_u_sphere=True,
                                          show=False)
                fig.savefig(
                    join(results_dir, 'samples', f'{epoch}_{k}_visible.png'))
                plt.close(fig)

        if config['clean_weights_dir']:
            log.debug('Cleaning weights path: %s' % weights_path)
            shutil.rmtree(weights_path, ignore_errors=True)
            os.makedirs(weights_path, exist_ok=True)

        if epoch % config['save_frequency'] == 0:
            log.debug('Saving data...')

            torch.save(hyper_network.state_dict(),
                       join(weights_path, f'{epoch:05}_G.pth'))
            torch.save(encoder_visible.state_dict(),
                       join(weights_path, f'{epoch:05}_EV.pth'))
            torch.save(encoder_pocket.state_dict(),
                       join(weights_path, f'{epoch:05}_EP.pth'))
            torch.save(e_hn_optimizer.state_dict(),
                       join(weights_path, f'{epoch:05}_EGo.pth'))

            np.save(join(metrics_path, f'{epoch:05}_E'), np.array(losses_e))
            np.save(join(metrics_path, f'{epoch:05}_KLD'),
                    np.array(losses_kld))
            np.save(join(metrics_path, f'{epoch:05}_EG'), np.array(losses_eg))
コード例 #18
0
def fixed(hyper_network,
          encoder_visible,
          X_visible,
          device,
          results_dir,
          epoch,
          fixed_number,
          z_size,
          fixed_mean,
          fixed_std,
          x_shape,
          triangulation,
          method,
          depth,
          *,
          number_of_samples=10):
    log.info("Fixed")

    encoder_output = encoder_visible(X_visible)
    for it in range(number_of_samples):
        fixed_noise = torch.zeros(fixed_number,
                                  z_size).normal_(mean=fixed_mean,
                                                  std=fixed_std).to(device)
        weights_fixed = hyper_network(
            torch.cat((fixed_noise, encoder_output), 1))

        X_visible = X_visible.cpu()
        for j, weights in enumerate(weights_fixed):
            target_network = aae.TargetNetwork(config, weights).to(device)

            target_network_input = generate_points(config=config,
                                                   epoch=epoch,
                                                   size=(x_shape[1],
                                                         x_shape[0]))
            fixed_rec = torch.transpose(
                target_network(target_network_input.to(device)), 0,
                1).cpu().numpy()
            np.save(
                join(results_dir, 'fixed', f'{j}_target_network_input_{it}'),
                np.array(target_network_input))
            np.save(
                join(results_dir, 'fixed', f'{j}_fixed_reconstruction_{it}'),
                fixed_rec)

            pretty_plot(fixed_rec[0], fixed_rec[1], fixed_rec[2],
                        X_visible[j][0], X_visible[j][1], X_visible[j][2],
                        f'pretty{j}_fixed_reconstructed_{it}.png')
            fig = plot_3d_point_cloud(fixed_rec[0],
                                      fixed_rec[1],
                                      fixed_rec[2],
                                      in_u_sphere=True,
                                      show=False,
                                      x1=X_visible[j][0],
                                      y1=X_visible[j][1],
                                      z1=X_visible[j][2])

            fig.savefig(
                join(results_dir, 'fixed',
                     f'{j}_fixed_reconstructed_{it}.png'))
            plt.close(fig)

            if triangulation:
                from utils.sphere_triangles import generate

                target_network_input, triangulation = generate(method, depth)

                with open(
                        join(results_dir, 'fixed',
                             f'{j}_triangulation_{it}.pickle'),
                        'wb') as triangulation_file:
                    pickle.dump(triangulation, triangulation_file)

                fixed_rec = torch.transpose(
                    target_network(target_network_input.to(device)), 0,
                    1).cpu().numpy()
                np.save(
                    join(results_dir, 'fixed',
                         f'{j}_target_network_input_triangulation_{it}'),
                    np.array(target_network_input))
                np.save(
                    join(results_dir, 'fixed',
                         f'{j}_fixed_reconstruction_triangulation_{it}'),
                    fixed_rec)

                fig = plot_3d_point_cloud(fixed_rec[0],
                                          fixed_rec[1],
                                          fixed_rec[2],
                                          in_u_sphere=True,
                                          show=False)
                fig.savefig(
                    join(results_dir, 'fixed',
                         f'{j}_fixed_reconstructed_triangulation_{it}.png'))
                plt.close(fig)

            np.save(join(results_dir, 'fixed', f'{j}_fixed_noise_{it}'),
                    np.array(fixed_noise[j].cpu()))
コード例 #19
0
def fixed(full_model: FullModel,
          device,
          datasets_dict,
          results_dir: str,
          epoch,
          amount=30,
          mean=0.0,
          std=0.015,
          noises_per_item=10,
          batch_size=8,
          save_plots=False,
          triangulation_config={
              'execute': False,
              'method': 'edge',
              'depth': 2
          }):
    # clean dir
    shutil.rmtree(join(results_dir, 'fixed'), ignore_errors=True)
    os.makedirs(join(results_dir, 'fixed'), exist_ok=True)

    dataloaders_dict = {
        cat_name: DataLoader(cat_ds, pin_memory=True, batch_size=batch_size)
        for cat_name, cat_ds in datasets_dict.items()
    }
    for cat_name, dl in dataloaders_dict.items():

        for i, data in tqdm(enumerate(dl), total=len(dl)):

            existing, _, _, idx = data
            existing = existing.to(device)

            for j in range(noises_per_item):
                fixed_noise = torch.zeros(existing.shape[0],
                                          full_model.get_noise_size()).normal_(
                                              mean=mean, std=std).to(device)
                reconstruction = full_model(existing,
                                            None, [existing.shape[0], 2048, 3],
                                            epoch,
                                            device,
                                            noise=fixed_noise).cpu()
                for k in range(reconstruction.shape[0]):
                    np.save(
                        join(
                            results_dir, 'fixed',
                            f'{cat_name}_{i * batch_size + k}_{j}_reconstruction'
                        ), reconstruction[k].numpy())
                    if save_plots:
                        fig = plot_3d_point_cloud(reconstruction[k][0],
                                                  reconstruction[k][1],
                                                  reconstruction[k][2],
                                                  in_u_sphere=True,
                                                  show=False)
                        fig.savefig(
                            join(
                                results_dir, 'fixed',
                                f'{cat_name}_{i * batch_size + k}_{j}_fixed_reconstructed.png'
                            ))
                        plt.close(fig)
                    # np.save(join(results_dir, 'fixed', f'{i*batch_size+k}_{j}_fixed_noise'), np.array(fixed_noise[k].cpu().numpy()))

            existing = existing.cpu()
            for k in range(existing.shape[0]):
                np.save(
                    join(results_dir, 'fixed',
                         f'{cat_name}_{i * batch_size + k}_existing'),
                    np.array(existing[k].cpu().numpy()))
                if save_plots:
                    fig = plot_3d_point_cloud(existing[k][0],
                                              existing[k][1],
                                              existing[k][2],
                                              in_u_sphere=True,
                                              show=False)
                    fig.savefig(
                        join(results_dir, 'fixed',
                             f'{cat_name}_{i * batch_size + k}_existing.png'))
                    plt.close(fig)