def experiment(variant):
    rdim = variant["rdim"]
    vae_paths = {
        2:
        "/home/ashvin/data/s3doodad/ashvin/vae/point2d-conv-sweep2/run0/id1/params.pkl",
        4:
        "/home/ashvin/data/s3doodad/ashvin/vae/point2d-conv-sweep2/run0/id4/params.pkl"
    }
    vae_path = vae_paths[rdim]
    vae = joblib.load(vae_path)
    print("loaded", vae_path)

    if variant['multitask']:
        env = MultitaskImagePoint2DEnv(**variant['env_kwargs'])
        env = VAEWrappedEnv(env,
                            vae,
                            use_vae_obs=True,
                            use_vae_reward=False,
                            use_vae_goals=False)
        env = MultitaskToFlatEnv(env)
    # else:
    # env = Pusher2DEnv(**variant['env_kwargs'])
    if variant['normalize']:
        env = NormalizedBoxEnv(env)
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space)
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            max_sigma=0.1,
            min_sigma=0.1,  # Constant sigma
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            prob_random_action=0.1,
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    obs_dim = env.observation_space.low.size
    action_dim = env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 300],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 300],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=[400, 300],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = TD3(env,
                    training_env=env,
                    qf1=qf1,
                    qf2=qf2,
                    policy=policy,
                    exploration_policy=exploration_policy,
                    **variant['algo_kwargs'])
    print("use_gpu", variant["use_gpu"], bool(variant["use_gpu"]))
    if variant["use_gpu"]:
        gpu_id = variant["gpu_id"]
        ptu.set_gpu_mode(True)
        ptu.set_device(gpu_id)
        algorithm.to(ptu.device)
        env._wrapped_env.vae.to(ptu.device)
    algorithm.train()
Beispiel #2
0
def get_envs(variant):
    from multiworld.core.image_env import ImageEnv
    from railrl.envs.vae_wrappers import VAEWrappedEnv
    from railrl.misc.asset_loader import load_local_or_remote_file

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reproj_vae_path = variant.get("reproj_vae_path", None)
    ckpt = variant.get("ckpt", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    do_state_exp = variant.get("do_state_exp", False)

    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only', False)
    presampled_goals_path = variant.get('presampled_goals_path', None)

    if not do_state_exp and type(ckpt) is str:
        vae = load_local_or_remote_file(osp.join(ckpt, 'vae.pkl'))
        if vae is not None:
            from railrl.core import logger
            logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    else:
        vae = None

    if vae is None and type(vae_path) is str:
        vae = load_local_or_remote_file(osp.join(vae_path, 'vae_params.pkl'))
        from railrl.core import logger

        logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    elif vae is None:
        vae = vae_path

    if type(vae) is str:
        vae = load_local_or_remote_file(vae)
    else:
        vae = vae

    if type(reproj_vae_path) is str:
        reproj_vae = load_local_or_remote_file(osp.join(reproj_vae_path, 'vae_params.pkl'))
    else:
        reproj_vae = None

    if 'env_id' in variant:
        import gym
        # trigger registration
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])
    if not do_state_exp:
        if isinstance(env, ImageEnv):
            image_env = env
        else:
            image_env = ImageEnv(
                env,
                variant.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        vae_env = VAEWrappedEnv(
            image_env,
            vae,
            imsize=image_env.imsize,
            decode_goals=render,
            render_goals=render,
            render_rollouts=render,
            reward_params=reward_params,
            reproj_vae=reproj_vae,
            **variant.get('vae_wrapped_env_kwargs', {})
        )
        if presample_goals:
            """
            This will fail for online-parallel as presampled_goals will not be
            serialized. Also don't use this for online-vae.
            """
            if presampled_goals_path is None:
                image_env.non_presampled_goal_img_is_garbage = True
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=image_env,
                    **variant['goal_generation_kwargs']
                )
            else:
                presampled_goals = load_local_or_remote_file(
                    presampled_goals_path
                ).item()
                presampled_goals = {
                    'state_desired_goal': presampled_goals['next_obs_state'],
                    'image_desired_goal': presampled_goals['next_obs'],
                }

            image_env.set_presampled_goals(presampled_goals)
            vae_env.set_presampled_goals(presampled_goals)
            print("Presampling all goals")
        else:
            if presample_image_goals_only:
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=vae_env.wrapped_env,
                    **variant['goal_generation_kwargs']
                )
                image_env.set_presampled_goals(presampled_goals)
                print("Presampling image goals only")
            else:
                print("Not using presampled goals")

        env = vae_env

    if not do_state_exp:
        training_mode = variant.get("training_mode", "train")
        testing_mode = variant.get("testing_mode", "test")
        env.add_mode('eval', testing_mode)
        env.add_mode('train', training_mode)
        env.add_mode('relabeling', training_mode)
        # relabeling_env.disable_render()
        env.add_mode("video_vae", 'video_vae')
        env.add_mode("video_env", 'video_env')
    return env
