コード例 #1
0
ファイル: babi_qa.py プロジェクト: kltony/tensor2tensor
def _register_babi_problems():
  """It dynamically instantiates a class for each babi subsets-tasks.

   @registry.register_problem
   class BabiQaConcatAllTasks_10k(EditSequenceRegexProblem):
     @property
     def babi_task_id(self):
       return 'qa0'
     @property
     def babi_subset(self):
      return 'en-10k'

  It does not put the classes into the global namespace, so to access the class
  we rely on the registry or this module's REGISTERED_PROBLEMS list.
  It will be available as

     registry.problem('babi_qa_concat_all_tasks_10k')

  i.e., change camel case to snake case. Numbers are considered lower case
  characters for these purposes.
  """
  for (subset, subset_suffix) in [('en', '_1k'), ('en-10k', '_10k')]:
    for problem_name, babi_task_id in six.iteritems(_problems_to_register()):
      problem_class = type('BabiQaConcat' + problem_name + subset_suffix,
                           (BabiQaConcat,), {
                               'babi_task_id': babi_task_id,
                               'babi_subset': subset
                           })
      registry.register_problem(problem_class)
      REGISTERED_PROBLEMS.append(problem_class.name)
コード例 #2
0
ファイル: babi_qa.py プロジェクト: hubayirp/fabric-vsf
def _register_babi_problems():
    """It dynamically instantiates a class for each babi subsets-tasks.

   @registry.register_problem
   class BabiQaConcatAllTasks_10k(EditSequenceRegexProblem):
     @property
     def babi_task_id(self):
       return "qa0"
     @property
     def babi_subset(self):
      return "en-10k"

  It does not put the classes into the global namespace, so to access the class
  we rely on the registry or this module"s REGISTERED_PROBLEMS list.
  It will be available as

     registry.problem("babi_qa_concat_all_tasks_10k")

  i.e., change camel case to snake case. Numbers are considered lower case
  characters for these purposes.
  """
    for (subset, subset_suffix) in [("en", "_1k"), ("en-10k", "_10k")]:
        for problem_name, babi_task_id in six.iteritems(
                _problems_to_register()):
            problem_class = type("BabiQaConcat" + problem_name + subset_suffix,
                                 (BabiQaConcat, ), {
                                     "babi_task_id": babi_task_id,
                                     "babi_subset": subset
                                 })
            registry.register_problem(problem_class)
            REGISTERED_PROBLEMS.append(problem_class.name)
コード例 #3
0
def create_problems_for_game(game_name, clipped_reward=True):
    """Create and register problems for game_name.

  Args:
    game_name: str, one of the games in ATARI_GAMES, e.g. "bank_heist".
    clipped_reward: bool, whether the rewards should be clipped. False is not
      yet supported.

  Returns:
    dict of problems with keys ("base", "agent", "simulated").

  Raises:
    ValueError: if clipped_reward=False or game_name not in ATARI_GAMES.
  """
    if not clipped_reward:
        raise ValueError("Creating problems without clipped reward is not "
                         "yet supported.")
    if game_name not in ATARI_GAMES:
        raise ValueError("Game %s not in ATARI_GAMES" % game_name)
    camel_game_name = "".join(
        [w[0].upper() + w[1:] for w in game_name.split("_")])
    env_name = "%sDeterministic-v4" % camel_game_name
    wrapped_env_name = "T2T%s" % env_name

    # Register an environment that does the reward clipping
    gym.envs.register(
        id=wrapped_env_name,
        entry_point=lambda: gym_utils.wrapped_factory(  # pylint: disable=g-long-lambda
            env=env_name,
            reward_clipping=True))

    # Create and register the Random and WithAgent Problem classes
    problem_cls = type("Gym%sRandom" % camel_game_name,
                       (GymClippedRewardRandom, ),
                       {"env_name": wrapped_env_name})
    with_agent_cls = type("GymDiscreteProblemWithAgentOn%s" % camel_game_name,
                          (GymRealDiscreteProblem, problem_cls), {})
    registry.register_problem(with_agent_cls)

    # Create and register the simulated Problem
    simulated_cls = type(
        "GymSimulatedDiscreteProblemWithAgentOn%s" % camel_game_name,
        (GymSimulatedDiscreteProblem, problem_cls), {
            "initial_frames_problem": with_agent_cls.name,
            "num_testing_steps": 100
        })
    registry.register_problem(simulated_cls)

    return {
        "base": problem_cls,
        "agent": with_agent_cls,
        "simulated": simulated_cls,
    }
