示例#1
0
 def contextual_env_distrib_and_reward(
         env_id, env_class, env_kwargs, goal_sampling_mode
 ):
     env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
     env.goal_sampling_mode = goal_sampling_mode
     goal_distribution = GoalDictDistributionFromMultitaskEnv(
         env,
         desired_goal_keys=[desired_goal_key],
     )
     reward_fn = ContextualRewardFnFromMultitaskEnv(
         env=env,
         achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key),
         desired_goal_key=desired_goal_key,
         achieved_goal_key=achieved_goal_key,
     )
     diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
         env.goal_conditioned_diagnostics,
         desired_goal_key=desired_goal_key,
         observation_key=observation_key,
     )
     env = ContextualEnv(
         env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
         contextual_diagnostics_fns=[diag_fn],
         update_env_info_fn=delete_info,
     )
     return env, goal_distribution, reward_fn
示例#2
0
 def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode,
                          renderer):
     state_env = get_gym_env(env_id,
                             env_class=env_class,
                             env_kwargs=env_kwargs)
     state_env.goal_sampling_mode = goal_sampling_mode
     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
         state_env,
         desired_goal_keys=[state_desired_goal_key],
     )
     state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
         state_env.goal_conditioned_diagnostics,
         desired_goal_key=state_desired_goal_key,
         observation_key=state_observation_key,
     )
     image_goal_distribution = AddImageDistribution(
         env=state_env,
         base_distribution=state_goal_distribution,
         image_goal_key=img_desired_goal_key,
         renderer=renderer,
     )
     goal_distribution = PresampledDistribution(image_goal_distribution,
                                                5000)
     img_env = InsertImageEnv(state_env, renderer=renderer)
     if reward_type == 'state_distance':
         reward_fn = ContextualRewardFnFromMultitaskEnv(
             env=state_env,
             achieved_goal_from_observation=IndexIntoAchievedGoal(
                 'state_observation'),
             desired_goal_key=state_desired_goal_key,
             achieved_goal_key=state_achieved_goal_key,
         )
     elif reward_type == 'pixel_distance':
         reward_fn = NegativeL2Distance(
             achieved_goal_from_observation=IndexIntoAchievedGoal(
                 img_observation_key),
             desired_goal_key=img_desired_goal_key,
         )
     else:
         raise ValueError(reward_type)
     env = ContextualEnv(
         img_env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=img_observation_key,
         contextual_diagnostics_fns=[state_diag_fn],
         update_env_info_fn=delete_info,
     )
     return env, goal_distribution, reward_fn
示例#3
0
def create_sets(env_id,
                env_class,
                env_kwargs,
                renderer,
                saved_filename=None,
                save_to_filename=None,
                **kwargs):
    if saved_filename is not None:
        sets = asset_loader.load_local_or_remote_file(saved_filename)
    else:
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        if isinstance(env, PickAndPlaceEnv):
            sets = sample_pnp_sets(env, renderer, **kwargs)
        else:
            raise NotImplementedError()
    if save_to_filename:
        save(sets, save_to_filename)
    return sets
示例#4
0
 def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                       goal_sampling_mode):
     env = get_gym_env(
         env_id,
         env_class=env_class,
         env_kwargs=env_kwargs,
         unwrap_timed_envs=True,
     )
     env.goal_sampling_mode = goal_sampling_mode
     goal_distribution = GoalDictDistributionFromGymGoalEnv(
         env,
         desired_goal_key=desired_goal_key,
     )
     distance_fn = L2Distance(
         achieved_goal_from_observation=IndexIntoAchievedGoal(
             achieved_goal_key, ),
         desired_goal_key=desired_goal_key,
     )
     if (isinstance(env, robotics.FetchReachEnv)
             or isinstance(env, robotics.FetchPushEnv)
             or isinstance(env, robotics.FetchPickAndPlaceEnv)
             or isinstance(env, robotics.FetchSlideEnv)):
         success_threshold = 0.05
     else:
         raise TypeError("I don't know the success threshold of env ", env)
     reward_fn = ThresholdDistanceReward(distance_fn, success_threshold)
     diag_fn = GenericGoalConditionedContextualDiagnostics(
         desired_goal_key=desired_goal_key,
         achieved_goal_key=achieved_goal_key,
         success_threshold=success_threshold,
     )
     env = ContextualEnv(
         env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
         contextual_diagnostics_fns=[diag_fn],
         update_env_info_fn=delete_info,
     )
     return env, goal_distribution, reward_fn
示例#5
0
def create_dataset(env_id,
                   env_class,
                   env_kwargs,
                   generate_set_kwargs,
                   num_ungrouped_images,
                   env=None,
                   renderer=None,
                   sets=None):
    # env = env or create_env(**env_kwargs)
    env = env or get_gym_env(env_id, env_class, env_kwargs)
    renderer = renderer or EnvRenderer(output_image_format='CHW')
    import time
    print("making train set")
    start = time.time()
    if sets is None:
        set_imgs = pnp_util.generate_set_images(env, renderer,
                                                **generate_set_kwargs)
        set_imgs = list(set_imgs)
    else:
        set_imgs = np.array(
            [set.example_dict['example_image'] for set in sets])
    set_imgs = ptu.from_numpy(np.array(set_imgs))
    print("making eval set", time.time() - start)
    start = time.time()
    eval_set_imgs = pnp_util.generate_set_images(env, renderer,
                                                 **generate_set_kwargs)
    eval_set_imgs = ptu.from_numpy(np.array(list(eval_set_imgs)))
    set_imgs_iterator = set_imgs  # an array is already a valid data iterator
    print("making ungrouped images", time.time() - start)
    start = time.time()
    ungrouped_imgs = generate_images(env,
                                     renderer,
                                     num_images=num_ungrouped_images)
    ungrouped_imgs = ptu.from_numpy(np.array(list(ungrouped_imgs)))
    print("done", time.time() - start)
    return eval_set_imgs, renderer, set_imgs, set_imgs_iterator, ungrouped_imgs
