Esempio n. 1
0
    def setUp(self):
        super(NDHEnvTest, self).setUp()
        self.data_dir = FLAGS.test_srcdir + ('valan/r2r/testdata')

        self.reward_fn_type = 'distance_to_goal'
        self._env_config = hparam.HParams(
            problem='NDH',
            history='all',
            path_type='trusted_path',
            max_goal_room_panos=4,
            base_path=self.data_dir,
            problem_path=os.path.join(self.data_dir, 'NDH'),
            vocab_file='vocab.txt',
            images_per_pano=36,
            max_conns=14,
            image_encoding_dim=2052,
            image_features_dir=os.path.join(self.data_dir, 'image_features'),
            instruction_len=50,
            max_agent_actions=6,
            reward_fn_type=self.reward_fn_type,
            reward_fn=env_config.RewardFunction.get_reward_fn(
                self.reward_fn_type))
        self._runtime_config = common.RuntimeConfig(task_id=0, num_tasks=1)

        self._env = env_ndh.NDHEnv(data_sources=['small_split'],
                                   runtime_config=self._runtime_config,
                                   env_config=self._env_config)

        # For deterministic behavior in test
        np.random.seed(0)
Esempio n. 2
0
 def get_environment(self):
     if not self._env:
         assert self._data_sources, 'data_sources must be non-empty.'
         self._env = env_ndh.NDHEnv(
             data_sources=self._data_sources,
             runtime_config=self._runtime_config,
             env_config=env_ndh_config.get_ndh_env_config())
     return self._env
Esempio n. 3
0
    def testStepToGoalRoom(self):
        self.reward_fn_type = 'distance_to_room'
        self._env_config = hparam.HParams(
            problem='NDH',
            history='all',
            path_type='trusted_path',
            max_goal_room_panos=4,
            base_path=self.data_dir,
            problem_path=os.path.join(self.data_dir, 'NDH'),
            vocab_file='vocab.txt',
            images_per_pano=36,
            max_conns=14,
            image_encoding_dim=2052,
            image_features_dir=os.path.join(self.data_dir, 'image_features'),
            instruction_len=50,
            max_agent_actions=6,
            reward_fn_type=self.reward_fn_type,
            reward_fn=env_config.RewardFunction.get_reward_fn(
                self.reward_fn_type))
        self._runtime_config = common.RuntimeConfig(task_id=0, num_tasks=1)

        self._env = env_ndh.NDHEnv(data_sources=['small_split'],
                                   runtime_config=self._runtime_config,
                                   env_config=self._env_config)

        scan_id = 0  # testdata only has single scan 'gZ6f7yhEvPG'
        _ = self._env.reset()
        golden_path = [
            'ba27da20782d4e1a825f0a133ad84da9',
            '47d8a8282c1c4a7fb3eeeacc45e9d959',  # in the goal room
            '0ee20663dfa34b438d48750ddcd7366c'  # in the goal room
        ]

        # Step through the trajectory and verify the env_output.
        for i, action in enumerate(
            [self._get_pano_id(p, scan_id) for p in golden_path]):
            expected_time_step = i + 1
            expected_heading, expected_pitch = self._env._get_heading_pitch(
                action, scan_id, expected_time_step)
            if i + 1 < len(golden_path):
                expected_oracle_action = self._get_pano_id(
                    golden_path[i + 1], scan_id)
            else:
                expected_oracle_action = constants.STOP_NODE_ID
            expected_reward = 1 if i <= 1 else 0
            env_test.verify_env_output(
                self,
                self._env.step(action),
                expected_reward=expected_reward,  #  Moving towards goal.
                expected_done=False,
                expected_info='',
                expected_time_step=expected_time_step,
                expected_path_id=318,
                expected_pano_name=golden_path[i],
                expected_heading=expected_heading,
                expected_pitch=expected_pitch,
                expected_scan_id=scan_id,
                expected_oracle_action=expected_oracle_action)

        # Stop at goal pano. Terminating the episode results in resetting the
        # observation to next episode.
        env_test.verify_env_output(
            self,
            self._env.step(constants.STOP_NODE_ID),
            expected_reward=4,  # reached goal and stopped
            expected_done=True,  # end of episode
            expected_info='',
            # observation for next episode.
            expected_time_step=0,
            expected_path_id=1304,
            expected_pano_name='80929af5cf234ae38ac3a2a4e60e4342',
            expected_heading=6.101,
            expected_pitch=0.,
            expected_scan_id=scan_id,
            expected_oracle_action=self._get_pano_id(
                'ba27da20782d4e1a825f0a133ad84da9', scan_id))