Beispiel #3
0
def _disentangled_grill_her_twin_sac_experiment(
        max_path_length,
        encoder_kwargs,
        disentangled_qf_kwargs,
        qf_kwargs,
        twin_sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        vae_evaluation_goal_sampling_mode,
        vae_exploration_goal_sampling_mode,
        base_env_evaluation_goal_sampling_mode,
        base_env_exploration_goal_sampling_mode,
        algo_kwargs,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='state_observation',
        desired_goal_key='state_desired_goal',
        achieved_goal_key='state_achieved_goal',
        latent_dim=2,
        vae_wrapped_env_kwargs=None,
        vae_path=None,
        vae_n_vae_training_kwargs=None,
        vectorized=False,
        save_video=True,
        save_video_kwargs=None,
        have_no_disentangled_encoder=False,
        **kwargs):
    if env_kwargs is None:
        env_kwargs = {}
    assert env_id or env_class

    if env_id:
        import gym
        import multiworld
        multiworld.register_all_envs()
        train_env = gym.make(env_id)
        eval_env = gym.make(env_id)
    else:
        eval_env = env_class(**env_kwargs)
        train_env = env_class(**env_kwargs)

    train_env.goal_sampling_mode = base_env_exploration_goal_sampling_mode
    eval_env.goal_sampling_mode = base_env_evaluation_goal_sampling_mode

    if vae_path:
        vae = load_local_or_remote_file(vae_path)
    else:
        vae = get_n_train_vae(latent_dim=latent_dim,
                              env=eval_env,
                              **vae_n_vae_training_kwargs)

    train_env = VAEWrappedEnv(train_env,
                              vae,
                              imsize=train_env.imsize,
                              **vae_wrapped_env_kwargs)
    eval_env = VAEWrappedEnv(eval_env,
                             vae,
                             imsize=train_env.imsize,
                             **vae_wrapped_env_kwargs)

    obs_dim = train_env.observation_space.spaces[observation_key].low.size
    goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size
    action_dim = train_env.action_space.low.size

    encoder = FlattenMlp(input_size=obs_dim,
                         output_size=latent_dim,
                         **encoder_kwargs)

    def make_qf():
        if have_no_disentangled_encoder:
            return FlattenMlp(
                input_size=obs_dim + goal_dim + action_dim,
                output_size=1,
                **qf_kwargs,
            )
        else:
            return DisentangledMlpQf(goal_processor=encoder,
                                     preprocess_obs_dim=obs_dim,
                                     action_dim=action_dim,
                                     qf_kwargs=qf_kwargs,
                                     vectorized=vectorized,
                                     **disentangled_qf_kwargs)

    qf1 = make_qf()
    qf2 = make_qf()
    target_qf1 = make_qf()
    target_qf2 = make_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    replay_buffer = ObsDictRelabelingBuffer(
        env=train_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        vectorized=vectorized,
        **replay_buffer_kwargs)
    sac_trainer = SACTrainer(env=train_env,
                             policy=policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             **twin_sac_trainer_kwargs)
    trainer = HERTrainer(sac_trainer)

    eval_path_collector = VAEWrappedEnvPathCollector(
        eval_env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=vae_evaluation_goal_sampling_mode,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        train_env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=vae_exploration_goal_sampling_mode,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=train_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs,
    )
    algorithm.to(ptu.device)

    if save_video:
        save_vf_heatmap = save_video_kwargs.get('save_vf_heatmap', True)

        if have_no_disentangled_encoder:

            def v_function(obs):
                action = policy.get_actions(obs)
                obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
                return qf1(obs, action)

            add_heatmap = partial(add_heatmap_img_to_o_dict,
                                  v_function=v_function)
        else:

            def v_function(obs):
                action = policy.get_actions(obs)
                obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
                return qf1(obs, action, return_individual_q_vals=True)

            add_heatmap = partial(
                add_heatmap_imgs_to_o_dict,
                v_function=v_function,
                vectorized=vectorized,
            )
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            full_o_postprocess_func=add_heatmap if save_vf_heatmap else None,
        )
        img_keys = ['v_vals'] + [
            'v_vals_dim_{}'.format(dim) for dim in range(latent_dim)
        ]
        eval_video_func = get_save_video_function(rollout_function,
                                                  eval_env,
                                                  MakeDeterministic(policy),
                                                  get_extra_imgs=partial(
                                                      get_extra_imgs,
                                                      img_keys=img_keys),
                                                  tag="eval",
                                                  **save_video_kwargs)
        train_video_func = get_save_video_function(rollout_function,
                                                   train_env,
                                                   policy,
                                                   get_extra_imgs=partial(
                                                       get_extra_imgs,
                                                       img_keys=img_keys),
                                                   tag="train",
                                                   **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(train_video_func)
    algorithm.train()
def get_envs(variant):
    from multiworld.core.image_env import ImageEnv
    from railrl.envs.vae_wrappers import VAEWrappedEnv, ConditionalVAEWrappedEnv
    from railrl.misc.asset_loader import load_local_or_remote_file
    from railrl.torch.vae.conditional_conv_vae import CVAE, CDVAE, ACE, CADVAE, DeltaCVAE

    render = variant.get('render', False)
    vae_path = variant.get("vae_path", None)
    reward_params = variant.get("reward_params", dict())
    init_camera = variant.get("init_camera", None)
    do_state_exp = variant.get("do_state_exp", False)
    presample_goals = variant.get('presample_goals', False)
    presample_image_goals_only = variant.get('presample_image_goals_only',
                                             False)
    presampled_goals_path = get_presampled_goals_path(
        variant.get('presampled_goals_path', None))
    vae = load_local_or_remote_file(
        vae_path) if type(vae_path) is str else vae_path
    if 'env_id' in variant:
        import gym
        import multiworld
        multiworld.register_all_envs()
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])
    if not do_state_exp:
        if isinstance(env, ImageEnv):
            image_env = env
        else:
            image_env = ImageEnv(
                env,
                variant.get('imsize'),
                init_camera=init_camera,
                transpose=True,
                normalize=True,
            )
        if presample_goals:
            """
            This will fail for online-parallel as presampled_goals will not be
            serialized. Also don't use this for online-vae.
            """
            if presampled_goals_path is None:
                image_env.non_presampled_goal_img_is_garbage = True
                vae_env = VAEWrappedEnv(image_env,
                                        vae,
                                        imsize=image_env.imsize,
                                        decode_goals=render,
                                        render_goals=render,
                                        render_rollouts=render,
                                        reward_params=reward_params,
                                        **variant.get('vae_wrapped_env_kwargs',
                                                      {}))
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    env=vae_env,
                    env_id=variant.get('env_id', None),
                    **variant['goal_generation_kwargs'])
                del vae_env
            else:
                presampled_goals = load_local_or_remote_file(
                    presampled_goals_path).item()
            del image_env
            image_env = ImageEnv(env,
                                 variant.get('imsize'),
                                 init_camera=init_camera,
                                 transpose=True,
                                 normalize=True,
                                 presampled_goals=presampled_goals,
                                 **variant.get('image_env_kwargs', {}))
            vae_env = VAEWrappedEnv(image_env,
                                    vae,
                                    imsize=image_env.imsize,
                                    decode_goals=render,
                                    render_goals=render,
                                    render_rollouts=render,
                                    reward_params=reward_params,
                                    presampled_goals=presampled_goals,
                                    **variant.get('vae_wrapped_env_kwargs',
                                                  {}))
            print("Presampling all goals only")
        else:
            if type(vae) is CVAE or type(vae) is CDVAE or type(
                    vae) is ACE or type(vae) is CADVAE or type(
                        vae) is DeltaCVAE:
                vae_env = ConditionalVAEWrappedEnv(
                    image_env,
                    vae,
                    imsize=image_env.imsize,
                    decode_goals=render,
                    render_goals=render,
                    render_rollouts=render,
                    reward_params=reward_params,
                    **variant.get('vae_wrapped_env_kwargs', {}))
            else:
                vae_env = VAEWrappedEnv(image_env,
                                        vae,
                                        imsize=image_env.imsize,
                                        decode_goals=render,
                                        render_goals=render,
                                        render_rollouts=render,
                                        reward_params=reward_params,
                                        **variant.get('vae_wrapped_env_kwargs',
                                                      {}))
            if presample_image_goals_only:
                presampled_goals = variant['generate_goal_dataset_fctn'](
                    image_env=vae_env.wrapped_env,
                    **variant['goal_generation_kwargs'])
                image_env.set_presampled_goals(presampled_goals)
                print("Presampling image goals only")
            else:
                print("Not using presampled goals")

        env = vae_env

    return env
