def main():
    try:
        os.mkdir(args.output_directory)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    hyperparams = HyperParameters(snapshot_directory=args.snapshot_path)
    model = Model(hyperparams, snapshot_directory=args.snapshot_path)
    if using_gpu:
        model.to_gpu()

    plt.style.use("dark_background")
    fig = plt.figure(figsize=(10, 5))

    num_views_per_scene = 4
    num_generation = 2  # lessened from 4 to 2 (remaining 2 used for original outpu)
    num_original = 2
    total_frames_per_rotation = 24

    image_shape = (3, ) + hyperparams.image_size
    blank_image = make_uint8(np.full(image_shape, 0))
    file_number = 1

    with chainer.no_backprop_mode():
        for subset in dataset:
            iterator = gqn.data.Iterator(subset, batch_size=1)

            for data_indices in iterator:
                snapshot_array = []

                observed_image_array = xp.zeros(
                    (num_views_per_scene, ) + image_shape, dtype=np.float32)
                observed_viewpoint_array = xp.zeros((num_views_per_scene, 7),
                                                    dtype=np.float32)

                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints, original_images = subset[data_indices]

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images = images / 255.0
                images += np.random.uniform(
                    0, 1.0 / 256.0, size=images.shape).astype(np.float32)

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                original_images = original_images.transpose(
                    (0, 1, 4, 2, 3)).astype(np.float32)
                original_images = original_images / 255.0
                original_images += np.random.uniform(
                    0, 1.0 / 256.0,
                    size=original_images.shape).astype(np.float32)

                batch_index = 0

                # Generate images without observations
                r = xp.zeros((
                    num_generation,
                    hyperparams.representation_channels,
                ) + hyperparams.chrz_size,
                             dtype=np.float32)

                angle_rad = 0
                current_scene_original_images_cpu = original_images[
                    batch_index]
                current_scene_original_images = to_gpu(
                    current_scene_original_images_cpu)

                kl_div_sum = 0
                kl_div_list = np.zeros(
                    (1 + num_views_per_scene, total_frames_per_rotation))

                gqn.animator.Snapshot.make_graph(
                    id='kl_div_graph',
                    pos=7,
                    graph_type='plot',
                    frame_in_rotation=total_frames_per_rotation,
                    num_of_data_per_graph=num_views_per_scene + 1,
                    trivial_settings={
                        'colors': ['red', 'blue', 'green', 'orange', 'white'],
                        'markers': ['o', 'o', 'o', 'o', 'o']
                    })

                for t in range(total_frames_per_rotation):
                    snapshot = gqn.animator.Snapshot((2, 4))
                    # grid_master = GridSpec(nrows=2, ncols=4, height_ratios=[1, 1])
                    # snapshot = gqn.animator.Snapshot(layout_settings={
                    #     'subplot_count': 8,
                    #     'grid_master': grid_master,
                    #     'subplots': [
                    #         {
                    #             'subplot_id': i + 1,
                    #             'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[i//4, i%4])
                    #         }
                    #         for i in range(8)
                    #     ]
                    # })

                    for i in [1, 2, 5, 6]:
                        snapshot.add_media(media_type='image',
                                           media_data=make_uint8(blank_image),
                                           media_position=i)
                        if i == 1:
                            snapshot.add_title(text='Observed',
                                               target_media_pos=i)

                    query_viewpoints = rotate_query_viewpoint(
                        angle_rad, num_generation, xp)
                    generated_images = model.generate_image(
                        query_viewpoints, r, xp)

                    kl_div = gqn.math.get_KL_div(
                        to_cpu(current_scene_original_images[t]),
                        to_cpu(generated_images[0]))

                    for i in [3]:
                        snapshot.add_media(media_type='image',
                                           media_data=make_uint8(
                                               generated_images[0]),
                                           media_position=i)
                        snapshot.add_title(text='Generated',
                                           target_media_pos=i)

                    for i in [4]:
                        snapshot.add_media(
                            media_type='image',
                            media_data=make_uint8(
                                current_scene_original_images[t]),
                            media_position=i)
                        snapshot.add_title(text='Original', target_media_pos=i)

                    gqn.animator.Snapshot.add_graph_data(
                        graph_id='kl_div_graph',
                        data_id='kl_div_data_0',
                        new_data=kl_div,
                        frame_num=t,
                    )

                    snapshot.add_title(text='KL Divergence',
                                       target_media_pos=7)

                    snapshot_array.append(snapshot)

                    angle_rad += 2 * math.pi / total_frames_per_rotation

                # Generate images with observations
                for m in range(num_views_per_scene):
                    kl_div_sum = 0
                    observed_image = images[batch_index, m]
                    observed_viewpoint = viewpoints[batch_index, m]

                    observed_image_array[m] = to_gpu(observed_image)
                    observed_viewpoint_array[m] = to_gpu(observed_viewpoint)

                    r = model.compute_observation_representation(
                        observed_image_array[None, :m + 1],
                        observed_viewpoint_array[None, :m + 1])

                    r = cf.broadcast_to(r, (num_generation, ) + r.shape[1:])

                    angle_rad = 0
                    for t in range(total_frames_per_rotation):
                        snapshot = gqn.animator.Snapshot((2, 4))
                        # grid_master = GridSpec(nrows=2, ncols=4, height_ratios=[1, 1])
                        # snapshot = gqn.animator.Snapshot(layout_settings={
                        #     'subplot_count': 8,
                        #     'grid_master': grid_master,
                        #     'subplots': [
                        #         {
                        #             'subplot_id': i + 1,
                        #             'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[i//4, i%4])
                        #         }
                        #         for i in range(8)
                        #     ]
                        # })

                        for i, observed_image in zip([1, 2, 5, 6],
                                                     observed_image_array):
                            snapshot.add_media(
                                media_type='image',
                                media_data=make_uint8(observed_image),
                                media_position=i)
                            if i == 1:
                                snapshot.add_title(text='Observed',
                                                   target_media_pos=i)

                        query_viewpoints = rotate_query_viewpoint(
                            angle_rad, num_generation, xp)
                        generated_images = model.generate_image(
                            query_viewpoints, r, xp)

                        kl_div = gqn.math.get_KL_div(
                            to_cpu(current_scene_original_images[t]),
                            to_cpu(generated_images[0]))

                        for i in [3]:
                            snapshot.add_media(media_type='image',
                                               media_data=make_uint8(
                                                   generated_images[0]),
                                               media_position=i)
                            snapshot.add_title(text='Generated',
                                               target_media_pos=i)

                        for i in [4]:
                            snapshot.add_media(
                                media_type='image',
                                media_data=make_uint8(
                                    current_scene_original_images[t]),
                                media_position=i)
                            snapshot.add_title(text='Original',
                                               target_media_pos=i)

                        gqn.animator.Snapshot.add_graph_data(
                            graph_id='kl_div_graph',
                            data_id='kl_div_data_' + str(m + 1),
                            new_data=kl_div,
                            frame_num=t,
                        )

                        snapshot.add_title(text='KL Divergence',
                                           target_media_pos=7)

                        angle_rad += 2 * math.pi / total_frames_per_rotation
                        # plt.pause(1e-8)

                        snapshot_array.append(snapshot)

                plt.subplots_adjust(left=None,
                                    bottom=None,
                                    right=None,
                                    top=None,
                                    wspace=0,
                                    hspace=0)

                anim = animation.FuncAnimation(
                    fig,
                    func_anim_upate,
                    fargs=(fig, [snapshot_array]),
                    interval=1 / 24,
                    frames=(num_views_per_scene + 1) *
                    total_frames_per_rotation)

                anim.save("{}/shepard_matzler_{}.mp4".format(
                    args.output_directory, file_number),
                          writer="ffmpeg",
                          fps=12)

                if not os.path.exists("{}/shepard_matzler_{}".format(
                        args.output_directory, file_number)):
                    os.mkdir("{}/shepard_matzler_{}".format(
                        args.output_directory, file_number))

                picData = []
                for i in range(
                    (num_views_per_scene) * total_frames_per_rotation):
                    snapshot = snapshot_array[i + total_frames_per_rotation]
                    _media = snapshot.get_subplot(3)
                    media = _media['body']
                    picData.append(media['media_data'])
                    figu = plt.figure()
                    plt.axis('off')
                    plt.imshow(media['media_data'])
                    plt.savefig("{}/shepard_matzler_{}/{}.png".format(
                        args.output_directory, file_number, i))
                    plt.close(figu)

                bigfig = plt.figure(figsize=(20, 10))
                for i in range(num_views_per_scene):
                    for j in range(total_frames_per_rotation):
                        plt.subplot(num_views_per_scene,
                                    total_frames_per_rotation,
                                    (i * total_frames_per_rotation + j + 1))
                        plt.axis('off')
                        plt.imshow(picData[i * total_frames_per_rotation + j])
                plt.savefig("{}/shepard_matzler_{}_ALL.png".format(
                    args.output_directory, file_number))
                plt.close(bigfig)

                file_number += 1
Example #2
0
def main():
    try:
        os.mkdir(args.output_directory)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    models = []
    snapshot_dirs = os.listdir(args.snapshots_dir_path)
    for snapshot_dir in snapshot_dirs:
        snapshot_path = os.path.join(args.snapshots_dir_path, snapshot_dir)
        hyperparams = HyperParameters(snapshot_directory=snapshot_path)
        models.append(Model(hyperparams, snapshot_directory=snapshot_path))

    if using_gpu:
        for model in models:
            model.to_gpu()

    plt.style.use("dark_background")
    fig = plt.figure(figsize=(10, 5))

    num_views_per_scene = 4
    num_generation = 2  # lessened from 4 to 2 (remaining 2 used for original outpu)
    num_original = 2
    total_frames_per_rotation = 24

    image_shape = (3, ) + hyperparams.image_size
    blank_image = make_uint8(np.full(image_shape, 0))
    file_number = 1

    with chainer.no_backprop_mode():
        for subset in dataset:
            iterator = gqn.data.Iterator(subset, batch_size=1)

            for data_indices in iterator:
                snapshot_array = []

                observed_image_array = xp.zeros(
                    (num_views_per_scene, ) + image_shape, dtype=np.float32)
                observed_viewpoint_array = xp.zeros((num_views_per_scene, 7),
                                                    dtype=np.float32)

                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints, original_images = subset[data_indices]

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images = images / 255.0
                images += np.random.uniform(
                    0, 1.0 / 256.0, size=images.shape).astype(np.float32)

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                original_images = original_images.transpose(
                    (0, 1, 4, 2, 3)).astype(np.float32)
                original_images = original_images / 255.0
                original_images += np.random.uniform(
                    0, 1.0 / 256.0,
                    size=original_images.shape).astype(np.float32)

                batch_index = 0

                # Generate images without observations
                r = xp.zeros((
                    num_generation,
                    hyperparams.representation_channels,
                ) + hyperparams.chrz_size,
                             dtype=np.float32)

                angle_rad = 0
                current_scene_original_images_cpu = original_images[
                    batch_index]
                current_scene_original_images = to_gpu(
                    current_scene_original_images_cpu)

                for i in range(10):
                    gqn.animator.Snapshot.make_graph(
                        id='kl_div_graph_' + str(i),
                        pos=5 + i,
                        graph_type='plot',
                        frame_in_rotation=total_frames_per_rotation,
                        num_of_data_per_graph=num_views_per_scene + 1,
                        trivial_settings={
                            'colors':
                            ['red', 'blue', 'green', 'orange', 'white'],
                            'markers': ['o', 'o', 'o', 'o', 'o'],
                            'noXTicks': True,
                            'noYTicks': True,
                        })

                for t in range(total_frames_per_rotation):
                    grid_master = GridSpec(nrows=4,
                                           ncols=5,
                                           height_ratios=[1, 1, 1, 1])
                    snapshot = gqn.animator.Snapshot(
                        layout_settings={
                            'subplot_count':
                            14,
                            'grid_master':
                            grid_master,
                            'subplots': [{
                                'subplot_id':
                                i + 1,
                                'subplot':
                                GridSpecFromSubplotSpec(
                                    nrows=2,
                                    ncols=2,
                                    subplot_spec=grid_master[i * 2:i * 2 + 2,
                                                             0:1])
                            } for i in range(2)] + [{
                                'subplot_id':
                                i + 3,
                                'subplot':
                                GridSpecFromSubplotSpec(
                                    nrows=1,
                                    ncols=1,
                                    subplot_spec=grid_master[i, 2])
                            } for i in range(2)] + [{
                                'subplot_id':
                                i + 5,
                                'subplot':
                                GridSpecFromSubplotSpec(
                                    nrows=1,
                                    ncols=1,
                                    subplot_spec=grid_master[i // 2,
                                                             3 + i % 2])
                            } for i in range(4)] + [{
                                'subplot_id':
                                i + 9,
                                'subplot':
                                GridSpecFromSubplotSpec(
                                    nrows=1,
                                    ncols=1,
                                    subplot_spec=grid_master[2 + i // 3,
                                                             2 + i % 3])
                            } for i in range(6)]
                        })

                    query_viewpoints = rotate_query_viewpoint(
                        angle_rad, num_generation, xp)
                    generated_images_list = []
                    for model in models:
                        generated_images_list.append(
                            model.generate_image(query_viewpoints, r, xp))

                    sq_d_list = []
                    for generated_images in generated_images_list:
                        sq_d_list.append(
                            gqn.math.get_squared_distance(
                                to_cpu(current_scene_original_images[t]),
                                to_cpu(generated_images[0]))[0])

                    snapshot.add_media(media_position=1,
                                       media_type='image',
                                       media_data=make_uint8(
                                           generated_images[0]))
                    snapshot.add_title(target_media_pos=1, text='Generated')

                    snapshot.add_media(media_position=2,
                                       media_type='image',
                                       media_data=make_uint8(
                                           current_scene_original_images[t]))
                    snapshot.add_title(target_media_pos=2, text='Original')

                    # snapshot.add_media(media_position=3  , media_type='image', media_data=make_uint8(current_scene_original_images[t]))
                    # snapshot.add_title(target_media_pos=3, text='Original')

                    snapshot.add_media(media_position=4,
                                       media_type='image',
                                       media_data=make_uint8(blank_image))
                    snapshot.add_title(target_media_pos=4, text='Observed')

                    for i in range(10):
                        gqn.animator.Snapshot.add_graph_data(
                            graph_id='kl_div_graph_' + str(i),
                            data_id='kl_div_data_0',
                            new_data=sq_d_list[i],
                            frame_num=t,
                        )
                    print('snap')
                    snapshot_array.append(snapshot)

                    angle_rad += 2 * math.pi / total_frames_per_rotation

                # Generate images with observations
                for m in range(num_views_per_scene):
                    kl_div_sum = 0
                    observed_image = images[batch_index, m]
                    observed_viewpoint = viewpoints[batch_index, m]

                    observed_image_array[m] = to_gpu(observed_image)
                    observed_viewpoint_array[m] = to_gpu(observed_viewpoint)

                    r_list = []
                    for i, model in enumerate(models):
                        r_list.append(
                            model.compute_observation_representation(
                                observed_image_array[None, :m + 1],
                                observed_viewpoint_array[None, :m + 1]))

                        r_list[i] = cf.broadcast_to(r_list[i],
                                                    (num_generation, ) +
                                                    r_list[i].shape[1:])

                    angle_rad = 0
                    for t in range(total_frames_per_rotation):
                        grid_master = GridSpec(nrows=4,
                                               ncols=5,
                                               height_ratios=[1, 1, 1, 1])
                        snapshot = gqn.animator.Snapshot(
                            layout_settings={
                                'subplot_count':
                                14,
                                'grid_master':
                                grid_master,
                                'subplots': [{
                                    'subplot_id':
                                    i + 1,
                                    'subplot':
                                    GridSpecFromSubplotSpec(
                                        nrows=2,
                                        ncols=2,
                                        subplot_spec=grid_master[i * 2:i * 2 +
                                                                 2, 0:1])
                                } for i in range(2)] + [{
                                    'subplot_id':
                                    i + 3,
                                    'subplot':
                                    GridSpecFromSubplotSpec(
                                        nrows=1,
                                        ncols=1,
                                        subplot_spec=grid_master[i, 2])
                                } for i in range(2)] + [{
                                    'subplot_id':
                                    i + 5,
                                    'subplot':
                                    GridSpecFromSubplotSpec(
                                        nrows=1,
                                        ncols=1,
                                        subplot_spec=grid_master[i // 2,
                                                                 3 + i % 2])
                                } for i in range(4)] + [{
                                    'subplot_id':
                                    i + 9,
                                    'subplot':
                                    GridSpecFromSubplotSpec(
                                        nrows=1,
                                        ncols=1,
                                        subplot_spec=grid_master[2 + i // 3,
                                                                 2 + i % 3])
                                } for i in range(6)]
                            })

                        query_viewpoints = rotate_query_viewpoint(
                            angle_rad, num_generation, xp)
                        generated_images_list = []
                        for model, r in zip(models, r_list):
                            generated_images_list.append(
                                model.generate_image(query_viewpoints, r, xp))

                        sq_d_list = []
                        for i, generated_images in enumerate(
                                generated_images_list):
                            sq_d_list.append(
                                gqn.math.get_squared_distance(
                                    to_cpu(current_scene_original_images[t]),
                                    to_cpu(generated_images[0]))[0])

                        snapshot.add_media(media_position=1,
                                           media_type='image',
                                           media_data=make_uint8(
                                               generated_images[0]))
                        snapshot.add_title(target_media_pos=1,
                                           text='Generated')

                        snapshot.add_media(
                            media_position=2,
                            media_type='image',
                            media_data=make_uint8(
                                current_scene_original_images[t]))
                        snapshot.add_title(target_media_pos=2, text='Original')

                        # snapshot.add_media(media_position=3  , media_type='image', media_data=make_uint8(current_scene_original_images[t]))
                        # snapshot.add_title(target_media_pos=3, text='Original')

                        snapshot.add_media(media_position=4,
                                           media_type='image',
                                           media_data=make_uint8(
                                               observed_image_array[m]))
                        snapshot.add_title(target_media_pos=4, text='Observed')

                        # for i, kl_div in enumerate(kl_div_list):
                        for i in range(10):
                            gqn.animator.Snapshot.add_graph_data(
                                graph_id='kl_div_graph_' + str(i),
                                data_id='kl_div_data_' + str(m + 1),
                                new_data=sq_d_list[i],
                                frame_num=t,
                            )

                        angle_rad += 2 * math.pi / total_frames_per_rotation
                        # plt.pause(1e-8)

                        print('snap')
                        snapshot_array.append(snapshot)

                plt.subplots_adjust(left=None,
                                    bottom=None,
                                    right=None,
                                    top=None,
                                    wspace=0,
                                    hspace=0)

                anim = animation.FuncAnimation(
                    fig,
                    func_anim_upate,
                    fargs=(fig, [snapshot_array]),
                    interval=1 / 24,
                    frames=(num_views_per_scene + 1) *
                    total_frames_per_rotation)

                anim.save("{}/shepard_matzler_{}.mp4".format(
                    args.output_directory, file_number),
                          writer="ffmpeg",
                          fps=12)
                file_number += 1
