コード例 #1
0
def get_envs(variant):
    from multiworld.core.image_env import ImageEnv
    from rlkit.envs.vae_wrapper import VAEWrappedEnv
    from rlkit.util.io import load_local_or_remote_file

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    do_state_exp = variant.get("do_state_exp", False)
    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only',
                                             False)
    presampled_goals_path = variant.get('presampled_goals_path', None)

    vae = load_local_or_remote_file(vae_path) if type(
        vae_path) is str else vae_path
    if 'env_id' in variant:
        import gym
        import multiworld
        multiworld.register_all_envs()
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])
    if not do_state_exp:
        if isinstance(env, ImageEnv):
            image_env = env
        else:
            image_env = ImageEnv(
                env,
                variant.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        if presample_goals:
            """
            This will fail for online-parallel as presampled_goals will not be
            serialized. Also don't use this for online-vae.
            """
            if presampled_goals_path is None:
                image_env.non_presampled_goal_img_is_garbage = True
                vae_env = VAEWrappedEnv(
                    image_env,
                    vae,
                    imsize=image_env.imsize,
                    decode_goals=render,
                    render_goals=render,
                    render_rollouts=render,
                    reward_params=reward_params,
                    **variant.get('vae_wrapped_env_kwargs', {})
                )
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    env=vae_env,
                    env_id=variant.get('env_id', None),
                    **variant['goal_generation_kwargs']
                )
                del vae_env
            else:
                presampled_goals = load_local_or_remote_file(
                    presampled_goals_path
                ).item()
            del image_env
            image_env = ImageEnv(
                env,
                variant.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
                presampled_goals=presampled_goals,
                **variant.get('image_env_kwargs', {})
            )
            vae_env = VAEWrappedEnv(
                image_env,
                vae,
                imsize=image_env.imsize,
                decode_goals=render,
                render_goals=render,
                render_rollouts=render,
                reward_params=reward_params,
                presampled_goals=presampled_goals,
                **variant.get('vae_wrapped_env_kwargs', {})
            )
            print("Presampling all goals only")
        else:
            vae_env = VAEWrappedEnv(
                image_env,
                vae,
                imsize=image_env.imsize,
                decode_goals=render,
                render_goals=render,
                render_rollouts=render,
                reward_params=reward_params,
                **variant.get('vae_wrapped_env_kwargs', {})
            )
            if presample_image_goals_only:
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=vae_env.wrapped_env,
                    **variant['goal_generation_kwargs']
                )
                image_env.set_presampled_goals(presampled_goals)
                print("Presampling image goals only")
            else:
                print("Not using presampled goals")

        env = vae_env

    return env
