예제 #1
0
 def eval_rollout_function(self):
     return create_rollout_function(
         tdm_rollout,
         init_tau=self.max_tau,
         cycle_tau=self.cycle_taus_for_rollout,
         decrement_tau=self.cycle_taus_for_rollout,
         observation_key=self.observation_key,
         desired_goal_key=self.desired_goal_key,
     )
예제 #2
0
def make_video(args):
    if args.pause:
        import ipdb
        ipdb.set_trace()
    data = pickle.load(open(args.file, "rb"))  # joblib.load(args.file)
    if 'policy' in data:
        policy = data['policy']
    elif 'evaluation/policy' in data:
        policy = data['evaluation/policy']
    else:
        raise AttributeError

    if 'env' in data:
        env = data['env']
    elif 'evaluation/env' in data:
        env = data['evaluation/env']
    else:
        raise AttributeError

    if isinstance(env, RemoteRolloutEnv):
        env = env._wrapped_env
    print("Policy loaded")
    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.to(ptu.device)
    else:
        ptu.set_gpu_mode(False)
        policy.to(ptu.device)
    if isinstance(env, VAEWrappedEnv):
        env.mode(args.mode)

    max_path_length = 100
    observation_key = 'latent_observation'
    desired_goal_key = 'latent_desired_goal'
    rollout_function = rf.create_rollout_function(
        rf.multitask_rollout,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    env.mode(env._mode_map['video_env'])
    random_id = str(uuid.uuid4()).split('-')[0]
    dump_video(
        env,
        policy,
        'rollouts_{}.mp4'.format(random_id),
        rollout_function,
        rows=3,
        columns=6,
        pad_length=0,
        pad_color=255,
        do_timer=True,
        horizon=max_path_length,
        dirname_to_save_images=None,
        subdirname="rollouts",
        imsize=48,
    )
예제 #3
0
 def __init__(
     self,
     observation_key='observation',
     desired_goal_key='desired_goal',
 ):
     self.observation_key = observation_key
     self.desired_goal_key = desired_goal_key
     self.eval_rollout_function = rf.create_rollout_function(
         rf.multitask_rollout,
         observation_key=self.observation_key,
         desired_goal_key=self.desired_goal_key)
예제 #4
0
    def __init__(
        self,
        observation_key=None,
        desired_goal_key=None,
        rollout_goal_params=None,
    ):
        self.observation_key = observation_key
        self.desired_goal_key = desired_goal_key
        self.rollout_goal_params = rollout_goal_params
        self._rollout_goal = None

        self.train_rollout_function = create_rollout_function(
            multitask_rollout,
            observation_key=self.observation_key,
            desired_goal_key=self.desired_goal_key)
        self.eval_rollout_function = self.train_rollout_function
예제 #5
0
def grill_her_td3_experiment(variant):
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])

    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['replay_buffer'] = replay_buffer
    td3_kwargs = algo_kwargs['td3_kwargs']
    td3_kwargs['training_env'] = env
    td3_kwargs['render'] = variant["render"]
    her_kwargs = algo_kwargs['her_kwargs']
    her_kwargs['observation_key'] = observation_key
    her_kwargs['desired_goal_key'] = desired_goal_key
    algorithm = HerTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       **variant['algo_kwargs'])

    if variant.get("save_video", True):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)

    algorithm.to(ptu.device)
    env.vae.to(ptu.device)

    algorithm.train()
예제 #6
0
 def eval_rollout_function(self):
     return create_rollout_function(
         multitask_rollout,
         observation_key=self.observation_key,
         desired_goal_key=self.desired_goal_key,
     )