Example #3
0
def main():
    try:
        os.mkdir(args.snapshot_directory)
    except:
        pass

    np.random.seed(0)

    xp = np
    device_gpu = args.gpu_device
    device_cpu = -1
    using_gpu = device_gpu >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_directory)

    hyperparams = HyperParameters()
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.generator_share_upsampler = args.generator_share_upsampler
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.h_channels = args.h_channels
    hyperparams.z_channels = args.z_channels
    hyperparams.u_channels = args.u_channels
    hyperparams.image_size = (args.image_size, args.image_size)
    hyperparams.representation_channels = args.representation_channels
    hyperparams.representation_architecture = args.representation_architecture
    hyperparams.pixel_n = args.pixel_n
    hyperparams.pixel_sigma_i = args.initial_pixel_variance
    hyperparams.pixel_sigma_f = args.final_pixel_variance
    hyperparams.save(args.snapshot_directory)
    print(hyperparams)

    model = Model(hyperparams,
                  snapshot_directory=args.snapshot_directory,
                  optimized=args.optimized)
    if using_gpu:
        model.to_gpu()

    scheduler = Scheduler(sigma_start=args.initial_pixel_variance,
                          sigma_end=args.final_pixel_variance,
                          final_num_updates=args.pixel_n,
                          snapshot_directory=args.snapshot_directory)
    print(scheduler)

    optimizer = AdamOptimizer(model.parameters,
                              mu_i=args.initial_lr,
                              mu_f=args.final_lr,
                              initial_training_step=scheduler.num_updates)
    print(optimizer)

    pixel_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                        scheduler.pixel_variance**2,
                        dtype="float32")
    pixel_ln_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                           math.log(scheduler.pixel_variance**2),
                           dtype="float32")

    representation_shape = (args.batch_size,
                            hyperparams.representation_channels,
                            args.image_size // 4, args.image_size // 4)

    fig = plt.figure(figsize=(9, 3))
    axis_data = fig.add_subplot(1, 3, 1)
    axis_data.set_title("Data")
    axis_data.axis("off")
    axis_reconstruction = fig.add_subplot(1, 3, 2)
    axis_reconstruction.set_title("Reconstruction")
    axis_reconstruction.axis("off")
    axis_generation = fig.add_subplot(1, 3, 3)
    axis_generation.set_title("Generation")
    axis_generation.axis("off")

    current_training_step = 0
    for iteration in range(args.training_iterations):
        mean_kld = 0
        mean_nll = 0
        mean_mse = 0
        mean_elbo = 0
        total_num_batch = 0
        start_time = time.time()

        for subset_index, subset in enumerate(dataset):
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for batch_index, data_indices in enumerate(iterator):
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)

                total_views = images.shape[1]

                # Sample number of views
                num_views = random.choice(range(1, total_views + 1))
                observation_view_indices = list(range(total_views))
                random.shuffle(observation_view_indices)
                observation_view_indices = observation_view_indices[:num_views]
                query_index = random.choice(range(total_views))

                if num_views > 0:
                    observation_images = preprocess_images(
                        images[:, observation_view_indices])
                    observation_query = viewpoints[:, observation_view_indices]
                    representation = model.compute_observation_representation(
                        observation_images, observation_query)
                else:
                    representation = xp.zeros(representation_shape,
                                              dtype="float32")
                    representation = chainer.Variable(representation)

                # Sample query
                query_index = random.choice(range(total_views))
                query_images = preprocess_images(images[:, query_index])
                query_viewpoints = viewpoints[:, query_index]

                # Transfer to gpu if necessary
                query_images = to_device(query_images, device_gpu)
                query_viewpoints = to_device(query_viewpoints, device_gpu)

                z_t_param_array, mean_x = model.sample_z_and_x_params_from_posterior(
                    query_images, query_viewpoints, representation)

                # Compute loss
                ## KL Divergence
                loss_kld = 0
                for params in z_t_param_array:
                    mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params
                    kld = gqn.functions.gaussian_kl_divergence(
                        mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)
                    loss_kld += cf.sum(kld)

                ## Negative log-likelihood of generated image
                loss_nll = cf.sum(
                    gqn.functions.gaussian_negative_log_likelihood(
                        query_images, mean_x, pixel_var, pixel_ln_var))

                # Calculate the average loss value
                loss_nll = loss_nll / args.batch_size
                loss_kld = loss_kld / args.batch_size

                loss = loss_nll / scheduler.pixel_variance + loss_kld

                model.cleargrads()
                loss.backward()
                optimizer.update(current_training_step)

                loss_nll = float(loss_nll.data) + math.log(256.0)
                loss_kld = float(loss_kld.data)

                elbo = -(loss_nll + loss_kld)

                loss_mse = float(
                    cf.mean_squared_error(query_images, mean_x).data)

                printr(
                    "Iteration {}: Subset {} / {}: Batch {} / {} - elbo: {:.2f} - loss: nll: {:.2f} mse: {:.6e} kld: {:.5f} - lr: {:.4e} - pixel_variance: {:.5f} - step: {}  "
                    .format(iteration + 1,
                            subset_index + 1, len(dataset), batch_index + 1,
                            len(iterator), elbo, loss_nll, loss_mse, loss_kld,
                            optimizer.learning_rate, scheduler.pixel_variance,
                            current_training_step))

                scheduler.step(iteration, current_training_step)
                pixel_var[...] = scheduler.pixel_variance**2
                pixel_ln_var[...] = math.log(scheduler.pixel_variance**2)

                total_num_batch += 1
                current_training_step += 1
                mean_kld += loss_kld
                mean_nll += loss_nll
                mean_mse += loss_mse
                mean_elbo += elbo

            model.serialize(args.snapshot_directory)

            # Visualize
            if args.with_visualization:
                axis_data.imshow(make_uint8(query_images[0]),
                                 interpolation="none")
                axis_reconstruction.imshow(make_uint8(mean_x.data[0]),
                                           interpolation="none")

                with chainer.no_backprop_mode():
                    generated_x = model.generate_image(
                        query_viewpoints[None, 0], representation[None, 0])
                    axis_generation.imshow(make_uint8(generated_x[0]),
                                           interpolation="none")
                plt.pause(1e-8)

        elapsed_time = time.time() - start_time
        print(
            "\033[2KIteration {} - elbo: {:.2f} - loss: nll: {:.2f} mse: {:.6e} kld: {:.5f} - lr: {:.4e} - pixel_variance: {:.5f} - step: {} - time: {:.3f} min"
            .format(iteration + 1, mean_elbo / total_num_batch,
                    mean_nll / total_num_batch, mean_mse / total_num_batch,
                    mean_kld / total_num_batch, optimizer.learning_rate,
                    scheduler.pixel_variance, current_training_step,
                    elapsed_time / 60))
        model.serialize(args.snapshot_directory)
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = HyperParameters(snapshot_directory=args.snapshot_path)
    model = Model(hyperparams, snapshot_directory=args.snapshot_path)
    if using_gpu:
        model.to_gpu()

    dataset_mean = np.load(os.path.join(args.snapshot_path, "mean.npy"))
    dataset_std = np.load(os.path.join(args.snapshot_path, "std.npy"))

    # avoid division by zero
    dataset_std += 1e-12

    screen_size = hyperparams.image_size
    camera = gqn.three.PerspectiveCamera(
        eye=(3, 1, 0),
        center=(0, 0, 0),
        up=(0, 1, 0),
        fov_rad=math.pi / 2.0,
        aspect_ratio=screen_size[0] / screen_size[1],
        z_near=0.1,
        z_far=10)

    figure = gqn.imgplot.figure()
    axes_observations = []
    axes_generations = []
    sqrt_n = math.sqrt(args.num_views_per_scene)
    axis_width = 0.5 / sqrt_n
    axis_height = 1.0 / sqrt_n
    for n in range(args.num_views_per_scene):
        axis = gqn.imgplot.image()
        x = n % sqrt_n
        y = n // sqrt_n
        figure.add(axis, x * axis_width, y * axis_height, axis_width,
                   axis_height)
        axes_observations.append(axis)
    sqrt_n = math.sqrt(args.num_generation)
    axis_width = 0.5 / sqrt_n
    axis_height = 1.0 / sqrt_n
    for n in range(args.num_generation):
        axis = gqn.imgplot.image()
        x = n % sqrt_n
        y = n // sqrt_n
        figure.add(axis, x * axis_width + 0.5, y * axis_height, axis_width,
                   axis_height)
        axes_generations.append(axis)

    window = gqn.imgplot.window(figure, (1600, 800), "Dataset")
    window.show()

    raw_observed_image = np.zeros(screen_size + (3, ), dtype="uint32")
    renderer = gqn.three.Renderer(screen_size[0], screen_size[1])

    observed_images = xp.zeros(
        (args.num_views_per_scene, 3) + screen_size, dtype="float32")
    observed_viewpoints = xp.zeros(
        (args.num_views_per_scene, 7), dtype="float32")

    with chainer.no_backprop_mode():
        while True:
            if window.closed():
                exit()

            scene, _, _ = gqn.environment.room.build_scene(
                object_names=["cube", "sphere", "cone", "cylinder", "icosahedron"],
                num_objects=random.choice([x for x in range(1, 6)]))
            renderer.set_scene(scene)

            # Generate images without observations
            r = xp.zeros(
                (
                    args.num_generation,
                    hyperparams.channels_r,
                ) + hyperparams.chrz_size,
                dtype="float32")
            total_frames = 50
            for tick in range(total_frames):
                if window.closed():
                    exit()
                query_viewpoints = generate_random_query_viewpoint(
                    tick / total_frames, xp)
                generated_images = to_cpu(
                    model.generate_image(query_viewpoints, r, xp))

                for m in range(args.num_generation):
                    if window.closed():
                        exit()
                    image = make_uint8(generated_images[m], dataset_mean,
                                       dataset_std)
                    axis = axes_generations[m]
                    axis.update(image)

            for n in range(args.num_views_per_scene):
                if window.closed():
                    exit()
                eye = (random.uniform(-3, 3), 1, random.uniform(-3, 3))
                center = (random.uniform(-3, 3), random.uniform(0, 1),
                        random.uniform(-3, 3))
                yaw = gqn.math.yaw(eye, center)
                pitch = gqn.math.pitch(eye, center)
                camera.look_at(
                    eye=eye,
                    center=center,
                    up=(0.0, 1.0, 0.0),
                )
                renderer.render(camera, raw_observed_image)

                # [0, 255] -> [-1, 1]
                observe_image = (raw_observed_image / 255.0 - 0.5) * 2.0

                # preprocess
                observe_image = (observe_image - dataset_mean) / dataset_std

                observed_images[n] = to_gpu(observe_image.transpose((2, 0, 1)))

                observed_viewpoints[n] = xp.array(
                    (eye[0], eye[1], eye[2], math.cos(yaw), math.sin(yaw),
                     math.cos(pitch), math.sin(pitch)),
                    dtype="float32")

                r = model.compute_observation_representation(
                    observed_images[None, :n + 1],
                    observed_viewpoints[None, :n + 1])

                r = cf.broadcast_to(r, (args.num_generation, ) + r.shape[1:])

                axis = axes_observations[n]
                axis.update(np.uint8(raw_observed_image))

                total_frames = 50
                for tick in range(total_frames):
                    if window.closed():
                        exit()
                    query_viewpoints = generate_random_query_viewpoint(
                        tick / total_frames, xp)
                    generated_images = to_cpu(
                        model.generate_image(query_viewpoints, r, xp))

                    for m in range(args.num_generation):
                        if window.closed():
                            exit()
                        image = make_uint8(generated_images[m], dataset_mean,
                                           dataset_std)
                        axis = axes_generations[m]
                        axis.update(image)

            raw_observed_image[...] = 0
            for axis in axes_observations:
                axis.update(np.uint8(raw_observed_image))
Example #5
0
def main():
    meter_train = Meter()
    assert meter_train.load(args.snapshot_directory)

    #==============================================================================
    # Selecting the GPU
    #==============================================================================
    xp = np
    gpu_device = args.gpu_device
    using_gpu = gpu_device >= 0
    if using_gpu:
        cuda.get_device(gpu_device).use()
        xp = cp

    #==============================================================================
    # Dataset
    #==============================================================================
    dataset_test = Dataset(args.test_dataset_directory)

    #==============================================================================
    # Hyperparameters
    #==============================================================================
    hyperparams = HyperParameters()
    assert hyperparams.load(args.snapshot_directory)
    print(hyperparams, "\n")

    #==============================================================================
    # Model
    #==============================================================================
    model = Model(hyperparams)
    assert model.load(args.snapshot_directory, meter_train.epoch)
    if using_gpu:
        model.to_gpu()

    #==============================================================================
    # Pixel-variance annealing
    #==============================================================================
    variance_scheduler = PixelVarianceScheduler()
    assert variance_scheduler.load(args.snapshot_directory)
    print(variance_scheduler, "\n")

    #==============================================================================
    # Algorithms
    #==============================================================================
    def encode_scene(images, viewpoints):
        # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
        images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)

        # Sample number of views
        total_views = images.shape[1]
        num_views = random.choice(range(1, total_views + 1))

        # Sample views
        observation_view_indices = list(range(total_views))
        random.shuffle(observation_view_indices)
        observation_view_indices = observation_view_indices[:num_views]

        observation_images = preprocess_images(
            images[:, observation_view_indices])

        observation_query = viewpoints[:, observation_view_indices]
        representation = model.compute_observation_representation(
            observation_images, observation_query)

        # Sample query view
        query_index = random.choice(range(total_views))
        query_images = preprocess_images(images[:, query_index])
        query_viewpoints = viewpoints[:, query_index]

        # Transfer to gpu if necessary
        query_images = to_device(query_images, gpu_device)
        query_viewpoints = to_device(query_viewpoints, gpu_device)

        return representation, query_images, query_viewpoints

    def estimate_ELBO(query_images, z_t_param_array, pixel_mean,
                      pixel_log_sigma):
        # KL Diverge, pixel_ln_varnce
        kl_divergence = 0
        for params_t in z_t_param_array:
            mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params_t
            normal_q = chainer.distributions.Normal(mean_z_q,
                                                    log_scale=ln_var_z_q)
            normal_p = chainer.distributions.Normal(mean_z_p,
                                                    log_scale=ln_var_z_p)
            kld_t = chainer.kl_divergence(normal_q, normal_p)
            kl_divergence += cf.sum(kld_t)
        kl_divergence = kl_divergence / args.batch_size

        # Negative log-likelihood of generated image
        batch_size = query_images.shape[0]
        num_pixels_per_batch = np.prod(query_images.shape[1:])
        normal = chainer.distributions.Normal(query_images,
                                              log_scale=pixel_log_sigma)

        log_px = cf.sum(normal.log_prob(pixel_mean)) / batch_size
        negative_log_likelihood = -log_px

        # Empirical ELBO
        ELBO = log_px - kl_divergence

        # https://arxiv.org/abs/1604.08772 Section.2
        # https://www.reddit.com/r/MachineLearning/comments/56m5o2/discussion_calculation_of_bitsdims/
        bits_per_pixel = -(ELBO / num_pixels_per_batch -
                           np.log(256)) / np.log(2)

        return ELBO, bits_per_pixel, negative_log_likelihood, kl_divergence

    #==============================================================================
    # Test the model
    #==============================================================================
    meter = Meter()
    pixel_log_sigma = xp.full((args.batch_size, 3) + hyperparams.image_size,
                              math.log(variance_scheduler.standard_deviation),
                              dtype="float32")

    with chainer.no_backprop_mode():
        for subset_index, subset in enumerate(dataset_test):
            iterator = Iterator(subset, batch_size=args.batch_size)
            for data_indices in iterator:
                images, viewpoints = subset[data_indices]

                # Scene encoder
                representation, query_images, query_viewpoints = encode_scene(
                    images, viewpoints)

                # Compute empirical ELBO
                (z_t_param_array,
                 pixel_mean) = model.sample_z_and_x_params_from_posterior(
                     query_images, query_viewpoints, representation)
                (ELBO, bits_per_pixel, negative_log_likelihood,
                 kl_divergence) = estimate_ELBO(query_images, z_t_param_array,
                                                pixel_mean, pixel_log_sigma)
                mean_squared_error = cf.mean_squared_error(
                    query_images, pixel_mean)

                # Logging
                meter.update(ELBO=float(ELBO.data),
                             bits_per_pixel=float(bits_per_pixel.data),
                             negative_log_likelihood=float(
                                 negative_log_likelihood.data),
                             kl_divergence=float(kl_divergence.data),
                             mean_squared_error=float(mean_squared_error.data))

            if subset_index % 100 == 0:
                print("    Subset {}/{}:".format(
                    subset_index + 1,
                    len(dataset_test),
                ))
                print("        {}".format(meter))

    print("    Test:")
    print("        {} - done in {:.3f} min".format(
        meter,
        meter.elapsed_time,
    ))
