예제 #1
0
 def add_images(env, base_distribution):
     if use_image_observations:
         img_env = env
         image_goal_distribution = base_distribution
     else:
         state_env = env.env
         image_goal_distribution = AddImageDistribution(
             env=state_env,
             base_distribution=base_distribution,
             image_goal_key='image_desired_goal',
             renderer=video_renderer,
         )
         img_env = InsertImageEnv(state_env, renderer=video_renderer)
     img_env = InsertDebugImagesEnv(
         img_env,
         obj1_sweep_renderers,
         compute_shared_data=obj1_sweeper,
     )
     img_env = InsertDebugImagesEnv(
         img_env,
         obj0_sweep_renderers,
         compute_shared_data=obj0_sweeper,
     )
     return ContextualEnv(
         img_env,
         context_distribution=image_goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key_for_rl,
         update_env_info_fn=delete_info,
     )
예제 #2
0
 def contextual_env_distrib_and_reward(
         env_id, env_class, env_kwargs, goal_sampling_mode
 ):
     env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
     env.goal_sampling_mode = goal_sampling_mode
     goal_distribution = GoalDictDistributionFromMultitaskEnv(
         env,
         desired_goal_keys=[desired_goal_key],
     )
     reward_fn = ContextualRewardFnFromMultitaskEnv(
         env=env,
         achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key),
         desired_goal_key=desired_goal_key,
         achieved_goal_key=achieved_goal_key,
     )
     diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
         env.goal_conditioned_diagnostics,
         desired_goal_key=desired_goal_key,
         observation_key=observation_key,
     )
     env = ContextualEnv(
         env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
         contextual_diagnostics_fns=[diag_fn],
         update_env_info_fn=delete_info,
     )
     return env, goal_distribution, reward_fn
예제 #3
0
 def add_images(env, base_distribution):
     if use_image_observations:
         video_env = InsertImageEnv(
             env,
             renderer=video_renderer,
             image_key='video_observation',
         )
         image_goal_distribution = base_distribution
     else:
         video_env = InsertImageEnv(
             env,
             renderer=video_renderer,
             image_key='image_observation',
         )
         state_env = env.env
         image_goal_distribution = AddImageDistribution(
             env=state_env,
             base_distribution=base_distribution,
             image_goal_key='image_desired_goal',
             renderer=video_renderer,
         )
     return ContextualEnv(
         video_env,
         context_distribution=image_goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key_for_rl,
         update_env_info_fn=delete_info,
     )
예제 #4
0
    def setup_env(state_env, encoder, reward_fn):
        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        if use_image_observations:
            goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=goal_distribution,
                image_goal_key=img_desired_goal_key,
                renderer=env_renderer,
            )
            base_env = InsertImageEnv(state_env, renderer=env_renderer)
            goal_distribution = PresampledDistribution(
                goal_distribution, num_presampled_goals)
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key, img_desired_goal_key],
                encoder_input_key=img_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        else:
            base_env = state_env
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key],
                encoder_input_key=state_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=latent_dim,
            distribution_type='one_hot_masks',
        )

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        env = ContextualEnv(
            base_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
            **contextual_env_kwargs,
        )
        return env, goal_distribution
예제 #5
0
 def add_images(env, context_distribution):
     state_env = env.env
     img_env = InsertImageEnv(
         state_env,
         renderer=renderer,
         image_key='image_observation',
     )
     return ContextualEnv(
         img_env,
         context_distribution=context_distribution,
         reward_fn=eval_reward,
         observation_key=observation_key,
         update_env_info_fn=delete_info,
     )
