예제 #1
0
 def test_registry_has_builtin_initializers(self):
     all_initializers = {
             "normal": torch.nn.init.normal,
             "uniform": torch.nn.init.uniform,
             "orthogonal": torch.nn.init.orthogonal,
             "constant": torch.nn.init.constant,
             "dirac": torch.nn.init.dirac,
             "xavier_normal": torch.nn.init.xavier_normal,
             "xavier_uniform": torch.nn.init.xavier_uniform,
             "kaiming_normal": torch.nn.init.kaiming_normal,
             "kaiming_uniform": torch.nn.init.kaiming_uniform,
             "sparse": torch.nn.init.sparse,
             "eye": torch.nn.init.eye,
     }
     for key, value in all_initializers.items():
         assert Registry.get_initializer(key) == value
예제 #2
0
 def test_registry_has_builtin_token_indexers(self):
     assert Registry.get_token_indexer('single_id').__name__ == 'SingleIdTokenIndexer'
     assert Registry.get_token_indexer('characters').__name__ == 'TokenCharactersIndexer'
예제 #3
0
 def test_registry_has_builtin_tokenizers(self):
     assert Registry.get_tokenizer('word').__name__ == 'WordTokenizer'
     assert Registry.get_tokenizer('character').__name__ == 'CharacterTokenizer'
예제 #4
0
 def test_registry_has_builtin_iterators(self):
     assert Registry.get_data_iterator('adaptive').__name__ == 'AdaptiveIterator'
     assert Registry.get_data_iterator('basic').__name__ == 'BasicIterator'
     assert Registry.get_data_iterator('bucket').__name__ == 'BucketIterator'
예제 #5
0
 def test_registry_has_builtin_readers(self):
     assert Registry.get_dataset_reader('snli').__name__ == 'SnliReader'
     assert Registry.get_dataset_reader('sequence_tagging').__name__ == 'SequenceTaggingDatasetReader'
     assert Registry.get_dataset_reader('language_modeling').__name__ == 'LanguageModelingReader'
     assert Registry.get_dataset_reader('squad_sentence_selection').__name__ == 'SquadSentenceSelectionReader'
예제 #6
0
 def test_registry_has_builtin_text_field_embedders(self):
     assert Registry.get_text_field_embedder("basic").__name__ == 'BasicTextFieldEmbedder'
예제 #7
0
 def test_registry_has_builtin_token_embedders(self):
     assert Registry.get_token_embedder("embedding").__name__ == 'Embedding'
예제 #8
0
 def test_registry_has_builtin_regularizers(self):
     assert Registry.get_regularizer('l1').__name__ == 'L1Regularizer'
     assert Registry.get_regularizer('l2').__name__ == 'L2Regularizer'