def test_snakecase_to_camelcase(self): self.assertEqual("TypicalCamelCase", misc_utils.snakecase_to_camelcase("typical_camel_case")) self.assertEqual("NumbersFuse2gether", misc_utils.snakecase_to_camelcase("numbers_fuse2gether")) self.assertEqual("NumbersFuse2Gether", misc_utils.snakecase_to_camelcase("numbers_fuse2_gether")) self.assertEqual("LstmSeq2Seq", misc_utils.snakecase_to_camelcase("lstm_seq2_seq"))
def full_game_name(short_name): """CamelCase game name with mode suffix. Args: short_name: snake_case name without mode e.g "crazy_climber" Returns: full game name e.g. "CrazyClimberNoFrameskip-v4" """ camel_game_name = misc_utils.snakecase_to_camelcase(short_name) full_name = camel_game_name + ATARI_GAME_MODE return full_name
def setup_env(hparams, batch_size, max_num_noops, rl_env_max_episode_steps=-1): """Setup.""" game_mode = "NoFrameskip-v4" camel_game_name = misc_utils.snakecase_to_camelcase(hparams.game) camel_game_name += game_mode env_name = camel_game_name env = T2TGymEnv(base_env_name=env_name, batch_size=batch_size, grayscale=hparams.grayscale, resize_width_factor=hparams.resize_width_factor, resize_height_factor=hparams.resize_height_factor, rl_env_max_episode_steps=rl_env_max_episode_steps, max_num_noops=max_num_noops, maxskip_envs=True) return env
def register_game(game_name, game_mode="NoFrameskip-v4"): """Create and register problems for the game. Args: game_name: str, one of the games in ATARI_GAMES, e.g. "bank_heist". game_mode: the frame skip and sticky keys config. Raises: ValueError: if game_name or game_mode are wrong. """ if game_name not in ATARI_GAMES: raise ValueError("Game %s not in ATARI_GAMES" % game_name) if game_mode not in ATARI_GAME_MODES: raise ValueError("Unknown ATARI game mode: %s." % game_mode) camel_game_name = misc_utils.snakecase_to_camelcase(game_name) + game_mode # Create and register the Problem cls = type("Gym%sRandom" % camel_game_name, (T2TGymEnv, ), {"base_env_name": camel_game_name}) registry.register_problem(cls)
def default_opt_name(opt_fn): return misc_utils.snakecase_to_camelcase(default_name(opt_fn))