Example #6
0
def main():
    try:
        os.makedirs(args.figure_directory)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cp

    dataset = gqn.data.Dataset(args.dataset_directory)

    meter = Meter()
    assert meter.load(args.snapshot_directory)

    hyperparams = HyperParameters()
    assert hyperparams.load(args.snapshot_directory)

    model = Model(hyperparams)
    assert model.load(args.snapshot_directory, meter.epoch)

    if using_gpu:
        model.to_gpu()

    total_observations_per_scene = 4
    fps = 30

    black_color = -0.5
    image_shape = (3, ) + hyperparams.image_size
    axis_observations_image = np.zeros(
        (3, image_shape[1], total_observations_per_scene * image_shape[2]),
        dtype=np.float32)

    #==============================================================================
    # Utilities
    #==============================================================================
    def to_device(array):
        if using_gpu:
            array = cuda.to_gpu(array)
        return array

    def fill_observations_axis(observation_images):
        axis_observations_image = np.full(
            (3, image_shape[1], total_observations_per_scene * image_shape[2]),
            black_color,
            dtype=np.float32)
        num_current_obs = len(observation_images)
        total_obs = total_observations_per_scene
        width = image_shape[2]
        x_start = width * (total_obs - num_current_obs) // 2
        for obs_image in observation_images:
            x_end = x_start + width
            axis_observations_image[:, :, x_start:x_end] = obs_image
            x_start += width
        return axis_observations_image

    def compute_camera_angle_at_frame(t):
        horizontal_angle_rad = 2 * t * math.pi / (fps * 2) + math.pi / 4
        y_rad_top = math.pi / 3
        y_rad_bottom = -math.pi / 3
        y_rad_range = y_rad_bottom - y_rad_top
        if t < fps * 1.5:
            vertical_angle_rad = y_rad_top
        elif fps * 1.5 <= t and t < fps * 2.5:
            interp = (t - fps * 1.5) / fps
            vertical_angle_rad = y_rad_top + interp * y_rad_range
        elif fps * 2.5 <= t and t < fps * 4:
            vertical_angle_rad = y_rad_bottom
        elif fps * 4.0 <= t and t < fps * 5:
            interp = (t - fps * 4.0) / fps
            vertical_angle_rad = y_rad_bottom - interp * y_rad_range
        else:
            vertical_angle_rad = y_rad_top
        return horizontal_angle_rad, vertical_angle_rad

    def rotate_query_viewpoint(horizontal_angle_rad, vertical_angle_rad):
        camera_direction = np.array([
            math.sin(horizontal_angle_rad),  # x
            math.sin(vertical_angle_rad),  # y
            math.cos(horizontal_angle_rad),  # z
        ])
        camera_direction = args.camera_distance * camera_direction / np.linalg.norm(
            camera_direction)
        yaw, pitch = compute_yaw_and_pitch(camera_direction)
        query_viewpoints = xp.array(
            (
                camera_direction[0],
                camera_direction[1],
                camera_direction[2],
                math.cos(yaw),
                math.sin(yaw),
                math.cos(pitch),
                math.sin(pitch),
            ),
            dtype=np.float32,
        )
        query_viewpoints = xp.broadcast_to(query_viewpoints,
                                           (1, ) + query_viewpoints.shape)
        return query_viewpoints

    #==============================================================================
    # Visualization
    #==============================================================================
    plt.style.use("dark_background")
    fig = plt.figure(figsize=(6, 7))
    plt.subplots_adjust(left=0.1, right=0.95, bottom=0.1, top=0.95)
    # fig.suptitle("GQN")
    axis_observations = fig.add_subplot(2, 1, 1)
    axis_observations.axis("off")
    axis_observations.set_title("observations")
    axis_generation = fig.add_subplot(2, 1, 2)
    axis_generation.axis("off")
    axis_generation.set_title("neural rendering")

    #==============================================================================
    # Generating animation
    #==============================================================================
    file_number = 1
    random.seed(0)
    np.random.seed(0)

    with chainer.no_backprop_mode():
        for subset in dataset:
            iterator = gqn.data.Iterator(subset, batch_size=1)

            for data_indices in iterator:
                animation_frame_array = []

                observed_image_array = xp.full(
                    (total_observations_per_scene, ) + image_shape,
                    black_color,
                    dtype=np.float32)
                observed_viewpoint_array = xp.zeros(
                    (total_observations_per_scene, 7), dtype=np.float32)

                # shape: (batch, views, height, width, channels)
                images, viewpoints = subset[data_indices]

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images = preprocess_images(images)

                batch_index = 0

                #------------------------------------------------------------------------------
                # Generate images with a single observation
                #------------------------------------------------------------------------------
                observation_index = 0

                # Scene encoder
                observed_image = images[batch_index, observation_index]
                observed_viewpoint = viewpoints[batch_index, observation_index]

                observed_image_array[observation_index] = to_device(
                    observed_image)
                observed_viewpoint_array[observation_index] = to_device(
                    observed_viewpoint)

                representation = model.compute_observation_representation(
                    observed_image_array[None, :observation_index + 1],
                    observed_viewpoint_array[None, :observation_index + 1])

                # Update figure
                axis_observations_image = fill_observations_axis(
                    [observed_image])

                # Rotate camera
                for t in range(fps, fps * 6):
                    artist_array = [
                        axis_observations.imshow(
                            make_uint8(axis_observations_image),
                            interpolation="none",
                            animated=True)
                    ]

                    horizontal_angle_rad, vertical_angle_rad = compute_camera_angle_at_frame(
                        t)
                    query_viewpoints = rotate_query_viewpoint(
                        horizontal_angle_rad, vertical_angle_rad)
                    generated_images = model.generate_image(
                        query_viewpoints, representation)[0]

                    artist_array.append(
                        axis_generation.imshow(make_uint8(generated_images),
                                               interpolation="none",
                                               animated=True))

                    animation_frame_array.append(artist_array)

                #------------------------------------------------------------------------------
                # Add observations
                #------------------------------------------------------------------------------
                for n in range(total_observations_per_scene):
                    axis_observations_image = fill_observations_axis(
                        images[batch_index, :n + 1])

                    # Scene encoder
                    representation = model.compute_observation_representation(
                        observed_image_array[None, :n + 1],
                        observed_viewpoint_array[None, :n + 1])

                    for t in range(fps // 2):
                        artist_array = [
                            axis_observations.imshow(
                                make_uint8(axis_observations_image),
                                interpolation="none",
                                animated=True)
                        ]

                        horizontal_angle_rad, vertical_angle_rad = compute_camera_angle_at_frame(
                            0)
                        query_viewpoints = rotate_query_viewpoint(
                            horizontal_angle_rad, vertical_angle_rad)
                        generated_images = model.generate_image(
                            query_viewpoints, representation)[0]

                        artist_array.append(
                            axis_generation.imshow(
                                make_uint8(generated_images),
                                interpolation="none",
                                animated=True))

                        animation_frame_array.append(artist_array)

                #------------------------------------------------------------------------------
                # Generate images with all observations
                #------------------------------------------------------------------------------
                # Scene encoder
                representation = model.compute_observation_representation(
                    observed_image_array[None, :total_observations_per_scene +
                                         1],
                    observed_viewpoint_array[
                        None, :total_observations_per_scene + 1])
                # Rotate camera
                for t in range(0, fps * 6):
                    artist_array = [
                        axis_observations.imshow(
                            make_uint8(axis_observations_image),
                            interpolation="none",
                            animated=True)
                    ]

                    horizontal_angle_rad, vertical_angle_rad = compute_camera_angle_at_frame(
                        t)
                    query_viewpoints = rotate_query_viewpoint(
                        horizontal_angle_rad, vertical_angle_rad)
                    generated_images = model.generate_image(
                        query_viewpoints, representation)[0]

                    artist_array.append(
                        axis_generation.imshow(make_uint8(generated_images),
                                               interpolation="none",
                                               animated=True))

                    animation_frame_array.append(artist_array)

                #------------------------------------------------------------------------------
                # Write to file
                #------------------------------------------------------------------------------
                anim = animation.ArtistAnimation(fig,
                                                 animation_frame_array,
                                                 interval=1 / fps,
                                                 blit=True,
                                                 repeat_delay=0)

                # anim.save(
                #     "{}/shepard_matzler_observations_{}.gif".format(
                #         args.figure_directory, file_number),
                #     writer="imagemagick",
                #     fps=fps)
                anim.save("{}/shepard_matzler_observations_{}.mp4".format(
                    args.figure_directory, file_number),
                          writer="ffmpeg",
                          fps=fps)

                file_number += 1
def main():
    try:
        os.mkdir(args.output_directory)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset_train = gqn.data.Dataset(args.dataset_path)

    if not args.snapshot_path == None:
        assert args.snapshot_path_1 == None, 'snapshot_path already specified'
        assert args.snapshot_path_2 == None, 'snapshot_path already specified'
        hyperparams_1 = HyperParameters(snapshot_directory=args.snapshot_path)
        model_1 = Model(hyperparams_1, snapshot_directory=args.snapshot_path)
        hyperparams_2 = HyperParameters(snapshot_directory=args.snapshot_path)
        model_2 = Model(hyperparams_2, snapshot_directory=args.snapshot_path)
    elif not (args.snapshot_path_1 == None or args.snapshot_path_2 == None):
        hyperparams_1 = HyperParameters(snapshot_directory=args.snapshot_path_1)
        model_1 = Model(hyperparams_1, snapshot_directory=args.snapshot_path_1)
        hyperparams_2 = HyperParameters(snapshot_directory=args.snapshot_path_2)
        model_2 = Model(hyperparams_2, snapshot_directory=args.snapshot_path_2)
    else:
        raise TypeError('snapshot path incorrectly specified')

    if using_gpu:
        model_1.to_gpu()
        model_2.to_gpu()

    plt.style.use("dark_background")
    fig = plt.figure(figsize=(13, 8))

    num_views_per_scene = 6
    num_generation = 2
    total_frames_per_rotation = 24

    image_shape = (3, ) + hyperparams_1.image_size
    blank_image = make_uint8(np.full(image_shape, 0))
    file_number = 1

    with chainer.no_backprop_mode():
        #for subset_1, subset_2 in zip(dataset_1, dataset_train):
        for i, subset in enumerate(dataset_train):
            iterator_train = gqn.data.Iterator(subset, batch_size=1)
            #iterator_1 = gqn.data.Iterator(subset_1, batch_size=1)
            #iterator_2 = gqn.data.Iterator(subset_2, batch_size=1)

            #for data_indices_1, data_indices_2 in zip(iterator_1, iterator_2):
            for j, data_indices in enumerate(iterator_train):
                if j == 0:
                    continue
                snapshot_array = []

                observed_image_array_1 = xp.zeros(
                    (num_views_per_scene, ) + image_shape, dtype=np.float32)
                observed_viewpoint_array_train = xp.zeros(
                    (num_views_per_scene, 7), dtype=np.float32)
                observed_image_array_2 = xp.zeros(
                    (num_views_per_scene, ) + image_shape, dtype=np.float32)
                observed_viewpoint_array_2 = xp.zeros(
                    (num_views_per_scene, 7), dtype=np.float32)


                test_data_path = os.path.join(args.dataset_path, 'test_data', str(i)+'_'+str(j)+'.npy')
                if not os.path.exists(test_data_path):
                    raise TypeError('test data not found')
                test_data = np.load(test_data_path)

                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images_1, viewpoints, original_images_1 = subset[data_indices]
                print('images_1.shape')
                print(images_1.shape)
                #images_2, viewpoints_2, original_images_2 = subset_2[data_indices_2]
                images_2, original_images_2  = test_data
                print('test_data:')
                print(len(test_data))
                print(type(test_data))
                print('images_2')
                print(len(images_2))
                print(type(images_2))
                images_2 = np.array([np.array(images_2)])      
                print(images_2.shape)
                print(type(images_2))
                print('original_images_2')
                print(len(original_images_2))
                print(type(original_images_2))
                original_images_2 = np.array([np.array(original_images_2)])      

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images_1 = images_1.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images_1 = images_1 / 255.0
                images_1 += np.random.uniform(0, 1.0 / 256.0, size=images_1.shape).astype(np.float32)
                images_2 = images_2.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images_2 = images_2 / 255.0
                images_2 += np.random.uniform(0, 1.0 / 256.0, size=images_2.shape).astype(np.float32)

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                original_images_1 = original_images_1.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                original_images_1 = original_images_1 / 255.0
                original_images_1 += np.random.uniform( 0, 1.0 / 256.0, size=original_images_1.shape).astype(np.float32)
                original_images_2 = original_images_2.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                original_images_2 = original_images_2 / 255.0
                original_images_2 += np.random.uniform( 0, 1.0 / 256.0, size=original_images_2.shape).astype(np.float32)

                batch_index = 0

                # Generate images without observations
                r_1 = xp.zeros(
                    (
                        num_generation,
                        hyperparams_1.representation_channels,
                    ) + hyperparams_1.chrz_size,
                    dtype=np.float32)

                r_2 = xp.zeros(
                    (
                        num_generation,
                        hyperparams_2.representation_channels,
                    ) + hyperparams_2.chrz_size,
                    dtype=np.float32)

                angle_rad = 0
                current_scene_original_images_cpu_1 = original_images_1[batch_index]
                current_scene_original_images_1 = to_gpu(current_scene_original_images_cpu_1)
                current_scene_original_images_cpu_2 = original_images_2[batch_index]
                current_scene_original_images_2 = to_gpu(current_scene_original_images_cpu_2)

                gqn.animator.Snapshot.make_graph(
                    id='sq_d_graph_1',
                    pos=8,
                    graph_type='plot',
                    mode='sequential',
                    frame_in_rotation=total_frames_per_rotation,
                    num_of_data_per_graph=num_views_per_scene + 1,
                    trivial_settings={
                        'colors': ['red', 'blue', 'green', 'orange', 'white', 'cyan', 'magenta'],
                        'markers': ['', '', '', '', '', '', '']
                    }
                )

                gqn.animator.Snapshot.make_graph(
                    id='sq_d_graph_2',
                    pos=16,
                    graph_type='plot',
                    mode='sequential',
                    frame_in_rotation=total_frames_per_rotation,
                    num_of_data_per_graph=num_views_per_scene + 1,
                    trivial_settings={
                        'colors': ['red', 'blue', 'green', 'orange', 'white', 'cyan', 'magenta'],
                        'markers': ['', '', '', '', '', '', '']
                    }
                )

                gqn.animator.Snapshot.make_graph(
                    id='sq_d_avg_graph',
                    pos=17,
                    graph_type='bar',
                    mode='simultaneous',
                    frame_in_rotation=9,
                    frame_per_cycle=total_frames_per_rotation,
                    num_of_data_per_graph=2,
                    trivial_settings={
                        'colors': ['red', 'blue'],
                        'markers': ['', ''],
                        'legends': ['Train', 'Test']
                    }
                )

                sq_d_sums_1 = [0 for i in range(num_views_per_scene+1)]
                sq_d_sums_2 = [0 for i in range(num_views_per_scene+1)]
                for t in range(total_frames_per_rotation):
                    grid_master = GridSpec(nrows=4, ncols=9, height_ratios=[1,1,1,1])
                    grid_master.update(wspace=0.5, hspace=0.8)
                    snapshot = gqn.animator.Snapshot(unify_ylim=True, layout_settings={
                        'subplot_count': 17,
                        'grid_master': grid_master,
                        'subplots': [
                            { 'subplot_id': 1,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[0, 0]) },
                            { 'subplot_id': 2,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[0, 1]) },
                            { 'subplot_id': 3,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[0, 2]) },
                            { 'subplot_id': 4,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[1, 0]) },
                            { 'subplot_id': 5,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[1, 1]) },
                            { 'subplot_id': 6,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[1, 2]) },
                            { 'subplot_id': 7,   'subplot': GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=grid_master[0:2, 3:5]) },
                            { 'subplot_id': 8,   'subplot': GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=grid_master[0:2, 5:7]) },

                            { 'subplot_id': 9,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[2, 0]) },
                            { 'subplot_id': 10,  'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[2, 1]) },
                            { 'subplot_id': 11,  'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[2, 2]) },
                            { 'subplot_id': 12,  'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[3, 0]) },
                            { 'subplot_id': 13,  'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[3, 1]) },
                            { 'subplot_id': 14,  'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[3, 2]) },
                            { 'subplot_id': 15,  'subplot': GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=grid_master[2:4, 3:5]) },
                            { 'subplot_id': 16,  'subplot': GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=grid_master[2:4, 5:7]) },

                            { 'subplot_id': 17,  'subplot': GridSpecFromSubplotSpec(nrows=4, ncols=2, subplot_spec=grid_master[0:4, 7:9]) }
                        ]
                    })

                    for i in [1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14]:
                        snapshot.add_media(
                            media_type='image',
                            media_data=make_uint8(blank_image),
                            media_position=i
                        )
                        if i == 1:
                            snapshot.add_title(text='Train',target_media_pos=i)
                        if i == 9:
                            snapshot.add_title(text='Test', target_media_pos=i)

                    query_viewpoints = rotate_query_viewpoint(
                        angle_rad, num_generation, xp)
                    generated_images_1 = model_1.generate_image(
                        query_viewpoints, r_1, xp)
                    generated_images_2 = model_2.generate_image(
                        query_viewpoints, r_2, xp)

                    total_sq_d_1, _ = gqn.math.get_squared_distance(
                        to_cpu(current_scene_original_images_1[t]),
                        to_cpu(generated_images_1[0]))
                    sq_d_sums_1[0] += total_sq_d_1

                    total_sq_d_2, _ = gqn.math.get_squared_distance(
                        to_cpu(current_scene_original_images_2[t]),
                        to_cpu(generated_images_2[0]))
                    sq_d_sums_2[0] += total_sq_d_2

                    for i in [7]:
                        snapshot.add_media(
                            media_type='image',
                            media_data=make_uint8(generated_images_1[0]),
                            media_position=i
                        )
                        snapshot.add_title(
                            text='GQN Output',
                            target_media_pos=i
                        )

                    for i in [15]:
                        snapshot.add_media(
                            media_type='image',
                            media_data=make_uint8(generated_images_2[0]),
                            media_position=i
                        )
                        snapshot.add_title(
                            text='GQN Output',
                            target_media_pos=i
                        )

                    for i in [8, 16]:
                        snapshot.add_title(
                            text='Squared Distance',
                            target_media_pos=i
                        )

                    gqn.animator.Snapshot.add_graph_data(
                        graph_id='sq_d_graph_1',
                        data_id='sq_d_data_0',
                        new_data=total_sq_d_1,
                        frame_num=t,
                    )

                    gqn.animator.Snapshot.add_graph_data(
                        graph_id='sq_d_graph_2',
                        data_id='sq_d_data_0',
                        new_data=total_sq_d_2,
                        frame_num=t,
                    )

                    if t == total_frames_per_rotation - 1:
                      sq_d_sums_1[0] /= total_frames_per_rotation
                      sq_d_sums_2[0] /= total_frames_per_rotation
                      gqn.animator.Snapshot.add_graph_data(
                          graph_id='sq_d_avg_graph',
                          data_id='sq_d_data_0',
                          new_data=sq_d_sums_1[0],
                          frame_num=0
                      )
                      gqn.animator.Snapshot.add_graph_data(
                          graph_id='sq_d_avg_graph',
                          data_id='sq_d_data_1',
                          new_data=sq_d_sums_2[0],
                          frame_num=0
                      )

                    snapshot_array.append(snapshot)

                    angle_rad += 2 * math.pi / total_frames_per_rotation

                # Generate images with observations
                for m in range(num_views_per_scene):
                    kl_div_sum = 0
                    observed_image_1 = images_1[batch_index, m]
                    observed_viewpoint_1 = viewpoints[batch_index, m]
                    observed_image_2 = images_2[batch_index, m]
                    observed_viewpoint_2 = viewpoints[batch_index, m]

                    observed_image_array_1[m] = to_gpu(observed_image_1)
                    observed_viewpoint_array_train[m] = to_gpu(observed_viewpoint_1)
                    observed_image_array_2[m] = to_gpu(observed_image_2)
                    observed_viewpoint_array_2[m] = to_gpu(observed_viewpoint_2)

                    r_1 = model_1.compute_observation_representation(
                        observed_image_array_1[None, :m + 1],
                        observed_viewpoint_array_train[None, :m + 1])
                    r_2 = model_2.compute_observation_representation(
                        observed_image_array_2[None, :m + 1],
                        observed_viewpoint_array_2[None, :m + 1])

                    r_1 = cf.broadcast_to(r_1, (num_generation, ) + r_1.shape[1:])
                    r_2 = cf.broadcast_to(r_2, (num_generation, ) + r_2.shape[1:])

                    angle_rad = 0
                    for t in range(total_frames_per_rotation):
                        grid_master = GridSpec(nrows=4, ncols=9, height_ratios=[1,1,1,1])
                        grid_master.update(wspace=0.5, hspace=0.8)
                        snapshot = gqn.animator.Snapshot(unify_ylim=True, layout_settings={
                            'subplot_count': 17,
                            'grid_master': grid_master,
                            'subplots': [
                              { 'subplot_id': 1,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[0, 0]) },
                              { 'subplot_id': 2,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[0, 1]) },
                              { 'subplot_id': 3,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[0, 2]) },
                              { 'subplot_id': 4,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[1, 0]) },
                              { 'subplot_id': 5,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[1, 1]) },
                              { 'subplot_id': 6,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[1, 2]) },
                              { 'subplot_id': 7,   'subplot': GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=grid_master[0:2, 3:5]) },
                              { 'subplot_id': 8,   'subplot': GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=grid_master[0:2, 5:7]) },

                              { 'subplot_id': 9,   'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[2, 0]) },
                              { 'subplot_id': 10,  'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[2, 1]) },
                              { 'subplot_id': 11,  'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[2, 2]) },
                              { 'subplot_id': 12,  'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[3, 0]) },
                              { 'subplot_id': 13,  'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[3, 1]) },
                              { 'subplot_id': 14,  'subplot': GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=grid_master[3, 2]) },
                              { 'subplot_id': 15,  'subplot': GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=grid_master[2:4, 3:5]) },
                              { 'subplot_id': 16,  'subplot': GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=grid_master[2:4, 5:7]) },

                              { 'subplot_id': 17,  'subplot': GridSpecFromSubplotSpec(nrows=4, ncols=2, subplot_spec=grid_master[0:4, 7:9]) }
                            ]
                        })

                        for i, observed_image in zip([1, 2, 3, 4, 5, 6], observed_image_array_1):
                            snapshot.add_media(
                                media_type='image',
                                media_data=make_uint8(observed_image),
                                media_position=i
                            )
                            if i == 1:
                                snapshot.add_title(text='Train', target_media_pos=i)

                        for i, observed_image in zip([9, 10, 11, 12, 13, 14], observed_image_array_2):
                            snapshot.add_media(
                                media_type='image',
                                media_data=make_uint8(observed_image),
                                media_position=i
                            )
                            if i == 9:
                                snapshot.add_title(text='Test', target_media_pos=i)

                        query_viewpoints = rotate_query_viewpoint(
                            angle_rad, num_generation, xp)
                        generated_images_1 = model_1.generate_image(
                            query_viewpoints, r_1, xp)
                        generated_images_2 = model_2.generate_image(
                            query_viewpoints, r_2, xp)

                        total_sq_d_1, _ = gqn.math.get_squared_distance(
                            to_cpu(current_scene_original_images_1[t]),
                            to_cpu(generated_images_1[0]))
                        sq_d_sums_1[m+1] += total_sq_d_1

                        total_sq_d_2, _ = gqn.math.get_squared_distance(
                            to_cpu(current_scene_original_images_2[t]),
                            to_cpu(generated_images_2[0]))
                        sq_d_sums_2[m+1] += total_sq_d_2

                        for i in [7]:
                            snapshot.add_media(
                                media_type='image',
                                media_data=make_uint8(generated_images_1[0]),
                                media_position=i
                            )
                            snapshot.add_title(
                                text='GQN Output',
                                target_media_pos=i
                            )

                        for i in [15]:
                            snapshot.add_media(
                                media_type='image',
                                media_data=make_uint8(generated_images_2[0]),
                                media_position=i
                            )
                            snapshot.add_title(
                                text='GQN Output',
                                target_media_pos=i
                            )

                        for i in [8, 16]:
                            snapshot.add_title(
                                text='Squared Distance',
                                target_media_pos=i
                            )

                        gqn.animator.Snapshot.add_graph_data(
                            graph_id='sq_d_graph_1',
                            data_id='sq_d_data_' + str(m+1),
                            new_data=total_sq_d_1,
                            frame_num=t,
                        )

                        gqn.animator.Snapshot.add_graph_data(
                            graph_id='sq_d_graph_2',
                            data_id='sq_d_data_' + str(m+1),
                            new_data=total_sq_d_2,
                            frame_num=t,
                        )

                        if t == total_frames_per_rotation - 1:
                          sq_d_sums_1[m+1] /= total_frames_per_rotation
                          sq_d_sums_2[m+1] /= total_frames_per_rotation
                          gqn.animator.Snapshot.add_graph_data(
                              graph_id='sq_d_avg_graph',
                              data_id='sq_d_data_0',
                              new_data=sq_d_sums_1[m+1],
                              frame_num=m+1
                          )

                          gqn.animator.Snapshot.add_graph_data(
                              graph_id='sq_d_avg_graph',
                              data_id='sq_d_data_1',
                              new_data=sq_d_sums_2[m+1],
                              frame_num=m+1
                          )

                        angle_rad += 2 * math.pi / total_frames_per_rotation
                        # plt.pause(1e-8)

                        snapshot_array.append(snapshot)

                plt.subplots_adjust(
                    left=None,
                    bottom=None,
                    right=None,
                    top=None,
                    wspace=0,
                    hspace=0)

                anim = animation.FuncAnimation(
                    fig,
                    func_anim_upate,
                    fargs = (fig, [snapshot_array]),
                    interval=1/24,
                    frames= (num_views_per_scene + 1) * total_frames_per_rotation
                )

                anim.save(
                    "{}/shepard_matzler_{}.mp4".format(
                        args.output_directory, file_number),
                    writer="ffmpeg",
                    fps=12)
                file_number += 1
