def fixed(hyper_network,
          encoder_visible,
          X_visible,
          device,
          results_dir,
          epoch,
          fixed_number,
          z_size,
          fixed_mean,
          fixed_std,
          x_shape,
          triangulation,
          method,
          depth,
          *,
          number_of_samples=10):
    log.info("Fixed")

    encoder_output = encoder_visible(X_visible)
    for it in range(number_of_samples):
        fixed_noise = torch.zeros(fixed_number,
                                  z_size).normal_(mean=fixed_mean,
                                                  std=fixed_std).to(device)
        weights_fixed = hyper_network(
            torch.cat((fixed_noise, encoder_output), 1))

        X_visible = X_visible.cpu()
        for j, weights in enumerate(weights_fixed):
            target_network = aae.TargetNetwork(config, weights).to(device)

            target_network_input = generate_points(config=config,
                                                   epoch=epoch,
                                                   size=(x_shape[1],
                                                         x_shape[0]))
            fixed_rec = torch.transpose(
                target_network(target_network_input.to(device)), 0,
                1).cpu().numpy()
            np.save(
                join(results_dir, 'fixed', f'{j}_target_network_input_{it}'),
                np.array(target_network_input))
            np.save(
                join(results_dir, 'fixed', f'{j}_fixed_reconstruction_{it}'),
                fixed_rec)

            pretty_plot(fixed_rec[0], fixed_rec[1], fixed_rec[2],
                        X_visible[j][0], X_visible[j][1], X_visible[j][2],
                        f'pretty{j}_fixed_reconstructed_{it}.png')
            fig = plot_3d_point_cloud(fixed_rec[0],
                                      fixed_rec[1],
                                      fixed_rec[2],
                                      in_u_sphere=True,
                                      show=False,
                                      x1=X_visible[j][0],
                                      y1=X_visible[j][1],
                                      z1=X_visible[j][2])

            fig.savefig(
                join(results_dir, 'fixed',
                     f'{j}_fixed_reconstructed_{it}.png'))
            plt.close(fig)

            if triangulation:
                from utils.sphere_triangles import generate

                target_network_input, triangulation = generate(method, depth)

                with open(
                        join(results_dir, 'fixed',
                             f'{j}_triangulation_{it}.pickle'),
                        'wb') as triangulation_file:
                    pickle.dump(triangulation, triangulation_file)

                fixed_rec = torch.transpose(
                    target_network(target_network_input.to(device)), 0,
                    1).cpu().numpy()
                np.save(
                    join(results_dir, 'fixed',
                         f'{j}_target_network_input_triangulation_{it}'),
                    np.array(target_network_input))
                np.save(
                    join(results_dir, 'fixed',
                         f'{j}_fixed_reconstruction_triangulation_{it}'),
                    fixed_rec)

                fig = plot_3d_point_cloud(fixed_rec[0],
                                          fixed_rec[1],
                                          fixed_rec[2],
                                          in_u_sphere=True,
                                          show=False)
                fig.savefig(
                    join(results_dir, 'fixed',
                         f'{j}_fixed_reconstructed_triangulation_{it}.png'))
                plt.close(fig)

            np.save(join(results_dir, 'fixed', f'{j}_fixed_noise_{it}'),
                    np.array(fixed_noise[j].cpu()))
