예제 #1
0
 def add_images(env, base_distribution):
     if use_image_observations:
         video_env = InsertImageEnv(
             env,
             renderer=video_renderer,
             image_key='video_observation',
         )
         image_goal_distribution = base_distribution
     else:
         video_env = InsertImageEnv(
             env,
             renderer=video_renderer,
             image_key='image_observation',
         )
         state_env = env.env
         image_goal_distribution = AddImageDistribution(
             env=state_env,
             base_distribution=base_distribution,
             image_goal_key='image_desired_goal',
             renderer=video_renderer,
         )
     return ContextualEnv(
         video_env,
         context_distribution=image_goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key_for_rl,
         update_env_info_fn=delete_info,
     )
예제 #2
0
def analyze_from_vae(snapshot_path,
                     latent_observation_key='latent_observation',
                     mean_key='latent_mean',
                     covariance_key='latent_covariance',
                     image_observation_key='image_observation',
                     **kwargs):
    data = torch.load(open(snapshot_path, "rb"))
    variant_path = snapshot_path.replace('params.pt', 'variant.json')
    print_settings(variant_path)
    vae = data['trainer/vae']
    state_env = gym.make('OneObject-PickAndPlace-BigBall-RandomInit-2D-v1')
    renderer = EnvRenderer()
    sets = make_custom_sets(state_env, renderer)
    reward_fn, _ = rewards.create_normal_likelihood_reward_fns(
        latent_observation_key=latent_observation_key,
        mean_key=mean_key,
        covariance_key=covariance_key,
        reward_fn_kwargs=dict(
            drop_log_det_term=True,
            sqrt_reward=True,
        ),
    )

    img_env = InsertImageEnv(state_env, renderer=renderer)
    env = DictEncoderWrappedEnv(
        img_env,
        vae,
        encoder_input_key='image_observation',
        encoder_output_remapping={'posterior_mean': 'latent_observation'},
    )
    analyze(sets, vae, env, **kwargs)
예제 #3
0
 def add_images(env, base_distribution):
     if use_image_observations:
         img_env = env
         image_goal_distribution = base_distribution
     else:
         state_env = env.env
         image_goal_distribution = AddImageDistribution(
             env=state_env,
             base_distribution=base_distribution,
             image_goal_key='image_desired_goal',
             renderer=video_renderer,
         )
         img_env = InsertImageEnv(state_env, renderer=video_renderer)
     img_env = InsertDebugImagesEnv(
         img_env,
         obj1_sweep_renderers,
         compute_shared_data=obj1_sweeper,
     )
     img_env = InsertDebugImagesEnv(
         img_env,
         obj0_sweep_renderers,
         compute_shared_data=obj0_sweeper,
     )
     return ContextualEnv(
         img_env,
         context_distribution=image_goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key_for_rl,
         update_env_info_fn=delete_info,
     )
예제 #4
0
    def setup_env(state_env, encoder, reward_fn):
        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        if use_image_observations:
            goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=goal_distribution,
                image_goal_key=img_desired_goal_key,
                renderer=env_renderer,
            )
            base_env = InsertImageEnv(state_env, renderer=env_renderer)
            goal_distribution = PresampledDistribution(
                goal_distribution, num_presampled_goals)
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key, img_desired_goal_key],
                encoder_input_key=img_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        else:
            base_env = state_env
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key],
                encoder_input_key=state_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=latent_dim,
            distribution_type='one_hot_masks',
        )

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        env = ContextualEnv(
            base_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
            **contextual_env_kwargs,
        )
        return env, goal_distribution
예제 #5
0
 def add_images(env, context_distribution):
     state_env = env.env
     img_env = InsertImageEnv(
         state_env,
         renderer=renderer,
         image_key='image_observation',
     )
     return ContextualEnv(
         img_env,
         context_distribution=context_distribution,
         reward_fn=eval_reward,
         observation_key=observation_key,
         update_env_info_fn=delete_info,
     )