예제 #6
0
 def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode,
                          renderer):
     state_env = get_gym_env(env_id,
                             env_class=env_class,
                             env_kwargs=env_kwargs)
     state_env.goal_sampling_mode = goal_sampling_mode
     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
         state_env,
         desired_goal_keys=[state_desired_goal_key],
     )
     state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
         state_env.goal_conditioned_diagnostics,
         desired_goal_key=state_desired_goal_key,
         observation_key=state_observation_key,
     )
     image_goal_distribution = AddImageDistribution(
         env=state_env,
         base_distribution=state_goal_distribution,
         image_goal_key=img_desired_goal_key,
         renderer=renderer,
     )
     goal_distribution = PresampledDistribution(image_goal_distribution,
                                                5000)
     img_env = InsertImageEnv(state_env, renderer=renderer)
     if reward_type == 'state_distance':
         reward_fn = ContextualRewardFnFromMultitaskEnv(
             env=state_env,
             achieved_goal_from_observation=IndexIntoAchievedGoal(
                 'state_observation'),
             desired_goal_key=state_desired_goal_key,
             achieved_goal_key=state_achieved_goal_key,
         )
     elif reward_type == 'pixel_distance':
         reward_fn = NegativeL2Distance(
             achieved_goal_from_observation=IndexIntoAchievedGoal(
                 img_observation_key),
             desired_goal_key=img_desired_goal_key,
         )
     else:
         raise ValueError(reward_type)
     env = ContextualEnv(
         img_env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=img_observation_key,
         contextual_diagnostics_fns=[state_diag_fn],
         update_env_info_fn=delete_info,
     )
     return env, goal_distribution, reward_fn
예제 #7
0
 def add_images(env, state_distribution):
     state_env = env.env
     image_goal_distribution = AddImageDistribution(
         env=state_env,
         base_distribution=state_distribution,
         image_goal_key='image_desired_goal',
         renderer=renderer,
     )
     img_env = InsertImageEnv(state_env, renderer=renderer)
     return ContextualEnv(
         img_env,
         context_distribution=image_goal_distribution,
         reward_fn=eval_reward,
         observation_key=observation_key,
         update_env_info_fn=delete_info,
     )
예제 #8
0
 def get_video_func(
     env,
     policy,
     tag,
 ):
     renderer = EnvRenderer(**renderer_kwargs)
     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
         env,
         desired_goal_keys=[desired_goal_key],
     )
     image_goal_distribution = AddImageDistribution(
         env=env,
         base_distribution=state_goal_distribution,
         image_goal_key="image_desired_goal",
         renderer=renderer,
     )
     img_env = InsertImageEnv(env, renderer=renderer)
     rollout_function = partial(
         rf.multitask_rollout,
         max_path_length=variant["max_path_length"],
         observation_key=observation_key,
         desired_goal_key=desired_goal_key,
         return_dict_obs=True,
     )
     reward_fn = ContextualRewardFnFromMultitaskEnv(
         env=env,
         achieved_goal_from_observation=IndexIntoAchievedGoal(
             observation_key),
         desired_goal_key=desired_goal_key,
         achieved_goal_key="state_achieved_goal",
     )
     contextual_env = ContextualEnv(
         img_env,
         context_distribution=image_goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
     )
     video_func = get_save_video_function(
         rollout_function,
         contextual_env,
         policy,
         tag=tag,
         imsize=renderer.width,
         image_format="CWH",
         **save_video_kwargs,
     )
     return video_func
예제 #9
0
    def contextual_env_distrib_and_reward(mode='expl'):
        assert mode in ['expl', 'eval']
        env = make(env_id, env_class, env_kwargs, normalize_env)
        env = GymToMultiEnv(env)
        # env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)

        no_goal_distribution = PriorDistribution(
            representation_size=0,
            key="no_goal",
        )
        contextual_reward_fn = None
        env = ContextualEnv(
            env,
            context_distribution=no_goal_distribution,
            reward_fn=contextual_reward_fn,
            observation_key=observation_key,
            # contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=None,
        )
        return env, no_goal_distribution, contextual_reward_fn
예제 #10
0
 def add_images(env, state_distribution):
     state_env = env.env
     image_goal_distribution = AddImageDistribution(
         env=state_env,
         base_distribution=state_distribution,
         image_goal_key='image_desired_goal',
         renderer=renderer,
     )
     img_env = InsertImagesEnv(state_env,
                               renderers={
                                   'image_observation': renderer,
                               })
     context_env = ContextualEnv(
         img_env,
         context_distribution=image_goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
         update_env_info_fn=None,
     )
     return context_env