def main():
    try:
        os.mkdir(args.figure_directory)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    hyperparams = HyperParameters(snapshot_directory=args.snapshot_path)
    model = Model(hyperparams, snapshot_directory=args.snapshot_path)
    if using_gpu:
        model.to_gpu()

    plt.style.use("dark_background")
    fig = plt.figure(figsize=(15, 5))
    fig.suptitle("GQN")
    axis_observations = fig.add_subplot(1, 3, 1)
    axis_observations.axis("off")
    axis_observations.set_title("Observations")
    axis_ground_truth = fig.add_subplot(1, 3, 2)
    axis_ground_truth.axis("off")
    axis_ground_truth.set_title("Ground Truth")
    axis_reconstruction = fig.add_subplot(1, 3, 3)
    axis_reconstruction.axis("off")
    axis_reconstruction.set_title("Reconstruction")

    total_observations_per_scene = 2**2
    num_observations_per_column = int(math.sqrt(total_observations_per_scene))

    black_color = -0.5
    image_shape = (3, ) + hyperparams.image_size
    axis_observations_image = np.full(
        (3, num_observations_per_column * image_shape[1],
         num_observations_per_column * image_shape[2]),
        black_color,
        dtype=np.float32)
    file_number = 1

    with chainer.no_backprop_mode():
        for subset in dataset:
            iterator = gqn.data.Iterator(subset, batch_size=1)

            for data_indices in iterator:
                animation_frame_array = []
                axis_observations_image[...] = black_color

                observed_image_array = xp.full(
                    (total_observations_per_scene, ) + image_shape,
                    black_color,
                    dtype=np.float32)
                observed_viewpoint_array = xp.zeros(
                    (total_observations_per_scene, 7), dtype=np.float32)

                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images = preprocess_images(images)

                batch_index = 0

                query_index = total_observations_per_scene
                query_image = images[batch_index, query_index]
                query_viewpoint = to_gpu(viewpoints[None, batch_index,
                                                    query_index])

                axis_ground_truth.imshow(make_uint8(query_image),
                                         interpolation="none")

                for observation_index in range(total_observations_per_scene):
                    observed_image = images[batch_index, observation_index]
                    observed_viewpoint = viewpoints[batch_index,
                                                    observation_index]

                    observed_image_array[observation_index] = to_gpu(
                        observed_image)
                    observed_viewpoint_array[observation_index] = to_gpu(
                        observed_viewpoint)

                    representation = model.compute_observation_representation(
                        observed_image_array[None, :observation_index + 1],
                        observed_viewpoint_array[None, :observation_index + 1])

                    representation = cf.broadcast_to(representation, (1, ) +
                                                     representation.shape[1:])

                    # Update figure
                    x_start = image_shape[1] * (observation_index %
                                                num_observations_per_column)
                    x_end = x_start + image_shape[1]
                    y_start = image_shape[2] * (observation_index //
                                                num_observations_per_column)
                    y_end = y_start + image_shape[2]
                    axis_observations_image[:, y_start:y_end,
                                            x_start:x_end] = observed_image

                    axis_observations.imshow(
                        make_uint8(axis_observations_image),
                        interpolation="none",
                        animated=True)

                    generated_images = model.generate_image(
                        query_viewpoint, representation)[0]

                    axis_reconstruction.imshow(make_uint8(generated_images),
                                               interpolation="none")

                    plt.pause(1)
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    hyperparams = HyperParameters()
    model = Model(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        model.to_gpu()

    screen_size = hyperparams.image_size
    camera = gqn.three.PerspectiveCamera(
        eye=(3, 1, 0),
        center=(0, 0, 0),
        up=(0, 1, 0),
        fov_rad=math.pi / 2.0,
        aspect_ratio=screen_size[0] / screen_size[1],
        z_near=0.1,
        z_far=10)

    figure = gqn.imgplot.figure()
    axes_observations = []
    axes_generations = []
    sqrt_n = math.sqrt(args.num_views_per_scene)
    axis_width = 0.5 / sqrt_n
    axis_height = 1.0 / sqrt_n
    for n in range(args.num_views_per_scene):
        axis = gqn.imgplot.image()
        x = n % sqrt_n
        y = n // sqrt_n
        figure.add(axis, x * axis_width, y * axis_height, axis_width,
                   axis_height)
        axes_observations.append(axis)
    sqrt_n = math.sqrt(args.num_generation)
    axis_width = 0.5 / sqrt_n
    axis_height = 1.0 / sqrt_n
    for n in range(args.num_generation):
        axis = gqn.imgplot.image()
        x = n % sqrt_n
        y = n // sqrt_n
        figure.add(axis, x * axis_width + 0.5, y * axis_height, axis_width,
                   axis_height)
        axes_generations.append(axis)

    window = gqn.imgplot.window(figure, (1600, 800), "Dataset")
    window.show()

    raw_observed_images = np.zeros(screen_size + (3, ), dtype="uint32")
    renderer = gqn.three.Renderer(screen_size[0], screen_size[1])

    observed_images = xp.zeros(
        (args.num_views_per_scene, 3) + screen_size, dtype="float32")
    observed_viewpoints = xp.zeros(
        (args.num_views_per_scene, 7), dtype="float32")

    with chainer.no_backprop_mode():
        while True:
            if window.closed():
                exit()

            scene, _ = gqn.environment.shepard_metzler.build_scene(
                num_blocks=random.choice([x for x in range(7, 8)]))
            renderer.set_scene(scene)

            # Generate images without observations
            r = xp.zeros(
                (
                    args.num_generation,
                    hyperparams.channels_r,
                ) + hyperparams.chrz_size,
                dtype="float32")
            total_frames = 50
            for tick in range(total_frames):
                if window.closed():
                    exit()
                query_viewpoints = generate_random_query_viewpoint(
                    tick / total_frames, xp)
                generated_images = to_cpu(
                    model.generate_image(query_viewpoints, r, xp))

                for m in range(args.num_generation):
                    if window.closed():
                        exit()
                    image = make_uint8(generated_images[m])
                    axis = axes_generations[m]
                    axis.update(image)

            for n in range(args.num_views_per_scene):
                if window.closed():
                    exit()
                rad_xz = random.uniform(0, math.pi * 2)
                rad_y = random.uniform(0, math.pi * 2)
                eye = (3.0 * math.cos(rad_xz), 3.0 * math.sin(rad_y),
                       3.0 * math.sin(rad_xz))
                center = (0, 0, 0)
                yaw = gqn.math.yaw(eye, center)
                pitch = gqn.math.pitch(eye, center)
                camera.look_at(
                    eye=eye,
                    center=center,
                    up=(0.0, 1.0, 0.0),
                )
                renderer.render(camera, raw_observed_images)

                # [0, 255] -> [-1, 1]
                observed_images[n] = to_gpu((raw_observed_images.transpose(
                    (2, 0, 1)) / 255 - 0.5) * 2.0)

                observed_viewpoints[n] = xp.array(
                    (eye[0], eye[1], eye[2], math.cos(yaw), math.sin(yaw),
                     math.cos(pitch), math.sin(pitch)),
                    dtype="float32")

                r = model.compute_observation_representation(
                    observed_images[None, :n + 1],
                    observed_viewpoints[None, :n + 1])
                    
                r = cf.broadcast_to(r, (args.num_generation, ) + r.shape[1:])

                axis = axes_observations[n]
                axis.update(np.uint8(raw_observed_images))

                total_frames = 50
                for tick in range(total_frames):
                    if window.closed():
                        exit()
                    query_viewpoints = generate_random_query_viewpoint(
                        tick / total_frames, xp)
                    generated_images = to_cpu(
                        model.generate_image(query_viewpoints, r, xp))

                    for m in range(args.num_generation):
                        if window.closed():
                            exit()
                        image = make_uint8(generated_images[m])
                        axis = axes_generations[m]
                        axis.update(image)

            raw_observed_images[...] = 0
            for axis in axes_observations:
                axis.update(np.uint8(raw_observed_images))
Example #10
0
def main():
    try:
        os.makedirs(args.figure_directory)
    except:
        pass

    # loading dataset & model
    cuda.get_device(args.gpu_device).use()
    xp=cp

    hyperparams = HyperParameters()
    assert hyperparams.load(args.snapshot_directory)

    model = Model(hyperparams)
    chainer.serializers.load_hdf5(args.snapshot_file, model)
    model.to_gpu()

    total_observations_per_scene = 4
    fps = 30

    black_color = -0.5
    image_shape = (3, ) + hyperparams.image_size
    axis_observations_image = np.zeros(
        (3, image_shape[1], total_observations_per_scene * image_shape[2]),
        dtype=np.float32)

    #==============================================================================
    # Utilities
    #==============================================================================
    def read_files(directory):
        filenames = []
        files = os.listdir(directory)
        # ipdb.set_trace()
        for filename in files:
            if filename.endswith(".h5"):
                filenames.append(filename)
        filenames.sort()
        
        dataset_images = []
        dataset_viewpoints = []
        for i in range(len(filenames)):
            F = h5py.File(os.path.join(directory,filenames[i]))
            tmp_images = list(F["images"])
            tmp_viewpoints = list(F["viewpoints"])
            
            dataset_images.extend(tmp_images)
            dataset_viewpoints.extend(tmp_viewpoints)
        # for i in range(len(filenames)):
        #     images_npy_path = os.path.join(directory, "images", filenames[i])
        #     viewpoints_npy_path = os.path.join(directory, "viewpoints", filenames[i])
        #     tmp_images = np.load(images_npy_path)
        #     tmp_viewpoints = np.load(viewpoints_npy_path)
        
        #     assert tmp_images.shape[0] == tmp_viewpoints.shape[0]
            
        #     dataset_images.extend(tmp_images)
        #     dataset_viewpoints.extend(tmp_viewpoints)
        dataset_images = np.array(dataset_images)
        dataset_viewpoints = np.array(dataset_viewpoints)

        dataset = list()
        for i in range(len(dataset_images)):
            item = {'image':dataset_images[i],'viewpoint':dataset_viewpoints[i]}
            dataset.append(item)
        
        return dataset
    
    def to_device(array):
        # if using_gpu:
        array = cuda.to_gpu(array)
        return array

    def fill_observations_axis(observation_images):
        axis_observations_image = np.full(
            (3, image_shape[1], total_observations_per_scene * image_shape[2]),
            black_color,
            dtype=np.float32)
        num_current_obs = len(observation_images)
        total_obs = total_observations_per_scene
        width = image_shape[2]
        x_start = width * (total_obs - num_current_obs) // 2
        for obs_image in observation_images:
            x_end = x_start + width
            axis_observations_image[:, :, x_start:x_end] = obs_image
            x_start += width
        return axis_observations_image

    def compute_camera_angle_at_frame(t):
        return t * 2 * math.pi / (fps * 2)

    def rotate_query_viewpoint(horizontal_angle_rad, camera_distance,
                               camera_position_y):
        camera_position = np.array([
            camera_distance * math.sin(horizontal_angle_rad),  # x
            camera_position_y,
            camera_distance * math.cos(horizontal_angle_rad),  # z
        ])
        center = np.array((0, camera_position_y, 0))
        camera_direction = camera_position - center
        yaw, pitch = compute_yaw_and_pitch(camera_direction)
        query_viewpoints = xp.array(
            (
                camera_position[0],
                camera_position[1],
                camera_position[2],
                math.cos(yaw),
                math.sin(yaw),
                math.cos(pitch),
                math.sin(pitch),
            ),
            dtype=np.float32,
        )
        query_viewpoints = xp.broadcast_to(query_viewpoints,
                                           (1, ) + query_viewpoints.shape)
        return query_viewpoints

    def render(representation,
               camera_distance,
               camera_position_y,
               total_frames,
               animation_frame_array,
               rotate_camera=True):
        for t in range(0, total_frames):
            artist_array = [
                axis_observations.imshow(
                    make_uint8(axis_observations_image),
                    interpolation="none",
                    animated=True)
            ]

            horizontal_angle_rad = compute_camera_angle_at_frame(t)
            if rotate_camera == False:
                horizontal_angle_rad = compute_camera_angle_at_frame(0)

            query_viewpoints = rotate_query_viewpoint(
                horizontal_angle_rad, camera_distance, camera_position_y)
            generated_images = model.generate_image(query_viewpoints,
                                                    representation)[0]

            artist_array.append(
                axis_generation.imshow(
                    make_uint8(generated_images),
                    interpolation="none",
                    animated=True))

            animation_frame_array.append(artist_array)

    #==============================================================================
    # Visualization
    #==============================================================================
    plt.style.use("dark_background")
    fig = plt.figure(figsize=(6, 7))
    plt.subplots_adjust(left=0.1, right=0.95, bottom=0.1, top=0.95)
    # fig.suptitle("GQN")
    axis_observations = fig.add_subplot(2, 1, 1)
    axis_observations.axis("off")
    axis_observations.set_title("observations")
    axis_generation = fig.add_subplot(2, 1, 2)
    axis_generation.axis("off")
    axis_generation.set_title("neural rendering")

    #==============================================================================
    # Generating animation
    #==============================================================================
    dataset = read_files(args.dataset_directory)
    file_number = 1
    random.seed(0)
    np.random.seed(0)

    with chainer.no_backprop_mode():
        iterator  = chainer.iterators.SerialIterator(dataset,batch_size=1)
        for i in range(len(iterator.dataset)):
            animation_frame_array = []

            # shape: (batch, views, height, width, channels)
            images, viewpoints = np.array([iterator.dataset[i]["image"]]),np.array([iterator.dataset[i]["viewpoint"]])
            camera_distance = np.mean(
                np.linalg.norm(viewpoints[:, :, :3], axis=2))
            camera_position_y = np.mean(viewpoints[:, :, 1])

            # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
            images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
            images = preprocess_images(images)

            batch_index = 0

            total_views = images.shape[1]
            random_observation_view_indices = list(range(total_views))
            random.shuffle(random_observation_view_indices)
            random_observation_view_indices = random_observation_view_indices[:
                                                                                total_observations_per_scene]

            #------------------------------------------------------------------------------
            # Observations
            #------------------------------------------------------------------------------
            observed_images = images[batch_index,
                                        random_observation_view_indices]
            observed_viewpoints = viewpoints[
                batch_index, random_observation_view_indices]

            observed_images = to_device(observed_images)
            observed_viewpoints = to_device(observed_viewpoints)

            #------------------------------------------------------------------------------
            # Generate images with a single observation
            #------------------------------------------------------------------------------
            # Scene encoder
            representation = model.compute_observation_representation(
                observed_images[None, :1], observed_viewpoints[None, :1])

            # Update figure
            observation_index = random_observation_view_indices[0]
            observed_image = images[batch_index, observation_index]
            axis_observations_image = fill_observations_axis(
                [observed_image])

            # Neural rendering
            render(representation, camera_distance, camera_position_y,
                    fps * 2, animation_frame_array)
            

            #------------------------------------------------------------------------------
            # Add observations
            #------------------------------------------------------------------------------
            for n in range(total_observations_per_scene):
                observation_indices = random_observation_view_indices[:n +
                                                                        1]
                axis_observations_image = fill_observations_axis(
                    images[batch_index, observation_indices])

                # Scene encoder
                representation = model.compute_observation_representation(
                    observed_images[None, :n + 1],
                    observed_viewpoints[None, :n + 1])
                # Neural rendering
                render(
                    representation,
                    camera_distance,
                    camera_position_y,
                    fps // 2,
                    animation_frame_array,
                    rotate_camera=False)

            #------------------------------------------------------------------------------
            # Generate images with all observations
            #------------------------------------------------------------------------------
            # Scene encoder
            representation = model.compute_observation_representation(
                observed_images[None, :total_observations_per_scene + 1],
                observed_viewpoints[None, :total_observations_per_scene +
                                    1])

            # Neural rendering
            render(representation, camera_distance, camera_position_y,
                    fps * 4, animation_frame_array)

            #------------------------------------------------------------------------------
            # Write to file
            #------------------------------------------------------------------------------
            anim = animation.ArtistAnimation(
                fig,
                animation_frame_array,
                interval=1 / fps,
                blit=True,
                repeat_delay=0)

            # anim.save(
            #     "{}/shepard_matzler_observations_{}.gif".format(
            #         args.figure_directory, file_number),
            #     writer="imagemagick",
            #     fps=fps)
            anim.save(
                "{}/rooms_ring_camera_observations_{}.mp4".format(
                    args.figure_directory, file_number),
                writer="ffmpeg",
                fps=fps)

            file_number += 1
Example #11
0
def main():
    ##############################################
    # To avoid OpenMPI bug
    multiprocessing.set_start_method("forkserver")
    p = multiprocessing.Process(target=print, args=("", ))
    p.start()
    p.join()
    ##############################################

    try:
        os.mkdir(args.snapshot_directory)
    except:
        pass

    comm = chainermn.create_communicator()
    device = comm.intra_rank
    print("device", device, "/", comm.size)
    cuda.get_device(device).use()
    xp = cupy

    dataset = gqn.data.Dataset(args.dataset_directory)

    hyperparams = HyperParameters()
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.generator_u_channels = args.u_channels
    hyperparams.generator_share_upsampler = args.generator_share_upsampler
    hyperparams.generator_subpixel_convolution_enabled = args.generator_subpixel_convolution_enabled
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.inference_downsampler_channels = args.inference_downsampler_channels
    hyperparams.h_channels = args.h_channels
    hyperparams.z_channels = args.z_channels
    hyperparams.representation_channels = args.representation_channels
    hyperparams.pixel_n = args.pixel_n
    hyperparams.pixel_sigma_i = args.initial_pixel_variance
    hyperparams.pixel_sigma_f = args.final_pixel_variance
    if comm.rank == 0:
        hyperparams.save(args.snapshot_directory)
        print(hyperparams)

    model = Model(hyperparams, snapshot_directory=args.snapshot_directory)
    model.to_gpu()

    optimizer = AdamOptimizer(model.parameters,
                              communicator=comm,
                              mu_i=args.initial_lr,
                              mu_f=args.final_lr)
    if comm.rank == 0:
        print(optimizer)

    scheduler = Scheduler(sigma_start=args.initial_pixel_variance,
                          sigma_end=args.final_pixel_variance,
                          final_num_updates=args.pixel_n)
    if comm.rank == 0:
        print(scheduler)

    pixel_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                        scheduler.pixel_variance**2,
                        dtype="float32")
    pixel_ln_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                           math.log(scheduler.pixel_variance**2),
                           dtype="float32")
    num_pixels = hyperparams.image_size[0] * hyperparams.image_size[1] * 3

    random.seed(0)
    subset_indices = list(range(len(dataset.subset_filenames)))

    representation_shape = (
        args.batch_size,
        hyperparams.representation_channels) + hyperparams.chrz_size

    current_training_step = 0
    for iteration in range(args.training_iterations):
        mean_kld = 0
        mean_nll = 0
        mean_mse = 0
        mean_elbo = 0
        total_num_batch = 0
        subset_size_per_gpu = len(subset_indices) // comm.size
        if len(subset_indices) % comm.size != 0:
            subset_size_per_gpu += 1
        start_time = time.time()

        for subset_loop in range(subset_size_per_gpu):
            random.shuffle(subset_indices)
            subset_index = subset_indices[comm.rank]
            subset = dataset.read(subset_index)
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for batch_index, data_indices in enumerate(iterator):
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # (batch, views, height, width, channels) ->  (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images = images / 255.0
                images += np.random.uniform(
                    0, 1.0 / 256.0, size=images.shape).astype(np.float32)

                total_views = images.shape[1]

                # Sample observations
                num_views = random.choice(range(total_views + 1))
                if current_training_step == 0 and num_views == 0:
                    num_views = 1  # avoid OpenMPI error

                observation_view_indices = list(range(total_views))
                random.shuffle(observation_view_indices)
                observation_view_indices = observation_view_indices[:num_views]

                if num_views > 0:
                    representation = model.compute_observation_representation(
                        images[:, observation_view_indices],
                        viewpoints[:, observation_view_indices])
                else:
                    representation = xp.zeros(representation_shape,
                                              dtype="float32")
                    representation = chainer.Variable(representation)

                # Sample query
                query_index = random.choice(range(total_views))
                query_images = images[:, query_index]
                query_viewpoints = viewpoints[:, query_index]

                # Transfer to gpu
                query_images = to_gpu(query_images)
                query_viewpoints = to_gpu(query_viewpoints)

                z_t_param_array, mean_x = model.sample_z_and_x_params_from_posterior(
                    query_images, query_viewpoints, representation)

                # Compute loss
                ## KL Divergence
                loss_kld = chainer.Variable(xp.zeros((), dtype=xp.float32))
                for params in z_t_param_array:
                    mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params
                    kld = gqn.functions.gaussian_kl_divergence(
                        mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)
                    loss_kld += cf.sum(kld)

                ##Negative log-likelihood of generated image
                loss_nll = cf.sum(
                    gqn.functions.gaussian_negative_log_likelihood(
                        query_images, mean_x, pixel_var, pixel_ln_var))

                # Calculate the average loss value
                loss_nll = loss_nll / args.batch_size
                loss_kld = loss_kld / args.batch_size

                loss = (loss_nll / scheduler.pixel_variance) + loss_kld

                model.cleargrads()
                loss.backward()
                optimizer.update(current_training_step)

                loss_nll = float(loss_nll.data) + math.log(256.0)
                loss_kld = float(loss_kld.data)

                elbo = -(loss_nll + loss_kld)

                loss_mse = float(
                    cf.mean_squared_error(query_images, mean_x).data)

                if comm.rank == 0:
                    printr(
                        "Iteration {}: Subset {} / {}: Batch {} / {} - elbo: {:.2f} - loss: nll: {:.2f} mse: {:.5f} kld: {:.5f} - lr: {:.4e} - pixel_variance: {:.5f} - step: {}  "
                        .format(iteration + 1, subset_loop + 1,
                                subset_size_per_gpu, batch_index + 1,
                                len(iterator), elbo, loss_nll, loss_mse,
                                loss_kld, optimizer.learning_rate,
                                scheduler.pixel_variance,
                                current_training_step))

                total_num_batch += 1
                current_training_step += comm.size
                mean_kld += loss_kld
                mean_nll += loss_nll
                mean_mse += loss_mse
                mean_elbo += elbo

                scheduler.step(current_training_step)
                pixel_var[...] = scheduler.pixel_variance**2
                pixel_ln_var[...] = math.log(scheduler.pixel_variance**2)

            if comm.rank == 0:
                model.serialize(args.snapshot_directory)

        if comm.rank == 0:
            elapsed_time = time.time() - start_time
            mean_elbo /= total_num_batch
            mean_nll /= total_num_batch
            mean_mse /= total_num_batch
            mean_kld /= total_num_batch
            print(
                "\033[2KIteration {} - elbo: {:.2f} - loss: nll: {:.2f} mse: {:.5f} kld: {:.5f} - lr: {:.4e} - pixel_variance: {:.5f} - step: {} - time: {:.3f} min"
                .format(iteration + 1, mean_elbo, mean_nll, mean_mse, mean_kld,
                        optimizer.learning_rate, scheduler.pixel_variance,
                        current_training_step, elapsed_time / 60))
            model.serialize(args.snapshot_directory)
def main():
    try:
        os.makedirs(args.figure_directory)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cp

    dataset = gqn.data.Dataset(args.dataset_directory)

    meter = Meter()
    assert meter.load(args.snapshot_directory)

    hyperparams = HyperParameters()
    assert hyperparams.load(args.snapshot_directory)

    model = Model(hyperparams)
    assert model.load(args.snapshot_directory, meter.epoch)

    if using_gpu:
        model.to_gpu()

    #==============================================================================
    # Visualization
    #==============================================================================
    plt.figure(figsize=(12, 16))

    axis_observation_1 = plt.subplot2grid((4, 3), (0, 0))
    axis_observation_2 = plt.subplot2grid((4, 3), (0, 1))
    axis_observation_3 = plt.subplot2grid((4, 3), (0, 2))

    axis_predictions = plt.subplot2grid((4, 3), (1, 0), rowspan=3, colspan=3)

    axis_observation_1.axis("off")
    axis_observation_2.axis("off")
    axis_observation_3.axis("off")
    axis_predictions.set_xticks([], [])
    axis_predictions.set_yticks([], [])

    axis_observation_1.set_title("Observation 1", fontsize=22)
    axis_observation_2.set_title("Observation 2", fontsize=22)
    axis_observation_3.set_title("Observation 3", fontsize=22)

    axis_predictions.set_title("Neural Rendering", fontsize=22)
    axis_predictions.set_xlabel("Yaw", fontsize=22)
    axis_predictions.set_ylabel("Pitch", fontsize=22)

    #==============================================================================
    # Generating images
    #==============================================================================
    num_views_per_scene = 3
    num_yaw_pitch_steps = 10
    image_width, image_height = hyperparams.image_size
    prediction_images = make_uint8(
        np.full((num_yaw_pitch_steps * image_width,
                 num_yaw_pitch_steps * image_height, 3), 0))
    file_number = 1
    random.seed(0)
    np.random.seed(0)

    with chainer.no_backprop_mode():
        for subset in dataset:
            iterator = gqn.data.Iterator(subset, batch_size=1)

            for data_indices in iterator:
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]
                camera_distance = np.mean(
                    np.linalg.norm(viewpoints[:, :, :3], axis=2))

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images = preprocess_images(images)

                batch_index = 0

                #------------------------------------------------------------------------------
                # Observations
                #------------------------------------------------------------------------------
                total_views = images.shape[1]
                random_observation_view_indices = list(range(total_views))
                random.shuffle(random_observation_view_indices)
                random_observation_view_indices = random_observation_view_indices[:
                                                                                  num_views_per_scene]

                observed_images = images[:, random_observation_view_indices]
                observed_viewpoints = viewpoints[:,
                                                 random_observation_view_indices]
                representation = model.compute_observation_representation(
                    observed_images, observed_viewpoints)

                axis_observation_1.imshow(
                    make_uint8(observed_images[batch_index, 0]))
                axis_observation_2.imshow(
                    make_uint8(observed_images[batch_index, 1]))
                axis_observation_3.imshow(
                    make_uint8(observed_images[batch_index, 2]))

                y_angle_rad = math.pi / 2

                for pitch_loop in range(num_yaw_pitch_steps):
                    camera_y = math.sin(y_angle_rad)
                    x_angle_rad = math.pi

                    for yaw_loop in range(num_yaw_pitch_steps):
                        camera_direction = np.array([
                            math.sin(x_angle_rad), camera_y,
                            math.cos(x_angle_rad)
                        ])
                        camera_direction = camera_distance * camera_direction / np.linalg.norm(
                            camera_direction)
                        yaw, pitch = compute_yaw_and_pitch(camera_direction)

                        query_viewpoints = xp.array(
                            (
                                camera_direction[0],
                                camera_direction[1],
                                camera_direction[2],
                                math.cos(yaw),
                                math.sin(yaw),
                                math.cos(pitch),
                                math.sin(pitch),
                            ),
                            dtype=np.float32,
                        )
                        query_viewpoints = xp.broadcast_to(
                            query_viewpoints, (1, ) + query_viewpoints.shape)

                        generated_images = model.generate_image(
                            query_viewpoints, representation)[0]

                        yi_start = pitch_loop * image_height
                        yi_end = (pitch_loop + 1) * image_height
                        xi_start = yaw_loop * image_width
                        xi_end = (yaw_loop + 1) * image_width
                        prediction_images[yi_start:yi_end,
                                          xi_start:xi_end] = make_uint8(
                                              generated_images)

                        x_angle_rad -= 2 * math.pi / num_yaw_pitch_steps
                    y_angle_rad -= math.pi / num_yaw_pitch_steps

                axis_predictions.imshow(prediction_images)

                plt.savefig("{}/shepard_metzler_predictions_{}.png".format(
                    args.figure_directory, file_number))
                file_number += 1
