Ejemplo n.º 1
0
def image_based_goal_conditioned_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    cnn_kwargs,
    policy_type='tanh-normal',
    env_id=None,
    env_class=None,
    env_kwargs=None,
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    reward_type='state_distance',
    env_renderer_kwargs=None,
    # Data augmentations
    apply_random_crops=False,
    random_crop_pixel_shift=4,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not env_renderer_kwargs:
        env_renderer_kwargs = {}
    if not video_renderer_kwargs:
        video_renderer_kwargs = {}
    img_observation_key = 'image_observation'
    img_desired_goal_key = 'image_desired_goal'
    state_observation_key = 'state_observation'
    state_desired_goal_key = 'state_desired_goal'
    state_achieved_goal_key = 'state_achieved_goal'
    sample_context_from_obs_dict_fn = RemapKeyFn({
        'image_desired_goal':
        'image_observation',
        'state_desired_goal':
        'state_observation',
    })

    def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode,
                             renderer):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        state_env.goal_sampling_mode = goal_sampling_mode
        state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        image_goal_distribution = AddImageDistribution(
            env=state_env,
            base_distribution=state_goal_distribution,
            image_goal_key=img_desired_goal_key,
            renderer=renderer,
        )
        goal_distribution = PresampledDistribution(image_goal_distribution,
                                                   5000)
        img_env = InsertImageEnv(state_env, renderer=renderer)
        if reward_type == 'state_distance':
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=state_env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    'state_observation'),
                desired_goal_key=state_desired_goal_key,
                achieved_goal_key=state_achieved_goal_key,
            )
        elif reward_type == 'pixel_distance':
            reward_fn = NegativeL2Distance(
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    img_observation_key),
                desired_goal_key=img_desired_goal_key,
            )
        else:
            raise ValueError(reward_type)
        env = ContextualEnv(
            img_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=img_observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, goal_distribution, reward_fn

    env_renderer = EnvRenderer(**env_renderer_kwargs)
    expl_env, expl_context_distrib, expl_reward = setup_contextual_env(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode,
        env_renderer)
    eval_env, eval_context_distrib, eval_reward = setup_contextual_env(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        env_renderer)

    action_dim = expl_env.action_space.low.size
    if env_renderer.output_image_format == 'WHC':
        img_width, img_height, img_num_channels = (
            expl_env.observation_space[img_observation_key].shape)
    elif env_renderer.output_image_format == 'CHW':
        img_num_channels, img_height, img_width = (
            expl_env.observation_space[img_observation_key].shape)
    else:
        raise ValueError(env_renderer.output_image_format)

    def create_qf():
        cnn = BasicCNN(input_width=img_width,
                       input_height=img_height,
                       input_channels=img_num_channels,
                       **cnn_kwargs)
        joint_cnn = ApplyConvToStateAndGoalImage(cnn)
        return basic.MultiInputSequential(
            ApplyToObs(joint_cnn), basic.FlattenEachParallel(),
            ConcatMlp(input_size=joint_cnn.output_size + action_dim,
                      output_size=1,
                      **qf_kwargs))

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()
    cnn = BasicCNN(input_width=img_width,
                   input_height=img_height,
                   input_channels=img_num_channels,
                   **cnn_kwargs)
    joint_cnn = ApplyConvToStateAndGoalImage(cnn)
    policy_obs_dim = joint_cnn.output_size
    if policy_type == 'normal':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(Gaussian(obs_processor))
    elif policy_type == 'tanh-normal':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(TanhGaussian(obs_processor))
    elif policy_type == 'normal-tanh-mean':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           output_activations=['tanh', 'identity'],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(Gaussian(obs_processor))
    else:
        raise ValueError("Unknown policy type: {}".format(policy_type))

    if apply_random_crops:
        pad = BatchPad(
            env_renderer.output_image_format,
            random_crop_pixel_shift,
            random_crop_pixel_shift,
        )
        crop = JointRandomCrop(
            env_renderer.output_image_format,
            env_renderer.image_shape,
        )

        def concat_context_to_obs(batch, *args, **kwargs):
            obs = batch['observations']
            next_obs = batch['next_observations']
            context = batch[img_desired_goal_key]
            obs_padded = pad(obs)
            next_obs_padded = pad(next_obs)
            context_padded = pad(context)
            obs_aug, context_aug = crop(obs_padded, context_padded)
            next_obs_aug, next_context_aug = crop(next_obs_padded,
                                                  context_padded)

            batch['observations'] = np.concatenate([obs_aug, context_aug],
                                                   axis=1)
            batch['next_observations'] = np.concatenate(
                [next_obs_aug, next_context_aug], axis=1)
            return batch
    else:

        def concat_context_to_obs(batch, *args, **kwargs):
            obs = batch['observations']
            next_obs = batch['next_observations']
            context = batch[img_desired_goal_key]
            batch['observations'] = np.concatenate([obs, context], axis=1)
            batch['next_observations'] = np.concatenate([next_obs, context],
                                                        axis=1)
            return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[img_desired_goal_key, state_desired_goal_key],
        observation_key=img_observation_key,
        observation_keys=[img_observation_key, state_observation_key],
        context_distribution=eval_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=img_observation_key,
        context_keys_for_policy=[img_desired_goal_key],
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=img_observation_key,
        context_keys_for_policy=[img_desired_goal_key],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=img_observation_key,
            context_keys_for_policy=[img_desired_goal_key],
        )
        video_renderer = EnvRenderer(**video_renderer_kwargs)
        video_eval_env = InsertImageEnv(eval_env,
                                        renderer=video_renderer,
                                        image_key='video_observation')
        video_expl_env = InsertImageEnv(expl_env,
                                        renderer=video_renderer,
                                        image_key='video_observation')
        video_eval_env = ContextualEnv(
            video_eval_env,
            context_distribution=eval_env.context_distribution,
            reward_fn=lambda *_: np.array([0]),
            observation_key=img_observation_key,
        )
        video_expl_env = ContextualEnv(
            video_expl_env,
            context_distribution=expl_env.context_distribution,
            reward_fn=lambda *_: np.array([0]),
            observation_key=img_observation_key,
        )
        eval_video_func = get_save_video_function(
            rollout_function,
            video_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=video_renderer.image_shape[1],
            image_formats=[
                env_renderer.output_image_format,
                env_renderer.output_image_format,
                video_renderer.output_image_format,
            ],
            keys_to_show=[
                'image_desired_goal', 'image_observation', 'video_observation'
            ],
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            video_expl_env,
            exploration_policy,
            tag="xplor",
            imsize=video_renderer.image_shape[1],
            image_formats=[
                env_renderer.output_image_format,
                env_renderer.output_image_format,
                video_renderer.output_image_format,
            ],
            keys_to_show=[
                'image_desired_goal', 'image_observation', 'video_observation'
            ],
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    algorithm.train()
Ejemplo n.º 2
0
def sac_on_gym_goal_env_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    env_id=None,
    env_class=None,
    env_kwargs=None,
    observation_key='observation',
    desired_goal_key='desired_goal',
    achieved_goal_key='achieved_goal',
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}
    context_key = desired_goal_key
    sample_context_from_obs_dict_fn = RemapKeyFn(
        {context_key: achieved_goal_key})

    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode):
        env = get_gym_env(
            env_id,
            env_class=env_class,
            env_kwargs=env_kwargs,
            unwrap_timed_envs=True,
        )
        env.goal_sampling_mode = goal_sampling_mode
        goal_distribution = GoalDictDistributionFromGymGoalEnv(
            env,
            desired_goal_key=desired_goal_key,
        )
        distance_fn = L2Distance(
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                achieved_goal_key, ),
            desired_goal_key=desired_goal_key,
        )
        if (isinstance(env, robotics.FetchReachEnv)
                or isinstance(env, robotics.FetchPushEnv)
                or isinstance(env, robotics.FetchPickAndPlaceEnv)
                or isinstance(env, robotics.FetchSlideEnv)):
            success_threshold = 0.05
        else:
            raise TypeError("I don't know the success threshold of env ", env)
        reward_fn = ThresholdDistanceReward(distance_fn, success_threshold)
        diag_fn = GenericGoalConditionedContextualDiagnostics(
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            success_threshold=success_threshold,
        )
        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, goal_distribution, reward_fn

    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode)
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode)

    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key],
        observation_keys_to_save=[observation_key, achieved_goal_key],
        context_distribution=eval_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )
    exploration_policy = create_exploration_policy(policy=policy,
                                                   env=expl_env,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:

        # Setting the goal like this is discourage, but the Fetch environment
        # are designed to visualize the goals by setting their goal parameter.
        def set_goal_for_visualization(env, policy, o):
            goal = o[desired_goal_key]
            print(goal)
            env.unwrapped.goal = goal

        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            context_keys_for_policy=[context_key],
            reset_callback=set_goal_for_visualization,
        )
        renderer = GymEnvRenderer(**renderer_kwargs)

        def add_images(env, context_distribution):
            state_env = env.env
            img_env = InsertImageEnv(
                state_env,
                renderer=renderer,
                image_key='image_observation',
            )
            return ContextualEnv(
                img_env,
                context_distribution=context_distribution,
                reward_fn=eval_reward,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )

        img_eval_env = add_images(eval_env, eval_context_distrib)
        img_expl_env = add_images(expl_env, expl_context_distrib)
        eval_video_func = get_save_video_function(
            rollout_function,
            img_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=renderer.image_chw[1],
            image_format=renderer.output_image_format,
            keys_to_show=['image_observation'],
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="train",
            imsize=renderer.image_chw[1],
            image_format=renderer.output_image_format,
            keys_to_show=['image_observation'],
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    algorithm.train()
Ejemplo n.º 3
0
def goal_conditioned_sac_experiment(
        max_path_length,
        qf_kwargs,
        sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        algo_kwargs,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='state_observation',
        desired_goal_key='state_desired_goal',
        achieved_goal_key='state_achieved_goal',
        exploration_policy_kwargs=None,
        evaluation_goal_sampling_mode=None,
        exploration_goal_sampling_mode=None,
        # Video parameters
        save_video=True,
        save_video_kwargs=None,
        renderer_kwargs=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}
    context_key = desired_goal_key
    sample_context_from_obs_dict_fn = RemapKeyFn({context_key: observation_key})

    def contextual_env_distrib_and_reward(
            env_id, env_class, env_kwargs, goal_sampling_mode
    ):
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        env.goal_sampling_mode = goal_sampling_mode
        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            env,
            desired_goal_keys=[desired_goal_key],
        )
        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=env,
            achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key),
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
        )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, goal_distribution, reward_fn


    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode
    )
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode
    )

    obs_dim = (
            expl_env.observation_space.spaces[observation_key].low.size
            + expl_env.observation_space.spaces[context_key].low.size
    )
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **qf_kwargs
        )
    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **policy_kwargs
    )

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context], axis=1)
        return batch
    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key],
        observation_keys=[observation_key],
        context_distribution=eval_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs
    )
    trainer = SACTrainer(
        env=expl_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **sac_trainer_kwargs
    )

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )
    exploration_policy = create_exploration_policy(
        policy=policy, env=expl_env, **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs
    )
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            context_keys_for_policy=[context_key],
        )
        renderer = EnvRenderer(**renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImageEnv(state_env, renderer=renderer)
            return ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=eval_reward,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
        img_eval_env = add_images(eval_env, eval_context_distrib)
        img_expl_env = add_images(expl_env, expl_context_distrib)
        eval_video_func = get_save_video_function(
            rollout_function,
            img_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=renderer.width,
            image_format=renderer.output_image_format,
            **save_video_kwargs
        )
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="train",
            imsize=renderer.width,
            image_format=renderer.output_image_format,
            **save_video_kwargs
        )

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    algorithm.train()
Ejemplo n.º 4
0
def her_sac_experiment(
        max_path_length,
        qf_kwargs,
        twin_sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        evaluation_goal_sampling_mode,
        exploration_goal_sampling_mode,
        algo_kwargs,
        save_video=True,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='state_observation',
        desired_goal_key='state_desired_goal',
        achieved_goal_key='state_achieved_goal',
        # Video parameters
        save_video_kwargs=None,
        exploration_policy_kwargs=None,
        **kwargs
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.torch.networks import ConcatMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    if not save_video_kwargs:
        save_video_kwargs = {}

    if env_kwargs is None:
        env_kwargs = {}

    assert env_id or env_class
    if env_id:
        import gym
        import multiworld
        multiworld.register_all_envs()
        train_env = gym.make(env_id)
        eval_env = gym.make(env_id)
    else:
        eval_env = env_class(**env_kwargs)
        train_env = env_class(**env_kwargs)

    obs_dim = (
            train_env.observation_space.spaces[observation_key].low.size
            + train_env.observation_space.spaces[desired_goal_key].low.size
    )
    action_dim = train_env.action_space.low.size
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **policy_kwargs
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=train_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **replay_buffer_kwargs
    )
    trainer = SACTrainer(
        env=train_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **twin_sac_trainer_kwargs
    )
    trainer = HERTrainer(trainer)

    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=evaluation_goal_sampling_mode,
    )
    exploration_policy = create_exploration_policy(
        train_env, policy, **exploration_policy_kwargs)
    expl_path_collector = GoalConditionedPathCollector(
        train_env,
        exploration_policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=exploration_goal_sampling_mode,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=train_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs
    )
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            return_dict_obs=True,
        )
        eval_video_func = get_save_video_function(
            rollout_function,
            eval_env,
            MakeDeterministic(policy),
            tag="eval",
            **save_video_kwargs
        )
        train_video_func = get_save_video_function(
            rollout_function,
            train_env,
            exploration_policy,
            tag="expl",
            **save_video_kwargs
        )

        # algorithm.post_train_funcs.append(plot_buffer_function(
            # save_video_period, 'state_achieved_goal'))
        # algorithm.post_train_funcs.append(plot_buffer_function(
            # save_video_period, 'state_desired_goal'))
        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(train_video_func)

    algorithm.train()