예제 #7
0
def _e2e_disentangled_experiment(max_path_length,
                                 encoder_kwargs,
                                 disentangled_qf_kwargs,
                                 qf_kwargs,
                                 twin_sac_trainer_kwargs,
                                 replay_buffer_kwargs,
                                 policy_kwargs,
                                 vae_evaluation_goal_sampling_mode,
                                 vae_exploration_goal_sampling_mode,
                                 base_env_evaluation_goal_sampling_mode,
                                 base_env_exploration_goal_sampling_mode,
                                 algo_kwargs,
                                 env_id=None,
                                 env_class=None,
                                 env_kwargs=None,
                                 observation_key='state_observation',
                                 desired_goal_key='state_desired_goal',
                                 achieved_goal_key='state_achieved_goal',
                                 latent_dim=2,
                                 vae_wrapped_env_kwargs=None,
                                 vae_path=None,
                                 vae_n_vae_training_kwargs=None,
                                 vectorized=False,
                                 save_video=True,
                                 save_video_kwargs=None,
                                 have_no_disentangled_encoder=False,
                                 **kwargs):
    if env_kwargs is None:
        env_kwargs = {}
    assert env_id or env_class

    if env_id:
        import gym
        import multiworld
        multiworld.register_all_envs()
        train_env = gym.make(env_id)
        eval_env = gym.make(env_id)
    else:
        eval_env = env_class(**env_kwargs)
        train_env = env_class(**env_kwargs)

    train_env.goal_sampling_mode = base_env_exploration_goal_sampling_mode
    eval_env.goal_sampling_mode = base_env_evaluation_goal_sampling_mode

    if vae_path:
        vae = load_local_or_remote_file(vae_path)
    else:
        vae = get_n_train_vae(latent_dim=latent_dim,
                              env=eval_env,
                              **vae_n_vae_training_kwargs)

    train_env = VAEWrappedEnv(train_env,
                              vae,
                              imsize=train_env.imsize,
                              **vae_wrapped_env_kwargs)
    eval_env = VAEWrappedEnv(eval_env,
                             vae,
                             imsize=train_env.imsize,
                             **vae_wrapped_env_kwargs)

    obs_dim = train_env.observation_space.spaces[observation_key].low.size
    goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size
    action_dim = train_env.action_space.low.size

    encoder = ConcatMlp(input_size=obs_dim,
                        output_size=latent_dim,
                        **encoder_kwargs)

    def make_qf():
        if have_no_disentangled_encoder:
            return ConcatMlp(
                input_size=obs_dim + goal_dim + action_dim,
                output_size=1,
                **qf_kwargs,
            )
        else:
            return DisentangledMlpQf(encoder=encoder,
                                     preprocess_obs_dim=obs_dim,
                                     action_dim=action_dim,
                                     qf_kwargs=qf_kwargs,
                                     vectorized=vectorized,
                                     **disentangled_qf_kwargs)

    qf1 = make_qf()
    qf2 = make_qf()
    target_qf1 = make_qf()
    target_qf2 = make_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    replay_buffer = ObsDictRelabelingBuffer(
        env=train_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        vectorized=vectorized,
        **replay_buffer_kwargs)
    sac_trainer = SACTrainer(env=train_env,
                             policy=policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             **twin_sac_trainer_kwargs)
    trainer = HERTrainer(sac_trainer)

    eval_path_collector = VAEWrappedEnvPathCollector(
        eval_env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=vae_evaluation_goal_sampling_mode,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        train_env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=vae_exploration_goal_sampling_mode,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=train_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs,
    )
    algorithm.to(ptu.device)

    if save_video:
        save_vf_heatmap = save_video_kwargs.get('save_vf_heatmap', True)

        if have_no_disentangled_encoder:

            def v_function(obs):
                action = policy.get_actions(obs)
                obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
                return qf1(obs, action)

            add_heatmap = partial(add_heatmap_img_to_o_dict,
                                  v_function=v_function)
        else:

            def v_function(obs):
                action = policy.get_actions(obs)
                obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
                return qf1(obs, action, return_individual_q_vals=True)

            add_heatmap = partial(
                add_heatmap_imgs_to_o_dict,
                v_function=v_function,
                vectorized=vectorized,
            )
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            full_o_postprocess_func=add_heatmap if save_vf_heatmap else None,
        )
        img_keys = ['v_vals'] + [
            'v_vals_dim_{}'.format(dim) for dim in range(latent_dim)
        ]
        eval_video_func = get_save_video_function(rollout_function,
                                                  eval_env,
                                                  MakeDeterministic(policy),
                                                  get_extra_imgs=partial(
                                                      get_extra_imgs,
                                                      img_keys=img_keys),
                                                  tag="eval",
                                                  **save_video_kwargs)
        train_video_func = get_save_video_function(rollout_function,
                                                   train_env,
                                                   policy,
                                                   get_extra_imgs=partial(
                                                       get_extra_imgs,
                                                       img_keys=img_keys),
                                                   tag="train",
                                                   **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(train_video_func)
    algorithm.train()
예제 #8
0
def td3_experiment(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)

    from rlkit.torch.td3.td3 import TD3 as TD3Trainer
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm

    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    # preprocess_rl_variant(variant)
    env = get_envs(variant)
    expl_env = env
    eval_env = env
    es = get_exploration_strategy(variant, env)

    if variant.get("use_masks", False):
        mask_wrapper_kwargs = variant.get("mask_wrapper_kwargs", dict())

        expl_mask_distribution_kwargs = variant[
            "expl_mask_distribution_kwargs"]
        expl_mask_distribution = DiscreteDistribution(
            **expl_mask_distribution_kwargs)
        expl_env = RewardMaskWrapper(env, expl_mask_distribution,
                                     **mask_wrapper_kwargs)

        eval_mask_distribution_kwargs = variant[
            "eval_mask_distribution_kwargs"]
        eval_mask_distribution = DiscreteDistribution(
            **eval_mask_distribution_kwargs)
        eval_env = RewardMaskWrapper(env, eval_mask_distribution,
                                     **mask_wrapper_kwargs)
        env = eval_env

    max_path_length = variant['max_path_length']

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = variant.get('achieved_goal_key',
                                    'latent_achieved_goal')
    # achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)

    action_dim = env.action_space.low.size
    qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])

    if variant.get("use_subgoal_policy", False):
        from rlkit.policies.timed_policy import SubgoalPolicyWrapper

        subgoal_policy_kwargs = variant.get('subgoal_policy_kwargs', {})

        policy = SubgoalPolicyWrapper(wrapped_policy=policy,
                                      env=env,
                                      episode_length=max_path_length,
                                      **subgoal_policy_kwargs)
        target_policy = SubgoalPolicyWrapper(wrapped_policy=target_policy,
                                             env=env,
                                             episode_length=max_path_length,
                                             **subgoal_policy_kwargs)

    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        # use_masks=variant.get("use_masks", False),
        **variant['replay_buffer_kwargs'])

    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant['td3_trainer_kwargs'])
    # if variant.get("use_masks", False):
    #     from rlkit.torch.her.her import MaskedHERTrainer
    #     trainer = MaskedHERTrainer(trainer)
    # else:
    trainer = HERTrainer(trainer)
    if variant.get("do_state_exp", False):
        eval_path_collector = GoalConditionedPathCollector(
            eval_env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            # use_masks=variant.get("use_masks", False),
            # full_mask=True,
        )
        expl_path_collector = GoalConditionedPathCollector(
            expl_env,
            expl_policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            # use_masks=variant.get("use_masks", False),
        )
    else:
        eval_path_collector = VAEWrappedEnvPathCollector(
            env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            goal_sampling_mode=['evaluation_goal_sampling_mode'],
        )
        expl_path_collector = VAEWrappedEnvPathCollector(
            env,
            expl_policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            goal_sampling_mode=['exploration_goal_sampling_mode'],
        )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    vis_variant = variant.get('vis_kwargs', {})
    vis_list = vis_variant.get('vis_list', [])
    if variant.get("save_video", True):
        if variant.get("do_state_exp", False):
            rollout_function = rf.create_rollout_function(
                rf.multitask_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                desired_goal_key=desired_goal_key,
                # use_masks=variant.get("use_masks", False),
                # full_mask=True,
                # vis_list=vis_list,
            )
            video_func = get_video_save_func(
                rollout_function,
                env,
                policy,
                variant,
            )
        else:
            video_func = VideoSaveFunction(
                env,
                variant,
            )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)
    algorithm.train()
