Пример #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("folder_path", type=str)
    args = parser.parse_args()
    base_dir = Path(os.getcwd())
    base_dir = base_dir / args.folder_path

    path_and_iter = get_path_and_iters(base_dir)

    resolution = 20
    state_bounds = (-WaterMaze.BOUNDARY_DIST, WaterMaze.BOUNDARY_DIST)
    action_bounds = (-1, 1)
    num_target_poses = 5
    target_poses = np.linspace(*state_bounds, num_target_poses)

    report = HTMLReport(
        str(base_dir / 'report.html'), images_per_row=num_target_poses
    )

    report.add_header("test_header")
    report.add_text("test")
    for path, iter_number in path_and_iter:
        data = joblib.load(str(path))
        qf = data['qf']
        print("QF loaded from iteration %d" % iter_number)

        list_of_vector_fields = []
        for time in [0, 24]:
            vector_fields = []
            for target_pos in target_poses:
                qf_vector_field_eval = create_qf_derivative_eval_fnct(
                    qf, target_pos, time
                )
                vector_fields.append(vu.make_vector_field(
                    qf_vector_field_eval,
                    x_bounds=state_bounds,
                    y_bounds=action_bounds,
                    resolution=resolution,
                    info=dict(
                        time=time,
                        target_pos=target_pos,
                        title="Estimated QF and dQ/da",
                    )
                ))
            list_of_vector_fields.append(vector_fields)

        report.add_text("Iteration = {0}".format(iter_number))
        create_figure(
            report,
            target_poses,
            *list_of_vector_fields,
        )
        report.new_row()
Пример #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('dir', type=str)
    parser.add_argument('--report_name', type=str,
                        default='report_retroactive.html')

    args = parser.parse_args()
    directory = args.dir
    report_name = args.report_name

    with open(join(directory, 'variant.json')) as variant_file:
        variant = json.load(variant_file)
    skew_config = get_key_recursive(variant, 'skew_config')
    pkl_paths = glob.glob(directory + '/*.pkl')
    numbered_paths = append_itr(pkl_paths)
    ordered_numbered_paths = sorted(numbered_paths, key=lambda x: x[1])

    report = HTMLReport(join(directory, report_name), images_per_row=5)

    vae_heatmap_imgs = []
    sample_imgs = []
    for path, itr in ordered_numbered_paths:
        print("Processing iteration {}".format(itr))
        snapshot = pickle.load(open(path, "rb"))
        if 'vae' in snapshot:
            vae = snapshot['vae']
        else:
            vae = snapshot['p_theta']
        vae.to('cpu')
        vae_train_data = snapshot['train_data']
        dynamics = snapshot.get('dynamics', project_square_border_np_4x4)
        report.add_header("Iteration {}".format(itr))
        vae.xy_range = ((-4, 4), (-4, 4))
        vae_heatmap_img = visualize_vae_samples(
            itr,
            vae_train_data,
            vae,
            report,
            xlim=vae.get_plot_ranges()[0],
            ylim=vae.get_plot_ranges()[1],
            dynamics=dynamics,
        )
        sample_img = visualize_vae(
            vae,
            skew_config,
            report,
            title="Post-skew",
        )
        vae_heatmap_imgs.append(vae_heatmap_img)
        sample_imgs.append(sample_img)

    report.add_header("Summary GIFs")
    for filename, imgs in [
        ("vae_heatmaps", vae_heatmap_imgs),
        ("samples", sample_imgs),
    ]:
        video = np.stack(imgs)
        vwrite(
            '{}/{}.mp4'.format(directory, filename),
            video,
        )
        gif_file_path = '{}/{}.gif'.format(directory, filename)
        gif(gif_file_path, video)
        report.add_image(gif_file_path, txt=filename, is_url=True)

    report.save()
    print("Report saved to")
    print(report.path)
