Ejemplo n.º 1
0
 def __init__(self,
              p_replace,
              p_add_non_diverse,
              goal_buffer_size=1024,
              *args,
              **kwargs):
     super().__init__(*args, **kwargs)
     assert self.collection_mode != 'online-parallel', "not sure what happens to sample_goals"
     self.p_replace = p_replace
     self.p_add_non_diverse = p_add_non_diverse
     self.goal_buffer = OnlineVaeRelabelingBuffer(
         self.vae,
         max_size=goal_buffer_size,
         env=self.replay_buffer.env,
         observation_key='latent_observation',
         desired_goal_key='latent_desired_goal',
         achieved_goal_key='latent_achieved_goal',
     )
     self.env.goal_sampler = self.sample_goals
Ejemplo n.º 2
0
def tdm_td3_experiment_online_vae(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.state_distance.tdm_networks import TdmQf, TdmPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.torch.online_vae.online_vae_tdm_td3 import OnlineVaeTdmTd3
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    vae_trainer_kwargs = variant.get('vae_trainer_kwargs')
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size)
    goal_dim = (env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size

    vectorized = 'vectorized' in env.reward_type
    variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][
        'vectorized'] = vectorized

    norm_order = env.norm_order
    # variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][
    #     'norm_order'] = norm_order

    qf1 = TdmQf(env=env,
                vectorized=vectorized,
                norm_order=norm_order,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                vectorized=vectorized,
                norm_order=norm_order,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       observation_dim=obs_dim,
                       goal_dim=goal_dim,
                       action_dim=action_dim,
                       **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']['tdm_td3_kwargs']
    td3_kwargs = algo_kwargs['td3_kwargs']
    td3_kwargs['training_env'] = env
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algo_kwargs["replay_buffer"] = replay_buffer

    t = ConvVAETrainer(variant['vae_train_data'],
                       variant['vae_test_data'],
                       vae,
                       beta=variant['online_vae_beta'],
                       **vae_trainer_kwargs)
    render = variant["render"]
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    algorithm = OnlineVaeTdmTd3(
        online_vae_kwargs=dict(vae=vae,
                               vae_trainer=t,
                               **variant['algo_kwargs']['online_vae_kwargs']),
        tdm_td3_kwargs=dict(env=env,
                            qf1=qf1,
                            qf2=qf2,
                            policy=policy,
                            exploration_policy=exploration_policy,
                            **variant['algo_kwargs']['tdm_td3_kwargs']),
    )

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    if variant.get("save_video", True):
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm._sample_max_tau_for_rollout(),
            decrement_tau=algorithm.cycle_taus_for_rollout,
            cycle_tau=algorithm.cycle_taus_for_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)

    algorithm.train()
Ejemplo n.º 3
0
def td3_experiment_online_vae_exploring(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.torch.her.online_vae_joint_algo import OnlineVaeHerJointAlgo
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.torch.td3.td3 import TD3
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs'],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    exploring_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    exploring_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    exploring_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs'],
    )
    exploring_exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=exploring_policy,
    )

    vae = env.vae
    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    variant["algo_kwargs"]["replay_buffer"] = replay_buffer
    if variant.get('use_replay_buffer_goals', False):
        env.replay_buffer = replay_buffer
        env.use_replay_buffer_goals = True

    vae_trainer_kwargs = variant.get('vae_trainer_kwargs')
    t = ConvVAETrainer(variant['vae_train_data'],
                       variant['vae_test_data'],
                       vae,
                       beta=variant['online_vae_beta'],
                       **vae_trainer_kwargs)

    control_algorithm = TD3(env=env,
                            training_env=env,
                            qf1=qf1,
                            qf2=qf2,
                            policy=policy,
                            exploration_policy=exploration_policy,
                            **variant['algo_kwargs'])
    exploring_algorithm = TD3(env=env,
                              training_env=env,
                              qf1=exploring_qf1,
                              qf2=exploring_qf2,
                              policy=exploring_policy,
                              exploration_policy=exploring_exploration_policy,
                              **variant['algo_kwargs'])

    assert 'vae_training_schedule' not in variant,\
        "Just put it in joint_algo_kwargs"
    algorithm = OnlineVaeHerJointAlgo(vae=vae,
                                      vae_trainer=t,
                                      env=env,
                                      training_env=env,
                                      policy=policy,
                                      exploration_policy=exploration_policy,
                                      replay_buffer=replay_buffer,
                                      algo1=control_algorithm,
                                      algo2=exploring_algorithm,
                                      algo1_prefix="Control_",
                                      algo2_prefix="VAE_Exploration_",
                                      observation_key=observation_key,
                                      desired_goal_key=desired_goal_key,
                                      **variant['joint_algo_kwargs'])

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    if variant.get("save_video", True):
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)
    algorithm.train()