예제 #9
0
def simulate_policy(args):
    data = joblib.load(args.file)
    if 'eval_policy' in data:
        policy = data['eval_policy']
    elif 'policy' in data:
        policy = data['policy']
    elif 'exploration_policy' in data:
        policy = data['exploration_policy']
    else:
        raise Exception("No policy found in loaded dict. Keys: {}".format(
            data.keys()))
    max_tau = get_max_tau(args)

    env = data['env']

    env.mode("video_env")
    env.decode_goals = True

    if hasattr(env, 'enable_render'):
        # some environments need to be reconfigured for visualization
        env.enable_render()

    if args.gpu:
        set_gpu_mode(True)
        policy.to(ptu.device)
        if hasattr(env, "vae"):
            env.vae.to(ptu.device)
    else:
        # make sure everything is on the CPU
        set_gpu_mode(False)
        policy.cpu()
        if hasattr(env, "vae"):
            env.vae.cpu()

    if args.pause:
        import ipdb
        ipdb.set_trace()
    if isinstance(policy, PyTorchModule):
        policy.train(False)
    ROWS = 3
    COLUMNS = 6
    dirname = osp.dirname(args.file)
    input_file_name = os.path.splitext(os.path.basename(args.file))[0]
    filename = osp.join(dirname, "video_{}.mp4".format(input_file_name))
    rollout_function = create_rollout_function(
        tdm_rollout,
        init_tau=max_tau,
        observation_key='observation',
        desired_goal_key='desired_goal',
    )
    paths = dump_video(
        env,
        policy,
        filename,
        rollout_function,
        ROWS=ROWS,
        COLUMNS=COLUMNS,
        horizon=args.H,
        dirname_to_save_images=dirname,
        subdirname="rollouts_" + input_file_name,
    )

    if hasattr(env, "log_diagnostics"):
        env.log_diagnostics(paths)
    logger.dump_tabular()
예제 #10
0
def tdm_twin_sac_experiment(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.state_distance.tdm_networks import (
        TdmQf,
        TdmVf,
        StochasticTdmPolicy,
    )
    from rlkit.state_distance.tdm_twin_sac import TdmTwinSAC
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size)
    goal_dim = (env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size

    vectorized = 'vectorized' in env.reward_type
    norm_order = env.norm_order
    variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized
    variant['qf_kwargs']['vectorized'] = vectorized
    variant['vf_kwargs']['vectorized'] = vectorized
    variant['qf_kwargs']['norm_order'] = norm_order
    variant['vf_kwargs']['norm_order'] = norm_order

    qf1 = TdmQf(env=env,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    vf = TdmVf(env=env,
               observation_dim=obs_dim,
               goal_dim=goal_dim,
               **variant['vf_kwargs'])
    policy = StochasticTdmPolicy(env=env,
                                 observation_dim=obs_dim,
                                 goal_dim=goal_dim,
                                 action_dim=action_dim,
                                 **variant['policy_kwargs'])
    variant['replay_buffer_kwargs']['vectorized'] = vectorized
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['replay_buffer'] = replay_buffer
    base_kwargs = algo_kwargs['base_kwargs']
    base_kwargs['training_env'] = env
    base_kwargs['render'] = variant["render"]
    base_kwargs['render_during_eval'] = variant["render"]
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algorithm = TdmTwinSAC(env,
                           qf1=qf1,
                           qf2=qf2,
                           vf=vf,
                           policy=policy,
                           **variant['algo_kwargs'])

    if variant.get("save_video", True):
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm._sample_max_tau_for_rollout(),
            decrement_tau=algorithm.cycle_taus_for_rollout,
            cycle_tau=algorithm.cycle_taus_for_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)

    algorithm.train()
예제 #11
0
def tdm_td3_experiment(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.core import logger
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.state_distance.tdm_networks import TdmQf, TdmPolicy
    from rlkit.state_distance.tdm_td3 import TdmTd3
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size)
    goal_dim = (env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size

    vectorized = 'vectorized' in env.reward_type
    norm_order = env.norm_order
    variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized
    variant['qf_kwargs']['vectorized'] = vectorized
    variant['qf_kwargs']['norm_order'] = norm_order

    qf1 = TdmQf(env=env,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       observation_dim=obs_dim,
                       goal_dim=goal_dim,
                       action_dim=action_dim,
                       **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    variant['replay_buffer_kwargs']['vectorized'] = vectorized
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['replay_buffer'] = replay_buffer
    base_kwargs = algo_kwargs['base_kwargs']
    base_kwargs['training_env'] = env
    base_kwargs['render'] = variant["render"]
    base_kwargs['render_during_eval'] = variant["render"]
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algorithm = TdmTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       **variant['algo_kwargs'])

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)
    if variant.get("save_video", True):
        logdir = logger.get_snapshot_dir()
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm.max_tau,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)
    algorithm.train()