Example #13
0
def main():
    try:
        os.mkdir(args.snapshot_directory)
    except:
        pass

    comm = chainermn.create_communicator()
    device = comm.intra_rank
    cuda.get_device(device).use()
    xp = cp

    images = []
    files = os.listdir(args.dataset_path)
    files.sort()
    subset_size = int(math.ceil(len(files) / comm.size))
    files = deque(files)
    files.rotate(-subset_size * comm.rank)
    files = list(files)[:subset_size]
    for filename in files:
        image = np.load(os.path.join(args.dataset_path, filename))
        image = image / 256
        images.append(image)

    print(comm.rank, files)

    images = np.vstack(images)
    images = images.transpose((0, 3, 1, 2)).astype(np.float32)
    train_dev_split = 0.9
    num_images = images.shape[0]
    num_train_images = int(num_images * train_dev_split)
    num_dev_images = num_images - num_train_images
    images_train = images[:num_train_images]

    # To avoid OpenMPI bug
    # multiprocessing.set_start_method("forkserver")
    # p = multiprocessing.Process(target=print, args=("", ))
    # p.start()
    # p.join()

    hyperparams = HyperParameters()
    hyperparams.chz_channels = args.chz_channels
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_share_upsampler = args.generator_share_upsampler
    hyperparams.generator_downsampler_channels = args.generator_downsampler_channels
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.inference_downsampler_channels = args.inference_downsampler_channels
    hyperparams.batch_normalization_enabled = args.enable_batch_normalization
    hyperparams.use_gru = args.use_gru
    hyperparams.no_backprop_diff_xr = args.no_backprop_diff_xr

    if comm.rank == 0:
        hyperparams.save(args.snapshot_directory)
        hyperparams.print()

    if args.use_gru:
        model = GRUModel(hyperparams,
                         snapshot_directory=args.snapshot_directory)
    else:
        model = LSTMModel(hyperparams,
                          snapshot_directory=args.snapshot_directory)
    model.to_gpu()

    optimizer = AdamOptimizer(model.parameters,
                              lr_i=args.initial_lr,
                              lr_f=args.final_lr,
                              beta_1=args.adam_beta1,
                              communicator=comm)
    if comm.rank == 0:
        optimizer.print()

    num_pixels = images.shape[1] * images.shape[2] * images.shape[3]

    dataset = draw.data.Dataset(images_train)
    iterator = draw.data.Iterator(dataset, batch_size=args.batch_size)

    num_updates = 0

    for iteration in range(args.training_steps):
        mean_kld = 0
        mean_nll = 0
        mean_mse = 0
        start_time = time.time()

        for batch_index, data_indices in enumerate(iterator):
            x = dataset[data_indices]
            x += np.random.uniform(0, 1 / 256, size=x.shape)
            x = to_gpu(x)

            z_t_param_array, x_param, r_t_array = model.sample_z_and_x_params_from_posterior(
                x)

            loss_kld = 0
            for params in z_t_param_array:
                mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params
                kld = draw.nn.functions.gaussian_kl_divergence(
                    mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)
                loss_kld += cf.sum(kld)

            loss_sse = 0
            for r_t in r_t_array:
                loss_sse += cf.sum(cf.squared_error(r_t, x))

            mu_x, ln_var_x = x_param

            loss_nll = cf.gaussian_nll(x, mu_x, ln_var_x)

            loss_nll /= args.batch_size
            loss_kld /= args.batch_size
            loss_sse /= args.batch_size
            loss = args.loss_beta * loss_nll + loss_kld + args.loss_alpha * loss_sse

            model.cleargrads()
            loss.backward(loss_scale=optimizer.loss_scale())
            optimizer.update(num_updates, loss_value=float(loss.array))

            num_updates += 1
            mean_kld += float(loss_kld.data)
            mean_nll += float(loss_nll.data)
            mean_mse += float(loss_sse.data) / num_pixels / (
                hyperparams.generator_generation_steps - 1)

            printr(
                "Iteration {}: Batch {} / {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e}"
                .format(
                    iteration + 1, batch_index + 1, len(iterator),
                    float(loss_nll.data) / num_pixels + math.log(256.0),
                    float(loss_sse.data) / num_pixels /
                    (hyperparams.generator_generation_steps - 1),
                    float(loss_kld.data), optimizer.learning_rate))

            if comm.rank == 0 and batch_index > 0 and batch_index % 100 == 0:
                model.serialize(args.snapshot_directory)

        if comm.rank == 0:
            model.serialize(args.snapshot_directory)

        if comm.rank == 0:
            elapsed_time = time.time() - start_time
            print(
                "\r\033[2KIteration {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e} - elapsed_time: {:.3f} min"
                .format(
                    iteration + 1,
                    mean_nll / len(iterator) / num_pixels + math.log(256.0),
                    mean_mse / len(iterator), mean_kld / len(iterator),
                    optimizer.learning_rate, elapsed_time / 60))
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    hyperparams = HyperParameters(snapshot_directory=args.snapshot_path)
    model = Model(hyperparams, snapshot_directory=args.snapshot_path)
    if using_gpu:
        model.to_gpu()

    figure = gqn.imgplot.figure()
    axes = []
    sqrt_batch_size = int(math.sqrt(args.batch_size))
    axis_size = 1.0 / sqrt_batch_size
    for y in range(sqrt_batch_size):
        for x in range(sqrt_batch_size * 2):
            axis = gqn.imgplot.image()
            axes.append(axis)
            figure.add(axis, axis_size / 2 * x, axis_size * y, axis_size / 2,
                       axis_size)
    window = gqn.imgplot.window(figure, (1600, 800), "Reconstucted images")
    window.show()

    with chainer.no_backprop_mode():
        for _, subset in enumerate(dataset):
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for data_indices in iterator:
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3))

                total_views = images.shape[1]

                # sample number of views
                num_views = random.choice(range(total_views))
                query_index = random.choice(range(total_views))

                if num_views > 0:
                    r = model.compute_observation_representation(
                        images[:, :num_views], viewpoints[:, :num_views])
                else:
                    r = np.zeros((args.batch_size, hyperparams.channels_r) +
                                 hyperparams.chrz_size,
                                 dtype="float32")
                    r = to_gpu(r)

                query_images = images[:, query_index]
                query_viewpoints = viewpoints[:, query_index]

                # transfer to gpu
                query_viewpoints = to_gpu(query_viewpoints)

                reconstructed_images = model.generate_image(
                    query_viewpoints, r, xp)

                if window.closed():
                    exit()

                for batch_index in range(args.batch_size):
                    axis = axes[batch_index * 2 + 0]
                    image = query_images[batch_index]
                    axis.update(make_uint8(image))

                    axis = axes[batch_index * 2 + 1]
                    image = reconstructed_images[batch_index]
                    axis.update(make_uint8(image))

                time.sleep(1)
Example #15
0
def main():
    try:
        os.mkdir(args.output_directory)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    hyperparams = HyperParameters(snapshot_directory=args.snapshot_path)
    model = Model(hyperparams, snapshot_directory=args.snapshot_path)
    if using_gpu:
        model.to_gpu()

    plt.style.use("dark_background")
    fig = plt.figure(figsize=(10, 5))

    axis_observation_array = []
    axis_observation_array.append(fig.add_subplot(2, 4, 1))
    axis_observation_array.append(fig.add_subplot(2, 4, 2))
    axis_observation_array.append(fig.add_subplot(2, 4, 5))
    axis_observation_array.append(fig.add_subplot(2, 4, 6))

    for axis in axis_observation_array:
        axis.axis("off")

    axis_generation_array = []
    axis_generation_array.append(fig.add_subplot(2, 4, 3))
    axis_generation_array.append(fig.add_subplot(2, 4, 4))
    axis_generation_array.append(fig.add_subplot(2, 4, 7))
    axis_generation_array.append(fig.add_subplot(2, 4, 8))

    for axis in axis_generation_array:
        axis.axis("off")

    num_views_per_scene = 4
    num_generation = 4
    total_frames_per_rotation = 24

    image_shape = (3, ) + hyperparams.image_size
    blank_image = np.full(image_shape, -0.5)
    file_number = 1

    with chainer.no_backprop_mode():
        for subset in dataset:
            iterator = gqn.data.Iterator(subset, batch_size=1)

            for data_indices in iterator:
                artist_frame_array = []

                observed_image_array = xp.zeros(
                    (num_views_per_scene, ) + image_shape, dtype=np.float32)
                observed_viewpoint_array = xp.zeros((num_views_per_scene, 7),
                                                    dtype=np.float32)

                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images = preprocess_images(images)

                batch_index = 0

                # Generate images without observations
                r = xp.zeros((
                    num_generation,
                    hyperparams.representation_channels,
                ) + hyperparams.chrz_size,
                             dtype=np.float32)

                angle_rad = 0
                for t in range(total_frames_per_rotation):
                    artist_array = []

                    for axis in axis_observation_array:
                        axis_image = axis.imshow(make_uint8(blank_image),
                                                 interpolation="none",
                                                 animated=True)
                        artist_array.append(axis_image)

                    query_viewpoints = rotate_query_viewpoint(
                        angle_rad, num_generation, xp)
                    generated_images = model.generate_image(
                        query_viewpoints, r)

                    for j, axis in enumerate(axis_generation_array):
                        image = make_uint8(generated_images[j])
                        axis_image = axis.imshow(image,
                                                 interpolation="none",
                                                 animated=True)
                        artist_array.append(axis_image)

                    angle_rad += 2 * math.pi / total_frames_per_rotation

                    # plt.pause(1e-8)
                    axis = axis_generation_array[-1]
                    add_annotation(axis, artist_array)
                    artist_frame_array.append(artist_array)

                # Generate images with observations
                for m in range(num_views_per_scene):
                    observed_image = images[batch_index, m]
                    observed_viewpoint = viewpoints[batch_index, m]

                    observed_image_array[m] = to_gpu(observed_image)
                    observed_viewpoint_array[m] = to_gpu(observed_viewpoint)

                    r = model.compute_observation_representation(
                        observed_image_array[None, :m + 1],
                        observed_viewpoint_array[None, :m + 1])

                    r = cf.broadcast_to(r, (num_generation, ) + r.shape[1:])

                    angle_rad = 0
                    for t in range(total_frames_per_rotation):
                        artist_array = []

                        for axis, observed_image in zip(
                                axis_observation_array, observed_image_array):
                            axis_image = axis.imshow(
                                make_uint8(observed_image),
                                interpolation="none",
                                animated=True)
                            artist_array.append(axis_image)

                        query_viewpoints = rotate_query_viewpoint(
                            angle_rad, num_generation, xp)
                        generated_images = model.generate_image(
                            query_viewpoints, r)

                        for j in range(num_generation):
                            axis = axis_generation_array[j]
                            axis_image = axis.imshow(make_uint8(
                                generated_images[j]),
                                                     interpolation="none",
                                                     animated=True)
                            artist_array.append(axis_image)

                        angle_rad += 2 * math.pi / total_frames_per_rotation
                        # plt.pause(1e-8)

                        axis = axis_generation_array[-1]
                        add_annotation(axis, artist_array)
                        artist_frame_array.append(artist_array)

                # plt.tight_layout()
                # plt.subplots_adjust(
                #     left=None,
                #     bottom=None,
                #     right=None,
                #     top=None,
                #     wspace=0,
                #     hspace=0)
                anim = animation.ArtistAnimation(fig,
                                                 artist_frame_array,
                                                 interval=1 / 24,
                                                 blit=True,
                                                 repeat_delay=0)

                anim.save("{}/rooms_ring_camera_{}.gif".format(
                    args.output_directory, file_number),
                          writer="imagemagick")
                anim.save("{}/rooms_ring_camera_{}.mp4".format(
                    args.output_directory, file_number),
                          writer="ffmpeg",
                          fps=12)
                file_number += 1
Example #16
0
    return np.array(buffer_actor_step), np.array(
        buffer_learner_step), np.array(buffer_cur_size)


