Beispiel #1
0
def analyze_from_vae(snapshot_path,
                     latent_observation_key='latent_observation',
                     mean_key='latent_mean',
                     covariance_key='latent_covariance',
                     image_observation_key='image_observation',
                     **kwargs):
    data = torch.load(open(snapshot_path, "rb"))
    variant_path = snapshot_path.replace('params.pt', 'variant.json')
    print_settings(variant_path)
    vae = data['trainer/vae']
    state_env = gym.make('OneObject-PickAndPlace-BigBall-RandomInit-2D-v1')
    renderer = EnvRenderer()
    sets = make_custom_sets(state_env, renderer)
    reward_fn, _ = rewards.create_normal_likelihood_reward_fns(
        latent_observation_key=latent_observation_key,
        mean_key=mean_key,
        covariance_key=covariance_key,
        reward_fn_kwargs=dict(
            drop_log_det_term=True,
            sqrt_reward=True,
        ),
    )

    img_env = InsertImageEnv(state_env, renderer=renderer)
    env = DictEncoderWrappedEnv(
        img_env,
        vae,
        encoder_input_key='image_observation',
        encoder_output_remapping={'posterior_mean': 'latent_observation'},
    )
    analyze(sets, vae, env, **kwargs)
Beispiel #2
0
def main():
    register_all_envs()
    # env = PickAndPlaceEnv(
    #     # Environment dynamics
    #     action_scale=1.0,
    #     boundary_dist=4,
    #     ball_radius=1.5,
    #     object_radius=1.,
    #     ball_visual_radius=1.5,
    #     object_visual_radius=1.,
    #     min_grab_distance=1.,
    #     walls=None,
    #     # Rewards
    #     action_l2norm_penalty=0,
    #     reward_type="dense",
    #     success_threshold=0.60,
    #     # Reset settings
    #     fixed_goal=None,
    #     # Visualization settings
    #     images_are_rgb=True,
    #     render_dt_msec=0,
    #     render_onscreen=False,
    #     render_size=84,
    #     show_goal=False,
    #     goal_samplers=None,
    #     goal_sampling_mode='random',
    #     num_presampled_goals=10000,
    #     object_reward_only=False,
    #
    #     init_position_strategy='random',
    #     num_objects=1,
    # )
    env = gym.make('OneObject-PickAndPlace-BigBall-RandomInit-2D-v1')

    renderer = EnvRenderer(
        output_image_format='CHW',
        width=28,
        height=28,
    )
    import cv2
    from PIL import Image
    n = 12800
    imgs = []
    for _ in range(n):
        env.reset()
        img = renderer(env)
        # cv2.imshow('img', img.transpose())
        # cv2.waitKey(100)
        imgs.append(img)
    imgs = np.array(imgs)
    np.save(
        '/home/vitchyr/mnt/log/manual-upload/sets/OneObject-PickAndPlace-BigBall-RandomInit-2D-v1-ungrouped-train-28x28.npy',
        imgs,
    )
Beispiel #3
0
 def get_video_func(
     env,
     policy,
     tag,
 ):
     renderer = EnvRenderer(**renderer_kwargs)
     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
         env,
         desired_goal_keys=[desired_goal_key],
     )
     image_goal_distribution = AddImageDistribution(
         env=env,
         base_distribution=state_goal_distribution,
         image_goal_key="image_desired_goal",
         renderer=renderer,
     )
     img_env = InsertImageEnv(env, renderer=renderer)
     rollout_function = partial(
         rf.multitask_rollout,
         max_path_length=variant["max_path_length"],
         observation_key=observation_key,
         desired_goal_key=desired_goal_key,
         return_dict_obs=True,
     )
     reward_fn = ContextualRewardFnFromMultitaskEnv(
         env=env,
         achieved_goal_from_observation=IndexIntoAchievedGoal(
             observation_key),
         desired_goal_key=desired_goal_key,
         achieved_goal_key="state_achieved_goal",
     )
     contextual_env = ContextualEnv(
         img_env,
         context_distribution=image_goal_distribution,
         reward_fn=reward_fn,
         observation_key=observation_key,
     )
     video_func = get_save_video_function(
         rollout_function,
         contextual_env,
         policy,
         tag=tag,
         imsize=renderer.width,
         image_format="CWH",
         **save_video_kwargs,
     )
     return video_func
Beispiel #4
0
def create_dataset(env_kwargs, generate_set_kwargs, num_ungrouped_images):
    env = create_env(**env_kwargs)
    renderer = EnvRenderer(output_image_format='CHW')
    import time
    print("making train set")
    start = time.time()
    set_imgs = generate_set_images(env, renderer, **generate_set_kwargs)
    set_imgs = ptu.from_numpy(np.array(list(set_imgs)))
    print("making eval set", time.time() - start)
    start = time.time()
    eval_set_imgs = generate_set_images(env, renderer, **generate_set_kwargs)
    eval_set_imgs = ptu.from_numpy(np.array(list(eval_set_imgs)))
    set_imgs_iterator = set_imgs  # an array is already a valid data iterator
    print("making ungrouped images", time.time() - start)
    start = time.time()
    ungrouped_imgs = generate_images(env,
                                     renderer,
                                     num_images=num_ungrouped_images)
    ungrouped_imgs = ptu.from_numpy(np.array(list(ungrouped_imgs)))
    print("done", time.time() - start)
    return eval_set_imgs, renderer, set_imgs, set_imgs_iterator, ungrouped_imgs
Beispiel #5
0
def analyze_from_vae(
        snapshot_path,
        latent_observation_key='latent_observation',
        mean_key='latent_mean',
        covariance_key='latent_covariance',
        image_observation_key='image_observation',
        **kwargs
):
    data = torch.load(open(snapshot_path, "rb"))
    variant_path = snapshot_path.replace('params.pt', 'variant.json')
    vae = data['trainer/vae']
    state_env = gym.make('OneObject-PickAndPlace-BigBall-RandomInit-2D-v1')
    renderer = EnvRenderer()
    sets = make_custom_sets(state_env, renderer)
    reward_fn, _ = rewards.create_normal_likelihood_reward_fns(
        latent_observation_key=latent_observation_key,
        mean_key=mean_key,
        covariance_key=covariance_key,
        reward_fn_kwargs=dict(
            drop_log_det_term=True,
            sqrt_reward=True,
        ),
    )
    save_reward_visualizations(sets, vae, state_env, renderer, **kwargs)
Beispiel #6
0
def create_dataset(env_id,
                   env_class,
                   env_kwargs,
                   generate_set_kwargs,
                   num_ungrouped_images,
                   env=None,
                   renderer=None,
                   sets=None):
    # env = env or create_env(**env_kwargs)
    env = env or get_gym_env(env_id, env_class, env_kwargs)
    renderer = renderer or EnvRenderer(output_image_format='CHW')
    import time
    print("making train set")
    start = time.time()
    if sets is None:
        set_imgs = pnp_util.generate_set_images(env, renderer,
                                                **generate_set_kwargs)
        set_imgs = list(set_imgs)
    else:
        set_imgs = np.array(
            [set.example_dict['example_image'] for set in sets])
    set_imgs = ptu.from_numpy(np.array(set_imgs))
    print("making eval set", time.time() - start)
    start = time.time()
    eval_set_imgs = pnp_util.generate_set_images(env, renderer,
                                                 **generate_set_kwargs)
    eval_set_imgs = ptu.from_numpy(np.array(list(eval_set_imgs)))
    set_imgs_iterator = set_imgs  # an array is already a valid data iterator
    print("making ungrouped images", time.time() - start)
    start = time.time()
    ungrouped_imgs = generate_images(env,
                                     renderer,
                                     num_images=num_ungrouped_images)
    ungrouped_imgs = ptu.from_numpy(np.array(list(ungrouped_imgs)))
    print("done", time.time() - start)
    return eval_set_imgs, renderer, set_imgs, set_imgs_iterator, ungrouped_imgs