예제 #12
0
def grill_her_sac_experiment(variant):
    env = variant["env_class"](**variant['env_kwargs'])

    render = variant["render"]

    rdim = variant["rdim"]
    vae_path = variant["vae_paths"][str(rdim)]
    reward_params = variant.get("reward_params", dict())

    init_camera = variant.get("init_camera", None)
    if init_camera is None:
        camera_name = "topview"
    else:
        camera_name = None

    env = ImageEnv(
        env,
        84,
        init_camera=init_camera,
        camera_name=camera_name,
        transpose=True,
        normalize=True,
    )

    env = VAEWrappedEnv(env,
                        vae_path,
                        decode_goals=render,
                        render_goals=render,
                        render_rollouts=render,
                        reward_params=reward_params,
                        **variant.get('vae_wrapped_env_kwargs', {}))

    if variant['normalize']:
        env = NormalizedBoxEnv(env)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    vf = ConcatMlp(
        input_size=obs_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=hidden_sizes,
    )

    training_mode = variant.get("training_mode", "train")
    testing_mode = variant.get("testing_mode", "test")

    testing_env = pickle.loads(pickle.dumps(env))
    testing_env.mode(testing_mode)

    training_env = pickle.loads(pickle.dumps(env))
    training_env.mode(training_mode)

    relabeling_env = pickle.loads(pickle.dumps(env))
    relabeling_env.mode(training_mode)
    relabeling_env.disable_render()

    video_vae_env = pickle.loads(pickle.dumps(env))
    video_vae_env.mode("video_vae")
    video_goal_env = pickle.loads(pickle.dumps(env))
    video_goal_env.mode("video_env")

    replay_buffer = ObsDictRelabelingBuffer(
        env=relabeling_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_kwargs'])
    variant["algo_kwargs"]["replay_buffer"] = replay_buffer
    algorithm = HerSac(testing_env,
                       training_env=training_env,
                       qf=qf,
                       vf=vf,
                       policy=policy,
                       render=render,
                       render_during_eval=render,
                       observation_key=observation_key,
                       desired_goal_key=desired_goal_key,
                       **variant['algo_kwargs'])

    if ptu.gpu_enabled():
        print("using GPU")
        qf.to(ptu.device)
        vf.to(ptu.device)
        policy.to(ptu.device)
        algorithm.to(ptu.device)
        for e in [testing_env, training_env, video_vae_env, video_goal_env]:
            e.vae.to(ptu.device)

    algorithm.train()

    if variant.get("save_video", True):
        logdir = logger.get_snapshot_dir()
        policy.train(False)
        filename = osp.join(logdir, 'video_final_env.mp4')
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        dump_video(video_goal_env, policy, filename, rollout_function)
        filename = osp.join(logdir, 'video_final_vae.mp4')
        dump_video(video_vae_env, policy, filename, rollout_function)
예제 #13
0
def grill_her_td3_experiment(variant):
    env = variant["env_class"](**variant['env_kwargs'])

    render = variant["render"]

    rdim = variant["rdim"]
    vae_path = variant["vae_paths"][str(rdim)]
    reward_params = variant.get("reward_params", dict())

    init_camera = variant.get("init_camera", None)
    if init_camera is None:
        camera_name = "topview"
    else:
        camera_name = None

    env = ImageEnv(
        env,
        84,
        init_camera=init_camera,
        camera_name=camera_name,
        transpose=True,
        normalize=True,
    )

    env = VAEWrappedEnv(env,
                        vae_path,
                        decode_goals=render,
                        render_goals=render,
                        render_rollouts=render,
                        reward_params=reward_params,
                        **variant.get('vae_wrapped_env_kwargs', {}))

    if variant['normalize']:
        env = NormalizedBoxEnv(env)
    exploration_type = variant['exploration_type']
    exploration_noise = variant.get('exploration_noise', 0.1)
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space)
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            max_sigma=exploration_noise,
            min_sigma=exploration_noise,  # Constant sigma
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            prob_random_action=exploration_noise,
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    training_mode = variant.get("training_mode", "train")
    testing_mode = variant.get("testing_mode", "test")

    testing_env = pickle.loads(pickle.dumps(env))
    testing_env.mode(testing_mode)

    training_env = pickle.loads(pickle.dumps(env))
    training_env.mode(training_mode)

    relabeling_env = pickle.loads(pickle.dumps(env))
    relabeling_env.mode(training_mode)
    relabeling_env.disable_render()

    video_vae_env = pickle.loads(pickle.dumps(env))
    video_vae_env.mode("video_vae")
    video_goal_env = pickle.loads(pickle.dumps(env))
    video_goal_env.mode("video_env")

    replay_buffer = ObsDictRelabelingBuffer(
        env=relabeling_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_kwargs'])
    variant["algo_kwargs"]["replay_buffer"] = replay_buffer
    algorithm = HerTd3(testing_env,
                       training_env=training_env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       render=render,
                       render_during_eval=render,
                       observation_key=observation_key,
                       desired_goal_key=desired_goal_key,
                       **variant['algo_kwargs'])

    if ptu.gpu_enabled():
        print("using GPU")
        algorithm.to(ptu.device)
        for e in [testing_env, training_env, video_vae_env, video_goal_env]:
            e.vae.to(ptu.device)

    algorithm.train()

    if variant.get("save_video", True):
        logdir = logger.get_snapshot_dir()
        policy.train(False)
        filename = osp.join(logdir, 'video_final_env.mp4')
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        dump_video(video_goal_env, policy, filename, rollout_function)
        filename = osp.join(logdir, 'video_final_vae.mp4')
        dump_video(video_vae_env, policy, filename, rollout_function)
예제 #14
0
def her_td3_experiment(variant):
    import gym

    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer
    from rlkit.exploration_strategies.base import \
        PolicyWrappedWithExplorationStrategy
    from rlkit.exploration_strategies.gaussian_and_epislon import \
        GaussianAndEpislonStrategy
    from rlkit.launchers.launcher_util import setup_logger
    from rlkit.samplers.data_collector import GoalConditionedPathCollector
    from rlkit.torch.her.her import HERTrainer
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.torch.td3.td3 import TD3
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    import rlkit.samplers.rollout_functions as rf
    from rlkit.torch.grill.launcher import get_state_experiment_video_save_function

    if 'env_id' in variant:
        eval_env = gym.make(variant['env_id'])
        expl_env = gym.make(variant['env_id'])
    else:
        eval_env_kwargs = variant.get('eval_env_kwargs', variant['env_kwargs'])
        eval_env = variant['env_class'](**eval_env_kwargs)
        expl_env = variant['env_class'](**variant['env_kwargs'])

    observation_key = 'state_observation'
    desired_goal_key = 'state_desired_goal'
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es = GaussianAndEpislonStrategy(
        action_space=expl_env.action_space,
        max_sigma=.2,
        min_sigma=.2,  # constant sigma
        epsilon=.3,
    )
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    trainer = TD3(policy=policy,
                  qf1=qf1,
                  qf2=qf2,
                  target_qf1=target_qf1,
                  target_qf2=target_qf2,
                  target_policy=target_policy,
                  **variant['trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])

    if variant.get("save_video", False):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )
        video_func = get_state_experiment_video_save_function(
            rollout_function,
            eval_env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)

    algorithm.to(ptu.device)
    algorithm.train()
예제 #15
0
def _use_disentangled_encoder_distance(
        max_path_length,
        encoder_kwargs,
        disentangled_qf_kwargs,
        qf_kwargs,
        sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        evaluation_goal_sampling_mode,
        exploration_goal_sampling_mode,
        algo_kwargs,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        encoder_key_prefix='encoder',
        encoder_input_prefix='state',
        latent_dim=2,
        reward_mode=EncoderWrappedEnv.ENCODER_DISTANCE_REWARD,
        # Video parameters
        save_video=True,
        save_video_kwargs=None,
        save_vf_heatmap=True,
        **kwargs):
    if save_video_kwargs is None:
        save_video_kwargs = {}
    if env_kwargs is None:
        env_kwargs = {}
    assert env_id or env_class
    vectorized = (
        reward_mode == EncoderWrappedEnv.VECTORIZED_ENCODER_DISTANCE_REWARD)

    if env_id:
        import gym
        import multiworld
        multiworld.register_all_envs()
        raw_train_env = gym.make(env_id)
        raw_eval_env = gym.make(env_id)
    else:
        raw_eval_env = env_class(**env_kwargs)
        raw_train_env = env_class(**env_kwargs)

    raw_train_env.goal_sampling_mode = exploration_goal_sampling_mode
    raw_eval_env.goal_sampling_mode = evaluation_goal_sampling_mode

    raw_obs_dim = (
        raw_train_env.observation_space.spaces['state_observation'].low.size)
    action_dim = raw_train_env.action_space.low.size

    encoder = ConcatMlp(input_size=raw_obs_dim,
                        output_size=latent_dim,
                        **encoder_kwargs)
    encoder = Identity()
    encoder.input_size = raw_obs_dim
    encoder.output_size = raw_obs_dim

    np_encoder = EncoderFromNetwork(encoder)
    train_env = EncoderWrappedEnv(
        raw_train_env,
        np_encoder,
        encoder_input_prefix,
        key_prefix=encoder_key_prefix,
        reward_mode=reward_mode,
    )
    eval_env = EncoderWrappedEnv(
        raw_eval_env,
        np_encoder,
        encoder_input_prefix,
        key_prefix=encoder_key_prefix,
        reward_mode=reward_mode,
    )
    observation_key = '{}_observation'.format(encoder_key_prefix)
    desired_goal_key = '{}_desired_goal'.format(encoder_key_prefix)
    achieved_goal_key = '{}_achieved_goal'.format(encoder_key_prefix)
    obs_dim = train_env.observation_space.spaces[observation_key].low.size
    goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size

    def make_qf():
        return DisentangledMlpQf(encoder=encoder,
                                 preprocess_obs_dim=obs_dim,
                                 action_dim=action_dim,
                                 qf_kwargs=qf_kwargs,
                                 vectorized=vectorized,
                                 **disentangled_qf_kwargs)

    qf1 = make_qf()
    qf2 = make_qf()
    target_qf1 = make_qf()
    target_qf2 = make_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    replay_buffer = ObsDictRelabelingBuffer(
        env=train_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        vectorized=vectorized,
        **replay_buffer_kwargs)
    sac_trainer = SACTrainer(env=train_env,
                             policy=policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             **sac_trainer_kwargs)
    trainer = HERTrainer(sac_trainer)

    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode='env',
    )
    expl_path_collector = GoalConditionedPathCollector(
        train_env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode='env',
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=train_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:

        def v_function(obs):
            action = policy.get_actions(obs)
            obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
            return qf1(obs, action, return_individual_q_vals=True)

        add_heatmap = partial(
            add_heatmap_imgs_to_o_dict,
            v_function=v_function,
            vectorized=vectorized,
        )
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            full_o_postprocess_func=add_heatmap if save_vf_heatmap else None,
        )
        img_keys = ['v_vals'] + [
            'v_vals_dim_{}'.format(dim) for dim in range(latent_dim)
        ]
        eval_video_func = get_save_video_function(rollout_function,
                                                  eval_env,
                                                  MakeDeterministic(policy),
                                                  get_extra_imgs=partial(
                                                      get_extra_imgs,
                                                      img_keys=img_keys),
                                                  tag="eval",
                                                  **save_video_kwargs)
        train_video_func = get_save_video_function(rollout_function,
                                                   train_env,
                                                   policy,
                                                   get_extra_imgs=partial(
                                                       get_extra_imgs,
                                                       img_keys=img_keys),
                                                   tag="train",
                                                   **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(train_video_func)
    algorithm.train()
예제 #16
0
def td3_experiment_online_vae_exploring(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.torch.her.online_vae_joint_algo import OnlineVaeHerJointAlgo
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.torch.td3.td3 import TD3
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs'],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    exploring_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    exploring_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    exploring_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs'],
    )
    exploring_exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=exploring_policy,
    )

    vae = env.vae
    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    variant["algo_kwargs"]["replay_buffer"] = replay_buffer
    if variant.get('use_replay_buffer_goals', False):
        env.replay_buffer = replay_buffer
        env.use_replay_buffer_goals = True

    vae_trainer_kwargs = variant.get('vae_trainer_kwargs')
    t = ConvVAETrainer(variant['vae_train_data'],
                       variant['vae_test_data'],
                       vae,
                       beta=variant['online_vae_beta'],
                       **vae_trainer_kwargs)

    control_algorithm = TD3(env=env,
                            training_env=env,
                            qf1=qf1,
                            qf2=qf2,
                            policy=policy,
                            exploration_policy=exploration_policy,
                            **variant['algo_kwargs'])
    exploring_algorithm = TD3(env=env,
                              training_env=env,
                              qf1=exploring_qf1,
                              qf2=exploring_qf2,
                              policy=exploring_policy,
                              exploration_policy=exploring_exploration_policy,
                              **variant['algo_kwargs'])

    assert 'vae_training_schedule' not in variant,\
        "Just put it in joint_algo_kwargs"
    algorithm = OnlineVaeHerJointAlgo(vae=vae,
                                      vae_trainer=t,
                                      env=env,
                                      training_env=env,
                                      policy=policy,
                                      exploration_policy=exploration_policy,
                                      replay_buffer=replay_buffer,
                                      algo1=control_algorithm,
                                      algo2=exploring_algorithm,
                                      algo1_prefix="Control_",
                                      algo2_prefix="VAE_Exploration_",
                                      observation_key=observation_key,
                                      desired_goal_key=desired_goal_key,
                                      **variant['joint_algo_kwargs'])

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    if variant.get("save_video", True):
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)
    algorithm.train()
