def test_registry_has_builtin_initializers(self): all_initializers = { "normal": torch.nn.init.normal, "uniform": torch.nn.init.uniform, "orthogonal": torch.nn.init.orthogonal, "constant": torch.nn.init.constant, "dirac": torch.nn.init.dirac, "xavier_normal": torch.nn.init.xavier_normal, "xavier_uniform": torch.nn.init.xavier_uniform, "kaiming_normal": torch.nn.init.kaiming_normal, "kaiming_uniform": torch.nn.init.kaiming_uniform, "sparse": torch.nn.init.sparse, "eye": torch.nn.init.eye, } for key, value in all_initializers.items(): assert Registry.get_initializer(key) == value
def test_registry_has_builtin_token_indexers(self): assert Registry.get_token_indexer('single_id').__name__ == 'SingleIdTokenIndexer' assert Registry.get_token_indexer('characters').__name__ == 'TokenCharactersIndexer'
def test_registry_has_builtin_tokenizers(self): assert Registry.get_tokenizer('word').__name__ == 'WordTokenizer' assert Registry.get_tokenizer('character').__name__ == 'CharacterTokenizer'
def test_registry_has_builtin_iterators(self): assert Registry.get_data_iterator('adaptive').__name__ == 'AdaptiveIterator' assert Registry.get_data_iterator('basic').__name__ == 'BasicIterator' assert Registry.get_data_iterator('bucket').__name__ == 'BucketIterator'
def test_registry_has_builtin_readers(self): assert Registry.get_dataset_reader('snli').__name__ == 'SnliReader' assert Registry.get_dataset_reader('sequence_tagging').__name__ == 'SequenceTaggingDatasetReader' assert Registry.get_dataset_reader('language_modeling').__name__ == 'LanguageModelingReader' assert Registry.get_dataset_reader('squad_sentence_selection').__name__ == 'SquadSentenceSelectionReader'
def test_registry_has_builtin_text_field_embedders(self): assert Registry.get_text_field_embedder("basic").__name__ == 'BasicTextFieldEmbedder'
def test_registry_has_builtin_token_embedders(self): assert Registry.get_token_embedder("embedding").__name__ == 'Embedding'
def test_registry_has_builtin_regularizers(self): assert Registry.get_regularizer('l1').__name__ == 'L1Regularizer' assert Registry.get_regularizer('l2').__name__ == 'L2Regularizer'