Beispiel #5
0
def grill_her_sac_experiment(variant):
    env = variant["env_class"](**variant['env_kwargs'])

    render = variant["render"]

    rdim = variant["rdim"]
    vae_path = variant["vae_paths"][str(rdim)]
    reward_params = variant.get("reward_params", dict())

    init_camera = variant.get("init_camera", None)
    if init_camera is None:
        camera_name = "topview"
    else:
        camera_name = None

    env = ImageEnv(
        env,
        84,
        init_camera=init_camera,
        camera_name=camera_name,
        transpose=True,
        normalize=True,
    )

    env = VAEWrappedEnv(
        env,
        vae_path,
        decode_goals=render,
        render_goals=render,
        render_rollouts=render,
        reward_params=reward_params,
        **variant.get('vae_wrapped_env_kwargs', {})
    )

    if variant['normalize']:
        env = NormalizedBoxEnv(env)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (
        env.observation_space.spaces[observation_key].low.size
        + env.observation_space.spaces[desired_goal_key].low.size
    )
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    vf = FlattenMlp(
        input_size=obs_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=hidden_sizes,
    )

    training_mode = variant.get("training_mode", "train")
    testing_mode = variant.get("testing_mode", "test")

    testing_env = pickle.loads(pickle.dumps(env))
    testing_env.mode(testing_mode)

    training_env = pickle.loads(pickle.dumps(env))
    training_env.mode(training_mode)

    relabeling_env = pickle.loads(pickle.dumps(env))
    relabeling_env.mode(training_mode)
    relabeling_env.disable_render()

    video_vae_env = pickle.loads(pickle.dumps(env))
    video_vae_env.mode("video_vae")
    video_goal_env = pickle.loads(pickle.dumps(env))
    video_goal_env.mode("video_env")


    replay_buffer = ObsDictRelabelingBuffer(
        env=relabeling_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_kwargs']
    )
    variant["algo_kwargs"]["replay_buffer"] = replay_buffer
    algorithm = HerSac(
        testing_env,
        training_env=training_env,
        qf=qf,
        vf=vf,
        policy=policy,
        render=render,
        render_during_eval=render,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        **variant['algo_kwargs']
    )

    if ptu.gpu_enabled():
        print("using GPU")
        qf.to(ptu.device)
        vf.to(ptu.device)
        policy.to(ptu.device)
        algorithm.to(ptu.device)
        for e in [testing_env, training_env, video_vae_env, video_goal_env]:
            e.vae.to(ptu.device)

    algorithm.train()

    if variant.get("save_video", True):
        logdir = logger.get_snapshot_dir()
        policy.train(False)
        filename = osp.join(logdir, 'video_final_env.mp4')
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        dump_video(video_goal_env, policy, filename, rollout_function)
        filename = osp.join(logdir, 'video_final_vae.mp4')
        dump_video(video_vae_env, policy, filename, rollout_function)