Beispiel #7
0
    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode,
                                          presampled_goals_path):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        # encoded_env = EncoderWrappedEnv(
        #     img_env,
        #     model,
        #     dict(image_observation="latent_observation", ),
        # )
        # if goal_sampling_mode == "vae_prior":
        #     latent_goal_distribution = PriorDistribution(
        #         model.representation_size,
        #         desired_goal_key,
        #     )
        #     diagnostics = StateImageGoalDiagnosticsFn({}, )
        # elif goal_sampling_mode == "presampled":
        #     diagnostics = state_env.get_contextual_diagnostics
        #     image_goal_distribution = PresampledPathDistribution(
        #         presampled_goals_path,
        #     )

        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        # elif goal_sampling_mode == "reset_of_env":
        #     state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        #     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
        #         state_goal_env,
        #         desired_goal_keys=[state_goal_key],
        #     )
        #     image_goal_distribution = AddImageDistribution(
        #         env=state_env,
        #         base_distribution=state_goal_distribution,
        #         image_goal_key=image_goal_key,
        #         renderer=renderer,
        #     )
        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        #     no_goal_distribution = PriorDistribution(
        #         representation_size=0,
        #         key="no_goal",
        #     )
        #     diagnostics = state_goal_env.get_contextual_diagnostics
        # else:
        #     error
        diagnostics = StateImageGoalDiagnosticsFn({}, )
        no_goal_distribution = PriorDistribution(
            representation_size=0,
            key="no_goal",
        )

        reward_fn = GraspingRewardFn(
            # img_env, # state_env,
            # observation_key=observation_key,
            # desired_goal_key=desired_goal_key,
            # **reward_kwargs
        )

        env = ContextualEnv(
            img_env,  # state_env,
            context_distribution=no_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, no_goal_distribution, reward_fn
Beispiel #8
0
def goal_conditioned_sac_experiment(
        max_path_length,
        qf_kwargs,
        sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        algo_kwargs,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='state_observation',
        desired_goal_key='state_desired_goal',
        achieved_goal_key='state_achieved_goal',
        exploration_policy_kwargs=None,
        evaluation_goal_sampling_mode=None,
        exploration_goal_sampling_mode=None,
        # Video parameters
        save_video=True,
        save_video_kwargs=None,
        renderer_kwargs=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}
    context_key = desired_goal_key
    sample_context_from_obs_dict_fn = RemapKeyFn({context_key: observation_key})

    def contextual_env_distrib_and_reward(
            env_id, env_class, env_kwargs, goal_sampling_mode
    ):
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        env.goal_sampling_mode = goal_sampling_mode
        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            env,
            desired_goal_keys=[desired_goal_key],
        )
        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=env,
            achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key),
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
        )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, goal_distribution, reward_fn


    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode
    )
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode
    )

    obs_dim = (
            expl_env.observation_space.spaces[observation_key].low.size
            + expl_env.observation_space.spaces[context_key].low.size
    )
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **qf_kwargs
        )
    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **policy_kwargs
    )

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context], axis=1)
        return batch
    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key],
        observation_keys=[observation_key],
        context_distribution=eval_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs
    )
    trainer = SACTrainer(
        env=expl_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **sac_trainer_kwargs
    )

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )
    exploration_policy = create_exploration_policy(
        policy=policy, env=expl_env, **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs
    )
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            context_keys_for_policy=[context_key],
        )
        renderer = EnvRenderer(**renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImageEnv(state_env, renderer=renderer)
            return ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=eval_reward,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
        img_eval_env = add_images(eval_env, eval_context_distrib)
        img_expl_env = add_images(expl_env, expl_context_distrib)
        eval_video_func = get_save_video_function(
            rollout_function,
            img_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=renderer.width,
            image_format=renderer.output_image_format,
            **save_video_kwargs
        )
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="train",
            imsize=renderer.width,
            image_format=renderer.output_image_format,
            **save_video_kwargs
        )

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    algorithm.train()
Beispiel #9
0
def image_based_goal_conditioned_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    cnn_kwargs,
    policy_type='tanh-normal',
    env_id=None,
    env_class=None,
    env_kwargs=None,
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    reward_type='state_distance',
    env_renderer_kwargs=None,
    # Data augmentations
    apply_random_crops=False,
    random_crop_pixel_shift=4,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not env_renderer_kwargs:
        env_renderer_kwargs = {}
    if not video_renderer_kwargs:
        video_renderer_kwargs = {}
    img_observation_key = 'image_observation'
    img_desired_goal_key = 'image_desired_goal'
    state_observation_key = 'state_observation'
    state_desired_goal_key = 'state_desired_goal'
    state_achieved_goal_key = 'state_achieved_goal'
    sample_context_from_obs_dict_fn = RemapKeyFn({
        'image_desired_goal':
        'image_observation',
        'state_desired_goal':
        'state_observation',
    })

    def setup_contextual_env(env_id, env_class, env_kwargs, goal_sampling_mode,
                             renderer):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        state_env.goal_sampling_mode = goal_sampling_mode
        state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        image_goal_distribution = AddImageDistribution(
            env=state_env,
            base_distribution=state_goal_distribution,
            image_goal_key=img_desired_goal_key,
            renderer=renderer,
        )
        goal_distribution = PresampledDistribution(image_goal_distribution,
                                                   5000)
        img_env = InsertImageEnv(state_env, renderer=renderer)
        if reward_type == 'state_distance':
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=state_env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    'state_observation'),
                desired_goal_key=state_desired_goal_key,
                achieved_goal_key=state_achieved_goal_key,
            )
        elif reward_type == 'pixel_distance':
            reward_fn = NegativeL2Distance(
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    img_observation_key),
                desired_goal_key=img_desired_goal_key,
            )
        else:
            raise ValueError(reward_type)
        env = ContextualEnv(
            img_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=img_observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, goal_distribution, reward_fn

    env_renderer = EnvRenderer(**env_renderer_kwargs)
    expl_env, expl_context_distrib, expl_reward = setup_contextual_env(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode,
        env_renderer)
    eval_env, eval_context_distrib, eval_reward = setup_contextual_env(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        env_renderer)

    action_dim = expl_env.action_space.low.size
    if env_renderer.output_image_format == 'WHC':
        img_width, img_height, img_num_channels = (
            expl_env.observation_space[img_observation_key].shape)
    elif env_renderer.output_image_format == 'CHW':
        img_num_channels, img_height, img_width = (
            expl_env.observation_space[img_observation_key].shape)
    else:
        raise ValueError(env_renderer.output_image_format)

    def create_qf():
        cnn = BasicCNN(input_width=img_width,
                       input_height=img_height,
                       input_channels=img_num_channels,
                       **cnn_kwargs)
        joint_cnn = ApplyConvToStateAndGoalImage(cnn)
        return basic.MultiInputSequential(
            ApplyToObs(joint_cnn), basic.FlattenEachParallel(),
            ConcatMlp(input_size=joint_cnn.output_size + action_dim,
                      output_size=1,
                      **qf_kwargs))

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()
    cnn = BasicCNN(input_width=img_width,
                   input_height=img_height,
                   input_channels=img_num_channels,
                   **cnn_kwargs)
    joint_cnn = ApplyConvToStateAndGoalImage(cnn)
    policy_obs_dim = joint_cnn.output_size
    if policy_type == 'normal':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(Gaussian(obs_processor))
    elif policy_type == 'tanh-normal':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(TanhGaussian(obs_processor))
    elif policy_type == 'normal-tanh-mean':
        obs_processor = nn.Sequential(
            joint_cnn, basic.Flatten(),
            MultiHeadedMlp(input_size=policy_obs_dim,
                           output_sizes=[action_dim, action_dim],
                           output_activations=['tanh', 'identity'],
                           **policy_kwargs))
        policy = PolicyFromDistributionGenerator(Gaussian(obs_processor))
    else:
        raise ValueError("Unknown policy type: {}".format(policy_type))

    if apply_random_crops:
        pad = BatchPad(
            env_renderer.output_image_format,
            random_crop_pixel_shift,
            random_crop_pixel_shift,
        )
        crop = JointRandomCrop(
            env_renderer.output_image_format,
            env_renderer.image_shape,
        )

        def concat_context_to_obs(batch, *args, **kwargs):
            obs = batch['observations']
            next_obs = batch['next_observations']
            context = batch[img_desired_goal_key]
            obs_padded = pad(obs)
            next_obs_padded = pad(next_obs)
            context_padded = pad(context)
            obs_aug, context_aug = crop(obs_padded, context_padded)
            next_obs_aug, next_context_aug = crop(next_obs_padded,
                                                  context_padded)

            batch['observations'] = np.concatenate([obs_aug, context_aug],
                                                   axis=1)
            batch['next_observations'] = np.concatenate(
                [next_obs_aug, next_context_aug], axis=1)
            return batch
    else:

        def concat_context_to_obs(batch, *args, **kwargs):
            obs = batch['observations']
            next_obs = batch['next_observations']
            context = batch[img_desired_goal_key]
            batch['observations'] = np.concatenate([obs, context], axis=1)
            batch['next_observations'] = np.concatenate([next_obs, context],
                                                        axis=1)
            return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[img_desired_goal_key, state_desired_goal_key],
        observation_key=img_observation_key,
        observation_keys=[img_observation_key, state_observation_key],
        context_distribution=eval_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=img_observation_key,
        context_keys_for_policy=[img_desired_goal_key],
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=img_observation_key,
        context_keys_for_policy=[img_desired_goal_key],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=img_observation_key,
            context_keys_for_policy=[img_desired_goal_key],
        )
        video_renderer = EnvRenderer(**video_renderer_kwargs)
        video_eval_env = InsertImageEnv(eval_env,
                                        renderer=video_renderer,
                                        image_key='video_observation')
        video_expl_env = InsertImageEnv(expl_env,
                                        renderer=video_renderer,
                                        image_key='video_observation')
        video_eval_env = ContextualEnv(
            video_eval_env,
            context_distribution=eval_env.context_distribution,
            reward_fn=lambda *_: np.array([0]),
            observation_key=img_observation_key,
        )
        video_expl_env = ContextualEnv(
            video_expl_env,
            context_distribution=expl_env.context_distribution,
            reward_fn=lambda *_: np.array([0]),
            observation_key=img_observation_key,
        )
        eval_video_func = get_save_video_function(
            rollout_function,
            video_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=video_renderer.image_shape[1],
            image_formats=[
                env_renderer.output_image_format,
                env_renderer.output_image_format,
                video_renderer.output_image_format,
            ],
            keys_to_show=[
                'image_desired_goal', 'image_observation', 'video_observation'
            ],
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            video_expl_env,
            exploration_policy,
            tag="xplor",
            imsize=video_renderer.image_shape[1],
            image_formats=[
                env_renderer.output_image_format,
                env_renderer.output_image_format,
                video_renderer.output_image_format,
            ],
            keys_to_show=[
                'image_desired_goal', 'image_observation', 'video_observation'
            ],
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    algorithm.train()
Beispiel #10
0
def rig_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    train_vae_kwargs,
    env_id=None,
    env_class=None,
    env_kwargs=None,
    observation_key='latent_observation',
    desired_goal_key='latent_desired_goal',
    state_goal_key='state_desired_goal',
    state_observation_key='state_observation',
    image_goal_key='image_desired_goal',
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
    imsize=48,
    pretrained_vae_path="",
    init_camera=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)

    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)

        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        encoded_env = EncoderWrappedEnv(
            img_env,
            model,
            dict(image_observation="latent_observation", ),
        )
        if goal_sampling_mode == "vae_prior":
            latent_goal_distribution = PriorDistribution(
                model.representation_size,
                desired_goal_key,
            )
            diagnostics = StateImageGoalDiagnosticsFn({}, )
        elif goal_sampling_mode == "reset_of_env":
            state_goal_env = get_gym_env(env_id,
                                         env_class=env_class,
                                         env_kwargs=env_kwargs)
            state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
                state_goal_env,
                desired_goal_keys=[state_goal_key],
            )
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_goal_distribution,
                image_goal_key=image_goal_key,
                renderer=renderer,
            )
            latent_goal_distribution = AddLatentDistribution(
                image_goal_distribution,
                image_goal_key,
                desired_goal_key,
                model,
            )
            if hasattr(state_goal_env, 'goal_conditioned_diagnostics'):
                diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics(
                    state_goal_env.goal_conditioned_diagnostics,
                    desired_goal_key=state_goal_key,
                    observation_key=state_observation_key,
                )
            else:
                state_goal_env.get_contextual_diagnostics
                diagnostics = state_goal_env.get_contextual_diagnostics
        else:
            raise NotImplementedError('unknown goal sampling method: %s' %
                                      goal_sampling_mode)

        reward_fn = DistanceRewardFn(
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )

        env = ContextualEnv(
            encoded_env,
            context_distribution=latent_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, latent_goal_distribution, reward_fn

    if pretrained_vae_path:
        model = load_local_or_remote_file(pretrained_vae_path)
    else:
        model = train_vae(train_vae_kwargs, env_kwargs, env_id, env_class,
                          imsize, init_camera)

    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode)
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode)
    context_key = desired_goal_key

    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key],
        observation_keys=[observation_key],
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=RemapKeyFn(
            {context_key: observation_key}),
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        expl_video_func = RIGVideoSaveFunction(
            model,
            expl_path_collector,
            "train",
            decode_goal_image_key="image_decoded_goal",
            reconstruction_key="image_reconstruction",
            rows=2,
            columns=5,
            unnormalize=True,
            # max_path_length=200,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(expl_video_func)

        eval_video_func = RIGVideoSaveFunction(
            model,
            eval_path_collector,
            "eval",
            goal_image_key=image_goal_key,
            decode_goal_image_key="image_decoded_goal",
            reconstruction_key="image_reconstruction",
            num_imgs=4,
            rows=2,
            columns=5,
            unnormalize=True,
            # max_path_length=200,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)

    algorithm.train()
Beispiel #11
0
 def get_img_env(env):
     renderer = EnvRenderer(**variant["renderer_kwargs"])
     img_env = InsertImageEnv(GymToMultiEnv(env), renderer=renderer)
Beispiel #12
0
def encoder_goal_conditioned_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    algo_kwargs,
    policy_kwargs,
    # Encoder parameters
    disentangled_qf_kwargs,
    use_parallel_qf=False,
    encoder_kwargs=None,
    encoder_cnn_kwargs=None,
    qf_state_encoder_is_goal_encoder=False,
    reward_type='encoder_distance',
    reward_config=None,
    latent_dim=8,
    # Policy params
    policy_using_encoder_settings=None,
    use_separate_encoder_for_policy=True,
    # Env settings
    env_id=None,
    env_class=None,
    env_kwargs=None,
    contextual_env_kwargs=None,
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    num_presampled_goals=5000,
    # Image env parameters
    use_image_observations=False,
    env_renderer_kwargs=None,
    # Video parameters
    save_video=True,
    save_debug_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
    # Debugging parameters
    visualize_representation=True,
    distance_scatterplot_save_period=0,
    distance_scatterplot_initial_save_period=0,
    debug_renderer_kwargs=None,
    debug_visualization_kwargs=None,
    use_debug_trainer=False,

    # vae stuff
    train_encoder_as_vae=False,
    vae_trainer_kwargs=None,
    decoder_kwargs=None,
    encoder_backprop_settings="rl_n_vae",
    mlp_for_image_decoder=False,
    pretrain_vae=False,
    pretrain_vae_kwargs=None,

    # Should only be set by transfer_encoder_goal_conditioned_sac_experiment
    is_retraining_from_scratch=False,
    train_results=None,
    rl_csv_fname='progress.csv',
):
    if reward_config is None:
        reward_config = {}
    if encoder_cnn_kwargs is None:
        encoder_cnn_kwargs = {}
    if policy_using_encoder_settings is None:
        policy_using_encoder_settings = {}
    if debug_visualization_kwargs is None:
        debug_visualization_kwargs = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if contextual_env_kwargs is None:
        contextual_env_kwargs = {}
    if encoder_kwargs is None:
        encoder_kwargs = {}
    if save_video_kwargs is None:
        save_video_kwargs = {}
    if video_renderer_kwargs is None:
        video_renderer_kwargs = {}
    if debug_renderer_kwargs is None:
        debug_renderer_kwargs = {}

    assert (is_retraining_from_scratch != (train_results is None))

    img_observation_key = 'image_observation'
    state_observation_key = 'state_observation'
    latent_desired_goal_key = 'latent_desired_goal'
    state_desired_goal_key = 'state_desired_goal'
    img_desired_goal_key = 'image_desired_goal'

    backprop_rl_into_encoder = False
    backprop_vae_into_encoder = False
    if 'rl' in encoder_backprop_settings:
        backprop_rl_into_encoder = True
    if 'vae' in encoder_backprop_settings:
        backprop_vae_into_encoder = True

    if use_image_observations:
        env_renderer = EnvRenderer(**env_renderer_kwargs)

    def setup_env(state_env, encoder, reward_fn):
        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            state_env,
            desired_goal_keys=[state_desired_goal_key],
        )
        if use_image_observations:
            goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=goal_distribution,
                image_goal_key=img_desired_goal_key,
                renderer=env_renderer,
            )
            base_env = InsertImageEnv(state_env, renderer=env_renderer)
            goal_distribution = PresampledDistribution(goal_distribution,
                                                       num_presampled_goals)
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key, img_desired_goal_key],
                encoder_input_key=img_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        else:
            base_env = state_env
            goal_distribution = EncodedGoalDictDistribution(
                goal_distribution,
                encoder=encoder,
                keys_to_keep=[state_desired_goal_key],
                encoder_input_key=state_desired_goal_key,
                encoder_output_key=latent_desired_goal_key,
            )
        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            state_env.goal_conditioned_diagnostics,
            desired_goal_key=state_desired_goal_key,
            observation_key=state_observation_key,
        )
        env = ContextualEnv(
            base_env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            contextual_diagnostics_fns=[state_diag_fn],
            update_env_info_fn=delete_info,
            **contextual_env_kwargs,
        )
        return env, goal_distribution

    state_expl_env = get_gym_env(env_id,
                                 env_class=env_class,
                                 env_kwargs=env_kwargs)
    state_expl_env.goal_sampling_mode = exploration_goal_sampling_mode
    state_eval_env = get_gym_env(env_id,
                                 env_class=env_class,
                                 env_kwargs=env_kwargs)
    state_eval_env.goal_sampling_mode = evaluation_goal_sampling_mode

    if use_image_observations:
        context_keys_to_save = [
            state_desired_goal_key,
            img_desired_goal_key,
            latent_desired_goal_key,
        ]
        context_key_for_rl = img_desired_goal_key
        observation_key_for_rl = img_observation_key

        def create_encoder():
            img_num_channels, img_height, img_width = env_renderer.image_chw
            cnn = BasicCNN(input_width=img_width,
                           input_height=img_height,
                           input_channels=img_num_channels,
                           **encoder_cnn_kwargs)
            cnn_output_size = np.prod(cnn.output_shape)
            mlp = MultiHeadedMlp(input_size=cnn_output_size,
                                 output_sizes=[latent_dim, latent_dim],
                                 **encoder_kwargs)
            enc = nn.Sequential(cnn, Flatten(), mlp)
            enc.input_size = img_width * img_height * img_num_channels
            enc.output_size = latent_dim
            return enc
    else:
        context_keys_to_save = [
            state_desired_goal_key, latent_desired_goal_key
        ]
        context_key_for_rl = state_desired_goal_key
        observation_key_for_rl = state_observation_key

        def create_encoder():
            in_dim = (state_expl_env.observation_space.
                      spaces[state_observation_key].low.size)
            enc = ConcatMultiHeadedMlp(input_size=in_dim,
                                       output_sizes=[latent_dim, latent_dim],
                                       **encoder_kwargs)
            enc.input_size = in_dim
            enc.output_size = latent_dim
            return enc

    encoder_net = create_encoder()
    if train_results:
        print("Using transfer encoder")
        encoder_net = train_results.encoder_net
    mu_encoder_net = EncoderMuFromEncoderDistribution(encoder_net)
    target_encoder_net = create_encoder()
    mu_target_encoder_net = EncoderMuFromEncoderDistribution(
        target_encoder_net)
    encoder_input_dim = encoder_net.input_size

    encoder = EncoderFromNetwork(mu_encoder_net)
    encoder.to(ptu.device)
    if reward_type == 'encoder_distance':
        reward_fn = EncoderRewardFnFromMultitaskEnv(
            encoder=encoder,
            next_state_encoder_input_key=observation_key_for_rl,
            context_key=latent_desired_goal_key,
            **reward_config,
        )
    elif reward_type == 'target_encoder_distance':
        target_encoder = EncoderFromNetwork(mu_target_encoder_net)
        reward_fn = EncoderRewardFnFromMultitaskEnv(
            encoder=target_encoder,
            next_state_encoder_input_key=observation_key_for_rl,
            context_key=latent_desired_goal_key,
            **reward_config,
        )
    elif reward_type == 'state_distance':
        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=state_expl_env,
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                'state_observation'),
            desired_goal_key=state_desired_goal_key,
            achieved_goal_key='state_achieved_goal',
        )
    else:
        raise ValueError("invalid reward type {}".format(reward_type))
    expl_env, expl_context_distrib = setup_env(state_expl_env, encoder,
                                               reward_fn)
    eval_env, eval_context_distrib = setup_env(state_eval_env, encoder,
                                               reward_fn)

    action_dim = expl_env.action_space.low.size

    def make_qf(goal_encoder):
        if not backprop_rl_into_encoder:
            goal_encoder = Detach(goal_encoder)
        if qf_state_encoder_is_goal_encoder:
            state_encoder = goal_encoder
        else:
            state_encoder = EncoderMuFromEncoderDistribution(create_encoder())
            raise RuntimeError(
                "State encoder must be goal encoder for resuming exps")
        if use_parallel_qf:
            return ParallelDisentangledMlpQf(
                goal_encoder=goal_encoder,
                state_encoder=state_encoder,
                preprocess_obs_dim=encoder_input_dim,
                action_dim=action_dim,
                post_encoder_mlp_kwargs=qf_kwargs,
                vectorized=True,
                **disentangled_qf_kwargs)
        else:
            return DisentangledMlpQf(goal_encoder=goal_encoder,
                                     state_encoder=state_encoder,
                                     preprocess_obs_dim=encoder_input_dim,
                                     action_dim=action_dim,
                                     qf_kwargs=qf_kwargs,
                                     vectorized=True,
                                     **disentangled_qf_kwargs)

    qf1 = make_qf(mu_encoder_net)
    qf2 = make_qf(mu_encoder_net)
    target_qf1 = make_qf(mu_target_encoder_net)
    target_qf2 = make_qf(mu_target_encoder_net)

    if use_separate_encoder_for_policy:
        policy_encoder = EncoderMuFromEncoderDistribution(create_encoder())
        policy_encoder_net = EncodeObsAndGoal(
            policy_encoder,
            encoder_input_dim,
            encode_state=True,
            encode_goal=True,
            detach_encoder_via_goal=False,
            detach_encoder_via_state=False,
        )
    else:
        policy_obs_encoder = mu_encoder_net
        if not backprop_rl_into_encoder:
            policy_obs_encoder = Detach(policy_obs_encoder)
        policy_encoder_net = EncodeObsAndGoal(policy_obs_encoder,
                                              encoder_input_dim,
                                              **policy_using_encoder_settings)
    obs_processor = nn.Sequential(
        policy_encoder_net, ConcatTuple(),
        MultiHeadedMlp(input_size=policy_encoder_net.output_size,
                       output_sizes=[action_dim, action_dim],
                       **policy_kwargs))
    policy = PolicyFromDistributionGenerator(TanhGaussian(obs_processor))

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key_for_rl]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        batch['raw_next_observations'] = next_obs
        return batch

    if use_image_observations:
        # Do this so that the context has all two/three: the state, image, and
        # encoded goal
        sample_context_from_observation = compose(
            RemapKeyFn({
                state_desired_goal_key: state_observation_key,
                img_desired_goal_key: img_observation_key,
            }),
            ReEncoderAchievedStateFn(
                encoder=encoder,
                encoder_input_key=context_key_for_rl,
                encoder_output_key=latent_desired_goal_key,
                keys_to_keep=[state_desired_goal_key, img_desired_goal_key],
            ),
        )
        ob_keys_to_save_in_buffer = [
            state_observation_key, img_observation_key
        ]
    else:
        sample_context_from_observation = compose(
            RemapKeyFn({
                state_desired_goal_key: state_observation_key,
            }),
            ReEncoderAchievedStateFn(
                encoder=encoder,
                encoder_input_key=context_key_for_rl,
                encoder_output_key=latent_desired_goal_key,
                keys_to_keep=[state_desired_goal_key],
            ),
        )
        ob_keys_to_save_in_buffer = [state_observation_key]

    encoder_output_dim = mu_encoder_net.output_size
    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=context_keys_to_save,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_observation,
        observation_keys=ob_keys_to_save_in_buffer,
        observation_key=observation_key_for_rl,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        reward_dim=encoder_output_dim,
        **replay_buffer_kwargs)

    disentangled_trainer = DisentangedTrainer(env=expl_env,
                                              policy=policy,
                                              qf1=qf1,
                                              qf2=qf2,
                                              target_qf1=target_qf1,
                                              target_qf2=target_qf2,
                                              **sac_trainer_kwargs)

    if train_encoder_as_vae:
        assert backprop_vae_into_encoder, \
            "No point in training the vae if not backpropagating into encoder"
        if vae_trainer_kwargs is None:
            vae_trainer_kwargs = {}
        if decoder_kwargs is None:
            decoder_kwargs = invert_encoder_mlp_params(encoder_kwargs)

        # VAE training
        def make_decoder():
            if use_image_observations:
                img_num_channels, img_height, img_width = env_renderer.image_chw
                if not mlp_for_image_decoder:
                    dcnn_in_channels, dcnn_in_height, dcnn_in_width = (
                        encoder_net._modules['0'].output_shape)
                    dcnn_input_size = (dcnn_in_channels * dcnn_in_width *
                                       dcnn_in_height)
                    decoder_cnn_kwargs = invert_encoder_params(
                        encoder_cnn_kwargs,
                        img_num_channels,
                    )
                    dcnn = BasicDCNN(input_width=dcnn_in_width,
                                     input_height=dcnn_in_height,
                                     input_channels=dcnn_in_channels,
                                     **decoder_cnn_kwargs)
                    mlp = Mlp(input_size=latent_dim,
                              output_size=dcnn_input_size,
                              **decoder_kwargs)
                    dec = nn.Sequential(mlp, dcnn)
                    dec.input_size = latent_dim
                else:
                    dec = nn.Sequential(
                        Mlp(input_size=latent_dim,
                            output_size=img_num_channels * img_height *
                            img_width,
                            **decoder_kwargs),
                        Reshape(img_num_channels, img_height, img_width))
                    dec.input_size = latent_dim
                    dec.output_size = img_num_channels * img_height * img_width
                return dec
            else:
                return Mlp(input_size=latent_dim,
                           output_size=encoder_input_dim,
                           **decoder_kwargs)

        decoder_net = make_decoder()
        vae = VAE(encoder_net, decoder_net)

        vae_trainer = VAETrainer(vae=vae, **vae_trainer_kwargs)

        if pretrain_vae:
            goal_key = (img_desired_goal_key
                        if use_image_observations else state_desired_goal_key)
            vae.to(ptu.device)
            train_ae(vae_trainer,
                     expl_context_distrib,
                     **pretrain_vae_kwargs,
                     goal_key=goal_key,
                     rl_csv_fname=rl_csv_fname)
        trainers = OrderedDict()
        trainers['vae_trainer'] = vae_trainer
        trainers['disentangled_trainer'] = disentangled_trainer
        trainer = JointTrainer(trainers)
    else:
        trainer = disentangled_trainer

    if not use_image_observations and use_debug_trainer:
        # TODO: implement this for images
        debug_trainer = DebugTrainer(
            observation_space=expl_env.observation_space.
            spaces[state_observation_key],
            encoder=mu_encoder_net,
            encoder_output_dim=encoder_output_dim,
        )
        trainer = JointTrainer([trainer, debug_trainer])

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key_for_rl,
        context_keys_for_policy=[context_key_for_rl],
    )
    exploration_policy = create_exploration_policy(policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key_for_rl,
        context_keys_for_policy=[context_key_for_rl],
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    video_renderer = EnvRenderer(**video_renderer_kwargs)
    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key_for_rl,
            context_keys_for_policy=[context_key_for_rl],
        )
        if save_debug_video and not use_image_observations:
            # TODO: add visualization for image-based envs
            obj1_sweep_renderers = {
                'sweep_obj1_%d' % i: DebugEnvRenderer(encoder, i,
                                                      **debug_renderer_kwargs)
                for i in range(encoder_output_dim)
            }
            obj0_sweep_renderers = {
                'sweep_obj0_%d' % i: DebugEnvRenderer(encoder, i,
                                                      **debug_renderer_kwargs)
                for i in range(encoder_output_dim)
            }

            debugger_one = DebugEnvRenderer(encoder, 0,
                                            **debug_renderer_kwargs)

            low = eval_env.env.observation_space[
                state_observation_key].low.min()
            high = eval_env.env.observation_space[
                state_observation_key].high.max()
            y = np.linspace(low, high, num=debugger_one.image_shape[0])
            x = np.linspace(low, high, num=debugger_one.image_shape[1])
            cross = np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])

            def create_shared_data_creator(obj_index):
                def compute_shared_data(raw_obs, env):
                    state = raw_obs[observation_key_for_rl]
                    obs = state[:2]
                    goal = state[2:]
                    if obj_index == 0:
                        new_states = np.concatenate(
                            [
                                np.repeat(obs[None, :], cross.shape[0],
                                          axis=0),
                                cross,
                            ],
                            axis=1,
                        )
                    elif obj_index == 1:
                        new_states = np.concatenate(
                            [
                                cross,
                                np.repeat(
                                    goal[None, :], cross.shape[0], axis=0),
                            ],
                            axis=1,
                        )
                    else:
                        raise ValueError(obj_index)
                    return encoder.encode(new_states)

                return compute_shared_data

            obj0_sweeper = create_shared_data_creator(0)
            obj1_sweeper = create_shared_data_creator(1)

            def add_images(env, base_distribution):
                if use_image_observations:
                    img_env = env
                    image_goal_distribution = base_distribution
                else:
                    state_env = env.env
                    image_goal_distribution = AddImageDistribution(
                        env=state_env,
                        base_distribution=base_distribution,
                        image_goal_key='image_desired_goal',
                        renderer=video_renderer,
                    )
                    img_env = InsertImageEnv(state_env,
                                             renderer=video_renderer)
                img_env = InsertDebugImagesEnv(
                    img_env,
                    obj1_sweep_renderers,
                    compute_shared_data=obj1_sweeper,
                )
                img_env = InsertDebugImagesEnv(
                    img_env,
                    obj0_sweep_renderers,
                    compute_shared_data=obj0_sweeper,
                )
                return ContextualEnv(
                    img_env,
                    context_distribution=image_goal_distribution,
                    reward_fn=reward_fn,
                    observation_key=observation_key_for_rl,
                    update_env_info_fn=delete_info,
                )

            img_eval_env = add_images(eval_env, eval_context_distrib)
            img_expl_env = add_images(expl_env, expl_context_distrib)

            def get_extra_imgs(
                path,
                index_in_path,
                env,
            ):
                return [
                    path['full_observations'][index_in_path][key]
                    for key in obj1_sweep_renderers
                ] + [
                    path['full_observations'][index_in_path][key]
                    for key in obj0_sweep_renderers
                ]

            img_formats = [video_renderer.output_image_format]
            for r in obj1_sweep_renderers.values():
                img_formats.append(r.output_image_format)
            for r in obj0_sweep_renderers.values():
                img_formats.append(r.output_image_format)
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                MakeDeterministic(policy),
                tag="eval",
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                get_extra_imgs=get_extra_imgs,
                **save_video_kwargs)
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                exploration_policy,
                tag="train",
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                get_extra_imgs=get_extra_imgs,
                **save_video_kwargs)
        else:
            video_renderer = EnvRenderer(**video_renderer_kwargs)

            def add_images(env, base_distribution):
                if use_image_observations:
                    video_env = InsertImageEnv(
                        env,
                        renderer=video_renderer,
                        image_key='video_observation',
                    )
                    image_goal_distribution = base_distribution
                else:
                    video_env = InsertImageEnv(
                        env,
                        renderer=video_renderer,
                        image_key='image_observation',
                    )
                    state_env = env.env
                    image_goal_distribution = AddImageDistribution(
                        env=state_env,
                        base_distribution=base_distribution,
                        image_goal_key='image_desired_goal',
                        renderer=video_renderer,
                    )
                return ContextualEnv(
                    video_env,
                    context_distribution=image_goal_distribution,
                    reward_fn=reward_fn,
                    observation_key=observation_key_for_rl,
                    update_env_info_fn=delete_info,
                )

            img_eval_env = add_images(eval_env, eval_context_distrib)
            img_expl_env = add_images(expl_env, expl_context_distrib)

            if use_image_observations:
                keys_to_show = [
                    'image_desired_goal',
                    'image_observation',
                    'video_observation',
                ]
                image_formats = [
                    env_renderer.output_image_format,
                    env_renderer.output_image_format,
                    video_renderer.output_image_format,
                ]
            else:
                keys_to_show = ['image_desired_goal', 'image_observation']
                image_formats = [
                    video_renderer.output_image_format,
                    video_renderer.output_image_format,
                ]
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                MakeDeterministic(policy),
                tag="eval",
                imsize=video_renderer.image_chw[1],
                keys_to_show=keys_to_show,
                image_formats=image_formats,
                **save_video_kwargs)
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                exploration_policy,
                tag="train",
                imsize=video_renderer.image_chw[1],
                keys_to_show=keys_to_show,
                image_formats=image_formats,
                **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)
    if visualize_representation:
        from multiworld.envs.pygame import PickAndPlaceEnv
        if not isinstance(state_eval_env, PickAndPlaceEnv):
            raise NotImplementedError()
        num_objects = (state_eval_env.observation_space[state_observation_key].
                       low.size) // 2
        state_space = state_eval_env.observation_space['state_observation']
        start_states = np.vstack([state_space.sample() for _ in range(6)])
        if use_image_observations:

            def state_to_encoder_input(state):
                goal_dict = {
                    'state_desired_goal': state,
                }
                env_state = state_eval_env.get_env_state()
                state_eval_env.set_to_goal(goal_dict)
                start_img = env_renderer(state_eval_env)
                state_eval_env.set_env_state(env_state)
                return start_img

            for i in range(num_objects):
                visualize_representation = create_visualize_representation(
                    encoder,
                    i,
                    eval_env,
                    video_renderer,
                    state_to_encoder_input=state_to_encoder_input,
                    env_renderer=env_renderer,
                    start_states=start_states,
                    **debug_visualization_kwargs)
                algorithm.post_train_funcs.append(visualize_representation)
        else:
            for i in range(num_objects):
                visualize_representation = create_visualize_representation(
                    encoder, i, eval_env, video_renderer,
                    **debug_visualization_kwargs)
                algorithm.post_train_funcs.append(visualize_representation)

    if distance_scatterplot_save_period > 0:
        algorithm.post_train_funcs.append(
            create_save_h_vs_state_distance_fn(
                distance_scatterplot_save_period,
                distance_scatterplot_initial_save_period,
                encoder,
                observation_key_for_rl,
            ))
    algorithm.train()
    train_results = dict(encoder_net=encoder_net, )
    TrainResults = namedtuple('TrainResults', sorted(train_results))
    return TrainResults(**train_results)
