Beispiel #1
0
 def create_eval_path_collector(some_eval_env):
     return ContextualPathCollector(
         some_eval_env,
         eval_policy,
         observation_key=observation_key,
         context_keys_for_policy=[context_key],
     )
Beispiel #2
0
    def create_path_collector(
            env,
            policy,
            mode='expl',
            mask_kwargs=None,
    ):
        if mask_kwargs is None:
            mask_kwargs = {}
        assert mode in ['expl', 'eval']
        if mask_conditioned:
            if 'rollout_mask_order' in mask_kwargs:
                rollout_mask_order = mask_kwargs['rollout_mask_order']
            else:
                if mode == 'expl':
                    rollout_mask_order = rollout_mask_order_for_expl
                elif mode == 'eval':
                    rollout_mask_order = rollout_mask_order_for_eval
                else:
                    raise NotImplementedError

            if 'mask_distr' in mask_kwargs:
                mask_distr = mask_kwargs['mask_distr']
            else:
                if mode == 'expl':
                    mask_distr = expl_mask_distr
                elif mode == 'eval':
                    mask_distr = eval_mask_distr
                else:
                    raise NotImplementedError

            return MaskPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                concat_context_to_obs_fn=concat_context_to_obs,
                save_env_in_snapshot=save_env_in_snapshot,
                mask_sampler=(
                    context_distrib if mode == 'expl' else eval_context_distrib),
                mask_distr=mask_distr.copy(),
                mask_groups=mask_groups,
                max_path_length=max_path_length,
                rollout_mask_order=rollout_mask_order,
                prev_subtask_weight=prev_subtask_weight,
                prev_subtasks_solved=prev_subtasks_solved,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                max_subtasks_per_rollout=max_subtasks_per_rollout,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                save_env_in_snapshot=save_env_in_snapshot,
            )
Beispiel #3
0
def generate_trajectories(
        snapshot_path,
        max_path_length,
        num_steps,
        save_observation_keys,
):
    ptu.set_gpu_mode(True)
    snapshot = asset_loader.load_local_or_remote_file(
        snapshot_path,
        file_type='torch',
    )
    policy = snapshot['exploration/policy']
    env = snapshot['exploration/env']
    observation_key = snapshot['exploration/observation_key']
    context_keys_for_rl = snapshot['exploration/context_keys_for_policy']
    path_collector = ContextualPathCollector(
        env,
        policy,
        observation_key=observation_key,
        context_keys_for_policy=context_keys_for_rl,
    )
    policy.to(ptu.device)
    paths = path_collector.collect_new_paths(
        max_path_length,
        num_steps,
        True,
    )

    trajectories = []
    for path in paths:
        trajectory = dict(
            actions=path['actions'],
            terminals=path['terminals'],
        )
        for key in save_observation_keys:
            trajectory[key] = np.array([
                obs[key] for obs in path['full_observations']
            ])
            trajectory['next_' + key] = np.array([
                obs[key] for obs in path['full_next_observations']
            ])
        trajectories.append(trajectory)
    return trajectories
Beispiel #4
0
 def create_path_collector(
     env,
     policy,
     mode='expl',
 ):
     return ContextualPathCollector(
         env,
         policy,
         observation_key=observation_key,
         context_keys_for_policy=context_keys,
         save_env_in_snapshot=save_env_in_snapshot,
     )
Beispiel #5
0
 def create_path_collector(env, policy, is_rotating):
     if is_rotating:
         assert do_masking and mask_distribution == 'one_hot_masks'
         return RotatingMaskingPathCollector(
             env,
             policy,
             observation_key=observation_key,
             context_keys_for_policy=[context_key, mask_key],
             mask_key=mask_key,
             mask_length=mask_dim,
             num_steps_per_mask_change=num_steps_per_mask_change,
         )
     else:
         return ContextualPathCollector(
             env,
             policy,
             observation_key=observation_key,
             context_keys_for_policy=[context_key, mask_key],
         )