예제 #17
0
def _disentangled_her_twin_sac_experiment_v2(
        max_path_length,
        encoder_kwargs,
        disentangled_qf_kwargs,
        qf_kwargs,
        twin_sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        evaluation_goal_sampling_mode,
        exploration_goal_sampling_mode,
        algo_kwargs,
        save_video=True,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='state_observation',
        desired_goal_key='state_desired_goal',
        achieved_goal_key='state_achieved_goal',
        # Video parameters
        latent_dim=2,
        save_video_kwargs=None,
        **kwargs
):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.torch.networks import ConcatMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm

    if save_video_kwargs is None:
        save_video_kwargs = {}

    if env_kwargs is None:
        env_kwargs = {}
    assert env_id or env_class

    if env_id:
        import gym
        import multiworld
        multiworld.register_all_envs()
        train_env = gym.make(env_id)
        eval_env = gym.make(env_id)
    else:
        eval_env = env_class(**env_kwargs)
        train_env = env_class(**env_kwargs)

    obs_dim = train_env.observation_space.spaces[observation_key].low.size
    goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size
    action_dim = train_env.action_space.low.size

    encoder = ConcatMlp(
        input_size=goal_dim,
        output_size=latent_dim,
        **encoder_kwargs
    )

    qf1 = DisentangledMlpQf(
        encoder=encoder,
        preprocess_obs_dim=obs_dim,
        action_dim=action_dim,
        qf_kwargs=qf_kwargs,
        **disentangled_qf_kwargs
    )
    qf2 = DisentangledMlpQf(
        encoder=encoder,
        preprocess_obs_dim=obs_dim,
        action_dim=action_dim,
        qf_kwargs=qf_kwargs,
        **disentangled_qf_kwargs
    )
    target_qf1 = DisentangledMlpQf(
        encoder=Detach(encoder),
        preprocess_obs_dim=obs_dim,
        action_dim=action_dim,
        qf_kwargs=qf_kwargs,
        **disentangled_qf_kwargs
    )
    target_qf2 = DisentangledMlpQf(
        encoder=Detach(encoder),
        preprocess_obs_dim=obs_dim,
        action_dim=action_dim,
        qf_kwargs=qf_kwargs,
        **disentangled_qf_kwargs
    )

    policy = TanhGaussianPolicy(
        obs_dim=obs_dim + goal_dim,
        action_dim=action_dim,
        **policy_kwargs
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=train_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **replay_buffer_kwargs
    )
    sac_trainer = SACTrainer(
        env=train_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **twin_sac_trainer_kwargs
    )
    trainer = HERTrainer(sac_trainer)

    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=evaluation_goal_sampling_mode,
    )
    expl_path_collector = GoalConditionedPathCollector(
        train_env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=exploration_goal_sampling_mode,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=train_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs,
    )
    algorithm.to(ptu.device)

    if save_video:
        save_vf_heatmap = save_video_kwargs.get('save_vf_heatmap', True)

        def v_function(obs):
            action = policy.get_actions(obs)
            obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
            return qf1(obs, action, return_individual_q_vals=True)
        add_heatmap = partial(add_heatmap_imgs_to_o_dict, v_function=v_function)
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            full_o_postprocess_func=add_heatmap if save_vf_heatmap else None,
        )

        img_keys = ['v_vals'] + [
            'v_vals_dim_{}'.format(dim) for dim
            in range(latent_dim)
        ]
        eval_video_func = get_save_video_function(
            rollout_function,
            eval_env,
            MakeDeterministic(policy),
            tag="eval",
            get_extra_imgs=partial(get_extra_imgs, img_keys=img_keys),
            **save_video_kwargs
        )
        train_video_func = get_save_video_function(
            rollout_function,
            train_env,
            policy,
            tag="train",
            get_extra_imgs=partial(get_extra_imgs, img_keys=img_keys),
            **save_video_kwargs
        )
        decoder = ConcatMlp(
            input_size=obs_dim,
            output_size=obs_dim,
            hidden_sizes=[128, 128],
        )
        decoder.to(ptu.device)

        # algorithm.post_train_funcs.append(train_decoder(variant, encoder, decoder))
        # algorithm.post_train_funcs.append(plot_encoder_function(variant, encoder))
        # algorithm.post_train_funcs.append(plot_buffer_function(
            # save_video_period, 'state_achieved_goal'))
        # algorithm.post_train_funcs.append(plot_buffer_function(
            # save_video_period, 'state_desired_goal'))
        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(train_video_func)



    algorithm.train()