예제 #11
0
 def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                       goal_sampling_mode):
     env = get_gym_env(
         env_id,
         env_class=env_class,
         env_kwargs=env_kwargs,
         unwrap_timed_envs=True,
     )
     env.goal_sampling_mode = goal_sampling_mode
     goal_distribution = GoalDictDistributionFromGymGoalEnv(
         env,
         desired_goal_key=desired_goal_key,
     )
     distance_fn = L2Distance(
         achieved_goal_from_observation=IndexIntoAchievedGoal(
             achieved_goal_key, ),
         desired_goal_key=desired_goal_key,
     )
     if (isinstance(env, robotics.FetchReachEnv)
             or isinstance(env, robotics.FetchPushEnv)
             or isinstance(env, robotics.FetchPickAndPlaceEnv)
             or isinstance(env, robotics.FetchSlideEnv)):
         success_threshold = 0.05
     else:
         raise TypeError("I don't know the success threshold of env ", env)
     reward_fn = ThresholdDistanceReward(distance_fn, success_threshold)
     diag_fn = GenericGoalConditionedContextualDiagnostics(
         desired_goal_key=desired_goal_key,
         achieved_goal_key=achieved_goal_key,
         success_threshold=success_threshold,
     )
     env = ContextualEnv(
         env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
         contextual_diagnostics_fns=[diag_fn],
         update_env_info_fn=delete_info,
     )
     return env, goal_distribution, reward_fn
예제 #12
0
 def add_images(env, state_distribution):
     state_env = env.env
     if is_gym_env:
         goal_distribution = state_distribution
     else:
         goal_distribution = AddImageDistribution(
             env=state_env,
             base_distribution=state_distribution,
             image_goal_key='image_desired_goal',
             renderer=video_renderer,
         )
     context_env = ContextualEnv(
         state_env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
         update_env_info_fn=delete_info,
     )
     return InsertDebugImagesEnv(
         context_env,
         renderers=renderers,
     )