예제 #6
0
 def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode,
                          renderer):
     state_env = get_gym_env(env_id,
                             env_class=env_class,
                             env_kwargs=env_kwargs)
     state_env.goal_sampling_mode = goal_sampling_mode
     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
         state_env,
         desired_goal_keys=[state_desired_goal_key],
     )
     state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
         state_env.goal_conditioned_diagnostics,
         desired_goal_key=state_desired_goal_key,
         observation_key=state_observation_key,
     )
     image_goal_distribution = AddImageDistribution(
         env=state_env,
         base_distribution=state_goal_distribution,
         image_goal_key=img_desired_goal_key,
         renderer=renderer,
     )
     goal_distribution = PresampledDistribution(image_goal_distribution,
                                                5000)
     img_env = InsertImageEnv(state_env, renderer=renderer)
     if reward_type == 'state_distance':
         reward_fn = ContextualRewardFnFromMultitaskEnv(
             env=state_env,
             achieved_goal_from_observation=IndexIntoAchievedGoal(
                 'state_observation'),
             desired_goal_key=state_desired_goal_key,
             achieved_goal_key=state_achieved_goal_key,
         )
     elif reward_type == 'pixel_distance':
         reward_fn = NegativeL2Distance(
             achieved_goal_from_observation=IndexIntoAchievedGoal(
                 img_observation_key),
             desired_goal_key=img_desired_goal_key,
         )
     else:
         raise ValueError(reward_type)
     env = ContextualEnv(
         img_env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=img_observation_key,
         contextual_diagnostics_fns=[state_diag_fn],
         update_env_info_fn=delete_info,
     )
     return env, goal_distribution, reward_fn
예제 #7
0
 def add_images(env, state_distribution):
     state_env = env.env
     image_goal_distribution = AddImageDistribution(
         env=state_env,
         base_distribution=state_distribution,
         image_goal_key='image_desired_goal',
         renderer=renderer,
     )
     img_env = InsertImageEnv(state_env, renderer=renderer)
     return ContextualEnv(
         img_env,
         context_distribution=image_goal_distribution,
         reward_fn=eval_reward,
         observation_key=observation_key,
         update_env_info_fn=delete_info,
     )
예제 #8
0
 def get_video_func(
     env,
     policy,
     tag,
 ):
     renderer = EnvRenderer(**renderer_kwargs)
     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
         env,
         desired_goal_keys=[desired_goal_key],
     )
     image_goal_distribution = AddImageDistribution(
         env=env,
         base_distribution=state_goal_distribution,
         image_goal_key="image_desired_goal",
         renderer=renderer,
     )
     img_env = InsertImageEnv(env, renderer=renderer)
     rollout_function = partial(
         rf.multitask_rollout,
         max_path_length=variant["max_path_length"],
         observation_key=observation_key,
         desired_goal_key=desired_goal_key,
         return_dict_obs=True,
     )
     reward_fn = ContextualRewardFnFromMultitaskEnv(
         env=env,
         achieved_goal_from_observation=IndexIntoAchievedGoal(
             observation_key),
         desired_goal_key=desired_goal_key,
         achieved_goal_key="state_achieved_goal",
     )
     contextual_env = ContextualEnv(
         img_env,
         context_distribution=image_goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
     )
     video_func = get_save_video_function(
         rollout_function,
         contextual_env,
         policy,
         tag=tag,
         imsize=renderer.width,
         image_format="CWH",
         **save_video_kwargs,
     )
     return video_func