Beispiel #6
0
def image_based_goal_conditioned_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    cnn_kwargs,
    policy_type='tanh-normal',
    env_id=None,
    env_class=None,
    env_kwargs=None,
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    reward_type='state_distance',
    env_renderer_kwargs=None,
    # Data augmentations
    apply_random_crops=False,
    random_crop_pixel_shift=4,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not env_renderer_kwargs:
        env_renderer_kwargs = {}
    if not video_renderer_kwargs:
        video_renderer_kwargs = {}
    img_observation_key = 'image_observation'
    img_desired_goal_key = 'image_desired_goal'
    state_observation_key = 'state_observation'
    state_desired_goal_key = 'state_desired_goal'
    state_achieved_goal_key = 'state_achieved_goal'
    sample_context_from_obs_dict_fn = RemapKeyFn({
        'image_desired_goal':
        'image_observation',
        'state_desired_goal':
        'state_observation',
    })

    def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode,
                             renderer):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        state_env.goal_sampling_mode = goal_sampling_mode
        state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        image_goal_distribution = AddImageDistribution(
            env=state_env,
            base_distribution=state_goal_distribution,
            image_goal_key=img_desired_goal_key,
            renderer=renderer,
        )
        goal_distribution = PresampledDistribution(image_goal_distribution,
                                                   5000)
        img_env = InsertImageEnv(state_env, renderer=renderer)
        if reward_type == 'state_distance':
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=state_env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    'state_observation'),
                desired_goal_key=state_desired_goal_key,
                achieved_goal_key=state_achieved_goal_key,
            )
        elif reward_type == 'pixel_distance':
            reward_fn = NegativeL2Distance(
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    img_observation_key),
                desired_goal_key=img_desired_goal_key,
            )
        else:
            raise ValueError(reward_type)
        env = ContextualEnv(
            img_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=img_observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, goal_distribution, reward_fn

    env_renderer = EnvRenderer(**env_renderer_kwargs)
    expl_env, expl_context_distrib, expl_reward = setup_contextual_env(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode,
        env_renderer)
    eval_env, eval_context_distrib, eval_reward = setup_contextual_env(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        env_renderer)

    action_dim = expl_env.action_space.low.size
    if env_renderer.output_image_format == 'WHC':
        img_width, img_height, img_num_channels = (
            expl_env.observation_space[img_observation_key].shape)
    elif env_renderer.output_image_format == 'CHW':
        img_num_channels, img_height, img_width = (
            expl_env.observation_space[img_observation_key].shape)
    else:
        raise ValueError(env_renderer.output_image_format)

    def create_qf():
        cnn = BasicCNN(input_width=img_width,
                       input_height=img_height,
                       input_channels=img_num_channels,
                       **cnn_kwargs)
        joint_cnn = ApplyConvToStateAndGoalImage(cnn)
        return basic.MultiInputSequential(
            ApplyToObs(joint_cnn), basic.FlattenEachParallel(),
            ConcatMlp(input_size=joint_cnn.output_size + action_dim,
                      output_size=1,
                      **qf_kwargs))

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()
    cnn = BasicCNN(input_width=img_width,
                   input_height=img_height,
                   input_channels=img_num_channels,
                   **cnn_kwargs)
    joint_cnn = ApplyConvToStateAndGoalImage(cnn)
    policy_obs_dim = joint_cnn.output_size
    if policy_type == 'normal':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(Gaussian(obs_processor))
    elif policy_type == 'tanh-normal':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(TanhGaussian(obs_processor))
    elif policy_type == 'normal-tanh-mean':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           output_activations=['tanh', 'identity'],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(Gaussian(obs_processor))
    else:
        raise ValueError("Unknown policy type: {}".format(policy_type))

    if apply_random_crops:
        pad = BatchPad(
            env_renderer.output_image_format,
            random_crop_pixel_shift,
            random_crop_pixel_shift,
        )
        crop = JointRandomCrop(
            env_renderer.output_image_format,
            env_renderer.image_shape,
        )

        def concat_context_to_obs(batch, *args, **kwargs):
            obs = batch['observations']
            next_obs = batch['next_observations']
            context = batch[img_desired_goal_key]
            obs_padded = pad(obs)
            next_obs_padded = pad(next_obs)
            context_padded = pad(context)
            obs_aug, context_aug = crop(obs_padded, context_padded)
            next_obs_aug, next_context_aug = crop(next_obs_padded,
                                                  context_padded)

            batch['observations'] = np.concatenate([obs_aug, context_aug],
                                                   axis=1)
            batch['next_observations'] = np.concatenate(
                [next_obs_aug, next_context_aug], axis=1)
            return batch
    else:

        def concat_context_to_obs(batch, *args, **kwargs):
            obs = batch['observations']
            next_obs = batch['next_observations']
            context = batch[img_desired_goal_key]
            batch['observations'] = np.concatenate([obs, context], axis=1)
            batch['next_observations'] = np.concatenate([next_obs, context],
                                                        axis=1)
            return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[img_desired_goal_key, state_desired_goal_key],
        observation_key=img_observation_key,
        observation_keys=[img_observation_key, state_observation_key],
        context_distribution=eval_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=img_observation_key,
        context_keys_for_policy=[img_desired_goal_key],
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=img_observation_key,
        context_keys_for_policy=[img_desired_goal_key],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=img_observation_key,
            context_keys_for_policy=[img_desired_goal_key],
        )
        video_renderer = EnvRenderer(**video_renderer_kwargs)
        video_eval_env = InsertImageEnv(eval_env,
                                        renderer=video_renderer,
                                        image_key='video_observation')
        video_expl_env = InsertImageEnv(expl_env,
                                        renderer=video_renderer,
                                        image_key='video_observation')
        video_eval_env = ContextualEnv(
            video_eval_env,
            context_distribution=eval_env.context_distribution,
            reward_fn=lambda *_: np.array([0]),
            observation_key=img_observation_key,
        )
        video_expl_env = ContextualEnv(
            video_expl_env,
            context_distribution=expl_env.context_distribution,
            reward_fn=lambda *_: np.array([0]),
            observation_key=img_observation_key,
        )
        eval_video_func = get_save_video_function(
            rollout_function,
            video_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=video_renderer.image_shape[1],
            image_formats=[
                env_renderer.output_image_format,
                env_renderer.output_image_format,
                video_renderer.output_image_format,
            ],
            keys_to_show=[
                'image_desired_goal', 'image_observation', 'video_observation'
            ],
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            video_expl_env,
            exploration_policy,
            tag="xplor",
            imsize=video_renderer.image_shape[1],
            image_formats=[
                env_renderer.output_image_format,
                env_renderer.output_image_format,
                video_renderer.output_image_format,
            ],
            keys_to_show=[
                'image_desired_goal', 'image_observation', 'video_observation'
            ],
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    algorithm.train()
Beispiel #7
0
def goal_conditioned_sac_experiment(
        max_path_length,
        qf_kwargs,
        sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        algo_kwargs,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='state_observation',
        desired_goal_key='state_desired_goal',
        achieved_goal_key='state_achieved_goal',
        exploration_policy_kwargs=None,
        evaluation_goal_sampling_mode=None,
        exploration_goal_sampling_mode=None,
        # Video parameters
        save_video=True,
        save_video_kwargs=None,
        renderer_kwargs=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}
    context_key = desired_goal_key
    sample_context_from_obs_dict_fn = RemapKeyFn({context_key: observation_key})

    def contextual_env_distrib_and_reward(
            env_id, env_class, env_kwargs, goal_sampling_mode
    ):
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        env.goal_sampling_mode = goal_sampling_mode
        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            env,
            desired_goal_keys=[desired_goal_key],
        )
        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=env,
            achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key),
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
        )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, goal_distribution, reward_fn


    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode
    )
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode
    )

    obs_dim = (
            expl_env.observation_space.spaces[observation_key].low.size
            + expl_env.observation_space.spaces[context_key].low.size
    )
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **qf_kwargs
        )
    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **policy_kwargs
    )

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context], axis=1)
        return batch
    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key],
        observation_keys=[observation_key],
        context_distribution=eval_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs
    )
    trainer = SACTrainer(
        env=expl_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **sac_trainer_kwargs
    )

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )
    exploration_policy = create_exploration_policy(
        policy=policy, env=expl_env, **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs
    )
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            context_keys_for_policy=[context_key],
        )
        renderer = EnvRenderer(**renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImageEnv(state_env, renderer=renderer)
            return ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=eval_reward,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
        img_eval_env = add_images(eval_env, eval_context_distrib)
        img_expl_env = add_images(expl_env, expl_context_distrib)
        eval_video_func = get_save_video_function(
            rollout_function,
            img_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=renderer.width,
            image_format=renderer.output_image_format,
            **save_video_kwargs
        )
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="train",
            imsize=renderer.width,
            image_format=renderer.output_image_format,
            **save_video_kwargs
        )

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    algorithm.train()
Beispiel #8
0
def sac_on_gym_goal_env_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    env_id=None,
    env_class=None,
    env_kwargs=None,
    observation_key='observation',
    desired_goal_key='desired_goal',
    achieved_goal_key='achieved_goal',
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}
    context_key = desired_goal_key
    sample_context_from_obs_dict_fn = RemapKeyFn(
        {context_key: achieved_goal_key})

    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode):
        env = get_gym_env(
            env_id,
            env_class=env_class,
            env_kwargs=env_kwargs,
            unwrap_timed_envs=True,
        )
        env.goal_sampling_mode = goal_sampling_mode
        goal_distribution = GoalDictDistributionFromGymGoalEnv(
            env,
            desired_goal_key=desired_goal_key,
        )
        distance_fn = L2Distance(
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                achieved_goal_key, ),
            desired_goal_key=desired_goal_key,
        )
        if (isinstance(env, robotics.FetchReachEnv)
                or isinstance(env, robotics.FetchPushEnv)
                or isinstance(env, robotics.FetchPickAndPlaceEnv)
                or isinstance(env, robotics.FetchSlideEnv)):
            success_threshold = 0.05
        else:
            raise TypeError("I don't know the success threshold of env ", env)
        reward_fn = ThresholdDistanceReward(distance_fn, success_threshold)
        diag_fn = GenericGoalConditionedContextualDiagnostics(
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            success_threshold=success_threshold,
        )
        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, goal_distribution, reward_fn

    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode)
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode)

    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key],
        observation_keys_to_save=[observation_key, achieved_goal_key],
        context_distribution=eval_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )
    exploration_policy = create_exploration_policy(policy=policy,
                                                   env=expl_env,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:

        # Setting the goal like this is discourage, but the Fetch environment
        # are designed to visualize the goals by setting their goal parameter.
        def set_goal_for_visualization(env, policy, o):
            goal = o[desired_goal_key]
            print(goal)
            env.unwrapped.goal = goal

        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            context_keys_for_policy=[context_key],
            reset_callback=set_goal_for_visualization,
        )
        renderer = GymEnvRenderer(**renderer_kwargs)

        def add_images(env, context_distribution):
            state_env = env.env
            img_env = InsertImageEnv(
                state_env,
                renderer=renderer,
                image_key='image_observation',
            )
            return ContextualEnv(
                img_env,
                context_distribution=context_distribution,
                reward_fn=eval_reward,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )

        img_eval_env = add_images(eval_env, eval_context_distrib)
        img_expl_env = add_images(expl_env, expl_context_distrib)
        eval_video_func = get_save_video_function(
            rollout_function,
            img_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=renderer.image_chw[1],
            image_format=renderer.output_image_format,
            keys_to_show=['image_observation'],
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="train",
            imsize=renderer.image_chw[1],
            image_format=renderer.output_image_format,
            keys_to_show=['image_observation'],
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    algorithm.train()
Beispiel #9
0
def rig_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    train_vae_kwargs,
    env_id=None,
    env_class=None,
    env_kwargs=None,
    observation_key='latent_observation',
    desired_goal_key='latent_desired_goal',
    state_goal_key='state_desired_goal',
    state_observation_key='state_observation',
    image_goal_key='image_desired_goal',
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
    imsize=48,
    pretrained_vae_path="",
    init_camera=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)

    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)

        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        encoded_env = EncoderWrappedEnv(
            img_env,
            model,
            dict(image_observation="latent_observation", ),
        )
        if goal_sampling_mode == "vae_prior":
            latent_goal_distribution = PriorDistribution(
                model.representation_size,
                desired_goal_key,
            )
            diagnostics = StateImageGoalDiagnosticsFn({}, )
        elif goal_sampling_mode == "reset_of_env":
            state_goal_env = get_gym_env(env_id,
                                         env_class=env_class,
                                         env_kwargs=env_kwargs)
            state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
                state_goal_env,
                desired_goal_keys=[state_goal_key],
            )
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_goal_distribution,
                image_goal_key=image_goal_key,
                renderer=renderer,
            )
            latent_goal_distribution = AddLatentDistribution(
                image_goal_distribution,
                image_goal_key,
                desired_goal_key,
                model,
            )
            if hasattr(state_goal_env, 'goal_conditioned_diagnostics'):
                diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics(
                    state_goal_env.goal_conditioned_diagnostics,
                    desired_goal_key=state_goal_key,
                    observation_key=state_observation_key,
                )
            else:
                state_goal_env.get_contextual_diagnostics
                diagnostics = state_goal_env.get_contextual_diagnostics
        else:
            raise NotImplementedError('unknown goal sampling method: %s' %
                                      goal_sampling_mode)

        reward_fn = DistanceRewardFn(
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )

        env = ContextualEnv(
            encoded_env,
            context_distribution=latent_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, latent_goal_distribution, reward_fn

    if pretrained_vae_path:
        model = load_local_or_remote_file(pretrained_vae_path)
    else:
        model = train_vae(train_vae_kwargs, env_kwargs, env_id, env_class,
                          imsize, init_camera)

    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode)
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode)
    context_key = desired_goal_key

    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key],
        observation_keys=[observation_key],
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=RemapKeyFn(
            {context_key: observation_key}),
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        expl_video_func = RIGVideoSaveFunction(
            model,
            expl_path_collector,
            "train",
            decode_goal_image_key="image_decoded_goal",
            reconstruction_key="image_reconstruction",
            rows=2,
            columns=5,
            unnormalize=True,
            # max_path_length=200,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(expl_video_func)

        eval_video_func = RIGVideoSaveFunction(
            model,
            eval_path_collector,
            "eval",
            goal_image_key=image_goal_key,
            decode_goal_image_key="image_decoded_goal",
            reconstruction_key="image_reconstruction",
            num_imgs=4,
            rows=2,
            columns=5,
            unnormalize=True,
            # max_path_length=200,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)

    algorithm.train()
Beispiel #10
0
def encoder_goal_conditioned_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    algo_kwargs,
    policy_kwargs,
    # Encoder parameters
    disentangled_qf_kwargs,
    use_parallel_qf=False,
    encoder_kwargs=None,
    encoder_cnn_kwargs=None,
    qf_state_encoder_is_goal_encoder=False,
    reward_type='encoder_distance',
    reward_config=None,
    latent_dim=8,
    # Policy params
    policy_using_encoder_settings=None,
    use_separate_encoder_for_policy=True,
    # Env settings
    env_id=None,
    env_class=None,
    env_kwargs=None,
    contextual_env_kwargs=None,
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    num_presampled_goals=5000,
    # Image env parameters
    use_image_observations=False,
    env_renderer_kwargs=None,
    # Video parameters
    save_video=True,
    save_debug_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
    # Debugging parameters
    visualize_representation=True,
    distance_scatterplot_save_period=0,
    distance_scatterplot_initial_save_period=0,
    debug_renderer_kwargs=None,
    debug_visualization_kwargs=None,
    use_debug_trainer=False,

    # vae stuff
    train_encoder_as_vae=False,
    vae_trainer_kwargs=None,
    decoder_kwargs=None,
    encoder_backprop_settings="rl_n_vae",
    mlp_for_image_decoder=False,
    pretrain_vae=False,
    pretrain_vae_kwargs=None,

    # Should only be set by transfer_encoder_goal_conditioned_sac_experiment
    is_retraining_from_scratch=False,
    train_results=None,
    rl_csv_fname='progress.csv',
):
    if reward_config is None:
        reward_config = {}
    if encoder_cnn_kwargs is None:
        encoder_cnn_kwargs = {}
    if policy_using_encoder_settings is None:
        policy_using_encoder_settings = {}
    if debug_visualization_kwargs is None:
        debug_visualization_kwargs = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if contextual_env_kwargs is None:
        contextual_env_kwargs = {}
    if encoder_kwargs is None:
        encoder_kwargs = {}
    if save_video_kwargs is None:
        save_video_kwargs = {}
    if video_renderer_kwargs is None:
        video_renderer_kwargs = {}
    if debug_renderer_kwargs is None:
        debug_renderer_kwargs = {}

    assert (is_retraining_from_scratch != (train_results is None))

    img_observation_key = 'image_observation'
    state_observation_key = 'state_observation'
    latent_desired_goal_key = 'latent_desired_goal'
    state_desired_goal_key = 'state_desired_goal'
    img_desired_goal_key = 'image_desired_goal'

    backprop_rl_into_encoder = False
    backprop_vae_into_encoder = False
    if 'rl' in encoder_backprop_settings:
        backprop_rl_into_encoder = True
    if 'vae' in encoder_backprop_settings:
        backprop_vae_into_encoder = True

    if use_image_observations:
        env_renderer = EnvRenderer(**env_renderer_kwargs)

    def setup_env(state_env, encoder, reward_fn):
        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        if use_image_observations:
            goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=goal_distribution,
                image_goal_key=img_desired_goal_key,
                renderer=env_renderer,
            )
            base_env = InsertImageEnv(state_env, renderer=env_renderer)
            goal_distribution = PresampledDistribution(goal_distribution,
                                                       num_presampled_goals)
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key, img_desired_goal_key],
                encoder_input_key=img_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        else:
            base_env = state_env
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key],
                encoder_input_key=state_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        env = ContextualEnv(
            base_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
            **contextual_env_kwargs,
        )
        return env, goal_distribution

    state_expl_env = get_gym_env(env_id,
                                 env_class=env_class,
                                 env_kwargs=env_kwargs)
    state_expl_env.goal_sampling_mode = exploration_goal_sampling_mode
    state_eval_env = get_gym_env(env_id,
                                 env_class=env_class,
                                 env_kwargs=env_kwargs)
    state_eval_env.goal_sampling_mode = evaluation_goal_sampling_mode

    if use_image_observations:
        context_keys_to_save = [
            state_desired_goal_key,
            img_desired_goal_key,
            latent_desired_goal_key,
        ]
        context_key_for_rl = img_desired_goal_key
        observation_key_for_rl = img_observation_key

        def create_encoder():
            img_num_channels, img_height, img_width = env_renderer.image_chw
            cnn = BasicCNN(input_width=img_width,
                           input_height=img_height,
                           input_channels=img_num_channels,
                           **encoder_cnn_kwargs)
            cnn_output_size = np.prod(cnn.output_shape)
            mlp = MultiHeadedMlp(input_size=cnn_output_size,
                                 output_sizes=[latent_dim, latent_dim],
                                 **encoder_kwargs)
            enc = nn.Sequential(cnn, Flatten(), mlp)
            enc.input_size = img_width * img_height * img_num_channels
            enc.output_size = latent_dim
            return enc
    else:
        context_keys_to_save = [
            state_desired_goal_key, latent_desired_goal_key
        ]
        context_key_for_rl = state_desired_goal_key
        observation_key_for_rl = state_observation_key

        def create_encoder():
            in_dim = (state_expl_env.observation_space.
                      spaces[state_observation_key].low.size)
            enc = ConcatMultiHeadedMlp(input_size=in_dim,
                                       output_sizes=[latent_dim, latent_dim],
                                       **encoder_kwargs)
            enc.input_size = in_dim
            enc.output_size = latent_dim
            return enc

    encoder_net = create_encoder()
    if train_results:
        print("Using transfer encoder")
        encoder_net = train_results.encoder_net
    mu_encoder_net = EncoderMuFromEncoderDistribution(encoder_net)
    target_encoder_net = create_encoder()
    mu_target_encoder_net = EncoderMuFromEncoderDistribution(
        target_encoder_net)
    encoder_input_dim = encoder_net.input_size

    encoder = EncoderFromNetwork(mu_encoder_net)
    encoder.to(ptu.device)
    if reward_type == 'encoder_distance':
        reward_fn = EncoderRewardFnFromMultitaskEnv(
            encoder=encoder,
            next_state_encoder_input_key=observation_key_for_rl,
            context_key=latent_desired_goal_key,
            **reward_config,
        )
    elif reward_type == 'target_encoder_distance':
        target_encoder = EncoderFromNetwork(mu_target_encoder_net)
        reward_fn = EncoderRewardFnFromMultitaskEnv(
            encoder=target_encoder,
            next_state_encoder_input_key=observation_key_for_rl,
            context_key=latent_desired_goal_key,
            **reward_config,
        )
    elif reward_type == 'state_distance':
        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=state_expl_env,
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                'state_observation'),
            desired_goal_key=state_desired_goal_key,
            achieved_goal_key='state_achieved_goal',
        )
    else:
        raise ValueError("invalid reward type {}".format(reward_type))
    expl_env, expl_context_distrib = setup_env(state_expl_env, encoder,
                                               reward_fn)
    eval_env, eval_context_distrib = setup_env(state_eval_env, encoder,
                                               reward_fn)

    action_dim = expl_env.action_space.low.size

    def make_qf(goal_encoder):
        if not backprop_rl_into_encoder:
            goal_encoder = Detach(goal_encoder)
        if qf_state_encoder_is_goal_encoder:
            state_encoder = goal_encoder
        else:
            state_encoder = EncoderMuFromEncoderDistribution(create_encoder())
            raise RuntimeError(
                "State encoder must be goal encoder for resuming exps")
        if use_parallel_qf:
            return ParallelDisentangledMlpQf(
                goal_encoder=goal_encoder,
                state_encoder=state_encoder,
                preprocess_obs_dim=encoder_input_dim,
                action_dim=action_dim,
                post_encoder_mlp_kwargs=qf_kwargs,
                vectorized=True,
                **disentangled_qf_kwargs)
        else:
            return DisentangledMlpQf(goal_encoder=goal_encoder,
                                     state_encoder=state_encoder,
                                     preprocess_obs_dim=encoder_input_dim,
                                     action_dim=action_dim,
                                     qf_kwargs=qf_kwargs,
                                     vectorized=True,
                                     **disentangled_qf_kwargs)

    qf1 = make_qf(mu_encoder_net)
    qf2 = make_qf(mu_encoder_net)
    target_qf1 = make_qf(mu_target_encoder_net)
    target_qf2 = make_qf(mu_target_encoder_net)

    if use_separate_encoder_for_policy:
        policy_encoder = EncoderMuFromEncoderDistribution(create_encoder())
        policy_encoder_net = EncodeObsAndGoal(
            policy_encoder,
            encoder_input_dim,
            encode_state=True,
            encode_goal=True,
            detach_encoder_via_goal=False,
            detach_encoder_via_state=False,
        )
    else:
        policy_obs_encoder = mu_encoder_net
        if not backprop_rl_into_encoder:
            policy_obs_encoder = Detach(policy_obs_encoder)
        policy_encoder_net = EncodeObsAndGoal(policy_obs_encoder,
                                              encoder_input_dim,
                                              **policy_using_encoder_settings)
    obs_processor = nn.Sequential(
        policy_encoder_net, ConcatTuple(),
        MultiHeadedMlp(input_size=policy_encoder_net.output_size,
                       output_sizes=[action_dim, action_dim],
                       **policy_kwargs))
    policy = PolicyFromDistributionGenerator(TanhGaussian(obs_processor))

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key_for_rl]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        batch['raw_next_observations'] = next_obs
        return batch

    if use_image_observations:
        # Do this so that the context has all two/three: the state, image, and
        # encoded goal
        sample_context_from_observation = compose(
            RemapKeyFn({
                state_desired_goal_key: state_observation_key,
                img_desired_goal_key: img_observation_key,
            }),
            ReEncoderAchievedStateFn(
                encoder=encoder,
                encoder_input_key=context_key_for_rl,
                encoder_output_key=latent_desired_goal_key,
                keys_to_keep=[state_desired_goal_key, img_desired_goal_key],
            ),
        )
        ob_keys_to_save_in_buffer = [
            state_observation_key, img_observation_key
        ]
    else:
        sample_context_from_observation = compose(
            RemapKeyFn({
                state_desired_goal_key: state_observation_key,
            }),
            ReEncoderAchievedStateFn(
                encoder=encoder,
                encoder_input_key=context_key_for_rl,
                encoder_output_key=latent_desired_goal_key,
                keys_to_keep=[state_desired_goal_key],
            ),
        )
        ob_keys_to_save_in_buffer = [state_observation_key]

    encoder_output_dim = mu_encoder_net.output_size
    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=context_keys_to_save,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_observation,
        observation_keys=ob_keys_to_save_in_buffer,
        observation_key=observation_key_for_rl,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        reward_dim=encoder_output_dim,
        **replay_buffer_kwargs)

    disentangled_trainer = DisentangedTrainer(env=expl_env,
                                              policy=policy,
                                              qf1=qf1,
                                              qf2=qf2,
                                              target_qf1=target_qf1,
                                              target_qf2=target_qf2,
                                              **sac_trainer_kwargs)

    if train_encoder_as_vae:
        assert backprop_vae_into_encoder, \
            "No point in training the vae if not backpropagating into encoder"
        if vae_trainer_kwargs is None:
            vae_trainer_kwargs = {}
        if decoder_kwargs is None:
            decoder_kwargs = invert_encoder_mlp_params(encoder_kwargs)

        # VAE training
        def make_decoder():
            if use_image_observations:
                img_num_channels, img_height, img_width = env_renderer.image_chw
                if not mlp_for_image_decoder:
                    dcnn_in_channels, dcnn_in_height, dcnn_in_width = (
                        encoder_net._modules['0'].output_shape)
                    dcnn_input_size = (dcnn_in_channels * dcnn_in_width *
                                       dcnn_in_height)
                    decoder_cnn_kwargs = invert_encoder_params(
                        encoder_cnn_kwargs,
                        img_num_channels,
                    )
                    dcnn = BasicDCNN(input_width=dcnn_in_width,
                                     input_height=dcnn_in_height,
                                     input_channels=dcnn_in_channels,
                                     **decoder_cnn_kwargs)
                    mlp = Mlp(input_size=latent_dim,
                              output_size=dcnn_input_size,
                              **decoder_kwargs)
                    dec = nn.Sequential(mlp, dcnn)
                    dec.input_size = latent_dim
                else:
                    dec = nn.Sequential(
                        Mlp(input_size=latent_dim,
                            output_size=img_num_channels * img_height *
                            img_width,
                            **decoder_kwargs),
                        Reshape(img_num_channels, img_height, img_width))
                    dec.input_size = latent_dim
                    dec.output_size = img_num_channels * img_height * img_width
                return dec
            else:
                return Mlp(input_size=latent_dim,
                           output_size=encoder_input_dim,
                           **decoder_kwargs)

        decoder_net = make_decoder()
        vae = VAE(encoder_net, decoder_net)

        vae_trainer = VAETrainer(vae=vae, **vae_trainer_kwargs)

        if pretrain_vae:
            goal_key = (img_desired_goal_key
                        if use_image_observations else state_desired_goal_key)
            vae.to(ptu.device)
            train_ae(vae_trainer,
                     expl_context_distrib,
                     **pretrain_vae_kwargs,
                     goal_key=goal_key,
                     rl_csv_fname=rl_csv_fname)
        trainers = OrderedDict()
        trainers['vae_trainer'] = vae_trainer
        trainers['disentangled_trainer'] = disentangled_trainer
        trainer = JointTrainer(trainers)
    else:
        trainer = disentangled_trainer

    if not use_image_observations and use_debug_trainer:
        # TODO: implement this for images
        debug_trainer = DebugTrainer(
            observation_space=expl_env.observation_space.
            spaces[state_observation_key],
            encoder=mu_encoder_net,
            encoder_output_dim=encoder_output_dim,
        )
        trainer = JointTrainer([trainer, debug_trainer])

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key_for_rl,
        context_keys_for_policy=[context_key_for_rl],
    )
    exploration_policy = create_exploration_policy(policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key_for_rl,
        context_keys_for_policy=[context_key_for_rl],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    video_renderer = EnvRenderer(**video_renderer_kwargs)
    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key_for_rl,
            context_keys_for_policy=[context_key_for_rl],
        )
        if save_debug_video and not use_image_observations:
            # TODO: add visualization for image-based envs
            obj1_sweep_renderers = {
                'sweep_obj1_%d' % i: DebugEnvRenderer(encoder, i,
                                                      **debug_renderer_kwargs)
                for i in range(encoder_output_dim)
            }
            obj0_sweep_renderers = {
                'sweep_obj0_%d' % i: DebugEnvRenderer(encoder, i,
                                                      **debug_renderer_kwargs)
                for i in range(encoder_output_dim)
            }

            debugger_one = DebugEnvRenderer(encoder, 0,
                                            **debug_renderer_kwargs)

            low = eval_env.env.observation_space[
                state_observation_key].low.min()
            high = eval_env.env.observation_space[
                state_observation_key].high.max()
            y = np.linspace(low, high, num=debugger_one.image_shape[0])
            x = np.linspace(low, high, num=debugger_one.image_shape[1])
            cross = np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])

            def create_shared_data_creator(obj_index):
                def compute_shared_data(raw_obs, env):
                    state = raw_obs[observation_key_for_rl]
                    obs = state[:2]
                    goal = state[2:]
                    if obj_index == 0:
                        new_states = np.concatenate(
                            [
                                np.repeat(obs[None, :], cross.shape[0],
                                          axis=0),
                                cross,
                            ],
                            axis=1,
                        )
                    elif obj_index == 1:
                        new_states = np.concatenate(
                            [
                                cross,
                                np.repeat(
                                    goal[None, :], cross.shape[0], axis=0),
                            ],
                            axis=1,
                        )
                    else:
                        raise ValueError(obj_index)
                    return encoder.encode(new_states)

                return compute_shared_data

            obj0_sweeper = create_shared_data_creator(0)
            obj1_sweeper = create_shared_data_creator(1)

            def add_images(env, base_distribution):
                if use_image_observations:
                    img_env = env
                    image_goal_distribution = base_distribution
                else:
                    state_env = env.env
                    image_goal_distribution = AddImageDistribution(
                        env=state_env,
                        base_distribution=base_distribution,
                        image_goal_key='image_desired_goal',
                        renderer=video_renderer,
                    )
                    img_env = InsertImageEnv(state_env,
                                             renderer=video_renderer)
                img_env = InsertDebugImagesEnv(
                    img_env,
                    obj1_sweep_renderers,
                    compute_shared_data=obj1_sweeper,
                )
                img_env = InsertDebugImagesEnv(
                    img_env,
                    obj0_sweep_renderers,
                    compute_shared_data=obj0_sweeper,
                )
                return ContextualEnv(
                    img_env,
                    context_distribution=image_goal_distribution,
                    reward_fn=reward_fn,
                    observation_key=observation_key_for_rl,
                    update_env_info_fn=delete_info,
                )

            img_eval_env = add_images(eval_env, eval_context_distrib)
            img_expl_env = add_images(expl_env, expl_context_distrib)

            def get_extra_imgs(
                path,
                index_in_path,
                env,
            ):
                return [
                    path['full_observations'][index_in_path][key]
                    for key in obj1_sweep_renderers
                ] + [
                    path['full_observations'][index_in_path][key]
                    for key in obj0_sweep_renderers
                ]

            img_formats = [video_renderer.output_image_format]
            for r in obj1_sweep_renderers.values():
                img_formats.append(r.output_image_format)
            for r in obj0_sweep_renderers.values():
                img_formats.append(r.output_image_format)
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                MakeDeterministic(policy),
                tag="eval",
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                get_extra_imgs=get_extra_imgs,
                **save_video_kwargs)
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                exploration_policy,
                tag="train",
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                get_extra_imgs=get_extra_imgs,
                **save_video_kwargs)
        else:
            video_renderer = EnvRenderer(**video_renderer_kwargs)

            def add_images(env, base_distribution):
                if use_image_observations:
                    video_env = InsertImageEnv(
                        env,
                        renderer=video_renderer,
                        image_key='video_observation',
                    )
                    image_goal_distribution = base_distribution
                else:
                    video_env = InsertImageEnv(
                        env,
                        renderer=video_renderer,
                        image_key='image_observation',
                    )
                    state_env = env.env
                    image_goal_distribution = AddImageDistribution(
                        env=state_env,
                        base_distribution=base_distribution,
                        image_goal_key='image_desired_goal',
                        renderer=video_renderer,
                    )
                return ContextualEnv(
                    video_env,
                    context_distribution=image_goal_distribution,
                    reward_fn=reward_fn,
                    observation_key=observation_key_for_rl,
                    update_env_info_fn=delete_info,
                )

            img_eval_env = add_images(eval_env, eval_context_distrib)
            img_expl_env = add_images(expl_env, expl_context_distrib)

            if use_image_observations:
                keys_to_show = [
                    'image_desired_goal',
                    'image_observation',
                    'video_observation',
                ]
                image_formats = [
                    env_renderer.output_image_format,
                    env_renderer.output_image_format,
                    video_renderer.output_image_format,
                ]
            else:
                keys_to_show = ['image_desired_goal', 'image_observation']
                image_formats = [
                    video_renderer.output_image_format,
                    video_renderer.output_image_format,
                ]
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                MakeDeterministic(policy),
                tag="eval",
                imsize=video_renderer.image_chw[1],
                keys_to_show=keys_to_show,
                image_formats=image_formats,
                **save_video_kwargs)
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                exploration_policy,
                tag="train",
                imsize=video_renderer.image_chw[1],
                keys_to_show=keys_to_show,
                image_formats=image_formats,
                **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)
    if visualize_representation:
        from multiworld.envs.pygame import PickAndPlaceEnv
        if not isinstance(state_eval_env, PickAndPlaceEnv):
            raise NotImplementedError()
        num_objects = (state_eval_env.observation_space[state_observation_key].
                       low.size) // 2
        state_space = state_eval_env.observation_space['state_observation']
        start_states = np.vstack([state_space.sample() for _ in range(6)])
        if use_image_observations:

            def state_to_encoder_input(state):
                goal_dict = {
                    'state_desired_goal': state,
                }
                env_state = state_eval_env.get_env_state()
                state_eval_env.set_to_goal(goal_dict)
                start_img = env_renderer(state_eval_env)
                state_eval_env.set_env_state(env_state)
                return start_img

            for i in range(num_objects):
                visualize_representation = create_visualize_representation(
                    encoder,
                    i,
                    eval_env,
                    video_renderer,
                    state_to_encoder_input=state_to_encoder_input,
                    env_renderer=env_renderer,
                    start_states=start_states,
                    **debug_visualization_kwargs)
                algorithm.post_train_funcs.append(visualize_representation)
        else:
            for i in range(num_objects):
                visualize_representation = create_visualize_representation(
                    encoder, i, eval_env, video_renderer,
                    **debug_visualization_kwargs)
                algorithm.post_train_funcs.append(visualize_representation)

    if distance_scatterplot_save_period > 0:
        algorithm.post_train_funcs.append(
            create_save_h_vs_state_distance_fn(
                distance_scatterplot_save_period,
                distance_scatterplot_initial_save_period,
                encoder,
                observation_key_for_rl,
            ))
    algorithm.train()
    train_results = dict(encoder_net=encoder_net, )
    TrainResults = namedtuple('TrainResults', sorted(train_results))
    return TrainResults(**train_results)