if __name__ == '__main__':

    # ray.init()
    ray.init(resources={"node0": 256})

    # env = gym.make(FLAGS.env_name)
    env = TradingEnv(action_scheme_id=3, obs_dim=38)

    # ------ HyperParameters ------
    opt = HyperParameters(env, FLAGS.env_name, FLAGS.exp_name, FLAGS.num_nodes,
                          FLAGS.num_workers, FLAGS.a_l_ratio,
                          FLAGS.weights_file)

    if FLAGS.recover:
        opt.recover = True
    # ------ end ------

    node_ps = []
    node_buffer = []

    for node_index in range(FLAGS.num_nodes):

        # ------ Parameter Server (ray actor) ------
        # create model to get weights and create a parameter server
        node_ps.append(
            ParameterServer._remote(args=[
Example #17
0
def main():
    try:
        os.mkdir(args.snapshot_directory)
    except:
        pass

    images = []
    files = os.listdir(args.dataset_path)
    for filename in files:
        image = np.load(os.path.join(args.dataset_path, filename))
        image = image / 255 * 2.0 - 1.0
        images.append(image)

    images = np.vstack(images)
    images = images.transpose((0, 3, 1, 2)).astype(np.float32)
    train_dev_split = 0.9
    num_images = images.shape[0]
    num_train_images = int(num_images * train_dev_split)
    num_dev_images = num_images - num_train_images
    images_train = images[:args.batch_size]
    images_dev = images[args.batch_size:]

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cp

    hyperparams = HyperParameters()
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.layer_normalization_enabled = args.layer_normalization
    hyperparams.pixel_n = args.pixel_n
    hyperparams.chz_channels = args.chz_channels
    hyperparams.inference_channels_downsampler_x = args.channels_downsampler_x
    hyperparams.pixel_sigma_i = args.initial_pixel_sigma
    hyperparams.pixel_sigma_f = args.final_pixel_sigma
    hyperparams.chrz_size = (32, 32)
    hyperparams.save(args.snapshot_directory)
    hyperparams.print()

    model = Model(hyperparams, snapshot_directory=args.snapshot_directory)
    if using_gpu:
        model.to_gpu()

    optimizer = AdamOptimizer(model.parameters,
                              lr_i=args.initial_lr,
                              lr_f=args.final_lr)
    optimizer.print()

    sigma_t = hyperparams.pixel_sigma_i
    pixel_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                        sigma_t**2,
                        dtype="float32")
    pixel_ln_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                           math.log(sigma_t**2),
                           dtype="float32")
    num_pixels = images.shape[1] * images.shape[2] * images.shape[3]

    figure = plt.figure(figsize=(20, 4))
    axis_1 = figure.add_subplot(1, 5, 1)
    axis_2 = figure.add_subplot(1, 5, 2)
    axis_3 = figure.add_subplot(1, 5, 3)
    axis_4 = figure.add_subplot(1, 5, 4)
    axis_5 = figure.add_subplot(1, 5, 5)

    for iteration in range(args.training_steps):
        x = to_gpu(images_train)
        loss_kld = 0

        z_t_params_array, r_final = model.generate_z_params_and_x_from_posterior(
            x)
        for params in z_t_params_array:
            mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params
            kld = draw.nn.functions.gaussian_kl_divergence(
                mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)
            loss_kld += cf.sum(kld)

        mean_x_enc = r_final
        negative_log_likelihood = draw.nn.functions.gaussian_negative_log_likelihood(
            x, mean_x_enc, pixel_var, pixel_ln_var)
        loss_nll = cf.sum(negative_log_likelihood)
        loss_mse = cf.mean_squared_error(mean_x_enc, x)

        loss_nll /= args.batch_size
        loss_kld /= args.batch_size
        loss = loss_nll + loss_kld
        loss = loss_nll
        model.cleargrads()
        loss.backward()
        optimizer.update(iteration)

        sf = hyperparams.pixel_sigma_f
        si = hyperparams.pixel_sigma_i
        sigma_t = max(sf + (si - sf) * (1.0 - iteration / hyperparams.pixel_n),
                      sf)

        pixel_var[...] = sigma_t**2
        pixel_ln_var[...] = math.log(sigma_t**2)

        model.serialize(args.snapshot_directory)
        print(
            "\033[2KIteration {} - loss: nll_per_pixel: {:.6f} - mse: {:.6f} - kld: {:.6f} - lr: {:.4e} - sigma_t: {:.6f}"
            .format(iteration + 1,
                    float(loss_nll.data) / num_pixels, float(loss_mse.data),
                    float(loss_kld.data), optimizer.learning_rate, sigma_t))

        if iteration % 10 == 0:
            axis_1.imshow(make_uint8(x[0]))
            axis_2.imshow(make_uint8(mean_x_enc.data[0]))

            x_dev = images_dev[random.choice(range(num_dev_images))]
            axis_3.imshow(make_uint8(x_dev))

            with chainer.using_config("train", False), chainer.using_config(
                    "enable_backprop", False):
                x_dev = to_gpu(x_dev)[None, ...]
                _, r_final = model.generate_z_params_and_x_from_posterior(
                    x_dev)
                mean_x_enc = r_final
                axis_4.imshow(make_uint8(mean_x_enc.data[0]))

                mean_x_d = model.generate_image(batch_size=1, xp=xp)
                axis_5.imshow(make_uint8(mean_x_d[0]))

            plt.pause(0.01)
def main():
    dataset = gqn.data.Dataset(args.dataset_path)
    sampler = gqn.data.Sampler(dataset)
    iterator = gqn.data.Iterator(sampler, batch_size=args.batch_size)

    hyperparams = HyperParameters()
    model = Model(hyperparams, hdf5_path=args.snapshot_path)
    model.to_gpu()

    figure = gqn.imgplot.figure()
    axes = []
    sqrt_batch_size = int(math.sqrt(args.batch_size))
    axis_size = 1.0 / sqrt_batch_size
    for y in range(sqrt_batch_size):
        for x in range(sqrt_batch_size * 2):
            axis = gqn.imgplot.image()
            axes.append(axis)
            figure.add(axis, axis_size / 2 * x, axis_size * y, axis_size / 2,
                       axis_size)
    window = gqn.imgplot.window(figure, (1600, 800), "Reconstucted images")
    window.show()

    with chainer.no_backprop_mode():
        for subset_index, subset in enumerate(dataset):
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for data_indices in iterator:
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                image_size = images.shape[2:4]
                total_views = images.shape[1]

                # sample number of views
                num_views = random.choice(range(total_views))
                query_index = random.choice(range(total_views))

                if num_views > 0:
                    observed_images = images[:, :num_views]
                    observed_viewpoints = viewpoints[:, :num_views]

                    # (batch, views, height, width, channels) -> (batch * views, height, width, channels)
                    observed_images = observed_images.reshape(
                        (args.batch_size * num_views, ) +
                        observed_images.shape[2:])
                    observed_viewpoints = observed_viewpoints.reshape(
                        (args.batch_size * num_views, ) +
                        observed_viewpoints.shape[2:])

                    # (batch * views, height, width, channels) -> (batch * views, channels, height, width)
                    observed_images = observed_images.transpose((0, 3, 1, 2))

                    # transfer to gpu
                    observed_images = to_gpu(observed_images)
                    observed_viewpoints = to_gpu(observed_viewpoints)

                    r = model.representation_network.compute_r(
                        observed_images, observed_viewpoints)

                    # (batch * views, channels, height, width) -> (batch, views, channels, height, width)
                    r = r.reshape((args.batch_size, num_views) + r.shape[1:])

                    # sum element-wise across views
                    r = cf.sum(r, axis=1)
                else:
                    r = np.zeros((args.batch_size, hyperparams.channels_r) +
                                 hyperparams.chrz_size,
                                 dtype="float32")
                    r = to_gpu(r)

                query_images = images[:, query_index]
                query_viewpoints = viewpoints[:, query_index]

                # (batch * views, height, width, channels) -> (batch * views, channels, height, width)
                query_images = query_images.transpose((0, 3, 1, 2))

                # transfer to gpu
                query_images = to_gpu(query_images)
                query_viewpoints = to_gpu(query_viewpoints)

                reconstructed_images = model.reconstruct_image(
                    query_images, query_viewpoints, r, xp)

                if window.closed():
                    exit()

                for batch_index in range(args.batch_size):
                    axis = axes[batch_index * 2 + 0]
                    image = query_images[batch_index]
                    axis.update(make_uint8(image))

                    axis = axes[batch_index * 2 + 1]
                    image = reconstructed_images[batch_index]
                    axis.update(make_uint8(image))

                time.sleep(1)
Example #19
0
def main():
    _mkdir(args.snapshot_directory)
    _mkdir(args.log_directory)

    meter_train = Meter()
    meter_train.load(args.snapshot_directory)

    #==============================================================================
    # Workaround to fix OpenMPI bug
    #==============================================================================
    multiprocessing.set_start_method("forkserver")
    p = multiprocessing.Process(target=print, args=("", ))
    p.start()
    p.join()

    #==============================================================================
    # Selecting the GPU
    #==============================================================================
    comm = chainermn.create_communicator()
    device = comm.intra_rank
    cuda.get_device(device).use()

    def _print(*args):
        if comm.rank == 0:
            print(*args)

    _print("Using {} GPUs".format(comm.size))

    #==============================================================================
    # Dataset
    #==============================================================================
    dataset_train = Dataset(args.train_dataset_directory)
    dataset_test = None
    if args.test_dataset_directory is not None:
        dataset_test = Dataset(args.test_dataset_directory)

    #==============================================================================
    # Hyperparameters
    #==============================================================================
    hyperparams = HyperParameters()
    hyperparams.num_layers = args.generation_steps
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.h_channels = args.h_channels
    hyperparams.z_channels = args.z_channels
    hyperparams.u_channels = args.u_channels
    hyperparams.r_channels = args.r_channels
    hyperparams.image_size = (args.image_size, args.image_size)
    hyperparams.representation_architecture = args.representation_architecture
    hyperparams.pixel_sigma_annealing_steps = args.pixel_sigma_annealing_steps
    hyperparams.initial_pixel_sigma = args.initial_pixel_sigma
    hyperparams.final_pixel_sigma = args.final_pixel_sigma
    _print(hyperparams, "\n")

    if comm.rank == 0:
        hyperparams.save(args.snapshot_directory)

    #==============================================================================
    # Model
    #==============================================================================
    model = Model(hyperparams)
    model.load(args.snapshot_directory, meter_train.epoch)
    model.to_gpu()

    #==============================================================================
    # Pixel-variance annealing
    #==============================================================================
    variance_scheduler = PixelVarianceScheduler(
        sigma_start=args.initial_pixel_sigma,
        sigma_end=args.final_pixel_sigma,
        final_num_updates=args.pixel_sigma_annealing_steps)
    variance_scheduler.load(args.snapshot_directory)
    _print(variance_scheduler, "\n")

    pixel_log_sigma = cp.full(
        (args.batch_size, 3) + hyperparams.image_size,
        math.log(variance_scheduler.standard_deviation),
        dtype="float32")

    #==============================================================================
    # Logging
    #==============================================================================
    csv = DataFrame()
    csv.load(args.log_directory)

    #==============================================================================
    # Optimizer
    #==============================================================================
    optimizer = AdamOptimizer(
        model.parameters,
        initial_lr=args.initial_lr,
        final_lr=args.final_lr,
        initial_training_step=variance_scheduler.training_step)
    _print(optimizer, "\n")

    #==============================================================================
    # Algorithms
    #==============================================================================
    def encode_scene(images, viewpoints):
        # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
        images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)

        # Sample number of views
        total_views = images.shape[1]
        num_views = random.choice(range(1, total_views + 1))

        # Sample views
        observation_view_indices = list(range(total_views))
        random.shuffle(observation_view_indices)
        observation_view_indices = observation_view_indices[:num_views]

        observation_images = preprocess_images(
            images[:, observation_view_indices])

        observation_query = viewpoints[:, observation_view_indices]
        representation = model.compute_observation_representation(
            observation_images, observation_query)

        # Sample query view
        query_index = random.choice(range(total_views))
        query_images = preprocess_images(images[:, query_index])
        query_viewpoints = viewpoints[:, query_index]

        # Transfer to gpu if necessary
        query_images = cuda.to_gpu(query_images)
        query_viewpoints = cuda.to_gpu(query_viewpoints)

        return representation, query_images, query_viewpoints

    def estimate_ELBO(query_images, z_t_param_array, pixel_mean,
                      pixel_log_sigma):
        # KL Diverge, pixel_ln_varnce
        kl_divergence = 0
        for params_t in z_t_param_array:
            mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p = params_t
            normal_q = chainer.distributions.Normal(
                mean_z_q, log_scale=ln_var_z_q)
            normal_p = chainer.distributions.Normal(
                mean_z_p, log_scale=ln_var_z_p)
            kld_t = chainer.kl_divergence(normal_q, normal_p)
            kl_divergence += cf.sum(kld_t)
        kl_divergence = kl_divergence / args.batch_size

        # Negative log-likelihood of generated image
        batch_size = query_images.shape[0]
        num_pixels_per_batch = np.prod(query_images.shape[1:])
        normal = chainer.distributions.Normal(
            query_images, log_scale=pixel_log_sigma)

        log_px = cf.sum(normal.log_prob(pixel_mean)) / batch_size
        negative_log_likelihood = -log_px

        # Empirical ELBO
        ELBO = log_px - kl_divergence

        # https://arxiv.org/abs/1604.08772 Section.2
        # https://www.reddit.com/r/MachineLearning/comments/56m5o2/discussion_calculation_of_bitsdims/
        bits_per_pixel = -(ELBO / num_pixels_per_batch - np.log(256)) / np.log(
            2)

        return ELBO, bits_per_pixel, negative_log_likelihood, kl_divergence

    #==============================================================================
    # Training iterations
    #==============================================================================
    dataset_size = len(dataset_train)
    random.seed(0)
    np.random.seed(0)
    cp.random.seed(0)

    for epoch in range(args.epochs):
        _print("Epoch {}/{}:".format(
            epoch + 1,
            args.epochs,
        ))
        meter_train.next_epoch()

        subset_indices = list(range(len(dataset_train.subset_filenames)))
        subset_size_per_gpu = len(subset_indices) // comm.size
        if len(subset_indices) % comm.size != 0:
            subset_size_per_gpu += 1

        for subset_loop in range(subset_size_per_gpu):
            random.shuffle(subset_indices)
            subset_index = subset_indices[comm.rank]
            subset = dataset_train.read(subset_index)
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for batch_index, data_indices in enumerate(iterator):
                #------------------------------------------------------------------------------
                # Scene encoder
                #------------------------------------------------------------------------------
                # images.shape: (batch, views, height, width, channels)
                images, viewpoints = subset[data_indices]
                representation, query_images, query_viewpoints = encode_scene(
                    images, viewpoints)

                #------------------------------------------------------------------------------
                # Compute empirical ELBO
                #------------------------------------------------------------------------------
                # Compute distribution parameterws
                (z_t_param_array,
                 pixel_mean) = model.sample_z_and_x_params_from_posterior(
                     query_images, query_viewpoints, representation)

                # Compute ELBO
                (ELBO, bits_per_pixel, negative_log_likelihood,
                 kl_divergence) = estimate_ELBO(query_images, z_t_param_array,
                                                pixel_mean, pixel_log_sigma)

                #------------------------------------------------------------------------------
                # Update parameters
                #------------------------------------------------------------------------------
                loss = -ELBO
                model.cleargrads()
                loss.backward()
                optimizer.update(meter_train.num_updates)

                #------------------------------------------------------------------------------
                # Logging
                #------------------------------------------------------------------------------
                with chainer.no_backprop_mode():
                    mean_squared_error = cf.mean_squared_error(
                        query_images, pixel_mean)
                meter_train.update(
                    ELBO=float(ELBO.data),
                    bits_per_pixel=float(bits_per_pixel.data),
                    negative_log_likelihood=float(
                        negative_log_likelihood.data),
                    kl_divergence=float(kl_divergence.data),
                    mean_squared_error=float(mean_squared_error.data))

                #------------------------------------------------------------------------------
                # Annealing
                #------------------------------------------------------------------------------
                variance_scheduler.update(meter_train.num_updates)
                pixel_log_sigma[...] = math.log(
                    variance_scheduler.standard_deviation)

            if subset_loop % 100 == 0:
                _print("    Subset {}/{}:".format(
                    subset_loop + 1,
                    subset_size_per_gpu,
                    dataset_size,
                ))
                _print("        {}".format(meter_train))
                _print("        lr: {} - sigma: {}".format(
                    optimizer.learning_rate,
                    variance_scheduler.standard_deviation))

        #------------------------------------------------------------------------------
        # Validation
        #------------------------------------------------------------------------------
        meter_test = None
        if dataset_test is not None:
            meter_test = Meter()
            batch_size_test = args.batch_size * 6
            subset_indices_test = list(
                range(len(dataset_test.subset_filenames)))
            pixel_log_sigma_test = cp.full(
                (batch_size_test, 3) + hyperparams.image_size,
                math.log(variance_scheduler.standard_deviation),
                dtype="float32")

            subset_size_per_gpu = len(subset_indices_test) // comm.size

            with chainer.no_backprop_mode():
                for subset_loop in range(subset_size_per_gpu):
                    subset_index = subset_indices_test[subset_loop * comm.size
                                                       + comm.rank]
                    subset = dataset_train.read(subset_index)
                    iterator = gqn.data.Iterator(
                        subset, batch_size=batch_size_test)

                    for data_indices in iterator:
                        images, viewpoints = subset[data_indices]

                        # Scene encoder
                        representation, query_images, query_viewpoints = encode_scene(
                            images, viewpoints)

                        # Compute empirical ELBO
                        (z_t_param_array, pixel_mean
                         ) = model.sample_z_and_x_params_from_posterior(
                             query_images, query_viewpoints, representation)
                        (ELBO, bits_per_pixel, negative_log_likelihood,
                         kl_divergence) = estimate_ELBO(
                             query_images, z_t_param_array, pixel_mean,
                             pixel_log_sigma_test)
                        mean_squared_error = cf.mean_squared_error(
                            query_images, pixel_mean)

                        # Logging
                        meter_test.update(
                            ELBO=float(ELBO.data),
                            bits_per_pixel=float(bits_per_pixel.data),
                            negative_log_likelihood=float(
                                negative_log_likelihood.data),
                            kl_divergence=float(kl_divergence.data),
                            mean_squared_error=float(mean_squared_error.data))

            meter = meter_test.allreduce(comm)

            _print("    Test:")
            _print("        {} - done in {:.3f} min".format(
                meter,
                meter.elapsed_time,
            ))

            model.save(args.snapshot_directory, meter_train.epoch)
            variance_scheduler.save(args.snapshot_directory)
            meter_train.save(args.snapshot_directory)
            csv.save(args.log_directory)

            _print("Epoch {} done in {:.3f} min".format(
                epoch + 1,
                meter_train.epoch_elapsed_time,
            ))
            _print("    {}".format(meter_train))
            _print("    lr: {} - sigma: {} - training_steps: {}".format(
                optimizer.learning_rate,
                variance_scheduler.standard_deviation,
                meter_train.num_updates,
            ))
            _print("    Time elapsed: {:.3f} min".format(
                meter_train.elapsed_time))
Example #20
0
def main():
    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    hyperparams = HyperParameters()
    model = Model(hyperparams, hdf5_path=args.snapshot_path)
    if using_gpu:
        model.to_gpu()

    figure = gqn.imgplot.figure()
    axes = []
    sqrt_batch_size = int(math.sqrt(args.batch_size))
    axis_size = 1.0 / sqrt_batch_size
    for y in range(sqrt_batch_size):
        for x in range(sqrt_batch_size * 2):
            axis = gqn.imgplot.image()
            axes.append(axis)
            figure.add(axis, axis_size / 2 * x, axis_size * y, axis_size / 2,
                       axis_size)
    window = gqn.imgplot.window(figure, (1600, 800), "Generated images")
    window.show()

    camera = gqn.three.PerspectiveCamera(
        eye=(3, 1, 0),
        center=(0, 0, 0),
        up=(0, 1, 0),
        fov_rad=math.pi / 4.0,
        aspect_ratio=hyperparams.image_size[0] / hyperparams.image_size[1],
        z_near=0.1,
        z_far=10)

    with chainer.no_backprop_mode():
        for subset_index, subset in enumerate(dataset):
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for data_indices in iterator:
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                image_size = images.shape[2:4]
                total_views = images.shape[1]

                # sample number of views
                num_views = random.choice(range(total_views))
                query_index = random.choice(range(total_views))

                if num_views > 0:
                    observed_images = images[:, :num_views]
                    observed_viewpoints = viewpoints[:, :num_views]

                    # (batch, views, height, width, channels) -> (batch * views, height, width, channels)
                    observed_images = observed_images.reshape(
                        (args.batch_size * num_views, ) +
                        observed_images.shape[2:])
                    observed_viewpoints = observed_viewpoints.reshape(
                        (args.batch_size * num_views, ) +
                        observed_viewpoints.shape[2:])

                    # (batch * views, height, width, channels) -> (batch * views, channels, height, width)
                    observed_images = observed_images.transpose((0, 3, 1, 2))

                    # transfer to gpu
                    observed_images = to_gpu(observed_images)
                    observed_viewpoints = to_gpu(observed_viewpoints)

                    r = model.representation_network.compute_r(
                        observed_images, observed_viewpoints)

                    # (batch * views, channels, height, width) -> (batch, views, channels, height, width)
                    r = r.reshape((args.batch_size, num_views) + r.shape[1:])

                    # sum element-wise across views
                    r = cf.sum(r, axis=1)
                else:
                    r = np.zeros((args.batch_size, hyperparams.channels_r) +
                                 hyperparams.chrz_size,
                                 dtype="float32")
                    r = to_gpu(r)

                query_images = images[:, query_index]
                query_viewpoints = viewpoints[:, query_index]

                # transfer to gpu
                query_viewpoints = to_gpu(query_viewpoints)

                total_frames = 100
                for tick in range(total_frames):
                    rad = math.pi * 2 * tick / total_frames

                    eye = (3.0 * math.cos(rad), 1, 3.0 * math.sin(rad))
                    center = (0.0, 0.5, 0.0)

                    yaw = gqn.math.yaw(eye, center)
                    pitch = gqn.math.pitch(eye, center)
                    camera.look_at(
                        eye=eye,
                        center=center,
                        up=(0.0, 1.0, 0.0),
                    )
                    query = eye + (math.cos(yaw), math.cos(yaw),
                                   math.sin(pitch), math.sin(pitch))
                    query_viewpoints[:] = xp.asarray(query)

                    generated_images = model.generate_image(
                        query_viewpoints, r, xp)

                    if window.closed():
                        exit()

                    for batch_index in range(args.batch_size):
                        axis = axes[batch_index * 2 + 0]
                        image = query_images[batch_index]
                        axis.update(make_uint8(image))

                        axis = axes[batch_index * 2 + 1]
                        image = generated_images[batch_index]
                        axis.update(make_uint8(image))
Example #21
0
        time1 = time2
        sample_times1 = sample_times2

        # if steps >= opt.total_epochs * opt.steps_per_epoch:
        #     exit(0)
        # if time2 - time0 > 30:
        #     exit(0)

        time.sleep(5)