示例#6
0
def encoder_goal_conditioned_sac_experiment(
        max_path_length,
        qf_kwargs,
        sac_trainer_kwargs,
        replay_buffer_kwargs,
        algo_kwargs,
        policy_kwargs,
        # Encoder parameters
        disentangled_qf_kwargs,
        encoder_kwargs=None,
        encoder_cnn_kwargs=None,
        qf_state_encoder_is_goal_encoder=False,
        reward_type='encoder_distance',
        reward_config=None,
        latent_dim=8,
        # Policy params
        policy_using_encoder_settings=None,
        use_separate_encoder_for_policy=True,
        # Env settings
        env_id=None,
        env_class=None,
        env_kwargs=None,
        contextual_env_kwargs=None,
        exploration_policy_kwargs=None,
        evaluation_goal_sampling_mode=None,
        exploration_goal_sampling_mode=None,
        num_presampled_goals=5000,
        # Image env parameters
        use_image_observations=False,
        env_renderer_kwargs=None,
        # Video parameters
        save_video=True,
        save_debug_video=True,
        save_video_kwargs=None,
        video_renderer_kwargs=None,
        # Debugging parameters
        visualize_representation=True,
        distance_scatterplot_save_period=0,
        distance_scatterplot_initial_save_period=0,
        debug_renderer_kwargs=None,
        debug_visualization_kwargs=None,
        use_debug_trainer=False,

        # vae stuff
        train_encoder_as_vae=False,
        vae_trainer_kwargs=None,
        decoder_kwargs=None,
        vae_to_sac_loss_scale=1.0,

        mask_key='mask_desired_goal',
        num_steps_per_mask_change=10,
):
    if reward_config is None:
        reward_config = {}
    if encoder_cnn_kwargs is None:
        encoder_cnn_kwargs = {}
    if policy_using_encoder_settings is None:
        policy_using_encoder_settings = {}
    if debug_visualization_kwargs is None:
        debug_visualization_kwargs = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if contextual_env_kwargs is None:
        contextual_env_kwargs = {}
    if encoder_kwargs is None:
        encoder_kwargs = {}
    if save_video_kwargs is None:
        save_video_kwargs = {}
    if video_renderer_kwargs is None:
        video_renderer_kwargs = {}
    if debug_renderer_kwargs is None:
        debug_renderer_kwargs = {}

    img_observation_key = 'image_observation'
    state_observation_key = 'state_observation'
    latent_desired_goal_key = 'latent_desired_goal'
    state_desired_goal_key = 'state_desired_goal'
    img_desired_goal_key = 'image_desired_goal'


    if use_image_observations:
        env_renderer = Renderer(**env_renderer_kwargs)

    def setup_env(state_env, encoder, reward_fn):
        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        if use_image_observations:
            goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=goal_distribution,
                image_goal_key=img_desired_goal_key,
                renderer=env_renderer,
            )
            base_env = InsertImageEnv(state_env, renderer=env_renderer)
            goal_distribution = PresampledDistribution(
                goal_distribution, num_presampled_goals)
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key, img_desired_goal_key],
                encoder_input_key=img_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        else:
            base_env = state_env
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key],
                encoder_input_key=state_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=latent_dim,
            distribution_type='one_hot_masks',
        )

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        env = ContextualEnv(
            base_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
            **contextual_env_kwargs,
        )
        return env, goal_distribution

    state_expl_env = get_gym_env(env_id, env_class=env_class,
                                 env_kwargs=env_kwargs)
    state_expl_env.goal_sampling_mode = exploration_goal_sampling_mode
    state_eval_env = get_gym_env(env_id, env_class=env_class,
                                 env_kwargs=env_kwargs)
    state_eval_env.goal_sampling_mode = evaluation_goal_sampling_mode

    if use_image_observations:
        context_keys_to_save = [
            state_desired_goal_key,
            img_desired_goal_key,
            latent_desired_goal_key,
            mask_key,
        ]
        context_key_for_rl = img_desired_goal_key
        observation_key_for_rl = img_observation_key

        def create_encoder():
            img_num_channels, img_height, img_width = env_renderer.image_chw
            cnn = BasicCNN(
                input_width=img_width,
                input_height=img_height,
                input_channels=img_num_channels,
                **encoder_cnn_kwargs
            )
            cnn_output_size = np.prod(cnn.output_shape)
            mlp = MultiHeadedMlp(
                input_size=cnn_output_size,
                output_sizes=[latent_dim, latent_dim],
                **encoder_kwargs)
            enc = nn.Sequential(cnn, Flatten(), mlp)
            enc.input_size = img_width * img_height * img_num_channels
            enc.output_size = latent_dim
            return enc
    else:
        context_keys_to_save = [state_desired_goal_key,
                                latent_desired_goal_key,
                                mask_key]
        context_key_for_rl = state_desired_goal_key
        observation_key_for_rl = state_observation_key

        def create_encoder():
            in_dim = (
                state_expl_env.observation_space.spaces[state_observation_key].low.size
            )
            enc = ConcatMultiHeadedMlp(
                input_size=in_dim,
                output_sizes=[latent_dim, latent_dim],
                **encoder_kwargs
            )
            enc.input_size = in_dim
            enc.output_size = latent_dim
            return enc

    encoder_net = create_encoder()
    mu_encoder_net = EncoderMuFromEncoderDistribution(encoder_net)
    target_encoder_net = create_encoder()
    mu_target_encoder_net = EncoderMuFromEncoderDistribution(target_encoder_net)
    encoder_input_dim = encoder_net.input_size

    encoder = EncoderFromNetwork(mu_encoder_net)
    encoder.to(ptu.device)
    if reward_type == 'encoder_distance':
        reward_fn = MaskedEncoderRewardFnFromMultitaskEnv(
            encoder=encoder,
            next_state_encoder_input_key=observation_key_for_rl,
            context_key=latent_desired_goal_key,
            mask_key=mask_key,
            **reward_config,
        )
    elif reward_type == 'target_encoder_distance':
        target_encoder = EncoderFromNetwork(mu_target_encoder_net)
        reward_fn = MaskedEncoderRewardFnFromMultitaskEnv(
            encoder=target_encoder,
            next_state_encoder_input_key=observation_key_for_rl,
            context_key=latent_desired_goal_key,
            mask_key=mask_key,
            **reward_config,
        )
    else:
        raise ValueError("invalid reward type {}".format(reward_type))
    expl_env, expl_context_distrib = setup_env(state_expl_env, encoder, reward_fn)
    eval_env, eval_context_distrib = setup_env(state_eval_env, encoder, reward_fn)

    action_dim = expl_env.action_space.low.size

    mask_encoder = nn.Identity()
    mask_encoder.input_size = latent_dim
    mask_encoder.output_size = latent_dim

    def make_qf(goal_encoder):
        if qf_state_encoder_is_goal_encoder:
            state_encoder = goal_encoder
        else:
            state_encoder = EncoderMuFromEncoderDistribution(create_encoder())
        # Append mask to the goal encoder.
        goal_encoder_with_mask = Split(
            module1=goal_encoder,
            module2=mask_encoder,
            split_idx=goal_encoder.input_size,
        )
        goal_encoder_with_mask.input_size = encoder_input_dim + latent_dim
        goal_encoder_with_mask.output_size = 2 * latent_dim

        goal_encoder = nn.Sequential(goal_encoder_with_mask, ConcatTuple())
        goal_encoder.input_size = goal_encoder_with_mask.input_size
        goal_encoder.output_size = goal_encoder_with_mask.output_size

        return DisentangledMlpQf(
            goal_encoder=goal_encoder,
            state_encoder=state_encoder,
            preprocess_obs_dim=encoder_input_dim,
            action_dim=action_dim,
            qf_kwargs=qf_kwargs,
            vectorized=True,
            num_heads=latent_dim,
            **disentangled_qf_kwargs
        )
    qf1 = make_qf(mu_encoder_net)
    qf2 = make_qf(mu_encoder_net)
    target_qf1 = make_qf(mu_target_encoder_net)
    target_qf2 = make_qf(mu_target_encoder_net)

    if use_separate_encoder_for_policy:
        policy_encoder = EncoderMuFromEncoderDistribution(create_encoder())
        policy_encoder_net = EncodeObsAndGoal(
            policy_encoder,
            encoder_input_dim,
            encode_state=True,
            encode_goal=True,
            detach_encoder_via_goal=False,
            detach_encoder_via_state=False,
        )
    else:
        policy_encoder_net = EncodeObsAndGoal(
            mu_encoder_net,
            encoder_input_dim,
            **policy_using_encoder_settings
        )
    policy_encoder_net = nn.Sequential(policy_encoder_net, ConcatTuple())
    policy_encoder_net.input_size = 2 * encoder_input_dim
    policy_encoder_net.output_size = 2 * latent_dim
    policy_encoder_net = Split(
        module1=policy_encoder_net,
        module2=mask_encoder,
        split_idx=policy_encoder_net.input_size,
    )
    policy_encoder_net.input_size = 2 * encoder_input_dim + latent_dim
    policy_encoder_net.output_size = 3 * latent_dim

    obs_processor = nn.Sequential(
        policy_encoder_net,
        ConcatTuple(),
        MultiHeadedMlp(
            input_size=policy_encoder_net.output_size,
            output_sizes=[action_dim, action_dim],
            **policy_kwargs
        )
    )
    policy = PolicyFromDistributionGenerator(
        TanhGaussian(obs_processor)
    )

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        obs = obs.reshape(len(obs), -1)
        next_obs = next_obs.reshape(len(next_obs), -1)
        context = batch[context_key_for_rl]
        context = context.reshape(len(context), -1)
        mask = batch[mask_key]
        batch['observations'] = np.concatenate([obs, context, mask], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context, mask],
                                                    axis=1)
        batch['raw_next_observations'] = next_obs
        return batch

    if use_image_observations:
        # Do this so that the context has all two/three: the state, image, and
        # encoded goal
        sample_context_from_observation = compose(
            RemapKeyFn({
                state_desired_goal_key: state_observation_key,
                img_desired_goal_key: img_observation_key,
                mask_key: mask_key,
            }),
            ReEncoderAchievedStateFn(
                encoder=encoder,
                encoder_input_key=context_key_for_rl,
                encoder_output_key=latent_desired_goal_key,
                keys_to_keep=[state_desired_goal_key, img_desired_goal_key,
                              mask_key],
            ),
        )
        ob_keys_to_save_in_buffer = [state_observation_key, img_observation_key,
                                     mask_key]
    else:
        sample_context_from_observation = compose(
            RemapKeyFn({
                state_desired_goal_key: state_observation_key,
                mask_key: mask_key,
            }),
            ReEncoderAchievedStateFn(
                encoder=encoder,
                encoder_input_key=context_key_for_rl,
                encoder_output_key=latent_desired_goal_key,
                keys_to_keep=[state_desired_goal_key, mask_key],
            ),
        )
        ob_keys_to_save_in_buffer = [state_observation_key, mask_key]

    encoder_output_dim = mu_encoder_net.output_size
    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=context_keys_to_save,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_observation,
        observation_keys=ob_keys_to_save_in_buffer,
        observation_key=observation_key_for_rl,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        reward_dim=encoder_output_dim,
        **replay_buffer_kwargs
    )

    disentangled_trainer = DisentangedTrainer(
        env=expl_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **sac_trainer_kwargs
    )

    if train_encoder_as_vae:
        if vae_trainer_kwargs is None:
            vae_trainer_kwargs = {}
        if decoder_kwargs is None:
            decoder_kwargs = {}

        # VAE training
        def make_decoder():
            if use_image_observations:
                raise NotImplementedError
            else:
                return Mlp(
                    input_size=latent_dim,
                    output_size=encoder_input_dim,
                    **decoder_kwargs
                )

        decoder_net = make_decoder()
        vae = VAE(encoder_net, decoder_net)

        vae_trainer = VAETrainer(
            vae=vae,
            **vae_trainer_kwargs
        )
        trainers = OrderedDict()
        trainers['vae_trainer'] = vae_trainer
        trainers['disentangled_trainer'] = disentangled_trainer
        trainer = JointLossTrainer(
            trainers,
            optimizers=[
                disentangled_trainer.qf1_optimizer,
                disentangled_trainer.qf2_optimizer,
                disentangled_trainer.alpha_optimizer,
                disentangled_trainer.policy_optimizer,
                vae_trainer.vae_optimizer,
            ],
            trainer_loss_scales={
                vae_trainer: vae_to_sac_loss_scale,
                disentangled_trainer: 1,
            }
        )
    else:
        trainer = disentangled_trainer

    if not use_image_observations and use_debug_trainer:
        # TODO: implement this for images
        debug_trainer = DebugTrainer(
            observation_space=expl_env.observation_space.spaces[
                state_observation_key
            ],
            encoder=mu_encoder_net,
            encoder_output_dim=encoder_output_dim,
        )
        trainer = JointTrainer([trainer, debug_trainer])

    eval_path_collector = RotatingMaskingPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key_for_rl,
        context_keys_for_policy=[context_key_for_rl, mask_key],
        mask_key=mask_key,
        mask_length=latent_dim,
        num_steps_per_mask_change=num_steps_per_mask_change,
    )

    exploration_policy = create_exploration_policy(
        policy, **exploration_policy_kwargs)
    expl_path_collector = RotatingMaskingPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key_for_rl,
        context_keys_for_policy=[context_key_for_rl, mask_key],
        mask_key=mask_key,
        mask_length=latent_dim,
        num_steps_per_mask_change=num_steps_per_mask_change,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs
    )
    algorithm.to(ptu.device)

    video_renderer = Renderer(**video_renderer_kwargs)
    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key_for_rl,
            context_keys_for_policy=[context_key_for_rl],
        )
        if save_debug_video and not use_image_observations:
            # TODO: add visualization for image-based envs
            obj1_sweep_renderers = {
                'sweep_obj1_%d' % i: DebugRenderer(
                    encoder, i, **debug_renderer_kwargs)
                for i in range(encoder_output_dim)
            }
            obj0_sweep_renderers = {
                'sweep_obj0_%d' % i: DebugRenderer(
                    encoder, i, **debug_renderer_kwargs)
                for i in range(encoder_output_dim)

            }

            debugger_one = DebugRenderer(encoder, 0, **debug_renderer_kwargs)

            low = eval_env.env.observation_space[state_observation_key].low.min()
            high = eval_env.env.observation_space[state_observation_key].high.max()
            y = np.linspace(low, high, num=debugger_one.image_shape[0])
            x = np.linspace(low, high, num=debugger_one.image_shape[1])
            cross = np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])

            def create_shared_data_creator(obj_index):
                def compute_shared_data(raw_obs, env):
                    state = raw_obs[observation_key_for_rl]
                    obs = state[:2]
                    goal = state[2:]
                    if obj_index == 0:
                        new_states = np.concatenate(
                            [
                                np.repeat(obs[None, :], cross.shape[0], axis=0),
                                cross,
                            ],
                            axis=1,
                        )
                    elif obj_index == 1:
                        new_states = np.concatenate(
                            [
                                cross,
                                np.repeat(goal[None, :], cross.shape[0], axis=0),
                            ],
                            axis=1,
                        )
                    else:
                        raise ValueError(obj_index)
                    return encoder.encode(new_states)
                return compute_shared_data
            obj0_sweeper = create_shared_data_creator(0)
            obj1_sweeper = create_shared_data_creator(1)

            def add_images(env, base_distribution):
                if use_image_observations:
                    img_env = env
                    image_goal_distribution = base_distribution
                else:
                    state_env = env.env
                    image_goal_distribution = AddImageDistribution(
                        env=state_env,
                        base_distribution=base_distribution,
                        image_goal_key='image_desired_goal',
                        renderer=video_renderer,
                    )
                    img_env = InsertImageEnv(state_env, renderer=video_renderer)
                img_env = InsertDebugImagesEnv(
                    img_env,
                    obj1_sweep_renderers,
                    compute_shared_data=obj1_sweeper,
                )
                img_env = InsertDebugImagesEnv(
                    img_env,
                    obj0_sweep_renderers,
                    compute_shared_data=obj0_sweeper,
                )
                return ContextualEnv(
                    img_env,
                    context_distribution=image_goal_distribution,
                    reward_fn=reward_fn,
                    observation_key=observation_key_for_rl,
                    update_env_info_fn=delete_info,
                )

            img_eval_env = add_images(eval_env, eval_context_distrib)
            img_expl_env = add_images(expl_env, expl_context_distrib)

            def get_extra_imgs(
                    path,
                    index_in_path,
                    env,
            ):
                return [
                    path['full_observations'][index_in_path][key]
                    for key in obj1_sweep_renderers
                ] + [
                    path['full_observations'][index_in_path][key]
                    for key in obj0_sweep_renderers
                ]
            img_formats = [video_renderer.output_image_format]
            for r in obj1_sweep_renderers.values():
                img_formats.append(r.output_image_format)
            for r in obj0_sweep_renderers.values():
                img_formats.append(r.output_image_format)
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                MakeDeterministic(policy),
                tag="eval",
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                get_extra_imgs=get_extra_imgs,
                **save_video_kwargs
            )
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                exploration_policy,
                tag="train",
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                get_extra_imgs=get_extra_imgs,
                **save_video_kwargs
            )
        else:
            video_renderer = Renderer(**video_renderer_kwargs)

            def add_images(env, base_distribution):
                if use_image_observations:
                    video_env = InsertImageEnv(
                        env,
                        renderer=video_renderer,
                        image_key='video_observation',
                    )
                    image_goal_distribution = base_distribution
                else:
                    video_env = InsertImageEnv(
                        env,
                        renderer=video_renderer,
                        image_key='image_observation',
                    )
                    state_env = env.env
                    image_goal_distribution = AddImageDistribution(
                        env=state_env,
                        base_distribution=base_distribution,
                        image_goal_key='image_desired_goal',
                        renderer=video_renderer,
                    )
                return ContextualEnv(
                    video_env,
                    context_distribution=image_goal_distribution,
                    reward_fn=reward_fn,
                    observation_key=observation_key_for_rl,
                    update_env_info_fn=delete_info,
                )

            img_eval_env = add_images(eval_env, eval_context_distrib)
            img_expl_env = add_images(expl_env, expl_context_distrib)

            if use_image_observations:
                keys_to_show = [
                    'image_desired_goal',
                    'image_observation',
                    'video_observation',
                ]
                image_formats = [
                    env_renderer.output_image_format,
                    env_renderer.output_image_format,
                    video_renderer.output_image_format,
                ]
            else:
                keys_to_show = ['image_desired_goal', 'image_observation']
                image_formats = [
                    video_renderer.output_image_format,
                    video_renderer.output_image_format,
                ]
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                MakeDeterministic(policy),
                tag="eval",
                imsize=video_renderer.image_chw[1],
                keys_to_show=keys_to_show,
                image_formats=image_formats,
                **save_video_kwargs
            )
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                exploration_policy,
                tag="train",
                imsize=video_renderer.image_chw[1],
                keys_to_show=keys_to_show,
                image_formats=image_formats,
                **save_video_kwargs
            )

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)
    if visualize_representation:
        if use_image_observations:
            def state_to_encoder_input(state):
                goal_dict = {
                    'state_desired_goal': state,
                }
                env_state = state_eval_env.get_env_state()
                state_eval_env.set_to_goal(goal_dict)
                start_img = env_renderer.create_image(state_eval_env)
                state_eval_env.set_env_state(env_state)
                return start_img
            visualize_representation = create_visualize_representation(
                encoder, True, eval_env, video_renderer,
                state_to_encoder_input=state_to_encoder_input,
                env_renderer=env_renderer,
                **debug_visualization_kwargs
            )
            algorithm.post_train_funcs.append(visualize_representation)
            visualize_representation = create_visualize_representation(
                encoder, False, eval_env, video_renderer,
                state_to_encoder_input=state_to_encoder_input,
                env_renderer=env_renderer,
                **debug_visualization_kwargs
            )
            algorithm.post_train_funcs.append(visualize_representation)
        else:
            visualize_representation = create_visualize_representation(
                encoder, True, eval_env, video_renderer,
                **debug_visualization_kwargs
            )
            algorithm.post_train_funcs.append(visualize_representation)
            visualize_representation = create_visualize_representation(
                encoder, False, eval_env, video_renderer,
                **debug_visualization_kwargs
            )
            algorithm.post_train_funcs.append(visualize_representation)

    if distance_scatterplot_save_period > 0:
        algorithm.post_train_funcs.append(create_save_h_vs_state_distance_fn(
            distance_scatterplot_save_period,
            distance_scatterplot_initial_save_period,
            encoder,
            observation_key_for_rl,
        ))
    algorithm.train()