Beispiel #11
0
def probabilistic_goal_reaching_experiment(
    max_path_length,
    qf_kwargs,
    policy_kwargs,
    pgr_trainer_kwargs,
    replay_buffer_kwargs,
    algo_kwargs,
    env_id,
    discount_factor,
    reward_type,
    # Dynamics model
    dynamics_model_version,
    dynamics_model_config,
    dynamics_delta_model_config=None,
    dynamics_adam_config=None,
    dynamics_ensemble_kwargs=None,
    # Discount model
    learn_discount_model=False,
    discount_adam_config=None,
    discount_model_config=None,
    prior_discount_weight_schedule_kwargs=None,
    # Environment
    env_class=None,
    env_kwargs=None,
    observation_key='state_observation',
    desired_goal_key='state_desired_goal',
    exploration_policy_kwargs=None,
    action_noise_scale=0.,
    num_presampled_goals=4096,
    success_threshold=0.05,
    # Video / visualization parameters
    save_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
    plot_renderer_kwargs=None,
    eval_env_ids=None,
    # Debugging params
    visualize_dynamics=False,
    visualize_discount_model=False,
    visualize_all_plots=False,
    plot_discount=False,
    plot_reward=False,
    plot_bootstrap_value=False,
    # env specific-params
    normalize_distances_for_full_state_ant=False,
):
    if dynamics_ensemble_kwargs is None:
        dynamics_ensemble_kwargs = {}
    if eval_env_ids is None:
        eval_env_ids = {'eval': env_id}
    if discount_model_config is None:
        discount_model_config = {}
    if dynamics_delta_model_config is None:
        dynamics_delta_model_config = {}
    if dynamics_adam_config is None:
        dynamics_adam_config = {}
    if discount_adam_config is None:
        discount_adam_config = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not video_renderer_kwargs:
        video_renderer_kwargs = {}
    if not plot_renderer_kwargs:
        plot_renderer_kwargs = video_renderer_kwargs.copy()
        plot_renderer_kwargs['dpi'] = 48
    context_key = desired_goal_key

    stub_env = get_gym_env(
        env_id,
        env_class=env_class,
        env_kwargs=env_kwargs,
        unwrap_timed_envs=True,
    )
    is_gym_env = (
        isinstance(stub_env, FetchEnv) or isinstance(stub_env, AntXYGoalEnv)
        or isinstance(stub_env, AntFullPositionGoalEnv)
        # or isinstance(stub_env, HopperFullPositionGoalEnv)
    )
    is_ant_full_pos = isinstance(stub_env, AntFullPositionGoalEnv)

    if is_gym_env:
        achieved_goal_key = desired_goal_key.replace('desired', 'achieved')
        ob_keys_to_save_in_buffer = [observation_key, achieved_goal_key]
    elif isinstance(stub_env, SawyerPickAndPlaceEnvYZ):
        achieved_goal_key = desired_goal_key.replace('desired', 'achieved')
        ob_keys_to_save_in_buffer = [observation_key, achieved_goal_key]
    else:
        achieved_goal_key = observation_key
        ob_keys_to_save_in_buffer = [observation_key]
    # TODO move all env-specific code to other file
    if isinstance(stub_env, SawyerDoorHookEnv):
        init_camera = sawyer_door_env_camera_v0
    elif isinstance(stub_env, SawyerPushAndReachXYEnv):
        init_camera = sawyer_init_camera_zoomed_in
    elif isinstance(stub_env, SawyerPickAndPlaceEnvYZ):
        init_camera = sawyer_pick_and_place_camera
    else:
        init_camera = None

    full_ob_space = stub_env.observation_space
    action_space = stub_env.action_space
    state_to_goal = StateToGoalFn(stub_env)
    dynamics_model = create_goal_dynamics_model(
        full_ob_space[observation_key],
        action_space,
        full_ob_space[achieved_goal_key],
        dynamics_model_version,
        state_to_goal,
        dynamics_model_config,
        dynamics_delta_model_config,
        ensemble_model_kwargs=dynamics_ensemble_kwargs,
    )
    sample_context_from_obs_dict_fn = RemapKeyFn(
        {context_key: achieved_goal_key})

    def contextual_env_distrib_reward(_env_id,
                                      _env_class=None,
                                      _env_kwargs=None):
        base_env = get_gym_env(
            _env_id,
            env_class=env_class,
            env_kwargs=env_kwargs,
            unwrap_timed_envs=True,
        )
        if init_camera:
            base_env.initialize_camera(init_camera)
        if (isinstance(stub_env, AntFullPositionGoalEnv)
                and normalize_distances_for_full_state_ant):
            base_env = NormalizeAntFullPositionGoalEnv(base_env)
            normalize_env = base_env
        else:
            normalize_env = None
        env = NoisyAction(base_env, action_noise_scale)
        diag_fns = []
        if is_gym_env:
            goal_distribution = GoalDictDistributionFromGymGoalEnv(
                env,
                desired_goal_key=desired_goal_key,
            )
            diag_fns.append(
                GenericGoalConditionedContextualDiagnostics(
                    desired_goal_key=desired_goal_key,
                    achieved_goal_key=achieved_goal_key,
                    success_threshold=success_threshold,
                ))
        else:
            goal_distribution = GoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
            )
            diag_fns.append(
                GoalConditionedDiagnosticsToContextualDiagnostics(
                    env.goal_conditioned_diagnostics,
                    desired_goal_key=desired_goal_key,
                    observation_key=observation_key,
                ))
        if isinstance(stub_env, AntFullPositionGoalEnv):
            diag_fns.append(
                AntFullPositionGoalEnvDiagnostics(
                    desired_goal_key=desired_goal_key,
                    achieved_goal_key=achieved_goal_key,
                    success_threshold=success_threshold,
                    normalize_env=normalize_env,
                ))
        # if isinstance(stub_env, HopperFullPositionGoalEnv):
        #     diag_fns.append(
        #         HopperFullPositionGoalEnvDiagnostics(
        #             desired_goal_key=desired_goal_key,
        #             achieved_goal_key=achieved_goal_key,
        #             success_threshold=success_threshold,
        #         )
        #     )
        achieved_from_ob = IndexIntoAchievedGoal(achieved_goal_key, )
        if reward_type == 'sparse':
            distance_fn = L2Distance(
                achieved_goal_from_observation=achieved_from_ob,
                desired_goal_key=desired_goal_key,
            )
            reward_fn = ThresholdDistanceReward(distance_fn, success_threshold)
        elif reward_type == 'negative_distance':
            reward_fn = NegativeL2Distance(
                achieved_goal_from_observation=achieved_from_ob,
                desired_goal_key=desired_goal_key,
            )
        else:
            reward_fn = ProbabilisticGoalRewardFn(
                dynamics_model,
                state_key=observation_key,
                context_key=context_key,
                reward_type=reward_type,
                discount_factor=discount_factor,
            )
        goal_distribution = PresampledDistribution(goal_distribution,
                                                   num_presampled_goals)
        final_env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=diag_fns,
            update_env_info_fn=delete_info,
        )
        return final_env, goal_distribution, reward_fn

    expl_env, expl_context_distrib, reward_fn = contextual_env_distrib_reward(
        env_id,
        env_class,
        env_kwargs,
    )
    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    def create_policy():
        obs_processor = MultiHeadedMlp(input_size=obs_dim,
                                       output_sizes=[action_dim, action_dim],
                                       **policy_kwargs)
        return PolicyFromDistributionGenerator(TanhGaussian(obs_processor))

    policy = create_policy()

    def concat_context_to_obs(batch, replay_buffer, obs_dict, next_obs_dict,
                              new_contexts):
        obs = batch['observations']
        next_obs = batch['next_observations']
        batch['original_observations'] = obs
        batch['original_next_observations'] = next_obs
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=expl_env,
        context_keys=[context_key],
        observation_keys=ob_keys_to_save_in_buffer,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)

    def create_trainer():
        trainers = OrderedDict()
        if learn_discount_model:
            discount_model = create_discount_model(
                ob_space=stub_env.observation_space[observation_key],
                goal_space=stub_env.observation_space[context_key],
                action_space=stub_env.action_space,
                model_kwargs=discount_model_config)
            optimizer = optim.Adam(discount_model.parameters(),
                                   **discount_adam_config)
            discount_trainer = DiscountModelTrainer(
                discount_model,
                optimizer,
                observation_key='observations',
                next_observation_key='original_next_observations',
                goal_key=context_key,
                state_to_goal_fn=state_to_goal,
            )
            trainers['discount_trainer'] = discount_trainer
        else:
            discount_model = None
        if prior_discount_weight_schedule_kwargs is not None:
            schedule = create_schedule(**prior_discount_weight_schedule_kwargs)
        else:
            schedule = None
        pgr_trainer = PGRTrainer(env=expl_env,
                                 policy=policy,
                                 qf1=qf1,
                                 qf2=qf2,
                                 target_qf1=target_qf1,
                                 target_qf2=target_qf2,
                                 discount=discount_factor,
                                 discount_model=discount_model,
                                 prior_discount_weight_schedule=schedule,
                                 **pgr_trainer_kwargs)
        trainers[''] = pgr_trainer
        optimizers = [
            pgr_trainer.qf1_optimizer,
            pgr_trainer.qf2_optimizer,
            pgr_trainer.alpha_optimizer,
            pgr_trainer.policy_optimizer,
        ]
        if dynamics_model_version in {
                'learned_model',
                'learned_model_ensemble',
                'learned_model_laplace',
                'learned_model_laplace_global_variance',
                'learned_model_gaussian_global_variance',
        }:
            model_opt = optim.Adam(dynamics_model.parameters(),
                                   **dynamics_adam_config)
        elif dynamics_model_version in {
                'fixed_standard_laplace',
                'fixed_standard_gaussian',
        }:
            model_opt = None
        else:
            raise NotImplementedError()
        model_trainer = GenerativeGoalDynamicsModelTrainer(
            dynamics_model,
            model_opt,
            state_to_goal=state_to_goal,
            observation_key='original_observations',
            next_observation_key='original_next_observations',
        )
        trainers['dynamics_trainer'] = model_trainer
        optimizers.append(model_opt)
        return JointTrainer(trainers), pgr_trainer

    trainer, pgr_trainer = create_trainer()

    eval_policy = MakeDeterministic(policy)

    def create_eval_path_collector(some_eval_env):
        return ContextualPathCollector(
            some_eval_env,
            eval_policy,
            observation_key=observation_key,
            context_keys_for_policy=[context_key],
        )

    path_collectors = dict()
    eval_env_name_to_env_and_context_distrib = dict()
    for name, extra_env_id in eval_env_ids.items():
        env, context_distrib, _ = contextual_env_distrib_reward(extra_env_id)
        path_collectors[name] = create_eval_path_collector(env)
        eval_env_name_to_env_and_context_distrib[name] = (env, context_distrib)
    eval_path_collector = JointPathCollector(path_collectors)
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )

    def get_eval_diagnostics(key_to_paths):
        stats = OrderedDict()
        for eval_env_name, paths in key_to_paths.items():
            env, _ = eval_env_name_to_env_and_context_distrib[eval_env_name]
            stats.update(
                add_prefix(
                    env.get_diagnostics(paths),
                    eval_env_name,
                    divider='/',
                ))
            stats.update(
                add_prefix(
                    eval_util.get_generic_path_information(paths),
                    eval_env_name,
                    divider='/',
                ))
        return stats

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=None,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        evaluation_get_diagnostic_functions=[get_eval_diagnostics],
        **algo_kwargs)
    algorithm.to(ptu.device)

    if normalize_distances_for_full_state_ant and is_ant_full_pos:
        qpos_weights = expl_env.unwrapped.presampled_qpos.std(axis=0)
    else:
        qpos_weights = None

    if save_video:
        if is_gym_env:
            video_renderer = GymEnvRenderer(**video_renderer_kwargs)

            def set_goal_for_visualization(env, policy, o):
                goal = o[desired_goal_key]
                if normalize_distances_for_full_state_ant and is_ant_full_pos:
                    unnormalized_goal = goal * qpos_weights
                    env.unwrapped.goal = unnormalized_goal
                else:
                    env.unwrapped.goal = goal

            rollout_function = partial(
                rf.contextual_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                context_keys_for_policy=[context_key],
                reset_callback=set_goal_for_visualization,
            )
        else:
            video_renderer = EnvRenderer(**video_renderer_kwargs)
            rollout_function = partial(
                rf.contextual_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                context_keys_for_policy=[context_key],
                reset_callback=None,
            )

        renderers = OrderedDict(image_observation=video_renderer, )
        state_env = expl_env.env
        state_space = state_env.observation_space[observation_key]
        low = state_space.low.min()
        high = state_space.high.max()
        y = np.linspace(low, high, num=video_renderer.image_chw[1])
        x = np.linspace(low, high, num=video_renderer.image_chw[2])
        all_xy_np = np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])
        all_xy_torch = ptu.from_numpy(all_xy_np)
        num_states = all_xy_torch.shape[0]
        if visualize_dynamics:

            def create_dynamics_visualizer(show_prob, vary_state=False):
                def get_prob(obs_dict, action):
                    obs = obs_dict['state_observation']
                    obs_torch = ptu.from_numpy(obs)[None]
                    action_torch = ptu.from_numpy(action)[None]
                    if vary_state:
                        action_repeated = torch.zeros((num_states, 2))
                        dist = dynamics_model(all_xy_torch, action_repeated)
                        goal = ptu.from_numpy(
                            obs_dict['state_desired_goal'][None])
                        log_probs = dist.log_prob(goal)
                    else:
                        dist = dynamics_model(obs_torch, action_torch)
                        log_probs = dist.log_prob(all_xy_torch)
                    if show_prob:
                        return log_probs.exp()
                    else:
                        return log_probs

                return get_prob

            renderers['log_prob'] = ValueRenderer(
                create_dynamics_visualizer(False), **video_renderer_kwargs)
            # renderers['prob'] = ValueRenderer(
            #     create_dynamics_visualizer(True), **video_renderer_kwargs
            # )
            renderers['log_prob_vary_state'] = ValueRenderer(
                create_dynamics_visualizer(False, vary_state=True),
                only_get_image_once_per_episode=True,
                max_out_walls=isinstance(stub_env, PickAndPlaceEnv),
                **video_renderer_kwargs)
            # renderers['prob_vary_state'] = ValueRenderer(
            #     create_dynamics_visualizer(True, vary_state=True),
            #     **video_renderer_kwargs)
        if visualize_discount_model and pgr_trainer.discount_model:

            def get_discount_values(obs, action):
                obs = obs['state_observation']
                obs_torch = ptu.from_numpy(obs)[None]
                combined_obs = torch.cat([
                    obs_torch.repeat(num_states, 1),
                    all_xy_torch,
                ],
                                         dim=1)

                action_torch = ptu.from_numpy(action)[None]
                action_repeated = action_torch.repeat(num_states, 1)
                return pgr_trainer.discount_model(combined_obs,
                                                  action_repeated)

            renderers['discount_model'] = ValueRenderer(
                get_discount_values,
                states_to_eval=all_xy_torch,
                **video_renderer_kwargs)
        if 'log_prob' in renderers and 'discount_model' in renderers:
            renderers['log_prob_time_discount'] = ProductRenderer(
                renderers['discount_model'], renderers['log_prob'],
                **video_renderer_kwargs)

        def get_reward(obs_dict, action, next_obs_dict):
            o = batchify(obs_dict)
            a = batchify(action)
            next_o = batchify(next_obs_dict)
            reward = reward_fn(o, a, next_o, next_o)
            return reward[0]

        def get_bootstrap(obs_dict, action, next_obs_dict, return_float=True):
            context_pt = ptu.from_numpy(obs_dict[context_key][None])
            o_pt = ptu.from_numpy(obs_dict[observation_key][None])
            next_o_pt = ptu.from_numpy(next_obs_dict[observation_key][None])
            action_torch = ptu.from_numpy(action[None])
            bootstrap, *_ = pgr_trainer.get_bootstrap_stats(
                torch.cat((o_pt, context_pt), dim=1),
                action_torch,
                torch.cat((next_o_pt, context_pt), dim=1),
            )
            if return_float:
                return ptu.get_numpy(bootstrap)[0, 0]
            else:
                return bootstrap

        def get_discount(obs_dict, action, next_obs_dict):
            bootstrap = get_bootstrap(obs_dict,
                                      action,
                                      next_obs_dict,
                                      return_float=False)
            reward_np = get_reward(obs_dict, action, next_obs_dict)
            reward = ptu.from_numpy(reward_np[None, None])
            context_pt = ptu.from_numpy(obs_dict[context_key][None])
            o_pt = ptu.from_numpy(obs_dict[observation_key][None])
            obs = torch.cat((o_pt, context_pt), dim=1)
            actions = ptu.from_numpy(action[None])
            discount = pgr_trainer.get_discount_factor(
                bootstrap,
                reward,
                obs,
                actions,
            )
            if isinstance(discount, torch.Tensor):
                discount = ptu.get_numpy(discount)[0, 0]
            return np.clip(discount, a_min=1e-3, a_max=1)

        def create_modify_fn(
            title,
            set_params=None,
            scientific=True,
        ):
            def modify(ax):
                ax.set_title(title)
                if set_params:
                    ax.set(**set_params)
                if scientific:
                    scaler = ScalarFormatter(useOffset=True)
                    scaler.set_powerlimits((1, 1))
                    ax.yaxis.set_major_formatter(scaler)
                    ax.ticklabel_format(axis='y', style='sci')

            return modify

        def add_left_margin(fig):
            fig.subplots_adjust(left=0.2)

        if visualize_all_plots or plot_discount:
            renderers['discount'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_discount,
                modify_ax_fn=create_modify_fn(
                    title='discount',
                    set_params=dict(
                        # yscale='log',
                        ylim=[-0.05, 1.1], ),
                    # scientific=False,
                ),
                modify_fig_fn=add_left_margin,
                # autoscale_y=False,
                **plot_renderer_kwargs)

        if visualize_all_plots or plot_reward:
            renderers['reward'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_reward,
                modify_ax_fn=create_modify_fn(title='reward',
                                              # scientific=False,
                                              ),
                modify_fig_fn=add_left_margin,
                **plot_renderer_kwargs)
        if visualize_all_plots or plot_bootstrap_value:
            renderers['bootstrap-value'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_bootstrap,
                modify_ax_fn=create_modify_fn(title='bootstrap value',
                                              # scientific=False,
                                              ),
                modify_fig_fn=add_left_margin,
                **plot_renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            if is_gym_env:
                goal_distribution = state_distribution
            else:
                goal_distribution = AddImageDistribution(
                    env=state_env,
                    base_distribution=state_distribution,
                    image_goal_key='image_desired_goal',
                    renderer=video_renderer,
                )
            context_env = ContextualEnv(
                state_env,
                context_distribution=goal_distribution,
                reward_fn=reward_fn,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
            return InsertDebugImagesEnv(
                context_env,
                renderers=renderers,
            )

        img_expl_env = add_images(expl_env, expl_context_distrib)
        if is_gym_env:
            imgs_to_show = list(renderers.keys())
        else:
            imgs_to_show = ['image_desired_goal'] + list(renderers.keys())
        img_formats = [video_renderer.output_image_format]
        img_formats += [r.output_image_format for r in renderers.values()]
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="xplor",
            imsize=video_renderer.image_chw[1],
            image_formats=img_formats,
            keys_to_show=imgs_to_show,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(expl_video_func)
        for eval_env_name, (env, context_distrib) in (
                eval_env_name_to_env_and_context_distrib.items()):
            img_eval_env = add_images(env, context_distrib)
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                eval_policy,
                tag=eval_env_name,
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                keys_to_show=imgs_to_show,
                **save_video_kwargs)
            algorithm.post_train_funcs.append(eval_video_func)

    algorithm.train()
Beispiel #12
0
def awac_rig_experiment(
    max_path_length,
    qf_kwargs,
    trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    train_vae_kwargs,
    policy_class=TanhGaussianPolicy,
    env_id=None,
    env_class=None,
    env_kwargs=None,
    reward_kwargs=None,
    observation_key='latent_observation',
    desired_goal_key='latent_desired_goal',
    state_observation_key='state_observation',
    state_goal_key='state_desired_goal',
    image_goal_key='image_desired_goal',
    path_loader_class=MDPPathLoader,
    demo_replay_buffer_kwargs=None,
    path_loader_kwargs=None,
    env_demo_path='',
    env_offpolicy_data_path='',
    debug=False,
    epsilon=1.0,
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    add_env_demos=False,
    add_env_offpolicy_data=False,
    save_paths=False,
    load_demos=False,
    pretrain_policy=False,
    pretrain_rl=False,
    save_pretrained_algorithm=False,

    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
    imsize=84,
    pretrained_vae_path="",
    presampled_goals_path="",
    init_camera=None,
    qf_class=ConcatMlp,
):

    #Kwarg Definitions
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if demo_replay_buffer_kwargs is None:
        demo_replay_buffer_kwargs = {}
    if path_loader_kwargs is None:
        path_loader_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    if debug:
        max_path_length = 5
        algo_kwargs['batch_size'] = 5
        algo_kwargs['num_epochs'] = 5
        algo_kwargs['num_eval_steps_per_epoch'] = 100
        algo_kwargs['num_expl_steps_per_train_loop'] = 100
        algo_kwargs['num_trains_per_train_loop'] = 10
        algo_kwargs['min_num_steps_before_training'] = 100
        algo_kwargs['min_num_steps_before_training'] = 100
        trainer_kwargs['bc_num_pretrain_steps'] = min(
            10, trainer_kwargs.get('bc_num_pretrain_steps', 0))
        trainer_kwargs['q_num_pretrain1_steps'] = min(
            10, trainer_kwargs.get('q_num_pretrain1_steps', 0))
        trainer_kwargs['q_num_pretrain2_steps'] = min(
            10, trainer_kwargs.get('q_num_pretrain2_steps', 0))

    #Enviorment Wrapping
    renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)

    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode,
                                          presampled_goals_path):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        # encoded_env = EncoderWrappedEnv(
        #     img_env,
        #     model,
        #     dict(image_observation="latent_observation", ),
        # )
        # if goal_sampling_mode == "vae_prior":
        #     latent_goal_distribution = PriorDistribution(
        #         model.representation_size,
        #         desired_goal_key,
        #     )
        #     diagnostics = StateImageGoalDiagnosticsFn({}, )
        # elif goal_sampling_mode == "presampled":
        #     diagnostics = state_env.get_contextual_diagnostics
        #     image_goal_distribution = PresampledPathDistribution(
        #         presampled_goals_path,
        #     )

        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        # elif goal_sampling_mode == "reset_of_env":
        #     state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        #     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
        #         state_goal_env,
        #         desired_goal_keys=[state_goal_key],
        #     )
        #     image_goal_distribution = AddImageDistribution(
        #         env=state_env,
        #         base_distribution=state_goal_distribution,
        #         image_goal_key=image_goal_key,
        #         renderer=renderer,
        #     )
        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        #     no_goal_distribution = PriorDistribution(
        #         representation_size=0,
        #         key="no_goal",
        #     )
        #     diagnostics = state_goal_env.get_contextual_diagnostics
        # else:
        #     error
        diagnostics = StateImageGoalDiagnosticsFn({}, )
        no_goal_distribution = PriorDistribution(
            representation_size=0,
            key="no_goal",
        )

        reward_fn = GraspingRewardFn(
            # img_env, # state_env,
            # observation_key=observation_key,
            # desired_goal_key=desired_goal_key,
            # **reward_kwargs
        )

        env = ContextualEnv(
            img_env,  # state_env,
            context_distribution=no_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, no_goal_distribution, reward_fn

    #VAE Setup
    if pretrained_vae_path:
        model = load_local_or_remote_file(pretrained_vae_path)
    else:
        model = train_vae(train_vae_kwargs, env_kwargs, env_id, env_class,
                          imsize, init_camera)
    path_loader_kwargs['model_path'] = pretrained_vae_path

    #Enviorment Definitions
    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode,
        presampled_goals_path)
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        presampled_goals_path)
    path_loader_kwargs['env'] = eval_env

    #AWAC Code
    if add_env_demos:
        path_loader_kwargs["demo_paths"].append(env_demo_path)
    if add_env_offpolicy_data:
        path_loader_kwargs["demo_paths"].append(env_offpolicy_data_path)

    #Key Setting
    context_key = desired_goal_key
    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    state_rewards = reward_kwargs.get('reward_type', 'dense') == 'wrapped_env'
    # if state_rewards:
    #     mapper = RemapKeyFn({context_key: observation_key, state_goal_key: state_observation_key})
    #     obs_keys = [state_observation_key, observation_key]
    #     cont_keys = [state_goal_key, context_key]
    # else:
    mapper = RemapKeyFn({context_key: observation_key})
    obs_keys = [observation_key]
    cont_keys = [context_key]

    #Replay Buffer
    def concat_context_to_obs(batch, replay_buffer, obs_dict, next_obs_dict,
                              new_contexts):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=cont_keys,
        observation_keys=obs_keys,
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=mapper,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    replay_buffer_kwargs.update(demo_replay_buffer_kwargs)
    demo_train_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=cont_keys,
        observation_keys=obs_keys,
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=mapper,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    demo_test_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=cont_keys,
        observation_keys=obs_keys,
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=mapper,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)

    #Neural Network Architecture
    def create_qf():
        # return ConcatMlp(
        #     input_size=obs_dim + action_dim,
        #     output_size=1,
        #     **qf_kwargs
        # )
        if qf_class is ConcatMlp:
            qf_kwargs["input_size"] = obs_dim + action_dim
        if qf_class is ConcatCNN:
            qf_kwargs["added_fc_input_size"] = action_dim
        return qf_class(output_size=1, **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = policy_class(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **policy_kwargs,
    )

    #Path Collectors
    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )

    #Algorithm
    trainer = AWACTrainer(env=eval_env,
                          policy=policy,
                          qf1=qf1,
                          qf2=qf2,
                          target_qf1=target_qf1,
                          target_qf2=target_qf2,
                          **trainer_kwargs)

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)

    algorithm.to(ptu.device)

    #Video Saving
    if save_video:

        expl_video_func = RIGVideoSaveFunction(
            model,
            expl_path_collector,
            "train",
            # decode_goal_image_key="image_decoded_goal",
            # reconstruction_key="image_reconstruction",
            rows=2,
            columns=5,
            unnormalize=True,
            imsize=imsize,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(expl_video_func)

        eval_video_func = RIGVideoSaveFunction(
            model,
            eval_path_collector,
            "eval",
            # goal_image_key=image_goal_key,
            # decode_goal_image_key="image_decoded_goal",
            # reconstruction_key="image_reconstruction",
            num_imgs=4,
            rows=2,
            columns=5,
            unnormalize=True,
            imsize=imsize,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)

    #AWAC CODE
    if save_paths:
        algorithm.post_train_funcs.append(save_paths)

    if load_demos:
        path_loader = path_loader_class(
            trainer,
            replay_buffer=replay_buffer,
            demo_train_buffer=demo_train_buffer,
            demo_test_buffer=demo_test_buffer,
            # reward_fn=eval_reward, # omit reward because its recomputed later
            **path_loader_kwargs)
        path_loader.load_demos()
    if pretrain_policy:
        trainer.pretrain_policy_with_bc(
            policy,
            demo_train_buffer,
            demo_test_buffer,
            trainer.bc_num_pretrain_steps,
        )
    if pretrain_rl:
        trainer.pretrain_q_with_bc_data()

    if save_pretrained_algorithm:
        p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p')
        pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt')
        data = algorithm._get_snapshot()
        data['algorithm'] = algorithm
        torch.save(data, open(pt_path, "wb"))
        torch.save(data, open(p_path, "wb"))

    algorithm.train()
Beispiel #13
0
    def create_path_collector(
            env,
            policy,
            mode='expl',
            mask_kwargs={},
    ):
        assert mode in ['expl', 'eval']

        save_env_in_snapshot = variant.get('save_env_in_snapshot', True)

        if mask_conditioned:
            if 'rollout_mask_order' in mask_kwargs:
                rollout_mask_order = mask_kwargs['rollout_mask_order']
            else:
                if mode == 'expl':
                    rollout_mask_order = mask_variant.get('rollout_mask_order_for_expl', 'fixed')
                elif mode == 'eval':
                    rollout_mask_order = mask_variant.get('rollout_mask_order_for_eval', 'fixed')
                else:
                    raise TypeError

            if 'mask_distr' in mask_kwargs:
                mask_distr = mask_kwargs['mask_distr']
            else:
                if mode == 'expl':
                    mask_distr = mask_variant['expl_mask_distr']
                elif mode == 'eval':
                    mask_distr = mask_variant['eval_mask_distr']
                else:
                    raise TypeError

            if 'mask_ids' in mask_kwargs:
                mask_ids = mask_kwargs['mask_ids']
            else:
                if mode == 'expl':
                    mask_ids = mask_variant.get('mask_ids_for_expl', None)
                elif mode == 'eval':
                    mask_ids = mask_variant.get('mask_ids_for_eval', None)
                else:
                    raise TypeError

            prev_subtask_weight = mask_variant.get('prev_subtask_weight', None)
            max_subtasks_to_focus_on = mask_variant.get('max_subtasks_to_focus_on', None)
            max_subtasks_per_rollout = mask_variant.get('max_subtasks_per_rollout', None)

            mode = mask_variant.get('context_post_process_mode', None)
            if mode in ['dilute_prev_subtasks_uniform', 'dilute_prev_subtasks_fixed']:
                prev_subtask_weight = 0.5

            return MaskPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                concat_context_to_obs_fn=concat_context_to_obs,
                save_env_in_snapshot=save_env_in_snapshot,
                mask_sampler=(context_distrib if mode=='expl' else eval_context_distrib),
                mask_distr=mask_distr.copy(),
                mask_ids=mask_ids,
                max_path_length=max_path_length,
                rollout_mask_order=rollout_mask_order,
                prev_subtask_weight=prev_subtask_weight,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                max_subtasks_per_rollout=max_subtasks_per_rollout,
            )
        elif contextual_mdp:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                save_env_in_snapshot=save_env_in_snapshot,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[],
                save_env_in_snapshot=save_env_in_snapshot,
            )