예제 #18
0
def tdm_td3_experiment_online_vae(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.state_distance.tdm_networks import TdmQf, TdmPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.torch.online_vae.online_vae_tdm_td3 import OnlineVaeTdmTd3
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    vae_trainer_kwargs = variant.get('vae_trainer_kwargs')
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size)
    goal_dim = (env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size

    vectorized = 'vectorized' in env.reward_type
    variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][
        'vectorized'] = vectorized

    norm_order = env.norm_order
    # variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][
    #     'norm_order'] = norm_order

    qf1 = TdmQf(env=env,
                vectorized=vectorized,
                norm_order=norm_order,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                vectorized=vectorized,
                norm_order=norm_order,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       observation_dim=obs_dim,
                       goal_dim=goal_dim,
                       action_dim=action_dim,
                       **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']['tdm_td3_kwargs']
    td3_kwargs = algo_kwargs['td3_kwargs']
    td3_kwargs['training_env'] = env
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algo_kwargs["replay_buffer"] = replay_buffer

    t = ConvVAETrainer(variant['vae_train_data'],
                       variant['vae_test_data'],
                       vae,
                       beta=variant['online_vae_beta'],
                       **vae_trainer_kwargs)
    render = variant["render"]
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    algorithm = OnlineVaeTdmTd3(
        online_vae_kwargs=dict(vae=vae,
                               vae_trainer=t,
                               **variant['algo_kwargs']['online_vae_kwargs']),
        tdm_td3_kwargs=dict(env=env,
                            qf1=qf1,
                            qf2=qf2,
                            policy=policy,
                            exploration_policy=exploration_policy,
                            **variant['algo_kwargs']['tdm_td3_kwargs']),
    )

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    if variant.get("save_video", True):
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm._sample_max_tau_for_rollout(),
            decrement_tau=algorithm.cycle_taus_for_rollout,
            cycle_tau=algorithm.cycle_taus_for_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)

    algorithm.train()