예제 #9
0
def image_based_goal_conditioned_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    cnn_kwargs,
    policy_type='tanh-normal',
    env_id=None,
    env_class=None,
    env_kwargs=None,
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    reward_type='state_distance',
    env_renderer_kwargs=None,
    # Data augmentations
    apply_random_crops=False,
    random_crop_pixel_shift=4,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not env_renderer_kwargs:
        env_renderer_kwargs = {}
    if not video_renderer_kwargs:
        video_renderer_kwargs = {}
    img_observation_key = 'image_observation'
    img_desired_goal_key = 'image_desired_goal'
    state_observation_key = 'state_observation'
    state_desired_goal_key = 'state_desired_goal'
    state_achieved_goal_key = 'state_achieved_goal'
    sample_context_from_obs_dict_fn = RemapKeyFn({
        'image_desired_goal':
        'image_observation',
        'state_desired_goal':
        'state_observation',
    })

    def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode,
                             renderer):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        state_env.goal_sampling_mode = goal_sampling_mode
        state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        image_goal_distribution = AddImageDistribution(
            env=state_env,
            base_distribution=state_goal_distribution,
            image_goal_key=img_desired_goal_key,
            renderer=renderer,
        )
        goal_distribution = PresampledDistribution(image_goal_distribution,
                                                   5000)
        img_env = InsertImageEnv(state_env, renderer=renderer)
        if reward_type == 'state_distance':
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=state_env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    'state_observation'),
                desired_goal_key=state_desired_goal_key,
                achieved_goal_key=state_achieved_goal_key,
            )
        elif reward_type == 'pixel_distance':
            reward_fn = NegativeL2Distance(
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    img_observation_key),
                desired_goal_key=img_desired_goal_key,
            )
        else:
            raise ValueError(reward_type)
        env = ContextualEnv(
            img_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=img_observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, goal_distribution, reward_fn

    env_renderer = EnvRenderer(**env_renderer_kwargs)
    expl_env, expl_context_distrib, expl_reward = setup_contextual_env(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode,
        env_renderer)
    eval_env, eval_context_distrib, eval_reward = setup_contextual_env(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        env_renderer)

    action_dim = expl_env.action_space.low.size
    if env_renderer.output_image_format == 'WHC':
        img_width, img_height, img_num_channels = (
            expl_env.observation_space[img_observation_key].shape)
    elif env_renderer.output_image_format == 'CHW':
        img_num_channels, img_height, img_width = (
            expl_env.observation_space[img_observation_key].shape)
    else:
        raise ValueError(env_renderer.output_image_format)

    def create_qf():
        cnn = BasicCNN(input_width=img_width,
                       input_height=img_height,
                       input_channels=img_num_channels,
                       **cnn_kwargs)
        joint_cnn = ApplyConvToStateAndGoalImage(cnn)
        return basic.MultiInputSequential(
            ApplyToObs(joint_cnn), basic.FlattenEachParallel(),
            ConcatMlp(input_size=joint_cnn.output_size + action_dim,
                      output_size=1,
                      **qf_kwargs))

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()
    cnn = BasicCNN(input_width=img_width,
                   input_height=img_height,
                   input_channels=img_num_channels,
                   **cnn_kwargs)
    joint_cnn = ApplyConvToStateAndGoalImage(cnn)
    policy_obs_dim = joint_cnn.output_size
    if policy_type == 'normal':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(Gaussian(obs_processor))
    elif policy_type == 'tanh-normal':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(TanhGaussian(obs_processor))
    elif policy_type == 'normal-tanh-mean':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           output_activations=['tanh', 'identity'],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(Gaussian(obs_processor))
    else:
        raise ValueError("Unknown policy type: {}".format(policy_type))

    if apply_random_crops:
        pad = BatchPad(
            env_renderer.output_image_format,
            random_crop_pixel_shift,
            random_crop_pixel_shift,
        )
        crop = JointRandomCrop(
            env_renderer.output_image_format,
            env_renderer.image_shape,
        )

        def concat_context_to_obs(batch, *args, **kwargs):
            obs = batch['observations']
            next_obs = batch['next_observations']
            context = batch[img_desired_goal_key]
            obs_padded = pad(obs)
            next_obs_padded = pad(next_obs)
            context_padded = pad(context)
            obs_aug, context_aug = crop(obs_padded, context_padded)
            next_obs_aug, next_context_aug = crop(next_obs_padded,
                                                  context_padded)

            batch['observations'] = np.concatenate([obs_aug, context_aug],
                                                   axis=1)
            batch['next_observations'] = np.concatenate(
                [next_obs_aug, next_context_aug], axis=1)
            return batch
    else:

        def concat_context_to_obs(batch, *args, **kwargs):
            obs = batch['observations']
            next_obs = batch['next_observations']
            context = batch[img_desired_goal_key]
            batch['observations'] = np.concatenate([obs, context], axis=1)
            batch['next_observations'] = np.concatenate([next_obs, context],
                                                        axis=1)
            return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[img_desired_goal_key, state_desired_goal_key],
        observation_key=img_observation_key,
        observation_keys=[img_observation_key, state_observation_key],
        context_distribution=eval_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=img_observation_key,
        context_keys_for_policy=[img_desired_goal_key],
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=img_observation_key,
        context_keys_for_policy=[img_desired_goal_key],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=img_observation_key,
            context_keys_for_policy=[img_desired_goal_key],
        )
        video_renderer = EnvRenderer(**video_renderer_kwargs)
        video_eval_env = InsertImageEnv(eval_env,
                                        renderer=video_renderer,
                                        image_key='video_observation')
        video_expl_env = InsertImageEnv(expl_env,
                                        renderer=video_renderer,
                                        image_key='video_observation')
        video_eval_env = ContextualEnv(
            video_eval_env,
            context_distribution=eval_env.context_distribution,
            reward_fn=lambda *_: np.array([0]),
            observation_key=img_observation_key,
        )
        video_expl_env = ContextualEnv(
            video_expl_env,
            context_distribution=expl_env.context_distribution,
            reward_fn=lambda *_: np.array([0]),
            observation_key=img_observation_key,
        )
        eval_video_func = get_save_video_function(
            rollout_function,
            video_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=video_renderer.image_shape[1],
            image_formats=[
                env_renderer.output_image_format,
                env_renderer.output_image_format,
                video_renderer.output_image_format,
            ],
            keys_to_show=[
                'image_desired_goal', 'image_observation', 'video_observation'
            ],
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            video_expl_env,
            exploration_policy,
            tag="xplor",
            imsize=video_renderer.image_shape[1],
            image_formats=[
                env_renderer.output_image_format,
                env_renderer.output_image_format,
                video_renderer.output_image_format,
            ],
            keys_to_show=[
                'image_desired_goal', 'image_observation', 'video_observation'
            ],
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    algorithm.train()
예제 #10
0
 def get_img_env(env):
     renderer = EnvRenderer(**variant["renderer_kwargs"])
     img_env = InsertImageEnv(GymToMultiEnv(env), renderer=renderer)
예제 #11
0
    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)

        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        encoded_env = EncoderWrappedEnv(
            img_env,
            model,
            dict(image_observation="latent_observation", ),
        )
        if goal_sampling_mode == "vae_prior":
            latent_goal_distribution = PriorDistribution(
                model.representation_size,
                desired_goal_key,
            )
            diagnostics = StateImageGoalDiagnosticsFn({}, )
        elif goal_sampling_mode == "reset_of_env":
            state_goal_env = get_gym_env(env_id,
                                         env_class=env_class,
                                         env_kwargs=env_kwargs)
            state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
                state_goal_env,
                desired_goal_keys=[state_goal_key],
            )
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_goal_distribution,
                image_goal_key=image_goal_key,
                renderer=renderer,
            )
            latent_goal_distribution = AddLatentDistribution(
                image_goal_distribution,
                image_goal_key,
                desired_goal_key,
                model,
            )
            if hasattr(state_goal_env, 'goal_conditioned_diagnostics'):
                diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics(
                    state_goal_env.goal_conditioned_diagnostics,
                    desired_goal_key=state_goal_key,
                    observation_key=state_observation_key,
                )
            else:
                state_goal_env.get_contextual_diagnostics
                diagnostics = state_goal_env.get_contextual_diagnostics
        else:
            raise NotImplementedError('unknown goal sampling method: %s' %
                                      goal_sampling_mode)

        reward_fn = DistanceRewardFn(
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )

        env = ContextualEnv(
            encoded_env,
            context_distribution=latent_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, latent_goal_distribution, reward_fn
예제 #12
0
    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode,
                                          presampled_goals_path):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        # encoded_env = EncoderWrappedEnv(
        #     img_env,
        #     model,
        #     dict(image_observation="latent_observation", ),
        # )
        # if goal_sampling_mode == "vae_prior":
        #     latent_goal_distribution = PriorDistribution(
        #         model.representation_size,
        #         desired_goal_key,
        #     )
        #     diagnostics = StateImageGoalDiagnosticsFn({}, )
        # elif goal_sampling_mode == "presampled":
        #     diagnostics = state_env.get_contextual_diagnostics
        #     image_goal_distribution = PresampledPathDistribution(
        #         presampled_goals_path,
        #     )

        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        # elif goal_sampling_mode == "reset_of_env":
        #     state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        #     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
        #         state_goal_env,
        #         desired_goal_keys=[state_goal_key],
        #     )
        #     image_goal_distribution = AddImageDistribution(
        #         env=state_env,
        #         base_distribution=state_goal_distribution,
        #         image_goal_key=image_goal_key,
        #         renderer=renderer,
        #     )
        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        #     no_goal_distribution = PriorDistribution(
        #         representation_size=0,
        #         key="no_goal",
        #     )
        #     diagnostics = state_goal_env.get_contextual_diagnostics
        # else:
        #     error
        diagnostics = StateImageGoalDiagnosticsFn({}, )
        no_goal_distribution = PriorDistribution(
            representation_size=0,
            key="no_goal",
        )

        reward_fn = GraspingRewardFn(
            # img_env, # state_env,
            # observation_key=observation_key,
            # desired_goal_key=desired_goal_key,
            # **reward_kwargs
        )

        env = ContextualEnv(
            img_env,  # state_env,
            context_distribution=no_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, no_goal_distribution, reward_fn
예제 #13
0
def experiment(variant):
    render = variant.get("render", False)
    debug = variant.get("debug", False)
    vae_path = variant.get("vae_path", False)

    process_args(variant)

    env_class = variant.get("env_class")
    env_kwargs = variant.get("env_kwargs")
    env_id = variant.get("env_id")
    # expl_env = env_class(**env_kwargs)
    # eval_env = env_class(**env_kwargs)
    expl_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
    eval_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
    env = eval_env

    if variant.get('sparse_reward', False):
        expl_env = RewardWrapperEnv(expl_env, compute_hand_sparse_reward)
        eval_env = RewardWrapperEnv(eval_env, compute_hand_sparse_reward)

    if variant.get("vae_path", False):
        vae = load_local_or_remote_file(vae_path)
        variant['path_loader_kwargs']['model_path'] = vae_path
        renderer = EnvRenderer(**variant.get("renderer_kwargs", {}))
        expl_env = VQVAEWrappedEnv(InsertImageEnv(expl_env, renderer=renderer),
                                   vae,
                                   reward_params=variant.get(
                                       "reward_params", {}),
                                   **variant.get('vae_wrapped_env_kwargs', {}))
        eval_env = VQVAEWrappedEnv(InsertImageEnv(eval_env, renderer=renderer),
                                   vae,
                                   reward_params=variant.get(
                                       "reward_params", {}),
                                   **variant.get('vae_wrapped_env_kwargs', {}))
        env = eval_env
        variant['path_loader_kwargs']['env'] = env

    if variant.get('add_env_demos', False):
        variant["path_loader_kwargs"]["demo_paths"].append(
            variant["env_demo_path"])

    if variant.get('add_env_offpolicy_data', False):
        variant["path_loader_kwargs"]["demo_paths"].append(
            variant["env_offpolicy_data_path"])

    if variant.get("use_masks", False):
        mask_wrapper_kwargs = variant.get("mask_wrapper_kwargs", dict())

        expl_mask_distribution_kwargs = variant[
            "expl_mask_distribution_kwargs"]
        expl_mask_distribution = DiscreteDistribution(
            **expl_mask_distribution_kwargs)
        expl_env = RewardMaskWrapper(env, expl_mask_distribution,
                                     **mask_wrapper_kwargs)

        eval_mask_distribution_kwargs = variant[
            "eval_mask_distribution_kwargs"]
        eval_mask_distribution = DiscreteDistribution(
            **eval_mask_distribution_kwargs)
        eval_env = RewardMaskWrapper(env, eval_mask_distribution,
                                     **mask_wrapper_kwargs)
        env = eval_env

    if variant.get("pretrained_algorithm_path", False):
        resume(variant)
        return

    path_loader_kwargs = variant.get("path_loader_kwargs", {})
    stack_obs = path_loader_kwargs.get("stack_obs", 1)
    if stack_obs > 1:
        expl_env = StackObservationEnv(expl_env, stack_obs=stack_obs)
        eval_env = StackObservationEnv(eval_env, stack_obs=stack_obs)

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = variant.get('achieved_goal_key',
                                    'latent_achieved_goal')

    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = eval_env.action_space.low.size

    if hasattr(expl_env, 'info_sizes'):
        env_info_sizes = expl_env.info_sizes
    else:
        env_info_sizes = dict()

    replay_buffer_kwargs = dict(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
    )
    replay_buffer_kwargs.update(variant.get('replay_buffer_kwargs', dict()))
    replay_buffer = ConcatToObsWrapper(
        ObsDictRelabelingBuffer(**replay_buffer_kwargs),
        [
            "resampled_goals",
        ],
    )
    replay_buffer_kwargs.update(
        variant.get('demo_replay_buffer_kwargs', dict()))
    demo_train_buffer = ConcatToObsWrapper(
        ObsDictRelabelingBuffer(**replay_buffer_kwargs),
        [
            "resampled_goals",
        ],
    )
    demo_test_buffer = ConcatToObsWrapper(
        ObsDictRelabelingBuffer(**replay_buffer_kwargs),
        [
            "resampled_goals",
        ],
    )

    M = variant['layer_size']
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    policy_class = variant.get("policy_class", TanhGaussianPolicy)
    policy = policy_class(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **variant['policy_kwargs'],
    )

    expl_policy = policy
    exploration_kwargs = variant.get('exploration_kwargs', {})
    if exploration_kwargs:
        if exploration_kwargs.get("deterministic_exploration", False):
            expl_policy = MakeDeterministic(policy)

        exploration_strategy = exploration_kwargs.get("strategy", None)
        if exploration_strategy is None:
            pass
        elif exploration_strategy == 'ou':
            es = OUStrategy(
                action_space=expl_env.action_space,
                max_sigma=exploration_kwargs['noise'],
                min_sigma=exploration_kwargs['noise'],
            )
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=expl_policy,
            )
        elif exploration_strategy == 'gauss_eps':
            es = GaussianAndEpislonStrategy(
                action_space=expl_env.action_space,
                max_sigma=exploration_kwargs['noise'],
                min_sigma=exploration_kwargs['noise'],  # constant sigma
                epsilon=0,
            )
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=expl_policy,
            )
        else:
            error

    trainer = AWACTrainer(env=eval_env,
                          policy=policy,
                          qf1=qf1,
                          qf2=qf2,
                          target_qf1=target_qf1,
                          target_qf2=target_qf2,
                          **variant['trainer_kwargs'])
    if variant['collection_mode'] == 'online':
        expl_path_collector = MdpStepCollector(
            expl_env,
            policy,
        )
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant['max_path_length'],
            batch_size=variant['batch_size'],
            num_epochs=variant['num_epochs'],
            num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'],
            num_expl_steps_per_train_loop=variant[
                'num_expl_steps_per_train_loop'],
            num_trains_per_train_loop=variant['num_trains_per_train_loop'],
            min_num_steps_before_training=variant[
                'min_num_steps_before_training'],
        )
    else:
        eval_path_collector = GoalConditionedPathCollector(
            eval_env,
            MakeDeterministic(policy),
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            render=render,
            goal_sampling_mode=variant.get("goal_sampling_mode", None),
        )
        expl_path_collector = GoalConditionedPathCollector(
            expl_env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            render=render,
            goal_sampling_mode=variant.get("goal_sampling_mode", None),
        )
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant['max_path_length'],
            batch_size=variant['batch_size'],
            num_epochs=variant['num_epochs'],
            num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'],
            num_expl_steps_per_train_loop=variant[
                'num_expl_steps_per_train_loop'],
            num_trains_per_train_loop=variant['num_trains_per_train_loop'],
            min_num_steps_before_training=variant[
                'min_num_steps_before_training'],
        )
    algorithm.to(ptu.device)

    if variant.get("save_video", True):
        video_func = VideoSaveFunction(
            env,
            variant,
        )
        #algorithm.post_train_funcs.append(video_func)
        algorithm.post_train_funcs.append(video_func)

    # if variant.get("save_video", False):
    #     from rlkit.visualization.video import VideoSaveFunction
    #     renderer_kwargs = variant.get("renderer_kwargs", {})
    #     save_video_kwargs = variant.get("save_video_kwargs", {})

    #     def get_video_func(
    #         env,
    #         policy,
    #         tag,
    #     ):
    #         renderer = EnvRenderer(**renderer_kwargs)
    #         state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
    #             env,
    #             desired_goal_keys=[desired_goal_key],
    #         )
    #         image_goal_distribution = AddImageDistribution(
    #             env=env,
    #             base_distribution=state_goal_distribution,
    #             image_goal_key='image_desired_goal',
    #             renderer=renderer,
    #         )
    #         img_env = InsertImageEnv(env, renderer=renderer)
    #         rollout_function = partial(
    #             rf.multitask_rollout,
    #             max_path_length=variant['max_path_length'],
    #             observation_key=observation_key,
    #             desired_goal_key=desired_goal_key,
    #             return_dict_obs=True,
    #         )
    #         reward_fn = ContextualRewardFnFromMultitaskEnv(
    #             env=env,
    #             achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key),
    #             desired_goal_key=desired_goal_key,
    #             achieved_goal_key="state_achieved_goal",
    #         )
    #         contextual_env = ContextualEnv(
    #             img_env,
    #             context_distribution=image_goal_distribution,
    #             reward_fn=reward_fn,
    #             observation_key=observation_key,
    #         )
    #         video_func = get_save_video_function(
    #             rollout_function,
    #             contextual_env,
    #             policy,
    #             tag=tag,
    #             imsize=renderer.width,
    #             image_format='CWH',
    #             **save_video_kwargs
    #         )
    #         return video_func
    # expl_video_func = get_video_func(expl_env, expl_policy, "expl")
    # eval_video_func = get_video_func(eval_env, MakeDeterministic(policy), "eval")
    # algorithm.post_train_funcs.append(eval_video_func)
    # algorithm.post_train_funcs.append(expl_video_func)

    if variant.get('save_paths', False):
        algorithm.post_train_funcs.append(save_paths)

    if variant.get('load_demos', False):
        path_loader_class = variant.get('path_loader_class', MDPPathLoader)
        path_loader = path_loader_class(trainer,
                                        replay_buffer=replay_buffer,
                                        demo_train_buffer=demo_train_buffer,
                                        demo_test_buffer=demo_test_buffer,
                                        **path_loader_kwargs)
        path_loader.load_demos()
    if variant.get('pretrain_policy', False):
        trainer.pretrain_policy_with_bc(
            policy,
            demo_train_buffer,
            demo_test_buffer,
            trainer.bc_num_pretrain_steps,
        )
    if variant.get('pretrain_rl', False):
        trainer.pretrain_q_with_bc_data()

    if variant.get('save_pretrained_algorithm', False):
        p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p')
        pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt')
        data = algorithm._get_snapshot()
        data['algorithm'] = algorithm
        torch.save(data, open(pt_path, "wb"))
        torch.save(data, open(p_path, "wb"))

    algorithm.train()