示例#7
0
def encoder_goal_conditioned_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    algo_kwargs,
    policy_kwargs,
    # Encoder parameters
    disentangled_qf_kwargs,
    use_parallel_qf=False,
    encoder_kwargs=None,
    encoder_cnn_kwargs=None,
    qf_state_encoder_is_goal_encoder=False,
    reward_type='encoder_distance',
    reward_config=None,
    latent_dim=8,
    # Policy params
    policy_using_encoder_settings=None,
    use_separate_encoder_for_policy=True,
    # Env settings
    env_id=None,
    env_class=None,
    env_kwargs=None,
    contextual_env_kwargs=None,
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    num_presampled_goals=5000,
    # Image env parameters
    use_image_observations=False,
    env_renderer_kwargs=None,
    # Video parameters
    save_video=True,
    save_debug_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
    # Debugging parameters
    visualize_representation=True,
    distance_scatterplot_save_period=0,
    distance_scatterplot_initial_save_period=0,
    debug_renderer_kwargs=None,
    debug_visualization_kwargs=None,
    use_debug_trainer=False,

    # vae stuff
    train_encoder_as_vae=False,
    vae_trainer_kwargs=None,
    decoder_kwargs=None,
    encoder_backprop_settings="rl_n_vae",
    mlp_for_image_decoder=False,
    pretrain_vae=False,
    pretrain_vae_kwargs=None,

    # Should only be set by transfer_encoder_goal_conditioned_sac_experiment
    is_retraining_from_scratch=False,
    train_results=None,
    rl_csv_fname='progress.csv',
):
    if reward_config is None:
        reward_config = {}
    if encoder_cnn_kwargs is None:
        encoder_cnn_kwargs = {}
    if policy_using_encoder_settings is None:
        policy_using_encoder_settings = {}
    if debug_visualization_kwargs is None:
        debug_visualization_kwargs = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if contextual_env_kwargs is None:
        contextual_env_kwargs = {}
    if encoder_kwargs is None:
        encoder_kwargs = {}
    if save_video_kwargs is None:
        save_video_kwargs = {}
    if video_renderer_kwargs is None:
        video_renderer_kwargs = {}
    if debug_renderer_kwargs is None:
        debug_renderer_kwargs = {}

    assert (is_retraining_from_scratch != (train_results is None))

    img_observation_key = 'image_observation'
    state_observation_key = 'state_observation'
    latent_desired_goal_key = 'latent_desired_goal'
    state_desired_goal_key = 'state_desired_goal'
    img_desired_goal_key = 'image_desired_goal'

    backprop_rl_into_encoder = False
    backprop_vae_into_encoder = False
    if 'rl' in encoder_backprop_settings:
        backprop_rl_into_encoder = True
    if 'vae' in encoder_backprop_settings:
        backprop_vae_into_encoder = True

    if use_image_observations:
        env_renderer = EnvRenderer(**env_renderer_kwargs)

    def setup_env(state_env, encoder, reward_fn):
        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        if use_image_observations:
            goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=goal_distribution,
                image_goal_key=img_desired_goal_key,
                renderer=env_renderer,
            )
            base_env = InsertImageEnv(state_env, renderer=env_renderer)
            goal_distribution = PresampledDistribution(goal_distribution,
                                                       num_presampled_goals)
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key, img_desired_goal_key],
                encoder_input_key=img_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        else:
            base_env = state_env
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key],
                encoder_input_key=state_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        env = ContextualEnv(
            base_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
            **contextual_env_kwargs,
        )
        return env, goal_distribution

    state_expl_env = get_gym_env(env_id,
                                 env_class=env_class,
                                 env_kwargs=env_kwargs)
    state_expl_env.goal_sampling_mode = exploration_goal_sampling_mode
    state_eval_env = get_gym_env(env_id,
                                 env_class=env_class,
                                 env_kwargs=env_kwargs)
    state_eval_env.goal_sampling_mode = evaluation_goal_sampling_mode

    if use_image_observations:
        context_keys_to_save = [
            state_desired_goal_key,
            img_desired_goal_key,
            latent_desired_goal_key,
        ]
        context_key_for_rl = img_desired_goal_key
        observation_key_for_rl = img_observation_key

        def create_encoder():
            img_num_channels, img_height, img_width = env_renderer.image_chw
            cnn = BasicCNN(input_width=img_width,
                           input_height=img_height,
                           input_channels=img_num_channels,
                           **encoder_cnn_kwargs)
            cnn_output_size = np.prod(cnn.output_shape)
            mlp = MultiHeadedMlp(input_size=cnn_output_size,
                                 output_sizes=[latent_dim, latent_dim],
                                 **encoder_kwargs)
            enc = nn.Sequential(cnn, Flatten(), mlp)
            enc.input_size = img_width * img_height * img_num_channels
            enc.output_size = latent_dim
            return enc
    else:
        context_keys_to_save = [
            state_desired_goal_key, latent_desired_goal_key
        ]
        context_key_for_rl = state_desired_goal_key
        observation_key_for_rl = state_observation_key

        def create_encoder():
            in_dim = (state_expl_env.observation_space.
                      spaces[state_observation_key].low.size)
            enc = ConcatMultiHeadedMlp(input_size=in_dim,
                                       output_sizes=[latent_dim, latent_dim],
                                       **encoder_kwargs)
            enc.input_size = in_dim
            enc.output_size = latent_dim
            return enc

    encoder_net = create_encoder()
    if train_results:
        print("Using transfer encoder")
        encoder_net = train_results.encoder_net
    mu_encoder_net = EncoderMuFromEncoderDistribution(encoder_net)
    target_encoder_net = create_encoder()
    mu_target_encoder_net = EncoderMuFromEncoderDistribution(
        target_encoder_net)
    encoder_input_dim = encoder_net.input_size

    encoder = EncoderFromNetwork(mu_encoder_net)
    encoder.to(ptu.device)
    if reward_type == 'encoder_distance':
        reward_fn = EncoderRewardFnFromMultitaskEnv(
            encoder=encoder,
            next_state_encoder_input_key=observation_key_for_rl,
            context_key=latent_desired_goal_key,
            **reward_config,
        )
    elif reward_type == 'target_encoder_distance':
        target_encoder = EncoderFromNetwork(mu_target_encoder_net)
        reward_fn = EncoderRewardFnFromMultitaskEnv(
            encoder=target_encoder,
            next_state_encoder_input_key=observation_key_for_rl,
            context_key=latent_desired_goal_key,
            **reward_config,
        )
    elif reward_type == 'state_distance':
        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=state_expl_env,
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                'state_observation'),
            desired_goal_key=state_desired_goal_key,
            achieved_goal_key='state_achieved_goal',
        )
    else:
        raise ValueError("invalid reward type {}".format(reward_type))
    expl_env, expl_context_distrib = setup_env(state_expl_env, encoder,
                                               reward_fn)
    eval_env, eval_context_distrib = setup_env(state_eval_env, encoder,
                                               reward_fn)

    action_dim = expl_env.action_space.low.size

    def make_qf(goal_encoder):
        if not backprop_rl_into_encoder:
            goal_encoder = Detach(goal_encoder)
        if qf_state_encoder_is_goal_encoder:
            state_encoder = goal_encoder
        else:
            state_encoder = EncoderMuFromEncoderDistribution(create_encoder())
            raise RuntimeError(
                "State encoder must be goal encoder for resuming exps")
        if use_parallel_qf:
            return ParallelDisentangledMlpQf(
                goal_encoder=goal_encoder,
                state_encoder=state_encoder,
                preprocess_obs_dim=encoder_input_dim,
                action_dim=action_dim,
                post_encoder_mlp_kwargs=qf_kwargs,
                vectorized=True,
                **disentangled_qf_kwargs)
        else:
            return DisentangledMlpQf(goal_encoder=goal_encoder,
                                     state_encoder=state_encoder,
                                     preprocess_obs_dim=encoder_input_dim,
                                     action_dim=action_dim,
                                     qf_kwargs=qf_kwargs,
                                     vectorized=True,
                                     **disentangled_qf_kwargs)

    qf1 = make_qf(mu_encoder_net)
    qf2 = make_qf(mu_encoder_net)
    target_qf1 = make_qf(mu_target_encoder_net)
    target_qf2 = make_qf(mu_target_encoder_net)

    if use_separate_encoder_for_policy:
        policy_encoder = EncoderMuFromEncoderDistribution(create_encoder())
        policy_encoder_net = EncodeObsAndGoal(
            policy_encoder,
            encoder_input_dim,
            encode_state=True,
            encode_goal=True,
            detach_encoder_via_goal=False,
            detach_encoder_via_state=False,
        )
    else:
        policy_obs_encoder = mu_encoder_net
        if not backprop_rl_into_encoder:
            policy_obs_encoder = Detach(policy_obs_encoder)
        policy_encoder_net = EncodeObsAndGoal(policy_obs_encoder,
                                              encoder_input_dim,
                                              **policy_using_encoder_settings)
    obs_processor = nn.Sequential(
        policy_encoder_net, ConcatTuple(),
        MultiHeadedMlp(input_size=policy_encoder_net.output_size,
                       output_sizes=[action_dim, action_dim],
                       **policy_kwargs))
    policy = PolicyFromDistributionGenerator(TanhGaussian(obs_processor))

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key_for_rl]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        batch['raw_next_observations'] = next_obs
        return batch

    if use_image_observations:
        # Do this so that the context has all two/three: the state, image, and
        # encoded goal
        sample_context_from_observation = compose(
            RemapKeyFn({
                state_desired_goal_key: state_observation_key,
                img_desired_goal_key: img_observation_key,
            }),
            ReEncoderAchievedStateFn(
                encoder=encoder,
                encoder_input_key=context_key_for_rl,
                encoder_output_key=latent_desired_goal_key,
                keys_to_keep=[state_desired_goal_key, img_desired_goal_key],
            ),
        )
        ob_keys_to_save_in_buffer = [
            state_observation_key, img_observation_key
        ]
    else:
        sample_context_from_observation = compose(
            RemapKeyFn({
                state_desired_goal_key: state_observation_key,
            }),
            ReEncoderAchievedStateFn(
                encoder=encoder,
                encoder_input_key=context_key_for_rl,
                encoder_output_key=latent_desired_goal_key,
                keys_to_keep=[state_desired_goal_key],
            ),
        )
        ob_keys_to_save_in_buffer = [state_observation_key]

    encoder_output_dim = mu_encoder_net.output_size
    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=context_keys_to_save,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_observation,
        observation_keys=ob_keys_to_save_in_buffer,
        observation_key=observation_key_for_rl,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        reward_dim=encoder_output_dim,
        **replay_buffer_kwargs)

    disentangled_trainer = DisentangedTrainer(env=expl_env,
                                              policy=policy,
                                              qf1=qf1,
                                              qf2=qf2,
                                              target_qf1=target_qf1,
                                              target_qf2=target_qf2,
                                              **sac_trainer_kwargs)

    if train_encoder_as_vae:
        assert backprop_vae_into_encoder, \
            "No point in training the vae if not backpropagating into encoder"
        if vae_trainer_kwargs is None:
            vae_trainer_kwargs = {}
        if decoder_kwargs is None:
            decoder_kwargs = invert_encoder_mlp_params(encoder_kwargs)

        # VAE training
        def make_decoder():
            if use_image_observations:
                img_num_channels, img_height, img_width = env_renderer.image_chw
                if not mlp_for_image_decoder:
                    dcnn_in_channels, dcnn_in_height, dcnn_in_width = (
                        encoder_net._modules['0'].output_shape)
                    dcnn_input_size = (dcnn_in_channels * dcnn_in_width *
                                       dcnn_in_height)
                    decoder_cnn_kwargs = invert_encoder_params(
                        encoder_cnn_kwargs,
                        img_num_channels,
                    )
                    dcnn = BasicDCNN(input_width=dcnn_in_width,
                                     input_height=dcnn_in_height,
                                     input_channels=dcnn_in_channels,
                                     **decoder_cnn_kwargs)
                    mlp = Mlp(input_size=latent_dim,
                              output_size=dcnn_input_size,
                              **decoder_kwargs)
                    dec = nn.Sequential(mlp, dcnn)
                    dec.input_size = latent_dim
                else:
                    dec = nn.Sequential(
                        Mlp(input_size=latent_dim,
                            output_size=img_num_channels * img_height *
                            img_width,
                            **decoder_kwargs),
                        Reshape(img_num_channels, img_height, img_width))
                    dec.input_size = latent_dim
                    dec.output_size = img_num_channels * img_height * img_width
                return dec
            else:
                return Mlp(input_size=latent_dim,
                           output_size=encoder_input_dim,
                           **decoder_kwargs)

        decoder_net = make_decoder()
        vae = VAE(encoder_net, decoder_net)

        vae_trainer = VAETrainer(vae=vae, **vae_trainer_kwargs)

        if pretrain_vae:
            goal_key = (img_desired_goal_key
                        if use_image_observations else state_desired_goal_key)
            vae.to(ptu.device)
            train_ae(vae_trainer,
                     expl_context_distrib,
                     **pretrain_vae_kwargs,
                     goal_key=goal_key,
                     rl_csv_fname=rl_csv_fname)
        trainers = OrderedDict()
        trainers['vae_trainer'] = vae_trainer
        trainers['disentangled_trainer'] = disentangled_trainer
        trainer = JointTrainer(trainers)
    else:
        trainer = disentangled_trainer

    if not use_image_observations and use_debug_trainer:
        # TODO: implement this for images
        debug_trainer = DebugTrainer(
            observation_space=expl_env.observation_space.
            spaces[state_observation_key],
            encoder=mu_encoder_net,
            encoder_output_dim=encoder_output_dim,
        )
        trainer = JointTrainer([trainer, debug_trainer])

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key_for_rl,
        context_keys_for_policy=[context_key_for_rl],
    )
    exploration_policy = create_exploration_policy(policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key_for_rl,
        context_keys_for_policy=[context_key_for_rl],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    video_renderer = EnvRenderer(**video_renderer_kwargs)
    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key_for_rl,
            context_keys_for_policy=[context_key_for_rl],
        )
        if save_debug_video and not use_image_observations:
            # TODO: add visualization for image-based envs
            obj1_sweep_renderers = {
                'sweep_obj1_%d' % i: DebugEnvRenderer(encoder, i,
                                                      **debug_renderer_kwargs)
                for i in range(encoder_output_dim)
            }
            obj0_sweep_renderers = {
                'sweep_obj0_%d' % i: DebugEnvRenderer(encoder, i,
                                                      **debug_renderer_kwargs)
                for i in range(encoder_output_dim)
            }

            debugger_one = DebugEnvRenderer(encoder, 0,
                                            **debug_renderer_kwargs)

            low = eval_env.env.observation_space[
                state_observation_key].low.min()
            high = eval_env.env.observation_space[
                state_observation_key].high.max()
            y = np.linspace(low, high, num=debugger_one.image_shape[0])
            x = np.linspace(low, high, num=debugger_one.image_shape[1])
            cross = np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])

            def create_shared_data_creator(obj_index):
                def compute_shared_data(raw_obs, env):
                    state = raw_obs[observation_key_for_rl]
                    obs = state[:2]
                    goal = state[2:]
                    if obj_index == 0:
                        new_states = np.concatenate(
                            [
                                np.repeat(obs[None, :], cross.shape[0],
                                          axis=0),
                                cross,
                            ],
                            axis=1,
                        )
                    elif obj_index == 1:
                        new_states = np.concatenate(
                            [
                                cross,
                                np.repeat(
                                    goal[None, :], cross.shape[0], axis=0),
                            ],
                            axis=1,
                        )
                    else:
                        raise ValueError(obj_index)
                    return encoder.encode(new_states)

                return compute_shared_data

            obj0_sweeper = create_shared_data_creator(0)
            obj1_sweeper = create_shared_data_creator(1)

            def add_images(env, base_distribution):
                if use_image_observations:
                    img_env = env
                    image_goal_distribution = base_distribution
                else:
                    state_env = env.env
                    image_goal_distribution = AddImageDistribution(
                        env=state_env,
                        base_distribution=base_distribution,
                        image_goal_key='image_desired_goal',
                        renderer=video_renderer,
                    )
                    img_env = InsertImageEnv(state_env,
                                             renderer=video_renderer)
                img_env = InsertDebugImagesEnv(
                    img_env,
                    obj1_sweep_renderers,
                    compute_shared_data=obj1_sweeper,
                )
                img_env = InsertDebugImagesEnv(
                    img_env,
                    obj0_sweep_renderers,
                    compute_shared_data=obj0_sweeper,
                )
                return ContextualEnv(
                    img_env,
                    context_distribution=image_goal_distribution,
                    reward_fn=reward_fn,
                    observation_key=observation_key_for_rl,
                    update_env_info_fn=delete_info,
                )

            img_eval_env = add_images(eval_env, eval_context_distrib)
            img_expl_env = add_images(expl_env, expl_context_distrib)

            def get_extra_imgs(
                path,
                index_in_path,
                env,
            ):
                return [
                    path['full_observations'][index_in_path][key]
                    for key in obj1_sweep_renderers
                ] + [
                    path['full_observations'][index_in_path][key]
                    for key in obj0_sweep_renderers
                ]

            img_formats = [video_renderer.output_image_format]
            for r in obj1_sweep_renderers.values():
                img_formats.append(r.output_image_format)
            for r in obj0_sweep_renderers.values():
                img_formats.append(r.output_image_format)
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                MakeDeterministic(policy),
                tag="eval",
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                get_extra_imgs=get_extra_imgs,
                **save_video_kwargs)
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                exploration_policy,
                tag="train",
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                get_extra_imgs=get_extra_imgs,
                **save_video_kwargs)
        else:
            video_renderer = EnvRenderer(**video_renderer_kwargs)

            def add_images(env, base_distribution):
                if use_image_observations:
                    video_env = InsertImageEnv(
                        env,
                        renderer=video_renderer,
                        image_key='video_observation',
                    )
                    image_goal_distribution = base_distribution
                else:
                    video_env = InsertImageEnv(
                        env,
                        renderer=video_renderer,
                        image_key='image_observation',
                    )
                    state_env = env.env
                    image_goal_distribution = AddImageDistribution(
                        env=state_env,
                        base_distribution=base_distribution,
                        image_goal_key='image_desired_goal',
                        renderer=video_renderer,
                    )
                return ContextualEnv(
                    video_env,
                    context_distribution=image_goal_distribution,
                    reward_fn=reward_fn,
                    observation_key=observation_key_for_rl,
                    update_env_info_fn=delete_info,
                )

            img_eval_env = add_images(eval_env, eval_context_distrib)
            img_expl_env = add_images(expl_env, expl_context_distrib)

            if use_image_observations:
                keys_to_show = [
                    'image_desired_goal',
                    'image_observation',
                    'video_observation',
                ]
                image_formats = [
                    env_renderer.output_image_format,
                    env_renderer.output_image_format,
                    video_renderer.output_image_format,
                ]
            else:
                keys_to_show = ['image_desired_goal', 'image_observation']
                image_formats = [
                    video_renderer.output_image_format,
                    video_renderer.output_image_format,
                ]
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                MakeDeterministic(policy),
                tag="eval",
                imsize=video_renderer.image_chw[1],
                keys_to_show=keys_to_show,
                image_formats=image_formats,
                **save_video_kwargs)
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                exploration_policy,
                tag="train",
                imsize=video_renderer.image_chw[1],
                keys_to_show=keys_to_show,
                image_formats=image_formats,
                **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)
    if visualize_representation:
        from multiworld.envs.pygame import PickAndPlaceEnv
        if not isinstance(state_eval_env, PickAndPlaceEnv):
            raise NotImplementedError()
        num_objects = (state_eval_env.observation_space[state_observation_key].
                       low.size) // 2
        state_space = state_eval_env.observation_space['state_observation']
        start_states = np.vstack([state_space.sample() for _ in range(6)])
        if use_image_observations:

            def state_to_encoder_input(state):
                goal_dict = {
                    'state_desired_goal': state,
                }
                env_state = state_eval_env.get_env_state()
                state_eval_env.set_to_goal(goal_dict)
                start_img = env_renderer(state_eval_env)
                state_eval_env.set_env_state(env_state)
                return start_img

            for i in range(num_objects):
                visualize_representation = create_visualize_representation(
                    encoder,
                    i,
                    eval_env,
                    video_renderer,
                    state_to_encoder_input=state_to_encoder_input,
                    env_renderer=env_renderer,
                    start_states=start_states,
                    **debug_visualization_kwargs)
                algorithm.post_train_funcs.append(visualize_representation)
        else:
            for i in range(num_objects):
                visualize_representation = create_visualize_representation(
                    encoder, i, eval_env, video_renderer,
                    **debug_visualization_kwargs)
                algorithm.post_train_funcs.append(visualize_representation)

    if distance_scatterplot_save_period > 0:
        algorithm.post_train_funcs.append(
            create_save_h_vs_state_distance_fn(
                distance_scatterplot_save_period,
                distance_scatterplot_initial_save_period,
                encoder,
                observation_key_for_rl,
            ))
    algorithm.train()
    train_results = dict(encoder_net=encoder_net, )
    TrainResults = namedtuple('TrainResults', sorted(train_results))
    return TrainResults(**train_results)
