예제 #1
0
 def get_environment(self):
   if not self._env:
     assert self._data_sources, 'data_sources must be non-empty.'
     self._env = env.R2REnv(
         data_sources=self._data_sources,
         runtime_config=self._runtime_config,
         env_config=env_config_lib.get_default_env_config())
   return self._env
예제 #2
0
 def get_environment(self):
     if not self._env:
         assert self._data_sources, 'data_sources must be non-empty.'
         if self._curriculum:
             # See actor_main.py and curriculum_env.py for the argument options.
             self._env = curriculum_env.CurriculumR2REnv(
                 data_sources=self._data_sources,
                 runtime_config=self._runtime_config,
                 curriculum_env_config=curriculum_env_config_lib.
                 get_default_curriculum_env_config(self._curriculum))
         else:
             self._env = env.R2REnv(
                 data_sources=self._data_sources,
                 runtime_config=self._runtime_config,
                 env_config=env_config_lib.get_default_env_config())
     return self._env
예제 #3
0
def get_default_curriculum_env_config(method, env_config=None):
    """Get default curriculum env config.

  Args:
    method: The method used in curriculum learning.
    env_config: Optional. The env config. If None, use the default env
      config file. Default, None.

  Returns:
    A curriculum env config.
  """
    if env_config is None:
        env_config = env_config_lib.get_default_env_config()
    config_updates = dict(env_config.values(), method=method)
    curriculum_env_config = DEFAULT_CURRICULUM_ENV_CONFIG.copy()
    curriculum_env_config.update(config_updates)
    config = hparam.HParams(**curriculum_env_config)
    return config
예제 #4
0
  def __init__(self, runtime_config, mode, data_sources, agent_config=None,
               env_config=None):
    self._runtime_config = runtime_config
    self._mode = mode
    self._data_sources = data_sources

    self._env_config = (
        env_config if env_config else env_config_lib.get_default_env_config())
    self._env = None
    self._loss_type = None
    self._eval_dict = self._get_eval_dict()

    agent_config = (
        agent_config
        if agent_config else agent_config_lib.get_r2r_agent_config())
    agent_type = (
        agent_config.agent_type
        if hasattr(agent_config, 'agent_type') else 'default')
    if agent_type == 'default':
      self._agent = discriminator_agent.DiscriminatorAgent(agent_config)
    elif agent_type == 'v2':
      self._agent = discriminator_agent.DiscriminatorAgentV2(agent_config)
예제 #5
0
파일: env.py 프로젝트: ml-lab/valan
def _get_default_env_config():
  return default_env_config.get_default_env_config()