if __name__ == '__main__':

    ray.init()

    opt = HyperParameters(FLAGS.env_name, FLAGS.total_epochs,
                          FLAGS.num_workers, FLAGS.a_l_ratio)

    # Create a parameter server with some random weights.
    if FLAGS.is_restore == "True":
        ps = ParameterServer.remote([], [], is_restore=True)
    else:
        net = Learner(opt, job="main")
        all_keys, all_values = net.get_weights()
        ps = ParameterServer.remote(all_keys, all_values)

    replay_buffer = ReplayBuffer.remote(obs_dim=opt.obs_dim,
                                        act_dim=opt.act_dim,
                                        size=opt.replay_size)

    # Start some training tasks.
    task_rollout = [
Example #22
0
def gqn_process():
    # load model
    my_gpu = args.gpu_device
    if my_gpu < 0:
        xp = np
    else:
        cuda.get_device(args.gpu_device).use()
        xp = cp
    hyperparams = HyperParameters()
    assert hyperparams.load(args.snapshot_directory)

    model = Model(hyperparams)
    chainer.serializers.load_hdf5(args.snapshot_file, model)
    if my_gpu > -1:
        model.to_gpu()
    chainer.print_runtime_info()

    observed_viewpoint, observed_image, offset = data_recv.get()
    observed_viewpoint = np.expand_dims(np.expand_dims(
        np.asarray(observed_viewpoint).astype(np.float32), axis=0),
                                        axis=0)
    observed_image = np.expand_dims(np.expand_dims(
        np.asarray(observed_image).astype(np.float32), axis=0),
                                    axis=0)
    offset = np.asarray(offset)

    camera_distance = np.mean(
        np.linalg.norm(observed_viewpoint[:, :, :3], axis=2))
    camera_position_z = np.mean(observed_viewpoint[:, :, 1])

    observed_image = observed_image.transpose(
        (0, 1, 4, 2, 3)).astype(np.float32)
    observed_image = preprocess_images(observed_image)

    # create representation and generate uncertainty map of environment [1000 viewpoints?]
    total_frames = 10
    representation = model.compute_observation_representation(
        observed_image, observed_viewpoint)

    # get predictions
    highest_var = 0.0
    no_of_samples = 20
    highest_var_vp = 0
    try:
        for i in range(0, total_frames):
            horizontal_angle_rad = compute_camera_angle_at_frame(
                i, total_frames)

            query_viewpoints = rotate_query_viewpoint(horizontal_angle_rad,
                                                      camera_distance,
                                                      camera_position_z, xp)

            generated_images = xp.squeeze(
                xp.array(
                    model.generate_images(query_viewpoints, representation,
                                          no_of_samples)))
            var_image = xp.var(generated_images, axis=0)
            # var_image = chainer.backends.cuda.to_cpu(var_image)
            # grayscale
            # r,g,b = var_image
            # gray_var_image = 0.2989*r+0.5870*g+0.1140*b
            current_var = xp.mean(var_image)

            if highest_var == 0:
                highest_var = current_var
                highest_var_vp = query_viewpoints[0]
            elif current_var > highest_var:
                highest_var = current_var
                highest_var_vp = query_viewpoints[0]
    except KeyboardInterrupt:
        logging.warning('interrupt')

    # return next viewpoint and unit vector of end effector based on highest uncertainty found in the uncertainty map
    _x, _y, _z, _, _, _, _ = highest_var_vp

    _yaw, _pitch = compute_yaw_and_pitch([_x, _y, _z])
    next_viewpoint = [_x, _y, _z, _yaw, _pitch]
    next_viewpoint = [chainer.backends.cuda.to_cpu(x) for x in next_viewpoint]
    next_viewpoint = [float(x) for x in next_viewpoint]
    data_send.put(next_viewpoint)
Example #23
0
def worker_test(ps, replay_buffer, opt):
    agent = Actor(opt, job="main")

    test_env = Wrapper(gym.make(opt.env_name), opt.obs_noise, opt.act_noise,
                       opt.reward_scale, 3)

    agent.test(ps, replay_buffer, opt, test_env)


if __name__ == '__main__':

    # ray.init(object_store_memory=1000000000, redis_max_memory=1000000000)
    ray.init()

    # ------ HyperParameters ------
    opt = HyperParameters(FLAGS.env_name, FLAGS.exp_name, FLAGS.num_workers,
                          FLAGS.a_l_ratio, FLAGS.weights_file)
    All_Parameters = copy.deepcopy(vars(opt))
    All_Parameters["wrapper"] = inspect.getsource(Wrapper)
    All_Parameters["obs_space"] = ""
    All_Parameters["act_space"] = ""

    try:
        os.makedirs(opt.save_dir)
    except OSError:
        pass
    with open(opt.save_dir + "/" + 'All_Parameters.json', 'w') as fp:
        json.dump(All_Parameters, fp, indent=4, sort_keys=True)

    # ------ end ------

    if FLAGS.weights_file:
Example #24
0
def main():
    try:
        os.mkdir(args.output_directory)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    hyperparams = HyperParameters(snapshot_directory=args.snapshot_path)
    model = Model(hyperparams, snapshot_directory=args.snapshot_path)
    if using_gpu:
        model.to_gpu()

    plt.style.use("dark_background")
    fig = plt.figure(figsize=(10, 5))

    num_views_per_scene = 4
    num_generation = 2  # lessened from 4 to 2 (remaining 2 used for original outpu)
    num_original = 2
    total_frames_per_rotation = 24

    image_shape = (3, ) + hyperparams.image_size
    blank_image = make_uint8(np.full(image_shape, 0))
    file_number = 1

    with chainer.no_backprop_mode():
        for subset in dataset:
            iterator = gqn.data.Iterator(subset, batch_size=1)

            for data_indices in iterator:
                snapshot_array = []

                observed_image_array = xp.zeros(
                    (num_views_per_scene, ) + image_shape, dtype=np.float32)
                observed_viewpoint_array = xp.zeros((num_views_per_scene, 7),
                                                    dtype=np.float32)

                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints, original_images = subset[data_indices]

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images = images / 255.0
                images += np.random.uniform(
                    0, 1.0 / 256.0, size=images.shape).astype(np.float32)

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                original_images = original_images.transpose(
                    (0, 1, 4, 2, 3)).astype(np.float32)
                original_images = original_images / 255.0
                original_images += np.random.uniform(
                    0, 1.0 / 256.0,
                    size=original_images.shape).astype(np.float32)

                batch_index = 0

                # Generate images without observations
                r = xp.zeros((
                    num_generation,
                    hyperparams.representation_channels,
                ) + hyperparams.chrz_size,
                             dtype=np.float32)

                angle_rad = 0
                current_scene_original_images_cpu = original_images[
                    batch_index]
                current_scene_original_images = to_gpu(
                    current_scene_original_images_cpu)

                for t in range(total_frames_per_rotation):
                    snapshot = gqn.animator.Snapshot((2, 4))

                    for i in [1, 2, 5, 6]:
                        snapshot.add_media(media_type='image',
                                           media_data=make_uint8(blank_image),
                                           media_position=i)

                    query_viewpoints = rotate_query_viewpoint(
                        angle_rad, num_generation, xp)
                    generated_images = model.generate_image(
                        query_viewpoints, r, xp)

                    for i in [3, 4]:
                        snapshot.add_media(media_type='image',
                                           media_data=make_uint8(
                                               generated_images[0]),
                                           media_position=i)

                    for i in [7, 8]:
                        snapshot.add_media(
                            media_type='image',
                            media_data=make_uint8(
                                current_scene_original_images[t]),
                            media_position=i)

                    angle_rad += 2 * math.pi / total_frames_per_rotation

                    snapshot_array.append(snapshot)

                # Generate images with observations
                for m in range(num_views_per_scene):
                    observed_image = images[batch_index, m]
                    observed_viewpoint = viewpoints[batch_index, m]

                    observed_image_array[m] = to_gpu(observed_image)
                    observed_viewpoint_array[m] = to_gpu(observed_viewpoint)

                    r = model.compute_observation_representation(
                        observed_image_array[None, :m + 1],
                        observed_viewpoint_array[None, :m + 1])

                    r = cf.broadcast_to(r, (num_generation, ) + r.shape[1:])

                    angle_rad = 0
                    for t in range(total_frames_per_rotation):
                        snapshot = gqn.animator.Snapshot((2, 4))

                        for i, observed_image in zip([1, 2, 5, 6],
                                                     observed_image_array):
                            snapshot.add_media(
                                media_type='image',
                                media_data=make_uint8(observed_image),
                                media_position=i)

                        query_viewpoints = rotate_query_viewpoint(
                            angle_rad, num_generation, xp)
                        generated_images = model.generate_image(
                            query_viewpoints, r, xp)

                        for i in [3, 4]:
                            snapshot.add_media(media_type='image',
                                               media_data=make_uint8(
                                                   generated_images[0]),
                                               media_position=i)

                        for i in [7, 8]:
                            snapshot.add_media(
                                media_type='image',
                                media_data=make_uint8(
                                    current_scene_original_images[t]),
                                media_position=i)

                        angle_rad += 2 * math.pi / total_frames_per_rotation
                        # plt.pause(1e-8)

                        snapshot_array.append(snapshot)

                plt.subplots_adjust(left=None,
                                    bottom=None,
                                    right=None,
                                    top=None,
                                    wspace=0,
                                    hspace=0)

                anim = animation.FuncAnimation(
                    fig,
                    func_anim_upate,
                    fargs=(fig, [snapshot_array]),
                    interval=1 / 24,
                    frames=(num_views_per_scene + 1) *
                    total_frames_per_rotation)

                anim.save("{}/shepard_matzler_{}.mp4".format(
                    args.output_directory, file_number),
                          writer="ffmpeg",
                          fps=12)

                generated_pic_list = []
                for i in range(
                    (num_views_per_scene + 1) * total_frames_per_rotation):
                    snapshot = snapshot_array[0][i]
                    media = snapshot.get_subplot(3)
                    generated_pic_list.append(media['body'])

                for i in range(len(generated_pic_list)):
                    figu = plt.figure()
                    plt.imshow(generated_pic_list[i])
                    plt.savefig("{}/shepard_matzler_{}_{}.jpg".format(
                        args.output_directory, file_number, i))

                file_number += 1
def main():
    try:
        os.mkdir(args.snapshot_path)
    except:
        pass

    comm = chainermn.create_communicator()
    device = comm.intra_rank
    print("device", device, "/", comm.size)
    cuda.get_device(device).use()
    xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    hyperparams = HyperParameters()
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.channels_chz = args.channels_chz
    hyperparams.generator_channels_u = args.channels_u
    hyperparams.inference_channels_map_x = args.channels_map_x
    hyperparams.pixel_n = args.pixel_n
    hyperparams.pixel_sigma_i = args.initial_pixel_sigma
    hyperparams.pixel_sigma_f = args.final_pixel_sigma
    if comm.rank == 0:
        hyperparams.save(args.snapshot_path)
        hyperparams.print()

    model = Model(hyperparams, snapshot_directory=args.snapshot_path)
    model.to_gpu()

    optimizer = Optimizer(
        model.parameters,
        communicator=comm,
        mu_i=args.initial_lr,
        mu_f=args.final_lr)
    if comm.rank == 0:
        optimizer.print()

    dataset_mean, dataset_std = dataset.load_mean_and_std()

    if comm.rank == 0:
        np.save(os.path.join(args.snapshot_path, "mean.npy"), dataset_mean)
        np.save(os.path.join(args.snapshot_path, "std.npy"), dataset_std)

    # avoid division by zero
    dataset_std += 1e-12

    sigma_t = hyperparams.pixel_sigma_i
    pixel_var = xp.full(
        (args.batch_size, 3) + hyperparams.image_size,
        sigma_t**2,
        dtype="float32")
    pixel_ln_var = xp.full(
        (args.batch_size, 3) + hyperparams.image_size,
        math.log(sigma_t**2),
        dtype="float32")

    random.seed(0)
    subset_indices = list(range(len(dataset.subset_filenames)))

    current_training_step = 0
    for iteration in range(args.training_iterations):
        mean_kld = 0
        mean_nll = 0
        total_batch = 0
        subset_size_per_gpu = len(subset_indices) // comm.size
        start_time = time.time()

        for subset_loop in range(subset_size_per_gpu):
            random.shuffle(subset_indices)
            subset_index = subset_indices[comm.rank]
            subset = dataset.read(subset_index)
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for batch_index, data_indices in enumerate(iterator):
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # preprocessing
                images = (images - dataset_mean) / dataset_std

                # (batch, views, height, width, channels) ->  (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3))

                total_views = images.shape[1]

                # sample number of views
                num_views = random.choice(range(total_views))
                query_index = random.choice(range(total_views))

                if current_training_step == 0 and num_views == 0:
                    num_views = 1  # avoid OpenMPI error

                if num_views > 0:
                    r = model.compute_observation_representation(
                        images[:, :num_views], viewpoints[:, :num_views])
                else:
                    r = xp.zeros(
                        (args.batch_size, hyperparams.channels_r) +
                        hyperparams.chrz_size,
                        dtype="float32")
                    r = chainer.Variable(r)

                query_images = images[:, query_index]
                query_viewpoints = viewpoints[:, query_index]
                # transfer to gpu
                query_images = to_gpu(query_images)
                query_viewpoints = to_gpu(query_viewpoints)

                h0_gen, c0_gen, u_0, h0_enc, c0_enc = model.generate_initial_state(
                    args.batch_size, xp)

                loss_kld = 0

                hl_enc = h0_enc
                cl_enc = c0_enc
                hl_gen = h0_gen
                cl_gen = c0_gen
                ul_enc = u_0

                xq = model.inference_downsampler.downsample(query_images)

                for l in range(model.generation_steps):
                    inference_core = model.get_inference_core(l)
                    inference_posterior = model.get_inference_posterior(l)
                    generation_core = model.get_generation_core(l)
                    generation_piror = model.get_generation_prior(l)

                    h_next_enc, c_next_enc = inference_core.forward_onestep(
                        hl_gen, hl_enc, cl_enc, xq, query_viewpoints, r)

                    mean_z_q = inference_posterior.compute_mean_z(hl_enc)
                    ln_var_z_q = inference_posterior.compute_ln_var_z(hl_enc)
                    ze_l = cf.gaussian(mean_z_q, ln_var_z_q)

                    mean_z_p = generation_piror.compute_mean_z(hl_gen)
                    ln_var_z_p = generation_piror.compute_ln_var_z(hl_gen)

                    h_next_gen, c_next_gen, u_next_enc = generation_core.forward_onestep(
                        hl_gen, cl_gen, ul_enc, ze_l, query_viewpoints, r)

                    kld = gqn.nn.chainer.functions.gaussian_kl_divergence(
                        mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)

                    loss_kld += cf.sum(kld)

                    hl_gen = h_next_gen
                    cl_gen = c_next_gen
                    ul_enc = u_next_enc
                    hl_enc = h_next_enc
                    cl_enc = c_next_enc

                mean_x = model.generation_observation.compute_mean_x(ul_enc)
                negative_log_likelihood = gqn.nn.chainer.functions.gaussian_negative_log_likelihood(
                    query_images, mean_x, pixel_var, pixel_ln_var)
                loss_nll = cf.sum(negative_log_likelihood)

                loss_nll /= args.batch_size
                loss_kld /= args.batch_size
                loss = loss_nll + loss_kld

                model.cleargrads()
                loss.backward()
                optimizer.update(current_training_step)

                if comm.rank == 0:
                    printr(
                        "Iteration {}: Subset {} / {}: Batch {} / {} - loss: nll: {:.3f} kld: {:.3f} - lr: {:.4e} - sigma_t: {:.6f}".
                        format(iteration + 1, subset_loop * comm.size + 1,
                               len(dataset), batch_index + 1,
                               len(subset) // args.batch_size,
                               float(loss_nll.data), float(loss_kld.data),
                               optimizer.learning_rate, sigma_t))

                sf = hyperparams.pixel_sigma_f
                si = hyperparams.pixel_sigma_i
                sigma_t = max(
                    sf + (si - sf) *
                    (1.0 - current_training_step / hyperparams.pixel_n), sf)

                pixel_var[...] = sigma_t**2
                pixel_ln_var[...] = math.log(sigma_t**2)

                total_batch += 1
                current_training_step += comm.size
                # current_training_step += 1
                mean_kld += float(loss_kld.data)
                mean_nll += float(loss_nll.data)

            if comm.rank == 0:
                model.serialize(args.snapshot_path)

        if comm.rank == 0:
            elapsed_time = time.time() - start_time
            print(
                "\033[2KIteration {} - loss: nll: {:.3f} kld: {:.3f} - lr: {:.4e} - sigma_t: {:.6f} - step: {} - elapsed_time: {:.3f} min".
                format(iteration + 1, mean_nll / total_batch,
                       mean_kld / total_batch, optimizer.learning_rate,
                       sigma_t, current_training_step, elapsed_time / 60))
            model.serialize(args.snapshot_path)
Example #26
0
def main():
    try:
        os.mkdir(args.figure_directory)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    hyperparams = HyperParameters(snapshot_directory=args.snapshot_path)
    model = Model(hyperparams, snapshot_directory=args.snapshot_path)
    if using_gpu:
        model.to_gpu()

    plt.style.use("dark_background")
    fig = plt.figure(figsize=(10, 5))
    fig.suptitle("GQN")
    axis_observations = fig.add_subplot(1, 2, 1)
    axis_observations.axis("off")
    axis_observations.set_title("Observations")
    axis_generation = fig.add_subplot(1, 2, 2)
    axis_generation.axis("off")
    axis_generation.set_title("Generation")

    total_observations_per_scene = 2**2
    num_observations_per_column = int(math.sqrt(total_observations_per_scene))
    num_generation = 1
    total_frames_per_rotation = 48

    black_color = -0.5
    image_shape = (3, ) + hyperparams.image_size
    axis_observations_image = np.full(
        (3, num_observations_per_column * image_shape[1],
         num_observations_per_column * image_shape[2]),
        black_color,
        dtype=np.float32)
    file_number = 1

    with chainer.no_backprop_mode():
        for subset in dataset:
            iterator = gqn.data.Iterator(subset, batch_size=1)

            for data_indices in iterator:
                animation_frame_array = []
                axis_observations_image[...] = black_color

                observed_image_array = xp.full(
                    (total_observations_per_scene, ) + image_shape,
                    black_color,
                    dtype=np.float32)
                observed_viewpoint_array = xp.zeros(
                    (total_observations_per_scene, 7), dtype=np.float32)

                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # (batch, views, height, width, channels) -> (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
                images = preprocess_images(images)

                batch_index = 0

                # Generate images without observations
                representation = xp.zeros((
                    num_generation,
                    hyperparams.representation_channels,
                ) + (hyperparams.image_size[0] // 4,
                     hyperparams.image_size[1] // 4),
                                          dtype=np.float32)

                angle_rad = 0
                for t in range(total_frames_per_rotation):
                    artist_array = [
                        axis_observations.imshow(
                            make_uint8(axis_observations_image),
                            interpolation="none",
                            animated=True)
                    ]

                    query_viewpoints = rotate_query_viewpoint(
                        angle_rad, num_generation, xp)
                    generated_image = model.generate_image_from_zero_z(
                        query_viewpoints, representation)[0]

                    artist_array.append(
                        axis_generation.imshow(make_uint8(generated_image),
                                               interpolation="none",
                                               animated=True))

                    angle_rad += 2 * math.pi / total_frames_per_rotation
                    animation_frame_array.append(artist_array)

                # Generate images with observations
                for observation_index in range(total_observations_per_scene):
                    observed_image = images[batch_index, observation_index]
                    observed_viewpoint = viewpoints[batch_index,
                                                    observation_index]

                    observed_image_array[observation_index] = to_gpu(
                        observed_image)
                    observed_viewpoint_array[observation_index] = to_gpu(
                        observed_viewpoint)

                    representation = model.compute_observation_representation(
                        observed_image_array[None, :observation_index + 1],
                        observed_viewpoint_array[None, :observation_index + 1])

                    representation = cf.broadcast_to(representation,
                                                     (num_generation, ) +
                                                     representation.shape[1:])

                    # Update figure
                    x_start = image_shape[1] * (observation_index %
                                                num_observations_per_column)
                    x_end = x_start + image_shape[1]
                    y_start = image_shape[2] * (observation_index //
                                                num_observations_per_column)
                    y_end = y_start + image_shape[2]
                    axis_observations_image[:, y_start:y_end,
                                            x_start:x_end] = observed_image

                    angle_rad = 0
                    for t in range(total_frames_per_rotation):
                        artist_array = [
                            axis_observations.imshow(
                                make_uint8(axis_observations_image),
                                interpolation="none",
                                animated=True)
                        ]

                        query_viewpoints = rotate_query_viewpoint(
                            angle_rad, num_generation, xp)
                        generated_images = model.generate_image_from_zero_z(
                            query_viewpoints, representation)[0]

                        artist_array.append(
                            axis_generation.imshow(
                                make_uint8(generated_images),
                                interpolation="none",
                                animated=True))

                        angle_rad += 2 * math.pi / total_frames_per_rotation
                        animation_frame_array.append(artist_array)

                anim = animation.ArtistAnimation(fig,
                                                 animation_frame_array,
                                                 interval=1 / 24,
                                                 blit=True,
                                                 repeat_delay=0)

                anim.save("{}/shepard_matzler_observations_{}.gif".format(
                    args.figure_directory, file_number),
                          writer="imagemagick")
                anim.save("{}/shepard_matzler_observations_{}.mp4".format(
                    args.figure_directory, file_number),
                          writer="ffmpeg",
                          fps=12)
                file_number += 1
Example #27
0
def main():

    os.environ['CHAINER_SEED'] = str(args.seed)
    logging.info('chainer seed = ' + os.environ['CHAINER_SEED'])

    _mkdir(args.snapshot_directory)
    _mkdir(args.log_directory)

    meter_train = Meter()
    meter_train.load(args.snapshot_directory)

    #==============================================================================
    # Dataset
    #==============================================================================
    def read_npy_files(directory):
        filenames = []
        files = os.listdir(os.path.join(directory, "images"))
        for filename in files:
            if filename.endswith(".npy"):
                filenames.append(filename)
        filenames.sort()
        
        dataset_images = []
        dataset_viewpoints = []
        for i in range(len(filenames)):
            images_npy_path = os.path.join(directory, "images", filenames[i])
            viewpoints_npy_path = os.path.join(directory, "viewpoints", filenames[i])
            tmp_images = np.load(images_npy_path)
            tmp_viewpoints = np.load(viewpoints_npy_path)
        
            assert tmp_images.shape[0] == tmp_viewpoints.shape[0]
            
            dataset_images.extend(tmp_images)
            dataset_viewpoints.extend(tmp_viewpoints)
        dataset_images = np.array(dataset_images)
        dataset_viewpoints = np.array(dataset_viewpoints)

        dataset = list()
        for i in range(len(dataset_images)):
            item = {'image':dataset_images[i],'viewpoint':dataset_viewpoints[i]}
            dataset.append(item)
        
        return dataset

    def read_files(directory):
        filenames = []
        files = os.listdir(directory)
        
        for filename in files:
            if filename.endswith(".h5"):
                filenames.append(filename)
        filenames.sort()
        
        dataset_images = []
        dataset_viewpoints = []
        for i in range(len(filenames)):
            F = h5py.File(os.path.join(directory,filenames[i]))
            tmp_images = list(F["images"])
            tmp_viewpoints = list(F["viewpoints"])
            
            dataset_images.extend(tmp_images)
            dataset_viewpoints.extend(tmp_viewpoints)
        
        dataset_images = np.array(dataset_images)
        dataset_viewpoints = np.array(dataset_viewpoints)

        dataset = list()
        for i in range(len(dataset_images)):
            item = {'image':dataset_images[i],'viewpoint':dataset_viewpoints[i]}
            dataset.append(item)
        
        return dataset
    
    dataset_train = read_files(args.train_dataset_directory)
    # ipdb.set_trace()
    if args.test_dataset_directory is not None:
        dataset_test = read_files(args.test_dataset_directory)
    
    # ipdb.set_trace()
    
    #==============================================================================
    # Hyperparameters
    #==============================================================================
    hyperparams = HyperParameters()
    hyperparams.num_layers = args.generation_steps
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.h_channels = args.h_channels
    hyperparams.z_channels = args.z_channels
    hyperparams.u_channels = args.u_channels
    hyperparams.r_channels = args.r_channels
    hyperparams.image_size = (args.image_size, args.image_size)
    hyperparams.representation_architecture = args.representation_architecture
    hyperparams.pixel_sigma_annealing_steps = args.pixel_sigma_annealing_steps
    hyperparams.initial_pixel_sigma = args.initial_pixel_sigma
    hyperparams.final_pixel_sigma = args.final_pixel_sigma

    hyperparams.save(args.snapshot_directory)
    print(hyperparams, "\n")

    #==============================================================================
    # Model
    #==============================================================================
    model = Model(hyperparams)
    model.load(args.snapshot_directory, meter_train.epoch)
    
    #==============================================================================
    # Pixel-variance annealing
    #==============================================================================
    variance_scheduler = PixelVarianceScheduler(
        sigma_start=args.initial_pixel_sigma,
        sigma_end=args.final_pixel_sigma,
        final_num_updates=args.pixel_sigma_annealing_steps)
    variance_scheduler.load(args.snapshot_directory)
    print(variance_scheduler, "\n")

    pixel_log_sigma = np.full(
        (args.batch_size, 3) + hyperparams.image_size,
        math.log(variance_scheduler.standard_deviation),
        dtype="float32")

    #==============================================================================
    # Selecting the GPU
    #==============================================================================
    # xp = np
    # gpu_device = args.gpu_device
    # using_gpu = gpu_device >= 0
    # if using_gpu:
    #     cuda.get_device(gpu_device).use()
    #     xp = cp

    # devices = tuple([chainer.get_device(f"@cupy:{gpu}") for gpu in args.gpu_devices])
    # if any(device.xp is chainerx for device in devices):
    #     sys.stderr.write("Cannot support ChainerX devices.")
    #     sys.exit(1)

    ngpu = args.ngpu
    using_gpu = ngpu > 0
    xp=cp
    if ngpu == 1:
        gpu_id = 0
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(gpu_id).use()
        model.to_gpu()  # Copy the model to the GPU
        logging.info('single gpu calculation.')
    elif ngpu > 1:
        gpu_id = 0
        devices = {'main': gpu_id}
        for gid in six.moves.xrange(1, ngpu):
            devices['sub_%d' % gid] = gid
        logging.info('multi gpu calculation (#gpus = %d).' % ngpu)
        logging.info('batch size is automatically increased (%d -> %d)' % (
            args.batch_size, args.batch_size * args.ngpu))
    else:
        gpu_id = -1
        logging.info('cpu calculation')

    #==============================================================================
    # Logging
    #==============================================================================
    csv = DataFrame()
    csv.load(args.log_directory)

    #==============================================================================
    # Optimizer
    #==============================================================================
    initial_training_step=0
    # lr = compute_lr_at_step(initial_training_step) # function in GQN AdamOptimizer
    
    optimizer = chainer.optimizers.Adam(beta1=0.9, beta2=0.99, eps=1e-8) #lr is needed originally
    optimizer.setup(model)
    # optimizer = AdamOptimizer(
    #     model.parameters,
    #     initial_lr=args.initial_lr,
    #     final_lr=args.final_lr,
    #     initial_training_step=variance_scheduler.training_step)
    # )
    print(optimizer, "\n")


    #==============================================================================
    # Training iterations
    #==============================================================================
    if ngpu>1:

        train_iters = [
            chainer.iterators.MultiprocessIterator(dataset_train, args.batch_size, n_processes=args.number_processes, order_sampler=chainer.iterators.ShuffleOrderSampler()) for i in chainer.datasets.split_dataset_n_random(dataset_train, len(devices))
        ]
        updater = CustomParallelUpdater(train_iters, optimizer, devices, converter=chainer.dataset.concat_examples, pixel_log_sigma=pixel_log_sigma)
    
    elif ngpu==1:
        
        train_iters = chainer.iterators.SerialIterator(dataset_train,args.batch_size,shuffle=True)
        updater = CustomUpdater(train_iters, optimizer, device=0, converter=chainer.dataset.concat_examples, pixel_log_sigma=pixel_log_sigma)
        
    else:
        raise NotImplementedError('Implement for single gpu or cpu')
    
    trainer = chainer.training.Trainer(updater,(args.epochs,'epoch'),args.snapshot_directory)
    
    trainer.extend(AnnealLearningRate(
                                    initial_lr=args.initial_lr,
                                    final_lr=args.final_lr,
                                    annealing_steps=args.pixel_sigma_annealing_steps,
                                    optimizer=optimizer),
                                    trigger=(1,'iteration'))

    # add information per epoch with report?
    # add learning rate annealing, snapshot saver, evaluator
    trainer.extend(extensions.LogReport())
    
    trainer.extend(extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}', 
                                    savefun=chainer.serializers.save_hdf5, 
                                    target=optimizer.target),
                                    trigger=(args.report_interval_iters,'epoch'))
    
    trainer.extend(extensions.ProgressBar())
    reports = ['epoch', 'main/loss', 'main/bits_per_pixel', 'main/NLL', 'main/MSE']
    #Validation
    if args.test_dataset_directory is not None:
        test_iters = chainer.iterators.SerialIterator(
            dataset_test,args.batch_size*6, repeat=False, shuffle=False)

        trainer.extend(Validation(
            test_iters, chainer.dataset.concat_examples, optimizer.target, variance_scheduler,device=0))

        reports.append('validation/main/bits_per_pixel')
        reports.append('validation/main/NLL')
        reports.append('validation/main/MSE')
    reports.append('elapsed_time')

    trainer.extend(
        extensions.PrintReport(reports), trigger=(args.report_interval_iters, 'iteration')) 

    # np.random.seed(args.seed)
    # cp.random.seed(args.seed)

    trainer.run()