Ejemplo n.º 5
0
def rig_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    train_vae_kwargs,
    env_id=None,
    env_class=None,
    env_kwargs=None,
    observation_key='latent_observation',
    desired_goal_key='latent_desired_goal',
    state_goal_key='state_desired_goal',
    state_observation_key='state_observation',
    image_goal_key='image_desired_goal',
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
    imsize=48,
    pretrained_vae_path="",
    init_camera=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)

    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)

        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        encoded_env = EncoderWrappedEnv(
            img_env,
            model,
            dict(image_observation="latent_observation", ),
        )
        if goal_sampling_mode == "vae_prior":
            latent_goal_distribution = PriorDistribution(
                model.representation_size,
                desired_goal_key,
            )
            diagnostics = StateImageGoalDiagnosticsFn({}, )
        elif goal_sampling_mode == "reset_of_env":
            state_goal_env = get_gym_env(env_id,
                                         env_class=env_class,
                                         env_kwargs=env_kwargs)
            state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
                state_goal_env,
                desired_goal_keys=[state_goal_key],
            )
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_goal_distribution,
                image_goal_key=image_goal_key,
                renderer=renderer,
            )
            latent_goal_distribution = AddLatentDistribution(
                image_goal_distribution,
                image_goal_key,
                desired_goal_key,
                model,
            )
            if hasattr(state_goal_env, 'goal_conditioned_diagnostics'):
                diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics(
                    state_goal_env.goal_conditioned_diagnostics,
                    desired_goal_key=state_goal_key,
                    observation_key=state_observation_key,
                )
            else:
                state_goal_env.get_contextual_diagnostics
                diagnostics = state_goal_env.get_contextual_diagnostics
        else:
            raise NotImplementedError('unknown goal sampling method: %s' %
                                      goal_sampling_mode)

        reward_fn = DistanceRewardFn(
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )

        env = ContextualEnv(
            encoded_env,
            context_distribution=latent_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, latent_goal_distribution, reward_fn

    if pretrained_vae_path:
        model = load_local_or_remote_file(pretrained_vae_path)
    else:
        model = train_vae(train_vae_kwargs, env_kwargs, env_id, env_class,
                          imsize, init_camera)

    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode)
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode)
    context_key = desired_goal_key

    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key],
        observation_keys=[observation_key],
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=RemapKeyFn(
            {context_key: observation_key}),
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        expl_video_func = RIGVideoSaveFunction(
            model,
            expl_path_collector,
            "train",
            decode_goal_image_key="image_decoded_goal",
            reconstruction_key="image_reconstruction",
            rows=2,
            columns=5,
            unnormalize=True,
            # max_path_length=200,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(expl_video_func)

        eval_video_func = RIGVideoSaveFunction(
            model,
            eval_path_collector,
            "eval",
            goal_image_key=image_goal_key,
            decode_goal_image_key="image_decoded_goal",
            reconstruction_key="image_reconstruction",
            num_imgs=4,
            rows=2,
            columns=5,
            unnormalize=True,
            # max_path_length=200,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)

    algorithm.train()
Ejemplo n.º 6
0
def probabilistic_goal_reaching_experiment(
    max_path_length,
    qf_kwargs,
    policy_kwargs,
    pgr_trainer_kwargs,
    replay_buffer_kwargs,
    algo_kwargs,
    env_id,
    discount_factor,
    reward_type,
    # Dynamics model
    dynamics_model_version,
    dynamics_model_config,
    dynamics_delta_model_config=None,
    dynamics_adam_config=None,
    dynamics_ensemble_kwargs=None,
    # Discount model
    learn_discount_model=False,
    discount_adam_config=None,
    discount_model_config=None,
    prior_discount_weight_schedule_kwargs=None,
    # Environment
    env_class=None,
    env_kwargs=None,
    observation_key='state_observation',
    desired_goal_key='state_desired_goal',
    exploration_policy_kwargs=None,
    action_noise_scale=0.,
    num_presampled_goals=4096,
    success_threshold=0.05,
    # Video / visualization parameters
    save_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
    plot_renderer_kwargs=None,
    eval_env_ids=None,
    # Debugging params
    visualize_dynamics=False,
    visualize_discount_model=False,
    visualize_all_plots=False,
    plot_discount=False,
    plot_reward=False,
    plot_bootstrap_value=False,
    # env specific-params
    normalize_distances_for_full_state_ant=False,
):
    if dynamics_ensemble_kwargs is None:
        dynamics_ensemble_kwargs = {}
    if eval_env_ids is None:
        eval_env_ids = {'eval': env_id}
    if discount_model_config is None:
        discount_model_config = {}
    if dynamics_delta_model_config is None:
        dynamics_delta_model_config = {}
    if dynamics_adam_config is None:
        dynamics_adam_config = {}
    if discount_adam_config is None:
        discount_adam_config = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not video_renderer_kwargs:
        video_renderer_kwargs = {}
    if not plot_renderer_kwargs:
        plot_renderer_kwargs = video_renderer_kwargs.copy()
        plot_renderer_kwargs['dpi'] = 48
    context_key = desired_goal_key

    stub_env = get_gym_env(
        env_id,
        env_class=env_class,
        env_kwargs=env_kwargs,
        unwrap_timed_envs=True,
    )
    is_gym_env = (
        isinstance(stub_env, FetchEnv) or isinstance(stub_env, AntXYGoalEnv)
        or isinstance(stub_env, AntFullPositionGoalEnv)
        # or isinstance(stub_env, HopperFullPositionGoalEnv)
    )
    is_ant_full_pos = isinstance(stub_env, AntFullPositionGoalEnv)

    if is_gym_env:
        achieved_goal_key = desired_goal_key.replace('desired', 'achieved')
        ob_keys_to_save_in_buffer = [observation_key, achieved_goal_key]
    elif isinstance(stub_env, SawyerPickAndPlaceEnvYZ):
        achieved_goal_key = desired_goal_key.replace('desired', 'achieved')
        ob_keys_to_save_in_buffer = [observation_key, achieved_goal_key]
    else:
        achieved_goal_key = observation_key
        ob_keys_to_save_in_buffer = [observation_key]
    # TODO move all env-specific code to other file
    if isinstance(stub_env, SawyerDoorHookEnv):
        init_camera = sawyer_door_env_camera_v0
    elif isinstance(stub_env, SawyerPushAndReachXYEnv):
        init_camera = sawyer_init_camera_zoomed_in
    elif isinstance(stub_env, SawyerPickAndPlaceEnvYZ):
        init_camera = sawyer_pick_and_place_camera
    else:
        init_camera = None

    full_ob_space = stub_env.observation_space
    action_space = stub_env.action_space
    state_to_goal = StateToGoalFn(stub_env)
    dynamics_model = create_goal_dynamics_model(
        full_ob_space[observation_key],
        action_space,
        full_ob_space[achieved_goal_key],
        dynamics_model_version,
        state_to_goal,
        dynamics_model_config,
        dynamics_delta_model_config,
        ensemble_model_kwargs=dynamics_ensemble_kwargs,
    )
    sample_context_from_obs_dict_fn = RemapKeyFn(
        {context_key: achieved_goal_key})

    def contextual_env_distrib_reward(_env_id,
                                      _env_class=None,
                                      _env_kwargs=None):
        base_env = get_gym_env(
            _env_id,
            env_class=env_class,
            env_kwargs=env_kwargs,
            unwrap_timed_envs=True,
        )
        if init_camera:
            base_env.initialize_camera(init_camera)
        if (isinstance(stub_env, AntFullPositionGoalEnv)
                and normalize_distances_for_full_state_ant):
            base_env = NormalizeAntFullPositionGoalEnv(base_env)
            normalize_env = base_env
        else:
            normalize_env = None
        env = NoisyAction(base_env, action_noise_scale)
        diag_fns = []
        if is_gym_env:
            goal_distribution = GoalDictDistributionFromGymGoalEnv(
                env,
                desired_goal_key=desired_goal_key,
            )
            diag_fns.append(
                GenericGoalConditionedContextualDiagnostics(
                    desired_goal_key=desired_goal_key,
                    achieved_goal_key=achieved_goal_key,
                    success_threshold=success_threshold,
                ))
        else:
            goal_distribution = GoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
            )
            diag_fns.append(
                GoalConditionedDiagnosticsToContextualDiagnostics(
                    env.goal_conditioned_diagnostics,
                    desired_goal_key=desired_goal_key,
                    observation_key=observation_key,
                ))
        if isinstance(stub_env, AntFullPositionGoalEnv):
            diag_fns.append(
                AntFullPositionGoalEnvDiagnostics(
                    desired_goal_key=desired_goal_key,
                    achieved_goal_key=achieved_goal_key,
                    success_threshold=success_threshold,
                    normalize_env=normalize_env,
                ))
        # if isinstance(stub_env, HopperFullPositionGoalEnv):
        #     diag_fns.append(
        #         HopperFullPositionGoalEnvDiagnostics(
        #             desired_goal_key=desired_goal_key,
        #             achieved_goal_key=achieved_goal_key,
        #             success_threshold=success_threshold,
        #         )
        #     )
        achieved_from_ob = IndexIntoAchievedGoal(achieved_goal_key, )
        if reward_type == 'sparse':
            distance_fn = L2Distance(
                achieved_goal_from_observation=achieved_from_ob,
                desired_goal_key=desired_goal_key,
            )
            reward_fn = ThresholdDistanceReward(distance_fn, success_threshold)
        elif reward_type == 'negative_distance':
            reward_fn = NegativeL2Distance(
                achieved_goal_from_observation=achieved_from_ob,
                desired_goal_key=desired_goal_key,
            )
        else:
            reward_fn = ProbabilisticGoalRewardFn(
                dynamics_model,
                state_key=observation_key,
                context_key=context_key,
                reward_type=reward_type,
                discount_factor=discount_factor,
            )
        goal_distribution = PresampledDistribution(goal_distribution,
                                                   num_presampled_goals)
        final_env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=diag_fns,
            update_env_info_fn=delete_info,
        )
        return final_env, goal_distribution, reward_fn

    expl_env, expl_context_distrib, reward_fn = contextual_env_distrib_reward(
        env_id,
        env_class,
        env_kwargs,
    )
    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    def create_policy():
        obs_processor = MultiHeadedMlp(input_size=obs_dim,
                                       output_sizes=[action_dim, action_dim],
                                       **policy_kwargs)
        return PolicyFromDistributionGenerator(TanhGaussian(obs_processor))

    policy = create_policy()

    def concat_context_to_obs(batch, replay_buffer, obs_dict, next_obs_dict,
                              new_contexts):
        obs = batch['observations']
        next_obs = batch['next_observations']
        batch['original_observations'] = obs
        batch['original_next_observations'] = next_obs
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=expl_env,
        context_keys=[context_key],
        observation_keys=ob_keys_to_save_in_buffer,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)

    def create_trainer():
        trainers = OrderedDict()
        if learn_discount_model:
            discount_model = create_discount_model(
                ob_space=stub_env.observation_space[observation_key],
                goal_space=stub_env.observation_space[context_key],
                action_space=stub_env.action_space,
                model_kwargs=discount_model_config)
            optimizer = optim.Adam(discount_model.parameters(),
                                   **discount_adam_config)
            discount_trainer = DiscountModelTrainer(
                discount_model,
                optimizer,
                observation_key='observations',
                next_observation_key='original_next_observations',
                goal_key=context_key,
                state_to_goal_fn=state_to_goal,
            )
            trainers['discount_trainer'] = discount_trainer
        else:
            discount_model = None
        if prior_discount_weight_schedule_kwargs is not None:
            schedule = create_schedule(**prior_discount_weight_schedule_kwargs)
        else:
            schedule = None
        pgr_trainer = PGRTrainer(env=expl_env,
                                 policy=policy,
                                 qf1=qf1,
                                 qf2=qf2,
                                 target_qf1=target_qf1,
                                 target_qf2=target_qf2,
                                 discount=discount_factor,
                                 discount_model=discount_model,
                                 prior_discount_weight_schedule=schedule,
                                 **pgr_trainer_kwargs)
        trainers[''] = pgr_trainer
        optimizers = [
            pgr_trainer.qf1_optimizer,
            pgr_trainer.qf2_optimizer,
            pgr_trainer.alpha_optimizer,
            pgr_trainer.policy_optimizer,
        ]
        if dynamics_model_version in {
                'learned_model',
                'learned_model_ensemble',
                'learned_model_laplace',
                'learned_model_laplace_global_variance',
                'learned_model_gaussian_global_variance',
        }:
            model_opt = optim.Adam(dynamics_model.parameters(),
                                   **dynamics_adam_config)
        elif dynamics_model_version in {
                'fixed_standard_laplace',
                'fixed_standard_gaussian',
        }:
            model_opt = None
        else:
            raise NotImplementedError()
        model_trainer = GenerativeGoalDynamicsModelTrainer(
            dynamics_model,
            model_opt,
            state_to_goal=state_to_goal,
            observation_key='original_observations',
            next_observation_key='original_next_observations',
        )
        trainers['dynamics_trainer'] = model_trainer
        optimizers.append(model_opt)
        return JointTrainer(trainers), pgr_trainer

    trainer, pgr_trainer = create_trainer()

    eval_policy = MakeDeterministic(policy)

    def create_eval_path_collector(some_eval_env):
        return ContextualPathCollector(
            some_eval_env,
            eval_policy,
            observation_key=observation_key,
            context_keys_for_policy=[context_key],
        )

    path_collectors = dict()
    eval_env_name_to_env_and_context_distrib = dict()
    for name, extra_env_id in eval_env_ids.items():
        env, context_distrib, _ = contextual_env_distrib_reward(extra_env_id)
        path_collectors[name] = create_eval_path_collector(env)
        eval_env_name_to_env_and_context_distrib[name] = (env, context_distrib)
    eval_path_collector = JointPathCollector(path_collectors)
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )

    def get_eval_diagnostics(key_to_paths):
        stats = OrderedDict()
        for eval_env_name, paths in key_to_paths.items():
            env, _ = eval_env_name_to_env_and_context_distrib[eval_env_name]
            stats.update(
                add_prefix(
                    env.get_diagnostics(paths),
                    eval_env_name,
                    divider='/',
                ))
            stats.update(
                add_prefix(
                    eval_util.get_generic_path_information(paths),
                    eval_env_name,
                    divider='/',
                ))
        return stats

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=None,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        evaluation_get_diagnostic_functions=[get_eval_diagnostics],
        **algo_kwargs)
    algorithm.to(ptu.device)

    if normalize_distances_for_full_state_ant and is_ant_full_pos:
        qpos_weights = expl_env.unwrapped.presampled_qpos.std(axis=0)
    else:
        qpos_weights = None

    if save_video:
        if is_gym_env:
            video_renderer = GymEnvRenderer(**video_renderer_kwargs)

            def set_goal_for_visualization(env, policy, o):
                goal = o[desired_goal_key]
                if normalize_distances_for_full_state_ant and is_ant_full_pos:
                    unnormalized_goal = goal * qpos_weights
                    env.unwrapped.goal = unnormalized_goal
                else:
                    env.unwrapped.goal = goal

            rollout_function = partial(
                rf.contextual_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                context_keys_for_policy=[context_key],
                reset_callback=set_goal_for_visualization,
            )
        else:
            video_renderer = EnvRenderer(**video_renderer_kwargs)
            rollout_function = partial(
                rf.contextual_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                context_keys_for_policy=[context_key],
                reset_callback=None,
            )

        renderers = OrderedDict(image_observation=video_renderer, )
        state_env = expl_env.env
        state_space = state_env.observation_space[observation_key]
        low = state_space.low.min()
        high = state_space.high.max()
        y = np.linspace(low, high, num=video_renderer.image_chw[1])
        x = np.linspace(low, high, num=video_renderer.image_chw[2])
        all_xy_np = np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])
        all_xy_torch = ptu.from_numpy(all_xy_np)
        num_states = all_xy_torch.shape[0]
        if visualize_dynamics:

            def create_dynamics_visualizer(show_prob, vary_state=False):
                def get_prob(obs_dict, action):
                    obs = obs_dict['state_observation']
                    obs_torch = ptu.from_numpy(obs)[None]
                    action_torch = ptu.from_numpy(action)[None]
                    if vary_state:
                        action_repeated = torch.zeros((num_states, 2))
                        dist = dynamics_model(all_xy_torch, action_repeated)
                        goal = ptu.from_numpy(
                            obs_dict['state_desired_goal'][None])
                        log_probs = dist.log_prob(goal)
                    else:
                        dist = dynamics_model(obs_torch, action_torch)
                        log_probs = dist.log_prob(all_xy_torch)
                    if show_prob:
                        return log_probs.exp()
                    else:
                        return log_probs

                return get_prob

            renderers['log_prob'] = ValueRenderer(
                create_dynamics_visualizer(False), **video_renderer_kwargs)
            # renderers['prob'] = ValueRenderer(
            #     create_dynamics_visualizer(True), **video_renderer_kwargs
            # )
            renderers['log_prob_vary_state'] = ValueRenderer(
                create_dynamics_visualizer(False, vary_state=True),
                only_get_image_once_per_episode=True,
                max_out_walls=isinstance(stub_env, PickAndPlaceEnv),
                **video_renderer_kwargs)
            # renderers['prob_vary_state'] = ValueRenderer(
            #     create_dynamics_visualizer(True, vary_state=True),
            #     **video_renderer_kwargs)
        if visualize_discount_model and pgr_trainer.discount_model:

            def get_discount_values(obs, action):
                obs = obs['state_observation']
                obs_torch = ptu.from_numpy(obs)[None]
                combined_obs = torch.cat([
                    obs_torch.repeat(num_states, 1),
                    all_xy_torch,
                ],
                                         dim=1)

                action_torch = ptu.from_numpy(action)[None]
                action_repeated = action_torch.repeat(num_states, 1)
                return pgr_trainer.discount_model(combined_obs,
                                                  action_repeated)

            renderers['discount_model'] = ValueRenderer(
                get_discount_values,
                states_to_eval=all_xy_torch,
                **video_renderer_kwargs)
        if 'log_prob' in renderers and 'discount_model' in renderers:
            renderers['log_prob_time_discount'] = ProductRenderer(
                renderers['discount_model'], renderers['log_prob'],
                **video_renderer_kwargs)

        def get_reward(obs_dict, action, next_obs_dict):
            o = batchify(obs_dict)
            a = batchify(action)
            next_o = batchify(next_obs_dict)
            reward = reward_fn(o, a, next_o, next_o)
            return reward[0]

        def get_bootstrap(obs_dict, action, next_obs_dict, return_float=True):
            context_pt = ptu.from_numpy(obs_dict[context_key][None])
            o_pt = ptu.from_numpy(obs_dict[observation_key][None])
            next_o_pt = ptu.from_numpy(next_obs_dict[observation_key][None])
            action_torch = ptu.from_numpy(action[None])
            bootstrap, *_ = pgr_trainer.get_bootstrap_stats(
                torch.cat((o_pt, context_pt), dim=1),
                action_torch,
                torch.cat((next_o_pt, context_pt), dim=1),
            )
            if return_float:
                return ptu.get_numpy(bootstrap)[0, 0]
            else:
                return bootstrap

        def get_discount(obs_dict, action, next_obs_dict):
            bootstrap = get_bootstrap(obs_dict,
                                      action,
                                      next_obs_dict,
                                      return_float=False)
            reward_np = get_reward(obs_dict, action, next_obs_dict)
            reward = ptu.from_numpy(reward_np[None, None])
            context_pt = ptu.from_numpy(obs_dict[context_key][None])
            o_pt = ptu.from_numpy(obs_dict[observation_key][None])
            obs = torch.cat((o_pt, context_pt), dim=1)
            actions = ptu.from_numpy(action[None])
            discount = pgr_trainer.get_discount_factor(
                bootstrap,
                reward,
                obs,
                actions,
            )
            if isinstance(discount, torch.Tensor):
                discount = ptu.get_numpy(discount)[0, 0]
            return np.clip(discount, a_min=1e-3, a_max=1)

        def create_modify_fn(
            title,
            set_params=None,
            scientific=True,
        ):
            def modify(ax):
                ax.set_title(title)
                if set_params:
                    ax.set(**set_params)
                if scientific:
                    scaler = ScalarFormatter(useOffset=True)
                    scaler.set_powerlimits((1, 1))
                    ax.yaxis.set_major_formatter(scaler)
                    ax.ticklabel_format(axis='y', style='sci')

            return modify

        def add_left_margin(fig):
            fig.subplots_adjust(left=0.2)

        if visualize_all_plots or plot_discount:
            renderers['discount'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_discount,
                modify_ax_fn=create_modify_fn(
                    title='discount',
                    set_params=dict(
                        # yscale='log',
                        ylim=[-0.05, 1.1], ),
                    # scientific=False,
                ),
                modify_fig_fn=add_left_margin,
                # autoscale_y=False,
                **plot_renderer_kwargs)

        if visualize_all_plots or plot_reward:
            renderers['reward'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_reward,
                modify_ax_fn=create_modify_fn(title='reward',
                                              # scientific=False,
                                              ),
                modify_fig_fn=add_left_margin,
                **plot_renderer_kwargs)
        if visualize_all_plots or plot_bootstrap_value:
            renderers['bootstrap-value'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_bootstrap,
                modify_ax_fn=create_modify_fn(title='bootstrap value',
                                              # scientific=False,
                                              ),
                modify_fig_fn=add_left_margin,
                **plot_renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            if is_gym_env:
                goal_distribution = state_distribution
            else:
                goal_distribution = AddImageDistribution(
                    env=state_env,
                    base_distribution=state_distribution,
                    image_goal_key='image_desired_goal',
                    renderer=video_renderer,
                )
            context_env = ContextualEnv(
                state_env,
                context_distribution=goal_distribution,
                reward_fn=reward_fn,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
            return InsertDebugImagesEnv(
                context_env,
                renderers=renderers,
            )

        img_expl_env = add_images(expl_env, expl_context_distrib)
        if is_gym_env:
            imgs_to_show = list(renderers.keys())
        else:
            imgs_to_show = ['image_desired_goal'] + list(renderers.keys())
        img_formats = [video_renderer.output_image_format]
        img_formats += [r.output_image_format for r in renderers.values()]
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="xplor",
            imsize=video_renderer.image_chw[1],
            image_formats=img_formats,
            keys_to_show=imgs_to_show,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(expl_video_func)
        for eval_env_name, (env, context_distrib) in (
                eval_env_name_to_env_and_context_distrib.items()):
            img_eval_env = add_images(env, context_distrib)
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                eval_policy,
                tag=eval_env_name,
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                keys_to_show=imgs_to_show,
                **save_video_kwargs)
            algorithm.post_train_funcs.append(eval_video_func)

    algorithm.train()
Ejemplo n.º 7
0
def awac_rig_experiment(
    max_path_length,
    qf_kwargs,
    trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    train_vae_kwargs,
    policy_class=TanhGaussianPolicy,
    env_id=None,
    env_class=None,
    env_kwargs=None,
    reward_kwargs=None,
    observation_key='latent_observation',
    desired_goal_key='latent_desired_goal',
    state_observation_key='state_observation',
    state_goal_key='state_desired_goal',
    image_goal_key='image_desired_goal',
    path_loader_class=MDPPathLoader,
    demo_replay_buffer_kwargs=None,
    path_loader_kwargs=None,
    env_demo_path='',
    env_offpolicy_data_path='',
    debug=False,
    epsilon=1.0,
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    add_env_demos=False,
    add_env_offpolicy_data=False,
    save_paths=False,
    load_demos=False,
    pretrain_policy=False,
    pretrain_rl=False,
    save_pretrained_algorithm=False,

    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
    imsize=84,
    pretrained_vae_path="",
    presampled_goals_path="",
    init_camera=None,
    qf_class=ConcatMlp,
):

    #Kwarg Definitions
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if demo_replay_buffer_kwargs is None:
        demo_replay_buffer_kwargs = {}
    if path_loader_kwargs is None:
        path_loader_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    if debug:
        max_path_length = 5
        algo_kwargs['batch_size'] = 5
        algo_kwargs['num_epochs'] = 5
        algo_kwargs['num_eval_steps_per_epoch'] = 100
        algo_kwargs['num_expl_steps_per_train_loop'] = 100
        algo_kwargs['num_trains_per_train_loop'] = 10
        algo_kwargs['min_num_steps_before_training'] = 100
        algo_kwargs['min_num_steps_before_training'] = 100
        trainer_kwargs['bc_num_pretrain_steps'] = min(
            10, trainer_kwargs.get('bc_num_pretrain_steps', 0))
        trainer_kwargs['q_num_pretrain1_steps'] = min(
            10, trainer_kwargs.get('q_num_pretrain1_steps', 0))
        trainer_kwargs['q_num_pretrain2_steps'] = min(
            10, trainer_kwargs.get('q_num_pretrain2_steps', 0))

    #Enviorment Wrapping
    renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)

    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode,
                                          presampled_goals_path):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        # encoded_env = EncoderWrappedEnv(
        #     img_env,
        #     model,
        #     dict(image_observation="latent_observation", ),
        # )
        # if goal_sampling_mode == "vae_prior":
        #     latent_goal_distribution = PriorDistribution(
        #         model.representation_size,
        #         desired_goal_key,
        #     )
        #     diagnostics = StateImageGoalDiagnosticsFn({}, )
        # elif goal_sampling_mode == "presampled":
        #     diagnostics = state_env.get_contextual_diagnostics
        #     image_goal_distribution = PresampledPathDistribution(
        #         presampled_goals_path,
        #     )

        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        # elif goal_sampling_mode == "reset_of_env":
        #     state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        #     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
        #         state_goal_env,
        #         desired_goal_keys=[state_goal_key],
        #     )
        #     image_goal_distribution = AddImageDistribution(
        #         env=state_env,
        #         base_distribution=state_goal_distribution,
        #         image_goal_key=image_goal_key,
        #         renderer=renderer,
        #     )
        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        #     no_goal_distribution = PriorDistribution(
        #         representation_size=0,
        #         key="no_goal",
        #     )
        #     diagnostics = state_goal_env.get_contextual_diagnostics
        # else:
        #     error
        diagnostics = StateImageGoalDiagnosticsFn({}, )
        no_goal_distribution = PriorDistribution(
            representation_size=0,
            key="no_goal",
        )

        reward_fn = GraspingRewardFn(
            # img_env, # state_env,
            # observation_key=observation_key,
            # desired_goal_key=desired_goal_key,
            # **reward_kwargs
        )

        env = ContextualEnv(
            img_env,  # state_env,
            context_distribution=no_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, no_goal_distribution, reward_fn

    #VAE Setup
    if pretrained_vae_path:
        model = load_local_or_remote_file(pretrained_vae_path)
    else:
        model = train_vae(train_vae_kwargs, env_kwargs, env_id, env_class,
                          imsize, init_camera)
    path_loader_kwargs['model_path'] = pretrained_vae_path

    #Enviorment Definitions
    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode,
        presampled_goals_path)
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        presampled_goals_path)
    path_loader_kwargs['env'] = eval_env

    #AWAC Code
    if add_env_demos:
        path_loader_kwargs["demo_paths"].append(env_demo_path)
    if add_env_offpolicy_data:
        path_loader_kwargs["demo_paths"].append(env_offpolicy_data_path)

    #Key Setting
    context_key = desired_goal_key
    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    state_rewards = reward_kwargs.get('reward_type', 'dense') == 'wrapped_env'
    # if state_rewards:
    #     mapper = RemapKeyFn({context_key: observation_key, state_goal_key: state_observation_key})
    #     obs_keys = [state_observation_key, observation_key]
    #     cont_keys = [state_goal_key, context_key]
    # else:
    mapper = RemapKeyFn({context_key: observation_key})
    obs_keys = [observation_key]
    cont_keys = [context_key]

    #Replay Buffer
    def concat_context_to_obs(batch, replay_buffer, obs_dict, next_obs_dict,
                              new_contexts):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=cont_keys,
        observation_keys=obs_keys,
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=mapper,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    replay_buffer_kwargs.update(demo_replay_buffer_kwargs)
    demo_train_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=cont_keys,
        observation_keys=obs_keys,
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=mapper,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    demo_test_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=cont_keys,
        observation_keys=obs_keys,
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=mapper,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)

    #Neural Network Architecture
    def create_qf():
        # return ConcatMlp(
        #     input_size=obs_dim + action_dim,
        #     output_size=1,
        #     **qf_kwargs
        # )
        if qf_class is ConcatMlp:
            qf_kwargs["input_size"] = obs_dim + action_dim
        if qf_class is ConcatCNN:
            qf_kwargs["added_fc_input_size"] = action_dim
        return qf_class(output_size=1, **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = policy_class(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **policy_kwargs,
    )

    #Path Collectors
    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )

    #Algorithm
    trainer = AWACTrainer(env=eval_env,
                          policy=policy,
                          qf1=qf1,
                          qf2=qf2,
                          target_qf1=target_qf1,
                          target_qf2=target_qf2,
                          **trainer_kwargs)

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)

    algorithm.to(ptu.device)

    #Video Saving
    if save_video:

        expl_video_func = RIGVideoSaveFunction(
            model,
            expl_path_collector,
            "train",
            # decode_goal_image_key="image_decoded_goal",
            # reconstruction_key="image_reconstruction",
            rows=2,
            columns=5,
            unnormalize=True,
            imsize=imsize,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(expl_video_func)

        eval_video_func = RIGVideoSaveFunction(
            model,
            eval_path_collector,
            "eval",
            # goal_image_key=image_goal_key,
            # decode_goal_image_key="image_decoded_goal",
            # reconstruction_key="image_reconstruction",
            num_imgs=4,
            rows=2,
            columns=5,
            unnormalize=True,
            imsize=imsize,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)

    #AWAC CODE
    if save_paths:
        algorithm.post_train_funcs.append(save_paths)

    if load_demos:
        path_loader = path_loader_class(
            trainer,
            replay_buffer=replay_buffer,
            demo_train_buffer=demo_train_buffer,
            demo_test_buffer=demo_test_buffer,
            # reward_fn=eval_reward, # omit reward because its recomputed later
            **path_loader_kwargs)
        path_loader.load_demos()
    if pretrain_policy:
        trainer.pretrain_policy_with_bc(
            policy,
            demo_train_buffer,
            demo_test_buffer,
            trainer.bc_num_pretrain_steps,
        )
    if pretrain_rl:
        trainer.pretrain_q_with_bc_data()

    if save_pretrained_algorithm:
        p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p')
        pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt')
        data = algorithm._get_snapshot()
        data['algorithm'] = algorithm
        torch.save(data, open(pt_path, "wb"))
        torch.save(data, open(p_path, "wb"))

    algorithm.train()
Ejemplo n.º 8
0
def rl_context_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.torch.td3.td3 import TD3 as TD3Trainer
    from rlkit.torch.sac.sac import SACTrainer
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.sac.policies import MakeDeterministic

    preprocess_rl_variant(variant)
    max_path_length = variant['max_path_length']
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = variant.get('achieved_goal_key', 'latent_achieved_goal')

    contextual_mdp = variant.get('contextual_mdp', True)
    print("contextual_mdp:", contextual_mdp)

    mask_variant = variant.get('mask_variant', {})
    mask_conditioned = mask_variant.get('mask_conditioned', False)
    print("mask_conditioned:", mask_conditioned)

    if mask_conditioned:
        assert contextual_mdp

    if 'sac' in variant['algorithm'].lower():
        rl_algo = 'sac'
    elif 'td3' in variant['algorithm'].lower():
        rl_algo = 'td3'
    else:
        raise NotImplementedError
    print("RL algorithm:", rl_algo)

    ### load the example dataset, if running checkpoints ###
    if 'ckpt' in variant:
        import os.path as osp
        example_set_variant = variant.get('example_set_variant', dict())
        example_set_variant['use_cache'] = True
        example_set_variant['cache_path'] = osp.join(variant['ckpt'], 'example_dataset.npy')

    if mask_conditioned:
        env = get_envs(variant)
        mask_format = mask_variant['param_variant']['mask_format']
        assert mask_format in ['vector', 'matrix', 'distribution', 'cond_distribution']
        goal_dim = env.observation_space.spaces[desired_goal_key].low.size
        if mask_format in ['vector']:
            context_dim_for_networks = goal_dim + goal_dim
        elif mask_format in ['matrix', 'distribution', 'cond_distribution']:
            context_dim_for_networks = goal_dim + (goal_dim * goal_dim)
        else:
            raise TypeError

        if 'ckpt' in variant:
            from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
            import os.path as osp

            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'masks.npy'))
            masks = np.load(filename, allow_pickle=True)[()]
        else:
            masks = get_mask_params(
                env=env,
                example_set_variant=variant['example_set_variant'],
                param_variant=mask_variant['param_variant'],
            )

        mask_keys = list(masks.keys())
        context_keys = [desired_goal_key] + mask_keys
    else:
        context_keys = [desired_goal_key]


    def contextual_env_distrib_and_reward(mode='expl'):
        assert mode in ['expl', 'eval']
        env = get_envs(variant)

        if mode == 'expl':
            goal_sampling_mode = variant.get('expl_goal_sampling_mode', None)
        elif mode == 'eval':
            goal_sampling_mode = variant.get('eval_goal_sampling_mode', None)
        if goal_sampling_mode not in [None, 'example_set']:
            env.goal_sampling_mode = goal_sampling_mode

        mask_ids_for_training = mask_variant.get('mask_ids_for_training', None)

        if mask_conditioned:
            context_distrib = MaskDictDistribution(
                env,
                desired_goal_keys=[desired_goal_key],
                mask_format=mask_format,
                masks=masks,
                max_subtasks_to_focus_on=mask_variant.get('max_subtasks_to_focus_on', None),
                prev_subtask_weight=mask_variant.get('prev_subtask_weight', None),
                mask_distr=mask_variant.get('train_mask_distr', None),
                mask_ids=mask_ids_for_training,
            )
            reward_fn = ContextualMaskingRewardFn(
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                mask_keys=mask_keys,
                mask_format=mask_format,
                use_g_for_mean=mask_variant['use_g_for_mean'],
                use_squared_reward=mask_variant.get('use_squared_reward', False),
            )
        else:
            if goal_sampling_mode == 'example_set':
                example_dataset = gen_example_sets(get_envs(variant), variant['example_set_variant'])
                assert len(example_dataset['list_of_waypoints']) == 1
                from rlkit.envs.contextual.set_distributions import GoalDictDistributionFromSet
                context_distrib = GoalDictDistributionFromSet(
                    example_dataset['list_of_waypoints'][0],
                    desired_goal_keys=[desired_goal_key],
                )
            else:
                context_distrib = GoalDictDistributionFromMultitaskEnv(
                    env,
                    desired_goal_keys=[desired_goal_key],
                )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=variant['contextual_replay_buffer_kwargs'].get('observation_keys', None),
            )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=context_distrib,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info if not variant.get('keep_env_infos', False) else None,
        )
        return env, context_distrib, reward_fn

    env, context_distrib, reward_fn = contextual_env_distrib_and_reward(mode='expl')
    eval_env, eval_context_distrib, _ = contextual_env_distrib_and_reward(mode='eval')

    if mask_conditioned:
        obs_dim = (
            env.observation_space.spaces[observation_key].low.size
            + context_dim_for_networks
        )
    elif contextual_mdp:
        obs_dim = (
            env.observation_space.spaces[observation_key].low.size
            + env.observation_space.spaces[desired_goal_key].low.size
        )
    else:
        obs_dim = env.observation_space.spaces[observation_key].low.size

    action_dim = env.action_space.low.size

    if 'ckpt' in variant and 'ckpt_epoch' in variant:
        from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
        import os.path as osp

        ckpt_epoch = variant['ckpt_epoch']
        if ckpt_epoch is not None:
            epoch = variant['ckpt_epoch']
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch))
        else:
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'params.pkl'))
        print("Loading ckpt from", filename)
        data = torch.load(filename)
        qf1 = data['trainer/qf1']
        qf2 = data['trainer/qf2']
        target_qf1 = data['trainer/target_qf1']
        target_qf2 = data['trainer/target_qf2']
        policy = data['trainer/policy']
        eval_policy = data['evaluation/policy']
        expl_policy = data['exploration/policy']
    else:
        qf1 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        qf2 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        target_qf1 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        target_qf2 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        if rl_algo == 'td3':
            policy = TanhMlpPolicy(
                input_size=obs_dim,
                output_size=action_dim,
                **variant['policy_kwargs']
            )
            target_policy = TanhMlpPolicy(
                input_size=obs_dim,
                output_size=action_dim,
                **variant['policy_kwargs']
            )
            expl_policy = create_exploration_policy(
                env, policy,
                exploration_version=variant['exploration_type'],
                exploration_noise=variant['exploration_noise'],
            )
            eval_policy = policy
        elif rl_algo == 'sac':
            policy = TanhGaussianPolicy(
                obs_dim=obs_dim,
                action_dim=action_dim,
                **variant['policy_kwargs']
            )
            expl_policy = policy
            eval_policy = MakeDeterministic(policy)

    post_process_mask_fn = partial(
        full_post_process_mask_fn,
        mask_conditioned=mask_conditioned,
        mask_variant=mask_variant,
        context_distrib=context_distrib,
        context_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
    )

    def context_from_obs_dict_fn(obs_dict):
        context_dict = {
            desired_goal_key: obs_dict[achieved_goal_key]
        }

        if mask_conditioned:
            sample_masks_for_relabeling = mask_variant.get('sample_masks_for_relabeling', True)
            if sample_masks_for_relabeling:
                batch_size = next(iter(obs_dict.values())).shape[0]
                sampled_contexts = context_distrib.sample(batch_size)
                for mask_key in mask_keys:
                    context_dict[mask_key] = sampled_contexts[mask_key]
            else:
                for mask_key in mask_keys:
                    context_dict[mask_key] = obs_dict[mask_key]

        return context_dict

    def concat_context_to_obs(batch, replay_buffer=None, obs_dict=None, next_obs_dict=None, new_contexts=None):
        obs = batch['observations']
        next_obs = batch['next_observations']
        batch_size = obs.shape[0]
        if mask_conditioned:
            if obs_dict is not None and new_contexts is not None:
                if not mask_variant.get('relabel_masks', True):
                    for k in mask_keys:
                        new_contexts[k] = next_obs_dict[k][:]
                    batch.update(new_contexts)
                if not mask_variant.get('relabel_goals', True):
                    new_contexts[desired_goal_key] = next_obs_dict[desired_goal_key][:]
                    batch.update(new_contexts)

                new_contexts = post_process_mask_fn(obs_dict, new_contexts)
                batch.update(new_contexts)

            if mask_format in ['vector', 'matrix']:
                goal = batch[desired_goal_key]
                mask = batch['mask'].reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, goal, mask], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, goal, mask], axis=1)
            elif mask_format == 'distribution':
                goal = batch[desired_goal_key]
                sigma_inv = batch['mask_sigma_inv'].reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, goal, sigma_inv], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, goal, sigma_inv], axis=1)
            elif mask_format == 'cond_distribution':
                goal = batch[desired_goal_key]
                mu_w = batch['mask_mu_w']
                mu_g = batch['mask_mu_g']
                mu_A = batch['mask_mu_mat']
                sigma_inv = batch['mask_sigma_inv']
                if mask_variant['use_g_for_mean']:
                    mu_w_given_g = goal
                else:
                    mu_w_given_g = mu_w + np.squeeze(mu_A @ np.expand_dims(goal - mu_g, axis=-1), axis=-1)
                sigma_w_given_g_inv = sigma_inv.reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
            else:
                raise NotImplementedError
        elif contextual_mdp:
            goal = batch[desired_goal_key]
            batch['observations'] = np.concatenate([obs, goal], axis=1)
            batch['next_observations'] = np.concatenate([next_obs, goal], axis=1)
        else:
            batch['observations'] = obs
            batch['next_observations'] = next_obs

        return batch

    if 'observation_keys' not in variant['contextual_replay_buffer_kwargs']:
        variant['contextual_replay_buffer_kwargs']['observation_keys'] = []
    observation_keys = variant['contextual_replay_buffer_kwargs']['observation_keys']
    if observation_key not in observation_keys:
        observation_keys.append(observation_key)
    if achieved_goal_key not in observation_keys:
        observation_keys.append(achieved_goal_key)

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=env,
        context_keys=context_keys,
        context_distribution=context_distrib,
        sample_context_from_obs_dict_fn=context_from_obs_dict_fn,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        **variant['contextual_replay_buffer_kwargs']
    )

    if rl_algo == 'td3':
        trainer = TD3Trainer(
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            target_policy=target_policy,
            **variant['td3_trainer_kwargs']
        )
    elif rl_algo == 'sac':
        trainer = SACTrainer(
            env=env,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            **variant['sac_trainer_kwargs']
        )

    def create_path_collector(
            env,
            policy,
            mode='expl',
            mask_kwargs={},
    ):
        assert mode in ['expl', 'eval']

        save_env_in_snapshot = variant.get('save_env_in_snapshot', True)

        if mask_conditioned:
            if 'rollout_mask_order' in mask_kwargs:
                rollout_mask_order = mask_kwargs['rollout_mask_order']
            else:
                if mode == 'expl':
                    rollout_mask_order = mask_variant.get('rollout_mask_order_for_expl', 'fixed')
                elif mode == 'eval':
                    rollout_mask_order = mask_variant.get('rollout_mask_order_for_eval', 'fixed')
                else:
                    raise TypeError

            if 'mask_distr' in mask_kwargs:
                mask_distr = mask_kwargs['mask_distr']
            else:
                if mode == 'expl':
                    mask_distr = mask_variant['expl_mask_distr']
                elif mode == 'eval':
                    mask_distr = mask_variant['eval_mask_distr']
                else:
                    raise TypeError

            if 'mask_ids' in mask_kwargs:
                mask_ids = mask_kwargs['mask_ids']
            else:
                if mode == 'expl':
                    mask_ids = mask_variant.get('mask_ids_for_expl', None)
                elif mode == 'eval':
                    mask_ids = mask_variant.get('mask_ids_for_eval', None)
                else:
                    raise TypeError

            prev_subtask_weight = mask_variant.get('prev_subtask_weight', None)
            max_subtasks_to_focus_on = mask_variant.get('max_subtasks_to_focus_on', None)
            max_subtasks_per_rollout = mask_variant.get('max_subtasks_per_rollout', None)

            mode = mask_variant.get('context_post_process_mode', None)
            if mode in ['dilute_prev_subtasks_uniform', 'dilute_prev_subtasks_fixed']:
                prev_subtask_weight = 0.5

            return MaskPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                concat_context_to_obs_fn=concat_context_to_obs,
                save_env_in_snapshot=save_env_in_snapshot,
                mask_sampler=(context_distrib if mode=='expl' else eval_context_distrib),
                mask_distr=mask_distr.copy(),
                mask_ids=mask_ids,
                max_path_length=max_path_length,
                rollout_mask_order=rollout_mask_order,
                prev_subtask_weight=prev_subtask_weight,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                max_subtasks_per_rollout=max_subtasks_per_rollout,
            )
        elif contextual_mdp:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                save_env_in_snapshot=save_env_in_snapshot,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[],
                save_env_in_snapshot=save_env_in_snapshot,
            )

    expl_path_collector = create_path_collector(env, expl_policy, mode='expl')
    eval_path_collector = create_path_collector(eval_env, eval_policy, mode='eval')

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **variant['algo_kwargs']
    )

    algorithm.to(ptu.device)

    if variant.get("save_video", True):
        save_period = variant.get('save_video_period', 50)
        dump_video_kwargs = variant.get("dump_video_kwargs", dict())
        dump_video_kwargs['horizon'] = max_path_length

        renderer = EnvRenderer(**variant.get('renderer_kwargs', {}))

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImagesEnv(state_env, renderers={
                'image_observation' : renderer,
            })
            context_env = ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=reward_fn,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
            return context_env

        img_eval_env = add_images(eval_env, eval_context_distrib)

        if variant.get('log_eval_video', True):
            video_path_collector = create_path_collector(img_eval_env, eval_policy, mode='eval')
            rollout_function = video_path_collector._rollout_fn
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                eval_policy,
                tag="eval",
                imsize=variant['renderer_kwargs']['width'],
                image_format='CHW',
                save_video_period=save_period,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(eval_video_func)

        # additional eval videos for mask conditioned case
        if mask_conditioned:
            default_list = [
                'atomic',
                'atomic_seq',
                'cumul_seq',
                'full',
            ]
            eval_rollouts_for_videos = mask_variant.get('eval_rollouts_for_videos', default_list)
            for key in eval_rollouts_for_videos:
                assert key in default_list

            if 'cumul_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            cumul_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_cumul" if mask_conditioned else "eval",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'full' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            full=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_full",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'atomic_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            atomic_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_atomic",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

        if variant.get('log_expl_video', True) and not variant['algo_kwargs'].get('eval_only', False):
            img_expl_env = add_images(env, context_distrib)
            video_path_collector = create_path_collector(img_expl_env, expl_policy, mode='expl')
            rollout_function = video_path_collector._rollout_fn
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                expl_policy,
                tag="expl",
                imsize=variant['renderer_kwargs']['width'],
                image_format='CHW',
                save_video_period=save_period,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(expl_video_func)

    addl_collectors = []
    addl_log_prefixes = []
    if mask_conditioned and mask_variant.get('log_mask_diagnostics', True):
        default_list = [
            'atomic',
            'atomic_seq',
            'cumul_seq',
            'full',
        ]
        eval_rollouts_to_log = mask_variant.get('eval_rollouts_to_log', default_list)
        for key in eval_rollouts_to_log:
            assert key in default_list

        # atomic masks
        if 'atomic' in eval_rollouts_to_log:
            for mask_id in eval_path_collector.mask_ids:
                mask_kwargs=dict(
                    mask_ids=[mask_id],
                    mask_distr=dict(
                        atomic=1.0,
                    ),
                )
                collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
                addl_collectors.append(collector)
            addl_log_prefixes += [
                'mask_{}/'.format(''.join(str(mask_id)))
                for mask_id in eval_path_collector.mask_ids
            ]

        # full mask
        if 'full' in eval_rollouts_to_log:
            mask_kwargs=dict(
                mask_distr=dict(
                    full=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_full/')

        # cumulative, sequential mask
        if 'cumul_seq' in eval_rollouts_to_log:
            mask_kwargs=dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    cumul_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_cumul_seq/')

        # atomic, sequential mask
        if 'atomic_seq' in eval_rollouts_to_log:
            mask_kwargs=dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    atomic_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_atomic_seq/')

        def get_mask_diagnostics(unused):
            from rlkit.core.logging import append_log, add_prefix, OrderedDict
            log = OrderedDict()
            for prefix, collector in zip(addl_log_prefixes, addl_collectors):
                paths = collector.collect_new_paths(
                    max_path_length,
                    variant['algo_kwargs']['num_eval_steps_per_epoch'],
                    discard_incomplete_paths=True,
                )
                old_path_info = eval_env.get_diagnostics(paths)

                keys_to_keep = []
                for key in old_path_info.keys():
                    if ('env_infos' in key) and ('final' in key) and ('Mean' in key):
                        keys_to_keep.append(key)
                path_info = OrderedDict()
                for key in keys_to_keep:
                    path_info[key] = old_path_info[key]

                generic_info = add_prefix(
                    path_info,
                    prefix,
                )
                append_log(log, generic_info)

            for collector in addl_collectors:
                collector.end_epoch(0)
            return log

        algorithm._eval_get_diag_fns.append(get_mask_diagnostics)
        
    if 'ckpt' in variant:
        from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
        import os.path as osp
        assert variant['algo_kwargs'].get('eval_only', False)

        def update_networks(algo, epoch):
            if 'ckpt_epoch' in variant:
                return

            if epoch % algo._eval_epoch_freq == 0:
                filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch))
                print("Loading ckpt from", filename)
                data = torch.load(filename)#, map_location='cuda:1')
                eval_policy = data['evaluation/policy']
                eval_policy.to(ptu.device)
                algo.eval_data_collector._policy = eval_policy
                for collector in addl_collectors:
                    collector._policy = eval_policy

        algorithm.post_train_funcs.insert(0, update_networks)

    algorithm.train()