Beispiel #14
0
def disco_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    generate_set_for_rl_kwargs,
    # VAE parameters
    create_vae_kwargs,
    vae_trainer_kwargs,
    vae_algo_kwargs,
    data_loader_kwargs,
    generate_set_for_vae_pretraining_kwargs,
    num_ungrouped_images,
    beta_schedule_kwargs=None,
    # Oracle settings
    use_ground_truth_reward=False,
    use_onehot_set_embedding=False,
    use_dummy_model=False,
    observation_key="latent_observation",
    # RIG comparison
    rig_goal_setter_kwargs=None,
    rig=False,
    # Miscellaneous
    reward_fn_kwargs=None,
    # None-VAE Params
    env_id=None,
    env_class=None,
    env_kwargs=None,
    latent_observation_key="latent_observation",
    state_observation_key="state_observation",
    image_observation_key="image_observation",
    set_description_key="set_description",
    example_state_key="example_state",
    example_image_key="example_image",
    # Exploration
    exploration_policy_kwargs=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
):
    if rig_goal_setter_kwargs is None:
        rig_goal_setter_kwargs = {}
    if reward_fn_kwargs is None:
        reward_fn_kwargs = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    renderer = EnvRenderer(**renderer_kwargs)

    sets = create_sets(
        env_id,
        env_class,
        env_kwargs,
        renderer,
        example_state_key=example_state_key,
        example_image_key=example_image_key,
        **generate_set_for_rl_kwargs,
    )
    if use_dummy_model:
        model = create_dummy_image_vae(img_chw=renderer.image_chw,
                                       **create_vae_kwargs)
    else:
        model = train_set_vae(
            create_vae_kwargs,
            vae_trainer_kwargs,
            vae_algo_kwargs,
            data_loader_kwargs,
            generate_set_for_vae_pretraining_kwargs,
            num_ungrouped_images,
            env_id=env_id,
            env_class=env_class,
            env_kwargs=env_kwargs,
            beta_schedule_kwargs=beta_schedule_kwargs,
            sets=sets,
            renderer=renderer,
        )
    expl_env, expl_context_distrib, expl_reward = (
        contextual_env_distrib_and_reward(
            vae=model,
            sets=sets,
            state_env=get_gym_env(
                env_id,
                env_class=env_class,
                env_kwargs=env_kwargs,
            ),
            renderer=renderer,
            reward_fn_kwargs=reward_fn_kwargs,
            use_ground_truth_reward=use_ground_truth_reward,
            state_observation_key=state_observation_key,
            latent_observation_key=latent_observation_key,
            example_image_key=example_image_key,
            set_description_key=set_description_key,
            observation_key=observation_key,
            image_observation_key=image_observation_key,
            rig_goal_setter_kwargs=rig_goal_setter_kwargs,
        ))
    eval_env, eval_context_distrib, eval_reward = (
        contextual_env_distrib_and_reward(
            vae=model,
            sets=sets,
            state_env=get_gym_env(
                env_id,
                env_class=env_class,
                env_kwargs=env_kwargs,
            ),
            renderer=renderer,
            reward_fn_kwargs=reward_fn_kwargs,
            use_ground_truth_reward=use_ground_truth_reward,
            state_observation_key=state_observation_key,
            latent_observation_key=latent_observation_key,
            example_image_key=example_image_key,
            set_description_key=set_description_key,
            observation_key=observation_key,
            image_observation_key=image_observation_key,
            rig_goal_setter_kwargs=rig_goal_setter_kwargs,
            oracle_rig_goal=rig,
        ))
    context_keys = [
        expl_context_distrib.mean_key,
        expl_context_distrib.covariance_key,
        expl_context_distrib.set_index_key,
        expl_context_distrib.set_embedding_key,
    ]
    if rig:
        context_keys_for_rl = [
            expl_context_distrib.mean_key,
        ]
    else:
        if use_onehot_set_embedding:
            context_keys_for_rl = [
                expl_context_distrib.set_embedding_key,
            ]
        else:
            context_keys_for_rl = [
                expl_context_distrib.mean_key,
                expl_context_distrib.covariance_key,
            ]

    obs_dim = np.prod(expl_env.observation_space.spaces[observation_key].shape)
    obs_dim += sum([
        np.prod(expl_env.observation_space.spaces[k].shape)
        for k in context_keys_for_rl
    ])
    action_dim = np.prod(expl_env.action_space.shape)

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch["observations"]
        next_obs = batch["next_observations"]
        contexts = [batch[k] for k in context_keys_for_rl]
        batch["observations"] = np.concatenate((obs, *contexts), axis=1)
        batch["next_observations"] = np.concatenate(
            (next_obs, *contexts),
            axis=1,
        )
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=context_keys,
        observation_keys=list(
            {observation_key, state_observation_key, latent_observation_key}),
        observation_key=observation_key,
        context_distribution=FilterKeys(
            expl_context_distrib,
            context_keys,
        ),
        sample_context_from_obs_dict_fn=None,
        # RemapKeyFn({context_key: observation_key}),
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs,
    )
    trainer = SACTrainer(
        env=expl_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **sac_trainer_kwargs,
    )

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=context_keys_for_rl,
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=context_keys_for_rl,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs,
    )
    algorithm.to(ptu.device)

    if save_video:
        set_index_key = eval_context_distrib.set_index_key
        expl_video_func = DisCoVideoSaveFunction(
            model,
            sets,
            expl_path_collector,
            tag="train",
            reconstruction_key="image_reconstruction",
            decode_set_image_key="decoded_set_prior",
            set_visualization_key="set_visualization",
            example_image_key=example_image_key,
            set_index_key=set_index_key,
            columns=len(sets),
            unnormalize=True,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs,
        )
        algorithm.post_train_funcs.append(expl_video_func)

        eval_video_func = DisCoVideoSaveFunction(
            model,
            sets,
            eval_path_collector,
            tag="eval",
            reconstruction_key="image_reconstruction",
            decode_set_image_key="decoded_set_prior",
            set_visualization_key="set_visualization",
            example_image_key=example_image_key,
            set_index_key=set_index_key,
            columns=len(sets),
            unnormalize=True,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs,
        )
        algorithm.post_train_funcs.append(eval_video_func)

    algorithm.train()