Beispiel #13
0
    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)

        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        encoded_env = EncoderWrappedEnv(
            img_env,
            model,
            dict(image_observation="latent_observation", ),
        )
        if goal_sampling_mode == "vae_prior":
            latent_goal_distribution = PriorDistribution(
                model.representation_size,
                desired_goal_key,
            )
            diagnostics = StateImageGoalDiagnosticsFn({}, )
        elif goal_sampling_mode == "reset_of_env":
            state_goal_env = get_gym_env(env_id,
                                         env_class=env_class,
                                         env_kwargs=env_kwargs)
            state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
                state_goal_env,
                desired_goal_keys=[state_goal_key],
            )
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_goal_distribution,
                image_goal_key=image_goal_key,
                renderer=renderer,
            )
            latent_goal_distribution = AddLatentDistribution(
                image_goal_distribution,
                image_goal_key,
                desired_goal_key,
                model,
            )
            if hasattr(state_goal_env, 'goal_conditioned_diagnostics'):
                diagnostics = GoalConditionedDiagnosticsToContextualDiagnostics(
                    state_goal_env.goal_conditioned_diagnostics,
                    desired_goal_key=state_goal_key,
                    observation_key=state_observation_key,
                )
            else:
                state_goal_env.get_contextual_diagnostics
                diagnostics = state_goal_env.get_contextual_diagnostics
        else:
            raise NotImplementedError('unknown goal sampling method: %s' %
                                      goal_sampling_mode)

        reward_fn = DistanceRewardFn(
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )

        env = ContextualEnv(
            encoded_env,
            context_distribution=latent_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, latent_goal_distribution, reward_fn
Beispiel #14
0
def probabilistic_goal_reaching_experiment(
    max_path_length,
    qf_kwargs,
    policy_kwargs,
    pgr_trainer_kwargs,
    replay_buffer_kwargs,
    algo_kwargs,
    env_id,
    discount_factor,
    reward_type,
    # Dynamics model
    dynamics_model_version,
    dynamics_model_config,
    dynamics_delta_model_config=None,
    dynamics_adam_config=None,
    dynamics_ensemble_kwargs=None,
    # Discount model
    learn_discount_model=False,
    discount_adam_config=None,
    discount_model_config=None,
    prior_discount_weight_schedule_kwargs=None,
    # Environment
    env_class=None,
    env_kwargs=None,
    observation_key='state_observation',
    desired_goal_key='state_desired_goal',
    exploration_policy_kwargs=None,
    action_noise_scale=0.,
    num_presampled_goals=4096,
    success_threshold=0.05,
    # Video / visualization parameters
    save_video=True,
    save_video_kwargs=None,
    video_renderer_kwargs=None,
    plot_renderer_kwargs=None,
    eval_env_ids=None,
    # Debugging params
    visualize_dynamics=False,
    visualize_discount_model=False,
    visualize_all_plots=False,
    plot_discount=False,
    plot_reward=False,
    plot_bootstrap_value=False,
    # env specific-params
    normalize_distances_for_full_state_ant=False,
):
    if dynamics_ensemble_kwargs is None:
        dynamics_ensemble_kwargs = {}
    if eval_env_ids is None:
        eval_env_ids = {'eval': env_id}
    if discount_model_config is None:
        discount_model_config = {}
    if dynamics_delta_model_config is None:
        dynamics_delta_model_config = {}
    if dynamics_adam_config is None:
        dynamics_adam_config = {}
    if discount_adam_config is None:
        discount_adam_config = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not video_renderer_kwargs:
        video_renderer_kwargs = {}
    if not plot_renderer_kwargs:
        plot_renderer_kwargs = video_renderer_kwargs.copy()
        plot_renderer_kwargs['dpi'] = 48
    context_key = desired_goal_key

    stub_env = get_gym_env(
        env_id,
        env_class=env_class,
        env_kwargs=env_kwargs,
        unwrap_timed_envs=True,
    )
    is_gym_env = (
        isinstance(stub_env, FetchEnv) or isinstance(stub_env, AntXYGoalEnv)
        or isinstance(stub_env, AntFullPositionGoalEnv)
        # or isinstance(stub_env, HopperFullPositionGoalEnv)
    )
    is_ant_full_pos = isinstance(stub_env, AntFullPositionGoalEnv)

    if is_gym_env:
        achieved_goal_key = desired_goal_key.replace('desired', 'achieved')
        ob_keys_to_save_in_buffer = [observation_key, achieved_goal_key]
    elif isinstance(stub_env, SawyerPickAndPlaceEnvYZ):
        achieved_goal_key = desired_goal_key.replace('desired', 'achieved')
        ob_keys_to_save_in_buffer = [observation_key, achieved_goal_key]
    else:
        achieved_goal_key = observation_key
        ob_keys_to_save_in_buffer = [observation_key]
    # TODO move all env-specific code to other file
    if isinstance(stub_env, SawyerDoorHookEnv):
        init_camera = sawyer_door_env_camera_v0
    elif isinstance(stub_env, SawyerPushAndReachXYEnv):
        init_camera = sawyer_init_camera_zoomed_in
    elif isinstance(stub_env, SawyerPickAndPlaceEnvYZ):
        init_camera = sawyer_pick_and_place_camera
    else:
        init_camera = None

    full_ob_space = stub_env.observation_space
    action_space = stub_env.action_space
    state_to_goal = StateToGoalFn(stub_env)
    dynamics_model = create_goal_dynamics_model(
        full_ob_space[observation_key],
        action_space,
        full_ob_space[achieved_goal_key],
        dynamics_model_version,
        state_to_goal,
        dynamics_model_config,
        dynamics_delta_model_config,
        ensemble_model_kwargs=dynamics_ensemble_kwargs,
    )
    sample_context_from_obs_dict_fn = RemapKeyFn(
        {context_key: achieved_goal_key})

    def contextual_env_distrib_reward(_env_id,
                                      _env_class=None,
                                      _env_kwargs=None):
        base_env = get_gym_env(
            _env_id,
            env_class=env_class,
            env_kwargs=env_kwargs,
            unwrap_timed_envs=True,
        )
        if init_camera:
            base_env.initialize_camera(init_camera)
        if (isinstance(stub_env, AntFullPositionGoalEnv)
                and normalize_distances_for_full_state_ant):
            base_env = NormalizeAntFullPositionGoalEnv(base_env)
            normalize_env = base_env
        else:
            normalize_env = None
        env = NoisyAction(base_env, action_noise_scale)
        diag_fns = []
        if is_gym_env:
            goal_distribution = GoalDictDistributionFromGymGoalEnv(
                env,
                desired_goal_key=desired_goal_key,
            )
            diag_fns.append(
                GenericGoalConditionedContextualDiagnostics(
                    desired_goal_key=desired_goal_key,
                    achieved_goal_key=achieved_goal_key,
                    success_threshold=success_threshold,
                ))
        else:
            goal_distribution = GoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
            )
            diag_fns.append(
                GoalConditionedDiagnosticsToContextualDiagnostics(
                    env.goal_conditioned_diagnostics,
                    desired_goal_key=desired_goal_key,
                    observation_key=observation_key,
                ))
        if isinstance(stub_env, AntFullPositionGoalEnv):
            diag_fns.append(
                AntFullPositionGoalEnvDiagnostics(
                    desired_goal_key=desired_goal_key,
                    achieved_goal_key=achieved_goal_key,
                    success_threshold=success_threshold,
                    normalize_env=normalize_env,
                ))
        # if isinstance(stub_env, HopperFullPositionGoalEnv):
        #     diag_fns.append(
        #         HopperFullPositionGoalEnvDiagnostics(
        #             desired_goal_key=desired_goal_key,
        #             achieved_goal_key=achieved_goal_key,
        #             success_threshold=success_threshold,
        #         )
        #     )
        achieved_from_ob = IndexIntoAchievedGoal(achieved_goal_key, )
        if reward_type == 'sparse':
            distance_fn = L2Distance(
                achieved_goal_from_observation=achieved_from_ob,
                desired_goal_key=desired_goal_key,
            )
            reward_fn = ThresholdDistanceReward(distance_fn, success_threshold)
        elif reward_type == 'negative_distance':
            reward_fn = NegativeL2Distance(
                achieved_goal_from_observation=achieved_from_ob,
                desired_goal_key=desired_goal_key,
            )
        else:
            reward_fn = ProbabilisticGoalRewardFn(
                dynamics_model,
                state_key=observation_key,
                context_key=context_key,
                reward_type=reward_type,
                discount_factor=discount_factor,
            )
        goal_distribution = PresampledDistribution(goal_distribution,
                                                   num_presampled_goals)
        final_env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=diag_fns,
            update_env_info_fn=delete_info,
        )
        return final_env, goal_distribution, reward_fn

    expl_env, expl_context_distrib, reward_fn = contextual_env_distrib_reward(
        env_id,
        env_class,
        env_kwargs,
    )
    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    def create_policy():
        obs_processor = MultiHeadedMlp(input_size=obs_dim,
                                       output_sizes=[action_dim, action_dim],
                                       **policy_kwargs)
        return PolicyFromDistributionGenerator(TanhGaussian(obs_processor))

    policy = create_policy()

    def concat_context_to_obs(batch, replay_buffer, obs_dict, next_obs_dict,
                              new_contexts):
        obs = batch['observations']
        next_obs = batch['next_observations']
        batch['original_observations'] = obs
        batch['original_next_observations'] = next_obs
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=expl_env,
        context_keys=[context_key],
        observation_keys=ob_keys_to_save_in_buffer,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=sample_context_from_obs_dict_fn,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)

    def create_trainer():
        trainers = OrderedDict()
        if learn_discount_model:
            discount_model = create_discount_model(
                ob_space=stub_env.observation_space[observation_key],
                goal_space=stub_env.observation_space[context_key],
                action_space=stub_env.action_space,
                model_kwargs=discount_model_config)
            optimizer = optim.Adam(discount_model.parameters(),
                                   **discount_adam_config)
            discount_trainer = DiscountModelTrainer(
                discount_model,
                optimizer,
                observation_key='observations',
                next_observation_key='original_next_observations',
                goal_key=context_key,
                state_to_goal_fn=state_to_goal,
            )
            trainers['discount_trainer'] = discount_trainer
        else:
            discount_model = None
        if prior_discount_weight_schedule_kwargs is not None:
            schedule = create_schedule(**prior_discount_weight_schedule_kwargs)
        else:
            schedule = None
        pgr_trainer = PGRTrainer(env=expl_env,
                                 policy=policy,
                                 qf1=qf1,
                                 qf2=qf2,
                                 target_qf1=target_qf1,
                                 target_qf2=target_qf2,
                                 discount=discount_factor,
                                 discount_model=discount_model,
                                 prior_discount_weight_schedule=schedule,
                                 **pgr_trainer_kwargs)
        trainers[''] = pgr_trainer
        optimizers = [
            pgr_trainer.qf1_optimizer,
            pgr_trainer.qf2_optimizer,
            pgr_trainer.alpha_optimizer,
            pgr_trainer.policy_optimizer,
        ]
        if dynamics_model_version in {
                'learned_model',
                'learned_model_ensemble',
                'learned_model_laplace',
                'learned_model_laplace_global_variance',
                'learned_model_gaussian_global_variance',
        }:
            model_opt = optim.Adam(dynamics_model.parameters(),
                                   **dynamics_adam_config)
        elif dynamics_model_version in {
                'fixed_standard_laplace',
                'fixed_standard_gaussian',
        }:
            model_opt = None
        else:
            raise NotImplementedError()
        model_trainer = GenerativeGoalDynamicsModelTrainer(
            dynamics_model,
            model_opt,
            state_to_goal=state_to_goal,
            observation_key='original_observations',
            next_observation_key='original_next_observations',
        )
        trainers['dynamics_trainer'] = model_trainer
        optimizers.append(model_opt)
        return JointTrainer(trainers), pgr_trainer

    trainer, pgr_trainer = create_trainer()

    eval_policy = MakeDeterministic(policy)

    def create_eval_path_collector(some_eval_env):
        return ContextualPathCollector(
            some_eval_env,
            eval_policy,
            observation_key=observation_key,
            context_keys_for_policy=[context_key],
        )

    path_collectors = dict()
    eval_env_name_to_env_and_context_distrib = dict()
    for name, extra_env_id in eval_env_ids.items():
        env, context_distrib, _ = contextual_env_distrib_reward(extra_env_id)
        path_collectors[name] = create_eval_path_collector(env)
        eval_env_name_to_env_and_context_distrib[name] = (env, context_distrib)
    eval_path_collector = JointPathCollector(path_collectors)
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[context_key],
    )

    def get_eval_diagnostics(key_to_paths):
        stats = OrderedDict()
        for eval_env_name, paths in key_to_paths.items():
            env, _ = eval_env_name_to_env_and_context_distrib[eval_env_name]
            stats.update(
                add_prefix(
                    env.get_diagnostics(paths),
                    eval_env_name,
                    divider='/',
                ))
            stats.update(
                add_prefix(
                    eval_util.get_generic_path_information(paths),
                    eval_env_name,
                    divider='/',
                ))
        return stats

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=None,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        evaluation_get_diagnostic_functions=[get_eval_diagnostics],
        **algo_kwargs)
    algorithm.to(ptu.device)

    if normalize_distances_for_full_state_ant and is_ant_full_pos:
        qpos_weights = expl_env.unwrapped.presampled_qpos.std(axis=0)
    else:
        qpos_weights = None

    if save_video:
        if is_gym_env:
            video_renderer = GymEnvRenderer(**video_renderer_kwargs)

            def set_goal_for_visualization(env, policy, o):
                goal = o[desired_goal_key]
                if normalize_distances_for_full_state_ant and is_ant_full_pos:
                    unnormalized_goal = goal * qpos_weights
                    env.unwrapped.goal = unnormalized_goal
                else:
                    env.unwrapped.goal = goal

            rollout_function = partial(
                rf.contextual_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                context_keys_for_policy=[context_key],
                reset_callback=set_goal_for_visualization,
            )
        else:
            video_renderer = EnvRenderer(**video_renderer_kwargs)
            rollout_function = partial(
                rf.contextual_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                context_keys_for_policy=[context_key],
                reset_callback=None,
            )

        renderers = OrderedDict(image_observation=video_renderer, )
        state_env = expl_env.env
        state_space = state_env.observation_space[observation_key]
        low = state_space.low.min()
        high = state_space.high.max()
        y = np.linspace(low, high, num=video_renderer.image_chw[1])
        x = np.linspace(low, high, num=video_renderer.image_chw[2])
        all_xy_np = np.transpose([np.tile(x, len(y)), np.repeat(y, len(x))])
        all_xy_torch = ptu.from_numpy(all_xy_np)
        num_states = all_xy_torch.shape[0]
        if visualize_dynamics:

            def create_dynamics_visualizer(show_prob, vary_state=False):
                def get_prob(obs_dict, action):
                    obs = obs_dict['state_observation']
                    obs_torch = ptu.from_numpy(obs)[None]
                    action_torch = ptu.from_numpy(action)[None]
                    if vary_state:
                        action_repeated = torch.zeros((num_states, 2))
                        dist = dynamics_model(all_xy_torch, action_repeated)
                        goal = ptu.from_numpy(
                            obs_dict['state_desired_goal'][None])
                        log_probs = dist.log_prob(goal)
                    else:
                        dist = dynamics_model(obs_torch, action_torch)
                        log_probs = dist.log_prob(all_xy_torch)
                    if show_prob:
                        return log_probs.exp()
                    else:
                        return log_probs

                return get_prob

            renderers['log_prob'] = ValueRenderer(
                create_dynamics_visualizer(False), **video_renderer_kwargs)
            # renderers['prob'] = ValueRenderer(
            #     create_dynamics_visualizer(True), **video_renderer_kwargs
            # )
            renderers['log_prob_vary_state'] = ValueRenderer(
                create_dynamics_visualizer(False, vary_state=True),
                only_get_image_once_per_episode=True,
                max_out_walls=isinstance(stub_env, PickAndPlaceEnv),
                **video_renderer_kwargs)
            # renderers['prob_vary_state'] = ValueRenderer(
            #     create_dynamics_visualizer(True, vary_state=True),
            #     **video_renderer_kwargs)
        if visualize_discount_model and pgr_trainer.discount_model:

            def get_discount_values(obs, action):
                obs = obs['state_observation']
                obs_torch = ptu.from_numpy(obs)[None]
                combined_obs = torch.cat([
                    obs_torch.repeat(num_states, 1),
                    all_xy_torch,
                ],
                                         dim=1)

                action_torch = ptu.from_numpy(action)[None]
                action_repeated = action_torch.repeat(num_states, 1)
                return pgr_trainer.discount_model(combined_obs,
                                                  action_repeated)

            renderers['discount_model'] = ValueRenderer(
                get_discount_values,
                states_to_eval=all_xy_torch,
                **video_renderer_kwargs)
        if 'log_prob' in renderers and 'discount_model' in renderers:
            renderers['log_prob_time_discount'] = ProductRenderer(
                renderers['discount_model'], renderers['log_prob'],
                **video_renderer_kwargs)

        def get_reward(obs_dict, action, next_obs_dict):
            o = batchify(obs_dict)
            a = batchify(action)
            next_o = batchify(next_obs_dict)
            reward = reward_fn(o, a, next_o, next_o)
            return reward[0]

        def get_bootstrap(obs_dict, action, next_obs_dict, return_float=True):
            context_pt = ptu.from_numpy(obs_dict[context_key][None])
            o_pt = ptu.from_numpy(obs_dict[observation_key][None])
            next_o_pt = ptu.from_numpy(next_obs_dict[observation_key][None])
            action_torch = ptu.from_numpy(action[None])
            bootstrap, *_ = pgr_trainer.get_bootstrap_stats(
                torch.cat((o_pt, context_pt), dim=1),
                action_torch,
                torch.cat((next_o_pt, context_pt), dim=1),
            )
            if return_float:
                return ptu.get_numpy(bootstrap)[0, 0]
            else:
                return bootstrap

        def get_discount(obs_dict, action, next_obs_dict):
            bootstrap = get_bootstrap(obs_dict,
                                      action,
                                      next_obs_dict,
                                      return_float=False)
            reward_np = get_reward(obs_dict, action, next_obs_dict)
            reward = ptu.from_numpy(reward_np[None, None])
            context_pt = ptu.from_numpy(obs_dict[context_key][None])
            o_pt = ptu.from_numpy(obs_dict[observation_key][None])
            obs = torch.cat((o_pt, context_pt), dim=1)
            actions = ptu.from_numpy(action[None])
            discount = pgr_trainer.get_discount_factor(
                bootstrap,
                reward,
                obs,
                actions,
            )
            if isinstance(discount, torch.Tensor):
                discount = ptu.get_numpy(discount)[0, 0]
            return np.clip(discount, a_min=1e-3, a_max=1)

        def create_modify_fn(
            title,
            set_params=None,
            scientific=True,
        ):
            def modify(ax):
                ax.set_title(title)
                if set_params:
                    ax.set(**set_params)
                if scientific:
                    scaler = ScalarFormatter(useOffset=True)
                    scaler.set_powerlimits((1, 1))
                    ax.yaxis.set_major_formatter(scaler)
                    ax.ticklabel_format(axis='y', style='sci')

            return modify

        def add_left_margin(fig):
            fig.subplots_adjust(left=0.2)

        if visualize_all_plots or plot_discount:
            renderers['discount'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_discount,
                modify_ax_fn=create_modify_fn(
                    title='discount',
                    set_params=dict(
                        # yscale='log',
                        ylim=[-0.05, 1.1], ),
                    # scientific=False,
                ),
                modify_fig_fn=add_left_margin,
                # autoscale_y=False,
                **plot_renderer_kwargs)

        if visualize_all_plots or plot_reward:
            renderers['reward'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_reward,
                modify_ax_fn=create_modify_fn(title='reward',
                                              # scientific=False,
                                              ),
                modify_fig_fn=add_left_margin,
                **plot_renderer_kwargs)
        if visualize_all_plots or plot_bootstrap_value:
            renderers['bootstrap-value'] = DynamicNumberEnvRenderer(
                dynamic_number_fn=get_bootstrap,
                modify_ax_fn=create_modify_fn(title='bootstrap value',
                                              # scientific=False,
                                              ),
                modify_fig_fn=add_left_margin,
                **plot_renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            if is_gym_env:
                goal_distribution = state_distribution
            else:
                goal_distribution = AddImageDistribution(
                    env=state_env,
                    base_distribution=state_distribution,
                    image_goal_key='image_desired_goal',
                    renderer=video_renderer,
                )
            context_env = ContextualEnv(
                state_env,
                context_distribution=goal_distribution,
                reward_fn=reward_fn,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
            return InsertDebugImagesEnv(
                context_env,
                renderers=renderers,
            )

        img_expl_env = add_images(expl_env, expl_context_distrib)
        if is_gym_env:
            imgs_to_show = list(renderers.keys())
        else:
            imgs_to_show = ['image_desired_goal'] + list(renderers.keys())
        img_formats = [video_renderer.output_image_format]
        img_formats += [r.output_image_format for r in renderers.values()]
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="xplor",
            imsize=video_renderer.image_chw[1],
            image_formats=img_formats,
            keys_to_show=imgs_to_show,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(expl_video_func)
        for eval_env_name, (env, context_distrib) in (
                eval_env_name_to_env_and_context_distrib.items()):
            img_eval_env = add_images(env, context_distrib)
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                eval_policy,
                tag=eval_env_name,
                imsize=video_renderer.image_chw[1],
                image_formats=img_formats,
                keys_to_show=imgs_to_show,
                **save_video_kwargs)
            algorithm.post_train_funcs.append(eval_video_func)

    algorithm.train()
Beispiel #15
0
def representation_learning_with_goal_distribution_launcher(
        max_path_length,
        contextual_replay_buffer_kwargs,
        sac_trainer_kwargs,
        algo_kwargs,
        qf_kwargs=None,
        policy_kwargs=None,
        # env settings
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='latent_observation',
        desired_goal_key='latent_desired_goal',
        achieved_goal_key='latent_achieved_goal',
        renderer_kwargs=None,
        # mask settings
        mask_variant=None,  # TODO: manually unpack this as well
        mask_conditioned=True,
        mask_format='vector',
        infer_masks=False,
        # rollout
        expl_goal_sampling_mode=None,
        eval_goal_sampling_mode=None,
        eval_rollouts_for_videos=None,
        eval_rollouts_to_log=None,
        # debugging
        log_mask_diagnostics=True,
        log_expl_video=True,
        log_eval_video=True,
        save_video=True,
        save_video_period=50,
        save_env_in_snapshot=True,
        dump_video_kwargs=None,
        # re-loading
        ckpt=None,
        ckpt_epoch=None,
        seedid=0,
):
    if eval_rollouts_to_log is None:
        eval_rollouts_to_log = [
            'atomic',
            'atomic_seq',
            'cumul_seq',
            'full',
        ]
    if renderer_kwargs is None:
        renderer_kwargs = {}
    if dump_video_kwargs is None:
        dump_video_kwargs = {}
    if eval_rollouts_for_videos is None:
        eval_rollouts_for_videos = [
            'atomic',
            'atomic_seq',
            'cumul_seq',
            'full',
        ]
    if mask_variant is None:
        mask_variant = {}
    if policy_kwargs is None:
        policy_kwargs = {}
    if qf_kwargs is None:
        qf_kwargs = {}
    context_key = desired_goal_key
    prev_subtask_weight = mask_variant.get('prev_subtask_weight', None)

    context_post_process_mode = mask_variant.get('context_post_process_mode',
                                                 None)
    if context_post_process_mode in [
        'dilute_prev_subtasks_uniform', 'dilute_prev_subtasks_fixed'
    ]:
        prev_subtask_weight = 0.5
    prev_subtasks_solved = mask_variant.get('prev_subtasks_solved', False)
    max_subtasks_to_focus_on = mask_variant.get(
        'max_subtasks_to_focus_on', None)
    max_subtasks_per_rollout = mask_variant.get(
        'max_subtasks_per_rollout', None)
    mask_groups = mask_variant.get('mask_groups', None)
    rollout_mask_order_for_expl = mask_variant.get(
        'rollout_mask_order_for_expl', 'fixed')
    rollout_mask_order_for_eval = mask_variant.get(
        'rollout_mask_order_for_eval', 'fixed')
    masks = mask_variant.get('masks', None)
    idx_masks = mask_variant.get('idx_masks', None)
    matrix_masks = mask_variant.get('matrix_masks', None)
    train_mask_distr = mask_variant.get('train_mask_distr', None)
    mask_inference_variant = mask_variant.get('mask_inference_variant', {})
    mask_reward_fn = mask_variant.get('reward_fn', default_masked_reward_fn)
    expl_mask_distr = mask_variant['expl_mask_distr']
    eval_mask_distr = mask_variant['eval_mask_distr']
    use_g_for_mean = mask_variant['use_g_for_mean']
    context_post_process_frac = mask_variant.get(
        'context_post_process_frac', 0.50)
    sample_masks_for_relabeling = mask_variant.get(
        'sample_masks_for_relabeling', True)

    if mask_conditioned:
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        assert mask_format in ['vector', 'matrix', 'distribution']
        goal_dim = env.observation_space.spaces[context_key].low.size
        if mask_format == 'vector':
            mask_keys = ['mask']
            mask_dims = [(goal_dim,)]
            context_dim = goal_dim + goal_dim
        elif mask_format == 'matrix':
            mask_keys = ['mask']
            mask_dims = [(goal_dim, goal_dim)]
            context_dim = goal_dim + (goal_dim * goal_dim)
        elif mask_format == 'distribution':
            mask_keys = ['mask_mu_w', 'mask_mu_g', 'mask_mu_mat',
                         'mask_sigma_inv']
            mask_dims = [(goal_dim,), (goal_dim,), (goal_dim, goal_dim),
                         (goal_dim, goal_dim)]
            context_dim = goal_dim + (goal_dim * goal_dim)  # mu and sigma_inv
        else:
            raise NotImplementedError

        if infer_masks:
            assert mask_format == 'distribution'
            env_kwargs_copy = copy.deepcopy(env_kwargs)
            env_kwargs_copy['lite_reset'] = True
            infer_masks_env = get_gym_env(env_id, env_class=env_class,
                                          env_kwargs=env_kwargs_copy)

            masks = infer_masks_fn(
                infer_masks_env,
                idx_masks,
                mask_inference_variant,
            )

        context_keys = [context_key] + mask_keys
    else:
        context_keys = [context_key]

    def contextual_env_distrib_and_reward(mode='expl'):
        assert mode in ['expl', 'eval']
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)

        if mode == 'expl':
            goal_sampling_mode = expl_goal_sampling_mode
        elif mode == 'eval':
            goal_sampling_mode = eval_goal_sampling_mode
        else:
            goal_sampling_mode = None
        if goal_sampling_mode is not None:
            env.goal_sampling_mode = goal_sampling_mode

        if mask_conditioned:
            context_distrib = MaskedGoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
                mask_keys=mask_keys,
                mask_dims=mask_dims,
                mask_format=mask_format,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                prev_subtask_weight=prev_subtask_weight,
                masks=masks,
                idx_masks=idx_masks,
                matrix_masks=matrix_masks,
                mask_distr=train_mask_distr,
            )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    achieved_goal_key),  # observation_key
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=contextual_replay_buffer_kwargs.get(
                    'observation_keys', None),
                additional_context_keys=mask_keys,
                reward_fn=partial(
                    mask_reward_fn,
                    mask_format=mask_format,
                    use_g_for_mean=use_g_for_mean
                ),
            )
        else:
            context_distrib = GoalDictDistributionFromMultitaskEnv(
                env,
                desired_goal_keys=[desired_goal_key],
            )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(
                    achieved_goal_key),  # observation_key
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=contextual_replay_buffer_kwargs.get(
                    'observation_keys', None),
            )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=context_distrib,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info,
        )
        return env, context_distrib, reward_fn

    env, context_distrib, reward_fn = contextual_env_distrib_and_reward(
        mode='expl')
    eval_env, eval_context_distrib, _ = contextual_env_distrib_and_reward(
        mode='eval')

    if mask_conditioned:
        obs_dim = (
                env.observation_space.spaces[observation_key].low.size
                + context_dim
        )
    else:
        obs_dim = (
                env.observation_space.spaces[observation_key].low.size
                + env.observation_space.spaces[context_key].low.size
        )

    action_dim = env.action_space.low.size

    if ckpt:
        from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
        import os.path as osp

        if ckpt_epoch is not None:
            epoch = ckpt_epoch
            filename = local_path_from_s3_or_local_path(
                osp.join(ckpt, 'itr_%d.pkl' % epoch))
        else:
            filename = local_path_from_s3_or_local_path(
                osp.join(ckpt, 'params.pkl'))
        print("Loading ckpt from", filename)
        # data = joblib.load(filename)
        data = torch.load(filename, map_location='cuda:1')
        qf1 = data['trainer/qf1']
        qf2 = data['trainer/qf2']
        target_qf1 = data['trainer/target_qf1']
        target_qf2 = data['trainer/target_qf2']
        policy = data['trainer/policy']
        eval_policy = data['evaluation/policy']
        expl_policy = data['exploration/policy']
    else:
        def create_qf():
            return ConcatMlp(
                input_size=obs_dim + action_dim,
                output_size=1,
                **qf_kwargs
            )

        qf1 = create_qf()
        qf2 = create_qf()
        target_qf1 = create_qf()
        target_qf2 = create_qf()
        policy = TanhGaussianPolicy(
            obs_dim=obs_dim,
            action_dim=action_dim,
            **policy_kwargs
        )
        expl_policy = policy
        eval_policy = MakeDeterministic(policy)

    def context_from_obs_dict_fn(obs_dict):
        context_dict = {
            context_key: obs_dict[achieved_goal_key],  # observation_key
        }
        if mask_conditioned:
            if sample_masks_for_relabeling:
                batch_size = obs_dict[list(obs_dict.keys())[0]].shape[0]
                sampled_contexts = context_distrib.sample(batch_size)
                for mask_key in mask_keys:
                    context_dict[mask_key] = sampled_contexts[mask_key]
            else:
                for mask_key in mask_keys:
                    context_dict[mask_key] = obs_dict[mask_key]
        return context_dict

    def post_process_mask_fn(obs_dict, context_dict):
        assert mask_conditioned
        pp_context_dict = copy.deepcopy(context_dict)

        assert context_post_process_mode in [
            'prev_subtasks_solved',
            'dilute_prev_subtasks_uniform',
            'dilute_prev_subtasks_fixed',
            'atomic_to_corresp_cumul',
            None
        ]

        if context_post_process_mode in [
            'prev_subtasks_solved',
            'dilute_prev_subtasks_uniform',
            'dilute_prev_subtasks_fixed',
            'atomic_to_corresp_cumul'
        ]:
            frac = context_post_process_frac
            cumul_mask_to_indices = context_distrib.get_cumul_mask_to_indices(
                context_dict['mask']
            )
            for k in cumul_mask_to_indices:
                indices = cumul_mask_to_indices[k]
                subset = np.random.choice(len(indices),
                                          int(len(indices) * frac),
                                          replace=False)
                cumul_mask_to_indices[k] = indices[subset]
        else:
            cumul_mask_to_indices = None

        mode = context_post_process_mode
        if mode in [
            'prev_subtasks_solved', 'dilute_prev_subtasks_uniform',
            'dilute_prev_subtasks_fixed'
        ]:
            cumul_masks = list(cumul_mask_to_indices.keys())
            for i in range(1, len(cumul_masks)):
                curr_mask = cumul_masks[i]
                prev_mask = cumul_masks[i - 1]
                prev_obj_indices = np.where(np.array(prev_mask) > 0)[0]
                indices = cumul_mask_to_indices[curr_mask]
                if mode == 'prev_subtasks_solved':
                    pp_context_dict[context_key][indices][:, prev_obj_indices] = \
                        obs_dict[achieved_goal_key][indices][:,
                        prev_obj_indices]
                elif mode == 'dilute_prev_subtasks_uniform':
                    pp_context_dict['mask'][indices][:, prev_obj_indices] = \
                        np.random.uniform(
                            size=(len(indices), len(prev_obj_indices)))
                elif mode == 'dilute_prev_subtasks_fixed':
                    pp_context_dict['mask'][indices][:, prev_obj_indices] = 0.5
            indices_to_relabel = np.concatenate(
                list(cumul_mask_to_indices.values()))
            orig_masks = obs_dict['mask'][indices_to_relabel]
            atomic_mask_to_subindices = context_distrib.get_atomic_mask_to_indices(
                orig_masks)
            atomic_masks = list(atomic_mask_to_subindices.keys())
            cumul_masks = list(cumul_mask_to_indices.keys())
            for i in range(1, len(atomic_masks)):
                orig_atomic_mask = atomic_masks[i]
                relabeled_cumul_mask = cumul_masks[i]
                subindices = atomic_mask_to_subindices[orig_atomic_mask]
                pp_context_dict['mask'][indices_to_relabel][
                    subindices] = relabeled_cumul_mask

        return pp_context_dict

    # if mask_conditioned:
    #     variant['contextual_replay_buffer_kwargs']['post_process_batch_fn'] = post_process_mask_fn

    def concat_context_to_obs(batch, replay_buffer=None, obs_dict=None,
                              next_obs_dict=None, new_contexts=None):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        if mask_conditioned:
            if obs_dict is not None and new_contexts is not None:
                updated_contexts = post_process_mask_fn(obs_dict, new_contexts)
                batch.update(updated_contexts)

            if mask_format in ['vector', 'matrix']:
                assert len(mask_keys) == 1
                mask = batch[mask_keys[0]].reshape((len(context), -1))
                batch['observations'] = np.concatenate([obs, context, mask],
                                                       axis=1)
                batch['next_observations'] = np.concatenate(
                    [next_obs, context, mask], axis=1)
            elif mask_format == 'distribution':
                g = context
                mu_w = batch['mask_mu_w']
                mu_g = batch['mask_mu_g']
                mu_A = batch['mask_mu_mat']
                sigma_inv = batch['mask_sigma_inv']
                if use_g_for_mean:
                    mu_w_given_g = g
                else:
                    mu_w_given_g = mu_w + np.squeeze(
                        mu_A @ np.expand_dims(g - mu_g, axis=-1), axis=-1)
                sigma_w_given_g_inv = sigma_inv.reshape((len(context), -1))
                batch['observations'] = np.concatenate(
                    [obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
                batch['next_observations'] = np.concatenate(
                    [next_obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
            else:
                raise NotImplementedError
        else:
            batch['observations'] = np.concatenate([obs, context], axis=1)
            batch['next_observations'] = np.concatenate([next_obs, context],
                                                        axis=1)
        return batch

    if 'observation_keys' not in contextual_replay_buffer_kwargs:
        contextual_replay_buffer_kwargs['observation_keys'] = []
    observation_keys = contextual_replay_buffer_kwargs['observation_keys']
    if observation_key not in observation_keys:
        observation_keys.append(observation_key)
    if achieved_goal_key not in observation_keys:
        observation_keys.append(achieved_goal_key)

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=env,
        context_keys=context_keys,
        context_distribution=context_distrib,
        sample_context_from_obs_dict_fn=context_from_obs_dict_fn,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        **contextual_replay_buffer_kwargs
    )

    trainer = SACTrainer(
        env=env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **sac_trainer_kwargs
    )

    def create_path_collector(
            env,
            policy,
            mode='expl',
            mask_kwargs=None,
    ):
        if mask_kwargs is None:
            mask_kwargs = {}
        assert mode in ['expl', 'eval']
        if mask_conditioned:
            if 'rollout_mask_order' in mask_kwargs:
                rollout_mask_order = mask_kwargs['rollout_mask_order']
            else:
                if mode == 'expl':
                    rollout_mask_order = rollout_mask_order_for_expl
                elif mode == 'eval':
                    rollout_mask_order = rollout_mask_order_for_eval
                else:
                    raise NotImplementedError

            if 'mask_distr' in mask_kwargs:
                mask_distr = mask_kwargs['mask_distr']
            else:
                if mode == 'expl':
                    mask_distr = expl_mask_distr
                elif mode == 'eval':
                    mask_distr = eval_mask_distr
                else:
                    raise NotImplementedError

            return MaskPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                concat_context_to_obs_fn=concat_context_to_obs,
                save_env_in_snapshot=save_env_in_snapshot,
                mask_sampler=(
                    context_distrib if mode == 'expl' else eval_context_distrib),
                mask_distr=mask_distr.copy(),
                mask_groups=mask_groups,
                max_path_length=max_path_length,
                rollout_mask_order=rollout_mask_order,
                prev_subtask_weight=prev_subtask_weight,
                prev_subtasks_solved=prev_subtasks_solved,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                max_subtasks_per_rollout=max_subtasks_per_rollout,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                save_env_in_snapshot=save_env_in_snapshot,
            )

    expl_path_collector = create_path_collector(env, expl_policy, mode='expl')
    eval_path_collector = create_path_collector(eval_env, eval_policy,
                                                mode='eval')

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs
    )

    algorithm.to(ptu.device)

    if save_video:
        renderer = EnvRenderer(**renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImagesEnv(state_env, renderers={
                'image_observation': renderer,
            })
            context_env = ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=reward_fn,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
            return context_env

        img_eval_env = add_images(eval_env, eval_context_distrib)

        if log_eval_video:
            video_path_collector = create_path_collector(img_eval_env,
                                                         eval_policy,
                                                         mode='eval')
            rollout_function = video_path_collector._rollout_fn
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                eval_policy,
                tag="eval",
                imsize=renderer_kwargs['width'],
                image_format='CHW',
                save_video_period=save_video_period,
                horizon=max_path_length,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(eval_video_func)

        # additional eval videos for mask conditioned case
        if mask_conditioned:
            if 'cumul_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            cumul_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_cumul" if mask_conditioned else "eval",
                    imsize=renderer_kwargs['width'],
                    image_format='HWC',
                    save_video_period=save_video_period,
                    horizon=max_path_length,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'full' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            full=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_full",
                    imsize=renderer_kwargs['width'],
                    image_format='HWC',
                    save_video_period=save_video_period,
                    horizon=max_path_length,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'atomic_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            atomic_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_atomic",
                    imsize=renderer_kwargs['width'],
                    image_format='HWC',
                    save_video_period=save_video_period,
                    horizon=max_path_length,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

        if log_expl_video:
            img_expl_env = add_images(env, context_distrib)
            video_path_collector = create_path_collector(img_expl_env,
                                                         expl_policy,
                                                         mode='expl')
            rollout_function = video_path_collector._rollout_fn
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                expl_policy,
                tag="expl",
                imsize=renderer_kwargs['width'],
                image_format='CHW',
                save_video_period=save_video_period,
                horizon=max_path_length,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(expl_video_func)

    if mask_conditioned and log_mask_diagnostics:
        collectors = []
        log_prefixes = []

        default_list = [
            'atomic',
            'atomic_seq',
            'cumul_seq',
            'full',
        ]
        for key in eval_rollouts_to_log:
            assert key in default_list

        if 'atomic' in eval_rollouts_to_log:
            num_masks = len(eval_path_collector.mask_groups)
            for mask_id in range(num_masks):
                mask_kwargs = dict(
                    rollout_mask_order=[mask_id],
                    mask_distr=dict(
                        atomic_seq=1.0,
                    ),
                )
                collector = create_path_collector(eval_env, eval_policy,
                                                  mode='eval',
                                                  mask_kwargs=mask_kwargs)
                collectors.append(collector)
            log_prefixes += [
                'mask_{}/'.format(''.join(str(mask_id)))
                for mask_id in range(num_masks)
            ]

        # full mask
        if 'full' in eval_rollouts_to_log:
            mask_kwargs = dict(
                mask_distr=dict(
                    full=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy,
                                              mode='eval',
                                              mask_kwargs=mask_kwargs)
            collectors.append(collector)
            log_prefixes.append('mask_full/')

        # cumulative, sequential mask
        if 'cumul_seq' in eval_rollouts_to_log:
            mask_kwargs = dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    cumul_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy,
                                              mode='eval',
                                              mask_kwargs=mask_kwargs)
            collectors.append(collector)
            log_prefixes.append('mask_cumul_seq/')

        # atomic, sequential mask
        if 'atomic_seq' in eval_rollouts_to_log:
            mask_kwargs = dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    atomic_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy,
                                              mode='eval',
                                              mask_kwargs=mask_kwargs)
            collectors.append(collector)
            log_prefixes.append('mask_atomic_seq/')

        def get_mask_diagnostics(unused):
            from rlkit.core.logging import append_log, add_prefix, OrderedDict
            log = OrderedDict()
            for prefix, collector in zip(log_prefixes, collectors):
                paths = collector.collect_new_paths(
                    max_path_length,
                    max_path_length,  # masking_eval_steps,
                    discard_incomplete_paths=True,
                )
                # old_path_info = eval_util.get_generic_path_information(paths)
                old_path_info = eval_env.get_diagnostics(paths)

                keys_to_keep = []
                for key in old_path_info.keys():
                    if ('env_infos' in key) and ('final' in key) and (
                            'Mean' in key):
                        keys_to_keep.append(key)
                path_info = OrderedDict()
                for key in keys_to_keep:
                    path_info[key] = old_path_info[key]

                generic_info = add_prefix(
                    path_info,
                    prefix,
                )
                append_log(log, generic_info)

            for collector in collectors:
                collector.end_epoch(0)
            return log

        algorithm._eval_get_diag_fns.append(get_mask_diagnostics)
    algorithm.train()
Beispiel #16
0
def masking_sac_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    env_class=None,
    env_kwargs=None,
    observation_key='state_observation',
    desired_goal_key='state_desired_goal',
    achieved_goal_key='state_achieved_goal',
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
    train_env_id=None,
    eval_env_id=None,
    do_masking=True,
    mask_key="masked_observation",
    masking_eval_steps=200,
    log_mask_diagnostics=True,
    mask_dim=None,
    masking_reward_fn=None,
    masking_for_exploration=True,
    rotate_masks_for_eval=False,
    rotate_masks_for_expl=False,
    mask_distribution=None,
    num_steps_per_mask_change=10,
    tag=None,
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    context_key = desired_goal_key
    env = get_gym_env(train_env_id, env_class=env_class, env_kwargs=env_kwargs)
    mask_dim = (mask_dim or env.observation_space.spaces[context_key].low.size)

    if not do_masking:
        mask_distribution = 'all_ones'

    assert mask_distribution in [
        'one_hot_masks', 'random_bit_masks', 'all_ones'
    ]
    if mask_distribution == 'all_ones':
        mask_distribution = 'static_mask'

    def contextual_env_distrib_and_reward(
        env_id,
        env_class,
        env_kwargs,
        goal_sampling_mode,
        env_mask_distribution_type,
        static_mask=None,
    ):
        env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        env.goal_sampling_mode = goal_sampling_mode
        use_static_mask = (not do_masking
                           or env_mask_distribution_type == 'static_mask')
        # Default to all ones mask if static mask isn't defined and should
        # be using static masks.
        if use_static_mask and static_mask is None:
            static_mask = np.ones(mask_dim)
        if not do_masking:
            assert env_mask_distribution_type == 'static_mask'

        goal_distribution = GoalDictDistributionFromMultitaskEnv(
            env,
            desired_goal_keys=[desired_goal_key],
        )
        goal_distribution = MaskedGoalDictDistribution(
            goal_distribution,
            mask_key=mask_key,
            mask_dim=mask_dim,
            distribution_type=env_mask_distribution_type,
            static_mask=static_mask,
        )

        reward_fn = ContextualRewardFnFromMultitaskEnv(
            env=env,
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            achieved_goal_from_observation=IndexIntoAchievedGoal(
                observation_key),
        )

        if do_masking:
            if masking_reward_fn:
                reward_fn = partial(masking_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)
            else:
                reward_fn = partial(default_masked_reward_fn,
                                    context_key=context_key,
                                    mask_key=mask_key)

        state_diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )

        env = ContextualEnv(
            env,
            context_distribution=goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[state_diag_fn],
        )
        return env, goal_distribution, reward_fn

    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        train_env_id,
        env_class,
        env_kwargs,
        exploration_goal_sampling_mode,
        mask_distribution if masking_for_exploration else 'static_mask',
    )
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        eval_env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        'static_mask')

    # Distribution for relabeling
    relabel_context_distrib = GoalDictDistributionFromMultitaskEnv(
        env,
        desired_goal_keys=[desired_goal_key],
    )
    relabel_context_distrib = MaskedGoalDictDistribution(
        relabel_context_distrib,
        mask_key=mask_key,
        mask_dim=mask_dim,
        distribution_type=mask_distribution,
        static_mask=None if do_masking else np.ones(mask_dim),
    )

    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size +
               mask_dim)
    action_dim = expl_env.action_space.low.size

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def context_from_obs_dict_fn(obs_dict):
        achieved_goal = obs_dict['state_achieved_goal']
        # Should the mask be randomized for future relabeling?
        # batch_size = len(achieved_goal)
        # mask = np.random.choice(2, size=(batch_size, mask_dim))
        mask = obs_dict[mask_key]
        return {
            mask_key: mask,
            context_key: achieved_goal,
        }

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch['observations']
        next_obs = batch['next_observations']

        context = batch[context_key]
        mask = batch[mask_key]

        batch['observations'] = np.concatenate([obs, context, mask], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context, mask],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=[context_key, mask_key],
        observation_keys=[observation_key, 'state_achieved_goal'],
        context_distribution=relabel_context_distrib,
        sample_context_from_obs_dict_fn=context_from_obs_dict_fn,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    trainer = SACTrainer(env=expl_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **sac_trainer_kwargs)

    def create_path_collector(env, policy, is_rotating):
        if is_rotating:
            assert do_masking and mask_distribution == 'one_hot_masks'
            return RotatingMaskingPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[context_key, mask_key],
                mask_key=mask_key,
                mask_length=mask_dim,
                num_steps_per_mask_change=num_steps_per_mask_change,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[context_key, mask_key],
            )

    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)

    eval_path_collector = create_path_collector(eval_env,
                                                MakeDeterministic(policy),
                                                rotate_masks_for_eval)
    expl_path_collector = create_path_collector(expl_env, exploration_policy,
                                                rotate_masks_for_expl)

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = partial(
            rf.contextual_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            context_keys_for_policy=[context_key, mask_key],
            # Eval on everything for the base video
            obs_processor=lambda o: np.hstack(
                (o[observation_key], o[context_key], np.ones(mask_dim))))
        renderer = EnvRenderer(**renderer_kwargs)

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImageEnv(state_env, renderer=renderer)
            return ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=eval_reward,
                observation_key=observation_key,
                # update_env_info_fn=DeleteOldEnvInfo(),
            )

        img_eval_env = add_images(eval_env, eval_context_distrib)
        img_expl_env = add_images(expl_env, expl_context_distrib)
        eval_video_func = get_save_video_function(
            rollout_function,
            img_eval_env,
            MakeDeterministic(policy),
            tag="eval",
            imsize=renderer.image_shape[0],
            image_format='CWH',
            **save_video_kwargs)
        expl_video_func = get_save_video_function(
            rollout_function,
            img_expl_env,
            exploration_policy,
            tag="train",
            imsize=renderer.image_shape[0],
            image_format='CWH',
            **save_video_kwargs)

        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(expl_video_func)

    # For diagnostics, evaluate the policy on each individual dimension of the
    # mask.
    masks = []
    collectors = []
    for mask_idx in range(mask_dim):
        mask = np.zeros(mask_dim)
        mask[mask_idx] = 1
        masks.append(mask)

    for mask in masks:
        for_dim_mask = mask if do_masking else np.ones(mask_dim)
        masked_env, _, _ = contextual_env_distrib_and_reward(
            eval_env_id,
            env_class,
            env_kwargs,
            evaluation_goal_sampling_mode,
            'static_mask',
            static_mask=for_dim_mask)

        collector = ContextualPathCollector(
            masked_env,
            MakeDeterministic(policy),
            observation_key=observation_key,
            context_keys_for_policy=[context_key, mask_key],
        )
        collectors.append(collector)
    log_prefixes = [
        'mask_{}/'.format(''.join(mask.astype(int).astype(str)))
        for mask in masks
    ]

    def get_mask_diagnostics(unused):
        from rlkit.core.logging import append_log, add_prefix, OrderedDict
        from rlkit.misc import eval_util
        log = OrderedDict()
        for prefix, collector in zip(log_prefixes, collectors):
            paths = collector.collect_new_paths(
                max_path_length,
                masking_eval_steps,
                discard_incomplete_paths=True,
            )
            generic_info = add_prefix(
                eval_util.get_generic_path_information(paths),
                prefix,
            )
            append_log(log, generic_info)

        for collector in collectors:
            collector.end_epoch(0)
        return log

    if log_mask_diagnostics:
        algorithm._eval_get_diag_fns.append(get_mask_diagnostics)

    algorithm.train()