コード例 #2
0
def get_envs(variant):
    from multiworld.core.image_env import ImageEnv
    from railrl.envs.vae_wrappers import VAEWrappedEnv
    from railrl.misc.asset_loader import load_local_or_remote_file

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reproj_vae_path = variant.get("reproj_vae_path", None)
    ckpt = variant.get("ckpt", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    do_state_exp = variant.get("do_state_exp", False)

    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only',
                                             False)
    presampled_goals_path = variant.get('presampled_goals_path', None)

    if not do_state_exp and type(ckpt) is str:
        vae = load_local_or_remote_file(osp.join(ckpt, 'vae.pkl'))
        if vae is not None:
            from railrl.core import logger
            logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    else:
        vae = None

    if vae is None and type(vae_path) is str:
        vae = load_local_or_remote_file(osp.join(vae_path, 'vae_params.pkl'))
        from railrl.core import logger

        logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    elif vae is None:
        vae = vae_path

    if type(vae) is str:
        vae = load_local_or_remote_file(vae)
    else:
        vae = vae

    if type(reproj_vae_path) is str:
        reproj_vae = load_local_or_remote_file(
            osp.join(reproj_vae_path, 'vae_params.pkl'))
    else:
        reproj_vae = None

    if 'env_id' in variant:
        import gym
        # trigger registration
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])
    if not do_state_exp:
        if isinstance(env, ImageEnv):
            image_env = env
        else:
            image_env = ImageEnv(
                env,
                variant.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        vae_env = VAEWrappedEnv(image_env,
                                vae,
                                imsize=image_env.imsize,
                                decode_goals=render,
                                render_goals=render,
                                render_rollouts=render,
                                reward_params=reward_params,
                                reproj_vae=reproj_vae,
                                **variant.get('vae_wrapped_env_kwargs', {}))
        if presample_goals:
            """
            This will fail for online-parallel as presampled_goals will not be
            serialized. Also don't use this for online-vae.
            """
            if presampled_goals_path is None:
                image_env.non_presampled_goal_img_is_garbage = True
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=image_env, **variant['goal_generation_kwargs'])
            else:
                presampled_goals = load_local_or_remote_file(
                    presampled_goals_path).item()
                presampled_goals = {
                    'state_desired_goal': presampled_goals['next_obs_state'],
                    'image_desired_goal': presampled_goals['next_obs'],
                }

            image_env.set_presampled_goals(presampled_goals)
            vae_env.set_presampled_goals(presampled_goals)
            print("Presampling all goals")
        else:
            if presample_image_goals_only:
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=vae_env.wrapped_env,
                    **variant['goal_generation_kwargs'])
                image_env.set_presampled_goals(presampled_goals)
                print("Presampling image goals only")
            else:
                print("Not using presampled goals")

        env = vae_env

    if not do_state_exp:
        training_mode = variant.get("training_mode", "train")
        testing_mode = variant.get("testing_mode", "test")
        env.add_mode('eval', testing_mode)
        env.add_mode('train', training_mode)
        env.add_mode('relabeling', training_mode)
        # relabeling_env.disable_render()
        env.add_mode("video_vae", 'video_vae')
        env.add_mode("video_env", 'video_env')
    return env
