def test_from_dataset_respects_inclusive_embedding_file(self): embeddings_filename = self.TEST_DIR + "embeddings.gz" with gzip.open(embeddings_filename, 'wb') as embeddings_file: embeddings_file.write("a 1.0 2.3 -1.0\n".encode('utf-8')) embeddings_file.write("b 0.1 0.4 -4.0\n".encode('utf-8')) vocab = Vocabulary.from_dataset( self.dataset, min_count=4, pretrained_files={'tokens': embeddings_filename}, only_include_pretrained_words=False) words = vocab.get_index_to_token_vocabulary().values() assert 'a' in words assert 'b' in words assert 'c' not in words vocab = Vocabulary.from_dataset( self.dataset, min_count=-1, pretrained_files={'tokens': embeddings_filename}, only_include_pretrained_words=False) words = vocab.get_index_to_token_vocabulary().values() assert 'a' in words assert 'b' in words assert 'c' in words
def test_from_dataset_respects_min_count(self): vocab = Vocabulary.from_dataset(self.dataset, min_count=4) words = vocab.get_index_to_token_vocabulary().values() assert 'a' in words assert 'b' not in words assert 'c' not in words vocab = Vocabulary.from_dataset(self.dataset, min_count=1) words = vocab.get_index_to_token_vocabulary().values() assert 'a' in words assert 'b' in words assert 'c' in words
def ensure_model_saves_and_loads(self, model: Model, dataset_reader: DatasetReader, iterator: DataIterator = None): data_iterator = iterator or BasicIterator() dataset = dataset_reader.read(self.TRAIN_FILE) vocab = Vocabulary.from_dataset(dataset) dataset.index_instances(vocab) single_batch = next(data_iterator(dataset)) single_batch = arrays_to_variables(single_batch) model_predictions = model.forward(**single_batch) torch.save(model.state_dict(), self.MODEL_FILE) loaded_model = model loaded_model.load_state_dict(torch.load(self.MODEL_FILE)) loaded_model_predictions = loaded_model.forward(**single_batch) # Both outputs should have the same keys and the values # for these keys should be close. for key in model_predictions.keys(): assert_allclose(model_predictions[key].data.numpy(), loaded_model_predictions[key].data.numpy()) return model, loaded_model
def simple_tagger_model() -> Model: """create a simple tagger model.""" # this is a bad hack to get the same data as the test case # TODO(joelgrus): replace this test_case = AllenNlpTestCase() test_case.setUp() test_case.write_sequence_tagging_data() dataset = SequenceTaggingDatasetReader().read(test_case.TRAIN_FILE) vocab = Vocabulary.from_dataset(dataset) dataset.index_instances(vocab) params = Params({ "text_field_embedder": { "tokens": { "type": "embedding", "embedding_dim": 5 } }, "hidden_size": 7, "num_layers": 2 }) model = SimpleTagger.from_params(vocab, params) tokenizer = WordTokenizer() def run(blob: JSON): sentence = blob.get("input", "") tokens = tokenizer.tokenize(sentence) text = TextField(tokens, token_indexers={"tokens": SingleIdTokenIndexer()}) output = model.tag(text) # convert np array to serializable list output['class_probabilities'] = output['class_probabilities'].tolist() possible_tags = list( vocab.get_index_to_token_vocabulary("tags").values()) return { 'model_name': 'simple_tagger', 'input': sentence, 'output': output, 'tokens': tokens, 'possible_tags': possible_tags } return run
def test_saving_and_loading_works_with_byte_encoding(self): # We're going to set a vocabulary from a TextField using byte encoding, index it, save the # vocab, load the vocab, then index the text field again, and make sure we get the same # result. tokenizer = CharacterTokenizer(byte_encoding='utf-8') token_indexer = TokenCharactersIndexer(character_tokenizer=tokenizer) tokens = [Token(t) for t in ["Øyvind", "für", "汉字"]] text_field = TextField(tokens, {"characters": token_indexer}) dataset = Dataset([Instance({"sentence": text_field})]) vocab = Vocabulary.from_dataset(dataset) text_field.index(vocab) indexed_tokens = deepcopy(text_field._indexed_tokens) # pylint: disable=protected-access vocab_dir = os.path.join(self.TEST_DIR, 'vocab_save') vocab.save_to_files(vocab_dir) vocab2 = Vocabulary.from_files(vocab_dir) text_field2 = TextField(tokens, {"characters": token_indexer}) text_field2.index(vocab2) indexed_tokens2 = deepcopy(text_field2._indexed_tokens) # pylint: disable=protected-access assert indexed_tokens == indexed_tokens2