Beispiel #6
0
def grill_her_td3_experiment(variant):
    env = variant["env_class"](**variant['env_kwargs'])

    render = variant["render"]

    rdim = variant["rdim"]
    vae_path = variant["vae_paths"][str(rdim)]
    reward_params = variant.get("reward_params", dict())

    init_camera = variant.get("init_camera", None)
    if init_camera is None:
        camera_name = "topview"
    else:
        camera_name = None

    env = ImageEnv(
        env,
        84,
        init_camera=init_camera,
        camera_name=camera_name,
        transpose=True,
        normalize=True,
    )

    env = VAEWrappedEnv(
        env,
        vae_path,
        decode_goals=render,
        render_goals=render,
        render_rollouts=render,
        reward_params=reward_params,
        **variant.get('vae_wrapped_env_kwargs', {})
    )

    if variant['normalize']:
        env = NormalizedBoxEnv(env)
    exploration_type = variant['exploration_type']
    exploration_noise = variant.get('exploration_noise', 0.1)
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space)
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            max_sigma=exploration_noise,
            min_sigma=exploration_noise,  # Constant sigma
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            prob_random_action=exploration_noise,
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (
        env.observation_space.spaces[observation_key].low.size
        + env.observation_space.spaces[desired_goal_key].low.size
    )
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    training_mode = variant.get("training_mode", "train")
    testing_mode = variant.get("testing_mode", "test")

    testing_env = pickle.loads(pickle.dumps(env))
    testing_env.mode(testing_mode)

    training_env = pickle.loads(pickle.dumps(env))
    training_env.mode(training_mode)

    relabeling_env = pickle.loads(pickle.dumps(env))
    relabeling_env.mode(training_mode)
    relabeling_env.disable_render()

    video_vae_env = pickle.loads(pickle.dumps(env))
    video_vae_env.mode("video_vae")
    video_goal_env = pickle.loads(pickle.dumps(env))
    video_goal_env.mode("video_env")


    replay_buffer = ObsDictRelabelingBuffer(
        env=relabeling_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_kwargs']
    )
    variant["algo_kwargs"]["replay_buffer"] = replay_buffer
    algorithm = HerTd3(
        testing_env,
        training_env=training_env,
        qf1=qf1,
        qf2=qf2,
        policy=policy,
        exploration_policy=exploration_policy,
        render=render,
        render_during_eval=render,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        **variant['algo_kwargs']
    )

    if ptu.gpu_enabled():
        print("using GPU")
        algorithm.to(ptu.device)
        for e in [testing_env, training_env, video_vae_env, video_goal_env]:
            e.vae.to(ptu.device)

    algorithm.train()

    if variant.get("save_video", True):
        logdir = logger.get_snapshot_dir()
        policy.train(False)
        filename = osp.join(logdir, 'video_final_env.mp4')
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        dump_video(video_goal_env, policy, filename, rollout_function)
        filename = osp.join(logdir, 'video_final_vae.mp4')
        dump_video(video_vae_env, policy, filename, rollout_function)
