def testEpisodicMemory(self):
        observation_shape = [9]
        memory = episodic_memory.EpisodicMemory(
            observation_shape=observation_shape,
            observation_compare_fn=embedding_similarity,
            capacity=150)

        self.RunTest(memory, observation_shape, add_count=100)
        memory.reset()

        self.RunTest(memory, observation_shape, add_count=200)
        memory.reset()
  def testEpisodicEnvWrapperSimple(self):
    num_envs = 10
    vec_env = HackDummyVecEnv([self.EnvFactory] * num_envs)

    embedding_size = 16
    vec_episodic_memory = [episodic_memory.EpisodicMemory(
        capacity=1000,
        observation_shape=[embedding_size],
        observation_compare_fn=embedding_similarity)
                           for _ in range(num_envs)]

    mat = np.random.normal(size=[28 * 28 * 3, embedding_size])
    observation_embedding = lambda x, m=mat: linear_embedding(m, x)

    target_image_shape = [14, 14, 1]
    env_wrapper = curiosity_env_wrapper.CuriosityEnvWrapper(
        vec_env, vec_episodic_memory,
        observation_embedding,
        target_image_shape)

    observations = env_wrapper.reset()
    self.assertAllEqual([num_envs] + target_image_shape,
                        observations.shape)

    dummy_actions = [1] * num_envs
    for _ in range(100):
      previous_mem_length = [len(mem) for mem in vec_episodic_memory]
      observations, unused_rewards, dones, unused_infos = (
          env_wrapper.step(dummy_actions))
      current_mem_length = [len(mem) for mem in vec_episodic_memory]

      self.assertAllEqual([num_envs] + target_image_shape,
                          observations.shape)
      for k in range(num_envs):
        if dones[k]:
          self.assertEqual(1, current_mem_length[k])
        else:
          self.assertGreaterEqual(current_mem_length[k],
                                  previous_mem_length[k])
def create_environments(env_name,
                        num_envs,
                        r_network_weights_path = None,
                        dmlab_homepath = '',
                        action_set = '',
                        base_seed = 123,
                        scale_task_reward_for_eval = 1.0,
                        scale_surrogate_reward_for_eval = 0.0,
                        online_r_training = False,
                        environment_engine = 'atari',
                        r_network_weights_store_path = ''):
  """Creates a environments with R-network-based curiosity reward.

  Args:
    env_name: Name of the DMLab environment.
    num_envs: Number of parallel environment to spawn.
    r_network_weights_path: Path to the weights of the R-network.
    dmlab_homepath: Path to the DMLab MPM. Required when running on borg.
    action_set: One of {'small', 'nofire', ''}. Which action set to use.
    base_seed: Each environment will use base_seed+env_index as seed.
    scale_task_reward_for_eval: scale of the task reward to be used for
      valid/test environments.
    scale_surrogate_reward_for_eval: scale of the surrogate reward to be used
      for valid/test environments.
    online_r_training: Whether to enable online training of the R-network.
    environment_engine: either 'dmlab', 'atari', 'parkour'.
    r_network_weights_store_path: Directory where to store R checkpoints
      generated during online training of the R network.

  Returns:
    Wrapped environment with curiosity.
  """
  # Environments without intrinsic exploration rewards.
  # pylint: disable=g-long-lambda
  create_dmlab_single_env = functools.partial(create_single_env,
                                              dmlab_homepath=dmlab_homepath,
                                              action_set=action_set)

  if environment_engine == 'dmlab':
    create_env_fn = create_dmlab_single_env
    is_atari_environment = False
  elif environment_engine == 'atari':
    create_env_fn = create_single_atari_env
    is_atari_environment = True
  else:
    raise ValueError('Unknown env engine {}'.format(environment_engine))

  # WARNING: python processes are not really compatible with other google3 code,
  # which can lead to deadlock. See go/g3process. This is why you can use
  # ThreadedVecEnv.
  VecEnvClass = (subproc_vec_env.SubprocVecEnv
                 if FLAGS.vec_env_class == 'SubprocVecEnv'
                 else threaded_vec_env.ThreadedVecEnv)

  vec_env = VecEnvClass([
      (lambda _i=i: create_env_fn(env_name, base_seed + _i, use_monitor=True,
                                  split='train'))
      for i in range(num_envs)
  ])
  valid_env = VecEnvClass([
      (lambda _i=i: create_env_fn(env_name, base_seed + _i, use_monitor=False,
                                  split='valid'))
      for i in range(num_envs)
  ])
  test_env = VecEnvClass([
      (lambda _i=i: create_env_fn(env_name, base_seed + _i, use_monitor=False,
                                  split='test'))
      for i in range(num_envs)
  ])
  # pylint: enable=g-long-lambda

  # Size of states when stored in the memory.
  embedding_size = models.EMBEDDING_DIM

  if not r_network_weights_path:
    # Empty string equivalent to no R_network checkpoint.
    r_network_weights_path = None
  r_net = r_network.RNetwork(
      (84, 84, 4) if is_atari_environment else Const.OBSERVATION_SHAPE,
      r_network_weights_path)

  # Only for online training do we need to train the R-network.
  r_network_trainer = None
  if online_r_training:
    r_network_trainer = r_network_training.RNetworkTrainer(
        r_net._r_network,  # pylint: disable=protected-access
        checkpoint_dir=r_network_weights_store_path)

  # Creates the episodic memory that is attached to each of those envs.
  vec_episodic_memory = [
      episodic_memory.EpisodicMemory(
          observation_shape=[embedding_size],
          observation_compare_fn=r_net.embedding_similarity)
      for _ in range(num_envs)
  ]

  # The size of images is reduced to 64x64 to make training faster.
  # Note: using color images with DMLab makes it much easier to train a policy.
  # So no conversion to grayscale.
  target_image_shape = [84, 84, 4 if is_atari_environment else 3]
  env_wrapper = curiosity_env_wrapper.CuriosityEnvWrapper(
      vec_env, vec_episodic_memory, r_net.embed_observation, target_image_shape)
  if r_network_trainer is not None:
    env_wrapper.add_observer(r_network_trainer)

  valid_env_wrapper, test_env_wrapper = (
      curiosity_env_wrapper.CuriosityEnvWrapper(
          env, vec_episodic_memory, r_net.embed_observation,
          target_image_shape,
          exploration_reward=('none' if (is_atari_environment or
                                         environment_engine == 'parkour')
                              else 'oracle'),
          scale_task_reward=scale_task_reward_for_eval,
          scale_surrogate_reward=scale_surrogate_reward_for_eval)
      for env in [valid_env, test_env])

  return env_wrapper, valid_env_wrapper, test_env_wrapper