Beispiel #17
0
def awac_rig_experiment(
    max_path_length,
    qf_kwargs,
    trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    train_vae_kwargs,
    policy_class=TanhGaussianPolicy,
    env_id=None,
    env_class=None,
    env_kwargs=None,
    reward_kwargs=None,
    observation_key='latent_observation',
    desired_goal_key='latent_desired_goal',
    state_observation_key='state_observation',
    state_goal_key='state_desired_goal',
    image_goal_key='image_desired_goal',
    path_loader_class=MDPPathLoader,
    demo_replay_buffer_kwargs=None,
    path_loader_kwargs=None,
    env_demo_path='',
    env_offpolicy_data_path='',
    debug=False,
    epsilon=1.0,
    exploration_policy_kwargs=None,
    evaluation_goal_sampling_mode=None,
    exploration_goal_sampling_mode=None,
    add_env_demos=False,
    add_env_offpolicy_data=False,
    save_paths=False,
    load_demos=False,
    pretrain_policy=False,
    pretrain_rl=False,
    save_pretrained_algorithm=False,

    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
    imsize=84,
    pretrained_vae_path="",
    presampled_goals_path="",
    init_camera=None,
    qf_class=ConcatMlp,
):

    #Kwarg Definitions
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if demo_replay_buffer_kwargs is None:
        demo_replay_buffer_kwargs = {}
    if path_loader_kwargs is None:
        path_loader_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    if debug:
        max_path_length = 5
        algo_kwargs['batch_size'] = 5
        algo_kwargs['num_epochs'] = 5
        algo_kwargs['num_eval_steps_per_epoch'] = 100
        algo_kwargs['num_expl_steps_per_train_loop'] = 100
        algo_kwargs['num_trains_per_train_loop'] = 10
        algo_kwargs['min_num_steps_before_training'] = 100
        algo_kwargs['min_num_steps_before_training'] = 100
        trainer_kwargs['bc_num_pretrain_steps'] = min(
            10, trainer_kwargs.get('bc_num_pretrain_steps', 0))
        trainer_kwargs['q_num_pretrain1_steps'] = min(
            10, trainer_kwargs.get('q_num_pretrain1_steps', 0))
        trainer_kwargs['q_num_pretrain2_steps'] = min(
            10, trainer_kwargs.get('q_num_pretrain2_steps', 0))

    #Enviorment Wrapping
    renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)

    def contextual_env_distrib_and_reward(env_id, env_class, env_kwargs,
                                          goal_sampling_mode,
                                          presampled_goals_path):
        state_env = get_gym_env(env_id,
                                env_class=env_class,
                                env_kwargs=env_kwargs)
        renderer = EnvRenderer(init_camera=init_camera, **renderer_kwargs)
        img_env = InsertImageEnv(state_env, renderer=renderer)

        # encoded_env = EncoderWrappedEnv(
        #     img_env,
        #     model,
        #     dict(image_observation="latent_observation", ),
        # )
        # if goal_sampling_mode == "vae_prior":
        #     latent_goal_distribution = PriorDistribution(
        #         model.representation_size,
        #         desired_goal_key,
        #     )
        #     diagnostics = StateImageGoalDiagnosticsFn({}, )
        # elif goal_sampling_mode == "presampled":
        #     diagnostics = state_env.get_contextual_diagnostics
        #     image_goal_distribution = PresampledPathDistribution(
        #         presampled_goals_path,
        #     )

        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        # elif goal_sampling_mode == "reset_of_env":
        #     state_goal_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
        #     state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
        #         state_goal_env,
        #         desired_goal_keys=[state_goal_key],
        #     )
        #     image_goal_distribution = AddImageDistribution(
        #         env=state_env,
        #         base_distribution=state_goal_distribution,
        #         image_goal_key=image_goal_key,
        #         renderer=renderer,
        #     )
        #     latent_goal_distribution = AddLatentDistribution(
        #         image_goal_distribution,
        #         image_goal_key,
        #         desired_goal_key,
        #         model,
        #     )
        #     no_goal_distribution = PriorDistribution(
        #         representation_size=0,
        #         key="no_goal",
        #     )
        #     diagnostics = state_goal_env.get_contextual_diagnostics
        # else:
        #     error
        diagnostics = StateImageGoalDiagnosticsFn({}, )
        no_goal_distribution = PriorDistribution(
            representation_size=0,
            key="no_goal",
        )

        reward_fn = GraspingRewardFn(
            # img_env, # state_env,
            # observation_key=observation_key,
            # desired_goal_key=desired_goal_key,
            # **reward_kwargs
        )

        env = ContextualEnv(
            img_env,  # state_env,
            context_distribution=no_goal_distribution,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diagnostics],
        )
        return env, no_goal_distribution, reward_fn

    #VAE Setup
    if pretrained_vae_path:
        model = load_local_or_remote_file(pretrained_vae_path)
    else:
        model = train_vae(train_vae_kwargs, env_kwargs, env_id, env_class,
                          imsize, init_camera)
    path_loader_kwargs['model_path'] = pretrained_vae_path

    #Enviorment Definitions
    expl_env, expl_context_distrib, expl_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, exploration_goal_sampling_mode,
        presampled_goals_path)
    eval_env, eval_context_distrib, eval_reward = contextual_env_distrib_and_reward(
        env_id, env_class, env_kwargs, evaluation_goal_sampling_mode,
        presampled_goals_path)
    path_loader_kwargs['env'] = eval_env

    #AWAC Code
    if add_env_demos:
        path_loader_kwargs["demo_paths"].append(env_demo_path)
    if add_env_offpolicy_data:
        path_loader_kwargs["demo_paths"].append(env_offpolicy_data_path)

    #Key Setting
    context_key = desired_goal_key
    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[context_key].low.size)
    action_dim = expl_env.action_space.low.size

    state_rewards = reward_kwargs.get('reward_type', 'dense') == 'wrapped_env'
    # if state_rewards:
    #     mapper = RemapKeyFn({context_key: observation_key, state_goal_key: state_observation_key})
    #     obs_keys = [state_observation_key, observation_key]
    #     cont_keys = [state_goal_key, context_key]
    # else:
    mapper = RemapKeyFn({context_key: observation_key})
    obs_keys = [observation_key]
    cont_keys = [context_key]

    #Replay Buffer
    def concat_context_to_obs(batch, replay_buffer, obs_dict, next_obs_dict,
                              new_contexts):
        obs = batch['observations']
        next_obs = batch['next_observations']
        context = batch[context_key]
        batch['observations'] = np.concatenate([obs, context], axis=1)
        batch['next_observations'] = np.concatenate([next_obs, context],
                                                    axis=1)
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=cont_keys,
        observation_keys=obs_keys,
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=mapper,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    replay_buffer_kwargs.update(demo_replay_buffer_kwargs)
    demo_train_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=cont_keys,
        observation_keys=obs_keys,
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=mapper,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)
    demo_test_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=cont_keys,
        observation_keys=obs_keys,
        observation_key=observation_key,
        context_distribution=expl_context_distrib,
        sample_context_from_obs_dict_fn=mapper,
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs)

    #Neural Network Architecture
    def create_qf():
        # return ConcatMlp(
        #     input_size=obs_dim + action_dim,
        #     output_size=1,
        #     **qf_kwargs
        # )
        if qf_class is ConcatMlp:
            qf_kwargs["input_size"] = obs_dim + action_dim
        if qf_class is ConcatCNN:
            qf_kwargs["added_fc_input_size"] = action_dim
        return qf_class(output_size=1, **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = policy_class(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **policy_kwargs,
    )

    #Path Collectors
    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=[
            context_key,
        ],
    )

    #Algorithm
    trainer = AWACTrainer(env=eval_env,
                          policy=policy,
                          qf1=qf1,
                          qf2=qf2,
                          target_qf1=target_qf1,
                          target_qf2=target_qf2,
                          **trainer_kwargs)

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)

    algorithm.to(ptu.device)

    #Video Saving
    if save_video:

        expl_video_func = RIGVideoSaveFunction(
            model,
            expl_path_collector,
            "train",
            # decode_goal_image_key="image_decoded_goal",
            # reconstruction_key="image_reconstruction",
            rows=2,
            columns=5,
            unnormalize=True,
            imsize=imsize,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(expl_video_func)

        eval_video_func = RIGVideoSaveFunction(
            model,
            eval_path_collector,
            "eval",
            # goal_image_key=image_goal_key,
            # decode_goal_image_key="image_decoded_goal",
            # reconstruction_key="image_reconstruction",
            num_imgs=4,
            rows=2,
            columns=5,
            unnormalize=True,
            imsize=imsize,
            image_format=renderer.output_image_format,
            **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)

    #AWAC CODE
    if save_paths:
        algorithm.post_train_funcs.append(save_paths)

    if load_demos:
        path_loader = path_loader_class(
            trainer,
            replay_buffer=replay_buffer,
            demo_train_buffer=demo_train_buffer,
            demo_test_buffer=demo_test_buffer,
            # reward_fn=eval_reward, # omit reward because its recomputed later
            **path_loader_kwargs)
        path_loader.load_demos()
    if pretrain_policy:
        trainer.pretrain_policy_with_bc(
            policy,
            demo_train_buffer,
            demo_test_buffer,
            trainer.bc_num_pretrain_steps,
        )
    if pretrain_rl:
        trainer.pretrain_q_with_bc_data()

    if save_pretrained_algorithm:
        p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p')
        pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt')
        data = algorithm._get_snapshot()
        data['algorithm'] = algorithm
        torch.save(data, open(pt_path, "wb"))
        torch.save(data, open(p_path, "wb"))

    algorithm.train()
Beispiel #18
0
def experiment(variant):
    render = variant.get("render", False)
    debug = variant.get("debug", False)
    vae_path = variant.get("vae_path", False)

    process_args(variant)

    env_class = variant.get("env_class")
    env_kwargs = variant.get("env_kwargs")
    env_id = variant.get("env_id")
    # expl_env = env_class(**env_kwargs)
    # eval_env = env_class(**env_kwargs)
    expl_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
    eval_env = get_gym_env(env_id, env_class=env_class, env_kwargs=env_kwargs)
    env = eval_env

    if variant.get('sparse_reward', False):
        expl_env = RewardWrapperEnv(expl_env, compute_hand_sparse_reward)
        eval_env = RewardWrapperEnv(eval_env, compute_hand_sparse_reward)

    if variant.get("vae_path", False):
        vae = load_local_or_remote_file(vae_path)
        variant['path_loader_kwargs']['model_path'] = vae_path
        renderer = EnvRenderer(**variant.get("renderer_kwargs", {}))
        expl_env = VQVAEWrappedEnv(InsertImageEnv(expl_env, renderer=renderer),
                                   vae,
                                   reward_params=variant.get(
                                       "reward_params", {}),
                                   **variant.get('vae_wrapped_env_kwargs', {}))
        eval_env = VQVAEWrappedEnv(InsertImageEnv(eval_env, renderer=renderer),
                                   vae,
                                   reward_params=variant.get(
                                       "reward_params", {}),
                                   **variant.get('vae_wrapped_env_kwargs', {}))
        env = eval_env
        variant['path_loader_kwargs']['env'] = env

    if variant.get('add_env_demos', False):
        variant["path_loader_kwargs"]["demo_paths"].append(
            variant["env_demo_path"])

    if variant.get('add_env_offpolicy_data', False):
        variant["path_loader_kwargs"]["demo_paths"].append(
            variant["env_offpolicy_data_path"])

    if variant.get("use_masks", False):
        mask_wrapper_kwargs = variant.get("mask_wrapper_kwargs", dict())

        expl_mask_distribution_kwargs = variant[
            "expl_mask_distribution_kwargs"]
        expl_mask_distribution = DiscreteDistribution(
            **expl_mask_distribution_kwargs)
        expl_env = RewardMaskWrapper(env, expl_mask_distribution,
                                     **mask_wrapper_kwargs)

        eval_mask_distribution_kwargs = variant[
            "eval_mask_distribution_kwargs"]
        eval_mask_distribution = DiscreteDistribution(
            **eval_mask_distribution_kwargs)
        eval_env = RewardMaskWrapper(env, eval_mask_distribution,
                                     **mask_wrapper_kwargs)
        env = eval_env

    if variant.get("pretrained_algorithm_path", False):
        resume(variant)
        return

    path_loader_kwargs = variant.get("path_loader_kwargs", {})
    stack_obs = path_loader_kwargs.get("stack_obs", 1)
    if stack_obs > 1:
        expl_env = StackObservationEnv(expl_env, stack_obs=stack_obs)
        eval_env = StackObservationEnv(eval_env, stack_obs=stack_obs)

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = variant.get('achieved_goal_key',
                                    'latent_achieved_goal')

    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = eval_env.action_space.low.size

    if hasattr(expl_env, 'info_sizes'):
        env_info_sizes = expl_env.info_sizes
    else:
        env_info_sizes = dict()

    replay_buffer_kwargs = dict(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
    )
    replay_buffer_kwargs.update(variant.get('replay_buffer_kwargs', dict()))
    replay_buffer = ConcatToObsWrapper(
        ObsDictRelabelingBuffer(**replay_buffer_kwargs),
        [
            "resampled_goals",
        ],
    )
    replay_buffer_kwargs.update(
        variant.get('demo_replay_buffer_kwargs', dict()))
    demo_train_buffer = ConcatToObsWrapper(
        ObsDictRelabelingBuffer(**replay_buffer_kwargs),
        [
            "resampled_goals",
        ],
    )
    demo_test_buffer = ConcatToObsWrapper(
        ObsDictRelabelingBuffer(**replay_buffer_kwargs),
        [
            "resampled_goals",
        ],
    )

    M = variant['layer_size']
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    policy_class = variant.get("policy_class", TanhGaussianPolicy)
    policy = policy_class(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **variant['policy_kwargs'],
    )

    expl_policy = policy
    exploration_kwargs = variant.get('exploration_kwargs', {})
    if exploration_kwargs:
        if exploration_kwargs.get("deterministic_exploration", False):
            expl_policy = MakeDeterministic(policy)

        exploration_strategy = exploration_kwargs.get("strategy", None)
        if exploration_strategy is None:
            pass
        elif exploration_strategy == 'ou':
            es = OUStrategy(
                action_space=expl_env.action_space,
                max_sigma=exploration_kwargs['noise'],
                min_sigma=exploration_kwargs['noise'],
            )
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=expl_policy,
            )
        elif exploration_strategy == 'gauss_eps':
            es = GaussianAndEpislonStrategy(
                action_space=expl_env.action_space,
                max_sigma=exploration_kwargs['noise'],
                min_sigma=exploration_kwargs['noise'],  # constant sigma
                epsilon=0,
            )
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=expl_policy,
            )
        else:
            error

    trainer = AWACTrainer(env=eval_env,
                          policy=policy,
                          qf1=qf1,
                          qf2=qf2,
                          target_qf1=target_qf1,
                          target_qf2=target_qf2,
                          **variant['trainer_kwargs'])
    if variant['collection_mode'] == 'online':
        expl_path_collector = MdpStepCollector(
            expl_env,
            policy,
        )
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant['max_path_length'],
            batch_size=variant['batch_size'],
            num_epochs=variant['num_epochs'],
            num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'],
            num_expl_steps_per_train_loop=variant[
                'num_expl_steps_per_train_loop'],
            num_trains_per_train_loop=variant['num_trains_per_train_loop'],
            min_num_steps_before_training=variant[
                'min_num_steps_before_training'],
        )
    else:
        eval_path_collector = GoalConditionedPathCollector(
            eval_env,
            MakeDeterministic(policy),
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            render=render,
            goal_sampling_mode=variant.get("goal_sampling_mode", None),
        )
        expl_path_collector = GoalConditionedPathCollector(
            expl_env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            render=render,
            goal_sampling_mode=variant.get("goal_sampling_mode", None),
        )
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant['max_path_length'],
            batch_size=variant['batch_size'],
            num_epochs=variant['num_epochs'],
            num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'],
            num_expl_steps_per_train_loop=variant[
                'num_expl_steps_per_train_loop'],
            num_trains_per_train_loop=variant['num_trains_per_train_loop'],
            min_num_steps_before_training=variant[
                'min_num_steps_before_training'],
        )
    algorithm.to(ptu.device)

    if variant.get("save_video", True):
        video_func = VideoSaveFunction(
            env,
            variant,
        )
        #algorithm.post_train_funcs.append(video_func)
        algorithm.post_train_funcs.append(video_func)

    # if variant.get("save_video", False):
    #     from rlkit.visualization.video import VideoSaveFunction
    #     renderer_kwargs = variant.get("renderer_kwargs", {})
    #     save_video_kwargs = variant.get("save_video_kwargs", {})

    #     def get_video_func(
    #         env,
    #         policy,
    #         tag,
    #     ):
    #         renderer = EnvRenderer(**renderer_kwargs)
    #         state_goal_distribution = GoalDictDistributionFromMultitaskEnv(
    #             env,
    #             desired_goal_keys=[desired_goal_key],
    #         )
    #         image_goal_distribution = AddImageDistribution(
    #             env=env,
    #             base_distribution=state_goal_distribution,
    #             image_goal_key='image_desired_goal',
    #             renderer=renderer,
    #         )
    #         img_env = InsertImageEnv(env, renderer=renderer)
    #         rollout_function = partial(
    #             rf.multitask_rollout,
    #             max_path_length=variant['max_path_length'],
    #             observation_key=observation_key,
    #             desired_goal_key=desired_goal_key,
    #             return_dict_obs=True,
    #         )
    #         reward_fn = ContextualRewardFnFromMultitaskEnv(
    #             env=env,
    #             achieved_goal_from_observation=IndexIntoAchievedGoal(observation_key),
    #             desired_goal_key=desired_goal_key,
    #             achieved_goal_key="state_achieved_goal",
    #         )
    #         contextual_env = ContextualEnv(
    #             img_env,
    #             context_distribution=image_goal_distribution,
    #             reward_fn=reward_fn,
    #             observation_key=observation_key,
    #         )
    #         video_func = get_save_video_function(
    #             rollout_function,
    #             contextual_env,
    #             policy,
    #             tag=tag,
    #             imsize=renderer.width,
    #             image_format='CWH',
    #             **save_video_kwargs
    #         )
    #         return video_func
    # expl_video_func = get_video_func(expl_env, expl_policy, "expl")
    # eval_video_func = get_video_func(eval_env, MakeDeterministic(policy), "eval")
    # algorithm.post_train_funcs.append(eval_video_func)
    # algorithm.post_train_funcs.append(expl_video_func)

    if variant.get('save_paths', False):
        algorithm.post_train_funcs.append(save_paths)

    if variant.get('load_demos', False):
        path_loader_class = variant.get('path_loader_class', MDPPathLoader)
        path_loader = path_loader_class(trainer,
                                        replay_buffer=replay_buffer,
                                        demo_train_buffer=demo_train_buffer,
                                        demo_test_buffer=demo_test_buffer,
                                        **path_loader_kwargs)
        path_loader.load_demos()
    if variant.get('pretrain_policy', False):
        trainer.pretrain_policy_with_bc(
            policy,
            demo_train_buffer,
            demo_test_buffer,
            trainer.bc_num_pretrain_steps,
        )
    if variant.get('pretrain_rl', False):
        trainer.pretrain_q_with_bc_data()

    if variant.get('save_pretrained_algorithm', False):
        p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p')
        pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt')
        data = algorithm._get_snapshot()
        data['algorithm'] = algorithm
        torch.save(data, open(pt_path, "wb"))
        torch.save(data, open(p_path, "wb"))

    algorithm.train()
