Esempio n. 1
0
    def testCurriculumSorting(self):
        # Check whether every sort_key option is effective.
        key_names = [
            'path_len',
            'inst_len',
            'path_len,inst_len',
            'inst_len,path_len',
        ]
        sort_keys = [
            curriculum_env_config_lib.KEY_PATHLEN,
            curriculum_env_config_lib.KEY_INSTLEN,
            curriculum_env_config_lib.KEY_PATHLEN_INSTLEN,
            curriculum_env_config_lib.KEY_INSTLEN_PATHLEN,
        ]
        for key_name, sort_key in zip(key_names, sort_keys):
            config = curriculum_env_config_lib.get_default_curriculum_env_config(
                'constant-100-1', self._env_config)
            config.sort_key = sort_key

            # Key function takes one input and converts it to the key.
            self._env = curriculum_env.CurriculumR2REnv(
                data_sources=['small_split'],
                runtime_config=self._runtime_config,
                curriculum_env_config=config)
            # Compare function takes two inputs and determine their relationship.
            sort_cmp = name2cmp[key_name]
            self._check_sorted(sort_cmp)
Esempio n. 2
0
 def get_environment(self):
     if not self._env:
         assert self._data_sources, 'data_sources must be non-empty.'
         if self._curriculum:
             # See actor_main.py and curriculum_env.py for the argument options.
             self._env = curriculum_env.CurriculumR2REnv(
                 data_sources=self._data_sources,
                 runtime_config=self._runtime_config,
                 curriculum_env_config=curriculum_env_config_lib.
                 get_default_curriculum_env_config(self._curriculum))
         else:
             self._env = env.R2REnv(
                 data_sources=self._data_sources,
                 runtime_config=self._runtime_config,
                 env_config=env_config_lib.get_default_env_config())
     return self._env
Esempio n. 3
0
 def testCurriculumConstantIncrement(self):
     config = curriculum_env_config_lib.get_default_curriculum_env_config(
         'constant-1-1.5', self._env_config)
     self._env = curriculum_env.CurriculumR2REnv(
         data_sources=['small_split'],
         runtime_config=self._runtime_config,
         curriculum_env_config=config)
     self.assertLen(self._env._paths, 1)
     for i in range(1, 4):
         _ = self._env.reset()
         self._check_paths_order()
         self._check_sorted(name2cmp['path_len,inst_len'])
         self.assertLen(self._env._paths, int(1 + 1.5 * i))
     for _ in range(100):
         _ = self._env.reset()
         self._check_paths_order()
         self._check_sorted(name2cmp['path_len,inst_len'])
         self.assertLen(self._env._paths, 6)
Esempio n. 4
0
    def testCurriculumAdaptiveIncrement(self):
        config = curriculum_env_config_lib.get_default_curriculum_env_config(
            'adaptive-1-4', self._env_config)
        self._env = curriculum_env.CurriculumR2REnv(
            data_sources=['R2R_small_split'],
            runtime_config=self._runtime_config,
            curriculum_env_config=config)
        self.assertLen(self._env._paths, 1)

        self.assertEqual(self._env._increment, (6. - 1.) / 4)
        for i in range(1, 4):
            _ = self._env.reset()
            self._check_paths_order()
            self._check_sorted(name2cmp['path_len,inst_len'])
            self.assertLen(self._env._paths, int(1 + 1.25 * i))
        for _ in range(100):
            _ = self._env.reset()
            self._check_paths_order()
            self._check_sorted(name2cmp['path_len,inst_len'])
            self.assertLen(self._env._paths, 6)
Esempio n. 5
0
    def testCurriculumBuilding(self):
        config = curriculum_env_config_lib.get_default_curriculum_env_config(
            'constant-1-1', self._env_config)
        self._env = curriculum_env.CurriculumR2REnv(
            data_sources=['small_split'],
            runtime_config=self._runtime_config,
            curriculum_env_config=config)

        # Since initially only 1 path is put in the environment, the length of
        # paths should be 1.
        self.assertLen(self._env._paths, 1)
        for i in range(2, 7):
            _ = self._env.reset()
            self._check_paths_order()
            self._check_sorted(name2cmp['path_len,inst_len'])
            self.assertLen(self._env._paths, i)
        for _ in range(100):
            _ = self._env.reset()
            self._check_paths_order()
            self._check_sorted(name2cmp['path_len,inst_len'])
            self.assertLen(self._env._paths, 6)