예제 #19
0
def her_sac_experiment(
        max_path_length,
        qf_kwargs,
        twin_sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        evaluation_goal_sampling_mode,
        exploration_goal_sampling_mode,
        algo_kwargs,
        save_video=True,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='state_observation',
        desired_goal_key='state_desired_goal',
        achieved_goal_key='state_achieved_goal',
        # Video parameters
        save_video_kwargs=None,
        exploration_policy_kwargs=None,
        **kwargs
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.torch.networks import ConcatMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    if not save_video_kwargs:
        save_video_kwargs = {}

    if env_kwargs is None:
        env_kwargs = {}

    assert env_id or env_class
    if env_id:
        import gym
        import multiworld
        multiworld.register_all_envs()
        train_env = gym.make(env_id)
        eval_env = gym.make(env_id)
    else:
        eval_env = env_class(**env_kwargs)
        train_env = env_class(**env_kwargs)

    obs_dim = (
            train_env.observation_space.spaces[observation_key].low.size
            + train_env.observation_space.spaces[desired_goal_key].low.size
    )
    action_dim = train_env.action_space.low.size
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **policy_kwargs
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=train_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **replay_buffer_kwargs
    )
    trainer = SACTrainer(
        env=train_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **twin_sac_trainer_kwargs
    )
    trainer = HERTrainer(trainer)

    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=evaluation_goal_sampling_mode,
    )
    exploration_policy = create_exploration_policy(
        train_env, policy, **exploration_policy_kwargs)
    expl_path_collector = GoalConditionedPathCollector(
        train_env,
        exploration_policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=exploration_goal_sampling_mode,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=train_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs
    )
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            return_dict_obs=True,
        )
        eval_video_func = get_save_video_function(
            rollout_function,
            eval_env,
            MakeDeterministic(policy),
            tag="eval",
            **save_video_kwargs
        )
        train_video_func = get_save_video_function(
            rollout_function,
            train_env,
            exploration_policy,
            tag="expl",
            **save_video_kwargs
        )

        # algorithm.post_train_funcs.append(plot_buffer_function(
            # save_video_period, 'state_achieved_goal'))
        # algorithm.post_train_funcs.append(plot_buffer_function(
            # save_video_period, 'state_desired_goal'))
        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(train_video_func)

    algorithm.train()