def skewfit_experiment(variant, other_variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer

    skewfit_preprocess_variant(variant)
    env = get_envs(variant)

    uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **variant['generate_uniform_dataset_kwargs'])
    else:
        uniform_dataset = None

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=hidden_sizes,
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=env.vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    vae_trainer = ConvVAETrainer(variant['vae_train_data'],
                                 variant['vae_test_data'], env.vae,
                                 other_variant,
                                 **variant['online_vae_trainer_kwargs'])
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    max_path_length = variant['max_path_length']

    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['twin_sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        variant['evaluation_goal_sampling_mode'],
        env,
        MakeDeterministic(policy),
        max_path_length,
        other_variant=other_variant,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        variant['exploration_goal_sampling_mode'],
        env,
        policy,
        max_path_length,
        other_variant=other_variant,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    algorithm = OnlineVaeAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        uniform_dataset=uniform_dataset,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    if variant['custom_goal_sampler'] == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Ejemplo n.º 5
0
def skewfit_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer \
        import OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    import rlkit.torch.vae.vae_schedules as vae_schedules

    #### getting parameter for training VAE and RIG
    env = get_envs(variant)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    replay_buffer_kwargs = variant.get(
        'replay_buffer_kwargs',
        dict(
            start_skew_epoch=10,
            max_size=int(100000),
            fraction_goals_rollout_goals=0.2,
            fraction_goals_env_goals=0.5,
            exploration_rewards_type='None',
            vae_priority_type='vae_prob',
            priority_function_kwargs=dict(
                sampling_method='importance_sampling',
                decoder_distribution='gaussian_identity_variance',
                num_latents_to_sample=10,
            ),
            power=0,
            relabeling_goal_sampling_mode='vae_prior',
        ))
    online_vae_trainer_kwargs = variant.get('online_vae_trainer_kwargs',
                                            dict(beta=20, lr=1e-3))
    max_path_length = variant.get('max_path_length', 50)
    algo_kwargs = variant.get(
        'algo_kwargs',
        dict(
            batch_size=1024,
            num_epochs=1000,
            num_eval_steps_per_epoch=500,
            num_expl_steps_per_train_loop=500,
            num_trains_per_train_loop=1000,
            min_num_steps_before_training=10000,
            vae_training_schedule=vae_schedules.custom_schedule_2,
            oracle_data=False,
            vae_save_period=50,
            parallel_vae_train=False,
        ))
    twin_sac_trainer_kwargs = variant.get(
        'twin_sac_trainer_kwargs',
        dict(
            discount=0.99,
            reward_scale=1,
            soft_target_tau=1e-3,
            target_update_period=1,  # 1
            use_automatic_entropy_tuning=True,
        ))
    ############################################################################

    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=hidden_sizes)
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=hidden_sizes)
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=hidden_sizes)
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=hidden_sizes)
    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                hidden_sizes=hidden_sizes)

    vae = variant['vae_model']
    # create a replay buffer for training an online VAE
    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **replay_buffer_kwargs)
    # create an online vae_trainer to train vae on the fly
    vae_trainer = ConvVAETrainer(variant['vae_train_data'],
                                 variant['vae_test_data'], vae,
                                 **online_vae_trainer_kwargs)
    # create a SACTrainer to learn a soft Q-function and appropriate policy
    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **twin_sac_trainer_kwargs)
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        variant.get('evaluation_goal_sampling_mode', 'reset_of_env'),
        env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        variant.get('exploration_goal_sampling_mode', 'vae_prior'),
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = OnlineVaeAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        max_path_length=max_path_length,
        **algo_kwargs)

    if variant['custom_goal_sampler'] == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Ejemplo n.º 6
