def testEpisodicEnvWrapperSimple(self): num_envs = 10 vec_env = HackDummyVecEnv([self.EnvFactory] * num_envs) embedding_size = 16 vec_episodic_memory = [episodic_memory.EpisodicMemory( capacity=1000, observation_shape=[embedding_size], observation_compare_fn=embedding_similarity) for _ in range(num_envs)] mat = np.random.normal(size=[28 * 28 * 3, embedding_size]) observation_embedding = lambda x, m=mat: linear_embedding(m, x) target_image_shape = [14, 14, 1] env_wrapper = curiosity_env_wrapper.CuriosityEnvWrapper( vec_env, vec_episodic_memory, observation_embedding, target_image_shape) observations = env_wrapper.reset() self.assertAllEqual([num_envs] + target_image_shape, observations.shape) dummy_actions = [1] * num_envs for _ in range(100): previous_mem_length = [len(mem) for mem in vec_episodic_memory] observations, unused_rewards, dones, unused_infos = ( env_wrapper.step(dummy_actions)) current_mem_length = [len(mem) for mem in vec_episodic_memory] self.assertAllEqual([num_envs] + target_image_shape, observations.shape) for k in range(num_envs): if dones[k]: self.assertEqual(1, current_mem_length[k]) else: self.assertGreaterEqual(current_mem_length[k], previous_mem_length[k])
def create_environments(env_name, num_envs, r_network_weights_path = None, dmlab_homepath = '', action_set = '', base_seed = 123, scale_task_reward_for_eval = 1.0, scale_surrogate_reward_for_eval = 0.0, online_r_training = False, environment_engine = 'atari', r_network_weights_store_path = ''): """Creates a environments with R-network-based curiosity reward. Args: env_name: Name of the DMLab environment. num_envs: Number of parallel environment to spawn. r_network_weights_path: Path to the weights of the R-network. dmlab_homepath: Path to the DMLab MPM. Required when running on borg. action_set: One of {'small', 'nofire', ''}. Which action set to use. base_seed: Each environment will use base_seed+env_index as seed. scale_task_reward_for_eval: scale of the task reward to be used for valid/test environments. scale_surrogate_reward_for_eval: scale of the surrogate reward to be used for valid/test environments. online_r_training: Whether to enable online training of the R-network. environment_engine: either 'dmlab', 'atari', 'parkour'. r_network_weights_store_path: Directory where to store R checkpoints generated during online training of the R network. Returns: Wrapped environment with curiosity. """ # Environments without intrinsic exploration rewards. # pylint: disable=g-long-lambda create_dmlab_single_env = functools.partial(create_single_env, dmlab_homepath=dmlab_homepath, action_set=action_set) if environment_engine == 'dmlab': create_env_fn = create_dmlab_single_env is_atari_environment = False elif environment_engine == 'atari': create_env_fn = create_single_atari_env is_atari_environment = True else: raise ValueError('Unknown env engine {}'.format(environment_engine)) # WARNING: python processes are not really compatible with other google3 code, # which can lead to deadlock. See go/g3process. This is why you can use # ThreadedVecEnv. VecEnvClass = (subproc_vec_env.SubprocVecEnv if FLAGS.vec_env_class == 'SubprocVecEnv' else threaded_vec_env.ThreadedVecEnv) vec_env = VecEnvClass([ (lambda _i=i: create_env_fn(env_name, base_seed + _i, use_monitor=True, split='train')) for i in range(num_envs) ]) valid_env = VecEnvClass([ (lambda _i=i: create_env_fn(env_name, base_seed + _i, use_monitor=False, split='valid')) for i in range(num_envs) ]) test_env = VecEnvClass([ (lambda _i=i: create_env_fn(env_name, base_seed + _i, use_monitor=False, split='test')) for i in range(num_envs) ]) # pylint: enable=g-long-lambda # Size of states when stored in the memory. embedding_size = models.EMBEDDING_DIM if not r_network_weights_path: # Empty string equivalent to no R_network checkpoint. r_network_weights_path = None r_net = r_network.RNetwork( (84, 84, 4) if is_atari_environment else Const.OBSERVATION_SHAPE, r_network_weights_path) # Only for online training do we need to train the R-network. r_network_trainer = None if online_r_training: r_network_trainer = r_network_training.RNetworkTrainer( r_net._r_network, # pylint: disable=protected-access checkpoint_dir=r_network_weights_store_path) # Creates the episodic memory that is attached to each of those envs. vec_episodic_memory = [ episodic_memory.EpisodicMemory( observation_shape=[embedding_size], observation_compare_fn=r_net.embedding_similarity) for _ in range(num_envs) ] # The size of images is reduced to 64x64 to make training faster. # Note: using color images with DMLab makes it much easier to train a policy. # So no conversion to grayscale. target_image_shape = [84, 84, 4 if is_atari_environment else 3] env_wrapper = curiosity_env_wrapper.CuriosityEnvWrapper( vec_env, vec_episodic_memory, r_net.embed_observation, target_image_shape) if r_network_trainer is not None: env_wrapper.add_observer(r_network_trainer) valid_env_wrapper, test_env_wrapper = ( curiosity_env_wrapper.CuriosityEnvWrapper( env, vec_episodic_memory, r_net.embed_observation, target_image_shape, exploration_reward=('none' if (is_atari_environment or environment_engine == 'parkour') else 'oracle'), scale_task_reward=scale_task_reward_for_eval, scale_surrogate_reward=scale_surrogate_reward_for_eval) for env in [valid_env, test_env]) return env_wrapper, valid_env_wrapper, test_env_wrapper