Example #28
0
def main():
    try:
        os.mkdir(args.snapshot_directory)
    except:
        pass

    images = []
    files = os.listdir(args.dataset_path)
    files.sort()
    for filename in files:
        image = np.load(os.path.join(args.dataset_path, filename))
        image = image / 255
        images.append(image)

    images = np.vstack(images)
    images = images.transpose((0, 3, 1, 2)).astype(np.float32)
    train_dev_split = 0.9
    num_images = images.shape[0]
    num_train_images = int(num_images * train_dev_split)
    num_dev_images = num_images - num_train_images
    images_train = images[:num_train_images]
    images_dev = images[num_dev_images:]

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cp

    hyperparams = HyperParameters(snapshot_directory=args.snapshot_directory)
    hyperparams.print()

    if hyperparams.use_gru:
        model = GRUModel(hyperparams,
                         snapshot_directory=args.snapshot_directory)
    else:
        model = LSTMModel(hyperparams,
                          snapshot_directory=args.snapshot_directory)
    if using_gpu:
        model.to_gpu()

    dataset = draw.data.Dataset(images_dev)
    iterator = draw.data.Iterator(dataset, batch_size=1)

    cols = hyperparams.generator_generation_steps
    figure = plt.figure(figsize=(8, 4 * cols))
    axis_1 = figure.add_subplot(cols, 3, 1)
    axis_1.set_title("Data")

    axis_rec_array = []
    for n in range(cols):
        axis_rec_array.append(figure.add_subplot(cols, 3, n * 3 + 2))

    axis_rec_array[0].set_title("Reconstruction")

    axis_gen_array = []
    for n in range(cols):
        axis_gen_array.append(figure.add_subplot(cols, 3, n * 3 + 3))

    axis_gen_array[0].set_title("Generation")

    for batch_index, data_indices in enumerate(iterator):

        with chainer.using_config("train", False), chainer.using_config(
                "enable_backprop", False):
            x = dataset[data_indices]
            x = to_gpu(x)
            axis_1.imshow(make_uint8(x[0]))

            r_t_array, x_param = model.sample_image_at_each_step_from_posterior(
                x,
                zero_variance=args.zero_variance,
                step_limit=args.step_limit)
            for r_t, axis in zip(r_t_array, axis_rec_array[:-1]):
                r_t = to_cpu(r_t)
                axis.imshow(make_uint8(r_t[0]))

            mu_x, ln_var_x = x_param
            mu_x = to_cpu(mu_x.data)
            axis_rec_array[-1].imshow(make_uint8(mu_x[0]))

            r_t_array, x_param = model.sample_image_at_each_step_from_prior(
                batch_size=1, xp=xp)
            for r_t, axis in zip(r_t_array, axis_gen_array[:-1]):
                r_t = to_cpu(r_t)
                axis.imshow(make_uint8(r_t[0]))

            mu_x, ln_var_x = x_param
            mu_x = to_cpu(mu_x.data)
            axis_gen_array[-1].imshow(make_uint8(mu_x[0]))

            plt.pause(0.01)
Example #29
0
def main():
    try:
        os.makedirs(args.figure_directory)
    except:
        pass

    #==============================================================================
    # Utilities
    #==============================================================================
    def read_files(directory):
        filenames = []
        files = os.listdir(directory)
        # ipdb.set_trace()
        for filename in files:
            if filename.endswith(".h5"):
                filenames.append(filename)
        filenames.sort()

        dataset_images = []
        dataset_viewpoints = []
        for i in range(len(filenames)):
            F = h5py.File(os.path.join(directory, filenames[i]))
            tmp_images = list(F["images"])
            tmp_viewpoints = list(F["viewpoints"])

            dataset_images.extend(tmp_images)
            dataset_viewpoints.extend(tmp_viewpoints)

        # for i in range(len(filenames)):
        #     images_npy_path = os.path.join(directory, "images", filenames[i])
        #     viewpoints_npy_path = os.path.join(directory, "viewpoints", filenames[i])
        #     tmp_images = np.load(images_npy_path)
        #     tmp_viewpoints = np.load(viewpoints_npy_path)

        #     assert tmp_images.shape[0] == tmp_viewpoints.shape[0]

        #     dataset_images.extend(tmp_images)
        #     dataset_viewpoints.extend(tmp_viewpoints)
        dataset_images = np.array(dataset_images)
        dataset_viewpoints = np.array(dataset_viewpoints)

        dataset = list()
        for i in range(len(dataset_images)):
            item = {
                'image': dataset_images[i],
                'viewpoint': dataset_viewpoints[i]
            }
            dataset.append(item)

        return dataset

    def to_device(array):
        # if using_gpu:
        array = cuda.to_gpu(array)
        return array

    def fill_observations_axis(observation_images):
        axis_observations_image = np.full(
            (3, image_shape[1], total_observations_per_scene * image_shape[2]),
            black_color,
            dtype=np.float32)
        num_current_obs = len(observation_images)
        total_obs = total_observations_per_scene
        width = image_shape[2]
        x_start = width * (total_obs - num_current_obs) // 2
        for obs_image in observation_images:
            x_end = x_start + width
            axis_observations_image[:, :, x_start:x_end] = obs_image
            x_start += width
        return axis_observations_image

    def compute_camera_angle_at_frame(t):
        return t * 2 * math.pi / (fps * 2)

    def rotate_query_viewpoint(horizontal_angle_rad, camera_distance,
                               camera_position_y):
        camera_position = np.array([
            camera_distance * math.sin(horizontal_angle_rad),  # x
            camera_position_y,
            camera_distance * math.cos(horizontal_angle_rad),  # z
        ])
        center = np.array((0, camera_position_y, 0))
        camera_direction = camera_position - center
        yaw, pitch = compute_yaw_and_pitch(camera_direction)

        query_viewpoints = xp.array(
            (
                camera_position[0],
                camera_position[1],
                camera_position[2],
                math.cos(yaw),
                math.sin(yaw),
                math.cos(pitch),
                math.sin(pitch),
            ),
            dtype=np.float32,
        )
        query_viewpoints = xp.broadcast_to(query_viewpoints,
                                           (1, ) + query_viewpoints.shape)

        return query_viewpoints

    def render(representation,
               camera_distance,
               camera_position_y,
               total_frames,
               animation_frame_array,
               rotate_camera=True):

        # viewpoint_file = open('viewpoints.txt','w')
        for t in range(0, total_frames):
            artist_array = [
                axis_observations.imshow(cv2.cvtColor(
                    make_uint8(axis_observations_image), cv2.COLOR_BGR2RGB),
                                         interpolation="none",
                                         animated=True)
            ]

            horizontal_angle_rad = compute_camera_angle_at_frame(t)
            if rotate_camera == False:
                horizontal_angle_rad = compute_camera_angle_at_frame(0)

            query_viewpoints = rotate_query_viewpoint(horizontal_angle_rad,
                                                      camera_distance,
                                                      camera_position_y)

            generated_images = model.generate_image(query_viewpoints,
                                                    representation)[0]
            generated_images = chainer.backends.cuda.to_cpu(generated_images)
            generated_images = make_uint8(generated_images)
            generated_images = cv2.cvtColor(generated_images,
                                            cv2.COLOR_BGR2RGB)

            artist_array.append(
                axis_generation.imshow(generated_images,
                                       interpolation="none",
                                       animated=True))

            animation_frame_array.append(artist_array)

    def render_wVar(representation,
                    camera_distance,
                    camera_position_y,
                    total_frames,
                    animation_frame_array,
                    no_of_samples,
                    rotate_camera=True,
                    wVariance=True):

        # highest_var = 0.0
        # with open("queries.txt",'w') as file_wviews, open("variance.txt",'w') as file_wvar:
        for t in range(0, total_frames):
            artist_array = [
                axis_observations.imshow(cv2.cvtColor(
                    make_uint8(axis_observations_image), cv2.COLOR_BGR2RGB),
                                         interpolation="none",
                                         animated=True)
            ]

            horizontal_angle_rad = compute_camera_angle_at_frame(t)
            if rotate_camera == False:
                horizontal_angle_rad = compute_camera_angle_at_frame(0)

            query_viewpoints = rotate_query_viewpoint(horizontal_angle_rad,
                                                      camera_distance,
                                                      camera_position_y)

            # q_x, q_y, q_z, _, _, _, _ = query_viewpoints[0]

            # file_wviews.writelines("".join(str(q_x))+", "+
            #                         "".join(str(q_y))+", "+
            #                         "".join(str(q_z))+"\n")

            generated_images = cp.squeeze(
                cp.array(
                    model.generate_images(query_viewpoints, representation,
                                          no_of_samples)))
            # ipdb.set_trace()
            var_image = cp.var(generated_images, axis=0)
            mean_image = cp.mean(generated_images, axis=0)
            mean_image = make_uint8(
                np.squeeze(chainer.backends.cuda.to_cpu(mean_image)))
            mean_image_rgb = cv2.cvtColor(mean_image, cv2.COLOR_BGR2RGB)

            var_image = chainer.backends.cuda.to_cpu(var_image)

            # grayscale
            r, g, b = var_image
            gray_var_image = 0.2989 * r + 0.5870 * g + 0.1140 * b
            # thresholding Otsu's method
            # thresh = threshold_otsu(gray_var_image)
            # var_binary = gray_var_image > thresh

            ## hill climb algorthm for searching highest variance
            # cur_var = np.mean(gray_var_image)
            # if cur_var>highest_var:
            #     highest_var = cur_var

            #     if wVariance==True:
            #         print('highest variance: '+str(highest_var)+', viewpoint: '+str(query_viewpoints[0]))
            #         highest_var_vp = query_viewpoints[0]
            #         file_wvar.writelines('highest variance: '+str(highest_var)+', viewpoint: '+str(highest_var_vp)+'\n')
            #     else:
            #         pass

            artist_array.append(
                axis_generation_var.imshow(gray_var_image,
                                           cmap=plt.cm.gray,
                                           interpolation="none",
                                           animated=True))

            artist_array.append(
                axis_generation_mean.imshow(mean_image_rgb,
                                            interpolation="none",
                                            animated=True))

            animation_frame_array.append(artist_array)

            # if wVariance==True:
            #     print('final highest variance: '+str(highest_var)+', viewpoint: '+str(highest_var_vp))
            #     file_wvar.writelines('final highest variance: '+str(highest_var)+', viewpoint: '+str(highest_var_vp)+'\n')
            # else:
            #     pass

        # file_wviews.close()
        # file_wvar.close()

    # loading dataset & model
    cuda.get_device(args.gpu_device).use()
    xp = cp

    hyperparams = HyperParameters()
    assert hyperparams.load(args.snapshot_directory)

    model = Model(hyperparams)
    chainer.serializers.load_hdf5(args.snapshot_file, model)
    model.to_gpu()

    total_observations_per_scene = 4
    fps = 30

    black_color = -0.5
    image_shape = (3, ) + hyperparams.image_size
    axis_observations_image = np.zeros(
        (3, image_shape[1], total_observations_per_scene * image_shape[2]),
        dtype=np.float32)

    #==============================================================================
    # Visualization
    #==============================================================================
    plt.style.use("dark_background")
    fig = plt.figure(figsize=(6, 7))
    plt.subplots_adjust(left=0.1, right=0.95, bottom=0.1, top=0.95)
    # fig.suptitle("GQN")
    axis_observations = fig.add_subplot(2, 1, 1)
    axis_observations.axis("off")
    axis_observations.set_title("observations")
    axis_generation = fig.add_subplot(2, 1, 2)
    axis_generation.axis("off")
    axis_generation.set_title("Rendered Predictions")
    axis_generation_var = fig.add_subplot(2, 2, 3)
    axis_generation_var.axis("off")
    axis_generation_var.set_title("Variance Render")
    axis_generation_mean = fig.add_subplot(2, 2, 4)
    axis_generation_mean.axis("off")
    axis_generation_mean.set_title("Mean Render")

    # iterator
    dataset = read_files(args.dataset_directory)
    file_number = 1
    with chainer.no_backprop_mode():

        iterator = chainer.iterators.SerialIterator(dataset, batch_size=1)
        # ipdb.set_trace()
        for i in tqdm(range(len(iterator.dataset))):
            animation_frame_array = []
            images, viewpoints = np.array([
                iterator.dataset[i]["image"]
            ]), np.array([iterator.dataset[i]["viewpoint"]])

            camera_distance = np.mean(
                np.linalg.norm(viewpoints[:, :, :3], axis=2))
            camera_position_y = np.mean(viewpoints[:, :, 1])

            images = images.transpose((0, 1, 4, 2, 3)).astype(np.float32)
            images = preprocess_images(images)

            batch_index = 0

            total_views = images.shape[1]
            random_observation_view_indices = list(range(total_views))
            random.shuffle(random_observation_view_indices)
            random_observation_view_indices = random_observation_view_indices[:
                                                                              total_observations_per_scene]
            observed_images = images[batch_index,
                                     random_observation_view_indices]
            observed_viewpoints = viewpoints[batch_index,
                                             random_observation_view_indices]

            observed_images = to_device(observed_images)
            observed_viewpoints = to_device(observed_viewpoints)

            # Scene encoder
            representation = model.compute_observation_representation(
                observed_images[None, :1], observed_viewpoints[None, :1])

            # Update figure
            observation_index = random_observation_view_indices[0]
            observed_image = images[batch_index, observation_index]
            axis_observations_image = fill_observations_axis([observed_image])

            # Neural rendering
            # render(representation, camera_distance, camera_position_y,
            #         fps * 2, animation_frame_array)
            render_wVar(representation, camera_distance, camera_position_y,
                        fps * 2, animation_frame_array, 100)

            for n in range(total_observations_per_scene):
                observation_indices = random_observation_view_indices[:n + 1]
                axis_observations_image = fill_observations_axis(
                    images[batch_index, observation_indices])

                # Scene encoder
                representation = model.compute_observation_representation(
                    observed_images[None, :n + 1],
                    observed_viewpoints[None, :n + 1])
                # Neural rendering
                # render(representation, camera_distance, camera_position_y,
                #     fps // 2, animation_frame_array,rotate_camera=False)
                render_wVar(representation,
                            camera_distance,
                            camera_position_y,
                            fps // 2,
                            animation_frame_array,
                            100,
                            rotate_camera=False,
                            wVariance=False)

            # Scene encoder with all given observations
            representation = model.compute_observation_representation(
                observed_images[None, :total_observations_per_scene + 1],
                observed_viewpoints[None, :total_observations_per_scene + 1])

            # Neural rendering
            # render(representation, camera_distance, camera_position_y,
            #         fps * 6, animation_frame_array)
            render_wVar(representation, camera_distance, camera_position_y,
                        fps * 6, animation_frame_array, 100)

            anim = animation.ArtistAnimation(
                fig,
                animation_frame_array,
                interval=1 / fps,  # originally 1/fps
                blit=True,
                repeat_delay=0)

            anim.save("{}/observations_{}.gif".format(args.figure_directory,
                                                      file_number),
                      writer="imagemagick",
                      fps=10)
            # ipdb.set_trace()
            # anim.save(
            #     "{}/rooms_ring_camera_observations_{}.mp4".format(
            #         args.figure_directory, file_number),
            #     writer='ffmpeg',
            #     fps=10)

            file_number += 1
Example #30
0
def main():
    try:
        os.mkdir(args.snapshot_path)
    except:
        pass

    xp = np
    using_gpu = args.gpu_device >= 0
    if using_gpu:
        cuda.get_device(args.gpu_device).use()
        xp = cupy

    dataset = gqn.data.Dataset(args.dataset_path)

    hyperparams = HyperParameters()
    hyperparams.generator_share_core = args.generator_share_core
    hyperparams.generator_share_prior = args.generator_share_prior
    hyperparams.generator_generation_steps = args.generation_steps
    hyperparams.inference_share_core = args.inference_share_core
    hyperparams.inference_share_posterior = args.inference_share_posterior
    hyperparams.pixel_n = args.pixel_n
    hyperparams.pixel_sigma_i = args.initial_pixel_sigma
    hyperparams.pixel_sigma_f = args.final_pixel_sigma
    hyperparams.save(args.snapshot_path)
    hyperparams.print()

    model = Model(hyperparams, snapshot_directory=args.snapshot_path)
    if using_gpu:
        model.to_gpu()

    optimizer = Optimizer(model.parameters,
                          mu_i=args.initial_lr,
                          mu_f=args.final_lr)
    optimizer.print()

    if args.with_visualization:
        figure = gqn.imgplot.figure()
        axis1 = gqn.imgplot.image()
        axis2 = gqn.imgplot.image()
        axis3 = gqn.imgplot.image()
        figure.add(axis1, 0, 0, 1 / 3, 1)
        figure.add(axis2, 1 / 3, 0, 1 / 3, 1)
        figure.add(axis3, 2 / 3, 0, 1 / 3, 1)
        plot = gqn.imgplot.window(
            figure, (500 * 3, 500),
            "Query image / Reconstructed image / Generated image")
        plot.show()

    sigma_t = hyperparams.pixel_sigma_i
    pixel_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                        sigma_t**2,
                        dtype="float32")
    pixel_ln_var = xp.full((args.batch_size, 3) + hyperparams.image_size,
                           math.log(sigma_t**2),
                           dtype="float32")

    dataset_mean, dataset_std = dataset.load_mean_and_std()

    np.save(os.path.join(args.snapshot_path, "mean.npy"), dataset_mean)
    np.save(os.path.join(args.snapshot_path, "std.npy"), dataset_std)

    # avoid division by zero
    dataset_std += 1e-12

    current_training_step = 0
    for iteration in range(args.training_iterations):
        mean_kld = 0
        mean_nll = 0
        total_batch = 0

        for subset_index, subset in enumerate(dataset):
            iterator = gqn.data.Iterator(subset, batch_size=args.batch_size)

            for batch_index, data_indices in enumerate(iterator):
                # shape: (batch, views, height, width, channels)
                # range: [-1, 1]
                images, viewpoints = subset[data_indices]

                # preprocessing
                images = (images - dataset_mean) / dataset_std

                # (batch, views, height, width, channels) ->  (batch, views, channels, height, width)
                images = images.transpose((0, 1, 4, 2, 3))

                total_views = images.shape[1]

                # sample number of views
                num_views = random.choice(range(total_views))
                query_index = random.choice(range(total_views))

                if num_views > 0:
                    r = model.compute_observation_representation(
                        images[:, :num_views], viewpoints[:, :num_views])
                else:
                    r = xp.zeros((args.batch_size, hyperparams.channels_r) +
                                 hyperparams.chrz_size,
                                 dtype="float32")
                    r = chainer.Variable(r)

                query_images = images[:, query_index]
                query_viewpoints = viewpoints[:, query_index]

                # transfer to gpu
                query_images = to_gpu(query_images)
                query_viewpoints = to_gpu(query_viewpoints)

                h0_gen, c0_gen, u_0, h0_enc, c0_enc = model.generate_initial_state(
                    args.batch_size, xp)

                loss_kld = 0

                hl_enc = h0_enc
                cl_enc = c0_enc
                hl_gen = h0_gen
                cl_gen = c0_gen
                ul_enc = u_0

                xq = model.inference_downsampler.downsample(query_images)

                for l in range(model.generation_steps):
                    inference_core = model.get_inference_core(l)
                    inference_posterior = model.get_inference_posterior(l)
                    generation_core = model.get_generation_core(l)
                    generation_piror = model.get_generation_prior(l)

                    h_next_enc, c_next_enc = inference_core.forward_onestep(
                        hl_gen, hl_enc, cl_enc, xq, query_viewpoints, r)

                    mean_z_q = inference_posterior.compute_mean_z(hl_enc)
                    ln_var_z_q = inference_posterior.compute_ln_var_z(hl_enc)
                    ze_l = cf.gaussian(mean_z_q, ln_var_z_q)

                    mean_z_p = generation_piror.compute_mean_z(hl_gen)
                    ln_var_z_p = generation_piror.compute_ln_var_z(hl_gen)

                    h_next_gen, c_next_gen, u_next_enc = generation_core.forward_onestep(
                        hl_gen, cl_gen, ul_enc, ze_l, query_viewpoints, r)

                    kld = gqn.nn.chainer.functions.gaussian_kl_divergence(
                        mean_z_q, ln_var_z_q, mean_z_p, ln_var_z_p)

                    loss_kld += cf.sum(kld)

                    hl_gen = h_next_gen
                    cl_gen = c_next_gen
                    ul_enc = u_next_enc
                    hl_enc = h_next_enc
                    cl_enc = c_next_enc

                mean_x = model.generation_observation.compute_mean_x(ul_enc)
                negative_log_likelihood = gqn.nn.chainer.functions.gaussian_negative_log_likelihood(
                    query_images, mean_x, pixel_var, pixel_ln_var)
                loss_nll = cf.sum(negative_log_likelihood)

                loss_nll /= args.batch_size
                loss_kld /= args.batch_size
                loss = loss_nll + loss_kld

                model.cleargrads()
                loss.backward()
                optimizer.update(current_training_step)

                if args.with_visualization and plot.closed() is False:
                    axis1.update(
                        make_uint8(query_images[0], dataset_mean, dataset_std))
                    axis2.update(
                        make_uint8(mean_x.data[0], dataset_mean, dataset_std))

                    with chainer.no_backprop_mode():
                        generated_x = model.generate_image(
                            query_viewpoints[None, 0], r[None, 0], xp)
                        axis3.update(
                            make_uint8(generated_x[0], dataset_mean,
                                       dataset_std))

                printr(
                    "Iteration {}: Subset {} / {}: Batch {} / {} - loss: nll: {:.3f} kld: {:.3f} - lr: {:.4e} - sigma_t: {:.6f}"
                    .format(iteration + 1,
                            subset_index + 1, len(dataset), batch_index + 1,
                            len(iterator), float(loss_nll.data),
                            float(loss_kld.data), optimizer.learning_rate,
                            sigma_t))

                sf = hyperparams.pixel_sigma_f
                si = hyperparams.pixel_sigma_i
                sigma_t = max(
                    sf + (si - sf) *
                    (1.0 - current_training_step / hyperparams.pixel_n), sf)

                pixel_var[...] = sigma_t**2
                pixel_ln_var[...] = math.log(sigma_t**2)

                total_batch += 1
                current_training_step += 1
                mean_kld += float(loss_kld.data)
                mean_nll += float(loss_nll.data)

            model.serialize(args.snapshot_path)

        print(
            "\033[2KIteration {} - loss: nll: {:.3f} kld: {:.3f} - lr: {:.4e} - sigma_t: {:.6f} - step: {}"
            .format(iteration + 1, mean_nll / total_batch,
                    mean_kld / total_batch, optimizer.learning_rate, sigma_t,
                    current_training_step))