예제 #20
0
def her_td3_experiment(variant):
    if 'env_id' in variant:
        env = gym.make(variant['env_id'])
    else:
        env = variant['env_class'](**variant['env_kwargs'])

    observation_key = variant['observation_key']
    desired_goal_key = variant['desired_goal_key']
    variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key
    variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key
    if variant.get('normalize', False):
        raise NotImplementedError()

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = env.observation_space.spaces['observation'].low.size
    action_dim = env.action_space.low.size
    goal_dim = env.observation_space.spaces['desired_goal'].low.size
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = HerTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       replay_buffer=replay_buffer,
                       **variant['algo_kwargs'])
    if variant.get("save_video", False):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)
    algorithm.to(ptu.device)
    algorithm.train()
예제 #21
0
def her_td3_experiment(variant):
    import gym
    import multiworld.envs.mujoco
    import multiworld.envs.pygame
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy
    from rlkit.exploration_strategies.gaussian_strategy import GaussianStrategy
    from rlkit.exploration_strategies.ou_strategy import OUStrategy
    from rlkit.torch.grill.launcher import get_video_save_func
    from rlkit.torch.her.her_td3 import HerTd3
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.data_management.obs_dict_replay_buffer import (
        ObsDictRelabelingBuffer)

    if 'env_id' in variant:
        env = gym.make(variant['env_id'])
    else:
        env = variant['env_class'](**variant['env_kwargs'])

    observation_key = variant['observation_key']
    desired_goal_key = variant['desired_goal_key']
    variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key
    variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key
    if variant.get('normalize', False):
        raise NotImplementedError()

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = env.observation_space.spaces['observation'].low.size
    action_dim = env.action_space.low.size
    goal_dim = env.observation_space.spaces['desired_goal'].low.size
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    qf1 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = HerTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       replay_buffer=replay_buffer,
                       **variant['algo_kwargs'])
    if variant.get("save_video", False):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)
    algorithm.to(ptu.device)
    algorithm.train()