def sphere_triangles(encoder, hyper_network, device, x, results_dir, amount,
                     method, depth, start, end, transitions):
    from utils.sphere_triangles import generate
    log.info("Sphere triangles")
    x = x[:amount]

    z_a, _, _ = encoder(x)
    weights_sphere = hyper_network(z_a)
    x = x.cpu().numpy()
    for k in range(amount):
        target_network = aae.TargetNetwork(config, weights_sphere[k])
        target_network_input, triangulation = generate(method, depth)
        x_rec = torch.transpose(
            target_network(target_network_input.to(device)), 0,
            1).cpu().numpy()

        np.save(join(results_dir, 'sphere_triangles', f'{k}_real'),
                np.array(x[k]))
        np.save(join(results_dir, 'sphere_triangles', f'{k}_point_cloud'),
                np.array(target_network_input))
        np.save(join(results_dir, 'sphere_triangles', f'{k}_reconstruction'),
                np.array(x_rec))

        with open(
                join(results_dir, 'sphere_triangles',
                     f'{k}_triangulation.pickle'), 'wb') as triangulation_file:
            pickle.dump(triangulation, triangulation_file)

        fig = plot_3d_point_cloud(x_rec[0],
                                  x_rec[1],
                                  x_rec[2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(
            join(results_dir, 'sphere_triangles', f'{k}_reconstructed.png'))
        plt.close(fig)

        for coefficient in np.linspace(start, end, num=transitions):
            coefficient = round(coefficient, 3)
            target_network_input_coefficient = target_network_input * coefficient
            x_sphere = torch.transpose(
                target_network(target_network_input_coefficient.to(device)), 0,
                1).cpu().numpy()

            np.save(
                join(results_dir, 'sphere_triangles',
                     f'{k}_point_cloud_coefficient_{coefficient}'),
                np.array(target_network_input_coefficient))
            np.save(
                join(results_dir, 'sphere_triangles',
                     f'{k}_reconstruction_coefficient_{coefficient}'),
                x_sphere)

            fig = plot_3d_point_cloud(x_sphere[0],
                                      x_sphere[1],
                                      x_sphere[2],
                                      in_u_sphere=True,
                                      show=False)
            fig.savefig(
                join(results_dir, 'sphere_triangles',
                     f'{k}_{coefficient}_reconstructed.png'))
            plt.close(fig)

        fig = plot_3d_point_cloud(x[k][0],
                                  x[k][1],
                                  x[k][2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(join(results_dir, 'sphere_triangles', f'{k}_real.png'))
        plt.close(fig)
def sphere_triangles_interpolation(encoder, hyper_network, device, x,
                                   results_dir, amount, method, depth,
                                   coefficient, transitions):
    from utils.sphere_triangles import generate
    log.info("Sphere triangles interpolation")

    for k in range(amount):
        x_a = x[None, 2 * k, :, :]
        x_b = x[None, 2 * k + 1, :, :]

        with torch.no_grad():
            z_a, mu_a, var_a = encoder(x_a)
            z_b, mu_b, var_b = encoder(x_b)

        for j, alpha in enumerate(np.linspace(0, 1, transitions)):
            z_int = (1 - alpha
                     ) * z_a + alpha * z_b  # interpolate in the latent space
            weights_int = hyper_network(
                z_int)  # decode the interpolated sample

            target_network = aae.TargetNetwork(config, weights_int[0])
            target_network_input, triangulation = generate(method, depth)
            x_int = torch.transpose(
                target_network(target_network_input.to(device)), 0,
                1).cpu().numpy()

            np.save(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_point_cloud'), np.array(target_network_input))
            np.save(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_interpolation'), x_int)

            with open(
                    join(results_dir, 'sphere_triangles_interpolation',
                         f'{k}_{j}_triangulation.pickle'),
                    'wb') as triangulation_file:
                pickle.dump(triangulation, triangulation_file)

            fig = plot_3d_point_cloud(x_int[0],
                                      x_int[1],
                                      x_int[2],
                                      in_u_sphere=True,
                                      show=False)
            fig.savefig(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_interpolation.png'))
            plt.close(fig)

            target_network_input_coefficient = target_network_input * coefficient
            x_int_coeff = torch.transpose(
                target_network(target_network_input_coefficient.to(device)), 0,
                1).cpu().numpy()

            np.save(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_point_cloud_coefficient_{coefficient}'),
                np.array(target_network_input_coefficient))
            np.save(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_interpolation_coefficient_{coefficient}'),
                x_int_coeff)

            fig = plot_3d_point_cloud(x_int_coeff[0],
                                      x_int_coeff[1],
                                      x_int_coeff[2],
                                      in_u_sphere=True,
                                      show=False)
            fig.savefig(
                join(results_dir, 'sphere_triangles_interpolation',
                     f'{k}_{j}_{coefficient}_interpolation.png'))
            plt.close(fig)
예제 #4
0
def fixed(hyper_network, device, results_dir, epoch, fixed_number, z_size,
          fixed_mean, fixed_std, x_shape, triangulation, method, depth):
    log.info("Fixed")

    fixed_noise = torch.zeros(fixed_number,
                              z_size).normal_(mean=fixed_mean,
                                              std=fixed_std).to(device)
    weights_fixed = hyper_network(fixed_noise)

    for j, weights in enumerate(weights_fixed):
        target_network = aae.TargetNetwork(config, weights).to(device)

        target_network_input = generate_points(config=config,
                                               epoch=epoch,
                                               size=(x_shape[1], x_shape[0]))
        fixed_rec = torch.transpose(
            target_network(target_network_input.to(device)), 0,
            1).cpu().numpy()
        np.save(join(results_dir, 'fixed', f'{j}_target_network_input'),
                np.array(target_network_input))
        np.save(join(results_dir, 'fixed', f'{j}_fixed_reconstruction'),
                fixed_rec)

        fig = plot_3d_point_cloud(fixed_rec[0],
                                  fixed_rec[1],
                                  fixed_rec[2],
                                  in_u_sphere=True,
                                  show=False)
        fig.savefig(join(results_dir, 'fixed', f'{j}_fixed_reconstructed.png'))
        plt.close(fig)

        if triangulation:
            from utils.sphere_triangles import generate

            target_network_input, triangulation = generate(method, depth)

            with open(join(results_dir, 'fixed', f'{j}_triangulation.pickle'),
                      'wb') as triangulation_file:
                pickle.dump(triangulation, triangulation_file)

            fixed_rec = torch.transpose(
                target_network(target_network_input.to(device)), 0,
                1).cpu().numpy()
            np.save(
                join(results_dir, 'fixed',
                     f'{j}_target_network_input_triangulation'),
                np.array(target_network_input))
            np.save(
                join(results_dir, 'fixed',
                     f'{j}_fixed_reconstruction_triangulation'), fixed_rec)

            fig = plot_3d_point_cloud(fixed_rec[0],
                                      fixed_rec[1],
                                      fixed_rec[2],
                                      in_u_sphere=True,
                                      show=False)
            fig.savefig(
                join(results_dir, 'fixed',
                     f'{j}_fixed_reconstructed_triangulation.png'))
            plt.close(fig)

        np.save(join(results_dir, 'fixed', f'{j}_fixed_noise'),
                np.array(fixed_noise[j].cpu()))