Ejemplo n.º 9
0
def disco_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    generate_set_for_rl_kwargs,
    # VAE parameters
    create_vae_kwargs,
    vae_trainer_kwargs,
    vae_algo_kwargs,
    data_loader_kwargs,
    generate_set_for_vae_pretraining_kwargs,
    num_ungrouped_images,
    beta_schedule_kwargs=None,
    # Oracle settings
    use_ground_truth_reward=False,
    use_onehot_set_embedding=False,
    use_dummy_model=False,
    observation_key="latent_observation",
    # RIG comparison
    rig_goal_setter_kwargs=None,
    rig=False,
    # Miscellaneous
    reward_fn_kwargs=None,
    # None-VAE Params
    env_id=None,
    env_class=None,
    env_kwargs=None,
    latent_observation_key="latent_observation",
    state_observation_key="state_observation",
    image_observation_key="image_observation",
    set_description_key="set_description",
    example_state_key="example_state",
    example_image_key="example_image",
    # Exploration
    exploration_policy_kwargs=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
):
    if rig_goal_setter_kwargs is None:
        rig_goal_setter_kwargs = {}
    if reward_fn_kwargs is None:
        reward_fn_kwargs = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    renderer = EnvRenderer(**renderer_kwargs)

    sets = create_sets(
        env_id,
        env_class,
        env_kwargs,
        renderer,
        example_state_key=example_state_key,
        example_image_key=example_image_key,
        **generate_set_for_rl_kwargs,
    )
    if use_dummy_model:
        model = create_dummy_image_vae(img_chw=renderer.image_chw,
                                       **create_vae_kwargs)
    else:
        model = train_set_vae(
            create_vae_kwargs,
            vae_trainer_kwargs,
            vae_algo_kwargs,
            data_loader_kwargs,
            generate_set_for_vae_pretraining_kwargs,
            num_ungrouped_images,
            env_id=env_id,
            env_class=env_class,
            env_kwargs=env_kwargs,
            beta_schedule_kwargs=beta_schedule_kwargs,
            sets=sets,
            renderer=renderer,
        )
    expl_env, expl_context_distrib, expl_reward = (
        contextual_env_distrib_and_reward(
            vae=model,
            sets=sets,
            state_env=get_gym_env(
                env_id,
                env_class=env_class,
                env_kwargs=env_kwargs,
            ),
            renderer=renderer,
            reward_fn_kwargs=reward_fn_kwargs,
            use_ground_truth_reward=use_ground_truth_reward,
            state_observation_key=state_observation_key,
            latent_observation_key=latent_observation_key,
            example_image_key=example_image_key,
            set_description_key=set_description_key,
            observation_key=observation_key,
            image_observation_key=image_observation_key,
            rig_goal_setter_kwargs=rig_goal_setter_kwargs,
        ))
    eval_env, eval_context_distrib, eval_reward = (
        contextual_env_distrib_and_reward(
            vae=model,
            sets=sets,
            state_env=get_gym_env(
                env_id,
                env_class=env_class,
                env_kwargs=env_kwargs,
            ),
            renderer=renderer,
            reward_fn_kwargs=reward_fn_kwargs,
            use_ground_truth_reward=use_ground_truth_reward,
            state_observation_key=state_observation_key,
            latent_observation_key=latent_observation_key,
            example_image_key=example_image_key,
            set_description_key=set_description_key,
            observation_key=observation_key,
            image_observation_key=image_observation_key,
            rig_goal_setter_kwargs=rig_goal_setter_kwargs,
            oracle_rig_goal=rig,
        ))
    context_keys = [
        expl_context_distrib.mean_key,
        expl_context_distrib.covariance_key,
        expl_context_distrib.set_index_key,
        expl_context_distrib.set_embedding_key,
    ]
    if rig:
        context_keys_for_rl = [
            expl_context_distrib.mean_key,
        ]
    else:
        if use_onehot_set_embedding:
            context_keys_for_rl = [
                expl_context_distrib.set_embedding_key,
            ]
        else:
            context_keys_for_rl = [
                expl_context_distrib.mean_key,
                expl_context_distrib.covariance_key,
            ]

    obs_dim = np.prod(expl_env.observation_space.spaces[observation_key].shape)
    obs_dim += sum([
        np.prod(expl_env.observation_space.spaces[k].shape)
        for k in context_keys_for_rl
    ])
    action_dim = np.prod(expl_env.action_space.shape)

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch["observations"]
        next_obs = batch["next_observations"]
        contexts = [batch[k] for k in context_keys_for_rl]
        batch["observations"] = np.concatenate((obs, *contexts), axis=1)
        batch["next_observations"] = np.concatenate(
            (next_obs, *contexts),
            axis=1,
        )
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=context_keys,
        observation_keys=list(
            {observation_key, state_observation_key, latent_observation_key}),
        observation_key=observation_key,
        context_distribution=FilterKeys(
            expl_context_distrib,
            context_keys,
        ),
        sample_context_from_obs_dict_fn=None,
        # RemapKeyFn({context_key: observation_key}),
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs,
    )
    trainer = SACTrainer(
        env=expl_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **sac_trainer_kwargs,
    )

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=context_keys_for_rl,
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=context_keys_for_rl,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs,
    )
    algorithm.to(ptu.device)

    if save_video:
        set_index_key = eval_context_distrib.set_index_key
        expl_video_func = DisCoVideoSaveFunction(
            model,
            sets,
            expl_path_collector,
            tag="train",
            reconstruction_key="image_reconstruction",
            decode_set_image_key="decoded_set_prior",
            set_visualization_key="set_visualization",
            example_image_key=example_image_key,
            set_index_key=set_index_key,
            columns=len(sets),
            unnormalize=True,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs,
        )
        algorithm.post_train_funcs.append(expl_video_func)

        eval_video_func = DisCoVideoSaveFunction(
            model,
            sets,
            eval_path_collector,
            tag="eval",
            reconstruction_key="image_reconstruction",
            decode_set_image_key="decoded_set_prior",
            set_visualization_key="set_visualization",
            example_image_key=example_image_key,
            set_index_key=set_index_key,
            columns=len(sets),
            unnormalize=True,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs,
        )
        algorithm.post_train_funcs.append(eval_video_func)

    algorithm.train()
