コード例 #1
0
    def visualization_post_processing(save_vis, save_video, epoch):
        vis_list = vis_variant.get('vis_list', [])

        if save_vis:
            if vae_env.vae_input_key_prefix == 'state':
                vae_env.dump_reconstructions(epoch,
                                             n_recon=vis_variant.get(
                                                 'n_recon', 16))
            vae_env.dump_samples(epoch,
                                 n_samples=vis_variant.get('n_samples', 64))
            if 'latent_representation' in vis_list:
                vae_env.dump_latent_plots(epoch)
            if any(elem in vis_list for elem in [
                    'latent_histogram', 'latent_histogram_mu',
                    'latent_histogram_2d', 'latent_histogram_mu_2d'
            ]):
                vae_env.compute_latent_histogram()
            if not save_video and ('latent_histogram' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=True,
                                              use_true_prior=True)
            if not save_video and ('latent_histogram_mu' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=False,
                                              use_true_prior=True)

        if save_video and save_vis:
            from railrl.envs.vae_wrappers import temporary_mode
            from railrl.misc.video_gen import dump_video
            from railrl.core import logger

            vae_env.compute_goal_encodings()

            logdir = logger.get_snapshot_dir()
            filename = osp.join(logdir,
                                'video_{epoch}.mp4'.format(epoch=epoch))
            variant['dump_video_kwargs']['epoch'] = epoch
            temporary_mode(vae_env,
                           mode='video_env',
                           func=dump_video,
                           args=(vae_env, random_policy, filename,
                                 rollout_function),
                           kwargs=variant['dump_video_kwargs'])
            if not vis_variant.get('save_video_env_only', True):
                filename = osp.join(
                    logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
                temporary_mode(vae_env,
                               mode='video_vae',
                               func=dump_video,
                               args=(vae_env, random_policy, filename,
                                     rollout_function),
                               kwargs=variant['dump_video_kwargs'])
コード例 #2
0
 def save_video(algo, epoch):
     if epoch % save_period == 0 or epoch == algo.num_epochs:
         filename = osp.join(
             logdir, 'video_{epoch}_env.mp4'.format(epoch=epoch))
         temporary_mode(image_env,
                        mode='video_env',
                        func=dump_video,
                        args=(image_env, policy, filename,
                              rollout_function),
                        kwargs=dump_video_kwargs)
         filename = osp.join(
             logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
         temporary_mode(image_env,
                        mode='video_vae',
                        func=dump_video,
                        args=(image_env, policy, filename,
                              rollout_function),
                        kwargs=dump_video_kwargs)
コード例 #3
0
        def save_video(algo, epoch):
            dump_video_kwargs["epoch"] = epoch
            if hasattr(algo, "qf1"):
                dump_video_kwargs['qf'] = algo.qf1
            if hasattr(algo, "vf"):
                dump_video_kwargs['vf'] = algo.vf

            if epoch % save_period == 0 or epoch == algo.num_epochs - 1:
                filename = osp.join(logdir, 'video_{epoch}.mp4'.format(epoch=epoch))
                temporary_mode(
                    image_env,
                    mode='video_env',
                    func=dump_video,
                    args=(image_env, algo.eval_policy, filename, rollout_function),
                    kwargs=dump_video_kwargs
                )

                if vis_variant.get('save_video_exp_policy', False):
                    filename = osp.join(logdir, 'video_{epoch}_exp.mp4'.format(epoch=epoch))
                    temporary_mode(
                        image_env,
                        mode='video_env',
                        func=dump_video,
                        args=(image_env, algo.exploration_policy, filename, rollout_function),
                        kwargs=dump_video_kwargs
                    )


                if not vis_variant.get('save_video_env_only', True):
                    filename = osp.join(logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
                    temporary_mode(
                        image_env,
                        mode='video_vae',
                        func=dump_video,
                        args=(image_env, algo.eval_policy, filename, rollout_function),
                        kwargs=dump_video_kwargs
                    )