예제 #13
0
def image_based_goal_conditioned_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    cnn_kwargs,
    policy_type='tanh-normal',
    env_id=None,
    env_class=None,
    env_kwargs=None,
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    reward_type='state_distance',
    env_renderer_kwargs=None,
    # Data augmentations
    apply_random_crops=False,
    random_crop_pixel_shift=4,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not env_renderer_kwargs:
        env_renderer_kwargs = {}
    if not video_renderer_kwargs:
        video_renderer_kwargs = {}
    img_observation_key = 'image_observation'
    img_desired_goal_key = 'image_desired_goal'
    state_observation_key = 'state_observation'
    state_desired_goal_key = 'state_desired_goal'
    state_achieved_goal_key = 'state_achieved_goal'
    sample_context_from_obs_dict_fn = RemapKeyFn({
        'image_desired_goal':
        'image_observation',
        'state_desired_goal':
        'state_observation',
    })

    def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode,
                             renderer):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        state_env.goal_sampling_mode = goal_sampling_mode
        state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        image_goal_distribution = AddImageDistribution(
            env=state_env,
            base_distribution=state_goal_distribution,
            image_goal_key=img_desired_goal_key,
            renderer=renderer,
        )
        goal_distribution = PresampledDistribution(image_goal_distribution,
                                                   5000)
        img_env = InsertImageEnv(state_env, renderer=renderer)
        if reward_type == 'state_distance':
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=state_env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    'state_observation'),
                desired_goal_key=state_desired_goal_key,
                achieved_goal_key=state_achieved_goal_key,
            )
        elif reward_type == 'pixel_distance':
            reward_fn = NegativeL2Distance(
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    img_observation_key),
                desired_goal_key=img_desired_goal_key,
            )
        else:
            raise ValueError(reward_type)
        env = ContextualEnv(
            img_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=img_observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, goal_distribution, reward_fn

    env_renderer = EnvRenderer(**env_renderer_kwargs)
    expl_env, expl_context_distrib, expl_reward = setup_contextual_env(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode,
        env_renderer)
    eval_env, eval_context_distrib, eval_reward = setup_contextual_env(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        env_renderer)

    action_dim = expl_env.action_space.low.size
    if env_renderer.output_image_format == 'WHC':
        img_width, img_height, img_num_channels = (
            expl_env.observation_space[img_observation_key].shape)
    elif env_renderer.output_image_format == 'CHW':
        img_num_channels, img_height, img_width = (
            expl_env.observation_space[img_observation_key].shape)
    else:
        raise ValueError(env_renderer.output_image_format)

    def create_qf():
        cnn = BasicCNN(input_width=img_width,
                       input_height=img_height,
                       input_channels=img_num_channels,
                       **cnn_kwargs)
        joint_cnn = ApplyConvToStateAndGoalImage(cnn)
        return basic.MultiInputSequential(
            ApplyToObs(joint_cnn), basic.FlattenEachParallel(),
            ConcatMlp(input_size=joint_cnn.output_size + action_dim,
                      output_size=1,
                      **qf_kwargs))

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()
    cnn = BasicCNN(input_width=img_width,
                   input_height=img_height,
                   input_channels=img_num_channels,
                   **cnn_kwargs)
    joint_cnn = ApplyConvToStateAndGoalImage(cnn)
    policy_obs_dim = joint_cnn.output_size
    if policy_type == 'normal':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(Gaussian(obs_processor))
    elif policy_type == 'tanh-normal':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(TanhGaussian(obs_processor))
    elif policy_type == 'normal-tanh-mean':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           output_activations=['tanh', 'identity'],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(Gaussian(obs_processor))
    else:
        raise ValueError("Unknown policy type: {}".format(policy_type))

    if apply_random_crops:
        pad = BatchPad(
            env_renderer.output_image_format,
            random_crop_pixel_shift,
            random_crop_pixel_shift,
        )
        crop = JointRandomCrop(
            env_renderer.output_image_format,
            env_renderer.image_shape,
        )

        def concat_context_to_obs(batch, *args, **kwargs):
            obs = batch['observations']
            next_obs = batch['next_observations']
            context = batch[img_desired_goal_key]
            obs_padded = pad(obs)
            next_obs_padded = pad(next_obs)
            context_padded = pad(context)
            obs_aug, context_aug = crop(obs_padded, context_padded)
            next_obs_aug, next_context_aug = crop(next_obs_padded,
                                                  context_padded)

            batch['observations'] = np.concatenate([obs_aug, context_aug],
                                                   axis=1)
            batch['next_observations'] = np.concatenate(
                [next_obs_aug, next_context_aug], axis=1)
            return batch
    else:

        def concat_context_to_obs(batch, *args, **kwargs):
            obs = batch['observations']
            next_obs = batch['next_observations']
            context = batch[img_desired_goal_key]
            batch['observations'] = np.concatenate([obs, context], axis=1)
            batch['next_observations'] = np.concatenate([next_obs, context],
                                                        axis=1)
            return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[img_desired_goal_key, state_desired_goal_key],
        observation_key=img_observation_key,
        observation_keys=[img_observation_key, state_observation_key],
        context_distribution=eval_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=img_observation_key,
        context_keys_for_policy=[img_desired_goal_key],
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=img_observation_key,
        context_keys_for_policy=[img_desired_goal_key],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=img_observation_key,
            context_keys_for_policy=[img_desired_goal_key],
        )
        video_renderer = EnvRenderer(**video_renderer_kwargs)
        video_eval_env = InsertImageEnv(eval_env,
                                        renderer=video_renderer,
                                        image_key='video_observation')
        video_expl_env = InsertImageEnv(expl_env,
                                        renderer=video_renderer,
                                        image_key='video_observation')
        video_eval_env = ContextualEnv(
            video_eval_env,
            context_distribution=eval_env.context_distribution,
            reward_fn=lambda *_: np.array([0]),
            observation_key=img_observation_key,
        )
        video_expl_env = ContextualEnv(
            video_expl_env,
            context_distribution=expl_env.context_distribution,
            reward_fn=lambda *_: np.array([0]),
            observation_key=img_observation_key,
        )
        eval_video_func = get_save_video_function(
            rollout_function,
            video_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=video_renderer.image_shape[1],
            image_formats=[
                env_renderer.output_image_format,
                env_renderer.output_image_format,
                video_renderer.output_image_format,
            ],
            keys_to_show=[
                'image_desired_goal', 'image_observation', 'video_observation'
            ],
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            video_expl_env,
            exploration_policy,
            tag="xplor",
            imsize=video_renderer.image_shape[1],
            image_formats=[
                env_renderer.output_image_format,
                env_renderer.output_image_format,
                video_renderer.output_image_format,
            ],
            keys_to_show=[
                'image_desired_goal', 'image_observation', 'video_observation'
            ],
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    algorithm.train()
예제 #14
0
 def contextual_env_distrib_reward(_env_id,
                                   _env_class=None,
                                   _env_kwargs=None):
     base_env = get_gym_env(
         _env_id,
         env_class=env_class,
         env_kwargs=env_kwargs,
         unwrap_timed_envs=True,
     )
     if init_camera:
         base_env.initialize_camera(init_camera)
     if (isinstance(stub_env, AntFullPositionGoalEnv)
             and normalize_distances_for_full_state_ant):
         base_env = NormalizeAntFullPositionGoalEnv(base_env)
         normalize_env = base_env
     else:
         normalize_env = None
     env = NoisyAction(base_env, action_noise_scale)
     diag_fns = []
     if is_gym_env:
         goal_distribution = GoalDictDistributionFromGymGoalEnv(
             env,
             desired_goal_key=desired_goal_key,
         )
         diag_fns.append(
             GenericGoalConditionedContextualDiagnostics(
                 desired_goal_key=desired_goal_key,
                 achieved_goal_key=achieved_goal_key,
                 success_threshold=success_threshold,
             ))
     else:
         goal_distribution = GoalDictDistributionFromMultitaskEnv(
             env,
             desired_goal_keys=[desired_goal_key],
         )
         diag_fns.append(
             GoalConditionedDiagnosticsToContextualDiagnostics(
                 env.goal_conditioned_diagnostics,
                 desired_goal_key=desired_goal_key,
                 observation_key=observation_key,
             ))
     if isinstance(stub_env, AntFullPositionGoalEnv):
         diag_fns.append(
             AntFullPositionGoalEnvDiagnostics(
                 desired_goal_key=desired_goal_key,
                 achieved_goal_key=achieved_goal_key,
                 success_threshold=success_threshold,
                 normalize_env=normalize_env,
             ))
     # if isinstance(stub_env, HopperFullPositionGoalEnv):
     #     diag_fns.append(
     #         HopperFullPositionGoalEnvDiagnostics(
     #             desired_goal_key=desired_goal_key,
     #             achieved_goal_key=achieved_goal_key,
     #             success_threshold=success_threshold,
     #         )
     #     )
     achieved_from_ob = IndexIntoAchievedGoal(achieved_goal_key, )
     if reward_type == 'sparse':
         distance_fn = L2Distance(
             achieved_goal_from_observation=achieved_from_ob,
             desired_goal_key=desired_goal_key,
         )
         reward_fn = ThresholdDistanceReward(distance_fn, success_threshold)
     elif reward_type == 'negative_distance':
         reward_fn = NegativeL2Distance(
             achieved_goal_from_observation=achieved_from_ob,
             desired_goal_key=desired_goal_key,
         )
     else:
         reward_fn = ProbabilisticGoalRewardFn(
             dynamics_model,
             state_key=observation_key,
             context_key=context_key,
             reward_type=reward_type,
             discount_factor=discount_factor,
         )
     goal_distribution = PresampledDistribution(goal_distribution,
                                                num_presampled_goals)
     final_env = ContextualEnv(
         env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
         contextual_diagnostics_fns=diag_fns,
         update_env_info_fn=delete_info,
     )
     return final_env, goal_distribution, reward_fn
예제 #15
0
    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode,
                                          presampled_goals_path):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        # encoded_env = EncoderWrappedEnv(
        #     img_env,
        #     model,
        #     dict(image_observation="latent_observation", ),
        # )
        # if goal_sampling_mode == "vae_prior":
        #     latent_goal_distribution = PriorDistribution(
        #         model.representation_size,
        #         desired_goal_key,
        #     )
        #     diagnostics = StateImageGoalDiagnosticsFn({}, )
        # elif goal_sampling_mode == "presampled":
        #     diagnostics = state_env.get_contextual_diagnostics
        #     image_goal_distribution = PresampledPathDistribution(
        #         presampled_goals_path,
        #     )

        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        # elif goal_sampling_mode == "reset_of_env":
        #     state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        #     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
        #         state_goal_env,
        #         desired_goal_keys=[state_goal_key],
        #     )
        #     image_goal_distribution = AddImageDistribution(
        #         env=state_env,
        #         base_distribution=state_goal_distribution,
        #         image_goal_key=image_goal_key,
        #         renderer=renderer,
        #     )
        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        #     no_goal_distribution = PriorDistribution(
        #         representation_size=0,
        #         key="no_goal",
        #     )
        #     diagnostics = state_goal_env.get_contextual_diagnostics
        # else:
        #     error
        diagnostics = StateImageGoalDiagnosticsFn({}, )
        no_goal_distribution = PriorDistribution(
            representation_size=0,
            key="no_goal",
        )

        reward_fn = GraspingRewardFn(
            # img_env, # state_env,
            # observation_key=observation_key,
            # desired_goal_key=desired_goal_key,
            # **reward_kwargs
        )

        env = ContextualEnv(
            img_env,  # state_env,
            context_distribution=no_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, no_goal_distribution, reward_fn
예제 #16
0
    def contextual_env_distrib_and_reward(mode='expl'):
        assert mode in ['expl', 'eval']
        env = get_envs(variant)

        if mode == 'expl':
            goal_sampling_mode = variant.get('expl_goal_sampling_mode', None)
        elif mode == 'eval':
            goal_sampling_mode = variant.get('eval_goal_sampling_mode', None)
        if goal_sampling_mode not in [None, 'example_set']:
            env.goal_sampling_mode = goal_sampling_mode

        mask_ids_for_training = mask_variant.get('mask_ids_for_training', None)

        if mask_conditioned:
            context_distrib = MaskDictDistribution(
                env,
                desired_goal_keys=[desired_goal_key],
                mask_format=mask_format,
                masks=masks,
                max_subtasks_to_focus_on=mask_variant.get('max_subtasks_to_focus_on', None),
                prev_subtask_weight=mask_variant.get('prev_subtask_weight', None),
                mask_distr=mask_variant.get('train_mask_distr', None),
                mask_ids=mask_ids_for_training,
            )
            reward_fn = ContextualMaskingRewardFn(
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                mask_keys=mask_keys,
                mask_format=mask_format,
                use_g_for_mean=mask_variant['use_g_for_mean'],
                use_squared_reward=mask_variant.get('use_squared_reward', False),
            )
        else:
            if goal_sampling_mode == 'example_set':
                example_dataset = gen_example_sets(get_envs(variant), variant['example_set_variant'])
                assert len(example_dataset['list_of_waypoints']) == 1
                from rlkit.envs.contextual.set_distributions import GoalDictDistributionFromSet
                context_distrib = GoalDictDistributionFromSet(
                    example_dataset['list_of_waypoints'][0],
                    desired_goal_keys=[desired_goal_key],
                )
            else:
                context_distrib = GoalDictDistributionFromMultitaskEnv(
                    env,
                    desired_goal_keys=[desired_goal_key],
                )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=variant['contextual_replay_buffer_kwargs'].get('observation_keys', None),
            )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=context_distrib,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info if not variant.get('keep_env_infos', False) else None,
        )
        return env, context_distrib, reward_fn