コード例 #3
0
def get_video_save_func(
    rollout_function,
    env,
    variant,
):
    from multiworld.core.image_env import ImageEnv
    from railrl.core import logger
    from railrl.envs.vae_wrappers import temporary_mode
    from railrl.misc.video_gen import dump_video
    logdir = logger.get_snapshot_dir()

    vis_variant = variant.get('vis_kwargs', {})
    save_period = vis_variant.get('save_period', 50)
    do_state_exp = variant.get("do_state_exp", False)
    dump_video_kwargs = variant.get("dump_video_kwargs", dict())

    vis_variant = variant.get('vis_kwargs', {})
    vis_blacklist = vis_variant.get('vis_blacklist', [])
    dump_video_kwargs['vis_blacklist'] = vis_blacklist

    if do_state_exp:
        imsize = variant.get('imsize')
        dump_video_kwargs['imsize'] = imsize
        image_env = ImageEnv(
            env,
            imsize,
            init_camera=variant.get('init_camera', None),
            transpose=True,
            normalize=True,
        )

        if 'pick' in env.__module__:
            from multiworld.envs.mujoco.sawyer_xyz.sawyer_pick_and_place import get_image_presampled_goals
            num_goals_presampled = vis_variant.get('num_goals_presampled', 100)
            image_goals = get_image_presampled_goals(image_env,
                                                     num_goals_presampled)
            image_env.set_presampled_goals(image_goals)

        def save_video(algo, epoch):
            dump_video_kwargs["epoch"] = epoch
            if hasattr(algo, "qf1"):
                dump_video_kwargs['qf'] = algo.qf1
            if hasattr(algo, "vf"):
                dump_video_kwargs['vf'] = algo.vf

            if epoch % save_period == 0 or epoch == algo.num_epochs - 1:
                filename = osp.join(logdir,
                                    'video_{epoch}.mp4'.format(epoch=epoch))
                dump_video(image_env, algo.eval_policy, filename,
                           rollout_function, **dump_video_kwargs)

                if vis_variant.get('save_video_exp_policy', False):
                    filename = osp.join(
                        logdir, 'video_{epoch}_exp.mp4'.format(epoch=epoch))
                    dump_video(image_env, algo.exploration_policy, filename,
                               rollout_function, **dump_video_kwargs)
    else:
        image_env = env
        dump_video_kwargs['imsize'] = env.imsize

        def save_video(algo, epoch):
            dump_video_kwargs["epoch"] = epoch
            if hasattr(algo, "qf1"):
                dump_video_kwargs['qf'] = algo.qf1
            if hasattr(algo, "vf"):
                dump_video_kwargs['vf'] = algo.vf

            if epoch % save_period == 0 or epoch == algo.num_epochs - 1:
                filename = osp.join(logdir,
                                    'video_{epoch}.mp4'.format(epoch=epoch))
                temporary_mode(image_env,
                               mode='video_env',
                               func=dump_video,
                               args=(image_env, algo.eval_policy, filename,
                                     rollout_function),
                               kwargs=dump_video_kwargs)

                if vis_variant.get('save_video_exp_policy', False):
                    filename = osp.join(
                        logdir, 'video_{epoch}_exp.mp4'.format(epoch=epoch))
                    temporary_mode(image_env,
                                   mode='video_env',
                                   func=dump_video,
                                   args=(image_env, algo.exploration_policy,
                                         filename, rollout_function),
                                   kwargs=dump_video_kwargs)

                if not vis_variant.get('save_video_env_only', True):
                    filename = osp.join(
                        logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
                    temporary_mode(image_env,
                                   mode='video_vae',
                                   func=dump_video,
                                   args=(image_env, algo.eval_policy, filename,
                                         rollout_function),
                                   kwargs=dump_video_kwargs)

    return save_video
コード例 #4
0
def get_envs(variant):

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only',
                                             False)
    presampled_goals_path = variant.get('presampled_goals_path', None)

    vae = load_local_or_remote_file(
        vae_path) if type(vae_path) is str else vae_path
    if 'env_id' in variant:
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])

    if isinstance(env, ImageEnv):
        image_env = env
    else:
        image_env = ImageEnv(
            env,
            variant.get('imsize'),
            init_camera=init_camera,
            transpose=True,
            normalize=True,
        )
    if presample_goals:
        """
        This will fail for online-parallel as presampled_goals will not be
        serialized. Also don't use this for online-vae.
        """
        if presampled_goals_path is None:
            image_env.non_presampled_goal_img_is_garbage = True
            vae_env = VAEWrappedEnv(image_env,
                                    vae,
                                    imsize=image_env.imsize,
                                    decode_goals=render,
                                    render_goals=render,
                                    render_rollouts=render,
                                    reward_params=reward_params,
                                    **variant.get('vae_wrapped_env_kwargs',
                                                  {}))
            presampled_goals = variant['generate_goal_dataset_fctn'](
                env=vae_env,
                env_id=variant.get('env_id', None),
                **variant['goal_generation_kwargs'])
            del vae_env
        else:
            presampled_goals = load_local_or_remote_file(
                presampled_goals_path).item()
        del image_env
        image_env = ImageEnv(env,
                             variant.get('imsize'),
                             init_camera=init_camera,
                             transpose=True,
                             normalize=True,
                             presampled_goals=presampled_goals,
                             **variant.get('image_env_kwargs', {}))
        vae_env = VAEWrappedEnv(image_env,
                                vae,
                                imsize=image_env.imsize,
                                decode_goals=render,
                                render_goals=render,
                                render_rollouts=render,
                                reward_params=reward_params,
                                presampled_goals=presampled_goals,
                                **variant.get('vae_wrapped_env_kwargs', {}))
        print("Presampling all goals only")
    else:
        vae_env = VAEWrappedEnv(image_env,
                                vae,
                                imsize=image_env.imsize,
                                decode_goals=render,
                                render_goals=render,
                                render_rollouts=render,
                                reward_params=reward_params,
                                **variant.get('vae_wrapped_env_kwargs', {}))
        if presample_image_goals_only:
            presampled_goals = variant['generate_goal_dataset_fctn'](
                image_env=vae_env.wrapped_env,
                **variant['goal_generation_kwargs'])
            image_env.set_presampled_goals(presampled_goals)
            print("Presampling image goals only")
        else:
            print("Not using presampled goals")

    env = vae_env

    training_mode = variant.get("training_mode", "train")
    testing_mode = variant.get("testing_mode", "test")
    env.add_mode('eval', testing_mode)
    env.add_mode('train', training_mode)
    env.add_mode('relabeling', training_mode)
    # relabeling_env.disable_render()
    env.add_mode("video_vae", 'video_vae')
    env.add_mode("video_env", 'video_env')
    return env