Beispiel #15
0
def masking_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    env_class=None,
    env_kwargs=None,
    observation_key='state_observation',
    desired_goal_key='state_desired_goal',
    achieved_goal_key='state_achieved_goal',
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
    train_env_id=None,
    eval_env_id=None,
    do_masking=True,
    mask_key="masked_observation",
    masking_eval_steps=200,
    log_mask_diagnostics=True,
    mask_dim=None,
    masking_reward_fn=None,
    masking_for_exploration=True,
    rotate_masks_for_eval=False,
    rotate_masks_for_expl=False,
    mask_distribution=None,
    num_steps_per_mask_change=10,
    tag=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    context_key = desired_goal_key
    env = get_gym_env(train_env_id, env_class=env_class, env_kwargs=env_kwargs)
    mask_dim = (mask_dim or env.observation_space.spaces[context_key].low.size)

    if not do_masking:
        mask_distribution = 'all_ones'

    assert mask_distribution in [
        'one_hot_masks', 'random_bit_masks', 'all_ones'
    ]
    if mask_distribution == 'all_ones':
        mask_distribution = 'static_mask'

    def contextual_env_distrib_and_reward(
        env_id,
        env_class,
        env_kwargs,
        goal_sampling_mode,
        env_mask_distribution_type,
        static_mask=None,
    ):
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        env.goal_sampling_mode = goal_sampling_mode
        use_static_mask = (not do_masking
                           or env_mask_distribution_type == 'static_mask')
        # Default to all ones mask if static mask isn't defined and should
        # be using static masks.
        if use_static_mask and static_mask is None:
            static_mask = np.ones(mask_dim)
        if not do_masking:
            assert env_mask_distribution_type == 'static_mask'

        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            env,
            desired_goal_keys=[desired_goal_key],
        )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=mask_dim,
            distribution_type=env_mask_distribution_type,
            static_mask=static_mask,
        )

        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=env,
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                observation_key),
        )

        if do_masking:
            if masking_reward_fn:
                reward_fn = partial(masking_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)
            else:
                reward_fn = partial(default_masked_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )

        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
        )
        return env, goal_distribution, reward_fn

    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        train_env_id,
        env_class,
        env_kwargs,
        exploration_goal_sampling_mode,
        mask_distribution if masking_for_exploration else 'static_mask',
    )
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        eval_env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        'static_mask')

    # Distribution for relabeling
    relabel_context_distrib = GoalDictDistributionFromMultitaskEnv(
        env,
        desired_goal_keys=[desired_goal_key],
    )
    relabel_context_distrib = MaskedGoalDictDistribution(
        relabel_context_distrib,
        mask_key=mask_key,
        mask_dim=mask_dim,
        distribution_type=mask_distribution,
        static_mask=None if do_masking else np.ones(mask_dim),
    )

    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size +
               mask_dim)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def context_from_obs_dict_fn(obs_dict):
        achieved_goal = obs_dict['state_achieved_goal']
        # Should the mask be randomized for future relabeling?
        # batch_size = len(achieved_goal)
        # mask = np.random.choice(2, size=(batch_size, mask_dim))
        mask = obs_dict[mask_key]
        return {
            mask_key: mask,
            context_key: achieved_goal,
        }

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']

        context = batch[context_key]
        mask = batch[mask_key]

        batch['observations'] = np.concatenate([obs, context, mask], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context, mask],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key, mask_key],
        observation_keys=[observation_key, 'state_achieved_goal'],
        context_distribution=relabel_context_distrib,
        sample_context_from_obs_dict_fn=context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    def create_path_collector(env, policy, is_rotating):
        if is_rotating:
            assert do_masking and mask_distribution == 'one_hot_masks'
            return RotatingMaskingPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[context_key, mask_key],
                mask_key=mask_key,
                mask_length=mask_dim,
                num_steps_per_mask_change=num_steps_per_mask_change,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[context_key, mask_key],
            )

    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)

    eval_path_collector = create_path_collector(eval_env,
                                                MakeDeterministic(policy),
                                                rotate_masks_for_eval)
    expl_path_collector = create_path_collector(expl_env, exploration_policy,
                                                rotate_masks_for_expl)

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            context_keys_for_policy=[context_key, mask_key],
            # Eval on everything for the base video
            obs_processor=lambda o: np.hstack(
                (o[observation_key], o[context_key], np.ones(mask_dim))))
        renderer = EnvRenderer(**renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImageEnv(state_env, renderer=renderer)
            return ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=eval_reward,
                observation_key=observation_key,
                # update_env_info_fn=DeleteOldEnvInfo(),
            )

        img_eval_env = add_images(eval_env, eval_context_distrib)
        img_expl_env = add_images(expl_env, expl_context_distrib)
        eval_video_func = get_save_video_function(
            rollout_function,
            img_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=renderer.image_shape[0],
            image_format='CWH',
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="train",
            imsize=renderer.image_shape[0],
            image_format='CWH',
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    # For diagnostics, evaluate the policy on each individual dimension of the
    # mask.
    masks = []
    collectors = []
    for mask_idx in range(mask_dim):
        mask = np.zeros(mask_dim)
        mask[mask_idx] = 1
        masks.append(mask)

    for mask in masks:
        for_dim_mask = mask if do_masking else np.ones(mask_dim)
        masked_env, _, _ = contextual_env_distrib_and_reward(
            eval_env_id,
            env_class,
            env_kwargs,
            evaluation_goal_sampling_mode,
            'static_mask',
            static_mask=for_dim_mask)

        collector = ContextualPathCollector(
            masked_env,
            MakeDeterministic(policy),
            observation_key=observation_key,
            context_keys_for_policy=[context_key, mask_key],
        )
        collectors.append(collector)
    log_prefixes = [
        'mask_{}/'.format(''.join(mask.astype(int).astype(str)))
        for mask in masks
    ]

    def get_mask_diagnostics(unused):
        from rlkit.core.logging import append_log, add_prefix, OrderedDict
        from rlkit.misc import eval_util
        log = OrderedDict()
        for prefix, collector in zip(log_prefixes, collectors):
            paths = collector.collect_new_paths(
                max_path_length,
                masking_eval_steps,
                discard_incomplete_paths=True,
            )
            generic_info = add_prefix(
                eval_util.get_generic_path_information(paths),
                prefix,
            )
            append_log(log, generic_info)

        for collector in collectors:
            collector.end_epoch(0)
        return log

    if log_mask_diagnostics:
        algorithm._eval_get_diag_fns.append(get_mask_diagnostics)

    algorithm.train()