예제 #17
0
    def contextual_env_distrib_and_reward(mode='expl'):
        assert mode in ['expl', 'eval']
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)

        if mode == 'expl':
            goal_sampling_mode = expl_goal_sampling_mode
        elif mode == 'eval':
            goal_sampling_mode = eval_goal_sampling_mode
        else:
            goal_sampling_mode = None
        if goal_sampling_mode is not None:
            env.goal_sampling_mode = goal_sampling_mode

        if mask_conditioned:
            context_distrib = MaskedGoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
                mask_keys=mask_keys,
                mask_dims=mask_dims,
                mask_format=mask_format,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                prev_subtask_weight=prev_subtask_weight,
                masks=masks,
                idx_masks=idx_masks,
                matrix_masks=matrix_masks,
                mask_distr=train_mask_distr,
            )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    achieved_goal_key),  # observation_key
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=contextual_replay_buffer_kwargs.get(
                    'observation_keys', None),
                additional_context_keys=mask_keys,
                reward_fn=partial(
                    mask_reward_fn,
                    mask_format=mask_format,
                    use_g_for_mean=use_g_for_mean
                ),
            )
        else:
            context_distrib = GoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
            )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    achieved_goal_key),  # observation_key
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=contextual_replay_buffer_kwargs.get(
                    'observation_keys', None),
            )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=context_distrib,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, context_distrib, reward_fn