示例#8
0
 def contextual_env_distrib_reward(_env_id,
                                   _env_class=None,
                                   _env_kwargs=None):
     base_env = get_gym_env(
         _env_id,
         env_class=env_class,
         env_kwargs=env_kwargs,
         unwrap_timed_envs=True,
     )
     if init_camera:
         base_env.initialize_camera(init_camera)
     if (isinstance(stub_env, AntFullPositionGoalEnv)
             and normalize_distances_for_full_state_ant):
         base_env = NormalizeAntFullPositionGoalEnv(base_env)
         normalize_env = base_env
     else:
         normalize_env = None
     env = NoisyAction(base_env, action_noise_scale)
     diag_fns = []
     if is_gym_env:
         goal_distribution = GoalDictDistributionFromGymGoalEnv(
             env,
             desired_goal_key=desired_goal_key,
         )
         diag_fns.append(
             GenericGoalConditionedContextualDiagnostics(
                 desired_goal_key=desired_goal_key,
                 achieved_goal_key=achieved_goal_key,
                 success_threshold=success_threshold,
             ))
     else:
         goal_distribution = GoalDictDistributionFromMultitaskEnv(
             env,
             desired_goal_keys=[desired_goal_key],
         )
         diag_fns.append(
             GoalConditionedDiagnosticsToContextualDiagnostics(
                 env.goal_conditioned_diagnostics,
                 desired_goal_key=desired_goal_key,
                 observation_key=observation_key,
             ))
     if isinstance(stub_env, AntFullPositionGoalEnv):
         diag_fns.append(
             AntFullPositionGoalEnvDiagnostics(
                 desired_goal_key=desired_goal_key,
                 achieved_goal_key=achieved_goal_key,
                 success_threshold=success_threshold,
                 normalize_env=normalize_env,
             ))
     # if isinstance(stub_env, HopperFullPositionGoalEnv):
     #     diag_fns.append(
     #         HopperFullPositionGoalEnvDiagnostics(
     #             desired_goal_key=desired_goal_key,
     #             achieved_goal_key=achieved_goal_key,
     #             success_threshold=success_threshold,
     #         )
     #     )
     achieved_from_ob = IndexIntoAchievedGoal(achieved_goal_key, )
     if reward_type == 'sparse':
         distance_fn = L2Distance(
             achieved_goal_from_observation=achieved_from_ob,
             desired_goal_key=desired_goal_key,
         )
         reward_fn = ThresholdDistanceReward(distance_fn, success_threshold)
     elif reward_type == 'negative_distance':
         reward_fn = NegativeL2Distance(
             achieved_goal_from_observation=achieved_from_ob,
             desired_goal_key=desired_goal_key,
         )
     else:
         reward_fn = ProbabilisticGoalRewardFn(
             dynamics_model,
             state_key=observation_key,
             context_key=context_key,
             reward_type=reward_type,
             discount_factor=discount_factor,
         )
     goal_distribution = PresampledDistribution(goal_distribution,
                                                num_presampled_goals)
     final_env = ContextualEnv(
         env,
         context_distribution=goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
         contextual_diagnostics_fns=diag_fns,
         update_env_info_fn=delete_info,
     )
     return final_env, goal_distribution, reward_fn