Beispiel #19
0
def rl_context_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.torch.td3.td3 import TD3 as TD3Trainer
    from rlkit.torch.sac.sac import SACTrainer
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.sac.policies import MakeDeterministic

    preprocess_rl_variant(variant)
    max_path_length = variant['max_path_length']
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = variant.get('achieved_goal_key', 'latent_achieved_goal')

    contextual_mdp = variant.get('contextual_mdp', True)
    print("contextual_mdp:", contextual_mdp)

    mask_variant = variant.get('mask_variant', {})
    mask_conditioned = mask_variant.get('mask_conditioned', False)
    print("mask_conditioned:", mask_conditioned)

    if mask_conditioned:
        assert contextual_mdp

    if 'sac' in variant['algorithm'].lower():
        rl_algo = 'sac'
    elif 'td3' in variant['algorithm'].lower():
        rl_algo = 'td3'
    else:
        raise NotImplementedError
    print("RL algorithm:", rl_algo)

    ### load the example dataset, if running checkpoints ###
    if 'ckpt' in variant:
        import os.path as osp
        example_set_variant = variant.get('example_set_variant', dict())
        example_set_variant['use_cache'] = True
        example_set_variant['cache_path'] = osp.join(variant['ckpt'], 'example_dataset.npy')

    if mask_conditioned:
        env = get_envs(variant)
        mask_format = mask_variant['param_variant']['mask_format']
        assert mask_format in ['vector', 'matrix', 'distribution', 'cond_distribution']
        goal_dim = env.observation_space.spaces[desired_goal_key].low.size
        if mask_format in ['vector']:
            context_dim_for_networks = goal_dim + goal_dim
        elif mask_format in ['matrix', 'distribution', 'cond_distribution']:
            context_dim_for_networks = goal_dim + (goal_dim * goal_dim)
        else:
            raise TypeError

        if 'ckpt' in variant:
            from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
            import os.path as osp

            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'masks.npy'))
            masks = np.load(filename, allow_pickle=True)[()]
        else:
            masks = get_mask_params(
                env=env,
                example_set_variant=variant['example_set_variant'],
                param_variant=mask_variant['param_variant'],
            )

        mask_keys = list(masks.keys())
        context_keys = [desired_goal_key] + mask_keys
    else:
        context_keys = [desired_goal_key]


    def contextual_env_distrib_and_reward(mode='expl'):
        assert mode in ['expl', 'eval']
        env = get_envs(variant)

        if mode == 'expl':
            goal_sampling_mode = variant.get('expl_goal_sampling_mode', None)
        elif mode == 'eval':
            goal_sampling_mode = variant.get('eval_goal_sampling_mode', None)
        if goal_sampling_mode not in [None, 'example_set']:
            env.goal_sampling_mode = goal_sampling_mode

        mask_ids_for_training = mask_variant.get('mask_ids_for_training', None)

        if mask_conditioned:
            context_distrib = MaskDictDistribution(
                env,
                desired_goal_keys=[desired_goal_key],
                mask_format=mask_format,
                masks=masks,
                max_subtasks_to_focus_on=mask_variant.get('max_subtasks_to_focus_on', None),
                prev_subtask_weight=mask_variant.get('prev_subtask_weight', None),
                mask_distr=mask_variant.get('train_mask_distr', None),
                mask_ids=mask_ids_for_training,
            )
            reward_fn = ContextualMaskingRewardFn(
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                mask_keys=mask_keys,
                mask_format=mask_format,
                use_g_for_mean=mask_variant['use_g_for_mean'],
                use_squared_reward=mask_variant.get('use_squared_reward', False),
            )
        else:
            if goal_sampling_mode == 'example_set':
                example_dataset = gen_example_sets(get_envs(variant), variant['example_set_variant'])
                assert len(example_dataset['list_of_waypoints']) == 1
                from rlkit.envs.contextual.set_distributions import GoalDictDistributionFromSet
                context_distrib = GoalDictDistributionFromSet(
                    example_dataset['list_of_waypoints'][0],
                    desired_goal_keys=[desired_goal_key],
                )
            else:
                context_distrib = GoalDictDistributionFromMultitaskEnv(
                    env,
                    desired_goal_keys=[desired_goal_key],
                )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=variant['contextual_replay_buffer_kwargs'].get('observation_keys', None),
            )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=context_distrib,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info if not variant.get('keep_env_infos', False) else None,
        )
        return env, context_distrib, reward_fn

    env, context_distrib, reward_fn = contextual_env_distrib_and_reward(mode='expl')
    eval_env, eval_context_distrib, _ = contextual_env_distrib_and_reward(mode='eval')

    if mask_conditioned:
        obs_dim = (
            env.observation_space.spaces[observation_key].low.size
            + context_dim_for_networks
        )
    elif contextual_mdp:
        obs_dim = (
            env.observation_space.spaces[observation_key].low.size
            + env.observation_space.spaces[desired_goal_key].low.size
        )
    else:
        obs_dim = env.observation_space.spaces[observation_key].low.size

    action_dim = env.action_space.low.size

    if 'ckpt' in variant and 'ckpt_epoch' in variant:
        from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
        import os.path as osp

        ckpt_epoch = variant['ckpt_epoch']
        if ckpt_epoch is not None:
            epoch = variant['ckpt_epoch']
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch))
        else:
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'params.pkl'))
        print("Loading ckpt from", filename)
        data = torch.load(filename)
        qf1 = data['trainer/qf1']
        qf2 = data['trainer/qf2']
        target_qf1 = data['trainer/target_qf1']
        target_qf2 = data['trainer/target_qf2']
        policy = data['trainer/policy']
        eval_policy = data['evaluation/policy']
        expl_policy = data['exploration/policy']
    else:
        qf1 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        qf2 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        target_qf1 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        target_qf2 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        if rl_algo == 'td3':
            policy = TanhMlpPolicy(
                input_size=obs_dim,
                output_size=action_dim,
                **variant['policy_kwargs']
            )
            target_policy = TanhMlpPolicy(
                input_size=obs_dim,
                output_size=action_dim,
                **variant['policy_kwargs']
            )
            expl_policy = create_exploration_policy(
                env, policy,
                exploration_version=variant['exploration_type'],
                exploration_noise=variant['exploration_noise'],
            )
            eval_policy = policy
        elif rl_algo == 'sac':
            policy = TanhGaussianPolicy(
                obs_dim=obs_dim,
                action_dim=action_dim,
                **variant['policy_kwargs']
            )
            expl_policy = policy
            eval_policy = MakeDeterministic(policy)

    post_process_mask_fn = partial(
        full_post_process_mask_fn,
        mask_conditioned=mask_conditioned,
        mask_variant=mask_variant,
        context_distrib=context_distrib,
        context_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
    )

    def context_from_obs_dict_fn(obs_dict):
        context_dict = {
            desired_goal_key: obs_dict[achieved_goal_key]
        }

        if mask_conditioned:
            sample_masks_for_relabeling = mask_variant.get('sample_masks_for_relabeling', True)
            if sample_masks_for_relabeling:
                batch_size = next(iter(obs_dict.values())).shape[0]
                sampled_contexts = context_distrib.sample(batch_size)
                for mask_key in mask_keys:
                    context_dict[mask_key] = sampled_contexts[mask_key]
            else:
                for mask_key in mask_keys:
                    context_dict[mask_key] = obs_dict[mask_key]

        return context_dict

    def concat_context_to_obs(batch, replay_buffer=None, obs_dict=None, next_obs_dict=None, new_contexts=None):
        obs = batch['observations']
        next_obs = batch['next_observations']
        batch_size = obs.shape[0]
        if mask_conditioned:
            if obs_dict is not None and new_contexts is not None:
                if not mask_variant.get('relabel_masks', True):
                    for k in mask_keys:
                        new_contexts[k] = next_obs_dict[k][:]
                    batch.update(new_contexts)
                if not mask_variant.get('relabel_goals', True):
                    new_contexts[desired_goal_key] = next_obs_dict[desired_goal_key][:]
                    batch.update(new_contexts)

                new_contexts = post_process_mask_fn(obs_dict, new_contexts)
                batch.update(new_contexts)

            if mask_format in ['vector', 'matrix']:
                goal = batch[desired_goal_key]
                mask = batch['mask'].reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, goal, mask], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, goal, mask], axis=1)
            elif mask_format == 'distribution':
                goal = batch[desired_goal_key]
                sigma_inv = batch['mask_sigma_inv'].reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, goal, sigma_inv], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, goal, sigma_inv], axis=1)
            elif mask_format == 'cond_distribution':
                goal = batch[desired_goal_key]
                mu_w = batch['mask_mu_w']
                mu_g = batch['mask_mu_g']
                mu_A = batch['mask_mu_mat']
                sigma_inv = batch['mask_sigma_inv']
                if mask_variant['use_g_for_mean']:
                    mu_w_given_g = goal
                else:
                    mu_w_given_g = mu_w + np.squeeze(mu_A @ np.expand_dims(goal - mu_g, axis=-1), axis=-1)
                sigma_w_given_g_inv = sigma_inv.reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
            else:
                raise NotImplementedError
        elif contextual_mdp:
            goal = batch[desired_goal_key]
            batch['observations'] = np.concatenate([obs, goal], axis=1)
            batch['next_observations'] = np.concatenate([next_obs, goal], axis=1)
        else:
            batch['observations'] = obs
            batch['next_observations'] = next_obs

        return batch

    if 'observation_keys' not in variant['contextual_replay_buffer_kwargs']:
        variant['contextual_replay_buffer_kwargs']['observation_keys'] = []
    observation_keys = variant['contextual_replay_buffer_kwargs']['observation_keys']
    if observation_key not in observation_keys:
        observation_keys.append(observation_key)
    if achieved_goal_key not in observation_keys:
        observation_keys.append(achieved_goal_key)

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=env,
        context_keys=context_keys,
        context_distribution=context_distrib,
        sample_context_from_obs_dict_fn=context_from_obs_dict_fn,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        **variant['contextual_replay_buffer_kwargs']
    )

    if rl_algo == 'td3':
        trainer = TD3Trainer(
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            target_policy=target_policy,
            **variant['td3_trainer_kwargs']
        )
    elif rl_algo == 'sac':
        trainer = SACTrainer(
            env=env,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            **variant['sac_trainer_kwargs']
        )

    def create_path_collector(
            env,
            policy,
            mode='expl',
            mask_kwargs={},
    ):
        assert mode in ['expl', 'eval']

        save_env_in_snapshot = variant.get('save_env_in_snapshot', True)

        if mask_conditioned:
            if 'rollout_mask_order' in mask_kwargs:
                rollout_mask_order = mask_kwargs['rollout_mask_order']
            else:
                if mode == 'expl':
                    rollout_mask_order = mask_variant.get('rollout_mask_order_for_expl', 'fixed')
                elif mode == 'eval':
                    rollout_mask_order = mask_variant.get('rollout_mask_order_for_eval', 'fixed')
                else:
                    raise TypeError

            if 'mask_distr' in mask_kwargs:
                mask_distr = mask_kwargs['mask_distr']
            else:
                if mode == 'expl':
                    mask_distr = mask_variant['expl_mask_distr']
                elif mode == 'eval':
                    mask_distr = mask_variant['eval_mask_distr']
                else:
                    raise TypeError

            if 'mask_ids' in mask_kwargs:
                mask_ids = mask_kwargs['mask_ids']
            else:
                if mode == 'expl':
                    mask_ids = mask_variant.get('mask_ids_for_expl', None)
                elif mode == 'eval':
                    mask_ids = mask_variant.get('mask_ids_for_eval', None)
                else:
                    raise TypeError

            prev_subtask_weight = mask_variant.get('prev_subtask_weight', None)
            max_subtasks_to_focus_on = mask_variant.get('max_subtasks_to_focus_on', None)
            max_subtasks_per_rollout = mask_variant.get('max_subtasks_per_rollout', None)

            mode = mask_variant.get('context_post_process_mode', None)
            if mode in ['dilute_prev_subtasks_uniform', 'dilute_prev_subtasks_fixed']:
                prev_subtask_weight = 0.5

            return MaskPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                concat_context_to_obs_fn=concat_context_to_obs,
                save_env_in_snapshot=save_env_in_snapshot,
                mask_sampler=(context_distrib if mode=='expl' else eval_context_distrib),
                mask_distr=mask_distr.copy(),
                mask_ids=mask_ids,
                max_path_length=max_path_length,
                rollout_mask_order=rollout_mask_order,
                prev_subtask_weight=prev_subtask_weight,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                max_subtasks_per_rollout=max_subtasks_per_rollout,
            )
        elif contextual_mdp:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                save_env_in_snapshot=save_env_in_snapshot,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[],
                save_env_in_snapshot=save_env_in_snapshot,
            )

    expl_path_collector = create_path_collector(env, expl_policy, mode='expl')
    eval_path_collector = create_path_collector(eval_env, eval_policy, mode='eval')

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **variant['algo_kwargs']
    )

    algorithm.to(ptu.device)

    if variant.get("save_video", True):
        save_period = variant.get('save_video_period', 50)
        dump_video_kwargs = variant.get("dump_video_kwargs", dict())
        dump_video_kwargs['horizon'] = max_path_length

        renderer = EnvRenderer(**variant.get('renderer_kwargs', {}))

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImagesEnv(state_env, renderers={
                'image_observation' : renderer,
            })
            context_env = ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=reward_fn,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
            return context_env

        img_eval_env = add_images(eval_env, eval_context_distrib)

        if variant.get('log_eval_video', True):
            video_path_collector = create_path_collector(img_eval_env, eval_policy, mode='eval')
            rollout_function = video_path_collector._rollout_fn
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                eval_policy,
                tag="eval",
                imsize=variant['renderer_kwargs']['width'],
                image_format='CHW',
                save_video_period=save_period,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(eval_video_func)

        # additional eval videos for mask conditioned case
        if mask_conditioned:
            default_list = [
                'atomic',
                'atomic_seq',
                'cumul_seq',
                'full',
            ]
            eval_rollouts_for_videos = mask_variant.get('eval_rollouts_for_videos', default_list)
            for key in eval_rollouts_for_videos:
                assert key in default_list

            if 'cumul_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            cumul_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_cumul" if mask_conditioned else "eval",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'full' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            full=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_full",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'atomic_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            atomic_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_atomic",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

        if variant.get('log_expl_video', True) and not variant['algo_kwargs'].get('eval_only', False):
            img_expl_env = add_images(env, context_distrib)
            video_path_collector = create_path_collector(img_expl_env, expl_policy, mode='expl')
            rollout_function = video_path_collector._rollout_fn
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                expl_policy,
                tag="expl",
                imsize=variant['renderer_kwargs']['width'],
                image_format='CHW',
                save_video_period=save_period,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(expl_video_func)

    addl_collectors = []
    addl_log_prefixes = []
    if mask_conditioned and mask_variant.get('log_mask_diagnostics', True):
        default_list = [
            'atomic',
            'atomic_seq',
            'cumul_seq',
            'full',
        ]
        eval_rollouts_to_log = mask_variant.get('eval_rollouts_to_log', default_list)
        for key in eval_rollouts_to_log:
            assert key in default_list

        # atomic masks
        if 'atomic' in eval_rollouts_to_log:
            for mask_id in eval_path_collector.mask_ids:
                mask_kwargs=dict(
                    mask_ids=[mask_id],
                    mask_distr=dict(
                        atomic=1.0,
                    ),
                )
                collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
                addl_collectors.append(collector)
            addl_log_prefixes += [
                'mask_{}/'.format(''.join(str(mask_id)))
                for mask_id in eval_path_collector.mask_ids
            ]

        # full mask
        if 'full' in eval_rollouts_to_log:
            mask_kwargs=dict(
                mask_distr=dict(
                    full=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_full/')

        # cumulative, sequential mask
        if 'cumul_seq' in eval_rollouts_to_log:
            mask_kwargs=dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    cumul_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_cumul_seq/')

        # atomic, sequential mask
        if 'atomic_seq' in eval_rollouts_to_log:
            mask_kwargs=dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    atomic_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_atomic_seq/')

        def get_mask_diagnostics(unused):
            from rlkit.core.logging import append_log, add_prefix, OrderedDict
            log = OrderedDict()
            for prefix, collector in zip(addl_log_prefixes, addl_collectors):
                paths = collector.collect_new_paths(
                    max_path_length,
                    variant['algo_kwargs']['num_eval_steps_per_epoch'],
                    discard_incomplete_paths=True,
                )
                old_path_info = eval_env.get_diagnostics(paths)

                keys_to_keep = []
                for key in old_path_info.keys():
                    if ('env_infos' in key) and ('final' in key) and ('Mean' in key):
                        keys_to_keep.append(key)
                path_info = OrderedDict()
                for key in keys_to_keep:
                    path_info[key] = old_path_info[key]

                generic_info = add_prefix(
                    path_info,
                    prefix,
                )
                append_log(log, generic_info)

            for collector in addl_collectors:
                collector.end_epoch(0)
            return log

        algorithm._eval_get_diag_fns.append(get_mask_diagnostics)
        
    if 'ckpt' in variant:
        from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
        import os.path as osp
        assert variant['algo_kwargs'].get('eval_only', False)

        def update_networks(algo, epoch):
            if 'ckpt_epoch' in variant:
                return

            if epoch % algo._eval_epoch_freq == 0:
                filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch))
                print("Loading ckpt from", filename)
                data = torch.load(filename)#, map_location='cuda:1')
                eval_policy = data['evaluation/policy']
                eval_policy.to(ptu.device)
                algo.eval_data_collector._policy = eval_policy
                for collector in addl_collectors:
                    collector._policy = eval_policy

        algorithm.post_train_funcs.insert(0, update_networks)

    algorithm.train()