Ejemplo n.º 10
0
def masking_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    env_class=None,
    env_kwargs=None,
    observation_key='state_observation',
    desired_goal_key='state_desired_goal',
    achieved_goal_key='state_achieved_goal',
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
    train_env_id=None,
    eval_env_id=None,
    do_masking=True,
    mask_key="masked_observation",
    masking_eval_steps=200,
    log_mask_diagnostics=True,
    mask_dim=None,
    masking_reward_fn=None,
    masking_for_exploration=True,
    rotate_masks_for_eval=False,
    rotate_masks_for_expl=False,
    mask_distribution=None,
    num_steps_per_mask_change=10,
    tag=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    context_key = desired_goal_key
    env = get_gym_env(train_env_id, env_class=env_class, env_kwargs=env_kwargs)
    mask_dim = (mask_dim or env.observation_space.spaces[context_key].low.size)

    if not do_masking:
        mask_distribution = 'all_ones'

    assert mask_distribution in [
        'one_hot_masks', 'random_bit_masks', 'all_ones'
    ]
    if mask_distribution == 'all_ones':
        mask_distribution = 'static_mask'

    def contextual_env_distrib_and_reward(
        env_id,
        env_class,
        env_kwargs,
        goal_sampling_mode,
        env_mask_distribution_type,
        static_mask=None,
    ):
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        env.goal_sampling_mode = goal_sampling_mode
        use_static_mask = (not do_masking
                           or env_mask_distribution_type == 'static_mask')
        # Default to all ones mask if static mask isn't defined and should
        # be using static masks.
        if use_static_mask and static_mask is None:
            static_mask = np.ones(mask_dim)
        if not do_masking:
            assert env_mask_distribution_type == 'static_mask'

        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            env,
            desired_goal_keys=[desired_goal_key],
        )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=mask_dim,
            distribution_type=env_mask_distribution_type,
            static_mask=static_mask,
        )

        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=env,
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                observation_key),
        )

        if do_masking:
            if masking_reward_fn:
                reward_fn = partial(masking_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)
            else:
                reward_fn = partial(default_masked_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )

        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
        )
        return env, goal_distribution, reward_fn

    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        train_env_id,
        env_class,
        env_kwargs,
        exploration_goal_sampling_mode,
        mask_distribution if masking_for_exploration else 'static_mask',
    )
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        eval_env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        'static_mask')

    # Distribution for relabeling
    relabel_context_distrib = GoalDictDistributionFromMultitaskEnv(
        env,
        desired_goal_keys=[desired_goal_key],
    )
    relabel_context_distrib = MaskedGoalDictDistribution(
        relabel_context_distrib,
        mask_key=mask_key,
        mask_dim=mask_dim,
        distribution_type=mask_distribution,
        static_mask=None if do_masking else np.ones(mask_dim),
    )

    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size +
               mask_dim)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def context_from_obs_dict_fn(obs_dict):
        achieved_goal = obs_dict['state_achieved_goal']
        # Should the mask be randomized for future relabeling?
        # batch_size = len(achieved_goal)
        # mask = np.random.choice(2, size=(batch_size, mask_dim))
        mask = obs_dict[mask_key]
        return {
            mask_key: mask,
            context_key: achieved_goal,
        }

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']

        context = batch[context_key]
        mask = batch[mask_key]

        batch['observations'] = np.concatenate([obs, context, mask], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context, mask],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key, mask_key],
        observation_keys=[observation_key, 'state_achieved_goal'],
        context_distribution=relabel_context_distrib,
        sample_context_from_obs_dict_fn=context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    def create_path_collector(env, policy, is_rotating):
        if is_rotating:
            assert do_masking and mask_distribution == 'one_hot_masks'
            return RotatingMaskingPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[context_key, mask_key],
                mask_key=mask_key,
                mask_length=mask_dim,
                num_steps_per_mask_change=num_steps_per_mask_change,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[context_key, mask_key],
            )

    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)

    eval_path_collector = create_path_collector(eval_env,
                                                MakeDeterministic(policy),
                                                rotate_masks_for_eval)
    expl_path_collector = create_path_collector(expl_env, exploration_policy,
                                                rotate_masks_for_expl)

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            context_keys_for_policy=[context_key, mask_key],
            # Eval on everything for the base video
            obs_processor=lambda o: np.hstack(
                (o[observation_key], o[context_key], np.ones(mask_dim))))
        renderer = EnvRenderer(**renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImageEnv(state_env, renderer=renderer)
            return ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=eval_reward,
                observation_key=observation_key,
                # update_env_info_fn=DeleteOldEnvInfo(),
            )

        img_eval_env = add_images(eval_env, eval_context_distrib)
        img_expl_env = add_images(expl_env, expl_context_distrib)
        eval_video_func = get_save_video_function(
            rollout_function,
            img_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=renderer.image_shape[0],
            image_format='CWH',
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="train",
            imsize=renderer.image_shape[0],
            image_format='CWH',
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    # For diagnostics, evaluate the policy on each individual dimension of the
    # mask.
    masks = []
    collectors = []
    for mask_idx in range(mask_dim):
        mask = np.zeros(mask_dim)
        mask[mask_idx] = 1
        masks.append(mask)

    for mask in masks:
        for_dim_mask = mask if do_masking else np.ones(mask_dim)
        masked_env, _, _ = contextual_env_distrib_and_reward(
            eval_env_id,
            env_class,
            env_kwargs,
            evaluation_goal_sampling_mode,
            'static_mask',
            static_mask=for_dim_mask)

        collector = ContextualPathCollector(
            masked_env,
            MakeDeterministic(policy),
            observation_key=observation_key,
            context_keys_for_policy=[context_key, mask_key],
        )
        collectors.append(collector)
    log_prefixes = [
        'mask_{}/'.format(''.join(mask.astype(int).astype(str)))
        for mask in masks
    ]

    def get_mask_diagnostics(unused):
        from rlkit.core.logging import append_log, add_prefix, OrderedDict
        from rlkit.misc import eval_util
        log = OrderedDict()
        for prefix, collector in zip(log_prefixes, collectors):
            paths = collector.collect_new_paths(
                max_path_length,
                masking_eval_steps,
                discard_incomplete_paths=True,
            )
            generic_info = add_prefix(
                eval_util.get_generic_path_information(paths),
                prefix,
            )
            append_log(log, generic_info)

        for collector in collectors:
            collector.end_epoch(0)
        return log

    if log_mask_diagnostics:
        algorithm._eval_get_diag_fns.append(get_mask_diagnostics)

    algorithm.train()