示例#9
0
def probabilistic_goal_reaching_experiment(
    max_path_length,
    qf_kwargs,
    policy_kwargs,
    pgr_trainer_kwargs,
    replay_buffer_kwargs,
    algo_kwargs,
    env_id,
    discount_factor,
    reward_type,
    # Dynamics model
    dynamics_model_version,
    dynamics_model_config,
    dynamics_delta_model_config=None,
    dynamics_adam_config=None,
    dynamics_ensemble_kwargs=None,
    # Discount model
    learn_discount_model=False,
    discount_adam_config=None,
    discount_model_config=None,
    prior_discount_weight_schedule_kwargs=None,
    # Environment
    env_class=None,
    env_kwargs=None,
    observation_key='state_observation',
    desired_goal_key='state_desired_goal',
    exploration_policy_kwargs=None,
    action_noise_scale=0.,
    num_presampled_goals=4096,
    success_threshold=0.05,
    # Video / visualization parameters
    save_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
    plot_renderer_kwargs=None,
    eval_env_ids=None,
    # Debugging params
    visualize_dynamics=False,
    visualize_discount_model=False,
    visualize_all_plots=False,
    plot_discount=False,
    plot_reward=False,
    plot_bootstrap_value=False,
    # env specific-params
    normalize_distances_for_full_state_ant=False,
):
    if dynamics_ensemble_kwargs is None:
        dynamics_ensemble_kwargs = {}
    if eval_env_ids is None:
        eval_env_ids = {'eval': env_id}
    if discount_model_config is None:
        discount_model_config = {}
    if dynamics_delta_model_config is None:
        dynamics_delta_model_config = {}
    if dynamics_adam_config is None:
        dynamics_adam_config = {}
    if discount_adam_config is None:
        discount_adam_config = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not video_renderer_kwargs:
        video_renderer_kwargs = {}
    if not plot_renderer_kwargs:
        plot_renderer_kwargs = video_renderer_kwargs.copy()
        plot_renderer_kwargs['dpi'] = 48
    context_key = desired_goal_key

    stub_env = get_gym_env(
        env_id,
        env_class=env_class,
        env_kwargs=env_kwargs,
        unwrap_timed_envs=True,
    )
    is_gym_env = (
        isinstance(stub_env, FetchEnv) or isinstance(stub_env, AntXYGoalEnv)
        or isinstance(stub_env, AntFullPositionGoalEnv)
        # or isinstance(stub_env, HopperFullPositionGoalEnv)
    )
    is_ant_full_pos = isinstance(stub_env, AntFullPositionGoalEnv)

    if is_gym_env:
        achieved_goal_key = desired_goal_key.replace('desired', 'achieved')
        ob_keys_to_save_in_buffer = [observation_key, achieved_goal_key]
    elif isinstance(stub_env, SawyerPickAndPlaceEnvYZ):
        achieved_goal_key = desired_goal_key.replace('desired', 'achieved')
        ob_keys_to_save_in_buffer = [observation_key, achieved_goal_key]
    else:
        achieved_goal_key = observation_key
        ob_keys_to_save_in_buffer = [observation_key]
    # TODO move all env-specific code to other file
    if isinstance(stub_env, SawyerDoorHookEnv):
        init_camera = sawyer_door_env_camera_v0
    elif isinstance(stub_env, SawyerPushAndReachXYEnv):
        init_camera = sawyer_init_camera_zoomed_in
    elif isinstance(stub_env, SawyerPickAndPlaceEnvYZ):
        init_camera = sawyer_pick_and_place_camera
    else:
        init_camera = None

    full_ob_space = stub_env.observation_space
    action_space = stub_env.action_space
    state_to_goal = StateToGoalFn(stub_env)
    dynamics_model = create_goal_dynamics_model(
        full_ob_space[observation_key],
        action_space,
        full_ob_space[achieved_goal_key],
        dynamics_model_version,
        state_to_goal,
        dynamics_model_config,
        dynamics_delta_model_config,
        ensemble_model_kwargs=dynamics_ensemble_kwargs,
    )
    sample_context_from_obs_dict_fn = RemapKeyFn(
        {context_key: achieved_goal_key})

    def contextual_env_distrib_reward(_env_id,
                                      _env_class=None,
                                      _env_kwargs=None):
        base_env = get_gym_env(
            _env_id,
            env_class=env_class,
            env_kwargs=env_kwargs,
            unwrap_timed_envs=True,
        )
        if init_camera:
            base_env.initialize_camera(init_camera)
        if (isinstance(stub_env, AntFullPositionGoalEnv)
                and normalize_distances_for_full_state_ant):
            base_env = NormalizeAntFullPositionGoalEnv(base_env)
            normalize_env = base_env
        else:
            normalize_env = None
        env = NoisyAction(base_env, action_noise_scale)
        diag_fns = []
        if is_gym_env:
            goal_distribution = GoalDictDistributionFromGymGoalEnv(
                env,
                desired_goal_key=desired_goal_key,
            )
            diag_fns.append(
                GenericGoalConditionedContextualDiagnostics(
                    desired_goal_key=desired_goal_key,
                    achieved_goal_key=achieved_goal_key,
                    success_threshold=success_threshold,
                ))
        else:
            goal_distribution = GoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
            )
            diag_fns.append(
                GoalConditionedDiagnosticsToContextualDiagnostics(
                    env.goal_conditioned_diagnostics,
                    desired_goal_key=desired_goal_key,
                    observation_key=observation_key,
                ))
        if isinstance(stub_env, AntFullPositionGoalEnv):
            diag_fns.append(
                AntFullPositionGoalEnvDiagnostics(
                    desired_goal_key=desired_goal_key,
                    achieved_goal_key=achieved_goal_key,
                    success_threshold=success_threshold,
                    normalize_env=normalize_env,
                ))
        # if isinstance(stub_env, HopperFullPositionGoalEnv):
        #     diag_fns.append(
        #         HopperFullPositionGoalEnvDiagnostics(
        #             desired_goal_key=desired_goal_key,
        #             achieved_goal_key=achieved_goal_key,
        #             success_threshold=success_threshold,
        #         )
        #     )
        achieved_from_ob = IndexIntoAchievedGoal(achieved_goal_key, )
        if reward_type == 'sparse':
            distance_fn = L2Distance(
                achieved_goal_from_observation=achieved_from_ob,
                desired_goal_key=desired_goal_key,
            )
            reward_fn = ThresholdDistanceReward(distance_fn, success_threshold)
        elif reward_type == 'negative_distance':
            reward_fn = NegativeL2Distance(
                achieved_goal_from_observation=achieved_from_ob,
                desired_goal_key=desired_goal_key,
            )
        else:
            reward_fn = ProbabilisticGoalRewardFn(
                dynamics_model,
                state_key=observation_key,
                context_key=context_key,
                reward_type=reward_type,
                discount_factor=discount_factor,
            )
        goal_distribution = PresampledDistribution(goal_distribution,
                                                   num_presampled_goals)
        final_env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=diag_fns,
            update_env_info_fn=delete_info,
        )
        return final_env, goal_distribution, reward_fn

    expl_env, expl_context_distrib, reward_fn = contextual_env_distrib_reward(
        env_id,
        env_class,
        env_kwargs,
    )
    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    def create_policy():
        obs_processor = MultiHeadedMlp(input_size=obs_dim,
                                       output_sizes=[action_dim, action_dim],
                                       **policy_kwargs)
        return PolicyFromDistributionGenerator(TanhGaussian(obs_processor))

    policy = create_policy()

    def concat_context_to_obs(batch, replay_buffer, obs_dict, next_obs_dict,
                              new_contexts):
        obs = batch['observations']
        next_obs = batch['next_observations']
        batch['original_observations'] = obs
        batch['original_next_observations'] = next_obs
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=expl_env,
        context_keys=[context_key],
        observation_keys=ob_keys_to_save_in_buffer,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)

    def create_trainer():
        trainers = OrderedDict()
        if learn_discount_model:
            discount_model = create_discount_model(
                ob_space=stub_env.observation_space[observation_key],
                goal_space=stub_env.observation_space[context_key],
                action_space=stub_env.action_space,
                model_kwargs=discount_model_config)
            optimizer = optim.Adam(discount_model.parameters(),
                                   **discount_adam_config)
            discount_trainer = DiscountModelTrainer(
                discount_model,
                optimizer,
                observation_key='observations',
                next_observation_key='original_next_observations',
                goal_key=context_key,
                state_to_goal_fn=state_to_goal,
            )
            trainers['discount_trainer'] = discount_trainer
        else:
            discount_model = None
        if prior_discount_weight_schedule_kwargs is not None:
            schedule = create_schedule(**prior_discount_weight_schedule_kwargs)
        else:
            schedule = None
        pgr_trainer = PGRTrainer(env=expl_env,
                                 policy=policy,
                                 qf1=qf1,
                                 qf2=qf2,
                                 target_qf1=target_qf1,
                                 target_qf2=target_qf2,
                                 discount=discount_factor,
                                 discount_model=discount_model,
                                 prior_discount_weight_schedule=schedule,
                                 **pgr_trainer_kwargs)
        trainers[''] = pgr_trainer
        optimizers = [
            pgr_trainer.qf1_optimizer,
            pgr_trainer.qf2_optimizer,
            pgr_trainer.alpha_optimizer,
            pgr_trainer.policy_optimizer,
        ]
        if dynamics_model_version in {
                'learned_model',
                'learned_model_ensemble',
                'learned_model_laplace',
                'learned_model_laplace_global_variance',
                'learned_model_gaussian_global_variance',
        }:
            model_opt = optim.Adam(dynamics_model.parameters(),
                                   **dynamics_adam_config)
        elif dynamics_model_version in {
                'fixed_standard_laplace',
                'fixed_standard_gaussian',
        }:
            model_opt = None
        else:
            raise NotImplementedError()
        model_trainer = GenerativeGoalDynamicsModelTrainer(
            dynamics_model,
            model_opt,
            state_to_goal=state_to_goal,
            observation_key='original_observations',
            next_observation_key='original_next_observations',
        )
        trainers['dynamics_trainer'] = model_trainer
        optimizers.append(model_opt)
        return JointTrainer(trainers), pgr_trainer

    trainer, pgr_trainer = create_trainer()

    eval_policy = MakeDeterministic(policy)

    def create_eval_path_collector(some_eval_env):
        return ContextualPathCollector(
            some_eval_env,
            eval_policy,
            observation_key=observation_key,
            context_keys_for_policy=[context_key],
        )

    path_collectors = dict()
    eval_env_name_to_env_and_context_distrib = dict()
    for name, extra_env_id in eval_env_ids.items():
        env, context_distrib, _ = contextual_env_distrib_reward(extra_env_id)
        path_collectors[name] = create_eval_path_collector(env)
        eval_env_name_to_env_and_context_distrib[name] = (env, context_distrib)
    eval_path_collector = JointPathCollector(path_collectors)
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )

    def get_eval_diagnostics(key_to_paths):
        stats = OrderedDict()
        for eval_env_name, paths in key_to_paths.items():
            env, _ = eval_env_name_to_env_and_context_distrib[eval_env_name]
            stats.update(
                add_prefix(
                    env.get_diagnostics(paths),
                    eval_env_name,
                    divider='/',
                ))
            stats.update(
                add_prefix(
                    eval_util.get_generic_path_information(paths),
                    eval_env_name,
                    divider='/',
                ))
        return stats

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=None,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        evaluation_get_diagnostic_functions=[get_eval_diagnostics],
        **algo_kwargs)
    algorithm.to(ptu.device)

    if normalize_distances_for_full_state_ant and is_ant_full_pos:
        qpos_weights = expl_env.unwrapped.presampled_qpos.std(axis=0)
    else:
        qpos_weights = None

    if save_video:
        if is_gym_env:
            video_renderer = GymEnvRenderer(**video_renderer_kwargs)

            def set_goal_for_visualization(env, policy, o):
                goal = o[desired_goal_key]
                if normalize_distances_for_full_state_ant and is_ant_full_pos:
                    unnormalized_goal = goal * qpos_weights
                    env.unwrapped.goal = unnormalized_goal
                else:
                    env.unwrapped.goal = goal

            rollout_function = partial(
                rf.contextual_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                context_keys_for_policy=[context_key],
                reset_callback=set_goal_for_visualization,
            )
        else:
            video_renderer = EnvRenderer(**video_renderer_kwargs)
            rollout_function = partial(
                rf.contextual_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                context_keys_for_policy=[context_key],
                reset_callback=None,
            )

        renderers = OrderedDict(image_observation=video_renderer, )
        state_env = expl_env.env
        state_space = state_env.observation_space[observation_key]
        low = state_space.low.min()
        high = state_space.high.max()
        y = np.linspace(low, high, num=video_renderer.image_chw[1])
        x = np.linspace(low, high, num=video_renderer.image_chw[2])
        all_xy_np = np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])
        all_xy_torch = ptu.from_numpy(all_xy_np)
        num_states = all_xy_torch.shape[0]
        if visualize_dynamics:

            def create_dynamics_visualizer(show_prob, vary_state=False):
                def get_prob(obs_dict, action):
                    obs = obs_dict['state_observation']
                    obs_torch = ptu.from_numpy(obs)[None]
                    action_torch = ptu.from_numpy(action)[None]
                    if vary_state:
                        action_repeated = torch.zeros((num_states, 2))
                        dist = dynamics_model(all_xy_torch, action_repeated)
                        goal = ptu.from_numpy(
                            obs_dict['state_desired_goal'][None])
                        log_probs = dist.log_prob(goal)
                    else:
                        dist = dynamics_model(obs_torch, action_torch)
                        log_probs = dist.log_prob(all_xy_torch)
                    if show_prob:
                        return log_probs.exp()
                    else:
                        return log_probs

                return get_prob

            renderers['log_prob'] = ValueRenderer(
                create_dynamics_visualizer(False), **video_renderer_kwargs)
            # renderers['prob'] = ValueRenderer(
            #     create_dynamics_visualizer(True), **video_renderer_kwargs
            # )
            renderers['log_prob_vary_state'] = ValueRenderer(
                create_dynamics_visualizer(False, vary_state=True),
                only_get_image_once_per_episode=True,
                max_out_walls=isinstance(stub_env, PickAndPlaceEnv),
                **video_renderer_kwargs)
            # renderers['prob_vary_state'] = ValueRenderer(
            #     create_dynamics_visualizer(True, vary_state=True),
            #     **video_renderer_kwargs)
        if visualize_discount_model and pgr_trainer.discount_model:

            def get_discount_values(obs, action):
                obs = obs['state_observation']
                obs_torch = ptu.from_numpy(obs)[None]
                combined_obs = torch.cat([
                    obs_torch.repeat(num_states, 1),
                    all_xy_torch,
                ],
                                         dim=1)

                action_torch = ptu.from_numpy(action)[None]
                action_repeated = action_torch.repeat(num_states, 1)
                return pgr_trainer.discount_model(combined_obs,
                                                  action_repeated)

            renderers['discount_model'] = ValueRenderer(
                get_discount_values,
                states_to_eval=all_xy_torch,
                **video_renderer_kwargs)
        if 'log_prob' in renderers and 'discount_model' in renderers:
            renderers['log_prob_time_discount'] = ProductRenderer(
                renderers['discount_model'], renderers['log_prob'],
                **video_renderer_kwargs)

        def get_reward(obs_dict, action, next_obs_dict):
            o = batchify(obs_dict)
            a = batchify(action)
            next_o = batchify(next_obs_dict)
            reward = reward_fn(o, a, next_o, next_o)
            return reward[0]

        def get_bootstrap(obs_dict, action, next_obs_dict, return_float=True):
            context_pt = ptu.from_numpy(obs_dict[context_key][None])
            o_pt = ptu.from_numpy(obs_dict[observation_key][None])
            next_o_pt = ptu.from_numpy(next_obs_dict[observation_key][None])
            action_torch = ptu.from_numpy(action[None])
            bootstrap, *_ = pgr_trainer.get_bootstrap_stats(
                torch.cat((o_pt, context_pt), dim=1),
                action_torch,
                torch.cat((next_o_pt, context_pt), dim=1),
            )
            if return_float:
                return ptu.get_numpy(bootstrap)[0, 0]
            else:
                return bootstrap

        def get_discount(obs_dict, action, next_obs_dict):
            bootstrap = get_bootstrap(obs_dict,
                                      action,
                                      next_obs_dict,
                                      return_float=False)
            reward_np = get_reward(obs_dict, action, next_obs_dict)
            reward = ptu.from_numpy(reward_np[None, None])
            context_pt = ptu.from_numpy(obs_dict[context_key][None])
            o_pt = ptu.from_numpy(obs_dict[observation_key][None])
            obs = torch.cat((o_pt, context_pt), dim=1)
            actions = ptu.from_numpy(action[None])
            discount = pgr_trainer.get_discount_factor(
                bootstrap,
                reward,
                obs,
                actions,
            )
            if isinstance(discount, torch.Tensor):
                discount = ptu.get_numpy(discount)[0, 0]
            return np.clip(discount, a_min=1e-3, a_max=1)

        def create_modify_fn(
            title,
            set_params=None,
            scientific=True,
        ):
            def modify(ax):
                ax.set_title(title)
                if set_params:
                    ax.set(**set_params)
                if scientific:
                    scaler = ScalarFormatter(useOffset=True)
                    scaler.set_powerlimits((1, 1))
                    ax.yaxis.set_major_formatter(scaler)
                    ax.ticklabel_format(axis='y', style='sci')

            return modify

        def add_left_margin(fig):
            fig.subplots_adjust(left=0.2)

        if visualize_all_plots or plot_discount:
            renderers['discount'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_discount,
                modify_ax_fn=create_modify_fn(
                    title='discount',
                    set_params=dict(
                        # yscale='log',
                        ylim=[-0.05, 1.1], ),
                    # scientific=False,
                ),
                modify_fig_fn=add_left_margin,
                # autoscale_y=False,
                **plot_renderer_kwargs)

        if visualize_all_plots or plot_reward:
            renderers['reward'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_reward,
                modify_ax_fn=create_modify_fn(title='reward',
                                              # scientific=False,
                                              ),
                modify_fig_fn=add_left_margin,
                **plot_renderer_kwargs)
        if visualize_all_plots or plot_bootstrap_value:
            renderers['bootstrap-value'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_bootstrap,
                modify_ax_fn=create_modify_fn(title='bootstrap value',
                                              # scientific=False,
                                              ),
                modify_fig_fn=add_left_margin,
                **plot_renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            if is_gym_env:
                goal_distribution = state_distribution
            else:
                goal_distribution = AddImageDistribution(
                    env=state_env,
                    base_distribution=state_distribution,
                    image_goal_key='image_desired_goal',
                    renderer=video_renderer,
                )
            context_env = ContextualEnv(
                state_env,
                context_distribution=goal_distribution,
                reward_fn=reward_fn,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
            return InsertDebugImagesEnv(
                context_env,
                renderers=renderers,
            )

        img_expl_env = add_images(expl_env, expl_context_distrib)
        if is_gym_env:
            imgs_to_show = list(renderers.keys())
        else:
            imgs_to_show = ['image_desired_goal'] + list(renderers.keys())
        img_formats = [video_renderer.output_image_format]
        img_formats += [r.output_image_format for r in renderers.values()]
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="xplor",
            imsize=video_renderer.image_chw[1],
            image_formats=img_formats,
            keys_to_show=imgs_to_show,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(expl_video_func)
        for eval_env_name, (env, context_distrib) in (
                eval_env_name_to_env_and_context_distrib.items()):
            img_eval_env = add_images(env, context_distrib)
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                eval_policy,
                tag=eval_env_name,
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                keys_to_show=imgs_to_show,
                **save_video_kwargs)
            algorithm.post_train_funcs.append(eval_video_func)

    algorithm.train()
示例#10
0
    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode,
                                          presampled_goals_path):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        # encoded_env = EncoderWrappedEnv(
        #     img_env,
        #     model,
        #     dict(image_observation="latent_observation", ),
        # )
        # if goal_sampling_mode == "vae_prior":
        #     latent_goal_distribution = PriorDistribution(
        #         model.representation_size,
        #         desired_goal_key,
        #     )
        #     diagnostics = StateImageGoalDiagnosticsFn({}, )
        # elif goal_sampling_mode == "presampled":
        #     diagnostics = state_env.get_contextual_diagnostics
        #     image_goal_distribution = PresampledPathDistribution(
        #         presampled_goals_path,
        #     )

        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        # elif goal_sampling_mode == "reset_of_env":
        #     state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        #     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
        #         state_goal_env,
        #         desired_goal_keys=[state_goal_key],
        #     )
        #     image_goal_distribution = AddImageDistribution(
        #         env=state_env,
        #         base_distribution=state_goal_distribution,
        #         image_goal_key=image_goal_key,
        #         renderer=renderer,
        #     )
        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        #     no_goal_distribution = PriorDistribution(
        #         representation_size=0,
        #         key="no_goal",
        #     )
        #     diagnostics = state_goal_env.get_contextual_diagnostics
        # else:
        #     error
        diagnostics = StateImageGoalDiagnosticsFn({}, )
        no_goal_distribution = PriorDistribution(
            representation_size=0,
            key="no_goal",
        )

        reward_fn = GraspingRewardFn(
            # img_env, # state_env,
            # observation_key=observation_key,
            # desired_goal_key=desired_goal_key,
            # **reward_kwargs
        )

        env = ContextualEnv(
            img_env,  # state_env,
            context_distribution=no_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, no_goal_distribution, reward_fn
示例#11
0
def experiment(variant):
    render = variant.get("render", False)
    debug = variant.get("debug", False)
    vae_path = variant.get("vae_path", False)

    process_args(variant)

    env_class = variant.get("env_class")
    env_kwargs = variant.get("env_kwargs")
    env_id = variant.get("env_id")
    # expl_env = env_class(**env_kwargs)
    # eval_env = env_class(**env_kwargs)
    expl_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
    eval_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
    env = eval_env

    if variant.get('sparse_reward', False):
        expl_env = RewardWrapperEnv(expl_env, compute_hand_sparse_reward)
        eval_env = RewardWrapperEnv(eval_env, compute_hand_sparse_reward)

    if variant.get("vae_path", False):
        vae = load_local_or_remote_file(vae_path)
        variant['path_loader_kwargs']['model_path'] = vae_path
        renderer = EnvRenderer(**variant.get("renderer_kwargs", {}))
        expl_env = VQVAEWrappedEnv(InsertImageEnv(expl_env, renderer=renderer),
                                   vae,
                                   reward_params=variant.get(
                                       "reward_params", {}),
                                   **variant.get('vae_wrapped_env_kwargs', {}))
        eval_env = VQVAEWrappedEnv(InsertImageEnv(eval_env, renderer=renderer),
                                   vae,
                                   reward_params=variant.get(
                                       "reward_params", {}),
                                   **variant.get('vae_wrapped_env_kwargs', {}))
        env = eval_env
        variant['path_loader_kwargs']['env'] = env

    if variant.get('add_env_demos', False):
        variant["path_loader_kwargs"]["demo_paths"].append(
            variant["env_demo_path"])

    if variant.get('add_env_offpolicy_data', False):
        variant["path_loader_kwargs"]["demo_paths"].append(
            variant["env_offpolicy_data_path"])

    if variant.get("use_masks", False):
        mask_wrapper_kwargs = variant.get("mask_wrapper_kwargs", dict())

        expl_mask_distribution_kwargs = variant[
            "expl_mask_distribution_kwargs"]
        expl_mask_distribution = DiscreteDistribution(
            **expl_mask_distribution_kwargs)
        expl_env = RewardMaskWrapper(env, expl_mask_distribution,
                                     **mask_wrapper_kwargs)

        eval_mask_distribution_kwargs = variant[
            "eval_mask_distribution_kwargs"]
        eval_mask_distribution = DiscreteDistribution(
            **eval_mask_distribution_kwargs)
        eval_env = RewardMaskWrapper(env, eval_mask_distribution,
                                     **mask_wrapper_kwargs)
        env = eval_env

    if variant.get("pretrained_algorithm_path", False):
        resume(variant)
        return

    path_loader_kwargs = variant.get("path_loader_kwargs", {})
    stack_obs = path_loader_kwargs.get("stack_obs", 1)
    if stack_obs > 1:
        expl_env = StackObservationEnv(expl_env, stack_obs=stack_obs)
        eval_env = StackObservationEnv(eval_env, stack_obs=stack_obs)

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = variant.get('achieved_goal_key',
                                    'latent_achieved_goal')

    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = eval_env.action_space.low.size

    if hasattr(expl_env, 'info_sizes'):
        env_info_sizes = expl_env.info_sizes
    else:
        env_info_sizes = dict()

    replay_buffer_kwargs = dict(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
    )
    replay_buffer_kwargs.update(variant.get('replay_buffer_kwargs', dict()))
    replay_buffer = ConcatToObsWrapper(
        ObsDictRelabelingBuffer(**replay_buffer_kwargs),
        [
            "resampled_goals",
        ],
    )
    replay_buffer_kwargs.update(
        variant.get('demo_replay_buffer_kwargs', dict()))
    demo_train_buffer = ConcatToObsWrapper(
        ObsDictRelabelingBuffer(**replay_buffer_kwargs),
        [
            "resampled_goals",
        ],
    )
    demo_test_buffer = ConcatToObsWrapper(
        ObsDictRelabelingBuffer(**replay_buffer_kwargs),
        [
            "resampled_goals",
        ],
    )

    M = variant['layer_size']
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    policy_class = variant.get("policy_class", TanhGaussianPolicy)
    policy = policy_class(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **variant['policy_kwargs'],
    )

    expl_policy = policy
    exploration_kwargs = variant.get('exploration_kwargs', {})
    if exploration_kwargs:
        if exploration_kwargs.get("deterministic_exploration", False):
            expl_policy = MakeDeterministic(policy)

        exploration_strategy = exploration_kwargs.get("strategy", None)
        if exploration_strategy is None:
            pass
        elif exploration_strategy == 'ou':
            es = OUStrategy(
                action_space=expl_env.action_space,
                max_sigma=exploration_kwargs['noise'],
                min_sigma=exploration_kwargs['noise'],
            )
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=expl_policy,
            )
        elif exploration_strategy == 'gauss_eps':
            es = GaussianAndEpislonStrategy(
                action_space=expl_env.action_space,
                max_sigma=exploration_kwargs['noise'],
                min_sigma=exploration_kwargs['noise'],  # constant sigma
                epsilon=0,
            )
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=expl_policy,
            )
        else:
            error

    trainer = AWACTrainer(env=eval_env,
                          policy=policy,
                          qf1=qf1,
                          qf2=qf2,
                          target_qf1=target_qf1,
                          target_qf2=target_qf2,
                          **variant['trainer_kwargs'])
    if variant['collection_mode'] == 'online':
        expl_path_collector = MdpStepCollector(
            expl_env,
            policy,
        )
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant['max_path_length'],
            batch_size=variant['batch_size'],
            num_epochs=variant['num_epochs'],
            num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'],
            num_expl_steps_per_train_loop=variant[
                'num_expl_steps_per_train_loop'],
            num_trains_per_train_loop=variant['num_trains_per_train_loop'],
            min_num_steps_before_training=variant[
                'min_num_steps_before_training'],
        )
    else:
        eval_path_collector = GoalConditionedPathCollector(
            eval_env,
            MakeDeterministic(policy),
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            render=render,
            goal_sampling_mode=variant.get("goal_sampling_mode", None),
        )
        expl_path_collector = GoalConditionedPathCollector(
            expl_env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            render=render,
            goal_sampling_mode=variant.get("goal_sampling_mode", None),
        )
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant['max_path_length'],
            batch_size=variant['batch_size'],
            num_epochs=variant['num_epochs'],
            num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'],
            num_expl_steps_per_train_loop=variant[
                'num_expl_steps_per_train_loop'],
            num_trains_per_train_loop=variant['num_trains_per_train_loop'],
            min_num_steps_before_training=variant[
                'min_num_steps_before_training'],
        )
    algorithm.to(ptu.device)

    if variant.get("save_video", True):
        video_func = VideoSaveFunction(
            env,
            variant,
        )
        #algorithm.post_train_funcs.append(video_func)
        algorithm.post_train_funcs.append(video_func)

    # if variant.get("save_video", False):
    #     from rlkit.visualization.video import VideoSaveFunction
    #     renderer_kwargs = variant.get("renderer_kwargs", {})
    #     save_video_kwargs = variant.get("save_video_kwargs", {})

    #     def get_video_func(
    #         env,
    #         policy,
    #         tag,
    #     ):
    #         renderer = EnvRenderer(**renderer_kwargs)
    #         state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
    #             env,
    #             desired_goal_keys=[desired_goal_key],
    #         )
    #         image_goal_distribution = AddImageDistribution(
    #             env=env,
    #             base_distribution=state_goal_distribution,
    #             image_goal_key='image_desired_goal',
    #             renderer=renderer,
    #         )
    #         img_env = InsertImageEnv(env, renderer=renderer)
    #         rollout_function = partial(
    #             rf.multitask_rollout,
    #             max_path_length=variant['max_path_length'],
    #             observation_key=observation_key,
    #             desired_goal_key=desired_goal_key,
    #             return_dict_obs=True,
    #         )
    #         reward_fn = ContextualRewardFnFromMultitaskEnv(
    #             env=env,
    #             achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key),
    #             desired_goal_key=desired_goal_key,
    #             achieved_goal_key="state_achieved_goal",
    #         )
    #         contextual_env = ContextualEnv(
    #             img_env,
    #             context_distribution=image_goal_distribution,
    #             reward_fn=reward_fn,
    #             observation_key=observation_key,
    #         )
    #         video_func = get_save_video_function(
    #             rollout_function,
    #             contextual_env,
    #             policy,
    #             tag=tag,
    #             imsize=renderer.width,
    #             image_format='CWH',
    #             **save_video_kwargs
    #         )
    #         return video_func
    # expl_video_func = get_video_func(expl_env, expl_policy, "expl")
    # eval_video_func = get_video_func(eval_env, MakeDeterministic(policy), "eval")
    # algorithm.post_train_funcs.append(eval_video_func)
    # algorithm.post_train_funcs.append(expl_video_func)

    if variant.get('save_paths', False):
        algorithm.post_train_funcs.append(save_paths)

    if variant.get('load_demos', False):
        path_loader_class = variant.get('path_loader_class', MDPPathLoader)
        path_loader = path_loader_class(trainer,
                                        replay_buffer=replay_buffer,
                                        demo_train_buffer=demo_train_buffer,
                                        demo_test_buffer=demo_test_buffer,
                                        **path_loader_kwargs)
        path_loader.load_demos()
    if variant.get('pretrain_policy', False):
        trainer.pretrain_policy_with_bc(
            policy,
            demo_train_buffer,
            demo_test_buffer,
            trainer.bc_num_pretrain_steps,
        )
    if variant.get('pretrain_rl', False):
        trainer.pretrain_q_with_bc_data()

    if variant.get('save_pretrained_algorithm', False):
        p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p')
        pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt')
        data = algorithm._get_snapshot()
        data['algorithm'] = algorithm
        torch.save(data, open(pt_path, "wb"))
        torch.save(data, open(p_path, "wb"))

    algorithm.train()