0
class DiverseGoals(OnlineVaeHerTwinSac):
    def __init__(self,
                 p_replace,
                 p_add_non_diverse,
                 goal_buffer_size=1024,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        assert self.collection_mode != 'online-parallel', "not sure what happens to sample_goals"
        self.p_replace = p_replace
        self.p_add_non_diverse = p_add_non_diverse
        self.goal_buffer = OnlineVaeRelabelingBuffer(
            self.vae,
            max_size=goal_buffer_size,
            env=self.replay_buffer.env,
            observation_key='latent_observation',
            desired_goal_key='latent_desired_goal',
            achieved_goal_key='latent_achieved_goal',
        )
        self.env.goal_sampler = self.sample_goals

    def _post_epoch(self, epoch):
        super()._post_epoch(epoch)
        should_train, amount_to_train = self.vae_training_schedule(epoch)
        rl_start_epoch = int(self.min_num_steps_before_training /
                             self.num_env_steps_per_epoch)
        if should_train or epoch <= (rl_start_epoch - 1):
            self.goal_buffer.refresh_latents(epoch)

    def _handle_path(self, path):
        self.handle_goal_buffer(path)
        super()._handle_path(path)

    def _handle_rollout_ending(self):
        if len(self._current_path_builder) > 0:
            path = self._current_path_builder.get_all_stacked()
            self.handle_goal_buffer(path)
        super()._handle_rollout_ending()

    def handle_goal_buffer(self, path):
        """
        Note that we only care about next_obs for goal relabeling.
        """
        next_observations = path['next_observations']
        for next_obs in next_observations:
            self.handle_goal_buffer_step(
                obs=None,
                action=None,
                rewards=None,
                terminal=None,
                next_observation=next_obs,
            )

    def set_goal_buffer_goal(self, idx, next_obs):
        """
        We only keep track of the 'image_observation' and 'latent_observation' of
        next_observation as goals are sampled based on next_observation.
        """
        self.goal_buffer._next_obs['image_observation'][idx] = \
                unnormalize_image(next_obs['image_observation'])
        self.goal_buffer._next_obs['latent_observation'][idx] = \
                next_obs['latent_observation']

    def sample_goals(self, batch_size):
        if self.goal_buffer._size == 0:
            return None
        goal_idxs = self.goal_buffer._sample_indices(batch_size)
        goals = {
            'latent_desired_goal':
            self.goal_buffer._next_obs['latent_observation'][goal_idxs],
            'image_desired_goal':
            normalize_image(
                self.goal_buffer._next_obs['image_observation'][goal_idxs])
        }
        return goals

    def handle_goal_buffer_step(self, obs, action, rewards, terminal,
                                next_observation):
        if self.goal_buffer._size < self.goal_buffer.max_size:
            self.set_goal_buffer_goal(self.goal_buffer._size, next_observation)
            self.goal_buffer._size += 1
        else:
            """
            Goal buffer is full. With prob self.p_replace, consider as a goal
            candidate.
            """
            if np.random.random() > self.p_replace:
                return
            """
            Sample random goal for goal buffer and replace if sampled goal is a
            closer neighbor of replay buffer
            """
            goal_idx = self.goal_buffer._sample_indices(1)

            candidate_goal = next_observation['latent_observation']
            goal = self.goal_buffer._next_obs['latent_observation'][goal_idx]
            goal_dist = 0.0
            candidate_dist = 0.0
            for i in range(0, self.goal_buffer._size):
                if i == goal_idx:
                    continue
                cur_goal = self.goal_buffer._next_obs['latent_observation'][i]
                candidate_dist += np.linalg.norm(candidate_goal - cur_goal)
                goal_dist += np.linalg.norm(goal_dist - cur_goal)
            """
            Replace the sampled goal with the candidate goal if sampled goal is
            closer or if prob p_add_non_diverse
            """
            if (goal_dist < candidate_dist
                    or np.random.random() > self.p_add_non_diverse):
                self.set_goal_buffer_goal(goal_idx, next_observation)
Ejemplo n.º 7
0
def skewfit_experiment(cfgs):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer

    skewfit_preprocess_variant(cfgs)
    env = get_envs(cfgs)

    # TODO
    uniform_dataset_fn = cfgs.GENERATE_VAE_DATASET.get(
        'uniform_dataset_generator', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **cfgs.GENERATE_VAE_DATASET.generate_uniform_dataset_kwargs)
    else:
        uniform_dataset = None

    observation_key = cfgs.SKEW_FIT.get('observation_key',
                                        'latent_observation')
    desired_goal_key = cfgs.SKEW_FIT.get('desired_goal_key',
                                         'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = cfgs.Q_FUNCTION.get('hidden_sizes', [400, 300])
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=cfgs.POLICY.get('hidden_sizes', [400, 300]),
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=env.vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        priority_function_kwargs=cfgs.PRIORITY_FUNCTION,
        **cfgs.REPLAY_BUFFER)
    vae_trainer = ConvVAETrainer(
        cfgs.VAE_TRAINER.train_data,
        cfgs.VAE_TRAINER.test_data,
        env.vae,
        beta=cfgs.VAE_TRAINER.beta,
        lr=cfgs.VAE_TRAINER.lr,
    )

    # assert 'vae_training_schedule' not in cfgs, "Just put it in algo_kwargs"
    max_path_length = cfgs.SKEW_FIT.max_path_length
    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **cfgs.TWIN_SAC_TRAINER)
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        cfgs.SKEW_FIT.evaluation_goal_sampling_mode,
        env,
        MakeDeterministic(policy),
        decode_goals=True,  # TODO check this
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        cfgs.SKEW_FIT.exploration_goal_sampling_mode,
        env,
        policy,
        decode_goals=True,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    algorithm = OnlineVaeAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        uniform_dataset=uniform_dataset,  # TODO used in test vae
        max_path_length=max_path_length,
        parallel_vae_train=cfgs.VAE_TRAINER.parallel_train,
        **cfgs.ALGORITHM)

    if cfgs.SKEW_FIT.custom_goal_sampler == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()