Esempio n. 1
0
def optimizer(name):
    """Get pre-registered optimizer keyed by name.

  `name` should be snake case, though SGD -> sgd, RMSProp -> rms_prop and
  UpperCamelCase -> snake_case conversions included for legacy support.

  Args:
    name: name of optimizer used in registration. This should be a snake case
      identifier, though others supported for legacy reasons.

  Returns:
    optimizer
  """
    warn_msg = ("Please update `registry.optimizer` callsite "
                "(likely due to a `HParams.optimizer` value)")
    if name == "SGD":
        name = "sgd"
        tf.logging.warning("'SGD' optimizer now keyed by 'sgd'. %s" % warn_msg)
    elif name == "RMSProp":
        name = "rms_prop"
        tf.logging.warning("'RMSProp' optimizer now keyed by 'rms_prop'. %s" %
                           warn_msg)
    else:
        snake_name = misc_utils.camelcase_to_snakecase(name)
        if name != snake_name:
            tf.logging.warning(
                "optimizer names now keyed by snake_case names. %s" % warn_msg)
            name = snake_name
    return Registries.optimizers[name]
Esempio n. 2
0
def default_name(obj_class):
    """Convert a class name to the registry's default name for the class.

  Args:
    obj_class: the name of a class

  Returns:
    The registry's default name for the class.
  """
    return misc_utils.camelcase_to_snakecase(obj_class.__name__)
Esempio n. 3
0
def default_name(class_or_fn):
    """Default name for a class or function.

  This is the naming function by default for registries expecting classes or
  functions.

  Args:
    class_or_fn: class or function to be named.

  Returns:
    Default name for registration.
  """
    return misc_utils.camelcase_to_snakecase(class_or_fn.__name__)
Esempio n. 4
0
def infer_game_name_from_filenames(data_dir, snake_case=True):
  """Infer name from filenames."""
  names = os.listdir(data_dir)
  game_names = [re.findall(pattern=r"^Gym(.*)NoFrameskip", string=name)
                for name in names]
  assert game_names, "No data files found in {}".format(data_dir)
  game_names = sum(game_names, [])
  game_name = game_names[0]
  assert all(game_name == other for other in game_names), \
      "There are multiple different game names in {}".format(data_dir)
  if snake_case:
    game_name = camelcase_to_snakecase(game_name)
  return game_name
Esempio n. 5
0
 def test_camelcase_to_snakecase(self):
   self.assertEqual("typical_camel_case",
                    misc_utils.camelcase_to_snakecase("TypicalCamelCase"))
   self.assertEqual("numbers_fuse2gether",
                    misc_utils.camelcase_to_snakecase("NumbersFuse2gether"))
   self.assertEqual("numbers_fuse2_gether",
                    misc_utils.camelcase_to_snakecase("NumbersFuse2Gether"))
   self.assertEqual("lstm_seq2_seq",
                    misc_utils.camelcase_to_snakecase("LSTMSeq2Seq"))
   self.assertEqual("starts_lower",
                    misc_utils.camelcase_to_snakecase("startsLower"))
   self.assertEqual("starts_lower_caps",
                    misc_utils.camelcase_to_snakecase("startsLowerCAPS"))
   self.assertEqual("caps_fuse_together",
                    misc_utils.camelcase_to_snakecase("CapsFUSETogether"))
   self.assertEqual("startscap",
                    misc_utils.camelcase_to_snakecase("Startscap"))
   self.assertEqual("s_tartscap",
                    misc_utils.camelcase_to_snakecase("STartscap"))
def _register_scan_problems():
    classes = [
        AlgorithmicSCAN,
        AlgorithmicSCANSep,
    ]
    for problem_name, txts in six.iteritems(_problems_to_register()):
        for class_ in classes:
            base_problem_class_name = misc_utils.camelcase_to_snakecase(
                class_.__name__)
            problem_class = type(f"{base_problem_class_name}_{problem_name}",
                                 (class_, ), {
                                     "train_txt": txts[0],
                                     "test_txt": txts[1]
                                 })
            registry.register_problem(problem_class)
            REGISTERED_PROBLEMS.append(problem_class.name)
Esempio n. 7
0
 def name(self):
   return misc_utils.camelcase_to_snakecase(type(self).__name__)
Esempio n. 8
0
def _register_base_optimizer(name, opt):
    key = misc_utils.camelcase_to_snakecase(name)
    if key in registry.Registries.optimizers:
        return
    registry.register_optimizer(key)(
        lambda learning_rate, hparams: opt(learning_rate))
Esempio n. 9
0
 def name(cls, model_hparams, vocab_size=None):
     del model_hparams, vocab_size  # unused arg
     return misc_utils.camelcase_to_snakecase(type(cls).__name__)