Beispiel #1
0
    def __init__(self, runtime_config, mode, data_sources):
        self._runtime_config = runtime_config
        self._mode = mode
        self._data_sources = data_sources

        self._agent = agent.R2RAgent(agent_config.get_ndh_agent_config())
        self._prob_ac = 0.5
        self._env = None
        self._loss_type = None
        self._eval_dict = self._get_eval_dict()
Beispiel #2
0
  def test_call_ndh(self):
    self._agent = agent.R2RAgent(agent_config.get_ndh_agent_config())
    self.data_dir = FLAGS.test_srcdir + (
        'valan/r2r/testdata')

    self._env_config = hparam.HParams(
        problem='NDH',
        history='all',
        path_type='trusted_path',
        max_goal_room_panos=4,
        scan_base_dir=self.data_dir,
        data_base_dir=self.data_dir,
        vocab_dir=self.data_dir,
        problem_path=os.path.join(self.data_dir, 'NDH'),
        vocab_file='vocab.txt',
        images_per_pano=36,
        max_conns=14,
        image_encoding_dim=64,
        direction_encoding_dim=256,
        image_features_dir=os.path.join(self.data_dir, 'image_features'),
        instruction_len=50,
        max_agent_actions=6,
        reward_fn=env_config.RewardFunction.get_reward_fn('distance_to_goal'))

    self._runtime_config = common.RuntimeConfig(task_id=0, num_tasks=1)
    self._env = env.R2REnv(
        data_sources=['R2R_small_split'],
        runtime_config=self._runtime_config,
        env_config=self._env_config)

    env_output = self._env.reset()
    observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 0),
                                        env_output.observation)
    initial_agent_state = self._agent.get_initial_state(
        observation, batch_size=1)
    # Agent always expects time,batch dimensions. First add and then remove.
    env_output = utils.add_time_batch_dim(env_output)
    agent_output, _ = self._agent(env_output, initial_agent_state)

    self.assertEqual(agent_output.policy_logits.shape, [1, 1, 14])
    self.assertEqual(agent_output.baseline.shape, [1, 1])

    initial_agent_state = ([
        (tf.random.normal([self.batch_size,
                           512]), tf.random.normal([self.batch_size, 512])),
        (tf.random.normal([self.batch_size,
                           512]), tf.random.normal([self.batch_size, 512]))
    ], tf.random.normal([self.batch_size, 5, 512]))
    agent_output, _ = self._agent(self._test_environment, initial_agent_state)

    self.assertEqual(agent_output.policy_logits.shape,
                     [self.time_step, self.batch_size, 14])
    self.assertEqual(agent_output.baseline.shape,
                     [self.time_step, self.batch_size])
Beispiel #3
0
    def __init__(self, runtime_config, mode, data_sources, agent_type='r2r'):
        self._runtime_config = runtime_config
        self._mode = mode
        self._data_sources = data_sources

        if agent_type.lower() == 'r2r':
            self._agent = agent.R2RAgent(agent_config.get_ndh_agent_config(),
                                         mode=mode)
        elif agent_type.lower() == 'mt':
            self._agent = mt_agent.MTEnvAgAgent(
                mt_agent_config.get_agent_config(), mode=mode)
        else:
            raise ValueError('Invalid agent_type: {}'.format(agent_type))

        self._prob_ac = 0.5
        self._env = None
        self._loss_type = None
        self._eval_dict = self._get_eval_dict()