Ejemplo n.º 1
0
 def test_snakecase_to_camelcase(self):
   self.assertEqual("TypicalCamelCase",
                    misc_utils.snakecase_to_camelcase("typical_camel_case"))
   self.assertEqual("NumbersFuse2gether",
                    misc_utils.snakecase_to_camelcase("numbers_fuse2gether"))
   self.assertEqual("NumbersFuse2Gether",
                    misc_utils.snakecase_to_camelcase("numbers_fuse2_gether"))
   self.assertEqual("LstmSeq2Seq",
                    misc_utils.snakecase_to_camelcase("lstm_seq2_seq"))
Ejemplo n.º 2
0
def full_game_name(short_name):
  """CamelCase game name with mode suffix.

  Args:
    short_name: snake_case name without mode e.g "crazy_climber"

  Returns:
    full game name e.g. "CrazyClimberNoFrameskip-v4"
  """
  camel_game_name = misc_utils.snakecase_to_camelcase(short_name)
  full_name = camel_game_name + ATARI_GAME_MODE
  return full_name
Ejemplo n.º 3
0
def setup_env(hparams, batch_size, max_num_noops, rl_env_max_episode_steps=-1):
  """Setup."""
  game_mode = "NoFrameskip-v4"
  camel_game_name = misc_utils.snakecase_to_camelcase(hparams.game)
  camel_game_name += game_mode
  env_name = camel_game_name

  env = T2TGymEnv(base_env_name=env_name,
                  batch_size=batch_size,
                  grayscale=hparams.grayscale,
                  resize_width_factor=hparams.resize_width_factor,
                  resize_height_factor=hparams.resize_height_factor,
                  rl_env_max_episode_steps=rl_env_max_episode_steps,
                  max_num_noops=max_num_noops, maxskip_envs=True)
  return env
Ejemplo n.º 4
0
def register_game(game_name, game_mode="NoFrameskip-v4"):
    """Create and register problems for the game.

  Args:
    game_name: str, one of the games in ATARI_GAMES, e.g. "bank_heist".
    game_mode: the frame skip and sticky keys config.

  Raises:
    ValueError: if game_name or game_mode are wrong.
  """
    if game_name not in ATARI_GAMES:
        raise ValueError("Game %s not in ATARI_GAMES" % game_name)
    if game_mode not in ATARI_GAME_MODES:
        raise ValueError("Unknown ATARI game mode: %s." % game_mode)
    camel_game_name = misc_utils.snakecase_to_camelcase(game_name) + game_mode
    # Create and register the Problem
    cls = type("Gym%sRandom" % camel_game_name, (T2TGymEnv, ),
               {"base_env_name": camel_game_name})
    registry.register_problem(cls)
Ejemplo n.º 5
0
 def default_opt_name(opt_fn):
     return misc_utils.snakecase_to_camelcase(default_name(opt_fn))