def _register_scan_problems():
    classes = [
        AlgorithmicSCAN,
        AlgorithmicSCANSep,
    ]
    for problem_name, txts in six.iteritems(_problems_to_register()):
        for class_ in classes:
            base_problem_class_name = misc_utils.camelcase_to_snakecase(
                class_.__name__)
            problem_class = type(f"{base_problem_class_name}_{problem_name}",
                                 (class_, ), {
                                     "train_txt": txts[0],
                                     "test_txt": txts[1]
                                 })
            registry.register_problem(problem_class)
            REGISTERED_PROBLEMS.append(problem_class.name)
コード例 #5
0
ファイル: gym_env.py プロジェクト: yuhonghong66/tensor2tensor
def register_game(game_name, game_mode="NoFrameskip-v4"):
    """Create and register problems for the game.

  Args:
    game_name: str, one of the games in ATARI_GAMES, e.g. "bank_heist".
    game_mode: the frame skip and sticky keys config.

  Raises:
    ValueError: if game_name or game_mode are wrong.
  """
    if game_name not in ATARI_GAMES:
        raise ValueError("Game %s not in ATARI_GAMES" % game_name)
    if game_mode not in ATARI_GAME_MODES:
        raise ValueError("Unknown ATARI game mode: %s." % game_mode)
    camel_game_name = misc_utils.snakecase_to_camelcase(game_name) + game_mode
    # Create and register the Problem
    cls = type("Gym%sRandom" % camel_game_name, (T2TGymEnv, ),
               {"base_env_name": camel_game_name})
    registry.register_problem(cls)
コード例 #6
0
ファイル: gym_env.py プロジェクト: qixiuai/tensor2tensor
def register_game(game_name, game_mode="Deterministic-v4"):
  """Create and register problems for the game.

  Args:
    game_name: str, one of the games in ATARI_GAMES, e.g. "bank_heist".
    game_mode: the frame skip and sticky keys config.

  Raises:
    ValueError: if game_name or game_mode are wrong.
  """
  if game_name not in ATARI_GAMES:
    raise ValueError("Game %s not in ATARI_GAMES" % game_name)
  if game_mode not in ATARI_GAME_MODES:
    raise ValueError("Unknown ATARI game mode: %s." % game_mode)
  camel_game_name = camel_case_name(game_name) + game_mode
  # Create and register the Problem
  cls = type("Gym%sRandom" % camel_game_name,
             (T2TGymEnv,), {"base_env_name": camel_game_name})
  registry.register_problem(cls)
コード例 #7
0
def create_problems_for_game(game_name,
                             resize_height_factor=2,
                             resize_width_factor=2,
                             game_mode="Deterministic-v4"):
    """Create and register problems for game_name.

  Args:
    game_name: str, one of the games in ATARI_GAMES, e.g. "bank_heist".
    resize_height_factor: factor by which to resize the height of frames.
    resize_width_factor: factor by which to resize the width of frames.
    game_mode: the frame skip and sticky keys config.

  Returns:
    dict of problems with keys ("base", "agent", "simulated").

  Raises:
    ValueError: if clipped_reward=False or game_name not in ATARI_GAMES.
  """
    if game_name not in ATARI_GAMES:
        raise ValueError("Game %s not in ATARI_GAMES" % game_name)
    if game_mode not in ATARI_GAME_MODES:
        raise ValueError("Unknown ATARI game mode: %s." % game_mode)
    camel_game_name = "".join(
        [w[0].upper() + w[1:] for w in game_name.split("_")])
    camel_game_name += game_mode
    env_name = camel_game_name

    # Create and register the Random and WithAgent Problem classes
    problem_cls = type(
        "Gym%sRandom" % camel_game_name, (GymClippedRewardRandom, ), {
            "env_name": env_name,
            "resize_height_factor": resize_height_factor,
            "resize_width_factor": resize_width_factor
        })
    registry.register_problem(problem_cls)

    with_agent_cls = type("GymDiscreteProblemWithAgentOn%s" % camel_game_name,
                          (GymRealDiscreteProblem, problem_cls), {})

    registry.register_problem(with_agent_cls)

    # Create and register the simulated Problem
    simulated_cls = type(
        "GymSimulatedDiscreteProblemWithAgentOn%s" % camel_game_name,
        (GymSimulatedDiscreteProblem, problem_cls), {
            "initial_frames_problem": with_agent_cls.name,
            "num_testing_steps": 100
        })
    registry.register_problem(simulated_cls)

    return {
        "base": problem_cls,
        "agent": with_agent_cls,
        "simulated": simulated_cls,
    }
コード例 #8
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensor2tensor.utils import registry
from tensor2tensor.utils.registry import *

# Adds a subsection to the registries to store specific G2G problems
registry.Registries.g2g_problems = registry.Registry("g2g_problems", validator=registry._problem_name_validator, on_set=registry._on_problem_set)
registry.Registries.g2g_hparams = registry.Registry("g2g_hparams", value_transformer=registry._hparams_value_transformer)

