def optimizer(name): """Get pre-registered optimizer keyed by name. `name` should be snake case, though SGD -> sgd, RMSProp -> rms_prop and UpperCamelCase -> snake_case conversions included for legacy support. Args: name: name of optimizer used in registration. This should be a snake case identifier, though others supported for legacy reasons. Returns: optimizer """ warn_msg = ("Please update `registry.optimizer` callsite " "(likely due to a `HParams.optimizer` value)") if name == "SGD": name = "sgd" tf.logging.warning("'SGD' optimizer now keyed by 'sgd'. %s" % warn_msg) elif name == "RMSProp": name = "rms_prop" tf.logging.warning("'RMSProp' optimizer now keyed by 'rms_prop'. %s" % warn_msg) else: snake_name = misc_utils.camelcase_to_snakecase(name) if name != snake_name: tf.logging.warning( "optimizer names now keyed by snake_case names. %s" % warn_msg) name = snake_name return Registries.optimizers[name]
def default_name(obj_class): """Convert a class name to the registry's default name for the class. Args: obj_class: the name of a class Returns: The registry's default name for the class. """ return misc_utils.camelcase_to_snakecase(obj_class.__name__)
def default_name(class_or_fn): """Default name for a class or function. This is the naming function by default for registries expecting classes or functions. Args: class_or_fn: class or function to be named. Returns: Default name for registration. """ return misc_utils.camelcase_to_snakecase(class_or_fn.__name__)
def infer_game_name_from_filenames(data_dir, snake_case=True): """Infer name from filenames.""" names = os.listdir(data_dir) game_names = [re.findall(pattern=r"^Gym(.*)NoFrameskip", string=name) for name in names] assert game_names, "No data files found in {}".format(data_dir) game_names = sum(game_names, []) game_name = game_names[0] assert all(game_name == other for other in game_names), \ "There are multiple different game names in {}".format(data_dir) if snake_case: game_name = camelcase_to_snakecase(game_name) return game_name
def test_camelcase_to_snakecase(self): self.assertEqual("typical_camel_case", misc_utils.camelcase_to_snakecase("TypicalCamelCase")) self.assertEqual("numbers_fuse2gether", misc_utils.camelcase_to_snakecase("NumbersFuse2gether")) self.assertEqual("numbers_fuse2_gether", misc_utils.camelcase_to_snakecase("NumbersFuse2Gether")) self.assertEqual("lstm_seq2_seq", misc_utils.camelcase_to_snakecase("LSTMSeq2Seq")) self.assertEqual("starts_lower", misc_utils.camelcase_to_snakecase("startsLower")) self.assertEqual("starts_lower_caps", misc_utils.camelcase_to_snakecase("startsLowerCAPS")) self.assertEqual("caps_fuse_together", misc_utils.camelcase_to_snakecase("CapsFUSETogether")) self.assertEqual("startscap", misc_utils.camelcase_to_snakecase("Startscap")) self.assertEqual("s_tartscap", misc_utils.camelcase_to_snakecase("STartscap"))
def _register_scan_problems(): classes = [ AlgorithmicSCAN, AlgorithmicSCANSep, ] for problem_name, txts in six.iteritems(_problems_to_register()): for class_ in classes: base_problem_class_name = misc_utils.camelcase_to_snakecase( class_.__name__) problem_class = type(f"{base_problem_class_name}_{problem_name}", (class_, ), { "train_txt": txts[0], "test_txt": txts[1] }) registry.register_problem(problem_class) REGISTERED_PROBLEMS.append(problem_class.name)
def name(self): return misc_utils.camelcase_to_snakecase(type(self).__name__)
def _register_base_optimizer(name, opt): key = misc_utils.camelcase_to_snakecase(name) if key in registry.Registries.optimizers: return registry.register_optimizer(key)( lambda learning_rate, hparams: opt(learning_rate))
def name(cls, model_hparams, vocab_size=None): del model_hparams, vocab_size # unused arg return misc_utils.camelcase_to_snakecase(type(cls).__name__)