예제 #18
0
    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)

        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        encoded_env = EncoderWrappedEnv(
            img_env,
            model,
            dict(image_observation="latent_observation", ),
        )
        if goal_sampling_mode == "vae_prior":
            latent_goal_distribution = PriorDistribution(
                model.representation_size,
                desired_goal_key,
            )
            diagnostics = StateImageGoalDiagnosticsFn({}, )
        elif goal_sampling_mode == "reset_of_env":
            state_goal_env = get_gym_env(env_id,
                                         env_class=env_class,
                                         env_kwargs=env_kwargs)
            state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
                state_goal_env,
                desired_goal_keys=[state_goal_key],
            )
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_goal_distribution,
                image_goal_key=image_goal_key,
                renderer=renderer,
            )
            latent_goal_distribution = AddLatentDistribution(
                image_goal_distribution,
                image_goal_key,
                desired_goal_key,
                model,
            )
            if hasattr(state_goal_env, 'goal_conditioned_diagnostics'):
                diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics(
                    state_goal_env.goal_conditioned_diagnostics,
                    desired_goal_key=state_goal_key,
                    observation_key=state_observation_key,
                )
            else:
                state_goal_env.get_contextual_diagnostics
                diagnostics = state_goal_env.get_contextual_diagnostics
        else:
            raise NotImplementedError('unknown goal sampling method: %s' %
                                      goal_sampling_mode)

        reward_fn = DistanceRewardFn(
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )

        env = ContextualEnv(
            encoded_env,
            context_distribution=latent_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, latent_goal_distribution, reward_fn
예제 #19
0
    def contextual_env_distrib_and_reward(
        env_id,
        env_class,
        env_kwargs,
        goal_sampling_mode,
        env_mask_distribution_type,
        static_mask=None,
    ):
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        env.goal_sampling_mode = goal_sampling_mode
        use_static_mask = (not do_masking
                           or env_mask_distribution_type == 'static_mask')
        # Default to all ones mask if static mask isn't defined and should
        # be using static masks.
        if use_static_mask and static_mask is None:
            static_mask = np.ones(mask_dim)
        if not do_masking:
            assert env_mask_distribution_type == 'static_mask'

        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            env,
            desired_goal_keys=[desired_goal_key],
        )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=mask_dim,
            distribution_type=env_mask_distribution_type,
            static_mask=static_mask,
        )

        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=env,
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                observation_key),
        )

        if do_masking:
            if masking_reward_fn:
                reward_fn = partial(masking_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)
            else:
                reward_fn = partial(default_masked_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )

        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
        )
        return env, goal_distribution, reward_fn