# Defines decorator
register_problem = lambda x: registry.register_problem(registry.Registries.g2g_problems.register(x))
register_hparams = lambda x: registry.register_hparams(registry.Registries.g2g_hparams.register(x))

# Overrides registry queries
list_g2g_problems = lambda: sorted(Registries.g2g_problems)
list_problems = list_g2g_problems
list_all_problems = list_base_problems

list_g2g_hparams = lambda: sorted(Registries.g2g_hparams)
#list_hparams = list_g2g_hparams
#list_all_hparams = registry.list_hparams
コード例 #9
0
def create_problems_for_game(game_name,
                             resize_height_factor=2,
                             resize_width_factor=2,
                             grayscale=True,
                             game_mode="Deterministic-v4",
                             autoencoder_hparams=None):
    """Create and register problems for game_name.

  Args:
    game_name: str, one of the games in ATARI_GAMES, e.g. "bank_heist".
    resize_height_factor: factor by which to resize the height of frames.
    resize_width_factor: factor by which to resize the width of frames.
    grayscale: whether to make frames grayscale.
    game_mode: the frame skip and sticky keys config.
    autoencoder_hparams: the hparams for the autoencoder.

  Returns:
    dict of problems with keys ("base", "agent", "simulated").

  Raises:
    ValueError: if clipped_reward=False or game_name not in ATARI_GAMES.
  """
    if game_name not in ATARI_GAMES:
        raise ValueError("Game %s not in ATARI_GAMES" % game_name)
    if game_mode not in ATARI_GAME_MODES:
        raise ValueError("Unknown ATARI game mode: %s." % game_mode)
    camel_game_name = "".join(
        [w[0].upper() + w[1:] for w in game_name.split("_")])
    camel_game_name += game_mode
    env_name = camel_game_name

    # Create and register the Random and WithAgent Problem classes
    problem_cls = type(
        "Gym%sRandom" % camel_game_name, (GymClippedRewardRandom, ), {
            "env_name": env_name,
            "resize_height_factor": resize_height_factor,
            "resize_width_factor": resize_width_factor,
            "grayscale": grayscale
        })
    registry.register_problem(problem_cls)

    with_agent_cls = type("GymDiscreteProblemWithAgentOn%s" % camel_game_name,
                          (GymRealDiscreteProblem, problem_cls), {})
    registry.register_problem(with_agent_cls)

    with_ae_cls = type(
        "GymDiscreteProblemWithAgentOn%sWithAutoencoder" % camel_game_name,
        (GymDiscreteProblemWithAutoencoder, problem_cls),
        {"ae_hparams_set": autoencoder_hparams})
    registry.register_problem(with_ae_cls)

    ae_cls = type(
        "GymDiscreteProblemWithAgentOn%sAutoencoded" % camel_game_name,
        (GymDiscreteProblemAutoencoded, problem_cls),
        {"ae_hparams_set": autoencoder_hparams})
    registry.register_problem(ae_cls)

    # Create and register the simulated Problem
    simulated_cls = type(
        "GymSimulatedDiscreteProblemWithAgentOn%s" % camel_game_name,
        (GymSimulatedDiscreteProblem, problem_cls), {
            "initial_frames_problem": with_agent_cls.name,
            "num_testing_steps": 100
        })
    registry.register_problem(simulated_cls)

    simulated_ae_cls = type(
        "GymSimulatedDiscreteProblemWithAgentOn%sAutoencoded" %
        camel_game_name, (GymSimulatedDiscreteProblemAutoencoded, problem_cls),
        {
            "initial_frames_problem": ae_cls.name,
            "num_testing_steps": 100,
            "ae_hparams_set": autoencoder_hparams
        })
    registry.register_problem(simulated_ae_cls)

    # Create and register the simulated Problem
    world_model_eval_cls = type(
        "GymSimulatedDiscreteProblemForWorldModelEvalWithAgentOn%s" %
        camel_game_name,
        (GymSimulatedDiscreteProblemForWorldModelEval, problem_cls), {
            "initial_frames_problem": with_agent_cls.name,
            "num_testing_steps": 100,
            "ae_hparams_set": autoencoder_hparams
        })
    registry.register_problem(world_model_eval_cls)

    world_model_eval_ae_cls = type(
        "GymSimulatedDiscreteProblemForWorldModelEvalWithAgentOn%sAutoencoded"
        % camel_game_name,
        (GymSimulatedDiscreteProblemForWorldModelEvalAutoencoded, problem_cls),
        {
            "initial_frames_problem": ae_cls.name,
            "num_testing_steps": 100,
            "ae_hparams_set": autoencoder_hparams
        })
    registry.register_problem(world_model_eval_ae_cls)