def experiment(variant):
    rdim = variant["rdim"]
    use_env_goals = variant["use_env_goals"]
    vae_path = variant["vae_paths"][str(rdim)]
    render = variant["render"]
    wrap_mujoco_env = variant.get("wrap_mujoco_env", False)

    # vae = torch.load(vae_path)
    # print("loaded", vae_path)

    from railrl.envs.wrappers import ImageMujocoEnv, NormalizedBoxEnv
    from railrl.images.camera import sawyer_init_camera

    env = variant["env"](**variant['env_kwargs'])
    env = NormalizedBoxEnv(ImageMujocoEnv(
        env,
        imsize=84,
        keep_prev=0,
        init_camera=sawyer_init_camera,
    ))
    if wrap_mujoco_env:
        env = ImageMujocoEnv(env, 84, camera_name="topview", transpose=True, normalize=True)


    if use_env_goals:
        track_qpos_goal = variant.get("track_qpos_goal", 0)
        env = VAEWrappedImageGoalEnv(env, vae_path, use_vae_obs=True,
                                     use_vae_reward=True, use_vae_goals=True,
                                     render_goals=render, render_rollouts=render, track_qpos_goal=track_qpos_goal)
    else:
        env = VAEWrappedEnv(env, vae_path, use_vae_obs=True,
                            use_vae_reward=True, use_vae_goals=True,
                            render_goals=render, render_rollouts=render)

    env = MultitaskToFlatEnv(env)
    if variant['normalize']:
        env = NormalizedBoxEnv(env)
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space)
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            max_sigma=0.1,
            min_sigma=0.1,  # Constant sigma
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            prob_random_action=0.1,
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    obs_dim = env.observation_space.low.size
    action_dim = env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 300],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 300],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=[400, 300],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = TD3(
        env,
        training_env=env,
        qf1=qf1,
        qf2=qf2,
        policy=policy,
        exploration_policy=exploration_policy,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
        env._wrapped_env.vae.to(ptu.device)
def train_vae(variant):
    from railrl.misc.ml_util import PiecewiseLinearSchedule
    from railrl.torch.vae.conv_vae import ConvVAE
    from railrl.torch.vae.conv_vae_trainer import ConvVAETrainer
    from railrl.core import logger
    import railrl.torch.pytorch_util as ptu
    from multiworld.core.image_env import ImageEnv
    from railrl.envs.vae_wrappers import VAEWrappedEnv
    from railrl.misc.asset_loader import local_path_from_s3_or_local_path

    logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True)
    logger.add_tabular_output('vae_progress.csv',
                              relative_to_snapshot_dir=True)

    env_id = variant['generate_vae_dataset_kwargs'].get('env_id', None)
    if env_id is not None:
        import gym
        env = gym.make(env_id)
    else:
        env_class = variant['generate_vae_dataset_kwargs']['env_class']
        env_kwargs = variant['generate_vae_dataset_kwargs']['env_kwargs']
        env = env_class(**env_kwargs)

    representation_size = variant["representation_size"]
    beta = variant["beta"]
    if 'beta_schedule_kwargs' in variant:
        beta_schedule = PiecewiseLinearSchedule(
            **variant['beta_schedule_kwargs'])
    else:
        beta_schedule = None

    # obtain training and testing data
    dataset_path = variant['generate_vae_dataset_kwargs'].get(
        'dataset_path', None)
    test_p = variant['generate_vae_dataset_kwargs'].get('test_p', 0.9)
    filename = local_path_from_s3_or_local_path(dataset_path)
    dataset = np.load(filename, allow_pickle=True).item()
    N = dataset['obs'].shape[0]
    n = int(N * test_p)
    train_data = {}
    test_data = {}
    for k in dataset.keys():
        train_data[k] = dataset[k][:n, :]
        test_data[k] = dataset[k][n:, :]

    # setup vae
    variant['vae_kwargs']['action_dim'] = train_data['actions'].shape[1]
    if variant.get('vae_type', None) == "VAE-state":
        from railrl.torch.vae.vae import VAE
        input_size = train_data['obs'].shape[1]
        variant['vae_kwargs']['input_size'] = input_size
        m = VAE(representation_size, **variant['vae_kwargs'])
    elif variant.get('vae_type', None) == "VAE2":
        from railrl.torch.vae.conv_vae2 import ConvVAE2
        variant['vae_kwargs']['imsize'] = variant['imsize']
        m = ConvVAE2(representation_size, **variant['vae_kwargs'])
    else:
        variant['vae_kwargs']['imsize'] = variant['imsize']
        m = ConvVAE(representation_size, **variant['vae_kwargs'])
    if ptu.gpu_enabled():
        m.cuda()

    # setup vae trainer
    if variant.get('vae_type', None) == "VAE-state":
        from railrl.torch.vae.vae_trainer import VAETrainer
        t = VAETrainer(train_data,
                       test_data,
                       m,
                       beta=beta,
                       beta_schedule=beta_schedule,
                       **variant['algo_kwargs'])
    else:
        t = ConvVAETrainer(train_data,
                           test_data,
                           m,
                           beta=beta,
                           beta_schedule=beta_schedule,
                           **variant['algo_kwargs'])

    # visualization
    vis_variant = variant.get('vis_kwargs', {})
    save_video = vis_variant.get('save_video', False)
    if isinstance(env, ImageEnv):
        image_env = env
    else:
        image_env = ImageEnv(
            env,
            variant['generate_vae_dataset_kwargs'].get('imsize'),
            init_camera=variant['generate_vae_dataset_kwargs'].get(
                'init_camera'),
            transpose=True,
            normalize=True,
        )
    render = variant.get('render', False)
    reward_params = variant.get("reward_params", dict())
    vae_env = VAEWrappedEnv(image_env,
                            m,
                            imsize=image_env.imsize,
                            decode_goals=render,
                            render_goals=render,
                            render_rollouts=render,
                            reward_params=reward_params,
                            **variant.get('vae_wrapped_env_kwargs', {}))
    vae_env.reset()
    vae_env.add_mode("video_env", 'video_env')
    vae_env.add_mode("video_vae", 'video_vae')
    if save_video:
        import railrl.samplers.rollout_functions as rf
        from railrl.policies.simple import RandomPolicy
        random_policy = RandomPolicy(vae_env.action_space)
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=100,
            observation_key='latent_observation',
            desired_goal_key='latent_desired_goal',
            vis_list=vis_variant.get('vis_list', []),
            dont_terminate=True,
        )

        dump_video_kwargs = variant.get("dump_video_kwargs", dict())
        dump_video_kwargs['imsize'] = vae_env.imsize
        dump_video_kwargs['vis_list'] = [
            'image_observation',
            'reconstr_image_observation',
            'image_latent_histogram_2d',
            'image_latent_histogram_mu_2d',
            'image_plt',
            'image_rew',
            'image_rew_euclidean',
            'image_rew_mahalanobis',
            'image_rew_logp',
            'image_rew_kl',
            'image_rew_kl_rev',
        ]

    def visualization_post_processing(save_vis, save_video, epoch):
        vis_list = vis_variant.get('vis_list', [])

        if save_vis:
            if vae_env.vae_input_key_prefix == 'state':
                vae_env.dump_reconstructions(epoch,
                                             n_recon=vis_variant.get(
                                                 'n_recon', 16))
            vae_env.dump_samples(epoch,
                                 n_samples=vis_variant.get('n_samples', 64))
            if 'latent_representation' in vis_list:
                vae_env.dump_latent_plots(epoch)
            if any(elem in vis_list for elem in [
                    'latent_histogram', 'latent_histogram_mu',
                    'latent_histogram_2d', 'latent_histogram_mu_2d'
            ]):
                vae_env.compute_latent_histogram()
            if not save_video and ('latent_histogram' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=True,
                                              use_true_prior=True)
            if not save_video and ('latent_histogram_mu' in vis_list):
                vae_env.dump_latent_histogram(epoch=epoch,
                                              noisy=False,
                                              use_true_prior=True)

        if save_video and save_vis:
            from railrl.envs.vae_wrappers import temporary_mode
            from railrl.misc.video_gen import dump_video
            from railrl.core import logger

            vae_env.compute_goal_encodings()

            logdir = logger.get_snapshot_dir()
            filename = osp.join(logdir,
                                'video_{epoch}.mp4'.format(epoch=epoch))
            variant['dump_video_kwargs']['epoch'] = epoch
            temporary_mode(vae_env,
                           mode='video_env',
                           func=dump_video,
                           args=(vae_env, random_policy, filename,
                                 rollout_function),
                           kwargs=variant['dump_video_kwargs'])
            if not vis_variant.get('save_video_env_only', True):
                filename = osp.join(
                    logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch))
                temporary_mode(vae_env,
                               mode='video_vae',
                               func=dump_video,
                               args=(vae_env, random_policy, filename,
                                     rollout_function),
                               kwargs=variant['dump_video_kwargs'])

    # train vae
    for epoch in range(variant['num_epochs']):
        #for epoch in range(2000):
        save_vis = (epoch % vis_variant['save_period'] == 0
                    or epoch == variant['num_epochs'] - 1)
        save_vae = (epoch % variant['snapshot_gap'] == 0
                    or epoch == variant['num_epochs'] - 1)

        t.train_epoch(epoch)
        '''if epoch % 500 == 0 or epoch == variant['num_epochs']-1:
           t.test_epoch(
                epoch,
                save_reconstruction=save_vis,
                save_interpolation=save_vis,
                save_vae=save_vae,
            )
        if epoch % 200 == 0 or epoch == variant['num_epochs']-1:
            visualization_post_processing(save_video, save_video, epoch)'''

        t.test_epoch(
            epoch,
            save_reconstruction=save_vis,
            save_interpolation=save_vis,
            save_vae=save_vae,
        )
        if epoch % 300 == 0 or epoch == variant['num_epochs'] - 1:
            visualization_post_processing(save_vis, save_video, epoch)

    logger.save_extra_data(m, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output(
        'vae_progress.csv',
        relative_to_snapshot_dir=True,
    )
    logger.add_tabular_output(
        'progress.csv',
        relative_to_snapshot_dir=True,
    )
    print("finished --------------------!!!!!!!!!!!!!!!")

    return m