def td3_experiment_online_vae_exploring(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.torch.her.online_vae_joint_algo import OnlineVaeHerJointAlgo
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.torch.td3.td3 import TD3
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs'],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    exploring_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    exploring_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    exploring_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs'],
    )
    exploring_exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=exploring_policy,
    )

    vae = env.vae
    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    variant["algo_kwargs"]["replay_buffer"] = replay_buffer
    if variant.get('use_replay_buffer_goals', False):
        env.replay_buffer = replay_buffer
        env.use_replay_buffer_goals = True

    vae_trainer_kwargs = variant.get('vae_trainer_kwargs')
    t = ConvVAETrainer(variant['vae_train_data'],
                       variant['vae_test_data'],
                       vae,
                       beta=variant['online_vae_beta'],
                       **vae_trainer_kwargs)

    control_algorithm = TD3(env=env,
                            training_env=env,
                            qf1=qf1,
                            qf2=qf2,
                            policy=policy,
                            exploration_policy=exploration_policy,
                            **variant['algo_kwargs'])
    exploring_algorithm = TD3(env=env,
                              training_env=env,
                              qf1=exploring_qf1,
                              qf2=exploring_qf2,
                              policy=exploring_policy,
                              exploration_policy=exploring_exploration_policy,
                              **variant['algo_kwargs'])

    assert 'vae_training_schedule' not in variant,\
        "Just put it in joint_algo_kwargs"
    algorithm = OnlineVaeHerJointAlgo(vae=vae,
                                      vae_trainer=t,
                                      env=env,
                                      training_env=env,
                                      policy=policy,
                                      exploration_policy=exploration_policy,
                                      replay_buffer=replay_buffer,
                                      algo1=control_algorithm,
                                      algo2=exploring_algorithm,
                                      algo1_prefix="Control_",
                                      algo2_prefix="VAE_Exploration_",
                                      observation_key=observation_key,
                                      desired_goal_key=desired_goal_key,
                                      **variant['joint_algo_kwargs'])

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    if variant.get("save_video", True):
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)
    algorithm.train()
Beispiel #2
0
def grill_her_td3_experiment(variant):
    env = variant["env_class"](**variant['env_kwargs'])

    render = variant["render"]

    rdim = variant["rdim"]
    vae_path = variant["vae_paths"][str(rdim)]
    reward_params = variant.get("reward_params", dict())

    init_camera = variant.get("init_camera", None)
    if init_camera is None:
        camera_name = "topview"
    else:
        camera_name = None

    env = ImageEnv(
        env,
        84,
        init_camera=init_camera,
        camera_name=camera_name,
        transpose=True,
        normalize=True,
    )

    env = VAEWrappedEnv(env,
                        vae_path,
                        decode_goals=render,
                        render_goals=render,
                        render_rollouts=render,
                        reward_params=reward_params,
                        **variant.get('vae_wrapped_env_kwargs', {}))

    if variant['normalize']:
        env = NormalizedBoxEnv(env)
    exploration_type = variant['exploration_type']
    exploration_noise = variant.get('exploration_noise', 0.1)
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space)
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            max_sigma=exploration_noise,
            min_sigma=exploration_noise,  # Constant sigma
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            prob_random_action=exploration_noise,
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    training_mode = variant.get("training_mode", "train")
    testing_mode = variant.get("testing_mode", "test")

    testing_env = pickle.loads(pickle.dumps(env))
    testing_env.mode(testing_mode)

    training_env = pickle.loads(pickle.dumps(env))
    training_env.mode(training_mode)

    relabeling_env = pickle.loads(pickle.dumps(env))
    relabeling_env.mode(training_mode)
    relabeling_env.disable_render()

    video_vae_env = pickle.loads(pickle.dumps(env))
    video_vae_env.mode("video_vae")
    video_goal_env = pickle.loads(pickle.dumps(env))
    video_goal_env.mode("video_env")

    replay_buffer = ObsDictRelabelingBuffer(
        env=relabeling_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_kwargs'])
    variant["algo_kwargs"]["replay_buffer"] = replay_buffer
    algorithm = HerTd3(testing_env,
                       training_env=training_env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       render=render,
                       render_during_eval=render,
                       observation_key=observation_key,
                       desired_goal_key=desired_goal_key,
                       **variant['algo_kwargs'])

    if ptu.gpu_enabled():
        print("using GPU")
        algorithm.to(ptu.device)
        for e in [testing_env, training_env, video_vae_env, video_goal_env]:
            e.vae.to(ptu.device)

    algorithm.train()

    if variant.get("save_video", True):
        logdir = logger.get_snapshot_dir()
        policy.train(False)
        filename = osp.join(logdir, 'video_final_env.mp4')
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        dump_video(video_goal_env, policy, filename, rollout_function)
        filename = osp.join(logdir, 'video_final_vae.mp4')
        dump_video(video_vae_env, policy, filename, rollout_function)