예제 #14
0
def contextual_env_distrib_and_reward(
    vae,
    sets: typing.List[Set],
    state_env,
    renderer,
    reward_fn_kwargs,
    use_ground_truth_reward,
    state_observation_key,
    latent_observation_key,
    example_image_key,
    set_description_key,
    observation_key,
    image_observation_key,
    rig_goal_setter_kwargs,
    oracle_rig_goal=False,
):
    img_env = InsertImageEnv(state_env, renderer=renderer)
    encoded_env = EncoderWrappedEnv(
        img_env,
        vae,
        step_keys_map={image_observation_key: latent_observation_key},
    )
    if oracle_rig_goal:
        context_env_class = InitStateConditionedContextualEnv
        goal_distribution_params_distribution = (OracleRIGMeanSetter(
            sets,
            vae,
            example_image_key,
            env=state_env,
            renderer=renderer,
            cycle_for_batch_size_1=True,
            **rig_goal_setter_kwargs))
    else:
        context_env_class = ContextualEnv
        goal_distribution_params_distribution = (
            LatentGoalDictDistributionFromSet(
                sets,
                vae,
                example_image_key,
                cycle_for_batch_size_1=True,
            ))
    if use_ground_truth_reward:
        reward_fn, unbatched_reward_fn = create_ground_truth_set_rewards_fns(
            sets,
            goal_distribution_params_distribution.set_index_key,
            state_observation_key,
        )
    else:
        reward_fn, unbatched_reward_fn = create_normal_likelihood_reward_fns(
            latent_observation_key,
            goal_distribution_params_distribution.mean_key,
            goal_distribution_params_distribution.covariance_key,
            reward_fn_kwargs,
        )
    set_diagnostics = SetDiagnostics(
        set_description_key=set_description_key,
        set_index_key=goal_distribution_params_distribution.set_index_key,
        observation_key=state_observation_key,
    )
    env = context_env_class(
        encoded_env,
        context_distribution=goal_distribution_params_distribution,
        reward_fn=reward_fn,
        unbatched_reward_fn=unbatched_reward_fn,
        observation_key=observation_key,
        contextual_diagnostics_fns=[
            # goal_diagnostics,
            set_diagnostics,
        ],
        update_env_info_fn=delete_info,
    )
    return env, goal_distribution_params_distribution, reward_fn