示例#12
0
def disco_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    generate_set_for_rl_kwargs,
    # VAE parameters
    create_vae_kwargs,
    vae_trainer_kwargs,
    vae_algo_kwargs,
    data_loader_kwargs,
    generate_set_for_vae_pretraining_kwargs,
    num_ungrouped_images,
    beta_schedule_kwargs=None,
    # Oracle settings
    use_ground_truth_reward=False,
    use_onehot_set_embedding=False,
    use_dummy_model=False,
    observation_key="latent_observation",
    # RIG comparison
    rig_goal_setter_kwargs=None,
    rig=False,
    # Miscellaneous
    reward_fn_kwargs=None,
    # None-VAE Params
    env_id=None,
    env_class=None,
    env_kwargs=None,
    latent_observation_key="latent_observation",
    state_observation_key="state_observation",
    image_observation_key="image_observation",
    set_description_key="set_description",
    example_state_key="example_state",
    example_image_key="example_image",
    # Exploration
    exploration_policy_kwargs=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
):
    if rig_goal_setter_kwargs is None:
        rig_goal_setter_kwargs = {}
    if reward_fn_kwargs is None:
        reward_fn_kwargs = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    renderer = EnvRenderer(**renderer_kwargs)

    sets = create_sets(
        env_id,
        env_class,
        env_kwargs,
        renderer,
        example_state_key=example_state_key,
        example_image_key=example_image_key,
        **generate_set_for_rl_kwargs,
    )
    if use_dummy_model:
        model = create_dummy_image_vae(img_chw=renderer.image_chw,
                                       **create_vae_kwargs)
    else:
        model = train_set_vae(
            create_vae_kwargs,
            vae_trainer_kwargs,
            vae_algo_kwargs,
            data_loader_kwargs,
            generate_set_for_vae_pretraining_kwargs,
            num_ungrouped_images,
            env_id=env_id,
            env_class=env_class,
            env_kwargs=env_kwargs,
            beta_schedule_kwargs=beta_schedule_kwargs,
            sets=sets,
            renderer=renderer,
        )
    expl_env, expl_context_distrib, expl_reward = (
        contextual_env_distrib_and_reward(
            vae=model,
            sets=sets,
            state_env=get_gym_env(
                env_id,
                env_class=env_class,
                env_kwargs=env_kwargs,
            ),
            renderer=renderer,
            reward_fn_kwargs=reward_fn_kwargs,
            use_ground_truth_reward=use_ground_truth_reward,
            state_observation_key=state_observation_key,
            latent_observation_key=latent_observation_key,
            example_image_key=example_image_key,
            set_description_key=set_description_key,
            observation_key=observation_key,
            image_observation_key=image_observation_key,
            rig_goal_setter_kwargs=rig_goal_setter_kwargs,
        ))
    eval_env, eval_context_distrib, eval_reward = (
        contextual_env_distrib_and_reward(
            vae=model,
            sets=sets,
            state_env=get_gym_env(
                env_id,
                env_class=env_class,
                env_kwargs=env_kwargs,
            ),
            renderer=renderer,
            reward_fn_kwargs=reward_fn_kwargs,
            use_ground_truth_reward=use_ground_truth_reward,
            state_observation_key=state_observation_key,
            latent_observation_key=latent_observation_key,
            example_image_key=example_image_key,
            set_description_key=set_description_key,
            observation_key=observation_key,
            image_observation_key=image_observation_key,
            rig_goal_setter_kwargs=rig_goal_setter_kwargs,
            oracle_rig_goal=rig,
        ))
    context_keys = [
        expl_context_distrib.mean_key,
        expl_context_distrib.covariance_key,
        expl_context_distrib.set_index_key,
        expl_context_distrib.set_embedding_key,
    ]
    if rig:
        context_keys_for_rl = [
            expl_context_distrib.mean_key,
        ]
    else:
        if use_onehot_set_embedding:
            context_keys_for_rl = [
                expl_context_distrib.set_embedding_key,
            ]
        else:
            context_keys_for_rl = [
                expl_context_distrib.mean_key,
                expl_context_distrib.covariance_key,
            ]

    obs_dim = np.prod(expl_env.observation_space.spaces[observation_key].shape)
    obs_dim += sum([
        np.prod(expl_env.observation_space.spaces[k].shape)
        for k in context_keys_for_rl
    ])
    action_dim = np.prod(expl_env.action_space.shape)

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch["observations"]
        next_obs = batch["next_observations"]
        contexts = [batch[k] for k in context_keys_for_rl]
        batch["observations"] = np.concatenate((obs, *contexts), axis=1)
        batch["next_observations"] = np.concatenate(
            (next_obs, *contexts),
            axis=1,
        )
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=context_keys,
        observation_keys=list(
            {observation_key, state_observation_key, latent_observation_key}),
        observation_key=observation_key,
        context_distribution=FilterKeys(
            expl_context_distrib,
            context_keys,
        ),
        sample_context_from_obs_dict_fn=None,
        # RemapKeyFn({context_key: observation_key}),
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs,
    )
    trainer = SACTrainer(
        env=expl_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **sac_trainer_kwargs,
    )

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=context_keys_for_rl,
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=context_keys_for_rl,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs,
    )
    algorithm.to(ptu.device)

    if save_video:
        set_index_key = eval_context_distrib.set_index_key
        expl_video_func = DisCoVideoSaveFunction(
            model,
            sets,
            expl_path_collector,
            tag="train",
            reconstruction_key="image_reconstruction",
            decode_set_image_key="decoded_set_prior",
            set_visualization_key="set_visualization",
            example_image_key=example_image_key,
            set_index_key=set_index_key,
            columns=len(sets),
            unnormalize=True,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs,
        )
        algorithm.post_train_funcs.append(expl_video_func)

        eval_video_func = DisCoVideoSaveFunction(
            model,
            sets,
            eval_path_collector,
            tag="eval",
            reconstruction_key="image_reconstruction",
            decode_set_image_key="decoded_set_prior",
            set_visualization_key="set_visualization",
            example_image_key=example_image_key,
            set_index_key=set_index_key,
            columns=len(sets),
            unnormalize=True,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs,
        )
        algorithm.post_train_funcs.append(eval_video_func)

    algorithm.train()