Пример #3
0
def train(dataset_generator,
          n_start_samples,
          projection=project_samples_square_np,
          n_samples_to_add_per_epoch=1000,
          n_epochs=100,
          save_period=10,
          append_all_data=True,
          full_variant=None,
          dynamics_noise=0,
          num_bins=5,
          weight_type='sqrt_inv_p',
          **kwargs):
    report = HTMLReport(
        logger.get_snapshot_dir() + '/report.html',
        images_per_row=3,
    )
    dynamics = Dynamics(projection, dynamics_noise)
    if full_variant:
        report.add_header("Variant")
        report.add_text(
            json.dumps(
                ppp.dict_to_safe_json(full_variant, sort=True),
                indent=2,
            ))

    orig_train_data = dataset_generator(n_start_samples)
    train_data = orig_train_data

    heatmap_imgs = []
    sample_imgs = []
    entropies = []
    tvs_to_uniform = []
    """
    p_theta = previous iteration's model
    p_new = this iteration's distribution
    """
    p_theta = Histogram(num_bins, weight_type=weight_type)
    for epoch in range(n_epochs):
        logger.record_tabular('Epoch', epoch)
        logger.record_tabular('Entropy ', p_theta.entropy())
        logger.record_tabular('KL from uniform', p_theta.kl_from_uniform())
        logger.record_tabular('TV to uniform', p_theta.tv_to_uniform())
        entropies.append(p_theta.entropy())
        tvs_to_uniform.append(p_theta.tv_to_uniform())

        samples = p_theta.sample(n_samples_to_add_per_epoch)
        empirical_samples = dynamics(samples)

        if append_all_data:
            train_data = np.vstack((train_data, empirical_samples))
        else:
            train_data = np.vstack((orig_train_data, empirical_samples))

        if epoch == 0 or (epoch + 1) % save_period == 0:
            report.add_text("Epoch {}".format(epoch))
            heatmap_img = visualize_histogram(epoch, p_theta, report)
            sample_img = visualize_samples(epoch, train_data, p_theta, report,
                                           dynamics)
            heatmap_imgs.append(heatmap_img)
            sample_imgs.append(sample_img)
            report.save()

            from PIL import Image
            Image.fromarray(heatmap_img).save(logger.get_snapshot_dir() +
                                              '/heatmap{}.png'.format(epoch))
            Image.fromarray(sample_img).save(logger.get_snapshot_dir() +
                                             '/samples{}.png'.format(epoch))
        weights = p_theta.compute_per_elem_weights(train_data)
        p_new = Histogram(num_bins, weight_type=weight_type)
        p_new.fit(
            train_data,
            weights=weights,
        )
        p_theta = p_new
        logger.dump_tabular()
    plot_curves([
        ("Entropy", entropies),
        ("TVs to Uniform", tvs_to_uniform),
    ], report)
    report.add_text("Max entropy: {}".format(p_theta.max_entropy()))
    report.save()

    heatmap_video = np.stack(heatmap_imgs)
    sample_video = np.stack(sample_imgs)

    vwrite(
        logger.get_snapshot_dir() + '/heatmaps.mp4',
        heatmap_video,
    )
    vwrite(
        logger.get_snapshot_dir() + '/samples.mp4',
        sample_video,
    )
    try:
        gif(
            logger.get_snapshot_dir() + '/samples.gif',
            sample_video,
        )
        gif(
            logger.get_snapshot_dir() + '/heatmaps.gif',
            heatmap_video,
        )
        report.add_image(
            logger.get_snapshot_dir() + '/samples.gif',
            "Samples GIF",
            is_url=True,
        )
        report.add_image(
            logger.get_snapshot_dir() + '/heatmaps.gif',
            "Heatmaps GIF",
            is_url=True,
        )
        report.save()
    except ImportError as e:
        print(e)