Beispiel #20
0
def disco_experiment(
    max_path_length,
    qf_kwargs,
    sac_trainer_kwargs,
    replay_buffer_kwargs,
    policy_kwargs,
    algo_kwargs,
    generate_set_for_rl_kwargs,
    # VAE parameters
    create_vae_kwargs,
    vae_trainer_kwargs,
    vae_algo_kwargs,
    data_loader_kwargs,
    generate_set_for_vae_pretraining_kwargs,
    num_ungrouped_images,
    beta_schedule_kwargs=None,
    # Oracle settings
    use_ground_truth_reward=False,
    use_onehot_set_embedding=False,
    use_dummy_model=False,
    observation_key="latent_observation",
    # RIG comparison
    rig_goal_setter_kwargs=None,
    rig=False,
    # Miscellaneous
    reward_fn_kwargs=None,
    # None-VAE Params
    env_id=None,
    env_class=None,
    env_kwargs=None,
    latent_observation_key="latent_observation",
    state_observation_key="state_observation",
    image_observation_key="image_observation",
    set_description_key="set_description",
    example_state_key="example_state",
    example_image_key="example_image",
    # Exploration
    exploration_policy_kwargs=None,
    # Video parameters
    save_video=True,
    save_video_kwargs=None,
    renderer_kwargs=None,
):
    if rig_goal_setter_kwargs is None:
        rig_goal_setter_kwargs = {}
    if reward_fn_kwargs is None:
        reward_fn_kwargs = {}
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    if not save_video_kwargs:
        save_video_kwargs = {}
    if not renderer_kwargs:
        renderer_kwargs = {}

    renderer = EnvRenderer(**renderer_kwargs)

    sets = create_sets(
        env_id,
        env_class,
        env_kwargs,
        renderer,
        example_state_key=example_state_key,
        example_image_key=example_image_key,
        **generate_set_for_rl_kwargs,
    )
    if use_dummy_model:
        model = create_dummy_image_vae(img_chw=renderer.image_chw,
                                       **create_vae_kwargs)
    else:
        model = train_set_vae(
            create_vae_kwargs,
            vae_trainer_kwargs,
            vae_algo_kwargs,
            data_loader_kwargs,
            generate_set_for_vae_pretraining_kwargs,
            num_ungrouped_images,
            env_id=env_id,
            env_class=env_class,
            env_kwargs=env_kwargs,
            beta_schedule_kwargs=beta_schedule_kwargs,
            sets=sets,
            renderer=renderer,
        )
    expl_env, expl_context_distrib, expl_reward = (
        contextual_env_distrib_and_reward(
            vae=model,
            sets=sets,
            state_env=get_gym_env(
                env_id,
                env_class=env_class,
                env_kwargs=env_kwargs,
            ),
            renderer=renderer,
            reward_fn_kwargs=reward_fn_kwargs,
            use_ground_truth_reward=use_ground_truth_reward,
            state_observation_key=state_observation_key,
            latent_observation_key=latent_observation_key,
            example_image_key=example_image_key,
            set_description_key=set_description_key,
            observation_key=observation_key,
            image_observation_key=image_observation_key,
            rig_goal_setter_kwargs=rig_goal_setter_kwargs,
        ))
    eval_env, eval_context_distrib, eval_reward = (
        contextual_env_distrib_and_reward(
            vae=model,
            sets=sets,
            state_env=get_gym_env(
                env_id,
                env_class=env_class,
                env_kwargs=env_kwargs,
            ),
            renderer=renderer,
            reward_fn_kwargs=reward_fn_kwargs,
            use_ground_truth_reward=use_ground_truth_reward,
            state_observation_key=state_observation_key,
            latent_observation_key=latent_observation_key,
            example_image_key=example_image_key,
            set_description_key=set_description_key,
            observation_key=observation_key,
            image_observation_key=image_observation_key,
            rig_goal_setter_kwargs=rig_goal_setter_kwargs,
            oracle_rig_goal=rig,
        ))
    context_keys = [
        expl_context_distrib.mean_key,
        expl_context_distrib.covariance_key,
        expl_context_distrib.set_index_key,
        expl_context_distrib.set_embedding_key,
    ]
    if rig:
        context_keys_for_rl = [
            expl_context_distrib.mean_key,
        ]
    else:
        if use_onehot_set_embedding:
            context_keys_for_rl = [
                expl_context_distrib.set_embedding_key,
            ]
        else:
            context_keys_for_rl = [
                expl_context_distrib.mean_key,
                expl_context_distrib.covariance_key,
            ]

    obs_dim = np.prod(expl_env.observation_space.spaces[observation_key].shape)
    obs_dim += sum([
        np.prod(expl_env.observation_space.spaces[k].shape)
        for k in context_keys_for_rl
    ])
    action_dim = np.prod(expl_env.action_space.shape)

    def create_qf():
        return ConcatMlp(input_size=obs_dim + action_dim,
                         output_size=1,
                         **qf_kwargs)

    qf1 = create_qf()
    qf2 = create_qf()
    target_qf1 = create_qf()
    target_qf2 = create_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    def concat_context_to_obs(batch, *args, **kwargs):
        obs = batch["observations"]
        next_obs = batch["next_observations"]
        contexts = [batch[k] for k in context_keys_for_rl]
        batch["observations"] = np.concatenate((obs, *contexts), axis=1)
        batch["next_observations"] = np.concatenate(
            (next_obs, *contexts),
            axis=1,
        )
        return batch

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=eval_env,
        context_keys=context_keys,
        observation_keys=list(
            {observation_key, state_observation_key, latent_observation_key}),
        observation_key=observation_key,
        context_distribution=FilterKeys(
            expl_context_distrib,
            context_keys,
        ),
        sample_context_from_obs_dict_fn=None,
        # RemapKeyFn({context_key: observation_key}),
        reward_fn=eval_reward,
        post_process_batch_fn=concat_context_to_obs,
        **replay_buffer_kwargs,
    )
    trainer = SACTrainer(
        env=expl_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **sac_trainer_kwargs,
    )

    eval_path_collector = ContextualPathCollector(
        eval_env,
        MakeDeterministic(policy),
        observation_key=observation_key,
        context_keys_for_policy=context_keys_for_rl,
    )
    exploration_policy = create_exploration_policy(expl_env, policy,
                                                   **exploration_policy_kwargs)
    expl_path_collector = ContextualPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        context_keys_for_policy=context_keys_for_rl,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs,
    )
    algorithm.to(ptu.device)

    if save_video:
        set_index_key = eval_context_distrib.set_index_key
        expl_video_func = DisCoVideoSaveFunction(
            model,
            sets,
            expl_path_collector,
            tag="train",
            reconstruction_key="image_reconstruction",
            decode_set_image_key="decoded_set_prior",
            set_visualization_key="set_visualization",
            example_image_key=example_image_key,
            set_index_key=set_index_key,
            columns=len(sets),
            unnormalize=True,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs,
        )
        algorithm.post_train_funcs.append(expl_video_func)

        eval_video_func = DisCoVideoSaveFunction(
            model,
            sets,
            eval_path_collector,
            tag="eval",
            reconstruction_key="image_reconstruction",
            decode_set_image_key="decoded_set_prior",
            set_visualization_key="set_visualization",
            example_image_key=example_image_key,
            set_index_key=set_index_key,
            columns=len(sets),
            unnormalize=True,
            imsize=48,
            image_format=renderer.output_image_format,
            **save_video_kwargs,
        )
        algorithm.post_train_funcs.append(eval_video_func)

    algorithm.train()