示例#13
0
    def contextual_env_distrib_and_reward(
        env_id,
        env_class,
        env_kwargs,
        goal_sampling_mode,
        env_mask_distribution_type,
        static_mask=None,
    ):
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        env.goal_sampling_mode = goal_sampling_mode
        use_static_mask = (not do_masking
                           or env_mask_distribution_type == 'static_mask')
        # Default to all ones mask if static mask isn't defined and should
        # be using static masks.
        if use_static_mask and static_mask is None:
            static_mask = np.ones(mask_dim)
        if not do_masking:
            assert env_mask_distribution_type == 'static_mask'

        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            env,
            desired_goal_keys=[desired_goal_key],
        )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=mask_dim,
            distribution_type=env_mask_distribution_type,
            static_mask=static_mask,
        )

        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=env,
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                observation_key),
        )

        if do_masking:
            if masking_reward_fn:
                reward_fn = partial(masking_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)
            else:
                reward_fn = partial(default_masked_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )

        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
        )
        return env, goal_distribution, reward_fn
示例#14
0
def masking_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    env_class=None,
    env_kwargs=None,
    observation_key='state_observation',
    desired_goal_key='state_desired_goal',
    achieved_goal_key='state_achieved_goal',
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
    train_env_id=None,
    eval_env_id=None,
    do_masking=True,
    mask_key="masked_observation",
    masking_eval_steps=200,
    log_mask_diagnostics=True,
    mask_dim=None,
    masking_reward_fn=None,
    masking_for_exploration=True,
    rotate_masks_for_eval=False,
    rotate_masks_for_expl=False,
    mask_distribution=None,
    num_steps_per_mask_change=10,
    tag=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    context_key = desired_goal_key
    env = get_gym_env(train_env_id, env_class=env_class, env_kwargs=env_kwargs)
    mask_dim = (mask_dim or env.observation_space.spaces[context_key].low.size)

    if not do_masking:
        mask_distribution = 'all_ones'

    assert mask_distribution in [
        'one_hot_masks', 'random_bit_masks', 'all_ones'
    ]
    if mask_distribution == 'all_ones':
        mask_distribution = 'static_mask'

    def contextual_env_distrib_and_reward(
        env_id,
        env_class,
        env_kwargs,
        goal_sampling_mode,
        env_mask_distribution_type,
        static_mask=None,
    ):
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        env.goal_sampling_mode = goal_sampling_mode
        use_static_mask = (not do_masking
                           or env_mask_distribution_type == 'static_mask')
        # Default to all ones mask if static mask isn't defined and should
        # be using static masks.
        if use_static_mask and static_mask is None:
            static_mask = np.ones(mask_dim)
        if not do_masking:
            assert env_mask_distribution_type == 'static_mask'

        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            env,
            desired_goal_keys=[desired_goal_key],
        )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=mask_dim,
            distribution_type=env_mask_distribution_type,
            static_mask=static_mask,
        )

        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=env,
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                observation_key),
        )

        if do_masking:
            if masking_reward_fn:
                reward_fn = partial(masking_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)
            else:
                reward_fn = partial(default_masked_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )

        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
        )
        return env, goal_distribution, reward_fn

    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        train_env_id,
        env_class,
        env_kwargs,
        exploration_goal_sampling_mode,
        mask_distribution if masking_for_exploration else 'static_mask',
    )
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        eval_env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        'static_mask')

    # Distribution for relabeling
    relabel_context_distrib = GoalDictDistributionFromMultitaskEnv(
        env,
        desired_goal_keys=[desired_goal_key],
    )
    relabel_context_distrib = MaskedGoalDictDistribution(
        relabel_context_distrib,
        mask_key=mask_key,
        mask_dim=mask_dim,
        distribution_type=mask_distribution,
        static_mask=None if do_masking else np.ones(mask_dim),
    )

    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size +
               mask_dim)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def context_from_obs_dict_fn(obs_dict):
        achieved_goal = obs_dict['state_achieved_goal']
        # Should the mask be randomized for future relabeling?
        # batch_size = len(achieved_goal)
        # mask = np.random.choice(2, size=(batch_size, mask_dim))
        mask = obs_dict[mask_key]
        return {
            mask_key: mask,
            context_key: achieved_goal,
        }

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']

        context = batch[context_key]
        mask = batch[mask_key]

        batch['observations'] = np.concatenate([obs, context, mask], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context, mask],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key, mask_key],
        observation_keys=[observation_key, 'state_achieved_goal'],
        context_distribution=relabel_context_distrib,
        sample_context_from_obs_dict_fn=context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    def create_path_collector(env, policy, is_rotating):
        if is_rotating:
            assert do_masking and mask_distribution == 'one_hot_masks'
            return RotatingMaskingPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[context_key, mask_key],
                mask_key=mask_key,
                mask_length=mask_dim,
                num_steps_per_mask_change=num_steps_per_mask_change,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[context_key, mask_key],
            )

    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)

    eval_path_collector = create_path_collector(eval_env,
                                                MakeDeterministic(policy),
                                                rotate_masks_for_eval)
    expl_path_collector = create_path_collector(expl_env, exploration_policy,
                                                rotate_masks_for_expl)

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            context_keys_for_policy=[context_key, mask_key],
            # Eval on everything for the base video
            obs_processor=lambda o: np.hstack(
                (o[observation_key], o[context_key], np.ones(mask_dim))))
        renderer = EnvRenderer(**renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImageEnv(state_env, renderer=renderer)
            return ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=eval_reward,
                observation_key=observation_key,
                # update_env_info_fn=DeleteOldEnvInfo(),
            )

        img_eval_env = add_images(eval_env, eval_context_distrib)
        img_expl_env = add_images(expl_env, expl_context_distrib)
        eval_video_func = get_save_video_function(
            rollout_function,
            img_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=renderer.image_shape[0],
            image_format='CWH',
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="train",
            imsize=renderer.image_shape[0],
            image_format='CWH',
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    # For diagnostics, evaluate the policy on each individual dimension of the
    # mask.
    masks = []
    collectors = []
    for mask_idx in range(mask_dim):
        mask = np.zeros(mask_dim)
        mask[mask_idx] = 1
        masks.append(mask)

    for mask in masks:
        for_dim_mask = mask if do_masking else np.ones(mask_dim)
        masked_env, _, _ = contextual_env_distrib_and_reward(
            eval_env_id,
            env_class,
            env_kwargs,
            evaluation_goal_sampling_mode,
            'static_mask',
            static_mask=for_dim_mask)

        collector = ContextualPathCollector(
            masked_env,
            MakeDeterministic(policy),
            observation_key=observation_key,
            context_keys_for_policy=[context_key, mask_key],
        )
        collectors.append(collector)
    log_prefixes = [
        'mask_{}/'.format(''.join(mask.astype(int).astype(str)))
        for mask in masks
    ]

    def get_mask_diagnostics(unused):
        from rlkit.core.logging import append_log, add_prefix, OrderedDict
        from rlkit.misc import eval_util
        log = OrderedDict()
        for prefix, collector in zip(log_prefixes, collectors):
            paths = collector.collect_new_paths(
                max_path_length,
                masking_eval_steps,
                discard_incomplete_paths=True,
            )
            generic_info = add_prefix(
                eval_util.get_generic_path_information(paths),
                prefix,
            )
            append_log(log, generic_info)

        for collector in collectors:
            collector.end_epoch(0)
        return log

    if log_mask_diagnostics:
        algorithm._eval_get_diag_fns.append(get_mask_diagnostics)

    algorithm.train()