def train(
        dataset_generator,
        n_start_samples,
        projection=project_samples_square_np,
        n_samples_to_add_per_epoch=1000,
        n_epochs=100,
        z_dim=1,
        hidden_size=32,
        save_period=10,
        append_all_data=True,
        full_variant=None,
        dynamics_noise=0,
        decoder_output_var='learned',
        num_bins=5,
        skew_config=None,
        use_perfect_samples=False,
        use_perfect_density=False,
        vae_reset_period=0,
        vae_kwargs=None,
        use_dataset_generator_first_epoch=True,
        **kwargs
):

    """
    Sanitize Inputs
    """
    assert skew_config is not None
    if not (use_perfect_density and use_perfect_samples):
        assert vae_kwargs is not None
    if vae_kwargs is None:
        vae_kwargs = {}

    report = HTMLReport(
        logger.get_snapshot_dir() + '/report.html',
        images_per_row=10,
    )
    dynamics = Dynamics(projection, dynamics_noise)
    if full_variant:
        report.add_header("Variant")
        report.add_text(
            json.dumps(
                ppp.dict_to_safe_json(
                    full_variant,
                    sort=True),
                indent=2,
            )
        )

    vae, decoder, decoder_opt, encoder, encoder_opt = get_vae(
        decoder_output_var,
        hidden_size,
        z_dim,
        vae_kwargs,
    )
    vae.to(ptu.device)

    epochs = []
    losses = []
    kls = []
    log_probs = []
    hist_heatmap_imgs = []
    vae_heatmap_imgs = []
    sample_imgs = []
    entropies = []
    tvs_to_uniform = []
    entropy_gains_from_reweighting = []
    p_theta = Histogram(num_bins)
    p_new = Histogram(num_bins)

    orig_train_data = dataset_generator(n_start_samples)
    train_data = orig_train_data
    start = time.time()
    for epoch in progressbar(range(n_epochs)):
        p_theta = Histogram(num_bins)
        if epoch == 0 and use_dataset_generator_first_epoch:
            vae_samples = dataset_generator(n_samples_to_add_per_epoch)
        else:
            if use_perfect_samples and epoch != 0:
                # Ideally the VAE = p_new, but in practice, it won't be...
                vae_samples = p_new.sample(n_samples_to_add_per_epoch)
            else:
                vae_samples = vae.sample(n_samples_to_add_per_epoch)
        projected_samples = dynamics(vae_samples)
        if append_all_data:
            train_data = np.vstack((train_data, projected_samples))
        else:
            train_data = np.vstack((orig_train_data, projected_samples))

        p_theta.fit(train_data)
        if use_perfect_density:
            prob = p_theta.compute_density(train_data)
        else:
            prob = vae.compute_density(train_data)
        all_weights = prob_to_weight(prob, skew_config)
        p_new.fit(train_data, weights=all_weights)
        if epoch == 0 or (epoch + 1) % save_period == 0:
            epochs.append(epoch)
            report.add_text("Epoch {}".format(epoch))
            hist_heatmap_img = visualize_histogram(p_theta, skew_config, report)
            vae_heatmap_img = visualize_vae(
                vae, skew_config, report,
                resolution=num_bins,
            )
            sample_img = visualize_vae_samples(
                epoch, train_data, vae, report, dynamics,
            )

            visualize_samples(
                p_theta.sample(n_samples_to_add_per_epoch),
                report,
                title="P Theta/RB Samples",
            )
            visualize_samples(
                p_new.sample(n_samples_to_add_per_epoch),
                report,
                title="P Adjusted Samples",
            )
            hist_heatmap_imgs.append(hist_heatmap_img)
            vae_heatmap_imgs.append(vae_heatmap_img)
            sample_imgs.append(sample_img)
            report.save()

            Image.fromarray(hist_heatmap_img).save(
                logger.get_snapshot_dir() + '/hist_heatmap{}.png'.format(epoch)
            )
            Image.fromarray(vae_heatmap_img).save(
                logger.get_snapshot_dir() + '/hist_heatmap{}.png'.format(epoch)
            )
            Image.fromarray(sample_img).save(
                logger.get_snapshot_dir() + '/samples{}.png'.format(epoch)
            )

        """
        train VAE to look like p_new
        """
        if sum(all_weights) == 0:
            all_weights[:] = 1
        if vae_reset_period > 0 and epoch % vae_reset_period == 0:
            vae, decoder, decoder_opt, encoder, encoder_opt = get_vae(
                decoder_output_var,
                hidden_size,
                z_dim,
                vae_kwargs,
            )
            vae.to(ptu.device)
        vae.fit(train_data, weights=all_weights)
        epoch_stats = vae.get_epoch_stats()

        losses.append(np.mean(epoch_stats['losses']))
        kls.append(np.mean(epoch_stats['kls']))
        log_probs.append(np.mean(epoch_stats['log_probs']))
        entropies.append(p_theta.entropy())
        tvs_to_uniform.append(p_theta.tv_to_uniform())
        entropy_gain = p_new.entropy() - p_theta.entropy()
        entropy_gains_from_reweighting.append(entropy_gain)

        for k in sorted(epoch_stats.keys()):
            logger.record_tabular(k, epoch_stats[k])

        logger.record_tabular("Epoch", epoch)
        logger.record_tabular('Entropy ', p_theta.entropy())
        logger.record_tabular('KL from uniform', p_theta.kl_from_uniform())
        logger.record_tabular('TV to uniform', p_theta.tv_to_uniform())
        logger.record_tabular('Entropy gain from reweight', entropy_gain)
        logger.record_tabular('Total Time (s)', time.time() - start)
        logger.dump_tabular()
        logger.save_itr_params(epoch, {
            'vae': vae,
            'train_data': train_data,
            'vae_samples': vae_samples,
            'dynamics': dynamics,
        })

    report.add_header("Training Curves")
    plot_curves(
        [
            ("Training Loss", losses),
            ("KL", kls),
            ("Log Probs", log_probs),
            ("Entropy Gain from Reweighting", entropy_gains_from_reweighting),
        ],
        report,
    )
    plot_curves(
        [
            ("Entropy", entropies),
            ("TV to Uniform", tvs_to_uniform),
        ],
        report,
    )
    report.add_text("Max entropy: {}".format(p_theta.max_entropy()))
    report.save()

    for filename, imgs in [
        ("hist_heatmaps", hist_heatmap_imgs),
        ("vae_heatmaps", vae_heatmap_imgs),
        ("samples", sample_imgs),
    ]:
        video = np.stack(imgs)
        vwrite(
            logger.get_snapshot_dir() + '/{}.mp4'.format(filename),
            video,
        )
        local_gif_file_path = '{}.gif'.format(filename)
        gif_file_path = '{}/{}'.format(
            logger.get_snapshot_dir(),
            local_gif_file_path
        )
        gif(gif_file_path, video)
        report.add_image(local_gif_file_path, txt=filename, is_url=True)
    report.save()